├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── setup.py
├── spear-tts.png
└── spear_tts_pytorch
├── __init__.py
├── attend.py
├── data.py
├── distributed.py
├── spear_tts_pytorch.py
└── trainer.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Spear-TTS - Pytorch
4 |
5 | Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch
6 |
7 | The text-to-semantic module built here will be used for SoundStorm for conditioning.
8 |
9 | ## Appreciation
10 |
11 | - Stability for their generous sponsorships to work on and open source cutting edge artificial intelligence research
12 |
13 | - Lucas Newman for completing the backtranslation portion, as well as beam search decoding!
14 |
15 | - Lucas Newman for completing the final text to semantic transformer training code!
16 |
17 | ## Install
18 |
19 | ```bash
20 | $ pip install spear-tts-pytorch
21 | ```
22 |
23 | ## Usage
24 |
25 | ```python
26 | import torch
27 |
28 | from audiolm_pytorch import HubertWithKmeans
29 |
30 | from spear_tts_pytorch import (
31 | TextToSemantic,
32 | SemanticToTextDatasetGenerator,
33 | GeneratedAudioTextDataset,
34 | MockDataset
35 | )
36 |
37 | wav2vec = HubertWithKmeans(
38 | checkpoint_path = './hubert_base_ls960.pt',
39 | kmeans_path = './hubert_base_ls960_L9_km500.bin'
40 | )
41 |
42 | model = TextToSemantic(
43 | wav2vec = wav2vec,
44 | dim = 512,
45 | num_text_token_ids = 256,
46 | heads = 8,
47 | target_kv_heads = 2, # grouped query attention, for memory efficient decoding
48 | source_depth = 1,
49 | target_depth = 1
50 | )
51 |
52 | ds = MockDataset(10)
53 |
54 | dataset_generator = SemanticToTextDatasetGenerator(
55 | model = model,
56 | dataset = ds,
57 | folder = './output_folder'
58 | )
59 |
60 | dataset_generator(max_length = 2)
61 |
62 | generated_dataset = GeneratedAudioTextDataset(
63 | folder = './output_folder'
64 | )
65 |
66 | assert len(generated_dataset) == 10
67 | ```
68 |
69 | ## Todo
70 |
71 | - [x] add eos logic + generate, and hook up end-to-end generation in soundstorm
72 | - [x] add first pretraining speech-to-speech with the reconstruction of 60% deleted tokens
73 | - [x] add dropouts for this project, as low-resource
74 | - [x] add total flexiblity of which layers of encoder / decoder to freeze during training
75 | - [x] add step for training on small speech -> text corpus and generating pseudo-labelled dataset + finetuning (thanks to @lucasnewman)
76 | - [x] add final step of finetuning on text -> speech + pseudolabelled dataset
77 | - [x] figure out the best way to store and manage the pseudo-labelled generated dataset
78 | - [x] batched beam search decoding
79 | - [x] allow for using rotary positions in decoder + flash attention, give Tri another citation
80 | - [x] integrate speculative decoding with some improvisation - done in same model using early exit strategy
81 |
82 | - [ ] add cached key / values for starter + single / grouped key values, make sure flash attention can support specialized causal mask before flash attention 2 is in pytorch core
83 | - [ ] polish the audio-text generation workflow
84 | - [ ] concatting the real audio-text dataset with the generated one -> or being able to convert real audio-text dataset to generated
85 |
86 | ## Citations
87 |
88 | ```bibtex
89 | @misc{kharitonov2023speak,
90 | title = {Speak, Read and Prompt: High-Fidelity Text-to-Speech with Minimal Supervision},
91 | author = {Eugene Kharitonov and Damien Vincent and Zalán Borsos and Raphaël Marinier and Sertan Girgin and Olivier Pietquin and Matt Sharifi and Marco Tagliasacchi and Neil Zeghidour},
92 | year = {2023},
93 | eprint = {2302.03540},
94 | archivePrefix = {arXiv},
95 | primaryClass = {cs.SD}
96 | }
97 | ```
98 |
99 | ```bibtex
100 | @inproceedings{dao2022flashattention,
101 | title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
102 | author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
103 | booktitle = {Advances in Neural Information Processing Systems},
104 | year = {2022}
105 | }
106 | ```
107 |
108 | ```bibtex
109 | @misc{shi2023enhance,
110 | title = {Enhance audio generation controllability through representation similarity regularization},
111 | author = {Yangyang Shi and Gael Le Lan and Varun Nagaraja and Zhaoheng Ni and Xinhao Mei and Ernie Chang and Forrest Iandola and Yang Liu and Vikas Chandra},
112 | year = {2023},
113 | eprint = {2309.08773},
114 | archivePrefix = {arXiv},
115 | primaryClass = {cs.SD}
116 | }
117 | ```
118 |
119 | ```bibtex
120 | @article{Ainslie2023GQATG,
121 | title = {GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints},
122 | author = {Joshua Ainslie and James Lee-Thorp and Michiel de Jong and Yury Zemlyanskiy and Federico Lebr'on and Sumit K. Sanghai},
123 | journal = {ArXiv},
124 | year = {2023},
125 | volume = {abs/2305.13245},
126 | url = {https://api.semanticscholar.org/CorpusID:258833177}
127 | }
128 | ```
129 |
130 | ```bibtex
131 | @inproceedings{Leviathan2022FastIF,
132 | title = {Fast Inference from Transformers via Speculative Decoding},
133 | author = {Yaniv Leviathan and Matan Kalman and Y. Matias},
134 | booktitle = {International Conference on Machine Learning},
135 | year = {2022},
136 | url = {https://api.semanticscholar.org/CorpusID:254096365}
137 | }
138 | ```
139 |
140 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'spear-tts-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.4.8',
7 | license='MIT',
8 | description = 'Spear-TTS - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/spear-tts-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'transformers',
17 | 'attention mechanism',
18 | 'text-to-speech'
19 | ],
20 | install_requires=[
21 | 'audiolm-pytorch>=1.2.8',
22 | 'beartype',
23 | 'einops>=0.6.1',
24 | 'rotary-embedding-torch>=0.3.0',
25 | 'torch>=1.6',
26 | 'tqdm',
27 | 'x-clip>=0.12.2'
28 | ],
29 | classifiers=[
30 | 'Development Status :: 4 - Beta',
31 | 'Intended Audience :: Developers',
32 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
33 | 'License :: OSI Approved :: MIT License',
34 | 'Programming Language :: Python :: 3.6',
35 | ],
36 | )
37 |
--------------------------------------------------------------------------------
/spear-tts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/spear-tts-pytorch/0e6a63807f3b64f0e41ddc76fe2676fb93231f0f/spear-tts.png
--------------------------------------------------------------------------------
/spear_tts_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from spear_tts_pytorch.spear_tts_pytorch import (
2 | TextToSemantic,
3 | SpeechSpeechPretrainWrapper,
4 | SemanticToTextWrapper,
5 | TextToSemanticWrapper,
6 | SemanticToTextDatasetGenerator
7 | )
8 |
9 | from spear_tts_pytorch.trainer import (
10 | SpeechSpeechPretrainer,
11 | SemanticToTextTrainer,
12 | TextToSemanticTrainer
13 | )
14 |
15 | from spear_tts_pytorch.data import (
16 | GeneratedAudioTextDataset,
17 | MockDataset
18 | )
--------------------------------------------------------------------------------
/spear_tts_pytorch/attend.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | import torch.nn.functional as F
4 |
5 | from collections import namedtuple
6 | from functools import wraps
7 | from packaging import version
8 |
9 | from einops import rearrange, repeat
10 |
11 | # constants
12 |
13 | Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
14 |
15 | # helpers
16 |
17 | def exists(val):
18 | return val is not None
19 |
20 | def once(fn):
21 | called = False
22 | @wraps(fn)
23 | def inner(x):
24 | nonlocal called
25 | if called:
26 | return
27 | called = True
28 | return fn(x)
29 | return inner
30 |
31 | print_once = once(print)
32 |
33 | # main class
34 |
35 | class Attend(nn.Module):
36 | def __init__(
37 | self,
38 | dropout = 0.,
39 | causal = False,
40 | flash = False
41 | ):
42 | super().__init__()
43 | self.dropout = dropout
44 | self.attn_dropout = nn.Dropout(dropout)
45 |
46 | self.causal = causal
47 | self.register_buffer("mask", None, persistent=False)
48 |
49 | self.flash = flash
50 | assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
51 |
52 | # determine efficient attention configs for cuda and cpu
53 |
54 | self.cpu_config = Config(True, True, True)
55 | self.cuda_config = None
56 |
57 | if not torch.cuda.is_available() or not flash:
58 | return
59 |
60 | device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
61 |
62 | if device_properties.major == 8 and device_properties.minor == 0:
63 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
64 | self.cuda_config = Config(True, False, False)
65 | else:
66 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
67 | self.cuda_config = Config(False, True, True)
68 |
69 | def get_mask(self, i, j, device):
70 | n = max(i, j)
71 |
72 | if exists(self.mask) and self.mask.shape[-1] >= n:
73 | mask = self.mask[:n, :n]
74 | else:
75 | mask = torch.ones((n, n), device = device, dtype = torch.bool).triu(1)
76 | self.register_buffer("mask", mask, persistent = False)
77 |
78 | return mask[-i:, :]
79 |
80 | def flash_attn(self, q, k, v, mask = None):
81 | _, heads, q_len, _, k_len, causal, is_cuda, device = *q.shape, k.shape[-2], self.causal, q.is_cuda, q.device
82 |
83 | # Check if mask exists and expand to compatible shape
84 | # The mask is B L, so it would have to be expanded to B H N L
85 |
86 | if exists(mask):
87 | mask = rearrange(mask, 'b j -> b 1 1 j')
88 | mask = mask.expand(-1, heads, q_len, -1)
89 |
90 | # Check if there is a compatible device for flash attention
91 |
92 | config = self.cuda_config if is_cuda else self.cpu_config
93 |
94 | # if q and k lengths differ (caching of key/values), and causal, manually construct causal attn mask as float, as not supported (flash attn 2 will support this eventually)
95 |
96 | row_is_entirely_masked = None
97 |
98 | if causal and q_len != k_len:
99 | causal_mask = self.get_mask(q_len, k_len, device = device)
100 |
101 | if exists(mask):
102 | mask = mask & ~causal_mask
103 | else:
104 | mask = ~causal_mask
105 |
106 | row_is_entirely_masked = ~mask.any(dim = -1)
107 | mask[..., 0] = mask[..., 0] | row_is_entirely_masked
108 |
109 | causal = False
110 |
111 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
112 |
113 | with torch.backends.cuda.sdp_kernel(**config._asdict()):
114 | out = F.scaled_dot_product_attention(
115 | q, k, v,
116 | attn_mask = mask,
117 | dropout_p = self.dropout if self.training else 0.,
118 | is_causal = causal
119 | )
120 |
121 | if exists(row_is_entirely_masked):
122 | out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
123 |
124 | return out
125 |
126 | def forward(self, q, k, v, mask = None):
127 | """
128 | einstein notation
129 | b - batch
130 | h - heads
131 | n, i, j - sequence length (base sequence length, source, target)
132 | d - feature dimension
133 | """
134 |
135 | n, device = q.shape[-2], q.device
136 | heads, kv_heads = q.shape[1], k.shape[1]
137 |
138 | if kv_heads < heads:
139 | k, v = map(lambda t: repeat(t, 'b h ... -> b (g h) ...', g = heads // kv_heads), (k, v))
140 |
141 | scale = q.shape[-1] ** -0.5
142 |
143 | if self.flash:
144 | return self.flash_attn(q, k, v, mask = mask)
145 |
146 | # similarity
147 |
148 | sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
149 |
150 | # key padding mask
151 |
152 | if exists(mask):
153 | mask = rearrange(mask, 'b j -> b 1 1 j')
154 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
155 |
156 | # causal mask
157 |
158 | if self.causal:
159 | i, j = sim.shape[-2:]
160 | causal_mask = self.get_mask(i, j, device)
161 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
162 |
163 | # attention
164 |
165 | attn = sim.softmax(dim = -1)
166 | attn = self.attn_dropout(attn)
167 |
168 | # aggregate values
169 |
170 | out = einsum("b h i j, b h j d -> b h i d", attn, v)
171 |
172 | return out
173 |
--------------------------------------------------------------------------------
/spear_tts_pytorch/data.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | from torch.utils.data import Dataset
5 |
6 | from beartype import beartype
7 |
8 | # mock dataset
9 |
10 | class MockDataset(Dataset):
11 | def __init__(self, length: int):
12 | self.length = length
13 |
14 | def __len__(self):
15 | return self.length
16 |
17 | def __getitem__(self, ind):
18 | return torch.randn(1024)
19 |
20 | # generated audio-text dataset
21 |
22 | class GeneratedAudioTextDataset(Dataset):
23 | @beartype
24 | def __init__(
25 | self,
26 | folder: str,
27 | delimiter_id: int = -1
28 | ):
29 | self.folder = Path(folder)
30 | assert self.folder.exists() and self.folder.is_dir()
31 | self.paths = list(self.folder.glob('*.pt'))
32 | self.delimiter_id = delimiter_id
33 |
34 | def __len__(self):
35 | return len(self.paths)
36 |
37 | def __getitem__(self, ind):
38 | path = self.paths[ind]
39 | tensor = torch.load(str(path))
40 |
41 | delimiter_mask = tensor == self.delimiter_id
42 | assert delimiter_mask.any(), f'delimeter (