├── src ├── dataset │ ├── __init__.py │ ├── ittb_dataset.py │ ├── flores_dataset.py │ ├── wmt_dataset.py │ └── iwslt_dataset.py ├── ipi │ ├── __init__.py │ ├── decoders │ │ ├── __init__.py │ │ ├── beam_search.py │ │ ├── autoregressive.py │ │ ├── mt_decoding.py │ │ ├── gs_jacobi.py │ │ ├── jacobi.py │ │ └── hybrid_jacobi.py │ ├── stopping_condition.py │ └── initializer.py ├── utils │ ├── __init__.py │ ├── bench_scorer.py │ ├── bleu_calculator.py │ ├── utils.py │ └── beam_search.py ├── viz │ ├── __init__.py │ ├── dependecy_graph.py │ └── visualize.py ├── __init__.py └── bench.py ├── assets ├── ddg.png └── ipi.png ├── requirements.txt ├── .gitignore ├── conf └── config.yaml ├── exp ├── tab1.sh ├── tab2.sh └── flops_calculator.py ├── README.MD ├── main.py └── LICENSE /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/ipi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/viz/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/ipi/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/ddg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/teelinsan/parallel-decoding/HEAD/assets/ddg.png -------------------------------------------------------------------------------- /assets/ipi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/teelinsan/parallel-decoding/HEAD/assets/ipi.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch~=1.13.1 2 | transformers==4.19.2 3 | datasets==2.1.0 4 | matplotlib>=3.5.2 5 | pytorch-lightning==1.6.3 6 | nltk>=3.7 7 | sacrebleu==2.1.0 8 | tqdm>=4.64.0 9 | python-dotenv==0.20.0 10 | gitpython==3.1.27 11 | more-itertools==8.12.0 12 | hydra-core==1.2.0 13 | plotly>=5.8.2 14 | #sentencepiece==0.1.96 15 | sacremoses==0.0.49 16 | sentence-transformers 17 | kaleido==0.2.1 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | fertility/ 2 | /benchmark/.ipynb_checkpoints/ 3 | /venv/ 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | bin/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | benchmark/ipi/__pycache__/ 27 | 28 | # Installer logs 29 | pip-log.txt 30 | pip-delete-this-directory.txt 31 | 32 | # Unit test / coverage reports 33 | .tox/ 34 | .coverage 35 | .cache 36 | nosetests.xml 37 | coverage.xml 38 | 39 | # Translations 40 | *.mo 41 | 42 | # Mr Developer 43 | .mr.developer.cfg 44 | .project 45 | .pydevproject 46 | 47 | # Rope 48 | .ropeproject 49 | 50 | # Django stuff: 51 | *.log 52 | *.pot 53 | 54 | # Sphinx documentation 55 | docs/_build/ 56 | .installed.cfg 57 | bin 58 | develop-eggs 59 | dist 60 | downloads 61 | eggs 62 | parts 63 | src/*.egg-info 64 | lib 65 | lib64 66 | 67 | *.pyc 68 | *.pyo 69 | 70 | -------------------------------------------------------------------------------- /conf/config.yaml: -------------------------------------------------------------------------------- 1 | device: "cpu" 2 | src_lang: "en" 3 | tgt_lang: "de" 4 | task: "benchmark" 5 | 6 | model: 7 | src_lang: ${src_lang} 8 | tgt_lang: ${tgt_lang} 9 | model_name: "Helsinki-NLP/opus-mt-${src_lang}-${tgt_lang}" 10 | device: ${device} 11 | use_logits_preprocessor: True 12 | 13 | dataset: 14 | data_dir: "" 15 | src_lang: ${src_lang} 16 | tgt_lang: ${tgt_lang} 17 | name: "wmt" 18 | version: "14" 19 | split: 'test' 20 | subset: 21 | use_subset: False 22 | start: 10 23 | end: 10 24 | 25 | bench: 26 | result_dir: "" 27 | n_runs: 1 28 | device: ${device} 29 | 30 | beam_search: 31 | result_dir: "" 32 | batch_size: 1 33 | num_beams: 15 34 | device: ${device} 35 | 36 | initializer: 37 | use_initializer: True 38 | path_lexicon: null 39 | use_init: False 40 | src_lang: ${src_lang} 41 | tgt_lang: ${tgt_lang} 42 | device: ${device} 43 | 44 | decoder: 45 | gs_jaco_blocks: 3 46 | use_cache: True 47 | # Specify the decoders to use 48 | decoders: ['autoregressive', 'jacobi', 'gs_jacobi', 'h_jacobi', 'beam_search'] 49 | device: ${device} 50 | 51 | 52 | sample: 53 | path: "" -------------------------------------------------------------------------------- /src/ipi/decoders/beam_search.py: -------------------------------------------------------------------------------- 1 | from src.ipi.decoders.mt_decoding import MTDecoder 2 | 3 | 4 | class BeamSearchDecoder(MTDecoder): 5 | def __init__(self, tokenizer, model, initializer, num_beams, early_stopping, **kwargs): 6 | super().__init__(tokenizer, model, initializer, **kwargs) 7 | 8 | self.name = "beam_search" 9 | self.acronym = "b" 10 | 11 | self.num_beams = num_beams 12 | self.early_stopping = early_stopping 13 | 14 | def decode(self, input_ids, attention_mask, *args, **kwargs): 15 | if self.is_mbart: 16 | with self.tokenizer.as_target_tokenizer(): 17 | try: 18 | lang_id = self.tokenizer.cur_lang_code_id 19 | except: 20 | lang_id = self.tokenizer.cur_lang_id 21 | beam_output = self.model.generate( 22 | **{"input_ids": input_ids, "attention_mask": attention_mask}, 23 | num_beams=self.num_beams, 24 | early_stopping=self.early_stopping, 25 | forced_bos_token_id=lang_id, 26 | ) 27 | else: 28 | beam_output = self.model.generate( 29 | **{"input_ids": input_ids, "attention_mask": attention_mask}, 30 | num_beams=self.num_beams, 31 | early_stopping=self.early_stopping, 32 | ) 33 | 34 | return beam_output, 0 35 | 36 | def compute_decode_kwargs(self, *args, **kwargs): 37 | return {} -------------------------------------------------------------------------------- /src/ipi/stopping_condition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def limit_past_key_values(past_key_values, limit): 5 | new_list = [] 6 | for elem in past_key_values: 7 | new_elem = list(elem) 8 | new_elem[0] = elem[0][:, :, :limit, :] 9 | new_elem[1] = elem[1][:, :, :limit, :] 10 | new_list.append(tuple(new_elem)) 11 | return tuple(new_list) 12 | 13 | 14 | def stopping_criterion(past_tensor, current_tensor, eos=None): 15 | assert past_tensor.shape == current_tensor.shape 16 | if torch.equal(past_tensor, current_tensor): 17 | tensor = current_tensor 18 | if eos is not None: 19 | if eos in current_tensor[0]: 20 | pos = (current_tensor[0] == eos).nonzero(as_tuple=True)[0] 21 | if pos.shape[0] > 1: 22 | pos = pos[0].item() 23 | else: 24 | pos = pos.item() 25 | return True, tensor, pos 26 | else: 27 | return True, tensor, -1 28 | return True, tensor 29 | else: 30 | if eos is not None: 31 | return False, current_tensor, False 32 | else: 33 | return False, current_tensor 34 | 35 | 36 | def check_stop_cond(tensor, eos): 37 | if eos in tensor[0]: 38 | pos = (tensor[0] == eos).nonzero(as_tuple=True)[0] 39 | if pos.shape[0] > 1: 40 | pos = pos[0].item() 41 | else: 42 | pos = pos.item() 43 | return pos 44 | else: 45 | return -1 46 | -------------------------------------------------------------------------------- /src/utils/bench_scorer.py: -------------------------------------------------------------------------------- 1 | from src.utils.bleu_calculator import BleuEvaluator, BleuValues 2 | 3 | 4 | class Scorer(object): 5 | def __init__(self, name, acronym): 6 | self.name = name 7 | self.acronym = acronym 8 | 9 | self.bleu_scorer = BleuEvaluator() 10 | 11 | # number of sentences 12 | self.i = 0 13 | 14 | # param bleu score 15 | self.predictions = [] 16 | self.references = [] 17 | 18 | # benchmark values 19 | self.tot_mean_time = 0 20 | self.tot_mean_iter = 0 21 | 22 | # inline values 23 | self.current_init = None 24 | self.current_transl = None 25 | self.current_time = None 26 | self.current_iter = None 27 | 28 | def update_metrics(self, time, iter, translation, gold, init): 29 | self.tot_mean_time += (time - self.tot_mean_time) / (self.i + 1) 30 | self.tot_mean_iter += (iter - self.tot_mean_iter) / (self.i + 1) 31 | 32 | self.predictions.append(translation) 33 | self.references.append([gold]) 34 | 35 | self.current_init = init 36 | self.current_transl = translation 37 | self.current_time = time 38 | self.current_iter = iter 39 | 40 | self.i += 1 41 | 42 | def compute_bleu_score(self): 43 | bleu_score = self.bleu_scorer.final_score( 44 | model_predictions=self.predictions, 45 | gold_references=self.references 46 | ) 47 | 48 | bleu_dict = { 49 | 'score': bleu_score['score'], 50 | 'counts': str(bleu_score['counts']), 51 | 'totals': str(bleu_score['totals']), 52 | 'precisions': str(['%.2f' % prec for prec in bleu_score['precisions']]), 53 | 'bp': bleu_score['bp'], 54 | 'sys_len': bleu_score['sys_len'], 55 | 'ref_len': bleu_score['ref_len'] 56 | } 57 | 58 | return BleuValues(**bleu_dict) -------------------------------------------------------------------------------- /src/dataset/ittb_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import load_dataset 3 | from torch.utils.data.dataset import Dataset 4 | 5 | from src.utils.utils import clean_text 6 | 7 | 8 | class Ittb(Dataset): 9 | def __init__( 10 | self, 11 | src_lan, 12 | tgt_lan, 13 | hugginface_tokenizer=None, 14 | split: str = None, 15 | ): 16 | self.src_lan = src_lan 17 | self.tgt_lan = tgt_lan 18 | self.name = "ittb" 19 | self.max_length = 511 20 | 21 | assert ( 22 | src_lan == "en" or src_lan == "hi" 23 | ), "Ittb: src_lan must be either en or hi" 24 | assert ( 25 | tgt_lan == "en" or tgt_lan == "hi" 26 | ), "Ittb: tgt_lan must be either en or hi" 27 | assert src_lan != tgt_lan, "Ittb: src_lan and tgt_lan cannot be the same" 28 | 29 | self.translation_dataset = load_dataset("cfilt/iitb-english-hindi", split=split) 30 | 31 | with torch.no_grad(): 32 | self.tokenizer = hugginface_tokenizer 33 | 34 | def collate_fn(self, batch): 35 | 36 | batch_source = [b[0] for b in batch] 37 | batch_target = [b[1] for b in batch] 38 | 39 | encoded_source = self.tokenizer( 40 | batch_source, 41 | padding=True, 42 | return_tensors="pt", 43 | ) 44 | encoded_target = self.tokenizer( 45 | batch_target, 46 | padding=True, 47 | return_tensors="pt", 48 | ) 49 | 50 | return { 51 | "source": { 52 | "input_ids": encoded_source["input_ids"], 53 | "attention_mask": encoded_source["attention_mask"], 54 | "sentences": batch_source, 55 | }, 56 | "target": { 57 | "input_ids": encoded_target["input_ids"], 58 | "attention_mask": encoded_target["attention_mask"], 59 | "sentences": batch_target, 60 | }, 61 | } 62 | 63 | def __len__(self): 64 | return len(self.translation_dataset) 65 | 66 | def __getitem__(self, idx: int): 67 | source = self.translation_dataset["translation"][idx][self.src_lan] 68 | target = self.translation_dataset["translation"][idx][self.tgt_lan] 69 | 70 | return source, target 71 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | import dotenv 7 | import git 8 | 9 | pylogger = logging.getLogger(__name__) 10 | 11 | 12 | def get_env(env_name: str, default: Optional[str] = None) -> str: 13 | """Safely read an environment variable. 14 | Raises errors if it is not defined or it is empty. 15 | :param env_name: the name of the environment variable 16 | :param default: the default (optional) value for the environment variable 17 | :return: the value of the environment variable 18 | """ 19 | if env_name not in os.environ: 20 | if default is None: 21 | message = f"{env_name} not defined and no default value is present!" 22 | pylogger.error(message) 23 | raise KeyError(message) 24 | return default 25 | 26 | env_value: str = os.environ[env_name] 27 | if not env_value: 28 | if default is None: 29 | message = ( 30 | f"{env_name} has yet to be configured and no default value is present!" 31 | ) 32 | pylogger.error(message) 33 | raise ValueError(message) 34 | return default 35 | 36 | return env_value 37 | 38 | 39 | def load_envs(env_file: Optional[str] = None) -> None: 40 | """Load all the environment variables defined in the `env_file`. 41 | This is equivalent to `. env_file` in bash. 42 | It is possible to define all the system specific variables in the `env_file`. 43 | :param env_file: the file that defines the environment variables to use. If None 44 | it searches for a `.env` file in the project. 45 | """ 46 | dotenv.load_dotenv(dotenv_path=env_file, override=True) 47 | 48 | 49 | # Load environment variables 50 | load_envs() 51 | 52 | 53 | if "PROJECT_ROOT" not in os.environ: 54 | try: 55 | PROJECT_ROOT = Path( 56 | git.Repo(Path.cwd(), search_parent_directories=True).working_dir 57 | ) 58 | except git.exc.InvalidGitRepositoryError: 59 | PROJECT_ROOT = Path.cwd() 60 | 61 | pylogger.debug(f"Inferred project root: {PROJECT_ROOT}") 62 | os.environ["PROJECT_ROOT"] = str(PROJECT_ROOT) 63 | else: 64 | PROJECT_ROOT: Path = Path(os.environ["PROJECT_ROOT"]) 65 | 66 | __all__ = ["__version__", "PROJECT_ROOT"] 67 | -------------------------------------------------------------------------------- /src/dataset/flores_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import datasets 4 | 5 | from src.utils.utils import retrieve_map_languages_flores, clean_text 6 | import typing as t 7 | 8 | 9 | class Flores(Dataset): 10 | def __init__( 11 | self, 12 | src_lan: str = "ro", 13 | tgt_lan: str = "en", 14 | hugginface_tokenizer=None, 15 | split: str = None, 16 | ): 17 | self.name = "flores" 18 | self.max_length = 511 19 | self.src_lan = retrieve_map_languages_flores(src_lan).lower()[:3] 20 | self.tgt_lan = retrieve_map_languages_flores(tgt_lan).lower()[:3] 21 | 22 | if "test" in split: 23 | split = "dev" + split 24 | 25 | self.translation_dataset_src = datasets.load_dataset( 26 | "gsarti/flores_101", self.src_lan, split=split 27 | ) 28 | self.translation_dataset_tgt = datasets.load_dataset( 29 | "gsarti/flores_101", self.tgt_lan, split=split 30 | ) 31 | 32 | with torch.no_grad(): 33 | self.tokenizer = hugginface_tokenizer 34 | 35 | def collate_fn(self, batch): 36 | 37 | batch_source = [b[0] for b in batch] 38 | batch_target = [b[1] for b in batch] 39 | 40 | encoded_source = self.tokenizer( 41 | batch_source, 42 | padding=True, 43 | return_tensors="pt", 44 | ) 45 | encoded_target = self.tokenizer( 46 | batch_target, 47 | padding=True, 48 | return_tensors="pt", 49 | ) 50 | 51 | return { 52 | "source": { 53 | "input_ids": encoded_source["input_ids"], 54 | "attention_mask": encoded_source["attention_mask"], 55 | "sentences": batch_source, 56 | }, 57 | "target": { 58 | "input_ids": encoded_target["input_ids"], 59 | "attention_mask": encoded_target["attention_mask"], 60 | "sentences": batch_target, 61 | }, 62 | } 63 | 64 | def __len__(self) -> int: 65 | return self.translation_dataset_src.num_rows 66 | 67 | def __getitem__(self, idx: int) -> t.Tuple[str, str]: 68 | source = str(self.translation_dataset_src.data.column(6)[idx]) 69 | target = str(self.translation_dataset_tgt.data.column(6)[idx]) 70 | 71 | return source, target 72 | -------------------------------------------------------------------------------- /exp/tab1.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | python3 ../main.py src_lang="en" tgt_lang="de" device="cpu" dataset.name="wmt" dataset.version="14" model.model_name="Helsinki-NLP/opus-mt-en-de" 4 | python3 ../main.py src_lang="de" tgt_lang="en" device="cpu" dataset.name="wmt" dataset.version="14" model.model_name="Helsinki-NLP/opus-mt-en-de" 5 | python3 ../main.py src_lang="en" tgt_lang="de" device="cpu" dataset.name="wmt" dataset.version="14" model.model_name="teelinsan/opus-mt-eng-deu" 6 | python3 ../main.py src_lang="de" tgt_lang="en" device="cpu" dataset.name="wmt" dataset.version="14" model.model_name="teelinsan/opus-mt-eng-deu" 7 | python3 ../main.py src_lang="ro" tgt_lang="en" device="cpu" dataset.name="wmt" dataset.version="16" model.model_name="Helsinki-NLP/opus-mt-roa-en" 8 | python3 ../main.py src_lang="en" tgt_lang="ro" device="cpu" dataset.name="wmt" dataset.version="16" model.model_name="Helsinki-NLP/opus-mt-roa-en" 9 | 10 | python3 ../main.py src_lang="en" tgt_lang="de" device="cpu" dataset.name="wmt" dataset.version="14" task="beam_search" model.model_name="teelinsan/opus-mt-eng-deu" 11 | python3 ../main.py src_lang="de" tgt_lang="en" device="cpu" dataset.name="wmt" dataset.version="14" task="beam_search" 12 | python3 ../main.py src_lang="en" tgt_lang="ro" device="cpu" dataset.name="wmt" dataset.version="16" task="beam_search" 13 | python3 ../main.py src_lang="ro" tgt_lang="en" device="cpu" dataset.name="wmt" dataset.version="16" task="beam_search" model.model_name="Helsinki-NLP/opus-mt-roa-en" 14 | 15 | 16 | python3 ../main.py src_lang="en" tgt_lang="de" device="cuda" dataset.name="wmt" dataset.version="14" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 17 | python3 ../main.py src_lang="de" tgt_lang="en" device="cuda" dataset.name="wmt" dataset.version="14" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 18 | python3 ../main.py src_lang="ro" tgt_lang="en" device="cuda" dataset.name="wmt" dataset.version="16" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 19 | python3 ../main.py src_lang="en" tgt_lang="ro" device="cuda" dataset.name="wmt" dataset.version="16" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 20 | 21 | python3 ../main.py src_lang="en" tgt_lang="de" device="cuda" dataset.name="wmt" dataset.version="14" task="beam_search" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 22 | python3 ../main.py src_lang="de" tgt_lang="en" device="cuda" dataset.name="wmt" dataset.version="14" task="beam_search" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 23 | python3 ../main.py src_lang="en" tgt_lang="ro" device="cuda" dataset.name="wmt" dataset.version="16" task="beam_search" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 24 | python3 ../main.py src_lang="ro" tgt_lang="en" device="cuda" dataset.name="wmt" dataset.version="16" task="beam_search" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 25 | -------------------------------------------------------------------------------- /src/ipi/initializer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nltk.tokenize.toktok import ToktokTokenizer 3 | from nltk.tokenize.treebank import TreebankWordDetokenizer 4 | from transformers import ( 5 | M2M100Tokenizer, 6 | MarianTokenizer, 7 | MBart50Tokenizer, 8 | MBartTokenizer, 9 | ) 10 | 11 | 12 | class Initializer(object): 13 | def __init__( 14 | self, 15 | src_len, 16 | tgt_len, 17 | hugginface_tokenizer, 18 | use_init=True, 19 | device="cpu", 20 | ): 21 | 22 | self.src_len = src_len 23 | self.tgt_len = tgt_len 24 | self.tokenizer = hugginface_tokenizer 25 | 26 | self.pad_token_id = self.tokenizer.pad_token_id 27 | self.tokenizer_nltk = ToktokTokenizer() 28 | self.detokenizer_nltk = TreebankWordDetokenizer() 29 | self.use_init = use_init 30 | self.device = device 31 | 32 | def init_translation(self, tgt_len=None): 33 | final_translation = "" 34 | with self.tokenizer.as_target_tokenizer(): 35 | if isinstance(self.tokenizer, MBartTokenizer): 36 | tgt_tensor = self.tokenizer( 37 | final_translation, return_tensors="pt", padding=True 38 | ).data["input_ids"] 39 | if tgt_tensor.shape[-1] == 2: 40 | tgt_tensor = tgt_tensor[:, :1] 41 | elif isinstance(self.tokenizer, MarianTokenizer): 42 | bos = torch.tensor([self.pad_token_id]).unsqueeze(0) 43 | tgt_tensor = bos 44 | elif isinstance(self.tokenizer, MBart50Tokenizer) or isinstance( 45 | self.tokenizer, M2M100Tokenizer 46 | ): 47 | bos = torch.tensor([self.tokenizer.eos_token_id]).unsqueeze(0) 48 | tgt_tensor = self.tokenizer( 49 | final_translation, return_tensors="pt", padding=True 50 | ).data["input_ids"] 51 | tgt_tensor = torch.cat([bos, tgt_tensor], dim=-1) 52 | else: 53 | bos = torch.tensor([self.tokenizer.bos_token_id]).unsqueeze(0) 54 | tgt_tensor = self.tokenizer( 55 | final_translation, return_tensors="pt", padding=True 56 | ).data["input_ids"] 57 | tgt_tensor = torch.cat([bos, tgt_tensor], dim=-1) 58 | if tgt_len is not None: 59 | tgt_tensor = self.trim_length(tgt_tensor, tgt_len) 60 | return tgt_tensor.to(self.device), final_translation 61 | 62 | def trim_length(self, tgt_tensor, tgt_len): 63 | last_elem = tgt_tensor[:, -1].unsqueeze(0) 64 | if tgt_tensor.shape[-1] > tgt_len: 65 | return torch.cat([tgt_tensor[..., : tgt_len - 1], last_elem], dim=-1) 66 | elif tgt_tensor.shape[-1] < tgt_len: 67 | delta = tgt_len - tgt_tensor.shape[-1] - 1 68 | init_tensor = torch.tensor( 69 | [self.pad_token_id] * delta, dtype=tgt_tensor.dtype 70 | ).unsqueeze(0) 71 | return_tensor = torch.cat([tgt_tensor, init_tensor, last_elem], dim=-1) 72 | return return_tensor 73 | else: 74 | return tgt_tensor 75 | -------------------------------------------------------------------------------- /exp/tab2.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | python3 ../main.py src_lang="en" tgt_lang="fi" device="cuda" dataset.name="wmt" dataset.version="17" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 4 | python3 ../main.py src_lang="fi" tgt_lang="en" device="cuda" dataset.name="wmt" dataset.version="17" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 5 | python3 ../main.py src_lang="en" tgt_lang="hi" device="cuda" dataset.name="ittb" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 6 | python3 ../main.py src_lang="hi" tgt_lang="en" device="cuda" dataset.name="ittb" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 7 | python3 ../main.py src_lang="en" tgt_lang="vi" device="cuda" dataset.name="iwslt" dataset.version="15" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 8 | python3 ../main.py src_lang="vi" tgt_lang="en" device="cuda" dataset.name="iwslt" dataset.version="15" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 9 | python3 ../main.py src_lang="en" tgt_lang="it" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 10 | python3 ../main.py src_lang="it" tgt_lang="en" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 11 | python3 ../main.py src_lang="en" tgt_lang="fr" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 12 | python3 ../main.py src_lang="fr" tgt_lang="en" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt" 13 | 14 | 15 | 16 | python3 ../main.py src_lang="en" tgt_lang="fi" device="cuda" dataset.name="wmt" dataset.version="17" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search" 17 | python3 ../main.py src_lang="fi" tgt_lang="en" device="cuda" dataset.name="wmt" dataset.version="17" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search" 18 | python3 ../main.py src_lang="en" tgt_lang="hi" device="cuda" dataset.name="ittb" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search" 19 | python3 ../main.py src_lang="hi" tgt_lang="en" device="cuda" dataset.name="ittb" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search" 20 | python3 ../main.py src_lang="en" tgt_lang="vi" device="cuda" dataset.name="iwslt" dataset.version="15" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search" 21 | python3 ../main.py src_lang="vi" tgt_lang="en" device="cuda" dataset.name="iwslt" dataset.version="15" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search" 22 | python3 ../main.py src_lang="en" tgt_lang="it" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search" 23 | python3 ../main.py src_lang="it" tgt_lang="en" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search" 24 | python3 ../main.py src_lang="en" tgt_lang="fr" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search" 25 | python3 ../main.py src_lang="fr" tgt_lang="en" device="cuda" dataset.name="flores" model.model_name="facebook/mbart-large-50-many-to-many-mmt" task="beam_search" 26 | 27 | -------------------------------------------------------------------------------- /src/dataset/wmt_dataset.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | 3 | import datasets 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | from src.utils.utils import clean_text 8 | 9 | 10 | class Wmt(Dataset): 11 | """ 12 | Wmt machine translation dataset reader 13 | 14 | Input: 15 | - version -> the dataset version dataset, by default '16' (dataset-16) 16 | - src_lan -> the source language, by default 'ro' (Romanian) 17 | - tgt_lan -> the target language, by default 'en' (English) 18 | - tokenizer_model -> the tokenizer model 19 | - split -> if not None, allows to split the dataset in following set: ['train', 'test', 'validation'] 20 | - concat -> if not None, make possible the concatenation of the specified set. 21 | Note: It works only if split is None 22 | It can be: ['train', 'test', 'validation'] 23 | """ 24 | 25 | def __init__( 26 | self, 27 | version: str = "16", 28 | src_lan: str = "ro", 29 | tgt_lan: str = "en", 30 | hugginface_tokenizer=None, 31 | split: str = None, 32 | ): 33 | self.src_lan = src_lan 34 | self.tgt_lan = tgt_lan 35 | self.tokenizer_model = hugginface_tokenizer 36 | self.max_length = 511 37 | 38 | if src_lan == "en": 39 | source2target = "{}-{}".format(self.tgt_lan, self.src_lan) 40 | else: 41 | source2target = "{}-{}".format(self.src_lan, self.tgt_lan) 42 | 43 | if version == "19" and "test" in split: 44 | split = "validation" 45 | 46 | version = f"wmt{version}" 47 | 48 | self.name = version 49 | 50 | try: 51 | self.translation_dataset = datasets.load_dataset( 52 | version, source2target, split=split 53 | ) 54 | except: 55 | raise ValueError( 56 | f"{version} can read only the pairs cs-en, en-cs, de-en, en-de," 57 | f" fi-en, en-fi, ro-en, en-ro, ru-en, en-ru, tr-en, en-tr" 58 | ) 59 | 60 | with torch.no_grad(): 61 | self.tokenizer = hugginface_tokenizer 62 | 63 | def collate_fn(self, batch): 64 | 65 | batch_source = [b[0] for b in batch] 66 | batch_target = [b[1] for b in batch] 67 | 68 | encoded_source = self.tokenizer( 69 | batch_source, 70 | padding=True, 71 | return_tensors="pt", 72 | ) 73 | encoded_target = self.tokenizer( 74 | batch_target, 75 | padding=True, 76 | return_tensors="pt", 77 | ) 78 | 79 | return { 80 | "source": { 81 | "input_ids": encoded_source["input_ids"], 82 | "attention_mask": encoded_source["attention_mask"], 83 | "sentences": batch_source, 84 | }, 85 | "target": { 86 | "input_ids": encoded_target["input_ids"], 87 | "attention_mask": encoded_target["attention_mask"], 88 | "sentences": batch_target, 89 | }, 90 | } 91 | 92 | def __len__(self) -> int: 93 | return len(self.translation_dataset) 94 | 95 | def __getitem__(self, idx: int) -> t.Tuple[str, str]: 96 | sample = self.translation_dataset[idx] 97 | source = sample["translation"][self.src_lan] 98 | target = sample["translation"][self.tgt_lan] 99 | 100 | return source, target 101 | -------------------------------------------------------------------------------- /src/ipi/decoders/autoregressive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.ipi.decoders.mt_decoding import MTDecoder 4 | 5 | 6 | class AutoregressiveDecoder(MTDecoder): 7 | def __init__(self, tokenizer, model, initializer, **kwargs): 8 | super().__init__(tokenizer, model, initializer, **kwargs) 9 | 10 | self.name = "autoregressive" 11 | self.acronym = "a" 12 | 13 | @torch.no_grad() 14 | def decode(self, input_ids, attention_mask, init_tensor=None, logits_preprocessor=None, *args, **kwargs): 15 | 16 | index = 0 17 | 18 | if init_tensor is None: 19 | init_tensor = torch.tensor( 20 | [self.pad_token_id], device=self.device 21 | ).unsqueeze(0) 22 | elif self.is_mbart: 23 | output = self.model( 24 | input_ids, 25 | attention_mask, 26 | decoder_input_ids=init_tensor[:, 0].unsqueeze(0), 27 | use_cache=True, 28 | ) 29 | encoder_last_hidden_state = output.encoder_last_hidden_state 30 | past_key_values = output.past_key_values 31 | index += 1 32 | total_res = torch.tensor( 33 | [[init_tensor[:, 0], init_tensor[:, 1]]], device=self.device 34 | ) 35 | init_tensor = init_tensor[:, 1].unsqueeze(0) 36 | else: 37 | init_tensor = init_tensor[:, 0].unsqueeze(0) 38 | 39 | total_res = init_tensor.clone() 40 | while True: 41 | if self.use_cache and index > 0: 42 | if index == 1024: 43 | print(total_res) 44 | output = self.model( 45 | None, 46 | attention_mask, 47 | decoder_input_ids=init_tensor, 48 | encoder_outputs=(encoder_last_hidden_state, None, None), 49 | use_cache=True, 50 | past_key_values=past_key_values, 51 | ) 52 | else: 53 | output = self.model( 54 | input_ids, 55 | attention_mask, 56 | decoder_input_ids=init_tensor, 57 | use_cache=True, 58 | ) 59 | encoder_last_hidden_state = output.encoder_last_hidden_state 60 | past_key_values = output.past_key_values 61 | logits = output.logits 62 | if logits_preprocessor is not None: 63 | logits = logits_preprocessor(total_res, logits[:,-1,:]) 64 | else: 65 | logits = logits[:,-1,:] 66 | max_value = torch.argmax(logits, dim=-1) 67 | last = max_value 68 | init_tensor = last.unsqueeze(0) 69 | total_res = torch.cat((total_res, init_tensor), dim=1) 70 | 71 | index += 1 72 | if last[0].item() == self.eos_token_id or index == self.model.config.max_length - 1: 73 | break 74 | return total_res, index 75 | 76 | def initialize(self): 77 | if self.initializer is not None: 78 | init_tensor, _ = self.initializer.init_translation() 79 | else: 80 | init_tensor = None 81 | 82 | return init_tensor 83 | 84 | def compute_decode_kwargs(self, input_ids, *args, **kwargs): 85 | init_tensor = self.initialize() 86 | logits_preprocessor = self.generate_logits_preprocessor(input_ids) 87 | 88 | return { 89 | "init_tensor": init_tensor.clone(), 90 | "logits_preprocessor": logits_preprocessor 91 | } 92 | 93 | -------------------------------------------------------------------------------- /src/utils/bleu_calculator.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import datasets 4 | import pandas as pd 5 | from tabulate import tabulate 6 | 7 | 8 | class BleuValues: 9 | def __init__(self, **entries): 10 | self.__dict__.update(entries) 11 | 12 | 13 | class BleuEvaluator(object): 14 | def __init__(self): 15 | self.metric = datasets.load_metric("sacrebleu") 16 | 17 | def add_element(self, model_predictions, gold_references): 18 | self.metric.add(predictions=model_predictions, references=gold_references) 19 | 20 | def add_batch(self, predictions, references): 21 | self.metric.add_batch(predictions=predictions, references=references) 22 | 23 | def final_score(self, model_predictions, gold_references): 24 | return self.metric.compute(predictions=model_predictions, references=gold_references) 25 | 26 | 27 | class BleuCalculator: 28 | def __init__( 29 | self, 30 | dataset, 31 | result_dir, 32 | ): 33 | self.dataset = dataset 34 | self.result_dir = result_dir 35 | 36 | @staticmethod 37 | def read_csv(path_csv): 38 | csv_reader = pd.read_csv(path_csv, sep="\t", header=0) 39 | return { 40 | k: v 41 | for k, v in zip( 42 | csv_reader["#sentence"].tolist(), csv_reader["times"].tolist() 43 | ) 44 | } 45 | 46 | def _retrieve_files(self): 47 | file2data = dict() 48 | for root, dirs, files in os.walk(self.result_dir): 49 | if any(map(lambda x: "trans" in x, files)) and "initrans" not in root: 50 | trans_files_name = list(filter(lambda x: ("trans" in x), files))[0] 51 | data = self.read_csv(path_csv=os.path.join(root, trans_files_name)) 52 | file2data.update({trans_files_name.split(".")[-2]: data}) 53 | return file2data 54 | 55 | def _load_dataset(self): 56 | return {i: x[1] for i, x in enumerate(self.dataset)} 57 | 58 | @staticmethod 59 | def _match_indices(method, gold): 60 | new_gold = dict() 61 | for k in gold: 62 | if k in method: 63 | new_gold.update({k: gold[k]}) 64 | 65 | return new_gold 66 | 67 | @staticmethod 68 | def _bleu_score_formatter(bleu_score): 69 | 70 | bleu_dict = { 71 | "score": bleu_score["score"], 72 | "counts": str(bleu_score["counts"]), 73 | "totals": str(bleu_score["totals"]), 74 | "precisions": str(["%.2f" % prec for prec in bleu_score["precisions"]]), 75 | "bp": bleu_score["bp"], 76 | "sys_len": bleu_score["sys_len"], 77 | "ref_len": bleu_score["ref_len"], 78 | } 79 | 80 | return BleuValues(**bleu_dict) 81 | 82 | def write_report(self, file2score): 83 | print("Writing report...") 84 | 85 | # Table for the Bleu score 86 | header = ["Metrics"] + [m[0] for m in file2score] 87 | 88 | bleu_table = tabulate( 89 | [ 90 | ["Score"] + [b[1].score for b in file2score], 91 | ["Counts"] + [b[1].counts for b in file2score], 92 | ["Totals"] + [b[1].totals for b in file2score], 93 | ["Precisions"] + [b[1].precisions for b in file2score], 94 | ["Bp"] + [b[1].bp for b in file2score], 95 | ["Sys_len"] + [b[1].sys_len for b in file2score], 96 | ["ref_len"] + [b[1].ref_len for b in file2score], 97 | ], 98 | headers=header, 99 | tablefmt="rst", 100 | ) 101 | 102 | with open(os.path.join(self.result_dir, "bleu_report.txt"), mode="w") as report: 103 | report.write(f"Bleu Score\n{bleu_table}\n\n") 104 | 105 | def _compute_bleu_score(self, name, translations, gold): 106 | scorer = BleuEvaluator() 107 | 108 | translations = list(translations.values()) 109 | gold = list(gold.values()) 110 | 111 | gold = [[g] for g in gold] 112 | 113 | # for t, g in zip(translations, gold): 114 | # scorer.add_element(t, [g]) 115 | 116 | score_value = scorer.final_score(translations, gold) 117 | return name, self._bleu_score_formatter(score_value) 118 | 119 | def compute_bleu_score(self): 120 | file2data = self._retrieve_files() 121 | gold = self._load_dataset() 122 | 123 | if "trans_beam" in file2data: 124 | beam_search = self._match_indices( 125 | file2data["trans_gs_jacobi"], file2data["trans_beam"] 126 | ) 127 | file2data["trans_beam"] = beam_search 128 | gold = self._match_indices(file2data["trans_gs_jacobi"], gold) 129 | 130 | file2score = [ 131 | self._compute_bleu_score(file, file2data[file], gold) for file in file2data 132 | ] 133 | 134 | self.write_report(file2score) 135 | -------------------------------------------------------------------------------- /src/ipi/decoders/mt_decoding.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import torch 4 | from transformers import MBartForConditionalGeneration 5 | 6 | from src.ipi import stopping_condition as sc 7 | from src.utils.utils import get_logits_preprocessor 8 | 9 | PREC_GOLD_AUTOREGRESSIVE: Dict[str, Optional[torch.Tensor]] = {"input_ids": None, "gold": None} 10 | PREC_LOGISTS_PREPROCESSOR: Dict[str, Optional[torch.Tensor]] = {"input_ids": None, "logists": None} 11 | 12 | class MTDecoder: 13 | def __init__( 14 | self, 15 | tokenizer, 16 | model, 17 | initializer, 18 | use_cache: bool = True, 19 | use_logits_preprocessor: bool = True, 20 | device: str = "cuda", 21 | **kwargs 22 | ): 23 | self.tokenizer = tokenizer 24 | 25 | self.initializer = initializer 26 | 27 | self.pad_token_id = self.tokenizer.pad_token_id 28 | self.eos_token_id = self.tokenizer.eos_token_id 29 | 30 | self.use_cache = use_cache 31 | self.device = device 32 | 33 | with torch.no_grad(): 34 | self.model = model 35 | self.model_name = self.model.name_or_path 36 | self.model.eval() 37 | 38 | self.max_length = min(self.tokenizer.model_max_length, 511) 39 | 40 | self.use_logits_preprocessor = use_logits_preprocessor 41 | 42 | self.is_mbart = isinstance(self.model, MBartForConditionalGeneration) 43 | 44 | def decode(self, input_ids, attention_mask, *args, **kwargs): 45 | pass 46 | 47 | def compute_decode_kwargs(self, *args, **kwargs): 48 | pass 49 | 50 | def initialize(self, **kwargs): 51 | pass 52 | 53 | def generate_gold_autoregressive(self, input_ids, attention_mask): 54 | 55 | global PREC_GOLD_AUTOREGRESSIVE 56 | 57 | if PREC_GOLD_AUTOREGRESSIVE['input_ids'] is None or not torch.equal(input_ids, PREC_GOLD_AUTOREGRESSIVE['input_ids']): 58 | if self.is_mbart: 59 | with self.tokenizer.as_target_tokenizer(): 60 | try: 61 | lang_id = self.tokenizer.cur_lang_code_id 62 | except: 63 | lang_id = self.tokenizer.cur_lang_id 64 | gold_autoregressive = self.model.generate( 65 | **{"input_ids": input_ids, "attention_mask": attention_mask}, 66 | num_beams=1, 67 | do_sample=False, 68 | use_cache=False, 69 | forced_bos_token_id=lang_id, 70 | ) 71 | else: 72 | gold_autoregressive = self.model.generate( 73 | **{"input_ids": input_ids, "attention_mask": attention_mask}, 74 | num_beams=1, 75 | do_sample=False, 76 | use_cache=False, 77 | ) 78 | gold_autoregressive = gold_autoregressive[:, : self.max_length] 79 | 80 | PREC_GOLD_AUTOREGRESSIVE['input_ids'] = input_ids 81 | PREC_GOLD_AUTOREGRESSIVE['gold'] = gold_autoregressive 82 | 83 | return PREC_GOLD_AUTOREGRESSIVE['gold'] 84 | 85 | def generate_logits_preprocessor(self, input_ids): 86 | 87 | global PREC_LOGISTS_PREPROCESSOR 88 | 89 | if self.use_logits_preprocessor: 90 | if PREC_LOGISTS_PREPROCESSOR['input_ids'] is None or not torch.equal(input_ids, PREC_LOGISTS_PREPROCESSOR['input_ids']): 91 | logits_preprocessor = get_logits_preprocessor( 92 | model=self.model, 93 | input_ids=input_ids, 94 | eos_token_id=self.eos_token_id 95 | ) 96 | 97 | PREC_LOGISTS_PREPROCESSOR['input_ids'] = input_ids 98 | PREC_LOGISTS_PREPROCESSOR['logists'] = logits_preprocessor 99 | else: 100 | return None 101 | 102 | return PREC_LOGISTS_PREPROCESSOR['logists'] 103 | 104 | @staticmethod 105 | def stopping_criterion(past_tensor, current_tensor, eos=None): 106 | return sc.stopping_criterion(past_tensor, current_tensor, eos) 107 | 108 | @staticmethod 109 | def limit_past_key_values(past_key_values, limit): 110 | return sc.limit_past_key_values(past_key_values, limit) 111 | 112 | @staticmethod 113 | def trig_eos(tensor, eos_token_id, init_tensor, base_value): 114 | if tensor[:, 0].item() == eos_token_id: 115 | return init_tensor[:, : base_value + 1] 116 | else: 117 | return None 118 | 119 | 120 | def generate_target( 121 | tokenizer, 122 | model, 123 | input_ids: torch.Tensor, 124 | attention_mask: torch.Tensor, 125 | is_mbart: bool, 126 | decoding_method: str = "greedy", 127 | remove_padding: bool = False, 128 | ): 129 | if decoding_method == "greedy": 130 | if is_mbart: 131 | with tokenizer.as_target_tokenizer(): 132 | gold_output = model.generate( 133 | **{"input_ids": input_ids, "attention_mask": attention_mask}, 134 | num_beams=1, 135 | do_sample=False, 136 | forced_bos_token_id=tokenizer.cur_lang_code_id, 137 | ) 138 | else: 139 | gold_output = model.generate( 140 | **{"input_ids": input_ids, "attention_mask": attention_mask}, 141 | num_beams=1, 142 | do_sample=False, 143 | ) 144 | else: 145 | raise NotImplementedError() 146 | 147 | if remove_padding: 148 | sample_lengths = (gold_output != tokenizer.pad_token_id).sum(dim=1) 149 | gold_output = [ 150 | sample[:length] for sample, length in zip(gold_output, sample_lengths) 151 | ] 152 | 153 | return gold_output 154 | 155 | -------------------------------------------------------------------------------- /src/ipi/decoders/gs_jacobi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.ipi.decoders.mt_decoding import MTDecoder 4 | from more_itertools import sliced 5 | 6 | 7 | class GSJacobiDecoder(MTDecoder): 8 | def __init__(self, tokenizer, model, initializer, gs_jaco_blocks, init_mode, **kwargs): 9 | super().__init__(tokenizer, model, initializer, **kwargs) 10 | 11 | self.name = "gs_jacobi" 12 | self.acronym = "g" 13 | 14 | self.gs_jaco_blocks = gs_jaco_blocks 15 | self.init_mode = init_mode 16 | 17 | @torch.no_grad() 18 | def decode( 19 | self, input_ids, attention_mask, target_len, gold_target, init_tensor=None, logits_preprocessor=None, *args, **kwargs 20 | ): 21 | key_cache = 1 22 | if init_tensor is None: 23 | init_tensor = torch.tensor( 24 | [self.pad_token_id] * target_len, device=self.device 25 | ) 26 | blocks = list(sliced(init_tensor, self.gs_jaco_blocks)) 27 | init_tensor = init_tensor.unsqueeze(0) 28 | total_past_key_values = None 29 | elif self.is_mbart: 30 | output = self.model( 31 | input_ids, 32 | attention_mask, 33 | decoder_input_ids=init_tensor[:, 0].unsqueeze(0), 34 | use_cache=True, 35 | ) 36 | encoder_last_hidden_state = output.encoder_last_hidden_state 37 | total_past_key_values = output.past_key_values 38 | delta = 1 39 | # total_res = init_tensor.to(self.device) 40 | init_tensor = init_tensor[:, 1:] 41 | blocks = list(sliced(init_tensor.squeeze(0), self.gs_jaco_blocks)) 42 | key_cache = 2 43 | else: 44 | init_tensor = init_tensor 45 | blocks = list(sliced(init_tensor.squeeze(0), self.gs_jaco_blocks)) 46 | total_past_key_values = None 47 | 48 | iteration_saved = 0 49 | base_value = 0 50 | 51 | for blocco in blocks: 52 | max_len = blocco.shape[-1] 53 | blocco_usr = init_tensor[:, base_value : base_value + max_len] 54 | 55 | for index in range(max_len): 56 | old_blocco = blocco_usr.detach().clone() 57 | blocco_usr_new = blocco_usr[:, index:] 58 | if base_value == 0 and index == 0 and not self.is_mbart: 59 | output = self.model( 60 | input_ids, 61 | attention_mask, 62 | decoder_input_ids=blocco_usr_new, 63 | use_cache=True, 64 | past_key_values=total_past_key_values, 65 | ) 66 | encoder_last_hidden_state = output.encoder_last_hidden_state 67 | else: 68 | output = self.model( 69 | input_ids, 70 | attention_mask, 71 | decoder_input_ids=blocco_usr_new, 72 | encoder_outputs=(encoder_last_hidden_state, None, None), 73 | use_cache=True, 74 | past_key_values=total_past_key_values, 75 | ) 76 | 77 | total_past_key_values = self.limit_past_key_values( 78 | output.past_key_values, 79 | base_value + index + key_cache, 80 | ) 81 | logits = output.logits 82 | max_value = torch.argmax(logits, dim=-1) 83 | 84 | if logits_preprocessor is not None: 85 | logits_new = logits_preprocessor(init_tensor[:, :base_value + index +1], logits[:, 0, :]) 86 | max_value_new = torch.argmax(logits_new, dim=-1) 87 | max_value[:, 0] = max_value_new 88 | 89 | 90 | if ( 91 | max_value.shape[-1] 92 | == init_tensor[ 93 | :, base_value + index + 1 : base_value + max_len + 1 94 | ].shape[-1] 95 | ): 96 | init_tensor[ 97 | :, base_value + index + 1 : base_value + max_len + 1 98 | ] = max_value[:, :] 99 | else: 100 | # If last block remove the last token after EOS 101 | init_tensor[ 102 | :, base_value + index + 1 : base_value + max_len + 1 103 | ] = max_value[:, :-1] 104 | 105 | stop_condition, _ = self.stopping_criterion(old_blocco, blocco_usr) 106 | 107 | if stop_condition and index + 1 != max_len: 108 | total_past_key_values = self.limit_past_key_values( 109 | output.past_key_values, 110 | base_value + max_len + 1, 111 | ) 112 | iteration_saved += max_len - index - 1 113 | break 114 | base_value += max_len 115 | return init_tensor, (gold_target.shape[-1] - 1) - iteration_saved 116 | 117 | def initialize(self, input_ids, gold_autoregressive): 118 | if self.initializer is not None: 119 | if self.init_mode == "overprov": 120 | m = int(input_ids.shape[-1] + 10 / 100 * input_ids.shape[-1]) 121 | else: 122 | m = gold_autoregressive.shape[-1] 123 | init_tensor, _ = self.initializer.init_translation(m) 124 | else: 125 | init_tensor = None 126 | 127 | return init_tensor 128 | 129 | def compute_decode_kwargs(self, input_ids, attention_mask, **kwargs): 130 | gold_autoregressive = self.generate_gold_autoregressive(input_ids, attention_mask) 131 | init_tensor = self.initialize(input_ids, gold_autoregressive) 132 | logits_preprocessor = self.generate_logits_preprocessor(input_ids) 133 | 134 | return{ 135 | "init_tensor": init_tensor.clone(), 136 | "gold_target": gold_autoregressive, 137 | "target_len": gold_autoregressive.shape[-1], 138 | "logits_preprocessor": logits_preprocessor 139 | } 140 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Accelerating Transformer Inference for Translation via Parallel Decoding 4 | 5 | [![Paper](http://img.shields.io/badge/paper-ArXiv-B31B1B.svg)](https://arxiv.org/abs/2305.10427) 6 | [![Conference](http://img.shields.io/badge/ACL-2023-c92828.svg)](https://aclanthology.org/2023.acl-long.689/) 7 | 8 |
9 | 10 | 11 | This is the code repository of the paper "Accelerating Transformer Inference for Translation via Parallel Decoding" accepted at [ACL 2023 main conference](https://aclanthology.org/2023.acl-long.689/). 12 | 13 | The paper proposes three Parallel Decoding methods to speed up existing autoregressive machine translation models: **Jacobi Decoding**, **GS-Jacobi Decoding**, and **Hybrid GS-Jacobi Decoding**. 14 | 15 | This code is not production-ready and should be used just for research purposes. 16 | 17 | **Paper**: https://arxiv.org/abs/2305.10427 18 | 19 |
20 | drawing 21 |
22 | 23 | 24 | ## Reproduce the results 25 | To produce the benchmark values for en-ro Wmt16, do: 26 | 1. install all the requirements with `pip install -r requirements.txt` 27 | 2. run the following to retrieve the benchmark values: 28 | ``` 29 | python3 main.py src_lang="en" tgt_lang="ro" device="cpu" dataset.name="wmt" dataset.version="16" task="benchmark" bench.result_dir="[your path]" model.model_name="Helsinki-NLP/opus-mt-en-ro" 30 | ``` 31 | Iters speedup and BLEU results should be easy to reproduce. Time speedups depend on the availability of the underlying hardware and software to run computation in parallel without introducing overheads. Please follow the experimental setting proposed in the paper. The easiest way is to use a virgin virtual machine, we provide the instructions in the Scaling Experiments in this readme. 32 | 33 | ## Datasets 34 | All the datasets are available via `HuggingFace` datasets, so they will be downloaded automatically. 35 | However, `Iwslt` needs to be downloaded manually. In particular, you have to download `2015-01/texts` for `Iwslt15` and ` 36 | 2017-01-trnted/texts` for `Iwslt17`. Once downloaded, you should specify the path as parameter in the Python command, by adding `dataset.data_dir=[your path]` (it is possible also to modify it manually in `conf/config.yaml`). 37 | 38 | ## Table 1 and Table 5 Experiments - Parallel Decoding Algorithms 39 | To reproduce results in Table 1 run the command. 40 | ``` 41 | /bin/bash ./exp/tab1.sh 42 | ``` 43 | Please modify beforehand the result_dir path in `tab1.sh` or in the config file `conf/config.yaml`. 44 | 45 | ## Table 2 and Table 6 Experiments - Cross Languages 46 | To reproduce results in Table 2 run the command. 47 | ``` 48 | /bin/bash ./exp/tab2.sh 49 | ``` 50 | Please modify beforehand the result_dir path in `tab2.sh` or in the config file `conf/config.yaml`. 51 | 52 | ## Figure 3 and Figure 5 - Scaling Experiments 53 | To reproduce the scaling experiments you need to use Google Clouds with `c2d-standard-XX`, where XX is the number of used cores. Then you need to run the command as specified in section "Reproduce Results" of this README. 54 | To ease the process we provide the command to launch the virtual machine using gcloud-cli. 55 | 56 | ``` 57 | gcloud compute instances create instance-1 --zone=us-central1-a --machine-type=c2d-standard-8 --network-interface=network-tier=PREMIUM,subnet=default --maintenance-policy=MIGRATE --provisioning-model=STANDARD --scopes=https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write,https://www.googleapis.com/auth/servicecontrol,https://www.googleapis.com/auth/service.management.readonly,https://www.googleapis.com/auth/trace.append --create-disk=boot=yes,device-name=instance-1,image=projects/ubuntu-os-cloud/global/images/ubuntu-2004-focal-v20221015,mode=rw,size=30 --no-shielded-secure-boot --shielded-vtpm --shielded-integrity-monitoring --reservation-affinity=any 58 | ``` 59 | ## Table 3 Experiments - FLOPs calculator 60 | 61 | To reproduce the FLOPs calculation in Table 3 simply run the script: 62 | ``` 63 | python3 ./exp/flops_calculator.py 64 | ``` 65 | 66 | ## Dependency Graph Visualizer (DDGviz) 67 | 68 | To run the Dependency Graph Visualizer (DDGviz) execute the command: 69 | ``` 70 | PYTHONPATH=. python3 ./src/viz/visualize.py 71 | ``` 72 | It is possible to select the examples to visualize with the param `--examples [list of id in the dataset]`. The dataset and source/target language can be selected with the corresponding commands, please use the param `--help` for more info. 73 | The output DDGviz visualization will be saved in `iteration_matrix//images`. 74 | 75 |
76 | drawing 77 |
78 | 79 | 80 | ## Citation 81 | 82 | If you use this code please cite: 83 | 84 | ```bibtex 85 | @inproceedings{santilli-etal-2023-accelerating, 86 | title = "Accelerating Transformer Inference for Translation via Parallel Decoding", 87 | author = "Santilli, Andrea and 88 | Severino, Silvio and 89 | Postolache, Emilian and 90 | Maiorca, Valentino and 91 | Mancusi, Michele and 92 | Marin, Riccardo and 93 | Rodola, Emanuele", 94 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 95 | month = jul, 96 | year = "2023", 97 | address = "Toronto, Canada", 98 | publisher = "Association for Computational Linguistics", 99 | url = "https://aclanthology.org/2023.acl-long.689", 100 | pages = "12336--12355", 101 | abstract = "Autoregressive decoding limits the efficiency of transformers for Machine Translation (MT). The community proposed specific network architectures and learning-based methods to solve this issue, which are expensive and require changes to the MT model, trading inference speed at the cost of the translation quality. In this paper, we propose to address the problem from the point of view of decoding algorithms, as a less explored but rather compelling direction. We propose to reframe the standard greedy autoregressive decoding of MT with a parallel formulation leveraging Jacobi and Gauss-Seidel fixed-point iteration methods for fast inference. This formulation allows to speed up existing models without training or modifications while retaining translation quality. We present three parallel decoding algorithms and test them on different languages and models showing how the parallelization introduces a speedup up to 38{\%} w.r.t. the standard autoregressive decoding and nearly 2x when scaling the method on parallel resources. Finally, we introduce a decoding dependency graph visualizer (DDGviz) that let us see how the model has learned the conditional dependence between tokens and inspect the decoding procedure.", 102 | } 103 | ``` 104 | -------------------------------------------------------------------------------- /src/ipi/decoders/jacobi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.ipi.decoders.mt_decoding import MTDecoder 4 | from src.viz.dependecy_graph import DecodingDependencyGraph 5 | 6 | 7 | class JacobiDecoder(MTDecoder): 8 | def __init__(self, tokenizer, model, initializer, **kwargs): 9 | super().__init__(tokenizer, model, initializer, **kwargs) 10 | 11 | self.name = "jacobi" 12 | self.acronym = "j" 13 | 14 | def generate_target( 15 | self, 16 | input_ids: torch.Tensor, 17 | attention_mask: torch.Tensor, 18 | is_mbart: bool, 19 | decoding_method: str = "greedy", 20 | remove_padding: bool = False, 21 | ): 22 | if decoding_method == "greedy": 23 | if is_mbart: 24 | with self.tokenizer.as_target_tokenizer(): 25 | gold_output = self.model.generate( 26 | **{"input_ids": input_ids, "attention_mask": attention_mask}, 27 | num_beams=1, 28 | do_sample=False, 29 | forced_bos_token_id=self.tokenizer.cur_lang_code_id, 30 | ) 31 | else: 32 | gold_output = self.model.generate( 33 | **{"input_ids": input_ids, "attention_mask": attention_mask}, 34 | num_beams=1, 35 | do_sample=False, 36 | ) 37 | else: 38 | raise NotImplementedError() 39 | 40 | if remove_padding: 41 | sample_lengths = (gold_output != self.tokenizer.pad_token_id).sum(dim=1) 42 | gold_output = [ 43 | sample[:length] for sample, length in zip(gold_output, sample_lengths) 44 | ] 45 | 46 | return gold_output 47 | 48 | 49 | @torch.no_grad() 50 | def decode( 51 | self, 52 | input_ids, 53 | attention_mask, 54 | target_len=None, 55 | gold_target=None, 56 | init_tensor=None, 57 | compute_ddg: bool = False, 58 | logits_preprocessor=None, 59 | *args, 60 | **kwargs 61 | ): 62 | max_length = target_len 63 | str_index = 0 64 | key_cache = 0 65 | if compute_ddg: 66 | if gold_target is None: 67 | gold_target = self.generate_target( 68 | input_ids=input_ids, 69 | attention_mask=attention_mask, 70 | is_mbart=self.is_mbart, 71 | ) 72 | ddg = DecodingDependencyGraph( 73 | model=self.model, tokenizer=self.tokenizer, gold_target=gold_target 74 | ) 75 | 76 | max_length = gold_target.shape[-1] 77 | 78 | if init_tensor is None: 79 | init_tensor = torch.tensor( 80 | [self.pad_token_id] * input_ids.size(0) * max_length, 81 | device=self.device, 82 | ).reshape(input_ids.size(0), max_length) 83 | elif self.is_mbart: 84 | if init_tensor.shape[0] == 1: 85 | decoder_input_ids = init_tensor[:, 0].unsqueeze(0) 86 | else: 87 | decoder_input_ids = init_tensor[:, 0].unsqueeze(-1) 88 | output = self.model( 89 | input_ids, 90 | attention_mask, 91 | decoder_input_ids=decoder_input_ids, 92 | use_cache=True, 93 | ) 94 | encoder_last_hidden_state = output.encoder_last_hidden_state 95 | past_key_values = output.past_key_values 96 | str_index = 1 97 | total_res = init_tensor 98 | init_tensor = init_tensor[:, 1:] 99 | key_cache = 1 100 | 101 | output_probs = init_tensor.clone().float() 102 | 103 | for index in range(str_index, max_length): 104 | if self.use_cache and index > 0: 105 | old_init_tensor = total_res.detach().clone() 106 | init_tensor = total_res[:, index:] 107 | output = self.model( 108 | input_ids, 109 | attention_mask, 110 | decoder_input_ids=init_tensor, 111 | encoder_outputs=(encoder_last_hidden_state, None, None), 112 | use_cache=True, 113 | past_key_values=self.limit_past_key_values(past_key_values, index + key_cache), 114 | ) 115 | else: 116 | old_init_tensor = init_tensor.detach().clone() 117 | output = self.model( 118 | input_ids, 119 | attention_mask, 120 | decoder_input_ids=init_tensor, 121 | use_cache=True, 122 | ) 123 | encoder_last_hidden_state = output.encoder_last_hidden_state 124 | past_key_values = output.past_key_values 125 | logits = output.logits 126 | max_index = torch.argmax(logits, dim=-1) 127 | max_value, max_i = torch.max(torch.softmax(logits, dim=-1), dim=-1) 128 | if index > 0 and logits_preprocessor is not None: 129 | logits_new = logits_preprocessor(total_res[:, : index + 1], logits[:, 0, :]) 130 | max_value_new = torch.argmax(logits_new, dim=-1) 131 | max_index[:, 0] = max_value_new 132 | if self.use_cache and index > 0: 133 | init_tensor = max_index 134 | total_res = torch.cat( 135 | (total_res[:, : index + 1], init_tensor[:, :-1]), dim=1 136 | ) 137 | else: 138 | init_tensor[:, index + 1 :] = max_index[:, index:-1] 139 | total_res = init_tensor 140 | 141 | output_probs[:, index + 1 :] = max_value[:, index:-1] 142 | 143 | stop_condition, return_tensor = self.stopping_criterion( 144 | old_init_tensor, total_res 145 | ) 146 | 147 | if compute_ddg: 148 | ddg.insert_one_element( 149 | total_res, gold_target, output_probs=output_probs 150 | ) 151 | 152 | if stop_condition: 153 | break 154 | 155 | if compute_ddg: 156 | return return_tensor, index, ddg 157 | 158 | return return_tensor, index 159 | 160 | 161 | def initialize(self, init_transl): 162 | if self.initializer is not None: 163 | init_tensor, _ = self.initializer.init_translation(init_transl.shape[-1]) 164 | else: 165 | init_tensor = None 166 | 167 | return init_tensor 168 | 169 | def compute_decode_kwargs(self, input_ids, attention_mask, **kwargs): 170 | 171 | gold_autoregressive = self.generate_gold_autoregressive(input_ids, attention_mask) 172 | init_tensor = self.initialize(init_transl=gold_autoregressive) 173 | logits_preprocessor = self.generate_logits_preprocessor(input_ids) 174 | 175 | return { 176 | "init_tensor": init_tensor.clone(), 177 | "target_len": gold_autoregressive.shape[-1], 178 | "gold_target": gold_autoregressive, 179 | "logits_preprocessor": logits_preprocessor 180 | } 181 | -------------------------------------------------------------------------------- /src/viz/dependecy_graph.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import plotly.express as px 5 | import torch 6 | from transformers import MBart50Tokenizer 7 | 8 | 9 | class DecodingDependencyGraph(object): 10 | def __init__(self, model, tokenizer, gold_target: torch.Tensor): 11 | self.model = model 12 | self.tokenizer = tokenizer 13 | self.init_matrix = [] 14 | self.output_probs = [] 15 | self.gold_target = gold_target 16 | 17 | if isinstance(tokenizer, MBart50Tokenizer): 18 | self.is_mbart = True 19 | self.starting_index = 2 20 | else: 21 | self.is_mbart = False 22 | self.starting_index = 1 23 | 24 | def insert_one_element( 25 | self, current_tensor, gold_tensor, index=None, output_probs=None 26 | ): 27 | self.init_matrix.append( 28 | ( 29 | current_tensor[:, self.starting_index :] 30 | == gold_tensor[:, self.starting_index :] 31 | ) 32 | ) 33 | if output_probs is not None: 34 | self.output_probs.append(output_probs) 35 | 36 | def finalize_matrices(self) -> Tuple[torch.Tensor, torch.Tensor]: 37 | return ( 38 | torch.stack(self.init_matrix).permute(1, 0, 2).cpu(), 39 | torch.stack(self.output_probs).permute(1, 0, 2).cpu(), 40 | ) 41 | 42 | def _create_labels(self, sample_index: int, method): 43 | sample_target = self.gold_target[sample_index, :].squeeze(0) 44 | 45 | if method == "decoded_ids": 46 | labels = self.tokenizer.convert_ids_to_tokens(sample_target) 47 | labels = [ 48 | f"{i}:{id}" 49 | for i, id in zip( 50 | labels[self.starting_index - 1 :], 51 | sample_target[self.starting_index - 1 :], 52 | ) 53 | ] 54 | elif method == "basic": 55 | labels = [f"{i}" for i in sample_target[self.starting_index - 1 :].tolist()] 56 | 57 | return labels 58 | 59 | def pretty_print( 60 | self, sample_index: int, sentence_id: str, method="basic", x_remap="text" 61 | ): 62 | labels = self._create_labels(sample_index=sample_index, method=method) 63 | iteration_matrix, probability_matrix = self.finalize_matrices() 64 | iteration_matrix = iteration_matrix[sample_index, :, :].int().numpy() 65 | probability_matrix = probability_matrix[sample_index, :, 1:].numpy() 66 | 67 | mask: np.ndarray = np.zeros_like(iteration_matrix) 68 | i, j = 0, 0 69 | while i < iteration_matrix.shape[0] and j < iteration_matrix.shape[1]: 70 | if iteration_matrix[i, j]: 71 | mask[i, j] += 1 72 | j += 1 73 | else: 74 | mask[i, j:] = iteration_matrix[i, j:] 75 | i += 1 76 | 77 | probability_matrix = mask * probability_matrix 78 | 79 | fig = px.imshow( 80 | iteration_matrix + mask, 81 | binary_compression_level=0, 82 | title=f"Decoding Dependency Graph for sentence {sentence_id}", 83 | color_continuous_scale="Viridis", 84 | ) 85 | 86 | fig.update_xaxes( 87 | tickmode="array", 88 | tickvals=list(range(len(labels[1:]))), 89 | ticktext=[ 90 | self.tokenizer.convert_ids_to_tokens([x])[0] if x_remap else str(x) 91 | for x in labels[1:] 92 | ], 93 | tickangle=45, 94 | ) 95 | 96 | fig.update_traces( 97 | text=[ 98 | [f"{xy:.2f}" if xy > 0 else "" for xy in x] for x in probability_matrix 99 | ], 100 | texttemplate="%{text}", 101 | ) 102 | 103 | fig.update_layout( 104 | font=dict(family="Courier New, monospace", size=22, color="Black"), 105 | showlegend=False, 106 | coloraxis_showscale=False, 107 | ) 108 | 109 | fig.show() 110 | 111 | def plot_confusion_matrix( 112 | self, cm, target_names, title="Confusion matrix", cmap=None, normalize=True 113 | ): 114 | """ 115 | given a sklearn confusion matrix (cm), make a nice plot 116 | 117 | Arguments 118 | --------- 119 | cm: confusion matrix from sklearn.metrics.confusion_matrix 120 | 121 | target_names: given classification classes such as [0, 1, 2] 122 | the class names, for example: ['high', 'medium', 'low'] 123 | 124 | title: the text to display at the top of the matrix 125 | 126 | cmap: the gradient of the values displayed from matplotlib.pyplot.cm 127 | see http://matplotlib.org/examples/color/colormaps_reference.html 128 | plt.get_cmap('jet') or plt.cm.Blues 129 | 130 | normalize: If False, plot the raw numbers 131 | If True, plot the proportions 132 | 133 | Usage 134 | ----- 135 | plot_confusion_matrix(cm = cm, # confusion matrix created by 136 | # sklearn.metrics.confusion_matrix 137 | normalize = True, # show proportions 138 | target_names = y_labels_vals, # list of names of the classes 139 | title = best_estimator_name) # title of graph 140 | 141 | Citiation 142 | --------- 143 | http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html 144 | 145 | """ 146 | import itertools 147 | 148 | import matplotlib.pyplot as plt 149 | import numpy as np 150 | 151 | accuracy = np.trace(cm) / np.sum(cm).astype("float") 152 | misclass = 1 - accuracy 153 | 154 | if cmap is None: 155 | cmap = plt.get_cmap("Blues") 156 | 157 | plt.figure(figsize=(8, 6)) 158 | plt.imshow(cm, interpolation="nearest", cmap=cmap) 159 | plt.title(title) 160 | plt.colorbar() 161 | 162 | if target_names is not None: 163 | tick_marks = np.arange(len(target_names)) 164 | plt.xticks(tick_marks, target_names, rotation=45) 165 | plt.yticks(tick_marks, target_names) 166 | 167 | if normalize: 168 | cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] 169 | 170 | thresh = cm.max() / 1.5 if normalize else cm.max() / 2 171 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 172 | if normalize: 173 | plt.text( 174 | j, 175 | i, 176 | "{:0.4f}".format(cm[i, j]), 177 | horizontalalignment="center", 178 | color="white" if cm[i, j] > thresh else "black", 179 | ) 180 | else: 181 | plt.text( 182 | j, 183 | i, 184 | "{:,}".format(cm[i, j]), 185 | horizontalalignment="center", 186 | color="white" if cm[i, j] > thresh else "black", 187 | ) 188 | 189 | plt.tight_layout() 190 | plt.ylabel("True label") 191 | plt.xlabel( 192 | "Predicted label\naccuracy={:0.4f}; misclass={:0.4f}".format( 193 | accuracy, misclass 194 | ) 195 | ) 196 | plt.show() 197 | -------------------------------------------------------------------------------- /src/ipi/decoders/hybrid_jacobi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from more_itertools import sliced 3 | 4 | from src.ipi.decoders.mt_decoding import MTDecoder 5 | 6 | 7 | class HybridJacobiDecoder(MTDecoder): 8 | def __init__(self, tokenizer, model, gs_jaco_blocks, init_mode, initializer, percent=300, **kwargs): 9 | super().__init__(tokenizer, model, initializer, **kwargs) 10 | 11 | self.name = "hybrid_jacobi" 12 | self.acronym = "h" 13 | 14 | self.gs_jaco_blocks = gs_jaco_blocks 15 | 16 | self.init_mode = init_mode 17 | self.percent = percent 18 | 19 | @torch.no_grad() 20 | def decode( 21 | self, input_ids, attention_mask, target_len, gold_target, init_tensor=None, logits_preprocessor=None, *args, **kwargs 22 | ): 23 | key_cache = 1 24 | if init_tensor is None: 25 | init_tensor = torch.tensor( 26 | [self.pad_token_id] * target_len, device=self.device 27 | ) 28 | blocks = list(sliced(init_tensor, self.gs_jaco_blocks)) 29 | init_tensor = init_tensor.unsqueeze(0) 30 | total_past_key_values = None 31 | elif self.is_mbart: 32 | output = self.model( 33 | input_ids, 34 | attention_mask, 35 | decoder_input_ids=init_tensor[:, 0].unsqueeze(0), 36 | use_cache=True, 37 | ) 38 | encoder_last_hidden_state = output.encoder_last_hidden_state 39 | total_past_key_values = output.past_key_values 40 | init_tensor = init_tensor[:, 1:] 41 | blocks = list(sliced(init_tensor.squeeze(0), self.gs_jaco_blocks)) 42 | key_cache = 2 43 | else: 44 | init_tensor = init_tensor 45 | blocks = list(sliced(init_tensor.squeeze(0), self.gs_jaco_blocks)) 46 | total_past_key_values = None 47 | 48 | iteration_saved = 0 49 | base_value = 0 50 | 51 | for blocco in blocks: 52 | max_len = blocco.shape[-1] 53 | blocco_usr = init_tensor[:, base_value : base_value + max_len] 54 | for index in range(max_len): 55 | old_blocco = blocco_usr.detach().clone() 56 | trig = self.trig_eos( 57 | old_blocco, self.eos_token_id, init_tensor, base_value 58 | ) 59 | if trig is not None: 60 | return trig, (gold_target.shape[-1] - 1) - iteration_saved 61 | blocco_usr_new = blocco_usr[:, index:] 62 | if base_value == 0 and index == 0 and not self.is_mbart: 63 | output = self.model( 64 | input_ids, 65 | attention_mask, 66 | decoder_input_ids=blocco_usr_new, 67 | use_cache=True, 68 | past_key_values=total_past_key_values, 69 | ) 70 | encoder_last_hidden_state = output.encoder_last_hidden_state 71 | else: 72 | output = self.model( 73 | input_ids, 74 | attention_mask, 75 | decoder_input_ids=blocco_usr_new, 76 | encoder_outputs=(encoder_last_hidden_state, None, None), 77 | use_cache=True, 78 | past_key_values=total_past_key_values, 79 | ) 80 | 81 | total_past_key_values = self.limit_past_key_values( 82 | output.past_key_values, 83 | base_value + index + key_cache, 84 | ) 85 | 86 | logits = output.logits 87 | max_value = torch.argmax(logits, dim=-1) 88 | 89 | if logits_preprocessor is not None: 90 | logits_new = logits_preprocessor(init_tensor[:, :base_value + index + 1], logits[:, 0, :]) 91 | max_value_new = torch.argmax(logits_new, dim=-1) 92 | max_value[:,0] = max_value_new 93 | 94 | if ( 95 | max_value.shape[-1] 96 | == init_tensor[ 97 | :, base_value + index + 1 : base_value + max_len + 1 98 | ].shape[-1] 99 | ): 100 | init_tensor[ 101 | :, base_value + index + 1 : base_value + max_len + 1 102 | ] = max_value[:, :] 103 | else: 104 | # If last block remove the last token after EOS 105 | init_tensor[ 106 | :, base_value + index + 1 : base_value + max_len + 1 107 | ] = max_value[:, :-1] 108 | 109 | stop_condition, _, eos_cond = self.stopping_criterion( 110 | old_blocco, blocco_usr, eos=self.eos_token_id 111 | ) 112 | 113 | if stop_condition: 114 | if eos_cond >= 0: 115 | return ( 116 | init_tensor[:, : base_value + eos_cond + 1], 117 | (gold_target.shape[-1] - 1) - iteration_saved, 118 | ) 119 | if index + 1 != max_len: 120 | iteration_saved += max_len - index - 1 121 | total_past_key_values = self.limit_past_key_values( 122 | output.past_key_values, 123 | base_value + max_len + 1, 124 | ) 125 | break 126 | 127 | base_value += max_len 128 | 129 | total_res, total_iter = ( 130 | init_tensor, 131 | (gold_target.shape[-1] - 1) - iteration_saved, 132 | ) 133 | 134 | init_tensor = init_tensor[:, -1].clone().unsqueeze(0) 135 | 136 | #Go autoregressive until [EOS] 137 | while True and base_value != self.model.config.max_length - 1: 138 | index = 0 139 | output = self.model( 140 | input_ids, 141 | attention_mask, 142 | decoder_input_ids=init_tensor, 143 | encoder_outputs=(encoder_last_hidden_state, None, None), 144 | use_cache=True, 145 | past_key_values=total_past_key_values, 146 | ) 147 | encoder_last_hidden_state = output.encoder_last_hidden_state 148 | total_past_key_values = output.past_key_values 149 | logits = output.logits 150 | max_value = torch.argmax(logits, dim=-1) 151 | last = max_value[:, -1] 152 | if self.use_cache: 153 | init_tensor = last.unsqueeze(0) 154 | total_res = torch.cat((total_res, init_tensor), dim=1) 155 | 156 | index += 1 157 | if last[0].item() == self.eos_token_id: 158 | break 159 | return total_res, index + total_iter 160 | 161 | def initialize(self, input_ids, gold_autoregressive): 162 | if self.initializer is not None: 163 | if self.init_mode == "under": 164 | len = max(3, input_ids.shape[-1] - self.percent / 100 * input_ids.shape[-1]) 165 | m = int(len) 166 | elif self.init_mode == "over": 167 | len = input_ids.shape[-1] + self.percent / 100 * input_ids.shape[-1] 168 | m = int(len) 169 | elif self.init_mode == "fixed": 170 | m = 511 171 | else: 172 | m = gold_autoregressive.shape[-1] 173 | init_tensor, _ = self.initializer.init_translation(m) 174 | else: 175 | init_tensor = None 176 | 177 | return init_tensor 178 | 179 | def compute_decode_kwargs(self, input_ids, attention_mask, **kwargs): 180 | gold_autoregressive = self.generate_gold_autoregressive(input_ids, attention_mask) 181 | init_tensor = self.initialize(input_ids, gold_autoregressive) 182 | logits_preprocessor = self.generate_logits_preprocessor(input_ids) 183 | 184 | return{ 185 | "init_tensor": init_tensor.clone(), 186 | "gold_target": gold_autoregressive, 187 | "target_len": gold_autoregressive.shape[-1], 188 | "logits_preprocessor": logits_preprocessor 189 | } -------------------------------------------------------------------------------- /src/dataset/iwslt_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import typing as t 3 | 4 | import datasets 5 | import torch 6 | from datasets.utils.download_manager import DownloadManager 7 | from torch.utils.data.dataset import Dataset 8 | 9 | from src.utils.utils import clean_text 10 | 11 | 12 | class Iwslt(Dataset): 13 | def __init__( 14 | self, 15 | version: str = "17", 16 | src_lan: str = "en", 17 | tgt_lan: str = "ro", 18 | data_dir: str = None, 19 | hugginface_tokenizer=None, 20 | split: str = None, 21 | ): 22 | self.version = version 23 | self.src_lan = src_lan 24 | self.tgt_lan = tgt_lan 25 | self.max_length = 511 26 | 27 | self.dl = DownloadManager() 28 | 29 | self.name = f"iwslt{self.version}" 30 | 31 | self.version2folder = { 32 | "15": os.path.join(data_dir, "2015-01/texts"), 33 | "17": os.path.join(data_dir, "2017-01-trnted/texts"), 34 | } 35 | self.version2years = { 36 | "15": {"train_and_test": [2010, 2011, 2012, 2013], "dev": [2010]}, 37 | "17": { 38 | "train_and_test": [2010, 2011, 2012, 2013, 2014, 2015], 39 | "dev": [2010], 40 | }, 41 | } 42 | 43 | data_file = f"{self.version2folder[version]}/{src_lan}/{tgt_lan}/{src_lan}-{tgt_lan}.tgz" 44 | 45 | splitted_generators = self._split_generators(data_file) 46 | self.translation_dataset = self.load_dataset(splitted_generators, split=split) 47 | 48 | with torch.no_grad(): 49 | self.tokenizer = hugginface_tokenizer 50 | 51 | def load_dataset( 52 | self, 53 | splitted_generators: t.List[datasets.SplitGenerator], 54 | split: str, 55 | ) -> t.List[t.Dict]: 56 | splitted_generators = self.concat_dataset(splitted_generators, split) 57 | 58 | return list( 59 | self._generate_examples( 60 | source_files=splitted_generators.gen_kwargs["source_files"], 61 | target_files=splitted_generators.gen_kwargs["target_files"], 62 | ) 63 | ) 64 | 65 | @staticmethod 66 | def concat_dataset( 67 | splitted_generators: t.List[datasets.SplitGenerator], 68 | split: str, 69 | ) -> datasets.SplitGenerator: 70 | split2ix = {"train": 0, "test": 1, "validation": 2} 71 | assert ( 72 | split in split2ix 73 | ), "Iwslt: split must be either train or test on validation" 74 | if split is not None: 75 | return splitted_generators[split2ix[split]] 76 | 77 | def _split_generators(self, data_file: str) -> t.List[datasets.SplitGenerator]: 78 | """Returns SplitGenerators.""" 79 | pair = f"{self.src_lan}-{self.tgt_lan}" 80 | dl_dir = self.dl.extract(data_file) 81 | data_dir = os.path.join(dl_dir, f"{self.src_lan}-{self.tgt_lan}") 82 | 83 | years = self.version2years[self.version]["train_and_test"] 84 | dev = self.version2years[self.version]["dev"] 85 | 86 | return [ 87 | datasets.SplitGenerator( 88 | name=datasets.Split.TRAIN, 89 | # These kwargs will be passed to _generate_examples 90 | gen_kwargs={ 91 | "source_files": [ 92 | os.path.join( 93 | data_dir, 94 | f"train.tags.{pair}.{self.src_lan}", 95 | ) 96 | ], 97 | "target_files": [ 98 | os.path.join( 99 | data_dir, 100 | f"train.tags.{pair}.{self.tgt_lan}", 101 | ) 102 | ], 103 | "split": "train", 104 | }, 105 | ), 106 | datasets.SplitGenerator( 107 | name=datasets.Split.TEST, 108 | # These kwargs will be passed to _generate_examples 109 | gen_kwargs={ 110 | "source_files": [ 111 | os.path.join( 112 | data_dir, 113 | f"IWSLT{self.version}.TED.tst{year}.{pair}.{self.src_lan}.xml", 114 | ) 115 | for year in years 116 | ], 117 | "target_files": [ 118 | os.path.join( 119 | data_dir, 120 | f"IWSLT{self.version}.TED.tst{year}.{pair}.{self.tgt_lan}.xml", 121 | ) 122 | for year in years 123 | ], 124 | "split": "test", 125 | }, 126 | ), 127 | datasets.SplitGenerator( 128 | name=datasets.Split.VALIDATION, 129 | # These kwargs will be passed to _generate_examples 130 | gen_kwargs={ 131 | "source_files": [ 132 | os.path.join( 133 | data_dir, 134 | f"IWSLT{self.version}.TED.dev{year}.{pair}.{self.src_lan}.xml", 135 | ) 136 | for year in dev 137 | ], 138 | "target_files": [ 139 | os.path.join( 140 | data_dir, 141 | f"IWSLT{self.version}.TED.dev{year}.{pair}.{self.tgt_lan}.xml", 142 | ) 143 | for year in dev 144 | ], 145 | "split": "validation", 146 | }, 147 | ), 148 | ] 149 | 150 | def _generate_examples( 151 | self, source_files: t.List[str], target_files: t.List[str] 152 | ) -> t.List[t.Dict]: 153 | """Yields examples.""" 154 | for source_file, target_file in zip(source_files, target_files): 155 | with open(source_file, "r", encoding="utf-8") as sf: 156 | with open(target_file, "r", encoding="utf-8") as tf: 157 | for source_row, target_row in zip(sf, tf): 158 | source_row = source_row.strip() 159 | target_row = target_row.strip() 160 | 161 | if source_row.startswith("<"): 162 | if source_row.startswith("..... 164 | # Very simple code instead of regex or xml parsing 165 | part1 = source_row.split(">")[1] 166 | source_row = part1.split("<")[0] 167 | part1 = target_row.split(">")[1] 168 | target_row = part1.split("<")[0] 169 | 170 | source_row = source_row.strip() 171 | target_row = target_row.strip() 172 | else: 173 | continue 174 | 175 | yield { 176 | "translation": { 177 | self.src_lan: source_row, 178 | self.tgt_lan: target_row, 179 | } 180 | } 181 | 182 | def collate_fn(self, batch): 183 | 184 | batch_source = [b[0] for b in batch] 185 | batch_target = [b[1] for b in batch] 186 | 187 | encoded_source = self.tokenizer( 188 | batch_source, 189 | padding=True, 190 | return_tensors="pt", 191 | ) 192 | encoded_target = self.tokenizer( 193 | batch_target, 194 | padding=True, 195 | return_tensors="pt", 196 | ) 197 | 198 | return { 199 | "source": { 200 | "input_ids": encoded_source["input_ids"], 201 | "attention_mask": encoded_source["attention_mask"], 202 | "sentences": batch_source, 203 | }, 204 | "target": { 205 | "input_ids": encoded_target["input_ids"], 206 | "attention_mask": encoded_target["attention_mask"], 207 | "sentences": batch_target, 208 | }, 209 | } 210 | 211 | def __len__(self) -> int: 212 | return len(self.translation_dataset) 213 | 214 | def __getitem__(self, idx: int) -> t.Tuple[str, str]: 215 | sample = self.translation_dataset[idx] 216 | source = sample["translation"][self.src_lan] 217 | target = sample["translation"][self.tgt_lan] 218 | 219 | return source, target 220 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from transformers import ( 3 | AutoModelForSeq2SeqLM, 4 | AutoTokenizer, 5 | MBartForConditionalGeneration, 6 | ) 7 | 8 | from src.bench import MTBenchmarker 9 | from src.dataset.flores_dataset import Flores 10 | from src.dataset.ittb_dataset import Ittb 11 | from src.dataset.iwslt_dataset import Iwslt 12 | from src.dataset.wmt_dataset import Wmt 13 | from src.ipi.decoders.autoregressive import AutoregressiveDecoder 14 | from src.ipi.decoders.beam_search import BeamSearchDecoder 15 | from src.ipi.decoders.gs_jacobi import GSJacobiDecoder 16 | from src.ipi.decoders.hybrid_jacobi import HybridJacobiDecoder 17 | from src.ipi.decoders.jacobi import JacobiDecoder 18 | from src.ipi.initializer import Initializer 19 | from src.ipi.decoders.mt_decoding import MTDecoder 20 | from src.utils.beam_search import BeamSearcher 21 | from src.utils.utils import retrieve_samples 22 | 23 | 24 | def load_tokenizer(cfg): 25 | # MBart 26 | mapping_dict = { 27 | "ar": "ar_AR", 28 | "cs": "cs_CZ", 29 | "de": "de_DE", 30 | "en": "en_XX", 31 | "es": "es_XX", 32 | "et": "et_EE", 33 | "fi": "fi_FI", 34 | "fr": "fr_XX", 35 | "gu": "gu_IN", 36 | "hi": "hi_IN", 37 | "it": "it_IT", 38 | "ja": "ja_XX", 39 | "kk": "kk_KZ", 40 | "ko": "ko_KR", 41 | "lt": "lt_LT", 42 | "lv": "lv_LV", 43 | "my": "my_MM", 44 | "ne": "ne_NP", 45 | "nl": "nl_XX", 46 | "ro": "ro_RO", 47 | "ru": "ru_RU", 48 | "si": "si_LK", 49 | "tr": "tr_TR", 50 | "vi": "vi_VN", 51 | "zh": "zh_CN", 52 | "af": "af_ZA", 53 | "az": "az_AZ", 54 | "bn": "bn_IN", 55 | "fa": "fa_IR", 56 | "he": "he_IL", 57 | "hr": "hr_HR", 58 | "id": "id_ID", 59 | "ka": "ka_GE", 60 | "km": "km_KH", 61 | "mk": "mk_MK", 62 | "ml": "ml_IN", 63 | "mn": "mn_MN", 64 | "mr": "mr_IN", 65 | "pl": "pl_PL", 66 | "ps": "ps_AF", 67 | "pt": "pt_XX", 68 | "sv": "sv_SE", 69 | "sw": "sw_KE", 70 | "ta": "ta_IN", 71 | "te": "te_IN", 72 | "th": "th_TH", 73 | "tl": "tl_XX", 74 | "uk": "uk_UA", 75 | "ur": "ur_PK", 76 | "xh": "xh_ZA", 77 | "gl": "gl_ES", 78 | "sl": "sl_SI", 79 | } 80 | 81 | if "mbart" in cfg.model_name: 82 | tokenizer = AutoTokenizer.from_pretrained( 83 | cfg.model_name, 84 | src_lang=mapping_dict[cfg.src_lang], 85 | tgt_lang=mapping_dict[cfg.tgt_lang], 86 | use_fast=False, 87 | ) 88 | else: 89 | print(cfg.model_name) 90 | tokenizer = AutoTokenizer.from_pretrained(cfg.model_name) 91 | 92 | return tokenizer 93 | 94 | 95 | def load_model(cfg): 96 | if "mbart" in cfg.model_name: 97 | model = MBartForConditionalGeneration.from_pretrained(cfg.model_name).to( 98 | cfg.device 99 | ) 100 | else: 101 | model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_name).to(cfg.device) 102 | 103 | return model 104 | 105 | 106 | def load_dataset(tokenizer, cfg): 107 | # Wmt-xx-xx-xx 108 | if cfg.name == "wmt": 109 | split = cfg.split 110 | if cfg.subset.use_subset: 111 | split = f"{cfg.split}[{cfg.subset.start}:{cfg.subset.end + 1}]" 112 | 113 | dataset = Wmt( 114 | version=cfg.version, 115 | src_lan=cfg.src_lang, 116 | tgt_lan=cfg.tgt_lang, 117 | hugginface_tokenizer=tokenizer, 118 | split=split, 119 | ) 120 | # Iwsltxx-xx-xx 121 | elif cfg.name == "iwslt": 122 | dataset = Iwslt( 123 | data_dir=cfg.data_dir, 124 | version=str(cfg.version), 125 | src_lan=cfg.src_lang, 126 | tgt_lan=cfg.tgt_lang, 127 | hugginface_tokenizer=tokenizer, 128 | split=cfg.split, 129 | ) 130 | elif cfg.name == "ittb": 131 | dataset = Ittb( 132 | src_lan=cfg.src_lang, 133 | tgt_lan=cfg.tgt_lang, 134 | hugginface_tokenizer=tokenizer, 135 | split=cfg.split, 136 | ) 137 | elif cfg.name == "flores": 138 | dataset = Flores( 139 | src_lan=cfg.src_lang, 140 | tgt_lan=cfg.tgt_lang, 141 | hugginface_tokenizer=tokenizer, 142 | split=cfg.split, 143 | ) 144 | else: 145 | raise ValueError(f"{cfg.dataset.name} is not yet implemented") 146 | 147 | return dataset 148 | 149 | 150 | def load_initializer(tokenizer, cfg): 151 | if cfg.use_initializer: 152 | initializer = Initializer( 153 | src_len=cfg.src_lang, 154 | tgt_len=cfg.tgt_lang, 155 | hugginface_tokenizer=tokenizer, 156 | use_init=cfg.use_init, 157 | device=cfg.device, 158 | ) 159 | 160 | else: 161 | initializer = None 162 | 163 | return initializer 164 | 165 | 166 | def compute_beam_search(cfg, model, dataset): 167 | initializer = load_initializer(dataset.tokenizer, cfg.initializer) 168 | 169 | decoder = MTDecoder( 170 | tokenizer=dataset.tokenizer, 171 | model=model, 172 | use_cache=cfg.decoder.use_cache, 173 | gs_jaco_blocks=cfg.decoder.gs_jaco_blocks, 174 | device=cfg.device, 175 | initializer=initializer 176 | ) 177 | 178 | beam_searcher = BeamSearcher( 179 | model=model, 180 | dataset=dataset, 181 | initializer=initializer, 182 | decoder=decoder, 183 | batch_size=cfg.beam_search.batch_size, 184 | num_beams=cfg.beam_search.num_beams, 185 | device=cfg.beam_search.device, 186 | no_repeat_ngram_size=2, 187 | early_stopping=True, 188 | result_dir=cfg.beam_search.result_dir, 189 | ) 190 | 191 | beam_searcher.compute_beam_search(cfg) 192 | 193 | 194 | def load_decoders(cfg, tokenizer, model, initializer): 195 | decoders = [] 196 | for decoder in cfg.decoder.decoders: 197 | if decoder == "autoregressive": 198 | dec = AutoregressiveDecoder( 199 | tokenizer=tokenizer, 200 | model=model, 201 | initializer=initializer, 202 | use_cache=cfg.decoder.use_cache, 203 | device=cfg.decoder.device 204 | ) 205 | elif decoder == "jacobi": 206 | dec = JacobiDecoder( 207 | tokenizer=tokenizer, 208 | model=model, 209 | initializer=initializer, 210 | use_cache=cfg.decoder.use_cache, 211 | device=cfg.decoder.device 212 | ) 213 | elif decoder == "gs_jacobi": 214 | dec = GSJacobiDecoder( 215 | tokenizer=tokenizer, 216 | model=model, 217 | initializer=initializer, 218 | gs_jaco_blocks=cfg.decoder.gs_jaco_blocks, 219 | init_mode="", 220 | use_cache=cfg.decoder.use_cache, 221 | device=cfg.decoder.device 222 | ) 223 | elif decoder == "h_jacobi": 224 | dec = HybridJacobiDecoder( 225 | tokenizer=tokenizer, 226 | model=model, 227 | initializer=initializer, 228 | init_mode="fixed", 229 | gs_jaco_blocks=cfg.decoder.gs_jaco_blocks, 230 | use_cache=cfg.decoder.use_cache, 231 | device=cfg.decoder.device 232 | ) 233 | elif decoder == "beam_search": 234 | dec = BeamSearchDecoder( 235 | tokenizer=tokenizer, 236 | model=model, 237 | initializer=initializer, 238 | num_beams=cfg.beam_search.num_beams, 239 | early_stopping=True, 240 | ) 241 | else: 242 | raise ValueError(f"{decoder} decoder have not been implemented yet.") 243 | 244 | decoders.append(dec) 245 | 246 | return decoders 247 | 248 | 249 | def compute_benchmark(cfg, tokenizer, dataset, model): 250 | initializer = load_initializer(tokenizer, cfg.initializer) 251 | decoders = load_decoders(cfg, tokenizer, model, initializer) 252 | 253 | benchmarker = MTBenchmarker( 254 | dataset=dataset, 255 | decoders=decoders, 256 | src_lang=cfg.model.src_lang, 257 | tgt_lang=cfg.model.tgt_lang, 258 | result_dir=cfg.bench.result_dir, 259 | device=cfg.bench.device, 260 | debug=True, 261 | ) 262 | benchmarker.compute_total_time() 263 | 264 | 265 | @hydra.main(config_path="conf", config_name="config", version_base="1.1") 266 | def main(cfg): 267 | tokenizer = load_tokenizer(cfg.model) 268 | 269 | model = load_model(cfg.model) 270 | 271 | dataset = load_dataset(tokenizer, cfg.dataset) 272 | 273 | if "benchmark" in cfg.task: 274 | compute_benchmark(cfg, tokenizer, dataset, model) 275 | elif "beam_search" in cfg.task: 276 | compute_beam_search(cfg, model, dataset) 277 | elif "sample" in cfg.task: 278 | retrieve_samples(cfg.sample.path, dataset) 279 | else: 280 | raise ValueError(f"{cfg.task} is not yet implemented") 281 | 282 | 283 | if __name__ == "__main__": 284 | main() 285 | -------------------------------------------------------------------------------- /src/viz/visualize.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import hashlib 3 | import json 4 | from pathlib import Path 5 | 6 | import argparse 7 | import datasets 8 | import numpy as np 9 | import plotly.express as px 10 | import torch.utils.data 11 | from datasets import load_from_disk, Dataset 12 | from transformers import ( 13 | AutoModelForSeq2SeqLM, 14 | AutoTokenizer, 15 | PreTrainedTokenizer, 16 | PreTrainedModel, 17 | ) 18 | 19 | from src import PROJECT_ROOT 20 | from src.ipi.decoders.jacobi import JacobiDecoder 21 | from src.ipi.initializer import Initializer 22 | from src.ipi.decoders.mt_decoding import MTDecoder, generate_target 23 | 24 | 25 | def add_iteration_matrix( 26 | sample: dict, src_lang: str, tgt_lang: str, model, tokenizer, initializer, decoder 27 | ): 28 | source = sample["translation"][src_lang] 29 | target = sample["translation"][tgt_lang] 30 | 31 | encoded_source = tokenizer( 32 | source, 33 | return_tensors="pt", 34 | ) 35 | encoded_target = tokenizer( 36 | target, 37 | return_tensors="pt", 38 | ) 39 | 40 | x = { 41 | "source": { 42 | "input_ids": encoded_source["input_ids"], 43 | "attention_mask": encoded_source["attention_mask"], 44 | "sentences": source, 45 | }, 46 | "target": { 47 | "input_ids": encoded_target["input_ids"], 48 | "attention_mask": encoded_target["attention_mask"], 49 | "sentences": target, 50 | }, 51 | } 52 | 53 | input_ids = encoded_source["input_ids"].to(device) 54 | attention_mask = encoded_source["attention_mask"].to(device) 55 | gold_output = generate_target( 56 | model=model, 57 | tokenizer=tokenizer, 58 | input_ids=input_ids.to(device), 59 | attention_mask=attention_mask.to(device), 60 | is_mbart=False, 61 | ) 62 | 63 | init_tensor, _ = initializer.init_translation(gold_output.shape[-1]) 64 | 65 | return_tensor, index, ddg = decoder.decode( 66 | input_ids=input_ids, 67 | attention_mask=attention_mask, 68 | init_tensor=init_tensor, 69 | compute_ddg=True, 70 | ) 71 | 72 | iteration_matrix, probability_matrix = ddg.finalize_matrices() 73 | parallel_steps = iteration_matrix.shape[1] 74 | gold_steps = gold_output.shape[1] 75 | 76 | return dict( 77 | gold_output=gold_output[0], 78 | iteration_matrix=iteration_matrix[0], 79 | probability_matrix=probability_matrix[0], 80 | parallel_steps=parallel_steps, 81 | gold_steps=gold_steps, 82 | score=gold_steps - parallel_steps, 83 | ) 84 | 85 | 86 | def enrich_dataset( 87 | run_info: dict, 88 | device: str, 89 | src_lang: str, 90 | tgt_lang: str, 91 | dataset_name: str, 92 | dataset_key: str, 93 | tokenizer: PreTrainedTokenizer, 94 | model: PreTrainedModel, 95 | force_recompute: bool = False, 96 | ) -> Dataset: 97 | # MarianMT 98 | run_dir: Path = run_info["run_dir"] 99 | 100 | dataset_path: Path = run_dir / "dataset" 101 | 102 | if run_dir.exists() and not force_recompute: 103 | dataset = load_from_disk(str(run_dir / "dataset")) 104 | else: 105 | initializer = Initializer(src_lang, tgt_lang, tokenizer, use_init=False) 106 | model.eval() 107 | 108 | # decoder = MTDecoder( 109 | # tokenizer=tokenizer, 110 | # model=model, 111 | # use_cache=False, 112 | # gs_jaco_blocks=5, 113 | # device=device, 114 | # initializer=initializer 115 | # ) 116 | 117 | decoder = JacobiDecoder( 118 | tokenizer=tokenizer, 119 | model=model, 120 | initializer=initializer, 121 | use_cache=False, 122 | device=device 123 | ) 124 | 125 | dataset = datasets.load_dataset( 126 | dataset_name, 127 | dataset_key, 128 | split="test[0:10]", 129 | data_dir=str(PROJECT_ROOT / "hf_data"), 130 | cache_dir=str(PROJECT_ROOT / "hf_cache"), 131 | ).map( 132 | function=functools.partial( 133 | add_iteration_matrix, 134 | src_lang=src_lang, 135 | tgt_lang=tgt_lang, 136 | model=model, 137 | initializer=initializer, 138 | decoder=decoder, 139 | tokenizer=tokenizer, 140 | ) 141 | ) 142 | dataset.save_to_disk(str(dataset_path)) 143 | 144 | json.dump( 145 | run_info, 146 | fp=(run_dir / "run_info.json").open("w", encoding="utf-8"), 147 | indent=4, 148 | default=lambda x: str(x) 149 | ) 150 | 151 | return dataset 152 | 153 | 154 | def draw(sample: dict, tokenizer: PreTrainedTokenizer, starting_index: int): 155 | labels = [f"{i}" for i in sample["gold_output"][starting_index - 1 :]] 156 | iteration_matrix = torch.as_tensor(sample["iteration_matrix"]) 157 | probability_matrix = torch.as_tensor(sample["probability_matrix"]) 158 | iteration_matrix = iteration_matrix[:, :].int().numpy() 159 | probability_matrix = probability_matrix[:, 1:].numpy() 160 | 161 | mask: np.ndarray = np.zeros_like(iteration_matrix) 162 | i, j = 0, 0 163 | while i < iteration_matrix.shape[0] and j < iteration_matrix.shape[1]: 164 | if iteration_matrix[i, j]: 165 | mask[i, j] += 1 166 | j += 1 167 | else: 168 | mask[i, j:] = iteration_matrix[i, j:] 169 | i += 1 170 | 171 | probability_matrix = mask * probability_matrix 172 | 173 | fig = px.imshow( 174 | iteration_matrix + mask, 175 | binary_compression_level=0, 176 | # title=f"Decoding Dependency Graph for sentence {sentence_id}", 177 | color_continuous_scale="Viridis", 178 | ) 179 | 180 | fig.update_xaxes( 181 | tickmode="array", 182 | tickvals=list(range(len(labels[1:]))), 183 | ticktext=[tokenizer.convert_ids_to_tokens([x])[0] for x in labels[1:]], 184 | tickangle=45, 185 | ) 186 | 187 | fig.update_traces( 188 | text=[[f"{xy:.2f}" if xy > 0 else "" for xy in x] for x in probability_matrix], 189 | texttemplate="%{text}", 190 | ) 191 | 192 | fig.update_layout( 193 | font=dict(family="Courier New, monospace", size=22, color="Black"), 194 | showlegend=False, 195 | coloraxis_showscale=False, 196 | margin=dict(l=0, r=0, b=0, t=0), 197 | paper_bgcolor="rgba(0,0,0,0)", 198 | plot_bgcolor="rgba(0,0,0,0)", 199 | ) 200 | 201 | return fig 202 | 203 | 204 | if __name__ == "__main__": 205 | 206 | 207 | parser = argparse.ArgumentParser(description='Dependency Graph Visualizer (DDGviz)') 208 | 209 | parser.add_argument('--src', default="ro", help='src language') 210 | parser.add_argument('--tgt', default="en", help='target language') 211 | parser.add_argument('--dataset', default="wmt16", help='Dataset name') 212 | # parser.add_argument('--examples', default=[1566, 960], help='Examples to print with DDGviz') 213 | parser.add_argument('--examples', default=[1, 4], help='Examples to print with DDGviz') 214 | 215 | args = parser.parse_args() 216 | 217 | device = "cpu" 218 | src_lang = args.src 219 | tgt_lang = args.tgt 220 | dataset_name: str = args.dataset 221 | dataset_key: str = f"{src_lang}-{tgt_lang}" 222 | langs = {src_lang, tgt_lang} 223 | examples_to_print = args.examples 224 | 225 | 226 | if "en" in langs: 227 | dataset_key = ( 228 | f"{src_lang}-{tgt_lang}" if "en" == tgt_lang else f"{tgt_lang}-{src_lang}" 229 | ) 230 | 231 | model_src_lang = src_lang 232 | if model_src_lang == "ro": 233 | model_src_lang: str = "roa" 234 | 235 | dataset_split: str = "test" 236 | model_name: str = f"Helsinki-NLP/opus-mt-{model_src_lang}-{tgt_lang}" 237 | # model_name: str = "zannabethl/opus-mt-en-ro-finetuned-en-to-ro" 238 | 239 | tokenizer = AutoTokenizer.from_pretrained(model_name) 240 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) 241 | 242 | run_info = dict( 243 | model_name=model_name, 244 | source_lang=src_lang, 245 | target_lang=tgt_lang, 246 | dataset_name=dataset_name, 247 | dataset_key=dataset_key, 248 | split=dataset_split, 249 | ) 250 | 251 | run_info_hash: str = hashlib.md5( 252 | json.dumps(run_info).encode(encoding="utf-8") 253 | ).hexdigest() 254 | run_dir: Path = PROJECT_ROOT / "iteration_matrix" / run_info_hash 255 | 256 | run_info["run_dir"] = run_dir 257 | 258 | dataset = ( 259 | enrich_dataset( 260 | run_info=run_info, 261 | device=device, 262 | src_lang=src_lang, 263 | tgt_lang=tgt_lang, 264 | dataset_name=dataset_name, 265 | dataset_key=dataset_key, 266 | tokenizer=tokenizer, 267 | model=model, 268 | ) 269 | .map(function=lambda x, i: {"index": i}, with_indices=True) 270 | .select(examples_to_print) 271 | ) 272 | 273 | starting_index: int = 1 274 | images_dir: Path = run_dir / "images" 275 | images_dir.mkdir(exist_ok=True) 276 | for sample in dataset: 277 | fig = draw(sample=sample, tokenizer=tokenizer, starting_index=1) 278 | # fig.show() 279 | print(f"Index: {sample['index']}") 280 | print(f"Translations: {sample['translation']}") 281 | print( 282 | f"Gold output: {tokenizer.decode(sample['gold_output'], skip_special_tokens=True)}" 283 | ) 284 | print() 285 | # input() 286 | # continue 287 | fig.write_image(images_dir / f"{sample['index']}.png", width=1800, height=1500) 288 | x = {x: sample[x] for x in ("translation", "gold_output", "index", "score")} 289 | x["gold_output"] = tokenizer.decode(sample["gold_output"]) 290 | 291 | (images_dir / f"{sample['index']}.json").write_text( 292 | json.dumps(x, indent=4, sort_keys=True), encoding="utf-8" 293 | ) 294 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import random 4 | import re 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from torch.nn.modules import linear 10 | import datasets 11 | 12 | def makedirs(path): 13 | if path.endswith((".tsv", ".csv", ".txt")): 14 | path = "/".join(path.split("/")[:-1]) 15 | 16 | if not os.path.exists(path): 17 | os.makedirs(path) 18 | 19 | def check_zero_division(a, b): 20 | return "na" if b == 0 else round(a / b, 3) 21 | 22 | def get_logits_preprocessor(model, input_ids, eos_token_id): 23 | logits_preprocessor = model._get_logits_processor( 24 | repetition_penalty=None, 25 | no_repeat_ngram_size=None, 26 | encoder_no_repeat_ngram_size=None, 27 | input_ids_seq_length=1, 28 | encoder_input_ids=input_ids, 29 | bad_words_ids=None, 30 | min_length=0, 31 | max_length=model.config.max_length, 32 | eos_token_id=eos_token_id, 33 | forced_bos_token_id=None, 34 | forced_eos_token_id=None, 35 | prefix_allowed_tokens_fn=None, 36 | num_beams=1, 37 | num_beam_groups=1, 38 | diversity_penalty=None, 39 | remove_invalid_values=None, 40 | exponential_decay_length_penalty=None, 41 | logits_processor=[], 42 | renormalize_logits=None 43 | ) 44 | return logits_preprocessor 45 | 46 | 47 | 48 | def retrieve_model_name(model_name): 49 | if "opus" in model_name: 50 | return "opus" 51 | if "mbart" in model_name: 52 | if "50" in model_name: 53 | return "mbart_50" 54 | return "mbart" 55 | 56 | 57 | def seed_everything(seed: int): 58 | random.seed(seed) 59 | os.environ["PYTHONHASHSEED"] = str(seed) 60 | np.random.seed(seed) 61 | torch.manual_seed(seed) 62 | torch.cuda.manual_seed(seed) 63 | torch.backends.cudnn.deterministic = True 64 | torch.backends.cudnn.benchmark = True 65 | 66 | 67 | def retrieve_map_languages_flores(lan): 68 | lang_map = { 69 | "ab": "Abkhaz", 70 | "aa": "Afar", 71 | "af": "Afrikaans", 72 | "ak": "Akan", 73 | "sq": "Albanian", 74 | "am": "Amharic", 75 | "ar": "Arabic", 76 | "an": "Aragonese", 77 | "hy": "Armenian", 78 | "as": "Assamese", 79 | "av": "Avaric", 80 | "ae": "Avestan", 81 | "ay": "Aymara", 82 | "az": "Azerbaijani", 83 | "bm": "Bambara", 84 | "ba": "Bashkir", 85 | "eu": "Basque", 86 | "be": "Belarusian", 87 | "bn": "Bengali", 88 | "bh": "Bihari", 89 | "bi": "Bislama", 90 | "bs": "Bosnian", 91 | "br": "Breton", 92 | "bg": "Bulgarian", 93 | "my": "Burmese", 94 | "ca": "Catalan", 95 | "ch": "Chamorro", 96 | "ce": "Chechen", 97 | "ny": "Chichewa", 98 | "zh": "Chinese", 99 | "cv": "Chuvash", 100 | "kw": "Cornish", 101 | "co": "Corsican", 102 | "cr": "Cree", 103 | "hr": "Croatian", 104 | "cs": "Czech", 105 | "da": "Danish", 106 | "dv": "Divehi", 107 | "nl": "Dutch", 108 | "dz": "Dzongkha", 109 | "en": "English", 110 | "eo": "Esperanto", 111 | "et": "Estonian", 112 | "ee": "Ewe", 113 | "fo": "Faroese", 114 | "fj": "Fijian", 115 | "fi": "Finnish", 116 | "fr": "Franch", 117 | "ff": "Fula", 118 | "gl": "Galician", 119 | "ka": "Georgian", 120 | "de": "German", 121 | "el": "Greek", 122 | "gn": "Guaraní", 123 | "gu": "Gujarati", 124 | "ht": "Haitian", 125 | "ha": "Hausa", 126 | "he": "Hebrew", 127 | "hz": "Herero", 128 | "hi": "Hindi", 129 | "ho": "Hiri Motu", 130 | "hu": "Hungarian", 131 | "ia": "Interlingua", 132 | "id": "Indonesian", 133 | "ie": "Interlingue", 134 | "ga": "Irish", 135 | "ig": "Igbo", 136 | "ik": "Inupiaq", 137 | "io": "Ido", 138 | "is": "Icelandic", 139 | "it": "Italian", 140 | "iu": "Inuktitut", 141 | "ja": "Japanese", 142 | "jv": "Javanese", 143 | "kl": "Kalaallisut", 144 | "kn": "Kannada", 145 | "kr": "Kanuri", 146 | "ks": "Kashmiri", 147 | "kk": "Kazakh", 148 | "km": "Khmer", 149 | "ki": "Kikuyu", 150 | "rw": "Kinyarwanda", 151 | "ky": "Kirghiz", 152 | "kv": "Komi", 153 | "kg": "Kongo", 154 | "ko": "Korean", 155 | "ku": "Kurdish", 156 | "kj": "Kwanyama", 157 | "la": "Latin", 158 | "lb": "Luxembourgish", 159 | "lg": "Luganda", 160 | "li": "Limburgish", 161 | "ln": "Lingala", 162 | "lo": "Lao", 163 | "lt": "Lithuanian", 164 | "lu": "Luba-Katanga", 165 | "lv": "Latvian", 166 | "gv": "Manx", 167 | "mk": "Macedonian", 168 | "mg": "Malagasy", 169 | "ms": "Malay", 170 | "ml": "Malayalam", 171 | "mt": "Maltese", 172 | "mi": "Māori", 173 | "mr": "Marathi", 174 | "mh": "Marshallese", 175 | "mn": "Mongolian", 176 | "na": "Nauru", 177 | "nv": "Navajo", 178 | "nb": "Norwegian Bokmål", 179 | "nd": "North Ndebele", 180 | "ne": "Nepali", 181 | "ng": "Ndonga", 182 | "nn": "Norwegian Nynorsk", 183 | "no": "Norwegian", 184 | "ii": "Nuosu", 185 | "nr": "South Ndebele", 186 | "oc": "Occitan", 187 | "oj": "Ojibwe", 188 | "cu": "Old Church Slavonic", 189 | "om": "Oromo", 190 | "or": "Oriya", 191 | "os": "Ossetian", 192 | "pa": "Panjabi", 193 | "pi": "Pāli", 194 | "fa": "Persian", 195 | "pl": "Polish", 196 | "ps": "Pashto", 197 | "pt": "Portuguese", 198 | "qu": "Quechua", 199 | "rm": "Romansh", 200 | "rn": "Kirundi", 201 | "ro": "Romanian", 202 | "ru": "Russian", 203 | "sa": "Sanskrit", 204 | "sc": "Sardinian", 205 | "sd": "Sindhi", 206 | "se": "Northern Sami", 207 | "sm": "Samoan", 208 | "sg": "Sango", 209 | "sr": "Serbian", 210 | "gd": "Scottish Gaelic", 211 | "sn": "Shona", 212 | "si": "Sinhala", 213 | "sk": "Slovak", 214 | "sl": "Slovene", 215 | "so": "Somali", 216 | "st": "Southern Sotho", 217 | "es": "Spanish", 218 | "su": "Sundanese", 219 | "sw": "Swahili", 220 | "ss": "Swati", 221 | "sv": "Swedish", 222 | "ta": "Tamil", 223 | "te": "Telugu", 224 | "tg": "Tajik", 225 | "th": "Thai", 226 | "ti": "Tigrinya", 227 | "bo": "Tibetan", 228 | "tk": "Turkmen", 229 | "tl": "Tagalog", 230 | "tn": "Tswana", 231 | "to": "Tonga", 232 | "tr": "Turkish", 233 | "ts": "Tsonga", 234 | "tt": "Tatar", 235 | "tw": "Twi", 236 | "ty": "Tahitian", 237 | "ug": "Uighur", 238 | "uk": "Ukrainian", 239 | "ur": "Urdu", 240 | "uz": "Uzbek", 241 | "ve": "Venda", 242 | "vi": "Vietnamese", 243 | "vo": "Volapük", 244 | "wa": "Walloon", 245 | "cy": "Welsh", 246 | "wo": "Wolof", 247 | "fy": "Western Frisian", 248 | "xh": "Xhosa", 249 | "yi": "Yiddish", 250 | "yo": "Yoruba", 251 | "za": "Zhuang", 252 | "zu": "Zulu", 253 | } 254 | 255 | return lang_map[lan] 256 | 257 | def read_csv(path_csv): 258 | csv_reader = pd.read_csv(path_csv, sep="\t", header=0) 259 | return csv_reader["times"].tolist() 260 | 261 | 262 | def retrieve_samples(path, dataset): 263 | sacrebleu = datasets.load_metric("sacrebleu") 264 | decoders = ["autoregressive", "jacobi", "gs_jacobi", "aw_jacobi"] 265 | 266 | decoder2file = { 267 | "autoregressive": {"time": [], "trans": []}, 268 | "jacobi": {"time": [], "trans": []}, 269 | "gs_jacobi": {"time": [], "trans": []}, 270 | "aw_jacobi": {"time": [], "trans": []}, 271 | } 272 | 273 | for folder in decoders: 274 | subf = os.path.join(path, folder) 275 | 276 | for root, dirs, files in os.walk(subf): 277 | for filename in files: 278 | if "time" in filename: 279 | times = read_csv(os.path.join(root, filename)) 280 | decoder2file[folder]['time'].extend(times) 281 | if 'trans' in filename: 282 | trans = read_csv(os.path.join(root, filename)) 283 | decoder2file[folder]['trans'].extend(trans) 284 | 285 | diff_times = np.array([auto-aw for auto, aw in zip(decoder2file['autoregressive']['time'], decoder2file['aw_jacobi']['time'])]) 286 | indices = diff_times.argsort()[::-1] 287 | 288 | for i in indices: 289 | target = dataset[int(i)][1] 290 | source = dataset[int(i)][0] 291 | print("gold", "nd", source) 292 | print("gold", "nd", target) 293 | for d in decoders: 294 | prediction = decoder2file[d]['trans'][i] 295 | results = sacrebleu.compute(predictions=[prediction], references=[[target]]) 296 | print(d, decoder2file[d]['time'][i], round(results['score'], 3), prediction) 297 | 298 | input() 299 | continue 300 | 301 | def clean_text(text): 302 | return re.sub("[!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~£−€¿]+", " ", text) 303 | 304 | 305 | def read_wrong_beam_translations(path): 306 | csv_reader = pd.read_csv(path, sep="\t", header=0) 307 | idx = csv_reader['i'].tolist() 308 | bleus = csv_reader['bleu'].tolist() 309 | beams = csv_reader['beam'].tolist() 310 | autos = csv_reader['auto'].tolist() 311 | tgts = csv_reader['tgt'].tolist() 312 | 313 | for x in zip(idx, bleus, beams, autos, tgts): 314 | print(f"{x[0]} {x[1]}\nBeam: {x[2]}\nAuto: {x[3]}\nTgt: {x[4]}") 315 | input() 316 | continue 317 | 318 | 319 | def write_sentences(path, data): 320 | 321 | with open(path, 'w') as out_file: 322 | output = csv.writer(out_file, delimiter="\n") 323 | output.writerow(data) 324 | -------------------------------------------------------------------------------- /src/utils/beam_search.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | 4 | import torch 5 | import time 6 | from tabulate import tabulate 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | from transformers import MBartForConditionalGeneration, M2M100ForConditionalGeneration 10 | import numpy as np 11 | 12 | from src.utils.bench_scorer import Scorer 13 | from src.utils.utils import makedirs, retrieve_model_name, write_sentences, get_logits_preprocessor 14 | 15 | 16 | class BeamSearcher(object): 17 | def __init__( 18 | self, 19 | dataset, 20 | model, 21 | initializer, 22 | decoder, 23 | num_beams, 24 | no_repeat_ngram_size, 25 | early_stopping, 26 | batch_size, 27 | device, 28 | result_dir, 29 | ): 30 | self.num_beams = num_beams 31 | self.no_repeat_ngram_size = no_repeat_ngram_size 32 | self.early_stopping = early_stopping 33 | self.device = device 34 | self.result_dir = result_dir 35 | 36 | self.dataset = dataset 37 | self.initializer = initializer 38 | self.decoder = decoder 39 | 40 | self.tokenizer = dataset.tokenizer 41 | self.model = model 42 | model.eval().to(self.device) 43 | 44 | self.model_name = self.model.name_or_path 45 | print("model name in beam_search", self.model_name) 46 | 47 | self.dataloader = DataLoader( 48 | dataset, collate_fn=dataset.collate_fn, batch_size=1, shuffle=False 49 | ) 50 | 51 | if isinstance( 52 | self.model, MBartForConditionalGeneration 53 | ) or isinstance( 54 | self.model, M2M100ForConditionalGeneration 55 | ): 56 | self.is_mbart = True 57 | else: 58 | self.is_mbart = False 59 | 60 | self.exp_dir = self._retrieve_exp_dir() 61 | 62 | def _synchronize(self): 63 | if self.device == "cuda": 64 | torch.cuda.synchronize() 65 | 66 | def _retrieve_exp_dir(self): 67 | file_name = self._retrieve_file_name() 68 | 69 | exp_dir = os.path.join(self.result_dir, 'beam_search', file_name) 70 | makedirs(exp_dir) 71 | 72 | return exp_dir 73 | 74 | def _retrieve_file_name(self): 75 | model_name = retrieve_model_name(self.model_name.split("/")[1]) 76 | lang = f"{self.dataset.src_lan}_{self.dataset.tgt_lan}" 77 | return f"{model_name}/{self.dataset.name}/{lang}" 78 | 79 | def _beam_search(self, input_ids, attention_mask): 80 | if self.is_mbart: 81 | with self.tokenizer.as_target_tokenizer(): 82 | try: 83 | lang_id = self.tokenizer.cur_lang_code_id 84 | except: 85 | lang_id = self.tokenizer.cur_lang_id 86 | beam_output = self.model.generate( 87 | **{"input_ids": input_ids, "attention_mask": attention_mask}, 88 | num_beams=self.num_beams, 89 | early_stopping=self.early_stopping, 90 | # no_repeat_ngram_size=self.no_repeat_ngram_size, 91 | forced_bos_token_id=lang_id, 92 | # do_sample=False, 93 | # use_cache=True, 94 | ) 95 | else: 96 | beam_output = self.model.generate( 97 | **{"input_ids": input_ids, "attention_mask": attention_mask}, 98 | num_beams=self.num_beams, 99 | early_stopping=self.early_stopping, 100 | # no_repeat_ngram_size=self.no_repeat_ngram_size, 101 | # do_sample=False, 102 | # use_cache=True, 103 | ) 104 | 105 | return beam_output 106 | 107 | 108 | def _bench_time(self, input_ids, attention_mask): 109 | sample_time = [] 110 | for _ in range(1): 111 | start = time.perf_counter() 112 | self._synchronize() 113 | beam_output = self._beam_search(input_ids, attention_mask) 114 | self._synchronize() 115 | end = time.perf_counter() 116 | sample_time.append(end - start) 117 | 118 | sample_mean = np.average(sample_time) 119 | sample_variance = np.var(sample_time) 120 | 121 | return sample_mean, sample_variance, beam_output 122 | 123 | def _auto_time(self, input_ids, attention_mask, logits_preprocessor=None): 124 | 125 | if self.initializer is not None: 126 | init_tensor, _ = self.initializer.init_translation() 127 | else: 128 | init_tensor = None 129 | 130 | sample_time = [] 131 | for _ in range(1): 132 | init_new = init_tensor.clone() 133 | start = time.perf_counter() 134 | self._synchronize() 135 | auto_output, _ = self.decoder.autoregressive( 136 | input_ids, attention_mask, init_tensor=init_new, logits_preprocessor=logits_preprocessor 137 | ) 138 | self._synchronize() 139 | end = time.perf_counter() 140 | sample_time.append(end - start) 141 | 142 | sample_mean = np.average(sample_time) 143 | sample_variance = np.var(sample_time) 144 | 145 | return sample_mean, sample_variance, auto_output 146 | 147 | def compute_beam_search(self, cfg): 148 | 149 | beam_scorer, auto_scorer = Scorer(), Scorer() 150 | worst_beam_translations = [] 151 | 152 | pbar = tqdm(self.dataloader, desc="Computing Beam Search...") 153 | for x in pbar: 154 | input_ids = x["source"]["input_ids"].to(self.device) 155 | attention_mask = x["source"]["attention_mask"].to(self.device) 156 | 157 | tgt_text = x['target']['sentences'] 158 | 159 | if cfg.model.use_logits_preprocessor: 160 | logits_preprocessor = get_logits_preprocessor(self.decoder, input_ids) 161 | else: 162 | logits_preprocessor = None 163 | 164 | mean_beam, var_beam, beam_output = self._bench_time(input_ids, attention_mask) 165 | mean_auto, var_auto, auto_output = self._auto_time(input_ids, attention_mask, logits_preprocessor) 166 | 167 | translation_beam = self.tokenizer.batch_decode( 168 | beam_output, skip_special_tokens=True 169 | ) 170 | translation_auto = self.tokenizer.batch_decode( 171 | auto_output, skip_special_tokens=True 172 | ) 173 | 174 | beam_scorer.update_metrics(mean_beam, var_beam, 0, translation_beam[0], tgt_text[0]) 175 | auto_scorer.update_metrics(mean_auto, var_auto, 0, translation_auto[0], tgt_text[0]) 176 | 177 | worst_beam_translations.extend( 178 | self._compute_tmp_bleu( 179 | translation_beam[0], 180 | translation_auto[0], 181 | tgt_text[0], 182 | beam_scorer.i 183 | ) 184 | ) 185 | 186 | self.write_report(beam_scorer, auto_scorer, worst_beam_translations) 187 | 188 | def write_report(self, beam_scorer, auto_scorer, worst_beam_translations): 189 | print("Writing report...") 190 | 191 | beam_score = beam_scorer.compute_bleu_score() 192 | auto_score = auto_scorer.compute_bleu_score() 193 | 194 | worst_beam_translations.sort(key=lambda x: x[1]) 195 | 196 | # Table for the test info 197 | info_table = tabulate([ 198 | ['Model', self.model_name], 199 | ['Dataset', self.dataset.name], 200 | ['Languages', f"{self.dataset.src_lan}-{self.dataset.tgt_lan}"], 201 | ['Device', self.device], 202 | ['Sentences', beam_scorer.i], 203 | ['Num Beams', self.num_beams], 204 | ['No Rep Ngram Size', self.no_repeat_ngram_size], 205 | ['Early Stopping', self.early_stopping], 206 | ['Do Sample', False], 207 | ['Use Cache', True], 208 | ], headers=['Info', 'Value'], tablefmt='grid') 209 | 210 | # Table for the Time 211 | time_table = tabulate([ 212 | ['Time', beam_scorer.tot_mean_time, auto_scorer.tot_mean_time, (auto_scorer.tot_mean_time / beam_scorer.tot_mean_time)], 213 | ['Iter', beam_scorer.tot_mean_iter, auto_scorer.tot_mean_iter, 1], 214 | ['Var', beam_scorer.tot_var_time, auto_scorer.tot_var_time, 1], 215 | ], headers=['Metrics', 'Beam', 'Auto', 'Speedup'], tablefmt='grid') 216 | 217 | # Table for the Bleu score 218 | bleu_table = tabulate([ 219 | ['Score', beam_score.score, auto_score.score], 220 | ['Counts', beam_score.counts, auto_score.counts], 221 | ['Totals', beam_score.totals, auto_score.totals], 222 | ['Precisions', beam_score.precisions, auto_score.precisions], 223 | ['Bp', beam_score.bp, auto_score.bp], 224 | ['Sys_len', beam_score.sys_len, auto_score.sys_len], 225 | ['ref_len', beam_score.ref_len, auto_score.ref_len], 226 | ], headers=['Metrics', 'Beam Search', 'Auto'], tablefmt='rst') 227 | 228 | with open(os.path.join(self.exp_dir, "report.txt"), mode='w') as report: 229 | report.write(f"Test Info\n{info_table}\n\n") 230 | report.write(f"Time\n{time_table}\n\n") 231 | report.write(f"Bleu Score\n{bleu_table}\n\n") 232 | 233 | with open(os.path.join(self.exp_dir, "worst_beam_translation.csv"), 'w') as csvfile: 234 | writer = csv.writer(csvfile, delimiter='\t') 235 | writer.writerow(['i', 'bleu', 'beam', 'auto', 'tgt']) 236 | for sample in worst_beam_translations: 237 | writer.writerow(list(sample)) 238 | 239 | write_sentences(os.path.join(self.exp_dir, "beam.txt"), beam_scorer.predictions) 240 | write_sentences(os.path.join(self.exp_dir, "auto.txt"), auto_scorer.predictions) 241 | write_sentences(os.path.join(self.exp_dir, "reference.txt"), sum(auto_scorer.references, [])) 242 | 243 | def _compute_tmp_bleu(self, translation_beam, translation_auto, tgt_text, i): 244 | 245 | beam_tmp_scorer, auto_tmp_scorer = Scorer(), Scorer() 246 | 247 | beam_tmp_scorer.update_metrics(0, 0, 0, translation_beam, tgt_text) 248 | auto_tmp_scorer.update_metrics(0, 0, 0, translation_auto, tgt_text) 249 | 250 | beam_score = beam_tmp_scorer.compute_bleu_score() 251 | auto_score = auto_tmp_scorer.compute_bleu_score() 252 | 253 | if beam_score.score < auto_score.score: 254 | return [(i, beam_score.score, translation_beam, translation_auto, tgt_text)] 255 | else: 256 | return [] 257 | 258 | 259 | 260 | -------------------------------------------------------------------------------- /src/bench.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import logging 3 | import os 4 | import sys 5 | import time 6 | 7 | import torch 8 | from tabulate import tabulate 9 | from torch.utils.data import DataLoader, Dataset 10 | from tqdm import tqdm 11 | 12 | from src.utils import utils 13 | from src.utils.bench_scorer import Scorer 14 | from src.utils.utils import retrieve_model_name, check_zero_division 15 | 16 | logging.basicConfig( 17 | format="%(asctime)s | %(levelname)s | %(name)s | [%(filename)s:%(lineno)d] %(message)s", 18 | datefmt="%Y-%m-%d %H:%M:%S", 19 | level=os.environ.get("LOGLEVEL", "INFO").upper(), 20 | stream=sys.stdout, 21 | ) 22 | logger = logging.getLogger("bench") 23 | 24 | 25 | class MTBenchmarker(object): 26 | def __init__( 27 | self, 28 | dataset: Dataset, 29 | decoders, 30 | src_lang, 31 | tgt_lang, 32 | compare_to: str = "autoregressive", 33 | result_dir: str = None, 34 | device: str = "cuda", 35 | debug: bool = False, 36 | ): 37 | self.dataset = dataset 38 | self.decoders = decoders 39 | self.device = device 40 | self.debug = debug 41 | 42 | self.compare_to = compare_to 43 | 44 | self.src_lang = src_lang 45 | self.tgt_lang = tgt_lang 46 | self.model_name = self.decoders[0].model_name.split("/")[1] 47 | 48 | self.dataloader = DataLoader( 49 | dataset, collate_fn=dataset.collate_fn, batch_size=1, shuffle=False 50 | ) 51 | 52 | self.result_dir = result_dir 53 | self.exp_dir = self._retrieve_exp_dir() 54 | 55 | def _synchronize(self): 56 | if self.device == "cuda": 57 | torch.cuda.synchronize() 58 | 59 | def _retrieve_exp_dir(self): 60 | file_name = self._retrieve_file_name() 61 | 62 | exp_dir = os.path.join(self.result_dir, "benchmark", file_name) 63 | utils.makedirs(exp_dir) 64 | 65 | return exp_dir 66 | 67 | def _retrieve_file_name(self): 68 | model_name = retrieve_model_name(self.model_name) 69 | lang = f"{self.src_lang}_{self.tgt_lang}" 70 | return f"{model_name}/{self.device}/{self.dataset.name}/{lang}" 71 | 72 | @staticmethod 73 | def _write_on_file(path, item, i): 74 | 75 | utils.makedirs(path) 76 | with open(path, "a") as file: 77 | writer = csv.writer(file, delimiter="\t") 78 | # write header 79 | if os.stat(path).st_size == 0: 80 | writer.writerow(["i", "item"]) 81 | writer.writerow([i, item]) 82 | 83 | def _compute_info_table(self, best_alg, i, grid): 84 | # Table for the test info 85 | info_table = tabulate([ 86 | [ 87 | self.decoders[0].model_name, 88 | self.dataset.name, 89 | self.device, 90 | best_alg, 91 | i, 92 | f"{self.src_lang}-{self.tgt_lang}", 93 | ] 94 | ], 95 | headers=[ 96 | "Model", 97 | "Dataset", 98 | "Device", 99 | "Best Algorithm", 100 | "Sentences", 101 | "Languages", 102 | ], 103 | tablefmt=grid, 104 | ) 105 | 106 | return info_table 107 | 108 | def _compute_time_table(self, scorers, grid): 109 | 110 | if len(scorers) == 1: 111 | name, scorer = list(scorers.items())[0] 112 | times_table = tabulate([ 113 | ["Time", scorer.tot_mean_time], 114 | ["Iteration", scorer.tot_mean_iter], 115 | ], 116 | headers=["Metrics", name], tablefmt=grid, 117 | ) 118 | else: 119 | tests_header = ["Metrics"] + [name for name in scorers] 120 | 121 | comp_scorer = scorers.get(self.compare_to) 122 | time_speedup = [check_zero_division(comp_scorer.tot_mean_time, scorer.tot_mean_time) for scorer in scorers.values()] 123 | iter_speedup = [check_zero_division(comp_scorer.tot_mean_iter, scorer.tot_mean_iter) for scorer in scorers.values()] 124 | 125 | times_table = tabulate( 126 | [ 127 | ["Speedup T"] + time_speedup, 128 | ["Speedup I"] + iter_speedup, 129 | ["Time"] + [scorer.tot_mean_time for scorer in scorers.values()], 130 | ["Iter"] + [scorer.tot_mean_iter for scorer in scorers.values()], 131 | ], 132 | headers=tests_header, tablefmt=grid, 133 | ) 134 | 135 | return times_table 136 | 137 | def _compute_bleu_table(self, scorers, grid): 138 | 139 | bleu_scores = [scorer.compute_bleu_score() for scorer in scorers.values()] 140 | 141 | # Table for the Bleu score 142 | bleu_table = tabulate([ 143 | ['Score'] + [score.score for score in bleu_scores], 144 | ['Counts'] + [score.counts for score in bleu_scores], 145 | ['Totals'] + [score.totals for score in bleu_scores], 146 | ['Precisions'] + [score.precisions for score in bleu_scores], 147 | ['Bp'] + [score.bp for score in bleu_scores], 148 | ['Sys_len'] + [score.sys_len for score in bleu_scores], 149 | ['ref_len'] + [score.ref_len for score in bleu_scores], 150 | ], headers=["Metrics"] + [name for name in scorers], tablefmt=grid) 151 | 152 | return bleu_table 153 | 154 | def write_report(self, i, scorers, best_algorithm): 155 | print("Writing report...") 156 | 157 | # Compute best algorithm 158 | best_alg = max(best_algorithm, key=best_algorithm.get) 159 | info_table_txt = self._compute_info_table(best_alg, i, grid="grid") 160 | info_table_tex = self._compute_info_table(best_alg, i, grid="latex") 161 | 162 | # Table for the benchmark times 163 | times_table_txt = self._compute_time_table(scorers, grid="rst") 164 | times_table_tex = self._compute_time_table(scorers, grid="latex") 165 | 166 | # Table for the bleu score 167 | bleu_table_txt = self._compute_bleu_table(scorers, grid="rst") 168 | 169 | print(self.exp_dir) 170 | 171 | with open(os.path.join(self.exp_dir, "report.txt"), mode="w") as report: 172 | report.write(f"Test Info\n{info_table_txt}\n\n") 173 | report.write(f"Benchmark\n{times_table_txt}\n\n") 174 | report.write(f"Bleu\n{bleu_table_txt}\n\n") 175 | 176 | with open(os.path.join(self.exp_dir, "latex.txt"), mode="w") as report: 177 | report.write(f"Test Info\n{info_table_tex}\n\n") 178 | report.write(f"Benchmark\n{times_table_tex}\n\n") 179 | 180 | def write_inline(self, i: int, scorers, best_alg): 181 | 182 | for name, scorer in scorers.items(): 183 | # Write times 184 | path = os.path.join(self.exp_dir, name, f"{name}.tsv") 185 | self._write_on_file(path, scorer.current_time, i) 186 | # Write iterations 187 | path = os.path.join(self.exp_dir, name, f"iter_{name}.tsv") 188 | self._write_on_file(path, scorer.current_iter, i) 189 | # Write translations 190 | path = os.path.join(self.exp_dir, name, f"trans_{name}.tsv") 191 | self._write_on_file(path, scorer.current_transl, i) 192 | # Write initializations 193 | path = os.path.join(self.exp_dir, name, f"init_{name}.tsv") 194 | self._write_on_file(path, scorer.current_init, i) 195 | 196 | # Write mean 197 | path = os.path.join(self.exp_dir, "meanvar.tsv") 198 | utils.makedirs(path) 199 | with open(path, "a") as file: 200 | writer = csv.writer(file, delimiter="\t") 201 | 202 | # Write header 203 | if os.stat(path).st_size == 0: 204 | header = ["#sentence"] + [f"mean_{name}" for name in scorers] + ["best_alg"] 205 | writer.writerow(header) 206 | 207 | row = [i] + [scorer.current_time for scorer in scorers.values()] + [best_alg] 208 | writer.writerow(row) 209 | 210 | @staticmethod 211 | def _compute_best_algorithm(scorers): 212 | 213 | best = scorers.get(min(scorers, key=lambda x: scorers[x].current_time)).acronym 214 | 215 | return best 216 | 217 | def _compute_postfix(self, scorers, best_algorithms, curr_alg): 218 | 219 | best_alg = max(best_algorithms, key=best_algorithms.get) 220 | 221 | if len(scorers) == 1: 222 | _, scorer = list(scorers.items())[0] 223 | postfix = { 224 | scorer.acronym: ( 225 | round(scorer.tot_mean_time, 3), 226 | round(scorer.tot_mean_iter, 3) 227 | ), 228 | "ca": curr_alg, 229 | "ba": best_alg, 230 | } 231 | 232 | else: 233 | comp_scorer = scorers.get(self.compare_to) 234 | 235 | postfix = { 236 | scorer.acronym: ( 237 | check_zero_division(comp_scorer.tot_mean_time, scorer.tot_mean_time), 238 | check_zero_division(comp_scorer.tot_mean_iter, scorer.tot_mean_iter) 239 | ) for name, scorer in scorers.items() if name != self.compare_to 240 | } 241 | 242 | postfix.update({"ca": curr_alg, "ba": best_alg}) 243 | 244 | return postfix 245 | 246 | def compute_total_time(self): 247 | 248 | i = 0 249 | scorers = {decoder.name: Scorer(decoder.name, decoder.acronym) for decoder in self.decoders} 250 | best_algorithms = {decoder.acronym: 0 for decoder in self.decoders} 251 | 252 | pbar = tqdm(self.dataloader, desc="Computing Benchmark...") 253 | for x in pbar: 254 | 255 | try: 256 | input_ids = x["source"]["input_ids"].to(self.device) 257 | attention_mask = x["source"]["attention_mask"].to(self.device) 258 | 259 | tgt_text = x['target']['sentences'] 260 | 261 | for decoder, name in zip(self.decoders, scorers): 262 | kwargs = decoder.compute_decode_kwargs(input_ids, attention_mask) 263 | 264 | start = time.perf_counter() 265 | self._synchronize() 266 | trans, iter = decoder.decode(input_ids, attention_mask, **kwargs) 267 | self._synchronize() 268 | end = time.perf_counter() 269 | sample_time = end - start 270 | 271 | init_tensor = kwargs["init_tensor"] if "init_tensor" in kwargs else "" 272 | 273 | trans = self.dataset.tokenizer.batch_decode( 274 | trans, skip_special_tokens=True 275 | )[0] 276 | 277 | scorers[name].update_metrics(sample_time, iter, trans, tgt_text[0], init_tensor) 278 | 279 | best_alg = self._compute_best_algorithm(scorers) 280 | best_algorithms[best_alg] += 1 281 | 282 | # Update tqdm bar 283 | postfix_pbar = self._compute_postfix(scorers, best_algorithms, best_alg) 284 | pbar.set_postfix(postfix_pbar) 285 | 286 | self.write_inline(i, scorers, best_alg) 287 | except Exception as e: 288 | if self.debug: 289 | raise e 290 | else: 291 | logger.error(e) 292 | return 293 | 294 | i += 1 295 | 296 | self.write_report(i, scorers, best_algorithms) 297 | -------------------------------------------------------------------------------- /exp/flops_calculator.py: -------------------------------------------------------------------------------- 1 | """Computes the flops needed for training/running transformer networks. Adapted from https://github.com/google-research/electra/blob/master/flops_computation.py """ 2 | 3 | import collections 4 | 5 | # We checked this code with TensorFlow"s FLOPs counting, although we had to 6 | # correct for this issue: https://github.com/tensorflow/tensorflow/issues/22071 7 | # Assumptions going into the FLOPs counting 8 | # - An "operation" is a mathematical operation, not a machine instruction. So 9 | # an "exp" takes one opp like and add, even though in practice an exp 10 | # might be slower. This is not too bad an assumption because 11 | # matrix-multiplies dominate the compute for most models, so minor details 12 | # about activation functions don"t matter too much. Similarly, we count 13 | # matrix-multiplies as 2*m*n flops instead of m*n, as one might if 14 | # if considering fused multiply-add ops. 15 | # - Backward pass takes the same number of FLOPs as forward pass. No exactly 16 | # right (e.g., for softmax cross entropy loss the backward pass is faster). 17 | # Importantly, it really is the same for matrix-multiplies, which is most of 18 | # the compute anyway. 19 | # - We assume "dense" embedding lookups (i.e., multiplication by a one-hot 20 | # vector). On some hardware accelerators, these dense operations are 21 | # actually faster than sparse lookups. 22 | # Please open a github issue if you spot a problem with this code! 23 | 24 | # I am not sure if the below constants are 100% right, but they are only applied 25 | # to O(hidden_size) activations, which is generally a lot less compute than the 26 | # matrix-multiplies, which are O(hidden_size^2), so they don't affect the total 27 | # number of FLOPs much. 28 | 29 | # random number, >=, multiply activations by dropout mask, multiply activations 30 | # by correction (1 / (1 - dropout_rate)) 31 | DROPOUT_FLOPS = 4 32 | 33 | # compute mean activation (sum), computate variance of activation 34 | # (square and sum), bias (add), scale (multiply) 35 | LAYER_NORM_FLOPS = 5 36 | 37 | # GELU: 0.5 * x * (1 + tanh(sqrt(2 / np.pi) * (x + 0.044715 * pow(x, 3)))) 38 | ACTIVATION_FLOPS = 8 39 | 40 | # max/substract (for stability), exp, sum, divide 41 | SOFTMAX_FLOPS = 5 42 | 43 | 44 | class TransformerHparams(object): 45 | """Computes the train/inference FLOPs for transformers.""" 46 | 47 | def __init__(self, h, l, s=512, v=30522, e=None, i=None, heads=None, 48 | head_size=None, output_frac=0.15625, sparse_embed_lookup=False, 49 | decoder=False): 50 | self.h = h # hidden size 51 | self.l = l # number of layers 52 | self.s = s # sequence length 53 | self.v = v # vocab size 54 | self.e = h if e is None else e # embedding size 55 | self.i = h * 4 if i is None else i # intermediate size 56 | self.kqv = h if head_size is None else head_size * heads # attn proj sizes 57 | self.heads = max(h // 64, 1) if heads is None else heads # attention heads 58 | self.output_frac = output_frac # percent of tokens using an output softmax 59 | self.sparse_embed_lookup = sparse_embed_lookup # sparse embedding lookups 60 | self.decoder = decoder # decoder has extra attn to encoder states 61 | 62 | def get_block_flops(self): 63 | """Get the forward-pass FLOPs for a single transformer block.""" 64 | attn_mul = 2 if self.decoder else 1 65 | block_flops = dict( 66 | kqv=3 * 2 * self.h * self.kqv * attn_mul, 67 | kqv_bias=3 * self.kqv * attn_mul, 68 | attention_scores=2 * self.kqv * self.s * attn_mul, 69 | attn_softmax=SOFTMAX_FLOPS * self.s * self.heads * attn_mul, 70 | attention_dropout=DROPOUT_FLOPS * self.s * self.heads * attn_mul, 71 | attention_scale=self.s * self.heads * attn_mul, 72 | attention_weighted_avg_values=2 * self.h * self.s * attn_mul, 73 | attn_output=2 * self.h * self.h * attn_mul, 74 | attn_output_bias=self.h * attn_mul, 75 | attn_output_dropout=DROPOUT_FLOPS * self.h * attn_mul, 76 | attn_output_residual=self.h * attn_mul, 77 | attn_output_layer_norm=LAYER_NORM_FLOPS * attn_mul, 78 | intermediate=2 * self.h * self.i, 79 | intermediate_act=ACTIVATION_FLOPS * self.i, 80 | intermediate_bias=self.i, 81 | output=2 * self.h * self.i, 82 | output_bias=self.h, 83 | output_dropout=DROPOUT_FLOPS * self.h, 84 | output_residual=self.h, 85 | output_layer_norm=LAYER_NORM_FLOPS * self.h, 86 | ) 87 | return sum(block_flops.values()) * self.s 88 | 89 | def get_embedding_flops(self, output=False): 90 | """Get the forward-pass FLOPs the transformer inputs or output softmax.""" 91 | embedding_flops = {} 92 | if output or (not self.sparse_embed_lookup): 93 | embedding_flops["main_multiply"] = 2 * self.e * self.v 94 | # input embedding post-processing 95 | if not output: 96 | embedding_flops.update(dict( 97 | tok_type_and_position=2 * self.e * (self.s + 2), 98 | add_tok_type_and_position=2 * self.e, 99 | emb_layer_norm=LAYER_NORM_FLOPS * self.e, 100 | emb_dropout=DROPOUT_FLOPS * self.e 101 | )) 102 | # projection layer if e != h 103 | if self.e != self.h or output: 104 | embedding_flops.update(dict( 105 | hidden_kernel=2 * self.h * self.e, 106 | hidden_bias=self.e if output else self.h 107 | )) 108 | # extra hidden layer and output softmax 109 | if output: 110 | embedding_flops.update(dict( 111 | hidden_activation=ACTIVATION_FLOPS * self.e, 112 | hidden_layernorm=LAYER_NORM_FLOPS * self.e, 113 | output_softmax=SOFTMAX_FLOPS * self.v, 114 | output_target_word=2 * self.v 115 | )) 116 | return self.output_frac * sum(embedding_flops.values()) * self.s 117 | return sum(embedding_flops.values()) * self.s 118 | 119 | def get_binary_classification_flops(self): 120 | classification_flops = dict( 121 | hidden=2 * self.h * self.h, 122 | hidden_bias=self.h, 123 | hidden_act=ACTIVATION_FLOPS * self.h, 124 | logits=2 * self.h 125 | ) 126 | return sum(classification_flops.values()) * self.s 127 | 128 | def get_train_flops(self, batch_size, train_steps, discriminator=False, use_backprop=True): 129 | """Get the FLOPs for pre-training the transformer.""" 130 | # 2* for forward/backward pass 131 | if use_backprop: 132 | mult = 2 133 | else: 134 | mult = 1 135 | return mult * batch_size * train_steps * ( 136 | (self.l * self.get_block_flops()) + 137 | self.get_embedding_flops(output=False) + 138 | (self.get_binary_classification_flops() if discriminator else 139 | self.get_embedding_flops(output=True)) 140 | ) 141 | 142 | def get_infer_flops(self): 143 | """Get the FLOPs for running inference with the transformer on a 144 | classification task.""" 145 | return ((self.l * self.get_block_flops()) + 146 | self.get_embedding_flops(output=False) + 147 | self.get_binary_classification_flops()) 148 | 149 | 150 | def get_electra_train_flops( 151 | h_d, l_d, h_g, l_g, batch_size, train_steps, tied_embeddings, 152 | e=None, s=512, output_frac=0.15625): 153 | """Get the FLOPs needed for pre-training ELECTRA.""" 154 | if e is None: 155 | e = h_d 156 | disc = TransformerHparams( 157 | h_d, l_d, s=s, e=e, 158 | output_frac=output_frac).get_train_flops(batch_size, train_steps, True) 159 | gen = TransformerHparams( 160 | h_g, l_g, s=s, e=e if tied_embeddings else None, 161 | output_frac=output_frac).get_train_flops(batch_size, train_steps) 162 | return disc + gen 163 | 164 | 165 | 166 | def calculate_LASSFlops(prior_flop): 167 | vq_vaeflops = 0 168 | 169 | MODEL_FLOPS = collections.OrderedDict([ 170 | # These runtimes were computed with tensorflow FLOPs counting instead of the 171 | # script, as the neural architectures are quite different. 172 | # 768648884 words in LM1b benchmark, 10 epochs with batch size 20, 173 | # seq length 128, 568093262680 FLOPs per example. 174 | ("elmo", 2 * 10 * 768648884 * 568093262680 / (20.0 * 128)), 175 | # 15064773691518 is FLOPs for forward pass on 32 examples. 176 | # Therefore 2 * steps * batch_size * 15064773691518 / 32 is XLNet compute 177 | ("xlnet", 2 * 500000 * 8192 * 15064773691518 / 32.0), 178 | 179 | # Runtimes computed with the script 180 | ("gpt", TransformerHparams(768, 12, v=40000, output_frac=1.0).get_train_flops( 181 | 128, 960800)), 182 | 183 | ("jukebox", TransformerHparams(1024, 48, s=8192, v=2048, output_frac=1.0, heads=1).get_infer_flops()), 184 | 185 | 186 | ("bert_small", TransformerHparams(256, 12, e=128, s=128).get_train_flops(128, 1.45e6)), 187 | ("bert_base", TransformerHparams(768, 12).get_train_flops(256, 1e6)), 188 | ("bert_large", TransformerHparams(1024, 24).get_train_flops(256, 1e6)), 189 | ("electra_small", get_electra_train_flops(256, 12, 64, 12, 128, 1e6, True, s=128, e=128)), 190 | ("electra_base", get_electra_train_flops(768, 12, 256, 12, 256, 766000, True)), 191 | ("electra_400k", get_electra_train_flops(1024, 24, 256, 24, 2048, 400000, True)), 192 | ("electra_1.75M", get_electra_train_flops(1024, 24, 256, 24, 2048, 1750000, True)), 193 | 194 | # RoBERTa, ALBERT, and T5 have minor architectural differences from 195 | # BERT/ELECTRA, but I believe they don't significantly effect the runtime, 196 | # so we use this script for those models as well. 197 | ("roberta", TransformerHparams(1024, 24, v=50265).get_train_flops(8000, 500000)), 198 | ("albert", TransformerHparams(4096, 12, v=30000, e=128).get_train_flops( 199 | 4096, 1.5e6)), 200 | ("t5_11b", TransformerHparams( 201 | 1024, # hidden size 202 | 24, # layers 203 | v=32000, # vocab size 204 | i=65536, # ff intermediate hidden size 205 | heads=128, head_size=128, # heads/head size 206 | output_frac=0.0 # encoder has no output softmax 207 | ).get_train_flops(2048, 1e6) + # 1M steps with batch size 2048 208 | TransformerHparams( 209 | 1024, 210 | 24, 211 | v=32000, 212 | i=65536, 213 | heads=128, head_size=128, 214 | output_frac=1.0, # decoder has output softmax for all positions 215 | decoder=True 216 | ).get_train_flops(2048, 1e6)), 217 | ("Shallow", TransformerHparams(512, 13, i=2048).get_train_flops(250, 300000)), 218 | ("Shallow + KD", TransformerHparams(512, 13, i=2048).get_train_flops(250, 300000) + TransformerHparams(512, 13, i=2048, 219 | heads=16).get_train_flops( 220 | 250, 300000, use_backprop=False)), 221 | ("DSLP", TransformerHparams(512, 12, i=2048).get_train_flops(250, 300000) + TransformerHparams(512, 12, i=2048, 222 | heads=16).get_train_flops( 223 | 250, 300000, use_backprop=False)), 224 | ("F-VAE", TransformerHparams(512, 12, i=2048).get_train_flops(250, 300000) + TransformerHparams(512, 12, i=2048, 225 | heads=16).get_train_flops( 226 | 250, 300000, use_backprop=False)), 227 | ("DisCo", TransformerHparams(512, 12, i=2048).get_train_flops(250, 300000) + TransformerHparams(1024, 24, i=4096, heads=16).get_train_flops(250, 300000, use_backprop=False)), 228 | ("SUNDAE", TransformerHparams(512, 12, i=2048).get_train_flops(4096, 10e6)), 229 | 230 | ]) 231 | 232 | 233 | def main(): 234 | for k, v in MODEL_FLOPS.items(): 235 | print(k, v) 236 | 237 | 238 | if __name__ == "__main__": 239 | main() 240 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023-2024 Andrea Santilli, Silvio Severino 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------