├── stable_audio_tools ├── data │ ├── __init__.py │ └── utils.py ├── inference │ ├── __init__.py │ └── utils.py ├── interface │ ├── __init__.py │ └── interfaces │ │ └── __init__.py ├── models │ ├── __init__.py │ ├── pretrained.py │ ├── wavelets.py │ ├── arc.py │ ├── inpainting.py │ ├── fsq.py │ ├── utils.py │ ├── convnext.py │ ├── lm_backbone.py │ ├── encodec.py │ ├── factory.py │ ├── local_attention.py │ └── pretransforms.py ├── training │ ├── __init__.py │ ├── losses │ │ ├── __init__.py │ │ ├── metrics.py │ │ ├── utils.py │ │ ├── semantic.py │ │ └── losses.py │ ├── utils.py │ └── lm.py ├── __init__.py └── configs │ ├── dataset_configs │ ├── custom_metadata │ │ └── custom_md_example.py │ ├── s3_wds_example.json │ └── local_training_example.json │ └── model_configs │ ├── dance_diffusion │ ├── dance_diffusion_base.json │ ├── dance_diffusion_base_16k.json │ ├── dance_diffusion_base_44k.json │ └── dance_diffusion_large.json │ ├── autoencoders │ ├── dac_2048_32_vae.json │ ├── encodec_musicgen_rvq.json │ ├── stable_audio_1_0_vae.json │ └── stable_audio_2_0_vae.json │ └── txt2audio │ ├── stable_audio_1_0.json │ └── stable_audio_2_0.json ├── pyproject.toml ├── scripts └── ds_zero_to_pl_ckpt.py ├── LICENSE ├── LICENSES ├── LICENSE_ADP.txt ├── LICENSE_DESCRIPT.txt ├── LICENSE_XTRANSFORMERS.txt ├── LICENSE_NVIDIA.txt ├── LICENSE_META.txt └── LICENSE_AEIOU.txt ├── setup.py ├── defaults.ini ├── run_gradio.py ├── docs ├── pretransforms.md ├── pre_encoding.md ├── datasets.md ├── conditioning.md └── diffusion.md ├── .gitignore ├── unwrap_model.py ├── train.py ├── README.md └── pre_encode.py /stable_audio_tools/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stable_audio_tools/inference/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stable_audio_tools/interface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stable_audio_tools/interface/interfaces/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /stable_audio_tools/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import create_model_from_config, create_model_from_config_path -------------------------------------------------------------------------------- /stable_audio_tools/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import create_training_wrapper_from_config, create_demo_callback_from_config 2 | -------------------------------------------------------------------------------- /stable_audio_tools/training/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import * 2 | from .metrics import * 3 | from .semantic import * 4 | from .utils import * 5 | -------------------------------------------------------------------------------- /stable_audio_tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .models.factory import create_model_from_config, create_model_from_config_path 2 | from .models.pretrained import get_pretrained_model -------------------------------------------------------------------------------- /stable_audio_tools/configs/dataset_configs/custom_metadata/custom_md_example.py: -------------------------------------------------------------------------------- 1 | def get_custom_metadata(info, audio): 2 | 3 | # Use relative path as the prompt 4 | return {"prompt": info["relpath"]} -------------------------------------------------------------------------------- /stable_audio_tools/configs/dataset_configs/s3_wds_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "s3", 3 | "datasets": [ 4 | { 5 | "id": "s3-test", 6 | "s3_path": "s3://my-bucket/datasets/webdataset/audio/" 7 | } 8 | ], 9 | "random_crop": true 10 | } -------------------------------------------------------------------------------- /stable_audio_tools/configs/dataset_configs/local_training_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "audio_dir", 3 | "datasets": [ 4 | { 5 | "id": "my_audio", 6 | "path": "/path/to/audio/dataset/", 7 | "custom_metadata_module": "/path/to/custom_metadata/custom_md_example.py" 8 | } 9 | ], 10 | "random_crop": true 11 | } -------------------------------------------------------------------------------- /stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "diffusion_uncond", 3 | "sample_size": 65536, 4 | "sample_rate": 48000, 5 | "model": { 6 | "type": "DAU1d", 7 | "config": { 8 | "n_attn_layers": 5 9 | } 10 | }, 11 | "training": { 12 | "learning_rate": 1e-4, 13 | "demo": { 14 | "demo_every": 2000, 15 | "demo_steps": 250 16 | } 17 | } 18 | } -------------------------------------------------------------------------------- /stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "diffusion_uncond", 3 | "sample_size": 65536, 4 | "sample_rate": 16000, 5 | "model": { 6 | "type": "DAU1d", 7 | "config": { 8 | "n_attn_layers": 5 9 | } 10 | }, 11 | "training": { 12 | "learning_rate": 1e-4, 13 | "demo": { 14 | "demo_every": 2000, 15 | "demo_steps": 250 16 | } 17 | } 18 | } -------------------------------------------------------------------------------- /stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "diffusion_uncond", 3 | "sample_size": 65536, 4 | "sample_rate": 44100, 5 | "model": { 6 | "type": "DAU1d", 7 | "config": { 8 | "n_attn_layers": 5 9 | } 10 | }, 11 | "training": { 12 | "learning_rate": 4e-5, 13 | "demo": { 14 | "demo_every": 2000, 15 | "demo_steps": 250 16 | } 17 | } 18 | } -------------------------------------------------------------------------------- /stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "diffusion_uncond", 3 | "sample_size": 131072, 4 | "sample_rate": 48000, 5 | "model": { 6 | "type": "DAU1d", 7 | "config": { 8 | "n_attn_layers": 5 9 | } 10 | }, 11 | "training": { 12 | "learning_rate": 1e-4, 13 | "demo": { 14 | "demo_every": 2000, 15 | "demo_steps": 250 16 | } 17 | } 18 | } -------------------------------------------------------------------------------- /scripts/ds_zero_to_pl_ckpt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict 3 | 4 | if __name__ == "__main__": 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--save_path", type=str, help="Path to the zero checkpoint") 8 | parser.add_argument("--output_path", type=str, help="Path to the output checkpoint", default="lightning_model.pt") 9 | args = parser.parse_args() 10 | 11 | # lightning deepspeed has saved a directory instead of a file 12 | save_path = args.save_path 13 | output_path = args.output_path 14 | convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path) -------------------------------------------------------------------------------- /stable_audio_tools/models/pretrained.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from .factory import create_model_from_config 4 | from .utils import load_ckpt_state_dict 5 | 6 | from huggingface_hub import hf_hub_download 7 | 8 | def get_pretrained_model(name: str): 9 | 10 | model_config_path = hf_hub_download(name, filename="model_config.json", repo_type='model') 11 | 12 | with open(model_config_path) as f: 13 | model_config = json.load(f) 14 | 15 | model = create_model_from_config(model_config) 16 | 17 | # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file 18 | try: 19 | model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model') 20 | except Exception as e: 21 | model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model') 22 | 23 | model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) 24 | 25 | return model, model_config -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Stability AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /stable_audio_tools/training/losses/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | 4 | from torch.nn import functional as F 5 | from torch import nn 6 | 7 | ### Metrics are loss-like functions that do not backpropagate gradients. 8 | 9 | class PESQMetric(nn.Module): 10 | def __init__(self, sample_rate: int): 11 | super().__init__() 12 | self.resampler = ( 13 | torchaudio.transforms.Resample(sample_rate, 16000) 14 | if sample_rate != 16000 else None) 15 | 16 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor): 17 | if self.resampler is not None: 18 | inputs = self.resampler(inputs) 19 | targets = self.resampler(targets) 20 | 21 | inputs_np = inputs.cpu().numpy().astype("float64") 22 | targets_np = targets.cpu().numpy().astype("float64") 23 | batch_size = targets.shape[0] 24 | 25 | # Compute average pesq across batch size. 26 | val_pesq = (1.0 / batch_size) * sum( 27 | pesq(targets_np[i].reshape(-1), inputs_np[i].reshape(-1), 16000) 28 | for i in range(batch_size)) 29 | return val_pesq -------------------------------------------------------------------------------- /LICENSES/LICENSE_ADP.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 archinet.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /LICENSES/LICENSE_DESCRIPT.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-present, Descript 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /LICENSES/LICENSE_XTRANSFORMERS.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /LICENSES/LICENSE_NVIDIA.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 NVIDIA CORPORATION. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /LICENSES/LICENSE_META.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /stable_audio_tools/inference/utils.py: -------------------------------------------------------------------------------- 1 | from ..data.utils import PadCrop 2 | 3 | from torchaudio import transforms as T 4 | 5 | def set_audio_channels(audio, target_channels): 6 | # Add channel dim if it's missing 7 | if audio.dim() == 2: 8 | audio = audio.unsqueeze(1) 9 | 10 | if target_channels == 1: 11 | # Convert to mono 12 | audio = audio.mean(1, keepdim=True) 13 | elif target_channels == 2: 14 | # Convert to stereo 15 | if audio.shape[1] == 1: 16 | audio = audio.repeat(1, 2, 1) 17 | elif audio.shape[1] > 2: 18 | audio = audio[:, :2, :] 19 | return audio 20 | 21 | def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): 22 | 23 | audio = audio.to(device) 24 | 25 | if in_sr != target_sr: 26 | resample_tf = T.Resample(in_sr, target_sr).to(device) 27 | audio = resample_tf(audio) 28 | 29 | audio = PadCrop(target_length, randomize=False)(audio) 30 | 31 | # Add batch dimension 32 | if audio.dim() == 1: 33 | audio = audio.unsqueeze(0).unsqueeze(0) 34 | elif audio.dim() == 2: 35 | audio = audio.unsqueeze(0) 36 | 37 | audio = set_audio_channels(audio, target_channels) 38 | 39 | return audio -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='stable-audio-tools', 5 | version='0.0.19', 6 | url='https://github.com/Stability-AI/stable-audio-tools.git', 7 | author='Stability AI', 8 | description='Training and inference tools for generative audio models from Stability AI', 9 | packages=find_packages(), 10 | install_requires=[ 11 | 'alias-free-torch==0.0.6', 12 | 'auraloss==0.4.0', 13 | 'descript-audio-codec==1.0.0', 14 | 'einops', 15 | 'einops-exts', 16 | 'ema-pytorch==0.2.3', 17 | 'encodec==0.1.1', 18 | 'gradio>=5.20.0', 19 | 'huggingface_hub', 20 | 'importlib-resources==5.12.0', 21 | 'k-diffusion==0.1.1', 22 | 'laion-clap==1.1.4', 23 | 'local-attention==1.8.6', 24 | 'pandas==2.0.2', 25 | 'prefigure==0.0.9', 26 | 'pytorch_lightning==2.1.0', 27 | 'PyWavelets==1.4.1', 28 | 'safetensors', 29 | 'sentencepiece==0.1.99', 30 | 'torch>=2.5.1', 31 | 'torchaudio>=2.5.1', 32 | 'torchmetrics==0.11.4', 33 | 'tqdm', 34 | 'transformers', 35 | 'v-diffusion-pytorch==0.0.2', 36 | 'vector-quantize-pytorch==1.14.41', 37 | 'wandb==0.15.4', 38 | 'webdataset==0.2.100' 39 | ], 40 | ) 41 | -------------------------------------------------------------------------------- /defaults.ini: -------------------------------------------------------------------------------- 1 | 2 | [DEFAULTS] 3 | 4 | #name of the run 5 | name = stable_audio_tools 6 | 7 | # name of the project 8 | project = None 9 | 10 | # the batch size 11 | batch_size = 4 12 | 13 | # If `true`, attempts to resume training from latest checkpoint. 14 | # In this case, each run must have unique config filename. 15 | recover = false 16 | 17 | # Save top K model checkpoints during training. 18 | save_top_k = -1 19 | 20 | # number of nodes to use for training 21 | num_nodes = 1 22 | 23 | # Multi-GPU strategy for PyTorch Lightning 24 | strategy = "auto" 25 | 26 | # Precision to use for training 27 | precision = "16-mixed" 28 | 29 | # number of CPU workers for the DataLoader 30 | num_workers = 6 31 | 32 | # the random seed 33 | seed = 42 34 | 35 | # Batches for gradient accumulation 36 | accum_batches = 1 37 | 38 | # Number of steps between checkpoints 39 | checkpoint_every = 10000 40 | 41 | # Number of steps between validation runs 42 | val_every = -1 43 | 44 | # trainer checkpoint file to restart training from 45 | ckpt_path = '' 46 | 47 | # model checkpoint file to start a new training run from 48 | pretrained_ckpt_path = '' 49 | 50 | # Checkpoint path for the pretransform model if needed 51 | pretransform_ckpt_path = '' 52 | 53 | # configuration model specifying model hyperparameters 54 | model_config = '' 55 | 56 | # configuration for datasets 57 | dataset_config = '' 58 | 59 | # configuration for validation datasets 60 | val_dataset_config = '' 61 | 62 | # directory to save the checkpoints in 63 | save_dir = '' 64 | 65 | # gradient_clip_val passed into PyTorch Lightning Trainer 66 | gradient_clip_val = 0.0 67 | 68 | # remove the weight norm from the pretransform model 69 | remove_pretransform_weight_norm = '' 70 | 71 | # Logger type to use 72 | logger = 'wandb' 73 | -------------------------------------------------------------------------------- /run_gradio.py: -------------------------------------------------------------------------------- 1 | from stable_audio_tools import get_pretrained_model 2 | from stable_audio_tools.interface.gradio import create_ui 3 | import json 4 | 5 | import torch 6 | 7 | def main(args): 8 | torch.manual_seed(42) 9 | 10 | interface = create_ui( 11 | model_config_path = args.model_config, 12 | ckpt_path=args.ckpt_path, 13 | pretrained_name=args.pretrained_name, 14 | pretransform_ckpt_path=args.pretransform_ckpt_path, 15 | model_half=args.model_half, 16 | gradio_title=args.title 17 | ) 18 | interface.queue() 19 | interface.launch(share=args.share, auth=(args.username, args.password) if args.username is not None else None) 20 | 21 | if __name__ == "__main__": 22 | import argparse 23 | parser = argparse.ArgumentParser(description='Run gradio interface') 24 | parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False) 25 | parser.add_argument('--model-config', type=str, help='Path to model config', required=False) 26 | parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False) 27 | parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False) 28 | parser.add_argument('--share', action='store_true', help='Create a publicly shareable link', required=False) 29 | parser.add_argument('--username', type=str, help='Gradio username', required=False) 30 | parser.add_argument('--password', type=str, help='Gradio password', required=False) 31 | parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False, default=True) 32 | parser.add_argument('--title', type=str, help='Display Title top of Gradio', required=False) 33 | args = parser.parse_args() 34 | main(args) -------------------------------------------------------------------------------- /stable_audio_tools/configs/model_configs/autoencoders/dac_2048_32_vae.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "autoencoder", 3 | "sample_size": 65536, 4 | "sample_rate": 44100, 5 | "audio_channels": 1, 6 | "model": { 7 | "encoder": { 8 | "type": "dac", 9 | "config": { 10 | "latent_dim": 64, 11 | "d_model": 128, 12 | "strides": [4, 8, 8, 8] 13 | } 14 | }, 15 | "decoder": { 16 | "type": "dac", 17 | "config": { 18 | "latent_dim": 32, 19 | "channels": 1536, 20 | "rates": [8, 8, 8, 4] 21 | } 22 | }, 23 | "bottleneck": { 24 | "type": "vae" 25 | }, 26 | "latent_dim": 32, 27 | "downsampling_ratio": 2048, 28 | "io_channels": 1 29 | }, 30 | "training": { 31 | "learning_rate": 1e-4, 32 | "warmup_steps": 0, 33 | "use_ema": false, 34 | "loss_configs": { 35 | "discriminator": { 36 | "type": "encodec", 37 | "config": { 38 | "filters": 32, 39 | "n_ffts": [2048, 1024, 512, 256, 128, 64, 32], 40 | "hop_lengths": [512, 256, 128, 64, 32, 16, 8], 41 | "win_lengths": [2048, 1024, 512, 256, 128, 64, 32] 42 | }, 43 | "weights": { 44 | "adversarial": 0.1, 45 | "feature_matching": 5.0 46 | } 47 | }, 48 | "spectral": { 49 | "type": "mrstft", 50 | "config": { 51 | "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], 52 | "hop_sizes": [512, 256, 128, 64, 32, 16, 8], 53 | "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], 54 | "perceptual_weighting": true 55 | }, 56 | "weights": { 57 | "mrstft": 1.0 58 | } 59 | }, 60 | "time": { 61 | "type": "l1", 62 | "weights": { 63 | "l1": 0.0 64 | } 65 | } 66 | }, 67 | "demo": { 68 | "demo_every": 2000 69 | } 70 | } 71 | } -------------------------------------------------------------------------------- /docs/pretransforms.md: -------------------------------------------------------------------------------- 1 | # Pretransforms 2 | Many models require some fixed transform to be applied to the input audio before the audio is passed in to the trainable layers of the model, as well as a corresponding inverse transform to be applied to the outputs of the model. We refer to these as "pretransforms". 3 | 4 | At the moment, `stable-audio-tools` supports two pretransforms, frozen autoencoders for latent diffusion models and wavelet decompositions. 5 | 6 | Pretransforms have a similar interface to autoencoders with "encode" and "decode" functions defined for each pretransform. 7 | 8 | ## Autoencoder pretransform 9 | To define a model with an autoencoder pretransform, you can define the "pretransform" property in the model config, with the `type` property set to `autoencoder`. The `config` property should be an autoencoder model definition. 10 | 11 | Example: 12 | ```json 13 | "pretransform": { 14 | "type": "autoencoder", 15 | "config": { 16 | "encoder": { 17 | ... 18 | }, 19 | "decoder": { 20 | ... 21 | } 22 | ...normal autoencoder configuration 23 | } 24 | } 25 | ``` 26 | 27 | ### Latent rescaling 28 | The original [Latent Diffusion paper](https://arxiv.org/abs/2112.10752) found that rescaling the latent series to unit variance before performing diffusion improved quality. To this end, we expose a `scale` property on autoencoder pretransforms that will take care of this rescaling. The scale should be set to the original standard deviation of the latents, which can be determined experimentally, or by looking at the `latent_std` value during training. The pretransform code will divide by this scale factor in the `encode` function and multiply by this scale in the `decode` function. 29 | 30 | ## Wavelet pretransform 31 | `stable-audio-tools` also exposes wavelet decomposition as a pretransform. Wavelet decomposition is a quick way to trade off sequence length for channels in autoencoders, while maintaining a multi-band implicit bias. 32 | 33 | Wavelet pretransforms take the following properties: 34 | 35 | - `channels` 36 | - The number of input and output audio channels for the wavelet transform 37 | - `levels` 38 | - The number of successive wavelet decompositions to perform. Each level doubles the channel count and halves the sequence length 39 | - `wavelet` 40 | - The specific wavelet from [PyWavelets](https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html) to use, currently limited to `"bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"` 41 | 42 | ## Future work 43 | We hope to add more filters and transforms to this list, including PQMF and STFT transforms. -------------------------------------------------------------------------------- /stable_audio_tools/configs/model_configs/autoencoders/encodec_musicgen_rvq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "autoencoder", 3 | "sample_size": 32000, 4 | "sample_rate": 32000, 5 | "audio_channels": 1, 6 | "model": { 7 | "encoder": { 8 | "type": "seanet", 9 | "config": { 10 | "channels": 1, 11 | "dimension": 128, 12 | "n_filters": 64, 13 | "ratios": [4, 4, 5, 8], 14 | "n_residual_layers": 1, 15 | "dilation_base": 2, 16 | "lstm": 2, 17 | "norm": "weight_norm" 18 | } 19 | }, 20 | "decoder": { 21 | "type": "seanet", 22 | "config": { 23 | "channels": 1, 24 | "dimension": 128, 25 | "n_filters": 64, 26 | "ratios": [4, 4, 5, 8], 27 | "n_residual_layers": 1, 28 | "dilation_base": 2, 29 | "lstm": 2, 30 | "norm": "weight_norm" 31 | } 32 | }, 33 | "bottleneck": { 34 | "type": "rvq", 35 | "config": { 36 | "num_quantizers": 4, 37 | "codebook_size": 2048, 38 | "dim": 128, 39 | "decay": 0.99, 40 | "threshold_ema_dead_code": 2 41 | } 42 | }, 43 | "latent_dim": 128, 44 | "downsampling_ratio": 640, 45 | "io_channels": 1 46 | }, 47 | "training": { 48 | "learning_rate": 1e-4, 49 | "warmup_steps": 0, 50 | "use_ema": true, 51 | "loss_configs": { 52 | "discriminator": { 53 | "type": "encodec", 54 | "config": { 55 | "filters": 32, 56 | "n_ffts": [2048, 1024, 512, 256, 128], 57 | "hop_lengths": [512, 256, 128, 64, 32], 58 | "win_lengths": [2048, 1024, 512, 256, 128] 59 | }, 60 | "weights": { 61 | "adversarial": 0.1, 62 | "feature_matching": 5.0 63 | } 64 | }, 65 | "spectral": { 66 | "type": "mrstft", 67 | "config": { 68 | "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], 69 | "hop_sizes": [512, 256, 128, 64, 32, 16, 8], 70 | "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], 71 | "perceptual_weighting": true 72 | }, 73 | "weights": { 74 | "mrstft": 1.0 75 | } 76 | }, 77 | "time": { 78 | "type": "l1", 79 | "weights": { 80 | "l1": 0.0 81 | } 82 | } 83 | }, 84 | "demo": { 85 | "demo_every": 2000 86 | } 87 | } 88 | } -------------------------------------------------------------------------------- /stable_audio_tools/models/wavelets.py: -------------------------------------------------------------------------------- 1 | """The 1D discrete wavelet transform for PyTorch.""" 2 | 3 | from einops import rearrange 4 | import pywt 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from typing import Literal 9 | 10 | 11 | def get_filter_bank(wavelet): 12 | filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank) 13 | if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0): 14 | filt = filt[:, 1:] 15 | return filt 16 | 17 | class WaveletEncode1d(nn.Module): 18 | def __init__(self, 19 | channels, 20 | levels, 21 | wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): 22 | super().__init__() 23 | self.wavelet = wavelet 24 | self.channels = channels 25 | self.levels = levels 26 | filt = get_filter_bank(wavelet) 27 | assert filt.shape[-1] % 2 == 1 28 | kernel = filt[:2, None] 29 | kernel = torch.flip(kernel, dims=(-1,)) 30 | index_i = torch.repeat_interleave(torch.arange(2), channels) 31 | index_j = torch.tile(torch.arange(channels), (2,)) 32 | kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) 33 | kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] 34 | self.register_buffer("kernel", kernel_final) 35 | 36 | def forward(self, x): 37 | for i in range(self.levels): 38 | low, rest = x[:, : self.channels], x[:, self.channels :] 39 | pad = self.kernel.shape[-1] // 2 40 | low = F.pad(low, (pad, pad), "reflect") 41 | low = F.conv1d(low, self.kernel, stride=2) 42 | rest = rearrange( 43 | rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels 44 | ) 45 | x = torch.cat([low, rest], dim=1) 46 | return x 47 | 48 | 49 | class WaveletDecode1d(nn.Module): 50 | def __init__(self, 51 | channels, 52 | levels, 53 | wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): 54 | super().__init__() 55 | self.wavelet = wavelet 56 | self.channels = channels 57 | self.levels = levels 58 | filt = get_filter_bank(wavelet) 59 | assert filt.shape[-1] % 2 == 1 60 | kernel = filt[2:, None] 61 | index_i = torch.repeat_interleave(torch.arange(2), channels) 62 | index_j = torch.tile(torch.arange(channels), (2,)) 63 | kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) 64 | kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] 65 | self.register_buffer("kernel", kernel_final) 66 | 67 | def forward(self, x): 68 | for i in range(self.levels): 69 | low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :] 70 | pad = self.kernel.shape[-1] // 2 + 2 71 | low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2) 72 | low = F.pad(low, (pad, pad), "reflect") 73 | low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2) 74 | low = F.conv_transpose1d( 75 | low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2 76 | ) 77 | low = low[..., pad - 1 : -pad] 78 | rest = rearrange( 79 | rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels 80 | ) 81 | x = torch.cat([low, rest], dim=1) 82 | return x -------------------------------------------------------------------------------- /stable_audio_tools/training/losses/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch import nn 4 | from einops import rearrange 5 | 6 | class DynamicLossWeighting(nn.Module): 7 | def __init__(self, init_val = 1.0): 8 | super().__init__() 9 | self.loss_weight = nn.Parameter(torch.tensor(init_val)) 10 | def forward(self, loss): 11 | return loss / torch.exp(self.loss_weight) + self.loss_weight 12 | 13 | def flat_pairwise_sq_distance(x, y): 14 | """ 15 | Compute pairwise squared Euclidean distances for flat tensors. 16 | Args: 17 | x: Tensor of shape (B, D) 18 | y: Tensor of shape (B, D) 19 | Returns: 20 | dist_sq: Tensor of shape (B, B) 21 | """ 22 | x_norm = (x ** 2).mean(dim=1, keepdim=True) # (B, 1) 23 | y_norm = (y ** 2).mean(dim=1, keepdim=True).transpose(0, 1) # (1, B) 24 | return (x_norm + y_norm - 2.0 * torch.mm(x, y.t())/ x.shape[1] ) 25 | 26 | def multi_bandwidth_kernel_2d(x, y, bandwidths): 27 | """ 28 | Compute the sum of Gaussian kernels (with different bandwidths) between two flat tensors. 29 | Args: 30 | x: Tensor of shape (B, D) 31 | y: Tensor of shape (B, D) 32 | bandwidths: Iterable of scalar bandwidth values. 33 | Returns: 34 | kernel: Tensor of shape (B, B) 35 | """ 36 | dist_sq = flat_pairwise_sq_distance(x, y).clip(min = 0.0) 37 | kernel_sum = 0.0 38 | for bw in bandwidths: 39 | #kernel_sum += torch.exp(-dist_sq / (2.0 * bw)) 40 | kernel_sum += (1/(1 + dist_sq / ( 2 * bw))).mean() 41 | return kernel_sum / len(bandwidths) 42 | 43 | def mmd_loss_flat(x, y, bandwidths): 44 | """ 45 | Compute the MMD loss between two flat sets of vectors. 46 | Args: 47 | x: Tensor of shape (B, D) 48 | y: Tensor of shape (B, D) 49 | bandwidths: Iterable of bandwidth values. 50 | Returns: 51 | loss: Scalar tensor representing the MMD loss. 52 | """ 53 | K_xx = multi_bandwidth_kernel_2d(x, x, bandwidths) 54 | K_yy = multi_bandwidth_kernel_2d(y, y, bandwidths) 55 | K_xy = multi_bandwidth_kernel_2d(x, y, bandwidths) 56 | loss = K_xx + K_yy - 2.0 * K_xy 57 | return loss 58 | 59 | def mmd(x, y, bandwidths =[1], dim = None): 60 | """ 61 | Compute the MMD loss along a chosen feature axis by collapsing all other dimensions. 62 | 63 | Args: 64 | x: Tensor of arbitrary shape. 65 | y: Tensor of the same shape as x. 66 | bandwidths: Iterable of scalar bandwidth values for the kernel. 67 | dim: The axis index that should be treated as the feature dimension. 68 | 69 | Returns: 70 | Scalar tensor representing the MMD loss computed on the flattened representations. 71 | """ 72 | if dim is None: 73 | dim_product = math.prod(x.shape[1:]) 74 | new_shape = (-1, dim_product) 75 | x_flat = x.reshape(new_shape) 76 | y_flat = y.reshape(new_shape) 77 | else: 78 | dims = list(range(x.dim())) 79 | dims.pop(dim) 80 | dims.append(dim) 81 | x_perm = x.permute(*dims) 82 | y_perm = y.permute(*dims) 83 | # Collapse all dimensions except the last one. 84 | new_shape = (-1, x_perm.size(-1)) 85 | x_flat = x_perm.reshape(new_shape) 86 | y_flat = y_perm.reshape(new_shape) 87 | return mmd_loss_flat(x_flat, y_flat, bandwidths) 88 | 89 | def grouped_mmd(x, y, bandwidths = [1], groups = 2): 90 | grouped_x = rearrange(x, '... (g f) t -> ... g (f t)', g = groups) 91 | grouped_y = rearrange(y, '... (g f) t -> ... g (f t)', g = groups) 92 | return mmd(grouped_x, grouped_y, bandwidths, dim = None) -------------------------------------------------------------------------------- /stable_audio_tools/configs/model_configs/autoencoders/stable_audio_1_0_vae.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "autoencoder", 3 | "sample_size": 65536, 4 | "sample_rate": 44100, 5 | "audio_channels": 2, 6 | "model": { 7 | "encoder": { 8 | "type": "dac", 9 | "config": { 10 | "in_channels": 2, 11 | "latent_dim": 128, 12 | "d_model": 128, 13 | "strides": [4, 4, 8, 8] 14 | } 15 | }, 16 | "decoder": { 17 | "type": "dac", 18 | "config": { 19 | "out_channels": 2, 20 | "latent_dim": 64, 21 | "channels": 1536, 22 | "rates": [8, 8, 4, 4] 23 | } 24 | }, 25 | "bottleneck": { 26 | "type": "vae" 27 | }, 28 | "latent_dim": 64, 29 | "downsampling_ratio": 1024, 30 | "io_channels": 2 31 | }, 32 | "training": { 33 | "learning_rate": 1e-4, 34 | "warmup_steps": 0, 35 | "use_ema": true, 36 | "optimizer_configs": { 37 | "autoencoder": { 38 | "optimizer": { 39 | "type": "AdamW", 40 | "config": { 41 | "betas": [0.8, 0.99], 42 | "lr": 1e-4 43 | } 44 | }, 45 | "scheduler": { 46 | "type": "ExponentialLR", 47 | "config": { 48 | "gamma": 0.999996 49 | } 50 | } 51 | }, 52 | "discriminator": { 53 | "optimizer": { 54 | "type": "AdamW", 55 | "config": { 56 | "betas": [0.8, 0.99], 57 | "lr": 1e-4 58 | } 59 | }, 60 | "scheduler": { 61 | "type": "ExponentialLR", 62 | "config": { 63 | "gamma": 0.999996 64 | } 65 | } 66 | } 67 | }, 68 | "loss_configs": { 69 | "discriminator": { 70 | "type": "encodec", 71 | "config": { 72 | "filters": 32, 73 | "n_ffts": [2048, 1024, 512, 256, 128], 74 | "hop_lengths": [512, 256, 128, 64, 32], 75 | "win_lengths": [2048, 1024, 512, 256, 128] 76 | }, 77 | "weights": { 78 | "adversarial": 0.1, 79 | "feature_matching": 5.0 80 | } 81 | }, 82 | "spectral": { 83 | "type": "mrstft", 84 | "config": { 85 | "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], 86 | "hop_sizes": [512, 256, 128, 64, 32, 16, 8], 87 | "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], 88 | "perceptual_weighting": true 89 | }, 90 | "weights": { 91 | "mrstft": 1.0 92 | } 93 | }, 94 | "time": { 95 | "type": "l1", 96 | "weights": { 97 | "l1": 0.0 98 | } 99 | }, 100 | "bottleneck": { 101 | "type": "kl", 102 | "weights": { 103 | "kl": 1e-6 104 | } 105 | } 106 | }, 107 | "demo": { 108 | "demo_every": 2000 109 | } 110 | } 111 | } -------------------------------------------------------------------------------- /stable_audio_tools/models/arc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch import nn 5 | 6 | def get_relativistic_losses(score_real, score_fake): 7 | # Compute difference between real and fake scores 8 | diff = score_real - score_fake 9 | dis_loss = F.softplus(-diff).mean() 10 | gen_loss = F.softplus(diff).mean() 11 | return dis_loss, gen_loss 12 | 13 | class ConvDiscriminator(nn.Module): 14 | def __init__(self, channels, soft_clip_scale=None, loss_type="lsgan"): 15 | super().__init__() 16 | 17 | self.loss_type = loss_type 18 | 19 | self.layers = nn.Sequential( 20 | nn.Conv1d(kernel_size=4, in_channels=channels, out_channels=channels, stride=2, padding=1), # x2 downsampling 21 | nn.GroupNorm(num_groups=32, num_channels=channels), 22 | nn.SiLU(), 23 | nn.Conv1d(kernel_size=4, in_channels=channels, out_channels=channels, stride=2, padding=1), # x4 downsampling 24 | nn.GroupNorm(num_groups=32, num_channels=channels), 25 | nn.SiLU(), 26 | nn.Conv1d(kernel_size=4, in_channels=channels, out_channels=channels, stride=2, padding=1), # x8 downsampling 27 | nn.GroupNorm(num_groups=32, num_channels=channels), 28 | nn.SiLU(), 29 | nn.Conv1d(kernel_size=4, in_channels=channels, out_channels=channels, stride=2, padding=1), # x16 downsampling 30 | nn.GroupNorm(num_groups=32, num_channels=channels), 31 | nn.SiLU(), 32 | nn.Conv1d(kernel_size=4, in_channels=channels, out_channels=1, stride=1, padding=0), # to 1 channel for score 33 | ) 34 | 35 | self.soft_clip_scale = soft_clip_scale 36 | 37 | def forward(self, x): 38 | output = self.layers(x) 39 | 40 | if self.soft_clip_scale is not None: 41 | output = self.soft_clip_scale * torch.tanh(output/self.soft_clip_scale) 42 | 43 | return output 44 | 45 | def loss(self, reals, fakes, *args, **kwargs): 46 | real_scores = self(reals) 47 | fake_scores = self(fakes) 48 | 49 | loss_dis = loss_adv = 0 50 | 51 | if self.loss_type == "lsgan": 52 | # Calculate least-squares GAN losses 53 | loss_dis = torch.mean(fake_scores**2) + torch.mean ((1 - real_scores)**2) 54 | loss_adv = torch.mean((1 - fake_scores)**2) 55 | elif self.loss_type == "relativistic": 56 | 57 | diff = real_scores - fake_scores 58 | 59 | loss_dis = F.softplus(-diff).mean() 60 | loss_adv = F.softplus(diff).mean() 61 | 62 | return { 63 | "loss_dis": loss_dis, 64 | "loss_adv": loss_adv 65 | } 66 | 67 | class ConvNeXtDiscriminator(nn.Module): 68 | def __init__(self, loss_type="lsgan", *args, **kwargs): 69 | super().__init__() 70 | 71 | from .convnext import ConvNeXtEncoder 72 | 73 | self.encoder = ConvNeXtEncoder(*args, **kwargs) 74 | 75 | self.loss_type = loss_type 76 | 77 | def forward(self, x): 78 | return self.encoder(x) 79 | 80 | def loss(self, reals, fakes, *args, **kwargs): 81 | real_scores = self(reals) 82 | fake_scores = self(fakes) 83 | 84 | loss_dis = loss_adv = 0 85 | 86 | if self.loss_type == "lsgan": 87 | # Calculate least-squares GAN losses 88 | loss_dis = torch.mean(fake_scores**2) + torch.mean ((1 - real_scores)**2) 89 | loss_adv = torch.mean((1 - fake_scores)**2) 90 | elif self.loss_type == "relativistic": 91 | 92 | diff = real_scores - fake_scores 93 | 94 | loss_dis = F.softplus(-diff).mean() 95 | loss_adv = F.softplus(diff).mean() 96 | 97 | return { 98 | "loss_dis": loss_dis, 99 | "loss_adv": loss_adv 100 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | *.ckpt 163 | *.wav 164 | wandb/* -------------------------------------------------------------------------------- /stable_audio_tools/configs/model_configs/autoencoders/stable_audio_2_0_vae.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "autoencoder", 3 | "sample_size": 65536, 4 | "sample_rate": 44100, 5 | "audio_channels": 2, 6 | "model": { 7 | "encoder": { 8 | "type": "oobleck", 9 | "config": { 10 | "in_channels": 2, 11 | "channels": 128, 12 | "c_mults": [1, 2, 4, 8, 16], 13 | "strides": [2, 4, 4, 8, 8], 14 | "latent_dim": 128, 15 | "use_snake": true 16 | } 17 | }, 18 | "decoder": { 19 | "type": "oobleck", 20 | "config": { 21 | "out_channels": 2, 22 | "channels": 128, 23 | "c_mults": [1, 2, 4, 8, 16], 24 | "strides": [2, 4, 4, 8, 8], 25 | "latent_dim": 64, 26 | "use_snake": true, 27 | "final_tanh": false 28 | } 29 | }, 30 | "bottleneck": { 31 | "type": "vae" 32 | }, 33 | "latent_dim": 64, 34 | "downsampling_ratio": 2048, 35 | "io_channels": 2 36 | }, 37 | "training": { 38 | "learning_rate": 1.5e-4, 39 | "warmup_steps": 0, 40 | "use_ema": true, 41 | "optimizer_configs": { 42 | "autoencoder": { 43 | "optimizer": { 44 | "type": "AdamW", 45 | "config": { 46 | "betas": [0.8, 0.99], 47 | "lr": 1.5e-4, 48 | "weight_decay": 1e-3 49 | } 50 | }, 51 | "scheduler": { 52 | "type": "InverseLR", 53 | "config": { 54 | "inv_gamma": 200000, 55 | "power": 0.5, 56 | "warmup": 0.999 57 | } 58 | } 59 | }, 60 | "discriminator": { 61 | "optimizer": { 62 | "type": "AdamW", 63 | "config": { 64 | "betas": [0.8, 0.99], 65 | "lr": 3e-4, 66 | "weight_decay": 1e-3 67 | } 68 | }, 69 | "scheduler": { 70 | "type": "InverseLR", 71 | "config": { 72 | "inv_gamma": 200000, 73 | "power": 0.5, 74 | "warmup": 0.999 75 | } 76 | } 77 | } 78 | }, 79 | "loss_configs": { 80 | "discriminator": { 81 | "type": "encodec", 82 | "config": { 83 | "filters": 64, 84 | "n_ffts": [2048, 1024, 512, 256, 128], 85 | "hop_lengths": [512, 256, 128, 64, 32], 86 | "win_lengths": [2048, 1024, 512, 256, 128] 87 | }, 88 | "weights": { 89 | "adversarial": 0.1, 90 | "feature_matching": 5.0 91 | } 92 | }, 93 | "spectral": { 94 | "type": "mrstft", 95 | "config": { 96 | "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], 97 | "hop_sizes": [512, 256, 128, 64, 32, 16, 8], 98 | "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], 99 | "perceptual_weighting": true 100 | }, 101 | "weights": { 102 | "mrstft": 1.0 103 | } 104 | }, 105 | "time": { 106 | "type": "l1", 107 | "weights": { 108 | "l1": 0.0 109 | } 110 | }, 111 | "bottleneck": { 112 | "type": "kl", 113 | "weights": { 114 | "kl": 1e-4 115 | } 116 | } 117 | }, 118 | "demo": { 119 | "demo_every": 2000 120 | } 121 | } 122 | } -------------------------------------------------------------------------------- /stable_audio_tools/configs/model_configs/txt2audio/stable_audio_1_0.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "diffusion_cond", 3 | "sample_size": 4194304, 4 | "sample_rate": 44100, 5 | "audio_channels": 2, 6 | "model": { 7 | "pretransform": { 8 | "type": "autoencoder", 9 | "iterate_batch": true, 10 | "config": { 11 | "encoder": { 12 | "type": "dac", 13 | "config": { 14 | "in_channels": 2, 15 | "latent_dim": 128, 16 | "d_model": 128, 17 | "strides": [4, 4, 8, 8] 18 | } 19 | }, 20 | "decoder": { 21 | "type": "dac", 22 | "config": { 23 | "out_channels": 2, 24 | "latent_dim": 64, 25 | "channels": 1536, 26 | "rates": [8, 8, 4, 4] 27 | } 28 | }, 29 | "bottleneck": { 30 | "type": "vae" 31 | }, 32 | "latent_dim": 64, 33 | "downsampling_ratio": 1024, 34 | "io_channels": 2 35 | } 36 | }, 37 | "conditioning": { 38 | "configs": [ 39 | { 40 | "id": "prompt", 41 | "type": "clap_text", 42 | "config": { 43 | "audio_model_type": "HTSAT-base", 44 | "enable_fusion": true, 45 | "clap_ckpt_path": "/path/to/clap.ckpt", 46 | "use_text_features": true, 47 | "feature_layer_ix": -2 48 | } 49 | }, 50 | { 51 | "id": "seconds_start", 52 | "type": "int", 53 | "config": { 54 | "min_val": 0, 55 | "max_val": 512 56 | } 57 | }, 58 | { 59 | "id": "seconds_total", 60 | "type": "int", 61 | "config": { 62 | "min_val": 0, 63 | "max_val": 512 64 | } 65 | } 66 | ], 67 | "cond_dim": 768 68 | }, 69 | "diffusion": { 70 | "type": "adp_cfg_1d", 71 | "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"], 72 | "config": { 73 | "in_channels": 64, 74 | "context_embedding_features": 768, 75 | "context_embedding_max_length": 79, 76 | "channels": 256, 77 | "resnet_groups": 16, 78 | "kernel_multiplier_downsample": 2, 79 | "multipliers": [4, 4, 4, 5, 5], 80 | "factors": [1, 2, 2, 4], 81 | "num_blocks": [2, 2, 2, 2], 82 | "attentions": [1, 3, 3, 3, 3], 83 | "attention_heads": 16, 84 | "attention_multiplier": 4, 85 | "use_nearest_upsample": false, 86 | "use_skip_scale": true, 87 | "use_context_time": true 88 | } 89 | }, 90 | "io_channels": 64 91 | }, 92 | "training": { 93 | "learning_rate": 4e-5, 94 | "demo": { 95 | "demo_every": 2000, 96 | "demo_steps": 250, 97 | "num_demos": 4, 98 | "demo_cond": [ 99 | {"prompt": "A beautiful piano arpeggio", "seconds_start": 0, "seconds_total": 95}, 100 | {"prompt": "A tropical house track with upbeat melodies, a driving bassline, and cheery vibes", "seconds_start": 0, "seconds_total": 90}, 101 | {"prompt": "A cool 80s glam rock song with driving drums and distorted guitars", "seconds_start": 0, "seconds_total": 180}, 102 | {"prompt": "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle.", "seconds_start": 0, "seconds_total": 60} 103 | ], 104 | "demo_cfg_scales": [3, 6, 9] 105 | } 106 | } 107 | } -------------------------------------------------------------------------------- /docs/pre_encoding.md: -------------------------------------------------------------------------------- 1 | # Pre Encoding 2 | 3 | When training models on encoded latents from a frozen pre-trained autoencoder, the encoder is typically frozen. Because of that, it is common to pre-encode audio to latents and store them on disk instead of computing them on-the-fly during training. This can improve training throughput as well as free up GPU memory that would otherwise be used for encoding. 4 | 5 | ## Prerequisites 6 | 7 | To pre-encode audio to latents, you'll need a dataset config file, an autoencoder model config file, and an **unwrapped** autoencoder checkpoint file. 8 | 9 | **Note:** You can find a copy of the unwrapped VAE checkpoint (`vae_model.ckpt`) and config (`vae_config.json`) in the `stabilityai/stable-audio-open-1.0` Hugging Face [repo](https://huggingface.co/stabilityai/stable-audio-open-1.0). This is the same VAE used in `stable-audio-open-small`. 10 | 11 | ## Run the Pre Encoding Script 12 | 13 | To pre-encode latents from an autoencoder model, you can use `pre_encode.py`. This script will load a pre-trained autoencoder, encode the latents/tokens, and save them to disk in a format that can be easily loaded during training. 14 | 15 | The `pre_encode.py` script accepts the following command line arguments: 16 | 17 | - `--model-config` 18 | - Path to model config 19 | - `--ckpt-path` 20 | - Path to **unwrapped** autoencoder model checkpoint 21 | - `--model-half` 22 | - If true, uses half precision for model weights 23 | - Optional 24 | - `--dataset-config` 25 | - Path to dataset config file 26 | - Required 27 | - `--output-path` 28 | - Path to output folder 29 | - Required 30 | - `--batch-size` 31 | - Batch size for processing 32 | - Optional, defaults to 1 33 | - `--sample-size` 34 | - Number of audio samples to pad/crop to for pre-encoding 35 | - Optional, defaults to 1320960 (~30 seconds) 36 | - `--is-discrete` 37 | - If true, treats the model as discrete, saving discrete tokens instead of continuous latents 38 | - Optional 39 | - `--num-nodes` 40 | - Number of nodes to use for distributed processing, if available. 41 | - Optional, defaults to 1 42 | - `--num-workers` 43 | - Number of dataloader workers 44 | - Optional, defaults to 4 45 | - `--strategy` 46 | - PyTorch Lightning strategy 47 | - Optional, defaults to 'auto' 48 | - `--limit-batches` 49 | - Limits the number of batches processed 50 | - Optional 51 | - `--shuffle` 52 | - If true, shuffles the dataset 53 | - Optional 54 | 55 | **Note:** When pre encoding, it's recommended to set `"drop_last": false` in your dataset config to ensure the last batch is processed even if it's not full. 56 | 57 | For example, if you wanted to encode latents with padding up to 30 seconds long in half precision, you could run the following: 58 | 59 | ```bash 60 | $ python3 ./pre_encode.py \ 61 | --model-config /path/to/model/config.json \ 62 | --ckpt-path /path/to/autoencoder/model.ckpt \ 63 | --model-half \ 64 | --dataset-config /path/to/dataset/config.json \ 65 | --output-path /path/to/output/dir \ 66 | --sample-size 1320960 \ 67 | ``` 68 | 69 | When you run the above, the `--output-path` directory will contain numbered subdirectories for each GPU process used to encode the latents, and a `details.json` file that keeps track of settings used when the script was run. 70 | 71 | Inside the numbered subdirectories, you will find the encoded latents as `.npy` files, along with associated `.json` metadata files. 72 | 73 | ```bash 74 | /path/to/output/dir/ 75 | ├── 0 76 | │ ├── 0000000000000.json 77 | │ ├── 0000000000000.npy 78 | │ ├── 0000000000001.json 79 | │ ├── 0000000000001.npy 80 | │ ├── 0000000000002.json 81 | │ ├── 0000000000002.npy 82 | ... 83 | └── details.json 84 | ``` 85 | 86 | ## Training on Pre Encoded Latents 87 | 88 | Once you have saved your latents to disk, you can use them to train a model by providing a dataset config file to `train.py` that points to the pre-encoded latents, specifying `"dataset_type"` is `"pre_encoded"`. Under the hood, this will configure a `stable_audio_tools.data.dataset.PreEncodedDataset`. For more information on configuring pre encoded datasets, see the [Pre Encoded Datasets](datasets.md#pre-encoded-datasets) section of the datasets docs. 89 | 90 | The dataset config file should look something like this: 91 | 92 | ```json 93 | { 94 | "dataset_type": "pre_encoded", 95 | "datasets": [ 96 | { 97 | "id": "my_audio", 98 | "path": "/path/to/output/dir" 99 | } 100 | ], 101 | "random_crop": false 102 | } 103 | ``` 104 | 105 | In your diffusion model config, you'll also need to specify `pre_encoded: true` in the [`training` section](diffusion.md#training-configs) to tell the training wrapper to operate on pre encoded latents instead of audio. 106 | 107 | ```json 108 | "training": { 109 | "pre_encoded": true, 110 | ... 111 | } 112 | ``` 113 | -------------------------------------------------------------------------------- /stable_audio_tools/models/inpainting.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from enum import Enum 4 | from typing import List, Optional, Tuple 5 | 6 | class MaskType(Enum): 7 | RANDOM_SEGMENTS = 0 8 | FULL_MASK = 1 9 | CAUSAL_MASK = 2 10 | 11 | def random_inpaint_mask( 12 | sequence: torch.Tensor, 13 | padding_masks: torch.Tensor, 14 | max_mask_segments: int = 10, 15 | mask_type_probabilities: Optional[List[float]] = None, 16 | ) -> Tuple[torch.Tensor, torch.Tensor]: 17 | """ 18 | Generates random inpainting masks for a batch of latent audio sequences. 19 | The output inpainting mask has 0 where data should be inpainted, and 1 where data is provided. 20 | 21 | Args: 22 | sequence: The input sequence tensor of shape (b, c, sequence_length). 23 | padding_masks: A tensor of shape (b, sequence_length) 24 | where 1 indicates real data latents and 0 indicates latents encoding silence padding. 25 | max_mask_segments: The maximum number of segments for the RANDOM_SEGMENTS mask type. 26 | mask_type_probabilities: A list of probabilities for choosing each mask type. 27 | The order should correspond to: 28 | [P(RANDOM_SEGMENTS), P(FULL_MASK), P(CAUSAL_MASK)]. 29 | If None, defaults to uniform probabilities. 30 | 31 | Returns: 32 | A tuple containing: 33 | - masked_sequence: The sequence with masks applied (original sequence where mask is 1, 34 | and usually 0 or a placeholder where mask is 0). 35 | - inpaint_mask: The generated inpainting mask tensor (0 for inpaint, 1 for keep). 36 | """ 37 | b, _, sequence_length = sequence.size() 38 | 39 | num_mask_types = len(MaskType) 40 | if mask_type_probabilities is None: 41 | mask_type_probabilities = [0.1, 0.8, 0.1] 42 | else: 43 | if len(mask_type_probabilities) != num_mask_types: 44 | raise ValueError( 45 | f"mask_type_probabilities must have {num_mask_types} elements, " 46 | f"one for each MaskType. Got {len(mask_type_probabilities)}." 47 | ) 48 | if not torch.isclose(torch.tensor(sum(mask_type_probabilities)), torch.tensor(1.0)): 49 | raise ValueError( 50 | f"mask_type_probabilities must sum to 1.0. " 51 | f"Current sum: {sum(mask_type_probabilities)}" 52 | ) 53 | 54 | output_masks_list = [] 55 | mask_types_to_sample = [mt.value for mt in MaskType] 56 | 57 | for i in range(b): 58 | padding_mask_single_item = padding_masks[i] 59 | real_sequence_length = (padding_mask_single_item == 1).sum().item() 60 | 61 | item_mask = torch.ones((1, 1, sequence_length), device=sequence.device, dtype=torch.float32) 62 | 63 | chosen_mask_value = random.choices(mask_types_to_sample, weights=mask_type_probabilities, k=1)[0] 64 | current_mask_type = MaskType(chosen_mask_value) 65 | 66 | if current_mask_type == MaskType.FULL_MASK: 67 | item_mask = torch.zeros((1, 1, sequence_length), device=sequence.device, dtype=torch.float32) 68 | elif real_sequence_length == 0: 69 | pass # item_mask remains all ones for RANDOM_SEGMENTS/CAUSAL_MASK on empty real data 70 | else: 71 | # Logic for RANDOM_SEGMENTS and CAUSAL_MASK when real_sequence_length > 0 72 | if current_mask_type == MaskType.RANDOM_SEGMENTS: 73 | num_segments = random.randint(1, max_mask_segments) 74 | # Max length for a single segment, based on average length 75 | max_len_per_segment_calc = max(1, real_sequence_length // num_segments) 76 | 77 | for _ in range(num_segments): 78 | segment_length = random.randint(1, max_len_per_segment_calc) 79 | 80 | if real_sequence_length - segment_length < 0: 81 | continue 82 | mask_start = random.randint(0, real_sequence_length - segment_length) 83 | item_mask[:, :, mask_start : mask_start + segment_length] = 0 84 | 85 | elif current_mask_type == MaskType.CAUSAL_MASK: 86 | # Keep a prefix of real data, inpaint the suffix. 87 | # The length of the unmasked prefix can be from 0 to real_sequence_length. 88 | unmasked_prefix_len = random.randint(0, real_sequence_length) 89 | 90 | if unmasked_prefix_len < real_sequence_length: 91 | item_mask[:, :, unmasked_prefix_len:real_sequence_length] = 0 92 | 93 | output_masks_list.append(item_mask) 94 | 95 | final_inpaint_mask = torch.cat(output_masks_list, dim=0).to(sequence.device) 96 | masked_sequence = sequence * final_inpaint_mask 97 | return masked_sequence, final_inpaint_mask -------------------------------------------------------------------------------- /stable_audio_tools/configs/model_configs/txt2audio/stable_audio_2_0.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "diffusion_cond", 3 | "sample_size": 12582912, 4 | "sample_rate": 44100, 5 | "audio_channels": 2, 6 | "model": { 7 | "pretransform": { 8 | "type": "autoencoder", 9 | "iterate_batch": true, 10 | "config": { 11 | "encoder": { 12 | "type": "oobleck", 13 | "config": { 14 | "in_channels": 2, 15 | "channels": 128, 16 | "c_mults": [1, 2, 4, 8, 16], 17 | "strides": [2, 4, 4, 8, 8], 18 | "latent_dim": 128, 19 | "use_snake": true 20 | } 21 | }, 22 | "decoder": { 23 | "type": "oobleck", 24 | "config": { 25 | "out_channels": 2, 26 | "channels": 128, 27 | "c_mults": [1, 2, 4, 8, 16], 28 | "strides": [2, 4, 4, 8, 8], 29 | "latent_dim": 64, 30 | "use_snake": true, 31 | "final_tanh": false 32 | } 33 | }, 34 | "bottleneck": { 35 | "type": "vae" 36 | }, 37 | "latent_dim": 64, 38 | "downsampling_ratio": 2048, 39 | "io_channels": 2 40 | } 41 | }, 42 | "conditioning": { 43 | "configs": [ 44 | { 45 | "id": "prompt", 46 | "type": "clap_text", 47 | "config": { 48 | "audio_model_type": "HTSAT-base", 49 | "enable_fusion": true, 50 | "clap_ckpt_path": "/path/to/clap.ckpt", 51 | "use_text_features": true, 52 | "feature_layer_ix": -2 53 | } 54 | }, 55 | { 56 | "id": "seconds_start", 57 | "type": "number", 58 | "config": { 59 | "min_val": 0, 60 | "max_val": 512 61 | } 62 | }, 63 | { 64 | "id": "seconds_total", 65 | "type": "number", 66 | "config": { 67 | "min_val": 0, 68 | "max_val": 512 69 | } 70 | } 71 | ], 72 | "cond_dim": 768 73 | }, 74 | "diffusion": { 75 | "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"], 76 | "global_cond_ids": ["seconds_start", "seconds_total"], 77 | "type": "dit", 78 | "config": { 79 | "io_channels": 64, 80 | "embed_dim": 1536, 81 | "depth": 24, 82 | "num_heads": 24, 83 | "cond_token_dim": 768, 84 | "global_cond_dim": 1536, 85 | "project_cond_tokens": false, 86 | "transformer_type": "continuous_transformer" 87 | } 88 | }, 89 | "io_channels": 64 90 | }, 91 | "training": { 92 | "use_ema": true, 93 | "log_loss_info": false, 94 | "optimizer_configs": { 95 | "diffusion": { 96 | "optimizer": { 97 | "type": "AdamW", 98 | "config": { 99 | "lr": 5e-5, 100 | "betas": [0.9, 0.999], 101 | "weight_decay": 1e-3 102 | } 103 | }, 104 | "scheduler": { 105 | "type": "InverseLR", 106 | "config": { 107 | "inv_gamma": 1000000, 108 | "power": 0.5, 109 | "warmup": 0.99 110 | } 111 | } 112 | } 113 | }, 114 | "demo": { 115 | "demo_every": 2000, 116 | "demo_steps": 250, 117 | "num_demos": 4, 118 | "demo_cond": [ 119 | {"prompt": "A beautiful piano arpeggio", "seconds_start": 0, "seconds_total": 80}, 120 | {"prompt": "A tropical house track with upbeat melodies, a driving bassline, and cheery vibes", "seconds_start": 0, "seconds_total": 250}, 121 | {"prompt": "A cool 80s glam rock song with driving drums and distorted guitars", "seconds_start": 0, "seconds_total": 180}, 122 | {"prompt": "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle.", "seconds_start": 0, "seconds_total": 190} 123 | ], 124 | "demo_cfg_scales": [3, 6, 9] 125 | } 126 | } 127 | } -------------------------------------------------------------------------------- /unwrap_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from torch.nn.parameter import Parameter 4 | from stable_audio_tools.models import create_model_from_config 5 | 6 | if __name__ == '__main__': 7 | args = argparse.ArgumentParser() 8 | args.add_argument('--model-config', type=str, default=None) 9 | args.add_argument('--ckpt-path', type=str, default=None) 10 | args.add_argument('--name', type=str, default='exported_model') 11 | args.add_argument('--use-safetensors', action='store_true') 12 | 13 | args = args.parse_args() 14 | 15 | with open(args.model_config) as f: 16 | model_config = json.load(f) 17 | 18 | model = create_model_from_config(model_config) 19 | 20 | model_type = model_config.get('model_type', None) 21 | 22 | assert model_type is not None, 'model_type must be specified in model config' 23 | 24 | training_config = model_config.get('training', None) 25 | 26 | if model_type == 'autoencoder': 27 | from stable_audio_tools.training.autoencoders import AutoencoderTrainingWrapper 28 | 29 | ema_copy = None 30 | 31 | if training_config.get("use_ema", False): 32 | from stable_audio_tools.models.factory import create_model_from_config 33 | ema_copy = create_model_from_config(model_config) 34 | ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once 35 | 36 | # Copy each weight to the ema copy 37 | for name, param in model.state_dict().items(): 38 | if isinstance(param, Parameter): 39 | # backwards compatibility for serialized parameters 40 | param = param.data 41 | ema_copy.state_dict()[name].copy_(param) 42 | 43 | use_ema = training_config.get("use_ema", False) 44 | 45 | training_wrapper = AutoencoderTrainingWrapper.load_from_checkpoint( 46 | args.ckpt_path, 47 | autoencoder=model, 48 | strict=False, 49 | loss_config=training_config["loss_configs"], 50 | use_ema=training_config["use_ema"], 51 | ema_copy=ema_copy if use_ema else None 52 | ) 53 | elif model_type == 'diffusion_uncond': 54 | from stable_audio_tools.training.diffusion import DiffusionUncondTrainingWrapper 55 | training_wrapper = DiffusionUncondTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) 56 | 57 | elif model_type == 'diffusion_autoencoder': 58 | from stable_audio_tools.training.diffusion import DiffusionAutoencoderTrainingWrapper 59 | 60 | ema_copy = create_model_from_config(model_config) 61 | 62 | for name, param in model.state_dict().items(): 63 | if isinstance(param, Parameter): 64 | # backwards compatibility for serialized parameters 65 | param = param.data 66 | ema_copy.state_dict()[name].copy_(param) 67 | 68 | training_wrapper = DiffusionAutoencoderTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, ema_copy=ema_copy, strict=False) 69 | elif model_type in ['diffusion_cond', 'diffusion_cond_inpaint']: 70 | from stable_audio_tools.training.diffusion import DiffusionCondTrainingWrapper 71 | 72 | use_ema = training_config.get("use_ema", True) 73 | 74 | training_wrapper = DiffusionCondTrainingWrapper.load_from_checkpoint( 75 | args.ckpt_path, 76 | model=model, 77 | use_ema=use_ema, 78 | lr=training_config.get("learning_rate", None), 79 | optimizer_configs=training_config.get("optimizer_configs", None), 80 | strict=False 81 | ) 82 | elif model_type == 'lm': 83 | from stable_audio_tools.training.lm import AudioLanguageModelTrainingWrapper 84 | 85 | ema_copy = None 86 | 87 | if training_config.get("use_ema", False): 88 | 89 | ema_copy = create_model_from_config(model_config) 90 | 91 | for name, param in model.state_dict().items(): 92 | if isinstance(param, Parameter): 93 | # backwards compatibility for serialized parameters 94 | param = param.data 95 | ema_copy.state_dict()[name].copy_(param) 96 | 97 | training_wrapper = AudioLanguageModelTrainingWrapper.load_from_checkpoint( 98 | args.ckpt_path, 99 | model=model, 100 | strict=False, 101 | ema_copy=ema_copy, 102 | optimizer_configs=training_config.get("optimizer_configs", None) 103 | ) 104 | 105 | else: 106 | raise ValueError(f"Unknown model type {model_type}") 107 | 108 | print(f"Loaded model from {args.ckpt_path}") 109 | 110 | if args.use_safetensors: 111 | ckpt_path = f"{args.name}.safetensors" 112 | else: 113 | ckpt_path = f"{args.name}.ckpt" 114 | 115 | training_wrapper.export_model(ckpt_path, use_safetensors=args.use_safetensors) 116 | 117 | print(f"Exported model to {ckpt_path}") -------------------------------------------------------------------------------- /stable_audio_tools/models/fsq.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dithered Finite Scalar Quantization 3 | Code adapted from https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/finite_scalar_quantization.py 4 | """ 5 | 6 | from typing import List, Tuple 7 | import random 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import Module 12 | from torch import Tensor, int32 13 | from torch.amp import autocast 14 | 15 | from einops import rearrange 16 | 17 | 18 | def leaky_hard_clip(x: Tensor, alpha: float = 1e-3) -> Tensor: 19 | return (1-alpha) * torch.clamp(x, -1, 1) + alpha * x 20 | 21 | def round_ste(z: Tensor) -> Tensor: 22 | """Round with straight through gradients.""" 23 | zhat = z.round() 24 | return z + (zhat - z).detach() 25 | 26 | class DitheredFSQ(Module): 27 | def __init__( 28 | self, 29 | levels: List[int], 30 | dither_inference: bool = False, 31 | num_codebooks: int = 1, 32 | noise_dropout: float = 0.5, 33 | scale: float = 1.0, 34 | ): 35 | super().__init__() 36 | self.levels = levels 37 | 38 | _levels = torch.tensor(levels, dtype=torch.int64) 39 | self.register_buffer("_levels", _levels, persistent = False) 40 | 41 | _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int64) 42 | self.register_buffer("_basis", _basis, persistent = False) 43 | 44 | codebook_dim = len(levels) 45 | self.codebook_dim = codebook_dim 46 | 47 | self.codebook_size = _levels.prod().item() 48 | 49 | self.num_codebooks = num_codebooks 50 | 51 | self.dim = codebook_dim * num_codebooks 52 | 53 | self.dither_inference = dither_inference 54 | 55 | self.scale = scale 56 | 57 | half_l = self.scale * 2 / (self._levels - 1) 58 | self.register_buffer("half_l", half_l, persistent = False) 59 | 60 | self.allowed_dtypes = (torch.float32, torch.float64) 61 | 62 | self.noise_dropout = noise_dropout 63 | 64 | def quantize(self, z, skip_tanh: bool = False): 65 | if not skip_tanh: z = torch.tanh(z) 66 | 67 | if not self.training: 68 | quantized = self._scale_and_shift_inverse(round_ste(self._scale_and_shift(z))) 69 | else: 70 | quantized = z 71 | mask = torch.bernoulli(torch.full([z.shape[0],1,1,1], self.noise_dropout, device = z.device)).bool().expand_as(z) 72 | quantized = torch.where(mask, quantized, self._scale_and_shift_inverse(round_ste(self._scale_and_shift(quantized)))) 73 | mask = torch.bernoulli(torch.full([z.shape[0],1,1,1], self.noise_dropout, device = z.device)).bool().expand_as(z) 74 | quantized = torch.where(mask, quantized, z + (torch.rand_like(z) - 0.5) * self.half_l) 75 | 76 | return quantized 77 | 78 | def _scale_and_shift(self, z): 79 | level_indices = (z + 1 * self.scale) / self.half_l 80 | return level_indices 81 | 82 | def _scale_and_shift_inverse(self, level_indices): 83 | z = level_indices * self.half_l - 1 * self.scale 84 | return z 85 | 86 | def _indices_to_codes(self, indices): 87 | level_indices = self._indices_to_level_indices(indices) 88 | codes = self._scale_and_shift_inverse(level_indices) 89 | return codes 90 | 91 | def _codes_to_indices(self, zhat): 92 | zhat = self._scale_and_shift(zhat) 93 | zhat = zhat.round().to(torch.int64) 94 | out = (zhat * self._basis).sum(dim=-1) 95 | return out 96 | 97 | def _indices_to_level_indices(self, indices): 98 | indices = rearrange(indices, '... -> ... 1') 99 | codes_non_centered = (indices // self._basis) % self._levels 100 | return codes_non_centered 101 | 102 | def indices_to_codes(self, indices): 103 | # Expects input of batch x sequence x num_codebooks 104 | assert indices.shape[-1] == self.num_codebooks, f'expected last dimension of {self.num_codebooks} but found last dimension of {indices.shape[-1]}' 105 | codes = self._indices_to_codes(indices.to(torch.int64)) 106 | codes = rearrange(codes, '... c d -> ... (c d)') 107 | return codes 108 | 109 | @autocast(device_type="cuda", enabled = False) 110 | def forward(self, z, skip_tanh: bool = False): 111 | 112 | orig_dtype = z.dtype 113 | 114 | assert z.shape[-1] == self.dim, f'expected dimension of {self.num_codebooks * self.dim} but found dimension of {z.shape[-1]}' 115 | 116 | z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks) 117 | 118 | # make sure allowed dtype before quantizing 119 | 120 | if z.dtype not in self.allowed_dtypes: 121 | z = z.to(torch.float64) 122 | 123 | codes = self.quantize(z, skip_tanh=skip_tanh) 124 | indices = self._codes_to_indices(codes) 125 | codes = rearrange(codes, 'b n c d -> b n (c d)') 126 | 127 | # cast codes back to original dtype 128 | 129 | if codes.dtype != orig_dtype: 130 | codes = codes.type(orig_dtype) 131 | 132 | # return quantized output and indices 133 | 134 | return codes, indices 135 | -------------------------------------------------------------------------------- /stable_audio_tools/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors.torch import load_file 3 | 4 | from torch.nn.utils import remove_weight_norm 5 | 6 | def copy_state_dict(model, state_dict): 7 | """Load state_dict to model, but only for keys that match exactly. 8 | 9 | Args: 10 | model (nn.Module): model to load state_dict. 11 | state_dict (OrderedDict): state_dict to load. 12 | """ 13 | model_state_dict = model.state_dict() 14 | for key in state_dict: 15 | if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape: 16 | if isinstance(state_dict[key], torch.nn.Parameter): 17 | # backwards compatibility for serialized parameters 18 | state_dict[key] = state_dict[key].data 19 | model_state_dict[key] = state_dict[key] 20 | 21 | model.load_state_dict(model_state_dict, strict=False) 22 | 23 | def load_ckpt_state_dict(ckpt_path): 24 | if ckpt_path.endswith(".safetensors"): 25 | state_dict = load_file(ckpt_path) 26 | else: 27 | state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)["state_dict"] 28 | 29 | return state_dict 30 | 31 | def remove_weight_norm_from_model(model): 32 | for module in model.modules(): 33 | if hasattr(module, "weight"): 34 | print(f"Removing weight norm from {module}") 35 | remove_weight_norm(module) 36 | 37 | return model 38 | 39 | try: 40 | torch._dynamo.config.cache_size_limit = max(64, torch._dynamo.config.cache_size_limit) 41 | torch._dynamo.config.suppress_errors = True 42 | except Exception as e: 43 | pass 44 | 45 | # Get torch.compile flag from environment variable ENABLE_TORCH_COMPILE 46 | 47 | import os 48 | enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1" 49 | 50 | def compile(function, *args, **kwargs): 51 | 52 | if enable_torch_compile: 53 | try: 54 | return torch.compile(function, *args, **kwargs) 55 | except RuntimeError: 56 | return function 57 | 58 | return function 59 | 60 | # Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license 61 | # License can be found in LICENSES/LICENSE_META.txt 62 | 63 | def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): 64 | """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. 65 | 66 | Args: 67 | input (torch.Tensor): The input tensor containing probabilities. 68 | num_samples (int): Number of samples to draw. 69 | replacement (bool): Whether to draw with replacement or not. 70 | Keywords args: 71 | generator (torch.Generator): A pseudorandom number generator for sampling. 72 | Returns: 73 | torch.Tensor: Last dimension contains num_samples indices 74 | sampled from the multinomial probability distribution 75 | located in the last dimension of tensor input. 76 | """ 77 | 78 | if num_samples == 1: 79 | q = torch.empty_like(input).exponential_(1, generator=generator) 80 | return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64) 81 | 82 | input_ = input.reshape(-1, input.shape[-1]) 83 | output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) 84 | output = output_.reshape(*list(input.shape[:-1]), -1) 85 | return output 86 | 87 | 88 | def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: 89 | """Sample next token from top K values along the last dimension of the input probs tensor. 90 | 91 | Args: 92 | probs (torch.Tensor): Input probabilities with token candidates on the last dimension. 93 | k (int): The k in “top-k”. 94 | Returns: 95 | torch.Tensor: Sampled tokens. 96 | """ 97 | top_k_value, _ = torch.topk(probs, k, dim=-1) 98 | min_value_top_k = top_k_value[..., [-1]] 99 | probs *= (probs >= min_value_top_k).float() 100 | probs.div_(probs.sum(dim=-1, keepdim=True)) 101 | next_token = multinomial(probs, num_samples=1) 102 | return next_token 103 | 104 | 105 | def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: 106 | """Sample next token from top P probabilities along the last dimension of the input probs tensor. 107 | 108 | Args: 109 | probs (torch.Tensor): Input probabilities with token candidates on the last dimension. 110 | p (int): The p in “top-p”. 111 | Returns: 112 | torch.Tensor: Sampled tokens. 113 | """ 114 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 115 | probs_sum = torch.cumsum(probs_sort, dim=-1) 116 | mask = probs_sum - probs_sort > p 117 | probs_sort *= (~mask).float() 118 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 119 | next_token = multinomial(probs_sort, num_samples=1) 120 | next_token = torch.gather(probs_idx, -1, next_token) 121 | return next_token 122 | 123 | def next_power_of_two(n): 124 | return 2 ** (n - 1).bit_length() 125 | 126 | def next_multiple_of_64(n): 127 | return ((n + 63) // 64) * 64 -------------------------------------------------------------------------------- /stable_audio_tools/data/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | 5 | from torch import nn 6 | from typing import Tuple 7 | 8 | from torchaudio import transforms as T 9 | 10 | class PadCrop(nn.Module): 11 | def __init__(self, n_samples, randomize=True): 12 | super().__init__() 13 | self.n_samples = n_samples 14 | self.randomize = randomize 15 | 16 | def __call__(self, signal): 17 | n, s = signal.shape 18 | start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() 19 | end = start + self.n_samples 20 | output = signal.new_zeros([n, self.n_samples]) 21 | output[:, :min(s, self.n_samples)] = signal[:, start:end] 22 | return output 23 | 24 | class PadCrop_Normalized_T(nn.Module): 25 | 26 | def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): 27 | 28 | super().__init__() 29 | 30 | self.n_samples = n_samples 31 | self.sample_rate = sample_rate 32 | self.randomize = randomize 33 | 34 | def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]: 35 | 36 | n_channels, n_samples = source.shape 37 | 38 | # If the audio is shorter than the desired length, pad it 39 | upper_bound = max(0, n_samples - self.n_samples) 40 | 41 | # If randomize is False, always start at the beginning of the audio 42 | offset = 0 43 | if(self.randomize and n_samples > self.n_samples): 44 | offset = random.randint(0, upper_bound) 45 | 46 | # Calculate the start and end times of the chunk 47 | t_start = offset / (upper_bound + self.n_samples) 48 | t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) 49 | 50 | # Create the chunk 51 | chunk = source.new_zeros([n_channels, self.n_samples]) 52 | 53 | # Copy the audio into the chunk 54 | chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples] 55 | 56 | # Calculate the start and end times of the chunk in seconds 57 | seconds_start = math.floor(offset / self.sample_rate) 58 | seconds_total = math.ceil(n_samples / self.sample_rate) 59 | 60 | # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't 61 | padding_mask = torch.zeros([self.n_samples]) 62 | padding_mask[:min(n_samples, self.n_samples)] = 1 63 | 64 | 65 | return ( 66 | chunk, 67 | t_start, 68 | t_end, 69 | seconds_start, 70 | seconds_total, 71 | padding_mask 72 | ) 73 | 74 | class PhaseFlipper(nn.Module): 75 | "Randomly invert the phase of a signal" 76 | def __init__(self, p=0.5): 77 | super().__init__() 78 | self.p = p 79 | def __call__(self, signal): 80 | return -signal if (random.random() < self.p) else signal 81 | 82 | class Mono(nn.Module): 83 | def __call__(self, signal): 84 | return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal 85 | 86 | class Stereo(nn.Module): 87 | def __call__(self, signal): 88 | signal_shape = signal.shape 89 | # Check if it's mono 90 | if len(signal_shape) == 1: # s -> 2, s 91 | signal = signal.unsqueeze(0).repeat(2, 1) 92 | elif len(signal_shape) == 2: 93 | if signal_shape[0] == 1: #1, s -> 2, s 94 | signal = signal.repeat(2, 1) 95 | elif signal_shape[0] > 2: #?, s -> 2,s 96 | signal = signal[:2, :] 97 | 98 | return signal 99 | 100 | class VolumeNorm(nn.Module): 101 | "Volume normalization and augmentation of a signal [LUFS standard]" 102 | def __init__(self, params=[-16, 2], sample_rate=16000, energy_threshold=1e-6): 103 | super().__init__() 104 | self.loudness = T.Loudness(sample_rate) 105 | self.value = params[0] 106 | self.gain_range = [-params[1], params[1]] 107 | self.energy_threshold = energy_threshold 108 | 109 | def __call__(self, signal): 110 | """ 111 | signal: torch.Tensor [channels, time] 112 | """ 113 | # avoid do normalisation for silence 114 | energy = torch.mean(signal**2) 115 | if energy < self.energy_threshold: 116 | return signal 117 | 118 | input_loudness = self.loudness(signal) 119 | # Generate a random target loudness within the specified range 120 | target_loudness = self.value + (torch.rand(1).item() * (self.gain_range[1] - self.gain_range[0]) + self.gain_range[0]) 121 | delta_loudness = target_loudness - input_loudness 122 | gain = torch.pow(10.0, delta_loudness / 20.0) 123 | output = gain * signal 124 | 125 | # Check for potentially clipped samples 126 | if torch.max(torch.abs(output)) >= 1.0: 127 | output = self.declip(output) 128 | 129 | return output 130 | 131 | def declip(self, signal): 132 | """ 133 | Declip the signal by scaling down if any samples are clipped 134 | """ 135 | max_val = torch.max(torch.abs(signal)) 136 | if max_val > 1.0: 137 | signal = signal / max_val 138 | signal *= 0.95 139 | return signal 140 | 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /docs/datasets.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | `stable-audio-tools` supports loading data from local file storage, as well as loading audio files and JSON files in the [WebDataset](https://github.com/webdataset/webdataset/tree/main/webdataset) format from Amazon S3 buckets. 3 | 4 | # Dataset configs 5 | To specify the dataset used for training, you must provide a dataset config JSON file to `train.py`. 6 | 7 | The dataset config consists of a `dataset_type` property specifying the type of data loader to use, a `datasets` array to provide multiple data sources, and a `random_crop` property, which decides if the cropped audio from the training samples is from a random place in the audio file, or always from the beginning. 8 | 9 | ## Local audio files 10 | To use a local directory of audio samples, set the `dataset_type` property in your dataset config to `"audio_dir"`, and provide a list of objects to the `datasets` property including the `path` property, which should be the path to your directory of audio samples. 11 | 12 | This will load all of the compatible audio files from the provided directory and all subdirectories. 13 | 14 | ### Example config 15 | ```json 16 | { 17 | "dataset_type": "audio_dir", 18 | "datasets": [ 19 | { 20 | "id": "my_audio", 21 | "path": "/path/to/audio/dataset/" 22 | } 23 | ], 24 | "random_crop": true 25 | } 26 | ``` 27 | 28 | ## S3 WebDataset 29 | To load audio files and related metadata from .tar files in the WebDataset format hosted in Amazon S3 buckets, you can set the `dataset_type` property to `s3`, and provide the `datasets` parameter with a list of objects containing the AWS S3 path to the shared S3 bucket prefix of the WebDataset .tar files. The S3 bucket will be searched recursively given the path, and assumes any .tar files found contain audio files and corresponding JSON files where the related files differ only in file extension (e.g. "000001.flac", "000001.json", "00002.flac", "00002.json", etc.) 30 | 31 | ### Example config 32 | ```json 33 | { 34 | "dataset_type": "s3", 35 | "datasets": [ 36 | { 37 | "id": "s3-test", 38 | "s3_path": "s3://my-bucket/datasets/webdataset/audio/" 39 | } 40 | ], 41 | "random_crop": true 42 | } 43 | ``` 44 | 45 | ## Pre Encoded Datasets 46 | To use pre encoded latents created with the [pre encoding script](pre_encoding.md), set the `dataset_type` property to `"pre_encoded"`, and provide the path to the directory containing the pre encoded `.npy` latent files and corresponding `.json` metadata files. 47 | 48 | You can optionally specify a `latent_crop_length` in latent units (latent length = `audio_samples // 2048`) to crop the pre encoded latents to a smaller length than you encoded to. If not specified, uses the full pre encoded length. When `random_crop` is set to true, it will randomly crop from the sequence at your desired `latent_crop_length` while taking padding into account. 49 | 50 | **Note**: `random_crop` does not currently update `seconds_start`, so it will be inaccurate when used to train or fine-tune models with that condition (e.g. `stable-audio-open-1.0`), but can be used with models that do not use `seconds_start` (e.g. `stable-audio-open-small`). 51 | 52 | ### Example config 53 | ```json 54 | { 55 | "dataset_type": "pre_encoded", 56 | "datasets": [ 57 | { 58 | "id": "my_pre_encoded_audio", 59 | "path": "/path/to/pre_encoded/output/", 60 | "latent_crop_length": 512, 61 | "custom_metadata_module": "/path/to/custom_metadata.py" 62 | } 63 | ], 64 | "random_crop": true 65 | } 66 | ``` 67 | 68 | For information on creating pre encoded datasets, see [Pre Encoding](pre_encoding.md). 69 | 70 | # Custom metadata 71 | To customize the metadata provided to the conditioners during model training, you can provide a separate custom metadata module to the dataset config. This metadata module should be a Python file that must contain a function called `get_custom_metadata` that takes in two parameters, `info`, and `audio`, and returns a dictionary. 72 | 73 | For local training, the `info` parameter will contain a few pieces of information about the loaded audio file, such as the path, and information about how the audio was cropped from the original training sample. For WebDataset datasets, it will also contain the metadata from the related JSON files. 74 | 75 | The `audio` parameter contains the audio sample that will be passed to the model at training time. This lets you analyze the audio for extra properties that you can then pass in as extra conditioning signals. 76 | 77 | The dictionary returned from the `get_custom_metadata` function will have its properties added to the `metadata` object used at training time. For more information on how conditioning works, please see the [Conditioning documentation](./conditioning.md) 78 | 79 | ## Example config and custom metadata module 80 | ```json 81 | { 82 | "dataset_type": "audio_dir", 83 | "datasets": [ 84 | { 85 | "id": "my_audio", 86 | "path": "/path/to/audio/dataset/", 87 | "custom_metadata_module": "/path/to/custom_metadata.py", 88 | } 89 | ], 90 | "random_crop": true 91 | } 92 | ``` 93 | 94 | `custom_metadata.py`: 95 | ```py 96 | def get_custom_metadata(info, audio): 97 | 98 | # Pass in the relative path of the audio file as the prompt 99 | return {"prompt": info["relpath"]} 100 | ``` -------------------------------------------------------------------------------- /stable_audio_tools/training/utils.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.loggers import WandbLogger, CometLogger 2 | from ..interface.aeiou import pca_point_cloud 3 | 4 | import wandb 5 | import torch 6 | import os 7 | 8 | def get_rank(): 9 | """Get rank of current process.""" 10 | 11 | print(os.environ.keys()) 12 | 13 | if "SLURM_PROCID" in os.environ: 14 | return int(os.environ["SLURM_PROCID"]) 15 | 16 | if not torch.distributed.is_available() or not torch.distributed.is_initialized(): 17 | return 0 18 | 19 | return torch.distributed.get_rank() 20 | 21 | class InverseLR(torch.optim.lr_scheduler._LRScheduler): 22 | """Implements an inverse decay learning rate schedule with an optional exponential 23 | warmup. When last_epoch=-1, sets initial lr as lr. 24 | inv_gamma is the number of steps/epochs required for the learning rate to decay to 25 | (1 / 2)**power of its original value. 26 | Args: 27 | optimizer (Optimizer): Wrapped optimizer. 28 | inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. 29 | power (float): Exponential factor of learning rate decay. Default: 1. 30 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) 31 | Default: 0. 32 | final_lr (float): The final learning rate. Default: 0. 33 | last_epoch (int): The index of last epoch. Default: -1. 34 | """ 35 | 36 | def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., 37 | last_epoch=-1): 38 | self.inv_gamma = inv_gamma 39 | self.power = power 40 | if not 0. <= warmup < 1: 41 | raise ValueError('Invalid value for warmup') 42 | self.warmup = warmup 43 | self.final_lr = final_lr 44 | super().__init__(optimizer, last_epoch) 45 | 46 | def get_lr(self): 47 | if not self._get_lr_called_within_step: 48 | import warnings 49 | warnings.warn("To get the last learning rate computed by the scheduler, " 50 | "please use `get_last_lr()`.") 51 | 52 | return self._get_closed_form_lr() 53 | 54 | def _get_closed_form_lr(self): 55 | warmup = 1 - self.warmup ** (self.last_epoch + 1) 56 | lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power 57 | return [warmup * max(self.final_lr, base_lr * lr_mult) 58 | for base_lr in self.base_lrs] 59 | 60 | def create_optimizer_from_config(optimizer_config, parameters): 61 | """Create optimizer from config. 62 | 63 | Args: 64 | parameters (iterable): parameters to optimize. 65 | optimizer_config (dict): optimizer config. 66 | 67 | Returns: 68 | torch.optim.Optimizer: optimizer. 69 | """ 70 | 71 | optimizer_type = optimizer_config["type"] 72 | 73 | if optimizer_type == "FusedAdam": 74 | from deepspeed.ops.adam import FusedAdam 75 | optimizer = FusedAdam(parameters, **optimizer_config["config"]) 76 | else: 77 | optimizer_fn = getattr(torch.optim, optimizer_type) 78 | optimizer = optimizer_fn(parameters, **optimizer_config["config"]) 79 | return optimizer 80 | 81 | def create_scheduler_from_config(scheduler_config, optimizer): 82 | """Create scheduler from config. 83 | 84 | Args: 85 | scheduler_config (dict): scheduler config. 86 | optimizer (torch.optim.Optimizer): optimizer. 87 | 88 | Returns: 89 | torch.optim.lr_scheduler._LRScheduler: scheduler. 90 | """ 91 | if scheduler_config["type"] == "InverseLR": 92 | scheduler_fn = InverseLR 93 | else: 94 | scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"]) 95 | scheduler = scheduler_fn(optimizer, **scheduler_config["config"]) 96 | return scheduler 97 | 98 | def logger_project_name(logger) -> str: 99 | if isinstance(logger, WandbLogger): 100 | return logger.experiment.project 101 | elif isinstance(logger, CometLogger): 102 | return logger.name 103 | 104 | def log_metric(logger, key, value, step=None): 105 | from pytorch_lightning.loggers import WandbLogger, CometLogger 106 | if isinstance(logger, WandbLogger): 107 | logger.experiment.log({key: value}) 108 | elif isinstance(logger, CometLogger): 109 | logger.experiment.log_metrics({key: value}, step=step) 110 | 111 | def log_audio(logger, key, audio_path, sample_rate, caption=None): 112 | if isinstance(logger, WandbLogger): 113 | logger.experiment.log({key: wandb.Audio(audio_path, sample_rate=sample_rate, caption=caption)}) 114 | elif isinstance(logger, CometLogger): 115 | logger.experiment.log_audio(audio_path, file_name=key, sample_rate=sample_rate) 116 | 117 | def log_image(logger, key, img_data): 118 | if isinstance(logger, WandbLogger): 119 | logger.experiment.log({key: wandb.Image(img_data)}) 120 | elif isinstance(logger, CometLogger): 121 | logger.experiment.log_image(img_data, name=key) 122 | 123 | def log_point_cloud(logger, key, tokens, caption=None): 124 | if isinstance(logger, WandbLogger): 125 | point_cloud = pca_point_cloud(tokens) 126 | logger.experiment.log({key: point_cloud}) 127 | elif isinstance(logger, CometLogger): 128 | point_cloud = pca_point_cloud(tokens, rgb_float=True, output_type="points") 129 | #logger.experiment.log_points_3d(scene_name=key, points=point_cloud) -------------------------------------------------------------------------------- /stable_audio_tools/models/convnext.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn.utils import weight_norm 5 | 6 | def WNConv1d(*args, **kwargs): 7 | return weight_norm(nn.Conv1d(*args, **kwargs)) 8 | 9 | def WNConvTranspose1d(*args, **kwargs): 10 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 11 | 12 | def checkpoint(function, *args, **kwargs): 13 | kwargs.setdefault("use_reentrant", False) 14 | return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) 15 | 16 | class ConvNeXtBlock(nn.Module): 17 | 18 | def __init__(self, dim, kernel_size=7, mult=4, glu=False): 19 | super().__init__() 20 | padding = kernel_size // 2 21 | self.dw_conv = WNConv1d(dim, dim, kernel_size=kernel_size, padding=padding, groups=dim) 22 | 23 | self.glu = glu 24 | 25 | if glu: 26 | self.proj_up = WNConv1d(dim, dim * mult * 2, kernel_size=1) 27 | self.act = nn.SiLU() 28 | else: 29 | self.proj_up = WNConv1d(dim, dim * mult, kernel_size=1) 30 | self.act = nn.GELU() 31 | 32 | self.proj_down = WNConv1d(dim * mult, dim, kernel_size=1) 33 | 34 | # Zero-init the last conv layer 35 | nn.init.zeros_(self.proj_down.weight) 36 | nn.init.zeros_(self.proj_down.bias) 37 | 38 | 39 | 40 | def forward(self, x): 41 | input = x 42 | x = self.dw_conv(x) 43 | x = self.proj_up(x) 44 | if self.glu: 45 | x, gate = x.chunk(2, dim=1) 46 | x = x * torch.sigmoid(gate) 47 | x = self.act(x) 48 | x = self.proj_down(x) 49 | 50 | return x + input 51 | 52 | class ConvNextEncoderBlock(nn.Module): 53 | def __init__(self, in_channels, out_channels, stride, num_blocks=3, conv_args = {}): 54 | super().__init__() 55 | 56 | self.layers = nn.ModuleList([ConvNeXtBlock(dim=in_channels, **conv_args) for _ in range(num_blocks)]) 57 | 58 | self.downsample = WNConv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) 59 | 60 | def forward(self, x): 61 | for layer in self.layers: 62 | x = checkpoint(layer, x) 63 | 64 | x = self.downsample(x) 65 | 66 | return x 67 | 68 | class ConvNextDecoderBlock(nn.Module): 69 | def __init__(self, in_channels, out_channels, stride, num_blocks=3, conv_args = {}): 70 | super().__init__() 71 | 72 | 73 | self.upsample = WNConvTranspose1d(in_channels=in_channels, 74 | out_channels=out_channels, 75 | kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) 76 | 77 | self.layers = nn.ModuleList([ConvNeXtBlock(dim=out_channels, **conv_args) for _ in range(num_blocks)]) 78 | 79 | def forward(self, x): 80 | x = self.upsample(x) 81 | for layer in self.layers: 82 | x = checkpoint(layer, x) 83 | return x 84 | 85 | class ConvNeXtEncoder(nn.Module): 86 | def __init__(self, 87 | in_channels=2, 88 | channels=128, 89 | latent_dim=32, 90 | c_mults = [1, 2, 4, 8], 91 | strides = [2, 4, 8, 8], 92 | num_blocks = None, 93 | conv_args = {}, 94 | ): 95 | super().__init__() 96 | 97 | c_mults = [1] + c_mults 98 | 99 | if num_blocks is None: 100 | num_blocks = [3] * (len(c_mults)-1) 101 | 102 | self.depth = len(c_mults) 103 | 104 | self.proj_in = WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) 105 | 106 | layers = [] 107 | 108 | for i in range(self.depth-1): 109 | layers += [ConvNextEncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], conv_args=conv_args, num_blocks=num_blocks[i])] 110 | 111 | layers += [ 112 | WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) 113 | ] 114 | 115 | self.layers = nn.Sequential(*layers) 116 | 117 | def forward(self, x, return_features=False): 118 | 119 | x = self.proj_in(x) 120 | 121 | return self.layers(x) 122 | 123 | 124 | class ConvNeXtDecoder(nn.Module): 125 | def __init__(self, 126 | out_channels=2, 127 | channels=128, 128 | latent_dim=32, 129 | c_mults = [1, 2, 4, 8], 130 | strides = [2, 4, 8, 8], 131 | conv_args = {}): 132 | super().__init__() 133 | 134 | c_mults = [1] + c_mults 135 | 136 | self.depth = len(c_mults) 137 | 138 | layers = [ 139 | WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3) 140 | ] 141 | 142 | for i in range(self.depth-1, 0, -1): 143 | layers += [ConvNextDecoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i-1]*channels, stride=strides[i-1], conv_args=conv_args)] 144 | 145 | layers += [WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False)] 146 | 147 | self.layers = nn.Sequential(*layers) 148 | 149 | def forward(self, x): 150 | return self.layers(x) -------------------------------------------------------------------------------- /stable_audio_tools/training/losses/semantic.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import audiotools 4 | import torch 5 | import torchaudio 6 | 7 | from einops import rearrange 8 | from torch.nn import functional as F 9 | from torch import nn 10 | 11 | def fold_channels_into_batch(x): 12 | x = rearrange(x, 'b c ... -> (b c) ...') 13 | return x 14 | 15 | class HubertLoss(nn.Module): 16 | def __init__(self, 17 | feature_ids: tp.Optional[tp.List[int]] = None, 18 | weight: float = 1.0, 19 | model_name: str = "HUBERT_LARGE" 20 | ): 21 | super().__init__() 22 | 23 | self.weight = weight 24 | self.feature_ids = feature_ids 25 | self.model_name = model_name 26 | 27 | # Load model based on the specified model name 28 | if self.model_name == "WAVLM_LARGE": 29 | bundle = torchaudio.pipelines.WAVLM_LARGE 30 | elif self.model_name == "HUBERT_LARGE": 31 | bundle = torchaudio.pipelines.HUBERT_LARGE 32 | elif self.model_name == "WAV2VEC2_LARGE_LV60K": 33 | bundle = torchaudio.pipelines.WAV2VEC2_LARGE_LV60K 34 | else: 35 | raise ValueError(f"Unsupported model_name: {self.model_name}") 36 | 37 | self.model = bundle.get_model() 38 | 39 | for param in self.model.parameters(): 40 | param.requires_grad = False 41 | 42 | def forward(self, x, y): 43 | x = fold_channels_into_batch(x) 44 | y = fold_channels_into_batch(y) 45 | 46 | conv_features = ( 47 | self.feature_ids is not None and 48 | len(self.feature_ids) == 1 and 49 | self.feature_ids[0] == -1) 50 | 51 | # Extract features from conv layers only. 52 | if conv_features: 53 | if self.model.normalize_waveform: 54 | x = nn.functional.layer_norm(x, x.shape) 55 | y = nn.functional.layer_norm(y, y.shape) 56 | x_list, _ = self.model.model.feature_extractor(x, None) 57 | y_list, _ = self.model.model.feature_extractor(y, None) 58 | x_list = [x_list] 59 | y_list = [y_list] 60 | else: 61 | x_list, _ = self.model.extract_features(x) 62 | y_list, _ = self.model.extract_features(y) 63 | 64 | loss = 0 65 | denom = 0 66 | for i, (x, y) in enumerate(zip(x_list, y_list)): 67 | if self.feature_ids is None or i in self.feature_ids or conv_features: 68 | loss += F.l1_loss(x, y) / (y.std() + 1e-5) 69 | denom += 1 70 | 71 | loss = loss / denom 72 | return self.weight * loss 73 | 74 | # Implementation taken from: 75 | # https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/nn/loss.py#L231 76 | class MelSpectrogramLoss(nn.Module): 77 | """Compute distance between mel spectrograms. Can be used 78 | in a multi-scale way. 79 | 80 | Parameters 81 | ---------- 82 | n_mels : List[int] 83 | Number of mels per STFT, by default [150, 80], 84 | window_lengths : List[int], optional 85 | Length of each window of each STFT, by default [2048, 512] 86 | loss_fn : typing.Callable, optional 87 | How to compare each loss, by default nn.L1Loss() 88 | clamp_eps : float, optional 89 | Clamp on the log magnitude, below, by default 1e-5 90 | mag_weight : float, optional 91 | Weight of raw magnitude portion of loss, by default 1.0 92 | log_weight : float, optional 93 | Weight of log magnitude portion of loss, by default 1.0 94 | pow : float, optional 95 | Power to raise magnitude to before taking log, by default 2.0 96 | weight : float, optional 97 | Weight of this loss, by default 1.0 98 | """ 99 | 100 | def __init__(self, sample_rate: int, 101 | n_mels: tp.List[int], 102 | window_lengths: tp.List[int], 103 | loss_fn: tp.Callable = nn.L1Loss(), 104 | clamp_eps: float = 1e-5, 105 | mag_weight: float = 1.0, 106 | log_weight: float = 1.0, 107 | pow: float = 2.0, 108 | weight: float = 1.0, 109 | mel_fmin: tp.Optional[tp.List[float]] = None, 110 | mel_fmax: tp.Optional[tp.List[float]] = None, 111 | window_type: tp.Optional[str] = None, 112 | ): 113 | super().__init__() 114 | self.stft_params = [{ 115 | "window_length": w, 116 | "hop_length": w // 4, 117 | "window_type": window_type, 118 | } for w in window_lengths] 119 | 120 | self.sample_rate = sample_rate 121 | self.n_mels = n_mels 122 | self.loss_fn = loss_fn 123 | self.clamp_eps = clamp_eps 124 | self.log_weight = log_weight 125 | self.mag_weight = mag_weight 126 | self.weight = weight 127 | self.pow = pow 128 | 129 | self.mel_fmin = ( 130 | mel_fmin 131 | if mel_fmin is not None else 132 | [0.0 for _ in range(len(window_lengths))] 133 | ) 134 | self.mel_fmax = ( 135 | mel_fmax 136 | if mel_fmin is not None else 137 | [None for _ in range(len(window_lengths))] 138 | ) 139 | 140 | def forward(self, x: torch.Tensor, y: torch.Tensor): 141 | x = audiotools.AudioSignal(x, self.sample_rate) 142 | y = audiotools.AudioSignal(y, self.sample_rate) 143 | 144 | loss = 0.0 145 | for n_mels, fmin, fmax, params in zip( 146 | self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params, 147 | ): 148 | x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **params) 149 | y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **params) 150 | 151 | loss += self.log_weight * self.loss_fn( 152 | x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), 153 | y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), 154 | ) 155 | loss += self.mag_weight * self.loss_fn(x_mels, y_mels) 156 | return loss 157 | -------------------------------------------------------------------------------- /stable_audio_tools/training/losses/losses.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | import torch 3 | 4 | from torch.nn import functional as F 5 | from torch import nn 6 | from .utils import mmd 7 | 8 | class LossModule(nn.Module): 9 | def __init__(self, name: str, weight: float = 1.0, decay = 1.0): 10 | super().__init__() 11 | 12 | self.name = name 13 | self.decay = float(decay) 14 | self.master_weight = weight 15 | weight = torch.tensor(float(weight)) 16 | self.register_buffer('weight', weight) 17 | 18 | def decay_weight(self): 19 | if self.decay != 1.0: 20 | self.weight *= self.decay 21 | elif self.decay == 1.0 and self.weight != self.master_weight: 22 | self.weight = torch.tensor(self.master_weight, dtype=self.weight.dtype, device=self.weight.device) 23 | def forward(self, info, *args, **kwargs): 24 | raise NotImplementedError 25 | 26 | class ValueLoss(LossModule): 27 | def __init__(self, key: str, name, weight: float = 1.0, decay = 1.0): 28 | super().__init__(name=name, weight=weight, decay=decay) 29 | 30 | self.key = key 31 | 32 | def forward(self, info): 33 | self.decay_weight() 34 | return self.weight * info[self.key] 35 | 36 | class TargetValueLoss(LossModule): 37 | def __init__(self, key: str, target: float, name: str, weight: float = 1.0): 38 | super().__init__(name=name, weight=weight) 39 | 40 | self.key = key 41 | self.target = target 42 | 43 | def forward(self, info): 44 | self.decay_weight() 45 | return self.weight * (info[self.key] - self.target).abs() 46 | 47 | class L1Loss(LossModule): 48 | def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'l1_loss', decay = 1.0): 49 | super().__init__(name=name, weight=weight, decay=decay) 50 | 51 | self.key_a = key_a 52 | self.key_b = key_b 53 | 54 | self.mask_key = mask_key 55 | 56 | def forward(self, info): 57 | l1_loss = F.l1_loss(info[self.key_a], info[self.key_b], reduction='none') 58 | 59 | if self.mask_key is not None and self.mask_key in info: 60 | l1_loss = l1_loss[info[self.mask_key]] 61 | 62 | l1_loss = l1_loss.mean() 63 | self.decay_weight() 64 | return self.weight * l1_loss 65 | 66 | class MSELoss(LossModule): 67 | def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'mse_loss',decay = 1.0): 68 | super().__init__(name=name, weight=weight, decay=decay) 69 | 70 | self.key_a = key_a 71 | self.key_b = key_b 72 | 73 | self.mask_key = mask_key 74 | 75 | def forward(self, info): 76 | mse_loss = F.mse_loss(info[self.key_a], info[self.key_b], reduction='none') 77 | 78 | if self.mask_key is not None and self.mask_key in info and info[self.mask_key] is not None: 79 | mask = info[self.mask_key] 80 | 81 | if mask.ndim == 2 and mse_loss.ndim == 3: 82 | mask = mask.unsqueeze(1) 83 | 84 | if mask.shape[1] != mse_loss.shape[1]: 85 | mask = mask.repeat(1, mse_loss.shape[1], 1) 86 | 87 | mse_loss = mse_loss[mask] 88 | 89 | mse_loss = mse_loss.mean() 90 | self.decay_weight() 91 | return self.weight * mse_loss 92 | 93 | class LossWithTarget(LossModule): 94 | def __init__(self, loss_module, input_key: str, target_key: str, name: str, weight: float = 1, decay = 1.0): 95 | super().__init__(name, weight, decay) 96 | 97 | self.loss_module = loss_module 98 | 99 | self.input_key = input_key 100 | self.target_key = target_key 101 | 102 | def forward(self, info): 103 | loss = self.loss_module(info[self.input_key], info[self.target_key]) 104 | self.decay_weight() 105 | return self.weight * loss 106 | 107 | class AuralossLoss(LossWithTarget): 108 | def __init__(self, loss_module, input_key: str, target_key: str, name: str, weight: float = 1, decay = 1.0): 109 | super().__init__(loss_module, input_key=input_key, target_key=target_key, name=name, weight=weight, decay=decay) 110 | def forward(self, info): 111 | loss = self.loss_module(info[self.target_key], info[self.input_key]) # Enforce wrong order of input and target until we find issue in Auraloss 112 | self.decay_weight() 113 | return self.weight * loss 114 | 115 | class MultiLoss(nn.Module): 116 | def __init__(self, losses: tp.List[LossModule]): 117 | super().__init__() 118 | 119 | self.losses = nn.ModuleList(losses) 120 | 121 | def forward(self, info): 122 | total_loss = 0 123 | 124 | losses = {} 125 | 126 | for loss_module in self.losses: 127 | module_loss = loss_module(info) 128 | total_loss += module_loss 129 | losses[loss_module.name] = module_loss 130 | 131 | return total_loss, losses 132 | 133 | class StereoImageLoss(LossModule): 134 | def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'stereo_image_loss', decay = 1.0): 135 | super().__init__(name=name, weight=weight, decay=decay) 136 | 137 | self.key_a = key_a 138 | self.key_b = key_b 139 | 140 | self.mask_key = mask_key 141 | 142 | def forward(self, info): 143 | loss = 0.5*(1 - F.cosine_similarity(info[self.key_a], info[self.key_b], dim=1)) 144 | 145 | if self.mask_key is not None and self.mask_key in info: 146 | loss = loss[info[self.mask_key]] 147 | 148 | loss = loss.mean() 149 | self.decay_weight() 150 | return self.weight * loss 151 | 152 | class TimeDomainMMDLoss(LossModule): 153 | def __init__(self, key_a: str, key_b: str, weight: float = 1.0, name: str = 'time_domain_mmd_loss', decay = 1.0): 154 | super().__init__(name=name, weight=weight, decay=decay) 155 | 156 | self.key_a = key_a 157 | self.key_b = key_b 158 | 159 | def forward(self, info): 160 | loss = mmd(info[self.key_a], info[self.key_b], bandwidths=[0.0001,0.001,0.01,0.1,1], dim=-1) 161 | self.decay_weight() 162 | return self.weight * loss -------------------------------------------------------------------------------- /stable_audio_tools/models/lm_backbone.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from x_transformers import ContinuousTransformerWrapper, Decoder 3 | 4 | from .transformer import ContinuousTransformer 5 | 6 | # Interface for backbone of a language model 7 | # Handles conditioning and cross-attention 8 | # Does not have to deal with patterns or quantizer heads 9 | class AudioLMBackbone(nn.Module): 10 | def __init__(self, embed_dim: int, use_generation_cache=False, **kwargs): 11 | super().__init__() 12 | 13 | self.embed_dim = embed_dim 14 | self.use_generation_cache = use_generation_cache 15 | 16 | def forward( 17 | self, 18 | x, 19 | cross_attn_cond=None, 20 | prepend_cond=None, 21 | prepend_cond_mask=None, 22 | global_cond=None, 23 | use_cache=False, 24 | **kwargs 25 | ): 26 | raise NotImplementedError 27 | 28 | def reset_generation_cache( 29 | self, 30 | max_seq_len, 31 | batch_size, 32 | dtype=None 33 | ): 34 | pass 35 | 36 | def update_generation_cache( 37 | self, 38 | seqlen_offset 39 | ): 40 | pass 41 | 42 | class XTransformersAudioLMBackbone(AudioLMBackbone): 43 | def __init__(self, 44 | embed_dim: int, 45 | cross_attn_cond_dim: int = 0, 46 | prepend_cond_dim: int = 0, 47 | **kwargs): 48 | super().__init__(embed_dim=embed_dim) 49 | 50 | # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer 51 | self.model = ContinuousTransformerWrapper( 52 | dim_in=embed_dim, 53 | dim_out=embed_dim, 54 | max_seq_len=0, #Not relevant without absolute positional embeds, 55 | attn_layers=Decoder( 56 | dim=embed_dim, 57 | attn_flash = True, 58 | cross_attend = cross_attn_cond_dim > 0, 59 | zero_init_branch_output=True, 60 | use_abs_pos_emb = False, 61 | rotary_pos_emb=True, 62 | ff_swish = True, 63 | ff_glu = True, 64 | **kwargs 65 | ) 66 | ) 67 | 68 | if prepend_cond_dim > 0: 69 | # Prepend conditioning 70 | self.to_prepend_embed = nn.Sequential( 71 | nn.Linear(prepend_cond_dim, embed_dim, bias=False), 72 | nn.SiLU(), 73 | nn.Linear(embed_dim, embed_dim, bias=False) 74 | ) 75 | 76 | if cross_attn_cond_dim > 0: 77 | # Cross-attention conditioning 78 | self.to_cross_attn_embed = nn.Sequential( 79 | nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), 80 | nn.SiLU(), 81 | nn.Linear(embed_dim, embed_dim, bias=False) 82 | ) 83 | 84 | def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): 85 | 86 | prepend_length = 0 87 | if prepend_cond is not None: 88 | # Project the prepend conditioning to the embedding dimension 89 | prepend_cond = self.to_prepend_embed(prepend_cond) 90 | prepend_length = prepend_cond.shape[1] 91 | 92 | if prepend_cond_mask is not None: 93 | # Cast mask to bool 94 | prepend_cond_mask = prepend_cond_mask.bool() 95 | 96 | if cross_attn_cond is not None: 97 | # Project the cross-attention conditioning to the embedding dimension 98 | cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) 99 | 100 | return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] 101 | 102 | class ContinuousTransformerAudioLMBackbone(AudioLMBackbone): 103 | def __init__(self, 104 | embed_dim: int, 105 | cross_attn_cond_dim: int = 0, 106 | prepend_cond_dim: int = 0, 107 | project_cross_attn_cond: bool = False, 108 | **kwargs): 109 | super().__init__(embed_dim=embed_dim) 110 | 111 | # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer 112 | self.model = ContinuousTransformer( 113 | dim=embed_dim, 114 | dim_in=embed_dim, 115 | dim_out=embed_dim, 116 | cross_attend = cross_attn_cond_dim > 0, 117 | cond_token_dim = embed_dim if project_cross_attn_cond else cross_attn_cond_dim, 118 | causal=True, 119 | **kwargs 120 | ) 121 | 122 | if prepend_cond_dim > 0: 123 | # Prepend conditioning 124 | self.to_prepend_embed = nn.Sequential( 125 | nn.Linear(prepend_cond_dim, embed_dim, bias=False), 126 | nn.SiLU(), 127 | nn.Linear(embed_dim, embed_dim, bias=False) 128 | ) 129 | 130 | if cross_attn_cond_dim > 0 and project_cross_attn_cond: 131 | # Cross-attention conditioning 132 | self.to_cross_attn_embed = nn.Sequential( 133 | nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), 134 | nn.SiLU(), 135 | nn.Linear(embed_dim, embed_dim, bias=False) 136 | ) 137 | else: 138 | self.to_cross_attn_embed = nn.Identity() 139 | 140 | def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): 141 | 142 | prepend_length = 0 143 | if prepend_cond is not None: 144 | # Project the prepend conditioning to the embedding dimension 145 | prepend_cond = self.to_prepend_embed(prepend_cond) 146 | prepend_length = prepend_cond.shape[1] 147 | 148 | if prepend_cond_mask is not None: 149 | # Cast mask to bool 150 | prepend_cond_mask = prepend_cond_mask.bool() 151 | 152 | if cross_attn_cond is not None: 153 | # Cast cross_attn_cond to same dtype as self.to_cross_attn_embed 154 | cross_attn_cond = cross_attn_cond.to(self.to_cross_attn_embed[0].weight.dtype) 155 | 156 | # Project the cross-attention conditioning to the embedding dimension 157 | cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) 158 | 159 | return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import os 4 | import pytorch_lightning as pl 5 | 6 | from typing import Dict, Optional, Union 7 | from prefigure.prefigure import get_all_args, push_wandb_config 8 | from stable_audio_tools.data.dataset import create_dataloader_from_config, fast_scandir 9 | from stable_audio_tools.models import create_model_from_config 10 | from stable_audio_tools.models.utils import copy_state_dict, load_ckpt_state_dict, remove_weight_norm_from_model 11 | from stable_audio_tools.training import create_training_wrapper_from_config, create_demo_callback_from_config 12 | 13 | class ExceptionCallback(pl.Callback): 14 | def on_exception(self, trainer, module, err): 15 | print(f'{type(err).__name__}: {err}') 16 | 17 | class ModelConfigEmbedderCallback(pl.Callback): 18 | def __init__(self, model_config): 19 | self.model_config = model_config 20 | 21 | def on_save_checkpoint(self, trainer, pl_module, checkpoint): 22 | checkpoint["model_config"] = self.model_config 23 | 24 | def main(): 25 | torch.multiprocessing.set_sharing_strategy('file_system') 26 | args = get_all_args() 27 | seed = args.seed 28 | 29 | # Set a different seed for each process if using SLURM 30 | if os.environ.get("SLURM_PROCID") is not None: 31 | seed += int(os.environ.get("SLURM_PROCID")) 32 | 33 | pl.seed_everything(seed, workers=True) 34 | 35 | #Get JSON config from args.model_config 36 | with open(args.model_config) as f: 37 | model_config = json.load(f) 38 | 39 | with open(args.dataset_config) as f: 40 | dataset_config = json.load(f) 41 | 42 | train_dl = create_dataloader_from_config( 43 | dataset_config, 44 | batch_size=args.batch_size, 45 | num_workers=args.num_workers, 46 | sample_rate=model_config["sample_rate"], 47 | sample_size=model_config["sample_size"], 48 | audio_channels=model_config.get("audio_channels", 2), 49 | ) 50 | 51 | val_dl = None 52 | val_dataset_config = None 53 | 54 | if args.val_dataset_config: 55 | with open(args.val_dataset_config) as f: 56 | val_dataset_config = json.load(f) 57 | 58 | val_dl = create_dataloader_from_config( 59 | val_dataset_config, 60 | batch_size=args.batch_size, 61 | num_workers=args.num_workers, 62 | sample_rate=model_config["sample_rate"], 63 | sample_size=model_config["sample_size"], 64 | audio_channels=model_config.get("audio_channels", 2), 65 | shuffle=False 66 | ) 67 | 68 | model = create_model_from_config(model_config) 69 | 70 | if args.pretrained_ckpt_path: 71 | copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path)) 72 | 73 | if args.remove_pretransform_weight_norm == "pre_load": 74 | remove_weight_norm_from_model(model.pretransform) 75 | 76 | if args.pretransform_ckpt_path: 77 | model.pretransform.load_state_dict(load_ckpt_state_dict(args.pretransform_ckpt_path)) 78 | 79 | # Remove weight_norm from the pretransform if specified 80 | if args.remove_pretransform_weight_norm == "post_load": 81 | remove_weight_norm_from_model(model.pretransform) 82 | 83 | training_wrapper = create_training_wrapper_from_config(model_config, model) 84 | 85 | exc_callback = ExceptionCallback() 86 | 87 | if args.logger == 'wandb': 88 | logger = pl.loggers.WandbLogger(project=args.name) 89 | logger.watch(training_wrapper) 90 | 91 | if args.save_dir and isinstance(logger.experiment.id, str): 92 | checkpoint_dir = os.path.join(args.save_dir, logger.experiment.project, logger.experiment.id, "checkpoints") 93 | else: 94 | checkpoint_dir = None 95 | elif args.logger == 'comet': 96 | logger = pl.loggers.CometLogger(project_name=args.name) 97 | if args.save_dir and isinstance(logger.version, str): 98 | checkpoint_dir = os.path.join(args.save_dir, logger.name, logger.version, "checkpoints") 99 | else: 100 | checkpoint_dir = args.save_dir if args.save_dir else None 101 | else: 102 | logger = None 103 | checkpoint_dir = args.save_dir if args.save_dir else None 104 | 105 | ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, dirpath=checkpoint_dir, save_top_k=-1) 106 | save_model_config_callback = ModelConfigEmbedderCallback(model_config) 107 | 108 | if args.val_dataset_config: 109 | demo_callback = create_demo_callback_from_config(model_config, demo_dl=val_dl) 110 | else: 111 | demo_callback = create_demo_callback_from_config(model_config, demo_dl=train_dl) 112 | 113 | #Combine args and config dicts 114 | args_dict = vars(args) 115 | args_dict.update({"model_config": model_config}) 116 | args_dict.update({"dataset_config": dataset_config}) 117 | args_dict.update({"val_dataset_config": val_dataset_config}) 118 | 119 | if args.logger == 'wandb': 120 | push_wandb_config(logger, args_dict) 121 | elif args.logger == 'comet': 122 | logger.log_hyperparams(args_dict) 123 | 124 | #Set multi-GPU strategy if specified 125 | if args.strategy: 126 | if args.strategy == "deepspeed": 127 | from pytorch_lightning.strategies import DeepSpeedStrategy 128 | strategy = DeepSpeedStrategy(stage=2, 129 | contiguous_gradients=True, 130 | overlap_comm=True, 131 | reduce_scatter=True, 132 | reduce_bucket_size=5e8, 133 | allgather_bucket_size=5e8, 134 | load_full_weights=True) 135 | else: 136 | strategy = args.strategy 137 | else: 138 | strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else "auto" 139 | 140 | val_args = {} 141 | 142 | if args.val_every > 0: 143 | val_args.update({ 144 | "check_val_every_n_epoch": None, 145 | "val_check_interval": args.val_every, 146 | }) 147 | 148 | trainer = pl.Trainer( 149 | devices="auto", 150 | accelerator="gpu", 151 | num_nodes = args.num_nodes, 152 | strategy=strategy, 153 | precision=args.precision, 154 | accumulate_grad_batches=args.accum_batches, 155 | callbacks=[ckpt_callback, demo_callback, exc_callback, save_model_config_callback], 156 | logger=logger, 157 | log_every_n_steps=1, 158 | max_epochs=10000000, 159 | default_root_dir=args.save_dir, 160 | gradient_clip_val=args.gradient_clip_val, 161 | reload_dataloaders_every_n_epochs = 0, 162 | num_sanity_val_steps=0, # If you need to debug validation, change this line 163 | **val_args 164 | ) 165 | 166 | trainer.fit(training_wrapper, train_dl, val_dl, ckpt_path=args.ckpt_path if args.ckpt_path else None) 167 | 168 | if __name__ == '__main__': 169 | main() 170 | -------------------------------------------------------------------------------- /stable_audio_tools/models/encodec.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/encodec/blob/main/encodec/msstftd.py under MIT License 2 | # License can be found in LICENSES/LICENSE_META.txt 3 | 4 | """MS-STFT discriminator, provided here for reference.""" 5 | 6 | import typing as tp 7 | 8 | import torchaudio 9 | import torch 10 | from torch import nn 11 | from einops import rearrange 12 | 13 | from torch.nn.utils import weight_norm 14 | 15 | def checkpoint(function, *args, **kwargs): 16 | kwargs.setdefault("use_reentrant", False) 17 | return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) 18 | 19 | class NormConv2d(nn.Module): 20 | """Wrapper around Conv2d and normalization applied to this conv 21 | to provide a uniform interface across normalization approaches. 22 | """ 23 | def __init__(self, *args, **kwargs): 24 | super().__init__() 25 | self.conv = weight_norm(nn.Conv2d(*args, **kwargs)) 26 | 27 | def forward(self, x): 28 | return self.conv(x) 29 | 30 | FeatureMapType = tp.List[torch.Tensor] 31 | LogitsType = torch.Tensor 32 | DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]] 33 | 34 | def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)): 35 | return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2) 36 | 37 | class DiscriminatorSTFT(nn.Module): 38 | """STFT sub-discriminator. 39 | Args: 40 | filters (int): Number of filters in convolutions 41 | in_channels (int): Number of input channels. Default: 1 42 | out_channels (int): Number of output channels. Default: 1 43 | n_fft (int): Size of FFT for each scale. Default: 1024 44 | hop_length (int): Length of hop between STFT windows for each scale. Default: 256 45 | kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)`` 46 | stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)`` 47 | dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]`` 48 | win_length (int): Window size for each scale. Default: 1024 49 | normalized (bool): Whether to normalize by magnitude after stft. Default: True 50 | activation (str): Activation function. Default: `'LeakyReLU'` 51 | activation_params (dict): Parameters to provide to the activation function. 52 | growth (int): Growth factor for the filters. Default: 1 53 | """ 54 | def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, 55 | n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024, 56 | filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4], 57 | stride: tp.Tuple[int, int] = (1, 1), normalized: bool = True, 58 | activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}, spec_scale_pow = 0.0, **kwargs): 59 | super().__init__() 60 | assert len(kernel_size) == 2 61 | assert len(stride) == 2 62 | self.filters = filters 63 | self.in_channels = in_channels 64 | self.out_channels = out_channels 65 | self.n_fft = n_fft 66 | self.hop_length = hop_length 67 | self.win_length = win_length 68 | self.normalized = normalized 69 | self.activation = getattr(torch.nn, activation)(**activation_params) 70 | 71 | self.spec_transform = torchaudio.transforms.Spectrogram( 72 | n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window, 73 | normalized=self.normalized, center=False, pad_mode=None, power=None) 74 | spec_channels = 2 * self.in_channels 75 | self.convs = nn.ModuleList() 76 | self.convs.append( 77 | NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size)) 78 | ) 79 | in_chs = min(filters_scale * self.filters, max_filters) 80 | for i, dilation in enumerate(dilations): 81 | out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) 82 | self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, 83 | dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)))) 84 | in_chs = out_chs 85 | out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters) 86 | self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]), 87 | padding=get_2d_padding((kernel_size[0], kernel_size[0])))) 88 | self.conv_post = NormConv2d(out_chs, self.out_channels, 89 | kernel_size=(kernel_size[0], kernel_size[0]), 90 | padding=get_2d_padding((kernel_size[0], kernel_size[0]))) 91 | 92 | self.spec_scale_pow = spec_scale_pow 93 | 94 | def forward(self, x: torch.Tensor): 95 | fmap = [] 96 | z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] 97 | if self.spec_scale_pow != 0.0: 98 | z = z * torch.pow(z.abs()+1e-6, self.spec_scale_pow) 99 | z = torch.cat([z.real, z.imag], dim=1) 100 | z = rearrange(z, 'b c w t -> b c t w') 101 | for i, layer in enumerate(self.convs): 102 | z = checkpoint(layer, z) 103 | z = self.activation(z) 104 | fmap.append(z) 105 | z = checkpoint(self.conv_post, z) 106 | return z, fmap 107 | 108 | class MultiScaleSTFTDiscriminator(nn.Module): 109 | """Multi-Scale STFT (MS-STFT) discriminator. 110 | Args: 111 | filters (int): Number of filters in convolutions 112 | in_channels (int): Number of input channels. Default: 1 113 | out_channels (int): Number of output channels. Default: 1 114 | n_ffts (Sequence[int]): Size of FFT for each scale 115 | hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale 116 | win_lengths (Sequence[int]): Window size for each scale 117 | **kwargs: additional args for STFTDiscriminator 118 | """ 119 | def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, 120 | n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128], 121 | win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs): 122 | super().__init__() 123 | assert len(n_ffts) == len(hop_lengths) == len(win_lengths) 124 | self.discriminators = nn.ModuleList([ 125 | DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels, 126 | n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs) 127 | for i in range(len(n_ffts)) 128 | ]) 129 | self.num_discriminators = len(self.discriminators) 130 | 131 | def forward(self, x: torch.Tensor) -> DiscriminatorOutput: 132 | logits = [] 133 | fmaps = [] 134 | for disc in self.discriminators: 135 | logit, fmap = disc(x) 136 | logits.append(logit) 137 | fmaps.append(fmap) 138 | return logits, fmaps 139 | -------------------------------------------------------------------------------- /stable_audio_tools/models/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def create_model_from_config(model_config): 4 | model_type = model_config.get('model_type', None) 5 | 6 | assert model_type is not None, 'model_type must be specified in model config' 7 | 8 | if model_type == 'autoencoder': 9 | from .autoencoders import create_autoencoder_from_config 10 | return create_autoencoder_from_config(model_config) 11 | elif model_type == 'diffusion_uncond': 12 | from .diffusion import create_diffusion_uncond_from_config 13 | return create_diffusion_uncond_from_config(model_config) 14 | elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint': 15 | from .diffusion import create_diffusion_cond_from_config 16 | return create_diffusion_cond_from_config(model_config) 17 | elif model_type == 'diffusion_autoencoder': 18 | from .autoencoders import create_diffAE_from_config 19 | return create_diffAE_from_config(model_config) 20 | elif model_type == 'lm': 21 | from .lm import create_audio_lm_from_config 22 | return create_audio_lm_from_config(model_config) 23 | else: 24 | raise NotImplementedError(f'Unknown model type: {model_type}') 25 | 26 | def create_model_from_config_path(model_config_path): 27 | with open(model_config_path) as f: 28 | model_config = json.load(f) 29 | 30 | return create_model_from_config(model_config) 31 | 32 | def create_pretransform_from_config(pretransform_config, sample_rate): 33 | pretransform_type = pretransform_config.get('type', None) 34 | 35 | assert pretransform_type is not None, 'type must be specified in pretransform config' 36 | 37 | if pretransform_type == 'autoencoder': 38 | from .autoencoders import create_autoencoder_from_config 39 | from .pretransforms import AutoencoderPretransform 40 | 41 | # Create fake top-level config to pass sample rate to autoencoder constructor 42 | # This is a bit of a hack but it keeps us from re-defining the sample rate in the config 43 | autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]} 44 | autoencoder = create_autoencoder_from_config(autoencoder_config) 45 | 46 | scale = pretransform_config.get("scale", 1.0) 47 | model_half = pretransform_config.get("model_half", False) 48 | iterate_batch = pretransform_config.get("iterate_batch", False) 49 | chunked = pretransform_config.get("chunked", False) 50 | 51 | pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) 52 | elif pretransform_type == 'wavelet': 53 | from .pretransforms import WaveletPretransform 54 | 55 | wavelet_config = pretransform_config["config"] 56 | channels = wavelet_config["channels"] 57 | levels = wavelet_config["levels"] 58 | wavelet = wavelet_config["wavelet"] 59 | 60 | pretransform = WaveletPretransform(channels, levels, wavelet) 61 | elif pretransform_type == 'pqmf': 62 | from .pretransforms import PQMFPretransform 63 | pqmf_config = pretransform_config["config"] 64 | pretransform = PQMFPretransform(**pqmf_config) 65 | elif pretransform_type == 'dac_pretrained': 66 | from .pretransforms import PretrainedDACPretransform 67 | pretrained_dac_config = pretransform_config["config"] 68 | pretransform = PretrainedDACPretransform(**pretrained_dac_config) 69 | elif pretransform_type == "audiocraft_pretrained": 70 | from .pretransforms import AudiocraftCompressionPretransform 71 | 72 | audiocraft_config = pretransform_config["config"] 73 | pretransform = AudiocraftCompressionPretransform(**audiocraft_config) 74 | elif pretransform_type == "patched": 75 | from .pretransforms import PatchedPretransform 76 | 77 | patched_config = pretransform_config["config"] 78 | pretransform = PatchedPretransform(**patched_config) 79 | else: 80 | raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') 81 | 82 | enable_grad = pretransform_config.get('enable_grad', False) 83 | pretransform.enable_grad = enable_grad 84 | 85 | pretransform.eval().requires_grad_(pretransform.enable_grad) 86 | 87 | return pretransform 88 | 89 | def create_bottleneck_from_config(bottleneck_config): 90 | bottleneck_type = bottleneck_config.get('type', None) 91 | 92 | assert bottleneck_type is not None, 'type must be specified in bottleneck config' 93 | 94 | if bottleneck_type == 'tanh': 95 | from .bottleneck import TanhBottleneck 96 | bottleneck = TanhBottleneck(**bottleneck_config.get('config', {})) 97 | elif bottleneck_type == 'vae': 98 | from .bottleneck import VAEBottleneck 99 | bottleneck = VAEBottleneck() 100 | elif bottleneck_type == 'rvq': 101 | from .bottleneck import RVQBottleneck 102 | 103 | quantizer_params = { 104 | "dim": 128, 105 | "codebook_size": 1024, 106 | "num_quantizers": 8, 107 | "decay": 0.99, 108 | "kmeans_init": True, 109 | "kmeans_iters": 50, 110 | "threshold_ema_dead_code": 2, 111 | } 112 | 113 | quantizer_params.update(bottleneck_config["config"]) 114 | 115 | bottleneck = RVQBottleneck(**quantizer_params) 116 | elif bottleneck_type == "dac_rvq": 117 | from .bottleneck import DACRVQBottleneck 118 | 119 | bottleneck = DACRVQBottleneck(**bottleneck_config["config"]) 120 | 121 | elif bottleneck_type == 'rvq_vae': 122 | from .bottleneck import RVQVAEBottleneck 123 | 124 | quantizer_params = { 125 | "dim": 128, 126 | "codebook_size": 1024, 127 | "num_quantizers": 8, 128 | "decay": 0.99, 129 | "kmeans_init": True, 130 | "kmeans_iters": 50, 131 | "threshold_ema_dead_code": 2, 132 | } 133 | 134 | quantizer_params.update(bottleneck_config["config"]) 135 | 136 | bottleneck = RVQVAEBottleneck(**quantizer_params) 137 | 138 | elif bottleneck_type == 'dac_rvq_vae': 139 | from .bottleneck import DACRVQVAEBottleneck 140 | bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"]) 141 | elif bottleneck_type == 'l2_norm': 142 | from .bottleneck import L2Bottleneck 143 | bottleneck = L2Bottleneck() 144 | elif bottleneck_type == "wasserstein": 145 | from .bottleneck import WassersteinBottleneck 146 | bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {})) 147 | elif bottleneck_type == "fsq": 148 | from .bottleneck import FSQBottleneck 149 | bottleneck = FSQBottleneck(**bottleneck_config["config"]) 150 | elif bottleneck_type == "dithered_fsq": 151 | from .bottleneck import DitheredFSQBottleneck 152 | return DitheredFSQBottleneck(**bottleneck_config["config"]) 153 | else: 154 | raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') 155 | 156 | requires_grad = bottleneck_config.get('requires_grad', True) 157 | if not requires_grad: 158 | for param in bottleneck.parameters(): 159 | param.requires_grad = False 160 | 161 | return bottleneck 162 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # stable-audio-tools 2 | Training and inference code for audio generation models 3 | 4 | # Install 5 | 6 | The library can be installed from PyPI with: 7 | ```bash 8 | $ pip install stable-audio-tools 9 | ``` 10 | 11 | To run the training scripts or inference code, you'll want to clone this repository, navigate to the root, and run: 12 | ```bash 13 | $ pip install . 14 | ``` 15 | 16 | # Requirements 17 | Requires PyTorch 2.5 or later for Flash Attention and Flex Attention support 18 | 19 | Development for the repo is done in Python 3.10 20 | 21 | # Interface 22 | 23 | A basic Gradio interface is provided to test out trained models. 24 | 25 | For example, to create an interface for the [`stable-audio-open-1.0`](https://huggingface.co/stabilityai/stable-audio-open-1.0) model, once you've accepted the terms for the model on Hugging Face, you can run: 26 | ```bash 27 | $ python3 ./run_gradio.py --pretrained-name stabilityai/stable-audio-open-1.0 28 | ``` 29 | 30 | The `run_gradio.py` script accepts the following command line arguments: 31 | 32 | - `--pretrained-name` 33 | - Hugging Face repository name for a Stable Audio Tools model 34 | - Will prioritize `model.safetensors` over `model.ckpt` in the repo 35 | - Optional, used in place of `model-config` and `ckpt-path` when using pre-trained model checkpoints on Hugging Face 36 | - `--model-config` 37 | - Path to the model config file for a local model 38 | - `--ckpt-path` 39 | - Path to unwrapped model checkpoint file for a local model 40 | - `--pretransform-ckpt-path` 41 | - Path to an unwrapped pretransform checkpoint, replaces the pretransform in the model, useful for testing out fine-tuned decoders 42 | - Optional 43 | - `--share` 44 | - If true, a publicly shareable link will be created for the Gradio demo 45 | - Optional 46 | - `--username` and `--password` 47 | - Used together to set a login for the Gradio demo 48 | - Optional 49 | - `--model-half` 50 | - If true, the model weights to half-precision 51 | - Optional 52 | 53 | # Training 54 | 55 | ## Prerequisites 56 | Before starting your training run, you'll need a model config file, as well as a dataset config file. For more information about those, refer to the Configurations section below 57 | 58 | The training code also requires a Weights & Biases account to log the training outputs and demos. Create an account and log in with: 59 | ```bash 60 | $ wandb login 61 | ``` 62 | 63 | ## Start training 64 | To start a training run, run the `train.py` script in the repo root with: 65 | ```bash 66 | $ python3 ./train.py --dataset-config /path/to/dataset/config --model-config /path/to/model/config --name harmonai_train 67 | ``` 68 | 69 | The `--name` parameter will set the project name for your Weights and Biases run. 70 | 71 | ## Training wrappers and model unwrapping 72 | `stable-audio-tools` uses PyTorch Lightning to facilitate multi-GPU and multi-node training. 73 | 74 | When a model is being trained, it is wrapped in a "training wrapper", which is a `pl.LightningModule` that contains all of the relevant objects needed only for training. That includes things like discriminators for autoencoders, EMA copies of models, and all of the optimizer states. 75 | 76 | The checkpoint files created during training include this training wrapper, which greatly increases the size of the checkpoint file. 77 | 78 | `unwrap_model.py` in the repo root will take in a wrapped model checkpoint and save a new checkpoint file including only the model itself. 79 | 80 | That can be run with from the repo root with: 81 | ```bash 82 | $ python3 ./unwrap_model.py --model-config /path/to/model/config --ckpt-path /path/to/wrapped/ckpt --name model_unwrap 83 | ``` 84 | 85 | Unwrapped model checkpoints are required for: 86 | - Inference scripts 87 | - Using a model as a pretransform for another model (e.g. using an autoencoder model for latent diffusion) 88 | - Fine-tuning a pre-trained model with a modified configuration (i.e. partial initialization) 89 | 90 | ## Fine-tuning 91 | Fine-tuning a model involves continuning a training run from a pre-trained checkpoint. 92 | 93 | To continue a training run from a wrapped model checkpoint, you can pass in the checkpoint path to `train.py` with the `--ckpt-path` flag. 94 | 95 | To start a fresh training run using a pre-trained unwrapped model, you can pass in the unwrapped checkpoint to `train.py` with the `--pretrained-ckpt-path` flag. 96 | 97 | ## Additional training flags 98 | 99 | Additional optional flags for `train.py` include: 100 | - `--config-file` 101 | - The path to the defaults.ini file in the repo root, required if running `train.py` from a directory other than the repo root 102 | - `--pretransform-ckpt-path` 103 | - Used in various model types such as latent diffusion models to load a pre-trained autoencoder. Requires an unwrapped model checkpoint. 104 | - `--save-dir` 105 | - The directory in which to save the model checkpoints 106 | - `--checkpoint-every` 107 | - The number of steps between saved checkpoints. 108 | - *Default*: 10000 109 | - `--batch-size` 110 | - Number of samples per-GPU during training. Should be set as large as your GPU VRAM will allow. 111 | - *Default*: 8 112 | - `--num-gpus` 113 | - Number of GPUs per-node to use for training 114 | - *Default*: 1 115 | - `--num-nodes` 116 | - Number of GPU nodes being used for training 117 | - *Default*: 1 118 | - `--accum-batches` 119 | - Enables and sets the number of batches for gradient batch accumulation. Useful for increasing effective batch size when training on smaller GPUs. 120 | - `--strategy` 121 | - Multi-GPU strategy for distributed training. Setting to `deepspeed` will enable DeepSpeed ZeRO Stage 2. 122 | - *Default*: `ddp` if `--num_gpus` > 1, else None 123 | - `--precision` 124 | - floating-point precision to use during training 125 | - *Default*: 16 126 | - `--num-workers` 127 | - Number of CPU workers used by the data loader 128 | - `--seed` 129 | - RNG seed for PyTorch, helps with deterministic training 130 | 131 | # Configurations 132 | Training and inference code for `stable-audio-tools` is based around JSON configuration files that define model hyperparameters, training settings, and information about your training dataset. 133 | 134 | ## Model config 135 | The model config file defines all of the information needed to load a model for training or inference. It also contains the training configuration needed to fine-tune a model or train from scratch. 136 | 137 | The following properties are defined in the top level of the model configuration: 138 | 139 | - `model_type` 140 | - The type of model being defined, currently limited to one of `"autoencoder", "diffusion_uncond", "diffusion_cond", "diffusion_cond_inpaint", "diffusion_autoencoder", "lm"`. 141 | - `sample_size` 142 | - The length of the audio provided to the model during training, in samples. For diffusion models, this is also the raw audio sample length used for inference. 143 | - `sample_rate` 144 | - The sample rate of the audio provided to the model during training, and generated during inference, in Hz. 145 | - `audio_channels` 146 | - The number of channels of audio provided to the model during training, and generated during inference. Defaults to 2. Set to 1 for mono. 147 | - `model` 148 | - The specific configuration for the model being defined, varies based on `model_type` 149 | - `training` 150 | - The training configuration for the model, varies based on `model_type`. Provides parameters for training as well as demos. 151 | 152 | ## Dataset config 153 | `stable-audio-tools` currently supports two kinds of data sources: local directories of audio files, and WebDataset datasets stored in Amazon S3. More information can be found in [the dataset config documentation](docs/datasets.md) 154 | 155 | # Todo 156 | - [ ] Add troubleshooting section 157 | - [ ] Add contribution guidelines 158 | -------------------------------------------------------------------------------- /pre_encode.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import json 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | import torch 9 | from torch.nn import functional as F 10 | 11 | from stable_audio_tools.data.dataset import create_dataloader_from_config 12 | from stable_audio_tools.models.factory import create_model_from_config 13 | from stable_audio_tools.models.pretrained import get_pretrained_model 14 | from stable_audio_tools.models.utils import load_ckpt_state_dict, copy_state_dict 15 | 16 | 17 | def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, model_half=False): 18 | if pretrained_name is not None: 19 | print(f"Loading pretrained model {pretrained_name}") 20 | model, model_config = get_pretrained_model(pretrained_name) 21 | 22 | elif model_config is not None and model_ckpt_path is not None: 23 | print(f"Creating model from config") 24 | model = create_model_from_config(model_config) 25 | 26 | print(f"Loading model checkpoint from {model_ckpt_path}") 27 | copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path)) 28 | 29 | model.eval().requires_grad_(False) 30 | 31 | if model_half: 32 | model.to(torch.float16) 33 | 34 | print("Done loading model") 35 | 36 | return model, model_config 37 | 38 | 39 | class PreEncodedLatentsInferenceWrapper(pl.LightningModule): 40 | def __init__( 41 | self, 42 | model, 43 | output_path, 44 | is_discrete=False, 45 | model_half=False, 46 | model_config=None, 47 | dataset_config=None, 48 | sample_size=1920000, 49 | args_dict=None 50 | ): 51 | super().__init__() 52 | self.save_hyperparameters(ignore=['model']) 53 | self.model = model 54 | self.output_path = Path(output_path) 55 | 56 | def prepare_data(self): 57 | # runs on rank 0 58 | self.output_path.mkdir(parents=True, exist_ok=True) 59 | details_path = self.output_path / "details.json" 60 | if not details_path.exists(): # Only save if it doesn't exist 61 | details = { 62 | "model_config": self.hparams.model_config, 63 | "dataset_config": self.hparams.dataset_config, 64 | "sample_size": self.hparams.sample_size, 65 | "args": self.hparams.args_dict 66 | } 67 | details_path.write_text(json.dumps(details)) 68 | 69 | def setup(self, stage=None): 70 | # runs on each device 71 | process_dir = self.output_path / str(self.global_rank) 72 | process_dir.mkdir(parents=True, exist_ok=True) 73 | 74 | def validation_step(self, batch, batch_idx): 75 | audio, metadata = batch 76 | 77 | if audio.ndim == 4 and audio.shape[0] == 1: 78 | audio = audio[0] 79 | 80 | if torch.cuda.is_available(): 81 | torch.cuda.empty_cache() 82 | gc.collect() 83 | 84 | if self.hparams.model_half: 85 | audio = audio.to(torch.float16) 86 | 87 | with torch.no_grad(): 88 | if not self.hparams.is_discrete: 89 | latents = self.model.encode(audio) 90 | else: 91 | _, info = self.model.encode(audio, return_info=True) 92 | latents = info[self.model.bottleneck.tokens_id] 93 | 94 | latents = latents.cpu().numpy() 95 | 96 | # Save each sample in the batch 97 | for i, latent in enumerate(latents): 98 | latent_id = f"{self.global_rank:03d}{batch_idx:06d}{i:04d}" 99 | 100 | # Save latent as numpy file 101 | latent_path = self.output_path / str(self.global_rank) / f"{latent_id}.npy" 102 | with open(latent_path, "wb") as f: 103 | np.save(f, latent) 104 | 105 | md = metadata[i] 106 | padding_mask = F.interpolate( 107 | md["padding_mask"].unsqueeze(0).unsqueeze(1).float(), 108 | size=latent.shape[1], 109 | mode="nearest" 110 | ).squeeze().int() 111 | md["padding_mask"] = padding_mask.cpu().numpy().tolist() 112 | 113 | # Convert tensors in md to serializable types 114 | for k, v in md.items(): 115 | if isinstance(v, torch.Tensor): 116 | md[k] = v.cpu().numpy().tolist() 117 | 118 | # Save metadata to json file 119 | metadata_path = self.output_path / str(self.global_rank) / f"{latent_id}.json" 120 | with open(metadata_path, "w") as f: 121 | json.dump(md, f) 122 | 123 | def configure_optimizers(self): 124 | return None 125 | 126 | 127 | def main(args): 128 | with open(args.model_config) as f: 129 | model_config = json.load(f) 130 | 131 | with open(args.dataset_config) as f: 132 | dataset_config = json.load(f) 133 | 134 | model, model_config = load_model( 135 | model_config=model_config, 136 | model_ckpt_path=args.ckpt_path, 137 | model_half=args.model_half 138 | ) 139 | 140 | data_loader = create_dataloader_from_config( 141 | dataset_config, 142 | batch_size=args.batch_size, 143 | num_workers=args.num_workers, 144 | sample_rate=model_config["sample_rate"], 145 | sample_size=args.sample_size, 146 | audio_channels=model_config.get("audio_channels", 2), 147 | shuffle=args.shuffle 148 | ) 149 | 150 | pl_module = PreEncodedLatentsInferenceWrapper( 151 | model=model, 152 | output_path=args.output_path, 153 | is_discrete=args.is_discrete, 154 | model_half=args.model_half, 155 | model_config=args.model_config, 156 | dataset_config=args.dataset_config, 157 | sample_size=args.sample_size, 158 | args_dict=vars(args) 159 | ) 160 | 161 | trainer = pl.Trainer( 162 | accelerator="gpu", 163 | devices="auto", 164 | num_nodes = args.num_nodes, 165 | strategy=args.strategy, 166 | precision="16-true" if args.model_half else "32", 167 | max_steps=args.limit_batches if args.limit_batches else -1, 168 | logger=False, # Disable logging since we're just doing inference 169 | enable_checkpointing=False, 170 | ) 171 | trainer.validate(pl_module, data_loader) 172 | 173 | if __name__ == "__main__": 174 | parser = argparse.ArgumentParser(description='Encode audio dataset to VAE latents using PyTorch Lightning') 175 | parser.add_argument('--model-config', type=str, help='Path to model config', required=False) 176 | parser.add_argument('--ckpt-path', type=str, help='Path to unwrapped autoencoder model checkpoint', required=False) 177 | parser.add_argument('--model-half', action='store_true', help='Whether to use half precision') 178 | parser.add_argument('--dataset-config', type=str, help='Path to dataset config file', required=True) 179 | parser.add_argument('--output-path', type=str, help='Path to output folder', required=True) 180 | parser.add_argument('--batch-size', type=int, help='Batch size', default=1) 181 | parser.add_argument('--sample-size', type=int, help='Number of audio samples to pad/crop to', default=1320960) 182 | parser.add_argument('--is-discrete', action='store_true', help='Whether the model is discrete') 183 | parser.add_argument('--num-nodes', type=int, help='Number of GPU nodes', default=1) 184 | parser.add_argument('--num-workers', type=int, help='Number of dataloader workers', default=4) 185 | parser.add_argument('--strategy', type=str, help='PyTorch Lightning strategy', default='auto') 186 | parser.add_argument('--limit-batches', type=int, help='Limit number of batches (optional)', default=None) 187 | parser.add_argument('--shuffle', action='store_true', help='Shuffle dataset') 188 | args = parser.parse_args() 189 | main(args) -------------------------------------------------------------------------------- /docs/conditioning.md: -------------------------------------------------------------------------------- 1 | # Conditioning 2 | Conditioning, in the context of `stable-audio-tools` is the use of additional signals in a model that are used to add an additional level of control over the model's behavior. For example, we can condition the outputs of a diffusion model on a text prompt, creating a text-to-audio model. 3 | 4 | # Conditioning types 5 | There are a few different kinds of conditioning depending on the conditioning signal being used. 6 | 7 | ## Cross attention 8 | Cross attention is a type of conditioning that allows us to find correlations between two sequences of potentially different lengths. For example, cross attention allows us to find correlations between a sequence of features from a text encoder and a sequence of high-level audio features. 9 | 10 | Signals used for cross-attention conditioning should be of the shape `[batch, sequence, channels]`. 11 | 12 | ## Global conditioning 13 | Global conditioning is the use of a single n-dimensional tensor to provide conditioning information that pertains to the whole sequence being conditioned. For example, this could be the single embedding output of a CLAP model, or a learned class embedding. 14 | 15 | Signals used for global conditioning should be of the shape `[batch, channels]`. 16 | 17 | ## Prepend conditioning 18 | Prepend conditioning involves prepending the conditioning tokens to the data tokens in the model, allowing for the information to be interpreted through the model's self-attention mechanism. 19 | 20 | This kind of conditioning is currently only supported by Transformer-based models such as diffusion transformers. 21 | 22 | Signals used for prepend conditioning should be of the shape `[batch, sequence, channels]`. 23 | 24 | ## Input concatenation 25 | Input concatenation applies a spatial conditioning signal to the model that correlates in the sequence dimension with the model's input, and is of the same length. The conditioning signal will be concatenated with the model's input data along the channel dimension. This can be used for things like inpainting information, melody conditioning, or for creating a diffusion autoencoder. 26 | 27 | Signals used for input concatenation conditioning should be of the shape `[batch, channels, sequence]` and must be the same length as the model's input. 28 | 29 | # Conditioners and conditioning configs 30 | `stable-audio-tools` uses Conditioner modules to translate human-readable metadata such as text prompts or a number of seconds into tensors that the model can take as input. 31 | 32 | Each conditioner has a corresponding `id` that it expects to find in the conditioning dictionary provided during training or inference. Each conditioner takes in the relevant conditioning data and returns a tuple containing the corresponding tensor and a mask. 33 | 34 | The ConditionedDiffusionModelWrapper manages the translation between the user-provided metadata dictionary (e.g. `{"prompt": "a beautiful song", "seconds_start": 22, "seconds_total": 193}`) and the dictionary of different conditioning types that the model uses (e.g. `{"cross_attn_cond": ...}`). 35 | 36 | To apply conditioning to a model, you must provide a `conditioning` configuration in the model's config. At the moment, we only support conditioning diffusion models though the `diffusion_cond` model type. 37 | 38 | The `conditioning` configuration should contain a `configs` array, which allows you to define multiple conditioning signals. 39 | 40 | Each item in `configs` array should define the `id` for the corresponding metadata, the type of conditioner to be used, and the config for that conditioner. 41 | 42 | The `cond_dim` property is used to enforce the same dimension on all conditioning inputs, however that can be overridden with an explicit `output_dim` property on any of the individual configs. 43 | 44 | ## Example config 45 | ```json 46 | "conditioning": { 47 | "configs": [ 48 | { 49 | "id": "prompt", 50 | "type": "t5", 51 | "config": { 52 | "t5_model_name": "t5-base", 53 | "max_length": 77, 54 | "project_out": true 55 | } 56 | } 57 | ], 58 | "cond_dim": 768 59 | } 60 | ``` 61 | 62 | # Conditioners 63 | 64 | ## Text encoders 65 | 66 | ### `t5` 67 | This uses a frozen [T5](https://huggingface.co/docs/transformers/model_doc/t5) text encoder from the `transformers` library to encode text prompts into a sequence of text features. 68 | 69 | The `t5_model_name` property determines which T5 model is loaded from the `transformers` library. 70 | 71 | The `max_length` property determines the maximum number of tokens that the text encoder will take in, as well as the sequence length of the output text features. 72 | 73 | If you set `enable_grad` to `true`, the T5 model will be un-frozen and saved with the model checkpoint, allowing you to fine-tune the T5 model. 74 | 75 | T5 encodings are only compatible with cross attention conditioning. 76 | 77 | #### Example config 78 | ```json 79 | { 80 | "id": "prompt", 81 | "type": "t5", 82 | "config": { 83 | "t5_model_name": "t5-base", 84 | "max_length": 77, 85 | "project_out": true 86 | } 87 | } 88 | ``` 89 | 90 | ### `clap_text` 91 | This loads the text encoder from a [CLAP](https://github.com/LAION-AI/CLAP) model, which can provide either a sequence of text features, or a single multimodal text/audio embedding. 92 | 93 | The CLAP model must be provided with a local file path, set in the `clap_ckpt_path` property,along with the correct `audio_model_type` and `enable_fusion` properties for the provided model. 94 | 95 | If the `use_text_features` property is set to `true`, the conditioner output will be a sequence of text features, instead of a single multimodal embedding. This allows for more fine-grained text information to be used by the model, at the cost of losing the ability to prompt with CLAP audio embeddings. 96 | 97 | By default, if `use_text_features` is true, the last layer of the CLAP text encoder's features are returned. You can return the text features of earlier layers by specifying the index of the layer to return in the `feature_layer_ix` property. For example, you can return the text features of the next-to-last layer of the CLAP model by setting `feature_layer_ix` to `-2`. 98 | 99 | If you set `enable_grad` to `true`, the CLAP model will be un-frozen and saved with the model checkpoint, allowing you to fine-tune the CLAP model. 100 | 101 | CLAP text embeddings are compatible with global conditioning and cross attention conditioning. If `use_text_features` is set to `true`, the features are not compatible with global conditioning. 102 | 103 | #### Example config 104 | ```json 105 | { 106 | "id": "prompt", 107 | "type": "clap_text", 108 | "config": { 109 | "clap_ckpt_path": "/path/to/clap/model.ckpt", 110 | "audio_model_type": "HTSAT-base", 111 | "enable_fusion": true, 112 | "use_text_features": true, 113 | "feature_layer_ix": -2 114 | } 115 | } 116 | ``` 117 | 118 | ## Number encoders 119 | 120 | ### `int` 121 | The IntConditioner takes in a list of integers in a given range, and returns a discrete learned embedding for each of those integers. 122 | 123 | The `min_val` and `max_val` properties set the range of the embedding values. Input integers are clamped to this range. 124 | 125 | This can be used for things like discrete timing embeddings, or learned class embeddings. 126 | 127 | Int embeddings are compatible with global conditioning and cross attention conditioning. 128 | 129 | #### Example config 130 | ```json 131 | { 132 | "id": "seconds_start", 133 | "type": "int", 134 | "config": { 135 | "min_val": 0, 136 | "max_val": 512 137 | } 138 | } 139 | ``` 140 | 141 | ### `number` 142 | The NumberConditioner takes in a a list of floats in a given range, and returns a continuous Fourier embedding of the provided floats. 143 | 144 | The `min_val` and `max_val` properties set the range of the float values. This is the range used to normalize the input float values. 145 | 146 | Number embeddings are compatible with global conditioning and cross attention conditioning. 147 | 148 | #### Example config 149 | ```json 150 | { 151 | "id": "seconds_total", 152 | "type": "number", 153 | "config": { 154 | "min_val": 0, 155 | "max_val": 512 156 | } 157 | } 158 | ``` -------------------------------------------------------------------------------- /docs/diffusion.md: -------------------------------------------------------------------------------- 1 | # Diffusion 2 | 3 | Diffusion models learn to denoise data 4 | 5 | # Model configs 6 | The model config file for a diffusion model should set the `model_type` to `diffusion_cond` if the model uses conditioning, or `diffusion_uncond` if it does not, and the `model` object should have the following properties: 7 | 8 | - `diffusion` 9 | - The configuration for the diffusion model itself. See below for more information on the diffusion model config 10 | - `pretransform` 11 | - The configuration of the diffusion model's [pretransform](pretransforms.md), such as an autoencoder for latent diffusion. 12 | - Optional 13 | - `conditioning` 14 | - The configuration of the various [conditioning](conditioning.md) modules for the diffusion model 15 | - Only required for `diffusion_cond` 16 | - `io_channels` 17 | - The base number of input/output channels for the diffusion model 18 | - Used by inference scripts to determine the shape of the noise to generate for the diffusion model 19 | 20 | # Diffusion configs 21 | - `type` 22 | - The underlying model type for the transformer 23 | - For conditioned diffusion models, be one of `dit` ([Diffusion Transformer](#diffusion-transformers-dit)), `DAU1d` ([Dance Diffusion U-Net](#dance-diffusion-u-net)), or `adp_cfg_1d` ([audio-diffusion-pytorch U-Net](#audio-diffusion-pytorch-u-net-adp)) 24 | - Unconditioned diffusion models can also use `adp_1d` 25 | - `cross_attention_cond_ids` 26 | - Conditioner ids for conditioning information to be used as cross-attention input 27 | - If multiple ids are specified, the conditioning tensors will be concatenated along the sequence dimension 28 | - `global_cond_ids` 29 | - Conditioner ids for conditioning information to be used as global conditioning input 30 | - If multiple ids are specified, the conditioning tensors will be concatenated along the channel dimension 31 | - `prepend_cond_ids` 32 | - Conditioner ids for conditioning information to be prepended to the model input 33 | - If multiple ids are specified, the conditioning tensors will be concatenated along the sequence dimension 34 | - Only works with diffusion transformer models 35 | - `input_concat_ids` 36 | - Conditioner ids for conditioning information to be concatenated to the model input 37 | - If multiple ids are specified, the conditioning tensors will be concatenated along the channel dimension 38 | - If the conditioning tensors are not the same length as the model input, they will be interpolated along the sequence dimension to be the same length. 39 | - The interpolation algorithm is model-dependent, but usually uses nearest-neighbor resampling. 40 | - `config` 41 | - The configuration for the model backbone itself 42 | - Model-dependent 43 | 44 | # Training configs 45 | The `training` config in the diffusion model config file should have the following properties: 46 | 47 | - `learning_rate` 48 | - The learning rate to use during training 49 | - Defaults to constant learning rate, can be overridden with `optimizer_configs` 50 | - `use_ema` 51 | - If true, a copy of the model weights is maintained during training and updated as an exponential moving average of the trained model's weights. 52 | - Optional. Default: `true` 53 | - `log_loss_info` 54 | - If true, additional diffusion loss info will be gathered across all GPUs and displayed during training 55 | - Optional. Default: `false` 56 | - `loss_configs` 57 | - Configurations for the loss function calculation 58 | - Optional 59 | - `optimizer_configs` 60 | - Configuration for optimizers and schedulers 61 | - Optional, overrides `learning_rate` 62 | - `demo` 63 | - Configuration for the demos during training, including conditioning information 64 | - `pre_encoded` 65 | - If true, indicates that the model should operate on [pre encoded latents](pre_encoding.md) instead of raw audio 66 | - Required when training with [pre encoded datasets](datasets.md#pre-encoded-datasets) 67 | - Optional. Default: `false` 68 | 69 | ## Example config 70 | ```json 71 | "training": { 72 | "use_ema": true, 73 | "log_loss_info": false, 74 | "optimizer_configs": { 75 | "diffusion": { 76 | "optimizer": { 77 | "type": "AdamW", 78 | "config": { 79 | "lr": 5e-5, 80 | "betas": [0.9, 0.999], 81 | "weight_decay": 1e-3 82 | } 83 | }, 84 | "scheduler": { 85 | "type": "InverseLR", 86 | "config": { 87 | "inv_gamma": 1000000, 88 | "power": 0.5, 89 | "warmup": 0.99 90 | } 91 | } 92 | } 93 | }, 94 | "demo": { ... } 95 | } 96 | ``` 97 | 98 | # Demo configs 99 | The `demo` config in the diffusion model training config should have the following properties: 100 | - `demo_every` 101 | - How many training steps between demos 102 | - `demo_steps` 103 | - Number of diffusion timesteps to run for the demos 104 | - `num_demos` 105 | - This is the number of examples to generate in each demo 106 | - `demo_cond` 107 | - For conditioned diffusion models, this is the conditioning metadata to provide to each example, provided as a list 108 | - NOTE: List must be the same length as `num_demos` 109 | - `demo_cfg_scales` 110 | - For conditioned diffusion models, this provides a list of classifier-free guidance (CFG) scales to render during the demos. This can be helpful to get an idea of how the model responds to different conditioning strengths as training continues. 111 | 112 | ## Example config 113 | ```json 114 | "demo": { 115 | "demo_every": 2000, 116 | "demo_steps": 250, 117 | "num_demos": 4, 118 | "demo_cond": [ 119 | {"prompt": "A beautiful piano arpeggio", "seconds_start": 0, "seconds_total": 80}, 120 | {"prompt": "A tropical house track with upbeat melodies, a driving bassline, and cheery vibes", "seconds_start": 0, "seconds_total": 250}, 121 | {"prompt": "A cool 80s glam rock song with driving drums and distorted guitars", "seconds_start": 0, "seconds_total": 180}, 122 | {"prompt": "A grand orchestral arrangement", "seconds_start": 0, "seconds_total": 190} 123 | ], 124 | "demo_cfg_scales": [3, 6, 9] 125 | } 126 | ``` 127 | 128 | # Model types 129 | 130 | A variety of different model types can be used as the underlying backbone for a diffusion model. At the moment, this includes variants of U-Net and Transformer models. 131 | 132 | ## Diffusion Transformers (DiT) 133 | 134 | Transformers tend to consistently outperform U-Nets in terms of model quality, but are much more memory- and compute-intensive and work best on shorter sequences such as latent encodings of audio. 135 | 136 | ### Continuous Transformer 137 | 138 | This is our custom implementation of a transformer model, based on the `x-transformers` implementation, but with efficiency improvements such as fused QKV layers, and Flash Attention 2 support. 139 | 140 | ### `x-transformers` 141 | 142 | This model type uses the `ContinuousTransformerWrapper` class from the https://github.com/lucidrains/x-transformers repository as the diffusion transformer backbone. 143 | 144 | `x-transformers` is a great baseline transformer implementation with lots of options for various experimental settings. 145 | It's great for testing out experimental features without implementing them yourself, but the implementations might not be fully optimized, and breaking changes may be introduced without much warning. 146 | 147 | ## Diffusion U-Net 148 | 149 | U-Nets use a hierarchical architecture to gradually downsample the input data before more heavy processing is performed, then upsample the data again, using skip connections to pass data across the downsampling "valley" (the "U" in the name) to the upsampling layer at the same resolution. 150 | 151 | ### audio-diffusion-pytorch U-Net (ADP) 152 | 153 | This model type uses a modified implementation of the `UNetCFG1D` class from version 0.0.94 of the `https://github.com/archinetai/audio-diffusion-pytorch` repo, with added Flash Attention support. 154 | 155 | ### Dance Diffusion U-Net 156 | 157 | This is a reimplementation of the U-Net used in [Dance Diffusion](https://github.com/Harmonai-org/sample-generator). It has minimal conditioning support, only really supporting global conditioning. Mostly used for unconditional diffusion models. -------------------------------------------------------------------------------- /stable_audio_tools/models/local_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | from .blocks import AdaRMSNorm 7 | from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm 8 | 9 | def checkpoint(function, *args, **kwargs): 10 | kwargs.setdefault("use_reentrant", False) 11 | return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) 12 | 13 | # Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py 14 | class ContinuousLocalTransformer(nn.Module): 15 | def __init__( 16 | self, 17 | *, 18 | dim, 19 | depth, 20 | dim_in = None, 21 | dim_out = None, 22 | causal = False, 23 | local_attn_window_size = 64, 24 | heads = 8, 25 | ff_mult = 2, 26 | cond_dim = 0, 27 | cross_attn_cond_dim = 0, 28 | **kwargs 29 | ): 30 | super().__init__() 31 | 32 | dim_head = dim//heads 33 | 34 | self.layers = nn.ModuleList([]) 35 | 36 | self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity() 37 | 38 | self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity() 39 | 40 | self.local_attn_window_size = local_attn_window_size 41 | 42 | self.cond_dim = cond_dim 43 | 44 | self.cross_attn_cond_dim = cross_attn_cond_dim 45 | 46 | self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32)) 47 | 48 | for _ in range(depth): 49 | 50 | self.layers.append(nn.ModuleList([ 51 | AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), 52 | Attention( 53 | dim=dim, 54 | dim_heads=dim_head, 55 | causal=causal, 56 | zero_init_output=True, 57 | natten_kernel_size=local_attn_window_size, 58 | ), 59 | Attention( 60 | dim=dim, 61 | dim_heads=dim_head, 62 | dim_context = cross_attn_cond_dim, 63 | zero_init_output=True 64 | ) if self.cross_attn_cond_dim > 0 else nn.Identity(), 65 | AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), 66 | FeedForward(dim = dim, mult = ff_mult, no_bias=True) 67 | ])) 68 | 69 | def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None): 70 | 71 | x = checkpoint(self.project_in, x) 72 | 73 | if prepend_cond is not None: 74 | x = torch.cat([prepend_cond, x], dim=1) 75 | 76 | pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) 77 | 78 | for attn_norm, attn, xattn, ff_norm, ff in self.layers: 79 | 80 | residual = x 81 | if cond is not None: 82 | x = checkpoint(attn_norm, x, cond) 83 | else: 84 | x = checkpoint(attn_norm, x) 85 | 86 | x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual 87 | 88 | if cross_attn_cond is not None: 89 | x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x 90 | 91 | residual = x 92 | 93 | if cond is not None: 94 | x = checkpoint(ff_norm, x, cond) 95 | else: 96 | x = checkpoint(ff_norm, x) 97 | 98 | x = checkpoint(ff, x) + residual 99 | 100 | return checkpoint(self.project_out, x) 101 | 102 | class TransformerDownsampleBlock1D(nn.Module): 103 | def __init__( 104 | self, 105 | in_channels, 106 | embed_dim = 768, 107 | depth = 3, 108 | heads = 12, 109 | downsample_ratio = 2, 110 | local_attn_window_size = 64, 111 | **kwargs 112 | ): 113 | super().__init__() 114 | 115 | self.downsample_ratio = downsample_ratio 116 | 117 | self.transformer = ContinuousLocalTransformer( 118 | dim=embed_dim, 119 | depth=depth, 120 | heads=heads, 121 | local_attn_window_size=local_attn_window_size, 122 | **kwargs 123 | ) 124 | 125 | self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() 126 | 127 | self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False) 128 | 129 | 130 | def forward(self, x): 131 | 132 | x = checkpoint(self.project_in, x) 133 | 134 | # Compute 135 | x = self.transformer(x) 136 | 137 | # Trade sequence length for channels 138 | x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio) 139 | 140 | # Project back to embed dim 141 | x = checkpoint(self.project_down, x) 142 | 143 | return x 144 | 145 | class TransformerUpsampleBlock1D(nn.Module): 146 | def __init__( 147 | self, 148 | in_channels, 149 | embed_dim, 150 | depth = 3, 151 | heads = 12, 152 | upsample_ratio = 2, 153 | local_attn_window_size = 64, 154 | **kwargs 155 | ): 156 | super().__init__() 157 | 158 | self.upsample_ratio = upsample_ratio 159 | 160 | self.transformer = ContinuousLocalTransformer( 161 | dim=embed_dim, 162 | depth=depth, 163 | heads=heads, 164 | local_attn_window_size = local_attn_window_size, 165 | **kwargs 166 | ) 167 | 168 | self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() 169 | 170 | self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False) 171 | 172 | def forward(self, x): 173 | 174 | # Project to embed dim 175 | x = checkpoint(self.project_in, x) 176 | 177 | # Project to increase channel dim 178 | x = checkpoint(self.project_up, x) 179 | 180 | # Trade channels for sequence length 181 | x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio) 182 | 183 | # Compute 184 | x = self.transformer(x) 185 | 186 | return x 187 | 188 | 189 | class TransformerEncoder1D(nn.Module): 190 | def __init__( 191 | self, 192 | in_channels, 193 | out_channels, 194 | embed_dims = [96, 192, 384, 768], 195 | heads = [12, 12, 12, 12], 196 | depths = [3, 3, 3, 3], 197 | ratios = [2, 2, 2, 2], 198 | local_attn_window_size = 64, 199 | **kwargs 200 | ): 201 | super().__init__() 202 | 203 | layers = [] 204 | 205 | for layer in range(len(depths)): 206 | prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] 207 | 208 | layers.append( 209 | TransformerDownsampleBlock1D( 210 | in_channels = prev_dim, 211 | embed_dim = embed_dims[layer], 212 | heads = heads[layer], 213 | depth = depths[layer], 214 | downsample_ratio = ratios[layer], 215 | local_attn_window_size = local_attn_window_size, 216 | **kwargs 217 | ) 218 | ) 219 | 220 | self.layers = nn.Sequential(*layers) 221 | 222 | self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) 223 | self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) 224 | 225 | def forward(self, x): 226 | x = rearrange(x, "b c n -> b n c") 227 | x = checkpoint(self.project_in, x) 228 | x = self.layers(x) 229 | x = checkpoint(self.project_out, x) 230 | x = rearrange(x, "b n c -> b c n") 231 | 232 | return x 233 | 234 | 235 | class TransformerDecoder1D(nn.Module): 236 | def __init__( 237 | self, 238 | in_channels, 239 | out_channels, 240 | embed_dims = [768, 384, 192, 96], 241 | heads = [12, 12, 12, 12], 242 | depths = [3, 3, 3, 3], 243 | ratios = [2, 2, 2, 2], 244 | local_attn_window_size = 64, 245 | **kwargs 246 | ): 247 | 248 | super().__init__() 249 | 250 | layers = [] 251 | 252 | for layer in range(len(depths)): 253 | prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] 254 | 255 | layers.append( 256 | TransformerUpsampleBlock1D( 257 | in_channels = prev_dim, 258 | embed_dim = embed_dims[layer], 259 | heads = heads[layer], 260 | depth = depths[layer], 261 | upsample_ratio = ratios[layer], 262 | local_attn_window_size = local_attn_window_size, 263 | **kwargs 264 | ) 265 | ) 266 | 267 | self.layers = nn.Sequential(*layers) 268 | 269 | self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) 270 | self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) 271 | 272 | def forward(self, x): 273 | x = rearrange(x, "b c n -> b n c") 274 | x = checkpoint(self.project_in, x) 275 | x = self.layers(x) 276 | x = checkpoint(self.project_out, x) 277 | x = rearrange(x, "b n c -> b c n") 278 | return x -------------------------------------------------------------------------------- /stable_audio_tools/models/pretransforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import nn 4 | from torchaudio.transforms import Resample 5 | 6 | class Pretransform(nn.Module): 7 | def __init__(self, enable_grad, io_channels, is_discrete): 8 | super().__init__() 9 | 10 | self.is_discrete = is_discrete 11 | self.io_channels = io_channels 12 | self.encoded_channels = None 13 | self.downsampling_ratio = None 14 | 15 | self.enable_grad = enable_grad 16 | 17 | def encode(self, x): 18 | raise NotImplementedError 19 | 20 | def decode(self, z): 21 | raise NotImplementedError 22 | 23 | def tokenize(self, x): 24 | raise NotImplementedError 25 | 26 | def decode_tokens(self, tokens): 27 | raise NotImplementedError 28 | 29 | class AutoencoderPretransform(Pretransform): 30 | def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False): 31 | super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete) 32 | self.model = model 33 | self.model.requires_grad_(False).eval() 34 | self.scale=scale 35 | self.downsampling_ratio = model.downsampling_ratio 36 | self.io_channels = model.io_channels 37 | self.sample_rate = model.sample_rate 38 | 39 | self.model_half = model_half 40 | self.iterate_batch = iterate_batch 41 | 42 | self.encoded_channels = model.latent_dim 43 | 44 | self.chunked = chunked 45 | self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None 46 | self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None 47 | 48 | if self.model_half: 49 | self.model.half() 50 | 51 | def encode(self, x, **kwargs): 52 | 53 | if self.model_half: 54 | x = x.half() 55 | 56 | encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) 57 | 58 | if self.model_half: 59 | encoded = encoded.float() 60 | 61 | return encoded / self.scale 62 | 63 | def decode(self, z, **kwargs): 64 | z = z * self.scale 65 | 66 | if self.model_half: 67 | z = z.half() 68 | 69 | decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) 70 | 71 | if self.model_half: 72 | decoded = decoded.float() 73 | 74 | return decoded 75 | 76 | def tokenize(self, x, **kwargs): 77 | assert self.model.is_discrete, "Cannot tokenize with a continuous model" 78 | 79 | _, info = self.model.encode(x, return_info = True, **kwargs) 80 | 81 | return info[self.model.bottleneck.tokens_id] 82 | 83 | def decode_tokens(self, tokens, **kwargs): 84 | assert self.model.is_discrete, "Cannot decode tokens with a continuous model" 85 | 86 | return self.model.decode_tokens(tokens, **kwargs) 87 | 88 | def load_state_dict(self, state_dict, strict=True): 89 | self.model.load_state_dict(state_dict, strict=strict) 90 | 91 | class WaveletPretransform(Pretransform): 92 | def __init__(self, channels, levels, wavelet, **kwargs): 93 | super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) 94 | 95 | from .wavelets import WaveletEncode1d, WaveletDecode1d 96 | 97 | self.encoder = WaveletEncode1d(channels, levels, wavelet) 98 | self.decoder = WaveletDecode1d(channels, levels, wavelet) 99 | 100 | self.downsampling_ratio = 2 ** levels 101 | self.io_channels = channels 102 | self.encoded_channels = channels * self.downsampling_ratio 103 | 104 | def encode(self, x): 105 | x = self.encoder(x) 106 | return x 107 | 108 | def decode(self, z): 109 | return self.decoder(z) 110 | 111 | class PatchedPretransform(Pretransform): 112 | def __init__(self, channels, patch_size, oversampling = 1, **kwargs): 113 | super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) 114 | self.channels = channels 115 | self.patch_size = patch_size 116 | self.oversampling = oversampling 117 | 118 | self.downsampling_ratio = patch_size 119 | self.io_channels = channels 120 | self.encoded_channels = channels * patch_size 121 | 122 | if self.oversampling > 1: 123 | self.input_upsampler = Resample(1, self.oversampling) 124 | self.output_downsampler = Resample(self.oversampling, 1) 125 | 126 | def _pad(self, x): 127 | seq_len = x.shape[-1] 128 | pad_len = (self.patch_size - (seq_len % self.patch_size)) % self.patch_size 129 | if pad_len > 0: 130 | x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1) 131 | return x 132 | 133 | def encode(self, x): 134 | if self.oversampling > 1: 135 | x = self.input_upsampler(x) 136 | x = self._pad(x) 137 | x = rearrange(x, "b c (l h) -> b (c h) l", h=self.patch_size) 138 | return x 139 | def decode(self, z): 140 | x = rearrange(z, "b (c h) l -> b c (l h)", h=self.patch_size) 141 | if self.oversampling > 1: 142 | x = self.output_downsampler(x) 143 | return x 144 | 145 | class PQMFPretransform(Pretransform): 146 | def __init__(self, attenuation=100, num_bands=16, channels = 1): 147 | # TODO: Fix PQMF to take in in-channels 148 | super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) 149 | from .pqmf import PQMF 150 | self.pqmf = PQMF(attenuation, num_bands) 151 | 152 | 153 | def encode(self, x): 154 | # x is (Batch x Channels x Time) 155 | x = self.pqmf.forward(x) 156 | # pqmf.forward returns (Batch x Channels x Bands x Time) 157 | # but Pretransform needs Batch x Channels x Time 158 | # so concatenate channels and bands into one axis 159 | return rearrange(x, "b c n t -> b (c n) t") 160 | 161 | def decode(self, x): 162 | # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) 163 | x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands) 164 | # returns (Batch x Channels x Time) 165 | return self.pqmf.inverse(x) 166 | 167 | class PretrainedDACPretransform(Pretransform): 168 | def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True): 169 | super().__init__(enable_grad=False, io_channels=1, is_discrete=True) 170 | 171 | import dac 172 | 173 | model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate) 174 | 175 | self.model = dac.DAC.load(model_path) 176 | 177 | self.quantize_on_decode = quantize_on_decode 178 | 179 | if model_type == "44khz": 180 | self.downsampling_ratio = 512 181 | else: 182 | self.downsampling_ratio = 320 183 | 184 | self.io_channels = 1 185 | 186 | self.scale = scale 187 | 188 | self.chunked = chunked 189 | 190 | self.encoded_channels = self.model.latent_dim 191 | 192 | self.num_quantizers = self.model.n_codebooks 193 | 194 | self.codebook_size = self.model.codebook_size 195 | 196 | def encode(self, x): 197 | 198 | latents = self.model.encoder(x) 199 | 200 | if self.quantize_on_decode: 201 | output = latents 202 | else: 203 | z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) 204 | output = z 205 | 206 | if self.scale != 1.0: 207 | output = output / self.scale 208 | 209 | return output 210 | 211 | def decode(self, z): 212 | 213 | if self.scale != 1.0: 214 | z = z * self.scale 215 | 216 | if self.quantize_on_decode: 217 | z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) 218 | 219 | return self.model.decode(z) 220 | 221 | def tokenize(self, x): 222 | return self.model.encode(x)[1] 223 | 224 | def decode_tokens(self, tokens): 225 | latents = self.model.quantizer.from_codes(tokens) 226 | return self.model.decode(latents) 227 | 228 | class AudiocraftCompressionPretransform(Pretransform): 229 | def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True): 230 | super().__init__(enable_grad=False, io_channels=1, is_discrete=True) 231 | 232 | try: 233 | from audiocraft.models import CompressionModel 234 | except ImportError: 235 | raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.") 236 | 237 | self.model = CompressionModel.get_pretrained(model_type) 238 | 239 | self.quantize_on_decode = quantize_on_decode 240 | 241 | self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate) 242 | 243 | self.sample_rate = self.model.sample_rate 244 | 245 | self.io_channels = self.model.channels 246 | 247 | self.scale = scale 248 | 249 | #self.encoded_channels = self.model.latent_dim 250 | 251 | self.num_quantizers = self.model.num_codebooks 252 | 253 | self.codebook_size = self.model.cardinality 254 | 255 | self.model.to(torch.float16).eval().requires_grad_(False) 256 | 257 | def encode(self, x): 258 | 259 | assert False, "Audiocraft compression models do not support continuous encoding" 260 | 261 | # latents = self.model.encoder(x) 262 | 263 | # if self.quantize_on_decode: 264 | # output = latents 265 | # else: 266 | # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) 267 | # output = z 268 | 269 | # if self.scale != 1.0: 270 | # output = output / self.scale 271 | 272 | # return output 273 | 274 | def decode(self, z): 275 | 276 | assert False, "Audiocraft compression models do not support continuous decoding" 277 | 278 | # if self.scale != 1.0: 279 | # z = z * self.scale 280 | 281 | # if self.quantize_on_decode: 282 | # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) 283 | 284 | # return self.model.decode(z) 285 | 286 | def tokenize(self, x): 287 | with torch.amp.autocast("cuda", enabled=False): 288 | return self.model.encode(x.to(torch.float16))[0] 289 | 290 | def decode_tokens(self, tokens): 291 | with torch.amp.autocast("cuda", enabled=False): 292 | return self.model.decode(tokens) 293 | -------------------------------------------------------------------------------- /stable_audio_tools/training/lm.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import sys, gc 3 | import random 4 | import torch 5 | import torchaudio 6 | import typing as tp 7 | import wandb 8 | 9 | from ema_pytorch import EMA 10 | from einops import rearrange 11 | from safetensors.torch import save_file 12 | from torch import optim 13 | from torch.nn import functional as F 14 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 15 | 16 | from ..interface.aeiou import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image 17 | from ..models.lm import AudioLanguageModelWrapper 18 | from .utils import create_optimizer_from_config, create_scheduler_from_config, log_audio, log_image 19 | 20 | class AudioLanguageModelTrainingWrapper(pl.LightningModule): 21 | def __init__( 22 | self, 23 | model: AudioLanguageModelWrapper, 24 | lr = 1e-4, 25 | use_ema=False, 26 | ema_copy=None, 27 | optimizer_configs: dict = None, 28 | pre_encoded=False 29 | ): 30 | super().__init__() 31 | 32 | self.model = model 33 | 34 | self.model.pretransform.requires_grad_(False) 35 | 36 | self.model_ema = None 37 | if use_ema: 38 | self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10) 39 | 40 | assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" 41 | 42 | if optimizer_configs is None: 43 | optimizer_configs = { 44 | "lm": { 45 | "optimizer": { 46 | "type": "AdamW", 47 | "config": { 48 | "lr": lr, 49 | "betas": (0.9, 0.95), 50 | "weight_decay": 0.1 51 | } 52 | } 53 | } 54 | } 55 | else: 56 | if lr is not None: 57 | print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") 58 | 59 | self.optimizer_configs = optimizer_configs 60 | 61 | self.pre_encoded = pre_encoded 62 | 63 | def configure_optimizers(self): 64 | lm_opt_config = self.optimizer_configs['lm'] 65 | opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters()) 66 | 67 | if "scheduler" in lm_opt_config: 68 | sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm) 69 | sched_lm_config = { 70 | "scheduler": sched_lm, 71 | "interval": "step" 72 | } 73 | return [opt_lm], [sched_lm_config] 74 | 75 | return [opt_lm] 76 | 77 | # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license 78 | # License can be found in LICENSES/LICENSE_META.txt 79 | 80 | def _compute_cross_entropy( 81 | self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor 82 | ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: 83 | """Compute cross entropy between multi-codebook targets and model's logits. 84 | The cross entropy is computed per codebook to provide codebook-level cross entropy. 85 | Valid timesteps for each of the codebook are pulled from the mask, where invalid 86 | timesteps are set to 0. 87 | 88 | Args: 89 | logits (torch.Tensor): Model's logits of shape [B, K, T, card]. 90 | targets (torch.Tensor): Target codes, of shape [B, K, T]. 91 | mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. 92 | Returns: 93 | ce (torch.Tensor): Cross entropy averaged over the codebooks 94 | ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). 95 | """ 96 | B, K, T = targets.shape 97 | assert logits.shape[:-1] == targets.shape 98 | assert mask.shape == targets.shape 99 | ce = torch.zeros([], device=targets.device) 100 | ce_per_codebook: tp.List[torch.Tensor] = [] 101 | for k in range(K): 102 | logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] 103 | targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] 104 | mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] 105 | ce_targets = targets_k[mask_k] 106 | ce_logits = logits_k[mask_k] 107 | q_ce = F.cross_entropy(ce_logits, ce_targets) 108 | ce += q_ce 109 | ce_per_codebook.append(q_ce.detach()) 110 | # average cross entropy across codebooks 111 | ce = ce / K 112 | return ce, ce_per_codebook 113 | 114 | def training_step(self, batch, batch_idx): 115 | reals, metadata = batch 116 | 117 | if reals.ndim == 4 and reals.shape[0] == 1: 118 | reals = reals[0] 119 | 120 | if not self.pre_encoded: 121 | codes = self.model.pretransform.tokenize(reals) 122 | else: 123 | codes = reals 124 | 125 | padding_masks = [] 126 | for md in metadata: 127 | if md["padding_mask"].ndim == 1: 128 | padding_masks.append(md["padding_mask"]) 129 | else: 130 | padding_masks.append(md["padding_mask"][0]) 131 | 132 | padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length) 133 | 134 | # Interpolate padding masks to the same length as the codes 135 | padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode='nearest').bool() 136 | 137 | condition_tensors = None 138 | 139 | # If the model is conditioned, get the conditioning tensors 140 | if self.model.conditioner is not None: 141 | condition_tensors = self.model.conditioner(metadata, self.device) 142 | 143 | lm_output = self.model.compute_logits(codes, condition_tensors=condition_tensors, cfg_dropout_prob=0.1) 144 | 145 | logits = lm_output.logits # [b, k, t, c] 146 | logits_mask = lm_output.mask # [b, k, t] 147 | 148 | logits_mask = logits_mask & padding_masks 149 | 150 | cross_entropy, cross_entropy_per_codebook = self._compute_cross_entropy(logits, codes, logits_mask) 151 | 152 | loss = cross_entropy 153 | 154 | log_dict = { 155 | 'train/loss': loss.detach(), 156 | 'train/cross_entropy': cross_entropy.detach(), 157 | 'train/perplexity': torch.exp(cross_entropy).detach(), 158 | 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] 159 | } 160 | 161 | for k, ce_q in enumerate(cross_entropy_per_codebook): 162 | log_dict[f'cross_entropy_q{k + 1}'] = ce_q 163 | log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q) 164 | 165 | self.log_dict(log_dict, prog_bar=True, on_step=True) 166 | return loss 167 | 168 | def on_before_zero_grad(self, *args, **kwargs): 169 | if self.model_ema is not None: 170 | self.model_ema.update() 171 | 172 | def export_model(self, path, use_safetensors=False): 173 | 174 | model = self.model_ema.ema_model if self.model_ema is not None else self.model 175 | 176 | if use_safetensors: 177 | save_file(model.state_dict(), path) 178 | else: 179 | torch.save({"state_dict": model.state_dict()}, path) 180 | 181 | 182 | class AudioLanguageModelDemoCallback(pl.Callback): 183 | def __init__(self, 184 | demo_every=2000, 185 | num_demos=8, 186 | sample_size=65536, 187 | sample_rate=48000, 188 | demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, 189 | demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], 190 | **kwargs 191 | ): 192 | super().__init__() 193 | 194 | self.demo_every = demo_every 195 | self.num_demos = num_demos 196 | self.demo_samples = sample_size 197 | self.sample_rate = sample_rate 198 | self.last_demo_step = -1 199 | self.demo_conditioning = demo_conditioning 200 | self.demo_cfg_scales = demo_cfg_scales 201 | 202 | @rank_zero_only 203 | @torch.no_grad() 204 | def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx): 205 | 206 | if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: 207 | return 208 | 209 | module.eval() 210 | 211 | print(f"Generating demo") 212 | self.last_demo_step = trainer.global_step 213 | 214 | demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio 215 | 216 | #demo_reals = batch[0][:self.num_demos] 217 | 218 | # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: 219 | # demo_reals = demo_reals[0] 220 | 221 | #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals) 222 | 223 | ##Limit to first 50 tokens 224 | #demo_reals_tokens = demo_reals_tokens[:, :, :50] 225 | 226 | try: 227 | print("Getting conditioning") 228 | 229 | for cfg_scale in self.demo_cfg_scales: 230 | 231 | model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model 232 | 233 | print(f"Generating demo for cfg scale {cfg_scale}") 234 | fakes = model.generate_audio( 235 | batch_size=self.num_demos, 236 | max_gen_len=demo_length_tokens, 237 | conditioning=self.demo_conditioning, 238 | #init_data = demo_reals_tokens, 239 | cfg_scale=cfg_scale, 240 | temp=1.0, 241 | top_p=0.95 242 | ) 243 | 244 | # Put the demos together 245 | fakes = rearrange(fakes, 'b d n -> d (b n)') 246 | 247 | filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' 248 | fakes = fakes / fakes.abs().max() 249 | fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu() 250 | torchaudio.save(filename, fakes, self.sample_rate) 251 | 252 | log_audio( 253 | trainer.logger, f'demo_cfg_{cfg_scale}', filename, 254 | sample_rate=self.sample_rate, caption='Reconstructed') 255 | log_image( 256 | trainer.logger, f'demo_melspec_left_cfg_{cfg_scale}', 257 | audio_spectrogram_image(fakes)) 258 | 259 | except Exception as e: 260 | raise e 261 | finally: 262 | gc.collect() 263 | torch.cuda.empty_cache() 264 | module.train() 265 | -------------------------------------------------------------------------------- /LICENSES/LICENSE_AEIOU.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. --------------------------------------------------------------------------------