├── audiolm_pytorch
├── version.py
├── utils.py
├── __init__.py
├── optimizer.py
├── vq_wav2vec.py
├── t5.py
├── hubert_kmeans.py
├── data.py
├── encodec.py
├── soundstream.py
└── trainer.py
├── audiolm.png
├── LICENSE
├── .github
└── workflows
│ └── python-publish.yml
├── setup.py
├── .gitignore
├── README.md
└── audiolm_pytorch_demo.ipynb
/audiolm_pytorch/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '1.1.4'
2 |
--------------------------------------------------------------------------------
/audiolm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vaibhavs10/audiolm-pytorch/main/audiolm.png
--------------------------------------------------------------------------------
/audiolm_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | # functions
4 |
5 | def round_down_nearest_multiple(num, divisor):
6 | return num // divisor * divisor
7 |
8 | def curtail_to_multiple(t, mult, from_left = False):
9 | data_len = t.shape[-1]
10 | rounded_seq_len = round_down_nearest_multiple(data_len, mult)
11 | seq_slice = slice(None, rounded_seq_len) if not from_left else slice(-rounded_seq_len, None)
12 | return t[..., seq_slice]
13 |
14 | # base class
15 |
16 | class AudioConditionerBase(nn.Module):
17 | pass
18 |
--------------------------------------------------------------------------------
/audiolm_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from packaging import version
3 |
4 | if version.parse(torch.__version__) >= version.parse('2.0.0'):
5 | from einops._torch_specific import allow_ops_in_compiled_graph
6 | allow_ops_in_compiled_graph()
7 |
8 | from audiolm_pytorch.audiolm_pytorch import AudioLM
9 | from audiolm_pytorch.soundstream import SoundStream, AudioLMSoundStream, MusicLMSoundStream
10 | from audiolm_pytorch.encodec import EncodecWrapper
11 |
12 | from audiolm_pytorch.audiolm_pytorch import SemanticTransformer, CoarseTransformer, FineTransformer
13 | from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper, SemanticTransformerWrapper
14 |
15 | from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
16 | from audiolm_pytorch.hubert_kmeans import HubertWithKmeans
17 |
18 | from audiolm_pytorch.trainer import SoundStreamTrainer, SemanticTransformerTrainer, FineTransformerTrainer, CoarseTransformerTrainer
19 |
20 | from audiolm_pytorch.audiolm_pytorch import get_embeds
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/audiolm_pytorch/optimizer.py:
--------------------------------------------------------------------------------
1 | from lion_pytorch import Lion
2 | from torch.optim import AdamW, Adam
3 |
4 | def separate_weight_decayable_params(params):
5 | wd_params, no_wd_params = [], []
6 | for param in params:
7 | param_list = no_wd_params if param.ndim < 2 else wd_params
8 | param_list.append(param)
9 | return wd_params, no_wd_params
10 |
11 | def get_optimizer(
12 | params,
13 | lr = 1e-4,
14 | wd = 1e-2,
15 | betas = (0.9, 0.99),
16 | eps = 1e-8,
17 | filter_by_requires_grad = False,
18 | group_wd_params = True,
19 | use_lion = False,
20 | **kwargs
21 | ):
22 | has_wd = wd > 0
23 |
24 | if filter_by_requires_grad:
25 | params = list(filter(lambda t: t.requires_grad, params))
26 |
27 | if group_wd_params and has_wd:
28 | wd_params, no_wd_params = separate_weight_decayable_params(params)
29 |
30 | params = [
31 | {'params': wd_params},
32 | {'params': no_wd_params, 'weight_decay': 0},
33 | ]
34 |
35 | if use_lion:
36 | return Lion(params, lr = lr, betas = betas, weight_decay = wd)
37 |
38 | if not has_wd:
39 | return Adam(params, lr = lr, betas = betas, eps = eps)
40 |
41 | return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
42 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | exec(open('audiolm_pytorch/version.py').read())
3 |
4 | setup(
5 | name = 'audiolm-pytorch',
6 | packages = find_packages(exclude=[]),
7 | version = __version__,
8 | license='MIT',
9 | description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
10 | author = 'Phil Wang',
11 | author_email = 'lucidrains@gmail.com',
12 | long_description_content_type = 'text/markdown',
13 | url = 'https://github.com/lucidrains/audiolm-pytorch',
14 | keywords = [
15 | 'artificial intelligence',
16 | 'deep learning',
17 | 'transformers',
18 | 'attention mechanism',
19 | 'audio generation'
20 | ],
21 | install_requires=[
22 | 'accelerate',
23 | 'beartype',
24 | 'einops>=0.6.1',
25 | 'ema-pytorch>=0.2.2',
26 | 'encodec',
27 | 'fairseq',
28 | 'joblib',
29 | 'lion-pytorch',
30 | 'local-attention>=1.8.4',
31 | 'scikit-learn',
32 | 'sentencepiece',
33 | 'torch>=1.12',
34 | 'torchaudio',
35 | 'transformers',
36 | 'tqdm',
37 | 'vector-quantize-pytorch>=1.5.14'
38 | ],
39 | classifiers=[
40 | 'Development Status :: 4 - Beta',
41 | 'Intended Audience :: Developers',
42 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
43 | 'License :: OSI Approved :: MIT License',
44 | 'Programming Language :: Python :: 3.6',
45 | ],
46 | )
47 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Pycharm
132 | .idea/
133 |
--------------------------------------------------------------------------------
/audiolm_pytorch/vq_wav2vec.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | from torch import nn
5 | from einops import rearrange
6 |
7 | import fairseq
8 |
9 | from torchaudio.functional import resample
10 |
11 | from audiolm_pytorch.utils import curtail_to_multiple
12 |
13 | import logging
14 | logging.root.setLevel(logging.ERROR)
15 |
16 | def exists(val):
17 | return val is not None
18 |
19 | class FairseqVQWav2Vec(nn.Module):
20 | """
21 | checkpoint path can be found at https://github.com/facebookresearch/fairseq/blob/main/examples/wav2vec/README.md#vq-wav2vec
22 | specifically download the kmeans model for now
23 |
24 | $ wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec_kmeans.pt
25 | """
26 |
27 | def __init__(
28 | self,
29 | checkpoint_path,
30 | target_sample_hz = 24000,
31 | seq_len_multiple_of = None
32 | ):
33 | super().__init__()
34 | self.target_sample_hz = target_sample_hz
35 | self.seq_len_multiple_of = seq_len_multiple_of
36 |
37 | path = Path(checkpoint_path)
38 | assert path.exists(), f'path {checkpoint_path} does not exist'
39 |
40 | checkpoint = torch.load(checkpoint_path)
41 | load_model_input = {checkpoint_path: checkpoint}
42 | model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
43 |
44 | self.model = model[0]
45 | self.model.eval()
46 |
47 | assert hasattr(self.model, 'vector_quantizer') and hasattr(self.model.vector_quantizer, 'embedding'), 'the vq wav2vec model does not seem to be valid'
48 |
49 | @property
50 | def groups(self):
51 | return self.model.vector_quantizer.groups
52 |
53 | @property
54 | def codebook_size(self):
55 | return self.model.vector_quantizer.embedding.shape[0]
56 |
57 | @torch.no_grad()
58 | def forward(
59 | self,
60 | wav_input,
61 | flatten = True,
62 | input_sample_hz = None
63 | ):
64 | if exists(input_sample_hz):
65 | wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)
66 |
67 | if exists(self.seq_len_multiple_of):
68 | wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)
69 |
70 | embed = self.model.feature_extractor(wav_input)
71 | _, codebook_indices = self.model.vector_quantizer.forward_idx(embed)
72 |
73 | if not flatten:
74 | return codebook_indices
75 |
76 | return rearrange(codebook_indices, 'b ... -> b (...)')
77 |
--------------------------------------------------------------------------------
/audiolm_pytorch/t5.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import transformers
3 | from transformers import T5Tokenizer, T5EncoderModel, T5Config
4 |
5 | from beartype import beartype
6 | from beartype.typing import Union, List
7 |
8 | # less warning messages since only using encoder
9 |
10 | transformers.logging.set_verbosity_error()
11 |
12 | # helper functions
13 |
14 | def exists(val):
15 | return val is not None
16 |
17 | # config
18 |
19 | MAX_LENGTH = 256
20 |
21 | DEFAULT_T5_NAME = 'google/t5-v1_1-base'
22 |
23 | T5_CONFIGS = {}
24 |
25 | # singleton globals
26 |
27 | def get_tokenizer(name):
28 | tokenizer = T5Tokenizer.from_pretrained(name)
29 | return tokenizer
30 |
31 | def get_model(name):
32 | model = T5EncoderModel.from_pretrained(name)
33 | return model
34 |
35 | def get_model_and_tokenizer(name):
36 | global T5_CONFIGS
37 |
38 | if name not in T5_CONFIGS:
39 | T5_CONFIGS[name] = dict()
40 |
41 | if "model" not in T5_CONFIGS[name]:
42 | T5_CONFIGS[name]["model"] = get_model(name)
43 |
44 | if "tokenizer" not in T5_CONFIGS[name]:
45 | T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)
46 |
47 | return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']
48 |
49 | def get_encoded_dim(name):
50 | if name not in T5_CONFIGS:
51 | config = T5Config.from_pretrained(name)
52 | T5_CONFIGS[name] = dict(config = config)
53 |
54 | elif "config" in T5_CONFIGS[name]:
55 | config = T5_CONFIGS[name]["config"]
56 |
57 | elif "model" in T5_CONFIGS[name]:
58 | config = T5_CONFIGS[name]["model"].config
59 |
60 | else:
61 | raise ValueError(f'unknown t5 name {name}')
62 |
63 | return config.d_model
64 |
65 | # encoding text
66 |
67 | @beartype
68 | def t5_encode_text(
69 | texts: Union[str, List[str]],
70 | name = DEFAULT_T5_NAME,
71 | output_device = None
72 | ):
73 | if isinstance(texts, str):
74 | texts = [texts]
75 |
76 | t5, tokenizer = get_model_and_tokenizer(name)
77 |
78 | if torch.cuda.is_available():
79 | t5 = t5.cuda()
80 |
81 | device = next(t5.parameters()).device
82 |
83 | encoded = tokenizer.batch_encode_plus(
84 | texts,
85 | return_tensors = 'pt',
86 | padding = 'longest',
87 | max_length = MAX_LENGTH,
88 | truncation = True
89 | )
90 |
91 | input_ids = encoded.input_ids.to(device)
92 | attn_mask = encoded.attention_mask.to(device)
93 |
94 | t5.eval()
95 |
96 | with torch.no_grad():
97 | output = t5(input_ids = input_ids, attention_mask = attn_mask)
98 | encoded_text = output.last_hidden_state.detach()
99 |
100 | attn_mask = attn_mask[..., None].bool()
101 |
102 | if not exists(output_device):
103 | encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
104 | return encoded_text
105 |
106 | encoded_text.to(output_device)
107 | attn_mask.to(output_device)
108 |
109 | encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
110 | return encoded_text
111 |
--------------------------------------------------------------------------------
/audiolm_pytorch/hubert_kmeans.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | from torch import nn
5 | from einops import rearrange, pack, unpack
6 |
7 | import joblib
8 |
9 | import fairseq
10 |
11 | from torchaudio.functional import resample
12 |
13 | from audiolm_pytorch.utils import curtail_to_multiple
14 |
15 | import logging
16 | logging.root.setLevel(logging.ERROR)
17 |
18 | def exists(val):
19 | return val is not None
20 |
21 | def default(val, d):
22 | return val if exists(val) else d
23 |
24 | class HubertWithKmeans(nn.Module):
25 | """
26 | checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
27 | or you can train your own
28 | """
29 |
30 | def __init__(
31 | self,
32 | checkpoint_path,
33 | kmeans_path,
34 | target_sample_hz = 16000,
35 | seq_len_multiple_of = None,
36 | output_layer = 9
37 | ):
38 | super().__init__()
39 | self.target_sample_hz = target_sample_hz
40 | self.seq_len_multiple_of = seq_len_multiple_of
41 | self.output_layer = output_layer
42 |
43 | model_path = Path(checkpoint_path)
44 | kmeans_path = Path(kmeans_path)
45 |
46 | assert model_path.exists(), f'path {checkpoint_path} does not exist'
47 | assert kmeans_path.exists(), f'path {kmeans_path} does not exist'
48 |
49 | checkpoint = torch.load(checkpoint_path)
50 | load_model_input = {checkpoint_path: checkpoint}
51 | model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
52 |
53 | self.model = model[0]
54 | self.model.eval()
55 |
56 | kmeans = joblib.load(kmeans_path)
57 | self.kmeans = kmeans
58 |
59 | @property
60 | def groups(self):
61 | return 1
62 |
63 | @property
64 | def codebook_size(self):
65 | return self.kmeans.n_clusters
66 |
67 | @torch.no_grad()
68 | def forward(
69 | self,
70 | wav_input,
71 | flatten = True,
72 | input_sample_hz = None
73 | ):
74 | device = wav_input.device
75 |
76 | if exists(input_sample_hz):
77 | wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)
78 |
79 | if exists(self.seq_len_multiple_of):
80 | wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)
81 |
82 | embed = self.model(
83 | wav_input,
84 | features_only = True,
85 | mask = False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
86 | output_layer = self.output_layer
87 | )
88 |
89 | embed, packed_shape = pack([embed['x']], '* d')
90 |
91 | codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())
92 |
93 | codebook_indices = torch.from_numpy(codebook_indices).to(device).long()
94 |
95 | if flatten:
96 | return codebook_indices
97 |
98 | codebook_indices, = unpack(codebook_indices, packed_shape, '*')
99 | return codebook_indices
100 |
--------------------------------------------------------------------------------
/audiolm_pytorch/data.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from functools import partial, wraps
3 |
4 | from beartype import beartype
5 | from beartype.typing import Tuple, Union, Optional
6 | from beartype.door import is_bearable
7 |
8 | import torchaudio
9 | from torchaudio.functional import resample
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | from torch.nn.utils.rnn import pad_sequence
14 | from torch.utils.data import Dataset, DataLoader
15 |
16 | from audiolm_pytorch.utils import curtail_to_multiple
17 |
18 | from einops import rearrange
19 |
20 | # helper functions
21 |
22 | def exists(val):
23 | return val is not None
24 |
25 | def cast_tuple(val, length = 1):
26 | return val if isinstance(val, tuple) else ((val,) * length)
27 |
28 | # type
29 |
30 | OptionalIntOrTupleInt = Optional[Union[int, Tuple[Optional[int], ...]]]
31 |
32 | # dataset functions
33 |
34 | class SoundDataset(Dataset):
35 | @beartype
36 | def __init__(
37 | self,
38 | folder,
39 | exts = ['flac', 'wav', 'mp3', 'webm'],
40 | max_length: OptionalIntOrTupleInt = None,
41 | target_sample_hz: OptionalIntOrTupleInt = None,
42 | seq_len_multiple_of: OptionalIntOrTupleInt = None
43 | ):
44 | super().__init__()
45 | path = Path(folder)
46 | assert path.exists(), 'folder does not exist'
47 |
48 | files = [file for ext in exts for file in path.glob(f'**/*.{ext}')]
49 | assert len(files) > 0, 'no sound files found'
50 |
51 | self.files = files
52 |
53 | self.target_sample_hz = cast_tuple(target_sample_hz)
54 | num_outputs = len(self.target_sample_hz)
55 |
56 | self.max_length = cast_tuple(max_length, num_outputs)
57 | self.seq_len_multiple_of = cast_tuple(seq_len_multiple_of, num_outputs)
58 |
59 | assert len(self.max_length) == len(self.target_sample_hz) == len(self.seq_len_multiple_of)
60 |
61 | def __len__(self):
62 | return len(self.files)
63 |
64 | def __getitem__(self, idx):
65 | file = self.files[idx]
66 |
67 | data, sample_hz = torchaudio.load(file)
68 |
69 | assert data.numel() > 0, f'one of your audio file ({file}) is empty. please remove it from your folder'
70 |
71 | if data.shape[0] > 1:
72 | # the audio has more than 1 channel, convert to mono
73 | data = torch.mean(data, dim=0).unsqueeze(0)
74 |
75 | num_outputs = len(self.target_sample_hz)
76 | data = cast_tuple(data, num_outputs)
77 |
78 | # resample if target_sample_hz is not None in the tuple
79 |
80 | data_tuple = tuple((resample(d, sample_hz, target_sample_hz) if exists(target_sample_hz) else d) for d, target_sample_hz in zip(data, self.target_sample_hz))
81 |
82 | output = []
83 |
84 | # process each of the data resample at different frequencies individually
85 |
86 | for data, max_length, seq_len_multiple_of in zip(data_tuple, self.max_length, self.seq_len_multiple_of):
87 | audio_length = data.size(1)
88 |
89 | # pad or curtail
90 |
91 | if audio_length > max_length:
92 | max_start = audio_length - max_length
93 | start = torch.randint(0, max_start, (1, ))
94 | data = data[:, start:start + max_length]
95 |
96 | else:
97 | data = F.pad(data, (0, max_length - audio_length), 'constant')
98 |
99 | data = rearrange(data, '1 ... -> ...')
100 |
101 | if exists(max_length):
102 | data = data[:max_length]
103 |
104 | if exists(seq_len_multiple_of):
105 | data = curtail_to_multiple(data, seq_len_multiple_of)
106 |
107 | output.append(data.float())
108 |
109 | # cast from list to tuple
110 |
111 | output = tuple(output)
112 |
113 | # return only one audio, if only one target resample freq
114 |
115 | if num_outputs == 1:
116 | return output[0]
117 |
118 | return output
119 |
120 | # dataloader functions
121 |
122 | def collate_one_or_multiple_tensors(fn):
123 | @wraps(fn)
124 | def inner(data):
125 | is_one_data = not isinstance(data[0], tuple)
126 |
127 | if is_one_data:
128 | data = torch.stack(data)
129 | return (data,)
130 |
131 | outputs = []
132 | for datum in zip(*data):
133 | if is_bearable(datum, Tuple[str, ...]):
134 | output = list(datum)
135 | else:
136 | output = fn(datum)
137 |
138 | outputs.append(output)
139 |
140 | return tuple(outputs)
141 |
142 | return inner
143 |
144 | @collate_one_or_multiple_tensors
145 | def curtail_to_shortest_collate(data):
146 | min_len = min(*[datum.shape[0] for datum in data])
147 | data = [datum[:min_len] for datum in data]
148 | return torch.stack(data)
149 |
150 | @collate_one_or_multiple_tensors
151 | def pad_to_longest_fn(data):
152 | return pad_sequence(data, batch_first = True)
153 |
154 | def get_dataloader(ds, pad_to_longest = True, **kwargs):
155 | collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate
156 | return DataLoader(ds, collate_fn = collate_fn, **kwargs)
157 |
--------------------------------------------------------------------------------
/audiolm_pytorch/encodec.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | from einops import rearrange, pack, unpack
3 |
4 | import torch
5 | from torch import nn
6 |
7 | from vector_quantize_pytorch import ResidualVQ
8 |
9 | from encodec import EncodecModel
10 | from encodec.utils import _linear_overlap_add
11 |
12 | class EncodecWrapper(nn.Module):
13 | """
14 | Support pretrained 24kHz Encodec by Meta AI, if you want to skip training SoundStream.
15 |
16 | TODO:
17 | - see if we need to keep the scaled version and somehow persist the scale factors for when we need to decode? Right
18 | now I'm just setting self.model.normalize = False to sidestep all of that
19 | - see if we can use the 48kHz model, which is specifically for music. Right now we're using the 24kHz model because
20 | that's what was used in MusicLM and avoids any resampling issues.
21 | -
22 |
23 | """
24 | def __init__(
25 | self,
26 | target_sample_hz = 24000,
27 | strides = (2, 4, 5, 8),
28 | num_quantizers = 8,
29 | ):
30 | super().__init__()
31 | # Instantiate a pretrained EnCodec model
32 | self.model = EncodecModel.encodec_model_24khz()
33 | self.model.normalize = False # this means we don't need to scale codes e.g. when running model.encode(wav)
34 |
35 | # bandwidth affects num quantizers used: https://github.com/facebookresearch/encodec/pull/41
36 | self.model.set_target_bandwidth(6.0)
37 | assert num_quantizers == 8, "assuming 8 quantizers for now, see bandwidth comment above"
38 |
39 | # Fields that SoundStream has that get used externally. We replicate them here.
40 | self.target_sample_hz = target_sample_hz
41 | assert self.target_sample_hz == 24000, "haven't done anything with non-24kHz yet"
42 |
43 | self.codebook_dim = 128
44 | self.rq_groups = 1
45 | self.num_quantizers = num_quantizers
46 | self.strides = strides # used in seq_len_multiple_of
47 |
48 | # cross entropy loss to indices passed in on l2 distance logits introduced in vector-quantize-pytorch 1.2.2
49 |
50 | self.rq = ResidualVQ(
51 | dim = 128,
52 | codebook_size = 1024,
53 | num_quantizers = 8
54 | )
55 |
56 | # copy codebook over to ResidualVQ for cross entropy loss logic from naturalspeech2
57 | # luckily, it seems Meta AI basically used my ResidualVQ code verbatim. makes porting it over easy
58 |
59 | for encodec_rq_layer, rq_layer in zip(self.model.quantizer.vq.layers, self.rq.layers):
60 | encodec_codebook = dict(encodec_rq_layer._codebook.named_buffers()).get('embed')
61 | vq_codebook = dict(rq_layer._codebook.named_buffers()).get('embed')
62 |
63 | encodec_codebook = rearrange(encodec_codebook, '... -> 1 ...')
64 | vq_codebook.copy_(encodec_codebook)
65 |
66 | @property
67 | def seq_len_multiple_of(self):
68 | return reduce(lambda x, y: x * y, self.strides)
69 |
70 | def forward(
71 | self,
72 | x,
73 | return_encoded = False,
74 | **kwargs
75 | ):
76 |
77 | x, ps = pack([x], '* n')
78 |
79 | # kwargs for stuff like return_encoded=True, which SoundStream uses but Encodec doesn't
80 | assert not self.model.training, "Encodec is pretrained and should never be called outside eval mode."
81 | # Unlike in the Encodec sample code in its README, x has already been resampled so we don't need to call
82 | # convert_audio and unsqueeze. The convert_audio function also doesn't play nicely with batches.
83 |
84 | # b = batch, t = timesteps, 1 channel for the 24kHz model, 2 channels for the 48kHz model
85 | wav = rearrange(x, f'b t -> b {self.model.channels} t')
86 |
87 | # Extract discrete codes from EnCodec
88 | with torch.no_grad():
89 | encoded_frames = self.model.encode(wav)
90 | # encoded_frames is a list of (frame, scale) tuples. Scale is a scalar but we don't use it. Frame is a tensor
91 | # of shape [batch, num_quantizers, num_samples_per_frame]. We want to concatenate the frames to get all the
92 | # timesteps concatenated.
93 | codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [batch, num_quantizers, timesteps]
94 | # transformer code that uses codec expects codes to be [batch, timesteps, num_quantizers]
95 | codes = rearrange(codes, 'b q n -> b n q') # result: [batch, timesteps, num_quantizers]
96 | # in original soundstream, is x, indices, commit_loss. But we only use indices in eval mode, so just keep that.
97 |
98 | # allow for returning of sum of quantized embeddings
99 |
100 | emb = None
101 |
102 | if return_encoded:
103 | emb = self.get_emb_from_indices(codes)
104 | emb, = unpack(emb, ps, '* n c')
105 |
106 | codes, = unpack(codes, ps, '* n q')
107 |
108 | return emb, codes, None
109 |
110 | def decode_from_codebook_indices(self, quantized_indices):
111 | # Input: batch x num tokens x num quantizers
112 | # Output: batch x 1 x num samples
113 |
114 | assert self.model.sample_rate == 24000,\
115 | "if changing to 48kHz, that model segments its audio into lengths of 1.0 second with 1% overlap, whereas " \
116 | "the 24kHz doesn't segment at all. this means the frame decode logic might change; this is a reminder to " \
117 | "double check that."
118 | # Since 24kHz pretrained doesn't do any segmenting, we have all the frames already (1 frame = 1 token in quantized_indices)
119 |
120 | # The following code is hacked in from self.model.decode() (Encodec version 0.1.1) where we skip the part about
121 | # scaling.
122 | # Shape: 1 x (num_frames * stride product). 1 because we have 1 frame (because no segmenting)
123 | frames = self._decode_frame(quantized_indices)
124 | result = _linear_overlap_add(frames, self.model.segment_stride or 1)
125 | # TODO: I'm not overly pleased with this because when this function gets called, we just rearrange the result
126 | # back to b n anyways, but we'll keep this as a temporary hack just to make things work for now
127 | return rearrange(result, 'b n -> b 1 n')
128 |
129 | def get_emb_from_indices(self, indices):
130 | codes = rearrange(indices, 'b t q -> q b t')
131 | emb = self.model.quantizer.decode(codes)
132 | return rearrange(emb, 'b c n -> b n c')
133 |
134 | def decode(self, emb):
135 | emb = rearrange(emb, 'b n c -> b c n')
136 | return self.model.decoder(emb)
137 |
138 | def _decode_frame(self, quantized_indices):
139 | # The following code is hacked in from self.model._decode_frame() (Encodec version 0.1.1) where we assume we've
140 | # already unwrapped the EncodedFrame
141 | # Input: batch x num tokens x num quantizers
142 | # Output: batch x new_num_samples, where new_num_samples is num_frames * stride product (may be slightly
143 | # larger than original num samples as a result, because the last frame might not be "fully filled" with samples
144 | # if num_samples doesn't divide perfectly).
145 | # num_frames == the number of acoustic tokens you have, one token per frame
146 | codes = rearrange(quantized_indices, 'b t q -> q b t')
147 | emb = self.model.quantizer.decode(codes)
148 | # emb shape: batch x self.model.quantizer.dimension x T. Note self.model.quantizer.dimension is the embedding dimension
149 | return self.model.decoder(emb)
150 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## AudioLM - Pytorch
4 |
5 | Implementation of AudioLM, a Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
6 |
7 | It also extends the work for conditioning with classifier free guidance with T5. This allows for one to do text-to-audio or TTS, not offered in the paper. Yes, this means VALL-E can be trained from this repository. It is essentially the same.
8 |
9 | Please join
if you are interested in replicating this work in the open
10 |
11 | This repository now also contains a MIT licensed version of SoundStream. It is also compatible with EnCodec, which is also [MIT-licensed](https://github.com/facebookresearch/encodec/commit/349b72939f57cb3bc7b60906c0ee8228c849485d) at the time of writing.
12 |
13 | Update: AudioLM was essentially used to 'solve' music generation in the new MusicLM
14 |
15 | In the future, this movie clip would no longer make any sense. You would just prompt an AI instead.
16 |
17 | ## Appreciation
18 |
19 | - Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research
20 |
21 | - 🤗 Huggingface for their amazing accelerate and transformers libraries
22 |
23 | - MetaAI for Fairseq and the liberal license
24 |
25 | - @eonglints and Joseph for offering their professional advice and expertise as well as pull requests!
26 |
27 | - @djqualia, @yigityu, @inspirit, and @BlackFox1197 for helping with the debugging of soundstream
28 |
29 | - Allen and LWprogramming for reviewing the code and submitting bug fixes!
30 |
31 | - Ilya for finding an issue with multi-scale discriminator downsampling and for soundstream trainer improvements
32 |
33 | - Andrey for identifying a missing loss in soundstream and guiding me through the proper mel spectrogram hyperparameters
34 |
35 | - Alejandro and Ilya for sharing their results with training soundstream, and for working through a few issues with the local attention positional embeddings
36 |
37 | - LWprogramming for adding Encodec compatibility!
38 |
39 | - @YoungloLee for identifying a big bug in the 1d causal convolution for soundstream related to padding not accounting for strides!
40 |
41 | - Hayden for pointing out some discrepancies in the multi-scale discriminator for Soundstream
42 |
43 | ## Install
44 |
45 | ```bash
46 | $ pip install audiolm-pytorch
47 | ```
48 |
49 | ## Usage
50 |
51 | ### SoundStream & Encodec
52 |
53 | There are two options for the neural codec. If you want to use the pretrained 24kHz Encodec, just create an Encodec object as follows:
54 | ```python
55 | from audiolm_pytorch import EncodecWrapper
56 | encodec = EncodecWrapper()
57 | # Now you can use the encodec variable in the same way you'd use the soundstream variables below.
58 | ```
59 |
60 | Otherwise, to stay more true to the original paper, you can use `SoundStream`. First, `SoundStream` needs to be trained on a large corpus of audio data
61 |
62 | ```python
63 | from audiolm_pytorch import SoundStream, SoundStreamTrainer
64 |
65 | soundstream = SoundStream(
66 | codebook_size = 1024,
67 | rq_num_quantizers = 8,
68 | rq_groups = 2, # this paper proposes using multi-headed residual vector quantization - https://arxiv.org/abs/2305.02765
69 | attn_window_size = 128, # local attention receptive field at bottleneck
70 | attn_depth = 2 # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better
71 | )
72 |
73 | trainer = SoundStreamTrainer(
74 | soundstream,
75 | folder = '/path/to/audio/files',
76 | batch_size = 4,
77 | grad_accum_every = 8, # effective batch size of 32
78 | data_max_length_seconds = 2, # train on 2 second audio
79 | num_train_steps = 1_000_000
80 | ).cuda()
81 |
82 | trainer.train()
83 |
84 | # after a lot of training, you can test the autoencoding as so
85 |
86 | audio = torch.randn(10080).cuda()
87 | recons = soundstream(audio, return_recons_only = True) # (1, 10080) - 1 channel
88 | ```
89 |
90 | You can also use soundstreams that are specific to `AudioLM` and `MusicLM` by importing `AudioLMSoundStream` and `MusicLMSoundStream` respectively
91 |
92 | ```python
93 | from audiolm_pytorch import AudioLMSoundStream, MusicLMSoundStream
94 |
95 | soundstream = AudioLMSoundStream(...) # say you want the hyperparameters as in Audio LM paper
96 |
97 | # rest is the same as above
98 | ```
99 |
100 | As of version `0.17.0`, you can now invoke the class method on `SoundStream` to load from checkpoint files, without having to remember your configurations.
101 |
102 | ```python
103 | from audiolm_pytorch import SoundStream
104 |
105 | soundstream = SoundStream.init_and_load_from('./path/to/checkpoint.pt')
106 | ```
107 |
108 | ### Hierarchical Transformers
109 |
110 | Then three separate transformers (`SemanticTransformer`, `CoarseTransformer`, `FineTransformer`) need to be trained
111 |
112 |
113 | ex. `SemanticTransformer`
114 |
115 | ```python
116 | import torch
117 | from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer
118 |
119 | # hubert checkpoints can be downloaded at
120 | # https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
121 |
122 | wav2vec = HubertWithKmeans(
123 | checkpoint_path = './hubert/hubert_base_ls960.pt',
124 | kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
125 | )
126 |
127 | semantic_transformer = SemanticTransformer(
128 | num_semantic_tokens = wav2vec.codebook_size,
129 | dim = 1024,
130 | depth = 6
131 | ).cuda()
132 |
133 |
134 | trainer = SemanticTransformerTrainer(
135 | transformer = semantic_transformer,
136 | wav2vec = wav2vec,
137 | folder ='/path/to/audio/files',
138 | batch_size = 1,
139 | data_max_length = 320 * 32,
140 | num_train_steps = 1
141 | )
142 |
143 | trainer.train()
144 | ```
145 |
146 | ex. `CoarseTransformer`
147 |
148 | ```python
149 | import torch
150 | from audiolm_pytorch import HubertWithKmeans, SoundStream, CoarseTransformer, CoarseTransformerTrainer
151 |
152 | wav2vec = HubertWithKmeans(
153 | checkpoint_path = './hubert/hubert_base_ls960.pt',
154 | kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
155 | )
156 |
157 | soundstream = SoundStream.init_and_load_from('/path/to/trained/soundstream.pt')
158 |
159 | coarse_transformer = CoarseTransformer(
160 | num_semantic_tokens = wav2vec.codebook_size,
161 | codebook_size = 1024,
162 | num_coarse_quantizers = 3,
163 | dim = 512,
164 | depth = 6
165 | )
166 |
167 | trainer = CoarseTransformerTrainer(
168 | transformer = coarse_transformer,
169 | codec = soundstream,
170 | wav2vec = wav2vec,
171 | folder = '/path/to/audio/files',
172 | batch_size = 1,
173 | data_max_length = 320 * 32,
174 | num_train_steps = 1_000_000
175 | )
176 |
177 | trainer.train()
178 | ```
179 |
180 | ex. `FineTransformer`
181 |
182 | ```python
183 | import torch
184 | from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerTrainer
185 |
186 | soundstream = SoundStream.init_and_load_from('/path/to/trained/soundstream.pt')
187 |
188 | fine_transformer = FineTransformer(
189 | num_coarse_quantizers = 3,
190 | num_fine_quantizers = 5,
191 | codebook_size = 1024,
192 | dim = 512,
193 | depth = 6
194 | )
195 |
196 | trainer = FineTransformerTrainer(
197 | transformer = fine_transformer,
198 | codec = soundstream,
199 | folder = '/path/to/audio/files',
200 | batch_size = 1,
201 | data_max_length = 320 * 32,
202 | num_train_steps = 1_000_000
203 | )
204 |
205 | trainer.train()
206 | ```
207 |
208 | All together now
209 |
210 | ```python
211 | from audiolm_pytorch import AudioLM
212 |
213 | audiolm = AudioLM(
214 | wav2vec = wav2vec,
215 | codec = soundstream,
216 | semantic_transformer = semantic_transformer,
217 | coarse_transformer = coarse_transformer,
218 | fine_transformer = fine_transformer
219 | )
220 |
221 | generated_wav = audiolm(batch_size = 1)
222 |
223 | # or with priming
224 |
225 | generated_wav_with_prime = audiolm(prime_wave = torch.randn(1, 320 * 8))
226 |
227 | # or with text condition, if given
228 |
229 | generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the distant echos of bells'])
230 |
231 | ```
232 |
233 | ## Text Conditioned Audio Synthesis
234 |
235 | Update: Looks like this will work, given 'VALL-E'
236 |
237 | ex. Semantic Transformer
238 |
239 | ```python
240 | import torch
241 | from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer
242 |
243 | wav2vec = HubertWithKmeans(
244 | checkpoint_path = './hubert/hubert_base_ls960.pt',
245 | kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
246 | )
247 |
248 | semantic_transformer = SemanticTransformer(
249 | num_semantic_tokens = 500,
250 | dim = 1024,
251 | depth = 6,
252 | has_condition = True, # this will have to be set to True
253 | cond_as_self_attn_prefix = True # whether to condition as prefix to self attention, instead of cross attention, as was done in 'VALL-E' paper
254 | ).cuda()
255 |
256 | # mock text video dataset (as an example)
257 |
258 | # you will have to extend your own from `Dataset`, and return an audio tensor as well as a string (the audio description) in any order (the framework will autodetect and route it into the transformer)
259 |
260 | from torch.utils.data import Dataset
261 |
262 | class MockTextAudioDataset(Dataset):
263 | def __init__(self, length = 100, audio_length = 320 * 32):
264 | super().__init__()
265 | self.audio_length = audio_length
266 | self.len = length
267 |
268 | def __len__(self):
269 | return self.len
270 |
271 | def __getitem__(self, idx):
272 | mock_audio = torch.randn(self.audio_length)
273 | mock_caption = 'audio caption'
274 | return mock_caption, mock_audio
275 |
276 | dataset = MockTextAudioDataset()
277 |
278 | # instantiate semantic transformer trainer and train
279 |
280 | trainer = SemanticTransformerTrainer(
281 | transformer = semantic_transformer,
282 | wav2vec = wav2vec,
283 | dataset = dataset,
284 | batch_size = 4,
285 | grad_accum_every = 8,
286 | data_max_length = 320 * 32,
287 | num_train_steps = 1_000_000
288 | )
289 |
290 | trainer.train()
291 |
292 | # after much training above
293 |
294 | sample = trainer.generate(text = ['sound of rain drops on the rooftops'], batch_size = 1, max_length = 2) # (1, < 128) - may terminate early if it detects [eos]
295 |
296 | ```
297 |
298 | ## Multi-GPU
299 |
300 | Because all the trainer classes uses 🤗 Accelerator, you can easily do multi gpu training by using the `accelerate` command as so
301 |
302 | At the project root
303 |
304 | ```python
305 | $ accelerate config
306 | ```
307 |
308 | Then, in the same directory
309 |
310 | ```python
311 | $ accelerate launch train.py
312 | ```
313 |
314 | ## Todo
315 |
316 | - [x] complete CoarseTransformer
317 | - [x] use fairseq vq-wav2vec for embeddings
318 | - [x] add conditioning
319 | - [x] add classifier free guidance
320 | - [x] add unique consecutive for
321 | - [x] incorporate ability to use hubert intermediate features as semantic tokens, recommended by eonglints
322 | - [x] accommodate variable lengthed audio, bring in eos token
323 | - [x] make sure unique consecutive works with coarse transformer
324 | - [x] pretty printing all discriminator losses to log
325 | - [x] handle when generating semantic tokens, that last logits may not be necessarily the last in the sequence given unique consecutive processing
326 | - [x] complete sampling code for both Coarse and Fine Transformers, which will be tricky
327 | - [x] make sure full inference with or without prompting works on the `AudioLM` class
328 | - [x] complete full training code for soundstream, taking care of discriminator training
329 | - [x] add efficient gradient penalty for discriminators for soundstream
330 | - [x] wire up sample hz from sound dataset -> transformers, and have proper resampling within during training - think about whether to allow for dataset to have sound files of varying or enforce same sample hz
331 | - [x] full transformer training code for all three transformers
332 | - [x] refactor so semantic transformer has a wrapper to that handles unique consecutives as well as wav to hubert or vq-wav2vec
333 | - [x] simply not self attend to eos token on the prompting side (semantic for coarse transformer, coarse for fine transformer)
334 | - [x] add structured dropout from forgetful causal masking, far better than traditional dropouts
335 | - [x] figure out how to suppress logging in fairseq
336 | - [x] assert that all three transformers passed into audiolm is compatible
337 | - [x] allow for specialized relative positional embeddings in fine transformer based on absolute matching positions of quantizers between coarse and fine
338 | - [x] allow for grouped residual vq in soundstream (use `GroupedResidualVQ` from vector-quantize-pytorch lib), from hifi-codec
339 |
340 | - [ ] redo the positional embeddings in the presence of groups in residual vq
341 | - [ ] test with speech synthesis for starters
342 | - [ ] cli tool, something like `audiolm generate ` and save generated wav file to local directory
343 | - [ ] return a list of waves in the case of variable lengthed audio
344 | - [ ] just take care of the edge case in coarse transformer text conditioned training, where the raw wave is resampled at different frequencies. autodetermine how to route based on length
345 |
346 | ## Citations
347 |
348 | ```bibtex
349 | @inproceedings{Borsos2022AudioLMAL,
350 | title = {AudioLM: a Language Modeling Approach to Audio Generation},
351 | author = {Zal{\'a}n Borsos and Rapha{\"e}l Marinier and Damien Vincent and Eugene Kharitonov and Olivier Pietquin and Matthew Sharifi and Olivier Teboul and David Grangier and Marco Tagliasacchi and Neil Zeghidour},
352 | year = {2022}
353 | }
354 | ```
355 |
356 | ```bibtex
357 | @misc{https://doi.org/10.48550/arxiv.2107.03312,
358 | title = {SoundStream: An End-to-End Neural Audio Codec},
359 | author = {Zeghidour, Neil and Luebs, Alejandro and Omran, Ahmed and Skoglund, Jan and Tagliasacchi, Marco},
360 | publisher = {arXiv},
361 | url = {https://arxiv.org/abs/2107.03312},
362 | year = {2021}
363 | }
364 | ```
365 |
366 | ```bibtex
367 | @misc{shazeer2020glu,
368 | title = {GLU Variants Improve Transformer},
369 | author = {Noam Shazeer},
370 | year = {2020},
371 | url = {https://arxiv.org/abs/2002.05202}
372 | }
373 | ```
374 |
375 | ```bibtex
376 | @article{Shazeer2019FastTD,
377 | title = {Fast Transformer Decoding: One Write-Head is All You Need},
378 | author = {Noam M. Shazeer},
379 | journal = {ArXiv},
380 | year = {2019},
381 | volume = {abs/1911.02150}
382 | }
383 | ```
384 |
385 | ```bibtex
386 | @article{Ho2022ClassifierFreeDG,
387 | title = {Classifier-Free Diffusion Guidance},
388 | author = {Jonathan Ho},
389 | journal = {ArXiv},
390 | year = {2022},
391 | volume = {abs/2207.12598}
392 | }
393 | ```
394 |
395 | ```bibtex
396 | @misc{crowson2022,
397 | author = {Katherine Crowson},
398 | url = {https://twitter.com/rivershavewings}
399 | }
400 | ```
401 |
402 | ```bibtex
403 | @misc{ding2021cogview,
404 | title = {CogView: Mastering Text-to-Image Generation via Transformers},
405 | author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
406 | year = {2021},
407 | eprint = {2105.13290},
408 | archivePrefix = {arXiv},
409 | primaryClass = {cs.CV}
410 | }
411 | ```
412 |
413 | ```bibtex
414 | @article{Liu2022FCMFC,
415 | title = {FCM: Forgetful Causal Masking Makes Causal Language Models Better Zero-Shot Learners},
416 | author = {Hao Liu and Xinyang Geng and Lisa Lee and Igor Mordatch and Sergey Levine and Sharan Narang and P. Abbeel},
417 | journal = {ArXiv},
418 | year = {2022},
419 | volume = {abs/2210.13432}
420 | }
421 | ```
422 |
423 | ```bibtex
424 | @inproceedings{anonymous2022normformer,
425 | title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
426 | author = {Anonymous},
427 | booktitle = {Submitted to The Tenth International Conference on Learning Representations },
428 | year = {2022},
429 | url = {https://openreview.net/forum?id=GMYWzWztDx5},
430 | note = {under review}
431 | }
432 | ```
433 |
434 | ```bibtex
435 | @article{Li2021LocalViTBL,
436 | title = {LocalViT: Bringing Locality to Vision Transformers},
437 | author = {Yawei Li and K. Zhang and Jie Cao and Radu Timofte and Luc Van Gool},
438 | journal = {ArXiv},
439 | year = {2021},
440 | volume = {abs/2104.05707}
441 | }
442 | ```
443 |
444 | ```bibtex
445 | @misc{liu2021swin,
446 | title = {Swin Transformer V2: Scaling Up Capacity and Resolution},
447 | author = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
448 | year = {2021},
449 | eprint = {2111.09883},
450 | archivePrefix = {arXiv},
451 | primaryClass = {cs.CV}
452 | }
453 | ```
454 |
455 | ```bibtex
456 | @inproceedings{Ma2022MegaMA,
457 | title = {Mega: Moving Average Equipped Gated Attention},
458 | author = {Xuezhe Ma and Chunting Zhou and Xiang Kong and Junxian He and Liangke Gui and Graham Neubig and Jonathan May and Luke Zettlemoyer},
459 | year = {2022}
460 | }
461 | ```
462 |
463 | ```bibtex
464 | @misc{gilmer2023intriguing
465 | title = {Intriguing Properties of Transformer Training Instabilities},
466 | author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
467 | year = {2023},
468 | status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
469 | }
470 | ```
471 |
472 | ```bibtex
473 | @article{Defossez2022HighFN,
474 | title = {High Fidelity Neural Audio Compression},
475 | author = {Alexandre D'efossez and Jade Copet and Gabriel Synnaeve and Yossi Adi},
476 | journal = {ArXiv},
477 | year = {2022},
478 | volume = {abs/2210.13438}
479 | }
480 | ```
481 |
482 | ```bibtex
483 | @article{Hu2017SqueezeandExcitationN,
484 | title = {Squeeze-and-Excitation Networks},
485 | author = {Jie Hu and Li Shen and Gang Sun},
486 | journal = {2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition},
487 | year = {2017},
488 | pages = {7132-7141}
489 | }
490 | ```
491 |
492 | ```bibtex
493 | @inproceedings{Yang2023HiFiCodecGV,
494 | title = {HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec},
495 | author = {Dongchao Yang and Songxiang Liu and Rongjie Huang and Jinchuan Tian and Chao Weng and Yuexian Zou},
496 | year = {2023}
497 | }
498 | ```
499 |
--------------------------------------------------------------------------------
/audiolm_pytorch_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 12,
6 | "metadata": {
7 | "colab": {
8 | "base_uri": "https://localhost:8080/"
9 | },
10 | "id": "n337KoD2om3L",
11 | "outputId": "97ada0c6-f21c-483e-d63d-08abddd49004"
12 | },
13 | "outputs": [
14 | {
15 | "name": "stdout",
16 | "output_type": "stream",
17 | "text": [
18 | "Mon Jan 30 20:47:47 2023 \n",
19 | "+-----------------------------------------------------------------------------+\n",
20 | "| NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 |\n",
21 | "|-------------------------------+----------------------+----------------------+\n",
22 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
23 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
24 | "| | | MIG M. |\n",
25 | "|===============================+======================+======================|\n",
26 | "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
27 | "| N/A 73C P0 32W / 70W | 10692MiB / 15360MiB | 0% Default |\n",
28 | "| | | N/A |\n",
29 | "+-------------------------------+----------------------+----------------------+\n",
30 | " \n",
31 | "+-----------------------------------------------------------------------------+\n",
32 | "| Processes: |\n",
33 | "| GPU GI CI PID Type Process name GPU Memory |\n",
34 | "| ID ID Usage |\n",
35 | "|=============================================================================|\n",
36 | "| 0 N/A N/A 5896 C 10689MiB |\n",
37 | "+-----------------------------------------------------------------------------+\n"
38 | ]
39 | }
40 | ],
41 | "source": [
42 | "!nvidia-smi\n",
43 | "\n",
44 | "# If this doesn't work, there's no GPU available or detected"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 13,
50 | "metadata": {
51 | "colab": {
52 | "base_uri": "https://localhost:8080/"
53 | },
54 | "id": "TLJAcUHpvmp4",
55 | "outputId": "95bcda95-a484-40c6-e5a7-47f4378759a8"
56 | },
57 | "outputs": [
58 | {
59 | "name": "stdout",
60 | "output_type": "stream",
61 | "text": [
62 | "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
63 | "Requirement already satisfied: audiolm-pytorch in /usr/local/lib/python3.8/dist-packages (0.7.5)\n",
64 | "Requirement already satisfied: ema-pytorch in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.1.4)\n",
65 | "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.1.97)\n",
66 | "Requirement already satisfied: beartype in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.12.0)\n",
67 | "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (1.0.2)\n",
68 | "Requirement already satisfied: torchaudio in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.13.1+cu116)\n",
69 | "Requirement already satisfied: joblib in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (1.2.0)\n",
70 | "Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (1.13.1+cu116)\n",
71 | "Requirement already satisfied: transformers in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (4.26.0)\n",
72 | "Requirement already satisfied: Mega-pytorch in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.0.12)\n",
73 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (4.64.1)\n",
74 | "Requirement already satisfied: accelerate in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.15.0)\n",
75 | "Requirement already satisfied: vector-quantize-pytorch>=0.10.15 in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.10.15)\n",
76 | "Requirement already satisfied: einops>=0.6 in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.6.0)\n",
77 | "Requirement already satisfied: local-attention>=1.5.7 in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (1.5.8)\n",
78 | "Requirement already satisfied: fairseq in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.12.2)\n",
79 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->audiolm-pytorch) (4.4.0)\n",
80 | "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from accelerate->audiolm-pytorch) (21.3)\n",
81 | "Requirement already satisfied: pyyaml in /usr/local/lib/python3.8/dist-packages (from accelerate->audiolm-pytorch) (6.0)\n",
82 | "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from accelerate->audiolm-pytorch) (1.21.6)\n",
83 | "Requirement already satisfied: psutil in /usr/local/lib/python3.8/dist-packages (from accelerate->audiolm-pytorch) (5.4.8)\n",
84 | "Requirement already satisfied: bitarray in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (2.6.2)\n",
85 | "Requirement already satisfied: regex in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (2022.6.2)\n",
86 | "Requirement already satisfied: hydra-core<1.1,>=1.0.7 in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (1.0.7)\n",
87 | "Requirement already satisfied: cffi in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (1.15.1)\n",
88 | "Requirement already satisfied: cython in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (0.29.33)\n",
89 | "Requirement already satisfied: omegaconf<2.1 in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (2.0.6)\n",
90 | "Requirement already satisfied: sacrebleu>=1.4.12 in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (2.3.1)\n",
91 | "Requirement already satisfied: scipy in /usr/local/lib/python3.8/dist-packages (from Mega-pytorch->audiolm-pytorch) (1.7.3)\n",
92 | "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn->audiolm-pytorch) (3.1.0)\n",
93 | "Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers->audiolm-pytorch) (2.25.1)\n",
94 | "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.8/dist-packages (from transformers->audiolm-pytorch) (0.12.0)\n",
95 | "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.8/dist-packages (from transformers->audiolm-pytorch) (0.13.2)\n",
96 | "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers->audiolm-pytorch) (3.9.0)\n",
97 | "Requirement already satisfied: antlr4-python3-runtime==4.8 in /usr/local/lib/python3.8/dist-packages (from hydra-core<1.1,>=1.0.7->fairseq->audiolm-pytorch) (4.8)\n",
98 | "Requirement already satisfied: importlib-resources in /usr/local/lib/python3.8/dist-packages (from hydra-core<1.1,>=1.0.7->fairseq->audiolm-pytorch) (5.10.2)\n",
99 | "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.0->accelerate->audiolm-pytorch) (3.0.9)\n",
100 | "Requirement already satisfied: portalocker in /usr/local/lib/python3.8/dist-packages (from sacrebleu>=1.4.12->fairseq->audiolm-pytorch) (2.7.0)\n",
101 | "Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.8/dist-packages (from sacrebleu>=1.4.12->fairseq->audiolm-pytorch) (0.8.10)\n",
102 | "Requirement already satisfied: colorama in /usr/local/lib/python3.8/dist-packages (from sacrebleu>=1.4.12->fairseq->audiolm-pytorch) (0.4.6)\n",
103 | "Requirement already satisfied: lxml in /usr/local/lib/python3.8/dist-packages (from sacrebleu>=1.4.12->fairseq->audiolm-pytorch) (4.9.2)\n",
104 | "Requirement already satisfied: pycparser in /usr/local/lib/python3.8/dist-packages (from cffi->fairseq->audiolm-pytorch) (2.21)\n",
105 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers->audiolm-pytorch) (2.10)\n",
106 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers->audiolm-pytorch) (2022.12.7)\n",
107 | "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers->audiolm-pytorch) (1.24.3)\n",
108 | "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers->audiolm-pytorch) (4.0.0)\n",
109 | "Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.8/dist-packages (from importlib-resources->hydra-core<1.1,>=1.0.7->fairseq->audiolm-pytorch) (3.11.0)\n"
110 | ]
111 | }
112 | ],
113 | "source": [
114 | "!pip install audiolm-pytorch"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "metadata": {
120 | "id": "xuNcsDJsvQwh"
121 | },
122 | "source": [
123 | "## Setup\n",
124 | "\n",
125 | "Includes:\n",
126 | "\n",
127 | "- How to generate a placeholder dataset if you haven't already, just the basics to run \"training\" e2e on a tiny dataset\n",
128 | "- How to download a dataset from OpenSLR"
129 | ]
130 | },
131 | {
132 | "cell_type": "markdown",
133 | "metadata": {
134 | "id": "jBxNK5cKW--_"
135 | },
136 | "source": [
137 | "### Imports & paths"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 14,
143 | "metadata": {
144 | "id": "OrNeKngVVM0L"
145 | },
146 | "outputs": [],
147 | "source": [
148 | "# imports\n",
149 | "import math\n",
150 | "import wave\n",
151 | "import struct\n",
152 | "import os\n",
153 | "import urllib.request\n",
154 | "import tarfile\n",
155 | "from audiolm_pytorch import SoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM\n",
156 | "from torch import nn\n",
157 | "import torch\n",
158 | "import torchaudio\n",
159 | "\n",
160 | "\n",
161 | "# define all dataset paths, checkpoints, etc\n",
162 | "dataset_folder = \"placeholder_dataset\"\n",
163 | "soundstream_ckpt = \"results/soundstream.8.pt\" # this can change depending on number of steps\n",
164 | "hubert_ckpt = 'hubert/hubert_base_ls960.pt'\n",
165 | "hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row \"HuBERT Base (~95M params)\", column Quantizer"
166 | ]
167 | },
168 | {
169 | "cell_type": "markdown",
170 | "metadata": {
171 | "id": "pA56YODZXBtf"
172 | },
173 | "source": [
174 | "### Data"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": 15,
180 | "metadata": {
181 | "id": "6nnPceFWwedh"
182 | },
183 | "outputs": [],
184 | "source": [
185 | "# Placeholder data generation\n",
186 | "def get_sinewave(freq=440.0, duration_ms=200, volume=1.0, sample_rate=44100.0):\n",
187 | " # code adapted from https://stackoverflow.com/a/33913403\n",
188 | " audio = []\n",
189 | " num_samples = duration_ms * (sample_rate / 1000.0)\n",
190 | " for x in range(int(num_samples)):\n",
191 | " audio.append(volume * math.sin(2 * math.pi * freq * (x / sample_rate)))\n",
192 | " return audio\n",
193 | "\n",
194 | "def save_wav(file_name, audio, sample_rate=44100.0):\n",
195 | " # Open up a wav file\n",
196 | " wav_file=wave.open(file_name,\"w\")\n",
197 | " # wav params\n",
198 | " nchannels = 1\n",
199 | " sampwidth = 2\n",
200 | " # 44100 is the industry standard sample rate - CD quality. If you need to\n",
201 | " # save on file size you can adjust it downwards. The stanard for low quality\n",
202 | " # is 8000 or 8kHz.\n",
203 | " nframes = len(audio)\n",
204 | " comptype = \"NONE\"\n",
205 | " compname = \"not compressed\"\n",
206 | " wav_file.setparams((nchannels, sampwidth, sample_rate, nframes, comptype, compname))\n",
207 | " # WAV files here are using short, 16 bit, signed integers for the \n",
208 | " # sample size. So we multiply the floating point data we have by 32767, the\n",
209 | " # maximum value for a short integer. NOTE: It is theortically possible to\n",
210 | " # use the floating point -1.0 to 1.0 data directly in a WAV file but not\n",
211 | " # obvious how to do that using the wave module in python.\n",
212 | " for sample in audio:\n",
213 | " wav_file.writeframes(struct.pack('h', int( sample * 32767.0 )))\n",
214 | " wav_file.close()\n",
215 | " return\n",
216 | "\n",
217 | "def make_placeholder_dataset():\n",
218 | " # Make a placeholder dataset with a few .wav files that you can \"train\" on, just to verify things work e2e\n",
219 | " if os.path.isdir(dataset_folder):\n",
220 | " return\n",
221 | " os.makedirs(dataset_folder)\n",
222 | " save_wav(f\"{dataset_folder}/example.wav\", get_sinewave())\n",
223 | " save_wav(f\"{dataset_folder}/example2.wav\", get_sinewave(duration_ms=500))\n",
224 | " os.makedirs(f\"{dataset_folder}/subdirectory\")\n",
225 | " save_wav(f\"{dataset_folder}/subdirectory/example.wav\", get_sinewave(freq=330.0))\n",
226 | "\n",
227 | "make_placeholder_dataset()"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 16,
233 | "metadata": {
234 | "id": "jwYCbFpHvmRI"
235 | },
236 | "outputs": [],
237 | "source": [
238 | "# Get actual dataset. Uncomment this if you want to try training on real data\n",
239 | "\n",
240 | "# full dataset: https://www.openslr.org/12\n",
241 | "# We'll use https://us.openslr.org/resources/12/dev-clean.tar.gz development set, \"clean\" speech.\n",
242 | "# We *should* train on, well, training, but this is just to demo running things end-to-end at all so I just picked a small clean set.\n",
243 | "\n",
244 | "# url = \"https://us.openslr.org/resources/12/dev-clean.tar.gz\"\n",
245 | "# filename = \"dev-clean\"\n",
246 | "# filename_targz = filename + \".tar.gz\"\n",
247 | "# if not os.path.isfile(filename_targz):\n",
248 | "# urllib.request.urlretrieve(url, filename_targz)\n",
249 | "# if not os.path.isdir(filename):\n",
250 | "# # open file\n",
251 | "# with tarfile.open(filename_targz) as t:\n",
252 | "# t.extractall(filename)"
253 | ]
254 | },
255 | {
256 | "cell_type": "markdown",
257 | "metadata": {
258 | "id": "PYcI0aXEwuxR"
259 | },
260 | "source": [
261 | "## Training\n",
262 | "\n",
263 | "Now that we have a dataset, we can train AudioLM.\n",
264 | "\n",
265 | "**Note**: do NOT type \"y\" to overwrite previous experiments/ checkpoints when running through the cells here unless you're ready to the entire results folder! Otherwise you will end up erasing things (e.g. you train SoundStream first, and if you choose \"overwrite\" then you lose the SoundStream checkpoint when you then train SemanticTransformer)."
266 | ]
267 | },
268 | {
269 | "cell_type": "markdown",
270 | "metadata": {
271 | "id": "T7GiyBcBWiZV"
272 | },
273 | "source": [
274 | "### SoundStream"
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": 17,
280 | "metadata": {
281 | "colab": {
282 | "base_uri": "https://localhost:8080/"
283 | },
284 | "id": "nGU0OZiOwPEO",
285 | "outputId": "21dd959c-6458-4477-8403-cf810166f38d"
286 | },
287 | "outputs": [
288 | {
289 | "name": "stdout",
290 | "output_type": "stream",
291 | "text": [
292 | "training with dataset of 2 samples and validating with randomly splitted 1 samples\n",
293 | "0: soundstream total loss: 167.262, soundstream recon loss: 1.123 | discr (scale 1) loss: 2.003 | discr (scale 0.5) loss: 1.999 | discr (scale 0.25) loss: 1.999\n",
294 | "0: saving to results\n",
295 | "0: saving model to results\n",
296 | "1: soundstream total loss: 182.282, soundstream recon loss: 1.389 | discr (scale 1) loss: 1.938 | discr (scale 0.5) loss: 1.928 | discr (scale 0.25) loss: 1.928\n",
297 | "2: soundstream total loss: 196.668, soundstream recon loss: 1.450 | discr (scale 1) loss: 1.845 | discr (scale 0.5) loss: 1.842 | discr (scale 0.25) loss: 1.843\n",
298 | "2: saving to results\n",
299 | "3: soundstream total loss: 216.329, soundstream recon loss: 1.451 | discr (scale 1) loss: 1.751 | discr (scale 0.5) loss: 1.750 | discr (scale 0.25) loss: 1.757\n",
300 | "4: soundstream total loss: 206.804, soundstream recon loss: 1.167 | discr (scale 1) loss: 1.671 | discr (scale 0.5) loss: 1.706 | discr (scale 0.25) loss: 1.724\n",
301 | "4: saving to results\n",
302 | "4: saving model to results\n",
303 | "5: soundstream total loss: 195.325, soundstream recon loss: 0.929 | discr (scale 1) loss: 1.348 | discr (scale 0.5) loss: 1.372 | discr (scale 0.25) loss: 1.482\n",
304 | "6: soundstream total loss: 245.195, soundstream recon loss: 1.054 | discr (scale 1) loss: 1.060 | discr (scale 0.5) loss: 1.244 | discr (scale 0.25) loss: 1.288\n",
305 | "6: saving to results\n",
306 | "7: soundstream total loss: 245.724, soundstream recon loss: 0.970 | discr (scale 1) loss: 1.092 | discr (scale 0.5) loss: 1.358 | discr (scale 0.25) loss: 1.079\n",
307 | "8: soundstream total loss: 202.707, soundstream recon loss: 0.786 | discr (scale 1) loss: 0.733 | discr (scale 0.5) loss: 0.687 | discr (scale 0.25) loss: 0.790\n",
308 | "8: saving to results\n",
309 | "8: saving model to results\n",
310 | "training complete\n"
311 | ]
312 | }
313 | ],
314 | "source": [
315 | "soundstream = SoundStream(\n",
316 | " codebook_size = 1024,\n",
317 | " rq_num_quantizers = 8,\n",
318 | ")\n",
319 | "\n",
320 | "trainer = SoundStreamTrainer(\n",
321 | " soundstream,\n",
322 | " folder = dataset_folder,\n",
323 | " batch_size = 4,\n",
324 | " grad_accum_every = 8, # effective batch size of 32\n",
325 | " data_max_length = 320 * 32,\n",
326 | " save_results_every = 2,\n",
327 | " save_model_every = 4,\n",
328 | " num_train_steps = 9\n",
329 | ").cuda()\n",
330 | "# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes\n",
331 | "# adjusting save_*_every variables for the same reason\n",
332 | "\n",
333 | "trainer.train()"
334 | ]
335 | },
336 | {
337 | "cell_type": "markdown",
338 | "metadata": {
339 | "id": "lqjN28L4Wc5Q"
340 | },
341 | "source": [
342 | "### SemanticTransformer"
343 | ]
344 | },
345 | {
346 | "cell_type": "code",
347 | "execution_count": 18,
348 | "metadata": {
349 | "colab": {
350 | "base_uri": "https://localhost:8080/"
351 | },
352 | "id": "qgd962eSvDzS",
353 | "outputId": "b0550cde-0c8b-4a39-f896-f6f813f50f8c"
354 | },
355 | "outputs": [
356 | {
357 | "name": "stderr",
358 | "output_type": "stream",
359 | "text": [
360 | "/usr/local/lib/python3.8/dist-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator MiniBatchKMeans from version 0.24.0 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
361 | "https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations\n",
362 | " warnings.warn(\n"
363 | ]
364 | },
365 | {
366 | "name": "stdout",
367 | "output_type": "stream",
368 | "text": [
369 | "training with dataset of 2 samples and validating with randomly splitted 1 samples\n",
370 | "do you want to clear previous experiment checkpoints and results? (y/n) n\n",
371 | "0: loss: 6.648584365844727\n",
372 | "0: valid loss 5.763116359710693\n",
373 | "0: saving model to results\n",
374 | "training complete\n"
375 | ]
376 | }
377 | ],
378 | "source": [
379 | "# hubert checkpoints can be downloaded at\n",
380 | "# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert\n",
381 | "if not os.path.isdir(\"hubert\"):\n",
382 | " os.makedirs(\"hubert\")\n",
383 | "if not os.path.isfile(hubert_ckpt):\n",
384 | " hubert_ckpt_download = f\"https://dl.fbaipublicfiles.com/{hubert_ckpt}\"\n",
385 | " urllib.request.urlretrieve(hubert_ckpt_download, f\"./{hubert_ckpt}\")\n",
386 | "if not os.path.isfile(hubert_quantizer):\n",
387 | " hubert_quantizer_download = f\"https://dl.fbaipublicfiles.com/{hubert_quantizer}\"\n",
388 | " urllib.request.urlretrieve(hubert_quantizer_download, f\"./{hubert_quantizer}\")\n",
389 | "\n",
390 | "wav2vec = HubertWithKmeans(\n",
391 | " checkpoint_path = f'./{hubert_ckpt}',\n",
392 | " kmeans_path = f'./{hubert_quantizer}'\n",
393 | ")\n",
394 | "\n",
395 | "semantic_transformer = SemanticTransformer(\n",
396 | " num_semantic_tokens = wav2vec.codebook_size,\n",
397 | " dim = 1024,\n",
398 | " depth = 6\n",
399 | ").cuda()\n",
400 | "\n",
401 | "\n",
402 | "trainer = SemanticTransformerTrainer(\n",
403 | " transformer = semantic_transformer,\n",
404 | " wav2vec = wav2vec,\n",
405 | " folder = dataset_folder,\n",
406 | " batch_size = 1,\n",
407 | " data_max_length = 320 * 32,\n",
408 | " num_train_steps = 1\n",
409 | ")\n",
410 | "\n",
411 | "trainer.train()"
412 | ]
413 | },
414 | {
415 | "cell_type": "markdown",
416 | "metadata": {
417 | "id": "4eEvIzhEWwRz"
418 | },
419 | "source": [
420 | "### CoarseTransformer"
421 | ]
422 | },
423 | {
424 | "cell_type": "code",
425 | "execution_count": 19,
426 | "metadata": {
427 | "colab": {
428 | "base_uri": "https://localhost:8080/"
429 | },
430 | "id": "1LeWmaNHzzY9",
431 | "outputId": "7e7ecb3b-f59e-4d18-c8c9-64762e9b43fc"
432 | },
433 | "outputs": [
434 | {
435 | "name": "stderr",
436 | "output_type": "stream",
437 | "text": [
438 | "/usr/local/lib/python3.8/dist-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator MiniBatchKMeans from version 0.24.0 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
439 | "https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations\n",
440 | " warnings.warn(\n"
441 | ]
442 | },
443 | {
444 | "name": "stdout",
445 | "output_type": "stream",
446 | "text": [
447 | "training with dataset of 2 samples and validating with randomly splitted 1 samples\n",
448 | "do you want to clear previous experiment checkpoints and results? (y/n) n\n",
449 | "0: loss: 63.983970642089844\n",
450 | "0: valid loss 63.398582458496094\n",
451 | "0: saving model to results\n",
452 | "1: loss: 65.85967254638672\n",
453 | "2: loss: 62.4722900390625\n",
454 | "2: valid loss 50.01605987548828\n",
455 | "3: loss: 11.735434532165527\n",
456 | "4: loss: 3.976104497909546\n",
457 | "4: valid loss 46.094608306884766\n",
458 | "4: saving model to results\n",
459 | "5: loss: 58.27140426635742\n",
460 | "6: loss: 41.68347930908203\n",
461 | "6: valid loss 45.54595184326172\n",
462 | "7: loss: 2.2387890815734863\n",
463 | "8: loss: 0.4718627631664276\n",
464 | "8: valid loss 39.10848617553711\n",
465 | "8: saving model to results\n",
466 | "training complete\n"
467 | ]
468 | }
469 | ],
470 | "source": [
471 | "wav2vec = HubertWithKmeans(\n",
472 | " checkpoint_path = f'./{hubert_ckpt}',\n",
473 | " kmeans_path = f'./{hubert_quantizer}'\n",
474 | ")\n",
475 | "\n",
476 | "soundstream = SoundStream(\n",
477 | " codebook_size = 1024,\n",
478 | " rq_num_quantizers = 8,\n",
479 | ")\n",
480 | "\n",
481 | "soundstream.load(f\"./{soundstream_ckpt}\")\n",
482 | "\n",
483 | "coarse_transformer = CoarseTransformer(\n",
484 | " num_semantic_tokens = wav2vec.codebook_size,\n",
485 | " codebook_size = 1024,\n",
486 | " num_coarse_quantizers = 3,\n",
487 | " dim = 512,\n",
488 | " depth = 6\n",
489 | ")\n",
490 | "\n",
491 | "trainer = CoarseTransformerTrainer(\n",
492 | " transformer = coarse_transformer,\n",
493 | " codec = soundstream,\n",
494 | " wav2vec = wav2vec,\n",
495 | " folder = dataset_folder,\n",
496 | " batch_size = 1,\n",
497 | " data_max_length = 320 * 32,\n",
498 | " save_results_every = 2,\n",
499 | " save_model_every = 4,\n",
500 | " num_train_steps = 9\n",
501 | ")\n",
502 | "# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes\n",
503 | "# adjusting save_*_every variables for the same reason\n",
504 | "\n",
505 | "trainer.train()"
506 | ]
507 | },
508 | {
509 | "cell_type": "markdown",
510 | "metadata": {
511 | "id": "fRvj7qOJWzmw"
512 | },
513 | "source": [
514 | "### FineTransformer"
515 | ]
516 | },
517 | {
518 | "cell_type": "code",
519 | "execution_count": 20,
520 | "metadata": {
521 | "colab": {
522 | "base_uri": "https://localhost:8080/"
523 | },
524 | "id": "ZRaEhRRKWg8F",
525 | "outputId": "7cc166c4-c8e9-45ef-8293-8f5381c2d3af"
526 | },
527 | "outputs": [
528 | {
529 | "name": "stdout",
530 | "output_type": "stream",
531 | "text": [
532 | "training with dataset of 2 samples and validating with randomly splitted 1 samples\n",
533 | "do you want to clear previous experiment checkpoints and results? (y/n) n\n",
534 | "0: loss: 70.90608215332031\n",
535 | "0: valid loss 65.99951171875\n",
536 | "0: saving model to results\n",
537 | "1: loss: 43.6014289855957\n",
538 | "2: loss: 8.300681114196777\n",
539 | "3: loss: 61.23375701904297\n",
540 | "4: loss: 63.34052276611328\n",
541 | "5: loss: 2.010118246078491\n",
542 | "6: loss: 56.52588653564453\n",
543 | "7: loss: 0.5423888564109802\n",
544 | "8: loss: 0.005095238331705332\n",
545 | "training complete\n"
546 | ]
547 | }
548 | ],
549 | "source": [
550 | "soundstream = SoundStream(\n",
551 | " codebook_size = 1024,\n",
552 | " rq_num_quantizers = 8,\n",
553 | ")\n",
554 | "\n",
555 | "soundstream.load(f\"./{soundstream_ckpt}\")\n",
556 | "\n",
557 | "fine_transformer = FineTransformer(\n",
558 | " num_coarse_quantizers = 3,\n",
559 | " num_fine_quantizers = 5,\n",
560 | " codebook_size = 1024,\n",
561 | " dim = 512,\n",
562 | " depth = 6\n",
563 | ")\n",
564 | "\n",
565 | "trainer = FineTransformerTrainer(\n",
566 | " transformer = fine_transformer,\n",
567 | " codec = soundstream,\n",
568 | " folder = dataset_folder,\n",
569 | " batch_size = 1,\n",
570 | " data_max_length = 320 * 32,\n",
571 | " num_train_steps = 9\n",
572 | ")\n",
573 | "# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes\n",
574 | "# adjusting save_*_every variables for the same reason\n",
575 | "\n",
576 | "trainer.train()"
577 | ]
578 | },
579 | {
580 | "cell_type": "markdown",
581 | "metadata": {
582 | "id": "QoHgkgA3XKXH"
583 | },
584 | "source": [
585 | "## Inference"
586 | ]
587 | },
588 | {
589 | "cell_type": "code",
590 | "execution_count": 21,
591 | "metadata": {
592 | "colab": {
593 | "base_uri": "https://localhost:8080/"
594 | },
595 | "id": "rzghrux5WinW",
596 | "outputId": "9dd39f7f-0046-4a5f-826e-a442345987af"
597 | },
598 | "outputs": [
599 | {
600 | "name": "stderr",
601 | "output_type": "stream",
602 | "text": [
603 | "generating semantic: 0%| | 10/2048 [00:00<00:25, 78.55it/s]\n",
604 | "generating coarse: 100%|██████████| 512/512 [00:14<00:00, 34.83it/s]\n",
605 | "generating fine: 100%|██████████| 512/512 [02:56<00:00, 2.91it/s]\n"
606 | ]
607 | }
608 | ],
609 | "source": [
610 | "# Everything together\n",
611 | "audiolm = AudioLM(\n",
612 | " wav2vec = wav2vec,\n",
613 | " codec = soundstream,\n",
614 | " semantic_transformer = semantic_transformer,\n",
615 | " coarse_transformer = coarse_transformer,\n",
616 | " fine_transformer = fine_transformer\n",
617 | ")\n",
618 | "\n",
619 | "generated_wav = audiolm(batch_size = 1)"
620 | ]
621 | },
622 | {
623 | "cell_type": "code",
624 | "execution_count": 22,
625 | "metadata": {
626 | "id": "4rQPHTSRngEr"
627 | },
628 | "outputs": [],
629 | "source": [
630 | "output_path = \"out.wav\"\n",
631 | "sample_rate = 44100\n",
632 | "torchaudio.save(output_path, generated_wav.cpu(), sample_rate)"
633 | ]
634 | },
635 | {
636 | "cell_type": "code",
637 | "execution_count": 22,
638 | "metadata": {
639 | "id": "is9wLY_ncDYK"
640 | },
641 | "outputs": [],
642 | "source": []
643 | }
644 | ],
645 | "metadata": {
646 | "accelerator": "GPU",
647 | "colab": {
648 | "provenance": []
649 | },
650 | "gpuClass": "standard",
651 | "kernelspec": {
652 | "display_name": "Python 3 (ipykernel)",
653 | "language": "python",
654 | "name": "python3"
655 | },
656 | "language_info": {
657 | "codemirror_mode": {
658 | "name": "ipython",
659 | "version": 3
660 | },
661 | "file_extension": ".py",
662 | "mimetype": "text/x-python",
663 | "name": "python",
664 | "nbconvert_exporter": "python",
665 | "pygments_lexer": "ipython3",
666 | "version": "3.9.13"
667 | }
668 | },
669 | "nbformat": 4,
670 | "nbformat_minor": 1
671 | }
672 |
--------------------------------------------------------------------------------
/audiolm_pytorch/soundstream.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from itertools import cycle
3 | from pathlib import Path
4 |
5 | from functools import partial, wraps
6 | from itertools import zip_longest
7 | from typing import Optional
8 |
9 | import torch
10 | from torch import nn, einsum
11 | from torch.autograd import grad as torch_grad
12 | import torch.nn.functional as F
13 | from torch.linalg import vector_norm
14 |
15 | import torchaudio.transforms as T
16 | from torchaudio.functional import resample
17 |
18 | from einops import rearrange, reduce, pack, unpack
19 |
20 | from vector_quantize_pytorch import GroupedResidualVQ
21 |
22 | from local_attention import LocalMHA
23 | from local_attention.transformer import FeedForward, DynamicPositionBias
24 |
25 | from audiolm_pytorch.utils import curtail_to_multiple
26 |
27 | from audiolm_pytorch.version import __version__
28 | from packaging import version
29 | parsed_version = version.parse(__version__)
30 |
31 | import pickle
32 |
33 | # helper functions
34 |
35 | def exists(val):
36 | return val is not None
37 |
38 | def default(val, d):
39 | return val if exists(val) else d
40 |
41 | def cast_tuple(t, l = 1):
42 | return ((t,) * l) if not isinstance(t, tuple) else t
43 |
44 | def filter_by_keys(fn, d):
45 | return {k: v for k, v in d.items() if fn(k)}
46 |
47 | def map_keys(fn, d):
48 | return {fn(k): v for k, v in d.items()}
49 |
50 | # gan losses
51 |
52 | def log(t, eps = 1e-20):
53 | return torch.log(t.clamp(min = eps))
54 |
55 | def hinge_discr_loss(fake, real):
56 | return (F.relu(1 + fake) + F.relu(1 - real)).mean()
57 |
58 | def hinge_gen_loss(fake):
59 | return -fake.mean()
60 |
61 | def leaky_relu(p = 0.1):
62 | return nn.LeakyReLU(p)
63 |
64 | def gradient_penalty(wave, output, weight = 10):
65 | batch_size, device = wave.shape[0], wave.device
66 |
67 | gradients = torch_grad(
68 | outputs = output,
69 | inputs = wave,
70 | grad_outputs = torch.ones_like(output),
71 | create_graph = True,
72 | retain_graph = True,
73 | only_inputs = True
74 | )[0]
75 |
76 | gradients = rearrange(gradients, 'b ... -> b (...)')
77 | return weight * ((vector_norm(gradients, dim = 1) - 1) ** 2).mean()
78 |
79 | # better sequential
80 |
81 | def Sequential(*mods):
82 | return nn.Sequential(*filter(exists, mods))
83 |
84 | # discriminators
85 |
86 | class MultiScaleDiscriminator(nn.Module):
87 | def __init__(
88 | self,
89 | channels = 16,
90 | layers = 4,
91 | groups = (4, 16, 64, 256),
92 | chan_max = 1024,
93 | input_channels = 1
94 | ):
95 | super().__init__()
96 | self.init_conv = nn.Conv1d(input_channels, channels, 15, padding = 7)
97 | self.conv_layers = nn.ModuleList([])
98 |
99 | curr_channels = channels
100 |
101 | for _, group in zip(range(layers), groups):
102 | chan_out = min(curr_channels * 4, chan_max)
103 |
104 | self.conv_layers.append(nn.Sequential(
105 | nn.Conv1d(curr_channels, chan_out, 41, stride = 4, padding = 20, groups = group),
106 | leaky_relu()
107 | ))
108 |
109 | curr_channels = chan_out
110 |
111 | self.final_conv = nn.Sequential(
112 | nn.Conv1d(curr_channels, curr_channels, 5, padding = 2),
113 | leaky_relu(),
114 | nn.Conv1d(curr_channels, 1, 3, padding = 1),
115 | )
116 |
117 | def forward(
118 | self,
119 | x,
120 | return_intermediates = False
121 | ):
122 | x = self.init_conv(x)
123 | intermediates = []
124 |
125 | for layer in self.conv_layers:
126 | x = layer(x)
127 | intermediates.append(x)
128 |
129 | out = self.final_conv(x)
130 |
131 | if not return_intermediates:
132 | return out
133 |
134 | return out, intermediates
135 |
136 | # autoregressive squeeze excitation
137 | # https://arxiv.org/abs/1709.01507
138 |
139 | class SqueezeExcite(nn.Module):
140 | def __init__(self, dim, reduction_factor = 4, dim_minimum = 8):
141 | super().__init__()
142 | dim_inner = max(dim_minimum, dim // reduction_factor)
143 | self.net = nn.Sequential(
144 | nn.Conv1d(dim, dim_inner, 1),
145 | nn.SiLU(),
146 | nn.Conv1d(dim_inner, dim, 1),
147 | nn.Sigmoid()
148 | )
149 |
150 | def forward(self, x):
151 | seq, device = x.shape[-2], x.device
152 |
153 | # cumulative mean - since it is autoregressive
154 |
155 | cum_sum = x.cumsum(dim = -2)
156 | denom = torch.arange(1, seq + 1, device = device).float()
157 | cum_mean = cum_sum / rearrange(denom, 'n -> n 1')
158 |
159 | # glu gate
160 |
161 | gate = self.net(cum_mean)
162 |
163 | return x * gate
164 |
165 | # complex stft discriminator
166 |
167 | class ModReLU(nn.Module):
168 | """
169 | https://arxiv.org/abs/1705.09792
170 | https://github.com/pytorch/pytorch/issues/47052#issuecomment-718948801
171 | """
172 | def __init__(self):
173 | super().__init__()
174 | self.b = nn.Parameter(torch.tensor(0.))
175 |
176 | def forward(self, x):
177 | return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x))
178 |
179 | class ComplexConv2d(nn.Module):
180 | def __init__(
181 | self,
182 | dim,
183 | dim_out,
184 | kernel_size,
185 | stride = 1,
186 | padding = 0
187 | ):
188 | super().__init__()
189 | conv = nn.Conv2d(dim, dim_out, kernel_size, dtype = torch.complex64)
190 | self.weight = nn.Parameter(torch.view_as_real(conv.weight))
191 | self.bias = nn.Parameter(torch.view_as_real(conv.bias))
192 |
193 | self.stride = stride
194 | self.padding = padding
195 |
196 | def forward(self, x):
197 | weight, bias = map(torch.view_as_complex, (self.weight, self.bias))
198 |
199 | x = x.to(weight.dtype)
200 | return F.conv2d(x, weight, bias, stride = self.stride, padding = self.padding)
201 |
202 | def ComplexSTFTResidualUnit(chan_in, chan_out, strides):
203 | kernel_sizes = tuple(map(lambda t: t + 2, strides))
204 | paddings = tuple(map(lambda t: t // 2, kernel_sizes))
205 |
206 | return nn.Sequential(
207 | Residual(Sequential(
208 | ComplexConv2d(chan_in, chan_in, 3, padding = 1),
209 | ModReLU(),
210 | ComplexConv2d(chan_in, chan_in, 3, padding = 1)
211 | )),
212 | ComplexConv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings)
213 | )
214 |
215 | class ComplexSTFTDiscriminator(nn.Module):
216 | def __init__(
217 | self,
218 | *,
219 | channels = 32,
220 | strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)),
221 | chan_mults = (1, 2, 4, 4, 8, 8),
222 | input_channels = 1,
223 | n_fft = 1024,
224 | hop_length = 256,
225 | win_length = 1024,
226 | stft_normalized = False,
227 | logits_abs = True
228 | ):
229 | super().__init__()
230 | self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3)
231 |
232 | layer_channels = tuple(map(lambda mult: mult * channels, chan_mults))
233 | layer_channels = (channels, *layer_channels)
234 | layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:]))
235 |
236 | curr_channels = channels
237 |
238 | self.layers = nn.ModuleList([])
239 |
240 | for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs):
241 | self.layers.append(ComplexSTFTResidualUnit(chan_in, chan_out, layer_stride))
242 |
243 | self.final_conv = ComplexConv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16
244 |
245 | # stft settings
246 |
247 | self.stft_normalized = stft_normalized
248 |
249 | self.n_fft = n_fft
250 | self.hop_length = hop_length
251 | self.win_length = win_length
252 |
253 | # how to output the logits into real space
254 | self.logits_abs = logits_abs
255 |
256 | def forward(self, x, return_intermediates = False):
257 | x = rearrange(x, 'b 1 n -> b n')
258 |
259 | '''
260 | reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows:
261 | The STFT-based discriminator is illustrated in Figure 4
262 | and operates on a single scale, computing the STFT with a
263 | window length of W = 1024 samples and a hop length of
264 | H = 256 samples
265 | '''
266 |
267 | x = torch.stft(
268 | x,
269 | self.n_fft,
270 | hop_length = self.hop_length,
271 | win_length = self.win_length,
272 | normalized = self.stft_normalized,
273 | return_complex = True
274 | )
275 |
276 | x = rearrange(x, 'b ... -> b 1 ...')
277 |
278 | intermediates = []
279 |
280 | x = self.init_conv(x)
281 |
282 | intermediates.append(x)
283 |
284 | for layer in self.layers:
285 | x = layer(x)
286 | intermediates.append(x)
287 |
288 | complex_logits = self.final_conv(x)
289 |
290 | if self.logits_abs:
291 | complex_logits = complex_logits.abs()
292 | else:
293 | complex_logits = torch.view_as_real(complex_logits)
294 |
295 | if not return_intermediates:
296 | return complex_logits
297 |
298 | return complex_logits, intermediates
299 |
300 | # sound stream
301 |
302 | class Residual(nn.Module):
303 | def __init__(self, fn):
304 | super().__init__()
305 | self.fn = fn
306 |
307 | def forward(self, x, **kwargs):
308 | return self.fn(x, **kwargs) + x
309 |
310 | class CausalConv1d(nn.Module):
311 | def __init__(self, chan_in, chan_out, kernel_size, pad_mode = 'reflect', **kwargs):
312 | super().__init__()
313 | kernel_size = kernel_size
314 | dilation = kwargs.get('dilation', 1)
315 | stride = kwargs.get('stride', 1)
316 | self.pad_mode = pad_mode
317 | self.causal_padding = dilation * (kernel_size - 1) + (1 - stride)
318 |
319 | self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, **kwargs)
320 |
321 | def forward(self, x):
322 | x = F.pad(x, (self.causal_padding, 0), mode = self.pad_mode)
323 | return self.conv(x)
324 |
325 | class CausalConvTranspose1d(nn.Module):
326 | def __init__(self, chan_in, chan_out, kernel_size, stride, **kwargs):
327 | super().__init__()
328 | self.upsample_factor = stride
329 | self.padding = kernel_size - 1
330 | self.conv = nn.ConvTranspose1d(chan_in, chan_out, kernel_size, stride, **kwargs)
331 |
332 | def forward(self, x):
333 | n = x.shape[-1]
334 |
335 | out = self.conv(x)
336 | out = out[..., :(n * self.upsample_factor)]
337 |
338 | return out
339 |
340 | def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7, squeeze_excite = False, pad_mode = 'reflect'):
341 | return Residual(Sequential(
342 | CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation, pad_mode = pad_mode),
343 | nn.ELU(),
344 | CausalConv1d(chan_out, chan_out, 1, pad_mode = pad_mode),
345 | nn.ELU(),
346 | SqueezeExcite(chan_out) if squeeze_excite else None
347 | ))
348 |
349 | def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'):
350 | it = cycle(cycle_dilations)
351 | residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode)
352 |
353 | return nn.Sequential(
354 | residual_unit(chan_in, chan_in, next(it)),
355 | residual_unit(chan_in, chan_in, next(it)),
356 | residual_unit(chan_in, chan_in, next(it)),
357 | CausalConv1d(chan_in, chan_out, 2 * stride, stride = stride)
358 | )
359 |
360 | def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'):
361 | even_stride = (stride % 2 == 0)
362 | padding = (stride + (0 if even_stride else 1)) // 2
363 | output_padding = 0 if even_stride else 1
364 |
365 | residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode)
366 |
367 | it = cycle(cycle_dilations)
368 | return nn.Sequential(
369 | CausalConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride),
370 | residual_unit(chan_out, chan_out, next(it)),
371 | residual_unit(chan_out, chan_out, next(it)),
372 | residual_unit(chan_out, chan_out, next(it)),
373 | )
374 |
375 | class LocalTransformer(nn.Module):
376 | def __init__(
377 | self,
378 | *,
379 | dim,
380 | depth,
381 | heads,
382 | window_size,
383 | dynamic_pos_bias = False,
384 | **kwargs
385 | ):
386 | super().__init__()
387 | self.window_size = window_size
388 | self.layers = nn.ModuleList([])
389 |
390 | self.pos_bias = None
391 | if dynamic_pos_bias:
392 | self.pos_bias = DynamicPositionBias(dim = dim // 2, heads = heads)
393 |
394 | for _ in range(depth):
395 | self.layers.append(nn.ModuleList([
396 | LocalMHA(dim = dim, heads = heads, qk_rmsnorm = True, window_size = window_size, use_rotary_pos_emb = not dynamic_pos_bias, use_xpos = True, **kwargs),
397 | FeedForward(dim = dim)
398 | ]))
399 |
400 | def forward(self, x):
401 | w = self.window_size
402 |
403 | attn_bias = self.pos_bias(w, w * 2) if exists(self.pos_bias) else None
404 |
405 | for attn, ff in self.layers:
406 | x = attn(x, attn_bias = attn_bias) + x
407 | x = ff(x) + x
408 |
409 | return x
410 |
411 | class FiLM(nn.Module):
412 | def __init__(self, dim, dim_cond):
413 | super().__init__()
414 | self.to_cond = nn.Linear(dim_cond, dim * 2)
415 |
416 | def forward(self, x, cond):
417 | gamma, beta = self.to_cond(cond).chunk(2, dim = -1)
418 | return x * gamma + beta
419 |
420 | class SoundStream(nn.Module):
421 | def __init__(
422 | self,
423 | *,
424 | channels = 32,
425 | strides = (2, 4, 5, 8),
426 | channel_mults = (2, 4, 8, 16),
427 | codebook_dim = 512,
428 | codebook_size = 1024,
429 | rq_num_quantizers = 8,
430 | rq_commitment_weight = 1.,
431 | rq_ema_decay = 0.95,
432 | rq_quantize_dropout_multiple_of = 1,
433 | rq_groups = 1,
434 | rq_stochastic_sample_codes = False,
435 | rq_kwargs: dict = {},
436 | input_channels = 1,
437 | discr_multi_scales = (1, 0.5, 0.25),
438 | stft_normalized = False,
439 | enc_cycle_dilations = (1, 3, 9),
440 | dec_cycle_dilations = (1, 3, 9),
441 | multi_spectral_window_powers_of_two = tuple(range(6, 12)),
442 | multi_spectral_n_ffts = 512,
443 | multi_spectral_n_mels = 64,
444 | recon_loss_weight = 1.,
445 | multi_spectral_recon_loss_weight = 1e-5,
446 | adversarial_loss_weight = 1.,
447 | feature_loss_weight = 100,
448 | quantize_dropout_cutoff_index = 1,
449 | target_sample_hz = 16000,
450 | use_local_attn = True,
451 | attn_window_size = 128,
452 | attn_dim_head = 64,
453 | attn_heads = 8,
454 | attn_depth = 1,
455 | attn_xpos_scale_base = None,
456 | attn_dynamic_pos_bias = False,
457 | squeeze_excite = False,
458 | complex_stft_discr_logits_abs = True,
459 | pad_mode = 'reflect',
460 | stft_discriminator: Optional[nn.Module] = None # can pass in own stft discriminator
461 | ):
462 | super().__init__()
463 |
464 | # for autosaving the config
465 |
466 | _locals = locals()
467 | _locals.pop('self', None)
468 | _locals.pop('__class__', None)
469 | self._configs = pickle.dumps(_locals)
470 |
471 | # rest of the class
472 |
473 | self.target_sample_hz = target_sample_hz # for resampling on the fly
474 |
475 | self.single_channel = input_channels == 1
476 | self.strides = strides
477 |
478 | layer_channels = tuple(map(lambda t: t * channels, channel_mults))
479 | layer_channels = (channels, *layer_channels)
480 | chan_in_out_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:]))
481 |
482 | encoder_blocks = []
483 |
484 | for ((chan_in, chan_out), layer_stride) in zip(chan_in_out_pairs, strides):
485 | encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations, squeeze_excite, pad_mode))
486 |
487 | self.encoder = nn.Sequential(
488 | CausalConv1d(input_channels, channels, 7, pad_mode = pad_mode),
489 | *encoder_blocks,
490 | CausalConv1d(layer_channels[-1], codebook_dim, 3, pad_mode = pad_mode)
491 | )
492 |
493 | attn_kwargs = dict(
494 | dim = codebook_dim,
495 | dim_head = attn_dim_head,
496 | heads = attn_heads,
497 | depth = attn_depth,
498 | window_size = attn_window_size,
499 | xpos_scale_base = attn_xpos_scale_base,
500 | dynamic_pos_bias = attn_dynamic_pos_bias,
501 | prenorm = True,
502 | causal = True
503 | )
504 |
505 | self.encoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None
506 |
507 | self.encoder_film = FiLM(codebook_dim, dim_cond = 2)
508 |
509 | self.num_quantizers = rq_num_quantizers
510 |
511 | self.codebook_dim = codebook_dim
512 | self.codebook_size = codebook_size
513 |
514 | self.rq_groups = rq_groups
515 |
516 | self.rq = GroupedResidualVQ(
517 | dim = codebook_dim,
518 | num_quantizers = rq_num_quantizers,
519 | codebook_size = codebook_size,
520 | groups = rq_groups,
521 | decay = rq_ema_decay,
522 | commitment_weight = rq_commitment_weight,
523 | quantize_dropout_multiple_of = rq_quantize_dropout_multiple_of,
524 | kmeans_init = True,
525 | threshold_ema_dead_code = 2,
526 | quantize_dropout = True,
527 | quantize_dropout_cutoff_index = quantize_dropout_cutoff_index,
528 | stochastic_sample_codes = rq_stochastic_sample_codes,
529 | **rq_kwargs
530 | )
531 |
532 | self.decoder_film = FiLM(codebook_dim, dim_cond = 2)
533 |
534 | self.decoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None
535 |
536 | decoder_blocks = []
537 |
538 | for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)):
539 | decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations, squeeze_excite, pad_mode))
540 |
541 | self.decoder = nn.Sequential(
542 | CausalConv1d(codebook_dim, layer_channels[-1], 7, pad_mode = pad_mode),
543 | *decoder_blocks,
544 | CausalConv1d(channels, input_channels, 7, pad_mode = pad_mode)
545 | )
546 |
547 | # discriminators
548 |
549 | self.discr_multi_scales = discr_multi_scales
550 | self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))])
551 | discr_rel_factors = [int(s1 / s2) for s1, s2 in zip(discr_multi_scales[:-1], discr_multi_scales[1:])]
552 | self.downsamples = nn.ModuleList([nn.Identity()] + [nn.AvgPool1d(2 * factor, stride = factor, padding = factor) for factor in discr_rel_factors])
553 |
554 | self.stft_discriminator = stft_discriminator
555 |
556 | if not exists(self.stft_discriminator):
557 | self.stft_discriminator = ComplexSTFTDiscriminator(
558 | stft_normalized = stft_normalized,
559 | logits_abs = complex_stft_discr_logits_abs # whether to output as abs() or use view_as_real
560 | )
561 |
562 | # multi spectral reconstruction
563 |
564 | self.mel_spec_transforms = nn.ModuleList([])
565 | self.mel_spec_recon_alphas = []
566 |
567 | num_transforms = len(multi_spectral_window_powers_of_two)
568 | multi_spectral_n_ffts = cast_tuple(multi_spectral_n_ffts, num_transforms)
569 | multi_spectral_n_mels = cast_tuple(multi_spectral_n_mels, num_transforms)
570 |
571 | for powers, n_fft, n_mels in zip_longest(multi_spectral_window_powers_of_two, multi_spectral_n_ffts, multi_spectral_n_mels):
572 | win_length = 2 ** powers
573 | alpha = (win_length / 2) ** 0.5
574 |
575 | calculated_n_fft = default(max(n_fft, win_length), win_length) # @AndreyBocharnikov said this is usually win length, but overridable
576 |
577 | # if any audio experts have an opinion about these settings, please submit a PR
578 |
579 | melspec_transform = T.MelSpectrogram(
580 | sample_rate = target_sample_hz,
581 | n_fft = calculated_n_fft,
582 | win_length = win_length,
583 | hop_length = win_length // 4,
584 | n_mels = n_mels,
585 | normalized = stft_normalized
586 | )
587 |
588 | self.mel_spec_transforms.append(melspec_transform)
589 | self.mel_spec_recon_alphas.append(alpha)
590 |
591 | # loss weights
592 |
593 | self.recon_loss_weight = recon_loss_weight
594 | self.multi_spectral_recon_loss_weight = multi_spectral_recon_loss_weight
595 | self.adversarial_loss_weight = adversarial_loss_weight
596 | self.feature_loss_weight = feature_loss_weight
597 |
598 | self.register_buffer('zero', torch.tensor([0.]), persistent = False)
599 |
600 | @property
601 | def device(self):
602 | return next(self.parameters()).device
603 |
604 | @property
605 | def configs(self):
606 | return pickle.loads(self._configs)
607 |
608 | def decode_from_codebook_indices(self, quantized_indices):
609 | quantized_indices = rearrange(quantized_indices, 'b n (g q) -> g b n q', g = self.rq_groups)
610 |
611 | codes = self.rq.get_codes_from_indices(quantized_indices)
612 | x = reduce(codes, 'g q b n d -> b n (g d)', 'sum')
613 |
614 | return self.decode(x)
615 |
616 | def decode(self, x, quantize = False):
617 | if quantize:
618 | x, *_ = self.rq(x)
619 |
620 | x = self.decoder_attn(x)
621 | x = rearrange(x, 'b n c -> b c n')
622 | return self.decoder(x)
623 |
624 | def save(self, path):
625 | path = Path(path)
626 | pkg = dict(
627 | model = self.state_dict(),
628 | config = self._configs,
629 | version = __version__
630 | )
631 |
632 | torch.save(pkg, str(path))
633 |
634 | @classmethod
635 | def init_and_load_from(cls, path, strict = True):
636 | path = Path(path)
637 | assert path.exists()
638 | pkg = torch.load(str(path), map_location = 'cpu')
639 |
640 | assert 'config' in pkg, 'model configs were not found in this saved checkpoint'
641 |
642 | config = pickle.loads(pkg['config'])
643 | soundstream = cls(**config)
644 | soundstream.load(path, strict = strict)
645 | return soundstream
646 |
647 | def load(self, path, strict = True):
648 | path = Path(path)
649 | assert path.exists()
650 | pkg = torch.load(str(path), map_location = 'cpu')
651 |
652 | # check version
653 |
654 | if 'version' in pkg and version.parse(pkg['version']) < parsed_version:
655 | print(f'soundstream model being loaded was trained on an older version of audiolm-pytorch ({pkg["version"]})')
656 |
657 | has_ema = 'ema_model' in pkg
658 | model_pkg = pkg['ema_model'] if has_ema else pkg['model']
659 |
660 | if has_ema:
661 | model_pkg = filter_by_keys(lambda k: k.startswith('ema_model.'), model_pkg)
662 | model_pkg = map_keys(lambda k: k[len('ema_model.'):], model_pkg)
663 |
664 | self.load_state_dict(model_pkg, strict = strict)
665 |
666 | def load_from_trainer_saved_obj(self, path):
667 | path = Path(path)
668 | assert path.exists()
669 | obj = torch.load(str(path))
670 | self.load_state_dict(obj['model'])
671 |
672 | def non_discr_parameters(self):
673 | return [
674 | *self.encoder.parameters(),
675 | *self.decoder.parameters(),
676 | *(self.encoder_attn.parameters() if exists(self.encoder_attn) else []),
677 | *(self.decoder_attn.parameters() if exists(self.decoder_attn) else []),
678 | *self.encoder_film.parameters(),
679 | *self.decoder_film.parameters()
680 | ]
681 |
682 | @property
683 | def seq_len_multiple_of(self):
684 | return functools.reduce(lambda x, y: x * y, self.strides)
685 |
686 | def process_input(
687 | self,
688 | x,
689 | input_sample_hz = None,
690 | curtail_from_left = False
691 | ):
692 | x, ps = pack([x], '* n')
693 |
694 | if exists(input_sample_hz):
695 | x = resample(x, input_sample_hz, self.target_sample_hz)
696 |
697 | x = curtail_to_multiple(x, self.seq_len_multiple_of, from_left = curtail_from_left)
698 |
699 | if x.ndim == 2:
700 | x = rearrange(x, 'b n -> b 1 n')
701 |
702 | return x, ps
703 |
704 | def forward(
705 | self,
706 | x,
707 | target = None,
708 | is_denoising = None, # if you want to learn film conditioners that teach the soundstream to denoise - target would need to be passed in above
709 | return_encoded = False,
710 | return_discr_loss = False,
711 | return_discr_losses_separately = False,
712 | return_loss_breakdown = False,
713 | return_recons_only = False,
714 | input_sample_hz = None,
715 | apply_grad_penalty = False,
716 | curtail_from_left = False
717 | ):
718 | assert not (exists(is_denoising) and not exists(target))
719 |
720 | process_input = partial(self.process_input, input_sample_hz = input_sample_hz, curtail_from_left = curtail_from_left)
721 |
722 | x, ps = process_input(x)
723 |
724 | if exists(target):
725 | target, _ = process_input(target)
726 |
727 | orig_x = x.clone()
728 |
729 | x = self.encoder(x)
730 |
731 | x = rearrange(x, 'b c n -> b n c')
732 |
733 | if exists(self.encoder_attn):
734 | x = self.encoder_attn(x)
735 |
736 | if exists(is_denoising):
737 | denoise_input = torch.tensor([is_denoising, not is_denoising], dtype = x.dtype, device = self.device) # [1, 0] for denoise, [0, 1] for not denoising
738 | x = self.encoder_film(x, denoise_input)
739 |
740 | x, indices, commit_loss = self.rq(x)
741 |
742 | if return_encoded:
743 | indices = rearrange(indices, 'g b n q -> b n (g q)')
744 | return x, indices, commit_loss
745 |
746 | if exists(is_denoising):
747 | x = self.decoder_film(x, denoise_input)
748 |
749 | if exists(self.decoder_attn):
750 | x = self.decoder_attn(x)
751 |
752 | x = rearrange(x, 'b n c -> b c n')
753 |
754 | recon_x = self.decoder(x)
755 |
756 | if return_recons_only:
757 | recon_x, = unpack(recon_x, ps, '* c n')
758 | return recon_x
759 |
760 | # multi-scale discriminator loss
761 |
762 | if return_discr_loss:
763 | real, fake = orig_x, recon_x.detach()
764 |
765 | stft_discr_loss = None
766 | stft_grad_penalty = None
767 | discr_losses = []
768 | discr_grad_penalties = []
769 |
770 | if self.single_channel:
771 | real, fake = orig_x.clone(), recon_x.detach()
772 | stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real.requires_grad_(), fake))
773 | stft_discr_loss = hinge_discr_loss(stft_fake_logits, stft_real_logits)
774 |
775 | if apply_grad_penalty:
776 | stft_grad_penalty = gradient_penalty(real, stft_discr_loss)
777 |
778 | scaled_real, scaled_fake = real, fake
779 | for discr, downsample in zip(self.discriminators, self.downsamples):
780 | scaled_real, scaled_fake = map(downsample, (scaled_real, scaled_fake))
781 |
782 | real_logits, fake_logits = map(discr, (scaled_real.requires_grad_(), scaled_fake))
783 | one_discr_loss = hinge_discr_loss(fake_logits, real_logits)
784 |
785 | discr_losses.append(one_discr_loss)
786 | if apply_grad_penalty:
787 | discr_grad_penalties.append(gradient_penalty(scaled_real, one_discr_loss))
788 |
789 | if not return_discr_losses_separately:
790 | all_discr_losses = torch.stack(discr_losses).mean()
791 |
792 | if exists(stft_discr_loss):
793 | all_discr_losses = all_discr_losses + stft_discr_loss
794 |
795 | if exists(stft_grad_penalty):
796 | all_discr_losses = all_discr_losses + stft_grad_penalty
797 |
798 | return all_discr_losses
799 |
800 | # return a list of discriminator losses with List[Tuple[str, Tensor]]
801 |
802 | discr_losses_pkg = []
803 |
804 | discr_losses_pkg.extend([(f'scale:{scale}', multi_scale_loss) for scale, multi_scale_loss in zip(self.discr_multi_scales, discr_losses)])
805 |
806 | discr_losses_pkg.extend([(f'scale_grad_penalty:{scale}', discr_grad_penalty) for scale, discr_grad_penalty in zip(self.discr_multi_scales, discr_grad_penalties)])
807 |
808 | if exists(stft_discr_loss):
809 | discr_losses_pkg.append(('stft', stft_discr_loss))
810 |
811 | if exists(stft_grad_penalty):
812 | discr_losses_pkg.append(('stft_grad_penalty', stft_grad_penalty))
813 |
814 | return discr_losses_pkg
815 |
816 | # recon loss
817 |
818 | target = default(target, orig_x) # target can also be passed in, in the case of denoising
819 |
820 | recon_loss = F.mse_loss(target, recon_x)
821 |
822 | # multispectral recon loss - eq (4) and (5) in https://arxiv.org/abs/2107.03312
823 |
824 | multi_spectral_recon_loss = self.zero
825 |
826 | if self.multi_spectral_recon_loss_weight > 0:
827 | for mel_transform, alpha in zip(self.mel_spec_transforms, self.mel_spec_recon_alphas):
828 | orig_mel, recon_mel = map(mel_transform, (orig_x, recon_x))
829 | log_orig_mel, log_recon_mel = map(log, (orig_mel, recon_mel))
830 |
831 | l1_mel_loss = (orig_mel - recon_mel).abs().sum(dim = -2).mean()
832 | l2_log_mel_loss = alpha * vector_norm(log_orig_mel - log_recon_mel, dim = -2).mean()
833 |
834 | multi_spectral_recon_loss = multi_spectral_recon_loss + l1_mel_loss + l2_log_mel_loss
835 |
836 | # adversarial loss
837 |
838 | adversarial_losses = []
839 |
840 | discr_intermediates = []
841 |
842 | # adversarial loss for multi-scale discriminators
843 |
844 | real, fake = orig_x, recon_x
845 |
846 | # features from stft
847 |
848 | (stft_real_logits, stft_real_intermediates), (stft_fake_logits, stft_fake_intermediates) = map(partial(self.stft_discriminator, return_intermediates=True), (real, fake))
849 | discr_intermediates.append((stft_real_intermediates, stft_fake_intermediates))
850 |
851 | scaled_real, scaled_fake = real, fake
852 | for discr, downsample in zip(self.discriminators, self.downsamples):
853 | scaled_real, scaled_fake = map(downsample, (scaled_real, scaled_fake))
854 |
855 | (real_logits, real_intermediates), (fake_logits, fake_intermediates) = map(partial(discr, return_intermediates = True), (scaled_real, scaled_fake))
856 |
857 | discr_intermediates.append((real_intermediates, fake_intermediates))
858 |
859 | one_adversarial_loss = hinge_gen_loss(fake_logits)
860 | adversarial_losses.append(one_adversarial_loss)
861 |
862 | feature_losses = []
863 |
864 | for real_intermediates, fake_intermediates in discr_intermediates:
865 | losses = [F.l1_loss(real_intermediate, fake_intermediate) for real_intermediate, fake_intermediate in zip(real_intermediates, fake_intermediates)]
866 | feature_losses.extend(losses)
867 |
868 | feature_loss = torch.stack(feature_losses).mean()
869 |
870 | # adversarial loss for stft discriminator
871 |
872 | adversarial_losses.append(hinge_gen_loss(stft_fake_logits))
873 | adversarial_loss = torch.stack(adversarial_losses).mean()
874 |
875 | # sum commitment loss
876 |
877 | all_commitment_loss = commit_loss.sum()
878 |
879 | total_loss = recon_loss * self.recon_loss_weight + multi_spectral_recon_loss * self.multi_spectral_recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight + all_commitment_loss
880 |
881 | if return_loss_breakdown:
882 | return total_loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss)
883 |
884 | return total_loss
885 |
886 | # some default soundstreams
887 |
888 | def AudioLMSoundStream(
889 | strides = (2, 4, 5, 8),
890 | target_sample_hz = 16000,
891 | rq_num_quantizers = 12,
892 | **kwargs
893 | ):
894 | return SoundStream(
895 | strides = strides,
896 | target_sample_hz = target_sample_hz,
897 | rq_num_quantizers = rq_num_quantizers,
898 | **kwargs
899 | )
900 |
901 | def MusicLMSoundStream(
902 | strides = (3, 4, 5, 8),
903 | target_sample_hz = 24000,
904 | rq_num_quantizers = 12,
905 | **kwargs
906 | ):
907 | return SoundStream(
908 | strides = strides,
909 | target_sample_hz = target_sample_hz,
910 | rq_num_quantizers = rq_num_quantizers,
911 | **kwargs
912 | )
913 |
--------------------------------------------------------------------------------
/audiolm_pytorch/trainer.py:
--------------------------------------------------------------------------------
1 | import re
2 | from math import sqrt
3 | import copy
4 | from random import choice
5 | from pathlib import Path
6 | from shutil import rmtree
7 |
8 | from beartype.typing import Union, List, Optional, Tuple
9 | from typing_extensions import Annotated
10 |
11 | from beartype import beartype
12 | from beartype.door import is_bearable
13 | from beartype.vale import Is
14 |
15 | import torch
16 | import torchaudio
17 | from torch import nn
18 | from torch.utils.data import Dataset, DataLoader, random_split
19 |
20 | from einops import rearrange
21 |
22 | from audiolm_pytorch.optimizer import get_optimizer
23 |
24 | from ema_pytorch import EMA
25 |
26 | from audiolm_pytorch.soundstream import SoundStream
27 | from audiolm_pytorch.encodec import EncodecWrapper
28 |
29 | from audiolm_pytorch.audiolm_pytorch import (
30 | SemanticTransformer,
31 | SemanticTransformerWrapper,
32 | CoarseTransformer,
33 | CoarseTransformerWrapper,
34 | FineTransformer,
35 | FineTransformerWrapper,
36 | FairseqVQWav2Vec,
37 | HubertWithKmeans
38 | )
39 |
40 | from audiolm_pytorch.data import SoundDataset, get_dataloader
41 | from audiolm_pytorch.utils import AudioConditionerBase
42 |
43 | from audiolm_pytorch.version import __version__
44 | from packaging import version
45 |
46 | from accelerate import Accelerator
47 | from accelerate.utils import DistributedDataParallelKwargs
48 |
49 | # constants
50 |
51 | DEFAULT_SAMPLE_RATE = 16000
52 |
53 | # for automatically routing data emitted from a dataset to keywords of the transformer wrappers
54 |
55 | DATASET_FIELD_TYPE_CONFIG = dict(
56 | raw_wave = Annotated[
57 | torch.Tensor,
58 | Is[lambda t: t.dtype == torch.float and t.ndim in {2, 3}]
59 | ],
60 | text = List[str],
61 | text_embeds = Annotated[
62 | torch.Tensor,
63 | Is[lambda t: t.dtype == torch.float and t.ndim == 3]
64 | ],
65 | )
66 |
67 | # helpers
68 |
69 | def exists(val):
70 | return val is not None
71 |
72 | def noop(*args, **kwargs):
73 | pass
74 |
75 | def cycle(dl):
76 | while True:
77 | for data in dl:
78 | yield data
79 |
80 | def cast_tuple(t):
81 | return t if isinstance(t, (tuple, list)) else (t,)
82 |
83 | def yes_or_no(question):
84 | answer = input(f'{question} (y/n) ')
85 | return answer.lower() in ('yes', 'y')
86 |
87 | def accum_log(log, new_logs):
88 | for key, new_value in new_logs.items():
89 | old_value = log.get(key, 0.)
90 | log[key] = old_value + new_value
91 | return log
92 |
93 | # auto data to module keyword argument routing functions
94 |
95 | def has_duplicates(tup):
96 | counts = dict()
97 | for el in tup:
98 | if el not in counts:
99 | counts[el] = 0
100 | counts[el] += 1
101 | return any(filter(lambda count: count > 1, counts.values()))
102 |
103 | def determine_types(data, config):
104 | output = []
105 | for el in data:
106 | for name, data_type in config.items():
107 | if is_bearable(el, data_type):
108 | output.append(name)
109 | break
110 | else:
111 | raise TypeError(f'unable to determine type of {data}')
112 |
113 | return tuple(output)
114 |
115 | def checkpoint_num_steps(checkpoint_path):
116 | """Returns the number of steps trained from a checkpoint based on the filename.
117 |
118 | Filename format assumed to be something like "/path/to/semantic.transformer.20000.pt" which is
119 | for 20k train steps. Returns 20000 in that case.
120 | """
121 | return int(re.findall(r'\d+', str(checkpoint_path))[-1])
122 |
123 | # main trainer class
124 |
125 | class SoundStreamTrainer(nn.Module):
126 | @beartype
127 | def __init__(
128 | self,
129 | soundstream: SoundStream,
130 | *,
131 | num_train_steps: int,
132 | batch_size: int,
133 | data_max_length: int = None,
134 | data_max_length_seconds: Union[int, float] = None,
135 | folder: str = None,
136 | train_dataloader: DataLoader = None,
137 | val_dataloader: DataLoader = None,
138 | lr: float = 2e-4,
139 | grad_accum_every: int = 4,
140 | wd: float = 0.,
141 | max_grad_norm: float = 0.5,
142 | discr_max_grad_norm: float = None,
143 | save_results_every: int = 100,
144 | save_model_every: int= 1000,
145 | log_losses_every: int= 1,
146 | results_folder: str = './results',
147 | valid_frac: float = 0.05,
148 | random_split_seed: int = 42,
149 | use_ema: bool = True,
150 | ema_beta: float = 0.995,
151 | ema_update_after_step: int = 500,
152 | ema_update_every: int = 10,
153 | apply_grad_penalty_every: int = 4,
154 | dl_num_workers: int = 0,
155 | accelerator: Accelerator = None,
156 | accelerate_kwargs: dict = dict(),
157 | use_lion: bool = False,
158 | force_clear_prev_results: bool = None # set to True | False to skip the prompt
159 | ):
160 | """
161 | Initialize with a SoundStream instance and either a folder containing audio data or
162 | train/val DataLoader instances.
163 | """
164 | super().__init__()
165 |
166 | if accelerator:
167 | self.accelerator = accelerator
168 | assert len(accelerate_kwargs) == 0
169 | else:
170 | kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
171 | self.accelerator = Accelerator(kwargs_handlers = [kwargs], **accelerate_kwargs)
172 |
173 | self.soundstream = soundstream
174 |
175 | self.use_ema = use_ema
176 | if self.use_ema:
177 | self.ema_soundstream = EMA(soundstream, beta = ema_beta, update_after_step = ema_update_after_step, update_every = ema_update_every)
178 |
179 | self.register_buffer('steps', torch.Tensor([0]))
180 |
181 | self.num_train_steps = num_train_steps
182 | self.batch_size = batch_size
183 | self.grad_accum_every = grad_accum_every
184 |
185 | hyperparameters = {
186 | "num_train_steps": num_train_steps,
187 | "batch_size": batch_size,
188 | "gradient_accum_every": grad_accum_every,
189 | "learning_rate": lr,
190 | "target_sample_hz": soundstream.target_sample_hz,
191 | }
192 |
193 | # optimizers
194 |
195 | self.optim = get_optimizer(soundstream.non_discr_parameters(), lr = lr, wd = wd)
196 |
197 | for discr_optimizer_key, discr in self.multiscale_discriminator_iter():
198 | one_multiscale_discr_optimizer = get_optimizer(discr.parameters(), lr = lr, wd = wd)
199 | setattr(self, discr_optimizer_key, one_multiscale_discr_optimizer)
200 |
201 | self.discr_optim = get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd, use_lion = use_lion)
202 |
203 | # max grad norm
204 |
205 | self.max_grad_norm = max_grad_norm
206 | self.discr_max_grad_norm = discr_max_grad_norm
207 |
208 | if folder is None:
209 | assert train_dataloader is not None
210 | assert val_dataloader is not None
211 | self.dl = train_dataloader
212 | self.valid_dl = val_dataloader
213 | else:
214 | assert train_dataloader is None
215 | assert val_dataloader is None
216 |
217 | # create dataset
218 |
219 | if exists(data_max_length_seconds):
220 | assert not exists(data_max_length)
221 | data_max_length = int(data_max_length_seconds * soundstream.target_sample_hz)
222 | else:
223 | assert exists(data_max_length)
224 |
225 | hyperparameters['data_max_length'] = data_max_length
226 |
227 | self.ds = SoundDataset(
228 | folder,
229 | max_length = data_max_length,
230 | target_sample_hz = soundstream.target_sample_hz,
231 | seq_len_multiple_of = soundstream.seq_len_multiple_of
232 | )
233 |
234 | # split for validation
235 |
236 | if valid_frac > 0:
237 | train_size = int((1 - valid_frac) * len(self.ds))
238 | valid_size = len(self.ds) - train_size
239 | self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
240 | self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
241 | else:
242 | self.valid_ds = self.ds
243 | self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')
244 |
245 | # dataloader
246 |
247 | self.dl = get_dataloader(self.ds, batch_size = batch_size, num_workers = dl_num_workers, shuffle = True)
248 |
249 | self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, num_workers = dl_num_workers, shuffle = True)
250 |
251 | # prepare with accelerator
252 |
253 | (
254 | self.soundstream,
255 | self.optim,
256 | self.discr_optim,
257 | self.dl
258 | ) = self.accelerator.prepare(
259 | self.soundstream,
260 | self.optim,
261 | self.discr_optim,
262 | self.dl
263 | )
264 |
265 | # prepare the multiscale discriminators with accelerator
266 |
267 | for name, _ in self.multiscale_discriminator_iter():
268 | optimizer = getattr(self, name)
269 | optimizer = self.accelerator.prepare(optimizer)
270 | setattr(self, name, optimizer)
271 |
272 | # dataloader iterators
273 |
274 | self.dl_iter = cycle(self.dl)
275 | self.valid_dl_iter = cycle(self.valid_dl)
276 |
277 | self.save_model_every = save_model_every
278 | self.save_results_every = save_results_every
279 | self.log_losses_every = log_losses_every
280 |
281 | self.apply_grad_penalty_every = apply_grad_penalty_every
282 |
283 | self.results_folder = Path(results_folder)
284 |
285 | if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
286 | rmtree(str(self.results_folder))
287 |
288 | self.results_folder.mkdir(parents = True, exist_ok = True)
289 |
290 | # Initialize experiment trackers if an external Accelerator is not passed in
291 | if not accelerator:
292 | self.accelerator.init_trackers("soundstream", config=hyperparameters)
293 |
294 | def set_model_as_ema_model_(self):
295 | """ this will force the main 'online' model to have same parameters as the exponentially moving averaged model """
296 | assert self.use_ema
297 | self.ema_soundstream.ema_model.load_state_dict(self.soundstream.state_dict())
298 |
299 | def save(self, path):
300 | pkg = dict(
301 | model = self.accelerator.get_state_dict(self.soundstream),
302 | optim = self.optim.state_dict(),
303 | config = self.unwrapped_soundstream._configs,
304 | discr_optim = self.discr_optim.state_dict(),
305 | version = __version__
306 | )
307 |
308 | if self.use_ema:
309 | pkg['ema_model'] = self.ema_soundstream.state_dict()
310 |
311 | for key, _ in self.multiscale_discriminator_iter():
312 | discr_optim = getattr(self, key)
313 | pkg[key] = discr_optim.state_dict()
314 |
315 | torch.save(pkg, path)
316 |
317 | @property
318 | def unwrapped_soundstream(self):
319 | return self.accelerator.unwrap_model(self.soundstream)
320 |
321 | def load(self, path):
322 | path = Path(path)
323 | assert path.exists()
324 | pkg = torch.load(str(path), map_location = 'cpu')
325 |
326 | # if loading from old version, make a hacky guess
327 |
328 | if len(pkg.keys()) > 20:
329 | self.unwrapped_soundstream.load_state_dict(pkg)
330 |
331 | if self.use_ema:
332 | self.ema_soundstream.ema_model.load_state_dict(pkg)
333 | return
334 |
335 | # check version
336 |
337 | if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__):
338 | print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch')
339 |
340 | # otherwise load things normally
341 |
342 | self.unwrapped_soundstream.load_state_dict(pkg['model'])
343 |
344 | if self.use_ema:
345 | assert 'ema_model' in pkg
346 | self.ema_soundstream.load_state_dict(pkg['ema_model'])
347 |
348 | self.optim.load_state_dict(pkg['optim'])
349 | self.discr_optim.load_state_dict(pkg['discr_optim'])
350 |
351 | for key, _ in self.multiscale_discriminator_iter():
352 | discr_optim = getattr(self, key)
353 | discr_optim.load_state_dict(pkg[key])
354 | # + 1 to start from the next step and avoid overwriting the last checkpoint
355 | self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
356 |
357 | def multiscale_discriminator_iter(self):
358 | for ind, discr in enumerate(self.unwrapped_soundstream.discriminators):
359 | yield f'multiscale_discr_optimizer_{ind}', discr
360 |
361 | def multiscale_discriminator_optim_iter(self):
362 | for name, _ in self.multiscale_discriminator_iter():
363 | yield name, getattr(self, name)
364 |
365 | def print(self, msg):
366 | self.accelerator.print(msg)
367 |
368 | @property
369 | def device(self):
370 | return self.accelerator.device
371 |
372 | @property
373 | def is_distributed(self):
374 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
375 |
376 | @property
377 | def is_main(self):
378 | return self.accelerator.is_main_process
379 |
380 | @property
381 | def is_local_main(self):
382 | return self.accelerator.is_local_main_process
383 |
384 | def train_step(self):
385 | device = self.device
386 |
387 | steps = int(self.steps.item())
388 | apply_grad_penalty = self.apply_grad_penalty_every > 0 and not (steps % self.apply_grad_penalty_every)
389 | log_losses = self.log_losses_every > 0 and not (steps % self.log_losses_every)
390 |
391 | self.soundstream.train()
392 |
393 | # logs
394 |
395 | logs = {}
396 |
397 | # update vae (generator)
398 |
399 | for _ in range(self.grad_accum_every):
400 | wave, = next(self.dl_iter)
401 | wave = wave.to(device)
402 |
403 | loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True)
404 |
405 | self.accelerator.backward(loss / self.grad_accum_every)
406 |
407 | accum_log(logs, dict(
408 | loss = loss.item() / self.grad_accum_every,
409 | recon_loss = recon_loss.item() / self.grad_accum_every,
410 | ))
411 |
412 | if log_losses:
413 | accum_log(logs, dict(
414 | multi_spectral_recon_loss = multi_spectral_recon_loss.item() / self.grad_accum_every,
415 | adversarial_loss = adversarial_loss.item() / self.grad_accum_every,
416 | feature_loss = feature_loss.item() / self.grad_accum_every,
417 | all_commitment_loss = all_commitment_loss.item() / self.grad_accum_every,
418 | ))
419 |
420 | if exists(self.max_grad_norm):
421 | self.accelerator.clip_grad_norm_(self.soundstream.parameters(), self.max_grad_norm)
422 |
423 | self.optim.step()
424 | self.optim.zero_grad()
425 |
426 | # update discriminator
427 |
428 | self.discr_optim.zero_grad()
429 |
430 | for name, multiscale_discr_optim in self.multiscale_discriminator_optim_iter():
431 | multiscale_discr_optim.zero_grad()
432 |
433 | for _ in range(self.grad_accum_every):
434 | wave, = next(self.dl_iter)
435 | wave = wave.to(device)
436 |
437 | discr_losses = self.soundstream(
438 | wave,
439 | apply_grad_penalty = apply_grad_penalty,
440 | return_discr_loss = True,
441 | return_discr_losses_separately = True
442 | )
443 |
444 | for name, discr_loss in discr_losses:
445 | self.accelerator.backward(discr_loss / self.grad_accum_every, retain_graph = True)
446 | accum_log(logs, {name: discr_loss.item() / self.grad_accum_every})
447 |
448 | if exists(self.discr_max_grad_norm):
449 | self.accelerator.clip_grad_norm_(self.soundstream.stft_discriminator.parameters(), self.discr_max_grad_norm)
450 |
451 | # gradient step for all discriminators
452 |
453 | self.discr_optim.step()
454 |
455 | for name, multiscale_discr_optim in self.multiscale_discriminator_optim_iter():
456 | multiscale_discr_optim.step()
457 |
458 | # build pretty printed losses
459 |
460 | losses_str = f"{steps}: soundstream total loss: {logs['loss']:.3f}, soundstream recon loss: {logs['recon_loss']:.3f}"
461 | if log_losses:
462 | self.accelerator.log({
463 | "total_loss": logs['loss'],
464 | "recon_loss": logs['recon_loss'],
465 | "multi_spectral_recon_loss": logs['multi_spectral_recon_loss'],
466 | "adversarial_loss": logs['adversarial_loss'],
467 | "feature_loss": logs['feature_loss'],
468 | "all_commitment_loss": logs['all_commitment_loss'],
469 | "stft_discr_loss": logs['stft']
470 | }, step=steps)
471 |
472 | for key, loss in logs.items():
473 | if not key.startswith('scale:'):
474 | continue
475 | _, scale_factor = key.split(':')
476 |
477 | losses_str += f" | discr (scale {scale_factor}) loss: {loss:.3f}"
478 | if log_losses:
479 | self.accelerator.log({f"discr_loss (scale {scale_factor})": loss}, step=steps)
480 |
481 | # log
482 |
483 | self.print(losses_str)
484 |
485 | # update exponential moving averaged generator
486 |
487 | self.accelerator.wait_for_everyone()
488 |
489 | if self.is_main and self.use_ema:
490 | self.ema_soundstream.update()
491 |
492 | # sample results every so often
493 |
494 | self.accelerator.wait_for_everyone()
495 |
496 | if self.is_main and not (steps % self.save_results_every):
497 | models = [(self.unwrapped_soundstream, str(steps))]
498 | if self.use_ema:
499 | models.append((self.ema_soundstream.ema_model if self.use_ema else self.unwrapped_soundstream, f'{steps}.ema'))
500 |
501 | wave, = next(self.valid_dl_iter)
502 | wave = wave.to(device)
503 |
504 | for model, label in models:
505 | model.eval()
506 |
507 | with torch.no_grad():
508 | recons = model(wave, return_recons_only = True)
509 |
510 | for ind, recon in enumerate(recons.unbind(dim = 0)):
511 | filename = str(self.results_folder / f'sample_{label}.flac')
512 | torchaudio.save(filename, recon.cpu().detach(), self.unwrapped_soundstream.target_sample_hz)
513 |
514 | self.print(f'{steps}: saving to {str(self.results_folder)}')
515 |
516 | # save model every so often
517 |
518 | self.accelerator.wait_for_everyone()
519 |
520 | if self.is_main and not (steps % self.save_model_every):
521 | model_path = str(self.results_folder / f'soundstream.{steps}.pt')
522 | self.save(model_path)
523 |
524 | self.print(f'{steps}: saving model to {str(self.results_folder)}')
525 |
526 | self.steps += 1
527 | return logs
528 |
529 | def train(self, log_fn = noop):
530 |
531 | while self.steps < self.num_train_steps:
532 | logs = self.train_step()
533 | log_fn(logs)
534 |
535 | self.print('training complete')
536 |
537 | # semantic transformer trainer
538 |
539 | class SemanticTransformerTrainer(nn.Module):
540 | @beartype
541 | def __init__(
542 | self,
543 | wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
544 | transformer: SemanticTransformer,
545 | *,
546 | num_train_steps,
547 | batch_size,
548 | audio_conditioner: Optional[AudioConditionerBase] = None,
549 | dataset: Optional[Dataset] = None,
550 | data_max_length = None,
551 | data_max_length_seconds = None,
552 | folder = None,
553 | lr = 3e-4,
554 | grad_accum_every = 1,
555 | wd = 0.,
556 | max_grad_norm = 0.5,
557 | valid_frac = 0.05,
558 | random_split_seed = 42,
559 | save_results_every = 100,
560 | save_model_every = 1000,
561 | results_folder = './results',
562 | accelerate_kwargs: dict = dict(),
563 | force_clear_prev_results = None
564 | ):
565 | super().__init__()
566 | self.accelerator = Accelerator(**accelerate_kwargs)
567 |
568 | self.wav2vec = wav2vec
569 | self.transformer = transformer
570 | self.audio_conditioner = audio_conditioner
571 |
572 | self.train_wrapper = SemanticTransformerWrapper(
573 | wav2vec = wav2vec,
574 | transformer = transformer,
575 | audio_conditioner = audio_conditioner
576 | )
577 |
578 | self.register_buffer('steps', torch.Tensor([0]))
579 |
580 | self.num_train_steps = num_train_steps
581 | self.batch_size = batch_size
582 | self.grad_accum_every = grad_accum_every
583 |
584 | # optimizers
585 |
586 | self.optim = get_optimizer(transformer.parameters(), lr = lr, wd = wd)
587 |
588 | # max grad norm
589 |
590 | self.max_grad_norm = max_grad_norm
591 |
592 | # create dataset
593 |
594 | self.ds = dataset
595 | if not exists(self.ds):
596 | assert exists(folder), 'folder must be passed in, if not passing in a custom dataset for text conditioned audio synthesis training'
597 |
598 | assert not (exists(data_max_length) and exists(data_max_length_seconds))
599 |
600 | if exists(data_max_length_seconds):
601 | data_max_length = data_max_length_seconds * wav2vec.target_sample_hz
602 |
603 | self.ds = SoundDataset(
604 | folder,
605 | max_length = data_max_length,
606 | target_sample_hz = wav2vec.target_sample_hz,
607 | seq_len_multiple_of = wav2vec.seq_len_multiple_of
608 | )
609 |
610 | self.ds_fields = None
611 |
612 | # split for validation
613 |
614 | if valid_frac > 0:
615 | train_size = int((1 - valid_frac) * len(self.ds))
616 | valid_size = len(self.ds) - train_size
617 | self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
618 | self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
619 | else:
620 | self.valid_ds = self.ds
621 | self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')
622 |
623 | # dataloader
624 |
625 | self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True)
626 |
627 | self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True)
628 |
629 | # prepare with accelerator
630 |
631 | (
632 | self.train_wrapper,
633 | self.optim,
634 | self.dl,
635 | self.valid_dl
636 | ) = self.accelerator.prepare(
637 | self.train_wrapper,
638 | self.optim,
639 | self.dl,
640 | self.valid_dl
641 | )
642 |
643 | # dataloader iterators
644 |
645 | self.dl_iter = cycle(self.dl)
646 | self.valid_dl_iter = cycle(self.valid_dl)
647 |
648 | self.save_model_every = save_model_every
649 | self.save_results_every = save_results_every
650 |
651 | self.results_folder = Path(results_folder)
652 |
653 | if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
654 | rmtree(str(self.results_folder))
655 |
656 | self.results_folder.mkdir(parents = True, exist_ok = True)
657 |
658 | hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
659 | self.accelerator.init_trackers("semantic", config=hps)
660 |
661 | def save(self, path):
662 | pkg = dict(
663 | model = self.accelerator.get_state_dict(self.transformer),
664 | optim = self.optim.state_dict(),
665 | version = __version__
666 | )
667 | torch.save(pkg, path)
668 |
669 | def load(self, path):
670 | path = Path(path)
671 | assert path.exists()
672 | pkg = torch.load(str(path), map_location = 'cpu')
673 |
674 | # check version
675 |
676 | if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__):
677 | print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch')
678 |
679 | transformer = self.accelerator.unwrap_model(self.transformer)
680 | transformer.load_state_dict(pkg['model'])
681 | self.optim.load_state_dict(pkg['optim'])
682 | # + 1 to start from the next step and avoid overwriting the last checkpoint
683 | self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
684 |
685 |
686 | def print(self, msg):
687 | self.accelerator.print(msg)
688 |
689 | def generate(self, *args, **kwargs):
690 | return self.train_wrapper.generate(*args, **kwargs)
691 |
692 | @property
693 | def device(self):
694 | return self.accelerator.device
695 |
696 | @property
697 | def is_distributed(self):
698 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
699 |
700 | @property
701 | def is_main(self):
702 | return self.accelerator.is_main_process
703 |
704 | @property
705 | def is_local_main(self):
706 | return self.accelerator.is_local_main_process
707 |
708 | def data_tuple_to_kwargs(self, data):
709 | if not exists(self.ds_fields):
710 | self.ds_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)
711 | assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names'
712 |
713 | return dict(zip(self.ds_fields, data))
714 |
715 | def train_step(self):
716 | device = self.device
717 |
718 | steps = int(self.steps.item())
719 |
720 | self.transformer.train()
721 |
722 | # logs
723 |
724 | logs = {}
725 |
726 | # update vae (generator)
727 |
728 | for _ in range(self.grad_accum_every):
729 | data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))
730 |
731 | loss = self.train_wrapper(**data_kwargs, return_loss = True)
732 |
733 | self.accelerator.backward(loss / self.grad_accum_every)
734 |
735 | accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
736 |
737 | if exists(self.max_grad_norm):
738 | self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm)
739 |
740 | self.optim.step()
741 | self.optim.zero_grad()
742 |
743 | # log
744 |
745 | self.print(f"{steps}: loss: {logs['loss']}")
746 | self.accelerator.log({"train_loss": logs['loss']}, step=steps)
747 |
748 | # sample results every so often
749 |
750 | if self.is_main and not (steps % self.save_results_every):
751 | data_kwargs = self.data_tuple_to_kwargs(next(self.valid_dl_iter))
752 |
753 | with torch.no_grad():
754 | self.train_wrapper.eval()
755 | valid_loss = self.train_wrapper(**data_kwargs, return_loss = True)
756 |
757 | self.print(f'{steps}: valid loss {valid_loss}')
758 | self.accelerator.log({"valid_loss": valid_loss}, step=steps)
759 |
760 | # save model every so often
761 |
762 | if self.is_main and not (steps % self.save_model_every):
763 | model_path = str(self.results_folder / f'semantic.transformer.{steps}.pt')
764 | self.save(model_path)
765 |
766 | self.print(f'{steps}: saving model to {str(self.results_folder)}')
767 |
768 | self.steps += 1
769 | return logs
770 |
771 | def train(self, log_fn = noop):
772 |
773 | while self.steps < self.num_train_steps:
774 | logs = self.train_step()
775 | log_fn(logs)
776 |
777 | self.print('training complete')
778 |
779 | # fine transformer trainer
780 |
781 | class CoarseTransformerTrainer(nn.Module):
782 | @beartype
783 | def __init__(
784 | self,
785 | transformer: CoarseTransformer,
786 | codec: Union[SoundStream, EncodecWrapper],
787 | wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
788 | *,
789 | num_train_steps,
790 | batch_size,
791 | audio_conditioner: Optional[AudioConditionerBase] = None,
792 | dataset: Optional[Dataset] = None,
793 | ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_codec', 'text'),
794 | data_max_length = None,
795 | data_max_length_seconds = None,
796 | folder = None,
797 | lr = 3e-4,
798 | grad_accum_every = 1,
799 | wd = 0.,
800 | max_grad_norm = 0.5,
801 | valid_frac = 0.05,
802 | random_split_seed = 42,
803 | save_results_every = 100,
804 | save_model_every = 1000,
805 | results_folder = './results',
806 | accelerate_kwargs: dict = dict(),
807 | force_clear_prev_results = None
808 | ):
809 | super().__init__()
810 | self.accelerator = Accelerator(**accelerate_kwargs)
811 |
812 | self.transformer = transformer
813 | self.codec = codec
814 | self.wav2vec = wav2vec
815 | self.audio_conditioner = audio_conditioner
816 |
817 | self.train_wrapper = CoarseTransformerWrapper(
818 | codec = codec,
819 | wav2vec = wav2vec,
820 | transformer = transformer,
821 | audio_conditioner = audio_conditioner
822 | )
823 |
824 | self.register_buffer('steps', torch.Tensor([0]))
825 |
826 | self.num_train_steps = num_train_steps
827 | self.batch_size = batch_size
828 | self.grad_accum_every = grad_accum_every
829 |
830 | # optimizers
831 |
832 | self.optim = get_optimizer(transformer.parameters(), lr = lr, wd = wd)
833 |
834 | # max grad norm
835 |
836 | self.max_grad_norm = max_grad_norm
837 |
838 | # create dataset
839 |
840 | self.ds = dataset
841 |
842 | if not exists(self.ds):
843 | assert exists(folder), 'folder must be passed in, if not passing in a custom dataset for text conditioned audio synthesis training'
844 |
845 | assert not (exists(data_max_length) and exists(data_max_length_seconds))
846 |
847 | if exists(data_max_length_seconds):
848 | data_max_length = tuple(data_max_length_seconds * hz for hz in (wav2vec.target_sample_hz, codec.target_sample_hz))
849 |
850 | self.ds = SoundDataset(
851 | folder,
852 | max_length = data_max_length,
853 | target_sample_hz = (
854 | wav2vec.target_sample_hz,
855 | codec.target_sample_hz
856 | ), # need 2 waves resampled differently here
857 | seq_len_multiple_of = codec.seq_len_multiple_of
858 | )
859 |
860 | self.ds_fields = ds_fields
861 |
862 | # split for validation
863 |
864 | if valid_frac > 0:
865 | train_size = int((1 - valid_frac) * len(self.ds))
866 | valid_size = len(self.ds) - train_size
867 | self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
868 | self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
869 | else:
870 | self.valid_ds = self.ds
871 | self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')
872 |
873 | # dataloader
874 |
875 | self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True)
876 |
877 | self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True)
878 |
879 | # prepare with accelerator
880 |
881 | (
882 | self.transformer,
883 | self.optim,
884 | self.dl,
885 | self.valid_dl
886 | ) = self.accelerator.prepare(
887 | self.transformer,
888 | self.optim,
889 | self.dl,
890 | self.valid_dl
891 | )
892 |
893 | # dataloader iterators
894 |
895 | self.dl_iter = cycle(self.dl)
896 | self.valid_dl_iter = cycle(self.valid_dl)
897 |
898 | self.save_model_every = save_model_every
899 | self.save_results_every = save_results_every
900 |
901 | self.results_folder = Path(results_folder)
902 |
903 | if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
904 | rmtree(str(self.results_folder))
905 |
906 | self.results_folder.mkdir(parents = True, exist_ok = True)
907 |
908 | hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
909 | self.accelerator.init_trackers("coarse", config=hps)
910 |
911 | self.train_wrapper.to(self.device)
912 |
913 | def save(self, path):
914 | pkg = dict(
915 | model = self.accelerator.get_state_dict(self.transformer),
916 | optim = self.optim.state_dict(),
917 | version = __version__
918 | )
919 | torch.save(pkg, path)
920 |
921 | def load(self, path):
922 | path = Path(path)
923 | assert path.exists()
924 | pkg = torch.load(str(path), map_location = 'cpu')
925 |
926 | # check version
927 |
928 | if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__):
929 | print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch')
930 |
931 | transformer = self.accelerator.unwrap_model(self.transformer)
932 | transformer.load_state_dict(pkg['model'])
933 |
934 | self.optim.load_state_dict(pkg['optim'])
935 | # + 1 to start from the next step and avoid overwriting the last checkpoint
936 | self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
937 |
938 |
939 | def print(self, msg):
940 | self.accelerator.print(msg)
941 |
942 | def generate(self, *args, **kwargs):
943 | return self.train_wrapper.generate(*args, **kwargs)
944 |
945 | @property
946 | def device(self):
947 | return self.accelerator.device
948 |
949 | @property
950 | def is_distributed(self):
951 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
952 |
953 | @property
954 | def is_main(self):
955 | return self.accelerator.is_main_process
956 |
957 | @property
958 | def is_local_main(self):
959 | return self.accelerator.is_local_main_process
960 |
961 | def train_step(self):
962 | device = self.device
963 |
964 | steps = int(self.steps.item())
965 |
966 | self.transformer.train()
967 |
968 | # logs
969 |
970 | logs = {}
971 |
972 | # update vae (generator)
973 |
974 | for _ in range(self.grad_accum_every):
975 | data_kwargs = dict(zip(self.ds_fields, next(self.dl_iter)))
976 |
977 | loss = self.train_wrapper(
978 | **data_kwargs,
979 | return_loss = True
980 | )
981 |
982 | self.accelerator.backward(loss / self.grad_accum_every)
983 |
984 | accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
985 |
986 | if exists(self.max_grad_norm):
987 | self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm)
988 |
989 | self.optim.step()
990 | self.optim.zero_grad()
991 |
992 | # log
993 |
994 | self.print(f"{steps}: loss: {logs['loss']}")
995 | self.accelerator.log({"train_loss": logs['loss']}, step=steps)
996 |
997 | # sample results every so often
998 |
999 | if self.is_main and not (steps % self.save_results_every):
1000 | data_kwargs = dict(zip(self.ds_fields, next(self.valid_dl_iter)))
1001 |
1002 | with torch.no_grad():
1003 | self.train_wrapper.eval()
1004 |
1005 | valid_loss = self.train_wrapper(
1006 | **data_kwargs,
1007 | return_loss = True
1008 | )
1009 |
1010 | self.print(f'{steps}: valid loss {valid_loss}')
1011 | self.accelerator.log({"valid_loss": valid_loss}, step=steps)
1012 |
1013 | # save model every so often
1014 |
1015 | if self.is_main and not (steps % self.save_model_every):
1016 | model_path = str(self.results_folder / f'coarse.transformer.{steps}.pt')
1017 | self.save(model_path)
1018 |
1019 | self.print(f'{steps}: saving model to {str(self.results_folder)}')
1020 |
1021 | self.steps += 1
1022 | return logs
1023 |
1024 | def train(self, log_fn = noop):
1025 |
1026 | while self.steps < self.num_train_steps:
1027 | logs = self.train_step()
1028 | log_fn(logs)
1029 |
1030 | self.print('training complete')
1031 |
1032 | # fine transformer trainer
1033 |
1034 | class FineTransformerTrainer(nn.Module):
1035 | @beartype
1036 | def __init__(
1037 | self,
1038 | transformer: FineTransformer,
1039 | codec: Union[SoundStream, EncodecWrapper],
1040 | *,
1041 | num_train_steps,
1042 | batch_size,
1043 | audio_conditioner: Optional[AudioConditionerBase] = None,
1044 | dataset: Optional[Dataset] = None,
1045 | data_max_length = None,
1046 | data_max_length_seconds = None,
1047 | dataset_normalize = False,
1048 | folder = None,
1049 | lr = 3e-4,
1050 | grad_accum_every = 1,
1051 | wd = 0.,
1052 | max_grad_norm = 0.5,
1053 | valid_frac = 0.05,
1054 | random_split_seed = 42,
1055 | save_results_every = 100,
1056 | save_model_every = 1000,
1057 | results_folder = './results',
1058 | accelerate_kwargs: dict = dict(),
1059 | force_clear_prev_results = None
1060 | ):
1061 | super().__init__()
1062 | self.accelerator = Accelerator(**accelerate_kwargs)
1063 |
1064 | self.transformer = transformer
1065 | self.codec = codec
1066 | self.audio_conditioner = audio_conditioner
1067 |
1068 | self.train_wrapper = FineTransformerWrapper(
1069 | codec = codec,
1070 | transformer = transformer,
1071 | audio_conditioner = audio_conditioner
1072 | )
1073 |
1074 | self.register_buffer('steps', torch.Tensor([0]))
1075 |
1076 | self.num_train_steps = num_train_steps
1077 | self.batch_size = batch_size
1078 | self.grad_accum_every = grad_accum_every
1079 |
1080 | # optimizers
1081 |
1082 | self.optim = get_optimizer(transformer.parameters(), lr = lr, wd = wd)
1083 |
1084 | # max grad norm
1085 |
1086 | self.max_grad_norm = max_grad_norm
1087 |
1088 | # create dataset
1089 |
1090 | self.ds = dataset
1091 |
1092 | if not exists(self.ds):
1093 | assert exists(folder), 'folder must be passed in, if not passing in a custom dataset for text conditioned audio synthesis training'
1094 |
1095 | assert not (exists(data_max_length) and exists(data_max_length_seconds))
1096 |
1097 | if exists(data_max_length_seconds):
1098 | data_max_length = data_max_length_seconds * codec.target_sample_hz
1099 |
1100 | self.ds = SoundDataset(
1101 | folder,
1102 | max_length = data_max_length,
1103 | target_sample_hz = codec.target_sample_hz,
1104 | seq_len_multiple_of = codec.seq_len_multiple_of
1105 | )
1106 |
1107 | self.ds_fields = None
1108 |
1109 | # split for validation
1110 |
1111 | if valid_frac > 0:
1112 | train_size = int((1 - valid_frac) * len(self.ds))
1113 | valid_size = len(self.ds) - train_size
1114 | self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
1115 | self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
1116 | else:
1117 | self.valid_ds = self.ds
1118 | self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')
1119 |
1120 | # dataloader
1121 |
1122 | self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True)
1123 |
1124 | self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True)
1125 |
1126 | # prepare with accelerator
1127 |
1128 | (
1129 | self.transformer,
1130 | self.optim,
1131 | self.dl,
1132 | self.valid_dl
1133 | ) = self.accelerator.prepare(
1134 | self.transformer,
1135 | self.optim,
1136 | self.dl,
1137 | self.valid_dl
1138 | )
1139 |
1140 | # dataloader iterators
1141 |
1142 | self.dl_iter = cycle(self.dl)
1143 | self.valid_dl_iter = cycle(self.valid_dl)
1144 |
1145 | self.save_model_every = save_model_every
1146 | self.save_results_every = save_results_every
1147 |
1148 | self.results_folder = Path(results_folder)
1149 |
1150 | if force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
1151 | rmtree(str(self.results_folder))
1152 |
1153 | self.results_folder.mkdir(parents = True, exist_ok = True)
1154 |
1155 | hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
1156 | self.accelerator.init_trackers("fine", config=hps)
1157 |
1158 | self.train_wrapper.to(self.device)
1159 |
1160 | def save(self, path):
1161 | pkg = dict(
1162 | model = self.accelerator.get_state_dict(self.transformer),
1163 | optim = self.optim.state_dict(),
1164 | version = __version__
1165 | )
1166 | torch.save(pkg, path)
1167 |
1168 | def load(self, path):
1169 | path = Path(path)
1170 | assert path.exists()
1171 | pkg = torch.load(str(path), map_location = 'cpu')
1172 |
1173 | # check version
1174 |
1175 | if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__):
1176 | print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch')
1177 |
1178 | transformer = self.accelerator.unwrap_model(self.transformer)
1179 | transformer.load_state_dict(pkg['model'])
1180 |
1181 | self.optim.load_state_dict(pkg['optim'])
1182 | # + 1 to start from the next step and avoid overwriting the last checkpoint
1183 | self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
1184 |
1185 |
1186 | def print(self, msg):
1187 | self.accelerator.print(msg)
1188 |
1189 | def generate(self, *args, **kwargs):
1190 | return self.train_wrapper.generate(*args, **kwargs)
1191 |
1192 | @property
1193 | def device(self):
1194 | return self.accelerator.device
1195 |
1196 | @property
1197 | def is_distributed(self):
1198 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
1199 |
1200 | @property
1201 | def is_main(self):
1202 | return self.accelerator.is_main_process
1203 |
1204 | @property
1205 | def is_local_main(self):
1206 | return self.accelerator.is_local_main_process
1207 |
1208 | def data_tuple_to_kwargs(self, data):
1209 | if not exists(self.ds_fields):
1210 | self.ds_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)
1211 | assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names'
1212 |
1213 | return dict(zip(self.ds_fields, data))
1214 |
1215 | def train_step(self):
1216 | device = self.device
1217 |
1218 | steps = int(self.steps.item())
1219 |
1220 | self.transformer.train()
1221 |
1222 | # logs
1223 |
1224 | logs = {}
1225 |
1226 | # update vae (generator)
1227 |
1228 | for _ in range(self.grad_accum_every):
1229 | data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))
1230 | loss = self.train_wrapper(**data_kwargs, return_loss = True)
1231 |
1232 | self.accelerator.backward(loss / self.grad_accum_every)
1233 |
1234 | accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
1235 |
1236 | if exists(self.max_grad_norm):
1237 | self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm)
1238 |
1239 | self.optim.step()
1240 | self.optim.zero_grad()
1241 |
1242 | # log
1243 |
1244 | self.print(f"{steps}: loss: {logs['loss']}")
1245 | self.accelerator.log({"train_loss": logs['loss']}, step=steps)
1246 |
1247 | # sample results every so often
1248 |
1249 | if self.is_main and not (steps % self.save_results_every):
1250 | data_kwargs = self.data_tuple_to_kwargs(next(self.valid_dl_iter))
1251 |
1252 | with torch.no_grad():
1253 | self.train_wrapper.eval()
1254 | valid_loss = self.train_wrapper(**data_kwargs, return_loss = True)
1255 |
1256 | self.print(f'{steps}: valid loss {valid_loss}')
1257 | self.accelerator.log({"valid_loss": valid_loss}, step=steps)
1258 |
1259 | # save model every so often
1260 |
1261 | if self.is_main and not (steps % self.save_model_every):
1262 | model_path = str(self.results_folder / f'fine.transformer.{steps}.pt')
1263 | self.save(model_path)
1264 |
1265 | self.print(f'{steps}: saving model to {str(self.results_folder)}')
1266 |
1267 | self.steps += 1
1268 | return logs
1269 |
1270 | def train(self, log_fn = noop):
1271 |
1272 | while self.steps < self.num_train_steps:
1273 | logs = self.train_step()
1274 | log_fn(logs)
1275 |
1276 | self.print('training complete')
1277 |
--------------------------------------------------------------------------------