├── data └── .placeholder ├── src ├── __init__.py ├── data │ ├── __init__.py │ ├── datamodules │ │ ├── __init__.py │ │ ├── simple.py │ │ └── cnn_dm.py │ └── datasets │ │ ├── __init__.py │ │ ├── alternator.py │ │ └── generation.py ├── optim │ ├── __init__.py │ ├── optimizers │ │ ├── __init__.py │ │ └── radam.py │ └── factories.py ├── scripts │ ├── __init__.py │ └── model │ │ ├── __init__.py │ │ ├── train.py │ │ └── translate.py ├── utils │ ├── __init__.py │ ├── logging.py │ └── commons.py ├── callbacks │ ├── __init__.py │ ├── best_checkpoint.py │ └── generation.py └── pl_modules │ ├── __init__.py │ ├── generative_models │ ├── __init__.py │ ├── generative_model.py │ └── bart.py │ ├── utils.py │ └── generative_pl_module.py ├── experiments └── .placeholder ├── configurations ├── hydra-train │ ├── device │ │ ├── cpu.yaml │ │ ├── cuda.yaml │ │ └── cuda_amp.yaml │ ├── callbacks │ │ ├── default.yaml │ │ └── summarization.yaml │ ├── model │ │ └── bart.yaml │ ├── data │ │ ├── cnn_dm.yaml │ │ └── simple.yaml │ ├── root.yaml │ └── training │ │ └── default.yaml └── generation │ ├── beam.yaml │ └── sample.yaml ├── requirements.txt ├── setup.sh ├── .gitignore └── README.md /data/.placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/optim/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/.placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/pl_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/optim/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/scripts/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configurations/hydra-train/device/cpu.yaml: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /src/pl_modules/generative_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configurations/hydra-train/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | callbacks: [] -------------------------------------------------------------------------------- /configurations/hydra-train/device/cuda.yaml: -------------------------------------------------------------------------------- 1 | gpus: 2 | - 0 3 | precision: 32 4 | amp_level: 'O0' -------------------------------------------------------------------------------- /configurations/hydra-train/device/cuda_amp.yaml: -------------------------------------------------------------------------------- 1 | gpus: 2 | - 0 3 | precision: 16 4 | amp_level: 'O1' -------------------------------------------------------------------------------- /configurations/generation/beam.yaml: -------------------------------------------------------------------------------- 1 | num_beams: 5 2 | min_length: 5 3 | max_length: 100 4 | length_penalty: 1.0 5 | repetition_penalty: 1.0 -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def get_project_logger(module_name: str) -> logging.Logger: 5 | return logging.getLogger(f"grid_seq2seq.{module_name}") 6 | -------------------------------------------------------------------------------- /configurations/generation/sample.yaml: -------------------------------------------------------------------------------- 1 | num_beams: 1 2 | do_sample: True 3 | top_p: 0.8 4 | temperature: 1.0 5 | min_length: 25 6 | max_length: 200 7 | length_penalty: 1.0 8 | repetition_penalty: 1.0 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==1.2.3 2 | transformers==4.3.3 3 | hydra-core==1.1.0-dev4 4 | wandb==0.10.21 5 | # summarization dependencies 6 | datasets==1.4.1 7 | nltk==3.5 8 | rouge-score==0.0.4 9 | # black 10 | black==20.8b1 11 | -------------------------------------------------------------------------------- /configurations/hydra-train/model/bart.yaml: -------------------------------------------------------------------------------- 1 | _target_: 'src.pl_modules.generative_pl_module.GenerativePLModule' 2 | generative_model: 3 | _target_: 'src.pl_modules.generative_models.bart.BartGenerativeModel' 4 | bart_model: 'facebook/bart-large' 5 | dropout: 0.1 6 | label_smoothing: 0.1 # 0.0 for non label smoothing -------------------------------------------------------------------------------- /configurations/hydra-train/data/cnn_dm.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: 'src.data.datamodules.cnn_dm.CNNDMDataModule' 3 | data_dir: 'data/cnn-dm' 4 | num_workers: 0 5 | dataset: 6 | _target_: 'src.data.datasets.generation.ParallelDataset.from_file' 7 | max_tokens_per_batch: 1300 8 | min_length: 5 9 | max_length: 500 10 | truncate: true 11 | section_size: 10000 12 | -------------------------------------------------------------------------------- /configurations/hydra-train/callbacks/summarization.yaml: -------------------------------------------------------------------------------- 1 | callbacks: 2 | - _target_: "src.callbacks.generation.TextGenerationCallback" 3 | generation_callbacks: 4 | rouge: 5 | _target_: 'src.callbacks.generation.RougeGenerationCallback' 6 | generations: 7 | - name: "beam1" 8 | glob_translate_path: "data/cnn-dm/validation.tsv" 9 | generation_param_conf_path: "configurations/generation/beam.yaml" 10 | num_sequences: 1 11 | token_batch_size: 800 12 | limit: 1000 13 | enabled_generation_callbacks: 14 | - 'rouge' -------------------------------------------------------------------------------- /configurations/hydra-train/root.yaml: -------------------------------------------------------------------------------- 1 | project_name: grid-seq2seq 2 | exp_name: ??? 3 | exp_folder: ./experiments/${exp_name} 4 | 5 | hydra: 6 | # customize working dir 7 | run: 8 | dir: ./experiments/${exp_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 9 | # customize logging 10 | verbose: [ grid_seq2seq ] 11 | job_logging: 12 | formatters: 13 | simple: 14 | format: '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 15 | root: 16 | level: WARN 17 | 18 | # defaults 19 | defaults: 20 | - callbacks: default 21 | - data: null 22 | - device: cpu 23 | - model: bart 24 | - training: default 25 | -------------------------------------------------------------------------------- /configurations/hydra-train/data/simple.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: 'src.data.datamodules.simple.SimpleDataModule' 3 | num_workers: 0 4 | train_dataset: 5 | _target_: 'src.data.datasets.generation.ParallelDataset.from_file' 6 | path: null 7 | max_tokens_per_batch: 1300 8 | min_length: 5 9 | max_length: 500 10 | truncate: true 11 | section_size: 10000 12 | validation_dataset: 13 | _target_: 'src.data.datasets.generation.ParallelDataset.from_file' 14 | path: null 15 | max_tokens_per_batch: 1300 16 | min_length: 5 17 | max_length: 500 18 | truncate: true 19 | section_size: 10000 20 | -------------------------------------------------------------------------------- /configurations/hydra-train/training/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | # reproducibility 4 | reprod: 5 | seed: 12 6 | 7 | # optimization 8 | trainer: 9 | gradient_acc_steps: 4 10 | gradient_clip_value: 10.0 11 | max_steps: 100000 12 | val_check_interval: 1000 13 | patience: 5 14 | optim: 15 | _target_: "src.optim.factories.TorchFactory" 16 | optimizer: 17 | _target_: torch.optim.Adam 18 | lr: 1e-5 19 | checkpoint: 20 | filename: '{val_loss:.4f}' 21 | monitor: '-val_loss' 22 | save_top_k: 5 23 | save_last: true 24 | 25 | # logger 26 | logger: 27 | _target_: 'pytorch_lightning.loggers.wandb.WandbLogger' 28 | name: ${exp_name} 29 | project: ${project_name} 30 | 31 | -------------------------------------------------------------------------------- /src/callbacks/best_checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | 5 | from pytorch_lightning.callbacks import ModelCheckpoint 6 | 7 | 8 | class ModelCheckpointWithBest(ModelCheckpoint): 9 | """A callback that explictly saves the best checkpoint with best.ckpt. 10 | 11 | Note that the best checkpoint is duplicated, rather than linked, in best.ckpt 12 | """ 13 | 14 | CHECKPOINT_NAME_BEST = 'best.ckpt' 15 | 16 | def on_validation_end(self, trainer, pl_module): 17 | super().on_validation_end(trainer, pl_module) 18 | if self.best_model_path == "": 19 | return 20 | orig_best = Path(self.best_model_path) 21 | shutil.copyfile(orig_best, orig_best.parent / self.CHECKPOINT_NAME_BEST) 22 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # setup conda 4 | source ~/miniconda3/etc/profile.d/conda.sh 5 | 6 | # create conda env 7 | read -rp "Enter environment name: " env_name 8 | read -rp "Enter python version (e.g. 3.7): " python_version 9 | conda create -yn "$env_name" python="$python_version" 10 | conda activate "$env_name" 11 | 12 | # install torch 13 | read -rp "Enter cuda version (e.g. 10.1 or none to avoid installing cuda support): " cuda_version 14 | if [ "$cuda_version" == "none" ]; then 15 | conda install -y pytorch torchvision cpuonly -c pytorch 16 | else 17 | conda install -y pytorch torchvision cudatoolkit=$cuda_version -c pytorch 18 | fi 19 | 20 | # install python requirements 21 | pip install -r requirements.txt 22 | 23 | # login into wandb 24 | read -p "Enter wandb key: " wandb_key 25 | wandb login $wandb_key 26 | -------------------------------------------------------------------------------- /src/utils/commons.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from typing import Optional, List 3 | 4 | import numpy as np 5 | 6 | from src.utils.logging import get_project_logger 7 | 8 | logger = get_project_logger(__name__) 9 | 10 | 11 | def execute_bash_command(command: str) -> Optional[str]: 12 | command_result = subprocess.run(command, shell=True, capture_output=True) 13 | try: 14 | command_result.check_returncode() 15 | return command_result.stdout.decode("utf-8") 16 | except subprocess.CalledProcessError: 17 | logger.warning(f"failed executing command: {command}") 18 | logger.warning(f"return code was: {command_result.returncode}") 19 | logger.warning(f'stdout was: {command_result.stdout.decode("utf-8")}') 20 | logger.warning(f'stderr code was: {command_result.stderr.decode("utf-8")}') 21 | return None 22 | 23 | 24 | def chunks(l: List, k: int) -> List[List]: 25 | assert k >= 1 26 | return [l[i : i + k] for i in range(0, len(l), k)] 27 | 28 | 29 | def flatten(lst: List[list]) -> list: 30 | return [_e for sub_l in lst for _e in sub_l] 31 | 32 | 33 | def add_noise_to_value(value: int, noise_param: float): 34 | noise_value = value * noise_param 35 | noise = np.random.uniform(-noise_value, noise_value) 36 | return max(1, value + noise) 37 | -------------------------------------------------------------------------------- /src/data/datasets/alternator.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import random 4 | from typing import Iterator, List, Dict 5 | 6 | from torch.utils.data import IterableDataset 7 | 8 | 9 | class AlternatorIterableDataset(IterableDataset): 10 | """IterableDataset that allows for batch alternation, according to specified probabilities, among multiple IterableDataset-s. 11 | 12 | Iteration continues until all datasets are exhausted, restarting all those that finish until this condition is met. 13 | """ 14 | 15 | def __init__(self, datasets: List[Dict], p: List[float], **kwargs): 16 | assert len(datasets) == len(p) 17 | self.datasets: List[IterableDataset] = [ 18 | hydra.utils.instantiate(d, **kwargs) for d in datasets 19 | ] 20 | self.p = p 21 | assert np.isclose(sum(self.p), 1.0) 22 | 23 | def __iter__(self) -> Iterator: 24 | 25 | done = [False for _ in self.datasets] 26 | iterators = [iter(d) for d in self.datasets] 27 | 28 | while True: 29 | 30 | i = random.choices(list(range(len(self.datasets))), weights=self.p, k=1)[0] 31 | 32 | try: 33 | batch = next(iterators[i]) 34 | except StopIteration: 35 | done[i] = True 36 | if all(done): 37 | break 38 | iterators[i] = iter(self.datasets[i]) 39 | batch = next(iterators[i]) 40 | 41 | yield batch 42 | -------------------------------------------------------------------------------- /src/pl_modules/utils.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import hydra 3 | import torch 4 | from omegaconf import OmegaConf 5 | 6 | 7 | def label_smoothed_nll_loss(lprobs, target, epsilon, padding_mask: torch.Tensor): 8 | """ 9 | Inspired by https://github.com/huggingface/transformers/blob/5148f433097915f30864bf0ca6090656fecefbb8/examples/seq2seq/utils.py 10 | 11 | With a change however: using mean rather than sum ( nll_loss = nll...) 12 | 13 | """ 14 | 15 | assert target.dim() == padding_mask.dim() 16 | if target.dim() == lprobs.dim() - 1: 17 | target = target.unsqueeze(-1) 18 | padding_mask = padding_mask.unsqueeze(-1) 19 | 20 | # compute nll loss 21 | nll_loss = -lprobs.gather(dim=-1, index=target) 22 | nll_loss.masked_fill_(~padding_mask, 0.0) 23 | 24 | # compute smooth loss 25 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 26 | smooth_loss.masked_fill_(~padding_mask, 0.0) 27 | 28 | nll_loss = nll_loss.sum() / padding_mask.sum() 29 | smooth_loss = smooth_loss.sum() / padding_mask.sum() 30 | 31 | eps_i = epsilon / lprobs.size(-1) 32 | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss 33 | 34 | return loss, nll_loss 35 | 36 | 37 | def load_pl_module_from_checkpoint(checkpoint_path: str): 38 | """ 39 | Load a PL module from a checkpoint path only. Infer the model to load from the dumped hydra conf 40 | 41 | Args: 42 | checkpoint_path (str): 43 | 44 | Returns: 45 | pl.LightningModule 46 | 47 | """ 48 | 49 | # find hydra config path 50 | hydra_config_path = "/".join(checkpoint_path.split("/")[:-2]) 51 | 52 | # load hydra config 53 | conf = OmegaConf.load(f"{hydra_config_path}/.hydra/config.yaml") 54 | 55 | # instantiate and return 56 | return hydra.utils.instantiate( 57 | {"_target_": f'{conf["model"]["_target_"]}.load_from_checkpoint'}, 58 | checkpoint_path=checkpoint_path, 59 | ) 60 | -------------------------------------------------------------------------------- /src/data/datamodules/simple.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Optional, Dict 2 | 3 | import hydra 4 | import pytorch_lightning as pl 5 | from torch.utils.data import DataLoader 6 | from transformers import PreTrainedTokenizer 7 | 8 | from src.utils.logging import get_project_logger 9 | 10 | logger = get_project_logger(__name__) 11 | 12 | 13 | class SimpleDataModule(pl.LightningDataModule): 14 | """Simple datamodule, useful when datamodules are not really needed (e.g. train using ParallelDataset on a already-downlaoded tsv file)""" 15 | 16 | def __init__( 17 | self, 18 | tokenizer: PreTrainedTokenizer, 19 | train_dataset: Optional[Dict] = None, 20 | validation_dataset: Optional[Dict] = None, 21 | test_dataset: Optional[Dict] = None, 22 | num_workers: int = 0, 23 | ): 24 | super().__init__() 25 | self.tokenizer = tokenizer 26 | self.train_dataset_conf = train_dataset 27 | self.validation_dataset_conf = validation_dataset 28 | self.test_dataset_conf = test_dataset 29 | self.num_workers = num_workers 30 | self.train_dataset, self.validation_dataset, self.test_dataset = None, None, None 31 | 32 | def prepare_data(self, *args, **kwargs): 33 | pass 34 | 35 | def setup(self, stage: Optional[str] = None): 36 | if stage == "fit": 37 | assert self.train_dataset_conf is not None and self.validation_dataset_conf is not None 38 | self.train_dataset = hydra.utils.instantiate( 39 | self.train_dataset_conf, tokenizer=self.tokenizer, _recursive_=False 40 | ) 41 | self.validation_dataset = hydra.utils.instantiate( 42 | self.validation_dataset_conf, tokenizer=self.tokenizer, _recursive_=False 43 | ) 44 | else: 45 | assert self.test_dataset_conf is not None 46 | self.test_dataset = hydra.utils.instantiate( 47 | self.test_dataset_conf, tokenizer=self.tokenizer, _recursive_=False 48 | ) 49 | 50 | def train_dataloader(self, *args, **kwargs) -> DataLoader: 51 | return DataLoader(self.train_dataset, batch_size=None, num_workers=self.num_workers) 52 | 53 | def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: 54 | return DataLoader(self.validation_dataset, batch_size=None, num_workers=self.num_workers) 55 | 56 | def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: 57 | return DataLoader(self.test_dataset, batch_size=None, num_workers=self.num_workers) 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # ide stuff 132 | .idea 133 | .vscode 134 | 135 | # custom stuff 136 | data/* 137 | experiments/* 138 | !.placeholder -------------------------------------------------------------------------------- /src/pl_modules/generative_models/generative_model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Any, Union 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | @dataclass 9 | class TAGenerativeModelOutput: 10 | """Dataclass storing the returned fields of a teacher-forced forward. 11 | 12 | Args: 13 | loss (torch.Tensor): loss upon which the backprop should be performed (affected by regularizers such as label smoothing) 14 | plain_loss (torch.Tensor): plain loss unaffected by regularizers 15 | logits (torch.Tensor) 16 | predictions (torch.Tensor) 17 | """ 18 | 19 | loss: torch.Tensor 20 | plain_loss: torch.Tensor 21 | logits: torch.Tensor 22 | predictions: torch.Tensor 23 | 24 | 25 | @dataclass 26 | class GenGenerativeModelOutput: 27 | """Dataclass storing the returned fields of a generative forward. 28 | 29 | Args: 30 | generation (torch.Tensor): tensor of shape (batch_size, num_sequences, sequence_length) storing the generated sequences 31 | raw (torch.Tensor): raw object storing details of the decoding processes. For example, when using HuggingFace Transformers' 32 | models, it can be used to store attentions, generation scores, ... 33 | """ 34 | 35 | generation: torch.Tensor 36 | raw: Any 37 | 38 | 39 | class GenerativeModel(nn.Module): 40 | """Abstract class denoting the interface of a generative model. 41 | 42 | This class essentially alternates between two states, teacher-forced or generative, and the current state regulates 43 | the forward method. The teacher-forced state is the one expected to be activate at training time: given the input, the 44 | model will compute and return a teacher-forced loss. Conversely, the generative state is for generation: the model will 45 | generate sequences conditioned on the provided input. 46 | 47 | Besides the forward method, GenerativeModel offers hooks to switch between the two modes and an additional one to 48 | load generation-time parameters. 49 | """ 50 | 51 | def __init__(self): 52 | super().__init__() 53 | self.generation_mode = False 54 | self.generation_params = None 55 | 56 | def enable_generation_mode(self): 57 | """Enables generation mode for the model.""" 58 | self.generation_mode = True 59 | 60 | def disable_generation_mode(self): 61 | """Disables generation mode for the model.""" 62 | self.generation_mode = False 63 | 64 | def load_generation_params(self, generation_params: Dict[str, Any]): 65 | """Loads and stores the given generation params, that will be used in the next generation forwards""" 66 | self.generation_params = generation_params 67 | 68 | def forward(self, *args, **kwargs) -> Union[TAGenerativeModelOutput, GenGenerativeModelOutput]: 69 | raise NotImplementedError 70 | -------------------------------------------------------------------------------- /src/pl_modules/generative_pl_module.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import hydra 4 | import pytorch_lightning as pl 5 | import torch 6 | 7 | 8 | class GenerativePLModule(pl.LightningModule): 9 | """Standard LightningModule wrapping a GenerativeModel.""" 10 | 11 | def __init__(self, generative_model: Dict, optim_conf: Dict): 12 | 13 | super().__init__() 14 | self.save_hyperparameters() 15 | 16 | # generative model & optim 17 | self.generative_model = hydra.utils.instantiate(generative_model) 18 | self._optim_conf = optim_conf 19 | 20 | # metrics 21 | self.train_acc = pl.metrics.Accuracy() 22 | self.val_acc = pl.metrics.Accuracy() 23 | 24 | @property 25 | def tokenizer(self): 26 | return self.generative_model.tokenizer 27 | 28 | def enable_generation_mode(self): 29 | self.generative_model.enable_generation_mode() 30 | 31 | def disable_generation_mode(self): 32 | self.generative_model.disable_generation_mode() 33 | 34 | def load_generation_params(self, generation_params: Dict): 35 | self.generative_model.load_generation_params(generation_params) 36 | 37 | def forward(self, *args, **kwargs): 38 | return self.generative_model(*args, **kwargs) 39 | 40 | def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: 41 | 42 | forward_output = self.forward(**batch) 43 | 44 | # loss 45 | self.log("train_loss", forward_output.loss, prog_bar=True, on_step=False, on_epoch=True) 46 | 47 | # perplexity 48 | self.log( 49 | "train_ppl", torch.exp(forward_output.plain_loss), prog_bar=True, on_step=False, on_epoch=True 50 | ) 51 | 52 | # accuracy 53 | padding_mask = batch["target_padding_mask"][:, 1:].reshape(-1) 54 | train_acc = self.train_acc( 55 | forward_output.predictions.view(-1)[padding_mask], 56 | batch["target"][:, 1:].reshape(-1)[padding_mask], 57 | ) 58 | self.log("train_accuracy", train_acc, prog_bar=True, on_step=False, on_epoch=True) 59 | 60 | # return 61 | return forward_output.loss 62 | 63 | def validation_step(self, batch: Dict[str, Any], batch_idx: int): 64 | 65 | forward_output = self.forward(**batch) 66 | 67 | # loss 68 | self.log("val_loss", forward_output.loss, prog_bar=True, on_step=False, on_epoch=True) 69 | 70 | # perplexity 71 | self.log( 72 | "val_ppl", torch.exp(forward_output.plain_loss), prog_bar=True, on_step=False, on_epoch=True 73 | ) 74 | 75 | # accuracy 76 | padding_mask = batch["target_padding_mask"][:, 1:].reshape(-1) 77 | val_acc = self.val_acc( 78 | forward_output.predictions.view(-1)[padding_mask], 79 | batch["target"][:, 1:].reshape(-1)[padding_mask], 80 | ) 81 | self.log("val_accuracy", val_acc, prog_bar=True, on_step=False, on_epoch=True) 82 | 83 | # return 84 | return forward_output.loss 85 | 86 | def configure_optimizers(self): 87 | return hydra.utils.instantiate(self._optim_conf, _recursive_=False)(module=self) 88 | -------------------------------------------------------------------------------- /src/data/datamodules/cnn_dm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union, List, Optional, Dict 3 | 4 | import hydra 5 | import pytorch_lightning as pl 6 | from datasets import load_dataset 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | from transformers import PreTrainedTokenizer 10 | 11 | from src.utils.logging import get_project_logger 12 | 13 | logger = get_project_logger(__name__) 14 | 15 | 16 | class CNNDMDataModule(pl.LightningDataModule): 17 | def __init__( 18 | self, tokenizer: PreTrainedTokenizer, data_dir: str, dataset: Dict, num_workers: int = 0 19 | ): 20 | super().__init__() 21 | self.tokenizer = tokenizer 22 | self.data_dir = data_dir 23 | self.num_workers = num_workers 24 | self.dataset = dataset 25 | self.train, self.val, self.test = None, None, None 26 | 27 | def prepare_data(self, *args, **kwargs): 28 | 29 | # check if download is needed 30 | if os.path.exists(self.data_dir): 31 | if os.path.exists(f"{self.data_dir}/LOCK"): 32 | return 33 | else: 34 | logger.warning( 35 | "Dataset folder already exists, but written dataset seems incomplete. Overwriting it" 36 | ) 37 | else: 38 | os.mkdir(self.data_dir) 39 | 40 | # download and save cnn-datamodule to disk 41 | logger.info("Downloading CNN-DM corpus") 42 | dataset = load_dataset("cnn_dailymail", "3.0.0") 43 | for split in ["train", "validation", "test"]: 44 | with open(f"{self.data_dir}/{split}.tsv", "w") as f: 45 | for sample in tqdm(dataset[split], desc=f"Writing CNN-DM {split} split to disk"): 46 | article, summary = ( 47 | sample["article"].replace("\n", ". "), 48 | sample["highlights"].replace("\n", ". "), 49 | ) 50 | f.write(f"{article}\t{summary}\n") 51 | 52 | # dump lock to skip future re-downloads 53 | with open(f"{self.data_dir}/LOCK", "w") as _: 54 | pass 55 | 56 | def setup(self, stage: Optional[str] = None): 57 | if stage == "fit": 58 | self.train = hydra.utils.instantiate( 59 | self.dataset, path=f"{self.data_dir}/train.tsv", tokenizer=self.tokenizer 60 | ) 61 | self.val = hydra.utils.instantiate( 62 | self.dataset, path=f"{self.data_dir}/validation.tsv", tokenizer=self.tokenizer 63 | ) 64 | else: 65 | self.test = hydra.utils.instantiate( 66 | self.dataset, path=f"{self.data_dir}/test.tsv", tokenizer=self.tokenizer 67 | ) 68 | 69 | def train_dataloader(self, *args, **kwargs) -> DataLoader: 70 | return DataLoader(self.train, batch_size=None, num_workers=self.num_workers) 71 | 72 | def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: 73 | return DataLoader(self.val, batch_size=None, num_workers=self.num_workers) 74 | 75 | def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: 76 | return DataLoader(self.test, batch_size=None, num_workers=self.num_workers) 77 | -------------------------------------------------------------------------------- /src/optim/factories.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import hydra 4 | import torch 5 | from omegaconf import DictConfig 6 | 7 | from src.optim.optimizers.radam import RAdam 8 | 9 | 10 | class Factory: 11 | """Factory interface that allows for simple instantiation of optimizers and schedulers for PyTorch Lightning. 12 | 13 | This class is essentially a work-around for lazy instantiation: 14 | * all params but for the module to be optimized are received in __init__ 15 | * the actual instantiation of optimizers and schedulers takes place in the __call__ method, where the module to be 16 | optimized is provided 17 | 18 | __call__ will be invoked in the configure_optimizers hooks of LighiningModule-s and its return object directly returned. 19 | As such, the return type of __call__ can be any of those allowed by configure_optimizers, namely: 20 | * Single optimizer 21 | * List or Tuple - List of optimizers 22 | * Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict) 23 | * Dictionary, with an ‘optimizer’ key, and (optionally) a ‘lr_scheduler’ key whose value is a single LR scheduler or lr_dict 24 | * Tuple of dictionaries as described, with an optional ‘frequency’ key 25 | * None - Fit will run without any optimizer 26 | """ 27 | 28 | def __call__(self, module: torch.nn.Module): 29 | raise NotImplementedError 30 | 31 | 32 | class TorchFactory(Factory): 33 | """Simple factory wrapping standard PyTorch optimizers and schedulers.""" 34 | 35 | # todo add scheduler support as well 36 | 37 | def __init__(self, optimizer: DictConfig): 38 | self.optimizer = optimizer 39 | 40 | def __call__(self, module: torch.nn.Module): 41 | return hydra.utils.instantiate(self.optimizer, params=module.parameters()) 42 | 43 | 44 | class RAdamWithDecayFactory(Factory): 45 | """Factory for RAdam optimizer.""" 46 | 47 | def __init__(self, lr: float, weight_decay: float, no_decay_params: Optional[List[str]]): 48 | self.lr = lr 49 | self.weight_decay = weight_decay 50 | self.no_decay_params = no_decay_params 51 | 52 | def __call__(self, module: torch.nn.Module): 53 | 54 | if self.no_decay_params is not None: 55 | 56 | optimizer_grouped_parameters = [ 57 | { 58 | "params": [ 59 | p 60 | for n, p in module.named_parameters() 61 | if not any(nd in n for nd in self.no_decay_params) 62 | ], 63 | "weight_decay": self.weight_decay, 64 | }, 65 | { 66 | "params": [ 67 | p 68 | for n, p in module.named_parameters() 69 | if any(nd in n for nd in self.no_decay_params) 70 | ], 71 | "weight_decay": 0.0, 72 | }, 73 | ] 74 | 75 | else: 76 | 77 | optimizer_grouped_parameters = [ 78 | {"params": module.parameters(), "weight_decay": self.weight_decay} 79 | ] 80 | 81 | optimizer = RAdam(optimizer_grouped_parameters, lr=self.lr, weight_decay=self.weight_decay) 82 | 83 | return optimizer 84 | -------------------------------------------------------------------------------- /src/scripts/model/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import hydra 4 | import omegaconf 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.callbacks import EarlyStopping 7 | 8 | from src.callbacks.best_checkpoint import ModelCheckpointWithBest 9 | 10 | 11 | def train(conf: omegaconf.DictConfig) -> None: 12 | 13 | # reproducibility 14 | pl.seed_everything(conf.training.reprod.seed) 15 | 16 | # main module declaration 17 | pl_module = hydra.utils.instantiate( 18 | conf.model, optim_conf=conf.training.trainer.optim, _recursive_=False 19 | ) 20 | 21 | # data_module declaration 22 | pl_data_module = hydra.utils.instantiate( 23 | conf.data.datamodule, 24 | tokenizer=pl_module.tokenizer, # todo bad coupling towards huggingface 25 | _recursive_=False, 26 | ) 27 | 28 | # callbacks 29 | 30 | callbacks = [] 31 | 32 | # callbacks: checkpoint and early stopping 33 | 34 | monitor = conf.training.trainer.checkpoint.monitor 35 | assert monitor[0] in ["-", "+"] 36 | mode = "min" if monitor[0] == "-" else "max" 37 | monitor = monitor[1:] 38 | 39 | callbacks.append( 40 | ModelCheckpointWithBest( 41 | monitor=monitor, 42 | mode=mode, 43 | dirpath=f"checkpoints", 44 | filename=conf.training.trainer.checkpoint.filename, 45 | save_top_k=conf.training.trainer.checkpoint.save_top_k, 46 | save_last=conf.training.trainer.checkpoint.save_last, 47 | verbose=True, 48 | ) 49 | ) 50 | 51 | patience = conf.training.trainer.patience 52 | if patience is not None: 53 | callbacks.append( 54 | EarlyStopping(monitor=monitor, mode=mode, patience=conf.training.trainer.patience) 55 | ) 56 | 57 | # custom callbacks 58 | 59 | for callback in conf.callbacks.callbacks: 60 | callbacks.append(hydra.utils.instantiate(callback)) 61 | 62 | # instantiate trainer logger 63 | logger = hydra.utils.instantiate(conf.training.logger) 64 | 65 | # trainer 66 | trainer = pl.Trainer( 67 | **conf.device, 68 | accumulate_grad_batches=conf.training.trainer.gradient_acc_steps, 69 | gradient_clip_val=conf.training.trainer.gradient_clip_value, 70 | max_steps=conf.training.trainer.max_steps, 71 | val_check_interval=conf.training.trainer.val_check_interval, 72 | logger=logger, 73 | callbacks=callbacks, 74 | ) 75 | 76 | # module fit 77 | trainer.fit(pl_module, datamodule=pl_data_module) 78 | 79 | 80 | @hydra.main(config_path="../../../configurations/hydra-train", config_name="root") 81 | def main(conf: omegaconf.DictConfig): 82 | 83 | # fix paths 84 | 85 | def fix(conf): 86 | if type(conf) == list or type(conf) == omegaconf.listconfig.ListConfig: 87 | for i in range(len(conf)): 88 | conf[i] = fix(conf[i]) 89 | return conf 90 | elif type(conf) == dict or type(conf) == omegaconf.dictconfig.DictConfig: 91 | for k, v in conf.items(): 92 | conf[k] = fix(v) 93 | return conf 94 | elif type(conf) == str: 95 | if "/" in conf and os.path.exists( 96 | hydra.utils.to_absolute_path(conf[: conf.rindex("/")]) 97 | ): 98 | return hydra.utils.to_absolute_path(conf) 99 | else: 100 | return conf 101 | elif type(conf) in [float, int, bool]: 102 | return conf 103 | else: 104 | raise ValueError(f"Unexpected type {type(conf)}: {conf}") 105 | 106 | fix(conf) 107 | 108 | # actual train 109 | train(conf) 110 | 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /src/pl_modules/generative_models/bart.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from transformers import AutoTokenizer, BartForConditionalGeneration, AutoConfig 6 | 7 | from src.pl_modules.generative_models.generative_model import ( 8 | GenerativeModel, 9 | TAGenerativeModelOutput, 10 | GenGenerativeModelOutput, 11 | ) 12 | from src.pl_modules.utils import label_smoothed_nll_loss 13 | 14 | 15 | class BartGenerativeModel(GenerativeModel): 16 | def __init__(self, bart_model: str, dropout: float, label_smoothing: float): 17 | 18 | super().__init__() 19 | 20 | self.tokenizer = AutoTokenizer.from_pretrained(bart_model, add_prefix_space=True) 21 | self.config = AutoConfig.from_pretrained(bart_model, dropout=dropout) 22 | self.bart_model = BartForConditionalGeneration.from_pretrained( 23 | bart_model, config=self.config 24 | ) 25 | self._dropout = dropout 26 | self._label_smoothing = label_smoothing 27 | 28 | # metrics 29 | self.train_acc = pl.metrics.Accuracy() 30 | self.val_acc = pl.metrics.Accuracy() 31 | 32 | def forward( 33 | self, 34 | source: torch.Tensor, 35 | source_padding_mask: torch.Tensor, 36 | target: torch.Tensor, 37 | target_padding_mask: torch.Tensor, 38 | num_sequences: int = 1, 39 | **kwargs 40 | ) -> Union[TAGenerativeModelOutput, GenGenerativeModelOutput]: 41 | 42 | if ( 43 | target.shape[1] > 1 44 | ): # training-phase: "target" is provided and we can use the "teacher-forcing" strategy. 45 | 46 | assert ( 47 | not self.generation_mode 48 | ), 'The "target" is not empty but the GenerativeModel is in "generation mode"' 49 | 50 | # build target&labels 51 | decoder_input_ids = target[:, :-1].contiguous() 52 | decoder_padding_mask = target_padding_mask[:, :-1].contiguous() 53 | labels = target[:, 1:].contiguous() 54 | labels_padding_mask = target_padding_mask[:, 1:].contiguous() 55 | labels[~labels_padding_mask] = -100 56 | 57 | # actual forward 58 | result = self.bart_model( 59 | input_ids=source, 60 | attention_mask=source_padding_mask, 61 | decoder_input_ids=decoder_input_ids, 62 | decoder_attention_mask=decoder_padding_mask, 63 | labels=labels, 64 | ) 65 | logits = result[1] 66 | 67 | # compute loss with label smoothing 68 | labels[~labels_padding_mask] = self.tokenizer.pad_token_id 69 | log_probs = torch.log_softmax(logits, dim=-1) 70 | smoothed_loss, nll_loss = label_smoothed_nll_loss( 71 | log_probs.view(-1, log_probs.shape[2]), 72 | labels.view(-1), 73 | self._label_smoothing, 74 | padding_mask=labels_padding_mask.view(-1), 75 | ) 76 | 77 | # return 78 | return TAGenerativeModelOutput(loss=smoothed_loss, plain_loss=nll_loss, logits=logits, predictions=logits.argmax(-1)) 79 | 80 | else: # autoregressive-phase: the only token in target is "begin of sequence" ( for bart). 81 | 82 | assert self.generation_mode 83 | 84 | assert target.shape[1] == 1 85 | assert len(set(target[:, 0].tolist())) == 1 86 | 87 | generation = self.bart_model.generate( 88 | input_ids=source, 89 | attention_mask=source_padding_mask, 90 | num_return_sequences=num_sequences, 91 | decoder_start_token_id=target[0][0].item(), 92 | return_dict_in_generate=True, 93 | **self.generation_params 94 | ) 95 | 96 | return GenGenerativeModelOutput( 97 | generation=generation.sequences[:, 1:].reshape(source.shape[0], num_sequences, -1), 98 | raw=generation, 99 | ) 100 | -------------------------------------------------------------------------------- /src/optim/optimizers/radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class RAdam(Optimizer): 7 | def __init__( 8 | self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True 9 | ): 10 | if not 0.0 <= lr: 11 | raise ValueError("Invalid learning rate: {}".format(lr)) 12 | if not 0.0 <= eps: 13 | raise ValueError("Invalid epsilon value: {}".format(eps)) 14 | if not 0.0 <= betas[0] < 1.0: 15 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 16 | if not 0.0 <= betas[1] < 1.0: 17 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 18 | 19 | self.degenerated_to_sgd = degenerated_to_sgd 20 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 21 | for param in params: 22 | if "betas" in param and ( 23 | param["betas"][0] != betas[0] or param["betas"][1] != betas[1] 24 | ): 25 | param["buffer"] = [[None, None, None] for _ in range(10)] 26 | defaults = dict( 27 | lr=lr, 28 | betas=betas, 29 | eps=eps, 30 | weight_decay=weight_decay, 31 | buffer=[[None, None, None] for _ in range(10)], 32 | ) 33 | super(RAdam, self).__init__(params, defaults) 34 | 35 | def __setstate__(self, state): 36 | super(RAdam, self).__setstate__(state) 37 | 38 | def step(self, closure=None): 39 | 40 | loss = None 41 | if closure is not None: 42 | loss = closure() 43 | 44 | for group in self.param_groups: 45 | 46 | for p in group["params"]: 47 | if p.grad is None: 48 | continue 49 | grad = p.grad.data.float() 50 | if grad.is_sparse: 51 | raise RuntimeError("RAdam does not support sparse gradients") 52 | 53 | p_data_fp32 = p.data.float() 54 | 55 | state = self.state[p] 56 | 57 | if len(state) == 0: 58 | state["step"] = 0 59 | state["exp_avg"] = torch.zeros_like(p_data_fp32) 60 | state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) 61 | else: 62 | state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) 63 | state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) 64 | 65 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 66 | beta1, beta2 = group["betas"] 67 | 68 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 69 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 70 | 71 | state["step"] += 1 72 | buffered = group["buffer"][int(state["step"] % 10)] 73 | if state["step"] == buffered[0]: 74 | N_sma, step_size = buffered[1], buffered[2] 75 | else: 76 | buffered[0] = state["step"] 77 | beta2_t = beta2 ** state["step"] 78 | N_sma_max = 2 / (1 - beta2) - 1 79 | N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) 80 | buffered[1] = N_sma 81 | 82 | # more conservative since it's an approximated value 83 | if N_sma >= 5: 84 | step_size = math.sqrt( 85 | (1 - beta2_t) 86 | * (N_sma - 4) 87 | / (N_sma_max - 4) 88 | * (N_sma - 2) 89 | / N_sma 90 | * N_sma_max 91 | / (N_sma_max - 2) 92 | ) / (1 - beta1 ** state["step"]) 93 | elif self.degenerated_to_sgd: 94 | step_size = 1.0 / (1 - beta1 ** state["step"]) 95 | else: 96 | step_size = -1 97 | buffered[2] = step_size 98 | 99 | # more conservative since it's an approximated value 100 | if N_sma >= 5: 101 | if group["weight_decay"] != 0: 102 | p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32) 103 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 104 | p_data_fp32.addcdiv_(-step_size * group["lr"], exp_avg, denom) 105 | p.data.copy_(p_data_fp32) 106 | elif step_size > 0: 107 | if group["weight_decay"] != 0: 108 | p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32) 109 | p_data_fp32.add_(-step_size * group["lr"], exp_avg) 110 | p.data.copy_(p_data_fp32) 111 | 112 | return loss 113 | -------------------------------------------------------------------------------- /src/callbacks/generation.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import itertools 3 | from multiprocessing import Pool 4 | from typing import List, Dict, Any, Tuple, Callable, Optional 5 | 6 | import pytorch_lightning as pl 7 | import torch 8 | import wandb 9 | from datasets import load_metric 10 | 11 | from src.scripts.model.translate import translate 12 | 13 | from src.utils.logging import get_project_logger 14 | 15 | logger = get_project_logger(__name__) 16 | 17 | 18 | class GenerationCallback: 19 | def __call__( 20 | self, 21 | name: str, 22 | translations: List[Tuple[str, List[str], Optional[str]]], 23 | module: pl.LightningModule, 24 | ): 25 | raise NotImplementedError 26 | 27 | 28 | class RougeGenerationCallback(GenerationCallback): 29 | def __init__(self): 30 | self.rouge = load_metric("rouge") 31 | 32 | def __call__( 33 | self, 34 | name: str, 35 | translations: List[Tuple[str, List[str], Optional[str]]], 36 | module: pl.LightningModule, 37 | ): 38 | assert all(t[2] is not None for t in translations) 39 | results = self.rouge.compute( 40 | predictions=[t[1][0] for t in translations], references=[t[2] for t in translations] 41 | ) 42 | for k, v in results.items(): 43 | module.log( 44 | f"val_{name}_{k}", v.mid.fmeasure, prog_bar=True, on_step=False, on_epoch=True 45 | ) 46 | 47 | 48 | class TextGenerationCallback(pl.Callback): 49 | def __init__( 50 | self, generation_callbacks: Dict[str, GenerationCallback], generations: List[Dict[str, Any]] 51 | ): 52 | self._epoch = 0 53 | self.generation_callbacks = generation_callbacks 54 | self.generations_confs = [] 55 | for g in generations: 56 | self.generations_confs.append( 57 | ( 58 | g["name"], 59 | g["glob_translate_path"], 60 | g["generation_param_conf_path"], 61 | g["num_sequences"], 62 | g["token_batch_size"], 63 | g["limit"], 64 | g["enabled_generation_callbacks"], 65 | ) 66 | ) 67 | 68 | def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 69 | 70 | wandb_table = wandb.Table(columns=["Configuration", "Source", "Input", "Pred", "Gold"]) 71 | logger.info("Executing translation callback") 72 | 73 | for ( 74 | name, 75 | glob_translate_path, 76 | generation_param_conf_path, 77 | num_sequences, 78 | token_batch_size, 79 | limit, 80 | enabled_generation_callbacks, 81 | ) in self.generations_confs: 82 | 83 | translation_pairs = [] 84 | 85 | # translate 86 | 87 | for translation_path in glob.iglob(glob_translate_path): 88 | 89 | logger.info( 90 | f"Translating translation path {translation_path} for configuration {name}" 91 | ) 92 | source_type = translation_path.split("/")[-1][:-4] 93 | 94 | with open(translation_path) as f: 95 | 96 | # read sources 97 | iterator = map(lambda l: l.strip(), f) 98 | 99 | # do only a dry run on first epoch (correspond to sanity check run) 100 | if self._epoch == 0: 101 | iterator = itertools.islice(iterator, 5) 102 | 103 | # apply limit 104 | if limit != -1: 105 | iterator = itertools.islice(iterator, limit) 106 | 107 | for i, (source, sample_translations, gold_output) in enumerate( 108 | translate( 109 | pl_module, 110 | iterator, 111 | num_sequences=num_sequences, 112 | generation_param_conf_path=generation_param_conf_path, 113 | token_batch_size=token_batch_size, 114 | ) 115 | ): 116 | if i % 100 == 0: 117 | logger.debug( 118 | f"Translating translation path {translation_path} for configuration {name}: {i} lines translated" 119 | ) 120 | 121 | for translation in sample_translations: 122 | wandb_table.add_data( 123 | name, source_type, source, translation, gold_output 124 | ) 125 | 126 | translation_pairs.append((source, sample_translations, gold_output)) 127 | 128 | if self._epoch == 0: 129 | # do only a dry run on first epoch (correspond to sanity check run) 130 | break 131 | 132 | # run callbacks 133 | 134 | for callback in enabled_generation_callbacks: 135 | self.generation_callbacks[callback](name, translation_pairs, pl_module) 136 | 137 | if self._epoch > 0: 138 | trainer.logger.experiment.log({"translations": wandb_table}) 139 | 140 | logger.info("Translation callback completed") 141 | self._epoch += 1 142 | -------------------------------------------------------------------------------- /src/scripts/model/translate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import List, Iterable, Tuple, Optional 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | from omegaconf import OmegaConf 7 | from torch.cuda.amp import autocast 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | from src.data.datasets.generation import ParallelDataset 12 | from src.pl_modules.utils import load_pl_module_from_checkpoint 13 | 14 | 15 | def translate( 16 | module: pl.LightningModule, 17 | sources: Iterable[str], 18 | num_sequences: int, 19 | generation_param_conf_path: str, 20 | token_batch_size: int = 1024, 21 | progress_bar: bool = False, 22 | ) -> Iterable[Tuple[str, List[str], Optional[str]]]: 23 | 24 | module.enable_generation_mode() 25 | module.load_generation_params(OmegaConf.load(generation_param_conf_path)) 26 | 27 | # todo only works on single gpu 28 | device = next(module.parameters()).device 29 | 30 | # todo unnecessary coupling toward ParallelDataset 31 | # rather, should find and instantiate the appropriate GenerativeDataset 32 | dataset = ParallelDataset.from_lines( 33 | sources, 34 | tokenizer=module.tokenizer, 35 | for_inference=True, 36 | max_tokens_per_batch=token_batch_size, 37 | drop_last_batch=False, 38 | ) 39 | dataloader = DataLoader(dataset, batch_size=None, num_workers=0) 40 | 41 | iterator = dataloader 42 | if progress_bar: 43 | iterator = tqdm(iterator, desc="Translating") 44 | 45 | for batch in iterator: 46 | 47 | # translate 48 | with autocast(enabled=True): 49 | with torch.no_grad(): 50 | batch_generations = module( 51 | **{k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items()}, 52 | num_sequences=num_sequences, 53 | ) 54 | batch_input = batch["text_source"] 55 | batch_gold_output = batch["text_target"] 56 | batch_generations = batch_generations.generation 57 | 58 | # generate 59 | for sample_input, sample_generations, sample_gold_output in zip( 60 | batch_input, batch_generations, batch_gold_output 61 | ): 62 | decoded_sample_generations = [] 63 | for sample_generation in module.tokenizer.batch_decode( 64 | sample_generations, clean_up_tokenization_spaces=False 65 | ): 66 | if module.tokenizer.eos_token in sample_generation: 67 | sample_generation = sample_generation[ 68 | : sample_generation.index(module.tokenizer.eos_token) 69 | + len(module.tokenizer.eos_token) 70 | ] 71 | decoded_sample_generations.append(sample_generation) 72 | 73 | yield sample_input, decoded_sample_generations, sample_gold_output 74 | 75 | module.disable_generation_mode() 76 | 77 | 78 | def interactive_main( 79 | model_checkpoint_path: str, 80 | num_sequences: int, 81 | generation_param_conf_path: str, 82 | cuda_device: int, 83 | ): 84 | 85 | model = load_pl_module_from_checkpoint(model_checkpoint_path) 86 | model.to(torch.device(cuda_device if cuda_device != -1 else "cpu")) 87 | model.eval() 88 | model.freeze() 89 | 90 | while True: 91 | source = input("Enter source text: ").strip() 92 | _, predictions, _ = next( 93 | translate( 94 | model, 95 | [source], 96 | num_sequences=num_sequences, 97 | generation_param_conf_path=generation_param_conf_path, 98 | ) 99 | ) 100 | for i, prediction in enumerate(predictions): 101 | print(f"\t# prediction-{i}: \t{prediction}") 102 | 103 | 104 | def file_main( 105 | model_checkpoint_path: str, 106 | input_path: str, 107 | output_path: str, 108 | num_sequences: int, 109 | generation_param_conf_path: str, 110 | cuda_device: int, 111 | token_batch_size: int, 112 | ): 113 | 114 | model = load_pl_module_from_checkpoint(model_checkpoint_path) 115 | model.to(torch.device(cuda_device if cuda_device != -1 else "cpu")) 116 | model.eval() 117 | model.freeze() 118 | 119 | with open(input_path) as fi, open(output_path, "w") as fo: 120 | for source, sample_translations, _ in translate( 121 | model, 122 | map(lambda l: l.strip(), fi), 123 | num_sequences=num_sequences, 124 | generation_param_conf_path=generation_param_conf_path, 125 | token_batch_size=token_batch_size, 126 | progress_bar=True, 127 | ): 128 | for translation in sample_translations: 129 | fo.write(f"{source}\t{translation.strip()}\n") 130 | 131 | 132 | def main(): 133 | args = parse_args() 134 | if args.t: 135 | interactive_main( 136 | args.model_checkpoint, 137 | num_sequences=args.n, 138 | generation_param_conf_path=args.g, 139 | cuda_device=args.cuda_device, 140 | ) 141 | else: 142 | file_main( 143 | args.model_checkpoint, 144 | args.f, 145 | args.o, 146 | num_sequences=args.n, 147 | generation_param_conf_path=args.g, 148 | cuda_device=args.cuda_device, 149 | token_batch_size=args.token_batch_size, 150 | ) 151 | 152 | 153 | def parse_args(): 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument("model_checkpoint", type=str, help="Path to pl_modules checkpoint") 156 | parser.add_argument("-n", type=int, default=1, help="Num sequences") 157 | parser.add_argument("-g", type=str, help="Path to generation conf") 158 | parser.add_argument("--cuda-device", type=int, default=-1, help="Cuda device") 159 | # interactive params 160 | parser.add_argument("-t", action="store_true", help="Interactive mode") 161 | # generation params 162 | parser.add_argument("-f", type=str, default=None, help="Input file") 163 | parser.add_argument("-o", type=str, default=None, help="Output file") 164 | parser.add_argument("--token-batch-size", type=int, default=128, help="Token batch size") 165 | # return 166 | return parser.parse_args() 167 | 168 | 169 | if __name__ == "__main__": 170 | main() 171 | -------------------------------------------------------------------------------- /src/data/datasets/generation.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Iterator, Dict, List, Tuple, Iterable, Callable, Optional 3 | 4 | import torch 5 | from torch.nn.utils.rnn import pad_sequence 6 | from torch.utils.data import IterableDataset 7 | from transformers import PreTrainedTokenizer 8 | 9 | from src.utils.commons import add_noise_to_value, chunks, flatten 10 | from src.utils.logging import get_project_logger 11 | 12 | logger = get_project_logger(__name__) 13 | 14 | 15 | class GenerativeDataset(IterableDataset): 16 | """Abstract interface, adding two additional classmethods to the IterableDataset interface. 17 | 18 | These methods are needed for the translate script to work. 19 | """ 20 | 21 | @classmethod 22 | def from_lines(cls, lines: Iterable[str], **kwargs): 23 | raise NotImplementedError 24 | 25 | @classmethod 26 | def from_file(cls, path: str, **kwargs): 27 | raise NotImplementedError 28 | 29 | 30 | class ParallelDataset(GenerativeDataset): 31 | """Dataset useful when dealing with conditioned generation tasks (e.g. MT). 32 | 33 | The most relevant feature of this class is that its input is fed as the output a closure: from_lines and from_file 34 | setup a closure whose execution yields an Iterable[str]. This allows for the same code to be fully shared across the 35 | two modalities (and additional benefits in more complex scenarios). 36 | """ 37 | 38 | @classmethod 39 | def from_lines(cls, lines: Iterable[str], **kwargs): 40 | return cls(lambda: lines, **kwargs) 41 | 42 | @classmethod 43 | def from_file(cls, path: str, **kwargs): 44 | def r(): 45 | with open(path) as f: 46 | for line in f: 47 | yield line.strip() 48 | 49 | return cls(r, **kwargs) 50 | 51 | def __init__( 52 | self, 53 | iterator_generator: Callable[[], Iterable], 54 | tokenizer: PreTrainedTokenizer, 55 | for_inference: bool = False, 56 | max_tokens_per_batch: int = 1024, 57 | min_length: int = -1, 58 | max_length: int = -1, 59 | truncate: bool = False, 60 | section_size: int = 50_000, 61 | drop_last_batch: bool = False, 62 | prebatch: bool = False, 63 | ): 64 | self.iterator_generator = iterator_generator 65 | self.tokenizer = tokenizer 66 | self.for_inference = for_inference 67 | self.max_tokens_per_batch = max_tokens_per_batch 68 | self.min_length = min_length 69 | self.max_length = max_length 70 | self.truncate = truncate 71 | self.section_size = section_size if prebatch else 1 72 | self.drop_last_batch = drop_last_batch 73 | self.prebatch = prebatch 74 | 75 | def _generate_dataset(self): 76 | def prebatch_ds(ds: List[Tuple[List[int], List[int], str, Optional[str]]]): 77 | ds = sorted( 78 | ds, key=lambda x: add_noise_to_value(len(x[0]) + len(x[1]), noise_param=0.1) 79 | ) 80 | ds = list(chunks(ds, 512)) 81 | random.shuffle(ds) 82 | return flatten(ds) 83 | 84 | logger.info("Initting dataset") 85 | 86 | discarded_due_to_min_length = 0 87 | discarded_due_to_max_length = 0 88 | 89 | read_samples = 0 90 | ds = [] 91 | 92 | for line in self.iterator_generator(): 93 | 94 | if read_samples % 10_000 == 0: 95 | logger.info(f"{read_samples} entries added to dataset") 96 | 97 | line = line.strip() 98 | parts = line.split("\t") 99 | if self.for_inference: 100 | source = parts[0] 101 | target = parts[1] if len(parts) == 2 else None 102 | else: 103 | source = parts[0] 104 | target = parts[1] 105 | 106 | # encode 107 | text_source, text_target = source, target 108 | if not self.for_inference: 109 | sample = self.tokenizer.prepare_seq2seq_batch([source], tgt_texts=[target]) 110 | source = sample["input_ids"][0] 111 | target = sample["labels"][0] 112 | else: 113 | sample = self.tokenizer.prepare_seq2seq_batch([source]) 114 | source = sample["input_ids"][0] 115 | target = [self.tokenizer.bos_token_id] 116 | 117 | # truncate if requested 118 | if self.truncate: 119 | source = source[: self.max_length] 120 | target = target[: self.max_length] 121 | 122 | # check min length 123 | if self.min_length != -1 and ( 124 | len(source) < self.min_length 125 | or (len(target) < self.min_length and not self.for_inference) 126 | ): 127 | discarded_due_to_min_length += 1 128 | if discarded_due_to_min_length % 1_000 == 0: 129 | logger.warning( 130 | f"{discarded_due_to_min_length} samples have been discarded due to being shorter than minimum length {self.min_length}" 131 | ) 132 | continue 133 | 134 | # check max length 135 | if self.max_length != -1 and ( 136 | len(source) > self.max_length or len(target) > self.max_length 137 | ): 138 | discarded_due_to_max_length += 1 139 | if discarded_due_to_max_length % 1_000 == 0: 140 | logger.warning( 141 | f"{discarded_due_to_max_length} samples have been discarded due to being longer than maximum length {self.max_length}" 142 | ) 143 | continue 144 | 145 | ds.append((source, target, text_source, text_target)) 146 | if len(ds) == self.section_size: 147 | if self.prebatch: 148 | ds = prebatch_ds(ds) 149 | yield from ds 150 | ds = [] 151 | 152 | read_samples += 1 153 | 154 | if len(ds) > 0: 155 | if self.prebatch: 156 | ds = prebatch_ds(ds) 157 | yield from ds 158 | 159 | if discarded_due_to_min_length > 0: 160 | logger.warning( 161 | f"{discarded_due_to_min_length} samples have been discarded due to being shorter than minimum length {self.min_length}" 162 | ) 163 | 164 | if discarded_due_to_max_length > 0: 165 | logger.warning( 166 | f"{discarded_due_to_max_length} samples have been discarded due to being longer than maximum length {self.max_length}" 167 | ) 168 | 169 | def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: 170 | 171 | dataset = self._generate_dataset() 172 | 173 | batch = [] 174 | ct = 0 175 | 176 | for sample in dataset: 177 | 178 | sample_tokens = len(sample[0]) + len(sample[1]) 179 | 180 | if ( 181 | max(ct, sample_tokens) * (len(batch) + 1) > self.max_tokens_per_batch 182 | and len(batch) > 0 183 | ): 184 | yield self.prepare_output_batch(batch) 185 | batch = [] 186 | ct = 0 187 | 188 | batch.append(sample) 189 | ct = max(ct, sample_tokens) 190 | 191 | # drop last cause might be too short and result in issues (nan if we are using amp) 192 | if not self.drop_last_batch and len(batch) > 0: 193 | yield self.prepare_output_batch(batch) 194 | 195 | def prepare_output_batch( 196 | self, batch: List[Tuple[List[int], List[int], str, Optional[str]]] 197 | ) -> Dict[str, torch.Tensor]: 198 | 199 | try: 200 | pad_token_id = self.tokenizer.pad_token_id 201 | except: 202 | pad_token_id = 0 203 | 204 | # build source 205 | source = pad_sequence( 206 | [torch.tensor(e[0]) for e in batch], batch_first=True, padding_value=pad_token_id 207 | ) 208 | source_padding_mask = pad_sequence( 209 | [torch.full((len(e[0]),), fill_value=True) for e in batch], 210 | batch_first=True, 211 | padding_value=False, 212 | ) 213 | 214 | # build target 215 | target = pad_sequence( 216 | [torch.tensor(e[1]) for e in batch], batch_first=True, padding_value=pad_token_id 217 | ) 218 | target_padding_mask = pad_sequence( 219 | [torch.full((len(e[1]),), fill_value=True) for e in batch], 220 | batch_first=True, 221 | padding_value=False, 222 | ) 223 | 224 | # return 225 | return { 226 | "source": source, 227 | "source_padding_mask": source_padding_mask, 228 | "target": target, 229 | "target_padding_mask": target_padding_mask, 230 | "text_source": [e[2] for e in batch], 231 | "text_target": [e[3] for e in batch], 232 | } 233 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Grid - Seq2Seq 3 |

4 | 5 |

6 | PyTorch 7 | Lightning 8 | Config: hydra 9 | Code style: black 10 |

11 | 12 |

13 | A simple template to bootstrap your Seq2Seq project 14 |

15 | 16 | Grid-Seq2Seq (Get RID of boilerplate code for Seq2Seq) is a simple and generic template to bootstrap your next [PyTorch](https://pytorch.org) project on generative 17 | Natural Language Processing models. 18 | Quickly kickstart your work and get the power of: 19 | * PyTorch and PyTorch Lightning (no comments needed here) 20 | * Hydra, allowing to configure everything of your training in readable .yaml files 21 | * A modular skeleton that supports several frequent use cases when dealing with generative NLP models: 22 | * HuggingFace [Transformers](https://github.com/huggingface/transformers) integration 23 | * Easily swap between generative models 24 | * Monitor autoregressive metrics such as bleu/rouge during training 25 | * Log generations inside [WandB](https://wandb.ai) tables during training 26 | * Interactive and file-based translation 27 | 28 | If you are in a rush and have **2 minutes only**, be sure to check out the [quick tour](#quick-tour). 29 | When all else fails, read the - rest of the - instructions (a.k.a. the [FAQs](#faq-and-use-cases)). 30 | 31 | If you use this template, please add the following shield to your README: 32 |
33 | [![](https://img.shields.io/badge/-Grid--Seq2Seq-blueviolet?style=for-the-badge&logo=github)](https://github.com/poccio/nlp-gen) 34 | 35 | ## Quick Tour 36 | 37 | A quite frequent scenario in the NLP community these days: you have a parallel dataset (usually in a TSV format), 38 | representing some revolutionary way to use the latest advances in generative models, and you want to fine-tune a pretrained 39 | encoder-decoder such as Bart on it. You are in the right place! 40 | 41 | If you simply run the following: 42 | ```bash 43 | PYTHONPATH=$(pwd) python src/scripts/model/train.py \ 44 | exp_name= \ 45 | project_name= \ 46 | data=simple \ 47 | data.datamodule.train_dataset.path= \ 48 | data.datamodule.validation_dataset.path= \ 49 | device=cuda 50 | ``` 51 | You get your Bart model, magically fine-tuned on the provided parallel dataset and everything logged on WandB, 52 | under /. Once the training finishes (and eons have passed), 53 | you can use the interactive translate script to debug/present simple demos: 54 | ```bash 55 | PYTHONPATH=$(pwd) python src/scripts/model/translate.py \ 56 | -t -n 1 -g configurations/generation/beam.yaml 57 | ``` 58 | Note the script also support the common file-based translation. 59 | 60 | However, **the main purpose of Grid-Seq2Seq is providing a template** to bootstrap your own project. 61 | As such, it is quite easy to change the various components that make up the experiment to match your requirements: 62 | * Generative models are wrapped behind the [GenerativeModel] interface. As far as your models implement this interface, 63 | you can easily swap among different models 64 | * You can implement your Dataset/Datamodules and, as far as you are compliant with your generative model, everything will 65 | work transparently 66 | 67 | Additionally, we are quite proud of the callbacks system. It allows you to log, inside nice WandB tables, the generations of your model as the training 68 | progresses and, furthermore, you can use such generations to compute both referenced and unreferenced metrics upon them. 69 | 70 | ## Grid-Seq2Seq is your library for ... 71 | 72 | * :rocket: Quick prototyping of generative models 73 | * :skull: Modular skeleton to bootstrap your own project 74 | * :telephone: Callback system! 75 | * Use referenced metrics such as BLEU/Rouge as validation metric 76 | * Check the generations of your model as training goes on 77 | 78 | ## Template Structure 79 | 80 | ```bash 81 | . 82 | ├── configurations 83 | │ ├── generation # hydra generation config files 84 | │ └── hydra-train # hydra config files for training 85 | │ ├── root.yaml 86 | │ ├── callbacks 87 | │ ├── data 88 | │ ├── device 89 | │ ├── model 90 | │ └── training 91 | ├── data # data folder 92 | ├── experiments # experiments folder 93 | ├── src 94 | │ ├── callbacks # callbacks code 95 | │ ├── data # datasets and lightning datamodules 96 | │ ├── generative_models # supported generative models wrapped behind an interface 97 | │ ├── optim # optimizers' instantiation code and custom optimizers 98 | │ │ ├── factories.py 99 | │ │ ├── optimizers 100 | │ ├── pl_modules # lightning modules 101 | │ ├── scripts 102 | │ │ ├── model 103 | │ │ │ ├── train.py # training script 104 | │ │ │ └── translate.py # translation script (both interactive and file mode supported) 105 | │ └── utils 106 | ├── README.md 107 | ├── requirements.txt 108 | └── setup.sh # bash script to auto-setup the env 109 | ``` 110 | 111 | ## Setup Env 112 | 113 | To neatly setup the whole environment needed for the project (interpreter, requirements, runtime dependencies, ...), 114 | we made a bash script that automatically configures everything. The only actual requirement is that you have [conda](https://docs.conda.io/projects/conda/en/latest/index.html) 115 | installed. Once you have it installed (or if you already have it), just run the script and follow the prompts (desired python version, desired cuda version, ...) to quickly setup everything: 116 | 117 | ``` 118 | bash setup.sh 119 | ``` 120 | 121 | ## Usage Examples 122 | 123 | ### Train a Summarization Model in 10 seconds 124 | 125 | We chose [Summarization with Bart](https://www.aclweb.org/anthology/2020.acl-main.703.pdf) on the CNN/DailyMail dataset as our implemented working example. 126 | 127 | As we mentioned earlier, everything is quite modular and, in order to carry out your given experiment, you just need to: 128 | * Implement the various building blocks (mainly model and datamodule to use) 129 | * Write the hydra configuration files 130 | * Tell the training script to use them 131 | 132 | In this case, we already took care of the first two steps; thus, we can directly jump to the actual training 133 | (if you are not familiar with Hydra, we recommend reading Hydra [intro tutorials](https://hydra.cc/docs/tutorials/intro) to quickly get acquaninted with it), 134 | launching the training script with arguments that instruct Hydra what components to use: 135 | 136 | ```bash 137 | PYTHONPATH=$(pwd) python src/scripts/model/train.py \ 138 | exp_name=bart-sum \ 139 | project_name= \ 140 | data=cnn_dm \ 141 | model=bart \ 142 | callbacks=summarization \ 143 | device=cuda_amp 144 | ``` 145 | 146 | Once the training finally finishes (and eons have likely passed), check out the translation script. In particular, besides 147 | the common file-based mode, it features an interactive mode that can be useful for debugging/demos with your colleagues: 148 | 149 | ```bash 150 | PYTHONPATH=$(pwd) python src/scripts/model/translate.py \ 151 | -t -n 1 -g configurations/generation/beam.yaml 152 | ``` 153 | 154 | ## FAQ and Use Cases 155 | 156 | **Q**: I want to use another Generative Model. How? 157 | 158 | **A**: It depends. If your model is part of *HuggingFace Transformers*, then you're golden. You just need to wrap it 159 | behind the GenerativeModel interface and, if needed, write a suitable matching Dataset (or override some parts such as the 160 | encoding). Consider the case of adding GPT2: 161 | * You need to write a Dataset tailored for causal language modelling 162 | * You need to write your GenerativeModel 163 | 164 | Once you do, everything (callbacks, training, translations scripts) will work seamlessly. 165 | 166 | Conversely, if your model is not part of *HuggingFace Transformers*, you may need to refactor part of the code: for example, 167 | we currently have an explicit coupling in the training script toward *HuggingFace Transformers* Tokenizer object. 168 | We welcome contributions in this direction. 169 | 170 | **Q**: I want to monitor BLEU during training. How? 171 | 172 | **A**: Check how we log Rouge (src.callbacks.generation.RougeGenerationCallback); it's essentially identical to that. 173 | Once you have implemented your GenerationCallback, you just need to add it to your Hydra callback configuration file. 174 | 175 | ## Contributors 176 | 177 | * [Edoardo Barba](https://github.com/edobobo) 178 | * [Luigi Procopio](https://github.com/poccio) 179 | --------------------------------------------------------------------------------