├── images
├── voicebox.png
└── translated.png
├── voicebox_pytorch
├── __init__.py
├── optimizer.py
├── data.py
├── attend.py
├── trainer.py
└── voicebox_pytorch.py
├── LICENSE
├── setup.py
├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
└── README.md
/images/voicebox.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/voicebox-pytorch/HEAD/images/voicebox.png
--------------------------------------------------------------------------------
/images/translated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/voicebox-pytorch/HEAD/images/translated.png
--------------------------------------------------------------------------------
/voicebox_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from voicebox_pytorch.voicebox_pytorch import (
2 | Transformer,
3 | EncodecVoco,
4 | VoiceBox,
5 | DurationPredictor,
6 | ConditionalFlowMatcherWrapper,
7 | )
8 |
9 | from voicebox_pytorch.trainer import (
10 | VoiceBoxTrainer
11 | )
12 |
13 | from spear_tts_pytorch import TextToSemantic
14 |
15 | from audiolm_pytorch import HubertWithKmeans
16 |
--------------------------------------------------------------------------------
/voicebox_pytorch/optimizer.py:
--------------------------------------------------------------------------------
1 | from torch.optim import AdamW, Adam
2 |
3 | def separate_weight_decayable_params(params):
4 | wd_params, no_wd_params = [], []
5 | for param in params:
6 | param_list = no_wd_params if param.ndim < 2 else wd_params
7 | param_list.append(param)
8 | return wd_params, no_wd_params
9 |
10 | def get_optimizer(
11 | params,
12 | lr = 1e-4,
13 | wd = 1e-2,
14 | betas = (0.9, 0.99),
15 | eps = 1e-8,
16 | filter_by_requires_grad = False,
17 | group_wd_params = True
18 | ):
19 | has_wd = wd > 0
20 |
21 | if filter_by_requires_grad:
22 | params = list(filter(lambda t: t.requires_grad, params))
23 |
24 | if group_wd_params and has_wd:
25 | wd_params, no_wd_params = separate_weight_decayable_params(params)
26 |
27 | params = [
28 | {'params': wd_params},
29 | {'params': no_wd_params, 'weight_decay': 0},
30 | ]
31 |
32 | if not has_wd:
33 | return Adam(params, lr = lr, betas = betas, eps = eps)
34 |
35 | return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'voicebox-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.5.0',
7 | license='MIT',
8 | description = 'Voicebox - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/voicebox-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'text to speech'
17 | ],
18 | install_requires=[
19 | 'accelerate',
20 | 'audiolm-pytorch>=1.2.28',
21 | 'naturalspeech2-pytorch>=0.1.8',
22 | 'beartype',
23 | 'einops>=0.6.1',
24 | 'gateloop-transformer>=0.2.4',
25 | 'spear-tts-pytorch>=0.4.0',
26 | 'torch>=2.0',
27 | 'torchdiffeq',
28 | 'torchode',
29 | 'vocos'
30 | ],
31 | classifiers=[
32 | 'Development Status :: 4 - Beta',
33 | 'Intended Audience :: Developers',
34 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
35 | 'License :: OSI Approved :: MIT License',
36 | 'Programming Language :: Python :: 3.6',
37 | ],
38 | )
39 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/voicebox_pytorch/data.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from functools import wraps
3 |
4 | from einops import rearrange
5 |
6 | from beartype import beartype
7 | from beartype.door import is_bearable
8 | from beartype.typing import Optional, Tuple, Union
9 |
10 | import torch
11 | from torch.nn.utils.rnn import pad_sequence
12 | from torch.utils.data import Dataset, DataLoader
13 |
14 | import torchaudio
15 |
16 | # utilities
17 |
18 | def exists(val):
19 | return val is not None
20 |
21 | def cast_tuple(val, length = 1):
22 | return val if isinstance(val, tuple) else ((val,) * length)
23 |
24 | # dataset functions
25 |
26 | class AudioDataset(Dataset):
27 | @beartype
28 | def __init__(
29 | self,
30 | folder,
31 | audio_extension = ".flac"
32 | ):
33 | super().__init__()
34 | path = Path(folder)
35 | assert path.exists(), 'folder does not exist'
36 |
37 | self.audio_extension = audio_extension
38 |
39 | files = list(path.glob(f'**/*{audio_extension}'))
40 | assert len(files) > 0, 'no files found'
41 |
42 | self.files = files
43 |
44 | def __len__(self):
45 | return len(self.files)
46 |
47 | def __getitem__(self, idx):
48 | file = self.files[idx]
49 |
50 | wave, _ = torchaudio.load(file)
51 | wave = rearrange(wave, '1 ... -> ...')
52 |
53 | return wave
54 |
55 | # dataloader functions
56 |
57 | def collate_one_or_multiple_tensors(fn):
58 | @wraps(fn)
59 | def inner(data):
60 | is_one_data = not isinstance(data[0], tuple)
61 |
62 | if is_one_data:
63 | data = fn(data)
64 | return (data,)
65 |
66 | outputs = []
67 | for datum in zip(*data):
68 | if is_bearable(datum, Tuple[str, ...]):
69 | output = list(datum)
70 | else:
71 | output = fn(datum)
72 |
73 | outputs.append(output)
74 |
75 | return tuple(outputs)
76 |
77 | return inner
78 |
79 | @collate_one_or_multiple_tensors
80 | def curtail_to_shortest_collate(data):
81 | min_len = min(*[datum.shape[0] for datum in data])
82 | data = [datum[:min_len] for datum in data]
83 | return torch.stack(data)
84 |
85 | @collate_one_or_multiple_tensors
86 | def pad_to_longest_fn(data):
87 | return pad_sequence(data, batch_first = True)
88 |
89 | def get_dataloader(ds, pad_to_longest = True, **kwargs):
90 | collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate
91 | return DataLoader(ds, collate_fn = collate_fn, **kwargs)
92 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
--------------------------------------------------------------------------------
/voicebox_pytorch/attend.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | from packaging import version
3 | from collections import namedtuple
4 |
5 | import torch
6 | from torch import nn, einsum
7 | import torch.nn.functional as F
8 |
9 | from einops import rearrange, reduce
10 |
11 | # constants
12 |
13 | FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
14 |
15 | # helpers
16 |
17 | def exists(val):
18 | return val is not None
19 |
20 | def default(val, d):
21 | return val if exists(val) else d
22 |
23 | def once(fn):
24 | called = False
25 | @wraps(fn)
26 | def inner(x):
27 | nonlocal called
28 | if called:
29 | return
30 | called = True
31 | return fn(x)
32 | return inner
33 |
34 | print_once = once(print)
35 |
36 | # main class
37 |
38 | class Attend(nn.Module):
39 | def __init__(
40 | self,
41 | dropout = 0.,
42 | flash = False,
43 | scale = None
44 | ):
45 | super().__init__()
46 | self.dropout = dropout
47 | self.attn_dropout = nn.Dropout(dropout)
48 |
49 | self.scale = scale
50 |
51 | self.flash = flash
52 | assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
53 |
54 | # determine efficient attention configs for cuda and cpu
55 |
56 | self.cpu_config = FlashAttentionConfig(True, True, True)
57 | self.cuda_config = None
58 |
59 | if not torch.cuda.is_available() or not flash:
60 | return
61 |
62 | device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
63 |
64 | if device_properties.major == 8 and device_properties.minor == 0:
65 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
66 | self.cuda_config = FlashAttentionConfig(True, False, False)
67 | else:
68 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
69 | self.cuda_config = FlashAttentionConfig(False, True, True)
70 |
71 | def flash_attn(self, q, k, v, mask = None):
72 | _, heads, q_len, dim_head, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
73 |
74 | # if scale is given, divide by the default scale that sdpa uses
75 |
76 | if exists(self.scale):
77 | q = q * (self.scale / (dim_head ** -0.5))
78 |
79 | # Check if mask exists and expand to compatible shape
80 | # The mask is B L, so it would have to be expanded to B H N L
81 |
82 | if exists(mask):
83 | mask = mask.expand(-1, heads, q_len, -1)
84 |
85 | # Check if there is a compatible device for flash attention
86 |
87 | config = self.cuda_config if is_cuda else self.cpu_config
88 |
89 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
90 |
91 | with torch.backends.cuda.sdp_kernel(**config._asdict()):
92 | out = F.scaled_dot_product_attention(
93 | q, k, v,
94 | attn_mask = mask,
95 | dropout_p = self.dropout if self.training else 0.
96 | )
97 |
98 | return out
99 |
100 | def forward(self, q, k, v, mask = None):
101 | """
102 | einstein notation
103 | b - batch
104 | h - heads
105 | n, i, j - sequence length (base sequence length, source, target)
106 | d - feature dimension
107 | """
108 |
109 | q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
110 |
111 | scale = default(self.scale, q.shape[-1] ** -0.5)
112 |
113 | if exists(mask) and mask.ndim != 4:
114 | mask = rearrange(mask, 'b j -> b 1 1 j')
115 |
116 | if self.flash:
117 | return self.flash_attn(q, k, v, mask = mask)
118 |
119 | # similarity
120 |
121 | sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
122 |
123 | # key padding mask
124 |
125 | if exists(mask):
126 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
127 |
128 | # attention
129 |
130 | attn = sim.softmax(dim=-1)
131 | attn = self.attn_dropout(attn)
132 |
133 | # aggregate values
134 |
135 | out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
136 |
137 | return out
138 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Voicebox - Pytorch
4 |
5 | Implementation of Voicebox, new SOTA Text-to-Speech model from MetaAI, in Pytorch. Press release
6 |
7 | In this work, we will use rotary embeddings. The authors seem unaware that ALiBi cannot be straightforwardly used for bidirectional models.
8 |
9 | The paper also addresses the issue with time embedding incorrectly subjected to relative distances (they concat the time embedding along the frame dimension of the audio tokens). This repository will use adaptive normalization, as applied successfully in Paella
10 |
11 | Update: Recommend you just use E2 TTS instead of this work
12 |
13 | ## Appreciation
14 |
15 | - 
for awarding me the Imminent Grant to advance the state of open sourced text-to-speech solutions. This project was started and will be completed under this grant.
16 |
17 | - StabilityAI for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.
18 |
19 | - Bryan Chiang for the ongoing code review, sharing his expertise on TTS, and pointing me to an open sourced implementation of conditional flow matching
20 |
21 | - Manmay for getting the repository started with the alignment code
22 |
23 | - @chenht2010 for finding a bug with rotary positions, and for validating that the code in the repository converges
24 |
25 | - Lucas Newman for (yet again) pull requesting all the training code for Spear-TTS conditioned Voicebox training!
26 |
27 | - Lucas Newman has demonstrated that the whole system works with Spear-TTS conditioning. Training converges even better than Soundstorm
28 |
29 | ## Install
30 |
31 | ```bash
32 | $ pip install voicebox-pytorch
33 | ```
34 |
35 | ## Usage
36 |
37 | Training and sampling with `TextToSemantic` module from SpearTTS
38 |
39 | ```python
40 | import torch
41 |
42 | from voicebox_pytorch import (
43 | VoiceBox,
44 | EncodecVoco,
45 | ConditionalFlowMatcherWrapper,
46 | HubertWithKmeans,
47 | TextToSemantic
48 | )
49 |
50 | # https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
51 |
52 | wav2vec = HubertWithKmeans(
53 | checkpoint_path = '/path/to/hubert/checkpoint.pt',
54 | kmeans_path = '/path/to/hubert/kmeans.bin'
55 | )
56 |
57 | text_to_semantic = TextToSemantic(
58 | wav2vec = wav2vec,
59 | dim = 512,
60 | source_depth = 1,
61 | target_depth = 1,
62 | use_openai_tokenizer = True
63 | )
64 |
65 | text_to_semantic.load('/path/to/trained/spear-tts/model.pt')
66 |
67 | model = VoiceBox(
68 | dim = 512,
69 | audio_enc_dec = EncodecVoco(),
70 | num_cond_tokens = 500,
71 | depth = 2,
72 | dim_head = 64,
73 | heads = 16
74 | )
75 |
76 | cfm_wrapper = ConditionalFlowMatcherWrapper(
77 | voicebox = model,
78 | text_to_semantic = text_to_semantic
79 | )
80 |
81 | # mock data
82 |
83 | audio = torch.randn(2, 12000)
84 |
85 | # train
86 |
87 | loss = cfm_wrapper(audio)
88 | loss.backward()
89 |
90 | # after much training
91 |
92 | texts = [
93 | 'the rain in spain falls mainly in the plains',
94 | 'she sells sea shells by the seashore'
95 | ]
96 |
97 | cond = torch.randn(2, 12000)
98 | sampled = cfm_wrapper.sample(cond = cond, texts = texts) # (2, 1,