├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── e2-tts.png
├── e2_tts_pytorch
├── __init__.py
├── e2_tts.py
└── trainer.py
├── pyproject.toml
└── train_example.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | ## E2 TTS - Pytorch
5 |
6 | Implementation of E2-TTS, Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS, in Pytorch
7 |
8 | The repository differs from the paper in that it uses a multistream transformer for text and audio, with conditioning done every transformer block in the E2 manner.
9 |
10 | It also includes an improvisation that was proven out by Manmay, where the text is simply interpolated to the length of the audio for conditioning. You can try this by setting `interpolated_text = True` on `E2TTS`
11 |
12 | ## Appreciation
13 |
14 | - Manmay for contributing working end-to-end training code!
15 |
16 | - Lucas Newman for the code contributions, helpful feedback, and for sharing the first set of positive experiments!
17 |
18 | - Jing for sharing the second positive result with a multilingual (English + Chinese) dataset!
19 |
20 | - Coice and Manmay for reporting the third and fourth successful runs. Farewell alignment engineering
21 |
22 | ## Install
23 |
24 | ```bash
25 | $ pip install e2-tts-pytorch
26 | ```
27 |
28 | ## Usage
29 |
30 | ```python
31 | import torch
32 |
33 | from e2_tts_pytorch import (
34 | E2TTS,
35 | DurationPredictor
36 | )
37 |
38 | duration_predictor = DurationPredictor(
39 | transformer = dict(
40 | dim = 512,
41 | depth = 8,
42 | )
43 | )
44 |
45 | mel = torch.randn(2, 1024, 100)
46 | text = ['Hello', 'Goodbye']
47 |
48 | loss = duration_predictor(mel, text = text)
49 | loss.backward()
50 |
51 | e2tts = E2TTS(
52 | duration_predictor = duration_predictor,
53 | transformer = dict(
54 | dim = 512,
55 | depth = 8
56 | ),
57 | )
58 |
59 | out = e2tts(mel, text = text)
60 | out.loss.backward()
61 |
62 | sampled = e2tts.sample(mel[:, :5], text = text)
63 |
64 | ```
65 |
66 | ## Related Works
67 |
68 | - [Nanospeech](https://github.com/lucasnewman/nanospeech) by [Lucas Newman](https://github.com/lucasnewman), which contains training code, working examples, as well as interoperable MLX version!
69 |
70 | ## Citations
71 |
72 | ```bibtex
73 | @inproceedings{Eskimez2024E2TE,
74 | title = {E2 TTS: Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS},
75 | author = {Sefik Emre Eskimez and Xiaofei Wang and Manthan Thakker and Canrun Li and Chung-Hsien Tsai and Zhen Xiao and Hemin Yang and Zirun Zhu and Min Tang and Xu Tan and Yanqing Liu and Sheng Zhao and Naoyuki Kanda},
76 | year = {2024},
77 | url = {https://api.semanticscholar.org/CorpusID:270738197}
78 | }
79 | ```
80 |
81 | ```bibtex
82 | @inproceedings{Darcet2023VisionTN,
83 | title = {Vision Transformers Need Registers},
84 | author = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
85 | year = {2023},
86 | url = {https://api.semanticscholar.org/CorpusID:263134283}
87 | }
88 | ```
89 |
90 | ```bibtex
91 | @article{Bao2022AllAW,
92 | title = {All are Worth Words: A ViT Backbone for Diffusion Models},
93 | author = {Fan Bao and Shen Nie and Kaiwen Xue and Yue Cao and Chongxuan Li and Hang Su and Jun Zhu},
94 | journal = {2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
95 | year = {2022},
96 | pages = {22669-22679},
97 | url = {https://api.semanticscholar.org/CorpusID:253581703}
98 | }
99 | ```
100 |
101 | ```bibtex
102 | @article{Burtsev2021MultiStreamT,
103 | title = {Multi-Stream Transformers},
104 | author = {Mikhail S. Burtsev and Anna Rumshisky},
105 | journal = {ArXiv},
106 | year = {2021},
107 | volume = {abs/2107.10342},
108 | url = {https://api.semanticscholar.org/CorpusID:236171087}
109 | }
110 | ```
111 |
112 | ```bibtex
113 | @inproceedings{Sadat2024EliminatingOA,
114 | title = {Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models},
115 | author = {Seyedmorteza Sadat and Otmar Hilliges and Romann M. Weber},
116 | year = {2024},
117 | url = {https://api.semanticscholar.org/CorpusID:273098845}
118 | }
119 | ```
120 |
121 | ```bibtex
122 | @article{Gulati2020ConformerCT,
123 | title = {Conformer: Convolution-augmented Transformer for Speech Recognition},
124 | author = {Anmol Gulati and James Qin and Chung-Cheng Chiu and Niki Parmar and Yu Zhang and Jiahui Yu and Wei Han and Shibo Wang and Zhengdong Zhang and Yonghui Wu and Ruoming Pang},
125 | journal = {ArXiv},
126 | year = {2020},
127 | volume = {abs/2005.08100},
128 | url = {https://api.semanticscholar.org/CorpusID:218674528}
129 | }
130 | ```
131 |
132 | ```bibtex
133 | @article{Yang2024ConsistencyFM,
134 | title = {Consistency Flow Matching: Defining Straight Flows with Velocity Consistency},
135 | author = {Ling Yang and Zixiang Zhang and Zhilong Zhang and Xingchao Liu and Minkai Xu and Wentao Zhang and Chenlin Meng and Stefano Ermon and Bin Cui},
136 | journal = {ArXiv},
137 | year = {2024},
138 | volume = {abs/2407.02398},
139 | url = {https://api.semanticscholar.org/CorpusID:270878436}
140 | }
141 | ```
142 |
143 | ```bibtex
144 | @article{Li2024SwitchEA,
145 | title = {Switch EMA: A Free Lunch for Better Flatness and Sharpness},
146 | author = {Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li},
147 | journal = {ArXiv},
148 | year = {2024},
149 | volume = {abs/2402.09240},
150 | url = {https://api.semanticscholar.org/CorpusID:267657558}
151 | }
152 | ```
153 |
154 | ```bibtex
155 | @inproceedings{Zhou2024ValueRL,
156 | title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
157 | author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
158 | year = {2024},
159 | url = {https://api.semanticscholar.org/CorpusID:273532030}
160 | }
161 | ```
162 |
163 | ```bibtex
164 | @inproceedings{Duvvuri2024LASERAW,
165 | title = {LASER: Attention with Exponential Transformation},
166 | author = {Sai Surya Duvvuri and Inderjit S. Dhillon},
167 | year = {2024},
168 | url = {https://api.semanticscholar.org/CorpusID:273849947}
169 | }
170 | ```
171 |
172 | ```bibtex
173 | @article{Zhu2024HyperConnections,
174 | title = {Hyper-Connections},
175 | author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
176 | journal = {ArXiv},
177 | year = {2024},
178 | volume = {abs/2409.19606},
179 | url = {https://api.semanticscholar.org/CorpusID:272987528}
180 | }
181 | ```
182 |
183 | ```bibtex
184 | @inproceedings{Lu2023MusicSS,
185 | title = {Music Source Separation with Band-Split RoPE Transformer},
186 | author = {Wei-Tsung Lu and Ju-Chiang Wang and Qiuqiang Kong and Yun-Ning Hung},
187 | year = {2023},
188 | url = {https://api.semanticscholar.org/CorpusID:261556702}
189 | }
190 | ```
191 |
192 | ```bibtex
193 | @inproceedings{Dong2025FANformerIL,
194 | title = {FANformer: Improving Large Language Models Through Effective Periodicity Modeling},
195 | author = {Yi Dong and Ge Li and Xue Jiang and Yongding Tao and Kechi Zhang and Hao Zhu and Huanyu Liu and Jiazheng Ding and Jia Li and Jinliang Deng and Hong Mei},
196 | year = {2025},
197 | url = {https://api.semanticscholar.org/CorpusID:276724636}
198 | }
199 | ```
200 |
201 | ```bibtex
202 | @article{Karras2024GuidingAD,
203 | title = {Guiding a Diffusion Model with a Bad Version of Itself},
204 | author = {Tero Karras and Miika Aittala and Tuomas Kynk{\"a}{\"a}nniemi and Jaakko Lehtinen and Timo Aila and Samuli Laine},
205 | journal = {ArXiv},
206 | year = {2024},
207 | volume = {abs/2406.02507},
208 | url = {https://api.semanticscholar.org/CorpusID:270226598}
209 | }
210 | ```
211 |
--------------------------------------------------------------------------------
/e2-tts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/e2-tts-pytorch/1028716e6a4072f668716687d732a3b123dda4cf/e2-tts.png
--------------------------------------------------------------------------------
/e2_tts_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from e2_tts_pytorch.e2_tts import (
2 | Transformer,
3 | DurationPredictor,
4 | E2TTS,
5 | )
6 |
7 | from e2_tts_pytorch.trainer import (
8 | E2Trainer
9 | )
--------------------------------------------------------------------------------
/e2_tts_pytorch/e2_tts.py:
--------------------------------------------------------------------------------
1 | """
2 | ein notation:
3 | b - batch
4 | n - sequence
5 | f - frequency token dimension
6 | nt - text sequence
7 | nw - raw wave length
8 | d - dimension
9 | dt - dimension text
10 | """
11 |
12 | from __future__ import annotations
13 |
14 | from pathlib import Path
15 | from random import random
16 | from functools import partial
17 | from itertools import zip_longest
18 | from collections import namedtuple
19 |
20 | from typing import Literal, Callable
21 |
22 | import jaxtyping
23 | from beartype import beartype
24 |
25 | import torch
26 | import torch.nn.functional as F
27 | from torch import nn, tensor, Tensor, from_numpy
28 | from torch.nn import Module, ModuleList, Sequential, Linear
29 | from torch.nn.utils.rnn import pad_sequence
30 |
31 | import torchaudio
32 | from torchaudio.functional import DB_to_amplitude
33 | from torchdiffeq import odeint
34 |
35 | import einx
36 | from einops.layers.torch import Rearrange, Reduce
37 | from einops import rearrange, repeat, reduce, einsum, pack, unpack
38 |
39 | from x_transformers import (
40 | Attention,
41 | FeedForward,
42 | RMSNorm,
43 | AdaptiveRMSNorm,
44 | )
45 |
46 | from x_transformers.x_transformers import RotaryEmbedding
47 |
48 | from hyper_connections import HyperConnections
49 |
50 | from hl_gauss_pytorch import HLGaussLayer
51 |
52 | from vocos import Vocos
53 |
54 | pad_sequence = partial(pad_sequence, batch_first = True)
55 |
56 | # constants
57 |
58 | class TorchTyping:
59 | def __init__(self, abstract_dtype):
60 | self.abstract_dtype = abstract_dtype
61 |
62 | def __getitem__(self, shapes: str):
63 | return self.abstract_dtype[Tensor, shapes]
64 |
65 | Float = TorchTyping(jaxtyping.Float)
66 | Int = TorchTyping(jaxtyping.Int)
67 | Bool = TorchTyping(jaxtyping.Bool)
68 |
69 | # named tuples
70 |
71 | LossBreakdown = namedtuple('LossBreakdown', ['flow', 'velocity_consistency'])
72 |
73 | E2TTSReturn = namedtuple('E2TTS', ['loss', 'cond', 'pred_flow', 'pred_data', 'loss_breakdown'])
74 |
75 | # helpers
76 |
77 | def exists(v):
78 | return v is not None
79 |
80 | def default(v, d):
81 | return v if exists(v) else d
82 |
83 | def xnor(x, y):
84 | return not (x ^ y)
85 |
86 | def set_if_missing_key(d, key, value):
87 | if key in d:
88 | return
89 |
90 | d.update(**{key: value})
91 |
92 | def l2norm(t):
93 | return F.normalize(t, dim = -1)
94 |
95 | def divisible_by(num, den):
96 | return (num % den) == 0
97 |
98 | def pack_one_with_inverse(x, pattern):
99 | packed, packed_shape = pack([x], pattern)
100 |
101 | def inverse(x, inverse_pattern = None):
102 | inverse_pattern = default(inverse_pattern, pattern)
103 | return unpack(x, packed_shape, inverse_pattern)[0]
104 |
105 | return packed, inverse
106 |
107 | class Identity(Module):
108 | def forward(self, x, **kwargs):
109 | return x
110 |
111 | # tensor helpers
112 |
113 | def project(x, y):
114 | x, inverse = pack_one_with_inverse(x, 'b *')
115 | y, _ = pack_one_with_inverse(y, 'b *')
116 |
117 | dtype = x.dtype
118 | x, y = x.double(), y.double()
119 | unit = F.normalize(y, dim = -1)
120 |
121 | parallel = (x * unit).sum(dim = -1, keepdim = True) * unit
122 | orthogonal = x - parallel
123 |
124 | return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype)
125 |
126 | # simple utf-8 tokenizer, since paper went character based
127 |
128 | def list_str_to_tensor(
129 | text: list[str],
130 | padding_value = -1
131 | ) -> Int['b nt']:
132 |
133 | list_tensors = [tensor([*bytes(t, 'UTF-8')]) for t in text]
134 | padded_tensor = pad_sequence(list_tensors, padding_value = -1)
135 | return padded_tensor
136 |
137 | # simple english phoneme-based tokenizer
138 |
139 | from g2p_en import G2p
140 |
141 | def get_g2p_en_encode():
142 | g2p = G2p()
143 |
144 | # used by @lucasnewman successfully here
145 | # https://github.com/lucasnewman/e2-tts-pytorch/blob/ljspeech-test/e2_tts_pytorch/e2_tts.py
146 |
147 | phoneme_to_index = g2p.p2idx
148 | num_phonemes = len(phoneme_to_index)
149 |
150 | extended_chars = [' ', ',', '.', '-', '!', '?', '\'', '"', '...', '..', '. .', '. . .', '. . . .', '. . . . .', '. ...', '... .', '.. ..']
151 | num_extended_chars = len(extended_chars)
152 |
153 | extended_chars_dict = {p: (num_phonemes + i) for i, p in enumerate(extended_chars)}
154 | phoneme_to_index = {**phoneme_to_index, **extended_chars_dict}
155 |
156 | def encode(
157 | text: list[str],
158 | padding_value = -1
159 | ) -> Int['b nt']:
160 |
161 | phonemes = [g2p(t) for t in text]
162 | list_tensors = [tensor([phoneme_to_index[p] for p in one_phoneme]) for one_phoneme in phonemes]
163 | padded_tensor = pad_sequence(list_tensors, padding_value = -1)
164 | return padded_tensor
165 |
166 | return encode, (num_phonemes + num_extended_chars)
167 |
168 | # tensor helpers
169 |
170 | def log(t, eps = 1e-5):
171 | return t.clamp(min = eps).log()
172 |
173 | def lens_to_mask(
174 | t: Int['b'],
175 | length: int | None = None
176 | ) -> Bool['b n']:
177 |
178 | if not exists(length):
179 | length = t.amax()
180 |
181 | seq = torch.arange(length, device = t.device)
182 | return einx.less('n, b -> b n', seq, t)
183 |
184 | def mask_from_start_end_indices(
185 | seq_len: Int['b'],
186 | start: Int['b'],
187 | end: Int['b']
188 | ):
189 | max_seq_len = seq_len.max().item()
190 | seq = torch.arange(max_seq_len, device = start.device).long()
191 | return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
192 |
193 | def mask_from_frac_lengths(
194 | seq_len: Int['b'],
195 | frac_lengths: Float['b'],
196 | max_length: int | None = None
197 | ):
198 | lengths = (frac_lengths * seq_len).long()
199 | max_start = seq_len - lengths
200 |
201 | rand = torch.rand_like(frac_lengths)
202 | start = (max_start * rand).long().clamp(min = 0)
203 | end = start + lengths
204 |
205 | out = mask_from_start_end_indices(seq_len, start, end)
206 |
207 | if exists(max_length):
208 | out = pad_to_length(out, max_length)
209 |
210 | return out
211 |
212 | def maybe_masked_mean(
213 | t: Float['b n d'],
214 | mask: Bool['b n'] | None = None
215 | ) -> Float['b d']:
216 |
217 | if not exists(mask):
218 | return t.mean(dim = 1)
219 |
220 | t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
221 | num = reduce(t, 'b n d -> b d', 'sum')
222 | den = reduce(mask.float(), 'b n -> b', 'sum')
223 |
224 | return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
225 |
226 | def pad_to_length(
227 | t: Tensor,
228 | length: int,
229 | value = None
230 | ):
231 | seq_len = t.shape[-1]
232 | if length > seq_len:
233 | t = F.pad(t, (0, length - seq_len), value = value)
234 |
235 | return t[..., :length]
236 |
237 | def interpolate_1d(
238 | x: Tensor,
239 | length: int,
240 | mode = 'bilinear'
241 | ):
242 | x = rearrange(x, 'n d -> 1 d n 1')
243 | x = F.interpolate(x, (length, 1), mode = mode)
244 | return rearrange(x, '1 d n 1 -> n d')
245 |
246 | # to mel spec
247 |
248 | class MelSpec(Module):
249 | def __init__(
250 | self,
251 | filter_length = 1024,
252 | hop_length = 256,
253 | win_length = 1024,
254 | n_mel_channels = 100,
255 | sampling_rate = 24_000,
256 | normalize = False,
257 | power = 1,
258 | norm = None,
259 | center = True,
260 | ):
261 | super().__init__()
262 | self.n_mel_channels = n_mel_channels
263 | self.sampling_rate = sampling_rate
264 |
265 | self.mel_stft = torchaudio.transforms.MelSpectrogram(
266 | sample_rate = sampling_rate,
267 | n_fft = filter_length,
268 | win_length = win_length,
269 | hop_length = hop_length,
270 | n_mels = n_mel_channels,
271 | power = power,
272 | center = center,
273 | normalized = normalize,
274 | norm = norm,
275 | )
276 |
277 | self.register_buffer('dummy', tensor(0), persistent = False)
278 |
279 | def forward(self, inp):
280 | if len(inp.shape) == 3:
281 | inp = rearrange(inp, 'b 1 nw -> b nw')
282 |
283 | assert len(inp.shape) == 2
284 |
285 | if self.dummy.device != inp.device:
286 | self.to(inp.device)
287 |
288 | mel = self.mel_stft(inp)
289 | mel = log(mel)
290 | return mel
291 |
292 | # convolutional positional generating module
293 | # taken from https://github.com/lucidrains/voicebox-pytorch/blob/main/voicebox_pytorch/voicebox_pytorch.py#L203
294 |
295 | class DepthwiseConv(Module):
296 | def __init__(
297 | self,
298 | dim,
299 | *,
300 | kernel_size,
301 | groups = None
302 | ):
303 | super().__init__()
304 | assert not divisible_by(kernel_size, 2)
305 | groups = default(groups, dim) # full depthwise conv by default
306 |
307 | self.dw_conv1d = nn.Sequential(
308 | nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
309 | nn.SiLU()
310 | )
311 |
312 | def forward(
313 | self,
314 | x,
315 | mask = None
316 | ):
317 |
318 | if exists(mask):
319 | x = einx.where('b n, b n d, -> b n d', mask, x, 0.)
320 |
321 | x = rearrange(x, 'b n c -> b c n')
322 | x = self.dw_conv1d(x)
323 | out = rearrange(x, 'b c n -> b n c')
324 |
325 | if exists(mask):
326 | out = einx.where('b n, b n d, -> b n d', mask, out, 0.)
327 |
328 | return out
329 |
330 | # adaln zero from DiT paper
331 |
332 | class AdaLNZero(Module):
333 | def __init__(
334 | self,
335 | dim,
336 | dim_condition = None,
337 | init_bias_value = -2.
338 | ):
339 | super().__init__()
340 | dim_condition = default(dim_condition, dim)
341 | self.to_gamma = nn.Linear(dim_condition, dim)
342 |
343 | nn.init.zeros_(self.to_gamma.weight)
344 | nn.init.constant_(self.to_gamma.bias, init_bias_value)
345 |
346 | def forward(self, x, *, condition):
347 | if condition.ndim == 2:
348 | condition = rearrange(condition, 'b d -> b 1 d')
349 |
350 | gamma = self.to_gamma(condition).sigmoid()
351 | return x * gamma
352 |
353 | # random projection fourier embedding
354 |
355 | class RandomFourierEmbed(Module):
356 | def __init__(self, dim):
357 | super().__init__()
358 | assert divisible_by(dim, 2)
359 | self.register_buffer('weights', torch.randn(dim // 2))
360 |
361 | def forward(self, x):
362 | freqs = einx.multiply('i, j -> i j', x, self.weights) * 2 * torch.pi
363 | fourier_embed, _ = pack((x, freqs.sin(), freqs.cos()), 'b *')
364 | return fourier_embed
365 |
366 | # linear with fourier embedded outputs
367 |
368 | class LinearFourierEmbed(Module):
369 | def __init__(
370 | self,
371 | dim,
372 | p = 0.5, # percentage of output dimension to fourier, they found 0.5 to be best (0.25 sin + 0.25 cos)
373 | ):
374 | super().__init__()
375 | assert p <= 1.
376 |
377 | dim_fourier = int(p * dim)
378 | dim_rest = dim - (dim_fourier * 2)
379 |
380 | self.linear = nn.Linear(dim, dim_fourier + dim_rest, bias = False)
381 | self.split_dims = (dim_fourier, dim_rest)
382 |
383 | def forward(self, x):
384 | hiddens = self.linear(x)
385 | fourier, rest = hiddens.split(self.split_dims, dim = -1)
386 | return torch.cat((fourier.sin(), fourier.cos(), rest), dim = -1)
387 |
388 | # character embedding
389 |
390 | class CharacterEmbed(Module):
391 | def __init__(
392 | self,
393 | dim,
394 | num_embeds = 256,
395 | ):
396 | super().__init__()
397 | self.dim = dim
398 | self.embed = nn.Embedding(num_embeds + 1, dim) # will just use 0 as the 'filler token'
399 |
400 | def forward(
401 | self,
402 | text: Int['b nt'],
403 | max_seq_len: int,
404 | **kwargs
405 | ) -> Float['b n d']:
406 |
407 | text = text + 1 # shift all other token ids up by 1 and use 0 as filler token
408 |
409 | text = text[:, :max_seq_len] # just curtail if character tokens are more than the mel spec tokens, one of the edge cases the paper did not address
410 | text = pad_to_length(text, max_seq_len, value = 0)
411 |
412 | return self.embed(text)
413 |
414 | class InterpolatedCharacterEmbed(Module):
415 | def __init__(
416 | self,
417 | dim,
418 | num_embeds = 256,
419 | ):
420 | super().__init__()
421 | self.dim = dim
422 | self.embed = nn.Embedding(num_embeds, dim)
423 |
424 | self.abs_pos_mlp = Sequential(
425 | Rearrange('... -> ... 1'),
426 | Linear(1, dim),
427 | nn.SiLU(),
428 | Linear(dim, dim)
429 | )
430 |
431 | def forward(
432 | self,
433 | text: Int['b nt'],
434 | max_seq_len: int,
435 | mask: Bool['b n'] | None = None
436 | ) -> Float['b n d']:
437 |
438 | device = text.device
439 |
440 | mask = default(mask, (None,))
441 |
442 | interp_embeds = []
443 | interp_abs_positions = []
444 |
445 | for one_text, one_mask in zip_longest(text, mask):
446 |
447 | valid_text = one_text >= 0
448 | one_text = one_text[valid_text]
449 | one_text_embed = self.embed(one_text)
450 |
451 | # save the absolute positions
452 |
453 | text_seq_len = one_text.shape[0]
454 |
455 | # determine audio sequence length from mask
456 |
457 | audio_seq_len = max_seq_len
458 | if exists(one_mask):
459 | audio_seq_len = one_mask.sum().long().item()
460 |
461 | # interpolate text embedding to audio embedding length
462 |
463 | interp_text_embed = interpolate_1d(one_text_embed, audio_seq_len)
464 | interp_abs_pos = torch.linspace(0, text_seq_len, audio_seq_len, device = device)
465 |
466 | interp_embeds.append(interp_text_embed)
467 | interp_abs_positions.append(interp_abs_pos)
468 |
469 | interp_embeds = pad_sequence(interp_embeds)
470 | interp_abs_positions = pad_sequence(interp_abs_positions)
471 |
472 | interp_embeds = F.pad(interp_embeds, (0, 0, 0, max_seq_len - interp_embeds.shape[-2]))
473 | interp_abs_positions = pad_to_length(interp_abs_positions, max_seq_len)
474 |
475 | # pass interp absolute positions through mlp for implicit positions
476 |
477 | interp_embeds = interp_embeds + self.abs_pos_mlp(interp_abs_positions)
478 |
479 | if exists(mask):
480 | interp_embeds = einx.where('b n, b n d, -> b n d', mask, interp_embeds, 0.)
481 |
482 | return interp_embeds
483 |
484 | # text audio cross conditioning in multistream setup
485 |
486 | class TextAudioCrossCondition(Module):
487 | def __init__(
488 | self,
489 | dim,
490 | dim_text,
491 | cond_audio_to_text = True,
492 | ):
493 | super().__init__()
494 | self.text_to_audio = nn.Linear(dim_text + dim, dim, bias = False)
495 | nn.init.zeros_(self.text_to_audio.weight)
496 |
497 | self.cond_audio_to_text = cond_audio_to_text
498 |
499 | if cond_audio_to_text:
500 | self.audio_to_text = nn.Linear(dim + dim_text, dim_text, bias = False)
501 | nn.init.zeros_(self.audio_to_text.weight)
502 |
503 | def forward(
504 | self,
505 | audio: Float['b n d'],
506 | text: Float['b n dt']
507 | ):
508 | audio_text, _ = pack((audio, text), 'b n *')
509 |
510 | text_cond = self.text_to_audio(audio_text)
511 | audio_cond = self.audio_to_text(audio_text) if self.cond_audio_to_text else 0.
512 |
513 | return audio + text_cond, text + audio_cond
514 |
515 | # attention and transformer backbone
516 | # for use in both e2tts as well as duration module
517 |
518 | class Transformer(Module):
519 | @beartype
520 | def __init__(
521 | self,
522 | *,
523 | dim,
524 | dim_text = None, # will default to half of audio dimension
525 | depth = 8,
526 | heads = 8,
527 | dim_head = 64,
528 | ff_mult = 4,
529 | text_depth = None,
530 | text_heads = None,
531 | text_dim_head = None,
532 | text_ff_mult = None,
533 | has_freq_axis = False,
534 | freq_heads = None,
535 | freq_dim_head = None,
536 | cond_on_time = True,
537 | abs_pos_emb = True,
538 | max_seq_len = 8192,
539 | kernel_size = 31,
540 | dropout = 0.1,
541 | num_registers = 32,
542 | scale_residual = False,
543 | attn_laser = False,
544 | attn_laser_softclamp_value = 15.,
545 | attn_fourier_embed_input = False,
546 | attn_fourier_embed_input_frac = 0.25, # https://arxiv.org/abs/2502.21309
547 | num_residual_streams = 4,
548 | attn_kwargs: dict = dict(
549 | gate_value_heads = True,
550 | softclamp_logits = True,
551 | ),
552 | ff_kwargs: dict = dict(),
553 | ):
554 | super().__init__()
555 | assert divisible_by(depth, 2), 'depth needs to be even'
556 |
557 | # absolute positional embedding
558 |
559 | self.max_seq_len = max_seq_len
560 | self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if abs_pos_emb else None
561 |
562 | self.dim = dim
563 |
564 | # determine text related hparams
565 |
566 | dim_text = default(dim_text, dim // 2)
567 | self.dim_text = dim_text
568 |
569 | text_heads = default(text_heads, heads)
570 | text_dim_head = default(text_dim_head, dim_head)
571 | text_ff_mult = default(text_ff_mult, ff_mult)
572 | text_depth = default(text_depth, depth)
573 |
574 | assert 1 <= text_depth <= depth, 'must have at least 1 layer of text conditioning, but less than total number of speech layers'
575 |
576 | # determine maybe freq axis hparams
577 |
578 | freq_heads = default(freq_heads, heads)
579 | freq_dim_head = default(freq_dim_head, dim_head)
580 |
581 | self.has_freq_axis = has_freq_axis
582 |
583 | # layers
584 |
585 | self.depth = depth
586 | layers = []
587 |
588 | # registers
589 |
590 | self.num_registers = num_registers
591 | self.registers = nn.Parameter(torch.zeros(num_registers, dim))
592 | nn.init.normal_(self.registers, std = 0.02)
593 |
594 | self.text_registers = nn.Parameter(torch.zeros(num_registers, dim_text))
595 | nn.init.normal_(self.text_registers, std = 0.02)
596 |
597 | # rotary embedding
598 |
599 | self.rotary_emb = RotaryEmbedding(dim_head)
600 | self.text_rotary_emb = RotaryEmbedding(text_dim_head)
601 |
602 | if has_freq_axis:
603 | self.freq_rotary_emb = RotaryEmbedding(freq_dim_head)
604 |
605 | # hyper connection related
606 |
607 | init_hyper_conn, self.hyper_conn_expand, self.hyper_conn_reduce = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
608 |
609 | hyper_conns = []
610 |
611 | # time conditioning
612 | # will use adaptive rmsnorm
613 |
614 | self.cond_on_time = cond_on_time
615 | rmsnorm_klass = RMSNorm if not cond_on_time else AdaptiveRMSNorm
616 | postbranch_klass = Identity if not cond_on_time else partial(AdaLNZero, dim = dim)
617 |
618 | self.time_cond_mlp = Identity()
619 |
620 | if cond_on_time:
621 | self.time_cond_mlp = Sequential(
622 | RandomFourierEmbed(dim),
623 | Linear(dim + 1, dim),
624 | nn.SiLU()
625 | )
626 |
627 | for ind in range(depth):
628 | is_first_block = ind == 0
629 |
630 | is_later_half = ind >= (depth // 2)
631 | has_text = ind < text_depth
632 |
633 | # speech related
634 |
635 | speech_conv = DepthwiseConv(dim, kernel_size = kernel_size)
636 |
637 | attn_norm = rmsnorm_klass(dim)
638 |
639 | attn_input_fourier_embed = LinearFourierEmbed(dim, p = attn_fourier_embed_input_frac) if attn_fourier_embed_input else nn.Identity()
640 |
641 | attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, learned_value_residual_mix = not is_first_block, laser = attn_laser, laser_softclamp_value = attn_laser_softclamp_value, **attn_kwargs)
642 |
643 | attn_adaln_zero = postbranch_klass()
644 |
645 | ff_norm = rmsnorm_klass(dim)
646 | ff = FeedForward(dim = dim, glu = True, mult = ff_mult, dropout = dropout, **ff_kwargs)
647 | ff_adaln_zero = postbranch_klass()
648 |
649 | skip_proj = Linear(dim * 2, dim, bias = False) if is_later_half else None
650 |
651 | freq_attn_norm = freq_attn = freq_attn_adaln_zero = None
652 |
653 | if has_freq_axis:
654 | freq_attn_norm = rmsnorm_klass(dim)
655 | freq_attn = Attention(dim = dim, heads = freq_heads, dim_head = freq_dim_head)
656 | freq_attn_adaln_zero = postbranch_klass()
657 |
658 | speech_modules = ModuleList([
659 | skip_proj,
660 | speech_conv,
661 | attn_norm,
662 | attn,
663 | attn_input_fourier_embed,
664 | attn_adaln_zero,
665 | ff_norm,
666 | ff,
667 | ff_adaln_zero,
668 | freq_attn_norm,
669 | freq_attn,
670 | freq_attn_adaln_zero
671 | ])
672 |
673 | speech_hyper_conns = ModuleList([
674 | init_hyper_conn(dim = dim), # conv
675 | init_hyper_conn(dim = dim), # attn
676 | init_hyper_conn(dim = dim), # ff
677 | init_hyper_conn(dim = dim) if has_freq_axis else None
678 | ])
679 |
680 | text_modules = None
681 | text_hyper_conns = None
682 |
683 | if has_text:
684 | # text related
685 |
686 | text_conv = DepthwiseConv(dim_text, kernel_size = kernel_size)
687 |
688 | text_attn_norm = RMSNorm(dim_text)
689 | text_attn = Attention(dim = dim_text, heads = text_heads, dim_head = text_dim_head, dropout = dropout, learned_value_residual_mix = not is_first_block, laser = attn_laser, laser_softclamp_value = attn_laser_softclamp_value, **attn_kwargs)
690 |
691 | text_ff_norm = RMSNorm(dim_text)
692 | text_ff = FeedForward(dim = dim_text, glu = True, mult = text_ff_mult, dropout = dropout, **ff_kwargs)
693 |
694 | # cross condition
695 |
696 | is_last = ind == (text_depth - 1)
697 |
698 | cross_condition = TextAudioCrossCondition(dim = dim, dim_text = dim_text, cond_audio_to_text = not is_last)
699 |
700 | text_modules = ModuleList([
701 | text_conv,
702 | text_attn_norm,
703 | text_attn,
704 | text_ff_norm,
705 | text_ff,
706 | cross_condition
707 | ])
708 |
709 | text_hyper_conns = ModuleList([
710 | init_hyper_conn(dim = dim_text), # conv
711 | init_hyper_conn(dim = dim_text), # attn
712 | init_hyper_conn(dim = dim_text), # ff
713 | ])
714 |
715 | hyper_conns.append(ModuleList([
716 | speech_hyper_conns,
717 | text_hyper_conns
718 | ]))
719 |
720 | layers.append(ModuleList([
721 | speech_modules,
722 | text_modules
723 | ]))
724 |
725 | self.layers = ModuleList(layers)
726 |
727 | self.hyper_conns = ModuleList(hyper_conns)
728 |
729 | self.final_norm = RMSNorm(dim)
730 |
731 | def forward(
732 | self,
733 | x: Float['b n d'] | Float['b f n d'],
734 | times: Float['b'] | Float[''] | None = None,
735 | mask: Bool['b n'] | None = None,
736 | text_embed: Float['b n dt'] | None = None,
737 | ):
738 | orig_batch = x.shape[0]
739 |
740 | assert xnor(x.ndim == 4, self.has_freq_axis), '`has_freq_axis` must be set if passing in tensor with frequency dimension (4 ndims), and not set if passing in only 3'
741 |
742 | freq_seq_len = 1
743 |
744 | if self.has_freq_axis:
745 | freq_seq_len = x.shape[1]
746 | x = rearrange(x, 'b f n d -> (b f) n d')
747 |
748 | if exists(text_embed):
749 | text_embed = repeat(text_embed, 'b ... -> (b f) ...', f = freq_seq_len)
750 |
751 | if exists(mask):
752 | mask = repeat(mask, 'b ... -> (b f) ...', f = freq_seq_len)
753 |
754 | batch, seq_len, device = x.shape[0], x.shape[1], x.device
755 |
756 | assert not (exists(times) ^ self.cond_on_time), '`times` must be passed in if `cond_on_time` is set to `True` and vice versa'
757 |
758 | # handle absolute positions if needed
759 |
760 | if exists(self.abs_pos_emb):
761 | assert seq_len <= self.max_seq_len, f'{seq_len} exceeds the set `max_seq_len` ({self.max_seq_len}) on Transformer'
762 | seq = torch.arange(seq_len, device = device)
763 | x = x + self.abs_pos_emb(seq)
764 |
765 | # register tokens
766 |
767 | registers = repeat(self.registers, 'r d -> b r d', b = batch)
768 | x, registers_packed_shape = pack((registers, x), 'b * d')
769 |
770 | if exists(mask):
771 | mask = F.pad(mask, (self.num_registers, 0), value = True)
772 |
773 | # handle adaptive rmsnorm kwargs
774 |
775 | norm_kwargs = dict()
776 | freq_norm_kwargs = dict()
777 |
778 | if exists(times):
779 | if times.ndim == 0:
780 | times = repeat(times, ' -> b', b = orig_batch)
781 |
782 | times = self.time_cond_mlp(times)
783 |
784 | if self.has_freq_axis:
785 | freq_times = repeat(times, 'b ... -> (b n) ...', n = x.shape[-2])
786 | freq_norm_kwargs.update(condition = freq_times)
787 |
788 | times = repeat(times, 'b ... -> (b f) ...', f = freq_seq_len)
789 | norm_kwargs.update(condition = times)
790 |
791 | # rotary embedding
792 |
793 | rotary_pos_emb = self.rotary_emb.forward_from_seq_len(x.shape[-2])
794 |
795 | # text related
796 |
797 | if exists(text_embed):
798 | text_rotary_pos_emb = self.text_rotary_emb.forward_from_seq_len(x.shape[-2])
799 |
800 | text_registers = repeat(self.text_registers, 'r d -> b r d', b = batch)
801 | text_embed, _ = pack((text_registers, text_embed), 'b * d')
802 |
803 | if self.has_freq_axis:
804 | freq_rotary_pos_emb = self.freq_rotary_emb.forward_from_seq_len(freq_seq_len)
805 |
806 | # skip connection related stuff
807 |
808 | skips = []
809 |
810 | # value residual
811 |
812 | text_attn_first_values = None
813 | freq_attn_first_values = None
814 | attn_first_values = None
815 |
816 | # expand hyper connections
817 |
818 | x = self.hyper_conn_expand(x)
819 |
820 | if exists(text_embed):
821 | text_embed = self.hyper_conn_expand(text_embed)
822 |
823 | # go through the layers
824 |
825 | for ind, ((speech_modules, text_modules), (speech_residual_fns, text_residual_fns)) in enumerate(zip(self.layers, self.hyper_conns)):
826 |
827 | layer = ind + 1
828 |
829 | (
830 | maybe_skip_proj,
831 | speech_conv,
832 | attn_norm,
833 | attn,
834 | attn_input_fourier_embed,
835 | maybe_attn_adaln_zero,
836 | ff_norm,
837 | ff,
838 | maybe_ff_adaln_zero,
839 | maybe_freq_attn_norm,
840 | maybe_freq_attn,
841 | maybe_freq_attn_adaln_zero
842 | ) = speech_modules
843 |
844 | (
845 | conv_residual,
846 | attn_residual,
847 | ff_residual,
848 | maybe_freq_attn_residual
849 | ) = speech_residual_fns
850 |
851 | # smaller text transformer
852 |
853 | if exists(text_embed) and exists(text_modules):
854 |
855 | (
856 | text_conv,
857 | text_attn_norm,
858 | text_attn,
859 | text_ff_norm,
860 | text_ff,
861 | cross_condition
862 | ) = text_modules
863 |
864 | (
865 | text_conv_residual,
866 | text_attn_residual,
867 | text_ff_residual
868 | ) = text_residual_fns
869 |
870 | text_embed, add_residual = text_conv_residual(text_embed)
871 | text_embed = text_conv(text_embed, mask = mask)
872 | text_embed = add_residual(text_embed)
873 |
874 | text_embed, add_residual = text_attn_residual(text_embed)
875 | text_attn_out, text_attn_inter = text_attn(text_attn_norm(text_embed), rotary_pos_emb = text_rotary_pos_emb, mask = mask, return_intermediates = True, value_residual = text_attn_first_values)
876 | text_embed = add_residual(text_attn_out)
877 |
878 | text_attn_first_values = default(text_attn_first_values, text_attn_inter.values)
879 |
880 | text_embed, add_residual = text_ff_residual(text_embed)
881 | text_embed = text_ff(text_ff_norm(text_embed))
882 | text_embed = add_residual(text_embed)
883 | x, text_embed = cross_condition(x, text_embed)
884 |
885 | # skip connection logic
886 |
887 | is_first_half = layer <= (self.depth // 2)
888 | is_later_half = not is_first_half
889 |
890 | if is_first_half:
891 | skips.append(x)
892 |
893 | if is_later_half:
894 | skip = skips.pop()
895 | x = torch.cat((x, skip), dim = -1)
896 | x = maybe_skip_proj(x)
897 |
898 | # position generating convolution
899 |
900 | x, add_residual = conv_residual(x)
901 | x = speech_conv(x, mask = mask)
902 | x = add_residual(x)
903 |
904 | # attention
905 |
906 | x, add_residual = attn_residual(x)
907 |
908 | x = attn_norm(x, **norm_kwargs)
909 | x = attn_input_fourier_embed(x)
910 |
911 | attn_out, attn_inter = attn(x, rotary_pos_emb = rotary_pos_emb, mask = mask, return_intermediates = True, value_residual = attn_first_values)
912 |
913 | attn_out = maybe_attn_adaln_zero(attn_out, **norm_kwargs)
914 | x = add_residual(attn_out)
915 |
916 | attn_first_values = default(attn_first_values, attn_inter.values)
917 |
918 | # attention across frequency tokens, if needed
919 |
920 | if self.has_freq_axis:
921 |
922 | x, add_residual = maybe_freq_attn_residual(x)
923 |
924 | x = rearrange(x, '(b f) n d -> (b n) f d', b = orig_batch)
925 |
926 | attn_out, attn_inter = maybe_freq_attn(maybe_freq_attn_norm(x, **freq_norm_kwargs), rotary_pos_emb = freq_rotary_pos_emb, return_intermediates = True, value_residual = freq_attn_first_values)
927 | attn_out = maybe_freq_attn_adaln_zero(attn_out, **freq_norm_kwargs)
928 |
929 | attn_out = rearrange(attn_out, '(b n) f d -> (b f) n d', b = orig_batch)
930 |
931 | x = add_residual(attn_out)
932 | freq_attn_first_values = default(freq_attn_first_values, attn_inter.values)
933 |
934 | # feedforward
935 |
936 | x, add_residual = ff_residual(x)
937 | ff_out = ff(ff_norm(x, **norm_kwargs))
938 | ff_out = maybe_ff_adaln_zero(ff_out, **norm_kwargs)
939 | x = add_residual(ff_out)
940 |
941 | assert len(skips) == 0
942 |
943 | _, x = unpack(x, registers_packed_shape, 'b * d')
944 |
945 | # sum all residual streams from hyper connections
946 |
947 | x = self.hyper_conn_reduce(x)
948 |
949 | if self.has_freq_axis:
950 | x = rearrange(x, '(b f) n d -> b f n d', f = freq_seq_len)
951 |
952 | return self.final_norm(x)
953 |
954 | # main classes
955 |
956 | class DurationPredictor(Module):
957 | @beartype
958 | def __init__(
959 | self,
960 | transformer: dict | Transformer,
961 | num_channels = None,
962 | mel_spec_kwargs: dict = dict(),
963 | char_embed_kwargs: dict = dict(),
964 | text_num_embeds = None,
965 | num_freq_tokens = 1,
966 | hl_gauss_loss: dict | None = None,
967 | use_regression = True,
968 | tokenizer: (
969 | Literal['char_utf8', 'phoneme_en'] |
970 | Callable[[list[str]], Int['b nt']]
971 | ) = 'char_utf8'
972 | ):
973 | super().__init__()
974 |
975 | # freq axis hparams
976 |
977 | assert num_freq_tokens > 0
978 | self.num_freq_tokens = num_freq_tokens
979 | self.has_freq_axis = num_freq_tokens > 1
980 |
981 | if isinstance(transformer, dict):
982 | set_if_missing_key(transformer, 'has_freq_axis', self.has_freq_axis)
983 |
984 | transformer = Transformer(
985 | **transformer,
986 | cond_on_time = False
987 | )
988 |
989 | assert transformer.has_freq_axis == self.has_freq_axis
990 |
991 | # mel spec
992 |
993 | self.mel_spec = MelSpec(**mel_spec_kwargs)
994 | self.num_channels = default(num_channels, self.mel_spec.n_mel_channels)
995 |
996 | self.transformer = transformer
997 |
998 | dim = transformer.dim
999 | dim_text = transformer.dim_text
1000 |
1001 | self.dim = dim
1002 |
1003 | # projecting depends on whether frequency axis is needed
1004 |
1005 | if not self.has_freq_axis:
1006 | self.proj_in = Linear(self.num_channels, self.dim)
1007 | else:
1008 | self.proj_in = nn.Sequential(
1009 | Linear(self.num_channels, self.dim * num_freq_tokens),
1010 | Rearrange('b n (f d) -> b f n d', f = num_freq_tokens)
1011 | )
1012 |
1013 | # tokenizer and text embed
1014 |
1015 | if callable(tokenizer):
1016 | assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function'
1017 | self.tokenizer = tokenizer
1018 | elif tokenizer == 'char_utf8':
1019 | text_num_embeds = 256
1020 | self.tokenizer = list_str_to_tensor
1021 | elif tokenizer == 'phoneme_en':
1022 | self.tokenizer, text_num_embeds = get_g2p_en_encode()
1023 | else:
1024 | raise ValueError(f'unknown tokenizer string {tokenizer}')
1025 |
1026 | self.embed_text = CharacterEmbed(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs)
1027 |
1028 | # maybe reduce frequencies
1029 |
1030 | self.maybe_reduce_freq_axis = Reduce('b f n d -> b n d', 'mean') if self.has_freq_axis else nn.Identity()
1031 |
1032 | # to prediction
1033 | # applying https://arxiv.org/abs/2403.03950
1034 |
1035 | self.hl_gauss_layer = HLGaussLayer(
1036 | dim,
1037 | hl_gauss_loss = hl_gauss_loss,
1038 | use_regression = use_regression,
1039 | regress_activation = nn.Softplus()
1040 | )
1041 |
1042 | def forward(
1043 | self,
1044 | x: Float['b n d'] | Float['b nw'],
1045 | *,
1046 | text: Int['b nt'] | list[str] | None = None,
1047 | lens: Int['b'] | None = None,
1048 | return_loss = True
1049 | ):
1050 | # raw wave
1051 |
1052 | if x.ndim == 2:
1053 | x = self.mel_spec(x)
1054 | x = rearrange(x, 'b d n -> b n d')
1055 | assert x.shape[-1] == self.dim
1056 |
1057 | x = self.proj_in(x)
1058 |
1059 | batch, seq_len, device = x.shape[0], x.shape[-2], x.device
1060 |
1061 | # text
1062 |
1063 | text_embed = None
1064 |
1065 | if exists(text):
1066 | if isinstance(text, list):
1067 | text = list_str_to_tensor(text).to(device)
1068 | assert text.shape[0] == batch
1069 |
1070 | text_embed = self.embed_text(text, seq_len)
1071 |
1072 | # handle lengths (duration)
1073 |
1074 | if not exists(lens):
1075 | lens = torch.full((batch,), seq_len, device = device)
1076 |
1077 | mask = lens_to_mask(lens, length = seq_len)
1078 |
1079 | # if returning a loss, mask out randomly from an index and have it predict the duration
1080 |
1081 | if return_loss:
1082 | rand_frac_index = x.new_zeros(batch).uniform_(0, 1)
1083 | rand_index = (rand_frac_index * lens).long()
1084 |
1085 | seq = torch.arange(seq_len, device = device)
1086 | mask &= einx.less('n, b -> b n', seq, rand_index)
1087 |
1088 | # attending
1089 |
1090 | embed = self.transformer(
1091 | x,
1092 | mask = mask,
1093 | text_embed = text_embed,
1094 | )
1095 |
1096 | # maybe reduce freq
1097 |
1098 | embed = self.maybe_reduce_freq_axis(embed)
1099 |
1100 | # masked mean
1101 |
1102 | pooled_embed = maybe_masked_mean(embed, mask)
1103 |
1104 | # return the prediction if not returning loss
1105 |
1106 | if not return_loss:
1107 | return self.hl_gauss_layer(pooled_embed)
1108 |
1109 | # loss
1110 |
1111 | loss = self.hl_gauss_layer(pooled_embed, lens.float())
1112 |
1113 | return loss
1114 |
1115 | class E2TTS(Module):
1116 |
1117 | @beartype
1118 | def __init__(
1119 | self,
1120 | transformer: dict | Transformer = None,
1121 | duration_predictor: dict | DurationPredictor | None = None,
1122 | odeint_kwargs: dict = dict(
1123 | atol = 1e-5,
1124 | rtol = 1e-5,
1125 | method = 'midpoint'
1126 | ),
1127 | cond_drop_prob = 0.25,
1128 | num_channels = None,
1129 | mel_spec_module: Module | None = None,
1130 | num_freq_tokens = 1,
1131 | char_embed_kwargs: dict = dict(),
1132 | mel_spec_kwargs: dict = dict(),
1133 | frac_lengths_mask: tuple[float, float] = (0.7, 1.),
1134 | concat_cond = False,
1135 | interpolated_text = False,
1136 | text_num_embeds: int | None = None,
1137 | tokenizer: (
1138 | Literal['char_utf8', 'phoneme_en'] |
1139 | Callable[[list[str]], Int['b nt']]
1140 | ) = 'char_utf8',
1141 | use_vocos = True,
1142 | pretrained_vocos_path = 'charactr/vocos-mel-24khz',
1143 | sampling_rate: int | None = None,
1144 | velocity_consistency_weight = 0.,
1145 | ):
1146 | super().__init__()
1147 |
1148 | # freq axis hparams
1149 |
1150 | assert num_freq_tokens > 0
1151 | self.num_freq_tokens = num_freq_tokens
1152 | self.has_freq_axis = num_freq_tokens > 1
1153 |
1154 | # set transformer
1155 |
1156 | if isinstance(transformer, dict):
1157 | set_if_missing_key(transformer, 'has_freq_axis', self.has_freq_axis)
1158 |
1159 | transformer = Transformer(
1160 | **transformer,
1161 | cond_on_time = True
1162 | )
1163 |
1164 | assert transformer.has_freq_axis == self.has_freq_axis
1165 | self.transformer = transformer
1166 |
1167 | # duration predictor
1168 |
1169 | if isinstance(duration_predictor, dict):
1170 | duration_predictor = DurationPredictor(**duration_predictor)
1171 |
1172 | # hparams
1173 |
1174 | dim = transformer.dim
1175 | dim_text = transformer.dim_text
1176 |
1177 | self.dim = dim
1178 | self.dim_text = dim_text
1179 |
1180 | self.frac_lengths_mask = frac_lengths_mask
1181 |
1182 | self.duration_predictor = duration_predictor
1183 |
1184 | # sampling
1185 |
1186 | self.odeint_kwargs = odeint_kwargs
1187 |
1188 | # mel spec
1189 |
1190 | self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
1191 | num_channels = default(num_channels, self.mel_spec.n_mel_channels)
1192 |
1193 | self.num_channels = num_channels
1194 | self.sampling_rate = default(sampling_rate, getattr(self.mel_spec, 'sampling_rate', None))
1195 |
1196 | # whether to concat condition and project rather than project both and sum
1197 |
1198 | self.concat_cond = concat_cond
1199 |
1200 | if concat_cond:
1201 | self.proj_in = nn.Linear(num_channels * 2, dim * num_freq_tokens)
1202 | else:
1203 | self.proj_in = nn.Linear(num_channels, dim * num_freq_tokens)
1204 | self.cond_proj_in = nn.Linear(num_channels, dim * num_freq_tokens)
1205 |
1206 | # maybe split out frequency
1207 |
1208 | self.maybe_split_freq = Rearrange('b n (f d) -> b f n d', f = num_freq_tokens) if self.has_freq_axis else nn.Identity()
1209 |
1210 | self.maybe_reduce_freq = Reduce('b f n d -> b n d', 'mean') if self.has_freq_axis else nn.Identity()
1211 |
1212 | # to prediction
1213 |
1214 | self.to_pred = Linear(dim, num_channels)
1215 |
1216 | # tokenizer and text embed
1217 |
1218 | if callable(tokenizer):
1219 | assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function'
1220 | self.tokenizer = tokenizer
1221 | elif tokenizer == 'char_utf8':
1222 | text_num_embeds = 256
1223 | self.tokenizer = list_str_to_tensor
1224 | elif tokenizer == 'phoneme_en':
1225 | self.tokenizer, text_num_embeds = get_g2p_en_encode()
1226 | else:
1227 | raise ValueError(f'unknown tokenizer string {tokenizer}')
1228 |
1229 | self.cond_drop_prob = cond_drop_prob
1230 |
1231 | # text embedding
1232 |
1233 | text_embed_klass = CharacterEmbed if not interpolated_text else InterpolatedCharacterEmbed
1234 |
1235 | self.embed_text = text_embed_klass(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs)
1236 |
1237 | # weight for velocity consistency
1238 |
1239 | self.register_buffer('zero', tensor(0.), persistent = False)
1240 | self.velocity_consistency_weight = velocity_consistency_weight
1241 |
1242 | # default vocos for mel -> audio
1243 |
1244 | self.vocos = Vocos.from_pretrained(pretrained_vocos_path) if use_vocos else None
1245 |
1246 | @property
1247 | def device(self):
1248 | return next(self.parameters()).device
1249 |
1250 | def transformer_with_pred_head(
1251 | self,
1252 | x: Float['b n d'],
1253 | cond: Float['b n d'],
1254 | times: Float['b'],
1255 | mask: Bool['b n'] | None = None,
1256 | text: Int['b nt'] | None = None,
1257 | drop_text_cond: bool | None = None,
1258 | return_drop_text_cond = False
1259 | ):
1260 | seq_len = x.shape[-2]
1261 | drop_text_cond = default(drop_text_cond, self.training and random() < self.cond_drop_prob)
1262 |
1263 | if self.concat_cond:
1264 | # concat condition, given as using voicebox-like scheme
1265 | x = torch.cat((cond, x), dim = -1)
1266 |
1267 | x = self.proj_in(x)
1268 | x = self.maybe_split_freq(x)
1269 |
1270 | if not self.concat_cond:
1271 | # an alternative is to simply sum the condition
1272 | # seems to work fine
1273 |
1274 | cond = self.cond_proj_in(cond)
1275 | cond = self.maybe_split_freq(cond)
1276 |
1277 | x = x + cond
1278 |
1279 | # whether to use a text embedding
1280 |
1281 | text_embed = None
1282 | if exists(text) and not drop_text_cond:
1283 | text_embed = self.embed_text(text, seq_len, mask = mask)
1284 |
1285 | # attend
1286 |
1287 | embed = self.transformer(
1288 | x,
1289 | times = times,
1290 | mask = mask,
1291 | text_embed = text_embed
1292 | )
1293 |
1294 | embed = self.maybe_reduce_freq(embed)
1295 |
1296 | pred = self.to_pred(embed)
1297 |
1298 | if not return_drop_text_cond:
1299 | return pred
1300 |
1301 | return pred, drop_text_cond
1302 |
1303 | def cfg_transformer_with_pred_head(
1304 | self,
1305 | *args,
1306 | cfg_strength: float = 1.,
1307 | cfg_null_model: E2TTS | None = None,
1308 | remove_parallel_component: bool = True,
1309 | keep_parallel_frac: float = 0.,
1310 | **kwargs,
1311 | ):
1312 |
1313 | pred = self.transformer_with_pred_head(*args, drop_text_cond = False, **kwargs)
1314 |
1315 | if cfg_strength < 1e-5:
1316 | return pred
1317 |
1318 | null_drop_text_cond = not exists(cfg_null_model)
1319 | cfg_null_model = default(cfg_null_model, self)
1320 |
1321 | null_pred = cfg_null_model.transformer_with_pred_head(*args, drop_text_cond = null_drop_text_cond, **kwargs)
1322 |
1323 | cfg_update = pred - null_pred
1324 |
1325 | if remove_parallel_component:
1326 | # https://arxiv.org/abs/2410.02416
1327 | parallel, orthogonal = project(cfg_update, pred)
1328 | cfg_update = orthogonal + parallel * keep_parallel_frac
1329 |
1330 | return pred + cfg_update * cfg_strength
1331 |
1332 | @torch.no_grad()
1333 | def sample(
1334 | self,
1335 | cond: Float['b n d'] | Float['b nw'],
1336 | *,
1337 | text: Int['b nt'] | list[str] | None = None,
1338 | lens: Int['b'] | None = None,
1339 | duration: int | Int['b'] | None = None,
1340 | steps = 32,
1341 | cfg_strength = 1., # they used a classifier free guidance strength of 1.
1342 | cfg_null_model: E2TTS | None = None, # for "autoguidance" from Karras et al. https://arxiv.org/abs/2406.02507
1343 | max_duration = 4096, # in case the duration predictor goes haywire
1344 | vocoder: Callable[[Float['b d n']], list[Float['_']]] | None = None,
1345 | return_raw_output: bool | None = None,
1346 | save_to_filename: str | None = None
1347 | ) -> (
1348 | Float['b n d'],
1349 | list[Float['_']]
1350 | ):
1351 | self.eval()
1352 |
1353 | # raw wave
1354 |
1355 | if cond.ndim == 2:
1356 | cond = self.mel_spec(cond)
1357 | cond = rearrange(cond, 'b d n -> b n d')
1358 | assert cond.shape[-1] == self.num_channels
1359 |
1360 | batch, cond_seq_len, device = *cond.shape[:2], cond.device
1361 |
1362 | if not exists(lens):
1363 | lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
1364 |
1365 | # text
1366 |
1367 | if isinstance(text, list):
1368 | text = self.tokenizer(text).to(device)
1369 | assert text.shape[0] == batch
1370 |
1371 | if exists(text):
1372 | text_lens = (text != -1).sum(dim = -1)
1373 | lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
1374 |
1375 | # duration
1376 |
1377 | cond_mask = lens_to_mask(lens)
1378 |
1379 | if exists(duration):
1380 | if isinstance(duration, int):
1381 | duration = torch.full((batch,), duration, device = device, dtype = torch.long)
1382 |
1383 | elif exists(self.duration_predictor):
1384 | duration = self.duration_predictor(cond, text = text, lens = lens, return_loss = False).long()
1385 |
1386 | duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
1387 | duration = duration.clamp(max = max_duration)
1388 |
1389 | assert duration.shape[0] == batch
1390 |
1391 | max_duration = duration.amax()
1392 |
1393 | cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
1394 | cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
1395 | cond_mask = rearrange(cond_mask, '... -> ... 1')
1396 |
1397 | mask = lens_to_mask(duration)
1398 |
1399 | # neural ode
1400 |
1401 | def fn(t, x):
1402 | # at each step, conditioning is fixed
1403 |
1404 | step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
1405 |
1406 | # predict flow
1407 |
1408 | return self.cfg_transformer_with_pred_head(
1409 | x,
1410 | step_cond,
1411 | times = t,
1412 | text = text,
1413 | mask = mask,
1414 | cfg_strength = cfg_strength,
1415 | cfg_null_model = cfg_null_model
1416 | )
1417 |
1418 | y0 = torch.randn_like(cond)
1419 | t = torch.linspace(0, 1, steps, device = self.device)
1420 |
1421 | trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
1422 | sampled = trajectory[-1]
1423 |
1424 | out = sampled
1425 |
1426 | out = torch.where(cond_mask, cond, out)
1427 |
1428 | # able to return raw untransformed output, if not using mel rep
1429 |
1430 | if exists(return_raw_output) and return_raw_output:
1431 | return out
1432 |
1433 | # take care of transforming mel to audio if `vocoder` is passed in, or if `use_vocos` is turned on
1434 |
1435 | if exists(vocoder):
1436 | assert not exists(self.vocos), '`use_vocos` should not be turned on if you are passing in a custom `vocoder` on sampling'
1437 | out = rearrange(out, 'b n d -> b d n')
1438 | out = vocoder(out)
1439 |
1440 | elif exists(self.vocos):
1441 |
1442 | audio = []
1443 | for mel, one_mask in zip(out, mask):
1444 | one_out = DB_to_amplitude(mel[one_mask], ref = 1., power = 0.5)
1445 |
1446 | one_out = rearrange(one_out, 'n d -> 1 d n')
1447 | one_audio = self.vocos.decode(one_out)
1448 | one_audio = rearrange(one_audio, '1 nw -> nw')
1449 | audio.append(one_audio)
1450 |
1451 | out = audio
1452 |
1453 | if exists(save_to_filename):
1454 | assert exists(vocoder) or exists(self.vocos)
1455 | assert exists(self.sampling_rate)
1456 |
1457 | path = Path(save_to_filename)
1458 | parent_path = path.parents[0]
1459 | parent_path.mkdir(exist_ok = True, parents = True)
1460 |
1461 | for ind, one_audio in enumerate(out):
1462 | one_audio = rearrange(one_audio, 'nw -> 1 nw')
1463 | save_path = str(parent_path / f'{ind + 1}.{path.name}')
1464 | torchaudio.save(save_path, one_audio.detach().cpu(), sample_rate = self.sampling_rate)
1465 |
1466 | return out
1467 |
1468 | def forward(
1469 | self,
1470 | inp: Float['b n d'] | Float['b nw'], # mel or raw wave
1471 | *,
1472 | text: Int['b nt'] | list[str] | None = None,
1473 | times: Int['b'] | None = None,
1474 | lens: Int['b'] | None = None,
1475 | velocity_consistency_model: E2TTS | None = None,
1476 | velocity_consistency_delta = 1e-5
1477 | ):
1478 | need_velocity_loss = exists(velocity_consistency_model) and self.velocity_consistency_weight > 0.
1479 |
1480 | # handle raw wave
1481 |
1482 | if inp.ndim == 2:
1483 | inp = self.mel_spec(inp)
1484 | inp = rearrange(inp, 'b d n -> b n d')
1485 | assert inp.shape[-1] == self.num_channels
1486 |
1487 | batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, self.device
1488 |
1489 | # handle text as string
1490 |
1491 | if isinstance(text, list):
1492 | text = self.tokenizer(text).to(device)
1493 | assert text.shape[0] == batch
1494 |
1495 | # lens and mask
1496 |
1497 | if not exists(lens):
1498 | lens = torch.full((batch,), seq_len, device = device)
1499 |
1500 | mask = lens_to_mask(lens, length = seq_len)
1501 |
1502 | # get a random span to mask out for training conditionally
1503 |
1504 | frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
1505 | rand_span_mask = mask_from_frac_lengths(lens, frac_lengths, max_length = seq_len)
1506 |
1507 | if exists(mask):
1508 | rand_span_mask &= mask
1509 |
1510 | # mel is x1
1511 |
1512 | x1 = inp
1513 |
1514 | # main conditional flow training logic
1515 | # just ~5 loc
1516 |
1517 | # x0 is gaussian noise
1518 |
1519 | x0 = torch.randn_like(x1)
1520 |
1521 | # t is random times from above
1522 |
1523 | times = torch.rand((batch,), dtype = dtype, device = self.device)
1524 | t = rearrange(times, 'b -> b 1 1')
1525 |
1526 | # if need velocity consistency, make sure time does not exceed 1.
1527 |
1528 | if need_velocity_loss:
1529 | t = t * (1. - velocity_consistency_delta)
1530 |
1531 | # sample xt (w in the paper)
1532 |
1533 | w = (1. - t) * x0 + t * x1
1534 |
1535 | flow = x1 - x0
1536 |
1537 | # only predict what is within the random mask span for infilling
1538 |
1539 | cond = einx.where(
1540 | 'b n, b n d, b n d -> b n d',
1541 | rand_span_mask,
1542 | torch.zeros_like(x1), x1
1543 | )
1544 |
1545 | # transformer and prediction head
1546 |
1547 | pred, did_drop_text_cond = self.transformer_with_pred_head(
1548 | w,
1549 | cond,
1550 | times = times,
1551 | text = text,
1552 | mask = mask,
1553 | return_drop_text_cond = True
1554 | )
1555 |
1556 | # maybe velocity consistency loss
1557 |
1558 | velocity_loss = self.zero
1559 |
1560 | if need_velocity_loss:
1561 |
1562 | t_with_delta = t + velocity_consistency_delta
1563 | w_with_delta = (1. - t_with_delta) * x0 + t_with_delta * x1
1564 |
1565 | with torch.no_grad():
1566 | ema_pred = velocity_consistency_model.transformer_with_pred_head(
1567 | w_with_delta,
1568 | cond,
1569 | times = times + velocity_consistency_delta,
1570 | text = text,
1571 | mask = mask,
1572 | drop_text_cond = did_drop_text_cond
1573 | )
1574 |
1575 | velocity_loss = F.mse_loss(pred, ema_pred, reduction = 'none')
1576 | velocity_loss = velocity_loss[rand_span_mask].mean()
1577 |
1578 | # flow matching loss
1579 |
1580 | loss = F.mse_loss(pred, flow, reduction = 'none')
1581 |
1582 | loss = loss[rand_span_mask].mean()
1583 |
1584 | # total loss and get breakdown
1585 |
1586 | total_loss = (
1587 | loss +
1588 | velocity_loss * self.velocity_consistency_weight
1589 | )
1590 |
1591 | breakdown = LossBreakdown(loss, velocity_loss)
1592 |
1593 | # return total loss and bunch of intermediates
1594 |
1595 | return E2TTSReturn(total_loss, cond, pred, x0 + pred, breakdown)
1596 |
--------------------------------------------------------------------------------
/e2_tts_pytorch/trainer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | from tqdm import tqdm
5 | import matplotlib
6 | matplotlib.use("Agg")
7 | import matplotlib.pylab as plt
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.optim import Optimizer
12 | from torch.utils.data import DataLoader, Dataset
13 | from torch.utils.tensorboard import SummaryWriter
14 | from torch.optim.lr_scheduler import LinearLR, SequentialLR
15 |
16 | import torchaudio
17 |
18 | from einops import rearrange
19 |
20 | from accelerate import Accelerator
21 | from accelerate.utils import DistributedDataParallelKwargs
22 |
23 | from adam_atan2_pytorch.adopt import Adopt
24 |
25 | from ema_pytorch import EMA
26 |
27 | from loguru import logger
28 |
29 | from e2_tts_pytorch.e2_tts import (
30 | E2TTS,
31 | DurationPredictor,
32 | MelSpec
33 | )
34 |
35 | def exists(v):
36 | return v is not None
37 |
38 | def default(v, d):
39 | return v if exists(v) else d
40 |
41 | def to_numpy(t):
42 | return t.detach().cpu().numpy()
43 |
44 | # plot spectrogram
45 |
46 | def plot_spectrogram(spectrogram):
47 | spectrogram = to_numpy(spectrogram)
48 | fig, ax = plt.subplots(figsize=(10, 4))
49 | im = ax.imshow(spectrogram.T, aspect="auto", origin="lower", interpolation="none")
50 | plt.colorbar(im, ax=ax)
51 | plt.xlabel("Frames")
52 | plt.ylabel("Channels")
53 | plt.tight_layout()
54 |
55 | fig.canvas.draw()
56 | plt.close()
57 | return fig
58 |
59 | # collation
60 |
61 | def collate_fn(batch):
62 | mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
63 | mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
64 | max_mel_length = mel_lengths.amax()
65 |
66 | padded_mel_specs = []
67 | for spec in mel_specs:
68 | padding = (0, max_mel_length - spec.size(-1))
69 | padded_spec = F.pad(spec, padding, value = 0)
70 | padded_mel_specs.append(padded_spec)
71 |
72 | mel_specs = torch.stack(padded_mel_specs)
73 |
74 | text = [item['text'] for item in batch]
75 | text_lengths = torch.LongTensor([len(item) for item in text])
76 |
77 | return dict(
78 | mel = mel_specs,
79 | mel_lengths = mel_lengths,
80 | text = text,
81 | text_lengths = text_lengths,
82 | )
83 |
84 | # dataset
85 |
86 | class HFDataset(Dataset):
87 | def __init__(
88 | self,
89 | hf_dataset: Dataset,
90 | target_sample_rate = 24_000,
91 | hop_length = 256
92 | ):
93 | self.data = hf_dataset
94 | self.target_sample_rate = target_sample_rate
95 | self.hop_length = hop_length
96 | self.mel_spectrogram = MelSpec(sampling_rate=target_sample_rate)
97 |
98 | def __len__(self):
99 | return len(self.data)
100 |
101 | def __getitem__(self, index):
102 | row = self.data[index]
103 | audio = row['audio']['array']
104 |
105 | logger.info(f"Audio shape: {audio.shape}")
106 |
107 | sample_rate = row['audio']['sampling_rate']
108 | duration = audio.shape[-1] / sample_rate
109 |
110 | if duration > 20 or duration < 0.3:
111 | logger.warning(f"Skipping due to duration out of bound: {duration}")
112 | return self.__getitem__((index + 1) % len(self.data))
113 |
114 | audio_tensor = torch.from_numpy(audio).float()
115 |
116 | if sample_rate != self.target_sample_rate:
117 | resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
118 | audio_tensor = resampler(audio_tensor)
119 |
120 | audio_tensor = rearrange(audio_tensor, 't -> 1 t')
121 |
122 | mel_spec = self.mel_spectrogram(audio_tensor)
123 |
124 | mel_spec = rearrange(mel_spec, '1 d t -> d t')
125 |
126 | text = row['transcript']
127 |
128 | return dict(
129 | mel_spec = mel_spec,
130 | text = text,
131 | )
132 |
133 | # trainer
134 |
135 | class E2Trainer:
136 | def __init__(
137 | self,
138 | model: E2TTS,
139 | optimizer: Optimizer | None = None,
140 | learning_rate = 7.5e-5,
141 | num_warmup_steps = 20000,
142 | grad_accumulation_steps = 1,
143 | duration_predictor: DurationPredictor | None = None,
144 | checkpoint_path = None,
145 | log_file = "logs.txt",
146 | max_grad_norm = 1.0,
147 | sample_rate = 22050,
148 | tensorboard_log_dir = 'runs/e2_tts_experiment',
149 | accelerate_kwargs: dict = dict(),
150 | ema_kwargs: dict = dict(),
151 | use_switch_ema = False
152 | ):
153 | logger.add(log_file)
154 |
155 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
156 |
157 | self.accelerator = Accelerator(
158 | log_with = "all",
159 | kwargs_handlers = [ddp_kwargs],
160 | gradient_accumulation_steps = grad_accumulation_steps,
161 | **accelerate_kwargs
162 | )
163 |
164 | self.target_sample_rate = sample_rate
165 |
166 | self.model = model
167 |
168 | self.need_velocity_consistent_loss = model.velocity_consistency_weight > 0.
169 |
170 | self.ema_model = EMA(
171 | model,
172 | include_online_model = False,
173 | **ema_kwargs
174 | )
175 |
176 | self.use_switch_ema = use_switch_ema
177 |
178 | self.duration_predictor = duration_predictor
179 |
180 | # optimizer
181 |
182 | if not exists(optimizer):
183 | optimizer = Adopt(model.parameters(), lr = learning_rate)
184 |
185 | self.optimizer = optimizer
186 |
187 | self.num_warmup_steps = num_warmup_steps
188 | self.mel_spectrogram = MelSpec(sampling_rate=self.target_sample_rate)
189 |
190 | self.ema_model, self.model, self.optimizer = self.accelerator.prepare(
191 | self.ema_model, self.model, self.optimizer
192 | )
193 | self.max_grad_norm = max_grad_norm
194 |
195 | self.checkpoint_path = default(checkpoint_path, 'model.pth')
196 | self.writer = SummaryWriter(log_dir=tensorboard_log_dir)
197 |
198 | @property
199 | def is_main(self):
200 | return self.accelerator.is_main_process
201 |
202 | def save_checkpoint(self, step, finetune=False):
203 | self.accelerator.wait_for_everyone()
204 | if self.is_main:
205 | checkpoint = dict(
206 | model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
207 | optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
208 | ema_model_state_dict = self.ema_model.state_dict(),
209 | scheduler_state_dict = self.scheduler.state_dict(),
210 | step = step
211 | )
212 |
213 | self.accelerator.save(checkpoint, self.checkpoint_path)
214 |
215 | def load_checkpoint(self):
216 | if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path):
217 | return 0
218 |
219 | checkpoint = torch.load(self.checkpoint_path)
220 | self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
221 | self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
222 |
223 | if self.is_main:
224 | self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
225 |
226 | if self.scheduler:
227 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
228 | return checkpoint['step']
229 |
230 | def train(self, train_dataset, epochs, batch_size, num_workers=12, save_step=1000):
231 |
232 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers, pin_memory=True)
233 | total_steps = len(train_dataloader) * epochs
234 | decay_steps = total_steps - self.num_warmup_steps
235 | warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=self.num_warmup_steps)
236 | decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
237 | self.scheduler = SequentialLR(self.optimizer,
238 | schedulers=[warmup_scheduler, decay_scheduler],
239 | milestones=[self.num_warmup_steps])
240 | train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler)
241 | start_step = self.load_checkpoint()
242 | global_step = start_step
243 |
244 | for epoch in range(epochs):
245 | self.model.train()
246 | progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
247 | epoch_loss = 0.0
248 |
249 | for batch in progress_bar:
250 | with self.accelerator.accumulate(self.model):
251 | text_inputs = batch['text']
252 | mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
253 | mel_lengths = batch["mel_lengths"]
254 |
255 | if exists(self.duration_predictor):
256 | dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
257 | self.writer.add_scalar('duration loss', dur_loss.item(), global_step)
258 |
259 | velocity_consistency_model = None
260 | if self.need_velocity_consistent_loss and self.ema_model.initted:
261 | velocity_consistency_model = self.accelerator.unwrap_model(self.ema_model).ema_model
262 |
263 | loss, cond, pred, pred_data = self.model(
264 | mel_spec,
265 | text=text_inputs,
266 | lens=mel_lengths,
267 | velocity_consistency_model=velocity_consistency_model
268 | )
269 |
270 | self.accelerator.backward(loss)
271 |
272 | if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
273 | self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
274 |
275 | self.optimizer.step()
276 | self.scheduler.step()
277 | self.optimizer.zero_grad()
278 |
279 | self.accelerator.unwrap_model(self.ema_model).update()
280 |
281 | if self.accelerator.is_local_main_process:
282 | logger.info(f"step {global_step+1}: loss = {loss.item():.4f}")
283 | self.writer.add_scalar('loss', loss.item(), global_step)
284 | self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
285 |
286 | global_step += 1
287 | epoch_loss += loss.item()
288 | progress_bar.set_postfix(loss=loss.item())
289 |
290 | if global_step % save_step == 0:
291 | self.save_checkpoint(global_step)
292 | self.writer.add_figure("mel/target", plot_spectrogram(mel_spec[0,:,:]), global_step)
293 | self.writer.add_figure("mel/mask", plot_spectrogram(cond[0,:,:]), global_step)
294 | self.writer.add_figure("mel/prediction", plot_spectrogram(pred_data[0,:,:]), global_step)
295 |
296 | epoch_loss /= len(train_dataloader)
297 | if self.accelerator.is_local_main_process:
298 | logger.info(f"epoch {epoch+1}/{epochs} - average loss = {epoch_loss:.4f}")
299 | self.writer.add_scalar('epoch average loss', epoch_loss, epoch)
300 |
301 | if self.use_switch_ema:
302 | self.ema_model.update_model_with_ema()
303 |
304 | self.writer.close()
305 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "e2-tts-pytorch"
3 | version = "2.2.1"
4 | description = "E2-TTS in Pytorch"
5 | authors = [
6 | { name = "Phil Wang", email = "lucidrains@gmail.com" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">= 3.8"
10 | license = { file = "LICENSE" }
11 | keywords = [
12 | 'artificial intelligence',
13 | 'deep learning',
14 | 'transformers',
15 | 'attention mechanism',
16 | 'text to speech'
17 | ]
18 | classifiers=[
19 | 'Development Status :: 4 - Beta',
20 | 'Intended Audience :: Developers',
21 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
22 | 'License :: OSI Approved :: MIT License',
23 | 'Programming Language :: Python :: 3.8',
24 | ]
25 |
26 | dependencies = [
27 | 'accelerate>=0.33.0',
28 | 'adam-atan2-pytorch>=0.1.12',
29 | 'beartype',
30 | 'einops>=0.8.0',
31 | 'einx>=0.3.0',
32 | 'ema-pytorch>=0.5.2',
33 | 'hl-gauss-pytorch>=0.1.7',
34 | 'hyper-connections>=0.0.10',
35 | 'g2p-en',
36 | 'jaxtyping',
37 | 'loguru',
38 | 'pydantic<2',
39 | 'tensorboard',
40 | 'torch>=2.0',
41 | 'torchdiffeq',
42 | 'torchaudio>=2.3.1',
43 | 'tqdm>=4.65.0',
44 | 'vocos',
45 | 'x-transformers>=1.42.23',
46 | ]
47 |
48 | [project.urls]
49 | Homepage = "https://pypi.org/project/e2-tts-pytorch/"
50 | Repository = "https://github.com/lucidrains/e2-tts-pytorch"
51 |
52 | [project.optional-dependencies]
53 | examples = ["datasets"]
54 |
55 | [build-system]
56 | requires = ["hatchling"]
57 | build-backend = "hatchling.build"
58 |
59 | [tool.hatch.metadata]
60 | allow-direct-references = true
61 |
62 | [tool.hatch.build.targets.wheel]
63 | packages = ["e2_tts_pytorch"]
64 |
--------------------------------------------------------------------------------
/train_example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from e2_tts_pytorch import E2TTS, DurationPredictor
3 |
4 | from datasets import load_dataset
5 |
6 | from e2_tts_pytorch.trainer import (
7 | HFDataset,
8 | E2Trainer
9 | )
10 |
11 | duration_predictor = DurationPredictor(
12 | transformer = dict(
13 | dim = 512,
14 | depth = 6,
15 | )
16 | )
17 |
18 | e2tts = E2TTS(
19 | duration_predictor = duration_predictor,
20 | transformer = dict(
21 | dim = 512,
22 | depth = 12
23 | ),
24 | )
25 |
26 | train_dataset = HFDataset(load_dataset("MushanW/GLOBE")["train"])
27 |
28 | trainer = E2Trainer(
29 | e2tts,
30 | num_warmup_steps=20000,
31 | grad_accumulation_steps = 1,
32 | checkpoint_path = 'e2tts.pt',
33 | log_file = 'e2tts.txt'
34 | )
35 |
36 | epochs = 10
37 | batch_size = 32
38 |
39 | trainer.train(train_dataset, epochs, batch_size, save_step=1000)
40 |
--------------------------------------------------------------------------------