├── audiolm_pytorch ├── version.py ├── utils.py ├── __init__.py ├── optimizer.py ├── vq_wav2vec.py ├── t5.py ├── hubert_kmeans.py ├── data.py ├── encodec.py ├── soundstream.py └── trainer.py ├── audiolm.png ├── LICENSE ├── .github └── workflows │ └── python-publish.yml ├── setup.py ├── .gitignore ├── README.md └── audiolm_pytorch_demo.ipynb /audiolm_pytorch/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.1.4' 2 | -------------------------------------------------------------------------------- /audiolm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/audiolm-pytorch/main/audiolm.png -------------------------------------------------------------------------------- /audiolm_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | # functions 4 | 5 | def round_down_nearest_multiple(num, divisor): 6 | return num // divisor * divisor 7 | 8 | def curtail_to_multiple(t, mult, from_left = False): 9 | data_len = t.shape[-1] 10 | rounded_seq_len = round_down_nearest_multiple(data_len, mult) 11 | seq_slice = slice(None, rounded_seq_len) if not from_left else slice(-rounded_seq_len, None) 12 | return t[..., seq_slice] 13 | 14 | # base class 15 | 16 | class AudioConditionerBase(nn.Module): 17 | pass 18 | -------------------------------------------------------------------------------- /audiolm_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from packaging import version 3 | 4 | if version.parse(torch.__version__) >= version.parse('2.0.0'): 5 | from einops._torch_specific import allow_ops_in_compiled_graph 6 | allow_ops_in_compiled_graph() 7 | 8 | from audiolm_pytorch.audiolm_pytorch import AudioLM 9 | from audiolm_pytorch.soundstream import SoundStream, AudioLMSoundStream, MusicLMSoundStream 10 | from audiolm_pytorch.encodec import EncodecWrapper 11 | 12 | from audiolm_pytorch.audiolm_pytorch import SemanticTransformer, CoarseTransformer, FineTransformer 13 | from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper, SemanticTransformerWrapper 14 | 15 | from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec 16 | from audiolm_pytorch.hubert_kmeans import HubertWithKmeans 17 | 18 | from audiolm_pytorch.trainer import SoundStreamTrainer, SemanticTransformerTrainer, FineTransformerTrainer, CoarseTransformerTrainer 19 | 20 | from audiolm_pytorch.audiolm_pytorch import get_embeds 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /audiolm_pytorch/optimizer.py: -------------------------------------------------------------------------------- 1 | from lion_pytorch import Lion 2 | from torch.optim import AdamW, Adam 3 | 4 | def separate_weight_decayable_params(params): 5 | wd_params, no_wd_params = [], [] 6 | for param in params: 7 | param_list = no_wd_params if param.ndim < 2 else wd_params 8 | param_list.append(param) 9 | return wd_params, no_wd_params 10 | 11 | def get_optimizer( 12 | params, 13 | lr = 1e-4, 14 | wd = 1e-2, 15 | betas = (0.9, 0.99), 16 | eps = 1e-8, 17 | filter_by_requires_grad = False, 18 | group_wd_params = True, 19 | use_lion = False, 20 | **kwargs 21 | ): 22 | has_wd = wd > 0 23 | 24 | if filter_by_requires_grad: 25 | params = list(filter(lambda t: t.requires_grad, params)) 26 | 27 | if group_wd_params and has_wd: 28 | wd_params, no_wd_params = separate_weight_decayable_params(params) 29 | 30 | params = [ 31 | {'params': wd_params}, 32 | {'params': no_wd_params, 'weight_decay': 0}, 33 | ] 34 | 35 | if use_lion: 36 | return Lion(params, lr = lr, betas = betas, weight_decay = wd) 37 | 38 | if not has_wd: 39 | return Adam(params, lr = lr, betas = betas, eps = eps) 40 | 41 | return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps) 42 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | exec(open('audiolm_pytorch/version.py').read()) 3 | 4 | setup( 5 | name = 'audiolm-pytorch', 6 | packages = find_packages(exclude=[]), 7 | version = __version__, 8 | license='MIT', 9 | description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', 10 | author = 'Phil Wang', 11 | author_email = 'lucidrains@gmail.com', 12 | long_description_content_type = 'text/markdown', 13 | url = 'https://github.com/lucidrains/audiolm-pytorch', 14 | keywords = [ 15 | 'artificial intelligence', 16 | 'deep learning', 17 | 'transformers', 18 | 'attention mechanism', 19 | 'audio generation' 20 | ], 21 | install_requires=[ 22 | 'accelerate', 23 | 'beartype', 24 | 'einops>=0.6.1', 25 | 'ema-pytorch>=0.2.2', 26 | 'encodec', 27 | 'fairseq', 28 | 'joblib', 29 | 'lion-pytorch', 30 | 'local-attention>=1.8.4', 31 | 'scikit-learn', 32 | 'sentencepiece', 33 | 'torch>=1.12', 34 | 'torchaudio', 35 | 'transformers', 36 | 'tqdm', 37 | 'vector-quantize-pytorch>=1.5.14' 38 | ], 39 | classifiers=[ 40 | 'Development Status :: 4 - Beta', 41 | 'Intended Audience :: Developers', 42 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 43 | 'License :: OSI Approved :: MIT License', 44 | 'Programming Language :: Python :: 3.6', 45 | ], 46 | ) 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Pycharm 132 | .idea/ 133 | -------------------------------------------------------------------------------- /audiolm_pytorch/vq_wav2vec.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from torch import nn 5 | from einops import rearrange 6 | 7 | import fairseq 8 | 9 | from torchaudio.functional import resample 10 | 11 | from audiolm_pytorch.utils import curtail_to_multiple 12 | 13 | import logging 14 | logging.root.setLevel(logging.ERROR) 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | class FairseqVQWav2Vec(nn.Module): 20 | """ 21 | checkpoint path can be found at https://github.com/facebookresearch/fairseq/blob/main/examples/wav2vec/README.md#vq-wav2vec 22 | specifically download the kmeans model for now 23 | 24 | $ wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec_kmeans.pt 25 | """ 26 | 27 | def __init__( 28 | self, 29 | checkpoint_path, 30 | target_sample_hz = 24000, 31 | seq_len_multiple_of = None 32 | ): 33 | super().__init__() 34 | self.target_sample_hz = target_sample_hz 35 | self.seq_len_multiple_of = seq_len_multiple_of 36 | 37 | path = Path(checkpoint_path) 38 | assert path.exists(), f'path {checkpoint_path} does not exist' 39 | 40 | checkpoint = torch.load(checkpoint_path) 41 | load_model_input = {checkpoint_path: checkpoint} 42 | model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input) 43 | 44 | self.model = model[0] 45 | self.model.eval() 46 | 47 | assert hasattr(self.model, 'vector_quantizer') and hasattr(self.model.vector_quantizer, 'embedding'), 'the vq wav2vec model does not seem to be valid' 48 | 49 | @property 50 | def groups(self): 51 | return self.model.vector_quantizer.groups 52 | 53 | @property 54 | def codebook_size(self): 55 | return self.model.vector_quantizer.embedding.shape[0] 56 | 57 | @torch.no_grad() 58 | def forward( 59 | self, 60 | wav_input, 61 | flatten = True, 62 | input_sample_hz = None 63 | ): 64 | if exists(input_sample_hz): 65 | wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz) 66 | 67 | if exists(self.seq_len_multiple_of): 68 | wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) 69 | 70 | embed = self.model.feature_extractor(wav_input) 71 | _, codebook_indices = self.model.vector_quantizer.forward_idx(embed) 72 | 73 | if not flatten: 74 | return codebook_indices 75 | 76 | return rearrange(codebook_indices, 'b ... -> b (...)') 77 | -------------------------------------------------------------------------------- /audiolm_pytorch/t5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | from transformers import T5Tokenizer, T5EncoderModel, T5Config 4 | 5 | from beartype import beartype 6 | from beartype.typing import Union, List 7 | 8 | # less warning messages since only using encoder 9 | 10 | transformers.logging.set_verbosity_error() 11 | 12 | # helper functions 13 | 14 | def exists(val): 15 | return val is not None 16 | 17 | # config 18 | 19 | MAX_LENGTH = 256 20 | 21 | DEFAULT_T5_NAME = 'google/t5-v1_1-base' 22 | 23 | T5_CONFIGS = {} 24 | 25 | # singleton globals 26 | 27 | def get_tokenizer(name): 28 | tokenizer = T5Tokenizer.from_pretrained(name) 29 | return tokenizer 30 | 31 | def get_model(name): 32 | model = T5EncoderModel.from_pretrained(name) 33 | return model 34 | 35 | def get_model_and_tokenizer(name): 36 | global T5_CONFIGS 37 | 38 | if name not in T5_CONFIGS: 39 | T5_CONFIGS[name] = dict() 40 | 41 | if "model" not in T5_CONFIGS[name]: 42 | T5_CONFIGS[name]["model"] = get_model(name) 43 | 44 | if "tokenizer" not in T5_CONFIGS[name]: 45 | T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name) 46 | 47 | return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer'] 48 | 49 | def get_encoded_dim(name): 50 | if name not in T5_CONFIGS: 51 | config = T5Config.from_pretrained(name) 52 | T5_CONFIGS[name] = dict(config = config) 53 | 54 | elif "config" in T5_CONFIGS[name]: 55 | config = T5_CONFIGS[name]["config"] 56 | 57 | elif "model" in T5_CONFIGS[name]: 58 | config = T5_CONFIGS[name]["model"].config 59 | 60 | else: 61 | raise ValueError(f'unknown t5 name {name}') 62 | 63 | return config.d_model 64 | 65 | # encoding text 66 | 67 | @beartype 68 | def t5_encode_text( 69 | texts: Union[str, List[str]], 70 | name = DEFAULT_T5_NAME, 71 | output_device = None 72 | ): 73 | if isinstance(texts, str): 74 | texts = [texts] 75 | 76 | t5, tokenizer = get_model_and_tokenizer(name) 77 | 78 | if torch.cuda.is_available(): 79 | t5 = t5.cuda() 80 | 81 | device = next(t5.parameters()).device 82 | 83 | encoded = tokenizer.batch_encode_plus( 84 | texts, 85 | return_tensors = 'pt', 86 | padding = 'longest', 87 | max_length = MAX_LENGTH, 88 | truncation = True 89 | ) 90 | 91 | input_ids = encoded.input_ids.to(device) 92 | attn_mask = encoded.attention_mask.to(device) 93 | 94 | t5.eval() 95 | 96 | with torch.no_grad(): 97 | output = t5(input_ids = input_ids, attention_mask = attn_mask) 98 | encoded_text = output.last_hidden_state.detach() 99 | 100 | attn_mask = attn_mask[..., None].bool() 101 | 102 | if not exists(output_device): 103 | encoded_text = encoded_text.masked_fill(~attn_mask, 0.) 104 | return encoded_text 105 | 106 | encoded_text.to(output_device) 107 | attn_mask.to(output_device) 108 | 109 | encoded_text = encoded_text.masked_fill(~attn_mask, 0.) 110 | return encoded_text 111 | -------------------------------------------------------------------------------- /audiolm_pytorch/hubert_kmeans.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from torch import nn 5 | from einops import rearrange, pack, unpack 6 | 7 | import joblib 8 | 9 | import fairseq 10 | 11 | from torchaudio.functional import resample 12 | 13 | from audiolm_pytorch.utils import curtail_to_multiple 14 | 15 | import logging 16 | logging.root.setLevel(logging.ERROR) 17 | 18 | def exists(val): 19 | return val is not None 20 | 21 | def default(val, d): 22 | return val if exists(val) else d 23 | 24 | class HubertWithKmeans(nn.Module): 25 | """ 26 | checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert 27 | or you can train your own 28 | """ 29 | 30 | def __init__( 31 | self, 32 | checkpoint_path, 33 | kmeans_path, 34 | target_sample_hz = 16000, 35 | seq_len_multiple_of = None, 36 | output_layer = 9 37 | ): 38 | super().__init__() 39 | self.target_sample_hz = target_sample_hz 40 | self.seq_len_multiple_of = seq_len_multiple_of 41 | self.output_layer = output_layer 42 | 43 | model_path = Path(checkpoint_path) 44 | kmeans_path = Path(kmeans_path) 45 | 46 | assert model_path.exists(), f'path {checkpoint_path} does not exist' 47 | assert kmeans_path.exists(), f'path {kmeans_path} does not exist' 48 | 49 | checkpoint = torch.load(checkpoint_path) 50 | load_model_input = {checkpoint_path: checkpoint} 51 | model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input) 52 | 53 | self.model = model[0] 54 | self.model.eval() 55 | 56 | kmeans = joblib.load(kmeans_path) 57 | self.kmeans = kmeans 58 | 59 | @property 60 | def groups(self): 61 | return 1 62 | 63 | @property 64 | def codebook_size(self): 65 | return self.kmeans.n_clusters 66 | 67 | @torch.no_grad() 68 | def forward( 69 | self, 70 | wav_input, 71 | flatten = True, 72 | input_sample_hz = None 73 | ): 74 | device = wav_input.device 75 | 76 | if exists(input_sample_hz): 77 | wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz) 78 | 79 | if exists(self.seq_len_multiple_of): 80 | wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) 81 | 82 | embed = self.model( 83 | wav_input, 84 | features_only = True, 85 | mask = False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code 86 | output_layer = self.output_layer 87 | ) 88 | 89 | embed, packed_shape = pack([embed['x']], '* d') 90 | 91 | codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy()) 92 | 93 | codebook_indices = torch.from_numpy(codebook_indices).to(device).long() 94 | 95 | if flatten: 96 | return codebook_indices 97 | 98 | codebook_indices, = unpack(codebook_indices, packed_shape, '*') 99 | return codebook_indices 100 | -------------------------------------------------------------------------------- /audiolm_pytorch/data.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from functools import partial, wraps 3 | 4 | from beartype import beartype 5 | from beartype.typing import Tuple, Union, Optional 6 | from beartype.door import is_bearable 7 | 8 | import torchaudio 9 | from torchaudio.functional import resample 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.nn.utils.rnn import pad_sequence 14 | from torch.utils.data import Dataset, DataLoader 15 | 16 | from audiolm_pytorch.utils import curtail_to_multiple 17 | 18 | from einops import rearrange 19 | 20 | # helper functions 21 | 22 | def exists(val): 23 | return val is not None 24 | 25 | def cast_tuple(val, length = 1): 26 | return val if isinstance(val, tuple) else ((val,) * length) 27 | 28 | # type 29 | 30 | OptionalIntOrTupleInt = Optional[Union[int, Tuple[Optional[int], ...]]] 31 | 32 | # dataset functions 33 | 34 | class SoundDataset(Dataset): 35 | @beartype 36 | def __init__( 37 | self, 38 | folder, 39 | exts = ['flac', 'wav', 'mp3', 'webm'], 40 | max_length: OptionalIntOrTupleInt = None, 41 | target_sample_hz: OptionalIntOrTupleInt = None, 42 | seq_len_multiple_of: OptionalIntOrTupleInt = None 43 | ): 44 | super().__init__() 45 | path = Path(folder) 46 | assert path.exists(), 'folder does not exist' 47 | 48 | files = [file for ext in exts for file in path.glob(f'**/*.{ext}')] 49 | assert len(files) > 0, 'no sound files found' 50 | 51 | self.files = files 52 | 53 | self.target_sample_hz = cast_tuple(target_sample_hz) 54 | num_outputs = len(self.target_sample_hz) 55 | 56 | self.max_length = cast_tuple(max_length, num_outputs) 57 | self.seq_len_multiple_of = cast_tuple(seq_len_multiple_of, num_outputs) 58 | 59 | assert len(self.max_length) == len(self.target_sample_hz) == len(self.seq_len_multiple_of) 60 | 61 | def __len__(self): 62 | return len(self.files) 63 | 64 | def __getitem__(self, idx): 65 | file = self.files[idx] 66 | 67 | data, sample_hz = torchaudio.load(file) 68 | 69 | assert data.numel() > 0, f'one of your audio file ({file}) is empty. please remove it from your folder' 70 | 71 | if data.shape[0] > 1: 72 | # the audio has more than 1 channel, convert to mono 73 | data = torch.mean(data, dim=0).unsqueeze(0) 74 | 75 | num_outputs = len(self.target_sample_hz) 76 | data = cast_tuple(data, num_outputs) 77 | 78 | # resample if target_sample_hz is not None in the tuple 79 | 80 | data_tuple = tuple((resample(d, sample_hz, target_sample_hz) if exists(target_sample_hz) else d) for d, target_sample_hz in zip(data, self.target_sample_hz)) 81 | 82 | output = [] 83 | 84 | # process each of the data resample at different frequencies individually 85 | 86 | for data, max_length, seq_len_multiple_of in zip(data_tuple, self.max_length, self.seq_len_multiple_of): 87 | audio_length = data.size(1) 88 | 89 | # pad or curtail 90 | 91 | if audio_length > max_length: 92 | max_start = audio_length - max_length 93 | start = torch.randint(0, max_start, (1, )) 94 | data = data[:, start:start + max_length] 95 | 96 | else: 97 | data = F.pad(data, (0, max_length - audio_length), 'constant') 98 | 99 | data = rearrange(data, '1 ... -> ...') 100 | 101 | if exists(max_length): 102 | data = data[:max_length] 103 | 104 | if exists(seq_len_multiple_of): 105 | data = curtail_to_multiple(data, seq_len_multiple_of) 106 | 107 | output.append(data.float()) 108 | 109 | # cast from list to tuple 110 | 111 | output = tuple(output) 112 | 113 | # return only one audio, if only one target resample freq 114 | 115 | if num_outputs == 1: 116 | return output[0] 117 | 118 | return output 119 | 120 | # dataloader functions 121 | 122 | def collate_one_or_multiple_tensors(fn): 123 | @wraps(fn) 124 | def inner(data): 125 | is_one_data = not isinstance(data[0], tuple) 126 | 127 | if is_one_data: 128 | data = torch.stack(data) 129 | return (data,) 130 | 131 | outputs = [] 132 | for datum in zip(*data): 133 | if is_bearable(datum, Tuple[str, ...]): 134 | output = list(datum) 135 | else: 136 | output = fn(datum) 137 | 138 | outputs.append(output) 139 | 140 | return tuple(outputs) 141 | 142 | return inner 143 | 144 | @collate_one_or_multiple_tensors 145 | def curtail_to_shortest_collate(data): 146 | min_len = min(*[datum.shape[0] for datum in data]) 147 | data = [datum[:min_len] for datum in data] 148 | return torch.stack(data) 149 | 150 | @collate_one_or_multiple_tensors 151 | def pad_to_longest_fn(data): 152 | return pad_sequence(data, batch_first = True) 153 | 154 | def get_dataloader(ds, pad_to_longest = True, **kwargs): 155 | collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate 156 | return DataLoader(ds, collate_fn = collate_fn, **kwargs) 157 | -------------------------------------------------------------------------------- /audiolm_pytorch/encodec.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from einops import rearrange, pack, unpack 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from vector_quantize_pytorch import ResidualVQ 8 | 9 | from encodec import EncodecModel 10 | from encodec.utils import _linear_overlap_add 11 | 12 | class EncodecWrapper(nn.Module): 13 | """ 14 | Support pretrained 24kHz Encodec by Meta AI, if you want to skip training SoundStream. 15 | 16 | TODO: 17 | - see if we need to keep the scaled version and somehow persist the scale factors for when we need to decode? Right 18 | now I'm just setting self.model.normalize = False to sidestep all of that 19 | - see if we can use the 48kHz model, which is specifically for music. Right now we're using the 24kHz model because 20 | that's what was used in MusicLM and avoids any resampling issues. 21 | - 22 | 23 | """ 24 | def __init__( 25 | self, 26 | target_sample_hz = 24000, 27 | strides = (2, 4, 5, 8), 28 | num_quantizers = 8, 29 | ): 30 | super().__init__() 31 | # Instantiate a pretrained EnCodec model 32 | self.model = EncodecModel.encodec_model_24khz() 33 | self.model.normalize = False # this means we don't need to scale codes e.g. when running model.encode(wav) 34 | 35 | # bandwidth affects num quantizers used: https://github.com/facebookresearch/encodec/pull/41 36 | self.model.set_target_bandwidth(6.0) 37 | assert num_quantizers == 8, "assuming 8 quantizers for now, see bandwidth comment above" 38 | 39 | # Fields that SoundStream has that get used externally. We replicate them here. 40 | self.target_sample_hz = target_sample_hz 41 | assert self.target_sample_hz == 24000, "haven't done anything with non-24kHz yet" 42 | 43 | self.codebook_dim = 128 44 | self.rq_groups = 1 45 | self.num_quantizers = num_quantizers 46 | self.strides = strides # used in seq_len_multiple_of 47 | 48 | # cross entropy loss to indices passed in on l2 distance logits introduced in vector-quantize-pytorch 1.2.2 49 | 50 | self.rq = ResidualVQ( 51 | dim = 128, 52 | codebook_size = 1024, 53 | num_quantizers = 8 54 | ) 55 | 56 | # copy codebook over to ResidualVQ for cross entropy loss logic from naturalspeech2 57 | # luckily, it seems Meta AI basically used my ResidualVQ code verbatim. makes porting it over easy 58 | 59 | for encodec_rq_layer, rq_layer in zip(self.model.quantizer.vq.layers, self.rq.layers): 60 | encodec_codebook = dict(encodec_rq_layer._codebook.named_buffers()).get('embed') 61 | vq_codebook = dict(rq_layer._codebook.named_buffers()).get('embed') 62 | 63 | encodec_codebook = rearrange(encodec_codebook, '... -> 1 ...') 64 | vq_codebook.copy_(encodec_codebook) 65 | 66 | @property 67 | def seq_len_multiple_of(self): 68 | return reduce(lambda x, y: x * y, self.strides) 69 | 70 | def forward( 71 | self, 72 | x, 73 | return_encoded = False, 74 | **kwargs 75 | ): 76 | 77 | x, ps = pack([x], '* n') 78 | 79 | # kwargs for stuff like return_encoded=True, which SoundStream uses but Encodec doesn't 80 | assert not self.model.training, "Encodec is pretrained and should never be called outside eval mode." 81 | # Unlike in the Encodec sample code in its README, x has already been resampled so we don't need to call 82 | # convert_audio and unsqueeze. The convert_audio function also doesn't play nicely with batches. 83 | 84 | # b = batch, t = timesteps, 1 channel for the 24kHz model, 2 channels for the 48kHz model 85 | wav = rearrange(x, f'b t -> b {self.model.channels} t') 86 | 87 | # Extract discrete codes from EnCodec 88 | with torch.no_grad(): 89 | encoded_frames = self.model.encode(wav) 90 | # encoded_frames is a list of (frame, scale) tuples. Scale is a scalar but we don't use it. Frame is a tensor 91 | # of shape [batch, num_quantizers, num_samples_per_frame]. We want to concatenate the frames to get all the 92 | # timesteps concatenated. 93 | codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [batch, num_quantizers, timesteps] 94 | # transformer code that uses codec expects codes to be [batch, timesteps, num_quantizers] 95 | codes = rearrange(codes, 'b q n -> b n q') # result: [batch, timesteps, num_quantizers] 96 | # in original soundstream, is x, indices, commit_loss. But we only use indices in eval mode, so just keep that. 97 | 98 | # allow for returning of sum of quantized embeddings 99 | 100 | emb = None 101 | 102 | if return_encoded: 103 | emb = self.get_emb_from_indices(codes) 104 | emb, = unpack(emb, ps, '* n c') 105 | 106 | codes, = unpack(codes, ps, '* n q') 107 | 108 | return emb, codes, None 109 | 110 | def decode_from_codebook_indices(self, quantized_indices): 111 | # Input: batch x num tokens x num quantizers 112 | # Output: batch x 1 x num samples 113 | 114 | assert self.model.sample_rate == 24000,\ 115 | "if changing to 48kHz, that model segments its audio into lengths of 1.0 second with 1% overlap, whereas " \ 116 | "the 24kHz doesn't segment at all. this means the frame decode logic might change; this is a reminder to " \ 117 | "double check that." 118 | # Since 24kHz pretrained doesn't do any segmenting, we have all the frames already (1 frame = 1 token in quantized_indices) 119 | 120 | # The following code is hacked in from self.model.decode() (Encodec version 0.1.1) where we skip the part about 121 | # scaling. 122 | # Shape: 1 x (num_frames * stride product). 1 because we have 1 frame (because no segmenting) 123 | frames = self._decode_frame(quantized_indices) 124 | result = _linear_overlap_add(frames, self.model.segment_stride or 1) 125 | # TODO: I'm not overly pleased with this because when this function gets called, we just rearrange the result 126 | # back to b n anyways, but we'll keep this as a temporary hack just to make things work for now 127 | return rearrange(result, 'b n -> b 1 n') 128 | 129 | def get_emb_from_indices(self, indices): 130 | codes = rearrange(indices, 'b t q -> q b t') 131 | emb = self.model.quantizer.decode(codes) 132 | return rearrange(emb, 'b c n -> b n c') 133 | 134 | def decode(self, emb): 135 | emb = rearrange(emb, 'b n c -> b c n') 136 | return self.model.decoder(emb) 137 | 138 | def _decode_frame(self, quantized_indices): 139 | # The following code is hacked in from self.model._decode_frame() (Encodec version 0.1.1) where we assume we've 140 | # already unwrapped the EncodedFrame 141 | # Input: batch x num tokens x num quantizers 142 | # Output: batch x new_num_samples, where new_num_samples is num_frames * stride product (may be slightly 143 | # larger than original num samples as a result, because the last frame might not be "fully filled" with samples 144 | # if num_samples doesn't divide perfectly). 145 | # num_frames == the number of acoustic tokens you have, one token per frame 146 | codes = rearrange(quantized_indices, 'b t q -> q b t') 147 | emb = self.model.quantizer.decode(codes) 148 | # emb shape: batch x self.model.quantizer.dimension x T. Note self.model.quantizer.dimension is the embedding dimension 149 | return self.model.decoder(emb) 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## AudioLM - Pytorch 4 | 5 | Implementation of AudioLM, a Language Modeling Approach to Audio Generation out of Google Research, in Pytorch 6 | 7 | It also extends the work for conditioning with classifier free guidance with T5. This allows for one to do text-to-audio or TTS, not offered in the paper. Yes, this means VALL-E can be trained from this repository. It is essentially the same. 8 | 9 | Please join Join us on Discord if you are interested in replicating this work in the open 10 | 11 | This repository now also contains a MIT licensed version of SoundStream. It is also compatible with EnCodec, which is also [MIT-licensed](https://github.com/facebookresearch/encodec/commit/349b72939f57cb3bc7b60906c0ee8228c849485d) at the time of writing. 12 | 13 | Update: AudioLM was essentially used to 'solve' music generation in the new MusicLM 14 | 15 | In the future, this movie clip would no longer make any sense. You would just prompt an AI instead. 16 | 17 | ## Appreciation 18 | 19 | - Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research 20 | 21 | - 🤗 Huggingface for their amazing accelerate and transformers libraries 22 | 23 | - MetaAI for Fairseq and the liberal license 24 | 25 | - @eonglints and Joseph for offering their professional advice and expertise as well as pull requests! 26 | 27 | - @djqualia, @yigityu, @inspirit, and @BlackFox1197 for helping with the debugging of soundstream 28 | 29 | - Allen and LWprogramming for reviewing the code and submitting bug fixes! 30 | 31 | - Ilya for finding an issue with multi-scale discriminator downsampling and for soundstream trainer improvements 32 | 33 | - Andrey for identifying a missing loss in soundstream and guiding me through the proper mel spectrogram hyperparameters 34 | 35 | - Alejandro and Ilya for sharing their results with training soundstream, and for working through a few issues with the local attention positional embeddings 36 | 37 | - LWprogramming for adding Encodec compatibility! 38 | 39 | - @YoungloLee for identifying a big bug in the 1d causal convolution for soundstream related to padding not accounting for strides! 40 | 41 | - Hayden for pointing out some discrepancies in the multi-scale discriminator for Soundstream 42 | 43 | ## Install 44 | 45 | ```bash 46 | $ pip install audiolm-pytorch 47 | ``` 48 | 49 | ## Usage 50 | 51 | ### SoundStream & Encodec 52 | 53 | There are two options for the neural codec. If you want to use the pretrained 24kHz Encodec, just create an Encodec object as follows: 54 | ```python 55 | from audiolm_pytorch import EncodecWrapper 56 | encodec = EncodecWrapper() 57 | # Now you can use the encodec variable in the same way you'd use the soundstream variables below. 58 | ``` 59 | 60 | Otherwise, to stay more true to the original paper, you can use `SoundStream`. First, `SoundStream` needs to be trained on a large corpus of audio data 61 | 62 | ```python 63 | from audiolm_pytorch import SoundStream, SoundStreamTrainer 64 | 65 | soundstream = SoundStream( 66 | codebook_size = 1024, 67 | rq_num_quantizers = 8, 68 | rq_groups = 2, # this paper proposes using multi-headed residual vector quantization - https://arxiv.org/abs/2305.02765 69 | attn_window_size = 128, # local attention receptive field at bottleneck 70 | attn_depth = 2 # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better 71 | ) 72 | 73 | trainer = SoundStreamTrainer( 74 | soundstream, 75 | folder = '/path/to/audio/files', 76 | batch_size = 4, 77 | grad_accum_every = 8, # effective batch size of 32 78 | data_max_length_seconds = 2, # train on 2 second audio 79 | num_train_steps = 1_000_000 80 | ).cuda() 81 | 82 | trainer.train() 83 | 84 | # after a lot of training, you can test the autoencoding as so 85 | 86 | audio = torch.randn(10080).cuda() 87 | recons = soundstream(audio, return_recons_only = True) # (1, 10080) - 1 channel 88 | ``` 89 | 90 | You can also use soundstreams that are specific to `AudioLM` and `MusicLM` by importing `AudioLMSoundStream` and `MusicLMSoundStream` respectively 91 | 92 | ```python 93 | from audiolm_pytorch import AudioLMSoundStream, MusicLMSoundStream 94 | 95 | soundstream = AudioLMSoundStream(...) # say you want the hyperparameters as in Audio LM paper 96 | 97 | # rest is the same as above 98 | ``` 99 | 100 | As of version `0.17.0`, you can now invoke the class method on `SoundStream` to load from checkpoint files, without having to remember your configurations. 101 | 102 | ```python 103 | from audiolm_pytorch import SoundStream 104 | 105 | soundstream = SoundStream.init_and_load_from('./path/to/checkpoint.pt') 106 | ``` 107 | 108 | ### Hierarchical Transformers 109 | 110 | Then three separate transformers (`SemanticTransformer`, `CoarseTransformer`, `FineTransformer`) need to be trained 111 | 112 | 113 | ex. `SemanticTransformer` 114 | 115 | ```python 116 | import torch 117 | from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer 118 | 119 | # hubert checkpoints can be downloaded at 120 | # https://github.com/facebookresearch/fairseq/tree/main/examples/hubert 121 | 122 | wav2vec = HubertWithKmeans( 123 | checkpoint_path = './hubert/hubert_base_ls960.pt', 124 | kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin' 125 | ) 126 | 127 | semantic_transformer = SemanticTransformer( 128 | num_semantic_tokens = wav2vec.codebook_size, 129 | dim = 1024, 130 | depth = 6 131 | ).cuda() 132 | 133 | 134 | trainer = SemanticTransformerTrainer( 135 | transformer = semantic_transformer, 136 | wav2vec = wav2vec, 137 | folder ='/path/to/audio/files', 138 | batch_size = 1, 139 | data_max_length = 320 * 32, 140 | num_train_steps = 1 141 | ) 142 | 143 | trainer.train() 144 | ``` 145 | 146 | ex. `CoarseTransformer` 147 | 148 | ```python 149 | import torch 150 | from audiolm_pytorch import HubertWithKmeans, SoundStream, CoarseTransformer, CoarseTransformerTrainer 151 | 152 | wav2vec = HubertWithKmeans( 153 | checkpoint_path = './hubert/hubert_base_ls960.pt', 154 | kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin' 155 | ) 156 | 157 | soundstream = SoundStream.init_and_load_from('/path/to/trained/soundstream.pt') 158 | 159 | coarse_transformer = CoarseTransformer( 160 | num_semantic_tokens = wav2vec.codebook_size, 161 | codebook_size = 1024, 162 | num_coarse_quantizers = 3, 163 | dim = 512, 164 | depth = 6 165 | ) 166 | 167 | trainer = CoarseTransformerTrainer( 168 | transformer = coarse_transformer, 169 | codec = soundstream, 170 | wav2vec = wav2vec, 171 | folder = '/path/to/audio/files', 172 | batch_size = 1, 173 | data_max_length = 320 * 32, 174 | num_train_steps = 1_000_000 175 | ) 176 | 177 | trainer.train() 178 | ``` 179 | 180 | ex. `FineTransformer` 181 | 182 | ```python 183 | import torch 184 | from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerTrainer 185 | 186 | soundstream = SoundStream.init_and_load_from('/path/to/trained/soundstream.pt') 187 | 188 | fine_transformer = FineTransformer( 189 | num_coarse_quantizers = 3, 190 | num_fine_quantizers = 5, 191 | codebook_size = 1024, 192 | dim = 512, 193 | depth = 6 194 | ) 195 | 196 | trainer = FineTransformerTrainer( 197 | transformer = fine_transformer, 198 | codec = soundstream, 199 | folder = '/path/to/audio/files', 200 | batch_size = 1, 201 | data_max_length = 320 * 32, 202 | num_train_steps = 1_000_000 203 | ) 204 | 205 | trainer.train() 206 | ``` 207 | 208 | All together now 209 | 210 | ```python 211 | from audiolm_pytorch import AudioLM 212 | 213 | audiolm = AudioLM( 214 | wav2vec = wav2vec, 215 | codec = soundstream, 216 | semantic_transformer = semantic_transformer, 217 | coarse_transformer = coarse_transformer, 218 | fine_transformer = fine_transformer 219 | ) 220 | 221 | generated_wav = audiolm(batch_size = 1) 222 | 223 | # or with priming 224 | 225 | generated_wav_with_prime = audiolm(prime_wave = torch.randn(1, 320 * 8)) 226 | 227 | # or with text condition, if given 228 | 229 | generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the distant echos of bells']) 230 | 231 | ``` 232 | 233 | ## Text Conditioned Audio Synthesis 234 | 235 | Update: Looks like this will work, given 'VALL-E' 236 | 237 | ex. Semantic Transformer 238 | 239 | ```python 240 | import torch 241 | from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer 242 | 243 | wav2vec = HubertWithKmeans( 244 | checkpoint_path = './hubert/hubert_base_ls960.pt', 245 | kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin' 246 | ) 247 | 248 | semantic_transformer = SemanticTransformer( 249 | num_semantic_tokens = 500, 250 | dim = 1024, 251 | depth = 6, 252 | has_condition = True, # this will have to be set to True 253 | cond_as_self_attn_prefix = True # whether to condition as prefix to self attention, instead of cross attention, as was done in 'VALL-E' paper 254 | ).cuda() 255 | 256 | # mock text video dataset (as an example) 257 | 258 | # you will have to extend your own from `Dataset`, and return an audio tensor as well as a string (the audio description) in any order (the framework will autodetect and route it into the transformer) 259 | 260 | from torch.utils.data import Dataset 261 | 262 | class MockTextAudioDataset(Dataset): 263 | def __init__(self, length = 100, audio_length = 320 * 32): 264 | super().__init__() 265 | self.audio_length = audio_length 266 | self.len = length 267 | 268 | def __len__(self): 269 | return self.len 270 | 271 | def __getitem__(self, idx): 272 | mock_audio = torch.randn(self.audio_length) 273 | mock_caption = 'audio caption' 274 | return mock_caption, mock_audio 275 | 276 | dataset = MockTextAudioDataset() 277 | 278 | # instantiate semantic transformer trainer and train 279 | 280 | trainer = SemanticTransformerTrainer( 281 | transformer = semantic_transformer, 282 | wav2vec = wav2vec, 283 | dataset = dataset, 284 | batch_size = 4, 285 | grad_accum_every = 8, 286 | data_max_length = 320 * 32, 287 | num_train_steps = 1_000_000 288 | ) 289 | 290 | trainer.train() 291 | 292 | # after much training above 293 | 294 | sample = trainer.generate(text = ['sound of rain drops on the rooftops'], batch_size = 1, max_length = 2) # (1, < 128) - may terminate early if it detects [eos] 295 | 296 | ``` 297 | 298 | ## Multi-GPU 299 | 300 | Because all the trainer classes uses 🤗 Accelerator, you can easily do multi gpu training by using the `accelerate` command as so 301 | 302 | At the project root 303 | 304 | ```python 305 | $ accelerate config 306 | ``` 307 | 308 | Then, in the same directory 309 | 310 | ```python 311 | $ accelerate launch train.py 312 | ``` 313 | 314 | ## Todo 315 | 316 | - [x] complete CoarseTransformer 317 | - [x] use fairseq vq-wav2vec for embeddings 318 | - [x] add conditioning 319 | - [x] add classifier free guidance 320 | - [x] add unique consecutive for 321 | - [x] incorporate ability to use hubert intermediate features as semantic tokens, recommended by eonglints 322 | - [x] accommodate variable lengthed audio, bring in eos token 323 | - [x] make sure unique consecutive works with coarse transformer 324 | - [x] pretty printing all discriminator losses to log 325 | - [x] handle when generating semantic tokens, that last logits may not be necessarily the last in the sequence given unique consecutive processing 326 | - [x] complete sampling code for both Coarse and Fine Transformers, which will be tricky 327 | - [x] make sure full inference with or without prompting works on the `AudioLM` class 328 | - [x] complete full training code for soundstream, taking care of discriminator training 329 | - [x] add efficient gradient penalty for discriminators for soundstream 330 | - [x] wire up sample hz from sound dataset -> transformers, and have proper resampling within during training - think about whether to allow for dataset to have sound files of varying or enforce same sample hz 331 | - [x] full transformer training code for all three transformers 332 | - [x] refactor so semantic transformer has a wrapper to that handles unique consecutives as well as wav to hubert or vq-wav2vec 333 | - [x] simply not self attend to eos token on the prompting side (semantic for coarse transformer, coarse for fine transformer) 334 | - [x] add structured dropout from forgetful causal masking, far better than traditional dropouts 335 | - [x] figure out how to suppress logging in fairseq 336 | - [x] assert that all three transformers passed into audiolm is compatible 337 | - [x] allow for specialized relative positional embeddings in fine transformer based on absolute matching positions of quantizers between coarse and fine 338 | - [x] allow for grouped residual vq in soundstream (use `GroupedResidualVQ` from vector-quantize-pytorch lib), from hifi-codec 339 | 340 | - [ ] redo the positional embeddings in the presence of groups in residual vq 341 | - [ ] test with speech synthesis for starters 342 | - [ ] cli tool, something like `audiolm generate ` and save generated wav file to local directory 343 | - [ ] return a list of waves in the case of variable lengthed audio 344 | - [ ] just take care of the edge case in coarse transformer text conditioned training, where the raw wave is resampled at different frequencies. autodetermine how to route based on length 345 | 346 | ## Citations 347 | 348 | ```bibtex 349 | @inproceedings{Borsos2022AudioLMAL, 350 | title = {AudioLM: a Language Modeling Approach to Audio Generation}, 351 | author = {Zal{\'a}n Borsos and Rapha{\"e}l Marinier and Damien Vincent and Eugene Kharitonov and Olivier Pietquin and Matthew Sharifi and Olivier Teboul and David Grangier and Marco Tagliasacchi and Neil Zeghidour}, 352 | year = {2022} 353 | } 354 | ``` 355 | 356 | ```bibtex 357 | @misc{https://doi.org/10.48550/arxiv.2107.03312, 358 | title = {SoundStream: An End-to-End Neural Audio Codec}, 359 | author = {Zeghidour, Neil and Luebs, Alejandro and Omran, Ahmed and Skoglund, Jan and Tagliasacchi, Marco}, 360 | publisher = {arXiv}, 361 | url = {https://arxiv.org/abs/2107.03312}, 362 | year = {2021} 363 | } 364 | ``` 365 | 366 | ```bibtex 367 | @misc{shazeer2020glu, 368 | title = {GLU Variants Improve Transformer}, 369 | author = {Noam Shazeer}, 370 | year = {2020}, 371 | url = {https://arxiv.org/abs/2002.05202} 372 | } 373 | ``` 374 | 375 | ```bibtex 376 | @article{Shazeer2019FastTD, 377 | title = {Fast Transformer Decoding: One Write-Head is All You Need}, 378 | author = {Noam M. Shazeer}, 379 | journal = {ArXiv}, 380 | year = {2019}, 381 | volume = {abs/1911.02150} 382 | } 383 | ``` 384 | 385 | ```bibtex 386 | @article{Ho2022ClassifierFreeDG, 387 | title = {Classifier-Free Diffusion Guidance}, 388 | author = {Jonathan Ho}, 389 | journal = {ArXiv}, 390 | year = {2022}, 391 | volume = {abs/2207.12598} 392 | } 393 | ``` 394 | 395 | ```bibtex 396 | @misc{crowson2022, 397 | author = {Katherine Crowson}, 398 | url = {https://twitter.com/rivershavewings} 399 | } 400 | ``` 401 | 402 | ```bibtex 403 | @misc{ding2021cogview, 404 | title = {CogView: Mastering Text-to-Image Generation via Transformers}, 405 | author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang}, 406 | year = {2021}, 407 | eprint = {2105.13290}, 408 | archivePrefix = {arXiv}, 409 | primaryClass = {cs.CV} 410 | } 411 | ``` 412 | 413 | ```bibtex 414 | @article{Liu2022FCMFC, 415 | title = {FCM: Forgetful Causal Masking Makes Causal Language Models Better Zero-Shot Learners}, 416 | author = {Hao Liu and Xinyang Geng and Lisa Lee and Igor Mordatch and Sergey Levine and Sharan Narang and P. Abbeel}, 417 | journal = {ArXiv}, 418 | year = {2022}, 419 | volume = {abs/2210.13432} 420 | } 421 | ``` 422 | 423 | ```bibtex 424 | @inproceedings{anonymous2022normformer, 425 | title = {NormFormer: Improved Transformer Pretraining with Extra Normalization}, 426 | author = {Anonymous}, 427 | booktitle = {Submitted to The Tenth International Conference on Learning Representations }, 428 | year = {2022}, 429 | url = {https://openreview.net/forum?id=GMYWzWztDx5}, 430 | note = {under review} 431 | } 432 | ``` 433 | 434 | ```bibtex 435 | @article{Li2021LocalViTBL, 436 | title = {LocalViT: Bringing Locality to Vision Transformers}, 437 | author = {Yawei Li and K. Zhang and Jie Cao and Radu Timofte and Luc Van Gool}, 438 | journal = {ArXiv}, 439 | year = {2021}, 440 | volume = {abs/2104.05707} 441 | } 442 | ``` 443 | 444 | ```bibtex 445 | @misc{liu2021swin, 446 | title = {Swin Transformer V2: Scaling Up Capacity and Resolution}, 447 | author = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo}, 448 | year = {2021}, 449 | eprint = {2111.09883}, 450 | archivePrefix = {arXiv}, 451 | primaryClass = {cs.CV} 452 | } 453 | ``` 454 | 455 | ```bibtex 456 | @inproceedings{Ma2022MegaMA, 457 | title = {Mega: Moving Average Equipped Gated Attention}, 458 | author = {Xuezhe Ma and Chunting Zhou and Xiang Kong and Junxian He and Liangke Gui and Graham Neubig and Jonathan May and Luke Zettlemoyer}, 459 | year = {2022} 460 | } 461 | ``` 462 | 463 | ```bibtex 464 | @misc{gilmer2023intriguing 465 | title = {Intriguing Properties of Transformer Training Instabilities}, 466 | author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen}, 467 | year = {2023}, 468 | status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams} 469 | } 470 | ``` 471 | 472 | ```bibtex 473 | @article{Defossez2022HighFN, 474 | title = {High Fidelity Neural Audio Compression}, 475 | author = {Alexandre D'efossez and Jade Copet and Gabriel Synnaeve and Yossi Adi}, 476 | journal = {ArXiv}, 477 | year = {2022}, 478 | volume = {abs/2210.13438} 479 | } 480 | ``` 481 | 482 | ```bibtex 483 | @article{Hu2017SqueezeandExcitationN, 484 | title = {Squeeze-and-Excitation Networks}, 485 | author = {Jie Hu and Li Shen and Gang Sun}, 486 | journal = {2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 487 | year = {2017}, 488 | pages = {7132-7141} 489 | } 490 | ``` 491 | 492 | ```bibtex 493 | @inproceedings{Yang2023HiFiCodecGV, 494 | title = {HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec}, 495 | author = {Dongchao Yang and Songxiang Liu and Rongjie Huang and Jinchuan Tian and Chao Weng and Yuexian Zou}, 496 | year = {2023} 497 | } 498 | ``` 499 | -------------------------------------------------------------------------------- /audiolm_pytorch_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 12, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/" 9 | }, 10 | "id": "n337KoD2om3L", 11 | "outputId": "97ada0c6-f21c-483e-d63d-08abddd49004" 12 | }, 13 | "outputs": [ 14 | { 15 | "name": "stdout", 16 | "output_type": "stream", 17 | "text": [ 18 | "Mon Jan 30 20:47:47 2023 \n", 19 | "+-----------------------------------------------------------------------------+\n", 20 | "| NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 |\n", 21 | "|-------------------------------+----------------------+----------------------+\n", 22 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 23 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 24 | "| | | MIG M. |\n", 25 | "|===============================+======================+======================|\n", 26 | "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", 27 | "| N/A 73C P0 32W / 70W | 10692MiB / 15360MiB | 0% Default |\n", 28 | "| | | N/A |\n", 29 | "+-------------------------------+----------------------+----------------------+\n", 30 | " \n", 31 | "+-----------------------------------------------------------------------------+\n", 32 | "| Processes: |\n", 33 | "| GPU GI CI PID Type Process name GPU Memory |\n", 34 | "| ID ID Usage |\n", 35 | "|=============================================================================|\n", 36 | "| 0 N/A N/A 5896 C 10689MiB |\n", 37 | "+-----------------------------------------------------------------------------+\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "!nvidia-smi\n", 43 | "\n", 44 | "# If this doesn't work, there's no GPU available or detected" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 13, 50 | "metadata": { 51 | "colab": { 52 | "base_uri": "https://localhost:8080/" 53 | }, 54 | "id": "TLJAcUHpvmp4", 55 | "outputId": "95bcda95-a484-40c6-e5a7-47f4378759a8" 56 | }, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", 63 | "Requirement already satisfied: audiolm-pytorch in /usr/local/lib/python3.8/dist-packages (0.7.5)\n", 64 | "Requirement already satisfied: ema-pytorch in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.1.4)\n", 65 | "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.1.97)\n", 66 | "Requirement already satisfied: beartype in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.12.0)\n", 67 | "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (1.0.2)\n", 68 | "Requirement already satisfied: torchaudio in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.13.1+cu116)\n", 69 | "Requirement already satisfied: joblib in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (1.2.0)\n", 70 | "Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (1.13.1+cu116)\n", 71 | "Requirement already satisfied: transformers in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (4.26.0)\n", 72 | "Requirement already satisfied: Mega-pytorch in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.0.12)\n", 73 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (4.64.1)\n", 74 | "Requirement already satisfied: accelerate in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.15.0)\n", 75 | "Requirement already satisfied: vector-quantize-pytorch>=0.10.15 in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.10.15)\n", 76 | "Requirement already satisfied: einops>=0.6 in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.6.0)\n", 77 | "Requirement already satisfied: local-attention>=1.5.7 in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (1.5.8)\n", 78 | "Requirement already satisfied: fairseq in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.12.2)\n", 79 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->audiolm-pytorch) (4.4.0)\n", 80 | "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from accelerate->audiolm-pytorch) (21.3)\n", 81 | "Requirement already satisfied: pyyaml in /usr/local/lib/python3.8/dist-packages (from accelerate->audiolm-pytorch) (6.0)\n", 82 | "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from accelerate->audiolm-pytorch) (1.21.6)\n", 83 | "Requirement already satisfied: psutil in /usr/local/lib/python3.8/dist-packages (from accelerate->audiolm-pytorch) (5.4.8)\n", 84 | "Requirement already satisfied: bitarray in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (2.6.2)\n", 85 | "Requirement already satisfied: regex in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (2022.6.2)\n", 86 | "Requirement already satisfied: hydra-core<1.1,>=1.0.7 in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (1.0.7)\n", 87 | "Requirement already satisfied: cffi in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (1.15.1)\n", 88 | "Requirement already satisfied: cython in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (0.29.33)\n", 89 | "Requirement already satisfied: omegaconf<2.1 in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (2.0.6)\n", 90 | "Requirement already satisfied: sacrebleu>=1.4.12 in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (2.3.1)\n", 91 | "Requirement already satisfied: scipy in /usr/local/lib/python3.8/dist-packages (from Mega-pytorch->audiolm-pytorch) (1.7.3)\n", 92 | "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn->audiolm-pytorch) (3.1.0)\n", 93 | "Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers->audiolm-pytorch) (2.25.1)\n", 94 | "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.8/dist-packages (from transformers->audiolm-pytorch) (0.12.0)\n", 95 | "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.8/dist-packages (from transformers->audiolm-pytorch) (0.13.2)\n", 96 | "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers->audiolm-pytorch) (3.9.0)\n", 97 | "Requirement already satisfied: antlr4-python3-runtime==4.8 in /usr/local/lib/python3.8/dist-packages (from hydra-core<1.1,>=1.0.7->fairseq->audiolm-pytorch) (4.8)\n", 98 | "Requirement already satisfied: importlib-resources in /usr/local/lib/python3.8/dist-packages (from hydra-core<1.1,>=1.0.7->fairseq->audiolm-pytorch) (5.10.2)\n", 99 | "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.0->accelerate->audiolm-pytorch) (3.0.9)\n", 100 | "Requirement already satisfied: portalocker in /usr/local/lib/python3.8/dist-packages (from sacrebleu>=1.4.12->fairseq->audiolm-pytorch) (2.7.0)\n", 101 | "Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.8/dist-packages (from sacrebleu>=1.4.12->fairseq->audiolm-pytorch) (0.8.10)\n", 102 | "Requirement already satisfied: colorama in /usr/local/lib/python3.8/dist-packages (from sacrebleu>=1.4.12->fairseq->audiolm-pytorch) (0.4.6)\n", 103 | "Requirement already satisfied: lxml in /usr/local/lib/python3.8/dist-packages (from sacrebleu>=1.4.12->fairseq->audiolm-pytorch) (4.9.2)\n", 104 | "Requirement already satisfied: pycparser in /usr/local/lib/python3.8/dist-packages (from cffi->fairseq->audiolm-pytorch) (2.21)\n", 105 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers->audiolm-pytorch) (2.10)\n", 106 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers->audiolm-pytorch) (2022.12.7)\n", 107 | "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers->audiolm-pytorch) (1.24.3)\n", 108 | "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers->audiolm-pytorch) (4.0.0)\n", 109 | "Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.8/dist-packages (from importlib-resources->hydra-core<1.1,>=1.0.7->fairseq->audiolm-pytorch) (3.11.0)\n" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "!pip install audiolm-pytorch" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": { 120 | "id": "xuNcsDJsvQwh" 121 | }, 122 | "source": [ 123 | "## Setup\n", 124 | "\n", 125 | "Includes:\n", 126 | "\n", 127 | "- How to generate a placeholder dataset if you haven't already, just the basics to run \"training\" e2e on a tiny dataset\n", 128 | "- How to download a dataset from OpenSLR" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": { 134 | "id": "jBxNK5cKW--_" 135 | }, 136 | "source": [ 137 | "### Imports & paths" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 14, 143 | "metadata": { 144 | "id": "OrNeKngVVM0L" 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "# imports\n", 149 | "import math\n", 150 | "import wave\n", 151 | "import struct\n", 152 | "import os\n", 153 | "import urllib.request\n", 154 | "import tarfile\n", 155 | "from audiolm_pytorch import SoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM\n", 156 | "from torch import nn\n", 157 | "import torch\n", 158 | "import torchaudio\n", 159 | "\n", 160 | "\n", 161 | "# define all dataset paths, checkpoints, etc\n", 162 | "dataset_folder = \"placeholder_dataset\"\n", 163 | "soundstream_ckpt = \"results/soundstream.8.pt\" # this can change depending on number of steps\n", 164 | "hubert_ckpt = 'hubert/hubert_base_ls960.pt'\n", 165 | "hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row \"HuBERT Base (~95M params)\", column Quantizer" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": { 171 | "id": "pA56YODZXBtf" 172 | }, 173 | "source": [ 174 | "### Data" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 15, 180 | "metadata": { 181 | "id": "6nnPceFWwedh" 182 | }, 183 | "outputs": [], 184 | "source": [ 185 | "# Placeholder data generation\n", 186 | "def get_sinewave(freq=440.0, duration_ms=200, volume=1.0, sample_rate=44100.0):\n", 187 | " # code adapted from https://stackoverflow.com/a/33913403\n", 188 | " audio = []\n", 189 | " num_samples = duration_ms * (sample_rate / 1000.0)\n", 190 | " for x in range(int(num_samples)):\n", 191 | " audio.append(volume * math.sin(2 * math.pi * freq * (x / sample_rate)))\n", 192 | " return audio\n", 193 | "\n", 194 | "def save_wav(file_name, audio, sample_rate=44100.0):\n", 195 | " # Open up a wav file\n", 196 | " wav_file=wave.open(file_name,\"w\")\n", 197 | " # wav params\n", 198 | " nchannels = 1\n", 199 | " sampwidth = 2\n", 200 | " # 44100 is the industry standard sample rate - CD quality. If you need to\n", 201 | " # save on file size you can adjust it downwards. The stanard for low quality\n", 202 | " # is 8000 or 8kHz.\n", 203 | " nframes = len(audio)\n", 204 | " comptype = \"NONE\"\n", 205 | " compname = \"not compressed\"\n", 206 | " wav_file.setparams((nchannels, sampwidth, sample_rate, nframes, comptype, compname))\n", 207 | " # WAV files here are using short, 16 bit, signed integers for the \n", 208 | " # sample size. So we multiply the floating point data we have by 32767, the\n", 209 | " # maximum value for a short integer. NOTE: It is theortically possible to\n", 210 | " # use the floating point -1.0 to 1.0 data directly in a WAV file but not\n", 211 | " # obvious how to do that using the wave module in python.\n", 212 | " for sample in audio:\n", 213 | " wav_file.writeframes(struct.pack('h', int( sample * 32767.0 )))\n", 214 | " wav_file.close()\n", 215 | " return\n", 216 | "\n", 217 | "def make_placeholder_dataset():\n", 218 | " # Make a placeholder dataset with a few .wav files that you can \"train\" on, just to verify things work e2e\n", 219 | " if os.path.isdir(dataset_folder):\n", 220 | " return\n", 221 | " os.makedirs(dataset_folder)\n", 222 | " save_wav(f\"{dataset_folder}/example.wav\", get_sinewave())\n", 223 | " save_wav(f\"{dataset_folder}/example2.wav\", get_sinewave(duration_ms=500))\n", 224 | " os.makedirs(f\"{dataset_folder}/subdirectory\")\n", 225 | " save_wav(f\"{dataset_folder}/subdirectory/example.wav\", get_sinewave(freq=330.0))\n", 226 | "\n", 227 | "make_placeholder_dataset()" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 16, 233 | "metadata": { 234 | "id": "jwYCbFpHvmRI" 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "# Get actual dataset. Uncomment this if you want to try training on real data\n", 239 | "\n", 240 | "# full dataset: https://www.openslr.org/12\n", 241 | "# We'll use https://us.openslr.org/resources/12/dev-clean.tar.gz development set, \"clean\" speech.\n", 242 | "# We *should* train on, well, training, but this is just to demo running things end-to-end at all so I just picked a small clean set.\n", 243 | "\n", 244 | "# url = \"https://us.openslr.org/resources/12/dev-clean.tar.gz\"\n", 245 | "# filename = \"dev-clean\"\n", 246 | "# filename_targz = filename + \".tar.gz\"\n", 247 | "# if not os.path.isfile(filename_targz):\n", 248 | "# urllib.request.urlretrieve(url, filename_targz)\n", 249 | "# if not os.path.isdir(filename):\n", 250 | "# # open file\n", 251 | "# with tarfile.open(filename_targz) as t:\n", 252 | "# t.extractall(filename)" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": { 258 | "id": "PYcI0aXEwuxR" 259 | }, 260 | "source": [ 261 | "## Training\n", 262 | "\n", 263 | "Now that we have a dataset, we can train AudioLM.\n", 264 | "\n", 265 | "**Note**: do NOT type \"y\" to overwrite previous experiments/ checkpoints when running through the cells here unless you're ready to the entire results folder! Otherwise you will end up erasing things (e.g. you train SoundStream first, and if you choose \"overwrite\" then you lose the SoundStream checkpoint when you then train SemanticTransformer)." 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": { 271 | "id": "T7GiyBcBWiZV" 272 | }, 273 | "source": [ 274 | "### SoundStream" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 17, 280 | "metadata": { 281 | "colab": { 282 | "base_uri": "https://localhost:8080/" 283 | }, 284 | "id": "nGU0OZiOwPEO", 285 | "outputId": "21dd959c-6458-4477-8403-cf810166f38d" 286 | }, 287 | "outputs": [ 288 | { 289 | "name": "stdout", 290 | "output_type": "stream", 291 | "text": [ 292 | "training with dataset of 2 samples and validating with randomly splitted 1 samples\n", 293 | "0: soundstream total loss: 167.262, soundstream recon loss: 1.123 | discr (scale 1) loss: 2.003 | discr (scale 0.5) loss: 1.999 | discr (scale 0.25) loss: 1.999\n", 294 | "0: saving to results\n", 295 | "0: saving model to results\n", 296 | "1: soundstream total loss: 182.282, soundstream recon loss: 1.389 | discr (scale 1) loss: 1.938 | discr (scale 0.5) loss: 1.928 | discr (scale 0.25) loss: 1.928\n", 297 | "2: soundstream total loss: 196.668, soundstream recon loss: 1.450 | discr (scale 1) loss: 1.845 | discr (scale 0.5) loss: 1.842 | discr (scale 0.25) loss: 1.843\n", 298 | "2: saving to results\n", 299 | "3: soundstream total loss: 216.329, soundstream recon loss: 1.451 | discr (scale 1) loss: 1.751 | discr (scale 0.5) loss: 1.750 | discr (scale 0.25) loss: 1.757\n", 300 | "4: soundstream total loss: 206.804, soundstream recon loss: 1.167 | discr (scale 1) loss: 1.671 | discr (scale 0.5) loss: 1.706 | discr (scale 0.25) loss: 1.724\n", 301 | "4: saving to results\n", 302 | "4: saving model to results\n", 303 | "5: soundstream total loss: 195.325, soundstream recon loss: 0.929 | discr (scale 1) loss: 1.348 | discr (scale 0.5) loss: 1.372 | discr (scale 0.25) loss: 1.482\n", 304 | "6: soundstream total loss: 245.195, soundstream recon loss: 1.054 | discr (scale 1) loss: 1.060 | discr (scale 0.5) loss: 1.244 | discr (scale 0.25) loss: 1.288\n", 305 | "6: saving to results\n", 306 | "7: soundstream total loss: 245.724, soundstream recon loss: 0.970 | discr (scale 1) loss: 1.092 | discr (scale 0.5) loss: 1.358 | discr (scale 0.25) loss: 1.079\n", 307 | "8: soundstream total loss: 202.707, soundstream recon loss: 0.786 | discr (scale 1) loss: 0.733 | discr (scale 0.5) loss: 0.687 | discr (scale 0.25) loss: 0.790\n", 308 | "8: saving to results\n", 309 | "8: saving model to results\n", 310 | "training complete\n" 311 | ] 312 | } 313 | ], 314 | "source": [ 315 | "soundstream = SoundStream(\n", 316 | " codebook_size = 1024,\n", 317 | " rq_num_quantizers = 8,\n", 318 | ")\n", 319 | "\n", 320 | "trainer = SoundStreamTrainer(\n", 321 | " soundstream,\n", 322 | " folder = dataset_folder,\n", 323 | " batch_size = 4,\n", 324 | " grad_accum_every = 8, # effective batch size of 32\n", 325 | " data_max_length = 320 * 32,\n", 326 | " save_results_every = 2,\n", 327 | " save_model_every = 4,\n", 328 | " num_train_steps = 9\n", 329 | ").cuda()\n", 330 | "# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes\n", 331 | "# adjusting save_*_every variables for the same reason\n", 332 | "\n", 333 | "trainer.train()" 334 | ] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "metadata": { 339 | "id": "lqjN28L4Wc5Q" 340 | }, 341 | "source": [ 342 | "### SemanticTransformer" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 18, 348 | "metadata": { 349 | "colab": { 350 | "base_uri": "https://localhost:8080/" 351 | }, 352 | "id": "qgd962eSvDzS", 353 | "outputId": "b0550cde-0c8b-4a39-f896-f6f813f50f8c" 354 | }, 355 | "outputs": [ 356 | { 357 | "name": "stderr", 358 | "output_type": "stream", 359 | "text": [ 360 | "/usr/local/lib/python3.8/dist-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator MiniBatchKMeans from version 0.24.0 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n", 361 | "https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations\n", 362 | " warnings.warn(\n" 363 | ] 364 | }, 365 | { 366 | "name": "stdout", 367 | "output_type": "stream", 368 | "text": [ 369 | "training with dataset of 2 samples and validating with randomly splitted 1 samples\n", 370 | "do you want to clear previous experiment checkpoints and results? (y/n) n\n", 371 | "0: loss: 6.648584365844727\n", 372 | "0: valid loss 5.763116359710693\n", 373 | "0: saving model to results\n", 374 | "training complete\n" 375 | ] 376 | } 377 | ], 378 | "source": [ 379 | "# hubert checkpoints can be downloaded at\n", 380 | "# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert\n", 381 | "if not os.path.isdir(\"hubert\"):\n", 382 | " os.makedirs(\"hubert\")\n", 383 | "if not os.path.isfile(hubert_ckpt):\n", 384 | " hubert_ckpt_download = f\"https://dl.fbaipublicfiles.com/{hubert_ckpt}\"\n", 385 | " urllib.request.urlretrieve(hubert_ckpt_download, f\"./{hubert_ckpt}\")\n", 386 | "if not os.path.isfile(hubert_quantizer):\n", 387 | " hubert_quantizer_download = f\"https://dl.fbaipublicfiles.com/{hubert_quantizer}\"\n", 388 | " urllib.request.urlretrieve(hubert_quantizer_download, f\"./{hubert_quantizer}\")\n", 389 | "\n", 390 | "wav2vec = HubertWithKmeans(\n", 391 | " checkpoint_path = f'./{hubert_ckpt}',\n", 392 | " kmeans_path = f'./{hubert_quantizer}'\n", 393 | ")\n", 394 | "\n", 395 | "semantic_transformer = SemanticTransformer(\n", 396 | " num_semantic_tokens = wav2vec.codebook_size,\n", 397 | " dim = 1024,\n", 398 | " depth = 6\n", 399 | ").cuda()\n", 400 | "\n", 401 | "\n", 402 | "trainer = SemanticTransformerTrainer(\n", 403 | " transformer = semantic_transformer,\n", 404 | " wav2vec = wav2vec,\n", 405 | " folder = dataset_folder,\n", 406 | " batch_size = 1,\n", 407 | " data_max_length = 320 * 32,\n", 408 | " num_train_steps = 1\n", 409 | ")\n", 410 | "\n", 411 | "trainer.train()" 412 | ] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "metadata": { 417 | "id": "4eEvIzhEWwRz" 418 | }, 419 | "source": [ 420 | "### CoarseTransformer" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 19, 426 | "metadata": { 427 | "colab": { 428 | "base_uri": "https://localhost:8080/" 429 | }, 430 | "id": "1LeWmaNHzzY9", 431 | "outputId": "7e7ecb3b-f59e-4d18-c8c9-64762e9b43fc" 432 | }, 433 | "outputs": [ 434 | { 435 | "name": "stderr", 436 | "output_type": "stream", 437 | "text": [ 438 | "/usr/local/lib/python3.8/dist-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator MiniBatchKMeans from version 0.24.0 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n", 439 | "https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations\n", 440 | " warnings.warn(\n" 441 | ] 442 | }, 443 | { 444 | "name": "stdout", 445 | "output_type": "stream", 446 | "text": [ 447 | "training with dataset of 2 samples and validating with randomly splitted 1 samples\n", 448 | "do you want to clear previous experiment checkpoints and results? (y/n) n\n", 449 | "0: loss: 63.983970642089844\n", 450 | "0: valid loss 63.398582458496094\n", 451 | "0: saving model to results\n", 452 | "1: loss: 65.85967254638672\n", 453 | "2: loss: 62.4722900390625\n", 454 | "2: valid loss 50.01605987548828\n", 455 | "3: loss: 11.735434532165527\n", 456 | "4: loss: 3.976104497909546\n", 457 | "4: valid loss 46.094608306884766\n", 458 | "4: saving model to results\n", 459 | "5: loss: 58.27140426635742\n", 460 | "6: loss: 41.68347930908203\n", 461 | "6: valid loss 45.54595184326172\n", 462 | "7: loss: 2.2387890815734863\n", 463 | "8: loss: 0.4718627631664276\n", 464 | "8: valid loss 39.10848617553711\n", 465 | "8: saving model to results\n", 466 | "training complete\n" 467 | ] 468 | } 469 | ], 470 | "source": [ 471 | "wav2vec = HubertWithKmeans(\n", 472 | " checkpoint_path = f'./{hubert_ckpt}',\n", 473 | " kmeans_path = f'./{hubert_quantizer}'\n", 474 | ")\n", 475 | "\n", 476 | "soundstream = SoundStream(\n", 477 | " codebook_size = 1024,\n", 478 | " rq_num_quantizers = 8,\n", 479 | ")\n", 480 | "\n", 481 | "soundstream.load(f\"./{soundstream_ckpt}\")\n", 482 | "\n", 483 | "coarse_transformer = CoarseTransformer(\n", 484 | " num_semantic_tokens = wav2vec.codebook_size,\n", 485 | " codebook_size = 1024,\n", 486 | " num_coarse_quantizers = 3,\n", 487 | " dim = 512,\n", 488 | " depth = 6\n", 489 | ")\n", 490 | "\n", 491 | "trainer = CoarseTransformerTrainer(\n", 492 | " transformer = coarse_transformer,\n", 493 | " codec = soundstream,\n", 494 | " wav2vec = wav2vec,\n", 495 | " folder = dataset_folder,\n", 496 | " batch_size = 1,\n", 497 | " data_max_length = 320 * 32,\n", 498 | " save_results_every = 2,\n", 499 | " save_model_every = 4,\n", 500 | " num_train_steps = 9\n", 501 | ")\n", 502 | "# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes\n", 503 | "# adjusting save_*_every variables for the same reason\n", 504 | "\n", 505 | "trainer.train()" 506 | ] 507 | }, 508 | { 509 | "cell_type": "markdown", 510 | "metadata": { 511 | "id": "fRvj7qOJWzmw" 512 | }, 513 | "source": [ 514 | "### FineTransformer" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": 20, 520 | "metadata": { 521 | "colab": { 522 | "base_uri": "https://localhost:8080/" 523 | }, 524 | "id": "ZRaEhRRKWg8F", 525 | "outputId": "7cc166c4-c8e9-45ef-8293-8f5381c2d3af" 526 | }, 527 | "outputs": [ 528 | { 529 | "name": "stdout", 530 | "output_type": "stream", 531 | "text": [ 532 | "training with dataset of 2 samples and validating with randomly splitted 1 samples\n", 533 | "do you want to clear previous experiment checkpoints and results? (y/n) n\n", 534 | "0: loss: 70.90608215332031\n", 535 | "0: valid loss 65.99951171875\n", 536 | "0: saving model to results\n", 537 | "1: loss: 43.6014289855957\n", 538 | "2: loss: 8.300681114196777\n", 539 | "3: loss: 61.23375701904297\n", 540 | "4: loss: 63.34052276611328\n", 541 | "5: loss: 2.010118246078491\n", 542 | "6: loss: 56.52588653564453\n", 543 | "7: loss: 0.5423888564109802\n", 544 | "8: loss: 0.005095238331705332\n", 545 | "training complete\n" 546 | ] 547 | } 548 | ], 549 | "source": [ 550 | "soundstream = SoundStream(\n", 551 | " codebook_size = 1024,\n", 552 | " rq_num_quantizers = 8,\n", 553 | ")\n", 554 | "\n", 555 | "soundstream.load(f\"./{soundstream_ckpt}\")\n", 556 | "\n", 557 | "fine_transformer = FineTransformer(\n", 558 | " num_coarse_quantizers = 3,\n", 559 | " num_fine_quantizers = 5,\n", 560 | " codebook_size = 1024,\n", 561 | " dim = 512,\n", 562 | " depth = 6\n", 563 | ")\n", 564 | "\n", 565 | "trainer = FineTransformerTrainer(\n", 566 | " transformer = fine_transformer,\n", 567 | " codec = soundstream,\n", 568 | " folder = dataset_folder,\n", 569 | " batch_size = 1,\n", 570 | " data_max_length = 320 * 32,\n", 571 | " num_train_steps = 9\n", 572 | ")\n", 573 | "# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes\n", 574 | "# adjusting save_*_every variables for the same reason\n", 575 | "\n", 576 | "trainer.train()" 577 | ] 578 | }, 579 | { 580 | "cell_type": "markdown", 581 | "metadata": { 582 | "id": "QoHgkgA3XKXH" 583 | }, 584 | "source": [ 585 | "## Inference" 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": 21, 591 | "metadata": { 592 | "colab": { 593 | "base_uri": "https://localhost:8080/" 594 | }, 595 | "id": "rzghrux5WinW", 596 | "outputId": "9dd39f7f-0046-4a5f-826e-a442345987af" 597 | }, 598 | "outputs": [ 599 | { 600 | "name": "stderr", 601 | "output_type": "stream", 602 | "text": [ 603 | "generating semantic: 0%| | 10/2048 [00:00<00:25, 78.55it/s]\n", 604 | "generating coarse: 100%|██████████| 512/512 [00:14<00:00, 34.83it/s]\n", 605 | "generating fine: 100%|██████████| 512/512 [02:56<00:00, 2.91it/s]\n" 606 | ] 607 | } 608 | ], 609 | "source": [ 610 | "# Everything together\n", 611 | "audiolm = AudioLM(\n", 612 | " wav2vec = wav2vec,\n", 613 | " codec = soundstream,\n", 614 | " semantic_transformer = semantic_transformer,\n", 615 | " coarse_transformer = coarse_transformer,\n", 616 | " fine_transformer = fine_transformer\n", 617 | ")\n", 618 | "\n", 619 | "generated_wav = audiolm(batch_size = 1)" 620 | ] 621 | }, 622 | { 623 | "cell_type": "code", 624 | "execution_count": 22, 625 | "metadata": { 626 | "id": "4rQPHTSRngEr" 627 | }, 628 | "outputs": [], 629 | "source": [ 630 | "output_path = \"out.wav\"\n", 631 | "sample_rate = 44100\n", 632 | "torchaudio.save(output_path, generated_wav.cpu(), sample_rate)" 633 | ] 634 | }, 635 | { 636 | "cell_type": "code", 637 | "execution_count": 22, 638 | "metadata": { 639 | "id": "is9wLY_ncDYK" 640 | }, 641 | "outputs": [], 642 | "source": [] 643 | } 644 | ], 645 | "metadata": { 646 | "accelerator": "GPU", 647 | "colab": { 648 | "provenance": [] 649 | }, 650 | "gpuClass": "standard", 651 | "kernelspec": { 652 | "display_name": "Python 3 (ipykernel)", 653 | "language": "python", 654 | "name": "python3" 655 | }, 656 | "language_info": { 657 | "codemirror_mode": { 658 | "name": "ipython", 659 | "version": 3 660 | }, 661 | "file_extension": ".py", 662 | "mimetype": "text/x-python", 663 | "name": "python", 664 | "nbconvert_exporter": "python", 665 | "pygments_lexer": "ipython3", 666 | "version": "3.9.13" 667 | } 668 | }, 669 | "nbformat": 4, 670 | "nbformat_minor": 1 671 | } 672 | -------------------------------------------------------------------------------- /audiolm_pytorch/soundstream.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from itertools import cycle 3 | from pathlib import Path 4 | 5 | from functools import partial, wraps 6 | from itertools import zip_longest 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn, einsum 11 | from torch.autograd import grad as torch_grad 12 | import torch.nn.functional as F 13 | from torch.linalg import vector_norm 14 | 15 | import torchaudio.transforms as T 16 | from torchaudio.functional import resample 17 | 18 | from einops import rearrange, reduce, pack, unpack 19 | 20 | from vector_quantize_pytorch import GroupedResidualVQ 21 | 22 | from local_attention import LocalMHA 23 | from local_attention.transformer import FeedForward, DynamicPositionBias 24 | 25 | from audiolm_pytorch.utils import curtail_to_multiple 26 | 27 | from audiolm_pytorch.version import __version__ 28 | from packaging import version 29 | parsed_version = version.parse(__version__) 30 | 31 | import pickle 32 | 33 | # helper functions 34 | 35 | def exists(val): 36 | return val is not None 37 | 38 | def default(val, d): 39 | return val if exists(val) else d 40 | 41 | def cast_tuple(t, l = 1): 42 | return ((t,) * l) if not isinstance(t, tuple) else t 43 | 44 | def filter_by_keys(fn, d): 45 | return {k: v for k, v in d.items() if fn(k)} 46 | 47 | def map_keys(fn, d): 48 | return {fn(k): v for k, v in d.items()} 49 | 50 | # gan losses 51 | 52 | def log(t, eps = 1e-20): 53 | return torch.log(t.clamp(min = eps)) 54 | 55 | def hinge_discr_loss(fake, real): 56 | return (F.relu(1 + fake) + F.relu(1 - real)).mean() 57 | 58 | def hinge_gen_loss(fake): 59 | return -fake.mean() 60 | 61 | def leaky_relu(p = 0.1): 62 | return nn.LeakyReLU(p) 63 | 64 | def gradient_penalty(wave, output, weight = 10): 65 | batch_size, device = wave.shape[0], wave.device 66 | 67 | gradients = torch_grad( 68 | outputs = output, 69 | inputs = wave, 70 | grad_outputs = torch.ones_like(output), 71 | create_graph = True, 72 | retain_graph = True, 73 | only_inputs = True 74 | )[0] 75 | 76 | gradients = rearrange(gradients, 'b ... -> b (...)') 77 | return weight * ((vector_norm(gradients, dim = 1) - 1) ** 2).mean() 78 | 79 | # better sequential 80 | 81 | def Sequential(*mods): 82 | return nn.Sequential(*filter(exists, mods)) 83 | 84 | # discriminators 85 | 86 | class MultiScaleDiscriminator(nn.Module): 87 | def __init__( 88 | self, 89 | channels = 16, 90 | layers = 4, 91 | groups = (4, 16, 64, 256), 92 | chan_max = 1024, 93 | input_channels = 1 94 | ): 95 | super().__init__() 96 | self.init_conv = nn.Conv1d(input_channels, channels, 15, padding = 7) 97 | self.conv_layers = nn.ModuleList([]) 98 | 99 | curr_channels = channels 100 | 101 | for _, group in zip(range(layers), groups): 102 | chan_out = min(curr_channels * 4, chan_max) 103 | 104 | self.conv_layers.append(nn.Sequential( 105 | nn.Conv1d(curr_channels, chan_out, 41, stride = 4, padding = 20, groups = group), 106 | leaky_relu() 107 | )) 108 | 109 | curr_channels = chan_out 110 | 111 | self.final_conv = nn.Sequential( 112 | nn.Conv1d(curr_channels, curr_channels, 5, padding = 2), 113 | leaky_relu(), 114 | nn.Conv1d(curr_channels, 1, 3, padding = 1), 115 | ) 116 | 117 | def forward( 118 | self, 119 | x, 120 | return_intermediates = False 121 | ): 122 | x = self.init_conv(x) 123 | intermediates = [] 124 | 125 | for layer in self.conv_layers: 126 | x = layer(x) 127 | intermediates.append(x) 128 | 129 | out = self.final_conv(x) 130 | 131 | if not return_intermediates: 132 | return out 133 | 134 | return out, intermediates 135 | 136 | # autoregressive squeeze excitation 137 | # https://arxiv.org/abs/1709.01507 138 | 139 | class SqueezeExcite(nn.Module): 140 | def __init__(self, dim, reduction_factor = 4, dim_minimum = 8): 141 | super().__init__() 142 | dim_inner = max(dim_minimum, dim // reduction_factor) 143 | self.net = nn.Sequential( 144 | nn.Conv1d(dim, dim_inner, 1), 145 | nn.SiLU(), 146 | nn.Conv1d(dim_inner, dim, 1), 147 | nn.Sigmoid() 148 | ) 149 | 150 | def forward(self, x): 151 | seq, device = x.shape[-2], x.device 152 | 153 | # cumulative mean - since it is autoregressive 154 | 155 | cum_sum = x.cumsum(dim = -2) 156 | denom = torch.arange(1, seq + 1, device = device).float() 157 | cum_mean = cum_sum / rearrange(denom, 'n -> n 1') 158 | 159 | # glu gate 160 | 161 | gate = self.net(cum_mean) 162 | 163 | return x * gate 164 | 165 | # complex stft discriminator 166 | 167 | class ModReLU(nn.Module): 168 | """ 169 | https://arxiv.org/abs/1705.09792 170 | https://github.com/pytorch/pytorch/issues/47052#issuecomment-718948801 171 | """ 172 | def __init__(self): 173 | super().__init__() 174 | self.b = nn.Parameter(torch.tensor(0.)) 175 | 176 | def forward(self, x): 177 | return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x)) 178 | 179 | class ComplexConv2d(nn.Module): 180 | def __init__( 181 | self, 182 | dim, 183 | dim_out, 184 | kernel_size, 185 | stride = 1, 186 | padding = 0 187 | ): 188 | super().__init__() 189 | conv = nn.Conv2d(dim, dim_out, kernel_size, dtype = torch.complex64) 190 | self.weight = nn.Parameter(torch.view_as_real(conv.weight)) 191 | self.bias = nn.Parameter(torch.view_as_real(conv.bias)) 192 | 193 | self.stride = stride 194 | self.padding = padding 195 | 196 | def forward(self, x): 197 | weight, bias = map(torch.view_as_complex, (self.weight, self.bias)) 198 | 199 | x = x.to(weight.dtype) 200 | return F.conv2d(x, weight, bias, stride = self.stride, padding = self.padding) 201 | 202 | def ComplexSTFTResidualUnit(chan_in, chan_out, strides): 203 | kernel_sizes = tuple(map(lambda t: t + 2, strides)) 204 | paddings = tuple(map(lambda t: t // 2, kernel_sizes)) 205 | 206 | return nn.Sequential( 207 | Residual(Sequential( 208 | ComplexConv2d(chan_in, chan_in, 3, padding = 1), 209 | ModReLU(), 210 | ComplexConv2d(chan_in, chan_in, 3, padding = 1) 211 | )), 212 | ComplexConv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings) 213 | ) 214 | 215 | class ComplexSTFTDiscriminator(nn.Module): 216 | def __init__( 217 | self, 218 | *, 219 | channels = 32, 220 | strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)), 221 | chan_mults = (1, 2, 4, 4, 8, 8), 222 | input_channels = 1, 223 | n_fft = 1024, 224 | hop_length = 256, 225 | win_length = 1024, 226 | stft_normalized = False, 227 | logits_abs = True 228 | ): 229 | super().__init__() 230 | self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3) 231 | 232 | layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) 233 | layer_channels = (channels, *layer_channels) 234 | layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:])) 235 | 236 | curr_channels = channels 237 | 238 | self.layers = nn.ModuleList([]) 239 | 240 | for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): 241 | self.layers.append(ComplexSTFTResidualUnit(chan_in, chan_out, layer_stride)) 242 | 243 | self.final_conv = ComplexConv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16 244 | 245 | # stft settings 246 | 247 | self.stft_normalized = stft_normalized 248 | 249 | self.n_fft = n_fft 250 | self.hop_length = hop_length 251 | self.win_length = win_length 252 | 253 | # how to output the logits into real space 254 | self.logits_abs = logits_abs 255 | 256 | def forward(self, x, return_intermediates = False): 257 | x = rearrange(x, 'b 1 n -> b n') 258 | 259 | ''' 260 | reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows: 261 | The STFT-based discriminator is illustrated in Figure 4 262 | and operates on a single scale, computing the STFT with a 263 | window length of W = 1024 samples and a hop length of 264 | H = 256 samples 265 | ''' 266 | 267 | x = torch.stft( 268 | x, 269 | self.n_fft, 270 | hop_length = self.hop_length, 271 | win_length = self.win_length, 272 | normalized = self.stft_normalized, 273 | return_complex = True 274 | ) 275 | 276 | x = rearrange(x, 'b ... -> b 1 ...') 277 | 278 | intermediates = [] 279 | 280 | x = self.init_conv(x) 281 | 282 | intermediates.append(x) 283 | 284 | for layer in self.layers: 285 | x = layer(x) 286 | intermediates.append(x) 287 | 288 | complex_logits = self.final_conv(x) 289 | 290 | if self.logits_abs: 291 | complex_logits = complex_logits.abs() 292 | else: 293 | complex_logits = torch.view_as_real(complex_logits) 294 | 295 | if not return_intermediates: 296 | return complex_logits 297 | 298 | return complex_logits, intermediates 299 | 300 | # sound stream 301 | 302 | class Residual(nn.Module): 303 | def __init__(self, fn): 304 | super().__init__() 305 | self.fn = fn 306 | 307 | def forward(self, x, **kwargs): 308 | return self.fn(x, **kwargs) + x 309 | 310 | class CausalConv1d(nn.Module): 311 | def __init__(self, chan_in, chan_out, kernel_size, pad_mode = 'reflect', **kwargs): 312 | super().__init__() 313 | kernel_size = kernel_size 314 | dilation = kwargs.get('dilation', 1) 315 | stride = kwargs.get('stride', 1) 316 | self.pad_mode = pad_mode 317 | self.causal_padding = dilation * (kernel_size - 1) + (1 - stride) 318 | 319 | self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, **kwargs) 320 | 321 | def forward(self, x): 322 | x = F.pad(x, (self.causal_padding, 0), mode = self.pad_mode) 323 | return self.conv(x) 324 | 325 | class CausalConvTranspose1d(nn.Module): 326 | def __init__(self, chan_in, chan_out, kernel_size, stride, **kwargs): 327 | super().__init__() 328 | self.upsample_factor = stride 329 | self.padding = kernel_size - 1 330 | self.conv = nn.ConvTranspose1d(chan_in, chan_out, kernel_size, stride, **kwargs) 331 | 332 | def forward(self, x): 333 | n = x.shape[-1] 334 | 335 | out = self.conv(x) 336 | out = out[..., :(n * self.upsample_factor)] 337 | 338 | return out 339 | 340 | def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7, squeeze_excite = False, pad_mode = 'reflect'): 341 | return Residual(Sequential( 342 | CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation, pad_mode = pad_mode), 343 | nn.ELU(), 344 | CausalConv1d(chan_out, chan_out, 1, pad_mode = pad_mode), 345 | nn.ELU(), 346 | SqueezeExcite(chan_out) if squeeze_excite else None 347 | )) 348 | 349 | def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'): 350 | it = cycle(cycle_dilations) 351 | residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode) 352 | 353 | return nn.Sequential( 354 | residual_unit(chan_in, chan_in, next(it)), 355 | residual_unit(chan_in, chan_in, next(it)), 356 | residual_unit(chan_in, chan_in, next(it)), 357 | CausalConv1d(chan_in, chan_out, 2 * stride, stride = stride) 358 | ) 359 | 360 | def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'): 361 | even_stride = (stride % 2 == 0) 362 | padding = (stride + (0 if even_stride else 1)) // 2 363 | output_padding = 0 if even_stride else 1 364 | 365 | residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode) 366 | 367 | it = cycle(cycle_dilations) 368 | return nn.Sequential( 369 | CausalConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride), 370 | residual_unit(chan_out, chan_out, next(it)), 371 | residual_unit(chan_out, chan_out, next(it)), 372 | residual_unit(chan_out, chan_out, next(it)), 373 | ) 374 | 375 | class LocalTransformer(nn.Module): 376 | def __init__( 377 | self, 378 | *, 379 | dim, 380 | depth, 381 | heads, 382 | window_size, 383 | dynamic_pos_bias = False, 384 | **kwargs 385 | ): 386 | super().__init__() 387 | self.window_size = window_size 388 | self.layers = nn.ModuleList([]) 389 | 390 | self.pos_bias = None 391 | if dynamic_pos_bias: 392 | self.pos_bias = DynamicPositionBias(dim = dim // 2, heads = heads) 393 | 394 | for _ in range(depth): 395 | self.layers.append(nn.ModuleList([ 396 | LocalMHA(dim = dim, heads = heads, qk_rmsnorm = True, window_size = window_size, use_rotary_pos_emb = not dynamic_pos_bias, use_xpos = True, **kwargs), 397 | FeedForward(dim = dim) 398 | ])) 399 | 400 | def forward(self, x): 401 | w = self.window_size 402 | 403 | attn_bias = self.pos_bias(w, w * 2) if exists(self.pos_bias) else None 404 | 405 | for attn, ff in self.layers: 406 | x = attn(x, attn_bias = attn_bias) + x 407 | x = ff(x) + x 408 | 409 | return x 410 | 411 | class FiLM(nn.Module): 412 | def __init__(self, dim, dim_cond): 413 | super().__init__() 414 | self.to_cond = nn.Linear(dim_cond, dim * 2) 415 | 416 | def forward(self, x, cond): 417 | gamma, beta = self.to_cond(cond).chunk(2, dim = -1) 418 | return x * gamma + beta 419 | 420 | class SoundStream(nn.Module): 421 | def __init__( 422 | self, 423 | *, 424 | channels = 32, 425 | strides = (2, 4, 5, 8), 426 | channel_mults = (2, 4, 8, 16), 427 | codebook_dim = 512, 428 | codebook_size = 1024, 429 | rq_num_quantizers = 8, 430 | rq_commitment_weight = 1., 431 | rq_ema_decay = 0.95, 432 | rq_quantize_dropout_multiple_of = 1, 433 | rq_groups = 1, 434 | rq_stochastic_sample_codes = False, 435 | rq_kwargs: dict = {}, 436 | input_channels = 1, 437 | discr_multi_scales = (1, 0.5, 0.25), 438 | stft_normalized = False, 439 | enc_cycle_dilations = (1, 3, 9), 440 | dec_cycle_dilations = (1, 3, 9), 441 | multi_spectral_window_powers_of_two = tuple(range(6, 12)), 442 | multi_spectral_n_ffts = 512, 443 | multi_spectral_n_mels = 64, 444 | recon_loss_weight = 1., 445 | multi_spectral_recon_loss_weight = 1e-5, 446 | adversarial_loss_weight = 1., 447 | feature_loss_weight = 100, 448 | quantize_dropout_cutoff_index = 1, 449 | target_sample_hz = 16000, 450 | use_local_attn = True, 451 | attn_window_size = 128, 452 | attn_dim_head = 64, 453 | attn_heads = 8, 454 | attn_depth = 1, 455 | attn_xpos_scale_base = None, 456 | attn_dynamic_pos_bias = False, 457 | squeeze_excite = False, 458 | complex_stft_discr_logits_abs = True, 459 | pad_mode = 'reflect', 460 | stft_discriminator: Optional[nn.Module] = None # can pass in own stft discriminator 461 | ): 462 | super().__init__() 463 | 464 | # for autosaving the config 465 | 466 | _locals = locals() 467 | _locals.pop('self', None) 468 | _locals.pop('__class__', None) 469 | self._configs = pickle.dumps(_locals) 470 | 471 | # rest of the class 472 | 473 | self.target_sample_hz = target_sample_hz # for resampling on the fly 474 | 475 | self.single_channel = input_channels == 1 476 | self.strides = strides 477 | 478 | layer_channels = tuple(map(lambda t: t * channels, channel_mults)) 479 | layer_channels = (channels, *layer_channels) 480 | chan_in_out_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:])) 481 | 482 | encoder_blocks = [] 483 | 484 | for ((chan_in, chan_out), layer_stride) in zip(chan_in_out_pairs, strides): 485 | encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations, squeeze_excite, pad_mode)) 486 | 487 | self.encoder = nn.Sequential( 488 | CausalConv1d(input_channels, channels, 7, pad_mode = pad_mode), 489 | *encoder_blocks, 490 | CausalConv1d(layer_channels[-1], codebook_dim, 3, pad_mode = pad_mode) 491 | ) 492 | 493 | attn_kwargs = dict( 494 | dim = codebook_dim, 495 | dim_head = attn_dim_head, 496 | heads = attn_heads, 497 | depth = attn_depth, 498 | window_size = attn_window_size, 499 | xpos_scale_base = attn_xpos_scale_base, 500 | dynamic_pos_bias = attn_dynamic_pos_bias, 501 | prenorm = True, 502 | causal = True 503 | ) 504 | 505 | self.encoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None 506 | 507 | self.encoder_film = FiLM(codebook_dim, dim_cond = 2) 508 | 509 | self.num_quantizers = rq_num_quantizers 510 | 511 | self.codebook_dim = codebook_dim 512 | self.codebook_size = codebook_size 513 | 514 | self.rq_groups = rq_groups 515 | 516 | self.rq = GroupedResidualVQ( 517 | dim = codebook_dim, 518 | num_quantizers = rq_num_quantizers, 519 | codebook_size = codebook_size, 520 | groups = rq_groups, 521 | decay = rq_ema_decay, 522 | commitment_weight = rq_commitment_weight, 523 | quantize_dropout_multiple_of = rq_quantize_dropout_multiple_of, 524 | kmeans_init = True, 525 | threshold_ema_dead_code = 2, 526 | quantize_dropout = True, 527 | quantize_dropout_cutoff_index = quantize_dropout_cutoff_index, 528 | stochastic_sample_codes = rq_stochastic_sample_codes, 529 | **rq_kwargs 530 | ) 531 | 532 | self.decoder_film = FiLM(codebook_dim, dim_cond = 2) 533 | 534 | self.decoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None 535 | 536 | decoder_blocks = [] 537 | 538 | for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)): 539 | decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations, squeeze_excite, pad_mode)) 540 | 541 | self.decoder = nn.Sequential( 542 | CausalConv1d(codebook_dim, layer_channels[-1], 7, pad_mode = pad_mode), 543 | *decoder_blocks, 544 | CausalConv1d(channels, input_channels, 7, pad_mode = pad_mode) 545 | ) 546 | 547 | # discriminators 548 | 549 | self.discr_multi_scales = discr_multi_scales 550 | self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))]) 551 | discr_rel_factors = [int(s1 / s2) for s1, s2 in zip(discr_multi_scales[:-1], discr_multi_scales[1:])] 552 | self.downsamples = nn.ModuleList([nn.Identity()] + [nn.AvgPool1d(2 * factor, stride = factor, padding = factor) for factor in discr_rel_factors]) 553 | 554 | self.stft_discriminator = stft_discriminator 555 | 556 | if not exists(self.stft_discriminator): 557 | self.stft_discriminator = ComplexSTFTDiscriminator( 558 | stft_normalized = stft_normalized, 559 | logits_abs = complex_stft_discr_logits_abs # whether to output as abs() or use view_as_real 560 | ) 561 | 562 | # multi spectral reconstruction 563 | 564 | self.mel_spec_transforms = nn.ModuleList([]) 565 | self.mel_spec_recon_alphas = [] 566 | 567 | num_transforms = len(multi_spectral_window_powers_of_two) 568 | multi_spectral_n_ffts = cast_tuple(multi_spectral_n_ffts, num_transforms) 569 | multi_spectral_n_mels = cast_tuple(multi_spectral_n_mels, num_transforms) 570 | 571 | for powers, n_fft, n_mels in zip_longest(multi_spectral_window_powers_of_two, multi_spectral_n_ffts, multi_spectral_n_mels): 572 | win_length = 2 ** powers 573 | alpha = (win_length / 2) ** 0.5 574 | 575 | calculated_n_fft = default(max(n_fft, win_length), win_length) # @AndreyBocharnikov said this is usually win length, but overridable 576 | 577 | # if any audio experts have an opinion about these settings, please submit a PR 578 | 579 | melspec_transform = T.MelSpectrogram( 580 | sample_rate = target_sample_hz, 581 | n_fft = calculated_n_fft, 582 | win_length = win_length, 583 | hop_length = win_length // 4, 584 | n_mels = n_mels, 585 | normalized = stft_normalized 586 | ) 587 | 588 | self.mel_spec_transforms.append(melspec_transform) 589 | self.mel_spec_recon_alphas.append(alpha) 590 | 591 | # loss weights 592 | 593 | self.recon_loss_weight = recon_loss_weight 594 | self.multi_spectral_recon_loss_weight = multi_spectral_recon_loss_weight 595 | self.adversarial_loss_weight = adversarial_loss_weight 596 | self.feature_loss_weight = feature_loss_weight 597 | 598 | self.register_buffer('zero', torch.tensor([0.]), persistent = False) 599 | 600 | @property 601 | def device(self): 602 | return next(self.parameters()).device 603 | 604 | @property 605 | def configs(self): 606 | return pickle.loads(self._configs) 607 | 608 | def decode_from_codebook_indices(self, quantized_indices): 609 | quantized_indices = rearrange(quantized_indices, 'b n (g q) -> g b n q', g = self.rq_groups) 610 | 611 | codes = self.rq.get_codes_from_indices(quantized_indices) 612 | x = reduce(codes, 'g q b n d -> b n (g d)', 'sum') 613 | 614 | return self.decode(x) 615 | 616 | def decode(self, x, quantize = False): 617 | if quantize: 618 | x, *_ = self.rq(x) 619 | 620 | x = self.decoder_attn(x) 621 | x = rearrange(x, 'b n c -> b c n') 622 | return self.decoder(x) 623 | 624 | def save(self, path): 625 | path = Path(path) 626 | pkg = dict( 627 | model = self.state_dict(), 628 | config = self._configs, 629 | version = __version__ 630 | ) 631 | 632 | torch.save(pkg, str(path)) 633 | 634 | @classmethod 635 | def init_and_load_from(cls, path, strict = True): 636 | path = Path(path) 637 | assert path.exists() 638 | pkg = torch.load(str(path), map_location = 'cpu') 639 | 640 | assert 'config' in pkg, 'model configs were not found in this saved checkpoint' 641 | 642 | config = pickle.loads(pkg['config']) 643 | soundstream = cls(**config) 644 | soundstream.load(path, strict = strict) 645 | return soundstream 646 | 647 | def load(self, path, strict = True): 648 | path = Path(path) 649 | assert path.exists() 650 | pkg = torch.load(str(path), map_location = 'cpu') 651 | 652 | # check version 653 | 654 | if 'version' in pkg and version.parse(pkg['version']) < parsed_version: 655 | print(f'soundstream model being loaded was trained on an older version of audiolm-pytorch ({pkg["version"]})') 656 | 657 | has_ema = 'ema_model' in pkg 658 | model_pkg = pkg['ema_model'] if has_ema else pkg['model'] 659 | 660 | if has_ema: 661 | model_pkg = filter_by_keys(lambda k: k.startswith('ema_model.'), model_pkg) 662 | model_pkg = map_keys(lambda k: k[len('ema_model.'):], model_pkg) 663 | 664 | self.load_state_dict(model_pkg, strict = strict) 665 | 666 | def load_from_trainer_saved_obj(self, path): 667 | path = Path(path) 668 | assert path.exists() 669 | obj = torch.load(str(path)) 670 | self.load_state_dict(obj['model']) 671 | 672 | def non_discr_parameters(self): 673 | return [ 674 | *self.encoder.parameters(), 675 | *self.decoder.parameters(), 676 | *(self.encoder_attn.parameters() if exists(self.encoder_attn) else []), 677 | *(self.decoder_attn.parameters() if exists(self.decoder_attn) else []), 678 | *self.encoder_film.parameters(), 679 | *self.decoder_film.parameters() 680 | ] 681 | 682 | @property 683 | def seq_len_multiple_of(self): 684 | return functools.reduce(lambda x, y: x * y, self.strides) 685 | 686 | def process_input( 687 | self, 688 | x, 689 | input_sample_hz = None, 690 | curtail_from_left = False 691 | ): 692 | x, ps = pack([x], '* n') 693 | 694 | if exists(input_sample_hz): 695 | x = resample(x, input_sample_hz, self.target_sample_hz) 696 | 697 | x = curtail_to_multiple(x, self.seq_len_multiple_of, from_left = curtail_from_left) 698 | 699 | if x.ndim == 2: 700 | x = rearrange(x, 'b n -> b 1 n') 701 | 702 | return x, ps 703 | 704 | def forward( 705 | self, 706 | x, 707 | target = None, 708 | is_denoising = None, # if you want to learn film conditioners that teach the soundstream to denoise - target would need to be passed in above 709 | return_encoded = False, 710 | return_discr_loss = False, 711 | return_discr_losses_separately = False, 712 | return_loss_breakdown = False, 713 | return_recons_only = False, 714 | input_sample_hz = None, 715 | apply_grad_penalty = False, 716 | curtail_from_left = False 717 | ): 718 | assert not (exists(is_denoising) and not exists(target)) 719 | 720 | process_input = partial(self.process_input, input_sample_hz = input_sample_hz, curtail_from_left = curtail_from_left) 721 | 722 | x, ps = process_input(x) 723 | 724 | if exists(target): 725 | target, _ = process_input(target) 726 | 727 | orig_x = x.clone() 728 | 729 | x = self.encoder(x) 730 | 731 | x = rearrange(x, 'b c n -> b n c') 732 | 733 | if exists(self.encoder_attn): 734 | x = self.encoder_attn(x) 735 | 736 | if exists(is_denoising): 737 | denoise_input = torch.tensor([is_denoising, not is_denoising], dtype = x.dtype, device = self.device) # [1, 0] for denoise, [0, 1] for not denoising 738 | x = self.encoder_film(x, denoise_input) 739 | 740 | x, indices, commit_loss = self.rq(x) 741 | 742 | if return_encoded: 743 | indices = rearrange(indices, 'g b n q -> b n (g q)') 744 | return x, indices, commit_loss 745 | 746 | if exists(is_denoising): 747 | x = self.decoder_film(x, denoise_input) 748 | 749 | if exists(self.decoder_attn): 750 | x = self.decoder_attn(x) 751 | 752 | x = rearrange(x, 'b n c -> b c n') 753 | 754 | recon_x = self.decoder(x) 755 | 756 | if return_recons_only: 757 | recon_x, = unpack(recon_x, ps, '* c n') 758 | return recon_x 759 | 760 | # multi-scale discriminator loss 761 | 762 | if return_discr_loss: 763 | real, fake = orig_x, recon_x.detach() 764 | 765 | stft_discr_loss = None 766 | stft_grad_penalty = None 767 | discr_losses = [] 768 | discr_grad_penalties = [] 769 | 770 | if self.single_channel: 771 | real, fake = orig_x.clone(), recon_x.detach() 772 | stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real.requires_grad_(), fake)) 773 | stft_discr_loss = hinge_discr_loss(stft_fake_logits, stft_real_logits) 774 | 775 | if apply_grad_penalty: 776 | stft_grad_penalty = gradient_penalty(real, stft_discr_loss) 777 | 778 | scaled_real, scaled_fake = real, fake 779 | for discr, downsample in zip(self.discriminators, self.downsamples): 780 | scaled_real, scaled_fake = map(downsample, (scaled_real, scaled_fake)) 781 | 782 | real_logits, fake_logits = map(discr, (scaled_real.requires_grad_(), scaled_fake)) 783 | one_discr_loss = hinge_discr_loss(fake_logits, real_logits) 784 | 785 | discr_losses.append(one_discr_loss) 786 | if apply_grad_penalty: 787 | discr_grad_penalties.append(gradient_penalty(scaled_real, one_discr_loss)) 788 | 789 | if not return_discr_losses_separately: 790 | all_discr_losses = torch.stack(discr_losses).mean() 791 | 792 | if exists(stft_discr_loss): 793 | all_discr_losses = all_discr_losses + stft_discr_loss 794 | 795 | if exists(stft_grad_penalty): 796 | all_discr_losses = all_discr_losses + stft_grad_penalty 797 | 798 | return all_discr_losses 799 | 800 | # return a list of discriminator losses with List[Tuple[str, Tensor]] 801 | 802 | discr_losses_pkg = [] 803 | 804 | discr_losses_pkg.extend([(f'scale:{scale}', multi_scale_loss) for scale, multi_scale_loss in zip(self.discr_multi_scales, discr_losses)]) 805 | 806 | discr_losses_pkg.extend([(f'scale_grad_penalty:{scale}', discr_grad_penalty) for scale, discr_grad_penalty in zip(self.discr_multi_scales, discr_grad_penalties)]) 807 | 808 | if exists(stft_discr_loss): 809 | discr_losses_pkg.append(('stft', stft_discr_loss)) 810 | 811 | if exists(stft_grad_penalty): 812 | discr_losses_pkg.append(('stft_grad_penalty', stft_grad_penalty)) 813 | 814 | return discr_losses_pkg 815 | 816 | # recon loss 817 | 818 | target = default(target, orig_x) # target can also be passed in, in the case of denoising 819 | 820 | recon_loss = F.mse_loss(target, recon_x) 821 | 822 | # multispectral recon loss - eq (4) and (5) in https://arxiv.org/abs/2107.03312 823 | 824 | multi_spectral_recon_loss = self.zero 825 | 826 | if self.multi_spectral_recon_loss_weight > 0: 827 | for mel_transform, alpha in zip(self.mel_spec_transforms, self.mel_spec_recon_alphas): 828 | orig_mel, recon_mel = map(mel_transform, (orig_x, recon_x)) 829 | log_orig_mel, log_recon_mel = map(log, (orig_mel, recon_mel)) 830 | 831 | l1_mel_loss = (orig_mel - recon_mel).abs().sum(dim = -2).mean() 832 | l2_log_mel_loss = alpha * vector_norm(log_orig_mel - log_recon_mel, dim = -2).mean() 833 | 834 | multi_spectral_recon_loss = multi_spectral_recon_loss + l1_mel_loss + l2_log_mel_loss 835 | 836 | # adversarial loss 837 | 838 | adversarial_losses = [] 839 | 840 | discr_intermediates = [] 841 | 842 | # adversarial loss for multi-scale discriminators 843 | 844 | real, fake = orig_x, recon_x 845 | 846 | # features from stft 847 | 848 | (stft_real_logits, stft_real_intermediates), (stft_fake_logits, stft_fake_intermediates) = map(partial(self.stft_discriminator, return_intermediates=True), (real, fake)) 849 | discr_intermediates.append((stft_real_intermediates, stft_fake_intermediates)) 850 | 851 | scaled_real, scaled_fake = real, fake 852 | for discr, downsample in zip(self.discriminators, self.downsamples): 853 | scaled_real, scaled_fake = map(downsample, (scaled_real, scaled_fake)) 854 | 855 | (real_logits, real_intermediates), (fake_logits, fake_intermediates) = map(partial(discr, return_intermediates = True), (scaled_real, scaled_fake)) 856 | 857 | discr_intermediates.append((real_intermediates, fake_intermediates)) 858 | 859 | one_adversarial_loss = hinge_gen_loss(fake_logits) 860 | adversarial_losses.append(one_adversarial_loss) 861 | 862 | feature_losses = [] 863 | 864 | for real_intermediates, fake_intermediates in discr_intermediates: 865 | losses = [F.l1_loss(real_intermediate, fake_intermediate) for real_intermediate, fake_intermediate in zip(real_intermediates, fake_intermediates)] 866 | feature_losses.extend(losses) 867 | 868 | feature_loss = torch.stack(feature_losses).mean() 869 | 870 | # adversarial loss for stft discriminator 871 | 872 | adversarial_losses.append(hinge_gen_loss(stft_fake_logits)) 873 | adversarial_loss = torch.stack(adversarial_losses).mean() 874 | 875 | # sum commitment loss 876 | 877 | all_commitment_loss = commit_loss.sum() 878 | 879 | total_loss = recon_loss * self.recon_loss_weight + multi_spectral_recon_loss * self.multi_spectral_recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight + all_commitment_loss 880 | 881 | if return_loss_breakdown: 882 | return total_loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) 883 | 884 | return total_loss 885 | 886 | # some default soundstreams 887 | 888 | def AudioLMSoundStream( 889 | strides = (2, 4, 5, 8), 890 | target_sample_hz = 16000, 891 | rq_num_quantizers = 12, 892 | **kwargs 893 | ): 894 | return SoundStream( 895 | strides = strides, 896 | target_sample_hz = target_sample_hz, 897 | rq_num_quantizers = rq_num_quantizers, 898 | **kwargs 899 | ) 900 | 901 | def MusicLMSoundStream( 902 | strides = (3, 4, 5, 8), 903 | target_sample_hz = 24000, 904 | rq_num_quantizers = 12, 905 | **kwargs 906 | ): 907 | return SoundStream( 908 | strides = strides, 909 | target_sample_hz = target_sample_hz, 910 | rq_num_quantizers = rq_num_quantizers, 911 | **kwargs 912 | ) 913 | -------------------------------------------------------------------------------- /audiolm_pytorch/trainer.py: -------------------------------------------------------------------------------- 1 | import re 2 | from math import sqrt 3 | import copy 4 | from random import choice 5 | from pathlib import Path 6 | from shutil import rmtree 7 | 8 | from beartype.typing import Union, List, Optional, Tuple 9 | from typing_extensions import Annotated 10 | 11 | from beartype import beartype 12 | from beartype.door import is_bearable 13 | from beartype.vale import Is 14 | 15 | import torch 16 | import torchaudio 17 | from torch import nn 18 | from torch.utils.data import Dataset, DataLoader, random_split 19 | 20 | from einops import rearrange 21 | 22 | from audiolm_pytorch.optimizer import get_optimizer 23 | 24 | from ema_pytorch import EMA 25 | 26 | from audiolm_pytorch.soundstream import SoundStream 27 | from audiolm_pytorch.encodec import EncodecWrapper 28 | 29 | from audiolm_pytorch.audiolm_pytorch import ( 30 | SemanticTransformer, 31 | SemanticTransformerWrapper, 32 | CoarseTransformer, 33 | CoarseTransformerWrapper, 34 | FineTransformer, 35 | FineTransformerWrapper, 36 | FairseqVQWav2Vec, 37 | HubertWithKmeans 38 | ) 39 | 40 | from audiolm_pytorch.data import SoundDataset, get_dataloader 41 | from audiolm_pytorch.utils import AudioConditionerBase 42 | 43 | from audiolm_pytorch.version import __version__ 44 | from packaging import version 45 | 46 | from accelerate import Accelerator 47 | from accelerate.utils import DistributedDataParallelKwargs 48 | 49 | # constants 50 | 51 | DEFAULT_SAMPLE_RATE = 16000 52 | 53 | # for automatically routing data emitted from a dataset to keywords of the transformer wrappers 54 | 55 | DATASET_FIELD_TYPE_CONFIG = dict( 56 | raw_wave = Annotated[ 57 | torch.Tensor, 58 | Is[lambda t: t.dtype == torch.float and t.ndim in {2, 3}] 59 | ], 60 | text = List[str], 61 | text_embeds = Annotated[ 62 | torch.Tensor, 63 | Is[lambda t: t.dtype == torch.float and t.ndim == 3] 64 | ], 65 | ) 66 | 67 | # helpers 68 | 69 | def exists(val): 70 | return val is not None 71 | 72 | def noop(*args, **kwargs): 73 | pass 74 | 75 | def cycle(dl): 76 | while True: 77 | for data in dl: 78 | yield data 79 | 80 | def cast_tuple(t): 81 | return t if isinstance(t, (tuple, list)) else (t,) 82 | 83 | def yes_or_no(question): 84 | answer = input(f'{question} (y/n) ') 85 | return answer.lower() in ('yes', 'y') 86 | 87 | def accum_log(log, new_logs): 88 | for key, new_value in new_logs.items(): 89 | old_value = log.get(key, 0.) 90 | log[key] = old_value + new_value 91 | return log 92 | 93 | # auto data to module keyword argument routing functions 94 | 95 | def has_duplicates(tup): 96 | counts = dict() 97 | for el in tup: 98 | if el not in counts: 99 | counts[el] = 0 100 | counts[el] += 1 101 | return any(filter(lambda count: count > 1, counts.values())) 102 | 103 | def determine_types(data, config): 104 | output = [] 105 | for el in data: 106 | for name, data_type in config.items(): 107 | if is_bearable(el, data_type): 108 | output.append(name) 109 | break 110 | else: 111 | raise TypeError(f'unable to determine type of {data}') 112 | 113 | return tuple(output) 114 | 115 | def checkpoint_num_steps(checkpoint_path): 116 | """Returns the number of steps trained from a checkpoint based on the filename. 117 | 118 | Filename format assumed to be something like "/path/to/semantic.transformer.20000.pt" which is 119 | for 20k train steps. Returns 20000 in that case. 120 | """ 121 | return int(re.findall(r'\d+', str(checkpoint_path))[-1]) 122 | 123 | # main trainer class 124 | 125 | class SoundStreamTrainer(nn.Module): 126 | @beartype 127 | def __init__( 128 | self, 129 | soundstream: SoundStream, 130 | *, 131 | num_train_steps: int, 132 | batch_size: int, 133 | data_max_length: int = None, 134 | data_max_length_seconds: Union[int, float] = None, 135 | folder: str = None, 136 | train_dataloader: DataLoader = None, 137 | val_dataloader: DataLoader = None, 138 | lr: float = 2e-4, 139 | grad_accum_every: int = 4, 140 | wd: float = 0., 141 | max_grad_norm: float = 0.5, 142 | discr_max_grad_norm: float = None, 143 | save_results_every: int = 100, 144 | save_model_every: int= 1000, 145 | log_losses_every: int= 1, 146 | results_folder: str = './results', 147 | valid_frac: float = 0.05, 148 | random_split_seed: int = 42, 149 | use_ema: bool = True, 150 | ema_beta: float = 0.995, 151 | ema_update_after_step: int = 500, 152 | ema_update_every: int = 10, 153 | apply_grad_penalty_every: int = 4, 154 | dl_num_workers: int = 0, 155 | accelerator: Accelerator = None, 156 | accelerate_kwargs: dict = dict(), 157 | use_lion: bool = False, 158 | force_clear_prev_results: bool = None # set to True | False to skip the prompt 159 | ): 160 | """ 161 | Initialize with a SoundStream instance and either a folder containing audio data or 162 | train/val DataLoader instances. 163 | """ 164 | super().__init__() 165 | 166 | if accelerator: 167 | self.accelerator = accelerator 168 | assert len(accelerate_kwargs) == 0 169 | else: 170 | kwargs = DistributedDataParallelKwargs(find_unused_parameters = True) 171 | self.accelerator = Accelerator(kwargs_handlers = [kwargs], **accelerate_kwargs) 172 | 173 | self.soundstream = soundstream 174 | 175 | self.use_ema = use_ema 176 | if self.use_ema: 177 | self.ema_soundstream = EMA(soundstream, beta = ema_beta, update_after_step = ema_update_after_step, update_every = ema_update_every) 178 | 179 | self.register_buffer('steps', torch.Tensor([0])) 180 | 181 | self.num_train_steps = num_train_steps 182 | self.batch_size = batch_size 183 | self.grad_accum_every = grad_accum_every 184 | 185 | hyperparameters = { 186 | "num_train_steps": num_train_steps, 187 | "batch_size": batch_size, 188 | "gradient_accum_every": grad_accum_every, 189 | "learning_rate": lr, 190 | "target_sample_hz": soundstream.target_sample_hz, 191 | } 192 | 193 | # optimizers 194 | 195 | self.optim = get_optimizer(soundstream.non_discr_parameters(), lr = lr, wd = wd) 196 | 197 | for discr_optimizer_key, discr in self.multiscale_discriminator_iter(): 198 | one_multiscale_discr_optimizer = get_optimizer(discr.parameters(), lr = lr, wd = wd) 199 | setattr(self, discr_optimizer_key, one_multiscale_discr_optimizer) 200 | 201 | self.discr_optim = get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd, use_lion = use_lion) 202 | 203 | # max grad norm 204 | 205 | self.max_grad_norm = max_grad_norm 206 | self.discr_max_grad_norm = discr_max_grad_norm 207 | 208 | if folder is None: 209 | assert train_dataloader is not None 210 | assert val_dataloader is not None 211 | self.dl = train_dataloader 212 | self.valid_dl = val_dataloader 213 | else: 214 | assert train_dataloader is None 215 | assert val_dataloader is None 216 | 217 | # create dataset 218 | 219 | if exists(data_max_length_seconds): 220 | assert not exists(data_max_length) 221 | data_max_length = int(data_max_length_seconds * soundstream.target_sample_hz) 222 | else: 223 | assert exists(data_max_length) 224 | 225 | hyperparameters['data_max_length'] = data_max_length 226 | 227 | self.ds = SoundDataset( 228 | folder, 229 | max_length = data_max_length, 230 | target_sample_hz = soundstream.target_sample_hz, 231 | seq_len_multiple_of = soundstream.seq_len_multiple_of 232 | ) 233 | 234 | # split for validation 235 | 236 | if valid_frac > 0: 237 | train_size = int((1 - valid_frac) * len(self.ds)) 238 | valid_size = len(self.ds) - train_size 239 | self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) 240 | self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') 241 | else: 242 | self.valid_ds = self.ds 243 | self.print(f'training with shared training and valid dataset of {len(self.ds)} samples') 244 | 245 | # dataloader 246 | 247 | self.dl = get_dataloader(self.ds, batch_size = batch_size, num_workers = dl_num_workers, shuffle = True) 248 | 249 | self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, num_workers = dl_num_workers, shuffle = True) 250 | 251 | # prepare with accelerator 252 | 253 | ( 254 | self.soundstream, 255 | self.optim, 256 | self.discr_optim, 257 | self.dl 258 | ) = self.accelerator.prepare( 259 | self.soundstream, 260 | self.optim, 261 | self.discr_optim, 262 | self.dl 263 | ) 264 | 265 | # prepare the multiscale discriminators with accelerator 266 | 267 | for name, _ in self.multiscale_discriminator_iter(): 268 | optimizer = getattr(self, name) 269 | optimizer = self.accelerator.prepare(optimizer) 270 | setattr(self, name, optimizer) 271 | 272 | # dataloader iterators 273 | 274 | self.dl_iter = cycle(self.dl) 275 | self.valid_dl_iter = cycle(self.valid_dl) 276 | 277 | self.save_model_every = save_model_every 278 | self.save_results_every = save_results_every 279 | self.log_losses_every = log_losses_every 280 | 281 | self.apply_grad_penalty_every = apply_grad_penalty_every 282 | 283 | self.results_folder = Path(results_folder) 284 | 285 | if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')): 286 | rmtree(str(self.results_folder)) 287 | 288 | self.results_folder.mkdir(parents = True, exist_ok = True) 289 | 290 | # Initialize experiment trackers if an external Accelerator is not passed in 291 | if not accelerator: 292 | self.accelerator.init_trackers("soundstream", config=hyperparameters) 293 | 294 | def set_model_as_ema_model_(self): 295 | """ this will force the main 'online' model to have same parameters as the exponentially moving averaged model """ 296 | assert self.use_ema 297 | self.ema_soundstream.ema_model.load_state_dict(self.soundstream.state_dict()) 298 | 299 | def save(self, path): 300 | pkg = dict( 301 | model = self.accelerator.get_state_dict(self.soundstream), 302 | optim = self.optim.state_dict(), 303 | config = self.unwrapped_soundstream._configs, 304 | discr_optim = self.discr_optim.state_dict(), 305 | version = __version__ 306 | ) 307 | 308 | if self.use_ema: 309 | pkg['ema_model'] = self.ema_soundstream.state_dict() 310 | 311 | for key, _ in self.multiscale_discriminator_iter(): 312 | discr_optim = getattr(self, key) 313 | pkg[key] = discr_optim.state_dict() 314 | 315 | torch.save(pkg, path) 316 | 317 | @property 318 | def unwrapped_soundstream(self): 319 | return self.accelerator.unwrap_model(self.soundstream) 320 | 321 | def load(self, path): 322 | path = Path(path) 323 | assert path.exists() 324 | pkg = torch.load(str(path), map_location = 'cpu') 325 | 326 | # if loading from old version, make a hacky guess 327 | 328 | if len(pkg.keys()) > 20: 329 | self.unwrapped_soundstream.load_state_dict(pkg) 330 | 331 | if self.use_ema: 332 | self.ema_soundstream.ema_model.load_state_dict(pkg) 333 | return 334 | 335 | # check version 336 | 337 | if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__): 338 | print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch') 339 | 340 | # otherwise load things normally 341 | 342 | self.unwrapped_soundstream.load_state_dict(pkg['model']) 343 | 344 | if self.use_ema: 345 | assert 'ema_model' in pkg 346 | self.ema_soundstream.load_state_dict(pkg['ema_model']) 347 | 348 | self.optim.load_state_dict(pkg['optim']) 349 | self.discr_optim.load_state_dict(pkg['discr_optim']) 350 | 351 | for key, _ in self.multiscale_discriminator_iter(): 352 | discr_optim = getattr(self, key) 353 | discr_optim.load_state_dict(pkg[key]) 354 | # + 1 to start from the next step and avoid overwriting the last checkpoint 355 | self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device) 356 | 357 | def multiscale_discriminator_iter(self): 358 | for ind, discr in enumerate(self.unwrapped_soundstream.discriminators): 359 | yield f'multiscale_discr_optimizer_{ind}', discr 360 | 361 | def multiscale_discriminator_optim_iter(self): 362 | for name, _ in self.multiscale_discriminator_iter(): 363 | yield name, getattr(self, name) 364 | 365 | def print(self, msg): 366 | self.accelerator.print(msg) 367 | 368 | @property 369 | def device(self): 370 | return self.accelerator.device 371 | 372 | @property 373 | def is_distributed(self): 374 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) 375 | 376 | @property 377 | def is_main(self): 378 | return self.accelerator.is_main_process 379 | 380 | @property 381 | def is_local_main(self): 382 | return self.accelerator.is_local_main_process 383 | 384 | def train_step(self): 385 | device = self.device 386 | 387 | steps = int(self.steps.item()) 388 | apply_grad_penalty = self.apply_grad_penalty_every > 0 and not (steps % self.apply_grad_penalty_every) 389 | log_losses = self.log_losses_every > 0 and not (steps % self.log_losses_every) 390 | 391 | self.soundstream.train() 392 | 393 | # logs 394 | 395 | logs = {} 396 | 397 | # update vae (generator) 398 | 399 | for _ in range(self.grad_accum_every): 400 | wave, = next(self.dl_iter) 401 | wave = wave.to(device) 402 | 403 | loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True) 404 | 405 | self.accelerator.backward(loss / self.grad_accum_every) 406 | 407 | accum_log(logs, dict( 408 | loss = loss.item() / self.grad_accum_every, 409 | recon_loss = recon_loss.item() / self.grad_accum_every, 410 | )) 411 | 412 | if log_losses: 413 | accum_log(logs, dict( 414 | multi_spectral_recon_loss = multi_spectral_recon_loss.item() / self.grad_accum_every, 415 | adversarial_loss = adversarial_loss.item() / self.grad_accum_every, 416 | feature_loss = feature_loss.item() / self.grad_accum_every, 417 | all_commitment_loss = all_commitment_loss.item() / self.grad_accum_every, 418 | )) 419 | 420 | if exists(self.max_grad_norm): 421 | self.accelerator.clip_grad_norm_(self.soundstream.parameters(), self.max_grad_norm) 422 | 423 | self.optim.step() 424 | self.optim.zero_grad() 425 | 426 | # update discriminator 427 | 428 | self.discr_optim.zero_grad() 429 | 430 | for name, multiscale_discr_optim in self.multiscale_discriminator_optim_iter(): 431 | multiscale_discr_optim.zero_grad() 432 | 433 | for _ in range(self.grad_accum_every): 434 | wave, = next(self.dl_iter) 435 | wave = wave.to(device) 436 | 437 | discr_losses = self.soundstream( 438 | wave, 439 | apply_grad_penalty = apply_grad_penalty, 440 | return_discr_loss = True, 441 | return_discr_losses_separately = True 442 | ) 443 | 444 | for name, discr_loss in discr_losses: 445 | self.accelerator.backward(discr_loss / self.grad_accum_every, retain_graph = True) 446 | accum_log(logs, {name: discr_loss.item() / self.grad_accum_every}) 447 | 448 | if exists(self.discr_max_grad_norm): 449 | self.accelerator.clip_grad_norm_(self.soundstream.stft_discriminator.parameters(), self.discr_max_grad_norm) 450 | 451 | # gradient step for all discriminators 452 | 453 | self.discr_optim.step() 454 | 455 | for name, multiscale_discr_optim in self.multiscale_discriminator_optim_iter(): 456 | multiscale_discr_optim.step() 457 | 458 | # build pretty printed losses 459 | 460 | losses_str = f"{steps}: soundstream total loss: {logs['loss']:.3f}, soundstream recon loss: {logs['recon_loss']:.3f}" 461 | if log_losses: 462 | self.accelerator.log({ 463 | "total_loss": logs['loss'], 464 | "recon_loss": logs['recon_loss'], 465 | "multi_spectral_recon_loss": logs['multi_spectral_recon_loss'], 466 | "adversarial_loss": logs['adversarial_loss'], 467 | "feature_loss": logs['feature_loss'], 468 | "all_commitment_loss": logs['all_commitment_loss'], 469 | "stft_discr_loss": logs['stft'] 470 | }, step=steps) 471 | 472 | for key, loss in logs.items(): 473 | if not key.startswith('scale:'): 474 | continue 475 | _, scale_factor = key.split(':') 476 | 477 | losses_str += f" | discr (scale {scale_factor}) loss: {loss:.3f}" 478 | if log_losses: 479 | self.accelerator.log({f"discr_loss (scale {scale_factor})": loss}, step=steps) 480 | 481 | # log 482 | 483 | self.print(losses_str) 484 | 485 | # update exponential moving averaged generator 486 | 487 | self.accelerator.wait_for_everyone() 488 | 489 | if self.is_main and self.use_ema: 490 | self.ema_soundstream.update() 491 | 492 | # sample results every so often 493 | 494 | self.accelerator.wait_for_everyone() 495 | 496 | if self.is_main and not (steps % self.save_results_every): 497 | models = [(self.unwrapped_soundstream, str(steps))] 498 | if self.use_ema: 499 | models.append((self.ema_soundstream.ema_model if self.use_ema else self.unwrapped_soundstream, f'{steps}.ema')) 500 | 501 | wave, = next(self.valid_dl_iter) 502 | wave = wave.to(device) 503 | 504 | for model, label in models: 505 | model.eval() 506 | 507 | with torch.no_grad(): 508 | recons = model(wave, return_recons_only = True) 509 | 510 | for ind, recon in enumerate(recons.unbind(dim = 0)): 511 | filename = str(self.results_folder / f'sample_{label}.flac') 512 | torchaudio.save(filename, recon.cpu().detach(), self.unwrapped_soundstream.target_sample_hz) 513 | 514 | self.print(f'{steps}: saving to {str(self.results_folder)}') 515 | 516 | # save model every so often 517 | 518 | self.accelerator.wait_for_everyone() 519 | 520 | if self.is_main and not (steps % self.save_model_every): 521 | model_path = str(self.results_folder / f'soundstream.{steps}.pt') 522 | self.save(model_path) 523 | 524 | self.print(f'{steps}: saving model to {str(self.results_folder)}') 525 | 526 | self.steps += 1 527 | return logs 528 | 529 | def train(self, log_fn = noop): 530 | 531 | while self.steps < self.num_train_steps: 532 | logs = self.train_step() 533 | log_fn(logs) 534 | 535 | self.print('training complete') 536 | 537 | # semantic transformer trainer 538 | 539 | class SemanticTransformerTrainer(nn.Module): 540 | @beartype 541 | def __init__( 542 | self, 543 | wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]], 544 | transformer: SemanticTransformer, 545 | *, 546 | num_train_steps, 547 | batch_size, 548 | audio_conditioner: Optional[AudioConditionerBase] = None, 549 | dataset: Optional[Dataset] = None, 550 | data_max_length = None, 551 | data_max_length_seconds = None, 552 | folder = None, 553 | lr = 3e-4, 554 | grad_accum_every = 1, 555 | wd = 0., 556 | max_grad_norm = 0.5, 557 | valid_frac = 0.05, 558 | random_split_seed = 42, 559 | save_results_every = 100, 560 | save_model_every = 1000, 561 | results_folder = './results', 562 | accelerate_kwargs: dict = dict(), 563 | force_clear_prev_results = None 564 | ): 565 | super().__init__() 566 | self.accelerator = Accelerator(**accelerate_kwargs) 567 | 568 | self.wav2vec = wav2vec 569 | self.transformer = transformer 570 | self.audio_conditioner = audio_conditioner 571 | 572 | self.train_wrapper = SemanticTransformerWrapper( 573 | wav2vec = wav2vec, 574 | transformer = transformer, 575 | audio_conditioner = audio_conditioner 576 | ) 577 | 578 | self.register_buffer('steps', torch.Tensor([0])) 579 | 580 | self.num_train_steps = num_train_steps 581 | self.batch_size = batch_size 582 | self.grad_accum_every = grad_accum_every 583 | 584 | # optimizers 585 | 586 | self.optim = get_optimizer(transformer.parameters(), lr = lr, wd = wd) 587 | 588 | # max grad norm 589 | 590 | self.max_grad_norm = max_grad_norm 591 | 592 | # create dataset 593 | 594 | self.ds = dataset 595 | if not exists(self.ds): 596 | assert exists(folder), 'folder must be passed in, if not passing in a custom dataset for text conditioned audio synthesis training' 597 | 598 | assert not (exists(data_max_length) and exists(data_max_length_seconds)) 599 | 600 | if exists(data_max_length_seconds): 601 | data_max_length = data_max_length_seconds * wav2vec.target_sample_hz 602 | 603 | self.ds = SoundDataset( 604 | folder, 605 | max_length = data_max_length, 606 | target_sample_hz = wav2vec.target_sample_hz, 607 | seq_len_multiple_of = wav2vec.seq_len_multiple_of 608 | ) 609 | 610 | self.ds_fields = None 611 | 612 | # split for validation 613 | 614 | if valid_frac > 0: 615 | train_size = int((1 - valid_frac) * len(self.ds)) 616 | valid_size = len(self.ds) - train_size 617 | self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) 618 | self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') 619 | else: 620 | self.valid_ds = self.ds 621 | self.print(f'training with shared training and valid dataset of {len(self.ds)} samples') 622 | 623 | # dataloader 624 | 625 | self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True) 626 | 627 | self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True) 628 | 629 | # prepare with accelerator 630 | 631 | ( 632 | self.train_wrapper, 633 | self.optim, 634 | self.dl, 635 | self.valid_dl 636 | ) = self.accelerator.prepare( 637 | self.train_wrapper, 638 | self.optim, 639 | self.dl, 640 | self.valid_dl 641 | ) 642 | 643 | # dataloader iterators 644 | 645 | self.dl_iter = cycle(self.dl) 646 | self.valid_dl_iter = cycle(self.valid_dl) 647 | 648 | self.save_model_every = save_model_every 649 | self.save_results_every = save_results_every 650 | 651 | self.results_folder = Path(results_folder) 652 | 653 | if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')): 654 | rmtree(str(self.results_folder)) 655 | 656 | self.results_folder.mkdir(parents = True, exist_ok = True) 657 | 658 | hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr} 659 | self.accelerator.init_trackers("semantic", config=hps) 660 | 661 | def save(self, path): 662 | pkg = dict( 663 | model = self.accelerator.get_state_dict(self.transformer), 664 | optim = self.optim.state_dict(), 665 | version = __version__ 666 | ) 667 | torch.save(pkg, path) 668 | 669 | def load(self, path): 670 | path = Path(path) 671 | assert path.exists() 672 | pkg = torch.load(str(path), map_location = 'cpu') 673 | 674 | # check version 675 | 676 | if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__): 677 | print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch') 678 | 679 | transformer = self.accelerator.unwrap_model(self.transformer) 680 | transformer.load_state_dict(pkg['model']) 681 | self.optim.load_state_dict(pkg['optim']) 682 | # + 1 to start from the next step and avoid overwriting the last checkpoint 683 | self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device) 684 | 685 | 686 | def print(self, msg): 687 | self.accelerator.print(msg) 688 | 689 | def generate(self, *args, **kwargs): 690 | return self.train_wrapper.generate(*args, **kwargs) 691 | 692 | @property 693 | def device(self): 694 | return self.accelerator.device 695 | 696 | @property 697 | def is_distributed(self): 698 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) 699 | 700 | @property 701 | def is_main(self): 702 | return self.accelerator.is_main_process 703 | 704 | @property 705 | def is_local_main(self): 706 | return self.accelerator.is_local_main_process 707 | 708 | def data_tuple_to_kwargs(self, data): 709 | if not exists(self.ds_fields): 710 | self.ds_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG) 711 | assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names' 712 | 713 | return dict(zip(self.ds_fields, data)) 714 | 715 | def train_step(self): 716 | device = self.device 717 | 718 | steps = int(self.steps.item()) 719 | 720 | self.transformer.train() 721 | 722 | # logs 723 | 724 | logs = {} 725 | 726 | # update vae (generator) 727 | 728 | for _ in range(self.grad_accum_every): 729 | data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter)) 730 | 731 | loss = self.train_wrapper(**data_kwargs, return_loss = True) 732 | 733 | self.accelerator.backward(loss / self.grad_accum_every) 734 | 735 | accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) 736 | 737 | if exists(self.max_grad_norm): 738 | self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm) 739 | 740 | self.optim.step() 741 | self.optim.zero_grad() 742 | 743 | # log 744 | 745 | self.print(f"{steps}: loss: {logs['loss']}") 746 | self.accelerator.log({"train_loss": logs['loss']}, step=steps) 747 | 748 | # sample results every so often 749 | 750 | if self.is_main and not (steps % self.save_results_every): 751 | data_kwargs = self.data_tuple_to_kwargs(next(self.valid_dl_iter)) 752 | 753 | with torch.no_grad(): 754 | self.train_wrapper.eval() 755 | valid_loss = self.train_wrapper(**data_kwargs, return_loss = True) 756 | 757 | self.print(f'{steps}: valid loss {valid_loss}') 758 | self.accelerator.log({"valid_loss": valid_loss}, step=steps) 759 | 760 | # save model every so often 761 | 762 | if self.is_main and not (steps % self.save_model_every): 763 | model_path = str(self.results_folder / f'semantic.transformer.{steps}.pt') 764 | self.save(model_path) 765 | 766 | self.print(f'{steps}: saving model to {str(self.results_folder)}') 767 | 768 | self.steps += 1 769 | return logs 770 | 771 | def train(self, log_fn = noop): 772 | 773 | while self.steps < self.num_train_steps: 774 | logs = self.train_step() 775 | log_fn(logs) 776 | 777 | self.print('training complete') 778 | 779 | # fine transformer trainer 780 | 781 | class CoarseTransformerTrainer(nn.Module): 782 | @beartype 783 | def __init__( 784 | self, 785 | transformer: CoarseTransformer, 786 | codec: Union[SoundStream, EncodecWrapper], 787 | wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]], 788 | *, 789 | num_train_steps, 790 | batch_size, 791 | audio_conditioner: Optional[AudioConditionerBase] = None, 792 | dataset: Optional[Dataset] = None, 793 | ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_codec', 'text'), 794 | data_max_length = None, 795 | data_max_length_seconds = None, 796 | folder = None, 797 | lr = 3e-4, 798 | grad_accum_every = 1, 799 | wd = 0., 800 | max_grad_norm = 0.5, 801 | valid_frac = 0.05, 802 | random_split_seed = 42, 803 | save_results_every = 100, 804 | save_model_every = 1000, 805 | results_folder = './results', 806 | accelerate_kwargs: dict = dict(), 807 | force_clear_prev_results = None 808 | ): 809 | super().__init__() 810 | self.accelerator = Accelerator(**accelerate_kwargs) 811 | 812 | self.transformer = transformer 813 | self.codec = codec 814 | self.wav2vec = wav2vec 815 | self.audio_conditioner = audio_conditioner 816 | 817 | self.train_wrapper = CoarseTransformerWrapper( 818 | codec = codec, 819 | wav2vec = wav2vec, 820 | transformer = transformer, 821 | audio_conditioner = audio_conditioner 822 | ) 823 | 824 | self.register_buffer('steps', torch.Tensor([0])) 825 | 826 | self.num_train_steps = num_train_steps 827 | self.batch_size = batch_size 828 | self.grad_accum_every = grad_accum_every 829 | 830 | # optimizers 831 | 832 | self.optim = get_optimizer(transformer.parameters(), lr = lr, wd = wd) 833 | 834 | # max grad norm 835 | 836 | self.max_grad_norm = max_grad_norm 837 | 838 | # create dataset 839 | 840 | self.ds = dataset 841 | 842 | if not exists(self.ds): 843 | assert exists(folder), 'folder must be passed in, if not passing in a custom dataset for text conditioned audio synthesis training' 844 | 845 | assert not (exists(data_max_length) and exists(data_max_length_seconds)) 846 | 847 | if exists(data_max_length_seconds): 848 | data_max_length = tuple(data_max_length_seconds * hz for hz in (wav2vec.target_sample_hz, codec.target_sample_hz)) 849 | 850 | self.ds = SoundDataset( 851 | folder, 852 | max_length = data_max_length, 853 | target_sample_hz = ( 854 | wav2vec.target_sample_hz, 855 | codec.target_sample_hz 856 | ), # need 2 waves resampled differently here 857 | seq_len_multiple_of = codec.seq_len_multiple_of 858 | ) 859 | 860 | self.ds_fields = ds_fields 861 | 862 | # split for validation 863 | 864 | if valid_frac > 0: 865 | train_size = int((1 - valid_frac) * len(self.ds)) 866 | valid_size = len(self.ds) - train_size 867 | self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) 868 | self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') 869 | else: 870 | self.valid_ds = self.ds 871 | self.print(f'training with shared training and valid dataset of {len(self.ds)} samples') 872 | 873 | # dataloader 874 | 875 | self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True) 876 | 877 | self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True) 878 | 879 | # prepare with accelerator 880 | 881 | ( 882 | self.transformer, 883 | self.optim, 884 | self.dl, 885 | self.valid_dl 886 | ) = self.accelerator.prepare( 887 | self.transformer, 888 | self.optim, 889 | self.dl, 890 | self.valid_dl 891 | ) 892 | 893 | # dataloader iterators 894 | 895 | self.dl_iter = cycle(self.dl) 896 | self.valid_dl_iter = cycle(self.valid_dl) 897 | 898 | self.save_model_every = save_model_every 899 | self.save_results_every = save_results_every 900 | 901 | self.results_folder = Path(results_folder) 902 | 903 | if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')): 904 | rmtree(str(self.results_folder)) 905 | 906 | self.results_folder.mkdir(parents = True, exist_ok = True) 907 | 908 | hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr} 909 | self.accelerator.init_trackers("coarse", config=hps) 910 | 911 | self.train_wrapper.to(self.device) 912 | 913 | def save(self, path): 914 | pkg = dict( 915 | model = self.accelerator.get_state_dict(self.transformer), 916 | optim = self.optim.state_dict(), 917 | version = __version__ 918 | ) 919 | torch.save(pkg, path) 920 | 921 | def load(self, path): 922 | path = Path(path) 923 | assert path.exists() 924 | pkg = torch.load(str(path), map_location = 'cpu') 925 | 926 | # check version 927 | 928 | if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__): 929 | print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch') 930 | 931 | transformer = self.accelerator.unwrap_model(self.transformer) 932 | transformer.load_state_dict(pkg['model']) 933 | 934 | self.optim.load_state_dict(pkg['optim']) 935 | # + 1 to start from the next step and avoid overwriting the last checkpoint 936 | self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device) 937 | 938 | 939 | def print(self, msg): 940 | self.accelerator.print(msg) 941 | 942 | def generate(self, *args, **kwargs): 943 | return self.train_wrapper.generate(*args, **kwargs) 944 | 945 | @property 946 | def device(self): 947 | return self.accelerator.device 948 | 949 | @property 950 | def is_distributed(self): 951 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) 952 | 953 | @property 954 | def is_main(self): 955 | return self.accelerator.is_main_process 956 | 957 | @property 958 | def is_local_main(self): 959 | return self.accelerator.is_local_main_process 960 | 961 | def train_step(self): 962 | device = self.device 963 | 964 | steps = int(self.steps.item()) 965 | 966 | self.transformer.train() 967 | 968 | # logs 969 | 970 | logs = {} 971 | 972 | # update vae (generator) 973 | 974 | for _ in range(self.grad_accum_every): 975 | data_kwargs = dict(zip(self.ds_fields, next(self.dl_iter))) 976 | 977 | loss = self.train_wrapper( 978 | **data_kwargs, 979 | return_loss = True 980 | ) 981 | 982 | self.accelerator.backward(loss / self.grad_accum_every) 983 | 984 | accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) 985 | 986 | if exists(self.max_grad_norm): 987 | self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm) 988 | 989 | self.optim.step() 990 | self.optim.zero_grad() 991 | 992 | # log 993 | 994 | self.print(f"{steps}: loss: {logs['loss']}") 995 | self.accelerator.log({"train_loss": logs['loss']}, step=steps) 996 | 997 | # sample results every so often 998 | 999 | if self.is_main and not (steps % self.save_results_every): 1000 | data_kwargs = dict(zip(self.ds_fields, next(self.valid_dl_iter))) 1001 | 1002 | with torch.no_grad(): 1003 | self.train_wrapper.eval() 1004 | 1005 | valid_loss = self.train_wrapper( 1006 | **data_kwargs, 1007 | return_loss = True 1008 | ) 1009 | 1010 | self.print(f'{steps}: valid loss {valid_loss}') 1011 | self.accelerator.log({"valid_loss": valid_loss}, step=steps) 1012 | 1013 | # save model every so often 1014 | 1015 | if self.is_main and not (steps % self.save_model_every): 1016 | model_path = str(self.results_folder / f'coarse.transformer.{steps}.pt') 1017 | self.save(model_path) 1018 | 1019 | self.print(f'{steps}: saving model to {str(self.results_folder)}') 1020 | 1021 | self.steps += 1 1022 | return logs 1023 | 1024 | def train(self, log_fn = noop): 1025 | 1026 | while self.steps < self.num_train_steps: 1027 | logs = self.train_step() 1028 | log_fn(logs) 1029 | 1030 | self.print('training complete') 1031 | 1032 | # fine transformer trainer 1033 | 1034 | class FineTransformerTrainer(nn.Module): 1035 | @beartype 1036 | def __init__( 1037 | self, 1038 | transformer: FineTransformer, 1039 | codec: Union[SoundStream, EncodecWrapper], 1040 | *, 1041 | num_train_steps, 1042 | batch_size, 1043 | audio_conditioner: Optional[AudioConditionerBase] = None, 1044 | dataset: Optional[Dataset] = None, 1045 | data_max_length = None, 1046 | data_max_length_seconds = None, 1047 | dataset_normalize = False, 1048 | folder = None, 1049 | lr = 3e-4, 1050 | grad_accum_every = 1, 1051 | wd = 0., 1052 | max_grad_norm = 0.5, 1053 | valid_frac = 0.05, 1054 | random_split_seed = 42, 1055 | save_results_every = 100, 1056 | save_model_every = 1000, 1057 | results_folder = './results', 1058 | accelerate_kwargs: dict = dict(), 1059 | force_clear_prev_results = None 1060 | ): 1061 | super().__init__() 1062 | self.accelerator = Accelerator(**accelerate_kwargs) 1063 | 1064 | self.transformer = transformer 1065 | self.codec = codec 1066 | self.audio_conditioner = audio_conditioner 1067 | 1068 | self.train_wrapper = FineTransformerWrapper( 1069 | codec = codec, 1070 | transformer = transformer, 1071 | audio_conditioner = audio_conditioner 1072 | ) 1073 | 1074 | self.register_buffer('steps', torch.Tensor([0])) 1075 | 1076 | self.num_train_steps = num_train_steps 1077 | self.batch_size = batch_size 1078 | self.grad_accum_every = grad_accum_every 1079 | 1080 | # optimizers 1081 | 1082 | self.optim = get_optimizer(transformer.parameters(), lr = lr, wd = wd) 1083 | 1084 | # max grad norm 1085 | 1086 | self.max_grad_norm = max_grad_norm 1087 | 1088 | # create dataset 1089 | 1090 | self.ds = dataset 1091 | 1092 | if not exists(self.ds): 1093 | assert exists(folder), 'folder must be passed in, if not passing in a custom dataset for text conditioned audio synthesis training' 1094 | 1095 | assert not (exists(data_max_length) and exists(data_max_length_seconds)) 1096 | 1097 | if exists(data_max_length_seconds): 1098 | data_max_length = data_max_length_seconds * codec.target_sample_hz 1099 | 1100 | self.ds = SoundDataset( 1101 | folder, 1102 | max_length = data_max_length, 1103 | target_sample_hz = codec.target_sample_hz, 1104 | seq_len_multiple_of = codec.seq_len_multiple_of 1105 | ) 1106 | 1107 | self.ds_fields = None 1108 | 1109 | # split for validation 1110 | 1111 | if valid_frac > 0: 1112 | train_size = int((1 - valid_frac) * len(self.ds)) 1113 | valid_size = len(self.ds) - train_size 1114 | self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) 1115 | self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') 1116 | else: 1117 | self.valid_ds = self.ds 1118 | self.print(f'training with shared training and valid dataset of {len(self.ds)} samples') 1119 | 1120 | # dataloader 1121 | 1122 | self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True) 1123 | 1124 | self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True) 1125 | 1126 | # prepare with accelerator 1127 | 1128 | ( 1129 | self.transformer, 1130 | self.optim, 1131 | self.dl, 1132 | self.valid_dl 1133 | ) = self.accelerator.prepare( 1134 | self.transformer, 1135 | self.optim, 1136 | self.dl, 1137 | self.valid_dl 1138 | ) 1139 | 1140 | # dataloader iterators 1141 | 1142 | self.dl_iter = cycle(self.dl) 1143 | self.valid_dl_iter = cycle(self.valid_dl) 1144 | 1145 | self.save_model_every = save_model_every 1146 | self.save_results_every = save_results_every 1147 | 1148 | self.results_folder = Path(results_folder) 1149 | 1150 | if force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')): 1151 | rmtree(str(self.results_folder)) 1152 | 1153 | self.results_folder.mkdir(parents = True, exist_ok = True) 1154 | 1155 | hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr} 1156 | self.accelerator.init_trackers("fine", config=hps) 1157 | 1158 | self.train_wrapper.to(self.device) 1159 | 1160 | def save(self, path): 1161 | pkg = dict( 1162 | model = self.accelerator.get_state_dict(self.transformer), 1163 | optim = self.optim.state_dict(), 1164 | version = __version__ 1165 | ) 1166 | torch.save(pkg, path) 1167 | 1168 | def load(self, path): 1169 | path = Path(path) 1170 | assert path.exists() 1171 | pkg = torch.load(str(path), map_location = 'cpu') 1172 | 1173 | # check version 1174 | 1175 | if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__): 1176 | print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch') 1177 | 1178 | transformer = self.accelerator.unwrap_model(self.transformer) 1179 | transformer.load_state_dict(pkg['model']) 1180 | 1181 | self.optim.load_state_dict(pkg['optim']) 1182 | # + 1 to start from the next step and avoid overwriting the last checkpoint 1183 | self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device) 1184 | 1185 | 1186 | def print(self, msg): 1187 | self.accelerator.print(msg) 1188 | 1189 | def generate(self, *args, **kwargs): 1190 | return self.train_wrapper.generate(*args, **kwargs) 1191 | 1192 | @property 1193 | def device(self): 1194 | return self.accelerator.device 1195 | 1196 | @property 1197 | def is_distributed(self): 1198 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) 1199 | 1200 | @property 1201 | def is_main(self): 1202 | return self.accelerator.is_main_process 1203 | 1204 | @property 1205 | def is_local_main(self): 1206 | return self.accelerator.is_local_main_process 1207 | 1208 | def data_tuple_to_kwargs(self, data): 1209 | if not exists(self.ds_fields): 1210 | self.ds_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG) 1211 | assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names' 1212 | 1213 | return dict(zip(self.ds_fields, data)) 1214 | 1215 | def train_step(self): 1216 | device = self.device 1217 | 1218 | steps = int(self.steps.item()) 1219 | 1220 | self.transformer.train() 1221 | 1222 | # logs 1223 | 1224 | logs = {} 1225 | 1226 | # update vae (generator) 1227 | 1228 | for _ in range(self.grad_accum_every): 1229 | data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter)) 1230 | loss = self.train_wrapper(**data_kwargs, return_loss = True) 1231 | 1232 | self.accelerator.backward(loss / self.grad_accum_every) 1233 | 1234 | accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) 1235 | 1236 | if exists(self.max_grad_norm): 1237 | self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm) 1238 | 1239 | self.optim.step() 1240 | self.optim.zero_grad() 1241 | 1242 | # log 1243 | 1244 | self.print(f"{steps}: loss: {logs['loss']}") 1245 | self.accelerator.log({"train_loss": logs['loss']}, step=steps) 1246 | 1247 | # sample results every so often 1248 | 1249 | if self.is_main and not (steps % self.save_results_every): 1250 | data_kwargs = self.data_tuple_to_kwargs(next(self.valid_dl_iter)) 1251 | 1252 | with torch.no_grad(): 1253 | self.train_wrapper.eval() 1254 | valid_loss = self.train_wrapper(**data_kwargs, return_loss = True) 1255 | 1256 | self.print(f'{steps}: valid loss {valid_loss}') 1257 | self.accelerator.log({"valid_loss": valid_loss}, step=steps) 1258 | 1259 | # save model every so often 1260 | 1261 | if self.is_main and not (steps % self.save_model_every): 1262 | model_path = str(self.results_folder / f'fine.transformer.{steps}.pt') 1263 | self.save(model_path) 1264 | 1265 | self.print(f'{steps}: saving model to {str(self.results_folder)}') 1266 | 1267 | self.steps += 1 1268 | return logs 1269 | 1270 | def train(self, log_fn = noop): 1271 | 1272 | while self.steps < self.num_train_steps: 1273 | logs = self.train_step() 1274 | log_fn(logs) 1275 | 1276 | self.print('training complete') 1277 | --------------------------------------------------------------------------------