├── e2-tts.png
├── e2_tts_pytorch
├── __init__.py
├── trainer.py
└── e2_tts.py
├── train_example.py
├── LICENSE
├── .github
└── workflows
│ └── python-publish.yml
├── pyproject.toml
├── README.md
└── .gitignore
/e2-tts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wetdog/e2-tts-pytorch/main/e2-tts.png
--------------------------------------------------------------------------------
/e2_tts_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from e2_tts_pytorch.e2_tts import (
2 | Transformer,
3 | DurationPredictor,
4 | E2TTS,
5 | )
6 |
7 | from e2_tts_pytorch.trainer import (
8 | E2Trainer
9 | )
--------------------------------------------------------------------------------
/train_example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from e2_tts_pytorch import E2TTS, DurationPredictor
3 |
4 | from torch.optim import Adam
5 | from datasets import load_dataset
6 |
7 | from e2_tts_pytorch.trainer import (
8 | HFDataset,
9 | E2Trainer
10 | )
11 |
12 | duration_predictor = DurationPredictor(
13 | transformer = dict(
14 | dim = 512,
15 | depth = 6,
16 | )
17 | )
18 |
19 | e2tts = E2TTS(
20 | duration_predictor = duration_predictor,
21 | transformer = dict(
22 | dim = 512,
23 | depth = 12,
24 | skip_connect_type = 'concat'
25 | ),
26 | )
27 |
28 | train_dataset = HFDataset(load_dataset("MushanW/GLOBE")["train"])
29 |
30 | optimizer = Adam(e2tts.parameters(), lr=7.5e-5)
31 |
32 | trainer = E2Trainer(
33 | e2tts,
34 | optimizer,
35 | num_warmup_steps=20000,
36 | grad_accumulation_steps = 1,
37 | checkpoint_path = 'e2tts.pt',
38 | log_file = 'e2tts.txt'
39 | )
40 |
41 | epochs = 10
42 | batch_size = 32
43 |
44 | trainer.train(train_dataset, epochs, batch_size, save_step=1000)
45 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "e2-tts-pytorch"
3 | version = "1.0.5"
4 | description = "E2-TTS in Pytorch"
5 | authors = [
6 | { name = "Phil Wang", email = "lucidrains@gmail.com" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">= 3.8"
10 | license = { file = "LICENSE" }
11 | keywords = [
12 | 'artificial intelligence',
13 | 'deep learning',
14 | 'transformers',
15 | 'attention mechanism',
16 | 'text to speech'
17 | ]
18 | classifiers=[
19 | 'Development Status :: 4 - Beta',
20 | 'Intended Audience :: Developers',
21 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
22 | 'License :: OSI Approved :: MIT License',
23 | 'Programming Language :: Python :: 3.8',
24 | ]
25 |
26 | dependencies = [
27 | 'accelerate>=0.33.0',
28 | 'beartype',
29 | 'einops>=0.8.0',
30 | 'einx>=0.3.0',
31 | 'ema-pytorch>=0.5.2',
32 | 'g2p-en',
33 | 'jaxtyping',
34 | 'loguru',
35 | 'tensorboard',
36 | 'torch>=2.0',
37 | 'torchdiffeq',
38 | 'torchaudio>=2.3.1',
39 | 'tqdm>=4.65.0',
40 | 'vocos',
41 | 'x-transformers>=1.31.14',
42 | ]
43 |
44 | [project.urls]
45 | Homepage = "https://pypi.org/project/e2-tts-pytorch/"
46 | Repository = "https://github.com/lucidrains/e2-tts-pytorch"
47 |
48 | [project.optional-dependencies]
49 | examples = ["datasets"]
50 |
51 | [build-system]
52 | requires = ["hatchling"]
53 | build-backend = "hatchling.build"
54 |
55 | [tool.hatch.metadata]
56 | allow-direct-references = true
57 |
58 | [tool.hatch.build.targets.wheel]
59 | packages = ["e2_tts_pytorch"]
60 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | ## E2 TTS - Pytorch
5 |
6 | Implementation of E2-TTS, Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS, in Pytorch
7 |
8 | The repository differs from the paper in that it uses a multistream transformer for text and audio, with conditioning done every transformer block in the E2 manner.
9 |
10 | It also includes an improvisation that was proven out by Manmay, where the text is simply interpolated to the length of the audio for conditioning. You can try this by setting `interpolated_text = True` on `E2TTS`
11 |
12 | ## Appreciation
13 |
14 | - Manmay for contributing working end-to-end training code!
15 |
16 | - Lucas Newman for the code contributions, helpful feedback, and for sharing the first set of positive experiments!
17 |
18 | - Jing for sharing the second positive result with a multilingual (English + Chinese) dataset!
19 |
20 | - Coice and Manmay for reporting the third and fourth successful runs. Farewell alignment engineering
21 |
22 | ## Install
23 |
24 | ```bash
25 | $ pip install e2-tts-pytorch
26 | ```
27 |
28 | ## Usage
29 |
30 | ```python
31 | import torch
32 |
33 | from e2_tts_pytorch import (
34 | E2TTS,
35 | DurationPredictor
36 | )
37 |
38 | duration_predictor = DurationPredictor(
39 | transformer = dict(
40 | dim = 512,
41 | depth = 8,
42 | )
43 | )
44 |
45 | mel = torch.randn(2, 1024, 100)
46 | text = ['Hello', 'Goodbye']
47 |
48 | loss = duration_predictor(mel, text = text)
49 | loss.backward()
50 |
51 | e2tts = E2TTS(
52 | duration_predictor = duration_predictor,
53 | transformer = dict(
54 | dim = 512,
55 | depth = 8
56 | ),
57 | )
58 |
59 | out = e2tts(mel, text = text)
60 | out.loss.backward()
61 |
62 | sampled = e2tts.sample(mel[:, :5], text = text)
63 |
64 | ```
65 |
66 | ## Citations
67 |
68 | ```bibtex
69 | @inproceedings{Eskimez2024E2TE,
70 | title = {E2 TTS: Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS},
71 | author = {Sefik Emre Eskimez and Xiaofei Wang and Manthan Thakker and Canrun Li and Chung-Hsien Tsai and Zhen Xiao and Hemin Yang and Zirun Zhu and Min Tang and Xu Tan and Yanqing Liu and Sheng Zhao and Naoyuki Kanda},
72 | year = {2024},
73 | url = {https://api.semanticscholar.org/CorpusID:270738197}
74 | }
75 | ```
76 |
77 | ```bibtex
78 | @inproceedings{Darcet2023VisionTN,
79 | title = {Vision Transformers Need Registers},
80 | author = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
81 | year = {2023},
82 | url = {https://api.semanticscholar.org/CorpusID:263134283}
83 | }
84 | ```
85 |
86 | ```bibtex
87 | @article{Bao2022AllAW,
88 | title = {All are Worth Words: A ViT Backbone for Diffusion Models},
89 | author = {Fan Bao and Shen Nie and Kaiwen Xue and Yue Cao and Chongxuan Li and Hang Su and Jun Zhu},
90 | journal = {2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
91 | year = {2022},
92 | pages = {22669-22679},
93 | url = {https://api.semanticscholar.org/CorpusID:253581703}
94 | }
95 | ```
96 |
97 | ```bibtex
98 | @article{Burtsev2021MultiStreamT,
99 | title = {Multi-Stream Transformers},
100 | author = {Mikhail S. Burtsev and Anna Rumshisky},
101 | journal = {ArXiv},
102 | year = {2021},
103 | volume = {abs/2107.10342},
104 | url = {https://api.semanticscholar.org/CorpusID:236171087}
105 | }
106 | ```
107 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/e2_tts_pytorch/trainer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | from tqdm import tqdm
5 | import matplotlib
6 | matplotlib.use("Agg")
7 | import matplotlib.pylab as plt
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.utils.data import DataLoader, Dataset
12 | from torch.utils.tensorboard import SummaryWriter
13 | from torch.optim.lr_scheduler import LinearLR, SequentialLR
14 |
15 | import torchaudio
16 |
17 | from einops import rearrange
18 |
19 | from accelerate import Accelerator
20 | from accelerate.utils import DistributedDataParallelKwargs
21 |
22 | from ema_pytorch import EMA
23 |
24 | from loguru import logger
25 |
26 | from e2_tts_pytorch.e2_tts import (
27 | E2TTS,
28 | DurationPredictor,
29 | MelSpec
30 | )
31 |
32 | def exists(v):
33 | return v is not None
34 |
35 | def default(v, d):
36 | return v if exists(v) else d
37 |
38 | # plot spectrogram
39 | def plot_spectrogram(spectrogram):
40 | fig, ax = plt.subplots(figsize=(10, 4))
41 | im = ax.imshow(spectrogram.T, aspect="auto", origin="lower", interpolation="none")
42 | plt.colorbar(im, ax=ax)
43 | plt.xlabel("Frames")
44 | plt.ylabel("Channels")
45 | plt.tight_layout()
46 |
47 | fig.canvas.draw()
48 | plt.close()
49 | return fig
50 |
51 | # collation
52 |
53 | def collate_fn(batch):
54 | mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
55 | mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
56 | max_mel_length = mel_lengths.amax()
57 |
58 | padded_mel_specs = []
59 | for spec in mel_specs:
60 | padding = (0, max_mel_length - spec.size(-1))
61 | padded_spec = F.pad(spec, padding, value = 0)
62 | padded_mel_specs.append(padded_spec)
63 |
64 | mel_specs = torch.stack(padded_mel_specs)
65 |
66 | text = [item['text'] for item in batch]
67 | text_lengths = torch.LongTensor([len(item) for item in text])
68 |
69 | return dict(
70 | mel = mel_specs,
71 | mel_lengths = mel_lengths,
72 | text = text,
73 | text_lengths = text_lengths,
74 | )
75 |
76 | # dataset
77 |
78 | class HFDataset(Dataset):
79 | def __init__(
80 | self,
81 | hf_dataset: Dataset,
82 | target_sample_rate = 24_000,
83 | hop_length = 256
84 | ):
85 | self.data = hf_dataset
86 | self.target_sample_rate = target_sample_rate
87 | self.hop_length = hop_length
88 | self.mel_spectrogram = MelSpec(sampling_rate=target_sample_rate)
89 |
90 | def __len__(self):
91 | return len(self.data)
92 |
93 | def __getitem__(self, index):
94 | row = self.data[index]
95 | audio = row['audio']['array']
96 |
97 | logger.info(f"Audio shape: {audio.shape}")
98 |
99 | sample_rate = row['audio']['sampling_rate']
100 | duration = audio.shape[-1] / sample_rate
101 |
102 | if duration > 20 or duration < 0.3:
103 | logger.warning(f"Skipping due to duration out of bound: {duration}")
104 | return self.__getitem__((index + 1) % len(self.data))
105 |
106 | audio_tensor = torch.from_numpy(audio).float()
107 |
108 | if sample_rate != self.target_sample_rate:
109 | resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
110 | audio_tensor = resampler(audio_tensor)
111 |
112 | audio_tensor = rearrange(audio_tensor, 't -> 1 t')
113 |
114 | mel_spec = self.mel_spectrogram(audio_tensor)
115 |
116 | mel_spec = rearrange(mel_spec, '1 d t -> d t')
117 |
118 | text = row['transcript']
119 |
120 | return dict(
121 | mel_spec = mel_spec,
122 | text = text,
123 | )
124 |
125 | # trainer
126 |
127 | class E2Trainer:
128 | def __init__(
129 | self,
130 | model: E2TTS,
131 | optimizer,
132 | num_warmup_steps=20000,
133 | grad_accumulation_steps=1,
134 | duration_predictor: DurationPredictor | None = None,
135 | checkpoint_path = None,
136 | log_file = "logs.txt",
137 | max_grad_norm = 1.0,
138 | sample_rate = 22050,
139 | tensorboard_log_dir = 'runs/e2_tts_experiment',
140 | accelerate_kwargs: dict = dict(),
141 | ema_kwargs: dict = dict()
142 | ):
143 | logger.add(log_file)
144 |
145 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
146 |
147 | self.accelerator = Accelerator(
148 | log_with = "all",
149 | kwargs_handlers = [ddp_kwargs],
150 | gradient_accumulation_steps = grad_accumulation_steps,
151 | **accelerate_kwargs
152 | )
153 |
154 | self.target_sample_rate = sample_rate
155 |
156 | self.model = model
157 |
158 | if self.is_main:
159 | self.ema_model = EMA(
160 | model,
161 | include_online_model = False,
162 | **ema_kwargs
163 | )
164 |
165 | self.ema_model.to(self.accelerator.device)
166 |
167 | self.duration_predictor = duration_predictor
168 | self.optimizer = optimizer
169 | self.num_warmup_steps = num_warmup_steps
170 | self.checkpoint_path = default(checkpoint_path, 'model.pth')
171 | self.mel_spectrogram = MelSpec(sampling_rate=self.target_sample_rate)
172 |
173 | self.model, self.optimizer = self.accelerator.prepare(
174 | self.model, self.optimizer
175 | )
176 | self.max_grad_norm = max_grad_norm
177 |
178 | self.writer = SummaryWriter(log_dir=tensorboard_log_dir)
179 |
180 | @property
181 | def is_main(self):
182 | return self.accelerator.is_main_process
183 |
184 | def save_checkpoint(self, step, finetune=False):
185 | self.accelerator.wait_for_everyone()
186 | if self.is_main:
187 | checkpoint = dict(
188 | model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
189 | optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
190 | ema_model_state_dict = self.ema_model.state_dict(),
191 | scheduler_state_dict = self.scheduler.state_dict(),
192 | step = step
193 | )
194 |
195 | self.accelerator.save(checkpoint, self.checkpoint_path)
196 |
197 | def load_checkpoint(self):
198 | if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path):
199 | return 0
200 |
201 | checkpoint = torch.load(self.checkpoint_path)
202 | self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
203 | self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
204 |
205 | if self.is_main:
206 | self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
207 |
208 | if self.scheduler:
209 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
210 | return checkpoint['step']
211 |
212 | def train(self, train_dataset, epochs, batch_size, num_workers=12, save_step=1000):
213 |
214 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers, pin_memory=True)
215 | total_steps = len(train_dataloader) * epochs
216 | decay_steps = total_steps - self.num_warmup_steps
217 | warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=self.num_warmup_steps)
218 | decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
219 | self.scheduler = SequentialLR(self.optimizer,
220 | schedulers=[warmup_scheduler, decay_scheduler],
221 | milestones=[self.num_warmup_steps])
222 | train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler)
223 | start_step = self.load_checkpoint()
224 | global_step = start_step
225 |
226 | for epoch in range(epochs):
227 | self.model.train()
228 | progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
229 | epoch_loss = 0.0
230 |
231 | for batch in progress_bar:
232 | with self.accelerator.accumulate(self.model):
233 | text_inputs = batch['text']
234 | mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
235 | mel_lengths = batch["mel_lengths"]
236 |
237 | if self.duration_predictor is not None:
238 | dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
239 | self.writer.add_scalar('duration loss', dur_loss.item(), global_step)
240 |
241 | loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths)
242 | self.accelerator.backward(loss)
243 |
244 | if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
245 | self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
246 |
247 | self.optimizer.step()
248 | self.scheduler.step()
249 | self.optimizer.zero_grad()
250 |
251 | if self.is_main:
252 | self.ema_model.update()
253 |
254 | if self.accelerator.is_local_main_process:
255 | logger.info(f"step {global_step+1}: loss = {loss.item():.4f}")
256 | self.writer.add_scalar('loss', loss.item(), global_step)
257 | self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
258 |
259 | global_step += 1
260 | epoch_loss += loss.item()
261 | progress_bar.set_postfix(loss=loss.item())
262 |
263 | if global_step % save_step == 0:
264 | self.save_checkpoint(global_step)
265 | self.writer.add_figure("mel/target", plot_spectrogram(mel_spec[0,:,:].detach().cpu().numpy()), global_step)
266 | self.writer.add_figure("mel/mask", plot_spectrogram(cond[0,:,:].detach().cpu().numpy()), global_step)
267 | self.writer.add_figure("mel/prediction", plot_spectrogram(pred[0,:,:].detach().cpu().numpy()), global_step)
268 |
269 | epoch_loss /= len(train_dataloader)
270 | if self.accelerator.is_local_main_process:
271 | logger.info(f"epoch {epoch+1}/{epochs} - average loss = {epoch_loss:.4f}")
272 | self.writer.add_scalar('epoch average loss', epoch_loss, epoch)
273 |
274 | self.writer.close()
275 |
--------------------------------------------------------------------------------
/e2_tts_pytorch/e2_tts.py:
--------------------------------------------------------------------------------
1 | """
2 | ein notation:
3 | b - batch
4 | n - sequence
5 | nt - text sequence
6 | nw - raw wave length
7 | d - dimension
8 | dt - dimension text
9 | """
10 |
11 | from __future__ import annotations
12 |
13 | from pathlib import Path
14 | from random import random
15 | from functools import partial
16 | from itertools import zip_longest
17 | from collections import namedtuple
18 |
19 | from typing import Literal, Callable
20 |
21 | import jaxtyping
22 | from beartype import beartype
23 |
24 | import torch
25 | import torch.nn.functional as F
26 | from torch import nn, tensor, Tensor, from_numpy
27 | from torch.nn import Module, ModuleList, Sequential, Linear
28 | from torch.nn.utils.rnn import pad_sequence
29 |
30 | import torchaudio
31 | from torchaudio.functional import DB_to_amplitude, resample
32 | from torchdiffeq import odeint
33 |
34 | import einx
35 | from einops.layers.torch import Rearrange
36 | from einops import einsum, rearrange, repeat, reduce, pack, unpack
37 |
38 | from x_transformers import (
39 | Attention,
40 | FeedForward,
41 | RMSNorm,
42 | AdaptiveRMSNorm,
43 | )
44 |
45 | from x_transformers.x_transformers import RotaryEmbedding
46 |
47 | from vocos import Vocos
48 |
49 | pad_sequence = partial(pad_sequence, batch_first = True)
50 |
51 | # constants
52 |
53 | class TorchTyping:
54 | def __init__(self, abstract_dtype):
55 | self.abstract_dtype = abstract_dtype
56 |
57 | def __getitem__(self, shapes: str):
58 | return self.abstract_dtype[Tensor, shapes]
59 |
60 | Float = TorchTyping(jaxtyping.Float)
61 | Int = TorchTyping(jaxtyping.Int)
62 | Bool = TorchTyping(jaxtyping.Bool)
63 |
64 | E2TTSReturn = namedtuple('E2TTS', ['loss', 'cond', 'pred'])
65 |
66 | # helpers
67 |
68 | def exists(v):
69 | return v is not None
70 |
71 | def default(v, d):
72 | return v if exists(v) else d
73 |
74 | def divisible_by(num, den):
75 | return (num % den) == 0
76 |
77 | class Identity(Module):
78 | def forward(self, x, **kwargs):
79 | return x
80 |
81 | # simple utf-8 tokenizer, since paper went character based
82 |
83 | def list_str_to_tensor(
84 | text: list[str],
85 | padding_value = -1
86 | ) -> Int['b nt']:
87 |
88 | list_tensors = [tensor([*bytes(t, 'UTF-8')]) for t in text]
89 | padded_tensor = pad_sequence(list_tensors, padding_value = -1)
90 | return padded_tensor
91 |
92 | # simple english phoneme-based tokenizer
93 |
94 | from g2p_en import G2p
95 |
96 | def get_g2p_en_encode():
97 | g2p = G2p()
98 |
99 | # used by @lucasnewman successfully here
100 | # https://github.com/lucasnewman/e2-tts-pytorch/blob/ljspeech-test/e2_tts_pytorch/e2_tts.py
101 |
102 | phoneme_to_index = g2p.p2idx
103 | num_phonemes = len(phoneme_to_index)
104 |
105 | extended_chars = [' ', ',', '.', '-', '!', '?', '\'', '"', '...', '..', '. .', '. . .', '. . . .', '. . . . .', '. ...', '... .', '.. ..']
106 | num_extended_chars = len(extended_chars)
107 |
108 | extended_chars_dict = {p: (num_phonemes + i) for i, p in enumerate(extended_chars)}
109 | phoneme_to_index = {**phoneme_to_index, **extended_chars_dict}
110 |
111 | def encode(
112 | text: list[str],
113 | padding_value = -1
114 | ) -> Int['b nt']:
115 |
116 | phonemes = [g2p(t) for t in text]
117 | list_tensors = [tensor([phoneme_to_index[p] for p in one_phoneme]) for one_phoneme in phonemes]
118 | padded_tensor = pad_sequence(list_tensors, padding_value = -1)
119 | return padded_tensor
120 |
121 | return encode, (num_phonemes + num_extended_chars)
122 |
123 | # tensor helpers
124 |
125 | def log(t, eps = 1e-5):
126 | return t.clamp(min = eps).log()
127 |
128 | def lens_to_mask(
129 | t: Int['b'],
130 | length: int | None = None
131 | ) -> Bool['b n']:
132 |
133 | if not exists(length):
134 | length = t.amax()
135 |
136 | seq = torch.arange(length, device = t.device)
137 | return einx.less('n, b -> b n', seq, t)
138 |
139 | def mask_from_start_end_indices(
140 | seq_len: Int['b'],
141 | start: Int['b'],
142 | end: Int['b']
143 | ):
144 | max_seq_len = seq_len.max().item()
145 | seq = torch.arange(max_seq_len, device = start.device).long()
146 | return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
147 |
148 | def mask_from_frac_lengths(
149 | seq_len: Int['b'],
150 | frac_lengths: Float['b'],
151 | max_length: int | None = None
152 | ):
153 | lengths = (frac_lengths * seq_len).long()
154 | max_start = seq_len - lengths
155 |
156 | rand = torch.rand_like(frac_lengths)
157 | start = (max_start * rand).long().clamp(min = 0)
158 | end = start + lengths
159 |
160 | out = mask_from_start_end_indices(seq_len, start, end)
161 |
162 | if exists(max_length):
163 | out = pad_to_length(out, max_length)
164 |
165 | return out
166 |
167 | def maybe_masked_mean(
168 | t: Float['b n d'],
169 | mask: Bool['b n'] | None = None
170 | ) -> Float['b d']:
171 |
172 | if not exists(mask):
173 | return t.mean(dim = 1)
174 |
175 | t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
176 | num = reduce(t, 'b n d -> b d', 'sum')
177 | den = reduce(mask.float(), 'b n -> b', 'sum')
178 |
179 | return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
180 |
181 | def pad_to_length(
182 | t: Tensor,
183 | length: int,
184 | value = None
185 | ):
186 | seq_len = t.shape[-1]
187 | if length > seq_len:
188 | t = F.pad(t, (0, length - seq_len), value = value)
189 |
190 | return t[..., :length]
191 |
192 | def interpolate_1d(
193 | x: Tensor,
194 | length: int,
195 | mode = 'bilinear'
196 | ):
197 | x = rearrange(x, 'n d -> 1 d n 1')
198 | x = F.interpolate(x, (length, 1), mode = mode)
199 | return rearrange(x, '1 d n 1 -> n d')
200 |
201 | # to mel spec
202 |
203 | class MelSpec(Module):
204 | def __init__(
205 | self,
206 | filter_length = 1024,
207 | hop_length = 256,
208 | win_length = 1024,
209 | n_mel_channels = 100,
210 | sampling_rate = 24_000,
211 | normalize = False,
212 | power = 1,
213 | norm = None,
214 | center = True,
215 | ):
216 | super().__init__()
217 | self.n_mel_channels = n_mel_channels
218 | self.sampling_rate = sampling_rate
219 |
220 | self.mel_stft = torchaudio.transforms.MelSpectrogram(
221 | sample_rate = sampling_rate,
222 | n_fft = filter_length,
223 | win_length = win_length,
224 | hop_length = hop_length,
225 | n_mels = n_mel_channels,
226 | power = power,
227 | center = center,
228 | normalized = normalize,
229 | norm = norm,
230 | )
231 |
232 | self.register_buffer('dummy', tensor(0), persistent = False)
233 |
234 | def forward(self, inp):
235 | if len(inp.shape) == 3:
236 | inp = rearrange(inp, 'b 1 nw -> b nw')
237 |
238 | assert len(inp.shape) == 2
239 |
240 | if self.dummy.device != inp.device:
241 | self.to(inp.device)
242 |
243 | mel = self.mel_stft(inp)
244 | mel = log(mel)
245 | return mel
246 |
247 | # adaln zero from DiT paper
248 |
249 | class AdaLNZero(Module):
250 | def __init__(
251 | self,
252 | dim,
253 | dim_condition = None,
254 | init_bias_value = -2.
255 | ):
256 | super().__init__()
257 | dim_condition = default(dim_condition, dim)
258 | self.to_gamma = nn.Linear(dim_condition, dim)
259 |
260 | nn.init.zeros_(self.to_gamma.weight)
261 | nn.init.constant_(self.to_gamma.bias, init_bias_value)
262 |
263 | def forward(self, x, *, condition):
264 | if condition.ndim == 2:
265 | condition = rearrange(condition, 'b d -> b 1 d')
266 |
267 | gamma = self.to_gamma(condition).sigmoid()
268 | return x * gamma
269 |
270 | # random projection fourier embedding
271 |
272 | class RandomFourierEmbed(Module):
273 | def __init__(self, dim):
274 | super().__init__()
275 | assert divisible_by(dim, 2)
276 | self.register_buffer('weights', torch.randn(dim // 2))
277 |
278 | def forward(self, x):
279 | freqs = einx.multiply('i, j -> i j', x, self.weights) * 2 * torch.pi
280 | fourier_embed, _ = pack((x, freqs.sin(), freqs.cos()), 'b *')
281 | return fourier_embed
282 |
283 | # character embedding
284 |
285 | class CharacterEmbed(Module):
286 | def __init__(
287 | self,
288 | dim,
289 | num_embeds = 256,
290 | ):
291 | super().__init__()
292 | self.dim = dim
293 | self.embed = nn.Embedding(num_embeds + 1, dim) # will just use 0 as the 'filler token'
294 |
295 | def forward(
296 | self,
297 | text: Int['b nt'],
298 | max_seq_len: int,
299 | **kwargs
300 | ) -> Float['b n d']:
301 |
302 | text = text + 1 # shift all other token ids up by 1 and use 0 as filler token
303 |
304 | text = text[:, :max_seq_len] # just curtail if character tokens are more than the mel spec tokens, one of the edge cases the paper did not address
305 | text = pad_to_length(text, max_seq_len, value = 0)
306 |
307 | return self.embed(text)
308 |
309 | class InterpolatedCharacterEmbed(Module):
310 | def __init__(
311 | self,
312 | dim,
313 | num_embeds = 256,
314 | ):
315 | super().__init__()
316 | self.dim = dim
317 | self.embed = nn.Embedding(num_embeds, dim)
318 |
319 | self.abs_pos_mlp = Sequential(
320 | Rearrange('... -> ... 1'),
321 | Linear(1, dim),
322 | nn.SiLU(),
323 | Linear(dim, dim)
324 | )
325 |
326 | def forward(
327 | self,
328 | text: Int['b nt'],
329 | max_seq_len: int,
330 | mask: Bool['b n'] | None = None
331 | ) -> Float['b n d']:
332 |
333 | device = text.device
334 | seq = torch.arange(max_seq_len, device = device)
335 |
336 | mask = default(mask, (None,))
337 |
338 | interp_embeds = []
339 | interp_abs_positions = []
340 |
341 | for one_text, one_mask in zip_longest(text, mask):
342 |
343 | valid_text = one_text >= 0
344 | one_text = one_text[valid_text]
345 | one_text_embed = self.embed(one_text)
346 |
347 | # save the absolute positions
348 |
349 | text_seq_len = one_text.shape[0]
350 |
351 | # determine audio sequence length from mask
352 |
353 | audio_seq_len = max_seq_len
354 | if exists(one_mask):
355 | audio_seq_len = one_mask.sum().long().item()
356 |
357 | # interpolate text embedding to audio embedding length
358 |
359 | interp_text_embed = interpolate_1d(one_text_embed, audio_seq_len)
360 | interp_abs_pos = torch.linspace(0, text_seq_len, audio_seq_len, device = device)
361 |
362 | interp_embeds.append(interp_text_embed)
363 | interp_abs_positions.append(interp_abs_pos)
364 |
365 | interp_embeds = pad_sequence(interp_embeds)
366 | interp_abs_positions = pad_sequence(interp_abs_positions)
367 |
368 | interp_embeds = F.pad(interp_embeds, (0, 0, 0, max_seq_len - interp_embeds.shape[-2]))
369 | interp_abs_positions = pad_to_length(interp_abs_positions, max_seq_len)
370 |
371 | # pass interp absolute positions through mlp for implicit positions
372 |
373 | interp_embeds = interp_embeds + self.abs_pos_mlp(interp_abs_positions)
374 |
375 | if exists(mask):
376 | interp_embeds = einx.where('b n, b n d, -> b n d', mask, interp_embeds, 0.)
377 |
378 | return interp_embeds
379 |
380 | # text audio cross conditioning in multistream setup
381 |
382 | class TextAudioCrossCondition(Module):
383 | def __init__(
384 | self,
385 | dim,
386 | dim_text,
387 | cond_audio_to_text = True,
388 | ):
389 | super().__init__()
390 | self.text_to_audio = nn.Linear(dim_text + dim, dim, bias = False)
391 | nn.init.zeros_(self.text_to_audio.weight)
392 |
393 | self.cond_audio_to_text = cond_audio_to_text
394 |
395 | if cond_audio_to_text:
396 | self.audio_to_text = nn.Linear(dim + dim_text, dim_text, bias = False)
397 | nn.init.zeros_(self.audio_to_text.weight)
398 |
399 | def forward(
400 | self,
401 | audio: Float['b n d'],
402 | text: Float['b n dt']
403 | ):
404 | audio_text, _ = pack((audio, text), 'b n *')
405 |
406 | text_cond = self.text_to_audio(audio_text)
407 | audio_cond = self.audio_to_text(audio_text) if self.cond_audio_to_text else 0.
408 |
409 | return audio + text_cond, text + audio_cond
410 |
411 | # attention and transformer backbone
412 | # for use in both e2tts as well as duration module
413 |
414 | class Transformer(Module):
415 | @beartype
416 | def __init__(
417 | self,
418 | *,
419 | dim,
420 | dim_text = None, # will default to half of audio dimension
421 | depth = 8,
422 | heads = 8,
423 | dim_head = 64,
424 | ff_mult = 4,
425 | text_depth = None,
426 | text_heads = None,
427 | text_dim_head = None,
428 | text_ff_mult = None,
429 | cond_on_time = True,
430 | abs_pos_emb = True,
431 | max_seq_len = 8192,
432 | dropout = 0.1,
433 | num_registers = 32,
434 | attn_kwargs: dict = dict(
435 | gate_value_heads = True,
436 | softclamp_logits = True,
437 | ),
438 | ff_kwargs: dict = dict(),
439 | ):
440 | super().__init__()
441 | assert divisible_by(depth, 2), 'depth needs to be even'
442 |
443 | # absolute positional embedding
444 |
445 | self.max_seq_len = max_seq_len
446 | self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if abs_pos_emb else None
447 |
448 | self.dim = dim
449 |
450 | dim_text = default(dim_text, dim // 2)
451 | self.dim_text = dim_text
452 |
453 | text_heads = default(text_heads, heads)
454 | text_dim_head = default(text_dim_head, dim_head)
455 | text_ff_mult = default(text_ff_mult, ff_mult)
456 | text_depth = default(text_depth, depth)
457 |
458 | assert 1 <= text_depth <= depth, 'must have at least 1 layer of text conditioning, but less than total number of speech layers'
459 |
460 | self.depth = depth
461 | self.layers = ModuleList([])
462 |
463 | # registers
464 |
465 | self.num_registers = num_registers
466 | self.registers = nn.Parameter(torch.zeros(num_registers, dim))
467 | nn.init.normal_(self.registers, std = 0.02)
468 |
469 | self.text_registers = nn.Parameter(torch.zeros(num_registers, dim_text))
470 | nn.init.normal_(self.text_registers, std = 0.02)
471 |
472 | # rotary embedding
473 |
474 | self.rotary_emb = RotaryEmbedding(dim_head)
475 | self.text_rotary_emb = RotaryEmbedding(dim_head)
476 |
477 | # time conditioning
478 | # will use adaptive rmsnorm
479 |
480 | self.cond_on_time = cond_on_time
481 | rmsnorm_klass = RMSNorm if not cond_on_time else AdaptiveRMSNorm
482 | postbranch_klass = Identity if not cond_on_time else partial(AdaLNZero, dim = dim)
483 |
484 | self.time_cond_mlp = Identity()
485 |
486 | if cond_on_time:
487 | self.time_cond_mlp = Sequential(
488 | RandomFourierEmbed(dim),
489 | Linear(dim + 1, dim),
490 | nn.SiLU()
491 | )
492 |
493 | for ind in range(depth):
494 | is_later_half = ind >= (depth // 2)
495 | has_text = ind < text_depth
496 |
497 | # speech related
498 |
499 | attn_norm = rmsnorm_klass(dim)
500 | attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, **attn_kwargs)
501 | attn_adaln_zero = postbranch_klass()
502 |
503 | ff_norm = rmsnorm_klass(dim)
504 | ff = FeedForward(dim = dim, glu = True, mult = ff_mult, dropout = dropout, **ff_kwargs)
505 | ff_adaln_zero = postbranch_klass()
506 |
507 | skip_proj = Linear(dim * 2, dim, bias = False) if is_later_half else None
508 |
509 | speech_modules = ModuleList([
510 | skip_proj,
511 | attn_norm,
512 | attn,
513 | attn_adaln_zero,
514 | ff_norm,
515 | ff,
516 | ff_adaln_zero,
517 | ])
518 |
519 | text_modules = None
520 |
521 | if has_text:
522 | # text related
523 |
524 | text_attn_norm = RMSNorm(dim_text)
525 | text_attn = Attention(dim = dim_text, heads = text_heads, dim_head = text_dim_head, dropout = dropout, **attn_kwargs)
526 |
527 | text_ff_norm = RMSNorm(dim_text)
528 | text_ff = FeedForward(dim = dim_text, glu = True, mult = text_ff_mult, dropout = dropout, **ff_kwargs)
529 |
530 | # cross condition
531 |
532 | is_last = ind == (text_depth - 1)
533 |
534 | cross_condition = TextAudioCrossCondition(dim = dim, dim_text = dim_text, cond_audio_to_text = not is_last)
535 |
536 | text_modules = ModuleList([
537 | text_attn_norm,
538 | text_attn,
539 | text_ff_norm,
540 | text_ff,
541 | cross_condition
542 | ])
543 |
544 | self.layers.append(ModuleList([
545 | speech_modules,
546 | text_modules
547 | ]))
548 |
549 | self.final_norm = RMSNorm(dim)
550 |
551 | def forward(
552 | self,
553 | x: Float['b n d'],
554 | times: Float['b'] | Float[''] | None = None,
555 | mask: Bool['b n'] | None = None,
556 | text_embed: Float['b n dt'] | None = None,
557 | ):
558 | batch, seq_len, device = *x.shape[:2], x.device
559 |
560 | assert not (exists(times) ^ self.cond_on_time), '`times` must be passed in if `cond_on_time` is set to `True` and vice versa'
561 |
562 | # handle absolute positions if needed
563 |
564 | if exists(self.abs_pos_emb):
565 | assert seq_len <= self.max_seq_len, f'{seq_len} exceeds the set `max_seq_len` ({self.max_seq_len}) on Transformer'
566 | seq = torch.arange(seq_len, device = device)
567 | x = x + self.abs_pos_emb(seq)
568 |
569 | # handle adaptive rmsnorm kwargs
570 |
571 | norm_kwargs = dict()
572 |
573 | if exists(times):
574 | if times.ndim == 0:
575 | times = repeat(times, ' -> b', b = batch)
576 |
577 | times = self.time_cond_mlp(times)
578 | norm_kwargs.update(condition = times)
579 |
580 | # register tokens
581 |
582 | registers = repeat(self.registers, 'r d -> b r d', b = batch)
583 | x, registers_packed_shape = pack((registers, x), 'b * d')
584 |
585 | if exists(mask):
586 | mask = F.pad(mask, (self.num_registers, 0), value = True)
587 |
588 | # rotary embedding
589 |
590 | rotary_pos_emb = self.rotary_emb.forward_from_seq_len(x.shape[-2])
591 |
592 | # text related
593 |
594 | if exists(text_embed):
595 | text_rotary_pos_emb = self.text_rotary_emb.forward_from_seq_len(x.shape[-2])
596 |
597 | text_registers = repeat(self.text_registers, 'r d -> b r d', b = batch)
598 | text_embed, _ = pack((text_registers, text_embed), 'b * d')
599 |
600 | # skip connection related stuff
601 |
602 | skips = []
603 |
604 | # go through the layers
605 |
606 | for ind, (speech_modules, text_modules) in enumerate(self.layers):
607 | layer = ind + 1
608 |
609 | (
610 | maybe_skip_proj,
611 | attn_norm,
612 | attn,
613 | maybe_attn_adaln_zero,
614 | ff_norm,
615 | ff,
616 | maybe_ff_adaln_zero
617 | ) = speech_modules
618 |
619 | # smaller text transformer
620 |
621 | if exists(text_embed) and exists(text_modules):
622 |
623 | (
624 | text_attn_norm,
625 | text_attn,
626 | text_ff_norm,
627 | text_ff,
628 | cross_condition
629 | ) = text_modules
630 |
631 | text_embed = text_attn(text_attn_norm(text_embed), rotary_pos_emb = text_rotary_pos_emb, mask = mask) + text_embed
632 |
633 | text_embed = text_ff(text_ff_norm(text_embed)) + text_embed
634 |
635 | x, text_embed = cross_condition(x, text_embed)
636 |
637 | # skip connection logic
638 |
639 | is_first_half = layer <= (self.depth // 2)
640 | is_later_half = not is_first_half
641 |
642 | if is_first_half:
643 | skips.append(x)
644 |
645 | if is_later_half:
646 | skip = skips.pop()
647 | x = torch.cat((x, skip), dim = -1)
648 | x = maybe_skip_proj(x)
649 |
650 | # attention and feedforward blocks
651 |
652 | attn_out = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask)
653 |
654 | x = x + maybe_attn_adaln_zero(attn_out, **norm_kwargs)
655 |
656 | ff_out = ff(ff_norm(x, **norm_kwargs))
657 |
658 | x = x + maybe_ff_adaln_zero(ff_out, **norm_kwargs)
659 |
660 | assert len(skips) == 0
661 |
662 | _, x = unpack(x, registers_packed_shape, 'b * d')
663 |
664 | return self.final_norm(x)
665 |
666 | # main classes
667 |
668 | class DurationPredictor(Module):
669 | @beartype
670 | def __init__(
671 | self,
672 | transformer: dict | Transformer,
673 | num_channels = None,
674 | mel_spec_kwargs: dict = dict(),
675 | char_embed_kwargs: dict = dict(),
676 | text_num_embeds = None,
677 | tokenizer: (
678 | Literal['char_utf8', 'phoneme_en'] |
679 | Callable[[list[str]], Int['b nt']]
680 | ) = 'char_utf8'
681 | ):
682 | super().__init__()
683 |
684 | if isinstance(transformer, dict):
685 | transformer = Transformer(
686 | **transformer,
687 | cond_on_time = False
688 | )
689 |
690 | # mel spec
691 |
692 | self.mel_spec = MelSpec(**mel_spec_kwargs)
693 | self.num_channels = default(num_channels, self.mel_spec.n_mel_channels)
694 |
695 | self.transformer = transformer
696 |
697 | dim = transformer.dim
698 | dim_text = transformer.dim_text
699 |
700 | self.dim = dim
701 |
702 | self.proj_in = Linear(self.num_channels, self.dim)
703 |
704 | # tokenizer and text embed
705 |
706 | if callable(tokenizer):
707 | assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function'
708 | self.tokenizer = tokenizer
709 | elif tokenizer == 'char_utf8':
710 | text_num_embeds = 256
711 | self.tokenizer = list_str_to_tensor
712 | elif tokenizer == 'phoneme_en':
713 | self.tokenizer, text_num_embeds = get_g2p_en_encode()
714 | else:
715 | raise ValueError(f'unknown tokenizer string {tokenizer}')
716 |
717 | self.embed_text = CharacterEmbed(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs)
718 |
719 | # to prediction
720 |
721 | self.to_pred = Sequential(
722 | Linear(dim, 1, bias = False),
723 | nn.Softplus(),
724 | Rearrange('... 1 -> ...')
725 | )
726 |
727 | def forward(
728 | self,
729 | x: Float['b n d'] | Float['b nw'],
730 | *,
731 | text: Int['b nt'] | list[str] | None = None,
732 | lens: Int['b'] | None = None,
733 | return_loss = True
734 | ):
735 | # raw wave
736 |
737 | if x.ndim == 2:
738 | x = self.mel_spec(x)
739 | x = rearrange(x, 'b d n -> b n d')
740 | assert x.shape[-1] == self.dim
741 |
742 | x = self.proj_in(x)
743 |
744 | batch, seq_len, device = *x.shape[:2], x.device
745 |
746 | # text
747 |
748 | text_embed = None
749 |
750 | if exists(text):
751 | if isinstance(text, list):
752 | text = list_str_to_tensor(text).to(device)
753 | assert text.shape[0] == batch
754 |
755 | text_embed = self.embed_text(text, seq_len)
756 |
757 | # handle lengths (duration)
758 |
759 | if not exists(lens):
760 | lens = torch.full((batch,), seq_len, device = device)
761 |
762 | mask = lens_to_mask(lens, length = seq_len)
763 |
764 | # if returning a loss, mask out randomly from an index and have it predict the duration
765 |
766 | if return_loss:
767 | rand_frac_index = x.new_zeros(batch).uniform_(0, 1)
768 | rand_index = (rand_frac_index * lens).long()
769 |
770 | seq = torch.arange(seq_len, device = device)
771 | mask &= einx.less('n, b -> b n', seq, rand_index)
772 |
773 | # attending
774 |
775 | x = self.transformer(
776 | x,
777 | mask = mask,
778 | text_embed = text_embed,
779 | )
780 |
781 | x = maybe_masked_mean(x, mask)
782 |
783 | pred = self.to_pred(x)
784 |
785 | # return the prediction if not returning loss
786 |
787 | if not return_loss:
788 | return pred
789 |
790 | # loss
791 |
792 | return F.mse_loss(pred, lens.float())
793 |
794 | class E2TTS(Module):
795 |
796 | @beartype
797 | def __init__(
798 | self,
799 | transformer: dict | Transformer = None,
800 | duration_predictor: dict | DurationPredictor | None = None,
801 | odeint_kwargs: dict = dict(
802 | atol = 1e-5,
803 | rtol = 1e-5,
804 | method = 'midpoint'
805 | ),
806 | cond_drop_prob = 0.25,
807 | num_channels = None,
808 | mel_spec_module: Module | None = None,
809 | char_embed_kwargs: dict = dict(),
810 | mel_spec_kwargs: dict = dict(),
811 | frac_lengths_mask: tuple[float, float] = (0.7, 1.),
812 | concat_cond = False,
813 | interpolated_text = False,
814 | text_num_embeds: int | None = None,
815 | tokenizer: (
816 | Literal['char_utf8', 'phoneme_en'] |
817 | Callable[[list[str]], Int['b nt']]
818 | ) = 'char_utf8',
819 | use_vocos = True,
820 | pretrained_vocos_path = 'charactr/vocos-mel-24khz',
821 | sampling_rate: int | None = None
822 | ):
823 | super().__init__()
824 |
825 | if isinstance(transformer, dict):
826 | transformer = Transformer(
827 | **transformer,
828 | cond_on_time = True
829 | )
830 |
831 | if isinstance(duration_predictor, dict):
832 | duration_predictor = DurationPredictor(**duration_predictor)
833 |
834 | self.transformer = transformer
835 |
836 | dim = transformer.dim
837 | dim_text = transformer.dim_text
838 |
839 | self.dim = dim
840 | self.dim_text = dim_text
841 |
842 | self.frac_lengths_mask = frac_lengths_mask
843 |
844 | self.duration_predictor = duration_predictor
845 |
846 | # sampling
847 |
848 | self.odeint_kwargs = odeint_kwargs
849 |
850 | # mel spec
851 |
852 | self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
853 | num_channels = default(num_channels, self.mel_spec.n_mel_channels)
854 |
855 | self.num_channels = num_channels
856 | self.sampling_rate = default(sampling_rate, getattr(self.mel_spec, 'sampling_rate', None))
857 |
858 | # whether to concat condition and project rather than project both and sum
859 |
860 | self.concat_cond = concat_cond
861 |
862 | if concat_cond:
863 | self.proj_in = nn.Linear(num_channels * 2, dim)
864 | else:
865 | self.proj_in = nn.Linear(num_channels, dim)
866 | self.cond_proj_in = nn.Linear(num_channels, dim)
867 |
868 | # to prediction
869 |
870 | self.to_pred = Linear(dim, num_channels)
871 |
872 | # tokenizer and text embed
873 |
874 | if callable(tokenizer):
875 | assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function'
876 | self.tokenizer = tokenizer
877 | elif tokenizer == 'char_utf8':
878 | text_num_embeds = 256
879 | self.tokenizer = list_str_to_tensor
880 | elif tokenizer == 'phoneme_en':
881 | self.tokenizer, text_num_embeds = get_g2p_en_encode()
882 | else:
883 | raise ValueError(f'unknown tokenizer string {tokenizer}')
884 |
885 | self.cond_drop_prob = cond_drop_prob
886 |
887 | # text embedding
888 |
889 | text_embed_klass = CharacterEmbed if not interpolated_text else InterpolatedCharacterEmbed
890 |
891 | self.embed_text = text_embed_klass(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs)
892 |
893 | # default vocos for mel -> audio
894 |
895 | self.vocos = Vocos.from_pretrained(pretrained_vocos_path) if use_vocos else None
896 |
897 | @property
898 | def device(self):
899 | return next(self.parameters()).device
900 |
901 | def transformer_with_pred_head(
902 | self,
903 | x: Float['b n d'],
904 | cond: Float['b n d'],
905 | times: Float['b'],
906 | mask: Bool['b n'] | None = None,
907 | text: Int['b nt'] | None = None,
908 | drop_text_cond: bool | None = None
909 | ):
910 | seq_len = x.shape[-2]
911 | drop_text_cond = default(drop_text_cond, self.training and random() < self.cond_drop_prob)
912 |
913 | if self.concat_cond:
914 | # concat condition, given as using voicebox-like scheme
915 | x = torch.cat((cond, x), dim = -1)
916 |
917 | x = self.proj_in(x)
918 |
919 | if not self.concat_cond:
920 | # an alternative is to simply sum the condition
921 | # seems to work fine
922 |
923 | cond = self.cond_proj_in(cond)
924 | x = x + cond
925 |
926 | # whether to use a text embedding
927 |
928 | text_embed = None
929 | if exists(text) and not drop_text_cond:
930 | text_embed = self.embed_text(text, seq_len, mask = mask)
931 |
932 | # attend
933 |
934 | attended = self.transformer(
935 | x,
936 | times = times,
937 | mask = mask,
938 | text_embed = text_embed
939 | )
940 |
941 | return self.to_pred(attended)
942 |
943 | def cfg_transformer_with_pred_head(
944 | self,
945 | *args,
946 | cfg_strength: float = 1.,
947 | **kwargs,
948 | ):
949 |
950 | pred = self.transformer_with_pred_head(*args, drop_text_cond = False, **kwargs)
951 |
952 | if cfg_strength < 1e-5:
953 | return pred
954 |
955 | null_pred = self.transformer_with_pred_head(*args, drop_text_cond = True, **kwargs)
956 |
957 | return pred + (pred - null_pred) * cfg_strength
958 |
959 | @torch.no_grad()
960 | def sample(
961 | self,
962 | cond: Float['b n d'] | Float['b nw'],
963 | *,
964 | text: Int['b nt'] | list[str] | None = None,
965 | lens: Int['b'] | None = None,
966 | duration: int | Int['b'] | None = None,
967 | steps = 32,
968 | cfg_strength = 1., # they used a classifier free guidance strength of 1.
969 | max_duration = 4096, # in case the duration predictor goes haywire
970 | vocoder: Callable[[Float['b d n']], list[Float['_']]] | None = None,
971 | return_raw_output: bool | None = None,
972 | save_to_filename: str | None = None
973 | ) -> (
974 | Float['b n d'],
975 | list[Float['_']]
976 | ):
977 | self.eval()
978 |
979 | # raw wave
980 |
981 | if cond.ndim == 2:
982 | cond = self.mel_spec(cond)
983 | cond = rearrange(cond, 'b d n -> b n d')
984 | assert cond.shape[-1] == self.num_channels
985 |
986 | batch, cond_seq_len, device = *cond.shape[:2], cond.device
987 |
988 | if not exists(lens):
989 | lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
990 |
991 | # text
992 |
993 | if isinstance(text, list):
994 | text = self.tokenizer(text).to(device)
995 | assert text.shape[0] == batch
996 |
997 | if exists(text):
998 | text_lens = (text != -1).sum(dim = -1)
999 | lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
1000 |
1001 | # duration
1002 |
1003 | cond_mask = lens_to_mask(lens)
1004 |
1005 | if exists(duration):
1006 | if isinstance(duration, int):
1007 | duration = torch.full((batch,), duration, device = device, dtype = torch.long)
1008 |
1009 | elif exists(self.duration_predictor):
1010 | duration = self.duration_predictor(cond, text = text, lens = lens, return_loss = False).long()
1011 |
1012 | duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
1013 | duration = duration.clamp(max = max_duration)
1014 |
1015 | assert duration.shape[0] == batch
1016 |
1017 | max_duration = duration.amax()
1018 |
1019 | cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
1020 | cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
1021 | cond_mask = rearrange(cond_mask, '... -> ... 1')
1022 |
1023 | mask = lens_to_mask(duration)
1024 |
1025 | # neural ode
1026 |
1027 | def fn(t, x):
1028 | # at each step, conditioning is fixed
1029 |
1030 | step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
1031 |
1032 | # predict flow
1033 |
1034 | return self.cfg_transformer_with_pred_head(
1035 | x,
1036 | step_cond,
1037 | times = t,
1038 | text = text,
1039 | mask = mask,
1040 | cfg_strength = cfg_strength
1041 | )
1042 |
1043 | y0 = torch.randn_like(cond)
1044 | t = torch.linspace(0, 1, steps, device = self.device)
1045 |
1046 | trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
1047 | sampled = trajectory[-1]
1048 |
1049 | out = sampled
1050 |
1051 | out = torch.where(cond_mask, cond, out)
1052 |
1053 | # able to return raw untransformed output, if not using mel rep
1054 |
1055 | if exists(return_raw_output) and return_raw_output:
1056 | return out
1057 |
1058 | # take care of transforming mel to audio if `vocoder` is passed in, or if `use_vocos` is turned on
1059 |
1060 | if exists(vocoder):
1061 | assert not exists(self.vocos), '`use_vocos` should not be turned on if you are passing in a custom `vocoder` on sampling'
1062 | out = rearrange(out, 'b n d -> b d n')
1063 | out = vocoder(out)
1064 |
1065 | elif exists(self.vocos):
1066 |
1067 | audio = []
1068 | for mel, one_mask in zip(out, mask):
1069 | one_out = DB_to_amplitude(mel[one_mask], ref = 1., power = 0.5)
1070 |
1071 | one_out = rearrange(one_out, 'n d -> 1 d n')
1072 | one_audio = self.vocos.decode(one_out)
1073 | one_audio = rearrange(one_audio, '1 nw -> nw')
1074 | audio.append(one_audio)
1075 |
1076 | out = audio
1077 |
1078 | if exists(save_to_filename):
1079 | assert exists(vocoder) or exists(self.vocos)
1080 | assert exists(self.sampling_rate)
1081 |
1082 | path = Path(save_to_filename)
1083 | parent_path = path.parents[0]
1084 | parent_path.mkdir(exist_ok = True, parents = True)
1085 |
1086 | for ind, one_audio in enumerate(out):
1087 | one_audio = rearrange(one_audio, 'nw -> 1 nw')
1088 | save_path = str(parent_path / f'{ind + 1}.{path.name}')
1089 | torchaudio.save(save_path, one_audio.detach().cpu(), sample_rate = self.sampling_rate)
1090 |
1091 | return out
1092 |
1093 | def forward(
1094 | self,
1095 | inp: Float['b n d'] | Float['b nw'], # mel or raw wave
1096 | *,
1097 | text: Int['b nt'] | list[str] | None = None,
1098 | times: Int['b'] | None = None,
1099 | lens: Int['b'] | None = None,
1100 | ):
1101 | # handle raw wave
1102 |
1103 | if inp.ndim == 2:
1104 | inp = self.mel_spec(inp)
1105 | inp = rearrange(inp, 'b d n -> b n d')
1106 | assert inp.shape[-1] == self.num_channels
1107 |
1108 | batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, self.device
1109 |
1110 | # handle text as string
1111 |
1112 | if isinstance(text, list):
1113 | text = self.tokenizer(text).to(device)
1114 | assert text.shape[0] == batch
1115 |
1116 | # lens and mask
1117 |
1118 | if not exists(lens):
1119 | lens = torch.full((batch,), seq_len, device = device)
1120 |
1121 | mask = lens_to_mask(lens, length = seq_len)
1122 |
1123 | # get a random span to mask out for training conditionally
1124 |
1125 | frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
1126 | rand_span_mask = mask_from_frac_lengths(lens, frac_lengths, max_length = seq_len)
1127 |
1128 | if exists(mask):
1129 | rand_span_mask &= mask
1130 |
1131 | # mel is x1
1132 |
1133 | x1 = inp
1134 |
1135 | # main conditional flow training logic
1136 | # just ~5 loc
1137 |
1138 | # x0 is gaussian noise
1139 |
1140 | x0 = torch.randn_like(x1)
1141 |
1142 | # t is random times from above
1143 |
1144 | times = torch.rand((batch,), dtype = dtype, device = self.device)
1145 | t = rearrange(times, 'b -> b 1 1')
1146 |
1147 | # sample xt (w in the paper)
1148 |
1149 | w = (1. - t) * x0 + t * x1
1150 |
1151 | flow = x1 - x0
1152 |
1153 | # only predict what is within the random mask span for infilling
1154 |
1155 | cond = einx.where(
1156 | 'b n, b n d, b n d -> b n d',
1157 | rand_span_mask,
1158 | torch.zeros_like(x1), x1
1159 | )
1160 |
1161 | # transformer and prediction head
1162 |
1163 | pred = self.transformer_with_pred_head(w, cond, times = times, text = text, mask = mask)
1164 |
1165 | # flow matching loss
1166 |
1167 | loss = F.mse_loss(pred, flow, reduction = 'none')
1168 |
1169 | loss = loss[rand_span_mask].mean()
1170 |
1171 | return E2TTSReturn(loss, cond, pred)
1172 |
--------------------------------------------------------------------------------