├── data
└── .placeholder
├── src
├── __init__.py
├── data
│ ├── __init__.py
│ ├── datamodules
│ │ ├── __init__.py
│ │ ├── simple.py
│ │ └── cnn_dm.py
│ └── datasets
│ │ ├── __init__.py
│ │ ├── alternator.py
│ │ └── generation.py
├── optim
│ ├── __init__.py
│ ├── optimizers
│ │ ├── __init__.py
│ │ └── radam.py
│ └── factories.py
├── scripts
│ ├── __init__.py
│ └── model
│ │ ├── __init__.py
│ │ ├── train.py
│ │ └── translate.py
├── utils
│ ├── __init__.py
│ ├── logging.py
│ └── commons.py
├── callbacks
│ ├── __init__.py
│ ├── best_checkpoint.py
│ └── generation.py
└── pl_modules
│ ├── __init__.py
│ ├── generative_models
│ ├── __init__.py
│ ├── generative_model.py
│ └── bart.py
│ ├── utils.py
│ └── generative_pl_module.py
├── experiments
└── .placeholder
├── configurations
├── hydra-train
│ ├── device
│ │ ├── cpu.yaml
│ │ ├── cuda.yaml
│ │ └── cuda_amp.yaml
│ ├── callbacks
│ │ ├── default.yaml
│ │ └── summarization.yaml
│ ├── model
│ │ └── bart.yaml
│ ├── data
│ │ ├── cnn_dm.yaml
│ │ └── simple.yaml
│ ├── root.yaml
│ └── training
│ │ └── default.yaml
└── generation
│ ├── beam.yaml
│ └── sample.yaml
├── requirements.txt
├── setup.sh
├── .gitignore
└── README.md
/data/.placeholder:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/optim/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiments/.placeholder:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/pl_modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/data/datamodules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/optim/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/scripts/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/configurations/hydra-train/device/cpu.yaml:
--------------------------------------------------------------------------------
1 | {}
--------------------------------------------------------------------------------
/src/pl_modules/generative_models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/configurations/hydra-train/callbacks/default.yaml:
--------------------------------------------------------------------------------
1 | callbacks: []
--------------------------------------------------------------------------------
/configurations/hydra-train/device/cuda.yaml:
--------------------------------------------------------------------------------
1 | gpus:
2 | - 0
3 | precision: 32
4 | amp_level: 'O0'
--------------------------------------------------------------------------------
/configurations/hydra-train/device/cuda_amp.yaml:
--------------------------------------------------------------------------------
1 | gpus:
2 | - 0
3 | precision: 16
4 | amp_level: 'O1'
--------------------------------------------------------------------------------
/configurations/generation/beam.yaml:
--------------------------------------------------------------------------------
1 | num_beams: 5
2 | min_length: 5
3 | max_length: 100
4 | length_penalty: 1.0
5 | repetition_penalty: 1.0
--------------------------------------------------------------------------------
/src/utils/logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def get_project_logger(module_name: str) -> logging.Logger:
5 | return logging.getLogger(f"grid_seq2seq.{module_name}")
6 |
--------------------------------------------------------------------------------
/configurations/generation/sample.yaml:
--------------------------------------------------------------------------------
1 | num_beams: 1
2 | do_sample: True
3 | top_p: 0.8
4 | temperature: 1.0
5 | min_length: 25
6 | max_length: 200
7 | length_penalty: 1.0
8 | repetition_penalty: 1.0
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning==1.2.3
2 | transformers==4.3.3
3 | hydra-core==1.1.0-dev4
4 | wandb==0.10.21
5 | # summarization dependencies
6 | datasets==1.4.1
7 | nltk==3.5
8 | rouge-score==0.0.4
9 | # black
10 | black==20.8b1
11 |
--------------------------------------------------------------------------------
/configurations/hydra-train/model/bart.yaml:
--------------------------------------------------------------------------------
1 | _target_: 'src.pl_modules.generative_pl_module.GenerativePLModule'
2 | generative_model:
3 | _target_: 'src.pl_modules.generative_models.bart.BartGenerativeModel'
4 | bart_model: 'facebook/bart-large'
5 | dropout: 0.1
6 | label_smoothing: 0.1 # 0.0 for non label smoothing
--------------------------------------------------------------------------------
/configurations/hydra-train/data/cnn_dm.yaml:
--------------------------------------------------------------------------------
1 | datamodule:
2 | _target_: 'src.data.datamodules.cnn_dm.CNNDMDataModule'
3 | data_dir: 'data/cnn-dm'
4 | num_workers: 0
5 | dataset:
6 | _target_: 'src.data.datasets.generation.ParallelDataset.from_file'
7 | max_tokens_per_batch: 1300
8 | min_length: 5
9 | max_length: 500
10 | truncate: true
11 | section_size: 10000
12 |
--------------------------------------------------------------------------------
/configurations/hydra-train/callbacks/summarization.yaml:
--------------------------------------------------------------------------------
1 | callbacks:
2 | - _target_: "src.callbacks.generation.TextGenerationCallback"
3 | generation_callbacks:
4 | rouge:
5 | _target_: 'src.callbacks.generation.RougeGenerationCallback'
6 | generations:
7 | - name: "beam1"
8 | glob_translate_path: "data/cnn-dm/validation.tsv"
9 | generation_param_conf_path: "configurations/generation/beam.yaml"
10 | num_sequences: 1
11 | token_batch_size: 800
12 | limit: 1000
13 | enabled_generation_callbacks:
14 | - 'rouge'
--------------------------------------------------------------------------------
/configurations/hydra-train/root.yaml:
--------------------------------------------------------------------------------
1 | project_name: grid-seq2seq
2 | exp_name: ???
3 | exp_folder: ./experiments/${exp_name}
4 |
5 | hydra:
6 | # customize working dir
7 | run:
8 | dir: ./experiments/${exp_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
9 | # customize logging
10 | verbose: [ grid_seq2seq ]
11 | job_logging:
12 | formatters:
13 | simple:
14 | format: '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
15 | root:
16 | level: WARN
17 |
18 | # defaults
19 | defaults:
20 | - callbacks: default
21 | - data: null
22 | - device: cpu
23 | - model: bart
24 | - training: default
25 |
--------------------------------------------------------------------------------
/configurations/hydra-train/data/simple.yaml:
--------------------------------------------------------------------------------
1 | datamodule:
2 | _target_: 'src.data.datamodules.simple.SimpleDataModule'
3 | num_workers: 0
4 | train_dataset:
5 | _target_: 'src.data.datasets.generation.ParallelDataset.from_file'
6 | path: null
7 | max_tokens_per_batch: 1300
8 | min_length: 5
9 | max_length: 500
10 | truncate: true
11 | section_size: 10000
12 | validation_dataset:
13 | _target_: 'src.data.datasets.generation.ParallelDataset.from_file'
14 | path: null
15 | max_tokens_per_batch: 1300
16 | min_length: 5
17 | max_length: 500
18 | truncate: true
19 | section_size: 10000
20 |
--------------------------------------------------------------------------------
/configurations/hydra-train/training/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | # reproducibility
4 | reprod:
5 | seed: 12
6 |
7 | # optimization
8 | trainer:
9 | gradient_acc_steps: 4
10 | gradient_clip_value: 10.0
11 | max_steps: 100000
12 | val_check_interval: 1000
13 | patience: 5
14 | optim:
15 | _target_: "src.optim.factories.TorchFactory"
16 | optimizer:
17 | _target_: torch.optim.Adam
18 | lr: 1e-5
19 | checkpoint:
20 | filename: '{val_loss:.4f}'
21 | monitor: '-val_loss'
22 | save_top_k: 5
23 | save_last: true
24 |
25 | # logger
26 | logger:
27 | _target_: 'pytorch_lightning.loggers.wandb.WandbLogger'
28 | name: ${exp_name}
29 | project: ${project_name}
30 |
31 |
--------------------------------------------------------------------------------
/src/callbacks/best_checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from pathlib import Path
4 |
5 | from pytorch_lightning.callbacks import ModelCheckpoint
6 |
7 |
8 | class ModelCheckpointWithBest(ModelCheckpoint):
9 | """A callback that explictly saves the best checkpoint with best.ckpt.
10 |
11 | Note that the best checkpoint is duplicated, rather than linked, in best.ckpt
12 | """
13 |
14 | CHECKPOINT_NAME_BEST = 'best.ckpt'
15 |
16 | def on_validation_end(self, trainer, pl_module):
17 | super().on_validation_end(trainer, pl_module)
18 | if self.best_model_path == "":
19 | return
20 | orig_best = Path(self.best_model_path)
21 | shutil.copyfile(orig_best, orig_best.parent / self.CHECKPOINT_NAME_BEST)
22 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # setup conda
4 | source ~/miniconda3/etc/profile.d/conda.sh
5 |
6 | # create conda env
7 | read -rp "Enter environment name: " env_name
8 | read -rp "Enter python version (e.g. 3.7): " python_version
9 | conda create -yn "$env_name" python="$python_version"
10 | conda activate "$env_name"
11 |
12 | # install torch
13 | read -rp "Enter cuda version (e.g. 10.1 or none to avoid installing cuda support): " cuda_version
14 | if [ "$cuda_version" == "none" ]; then
15 | conda install -y pytorch torchvision cpuonly -c pytorch
16 | else
17 | conda install -y pytorch torchvision cudatoolkit=$cuda_version -c pytorch
18 | fi
19 |
20 | # install python requirements
21 | pip install -r requirements.txt
22 |
23 | # login into wandb
24 | read -p "Enter wandb key: " wandb_key
25 | wandb login $wandb_key
26 |
--------------------------------------------------------------------------------
/src/utils/commons.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from typing import Optional, List
3 |
4 | import numpy as np
5 |
6 | from src.utils.logging import get_project_logger
7 |
8 | logger = get_project_logger(__name__)
9 |
10 |
11 | def execute_bash_command(command: str) -> Optional[str]:
12 | command_result = subprocess.run(command, shell=True, capture_output=True)
13 | try:
14 | command_result.check_returncode()
15 | return command_result.stdout.decode("utf-8")
16 | except subprocess.CalledProcessError:
17 | logger.warning(f"failed executing command: {command}")
18 | logger.warning(f"return code was: {command_result.returncode}")
19 | logger.warning(f'stdout was: {command_result.stdout.decode("utf-8")}')
20 | logger.warning(f'stderr code was: {command_result.stderr.decode("utf-8")}')
21 | return None
22 |
23 |
24 | def chunks(l: List, k: int) -> List[List]:
25 | assert k >= 1
26 | return [l[i : i + k] for i in range(0, len(l), k)]
27 |
28 |
29 | def flatten(lst: List[list]) -> list:
30 | return [_e for sub_l in lst for _e in sub_l]
31 |
32 |
33 | def add_noise_to_value(value: int, noise_param: float):
34 | noise_value = value * noise_param
35 | noise = np.random.uniform(-noise_value, noise_value)
36 | return max(1, value + noise)
37 |
--------------------------------------------------------------------------------
/src/data/datasets/alternator.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import random
4 | from typing import Iterator, List, Dict
5 |
6 | from torch.utils.data import IterableDataset
7 |
8 |
9 | class AlternatorIterableDataset(IterableDataset):
10 | """IterableDataset that allows for batch alternation, according to specified probabilities, among multiple IterableDataset-s.
11 |
12 | Iteration continues until all datasets are exhausted, restarting all those that finish until this condition is met.
13 | """
14 |
15 | def __init__(self, datasets: List[Dict], p: List[float], **kwargs):
16 | assert len(datasets) == len(p)
17 | self.datasets: List[IterableDataset] = [
18 | hydra.utils.instantiate(d, **kwargs) for d in datasets
19 | ]
20 | self.p = p
21 | assert np.isclose(sum(self.p), 1.0)
22 |
23 | def __iter__(self) -> Iterator:
24 |
25 | done = [False for _ in self.datasets]
26 | iterators = [iter(d) for d in self.datasets]
27 |
28 | while True:
29 |
30 | i = random.choices(list(range(len(self.datasets))), weights=self.p, k=1)[0]
31 |
32 | try:
33 | batch = next(iterators[i])
34 | except StopIteration:
35 | done[i] = True
36 | if all(done):
37 | break
38 | iterators[i] = iter(self.datasets[i])
39 | batch = next(iterators[i])
40 |
41 | yield batch
42 |
--------------------------------------------------------------------------------
/src/pl_modules/utils.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import hydra
3 | import torch
4 | from omegaconf import OmegaConf
5 |
6 |
7 | def label_smoothed_nll_loss(lprobs, target, epsilon, padding_mask: torch.Tensor):
8 | """
9 | Inspired by https://github.com/huggingface/transformers/blob/5148f433097915f30864bf0ca6090656fecefbb8/examples/seq2seq/utils.py
10 |
11 | With a change however: using mean rather than sum ( nll_loss = nll...)
12 |
13 | """
14 |
15 | assert target.dim() == padding_mask.dim()
16 | if target.dim() == lprobs.dim() - 1:
17 | target = target.unsqueeze(-1)
18 | padding_mask = padding_mask.unsqueeze(-1)
19 |
20 | # compute nll loss
21 | nll_loss = -lprobs.gather(dim=-1, index=target)
22 | nll_loss.masked_fill_(~padding_mask, 0.0)
23 |
24 | # compute smooth loss
25 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
26 | smooth_loss.masked_fill_(~padding_mask, 0.0)
27 |
28 | nll_loss = nll_loss.sum() / padding_mask.sum()
29 | smooth_loss = smooth_loss.sum() / padding_mask.sum()
30 |
31 | eps_i = epsilon / lprobs.size(-1)
32 | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
33 |
34 | return loss, nll_loss
35 |
36 |
37 | def load_pl_module_from_checkpoint(checkpoint_path: str):
38 | """
39 | Load a PL module from a checkpoint path only. Infer the model to load from the dumped hydra conf
40 |
41 | Args:
42 | checkpoint_path (str):
43 |
44 | Returns:
45 | pl.LightningModule
46 |
47 | """
48 |
49 | # find hydra config path
50 | hydra_config_path = "/".join(checkpoint_path.split("/")[:-2])
51 |
52 | # load hydra config
53 | conf = OmegaConf.load(f"{hydra_config_path}/.hydra/config.yaml")
54 |
55 | # instantiate and return
56 | return hydra.utils.instantiate(
57 | {"_target_": f'{conf["model"]["_target_"]}.load_from_checkpoint'},
58 | checkpoint_path=checkpoint_path,
59 | )
60 |
--------------------------------------------------------------------------------
/src/data/datamodules/simple.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List, Optional, Dict
2 |
3 | import hydra
4 | import pytorch_lightning as pl
5 | from torch.utils.data import DataLoader
6 | from transformers import PreTrainedTokenizer
7 |
8 | from src.utils.logging import get_project_logger
9 |
10 | logger = get_project_logger(__name__)
11 |
12 |
13 | class SimpleDataModule(pl.LightningDataModule):
14 | """Simple datamodule, useful when datamodules are not really needed (e.g. train using ParallelDataset on a already-downlaoded tsv file)"""
15 |
16 | def __init__(
17 | self,
18 | tokenizer: PreTrainedTokenizer,
19 | train_dataset: Optional[Dict] = None,
20 | validation_dataset: Optional[Dict] = None,
21 | test_dataset: Optional[Dict] = None,
22 | num_workers: int = 0,
23 | ):
24 | super().__init__()
25 | self.tokenizer = tokenizer
26 | self.train_dataset_conf = train_dataset
27 | self.validation_dataset_conf = validation_dataset
28 | self.test_dataset_conf = test_dataset
29 | self.num_workers = num_workers
30 | self.train_dataset, self.validation_dataset, self.test_dataset = None, None, None
31 |
32 | def prepare_data(self, *args, **kwargs):
33 | pass
34 |
35 | def setup(self, stage: Optional[str] = None):
36 | if stage == "fit":
37 | assert self.train_dataset_conf is not None and self.validation_dataset_conf is not None
38 | self.train_dataset = hydra.utils.instantiate(
39 | self.train_dataset_conf, tokenizer=self.tokenizer, _recursive_=False
40 | )
41 | self.validation_dataset = hydra.utils.instantiate(
42 | self.validation_dataset_conf, tokenizer=self.tokenizer, _recursive_=False
43 | )
44 | else:
45 | assert self.test_dataset_conf is not None
46 | self.test_dataset = hydra.utils.instantiate(
47 | self.test_dataset_conf, tokenizer=self.tokenizer, _recursive_=False
48 | )
49 |
50 | def train_dataloader(self, *args, **kwargs) -> DataLoader:
51 | return DataLoader(self.train_dataset, batch_size=None, num_workers=self.num_workers)
52 |
53 | def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
54 | return DataLoader(self.validation_dataset, batch_size=None, num_workers=self.num_workers)
55 |
56 | def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
57 | return DataLoader(self.test_dataset, batch_size=None, num_workers=self.num_workers)
58 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # ide stuff
132 | .idea
133 | .vscode
134 |
135 | # custom stuff
136 | data/*
137 | experiments/*
138 | !.placeholder
--------------------------------------------------------------------------------
/src/pl_modules/generative_models/generative_model.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Dict, Any, Union
3 |
4 | import torch
5 | from torch import nn
6 |
7 |
8 | @dataclass
9 | class TAGenerativeModelOutput:
10 | """Dataclass storing the returned fields of a teacher-forced forward.
11 |
12 | Args:
13 | loss (torch.Tensor): loss upon which the backprop should be performed (affected by regularizers such as label smoothing)
14 | plain_loss (torch.Tensor): plain loss unaffected by regularizers
15 | logits (torch.Tensor)
16 | predictions (torch.Tensor)
17 | """
18 |
19 | loss: torch.Tensor
20 | plain_loss: torch.Tensor
21 | logits: torch.Tensor
22 | predictions: torch.Tensor
23 |
24 |
25 | @dataclass
26 | class GenGenerativeModelOutput:
27 | """Dataclass storing the returned fields of a generative forward.
28 |
29 | Args:
30 | generation (torch.Tensor): tensor of shape (batch_size, num_sequences, sequence_length) storing the generated sequences
31 | raw (torch.Tensor): raw object storing details of the decoding processes. For example, when using HuggingFace Transformers'
32 | models, it can be used to store attentions, generation scores, ...
33 | """
34 |
35 | generation: torch.Tensor
36 | raw: Any
37 |
38 |
39 | class GenerativeModel(nn.Module):
40 | """Abstract class denoting the interface of a generative model.
41 |
42 | This class essentially alternates between two states, teacher-forced or generative, and the current state regulates
43 | the forward method. The teacher-forced state is the one expected to be activate at training time: given the input, the
44 | model will compute and return a teacher-forced loss. Conversely, the generative state is for generation: the model will
45 | generate sequences conditioned on the provided input.
46 |
47 | Besides the forward method, GenerativeModel offers hooks to switch between the two modes and an additional one to
48 | load generation-time parameters.
49 | """
50 |
51 | def __init__(self):
52 | super().__init__()
53 | self.generation_mode = False
54 | self.generation_params = None
55 |
56 | def enable_generation_mode(self):
57 | """Enables generation mode for the model."""
58 | self.generation_mode = True
59 |
60 | def disable_generation_mode(self):
61 | """Disables generation mode for the model."""
62 | self.generation_mode = False
63 |
64 | def load_generation_params(self, generation_params: Dict[str, Any]):
65 | """Loads and stores the given generation params, that will be used in the next generation forwards"""
66 | self.generation_params = generation_params
67 |
68 | def forward(self, *args, **kwargs) -> Union[TAGenerativeModelOutput, GenGenerativeModelOutput]:
69 | raise NotImplementedError
70 |
--------------------------------------------------------------------------------
/src/pl_modules/generative_pl_module.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 |
3 | import hydra
4 | import pytorch_lightning as pl
5 | import torch
6 |
7 |
8 | class GenerativePLModule(pl.LightningModule):
9 | """Standard LightningModule wrapping a GenerativeModel."""
10 |
11 | def __init__(self, generative_model: Dict, optim_conf: Dict):
12 |
13 | super().__init__()
14 | self.save_hyperparameters()
15 |
16 | # generative model & optim
17 | self.generative_model = hydra.utils.instantiate(generative_model)
18 | self._optim_conf = optim_conf
19 |
20 | # metrics
21 | self.train_acc = pl.metrics.Accuracy()
22 | self.val_acc = pl.metrics.Accuracy()
23 |
24 | @property
25 | def tokenizer(self):
26 | return self.generative_model.tokenizer
27 |
28 | def enable_generation_mode(self):
29 | self.generative_model.enable_generation_mode()
30 |
31 | def disable_generation_mode(self):
32 | self.generative_model.disable_generation_mode()
33 |
34 | def load_generation_params(self, generation_params: Dict):
35 | self.generative_model.load_generation_params(generation_params)
36 |
37 | def forward(self, *args, **kwargs):
38 | return self.generative_model(*args, **kwargs)
39 |
40 | def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor:
41 |
42 | forward_output = self.forward(**batch)
43 |
44 | # loss
45 | self.log("train_loss", forward_output.loss, prog_bar=True, on_step=False, on_epoch=True)
46 |
47 | # perplexity
48 | self.log(
49 | "train_ppl", torch.exp(forward_output.plain_loss), prog_bar=True, on_step=False, on_epoch=True
50 | )
51 |
52 | # accuracy
53 | padding_mask = batch["target_padding_mask"][:, 1:].reshape(-1)
54 | train_acc = self.train_acc(
55 | forward_output.predictions.view(-1)[padding_mask],
56 | batch["target"][:, 1:].reshape(-1)[padding_mask],
57 | )
58 | self.log("train_accuracy", train_acc, prog_bar=True, on_step=False, on_epoch=True)
59 |
60 | # return
61 | return forward_output.loss
62 |
63 | def validation_step(self, batch: Dict[str, Any], batch_idx: int):
64 |
65 | forward_output = self.forward(**batch)
66 |
67 | # loss
68 | self.log("val_loss", forward_output.loss, prog_bar=True, on_step=False, on_epoch=True)
69 |
70 | # perplexity
71 | self.log(
72 | "val_ppl", torch.exp(forward_output.plain_loss), prog_bar=True, on_step=False, on_epoch=True
73 | )
74 |
75 | # accuracy
76 | padding_mask = batch["target_padding_mask"][:, 1:].reshape(-1)
77 | val_acc = self.val_acc(
78 | forward_output.predictions.view(-1)[padding_mask],
79 | batch["target"][:, 1:].reshape(-1)[padding_mask],
80 | )
81 | self.log("val_accuracy", val_acc, prog_bar=True, on_step=False, on_epoch=True)
82 |
83 | # return
84 | return forward_output.loss
85 |
86 | def configure_optimizers(self):
87 | return hydra.utils.instantiate(self._optim_conf, _recursive_=False)(module=self)
88 |
--------------------------------------------------------------------------------
/src/data/datamodules/cnn_dm.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Union, List, Optional, Dict
3 |
4 | import hydra
5 | import pytorch_lightning as pl
6 | from datasets import load_dataset
7 | from torch.utils.data import DataLoader
8 | from tqdm import tqdm
9 | from transformers import PreTrainedTokenizer
10 |
11 | from src.utils.logging import get_project_logger
12 |
13 | logger = get_project_logger(__name__)
14 |
15 |
16 | class CNNDMDataModule(pl.LightningDataModule):
17 | def __init__(
18 | self, tokenizer: PreTrainedTokenizer, data_dir: str, dataset: Dict, num_workers: int = 0
19 | ):
20 | super().__init__()
21 | self.tokenizer = tokenizer
22 | self.data_dir = data_dir
23 | self.num_workers = num_workers
24 | self.dataset = dataset
25 | self.train, self.val, self.test = None, None, None
26 |
27 | def prepare_data(self, *args, **kwargs):
28 |
29 | # check if download is needed
30 | if os.path.exists(self.data_dir):
31 | if os.path.exists(f"{self.data_dir}/LOCK"):
32 | return
33 | else:
34 | logger.warning(
35 | "Dataset folder already exists, but written dataset seems incomplete. Overwriting it"
36 | )
37 | else:
38 | os.mkdir(self.data_dir)
39 |
40 | # download and save cnn-datamodule to disk
41 | logger.info("Downloading CNN-DM corpus")
42 | dataset = load_dataset("cnn_dailymail", "3.0.0")
43 | for split in ["train", "validation", "test"]:
44 | with open(f"{self.data_dir}/{split}.tsv", "w") as f:
45 | for sample in tqdm(dataset[split], desc=f"Writing CNN-DM {split} split to disk"):
46 | article, summary = (
47 | sample["article"].replace("\n", ". "),
48 | sample["highlights"].replace("\n", ". "),
49 | )
50 | f.write(f"{article}\t{summary}\n")
51 |
52 | # dump lock to skip future re-downloads
53 | with open(f"{self.data_dir}/LOCK", "w") as _:
54 | pass
55 |
56 | def setup(self, stage: Optional[str] = None):
57 | if stage == "fit":
58 | self.train = hydra.utils.instantiate(
59 | self.dataset, path=f"{self.data_dir}/train.tsv", tokenizer=self.tokenizer
60 | )
61 | self.val = hydra.utils.instantiate(
62 | self.dataset, path=f"{self.data_dir}/validation.tsv", tokenizer=self.tokenizer
63 | )
64 | else:
65 | self.test = hydra.utils.instantiate(
66 | self.dataset, path=f"{self.data_dir}/test.tsv", tokenizer=self.tokenizer
67 | )
68 |
69 | def train_dataloader(self, *args, **kwargs) -> DataLoader:
70 | return DataLoader(self.train, batch_size=None, num_workers=self.num_workers)
71 |
72 | def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
73 | return DataLoader(self.val, batch_size=None, num_workers=self.num_workers)
74 |
75 | def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
76 | return DataLoader(self.test, batch_size=None, num_workers=self.num_workers)
77 |
--------------------------------------------------------------------------------
/src/optim/factories.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import hydra
4 | import torch
5 | from omegaconf import DictConfig
6 |
7 | from src.optim.optimizers.radam import RAdam
8 |
9 |
10 | class Factory:
11 | """Factory interface that allows for simple instantiation of optimizers and schedulers for PyTorch Lightning.
12 |
13 | This class is essentially a work-around for lazy instantiation:
14 | * all params but for the module to be optimized are received in __init__
15 | * the actual instantiation of optimizers and schedulers takes place in the __call__ method, where the module to be
16 | optimized is provided
17 |
18 | __call__ will be invoked in the configure_optimizers hooks of LighiningModule-s and its return object directly returned.
19 | As such, the return type of __call__ can be any of those allowed by configure_optimizers, namely:
20 | * Single optimizer
21 | * List or Tuple - List of optimizers
22 | * Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict)
23 | * Dictionary, with an ‘optimizer’ key, and (optionally) a ‘lr_scheduler’ key whose value is a single LR scheduler or lr_dict
24 | * Tuple of dictionaries as described, with an optional ‘frequency’ key
25 | * None - Fit will run without any optimizer
26 | """
27 |
28 | def __call__(self, module: torch.nn.Module):
29 | raise NotImplementedError
30 |
31 |
32 | class TorchFactory(Factory):
33 | """Simple factory wrapping standard PyTorch optimizers and schedulers."""
34 |
35 | # todo add scheduler support as well
36 |
37 | def __init__(self, optimizer: DictConfig):
38 | self.optimizer = optimizer
39 |
40 | def __call__(self, module: torch.nn.Module):
41 | return hydra.utils.instantiate(self.optimizer, params=module.parameters())
42 |
43 |
44 | class RAdamWithDecayFactory(Factory):
45 | """Factory for RAdam optimizer."""
46 |
47 | def __init__(self, lr: float, weight_decay: float, no_decay_params: Optional[List[str]]):
48 | self.lr = lr
49 | self.weight_decay = weight_decay
50 | self.no_decay_params = no_decay_params
51 |
52 | def __call__(self, module: torch.nn.Module):
53 |
54 | if self.no_decay_params is not None:
55 |
56 | optimizer_grouped_parameters = [
57 | {
58 | "params": [
59 | p
60 | for n, p in module.named_parameters()
61 | if not any(nd in n for nd in self.no_decay_params)
62 | ],
63 | "weight_decay": self.weight_decay,
64 | },
65 | {
66 | "params": [
67 | p
68 | for n, p in module.named_parameters()
69 | if any(nd in n for nd in self.no_decay_params)
70 | ],
71 | "weight_decay": 0.0,
72 | },
73 | ]
74 |
75 | else:
76 |
77 | optimizer_grouped_parameters = [
78 | {"params": module.parameters(), "weight_decay": self.weight_decay}
79 | ]
80 |
81 | optimizer = RAdam(optimizer_grouped_parameters, lr=self.lr, weight_decay=self.weight_decay)
82 |
83 | return optimizer
84 |
--------------------------------------------------------------------------------
/src/scripts/model/train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import hydra
4 | import omegaconf
5 | import pytorch_lightning as pl
6 | from pytorch_lightning.callbacks import EarlyStopping
7 |
8 | from src.callbacks.best_checkpoint import ModelCheckpointWithBest
9 |
10 |
11 | def train(conf: omegaconf.DictConfig) -> None:
12 |
13 | # reproducibility
14 | pl.seed_everything(conf.training.reprod.seed)
15 |
16 | # main module declaration
17 | pl_module = hydra.utils.instantiate(
18 | conf.model, optim_conf=conf.training.trainer.optim, _recursive_=False
19 | )
20 |
21 | # data_module declaration
22 | pl_data_module = hydra.utils.instantiate(
23 | conf.data.datamodule,
24 | tokenizer=pl_module.tokenizer, # todo bad coupling towards huggingface
25 | _recursive_=False,
26 | )
27 |
28 | # callbacks
29 |
30 | callbacks = []
31 |
32 | # callbacks: checkpoint and early stopping
33 |
34 | monitor = conf.training.trainer.checkpoint.monitor
35 | assert monitor[0] in ["-", "+"]
36 | mode = "min" if monitor[0] == "-" else "max"
37 | monitor = monitor[1:]
38 |
39 | callbacks.append(
40 | ModelCheckpointWithBest(
41 | monitor=monitor,
42 | mode=mode,
43 | dirpath=f"checkpoints",
44 | filename=conf.training.trainer.checkpoint.filename,
45 | save_top_k=conf.training.trainer.checkpoint.save_top_k,
46 | save_last=conf.training.trainer.checkpoint.save_last,
47 | verbose=True,
48 | )
49 | )
50 |
51 | patience = conf.training.trainer.patience
52 | if patience is not None:
53 | callbacks.append(
54 | EarlyStopping(monitor=monitor, mode=mode, patience=conf.training.trainer.patience)
55 | )
56 |
57 | # custom callbacks
58 |
59 | for callback in conf.callbacks.callbacks:
60 | callbacks.append(hydra.utils.instantiate(callback))
61 |
62 | # instantiate trainer logger
63 | logger = hydra.utils.instantiate(conf.training.logger)
64 |
65 | # trainer
66 | trainer = pl.Trainer(
67 | **conf.device,
68 | accumulate_grad_batches=conf.training.trainer.gradient_acc_steps,
69 | gradient_clip_val=conf.training.trainer.gradient_clip_value,
70 | max_steps=conf.training.trainer.max_steps,
71 | val_check_interval=conf.training.trainer.val_check_interval,
72 | logger=logger,
73 | callbacks=callbacks,
74 | )
75 |
76 | # module fit
77 | trainer.fit(pl_module, datamodule=pl_data_module)
78 |
79 |
80 | @hydra.main(config_path="../../../configurations/hydra-train", config_name="root")
81 | def main(conf: omegaconf.DictConfig):
82 |
83 | # fix paths
84 |
85 | def fix(conf):
86 | if type(conf) == list or type(conf) == omegaconf.listconfig.ListConfig:
87 | for i in range(len(conf)):
88 | conf[i] = fix(conf[i])
89 | return conf
90 | elif type(conf) == dict or type(conf) == omegaconf.dictconfig.DictConfig:
91 | for k, v in conf.items():
92 | conf[k] = fix(v)
93 | return conf
94 | elif type(conf) == str:
95 | if "/" in conf and os.path.exists(
96 | hydra.utils.to_absolute_path(conf[: conf.rindex("/")])
97 | ):
98 | return hydra.utils.to_absolute_path(conf)
99 | else:
100 | return conf
101 | elif type(conf) in [float, int, bool]:
102 | return conf
103 | else:
104 | raise ValueError(f"Unexpected type {type(conf)}: {conf}")
105 |
106 | fix(conf)
107 |
108 | # actual train
109 | train(conf)
110 |
111 |
112 | if __name__ == "__main__":
113 | main()
114 |
--------------------------------------------------------------------------------
/src/pl_modules/generative_models/bart.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import pytorch_lightning as pl
4 | import torch
5 | from transformers import AutoTokenizer, BartForConditionalGeneration, AutoConfig
6 |
7 | from src.pl_modules.generative_models.generative_model import (
8 | GenerativeModel,
9 | TAGenerativeModelOutput,
10 | GenGenerativeModelOutput,
11 | )
12 | from src.pl_modules.utils import label_smoothed_nll_loss
13 |
14 |
15 | class BartGenerativeModel(GenerativeModel):
16 | def __init__(self, bart_model: str, dropout: float, label_smoothing: float):
17 |
18 | super().__init__()
19 |
20 | self.tokenizer = AutoTokenizer.from_pretrained(bart_model, add_prefix_space=True)
21 | self.config = AutoConfig.from_pretrained(bart_model, dropout=dropout)
22 | self.bart_model = BartForConditionalGeneration.from_pretrained(
23 | bart_model, config=self.config
24 | )
25 | self._dropout = dropout
26 | self._label_smoothing = label_smoothing
27 |
28 | # metrics
29 | self.train_acc = pl.metrics.Accuracy()
30 | self.val_acc = pl.metrics.Accuracy()
31 |
32 | def forward(
33 | self,
34 | source: torch.Tensor,
35 | source_padding_mask: torch.Tensor,
36 | target: torch.Tensor,
37 | target_padding_mask: torch.Tensor,
38 | num_sequences: int = 1,
39 | **kwargs
40 | ) -> Union[TAGenerativeModelOutput, GenGenerativeModelOutput]:
41 |
42 | if (
43 | target.shape[1] > 1
44 | ): # training-phase: "target" is provided and we can use the "teacher-forcing" strategy.
45 |
46 | assert (
47 | not self.generation_mode
48 | ), 'The "target" is not empty but the GenerativeModel is in "generation mode"'
49 |
50 | # build target&labels
51 | decoder_input_ids = target[:, :-1].contiguous()
52 | decoder_padding_mask = target_padding_mask[:, :-1].contiguous()
53 | labels = target[:, 1:].contiguous()
54 | labels_padding_mask = target_padding_mask[:, 1:].contiguous()
55 | labels[~labels_padding_mask] = -100
56 |
57 | # actual forward
58 | result = self.bart_model(
59 | input_ids=source,
60 | attention_mask=source_padding_mask,
61 | decoder_input_ids=decoder_input_ids,
62 | decoder_attention_mask=decoder_padding_mask,
63 | labels=labels,
64 | )
65 | logits = result[1]
66 |
67 | # compute loss with label smoothing
68 | labels[~labels_padding_mask] = self.tokenizer.pad_token_id
69 | log_probs = torch.log_softmax(logits, dim=-1)
70 | smoothed_loss, nll_loss = label_smoothed_nll_loss(
71 | log_probs.view(-1, log_probs.shape[2]),
72 | labels.view(-1),
73 | self._label_smoothing,
74 | padding_mask=labels_padding_mask.view(-1),
75 | )
76 |
77 | # return
78 | return TAGenerativeModelOutput(loss=smoothed_loss, plain_loss=nll_loss, logits=logits, predictions=logits.argmax(-1))
79 |
80 | else: # autoregressive-phase: the only token in target is "begin of sequence" ( for bart).
81 |
82 | assert self.generation_mode
83 |
84 | assert target.shape[1] == 1
85 | assert len(set(target[:, 0].tolist())) == 1
86 |
87 | generation = self.bart_model.generate(
88 | input_ids=source,
89 | attention_mask=source_padding_mask,
90 | num_return_sequences=num_sequences,
91 | decoder_start_token_id=target[0][0].item(),
92 | return_dict_in_generate=True,
93 | **self.generation_params
94 | )
95 |
96 | return GenGenerativeModelOutput(
97 | generation=generation.sequences[:, 1:].reshape(source.shape[0], num_sequences, -1),
98 | raw=generation,
99 | )
100 |
--------------------------------------------------------------------------------
/src/optim/optimizers/radam.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 |
5 |
6 | class RAdam(Optimizer):
7 | def __init__(
8 | self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True
9 | ):
10 | if not 0.0 <= lr:
11 | raise ValueError("Invalid learning rate: {}".format(lr))
12 | if not 0.0 <= eps:
13 | raise ValueError("Invalid epsilon value: {}".format(eps))
14 | if not 0.0 <= betas[0] < 1.0:
15 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
16 | if not 0.0 <= betas[1] < 1.0:
17 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
18 |
19 | self.degenerated_to_sgd = degenerated_to_sgd
20 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
21 | for param in params:
22 | if "betas" in param and (
23 | param["betas"][0] != betas[0] or param["betas"][1] != betas[1]
24 | ):
25 | param["buffer"] = [[None, None, None] for _ in range(10)]
26 | defaults = dict(
27 | lr=lr,
28 | betas=betas,
29 | eps=eps,
30 | weight_decay=weight_decay,
31 | buffer=[[None, None, None] for _ in range(10)],
32 | )
33 | super(RAdam, self).__init__(params, defaults)
34 |
35 | def __setstate__(self, state):
36 | super(RAdam, self).__setstate__(state)
37 |
38 | def step(self, closure=None):
39 |
40 | loss = None
41 | if closure is not None:
42 | loss = closure()
43 |
44 | for group in self.param_groups:
45 |
46 | for p in group["params"]:
47 | if p.grad is None:
48 | continue
49 | grad = p.grad.data.float()
50 | if grad.is_sparse:
51 | raise RuntimeError("RAdam does not support sparse gradients")
52 |
53 | p_data_fp32 = p.data.float()
54 |
55 | state = self.state[p]
56 |
57 | if len(state) == 0:
58 | state["step"] = 0
59 | state["exp_avg"] = torch.zeros_like(p_data_fp32)
60 | state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
61 | else:
62 | state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
63 | state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
64 |
65 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
66 | beta1, beta2 = group["betas"]
67 |
68 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
69 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
70 |
71 | state["step"] += 1
72 | buffered = group["buffer"][int(state["step"] % 10)]
73 | if state["step"] == buffered[0]:
74 | N_sma, step_size = buffered[1], buffered[2]
75 | else:
76 | buffered[0] = state["step"]
77 | beta2_t = beta2 ** state["step"]
78 | N_sma_max = 2 / (1 - beta2) - 1
79 | N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
80 | buffered[1] = N_sma
81 |
82 | # more conservative since it's an approximated value
83 | if N_sma >= 5:
84 | step_size = math.sqrt(
85 | (1 - beta2_t)
86 | * (N_sma - 4)
87 | / (N_sma_max - 4)
88 | * (N_sma - 2)
89 | / N_sma
90 | * N_sma_max
91 | / (N_sma_max - 2)
92 | ) / (1 - beta1 ** state["step"])
93 | elif self.degenerated_to_sgd:
94 | step_size = 1.0 / (1 - beta1 ** state["step"])
95 | else:
96 | step_size = -1
97 | buffered[2] = step_size
98 |
99 | # more conservative since it's an approximated value
100 | if N_sma >= 5:
101 | if group["weight_decay"] != 0:
102 | p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
103 | denom = exp_avg_sq.sqrt().add_(group["eps"])
104 | p_data_fp32.addcdiv_(-step_size * group["lr"], exp_avg, denom)
105 | p.data.copy_(p_data_fp32)
106 | elif step_size > 0:
107 | if group["weight_decay"] != 0:
108 | p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
109 | p_data_fp32.add_(-step_size * group["lr"], exp_avg)
110 | p.data.copy_(p_data_fp32)
111 |
112 | return loss
113 |
--------------------------------------------------------------------------------
/src/callbacks/generation.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import itertools
3 | from multiprocessing import Pool
4 | from typing import List, Dict, Any, Tuple, Callable, Optional
5 |
6 | import pytorch_lightning as pl
7 | import torch
8 | import wandb
9 | from datasets import load_metric
10 |
11 | from src.scripts.model.translate import translate
12 |
13 | from src.utils.logging import get_project_logger
14 |
15 | logger = get_project_logger(__name__)
16 |
17 |
18 | class GenerationCallback:
19 | def __call__(
20 | self,
21 | name: str,
22 | translations: List[Tuple[str, List[str], Optional[str]]],
23 | module: pl.LightningModule,
24 | ):
25 | raise NotImplementedError
26 |
27 |
28 | class RougeGenerationCallback(GenerationCallback):
29 | def __init__(self):
30 | self.rouge = load_metric("rouge")
31 |
32 | def __call__(
33 | self,
34 | name: str,
35 | translations: List[Tuple[str, List[str], Optional[str]]],
36 | module: pl.LightningModule,
37 | ):
38 | assert all(t[2] is not None for t in translations)
39 | results = self.rouge.compute(
40 | predictions=[t[1][0] for t in translations], references=[t[2] for t in translations]
41 | )
42 | for k, v in results.items():
43 | module.log(
44 | f"val_{name}_{k}", v.mid.fmeasure, prog_bar=True, on_step=False, on_epoch=True
45 | )
46 |
47 |
48 | class TextGenerationCallback(pl.Callback):
49 | def __init__(
50 | self, generation_callbacks: Dict[str, GenerationCallback], generations: List[Dict[str, Any]]
51 | ):
52 | self._epoch = 0
53 | self.generation_callbacks = generation_callbacks
54 | self.generations_confs = []
55 | for g in generations:
56 | self.generations_confs.append(
57 | (
58 | g["name"],
59 | g["glob_translate_path"],
60 | g["generation_param_conf_path"],
61 | g["num_sequences"],
62 | g["token_batch_size"],
63 | g["limit"],
64 | g["enabled_generation_callbacks"],
65 | )
66 | )
67 |
68 | def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
69 |
70 | wandb_table = wandb.Table(columns=["Configuration", "Source", "Input", "Pred", "Gold"])
71 | logger.info("Executing translation callback")
72 |
73 | for (
74 | name,
75 | glob_translate_path,
76 | generation_param_conf_path,
77 | num_sequences,
78 | token_batch_size,
79 | limit,
80 | enabled_generation_callbacks,
81 | ) in self.generations_confs:
82 |
83 | translation_pairs = []
84 |
85 | # translate
86 |
87 | for translation_path in glob.iglob(glob_translate_path):
88 |
89 | logger.info(
90 | f"Translating translation path {translation_path} for configuration {name}"
91 | )
92 | source_type = translation_path.split("/")[-1][:-4]
93 |
94 | with open(translation_path) as f:
95 |
96 | # read sources
97 | iterator = map(lambda l: l.strip(), f)
98 |
99 | # do only a dry run on first epoch (correspond to sanity check run)
100 | if self._epoch == 0:
101 | iterator = itertools.islice(iterator, 5)
102 |
103 | # apply limit
104 | if limit != -1:
105 | iterator = itertools.islice(iterator, limit)
106 |
107 | for i, (source, sample_translations, gold_output) in enumerate(
108 | translate(
109 | pl_module,
110 | iterator,
111 | num_sequences=num_sequences,
112 | generation_param_conf_path=generation_param_conf_path,
113 | token_batch_size=token_batch_size,
114 | )
115 | ):
116 | if i % 100 == 0:
117 | logger.debug(
118 | f"Translating translation path {translation_path} for configuration {name}: {i} lines translated"
119 | )
120 |
121 | for translation in sample_translations:
122 | wandb_table.add_data(
123 | name, source_type, source, translation, gold_output
124 | )
125 |
126 | translation_pairs.append((source, sample_translations, gold_output))
127 |
128 | if self._epoch == 0:
129 | # do only a dry run on first epoch (correspond to sanity check run)
130 | break
131 |
132 | # run callbacks
133 |
134 | for callback in enabled_generation_callbacks:
135 | self.generation_callbacks[callback](name, translation_pairs, pl_module)
136 |
137 | if self._epoch > 0:
138 | trainer.logger.experiment.log({"translations": wandb_table})
139 |
140 | logger.info("Translation callback completed")
141 | self._epoch += 1
142 |
--------------------------------------------------------------------------------
/src/scripts/model/translate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import List, Iterable, Tuple, Optional
3 |
4 | import pytorch_lightning as pl
5 | import torch
6 | from omegaconf import OmegaConf
7 | from torch.cuda.amp import autocast
8 | from torch.utils.data import DataLoader
9 | from tqdm import tqdm
10 |
11 | from src.data.datasets.generation import ParallelDataset
12 | from src.pl_modules.utils import load_pl_module_from_checkpoint
13 |
14 |
15 | def translate(
16 | module: pl.LightningModule,
17 | sources: Iterable[str],
18 | num_sequences: int,
19 | generation_param_conf_path: str,
20 | token_batch_size: int = 1024,
21 | progress_bar: bool = False,
22 | ) -> Iterable[Tuple[str, List[str], Optional[str]]]:
23 |
24 | module.enable_generation_mode()
25 | module.load_generation_params(OmegaConf.load(generation_param_conf_path))
26 |
27 | # todo only works on single gpu
28 | device = next(module.parameters()).device
29 |
30 | # todo unnecessary coupling toward ParallelDataset
31 | # rather, should find and instantiate the appropriate GenerativeDataset
32 | dataset = ParallelDataset.from_lines(
33 | sources,
34 | tokenizer=module.tokenizer,
35 | for_inference=True,
36 | max_tokens_per_batch=token_batch_size,
37 | drop_last_batch=False,
38 | )
39 | dataloader = DataLoader(dataset, batch_size=None, num_workers=0)
40 |
41 | iterator = dataloader
42 | if progress_bar:
43 | iterator = tqdm(iterator, desc="Translating")
44 |
45 | for batch in iterator:
46 |
47 | # translate
48 | with autocast(enabled=True):
49 | with torch.no_grad():
50 | batch_generations = module(
51 | **{k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items()},
52 | num_sequences=num_sequences,
53 | )
54 | batch_input = batch["text_source"]
55 | batch_gold_output = batch["text_target"]
56 | batch_generations = batch_generations.generation
57 |
58 | # generate
59 | for sample_input, sample_generations, sample_gold_output in zip(
60 | batch_input, batch_generations, batch_gold_output
61 | ):
62 | decoded_sample_generations = []
63 | for sample_generation in module.tokenizer.batch_decode(
64 | sample_generations, clean_up_tokenization_spaces=False
65 | ):
66 | if module.tokenizer.eos_token in sample_generation:
67 | sample_generation = sample_generation[
68 | : sample_generation.index(module.tokenizer.eos_token)
69 | + len(module.tokenizer.eos_token)
70 | ]
71 | decoded_sample_generations.append(sample_generation)
72 |
73 | yield sample_input, decoded_sample_generations, sample_gold_output
74 |
75 | module.disable_generation_mode()
76 |
77 |
78 | def interactive_main(
79 | model_checkpoint_path: str,
80 | num_sequences: int,
81 | generation_param_conf_path: str,
82 | cuda_device: int,
83 | ):
84 |
85 | model = load_pl_module_from_checkpoint(model_checkpoint_path)
86 | model.to(torch.device(cuda_device if cuda_device != -1 else "cpu"))
87 | model.eval()
88 | model.freeze()
89 |
90 | while True:
91 | source = input("Enter source text: ").strip()
92 | _, predictions, _ = next(
93 | translate(
94 | model,
95 | [source],
96 | num_sequences=num_sequences,
97 | generation_param_conf_path=generation_param_conf_path,
98 | )
99 | )
100 | for i, prediction in enumerate(predictions):
101 | print(f"\t# prediction-{i}: \t{prediction}")
102 |
103 |
104 | def file_main(
105 | model_checkpoint_path: str,
106 | input_path: str,
107 | output_path: str,
108 | num_sequences: int,
109 | generation_param_conf_path: str,
110 | cuda_device: int,
111 | token_batch_size: int,
112 | ):
113 |
114 | model = load_pl_module_from_checkpoint(model_checkpoint_path)
115 | model.to(torch.device(cuda_device if cuda_device != -1 else "cpu"))
116 | model.eval()
117 | model.freeze()
118 |
119 | with open(input_path) as fi, open(output_path, "w") as fo:
120 | for source, sample_translations, _ in translate(
121 | model,
122 | map(lambda l: l.strip(), fi),
123 | num_sequences=num_sequences,
124 | generation_param_conf_path=generation_param_conf_path,
125 | token_batch_size=token_batch_size,
126 | progress_bar=True,
127 | ):
128 | for translation in sample_translations:
129 | fo.write(f"{source}\t{translation.strip()}\n")
130 |
131 |
132 | def main():
133 | args = parse_args()
134 | if args.t:
135 | interactive_main(
136 | args.model_checkpoint,
137 | num_sequences=args.n,
138 | generation_param_conf_path=args.g,
139 | cuda_device=args.cuda_device,
140 | )
141 | else:
142 | file_main(
143 | args.model_checkpoint,
144 | args.f,
145 | args.o,
146 | num_sequences=args.n,
147 | generation_param_conf_path=args.g,
148 | cuda_device=args.cuda_device,
149 | token_batch_size=args.token_batch_size,
150 | )
151 |
152 |
153 | def parse_args():
154 | parser = argparse.ArgumentParser()
155 | parser.add_argument("model_checkpoint", type=str, help="Path to pl_modules checkpoint")
156 | parser.add_argument("-n", type=int, default=1, help="Num sequences")
157 | parser.add_argument("-g", type=str, help="Path to generation conf")
158 | parser.add_argument("--cuda-device", type=int, default=-1, help="Cuda device")
159 | # interactive params
160 | parser.add_argument("-t", action="store_true", help="Interactive mode")
161 | # generation params
162 | parser.add_argument("-f", type=str, default=None, help="Input file")
163 | parser.add_argument("-o", type=str, default=None, help="Output file")
164 | parser.add_argument("--token-batch-size", type=int, default=128, help="Token batch size")
165 | # return
166 | return parser.parse_args()
167 |
168 |
169 | if __name__ == "__main__":
170 | main()
171 |
--------------------------------------------------------------------------------
/src/data/datasets/generation.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import Iterator, Dict, List, Tuple, Iterable, Callable, Optional
3 |
4 | import torch
5 | from torch.nn.utils.rnn import pad_sequence
6 | from torch.utils.data import IterableDataset
7 | from transformers import PreTrainedTokenizer
8 |
9 | from src.utils.commons import add_noise_to_value, chunks, flatten
10 | from src.utils.logging import get_project_logger
11 |
12 | logger = get_project_logger(__name__)
13 |
14 |
15 | class GenerativeDataset(IterableDataset):
16 | """Abstract interface, adding two additional classmethods to the IterableDataset interface.
17 |
18 | These methods are needed for the translate script to work.
19 | """
20 |
21 | @classmethod
22 | def from_lines(cls, lines: Iterable[str], **kwargs):
23 | raise NotImplementedError
24 |
25 | @classmethod
26 | def from_file(cls, path: str, **kwargs):
27 | raise NotImplementedError
28 |
29 |
30 | class ParallelDataset(GenerativeDataset):
31 | """Dataset useful when dealing with conditioned generation tasks (e.g. MT).
32 |
33 | The most relevant feature of this class is that its input is fed as the output a closure: from_lines and from_file
34 | setup a closure whose execution yields an Iterable[str]. This allows for the same code to be fully shared across the
35 | two modalities (and additional benefits in more complex scenarios).
36 | """
37 |
38 | @classmethod
39 | def from_lines(cls, lines: Iterable[str], **kwargs):
40 | return cls(lambda: lines, **kwargs)
41 |
42 | @classmethod
43 | def from_file(cls, path: str, **kwargs):
44 | def r():
45 | with open(path) as f:
46 | for line in f:
47 | yield line.strip()
48 |
49 | return cls(r, **kwargs)
50 |
51 | def __init__(
52 | self,
53 | iterator_generator: Callable[[], Iterable],
54 | tokenizer: PreTrainedTokenizer,
55 | for_inference: bool = False,
56 | max_tokens_per_batch: int = 1024,
57 | min_length: int = -1,
58 | max_length: int = -1,
59 | truncate: bool = False,
60 | section_size: int = 50_000,
61 | drop_last_batch: bool = False,
62 | prebatch: bool = False,
63 | ):
64 | self.iterator_generator = iterator_generator
65 | self.tokenizer = tokenizer
66 | self.for_inference = for_inference
67 | self.max_tokens_per_batch = max_tokens_per_batch
68 | self.min_length = min_length
69 | self.max_length = max_length
70 | self.truncate = truncate
71 | self.section_size = section_size if prebatch else 1
72 | self.drop_last_batch = drop_last_batch
73 | self.prebatch = prebatch
74 |
75 | def _generate_dataset(self):
76 | def prebatch_ds(ds: List[Tuple[List[int], List[int], str, Optional[str]]]):
77 | ds = sorted(
78 | ds, key=lambda x: add_noise_to_value(len(x[0]) + len(x[1]), noise_param=0.1)
79 | )
80 | ds = list(chunks(ds, 512))
81 | random.shuffle(ds)
82 | return flatten(ds)
83 |
84 | logger.info("Initting dataset")
85 |
86 | discarded_due_to_min_length = 0
87 | discarded_due_to_max_length = 0
88 |
89 | read_samples = 0
90 | ds = []
91 |
92 | for line in self.iterator_generator():
93 |
94 | if read_samples % 10_000 == 0:
95 | logger.info(f"{read_samples} entries added to dataset")
96 |
97 | line = line.strip()
98 | parts = line.split("\t")
99 | if self.for_inference:
100 | source = parts[0]
101 | target = parts[1] if len(parts) == 2 else None
102 | else:
103 | source = parts[0]
104 | target = parts[1]
105 |
106 | # encode
107 | text_source, text_target = source, target
108 | if not self.for_inference:
109 | sample = self.tokenizer.prepare_seq2seq_batch([source], tgt_texts=[target])
110 | source = sample["input_ids"][0]
111 | target = sample["labels"][0]
112 | else:
113 | sample = self.tokenizer.prepare_seq2seq_batch([source])
114 | source = sample["input_ids"][0]
115 | target = [self.tokenizer.bos_token_id]
116 |
117 | # truncate if requested
118 | if self.truncate:
119 | source = source[: self.max_length]
120 | target = target[: self.max_length]
121 |
122 | # check min length
123 | if self.min_length != -1 and (
124 | len(source) < self.min_length
125 | or (len(target) < self.min_length and not self.for_inference)
126 | ):
127 | discarded_due_to_min_length += 1
128 | if discarded_due_to_min_length % 1_000 == 0:
129 | logger.warning(
130 | f"{discarded_due_to_min_length} samples have been discarded due to being shorter than minimum length {self.min_length}"
131 | )
132 | continue
133 |
134 | # check max length
135 | if self.max_length != -1 and (
136 | len(source) > self.max_length or len(target) > self.max_length
137 | ):
138 | discarded_due_to_max_length += 1
139 | if discarded_due_to_max_length % 1_000 == 0:
140 | logger.warning(
141 | f"{discarded_due_to_max_length} samples have been discarded due to being longer than maximum length {self.max_length}"
142 | )
143 | continue
144 |
145 | ds.append((source, target, text_source, text_target))
146 | if len(ds) == self.section_size:
147 | if self.prebatch:
148 | ds = prebatch_ds(ds)
149 | yield from ds
150 | ds = []
151 |
152 | read_samples += 1
153 |
154 | if len(ds) > 0:
155 | if self.prebatch:
156 | ds = prebatch_ds(ds)
157 | yield from ds
158 |
159 | if discarded_due_to_min_length > 0:
160 | logger.warning(
161 | f"{discarded_due_to_min_length} samples have been discarded due to being shorter than minimum length {self.min_length}"
162 | )
163 |
164 | if discarded_due_to_max_length > 0:
165 | logger.warning(
166 | f"{discarded_due_to_max_length} samples have been discarded due to being longer than maximum length {self.max_length}"
167 | )
168 |
169 | def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
170 |
171 | dataset = self._generate_dataset()
172 |
173 | batch = []
174 | ct = 0
175 |
176 | for sample in dataset:
177 |
178 | sample_tokens = len(sample[0]) + len(sample[1])
179 |
180 | if (
181 | max(ct, sample_tokens) * (len(batch) + 1) > self.max_tokens_per_batch
182 | and len(batch) > 0
183 | ):
184 | yield self.prepare_output_batch(batch)
185 | batch = []
186 | ct = 0
187 |
188 | batch.append(sample)
189 | ct = max(ct, sample_tokens)
190 |
191 | # drop last cause might be too short and result in issues (nan if we are using amp)
192 | if not self.drop_last_batch and len(batch) > 0:
193 | yield self.prepare_output_batch(batch)
194 |
195 | def prepare_output_batch(
196 | self, batch: List[Tuple[List[int], List[int], str, Optional[str]]]
197 | ) -> Dict[str, torch.Tensor]:
198 |
199 | try:
200 | pad_token_id = self.tokenizer.pad_token_id
201 | except:
202 | pad_token_id = 0
203 |
204 | # build source
205 | source = pad_sequence(
206 | [torch.tensor(e[0]) for e in batch], batch_first=True, padding_value=pad_token_id
207 | )
208 | source_padding_mask = pad_sequence(
209 | [torch.full((len(e[0]),), fill_value=True) for e in batch],
210 | batch_first=True,
211 | padding_value=False,
212 | )
213 |
214 | # build target
215 | target = pad_sequence(
216 | [torch.tensor(e[1]) for e in batch], batch_first=True, padding_value=pad_token_id
217 | )
218 | target_padding_mask = pad_sequence(
219 | [torch.full((len(e[1]),), fill_value=True) for e in batch],
220 | batch_first=True,
221 | padding_value=False,
222 | )
223 |
224 | # return
225 | return {
226 | "source": source,
227 | "source_padding_mask": source_padding_mask,
228 | "target": target,
229 | "target_padding_mask": target_padding_mask,
230 | "text_source": [e[2] for e in batch],
231 | "text_target": [e[3] for e in batch],
232 | }
233 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | Grid - Seq2Seq
3 |
4 |
5 |