├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── images
├── translated.png
└── voicebox.png
├── setup.py
└── voicebox_pytorch
├── __init__.py
├── attend.py
├── data.py
├── optimizer.py
├── trainer.py
└── voicebox_pytorch.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Voicebox - Pytorch
4 |
5 | Implementation of Voicebox , new SOTA Text-to-Speech model from MetaAI, in Pytorch. Press release
6 |
7 | In this work, we will use rotary embeddings. The authors seem unaware that ALiBi cannot be straightforwardly used for bidirectional models.
8 |
9 | The paper also addresses the issue with time embedding incorrectly subjected to relative distances (they concat the time embedding along the frame dimension of the audio tokens). This repository will use adaptive normalization, as applied successfully in Paella
10 |
11 | Update: Recommend you just use E2 TTS instead of this work
12 |
13 | ## Appreciation
14 |
15 | - for awarding me the Imminent Grant to advance the state of open sourced text-to-speech solutions. This project was started and will be completed under this grant.
16 |
17 | - StabilityAI for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.
18 |
19 | - Bryan Chiang for the ongoing code review, sharing his expertise on TTS, and pointing me to an open sourced implementation of conditional flow matching
20 |
21 | - Manmay for getting the repository started with the alignment code
22 |
23 | - @chenht2010 for finding a bug with rotary positions, and for validating that the code in the repository converges
24 |
25 | - Lucas Newman for (yet again) pull requesting all the training code for Spear-TTS conditioned Voicebox training!
26 |
27 | - Lucas Newman has demonstrated that the whole system works with Spear-TTS conditioning. Training converges even better than Soundstorm
28 |
29 | ## Install
30 |
31 | ```bash
32 | $ pip install voicebox-pytorch
33 | ```
34 |
35 | ## Usage
36 |
37 | Training and sampling with `TextToSemantic` module from SpearTTS
38 |
39 | ```python
40 | import torch
41 |
42 | from voicebox_pytorch import (
43 | VoiceBox,
44 | EncodecVoco,
45 | ConditionalFlowMatcherWrapper,
46 | HubertWithKmeans,
47 | TextToSemantic
48 | )
49 |
50 | # https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
51 |
52 | wav2vec = HubertWithKmeans(
53 | checkpoint_path = '/path/to/hubert/checkpoint.pt',
54 | kmeans_path = '/path/to/hubert/kmeans.bin'
55 | )
56 |
57 | text_to_semantic = TextToSemantic(
58 | wav2vec = wav2vec,
59 | dim = 512,
60 | source_depth = 1,
61 | target_depth = 1,
62 | use_openai_tokenizer = True
63 | )
64 |
65 | text_to_semantic.load('/path/to/trained/spear-tts/model.pt')
66 |
67 | model = VoiceBox(
68 | dim = 512,
69 | audio_enc_dec = EncodecVoco(),
70 | num_cond_tokens = 500,
71 | depth = 2,
72 | dim_head = 64,
73 | heads = 16
74 | )
75 |
76 | cfm_wrapper = ConditionalFlowMatcherWrapper(
77 | voicebox = model,
78 | text_to_semantic = text_to_semantic
79 | )
80 |
81 | # mock data
82 |
83 | audio = torch.randn(2, 12000)
84 |
85 | # train
86 |
87 | loss = cfm_wrapper(audio)
88 | loss.backward()
89 |
90 | # after much training
91 |
92 | texts = [
93 | 'the rain in spain falls mainly in the plains',
94 | 'she sells sea shells by the seashore'
95 | ]
96 |
97 | cond = torch.randn(2, 12000)
98 | sampled = cfm_wrapper.sample(cond = cond, texts = texts) # (2, 1, )
99 | ```
100 |
101 | For unconditional training, `condition_on_text` on `VoiceBox` must be set to `False`
102 |
103 | ```python
104 | import torch
105 | from voicebox_pytorch import (
106 | VoiceBox,
107 | ConditionalFlowMatcherWrapper
108 | )
109 |
110 | model = VoiceBox(
111 | dim = 512,
112 | num_cond_tokens = 500,
113 | depth = 2,
114 | dim_head = 64,
115 | heads = 16,
116 | condition_on_text = False
117 | )
118 |
119 | cfm_wrapper = ConditionalFlowMatcherWrapper(
120 | voicebox = model
121 | )
122 |
123 | # mock data
124 |
125 | x = torch.randn(2, 1024, 512)
126 |
127 | # train
128 |
129 | loss = cfm_wrapper(x)
130 |
131 | loss.backward()
132 |
133 | # after much training
134 |
135 | cond = torch.randn(2, 1024, 512)
136 |
137 | sampled = cfm_wrapper.sample(cond = cond) # (2, 1024, 512)
138 | ```
139 |
140 | ## Todo
141 |
142 | - [x] read and internalize original flow matching paper
143 | - [x] basic loss
144 | - [x] get neural ode working with torchdyn
145 | - [x] get basic mask generation logic with the p_drop of 0.2-0.3 for ICL
146 | - [x] take care of p_drop, different between voicebox and duration model
147 | - [x] support torchdiffeq and torchode
148 | - [x] switch to adaptive rmsnorm for time conditioning
149 | - [x] add encodec / voco for starters
150 | - [x] setup training and sampling with raw audio, if `audio_enc_dec` is passed in
151 | - [x] integrate with log mel spec / encodec - vocos
152 | - [x] spear-tts-integration
153 | - [x] basic accelerate trainer - thanks to @lucasnewman!
154 |
155 | - [ ] cleanup NS2 aligner class and then setup duration predictor training
156 | - [ ] figure out the correct settings for `MelVoco` encode, as the reconstructed audio is longer in length
157 | - [ ] calculate how many seconds corresponds to each frame and add as property on `AudioEncoderDecoder` - when sampling, allow for specifying in seconds
158 |
159 | ## Citations
160 |
161 | ```bibtex
162 | @article{Le2023VoiceboxTM,
163 | title = {Voicebox: Text-Guided Multilingual Universal Speech Generation at Scale},
164 | author = {Matt Le and Apoorv Vyas and Bowen Shi and Brian Karrer and Leda Sari and Rashel Moritz and Mary Williamson and Vimal Manohar and Yossi Adi and Jay Mahadeokar and Wei-Ning Hsu},
165 | journal = {ArXiv},
166 | year = {2023},
167 | volume = {abs/2306.15687},
168 | url = {https://api.semanticscholar.org/CorpusID:259275061}
169 | }
170 | ```
171 |
172 | ```bibtex
173 | @inproceedings{dao2022flashattention,
174 | title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
175 | author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
176 | booktitle = {Advances in Neural Information Processing Systems},
177 | year = {2022}
178 | }
179 | ```
180 |
181 | ```bibtex
182 | @misc{torchdiffeq,
183 | author = {Chen, Ricky T. Q.},
184 | title = {torchdiffeq},
185 | year = {2018},
186 | url = {https://github.com/rtqichen/torchdiffeq},
187 | }
188 | ```
189 |
190 | ```bibtex
191 | @inproceedings{lienen2022torchode,
192 | title = {torchode: A Parallel {ODE} Solver for PyTorch},
193 | author = {Marten Lienen and Stephan G{\"u}nnemann},
194 | booktitle = {The Symbiosis of Deep Learning and Differential Equations II, NeurIPS},
195 | year = {2022},
196 | url = {https://openreview.net/forum?id=uiKVKTiUYB0}
197 | }
198 | ```
199 |
200 | ```bibtex
201 | @article{siuzdak2023vocos,
202 | title = {Vocos: Closing the gap between time-domain and Fourier-based neural vocoders for high-quality audio synthesis},
203 | author = {Siuzdak, Hubert},
204 | journal = {arXiv preprint arXiv:2306.00814},
205 | year = {2023}
206 | }
207 | ```
208 |
209 | ```bibtex
210 | @misc{darcet2023vision,
211 | title = {Vision Transformers Need Registers},
212 | author = {Timothée Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
213 | year = {2023},
214 | eprint = {2309.16588},
215 | archivePrefix = {arXiv},
216 | primaryClass = {cs.CV}
217 | }
218 | ```
219 |
220 | ```bibtex
221 | @inproceedings{Dehghani2023ScalingVT,
222 | title = {Scaling Vision Transformers to 22 Billion Parameters},
223 | author = {Mostafa Dehghani and Josip Djolonga and Basil Mustafa and Piotr Padlewski and Jonathan Heek and Justin Gilmer and Andreas Steiner and Mathilde Caron and Robert Geirhos and Ibrahim M. Alabdulmohsin and Rodolphe Jenatton and Lucas Beyer and Michael Tschannen and Anurag Arnab and Xiao Wang and Carlos Riquelme and Matthias Minderer and Joan Puigcerver and Utku Evci and Manoj Kumar and Sjoerd van Steenkiste and Gamaleldin F. Elsayed and Aravindh Mahendran and Fisher Yu and Avital Oliver and Fantine Huot and Jasmijn Bastings and Mark Collier and Alexey A. Gritsenko and Vighnesh Birodkar and Cristina Nader Vasconcelos and Yi Tay and Thomas Mensink and Alexander Kolesnikov and Filip Paveti'c and Dustin Tran and Thomas Kipf and Mario Luvci'c and Xiaohua Zhai and Daniel Keysers and Jeremiah Harmsen and Neil Houlsby},
224 | booktitle = {International Conference on Machine Learning},
225 | year = {2023},
226 | url = {https://api.semanticscholar.org/CorpusID:256808367}
227 | }
228 | ```
229 |
230 | ```bibtex
231 | @inproceedings{Katsch2023GateLoopFD,
232 | title = {GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling},
233 | author = {Tobias Katsch},
234 | year = {2023},
235 | url = {https://api.semanticscholar.org/CorpusID:265018962}
236 | }
237 | ```
238 |
--------------------------------------------------------------------------------
/images/translated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/voicebox-pytorch/d115a997452f278190a2634be500a3db0da5db15/images/translated.png
--------------------------------------------------------------------------------
/images/voicebox.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/voicebox-pytorch/d115a997452f278190a2634be500a3db0da5db15/images/voicebox.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'voicebox-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.5.0',
7 | license='MIT',
8 | description = 'Voicebox - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/voicebox-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'text to speech'
17 | ],
18 | install_requires=[
19 | 'accelerate',
20 | 'audiolm-pytorch>=1.2.28',
21 | 'naturalspeech2-pytorch>=0.1.8',
22 | 'beartype',
23 | 'einops>=0.6.1',
24 | 'gateloop-transformer>=0.2.4',
25 | 'spear-tts-pytorch>=0.4.0',
26 | 'torch>=2.0',
27 | 'torchdiffeq',
28 | 'torchode',
29 | 'vocos'
30 | ],
31 | classifiers=[
32 | 'Development Status :: 4 - Beta',
33 | 'Intended Audience :: Developers',
34 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
35 | 'License :: OSI Approved :: MIT License',
36 | 'Programming Language :: Python :: 3.6',
37 | ],
38 | )
39 |
--------------------------------------------------------------------------------
/voicebox_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from voicebox_pytorch.voicebox_pytorch import (
2 | Transformer,
3 | EncodecVoco,
4 | VoiceBox,
5 | DurationPredictor,
6 | ConditionalFlowMatcherWrapper,
7 | )
8 |
9 | from voicebox_pytorch.trainer import (
10 | VoiceBoxTrainer
11 | )
12 |
13 | from spear_tts_pytorch import TextToSemantic
14 |
15 | from audiolm_pytorch import HubertWithKmeans
16 |
--------------------------------------------------------------------------------
/voicebox_pytorch/attend.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | from packaging import version
3 | from collections import namedtuple
4 |
5 | import torch
6 | from torch import nn, einsum
7 | import torch.nn.functional as F
8 |
9 | from einops import rearrange, reduce
10 |
11 | # constants
12 |
13 | FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
14 |
15 | # helpers
16 |
17 | def exists(val):
18 | return val is not None
19 |
20 | def default(val, d):
21 | return val if exists(val) else d
22 |
23 | def once(fn):
24 | called = False
25 | @wraps(fn)
26 | def inner(x):
27 | nonlocal called
28 | if called:
29 | return
30 | called = True
31 | return fn(x)
32 | return inner
33 |
34 | print_once = once(print)
35 |
36 | # main class
37 |
38 | class Attend(nn.Module):
39 | def __init__(
40 | self,
41 | dropout = 0.,
42 | flash = False,
43 | scale = None
44 | ):
45 | super().__init__()
46 | self.dropout = dropout
47 | self.attn_dropout = nn.Dropout(dropout)
48 |
49 | self.scale = scale
50 |
51 | self.flash = flash
52 | assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
53 |
54 | # determine efficient attention configs for cuda and cpu
55 |
56 | self.cpu_config = FlashAttentionConfig(True, True, True)
57 | self.cuda_config = None
58 |
59 | if not torch.cuda.is_available() or not flash:
60 | return
61 |
62 | device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
63 |
64 | if device_properties.major == 8 and device_properties.minor == 0:
65 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
66 | self.cuda_config = FlashAttentionConfig(True, False, False)
67 | else:
68 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
69 | self.cuda_config = FlashAttentionConfig(False, True, True)
70 |
71 | def flash_attn(self, q, k, v, mask = None):
72 | _, heads, q_len, dim_head, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
73 |
74 | # if scale is given, divide by the default scale that sdpa uses
75 |
76 | if exists(self.scale):
77 | q = q * (self.scale / (dim_head ** -0.5))
78 |
79 | # Check if mask exists and expand to compatible shape
80 | # The mask is B L, so it would have to be expanded to B H N L
81 |
82 | if exists(mask):
83 | mask = mask.expand(-1, heads, q_len, -1)
84 |
85 | # Check if there is a compatible device for flash attention
86 |
87 | config = self.cuda_config if is_cuda else self.cpu_config
88 |
89 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
90 |
91 | with torch.backends.cuda.sdp_kernel(**config._asdict()):
92 | out = F.scaled_dot_product_attention(
93 | q, k, v,
94 | attn_mask = mask,
95 | dropout_p = self.dropout if self.training else 0.
96 | )
97 |
98 | return out
99 |
100 | def forward(self, q, k, v, mask = None):
101 | """
102 | einstein notation
103 | b - batch
104 | h - heads
105 | n, i, j - sequence length (base sequence length, source, target)
106 | d - feature dimension
107 | """
108 |
109 | q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
110 |
111 | scale = default(self.scale, q.shape[-1] ** -0.5)
112 |
113 | if exists(mask) and mask.ndim != 4:
114 | mask = rearrange(mask, 'b j -> b 1 1 j')
115 |
116 | if self.flash:
117 | return self.flash_attn(q, k, v, mask = mask)
118 |
119 | # similarity
120 |
121 | sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
122 |
123 | # key padding mask
124 |
125 | if exists(mask):
126 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
127 |
128 | # attention
129 |
130 | attn = sim.softmax(dim=-1)
131 | attn = self.attn_dropout(attn)
132 |
133 | # aggregate values
134 |
135 | out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
136 |
137 | return out
138 |
--------------------------------------------------------------------------------
/voicebox_pytorch/data.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from functools import wraps
3 |
4 | from einops import rearrange
5 |
6 | from beartype import beartype
7 | from beartype.door import is_bearable
8 | from beartype.typing import Optional, Tuple, Union
9 |
10 | import torch
11 | from torch.nn.utils.rnn import pad_sequence
12 | from torch.utils.data import Dataset, DataLoader
13 |
14 | import torchaudio
15 |
16 | # utilities
17 |
18 | def exists(val):
19 | return val is not None
20 |
21 | def cast_tuple(val, length = 1):
22 | return val if isinstance(val, tuple) else ((val,) * length)
23 |
24 | # dataset functions
25 |
26 | class AudioDataset(Dataset):
27 | @beartype
28 | def __init__(
29 | self,
30 | folder,
31 | audio_extension = ".flac"
32 | ):
33 | super().__init__()
34 | path = Path(folder)
35 | assert path.exists(), 'folder does not exist'
36 |
37 | self.audio_extension = audio_extension
38 |
39 | files = list(path.glob(f'**/*{audio_extension}'))
40 | assert len(files) > 0, 'no files found'
41 |
42 | self.files = files
43 |
44 | def __len__(self):
45 | return len(self.files)
46 |
47 | def __getitem__(self, idx):
48 | file = self.files[idx]
49 |
50 | wave, _ = torchaudio.load(file)
51 | wave = rearrange(wave, '1 ... -> ...')
52 |
53 | return wave
54 |
55 | # dataloader functions
56 |
57 | def collate_one_or_multiple_tensors(fn):
58 | @wraps(fn)
59 | def inner(data):
60 | is_one_data = not isinstance(data[0], tuple)
61 |
62 | if is_one_data:
63 | data = fn(data)
64 | return (data,)
65 |
66 | outputs = []
67 | for datum in zip(*data):
68 | if is_bearable(datum, Tuple[str, ...]):
69 | output = list(datum)
70 | else:
71 | output = fn(datum)
72 |
73 | outputs.append(output)
74 |
75 | return tuple(outputs)
76 |
77 | return inner
78 |
79 | @collate_one_or_multiple_tensors
80 | def curtail_to_shortest_collate(data):
81 | min_len = min(*[datum.shape[0] for datum in data])
82 | data = [datum[:min_len] for datum in data]
83 | return torch.stack(data)
84 |
85 | @collate_one_or_multiple_tensors
86 | def pad_to_longest_fn(data):
87 | return pad_sequence(data, batch_first = True)
88 |
89 | def get_dataloader(ds, pad_to_longest = True, **kwargs):
90 | collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate
91 | return DataLoader(ds, collate_fn = collate_fn, **kwargs)
92 |
--------------------------------------------------------------------------------
/voicebox_pytorch/optimizer.py:
--------------------------------------------------------------------------------
1 | from torch.optim import AdamW, Adam
2 |
3 | def separate_weight_decayable_params(params):
4 | wd_params, no_wd_params = [], []
5 | for param in params:
6 | param_list = no_wd_params if param.ndim < 2 else wd_params
7 | param_list.append(param)
8 | return wd_params, no_wd_params
9 |
10 | def get_optimizer(
11 | params,
12 | lr = 1e-4,
13 | wd = 1e-2,
14 | betas = (0.9, 0.99),
15 | eps = 1e-8,
16 | filter_by_requires_grad = False,
17 | group_wd_params = True
18 | ):
19 | has_wd = wd > 0
20 |
21 | if filter_by_requires_grad:
22 | params = list(filter(lambda t: t.requires_grad, params))
23 |
24 | if group_wd_params and has_wd:
25 | wd_params, no_wd_params = separate_weight_decayable_params(params)
26 |
27 | params = [
28 | {'params': wd_params},
29 | {'params': no_wd_params, 'weight_decay': 0},
30 | ]
31 |
32 | if not has_wd:
33 | return Adam(params, lr = lr, betas = betas, eps = eps)
34 |
35 | return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
36 |
--------------------------------------------------------------------------------
/voicebox_pytorch/trainer.py:
--------------------------------------------------------------------------------
1 | import re
2 | from pathlib import Path
3 | from shutil import rmtree
4 | from functools import partial
5 | from contextlib import nullcontext
6 |
7 | from beartype import beartype
8 |
9 | import torch
10 | from torch import nn
11 | from torch.optim.lr_scheduler import CosineAnnealingLR
12 | from torch.utils.data import Dataset, random_split
13 |
14 | from voicebox_pytorch.voicebox_pytorch import ConditionalFlowMatcherWrapper
15 | from voicebox_pytorch.data import get_dataloader
16 | from voicebox_pytorch.optimizer import get_optimizer
17 |
18 | from accelerate import Accelerator, DistributedType
19 | from accelerate.utils import DistributedDataParallelKwargs
20 |
21 | # helpers
22 |
23 | def exists(val):
24 | return val is not None
25 |
26 | def noop(*args, **kwargs):
27 | pass
28 |
29 | def cycle(dl):
30 | while True:
31 | for data in dl:
32 | yield data
33 |
34 | def cast_tuple(t):
35 | return t if isinstance(t, (tuple, list)) else (t,)
36 |
37 | def yes_or_no(question):
38 | answer = input(f'{question} (y/n) ')
39 | return answer.lower() in ('yes', 'y')
40 |
41 | def accum_log(log, new_logs):
42 | for key, new_value in new_logs.items():
43 | old_value = log.get(key, 0.)
44 | log[key] = old_value + new_value
45 | return log
46 |
47 | def checkpoint_num_steps(checkpoint_path):
48 | """Returns the number of steps trained from a checkpoint based on the filename.
49 |
50 | Filename format assumed to be something like "/path/to/voicebox.20000.pt" which is
51 | for 20k train steps. Returns 20000 in that case.
52 | """
53 | results = re.findall(r'\d+', str(checkpoint_path))
54 |
55 | if len(results) == 0:
56 | return 0
57 |
58 | return int(results[-1])
59 |
60 | class VoiceBoxTrainer(nn.Module):
61 | @beartype
62 | def __init__(
63 | self,
64 | cfm_wrapper: ConditionalFlowMatcherWrapper,
65 | *,
66 | batch_size,
67 | dataset: Dataset,
68 | num_train_steps = None,
69 | num_warmup_steps = None,
70 | num_epochs = None,
71 | lr = 3e-4,
72 | initial_lr = 1e-5,
73 | grad_accum_every = 1,
74 | wd = 0.,
75 | max_grad_norm = 0.5,
76 | valid_frac = 0.05,
77 | random_split_seed = 42,
78 | log_every = 10,
79 | save_results_every = 100,
80 | save_model_every = 1000,
81 | results_folder = './results',
82 | force_clear_prev_results = None,
83 | split_batches = False,
84 | drop_last = False,
85 | accelerate_kwargs: dict = dict(),
86 | ):
87 | super().__init__()
88 |
89 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
90 |
91 | self.accelerator = Accelerator(
92 | kwargs_handlers = [ddp_kwargs],
93 | split_batches = split_batches,
94 | **accelerate_kwargs
95 | )
96 |
97 | self.cfm_wrapper = cfm_wrapper
98 |
99 | self.register_buffer('steps', torch.Tensor([0]))
100 |
101 | self.batch_size = batch_size
102 | self.grad_accum_every = grad_accum_every
103 |
104 | # optimizer
105 |
106 | self.optim = get_optimizer(
107 | cfm_wrapper.parameters(),
108 | lr = lr,
109 | wd = wd
110 | )
111 |
112 | self.lr = lr
113 | self.initial_lr = initial_lr
114 |
115 |
116 | # max grad norm
117 |
118 | self.max_grad_norm = max_grad_norm
119 |
120 | # create dataset
121 |
122 | self.ds = dataset
123 |
124 | # split for validation
125 |
126 | if valid_frac > 0:
127 | train_size = int((1 - valid_frac) * len(self.ds))
128 | valid_size = len(self.ds) - train_size
129 | self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
130 | self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
131 | else:
132 | self.valid_ds = self.ds
133 | self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')
134 |
135 | assert len(self.ds) >= batch_size, 'dataset must have sufficient samples for training'
136 | assert len(self.valid_ds) >= batch_size, f'validation dataset must have sufficient number of samples (currently {len(self.valid_ds)}) for training'
137 |
138 | assert exists(num_train_steps) or exists(num_epochs), 'either num_train_steps or num_epochs must be specified'
139 |
140 | if exists(num_epochs):
141 | self.num_train_steps = len(dataset) // batch_size * num_epochs
142 | else:
143 | self.num_train_steps = num_train_steps
144 | self.scheduler = CosineAnnealingLR(self.optim, T_max=self.num_train_steps)
145 | self.num_warmup_steps = num_warmup_steps if exists(num_warmup_steps) else 0
146 |
147 | # dataloader
148 |
149 | self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)
150 | self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)
151 |
152 | # prepare with accelerator
153 |
154 | (
155 | self.cfm_wrapper,
156 | self.optim,
157 | self.scheduler,
158 | self.dl
159 | ) = self.accelerator.prepare(
160 | self.cfm_wrapper,
161 | self.optim,
162 | self.scheduler,
163 | self.dl
164 | )
165 |
166 | # dataloader iterators
167 |
168 | self.dl_iter = cycle(self.dl)
169 | self.valid_dl_iter = cycle(self.valid_dl)
170 |
171 | self.log_every = log_every
172 | self.save_model_every = save_model_every
173 | self.save_results_every = save_results_every
174 |
175 | self.results_folder = Path(results_folder)
176 |
177 | if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
178 | rmtree(str(self.results_folder))
179 |
180 | self.results_folder.mkdir(parents = True, exist_ok = True)
181 |
182 | hps = {
183 | "num_train_steps": self.num_train_steps,
184 | "num_warmup_steps": self.num_warmup_steps,
185 | "learning_rate": self.lr,
186 | "initial_learning_rate": self.initial_lr,
187 | "wd": wd
188 | }
189 | self.accelerator.init_trackers("voicebox", config=hps)
190 |
191 | def save(self, path):
192 | pkg = dict(
193 | model = self.accelerator.get_state_dict(self.cfm_wrapper),
194 | optim = self.optim.state_dict(),
195 | scheduler = self.scheduler.state_dict()
196 | )
197 | torch.save(pkg, path)
198 |
199 | def load(self, path):
200 | cfm_wrapper = self.accelerator.unwrap_model(self.cfm_wrapper)
201 | pkg = cfm_wrapper.load(path)
202 |
203 | self.optim.load_state_dict(pkg['optim'])
204 | self.scheduler.load_state_dict(pkg['scheduler'])
205 |
206 | # + 1 to start from the next step and avoid overwriting the last checkpoint
207 | self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
208 |
209 | def print(self, msg):
210 | self.accelerator.print(msg)
211 |
212 | def generate(self, *args, **kwargs):
213 | return self.cfm_wrapper.generate(*args, **kwargs)
214 |
215 | @property
216 | def device(self):
217 | return self.accelerator.device
218 |
219 | @property
220 | def is_distributed(self):
221 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
222 |
223 | @property
224 | def is_main(self):
225 | return self.accelerator.is_main_process
226 |
227 | @property
228 | def is_local_main(self):
229 | return self.accelerator.is_local_main_process
230 |
231 | def warmup(self, step):
232 | if step < self.num_warmup_steps:
233 | return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps
234 | else:
235 | return self.lr
236 |
237 | def train_step(self):
238 | steps = int(self.steps.item())
239 |
240 | self.cfm_wrapper.train()
241 |
242 | # adjust the lr according to the schedule
243 |
244 | if steps < self.num_warmup_steps:
245 | # apply warmup
246 |
247 | lr = self.warmup(steps)
248 | for param_group in self.optim.param_groups:
249 | param_group['lr'] = lr
250 | else:
251 | # after warmup period, start to apply lr annealing
252 |
253 | self.scheduler.step()
254 |
255 | # logs
256 |
257 | logs = {}
258 |
259 | # training step
260 |
261 | for grad_accum_step in range(self.grad_accum_every):
262 | is_last = grad_accum_step == (self.grad_accum_every - 1)
263 | context = partial(self.accelerator.no_sync, self.cfm_wrapper) if not is_last else nullcontext
264 |
265 | wave, = next(self.dl_iter)
266 |
267 | with self.accelerator.autocast(), context():
268 | loss = self.cfm_wrapper(wave)
269 |
270 | self.accelerator.backward(loss / self.grad_accum_every)
271 |
272 | accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
273 |
274 | if exists(self.max_grad_norm):
275 | self.accelerator.clip_grad_norm_(self.cfm_wrapper.parameters(), self.max_grad_norm)
276 |
277 | self.optim.step()
278 | self.optim.zero_grad()
279 |
280 | # log
281 |
282 | if not steps % self.log_every:
283 | self.print(f"{steps}: loss: {logs['loss']:0.3f}")
284 |
285 | self.accelerator.log({"train_loss": logs['loss']}, step=steps)
286 |
287 | # sample results every so often
288 |
289 | self.accelerator.wait_for_everyone()
290 |
291 | if self.is_main and not (steps % self.save_results_every):
292 | wave, = next(self.valid_dl_iter)
293 | unwrapped_model = self.accelerator.unwrap_model(self.cfm_wrapper)
294 |
295 | with torch.inference_mode():
296 | unwrapped_model.eval()
297 |
298 | wave = wave.to(unwrapped_model.device)
299 | valid_loss = unwrapped_model(wave)
300 |
301 | self.print(f'{steps}: valid loss {valid_loss:0.3f}')
302 | self.accelerator.log({"valid_loss": valid_loss}, step=steps)
303 |
304 | # save model every so often
305 |
306 | if self.is_main and not (steps % self.save_model_every):
307 | model_path = str(self.results_folder / f'voicebox.{steps}.pt')
308 | self.save(model_path)
309 |
310 | self.print(f'{steps}: saving model to {str(self.results_folder)}')
311 |
312 | self.steps += 1
313 | return logs
314 |
315 | def train(self, log_fn = noop):
316 | while self.steps < self.num_train_steps:
317 | logs = self.train_step()
318 | log_fn(logs)
319 |
320 | self.print('training complete')
321 | self.accelerator.end_training()
322 |
--------------------------------------------------------------------------------
/voicebox_pytorch/voicebox_pytorch.py:
--------------------------------------------------------------------------------
1 | import math
2 | import logging
3 | from random import random
4 | from functools import partial
5 | from pathlib import Path
6 |
7 | import torch
8 | from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
9 | from torch.nn import Module
10 | import torch.nn.functional as F
11 | from torch.cuda.amp import autocast
12 |
13 | import torchode as to
14 | from torchdiffeq import odeint
15 |
16 | from beartype import beartype
17 | from beartype.typing import Tuple, Optional, List, Union
18 |
19 | from einops.layers.torch import Rearrange
20 | from einops import rearrange, repeat, reduce, pack, unpack
21 |
22 | from voicebox_pytorch.attend import Attend
23 |
24 | from naturalspeech2_pytorch.aligner import Aligner, ForwardSumLoss, BinLoss, maximum_path
25 | from naturalspeech2_pytorch.utils.tokenizer import Tokenizer
26 | from naturalspeech2_pytorch.naturalspeech2_pytorch import generate_mask_from_repeats
27 |
28 | from audiolm_pytorch import EncodecWrapper
29 | from spear_tts_pytorch import TextToSemantic
30 |
31 | from gateloop_transformer import SimpleGateLoopLayer as GateLoop
32 |
33 | import torchaudio.transforms as T
34 | from torchaudio.functional import DB_to_amplitude, resample
35 |
36 | from vocos import Vocos
37 |
38 | LOGGER = logging.getLogger(__file__)
39 |
40 | # helper functions
41 |
42 | def exists(val):
43 | return val is not None
44 |
45 | def identity(t):
46 | return t
47 |
48 | def default(val, d):
49 | return val if exists(val) else d
50 |
51 | def divisible_by(num, den):
52 | return (num % den) == 0
53 |
54 | def is_odd(n):
55 | return not divisible_by(n, 2)
56 |
57 | def coin_flip():
58 | return random() < 0.5
59 |
60 | def pack_one(t, pattern):
61 | return pack([t], pattern)
62 |
63 | def unpack_one(t, ps, pattern):
64 | return unpack(t, ps, pattern)[0]
65 |
66 | # tensor helpers
67 |
68 | def prob_mask_like(shape, prob, device):
69 | if prob == 1:
70 | return torch.ones(shape, device = device, dtype = torch.bool)
71 | elif prob == 0:
72 | return torch.zeros(shape, device = device, dtype = torch.bool)
73 | else:
74 | return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
75 |
76 | def reduce_masks_with_and(*masks):
77 | masks = [*filter(exists, masks)]
78 |
79 | if len(masks) == 0:
80 | return None
81 |
82 | mask, *rest_masks = masks
83 |
84 | for rest_mask in rest_masks:
85 | mask = mask & rest_mask
86 |
87 | return mask
88 |
89 | def interpolate_1d(t, length, mode = 'bilinear'):
90 | " pytorch does not offer interpolation 1d, so hack by converting to 2d "
91 |
92 | dtype = t.dtype
93 | t = t.float()
94 |
95 | implicit_one_channel = t.ndim == 2
96 | if implicit_one_channel:
97 | t = rearrange(t, 'b n -> b 1 n')
98 |
99 | t = rearrange(t, 'b d n -> b d n 1')
100 | t = F.interpolate(t, (length, 1), mode = mode)
101 | t = rearrange(t, 'b d n 1 -> b d n')
102 |
103 | if implicit_one_channel:
104 | t = rearrange(t, 'b 1 n -> b n')
105 |
106 | t = t.to(dtype)
107 | return t
108 |
109 | def curtail_or_pad(t, target_length):
110 | length = t.shape[-2]
111 |
112 | if length > target_length:
113 | t = t[..., :target_length, :]
114 | elif length < target_length:
115 | t = F.pad(t, (0, 0, 0, target_length - length), value = 0.)
116 |
117 | return t
118 |
119 | # mask construction helpers
120 |
121 | def mask_from_start_end_indices(
122 | seq_len: int,
123 | start: Tensor,
124 | end: Tensor
125 | ):
126 | assert start.shape == end.shape
127 | device = start.device
128 |
129 | seq = torch.arange(seq_len, device = device, dtype = torch.long)
130 | seq = seq.reshape(*((-1,) * start.ndim), seq_len)
131 | seq = seq.expand(*start.shape, seq_len)
132 |
133 | mask = seq >= start[..., None].long()
134 | mask &= seq < end[..., None].long()
135 | return mask
136 |
137 | def mask_from_frac_lengths(
138 | seq_len: int,
139 | frac_lengths: Tensor
140 | ):
141 | device = frac_lengths.device
142 |
143 | lengths = (frac_lengths * seq_len).long()
144 | max_start = seq_len - lengths
145 |
146 | rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1)
147 | start = (max_start * rand).clamp(min = 0)
148 | end = start + lengths
149 |
150 | return mask_from_start_end_indices(seq_len, start, end)
151 |
152 | # sinusoidal positions
153 |
154 | class LearnedSinusoidalPosEmb(Module):
155 | """ used by @crowsonkb """
156 |
157 | def __init__(self, dim):
158 | super().__init__()
159 | assert divisible_by(dim, 2)
160 | half_dim = dim // 2
161 | self.weights = nn.Parameter(torch.randn(half_dim))
162 |
163 | def forward(self, x):
164 | x = rearrange(x, 'b -> b 1')
165 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
166 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
167 | return fouriered
168 |
169 | # rotary positional embeddings
170 | # https://arxiv.org/abs/2104.09864
171 |
172 | class RotaryEmbedding(Module):
173 | def __init__(self, dim, theta = 50000):
174 | super().__init__()
175 | inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
176 | self.register_buffer("inv_freq", inv_freq)
177 |
178 | @property
179 | def device(self):
180 | return self.inv_freq.device
181 |
182 | @autocast(enabled = False)
183 | @beartype
184 | def forward(self, t: Union[int, Tensor]):
185 | if not torch.is_tensor(t):
186 | t = torch.arange(t, device = self.device)
187 |
188 | t = t.type_as(self.inv_freq)
189 | freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
190 | freqs = torch.cat((freqs, freqs), dim = -1)
191 | return freqs
192 |
193 | def rotate_half(x):
194 | x1, x2 = x.chunk(2, dim = -1)
195 | return torch.cat((-x2, x1), dim = -1)
196 |
197 | @autocast(enabled = False)
198 | def apply_rotary_pos_emb(pos, t):
199 | return t * pos.cos() + rotate_half(t) * pos.sin()
200 |
201 | # convolutional positional generating module
202 |
203 | class ConvPositionEmbed(Module):
204 | def __init__(
205 | self,
206 | dim,
207 | *,
208 | kernel_size,
209 | groups = None
210 | ):
211 | super().__init__()
212 | assert is_odd(kernel_size)
213 | groups = default(groups, dim) # full depthwise conv by default
214 |
215 | self.dw_conv1d = nn.Sequential(
216 | nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
217 | nn.GELU()
218 | )
219 |
220 | def forward(self, x, mask = None):
221 |
222 | if exists(mask):
223 | mask = mask[..., None]
224 | x = x.masked_fill(~mask, 0.)
225 |
226 | x = rearrange(x, 'b n c -> b c n')
227 | x = self.dw_conv1d(x)
228 | out = rearrange(x, 'b c n -> b n c')
229 |
230 | if exists(mask):
231 | out = out.masked_fill(~mask, 0.)
232 |
233 | return out
234 |
235 | # norms
236 |
237 | class RMSNorm(Module):
238 | def __init__(
239 | self,
240 | dim
241 | ):
242 | super().__init__()
243 | self.scale = dim ** 0.5
244 | self.gamma = nn.Parameter(torch.ones(dim))
245 |
246 | def forward(self, x):
247 | return F.normalize(x, dim = -1) * self.scale * self.gamma
248 |
249 | class AdaptiveRMSNorm(Module):
250 | def __init__(
251 | self,
252 | dim,
253 | cond_dim = None
254 | ):
255 | super().__init__()
256 | cond_dim = default(cond_dim, dim)
257 | self.scale = dim ** 0.5
258 |
259 | self.to_gamma = nn.Linear(cond_dim, dim)
260 | self.to_beta = nn.Linear(cond_dim, dim)
261 |
262 | # init to identity
263 |
264 | nn.init.zeros_(self.to_gamma.weight)
265 | nn.init.ones_(self.to_gamma.bias)
266 |
267 | nn.init.zeros_(self.to_beta.weight)
268 | nn.init.zeros_(self.to_beta.bias)
269 |
270 | def forward(self, x, *, cond):
271 | normed = F.normalize(x, dim = -1) * self.scale
272 |
273 | gamma, beta = self.to_gamma(cond), self.to_beta(cond)
274 | gamma, beta = map(lambda t: rearrange(t, 'b d -> b 1 d'), (gamma, beta))
275 |
276 | return normed * gamma + beta
277 |
278 | # attention
279 |
280 | class MultiheadRMSNorm(Module):
281 | def __init__(self, dim, heads):
282 | super().__init__()
283 | self.scale = dim ** 0.5
284 | self.gamma = nn.Parameter(torch.ones(heads, 1, dim))
285 |
286 | def forward(self, x):
287 | return F.normalize(x, dim = -1) * self.gamma * self.scale
288 |
289 | class Attention(Module):
290 | def __init__(
291 | self,
292 | dim,
293 | dim_head = 64,
294 | heads = 8,
295 | dropout = 0,
296 | flash = False,
297 | qk_norm = False,
298 | qk_norm_scale = 10
299 | ):
300 | super().__init__()
301 | self.heads = heads
302 | dim_inner = dim_head * heads
303 |
304 | scale = qk_norm_scale if qk_norm else None
305 |
306 | self.attend = Attend(dropout, flash = flash, scale = scale)
307 |
308 | self.qk_norm = qk_norm
309 |
310 | if qk_norm:
311 | self.q_norm = MultiheadRMSNorm(dim_head, heads = heads)
312 | self.k_norm = MultiheadRMSNorm(dim_head, heads = heads)
313 |
314 | self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
315 | self.to_out = nn.Linear(dim_inner, dim, bias = False)
316 |
317 | def forward(self, x, mask = None, rotary_emb = None):
318 | h = self.heads
319 |
320 | q, k, v = self.to_qkv(x).chunk(3, dim = -1)
321 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
322 |
323 | if self.qk_norm:
324 | q = self.q_norm(q)
325 | k = self.k_norm(k)
326 |
327 | if exists(rotary_emb):
328 | q, k = map(lambda t: apply_rotary_pos_emb(rotary_emb, t), (q, k))
329 |
330 | out = self.attend(q, k, v, mask = mask)
331 |
332 | out = rearrange(out, 'b h n d -> b n (h d)')
333 | return self.to_out(out)
334 |
335 | # feedforward
336 |
337 | class GEGLU(Module):
338 | def forward(self, x):
339 | x, gate = x.chunk(2, dim = -1)
340 | return F.gelu(gate) * x
341 |
342 | def FeedForward(dim, mult = 4, dropout = 0.):
343 | dim_inner = int(dim * mult * 2 / 3)
344 | return nn.Sequential(
345 | nn.Linear(dim, dim_inner * 2),
346 | GEGLU(),
347 | nn.Dropout(dropout),
348 | nn.Linear(dim_inner, dim)
349 | )
350 |
351 | # transformer
352 |
353 | class Transformer(Module):
354 | def __init__(
355 | self,
356 | dim,
357 | *,
358 | depth,
359 | dim_head = 64,
360 | heads = 8,
361 | ff_mult = 4,
362 | attn_dropout = 0.,
363 | ff_dropout = 0.,
364 | num_register_tokens = 0.,
365 | attn_flash = False,
366 | adaptive_rmsnorm = False,
367 | adaptive_rmsnorm_cond_dim_in = None,
368 | use_unet_skip_connection = False,
369 | skip_connect_scale = None,
370 | attn_qk_norm = False,
371 | use_gateloop_layers = False,
372 | gateloop_use_jax = False,
373 | ):
374 | super().__init__()
375 | assert divisible_by(depth, 2)
376 | self.layers = nn.ModuleList([])
377 |
378 | self.rotary_emb = RotaryEmbedding(dim = dim_head)
379 |
380 | self.num_register_tokens = num_register_tokens
381 | self.has_register_tokens = num_register_tokens > 0
382 |
383 | if self.has_register_tokens:
384 | self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
385 |
386 | if adaptive_rmsnorm:
387 | rmsnorm_klass = partial(AdaptiveRMSNorm, cond_dim = adaptive_rmsnorm_cond_dim_in)
388 | else:
389 | rmsnorm_klass = RMSNorm
390 |
391 | self.skip_connect_scale = default(skip_connect_scale, 2 ** -0.5)
392 |
393 | for ind in range(depth):
394 | layer = ind + 1
395 | has_skip = use_unet_skip_connection and layer > (depth // 2)
396 |
397 | self.layers.append(nn.ModuleList([
398 | nn.Linear(dim * 2, dim) if has_skip else None,
399 | GateLoop(dim = dim, use_jax_associative_scan = gateloop_use_jax, post_ln = True) if use_gateloop_layers else None,
400 | rmsnorm_klass(dim = dim),
401 | Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = attn_flash, qk_norm = attn_qk_norm),
402 | rmsnorm_klass(dim = dim),
403 | FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
404 | ]))
405 |
406 | self.final_norm = RMSNorm(dim)
407 |
408 | @property
409 | def device(self):
410 | return next(self.parameters()).device
411 |
412 | def forward(
413 | self,
414 | x,
415 | mask = None,
416 | adaptive_rmsnorm_cond = None
417 | ):
418 | batch, seq_len, *_ = x.shape
419 |
420 | # add register tokens to the left
421 |
422 | if self.has_register_tokens:
423 | register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
424 |
425 | x, ps = pack([register_tokens, x], 'b * d')
426 |
427 | if exists(mask):
428 | mask = F.pad(mask, (self.num_register_tokens, 0), value = True)
429 |
430 | # keep track of skip connections
431 |
432 | skip_connects = []
433 |
434 | # rotary embeddings
435 |
436 | positions = seq_len
437 |
438 | if self.has_register_tokens:
439 | main_positions = torch.arange(seq_len, device = self.device, dtype = torch.long)
440 | register_positions = torch.full((self.num_register_tokens,), -10000, device = self.device, dtype = torch.long)
441 | positions = torch.cat((register_positions, main_positions))
442 |
443 | rotary_emb = self.rotary_emb(positions)
444 |
445 | # adaptive rmsnorm
446 |
447 | rmsnorm_kwargs = dict()
448 | if exists(adaptive_rmsnorm_cond):
449 | rmsnorm_kwargs = dict(cond = adaptive_rmsnorm_cond)
450 |
451 | # going through the attention layers
452 |
453 | for skip_combiner, maybe_gateloop, attn_prenorm, attn, ff_prenorm, ff in self.layers:
454 |
455 | # in the paper, they use a u-net like skip connection
456 | # unclear how much this helps, as no ablations or further numbers given besides a brief one-two sentence mention
457 |
458 | if not exists(skip_combiner):
459 | skip_connects.append(x)
460 | else:
461 | skip_connect = skip_connects.pop() * self.skip_connect_scale
462 | x = torch.cat((x, skip_connect), dim = -1)
463 | x = skip_combiner(x)
464 |
465 | if exists(maybe_gateloop):
466 | x = maybe_gateloop(x) + x
467 |
468 | attn_input = attn_prenorm(x, **rmsnorm_kwargs)
469 | x = attn(attn_input, mask = mask, rotary_emb = rotary_emb) + x
470 |
471 | ff_input = ff_prenorm(x, **rmsnorm_kwargs)
472 | x = ff(ff_input) + x
473 |
474 | # remove the register tokens
475 |
476 | if self.has_register_tokens:
477 | _, x = unpack(x, ps, 'b * d')
478 |
479 | return self.final_norm(x)
480 |
481 | # encoder decoders
482 |
483 | class AudioEncoderDecoder(nn.Module):
484 | pass
485 |
486 | class MelVoco(AudioEncoderDecoder):
487 | def __init__(
488 | self,
489 | *,
490 | log = True,
491 | n_mels = 100,
492 | sampling_rate = 24000,
493 | f_max = 8000,
494 | n_fft = 1024,
495 | win_length = 640,
496 | hop_length = 160,
497 | pretrained_vocos_path = 'charactr/vocos-mel-24khz'
498 | ):
499 | super().__init__()
500 | self.log = log
501 | self.n_mels = n_mels
502 | self.n_fft = n_fft
503 | self.f_max = f_max
504 | self.win_length = win_length
505 | self.hop_length = hop_length
506 | self.sampling_rate = sampling_rate
507 |
508 | self.vocos = Vocos.from_pretrained(pretrained_vocos_path)
509 |
510 | @property
511 | def downsample_factor(self):
512 | raise NotImplementedError
513 |
514 | @property
515 | def latent_dim(self):
516 | return self.num_mels
517 |
518 | def encode(self, audio):
519 | stft_transform = T.Spectrogram(
520 | n_fft = self.n_fft,
521 | win_length = self.win_length,
522 | hop_length = self.hop_length,
523 | window_fn = torch.hann_window
524 | )
525 |
526 | spectrogram = stft_transform(audio)
527 |
528 | mel_transform = T.MelScale(
529 | n_mels = self.n_mels,
530 | sample_rate = self.sampling_rate,
531 | n_stft = self.n_fft // 2 + 1,
532 | f_max = self.f_max
533 | )
534 |
535 | mel = mel_transform(spectrogram)
536 |
537 | if self.log:
538 | mel = T.AmplitudeToDB()(mel)
539 |
540 | mel = rearrange(mel, 'b d n -> b n d')
541 | return mel
542 |
543 | def decode(self, mel):
544 | mel = rearrange(mel, 'b n d -> b d n')
545 |
546 | if self.log:
547 | mel = DB_to_amplitude(mel, ref = 1., power = 0.5)
548 |
549 | return self.vocos.decode(mel)
550 |
551 | class EncodecVoco(AudioEncoderDecoder):
552 | def __init__(
553 | self,
554 | *,
555 | sampling_rate = 24000,
556 | pretrained_vocos_path = 'charactr/vocos-encodec-24khz',
557 | bandwidth_id = 2
558 | ):
559 | super().__init__()
560 | self.sampling_rate = sampling_rate
561 | self.encodec = EncodecWrapper()
562 | self.vocos = Vocos.from_pretrained(pretrained_vocos_path)
563 |
564 | self.register_buffer('bandwidth_id', torch.tensor([bandwidth_id]))
565 |
566 | @property
567 | def downsample_factor(self):
568 | return self.encodec.downsample_factor
569 |
570 | @property
571 | def latent_dim(self):
572 | return self.encodec.codebook_dim
573 |
574 | def encode(self, audio):
575 | encoded_audio, _, _ = self.encodec(audio, return_encoded = True)
576 | return encoded_audio
577 |
578 | def decode_to_codes(self, latents):
579 | _, codes, _ = self.encodec.rq(latents)
580 | codes = rearrange(codes, 'b n q -> b q n')
581 | return codes
582 |
583 | def decode(self, latents):
584 | codes = self.decode_to_codes(latents)
585 |
586 | all_audios = []
587 | for code in codes:
588 | features = self.vocos.codes_to_features(code)
589 | audio = self.vocos.decode(features, bandwidth_id = self.bandwidth_id)
590 | all_audios.append(audio)
591 |
592 | return torch.stack(all_audios)
593 |
594 | # both duration and main denoising model are transformers
595 |
596 | class DurationPredictor(Module):
597 | @beartype
598 | def __init__(
599 | self,
600 | *,
601 | audio_enc_dec: Optional[AudioEncoderDecoder] = None,
602 | tokenizer: Optional[Tokenizer] = None,
603 | num_phoneme_tokens: Optional[int] = None,
604 | dim_phoneme_emb = 512,
605 | dim = 512,
606 | depth = 10,
607 | dim_head = 64,
608 | heads = 8,
609 | ff_mult = 4,
610 | ff_dropout = 0.,
611 | conv_pos_embed_kernel_size = 31,
612 | conv_pos_embed_groups = None,
613 | attn_dropout = 0,
614 | attn_flash = False,
615 | attn_qk_norm = True,
616 | use_gateloop_layers = False,
617 | p_drop_prob = 0.2, # p_drop in paper
618 | frac_lengths_mask: Tuple[float, float] = (0.1, 1.),
619 | aligner_kwargs: dict = dict(dim_in = 80, attn_channels = 80)
620 | ):
621 | super().__init__()
622 |
623 | # audio encoder / decoder
624 |
625 | self.audio_enc_dec = audio_enc_dec
626 |
627 | if exists(audio_enc_dec) and dim != audio_enc_dec.latent_dim:
628 | self.proj_in = nn.Linear(audio_enc_dec.latent_dim, dim)
629 | else:
630 | self.proj_in = nn.Identity()
631 |
632 | # phoneme related
633 |
634 | assert not (exists(tokenizer) and exists(num_phoneme_tokens)), 'if a phoneme tokenizer was passed into duration module, number of phoneme tokens does not need to be specified'
635 |
636 | if not exists(tokenizer) and not exists(num_phoneme_tokens):
637 | tokenizer = Tokenizer() # default to english phonemes with espeak
638 |
639 | if exists(tokenizer):
640 | num_phoneme_tokens = tokenizer.vocab_size
641 |
642 | self.tokenizer = tokenizer
643 |
644 | self.to_phoneme_emb = nn.Embedding(num_phoneme_tokens, dim_phoneme_emb)
645 |
646 | self.p_drop_prob = p_drop_prob
647 | self.frac_lengths_mask = frac_lengths_mask
648 |
649 | self.to_embed = nn.Linear(dim + dim_phoneme_emb, dim)
650 |
651 | self.null_cond = nn.Parameter(torch.zeros(dim), requires_grad = False)
652 |
653 | self.conv_embed = ConvPositionEmbed(
654 | dim = dim,
655 | kernel_size = conv_pos_embed_kernel_size,
656 | groups = conv_pos_embed_groups
657 | )
658 |
659 | self.transformer = Transformer(
660 | dim = dim,
661 | depth = depth,
662 | dim_head = dim_head,
663 | heads = heads,
664 | ff_mult = ff_mult,
665 | ff_dropout = ff_dropout,
666 | attn_dropout=attn_dropout,
667 | attn_flash = attn_flash,
668 | attn_qk_norm = attn_qk_norm,
669 | use_gateloop_layers = use_gateloop_layers
670 | )
671 |
672 | self.to_pred = nn.Sequential(
673 | nn.Linear(dim, 1),
674 | Rearrange('... 1 -> ...')
675 | )
676 |
677 | # aligner related
678 |
679 | # if we are using mel spec with 80 channels, we need to set attn_channels to 80
680 | # dim_in assuming we have spec with 80 channels
681 |
682 | self.aligner = Aligner(dim_hidden = dim_phoneme_emb, **aligner_kwargs)
683 | self.align_loss = ForwardSumLoss()
684 |
685 | @property
686 | def device(self):
687 | return next(self.parameters()).device
688 |
689 | def align_phoneme_ids_with_durations(self, phoneme_ids, durations):
690 | repeat_mask = generate_mask_from_repeats(durations.clamp(min = 1))
691 | aligned_phoneme_ids = einsum('b i, b i j -> b j', phoneme_ids.float(), repeat_mask.float()).long()
692 | return aligned_phoneme_ids
693 |
694 | @torch.inference_mode()
695 | @beartype
696 | def forward_with_cond_scale(
697 | self,
698 | *args,
699 | texts: Optional[List[str]] = None,
700 | phoneme_ids = None,
701 | cond_scale = 1.,
702 | return_aligned_phoneme_ids = False,
703 | **kwargs
704 | ):
705 | if exists(texts):
706 | phoneme_ids = self.tokenizer.texts_to_tensor_ids(texts)
707 |
708 | forward_kwargs = dict(
709 | return_aligned_phoneme_ids = False,
710 | phoneme_ids = phoneme_ids
711 | )
712 |
713 | durations = self.forward(*args, cond_drop_prob = 0., **forward_kwargs, **kwargs)
714 |
715 | if cond_scale == 1.:
716 | if not return_aligned_phoneme_ids:
717 | return durations
718 |
719 | return durations, self.align_phoneme_ids_with_durations(phoneme_ids, durations)
720 |
721 | null_durations = self.forward(*args, cond_drop_prob = 1., **forward_kwargs, **kwargs)
722 | scaled_durations = null_durations + (durations - null_durations) * cond_scale
723 |
724 | if not return_aligned_phoneme_ids:
725 | return scaled_durations
726 |
727 | return scaled_durations, self.align_phoneme_ids_with_durations(phoneme_ids, scaled_durations)
728 |
729 | @beartype
730 | def forward_aligner(
731 | self,
732 | x: FloatTensor, # (b, t, c)
733 | x_mask: IntTensor, # (b, 1, t)
734 | y: FloatTensor, # (b, t, c)
735 | y_mask: IntTensor # (b, 1, t)
736 | ) -> Tuple[
737 | FloatTensor, # alignment_hard: (b, t)
738 | FloatTensor, # alignment_soft: (b, tx, ty)
739 | FloatTensor, # alignment_logprob: (b, 1, ty, tx)
740 | BoolTensor # alignment_mas: (b, tx, ty)
741 | ]:
742 | attn_mask = rearrange(x_mask, 'b 1 t -> b 1 t 1') * rearrange(y_mask, 'b 1 t -> b 1 1 t')
743 | alignment_soft, alignment_logprob = self.aligner(rearrange(y, 'b t c -> b c t'), x, x_mask)
744 |
745 | assert not torch.isnan(alignment_soft).any()
746 |
747 | alignment_mas = maximum_path(
748 | rearrange(alignment_soft, 'b 1 t1 t2 -> b t2 t1').contiguous(),
749 | rearrange(attn_mask, 'b 1 t1 t2 -> b t1 t2').contiguous()
750 | )
751 |
752 | alignment_hard = torch.sum(alignment_mas, -1).float()
753 | alignment_soft = rearrange(alignment_soft, 'b 1 t1 t2 -> b t2 t1')
754 | return alignment_hard, alignment_soft, alignment_logprob, alignment_mas
755 |
756 | @beartype
757 | def forward(
758 | self,
759 | *,
760 | cond,
761 | texts: Optional[List[str]] = None,
762 | phoneme_ids = None,
763 | cond_drop_prob = 0.,
764 | target = None,
765 | cond_mask = None,
766 | mel = None,
767 | phoneme_len = None,
768 | mel_len = None,
769 | phoneme_mask = None,
770 | mel_mask = None,
771 | self_attn_mask = None,
772 | return_aligned_phoneme_ids = False
773 | ):
774 | batch, seq_len, cond_dim = cond.shape
775 |
776 | cond = self.proj_in(cond)
777 |
778 | # text to phonemes, if tokenizer is given
779 |
780 | if not exists(phoneme_ids):
781 | assert exists(self.tokenizer)
782 | phoneme_ids = self.tokenizer.texts_to_tensor_ids(texts)
783 |
784 | # construct mask if not given
785 |
786 | if not exists(cond_mask):
787 | if coin_flip():
788 | frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
789 | cond_mask = mask_from_frac_lengths(seq_len, frac_lengths)
790 | else:
791 | cond_mask = prob_mask_like((batch, seq_len), self.p_drop_prob, self.device)
792 |
793 | cond = cond * rearrange(~cond_mask, '... -> ... 1')
794 |
795 | # classifier free guidance
796 |
797 | if cond_drop_prob > 0.:
798 | cond_drop_mask = prob_mask_like(cond.shape[:1], cond_drop_prob, cond.device)
799 |
800 | cond = torch.where(
801 | rearrange(cond_drop_mask, '... -> ... 1 1'),
802 | self.null_cond,
803 | cond
804 | )
805 |
806 | # phoneme id of -1 is padding
807 |
808 | if not exists(self_attn_mask):
809 | self_attn_mask = phoneme_ids != -1
810 |
811 | phoneme_ids = phoneme_ids.clamp(min = 0)
812 |
813 | # get phoneme embeddings
814 |
815 | phoneme_emb = self.to_phoneme_emb(phoneme_ids)
816 |
817 | # force condition to be same length as input phonemes
818 |
819 | cond = curtail_or_pad(cond, phoneme_ids.shape[-1])
820 |
821 | # combine audio, phoneme, conditioning
822 |
823 | embed = torch.cat((phoneme_emb, cond), dim = -1)
824 | x = self.to_embed(embed)
825 |
826 | x = self.conv_embed(x, mask = self_attn_mask) + x
827 |
828 | x = self.transformer(
829 | x,
830 | mask = self_attn_mask
831 | )
832 |
833 | durations = self.to_pred(x)
834 |
835 | if not self.training:
836 | if not return_aligned_phoneme_ids:
837 | return durations
838 |
839 | return durations, self.align_phoneme_ids_with_durations(phoneme_ids, durations)
840 |
841 | # aligner
842 | # use alignment_hard to oversample phonemes
843 | # Duration Predictor should predict the duration of unmasked phonemes where target is masked alignment_hard
844 |
845 | assert all([exists(el) for el in (phoneme_len, mel_len, phoneme_mask, mel_mask)]), 'need to pass phoneme_len, mel_len, phoneme_mask, mel_mask, to train duration predictor module'
846 |
847 | alignment_hard, _, alignment_logprob, _ = self.forward_aligner(phoneme_emb, phoneme_mask, mel, mel_mask)
848 | target = alignment_hard
849 |
850 | if exists(self_attn_mask):
851 | loss_mask = cond_mask & self_attn_mask
852 | else:
853 | loss_mask = self_attn_mask
854 |
855 | if not exists(loss_mask):
856 | return F.l1_loss(x, target)
857 |
858 | loss = F.l1_loss(x, target, reduction = 'none')
859 | loss = loss.masked_fill(~loss_mask, 0.)
860 |
861 | # masked mean
862 |
863 | num = reduce(loss, 'b n -> b', 'sum')
864 | den = loss_mask.sum(dim = -1).clamp(min = 1e-5)
865 | loss = num / den
866 | loss = loss.mean()
867 |
868 | if not return_aligned_phoneme_ids:
869 | return loss
870 |
871 | #aligner loss
872 |
873 | align_loss = self.align_loss(alignment_logprob, phoneme_len, mel_len)
874 | loss = loss + align_loss
875 |
876 | return loss
877 |
878 | class VoiceBox(Module):
879 | def __init__(
880 | self,
881 | *,
882 | num_cond_tokens = None,
883 | audio_enc_dec: Optional[AudioEncoderDecoder] = None,
884 | dim_in = None,
885 | dim_cond_emb = 1024,
886 | dim = 1024,
887 | depth = 24,
888 | dim_head = 64,
889 | heads = 16,
890 | ff_mult = 4,
891 | ff_dropout = 0.,
892 | time_hidden_dim = None,
893 | conv_pos_embed_kernel_size = 31,
894 | conv_pos_embed_groups = None,
895 | attn_dropout = 0,
896 | attn_flash = False,
897 | attn_qk_norm = True,
898 | use_gateloop_layers = False,
899 | num_register_tokens = 16,
900 | p_drop_prob = 0.3, # p_drop in paper
901 | frac_lengths_mask: Tuple[float, float] = (0.7, 1.),
902 | condition_on_text = True
903 | ):
904 | super().__init__()
905 | dim_in = default(dim_in, dim)
906 |
907 | time_hidden_dim = default(time_hidden_dim, dim * 4)
908 |
909 | self.audio_enc_dec = audio_enc_dec
910 |
911 | if exists(audio_enc_dec) and dim != audio_enc_dec.latent_dim:
912 | self.proj_in = nn.Linear(audio_enc_dec.latent_dim, dim)
913 | else:
914 | self.proj_in = nn.Identity()
915 |
916 | self.sinu_pos_emb = nn.Sequential(
917 | LearnedSinusoidalPosEmb(dim),
918 | nn.Linear(dim, time_hidden_dim),
919 | nn.SiLU()
920 | )
921 |
922 | assert not (condition_on_text and not exists(num_cond_tokens)), 'number of conditioning tokens must be specified (whether phonemes or semantic token ids) if training conditional voicebox'
923 |
924 | if not condition_on_text:
925 | dim_cond_emb = 0
926 |
927 | self.dim_cond_emb = dim_cond_emb
928 | self.condition_on_text = condition_on_text
929 | self.num_cond_tokens = num_cond_tokens
930 |
931 | if condition_on_text:
932 | self.null_cond_id = num_cond_tokens # use last phoneme token as null token for CFG
933 | self.to_cond_emb = nn.Embedding(num_cond_tokens + 1, dim_cond_emb)
934 |
935 | self.p_drop_prob = p_drop_prob
936 | self.frac_lengths_mask = frac_lengths_mask
937 |
938 | self.to_embed = nn.Linear(dim_in * 2 + dim_cond_emb, dim)
939 |
940 | self.null_cond = nn.Parameter(torch.zeros(dim_in), requires_grad = False)
941 |
942 | self.conv_embed = ConvPositionEmbed(
943 | dim = dim,
944 | kernel_size = conv_pos_embed_kernel_size,
945 | groups = conv_pos_embed_groups
946 | )
947 |
948 | self.transformer = Transformer(
949 | dim = dim,
950 | depth = depth,
951 | dim_head = dim_head,
952 | heads = heads,
953 | ff_mult = ff_mult,
954 | ff_dropout = ff_dropout,
955 | attn_dropout= attn_dropout,
956 | attn_flash = attn_flash,
957 | attn_qk_norm = attn_qk_norm,
958 | num_register_tokens = num_register_tokens,
959 | adaptive_rmsnorm = True,
960 | adaptive_rmsnorm_cond_dim_in = time_hidden_dim,
961 | use_gateloop_layers = use_gateloop_layers
962 | )
963 |
964 | dim_out = audio_enc_dec.latent_dim if exists(audio_enc_dec) else dim_in
965 |
966 | self.to_pred = nn.Linear(dim, dim_out, bias = False)
967 |
968 | @property
969 | def device(self):
970 | return next(self.parameters()).device
971 |
972 | @torch.inference_mode()
973 | def forward_with_cond_scale(
974 | self,
975 | *args,
976 | cond_scale = 1.,
977 | **kwargs
978 | ):
979 | logits = self.forward(*args, cond_drop_prob = 0., **kwargs)
980 |
981 | if cond_scale == 1.:
982 | return logits
983 |
984 | null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
985 | return null_logits + (logits - null_logits) * cond_scale
986 |
987 | def forward(
988 | self,
989 | x,
990 | *,
991 | times,
992 | cond_token_ids,
993 | self_attn_mask = None,
994 | cond_drop_prob = 0.1,
995 | target = None,
996 | cond = None,
997 | cond_mask = None
998 | ):
999 | # project in, in case codebook dim is not equal to model dimensions
1000 |
1001 | x = self.proj_in(x)
1002 |
1003 | cond = default(cond, target)
1004 |
1005 | if exists(cond):
1006 | cond = self.proj_in(cond)
1007 |
1008 | # shapes
1009 |
1010 | batch, seq_len, cond_dim = cond.shape
1011 | assert cond_dim == x.shape[-1]
1012 |
1013 | # auto manage shape of times, for odeint times
1014 |
1015 | if times.ndim == 0:
1016 | times = repeat(times, '-> b', b = cond.shape[0])
1017 |
1018 | if times.ndim == 1 and times.shape[0] == 1:
1019 | times = repeat(times, '1 -> b', b = cond.shape[0])
1020 |
1021 | # construct conditioning mask if not given
1022 |
1023 | if self.training:
1024 | if not exists(cond_mask):
1025 | frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
1026 | cond_mask = mask_from_frac_lengths(seq_len, frac_lengths)
1027 | else:
1028 | if not exists(cond_mask):
1029 | cond_mask = torch.ones((batch, seq_len), device = cond.device, dtype = torch.bool)
1030 |
1031 | cond_mask_with_pad_dim = rearrange(cond_mask, '... -> ... 1')
1032 |
1033 | # as described in section 3.2
1034 |
1035 | cond = cond * ~cond_mask_with_pad_dim
1036 |
1037 | # classifier free guidance
1038 |
1039 | cond_ids = cond_token_ids
1040 |
1041 | if cond_drop_prob > 0.:
1042 | cond_drop_mask = prob_mask_like(cond.shape[:1], cond_drop_prob, self.device)
1043 |
1044 | cond = torch.where(
1045 | rearrange(cond_drop_mask, '... -> ... 1 1'),
1046 | self.null_cond,
1047 | cond
1048 | )
1049 |
1050 | cond_ids = torch.where(
1051 | rearrange(cond_drop_mask, '... -> ... 1'),
1052 | self.null_cond_id,
1053 | cond_token_ids
1054 | )
1055 |
1056 | # phoneme or semantic conditioning embedding
1057 |
1058 | cond_emb = None
1059 |
1060 | if self.condition_on_text:
1061 | cond_emb = self.to_cond_emb(cond_ids)
1062 |
1063 | cond_emb_length = cond_emb.shape[-2]
1064 | if cond_emb_length != seq_len:
1065 | cond_emb = rearrange(cond_emb, 'b n d -> b d n')
1066 | cond_emb = interpolate_1d(cond_emb, seq_len)
1067 | cond_emb = rearrange(cond_emb, 'b d n -> b n d')
1068 |
1069 | if exists(self_attn_mask):
1070 | self_attn_mask = interpolate_1d(self_attn_mask, seq_len)
1071 |
1072 | # concat source signal, semantic / phoneme conditioning embed, and conditioning
1073 | # and project
1074 |
1075 | to_concat = [*filter(exists, (x, cond_emb, cond))]
1076 | embed = torch.cat(to_concat, dim = -1)
1077 |
1078 | x = self.to_embed(embed)
1079 |
1080 | x = self.conv_embed(x, mask = self_attn_mask) + x
1081 |
1082 | time_emb = self.sinu_pos_emb(times)
1083 |
1084 | # attend
1085 |
1086 | x = self.transformer(
1087 | x,
1088 | mask = self_attn_mask,
1089 | adaptive_rmsnorm_cond = time_emb
1090 | )
1091 |
1092 | x = self.to_pred(x)
1093 |
1094 | # if no target passed in, just return logits
1095 |
1096 | if not exists(target):
1097 | return x
1098 |
1099 | loss_mask = reduce_masks_with_and(cond_mask, self_attn_mask)
1100 |
1101 | if not exists(loss_mask):
1102 | return F.mse_loss(x, target)
1103 |
1104 | loss = F.mse_loss(x, target, reduction = 'none')
1105 |
1106 | loss = reduce(loss, 'b n d -> b n', 'mean')
1107 | loss = loss.masked_fill(~loss_mask, 0.)
1108 |
1109 | # masked mean
1110 |
1111 | num = reduce(loss, 'b n -> b', 'sum')
1112 | den = loss_mask.sum(dim = -1).clamp(min = 1e-5)
1113 | loss = num / den
1114 |
1115 | return loss.mean()
1116 |
1117 | # wrapper for the CNF
1118 |
1119 | def is_probably_audio_from_shape(t):
1120 | return exists(t) and (t.ndim == 2 or (t.ndim == 3 and t.shape[1] == 1))
1121 |
1122 | class ConditionalFlowMatcherWrapper(Module):
1123 | @beartype
1124 | def __init__(
1125 | self,
1126 | voicebox: VoiceBox,
1127 | text_to_semantic: Optional[TextToSemantic] = None,
1128 | duration_predictor: Optional[DurationPredictor] = None,
1129 | sigma = 0.,
1130 | ode_atol = 1e-5,
1131 | ode_rtol = 1e-5,
1132 | use_torchode = False,
1133 | torchdiffeq_ode_method = 'midpoint', # use midpoint for torchdiffeq, as in paper
1134 | torchode_method_klass = to.Tsit5, # use tsit5 for torchode, as torchode does not have midpoint (recommended by Bryan @b-chiang)
1135 | cond_drop_prob = 0.
1136 | ):
1137 | super().__init__()
1138 | self.sigma = sigma
1139 |
1140 | self.voicebox = voicebox
1141 | self.condition_on_text = voicebox.condition_on_text
1142 |
1143 | assert not (not self.condition_on_text and exists(text_to_semantic)), 'TextToSemantic should not be passed in if not conditioning on text'
1144 | assert not (exists(text_to_semantic) and not exists(text_to_semantic.wav2vec)), 'the wav2vec module must exist on the TextToSemantic, if being used to condition on text'
1145 |
1146 | self.text_to_semantic = text_to_semantic
1147 | self.duration_predictor = duration_predictor
1148 |
1149 | if self.condition_on_text and (exists(text_to_semantic) or exists(duration_predictor)):
1150 | assert exists(text_to_semantic) ^ exists(duration_predictor), 'you should use either TextToSemantic from Spear-TTS, or DurationPredictor for the text / phoneme to audio alignment, but not both'
1151 |
1152 | self.cond_drop_prob = cond_drop_prob
1153 |
1154 | self.use_torchode = use_torchode
1155 | self.torchode_method_klass = torchode_method_klass
1156 |
1157 | self.odeint_kwargs = dict(
1158 | atol = ode_atol,
1159 | rtol = ode_rtol,
1160 | method = torchdiffeq_ode_method
1161 | )
1162 |
1163 | @property
1164 | def device(self):
1165 | return next(self.parameters()).device
1166 |
1167 | def load(self, path, strict = True):
1168 | # return pkg so the trainer can access it
1169 | path = Path(path)
1170 | assert path.exists()
1171 | pkg = torch.load(str(path), map_location = 'cpu')
1172 | self.load_state_dict(pkg['model'], strict = strict)
1173 | return pkg
1174 |
1175 | @torch.inference_mode()
1176 | def sample(
1177 | self,
1178 | *,
1179 | cond = None,
1180 | texts: Optional[List[str]] = None,
1181 | text_token_ids: Optional[Tensor] = None,
1182 | semantic_token_ids: Optional[Tensor] = None,
1183 | phoneme_ids: Optional[Tensor] = None,
1184 | cond_mask = None,
1185 | steps = 3,
1186 | cond_scale = 1.,
1187 | decode_to_audio = True,
1188 | decode_to_codes = False,
1189 | max_semantic_token_ids = 2048,
1190 | spec_decode = False,
1191 | spec_decode_gamma = 5 # could be higher, since speech is probably easier than text, needs to be tested
1192 | ):
1193 | # take care of condition as raw audio
1194 |
1195 | cond_is_raw_audio = is_probably_audio_from_shape(cond)
1196 |
1197 | if cond_is_raw_audio:
1198 | assert exists(self.voicebox.audio_enc_dec)
1199 |
1200 | self.voicebox.audio_enc_dec.eval()
1201 | cond = self.voicebox.audio_enc_dec.encode(cond)
1202 |
1203 | # setup text conditioning, either coming from duration model (as phoneme ids)
1204 | # for coming from text-to-semantic module from spear-tts paper, as (semantic ids)
1205 |
1206 | num_cond_inputs = sum([*map(exists, (texts, text_token_ids, semantic_token_ids, phoneme_ids))])
1207 | assert num_cond_inputs <= 1
1208 |
1209 | self_attn_mask = None
1210 | cond_token_ids = None
1211 |
1212 | if self.condition_on_text:
1213 | if exists(self.text_to_semantic) or exists(semantic_token_ids):
1214 | assert not exists(phoneme_ids)
1215 |
1216 | if not exists(semantic_token_ids):
1217 | self.text_to_semantic.eval()
1218 |
1219 | semantic_token_ids, self_attn_mask = self.text_to_semantic.generate(
1220 | source = default(text_token_ids, texts),
1221 | source_type = 'text',
1222 | target_type = 'speech',
1223 | max_length = max_semantic_token_ids,
1224 | return_target_mask = True,
1225 | spec_decode = spec_decode,
1226 | spec_decode_gamma = spec_decode_gamma
1227 | )
1228 |
1229 | cond_token_ids = semantic_token_ids
1230 |
1231 | elif exists(self.duration_predictor):
1232 | self.duration_predictor.eval()
1233 |
1234 | durations, aligned_phoneme_ids = self.duration_predictor.forward_with_cond_scale(
1235 | cond = cond,
1236 | texts = texts,
1237 | phoneme_ids = phoneme_ids,
1238 | return_aligned_phoneme_ids = True
1239 | )
1240 |
1241 | cond_token_ids = aligned_phoneme_ids
1242 |
1243 | cond_tokens_seq_len = cond_token_ids.shape[-1]
1244 | cond_target_length = cond_tokens_seq_len
1245 |
1246 | if exists(cond):
1247 | if exists(self.text_to_semantic):
1248 | # calculate the correct conditioning length for text to semantic
1249 | # based on the sampling freqs of wav2vec and audio-enc-dec, as well as downsample factor
1250 | # (cond_time x cond_sampling_freq / cond_downsample_factor) == (audio_time x audio_sampling_freq / audio_downsample_factor)
1251 | wav2vec = self.text_to_semantic.wav2vec
1252 | audio_enc_dec = self.voicebox.audio_enc_dec
1253 |
1254 | cond_target_length = (cond_tokens_seq_len * wav2vec.target_sample_hz / wav2vec.downsample_factor) / (audio_enc_dec.sampling_rate / audio_enc_dec.downsample_factor)
1255 | cond_target_length = math.ceil(cond_target_length)
1256 |
1257 | cond = curtail_or_pad(cond, cond_target_length)
1258 | else:
1259 | cond = torch.zeros((cond_token_ids.shape[0], cond_target_length, self.voicebox.audio_enc_dec.latent_dim), device = self.device)
1260 | else:
1261 | assert num_cond_inputs == 0, 'no conditioning inputs should be given if not conditioning on text'
1262 |
1263 | shape = cond.shape
1264 | batch = shape[0]
1265 |
1266 | # neural ode
1267 |
1268 | self.voicebox.eval()
1269 |
1270 | def fn(t, x, *, packed_shape = None):
1271 | if exists(packed_shape):
1272 | x = unpack_one(x, packed_shape, 'b *')
1273 |
1274 | out = self.voicebox.forward_with_cond_scale(
1275 | x,
1276 | times = t,
1277 | cond_token_ids = cond_token_ids,
1278 | cond = cond,
1279 | cond_scale = cond_scale,
1280 | cond_mask = cond_mask,
1281 | self_attn_mask = self_attn_mask
1282 | )
1283 |
1284 | if exists(packed_shape):
1285 | out = rearrange(out, 'b ... -> b (...)')
1286 |
1287 | return out
1288 |
1289 | y0 = torch.randn_like(cond)
1290 | t = torch.linspace(0, 1, steps, device = self.device)
1291 |
1292 | if not self.use_torchode:
1293 | LOGGER.debug('sampling with torchdiffeq')
1294 |
1295 | trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
1296 | sampled = trajectory[-1]
1297 | else:
1298 | LOGGER.debug('sampling with torchode')
1299 |
1300 | t = repeat(t, 'n -> b n', b = batch)
1301 | y0, packed_shape = pack_one(y0, 'b *')
1302 |
1303 | fn = partial(fn, packed_shape = packed_shape)
1304 |
1305 | term = to.ODETerm(fn)
1306 | step_method = self.torchode_method_klass(term = term)
1307 |
1308 | step_size_controller = to.IntegralController(
1309 | atol = self.odeint_kwargs['atol'],
1310 | rtol = self.odeint_kwargs['rtol'],
1311 | term = term
1312 | )
1313 |
1314 | solver = to.AutoDiffAdjoint(step_method, step_size_controller)
1315 | jit_solver = torch.compile(solver)
1316 |
1317 | init_value = to.InitialValueProblem(y0 = y0, t_eval = t)
1318 |
1319 | sol = jit_solver.solve(init_value)
1320 |
1321 | sampled = sol.ys[:, -1]
1322 | sampled = unpack_one(sampled, packed_shape, 'b *')
1323 |
1324 | if decode_to_codes and exists(self.voicebox.audio_enc_dec):
1325 | return self.voicebox.audio_enc_dec.decode_to_codes(sampled)
1326 |
1327 | if not decode_to_audio or not exists(self.voicebox.audio_enc_dec):
1328 | return sampled
1329 |
1330 | return self.voicebox.audio_enc_dec.decode(sampled)
1331 |
1332 | def forward(
1333 | self,
1334 | x1,
1335 | *,
1336 | mask = None,
1337 | semantic_token_ids = None,
1338 | phoneme_ids = None,
1339 | cond = None,
1340 | cond_mask = None,
1341 | input_sampling_rate = None # will assume it to be the same as the audio encoder decoder sampling rate, if not given. if given, will resample
1342 | ):
1343 | """
1344 | following eq (5) (6) in https://arxiv.org/pdf/2306.15687.pdf
1345 | """
1346 |
1347 | batch, seq_len, dtype, σ = *x1.shape[:2], x1.dtype, self.sigma
1348 |
1349 | # if raw audio is given, convert if audio encoder / decoder was passed in
1350 |
1351 | input_is_raw_audio, cond_is_raw_audio = map(is_probably_audio_from_shape, (x1, cond))
1352 |
1353 | if input_is_raw_audio:
1354 | raw_audio = x1
1355 |
1356 | if any([input_is_raw_audio, cond_is_raw_audio]):
1357 | assert exists(self.voicebox.audio_enc_dec), 'audio_enc_dec must be set on VoiceBox to train directly on raw audio'
1358 |
1359 | audio_enc_dec_sampling_rate = self.voicebox.audio_enc_dec.sampling_rate
1360 | input_sampling_rate = default(input_sampling_rate, audio_enc_dec_sampling_rate)
1361 |
1362 | with torch.no_grad():
1363 | self.voicebox.audio_enc_dec.eval()
1364 |
1365 | if input_is_raw_audio:
1366 | x1 = resample(x1, input_sampling_rate, audio_enc_dec_sampling_rate)
1367 | x1 = self.voicebox.audio_enc_dec.encode(x1)
1368 |
1369 | if exists(cond) and cond_is_raw_audio:
1370 | cond = resample(cond, input_sampling_rate, audio_enc_dec_sampling_rate)
1371 | cond = self.voicebox.audio_enc_dec.encode(cond)
1372 |
1373 | # setup text conditioning, either coming from duration model (as phoneme ids)
1374 | # or from text-to-semantic module, semantic ids encoded with wav2vec (hubert usually)
1375 |
1376 | assert self.condition_on_text or not (exists(semantic_token_ids) or exists(phoneme_ids)), 'semantic or phoneme ids should not be passed in if not conditioning on text'
1377 |
1378 | cond_token_ids = None
1379 |
1380 | if self.condition_on_text:
1381 | if exists(self.text_to_semantic) or exists(semantic_token_ids):
1382 | assert not exists(phoneme_ids), 'phoneme ids are not needed for conditioning with spear-tts text-to-semantic'
1383 |
1384 | if not exists(semantic_token_ids):
1385 | assert input_is_raw_audio
1386 | wav2vec = self.text_to_semantic.wav2vec
1387 | wav2vec_input = resample(raw_audio, input_sampling_rate, wav2vec.target_sample_hz)
1388 | semantic_token_ids = wav2vec(wav2vec_input).clone()
1389 |
1390 | cond_token_ids = semantic_token_ids
1391 | else:
1392 | assert exists(phoneme_ids)
1393 | cond_token_ids = phoneme_ids
1394 |
1395 | # main conditional flow logic is below
1396 |
1397 | # x0 is gaussian noise
1398 |
1399 | x0 = torch.randn_like(x1)
1400 |
1401 | # random times
1402 |
1403 | times = torch.rand((batch,), dtype = dtype, device = self.device)
1404 | t = rearrange(times, 'b -> b 1 1')
1405 |
1406 | # sample xt (w in the paper)
1407 |
1408 | w = (1 - (1 - σ) * t) * x0 + t * x1
1409 |
1410 | flow = x1 - (1 - σ) * x0
1411 |
1412 | # predict
1413 |
1414 | self.voicebox.train()
1415 |
1416 | loss = self.voicebox(
1417 | w,
1418 | cond = cond,
1419 | cond_mask = cond_mask,
1420 | times = times,
1421 | target = flow,
1422 | self_attn_mask = mask,
1423 | cond_token_ids = cond_token_ids,
1424 | cond_drop_prob = self.cond_drop_prob
1425 | )
1426 |
1427 | return loss
1428 |
--------------------------------------------------------------------------------