├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── setup.py ├── spear-tts.png └── spear_tts_pytorch ├── __init__.py ├── attend.py ├── data.py ├── distributed.py ├── spear_tts_pytorch.py └── trainer.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Spear-TTS - Pytorch 4 | 5 | Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch 6 | 7 | The text-to-semantic module built here will be used for SoundStorm for conditioning. 8 | 9 | ## Appreciation 10 | 11 | - Stability for their generous sponsorships to work on and open source cutting edge artificial intelligence research 12 | 13 | - Lucas Newman for completing the backtranslation portion, as well as beam search decoding! 14 | 15 | - Lucas Newman for completing the final text to semantic transformer training code! 16 | 17 | ## Install 18 | 19 | ```bash 20 | $ pip install spear-tts-pytorch 21 | ``` 22 | 23 | ## Usage 24 | 25 | ```python 26 | import torch 27 | 28 | from audiolm_pytorch import HubertWithKmeans 29 | 30 | from spear_tts_pytorch import ( 31 | TextToSemantic, 32 | SemanticToTextDatasetGenerator, 33 | GeneratedAudioTextDataset, 34 | MockDataset 35 | ) 36 | 37 | wav2vec = HubertWithKmeans( 38 | checkpoint_path = './hubert_base_ls960.pt', 39 | kmeans_path = './hubert_base_ls960_L9_km500.bin' 40 | ) 41 | 42 | model = TextToSemantic( 43 | wav2vec = wav2vec, 44 | dim = 512, 45 | num_text_token_ids = 256, 46 | heads = 8, 47 | target_kv_heads = 2, # grouped query attention, for memory efficient decoding 48 | source_depth = 1, 49 | target_depth = 1 50 | ) 51 | 52 | ds = MockDataset(10) 53 | 54 | dataset_generator = SemanticToTextDatasetGenerator( 55 | model = model, 56 | dataset = ds, 57 | folder = './output_folder' 58 | ) 59 | 60 | dataset_generator(max_length = 2) 61 | 62 | generated_dataset = GeneratedAudioTextDataset( 63 | folder = './output_folder' 64 | ) 65 | 66 | assert len(generated_dataset) == 10 67 | ``` 68 | 69 | ## Todo 70 | 71 | - [x] add eos logic + generate, and hook up end-to-end generation in soundstorm 72 | - [x] add first pretraining speech-to-speech with the reconstruction of 60% deleted tokens 73 | - [x] add dropouts for this project, as low-resource 74 | - [x] add total flexiblity of which layers of encoder / decoder to freeze during training 75 | - [x] add step for training on small speech -> text corpus and generating pseudo-labelled dataset + finetuning (thanks to @lucasnewman) 76 | - [x] add final step of finetuning on text -> speech + pseudolabelled dataset 77 | - [x] figure out the best way to store and manage the pseudo-labelled generated dataset 78 | - [x] batched beam search decoding 79 | - [x] allow for using rotary positions in decoder + flash attention, give Tri another citation 80 | - [x] integrate speculative decoding with some improvisation - done in same model using early exit strategy 81 | 82 | - [ ] add cached key / values for starter + single / grouped key values, make sure flash attention can support specialized causal mask before flash attention 2 is in pytorch core 83 | - [ ] polish the audio-text generation workflow 84 | - [ ] concatting the real audio-text dataset with the generated one -> or being able to convert real audio-text dataset to generated 85 | 86 | ## Citations 87 | 88 | ```bibtex 89 | @misc{kharitonov2023speak, 90 | title = {Speak, Read and Prompt: High-Fidelity Text-to-Speech with Minimal Supervision}, 91 | author = {Eugene Kharitonov and Damien Vincent and Zalán Borsos and Raphaël Marinier and Sertan Girgin and Olivier Pietquin and Matt Sharifi and Marco Tagliasacchi and Neil Zeghidour}, 92 | year = {2023}, 93 | eprint = {2302.03540}, 94 | archivePrefix = {arXiv}, 95 | primaryClass = {cs.SD} 96 | } 97 | ``` 98 | 99 | ```bibtex 100 | @inproceedings{dao2022flashattention, 101 | title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, 102 | author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, 103 | booktitle = {Advances in Neural Information Processing Systems}, 104 | year = {2022} 105 | } 106 | ``` 107 | 108 | ```bibtex 109 | @misc{shi2023enhance, 110 | title = {Enhance audio generation controllability through representation similarity regularization}, 111 | author = {Yangyang Shi and Gael Le Lan and Varun Nagaraja and Zhaoheng Ni and Xinhao Mei and Ernie Chang and Forrest Iandola and Yang Liu and Vikas Chandra}, 112 | year = {2023}, 113 | eprint = {2309.08773}, 114 | archivePrefix = {arXiv}, 115 | primaryClass = {cs.SD} 116 | } 117 | ``` 118 | 119 | ```bibtex 120 | @article{Ainslie2023GQATG, 121 | title = {GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints}, 122 | author = {Joshua Ainslie and James Lee-Thorp and Michiel de Jong and Yury Zemlyanskiy and Federico Lebr'on and Sumit K. Sanghai}, 123 | journal = {ArXiv}, 124 | year = {2023}, 125 | volume = {abs/2305.13245}, 126 | url = {https://api.semanticscholar.org/CorpusID:258833177} 127 | } 128 | ``` 129 | 130 | ```bibtex 131 | @inproceedings{Leviathan2022FastIF, 132 | title = {Fast Inference from Transformers via Speculative Decoding}, 133 | author = {Yaniv Leviathan and Matan Kalman and Y. Matias}, 134 | booktitle = {International Conference on Machine Learning}, 135 | year = {2022}, 136 | url = {https://api.semanticscholar.org/CorpusID:254096365} 137 | } 138 | ``` 139 | 140 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'spear-tts-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.4.8', 7 | license='MIT', 8 | description = 'Spear-TTS - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/spear-tts-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'attention mechanism', 18 | 'text-to-speech' 19 | ], 20 | install_requires=[ 21 | 'audiolm-pytorch>=1.2.8', 22 | 'beartype', 23 | 'einops>=0.6.1', 24 | 'rotary-embedding-torch>=0.3.0', 25 | 'torch>=1.6', 26 | 'tqdm', 27 | 'x-clip>=0.12.2' 28 | ], 29 | classifiers=[ 30 | 'Development Status :: 4 - Beta', 31 | 'Intended Audience :: Developers', 32 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 33 | 'License :: OSI Approved :: MIT License', 34 | 'Programming Language :: Python :: 3.6', 35 | ], 36 | ) 37 | -------------------------------------------------------------------------------- /spear-tts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/spear-tts-pytorch/0e6a63807f3b64f0e41ddc76fe2676fb93231f0f/spear-tts.png -------------------------------------------------------------------------------- /spear_tts_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from spear_tts_pytorch.spear_tts_pytorch import ( 2 | TextToSemantic, 3 | SpeechSpeechPretrainWrapper, 4 | SemanticToTextWrapper, 5 | TextToSemanticWrapper, 6 | SemanticToTextDatasetGenerator 7 | ) 8 | 9 | from spear_tts_pytorch.trainer import ( 10 | SpeechSpeechPretrainer, 11 | SemanticToTextTrainer, 12 | TextToSemanticTrainer 13 | ) 14 | 15 | from spear_tts_pytorch.data import ( 16 | GeneratedAudioTextDataset, 17 | MockDataset 18 | ) -------------------------------------------------------------------------------- /spear_tts_pytorch/attend.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from collections import namedtuple 6 | from functools import wraps 7 | from packaging import version 8 | 9 | from einops import rearrange, repeat 10 | 11 | # constants 12 | 13 | Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) 14 | 15 | # helpers 16 | 17 | def exists(val): 18 | return val is not None 19 | 20 | def once(fn): 21 | called = False 22 | @wraps(fn) 23 | def inner(x): 24 | nonlocal called 25 | if called: 26 | return 27 | called = True 28 | return fn(x) 29 | return inner 30 | 31 | print_once = once(print) 32 | 33 | # main class 34 | 35 | class Attend(nn.Module): 36 | def __init__( 37 | self, 38 | dropout = 0., 39 | causal = False, 40 | flash = False 41 | ): 42 | super().__init__() 43 | self.dropout = dropout 44 | self.attn_dropout = nn.Dropout(dropout) 45 | 46 | self.causal = causal 47 | self.register_buffer("mask", None, persistent=False) 48 | 49 | self.flash = flash 50 | assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' 51 | 52 | # determine efficient attention configs for cuda and cpu 53 | 54 | self.cpu_config = Config(True, True, True) 55 | self.cuda_config = None 56 | 57 | if not torch.cuda.is_available() or not flash: 58 | return 59 | 60 | device_properties = torch.cuda.get_device_properties(torch.device('cuda')) 61 | 62 | if device_properties.major == 8 and device_properties.minor == 0: 63 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda') 64 | self.cuda_config = Config(True, False, False) 65 | else: 66 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') 67 | self.cuda_config = Config(False, True, True) 68 | 69 | def get_mask(self, i, j, device): 70 | n = max(i, j) 71 | 72 | if exists(self.mask) and self.mask.shape[-1] >= n: 73 | mask = self.mask[:n, :n] 74 | else: 75 | mask = torch.ones((n, n), device = device, dtype = torch.bool).triu(1) 76 | self.register_buffer("mask", mask, persistent = False) 77 | 78 | return mask[-i:, :] 79 | 80 | def flash_attn(self, q, k, v, mask = None): 81 | _, heads, q_len, _, k_len, causal, is_cuda, device = *q.shape, k.shape[-2], self.causal, q.is_cuda, q.device 82 | 83 | # Check if mask exists and expand to compatible shape 84 | # The mask is B L, so it would have to be expanded to B H N L 85 | 86 | if exists(mask): 87 | mask = rearrange(mask, 'b j -> b 1 1 j') 88 | mask = mask.expand(-1, heads, q_len, -1) 89 | 90 | # Check if there is a compatible device for flash attention 91 | 92 | config = self.cuda_config if is_cuda else self.cpu_config 93 | 94 | # if q and k lengths differ (caching of key/values), and causal, manually construct causal attn mask as float, as not supported (flash attn 2 will support this eventually) 95 | 96 | row_is_entirely_masked = None 97 | 98 | if causal and q_len != k_len: 99 | causal_mask = self.get_mask(q_len, k_len, device = device) 100 | 101 | if exists(mask): 102 | mask = mask & ~causal_mask 103 | else: 104 | mask = ~causal_mask 105 | 106 | row_is_entirely_masked = ~mask.any(dim = -1) 107 | mask[..., 0] = mask[..., 0] | row_is_entirely_masked 108 | 109 | causal = False 110 | 111 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale 112 | 113 | with torch.backends.cuda.sdp_kernel(**config._asdict()): 114 | out = F.scaled_dot_product_attention( 115 | q, k, v, 116 | attn_mask = mask, 117 | dropout_p = self.dropout if self.training else 0., 118 | is_causal = causal 119 | ) 120 | 121 | if exists(row_is_entirely_masked): 122 | out = out.masked_fill(row_is_entirely_masked[..., None], 0.) 123 | 124 | return out 125 | 126 | def forward(self, q, k, v, mask = None): 127 | """ 128 | einstein notation 129 | b - batch 130 | h - heads 131 | n, i, j - sequence length (base sequence length, source, target) 132 | d - feature dimension 133 | """ 134 | 135 | n, device = q.shape[-2], q.device 136 | heads, kv_heads = q.shape[1], k.shape[1] 137 | 138 | if kv_heads < heads: 139 | k, v = map(lambda t: repeat(t, 'b h ... -> b (g h) ...', g = heads // kv_heads), (k, v)) 140 | 141 | scale = q.shape[-1] ** -0.5 142 | 143 | if self.flash: 144 | return self.flash_attn(q, k, v, mask = mask) 145 | 146 | # similarity 147 | 148 | sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale 149 | 150 | # key padding mask 151 | 152 | if exists(mask): 153 | mask = rearrange(mask, 'b j -> b 1 1 j') 154 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) 155 | 156 | # causal mask 157 | 158 | if self.causal: 159 | i, j = sim.shape[-2:] 160 | causal_mask = self.get_mask(i, j, device) 161 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 162 | 163 | # attention 164 | 165 | attn = sim.softmax(dim = -1) 166 | attn = self.attn_dropout(attn) 167 | 168 | # aggregate values 169 | 170 | out = einsum("b h i j, b h j d -> b h i d", attn, v) 171 | 172 | return out 173 | -------------------------------------------------------------------------------- /spear_tts_pytorch/data.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from beartype import beartype 7 | 8 | # mock dataset 9 | 10 | class MockDataset(Dataset): 11 | def __init__(self, length: int): 12 | self.length = length 13 | 14 | def __len__(self): 15 | return self.length 16 | 17 | def __getitem__(self, ind): 18 | return torch.randn(1024) 19 | 20 | # generated audio-text dataset 21 | 22 | class GeneratedAudioTextDataset(Dataset): 23 | @beartype 24 | def __init__( 25 | self, 26 | folder: str, 27 | delimiter_id: int = -1 28 | ): 29 | self.folder = Path(folder) 30 | assert self.folder.exists() and self.folder.is_dir() 31 | self.paths = list(self.folder.glob('*.pt')) 32 | self.delimiter_id = delimiter_id 33 | 34 | def __len__(self): 35 | return len(self.paths) 36 | 37 | def __getitem__(self, ind): 38 | path = self.paths[ind] 39 | tensor = torch.load(str(path)) 40 | 41 | delimiter_mask = tensor == self.delimiter_id 42 | assert delimiter_mask.any(), f'delimeter (