├── .python-version ├── valle ├── bin │ ├── __init__.py │ ├── display_manifest_statistics.py │ ├── add_tokens.py │ ├── infer.py │ └── tokenizer.py ├── modules │ ├── __init__.py │ ├── scheduler.py │ ├── embedding.py │ ├── activation.py │ └── transformer.py ├── version.py ├── __init__.py ├── data │ ├── __init__.py │ ├── hebrew_normalizer.py │ ├── collation.py │ ├── dataset.py │ ├── input_strategies.py │ ├── fbank.py │ ├── hebrew_root_tokenizer.py │ └── datamodule.py ├── models │ ├── macros.py │ ├── visualizer.py │ ├── __init__.py │ └── transformer.py ├── utils │ ├── __init__.py │ ├── icefall.py │ └── symbol_table.py └── tests │ ├── scaling_test.py │ ├── data │ └── tokenizer_test.py │ └── valle_test.py ├── .gitignore ├── imgs └── model.jpg ├── speakers ├── geek.wav ├── osim.wav ├── shaul.wav └── speakers.yaml ├── example.csv ├── utils.py ├── pyproject.toml ├── README.md └── infer.py /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 -------------------------------------------------------------------------------- /valle/bin/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /valle/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pt 3 | out/ -------------------------------------------------------------------------------- /valle/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0.dev+git.8327cf0.clean" -------------------------------------------------------------------------------- /valle/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data, models, modules, utils 2 | -------------------------------------------------------------------------------- /imgs/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slp-rl/HebTTS/HEAD/imgs/model.jpg -------------------------------------------------------------------------------- /speakers/geek.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slp-rl/HebTTS/HEAD/speakers/geek.wav -------------------------------------------------------------------------------- /speakers/osim.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slp-rl/HebTTS/HEAD/speakers/osim.wav -------------------------------------------------------------------------------- /speakers/shaul.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slp-rl/HebTTS/HEAD/speakers/shaul.wav -------------------------------------------------------------------------------- /example.csv: -------------------------------------------------------------------------------- 1 | filename|text 2 | 0.wav,שלום מה שלומכם איך אתם מרגישים 3 | 1.wav,שלום עולם מה המצב איך הולך -------------------------------------------------------------------------------- /valle/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datamodule import * 2 | from .tokenizer import * 3 | from .collation import * 4 | from .hebrew_root_tokenizer import AlefBERTRootTokenizer 5 | -------------------------------------------------------------------------------- /valle/models/macros.py: -------------------------------------------------------------------------------- 1 | # Text 2 | NUM_TEXT_TOKENS = 512 3 | 4 | # Audio 5 | NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins 6 | NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band 7 | 8 | 9 | # Speaker 10 | NUM_SPEAKER_CLASSES = 4096 11 | SPEAKER_EMBEDDING_DIM = 64 12 | -------------------------------------------------------------------------------- /speakers/speakers.yaml: -------------------------------------------------------------------------------- 1 | osim: 2 | text-prompt: "אני משתדל מאוד שכל חסות תהיה מיוחדת ושונה" 3 | audio-prompt: "osim.wav" 4 | 5 | geek: 6 | text-prompt: "מביא הרבה מהסיבות" 7 | audio-prompt: "geek.wav" 8 | 9 | shaul: 10 | text-prompt: "ובשביל להבין למה מחיר הדלק כל כך עלה, צריך לחזור שנתיים אחורנית." 11 | audio-prompt: "shaul.wav" 12 | -------------------------------------------------------------------------------- /valle/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .symbol_table import SymbolTable 5 | from .icefall import make_pad_mask 6 | SymbolTable = SymbolTable 7 | 8 | 9 | class Transpose(nn.Identity): 10 | """(N, T, D) -> (N, D, T)""" 11 | 12 | def forward(self, input: torch.Tensor) -> torch.Tensor: 13 | return input.transpose(1, 2) 14 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class AttributeDict(dict): 4 | def __getattr__(self, key): 5 | if key in self: 6 | return self[key] 7 | raise AttributeError(f"No such attribute '{key}'") 8 | 9 | def __setattr__(self, key, value): 10 | self[key] = value 11 | 12 | def __delattr__(self, key): 13 | if key in self: 14 | del self[key] 15 | return 16 | raise AttributeError(f"No such attribute '{key}'") -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "hebtts" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.11" 7 | dependencies = [ 8 | "audiocraft>=1.3.0", 9 | "encodec>=0.1.1", 10 | "gdown>=5.2.0", 11 | "lhotse", 12 | "librosa>=0.11.0", 13 | "matplotlib>=3.10.3", 14 | "numpy<2", 15 | "omegaconf>=2.3.0", 16 | "phonemizer>=3.3.0", 17 | "torch>=2.1.0", 18 | "torchaudio>=2.1.0", 19 | "torchmetrics>=1.7.2", 20 | ] 21 | 22 | [tool.uv.sources] 23 | lhotse = { git = "https://github.com/lhotse-speech/lhotse" } 24 | -------------------------------------------------------------------------------- /valle/utils/icefall.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: 4 | """ 5 | Args: 6 | lengths: 7 | A 1-D tensor containing sentence lengths. 8 | max_len: 9 | The length of masks. 10 | Returns: 11 | Return a 2-D bool tensor, where masked positions 12 | are filled with `True` and non-masked positions are 13 | filled with `False`. 14 | 15 | >>> lengths = torch.tensor([1, 3, 2, 5]) 16 | >>> make_pad_mask(lengths) 17 | tensor([[False, True, True, True, True], 18 | [False, False, False, True, True], 19 | [False, False, True, True, True], 20 | [False, False, False, False, False]]) 21 | """ 22 | assert lengths.ndim == 1, lengths.ndim 23 | max_len = max(max_len, lengths.max()) 24 | n = lengths.size(0) 25 | seq_range = torch.arange(0, max_len, device=lengths.device) 26 | expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) 27 | 28 | return expaned_lengths >= lengths.unsqueeze(-1) -------------------------------------------------------------------------------- /valle/bin/display_manifest_statistics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) 3 | # Copyright 2023 (authors: Feiteng Li) 4 | # 5 | # See ../../../../LICENSE for clarification regarding multiple authors 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | """ 20 | This file displays duration statistics of utterances in the manifests. 21 | You can use the displayed value to choose minimum/maximum duration 22 | to remove short and long utterances during the training. 23 | """ 24 | 25 | import argparse 26 | from pathlib import Path 27 | 28 | from lhotse import load_manifest_lazy 29 | 30 | 31 | def get_args(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument( 34 | "--manifest-dir", 35 | type=Path, 36 | default=Path("data/tokenized"), 37 | help="Path to the tokenized manifests.", 38 | ) 39 | return parser.parse_args() 40 | 41 | 42 | def main(): 43 | args = get_args() 44 | manifest_dir = args.manifest_dir or Path("data/tokenized") 45 | for part in ["train", "dev", "test"]: 46 | print(f"## {part}") 47 | cuts = load_manifest_lazy(manifest_dir / f"cuts_{part}.jsonl.gz") 48 | cuts.describe() 49 | print("\n") 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /valle/tests/scaling_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | 4 | import unittest 5 | 6 | import numpy as np 7 | import torch 8 | from icefall.utils import AttributeDict 9 | 10 | from valle.models import NUM_MEL_BINS, get_model 11 | 12 | 13 | class TestModel(unittest.TestCase): 14 | @classmethod 15 | def setUpClass(cls): 16 | cls.devices = [torch.device("cpu")] 17 | if torch.cuda.is_available(): 18 | cls.devices.append(torch.device("cuda", 0)) 19 | if torch.cuda.device_count() > 1: 20 | torch.cuda.set_device(1) 21 | cls.devices.append(torch.device("cuda", 1)) 22 | 23 | def test_scaling_transformer(self): 24 | params = AttributeDict() 25 | params.decoder_dim = 64 26 | params.nhead = 4 27 | params.num_decoder_layers = 4 28 | 29 | x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8])) 30 | x_lens = torch.from_numpy(np.random.randint(4, 8, size=[4])) 31 | x_lens[-1] = 8 32 | 33 | y = torch.from_numpy( 34 | np.random.random((4, 16, NUM_MEL_BINS)).astype(np.float32) 35 | ) 36 | y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4])) 37 | y_lens[-1] = 16 38 | 39 | params.model_name = "Transformer" 40 | params.norm_first = False 41 | params.add_prenet = False 42 | params.scaling_xformers = True 43 | 44 | for device in self.devices: 45 | # Transformer 46 | model = get_model(params) 47 | num_param = sum([p.numel() for p in model.parameters()]) 48 | 49 | model.to(device) 50 | x = x.to(device) 51 | x_lens = x_lens.to(device) 52 | y = y.to(device) 53 | y_lens = y_lens.to(device) 54 | 55 | # Training 56 | codes, loss, metrics = model(x, x_lens, y, y_lens) 57 | # Inference 58 | model.eval() 59 | codes = model.inference(x[-1:], x_lens[-1:]) 60 | params.add_prenet = False 61 | 62 | 63 | if __name__ == "__main__": 64 | unittest.main() 65 | -------------------------------------------------------------------------------- /valle/bin/add_tokens.py: -------------------------------------------------------------------------------- 1 | from lhotse import CutSet 2 | from pathlib import Path 3 | from transformers import BertTokenizer 4 | from tqdm import tqdm 5 | import logging 6 | from valle.utils import SymbolTable 7 | 8 | 9 | 10 | from valle.data import ( 11 | TextTokenizer, 12 | tokenize_text 13 | ) 14 | 15 | 16 | def append_chars_subwords(output_dir): 17 | chars_tokenizer = TextTokenizer(backend="english_chars") 18 | subwords_tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") 19 | 20 | cut_names = [ 21 | "cuts_dev.jsonl.gz", 22 | "cuts_train.jsonl.gz", 23 | "cuts_test.jsonl.gz" 24 | ] 25 | 26 | word_unique_symbols = set() 27 | char_unique_symbols = set() 28 | 29 | for cut_name in cut_names: 30 | cut_path = output_dir / cut_name 31 | 32 | print(f"tokenizing cut: {cut_path}") 33 | 34 | cuts = CutSet.from_file(cut_path) 35 | new_cut_list = list() 36 | 37 | for c in tqdm(cuts, "tokenizing text"): 38 | text = c.supervisions[0].text 39 | 40 | char_tokens = tokenize_text(chars_tokenizer, text=text) 41 | word_tokens = subwords_tokenizer.tokenize(text) 42 | 43 | word_unique_symbols.update(word_tokens) 44 | char_unique_symbols.update(char_tokens) 45 | 46 | c.supervisions[0].custom["tokens"]["char"] = char_tokens 47 | c.supervisions[0].custom["tokens"]["word"] = word_tokens 48 | """ 49 | PHONEMES TOKENS ARE CALLED TEXT DUE TO LEGACY! 50 | """ 51 | 52 | new_cut_list.append(c) 53 | 54 | new_cut_set = CutSet.from_cuts(new_cut_list) 55 | new_cut_set.to_file(cut_path) 56 | 57 | 58 | # Symbol tables 59 | unique_chars = SymbolTable() 60 | unique_words = SymbolTable() 61 | 62 | for char in sorted(list(char_unique_symbols)): 63 | unique_chars.add(char) 64 | logging.info(f"{len(unique_chars)} unique chars: {unique_chars}") 65 | unique_chars_file = f"{output_dir}/unique_chars_tokens.k2symbols" 66 | unique_chars.to_file(unique_chars_file) 67 | 68 | for word in sorted(list(word_unique_symbols)): 69 | unique_words.add(word) 70 | logging.info(f"{len(unique_words)} unique words: {unique_words}") 71 | unique_words_file = f"{output_dir}/unique_words_tokens.k2symbols" 72 | unique_words.to_file(unique_words_file) 73 | 74 | 75 | if __name__ == '__main__': 76 | cuts_dev_path = "/cs/labs/adiyoss/amitroth/valle/examples/libritts/data/tokenized" 77 | 78 | append_chars_subwords(Path(cuts_dev_path)) 79 | -------------------------------------------------------------------------------- /valle/modules/scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # See ../../../../LICENSE for clarification regarding multiple authors 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | import torch 20 | 21 | from valle.modules.optim import Eden 22 | 23 | 24 | def calc_lr(step, dim_embed, warmup_steps): 25 | return dim_embed ** (-0.5) * min( 26 | step ** (-0.5), step * warmup_steps ** (-1.5) 27 | ) 28 | 29 | 30 | class NoamScheduler(torch.optim.lr_scheduler._LRScheduler): 31 | def __init__( 32 | self, 33 | base_lr: float, 34 | optimizer: torch.optim.Optimizer, 35 | dim_embed: int, 36 | warmup_steps: int, 37 | last_epoch: int = -1, 38 | verbose: bool = False, 39 | ) -> None: 40 | 41 | self.dim_embed = dim_embed 42 | self.base_lr = base_lr 43 | self.warmup_steps = warmup_steps 44 | self.num_param_groups = len(optimizer.param_groups) 45 | 46 | super().__init__(optimizer, last_epoch, verbose) 47 | 48 | def get_lr(self) -> float: 49 | lr = self.base_lr * calc_lr( 50 | self._step_count, self.dim_embed, self.warmup_steps 51 | ) 52 | return [lr] * self.num_param_groups 53 | 54 | def set_step(self, step: int): 55 | self._step_count = step 56 | 57 | 58 | def get_scheduler(params, optimizer): 59 | if params.scheduler_name.lower() == "eden": 60 | # TODO add it to params 61 | scheduler = Eden(optimizer, lr_batches=params.lr_batches, lr_epochs=params.lr_epochs, warmup_batches=params.warmup_steps) 62 | elif params.scheduler_name.lower() == "noam": 63 | scheduler = NoamScheduler( 64 | params.base_lr, 65 | optimizer, 66 | params.decoder_dim, 67 | warmup_steps=params.warmup_steps, 68 | ) 69 | # scheduler.set_step(params.start_batch or params.batch_idx_train) 70 | elif params.scheduler_name.lower() == "cosine": 71 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 72 | T_max=params.warmup_steps, 73 | optimizer=optimizer, 74 | eta_min=params.base_lr, 75 | ) 76 | else: 77 | raise NotImplementedError(f"{params.scheduler_name}") 78 | 79 | return scheduler 80 | -------------------------------------------------------------------------------- /valle/data/hebrew_normalizer.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from num2words import num2words 3 | 4 | 5 | class HebrewNormalizer: 6 | HEBREW_CHARS = "אבגדהוזחטיכלמנסעפצקרשת" 7 | 8 | SUF_REPLACE = { 9 | 'ף': 'פ', 10 | 'ץ': 'צ', 11 | 'ך': 'כ', 12 | 'ן': 'נ', 13 | 'ם': 'מ', 14 | } 15 | 16 | date_formats = [ 17 | "%d.%m.%Y", 18 | "%d/%m/%Y", 19 | 20 | ] 21 | 22 | def __init__(self, unk_token='~', punctuation=".!?,'\"\'" + '״' + '׳'): 23 | self.unk_token = unk_token 24 | self.allowed_chars = HebrewNormalizer.HEBREW_CHARS + punctuation 25 | 26 | def __call__(self, text): 27 | normalized_text = whitespace_split(text) 28 | normalized_text = self.normalize(normalized_text) 29 | return normalized_text 30 | 31 | def normalize(self, text): 32 | res = list() 33 | for word in text: 34 | date = HebrewNormalizer.get_date(word) 35 | 36 | # switch to classify word 37 | if date: 38 | res += HebrewNormalizer.date_to_hebrew(date).split(" ") 39 | elif word.isdigit(): 40 | res += num2words(int(word), lang='he').split(" ") 41 | elif self.unk_token is not None: 42 | res.append(self.remove_unknown_chars(word)) 43 | else: 44 | res.append(word) 45 | 46 | return res 47 | 48 | def remove_unknown_chars(self, word): 49 | """ 50 | remove unknown chars and "suf" characters 51 | """ 52 | norm_word = "" 53 | for char in word: 54 | if char in HebrewNormalizer.SUF_REPLACE.keys(): 55 | norm_word += HebrewNormalizer.SUF_REPLACE[char] 56 | elif char not in self.allowed_chars: 57 | norm_word += self.unk_token 58 | else: 59 | norm_word += char 60 | 61 | return norm_word 62 | 63 | @staticmethod 64 | def get_date(date_string): 65 | for date_format in HebrewNormalizer.date_formats: 66 | try: 67 | parsed_date = datetime.strptime(date_string, date_format) 68 | return parsed_date 69 | except ValueError: 70 | pass 71 | return False 72 | 73 | @staticmethod 74 | def date_to_hebrew(date: datetime): 75 | """ 76 | converts dates to hebrew using num2words 77 | currently the gender is wrong 78 | """ 79 | month = [ 80 | "ינואר", "פברואר", "מרץ", "אפריל", "מאי", "יוני", "יולי", "אוגוסט", "ספטמבר", "אוקטובר", "נובמבר", "דצמבר" 81 | ] 82 | 83 | return f"{num2words(date.day, lang='he')} ב{month[date.month - 1]}, {num2words(date.year, lang='he')}" 84 | 85 | 86 | def whitespace_split(text): 87 | """ 88 | strip and split 89 | """ 90 | text = text.strip() 91 | if not text: 92 | return [] 93 | tokens = text.split() 94 | return tokens 95 | 96 | 97 | if __name__ == '__main__': 98 | norm = HebrewNormalizer() 99 | texts = [ 100 | "היום הולדת שלי ב 21.12.2002", 101 | "יש לי 4 בננות ו 45 ענבים", 102 | "עברית english" 103 | ] 104 | 105 | for text in texts: 106 | print(norm(text)) 107 | -------------------------------------------------------------------------------- /valle/modules/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | class TokenEmbedding(nn.Module): 22 | def __init__( 23 | self, 24 | dim_model: int, 25 | vocab_size: int, 26 | dropout: float = 0.0, 27 | ): 28 | super().__init__() 29 | 30 | self.vocab_size = vocab_size # 512 31 | self.dim_model = dim_model # 1024 32 | 33 | self.dropout = torch.nn.Dropout(p=dropout) 34 | self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) 35 | 36 | @property 37 | def weight(self) -> torch.Tensor: 38 | return self.word_embeddings.weight 39 | 40 | def embedding(self, index: int) -> torch.Tensor: 41 | return self.word_embeddings.weight[index : index + 1] 42 | 43 | def forward(self, x: torch.Tensor): 44 | X = self.word_embeddings(x) 45 | X = self.dropout(X) 46 | return X 47 | 48 | 49 | class SinePositionalEmbedding(nn.Module): 50 | def __init__( 51 | self, 52 | dim_model: int, 53 | dropout: float = 0.0, 54 | scale: bool = False, 55 | alpha: bool = False, 56 | ): 57 | super().__init__() 58 | self.dim_model = dim_model 59 | self.x_scale = math.sqrt(dim_model) if scale else 1.0 60 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) 61 | self.dropout = torch.nn.Dropout(p=dropout) 62 | 63 | self.reverse = False 64 | self.pe = None 65 | self.extend_pe(torch.tensor(0.0).expand(1, 4000)) 66 | 67 | def extend_pe(self, x): 68 | """Reset the positional encodings.""" 69 | if self.pe is not None: 70 | if self.pe.size(1) >= x.size(1): 71 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 72 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 73 | return 74 | pe = torch.zeros(x.size(1), self.dim_model) 75 | if self.reverse: 76 | position = torch.arange( 77 | x.size(1) - 1, -1, -1.0, dtype=torch.float32 78 | ).unsqueeze(1) 79 | else: 80 | position = torch.arange( 81 | 0, x.size(1), dtype=torch.float32 82 | ).unsqueeze(1) 83 | div_term = torch.exp( 84 | torch.arange(0, self.dim_model, 2, dtype=torch.float32) 85 | * -(math.log(10000.0) / self.dim_model) 86 | ) 87 | pe[:, 0::2] = torch.sin(position * div_term) 88 | pe[:, 1::2] = torch.cos(position * div_term) 89 | pe = pe.unsqueeze(0) 90 | self.pe = pe.to(device=x.device, dtype=x.dtype).detach() 91 | 92 | def forward(self, x: torch.Tensor) -> torch.Tensor: 93 | self.extend_pe(x) 94 | output = x.unsqueeze(-1) if x.ndim == 2 else x 95 | output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] 96 | return self.dropout(output) 97 | -------------------------------------------------------------------------------- /valle/tests/data/tokenizer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Zhao Ming) 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import unittest 18 | 19 | from valle.data import TextTokenizer 20 | 21 | 22 | class TestTextTokenizer(unittest.TestCase): 23 | def test_espeak(self): 24 | text_tokenizer = TextTokenizer(backend="espeak") 25 | 26 | for (_input, _target) in [ 27 | ("The two parties, the sheep and the wolves, met each other.", 28 | ['ð', 'ə', '_', 't', 'uː', '_', 'p', 'ɑːɹ', 'ɾ',]), # 'i', 'z', ',', '_', 'ð'] 29 | ("Mother! dear father! do you hear me?", 30 | ['m', 'ʌ', 'ð', 'ɚ', '!', '_', 'd', 'ɪɹ', '_', 'f', 'ɑː', 'ð', 'ɚ', '!']), 31 | ("\"Whoever thou art,\" She exclaimed, suddenly seizing Rodolfo's hand,", 32 | ['"', 'h', 'uː', 'ɛ', 'v', 'ɚ', '_', 'ð', 'aʊ', '_', 'ɑːɹ', 't', ',', '"', '_', 'ʃ', 'iː', 33 | '_', 'ɛ', 'k', 's', 'k', 'l', 'eɪ', 'm', 'd', ',', '_', 's', 'ʌ', 'd', 'ə', 'n', 'l', 'i', 34 | '_', 's', 'iː', 'z', 'ɪ', 'ŋ', '_', 'ɹ', 'ə', 'd', 'ɑː', 'l', 'f', 'oʊ', 'z', '_', 'h', 35 | 'æ', 'n', 'd', ',']) 36 | ]: 37 | phonemized = text_tokenizer(_input) 38 | self.assertEqual(phonemized[0][:len(_target)], _target) 39 | 40 | def test_pypinyin(self): 41 | text_tokenizer = TextTokenizer(backend="pypinyin") 42 | 43 | for (_input, _target) in [ 44 | ("你好这是测试", 45 | ["ni3", '-', "hao3", '-', "zhe4", '-', "shi4", '-', "ce4", '-', "shi4"]), 46 | ("\"你好\", 这是测试.", 47 | ["\"", "ni3", '-', "hao3", "\"", ",", '_', "zhe4", '-', "shi4", '-', "ce4", '-', "shi4", "."]), 48 | ("此项 工作 还能 怎么 改进", 49 | ['ci3', '-', 'xiang4', '_', 'gong1', '-', 'zuo4', '_', 50 | 'hai2', '-', 'neng2', '_', 'zen3', '-', 'me5', '_', 'gai3', '-', 'jin4']), # AISHELL 51 | ]: 52 | phonemized = text_tokenizer(_input) 53 | self.assertEqual(phonemized[0], _target) 54 | 55 | def test_pypinyin_initials_finals(self): 56 | text_tokenizer = TextTokenizer(backend="pypinyin_initials_finals") 57 | 58 | for (_input, _target) in [ 59 | ("你好这是测试", 60 | ["n", "i3", "-", "h", "ao3", "-", "zh", "e4", "-", "sh", "i4", "-", "c", "e4", "-", "sh", "i4"], 61 | ), 62 | ("\"你好.这是测试.", 63 | ["\"", "n", "i3", "-", "h", "ao3", ".", "zh", "e4", "-", "sh", "i4", "-", "c", "e4", "-", "sh", "i4", "."], 64 | ), 65 | ("\"你好. 这是测试.", 66 | ["\"", "n", "i3", "-", "h", "ao3", ".", "_", "zh", "e4", "-", "sh", "i4", "-", "c", "e4", "-", "sh", "i4", "."], 67 | ), 68 | ("此项 工作 还能 怎么 改进", ['c', 'i3', '-', 'x', 'iang4', '_', 'g', 'ong1', '-', 'z', 'uo4', '_', 69 | 'h', 'ai2', '-', 'n', 'eng2', '_', 'z', 'en3', '-', 'm', 'e5', '_', 70 | 'g', 'ai3', '-', 'j', 'in4']), # AISHELL 71 | ]: 72 | phonemized = text_tokenizer(_input) 73 | self.assertListEqual(phonemized[0], _target) 74 | 75 | 76 | if __name__ == "__main__": 77 | unittest.main() 78 | -------------------------------------------------------------------------------- /valle/models/visualizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # See ../../../../LICENSE for clarification regarding multiple authors 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | from typing import Dict, List, Tuple, Union 20 | 21 | import matplotlib.pyplot as plt 22 | import numpy as np 23 | import torch 24 | 25 | 26 | def visualize( 27 | predicts: Tuple[torch.Tensor], 28 | batch: Dict[str, Union[List, torch.Tensor]], 29 | output_dir: str, 30 | limit: int = 4, 31 | ) -> None: 32 | text_tokens = batch["text_tokens"].to("cpu").detach().numpy() 33 | text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy() 34 | audio_features = batch["audio_features"].to("cpu").detach().numpy() 35 | audio_features_lens = ( 36 | batch["audio_features_lens"].to("cpu").detach().numpy() 37 | ) 38 | assert text_tokens.ndim == 2 39 | 40 | utt_ids, texts = batch["utt_id"], batch["text"] 41 | 42 | encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy() 43 | decoder_outputs = predicts[1] 44 | if isinstance(decoder_outputs, list): 45 | decoder_outputs = decoder_outputs[-1] 46 | decoder_outputs = ( 47 | decoder_outputs.to("cpu").type(torch.float32).detach().numpy() 48 | ) 49 | 50 | vmin, vmax = 0, 1024 # Encodec 51 | if decoder_outputs.dtype == np.float32: 52 | vmin, vmax = -6, 0 # Fbank 53 | 54 | num_figures = 3 55 | for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])): 56 | _ = plt.figure(figsize=(14, 8 * num_figures)) 57 | 58 | S = text_tokens_lens[b] 59 | T = audio_features_lens[b] 60 | 61 | # encoder 62 | plt.subplot(num_figures, 1, 1) 63 | plt.title(f"Text: {text}") 64 | plt.imshow( 65 | X=np.transpose(encoder_outputs[b]), 66 | cmap=plt.get_cmap("jet"), 67 | aspect="auto", 68 | interpolation="nearest", 69 | ) 70 | plt.gca().invert_yaxis() 71 | plt.axvline(x=S - 0.4, linewidth=2, color="r") 72 | plt.xlabel("Encoder Output") 73 | plt.colorbar() 74 | 75 | # decoder 76 | plt.subplot(num_figures, 1, 2) 77 | plt.imshow( 78 | X=np.transpose(decoder_outputs[b]), 79 | cmap=plt.get_cmap("jet"), 80 | aspect="auto", 81 | interpolation="nearest", 82 | vmin=vmin, 83 | vmax=vmax, 84 | ) 85 | plt.gca().invert_yaxis() 86 | plt.axvline(x=T - 0.4, linewidth=2, color="r") 87 | plt.xlabel("Decoder Output") 88 | plt.colorbar() 89 | 90 | # target 91 | plt.subplot(num_figures, 1, 3) 92 | plt.imshow( 93 | X=np.transpose(audio_features[b]), 94 | cmap=plt.get_cmap("jet"), 95 | aspect="auto", 96 | interpolation="nearest", 97 | vmin=vmin, 98 | vmax=vmax, 99 | ) 100 | plt.gca().invert_yaxis() 101 | plt.axvline(x=T - 0.4, linewidth=2, color="r") 102 | plt.xlabel("Decoder Target") 103 | plt.colorbar() 104 | 105 | plt.savefig(f"{output_dir}/{utt_id}.png") 106 | plt.close() 107 | -------------------------------------------------------------------------------- /valle/data/collation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from valle.utils import SymbolTable 8 | 9 | 10 | class TextTokenCollater: 11 | """Collate list of text tokens 12 | 13 | Map sentences to integers. Sentences are padded to equal length. 14 | Beginning and end-of-sequence symbols can be added. 15 | 16 | Example: 17 | >>> token_collater = TextTokenCollater(text_tokens) 18 | >>> tokens_batch, tokens_lens = token_collater(text) 19 | 20 | Returns: 21 | tokens_batch: IntTensor of shape (B, L) 22 | B: batch dimension, number of input sentences 23 | L: length of the longest sentence 24 | tokens_lens: IntTensor of shape (B,) 25 | Length of each sentence after adding and 26 | but before padding. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | text_tokens: List[str], 32 | add_eos: bool = True, 33 | add_bos: bool = True, 34 | pad_symbol: str = "", 35 | bos_symbol: str = "", 36 | eos_symbol: str = "", 37 | ): 38 | self.pad_symbol = pad_symbol 39 | 40 | self.add_eos = add_eos 41 | self.add_bos = add_bos 42 | 43 | self.bos_symbol = bos_symbol 44 | self.eos_symbol = eos_symbol 45 | 46 | unique_tokens = ( 47 | [pad_symbol] 48 | + ([bos_symbol] if add_bos else []) 49 | + ([eos_symbol] if add_eos else []) 50 | + sorted(text_tokens) 51 | ) 52 | 53 | self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} 54 | self.idx2token = [token for token in unique_tokens] 55 | 56 | def index( 57 | self, tokens_list: List[str] 58 | ) -> Tuple[torch.Tensor, torch.Tensor]: 59 | seqs, seq_lens = [], [] 60 | for tokens in tokens_list: 61 | assert ( 62 | all([True if s in self.token2idx else False for s in tokens]) 63 | is True 64 | ) 65 | seq = ( 66 | ([self.bos_symbol] if self.add_bos else []) 67 | + list(tokens) 68 | + ([self.eos_symbol] if self.add_eos else []) 69 | ) 70 | seqs.append(seq) 71 | seq_lens.append(len(seq)) 72 | 73 | max_len = max(seq_lens) 74 | for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): 75 | seq.extend([self.pad_symbol] * (max_len - seq_len)) 76 | 77 | tokens = torch.from_numpy( 78 | np.array( 79 | [[self.token2idx[token] for token in seq] for seq in seqs], 80 | dtype=np.int64, 81 | ) 82 | ) 83 | tokens_lens = torch.IntTensor(seq_lens) 84 | 85 | return tokens, tokens_lens 86 | 87 | def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: 88 | tokens_seqs = [[p for p in text] for text in texts] 89 | max_len = len(max(tokens_seqs, key=len)) 90 | 91 | seqs = [ 92 | ([self.bos_symbol] if self.add_bos else []) 93 | + list(seq) 94 | + ([self.eos_symbol] if self.add_eos else []) 95 | + [self.pad_symbol] * (max_len - len(seq)) 96 | for seq in tokens_seqs 97 | ] 98 | 99 | tokens_batch = torch.from_numpy( 100 | np.array( 101 | [[self.token2idx[token] for token in seq] for seq in seqs], 102 | dtype=np.int64, 103 | ) 104 | ) 105 | 106 | tokens_lens = torch.IntTensor( 107 | [ 108 | len(seq) + int(self.add_eos) + int(self.add_bos) 109 | for seq in tokens_seqs 110 | ] 111 | ) 112 | 113 | return tokens_batch, tokens_lens 114 | 115 | 116 | def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater: 117 | text_tokens_path = Path(text_tokens_file) 118 | unique_tokens = SymbolTable.from_file(text_tokens_path) 119 | collater = TextTokenCollater( 120 | unique_tokens.symbols, add_bos=True, add_eos=True 121 | ) 122 | return collater 123 | 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Language Modeling Approach to Diacritic-Free Hebrew TTS (Interspeech 2024) 2 | 3 | Inference code and model weights for the paper "A Language Modeling Approach to Diacritic-Free Hebrew TTS" (Interspeech 4 | 2024). 5 | 6 |

7 | 8 | 9 | 10 | 11 | 12 |

13 | 14 | ![](imgs/model.jpg) 15 | 16 | ___ 17 | **Abstract:** We tackle the task of text-to-speech (TTS) in Hebrew. Traditional Hebrew contains Diacritics (`Niqqud'), 18 | which dictate the way individuals should pronounce given words, however, modern Hebrew rarely uses them. The lack of 19 | diacritics in modern Hebrew results in readers expected to conclude the correct pronunciation and understand which 20 | phonemes to use based on the context. This imposes a fundamental challenge on TTS systems to accurately map between 21 | text-to-speech. In this study, we propose to adopt a language modeling Diacritics-Free TTS approach, for the task of 22 | Hebrew TTS. The language model (LM) operates on discrete speech representations and is conditioned on a word-piece 23 | tokenizer. We optimize the proposed method using in-the-wild weakly supervised recordings and compare it to several 24 | diacritic based Hebrew TTS systems. Results suggest the proposed method is superior to the evaluated baselines 25 | considering both content preservation and naturalness of the generated speech. 26 | 27 | ## Try it out! 28 | You can try our model in the [google colab](https://colab.research.google.com/drive/1f3-6Dqbna9_hI5C9V4qTIG05dixW-r72?usp=sharing) demo. 29 | 30 | ## Installation 31 | 32 | ```bash 33 | git clone https://github.com/slp-rl/HebTTS.git 34 | ``` 35 | 36 | We publish our checkpoint 37 | in [google drive](https://drive.google.com/file/d/11NoOJzMLRX9q1C_Q4sX0w2b9miiDjGrv/view?usp=share_link). 38 | AR model trained for 1.2M steps and NAR model for 200K steps on [HebDB](https://pages.cs.huji.ac.il/adiyoss-lab/HebDB/). 39 | 40 | ```bash 41 | pip install uv 42 | uv sync 43 | uv run gdown 11NoOJzMLRX9q1C_Q4sX0w2b9miiDjGrv 44 | ``` 45 | 46 | ## Inference 47 | 48 | You can play with the model with different speakers and text prompts. 49 | 50 | ``` 51 | uv run infer.py --checkpoint checkpoint.pt --output-dir ./out --text "היי מה קורה" 52 | ``` 53 | 54 | you can specify additional arguments 55 | `--speaker` and `--top-k`. 56 | 57 | ## Create multiple samples from csv 58 | 59 | ```console 60 | uv run infer.py --checkpoint checkpoint.pt --output-dir ./out --csv_path ./example.csv 61 | ``` 62 | 63 | ### Multi Band Diffusion 64 | 65 | > [!TIP] 66 | > We allow using the new Multi Band Diffusion (MBD) vocoder for generating a better quallity audio. 67 | Install audiocraft and set `--mbd True` flag. 68 | 69 | ### Text 70 | 71 | you can concatenate text prompts using `|` or specify a path of a text file spereated by `\n` if writing Hebrew in 72 | terminal is inconvenient. 73 | 74 | ```text 75 | תגידו גנבו לכם פעם את האוטו ופשוט ידעתם שאין טעם להגיש תלונה במשטרה 76 | היי מה קורה 77 | בראשית היתה חללית מסוג נחתת 78 | ``` 79 | 80 | and run 81 | 82 | ``` 83 | uv run python infer.py --checkpoint checkpoint.pt --output-dir ./out --text example.txt 84 | ``` 85 | 86 | ### Speakers 87 | 88 | you can use the speaker defined in `speakers.yaml`, or add additional speakers. 89 | specify wav files and transcription in same format. 90 | 91 | ``` 92 | --speaker shaul 93 | ``` 94 | 95 | ## Citation 96 | 97 | ```bibtex 98 | @article{roth2024language, 99 | title={A Language Modeling Approach to Diacritic-Free Hebrew TTS}, 100 | author={Roth, Amit and Turetzky, Arnon and Adi, Yossi}, 101 | journal={arXiv preprint arXiv:2407.12206}, 102 | year={2024} 103 | } 104 | ``` 105 | 106 | ## Acknowledgments 107 | - Model code inside `valle` is based on the implementation of [Feiteng Li](https://github.com/lifeiteng/vall-e). 108 | -------------------------------------------------------------------------------- /valle/data/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # See ../../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """ 18 | modified from lhoste.dataset.speech_synthesis.py 19 | """ 20 | 21 | from typing import Callable, Dict, List, Sequence, Union 22 | 23 | import torch 24 | from lhotse import validate 25 | from lhotse.cut import CutSet 26 | from lhotse.dataset.collation import collate_audio 27 | from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures 28 | from lhotse.utils import ifnone 29 | 30 | from valle.data.collation import TextTokenCollater 31 | 32 | 33 | class SpeechSynthesisDataset(torch.utils.data.Dataset): 34 | """ 35 | The PyTorch Dataset for the speech synthesis(e.g. TTS) task. 36 | Each item in this dataset is a dict of: 37 | 38 | .. code-block:: 39 | 40 | { 41 | 'audio': (B x NumSamples) float tensor 42 | 'audio_lens': (B, ) int tensor 43 | 'text': str 44 | 'audio_features': (B x NumFrames x NumFeatures) float tensor 45 | 'audio_features_lens': (B, ) int tensor 46 | 'text_tokens': (B x NumTextTokens) long tensor 47 | 'text_tokens_lens': (B, ) int tensor 48 | } 49 | """ 50 | 51 | def __init__( 52 | self, 53 | text_token_collater: TextTokenCollater, 54 | cut_transforms: List[Callable[[CutSet], CutSet]] = None, 55 | feature_input_strategy: BatchIO = PrecomputedFeatures(), 56 | feature_transforms: Union[Sequence[Callable], Callable] = None, 57 | token = "text", 58 | ) -> None: 59 | super().__init__() 60 | 61 | self.text_token_collater = text_token_collater 62 | self.cut_transforms = ifnone(cut_transforms, []) 63 | self.feature_input_strategy = feature_input_strategy 64 | self.token = token 65 | 66 | if feature_transforms is None: 67 | feature_transforms = [] 68 | elif not isinstance(feature_transforms, Sequence): 69 | feature_transforms = [feature_transforms] 70 | 71 | assert all( 72 | isinstance(transform, Callable) for transform in feature_transforms 73 | ), "Feature transforms must be Callable" 74 | self.feature_transforms = feature_transforms 75 | 76 | def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: 77 | validate_for_tts(cuts) 78 | 79 | for transform in self.cut_transforms: 80 | cuts = transform(cuts) 81 | 82 | if False: # not used 83 | audio, audio_lens = collate_audio(cuts) 84 | else: # for sharing tokenized features in different machines 85 | audio, audio_lens = None, None 86 | 87 | audio_features, audio_features_lens = self.feature_input_strategy(cuts) 88 | 89 | for transform in self.feature_transforms: 90 | audio_features = transform(audio_features) 91 | 92 | text_tokens, text_tokens_lens = self.text_token_collater( 93 | [cut.supervisions[0].custom["tokens"][self.token] for cut in cuts] 94 | ) 95 | 96 | return { 97 | "utt_id": [cut.id for cut in cuts], 98 | "text": [cut.supervisions[0].text for cut in cuts], 99 | "audio": audio, 100 | "audio_lens": audio_lens, 101 | "audio_features": audio_features, 102 | "audio_features_lens": audio_features_lens, 103 | "text_tokens": text_tokens, 104 | "text_tokens_lens": text_tokens_lens, 105 | } 106 | 107 | 108 | def validate_for_tts(cuts: CutSet) -> None: 109 | validate(cuts) 110 | for cut in cuts: 111 | assert ( 112 | len(cut.supervisions) == 1 113 | ), "Only the Cuts with single supervision are supported." 114 | -------------------------------------------------------------------------------- /valle/models/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch.nn as nn 4 | 5 | from .macros import ( 6 | NUM_AUDIO_TOKENS, 7 | NUM_MEL_BINS, 8 | NUM_SPEAKER_CLASSES, 9 | NUM_TEXT_TOKENS, 10 | SPEAKER_EMBEDDING_DIM, 11 | ) 12 | from .transformer import Transformer 13 | from .valle import VALLE, VALLF, VALLE_ALEPHBERT, VALLE_ALEPHBERT_CONCAT 14 | from .visualizer import visualize 15 | 16 | 17 | def add_model_arguments(parser: argparse.ArgumentParser): 18 | parser.add_argument( 19 | "--model-name", 20 | type=str, 21 | default="VALL-E", 22 | help="VALL-E, VALL-F, Transformer.", 23 | ) 24 | parser.add_argument( 25 | "--decoder-dim", 26 | type=int, 27 | default=1024, 28 | help="Embedding dimension in the decoder model.", 29 | ) 30 | parser.add_argument( 31 | "--nhead", 32 | type=int, 33 | default=16, 34 | help="Number of attention heads in the Decoder layers.", 35 | ) 36 | parser.add_argument( 37 | "--num-decoder-layers", 38 | type=int, 39 | default=12, 40 | help="Number of Decoder layers.", 41 | ) 42 | parser.add_argument( 43 | "--scale-factor", 44 | type=float, 45 | default=1.0, 46 | help="Model scale factor which will be assigned different meanings in different models.", 47 | ) 48 | parser.add_argument( 49 | "--norm-first", 50 | default=True, 51 | help="Pre or Post Normalization.", 52 | ) 53 | parser.add_argument( 54 | "--add-prenet", 55 | default=False, 56 | help="Whether add PreNet after Inputs.", 57 | ) 58 | 59 | # VALL-E & F 60 | parser.add_argument( 61 | "--prefix-mode", 62 | type=int, 63 | default=0, 64 | help="The mode for how to prefix VALL-E NAR Decoder, " 65 | "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.", 66 | ) 67 | parser.add_argument( 68 | "--share-embedding", 69 | default=True, 70 | help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.", 71 | ) 72 | parser.add_argument( 73 | "--prepend-bos", 74 | default=False, 75 | help="Whether prepend to the acoustic tokens -> AR Decoder inputs.", 76 | ) 77 | parser.add_argument( 78 | "--num-quantizers", 79 | type=int, 80 | default=8, 81 | help="Number of Audio/Semantic quantization layers.", 82 | ) 83 | 84 | parser.add_argument( 85 | "--num-text-tokens", 86 | type=int, 87 | default=512, 88 | help="number of text tokens for token embedding" 89 | ) 90 | 91 | # Transformer 92 | parser.add_argument( 93 | "--scaling-xformers", 94 | type=str2bool, 95 | default=False, 96 | help="Apply Reworked Conformer scaling on Transformers.", 97 | ) 98 | 99 | 100 | def get_model(params) -> nn.Module: 101 | if params.model_name.lower() in ["vall-f", "vallf"]: 102 | model = VALLF( 103 | params.decoder_dim, 104 | params.nhead, 105 | params.num_decoder_layers, 106 | norm_first=params.norm_first, 107 | add_prenet=params.add_prenet, 108 | prefix_mode=params.prefix_mode, 109 | share_embedding=params.share_embedding, 110 | nar_scale_factor=params.scale_factor, 111 | prepend_bos=params.prepend_bos, 112 | num_quantizers=params.num_quantizers, 113 | num_text_tokens=params.num_text_tokens, 114 | ) 115 | elif params.model_name.lower() in ["vall-e", "valle"]: 116 | model = VALLE( 117 | params.decoder_dim, 118 | params.nhead, 119 | params.num_decoder_layers, 120 | norm_first=params.norm_first, 121 | add_prenet=params.add_prenet, 122 | prefix_mode=params.prefix_mode, 123 | share_embedding=params.share_embedding, 124 | nar_scale_factor=params.scale_factor, 125 | prepend_bos=params.prepend_bos, 126 | num_quantizers=params.num_quantizers, 127 | num_text_tokens=params.num_text_tokens 128 | ) 129 | elif params.model_name.lower() in ["valle-alephbert", "alephbert"]: 130 | model = VALLE_ALEPHBERT( 131 | params.decoder_dim, 132 | params.nhead, 133 | params.num_decoder_layers, 134 | norm_first=params.norm_first, 135 | add_prenet=params.add_prenet, 136 | prefix_mode=params.prefix_mode, 137 | share_embedding=params.share_embedding, 138 | nar_scale_factor=params.scale_factor, 139 | prepend_bos=params.prepend_bos, 140 | num_quantizers=params.num_quantizers, 141 | ) 142 | 143 | elif params.model_name.lower() in ["valle-alephbert-concat", "alephbert-concat"]: 144 | model = VALLE_ALEPHBERT_CONCAT( 145 | params.decoder_dim, 146 | params.nhead, 147 | params.num_decoder_layers, 148 | norm_first=params.norm_first, 149 | add_prenet=params.add_prenet, 150 | prefix_mode=params.prefix_mode, 151 | share_embedding=params.share_embedding, 152 | nar_scale_factor=params.scale_factor, 153 | prepend_bos=params.prepend_bos, 154 | num_quantizers=params.num_quantizers, 155 | num_text_tokens=params.num_text_tokens 156 | ) 157 | else: 158 | assert params.model_name in ["Transformer"] 159 | model = Transformer( 160 | params.decoder_dim, 161 | params.nhead, 162 | params.num_decoder_layers, 163 | norm_first=params.norm_first, 164 | add_prenet=params.add_prenet, 165 | scaling_xformers=params.scaling_xformers, 166 | ) 167 | 168 | return model 169 | -------------------------------------------------------------------------------- /valle/data/input_strategies.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | from concurrent.futures import ThreadPoolExecutor 4 | from typing import Tuple, Type 5 | 6 | from lhotse import CutSet 7 | from lhotse.dataset.collation import collate_features 8 | from lhotse.dataset.input_strategies import ( 9 | ExecutorType, 10 | PrecomputedFeatures, 11 | _get_executor, 12 | ) 13 | from lhotse.utils import fastcopy 14 | 15 | 16 | class PromptedFeatures: 17 | def __init__(self, prompts, features): 18 | self.prompts = prompts 19 | self.features = features 20 | 21 | def to(self, device): 22 | return PromptedFeatures( 23 | self.prompts.to(device), self.features.to(device) 24 | ) 25 | 26 | def sum(self): 27 | return self.features.sum() 28 | 29 | @property 30 | def ndim(self): 31 | return self.features.ndim 32 | 33 | @property 34 | def data(self): 35 | return (self.prompts, self.features) 36 | 37 | 38 | class PromptedPrecomputedFeatures(PrecomputedFeatures): 39 | """ 40 | :class:`InputStrategy` that reads pre-computed features, whose manifests 41 | are attached to cuts, from disk. 42 | 43 | It automatically pads the feature matrices with pre or post feature. 44 | 45 | .. automethod:: __call__ 46 | """ 47 | 48 | def __init__( 49 | self, 50 | dataset: str, 51 | cuts: CutSet, 52 | num_workers: int = 0, 53 | executor_type: Type[ExecutorType] = ThreadPoolExecutor, 54 | ) -> None: 55 | super(PromptedPrecomputedFeatures, self).__init__( 56 | num_workers, executor_type 57 | ) 58 | 59 | self.utt2neighbors = defaultdict(lambda: []) 60 | 61 | if dataset.lower() == "libritts": 62 | # 909_131041_000013_000002 63 | # 909_131041_000013_000003 64 | speaker2utts = defaultdict(lambda: []) 65 | 66 | utt2cut = {} 67 | for cut in cuts: 68 | speaker = cut.supervisions[0].speaker 69 | speaker2utts[speaker].append(cut.id) 70 | utt2cut[cut.id] = cut 71 | 72 | for spk in speaker2utts: 73 | uttids = sorted(speaker2utts[spk]) 74 | # Using the property of sorted keys to find previous utterance 75 | # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001 76 | if len(uttids) == 1: 77 | self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]]) 78 | continue 79 | 80 | utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1])) 81 | utt2postutt = dict(zip(uttids[:-1], uttids[1:])) 82 | 83 | for utt in utt2prevutt: 84 | self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]]) 85 | 86 | for utt in utt2postutt: 87 | self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]]) 88 | elif dataset.lower() == "ljspeech": 89 | utt2cut = {} 90 | uttids = [] 91 | for cut in cuts: 92 | uttids.append(cut.id) 93 | utt2cut[cut.id] = cut 94 | 95 | if len(uttids) == 1: 96 | self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]]) 97 | else: 98 | # Using the property of sorted keys to find previous utterance 99 | # The keys has structure: LJ001-0010 100 | utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1])) 101 | utt2postutt = dict(zip(uttids[:-1], uttids[1:])) 102 | 103 | for utt in utt2postutt: 104 | postutt = utt2postutt[utt] 105 | if utt[:5] == postutt[:5]: 106 | self.utt2neighbors[utt].append(utt2cut[postutt]) 107 | 108 | for utt in utt2prevutt: 109 | prevutt = utt2prevutt[utt] 110 | if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]: 111 | self.utt2neighbors[utt].append(utt2cut[prevutt]) 112 | else: 113 | raise ValueError 114 | 115 | def __call__( 116 | self, cuts: CutSet 117 | ) -> Tuple[PromptedFeatures, PromptedFeatures]: 118 | """ 119 | Reads the pre-computed features from disk/other storage. 120 | The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``. 121 | 122 | :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding. 123 | """ 124 | features, features_lens = collate_features( 125 | cuts, 126 | executor=_get_executor( 127 | self.num_workers, executor_type=self._executor_type 128 | ), 129 | ) 130 | 131 | prompts_cuts = [] 132 | for k, cut in enumerate(cuts): 133 | prompts_cut = random.choice(self.utt2neighbors[cut.id]) 134 | prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}")) 135 | 136 | mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0]) 137 | # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate( 138 | # max_duration=mini_duration, 139 | # offset_type="random", 140 | # preserve_id=True, 141 | # ) 142 | prompts_cuts = CutSet( 143 | cuts={k: cut for k, cut in enumerate(prompts_cuts)} 144 | ).truncate( 145 | max_duration=mini_duration, 146 | offset_type="random", 147 | preserve_id=False, 148 | ) 149 | 150 | prompts, prompts_lens = collate_features( 151 | prompts_cuts, 152 | executor=_get_executor( 153 | self.num_workers, executor_type=self._executor_type 154 | ), 155 | ) 156 | 157 | return PromptedFeatures(prompts, features), PromptedFeatures( 158 | prompts_lens, features_lens 159 | ) 160 | -------------------------------------------------------------------------------- /valle/data/fbank.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # See ../../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | from dataclasses import asdict, dataclass 19 | from typing import Any, Dict, Optional, Union 20 | 21 | import numpy as np 22 | import torch 23 | from lhotse.features.base import FeatureExtractor 24 | from lhotse.utils import EPSILON, Seconds, compute_num_frames 25 | from librosa.filters import mel as librosa_mel_fn 26 | 27 | 28 | @dataclass 29 | class BigVGANFbankConfig: 30 | # Spectogram-related part 31 | # Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them 32 | frame_length: Seconds = 1024 / 24000.0 33 | frame_shift: Seconds = 256 / 24000.0 34 | remove_dc_offset: bool = True 35 | round_to_power_of_two: bool = True 36 | 37 | # Fbank-related part 38 | low_freq: float = 0.0 39 | high_freq: float = 12000.0 40 | num_mel_bins: int = 100 41 | use_energy: bool = False 42 | 43 | def to_dict(self) -> Dict[str, Any]: 44 | return asdict(self) 45 | 46 | @staticmethod 47 | def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig": 48 | return BigVGANFbankConfig(**data) 49 | 50 | 51 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 52 | return torch.log(torch.clamp(x, min=clip_val) * C) 53 | 54 | 55 | def spectral_normalize_torch(magnitudes): 56 | output = dynamic_range_compression_torch(magnitudes) 57 | return output 58 | 59 | 60 | # https://github.com/NVIDIA/BigVGAN 61 | # bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz 62 | class BigVGANFbank(FeatureExtractor): 63 | name = "fbank" 64 | config_type = BigVGANFbankConfig 65 | 66 | def __init__(self, config: Optional[Any] = None): 67 | super(BigVGANFbank, self).__init__(config) 68 | sampling_rate = 24000 69 | self.mel_basis = torch.from_numpy( 70 | librosa_mel_fn( 71 | sampling_rate, 72 | 1024, 73 | self.config.num_mel_bins, 74 | self.config.low_freq, 75 | self.config.high_freq, 76 | ).astype(np.float32) 77 | ) 78 | self.hann_window = torch.hann_window(1024) 79 | 80 | def _feature_fn(self, samples, **kwargs): 81 | win_length, n_fft = 1024, 1024 82 | hop_size = 256 83 | if True: 84 | sampling_rate = 24000 85 | duration = round(samples.shape[-1] / sampling_rate, ndigits=12) 86 | expected_num_frames = compute_num_frames( 87 | duration=duration, 88 | frame_shift=self.frame_shift, 89 | sampling_rate=sampling_rate, 90 | ) 91 | pad_size = ( 92 | (expected_num_frames - 1) * hop_size 93 | + win_length 94 | - samples.shape[-1] 95 | ) 96 | assert pad_size >= 0 97 | 98 | y = torch.nn.functional.pad( 99 | samples, 100 | (0, pad_size), 101 | mode="constant", 102 | ) 103 | else: 104 | y = torch.nn.functional.pad( 105 | samples, 106 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 107 | mode="reflect", 108 | ) 109 | 110 | y = y.squeeze(1) 111 | 112 | # complex tensor as default, then use view_as_real for future pytorch compatibility 113 | spec = torch.stft( 114 | y, 115 | n_fft, 116 | hop_length=hop_size, 117 | win_length=win_length, 118 | window=self.hann_window, 119 | center=False, 120 | pad_mode="reflect", 121 | normalized=False, 122 | onesided=True, 123 | return_complex=True, 124 | ) 125 | spec = torch.view_as_real(spec) 126 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 127 | 128 | spec = torch.matmul(self.mel_basis, spec) 129 | spec = spectral_normalize_torch(spec) 130 | 131 | return spec.transpose(2, 1).squeeze(0) 132 | 133 | def extract( 134 | self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int 135 | ) -> np.ndarray: 136 | assert sampling_rate == 24000 137 | params = asdict(self.config) 138 | params.update({"sample_frequency": sampling_rate, "snip_edges": False}) 139 | params["frame_shift"] *= 1000.0 140 | params["frame_length"] *= 1000.0 141 | if not isinstance(samples, torch.Tensor): 142 | samples = torch.from_numpy(samples) 143 | # Torchaudio Kaldi feature extractors expect the channel dimension to be first. 144 | if len(samples.shape) == 1: 145 | samples = samples.unsqueeze(0) 146 | features = self._feature_fn(samples, **params).to(torch.float32) 147 | return features.numpy() 148 | 149 | @property 150 | def frame_shift(self) -> Seconds: 151 | return self.config.frame_shift 152 | 153 | def feature_dim(self, sampling_rate: int) -> int: 154 | return self.config.num_mel_bins 155 | 156 | @staticmethod 157 | def mix( 158 | features_a: np.ndarray, 159 | features_b: np.ndarray, 160 | energy_scaling_factor_b: float, 161 | ) -> np.ndarray: 162 | return np.log( 163 | np.maximum( 164 | # protection against log(0); max with EPSILON is adequate since these are energies (always >= 0) 165 | EPSILON, 166 | np.exp(features_a) 167 | + energy_scaling_factor_b * np.exp(features_b), 168 | ) 169 | ) 170 | 171 | @staticmethod 172 | def compute_energy(features: np.ndarray) -> float: 173 | return float(np.sum(np.exp(features))) 174 | 175 | 176 | def get_fbank_extractor() -> BigVGANFbank: 177 | return BigVGANFbank(BigVGANFbankConfig()) 178 | 179 | 180 | if __name__ == "__main__": 181 | extractor = BigVGANFbank(BigVGANFbankConfig()) 182 | 183 | samples = torch.from_numpy(np.random.random([1000]).astype(np.float32)) 184 | samples = torch.clip(samples, -1.0, 1.0) 185 | fbank = extractor.extract(samples, 24000.0) 186 | print(f"fbank {fbank.shape}") 187 | 188 | from scipy.io.wavfile import read 189 | 190 | MAX_WAV_VALUE = 32768.0 191 | 192 | sampling_rate, samples = read( 193 | "egs/libritts/prompts/5639_40744_000000_000002.wav" 194 | ) 195 | print(f"samples: [{samples.min()}, {samples.max()}]") 196 | fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000) 197 | print(f"fbank {fbank.shape}") 198 | 199 | import matplotlib.pyplot as plt 200 | 201 | _ = plt.figure(figsize=(18, 10)) 202 | plt.imshow( 203 | X=fbank.transpose(1, 0), 204 | cmap=plt.get_cmap("jet"), 205 | aspect="auto", 206 | interpolation="nearest", 207 | ) 208 | plt.gca().invert_yaxis() 209 | plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png") 210 | plt.close() 211 | 212 | print("fbank test PASS!") 213 | -------------------------------------------------------------------------------- /valle/bin/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Phonemize Text and EnCodec Audio. 17 | 18 | Usage example: 19 | python3 bin/infer.py \ 20 | --decoder-dim 128 --nhead 4 --num-decoder-layers 4 --model-name valle \ 21 | --text-prompts "Go to her." \ 22 | --audio-prompts ./prompts/61_70970_000007_000001.wav \ 23 | --output-dir infer/demo_valle_epoch20 \ 24 | --checkpoint exp/valle_nano_v2/epoch-20.pt 25 | 26 | """ 27 | import argparse 28 | import logging 29 | import os 30 | import sys 31 | from pathlib import Path 32 | 33 | os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" 34 | 35 | print(sys.path) 36 | 37 | import torch 38 | import torchaudio 39 | from icefall.utils import AttributeDict, str2bool 40 | 41 | from valle.data import ( 42 | AudioTokenizer, 43 | DacAudioTokenizer, 44 | TextTokenizer, 45 | tokenize_audio, 46 | tokenize_text, 47 | ) 48 | from valle.data.collation import get_text_token_collater 49 | from valle.models import get_model 50 | 51 | 52 | def get_args(): 53 | parser = argparse.ArgumentParser() 54 | 55 | parser.add_argument( 56 | "--text-prompts", 57 | type=str, 58 | default="", 59 | help="Text prompts which are separated by |.", 60 | ) 61 | 62 | parser.add_argument( 63 | "--audio-prompts", 64 | type=str, 65 | default="", 66 | help="Audio prompts which are separated by | and should be aligned with --text-prompts.", 67 | ) 68 | 69 | parser.add_argument( 70 | "--text", 71 | type=str, 72 | default="To get up and running quickly just follow the steps below.", 73 | help="Text to be synthesized.", 74 | ) 75 | 76 | # model 77 | # add_model_arguments(parser) 78 | # parser.add_argument( 79 | # "--text-tokens", 80 | # type=str, 81 | # default="data/tokenized/unique_text_tokens.k2symbols", 82 | # help="Path to the unique text tokens file.", 83 | # ) 84 | 85 | parser.add_argument( 86 | "--audio-tokenizer", 87 | type=str, 88 | default="encodec", 89 | help="Path to the unique text tokens file.", 90 | ) 91 | 92 | parser.add_argument( 93 | "--text-extractor", 94 | type=str, 95 | default="espeak", 96 | help="espeak or pypinyin or pypinyin_initials_finals", 97 | ) 98 | 99 | parser.add_argument( 100 | "--checkpoint", 101 | type=str, 102 | default="exp/vallf_nano_full/checkpoint-100000.pt", 103 | help="Path to the saved checkpoint.", 104 | ) 105 | 106 | parser.add_argument( 107 | "--output-dir", 108 | type=Path, 109 | default=Path("infer/demo"), 110 | help="Path to the tokenized files.", 111 | ) 112 | 113 | parser.add_argument( 114 | "--top-k", 115 | type=int, 116 | default=-100, 117 | help="Whether AR Decoder do top_k(if > 0) sampling.", 118 | ) 119 | 120 | parser.add_argument( 121 | "--temperature", 122 | type=float, 123 | default=1.0, 124 | help="The temperature of AR Decoder top_k sampling.", 125 | ) 126 | 127 | parser.add_argument( 128 | "--continual", 129 | type=str2bool, 130 | default=False, 131 | help="Do continual task.", 132 | ) 133 | 134 | parser.add_argument( 135 | "--num-text-tokens", 136 | type=int, 137 | default="512", 138 | help="token type", 139 | ) 140 | 141 | return parser.parse_args() 142 | 143 | 144 | def load_model(checkpoint, device): 145 | if not checkpoint: 146 | return None 147 | 148 | checkpoint = torch.load(checkpoint, map_location=device) 149 | 150 | args = AttributeDict(checkpoint) 151 | # args.num_text_tokens = 512 152 | model = get_model(args) 153 | 154 | missing_keys, unexpected_keys = model.load_state_dict( 155 | checkpoint["model"], strict=True 156 | ) 157 | assert not missing_keys 158 | model.to(device) 159 | model.eval() 160 | 161 | text_tokens = args.text_tokens 162 | 163 | return model, text_tokens 164 | 165 | 166 | @torch.no_grad() 167 | def main(): 168 | args = get_args() 169 | text_tokenizer = TextTokenizer(backend=args.text_extractor) 170 | 171 | device = torch.device("cpu") 172 | if torch.cuda.is_available(): 173 | device = torch.device("cuda", 0) 174 | model, text_tokens = load_model(args.checkpoint, device) 175 | text_tokens = "/cs/labs/adiyoss/amitroth/valle/examples/libritts/data/tokenized/unique_text_tokens.k2symbols" 176 | text_collater = get_text_token_collater(text_tokens) 177 | 178 | if args.audio_tokenizer == "encodec": 179 | print("using encodec") 180 | audio_tokenizer = AudioTokenizer() 181 | else: 182 | assert args.audio_tokenizer == "dac" 183 | print("using dac") 184 | audio_tokenizer = DacAudioTokenizer() 185 | 186 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 187 | 188 | text_prompts = " ".join(args.text_prompts.split("|")) 189 | 190 | audio_prompts = [] 191 | if args.audio_prompts: 192 | for n, audio_file in enumerate(args.audio_prompts.split("|")): 193 | encoded_frames = tokenize_audio(audio_tokenizer, audio_file) 194 | if False: 195 | samples = audio_tokenizer.decode(encoded_frames) 196 | torchaudio.save( 197 | f"{args.output_dir}/p{n}.wav", samples[0], 24000 198 | ) 199 | 200 | audio_prompts.append(encoded_frames[0][0]) 201 | 202 | assert len(args.text_prompts.split("|")) == len(audio_prompts) 203 | audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1) 204 | audio_prompts = audio_prompts.to(device) 205 | 206 | if os.path.isfile(args.text): # for demos 207 | # https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py 208 | with open(args.text) as f: 209 | for line in f: 210 | fields = line.strip().split("\t") 211 | assert len(fields) == 4 212 | prompt_text, prompt_audio, text, audio_path = fields 213 | logging.info(f"synthesize text: {text}") 214 | text_tokens, text_tokens_lens = text_collater( 215 | [ 216 | tokenize_text( 217 | text_tokenizer, text=f"{prompt_text} {text}".strip() 218 | ) 219 | ] 220 | ) 221 | _, enroll_x_lens = text_collater( 222 | [ 223 | tokenize_text( 224 | text_tokenizer, text=f"{prompt_text}".strip() 225 | ) 226 | ] 227 | ) 228 | 229 | audio_prompts = tokenize_audio(audio_tokenizer, prompt_audio) 230 | audio_prompts = audio_prompts[0][0].transpose(2, 1).to(device) 231 | 232 | # synthesis 233 | encoded_frames = model.inference( 234 | text_tokens.to(device), 235 | text_tokens_lens.to(device), 236 | audio_prompts, 237 | enroll_x_lens=enroll_x_lens, 238 | top_k=args.top_k, 239 | temperature=args.temperature, 240 | ) 241 | 242 | samples = audio_tokenizer.decode( 243 | [(encoded_frames.transpose(2, 1), None)] 244 | ) 245 | # store 246 | torchaudio.save(audio_path, samples[0].cpu(), 24000) 247 | return 248 | 249 | for n, text in enumerate(args.text.split("|")): 250 | logging.info(f"synthesize text: {text}") 251 | text_tokens, text_tokens_lens = text_collater( 252 | [ 253 | tokenize_text( 254 | text_tokenizer, text=f"{text_prompts} {text}".strip() 255 | ) 256 | ] 257 | ) 258 | 259 | # synthesis 260 | if args.continual: 261 | assert text == "" 262 | encoded_frames = model.continual( 263 | text_tokens.to(device), 264 | text_tokens_lens.to(device), 265 | audio_prompts, 266 | ) 267 | else: 268 | enroll_x_lens = None 269 | if text_prompts: 270 | _, enroll_x_lens = text_collater( 271 | [ 272 | tokenize_text( 273 | text_tokenizer, text=f"{text_prompts}".strip() 274 | ) 275 | ] 276 | ) 277 | encoded_frames = model.inference( 278 | text_tokens.to(device), 279 | text_tokens_lens.to(device), 280 | audio_prompts, 281 | enroll_x_lens=enroll_x_lens, 282 | top_k=args.top_k, 283 | temperature=args.temperature, 284 | ) 285 | 286 | if audio_prompts != []: 287 | samples = audio_tokenizer.decode( 288 | [(encoded_frames.transpose(2, 1), None)] 289 | ) 290 | # store 291 | torchaudio.save( 292 | f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000 293 | ) 294 | else: # Transformer 295 | pass 296 | 297 | 298 | torch.set_num_threads(1) 299 | torch.set_num_interop_threads(1) 300 | torch._C._jit_set_profiling_executor(False) 301 | torch._C._jit_set_profiling_mode(False) 302 | torch._C._set_graph_executor_optimize(False) 303 | if __name__ == "__main__": 304 | formatter = ( 305 | "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" 306 | ) 307 | logging.basicConfig(format=formatter, level=logging.INFO) 308 | main() 309 | -------------------------------------------------------------------------------- /valle/tests/valle_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import unittest 17 | 18 | import numpy as np 19 | import torch 20 | from icefall.utils import AttributeDict 21 | from torchmetrics.classification import MulticlassAccuracy 22 | 23 | from valle.data.input_strategies import PromptedFeatures 24 | from valle.models import NUM_MEL_BINS, get_model 25 | 26 | 27 | class TestModel(unittest.TestCase): 28 | @classmethod 29 | def setUpClass(cls): 30 | cls.devices = [torch.device("cpu")] 31 | if torch.cuda.is_available(): 32 | cls.devices.append(torch.device("cuda", 0)) 33 | if torch.cuda.device_count() > 1: 34 | torch.cuda.set_device(1) 35 | cls.devices.append(torch.device("cuda", 1)) 36 | 37 | def test_vallf(self): 38 | params = AttributeDict() 39 | params.decoder_dim = 64 40 | params.nhead = 16 41 | params.num_decoder_layers = 4 42 | 43 | x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8])) 44 | x_lens = torch.from_numpy(np.random.randint(6, 8, size=[4])) 45 | x_lens[-1] = 8 46 | enroll_x_lens = torch.from_numpy(np.random.randint(2, 4, size=[4])) 47 | 48 | y = torch.from_numpy(np.random.randint(0, 1000, size=[4, 16, 8])) 49 | y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4])) 50 | y_lens[-1] = 16 51 | 52 | params.norm_first = True 53 | params.add_prenet = False 54 | params.model_name = "VALL-F" 55 | params.share_embedding = True 56 | params.scale_factor = 1.0 57 | params.prepend_bos = True 58 | params.num_quantizers = 1 59 | 60 | for device in self.devices: 61 | for mode in [0, 1, 2]: 62 | params.prefix_mode = mode 63 | # VALL-E 64 | model = get_model(params) 65 | 66 | # VALL-F 67 | model.to(device) 68 | x = x.to(device) 69 | x_lens = x_lens.to(device) 70 | y = y.to(device) 71 | y_lens = y_lens.to(device) 72 | 73 | # Training 74 | for train_stage in [0, 1, 2]: 75 | codes, loss, metrics = model( 76 | x, x_lens, y, y_lens, train_stage=train_stage 77 | ) 78 | 79 | # Inference 80 | model.eval() 81 | codes = model.inference( 82 | x[-1:], 83 | x_lens[-1:], 84 | y[-1:], 85 | enroll_x_lens=enroll_x_lens[-1:], 86 | ) 87 | 88 | params.prepend_bos = not params.prepend_bos 89 | params.num_quantizers += 1 90 | 91 | def test_valle(self): 92 | params = AttributeDict() 93 | params.decoder_dim = 64 94 | params.nhead = 16 95 | params.num_decoder_layers = 4 96 | 97 | x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8])) 98 | x_lens = torch.from_numpy(np.random.randint(4, 8, size=[4])) 99 | x_lens[-1] = 8 100 | enroll_x_lens = torch.from_numpy(np.random.randint(1, 3, size=[4])) 101 | 102 | y = torch.from_numpy(np.random.randint(0, 1000, size=[4, 16, 8])) 103 | y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4])) 104 | y_lens[-1] = 16 105 | 106 | params.norm_first = False 107 | params.add_prenet = True 108 | params.model_name = "VALL-E" 109 | params.share_embedding = True 110 | params.scale_factor = 1.0 111 | params.prepend_bos = False 112 | params.num_quantizers = 8 113 | 114 | for device in self.devices: 115 | for mode in [0, 1, 2]: 116 | params.prefix_mode = mode 117 | # VALL-E 118 | model = get_model(params) 119 | model.to(device) 120 | x = x.to(device) 121 | x_lens = x_lens.to(device) 122 | y = y.to(device) 123 | y_lens = y_lens.to(device) 124 | 125 | # Training 126 | codes, loss, metrics = model(x, x_lens, y, y_lens) 127 | # Inference 128 | model.eval() 129 | codes = model.inference( 130 | x[-1:], x_lens[-1:], y[-1:], enroll_x_lens=enroll_x_lens 131 | ) 132 | params.scale_factor = 0.5 133 | 134 | params.prepend_bos = not params.prepend_bos 135 | params.num_quantizers -= 1 136 | 137 | def test_vallef_prefix4(self): 138 | params = AttributeDict() 139 | params.decoder_dim = 64 140 | params.nhead = 16 141 | params.num_decoder_layers = 4 142 | 143 | x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8])) 144 | x_lens = torch.from_numpy(np.random.randint(4, 8, size=[4])) 145 | x_lens[-1] = 8 146 | enroll_x_lens = torch.from_numpy(np.random.randint(1, 3, size=[4])) 147 | 148 | y = torch.from_numpy(np.random.randint(0, 1000, size=[4, 16, 8])) 149 | y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4])) 150 | y_lens[-1] = 16 151 | 152 | prompts = torch.from_numpy(np.random.randint(0, 1000, size=[4, 12, 8])) 153 | prompts_lens = torch.from_numpy(np.random.randint(12, 13, size=[4])) 154 | 155 | params.norm_first = False 156 | params.add_prenet = True 157 | params.share_embedding = False 158 | params.scale_factor = 1.0 159 | params.prepend_bos = False 160 | params.num_quantizers = 8 161 | 162 | for device in self.devices: 163 | for model_name in ["VALL-E", "VALL-F"]: 164 | for mode in [4]: 165 | params.prefix_mode = mode 166 | params.model_name = model_name 167 | # VALL-E 168 | model = get_model(params) 169 | model.to(device) 170 | x = x.to(device) 171 | x_lens = x_lens.to(device) 172 | y = y.to(device) 173 | 174 | _y = PromptedFeatures(prompts, y).to(device) 175 | _y_lens = PromptedFeatures(prompts_lens, y_lens).to(device) 176 | 177 | # Training 178 | codes, loss, metrics = model(x, x_lens, _y, _y_lens) 179 | # Inference 180 | model.eval() 181 | codes = model.inference( 182 | x[-1:], x_lens[-1:], y[-1:], enroll_x_lens=enroll_x_lens 183 | ) 184 | 185 | def test_topmetric(self): 186 | metric_top10 = MulticlassAccuracy(1024, top_k=10, average="micro") 187 | metric_top1 = MulticlassAccuracy(1024, top_k=1, average="micro") 188 | batch_size, seq_len = 4, 16 189 | targets = np.random.randint(0, 1000, size=[batch_size, seq_len]) 190 | logits = np.random.random([batch_size, 1024, seq_len]).astype( 191 | np.float32 192 | ) 193 | 194 | larger_logits = np.clip(logits, -1.0, 1.0) 195 | smaller_logits = np.clip(logits, -1.0, 1.0) 196 | for b in range(batch_size): 197 | for t in range(seq_len): 198 | assert targets[b, t] >= 0 199 | larger_logits[b, targets[b, t], t] = 2.0 200 | smaller_logits[b, targets[b, t], t] = -2.0 201 | 202 | targets = torch.from_numpy(targets) 203 | larger_logits = torch.from_numpy(larger_logits) 204 | smaller_logits = torch.from_numpy(smaller_logits) 205 | 206 | for device in self.devices: 207 | metric_top10.to(device) 208 | metric_top1.to(device) 209 | targets = targets.to(device) 210 | 211 | one = metric_top10(larger_logits.to(device), targets) 212 | assert one.cpu().item() == 1.0, one.cpu().item() 213 | 214 | zero = metric_top1(smaller_logits.to(device), targets) 215 | assert zero.cpu().item() == 0.0, zero.cpu().item() 216 | 217 | half = metric_top1( 218 | torch.concat( 219 | [smaller_logits.to(device), larger_logits.to(device)], dim=2 220 | ), 221 | torch.concat([targets, targets], dim=1), 222 | ) 223 | assert half.cpu().item() == 0.5, half.cpu().item() 224 | 225 | def test_transformer(self): 226 | params = AttributeDict() 227 | params.decoder_dim = 64 228 | params.nhead = 4 229 | params.num_decoder_layers = 4 230 | 231 | x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8])) 232 | x_lens = torch.from_numpy(np.random.randint(4, 8, size=[4])) 233 | x_lens[-1] = 8 234 | 235 | y = torch.from_numpy( 236 | np.random.random((4, 16, NUM_MEL_BINS)).astype(np.float32) 237 | ) 238 | y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4])) 239 | y_lens[-1] = 16 240 | 241 | params.model_name = "Transformer" 242 | params.norm_first = False 243 | params.add_prenet = True 244 | params.scaling_xformers = False 245 | 246 | for device in self.devices: 247 | # Transformer 248 | model = get_model(params) 249 | num_param = sum([p.numel() for p in model.parameters()]) 250 | 251 | model.to(device) 252 | x = x.to(device) 253 | x_lens = x_lens.to(device) 254 | y = y.to(device) 255 | y_lens = y_lens.to(device) 256 | 257 | # Training 258 | codes, loss, metrics = model(x, x_lens, y, y_lens) 259 | # Inference 260 | model.eval() 261 | codes = model.inference(x[-1:], x_lens[-1:]) 262 | params.add_prenet = False 263 | 264 | params.scaling_xformers = not params.scaling_xformers 265 | 266 | 267 | if __name__ == "__main__": 268 | unittest.main() 269 | -------------------------------------------------------------------------------- /valle/utils/symbol_table.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) 2 | # 3 | # See ../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from dataclasses import dataclass 18 | from dataclasses import field 19 | from typing import Dict 20 | from typing import Generic 21 | from typing import List 22 | from typing import Optional 23 | from typing import TypeVar 24 | from typing import Union 25 | 26 | Symbol = TypeVar('Symbol') 27 | 28 | 29 | # Disable __repr__ otherwise it could freeze e.g. Jupyter. 30 | @dataclass(repr=False) 31 | class SymbolTable(Generic[Symbol]): 32 | '''SymbolTable that maps symbol IDs, found on the FSA arcs to 33 | actual objects. These objects can be arbitrary Python objects 34 | that can serve as keys in a dictionary (i.e. they need to be 35 | hashable and immutable). 36 | 37 | The SymbolTable can only be read to/written from disk if the 38 | symbols are strings. 39 | ''' 40 | _id2sym: Dict[int, Symbol] = field(default_factory=dict) 41 | '''Map an integer to a symbol. 42 | ''' 43 | 44 | _sym2id: Dict[Symbol, int] = field(default_factory=dict) 45 | '''Map a symbol to an integer. 46 | ''' 47 | 48 | _next_available_id: int = 1 49 | '''A helper internal field that helps adding new symbols 50 | to the table efficiently. 51 | ''' 52 | 53 | eps: Symbol = '' 54 | '''Null symbol, always mapped to index 0. 55 | ''' 56 | 57 | def __post_init__(self): 58 | for idx, sym in self._id2sym.items(): 59 | assert self._sym2id[sym] == idx 60 | assert idx >= 0 61 | 62 | for sym, idx in self._sym2id.items(): 63 | assert idx >= 0 64 | assert self._id2sym[idx] == sym 65 | 66 | if 0 not in self._id2sym: 67 | self._id2sym[0] = self.eps 68 | self._sym2id[self.eps] = 0 69 | else: 70 | assert self._id2sym[0] == self.eps 71 | assert self._sym2id[self.eps] == 0 72 | 73 | self._next_available_id = max(self._id2sym) + 1 74 | 75 | @staticmethod 76 | def from_str(s: str) -> 'SymbolTable': 77 | '''Build a symbol table from a string. 78 | 79 | The string consists of lines. Every line has two fields separated 80 | by space(s), tab(s) or both. The first field is the symbol and the 81 | second the integer id of the symbol. 82 | 83 | Args: 84 | s: 85 | The input string with the format described above. 86 | Returns: 87 | An instance of :class:`SymbolTable`. 88 | ''' 89 | id2sym: Dict[int, str] = dict() 90 | sym2id: Dict[str, int] = dict() 91 | 92 | for line in s.split('\n'): 93 | fields = line.split() 94 | if len(fields) == 0: 95 | continue # skip empty lines 96 | assert len(fields) == 2, \ 97 | f'Expect a line with 2 fields. Given: {len(fields)}' 98 | sym, idx = fields[0], int(fields[1]) 99 | assert sym not in sym2id, f'Duplicated symbol {sym}' 100 | assert idx not in id2sym, f'Duplicated id {idx}' 101 | id2sym[idx] = sym 102 | sym2id[sym] = idx 103 | 104 | eps = id2sym.get(0, '') 105 | 106 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps) 107 | 108 | @staticmethod 109 | def from_file(filename: str) -> 'SymbolTable': 110 | '''Build a symbol table from file. 111 | 112 | Every line in the symbol table file has two fields separated by 113 | space(s), tab(s) or both. The following is an example file: 114 | 115 | .. code-block:: 116 | 117 | 0 118 | a 1 119 | b 2 120 | c 3 121 | 122 | Args: 123 | filename: 124 | Name of the symbol table file. Its format is documented above. 125 | 126 | Returns: 127 | An instance of :class:`SymbolTable`. 128 | 129 | ''' 130 | with open(filename, 'r', encoding='utf-8') as f: 131 | return SymbolTable.from_str(f.read().strip()) 132 | 133 | def to_str(self) -> str: 134 | ''' 135 | Returns: 136 | Return a string representation of this object. You can pass 137 | it to the method ``from_str`` to recreate an identical object. 138 | ''' 139 | s = '' 140 | for idx, symbol in sorted(self._id2sym.items()): 141 | s += f'{symbol} {idx}\n' 142 | return s 143 | 144 | def to_file(self, filename: str): 145 | '''Serialize the SymbolTable to a file. 146 | 147 | Every line in the symbol table file has two fields separated by 148 | space(s), tab(s) or both. The following is an example file: 149 | 150 | .. code-block:: 151 | 152 | 0 153 | a 1 154 | b 2 155 | c 3 156 | 157 | Args: 158 | filename: 159 | Name of the symbol table file. Its format is documented above. 160 | ''' 161 | with open(filename, 'w', encoding='utf-8') as f: 162 | for idx, symbol in sorted(self._id2sym.items()): 163 | print(symbol, idx, file=f) 164 | 165 | def add(self, symbol: Symbol, index: Optional[int] = None) -> int: 166 | '''Add a new symbol to the SymbolTable. 167 | 168 | Args: 169 | symbol: 170 | The symbol to be added. 171 | index: 172 | Optional int id to which the symbol should be assigned. 173 | If it is not available, a ValueError will be raised. 174 | 175 | Returns: 176 | The int id to which the symbol has been assigned. 177 | ''' 178 | # Already in the table? Return its ID. 179 | if symbol in self._sym2id: 180 | return self._sym2id[symbol] 181 | # Specific ID not provided - use next available. 182 | if index is None: 183 | index = self._next_available_id 184 | # Specific ID provided but not available. 185 | if index in self._id2sym: 186 | raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - " 187 | f"already occupied by {self._id2sym[index]}") 188 | self._sym2id[symbol] = index 189 | self._id2sym[index] = symbol 190 | 191 | # Update next available ID if needed 192 | if self._next_available_id <= index: 193 | self._next_available_id = index + 1 194 | 195 | return index 196 | 197 | def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]: 198 | '''Get a symbol for an id or get an id for a symbol 199 | 200 | Args: 201 | k: 202 | If it is an id, it tries to find the symbol corresponding 203 | to the id; if it is a symbol, it tries to find the id 204 | corresponding to the symbol. 205 | 206 | Returns: 207 | An id or a symbol depending on the given `k`. 208 | ''' 209 | if isinstance(k, int): 210 | return self._id2sym[k] 211 | else: 212 | return self._sym2id[k] 213 | 214 | def merge(self, other: 'SymbolTable') -> 'SymbolTable': 215 | '''Create a union of two SymbolTables. 216 | Raises an AssertionError if the same IDs are occupied by 217 | different symbols. 218 | 219 | Args: 220 | other: 221 | A symbol table to merge with ``self``. 222 | 223 | Returns: 224 | A new symbol table. 225 | ''' 226 | self._check_compatible(other) 227 | 228 | id2sym = {**self._id2sym, **other._id2sym} 229 | sym2id = {**self._sym2id, **other._sym2id} 230 | 231 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps) 232 | 233 | def _check_compatible(self, other: 'SymbolTable') -> None: 234 | # Epsilon compatibility 235 | assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \ 236 | f'{self.eps} != {other.eps}' 237 | # IDs compatibility 238 | common_ids = set(self._id2sym).intersection(other._id2sym) 239 | for idx in common_ids: 240 | assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \ 241 | f'self[idx] = "{self[idx]}", ' \ 242 | f'other[idx] = "{other[idx]}"' 243 | # Symbols compatibility 244 | common_symbols = set(self._sym2id).intersection(other._sym2id) 245 | for sym in common_symbols: 246 | assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \ 247 | f'self[sym] = "{self[sym]}", ' \ 248 | f'other[sym] = "{other[sym]}"' 249 | 250 | def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]: 251 | return self.get(item) 252 | 253 | def __contains__(self, item: Union[int, Symbol]) -> bool: 254 | if isinstance(item, int): 255 | return item in self._id2sym 256 | else: 257 | return item in self._sym2id 258 | 259 | def __len__(self) -> int: 260 | return len(self._id2sym) 261 | 262 | def __eq__(self, other: 'SymbolTable') -> bool: 263 | if len(self) != len(other): 264 | return False 265 | 266 | for s in self.symbols: 267 | if self[s] != other[s]: 268 | return False 269 | 270 | return True 271 | 272 | @property 273 | def ids(self) -> List[int]: 274 | '''Returns a list of integer IDs corresponding to the symbols. 275 | ''' 276 | ans = list(self._id2sym.keys()) 277 | ans.sort() 278 | return ans 279 | 280 | @property 281 | def symbols(self) -> List[Symbol]: 282 | '''Returns a list of symbols (e.g., strings) corresponding to 283 | the integer IDs. 284 | ''' 285 | ans = list(self._sym2id.keys()) 286 | ans.sort() 287 | return ans 288 | -------------------------------------------------------------------------------- /valle/bin/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Phonemize Text and EnCodec Audio. 17 | 18 | Usage example: 19 | python3 bin/tokenizer.py \ 20 | --src_dir ./data/manifests --output_dir ./data/tokenized 21 | 22 | """ 23 | import argparse 24 | import logging 25 | import os 26 | from pathlib import Path 27 | 28 | import torch 29 | import torch.multiprocessing 30 | from icefall.utils import get_executor 31 | from lhotse import CutSet, NumpyHdf5Writer 32 | from lhotse.recipes.utils import read_manifests_if_cached 33 | from tqdm.auto import tqdm 34 | 35 | from valle.data import ( 36 | AudioTokenConfig, 37 | AudioTokenConfigDac, 38 | DacAudioTokenizer, 39 | AudioTokenExtractor, 40 | TextTokenizer, 41 | tokenize_text, 42 | ) 43 | from valle.data.fbank import get_fbank_extractor 44 | from valle.utils import SymbolTable 45 | 46 | os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" 47 | 48 | 49 | # Torch's multithreaded behavior needs to be disabled or 50 | # it wastes a lot of CPU and slow things down. 51 | # Do this outside of main() in case it needs to take effect 52 | # even when we are not invoking the main (e.g. when spawning subprocesses). 53 | torch.set_num_threads(1) 54 | torch.set_num_interop_threads(1) 55 | torch.multiprocessing.set_sharing_strategy("file_system") 56 | 57 | 58 | def get_args(): 59 | parser = argparse.ArgumentParser() 60 | 61 | parser.add_argument( 62 | "--src-dir", 63 | type=Path, 64 | default=Path("data/manifests"), 65 | help="Path to the manifest files", 66 | ) 67 | parser.add_argument( 68 | "--output-dir", 69 | type=Path, 70 | default=Path("data/tokenized"), 71 | help="Path to the tokenized files", 72 | ) 73 | parser.add_argument( 74 | "--text-extractor", 75 | type=str, 76 | default="espeak", 77 | help="espeak or pypinyin or pypinyin_initials_finals", 78 | ) 79 | parser.add_argument( 80 | "--audio-extractor", 81 | type=str, 82 | default="Encodec", 83 | help="Encodec or Fbank", 84 | ) 85 | parser.add_argument( 86 | "--dataset-parts", 87 | type=str, 88 | default="all", 89 | help="Space separated dataset parts", 90 | ) 91 | parser.add_argument( 92 | "--prefix", 93 | type=str, 94 | default="libritts", 95 | help="prefix of the manifest file", 96 | ) 97 | parser.add_argument( 98 | "--suffix", 99 | type=str, 100 | default="jsonl.gz", 101 | help="suffix of the manifest file", 102 | ) 103 | parser.add_argument( 104 | "--batch-duration", 105 | type=float, 106 | default=400.0, 107 | help="The maximum number of audio seconds in a batch." 108 | "Determines batch size dynamically.", 109 | ) 110 | 111 | return parser.parse_args() 112 | 113 | 114 | def main(): 115 | args = get_args() 116 | 117 | dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip() 118 | if dataset_parts == "all": # LibriTTS 119 | dataset_parts = [ 120 | "dev-clean", 121 | "dev-other", 122 | "test-clean", 123 | "test-other", 124 | "train-clean-100", 125 | "train-clean-360", 126 | "train-other-500", 127 | ] 128 | else: 129 | dataset_parts = dataset_parts.replace("-p", "").strip().split(" ") 130 | 131 | assert len(dataset_parts) >= 1 132 | 133 | manifests = read_manifests_if_cached( 134 | dataset_parts=dataset_parts, 135 | output_dir=args.src_dir, 136 | prefix=args.prefix, 137 | suffix=args.suffix, 138 | types=["recordings", "supervisions", "cuts"], 139 | ) 140 | 141 | text_tokenizer = None 142 | if args.text_extractor: 143 | text_tokenizer = TextTokenizer(backend=args.text_extractor) 144 | 145 | audio_extractor = None 146 | if args.audio_extractor: 147 | if args.audio_extractor == "Encodec": 148 | audio_extractor = AudioTokenExtractor(AudioTokenConfig()) 149 | elif args.audio_extractor.lower() == "dac": 150 | audio_extractor = AudioTokenExtractor(AudioTokenConfigDac(), audio_tokenizer=DacAudioTokenizer) 151 | 152 | else: 153 | assert args.audio_extractor == "Fbank" 154 | audio_extractor = get_fbank_extractor() 155 | 156 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 157 | unique_symbols = set() 158 | num_jobs = min(32, os.cpu_count()) 159 | logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}") 160 | 161 | prefix = args.prefix 162 | if prefix and not prefix.endswith("_"): 163 | prefix = f"{prefix}_" 164 | with get_executor() as ex: 165 | for partition, m in manifests.items(): 166 | logging.info( 167 | f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}" 168 | ) 169 | try: 170 | cut_set = CutSet.from_manifests( 171 | # recordings=m["recordings"], 172 | supervisions=m["supervisions"], 173 | ) 174 | except Exception: 175 | cut_set = m["cuts"] 176 | 177 | # AudioTokenizer 178 | if args.audio_extractor and False: 179 | if args.audio_extractor == "Encodec": 180 | storage_path = ( 181 | f"{args.output_dir}/{args.prefix}_encodec_{partition}" 182 | ) 183 | elif args.audio_extractor == "dac": 184 | storage_path = ( 185 | f"{args.output_dir}/{args.prefix}_dac_{partition}" 186 | ) 187 | else: 188 | storage_path = ( 189 | f"{args.output_dir}/{args.prefix}_fbank_{partition}" 190 | ) 191 | 192 | if args.prefix.lower() in ["ljspeech", "aishell", "baker"]: 193 | cut_set = cut_set.resample(24000) 194 | # https://github.com/lifeiteng/vall-e/issues/90 195 | # if args.prefix == "aishell": 196 | # # NOTE: the loudness of aishell audio files is around -33 197 | # # The best way is datamodule --on-the-fly-feats --enable-audio-aug 198 | # cut_set = cut_set.normalize_loudness( 199 | # target=-20.0, affix_id=True 200 | # ) 201 | 202 | if args.audio_extractor.lower() in ["dac"]: 203 | cut_set = cut_set.resample(16000) 204 | 205 | with torch.no_grad(): 206 | if ( 207 | torch.cuda.is_available() 208 | and args.audio_extractor == "Encodec" or args.audio_extractor == "dac" 209 | ): 210 | cut_set = cut_set.compute_and_store_features_batch( 211 | extractor=audio_extractor, 212 | storage_path=storage_path, 213 | num_workers=num_jobs, 214 | batch_duration=args.batch_duration, 215 | collate=False, 216 | overwrite=True, 217 | storage_type=NumpyHdf5Writer, 218 | ) 219 | else: 220 | cut_set = cut_set.compute_and_store_features( 221 | extractor=audio_extractor, 222 | storage_path=storage_path, 223 | num_jobs=num_jobs if ex is None else 64, 224 | executor=ex, 225 | storage_type=NumpyHdf5Writer, 226 | ) 227 | 228 | # TextTokenizer 229 | if args.text_extractor: 230 | if ( 231 | args.prefix == "baker" 232 | and args.text_extractor == "labeled_pinyin" 233 | ): 234 | for c in tqdm(cut_set): 235 | phonemes = c.supervisions[0].custom["tokens"]["text"] 236 | unique_symbols.update(phonemes) 237 | else: 238 | for c in tqdm(cut_set): 239 | if args.prefix == "ljspeech": 240 | text = c.supervisions[0].custom["normalized_text"] 241 | text = text.replace("”", '"').replace("“", '"') 242 | phonemes = tokenize_text(text_tokenizer, text=text) 243 | elif args.prefix == "aishell": 244 | phonemes = tokenize_text( 245 | text_tokenizer, text=c.supervisions[0].text 246 | ) 247 | c.supervisions[0].custom = {} 248 | else: 249 | assert args.prefix == "libritts" 250 | phonemes = tokenize_text( 251 | text_tokenizer, text=c.supervisions[0].text 252 | ) 253 | c.supervisions[0].custom["tokens"] = {"text": phonemes} 254 | unique_symbols.update(phonemes) 255 | 256 | # cuts_filename = f"{prefix}cuts_{partition}.{args.suffix}" 257 | # cut_set.to_file(f"{args.output_dir}/{cuts_filename}") 258 | 259 | if args.text_extractor: 260 | unique_phonemes = SymbolTable() 261 | for s in sorted(list(unique_symbols)): 262 | unique_phonemes.add(s) 263 | logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}") 264 | 265 | unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols" 266 | unique_phonemes.to_file(unique_phonemes_file) 267 | 268 | 269 | if __name__ == "__main__": 270 | formatter = ( 271 | "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" 272 | ) 273 | logging.basicConfig(format=formatter, level=logging.INFO) 274 | main() 275 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchaudio 4 | from omegaconf import OmegaConf 5 | import argparse 6 | from pathlib import Path 7 | import csv 8 | 9 | os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" 10 | 11 | from utils import AttributeDict 12 | 13 | 14 | from valle.data import AudioTokenizer, tokenize_audio 15 | from valle.data.collation import get_text_token_collater 16 | from valle.models import get_model 17 | from valle.data.hebrew_root_tokenizer import AlefBERTRootTokenizer, replace_chars 18 | 19 | 20 | def load_model(checkpoint, device): 21 | if not checkpoint: 22 | return None 23 | 24 | checkpoint = torch.load(checkpoint, map_location=device) 25 | 26 | args = AttributeDict(checkpoint) 27 | model = get_model(args) 28 | 29 | missing_keys, unexpected_keys = model.load_state_dict( 30 | checkpoint["model"], strict=True 31 | ) 32 | assert not missing_keys 33 | model.to(device) 34 | model.eval() 35 | 36 | text_tokens = args.text_tokens_path 37 | 38 | return model, text_tokens 39 | 40 | 41 | def prepare_inference(checkpoint_path, args, prompt_audio): 42 | device = torch.device("cpu") 43 | if torch.cuda.is_available(): 44 | device = torch.device("cuda", 0) 45 | 46 | model, text_tokens = load_model(checkpoint_path, device) 47 | text_collater = get_text_token_collater(args.tokens_file) 48 | audio_tokenizer = AudioTokenizer(mbd=args.mbd) 49 | alef_bert_tokenizer = AlefBERTRootTokenizer(vocab_file=args.vocab_file) 50 | 51 | audio_prompts = [] 52 | encoded_frames = tokenize_audio(audio_tokenizer, prompt_audio) 53 | audio_prompts.append(encoded_frames[0][0]) 54 | audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1).to(device) 55 | 56 | return device, model, text_collater, audio_tokenizer, alef_bert_tokenizer, audio_prompts 57 | 58 | def infer_texts( 59 | texts_with_filenames, 60 | output_dir, 61 | prompt_text, 62 | device, 63 | model, 64 | text_collater, 65 | audio_tokenizer, 66 | alef_bert_tokenizer, 67 | audio_prompts, 68 | top_k=50, 69 | temperature=1, 70 | args=None, 71 | ): 72 | Path(output_dir).mkdir(parents=True, exist_ok=True) 73 | 74 | for filename, text in texts_with_filenames: 75 | text_without_space = [replace_chars(f"{prompt_text} {text}").strip().replace(" ", "_")] 76 | tokens = alef_bert_tokenizer._tokenize(text_without_space) 77 | prompt_text_without_space = [replace_chars(f"{prompt_text}").strip().replace(" ", "_")] 78 | prompt_tokens = alef_bert_tokenizer._tokenize(prompt_text_without_space) 79 | 80 | text_tokens, text_tokens_lens = text_collater([tokens]) 81 | _, enroll_x_lens = text_collater([prompt_tokens]) 82 | 83 | encoded_frames = model.inference( 84 | text_tokens.to(device), 85 | text_tokens_lens.to(device), 86 | audio_prompts, 87 | enroll_x_lens=enroll_x_lens, 88 | top_k=top_k, 89 | temperature=temperature, 90 | ) 91 | 92 | audio_path = Path(output_dir) / filename 93 | 94 | if args.mbd: 95 | samples = audio_tokenizer.mbd_decode(encoded_frames.transpose(2, 1)) 96 | else: 97 | samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)]) 98 | 99 | torchaudio.save(audio_path.as_posix(), samples[0].cpu(), 24000) 100 | 101 | def infer(checkpoint_path, output_dir, texts, prompt_text, prompt_audio, top_k=50, temperature=1, args=None): 102 | 103 | device = torch.device("cpu") 104 | if torch.cuda.is_available(): 105 | device = torch.device("cuda", 0) 106 | 107 | model, text_tokens = load_model(checkpoint_path, device) 108 | text_collater = get_text_token_collater(args.tokens_file) 109 | 110 | audio_tokenizer = AudioTokenizer(mbd=args.mbd) 111 | 112 | Path(output_dir).mkdir(parents=True, exist_ok=True) 113 | alef_bert_tokenizer = AlefBERTRootTokenizer(vocab_file=args.vocab_file) 114 | texts = texts.split("|") 115 | 116 | audio_prompts = list() 117 | encoded_frames = tokenize_audio(audio_tokenizer, prompt_audio) 118 | audio_prompts.append(encoded_frames[0][0]) 119 | audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1) 120 | audio_prompts = audio_prompts.to(device) 121 | 122 | for n, text in enumerate(texts): 123 | text_without_space = [replace_chars(f"{prompt_text} {text}").strip().replace(" ", "_")] 124 | tokens = alef_bert_tokenizer._tokenize(text_without_space) 125 | prompt_text_without_space = [replace_chars(f"{prompt_text}").strip().replace(" ", "_")] 126 | prompt_tokens = alef_bert_tokenizer._tokenize(prompt_text_without_space) 127 | 128 | text_tokens, text_tokens_lens = text_collater( 129 | [ 130 | tokens 131 | ] 132 | ) 133 | _, enroll_x_lens = text_collater( 134 | [ 135 | prompt_tokens 136 | ] 137 | ) 138 | 139 | # synthesis 140 | encoded_frames = model.inference( 141 | text_tokens.to(device), 142 | text_tokens_lens.to(device), 143 | audio_prompts, 144 | enroll_x_lens=enroll_x_lens, 145 | top_k=top_k, 146 | temperature=temperature, 147 | ) 148 | 149 | audio_path = f"{output_dir}/sample_{n}.wav" 150 | 151 | if args.mbd: 152 | samples = audio_tokenizer.mbd_decode( 153 | encoded_frames.transpose(2, 1) 154 | ) 155 | else: 156 | samples = audio_tokenizer.decode( 157 | [(encoded_frames.transpose(2, 1), None)] 158 | ) 159 | 160 | torchaudio.save(audio_path, samples[0].cpu(), 24000) 161 | 162 | 163 | def get_args(): 164 | parser = argparse.ArgumentParser() 165 | 166 | parser.add_argument( 167 | "--csv_path", 168 | type=str, 169 | default=None, 170 | help="Optional CSV path with columns filename|text. Overrides --text if provided." 171 | ) 172 | 173 | parser.add_argument( 174 | "--speaker", 175 | type=str, 176 | default="osim", 177 | help="A speaker from speakers.yaml", 178 | ) 179 | 180 | parser.add_argument( 181 | "--mbd", 182 | type=bool, 183 | default=False, 184 | help="use of multi band diffusion", 185 | ) 186 | 187 | parser.add_argument( 188 | "--text", 189 | type=str, 190 | help="Text to be synthesized.", 191 | required=False 192 | ) 193 | 194 | parser.add_argument( 195 | "--speaker-yaml", 196 | type=str, 197 | default="speakers/speakers.yaml", 198 | help="speaker yaml path", 199 | ) 200 | 201 | parser.add_argument( 202 | "--vocab-file", 203 | type=str, 204 | default="tokenizer/vocab.txt", 205 | help="vocab file for AlephBert" 206 | ) 207 | 208 | parser.add_argument( 209 | "--tokens-file", 210 | type=str, 211 | default="tokenizer/unique_words_tokens_all.k2symbols", 212 | help="tokens file path" 213 | ) 214 | 215 | parser.add_argument( 216 | "--checkpoint", 217 | type=str, 218 | default="ckpt.pt", 219 | help="Path to the saved checkpoint.", 220 | ) 221 | 222 | parser.add_argument( 223 | "--output-dir", 224 | type=Path, 225 | help="Path to the inferred wavs.", 226 | required=True 227 | ) 228 | 229 | parser.add_argument( 230 | "--top-k", 231 | type=int, 232 | default=40, 233 | help="top k sampling", 234 | ) 235 | 236 | 237 | args = parser.parse_args() 238 | 239 | if args.csv_path is None and not args.text: 240 | parser.error("argument --text is required if --csv_path is not provided") 241 | 242 | return args 243 | 244 | 245 | def jupyter_demo(text, speaker): 246 | device = torch.device("cpu") 247 | if torch.cuda.is_available(): 248 | device = torch.device("cuda", 0) 249 | 250 | model, text_tokens = load_model(CHECKPOINT_PATH, device) 251 | text_collater = get_text_token_collater(TOKENS_FILE) 252 | 253 | audio_tokenizer = AudioTokenizer(mbd=True) 254 | 255 | alef_bert_tokenizer = AlefBERTRootTokenizer(vocab_file=VOCAB_PATH) 256 | 257 | speaker_yaml = OmegaConf.load(SPEAKER_PATH) 258 | 259 | try: 260 | speaker = speaker_yaml[speaker] 261 | except: 262 | print(f"Invalid speaker {speaker}. Should be defined at speakers.yaml.") 263 | 264 | audio_prompt = str(Path(SPEAKER_PATH).parent / speaker["audio-prompt"]) 265 | 266 | audio_prompts = list() 267 | encoded_frames = tokenize_audio(audio_tokenizer, audio_prompt) 268 | audio_prompts.append(encoded_frames[0][0]) 269 | audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1) 270 | audio_prompts = audio_prompts.to(device) 271 | 272 | text_without_space = [replace_chars(f"{speaker['text-prompt']} {text}").strip().replace(" ", "_")] 273 | tokens = alef_bert_tokenizer._tokenize(text_without_space) 274 | prompt_text_without_space = [replace_chars(f"{speaker['text-prompt']}").strip().replace(" ", "_")] 275 | prompt_tokens = alef_bert_tokenizer._tokenize(prompt_text_without_space) 276 | 277 | text_tokens, text_tokens_lens = text_collater( 278 | [ 279 | tokens 280 | ] 281 | ) 282 | _, enroll_x_lens = text_collater( 283 | [ 284 | prompt_tokens 285 | ] 286 | ) 287 | 288 | # synthesis 289 | encoded_frames = model.inference( 290 | text_tokens.to(device), 291 | text_tokens_lens.to(device), 292 | audio_prompts, 293 | enroll_x_lens=enroll_x_lens, 294 | top_k=50, 295 | temperature=1, 296 | ) 297 | 298 | 299 | samples = audio_tokenizer.mbd_decode( 300 | encoded_frames 301 | ) 302 | 303 | torchaudio.save("out.wav", samples[0].cpu(), 24000) 304 | 305 | 306 | 307 | if __name__ == '__main__': 308 | args = get_args() 309 | speaker_yaml = OmegaConf.load(args.speaker_yaml) 310 | 311 | try: 312 | speaker = speaker_yaml[args.speaker] 313 | except: 314 | print(f"Invalid speaker {args.speaker}. Should be defined at speakers.yaml.") 315 | 316 | audio_prompt = str(Path(args.speaker_yaml).parent / speaker["audio-prompt"]) 317 | 318 | device, model, text_collater, audio_tokenizer, alef_bert_tokenizer, audio_prompts = prepare_inference( 319 | args.checkpoint, args, audio_prompt 320 | ) 321 | 322 | if args.csv_path: 323 | import csv 324 | with open(args.csv_path, "r", encoding="utf-8") as f: 325 | reader = csv.DictReader(f, delimiter=",") 326 | texts_with_filenames = [(row["filename"], row["text"]) for row in reader] 327 | else: 328 | if os.path.exists(args.text): 329 | with open(args.text, "r") as f: 330 | texts = f.read().splitlines() 331 | texts_with_filenames = [(f"sample_{i}", t) for i, t in enumerate(texts)] 332 | else: 333 | texts_with_filenames = [("sample_0", args.text)] 334 | 335 | infer_texts( 336 | texts_with_filenames, 337 | args.output_dir, 338 | speaker["text-prompt"], 339 | device, 340 | model, 341 | text_collater, 342 | audio_tokenizer, 343 | alef_bert_tokenizer, 344 | audio_prompts, 345 | top_k=args.top_k, 346 | temperature=1, 347 | args=args, 348 | ) 349 | 350 | -------------------------------------------------------------------------------- /valle/models/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from functools import partial 16 | from typing import Any, Dict, List, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from valle.utils import make_pad_mask 22 | from torchmetrics.classification import BinaryAccuracy 23 | 24 | from valle.models.valle import Transpose 25 | from valle.modules.embedding import SinePositionalEmbedding, TokenEmbedding 26 | from valle.modules.scaling import BalancedDoubleSwish, ScaledLinear 27 | from valle.modules.transformer import ( 28 | BalancedBasicNorm, 29 | IdentityNorm, 30 | TransformerDecoderLayer, 31 | TransformerEncoder, 32 | TransformerEncoderLayer, 33 | ) 34 | 35 | from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS 36 | from .visualizer import visualize 37 | 38 | IdentityNorm = IdentityNorm 39 | 40 | 41 | class Transformer(nn.Module): 42 | """It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding) 43 | Neural Speech Synthesis with Transformer Network 44 | https://arxiv.org/abs/1809.08895 45 | """ 46 | 47 | def __init__( 48 | self, 49 | d_model: int, 50 | nhead: int, 51 | num_layers: int, 52 | norm_first: bool = True, 53 | add_prenet: bool = False, 54 | scaling_xformers: bool = False, 55 | ): 56 | """ 57 | Args: 58 | d_model: 59 | The number of expected features in the input (required). 60 | nhead: 61 | The number of heads in the multiheadattention models (required). 62 | num_layers: 63 | The number of sub-decoder-layers in the decoder (required). 64 | """ 65 | super().__init__() 66 | self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x 67 | 68 | if add_prenet: 69 | self.encoder_prenet = nn.Sequential( 70 | Transpose(), 71 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), 72 | nn.BatchNorm1d(d_model), 73 | nn.ReLU(), 74 | nn.Dropout(0.5), 75 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), 76 | nn.BatchNorm1d(d_model), 77 | nn.ReLU(), 78 | nn.Dropout(0.5), 79 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), 80 | nn.BatchNorm1d(d_model), 81 | nn.ReLU(), 82 | nn.Dropout(0.5), 83 | Transpose(), 84 | nn.Linear(d_model, d_model), 85 | ) 86 | 87 | self.decoder_prenet = nn.Sequential( 88 | nn.Linear(NUM_MEL_BINS, 256), 89 | nn.ReLU(), 90 | nn.Dropout(0.5), 91 | nn.Linear(256, 256), 92 | nn.ReLU(), 93 | nn.Dropout(0.5), 94 | nn.Linear(256, d_model), 95 | ) 96 | 97 | assert scaling_xformers is False # TODO: update this block 98 | else: 99 | self.encoder_prenet = nn.Identity() 100 | if scaling_xformers: 101 | self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model) 102 | else: 103 | self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model) 104 | 105 | self.encoder_position = SinePositionalEmbedding( 106 | d_model, 107 | dropout=0.1, 108 | scale=False, 109 | ) 110 | self.decoder_position = SinePositionalEmbedding( 111 | d_model, dropout=0.1, scale=False 112 | ) 113 | 114 | if scaling_xformers: 115 | self.encoder = TransformerEncoder( 116 | TransformerEncoderLayer( 117 | d_model, 118 | nhead, 119 | dim_feedforward=d_model * 4, 120 | dropout=0.1, 121 | batch_first=True, 122 | norm_first=norm_first, 123 | linear1_self_attention_cls=ScaledLinear, 124 | linear2_self_attention_cls=partial( 125 | ScaledLinear, initial_scale=0.01 126 | ), 127 | linear1_feedforward_cls=ScaledLinear, 128 | linear2_feedforward_cls=partial( 129 | ScaledLinear, initial_scale=0.01 130 | ), 131 | activation=partial( 132 | BalancedDoubleSwish, 133 | channel_dim=-1, 134 | max_abs=10.0, 135 | min_prob=0.25, 136 | ), 137 | layer_norm_cls=IdentityNorm, 138 | ), 139 | num_layers=num_layers, 140 | norm=BalancedBasicNorm(d_model) if norm_first else None, 141 | ) 142 | 143 | self.decoder = nn.TransformerDecoder( 144 | TransformerDecoderLayer( 145 | d_model, 146 | nhead, 147 | dim_feedforward=d_model * 4, 148 | dropout=0.1, 149 | batch_first=True, 150 | norm_first=norm_first, 151 | linear1_self_attention_cls=ScaledLinear, 152 | linear2_self_attention_cls=partial( 153 | ScaledLinear, initial_scale=0.01 154 | ), 155 | linear1_feedforward_cls=ScaledLinear, 156 | linear2_feedforward_cls=partial( 157 | ScaledLinear, initial_scale=0.01 158 | ), 159 | activation=partial( 160 | BalancedDoubleSwish, 161 | channel_dim=-1, 162 | max_abs=10.0, 163 | min_prob=0.25, 164 | ), 165 | layer_norm_cls=IdentityNorm, 166 | ), 167 | num_layers=num_layers, 168 | norm=BalancedBasicNorm(d_model) if norm_first else None, 169 | ) 170 | 171 | self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS) 172 | self.stop_layer = nn.Linear(d_model, 1) 173 | else: 174 | self.encoder = nn.TransformerEncoder( 175 | nn.TransformerEncoderLayer( 176 | d_model, 177 | nhead, 178 | dim_feedforward=d_model * 4, 179 | activation=F.relu, 180 | dropout=0.1, 181 | batch_first=True, 182 | norm_first=norm_first, 183 | ), 184 | num_layers=num_layers, 185 | norm=nn.LayerNorm(d_model) if norm_first else None, 186 | ) 187 | 188 | self.decoder = nn.TransformerDecoder( 189 | nn.TransformerDecoderLayer( 190 | d_model, 191 | nhead, 192 | dim_feedforward=d_model * 4, 193 | activation=F.relu, 194 | dropout=0.1, 195 | batch_first=True, 196 | norm_first=norm_first, 197 | ), 198 | num_layers=num_layers, 199 | norm=nn.LayerNorm(d_model) if norm_first else None, 200 | ) 201 | 202 | self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS) 203 | self.stop_layer = nn.Linear(d_model, 1) 204 | 205 | self.stop_accuracy_metric = BinaryAccuracy( 206 | threshold=0.5, multidim_average="global" 207 | ) 208 | 209 | # self.apply(self._init_weights) 210 | 211 | # def _init_weights(self, module): 212 | # if isinstance(module, (nn.Linear)): 213 | # module.weight.data.normal_(mean=0.0, std=0.02) 214 | # if isinstance(module, nn.Linear) and module.bias is not None: 215 | # module.bias.data.zero_() 216 | # elif isinstance(module, nn.LayerNorm): 217 | # module.bias.data.zero_() 218 | # module.weight.data.fill_(1.0) 219 | # elif isinstance(module, nn.Embedding): 220 | # module.weight.data.normal_(mean=0.0, std=0.02) 221 | 222 | def forward( 223 | self, 224 | x: torch.Tensor, 225 | x_lens: torch.Tensor, 226 | y: torch.Tensor, 227 | y_lens: torch.Tensor, 228 | reduction: str = "sum", 229 | train_stage: int = 0, 230 | **kwargs, 231 | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: 232 | """ 233 | Args: 234 | x: 235 | A 2-D tensor of shape (N, S). 236 | x_lens: 237 | A 1-D tensor of shape (N,). It contains the number of tokens in `x` 238 | before padding. 239 | y: 240 | A 3-D tensor of shape (N, T, 8). 241 | y_lens: 242 | A 1-D tensor of shape (N,). It contains the number of tokens in `x` 243 | before padding. 244 | train_stage: 245 | Not used in this model. 246 | Returns: 247 | Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. 248 | """ 249 | del train_stage 250 | 251 | assert x.ndim == 2, x.shape 252 | assert x_lens.ndim == 1, x_lens.shape 253 | assert y.ndim == 3, y.shape 254 | assert y_lens.ndim == 1, y_lens.shape 255 | 256 | assert torch.all(x_lens > 0) 257 | 258 | # NOTE: x has been padded in TextTokenCollater 259 | x_mask = make_pad_mask(x_lens).to(x.device) 260 | 261 | x = self.text_embedding(x) 262 | x = self.encoder_prenet(x) 263 | x = self.encoder_position(x) 264 | x = self.encoder(x, src_key_padding_mask=x_mask) 265 | 266 | total_loss, metrics = 0.0, {} 267 | 268 | y_mask = make_pad_mask(y_lens).to(y.device) 269 | y_mask_float = y_mask.type(torch.float32) 270 | data_mask = 1.0 - y_mask_float.unsqueeze(-1) 271 | 272 | # Training 273 | # AR Decoder 274 | def pad_y(y): 275 | y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach() 276 | # inputs, targets 277 | return y[:, :-1], y[:, 1:] 278 | 279 | y, targets = pad_y(y * data_mask) # mask padding as zeros 280 | 281 | y_emb = self.decoder_prenet(y) 282 | y_pos = self.decoder_position(y_emb) 283 | 284 | y_len = y_lens.max() 285 | tgt_mask = torch.triu( 286 | torch.ones(y_len, y_len, device=y.device, dtype=torch.bool), 287 | diagonal=1, 288 | ) 289 | y_dec = self.decoder( 290 | y_pos, 291 | x, 292 | tgt_mask=tgt_mask, 293 | memory_key_padding_mask=x_mask, 294 | ) 295 | 296 | predict = self.predict_layer(y_dec) 297 | # loss 298 | total_loss = F.mse_loss(predict, targets, reduction=reduction) 299 | 300 | logits = self.stop_layer(y_dec).squeeze(-1) 301 | stop_loss = F.binary_cross_entropy_with_logits( 302 | logits, 303 | y_mask_float.detach(), 304 | weight=1.0 + y_mask_float.detach() * 4.0, 305 | reduction=reduction, 306 | ) 307 | metrics["stop_loss"] = stop_loss.detach() 308 | 309 | stop_accuracy = self.stop_accuracy_metric( 310 | (torch.sigmoid(logits) >= 0.5).type(torch.int64), 311 | y_mask.type(torch.int64), 312 | ) 313 | # icefall MetricsTracker.norm_items() 314 | metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type( 315 | torch.float32 316 | ) 317 | 318 | return ((x, predict), total_loss + 100.0 * stop_loss, metrics) 319 | 320 | def inference( 321 | self, 322 | x: torch.Tensor, 323 | x_lens: torch.Tensor, 324 | y: Any = None, 325 | **kwargs, 326 | ) -> torch.Tensor: 327 | """ 328 | Args: 329 | x: 330 | A 2-D tensor of shape (1, S). 331 | x_lens: 332 | A 1-D tensor of shape (1,). It contains the number of tokens in `x` 333 | before padding. 334 | Returns: 335 | Return the predicted audio code matrix and cross-entropy loss. 336 | """ 337 | assert x.ndim == 2, x.shape 338 | assert x_lens.ndim == 1, x_lens.shape 339 | 340 | assert torch.all(x_lens > 0) 341 | 342 | x_mask = make_pad_mask(x_lens).to(x.device) 343 | 344 | x = self.text_embedding(x) 345 | x = self.encoder_prenet(x) 346 | x = self.encoder_position(x) 347 | x = self.encoder(x, src_key_padding_mask=x_mask) 348 | 349 | x_mask = make_pad_mask(x_lens).to(x.device) 350 | 351 | # AR Decoder 352 | # TODO: Managing decoder steps avoid repetitive computation 353 | y = torch.zeros( 354 | [x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device 355 | ) 356 | while True: 357 | y_emb = self.decoder_prenet(y) 358 | y_pos = self.decoder_position(y_emb) 359 | 360 | tgt_mask = torch.triu( 361 | torch.ones( 362 | y.shape[1], y.shape[1], device=y.device, dtype=torch.bool 363 | ), 364 | diagonal=1, 365 | ) 366 | 367 | y_dec = self.decoder( 368 | y_pos, 369 | x, 370 | tgt_mask=tgt_mask, 371 | memory_mask=None, 372 | memory_key_padding_mask=x_mask, 373 | ) 374 | predict = self.predict_layer(y_dec[:, -1:]) 375 | 376 | logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5 377 | if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()): 378 | print( 379 | f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]" 380 | ) 381 | break 382 | 383 | y = torch.concat([y, predict], dim=1) 384 | 385 | return y[:, 1:] 386 | 387 | def visualize( 388 | self, 389 | predicts: Tuple[torch.Tensor], 390 | batch: Dict[str, Union[List, torch.Tensor]], 391 | output_dir: str, 392 | limit: int = 4, 393 | ) -> None: 394 | visualize(predicts, batch, output_dir, limit=limit) 395 | -------------------------------------------------------------------------------- /valle/data/hebrew_root_tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers.tokenization_utils import PreTrainedTokenizer 2 | from transformers.utils import logging 3 | from typing import List, Optional 4 | from itertools import chain 5 | import collections 6 | import os 7 | from functools import lru_cache 8 | 9 | logger = logging.get_logger(__name__) 10 | 11 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 12 | 13 | 14 | def load_vocab(vocab_file): 15 | """Loads a vocabulary file into a dictionary.""" 16 | vocab = collections.OrderedDict() 17 | with open(vocab_file, "r", encoding="utf-8") as reader: 18 | tokens = reader.readlines() 19 | for index, token in enumerate(tokens): 20 | token = token.rstrip("\n") 21 | vocab[token] = index 22 | return vocab 23 | 24 | 25 | suf_replace = { 26 | 'ף': 'פ', 27 | 'ץ': 'צ', 28 | 'ך': 'כ', 29 | 'ן': 'נ', 30 | 'ם': 'מ', 31 | } 32 | 33 | punctuation = { 34 | ',': '', 35 | '.': '', 36 | '?': '', 37 | '-': '', 38 | '"': '' 39 | } 40 | 41 | def replace_chars(text): 42 | text = ''.join(suf_replace.get(c, c) for c in text) 43 | text = ''.join(punctuation.get(c, c) for c in text) 44 | return text 45 | 46 | def whitespace_tokenize(text): 47 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 48 | text = replace_chars(text) 49 | text = text.strip() 50 | if not text: 51 | return [] 52 | tokens = text.split() 53 | return tokens 54 | 55 | 56 | class Piece: 57 | def __init__(self, piece, idxs): 58 | self.text = piece 59 | self.idxs = idxs 60 | 61 | def __add__(self, other): 62 | assert isinstance(other, Piece) 63 | return Piece(str(self) + str(other), self.idxs + other.idxs) 64 | 65 | def __str__(self): 66 | return self.text 67 | 68 | 69 | class Structre: 70 | def __init__(self, structre=None, idxs=[], length=0, head=None, tail=None): 71 | self.text = structre if structre else '#' * length 72 | self.idxs = idxs 73 | 74 | def __add__(self, piece): 75 | assert isinstance(piece, Piece) 76 | res = '' 77 | last = 0 78 | for i, p in zip(piece.idxs, str(piece)): 79 | # i = i - 1 80 | res += self.text[last:i] + p 81 | last = i + 1 82 | res += self.text[last:] 83 | return Structre(res, sorted(self.idxs + piece.idxs)) 84 | 85 | def __str__(self): 86 | return self.text 87 | 88 | 89 | from collections import OrderedDict, defaultdict, Counter, namedtuple 90 | 91 | 92 | class Word: 93 | def __init__(self, word, count=None): 94 | self.word = word 95 | self.count = count 96 | self.pieces = OrderedDict({i: Piece(p, [i, ]) for i, p in enumerate(word 97 | )}) 98 | self.structre = Structre(length=len(word)) 99 | 100 | def _pairs(self): 101 | pieces = list(self.pieces.values()) 102 | pieces_pairs = [a + b for a, b in zip(pieces, pieces[1:])] 103 | struct_pairs = [self.structre + p for p in pieces if '_' not in p.text] 104 | return struct_pairs + pieces_pairs 105 | 106 | def make_pairs(self): 107 | self.pairs_list = self._pairs() 108 | self.pairs = defaultdict(list) 109 | for pair in self.pairs_list: 110 | self.pairs[str(pair)].append(pair) 111 | 112 | def join(self, pair): 113 | joined = set() 114 | if pair in self.pairs: 115 | for instance in self.pairs[pair]: 116 | idxs = set(instance.idxs) 117 | if not idxs & joined: 118 | if isinstance(instance, Structre): 119 | self.structre = instance 120 | start = 0 121 | else: 122 | self.pieces[instance.idxs[0]] = instance 123 | start = 1 124 | for idx in instance.idxs[start:]: 125 | if idx in self.pieces: 126 | del self.pieces[idx] 127 | joined |= idxs 128 | 129 | def _sub_words(self): 130 | if self.structre.idxs: 131 | yield self.structre 132 | for piece in self.pieces.values(): 133 | yield piece 134 | 135 | def __repr__(self): 136 | return str([str(v) for v in sorted(self._sub_words(), key=lambda x: x.idxs)]) 137 | 138 | def tokenized_iter(self): 139 | last = 0 140 | structre = str(self.structre) 141 | for i, piece in self.pieces.items(): 142 | if i > last: 143 | for c in structre[last:i]: 144 | yield c 145 | yield str(piece) 146 | p_beg, p_end = piece.idxs[0], piece.idxs[-1] 147 | last = p_end + 1 148 | if any(p_beg < s_idx < p_end for s_idx in self.structre.idxs): 149 | yield structre[p_beg:last] 150 | 151 | if last < len(structre): 152 | yield structre[last:] 153 | 154 | 155 | class AlefBERTRootTokenizer(PreTrainedTokenizer): 156 | r""" 157 | Construct a BERT tokenizer. Based on WordPiece. 158 | 159 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. 160 | Users should refer to this superclass for more information regarding those methods. 161 | 162 | Args: 163 | vocab_file (:obj:`str`): 164 | File containing the vocabulary. 165 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): 166 | Whether or not to lowercase the input when tokenizing. 167 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): 168 | Whether or not to do basic tokenization before WordPiece. 169 | never_split (:obj:`Iterable`, `optional`): 170 | Collection of tokens which will never be split during tokenization. Only has an effect when 171 | :obj:`do_basic_tokenize=True` 172 | unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): 173 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 174 | token instead. 175 | sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): 176 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences 177 | for sequence classification or for a text and a question for question answering. 178 | It is also used as the last token of a sequence built with special tokens. 179 | pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): 180 | The token used for padding, for example when batching sequences of different lengths. 181 | cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): 182 | The classifier token which is used when doing sequence classification (classification of the whole 183 | sequence instead of per-token classification). It is the first token of the sequence when built with 184 | special tokens. 185 | mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): 186 | The token used for masking values. This is the token used when training this model with masked language 187 | modeling. This is the token which the model will try to predict. 188 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): 189 | Whether or not to tokenize Chinese characters. 190 | 191 | This should likely be deactivated for Japanese (see this `issue 192 | `__). 193 | strip_accents: (:obj:`bool`, `optional`): 194 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 195 | value for :obj:`lowercase` (as in the original BERT). 196 | """ 197 | 198 | def __init__( 199 | self, 200 | vocab_file, 201 | unk_token="[UNK]", 202 | sep_token="[SEP]", 203 | pad_token="[PAD]", 204 | cls_token="[CLS]", 205 | mask_token="[MASK]", 206 | **kwargs 207 | ): 208 | self.vocab = load_vocab(vocab_file) 209 | 210 | super().__init__( 211 | unk_token=unk_token, 212 | sep_token=sep_token, 213 | pad_token=pad_token, 214 | cls_token=cls_token, 215 | mask_token=mask_token, 216 | **kwargs, 217 | ) 218 | 219 | if not os.path.isfile(vocab_file): 220 | raise ValueError( 221 | "Can't find a vocabulary file at path '{}'".format(vocab_file) 222 | ) 223 | 224 | self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) 225 | self.model_max_length = 512 226 | self.cache = dict() 227 | self.special_tokens = {unk_token, sep_token, pad_token, cls_token, mask_token} 228 | 229 | @property 230 | def vocab_size(self): 231 | return len(self.vocab) 232 | 233 | def get_vocab(self): 234 | return dict(self.vocab, **self.added_tokens_encoder) 235 | 236 | def _tokenize_word(self, w): 237 | if w in self.special_tokens: 238 | return [w] 239 | cached = self.cache.get(w) 240 | if cached: 241 | return cached 242 | word = Word(w) 243 | while True: 244 | min_rank = float('inf') 245 | min_pair = None 246 | word.make_pairs() 247 | for pair in word.pairs_list: 248 | pair = str(pair) 249 | rank = self.vocab[pair] if pair in self.vocab else float('inf') 250 | if rank < min_rank: 251 | min_pair = pair 252 | min_rank = rank 253 | if min_rank == float('inf') or not min_pair: 254 | break 255 | word.join(min_pair) 256 | res = list(word.tokenized_iter()) 257 | self.cache[w] = res 258 | return res 259 | 260 | def _tokenize(self, text): 261 | split_tokens = list(chain(*(self._tokenize_word(word) for word in whitespace_tokenize(text)))) 262 | if not split_tokens: 263 | print('tokenizer issue') 264 | return split_tokens 265 | 266 | def _convert_token_to_id(self, token): 267 | """ Converts a token (str) in an id using the vocab. """ 268 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 269 | 270 | def _convert_id_to_token(self, index): 271 | """Converts an index (integer) in a token (str) using the vocab.""" 272 | return self.ids_to_tokens.get(index, self.unk_token) 273 | 274 | def convert_tokens_to_string(self, tokens): 275 | """ Converts a sequence of tokens (string) in a single string. """ 276 | print(tokens) 277 | raise NotImplemented 278 | # out_string = " ".join(tokens).replace(" ##", "").strip() 279 | # return out_string 280 | 281 | def build_inputs_with_special_tokens( 282 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 283 | ) -> List[int]: 284 | """ 285 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks 286 | by concatenating and adding special tokens. 287 | A BERT sequence has the following format: 288 | 289 | - single sequence: ``[CLS] X [SEP]`` 290 | - pair of sequences: ``[CLS] A [SEP] B [SEP]`` 291 | 292 | Args: 293 | token_ids_0 (:obj:`List[int]`): 294 | List of IDs to which the special tokens will be added. 295 | token_ids_1 (:obj:`List[int]`, `optional`): 296 | Optional second list of IDs for sequence pairs. 297 | 298 | Returns: 299 | :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. 300 | """ 301 | if token_ids_1 is None: 302 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 303 | cls = [self.cls_token_id] 304 | sep = [self.sep_token_id] 305 | return cls + token_ids_0 + sep + token_ids_1 + sep 306 | 307 | def get_special_tokens_mask( 308 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, 309 | already_has_special_tokens: bool = False 310 | ) -> List[int]: 311 | """ 312 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding 313 | special tokens using the tokenizer ``prepare_for_model`` method. 314 | 315 | Args: 316 | token_ids_0 (:obj:`List[int]`): 317 | List of IDs. 318 | token_ids_1 (:obj:`List[int]`, `optional`): 319 | Optional second list of IDs for sequence pairs. 320 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): 321 | Whether or not the token list is already formatted with special tokens for the model. 322 | 323 | Returns: 324 | :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 325 | """ 326 | 327 | if already_has_special_tokens: 328 | if token_ids_1 is not None: 329 | raise ValueError( 330 | "You should not supply a second sequence if the provided sequence of " 331 | "ids is already formated with special tokens for the model." 332 | ) 333 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) 334 | 335 | if token_ids_1 is not None: 336 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] 337 | return [1] + ([0] * len(token_ids_0)) + [1] 338 | 339 | def get_end_of_word_mask(self, text_0, text_1=None): 340 | words_lens_0 = [len(self._tokenize_word(word)) for word in whitespace_tokenize(text_0)] 341 | res_0 = [] 342 | for l in words_lens_0: 343 | res_0 += ([0] * (l - 1)) + [1] 344 | if text_1: 345 | words_lens_1 = [len(self._tokenize_word(word)) for word in whitespace_tokenize(text_1)] 346 | res_1 = [] 347 | for l in words_lens_1: 348 | res_1 += ([0] * (l - 1)) + [1] 349 | return [1] + res_0 + [1] + res_1 + [1] 350 | return [1] + res_0 + [1] 351 | 352 | def __call__(self, text_0, text_1=None, end_of_word=False, *argv, **kwargs): 353 | res = super().__call__(text_0, text_1, *argv, **kwargs) 354 | if end_of_word: 355 | res['end_of_word_mask'] = self.get_end_of_word_mask(text_0, text_1) 356 | return res 357 | 358 | def create_token_type_ids_from_sequences( 359 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 360 | ) -> List[int]: 361 | """ 362 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. 363 | A BERT sequence pair mask has the following format: 364 | 365 | :: 366 | 367 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 368 | | first sequence | second sequence | 369 | 370 | If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). 371 | 372 | Args: 373 | token_ids_0 (:obj:`List[int]`): 374 | List of IDs. 375 | token_ids_1 (:obj:`List[int]`, `optional`): 376 | Optional second list of IDs for sequence pairs. 377 | 378 | Returns: 379 | :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given 380 | sequence(s). 381 | """ 382 | sep = [self.sep_token_id] 383 | cls = [self.cls_token_id] 384 | if token_ids_1 is None: 385 | return len(cls + token_ids_0 + sep) * [0] 386 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 387 | 388 | def save_vocabulary(self, vocab_path, filename_prefix=''): 389 | """ 390 | Save the vocabulary (copy original file) and special tokens file to a directory. 391 | 392 | Args: 393 | vocab_path (:obj:`str`): 394 | The directory in which to save the vocabulary. 395 | 396 | Returns: 397 | :obj:`Tuple(str)`: Paths to the files saved. 398 | """ 399 | index = 0 400 | if os.path.isdir(vocab_path): 401 | vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"]) 402 | else: 403 | vocab_file = vocab_path 404 | with open(vocab_file, "w", encoding="utf-8") as writer: 405 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 406 | if index != token_index: 407 | logger.warning( 408 | "Saving vocabulary to {}: vocabulary indices are not consecutive." 409 | " Please check that the vocabulary is not corrupted!".format(vocab_file) 410 | ) 411 | index = token_index 412 | writer.write(token + "\n") 413 | index += 1 414 | return (vocab_file,) 415 | 416 | 417 | if __name__ == '__main__': 418 | tokenizer = AlefBERTRootTokenizer(vocab_file="/cs/labs/adiyoss/amitroth/valle/scripts/vocab.txt") 419 | print("done") -------------------------------------------------------------------------------- /valle/data/datamodule.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # See ../../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | import argparse 19 | import inspect 20 | import logging 21 | from functools import lru_cache 22 | from pathlib import Path 23 | from typing import Any, Dict, Optional 24 | 25 | import torch 26 | from lhotse import CutSet, load_manifest_lazy 27 | 28 | from lhotse.dataset import ( 29 | CutConcatenate, 30 | DynamicBucketingSampler, 31 | PrecomputedFeatures, 32 | SimpleCutSampler, 33 | SpecAugment, 34 | ) 35 | 36 | from lhotse.dataset.input_strategies import OnTheFlyFeatures 37 | from lhotse.utils import fix_random_seed 38 | from torch.utils.data import DataLoader 39 | 40 | from valle.data.collation import get_text_token_collater 41 | from valle.data.dataset import SpeechSynthesisDataset 42 | from valle.data.fbank import get_fbank_extractor 43 | from valle.data.input_strategies import PromptedPrecomputedFeatures 44 | 45 | PrecomputedFeatures = PrecomputedFeatures 46 | 47 | 48 | class _SeedWorkers: 49 | def __init__(self, seed: int): 50 | self.seed = seed 51 | 52 | def __call__(self, worker_id: int): 53 | fix_random_seed(self.seed + worker_id) 54 | 55 | 56 | def _get_input_strategy(input_strategy, dataset, cuts): 57 | if input_strategy == "PromptedPrecomputedFeatures": 58 | return PromptedPrecomputedFeatures(dataset, cuts) 59 | 60 | return eval(input_strategy)() 61 | 62 | 63 | class TtsDataModule: 64 | """ 65 | DataModule for VALL-E TTS experiments. 66 | It assumes there is always one train and valid dataloader. 67 | 68 | It contains all the common data pipeline modules used in TTS 69 | experiments, e.g.: 70 | - dynamic batch size, 71 | - bucketing samplers, 72 | - cut concatenation[not used & tested yet], 73 | - augmentation[not used & tested yet], 74 | - on-the-fly feature extraction[not used & tested yet] 75 | 76 | This class should be derived for specific corpora used in TTS tasks. 77 | """ 78 | 79 | def __init__(self, args: argparse.Namespace): 80 | self.args = args 81 | 82 | @classmethod 83 | def add_arguments(cls, parser: argparse.ArgumentParser): 84 | group = parser.add_argument_group( 85 | title="TTS data related options", 86 | description="These options are used for the preparation of " 87 | "PyTorch DataLoaders from Lhotse CutSet's -- they control the " 88 | "effective batch sizes, sampling strategies, applied data " 89 | "augmentations, etc.", 90 | ) 91 | group.add_argument( 92 | "--manifest-dir", 93 | type=Path, 94 | # default=Path("/cs/labs/adiyoss/amitroth/valle/examples/osim_geek/data/tokenized"), 95 | help="Path to directory with train/valid/test cuts.", 96 | ) # todo - change here default 97 | group.add_argument( 98 | "--max-duration", 99 | type=int, 100 | default=40.0, 101 | help="Maximum pooled recordings duration (seconds) in a " 102 | "single batch. You can reduce it if it causes CUDA OOM.", 103 | ) 104 | group.add_argument( 105 | "--bucketing-sampler", 106 | default=True, 107 | help="When enabled, the batches will come from buckets of " 108 | "similar duration (saves padding frames).", 109 | ) 110 | group.add_argument( 111 | "--num-buckets", 112 | type=int, 113 | default=10, 114 | help="The number of buckets for the DynamicBucketingSampler" 115 | "(you might want to increase it for larger datasets).", 116 | ) 117 | group.add_argument( 118 | "--concatenate-cuts", 119 | default=False, 120 | help="When enabled, utterances (cuts) will be concatenated " 121 | "to minimize the amount of padding.", 122 | ) 123 | group.add_argument( 124 | "--duration-factor", 125 | type=float, 126 | default=1.0, 127 | help="Determines the maximum duration of a concatenated cut " 128 | "relative to the duration of the longest cut in a batch.", 129 | ) 130 | group.add_argument( 131 | "--gap", 132 | type=float, 133 | default=0.1, 134 | help="The amount of padding (in seconds) inserted between " 135 | "concatenated cuts. This padding is filled with noise when " 136 | "noise augmentation is used.", 137 | ) 138 | group.add_argument( 139 | "--on-the-fly-feats", 140 | default=False, 141 | help="When enabled, use on-the-fly cut mixing and feature " 142 | "extraction. Will drop existing precomputed feature manifests " 143 | "if available.", 144 | ) 145 | group.add_argument( 146 | "--shuffle", 147 | default=True, 148 | help="When enabled (=default), the examples will be " 149 | "shuffled for each epoch.", 150 | ) 151 | group.add_argument( 152 | "--buffer-size", 153 | type=int, 154 | default=100000, 155 | help="How many cuts (or cut pairs, triplets) we hold at any time across all of the buckets." 156 | "Increasing ``max_duration`` (batch_size) or ``num_buckets`` might require increasing this number." 157 | "It will result in larger memory usage.", 158 | ) 159 | group.add_argument( 160 | "--shuffle-buffer-size", 161 | type=int, 162 | default=100000, 163 | help="How many cuts (or cut pairs, triplets) are being held in memory" 164 | "a buffer used for streaming shuffling. Larger number means better randomness at the cost" 165 | "of higher memory usage.", 166 | ) 167 | group.add_argument( 168 | "--drop-last", 169 | default=False, 170 | help="Whether to drop last batch. Used by sampler.", 171 | ) 172 | group.add_argument( 173 | "--return-cuts", 174 | default=True, 175 | help="When enabled, each batch will have the " 176 | "field: batch['supervisions']['cut'] with the cuts that " 177 | "were used to construct it.", 178 | ) 179 | 180 | group.add_argument( 181 | "--num-workers", 182 | type=int, 183 | default=8, 184 | help="The number of training dataloader workers that " 185 | "collect the batches.", 186 | ) 187 | 188 | group.add_argument( 189 | "--enable-spec-aug", 190 | default=False, 191 | help="When enabled, use SpecAugment for training dataset.", 192 | ) 193 | 194 | group.add_argument( 195 | "--spec-aug-time-warp-factor", 196 | type=int, 197 | default=80, 198 | help="Used only when --enable-spec-aug is True. " 199 | "It specifies the factor for time warping in SpecAugment. " 200 | "Larger values mean more warping. " 201 | "A value less than 1 means to disable time warp.", 202 | ) 203 | 204 | group.add_argument( 205 | "--input-strategy", 206 | type=str, 207 | default="PrecomputedFeatures", 208 | help="AudioSamples or PrecomputedFeatures or PromptedPrecomputedFeatures", 209 | ) 210 | 211 | group.add_argument( 212 | "--dataset", 213 | type=str, 214 | # default="libritts", 215 | help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.", 216 | ) 217 | 218 | parser.add_argument( 219 | "--text-tokens-path", 220 | type=str, 221 | default="None", 222 | help="Path to the unique text tokens file", 223 | ) # TODO change it 224 | 225 | parser.add_argument( 226 | "--sampling-rate", 227 | type=int, 228 | default=24000, 229 | help="""Audio sampling rate.""", 230 | ) 231 | 232 | group.add_argument( 233 | "--text-tokens", 234 | type=str, 235 | default="text", 236 | help="which token type to use", 237 | ) # todo - change here default 238 | 239 | group.add_argument( 240 | "--audio-tokens", 241 | type=str, 242 | default="encodec", 243 | help="which token type to use", 244 | ) # todo - change here default 245 | 246 | def train_dataloaders( 247 | self, 248 | cuts_train: CutSet, 249 | sampler_state_dict: Optional[Dict[str, Any]] = None, 250 | ) -> DataLoader: 251 | """ 252 | Args: 253 | cuts_train: 254 | CutSet for training. 255 | sampler_state_dict: 256 | The state dict for the training sampler. 257 | """ 258 | transforms = [] 259 | 260 | if self.args.concatenate_cuts: 261 | logging.info( 262 | f"Using cut concatenation with duration factor " 263 | f"{self.args.duration_factor} and gap {self.args.gap}." 264 | ) 265 | # Cut concatenation should be the first transform in the list, 266 | # so that if we e.g. mix noise in, it will fill the gaps between 267 | # different utterances. 268 | transforms = [ 269 | CutConcatenate( 270 | duration_factor=self.args.duration_factor, gap=self.args.gap 271 | ) 272 | ] + transforms 273 | 274 | input_transforms = [] 275 | if self.args.enable_spec_aug: 276 | logging.info("Enable SpecAugment") 277 | logging.info( 278 | f"Time warp factor: {self.args.spec_aug_time_warp_factor}" 279 | ) 280 | # Set the value of num_frame_masks according to Lhotse's version. 281 | # In different Lhotse's versions, the default of num_frame_masks is 282 | # different. 283 | num_frame_masks = 10 284 | num_frame_masks_parameter = inspect.signature( 285 | SpecAugment.__init__ 286 | ).parameters["num_frame_masks"] 287 | if num_frame_masks_parameter.default == 1: 288 | num_frame_masks = 2 289 | logging.info(f"Num frame mask: {num_frame_masks}") 290 | input_transforms.append( 291 | SpecAugment( 292 | time_warp_factor=self.args.spec_aug_time_warp_factor, 293 | num_frame_masks=num_frame_masks, 294 | features_mask_size=27, 295 | num_feature_masks=2, 296 | frames_mask_size=100, 297 | ) 298 | ) 299 | else: 300 | logging.info("Disable SpecAugment") 301 | 302 | logging.info("About to create train dataset") 303 | if self.args.on_the_fly_feats: 304 | # NOTE: the PerturbSpeed transform should be added only if we 305 | # remove it from data prep stage. 306 | # Add on-the-fly speed perturbation; since originally it would 307 | # have increased epoch size by 3, we will apply prob 2/3 and use 308 | # 3x more epochs. 309 | # Speed perturbation probably should come first before 310 | # concatenation, but in principle the transforms order doesn't have 311 | # to be strict (e.g. could be randomized) 312 | # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa 313 | # Drop feats to be on the safe side. 314 | train = SpeechSynthesisDataset( 315 | get_text_token_collater(self.args.text_tokens_path), 316 | cut_transforms=transforms, 317 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()), 318 | feature_transforms=input_transforms, 319 | token=self.args.text_tokens 320 | ) 321 | else: 322 | train = SpeechSynthesisDataset( 323 | get_text_token_collater(self.args.text_tokens_path), 324 | feature_input_strategy=_get_input_strategy( 325 | self.args.input_strategy, self.args.dataset, cuts_train 326 | ), 327 | cut_transforms=transforms, 328 | feature_transforms=input_transforms, 329 | token=self.args.text_tokens 330 | ) 331 | 332 | # @amitroth added this line to prevent oom from long tokens 333 | if True: 334 | # length_before_filter = len(cuts_train.to_eager()) 335 | 336 | def drop_c(c): 337 | # print(c.supervisions[0].custom is not None) and (len(c.supervisions[0].custom['tokens']['char']) < 190) 338 | return len(c.supervisions[0].custom['tokens']['char']) < 200 339 | 340 | 341 | cuts_train = cuts_train.filter(drop_c) 342 | logging.info("filtered") 343 | print("filtered nby length!") 344 | # print(f"filtered! {len(cuts_train.to_eager())} - {length_before_filter}") 345 | 346 | # print(f"filtered! {len(cuts_train.to_eager())}") 347 | logging.info("here") 348 | print("here!") 349 | 350 | if self.args.bucketing_sampler: 351 | logging.info("Using DynamicBucketingSampler") 352 | train_sampler = DynamicBucketingSampler( 353 | cuts_train, 354 | max_duration=self.args.max_duration, 355 | shuffle=self.args.shuffle, 356 | buffer_size=self.args.buffer_size, 357 | shuffle_buffer_size=self.args.shuffle_buffer_size, 358 | quadratic_duration=10, 359 | num_cuts_for_bins_estimate=10000, 360 | drop_last=True, 361 | ) 362 | else: 363 | logging.info( 364 | "Using SimpleCutSampler and sort by duraton(ascending=True)." 365 | ) 366 | cuts_train = cuts_train.to_eager().sort_by_duration(ascending=True) 367 | train_sampler = SimpleCutSampler( 368 | cuts_train, 369 | max_duration=self.args.max_duration, 370 | shuffle=self.args.shuffle, 371 | ) 372 | logging.info("About to create train dataloader") 373 | 374 | if sampler_state_dict is not None: 375 | logging.info("Loading sampler state dict") 376 | train_sampler.load_state_dict(sampler_state_dict) 377 | 378 | # 'seed' is derived from the current random state, which will have 379 | # previously been set in the main process. 380 | seed = torch.randint(0, 100000, ()).item() 381 | worker_init_fn = _SeedWorkers(seed) 382 | 383 | train_dl = DataLoader( 384 | train, 385 | sampler=train_sampler, 386 | batch_size=None, 387 | num_workers=self.args.num_workers, 388 | persistent_workers=False, 389 | worker_init_fn=worker_init_fn, 390 | ) 391 | 392 | return train_dl 393 | 394 | def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: 395 | logging.info("About to create dev dataset") 396 | if self.args.on_the_fly_feats: 397 | validate = SpeechSynthesisDataset( 398 | get_text_token_collater(self.args.text_tokens_path), 399 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()), 400 | cut_transforms=[], 401 | token=self.args.text_tokens 402 | ) 403 | else: 404 | validate = SpeechSynthesisDataset( 405 | get_text_token_collater(self.args.text_tokens_path), 406 | feature_input_strategy=_get_input_strategy( 407 | self.args.input_strategy, self.args.dataset, cuts_valid 408 | ), 409 | cut_transforms=[], 410 | token=self.args.text_tokens 411 | ) 412 | 413 | if True: 414 | # length_before_filter = len(cuts_train.to_eager()) 415 | 416 | def drop_c(c): 417 | # print(c.supervisions[0].custom is not None) and (len(c.supervisions[0].custom['tokens']['char']) < 190) 418 | return len(c.supervisions[0].custom['tokens']['char']) < 200 419 | 420 | 421 | cuts_valid = cuts_valid.filter(drop_c) 422 | logging.info("filtered") 423 | print("filtered nby length for valid!") 424 | # print(f"filtered! {len(cuts_train.to_eager())} - {length_before_filter}") 425 | 426 | 427 | valid_sampler = DynamicBucketingSampler( 428 | cuts_valid, 429 | max_duration=self.args.max_duration, 430 | shuffle=False, 431 | drop_last=True, 432 | ) 433 | logging.info("About to create dev dataloader") 434 | valid_dl = DataLoader( 435 | validate, 436 | sampler=valid_sampler, 437 | batch_size=None, 438 | num_workers=4, 439 | persistent_workers=False, 440 | ) 441 | 442 | return valid_dl 443 | 444 | def test_dataloaders(self, cuts: CutSet) -> DataLoader: 445 | logging.debug("About to create test dataset") 446 | test = SpeechSynthesisDataset( 447 | get_text_token_collater(self.args.text_tokens_path), 448 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()) 449 | if self.args.on_the_fly_feats 450 | else _get_input_strategy( 451 | self.args.input_strategy, self.args.dataset, cuts 452 | ), 453 | cut_transforms=[], 454 | token=self.args.text_tokens 455 | ) 456 | 457 | # @amitroth added this line to prevent oom from long tokens 458 | if True: 459 | # length_before_filter = len(cuts_train.to_eager()) 460 | 461 | def drop_c(c): 462 | # print(c.supervisions[0].custom is not None) and (len(c.supervisions[0].custom['tokens']['char']) < 190) 463 | return len(c.supervisions[0].custom['tokens']['char']) < 200 464 | 465 | cuts = cuts.filter(drop_c) 466 | logging.info("filtered") 467 | print("filtered by length for test!") 468 | # print(f"filtered! {len(cuts_train.to_eager())} - {length_before_filter}") 469 | 470 | sampler = DynamicBucketingSampler( 471 | cuts, 472 | max_duration=self.args.max_duration, 473 | shuffle=False, 474 | drop_last=True, 475 | ) 476 | logging.debug("About to create test dataloader") 477 | test_dl = DataLoader( 478 | test, 479 | batch_size=None, 480 | sampler=sampler, 481 | num_workers=self.args.num_workers, 482 | ) 483 | return test_dl 484 | 485 | @lru_cache() 486 | def train_cuts(self) -> CutSet: 487 | logging.info("About to get train cuts") 488 | return load_manifest_lazy( 489 | self.args.manifest_dir / "cuts_train.jsonl.gz" 490 | ) 491 | 492 | @lru_cache() 493 | def dev_cuts(self) -> CutSet: 494 | logging.info("About to get dev cuts") 495 | return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") 496 | 497 | @lru_cache() 498 | def test_cuts(self) -> CutSet: 499 | logging.info("About to get test cuts") 500 | return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz") 501 | -------------------------------------------------------------------------------- /valle/modules/activation.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Linear, Module 6 | from torch.nn import functional as F 7 | from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ 8 | from torch.nn.modules.linear import NonDynamicallyQuantizableLinear 9 | from torch.nn.parameter import Parameter 10 | 11 | 12 | class MultiheadAttention(Module): 13 | r"""Allows the model to jointly attend to information 14 | from different representation subspaces as described in the paper: 15 | `Attention Is All You Need `_. 16 | 17 | Multi-Head Attention is defined as: 18 | 19 | .. math:: 20 | \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O 21 | 22 | where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. 23 | 24 | ``forward()`` will use a special optimized implementation if all of the following 25 | conditions are met: 26 | 27 | - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This 28 | restriction will be loosened in the future.) 29 | - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` 30 | - training is disabled (using ``.eval()``) 31 | - dropout is 0 32 | - ``add_bias_kv`` is ``False`` 33 | - ``add_zero_attn`` is ``False`` 34 | - ``batch_first`` is ``True`` and the input is batched 35 | - ``kdim`` and ``vdim`` are equal to ``embed_dim`` 36 | - at most one of ``key_padding_mask`` or ``attn_mask`` is passed 37 | - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` 38 | nor ``attn_mask`` is passed 39 | 40 | If the optimized implementation is in use, a 41 | `NestedTensor `_ can be passed for 42 | ``query``/``key``/``value`` to represent padding more efficiently than using a 43 | padding mask. In this case, a `NestedTensor `_ 44 | will be returned, and an additional speedup proportional to the fraction of the input 45 | that is padding can be expected. 46 | 47 | Args: 48 | embed_dim: Total dimension of the model. 49 | num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split 50 | across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). 51 | dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). 52 | bias: If specified, adds bias to input / output projection layers. Default: ``True``. 53 | add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. 54 | add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. 55 | Default: ``False``. 56 | kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). 57 | vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). 58 | batch_first: If ``True``, then the input and output tensors are provided 59 | as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 60 | 61 | Examples:: 62 | 63 | >>> # xdoctest: +SKIP 64 | >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) 65 | >>> attn_output, attn_output_weights = multihead_attn(query, key, value) 66 | 67 | """ 68 | __constants__ = ["batch_first"] 69 | bias_k: Optional[torch.Tensor] 70 | bias_v: Optional[torch.Tensor] 71 | 72 | def __init__( 73 | self, 74 | embed_dim, 75 | num_heads, 76 | dropout=0.0, 77 | bias=True, 78 | add_bias_kv=False, 79 | add_zero_attn=False, 80 | kdim=None, 81 | vdim=None, 82 | batch_first=False, 83 | linear1_cls=Linear, 84 | linear2_cls=Linear, 85 | device=None, 86 | dtype=None, 87 | ) -> None: 88 | factory_kwargs = {"device": device, "dtype": dtype} 89 | super(MultiheadAttention, self).__init__() 90 | self.embed_dim = embed_dim 91 | self.kdim = kdim if kdim is not None else embed_dim 92 | self.vdim = vdim if vdim is not None else embed_dim 93 | self._qkv_same_embed_dim = ( 94 | self.kdim == embed_dim and self.vdim == embed_dim 95 | ) 96 | 97 | self.num_heads = num_heads 98 | self.dropout = dropout 99 | self.batch_first = batch_first 100 | self.head_dim = embed_dim // num_heads 101 | assert ( 102 | self.head_dim * num_heads == self.embed_dim 103 | ), "embed_dim must be divisible by num_heads" 104 | 105 | if add_bias_kv: 106 | self.bias_k = Parameter( 107 | torch.empty((1, 1, embed_dim), **factory_kwargs) 108 | ) 109 | self.bias_v = Parameter( 110 | torch.empty((1, 1, embed_dim), **factory_kwargs) 111 | ) 112 | else: 113 | self.bias_k = self.bias_v = None 114 | 115 | if linear1_cls == Linear: 116 | if not self._qkv_same_embed_dim: 117 | self.q_proj_weight = Parameter( 118 | torch.empty((embed_dim, embed_dim), **factory_kwargs) 119 | ) 120 | self.k_proj_weight = Parameter( 121 | torch.empty((embed_dim, self.kdim), **factory_kwargs) 122 | ) 123 | self.v_proj_weight = Parameter( 124 | torch.empty((embed_dim, self.vdim), **factory_kwargs) 125 | ) 126 | self.register_parameter("in_proj_weight", None) 127 | else: 128 | self.in_proj_weight = Parameter( 129 | torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) 130 | ) 131 | self.register_parameter("q_proj_weight", None) 132 | self.register_parameter("k_proj_weight", None) 133 | self.register_parameter("v_proj_weight", None) 134 | 135 | if bias: 136 | self.in_proj_bias = Parameter( 137 | torch.empty(3 * embed_dim, **factory_kwargs) 138 | ) 139 | else: 140 | self.register_parameter("in_proj_bias", None) 141 | self.out_proj = NonDynamicallyQuantizableLinear( 142 | embed_dim, embed_dim, bias=bias, **factory_kwargs 143 | ) 144 | 145 | self._reset_parameters() 146 | else: 147 | if not self._qkv_same_embed_dim: 148 | raise NotImplementedError 149 | else: 150 | self.in_proj_linear = linear1_cls( 151 | embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs 152 | ) 153 | self.in_proj_weight = self.in_proj_linear.weight 154 | 155 | self.register_parameter("q_proj_weight", None) 156 | self.register_parameter("k_proj_weight", None) 157 | self.register_parameter("v_proj_weight", None) 158 | 159 | if bias: 160 | self.in_proj_bias = self.in_proj_linear.bias 161 | else: 162 | self.register_parameter("in_proj_bias", None) 163 | 164 | self.out_proj = linear2_cls( 165 | embed_dim, embed_dim, bias=bias, **factory_kwargs 166 | ) 167 | 168 | if self.bias_k is not None: 169 | xavier_normal_(self.bias_k) 170 | if self.bias_v is not None: 171 | xavier_normal_(self.bias_v) 172 | 173 | self.add_zero_attn = add_zero_attn 174 | 175 | def _reset_parameters(self): 176 | if self._qkv_same_embed_dim: 177 | xavier_uniform_(self.in_proj_weight) 178 | else: 179 | xavier_uniform_(self.q_proj_weight) 180 | xavier_uniform_(self.k_proj_weight) 181 | xavier_uniform_(self.v_proj_weight) 182 | 183 | if self.in_proj_bias is not None: 184 | constant_(self.in_proj_bias, 0.0) 185 | constant_(self.out_proj.bias, 0.0) 186 | 187 | if self.bias_k is not None: 188 | xavier_normal_(self.bias_k) 189 | if self.bias_v is not None: 190 | xavier_normal_(self.bias_v) 191 | 192 | def __setstate__(self, state): 193 | # Support loading old MultiheadAttention checkpoints generated by v1.1.0 194 | if "_qkv_same_embed_dim" not in state: 195 | state["_qkv_same_embed_dim"] = True 196 | 197 | super(MultiheadAttention, self).__setstate__(state) 198 | 199 | def forward( 200 | self, 201 | query: Tensor, 202 | key: Tensor, 203 | value: Tensor, 204 | key_padding_mask: Optional[Tensor] = None, 205 | need_weights: bool = True, 206 | attn_mask: Optional[Tensor] = None, 207 | average_attn_weights: bool = True, 208 | ) -> Tuple[Tensor, Optional[Tensor]]: 209 | r""" 210 | Args: 211 | query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` 212 | or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, 213 | :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. 214 | Queries are compared against key-value pairs to produce the output. 215 | See "Attention Is All You Need" for more details. 216 | key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` 217 | or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, 218 | :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. 219 | See "Attention Is All You Need" for more details. 220 | value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when 221 | ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source 222 | sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. 223 | See "Attention Is All You Need" for more details. 224 | key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` 225 | to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. 226 | Binary and byte masks are supported. 227 | For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for 228 | the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. 229 | need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. 230 | Default: ``True``. 231 | attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape 232 | :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, 233 | :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be 234 | broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. 235 | Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the 236 | corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the 237 | corresponding position is not allowed to attend. For a float mask, the mask values will be added to 238 | the attention weight. 239 | average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across 240 | heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an 241 | effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) 242 | 243 | Outputs: 244 | - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, 245 | :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, 246 | where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the 247 | embedding dimension ``embed_dim``. 248 | - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, 249 | returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or 250 | :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and 251 | :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per 252 | head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. 253 | 254 | .. note:: 255 | `batch_first` argument is ignored for unbatched inputs. 256 | """ 257 | is_batched = query.dim() == 3 258 | if key_padding_mask is not None: 259 | _kpm_dtype = key_padding_mask.dtype 260 | if _kpm_dtype != torch.bool and not torch.is_floating_point( 261 | key_padding_mask 262 | ): 263 | raise AssertionError( 264 | "only bool and floating types of key_padding_mask are supported" 265 | ) 266 | why_not_fast_path = "" 267 | if not is_batched: 268 | why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" 269 | elif query is not key or key is not value: 270 | # When lifting this restriction, don't forget to either 271 | # enforce that the dtypes all match or test cases where 272 | # they don't! 273 | why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" 274 | elif ( 275 | self.in_proj_bias is not None 276 | and query.dtype != self.in_proj_bias.dtype 277 | ): 278 | why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" 279 | elif ( 280 | self.in_proj_weight is not None 281 | and query.dtype != self.in_proj_weight.dtype 282 | ): 283 | # this case will fail anyway, but at least they'll get a useful error message. 284 | why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" 285 | elif self.training: 286 | why_not_fast_path = "training is enabled" 287 | elif not self.batch_first: 288 | why_not_fast_path = "batch_first was not True" 289 | elif self.bias_k is not None: 290 | why_not_fast_path = "self.bias_k was not None" 291 | elif self.bias_v is not None: 292 | why_not_fast_path = "self.bias_v was not None" 293 | elif self.dropout: 294 | why_not_fast_path = f"dropout was {self.dropout}, required zero" 295 | elif self.add_zero_attn: 296 | why_not_fast_path = "add_zero_attn was enabled" 297 | elif not self._qkv_same_embed_dim: 298 | why_not_fast_path = "_qkv_same_embed_dim was not True" 299 | elif attn_mask is not None: 300 | why_not_fast_path = "attn_mask was not None" 301 | elif query.is_nested and key_padding_mask is not None: 302 | why_not_fast_path = ( 303 | "key_padding_mask is not supported with NestedTensor input" 304 | ) 305 | elif self.num_heads % 2 == 1: 306 | why_not_fast_path = "num_heads is odd" 307 | elif torch.is_autocast_enabled(): 308 | why_not_fast_path = "autocast is enabled" 309 | 310 | if not why_not_fast_path: 311 | tensor_args = ( 312 | query, 313 | key, 314 | value, 315 | self.in_proj_weight, 316 | self.in_proj_bias, 317 | self.out_proj.weight, 318 | self.out_proj.bias, 319 | ) 320 | # We have to use list comprehensions below because TorchScript does not support 321 | # generator expressions. 322 | if torch.overrides.has_torch_function(tensor_args): 323 | why_not_fast_path = "some Tensor argument has_torch_function" 324 | elif not all( 325 | [ 326 | (x is None or x.is_cuda or "cpu" in str(x.device)) 327 | for x in tensor_args 328 | ] 329 | ): 330 | why_not_fast_path = ( 331 | "some Tensor argument is neither CUDA nor CPU" 332 | ) 333 | elif torch.is_grad_enabled() and any( 334 | [x is not None and x.requires_grad for x in tensor_args] 335 | ): 336 | why_not_fast_path = ( 337 | "grad is enabled and at least one of query or the " 338 | "input/output projection weights or biases requires_grad" 339 | ) 340 | if not why_not_fast_path: 341 | return torch._native_multi_head_attention( 342 | query, 343 | key, 344 | value, 345 | self.embed_dim, 346 | self.num_heads, 347 | self.in_proj_weight, 348 | self.in_proj_bias, 349 | self.out_proj.weight, 350 | self.out_proj.bias, 351 | key_padding_mask 352 | if key_padding_mask is not None 353 | else attn_mask, 354 | need_weights, 355 | average_attn_weights, 356 | 1 357 | if key_padding_mask is not None 358 | else 0 359 | if attn_mask is not None 360 | else None, 361 | ) 362 | 363 | any_nested = query.is_nested or key.is_nested or value.is_nested 364 | assert not any_nested, ( 365 | "MultiheadAttention does not support NestedTensor outside of its fast path. " 366 | + f"The fast path was not hit because {why_not_fast_path}" 367 | ) 368 | 369 | if self.batch_first and is_batched: 370 | # make sure that the transpose op does not affect the "is" property 371 | if key is value: 372 | if query is key: 373 | query = key = value = query.transpose(1, 0) 374 | else: 375 | query, key = [x.transpose(1, 0) for x in (query, key)] 376 | value = key 377 | else: 378 | query, key, value = [ 379 | x.transpose(1, 0) for x in (query, key, value) 380 | ] 381 | 382 | if not self._qkv_same_embed_dim: 383 | attn_output, attn_output_weights = F.multi_head_attention_forward( 384 | query, 385 | key, 386 | value, 387 | self.embed_dim, 388 | self.num_heads, 389 | self.in_proj_weight, 390 | self.in_proj_bias, 391 | self.bias_k, 392 | self.bias_v, 393 | self.add_zero_attn, 394 | self.dropout, 395 | self.out_proj.weight, 396 | self.out_proj.bias, 397 | training=self.training, 398 | key_padding_mask=key_padding_mask, 399 | need_weights=need_weights, 400 | attn_mask=attn_mask, 401 | use_separate_proj_weight=True, 402 | q_proj_weight=self.q_proj_weight, 403 | k_proj_weight=self.k_proj_weight, 404 | v_proj_weight=self.v_proj_weight, 405 | average_attn_weights=average_attn_weights, 406 | ) 407 | else: 408 | attn_output, attn_output_weights = F.multi_head_attention_forward( 409 | query, 410 | key, 411 | value, 412 | self.embed_dim, 413 | self.num_heads, 414 | self.in_proj_weight, 415 | self.in_proj_bias, 416 | self.bias_k, 417 | self.bias_v, 418 | self.add_zero_attn, 419 | self.dropout, 420 | self.out_proj.weight, 421 | self.out_proj.bias, 422 | training=self.training, 423 | key_padding_mask=key_padding_mask, 424 | need_weights=need_weights, 425 | attn_mask=attn_mask, 426 | average_attn_weights=average_attn_weights, 427 | ) 428 | if self.batch_first and is_batched: 429 | return attn_output.transpose(1, 0), attn_output_weights 430 | else: 431 | return attn_output, attn_output_weights 432 | -------------------------------------------------------------------------------- /valle/modules/transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numbers 3 | from functools import partial 4 | from typing import Any, Callable, List, Optional, Tuple, Union 5 | 6 | import torch 7 | from torch import Tensor, nn 8 | from torch.nn import functional as F 9 | 10 | from .activation import MultiheadAttention 11 | from .scaling import ActivationBalancer, BalancedDoubleSwish 12 | from .scaling import BasicNorm as _BasicNorm 13 | 14 | _shape_t = Union[int, List[int], torch.Size] 15 | 16 | 17 | class LayerNorm(nn.Module): 18 | __constants__ = ["normalized_shape", "eps", "elementwise_affine"] 19 | normalized_shape: Tuple[int, ...] 20 | eps: float 21 | elementwise_affine: bool 22 | 23 | def __init__( 24 | self, 25 | normalized_shape: _shape_t, 26 | eps: float = 1e-5, 27 | elementwise_affine: bool = True, 28 | device=None, 29 | dtype=None, 30 | ) -> None: 31 | factory_kwargs = {"device": device, "dtype": dtype} 32 | super(LayerNorm, self).__init__() 33 | if isinstance(normalized_shape, numbers.Integral): 34 | # mypy error: incompatible types in assignment 35 | normalized_shape = (normalized_shape,) # type: ignore[assignment] 36 | self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] 37 | self.eps = eps 38 | self.elementwise_affine = elementwise_affine 39 | if self.elementwise_affine: 40 | self.weight = nn.Parameter( 41 | torch.empty(self.normalized_shape, **factory_kwargs) 42 | ) 43 | self.bias = nn.Parameter( 44 | torch.empty(self.normalized_shape, **factory_kwargs) 45 | ) 46 | else: 47 | self.register_parameter("weight", None) 48 | self.register_parameter("bias", None) 49 | 50 | self.reset_parameters() 51 | 52 | def reset_parameters(self) -> None: 53 | if self.elementwise_affine: 54 | nn.init.ones_(self.weight) 55 | nn.init.zeros_(self.bias) 56 | 57 | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: 58 | if isinstance(input, tuple): 59 | input, embedding = input 60 | return ( 61 | F.layer_norm( 62 | input, 63 | self.normalized_shape, 64 | self.weight, 65 | self.bias, 66 | self.eps, 67 | ), 68 | embedding, 69 | ) 70 | 71 | assert embedding is None 72 | return F.layer_norm( 73 | input, self.normalized_shape, self.weight, self.bias, self.eps 74 | ) 75 | 76 | def extra_repr(self) -> str: 77 | return ( 78 | "{normalized_shape}, eps={eps}, " 79 | "elementwise_affine={elementwise_affine}".format(**self.__dict__) 80 | ) 81 | 82 | 83 | class AdaptiveLayerNorm(nn.Module): 84 | r"""Adaptive Layer Normalization""" 85 | 86 | def __init__(self, d_model, norm) -> None: 87 | super(AdaptiveLayerNorm, self).__init__() 88 | self.project_layer = nn.Linear(d_model, 2 * d_model) 89 | self.norm = norm 90 | self.d_model = d_model 91 | self.eps = self.norm.eps 92 | 93 | def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: 94 | if isinstance(input, tuple): 95 | input, embedding = input 96 | weight, bias = torch.split( 97 | self.project_layer(embedding), 98 | split_size_or_sections=self.d_model, 99 | dim=-1, 100 | ) 101 | return (weight * self.norm(input) + bias, embedding) 102 | 103 | weight, bias = torch.split( 104 | self.project_layer(embedding), 105 | split_size_or_sections=self.d_model, 106 | dim=-1, 107 | ) 108 | return weight * self.norm(input) + bias 109 | 110 | 111 | class BasicNorm(_BasicNorm): 112 | def __init__( 113 | self, 114 | d_model: int, 115 | eps: float = 1e-5, 116 | device=None, 117 | dtype=None, 118 | ): 119 | super(BasicNorm, self).__init__(d_model, eps=eps) 120 | 121 | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: 122 | if isinstance(input, tuple): 123 | input, embedding = input 124 | return ( 125 | super(BasicNorm, self).forward(input), 126 | embedding, 127 | ) 128 | 129 | assert embedding is None 130 | return super(BasicNorm, self).forward(input) 131 | 132 | 133 | class BalancedBasicNorm(nn.Module): 134 | def __init__( 135 | self, 136 | d_model: int, 137 | eps: float = 1e-5, 138 | device=None, 139 | dtype=None, 140 | ): 141 | super(BalancedBasicNorm, self).__init__() 142 | self.balancer = ActivationBalancer( 143 | d_model, 144 | channel_dim=-1, 145 | min_positive=0.45, 146 | max_positive=0.55, 147 | max_abs=6.0, 148 | ) 149 | self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype) 150 | 151 | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: 152 | if isinstance(input, tuple): 153 | input, embedding = input 154 | return self.norm((self.balancer(input), embedding)) 155 | 156 | assert embedding is None 157 | return self.norm(self.balancer(input)) 158 | 159 | 160 | class IdentityNorm(nn.Module): 161 | def __init__( 162 | self, 163 | d_model: int, 164 | eps: float = 1e-5, 165 | device=None, 166 | dtype=None, 167 | ) -> None: 168 | super(IdentityNorm, self).__init__() 169 | 170 | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: 171 | if isinstance(input, tuple): 172 | return input 173 | 174 | assert embedding is None 175 | return input 176 | 177 | 178 | class TransformerEncoderLayer(nn.Module): 179 | __constants__ = ["batch_first", "norm_first"] 180 | 181 | def __init__( 182 | self, 183 | d_model: int, 184 | nhead: int, 185 | dim_feedforward: int = 2048, 186 | dropout: float = 0.1, 187 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 188 | batch_first: bool = False, 189 | norm_first: bool = False, 190 | device=None, 191 | dtype=None, 192 | linear1_self_attention_cls: nn.Module = nn.Linear, 193 | linear2_self_attention_cls: nn.Module = nn.Linear, 194 | linear1_feedforward_cls: nn.Module = nn.Linear, 195 | linear2_feedforward_cls: nn.Module = nn.Linear, 196 | layer_norm_cls: nn.Module = LayerNorm, 197 | layer_norm_eps: float = 1e-5, 198 | adaptive_layer_norm=False, 199 | ) -> None: 200 | factory_kwargs = {"device": device, "dtype": dtype} 201 | super(TransformerEncoderLayer, self).__init__() 202 | self.self_attn = MultiheadAttention( 203 | d_model, 204 | nhead, 205 | dropout=dropout, 206 | batch_first=batch_first, 207 | linear1_cls=linear1_self_attention_cls, 208 | linear2_cls=linear2_self_attention_cls, 209 | **factory_kwargs, 210 | ) 211 | 212 | # Implementation of Feedforward model 213 | self.linear1 = linear1_feedforward_cls( 214 | d_model, dim_feedforward, **factory_kwargs 215 | ) 216 | self.dropout = nn.Dropout(dropout) 217 | self.linear2 = linear2_feedforward_cls( 218 | dim_feedforward, d_model, **factory_kwargs 219 | ) 220 | 221 | self.norm_first = norm_first 222 | self.dropout1 = nn.Dropout(dropout) 223 | self.dropout2 = nn.Dropout(dropout) 224 | 225 | # Legacy string support for activation function. 226 | if isinstance(activation, str): 227 | activation = _get_activation_fn(activation) 228 | elif isinstance(activation, partial): 229 | activation = activation(d_model) 230 | elif activation == BalancedDoubleSwish: 231 | activation = BalancedDoubleSwish(d_model) 232 | 233 | # # We can't test self.activation in forward() in TorchScript, 234 | # # so stash some information about it instead. 235 | # if activation is F.relu or isinstance(activation, torch.nn.ReLU): 236 | # self.activation_relu_or_gelu = 1 237 | # elif activation is F.gelu or isinstance(activation, torch.nn.GELU): 238 | # self.activation_relu_or_gelu = 2 239 | # else: 240 | # self.activation_relu_or_gelu = 0 241 | self.activation = activation 242 | 243 | norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) 244 | if layer_norm_cls == IdentityNorm: 245 | norm2 = BalancedBasicNorm( 246 | d_model, eps=layer_norm_eps, **factory_kwargs 247 | ) 248 | else: 249 | norm2 = layer_norm_cls( 250 | d_model, eps=layer_norm_eps, **factory_kwargs 251 | ) 252 | 253 | if adaptive_layer_norm: 254 | self.norm1 = AdaptiveLayerNorm(d_model, norm1) 255 | self.norm2 = AdaptiveLayerNorm(d_model, norm2) 256 | else: 257 | self.norm1 = norm1 258 | self.norm2 = norm2 259 | 260 | def __setstate__(self, state): 261 | super(TransformerEncoderLayer, self).__setstate__(state) 262 | if not hasattr(self, "activation"): 263 | self.activation = F.relu 264 | 265 | def forward( 266 | self, 267 | src: Tensor, 268 | src_mask: Optional[Tensor] = None, 269 | src_key_padding_mask: Optional[Tensor] = None, 270 | ) -> Tensor: 271 | r"""Pass the input through the encoder layer. 272 | 273 | Args: 274 | src: the sequence to the encoder layer (required). 275 | src_mask: the mask for the src sequence (optional). 276 | src_key_padding_mask: the mask for the src keys per batch (optional). 277 | 278 | Shape: 279 | see the docs in Transformer class. 280 | """ 281 | x, stage_embedding = src, None 282 | is_src_tuple = False 283 | if isinstance(src, tuple): 284 | x, stage_embedding = src 285 | is_src_tuple = True 286 | 287 | if src_key_padding_mask is not None: 288 | _skpm_dtype = src_key_padding_mask.dtype 289 | if _skpm_dtype != torch.bool and not torch.is_floating_point( 290 | src_key_padding_mask 291 | ): 292 | raise AssertionError( 293 | "only bool and floating types of key_padding_mask are supported" 294 | ) 295 | 296 | if self.norm_first: 297 | x = x + self._sa_block( 298 | self.norm1(x, stage_embedding), 299 | src_mask, 300 | src_key_padding_mask, 301 | ) 302 | x = x + self._ff_block(self.norm2(x, stage_embedding)) 303 | else: 304 | x = self.norm1( 305 | x + self._sa_block(x, src_mask, src_key_padding_mask), 306 | stage_embedding, 307 | ) 308 | x = self.norm2(x + self._ff_block(x), stage_embedding) 309 | 310 | if is_src_tuple: 311 | return (x, stage_embedding) 312 | return x 313 | 314 | # self-attention block 315 | def _sa_block( 316 | self, 317 | x: Tensor, 318 | attn_mask: Optional[Tensor], 319 | key_padding_mask: Optional[Tensor], 320 | ) -> Tensor: 321 | x = self.self_attn( 322 | x, 323 | x, 324 | x, 325 | attn_mask=attn_mask, 326 | key_padding_mask=key_padding_mask, 327 | need_weights=False, 328 | )[0] 329 | return self.dropout1(x) 330 | 331 | # feed forward block 332 | def _ff_block(self, x: Tensor) -> Tensor: 333 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 334 | return self.dropout2(x) 335 | 336 | 337 | class TransformerEncoder(nn.Module): 338 | r"""TransformerEncoder is a stack of N encoder layers. Users can build the 339 | BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. 340 | 341 | Args: 342 | encoder_layer: an instance of the TransformerEncoderLayer() class (required). 343 | num_layers: the number of sub-encoder-layers in the encoder (required). 344 | norm: the layer normalization component (optional). 345 | enable_nested_tensor: if True, input will automatically convert to nested tensor 346 | (and convert back on output). This will improve the overall performance of 347 | TransformerEncoder when padding rate is high. Default: ``True`` (enabled). 348 | 349 | Examples:: 350 | >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) 351 | >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) 352 | >>> src = torch.rand(10, 32, 512) 353 | >>> out = transformer_encoder(src) 354 | """ 355 | __constants__ = ["norm"] 356 | 357 | def __init__(self, encoder_layer, num_layers, norm=None): 358 | super(TransformerEncoder, self).__init__() 359 | self.layers = _get_clones(encoder_layer, num_layers) 360 | self.num_layers = num_layers 361 | self.norm = norm 362 | 363 | def forward( 364 | self, 365 | src: Tensor, 366 | mask: Optional[Tensor] = None, 367 | src_key_padding_mask: Optional[Tensor] = None, 368 | return_layer_states: bool = False, 369 | ) -> Tensor: 370 | r"""Pass the input through the encoder layers in turn. 371 | 372 | Args: 373 | src: the sequence to the encoder (required). 374 | mask: the mask for the src sequence (optional). 375 | src_key_padding_mask: the mask for the src keys per batch (optional). 376 | return_layer_states: return layers' state (optional). 377 | 378 | Shape: 379 | see the docs in Transformer class. 380 | """ 381 | if return_layer_states: 382 | layer_states = [] # layers' output 383 | output = src 384 | for mod in self.layers: 385 | output = mod( 386 | output, 387 | src_mask=mask, 388 | src_key_padding_mask=src_key_padding_mask, 389 | ) 390 | layer_states.append(output[0]) 391 | 392 | if self.norm is not None: 393 | output = self.norm(output) 394 | 395 | return layer_states, output 396 | 397 | output = src 398 | for mod in self.layers: 399 | output = mod( 400 | output, src_mask=mask, src_key_padding_mask=src_key_padding_mask 401 | ) 402 | 403 | if self.norm is not None: 404 | output = self.norm(output) 405 | 406 | return output 407 | 408 | 409 | class TransformerDecoderLayer(nn.Module): 410 | __constants__ = ["batch_first", "norm_first"] 411 | 412 | def __init__( 413 | self, 414 | d_model: int, 415 | nhead: int, 416 | dim_feedforward: int = 2048, 417 | dropout: float = 0.1, 418 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 419 | linear1_self_attention_cls: nn.Module = nn.Linear, 420 | linear2_self_attention_cls: nn.Module = nn.Linear, 421 | linear1_feedforward_cls: nn.Module = nn.Linear, 422 | linear2_feedforward_cls: nn.Module = nn.Linear, 423 | batch_first: bool = False, 424 | norm_first: bool = False, 425 | device=None, 426 | dtype=None, 427 | layer_norm_cls: nn.Module = LayerNorm, 428 | layer_norm_eps: float = 1e-5, 429 | adaptive_layer_norm=False, 430 | ) -> None: 431 | factory_kwargs = {"device": device, "dtype": dtype} 432 | super(TransformerDecoderLayer, self).__init__() 433 | self.self_attn = MultiheadAttention( 434 | d_model, 435 | nhead, 436 | dropout=dropout, 437 | batch_first=batch_first, 438 | linear1_cls=linear1_self_attention_cls, 439 | linear2_cls=linear2_self_attention_cls, 440 | **factory_kwargs, 441 | ) 442 | self.multihead_attn = MultiheadAttention( 443 | d_model, 444 | nhead, 445 | dropout=dropout, 446 | batch_first=batch_first, 447 | linear1_cls=linear1_self_attention_cls, 448 | linear2_cls=linear2_self_attention_cls, 449 | **factory_kwargs, 450 | ) 451 | # Implementation of Feedforward model 452 | self.linear1 = linear1_feedforward_cls( 453 | d_model, dim_feedforward, **factory_kwargs 454 | ) 455 | self.dropout = nn.Dropout(dropout) 456 | self.linear2 = linear2_feedforward_cls( 457 | dim_feedforward, d_model, **factory_kwargs 458 | ) 459 | 460 | self.norm_first = norm_first 461 | self.dropout1 = nn.Dropout(dropout) 462 | self.dropout2 = nn.Dropout(dropout) 463 | self.dropout3 = nn.Dropout(dropout) 464 | 465 | # Legacy string support for activation function. 466 | if isinstance(activation, str): 467 | self.activation = _get_activation_fn(activation) 468 | elif isinstance(activation, partial): 469 | self.activation = activation(d_model) 470 | elif activation == BalancedDoubleSwish: 471 | self.activation = BalancedDoubleSwish(d_model) 472 | else: 473 | self.activation = activation 474 | 475 | if adaptive_layer_norm: 476 | norm1 = layer_norm_cls( 477 | d_model, eps=layer_norm_eps, **factory_kwargs 478 | ) 479 | norm2 = layer_norm_cls( 480 | d_model, eps=layer_norm_eps, **factory_kwargs 481 | ) 482 | norm3 = layer_norm_cls( 483 | d_model, eps=layer_norm_eps, **factory_kwargs 484 | ) 485 | 486 | self.norm1 = AdaptiveLayerNorm(d_model, norm1) 487 | self.norm2 = AdaptiveLayerNorm(d_model, norm2) 488 | self.norm3 = AdaptiveLayerNorm(d_model, norm3) 489 | else: 490 | self.norm1 = layer_norm_cls( 491 | d_model, eps=layer_norm_eps, **factory_kwargs 492 | ) 493 | self.norm2 = layer_norm_cls( 494 | d_model, eps=layer_norm_eps, **factory_kwargs 495 | ) 496 | if layer_norm_cls == IdentityNorm: 497 | self.norm3 = BalancedBasicNorm( 498 | d_model, eps=layer_norm_eps, **factory_kwargs 499 | ) 500 | else: 501 | self.norm3 = layer_norm_cls( 502 | d_model, eps=layer_norm_eps, **factory_kwargs 503 | ) 504 | 505 | def forward( 506 | self, 507 | tgt: Tensor, 508 | memory: Tensor, 509 | tgt_mask: Optional[Tensor] = None, 510 | memory_mask: Optional[Tensor] = None, 511 | tgt_key_padding_mask: Optional[Tensor] = None, 512 | memory_key_padding_mask: Optional[Tensor] = None, 513 | ) -> Tensor: 514 | r"""Pass the inputs (and mask) through the decoder layer. 515 | 516 | Args: 517 | tgt: the sequence to the decoder layer (required). 518 | memory: the sequence from the last layer of the encoder (required). 519 | tgt_mask: the mask for the tgt sequence (optional). 520 | memory_mask: the mask for the memory sequence (optional). 521 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 522 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 523 | 524 | Shape: 525 | see the docs in Transformer class. 526 | """ 527 | tgt_is_tuple = False 528 | if isinstance(tgt, tuple): 529 | x, stage_embedding = tgt 530 | tgt_is_tuple = True 531 | else: 532 | x, stage_embedding = tgt, None 533 | 534 | if self.norm_first: 535 | x = x + self._sa_block( 536 | self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask 537 | ) 538 | x = x + self._mha_block( 539 | self.norm2(x, stage_embedding), 540 | memory, 541 | memory_mask, 542 | memory_key_padding_mask, 543 | ) 544 | x = x + self._ff_block(self.norm3(x, stage_embedding)) 545 | else: 546 | x = self.norm1( 547 | x + self._sa_block(x, tgt_mask, tgt_key_padding_mask), 548 | stage_embedding, 549 | ) 550 | x = self.norm2( 551 | x 552 | + self._mha_block( 553 | x, memory, memory_mask, memory_key_padding_mask 554 | ), 555 | stage_embedding, 556 | ) 557 | x = self.norm3(x + self._ff_block(x), stage_embedding) 558 | 559 | if tgt_is_tuple: 560 | return (x, stage_embedding) 561 | return x 562 | 563 | # self-attention block 564 | def _sa_block( 565 | self, 566 | x: Tensor, 567 | attn_mask: Optional[Tensor], 568 | key_padding_mask: Optional[Tensor], 569 | ) -> Tensor: 570 | x = self.self_attn( 571 | x, 572 | x, 573 | x, 574 | attn_mask=attn_mask, 575 | key_padding_mask=key_padding_mask, 576 | need_weights=False, 577 | )[0] 578 | return self.dropout1(x) 579 | 580 | # multihead attention block 581 | def _mha_block( 582 | self, 583 | x: Tensor, 584 | mem: Tensor, 585 | attn_mask: Optional[Tensor], 586 | key_padding_mask: Optional[Tensor], 587 | ) -> Tensor: 588 | x = self.multihead_attn( 589 | x, 590 | mem, 591 | mem, 592 | attn_mask=attn_mask, 593 | key_padding_mask=key_padding_mask, 594 | need_weights=False, 595 | )[0] 596 | return self.dropout2(x) 597 | 598 | # feed forward block 599 | def _ff_block(self, x: Tensor) -> Tensor: 600 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 601 | return self.dropout3(x) 602 | 603 | 604 | def _get_clones(module, N): 605 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 606 | 607 | 608 | def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: 609 | if activation == "relu": 610 | return F.relu 611 | elif activation == "gelu": 612 | return F.gelu 613 | 614 | raise RuntimeError( 615 | "activation should be relu/gelu, not {}".format(activation) 616 | ) 617 | --------------------------------------------------------------------------------