├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── e2-tts.png ├── e2_tts_pytorch ├── __init__.py ├── e2_tts.py └── trainer.py ├── pyproject.toml └── train_example.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ## E2 TTS - Pytorch 5 | 6 | Implementation of E2-TTS, Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS, in Pytorch 7 | 8 | The repository differs from the paper in that it uses a multistream transformer for text and audio, with conditioning done every transformer block in the E2 manner. 9 | 10 | It also includes an improvisation that was proven out by Manmay, where the text is simply interpolated to the length of the audio for conditioning. You can try this by setting `interpolated_text = True` on `E2TTS` 11 | 12 | ## Appreciation 13 | 14 | - Manmay for contributing working end-to-end training code! 15 | 16 | - Lucas Newman for the code contributions, helpful feedback, and for sharing the first set of positive experiments! 17 | 18 | - Jing for sharing the second positive result with a multilingual (English + Chinese) dataset! 19 | 20 | - Coice and Manmay for reporting the third and fourth successful runs. Farewell alignment engineering 21 | 22 | ## Install 23 | 24 | ```bash 25 | $ pip install e2-tts-pytorch 26 | ``` 27 | 28 | ## Usage 29 | 30 | ```python 31 | import torch 32 | 33 | from e2_tts_pytorch import ( 34 | E2TTS, 35 | DurationPredictor 36 | ) 37 | 38 | duration_predictor = DurationPredictor( 39 | transformer = dict( 40 | dim = 512, 41 | depth = 8, 42 | ) 43 | ) 44 | 45 | mel = torch.randn(2, 1024, 100) 46 | text = ['Hello', 'Goodbye'] 47 | 48 | loss = duration_predictor(mel, text = text) 49 | loss.backward() 50 | 51 | e2tts = E2TTS( 52 | duration_predictor = duration_predictor, 53 | transformer = dict( 54 | dim = 512, 55 | depth = 8 56 | ), 57 | ) 58 | 59 | out = e2tts(mel, text = text) 60 | out.loss.backward() 61 | 62 | sampled = e2tts.sample(mel[:, :5], text = text) 63 | 64 | ``` 65 | 66 | ## Related Works 67 | 68 | - [Nanospeech](https://github.com/lucasnewman/nanospeech) by [Lucas Newman](https://github.com/lucasnewman), which contains training code, working examples, as well as interoperable MLX version! 69 | 70 | ## Citations 71 | 72 | ```bibtex 73 | @inproceedings{Eskimez2024E2TE, 74 | title = {E2 TTS: Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS}, 75 | author = {Sefik Emre Eskimez and Xiaofei Wang and Manthan Thakker and Canrun Li and Chung-Hsien Tsai and Zhen Xiao and Hemin Yang and Zirun Zhu and Min Tang and Xu Tan and Yanqing Liu and Sheng Zhao and Naoyuki Kanda}, 76 | year = {2024}, 77 | url = {https://api.semanticscholar.org/CorpusID:270738197} 78 | } 79 | ``` 80 | 81 | ```bibtex 82 | @inproceedings{Darcet2023VisionTN, 83 | title = {Vision Transformers Need Registers}, 84 | author = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski}, 85 | year = {2023}, 86 | url = {https://api.semanticscholar.org/CorpusID:263134283} 87 | } 88 | ``` 89 | 90 | ```bibtex 91 | @article{Bao2022AllAW, 92 | title = {All are Worth Words: A ViT Backbone for Diffusion Models}, 93 | author = {Fan Bao and Shen Nie and Kaiwen Xue and Yue Cao and Chongxuan Li and Hang Su and Jun Zhu}, 94 | journal = {2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 95 | year = {2022}, 96 | pages = {22669-22679}, 97 | url = {https://api.semanticscholar.org/CorpusID:253581703} 98 | } 99 | ``` 100 | 101 | ```bibtex 102 | @article{Burtsev2021MultiStreamT, 103 | title = {Multi-Stream Transformers}, 104 | author = {Mikhail S. Burtsev and Anna Rumshisky}, 105 | journal = {ArXiv}, 106 | year = {2021}, 107 | volume = {abs/2107.10342}, 108 | url = {https://api.semanticscholar.org/CorpusID:236171087} 109 | } 110 | ``` 111 | 112 | ```bibtex 113 | @inproceedings{Sadat2024EliminatingOA, 114 | title = {Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models}, 115 | author = {Seyedmorteza Sadat and Otmar Hilliges and Romann M. Weber}, 116 | year = {2024}, 117 | url = {https://api.semanticscholar.org/CorpusID:273098845} 118 | } 119 | ``` 120 | 121 | ```bibtex 122 | @article{Gulati2020ConformerCT, 123 | title = {Conformer: Convolution-augmented Transformer for Speech Recognition}, 124 | author = {Anmol Gulati and James Qin and Chung-Cheng Chiu and Niki Parmar and Yu Zhang and Jiahui Yu and Wei Han and Shibo Wang and Zhengdong Zhang and Yonghui Wu and Ruoming Pang}, 125 | journal = {ArXiv}, 126 | year = {2020}, 127 | volume = {abs/2005.08100}, 128 | url = {https://api.semanticscholar.org/CorpusID:218674528} 129 | } 130 | ``` 131 | 132 | ```bibtex 133 | @article{Yang2024ConsistencyFM, 134 | title = {Consistency Flow Matching: Defining Straight Flows with Velocity Consistency}, 135 | author = {Ling Yang and Zixiang Zhang and Zhilong Zhang and Xingchao Liu and Minkai Xu and Wentao Zhang and Chenlin Meng and Stefano Ermon and Bin Cui}, 136 | journal = {ArXiv}, 137 | year = {2024}, 138 | volume = {abs/2407.02398}, 139 | url = {https://api.semanticscholar.org/CorpusID:270878436} 140 | } 141 | ``` 142 | 143 | ```bibtex 144 | @article{Li2024SwitchEA, 145 | title = {Switch EMA: A Free Lunch for Better Flatness and Sharpness}, 146 | author = {Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li}, 147 | journal = {ArXiv}, 148 | year = {2024}, 149 | volume = {abs/2402.09240}, 150 | url = {https://api.semanticscholar.org/CorpusID:267657558} 151 | } 152 | ``` 153 | 154 | ```bibtex 155 | @inproceedings{Zhou2024ValueRL, 156 | title = {Value Residual Learning For Alleviating Attention Concentration In Transformers}, 157 | author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan}, 158 | year = {2024}, 159 | url = {https://api.semanticscholar.org/CorpusID:273532030} 160 | } 161 | ``` 162 | 163 | ```bibtex 164 | @inproceedings{Duvvuri2024LASERAW, 165 | title = {LASER: Attention with Exponential Transformation}, 166 | author = {Sai Surya Duvvuri and Inderjit S. Dhillon}, 167 | year = {2024}, 168 | url = {https://api.semanticscholar.org/CorpusID:273849947} 169 | } 170 | ``` 171 | 172 | ```bibtex 173 | @article{Zhu2024HyperConnections, 174 | title = {Hyper-Connections}, 175 | author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou}, 176 | journal = {ArXiv}, 177 | year = {2024}, 178 | volume = {abs/2409.19606}, 179 | url = {https://api.semanticscholar.org/CorpusID:272987528} 180 | } 181 | ``` 182 | 183 | ```bibtex 184 | @inproceedings{Lu2023MusicSS, 185 | title = {Music Source Separation with Band-Split RoPE Transformer}, 186 | author = {Wei-Tsung Lu and Ju-Chiang Wang and Qiuqiang Kong and Yun-Ning Hung}, 187 | year = {2023}, 188 | url = {https://api.semanticscholar.org/CorpusID:261556702} 189 | } 190 | ``` 191 | 192 | ```bibtex 193 | @inproceedings{Dong2025FANformerIL, 194 | title = {FANformer: Improving Large Language Models Through Effective Periodicity Modeling}, 195 | author = {Yi Dong and Ge Li and Xue Jiang and Yongding Tao and Kechi Zhang and Hao Zhu and Huanyu Liu and Jiazheng Ding and Jia Li and Jinliang Deng and Hong Mei}, 196 | year = {2025}, 197 | url = {https://api.semanticscholar.org/CorpusID:276724636} 198 | } 199 | ``` 200 | 201 | ```bibtex 202 | @article{Karras2024GuidingAD, 203 | title = {Guiding a Diffusion Model with a Bad Version of Itself}, 204 | author = {Tero Karras and Miika Aittala and Tuomas Kynk{\"a}{\"a}nniemi and Jaakko Lehtinen and Timo Aila and Samuli Laine}, 205 | journal = {ArXiv}, 206 | year = {2024}, 207 | volume = {abs/2406.02507}, 208 | url = {https://api.semanticscholar.org/CorpusID:270226598} 209 | } 210 | ``` 211 | -------------------------------------------------------------------------------- /e2-tts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/e2-tts-pytorch/1028716e6a4072f668716687d732a3b123dda4cf/e2-tts.png -------------------------------------------------------------------------------- /e2_tts_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from e2_tts_pytorch.e2_tts import ( 2 | Transformer, 3 | DurationPredictor, 4 | E2TTS, 5 | ) 6 | 7 | from e2_tts_pytorch.trainer import ( 8 | E2Trainer 9 | ) -------------------------------------------------------------------------------- /e2_tts_pytorch/e2_tts.py: -------------------------------------------------------------------------------- 1 | """ 2 | ein notation: 3 | b - batch 4 | n - sequence 5 | f - frequency token dimension 6 | nt - text sequence 7 | nw - raw wave length 8 | d - dimension 9 | dt - dimension text 10 | """ 11 | 12 | from __future__ import annotations 13 | 14 | from pathlib import Path 15 | from random import random 16 | from functools import partial 17 | from itertools import zip_longest 18 | from collections import namedtuple 19 | 20 | from typing import Literal, Callable 21 | 22 | import jaxtyping 23 | from beartype import beartype 24 | 25 | import torch 26 | import torch.nn.functional as F 27 | from torch import nn, tensor, Tensor, from_numpy 28 | from torch.nn import Module, ModuleList, Sequential, Linear 29 | from torch.nn.utils.rnn import pad_sequence 30 | 31 | import torchaudio 32 | from torchaudio.functional import DB_to_amplitude 33 | from torchdiffeq import odeint 34 | 35 | import einx 36 | from einops.layers.torch import Rearrange, Reduce 37 | from einops import rearrange, repeat, reduce, einsum, pack, unpack 38 | 39 | from x_transformers import ( 40 | Attention, 41 | FeedForward, 42 | RMSNorm, 43 | AdaptiveRMSNorm, 44 | ) 45 | 46 | from x_transformers.x_transformers import RotaryEmbedding 47 | 48 | from hyper_connections import HyperConnections 49 | 50 | from hl_gauss_pytorch import HLGaussLayer 51 | 52 | from vocos import Vocos 53 | 54 | pad_sequence = partial(pad_sequence, batch_first = True) 55 | 56 | # constants 57 | 58 | class TorchTyping: 59 | def __init__(self, abstract_dtype): 60 | self.abstract_dtype = abstract_dtype 61 | 62 | def __getitem__(self, shapes: str): 63 | return self.abstract_dtype[Tensor, shapes] 64 | 65 | Float = TorchTyping(jaxtyping.Float) 66 | Int = TorchTyping(jaxtyping.Int) 67 | Bool = TorchTyping(jaxtyping.Bool) 68 | 69 | # named tuples 70 | 71 | LossBreakdown = namedtuple('LossBreakdown', ['flow', 'velocity_consistency']) 72 | 73 | E2TTSReturn = namedtuple('E2TTS', ['loss', 'cond', 'pred_flow', 'pred_data', 'loss_breakdown']) 74 | 75 | # helpers 76 | 77 | def exists(v): 78 | return v is not None 79 | 80 | def default(v, d): 81 | return v if exists(v) else d 82 | 83 | def xnor(x, y): 84 | return not (x ^ y) 85 | 86 | def set_if_missing_key(d, key, value): 87 | if key in d: 88 | return 89 | 90 | d.update(**{key: value}) 91 | 92 | def l2norm(t): 93 | return F.normalize(t, dim = -1) 94 | 95 | def divisible_by(num, den): 96 | return (num % den) == 0 97 | 98 | def pack_one_with_inverse(x, pattern): 99 | packed, packed_shape = pack([x], pattern) 100 | 101 | def inverse(x, inverse_pattern = None): 102 | inverse_pattern = default(inverse_pattern, pattern) 103 | return unpack(x, packed_shape, inverse_pattern)[0] 104 | 105 | return packed, inverse 106 | 107 | class Identity(Module): 108 | def forward(self, x, **kwargs): 109 | return x 110 | 111 | # tensor helpers 112 | 113 | def project(x, y): 114 | x, inverse = pack_one_with_inverse(x, 'b *') 115 | y, _ = pack_one_with_inverse(y, 'b *') 116 | 117 | dtype = x.dtype 118 | x, y = x.double(), y.double() 119 | unit = F.normalize(y, dim = -1) 120 | 121 | parallel = (x * unit).sum(dim = -1, keepdim = True) * unit 122 | orthogonal = x - parallel 123 | 124 | return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype) 125 | 126 | # simple utf-8 tokenizer, since paper went character based 127 | 128 | def list_str_to_tensor( 129 | text: list[str], 130 | padding_value = -1 131 | ) -> Int['b nt']: 132 | 133 | list_tensors = [tensor([*bytes(t, 'UTF-8')]) for t in text] 134 | padded_tensor = pad_sequence(list_tensors, padding_value = -1) 135 | return padded_tensor 136 | 137 | # simple english phoneme-based tokenizer 138 | 139 | from g2p_en import G2p 140 | 141 | def get_g2p_en_encode(): 142 | g2p = G2p() 143 | 144 | # used by @lucasnewman successfully here 145 | # https://github.com/lucasnewman/e2-tts-pytorch/blob/ljspeech-test/e2_tts_pytorch/e2_tts.py 146 | 147 | phoneme_to_index = g2p.p2idx 148 | num_phonemes = len(phoneme_to_index) 149 | 150 | extended_chars = [' ', ',', '.', '-', '!', '?', '\'', '"', '...', '..', '. .', '. . .', '. . . .', '. . . . .', '. ...', '... .', '.. ..'] 151 | num_extended_chars = len(extended_chars) 152 | 153 | extended_chars_dict = {p: (num_phonemes + i) for i, p in enumerate(extended_chars)} 154 | phoneme_to_index = {**phoneme_to_index, **extended_chars_dict} 155 | 156 | def encode( 157 | text: list[str], 158 | padding_value = -1 159 | ) -> Int['b nt']: 160 | 161 | phonemes = [g2p(t) for t in text] 162 | list_tensors = [tensor([phoneme_to_index[p] for p in one_phoneme]) for one_phoneme in phonemes] 163 | padded_tensor = pad_sequence(list_tensors, padding_value = -1) 164 | return padded_tensor 165 | 166 | return encode, (num_phonemes + num_extended_chars) 167 | 168 | # tensor helpers 169 | 170 | def log(t, eps = 1e-5): 171 | return t.clamp(min = eps).log() 172 | 173 | def lens_to_mask( 174 | t: Int['b'], 175 | length: int | None = None 176 | ) -> Bool['b n']: 177 | 178 | if not exists(length): 179 | length = t.amax() 180 | 181 | seq = torch.arange(length, device = t.device) 182 | return einx.less('n, b -> b n', seq, t) 183 | 184 | def mask_from_start_end_indices( 185 | seq_len: Int['b'], 186 | start: Int['b'], 187 | end: Int['b'] 188 | ): 189 | max_seq_len = seq_len.max().item() 190 | seq = torch.arange(max_seq_len, device = start.device).long() 191 | return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end) 192 | 193 | def mask_from_frac_lengths( 194 | seq_len: Int['b'], 195 | frac_lengths: Float['b'], 196 | max_length: int | None = None 197 | ): 198 | lengths = (frac_lengths * seq_len).long() 199 | max_start = seq_len - lengths 200 | 201 | rand = torch.rand_like(frac_lengths) 202 | start = (max_start * rand).long().clamp(min = 0) 203 | end = start + lengths 204 | 205 | out = mask_from_start_end_indices(seq_len, start, end) 206 | 207 | if exists(max_length): 208 | out = pad_to_length(out, max_length) 209 | 210 | return out 211 | 212 | def maybe_masked_mean( 213 | t: Float['b n d'], 214 | mask: Bool['b n'] | None = None 215 | ) -> Float['b d']: 216 | 217 | if not exists(mask): 218 | return t.mean(dim = 1) 219 | 220 | t = einx.where('b n, b n d, -> b n d', mask, t, 0.) 221 | num = reduce(t, 'b n d -> b d', 'sum') 222 | den = reduce(mask.float(), 'b n -> b', 'sum') 223 | 224 | return einx.divide('b d, b -> b d', num, den.clamp(min = 1.)) 225 | 226 | def pad_to_length( 227 | t: Tensor, 228 | length: int, 229 | value = None 230 | ): 231 | seq_len = t.shape[-1] 232 | if length > seq_len: 233 | t = F.pad(t, (0, length - seq_len), value = value) 234 | 235 | return t[..., :length] 236 | 237 | def interpolate_1d( 238 | x: Tensor, 239 | length: int, 240 | mode = 'bilinear' 241 | ): 242 | x = rearrange(x, 'n d -> 1 d n 1') 243 | x = F.interpolate(x, (length, 1), mode = mode) 244 | return rearrange(x, '1 d n 1 -> n d') 245 | 246 | # to mel spec 247 | 248 | class MelSpec(Module): 249 | def __init__( 250 | self, 251 | filter_length = 1024, 252 | hop_length = 256, 253 | win_length = 1024, 254 | n_mel_channels = 100, 255 | sampling_rate = 24_000, 256 | normalize = False, 257 | power = 1, 258 | norm = None, 259 | center = True, 260 | ): 261 | super().__init__() 262 | self.n_mel_channels = n_mel_channels 263 | self.sampling_rate = sampling_rate 264 | 265 | self.mel_stft = torchaudio.transforms.MelSpectrogram( 266 | sample_rate = sampling_rate, 267 | n_fft = filter_length, 268 | win_length = win_length, 269 | hop_length = hop_length, 270 | n_mels = n_mel_channels, 271 | power = power, 272 | center = center, 273 | normalized = normalize, 274 | norm = norm, 275 | ) 276 | 277 | self.register_buffer('dummy', tensor(0), persistent = False) 278 | 279 | def forward(self, inp): 280 | if len(inp.shape) == 3: 281 | inp = rearrange(inp, 'b 1 nw -> b nw') 282 | 283 | assert len(inp.shape) == 2 284 | 285 | if self.dummy.device != inp.device: 286 | self.to(inp.device) 287 | 288 | mel = self.mel_stft(inp) 289 | mel = log(mel) 290 | return mel 291 | 292 | # convolutional positional generating module 293 | # taken from https://github.com/lucidrains/voicebox-pytorch/blob/main/voicebox_pytorch/voicebox_pytorch.py#L203 294 | 295 | class DepthwiseConv(Module): 296 | def __init__( 297 | self, 298 | dim, 299 | *, 300 | kernel_size, 301 | groups = None 302 | ): 303 | super().__init__() 304 | assert not divisible_by(kernel_size, 2) 305 | groups = default(groups, dim) # full depthwise conv by default 306 | 307 | self.dw_conv1d = nn.Sequential( 308 | nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2), 309 | nn.SiLU() 310 | ) 311 | 312 | def forward( 313 | self, 314 | x, 315 | mask = None 316 | ): 317 | 318 | if exists(mask): 319 | x = einx.where('b n, b n d, -> b n d', mask, x, 0.) 320 | 321 | x = rearrange(x, 'b n c -> b c n') 322 | x = self.dw_conv1d(x) 323 | out = rearrange(x, 'b c n -> b n c') 324 | 325 | if exists(mask): 326 | out = einx.where('b n, b n d, -> b n d', mask, out, 0.) 327 | 328 | return out 329 | 330 | # adaln zero from DiT paper 331 | 332 | class AdaLNZero(Module): 333 | def __init__( 334 | self, 335 | dim, 336 | dim_condition = None, 337 | init_bias_value = -2. 338 | ): 339 | super().__init__() 340 | dim_condition = default(dim_condition, dim) 341 | self.to_gamma = nn.Linear(dim_condition, dim) 342 | 343 | nn.init.zeros_(self.to_gamma.weight) 344 | nn.init.constant_(self.to_gamma.bias, init_bias_value) 345 | 346 | def forward(self, x, *, condition): 347 | if condition.ndim == 2: 348 | condition = rearrange(condition, 'b d -> b 1 d') 349 | 350 | gamma = self.to_gamma(condition).sigmoid() 351 | return x * gamma 352 | 353 | # random projection fourier embedding 354 | 355 | class RandomFourierEmbed(Module): 356 | def __init__(self, dim): 357 | super().__init__() 358 | assert divisible_by(dim, 2) 359 | self.register_buffer('weights', torch.randn(dim // 2)) 360 | 361 | def forward(self, x): 362 | freqs = einx.multiply('i, j -> i j', x, self.weights) * 2 * torch.pi 363 | fourier_embed, _ = pack((x, freqs.sin(), freqs.cos()), 'b *') 364 | return fourier_embed 365 | 366 | # linear with fourier embedded outputs 367 | 368 | class LinearFourierEmbed(Module): 369 | def __init__( 370 | self, 371 | dim, 372 | p = 0.5, # percentage of output dimension to fourier, they found 0.5 to be best (0.25 sin + 0.25 cos) 373 | ): 374 | super().__init__() 375 | assert p <= 1. 376 | 377 | dim_fourier = int(p * dim) 378 | dim_rest = dim - (dim_fourier * 2) 379 | 380 | self.linear = nn.Linear(dim, dim_fourier + dim_rest, bias = False) 381 | self.split_dims = (dim_fourier, dim_rest) 382 | 383 | def forward(self, x): 384 | hiddens = self.linear(x) 385 | fourier, rest = hiddens.split(self.split_dims, dim = -1) 386 | return torch.cat((fourier.sin(), fourier.cos(), rest), dim = -1) 387 | 388 | # character embedding 389 | 390 | class CharacterEmbed(Module): 391 | def __init__( 392 | self, 393 | dim, 394 | num_embeds = 256, 395 | ): 396 | super().__init__() 397 | self.dim = dim 398 | self.embed = nn.Embedding(num_embeds + 1, dim) # will just use 0 as the 'filler token' 399 | 400 | def forward( 401 | self, 402 | text: Int['b nt'], 403 | max_seq_len: int, 404 | **kwargs 405 | ) -> Float['b n d']: 406 | 407 | text = text + 1 # shift all other token ids up by 1 and use 0 as filler token 408 | 409 | text = text[:, :max_seq_len] # just curtail if character tokens are more than the mel spec tokens, one of the edge cases the paper did not address 410 | text = pad_to_length(text, max_seq_len, value = 0) 411 | 412 | return self.embed(text) 413 | 414 | class InterpolatedCharacterEmbed(Module): 415 | def __init__( 416 | self, 417 | dim, 418 | num_embeds = 256, 419 | ): 420 | super().__init__() 421 | self.dim = dim 422 | self.embed = nn.Embedding(num_embeds, dim) 423 | 424 | self.abs_pos_mlp = Sequential( 425 | Rearrange('... -> ... 1'), 426 | Linear(1, dim), 427 | nn.SiLU(), 428 | Linear(dim, dim) 429 | ) 430 | 431 | def forward( 432 | self, 433 | text: Int['b nt'], 434 | max_seq_len: int, 435 | mask: Bool['b n'] | None = None 436 | ) -> Float['b n d']: 437 | 438 | device = text.device 439 | 440 | mask = default(mask, (None,)) 441 | 442 | interp_embeds = [] 443 | interp_abs_positions = [] 444 | 445 | for one_text, one_mask in zip_longest(text, mask): 446 | 447 | valid_text = one_text >= 0 448 | one_text = one_text[valid_text] 449 | one_text_embed = self.embed(one_text) 450 | 451 | # save the absolute positions 452 | 453 | text_seq_len = one_text.shape[0] 454 | 455 | # determine audio sequence length from mask 456 | 457 | audio_seq_len = max_seq_len 458 | if exists(one_mask): 459 | audio_seq_len = one_mask.sum().long().item() 460 | 461 | # interpolate text embedding to audio embedding length 462 | 463 | interp_text_embed = interpolate_1d(one_text_embed, audio_seq_len) 464 | interp_abs_pos = torch.linspace(0, text_seq_len, audio_seq_len, device = device) 465 | 466 | interp_embeds.append(interp_text_embed) 467 | interp_abs_positions.append(interp_abs_pos) 468 | 469 | interp_embeds = pad_sequence(interp_embeds) 470 | interp_abs_positions = pad_sequence(interp_abs_positions) 471 | 472 | interp_embeds = F.pad(interp_embeds, (0, 0, 0, max_seq_len - interp_embeds.shape[-2])) 473 | interp_abs_positions = pad_to_length(interp_abs_positions, max_seq_len) 474 | 475 | # pass interp absolute positions through mlp for implicit positions 476 | 477 | interp_embeds = interp_embeds + self.abs_pos_mlp(interp_abs_positions) 478 | 479 | if exists(mask): 480 | interp_embeds = einx.where('b n, b n d, -> b n d', mask, interp_embeds, 0.) 481 | 482 | return interp_embeds 483 | 484 | # text audio cross conditioning in multistream setup 485 | 486 | class TextAudioCrossCondition(Module): 487 | def __init__( 488 | self, 489 | dim, 490 | dim_text, 491 | cond_audio_to_text = True, 492 | ): 493 | super().__init__() 494 | self.text_to_audio = nn.Linear(dim_text + dim, dim, bias = False) 495 | nn.init.zeros_(self.text_to_audio.weight) 496 | 497 | self.cond_audio_to_text = cond_audio_to_text 498 | 499 | if cond_audio_to_text: 500 | self.audio_to_text = nn.Linear(dim + dim_text, dim_text, bias = False) 501 | nn.init.zeros_(self.audio_to_text.weight) 502 | 503 | def forward( 504 | self, 505 | audio: Float['b n d'], 506 | text: Float['b n dt'] 507 | ): 508 | audio_text, _ = pack((audio, text), 'b n *') 509 | 510 | text_cond = self.text_to_audio(audio_text) 511 | audio_cond = self.audio_to_text(audio_text) if self.cond_audio_to_text else 0. 512 | 513 | return audio + text_cond, text + audio_cond 514 | 515 | # attention and transformer backbone 516 | # for use in both e2tts as well as duration module 517 | 518 | class Transformer(Module): 519 | @beartype 520 | def __init__( 521 | self, 522 | *, 523 | dim, 524 | dim_text = None, # will default to half of audio dimension 525 | depth = 8, 526 | heads = 8, 527 | dim_head = 64, 528 | ff_mult = 4, 529 | text_depth = None, 530 | text_heads = None, 531 | text_dim_head = None, 532 | text_ff_mult = None, 533 | has_freq_axis = False, 534 | freq_heads = None, 535 | freq_dim_head = None, 536 | cond_on_time = True, 537 | abs_pos_emb = True, 538 | max_seq_len = 8192, 539 | kernel_size = 31, 540 | dropout = 0.1, 541 | num_registers = 32, 542 | scale_residual = False, 543 | attn_laser = False, 544 | attn_laser_softclamp_value = 15., 545 | attn_fourier_embed_input = False, 546 | attn_fourier_embed_input_frac = 0.25, # https://arxiv.org/abs/2502.21309 547 | num_residual_streams = 4, 548 | attn_kwargs: dict = dict( 549 | gate_value_heads = True, 550 | softclamp_logits = True, 551 | ), 552 | ff_kwargs: dict = dict(), 553 | ): 554 | super().__init__() 555 | assert divisible_by(depth, 2), 'depth needs to be even' 556 | 557 | # absolute positional embedding 558 | 559 | self.max_seq_len = max_seq_len 560 | self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if abs_pos_emb else None 561 | 562 | self.dim = dim 563 | 564 | # determine text related hparams 565 | 566 | dim_text = default(dim_text, dim // 2) 567 | self.dim_text = dim_text 568 | 569 | text_heads = default(text_heads, heads) 570 | text_dim_head = default(text_dim_head, dim_head) 571 | text_ff_mult = default(text_ff_mult, ff_mult) 572 | text_depth = default(text_depth, depth) 573 | 574 | assert 1 <= text_depth <= depth, 'must have at least 1 layer of text conditioning, but less than total number of speech layers' 575 | 576 | # determine maybe freq axis hparams 577 | 578 | freq_heads = default(freq_heads, heads) 579 | freq_dim_head = default(freq_dim_head, dim_head) 580 | 581 | self.has_freq_axis = has_freq_axis 582 | 583 | # layers 584 | 585 | self.depth = depth 586 | layers = [] 587 | 588 | # registers 589 | 590 | self.num_registers = num_registers 591 | self.registers = nn.Parameter(torch.zeros(num_registers, dim)) 592 | nn.init.normal_(self.registers, std = 0.02) 593 | 594 | self.text_registers = nn.Parameter(torch.zeros(num_registers, dim_text)) 595 | nn.init.normal_(self.text_registers, std = 0.02) 596 | 597 | # rotary embedding 598 | 599 | self.rotary_emb = RotaryEmbedding(dim_head) 600 | self.text_rotary_emb = RotaryEmbedding(text_dim_head) 601 | 602 | if has_freq_axis: 603 | self.freq_rotary_emb = RotaryEmbedding(freq_dim_head) 604 | 605 | # hyper connection related 606 | 607 | init_hyper_conn, self.hyper_conn_expand, self.hyper_conn_reduce = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1) 608 | 609 | hyper_conns = [] 610 | 611 | # time conditioning 612 | # will use adaptive rmsnorm 613 | 614 | self.cond_on_time = cond_on_time 615 | rmsnorm_klass = RMSNorm if not cond_on_time else AdaptiveRMSNorm 616 | postbranch_klass = Identity if not cond_on_time else partial(AdaLNZero, dim = dim) 617 | 618 | self.time_cond_mlp = Identity() 619 | 620 | if cond_on_time: 621 | self.time_cond_mlp = Sequential( 622 | RandomFourierEmbed(dim), 623 | Linear(dim + 1, dim), 624 | nn.SiLU() 625 | ) 626 | 627 | for ind in range(depth): 628 | is_first_block = ind == 0 629 | 630 | is_later_half = ind >= (depth // 2) 631 | has_text = ind < text_depth 632 | 633 | # speech related 634 | 635 | speech_conv = DepthwiseConv(dim, kernel_size = kernel_size) 636 | 637 | attn_norm = rmsnorm_klass(dim) 638 | 639 | attn_input_fourier_embed = LinearFourierEmbed(dim, p = attn_fourier_embed_input_frac) if attn_fourier_embed_input else nn.Identity() 640 | 641 | attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, learned_value_residual_mix = not is_first_block, laser = attn_laser, laser_softclamp_value = attn_laser_softclamp_value, **attn_kwargs) 642 | 643 | attn_adaln_zero = postbranch_klass() 644 | 645 | ff_norm = rmsnorm_klass(dim) 646 | ff = FeedForward(dim = dim, glu = True, mult = ff_mult, dropout = dropout, **ff_kwargs) 647 | ff_adaln_zero = postbranch_klass() 648 | 649 | skip_proj = Linear(dim * 2, dim, bias = False) if is_later_half else None 650 | 651 | freq_attn_norm = freq_attn = freq_attn_adaln_zero = None 652 | 653 | if has_freq_axis: 654 | freq_attn_norm = rmsnorm_klass(dim) 655 | freq_attn = Attention(dim = dim, heads = freq_heads, dim_head = freq_dim_head) 656 | freq_attn_adaln_zero = postbranch_klass() 657 | 658 | speech_modules = ModuleList([ 659 | skip_proj, 660 | speech_conv, 661 | attn_norm, 662 | attn, 663 | attn_input_fourier_embed, 664 | attn_adaln_zero, 665 | ff_norm, 666 | ff, 667 | ff_adaln_zero, 668 | freq_attn_norm, 669 | freq_attn, 670 | freq_attn_adaln_zero 671 | ]) 672 | 673 | speech_hyper_conns = ModuleList([ 674 | init_hyper_conn(dim = dim), # conv 675 | init_hyper_conn(dim = dim), # attn 676 | init_hyper_conn(dim = dim), # ff 677 | init_hyper_conn(dim = dim) if has_freq_axis else None 678 | ]) 679 | 680 | text_modules = None 681 | text_hyper_conns = None 682 | 683 | if has_text: 684 | # text related 685 | 686 | text_conv = DepthwiseConv(dim_text, kernel_size = kernel_size) 687 | 688 | text_attn_norm = RMSNorm(dim_text) 689 | text_attn = Attention(dim = dim_text, heads = text_heads, dim_head = text_dim_head, dropout = dropout, learned_value_residual_mix = not is_first_block, laser = attn_laser, laser_softclamp_value = attn_laser_softclamp_value, **attn_kwargs) 690 | 691 | text_ff_norm = RMSNorm(dim_text) 692 | text_ff = FeedForward(dim = dim_text, glu = True, mult = text_ff_mult, dropout = dropout, **ff_kwargs) 693 | 694 | # cross condition 695 | 696 | is_last = ind == (text_depth - 1) 697 | 698 | cross_condition = TextAudioCrossCondition(dim = dim, dim_text = dim_text, cond_audio_to_text = not is_last) 699 | 700 | text_modules = ModuleList([ 701 | text_conv, 702 | text_attn_norm, 703 | text_attn, 704 | text_ff_norm, 705 | text_ff, 706 | cross_condition 707 | ]) 708 | 709 | text_hyper_conns = ModuleList([ 710 | init_hyper_conn(dim = dim_text), # conv 711 | init_hyper_conn(dim = dim_text), # attn 712 | init_hyper_conn(dim = dim_text), # ff 713 | ]) 714 | 715 | hyper_conns.append(ModuleList([ 716 | speech_hyper_conns, 717 | text_hyper_conns 718 | ])) 719 | 720 | layers.append(ModuleList([ 721 | speech_modules, 722 | text_modules 723 | ])) 724 | 725 | self.layers = ModuleList(layers) 726 | 727 | self.hyper_conns = ModuleList(hyper_conns) 728 | 729 | self.final_norm = RMSNorm(dim) 730 | 731 | def forward( 732 | self, 733 | x: Float['b n d'] | Float['b f n d'], 734 | times: Float['b'] | Float[''] | None = None, 735 | mask: Bool['b n'] | None = None, 736 | text_embed: Float['b n dt'] | None = None, 737 | ): 738 | orig_batch = x.shape[0] 739 | 740 | assert xnor(x.ndim == 4, self.has_freq_axis), '`has_freq_axis` must be set if passing in tensor with frequency dimension (4 ndims), and not set if passing in only 3' 741 | 742 | freq_seq_len = 1 743 | 744 | if self.has_freq_axis: 745 | freq_seq_len = x.shape[1] 746 | x = rearrange(x, 'b f n d -> (b f) n d') 747 | 748 | if exists(text_embed): 749 | text_embed = repeat(text_embed, 'b ... -> (b f) ...', f = freq_seq_len) 750 | 751 | if exists(mask): 752 | mask = repeat(mask, 'b ... -> (b f) ...', f = freq_seq_len) 753 | 754 | batch, seq_len, device = x.shape[0], x.shape[1], x.device 755 | 756 | assert not (exists(times) ^ self.cond_on_time), '`times` must be passed in if `cond_on_time` is set to `True` and vice versa' 757 | 758 | # handle absolute positions if needed 759 | 760 | if exists(self.abs_pos_emb): 761 | assert seq_len <= self.max_seq_len, f'{seq_len} exceeds the set `max_seq_len` ({self.max_seq_len}) on Transformer' 762 | seq = torch.arange(seq_len, device = device) 763 | x = x + self.abs_pos_emb(seq) 764 | 765 | # register tokens 766 | 767 | registers = repeat(self.registers, 'r d -> b r d', b = batch) 768 | x, registers_packed_shape = pack((registers, x), 'b * d') 769 | 770 | if exists(mask): 771 | mask = F.pad(mask, (self.num_registers, 0), value = True) 772 | 773 | # handle adaptive rmsnorm kwargs 774 | 775 | norm_kwargs = dict() 776 | freq_norm_kwargs = dict() 777 | 778 | if exists(times): 779 | if times.ndim == 0: 780 | times = repeat(times, ' -> b', b = orig_batch) 781 | 782 | times = self.time_cond_mlp(times) 783 | 784 | if self.has_freq_axis: 785 | freq_times = repeat(times, 'b ... -> (b n) ...', n = x.shape[-2]) 786 | freq_norm_kwargs.update(condition = freq_times) 787 | 788 | times = repeat(times, 'b ... -> (b f) ...', f = freq_seq_len) 789 | norm_kwargs.update(condition = times) 790 | 791 | # rotary embedding 792 | 793 | rotary_pos_emb = self.rotary_emb.forward_from_seq_len(x.shape[-2]) 794 | 795 | # text related 796 | 797 | if exists(text_embed): 798 | text_rotary_pos_emb = self.text_rotary_emb.forward_from_seq_len(x.shape[-2]) 799 | 800 | text_registers = repeat(self.text_registers, 'r d -> b r d', b = batch) 801 | text_embed, _ = pack((text_registers, text_embed), 'b * d') 802 | 803 | if self.has_freq_axis: 804 | freq_rotary_pos_emb = self.freq_rotary_emb.forward_from_seq_len(freq_seq_len) 805 | 806 | # skip connection related stuff 807 | 808 | skips = [] 809 | 810 | # value residual 811 | 812 | text_attn_first_values = None 813 | freq_attn_first_values = None 814 | attn_first_values = None 815 | 816 | # expand hyper connections 817 | 818 | x = self.hyper_conn_expand(x) 819 | 820 | if exists(text_embed): 821 | text_embed = self.hyper_conn_expand(text_embed) 822 | 823 | # go through the layers 824 | 825 | for ind, ((speech_modules, text_modules), (speech_residual_fns, text_residual_fns)) in enumerate(zip(self.layers, self.hyper_conns)): 826 | 827 | layer = ind + 1 828 | 829 | ( 830 | maybe_skip_proj, 831 | speech_conv, 832 | attn_norm, 833 | attn, 834 | attn_input_fourier_embed, 835 | maybe_attn_adaln_zero, 836 | ff_norm, 837 | ff, 838 | maybe_ff_adaln_zero, 839 | maybe_freq_attn_norm, 840 | maybe_freq_attn, 841 | maybe_freq_attn_adaln_zero 842 | ) = speech_modules 843 | 844 | ( 845 | conv_residual, 846 | attn_residual, 847 | ff_residual, 848 | maybe_freq_attn_residual 849 | ) = speech_residual_fns 850 | 851 | # smaller text transformer 852 | 853 | if exists(text_embed) and exists(text_modules): 854 | 855 | ( 856 | text_conv, 857 | text_attn_norm, 858 | text_attn, 859 | text_ff_norm, 860 | text_ff, 861 | cross_condition 862 | ) = text_modules 863 | 864 | ( 865 | text_conv_residual, 866 | text_attn_residual, 867 | text_ff_residual 868 | ) = text_residual_fns 869 | 870 | text_embed, add_residual = text_conv_residual(text_embed) 871 | text_embed = text_conv(text_embed, mask = mask) 872 | text_embed = add_residual(text_embed) 873 | 874 | text_embed, add_residual = text_attn_residual(text_embed) 875 | text_attn_out, text_attn_inter = text_attn(text_attn_norm(text_embed), rotary_pos_emb = text_rotary_pos_emb, mask = mask, return_intermediates = True, value_residual = text_attn_first_values) 876 | text_embed = add_residual(text_attn_out) 877 | 878 | text_attn_first_values = default(text_attn_first_values, text_attn_inter.values) 879 | 880 | text_embed, add_residual = text_ff_residual(text_embed) 881 | text_embed = text_ff(text_ff_norm(text_embed)) 882 | text_embed = add_residual(text_embed) 883 | x, text_embed = cross_condition(x, text_embed) 884 | 885 | # skip connection logic 886 | 887 | is_first_half = layer <= (self.depth // 2) 888 | is_later_half = not is_first_half 889 | 890 | if is_first_half: 891 | skips.append(x) 892 | 893 | if is_later_half: 894 | skip = skips.pop() 895 | x = torch.cat((x, skip), dim = -1) 896 | x = maybe_skip_proj(x) 897 | 898 | # position generating convolution 899 | 900 | x, add_residual = conv_residual(x) 901 | x = speech_conv(x, mask = mask) 902 | x = add_residual(x) 903 | 904 | # attention 905 | 906 | x, add_residual = attn_residual(x) 907 | 908 | x = attn_norm(x, **norm_kwargs) 909 | x = attn_input_fourier_embed(x) 910 | 911 | attn_out, attn_inter = attn(x, rotary_pos_emb = rotary_pos_emb, mask = mask, return_intermediates = True, value_residual = attn_first_values) 912 | 913 | attn_out = maybe_attn_adaln_zero(attn_out, **norm_kwargs) 914 | x = add_residual(attn_out) 915 | 916 | attn_first_values = default(attn_first_values, attn_inter.values) 917 | 918 | # attention across frequency tokens, if needed 919 | 920 | if self.has_freq_axis: 921 | 922 | x, add_residual = maybe_freq_attn_residual(x) 923 | 924 | x = rearrange(x, '(b f) n d -> (b n) f d', b = orig_batch) 925 | 926 | attn_out, attn_inter = maybe_freq_attn(maybe_freq_attn_norm(x, **freq_norm_kwargs), rotary_pos_emb = freq_rotary_pos_emb, return_intermediates = True, value_residual = freq_attn_first_values) 927 | attn_out = maybe_freq_attn_adaln_zero(attn_out, **freq_norm_kwargs) 928 | 929 | attn_out = rearrange(attn_out, '(b n) f d -> (b f) n d', b = orig_batch) 930 | 931 | x = add_residual(attn_out) 932 | freq_attn_first_values = default(freq_attn_first_values, attn_inter.values) 933 | 934 | # feedforward 935 | 936 | x, add_residual = ff_residual(x) 937 | ff_out = ff(ff_norm(x, **norm_kwargs)) 938 | ff_out = maybe_ff_adaln_zero(ff_out, **norm_kwargs) 939 | x = add_residual(ff_out) 940 | 941 | assert len(skips) == 0 942 | 943 | _, x = unpack(x, registers_packed_shape, 'b * d') 944 | 945 | # sum all residual streams from hyper connections 946 | 947 | x = self.hyper_conn_reduce(x) 948 | 949 | if self.has_freq_axis: 950 | x = rearrange(x, '(b f) n d -> b f n d', f = freq_seq_len) 951 | 952 | return self.final_norm(x) 953 | 954 | # main classes 955 | 956 | class DurationPredictor(Module): 957 | @beartype 958 | def __init__( 959 | self, 960 | transformer: dict | Transformer, 961 | num_channels = None, 962 | mel_spec_kwargs: dict = dict(), 963 | char_embed_kwargs: dict = dict(), 964 | text_num_embeds = None, 965 | num_freq_tokens = 1, 966 | hl_gauss_loss: dict | None = None, 967 | use_regression = True, 968 | tokenizer: ( 969 | Literal['char_utf8', 'phoneme_en'] | 970 | Callable[[list[str]], Int['b nt']] 971 | ) = 'char_utf8' 972 | ): 973 | super().__init__() 974 | 975 | # freq axis hparams 976 | 977 | assert num_freq_tokens > 0 978 | self.num_freq_tokens = num_freq_tokens 979 | self.has_freq_axis = num_freq_tokens > 1 980 | 981 | if isinstance(transformer, dict): 982 | set_if_missing_key(transformer, 'has_freq_axis', self.has_freq_axis) 983 | 984 | transformer = Transformer( 985 | **transformer, 986 | cond_on_time = False 987 | ) 988 | 989 | assert transformer.has_freq_axis == self.has_freq_axis 990 | 991 | # mel spec 992 | 993 | self.mel_spec = MelSpec(**mel_spec_kwargs) 994 | self.num_channels = default(num_channels, self.mel_spec.n_mel_channels) 995 | 996 | self.transformer = transformer 997 | 998 | dim = transformer.dim 999 | dim_text = transformer.dim_text 1000 | 1001 | self.dim = dim 1002 | 1003 | # projecting depends on whether frequency axis is needed 1004 | 1005 | if not self.has_freq_axis: 1006 | self.proj_in = Linear(self.num_channels, self.dim) 1007 | else: 1008 | self.proj_in = nn.Sequential( 1009 | Linear(self.num_channels, self.dim * num_freq_tokens), 1010 | Rearrange('b n (f d) -> b f n d', f = num_freq_tokens) 1011 | ) 1012 | 1013 | # tokenizer and text embed 1014 | 1015 | if callable(tokenizer): 1016 | assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function' 1017 | self.tokenizer = tokenizer 1018 | elif tokenizer == 'char_utf8': 1019 | text_num_embeds = 256 1020 | self.tokenizer = list_str_to_tensor 1021 | elif tokenizer == 'phoneme_en': 1022 | self.tokenizer, text_num_embeds = get_g2p_en_encode() 1023 | else: 1024 | raise ValueError(f'unknown tokenizer string {tokenizer}') 1025 | 1026 | self.embed_text = CharacterEmbed(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs) 1027 | 1028 | # maybe reduce frequencies 1029 | 1030 | self.maybe_reduce_freq_axis = Reduce('b f n d -> b n d', 'mean') if self.has_freq_axis else nn.Identity() 1031 | 1032 | # to prediction 1033 | # applying https://arxiv.org/abs/2403.03950 1034 | 1035 | self.hl_gauss_layer = HLGaussLayer( 1036 | dim, 1037 | hl_gauss_loss = hl_gauss_loss, 1038 | use_regression = use_regression, 1039 | regress_activation = nn.Softplus() 1040 | ) 1041 | 1042 | def forward( 1043 | self, 1044 | x: Float['b n d'] | Float['b nw'], 1045 | *, 1046 | text: Int['b nt'] | list[str] | None = None, 1047 | lens: Int['b'] | None = None, 1048 | return_loss = True 1049 | ): 1050 | # raw wave 1051 | 1052 | if x.ndim == 2: 1053 | x = self.mel_spec(x) 1054 | x = rearrange(x, 'b d n -> b n d') 1055 | assert x.shape[-1] == self.dim 1056 | 1057 | x = self.proj_in(x) 1058 | 1059 | batch, seq_len, device = x.shape[0], x.shape[-2], x.device 1060 | 1061 | # text 1062 | 1063 | text_embed = None 1064 | 1065 | if exists(text): 1066 | if isinstance(text, list): 1067 | text = list_str_to_tensor(text).to(device) 1068 | assert text.shape[0] == batch 1069 | 1070 | text_embed = self.embed_text(text, seq_len) 1071 | 1072 | # handle lengths (duration) 1073 | 1074 | if not exists(lens): 1075 | lens = torch.full((batch,), seq_len, device = device) 1076 | 1077 | mask = lens_to_mask(lens, length = seq_len) 1078 | 1079 | # if returning a loss, mask out randomly from an index and have it predict the duration 1080 | 1081 | if return_loss: 1082 | rand_frac_index = x.new_zeros(batch).uniform_(0, 1) 1083 | rand_index = (rand_frac_index * lens).long() 1084 | 1085 | seq = torch.arange(seq_len, device = device) 1086 | mask &= einx.less('n, b -> b n', seq, rand_index) 1087 | 1088 | # attending 1089 | 1090 | embed = self.transformer( 1091 | x, 1092 | mask = mask, 1093 | text_embed = text_embed, 1094 | ) 1095 | 1096 | # maybe reduce freq 1097 | 1098 | embed = self.maybe_reduce_freq_axis(embed) 1099 | 1100 | # masked mean 1101 | 1102 | pooled_embed = maybe_masked_mean(embed, mask) 1103 | 1104 | # return the prediction if not returning loss 1105 | 1106 | if not return_loss: 1107 | return self.hl_gauss_layer(pooled_embed) 1108 | 1109 | # loss 1110 | 1111 | loss = self.hl_gauss_layer(pooled_embed, lens.float()) 1112 | 1113 | return loss 1114 | 1115 | class E2TTS(Module): 1116 | 1117 | @beartype 1118 | def __init__( 1119 | self, 1120 | transformer: dict | Transformer = None, 1121 | duration_predictor: dict | DurationPredictor | None = None, 1122 | odeint_kwargs: dict = dict( 1123 | atol = 1e-5, 1124 | rtol = 1e-5, 1125 | method = 'midpoint' 1126 | ), 1127 | cond_drop_prob = 0.25, 1128 | num_channels = None, 1129 | mel_spec_module: Module | None = None, 1130 | num_freq_tokens = 1, 1131 | char_embed_kwargs: dict = dict(), 1132 | mel_spec_kwargs: dict = dict(), 1133 | frac_lengths_mask: tuple[float, float] = (0.7, 1.), 1134 | concat_cond = False, 1135 | interpolated_text = False, 1136 | text_num_embeds: int | None = None, 1137 | tokenizer: ( 1138 | Literal['char_utf8', 'phoneme_en'] | 1139 | Callable[[list[str]], Int['b nt']] 1140 | ) = 'char_utf8', 1141 | use_vocos = True, 1142 | pretrained_vocos_path = 'charactr/vocos-mel-24khz', 1143 | sampling_rate: int | None = None, 1144 | velocity_consistency_weight = 0., 1145 | ): 1146 | super().__init__() 1147 | 1148 | # freq axis hparams 1149 | 1150 | assert num_freq_tokens > 0 1151 | self.num_freq_tokens = num_freq_tokens 1152 | self.has_freq_axis = num_freq_tokens > 1 1153 | 1154 | # set transformer 1155 | 1156 | if isinstance(transformer, dict): 1157 | set_if_missing_key(transformer, 'has_freq_axis', self.has_freq_axis) 1158 | 1159 | transformer = Transformer( 1160 | **transformer, 1161 | cond_on_time = True 1162 | ) 1163 | 1164 | assert transformer.has_freq_axis == self.has_freq_axis 1165 | self.transformer = transformer 1166 | 1167 | # duration predictor 1168 | 1169 | if isinstance(duration_predictor, dict): 1170 | duration_predictor = DurationPredictor(**duration_predictor) 1171 | 1172 | # hparams 1173 | 1174 | dim = transformer.dim 1175 | dim_text = transformer.dim_text 1176 | 1177 | self.dim = dim 1178 | self.dim_text = dim_text 1179 | 1180 | self.frac_lengths_mask = frac_lengths_mask 1181 | 1182 | self.duration_predictor = duration_predictor 1183 | 1184 | # sampling 1185 | 1186 | self.odeint_kwargs = odeint_kwargs 1187 | 1188 | # mel spec 1189 | 1190 | self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) 1191 | num_channels = default(num_channels, self.mel_spec.n_mel_channels) 1192 | 1193 | self.num_channels = num_channels 1194 | self.sampling_rate = default(sampling_rate, getattr(self.mel_spec, 'sampling_rate', None)) 1195 | 1196 | # whether to concat condition and project rather than project both and sum 1197 | 1198 | self.concat_cond = concat_cond 1199 | 1200 | if concat_cond: 1201 | self.proj_in = nn.Linear(num_channels * 2, dim * num_freq_tokens) 1202 | else: 1203 | self.proj_in = nn.Linear(num_channels, dim * num_freq_tokens) 1204 | self.cond_proj_in = nn.Linear(num_channels, dim * num_freq_tokens) 1205 | 1206 | # maybe split out frequency 1207 | 1208 | self.maybe_split_freq = Rearrange('b n (f d) -> b f n d', f = num_freq_tokens) if self.has_freq_axis else nn.Identity() 1209 | 1210 | self.maybe_reduce_freq = Reduce('b f n d -> b n d', 'mean') if self.has_freq_axis else nn.Identity() 1211 | 1212 | # to prediction 1213 | 1214 | self.to_pred = Linear(dim, num_channels) 1215 | 1216 | # tokenizer and text embed 1217 | 1218 | if callable(tokenizer): 1219 | assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function' 1220 | self.tokenizer = tokenizer 1221 | elif tokenizer == 'char_utf8': 1222 | text_num_embeds = 256 1223 | self.tokenizer = list_str_to_tensor 1224 | elif tokenizer == 'phoneme_en': 1225 | self.tokenizer, text_num_embeds = get_g2p_en_encode() 1226 | else: 1227 | raise ValueError(f'unknown tokenizer string {tokenizer}') 1228 | 1229 | self.cond_drop_prob = cond_drop_prob 1230 | 1231 | # text embedding 1232 | 1233 | text_embed_klass = CharacterEmbed if not interpolated_text else InterpolatedCharacterEmbed 1234 | 1235 | self.embed_text = text_embed_klass(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs) 1236 | 1237 | # weight for velocity consistency 1238 | 1239 | self.register_buffer('zero', tensor(0.), persistent = False) 1240 | self.velocity_consistency_weight = velocity_consistency_weight 1241 | 1242 | # default vocos for mel -> audio 1243 | 1244 | self.vocos = Vocos.from_pretrained(pretrained_vocos_path) if use_vocos else None 1245 | 1246 | @property 1247 | def device(self): 1248 | return next(self.parameters()).device 1249 | 1250 | def transformer_with_pred_head( 1251 | self, 1252 | x: Float['b n d'], 1253 | cond: Float['b n d'], 1254 | times: Float['b'], 1255 | mask: Bool['b n'] | None = None, 1256 | text: Int['b nt'] | None = None, 1257 | drop_text_cond: bool | None = None, 1258 | return_drop_text_cond = False 1259 | ): 1260 | seq_len = x.shape[-2] 1261 | drop_text_cond = default(drop_text_cond, self.training and random() < self.cond_drop_prob) 1262 | 1263 | if self.concat_cond: 1264 | # concat condition, given as using voicebox-like scheme 1265 | x = torch.cat((cond, x), dim = -1) 1266 | 1267 | x = self.proj_in(x) 1268 | x = self.maybe_split_freq(x) 1269 | 1270 | if not self.concat_cond: 1271 | # an alternative is to simply sum the condition 1272 | # seems to work fine 1273 | 1274 | cond = self.cond_proj_in(cond) 1275 | cond = self.maybe_split_freq(cond) 1276 | 1277 | x = x + cond 1278 | 1279 | # whether to use a text embedding 1280 | 1281 | text_embed = None 1282 | if exists(text) and not drop_text_cond: 1283 | text_embed = self.embed_text(text, seq_len, mask = mask) 1284 | 1285 | # attend 1286 | 1287 | embed = self.transformer( 1288 | x, 1289 | times = times, 1290 | mask = mask, 1291 | text_embed = text_embed 1292 | ) 1293 | 1294 | embed = self.maybe_reduce_freq(embed) 1295 | 1296 | pred = self.to_pred(embed) 1297 | 1298 | if not return_drop_text_cond: 1299 | return pred 1300 | 1301 | return pred, drop_text_cond 1302 | 1303 | def cfg_transformer_with_pred_head( 1304 | self, 1305 | *args, 1306 | cfg_strength: float = 1., 1307 | cfg_null_model: E2TTS | None = None, 1308 | remove_parallel_component: bool = True, 1309 | keep_parallel_frac: float = 0., 1310 | **kwargs, 1311 | ): 1312 | 1313 | pred = self.transformer_with_pred_head(*args, drop_text_cond = False, **kwargs) 1314 | 1315 | if cfg_strength < 1e-5: 1316 | return pred 1317 | 1318 | null_drop_text_cond = not exists(cfg_null_model) 1319 | cfg_null_model = default(cfg_null_model, self) 1320 | 1321 | null_pred = cfg_null_model.transformer_with_pred_head(*args, drop_text_cond = null_drop_text_cond, **kwargs) 1322 | 1323 | cfg_update = pred - null_pred 1324 | 1325 | if remove_parallel_component: 1326 | # https://arxiv.org/abs/2410.02416 1327 | parallel, orthogonal = project(cfg_update, pred) 1328 | cfg_update = orthogonal + parallel * keep_parallel_frac 1329 | 1330 | return pred + cfg_update * cfg_strength 1331 | 1332 | @torch.no_grad() 1333 | def sample( 1334 | self, 1335 | cond: Float['b n d'] | Float['b nw'], 1336 | *, 1337 | text: Int['b nt'] | list[str] | None = None, 1338 | lens: Int['b'] | None = None, 1339 | duration: int | Int['b'] | None = None, 1340 | steps = 32, 1341 | cfg_strength = 1., # they used a classifier free guidance strength of 1. 1342 | cfg_null_model: E2TTS | None = None, # for "autoguidance" from Karras et al. https://arxiv.org/abs/2406.02507 1343 | max_duration = 4096, # in case the duration predictor goes haywire 1344 | vocoder: Callable[[Float['b d n']], list[Float['_']]] | None = None, 1345 | return_raw_output: bool | None = None, 1346 | save_to_filename: str | None = None 1347 | ) -> ( 1348 | Float['b n d'], 1349 | list[Float['_']] 1350 | ): 1351 | self.eval() 1352 | 1353 | # raw wave 1354 | 1355 | if cond.ndim == 2: 1356 | cond = self.mel_spec(cond) 1357 | cond = rearrange(cond, 'b d n -> b n d') 1358 | assert cond.shape[-1] == self.num_channels 1359 | 1360 | batch, cond_seq_len, device = *cond.shape[:2], cond.device 1361 | 1362 | if not exists(lens): 1363 | lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long) 1364 | 1365 | # text 1366 | 1367 | if isinstance(text, list): 1368 | text = self.tokenizer(text).to(device) 1369 | assert text.shape[0] == batch 1370 | 1371 | if exists(text): 1372 | text_lens = (text != -1).sum(dim = -1) 1373 | lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters 1374 | 1375 | # duration 1376 | 1377 | cond_mask = lens_to_mask(lens) 1378 | 1379 | if exists(duration): 1380 | if isinstance(duration, int): 1381 | duration = torch.full((batch,), duration, device = device, dtype = torch.long) 1382 | 1383 | elif exists(self.duration_predictor): 1384 | duration = self.duration_predictor(cond, text = text, lens = lens, return_loss = False).long() 1385 | 1386 | duration = torch.maximum(lens + 1, duration) # just add one token so something is generated 1387 | duration = duration.clamp(max = max_duration) 1388 | 1389 | assert duration.shape[0] == batch 1390 | 1391 | max_duration = duration.amax() 1392 | 1393 | cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.) 1394 | cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False) 1395 | cond_mask = rearrange(cond_mask, '... -> ... 1') 1396 | 1397 | mask = lens_to_mask(duration) 1398 | 1399 | # neural ode 1400 | 1401 | def fn(t, x): 1402 | # at each step, conditioning is fixed 1403 | 1404 | step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) 1405 | 1406 | # predict flow 1407 | 1408 | return self.cfg_transformer_with_pred_head( 1409 | x, 1410 | step_cond, 1411 | times = t, 1412 | text = text, 1413 | mask = mask, 1414 | cfg_strength = cfg_strength, 1415 | cfg_null_model = cfg_null_model 1416 | ) 1417 | 1418 | y0 = torch.randn_like(cond) 1419 | t = torch.linspace(0, 1, steps, device = self.device) 1420 | 1421 | trajectory = odeint(fn, y0, t, **self.odeint_kwargs) 1422 | sampled = trajectory[-1] 1423 | 1424 | out = sampled 1425 | 1426 | out = torch.where(cond_mask, cond, out) 1427 | 1428 | # able to return raw untransformed output, if not using mel rep 1429 | 1430 | if exists(return_raw_output) and return_raw_output: 1431 | return out 1432 | 1433 | # take care of transforming mel to audio if `vocoder` is passed in, or if `use_vocos` is turned on 1434 | 1435 | if exists(vocoder): 1436 | assert not exists(self.vocos), '`use_vocos` should not be turned on if you are passing in a custom `vocoder` on sampling' 1437 | out = rearrange(out, 'b n d -> b d n') 1438 | out = vocoder(out) 1439 | 1440 | elif exists(self.vocos): 1441 | 1442 | audio = [] 1443 | for mel, one_mask in zip(out, mask): 1444 | one_out = DB_to_amplitude(mel[one_mask], ref = 1., power = 0.5) 1445 | 1446 | one_out = rearrange(one_out, 'n d -> 1 d n') 1447 | one_audio = self.vocos.decode(one_out) 1448 | one_audio = rearrange(one_audio, '1 nw -> nw') 1449 | audio.append(one_audio) 1450 | 1451 | out = audio 1452 | 1453 | if exists(save_to_filename): 1454 | assert exists(vocoder) or exists(self.vocos) 1455 | assert exists(self.sampling_rate) 1456 | 1457 | path = Path(save_to_filename) 1458 | parent_path = path.parents[0] 1459 | parent_path.mkdir(exist_ok = True, parents = True) 1460 | 1461 | for ind, one_audio in enumerate(out): 1462 | one_audio = rearrange(one_audio, 'nw -> 1 nw') 1463 | save_path = str(parent_path / f'{ind + 1}.{path.name}') 1464 | torchaudio.save(save_path, one_audio.detach().cpu(), sample_rate = self.sampling_rate) 1465 | 1466 | return out 1467 | 1468 | def forward( 1469 | self, 1470 | inp: Float['b n d'] | Float['b nw'], # mel or raw wave 1471 | *, 1472 | text: Int['b nt'] | list[str] | None = None, 1473 | times: Int['b'] | None = None, 1474 | lens: Int['b'] | None = None, 1475 | velocity_consistency_model: E2TTS | None = None, 1476 | velocity_consistency_delta = 1e-5 1477 | ): 1478 | need_velocity_loss = exists(velocity_consistency_model) and self.velocity_consistency_weight > 0. 1479 | 1480 | # handle raw wave 1481 | 1482 | if inp.ndim == 2: 1483 | inp = self.mel_spec(inp) 1484 | inp = rearrange(inp, 'b d n -> b n d') 1485 | assert inp.shape[-1] == self.num_channels 1486 | 1487 | batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, self.device 1488 | 1489 | # handle text as string 1490 | 1491 | if isinstance(text, list): 1492 | text = self.tokenizer(text).to(device) 1493 | assert text.shape[0] == batch 1494 | 1495 | # lens and mask 1496 | 1497 | if not exists(lens): 1498 | lens = torch.full((batch,), seq_len, device = device) 1499 | 1500 | mask = lens_to_mask(lens, length = seq_len) 1501 | 1502 | # get a random span to mask out for training conditionally 1503 | 1504 | frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask) 1505 | rand_span_mask = mask_from_frac_lengths(lens, frac_lengths, max_length = seq_len) 1506 | 1507 | if exists(mask): 1508 | rand_span_mask &= mask 1509 | 1510 | # mel is x1 1511 | 1512 | x1 = inp 1513 | 1514 | # main conditional flow training logic 1515 | # just ~5 loc 1516 | 1517 | # x0 is gaussian noise 1518 | 1519 | x0 = torch.randn_like(x1) 1520 | 1521 | # t is random times from above 1522 | 1523 | times = torch.rand((batch,), dtype = dtype, device = self.device) 1524 | t = rearrange(times, 'b -> b 1 1') 1525 | 1526 | # if need velocity consistency, make sure time does not exceed 1. 1527 | 1528 | if need_velocity_loss: 1529 | t = t * (1. - velocity_consistency_delta) 1530 | 1531 | # sample xt (w in the paper) 1532 | 1533 | w = (1. - t) * x0 + t * x1 1534 | 1535 | flow = x1 - x0 1536 | 1537 | # only predict what is within the random mask span for infilling 1538 | 1539 | cond = einx.where( 1540 | 'b n, b n d, b n d -> b n d', 1541 | rand_span_mask, 1542 | torch.zeros_like(x1), x1 1543 | ) 1544 | 1545 | # transformer and prediction head 1546 | 1547 | pred, did_drop_text_cond = self.transformer_with_pred_head( 1548 | w, 1549 | cond, 1550 | times = times, 1551 | text = text, 1552 | mask = mask, 1553 | return_drop_text_cond = True 1554 | ) 1555 | 1556 | # maybe velocity consistency loss 1557 | 1558 | velocity_loss = self.zero 1559 | 1560 | if need_velocity_loss: 1561 | 1562 | t_with_delta = t + velocity_consistency_delta 1563 | w_with_delta = (1. - t_with_delta) * x0 + t_with_delta * x1 1564 | 1565 | with torch.no_grad(): 1566 | ema_pred = velocity_consistency_model.transformer_with_pred_head( 1567 | w_with_delta, 1568 | cond, 1569 | times = times + velocity_consistency_delta, 1570 | text = text, 1571 | mask = mask, 1572 | drop_text_cond = did_drop_text_cond 1573 | ) 1574 | 1575 | velocity_loss = F.mse_loss(pred, ema_pred, reduction = 'none') 1576 | velocity_loss = velocity_loss[rand_span_mask].mean() 1577 | 1578 | # flow matching loss 1579 | 1580 | loss = F.mse_loss(pred, flow, reduction = 'none') 1581 | 1582 | loss = loss[rand_span_mask].mean() 1583 | 1584 | # total loss and get breakdown 1585 | 1586 | total_loss = ( 1587 | loss + 1588 | velocity_loss * self.velocity_consistency_weight 1589 | ) 1590 | 1591 | breakdown = LossBreakdown(loss, velocity_loss) 1592 | 1593 | # return total loss and bunch of intermediates 1594 | 1595 | return E2TTSReturn(total_loss, cond, pred, x0 + pred, breakdown) 1596 | -------------------------------------------------------------------------------- /e2_tts_pytorch/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from tqdm import tqdm 5 | import matplotlib 6 | matplotlib.use("Agg") 7 | import matplotlib.pylab as plt 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.optim import Optimizer 12 | from torch.utils.data import DataLoader, Dataset 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torch.optim.lr_scheduler import LinearLR, SequentialLR 15 | 16 | import torchaudio 17 | 18 | from einops import rearrange 19 | 20 | from accelerate import Accelerator 21 | from accelerate.utils import DistributedDataParallelKwargs 22 | 23 | from adam_atan2_pytorch.adopt import Adopt 24 | 25 | from ema_pytorch import EMA 26 | 27 | from loguru import logger 28 | 29 | from e2_tts_pytorch.e2_tts import ( 30 | E2TTS, 31 | DurationPredictor, 32 | MelSpec 33 | ) 34 | 35 | def exists(v): 36 | return v is not None 37 | 38 | def default(v, d): 39 | return v if exists(v) else d 40 | 41 | def to_numpy(t): 42 | return t.detach().cpu().numpy() 43 | 44 | # plot spectrogram 45 | 46 | def plot_spectrogram(spectrogram): 47 | spectrogram = to_numpy(spectrogram) 48 | fig, ax = plt.subplots(figsize=(10, 4)) 49 | im = ax.imshow(spectrogram.T, aspect="auto", origin="lower", interpolation="none") 50 | plt.colorbar(im, ax=ax) 51 | plt.xlabel("Frames") 52 | plt.ylabel("Channels") 53 | plt.tight_layout() 54 | 55 | fig.canvas.draw() 56 | plt.close() 57 | return fig 58 | 59 | # collation 60 | 61 | def collate_fn(batch): 62 | mel_specs = [item['mel_spec'].squeeze(0) for item in batch] 63 | mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs]) 64 | max_mel_length = mel_lengths.amax() 65 | 66 | padded_mel_specs = [] 67 | for spec in mel_specs: 68 | padding = (0, max_mel_length - spec.size(-1)) 69 | padded_spec = F.pad(spec, padding, value = 0) 70 | padded_mel_specs.append(padded_spec) 71 | 72 | mel_specs = torch.stack(padded_mel_specs) 73 | 74 | text = [item['text'] for item in batch] 75 | text_lengths = torch.LongTensor([len(item) for item in text]) 76 | 77 | return dict( 78 | mel = mel_specs, 79 | mel_lengths = mel_lengths, 80 | text = text, 81 | text_lengths = text_lengths, 82 | ) 83 | 84 | # dataset 85 | 86 | class HFDataset(Dataset): 87 | def __init__( 88 | self, 89 | hf_dataset: Dataset, 90 | target_sample_rate = 24_000, 91 | hop_length = 256 92 | ): 93 | self.data = hf_dataset 94 | self.target_sample_rate = target_sample_rate 95 | self.hop_length = hop_length 96 | self.mel_spectrogram = MelSpec(sampling_rate=target_sample_rate) 97 | 98 | def __len__(self): 99 | return len(self.data) 100 | 101 | def __getitem__(self, index): 102 | row = self.data[index] 103 | audio = row['audio']['array'] 104 | 105 | logger.info(f"Audio shape: {audio.shape}") 106 | 107 | sample_rate = row['audio']['sampling_rate'] 108 | duration = audio.shape[-1] / sample_rate 109 | 110 | if duration > 20 or duration < 0.3: 111 | logger.warning(f"Skipping due to duration out of bound: {duration}") 112 | return self.__getitem__((index + 1) % len(self.data)) 113 | 114 | audio_tensor = torch.from_numpy(audio).float() 115 | 116 | if sample_rate != self.target_sample_rate: 117 | resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate) 118 | audio_tensor = resampler(audio_tensor) 119 | 120 | audio_tensor = rearrange(audio_tensor, 't -> 1 t') 121 | 122 | mel_spec = self.mel_spectrogram(audio_tensor) 123 | 124 | mel_spec = rearrange(mel_spec, '1 d t -> d t') 125 | 126 | text = row['transcript'] 127 | 128 | return dict( 129 | mel_spec = mel_spec, 130 | text = text, 131 | ) 132 | 133 | # trainer 134 | 135 | class E2Trainer: 136 | def __init__( 137 | self, 138 | model: E2TTS, 139 | optimizer: Optimizer | None = None, 140 | learning_rate = 7.5e-5, 141 | num_warmup_steps = 20000, 142 | grad_accumulation_steps = 1, 143 | duration_predictor: DurationPredictor | None = None, 144 | checkpoint_path = None, 145 | log_file = "logs.txt", 146 | max_grad_norm = 1.0, 147 | sample_rate = 22050, 148 | tensorboard_log_dir = 'runs/e2_tts_experiment', 149 | accelerate_kwargs: dict = dict(), 150 | ema_kwargs: dict = dict(), 151 | use_switch_ema = False 152 | ): 153 | logger.add(log_file) 154 | 155 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True) 156 | 157 | self.accelerator = Accelerator( 158 | log_with = "all", 159 | kwargs_handlers = [ddp_kwargs], 160 | gradient_accumulation_steps = grad_accumulation_steps, 161 | **accelerate_kwargs 162 | ) 163 | 164 | self.target_sample_rate = sample_rate 165 | 166 | self.model = model 167 | 168 | self.need_velocity_consistent_loss = model.velocity_consistency_weight > 0. 169 | 170 | self.ema_model = EMA( 171 | model, 172 | include_online_model = False, 173 | **ema_kwargs 174 | ) 175 | 176 | self.use_switch_ema = use_switch_ema 177 | 178 | self.duration_predictor = duration_predictor 179 | 180 | # optimizer 181 | 182 | if not exists(optimizer): 183 | optimizer = Adopt(model.parameters(), lr = learning_rate) 184 | 185 | self.optimizer = optimizer 186 | 187 | self.num_warmup_steps = num_warmup_steps 188 | self.mel_spectrogram = MelSpec(sampling_rate=self.target_sample_rate) 189 | 190 | self.ema_model, self.model, self.optimizer = self.accelerator.prepare( 191 | self.ema_model, self.model, self.optimizer 192 | ) 193 | self.max_grad_norm = max_grad_norm 194 | 195 | self.checkpoint_path = default(checkpoint_path, 'model.pth') 196 | self.writer = SummaryWriter(log_dir=tensorboard_log_dir) 197 | 198 | @property 199 | def is_main(self): 200 | return self.accelerator.is_main_process 201 | 202 | def save_checkpoint(self, step, finetune=False): 203 | self.accelerator.wait_for_everyone() 204 | if self.is_main: 205 | checkpoint = dict( 206 | model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(), 207 | optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(), 208 | ema_model_state_dict = self.ema_model.state_dict(), 209 | scheduler_state_dict = self.scheduler.state_dict(), 210 | step = step 211 | ) 212 | 213 | self.accelerator.save(checkpoint, self.checkpoint_path) 214 | 215 | def load_checkpoint(self): 216 | if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path): 217 | return 0 218 | 219 | checkpoint = torch.load(self.checkpoint_path) 220 | self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict']) 221 | self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict']) 222 | 223 | if self.is_main: 224 | self.ema_model.load_state_dict(checkpoint['ema_model_state_dict']) 225 | 226 | if self.scheduler: 227 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 228 | return checkpoint['step'] 229 | 230 | def train(self, train_dataset, epochs, batch_size, num_workers=12, save_step=1000): 231 | 232 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers, pin_memory=True) 233 | total_steps = len(train_dataloader) * epochs 234 | decay_steps = total_steps - self.num_warmup_steps 235 | warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=self.num_warmup_steps) 236 | decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps) 237 | self.scheduler = SequentialLR(self.optimizer, 238 | schedulers=[warmup_scheduler, decay_scheduler], 239 | milestones=[self.num_warmup_steps]) 240 | train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) 241 | start_step = self.load_checkpoint() 242 | global_step = start_step 243 | 244 | for epoch in range(epochs): 245 | self.model.train() 246 | progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}", unit="step", disable=not self.accelerator.is_local_main_process) 247 | epoch_loss = 0.0 248 | 249 | for batch in progress_bar: 250 | with self.accelerator.accumulate(self.model): 251 | text_inputs = batch['text'] 252 | mel_spec = rearrange(batch['mel'], 'b d n -> b n d') 253 | mel_lengths = batch["mel_lengths"] 254 | 255 | if exists(self.duration_predictor): 256 | dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations')) 257 | self.writer.add_scalar('duration loss', dur_loss.item(), global_step) 258 | 259 | velocity_consistency_model = None 260 | if self.need_velocity_consistent_loss and self.ema_model.initted: 261 | velocity_consistency_model = self.accelerator.unwrap_model(self.ema_model).ema_model 262 | 263 | loss, cond, pred, pred_data = self.model( 264 | mel_spec, 265 | text=text_inputs, 266 | lens=mel_lengths, 267 | velocity_consistency_model=velocity_consistency_model 268 | ) 269 | 270 | self.accelerator.backward(loss) 271 | 272 | if self.max_grad_norm > 0 and self.accelerator.sync_gradients: 273 | self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 274 | 275 | self.optimizer.step() 276 | self.scheduler.step() 277 | self.optimizer.zero_grad() 278 | 279 | self.accelerator.unwrap_model(self.ema_model).update() 280 | 281 | if self.accelerator.is_local_main_process: 282 | logger.info(f"step {global_step+1}: loss = {loss.item():.4f}") 283 | self.writer.add_scalar('loss', loss.item(), global_step) 284 | self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step) 285 | 286 | global_step += 1 287 | epoch_loss += loss.item() 288 | progress_bar.set_postfix(loss=loss.item()) 289 | 290 | if global_step % save_step == 0: 291 | self.save_checkpoint(global_step) 292 | self.writer.add_figure("mel/target", plot_spectrogram(mel_spec[0,:,:]), global_step) 293 | self.writer.add_figure("mel/mask", plot_spectrogram(cond[0,:,:]), global_step) 294 | self.writer.add_figure("mel/prediction", plot_spectrogram(pred_data[0,:,:]), global_step) 295 | 296 | epoch_loss /= len(train_dataloader) 297 | if self.accelerator.is_local_main_process: 298 | logger.info(f"epoch {epoch+1}/{epochs} - average loss = {epoch_loss:.4f}") 299 | self.writer.add_scalar('epoch average loss', epoch_loss, epoch) 300 | 301 | if self.use_switch_ema: 302 | self.ema_model.update_model_with_ema() 303 | 304 | self.writer.close() 305 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "e2-tts-pytorch" 3 | version = "2.2.1" 4 | description = "E2-TTS in Pytorch" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.8" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | 'transformers', 15 | 'attention mechanism', 16 | 'text to speech' 17 | ] 18 | classifiers=[ 19 | 'Development Status :: 4 - Beta', 20 | 'Intended Audience :: Developers', 21 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 22 | 'License :: OSI Approved :: MIT License', 23 | 'Programming Language :: Python :: 3.8', 24 | ] 25 | 26 | dependencies = [ 27 | 'accelerate>=0.33.0', 28 | 'adam-atan2-pytorch>=0.1.12', 29 | 'beartype', 30 | 'einops>=0.8.0', 31 | 'einx>=0.3.0', 32 | 'ema-pytorch>=0.5.2', 33 | 'hl-gauss-pytorch>=0.1.7', 34 | 'hyper-connections>=0.0.10', 35 | 'g2p-en', 36 | 'jaxtyping', 37 | 'loguru', 38 | 'pydantic<2', 39 | 'tensorboard', 40 | 'torch>=2.0', 41 | 'torchdiffeq', 42 | 'torchaudio>=2.3.1', 43 | 'tqdm>=4.65.0', 44 | 'vocos', 45 | 'x-transformers>=1.42.23', 46 | ] 47 | 48 | [project.urls] 49 | Homepage = "https://pypi.org/project/e2-tts-pytorch/" 50 | Repository = "https://github.com/lucidrains/e2-tts-pytorch" 51 | 52 | [project.optional-dependencies] 53 | examples = ["datasets"] 54 | 55 | [build-system] 56 | requires = ["hatchling"] 57 | build-backend = "hatchling.build" 58 | 59 | [tool.hatch.metadata] 60 | allow-direct-references = true 61 | 62 | [tool.hatch.build.targets.wheel] 63 | packages = ["e2_tts_pytorch"] 64 | -------------------------------------------------------------------------------- /train_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from e2_tts_pytorch import E2TTS, DurationPredictor 3 | 4 | from datasets import load_dataset 5 | 6 | from e2_tts_pytorch.trainer import ( 7 | HFDataset, 8 | E2Trainer 9 | ) 10 | 11 | duration_predictor = DurationPredictor( 12 | transformer = dict( 13 | dim = 512, 14 | depth = 6, 15 | ) 16 | ) 17 | 18 | e2tts = E2TTS( 19 | duration_predictor = duration_predictor, 20 | transformer = dict( 21 | dim = 512, 22 | depth = 12 23 | ), 24 | ) 25 | 26 | train_dataset = HFDataset(load_dataset("MushanW/GLOBE")["train"]) 27 | 28 | trainer = E2Trainer( 29 | e2tts, 30 | num_warmup_steps=20000, 31 | grad_accumulation_steps = 1, 32 | checkpoint_path = 'e2tts.pt', 33 | log_file = 'e2tts.txt' 34 | ) 35 | 36 | epochs = 10 37 | batch_size = 32 38 | 39 | trainer.train(train_dataset, epochs, batch_size, save_step=1000) 40 | --------------------------------------------------------------------------------