├── images ├── voicebox.png └── translated.png ├── voicebox_pytorch ├── __init__.py ├── optimizer.py ├── data.py ├── attend.py ├── trainer.py └── voicebox_pytorch.py ├── LICENSE ├── setup.py ├── .github └── workflows │ └── python-publish.yml ├── .gitignore └── README.md /images/voicebox.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/voicebox-pytorch/HEAD/images/voicebox.png -------------------------------------------------------------------------------- /images/translated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/voicebox-pytorch/HEAD/images/translated.png -------------------------------------------------------------------------------- /voicebox_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from voicebox_pytorch.voicebox_pytorch import ( 2 | Transformer, 3 | EncodecVoco, 4 | VoiceBox, 5 | DurationPredictor, 6 | ConditionalFlowMatcherWrapper, 7 | ) 8 | 9 | from voicebox_pytorch.trainer import ( 10 | VoiceBoxTrainer 11 | ) 12 | 13 | from spear_tts_pytorch import TextToSemantic 14 | 15 | from audiolm_pytorch import HubertWithKmeans 16 | -------------------------------------------------------------------------------- /voicebox_pytorch/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import AdamW, Adam 2 | 3 | def separate_weight_decayable_params(params): 4 | wd_params, no_wd_params = [], [] 5 | for param in params: 6 | param_list = no_wd_params if param.ndim < 2 else wd_params 7 | param_list.append(param) 8 | return wd_params, no_wd_params 9 | 10 | def get_optimizer( 11 | params, 12 | lr = 1e-4, 13 | wd = 1e-2, 14 | betas = (0.9, 0.99), 15 | eps = 1e-8, 16 | filter_by_requires_grad = False, 17 | group_wd_params = True 18 | ): 19 | has_wd = wd > 0 20 | 21 | if filter_by_requires_grad: 22 | params = list(filter(lambda t: t.requires_grad, params)) 23 | 24 | if group_wd_params and has_wd: 25 | wd_params, no_wd_params = separate_weight_decayable_params(params) 26 | 27 | params = [ 28 | {'params': wd_params}, 29 | {'params': no_wd_params, 'weight_decay': 0}, 30 | ] 31 | 32 | if not has_wd: 33 | return Adam(params, lr = lr, betas = betas, eps = eps) 34 | 35 | return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps) 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'voicebox-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.5.0', 7 | license='MIT', 8 | description = 'Voicebox - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/voicebox-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'text to speech' 17 | ], 18 | install_requires=[ 19 | 'accelerate', 20 | 'audiolm-pytorch>=1.2.28', 21 | 'naturalspeech2-pytorch>=0.1.8', 22 | 'beartype', 23 | 'einops>=0.6.1', 24 | 'gateloop-transformer>=0.2.4', 25 | 'spear-tts-pytorch>=0.4.0', 26 | 'torch>=2.0', 27 | 'torchdiffeq', 28 | 'torchode', 29 | 'vocos' 30 | ], 31 | classifiers=[ 32 | 'Development Status :: 4 - Beta', 33 | 'Intended Audience :: Developers', 34 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 35 | 'License :: OSI Approved :: MIT License', 36 | 'Programming Language :: Python :: 3.6', 37 | ], 38 | ) 39 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /voicebox_pytorch/data.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from functools import wraps 3 | 4 | from einops import rearrange 5 | 6 | from beartype import beartype 7 | from beartype.door import is_bearable 8 | from beartype.typing import Optional, Tuple, Union 9 | 10 | import torch 11 | from torch.nn.utils.rnn import pad_sequence 12 | from torch.utils.data import Dataset, DataLoader 13 | 14 | import torchaudio 15 | 16 | # utilities 17 | 18 | def exists(val): 19 | return val is not None 20 | 21 | def cast_tuple(val, length = 1): 22 | return val if isinstance(val, tuple) else ((val,) * length) 23 | 24 | # dataset functions 25 | 26 | class AudioDataset(Dataset): 27 | @beartype 28 | def __init__( 29 | self, 30 | folder, 31 | audio_extension = ".flac" 32 | ): 33 | super().__init__() 34 | path = Path(folder) 35 | assert path.exists(), 'folder does not exist' 36 | 37 | self.audio_extension = audio_extension 38 | 39 | files = list(path.glob(f'**/*{audio_extension}')) 40 | assert len(files) > 0, 'no files found' 41 | 42 | self.files = files 43 | 44 | def __len__(self): 45 | return len(self.files) 46 | 47 | def __getitem__(self, idx): 48 | file = self.files[idx] 49 | 50 | wave, _ = torchaudio.load(file) 51 | wave = rearrange(wave, '1 ... -> ...') 52 | 53 | return wave 54 | 55 | # dataloader functions 56 | 57 | def collate_one_or_multiple_tensors(fn): 58 | @wraps(fn) 59 | def inner(data): 60 | is_one_data = not isinstance(data[0], tuple) 61 | 62 | if is_one_data: 63 | data = fn(data) 64 | return (data,) 65 | 66 | outputs = [] 67 | for datum in zip(*data): 68 | if is_bearable(datum, Tuple[str, ...]): 69 | output = list(datum) 70 | else: 71 | output = fn(datum) 72 | 73 | outputs.append(output) 74 | 75 | return tuple(outputs) 76 | 77 | return inner 78 | 79 | @collate_one_or_multiple_tensors 80 | def curtail_to_shortest_collate(data): 81 | min_len = min(*[datum.shape[0] for datum in data]) 82 | data = [datum[:min_len] for datum in data] 83 | return torch.stack(data) 84 | 85 | @collate_one_or_multiple_tensors 86 | def pad_to_longest_fn(data): 87 | return pad_sequence(data, batch_first = True) 88 | 89 | def get_dataloader(ds, pad_to_longest = True, **kwargs): 90 | collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate 91 | return DataLoader(ds, collate_fn = collate_fn, **kwargs) 92 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | -------------------------------------------------------------------------------- /voicebox_pytorch/attend.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from packaging import version 3 | from collections import namedtuple 4 | 5 | import torch 6 | from torch import nn, einsum 7 | import torch.nn.functional as F 8 | 9 | from einops import rearrange, reduce 10 | 11 | # constants 12 | 13 | FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) 14 | 15 | # helpers 16 | 17 | def exists(val): 18 | return val is not None 19 | 20 | def default(val, d): 21 | return val if exists(val) else d 22 | 23 | def once(fn): 24 | called = False 25 | @wraps(fn) 26 | def inner(x): 27 | nonlocal called 28 | if called: 29 | return 30 | called = True 31 | return fn(x) 32 | return inner 33 | 34 | print_once = once(print) 35 | 36 | # main class 37 | 38 | class Attend(nn.Module): 39 | def __init__( 40 | self, 41 | dropout = 0., 42 | flash = False, 43 | scale = None 44 | ): 45 | super().__init__() 46 | self.dropout = dropout 47 | self.attn_dropout = nn.Dropout(dropout) 48 | 49 | self.scale = scale 50 | 51 | self.flash = flash 52 | assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' 53 | 54 | # determine efficient attention configs for cuda and cpu 55 | 56 | self.cpu_config = FlashAttentionConfig(True, True, True) 57 | self.cuda_config = None 58 | 59 | if not torch.cuda.is_available() or not flash: 60 | return 61 | 62 | device_properties = torch.cuda.get_device_properties(torch.device('cuda')) 63 | 64 | if device_properties.major == 8 and device_properties.minor == 0: 65 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda') 66 | self.cuda_config = FlashAttentionConfig(True, False, False) 67 | else: 68 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') 69 | self.cuda_config = FlashAttentionConfig(False, True, True) 70 | 71 | def flash_attn(self, q, k, v, mask = None): 72 | _, heads, q_len, dim_head, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device 73 | 74 | # if scale is given, divide by the default scale that sdpa uses 75 | 76 | if exists(self.scale): 77 | q = q * (self.scale / (dim_head ** -0.5)) 78 | 79 | # Check if mask exists and expand to compatible shape 80 | # The mask is B L, so it would have to be expanded to B H N L 81 | 82 | if exists(mask): 83 | mask = mask.expand(-1, heads, q_len, -1) 84 | 85 | # Check if there is a compatible device for flash attention 86 | 87 | config = self.cuda_config if is_cuda else self.cpu_config 88 | 89 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale 90 | 91 | with torch.backends.cuda.sdp_kernel(**config._asdict()): 92 | out = F.scaled_dot_product_attention( 93 | q, k, v, 94 | attn_mask = mask, 95 | dropout_p = self.dropout if self.training else 0. 96 | ) 97 | 98 | return out 99 | 100 | def forward(self, q, k, v, mask = None): 101 | """ 102 | einstein notation 103 | b - batch 104 | h - heads 105 | n, i, j - sequence length (base sequence length, source, target) 106 | d - feature dimension 107 | """ 108 | 109 | q_len, k_len, device = q.shape[-2], k.shape[-2], q.device 110 | 111 | scale = default(self.scale, q.shape[-1] ** -0.5) 112 | 113 | if exists(mask) and mask.ndim != 4: 114 | mask = rearrange(mask, 'b j -> b 1 1 j') 115 | 116 | if self.flash: 117 | return self.flash_attn(q, k, v, mask = mask) 118 | 119 | # similarity 120 | 121 | sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale 122 | 123 | # key padding mask 124 | 125 | if exists(mask): 126 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) 127 | 128 | # attention 129 | 130 | attn = sim.softmax(dim=-1) 131 | attn = self.attn_dropout(attn) 132 | 133 | # aggregate values 134 | 135 | out = einsum(f"b h i j, b h j d -> b h i d", attn, v) 136 | 137 | return out 138 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Voicebox - Pytorch 4 | 5 | Implementation of Voicebox, new SOTA Text-to-Speech model from MetaAI, in Pytorch. Press release 6 | 7 | In this work, we will use rotary embeddings. The authors seem unaware that ALiBi cannot be straightforwardly used for bidirectional models. 8 | 9 | The paper also addresses the issue with time embedding incorrectly subjected to relative distances (they concat the time embedding along the frame dimension of the audio tokens). This repository will use adaptive normalization, as applied successfully in Paella 10 | 11 | Update: Recommend you just use E2 TTS instead of this work 12 | 13 | ## Appreciation 14 | 15 | - Translated for awarding me the Imminent Grant to advance the state of open sourced text-to-speech solutions. This project was started and will be completed under this grant. 16 | 17 | - StabilityAI for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence. 18 | 19 | - Bryan Chiang for the ongoing code review, sharing his expertise on TTS, and pointing me to an open sourced implementation of conditional flow matching 20 | 21 | - Manmay for getting the repository started with the alignment code 22 | 23 | - @chenht2010 for finding a bug with rotary positions, and for validating that the code in the repository converges 24 | 25 | - Lucas Newman for (yet again) pull requesting all the training code for Spear-TTS conditioned Voicebox training! 26 | 27 | - Lucas Newman has demonstrated that the whole system works with Spear-TTS conditioning. Training converges even better than Soundstorm 28 | 29 | ## Install 30 | 31 | ```bash 32 | $ pip install voicebox-pytorch 33 | ``` 34 | 35 | ## Usage 36 | 37 | Training and sampling with `TextToSemantic` module from SpearTTS 38 | 39 | ```python 40 | import torch 41 | 42 | from voicebox_pytorch import ( 43 | VoiceBox, 44 | EncodecVoco, 45 | ConditionalFlowMatcherWrapper, 46 | HubertWithKmeans, 47 | TextToSemantic 48 | ) 49 | 50 | # https://github.com/facebookresearch/fairseq/tree/main/examples/hubert 51 | 52 | wav2vec = HubertWithKmeans( 53 | checkpoint_path = '/path/to/hubert/checkpoint.pt', 54 | kmeans_path = '/path/to/hubert/kmeans.bin' 55 | ) 56 | 57 | text_to_semantic = TextToSemantic( 58 | wav2vec = wav2vec, 59 | dim = 512, 60 | source_depth = 1, 61 | target_depth = 1, 62 | use_openai_tokenizer = True 63 | ) 64 | 65 | text_to_semantic.load('/path/to/trained/spear-tts/model.pt') 66 | 67 | model = VoiceBox( 68 | dim = 512, 69 | audio_enc_dec = EncodecVoco(), 70 | num_cond_tokens = 500, 71 | depth = 2, 72 | dim_head = 64, 73 | heads = 16 74 | ) 75 | 76 | cfm_wrapper = ConditionalFlowMatcherWrapper( 77 | voicebox = model, 78 | text_to_semantic = text_to_semantic 79 | ) 80 | 81 | # mock data 82 | 83 | audio = torch.randn(2, 12000) 84 | 85 | # train 86 | 87 | loss = cfm_wrapper(audio) 88 | loss.backward() 89 | 90 | # after much training 91 | 92 | texts = [ 93 | 'the rain in spain falls mainly in the plains', 94 | 'she sells sea shells by the seashore' 95 | ] 96 | 97 | cond = torch.randn(2, 12000) 98 | sampled = cfm_wrapper.sample(cond = cond, texts = texts) # (2, 1,