├── src
├── dataset
│ ├── __init__.py
│ ├── ittb_dataset.py
│ ├── flores_dataset.py
│ ├── wmt_dataset.py
│ └── iwslt_dataset.py
├── ipi
│ ├── __init__.py
│ ├── decoders
│ │ ├── __init__.py
│ │ ├── beam_search.py
│ │ ├── autoregressive.py
│ │ ├── mt_decoding.py
│ │ ├── gs_jacobi.py
│ │ ├── jacobi.py
│ │ └── hybrid_jacobi.py
│ ├── stopping_condition.py
│ └── initializer.py
├── utils
│ ├── __init__.py
│ ├── bench_scorer.py
│ ├── bleu_calculator.py
│ ├── utils.py
│ └── beam_search.py
├── viz
│ ├── __init__.py
│ ├── dependecy_graph.py
│ └── visualize.py
├── __init__.py
└── bench.py
├── assets
├── ddg.png
└── ipi.png
├── requirements.txt
├── .gitignore
├── conf
└── config.yaml
├── exp
├── tab1.sh
├── tab2.sh
└── flops_calculator.py
├── README.MD
├── main.py
└── LICENSE
/src/dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/ipi/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/viz/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/ipi/decoders/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/ddg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/teelinsan/parallel-decoding/HEAD/assets/ddg.png
--------------------------------------------------------------------------------
/assets/ipi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/teelinsan/parallel-decoding/HEAD/assets/ipi.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch~=1.13.1
2 | transformers==4.19.2
3 | datasets==2.1.0
4 | matplotlib>=3.5.2
5 | pytorch-lightning==1.6.3
6 | nltk>=3.7
7 | sacrebleu==2.1.0
8 | tqdm>=4.64.0
9 | python-dotenv==0.20.0
10 | gitpython==3.1.27
11 | more-itertools==8.12.0
12 | hydra-core==1.2.0
13 | plotly>=5.8.2
14 | #sentencepiece==0.1.96
15 | sacremoses==0.0.49
16 | sentence-transformers
17 | kaleido==0.2.1
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | fertility/
2 | /benchmark/.ipynb_checkpoints/
3 | /venv/
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | bin/
13 | build/
14 | develop-eggs/
15 | dist/
16 | eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | *.egg-info/
23 | .installed.cfg
24 | *.egg
25 |
26 | benchmark/ipi/__pycache__/
27 |
28 | # Installer logs
29 | pip-log.txt
30 | pip-delete-this-directory.txt
31 |
32 | # Unit test / coverage reports
33 | .tox/
34 | .coverage
35 | .cache
36 | nosetests.xml
37 | coverage.xml
38 |
39 | # Translations
40 | *.mo
41 |
42 | # Mr Developer
43 | .mr.developer.cfg
44 | .project
45 | .pydevproject
46 |
47 | # Rope
48 | .ropeproject
49 |
50 | # Django stuff:
51 | *.log
52 | *.pot
53 |
54 | # Sphinx documentation
55 | docs/_build/
56 | .installed.cfg
57 | bin
58 | develop-eggs
59 | dist
60 | downloads
61 | eggs
62 | parts
63 | src/*.egg-info
64 | lib
65 | lib64
66 |
67 | *.pyc
68 | *.pyo
69 |
70 |
--------------------------------------------------------------------------------
/conf/config.yaml:
--------------------------------------------------------------------------------
1 | device: "cpu"
2 | src_lang: "en"
3 | tgt_lang: "de"
4 | task: "benchmark"
5 |
6 | model:
7 | src_lang: ${src_lang}
8 | tgt_lang: ${tgt_lang}
9 | model_name: "Helsinki-NLP/opus-mt-${src_lang}-${tgt_lang}"
10 | device: ${device}
11 | use_logits_preprocessor: True
12 |
13 | dataset:
14 | data_dir: ""
15 | src_lang: ${src_lang}
16 | tgt_lang: ${tgt_lang}
17 | name: "wmt"
18 | version: "14"
19 | split: 'test'
20 | subset:
21 | use_subset: False
22 | start: 10
23 | end: 10
24 |
25 | bench:
26 | result_dir: ""
27 | n_runs: 1
28 | device: ${device}
29 |
30 | beam_search:
31 | result_dir: ""
32 | batch_size: 1
33 | num_beams: 15
34 | device: ${device}
35 |
36 | initializer:
37 | use_initializer: True
38 | path_lexicon: null
39 | use_init: False
40 | src_lang: ${src_lang}
41 | tgt_lang: ${tgt_lang}
42 | device: ${device}
43 |
44 | decoder:
45 | gs_jaco_blocks: 3
46 | use_cache: True
47 | # Specify the decoders to use
48 | decoders: ['autoregressive', 'jacobi', 'gs_jacobi', 'h_jacobi', 'beam_search']
49 | device: ${device}
50 |
51 |
52 | sample:
53 | path: ""
--------------------------------------------------------------------------------
/src/ipi/decoders/beam_search.py:
--------------------------------------------------------------------------------
1 | from src.ipi.decoders.mt_decoding import MTDecoder
2 |
3 |
4 | class BeamSearchDecoder(MTDecoder):
5 | def __init__(self, tokenizer, model, initializer, num_beams, early_stopping, **kwargs):
6 | super().__init__(tokenizer, model, initializer, **kwargs)
7 |
8 | self.name = "beam_search"
9 | self.acronym = "b"
10 |
11 | self.num_beams = num_beams
12 | self.early_stopping = early_stopping
13 |
14 | def decode(self, input_ids, attention_mask, *args, **kwargs):
15 | if self.is_mbart:
16 | with self.tokenizer.as_target_tokenizer():
17 | try:
18 | lang_id = self.tokenizer.cur_lang_code_id
19 | except:
20 | lang_id = self.tokenizer.cur_lang_id
21 | beam_output = self.model.generate(
22 | **{"input_ids": input_ids, "attention_mask": attention_mask},
23 | num_beams=self.num_beams,
24 | early_stopping=self.early_stopping,
25 | forced_bos_token_id=lang_id,
26 | )
27 | else:
28 | beam_output = self.model.generate(
29 | **{"input_ids": input_ids, "attention_mask": attention_mask},
30 | num_beams=self.num_beams,
31 | early_stopping=self.early_stopping,
32 | )
33 |
34 | return beam_output, 0
35 |
36 | def compute_decode_kwargs(self, *args, **kwargs):
37 | return {}
--------------------------------------------------------------------------------
/src/ipi/stopping_condition.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def limit_past_key_values(past_key_values, limit):
5 | new_list = []
6 | for elem in past_key_values:
7 | new_elem = list(elem)
8 | new_elem[0] = elem[0][:, :, :limit, :]
9 | new_elem[1] = elem[1][:, :, :limit, :]
10 | new_list.append(tuple(new_elem))
11 | return tuple(new_list)
12 |
13 |
14 | def stopping_criterion(past_tensor, current_tensor, eos=None):
15 | assert past_tensor.shape == current_tensor.shape
16 | if torch.equal(past_tensor, current_tensor):
17 | tensor = current_tensor
18 | if eos is not None:
19 | if eos in current_tensor[0]:
20 | pos = (current_tensor[0] == eos).nonzero(as_tuple=True)[0]
21 | if pos.shape[0] > 1:
22 | pos = pos[0].item()
23 | else:
24 | pos = pos.item()
25 | return True, tensor, pos
26 | else:
27 | return True, tensor, -1
28 | return True, tensor
29 | else:
30 | if eos is not None:
31 | return False, current_tensor, False
32 | else:
33 | return False, current_tensor
34 |
35 |
36 | def check_stop_cond(tensor, eos):
37 | if eos in tensor[0]:
38 | pos = (tensor[0] == eos).nonzero(as_tuple=True)[0]
39 | if pos.shape[0] > 1:
40 | pos = pos[0].item()
41 | else:
42 | pos = pos.item()
43 | return pos
44 | else:
45 | return -1
46 |
--------------------------------------------------------------------------------
/src/utils/bench_scorer.py:
--------------------------------------------------------------------------------
1 | from src.utils.bleu_calculator import BleuEvaluator, BleuValues
2 |
3 |
4 | class Scorer(object):
5 | def __init__(self, name, acronym):
6 | self.name = name
7 | self.acronym = acronym
8 |
9 | self.bleu_scorer = BleuEvaluator()
10 |
11 | # number of sentences
12 | self.i = 0
13 |
14 | # param bleu score
15 | self.predictions = []
16 | self.references = []
17 |
18 | # benchmark values
19 | self.tot_mean_time = 0
20 | self.tot_mean_iter = 0
21 |
22 | # inline values
23 | self.current_init = None
24 | self.current_transl = None
25 | self.current_time = None
26 | self.current_iter = None
27 |
28 | def update_metrics(self, time, iter, translation, gold, init):
29 | self.tot_mean_time += (time - self.tot_mean_time) / (self.i + 1)
30 | self.tot_mean_iter += (iter - self.tot_mean_iter) / (self.i + 1)
31 |
32 | self.predictions.append(translation)
33 | self.references.append([gold])
34 |
35 | self.current_init = init
36 | self.current_transl = translation
37 | self.current_time = time
38 | self.current_iter = iter
39 |
40 | self.i += 1
41 |
42 | def compute_bleu_score(self):
43 | bleu_score = self.bleu_scorer.final_score(
44 | model_predictions=self.predictions,
45 | gold_references=self.references
46 | )
47 |
48 | bleu_dict = {
49 | 'score': bleu_score['score'],
50 | 'counts': str(bleu_score['counts']),
51 | 'totals': str(bleu_score['totals']),
52 | 'precisions': str(['%.2f' % prec for prec in bleu_score['precisions']]),
53 | 'bp': bleu_score['bp'],
54 | 'sys_len': bleu_score['sys_len'],
55 | 'ref_len': bleu_score['ref_len']
56 | }
57 |
58 | return BleuValues(**bleu_dict)
--------------------------------------------------------------------------------
/src/dataset/ittb_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datasets import load_dataset
3 | from torch.utils.data.dataset import Dataset
4 |
5 | from src.utils.utils import clean_text
6 |
7 |
8 | class Ittb(Dataset):
9 | def __init__(
10 | self,
11 | src_lan,
12 | tgt_lan,
13 | hugginface_tokenizer=None,
14 | split: str = None,
15 | ):
16 | self.src_lan = src_lan
17 | self.tgt_lan = tgt_lan
18 | self.name = "ittb"
19 | self.max_length = 511
20 |
21 | assert (
22 | src_lan == "en" or src_lan == "hi"
23 | ), "Ittb: src_lan must be either en or hi"
24 | assert (
25 | tgt_lan == "en" or tgt_lan == "hi"
26 | ), "Ittb: tgt_lan must be either en or hi"
27 | assert src_lan != tgt_lan, "Ittb: src_lan and tgt_lan cannot be the same"
28 |
29 | self.translation_dataset = load_dataset("cfilt/iitb-english-hindi", split=split)
30 |
31 | with torch.no_grad():
32 | self.tokenizer = hugginface_tokenizer
33 |
34 | def collate_fn(self, batch):
35 |
36 | batch_source = [b[0] for b in batch]
37 | batch_target = [b[1] for b in batch]
38 |
39 | encoded_source = self.tokenizer(
40 | batch_source,
41 | padding=True,
42 | return_tensors="pt",
43 | )
44 | encoded_target = self.tokenizer(
45 | batch_target,
46 | padding=True,
47 | return_tensors="pt",
48 | )
49 |
50 | return {
51 | "source": {
52 | "input_ids": encoded_source["input_ids"],
53 | "attention_mask": encoded_source["attention_mask"],
54 | "sentences": batch_source,
55 | },
56 | "target": {
57 | "input_ids": encoded_target["input_ids"],
58 | "attention_mask": encoded_target["attention_mask"],
59 | "sentences": batch_target,
60 | },
61 | }
62 |
63 | def __len__(self):
64 | return len(self.translation_dataset)
65 |
66 | def __getitem__(self, idx: int):
67 | source = self.translation_dataset["translation"][idx][self.src_lan]
68 | target = self.translation_dataset["translation"][idx][self.tgt_lan]
69 |
70 | return source, target
71 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from pathlib import Path
4 | from typing import Optional
5 |
6 | import dotenv
7 | import git
8 |
9 | pylogger = logging.getLogger(__name__)
10 |
11 |
12 | def get_env(env_name: str, default: Optional[str] = None) -> str:
13 | """Safely read an environment variable.
14 | Raises errors if it is not defined or it is empty.
15 | :param env_name: the name of the environment variable
16 | :param default: the default (optional) value for the environment variable
17 | :return: the value of the environment variable
18 | """
19 | if env_name not in os.environ:
20 | if default is None:
21 | message = f"{env_name} not defined and no default value is present!"
22 | pylogger.error(message)
23 | raise KeyError(message)
24 | return default
25 |
26 | env_value: str = os.environ[env_name]
27 | if not env_value:
28 | if default is None:
29 | message = (
30 | f"{env_name} has yet to be configured and no default value is present!"
31 | )
32 | pylogger.error(message)
33 | raise ValueError(message)
34 | return default
35 |
36 | return env_value
37 |
38 |
39 | def load_envs(env_file: Optional[str] = None) -> None:
40 | """Load all the environment variables defined in the `env_file`.
41 | This is equivalent to `. env_file` in bash.
42 | It is possible to define all the system specific variables in the `env_file`.
43 | :param env_file: the file that defines the environment variables to use. If None
44 | it searches for a `.env` file in the project.
45 | """
46 | dotenv.load_dotenv(dotenv_path=env_file, override=True)
47 |
48 |
49 | # Load environment variables
50 | load_envs()
51 |
52 |
53 | if "PROJECT_ROOT" not in os.environ:
54 | try:
55 | PROJECT_ROOT = Path(
56 | git.Repo(Path.cwd(), search_parent_directories=True).working_dir
57 | )
58 | except git.exc.InvalidGitRepositoryError:
59 | PROJECT_ROOT = Path.cwd()
60 |
61 | pylogger.debug(f"Inferred project root: {PROJECT_ROOT}")
62 | os.environ["PROJECT_ROOT"] = str(PROJECT_ROOT)
63 | else:
64 | PROJECT_ROOT: Path = Path(os.environ["PROJECT_ROOT"])
65 |
66 | __all__ = ["__version__", "PROJECT_ROOT"]
67 |
--------------------------------------------------------------------------------
/src/dataset/flores_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import datasets
4 |
5 | from src.utils.utils import retrieve_map_languages_flores, clean_text
6 | import typing as t
7 |
8 |
9 | class Flores(Dataset):
10 | def __init__(
11 | self,
12 | src_lan: str = "ro",
13 | tgt_lan: str = "en",
14 | hugginface_tokenizer=None,
15 | split: str = None,
16 | ):
17 | self.name = "flores"
18 | self.max_length = 511
19 | self.src_lan = retrieve_map_languages_flores(src_lan).lower()[:3]
20 | self.tgt_lan = retrieve_map_languages_flores(tgt_lan).lower()[:3]
21 |
22 | if "test" in split:
23 | split = "dev" + split
24 |
25 | self.translation_dataset_src = datasets.load_dataset(
26 | "gsarti/flores_101", self.src_lan, split=split
27 | )
28 | self.translation_dataset_tgt = datasets.load_dataset(
29 | "gsarti/flores_101", self.tgt_lan, split=split
30 | )
31 |
32 | with torch.no_grad():
33 | self.tokenizer = hugginface_tokenizer
34 |
35 | def collate_fn(self, batch):
36 |
37 | batch_source = [b[0] for b in batch]
38 | batch_target = [b[1] for b in batch]
39 |
40 | encoded_source = self.tokenizer(
41 | batch_source,
42 | padding=True,
43 | return_tensors="pt",
44 | )
45 | encoded_target = self.tokenizer(
46 | batch_target,
47 | padding=True,
48 | return_tensors="pt",
49 | )
50 |
51 | return {
52 | "source": {
53 | "input_ids": encoded_source["input_ids"],
54 | "attention_mask": encoded_source["attention_mask"],
55 | "sentences": batch_source,
56 | },
57 | "target": {
58 | "input_ids": encoded_target["input_ids"],
59 | "attention_mask": encoded_target["attention_mask"],
60 | "sentences": batch_target,
61 | },
62 | }
63 |
64 | def __len__(self) -> int:
65 | return self.translation_dataset_src.num_rows
66 |
67 | def __getitem__(self, idx: int) -> t.Tuple[str, str]:
68 | source = str(self.translation_dataset_src.data.column(6)[idx])
69 | target = str(self.translation_dataset_tgt.data.column(6)[idx])
70 |
71 | return source, target
72 |
--------------------------------------------------------------------------------
/exp/tab1.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | python3 ../main.py src_lang="en" tgt_lang="de" device="cpu" dataset.name="wmt" dataset.version="14" model.model_name="Helsinki-NLP/opus-mt-en-de"
4 | python3 ../main.py src_lang="de" tgt_lang="en" device="cpu" dataset.name="wmt" dataset.version="14" model.model_name="Helsinki-NLP/opus-mt-en-de"
5 | python3 ../main.py src_lang="en" tgt_lang="de" device="cpu" dataset.name="wmt" dataset.version="14" model.model_name="teelinsan/opus-mt-eng-deu"
6 | python3 ../main.py src_lang="de" tgt_lang="en" device="cpu" dataset.name="wmt" dataset.version="14" model.model_name="teelinsan/opus-mt-eng-deu"
7 | python3 ../main.py src_lang="ro" tgt_lang="en" device="cpu" dataset.name="wmt" dataset.version="16" model.model_name="Helsinki-NLP/opus-mt-roa-en"
8 | python3 ../main.py src_lang="en" tgt_lang="ro" device="cpu" dataset.name="wmt" dataset.version="16" model.model_name="Helsinki-NLP/opus-mt-roa-en"
9 |
10 | python3 ../main.py src_lang="en" tgt_lang="de" device="cpu" dataset.name="wmt" dataset.version="14" task="beam_search" model.model_name="teelinsan/opus-mt-eng-deu"
11 | python3 ../main.py src_lang="de" tgt_lang="en" device="cpu" dataset.name="wmt" dataset.version="14" task="beam_search"
12 | python3 ../main.py src_lang="en" tgt_lang="ro" device="cpu" dataset.name="wmt" dataset.version="16" task="beam_search"
13 | python3 ../main.py src_lang="ro" tgt_lang="en" device="cpu" dataset.name="wmt" dataset.version="16" task="beam_search" model.model_name="Helsinki-NLP/opus-mt-roa-en"
14 |
15 |
16 | python3 ../main.py src_lang="en" tgt_lang="de" device="cuda" dataset.name="wmt" dataset.version="14" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
17 | python3 ../main.py src_lang="de" tgt_lang="en" device="cuda" dataset.name="wmt" dataset.version="14" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
18 | python3 ../main.py src_lang="ro" tgt_lang="en" device="cuda" dataset.name="wmt" dataset.version="16" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
19 | python3 ../main.py src_lang="en" tgt_lang="ro" device="cuda" dataset.name="wmt" dataset.version="16" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
20 |
21 | python3 ../main.py src_lang="en" tgt_lang="de" device="cuda" dataset.name="wmt" dataset.version="14" task="beam_search" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
22 | python3 ../main.py src_lang="de" tgt_lang="en" device="cuda" dataset.name="wmt" dataset.version="14" task="beam_search" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
23 | python3 ../main.py src_lang="en" tgt_lang="ro" device="cuda" dataset.name="wmt" dataset.version="16" task="beam_search" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
24 | python3 ../main.py src_lang="ro" tgt_lang="en" device="cuda" dataset.name="wmt" dataset.version="16" task="beam_search" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
25 |
--------------------------------------------------------------------------------
/src/ipi/initializer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from nltk.tokenize.toktok import ToktokTokenizer
3 | from nltk.tokenize.treebank import TreebankWordDetokenizer
4 | from transformers import (
5 | M2M100Tokenizer,
6 | MarianTokenizer,
7 | MBart50Tokenizer,
8 | MBartTokenizer,
9 | )
10 |
11 |
12 | class Initializer(object):
13 | def __init__(
14 | self,
15 | src_len,
16 | tgt_len,
17 | hugginface_tokenizer,
18 | use_init=True,
19 | device="cpu",
20 | ):
21 |
22 | self.src_len = src_len
23 | self.tgt_len = tgt_len
24 | self.tokenizer = hugginface_tokenizer
25 |
26 | self.pad_token_id = self.tokenizer.pad_token_id
27 | self.tokenizer_nltk = ToktokTokenizer()
28 | self.detokenizer_nltk = TreebankWordDetokenizer()
29 | self.use_init = use_init
30 | self.device = device
31 |
32 | def init_translation(self, tgt_len=None):
33 | final_translation = ""
34 | with self.tokenizer.as_target_tokenizer():
35 | if isinstance(self.tokenizer, MBartTokenizer):
36 | tgt_tensor = self.tokenizer(
37 | final_translation, return_tensors="pt", padding=True
38 | ).data["input_ids"]
39 | if tgt_tensor.shape[-1] == 2:
40 | tgt_tensor = tgt_tensor[:, :1]
41 | elif isinstance(self.tokenizer, MarianTokenizer):
42 | bos = torch.tensor([self.pad_token_id]).unsqueeze(0)
43 | tgt_tensor = bos
44 | elif isinstance(self.tokenizer, MBart50Tokenizer) or isinstance(
45 | self.tokenizer, M2M100Tokenizer
46 | ):
47 | bos = torch.tensor([self.tokenizer.eos_token_id]).unsqueeze(0)
48 | tgt_tensor = self.tokenizer(
49 | final_translation, return_tensors="pt", padding=True
50 | ).data["input_ids"]
51 | tgt_tensor = torch.cat([bos, tgt_tensor], dim=-1)
52 | else:
53 | bos = torch.tensor([self.tokenizer.bos_token_id]).unsqueeze(0)
54 | tgt_tensor = self.tokenizer(
55 | final_translation, return_tensors="pt", padding=True
56 | ).data["input_ids"]
57 | tgt_tensor = torch.cat([bos, tgt_tensor], dim=-1)
58 | if tgt_len is not None:
59 | tgt_tensor = self.trim_length(tgt_tensor, tgt_len)
60 | return tgt_tensor.to(self.device), final_translation
61 |
62 | def trim_length(self, tgt_tensor, tgt_len):
63 | last_elem = tgt_tensor[:, -1].unsqueeze(0)
64 | if tgt_tensor.shape[-1] > tgt_len:
65 | return torch.cat([tgt_tensor[..., : tgt_len - 1], last_elem], dim=-1)
66 | elif tgt_tensor.shape[-1] < tgt_len:
67 | delta = tgt_len - tgt_tensor.shape[-1] - 1
68 | init_tensor = torch.tensor(
69 | [self.pad_token_id] * delta, dtype=tgt_tensor.dtype
70 | ).unsqueeze(0)
71 | return_tensor = torch.cat([tgt_tensor, init_tensor, last_elem], dim=-1)
72 | return return_tensor
73 | else:
74 | return tgt_tensor
75 |
--------------------------------------------------------------------------------
/exp/tab2.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | python3 ../main.py src_lang="en" tgt_lang="fi" device="cuda" dataset.name="wmt" dataset.version="17" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
4 | python3 ../main.py src_lang="fi" tgt_lang="en" device="cuda" dataset.name="wmt" dataset.version="17" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
5 | python3 ../main.py src_lang="en" tgt_lang="hi" device="cuda" dataset.name="ittb" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
6 | python3 ../main.py src_lang="hi" tgt_lang="en" device="cuda" dataset.name="ittb" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
7 | python3 ../main.py src_lang="en" tgt_lang="vi" device="cuda" dataset.name="iwslt" dataset.version="15" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
8 | python3 ../main.py src_lang="vi" tgt_lang="en" device="cuda" dataset.name="iwslt" dataset.version="15" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
9 | python3 ../main.py src_lang="en" tgt_lang="it" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
10 | python3 ../main.py src_lang="it" tgt_lang="en" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
11 | python3 ../main.py src_lang="en" tgt_lang="fr" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
12 | python3 ../main.py src_lang="fr" tgt_lang="en" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt"
13 |
14 |
15 |
16 | python3 ../main.py src_lang="en" tgt_lang="fi" device="cuda" dataset.name="wmt" dataset.version="17" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search"
17 | python3 ../main.py src_lang="fi" tgt_lang="en" device="cuda" dataset.name="wmt" dataset.version="17" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search"
18 | python3 ../main.py src_lang="en" tgt_lang="hi" device="cuda" dataset.name="ittb" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search"
19 | python3 ../main.py src_lang="hi" tgt_lang="en" device="cuda" dataset.name="ittb" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search"
20 | python3 ../main.py src_lang="en" tgt_lang="vi" device="cuda" dataset.name="iwslt" dataset.version="15" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search"
21 | python3 ../main.py src_lang="vi" tgt_lang="en" device="cuda" dataset.name="iwslt" dataset.version="15" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search"
22 | python3 ../main.py src_lang="en" tgt_lang="it" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search"
23 | python3 ../main.py src_lang="it" tgt_lang="en" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search"
24 | python3 ../main.py src_lang="en" tgt_lang="fr" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search"
25 | python3 ../main.py src_lang="fr" tgt_lang="en" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search"
26 |
27 |
--------------------------------------------------------------------------------
/src/dataset/wmt_dataset.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 |
3 | import datasets
4 | import torch
5 | from torch.utils.data import Dataset
6 |
7 | from src.utils.utils import clean_text
8 |
9 |
10 | class Wmt(Dataset):
11 | """
12 | Wmt machine translation dataset reader
13 |
14 | Input:
15 | - version -> the dataset version dataset, by default '16' (dataset-16)
16 | - src_lan -> the source language, by default 'ro' (Romanian)
17 | - tgt_lan -> the target language, by default 'en' (English)
18 | - tokenizer_model -> the tokenizer model
19 | - split -> if not None, allows to split the dataset in following set: ['train', 'test', 'validation']
20 | - concat -> if not None, make possible the concatenation of the specified set.
21 | Note: It works only if split is None
22 | It can be: ['train', 'test', 'validation']
23 | """
24 |
25 | def __init__(
26 | self,
27 | version: str = "16",
28 | src_lan: str = "ro",
29 | tgt_lan: str = "en",
30 | hugginface_tokenizer=None,
31 | split: str = None,
32 | ):
33 | self.src_lan = src_lan
34 | self.tgt_lan = tgt_lan
35 | self.tokenizer_model = hugginface_tokenizer
36 | self.max_length = 511
37 |
38 | if src_lan == "en":
39 | source2target = "{}-{}".format(self.tgt_lan, self.src_lan)
40 | else:
41 | source2target = "{}-{}".format(self.src_lan, self.tgt_lan)
42 |
43 | if version == "19" and "test" in split:
44 | split = "validation"
45 |
46 | version = f"wmt{version}"
47 |
48 | self.name = version
49 |
50 | try:
51 | self.translation_dataset = datasets.load_dataset(
52 | version, source2target, split=split
53 | )
54 | except:
55 | raise ValueError(
56 | f"{version} can read only the pairs cs-en, en-cs, de-en, en-de,"
57 | f" fi-en, en-fi, ro-en, en-ro, ru-en, en-ru, tr-en, en-tr"
58 | )
59 |
60 | with torch.no_grad():
61 | self.tokenizer = hugginface_tokenizer
62 |
63 | def collate_fn(self, batch):
64 |
65 | batch_source = [b[0] for b in batch]
66 | batch_target = [b[1] for b in batch]
67 |
68 | encoded_source = self.tokenizer(
69 | batch_source,
70 | padding=True,
71 | return_tensors="pt",
72 | )
73 | encoded_target = self.tokenizer(
74 | batch_target,
75 | padding=True,
76 | return_tensors="pt",
77 | )
78 |
79 | return {
80 | "source": {
81 | "input_ids": encoded_source["input_ids"],
82 | "attention_mask": encoded_source["attention_mask"],
83 | "sentences": batch_source,
84 | },
85 | "target": {
86 | "input_ids": encoded_target["input_ids"],
87 | "attention_mask": encoded_target["attention_mask"],
88 | "sentences": batch_target,
89 | },
90 | }
91 |
92 | def __len__(self) -> int:
93 | return len(self.translation_dataset)
94 |
95 | def __getitem__(self, idx: int) -> t.Tuple[str, str]:
96 | sample = self.translation_dataset[idx]
97 | source = sample["translation"][self.src_lan]
98 | target = sample["translation"][self.tgt_lan]
99 |
100 | return source, target
101 |
--------------------------------------------------------------------------------
/src/ipi/decoders/autoregressive.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from src.ipi.decoders.mt_decoding import MTDecoder
4 |
5 |
6 | class AutoregressiveDecoder(MTDecoder):
7 | def __init__(self, tokenizer, model, initializer, **kwargs):
8 | super().__init__(tokenizer, model, initializer, **kwargs)
9 |
10 | self.name = "autoregressive"
11 | self.acronym = "a"
12 |
13 | @torch.no_grad()
14 | def decode(self, input_ids, attention_mask, init_tensor=None, logits_preprocessor=None, *args, **kwargs):
15 |
16 | index = 0
17 |
18 | if init_tensor is None:
19 | init_tensor = torch.tensor(
20 | [self.pad_token_id], device=self.device
21 | ).unsqueeze(0)
22 | elif self.is_mbart:
23 | output = self.model(
24 | input_ids,
25 | attention_mask,
26 | decoder_input_ids=init_tensor[:, 0].unsqueeze(0),
27 | use_cache=True,
28 | )
29 | encoder_last_hidden_state = output.encoder_last_hidden_state
30 | past_key_values = output.past_key_values
31 | index += 1
32 | total_res = torch.tensor(
33 | [[init_tensor[:, 0], init_tensor[:, 1]]], device=self.device
34 | )
35 | init_tensor = init_tensor[:, 1].unsqueeze(0)
36 | else:
37 | init_tensor = init_tensor[:, 0].unsqueeze(0)
38 |
39 | total_res = init_tensor.clone()
40 | while True:
41 | if self.use_cache and index > 0:
42 | if index == 1024:
43 | print(total_res)
44 | output = self.model(
45 | None,
46 | attention_mask,
47 | decoder_input_ids=init_tensor,
48 | encoder_outputs=(encoder_last_hidden_state, None, None),
49 | use_cache=True,
50 | past_key_values=past_key_values,
51 | )
52 | else:
53 | output = self.model(
54 | input_ids,
55 | attention_mask,
56 | decoder_input_ids=init_tensor,
57 | use_cache=True,
58 | )
59 | encoder_last_hidden_state = output.encoder_last_hidden_state
60 | past_key_values = output.past_key_values
61 | logits = output.logits
62 | if logits_preprocessor is not None:
63 | logits = logits_preprocessor(total_res, logits[:,-1,:])
64 | else:
65 | logits = logits[:,-1,:]
66 | max_value = torch.argmax(logits, dim=-1)
67 | last = max_value
68 | init_tensor = last.unsqueeze(0)
69 | total_res = torch.cat((total_res, init_tensor), dim=1)
70 |
71 | index += 1
72 | if last[0].item() == self.eos_token_id or index == self.model.config.max_length - 1:
73 | break
74 | return total_res, index
75 |
76 | def initialize(self):
77 | if self.initializer is not None:
78 | init_tensor, _ = self.initializer.init_translation()
79 | else:
80 | init_tensor = None
81 |
82 | return init_tensor
83 |
84 | def compute_decode_kwargs(self, input_ids, *args, **kwargs):
85 | init_tensor = self.initialize()
86 | logits_preprocessor = self.generate_logits_preprocessor(input_ids)
87 |
88 | return {
89 | "init_tensor": init_tensor.clone(),
90 | "logits_preprocessor": logits_preprocessor
91 | }
92 |
93 |
--------------------------------------------------------------------------------
/src/utils/bleu_calculator.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import datasets
4 | import pandas as pd
5 | from tabulate import tabulate
6 |
7 |
8 | class BleuValues:
9 | def __init__(self, **entries):
10 | self.__dict__.update(entries)
11 |
12 |
13 | class BleuEvaluator(object):
14 | def __init__(self):
15 | self.metric = datasets.load_metric("sacrebleu")
16 |
17 | def add_element(self, model_predictions, gold_references):
18 | self.metric.add(predictions=model_predictions, references=gold_references)
19 |
20 | def add_batch(self, predictions, references):
21 | self.metric.add_batch(predictions=predictions, references=references)
22 |
23 | def final_score(self, model_predictions, gold_references):
24 | return self.metric.compute(predictions=model_predictions, references=gold_references)
25 |
26 |
27 | class BleuCalculator:
28 | def __init__(
29 | self,
30 | dataset,
31 | result_dir,
32 | ):
33 | self.dataset = dataset
34 | self.result_dir = result_dir
35 |
36 | @staticmethod
37 | def read_csv(path_csv):
38 | csv_reader = pd.read_csv(path_csv, sep="\t", header=0)
39 | return {
40 | k: v
41 | for k, v in zip(
42 | csv_reader["#sentence"].tolist(), csv_reader["times"].tolist()
43 | )
44 | }
45 |
46 | def _retrieve_files(self):
47 | file2data = dict()
48 | for root, dirs, files in os.walk(self.result_dir):
49 | if any(map(lambda x: "trans" in x, files)) and "initrans" not in root:
50 | trans_files_name = list(filter(lambda x: ("trans" in x), files))[0]
51 | data = self.read_csv(path_csv=os.path.join(root, trans_files_name))
52 | file2data.update({trans_files_name.split(".")[-2]: data})
53 | return file2data
54 |
55 | def _load_dataset(self):
56 | return {i: x[1] for i, x in enumerate(self.dataset)}
57 |
58 | @staticmethod
59 | def _match_indices(method, gold):
60 | new_gold = dict()
61 | for k in gold:
62 | if k in method:
63 | new_gold.update({k: gold[k]})
64 |
65 | return new_gold
66 |
67 | @staticmethod
68 | def _bleu_score_formatter(bleu_score):
69 |
70 | bleu_dict = {
71 | "score": bleu_score["score"],
72 | "counts": str(bleu_score["counts"]),
73 | "totals": str(bleu_score["totals"]),
74 | "precisions": str(["%.2f" % prec for prec in bleu_score["precisions"]]),
75 | "bp": bleu_score["bp"],
76 | "sys_len": bleu_score["sys_len"],
77 | "ref_len": bleu_score["ref_len"],
78 | }
79 |
80 | return BleuValues(**bleu_dict)
81 |
82 | def write_report(self, file2score):
83 | print("Writing report...")
84 |
85 | # Table for the Bleu score
86 | header = ["Metrics"] + [m[0] for m in file2score]
87 |
88 | bleu_table = tabulate(
89 | [
90 | ["Score"] + [b[1].score for b in file2score],
91 | ["Counts"] + [b[1].counts for b in file2score],
92 | ["Totals"] + [b[1].totals for b in file2score],
93 | ["Precisions"] + [b[1].precisions for b in file2score],
94 | ["Bp"] + [b[1].bp for b in file2score],
95 | ["Sys_len"] + [b[1].sys_len for b in file2score],
96 | ["ref_len"] + [b[1].ref_len for b in file2score],
97 | ],
98 | headers=header,
99 | tablefmt="rst",
100 | )
101 |
102 | with open(os.path.join(self.result_dir, "bleu_report.txt"), mode="w") as report:
103 | report.write(f"Bleu Score\n{bleu_table}\n\n")
104 |
105 | def _compute_bleu_score(self, name, translations, gold):
106 | scorer = BleuEvaluator()
107 |
108 | translations = list(translations.values())
109 | gold = list(gold.values())
110 |
111 | gold = [[g] for g in gold]
112 |
113 | # for t, g in zip(translations, gold):
114 | # scorer.add_element(t, [g])
115 |
116 | score_value = scorer.final_score(translations, gold)
117 | return name, self._bleu_score_formatter(score_value)
118 |
119 | def compute_bleu_score(self):
120 | file2data = self._retrieve_files()
121 | gold = self._load_dataset()
122 |
123 | if "trans_beam" in file2data:
124 | beam_search = self._match_indices(
125 | file2data["trans_gs_jacobi"], file2data["trans_beam"]
126 | )
127 | file2data["trans_beam"] = beam_search
128 | gold = self._match_indices(file2data["trans_gs_jacobi"], gold)
129 |
130 | file2score = [
131 | self._compute_bleu_score(file, file2data[file], gold) for file in file2data
132 | ]
133 |
134 | self.write_report(file2score)
135 |
--------------------------------------------------------------------------------
/src/ipi/decoders/mt_decoding.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 |
3 | import torch
4 | from transformers import MBartForConditionalGeneration
5 |
6 | from src.ipi import stopping_condition as sc
7 | from src.utils.utils import get_logits_preprocessor
8 |
9 | PREC_GOLD_AUTOREGRESSIVE: Dict[str, Optional[torch.Tensor]] = {"input_ids": None, "gold": None}
10 | PREC_LOGISTS_PREPROCESSOR: Dict[str, Optional[torch.Tensor]] = {"input_ids": None, "logists": None}
11 |
12 | class MTDecoder:
13 | def __init__(
14 | self,
15 | tokenizer,
16 | model,
17 | initializer,
18 | use_cache: bool = True,
19 | use_logits_preprocessor: bool = True,
20 | device: str = "cuda",
21 | **kwargs
22 | ):
23 | self.tokenizer = tokenizer
24 |
25 | self.initializer = initializer
26 |
27 | self.pad_token_id = self.tokenizer.pad_token_id
28 | self.eos_token_id = self.tokenizer.eos_token_id
29 |
30 | self.use_cache = use_cache
31 | self.device = device
32 |
33 | with torch.no_grad():
34 | self.model = model
35 | self.model_name = self.model.name_or_path
36 | self.model.eval()
37 |
38 | self.max_length = min(self.tokenizer.model_max_length, 511)
39 |
40 | self.use_logits_preprocessor = use_logits_preprocessor
41 |
42 | self.is_mbart = isinstance(self.model, MBartForConditionalGeneration)
43 |
44 | def decode(self, input_ids, attention_mask, *args, **kwargs):
45 | pass
46 |
47 | def compute_decode_kwargs(self, *args, **kwargs):
48 | pass
49 |
50 | def initialize(self, **kwargs):
51 | pass
52 |
53 | def generate_gold_autoregressive(self, input_ids, attention_mask):
54 |
55 | global PREC_GOLD_AUTOREGRESSIVE
56 |
57 | if PREC_GOLD_AUTOREGRESSIVE['input_ids'] is None or not torch.equal(input_ids, PREC_GOLD_AUTOREGRESSIVE['input_ids']):
58 | if self.is_mbart:
59 | with self.tokenizer.as_target_tokenizer():
60 | try:
61 | lang_id = self.tokenizer.cur_lang_code_id
62 | except:
63 | lang_id = self.tokenizer.cur_lang_id
64 | gold_autoregressive = self.model.generate(
65 | **{"input_ids": input_ids, "attention_mask": attention_mask},
66 | num_beams=1,
67 | do_sample=False,
68 | use_cache=False,
69 | forced_bos_token_id=lang_id,
70 | )
71 | else:
72 | gold_autoregressive = self.model.generate(
73 | **{"input_ids": input_ids, "attention_mask": attention_mask},
74 | num_beams=1,
75 | do_sample=False,
76 | use_cache=False,
77 | )
78 | gold_autoregressive = gold_autoregressive[:, : self.max_length]
79 |
80 | PREC_GOLD_AUTOREGRESSIVE['input_ids'] = input_ids
81 | PREC_GOLD_AUTOREGRESSIVE['gold'] = gold_autoregressive
82 |
83 | return PREC_GOLD_AUTOREGRESSIVE['gold']
84 |
85 | def generate_logits_preprocessor(self, input_ids):
86 |
87 | global PREC_LOGISTS_PREPROCESSOR
88 |
89 | if self.use_logits_preprocessor:
90 | if PREC_LOGISTS_PREPROCESSOR['input_ids'] is None or not torch.equal(input_ids, PREC_LOGISTS_PREPROCESSOR['input_ids']):
91 | logits_preprocessor = get_logits_preprocessor(
92 | model=self.model,
93 | input_ids=input_ids,
94 | eos_token_id=self.eos_token_id
95 | )
96 |
97 | PREC_LOGISTS_PREPROCESSOR['input_ids'] = input_ids
98 | PREC_LOGISTS_PREPROCESSOR['logists'] = logits_preprocessor
99 | else:
100 | return None
101 |
102 | return PREC_LOGISTS_PREPROCESSOR['logists']
103 |
104 | @staticmethod
105 | def stopping_criterion(past_tensor, current_tensor, eos=None):
106 | return sc.stopping_criterion(past_tensor, current_tensor, eos)
107 |
108 | @staticmethod
109 | def limit_past_key_values(past_key_values, limit):
110 | return sc.limit_past_key_values(past_key_values, limit)
111 |
112 | @staticmethod
113 | def trig_eos(tensor, eos_token_id, init_tensor, base_value):
114 | if tensor[:, 0].item() == eos_token_id:
115 | return init_tensor[:, : base_value + 1]
116 | else:
117 | return None
118 |
119 |
120 | def generate_target(
121 | tokenizer,
122 | model,
123 | input_ids: torch.Tensor,
124 | attention_mask: torch.Tensor,
125 | is_mbart: bool,
126 | decoding_method: str = "greedy",
127 | remove_padding: bool = False,
128 | ):
129 | if decoding_method == "greedy":
130 | if is_mbart:
131 | with tokenizer.as_target_tokenizer():
132 | gold_output = model.generate(
133 | **{"input_ids": input_ids, "attention_mask": attention_mask},
134 | num_beams=1,
135 | do_sample=False,
136 | forced_bos_token_id=tokenizer.cur_lang_code_id,
137 | )
138 | else:
139 | gold_output = model.generate(
140 | **{"input_ids": input_ids, "attention_mask": attention_mask},
141 | num_beams=1,
142 | do_sample=False,
143 | )
144 | else:
145 | raise NotImplementedError()
146 |
147 | if remove_padding:
148 | sample_lengths = (gold_output != tokenizer.pad_token_id).sum(dim=1)
149 | gold_output = [
150 | sample[:length] for sample, length in zip(gold_output, sample_lengths)
151 | ]
152 |
153 | return gold_output
154 |
155 |
--------------------------------------------------------------------------------
/src/ipi/decoders/gs_jacobi.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from src.ipi.decoders.mt_decoding import MTDecoder
4 | from more_itertools import sliced
5 |
6 |
7 | class GSJacobiDecoder(MTDecoder):
8 | def __init__(self, tokenizer, model, initializer, gs_jaco_blocks, init_mode, **kwargs):
9 | super().__init__(tokenizer, model, initializer, **kwargs)
10 |
11 | self.name = "gs_jacobi"
12 | self.acronym = "g"
13 |
14 | self.gs_jaco_blocks = gs_jaco_blocks
15 | self.init_mode = init_mode
16 |
17 | @torch.no_grad()
18 | def decode(
19 | self, input_ids, attention_mask, target_len, gold_target, init_tensor=None, logits_preprocessor=None, *args, **kwargs
20 | ):
21 | key_cache = 1
22 | if init_tensor is None:
23 | init_tensor = torch.tensor(
24 | [self.pad_token_id] * target_len, device=self.device
25 | )
26 | blocks = list(sliced(init_tensor, self.gs_jaco_blocks))
27 | init_tensor = init_tensor.unsqueeze(0)
28 | total_past_key_values = None
29 | elif self.is_mbart:
30 | output = self.model(
31 | input_ids,
32 | attention_mask,
33 | decoder_input_ids=init_tensor[:, 0].unsqueeze(0),
34 | use_cache=True,
35 | )
36 | encoder_last_hidden_state = output.encoder_last_hidden_state
37 | total_past_key_values = output.past_key_values
38 | delta = 1
39 | # total_res = init_tensor.to(self.device)
40 | init_tensor = init_tensor[:, 1:]
41 | blocks = list(sliced(init_tensor.squeeze(0), self.gs_jaco_blocks))
42 | key_cache = 2
43 | else:
44 | init_tensor = init_tensor
45 | blocks = list(sliced(init_tensor.squeeze(0), self.gs_jaco_blocks))
46 | total_past_key_values = None
47 |
48 | iteration_saved = 0
49 | base_value = 0
50 |
51 | for blocco in blocks:
52 | max_len = blocco.shape[-1]
53 | blocco_usr = init_tensor[:, base_value : base_value + max_len]
54 |
55 | for index in range(max_len):
56 | old_blocco = blocco_usr.detach().clone()
57 | blocco_usr_new = blocco_usr[:, index:]
58 | if base_value == 0 and index == 0 and not self.is_mbart:
59 | output = self.model(
60 | input_ids,
61 | attention_mask,
62 | decoder_input_ids=blocco_usr_new,
63 | use_cache=True,
64 | past_key_values=total_past_key_values,
65 | )
66 | encoder_last_hidden_state = output.encoder_last_hidden_state
67 | else:
68 | output = self.model(
69 | input_ids,
70 | attention_mask,
71 | decoder_input_ids=blocco_usr_new,
72 | encoder_outputs=(encoder_last_hidden_state, None, None),
73 | use_cache=True,
74 | past_key_values=total_past_key_values,
75 | )
76 |
77 | total_past_key_values = self.limit_past_key_values(
78 | output.past_key_values,
79 | base_value + index + key_cache,
80 | )
81 | logits = output.logits
82 | max_value = torch.argmax(logits, dim=-1)
83 |
84 | if logits_preprocessor is not None:
85 | logits_new = logits_preprocessor(init_tensor[:, :base_value + index +1], logits[:, 0, :])
86 | max_value_new = torch.argmax(logits_new, dim=-1)
87 | max_value[:, 0] = max_value_new
88 |
89 |
90 | if (
91 | max_value.shape[-1]
92 | == init_tensor[
93 | :, base_value + index + 1 : base_value + max_len + 1
94 | ].shape[-1]
95 | ):
96 | init_tensor[
97 | :, base_value + index + 1 : base_value + max_len + 1
98 | ] = max_value[:, :]
99 | else:
100 | # If last block remove the last token after EOS
101 | init_tensor[
102 | :, base_value + index + 1 : base_value + max_len + 1
103 | ] = max_value[:, :-1]
104 |
105 | stop_condition, _ = self.stopping_criterion(old_blocco, blocco_usr)
106 |
107 | if stop_condition and index + 1 != max_len:
108 | total_past_key_values = self.limit_past_key_values(
109 | output.past_key_values,
110 | base_value + max_len + 1,
111 | )
112 | iteration_saved += max_len - index - 1
113 | break
114 | base_value += max_len
115 | return init_tensor, (gold_target.shape[-1] - 1) - iteration_saved
116 |
117 | def initialize(self, input_ids, gold_autoregressive):
118 | if self.initializer is not None:
119 | if self.init_mode == "overprov":
120 | m = int(input_ids.shape[-1] + 10 / 100 * input_ids.shape[-1])
121 | else:
122 | m = gold_autoregressive.shape[-1]
123 | init_tensor, _ = self.initializer.init_translation(m)
124 | else:
125 | init_tensor = None
126 |
127 | return init_tensor
128 |
129 | def compute_decode_kwargs(self, input_ids, attention_mask, **kwargs):
130 | gold_autoregressive = self.generate_gold_autoregressive(input_ids, attention_mask)
131 | init_tensor = self.initialize(input_ids, gold_autoregressive)
132 | logits_preprocessor = self.generate_logits_preprocessor(input_ids)
133 |
134 | return{
135 | "init_tensor": init_tensor.clone(),
136 | "gold_target": gold_autoregressive,
137 | "target_len": gold_autoregressive.shape[-1],
138 | "logits_preprocessor": logits_preprocessor
139 | }
140 |
--------------------------------------------------------------------------------
/README.MD:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Accelerating Transformer Inference for Translation via Parallel Decoding
4 |
5 | [](https://arxiv.org/abs/2305.10427)
6 | [](https://aclanthology.org/2023.acl-long.689/)
7 |
8 |
9 |
10 |
11 | This is the code repository of the paper "Accelerating Transformer Inference for Translation via Parallel Decoding" accepted at [ACL 2023 main conference](https://aclanthology.org/2023.acl-long.689/).
12 |
13 | The paper proposes three Parallel Decoding methods to speed up existing autoregressive machine translation models: **Jacobi Decoding**, **GS-Jacobi Decoding**, and **Hybrid GS-Jacobi Decoding**.
14 |
15 | This code is not production-ready and should be used just for research purposes.
16 |
17 | **Paper**: https://arxiv.org/abs/2305.10427
18 |
19 |
20 |

21 |
22 |
23 |
24 | ## Reproduce the results
25 | To produce the benchmark values for en-ro Wmt16, do:
26 | 1. install all the requirements with `pip install -r requirements.txt`
27 | 2. run the following to retrieve the benchmark values:
28 | ```
29 | python3 main.py src_lang="en" tgt_lang="ro" device="cpu" dataset.name="wmt" dataset.version="16" task="benchmark" bench.result_dir="[your path]" model.model_name="Helsinki-NLP/opus-mt-en-ro"
30 | ```
31 | Iters speedup and BLEU results should be easy to reproduce. Time speedups depend on the availability of the underlying hardware and software to run computation in parallel without introducing overheads. Please follow the experimental setting proposed in the paper. The easiest way is to use a virgin virtual machine, we provide the instructions in the Scaling Experiments in this readme.
32 |
33 | ## Datasets
34 | All the datasets are available via `HuggingFace` datasets, so they will be downloaded automatically.
35 | However, `Iwslt` needs to be downloaded manually. In particular, you have to download `2015-01/texts` for `Iwslt15` and `
36 | 2017-01-trnted/texts` for `Iwslt17`. Once downloaded, you should specify the path as parameter in the Python command, by adding `dataset.data_dir=[your path]` (it is possible also to modify it manually in `conf/config.yaml`).
37 |
38 | ## Table 1 and Table 5 Experiments - Parallel Decoding Algorithms
39 | To reproduce results in Table 1 run the command.
40 | ```
41 | /bin/bash ./exp/tab1.sh
42 | ```
43 | Please modify beforehand the result_dir path in `tab1.sh` or in the config file `conf/config.yaml`.
44 |
45 | ## Table 2 and Table 6 Experiments - Cross Languages
46 | To reproduce results in Table 2 run the command.
47 | ```
48 | /bin/bash ./exp/tab2.sh
49 | ```
50 | Please modify beforehand the result_dir path in `tab2.sh` or in the config file `conf/config.yaml`.
51 |
52 | ## Figure 3 and Figure 5 - Scaling Experiments
53 | To reproduce the scaling experiments you need to use Google Clouds with `c2d-standard-XX`, where XX is the number of used cores. Then you need to run the command as specified in section "Reproduce Results" of this README.
54 | To ease the process we provide the command to launch the virtual machine using gcloud-cli.
55 |
56 | ```
57 | gcloud compute instances create instance-1 --zone=us-central1-a --machine-type=c2d-standard-8 --network-interface=network-tier=PREMIUM,subnet=default --maintenance-policy=MIGRATE --provisioning-model=STANDARD --scopes=https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write,https://www.googleapis.com/auth/servicecontrol,https://www.googleapis.com/auth/service.management.readonly,https://www.googleapis.com/auth/trace.append --create-disk=boot=yes,device-name=instance-1,image=projects/ubuntu-os-cloud/global/images/ubuntu-2004-focal-v20221015,mode=rw,size=30 --no-shielded-secure-boot --shielded-vtpm --shielded-integrity-monitoring --reservation-affinity=any
58 | ```
59 | ## Table 3 Experiments - FLOPs calculator
60 |
61 | To reproduce the FLOPs calculation in Table 3 simply run the script:
62 | ```
63 | python3 ./exp/flops_calculator.py
64 | ```
65 |
66 | ## Dependency Graph Visualizer (DDGviz)
67 |
68 | To run the Dependency Graph Visualizer (DDGviz) execute the command:
69 | ```
70 | PYTHONPATH=. python3 ./src/viz/visualize.py
71 | ```
72 | It is possible to select the examples to visualize with the param `--examples [list of id in the dataset]`. The dataset and source/target language can be selected with the corresponding commands, please use the param `--help` for more info.
73 | The output DDGviz visualization will be saved in `iteration_matrix//images`.
74 |
75 |
76 |

77 |
78 |
79 |
80 | ## Citation
81 |
82 | If you use this code please cite:
83 |
84 | ```bibtex
85 | @inproceedings{santilli-etal-2023-accelerating,
86 | title = "Accelerating Transformer Inference for Translation via Parallel Decoding",
87 | author = "Santilli, Andrea and
88 | Severino, Silvio and
89 | Postolache, Emilian and
90 | Maiorca, Valentino and
91 | Mancusi, Michele and
92 | Marin, Riccardo and
93 | Rodola, Emanuele",
94 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
95 | month = jul,
96 | year = "2023",
97 | address = "Toronto, Canada",
98 | publisher = "Association for Computational Linguistics",
99 | url = "https://aclanthology.org/2023.acl-long.689",
100 | pages = "12336--12355",
101 | abstract = "Autoregressive decoding limits the efficiency of transformers for Machine Translation (MT). The community proposed specific network architectures and learning-based methods to solve this issue, which are expensive and require changes to the MT model, trading inference speed at the cost of the translation quality. In this paper, we propose to address the problem from the point of view of decoding algorithms, as a less explored but rather compelling direction. We propose to reframe the standard greedy autoregressive decoding of MT with a parallel formulation leveraging Jacobi and Gauss-Seidel fixed-point iteration methods for fast inference. This formulation allows to speed up existing models without training or modifications while retaining translation quality. We present three parallel decoding algorithms and test them on different languages and models showing how the parallelization introduces a speedup up to 38{\%} w.r.t. the standard autoregressive decoding and nearly 2x when scaling the method on parallel resources. Finally, we introduce a decoding dependency graph visualizer (DDGviz) that let us see how the model has learned the conditional dependence between tokens and inspect the decoding procedure.",
102 | }
103 | ```
104 |
--------------------------------------------------------------------------------
/src/ipi/decoders/jacobi.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from src.ipi.decoders.mt_decoding import MTDecoder
4 | from src.viz.dependecy_graph import DecodingDependencyGraph
5 |
6 |
7 | class JacobiDecoder(MTDecoder):
8 | def __init__(self, tokenizer, model, initializer, **kwargs):
9 | super().__init__(tokenizer, model, initializer, **kwargs)
10 |
11 | self.name = "jacobi"
12 | self.acronym = "j"
13 |
14 | def generate_target(
15 | self,
16 | input_ids: torch.Tensor,
17 | attention_mask: torch.Tensor,
18 | is_mbart: bool,
19 | decoding_method: str = "greedy",
20 | remove_padding: bool = False,
21 | ):
22 | if decoding_method == "greedy":
23 | if is_mbart:
24 | with self.tokenizer.as_target_tokenizer():
25 | gold_output = self.model.generate(
26 | **{"input_ids": input_ids, "attention_mask": attention_mask},
27 | num_beams=1,
28 | do_sample=False,
29 | forced_bos_token_id=self.tokenizer.cur_lang_code_id,
30 | )
31 | else:
32 | gold_output = self.model.generate(
33 | **{"input_ids": input_ids, "attention_mask": attention_mask},
34 | num_beams=1,
35 | do_sample=False,
36 | )
37 | else:
38 | raise NotImplementedError()
39 |
40 | if remove_padding:
41 | sample_lengths = (gold_output != self.tokenizer.pad_token_id).sum(dim=1)
42 | gold_output = [
43 | sample[:length] for sample, length in zip(gold_output, sample_lengths)
44 | ]
45 |
46 | return gold_output
47 |
48 |
49 | @torch.no_grad()
50 | def decode(
51 | self,
52 | input_ids,
53 | attention_mask,
54 | target_len=None,
55 | gold_target=None,
56 | init_tensor=None,
57 | compute_ddg: bool = False,
58 | logits_preprocessor=None,
59 | *args,
60 | **kwargs
61 | ):
62 | max_length = target_len
63 | str_index = 0
64 | key_cache = 0
65 | if compute_ddg:
66 | if gold_target is None:
67 | gold_target = self.generate_target(
68 | input_ids=input_ids,
69 | attention_mask=attention_mask,
70 | is_mbart=self.is_mbart,
71 | )
72 | ddg = DecodingDependencyGraph(
73 | model=self.model, tokenizer=self.tokenizer, gold_target=gold_target
74 | )
75 |
76 | max_length = gold_target.shape[-1]
77 |
78 | if init_tensor is None:
79 | init_tensor = torch.tensor(
80 | [self.pad_token_id] * input_ids.size(0) * max_length,
81 | device=self.device,
82 | ).reshape(input_ids.size(0), max_length)
83 | elif self.is_mbart:
84 | if init_tensor.shape[0] == 1:
85 | decoder_input_ids = init_tensor[:, 0].unsqueeze(0)
86 | else:
87 | decoder_input_ids = init_tensor[:, 0].unsqueeze(-1)
88 | output = self.model(
89 | input_ids,
90 | attention_mask,
91 | decoder_input_ids=decoder_input_ids,
92 | use_cache=True,
93 | )
94 | encoder_last_hidden_state = output.encoder_last_hidden_state
95 | past_key_values = output.past_key_values
96 | str_index = 1
97 | total_res = init_tensor
98 | init_tensor = init_tensor[:, 1:]
99 | key_cache = 1
100 |
101 | output_probs = init_tensor.clone().float()
102 |
103 | for index in range(str_index, max_length):
104 | if self.use_cache and index > 0:
105 | old_init_tensor = total_res.detach().clone()
106 | init_tensor = total_res[:, index:]
107 | output = self.model(
108 | input_ids,
109 | attention_mask,
110 | decoder_input_ids=init_tensor,
111 | encoder_outputs=(encoder_last_hidden_state, None, None),
112 | use_cache=True,
113 | past_key_values=self.limit_past_key_values(past_key_values, index + key_cache),
114 | )
115 | else:
116 | old_init_tensor = init_tensor.detach().clone()
117 | output = self.model(
118 | input_ids,
119 | attention_mask,
120 | decoder_input_ids=init_tensor,
121 | use_cache=True,
122 | )
123 | encoder_last_hidden_state = output.encoder_last_hidden_state
124 | past_key_values = output.past_key_values
125 | logits = output.logits
126 | max_index = torch.argmax(logits, dim=-1)
127 | max_value, max_i = torch.max(torch.softmax(logits, dim=-1), dim=-1)
128 | if index > 0 and logits_preprocessor is not None:
129 | logits_new = logits_preprocessor(total_res[:, : index + 1], logits[:, 0, :])
130 | max_value_new = torch.argmax(logits_new, dim=-1)
131 | max_index[:, 0] = max_value_new
132 | if self.use_cache and index > 0:
133 | init_tensor = max_index
134 | total_res = torch.cat(
135 | (total_res[:, : index + 1], init_tensor[:, :-1]), dim=1
136 | )
137 | else:
138 | init_tensor[:, index + 1 :] = max_index[:, index:-1]
139 | total_res = init_tensor
140 |
141 | output_probs[:, index + 1 :] = max_value[:, index:-1]
142 |
143 | stop_condition, return_tensor = self.stopping_criterion(
144 | old_init_tensor, total_res
145 | )
146 |
147 | if compute_ddg:
148 | ddg.insert_one_element(
149 | total_res, gold_target, output_probs=output_probs
150 | )
151 |
152 | if stop_condition:
153 | break
154 |
155 | if compute_ddg:
156 | return return_tensor, index, ddg
157 |
158 | return return_tensor, index
159 |
160 |
161 | def initialize(self, init_transl):
162 | if self.initializer is not None:
163 | init_tensor, _ = self.initializer.init_translation(init_transl.shape[-1])
164 | else:
165 | init_tensor = None
166 |
167 | return init_tensor
168 |
169 | def compute_decode_kwargs(self, input_ids, attention_mask, **kwargs):
170 |
171 | gold_autoregressive = self.generate_gold_autoregressive(input_ids, attention_mask)
172 | init_tensor = self.initialize(init_transl=gold_autoregressive)
173 | logits_preprocessor = self.generate_logits_preprocessor(input_ids)
174 |
175 | return {
176 | "init_tensor": init_tensor.clone(),
177 | "target_len": gold_autoregressive.shape[-1],
178 | "gold_target": gold_autoregressive,
179 | "logits_preprocessor": logits_preprocessor
180 | }
181 |
--------------------------------------------------------------------------------
/src/viz/dependecy_graph.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 | import plotly.express as px
5 | import torch
6 | from transformers import MBart50Tokenizer
7 |
8 |
9 | class DecodingDependencyGraph(object):
10 | def __init__(self, model, tokenizer, gold_target: torch.Tensor):
11 | self.model = model
12 | self.tokenizer = tokenizer
13 | self.init_matrix = []
14 | self.output_probs = []
15 | self.gold_target = gold_target
16 |
17 | if isinstance(tokenizer, MBart50Tokenizer):
18 | self.is_mbart = True
19 | self.starting_index = 2
20 | else:
21 | self.is_mbart = False
22 | self.starting_index = 1
23 |
24 | def insert_one_element(
25 | self, current_tensor, gold_tensor, index=None, output_probs=None
26 | ):
27 | self.init_matrix.append(
28 | (
29 | current_tensor[:, self.starting_index :]
30 | == gold_tensor[:, self.starting_index :]
31 | )
32 | )
33 | if output_probs is not None:
34 | self.output_probs.append(output_probs)
35 |
36 | def finalize_matrices(self) -> Tuple[torch.Tensor, torch.Tensor]:
37 | return (
38 | torch.stack(self.init_matrix).permute(1, 0, 2).cpu(),
39 | torch.stack(self.output_probs).permute(1, 0, 2).cpu(),
40 | )
41 |
42 | def _create_labels(self, sample_index: int, method):
43 | sample_target = self.gold_target[sample_index, :].squeeze(0)
44 |
45 | if method == "decoded_ids":
46 | labels = self.tokenizer.convert_ids_to_tokens(sample_target)
47 | labels = [
48 | f"{i}:{id}"
49 | for i, id in zip(
50 | labels[self.starting_index - 1 :],
51 | sample_target[self.starting_index - 1 :],
52 | )
53 | ]
54 | elif method == "basic":
55 | labels = [f"{i}" for i in sample_target[self.starting_index - 1 :].tolist()]
56 |
57 | return labels
58 |
59 | def pretty_print(
60 | self, sample_index: int, sentence_id: str, method="basic", x_remap="text"
61 | ):
62 | labels = self._create_labels(sample_index=sample_index, method=method)
63 | iteration_matrix, probability_matrix = self.finalize_matrices()
64 | iteration_matrix = iteration_matrix[sample_index, :, :].int().numpy()
65 | probability_matrix = probability_matrix[sample_index, :, 1:].numpy()
66 |
67 | mask: np.ndarray = np.zeros_like(iteration_matrix)
68 | i, j = 0, 0
69 | while i < iteration_matrix.shape[0] and j < iteration_matrix.shape[1]:
70 | if iteration_matrix[i, j]:
71 | mask[i, j] += 1
72 | j += 1
73 | else:
74 | mask[i, j:] = iteration_matrix[i, j:]
75 | i += 1
76 |
77 | probability_matrix = mask * probability_matrix
78 |
79 | fig = px.imshow(
80 | iteration_matrix + mask,
81 | binary_compression_level=0,
82 | title=f"Decoding Dependency Graph for sentence {sentence_id}",
83 | color_continuous_scale="Viridis",
84 | )
85 |
86 | fig.update_xaxes(
87 | tickmode="array",
88 | tickvals=list(range(len(labels[1:]))),
89 | ticktext=[
90 | self.tokenizer.convert_ids_to_tokens([x])[0] if x_remap else str(x)
91 | for x in labels[1:]
92 | ],
93 | tickangle=45,
94 | )
95 |
96 | fig.update_traces(
97 | text=[
98 | [f"{xy:.2f}" if xy > 0 else "" for xy in x] for x in probability_matrix
99 | ],
100 | texttemplate="%{text}",
101 | )
102 |
103 | fig.update_layout(
104 | font=dict(family="Courier New, monospace", size=22, color="Black"),
105 | showlegend=False,
106 | coloraxis_showscale=False,
107 | )
108 |
109 | fig.show()
110 |
111 | def plot_confusion_matrix(
112 | self, cm, target_names, title="Confusion matrix", cmap=None, normalize=True
113 | ):
114 | """
115 | given a sklearn confusion matrix (cm), make a nice plot
116 |
117 | Arguments
118 | ---------
119 | cm: confusion matrix from sklearn.metrics.confusion_matrix
120 |
121 | target_names: given classification classes such as [0, 1, 2]
122 | the class names, for example: ['high', 'medium', 'low']
123 |
124 | title: the text to display at the top of the matrix
125 |
126 | cmap: the gradient of the values displayed from matplotlib.pyplot.cm
127 | see http://matplotlib.org/examples/color/colormaps_reference.html
128 | plt.get_cmap('jet') or plt.cm.Blues
129 |
130 | normalize: If False, plot the raw numbers
131 | If True, plot the proportions
132 |
133 | Usage
134 | -----
135 | plot_confusion_matrix(cm = cm, # confusion matrix created by
136 | # sklearn.metrics.confusion_matrix
137 | normalize = True, # show proportions
138 | target_names = y_labels_vals, # list of names of the classes
139 | title = best_estimator_name) # title of graph
140 |
141 | Citiation
142 | ---------
143 | http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
144 |
145 | """
146 | import itertools
147 |
148 | import matplotlib.pyplot as plt
149 | import numpy as np
150 |
151 | accuracy = np.trace(cm) / np.sum(cm).astype("float")
152 | misclass = 1 - accuracy
153 |
154 | if cmap is None:
155 | cmap = plt.get_cmap("Blues")
156 |
157 | plt.figure(figsize=(8, 6))
158 | plt.imshow(cm, interpolation="nearest", cmap=cmap)
159 | plt.title(title)
160 | plt.colorbar()
161 |
162 | if target_names is not None:
163 | tick_marks = np.arange(len(target_names))
164 | plt.xticks(tick_marks, target_names, rotation=45)
165 | plt.yticks(tick_marks, target_names)
166 |
167 | if normalize:
168 | cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
169 |
170 | thresh = cm.max() / 1.5 if normalize else cm.max() / 2
171 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
172 | if normalize:
173 | plt.text(
174 | j,
175 | i,
176 | "{:0.4f}".format(cm[i, j]),
177 | horizontalalignment="center",
178 | color="white" if cm[i, j] > thresh else "black",
179 | )
180 | else:
181 | plt.text(
182 | j,
183 | i,
184 | "{:,}".format(cm[i, j]),
185 | horizontalalignment="center",
186 | color="white" if cm[i, j] > thresh else "black",
187 | )
188 |
189 | plt.tight_layout()
190 | plt.ylabel("True label")
191 | plt.xlabel(
192 | "Predicted label\naccuracy={:0.4f}; misclass={:0.4f}".format(
193 | accuracy, misclass
194 | )
195 | )
196 | plt.show()
197 |
--------------------------------------------------------------------------------
/src/ipi/decoders/hybrid_jacobi.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from more_itertools import sliced
3 |
4 | from src.ipi.decoders.mt_decoding import MTDecoder
5 |
6 |
7 | class HybridJacobiDecoder(MTDecoder):
8 | def __init__(self, tokenizer, model, gs_jaco_blocks, init_mode, initializer, percent=300, **kwargs):
9 | super().__init__(tokenizer, model, initializer, **kwargs)
10 |
11 | self.name = "hybrid_jacobi"
12 | self.acronym = "h"
13 |
14 | self.gs_jaco_blocks = gs_jaco_blocks
15 |
16 | self.init_mode = init_mode
17 | self.percent = percent
18 |
19 | @torch.no_grad()
20 | def decode(
21 | self, input_ids, attention_mask, target_len, gold_target, init_tensor=None, logits_preprocessor=None, *args, **kwargs
22 | ):
23 | key_cache = 1
24 | if init_tensor is None:
25 | init_tensor = torch.tensor(
26 | [self.pad_token_id] * target_len, device=self.device
27 | )
28 | blocks = list(sliced(init_tensor, self.gs_jaco_blocks))
29 | init_tensor = init_tensor.unsqueeze(0)
30 | total_past_key_values = None
31 | elif self.is_mbart:
32 | output = self.model(
33 | input_ids,
34 | attention_mask,
35 | decoder_input_ids=init_tensor[:, 0].unsqueeze(0),
36 | use_cache=True,
37 | )
38 | encoder_last_hidden_state = output.encoder_last_hidden_state
39 | total_past_key_values = output.past_key_values
40 | init_tensor = init_tensor[:, 1:]
41 | blocks = list(sliced(init_tensor.squeeze(0), self.gs_jaco_blocks))
42 | key_cache = 2
43 | else:
44 | init_tensor = init_tensor
45 | blocks = list(sliced(init_tensor.squeeze(0), self.gs_jaco_blocks))
46 | total_past_key_values = None
47 |
48 | iteration_saved = 0
49 | base_value = 0
50 |
51 | for blocco in blocks:
52 | max_len = blocco.shape[-1]
53 | blocco_usr = init_tensor[:, base_value : base_value + max_len]
54 | for index in range(max_len):
55 | old_blocco = blocco_usr.detach().clone()
56 | trig = self.trig_eos(
57 | old_blocco, self.eos_token_id, init_tensor, base_value
58 | )
59 | if trig is not None:
60 | return trig, (gold_target.shape[-1] - 1) - iteration_saved
61 | blocco_usr_new = blocco_usr[:, index:]
62 | if base_value == 0 and index == 0 and not self.is_mbart:
63 | output = self.model(
64 | input_ids,
65 | attention_mask,
66 | decoder_input_ids=blocco_usr_new,
67 | use_cache=True,
68 | past_key_values=total_past_key_values,
69 | )
70 | encoder_last_hidden_state = output.encoder_last_hidden_state
71 | else:
72 | output = self.model(
73 | input_ids,
74 | attention_mask,
75 | decoder_input_ids=blocco_usr_new,
76 | encoder_outputs=(encoder_last_hidden_state, None, None),
77 | use_cache=True,
78 | past_key_values=total_past_key_values,
79 | )
80 |
81 | total_past_key_values = self.limit_past_key_values(
82 | output.past_key_values,
83 | base_value + index + key_cache,
84 | )
85 |
86 | logits = output.logits
87 | max_value = torch.argmax(logits, dim=-1)
88 |
89 | if logits_preprocessor is not None:
90 | logits_new = logits_preprocessor(init_tensor[:, :base_value + index + 1], logits[:, 0, :])
91 | max_value_new = torch.argmax(logits_new, dim=-1)
92 | max_value[:,0] = max_value_new
93 |
94 | if (
95 | max_value.shape[-1]
96 | == init_tensor[
97 | :, base_value + index + 1 : base_value + max_len + 1
98 | ].shape[-1]
99 | ):
100 | init_tensor[
101 | :, base_value + index + 1 : base_value + max_len + 1
102 | ] = max_value[:, :]
103 | else:
104 | # If last block remove the last token after EOS
105 | init_tensor[
106 | :, base_value + index + 1 : base_value + max_len + 1
107 | ] = max_value[:, :-1]
108 |
109 | stop_condition, _, eos_cond = self.stopping_criterion(
110 | old_blocco, blocco_usr, eos=self.eos_token_id
111 | )
112 |
113 | if stop_condition:
114 | if eos_cond >= 0:
115 | return (
116 | init_tensor[:, : base_value + eos_cond + 1],
117 | (gold_target.shape[-1] - 1) - iteration_saved,
118 | )
119 | if index + 1 != max_len:
120 | iteration_saved += max_len - index - 1
121 | total_past_key_values = self.limit_past_key_values(
122 | output.past_key_values,
123 | base_value + max_len + 1,
124 | )
125 | break
126 |
127 | base_value += max_len
128 |
129 | total_res, total_iter = (
130 | init_tensor,
131 | (gold_target.shape[-1] - 1) - iteration_saved,
132 | )
133 |
134 | init_tensor = init_tensor[:, -1].clone().unsqueeze(0)
135 |
136 | #Go autoregressive until [EOS]
137 | while True and base_value != self.model.config.max_length - 1:
138 | index = 0
139 | output = self.model(
140 | input_ids,
141 | attention_mask,
142 | decoder_input_ids=init_tensor,
143 | encoder_outputs=(encoder_last_hidden_state, None, None),
144 | use_cache=True,
145 | past_key_values=total_past_key_values,
146 | )
147 | encoder_last_hidden_state = output.encoder_last_hidden_state
148 | total_past_key_values = output.past_key_values
149 | logits = output.logits
150 | max_value = torch.argmax(logits, dim=-1)
151 | last = max_value[:, -1]
152 | if self.use_cache:
153 | init_tensor = last.unsqueeze(0)
154 | total_res = torch.cat((total_res, init_tensor), dim=1)
155 |
156 | index += 1
157 | if last[0].item() == self.eos_token_id:
158 | break
159 | return total_res, index + total_iter
160 |
161 | def initialize(self, input_ids, gold_autoregressive):
162 | if self.initializer is not None:
163 | if self.init_mode == "under":
164 | len = max(3, input_ids.shape[-1] - self.percent / 100 * input_ids.shape[-1])
165 | m = int(len)
166 | elif self.init_mode == "over":
167 | len = input_ids.shape[-1] + self.percent / 100 * input_ids.shape[-1]
168 | m = int(len)
169 | elif self.init_mode == "fixed":
170 | m = 511
171 | else:
172 | m = gold_autoregressive.shape[-1]
173 | init_tensor, _ = self.initializer.init_translation(m)
174 | else:
175 | init_tensor = None
176 |
177 | return init_tensor
178 |
179 | def compute_decode_kwargs(self, input_ids, attention_mask, **kwargs):
180 | gold_autoregressive = self.generate_gold_autoregressive(input_ids, attention_mask)
181 | init_tensor = self.initialize(input_ids, gold_autoregressive)
182 | logits_preprocessor = self.generate_logits_preprocessor(input_ids)
183 |
184 | return{
185 | "init_tensor": init_tensor.clone(),
186 | "gold_target": gold_autoregressive,
187 | "target_len": gold_autoregressive.shape[-1],
188 | "logits_preprocessor": logits_preprocessor
189 | }
--------------------------------------------------------------------------------
/src/dataset/iwslt_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import typing as t
3 |
4 | import datasets
5 | import torch
6 | from datasets.utils.download_manager import DownloadManager
7 | from torch.utils.data.dataset import Dataset
8 |
9 | from src.utils.utils import clean_text
10 |
11 |
12 | class Iwslt(Dataset):
13 | def __init__(
14 | self,
15 | version: str = "17",
16 | src_lan: str = "en",
17 | tgt_lan: str = "ro",
18 | data_dir: str = None,
19 | hugginface_tokenizer=None,
20 | split: str = None,
21 | ):
22 | self.version = version
23 | self.src_lan = src_lan
24 | self.tgt_lan = tgt_lan
25 | self.max_length = 511
26 |
27 | self.dl = DownloadManager()
28 |
29 | self.name = f"iwslt{self.version}"
30 |
31 | self.version2folder = {
32 | "15": os.path.join(data_dir, "2015-01/texts"),
33 | "17": os.path.join(data_dir, "2017-01-trnted/texts"),
34 | }
35 | self.version2years = {
36 | "15": {"train_and_test": [2010, 2011, 2012, 2013], "dev": [2010]},
37 | "17": {
38 | "train_and_test": [2010, 2011, 2012, 2013, 2014, 2015],
39 | "dev": [2010],
40 | },
41 | }
42 |
43 | data_file = f"{self.version2folder[version]}/{src_lan}/{tgt_lan}/{src_lan}-{tgt_lan}.tgz"
44 |
45 | splitted_generators = self._split_generators(data_file)
46 | self.translation_dataset = self.load_dataset(splitted_generators, split=split)
47 |
48 | with torch.no_grad():
49 | self.tokenizer = hugginface_tokenizer
50 |
51 | def load_dataset(
52 | self,
53 | splitted_generators: t.List[datasets.SplitGenerator],
54 | split: str,
55 | ) -> t.List[t.Dict]:
56 | splitted_generators = self.concat_dataset(splitted_generators, split)
57 |
58 | return list(
59 | self._generate_examples(
60 | source_files=splitted_generators.gen_kwargs["source_files"],
61 | target_files=splitted_generators.gen_kwargs["target_files"],
62 | )
63 | )
64 |
65 | @staticmethod
66 | def concat_dataset(
67 | splitted_generators: t.List[datasets.SplitGenerator],
68 | split: str,
69 | ) -> datasets.SplitGenerator:
70 | split2ix = {"train": 0, "test": 1, "validation": 2}
71 | assert (
72 | split in split2ix
73 | ), "Iwslt: split must be either train or test on validation"
74 | if split is not None:
75 | return splitted_generators[split2ix[split]]
76 |
77 | def _split_generators(self, data_file: str) -> t.List[datasets.SplitGenerator]:
78 | """Returns SplitGenerators."""
79 | pair = f"{self.src_lan}-{self.tgt_lan}"
80 | dl_dir = self.dl.extract(data_file)
81 | data_dir = os.path.join(dl_dir, f"{self.src_lan}-{self.tgt_lan}")
82 |
83 | years = self.version2years[self.version]["train_and_test"]
84 | dev = self.version2years[self.version]["dev"]
85 |
86 | return [
87 | datasets.SplitGenerator(
88 | name=datasets.Split.TRAIN,
89 | # These kwargs will be passed to _generate_examples
90 | gen_kwargs={
91 | "source_files": [
92 | os.path.join(
93 | data_dir,
94 | f"train.tags.{pair}.{self.src_lan}",
95 | )
96 | ],
97 | "target_files": [
98 | os.path.join(
99 | data_dir,
100 | f"train.tags.{pair}.{self.tgt_lan}",
101 | )
102 | ],
103 | "split": "train",
104 | },
105 | ),
106 | datasets.SplitGenerator(
107 | name=datasets.Split.TEST,
108 | # These kwargs will be passed to _generate_examples
109 | gen_kwargs={
110 | "source_files": [
111 | os.path.join(
112 | data_dir,
113 | f"IWSLT{self.version}.TED.tst{year}.{pair}.{self.src_lan}.xml",
114 | )
115 | for year in years
116 | ],
117 | "target_files": [
118 | os.path.join(
119 | data_dir,
120 | f"IWSLT{self.version}.TED.tst{year}.{pair}.{self.tgt_lan}.xml",
121 | )
122 | for year in years
123 | ],
124 | "split": "test",
125 | },
126 | ),
127 | datasets.SplitGenerator(
128 | name=datasets.Split.VALIDATION,
129 | # These kwargs will be passed to _generate_examples
130 | gen_kwargs={
131 | "source_files": [
132 | os.path.join(
133 | data_dir,
134 | f"IWSLT{self.version}.TED.dev{year}.{pair}.{self.src_lan}.xml",
135 | )
136 | for year in dev
137 | ],
138 | "target_files": [
139 | os.path.join(
140 | data_dir,
141 | f"IWSLT{self.version}.TED.dev{year}.{pair}.{self.tgt_lan}.xml",
142 | )
143 | for year in dev
144 | ],
145 | "split": "validation",
146 | },
147 | ),
148 | ]
149 |
150 | def _generate_examples(
151 | self, source_files: t.List[str], target_files: t.List[str]
152 | ) -> t.List[t.Dict]:
153 | """Yields examples."""
154 | for source_file, target_file in zip(source_files, target_files):
155 | with open(source_file, "r", encoding="utf-8") as sf:
156 | with open(target_file, "r", encoding="utf-8") as tf:
157 | for source_row, target_row in zip(sf, tf):
158 | source_row = source_row.strip()
159 | target_row = target_row.strip()
160 |
161 | if source_row.startswith("<"):
162 | if source_row.startswith(".....
164 | # Very simple code instead of regex or xml parsing
165 | part1 = source_row.split(">")[1]
166 | source_row = part1.split("<")[0]
167 | part1 = target_row.split(">")[1]
168 | target_row = part1.split("<")[0]
169 |
170 | source_row = source_row.strip()
171 | target_row = target_row.strip()
172 | else:
173 | continue
174 |
175 | yield {
176 | "translation": {
177 | self.src_lan: source_row,
178 | self.tgt_lan: target_row,
179 | }
180 | }
181 |
182 | def collate_fn(self, batch):
183 |
184 | batch_source = [b[0] for b in batch]
185 | batch_target = [b[1] for b in batch]
186 |
187 | encoded_source = self.tokenizer(
188 | batch_source,
189 | padding=True,
190 | return_tensors="pt",
191 | )
192 | encoded_target = self.tokenizer(
193 | batch_target,
194 | padding=True,
195 | return_tensors="pt",
196 | )
197 |
198 | return {
199 | "source": {
200 | "input_ids": encoded_source["input_ids"],
201 | "attention_mask": encoded_source["attention_mask"],
202 | "sentences": batch_source,
203 | },
204 | "target": {
205 | "input_ids": encoded_target["input_ids"],
206 | "attention_mask": encoded_target["attention_mask"],
207 | "sentences": batch_target,
208 | },
209 | }
210 |
211 | def __len__(self) -> int:
212 | return len(self.translation_dataset)
213 |
214 | def __getitem__(self, idx: int) -> t.Tuple[str, str]:
215 | sample = self.translation_dataset[idx]
216 | source = sample["translation"][self.src_lan]
217 | target = sample["translation"][self.tgt_lan]
218 |
219 | return source, target
220 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | from transformers import (
3 | AutoModelForSeq2SeqLM,
4 | AutoTokenizer,
5 | MBartForConditionalGeneration,
6 | )
7 |
8 | from src.bench import MTBenchmarker
9 | from src.dataset.flores_dataset import Flores
10 | from src.dataset.ittb_dataset import Ittb
11 | from src.dataset.iwslt_dataset import Iwslt
12 | from src.dataset.wmt_dataset import Wmt
13 | from src.ipi.decoders.autoregressive import AutoregressiveDecoder
14 | from src.ipi.decoders.beam_search import BeamSearchDecoder
15 | from src.ipi.decoders.gs_jacobi import GSJacobiDecoder
16 | from src.ipi.decoders.hybrid_jacobi import HybridJacobiDecoder
17 | from src.ipi.decoders.jacobi import JacobiDecoder
18 | from src.ipi.initializer import Initializer
19 | from src.ipi.decoders.mt_decoding import MTDecoder
20 | from src.utils.beam_search import BeamSearcher
21 | from src.utils.utils import retrieve_samples
22 |
23 |
24 | def load_tokenizer(cfg):
25 | # MBart
26 | mapping_dict = {
27 | "ar": "ar_AR",
28 | "cs": "cs_CZ",
29 | "de": "de_DE",
30 | "en": "en_XX",
31 | "es": "es_XX",
32 | "et": "et_EE",
33 | "fi": "fi_FI",
34 | "fr": "fr_XX",
35 | "gu": "gu_IN",
36 | "hi": "hi_IN",
37 | "it": "it_IT",
38 | "ja": "ja_XX",
39 | "kk": "kk_KZ",
40 | "ko": "ko_KR",
41 | "lt": "lt_LT",
42 | "lv": "lv_LV",
43 | "my": "my_MM",
44 | "ne": "ne_NP",
45 | "nl": "nl_XX",
46 | "ro": "ro_RO",
47 | "ru": "ru_RU",
48 | "si": "si_LK",
49 | "tr": "tr_TR",
50 | "vi": "vi_VN",
51 | "zh": "zh_CN",
52 | "af": "af_ZA",
53 | "az": "az_AZ",
54 | "bn": "bn_IN",
55 | "fa": "fa_IR",
56 | "he": "he_IL",
57 | "hr": "hr_HR",
58 | "id": "id_ID",
59 | "ka": "ka_GE",
60 | "km": "km_KH",
61 | "mk": "mk_MK",
62 | "ml": "ml_IN",
63 | "mn": "mn_MN",
64 | "mr": "mr_IN",
65 | "pl": "pl_PL",
66 | "ps": "ps_AF",
67 | "pt": "pt_XX",
68 | "sv": "sv_SE",
69 | "sw": "sw_KE",
70 | "ta": "ta_IN",
71 | "te": "te_IN",
72 | "th": "th_TH",
73 | "tl": "tl_XX",
74 | "uk": "uk_UA",
75 | "ur": "ur_PK",
76 | "xh": "xh_ZA",
77 | "gl": "gl_ES",
78 | "sl": "sl_SI",
79 | }
80 |
81 | if "mbart" in cfg.model_name:
82 | tokenizer = AutoTokenizer.from_pretrained(
83 | cfg.model_name,
84 | src_lang=mapping_dict[cfg.src_lang],
85 | tgt_lang=mapping_dict[cfg.tgt_lang],
86 | use_fast=False,
87 | )
88 | else:
89 | print(cfg.model_name)
90 | tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
91 |
92 | return tokenizer
93 |
94 |
95 | def load_model(cfg):
96 | if "mbart" in cfg.model_name:
97 | model = MBartForConditionalGeneration.from_pretrained(cfg.model_name).to(
98 | cfg.device
99 | )
100 | else:
101 | model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_name).to(cfg.device)
102 |
103 | return model
104 |
105 |
106 | def load_dataset(tokenizer, cfg):
107 | # Wmt-xx-xx-xx
108 | if cfg.name == "wmt":
109 | split = cfg.split
110 | if cfg.subset.use_subset:
111 | split = f"{cfg.split}[{cfg.subset.start}:{cfg.subset.end + 1}]"
112 |
113 | dataset = Wmt(
114 | version=cfg.version,
115 | src_lan=cfg.src_lang,
116 | tgt_lan=cfg.tgt_lang,
117 | hugginface_tokenizer=tokenizer,
118 | split=split,
119 | )
120 | # Iwsltxx-xx-xx
121 | elif cfg.name == "iwslt":
122 | dataset = Iwslt(
123 | data_dir=cfg.data_dir,
124 | version=str(cfg.version),
125 | src_lan=cfg.src_lang,
126 | tgt_lan=cfg.tgt_lang,
127 | hugginface_tokenizer=tokenizer,
128 | split=cfg.split,
129 | )
130 | elif cfg.name == "ittb":
131 | dataset = Ittb(
132 | src_lan=cfg.src_lang,
133 | tgt_lan=cfg.tgt_lang,
134 | hugginface_tokenizer=tokenizer,
135 | split=cfg.split,
136 | )
137 | elif cfg.name == "flores":
138 | dataset = Flores(
139 | src_lan=cfg.src_lang,
140 | tgt_lan=cfg.tgt_lang,
141 | hugginface_tokenizer=tokenizer,
142 | split=cfg.split,
143 | )
144 | else:
145 | raise ValueError(f"{cfg.dataset.name} is not yet implemented")
146 |
147 | return dataset
148 |
149 |
150 | def load_initializer(tokenizer, cfg):
151 | if cfg.use_initializer:
152 | initializer = Initializer(
153 | src_len=cfg.src_lang,
154 | tgt_len=cfg.tgt_lang,
155 | hugginface_tokenizer=tokenizer,
156 | use_init=cfg.use_init,
157 | device=cfg.device,
158 | )
159 |
160 | else:
161 | initializer = None
162 |
163 | return initializer
164 |
165 |
166 | def compute_beam_search(cfg, model, dataset):
167 | initializer = load_initializer(dataset.tokenizer, cfg.initializer)
168 |
169 | decoder = MTDecoder(
170 | tokenizer=dataset.tokenizer,
171 | model=model,
172 | use_cache=cfg.decoder.use_cache,
173 | gs_jaco_blocks=cfg.decoder.gs_jaco_blocks,
174 | device=cfg.device,
175 | initializer=initializer
176 | )
177 |
178 | beam_searcher = BeamSearcher(
179 | model=model,
180 | dataset=dataset,
181 | initializer=initializer,
182 | decoder=decoder,
183 | batch_size=cfg.beam_search.batch_size,
184 | num_beams=cfg.beam_search.num_beams,
185 | device=cfg.beam_search.device,
186 | no_repeat_ngram_size=2,
187 | early_stopping=True,
188 | result_dir=cfg.beam_search.result_dir,
189 | )
190 |
191 | beam_searcher.compute_beam_search(cfg)
192 |
193 |
194 | def load_decoders(cfg, tokenizer, model, initializer):
195 | decoders = []
196 | for decoder in cfg.decoder.decoders:
197 | if decoder == "autoregressive":
198 | dec = AutoregressiveDecoder(
199 | tokenizer=tokenizer,
200 | model=model,
201 | initializer=initializer,
202 | use_cache=cfg.decoder.use_cache,
203 | device=cfg.decoder.device
204 | )
205 | elif decoder == "jacobi":
206 | dec = JacobiDecoder(
207 | tokenizer=tokenizer,
208 | model=model,
209 | initializer=initializer,
210 | use_cache=cfg.decoder.use_cache,
211 | device=cfg.decoder.device
212 | )
213 | elif decoder == "gs_jacobi":
214 | dec = GSJacobiDecoder(
215 | tokenizer=tokenizer,
216 | model=model,
217 | initializer=initializer,
218 | gs_jaco_blocks=cfg.decoder.gs_jaco_blocks,
219 | init_mode="",
220 | use_cache=cfg.decoder.use_cache,
221 | device=cfg.decoder.device
222 | )
223 | elif decoder == "h_jacobi":
224 | dec = HybridJacobiDecoder(
225 | tokenizer=tokenizer,
226 | model=model,
227 | initializer=initializer,
228 | init_mode="fixed",
229 | gs_jaco_blocks=cfg.decoder.gs_jaco_blocks,
230 | use_cache=cfg.decoder.use_cache,
231 | device=cfg.decoder.device
232 | )
233 | elif decoder == "beam_search":
234 | dec = BeamSearchDecoder(
235 | tokenizer=tokenizer,
236 | model=model,
237 | initializer=initializer,
238 | num_beams=cfg.beam_search.num_beams,
239 | early_stopping=True,
240 | )
241 | else:
242 | raise ValueError(f"{decoder} decoder have not been implemented yet.")
243 |
244 | decoders.append(dec)
245 |
246 | return decoders
247 |
248 |
249 | def compute_benchmark(cfg, tokenizer, dataset, model):
250 | initializer = load_initializer(tokenizer, cfg.initializer)
251 | decoders = load_decoders(cfg, tokenizer, model, initializer)
252 |
253 | benchmarker = MTBenchmarker(
254 | dataset=dataset,
255 | decoders=decoders,
256 | src_lang=cfg.model.src_lang,
257 | tgt_lang=cfg.model.tgt_lang,
258 | result_dir=cfg.bench.result_dir,
259 | device=cfg.bench.device,
260 | debug=True,
261 | )
262 | benchmarker.compute_total_time()
263 |
264 |
265 | @hydra.main(config_path="conf", config_name="config", version_base="1.1")
266 | def main(cfg):
267 | tokenizer = load_tokenizer(cfg.model)
268 |
269 | model = load_model(cfg.model)
270 |
271 | dataset = load_dataset(tokenizer, cfg.dataset)
272 |
273 | if "benchmark" in cfg.task:
274 | compute_benchmark(cfg, tokenizer, dataset, model)
275 | elif "beam_search" in cfg.task:
276 | compute_beam_search(cfg, model, dataset)
277 | elif "sample" in cfg.task:
278 | retrieve_samples(cfg.sample.path, dataset)
279 | else:
280 | raise ValueError(f"{cfg.task} is not yet implemented")
281 |
282 |
283 | if __name__ == "__main__":
284 | main()
285 |
--------------------------------------------------------------------------------
/src/viz/visualize.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import hashlib
3 | import json
4 | from pathlib import Path
5 |
6 | import argparse
7 | import datasets
8 | import numpy as np
9 | import plotly.express as px
10 | import torch.utils.data
11 | from datasets import load_from_disk, Dataset
12 | from transformers import (
13 | AutoModelForSeq2SeqLM,
14 | AutoTokenizer,
15 | PreTrainedTokenizer,
16 | PreTrainedModel,
17 | )
18 |
19 | from src import PROJECT_ROOT
20 | from src.ipi.decoders.jacobi import JacobiDecoder
21 | from src.ipi.initializer import Initializer
22 | from src.ipi.decoders.mt_decoding import MTDecoder, generate_target
23 |
24 |
25 | def add_iteration_matrix(
26 | sample: dict, src_lang: str, tgt_lang: str, model, tokenizer, initializer, decoder
27 | ):
28 | source = sample["translation"][src_lang]
29 | target = sample["translation"][tgt_lang]
30 |
31 | encoded_source = tokenizer(
32 | source,
33 | return_tensors="pt",
34 | )
35 | encoded_target = tokenizer(
36 | target,
37 | return_tensors="pt",
38 | )
39 |
40 | x = {
41 | "source": {
42 | "input_ids": encoded_source["input_ids"],
43 | "attention_mask": encoded_source["attention_mask"],
44 | "sentences": source,
45 | },
46 | "target": {
47 | "input_ids": encoded_target["input_ids"],
48 | "attention_mask": encoded_target["attention_mask"],
49 | "sentences": target,
50 | },
51 | }
52 |
53 | input_ids = encoded_source["input_ids"].to(device)
54 | attention_mask = encoded_source["attention_mask"].to(device)
55 | gold_output = generate_target(
56 | model=model,
57 | tokenizer=tokenizer,
58 | input_ids=input_ids.to(device),
59 | attention_mask=attention_mask.to(device),
60 | is_mbart=False,
61 | )
62 |
63 | init_tensor, _ = initializer.init_translation(gold_output.shape[-1])
64 |
65 | return_tensor, index, ddg = decoder.decode(
66 | input_ids=input_ids,
67 | attention_mask=attention_mask,
68 | init_tensor=init_tensor,
69 | compute_ddg=True,
70 | )
71 |
72 | iteration_matrix, probability_matrix = ddg.finalize_matrices()
73 | parallel_steps = iteration_matrix.shape[1]
74 | gold_steps = gold_output.shape[1]
75 |
76 | return dict(
77 | gold_output=gold_output[0],
78 | iteration_matrix=iteration_matrix[0],
79 | probability_matrix=probability_matrix[0],
80 | parallel_steps=parallel_steps,
81 | gold_steps=gold_steps,
82 | score=gold_steps - parallel_steps,
83 | )
84 |
85 |
86 | def enrich_dataset(
87 | run_info: dict,
88 | device: str,
89 | src_lang: str,
90 | tgt_lang: str,
91 | dataset_name: str,
92 | dataset_key: str,
93 | tokenizer: PreTrainedTokenizer,
94 | model: PreTrainedModel,
95 | force_recompute: bool = False,
96 | ) -> Dataset:
97 | # MarianMT
98 | run_dir: Path = run_info["run_dir"]
99 |
100 | dataset_path: Path = run_dir / "dataset"
101 |
102 | if run_dir.exists() and not force_recompute:
103 | dataset = load_from_disk(str(run_dir / "dataset"))
104 | else:
105 | initializer = Initializer(src_lang, tgt_lang, tokenizer, use_init=False)
106 | model.eval()
107 |
108 | # decoder = MTDecoder(
109 | # tokenizer=tokenizer,
110 | # model=model,
111 | # use_cache=False,
112 | # gs_jaco_blocks=5,
113 | # device=device,
114 | # initializer=initializer
115 | # )
116 |
117 | decoder = JacobiDecoder(
118 | tokenizer=tokenizer,
119 | model=model,
120 | initializer=initializer,
121 | use_cache=False,
122 | device=device
123 | )
124 |
125 | dataset = datasets.load_dataset(
126 | dataset_name,
127 | dataset_key,
128 | split="test[0:10]",
129 | data_dir=str(PROJECT_ROOT / "hf_data"),
130 | cache_dir=str(PROJECT_ROOT / "hf_cache"),
131 | ).map(
132 | function=functools.partial(
133 | add_iteration_matrix,
134 | src_lang=src_lang,
135 | tgt_lang=tgt_lang,
136 | model=model,
137 | initializer=initializer,
138 | decoder=decoder,
139 | tokenizer=tokenizer,
140 | )
141 | )
142 | dataset.save_to_disk(str(dataset_path))
143 |
144 | json.dump(
145 | run_info,
146 | fp=(run_dir / "run_info.json").open("w", encoding="utf-8"),
147 | indent=4,
148 | default=lambda x: str(x)
149 | )
150 |
151 | return dataset
152 |
153 |
154 | def draw(sample: dict, tokenizer: PreTrainedTokenizer, starting_index: int):
155 | labels = [f"{i}" for i in sample["gold_output"][starting_index - 1 :]]
156 | iteration_matrix = torch.as_tensor(sample["iteration_matrix"])
157 | probability_matrix = torch.as_tensor(sample["probability_matrix"])
158 | iteration_matrix = iteration_matrix[:, :].int().numpy()
159 | probability_matrix = probability_matrix[:, 1:].numpy()
160 |
161 | mask: np.ndarray = np.zeros_like(iteration_matrix)
162 | i, j = 0, 0
163 | while i < iteration_matrix.shape[0] and j < iteration_matrix.shape[1]:
164 | if iteration_matrix[i, j]:
165 | mask[i, j] += 1
166 | j += 1
167 | else:
168 | mask[i, j:] = iteration_matrix[i, j:]
169 | i += 1
170 |
171 | probability_matrix = mask * probability_matrix
172 |
173 | fig = px.imshow(
174 | iteration_matrix + mask,
175 | binary_compression_level=0,
176 | # title=f"Decoding Dependency Graph for sentence {sentence_id}",
177 | color_continuous_scale="Viridis",
178 | )
179 |
180 | fig.update_xaxes(
181 | tickmode="array",
182 | tickvals=list(range(len(labels[1:]))),
183 | ticktext=[tokenizer.convert_ids_to_tokens([x])[0] for x in labels[1:]],
184 | tickangle=45,
185 | )
186 |
187 | fig.update_traces(
188 | text=[[f"{xy:.2f}" if xy > 0 else "" for xy in x] for x in probability_matrix],
189 | texttemplate="%{text}",
190 | )
191 |
192 | fig.update_layout(
193 | font=dict(family="Courier New, monospace", size=22, color="Black"),
194 | showlegend=False,
195 | coloraxis_showscale=False,
196 | margin=dict(l=0, r=0, b=0, t=0),
197 | paper_bgcolor="rgba(0,0,0,0)",
198 | plot_bgcolor="rgba(0,0,0,0)",
199 | )
200 |
201 | return fig
202 |
203 |
204 | if __name__ == "__main__":
205 |
206 |
207 | parser = argparse.ArgumentParser(description='Dependency Graph Visualizer (DDGviz)')
208 |
209 | parser.add_argument('--src', default="ro", help='src language')
210 | parser.add_argument('--tgt', default="en", help='target language')
211 | parser.add_argument('--dataset', default="wmt16", help='Dataset name')
212 | # parser.add_argument('--examples', default=[1566, 960], help='Examples to print with DDGviz')
213 | parser.add_argument('--examples', default=[1, 4], help='Examples to print with DDGviz')
214 |
215 | args = parser.parse_args()
216 |
217 | device = "cpu"
218 | src_lang = args.src
219 | tgt_lang = args.tgt
220 | dataset_name: str = args.dataset
221 | dataset_key: str = f"{src_lang}-{tgt_lang}"
222 | langs = {src_lang, tgt_lang}
223 | examples_to_print = args.examples
224 |
225 |
226 | if "en" in langs:
227 | dataset_key = (
228 | f"{src_lang}-{tgt_lang}" if "en" == tgt_lang else f"{tgt_lang}-{src_lang}"
229 | )
230 |
231 | model_src_lang = src_lang
232 | if model_src_lang == "ro":
233 | model_src_lang: str = "roa"
234 |
235 | dataset_split: str = "test"
236 | model_name: str = f"Helsinki-NLP/opus-mt-{model_src_lang}-{tgt_lang}"
237 | # model_name: str = "zannabethl/opus-mt-en-ro-finetuned-en-to-ro"
238 |
239 | tokenizer = AutoTokenizer.from_pretrained(model_name)
240 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
241 |
242 | run_info = dict(
243 | model_name=model_name,
244 | source_lang=src_lang,
245 | target_lang=tgt_lang,
246 | dataset_name=dataset_name,
247 | dataset_key=dataset_key,
248 | split=dataset_split,
249 | )
250 |
251 | run_info_hash: str = hashlib.md5(
252 | json.dumps(run_info).encode(encoding="utf-8")
253 | ).hexdigest()
254 | run_dir: Path = PROJECT_ROOT / "iteration_matrix" / run_info_hash
255 |
256 | run_info["run_dir"] = run_dir
257 |
258 | dataset = (
259 | enrich_dataset(
260 | run_info=run_info,
261 | device=device,
262 | src_lang=src_lang,
263 | tgt_lang=tgt_lang,
264 | dataset_name=dataset_name,
265 | dataset_key=dataset_key,
266 | tokenizer=tokenizer,
267 | model=model,
268 | )
269 | .map(function=lambda x, i: {"index": i}, with_indices=True)
270 | .select(examples_to_print)
271 | )
272 |
273 | starting_index: int = 1
274 | images_dir: Path = run_dir / "images"
275 | images_dir.mkdir(exist_ok=True)
276 | for sample in dataset:
277 | fig = draw(sample=sample, tokenizer=tokenizer, starting_index=1)
278 | # fig.show()
279 | print(f"Index: {sample['index']}")
280 | print(f"Translations: {sample['translation']}")
281 | print(
282 | f"Gold output: {tokenizer.decode(sample['gold_output'], skip_special_tokens=True)}"
283 | )
284 | print()
285 | # input()
286 | # continue
287 | fig.write_image(images_dir / f"{sample['index']}.png", width=1800, height=1500)
288 | x = {x: sample[x] for x in ("translation", "gold_output", "index", "score")}
289 | x["gold_output"] = tokenizer.decode(sample["gold_output"])
290 |
291 | (images_dir / f"{sample['index']}.json").write_text(
292 | json.dumps(x, indent=4, sort_keys=True), encoding="utf-8"
293 | )
294 |
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | import random
4 | import re
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import torch
9 | from torch.nn.modules import linear
10 | import datasets
11 |
12 | def makedirs(path):
13 | if path.endswith((".tsv", ".csv", ".txt")):
14 | path = "/".join(path.split("/")[:-1])
15 |
16 | if not os.path.exists(path):
17 | os.makedirs(path)
18 |
19 | def check_zero_division(a, b):
20 | return "na" if b == 0 else round(a / b, 3)
21 |
22 | def get_logits_preprocessor(model, input_ids, eos_token_id):
23 | logits_preprocessor = model._get_logits_processor(
24 | repetition_penalty=None,
25 | no_repeat_ngram_size=None,
26 | encoder_no_repeat_ngram_size=None,
27 | input_ids_seq_length=1,
28 | encoder_input_ids=input_ids,
29 | bad_words_ids=None,
30 | min_length=0,
31 | max_length=model.config.max_length,
32 | eos_token_id=eos_token_id,
33 | forced_bos_token_id=None,
34 | forced_eos_token_id=None,
35 | prefix_allowed_tokens_fn=None,
36 | num_beams=1,
37 | num_beam_groups=1,
38 | diversity_penalty=None,
39 | remove_invalid_values=None,
40 | exponential_decay_length_penalty=None,
41 | logits_processor=[],
42 | renormalize_logits=None
43 | )
44 | return logits_preprocessor
45 |
46 |
47 |
48 | def retrieve_model_name(model_name):
49 | if "opus" in model_name:
50 | return "opus"
51 | if "mbart" in model_name:
52 | if "50" in model_name:
53 | return "mbart_50"
54 | return "mbart"
55 |
56 |
57 | def seed_everything(seed: int):
58 | random.seed(seed)
59 | os.environ["PYTHONHASHSEED"] = str(seed)
60 | np.random.seed(seed)
61 | torch.manual_seed(seed)
62 | torch.cuda.manual_seed(seed)
63 | torch.backends.cudnn.deterministic = True
64 | torch.backends.cudnn.benchmark = True
65 |
66 |
67 | def retrieve_map_languages_flores(lan):
68 | lang_map = {
69 | "ab": "Abkhaz",
70 | "aa": "Afar",
71 | "af": "Afrikaans",
72 | "ak": "Akan",
73 | "sq": "Albanian",
74 | "am": "Amharic",
75 | "ar": "Arabic",
76 | "an": "Aragonese",
77 | "hy": "Armenian",
78 | "as": "Assamese",
79 | "av": "Avaric",
80 | "ae": "Avestan",
81 | "ay": "Aymara",
82 | "az": "Azerbaijani",
83 | "bm": "Bambara",
84 | "ba": "Bashkir",
85 | "eu": "Basque",
86 | "be": "Belarusian",
87 | "bn": "Bengali",
88 | "bh": "Bihari",
89 | "bi": "Bislama",
90 | "bs": "Bosnian",
91 | "br": "Breton",
92 | "bg": "Bulgarian",
93 | "my": "Burmese",
94 | "ca": "Catalan",
95 | "ch": "Chamorro",
96 | "ce": "Chechen",
97 | "ny": "Chichewa",
98 | "zh": "Chinese",
99 | "cv": "Chuvash",
100 | "kw": "Cornish",
101 | "co": "Corsican",
102 | "cr": "Cree",
103 | "hr": "Croatian",
104 | "cs": "Czech",
105 | "da": "Danish",
106 | "dv": "Divehi",
107 | "nl": "Dutch",
108 | "dz": "Dzongkha",
109 | "en": "English",
110 | "eo": "Esperanto",
111 | "et": "Estonian",
112 | "ee": "Ewe",
113 | "fo": "Faroese",
114 | "fj": "Fijian",
115 | "fi": "Finnish",
116 | "fr": "Franch",
117 | "ff": "Fula",
118 | "gl": "Galician",
119 | "ka": "Georgian",
120 | "de": "German",
121 | "el": "Greek",
122 | "gn": "Guaraní",
123 | "gu": "Gujarati",
124 | "ht": "Haitian",
125 | "ha": "Hausa",
126 | "he": "Hebrew",
127 | "hz": "Herero",
128 | "hi": "Hindi",
129 | "ho": "Hiri Motu",
130 | "hu": "Hungarian",
131 | "ia": "Interlingua",
132 | "id": "Indonesian",
133 | "ie": "Interlingue",
134 | "ga": "Irish",
135 | "ig": "Igbo",
136 | "ik": "Inupiaq",
137 | "io": "Ido",
138 | "is": "Icelandic",
139 | "it": "Italian",
140 | "iu": "Inuktitut",
141 | "ja": "Japanese",
142 | "jv": "Javanese",
143 | "kl": "Kalaallisut",
144 | "kn": "Kannada",
145 | "kr": "Kanuri",
146 | "ks": "Kashmiri",
147 | "kk": "Kazakh",
148 | "km": "Khmer",
149 | "ki": "Kikuyu",
150 | "rw": "Kinyarwanda",
151 | "ky": "Kirghiz",
152 | "kv": "Komi",
153 | "kg": "Kongo",
154 | "ko": "Korean",
155 | "ku": "Kurdish",
156 | "kj": "Kwanyama",
157 | "la": "Latin",
158 | "lb": "Luxembourgish",
159 | "lg": "Luganda",
160 | "li": "Limburgish",
161 | "ln": "Lingala",
162 | "lo": "Lao",
163 | "lt": "Lithuanian",
164 | "lu": "Luba-Katanga",
165 | "lv": "Latvian",
166 | "gv": "Manx",
167 | "mk": "Macedonian",
168 | "mg": "Malagasy",
169 | "ms": "Malay",
170 | "ml": "Malayalam",
171 | "mt": "Maltese",
172 | "mi": "Māori",
173 | "mr": "Marathi",
174 | "mh": "Marshallese",
175 | "mn": "Mongolian",
176 | "na": "Nauru",
177 | "nv": "Navajo",
178 | "nb": "Norwegian Bokmål",
179 | "nd": "North Ndebele",
180 | "ne": "Nepali",
181 | "ng": "Ndonga",
182 | "nn": "Norwegian Nynorsk",
183 | "no": "Norwegian",
184 | "ii": "Nuosu",
185 | "nr": "South Ndebele",
186 | "oc": "Occitan",
187 | "oj": "Ojibwe",
188 | "cu": "Old Church Slavonic",
189 | "om": "Oromo",
190 | "or": "Oriya",
191 | "os": "Ossetian",
192 | "pa": "Panjabi",
193 | "pi": "Pāli",
194 | "fa": "Persian",
195 | "pl": "Polish",
196 | "ps": "Pashto",
197 | "pt": "Portuguese",
198 | "qu": "Quechua",
199 | "rm": "Romansh",
200 | "rn": "Kirundi",
201 | "ro": "Romanian",
202 | "ru": "Russian",
203 | "sa": "Sanskrit",
204 | "sc": "Sardinian",
205 | "sd": "Sindhi",
206 | "se": "Northern Sami",
207 | "sm": "Samoan",
208 | "sg": "Sango",
209 | "sr": "Serbian",
210 | "gd": "Scottish Gaelic",
211 | "sn": "Shona",
212 | "si": "Sinhala",
213 | "sk": "Slovak",
214 | "sl": "Slovene",
215 | "so": "Somali",
216 | "st": "Southern Sotho",
217 | "es": "Spanish",
218 | "su": "Sundanese",
219 | "sw": "Swahili",
220 | "ss": "Swati",
221 | "sv": "Swedish",
222 | "ta": "Tamil",
223 | "te": "Telugu",
224 | "tg": "Tajik",
225 | "th": "Thai",
226 | "ti": "Tigrinya",
227 | "bo": "Tibetan",
228 | "tk": "Turkmen",
229 | "tl": "Tagalog",
230 | "tn": "Tswana",
231 | "to": "Tonga",
232 | "tr": "Turkish",
233 | "ts": "Tsonga",
234 | "tt": "Tatar",
235 | "tw": "Twi",
236 | "ty": "Tahitian",
237 | "ug": "Uighur",
238 | "uk": "Ukrainian",
239 | "ur": "Urdu",
240 | "uz": "Uzbek",
241 | "ve": "Venda",
242 | "vi": "Vietnamese",
243 | "vo": "Volapük",
244 | "wa": "Walloon",
245 | "cy": "Welsh",
246 | "wo": "Wolof",
247 | "fy": "Western Frisian",
248 | "xh": "Xhosa",
249 | "yi": "Yiddish",
250 | "yo": "Yoruba",
251 | "za": "Zhuang",
252 | "zu": "Zulu",
253 | }
254 |
255 | return lang_map[lan]
256 |
257 | def read_csv(path_csv):
258 | csv_reader = pd.read_csv(path_csv, sep="\t", header=0)
259 | return csv_reader["times"].tolist()
260 |
261 |
262 | def retrieve_samples(path, dataset):
263 | sacrebleu = datasets.load_metric("sacrebleu")
264 | decoders = ["autoregressive", "jacobi", "gs_jacobi", "aw_jacobi"]
265 |
266 | decoder2file = {
267 | "autoregressive": {"time": [], "trans": []},
268 | "jacobi": {"time": [], "trans": []},
269 | "gs_jacobi": {"time": [], "trans": []},
270 | "aw_jacobi": {"time": [], "trans": []},
271 | }
272 |
273 | for folder in decoders:
274 | subf = os.path.join(path, folder)
275 |
276 | for root, dirs, files in os.walk(subf):
277 | for filename in files:
278 | if "time" in filename:
279 | times = read_csv(os.path.join(root, filename))
280 | decoder2file[folder]['time'].extend(times)
281 | if 'trans' in filename:
282 | trans = read_csv(os.path.join(root, filename))
283 | decoder2file[folder]['trans'].extend(trans)
284 |
285 | diff_times = np.array([auto-aw for auto, aw in zip(decoder2file['autoregressive']['time'], decoder2file['aw_jacobi']['time'])])
286 | indices = diff_times.argsort()[::-1]
287 |
288 | for i in indices:
289 | target = dataset[int(i)][1]
290 | source = dataset[int(i)][0]
291 | print("gold", "nd", source)
292 | print("gold", "nd", target)
293 | for d in decoders:
294 | prediction = decoder2file[d]['trans'][i]
295 | results = sacrebleu.compute(predictions=[prediction], references=[[target]])
296 | print(d, decoder2file[d]['time'][i], round(results['score'], 3), prediction)
297 |
298 | input()
299 | continue
300 |
301 | def clean_text(text):
302 | return re.sub("[!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~£−€¿]+", " ", text)
303 |
304 |
305 | def read_wrong_beam_translations(path):
306 | csv_reader = pd.read_csv(path, sep="\t", header=0)
307 | idx = csv_reader['i'].tolist()
308 | bleus = csv_reader['bleu'].tolist()
309 | beams = csv_reader['beam'].tolist()
310 | autos = csv_reader['auto'].tolist()
311 | tgts = csv_reader['tgt'].tolist()
312 |
313 | for x in zip(idx, bleus, beams, autos, tgts):
314 | print(f"{x[0]} {x[1]}\nBeam: {x[2]}\nAuto: {x[3]}\nTgt: {x[4]}")
315 | input()
316 | continue
317 |
318 |
319 | def write_sentences(path, data):
320 |
321 | with open(path, 'w') as out_file:
322 | output = csv.writer(out_file, delimiter="\n")
323 | output.writerow(data)
324 |
--------------------------------------------------------------------------------
/src/utils/beam_search.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 |
4 | import torch
5 | import time
6 | from tabulate import tabulate
7 | from torch.utils.data import DataLoader
8 | from tqdm import tqdm
9 | from transformers import MBartForConditionalGeneration, M2M100ForConditionalGeneration
10 | import numpy as np
11 |
12 | from src.utils.bench_scorer import Scorer
13 | from src.utils.utils import makedirs, retrieve_model_name, write_sentences, get_logits_preprocessor
14 |
15 |
16 | class BeamSearcher(object):
17 | def __init__(
18 | self,
19 | dataset,
20 | model,
21 | initializer,
22 | decoder,
23 | num_beams,
24 | no_repeat_ngram_size,
25 | early_stopping,
26 | batch_size,
27 | device,
28 | result_dir,
29 | ):
30 | self.num_beams = num_beams
31 | self.no_repeat_ngram_size = no_repeat_ngram_size
32 | self.early_stopping = early_stopping
33 | self.device = device
34 | self.result_dir = result_dir
35 |
36 | self.dataset = dataset
37 | self.initializer = initializer
38 | self.decoder = decoder
39 |
40 | self.tokenizer = dataset.tokenizer
41 | self.model = model
42 | model.eval().to(self.device)
43 |
44 | self.model_name = self.model.name_or_path
45 | print("model name in beam_search", self.model_name)
46 |
47 | self.dataloader = DataLoader(
48 | dataset, collate_fn=dataset.collate_fn, batch_size=1, shuffle=False
49 | )
50 |
51 | if isinstance(
52 | self.model, MBartForConditionalGeneration
53 | ) or isinstance(
54 | self.model, M2M100ForConditionalGeneration
55 | ):
56 | self.is_mbart = True
57 | else:
58 | self.is_mbart = False
59 |
60 | self.exp_dir = self._retrieve_exp_dir()
61 |
62 | def _synchronize(self):
63 | if self.device == "cuda":
64 | torch.cuda.synchronize()
65 |
66 | def _retrieve_exp_dir(self):
67 | file_name = self._retrieve_file_name()
68 |
69 | exp_dir = os.path.join(self.result_dir, 'beam_search', file_name)
70 | makedirs(exp_dir)
71 |
72 | return exp_dir
73 |
74 | def _retrieve_file_name(self):
75 | model_name = retrieve_model_name(self.model_name.split("/")[1])
76 | lang = f"{self.dataset.src_lan}_{self.dataset.tgt_lan}"
77 | return f"{model_name}/{self.dataset.name}/{lang}"
78 |
79 | def _beam_search(self, input_ids, attention_mask):
80 | if self.is_mbart:
81 | with self.tokenizer.as_target_tokenizer():
82 | try:
83 | lang_id = self.tokenizer.cur_lang_code_id
84 | except:
85 | lang_id = self.tokenizer.cur_lang_id
86 | beam_output = self.model.generate(
87 | **{"input_ids": input_ids, "attention_mask": attention_mask},
88 | num_beams=self.num_beams,
89 | early_stopping=self.early_stopping,
90 | # no_repeat_ngram_size=self.no_repeat_ngram_size,
91 | forced_bos_token_id=lang_id,
92 | # do_sample=False,
93 | # use_cache=True,
94 | )
95 | else:
96 | beam_output = self.model.generate(
97 | **{"input_ids": input_ids, "attention_mask": attention_mask},
98 | num_beams=self.num_beams,
99 | early_stopping=self.early_stopping,
100 | # no_repeat_ngram_size=self.no_repeat_ngram_size,
101 | # do_sample=False,
102 | # use_cache=True,
103 | )
104 |
105 | return beam_output
106 |
107 |
108 | def _bench_time(self, input_ids, attention_mask):
109 | sample_time = []
110 | for _ in range(1):
111 | start = time.perf_counter()
112 | self._synchronize()
113 | beam_output = self._beam_search(input_ids, attention_mask)
114 | self._synchronize()
115 | end = time.perf_counter()
116 | sample_time.append(end - start)
117 |
118 | sample_mean = np.average(sample_time)
119 | sample_variance = np.var(sample_time)
120 |
121 | return sample_mean, sample_variance, beam_output
122 |
123 | def _auto_time(self, input_ids, attention_mask, logits_preprocessor=None):
124 |
125 | if self.initializer is not None:
126 | init_tensor, _ = self.initializer.init_translation()
127 | else:
128 | init_tensor = None
129 |
130 | sample_time = []
131 | for _ in range(1):
132 | init_new = init_tensor.clone()
133 | start = time.perf_counter()
134 | self._synchronize()
135 | auto_output, _ = self.decoder.autoregressive(
136 | input_ids, attention_mask, init_tensor=init_new, logits_preprocessor=logits_preprocessor
137 | )
138 | self._synchronize()
139 | end = time.perf_counter()
140 | sample_time.append(end - start)
141 |
142 | sample_mean = np.average(sample_time)
143 | sample_variance = np.var(sample_time)
144 |
145 | return sample_mean, sample_variance, auto_output
146 |
147 | def compute_beam_search(self, cfg):
148 |
149 | beam_scorer, auto_scorer = Scorer(), Scorer()
150 | worst_beam_translations = []
151 |
152 | pbar = tqdm(self.dataloader, desc="Computing Beam Search...")
153 | for x in pbar:
154 | input_ids = x["source"]["input_ids"].to(self.device)
155 | attention_mask = x["source"]["attention_mask"].to(self.device)
156 |
157 | tgt_text = x['target']['sentences']
158 |
159 | if cfg.model.use_logits_preprocessor:
160 | logits_preprocessor = get_logits_preprocessor(self.decoder, input_ids)
161 | else:
162 | logits_preprocessor = None
163 |
164 | mean_beam, var_beam, beam_output = self._bench_time(input_ids, attention_mask)
165 | mean_auto, var_auto, auto_output = self._auto_time(input_ids, attention_mask, logits_preprocessor)
166 |
167 | translation_beam = self.tokenizer.batch_decode(
168 | beam_output, skip_special_tokens=True
169 | )
170 | translation_auto = self.tokenizer.batch_decode(
171 | auto_output, skip_special_tokens=True
172 | )
173 |
174 | beam_scorer.update_metrics(mean_beam, var_beam, 0, translation_beam[0], tgt_text[0])
175 | auto_scorer.update_metrics(mean_auto, var_auto, 0, translation_auto[0], tgt_text[0])
176 |
177 | worst_beam_translations.extend(
178 | self._compute_tmp_bleu(
179 | translation_beam[0],
180 | translation_auto[0],
181 | tgt_text[0],
182 | beam_scorer.i
183 | )
184 | )
185 |
186 | self.write_report(beam_scorer, auto_scorer, worst_beam_translations)
187 |
188 | def write_report(self, beam_scorer, auto_scorer, worst_beam_translations):
189 | print("Writing report...")
190 |
191 | beam_score = beam_scorer.compute_bleu_score()
192 | auto_score = auto_scorer.compute_bleu_score()
193 |
194 | worst_beam_translations.sort(key=lambda x: x[1])
195 |
196 | # Table for the test info
197 | info_table = tabulate([
198 | ['Model', self.model_name],
199 | ['Dataset', self.dataset.name],
200 | ['Languages', f"{self.dataset.src_lan}-{self.dataset.tgt_lan}"],
201 | ['Device', self.device],
202 | ['Sentences', beam_scorer.i],
203 | ['Num Beams', self.num_beams],
204 | ['No Rep Ngram Size', self.no_repeat_ngram_size],
205 | ['Early Stopping', self.early_stopping],
206 | ['Do Sample', False],
207 | ['Use Cache', True],
208 | ], headers=['Info', 'Value'], tablefmt='grid')
209 |
210 | # Table for the Time
211 | time_table = tabulate([
212 | ['Time', beam_scorer.tot_mean_time, auto_scorer.tot_mean_time, (auto_scorer.tot_mean_time / beam_scorer.tot_mean_time)],
213 | ['Iter', beam_scorer.tot_mean_iter, auto_scorer.tot_mean_iter, 1],
214 | ['Var', beam_scorer.tot_var_time, auto_scorer.tot_var_time, 1],
215 | ], headers=['Metrics', 'Beam', 'Auto', 'Speedup'], tablefmt='grid')
216 |
217 | # Table for the Bleu score
218 | bleu_table = tabulate([
219 | ['Score', beam_score.score, auto_score.score],
220 | ['Counts', beam_score.counts, auto_score.counts],
221 | ['Totals', beam_score.totals, auto_score.totals],
222 | ['Precisions', beam_score.precisions, auto_score.precisions],
223 | ['Bp', beam_score.bp, auto_score.bp],
224 | ['Sys_len', beam_score.sys_len, auto_score.sys_len],
225 | ['ref_len', beam_score.ref_len, auto_score.ref_len],
226 | ], headers=['Metrics', 'Beam Search', 'Auto'], tablefmt='rst')
227 |
228 | with open(os.path.join(self.exp_dir, "report.txt"), mode='w') as report:
229 | report.write(f"Test Info\n{info_table}\n\n")
230 | report.write(f"Time\n{time_table}\n\n")
231 | report.write(f"Bleu Score\n{bleu_table}\n\n")
232 |
233 | with open(os.path.join(self.exp_dir, "worst_beam_translation.csv"), 'w') as csvfile:
234 | writer = csv.writer(csvfile, delimiter='\t')
235 | writer.writerow(['i', 'bleu', 'beam', 'auto', 'tgt'])
236 | for sample in worst_beam_translations:
237 | writer.writerow(list(sample))
238 |
239 | write_sentences(os.path.join(self.exp_dir, "beam.txt"), beam_scorer.predictions)
240 | write_sentences(os.path.join(self.exp_dir, "auto.txt"), auto_scorer.predictions)
241 | write_sentences(os.path.join(self.exp_dir, "reference.txt"), sum(auto_scorer.references, []))
242 |
243 | def _compute_tmp_bleu(self, translation_beam, translation_auto, tgt_text, i):
244 |
245 | beam_tmp_scorer, auto_tmp_scorer = Scorer(), Scorer()
246 |
247 | beam_tmp_scorer.update_metrics(0, 0, 0, translation_beam, tgt_text)
248 | auto_tmp_scorer.update_metrics(0, 0, 0, translation_auto, tgt_text)
249 |
250 | beam_score = beam_tmp_scorer.compute_bleu_score()
251 | auto_score = auto_tmp_scorer.compute_bleu_score()
252 |
253 | if beam_score.score < auto_score.score:
254 | return [(i, beam_score.score, translation_beam, translation_auto, tgt_text)]
255 | else:
256 | return []
257 |
258 |
259 |
260 |
--------------------------------------------------------------------------------
/src/bench.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import logging
3 | import os
4 | import sys
5 | import time
6 |
7 | import torch
8 | from tabulate import tabulate
9 | from torch.utils.data import DataLoader, Dataset
10 | from tqdm import tqdm
11 |
12 | from src.utils import utils
13 | from src.utils.bench_scorer import Scorer
14 | from src.utils.utils import retrieve_model_name, check_zero_division
15 |
16 | logging.basicConfig(
17 | format="%(asctime)s | %(levelname)s | %(name)s | [%(filename)s:%(lineno)d] %(message)s",
18 | datefmt="%Y-%m-%d %H:%M:%S",
19 | level=os.environ.get("LOGLEVEL", "INFO").upper(),
20 | stream=sys.stdout,
21 | )
22 | logger = logging.getLogger("bench")
23 |
24 |
25 | class MTBenchmarker(object):
26 | def __init__(
27 | self,
28 | dataset: Dataset,
29 | decoders,
30 | src_lang,
31 | tgt_lang,
32 | compare_to: str = "autoregressive",
33 | result_dir: str = None,
34 | device: str = "cuda",
35 | debug: bool = False,
36 | ):
37 | self.dataset = dataset
38 | self.decoders = decoders
39 | self.device = device
40 | self.debug = debug
41 |
42 | self.compare_to = compare_to
43 |
44 | self.src_lang = src_lang
45 | self.tgt_lang = tgt_lang
46 | self.model_name = self.decoders[0].model_name.split("/")[1]
47 |
48 | self.dataloader = DataLoader(
49 | dataset, collate_fn=dataset.collate_fn, batch_size=1, shuffle=False
50 | )
51 |
52 | self.result_dir = result_dir
53 | self.exp_dir = self._retrieve_exp_dir()
54 |
55 | def _synchronize(self):
56 | if self.device == "cuda":
57 | torch.cuda.synchronize()
58 |
59 | def _retrieve_exp_dir(self):
60 | file_name = self._retrieve_file_name()
61 |
62 | exp_dir = os.path.join(self.result_dir, "benchmark", file_name)
63 | utils.makedirs(exp_dir)
64 |
65 | return exp_dir
66 |
67 | def _retrieve_file_name(self):
68 | model_name = retrieve_model_name(self.model_name)
69 | lang = f"{self.src_lang}_{self.tgt_lang}"
70 | return f"{model_name}/{self.device}/{self.dataset.name}/{lang}"
71 |
72 | @staticmethod
73 | def _write_on_file(path, item, i):
74 |
75 | utils.makedirs(path)
76 | with open(path, "a") as file:
77 | writer = csv.writer(file, delimiter="\t")
78 | # write header
79 | if os.stat(path).st_size == 0:
80 | writer.writerow(["i", "item"])
81 | writer.writerow([i, item])
82 |
83 | def _compute_info_table(self, best_alg, i, grid):
84 | # Table for the test info
85 | info_table = tabulate([
86 | [
87 | self.decoders[0].model_name,
88 | self.dataset.name,
89 | self.device,
90 | best_alg,
91 | i,
92 | f"{self.src_lang}-{self.tgt_lang}",
93 | ]
94 | ],
95 | headers=[
96 | "Model",
97 | "Dataset",
98 | "Device",
99 | "Best Algorithm",
100 | "Sentences",
101 | "Languages",
102 | ],
103 | tablefmt=grid,
104 | )
105 |
106 | return info_table
107 |
108 | def _compute_time_table(self, scorers, grid):
109 |
110 | if len(scorers) == 1:
111 | name, scorer = list(scorers.items())[0]
112 | times_table = tabulate([
113 | ["Time", scorer.tot_mean_time],
114 | ["Iteration", scorer.tot_mean_iter],
115 | ],
116 | headers=["Metrics", name], tablefmt=grid,
117 | )
118 | else:
119 | tests_header = ["Metrics"] + [name for name in scorers]
120 |
121 | comp_scorer = scorers.get(self.compare_to)
122 | time_speedup = [check_zero_division(comp_scorer.tot_mean_time, scorer.tot_mean_time) for scorer in scorers.values()]
123 | iter_speedup = [check_zero_division(comp_scorer.tot_mean_iter, scorer.tot_mean_iter) for scorer in scorers.values()]
124 |
125 | times_table = tabulate(
126 | [
127 | ["Speedup T"] + time_speedup,
128 | ["Speedup I"] + iter_speedup,
129 | ["Time"] + [scorer.tot_mean_time for scorer in scorers.values()],
130 | ["Iter"] + [scorer.tot_mean_iter for scorer in scorers.values()],
131 | ],
132 | headers=tests_header, tablefmt=grid,
133 | )
134 |
135 | return times_table
136 |
137 | def _compute_bleu_table(self, scorers, grid):
138 |
139 | bleu_scores = [scorer.compute_bleu_score() for scorer in scorers.values()]
140 |
141 | # Table for the Bleu score
142 | bleu_table = tabulate([
143 | ['Score'] + [score.score for score in bleu_scores],
144 | ['Counts'] + [score.counts for score in bleu_scores],
145 | ['Totals'] + [score.totals for score in bleu_scores],
146 | ['Precisions'] + [score.precisions for score in bleu_scores],
147 | ['Bp'] + [score.bp for score in bleu_scores],
148 | ['Sys_len'] + [score.sys_len for score in bleu_scores],
149 | ['ref_len'] + [score.ref_len for score in bleu_scores],
150 | ], headers=["Metrics"] + [name for name in scorers], tablefmt=grid)
151 |
152 | return bleu_table
153 |
154 | def write_report(self, i, scorers, best_algorithm):
155 | print("Writing report...")
156 |
157 | # Compute best algorithm
158 | best_alg = max(best_algorithm, key=best_algorithm.get)
159 | info_table_txt = self._compute_info_table(best_alg, i, grid="grid")
160 | info_table_tex = self._compute_info_table(best_alg, i, grid="latex")
161 |
162 | # Table for the benchmark times
163 | times_table_txt = self._compute_time_table(scorers, grid="rst")
164 | times_table_tex = self._compute_time_table(scorers, grid="latex")
165 |
166 | # Table for the bleu score
167 | bleu_table_txt = self._compute_bleu_table(scorers, grid="rst")
168 |
169 | print(self.exp_dir)
170 |
171 | with open(os.path.join(self.exp_dir, "report.txt"), mode="w") as report:
172 | report.write(f"Test Info\n{info_table_txt}\n\n")
173 | report.write(f"Benchmark\n{times_table_txt}\n\n")
174 | report.write(f"Bleu\n{bleu_table_txt}\n\n")
175 |
176 | with open(os.path.join(self.exp_dir, "latex.txt"), mode="w") as report:
177 | report.write(f"Test Info\n{info_table_tex}\n\n")
178 | report.write(f"Benchmark\n{times_table_tex}\n\n")
179 |
180 | def write_inline(self, i: int, scorers, best_alg):
181 |
182 | for name, scorer in scorers.items():
183 | # Write times
184 | path = os.path.join(self.exp_dir, name, f"{name}.tsv")
185 | self._write_on_file(path, scorer.current_time, i)
186 | # Write iterations
187 | path = os.path.join(self.exp_dir, name, f"iter_{name}.tsv")
188 | self._write_on_file(path, scorer.current_iter, i)
189 | # Write translations
190 | path = os.path.join(self.exp_dir, name, f"trans_{name}.tsv")
191 | self._write_on_file(path, scorer.current_transl, i)
192 | # Write initializations
193 | path = os.path.join(self.exp_dir, name, f"init_{name}.tsv")
194 | self._write_on_file(path, scorer.current_init, i)
195 |
196 | # Write mean
197 | path = os.path.join(self.exp_dir, "meanvar.tsv")
198 | utils.makedirs(path)
199 | with open(path, "a") as file:
200 | writer = csv.writer(file, delimiter="\t")
201 |
202 | # Write header
203 | if os.stat(path).st_size == 0:
204 | header = ["#sentence"] + [f"mean_{name}" for name in scorers] + ["best_alg"]
205 | writer.writerow(header)
206 |
207 | row = [i] + [scorer.current_time for scorer in scorers.values()] + [best_alg]
208 | writer.writerow(row)
209 |
210 | @staticmethod
211 | def _compute_best_algorithm(scorers):
212 |
213 | best = scorers.get(min(scorers, key=lambda x: scorers[x].current_time)).acronym
214 |
215 | return best
216 |
217 | def _compute_postfix(self, scorers, best_algorithms, curr_alg):
218 |
219 | best_alg = max(best_algorithms, key=best_algorithms.get)
220 |
221 | if len(scorers) == 1:
222 | _, scorer = list(scorers.items())[0]
223 | postfix = {
224 | scorer.acronym: (
225 | round(scorer.tot_mean_time, 3),
226 | round(scorer.tot_mean_iter, 3)
227 | ),
228 | "ca": curr_alg,
229 | "ba": best_alg,
230 | }
231 |
232 | else:
233 | comp_scorer = scorers.get(self.compare_to)
234 |
235 | postfix = {
236 | scorer.acronym: (
237 | check_zero_division(comp_scorer.tot_mean_time, scorer.tot_mean_time),
238 | check_zero_division(comp_scorer.tot_mean_iter, scorer.tot_mean_iter)
239 | ) for name, scorer in scorers.items() if name != self.compare_to
240 | }
241 |
242 | postfix.update({"ca": curr_alg, "ba": best_alg})
243 |
244 | return postfix
245 |
246 | def compute_total_time(self):
247 |
248 | i = 0
249 | scorers = {decoder.name: Scorer(decoder.name, decoder.acronym) for decoder in self.decoders}
250 | best_algorithms = {decoder.acronym: 0 for decoder in self.decoders}
251 |
252 | pbar = tqdm(self.dataloader, desc="Computing Benchmark...")
253 | for x in pbar:
254 |
255 | try:
256 | input_ids = x["source"]["input_ids"].to(self.device)
257 | attention_mask = x["source"]["attention_mask"].to(self.device)
258 |
259 | tgt_text = x['target']['sentences']
260 |
261 | for decoder, name in zip(self.decoders, scorers):
262 | kwargs = decoder.compute_decode_kwargs(input_ids, attention_mask)
263 |
264 | start = time.perf_counter()
265 | self._synchronize()
266 | trans, iter = decoder.decode(input_ids, attention_mask, **kwargs)
267 | self._synchronize()
268 | end = time.perf_counter()
269 | sample_time = end - start
270 |
271 | init_tensor = kwargs["init_tensor"] if "init_tensor" in kwargs else ""
272 |
273 | trans = self.dataset.tokenizer.batch_decode(
274 | trans, skip_special_tokens=True
275 | )[0]
276 |
277 | scorers[name].update_metrics(sample_time, iter, trans, tgt_text[0], init_tensor)
278 |
279 | best_alg = self._compute_best_algorithm(scorers)
280 | best_algorithms[best_alg] += 1
281 |
282 | # Update tqdm bar
283 | postfix_pbar = self._compute_postfix(scorers, best_algorithms, best_alg)
284 | pbar.set_postfix(postfix_pbar)
285 |
286 | self.write_inline(i, scorers, best_alg)
287 | except Exception as e:
288 | if self.debug:
289 | raise e
290 | else:
291 | logger.error(e)
292 | return
293 |
294 | i += 1
295 |
296 | self.write_report(i, scorers, best_algorithms)
297 |
--------------------------------------------------------------------------------
/exp/flops_calculator.py:
--------------------------------------------------------------------------------
1 | """Computes the flops needed for training/running transformer networks. Adapted from https://github.com/google-research/electra/blob/master/flops_computation.py """
2 |
3 | import collections
4 |
5 | # We checked this code with TensorFlow"s FLOPs counting, although we had to
6 | # correct for this issue: https://github.com/tensorflow/tensorflow/issues/22071
7 | # Assumptions going into the FLOPs counting
8 | # - An "operation" is a mathematical operation, not a machine instruction. So
9 | # an "exp" takes one opp like and add, even though in practice an exp
10 | # might be slower. This is not too bad an assumption because
11 | # matrix-multiplies dominate the compute for most models, so minor details
12 | # about activation functions don"t matter too much. Similarly, we count
13 | # matrix-multiplies as 2*m*n flops instead of m*n, as one might if
14 | # if considering fused multiply-add ops.
15 | # - Backward pass takes the same number of FLOPs as forward pass. No exactly
16 | # right (e.g., for softmax cross entropy loss the backward pass is faster).
17 | # Importantly, it really is the same for matrix-multiplies, which is most of
18 | # the compute anyway.
19 | # - We assume "dense" embedding lookups (i.e., multiplication by a one-hot
20 | # vector). On some hardware accelerators, these dense operations are
21 | # actually faster than sparse lookups.
22 | # Please open a github issue if you spot a problem with this code!
23 |
24 | # I am not sure if the below constants are 100% right, but they are only applied
25 | # to O(hidden_size) activations, which is generally a lot less compute than the
26 | # matrix-multiplies, which are O(hidden_size^2), so they don't affect the total
27 | # number of FLOPs much.
28 |
29 | # random number, >=, multiply activations by dropout mask, multiply activations
30 | # by correction (1 / (1 - dropout_rate))
31 | DROPOUT_FLOPS = 4
32 |
33 | # compute mean activation (sum), computate variance of activation
34 | # (square and sum), bias (add), scale (multiply)
35 | LAYER_NORM_FLOPS = 5
36 |
37 | # GELU: 0.5 * x * (1 + tanh(sqrt(2 / np.pi) * (x + 0.044715 * pow(x, 3))))
38 | ACTIVATION_FLOPS = 8
39 |
40 | # max/substract (for stability), exp, sum, divide
41 | SOFTMAX_FLOPS = 5
42 |
43 |
44 | class TransformerHparams(object):
45 | """Computes the train/inference FLOPs for transformers."""
46 |
47 | def __init__(self, h, l, s=512, v=30522, e=None, i=None, heads=None,
48 | head_size=None, output_frac=0.15625, sparse_embed_lookup=False,
49 | decoder=False):
50 | self.h = h # hidden size
51 | self.l = l # number of layers
52 | self.s = s # sequence length
53 | self.v = v # vocab size
54 | self.e = h if e is None else e # embedding size
55 | self.i = h * 4 if i is None else i # intermediate size
56 | self.kqv = h if head_size is None else head_size * heads # attn proj sizes
57 | self.heads = max(h // 64, 1) if heads is None else heads # attention heads
58 | self.output_frac = output_frac # percent of tokens using an output softmax
59 | self.sparse_embed_lookup = sparse_embed_lookup # sparse embedding lookups
60 | self.decoder = decoder # decoder has extra attn to encoder states
61 |
62 | def get_block_flops(self):
63 | """Get the forward-pass FLOPs for a single transformer block."""
64 | attn_mul = 2 if self.decoder else 1
65 | block_flops = dict(
66 | kqv=3 * 2 * self.h * self.kqv * attn_mul,
67 | kqv_bias=3 * self.kqv * attn_mul,
68 | attention_scores=2 * self.kqv * self.s * attn_mul,
69 | attn_softmax=SOFTMAX_FLOPS * self.s * self.heads * attn_mul,
70 | attention_dropout=DROPOUT_FLOPS * self.s * self.heads * attn_mul,
71 | attention_scale=self.s * self.heads * attn_mul,
72 | attention_weighted_avg_values=2 * self.h * self.s * attn_mul,
73 | attn_output=2 * self.h * self.h * attn_mul,
74 | attn_output_bias=self.h * attn_mul,
75 | attn_output_dropout=DROPOUT_FLOPS * self.h * attn_mul,
76 | attn_output_residual=self.h * attn_mul,
77 | attn_output_layer_norm=LAYER_NORM_FLOPS * attn_mul,
78 | intermediate=2 * self.h * self.i,
79 | intermediate_act=ACTIVATION_FLOPS * self.i,
80 | intermediate_bias=self.i,
81 | output=2 * self.h * self.i,
82 | output_bias=self.h,
83 | output_dropout=DROPOUT_FLOPS * self.h,
84 | output_residual=self.h,
85 | output_layer_norm=LAYER_NORM_FLOPS * self.h,
86 | )
87 | return sum(block_flops.values()) * self.s
88 |
89 | def get_embedding_flops(self, output=False):
90 | """Get the forward-pass FLOPs the transformer inputs or output softmax."""
91 | embedding_flops = {}
92 | if output or (not self.sparse_embed_lookup):
93 | embedding_flops["main_multiply"] = 2 * self.e * self.v
94 | # input embedding post-processing
95 | if not output:
96 | embedding_flops.update(dict(
97 | tok_type_and_position=2 * self.e * (self.s + 2),
98 | add_tok_type_and_position=2 * self.e,
99 | emb_layer_norm=LAYER_NORM_FLOPS * self.e,
100 | emb_dropout=DROPOUT_FLOPS * self.e
101 | ))
102 | # projection layer if e != h
103 | if self.e != self.h or output:
104 | embedding_flops.update(dict(
105 | hidden_kernel=2 * self.h * self.e,
106 | hidden_bias=self.e if output else self.h
107 | ))
108 | # extra hidden layer and output softmax
109 | if output:
110 | embedding_flops.update(dict(
111 | hidden_activation=ACTIVATION_FLOPS * self.e,
112 | hidden_layernorm=LAYER_NORM_FLOPS * self.e,
113 | output_softmax=SOFTMAX_FLOPS * self.v,
114 | output_target_word=2 * self.v
115 | ))
116 | return self.output_frac * sum(embedding_flops.values()) * self.s
117 | return sum(embedding_flops.values()) * self.s
118 |
119 | def get_binary_classification_flops(self):
120 | classification_flops = dict(
121 | hidden=2 * self.h * self.h,
122 | hidden_bias=self.h,
123 | hidden_act=ACTIVATION_FLOPS * self.h,
124 | logits=2 * self.h
125 | )
126 | return sum(classification_flops.values()) * self.s
127 |
128 | def get_train_flops(self, batch_size, train_steps, discriminator=False, use_backprop=True):
129 | """Get the FLOPs for pre-training the transformer."""
130 | # 2* for forward/backward pass
131 | if use_backprop:
132 | mult = 2
133 | else:
134 | mult = 1
135 | return mult * batch_size * train_steps * (
136 | (self.l * self.get_block_flops()) +
137 | self.get_embedding_flops(output=False) +
138 | (self.get_binary_classification_flops() if discriminator else
139 | self.get_embedding_flops(output=True))
140 | )
141 |
142 | def get_infer_flops(self):
143 | """Get the FLOPs for running inference with the transformer on a
144 | classification task."""
145 | return ((self.l * self.get_block_flops()) +
146 | self.get_embedding_flops(output=False) +
147 | self.get_binary_classification_flops())
148 |
149 |
150 | def get_electra_train_flops(
151 | h_d, l_d, h_g, l_g, batch_size, train_steps, tied_embeddings,
152 | e=None, s=512, output_frac=0.15625):
153 | """Get the FLOPs needed for pre-training ELECTRA."""
154 | if e is None:
155 | e = h_d
156 | disc = TransformerHparams(
157 | h_d, l_d, s=s, e=e,
158 | output_frac=output_frac).get_train_flops(batch_size, train_steps, True)
159 | gen = TransformerHparams(
160 | h_g, l_g, s=s, e=e if tied_embeddings else None,
161 | output_frac=output_frac).get_train_flops(batch_size, train_steps)
162 | return disc + gen
163 |
164 |
165 |
166 | def calculate_LASSFlops(prior_flop):
167 | vq_vaeflops = 0
168 |
169 | MODEL_FLOPS = collections.OrderedDict([
170 | # These runtimes were computed with tensorflow FLOPs counting instead of the
171 | # script, as the neural architectures are quite different.
172 | # 768648884 words in LM1b benchmark, 10 epochs with batch size 20,
173 | # seq length 128, 568093262680 FLOPs per example.
174 | ("elmo", 2 * 10 * 768648884 * 568093262680 / (20.0 * 128)),
175 | # 15064773691518 is FLOPs for forward pass on 32 examples.
176 | # Therefore 2 * steps * batch_size * 15064773691518 / 32 is XLNet compute
177 | ("xlnet", 2 * 500000 * 8192 * 15064773691518 / 32.0),
178 |
179 | # Runtimes computed with the script
180 | ("gpt", TransformerHparams(768, 12, v=40000, output_frac=1.0).get_train_flops(
181 | 128, 960800)),
182 |
183 | ("jukebox", TransformerHparams(1024, 48, s=8192, v=2048, output_frac=1.0, heads=1).get_infer_flops()),
184 |
185 |
186 | ("bert_small", TransformerHparams(256, 12, e=128, s=128).get_train_flops(128, 1.45e6)),
187 | ("bert_base", TransformerHparams(768, 12).get_train_flops(256, 1e6)),
188 | ("bert_large", TransformerHparams(1024, 24).get_train_flops(256, 1e6)),
189 | ("electra_small", get_electra_train_flops(256, 12, 64, 12, 128, 1e6, True, s=128, e=128)),
190 | ("electra_base", get_electra_train_flops(768, 12, 256, 12, 256, 766000, True)),
191 | ("electra_400k", get_electra_train_flops(1024, 24, 256, 24, 2048, 400000, True)),
192 | ("electra_1.75M", get_electra_train_flops(1024, 24, 256, 24, 2048, 1750000, True)),
193 |
194 | # RoBERTa, ALBERT, and T5 have minor architectural differences from
195 | # BERT/ELECTRA, but I believe they don't significantly effect the runtime,
196 | # so we use this script for those models as well.
197 | ("roberta", TransformerHparams(1024, 24, v=50265).get_train_flops(8000, 500000)),
198 | ("albert", TransformerHparams(4096, 12, v=30000, e=128).get_train_flops(
199 | 4096, 1.5e6)),
200 | ("t5_11b", TransformerHparams(
201 | 1024, # hidden size
202 | 24, # layers
203 | v=32000, # vocab size
204 | i=65536, # ff intermediate hidden size
205 | heads=128, head_size=128, # heads/head size
206 | output_frac=0.0 # encoder has no output softmax
207 | ).get_train_flops(2048, 1e6) + # 1M steps with batch size 2048
208 | TransformerHparams(
209 | 1024,
210 | 24,
211 | v=32000,
212 | i=65536,
213 | heads=128, head_size=128,
214 | output_frac=1.0, # decoder has output softmax for all positions
215 | decoder=True
216 | ).get_train_flops(2048, 1e6)),
217 | ("Shallow", TransformerHparams(512, 13, i=2048).get_train_flops(250, 300000)),
218 | ("Shallow + KD", TransformerHparams(512, 13, i=2048).get_train_flops(250, 300000) + TransformerHparams(512, 13, i=2048,
219 | heads=16).get_train_flops(
220 | 250, 300000, use_backprop=False)),
221 | ("DSLP", TransformerHparams(512, 12, i=2048).get_train_flops(250, 300000) + TransformerHparams(512, 12, i=2048,
222 | heads=16).get_train_flops(
223 | 250, 300000, use_backprop=False)),
224 | ("F-VAE", TransformerHparams(512, 12, i=2048).get_train_flops(250, 300000) + TransformerHparams(512, 12, i=2048,
225 | heads=16).get_train_flops(
226 | 250, 300000, use_backprop=False)),
227 | ("DisCo", TransformerHparams(512, 12, i=2048).get_train_flops(250, 300000) + TransformerHparams(1024, 24, i=4096, heads=16).get_train_flops(250, 300000, use_backprop=False)),
228 | ("SUNDAE", TransformerHparams(512, 12, i=2048).get_train_flops(4096, 10e6)),
229 |
230 | ])
231 |
232 |
233 | def main():
234 | for k, v in MODEL_FLOPS.items():
235 | print(k, v)
236 |
237 |
238 | if __name__ == "__main__":
239 | main()
240 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023-2024 Andrea Santilli, Silvio Severino
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------