├── .gitignore ├── LICENSE ├── README.md ├── defaults.ini ├── example ├── V2A_sample-1.mp4 ├── V2A_sample-2.mp4 ├── V2A_sample-3.mp4 ├── V2M_sample-1.mp4 ├── V2M_sample-2.mp4 └── V2M_sample-3.mp4 ├── pyproject.toml ├── run_gradio.py ├── setup.py └── stable_audio_tools ├── __init__.py ├── data ├── __init__.py ├── dataset.py └── utils.py ├── inference ├── __init__.py ├── generation.py ├── sampling.py └── utils.py ├── interface ├── __init__.py └── gradio.py ├── models ├── __init__.py ├── adp.py ├── autoencoders.py ├── blocks.py ├── bottleneck.py ├── codebook_patterns.py ├── conditioners.py ├── diffusion.py ├── discriminators.py ├── dit.py ├── factory.py ├── lm.py ├── local_attention.py ├── pqmf.py ├── pretrained.py ├── pretransforms.py ├── temptransformer.py ├── transformer.py ├── utils.py └── wavelets.py └── training ├── __init__.py ├── autoencoders.py ├── diffusion.py ├── factory.py ├── lm.py ├── losses ├── __init__.py ├── auraloss.py └── losses.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | 153 | 154 | # PyCharm 155 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 156 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 157 | # and can be added to the global gitignore or merged into this file. For a more nuclear 158 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 159 | #.idea/ 160 | 161 | *.ckpt 162 | *.wav 163 | # *.mp4 164 | *.mp3 165 | *.jsonl 166 | wandb/* 167 | 168 | 169 | 170 | 171 | model/ 172 | logs/ 173 | log/ 174 | saved_ckpt/ 175 | wandb/ 176 | demo_result/ 177 | model/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🎧 AudioX: Diffusion Transformer for Anything-to-Audio Generation 2 | 3 | [](https://arxiv.org/abs/2503.10522) 4 | [](https://zeyuet.github.io/AudioX/) 5 | [](https://huggingface.co/HKUSTAudio/AudioX) 6 | [](https://huggingface.co/spaces/Zeyue7/AudioX) 7 | 8 | --- 9 | 10 | **This is the official repository for "[AudioX: Diffusion Transformer for Anything-to-Audio Generation](https://arxiv.org/pdf/2503.10522)".** 11 | 12 | 13 | ## 📺 Demo Video 14 | 15 | https://github.com/user-attachments/assets/0d8dd927-ff0f-4b35-ab1f-b3c3915017be 16 | 17 | --- 18 | 19 | 20 | ## ✨ Abstract 21 | 22 | Audio and music generation have emerged as crucial tasks in many applications, yet existing approaches face significant limitations: they operate in isolation without unified capabilities across modalities, suffer from scarce high-quality, multi-modal training data, and struggle to effectively integrate diverse inputs. In this work, we propose AudioX, a unified Diffusion Transformer model for Anything-to-Audio and Music Generation. Unlike previous domain-specific models, AudioX can generate both general audio and music with high quality, while offering flexible natural language control and seamless processing of various modalities including text, video, image, music, and audio. Its key innovation is a multi-modal masked training strategy that masks inputs across modalities and forces the model to learn from masked inputs, yielding robust and unified cross-modal representations. To address data scarcity, we curate two comprehensive datasets: vggsound-caps with 190K audio captions based on the VGGSound dataset, and V2M-caps with 6 million music captions derived from the V2M dataset. Extensive experiments demonstrate that AudioX not only matches or outperforms state-of-the-art specialized models, but also offers remarkable versatility in handling diverse input modalities and generation tasks within a unified architecture. 23 | 24 | 25 | ## ✨ Teaser 26 | 27 |
28 |
29 |
(a) Overview of AudioX, illustrating its capabilities across various tasks. (b) Radar chart comparing the performance of different methods across multiple benchmarks. AudioX demonstrates superior Inception Scores (IS) across a diverse set of datasets in audio and music generation tasks.
31 | 32 | 33 | ## ✨ Method 34 | 35 |
36 |
37 |
Overview of the AudioX Framework.
39 | 40 | 41 | 42 | ## Code 43 | 44 | 45 | ### 🛠️ Environment Setup 46 | 47 | ```bash 48 | git clone https://github.com/ZeyueT/AudioX.git 49 | cd AudioX 50 | conda create -n AudioX python=3.8.20 51 | conda activate AudioX 52 | pip install git+https://github.com/ZeyueT/AudioX.git 53 | conda install -c conda-forge ffmpeg libsndfile 54 | 55 | ``` 56 | 57 | ## 🪄 Pretrained Checkpoints 58 | 59 | Download the pretrained model from 🤗 [AudioX on Hugging Face](https://huggingface.co/HKUSTAudio/AudioX): 60 | 61 | ```bash 62 | mkdir -p model 63 | wget https://huggingface.co/HKUSTAudio/AudioX/resolve/main/model.ckpt -O model/model.ckpt 64 | wget https://huggingface.co/HKUSTAudio/AudioX/resolve/main/config.json -O model/config.json 65 | ``` 66 | 67 | ### 🤗 Gradio Demo 68 | 69 | To launch the Gradio demo locally, run: 70 | 71 | ```bash 72 | python3 run_gradio.py \ 73 | --model-config model/config.json \ 74 | --share 75 | ``` 76 | 77 | 78 | ### 🎯 Prompt Configuration Examples 79 | 80 | | Task | `video_path` | `text_prompt` | `audio_path` | 81 | |:---------------------|:-------------------|:----------------------------------------------|:-------------| 82 | | Text-to-Audio (T2A) | `None` | `"Typing on a keyboard"` | `None` | 83 | | Text-to-Music (T2M) | `None` | `"A music with piano and violin"` | `None` | 84 | | Video-to-Audio (V2A) | `"video_path.mp4"` | `"Generate general audio for the video"` | `None` | 85 | | Video-to-Music (V2M) | `"video_path.mp4"` | `"Generate music for the video"` | `None` | 86 | | TV-to-Audio (TV2A) | `"video_path.mp4"` | `"Ocean waves crashing with people laughing"` | `None` | 87 | | TV-to-Music (TV2M) | `"video_path.mp4"` | `"Generate music with piano instrument"` | `None` | 88 | 89 | ### 🖥️ Script Inference 90 | 91 | ```python 92 | import torch 93 | import torchaudio 94 | from einops import rearrange 95 | from stable_audio_tools import get_pretrained_model 96 | from stable_audio_tools.inference.generation import generate_diffusion_cond 97 | from stable_audio_tools.data.utils import read_video, merge_video_audio 98 | from stable_audio_tools.data.utils import load_and_process_audio 99 | import os 100 | 101 | device = "cuda" if torch.cuda.is_available() else "cpu" 102 | 103 | # Download model 104 | model, model_config = get_pretrained_model("HKUSTAudio/AudioX") 105 | sample_rate = model_config["sample_rate"] 106 | sample_size = model_config["sample_size"] 107 | target_fps = model_config["video_fps"] 108 | seconds_start = 0 109 | seconds_total = 10 110 | 111 | model = model.to(device) 112 | 113 | # for video-to-music generation 114 | video_path = "example/V2M_sample-1.mp4" 115 | text_prompt = "Generate music for the video" 116 | audio_path = None 117 | 118 | video_tensor = read_video(video_path, seek_time=0, duration=seconds_total, target_fps=target_fps) 119 | audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total) 120 | 121 | conditioning = [{ 122 | "video_prompt": [video_tensor.unsqueeze(0)], 123 | "text_prompt": text_prompt, 124 | "audio_prompt": audio_tensor.unsqueeze(0), 125 | "seconds_start": seconds_start, 126 | "seconds_total": seconds_total 127 | }] 128 | 129 | # Generate stereo audio 130 | output = generate_diffusion_cond( 131 | model, 132 | steps=250, 133 | cfg_scale=7, 134 | conditioning=conditioning, 135 | sample_size=sample_size, 136 | sigma_min=0.3, 137 | sigma_max=500, 138 | sampler_type="dpmpp-3m-sde", 139 | device=device 140 | ) 141 | 142 | # Rearrange audio batch to a single sequence 143 | output = rearrange(output, "b d n -> d (b n)") 144 | 145 | # Peak normalize, clip, convert to int16, and save to file 146 | output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() 147 | torchaudio.save("output.wav", output, sample_rate) 148 | 149 | if video_path is not None and os.path.exists(video_path): 150 | merge_video_audio(video_path, "output.wav", "output.mp4", 0, seconds_total) 151 | 152 | ``` 153 | 154 | 155 | ## 🚀 Citation 156 | 157 | If you find our work useful, please consider citing: 158 | 159 | ``` 160 | @article{tian2025audiox, 161 | title={AudioX: Diffusion Transformer for Anything-to-Audio Generation}, 162 | author={Tian, Zeyue and Jin, Yizhu and Liu, Zhaoyang and Yuan, Ruibin and Tan, Xu and Chen, Qifeng and Xue, Wei and Guo, Yike}, 163 | journal={arXiv preprint arXiv:2503.10522}, 164 | year={2025} 165 | } 166 | ``` 167 | 168 | ## 📭 Contact 169 | 170 | If you have any comments or questions, feel free to contact Zeyue Tian(ztianad@connect.ust.hk). 171 | 172 | ## License 173 | 174 | Please follow [CC-BY-NC](./LICENSE). 175 | -------------------------------------------------------------------------------- /defaults.ini: -------------------------------------------------------------------------------- 1 | 2 | [DEFAULTS] 3 | 4 | #name of the run 5 | name = stable_audio_tools 6 | 7 | # the batch size 8 | batch_size = 8 9 | 10 | # number of GPUs to use for training 11 | num_gpus = 1 12 | 13 | # number of nodes to use for training 14 | num_nodes = 1 15 | 16 | # Multi-GPU strategy for PyTorch Lightning 17 | strategy = "" 18 | 19 | # Precision to use for training 20 | precision = "16-mixed" 21 | 22 | # number of CPU workers for the DataLoader 23 | num_workers = 8 24 | 25 | # the random seed 26 | seed = 42 27 | 28 | # Batches for gradient accumulation 29 | accum_batches = 1 30 | 31 | # Number of steps between checkpoints 32 | checkpoint_every = 10000 33 | 34 | # trainer checkpoint file to restart training from 35 | ckpt_path = '' 36 | 37 | # model checkpoint file to start a new training run from 38 | pretrained_ckpt_path = '' 39 | 40 | # Checkpoint path for the pretransform model if needed 41 | pretransform_ckpt_path = '' 42 | 43 | # configuration model specifying model hyperparameters 44 | model_config = '' 45 | 46 | # configuration for datasets 47 | dataset_config = '' 48 | 49 | # directory to save the checkpoints in 50 | save_dir = '' 51 | 52 | # gradient_clip_val passed into PyTorch Lightning Trainer 53 | gradient_clip_val = 0.0 54 | 55 | # remove the weight norm from the pretransform model 56 | remove_pretransform_weight_norm = '' -------------------------------------------------------------------------------- /example/V2A_sample-1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeyueT/AudioX/9f673e2499c1127f9f8c8060e5a68b084f8e63fb/example/V2A_sample-1.mp4 -------------------------------------------------------------------------------- /example/V2A_sample-2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeyueT/AudioX/9f673e2499c1127f9f8c8060e5a68b084f8e63fb/example/V2A_sample-2.mp4 -------------------------------------------------------------------------------- /example/V2A_sample-3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeyueT/AudioX/9f673e2499c1127f9f8c8060e5a68b084f8e63fb/example/V2A_sample-3.mp4 -------------------------------------------------------------------------------- /example/V2M_sample-1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeyueT/AudioX/9f673e2499c1127f9f8c8060e5a68b084f8e63fb/example/V2M_sample-1.mp4 -------------------------------------------------------------------------------- /example/V2M_sample-2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeyueT/AudioX/9f673e2499c1127f9f8c8060e5a68b084f8e63fb/example/V2M_sample-2.mp4 -------------------------------------------------------------------------------- /example/V2M_sample-3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeyueT/AudioX/9f673e2499c1127f9f8c8060e5a68b084f8e63fb/example/V2M_sample-3.mp4 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /run_gradio.py: -------------------------------------------------------------------------------- 1 | from stable_audio_tools import get_pretrained_model 2 | from stable_audio_tools.interface.gradio import create_ui 3 | import json 4 | 5 | import torch 6 | 7 | def main(args): 8 | torch.manual_seed(42) 9 | 10 | interface = create_ui( 11 | model_config_path = args.model_config, 12 | ckpt_path=args.ckpt_path, 13 | pretrained_name=args.pretrained_name, 14 | pretransform_ckpt_path=args.pretransform_ckpt_path, 15 | model_half=args.model_half 16 | ) 17 | interface.queue() 18 | interface.launch(share=args.share, auth=(args.username, args.password) if args.username is not None else None) 19 | 20 | if __name__ == "__main__": 21 | import argparse 22 | parser = argparse.ArgumentParser(description='Run gradio interface') 23 | parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False) 24 | parser.add_argument('--model-config', type=str, help='Path to model config', required=False) 25 | parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False) 26 | parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False) 27 | parser.add_argument('--share', action='store_true', help='Create a publicly shareable link', required=False) 28 | parser.add_argument('--username', type=str, help='Gradio username', required=False) 29 | parser.add_argument('--password', type=str, help='Gradio password', required=False) 30 | parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False) 31 | args = parser.parse_args() 32 | main(args) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='AudioX', 5 | version='0.1.0', 6 | url='https://github.com/ZeyueT/AudioX.git', 7 | author='AudioX, HKUST', 8 | description='Training and inference tools for generative audio models from AudioX', 9 | packages=find_packages(), 10 | install_requires=[ 11 | 'aeiou', 12 | 'alias-free-torch==0.0.6', 13 | 'auraloss==0.4.0', 14 | 'descript-audio-codec==1.0.0', 15 | 'decord==0.6.0', 16 | 'einops', 17 | 'einops_exts', 18 | 'ema-pytorch==0.2.3', 19 | 'encodec==0.1.1', 20 | 'gradio==4.44.1', 21 | 'gradio_client==1.3.0', 22 | 'huggingface_hub', 23 | 'importlib-resources==5.12.0', 24 | 'k-diffusion==0.1.1', 25 | 'laion-clap==1.1.6', 26 | 'local-attention==1.8.6', 27 | 'pandas==2.0.2', 28 | 'pedalboard==0.9.14', 29 | 'prefigure==0.0.9', 30 | 'pytorch_lightning==2.4.0', 31 | 'PyWavelets==1.4.1', 32 | 'safetensors', 33 | 'sentencepiece==0.1.99', 34 | 'torch>=2.0.1', 35 | 'torchaudio>=2.0.2', 36 | 'torchmetrics==0.11.4', 37 | 'tqdm', 38 | 'transformers', 39 | 'v-diffusion-pytorch==0.0.2', 40 | 'vector-quantize-pytorch==1.9.14', 41 | 'wandb', 42 | 'webdataset==0.2.48', 43 | 'x-transformers<1.27.0', 44 | ], 45 | 46 | ) -------------------------------------------------------------------------------- /stable_audio_tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .models.factory import create_model_from_config, create_model_from_config_path 2 | from .models.pretrained import get_pretrained_model -------------------------------------------------------------------------------- /stable_audio_tools/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeyueT/AudioX/9f673e2499c1127f9f8c8060e5a68b084f8e63fb/stable_audio_tools/data/__init__.py -------------------------------------------------------------------------------- /stable_audio_tools/data/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | 5 | from torch import nn 6 | from typing import Tuple 7 | import os 8 | import subprocess as sp 9 | from PIL import Image 10 | from torchvision import transforms 11 | from decord import VideoReader, cpu 12 | 13 | class PadCrop(nn.Module): 14 | def __init__(self, n_samples, randomize=True): 15 | super().__init__() 16 | self.n_samples = n_samples 17 | self.randomize = randomize 18 | 19 | def __call__(self, signal): 20 | n, s = signal.shape 21 | start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() 22 | end = start + self.n_samples 23 | output = signal.new_zeros([n, self.n_samples]) 24 | output[:, :min(s, self.n_samples)] = signal[:, start:end] 25 | return output 26 | 27 | 28 | class PadCrop_Normalized_T(nn.Module): 29 | 30 | def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): 31 | super().__init__() 32 | self.n_samples = n_samples 33 | self.sample_rate = sample_rate 34 | self.randomize = randomize 35 | 36 | def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int, torch.Tensor]: 37 | n_channels, n_samples = source.shape 38 | 39 | # Calculate the duration of the audio in seconds 40 | total_duration = n_samples // self.sample_rate 41 | 42 | # If the audio is shorter than the desired length, pad it 43 | upper_bound = max(0, n_samples - self.n_samples) 44 | 45 | # If randomize is False, always start at the beginning of the audio 46 | offset = 0 47 | 48 | if self.randomize and n_samples > self.n_samples: 49 | valid_offsets = [ 50 | i * self.sample_rate for i in range(0, total_duration, 10) 51 | if i * self.sample_rate + self.n_samples <= n_samples and 52 | (total_duration <= 20 or total_duration - i >= 15) 53 | ] 54 | if valid_offsets: 55 | offset = random.choice(valid_offsets) 56 | 57 | # Calculate the start and end times of the chunk 58 | t_start = offset / (upper_bound + self.n_samples) 59 | t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) 60 | 61 | # Create the chunk 62 | chunk = source.new_zeros([n_channels, self.n_samples]) 63 | 64 | # Copy the audio into the chunk 65 | chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples] 66 | 67 | # Calculate the start and end times of the chunk in seconds 68 | seconds_start = math.floor(offset / self.sample_rate) 69 | seconds_total = math.ceil(n_samples / self.sample_rate) 70 | 71 | # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't 72 | padding_mask = torch.zeros([self.n_samples]) 73 | padding_mask[:min(n_samples, self.n_samples)] = 1 74 | 75 | return ( 76 | chunk, 77 | t_start, 78 | t_end, 79 | seconds_start, 80 | seconds_total, 81 | padding_mask 82 | ) 83 | 84 | 85 | class PhaseFlipper(nn.Module): 86 | "Randomly invert the phase of a signal" 87 | def __init__(self, p=0.5): 88 | super().__init__() 89 | self.p = p 90 | def __call__(self, signal): 91 | return -signal if (random.random() < self.p) else signal 92 | 93 | class Mono(nn.Module): 94 | def __call__(self, signal): 95 | return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal 96 | 97 | class Stereo(nn.Module): 98 | def __call__(self, signal): 99 | signal_shape = signal.shape 100 | # Check if it's mono 101 | if len(signal_shape) == 1: # s -> 2, s 102 | signal = signal.unsqueeze(0).repeat(2, 1) 103 | elif len(signal_shape) == 2: 104 | if signal_shape[0] == 1: #1, s -> 2, s 105 | signal = signal.repeat(2, 1) 106 | elif signal_shape[0] > 2: #?, s -> 2,s 107 | signal = signal[:2, :] 108 | 109 | return signal 110 | 111 | 112 | def adjust_video_duration(video_tensor, duration, target_fps): 113 | current_duration = video_tensor.shape[0] 114 | target_duration = duration * target_fps 115 | if current_duration > target_duration: 116 | video_tensor = video_tensor[:target_duration] 117 | elif current_duration < target_duration: 118 | last_frame = video_tensor[-1:] 119 | repeat_times = target_duration - current_duration 120 | video_tensor = torch.cat((video_tensor, last_frame.repeat(repeat_times, 1, 1, 1)), dim=0) 121 | return video_tensor 122 | 123 | def read_video(filepath, seek_time=0., duration=-1, target_fps=2): 124 | if filepath is None: 125 | return torch.zeros((int(duration * target_fps), 3, 224, 224)) 126 | 127 | ext = os.path.splitext(filepath)[1].lower() 128 | if ext in ['.jpg', '.jpeg', '.png']: 129 | resize_transform = transforms.Resize((224, 224)) 130 | image = Image.open(filepath).convert("RGB") 131 | frame = transforms.ToTensor()(image).unsqueeze(0) 132 | frame = resize_transform(frame) 133 | target_frames = int(duration * target_fps) 134 | frame = frame.repeat(int(math.ceil(target_frames / frame.shape[0])), 1, 1, 1)[:target_frames] 135 | assert frame.shape[0] == target_frames, f"The shape of frame is {frame.shape}" 136 | return frame 137 | 138 | vr = VideoReader(filepath, ctx=cpu(0)) 139 | fps = vr.get_avg_fps() 140 | total_frames = len(vr) 141 | 142 | seek_frame = int(seek_time * fps) 143 | if duration > 0: 144 | total_frames_to_read = int(target_fps * duration) 145 | frame_interval = int(math.ceil(fps / target_fps)) 146 | end_frame = min(seek_frame + total_frames_to_read * frame_interval, total_frames) 147 | frame_ids = list(range(seek_frame, end_frame, frame_interval)) 148 | else: 149 | frame_interval = int(math.ceil(fps / target_fps)) 150 | frame_ids = list(range(0, total_frames, frame_interval)) 151 | 152 | frames = vr.get_batch(frame_ids).asnumpy() 153 | frames = torch.from_numpy(frames).permute(0, 3, 1, 2) 154 | 155 | if frames.shape[2] != 224 or frames.shape[3] != 224: 156 | resize_transform = transforms.Resize((224, 224)) 157 | frames = resize_transform(frames) 158 | 159 | video_tensor = adjust_video_duration(frames, duration, target_fps) 160 | assert video_tensor.shape[0] == duration * target_fps, f"The shape of video_tensor is {video_tensor.shape}" 161 | return video_tensor 162 | 163 | def merge_video_audio(video_path, audio_path, output_path, start_time, duration): 164 | command = [ 165 | 'ffmpeg', 166 | '-y', 167 | '-ss', str(start_time), 168 | '-t', str(duration), 169 | '-i', video_path, 170 | '-i', audio_path, 171 | '-c:v', 'copy', 172 | '-c:a', 'aac', 173 | '-map', '0:v:0', 174 | '-map', '1:a:0', 175 | '-shortest', 176 | '-strict', 'experimental', 177 | output_path 178 | ] 179 | 180 | try: 181 | sp.run(command, check=True) 182 | print(f"Successfully merged audio and video into {output_path}") 183 | return output_path 184 | except sp.CalledProcessError as e: 185 | print(f"Error merging audio and video: {e}") 186 | return None 187 | 188 | def load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total): 189 | if audio_path is None: 190 | return torch.zeros((2, int(sample_rate * seconds_total))) 191 | audio_tensor, sr = torchaudio.load(audio_path) 192 | start_index = int(sample_rate * seconds_start) 193 | target_length = int(sample_rate * seconds_total) 194 | end_index = start_index + target_length 195 | audio_tensor = audio_tensor[:, start_index:end_index] 196 | if audio_tensor.shape[1] < target_length: 197 | pad_length = target_length - audio_tensor.shape[1] 198 | audio_tensor = F.pad(audio_tensor, (pad_length, 0)) 199 | return audio_tensor -------------------------------------------------------------------------------- /stable_audio_tools/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeyueT/AudioX/9f673e2499c1127f9f8c8060e5a68b084f8e63fb/stable_audio_tools/inference/__init__.py -------------------------------------------------------------------------------- /stable_audio_tools/inference/generation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import typing as tp 4 | import math 5 | from torchaudio import transforms as T 6 | 7 | from .utils import prepare_audio 8 | from .sampling import sample, sample_k, sample_rf 9 | from ..data.utils import PadCrop 10 | 11 | def generate_diffusion_uncond( 12 | model, 13 | steps: int = 250, 14 | batch_size: int = 1, 15 | sample_size: int = 2097152, 16 | seed: int = -1, 17 | device: str = "cuda", 18 | init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, 19 | init_noise_level: float = 1.0, 20 | return_latents = False, 21 | **sampler_kwargs 22 | ) -> torch.Tensor: 23 | 24 | # The length of the output in audio samples 25 | audio_sample_size = sample_size 26 | 27 | # If this is latent diffusion, change sample_size instead to the downsampled latent size 28 | if model.pretransform is not None: 29 | sample_size = sample_size // model.pretransform.downsampling_ratio 30 | 31 | # Seed 32 | # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed. 33 | seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32) 34 | # seed = 777 35 | print(seed) 36 | torch.manual_seed(seed) 37 | # Define the initial noise immediately after setting the seed 38 | noise = torch.randn([batch_size, model.io_channels, sample_size], device=device) 39 | 40 | if init_audio is not None: 41 | # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio. 42 | in_sr, init_audio = init_audio 43 | 44 | io_channels = model.io_channels 45 | 46 | # For latent models, set the io_channels to the autoencoder's io_channels 47 | if model.pretransform is not None: 48 | io_channels = model.pretransform.io_channels 49 | 50 | # Prepare the initial audio for use by the model 51 | init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device) 52 | 53 | # For latent models, encode the initial audio into latents 54 | if model.pretransform is not None: 55 | init_audio = model.pretransform.encode(init_audio) 56 | 57 | init_audio = init_audio.repeat(batch_size, 1, 1) 58 | else: 59 | # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. 60 | init_audio = None 61 | init_noise_level = None 62 | 63 | # Inpainting mask 64 | 65 | if init_audio is not None: 66 | # variations 67 | sampler_kwargs["sigma_max"] = init_noise_level 68 | mask = None 69 | else: 70 | mask = None 71 | 72 | # Now the generative AI part: 73 | 74 | diff_objective = model.diffusion_objective 75 | 76 | if diff_objective == "v": 77 | # k-diffusion denoising process go! 78 | sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device) 79 | elif diff_objective == "rectified_flow": 80 | sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, device=device) 81 | 82 | # Denoising process done. 83 | # If this is latent diffusion, decode latents back into audio 84 | if model.pretransform is not None and not return_latents: 85 | sampled = model.pretransform.decode(sampled) 86 | 87 | # Return audio 88 | return sampled 89 | 90 | 91 | def generate_diffusion_cond( 92 | model, 93 | steps: int = 250, 94 | cfg_scale=6, 95 | conditioning: dict = None, 96 | conditioning_tensors: tp.Optional[dict] = None, 97 | negative_conditioning: dict = None, 98 | negative_conditioning_tensors: tp.Optional[dict] = None, 99 | batch_size: int = 1, 100 | sample_size: int = 2097152, 101 | sample_rate: int = 48000, 102 | seed: int = -1, 103 | device: str = "cuda", 104 | init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, 105 | init_noise_level: float = 1.0, 106 | mask_args: dict = None, 107 | return_latents = False, 108 | **sampler_kwargs 109 | ) -> torch.Tensor: 110 | """ 111 | Generate audio from a prompt using a diffusion model. 112 | 113 | Args: 114 | model: The diffusion model to use for generation. 115 | steps: The number of diffusion steps to use. 116 | cfg_scale: Classifier-free guidance scale 117 | conditioning: A dictionary of conditioning parameters to use for generation. 118 | conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation. 119 | batch_size: The batch size to use for generation. 120 | sample_size: The length of the audio to generate, in samples. 121 | sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly) 122 | seed: The random seed to use for generation, or -1 to use a random seed. 123 | device: The device to use for generation. 124 | init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation. 125 | init_noise_level: The noise level to use when generating from an initial audio sample. 126 | return_latents: Whether to return the latents used for generation instead of the decoded audio. 127 | **sampler_kwargs: Additional keyword arguments to pass to the sampler. 128 | """ 129 | 130 | # The length of the output in audio samples 131 | audio_sample_size = sample_size 132 | 133 | # If this is latent diffusion, change sample_size instead to the downsampled latent size 134 | if model.pretransform is not None: 135 | sample_size = sample_size // model.pretransform.downsampling_ratio 136 | 137 | # Seed 138 | # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed. 139 | seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32) 140 | # seed = 777 141 | # print(seed) 142 | torch.manual_seed(seed) 143 | # Define the initial noise immediately after setting the seed 144 | noise = torch.randn([batch_size, model.io_channels, sample_size], device=device) 145 | 146 | torch.backends.cuda.matmul.allow_tf32 = False 147 | torch.backends.cudnn.allow_tf32 = False 148 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 149 | torch.backends.cudnn.benchmark = False 150 | 151 | # Conditioning 152 | assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors" 153 | if conditioning_tensors is None: 154 | conditioning_tensors = model.conditioner(conditioning, device) 155 | conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors) 156 | 157 | if negative_conditioning is not None or negative_conditioning_tensors is not None: 158 | 159 | if negative_conditioning_tensors is None: 160 | negative_conditioning_tensors = model.conditioner(negative_conditioning, device) 161 | 162 | negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True) 163 | else: 164 | negative_conditioning_tensors = {} 165 | 166 | if init_audio is not None: 167 | # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio. 168 | in_sr, init_audio = init_audio 169 | 170 | io_channels = model.io_channels 171 | 172 | # For latent models, set the io_channels to the autoencoder's io_channels 173 | if model.pretransform is not None: 174 | io_channels = model.pretransform.io_channels 175 | 176 | # Prepare the initial audio for use by the model 177 | init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device) 178 | 179 | # For latent models, encode the initial audio into latents 180 | if model.pretransform is not None: 181 | init_audio = model.pretransform.encode(init_audio) 182 | 183 | init_audio = init_audio.repeat(batch_size, 1, 1) 184 | else: 185 | # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. 186 | init_audio = None 187 | init_noise_level = None 188 | mask_args = None 189 | 190 | # Inpainting mask 191 | if init_audio is not None and mask_args is not None: 192 | # Cut and paste init_audio according to cropfrom, pastefrom, pasteto 193 | # This is helpful for forward and reverse outpainting 194 | cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size) 195 | pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size) 196 | pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size) 197 | assert pastefrom < pasteto, "Paste From should be less than Paste To" 198 | croplen = pasteto - pastefrom 199 | if cropfrom + croplen > sample_size: 200 | croplen = sample_size - cropfrom 201 | cropto = cropfrom + croplen 202 | pasteto = pastefrom + croplen 203 | cutpaste = init_audio.new_zeros(init_audio.shape) 204 | cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto] 205 | #print(cropfrom, cropto, pastefrom, pasteto) 206 | init_audio = cutpaste 207 | # Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args 208 | mask = build_mask(sample_size, mask_args) 209 | mask = mask.to(device) 210 | elif init_audio is not None and mask_args is None: 211 | # variations 212 | sampler_kwargs["sigma_max"] = init_noise_level 213 | mask = None 214 | else: 215 | mask = None 216 | 217 | model_dtype = next(model.model.parameters()).dtype 218 | noise = noise.type(model_dtype) 219 | conditioning_inputs = {k: v.type(model_dtype) if v is not None else v for k, v in conditioning_inputs.items()} 220 | # Now the generative AI part: 221 | # k-diffusion denoising process go! 222 | 223 | diff_objective = model.diffusion_objective 224 | 225 | if diff_objective == "v": 226 | # k-diffusion denoising process go! 227 | sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device) 228 | 229 | elif diff_objective == "rectified_flow": 230 | 231 | if "sigma_min" in sampler_kwargs: 232 | del sampler_kwargs["sigma_min"] 233 | 234 | if "sampler_type" in sampler_kwargs: 235 | del sampler_kwargs["sampler_type"] 236 | 237 | sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device) 238 | 239 | # v-diffusion: 240 | del noise 241 | del conditioning_tensors 242 | del conditioning_inputs 243 | torch.cuda.empty_cache() 244 | # Denoising process done. 245 | # If this is latent diffusion, decode latents back into audio 246 | 247 | if model.pretransform is not None and not return_latents: 248 | #cast sampled latents to pretransform dtype 249 | sampled = sampled.to(next(model.pretransform.parameters()).dtype) 250 | sampled = model.pretransform.decode(sampled) 251 | 252 | return sampled 253 | 254 | # builds a softmask given the parameters 255 | # returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio, 256 | # and anything between is a mixture of old/new 257 | # ideally 0.5 is half/half mixture but i haven't figured this out yet 258 | def build_mask(sample_size, mask_args): 259 | maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size) 260 | maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size) 261 | softnessL = round(mask_args["softnessL"]/100.0 * sample_size) 262 | softnessR = round(mask_args["softnessR"]/100.0 * sample_size) 263 | marination = mask_args["marination"] 264 | # use hann windows for softening the transition (i don't know if this is correct) 265 | hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL] 266 | hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:] 267 | # build the mask. 268 | mask = torch.zeros((sample_size)) 269 | mask[maskstart:maskend] = 1 270 | mask[maskstart:maskstart+softnessL] = hannL 271 | mask[maskend-softnessR:maskend] = hannR 272 | # marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds 273 | if marination > 0: 274 | mask = mask * (1-marination) 275 | return mask 276 | -------------------------------------------------------------------------------- /stable_audio_tools/inference/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from tqdm import trange, tqdm 4 | 5 | import k_diffusion as K 6 | 7 | # Define the noise schedule and sampling loop 8 | def get_alphas_sigmas(t): 9 | """Returns the scaling factors for the clean image (alpha) and for the 10 | noise (sigma), given a timestep.""" 11 | return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) 12 | 13 | def alpha_sigma_to_t(alpha, sigma): 14 | """Returns a timestep, given the scaling factors for the clean image and for 15 | the noise.""" 16 | return torch.atan2(sigma, alpha) / math.pi * 2 17 | 18 | def t_to_alpha_sigma(t): 19 | """Returns the scaling factors for the clean image and for the noise, given 20 | a timestep.""" 21 | return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) 22 | 23 | 24 | @torch.no_grad() 25 | def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args): 26 | """Draws samples from a model given starting noise. Euler method""" 27 | 28 | # Make tensor of ones to broadcast the single t values 29 | ts = x.new_ones([x.shape[0]]) 30 | 31 | # Create the noise schedule 32 | t = torch.linspace(sigma_max, 0, steps + 1) 33 | 34 | #alphas, sigmas = 1-t, t 35 | 36 | for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])): 37 | # Broadcast the current timestep to the correct shape 38 | t_curr_tensor = t_curr * torch.ones( 39 | (x.shape[0],), dtype=x.dtype, device=x.device 40 | ) 41 | dt = t_prev - t_curr # we solve backwards in our formulation 42 | x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc) 43 | 44 | # If we are on the last timestep, output the denoised image 45 | return x 46 | 47 | @torch.no_grad() 48 | def sample(model, x, steps, eta, **extra_args): 49 | """Draws samples from a model given starting noise. v-diffusion""" 50 | ts = x.new_ones([x.shape[0]]) 51 | 52 | # Create the noise schedule 53 | t = torch.linspace(1, 0, steps + 1)[:-1] 54 | 55 | alphas, sigmas = get_alphas_sigmas(t) 56 | 57 | # The sampling loop 58 | for i in trange(steps): 59 | 60 | # Get the model output (v, the predicted velocity) 61 | with torch.cuda.amp.autocast(): 62 | v = model(x, ts * t[i], **extra_args).float() 63 | 64 | # Predict the noise and the denoised image 65 | pred = x * alphas[i] - v * sigmas[i] 66 | eps = x * sigmas[i] + v * alphas[i] 67 | 68 | # If we are not on the last timestep, compute the noisy image for the 69 | # next timestep. 70 | if i < steps - 1: 71 | # If eta > 0, adjust the scaling factor for the predicted noise 72 | # downward according to the amount of additional noise to add 73 | ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ 74 | (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() 75 | adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() 76 | 77 | # Recombine the predicted noise and predicted denoised image in the 78 | # correct proportions for the next step 79 | x = pred * alphas[i + 1] + eps * adjusted_sigma 80 | 81 | # Add the correct amount of fresh noise 82 | if eta: 83 | x += torch.randn_like(x) * ddim_sigma 84 | 85 | # If we are on the last timestep, output the denoised image 86 | return pred 87 | 88 | # Soft mask inpainting is just shrinking hard (binary) mask inpainting 89 | # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step 90 | def get_bmask(i, steps, mask): 91 | strength = (i+1)/(steps) 92 | # convert to binary mask 93 | bmask = torch.where(mask<=strength,1,0) 94 | return bmask 95 | 96 | def make_cond_model_fn(model, cond_fn): 97 | def cond_model_fn(x, sigma, **kwargs): 98 | with torch.enable_grad(): 99 | x = x.detach().requires_grad_() 100 | denoised = model(x, sigma, **kwargs) 101 | cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() 102 | cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) 103 | return cond_denoised 104 | return cond_model_fn 105 | 106 | # Uses k-diffusion from https://github.com/crowsonkb/k-diffusion 107 | # init_data is init_audio as latents (if this is latent diffusion) 108 | # For sampling, set both init_data and mask to None 109 | # For variations, set init_data 110 | # For inpainting, set both init_data & mask 111 | def sample_k( 112 | model_fn, 113 | noise, 114 | init_data=None, 115 | mask=None, 116 | steps=100, 117 | sampler_type="dpmpp-2m-sde", 118 | sigma_min=0.5, 119 | sigma_max=50, 120 | rho=1.0, device="cuda", 121 | callback=None, 122 | cond_fn=None, 123 | **extra_args 124 | ): 125 | 126 | denoiser = K.external.VDenoiser(model_fn) 127 | 128 | if cond_fn is not None: 129 | denoiser = make_cond_model_fn(denoiser, cond_fn) 130 | 131 | # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has 132 | sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device) 133 | # Scale the initial noise by sigma 134 | noise = noise * sigmas[0] 135 | 136 | wrapped_callback = callback 137 | 138 | 139 | if mask is None and init_data is not None: 140 | # VARIATION (no inpainting) 141 | # set the initial latent to the init_data, and noise it with initial sigma 142 | 143 | x = init_data + noise 144 | 145 | elif mask is not None and init_data is not None: 146 | # INPAINTING 147 | bmask = get_bmask(0, steps, mask) 148 | # initial noising 149 | input_noised = init_data + noise 150 | # set the initial latent to a mix of init_data and noise, based on step 0's binary mask 151 | x = input_noised * bmask + noise * (1-bmask) 152 | # define the inpainting callback function (Note: side effects, it mutates x) 153 | # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105 154 | # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 155 | # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)` 156 | def inpainting_callback(args): 157 | i = args["i"] 158 | x = args["x"] 159 | sigma = args["sigma"] 160 | #denoised = args["denoised"] 161 | # noise the init_data input with this step's appropriate amount of noise 162 | input_noised = init_data + torch.randn_like(init_data) * sigma 163 | # shrinking hard mask 164 | bmask = get_bmask(i, steps, mask) 165 | # mix input_noise with x, using binary mask 166 | new_x = input_noised * bmask + x * (1-bmask) 167 | # mutate x 168 | x[:,:,:] = new_x[:,:,:] 169 | # wrap together the inpainting callback and the user-submitted callback. 170 | if callback is None: 171 | wrapped_callback = inpainting_callback 172 | else: 173 | wrapped_callback = lambda args: (inpainting_callback(args), callback(args)) 174 | else: 175 | # SAMPLING 176 | # set the initial latent to noise 177 | x = noise 178 | # x = noise 179 | 180 | with torch.cuda.amp.autocast(): 181 | if sampler_type == "k-heun": 182 | return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) 183 | elif sampler_type == "k-lms": 184 | return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) 185 | elif sampler_type == "k-dpmpp-2s-ancestral": 186 | return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) 187 | elif sampler_type == "k-dpm-2": 188 | return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) 189 | elif sampler_type == "k-dpm-fast": 190 | return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args) 191 | elif sampler_type == "k-dpm-adaptive": 192 | return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args) 193 | elif sampler_type == "dpmpp-2m-sde": 194 | return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) 195 | elif sampler_type == "dpmpp-3m-sde": 196 | return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) 197 | 198 | # Uses discrete Euler sampling for rectified flow models 199 | # init_data is init_audio as latents (if this is latent diffusion) 200 | # For sampling, set both init_data and mask to None 201 | # For variations, set init_data 202 | # For inpainting, set both init_data & mask 203 | def sample_rf( 204 | model_fn, 205 | noise, 206 | init_data=None, 207 | steps=100, 208 | sigma_max=1, 209 | device="cuda", 210 | callback=None, 211 | cond_fn=None, 212 | **extra_args 213 | ): 214 | 215 | if sigma_max > 1: 216 | sigma_max = 1 217 | 218 | if cond_fn is not None: 219 | denoiser = make_cond_model_fn(denoiser, cond_fn) 220 | 221 | wrapped_callback = callback 222 | 223 | if init_data is not None: 224 | # VARIATION (no inpainting) 225 | # Interpolate the init data and the noise for init audio 226 | x = init_data * (1 - sigma_max) + noise * sigma_max 227 | else: 228 | # SAMPLING 229 | # set the initial latent to noise 230 | x = noise 231 | 232 | with torch.cuda.amp.autocast(): 233 | # TODO: Add callback support 234 | #return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args) 235 | return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args) -------------------------------------------------------------------------------- /stable_audio_tools/inference/utils.py: -------------------------------------------------------------------------------- 1 | from ..data.utils import PadCrop 2 | 3 | from torchaudio import transforms as T 4 | 5 | def set_audio_channels(audio, target_channels): 6 | if target_channels == 1: 7 | # Convert to mono 8 | audio = audio.mean(1, keepdim=True) 9 | elif target_channels == 2: 10 | # Convert to stereo 11 | if audio.shape[1] == 1: 12 | audio = audio.repeat(1, 2, 1) 13 | elif audio.shape[1] > 2: 14 | audio = audio[:, :2, :] 15 | return audio 16 | 17 | def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): 18 | 19 | audio = audio.to(device) 20 | 21 | if in_sr != target_sr: 22 | resample_tf = T.Resample(in_sr, target_sr).to(device) 23 | audio = resample_tf(audio) 24 | 25 | audio = PadCrop(target_length, randomize=False)(audio) 26 | 27 | # Add batch dimension 28 | if audio.dim() == 1: 29 | audio = audio.unsqueeze(0).unsqueeze(0) 30 | elif audio.dim() == 2: 31 | audio = audio.unsqueeze(0) 32 | 33 | audio = set_audio_channels(audio, target_channels) 34 | 35 | return audio -------------------------------------------------------------------------------- /stable_audio_tools/interface/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeyueT/AudioX/9f673e2499c1127f9f8c8060e5a68b084f8e63fb/stable_audio_tools/interface/__init__.py -------------------------------------------------------------------------------- /stable_audio_tools/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import create_model_from_config, create_model_from_config_path -------------------------------------------------------------------------------- /stable_audio_tools/models/blocks.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from torch.backends.cuda import sdp_kernel 9 | from packaging import version 10 | 11 | from dac.nn.layers import Snake1d 12 | 13 | class ResidualBlock(nn.Module): 14 | def __init__(self, main, skip=None): 15 | super().__init__() 16 | self.main = nn.Sequential(*main) 17 | self.skip = skip if skip else nn.Identity() 18 | 19 | def forward(self, input): 20 | return self.main(input) + self.skip(input) 21 | 22 | class ResConvBlock(ResidualBlock): 23 | def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): 24 | skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) 25 | super().__init__([ 26 | nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), 27 | nn.GroupNorm(1, c_mid), 28 | Snake1d(c_mid) if use_snake else nn.GELU(), 29 | nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), 30 | nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), 31 | (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), 32 | ], skip) 33 | 34 | class SelfAttention1d(nn.Module): 35 | def __init__(self, c_in, n_head=1, dropout_rate=0.): 36 | super().__init__() 37 | assert c_in % n_head == 0 38 | self.norm = nn.GroupNorm(1, c_in) 39 | self.n_head = n_head 40 | self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) 41 | self.out_proj = nn.Conv1d(c_in, c_in, 1) 42 | self.dropout = nn.Dropout(dropout_rate, inplace=True) 43 | 44 | self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') 45 | 46 | if not self.use_flash: 47 | return 48 | 49 | device_properties = torch.cuda.get_device_properties(torch.device('cuda')) 50 | 51 | if device_properties.major == 8 and device_properties.minor == 0: 52 | # Use flash attention for A100 GPUs 53 | self.sdp_kernel_config = (True, False, False) 54 | else: 55 | # Don't use flash attention for other GPUs 56 | self.sdp_kernel_config = (False, True, True) 57 | 58 | def forward(self, input): 59 | n, c, s = input.shape 60 | qkv = self.qkv_proj(self.norm(input)) 61 | qkv = qkv.view( 62 | [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) 63 | q, k, v = qkv.chunk(3, dim=1) 64 | scale = k.shape[3]**-0.25 65 | 66 | if self.use_flash: 67 | with sdp_kernel(*self.sdp_kernel_config): 68 | y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s]) 69 | else: 70 | att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) 71 | y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) 72 | 73 | 74 | return input + self.dropout(self.out_proj(y)) 75 | 76 | class SkipBlock(nn.Module): 77 | def __init__(self, *main): 78 | super().__init__() 79 | self.main = nn.Sequential(*main) 80 | 81 | def forward(self, input): 82 | return torch.cat([self.main(input), input], dim=1) 83 | 84 | class FourierFeatures(nn.Module): 85 | def __init__(self, in_features, out_features, std=1.): 86 | super().__init__() 87 | assert out_features % 2 == 0 88 | self.weight = nn.Parameter(torch.randn( 89 | [out_features // 2, in_features]) * std) 90 | 91 | def forward(self, input): 92 | f = 2 * math.pi * input @ self.weight.T 93 | return torch.cat([f.cos(), f.sin()], dim=-1) 94 | 95 | def expand_to_planes(input, shape): 96 | return input[..., None].repeat([1, 1, shape[2]]) 97 | 98 | _kernels = { 99 | 'linear': 100 | [1 / 8, 3 / 8, 3 / 8, 1 / 8], 101 | 'cubic': 102 | [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 103 | 0.43359375, 0.11328125, -0.03515625, -0.01171875], 104 | 'lanczos3': 105 | [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, 106 | -0.066637322306633, 0.13550527393817902, 0.44638532400131226, 107 | 0.44638532400131226, 0.13550527393817902, -0.066637322306633, 108 | -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] 109 | } 110 | 111 | class Downsample1d(nn.Module): 112 | def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): 113 | super().__init__() 114 | self.pad_mode = pad_mode 115 | kernel_1d = torch.tensor(_kernels[kernel]) 116 | self.pad = kernel_1d.shape[0] // 2 - 1 117 | self.register_buffer('kernel', kernel_1d) 118 | self.channels_last = channels_last 119 | 120 | def forward(self, x): 121 | if self.channels_last: 122 | x = x.permute(0, 2, 1) 123 | x = F.pad(x, (self.pad,) * 2, self.pad_mode) 124 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) 125 | indices = torch.arange(x.shape[1], device=x.device) 126 | weight[indices, indices] = self.kernel.to(weight) 127 | x = F.conv1d(x, weight, stride=2) 128 | if self.channels_last: 129 | x = x.permute(0, 2, 1) 130 | return x 131 | 132 | 133 | class Upsample1d(nn.Module): 134 | def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): 135 | super().__init__() 136 | self.pad_mode = pad_mode 137 | kernel_1d = torch.tensor(_kernels[kernel]) * 2 138 | self.pad = kernel_1d.shape[0] // 2 - 1 139 | self.register_buffer('kernel', kernel_1d) 140 | self.channels_last = channels_last 141 | 142 | def forward(self, x): 143 | if self.channels_last: 144 | x = x.permute(0, 2, 1) 145 | x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) 146 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) 147 | indices = torch.arange(x.shape[1], device=x.device) 148 | weight[indices, indices] = self.kernel.to(weight) 149 | x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) 150 | if self.channels_last: 151 | x = x.permute(0, 2, 1) 152 | return x 153 | 154 | def Downsample1d_2( 155 | in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 156 | ) -> nn.Module: 157 | assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" 158 | 159 | return nn.Conv1d( 160 | in_channels=in_channels, 161 | out_channels=out_channels, 162 | kernel_size=factor * kernel_multiplier + 1, 163 | stride=factor, 164 | padding=factor * (kernel_multiplier // 2), 165 | ) 166 | 167 | 168 | def Upsample1d_2( 169 | in_channels: int, out_channels: int, factor: int, use_nearest: bool = False 170 | ) -> nn.Module: 171 | 172 | if factor == 1: 173 | return nn.Conv1d( 174 | in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 175 | ) 176 | 177 | if use_nearest: 178 | return nn.Sequential( 179 | nn.Upsample(scale_factor=factor, mode="nearest"), 180 | nn.Conv1d( 181 | in_channels=in_channels, 182 | out_channels=out_channels, 183 | kernel_size=3, 184 | padding=1, 185 | ), 186 | ) 187 | else: 188 | return nn.ConvTranspose1d( 189 | in_channels=in_channels, 190 | out_channels=out_channels, 191 | kernel_size=factor * 2, 192 | stride=factor, 193 | padding=factor // 2 + factor % 2, 194 | output_padding=factor % 2, 195 | ) 196 | 197 | def zero_init(layer): 198 | nn.init.zeros_(layer.weight) 199 | if layer.bias is not None: 200 | nn.init.zeros_(layer.bias) 201 | return layer 202 | 203 | def rms_norm(x, scale, eps): 204 | dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) 205 | mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) 206 | scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) 207 | return x * scale.to(x.dtype) 208 | 209 | #rms_norm = torch.compile(rms_norm) 210 | 211 | class AdaRMSNorm(nn.Module): 212 | def __init__(self, features, cond_features, eps=1e-6): 213 | super().__init__() 214 | self.eps = eps 215 | self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) 216 | 217 | def extra_repr(self): 218 | return f"eps={self.eps}," 219 | 220 | def forward(self, x, cond): 221 | return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) 222 | 223 | def normalize(x, eps=1e-4): 224 | dim = list(range(1, x.ndim)) 225 | n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) 226 | alpha = np.sqrt(n.numel() / x.numel()) 227 | return x / torch.add(eps, n, alpha=alpha) 228 | 229 | class ForcedWNConv1d(nn.Module): 230 | def __init__(self, in_channels, out_channels, kernel_size=1): 231 | super().__init__() 232 | self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size])) 233 | 234 | def forward(self, x): 235 | if self.training: 236 | with torch.no_grad(): 237 | self.weight.copy_(normalize(self.weight)) 238 | 239 | fan_in = self.weight[0].numel() 240 | 241 | w = normalize(self.weight) / math.sqrt(fan_in) 242 | 243 | return F.conv1d(x, w, padding='same') 244 | 245 | # Kernels 246 | 247 | use_compile = True 248 | 249 | def compile(function, *args, **kwargs): 250 | if not use_compile: 251 | return function 252 | try: 253 | return torch.compile(function, *args, **kwargs) 254 | except RuntimeError: 255 | return function 256 | 257 | 258 | @compile 259 | def linear_geglu(x, weight, bias=None): 260 | x = x @ weight.mT 261 | if bias is not None: 262 | x = x + bias 263 | x, gate = x.chunk(2, dim=-1) 264 | return x * F.gelu(gate) 265 | 266 | 267 | @compile 268 | def rms_norm(x, scale, eps): 269 | dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) 270 | mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) 271 | scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) 272 | return x * scale.to(x.dtype) 273 | 274 | # Layers 275 | 276 | class LinearGEGLU(nn.Linear): 277 | def __init__(self, in_features, out_features, bias=True): 278 | super().__init__(in_features, out_features * 2, bias=bias) 279 | self.out_features = out_features 280 | 281 | def forward(self, x): 282 | return linear_geglu(x, self.weight, self.bias) 283 | 284 | 285 | class RMSNorm(nn.Module): 286 | def __init__(self, shape, fix_scale = False, eps=1e-6): 287 | super().__init__() 288 | self.eps = eps 289 | 290 | if fix_scale: 291 | self.register_buffer("scale", torch.ones(shape)) 292 | else: 293 | self.scale = nn.Parameter(torch.ones(shape)) 294 | 295 | def extra_repr(self): 296 | return f"shape={tuple(self.scale.shape)}, eps={self.eps}" 297 | 298 | def forward(self, x): 299 | return rms_norm(x, self.scale, self.eps) 300 | 301 | def snake_beta(x, alpha, beta): 302 | return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) 303 | 304 | # try: 305 | # snake_beta = torch.compile(snake_beta) 306 | # except RuntimeError: 307 | # pass 308 | 309 | # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license 310 | # License available in LICENSES/LICENSE_NVIDIA.txt 311 | class SnakeBeta(nn.Module): 312 | 313 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): 314 | super(SnakeBeta, self).__init__() 315 | self.in_features = in_features 316 | 317 | # initialize alpha 318 | self.alpha_logscale = alpha_logscale 319 | if self.alpha_logscale: # log scale alphas initialized to zeros 320 | self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) 321 | self.beta = nn.Parameter(torch.zeros(in_features) * alpha) 322 | else: # linear scale alphas initialized to ones 323 | self.alpha = nn.Parameter(torch.ones(in_features) * alpha) 324 | self.beta = nn.Parameter(torch.ones(in_features) * alpha) 325 | 326 | self.alpha.requires_grad = alpha_trainable 327 | self.beta.requires_grad = alpha_trainable 328 | 329 | self.no_div_by_zero = 0.000000001 330 | 331 | def forward(self, x): 332 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 333 | beta = self.beta.unsqueeze(0).unsqueeze(-1) 334 | if self.alpha_logscale: 335 | alpha = torch.exp(alpha) 336 | beta = torch.exp(beta) 337 | x = snake_beta(x, alpha, beta) 338 | 339 | return x -------------------------------------------------------------------------------- /stable_audio_tools/models/bottleneck.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from einops import rearrange 7 | from vector_quantize_pytorch import ResidualVQ, FSQ 8 | from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ 9 | 10 | class Bottleneck(nn.Module): 11 | def __init__(self, is_discrete: bool = False): 12 | super().__init__() 13 | 14 | self.is_discrete = is_discrete 15 | 16 | def encode(self, x, return_info=False, **kwargs): 17 | raise NotImplementedError 18 | 19 | def decode(self, x): 20 | raise NotImplementedError 21 | 22 | class DiscreteBottleneck(Bottleneck): 23 | def __init__(self, num_quantizers, codebook_size, tokens_id): 24 | super().__init__(is_discrete=True) 25 | 26 | self.num_quantizers = num_quantizers 27 | self.codebook_size = codebook_size 28 | self.tokens_id = tokens_id 29 | 30 | def decode_tokens(self, codes, **kwargs): 31 | raise NotImplementedError 32 | 33 | class TanhBottleneck(Bottleneck): 34 | def __init__(self): 35 | super().__init__(is_discrete=False) 36 | self.tanh = nn.Tanh() 37 | 38 | def encode(self, x, return_info=False): 39 | info = {} 40 | 41 | x = torch.tanh(x) 42 | 43 | if return_info: 44 | return x, info 45 | else: 46 | return x 47 | 48 | def decode(self, x): 49 | return x 50 | 51 | def vae_sample(mean, scale): 52 | stdev = nn.functional.softplus(scale) + 1e-4 53 | var = stdev * stdev 54 | logvar = torch.log(var) 55 | latents = torch.randn_like(mean) * stdev + mean 56 | 57 | kl = (mean * mean + var - logvar - 1).sum(1).mean() 58 | 59 | return latents, kl 60 | 61 | class VAEBottleneck(Bottleneck): 62 | def __init__(self): 63 | super().__init__(is_discrete=False) 64 | 65 | def encode(self, x, return_info=False, **kwargs): 66 | info = {} 67 | 68 | mean, scale = x.chunk(2, dim=1) 69 | 70 | x, kl = vae_sample(mean, scale) 71 | 72 | info["kl"] = kl 73 | 74 | if return_info: 75 | return x, info 76 | else: 77 | return x 78 | 79 | def decode(self, x): 80 | return x 81 | 82 | def compute_mean_kernel(x, y): 83 | kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] 84 | return torch.exp(-kernel_input).mean() 85 | 86 | def compute_mmd(latents): 87 | latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1]) 88 | noise = torch.randn_like(latents_reshaped) 89 | 90 | latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped) 91 | noise_kernel = compute_mean_kernel(noise, noise) 92 | latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise) 93 | 94 | mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel 95 | return mmd.mean() 96 | 97 | class WassersteinBottleneck(Bottleneck): 98 | def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False): 99 | super().__init__(is_discrete=False) 100 | 101 | self.noise_augment_dim = noise_augment_dim 102 | self.bypass_mmd = bypass_mmd 103 | 104 | def encode(self, x, return_info=False): 105 | info = {} 106 | 107 | if self.training and return_info: 108 | if self.bypass_mmd: 109 | mmd = torch.tensor(0.0) 110 | else: 111 | mmd = compute_mmd(x) 112 | 113 | info["mmd"] = mmd 114 | 115 | if return_info: 116 | return x, info 117 | 118 | return x 119 | 120 | def decode(self, x): 121 | 122 | if self.noise_augment_dim > 0: 123 | noise = torch.randn(x.shape[0], self.noise_augment_dim, 124 | x.shape[-1]).type_as(x) 125 | x = torch.cat([x, noise], dim=1) 126 | 127 | return x 128 | 129 | class L2Bottleneck(Bottleneck): 130 | def __init__(self): 131 | super().__init__(is_discrete=False) 132 | 133 | def encode(self, x, return_info=False): 134 | info = {} 135 | 136 | x = F.normalize(x, dim=1) 137 | 138 | if return_info: 139 | return x, info 140 | else: 141 | return x 142 | 143 | def decode(self, x): 144 | return F.normalize(x, dim=1) 145 | 146 | class RVQBottleneck(DiscreteBottleneck): 147 | def __init__(self, **quantizer_kwargs): 148 | super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") 149 | self.quantizer = ResidualVQ(**quantizer_kwargs) 150 | self.num_quantizers = quantizer_kwargs["num_quantizers"] 151 | 152 | def encode(self, x, return_info=False, **kwargs): 153 | info = {} 154 | 155 | x = rearrange(x, "b c n -> b n c") 156 | x, indices, loss = self.quantizer(x) 157 | x = rearrange(x, "b n c -> b c n") 158 | 159 | info["quantizer_indices"] = indices 160 | info["quantizer_loss"] = loss.mean() 161 | 162 | if return_info: 163 | return x, info 164 | else: 165 | return x 166 | 167 | def decode(self, x): 168 | return x 169 | 170 | def decode_tokens(self, codes, **kwargs): 171 | latents = self.quantizer.get_outputs_from_indices(codes) 172 | 173 | return self.decode(latents, **kwargs) 174 | 175 | class RVQVAEBottleneck(DiscreteBottleneck): 176 | def __init__(self, **quantizer_kwargs): 177 | super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") 178 | self.quantizer = ResidualVQ(**quantizer_kwargs) 179 | self.num_quantizers = quantizer_kwargs["num_quantizers"] 180 | 181 | def encode(self, x, return_info=False): 182 | info = {} 183 | 184 | x, kl = vae_sample(*x.chunk(2, dim=1)) 185 | 186 | info["kl"] = kl 187 | 188 | x = rearrange(x, "b c n -> b n c") 189 | x, indices, loss = self.quantizer(x) 190 | x = rearrange(x, "b n c -> b c n") 191 | 192 | info["quantizer_indices"] = indices 193 | info["quantizer_loss"] = loss.mean() 194 | 195 | if return_info: 196 | return x, info 197 | else: 198 | return x 199 | 200 | def decode(self, x): 201 | return x 202 | 203 | def decode_tokens(self, codes, **kwargs): 204 | latents = self.quantizer.get_outputs_from_indices(codes) 205 | 206 | return self.decode(latents, **kwargs) 207 | 208 | class DACRVQBottleneck(DiscreteBottleneck): 209 | def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs): 210 | super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") 211 | self.quantizer = DACResidualVQ(**quantizer_kwargs) 212 | self.num_quantizers = quantizer_kwargs["n_codebooks"] 213 | self.quantize_on_decode = quantize_on_decode 214 | self.noise_augment_dim = noise_augment_dim 215 | 216 | def encode(self, x, return_info=False, **kwargs): 217 | info = {} 218 | 219 | info["pre_quantizer"] = x 220 | 221 | if self.quantize_on_decode: 222 | return x, info if return_info else x 223 | 224 | z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs) 225 | 226 | output = { 227 | "z": z, 228 | "codes": codes, 229 | "latents": latents, 230 | "vq/commitment_loss": commitment_loss, 231 | "vq/codebook_loss": codebook_loss, 232 | } 233 | 234 | output["vq/commitment_loss"] /= self.num_quantizers 235 | output["vq/codebook_loss"] /= self.num_quantizers 236 | 237 | info.update(output) 238 | 239 | if return_info: 240 | return output["z"], info 241 | 242 | return output["z"] 243 | 244 | def decode(self, x): 245 | 246 | if self.quantize_on_decode: 247 | x = self.quantizer(x)[0] 248 | 249 | if self.noise_augment_dim > 0: 250 | noise = torch.randn(x.shape[0], self.noise_augment_dim, 251 | x.shape[-1]).type_as(x) 252 | x = torch.cat([x, noise], dim=1) 253 | 254 | return x 255 | 256 | def decode_tokens(self, codes, **kwargs): 257 | latents, _, _ = self.quantizer.from_codes(codes) 258 | 259 | return self.decode(latents, **kwargs) 260 | 261 | class DACRVQVAEBottleneck(DiscreteBottleneck): 262 | def __init__(self, quantize_on_decode=False, **quantizer_kwargs): 263 | super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") 264 | self.quantizer = DACResidualVQ(**quantizer_kwargs) 265 | self.num_quantizers = quantizer_kwargs["n_codebooks"] 266 | self.quantize_on_decode = quantize_on_decode 267 | 268 | def encode(self, x, return_info=False, n_quantizers: int = None): 269 | info = {} 270 | 271 | mean, scale = x.chunk(2, dim=1) 272 | 273 | x, kl = vae_sample(mean, scale) 274 | 275 | info["pre_quantizer"] = x 276 | info["kl"] = kl 277 | 278 | if self.quantize_on_decode: 279 | return x, info if return_info else x 280 | 281 | z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers) 282 | 283 | output = { 284 | "z": z, 285 | "codes": codes, 286 | "latents": latents, 287 | "vq/commitment_loss": commitment_loss, 288 | "vq/codebook_loss": codebook_loss, 289 | } 290 | 291 | output["vq/commitment_loss"] /= self.num_quantizers 292 | output["vq/codebook_loss"] /= self.num_quantizers 293 | 294 | info.update(output) 295 | 296 | if return_info: 297 | return output["z"], info 298 | 299 | return output["z"] 300 | 301 | def decode(self, x): 302 | 303 | if self.quantize_on_decode: 304 | x = self.quantizer(x)[0] 305 | 306 | return x 307 | 308 | def decode_tokens(self, codes, **kwargs): 309 | latents, _, _ = self.quantizer.from_codes(codes) 310 | 311 | return self.decode(latents, **kwargs) 312 | 313 | class FSQBottleneck(DiscreteBottleneck): 314 | def __init__(self, noise_augment_dim=0, **kwargs): 315 | super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices") 316 | 317 | self.noise_augment_dim = noise_augment_dim 318 | 319 | self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64]) 320 | 321 | def encode(self, x, return_info=False): 322 | info = {} 323 | 324 | orig_dtype = x.dtype 325 | x = x.float() 326 | 327 | x = rearrange(x, "b c n -> b n c") 328 | x, indices = self.quantizer(x) 329 | x = rearrange(x, "b n c -> b c n") 330 | 331 | x = x.to(orig_dtype) 332 | 333 | # Reorder indices to match the expected format 334 | indices = rearrange(indices, "b n q -> b q n") 335 | 336 | info["quantizer_indices"] = indices 337 | 338 | if return_info: 339 | return x, info 340 | else: 341 | return x 342 | 343 | def decode(self, x): 344 | 345 | if self.noise_augment_dim > 0: 346 | noise = torch.randn(x.shape[0], self.noise_augment_dim, 347 | x.shape[-1]).type_as(x) 348 | x = torch.cat([x, noise], dim=1) 349 | 350 | return x 351 | 352 | def decode_tokens(self, tokens, **kwargs): 353 | latents = self.quantizer.indices_to_codes(tokens) 354 | 355 | return self.decode(latents, **kwargs) -------------------------------------------------------------------------------- /stable_audio_tools/models/discriminators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from functools import reduce 6 | import typing as tp 7 | from einops import rearrange 8 | from audiotools import AudioSignal, STFTParams 9 | from dac.model.discriminator import WNConv1d, WNConv2d 10 | 11 | def get_hinge_losses(score_real, score_fake): 12 | gen_loss = -score_fake.mean() 13 | dis_loss = torch.relu(1 - score_real).mean() + torch.relu(1 + score_fake).mean() 14 | return dis_loss, gen_loss 15 | 16 | class EncodecDiscriminator(nn.Module): 17 | 18 | def __init__(self, *args, **kwargs): 19 | super().__init__() 20 | 21 | from encodec.msstftd import MultiScaleSTFTDiscriminator 22 | 23 | self.discriminators = MultiScaleSTFTDiscriminator(*args, **kwargs) 24 | 25 | def forward(self, x): 26 | logits, features = self.discriminators(x) 27 | return logits, features 28 | 29 | def loss(self, x, y): 30 | feature_matching_distance = 0. 31 | logits_true, feature_true = self.forward(x) 32 | logits_fake, feature_fake = self.forward(y) 33 | 34 | dis_loss = torch.tensor(0.) 35 | adv_loss = torch.tensor(0.) 36 | 37 | for i, (scale_true, scale_fake) in enumerate(zip(feature_true, feature_fake)): 38 | 39 | feature_matching_distance = feature_matching_distance + sum( 40 | map( 41 | lambda x, y: abs(x - y).mean(), 42 | scale_true, 43 | scale_fake, 44 | )) / len(scale_true) 45 | 46 | _dis, _adv = get_hinge_losses( 47 | logits_true[i], 48 | logits_fake[i], 49 | ) 50 | 51 | dis_loss = dis_loss + _dis 52 | adv_loss = adv_loss + _adv 53 | 54 | return dis_loss, adv_loss, feature_matching_distance 55 | 56 | # Discriminators from oobleck 57 | 58 | IndividualDiscriminatorOut = tp.Tuple[torch.Tensor, tp.Sequence[torch.Tensor]] 59 | 60 | TensorDict = tp.Dict[str, torch.Tensor] 61 | 62 | class SharedDiscriminatorConvNet(nn.Module): 63 | 64 | def __init__( 65 | self, 66 | in_size: int, 67 | convolution: tp.Union[nn.Conv1d, nn.Conv2d], 68 | out_size: int = 1, 69 | capacity: int = 32, 70 | n_layers: int = 4, 71 | kernel_size: int = 15, 72 | stride: int = 4, 73 | activation: tp.Callable[[], nn.Module] = lambda: nn.SiLU(), 74 | normalization: tp.Callable[[nn.Module], nn.Module] = torch.nn.utils.weight_norm, 75 | ) -> None: 76 | super().__init__() 77 | channels = [in_size] 78 | channels += list(capacity * 2**np.arange(n_layers)) 79 | 80 | if isinstance(stride, int): 81 | stride = n_layers * [stride] 82 | 83 | net = [] 84 | for i in range(n_layers): 85 | if isinstance(kernel_size, int): 86 | pad = kernel_size // 2 87 | s = stride[i] 88 | else: 89 | pad = kernel_size[0] // 2 90 | s = (stride[i], 1) 91 | 92 | net.append( 93 | normalization( 94 | convolution( 95 | channels[i], 96 | channels[i + 1], 97 | kernel_size, 98 | stride=s, 99 | padding=pad, 100 | ))) 101 | net.append(activation()) 102 | 103 | net.append(convolution(channels[-1], out_size, 1)) 104 | 105 | self.net = nn.ModuleList(net) 106 | 107 | def forward(self, x) -> IndividualDiscriminatorOut: 108 | features = [] 109 | for layer in self.net: 110 | x = layer(x) 111 | if isinstance(layer, nn.modules.conv._ConvNd): 112 | features.append(x) 113 | score = x.reshape(x.shape[0], -1).mean(-1) 114 | return score, features 115 | 116 | 117 | class MultiScaleDiscriminator(nn.Module): 118 | 119 | def __init__(self, 120 | in_channels: int, 121 | n_scales: int, 122 | **conv_kwargs) -> None: 123 | super().__init__() 124 | layers = [] 125 | for _ in range(n_scales): 126 | layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv1d, **conv_kwargs)) 127 | self.layers = nn.ModuleList(layers) 128 | 129 | def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut: 130 | score = 0 131 | features = [] 132 | for layer in self.layers: 133 | s, f = layer(x) 134 | score = score + s 135 | features.extend(f) 136 | x = nn.functional.avg_pool1d(x, 2) 137 | return score, features 138 | 139 | class MultiPeriodDiscriminator(nn.Module): 140 | 141 | def __init__(self, 142 | in_channels: int, 143 | periods: tp.Sequence[int], 144 | **conv_kwargs) -> None: 145 | super().__init__() 146 | layers = [] 147 | self.periods = periods 148 | 149 | for _ in periods: 150 | layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv2d, **conv_kwargs)) 151 | 152 | self.layers = nn.ModuleList(layers) 153 | 154 | def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut: 155 | score = 0 156 | features = [] 157 | for layer, n in zip(self.layers, self.periods): 158 | s, f = layer(self.fold(x, n)) 159 | score = score + s 160 | features.extend(f) 161 | return score, features 162 | 163 | def fold(self, x: torch.Tensor, n: int) -> torch.Tensor: 164 | pad = (n - (x.shape[-1] % n)) % n 165 | x = nn.functional.pad(x, (0, pad)) 166 | return x.reshape(*x.shape[:2], -1, n) 167 | 168 | 169 | class MultiDiscriminator(nn.Module): 170 | """ 171 | Individual discriminators should take a single tensor as input (NxB C T) and 172 | return a tuple composed of a score tensor (NxB) and a Sequence of Features 173 | Sequence[NxB C' T']. 174 | """ 175 | 176 | def __init__(self, discriminator_list: tp.Sequence[nn.Module], 177 | keys: tp.Sequence[str]) -> None: 178 | super().__init__() 179 | self.discriminators = nn.ModuleList(discriminator_list) 180 | self.keys = keys 181 | 182 | def unpack_tensor_to_dict(self, features: torch.Tensor) -> TensorDict: 183 | features = features.chunk(len(self.keys), 0) 184 | return {k: features[i] for i, k in enumerate(self.keys)} 185 | 186 | @staticmethod 187 | def concat_dicts(dict_a, dict_b): 188 | out_dict = {} 189 | keys = set(list(dict_a.keys()) + list(dict_b.keys())) 190 | for k in keys: 191 | out_dict[k] = [] 192 | if k in dict_a: 193 | if isinstance(dict_a[k], list): 194 | out_dict[k].extend(dict_a[k]) 195 | else: 196 | out_dict[k].append(dict_a[k]) 197 | if k in dict_b: 198 | if isinstance(dict_b[k], list): 199 | out_dict[k].extend(dict_b[k]) 200 | else: 201 | out_dict[k].append(dict_b[k]) 202 | return out_dict 203 | 204 | @staticmethod 205 | def sum_dicts(dict_a, dict_b): 206 | out_dict = {} 207 | keys = set(list(dict_a.keys()) + list(dict_b.keys())) 208 | for k in keys: 209 | out_dict[k] = 0. 210 | if k in dict_a: 211 | out_dict[k] = out_dict[k] + dict_a[k] 212 | if k in dict_b: 213 | out_dict[k] = out_dict[k] + dict_b[k] 214 | return out_dict 215 | 216 | def forward(self, inputs: TensorDict) -> TensorDict: 217 | discriminator_input = torch.cat([inputs[k] for k in self.keys], 0) 218 | all_scores = [] 219 | all_features = [] 220 | 221 | for discriminator in self.discriminators: 222 | score, features = discriminator(discriminator_input) 223 | scores = self.unpack_tensor_to_dict(score) 224 | scores = {f"score_{k}": scores[k] for k in scores.keys()} 225 | all_scores.append(scores) 226 | 227 | features = map(self.unpack_tensor_to_dict, features) 228 | features = reduce(self.concat_dicts, features) 229 | features = {f"features_{k}": features[k] for k in features.keys()} 230 | all_features.append(features) 231 | 232 | all_scores = reduce(self.sum_dicts, all_scores) 233 | all_features = reduce(self.concat_dicts, all_features) 234 | 235 | inputs.update(all_scores) 236 | inputs.update(all_features) 237 | 238 | return inputs 239 | 240 | class OobleckDiscriminator(nn.Module): 241 | 242 | def __init__( 243 | self, 244 | in_channels=1, 245 | ): 246 | super().__init__() 247 | 248 | multi_scale_discriminator = MultiScaleDiscriminator( 249 | in_channels=in_channels, 250 | n_scales=3, 251 | ) 252 | 253 | multi_period_discriminator = MultiPeriodDiscriminator( 254 | in_channels=in_channels, 255 | periods=[2, 3, 5, 7, 11] 256 | ) 257 | 258 | # multi_resolution_discriminator = MultiScaleSTFTDiscriminator( 259 | # filters=32, 260 | # in_channels = in_channels, 261 | # out_channels = 1, 262 | # n_ffts = [2048, 1024, 512, 256, 128], 263 | # hop_lengths = [512, 256, 128, 64, 32], 264 | # win_lengths = [2048, 1024, 512, 256, 128] 265 | # ) 266 | 267 | self.multi_discriminator = MultiDiscriminator( 268 | [multi_scale_discriminator, multi_period_discriminator], #, multi_resolution_discriminator], 269 | ["reals", "fakes"] 270 | ) 271 | 272 | def loss(self, reals, fakes): 273 | inputs = { 274 | "reals": reals, 275 | "fakes": fakes, 276 | } 277 | 278 | inputs = self.multi_discriminator(inputs) 279 | 280 | scores_real = inputs["score_reals"] 281 | scores_fake = inputs["score_fakes"] 282 | 283 | features_real = inputs["features_reals"] 284 | features_fake = inputs["features_fakes"] 285 | 286 | dis_loss, gen_loss = get_hinge_losses(scores_real, scores_fake) 287 | 288 | feature_matching_distance = torch.tensor(0.) 289 | 290 | for _, (scale_real, scale_fake) in enumerate(zip(features_real, features_fake)): 291 | 292 | feature_matching_distance = feature_matching_distance + sum( 293 | map( 294 | lambda real, fake: abs(real - fake).mean(), 295 | scale_real, 296 | scale_fake, 297 | )) / len(scale_real) 298 | 299 | return dis_loss, gen_loss, feature_matching_distance 300 | 301 | 302 | ## Discriminators from Descript Audio Codec repo 303 | ## Copied and modified under MIT license, see LICENSES/LICENSE_DESCRIPT.txt 304 | class MPD(nn.Module): 305 | def __init__(self, period, channels=1): 306 | super().__init__() 307 | 308 | self.period = period 309 | self.convs = nn.ModuleList( 310 | [ 311 | WNConv2d(channels, 32, (5, 1), (3, 1), padding=(2, 0)), 312 | WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), 313 | WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), 314 | WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), 315 | WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), 316 | ] 317 | ) 318 | self.conv_post = WNConv2d( 319 | 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False 320 | ) 321 | 322 | def pad_to_period(self, x): 323 | t = x.shape[-1] 324 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect") 325 | return x 326 | 327 | def forward(self, x): 328 | fmap = [] 329 | 330 | x = self.pad_to_period(x) 331 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period) 332 | 333 | for layer in self.convs: 334 | x = layer(x) 335 | fmap.append(x) 336 | 337 | x = self.conv_post(x) 338 | fmap.append(x) 339 | 340 | return fmap 341 | 342 | 343 | class MSD(nn.Module): 344 | def __init__(self, rate: int = 1, sample_rate: int = 44100, channels=1): 345 | super().__init__() 346 | 347 | self.convs = nn.ModuleList( 348 | [ 349 | WNConv1d(channels, 16, 15, 1, padding=7), 350 | WNConv1d(16, 64, 41, 4, groups=4, padding=20), 351 | WNConv1d(64, 256, 41, 4, groups=16, padding=20), 352 | WNConv1d(256, 1024, 41, 4, groups=64, padding=20), 353 | WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), 354 | WNConv1d(1024, 1024, 5, 1, padding=2), 355 | ] 356 | ) 357 | self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) 358 | self.sample_rate = sample_rate 359 | self.rate = rate 360 | 361 | def forward(self, x): 362 | x = AudioSignal(x, self.sample_rate) 363 | x.resample(self.sample_rate // self.rate) 364 | x = x.audio_data 365 | 366 | fmap = [] 367 | 368 | for l in self.convs: 369 | x = l(x) 370 | fmap.append(x) 371 | x = self.conv_post(x) 372 | fmap.append(x) 373 | 374 | return fmap 375 | 376 | 377 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] 378 | 379 | 380 | class MRD(nn.Module): 381 | def __init__( 382 | self, 383 | window_length: int, 384 | hop_factor: float = 0.25, 385 | sample_rate: int = 44100, 386 | bands: list = BANDS, 387 | channels: int = 1 388 | ): 389 | """Complex multi-band spectrogram discriminator. 390 | Parameters 391 | ---------- 392 | window_length : int 393 | Window length of STFT. 394 | hop_factor : float, optional 395 | Hop factor of the STFT, defaults to ``0.25 * window_length``. 396 | sample_rate : int, optional 397 | Sampling rate of audio in Hz, by default 44100 398 | bands : list, optional 399 | Bands to run discriminator over. 400 | """ 401 | super().__init__() 402 | 403 | self.window_length = window_length 404 | self.hop_factor = hop_factor 405 | self.sample_rate = sample_rate 406 | self.stft_params = STFTParams( 407 | window_length=window_length, 408 | hop_length=int(window_length * hop_factor), 409 | match_stride=True, 410 | ) 411 | 412 | self.channels = channels 413 | 414 | n_fft = window_length // 2 + 1 415 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] 416 | self.bands = bands 417 | 418 | ch = 32 419 | convs = lambda: nn.ModuleList( 420 | [ 421 | WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), 422 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 423 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 424 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 425 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), 426 | ] 427 | ) 428 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) 429 | self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) 430 | 431 | def spectrogram(self, x): 432 | x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) 433 | x = torch.view_as_real(x.stft()) 434 | x = rearrange(x, "b ch f t c -> (b ch) c t f", ch=self.channels) 435 | # Split into bands 436 | x_bands = [x[..., b[0] : b[1]] for b in self.bands] 437 | return x_bands 438 | 439 | def forward(self, x): 440 | x_bands = self.spectrogram(x) 441 | fmap = [] 442 | 443 | x = [] 444 | for band, stack in zip(x_bands, self.band_convs): 445 | for layer in stack: 446 | band = layer(band) 447 | fmap.append(band) 448 | x.append(band) 449 | 450 | x = torch.cat(x, dim=-1) 451 | x = self.conv_post(x) 452 | fmap.append(x) 453 | 454 | return fmap 455 | 456 | 457 | class DACDiscriminator(nn.Module): 458 | def __init__( 459 | self, 460 | channels: int = 1, 461 | rates: list = [], 462 | periods: list = [2, 3, 5, 7, 11], 463 | fft_sizes: list = [2048, 1024, 512], 464 | sample_rate: int = 44100, 465 | bands: list = BANDS, 466 | ): 467 | """Discriminator that combines multiple discriminators. 468 | 469 | Parameters 470 | ---------- 471 | rates : list, optional 472 | sampling rates (in Hz) to run MSD at, by default [] 473 | If empty, MSD is not used. 474 | periods : list, optional 475 | periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] 476 | fft_sizes : list, optional 477 | Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] 478 | sample_rate : int, optional 479 | Sampling rate of audio in Hz, by default 44100 480 | bands : list, optional 481 | Bands to run MRD at, by default `BANDS` 482 | """ 483 | super().__init__() 484 | discs = [] 485 | discs += [MPD(p, channels=channels) for p in periods] 486 | discs += [MSD(r, sample_rate=sample_rate, channels=channels) for r in rates] 487 | discs += [MRD(f, sample_rate=sample_rate, bands=bands, channels=channels) for f in fft_sizes] 488 | self.discriminators = nn.ModuleList(discs) 489 | 490 | def preprocess(self, y): 491 | # Remove DC offset 492 | y = y - y.mean(dim=-1, keepdims=True) 493 | # Peak normalize the volume of input audio 494 | y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) 495 | return y 496 | 497 | def forward(self, x): 498 | x = self.preprocess(x) 499 | fmaps = [d(x) for d in self.discriminators] 500 | return fmaps 501 | 502 | class DACGANLoss(nn.Module): 503 | """ 504 | Computes a discriminator loss, given a discriminator on 505 | generated waveforms/spectrograms compared to ground truth 506 | waveforms/spectrograms. Computes the loss for both the 507 | discriminator and the generator in separate functions. 508 | """ 509 | 510 | def __init__(self, **discriminator_kwargs): 511 | super().__init__() 512 | self.discriminator = DACDiscriminator(**discriminator_kwargs) 513 | 514 | def forward(self, fake, real): 515 | d_fake = self.discriminator(fake) 516 | d_real = self.discriminator(real) 517 | return d_fake, d_real 518 | 519 | def discriminator_loss(self, fake, real): 520 | d_fake, d_real = self.forward(fake.clone().detach(), real) 521 | 522 | loss_d = 0 523 | for x_fake, x_real in zip(d_fake, d_real): 524 | loss_d += torch.mean(x_fake[-1] ** 2) 525 | loss_d += torch.mean((1 - x_real[-1]) ** 2) 526 | return loss_d 527 | 528 | def generator_loss(self, fake, real): 529 | d_fake, d_real = self.forward(fake, real) 530 | 531 | loss_g = 0 532 | for x_fake in d_fake: 533 | loss_g += torch.mean((1 - x_fake[-1]) ** 2) 534 | 535 | loss_feature = 0 536 | 537 | for i in range(len(d_fake)): 538 | for j in range(len(d_fake[i]) - 1): 539 | loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) 540 | return loss_g, loss_feature 541 | 542 | def loss(self, fake, real): 543 | gen_loss, feature_distance = self.generator_loss(fake, real) 544 | dis_loss = self.discriminator_loss(fake, real) 545 | 546 | return dis_loss, gen_loss, feature_distance -------------------------------------------------------------------------------- /stable_audio_tools/models/dit.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import torch 4 | 5 | from einops import rearrange 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from x_transformers import ContinuousTransformerWrapper, Encoder 9 | 10 | from .blocks import FourierFeatures 11 | from .transformer import ContinuousTransformer 12 | 13 | class DiffusionTransformer(nn.Module): 14 | def __init__(self, 15 | io_channels=32, 16 | patch_size=1, 17 | embed_dim=768, 18 | cond_token_dim=0, 19 | project_cond_tokens=True, 20 | global_cond_dim=0, 21 | project_global_cond=True, 22 | input_concat_dim=0, 23 | prepend_cond_dim=0, 24 | depth=12, 25 | num_heads=8, 26 | transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers", 27 | global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", 28 | **kwargs): 29 | 30 | super().__init__() 31 | 32 | self.cond_token_dim = cond_token_dim 33 | 34 | # Timestep embeddings 35 | timestep_features_dim = 256 36 | 37 | self.timestep_features = FourierFeatures(1, timestep_features_dim) 38 | 39 | self.to_timestep_embed = nn.Sequential( 40 | nn.Linear(timestep_features_dim, embed_dim, bias=True), 41 | nn.SiLU(), 42 | nn.Linear(embed_dim, embed_dim, bias=True), 43 | ) 44 | 45 | if cond_token_dim > 0: 46 | # Conditioning tokens 47 | cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim 48 | self.to_cond_embed = nn.Sequential( 49 | nn.Linear(cond_token_dim, cond_embed_dim, bias=False), 50 | nn.SiLU(), 51 | nn.Linear(cond_embed_dim, cond_embed_dim, bias=False) 52 | ) 53 | else: 54 | cond_embed_dim = 0 55 | 56 | if global_cond_dim > 0: 57 | # Global conditioning 58 | global_embed_dim = global_cond_dim if not project_global_cond else embed_dim 59 | self.to_global_embed = nn.Sequential( 60 | nn.Linear(global_cond_dim, global_embed_dim, bias=False), 61 | nn.SiLU(), 62 | nn.Linear(global_embed_dim, global_embed_dim, bias=False) 63 | ) 64 | 65 | if prepend_cond_dim > 0: 66 | # Prepend conditioning 67 | self.to_prepend_embed = nn.Sequential( 68 | nn.Linear(prepend_cond_dim, embed_dim, bias=False), 69 | nn.SiLU(), 70 | nn.Linear(embed_dim, embed_dim, bias=False) 71 | ) 72 | 73 | self.input_concat_dim = input_concat_dim 74 | 75 | dim_in = io_channels + self.input_concat_dim 76 | 77 | self.patch_size = patch_size 78 | 79 | # Transformer 80 | 81 | self.transformer_type = transformer_type 82 | 83 | self.global_cond_type = global_cond_type 84 | 85 | if self.transformer_type == "x-transformers": 86 | self.transformer = ContinuousTransformerWrapper( 87 | dim_in=dim_in * patch_size, 88 | dim_out=io_channels * patch_size, 89 | max_seq_len=0, #Not relevant without absolute positional embeds 90 | attn_layers = Encoder( 91 | dim=embed_dim, 92 | depth=depth, 93 | heads=num_heads, 94 | attn_flash = True, 95 | cross_attend = cond_token_dim > 0, 96 | dim_context=None if cond_embed_dim == 0 else cond_embed_dim, 97 | zero_init_branch_output=True, 98 | use_abs_pos_emb = False, 99 | rotary_pos_emb=True, 100 | ff_swish = True, 101 | ff_glu = True, 102 | **kwargs 103 | ) 104 | ) 105 | 106 | elif self.transformer_type == "continuous_transformer": 107 | 108 | global_dim = None 109 | 110 | if self.global_cond_type == "adaLN": 111 | # The global conditioning is projected to the embed_dim already at this point 112 | global_dim = embed_dim 113 | 114 | self.transformer = ContinuousTransformer( 115 | dim=embed_dim, 116 | depth=depth, 117 | dim_heads=embed_dim // num_heads, 118 | dim_in=dim_in * patch_size, 119 | dim_out=io_channels * patch_size, 120 | cross_attend = cond_token_dim > 0, 121 | cond_token_dim = cond_embed_dim, 122 | global_cond_dim=global_dim, 123 | **kwargs 124 | ) 125 | 126 | else: 127 | raise ValueError(f"Unknown transformer type: {self.transformer_type}") 128 | 129 | self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False) 130 | nn.init.zeros_(self.preprocess_conv.weight) 131 | self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False) 132 | nn.init.zeros_(self.postprocess_conv.weight) 133 | 134 | def _forward( 135 | self, 136 | x, 137 | t, 138 | mask=None, 139 | cross_attn_cond=None, 140 | cross_attn_cond_mask=None, 141 | input_concat_cond=None, 142 | global_embed=None, 143 | prepend_cond=None, 144 | prepend_cond_mask=None, 145 | return_info=False, 146 | **kwargs): 147 | 148 | if cross_attn_cond is not None: 149 | cross_attn_cond = self.to_cond_embed(cross_attn_cond) # MLP endecoder, shape: [1, 130, 768] 150 | 151 | if global_embed is not None: 152 | # Project the global conditioning to the embedding dimension 153 | global_embed = self.to_global_embed(global_embed) 154 | 155 | prepend_inputs = None 156 | prepend_mask = None 157 | prepend_length = 0 158 | if prepend_cond is not None: 159 | # Project the prepend conditioning to the embedding dimension 160 | prepend_cond = self.to_prepend_embed(prepend_cond) 161 | 162 | prepend_inputs = prepend_cond 163 | if prepend_cond_mask is not None: 164 | prepend_mask = prepend_cond_mask 165 | 166 | if input_concat_cond is not None: 167 | 168 | # Interpolate input_concat_cond to the same length as x 169 | if input_concat_cond.shape[2] != x.shape[2]: 170 | input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') 171 | 172 | x = torch.cat([x, input_concat_cond], dim=1) 173 | 174 | # Get the batch of timestep embeddings 175 | timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) 176 | 177 | # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists 178 | if global_embed is not None: 179 | global_embed = global_embed + timestep_embed 180 | else: 181 | global_embed = timestep_embed 182 | 183 | # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer 184 | if self.global_cond_type == "prepend": # True 185 | if prepend_inputs is None: # True 186 | # Prepend inputs are just the global embed, and the mask is all ones 187 | prepend_inputs = global_embed.unsqueeze(1) # [1, 1, 1536] 188 | prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) 189 | else: 190 | # Prepend inputs are the prepend conditioning + the global embed 191 | prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) 192 | prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1) 193 | 194 | prepend_length = prepend_inputs.shape[1] # 1 195 | 196 | x = self.preprocess_conv(x) + x # [1, 64, 1024] 197 | 198 | x = rearrange(x, "b c t -> b t c") # [1, 1024, 64] 199 | 200 | extra_args = {} 201 | 202 | if self.global_cond_type == "adaLN": # 'prepend' 203 | extra_args["global_cond"] = global_embed 204 | 205 | if self.patch_size > 1: # self.patch_size==1 206 | x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) 207 | 208 | if self.transformer_type == "x-transformers": 209 | output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs) 210 | elif self.transformer_type == "continuous_transformer": 211 | 212 | output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs) 213 | 214 | if return_info: 215 | output, info = output 216 | elif self.transformer_type == "mm_transformer": 217 | output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs) 218 | 219 | output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:] 220 | 221 | if self.patch_size > 1: 222 | output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) 223 | 224 | output = self.postprocess_conv(output) + output 225 | 226 | if return_info: 227 | return output, info 228 | 229 | return output 230 | 231 | def forward( 232 | self, 233 | x, 234 | t, 235 | cross_attn_cond=None, 236 | cross_attn_cond_mask=None, 237 | negative_cross_attn_cond=None, 238 | negative_cross_attn_mask=None, 239 | input_concat_cond=None, 240 | global_embed=None, 241 | negative_global_embed=None, 242 | prepend_cond=None, 243 | prepend_cond_mask=None, 244 | cfg_scale=1.0, 245 | cfg_dropout_prob=0.0, 246 | causal=False, 247 | scale_phi=0.0, 248 | mask=None, 249 | return_info=False, 250 | **kwargs): 251 | 252 | assert causal == False, "Causal mode is not supported for DiffusionTransformer" 253 | 254 | if cross_attn_cond_mask is not None: 255 | cross_attn_cond_mask = cross_attn_cond_mask.bool() 256 | 257 | cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention 258 | 259 | if prepend_cond_mask is not None: 260 | prepend_cond_mask = prepend_cond_mask.bool() 261 | 262 | # CFG dropout 263 | if cfg_dropout_prob > 0.0: 264 | if cross_attn_cond is not None: 265 | null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) 266 | dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) 267 | cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) 268 | 269 | if prepend_cond is not None: 270 | null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) 271 | dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) 272 | prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) 273 | 274 | 275 | if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None): 276 | # Classifier-free guidance 277 | # Concatenate conditioned and unconditioned inputs on the batch dimension 278 | batch_inputs = torch.cat([x, x], dim=0) 279 | batch_timestep = torch.cat([t, t], dim=0) 280 | 281 | if global_embed is not None: 282 | batch_global_cond = torch.cat([global_embed, global_embed], dim=0) 283 | else: 284 | batch_global_cond = None 285 | 286 | if input_concat_cond is not None: 287 | batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0) 288 | else: 289 | batch_input_concat_cond = None 290 | 291 | batch_cond = None 292 | batch_cond_masks = None 293 | 294 | # Handle CFG for cross-attention conditioning 295 | if cross_attn_cond is not None: 296 | 297 | null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) 298 | 299 | # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning 300 | if negative_cross_attn_cond is not None: 301 | 302 | # If there's a negative cross-attention mask, set the masked tokens to the null embed 303 | if negative_cross_attn_mask is not None: 304 | negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) 305 | 306 | negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed) 307 | 308 | batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0) 309 | 310 | else: 311 | batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0) 312 | 313 | if cross_attn_cond_mask is not None: 314 | batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0) 315 | 316 | batch_prepend_cond = None 317 | batch_prepend_cond_mask = None 318 | 319 | if prepend_cond is not None: 320 | 321 | null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) 322 | 323 | batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) 324 | 325 | if prepend_cond_mask is not None: 326 | batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) 327 | 328 | 329 | if mask is not None: 330 | batch_masks = torch.cat([mask, mask], dim=0) 331 | else: 332 | batch_masks = None 333 | 334 | batch_output = self._forward( 335 | batch_inputs, 336 | batch_timestep, 337 | cross_attn_cond=batch_cond, 338 | cross_attn_cond_mask=batch_cond_masks, 339 | mask = batch_masks, 340 | input_concat_cond=batch_input_concat_cond, 341 | global_embed = batch_global_cond, 342 | prepend_cond = batch_prepend_cond, 343 | prepend_cond_mask = batch_prepend_cond_mask, 344 | return_info = return_info, 345 | **kwargs) 346 | 347 | if return_info: 348 | batch_output, info = batch_output 349 | 350 | cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0) 351 | cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale 352 | 353 | # CFG Rescale 354 | if scale_phi != 0.0: 355 | cond_out_std = cond_output.std(dim=1, keepdim=True) 356 | out_cfg_std = cfg_output.std(dim=1, keepdim=True) 357 | output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output 358 | else: 359 | output = cfg_output 360 | 361 | if return_info: 362 | return output, info 363 | 364 | return output 365 | 366 | else: 367 | return self._forward( 368 | x, 369 | t, 370 | cross_attn_cond=cross_attn_cond, 371 | cross_attn_cond_mask=cross_attn_cond_mask, 372 | input_concat_cond=input_concat_cond, 373 | global_embed=global_embed, 374 | prepend_cond=prepend_cond, 375 | prepend_cond_mask=prepend_cond_mask, 376 | mask=mask, 377 | return_info=return_info, 378 | **kwargs 379 | ) -------------------------------------------------------------------------------- /stable_audio_tools/models/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def create_model_from_config(model_config): 4 | model_type = model_config.get('model_type', None) 5 | 6 | assert model_type is not None, 'model_type must be specified in model config' 7 | 8 | if model_type == 'autoencoder': 9 | from .autoencoders import create_autoencoder_from_config 10 | return create_autoencoder_from_config(model_config) 11 | elif model_type == 'diffusion_uncond': 12 | from .diffusion import create_diffusion_uncond_from_config 13 | return create_diffusion_uncond_from_config(model_config) 14 | elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior": 15 | from .diffusion import create_diffusion_cond_from_config 16 | return create_diffusion_cond_from_config(model_config) 17 | elif model_type == 'diffusion_autoencoder': 18 | from .autoencoders import create_diffAE_from_config 19 | return create_diffAE_from_config(model_config) 20 | elif model_type == 'lm': 21 | from .lm import create_audio_lm_from_config 22 | return create_audio_lm_from_config(model_config) 23 | else: 24 | raise NotImplementedError(f'Unknown model type: {model_type}') 25 | 26 | def create_model_from_config_path(model_config_path): 27 | with open(model_config_path) as f: 28 | model_config = json.load(f) 29 | 30 | return create_model_from_config(model_config) 31 | 32 | def create_pretransform_from_config(pretransform_config, sample_rate): 33 | pretransform_type = pretransform_config.get('type', None) 34 | 35 | assert pretransform_type is not None, 'type must be specified in pretransform config' 36 | 37 | if pretransform_type == 'autoencoder': 38 | from .autoencoders import create_autoencoder_from_config 39 | from .pretransforms import AutoencoderPretransform 40 | 41 | # Create fake top-level config to pass sample rate to autoencoder constructor 42 | # This is a bit of a hack but it keeps us from re-defining the sample rate in the config 43 | autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]} 44 | autoencoder = create_autoencoder_from_config(autoencoder_config) 45 | 46 | scale = pretransform_config.get("scale", 1.0) 47 | model_half = pretransform_config.get("model_half", False) 48 | iterate_batch = pretransform_config.get("iterate_batch", False) 49 | chunked = pretransform_config.get("chunked", False) 50 | 51 | pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) 52 | elif pretransform_type == 'wavelet': 53 | from .pretransforms import WaveletPretransform 54 | 55 | wavelet_config = pretransform_config["config"] 56 | channels = wavelet_config["channels"] 57 | levels = wavelet_config["levels"] 58 | wavelet = wavelet_config["wavelet"] 59 | 60 | pretransform = WaveletPretransform(channels, levels, wavelet) 61 | elif pretransform_type == 'pqmf': 62 | from .pretransforms import PQMFPretransform 63 | pqmf_config = pretransform_config["config"] 64 | pretransform = PQMFPretransform(**pqmf_config) 65 | elif pretransform_type == 'dac_pretrained': 66 | from .pretransforms import PretrainedDACPretransform 67 | pretrained_dac_config = pretransform_config["config"] 68 | pretransform = PretrainedDACPretransform(**pretrained_dac_config) 69 | elif pretransform_type == "audiocraft_pretrained": 70 | from .pretransforms import AudiocraftCompressionPretransform 71 | 72 | audiocraft_config = pretransform_config["config"] 73 | pretransform = AudiocraftCompressionPretransform(**audiocraft_config) 74 | else: 75 | raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') 76 | 77 | enable_grad = pretransform_config.get('enable_grad', False) 78 | pretransform.enable_grad = enable_grad 79 | 80 | pretransform.eval().requires_grad_(pretransform.enable_grad) 81 | 82 | return pretransform 83 | 84 | def create_bottleneck_from_config(bottleneck_config): 85 | bottleneck_type = bottleneck_config.get('type', None) 86 | 87 | assert bottleneck_type is not None, 'type must be specified in bottleneck config' 88 | 89 | if bottleneck_type == 'tanh': 90 | from .bottleneck import TanhBottleneck 91 | bottleneck = TanhBottleneck() 92 | elif bottleneck_type == 'vae': 93 | from .bottleneck import VAEBottleneck 94 | bottleneck = VAEBottleneck() 95 | elif bottleneck_type == 'rvq': 96 | from .bottleneck import RVQBottleneck 97 | 98 | quantizer_params = { 99 | "dim": 128, 100 | "codebook_size": 1024, 101 | "num_quantizers": 8, 102 | "decay": 0.99, 103 | "kmeans_init": True, 104 | "kmeans_iters": 50, 105 | "threshold_ema_dead_code": 2, 106 | } 107 | 108 | quantizer_params.update(bottleneck_config["config"]) 109 | 110 | bottleneck = RVQBottleneck(**quantizer_params) 111 | elif bottleneck_type == "dac_rvq": 112 | from .bottleneck import DACRVQBottleneck 113 | 114 | bottleneck = DACRVQBottleneck(**bottleneck_config["config"]) 115 | 116 | elif bottleneck_type == 'rvq_vae': 117 | from .bottleneck import RVQVAEBottleneck 118 | 119 | quantizer_params = { 120 | "dim": 128, 121 | "codebook_size": 1024, 122 | "num_quantizers": 8, 123 | "decay": 0.99, 124 | "kmeans_init": True, 125 | "kmeans_iters": 50, 126 | "threshold_ema_dead_code": 2, 127 | } 128 | 129 | quantizer_params.update(bottleneck_config["config"]) 130 | 131 | bottleneck = RVQVAEBottleneck(**quantizer_params) 132 | 133 | elif bottleneck_type == 'dac_rvq_vae': 134 | from .bottleneck import DACRVQVAEBottleneck 135 | bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"]) 136 | elif bottleneck_type == 'l2_norm': 137 | from .bottleneck import L2Bottleneck 138 | bottleneck = L2Bottleneck() 139 | elif bottleneck_type == "wasserstein": 140 | from .bottleneck import WassersteinBottleneck 141 | bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {})) 142 | elif bottleneck_type == "fsq": 143 | from .bottleneck import FSQBottleneck 144 | bottleneck = FSQBottleneck(**bottleneck_config["config"]) 145 | else: 146 | raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') 147 | 148 | requires_grad = bottleneck_config.get('requires_grad', True) 149 | if not requires_grad: 150 | for param in bottleneck.parameters(): 151 | param.requires_grad = False 152 | 153 | return bottleneck 154 | -------------------------------------------------------------------------------- /stable_audio_tools/models/local_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | from .blocks import AdaRMSNorm 7 | from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm 8 | 9 | def checkpoint(function, *args, **kwargs): 10 | kwargs.setdefault("use_reentrant", False) 11 | return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) 12 | 13 | # Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py 14 | class ContinuousLocalTransformer(nn.Module): 15 | def __init__( 16 | self, 17 | *, 18 | dim, 19 | depth, 20 | dim_in = None, 21 | dim_out = None, 22 | causal = False, 23 | local_attn_window_size = 64, 24 | heads = 8, 25 | ff_mult = 2, 26 | cond_dim = 0, 27 | cross_attn_cond_dim = 0, 28 | **kwargs 29 | ): 30 | super().__init__() 31 | 32 | dim_head = dim//heads 33 | 34 | self.layers = nn.ModuleList([]) 35 | 36 | self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity() 37 | 38 | self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity() 39 | 40 | self.local_attn_window_size = local_attn_window_size 41 | 42 | self.cond_dim = cond_dim 43 | 44 | self.cross_attn_cond_dim = cross_attn_cond_dim 45 | 46 | self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32)) 47 | 48 | for _ in range(depth): 49 | 50 | self.layers.append(nn.ModuleList([ 51 | AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), 52 | Attention( 53 | dim=dim, 54 | dim_heads=dim_head, 55 | causal=causal, 56 | zero_init_output=True, 57 | natten_kernel_size=local_attn_window_size, 58 | ), 59 | Attention( 60 | dim=dim, 61 | dim_heads=dim_head, 62 | dim_context = cross_attn_cond_dim, 63 | zero_init_output=True 64 | ) if self.cross_attn_cond_dim > 0 else nn.Identity(), 65 | AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), 66 | FeedForward(dim = dim, mult = ff_mult, no_bias=True) 67 | ])) 68 | 69 | def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None): 70 | 71 | x = checkpoint(self.project_in, x) 72 | 73 | if prepend_cond is not None: 74 | x = torch.cat([prepend_cond, x], dim=1) 75 | 76 | pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) 77 | 78 | for attn_norm, attn, xattn, ff_norm, ff in self.layers: 79 | 80 | residual = x 81 | if cond is not None: 82 | x = checkpoint(attn_norm, x, cond) 83 | else: 84 | x = checkpoint(attn_norm, x) 85 | 86 | x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual 87 | 88 | if cross_attn_cond is not None: 89 | x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x 90 | 91 | residual = x 92 | 93 | if cond is not None: 94 | x = checkpoint(ff_norm, x, cond) 95 | else: 96 | x = checkpoint(ff_norm, x) 97 | 98 | x = checkpoint(ff, x) + residual 99 | 100 | return checkpoint(self.project_out, x) 101 | 102 | class TransformerDownsampleBlock1D(nn.Module): 103 | def __init__( 104 | self, 105 | in_channels, 106 | embed_dim = 768, 107 | depth = 3, 108 | heads = 12, 109 | downsample_ratio = 2, 110 | local_attn_window_size = 64, 111 | **kwargs 112 | ): 113 | super().__init__() 114 | 115 | self.downsample_ratio = downsample_ratio 116 | 117 | self.transformer = ContinuousLocalTransformer( 118 | dim=embed_dim, 119 | depth=depth, 120 | heads=heads, 121 | local_attn_window_size=local_attn_window_size, 122 | **kwargs 123 | ) 124 | 125 | self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() 126 | 127 | self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False) 128 | 129 | 130 | def forward(self, x): 131 | 132 | x = checkpoint(self.project_in, x) 133 | 134 | # Compute 135 | x = self.transformer(x) 136 | 137 | # Trade sequence length for channels 138 | x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio) 139 | 140 | # Project back to embed dim 141 | x = checkpoint(self.project_down, x) 142 | 143 | return x 144 | 145 | class TransformerUpsampleBlock1D(nn.Module): 146 | def __init__( 147 | self, 148 | in_channels, 149 | embed_dim, 150 | depth = 3, 151 | heads = 12, 152 | upsample_ratio = 2, 153 | local_attn_window_size = 64, 154 | **kwargs 155 | ): 156 | super().__init__() 157 | 158 | self.upsample_ratio = upsample_ratio 159 | 160 | self.transformer = ContinuousLocalTransformer( 161 | dim=embed_dim, 162 | depth=depth, 163 | heads=heads, 164 | local_attn_window_size = local_attn_window_size, 165 | **kwargs 166 | ) 167 | 168 | self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() 169 | 170 | self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False) 171 | 172 | def forward(self, x): 173 | 174 | # Project to embed dim 175 | x = checkpoint(self.project_in, x) 176 | 177 | # Project to increase channel dim 178 | x = checkpoint(self.project_up, x) 179 | 180 | # Trade channels for sequence length 181 | x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio) 182 | 183 | # Compute 184 | x = self.transformer(x) 185 | 186 | return x 187 | 188 | 189 | class TransformerEncoder1D(nn.Module): 190 | def __init__( 191 | self, 192 | in_channels, 193 | out_channels, 194 | embed_dims = [96, 192, 384, 768], 195 | heads = [12, 12, 12, 12], 196 | depths = [3, 3, 3, 3], 197 | ratios = [2, 2, 2, 2], 198 | local_attn_window_size = 64, 199 | **kwargs 200 | ): 201 | super().__init__() 202 | 203 | layers = [] 204 | 205 | for layer in range(len(depths)): 206 | prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] 207 | 208 | layers.append( 209 | TransformerDownsampleBlock1D( 210 | in_channels = prev_dim, 211 | embed_dim = embed_dims[layer], 212 | heads = heads[layer], 213 | depth = depths[layer], 214 | downsample_ratio = ratios[layer], 215 | local_attn_window_size = local_attn_window_size, 216 | **kwargs 217 | ) 218 | ) 219 | 220 | self.layers = nn.Sequential(*layers) 221 | 222 | self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) 223 | self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) 224 | 225 | def forward(self, x): 226 | x = rearrange(x, "b c n -> b n c") 227 | x = checkpoint(self.project_in, x) 228 | x = self.layers(x) 229 | x = checkpoint(self.project_out, x) 230 | x = rearrange(x, "b n c -> b c n") 231 | 232 | return x 233 | 234 | 235 | class TransformerDecoder1D(nn.Module): 236 | def __init__( 237 | self, 238 | in_channels, 239 | out_channels, 240 | embed_dims = [768, 384, 192, 96], 241 | heads = [12, 12, 12, 12], 242 | depths = [3, 3, 3, 3], 243 | ratios = [2, 2, 2, 2], 244 | local_attn_window_size = 64, 245 | **kwargs 246 | ): 247 | 248 | super().__init__() 249 | 250 | layers = [] 251 | 252 | for layer in range(len(depths)): 253 | prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] 254 | 255 | layers.append( 256 | TransformerUpsampleBlock1D( 257 | in_channels = prev_dim, 258 | embed_dim = embed_dims[layer], 259 | heads = heads[layer], 260 | depth = depths[layer], 261 | upsample_ratio = ratios[layer], 262 | local_attn_window_size = local_attn_window_size, 263 | **kwargs 264 | ) 265 | ) 266 | 267 | self.layers = nn.Sequential(*layers) 268 | 269 | self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) 270 | self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) 271 | 272 | def forward(self, x): 273 | x = rearrange(x, "b c n -> b n c") 274 | x = checkpoint(self.project_in, x) 275 | x = self.layers(x) 276 | x = checkpoint(self.project_out, x) 277 | x = rearrange(x, "b n c -> b c n") 278 | return x -------------------------------------------------------------------------------- /stable_audio_tools/models/pqmf.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from einops import rearrange 6 | from scipy.optimize import fmin 7 | from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord 8 | 9 | class PQMF(nn.Module): 10 | """ 11 | Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction. 12 | Uses polyphase representation which is computationally more efficient for real-time. 13 | 14 | Parameters: 15 | - attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB. 16 | - num_bands (int): Number of desired frequency bands. It must be a power of 2. 17 | """ 18 | 19 | def __init__(self, attenuation, num_bands): 20 | super(PQMF, self).__init__() 21 | 22 | # Ensure num_bands is a power of 2 23 | is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands))) 24 | assert is_power_of_2, "'num_bands' must be a power of 2." 25 | 26 | # Create the prototype filter 27 | prototype_filter = design_prototype_filter(attenuation, num_bands) 28 | filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands) 29 | padded_filter_bank = pad_to_nearest_power_of_two(filter_bank) 30 | 31 | # Register filters and settings 32 | self.register_buffer("filter_bank", padded_filter_bank) 33 | self.register_buffer("prototype", prototype_filter) 34 | self.num_bands = num_bands 35 | 36 | def forward(self, signal): 37 | """Decompose the signal into multiple frequency bands.""" 38 | # If signal is not a pytorch tensor of Batch x Channels x Length, convert it 39 | signal = prepare_signal_dimensions(signal) 40 | # The signal length must be a multiple of num_bands. Pad it with zeros. 41 | signal = pad_signal(signal, self.num_bands) 42 | # run it 43 | signal = polyphase_analysis(signal, self.filter_bank) 44 | return apply_alias_cancellation(signal) 45 | 46 | def inverse(self, bands): 47 | """Reconstruct the original signal from the frequency bands.""" 48 | bands = apply_alias_cancellation(bands) 49 | return polyphase_synthesis(bands, self.filter_bank) 50 | 51 | 52 | def prepare_signal_dimensions(signal): 53 | """ 54 | Rearrange signal into Batch x Channels x Length. 55 | 56 | Parameters 57 | ---------- 58 | signal : torch.Tensor or numpy.ndarray 59 | The input signal. 60 | 61 | Returns 62 | ------- 63 | torch.Tensor 64 | Preprocessed signal tensor. 65 | """ 66 | # Convert numpy to torch tensor 67 | if isinstance(signal, np.ndarray): 68 | signal = torch.from_numpy(signal) 69 | 70 | # Ensure tensor 71 | if not isinstance(signal, torch.Tensor): 72 | raise ValueError("Input should be either a numpy array or a PyTorch tensor.") 73 | 74 | # Modify dimension of signal to Batch x Channels x Length 75 | if signal.dim() == 1: 76 | # This is just a mono signal. Unsqueeze to 1 x 1 x Length 77 | signal = signal.unsqueeze(0).unsqueeze(0) 78 | elif signal.dim() == 2: 79 | # This is a multi-channel signal (e.g. stereo) 80 | # Rearrange so that larger dimension (Length) is last 81 | if signal.shape[0] > signal.shape[1]: 82 | signal = signal.T 83 | # Unsqueeze to 1 x Channels x Length 84 | signal = signal.unsqueeze(0) 85 | return signal 86 | 87 | def pad_signal(signal, num_bands): 88 | """ 89 | Pads the signal to make its length divisible by the given number of bands. 90 | 91 | Parameters 92 | ---------- 93 | signal : torch.Tensor 94 | The input signal tensor, where the last dimension represents the signal length. 95 | 96 | num_bands : int 97 | The number of bands by which the signal length should be divisible. 98 | 99 | Returns 100 | ------- 101 | torch.Tensor 102 | The padded signal tensor. If the original signal length was already divisible 103 | by num_bands, returns the original signal unchanged. 104 | """ 105 | remainder = signal.shape[-1] % num_bands 106 | if remainder > 0: 107 | padding_size = num_bands - remainder 108 | signal = nn.functional.pad(signal, (0, padding_size)) 109 | return signal 110 | 111 | def generate_modulated_filter_bank(prototype_filter, num_bands): 112 | """ 113 | Generate a QMF bank of cosine modulated filters based on a given prototype filter. 114 | 115 | Parameters 116 | ---------- 117 | prototype_filter : torch.Tensor 118 | The prototype filter used as the basis for modulation. 119 | num_bands : int 120 | The number of desired subbands or filters. 121 | 122 | Returns 123 | ------- 124 | torch.Tensor 125 | A bank of cosine modulated filters. 126 | """ 127 | 128 | # Initialize indices for modulation. 129 | subband_indices = torch.arange(num_bands).reshape(-1, 1) 130 | 131 | # Calculate the length of the prototype filter. 132 | filter_length = prototype_filter.shape[-1] 133 | 134 | # Generate symmetric time indices centered around zero. 135 | time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1) 136 | 137 | # Calculate phase offsets to ensure orthogonality between subbands. 138 | phase_offsets = (-1)**subband_indices * np.pi / 4 139 | 140 | # Compute the cosine modulation function. 141 | modulation = torch.cos( 142 | (2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets 143 | ) 144 | 145 | # Apply modulation to the prototype filter. 146 | modulated_filters = 2 * prototype_filter * modulation 147 | 148 | return modulated_filters 149 | 150 | 151 | def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None): 152 | """ 153 | Design a lowpass filter using the Kaiser window. 154 | 155 | Parameters 156 | ---------- 157 | angular_cutoff : float 158 | The angular frequency cutoff of the filter. 159 | attenuation : float 160 | The desired stopband attenuation in decibels (dB). 161 | filter_length : int, optional 162 | Desired length of the filter. If not provided, it's computed based on the given specs. 163 | 164 | Returns 165 | ------- 166 | ndarray 167 | The designed lowpass filter coefficients. 168 | """ 169 | 170 | estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi) 171 | 172 | # Ensure the estimated length is odd. 173 | estimated_length = 2 * (estimated_length // 2) + 1 174 | 175 | if filter_length is None: 176 | filter_length = estimated_length 177 | 178 | return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi) 179 | 180 | 181 | def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length): 182 | """ 183 | Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427 184 | 185 | Parameters 186 | ---------- 187 | angular_cutoff : float 188 | Angular frequency cutoff of the filter. 189 | attenuation : float 190 | Desired stopband attenuation in dB. 191 | num_bands : int 192 | Number of bands for the multiband filter system. 193 | filter_length : int, optional 194 | Desired length of the filter. 195 | 196 | Returns 197 | ------- 198 | float 199 | The computed objective (loss) value for the given filter specs. 200 | """ 201 | 202 | filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length) 203 | convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full") 204 | 205 | return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:])) 206 | 207 | 208 | def design_prototype_filter(attenuation, num_bands, filter_length=None): 209 | """ 210 | Design the optimal prototype filter for a multiband system given the desired specs. 211 | 212 | Parameters 213 | ---------- 214 | attenuation : float 215 | The desired stopband attenuation in dB. 216 | num_bands : int 217 | Number of bands for the multiband filter system. 218 | filter_length : int, optional 219 | Desired length of the filter. If not provided, it's computed based on the given specs. 220 | 221 | Returns 222 | ------- 223 | ndarray 224 | The optimal prototype filter coefficients. 225 | """ 226 | 227 | optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length), 228 | 1 / num_bands, disp=0)[0] 229 | 230 | prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length) 231 | return torch.tensor(prototype_filter, dtype=torch.float32) 232 | 233 | def pad_to_nearest_power_of_two(x): 234 | """ 235 | Pads the input tensor 'x' on both sides such that its last dimension 236 | becomes the nearest larger power of two. 237 | 238 | Parameters: 239 | ----------- 240 | x : torch.Tensor 241 | The input tensor to be padded. 242 | 243 | Returns: 244 | -------- 245 | torch.Tensor 246 | The padded tensor. 247 | """ 248 | current_length = x.shape[-1] 249 | target_length = 2**math.ceil(math.log2(current_length)) 250 | 251 | total_padding = target_length - current_length 252 | left_padding = total_padding // 2 253 | right_padding = total_padding - left_padding 254 | 255 | return nn.functional.pad(x, (left_padding, right_padding)) 256 | 257 | def apply_alias_cancellation(x): 258 | """ 259 | Applies alias cancellation by inverting the sign of every 260 | second element of every second row, starting from the second 261 | row's first element in a tensor. 262 | 263 | This operation helps ensure that the aliasing introduced in 264 | each band during the decomposition will be counteracted during 265 | the reconstruction. 266 | 267 | Parameters: 268 | ----------- 269 | x : torch.Tensor 270 | The input tensor. 271 | 272 | Returns: 273 | -------- 274 | torch.Tensor 275 | Tensor with specific elements' sign inverted for alias cancellation. 276 | """ 277 | 278 | # Create a mask of the same shape as 'x', initialized with all ones 279 | mask = torch.ones_like(x) 280 | 281 | # Update specific elements in the mask to -1 to perform inversion 282 | mask[..., 1::2, ::2] = -1 283 | 284 | # Apply the mask to the input tensor 'x' 285 | return x * mask 286 | 287 | def ensure_odd_length(tensor): 288 | """ 289 | Pads the last dimension of a tensor to ensure its size is odd. 290 | 291 | Parameters: 292 | ----------- 293 | tensor : torch.Tensor 294 | Input tensor whose last dimension might need padding. 295 | 296 | Returns: 297 | -------- 298 | torch.Tensor 299 | The original tensor if its last dimension was already odd, 300 | or the padded tensor with an odd-sized last dimension. 301 | """ 302 | 303 | last_dim_size = tensor.shape[-1] 304 | 305 | if last_dim_size % 2 == 0: 306 | tensor = nn.functional.pad(tensor, (0, 1)) 307 | 308 | return tensor 309 | 310 | def polyphase_analysis(signal, filter_bank): 311 | """ 312 | Applies the polyphase method to efficiently analyze the signal using a filter bank. 313 | 314 | Parameters: 315 | ----------- 316 | signal : torch.Tensor 317 | Input signal tensor with shape (Batch x Channels x Length). 318 | 319 | filter_bank : torch.Tensor 320 | Filter bank tensor with shape (Bands x Length). 321 | 322 | Returns: 323 | -------- 324 | torch.Tensor 325 | Signal split into sub-bands. (Batch x Channels x Bands x Length) 326 | """ 327 | 328 | num_bands = filter_bank.shape[0] 329 | num_channels = signal.shape[1] 330 | 331 | # Rearrange signal for polyphase processing. 332 | # Also combine Batch x Channel into one dimension for now. 333 | #signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands) 334 | signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands) 335 | 336 | # Rearrange the filter bank for matching signal shape 337 | filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands) 338 | 339 | # Apply convolution with appropriate padding to maintain spatial dimensions 340 | padding = filter_bank.shape[-1] // 2 341 | filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding) 342 | 343 | # Truncate the last dimension post-convolution to adjust the output shape 344 | filtered_signal = filtered_signal[..., :-1] 345 | # Rearrange the first dimension back into Batch x Channels 346 | filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels) 347 | 348 | return filtered_signal 349 | 350 | def polyphase_synthesis(signal, filter_bank): 351 | """ 352 | Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal. 353 | 354 | Parameters 355 | ---------- 356 | signal : torch.Tensor 357 | Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length). 358 | 359 | filter_bank : torch.Tensor 360 | Analysis filter bank (shape: Bands x Length). 361 | 362 | should_rearrange : bool, optional 363 | Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True. 364 | 365 | Returns 366 | ------- 367 | torch.Tensor 368 | Reconstructed signal (shape: Batch x Channels X Length) 369 | """ 370 | 371 | num_bands = filter_bank.shape[0] 372 | num_channels = signal.shape[1] 373 | 374 | # Rearrange the filter bank 375 | filter_bank = filter_bank.flip(-1) 376 | filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands) 377 | 378 | # Combine Batch x Channels into one dimension for now. 379 | signal = rearrange(signal, "b c n t -> (b c) n t") 380 | 381 | # Apply convolution with appropriate padding 382 | padding_amount = filter_bank.shape[-1] // 2 + 1 383 | reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount)) 384 | 385 | # Scale the result 386 | reconstructed_signal = reconstructed_signal[..., :-1] * num_bands 387 | 388 | # Reorganize the output and truncate 389 | reconstructed_signal = reconstructed_signal.flip(1) 390 | reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands) 391 | reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:] 392 | 393 | return reconstructed_signal -------------------------------------------------------------------------------- /stable_audio_tools/models/pretrained.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from .factory import create_model_from_config 4 | from .utils import load_ckpt_state_dict 5 | 6 | from huggingface_hub import hf_hub_download 7 | 8 | def get_pretrained_model(name: str): 9 | 10 | model_config_path = hf_hub_download(name, filename="config.json", repo_type='model') 11 | 12 | with open(model_config_path) as f: 13 | model_config = json.load(f) 14 | 15 | model = create_model_from_config(model_config) 16 | 17 | # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file 18 | try: 19 | model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model') 20 | except Exception as e: 21 | model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model') 22 | 23 | model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) 24 | 25 | return model, model_config -------------------------------------------------------------------------------- /stable_audio_tools/models/pretransforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import nn 4 | 5 | class Pretransform(nn.Module): 6 | def __init__(self, enable_grad, io_channels, is_discrete): 7 | super().__init__() 8 | 9 | self.is_discrete = is_discrete 10 | self.io_channels = io_channels 11 | self.encoded_channels = None 12 | self.downsampling_ratio = None 13 | 14 | self.enable_grad = enable_grad 15 | 16 | def encode(self, x): 17 | raise NotImplementedError 18 | 19 | def decode(self, z): 20 | raise NotImplementedError 21 | 22 | def tokenize(self, x): 23 | raise NotImplementedError 24 | 25 | def decode_tokens(self, tokens): 26 | raise NotImplementedError 27 | 28 | class AutoencoderPretransform(Pretransform): 29 | def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False): 30 | super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete) 31 | self.model = model 32 | self.model.requires_grad_(False).eval() 33 | self.scale=scale 34 | self.downsampling_ratio = model.downsampling_ratio 35 | self.io_channels = model.io_channels 36 | self.sample_rate = model.sample_rate 37 | 38 | self.model_half = model_half 39 | self.iterate_batch = iterate_batch 40 | 41 | self.encoded_channels = model.latent_dim 42 | 43 | self.chunked = chunked 44 | self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None 45 | self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None 46 | 47 | if self.model_half: 48 | self.model.half() 49 | 50 | def encode(self, x, **kwargs): 51 | 52 | if self.model_half: 53 | x = x.half() 54 | self.model.to(torch.float16) 55 | 56 | encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) 57 | 58 | if self.model_half: 59 | encoded = encoded.float() 60 | 61 | return encoded / self.scale 62 | 63 | def decode(self, z, **kwargs): 64 | z = z * self.scale 65 | 66 | if self.model_half: 67 | z = z.half() 68 | self.model.to(torch.float16) 69 | 70 | decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) 71 | 72 | if self.model_half: 73 | decoded = decoded.float() 74 | 75 | return decoded 76 | 77 | def tokenize(self, x, **kwargs): 78 | assert self.model.is_discrete, "Cannot tokenize with a continuous model" 79 | 80 | _, info = self.model.encode(x, return_info = True, **kwargs) 81 | 82 | return info[self.model.bottleneck.tokens_id] 83 | 84 | def decode_tokens(self, tokens, **kwargs): 85 | assert self.model.is_discrete, "Cannot decode tokens with a continuous model" 86 | 87 | return self.model.decode_tokens(tokens, **kwargs) 88 | 89 | def load_state_dict(self, state_dict, strict=True): 90 | self.model.load_state_dict(state_dict, strict=strict) 91 | 92 | class WaveletPretransform(Pretransform): 93 | def __init__(self, channels, levels, wavelet): 94 | super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) 95 | 96 | from .wavelets import WaveletEncode1d, WaveletDecode1d 97 | 98 | self.encoder = WaveletEncode1d(channels, levels, wavelet) 99 | self.decoder = WaveletDecode1d(channels, levels, wavelet) 100 | 101 | self.downsampling_ratio = 2 ** levels 102 | self.io_channels = channels 103 | self.encoded_channels = channels * self.downsampling_ratio 104 | 105 | def encode(self, x): 106 | return self.encoder(x) 107 | 108 | def decode(self, z): 109 | return self.decoder(z) 110 | 111 | class PQMFPretransform(Pretransform): 112 | def __init__(self, attenuation=100, num_bands=16): 113 | # TODO: Fix PQMF to take in in-channels 114 | super().__init__(enable_grad=False, io_channels=1, is_discrete=False) 115 | from .pqmf import PQMF 116 | self.pqmf = PQMF(attenuation, num_bands) 117 | 118 | 119 | def encode(self, x): 120 | # x is (Batch x Channels x Time) 121 | x = self.pqmf.forward(x) 122 | # pqmf.forward returns (Batch x Channels x Bands x Time) 123 | # but Pretransform needs Batch x Channels x Time 124 | # so concatenate channels and bands into one axis 125 | return rearrange(x, "b c n t -> b (c n) t") 126 | 127 | def decode(self, x): 128 | # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) 129 | x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands) 130 | # returns (Batch x Channels x Time) 131 | return self.pqmf.inverse(x) 132 | 133 | class PretrainedDACPretransform(Pretransform): 134 | def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True): 135 | super().__init__(enable_grad=False, io_channels=1, is_discrete=True) 136 | 137 | import dac 138 | 139 | model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate) 140 | 141 | self.model = dac.DAC.load(model_path) 142 | 143 | self.quantize_on_decode = quantize_on_decode 144 | 145 | if model_type == "44khz": 146 | self.downsampling_ratio = 512 147 | else: 148 | self.downsampling_ratio = 320 149 | 150 | self.io_channels = 1 151 | 152 | self.scale = scale 153 | 154 | self.chunked = chunked 155 | 156 | self.encoded_channels = self.model.latent_dim 157 | 158 | self.num_quantizers = self.model.n_codebooks 159 | 160 | self.codebook_size = self.model.codebook_size 161 | 162 | def encode(self, x): 163 | 164 | latents = self.model.encoder(x) 165 | 166 | if self.quantize_on_decode: 167 | output = latents 168 | else: 169 | z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) 170 | output = z 171 | 172 | if self.scale != 1.0: 173 | output = output / self.scale 174 | 175 | return output 176 | 177 | def decode(self, z): 178 | 179 | if self.scale != 1.0: 180 | z = z * self.scale 181 | 182 | if self.quantize_on_decode: 183 | z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) 184 | 185 | return self.model.decode(z) 186 | 187 | def tokenize(self, x): 188 | return self.model.encode(x)[1] 189 | 190 | def decode_tokens(self, tokens): 191 | latents = self.model.quantizer.from_codes(tokens) 192 | return self.model.decode(latents) 193 | 194 | class AudiocraftCompressionPretransform(Pretransform): 195 | def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True): 196 | super().__init__(enable_grad=False, io_channels=1, is_discrete=True) 197 | 198 | try: 199 | from audiocraft.models import CompressionModel 200 | except ImportError: 201 | raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.") 202 | 203 | self.model = CompressionModel.get_pretrained(model_type) 204 | 205 | self.quantize_on_decode = quantize_on_decode 206 | 207 | self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate) 208 | 209 | self.sample_rate = self.model.sample_rate 210 | 211 | self.io_channels = self.model.channels 212 | 213 | self.scale = scale 214 | 215 | #self.encoded_channels = self.model.latent_dim 216 | 217 | self.num_quantizers = self.model.num_codebooks 218 | 219 | self.codebook_size = self.model.cardinality 220 | 221 | self.model.to(torch.float16).eval().requires_grad_(False) 222 | 223 | def encode(self, x): 224 | 225 | assert False, "Audiocraft compression models do not support continuous encoding" 226 | 227 | # latents = self.model.encoder(x) 228 | 229 | # if self.quantize_on_decode: 230 | # output = latents 231 | # else: 232 | # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) 233 | # output = z 234 | 235 | # if self.scale != 1.0: 236 | # output = output / self.scale 237 | 238 | # return output 239 | 240 | def decode(self, z): 241 | 242 | assert False, "Audiocraft compression models do not support continuous decoding" 243 | 244 | # if self.scale != 1.0: 245 | # z = z * self.scale 246 | 247 | # if self.quantize_on_decode: 248 | # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) 249 | 250 | # return self.model.decode(z) 251 | 252 | def tokenize(self, x): 253 | with torch.cuda.amp.autocast(enabled=False): 254 | return self.model.encode(x.to(torch.float16))[0] 255 | 256 | def decode_tokens(self, tokens): 257 | with torch.cuda.amp.autocast(enabled=False): 258 | return self.model.decode(tokens) 259 | -------------------------------------------------------------------------------- /stable_audio_tools/models/temptransformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | class Residual(nn.Module): 9 | def __init__(self, fn): 10 | super().__init__() 11 | self.fn = fn 12 | def forward(self, x, **kwargs): 13 | return self.fn(x, **kwargs) + x 14 | 15 | class SA_PreNorm(nn.Module): 16 | def __init__(self, dim, fn): 17 | super().__init__() 18 | self.norm = nn.LayerNorm(dim) 19 | self.fn = fn 20 | def forward(self, x, **kwargs): 21 | return self.fn(self.norm(x), **kwargs) 22 | 23 | class SA_FeedForward(nn.Module): 24 | def __init__(self, dim, hidden_dim, dropout = 0.): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.Linear(dim, hidden_dim), 28 | nn.GELU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_dim, dim), 31 | nn.Dropout(dropout) 32 | ) 33 | def forward(self, x): 34 | return self.net(x) 35 | 36 | class SA_Attention(nn.Module): 37 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 38 | super().__init__() 39 | inner_dim = dim_head * heads 40 | project_out = not (heads == 1 and dim_head == dim) 41 | 42 | self.heads = heads 43 | self.scale = dim_head ** -0.5 44 | 45 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 46 | 47 | self.to_out = nn.Sequential( 48 | nn.Linear(inner_dim, dim), 49 | nn.Dropout(dropout) 50 | ) if project_out else nn.Identity() 51 | 52 | def forward(self, x): 53 | b, n, _, h = *x.shape, self.heads 54 | qkv = self.to_qkv(x).chunk(3, dim = -1) 55 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 56 | 57 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 58 | 59 | attn = dots.softmax(dim=-1) 60 | 61 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 62 | out = rearrange(out, 'b h n d -> b n (h d)') 63 | out = self.to_out(out) 64 | return out 65 | 66 | 67 | class ReAttention(nn.Module): 68 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 69 | super().__init__() 70 | inner_dim = dim_head * heads 71 | self.heads = heads 72 | self.scale = dim_head ** -0.5 73 | 74 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 75 | 76 | self.reattn_weights = nn.Parameter(torch.randn(heads, heads)) 77 | 78 | self.reattn_norm = nn.Sequential( 79 | Rearrange('b h i j -> b i j h'), 80 | nn.LayerNorm(heads), 81 | Rearrange('b i j h -> b h i j') 82 | ) 83 | 84 | self.to_out = nn.Sequential( 85 | nn.Linear(inner_dim, dim), 86 | nn.Dropout(dropout) 87 | ) 88 | 89 | def forward(self, x): 90 | b, n, _, h = *x.shape, self.heads 91 | qkv = self.to_qkv(x).chunk(3, dim = -1) 92 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 93 | 94 | # attention 95 | 96 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 97 | attn = dots.softmax(dim=-1) 98 | 99 | # re-attention 100 | 101 | attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights) 102 | attn = self.reattn_norm(attn) 103 | 104 | # aggregate and out 105 | 106 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 107 | out = rearrange(out, 'b h n d -> b n (h d)') 108 | out = self.to_out(out) 109 | return out 110 | 111 | class LeFF(nn.Module): 112 | 113 | def __init__(self, dim = 192, scale = 4, depth_kernel = 3): 114 | super().__init__() 115 | 116 | scale_dim = dim*scale 117 | self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim), 118 | Rearrange('b n c -> b c n'), 119 | nn.BatchNorm1d(scale_dim), 120 | nn.GELU(), 121 | Rearrange('b c (h w) -> b c h w', h=14, w=14) 122 | ) 123 | 124 | self.depth_conv = nn.Sequential(nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False), 125 | nn.BatchNorm2d(scale_dim), 126 | nn.GELU(), 127 | Rearrange('b c h w -> b (h w) c', h=14, w=14) 128 | ) 129 | 130 | self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim), 131 | Rearrange('b n c -> b c n'), 132 | nn.BatchNorm1d(dim), 133 | nn.GELU(), 134 | Rearrange('b c n -> b n c') 135 | ) 136 | 137 | def forward(self, x): 138 | x = self.up_proj(x) 139 | x = self.depth_conv(x) 140 | x = self.down_proj(x) 141 | return x 142 | 143 | 144 | class LCAttention(nn.Module): 145 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 146 | super().__init__() 147 | inner_dim = dim_head * heads 148 | project_out = not (heads == 1 and dim_head == dim) 149 | 150 | self.heads = heads 151 | self.scale = dim_head ** -0.5 152 | 153 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 154 | 155 | self.to_out = nn.Sequential( 156 | nn.Linear(inner_dim, dim), 157 | nn.Dropout(dropout) 158 | ) if project_out else nn.Identity() 159 | 160 | def forward(self, x): 161 | b, n, _, h = *x.shape, self.heads 162 | qkv = self.to_qkv(x).chunk(3, dim = -1) 163 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 164 | q = q[:, :, -1, :].unsqueeze(2) # Only Lth element use as query 165 | 166 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 167 | 168 | attn = dots.softmax(dim=-1) 169 | 170 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 171 | out = rearrange(out, 'b h n d -> b n (h d)') 172 | out = self.to_out(out) 173 | return out 174 | 175 | class SA_Transformer(nn.Module): 176 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 177 | super().__init__() 178 | self.layers = nn.ModuleList([]) 179 | self.norm = nn.LayerNorm(dim) 180 | for _ in range(depth): 181 | self.layers.append(nn.ModuleList([ 182 | SA_PreNorm(dim, SA_Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 183 | SA_PreNorm(dim, SA_FeedForward(dim, mlp_dim, dropout = dropout)) 184 | ])) 185 | 186 | def forward(self, x): 187 | for attn, ff in self.layers: 188 | x = attn(x) + x 189 | x = ff(x) + x 190 | return self.norm(x) -------------------------------------------------------------------------------- /stable_audio_tools/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors.torch import load_file 3 | 4 | from torch.nn.utils import remove_weight_norm 5 | import warnings 6 | warnings.simplefilter(action='ignore', category=FutureWarning) 7 | 8 | 9 | def load_ckpt_state_dict(ckpt_path): 10 | if ckpt_path.endswith(".safetensors"): 11 | state_dict = load_file(ckpt_path) 12 | else: 13 | state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] 14 | 15 | return state_dict 16 | 17 | def remove_weight_norm_from_model(model): 18 | for module in model.modules(): 19 | if hasattr(module, "weight"): 20 | print(f"Removing weight norm from {module}") 21 | remove_weight_norm(module) 22 | 23 | return model 24 | 25 | # Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license 26 | # License can be found in LICENSES/LICENSE_META.txt 27 | 28 | def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): 29 | """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. 30 | 31 | Args: 32 | input (torch.Tensor): The input tensor containing probabilities. 33 | num_samples (int): Number of samples to draw. 34 | replacement (bool): Whether to draw with replacement or not. 35 | Keywords args: 36 | generator (torch.Generator): A pseudorandom number generator for sampling. 37 | Returns: 38 | torch.Tensor: Last dimension contains num_samples indices 39 | sampled from the multinomial probability distribution 40 | located in the last dimension of tensor input. 41 | """ 42 | 43 | if num_samples == 1: 44 | q = torch.empty_like(input).exponential_(1, generator=generator) 45 | return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64) 46 | 47 | input_ = input.reshape(-1, input.shape[-1]) 48 | output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) 49 | output = output_.reshape(*list(input.shape[:-1]), -1) 50 | return output 51 | 52 | 53 | def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: 54 | """Sample next token from top K values along the last dimension of the input probs tensor. 55 | 56 | Args: 57 | probs (torch.Tensor): Input probabilities with token candidates on the last dimension. 58 | k (int): The k in “top-k”. 59 | Returns: 60 | torch.Tensor: Sampled tokens. 61 | """ 62 | top_k_value, _ = torch.topk(probs, k, dim=-1) 63 | min_value_top_k = top_k_value[..., [-1]] 64 | probs *= (probs >= min_value_top_k).float() 65 | probs.div_(probs.sum(dim=-1, keepdim=True)) 66 | next_token = multinomial(probs, num_samples=1) 67 | return next_token 68 | 69 | 70 | def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: 71 | """Sample next token from top P probabilities along the last dimension of the input probs tensor. 72 | 73 | Args: 74 | probs (torch.Tensor): Input probabilities with token candidates on the last dimension. 75 | p (int): The p in “top-p”. 76 | Returns: 77 | torch.Tensor: Sampled tokens. 78 | """ 79 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 80 | probs_sum = torch.cumsum(probs_sort, dim=-1) 81 | mask = probs_sum - probs_sort > p 82 | probs_sort *= (~mask).float() 83 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 84 | next_token = multinomial(probs_sort, num_samples=1) 85 | next_token = torch.gather(probs_idx, -1, next_token) 86 | return next_token 87 | 88 | def next_power_of_two(n): 89 | return 2 ** (n - 1).bit_length() 90 | 91 | def next_multiple_of_64(n): 92 | return ((n + 63) // 64) * 64 -------------------------------------------------------------------------------- /stable_audio_tools/models/wavelets.py: -------------------------------------------------------------------------------- 1 | """The 1D discrete wavelet transform for PyTorch.""" 2 | 3 | from einops import rearrange 4 | import pywt 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from typing import Literal 9 | 10 | 11 | def get_filter_bank(wavelet): 12 | filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank) 13 | if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0): 14 | filt = filt[:, 1:] 15 | return filt 16 | 17 | class WaveletEncode1d(nn.Module): 18 | def __init__(self, 19 | channels, 20 | levels, 21 | wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): 22 | super().__init__() 23 | self.wavelet = wavelet 24 | self.channels = channels 25 | self.levels = levels 26 | filt = get_filter_bank(wavelet) 27 | assert filt.shape[-1] % 2 == 1 28 | kernel = filt[:2, None] 29 | kernel = torch.flip(kernel, dims=(-1,)) 30 | index_i = torch.repeat_interleave(torch.arange(2), channels) 31 | index_j = torch.tile(torch.arange(channels), (2,)) 32 | kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) 33 | kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] 34 | self.register_buffer("kernel", kernel_final) 35 | 36 | def forward(self, x): 37 | for i in range(self.levels): 38 | low, rest = x[:, : self.channels], x[:, self.channels :] 39 | pad = self.kernel.shape[-1] // 2 40 | low = F.pad(low, (pad, pad), "reflect") 41 | low = F.conv1d(low, self.kernel, stride=2) 42 | rest = rearrange( 43 | rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels 44 | ) 45 | x = torch.cat([low, rest], dim=1) 46 | return x 47 | 48 | 49 | class WaveletDecode1d(nn.Module): 50 | def __init__(self, 51 | channels, 52 | levels, 53 | wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): 54 | super().__init__() 55 | self.wavelet = wavelet 56 | self.channels = channels 57 | self.levels = levels 58 | filt = get_filter_bank(wavelet) 59 | assert filt.shape[-1] % 2 == 1 60 | kernel = filt[2:, None] 61 | index_i = torch.repeat_interleave(torch.arange(2), channels) 62 | index_j = torch.tile(torch.arange(channels), (2,)) 63 | kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) 64 | kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] 65 | self.register_buffer("kernel", kernel_final) 66 | 67 | def forward(self, x): 68 | for i in range(self.levels): 69 | low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :] 70 | pad = self.kernel.shape[-1] // 2 + 2 71 | low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2) 72 | low = F.pad(low, (pad, pad), "reflect") 73 | low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2) 74 | low = F.conv_transpose1d( 75 | low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2 76 | ) 77 | low = low[..., pad - 1 : -pad] 78 | rest = rearrange( 79 | rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels 80 | ) 81 | x = torch.cat([low, rest], dim=1) 82 | return x -------------------------------------------------------------------------------- /stable_audio_tools/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import create_training_wrapper_from_config, create_demo_callback_from_config 2 | -------------------------------------------------------------------------------- /stable_audio_tools/training/autoencoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import wandb 4 | from einops import rearrange 5 | from safetensors.torch import save_file, save_model 6 | from ema_pytorch import EMA 7 | from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss 8 | import pytorch_lightning as pl 9 | from ..models.autoencoders import AudioAutoencoder 10 | from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss 11 | from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck 12 | from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss 13 | from .utils import create_optimizer_from_config, create_scheduler_from_config 14 | 15 | 16 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 17 | from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image 18 | 19 | class AutoencoderTrainingWrapper(pl.LightningModule): 20 | def __init__( 21 | self, 22 | autoencoder: AudioAutoencoder, 23 | lr: float = 1e-4, 24 | warmup_steps: int = 0, 25 | encoder_freeze_on_warmup: bool = False, 26 | sample_rate=48000, 27 | loss_config: dict = None, 28 | optimizer_configs: dict = None, 29 | use_ema: bool = True, 30 | ema_copy = None, 31 | force_input_mono = False, 32 | latent_mask_ratio = 0.0, 33 | teacher_model: AudioAutoencoder = None 34 | ): 35 | super().__init__() 36 | 37 | self.automatic_optimization = False 38 | 39 | self.autoencoder = autoencoder 40 | 41 | self.warmed_up = False 42 | self.warmup_steps = warmup_steps 43 | self.encoder_freeze_on_warmup = encoder_freeze_on_warmup 44 | self.lr = lr 45 | 46 | self.force_input_mono = force_input_mono 47 | 48 | self.teacher_model = teacher_model 49 | 50 | if optimizer_configs is None: 51 | optimizer_configs ={ 52 | "autoencoder": { 53 | "optimizer": { 54 | "type": "AdamW", 55 | "config": { 56 | "lr": lr, 57 | "betas": (.8, .99) 58 | } 59 | } 60 | }, 61 | "discriminator": { 62 | "optimizer": { 63 | "type": "AdamW", 64 | "config": { 65 | "lr": lr, 66 | "betas": (.8, .99) 67 | } 68 | } 69 | } 70 | 71 | } 72 | 73 | self.optimizer_configs = optimizer_configs 74 | 75 | if loss_config is None: 76 | scales = [2048, 1024, 512, 256, 128, 64, 32] 77 | hop_sizes = [] 78 | win_lengths = [] 79 | overlap = 0.75 80 | for s in scales: 81 | hop_sizes.append(int(s * (1 - overlap))) 82 | win_lengths.append(s) 83 | 84 | loss_config = { 85 | "discriminator": { 86 | "type": "encodec", 87 | "config": { 88 | "n_ffts": scales, 89 | "hop_lengths": hop_sizes, 90 | "win_lengths": win_lengths, 91 | "filters": 32 92 | }, 93 | "weights": { 94 | "adversarial": 0.1, 95 | "feature_matching": 5.0, 96 | } 97 | }, 98 | "spectral": { 99 | "type": "mrstft", 100 | "config": { 101 | "fft_sizes": scales, 102 | "hop_sizes": hop_sizes, 103 | "win_lengths": win_lengths, 104 | "perceptual_weighting": True 105 | }, 106 | "weights": { 107 | "mrstft": 1.0, 108 | } 109 | }, 110 | "time": { 111 | "type": "l1", 112 | "config": {}, 113 | "weights": { 114 | "l1": 0.0, 115 | } 116 | } 117 | } 118 | 119 | self.loss_config = loss_config 120 | 121 | # Spectral reconstruction loss 122 | 123 | stft_loss_args = loss_config['spectral']['config'] 124 | 125 | if self.autoencoder.out_channels == 2: 126 | self.sdstft = SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) 127 | self.lrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) 128 | else: 129 | self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) 130 | 131 | # Discriminator 132 | 133 | if loss_config['discriminator']['type'] == 'oobleck': 134 | self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config']) 135 | elif loss_config['discriminator']['type'] == 'encodec': 136 | self.discriminator = EncodecDiscriminator(in_channels=self.autoencoder.out_channels, **loss_config['discriminator']['config']) 137 | elif loss_config['discriminator']['type'] == 'dac': 138 | self.discriminator = DACGANLoss(channels=self.autoencoder.out_channels, sample_rate=sample_rate, **loss_config['discriminator']['config']) 139 | 140 | self.gen_loss_modules = [] 141 | 142 | # Adversarial and feature matching losses 143 | self.gen_loss_modules += [ 144 | ValueLoss(key='loss_adv', weight=self.loss_config['discriminator']['weights']['adversarial'], name='loss_adv'), 145 | ValueLoss(key='feature_matching_distance', weight=self.loss_config['discriminator']['weights']['feature_matching'], name='feature_matching'), 146 | ] 147 | 148 | if self.teacher_model is not None: 149 | # Distillation losses 150 | 151 | stft_loss_weight = self.loss_config['spectral']['weights']['mrstft'] * 0.25 152 | self.gen_loss_modules += [ 153 | AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=stft_loss_weight), # Reconstruction loss 154 | AuralossLoss(self.sdstft, 'decoded', 'teacher_decoded', name='mrstft_loss_distill', weight=stft_loss_weight), # Distilled model's decoder is compatible with teacher's decoder 155 | AuralossLoss(self.sdstft, 'reals', 'own_latents_teacher_decoded', name='mrstft_loss_own_latents_teacher', weight=stft_loss_weight), # Distilled model's encoder is compatible with teacher's decoder 156 | AuralossLoss(self.sdstft, 'reals', 'teacher_latents_own_decoded', name='mrstft_loss_teacher_latents_own', weight=stft_loss_weight) # Teacher's encoder is compatible with distilled model's decoder 157 | ] 158 | 159 | else: 160 | 161 | # Reconstruction loss 162 | self.gen_loss_modules += [ 163 | AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), 164 | ] 165 | 166 | if self.autoencoder.out_channels == 2: 167 | 168 | # Add left and right channel reconstruction losses in addition to the sum and difference 169 | self.gen_loss_modules += [ 170 | AuralossLoss(self.lrstft, 'reals_left', 'decoded_left', name='stft_loss_left', weight=self.loss_config['spectral']['weights']['mrstft']/2), 171 | AuralossLoss(self.lrstft, 'reals_right', 'decoded_right', name='stft_loss_right', weight=self.loss_config['spectral']['weights']['mrstft']/2), 172 | ] 173 | 174 | self.gen_loss_modules += [ 175 | AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), 176 | ] 177 | 178 | if self.loss_config['time']['weights']['l1'] > 0.0: 179 | self.gen_loss_modules.append(L1Loss(key_a='reals', key_b='decoded', weight=self.loss_config['time']['weights']['l1'], name='l1_time_loss')) 180 | 181 | if self.autoencoder.bottleneck is not None: 182 | self.gen_loss_modules += create_loss_modules_from_bottleneck(self.autoencoder.bottleneck, self.loss_config) 183 | 184 | self.losses_gen = MultiLoss(self.gen_loss_modules) 185 | 186 | self.disc_loss_modules = [ 187 | ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'), 188 | ] 189 | 190 | self.losses_disc = MultiLoss(self.disc_loss_modules) 191 | 192 | # Set up EMA for model weights 193 | self.autoencoder_ema = None 194 | 195 | self.use_ema = use_ema 196 | 197 | if self.use_ema: 198 | self.autoencoder_ema = EMA( 199 | self.autoencoder, 200 | ema_model=ema_copy, 201 | beta=0.9999, 202 | power=3/4, 203 | update_every=1, 204 | update_after_step=1 205 | ) 206 | 207 | self.latent_mask_ratio = latent_mask_ratio 208 | 209 | def configure_optimizers(self): 210 | 211 | opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], self.autoencoder.parameters()) 212 | opt_disc = create_optimizer_from_config(self.optimizer_configs['discriminator']['optimizer'], self.discriminator.parameters()) 213 | 214 | if "scheduler" in self.optimizer_configs['autoencoder'] and "scheduler" in self.optimizer_configs['discriminator']: 215 | sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen) 216 | sched_disc = create_scheduler_from_config(self.optimizer_configs['discriminator']['scheduler'], opt_disc) 217 | return [opt_gen, opt_disc], [sched_gen, sched_disc] 218 | 219 | return [opt_gen, opt_disc] 220 | 221 | def training_step(self, batch, batch_idx): 222 | reals, _ = batch 223 | 224 | # Remove extra dimension added by WebDataset 225 | if reals.ndim == 4 and reals.shape[0] == 1: 226 | reals = reals[0] 227 | 228 | if self.global_step >= self.warmup_steps: 229 | self.warmed_up = True 230 | 231 | loss_info = {} 232 | 233 | loss_info["reals"] = reals 234 | 235 | encoder_input = reals 236 | 237 | if self.force_input_mono and encoder_input.shape[1] > 1: 238 | encoder_input = encoder_input.mean(dim=1, keepdim=True) 239 | 240 | loss_info["encoder_input"] = encoder_input 241 | 242 | data_std = encoder_input.std() 243 | 244 | if self.warmed_up and self.encoder_freeze_on_warmup: 245 | with torch.no_grad(): 246 | latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) 247 | else: 248 | latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) 249 | 250 | loss_info["latents"] = latents 251 | 252 | loss_info.update(encoder_info) 253 | 254 | # Encode with teacher model for distillation 255 | if self.teacher_model is not None: 256 | with torch.no_grad(): 257 | teacher_latents = self.teacher_model.encode(encoder_input, return_info=False) 258 | loss_info['teacher_latents'] = teacher_latents 259 | 260 | if self.latent_mask_ratio > 0.0: 261 | mask = torch.rand_like(latents) < self.latent_mask_ratio 262 | latents = torch.where(mask, torch.zeros_like(latents), latents) 263 | 264 | decoded = self.autoencoder.decode(latents) 265 | 266 | loss_info["decoded"] = decoded 267 | 268 | if self.autoencoder.out_channels == 2: 269 | loss_info["decoded_left"] = decoded[:, 0:1, :] 270 | loss_info["decoded_right"] = decoded[:, 1:2, :] 271 | loss_info["reals_left"] = reals[:, 0:1, :] 272 | loss_info["reals_right"] = reals[:, 1:2, :] 273 | 274 | # Distillation 275 | if self.teacher_model is not None: 276 | with torch.no_grad(): 277 | teacher_decoded = self.teacher_model.decode(teacher_latents) 278 | own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher 279 | teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model 280 | 281 | loss_info['teacher_decoded'] = teacher_decoded 282 | loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded 283 | loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded 284 | 285 | 286 | if self.warmed_up: 287 | loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals, decoded) 288 | else: 289 | loss_dis = torch.tensor(0.).to(reals) 290 | loss_adv = torch.tensor(0.).to(reals) 291 | feature_matching_distance = torch.tensor(0.).to(reals) 292 | 293 | loss_info["loss_dis"] = loss_dis 294 | loss_info["loss_adv"] = loss_adv 295 | loss_info["feature_matching_distance"] = feature_matching_distance 296 | 297 | opt_gen, opt_disc = self.optimizers() 298 | 299 | lr_schedulers = self.lr_schedulers() 300 | 301 | sched_gen = None 302 | sched_disc = None 303 | 304 | if lr_schedulers is not None: 305 | sched_gen, sched_disc = lr_schedulers 306 | 307 | # Train the discriminator 308 | if self.global_step % 2 and self.warmed_up: 309 | loss, losses = self.losses_disc(loss_info) 310 | 311 | log_dict = { 312 | 'train/disc_lr': opt_disc.param_groups[0]['lr'] 313 | } 314 | 315 | opt_disc.zero_grad() 316 | self.manual_backward(loss) 317 | opt_disc.step() 318 | 319 | if sched_disc is not None: 320 | # sched step every step 321 | sched_disc.step() 322 | 323 | # Train the generator 324 | else: 325 | 326 | loss, losses = self.losses_gen(loss_info) 327 | 328 | if self.use_ema: 329 | self.autoencoder_ema.update() 330 | 331 | opt_gen.zero_grad() 332 | self.manual_backward(loss) 333 | opt_gen.step() 334 | 335 | if sched_gen is not None: 336 | # scheduler step every step 337 | sched_gen.step() 338 | 339 | log_dict = { 340 | 'train/loss': loss.detach(), 341 | 'train/latent_std': latents.std().detach(), 342 | 'train/data_std': data_std.detach(), 343 | 'train/gen_lr': opt_gen.param_groups[0]['lr'] 344 | } 345 | 346 | for loss_name, loss_value in losses.items(): 347 | log_dict[f'train/{loss_name}'] = loss_value.detach() 348 | 349 | self.log_dict(log_dict, prog_bar=True, on_step=True) 350 | 351 | return loss 352 | 353 | def export_model(self, path, use_safetensors=False): 354 | if self.autoencoder_ema is not None: 355 | model = self.autoencoder_ema.ema_model 356 | else: 357 | model = self.autoencoder 358 | 359 | if use_safetensors: 360 | save_model(model, path) 361 | else: 362 | torch.save({"state_dict": model.state_dict()}, path) 363 | 364 | 365 | class AutoencoderDemoCallback(pl.Callback): 366 | def __init__( 367 | self, 368 | demo_dl, 369 | demo_every=2000, 370 | sample_size=65536, 371 | sample_rate=48000 372 | ): 373 | super().__init__() 374 | self.demo_every = demo_every 375 | self.demo_samples = sample_size 376 | self.demo_dl = iter(demo_dl) 377 | self.sample_rate = sample_rate 378 | self.last_demo_step = -1 379 | 380 | @rank_zero_only 381 | @torch.no_grad() 382 | def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): 383 | if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: 384 | return 385 | 386 | self.last_demo_step = trainer.global_step 387 | 388 | module.eval() 389 | 390 | try: 391 | demo_reals, _ = next(self.demo_dl) 392 | 393 | # Remove extra dimension added by WebDataset 394 | if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: 395 | demo_reals = demo_reals[0] 396 | 397 | encoder_input = demo_reals 398 | 399 | encoder_input = encoder_input.to(module.device) 400 | 401 | if module.force_input_mono: 402 | encoder_input = encoder_input.mean(dim=1, keepdim=True) 403 | 404 | demo_reals = demo_reals.to(module.device) 405 | 406 | with torch.no_grad(): 407 | if module.use_ema: 408 | 409 | latents = module.autoencoder_ema.ema_model.encode(encoder_input) 410 | 411 | fakes = module.autoencoder_ema.ema_model.decode(latents) 412 | else: 413 | latents = module.autoencoder.encode(encoder_input) 414 | 415 | fakes = module.autoencoder.decode(latents) 416 | 417 | #Interleave reals and fakes 418 | reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') 419 | 420 | # Put the demos together 421 | reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') 422 | 423 | log_dict = {} 424 | 425 | filename = f'recon_{trainer.global_step:08}.wav' 426 | reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() 427 | torchaudio.save(filename, reals_fakes, self.sample_rate) 428 | 429 | log_dict[f'recon'] = wandb.Audio(filename, 430 | sample_rate=self.sample_rate, 431 | caption=f'Reconstructed') 432 | 433 | log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents) 434 | log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents)) 435 | 436 | log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) 437 | 438 | trainer.logger.experiment.log(log_dict) 439 | except Exception as e: 440 | print(f'{type(e).__name__}: {e}') 441 | raise e 442 | finally: 443 | module.train() 444 | 445 | def create_loss_modules_from_bottleneck(bottleneck, loss_config): 446 | losses = [] 447 | 448 | if isinstance(bottleneck, VAEBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): 449 | try: 450 | kl_weight = loss_config['bottleneck']['weights']['kl'] 451 | except: 452 | kl_weight = 1e-6 453 | 454 | kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss') 455 | losses.append(kl_loss) 456 | 457 | if isinstance(bottleneck, RVQBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): 458 | quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss') 459 | losses.append(quantizer_loss) 460 | 461 | if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck): 462 | codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss') 463 | commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss') 464 | losses.append(codebook_loss) 465 | losses.append(commitment_loss) 466 | 467 | if isinstance(bottleneck, WassersteinBottleneck): 468 | try: 469 | mmd_weight = loss_config['bottleneck']['weights']['mmd'] 470 | except: 471 | mmd_weight = 100 472 | 473 | mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss') 474 | losses.append(mmd_loss) 475 | 476 | return losses -------------------------------------------------------------------------------- /stable_audio_tools/training/factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from ..models.factory import create_model_from_config 4 | 5 | def create_training_wrapper_from_config(model_config, model): 6 | model_type = model_config.get('model_type', None) 7 | assert model_type is not None, 'model_type must be specified in model config' 8 | 9 | training_config = model_config.get('training', None) 10 | assert training_config is not None, 'training config must be specified in model config' 11 | 12 | if model_type == 'autoencoder': 13 | from .autoencoders import AutoencoderTrainingWrapper 14 | 15 | ema_copy = None 16 | 17 | if training_config.get("use_ema", False): 18 | ema_copy = create_model_from_config(model_config) 19 | ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once 20 | # Copy each weight to the ema copy 21 | for name, param in model.state_dict().items(): 22 | if isinstance(param, Parameter): 23 | # backwards compatibility for serialized parameters 24 | param = param.data 25 | ema_copy.state_dict()[name].copy_(param) 26 | 27 | use_ema = training_config.get("use_ema", False) 28 | 29 | latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0) 30 | 31 | teacher_model = training_config.get("teacher_model", None) 32 | if teacher_model is not None: 33 | teacher_model = create_model_from_config(teacher_model) 34 | teacher_model = teacher_model.eval().requires_grad_(False) 35 | 36 | teacher_model_ckpt = training_config.get("teacher_model_ckpt", None) 37 | if teacher_model_ckpt is not None: 38 | teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"]) 39 | else: 40 | raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified") 41 | 42 | return AutoencoderTrainingWrapper( 43 | model, 44 | lr=training_config["learning_rate"], 45 | warmup_steps=training_config.get("warmup_steps", 0), 46 | encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False), 47 | sample_rate=model_config["sample_rate"], 48 | loss_config=training_config.get("loss_configs", None), 49 | optimizer_configs=training_config.get("optimizer_configs", None), 50 | use_ema=use_ema, 51 | ema_copy=ema_copy if use_ema else None, 52 | force_input_mono=training_config.get("force_input_mono", False), 53 | latent_mask_ratio=latent_mask_ratio, 54 | teacher_model=teacher_model 55 | ) 56 | elif model_type == 'diffusion_uncond': 57 | from .diffusion import DiffusionUncondTrainingWrapper 58 | return DiffusionUncondTrainingWrapper( 59 | model, 60 | lr=training_config["learning_rate"], 61 | pre_encoded=training_config.get("pre_encoded", False), 62 | ) 63 | elif model_type == 'diffusion_cond': 64 | from .diffusion import DiffusionCondTrainingWrapper 65 | return DiffusionCondTrainingWrapper( 66 | model, 67 | lr=training_config.get("learning_rate", None), 68 | mask_padding=training_config.get("mask_padding", False), 69 | mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0), 70 | use_ema = training_config.get("use_ema", True), 71 | log_loss_info=training_config.get("log_loss_info", False), 72 | optimizer_configs=training_config.get("optimizer_configs", None), 73 | pre_encoded=training_config.get("pre_encoded", False), 74 | cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), 75 | timestep_sampler = training_config.get("timestep_sampler", "uniform") 76 | ) 77 | elif model_type == 'diffusion_prior': 78 | from .diffusion import DiffusionPriorTrainingWrapper 79 | from ..models.diffusion_prior import PriorType 80 | 81 | ema_copy = create_model_from_config(model_config) 82 | 83 | # Copy each weight to the ema copy 84 | for name, param in model.state_dict().items(): 85 | if isinstance(param, Parameter): 86 | # backwards compatibility for serialized parameters 87 | param = param.data 88 | ema_copy.state_dict()[name].copy_(param) 89 | 90 | prior_type = training_config.get("prior_type", "mono_stereo") 91 | 92 | if prior_type == "mono_stereo": 93 | prior_type_enum = PriorType.MonoToStereo 94 | else: 95 | raise ValueError(f"Unknown prior type: {prior_type}") 96 | 97 | return DiffusionPriorTrainingWrapper( 98 | model, 99 | lr=training_config["learning_rate"], 100 | ema_copy=ema_copy, 101 | prior_type=prior_type_enum, 102 | log_loss_info=training_config.get("log_loss_info", False), 103 | use_reconstruction_loss=training_config.get("use_reconstruction_loss", False), 104 | ) 105 | elif model_type == 'diffusion_cond_inpaint': 106 | from .diffusion import DiffusionCondInpaintTrainingWrapper 107 | return DiffusionCondInpaintTrainingWrapper( 108 | model, 109 | lr=training_config.get("learning_rate", None), 110 | max_mask_segments = training_config.get("max_mask_segments", 10), 111 | log_loss_info=training_config.get("log_loss_info", False), 112 | optimizer_configs=training_config.get("optimizer_configs", None), 113 | use_ema=training_config.get("use_ema", True), 114 | pre_encoded=training_config.get("pre_encoded", False), 115 | cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), 116 | timestep_sampler = training_config.get("timestep_sampler", "uniform") 117 | ) 118 | elif model_type == 'diffusion_autoencoder': 119 | from .diffusion import DiffusionAutoencoderTrainingWrapper 120 | 121 | ema_copy = create_model_from_config(model_config) 122 | 123 | # Copy each weight to the ema copy 124 | for name, param in model.state_dict().items(): 125 | if isinstance(param, Parameter): 126 | # backwards compatibility for serialized parameters 127 | param = param.data 128 | ema_copy.state_dict()[name].copy_(param) 129 | 130 | return DiffusionAutoencoderTrainingWrapper( 131 | model, 132 | ema_copy=ema_copy, 133 | lr=training_config["learning_rate"], 134 | use_reconstruction_loss=training_config.get("use_reconstruction_loss", False) 135 | ) 136 | elif model_type == 'lm': 137 | from .lm import AudioLanguageModelTrainingWrapper 138 | 139 | ema_copy = create_model_from_config(model_config) 140 | 141 | for name, param in model.state_dict().items(): 142 | if isinstance(param, Parameter): 143 | # backwards compatibility for serialized parameters 144 | param = param.data 145 | ema_copy.state_dict()[name].copy_(param) 146 | 147 | return AudioLanguageModelTrainingWrapper( 148 | model, 149 | ema_copy=ema_copy, 150 | lr=training_config.get("learning_rate", None), 151 | use_ema=training_config.get("use_ema", False), 152 | optimizer_configs=training_config.get("optimizer_configs", None), 153 | pre_encoded=training_config.get("pre_encoded", False), 154 | ) 155 | 156 | else: 157 | raise NotImplementedError(f'Unknown model type: {model_type}') 158 | 159 | def create_demo_callback_from_config(model_config, **kwargs): 160 | model_type = model_config.get('model_type', None) 161 | assert model_type is not None, 'model_type must be specified in model config' 162 | 163 | training_config = model_config.get('training', None) 164 | assert training_config is not None, 'training config must be specified in model config' 165 | 166 | demo_config = training_config.get("demo", {}) 167 | 168 | if model_type == 'autoencoder': 169 | from .autoencoders import AutoencoderDemoCallback 170 | return AutoencoderDemoCallback( 171 | demo_every=demo_config.get("demo_every", 2000), 172 | sample_size=model_config["sample_size"], 173 | sample_rate=model_config["sample_rate"], 174 | **kwargs 175 | ) 176 | elif model_type == 'diffusion_uncond': 177 | from .diffusion import DiffusionUncondDemoCallback 178 | return DiffusionUncondDemoCallback( 179 | demo_every=demo_config.get("demo_every", 2000), 180 | demo_steps=demo_config.get("demo_steps", 250), 181 | sample_rate=model_config["sample_rate"] 182 | ) 183 | elif model_type == "diffusion_autoencoder": 184 | from .diffusion import DiffusionAutoencoderDemoCallback 185 | return DiffusionAutoencoderDemoCallback( 186 | demo_every=demo_config.get("demo_every", 2000), 187 | demo_steps=demo_config.get("demo_steps", 250), 188 | sample_size=model_config["sample_size"], 189 | sample_rate=model_config["sample_rate"], 190 | **kwargs 191 | ) 192 | elif model_type == "diffusion_prior": 193 | from .diffusion import DiffusionPriorDemoCallback 194 | return DiffusionPriorDemoCallback( 195 | demo_every=demo_config.get("demo_every", 2000), 196 | demo_steps=demo_config.get("demo_steps", 250), 197 | sample_size=model_config["sample_size"], 198 | sample_rate=model_config["sample_rate"], 199 | **kwargs 200 | ) 201 | elif model_type == "diffusion_cond": 202 | from .diffusion import DiffusionCondDemoCallback 203 | 204 | return DiffusionCondDemoCallback( 205 | demo_every=demo_config.get("demo_every", 2000), 206 | sample_size=model_config["sample_size"], 207 | sample_rate=model_config["sample_rate"], 208 | demo_steps=demo_config.get("demo_steps", 250), 209 | num_demos=demo_config["num_demos"], 210 | demo_cfg_scales=demo_config["demo_cfg_scales"], 211 | demo_conditioning=demo_config.get("demo_cond", {}), 212 | demo_cond_from_batch=demo_config.get("demo_cond_from_batch", False), 213 | display_audio_cond=demo_config.get("display_audio_cond", False), 214 | ) 215 | elif model_type == "diffusion_cond_inpaint": 216 | from .diffusion import DiffusionCondInpaintDemoCallback 217 | 218 | return DiffusionCondInpaintDemoCallback( 219 | demo_every=demo_config.get("demo_every", 2000), 220 | sample_size=model_config["sample_size"], 221 | sample_rate=model_config["sample_rate"], 222 | demo_steps=demo_config.get("demo_steps", 250), 223 | demo_cfg_scales=demo_config["demo_cfg_scales"], 224 | **kwargs 225 | ) 226 | 227 | elif model_type == "lm": 228 | from .lm import AudioLanguageModelDemoCallback 229 | 230 | return AudioLanguageModelDemoCallback( 231 | demo_every=demo_config.get("demo_every", 2000), 232 | sample_size=model_config["sample_size"], 233 | sample_rate=model_config["sample_rate"], 234 | demo_cfg_scales=demo_config.get("demo_cfg_scales", [1]), 235 | demo_conditioning=demo_config.get("demo_cond", None), 236 | num_demos=demo_config.get("num_demos", 8), 237 | **kwargs 238 | ) 239 | else: 240 | raise NotImplementedError(f'Unknown model type: {model_type}') -------------------------------------------------------------------------------- /stable_audio_tools/training/lm.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import sys, gc 3 | import random 4 | import torch 5 | import torchaudio 6 | import typing as tp 7 | import wandb 8 | 9 | from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image 10 | from ema_pytorch import EMA 11 | from einops import rearrange 12 | from safetensors.torch import save_file 13 | from torch import optim 14 | from torch.nn import functional as F 15 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 16 | 17 | from ..models.lm import AudioLanguageModelWrapper 18 | from .utils import create_optimizer_from_config, create_scheduler_from_config 19 | 20 | class AudioLanguageModelTrainingWrapper(pl.LightningModule): 21 | def __init__( 22 | self, 23 | model: AudioLanguageModelWrapper, 24 | lr = 1e-4, 25 | use_ema=False, 26 | ema_copy=None, 27 | optimizer_configs: dict = None, 28 | pre_encoded=False 29 | ): 30 | super().__init__() 31 | 32 | self.model = model 33 | 34 | self.model.pretransform.requires_grad_(False) 35 | 36 | self.model_ema = None 37 | if use_ema: 38 | self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10) 39 | 40 | assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" 41 | 42 | if optimizer_configs is None: 43 | optimizer_configs = { 44 | "lm": { 45 | "optimizer": { 46 | "type": "AdamW", 47 | "config": { 48 | "lr": lr, 49 | "betas": (0.9, 0.95), 50 | "weight_decay": 0.1 51 | } 52 | } 53 | } 54 | } 55 | else: 56 | if lr is not None: 57 | print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") 58 | 59 | self.optimizer_configs = optimizer_configs 60 | 61 | self.pre_encoded = pre_encoded 62 | 63 | def configure_optimizers(self): 64 | lm_opt_config = self.optimizer_configs['lm'] 65 | opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters()) 66 | 67 | if "scheduler" in lm_opt_config: 68 | sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm) 69 | sched_lm_config = { 70 | "scheduler": sched_lm, 71 | "interval": "step" 72 | } 73 | return [opt_lm], [sched_lm_config] 74 | 75 | return [opt_lm] 76 | 77 | # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license 78 | # License can be found in LICENSES/LICENSE_META.txt 79 | 80 | def _compute_cross_entropy( 81 | self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor 82 | ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: 83 | """Compute cross entropy between multi-codebook targets and model's logits. 84 | The cross entropy is computed per codebook to provide codebook-level cross entropy. 85 | Valid timesteps for each of the codebook are pulled from the mask, where invalid 86 | timesteps are set to 0. 87 | 88 | Args: 89 | logits (torch.Tensor): Model's logits of shape [B, K, T, card]. 90 | targets (torch.Tensor): Target codes, of shape [B, K, T]. 91 | mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. 92 | Returns: 93 | ce (torch.Tensor): Cross entropy averaged over the codebooks 94 | ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). 95 | """ 96 | B, K, T = targets.shape 97 | assert logits.shape[:-1] == targets.shape 98 | assert mask.shape == targets.shape 99 | ce = torch.zeros([], device=targets.device) 100 | ce_per_codebook: tp.List[torch.Tensor] = [] 101 | for k in range(K): 102 | logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] 103 | targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] 104 | mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] 105 | ce_targets = targets_k[mask_k] 106 | ce_logits = logits_k[mask_k] 107 | q_ce = F.cross_entropy(ce_logits, ce_targets) 108 | ce += q_ce 109 | ce_per_codebook.append(q_ce.detach()) 110 | # average cross entropy across codebooks 111 | ce = ce / K 112 | return ce, ce_per_codebook 113 | 114 | def training_step(self, batch, batch_idx): 115 | reals, metadata = batch 116 | 117 | if reals.ndim == 4 and reals.shape[0] == 1: 118 | reals = reals[0] 119 | 120 | if not self.pre_encoded: 121 | codes = self.model.pretransform.tokenize(reals) 122 | else: 123 | codes = reals 124 | 125 | padding_masks = [] 126 | for md in metadata: 127 | if md["padding_mask"].ndim == 1: 128 | padding_masks.append(md["padding_mask"]) 129 | else: 130 | padding_masks.append(md["padding_mask"][0]) 131 | 132 | padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length) 133 | 134 | # Interpolate padding masks to the same length as the codes 135 | padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode='nearest').bool() 136 | 137 | condition_tensors = None 138 | 139 | # If the model is conditioned, get the conditioning tensors 140 | if self.model.conditioner is not None: 141 | condition_tensors = self.model.conditioner(metadata, self.device) 142 | 143 | lm_output = self.model.compute_logits(codes, condition_tensors=condition_tensors, cfg_dropout_prob=0.1) 144 | 145 | logits = lm_output.logits # [b, k, t, c] 146 | logits_mask = lm_output.mask # [b, k, t] 147 | 148 | logits_mask = logits_mask & padding_masks 149 | 150 | cross_entropy, cross_entropy_per_codebook = self._compute_cross_entropy(logits, codes, logits_mask) 151 | 152 | loss = cross_entropy 153 | 154 | log_dict = { 155 | 'train/loss': loss.detach(), 156 | 'train/cross_entropy': cross_entropy.detach(), 157 | 'train/perplexity': torch.exp(cross_entropy).detach(), 158 | 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] 159 | } 160 | 161 | for k, ce_q in enumerate(cross_entropy_per_codebook): 162 | log_dict[f'cross_entropy_q{k + 1}'] = ce_q 163 | log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q) 164 | 165 | self.log_dict(log_dict, prog_bar=True, on_step=True) 166 | return loss 167 | 168 | def on_before_zero_grad(self, *args, **kwargs): 169 | if self.model_ema is not None: 170 | self.model_ema.update() 171 | 172 | def export_model(self, path, use_safetensors=False): 173 | 174 | model = self.model_ema.ema_model if self.model_ema is not None else self.model 175 | 176 | if use_safetensors: 177 | save_file(model.state_dict(), path) 178 | else: 179 | torch.save({"state_dict": model.state_dict()}, path) 180 | 181 | 182 | class AudioLanguageModelDemoCallback(pl.Callback): 183 | def __init__(self, 184 | demo_every=2000, 185 | num_demos=8, 186 | sample_size=65536, 187 | sample_rate=48000, 188 | demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, 189 | demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], 190 | **kwargs 191 | ): 192 | super().__init__() 193 | 194 | self.demo_every = demo_every 195 | self.num_demos = num_demos 196 | self.demo_samples = sample_size 197 | self.sample_rate = sample_rate 198 | self.last_demo_step = -1 199 | self.demo_conditioning = demo_conditioning 200 | self.demo_cfg_scales = demo_cfg_scales 201 | 202 | @rank_zero_only 203 | @torch.no_grad() 204 | def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx): 205 | 206 | if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: 207 | return 208 | 209 | module.eval() 210 | 211 | print(f"Generating demo") 212 | self.last_demo_step = trainer.global_step 213 | 214 | demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio 215 | 216 | #demo_reals = batch[0][:self.num_demos] 217 | 218 | # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: 219 | # demo_reals = demo_reals[0] 220 | 221 | #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals) 222 | 223 | ##Limit to first 50 tokens 224 | #demo_reals_tokens = demo_reals_tokens[:, :, :50] 225 | 226 | try: 227 | print("Getting conditioning") 228 | 229 | for cfg_scale in self.demo_cfg_scales: 230 | 231 | model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model 232 | 233 | print(f"Generating demo for cfg scale {cfg_scale}") 234 | fakes = model.generate_audio( 235 | batch_size=self.num_demos, 236 | max_gen_len=demo_length_tokens, 237 | conditioning=self.demo_conditioning, 238 | #init_data = demo_reals_tokens, 239 | cfg_scale=cfg_scale, 240 | temp=1.0, 241 | top_p=0.95 242 | ) 243 | 244 | # Put the demos together 245 | fakes = rearrange(fakes, 'b d n -> d (b n)') 246 | 247 | log_dict = {} 248 | 249 | filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' 250 | fakes = fakes / fakes.abs().max() 251 | fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu() 252 | torchaudio.save(filename, fakes, self.sample_rate) 253 | 254 | log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, 255 | sample_rate=self.sample_rate, 256 | caption=f'Reconstructed') 257 | 258 | log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) 259 | 260 | trainer.logger.experiment.log(log_dict) 261 | 262 | except Exception as e: 263 | raise e 264 | finally: 265 | gc.collect() 266 | torch.cuda.empty_cache() 267 | module.train() -------------------------------------------------------------------------------- /stable_audio_tools/training/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import * -------------------------------------------------------------------------------- /stable_audio_tools/training/losses/losses.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | from torch.nn import functional as F 4 | from torch import nn 5 | import torch 6 | class LossModule(nn.Module): 7 | def __init__(self, name: str, weight: float = 1.0): 8 | super().__init__() 9 | 10 | self.name = name 11 | self.weight = weight 12 | 13 | def forward(self, info, *args, **kwargs): 14 | raise NotImplementedError 15 | 16 | class ValueLoss(LossModule): 17 | def __init__(self, key: str, name, weight: float = 1.0): 18 | super().__init__(name=name, weight=weight) 19 | 20 | self.key = key 21 | 22 | def forward(self, info): 23 | return self.weight * info[self.key] 24 | 25 | class L1Loss(LossModule): 26 | def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'l1_loss'): 27 | super().__init__(name=name, weight=weight) 28 | 29 | self.key_a = key_a 30 | self.key_b = key_b 31 | 32 | self.mask_key = mask_key 33 | 34 | def forward(self, info): 35 | mse_loss = F.l1_loss(info[self.key_a], info[self.key_b], reduction='none') 36 | 37 | if self.mask_key is not None and self.mask_key in info: 38 | mse_loss = mse_loss[info[self.mask_key]] 39 | 40 | mse_loss = mse_loss.mean() 41 | 42 | return self.weight * mse_loss 43 | 44 | class MSELoss(LossModule): 45 | def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'mse_loss'): 46 | super().__init__(name=name, weight=weight) 47 | 48 | self.key_a = key_a 49 | self.key_b = key_b 50 | 51 | self.mask_key = mask_key 52 | 53 | def forward(self, info): 54 | mse_loss = F.mse_loss(info[self.key_a], info[self.key_b], reduction='none') 55 | 56 | if self.mask_key is not None and self.mask_key in info and info[self.mask_key] is not None: 57 | mask = info[self.mask_key] 58 | 59 | if mask.ndim == 2 and mse_loss.ndim == 3: 60 | mask = mask.unsqueeze(1) 61 | 62 | if mask.shape[1] != mse_loss.shape[1]: 63 | mask = mask.repeat(1, mse_loss.shape[1], 1) 64 | 65 | mse_loss = mse_loss[mask] 66 | 67 | mse_loss = mse_loss.mean() 68 | 69 | return self.weight * mse_loss 70 | 71 | class AuralossLoss(LossModule): 72 | def __init__(self, auraloss_module, input_key: str, target_key: str, name: str, weight: float = 1): 73 | super().__init__(name, weight) 74 | 75 | self.auraloss_module = auraloss_module 76 | 77 | self.input_key = input_key 78 | self.target_key = target_key 79 | 80 | def forward(self, info): 81 | loss = self.auraloss_module(info[self.input_key], info[self.target_key]) 82 | 83 | return self.weight * loss 84 | 85 | class MultiLoss(nn.Module): 86 | def __init__(self, losses: tp.List[LossModule]): 87 | super().__init__() 88 | 89 | self.losses = nn.ModuleList(losses) 90 | 91 | def forward(self, info): 92 | total_loss = 0 93 | 94 | losses = {} 95 | 96 | for loss_module in self.losses: 97 | module_loss = loss_module(info) 98 | total_loss += module_loss 99 | losses[loss_module.name] = module_loss 100 | 101 | return total_loss, losses -------------------------------------------------------------------------------- /stable_audio_tools/training/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | def get_rank(): 5 | """Get rank of current process.""" 6 | 7 | print(os.environ.keys()) 8 | 9 | if "SLURM_PROCID" in os.environ: 10 | return int(os.environ["SLURM_PROCID"]) 11 | 12 | if not torch.distributed.is_available() or not torch.distributed.is_initialized(): 13 | return 0 14 | 15 | return torch.distributed.get_rank() 16 | 17 | class InverseLR(torch.optim.lr_scheduler._LRScheduler): 18 | """Implements an inverse decay learning rate schedule with an optional exponential 19 | warmup. When last_epoch=-1, sets initial lr as lr. 20 | inv_gamma is the number of steps/epochs required for the learning rate to decay to 21 | (1 / 2)**power of its original value. 22 | Args: 23 | optimizer (Optimizer): Wrapped optimizer. 24 | inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. 25 | power (float): Exponential factor of learning rate decay. Default: 1. 26 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) 27 | Default: 0. 28 | final_lr (float): The final learning rate. Default: 0. 29 | last_epoch (int): The index of last epoch. Default: -1. 30 | verbose (bool): If ``True``, prints a message to stdout for 31 | each update. Default: ``False``. 32 | """ 33 | 34 | def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., 35 | last_epoch=-1, verbose=False): 36 | self.inv_gamma = inv_gamma 37 | self.power = power 38 | if not 0. <= warmup < 1: 39 | raise ValueError('Invalid value for warmup') 40 | self.warmup = warmup 41 | self.final_lr = final_lr 42 | super().__init__(optimizer, last_epoch, verbose) 43 | 44 | def get_lr(self): 45 | if not self._get_lr_called_within_step: 46 | import warnings 47 | warnings.warn("To get the last learning rate computed by the scheduler, " 48 | "please use `get_last_lr()`.") 49 | 50 | return self._get_closed_form_lr() 51 | 52 | def _get_closed_form_lr(self): 53 | warmup = 1 - self.warmup ** (self.last_epoch + 1) 54 | lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power 55 | return [warmup * max(self.final_lr, base_lr * lr_mult) 56 | for base_lr in self.base_lrs] 57 | 58 | def copy_state_dict(model, state_dict): 59 | """Load state_dict to model, but only for keys that match exactly. 60 | 61 | Args: 62 | model (nn.Module): model to load state_dict. 63 | state_dict (OrderedDict): state_dict to load. 64 | """ 65 | model_state_dict = model.state_dict() 66 | for key in state_dict: 67 | if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape: 68 | if isinstance(state_dict[key], torch.nn.Parameter): 69 | # backwards compatibility for serialized parameters 70 | state_dict[key] = state_dict[key].data 71 | model_state_dict[key] = state_dict[key] 72 | 73 | model.load_state_dict(model_state_dict, strict=False) 74 | 75 | def create_optimizer_from_config(optimizer_config, parameters): 76 | """Create optimizer from config. 77 | 78 | Args: 79 | parameters (iterable): parameters to optimize. 80 | optimizer_config (dict): optimizer config. 81 | 82 | Returns: 83 | torch.optim.Optimizer: optimizer. 84 | """ 85 | 86 | optimizer_type = optimizer_config["type"] 87 | 88 | if optimizer_type == "FusedAdam": 89 | from deepspeed.ops.adam import FusedAdam 90 | optimizer = FusedAdam(parameters, **optimizer_config["config"]) 91 | else: 92 | optimizer_fn = getattr(torch.optim, optimizer_type) 93 | optimizer = optimizer_fn(parameters, **optimizer_config["config"]) 94 | return optimizer 95 | 96 | def create_scheduler_from_config(scheduler_config, optimizer): 97 | """Create scheduler from config. 98 | 99 | Args: 100 | scheduler_config (dict): scheduler config. 101 | optimizer (torch.optim.Optimizer): optimizer. 102 | 103 | Returns: 104 | torch.optim.lr_scheduler._LRScheduler: scheduler. 105 | """ 106 | if scheduler_config["type"] == "InverseLR": 107 | scheduler_fn = InverseLR 108 | else: 109 | scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"]) 110 | scheduler = scheduler_fn(optimizer, **scheduler_config["config"]) 111 | return scheduler --------------------------------------------------------------------------------