├── e2-tts.png ├── e2_tts_pytorch ├── __init__.py ├── trainer.py └── e2_tts.py ├── train_example.py ├── LICENSE ├── .github └── workflows │ └── python-publish.yml ├── pyproject.toml ├── README.md └── .gitignore /e2-tts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/e2-tts-pytorch/main/e2-tts.png -------------------------------------------------------------------------------- /e2_tts_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from e2_tts_pytorch.e2_tts import ( 2 | Transformer, 3 | DurationPredictor, 4 | E2TTS, 5 | ) 6 | 7 | from e2_tts_pytorch.trainer import ( 8 | E2Trainer 9 | ) -------------------------------------------------------------------------------- /train_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from e2_tts_pytorch import E2TTS, DurationPredictor 3 | 4 | from torch.optim import Adam 5 | from datasets import load_dataset 6 | 7 | from e2_tts_pytorch.trainer import ( 8 | HFDataset, 9 | E2Trainer 10 | ) 11 | 12 | duration_predictor = DurationPredictor( 13 | transformer = dict( 14 | dim = 512, 15 | depth = 6, 16 | ) 17 | ) 18 | 19 | e2tts = E2TTS( 20 | duration_predictor = duration_predictor, 21 | transformer = dict( 22 | dim = 512, 23 | depth = 12, 24 | skip_connect_type = 'concat' 25 | ), 26 | ) 27 | 28 | train_dataset = HFDataset(load_dataset("MushanW/GLOBE")["train"]) 29 | 30 | optimizer = Adam(e2tts.parameters(), lr=7.5e-5) 31 | 32 | trainer = E2Trainer( 33 | e2tts, 34 | optimizer, 35 | num_warmup_steps=20000, 36 | grad_accumulation_steps = 1, 37 | checkpoint_path = 'e2tts.pt', 38 | log_file = 'e2tts.txt' 39 | ) 40 | 41 | epochs = 10 42 | batch_size = 32 43 | 44 | trainer.train(train_dataset, epochs, batch_size, save_step=1000) 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "e2-tts-pytorch" 3 | version = "1.0.5" 4 | description = "E2-TTS in Pytorch" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.8" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | 'transformers', 15 | 'attention mechanism', 16 | 'text to speech' 17 | ] 18 | classifiers=[ 19 | 'Development Status :: 4 - Beta', 20 | 'Intended Audience :: Developers', 21 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 22 | 'License :: OSI Approved :: MIT License', 23 | 'Programming Language :: Python :: 3.8', 24 | ] 25 | 26 | dependencies = [ 27 | 'accelerate>=0.33.0', 28 | 'beartype', 29 | 'einops>=0.8.0', 30 | 'einx>=0.3.0', 31 | 'ema-pytorch>=0.5.2', 32 | 'g2p-en', 33 | 'jaxtyping', 34 | 'loguru', 35 | 'tensorboard', 36 | 'torch>=2.0', 37 | 'torchdiffeq', 38 | 'torchaudio>=2.3.1', 39 | 'tqdm>=4.65.0', 40 | 'vocos', 41 | 'x-transformers>=1.31.14', 42 | ] 43 | 44 | [project.urls] 45 | Homepage = "https://pypi.org/project/e2-tts-pytorch/" 46 | Repository = "https://github.com/lucidrains/e2-tts-pytorch" 47 | 48 | [project.optional-dependencies] 49 | examples = ["datasets"] 50 | 51 | [build-system] 52 | requires = ["hatchling"] 53 | build-backend = "hatchling.build" 54 | 55 | [tool.hatch.metadata] 56 | allow-direct-references = true 57 | 58 | [tool.hatch.build.targets.wheel] 59 | packages = ["e2_tts_pytorch"] 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ## E2 TTS - Pytorch 5 | 6 | Implementation of E2-TTS, Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS, in Pytorch 7 | 8 | The repository differs from the paper in that it uses a multistream transformer for text and audio, with conditioning done every transformer block in the E2 manner. 9 | 10 | It also includes an improvisation that was proven out by Manmay, where the text is simply interpolated to the length of the audio for conditioning. You can try this by setting `interpolated_text = True` on `E2TTS` 11 | 12 | ## Appreciation 13 | 14 | - Manmay for contributing working end-to-end training code! 15 | 16 | - Lucas Newman for the code contributions, helpful feedback, and for sharing the first set of positive experiments! 17 | 18 | - Jing for sharing the second positive result with a multilingual (English + Chinese) dataset! 19 | 20 | - Coice and Manmay for reporting the third and fourth successful runs. Farewell alignment engineering 21 | 22 | ## Install 23 | 24 | ```bash 25 | $ pip install e2-tts-pytorch 26 | ``` 27 | 28 | ## Usage 29 | 30 | ```python 31 | import torch 32 | 33 | from e2_tts_pytorch import ( 34 | E2TTS, 35 | DurationPredictor 36 | ) 37 | 38 | duration_predictor = DurationPredictor( 39 | transformer = dict( 40 | dim = 512, 41 | depth = 8, 42 | ) 43 | ) 44 | 45 | mel = torch.randn(2, 1024, 100) 46 | text = ['Hello', 'Goodbye'] 47 | 48 | loss = duration_predictor(mel, text = text) 49 | loss.backward() 50 | 51 | e2tts = E2TTS( 52 | duration_predictor = duration_predictor, 53 | transformer = dict( 54 | dim = 512, 55 | depth = 8 56 | ), 57 | ) 58 | 59 | out = e2tts(mel, text = text) 60 | out.loss.backward() 61 | 62 | sampled = e2tts.sample(mel[:, :5], text = text) 63 | 64 | ``` 65 | 66 | ## Citations 67 | 68 | ```bibtex 69 | @inproceedings{Eskimez2024E2TE, 70 | title = {E2 TTS: Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS}, 71 | author = {Sefik Emre Eskimez and Xiaofei Wang and Manthan Thakker and Canrun Li and Chung-Hsien Tsai and Zhen Xiao and Hemin Yang and Zirun Zhu and Min Tang and Xu Tan and Yanqing Liu and Sheng Zhao and Naoyuki Kanda}, 72 | year = {2024}, 73 | url = {https://api.semanticscholar.org/CorpusID:270738197} 74 | } 75 | ``` 76 | 77 | ```bibtex 78 | @inproceedings{Darcet2023VisionTN, 79 | title = {Vision Transformers Need Registers}, 80 | author = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski}, 81 | year = {2023}, 82 | url = {https://api.semanticscholar.org/CorpusID:263134283} 83 | } 84 | ``` 85 | 86 | ```bibtex 87 | @article{Bao2022AllAW, 88 | title = {All are Worth Words: A ViT Backbone for Diffusion Models}, 89 | author = {Fan Bao and Shen Nie and Kaiwen Xue and Yue Cao and Chongxuan Li and Hang Su and Jun Zhu}, 90 | journal = {2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 91 | year = {2022}, 92 | pages = {22669-22679}, 93 | url = {https://api.semanticscholar.org/CorpusID:253581703} 94 | } 95 | ``` 96 | 97 | ```bibtex 98 | @article{Burtsev2021MultiStreamT, 99 | title = {Multi-Stream Transformers}, 100 | author = {Mikhail S. Burtsev and Anna Rumshisky}, 101 | journal = {ArXiv}, 102 | year = {2021}, 103 | volume = {abs/2107.10342}, 104 | url = {https://api.semanticscholar.org/CorpusID:236171087} 105 | } 106 | ``` 107 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /e2_tts_pytorch/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from tqdm import tqdm 5 | import matplotlib 6 | matplotlib.use("Agg") 7 | import matplotlib.pylab as plt 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader, Dataset 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch.optim.lr_scheduler import LinearLR, SequentialLR 14 | 15 | import torchaudio 16 | 17 | from einops import rearrange 18 | 19 | from accelerate import Accelerator 20 | from accelerate.utils import DistributedDataParallelKwargs 21 | 22 | from ema_pytorch import EMA 23 | 24 | from loguru import logger 25 | 26 | from e2_tts_pytorch.e2_tts import ( 27 | E2TTS, 28 | DurationPredictor, 29 | MelSpec 30 | ) 31 | 32 | def exists(v): 33 | return v is not None 34 | 35 | def default(v, d): 36 | return v if exists(v) else d 37 | 38 | # plot spectrogram 39 | def plot_spectrogram(spectrogram): 40 | fig, ax = plt.subplots(figsize=(10, 4)) 41 | im = ax.imshow(spectrogram.T, aspect="auto", origin="lower", interpolation="none") 42 | plt.colorbar(im, ax=ax) 43 | plt.xlabel("Frames") 44 | plt.ylabel("Channels") 45 | plt.tight_layout() 46 | 47 | fig.canvas.draw() 48 | plt.close() 49 | return fig 50 | 51 | # collation 52 | 53 | def collate_fn(batch): 54 | mel_specs = [item['mel_spec'].squeeze(0) for item in batch] 55 | mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs]) 56 | max_mel_length = mel_lengths.amax() 57 | 58 | padded_mel_specs = [] 59 | for spec in mel_specs: 60 | padding = (0, max_mel_length - spec.size(-1)) 61 | padded_spec = F.pad(spec, padding, value = 0) 62 | padded_mel_specs.append(padded_spec) 63 | 64 | mel_specs = torch.stack(padded_mel_specs) 65 | 66 | text = [item['text'] for item in batch] 67 | text_lengths = torch.LongTensor([len(item) for item in text]) 68 | 69 | return dict( 70 | mel = mel_specs, 71 | mel_lengths = mel_lengths, 72 | text = text, 73 | text_lengths = text_lengths, 74 | ) 75 | 76 | # dataset 77 | 78 | class HFDataset(Dataset): 79 | def __init__( 80 | self, 81 | hf_dataset: Dataset, 82 | target_sample_rate = 24_000, 83 | hop_length = 256 84 | ): 85 | self.data = hf_dataset 86 | self.target_sample_rate = target_sample_rate 87 | self.hop_length = hop_length 88 | self.mel_spectrogram = MelSpec(sampling_rate=target_sample_rate) 89 | 90 | def __len__(self): 91 | return len(self.data) 92 | 93 | def __getitem__(self, index): 94 | row = self.data[index] 95 | audio = row['audio']['array'] 96 | 97 | logger.info(f"Audio shape: {audio.shape}") 98 | 99 | sample_rate = row['audio']['sampling_rate'] 100 | duration = audio.shape[-1] / sample_rate 101 | 102 | if duration > 20 or duration < 0.3: 103 | logger.warning(f"Skipping due to duration out of bound: {duration}") 104 | return self.__getitem__((index + 1) % len(self.data)) 105 | 106 | audio_tensor = torch.from_numpy(audio).float() 107 | 108 | if sample_rate != self.target_sample_rate: 109 | resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate) 110 | audio_tensor = resampler(audio_tensor) 111 | 112 | audio_tensor = rearrange(audio_tensor, 't -> 1 t') 113 | 114 | mel_spec = self.mel_spectrogram(audio_tensor) 115 | 116 | mel_spec = rearrange(mel_spec, '1 d t -> d t') 117 | 118 | text = row['transcript'] 119 | 120 | return dict( 121 | mel_spec = mel_spec, 122 | text = text, 123 | ) 124 | 125 | # trainer 126 | 127 | class E2Trainer: 128 | def __init__( 129 | self, 130 | model: E2TTS, 131 | optimizer, 132 | num_warmup_steps=20000, 133 | grad_accumulation_steps=1, 134 | duration_predictor: DurationPredictor | None = None, 135 | checkpoint_path = None, 136 | log_file = "logs.txt", 137 | max_grad_norm = 1.0, 138 | sample_rate = 22050, 139 | tensorboard_log_dir = 'runs/e2_tts_experiment', 140 | accelerate_kwargs: dict = dict(), 141 | ema_kwargs: dict = dict() 142 | ): 143 | logger.add(log_file) 144 | 145 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True) 146 | 147 | self.accelerator = Accelerator( 148 | log_with = "all", 149 | kwargs_handlers = [ddp_kwargs], 150 | gradient_accumulation_steps = grad_accumulation_steps, 151 | **accelerate_kwargs 152 | ) 153 | 154 | self.target_sample_rate = sample_rate 155 | 156 | self.model = model 157 | 158 | if self.is_main: 159 | self.ema_model = EMA( 160 | model, 161 | include_online_model = False, 162 | **ema_kwargs 163 | ) 164 | 165 | self.ema_model.to(self.accelerator.device) 166 | 167 | self.duration_predictor = duration_predictor 168 | self.optimizer = optimizer 169 | self.num_warmup_steps = num_warmup_steps 170 | self.checkpoint_path = default(checkpoint_path, 'model.pth') 171 | self.mel_spectrogram = MelSpec(sampling_rate=self.target_sample_rate) 172 | 173 | self.model, self.optimizer = self.accelerator.prepare( 174 | self.model, self.optimizer 175 | ) 176 | self.max_grad_norm = max_grad_norm 177 | 178 | self.writer = SummaryWriter(log_dir=tensorboard_log_dir) 179 | 180 | @property 181 | def is_main(self): 182 | return self.accelerator.is_main_process 183 | 184 | def save_checkpoint(self, step, finetune=False): 185 | self.accelerator.wait_for_everyone() 186 | if self.is_main: 187 | checkpoint = dict( 188 | model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(), 189 | optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(), 190 | ema_model_state_dict = self.ema_model.state_dict(), 191 | scheduler_state_dict = self.scheduler.state_dict(), 192 | step = step 193 | ) 194 | 195 | self.accelerator.save(checkpoint, self.checkpoint_path) 196 | 197 | def load_checkpoint(self): 198 | if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path): 199 | return 0 200 | 201 | checkpoint = torch.load(self.checkpoint_path) 202 | self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict']) 203 | self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict']) 204 | 205 | if self.is_main: 206 | self.ema_model.load_state_dict(checkpoint['ema_model_state_dict']) 207 | 208 | if self.scheduler: 209 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 210 | return checkpoint['step'] 211 | 212 | def train(self, train_dataset, epochs, batch_size, num_workers=12, save_step=1000): 213 | 214 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers, pin_memory=True) 215 | total_steps = len(train_dataloader) * epochs 216 | decay_steps = total_steps - self.num_warmup_steps 217 | warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=self.num_warmup_steps) 218 | decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps) 219 | self.scheduler = SequentialLR(self.optimizer, 220 | schedulers=[warmup_scheduler, decay_scheduler], 221 | milestones=[self.num_warmup_steps]) 222 | train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) 223 | start_step = self.load_checkpoint() 224 | global_step = start_step 225 | 226 | for epoch in range(epochs): 227 | self.model.train() 228 | progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}", unit="step", disable=not self.accelerator.is_local_main_process) 229 | epoch_loss = 0.0 230 | 231 | for batch in progress_bar: 232 | with self.accelerator.accumulate(self.model): 233 | text_inputs = batch['text'] 234 | mel_spec = rearrange(batch['mel'], 'b d n -> b n d') 235 | mel_lengths = batch["mel_lengths"] 236 | 237 | if self.duration_predictor is not None: 238 | dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations')) 239 | self.writer.add_scalar('duration loss', dur_loss.item(), global_step) 240 | 241 | loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths) 242 | self.accelerator.backward(loss) 243 | 244 | if self.max_grad_norm > 0 and self.accelerator.sync_gradients: 245 | self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 246 | 247 | self.optimizer.step() 248 | self.scheduler.step() 249 | self.optimizer.zero_grad() 250 | 251 | if self.is_main: 252 | self.ema_model.update() 253 | 254 | if self.accelerator.is_local_main_process: 255 | logger.info(f"step {global_step+1}: loss = {loss.item():.4f}") 256 | self.writer.add_scalar('loss', loss.item(), global_step) 257 | self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step) 258 | 259 | global_step += 1 260 | epoch_loss += loss.item() 261 | progress_bar.set_postfix(loss=loss.item()) 262 | 263 | if global_step % save_step == 0: 264 | self.save_checkpoint(global_step) 265 | self.writer.add_figure("mel/target", plot_spectrogram(mel_spec[0,:,:].detach().cpu().numpy()), global_step) 266 | self.writer.add_figure("mel/mask", plot_spectrogram(cond[0,:,:].detach().cpu().numpy()), global_step) 267 | self.writer.add_figure("mel/prediction", plot_spectrogram(pred[0,:,:].detach().cpu().numpy()), global_step) 268 | 269 | epoch_loss /= len(train_dataloader) 270 | if self.accelerator.is_local_main_process: 271 | logger.info(f"epoch {epoch+1}/{epochs} - average loss = {epoch_loss:.4f}") 272 | self.writer.add_scalar('epoch average loss', epoch_loss, epoch) 273 | 274 | self.writer.close() 275 | -------------------------------------------------------------------------------- /e2_tts_pytorch/e2_tts.py: -------------------------------------------------------------------------------- 1 | """ 2 | ein notation: 3 | b - batch 4 | n - sequence 5 | nt - text sequence 6 | nw - raw wave length 7 | d - dimension 8 | dt - dimension text 9 | """ 10 | 11 | from __future__ import annotations 12 | 13 | from pathlib import Path 14 | from random import random 15 | from functools import partial 16 | from itertools import zip_longest 17 | from collections import namedtuple 18 | 19 | from typing import Literal, Callable 20 | 21 | import jaxtyping 22 | from beartype import beartype 23 | 24 | import torch 25 | import torch.nn.functional as F 26 | from torch import nn, tensor, Tensor, from_numpy 27 | from torch.nn import Module, ModuleList, Sequential, Linear 28 | from torch.nn.utils.rnn import pad_sequence 29 | 30 | import torchaudio 31 | from torchaudio.functional import DB_to_amplitude, resample 32 | from torchdiffeq import odeint 33 | 34 | import einx 35 | from einops.layers.torch import Rearrange 36 | from einops import einsum, rearrange, repeat, reduce, pack, unpack 37 | 38 | from x_transformers import ( 39 | Attention, 40 | FeedForward, 41 | RMSNorm, 42 | AdaptiveRMSNorm, 43 | ) 44 | 45 | from x_transformers.x_transformers import RotaryEmbedding 46 | 47 | from vocos import Vocos 48 | 49 | pad_sequence = partial(pad_sequence, batch_first = True) 50 | 51 | # constants 52 | 53 | class TorchTyping: 54 | def __init__(self, abstract_dtype): 55 | self.abstract_dtype = abstract_dtype 56 | 57 | def __getitem__(self, shapes: str): 58 | return self.abstract_dtype[Tensor, shapes] 59 | 60 | Float = TorchTyping(jaxtyping.Float) 61 | Int = TorchTyping(jaxtyping.Int) 62 | Bool = TorchTyping(jaxtyping.Bool) 63 | 64 | E2TTSReturn = namedtuple('E2TTS', ['loss', 'cond', 'pred']) 65 | 66 | # helpers 67 | 68 | def exists(v): 69 | return v is not None 70 | 71 | def default(v, d): 72 | return v if exists(v) else d 73 | 74 | def divisible_by(num, den): 75 | return (num % den) == 0 76 | 77 | class Identity(Module): 78 | def forward(self, x, **kwargs): 79 | return x 80 | 81 | # simple utf-8 tokenizer, since paper went character based 82 | 83 | def list_str_to_tensor( 84 | text: list[str], 85 | padding_value = -1 86 | ) -> Int['b nt']: 87 | 88 | list_tensors = [tensor([*bytes(t, 'UTF-8')]) for t in text] 89 | padded_tensor = pad_sequence(list_tensors, padding_value = -1) 90 | return padded_tensor 91 | 92 | # simple english phoneme-based tokenizer 93 | 94 | from g2p_en import G2p 95 | 96 | def get_g2p_en_encode(): 97 | g2p = G2p() 98 | 99 | # used by @lucasnewman successfully here 100 | # https://github.com/lucasnewman/e2-tts-pytorch/blob/ljspeech-test/e2_tts_pytorch/e2_tts.py 101 | 102 | phoneme_to_index = g2p.p2idx 103 | num_phonemes = len(phoneme_to_index) 104 | 105 | extended_chars = [' ', ',', '.', '-', '!', '?', '\'', '"', '...', '..', '. .', '. . .', '. . . .', '. . . . .', '. ...', '... .', '.. ..'] 106 | num_extended_chars = len(extended_chars) 107 | 108 | extended_chars_dict = {p: (num_phonemes + i) for i, p in enumerate(extended_chars)} 109 | phoneme_to_index = {**phoneme_to_index, **extended_chars_dict} 110 | 111 | def encode( 112 | text: list[str], 113 | padding_value = -1 114 | ) -> Int['b nt']: 115 | 116 | phonemes = [g2p(t) for t in text] 117 | list_tensors = [tensor([phoneme_to_index[p] for p in one_phoneme]) for one_phoneme in phonemes] 118 | padded_tensor = pad_sequence(list_tensors, padding_value = -1) 119 | return padded_tensor 120 | 121 | return encode, (num_phonemes + num_extended_chars) 122 | 123 | # tensor helpers 124 | 125 | def log(t, eps = 1e-5): 126 | return t.clamp(min = eps).log() 127 | 128 | def lens_to_mask( 129 | t: Int['b'], 130 | length: int | None = None 131 | ) -> Bool['b n']: 132 | 133 | if not exists(length): 134 | length = t.amax() 135 | 136 | seq = torch.arange(length, device = t.device) 137 | return einx.less('n, b -> b n', seq, t) 138 | 139 | def mask_from_start_end_indices( 140 | seq_len: Int['b'], 141 | start: Int['b'], 142 | end: Int['b'] 143 | ): 144 | max_seq_len = seq_len.max().item() 145 | seq = torch.arange(max_seq_len, device = start.device).long() 146 | return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end) 147 | 148 | def mask_from_frac_lengths( 149 | seq_len: Int['b'], 150 | frac_lengths: Float['b'], 151 | max_length: int | None = None 152 | ): 153 | lengths = (frac_lengths * seq_len).long() 154 | max_start = seq_len - lengths 155 | 156 | rand = torch.rand_like(frac_lengths) 157 | start = (max_start * rand).long().clamp(min = 0) 158 | end = start + lengths 159 | 160 | out = mask_from_start_end_indices(seq_len, start, end) 161 | 162 | if exists(max_length): 163 | out = pad_to_length(out, max_length) 164 | 165 | return out 166 | 167 | def maybe_masked_mean( 168 | t: Float['b n d'], 169 | mask: Bool['b n'] | None = None 170 | ) -> Float['b d']: 171 | 172 | if not exists(mask): 173 | return t.mean(dim = 1) 174 | 175 | t = einx.where('b n, b n d, -> b n d', mask, t, 0.) 176 | num = reduce(t, 'b n d -> b d', 'sum') 177 | den = reduce(mask.float(), 'b n -> b', 'sum') 178 | 179 | return einx.divide('b d, b -> b d', num, den.clamp(min = 1.)) 180 | 181 | def pad_to_length( 182 | t: Tensor, 183 | length: int, 184 | value = None 185 | ): 186 | seq_len = t.shape[-1] 187 | if length > seq_len: 188 | t = F.pad(t, (0, length - seq_len), value = value) 189 | 190 | return t[..., :length] 191 | 192 | def interpolate_1d( 193 | x: Tensor, 194 | length: int, 195 | mode = 'bilinear' 196 | ): 197 | x = rearrange(x, 'n d -> 1 d n 1') 198 | x = F.interpolate(x, (length, 1), mode = mode) 199 | return rearrange(x, '1 d n 1 -> n d') 200 | 201 | # to mel spec 202 | 203 | class MelSpec(Module): 204 | def __init__( 205 | self, 206 | filter_length = 1024, 207 | hop_length = 256, 208 | win_length = 1024, 209 | n_mel_channels = 100, 210 | sampling_rate = 24_000, 211 | normalize = False, 212 | power = 1, 213 | norm = None, 214 | center = True, 215 | ): 216 | super().__init__() 217 | self.n_mel_channels = n_mel_channels 218 | self.sampling_rate = sampling_rate 219 | 220 | self.mel_stft = torchaudio.transforms.MelSpectrogram( 221 | sample_rate = sampling_rate, 222 | n_fft = filter_length, 223 | win_length = win_length, 224 | hop_length = hop_length, 225 | n_mels = n_mel_channels, 226 | power = power, 227 | center = center, 228 | normalized = normalize, 229 | norm = norm, 230 | ) 231 | 232 | self.register_buffer('dummy', tensor(0), persistent = False) 233 | 234 | def forward(self, inp): 235 | if len(inp.shape) == 3: 236 | inp = rearrange(inp, 'b 1 nw -> b nw') 237 | 238 | assert len(inp.shape) == 2 239 | 240 | if self.dummy.device != inp.device: 241 | self.to(inp.device) 242 | 243 | mel = self.mel_stft(inp) 244 | mel = log(mel) 245 | return mel 246 | 247 | # adaln zero from DiT paper 248 | 249 | class AdaLNZero(Module): 250 | def __init__( 251 | self, 252 | dim, 253 | dim_condition = None, 254 | init_bias_value = -2. 255 | ): 256 | super().__init__() 257 | dim_condition = default(dim_condition, dim) 258 | self.to_gamma = nn.Linear(dim_condition, dim) 259 | 260 | nn.init.zeros_(self.to_gamma.weight) 261 | nn.init.constant_(self.to_gamma.bias, init_bias_value) 262 | 263 | def forward(self, x, *, condition): 264 | if condition.ndim == 2: 265 | condition = rearrange(condition, 'b d -> b 1 d') 266 | 267 | gamma = self.to_gamma(condition).sigmoid() 268 | return x * gamma 269 | 270 | # random projection fourier embedding 271 | 272 | class RandomFourierEmbed(Module): 273 | def __init__(self, dim): 274 | super().__init__() 275 | assert divisible_by(dim, 2) 276 | self.register_buffer('weights', torch.randn(dim // 2)) 277 | 278 | def forward(self, x): 279 | freqs = einx.multiply('i, j -> i j', x, self.weights) * 2 * torch.pi 280 | fourier_embed, _ = pack((x, freqs.sin(), freqs.cos()), 'b *') 281 | return fourier_embed 282 | 283 | # character embedding 284 | 285 | class CharacterEmbed(Module): 286 | def __init__( 287 | self, 288 | dim, 289 | num_embeds = 256, 290 | ): 291 | super().__init__() 292 | self.dim = dim 293 | self.embed = nn.Embedding(num_embeds + 1, dim) # will just use 0 as the 'filler token' 294 | 295 | def forward( 296 | self, 297 | text: Int['b nt'], 298 | max_seq_len: int, 299 | **kwargs 300 | ) -> Float['b n d']: 301 | 302 | text = text + 1 # shift all other token ids up by 1 and use 0 as filler token 303 | 304 | text = text[:, :max_seq_len] # just curtail if character tokens are more than the mel spec tokens, one of the edge cases the paper did not address 305 | text = pad_to_length(text, max_seq_len, value = 0) 306 | 307 | return self.embed(text) 308 | 309 | class InterpolatedCharacterEmbed(Module): 310 | def __init__( 311 | self, 312 | dim, 313 | num_embeds = 256, 314 | ): 315 | super().__init__() 316 | self.dim = dim 317 | self.embed = nn.Embedding(num_embeds, dim) 318 | 319 | self.abs_pos_mlp = Sequential( 320 | Rearrange('... -> ... 1'), 321 | Linear(1, dim), 322 | nn.SiLU(), 323 | Linear(dim, dim) 324 | ) 325 | 326 | def forward( 327 | self, 328 | text: Int['b nt'], 329 | max_seq_len: int, 330 | mask: Bool['b n'] | None = None 331 | ) -> Float['b n d']: 332 | 333 | device = text.device 334 | seq = torch.arange(max_seq_len, device = device) 335 | 336 | mask = default(mask, (None,)) 337 | 338 | interp_embeds = [] 339 | interp_abs_positions = [] 340 | 341 | for one_text, one_mask in zip_longest(text, mask): 342 | 343 | valid_text = one_text >= 0 344 | one_text = one_text[valid_text] 345 | one_text_embed = self.embed(one_text) 346 | 347 | # save the absolute positions 348 | 349 | text_seq_len = one_text.shape[0] 350 | 351 | # determine audio sequence length from mask 352 | 353 | audio_seq_len = max_seq_len 354 | if exists(one_mask): 355 | audio_seq_len = one_mask.sum().long().item() 356 | 357 | # interpolate text embedding to audio embedding length 358 | 359 | interp_text_embed = interpolate_1d(one_text_embed, audio_seq_len) 360 | interp_abs_pos = torch.linspace(0, text_seq_len, audio_seq_len, device = device) 361 | 362 | interp_embeds.append(interp_text_embed) 363 | interp_abs_positions.append(interp_abs_pos) 364 | 365 | interp_embeds = pad_sequence(interp_embeds) 366 | interp_abs_positions = pad_sequence(interp_abs_positions) 367 | 368 | interp_embeds = F.pad(interp_embeds, (0, 0, 0, max_seq_len - interp_embeds.shape[-2])) 369 | interp_abs_positions = pad_to_length(interp_abs_positions, max_seq_len) 370 | 371 | # pass interp absolute positions through mlp for implicit positions 372 | 373 | interp_embeds = interp_embeds + self.abs_pos_mlp(interp_abs_positions) 374 | 375 | if exists(mask): 376 | interp_embeds = einx.where('b n, b n d, -> b n d', mask, interp_embeds, 0.) 377 | 378 | return interp_embeds 379 | 380 | # text audio cross conditioning in multistream setup 381 | 382 | class TextAudioCrossCondition(Module): 383 | def __init__( 384 | self, 385 | dim, 386 | dim_text, 387 | cond_audio_to_text = True, 388 | ): 389 | super().__init__() 390 | self.text_to_audio = nn.Linear(dim_text + dim, dim, bias = False) 391 | nn.init.zeros_(self.text_to_audio.weight) 392 | 393 | self.cond_audio_to_text = cond_audio_to_text 394 | 395 | if cond_audio_to_text: 396 | self.audio_to_text = nn.Linear(dim + dim_text, dim_text, bias = False) 397 | nn.init.zeros_(self.audio_to_text.weight) 398 | 399 | def forward( 400 | self, 401 | audio: Float['b n d'], 402 | text: Float['b n dt'] 403 | ): 404 | audio_text, _ = pack((audio, text), 'b n *') 405 | 406 | text_cond = self.text_to_audio(audio_text) 407 | audio_cond = self.audio_to_text(audio_text) if self.cond_audio_to_text else 0. 408 | 409 | return audio + text_cond, text + audio_cond 410 | 411 | # attention and transformer backbone 412 | # for use in both e2tts as well as duration module 413 | 414 | class Transformer(Module): 415 | @beartype 416 | def __init__( 417 | self, 418 | *, 419 | dim, 420 | dim_text = None, # will default to half of audio dimension 421 | depth = 8, 422 | heads = 8, 423 | dim_head = 64, 424 | ff_mult = 4, 425 | text_depth = None, 426 | text_heads = None, 427 | text_dim_head = None, 428 | text_ff_mult = None, 429 | cond_on_time = True, 430 | abs_pos_emb = True, 431 | max_seq_len = 8192, 432 | dropout = 0.1, 433 | num_registers = 32, 434 | attn_kwargs: dict = dict( 435 | gate_value_heads = True, 436 | softclamp_logits = True, 437 | ), 438 | ff_kwargs: dict = dict(), 439 | ): 440 | super().__init__() 441 | assert divisible_by(depth, 2), 'depth needs to be even' 442 | 443 | # absolute positional embedding 444 | 445 | self.max_seq_len = max_seq_len 446 | self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if abs_pos_emb else None 447 | 448 | self.dim = dim 449 | 450 | dim_text = default(dim_text, dim // 2) 451 | self.dim_text = dim_text 452 | 453 | text_heads = default(text_heads, heads) 454 | text_dim_head = default(text_dim_head, dim_head) 455 | text_ff_mult = default(text_ff_mult, ff_mult) 456 | text_depth = default(text_depth, depth) 457 | 458 | assert 1 <= text_depth <= depth, 'must have at least 1 layer of text conditioning, but less than total number of speech layers' 459 | 460 | self.depth = depth 461 | self.layers = ModuleList([]) 462 | 463 | # registers 464 | 465 | self.num_registers = num_registers 466 | self.registers = nn.Parameter(torch.zeros(num_registers, dim)) 467 | nn.init.normal_(self.registers, std = 0.02) 468 | 469 | self.text_registers = nn.Parameter(torch.zeros(num_registers, dim_text)) 470 | nn.init.normal_(self.text_registers, std = 0.02) 471 | 472 | # rotary embedding 473 | 474 | self.rotary_emb = RotaryEmbedding(dim_head) 475 | self.text_rotary_emb = RotaryEmbedding(dim_head) 476 | 477 | # time conditioning 478 | # will use adaptive rmsnorm 479 | 480 | self.cond_on_time = cond_on_time 481 | rmsnorm_klass = RMSNorm if not cond_on_time else AdaptiveRMSNorm 482 | postbranch_klass = Identity if not cond_on_time else partial(AdaLNZero, dim = dim) 483 | 484 | self.time_cond_mlp = Identity() 485 | 486 | if cond_on_time: 487 | self.time_cond_mlp = Sequential( 488 | RandomFourierEmbed(dim), 489 | Linear(dim + 1, dim), 490 | nn.SiLU() 491 | ) 492 | 493 | for ind in range(depth): 494 | is_later_half = ind >= (depth // 2) 495 | has_text = ind < text_depth 496 | 497 | # speech related 498 | 499 | attn_norm = rmsnorm_klass(dim) 500 | attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, **attn_kwargs) 501 | attn_adaln_zero = postbranch_klass() 502 | 503 | ff_norm = rmsnorm_klass(dim) 504 | ff = FeedForward(dim = dim, glu = True, mult = ff_mult, dropout = dropout, **ff_kwargs) 505 | ff_adaln_zero = postbranch_klass() 506 | 507 | skip_proj = Linear(dim * 2, dim, bias = False) if is_later_half else None 508 | 509 | speech_modules = ModuleList([ 510 | skip_proj, 511 | attn_norm, 512 | attn, 513 | attn_adaln_zero, 514 | ff_norm, 515 | ff, 516 | ff_adaln_zero, 517 | ]) 518 | 519 | text_modules = None 520 | 521 | if has_text: 522 | # text related 523 | 524 | text_attn_norm = RMSNorm(dim_text) 525 | text_attn = Attention(dim = dim_text, heads = text_heads, dim_head = text_dim_head, dropout = dropout, **attn_kwargs) 526 | 527 | text_ff_norm = RMSNorm(dim_text) 528 | text_ff = FeedForward(dim = dim_text, glu = True, mult = text_ff_mult, dropout = dropout, **ff_kwargs) 529 | 530 | # cross condition 531 | 532 | is_last = ind == (text_depth - 1) 533 | 534 | cross_condition = TextAudioCrossCondition(dim = dim, dim_text = dim_text, cond_audio_to_text = not is_last) 535 | 536 | text_modules = ModuleList([ 537 | text_attn_norm, 538 | text_attn, 539 | text_ff_norm, 540 | text_ff, 541 | cross_condition 542 | ]) 543 | 544 | self.layers.append(ModuleList([ 545 | speech_modules, 546 | text_modules 547 | ])) 548 | 549 | self.final_norm = RMSNorm(dim) 550 | 551 | def forward( 552 | self, 553 | x: Float['b n d'], 554 | times: Float['b'] | Float[''] | None = None, 555 | mask: Bool['b n'] | None = None, 556 | text_embed: Float['b n dt'] | None = None, 557 | ): 558 | batch, seq_len, device = *x.shape[:2], x.device 559 | 560 | assert not (exists(times) ^ self.cond_on_time), '`times` must be passed in if `cond_on_time` is set to `True` and vice versa' 561 | 562 | # handle absolute positions if needed 563 | 564 | if exists(self.abs_pos_emb): 565 | assert seq_len <= self.max_seq_len, f'{seq_len} exceeds the set `max_seq_len` ({self.max_seq_len}) on Transformer' 566 | seq = torch.arange(seq_len, device = device) 567 | x = x + self.abs_pos_emb(seq) 568 | 569 | # handle adaptive rmsnorm kwargs 570 | 571 | norm_kwargs = dict() 572 | 573 | if exists(times): 574 | if times.ndim == 0: 575 | times = repeat(times, ' -> b', b = batch) 576 | 577 | times = self.time_cond_mlp(times) 578 | norm_kwargs.update(condition = times) 579 | 580 | # register tokens 581 | 582 | registers = repeat(self.registers, 'r d -> b r d', b = batch) 583 | x, registers_packed_shape = pack((registers, x), 'b * d') 584 | 585 | if exists(mask): 586 | mask = F.pad(mask, (self.num_registers, 0), value = True) 587 | 588 | # rotary embedding 589 | 590 | rotary_pos_emb = self.rotary_emb.forward_from_seq_len(x.shape[-2]) 591 | 592 | # text related 593 | 594 | if exists(text_embed): 595 | text_rotary_pos_emb = self.text_rotary_emb.forward_from_seq_len(x.shape[-2]) 596 | 597 | text_registers = repeat(self.text_registers, 'r d -> b r d', b = batch) 598 | text_embed, _ = pack((text_registers, text_embed), 'b * d') 599 | 600 | # skip connection related stuff 601 | 602 | skips = [] 603 | 604 | # go through the layers 605 | 606 | for ind, (speech_modules, text_modules) in enumerate(self.layers): 607 | layer = ind + 1 608 | 609 | ( 610 | maybe_skip_proj, 611 | attn_norm, 612 | attn, 613 | maybe_attn_adaln_zero, 614 | ff_norm, 615 | ff, 616 | maybe_ff_adaln_zero 617 | ) = speech_modules 618 | 619 | # smaller text transformer 620 | 621 | if exists(text_embed) and exists(text_modules): 622 | 623 | ( 624 | text_attn_norm, 625 | text_attn, 626 | text_ff_norm, 627 | text_ff, 628 | cross_condition 629 | ) = text_modules 630 | 631 | text_embed = text_attn(text_attn_norm(text_embed), rotary_pos_emb = text_rotary_pos_emb, mask = mask) + text_embed 632 | 633 | text_embed = text_ff(text_ff_norm(text_embed)) + text_embed 634 | 635 | x, text_embed = cross_condition(x, text_embed) 636 | 637 | # skip connection logic 638 | 639 | is_first_half = layer <= (self.depth // 2) 640 | is_later_half = not is_first_half 641 | 642 | if is_first_half: 643 | skips.append(x) 644 | 645 | if is_later_half: 646 | skip = skips.pop() 647 | x = torch.cat((x, skip), dim = -1) 648 | x = maybe_skip_proj(x) 649 | 650 | # attention and feedforward blocks 651 | 652 | attn_out = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask) 653 | 654 | x = x + maybe_attn_adaln_zero(attn_out, **norm_kwargs) 655 | 656 | ff_out = ff(ff_norm(x, **norm_kwargs)) 657 | 658 | x = x + maybe_ff_adaln_zero(ff_out, **norm_kwargs) 659 | 660 | assert len(skips) == 0 661 | 662 | _, x = unpack(x, registers_packed_shape, 'b * d') 663 | 664 | return self.final_norm(x) 665 | 666 | # main classes 667 | 668 | class DurationPredictor(Module): 669 | @beartype 670 | def __init__( 671 | self, 672 | transformer: dict | Transformer, 673 | num_channels = None, 674 | mel_spec_kwargs: dict = dict(), 675 | char_embed_kwargs: dict = dict(), 676 | text_num_embeds = None, 677 | tokenizer: ( 678 | Literal['char_utf8', 'phoneme_en'] | 679 | Callable[[list[str]], Int['b nt']] 680 | ) = 'char_utf8' 681 | ): 682 | super().__init__() 683 | 684 | if isinstance(transformer, dict): 685 | transformer = Transformer( 686 | **transformer, 687 | cond_on_time = False 688 | ) 689 | 690 | # mel spec 691 | 692 | self.mel_spec = MelSpec(**mel_spec_kwargs) 693 | self.num_channels = default(num_channels, self.mel_spec.n_mel_channels) 694 | 695 | self.transformer = transformer 696 | 697 | dim = transformer.dim 698 | dim_text = transformer.dim_text 699 | 700 | self.dim = dim 701 | 702 | self.proj_in = Linear(self.num_channels, self.dim) 703 | 704 | # tokenizer and text embed 705 | 706 | if callable(tokenizer): 707 | assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function' 708 | self.tokenizer = tokenizer 709 | elif tokenizer == 'char_utf8': 710 | text_num_embeds = 256 711 | self.tokenizer = list_str_to_tensor 712 | elif tokenizer == 'phoneme_en': 713 | self.tokenizer, text_num_embeds = get_g2p_en_encode() 714 | else: 715 | raise ValueError(f'unknown tokenizer string {tokenizer}') 716 | 717 | self.embed_text = CharacterEmbed(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs) 718 | 719 | # to prediction 720 | 721 | self.to_pred = Sequential( 722 | Linear(dim, 1, bias = False), 723 | nn.Softplus(), 724 | Rearrange('... 1 -> ...') 725 | ) 726 | 727 | def forward( 728 | self, 729 | x: Float['b n d'] | Float['b nw'], 730 | *, 731 | text: Int['b nt'] | list[str] | None = None, 732 | lens: Int['b'] | None = None, 733 | return_loss = True 734 | ): 735 | # raw wave 736 | 737 | if x.ndim == 2: 738 | x = self.mel_spec(x) 739 | x = rearrange(x, 'b d n -> b n d') 740 | assert x.shape[-1] == self.dim 741 | 742 | x = self.proj_in(x) 743 | 744 | batch, seq_len, device = *x.shape[:2], x.device 745 | 746 | # text 747 | 748 | text_embed = None 749 | 750 | if exists(text): 751 | if isinstance(text, list): 752 | text = list_str_to_tensor(text).to(device) 753 | assert text.shape[0] == batch 754 | 755 | text_embed = self.embed_text(text, seq_len) 756 | 757 | # handle lengths (duration) 758 | 759 | if not exists(lens): 760 | lens = torch.full((batch,), seq_len, device = device) 761 | 762 | mask = lens_to_mask(lens, length = seq_len) 763 | 764 | # if returning a loss, mask out randomly from an index and have it predict the duration 765 | 766 | if return_loss: 767 | rand_frac_index = x.new_zeros(batch).uniform_(0, 1) 768 | rand_index = (rand_frac_index * lens).long() 769 | 770 | seq = torch.arange(seq_len, device = device) 771 | mask &= einx.less('n, b -> b n', seq, rand_index) 772 | 773 | # attending 774 | 775 | x = self.transformer( 776 | x, 777 | mask = mask, 778 | text_embed = text_embed, 779 | ) 780 | 781 | x = maybe_masked_mean(x, mask) 782 | 783 | pred = self.to_pred(x) 784 | 785 | # return the prediction if not returning loss 786 | 787 | if not return_loss: 788 | return pred 789 | 790 | # loss 791 | 792 | return F.mse_loss(pred, lens.float()) 793 | 794 | class E2TTS(Module): 795 | 796 | @beartype 797 | def __init__( 798 | self, 799 | transformer: dict | Transformer = None, 800 | duration_predictor: dict | DurationPredictor | None = None, 801 | odeint_kwargs: dict = dict( 802 | atol = 1e-5, 803 | rtol = 1e-5, 804 | method = 'midpoint' 805 | ), 806 | cond_drop_prob = 0.25, 807 | num_channels = None, 808 | mel_spec_module: Module | None = None, 809 | char_embed_kwargs: dict = dict(), 810 | mel_spec_kwargs: dict = dict(), 811 | frac_lengths_mask: tuple[float, float] = (0.7, 1.), 812 | concat_cond = False, 813 | interpolated_text = False, 814 | text_num_embeds: int | None = None, 815 | tokenizer: ( 816 | Literal['char_utf8', 'phoneme_en'] | 817 | Callable[[list[str]], Int['b nt']] 818 | ) = 'char_utf8', 819 | use_vocos = True, 820 | pretrained_vocos_path = 'charactr/vocos-mel-24khz', 821 | sampling_rate: int | None = None 822 | ): 823 | super().__init__() 824 | 825 | if isinstance(transformer, dict): 826 | transformer = Transformer( 827 | **transformer, 828 | cond_on_time = True 829 | ) 830 | 831 | if isinstance(duration_predictor, dict): 832 | duration_predictor = DurationPredictor(**duration_predictor) 833 | 834 | self.transformer = transformer 835 | 836 | dim = transformer.dim 837 | dim_text = transformer.dim_text 838 | 839 | self.dim = dim 840 | self.dim_text = dim_text 841 | 842 | self.frac_lengths_mask = frac_lengths_mask 843 | 844 | self.duration_predictor = duration_predictor 845 | 846 | # sampling 847 | 848 | self.odeint_kwargs = odeint_kwargs 849 | 850 | # mel spec 851 | 852 | self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) 853 | num_channels = default(num_channels, self.mel_spec.n_mel_channels) 854 | 855 | self.num_channels = num_channels 856 | self.sampling_rate = default(sampling_rate, getattr(self.mel_spec, 'sampling_rate', None)) 857 | 858 | # whether to concat condition and project rather than project both and sum 859 | 860 | self.concat_cond = concat_cond 861 | 862 | if concat_cond: 863 | self.proj_in = nn.Linear(num_channels * 2, dim) 864 | else: 865 | self.proj_in = nn.Linear(num_channels, dim) 866 | self.cond_proj_in = nn.Linear(num_channels, dim) 867 | 868 | # to prediction 869 | 870 | self.to_pred = Linear(dim, num_channels) 871 | 872 | # tokenizer and text embed 873 | 874 | if callable(tokenizer): 875 | assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function' 876 | self.tokenizer = tokenizer 877 | elif tokenizer == 'char_utf8': 878 | text_num_embeds = 256 879 | self.tokenizer = list_str_to_tensor 880 | elif tokenizer == 'phoneme_en': 881 | self.tokenizer, text_num_embeds = get_g2p_en_encode() 882 | else: 883 | raise ValueError(f'unknown tokenizer string {tokenizer}') 884 | 885 | self.cond_drop_prob = cond_drop_prob 886 | 887 | # text embedding 888 | 889 | text_embed_klass = CharacterEmbed if not interpolated_text else InterpolatedCharacterEmbed 890 | 891 | self.embed_text = text_embed_klass(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs) 892 | 893 | # default vocos for mel -> audio 894 | 895 | self.vocos = Vocos.from_pretrained(pretrained_vocos_path) if use_vocos else None 896 | 897 | @property 898 | def device(self): 899 | return next(self.parameters()).device 900 | 901 | def transformer_with_pred_head( 902 | self, 903 | x: Float['b n d'], 904 | cond: Float['b n d'], 905 | times: Float['b'], 906 | mask: Bool['b n'] | None = None, 907 | text: Int['b nt'] | None = None, 908 | drop_text_cond: bool | None = None 909 | ): 910 | seq_len = x.shape[-2] 911 | drop_text_cond = default(drop_text_cond, self.training and random() < self.cond_drop_prob) 912 | 913 | if self.concat_cond: 914 | # concat condition, given as using voicebox-like scheme 915 | x = torch.cat((cond, x), dim = -1) 916 | 917 | x = self.proj_in(x) 918 | 919 | if not self.concat_cond: 920 | # an alternative is to simply sum the condition 921 | # seems to work fine 922 | 923 | cond = self.cond_proj_in(cond) 924 | x = x + cond 925 | 926 | # whether to use a text embedding 927 | 928 | text_embed = None 929 | if exists(text) and not drop_text_cond: 930 | text_embed = self.embed_text(text, seq_len, mask = mask) 931 | 932 | # attend 933 | 934 | attended = self.transformer( 935 | x, 936 | times = times, 937 | mask = mask, 938 | text_embed = text_embed 939 | ) 940 | 941 | return self.to_pred(attended) 942 | 943 | def cfg_transformer_with_pred_head( 944 | self, 945 | *args, 946 | cfg_strength: float = 1., 947 | **kwargs, 948 | ): 949 | 950 | pred = self.transformer_with_pred_head(*args, drop_text_cond = False, **kwargs) 951 | 952 | if cfg_strength < 1e-5: 953 | return pred 954 | 955 | null_pred = self.transformer_with_pred_head(*args, drop_text_cond = True, **kwargs) 956 | 957 | return pred + (pred - null_pred) * cfg_strength 958 | 959 | @torch.no_grad() 960 | def sample( 961 | self, 962 | cond: Float['b n d'] | Float['b nw'], 963 | *, 964 | text: Int['b nt'] | list[str] | None = None, 965 | lens: Int['b'] | None = None, 966 | duration: int | Int['b'] | None = None, 967 | steps = 32, 968 | cfg_strength = 1., # they used a classifier free guidance strength of 1. 969 | max_duration = 4096, # in case the duration predictor goes haywire 970 | vocoder: Callable[[Float['b d n']], list[Float['_']]] | None = None, 971 | return_raw_output: bool | None = None, 972 | save_to_filename: str | None = None 973 | ) -> ( 974 | Float['b n d'], 975 | list[Float['_']] 976 | ): 977 | self.eval() 978 | 979 | # raw wave 980 | 981 | if cond.ndim == 2: 982 | cond = self.mel_spec(cond) 983 | cond = rearrange(cond, 'b d n -> b n d') 984 | assert cond.shape[-1] == self.num_channels 985 | 986 | batch, cond_seq_len, device = *cond.shape[:2], cond.device 987 | 988 | if not exists(lens): 989 | lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long) 990 | 991 | # text 992 | 993 | if isinstance(text, list): 994 | text = self.tokenizer(text).to(device) 995 | assert text.shape[0] == batch 996 | 997 | if exists(text): 998 | text_lens = (text != -1).sum(dim = -1) 999 | lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters 1000 | 1001 | # duration 1002 | 1003 | cond_mask = lens_to_mask(lens) 1004 | 1005 | if exists(duration): 1006 | if isinstance(duration, int): 1007 | duration = torch.full((batch,), duration, device = device, dtype = torch.long) 1008 | 1009 | elif exists(self.duration_predictor): 1010 | duration = self.duration_predictor(cond, text = text, lens = lens, return_loss = False).long() 1011 | 1012 | duration = torch.maximum(lens + 1, duration) # just add one token so something is generated 1013 | duration = duration.clamp(max = max_duration) 1014 | 1015 | assert duration.shape[0] == batch 1016 | 1017 | max_duration = duration.amax() 1018 | 1019 | cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.) 1020 | cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False) 1021 | cond_mask = rearrange(cond_mask, '... -> ... 1') 1022 | 1023 | mask = lens_to_mask(duration) 1024 | 1025 | # neural ode 1026 | 1027 | def fn(t, x): 1028 | # at each step, conditioning is fixed 1029 | 1030 | step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) 1031 | 1032 | # predict flow 1033 | 1034 | return self.cfg_transformer_with_pred_head( 1035 | x, 1036 | step_cond, 1037 | times = t, 1038 | text = text, 1039 | mask = mask, 1040 | cfg_strength = cfg_strength 1041 | ) 1042 | 1043 | y0 = torch.randn_like(cond) 1044 | t = torch.linspace(0, 1, steps, device = self.device) 1045 | 1046 | trajectory = odeint(fn, y0, t, **self.odeint_kwargs) 1047 | sampled = trajectory[-1] 1048 | 1049 | out = sampled 1050 | 1051 | out = torch.where(cond_mask, cond, out) 1052 | 1053 | # able to return raw untransformed output, if not using mel rep 1054 | 1055 | if exists(return_raw_output) and return_raw_output: 1056 | return out 1057 | 1058 | # take care of transforming mel to audio if `vocoder` is passed in, or if `use_vocos` is turned on 1059 | 1060 | if exists(vocoder): 1061 | assert not exists(self.vocos), '`use_vocos` should not be turned on if you are passing in a custom `vocoder` on sampling' 1062 | out = rearrange(out, 'b n d -> b d n') 1063 | out = vocoder(out) 1064 | 1065 | elif exists(self.vocos): 1066 | 1067 | audio = [] 1068 | for mel, one_mask in zip(out, mask): 1069 | one_out = DB_to_amplitude(mel[one_mask], ref = 1., power = 0.5) 1070 | 1071 | one_out = rearrange(one_out, 'n d -> 1 d n') 1072 | one_audio = self.vocos.decode(one_out) 1073 | one_audio = rearrange(one_audio, '1 nw -> nw') 1074 | audio.append(one_audio) 1075 | 1076 | out = audio 1077 | 1078 | if exists(save_to_filename): 1079 | assert exists(vocoder) or exists(self.vocos) 1080 | assert exists(self.sampling_rate) 1081 | 1082 | path = Path(save_to_filename) 1083 | parent_path = path.parents[0] 1084 | parent_path.mkdir(exist_ok = True, parents = True) 1085 | 1086 | for ind, one_audio in enumerate(out): 1087 | one_audio = rearrange(one_audio, 'nw -> 1 nw') 1088 | save_path = str(parent_path / f'{ind + 1}.{path.name}') 1089 | torchaudio.save(save_path, one_audio.detach().cpu(), sample_rate = self.sampling_rate) 1090 | 1091 | return out 1092 | 1093 | def forward( 1094 | self, 1095 | inp: Float['b n d'] | Float['b nw'], # mel or raw wave 1096 | *, 1097 | text: Int['b nt'] | list[str] | None = None, 1098 | times: Int['b'] | None = None, 1099 | lens: Int['b'] | None = None, 1100 | ): 1101 | # handle raw wave 1102 | 1103 | if inp.ndim == 2: 1104 | inp = self.mel_spec(inp) 1105 | inp = rearrange(inp, 'b d n -> b n d') 1106 | assert inp.shape[-1] == self.num_channels 1107 | 1108 | batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, self.device 1109 | 1110 | # handle text as string 1111 | 1112 | if isinstance(text, list): 1113 | text = self.tokenizer(text).to(device) 1114 | assert text.shape[0] == batch 1115 | 1116 | # lens and mask 1117 | 1118 | if not exists(lens): 1119 | lens = torch.full((batch,), seq_len, device = device) 1120 | 1121 | mask = lens_to_mask(lens, length = seq_len) 1122 | 1123 | # get a random span to mask out for training conditionally 1124 | 1125 | frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask) 1126 | rand_span_mask = mask_from_frac_lengths(lens, frac_lengths, max_length = seq_len) 1127 | 1128 | if exists(mask): 1129 | rand_span_mask &= mask 1130 | 1131 | # mel is x1 1132 | 1133 | x1 = inp 1134 | 1135 | # main conditional flow training logic 1136 | # just ~5 loc 1137 | 1138 | # x0 is gaussian noise 1139 | 1140 | x0 = torch.randn_like(x1) 1141 | 1142 | # t is random times from above 1143 | 1144 | times = torch.rand((batch,), dtype = dtype, device = self.device) 1145 | t = rearrange(times, 'b -> b 1 1') 1146 | 1147 | # sample xt (w in the paper) 1148 | 1149 | w = (1. - t) * x0 + t * x1 1150 | 1151 | flow = x1 - x0 1152 | 1153 | # only predict what is within the random mask span for infilling 1154 | 1155 | cond = einx.where( 1156 | 'b n, b n d, b n d -> b n d', 1157 | rand_span_mask, 1158 | torch.zeros_like(x1), x1 1159 | ) 1160 | 1161 | # transformer and prediction head 1162 | 1163 | pred = self.transformer_with_pred_head(w, cond, times = times, text = text, mask = mask) 1164 | 1165 | # flow matching loss 1166 | 1167 | loss = F.mse_loss(pred, flow, reduction = 'none') 1168 | 1169 | loss = loss[rand_span_mask].mean() 1170 | 1171 | return E2TTSReturn(loss, cond, pred) 1172 | --------------------------------------------------------------------------------