├── ETTA ├── stable_audio_tools │ ├── data │ │ ├── __init__.py │ │ └── utils.py │ ├── inference │ │ ├── __init__.py │ │ └── utils.py │ ├── interface │ │ └── __init__.py │ ├── training │ │ ├── losses │ │ │ ├── __init__.py │ │ │ └── losses.py │ │ ├── __init__.py │ │ └── utils.py │ ├── models │ │ ├── __init__.py │ │ ├── pretrained.py │ │ ├── diffusion_prior.py │ │ ├── wavelets.py │ │ ├── utils.py │ │ ├── lm_backbone.py │ │ └── factory.py │ ├── __init__.py │ ├── configs │ │ ├── dataset_configs │ │ │ ├── etta_vae_training_example.json │ │ │ ├── custom_metadata │ │ │ │ ├── location_caption_pair.py │ │ │ │ └── custom_md_example.py │ │ │ └── etta_dit_training_example.json │ │ └── model_configs │ │ │ ├── autoencoders │ │ │ └── etta_vae.json │ │ │ └── txt2audio │ │ │ └── etta_dit.json │ └── utils │ │ └── addict.py ├── pyproject.toml ├── ETTA_main_results.png ├── AFSynthetic │ ├── AFSynthetic_overview.png │ └── README.md ├── run_gradio.sh ├── examples │ ├── location_caption_pair_valid_example.json │ └── location_caption_pair_train_example.json ├── scripts │ └── ds_zero_to_pl_ckpt.py ├── LICENSE ├── LICENSES │ ├── LICENSE_ADP.txt │ ├── LICENSE_DESCRIPT.txt │ ├── LICENSE_XTRANSFORMERS.txt │ ├── LICENSE_NVIDIA.txt │ ├── LICENSE_STABILITYAI.txt │ ├── LICENSE_ADDICT.txt │ └── LICENSE_META.txt ├── run_gradio.py ├── Dockerfile ├── setup.py ├── docs │ ├── pretransforms.md │ ├── datasets.md │ ├── diffusion.md │ └── conditioning.md ├── README.md └── unwrap_model.py ├── .gitattributes ├── UALM └── README.md ├── AudioFlamingo3 └── static │ ├── af1_radial.png │ ├── af3_sota.png │ ├── logo-no-bg.png │ ├── af3_radial-1.png │ └── af3_main_diagram-1.png ├── A2SB ├── corruption │ ├── __init__.py │ └── corruptions.py ├── audio_transforms │ └── __init__.py ├── DockerFile ├── configs │ ├── inference_files_upsampling.yaml │ ├── inference_files_inpainting.yaml │ ├── ensemble_2split_sampling.yaml │ ├── pretrain.yaml │ └── t_finetune_2split_0.0_0.5.yaml ├── SECURITY.md ├── plotting_utils.py ├── main.py ├── ensembled_inference.py ├── ensembled_inference_api.py ├── utils.py ├── CODE_OF_CONDUCT.md ├── audio_utils.py ├── inference │ ├── A2SB_upsample_api.py │ ├── A2SB_upsample_dataset.py │ └── A2SB_inpaint_dataset.py ├── LICENSE ├── modelcard.md ├── README.md ├── diffusion.py └── datasets │ └── datamodule.py └── .gitignore /ETTA/stable_audio_tools/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/inference/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/interface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/training/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import * -------------------------------------------------------------------------------- /ETTA/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /ETTA/ETTA_main_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/audio-intelligence/main/ETTA/ETTA_main_results.png -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import create_model_from_config, create_model_from_config_path -------------------------------------------------------------------------------- /UALM/README.md: -------------------------------------------------------------------------------- 1 | # UALM: Unified Audio Language Model for Understanding, Generation and Reasoning 2 | 3 | Coming soon. -------------------------------------------------------------------------------- /AudioFlamingo3/static/af1_radial.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/audio-intelligence/main/AudioFlamingo3/static/af1_radial.png -------------------------------------------------------------------------------- /AudioFlamingo3/static/af3_sota.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/audio-intelligence/main/AudioFlamingo3/static/af3_sota.png -------------------------------------------------------------------------------- /AudioFlamingo3/static/logo-no-bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/audio-intelligence/main/AudioFlamingo3/static/logo-no-bg.png -------------------------------------------------------------------------------- /AudioFlamingo3/static/af3_radial-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/audio-intelligence/main/AudioFlamingo3/static/af3_radial-1.png -------------------------------------------------------------------------------- /ETTA/AFSynthetic/AFSynthetic_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/audio-intelligence/main/ETTA/AFSynthetic/AFSynthetic_overview.png -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import create_training_wrapper_from_config, create_demo_callback_from_config 2 | -------------------------------------------------------------------------------- /AudioFlamingo3/static/af3_main_diagram-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/audio-intelligence/main/AudioFlamingo3/static/af3_main_diagram-1.png -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .models.factory import create_model_from_config, create_model_from_config_path 2 | from .models.pretrained import get_pretrained_model -------------------------------------------------------------------------------- /ETTA/run_gradio.sh: -------------------------------------------------------------------------------- 1 | EXP_ROOT=/path/to/your/exp/root && \ 2 | MODEL_NAME=etta && \ 3 | CKPT_NAME=model_unwrap.ckpt && \ 4 | python run_gradio.py \ 5 | --ckpt-path $EXP_ROOT/$MODEL_NAME/$CKPT_NAME \ 6 | --share -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/configs/dataset_configs/etta_vae_training_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "audio_dir", 3 | "random_crop": true, 4 | "datasets": [ 5 | { 6 | "id": "name_of_audio_dataset", 7 | "path": "/path/to/name_of_audio_dataset/audio_dir" 8 | } 9 | ] 10 | } -------------------------------------------------------------------------------- /ETTA/examples/location_caption_pair_valid_example.json: -------------------------------------------------------------------------------- 1 | {"dataset": "name_of_valid_dataset", "location": "/path/to/audio_3.wav", "start": 0.0, "end": 10.0, "captions": "caption for audio_3.wav, 0 to 10 seconds"} 2 | {"dataset": "name_of_valid_dataset", "location": "/path/to/audio_4.wav", "start": 0.0, "end": 10.0, "captions": "caption for audio_4.wav, 0 to 10 seconds"} -------------------------------------------------------------------------------- /A2SB/corruption/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for A2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | 9 | -------------------------------------------------------------------------------- /A2SB/audio_transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for A2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | 9 | -------------------------------------------------------------------------------- /A2SB/DockerFile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.2.2-cuda12.1-cudnn8-runtime 2 | RUN apt-get update --fix-missing && apt-get install -y \ 3 | curl wget\ 4 | ca-certificates \ 5 | git \ 6 | bzip2 \ 7 | libx11-6 \ 8 | && rm -rf /var/lib/apt/lists/* 9 | RUN pip3 install inflect rotary-embedding-torch moviepy lightning tensorboard matplotlib scipy librosa gradio jsonargparse[signatures] einops lmdb h5py torch-audiomentations torchaudio 10 | RUN conda install ffmpeg 11 | -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/configs/dataset_configs/custom_metadata/location_caption_pair.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | def get_custom_metadata(dataset, audio, info): 6 | # Use 'captions' metadata to prompt without modification 7 | info.update({ 8 | "prompt": info['metadata']['captions'], 9 | "prompt_global": info['metadata']['captions'] 10 | }) 11 | 12 | return audio, info -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/configs/dataset_configs/custom_metadata/custom_md_example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | def get_custom_metadata(dataset, audio, info): 6 | # Use relative path as the prompt. 7 | # NOTE: this is mostly for debugging purposes unless relpath contains valid metadata to be used as text prompt. 8 | info.update({ 9 | "prompt": info["relpath"], 10 | }) 11 | 12 | return audio, info -------------------------------------------------------------------------------- /ETTA/examples/location_caption_pair_train_example.json: -------------------------------------------------------------------------------- 1 | {"dataset": "name_of_train_dataset", "location": "/path/to/audio_1.wav", "start": 0.0, "end": 10.0, "captions": "caption for audio_1.wav, 0 to 10 seconds"} 2 | {"dataset": "name_of_train_dataset", "location": "/path/to/audio_1.wav", "start": 10.0, "end": 20.0, "captions": "caption for audio_1.wav, 10 to 20 seconds"} 3 | {"dataset": "name_of_train_dataset", "location": "/path/to/audio_2.wav", "start": 0.0, "end": 10.0, "captions": "caption for audio_2.wav, 0 to 10 seconds"} 4 | {"dataset": "name_of_train_dataset", "location": "/path/to/audio_2.wav", "start": 10.0, "end": 20.0, "captions": "caption for audio_2.wav, 10 to 20 seconds"} -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/configs/dataset_configs/etta_dit_training_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "location_caption_pair_manifests", 3 | "custom_metadata_module": "stable_audio_tools/configs/dataset_configs/custom_metadata/location_caption_pair.py", 4 | "custom_metadata_module_valid": "stable_audio_tools/configs/dataset_configs/custom_metadata/location_caption_pair.py", 5 | "random_crop": false, 6 | "datasets": [ 7 | { 8 | "id": "location_caption_pair_train_example", 9 | "path": "/path/to/examples/location_caption_pair_train_example.json" 10 | } 11 | ], 12 | "datasets_valid": [ 13 | { 14 | "id": "location_caption_pair_valid_example", 15 | "path": "/path/to/examples/location_caption_pair_valid_example.json" 16 | } 17 | ] 18 | } -------------------------------------------------------------------------------- /ETTA/scripts/ds_zero_to_pl_ckpt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | import argparse 6 | from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict 7 | 8 | if __name__ == "__main__": 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--save_path", type=str, help="Path to the zero checkpoint") 12 | parser.add_argument("--output_path", type=str, help="Path to the output checkpoint", default="lightning_model.pt") 13 | args = parser.parse_args() 14 | 15 | # lightning deepspeed has saved a directory instead of a file 16 | save_path = args.save_path 17 | output_path = args.output_path 18 | convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path) -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/models/pretrained.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | import json 6 | 7 | from .factory import create_model_from_config 8 | from .utils import load_ckpt_state_dict 9 | 10 | from huggingface_hub import hf_hub_download 11 | 12 | def get_pretrained_model(name: str): 13 | 14 | model_config_path = hf_hub_download(name, filename="model_config.json", repo_type='model') 15 | 16 | with open(model_config_path) as f: 17 | model_config = json.load(f) 18 | 19 | model = create_model_from_config(model_config) 20 | 21 | # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file 22 | try: 23 | model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model') 24 | except Exception as e: 25 | model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model') 26 | 27 | model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) 28 | 29 | return model, model_config -------------------------------------------------------------------------------- /ETTA/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 NVIDIA CORPORATION. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /ETTA/LICENSES/LICENSE_ADP.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 archinet.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /ETTA/LICENSES/LICENSE_DESCRIPT.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-present, Descript 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /ETTA/LICENSES/LICENSE_XTRANSFORMERS.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /ETTA/LICENSES/LICENSE_NVIDIA.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 NVIDIA CORPORATION. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /ETTA/LICENSES/LICENSE_STABILITYAI.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Stability AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ETTA/LICENSES/LICENSE_ADDICT.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Mats Julian Olsen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /ETTA/LICENSES/LICENSE_META.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /A2SB/configs/inference_files_upsampling.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | predict_filelist: 3 | - filepath: "dummy1.wav" 4 | output_subdir: "dummy1" 5 | - filepath: "dummy2.wav" 6 | output_subdir: "dummy2" 7 | 8 | transforms_gt: 9 | - class_path: audio_transforms.transforms.ComplexSpectrogram 10 | init_args: 11 | n_fft: 2048 12 | win_length: 2048 13 | hop_length: 512 14 | - class_path: audio_transforms.transforms.ComplexToMagInstPhase 15 | - class_path: audio_transforms.transforms.SpectrogramDropDCTerm 16 | - class_path: audio_transforms.transforms.PowerScaleSpectrogram 17 | init_args: 18 | power: 0.25 19 | channels: 20 | - 0 21 | 22 | transforms_aug: 23 | - class_path: corruption.corruptions.MultinomialInpaintMaskTransform 24 | init_args: 25 | p_upsample_mask: 1.0 26 | p_extension_mask: 0.0 27 | p_inpaint_mask: 0.0 28 | fill_noise_level: 0.5 29 | sampling_rate: 44100 30 | upsample_mask_kwargs: 31 | min_cutoff_freq: 2000 32 | max_cutoff_freq: 2000 33 | inpainting_mask_kwargs: 34 | min_inpainting_frac: 0.1013 35 | max_inpainting_frac: 0.1013 36 | is_random: false 37 | -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/inference/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | from ..data.utils import PadCrop 6 | 7 | from torchaudio import transforms as T 8 | 9 | def set_audio_channels(audio, target_channels): 10 | if target_channels == 1: 11 | # Convert to mono 12 | audio = audio.mean(1, keepdim=True) 13 | elif target_channels == 2: 14 | # Convert to stereo 15 | if audio.shape[1] == 1: 16 | audio = audio.repeat(1, 2, 1) 17 | elif audio.shape[1] > 2: 18 | audio = audio[:, :2, :] 19 | return audio 20 | 21 | def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): 22 | 23 | audio = audio.to(device) 24 | 25 | if in_sr != target_sr: 26 | resample_tf = T.Resample(in_sr, target_sr).to(device) 27 | audio = resample_tf(audio) 28 | 29 | audio = PadCrop(target_length, randomize=False)(audio) 30 | 31 | # Add batch dimension 32 | if audio.dim() == 1: 33 | audio = audio.unsqueeze(0).unsqueeze(0) 34 | elif audio.dim() == 2: 35 | audio = audio.unsqueeze(0) 36 | 37 | audio = set_audio_channels(audio, target_channels) 38 | 39 | return audio -------------------------------------------------------------------------------- /ETTA/run_gradio.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | from stable_audio_tools.interface.gradio import create_ui 6 | 7 | import torch 8 | 9 | def main(args): 10 | torch.manual_seed(42) 11 | 12 | interface = create_ui( 13 | ckpt_path=args.ckpt_path, 14 | model_half=args.model_half 15 | ) 16 | interface.queue() 17 | interface.launch(server_name="0.0.0.0", server_port=7680, share=args.share, auth=(args.username, args.password) if args.username is not None else None) 18 | 19 | if __name__ == "__main__": 20 | import argparse 21 | parser = argparse.ArgumentParser(description='Run gradio interface') 22 | parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False) 23 | parser.add_argument('--share', action='store_true', help='Create a publicly shareable link', required=False) 24 | parser.add_argument('--username', type=str, help='Gradio username', required=False) 25 | parser.add_argument('--password', type=str, help='Gradio password', required=False) 26 | parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False) 27 | args = parser.parse_args() 28 | main(args) -------------------------------------------------------------------------------- /ETTA/Dockerfile: -------------------------------------------------------------------------------- 1 | ############################# 2 | # start of Dockerfile 3 | ############################# 4 | # start with latest official pytorch docker 5 | FROM pytorch/pytorch:2.7.1-cuda12.8-cudnn9-devel 6 | 7 | # to minimize node stress 8 | ENV MAX_JOBS=8 9 | ENV NINJA_JOBS=8 10 | 11 | # install essential apt packages first 12 | RUN apt-get update && apt-get install -y \ 13 | git build-essential ninja-build libsndfile1 ffmpeg espeak-ng sox libsox-fmt-all apt-transport-https wget curl gnupg libffi-dev 14 | 15 | # install pip and setuptools to latest 16 | RUN pip install -U pip setuptools 17 | 18 | # Set the working directory 19 | WORKDIR /app 20 | 21 | ############################# 22 | # install stable-audio-tools relatad pip packages with loose version checks 23 | ############################# 24 | RUN pip install flash-attn --no-build-isolation 25 | RUN pip install numpy soundfile pedalboard jupyter notebook packaging alias-free-torch auraloss descript-audio-codec einops einops-exts ema-pytorch encodec gradio huggingface_hub importlib-resources k-diffusion laion-clap local-attention pandas prefigure pytorch_lightning lightning pywavelets pypesq safetensors sentencepiece torchmetrics tqdm transformers v-diffusion-pytorch vector-quantize-pytorch wandb webdataset x-transformers diffusers["torch"] deepspeed 26 | 27 | ############################# 28 | # end of Dockerfile 29 | ############################# -------------------------------------------------------------------------------- /ETTA/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | from setuptools import setup, find_packages 6 | 7 | setup( 8 | name='etta', 9 | version='0.0.1', 10 | url='https://github.com/NVIDIA/elucidated-text-to-audio', 11 | author='NVIDIA', 12 | description='Training and inference code for ETTA, built on top of stable-audio-tools from Stability AI', 13 | packages=find_packages(), 14 | install_requires=[ 15 | 'alias-free-torch>=0.0.6', 16 | 'auraloss>=0.4.0', 17 | 'descript-audio-codec>=1.0.0', 18 | 'diffusers', 19 | 'einops>=0.7.0', 20 | 'einops-exts>=0.0.4', 21 | 'ema-pytorch>=0.2.3', 22 | 'encodec>=0.1.1', 23 | 'gradio>=3.42.0', 24 | 'huggingface_hub', 25 | 'importlib-resources>=5.12.0', 26 | 'k-diffusion>=0.1.1', 27 | 'laion-clap>=1.1.4', 28 | 'local-attention>=1.8.6', 29 | 'notebook', 30 | 'pandas>=2.0.2', 31 | 'pedalboard>=0.7.4', 32 | 'prefigure>=0.0.9', 33 | 'pytorch_lightning>=2.1.0', 34 | 'PyWavelets>=1.4.1', 35 | 'safetensors', 36 | 'sentencepiece>=0.1.99', 37 | 'soundfile', 38 | 's3fs', 39 | 'torchmetrics>=0.11.4', 40 | 'tqdm', 41 | 'transformers', 42 | 'v-diffusion-pytorch>=0.0.2', 43 | 'vector-quantize-pytorch>=1.9.14', 44 | 'wandb>=0.15.4', 45 | 'webdataset>=0.2.48', 46 | 'x-transformers>=1.27.0', 47 | 'deepspeed' 48 | ], 49 | ) -------------------------------------------------------------------------------- /A2SB/configs/inference_files_inpainting.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | predict_filelist: 3 | - filepath: "dummy1.wav" 4 | output_subdir: "dummy1" 5 | - filepath: "dummy2.wav" 6 | output_subdir: "dummy2" 7 | 8 | transforms_gt: 9 | - class_path: audio_transforms.transforms.ComplexSpectrogram 10 | init_args: 11 | n_fft: 2048 12 | win_length: 2048 13 | hop_length: 512 14 | - class_path: audio_transforms.transforms.ComplexToMagInstPhase 15 | - class_path: audio_transforms.transforms.SpectrogramDropDCTerm 16 | - class_path: audio_transforms.transforms.PowerScaleSpectrogram 17 | init_args: 18 | power: 0.25 19 | channels: 20 | - 0 21 | 22 | transforms_aug: 23 | - class_path: corruption.corruptions.TimestampedSegmentInpaintMaskTransform 24 | init_args: 25 | start_time: 1.0 26 | end_time: 1.2 27 | hop_length: 512 28 | sampling_rate: 44100 29 | fill_noise_level: 0.5 30 | - class_path: corruption.corruptions.TimestampedSegmentInpaintMaskTransform 31 | init_args: 32 | start_time: 3.7 33 | end_time: 4.0 34 | hop_length: 512 35 | sampling_rate: 44100 36 | fill_noise_level: 0.5 37 | - class_path: corruption.corruptions.TimestampedSegmentInpaintMaskTransform 38 | init_args: 39 | start_time: 8.0 40 | end_time: 8.5 41 | hop_length: 512 42 | sampling_rate: 44100 43 | fill_noise_level: 0.5 44 | - class_path: corruption.corruptions.TimestampedSegmentInpaintMaskTransform 45 | init_args: 46 | start_time: 11.2 47 | end_time: 12.0 48 | hop_length: 512 49 | sampling_rate: 44100 50 | fill_noise_level: 0.5 51 | -------------------------------------------------------------------------------- /A2SB/SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | NVIDIA is dedicated to the security and trust of our software products and services, including all source code repositories managed through our organization. 4 | 5 | If you need to report a security issue, please use the appropriate contact points outlined below. **Please do not report security vulnerabilities through GitHub.** 6 | 7 | ## Reporting Potential Security Vulnerability in an NVIDIA Product 8 | 9 | To report a potential security vulnerability in any NVIDIA product: 10 | - Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html) 11 | - E-Mail: psirt@nvidia.com 12 | - We encourage you to use the following PGP key for secure email communication: [NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key) 13 | - Please include the following information: 14 | - Product/Driver name and version/branch that contains the vulnerability 15 | - Type of vulnerability (code execution, denial of service, buffer overflow, etc.) 16 | - Instructions to reproduce the vulnerability 17 | - Proof-of-concept or exploit code 18 | - Potential impact of the vulnerability, including how an attacker could exploit the vulnerability 19 | 20 | While NVIDIA currently does not have a bug bounty program, we do offer acknowledgement when an externally reported security issue is addressed under our coordinated vulnerability disclosure policy. Please visit our [Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/) policies page for more information. 21 | 22 | ## NVIDIA Product Security 23 | 24 | For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security 25 | -------------------------------------------------------------------------------- /A2SB/plotting_utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for A2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import matplotlib 9 | matplotlib.use("Agg") 10 | import matplotlib.pylab as plt 11 | import numpy as np 12 | from moviepy.video.io.bindings import mplfig_to_npimage 13 | import librosa 14 | 15 | 16 | def plot_spec_to_numpy(spectrogram, title='', sr=48000, hop_length=512, info=None, vmin=None, vmax=None, cmap='brg'): 17 | fig, ax = plt.subplots(figsize=(6, 4)) 18 | spec_db = librosa.amplitude_to_db(spectrogram, ref=np.max) 19 | 20 | img = librosa.display.specshow(spec_db, sr=sr, hop_length=hop_length, x_axis='frames', y_axis='linear', ax=ax) 21 | 22 | fig.colorbar(img, ax=ax) 23 | fig.tight_layout() 24 | 25 | fig.canvas.draw() 26 | fig.show() 27 | numpy_fig = mplfig_to_npimage(fig) 28 | 29 | return numpy_fig 30 | 31 | 32 | def plot_phase_to_numpy(phase, title='', sr=48000, hop_length=512, info=None, vmin=-np.pi, vmax=np.pi, cmap='hsv'): 33 | fig, ax = plt.subplots(figsize=(6, 4)) 34 | phase_np = phase.numpy() 35 | 36 | img = librosa.display.specshow(phase_np, sr=sr, hop_length=hop_length, x_axis='frames', y_axis='linear', cmap=cmap, ax=ax, vmin=vmin, vmax=vmax) 37 | 38 | cbar = fig.colorbar(img, ax=ax, format='%+2.0f rad') 39 | cbar.set_label('Phase (radians)') 40 | 41 | ax.set_title(title if title else 'Spectrogram Phase') 42 | fig.tight_layout() 43 | 44 | fig.canvas.draw() 45 | fig.show() 46 | numpy_fig = mplfig_to_npimage(fig) 47 | matplotlib.pyplot.close(fig) 48 | return numpy_fig 49 | -------------------------------------------------------------------------------- /A2SB/main.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for A2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | from lightning.pytorch.cli import LightningCLI 9 | from A2SB_lightning_module import STFTBridgeModel, LogValidationInpaintingSTFTCallback 10 | from datasets.datamodule import STFTAudioDataModule 11 | from lightning.pytorch.callbacks import ModelCheckpoint 12 | 13 | 14 | class InpaintingAudioSBLightningCLI(LightningCLI): 15 | def add_arguments_to_parser(self, parser): 16 | parser.add_lightning_class_args(ModelCheckpoint, "checkpoint_callback") 17 | parser.add_lightning_class_args(LogValidationInpaintingSTFTCallback, "validation_inpainting_callback") 18 | parser.set_defaults({"checkpoint_callback.filename": "latest-epoch_{epoch}-iter_{global_step:.0f}", 19 | "checkpoint_callback.monitor": "global_step", 20 | "checkpoint_callback.mode": "max", 21 | "checkpoint_callback.every_n_train_steps": 1000, 22 | "checkpoint_callback.dirpath": "/debug", 23 | "checkpoint_callback.save_top_k": -1, 24 | "checkpoint_callback.auto_insert_metric_name": False}) 25 | parser.link_arguments("checkpoint_callback.dirpath", "trainer.default_root_dir") 26 | 27 | # parser.link_arguments("data.fft_size", "model.fft_size") 28 | # parser.link_arguments("data.hop_size", "model.hop_size") 29 | # parser.link_arguments("data.win_length", "model.win_length") 30 | # parser.link_arguments("data.sampling_rate", "model.sampling_rate") 31 | 32 | cli = InpaintingAudioSBLightningCLI(STFTBridgeModel, STFTAudioDataModule, save_config_kwargs={"overwrite": True}) 33 | -------------------------------------------------------------------------------- /A2SB/ensembled_inference.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for A2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | from lightning.pytorch.cli import LightningCLI 9 | from A2SB_lightning_module import TimePartitionedPretrainedSTFTBridgeModel, LogValidationInpaintingSTFTCallback 10 | from datasets.datamodule import STFTAudioDataModule 11 | from lightning.pytorch.callbacks import ModelCheckpoint 12 | 13 | 14 | class InpaintingAudioSBLightningCLI(LightningCLI): 15 | def add_arguments_to_parser(self, parser): 16 | parser.add_lightning_class_args(ModelCheckpoint, "checkpoint_callback") 17 | parser.add_lightning_class_args(LogValidationInpaintingSTFTCallback, "validation_inpainting_callback") 18 | parser.set_defaults({"checkpoint_callback.filename": "latest-epoch_{epoch}-iter_{global_step:.0f}", 19 | "checkpoint_callback.monitor": "global_step", 20 | "checkpoint_callback.mode": "max", 21 | "checkpoint_callback.every_n_train_steps": 1000, 22 | "checkpoint_callback.dirpath": "/debug", 23 | "checkpoint_callback.save_top_k": -1, 24 | "checkpoint_callback.auto_insert_metric_name": False}) 25 | parser.link_arguments("checkpoint_callback.dirpath", "trainer.default_root_dir") 26 | 27 | # parser.link_arguments("data.fft_size", "model.fft_size") 28 | # parser.link_arguments("data.hop_size", "model.hop_size") 29 | # parser.link_arguments("data.win_length", "model.win_length") 30 | # parser.link_arguments("data.sampling_rate", "model.sampling_rate") 31 | if __name__ == '__main__': 32 | cli = InpaintingAudioSBLightningCLI(TimePartitionedPretrainedSTFTBridgeModel, STFTAudioDataModule, save_config_kwargs={"overwrite": True}) 33 | -------------------------------------------------------------------------------- /A2SB/ensembled_inference_api.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for A2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | from lightning.pytorch.cli import LightningCLI 9 | from A2SB_lightning_module_api import TimePartitionedPretrainedSTFTBridgeModel, LogValidationInpaintingSTFTCallback 10 | from datasets.datamodule import STFTAudioDataModule 11 | from lightning.pytorch.callbacks import ModelCheckpoint 12 | 13 | 14 | class InpaintingAudioSBLightningCLI(LightningCLI): 15 | def add_arguments_to_parser(self, parser): 16 | parser.add_lightning_class_args(ModelCheckpoint, "checkpoint_callback") 17 | parser.add_lightning_class_args(LogValidationInpaintingSTFTCallback, "validation_inpainting_callback") 18 | parser.set_defaults({"checkpoint_callback.filename": "latest-epoch_{epoch}-iter_{global_step:.0f}", 19 | "checkpoint_callback.monitor": "global_step", 20 | "checkpoint_callback.mode": "max", 21 | "checkpoint_callback.every_n_train_steps": 1000, 22 | "checkpoint_callback.dirpath": "/debug", 23 | "checkpoint_callback.save_top_k": -1, 24 | "checkpoint_callback.auto_insert_metric_name": False}) 25 | parser.link_arguments("checkpoint_callback.dirpath", "trainer.default_root_dir") 26 | 27 | # parser.link_arguments("data.fft_size", "model.fft_size") 28 | # parser.link_arguments("data.hop_size", "model.hop_size") 29 | # parser.link_arguments("data.win_length", "model.win_length") 30 | # parser.link_arguments("data.sampling_rate", "model.sampling_rate") 31 | if __name__ == '__main__': 32 | cli = InpaintingAudioSBLightningCLI(TimePartitionedPretrainedSTFTBridgeModel, STFTAudioDataModule, save_config_kwargs={"overwrite": True}) 33 | -------------------------------------------------------------------------------- /ETTA/docs/pretransforms.md: -------------------------------------------------------------------------------- 1 | # Pretransforms 2 | Many models require some fixed transform to be applied to the input audio before the audio is passed in to the trainable layers of the model, as well as a corresponding inverse transform to be applied to the outputs of the model. We refer to these as "pretransforms". 3 | 4 | At the moment, `stable-audio-tools` supports two pretransforms, frozen autoencoders for latent diffusion models and wavelet decompositions. 5 | 6 | Pretransforms have a similar interface to autoencoders with "encode" and "decode" functions defined for each pretransform. 7 | 8 | ## Autoencoder pretransform 9 | To define a model with an autoencoder pretransform, you can define the "pretransform" property in the model config, with the `type` property set to `autoencoder`. The `config` property should be an autoencoder model definition. 10 | 11 | Example: 12 | ```json 13 | "pretransform": { 14 | "type": "autoencoder", 15 | "config": { 16 | "encoder": { 17 | ... 18 | }, 19 | "decoder": { 20 | ... 21 | } 22 | ...normal autoencoder configuration 23 | } 24 | } 25 | ``` 26 | 27 | ### Latent rescaling 28 | The original [Latent Diffusion paper](https://arxiv.org/abs/2112.10752) found that rescaling the latent series to unit variance before performing diffusion improved quality. To this end, we expose a `scale` property on autoencoder pretransforms that will take care of this rescaling. The scale should be set to the original standard deviation of the latents, which can be determined experimentally, or by looking at the `latent_std` value during training. The pretransform code will divide by this scale factor in the `encode` function and multiply by this scale in the `decode` function. 29 | 30 | ## Wavelet pretransform 31 | `stable-audio-tools` also exposes wavelet decomposition as a pretransform. Wavelet decomposition is a quick way to trade off sequence length for channels in autoencoders, while maintaining a multi-band implicit bias. 32 | 33 | Wavelet pretransforms take the following properties: 34 | 35 | - `channels` 36 | - The number of input and output audio channels for the wavelet transform 37 | - `levels` 38 | - The number of successive wavelet decompositions to perform. Each level doubles the channel count and halves the sequence length 39 | - `wavelet` 40 | - The specific wavelet from [PyWavelets](https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html) to use, currently limited to `"bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"` 41 | 42 | ## Future work 43 | We hope to add more filters and transforms to this list, including PQMF and STFT transforms. -------------------------------------------------------------------------------- /A2SB/utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for A2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import torch 9 | from torch import Tensor 10 | from torchaudio.transforms import Spectrogram 11 | 12 | def get_mask_from_lengths(lengths): 13 | """Constructs binary mask from a 1D torch tensor of input lengths 14 | 15 | Args: 16 | lengths (torch.tensor): 1D tensor 17 | Returns: 18 | mask (torch.tensor): num_sequences x max_length x 1 binary tensor 19 | """ 20 | max_len = torch.max(lengths).item() 21 | ids = torch.arange(0, max_len).to(lengths.device) 22 | mask = (ids < lengths.unsqueeze(1)).bool() 23 | return mask 24 | 25 | 26 | class SequenceLength: 27 | """Data structure for storing sequence lengths 28 | """ 29 | def __init__(self, lengths): 30 | self.lengths = lengths.long() 31 | self.mask = get_mask_from_lengths(lengths) 32 | 33 | 34 | def average_key_value(dict_list, key): 35 | """ 36 | Calculate the average value for a given key in a list of dictionaries. 37 | 38 | Parameters: 39 | dict_list (list): List of dictionaries 40 | key (str): The key whose values need to be averaged 41 | 42 | Returns: 43 | float: The average value 44 | """ 45 | if not dict_list: 46 | return 0 47 | 48 | total = sum(d[key] for d in dict_list if key in d) 49 | count = sum(1 for d in dict_list if key in d) 50 | 51 | return total / count if count != 0 else 0 52 | 53 | 54 | def find_middle_of_zero_segments(binary_array: torch.Tensor) -> torch.Tensor: 55 | """ 56 | Find the middle indices of all continuous segments of zeros in a binary array. 57 | 58 | Args: 59 | binary_array (torch.Tensor): A 1D binary tensor containing 0s and 1s. 60 | 61 | Returns: 62 | torch.Tensor: A tensor of middle indices for each segment of zeros. 63 | """ 64 | if not torch.is_tensor(binary_array) or binary_array.ndim != 1: 65 | raise ValueError("Input must be a 1D tensor.") 66 | 67 | # Find transitions 68 | diff = torch.diff(binary_array, prepend=torch.tensor([1], dtype=binary_array.dtype).to(binary_array.device)) 69 | 70 | # Find start and end indices of zero segments 71 | start_indices = (diff == -1).nonzero(as_tuple=True)[0] 72 | end_indices = (diff == 1).nonzero(as_tuple=True)[0] - 1 73 | 74 | # Handle edge case: if array ends with a zero segment 75 | if binary_array[-1] == 0: 76 | end_indices = torch.cat([end_indices, torch.tensor([len(binary_array) - 1]).to(binary_array.device)]) 77 | 78 | # Compute the middle indices of zero segments 79 | middle_indices = ((start_indices + end_indices) / 2).int() 80 | 81 | return middle_indices 82 | -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/models/diffusion_prior.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | from enum import Enum 6 | import typing as tp 7 | 8 | from .diffusion import ConditionedDiffusionModelWrapper 9 | from ..inference.generation import generate_diffusion_cond 10 | from ..inference.utils import prepare_audio 11 | 12 | import torch 13 | from torch.nn import functional as F 14 | from torchaudio import transforms as T 15 | 16 | # Define prior types enum 17 | class PriorType(Enum): 18 | MonoToStereo = 1 19 | 20 | class DiffusionPrior(ConditionedDiffusionModelWrapper): 21 | def __init__(self, *args, prior_type: PriorType=None, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | self.prior_type = prior_type 24 | 25 | class MonoToStereoDiffusionPrior(DiffusionPrior): 26 | def __init__(self, *args, **kwargs): 27 | super().__init__(*args, prior_type=PriorType.MonoToStereo, **kwargs) 28 | 29 | def stereoize( 30 | self, 31 | audio: torch.Tensor, # (batch, channels, time) 32 | in_sr: int, 33 | steps: int, 34 | sampler_kwargs: dict = {}, 35 | ): 36 | """ 37 | Generate stereo audio from mono audio using a pre-trained diffusion prior 38 | 39 | Args: 40 | audio: The mono audio to convert to stereo 41 | in_sr: The sample rate of the input audio 42 | steps: The number of diffusion steps to run 43 | sampler_kwargs: Keyword arguments to pass to the diffusion sampler 44 | """ 45 | 46 | device = audio.device 47 | 48 | sample_rate = self.sample_rate 49 | 50 | # Resample input audio if necessary 51 | if in_sr != sample_rate: 52 | resample_tf = T.Resample(in_sr, sample_rate).to(audio.device) 53 | audio = resample_tf(audio) 54 | 55 | audio_length = audio.shape[-1] 56 | 57 | # Pad input audio to be compatible with the model 58 | min_length = self.min_input_length 59 | padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length 60 | 61 | # Pad input audio to be compatible with the model 62 | if padded_input_length > audio_length: 63 | audio = F.pad(audio, (0, padded_input_length - audio_length)) 64 | 65 | # Make audio mono, duplicate to stereo 66 | dual_mono = audio.mean(1, keepdim=True).repeat(1, 2, 1) 67 | 68 | if self.pretransform is not None: 69 | dual_mono = self.pretransform.encode(dual_mono) 70 | 71 | conditioning = {"source": [dual_mono]} 72 | 73 | stereo_audio = generate_diffusion_cond( 74 | self, 75 | conditioning_tensors=conditioning, 76 | steps=steps, 77 | sample_size=padded_input_length, 78 | sample_rate=sample_rate, 79 | device=device, 80 | **sampler_kwargs, 81 | ) 82 | 83 | return stereo_audio -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/models/wavelets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | """The 1D discrete wavelet transform for PyTorch.""" 6 | 7 | from einops import rearrange 8 | import pywt 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from typing import Literal 13 | 14 | 15 | def get_filter_bank(wavelet): 16 | filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank) 17 | if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0): 18 | filt = filt[:, 1:] 19 | return filt 20 | 21 | class WaveletEncode1d(nn.Module): 22 | def __init__(self, 23 | channels, 24 | levels, 25 | wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): 26 | super().__init__() 27 | self.wavelet = wavelet 28 | self.channels = channels 29 | self.levels = levels 30 | filt = get_filter_bank(wavelet) 31 | assert filt.shape[-1] % 2 == 1 32 | kernel = filt[:2, None] 33 | kernel = torch.flip(kernel, dims=(-1,)) 34 | index_i = torch.repeat_interleave(torch.arange(2), channels) 35 | index_j = torch.tile(torch.arange(channels), (2,)) 36 | kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) 37 | kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] 38 | self.register_buffer("kernel", kernel_final) 39 | 40 | def forward(self, x): 41 | for i in range(self.levels): 42 | low, rest = x[:, : self.channels], x[:, self.channels :] 43 | pad = self.kernel.shape[-1] // 2 44 | low = F.pad(low, (pad, pad), "reflect") 45 | low = F.conv1d(low, self.kernel, stride=2) 46 | rest = rearrange( 47 | rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels 48 | ) 49 | x = torch.cat([low, rest], dim=1) 50 | return x 51 | 52 | 53 | class WaveletDecode1d(nn.Module): 54 | def __init__(self, 55 | channels, 56 | levels, 57 | wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): 58 | super().__init__() 59 | self.wavelet = wavelet 60 | self.channels = channels 61 | self.levels = levels 62 | filt = get_filter_bank(wavelet) 63 | assert filt.shape[-1] % 2 == 1 64 | kernel = filt[2:, None] 65 | index_i = torch.repeat_interleave(torch.arange(2), channels) 66 | index_j = torch.tile(torch.arange(channels), (2,)) 67 | kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) 68 | kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] 69 | self.register_buffer("kernel", kernel_final) 70 | 71 | def forward(self, x): 72 | for i in range(self.levels): 73 | low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :] 74 | pad = self.kernel.shape[-1] // 2 + 2 75 | low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2) 76 | low = F.pad(low, (pad, pad), "reflect") 77 | low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2) 78 | low = F.conv_transpose1d( 79 | low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2 80 | ) 81 | low = low[..., pad - 1 : -pad] 82 | rest = rearrange( 83 | rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels 84 | ) 85 | x = torch.cat([low, rest], dim=1) 86 | return x -------------------------------------------------------------------------------- /A2SB/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Overview 4 | 5 | Define the code of conduct followed and enforced for __diffusion audio restoration__. 6 | 7 | ### Intended audience 8 | 9 | Community | Developers | Project Leads 10 | 11 | ## Our Pledge 12 | 13 | In the interest of fostering an open and welcoming environment, we as 14 | contributors and maintainers pledge to making participation in our project and 15 | our community a harassment-free experience for everyone, regardless of age, body 16 | size, disability, ethnicity, sex characteristics, gender identity and expression, 17 | level of experience, education, socio-economic status, nationality, personal 18 | appearance, race, religion, or sexual identity and orientation. 19 | 20 | ## Our Standards 21 | 22 | Examples of behavior that contributes to creating a positive environment 23 | include: 24 | 25 | * Using welcoming and inclusive language 26 | * Being respectful of differing viewpoints and experiences 27 | * Gracefully accepting constructive criticism 28 | * Focusing on what is best for the community 29 | * Showing empathy towards other community members 30 | 31 | Examples of unacceptable behavior by participants include: 32 | 33 | * The use of sexualized language or imagery and unwelcome sexual attention or 34 | advances 35 | * Trolling, insulting/derogatory comments, and personal or political attacks 36 | * Public or private harassment 37 | * Publishing others' private information, such as a physical or electronic 38 | address, without explicit permission 39 | * Other conduct which could reasonably be considered inappropriate in a 40 | professional setting 41 | 42 | ## Our Responsibilities 43 | 44 | Project maintainers are responsible for clarifying the standards of acceptable 45 | behavior and are expected to take appropriate and fair corrective action in 46 | response to any instances of unacceptable behavior. 47 | 48 | Project maintainers have the right and responsibility to remove, edit, or 49 | reject comments, commits, code, wiki edits, issues, and other contributions 50 | that are not aligned to this Code of Conduct, or to ban temporarily or 51 | permanently any contributor for other behaviors that they deem inappropriate, 52 | threatening, offensive, or harmful. 53 | 54 | ## Scope 55 | 56 | This Code of Conduct applies both within project spaces and in public spaces 57 | when an individual is representing the project or its community. Examples of 58 | representing a project or community include using an official project e-mail 59 | address, posting via an official social media account, or acting as an appointed 60 | representative at an online or offline event. Representation of a project may be 61 | further defined and clarified by project maintainers. 62 | 63 | ## Enforcement 64 | 65 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 66 | reported by contacting GitHub_Conduct@nvidia.com. All complaints will be reviewed and 67 | investigated and will result in a response that is deemed necessary and appropriate 68 | to the circumstances. The project team is obligated to maintain confidentiality with 69 | regard to the reporter of an incident. Further details of specific enforcement policies 70 | may be posted separately. 71 | 72 | Project maintainers who do not follow or enforce the Code of Conduct in good 73 | faith may face temporary or permanent repercussions as determined by other 74 | members of the project's leadership. 75 | 76 | ## Attribution 77 | 78 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 79 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 80 | 81 | [homepage]: https://www.contributor-covenant.org 82 | 83 | For answers to common questions about this code of conduct, see 84 | https://www.contributor-covenant.org/faq 85 | -------------------------------------------------------------------------------- /A2SB/configs/ensemble_2split_sampling.yaml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.3.3 2 | seed_everything: 0 3 | trainer: 4 | accelerator: gpu 5 | strategy: ddp 6 | devices: 1 7 | num_nodes: 1 8 | precision: 32-true 9 | logger: null 10 | callbacks: null 11 | fast_dev_run: false 12 | max_epochs: null 13 | min_epochs: null 14 | max_steps: -1 15 | min_steps: null 16 | max_time: null 17 | limit_train_batches: null 18 | limit_val_batches: null 19 | limit_test_batches: null 20 | limit_predict_batches: null 21 | overfit_batches: 0.0 22 | val_check_interval: 1000 23 | check_val_every_n_epoch: 1 24 | num_sanity_val_steps: null 25 | log_every_n_steps: null 26 | enable_checkpointing: null 27 | enable_progress_bar: null 28 | enable_model_summary: null 29 | accumulate_grad_batches: 2 30 | gradient_clip_val: 0.5 31 | gradient_clip_algorithm: norm 32 | deterministic: null 33 | benchmark: null 34 | inference_mode: true 35 | use_distributed_sampler: true 36 | profiler: null 37 | detect_anomaly: false 38 | barebones: false 39 | plugins: 40 | - class_path: lightning.fabric.plugins.environments.SLURMEnvironment 41 | init_args: 42 | auto_requeue: true 43 | requeue_signal: null 44 | sync_batchnorm: false 45 | reload_dataloaders_every_n_epochs: 0 46 | model: 47 | vf_model: 48 | class_path: networks.AttnUNetF 49 | init_args: 50 | n_updown_levels: 5 51 | in_channels: 3 52 | hidden_channels: [128, 256, 512, 768, 1024, 2048] 53 | out_channels: 3 54 | emb_channels: 128 55 | rotary_dims: 16 56 | band_embedding_dim: 16 57 | n_attn_heads: 8 58 | attention_levels: 59 | - 3 60 | - 4 61 | use_attn_input_norm: true 62 | num_res_blocks: 2 63 | inv_transforms: 64 | - class_path: audio_transforms.transforms.PowerScaleSpectrogram 65 | init_args: 66 | power: 4 67 | channels: 68 | - 0 69 | eps: 1.0e-09 70 | - class_path: audio_transforms.transforms.SpectrogramAddDCTerm 71 | - class_path: audio_transforms.transforms.SVDFixMagInstPhase 72 | - class_path: audio_transforms.transforms.MagInstPhaseToComplex 73 | - class_path: audio_transforms.transforms.InverseComplexSpectrogram 74 | init_args: 75 | n_fft: 2048 76 | win_length: 2048 77 | hop_length: 512 78 | eps: 1.0e-09 79 | sampling_rate: 44100 80 | n_timestep_channels: 128 81 | use_ot_ode: false 82 | beta_max: 1.0 83 | pretrained_checkpoints: 84 | - PATH/TO/FIRST/SPLIT.ckpt 85 | - PATH/TO/SECOND/SPLIT.ckpt 86 | t_cutoffs: [ 0.5 ] 87 | data: 88 | mix_dataset_config: 89 | EVAL_DATASET_NAME: 90 | root_folder: PATH/TO/MANIFEST/FOLDER 91 | filename: 'EVAL_DATASET_NAME_manifest.csv' 92 | EVAL_DATASET_NAME_TWO: 93 | root_folder: PATH/TO/MANIFEST/FOLDER 94 | filename: 'EVAL_DATASET_NAME_TWO_manifest.csv' 95 | 96 | segment_length: 130560 97 | sampling_rate: 44100 98 | num_workers: 23 99 | batch_size: 4 100 | 101 | transforms_aug: [] 102 | transforms_aug_val: [] 103 | eval_transforms_aug: [] 104 | 105 | transforms_gt: 106 | - class_path: audio_transforms.transforms.ComplexSpectrogram 107 | init_args: 108 | n_fft: 2048 109 | win_length: 2048 110 | hop_length: 512 111 | eps: 1.0e-09 112 | - class_path: audio_transforms.transforms.ComplexToMagInstPhase 113 | - class_path: audio_transforms.transforms.SpectrogramDropDCTerm 114 | - class_path: audio_transforms.transforms.PowerScaleSpectrogram 115 | init_args: 116 | power: 0.25 117 | channels: 118 | - 0 119 | eps: 1.0e-09 120 | train_max_samples: null 121 | val_max_samples: 256 122 | predict_filelist: [] 123 | -------------------------------------------------------------------------------- /A2SB/audio_utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for A2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | 9 | import torch 10 | from torch import Tensor 11 | 12 | from typing import List, Optional, Tuple, Union 13 | 14 | 15 | def radian_to_SO2(rads: torch.Tensor): 16 | """ 17 | converts tensor of radians to tensor of SO(2) matrices 18 | for any inputs tensor of shape B x ..., the output will be B x ... x 2 x 2 19 | """ 20 | cos_theta = torch.cos(rads) 21 | sin_theta = torch.sin(rads) 22 | 23 | rot_m = torch.stack([cos_theta, -sin_theta, sin_theta, cos_theta], -1) 24 | rot_m = rot_m.view(*rot_m.shape[:-1], 2, 2) 25 | return rot_m 26 | 27 | 28 | def wav_to_stft(wav: torch.Tensor, fft_size: int, hop_size: int, win_length: int, drop_dc_term=True): 29 | """ 30 | Inputs: 31 | wav: N or B x N tensor of waveform sample values 32 | Returns: 33 | magnitude and phase tensors with or without a batch dimension depending on wav input 34 | Magnitudes: ... x H x W 35 | phase_R: ... x H x W x 2 x 2 (represented as SO2 matrices) 36 | """ 37 | 38 | stft_cmplx = torch.stft(wav, n_fft=fft_size, hop_length=hop_size, win_length=win_length, window=torch.hann_window(win_length), return_complex=True) 39 | if drop_dc_term: 40 | stft_cmplx = stft_cmplx[...,1:,:] 41 | 42 | magnitudes = stft_cmplx.abs() # abs on complex values should compute the vector norm 43 | stft_real = torch.view_as_real(stft_cmplx) 44 | phase = torch.atan2(stft_real[...,1],stft_real[...,0]) 45 | phase_R = radian_to_SO2(phase) 46 | return magnitudes, phase_R 47 | 48 | def phase_R_to_channels(stft_R): 49 | """ 50 | converts B x H x W x 2 x 2 to B x 4 x H x W 51 | """ 52 | if len(stft_R.shape) == 5: 53 | return stft_R.reshape(*stft_R.shape[:3], 4).permute(0,3,1,2) 54 | elif len(stft_R.shape) == 4: 55 | return phase_R_to_channels(phase_R_to_channels.unsqueeze(-1))[0] 56 | else: 57 | print("unsupported dimensions") 58 | exit(1) 59 | 60 | 61 | def phase_channels_to_R(stft_channels): 62 | """ 63 | inverse transformation for phase_R_to_channels 64 | """ 65 | stft_R_flat = stft_channels.permute(0, 2, 3, 1) 66 | stft_R = stft_R_flat.reshape(*stft_R_flat.shape[:3], 2, 2) 67 | return stft_R 68 | 69 | 70 | def stft_mag_R_to_wav(stft_mag, stft_Rch, n_fft, hop_length, win_length, append_dc_term=True): 71 | """ 72 | stft_mag: B x 1 x H x L magnitudes 73 | stft_R: B x 4 x H x L flattened rotations 74 | """ 75 | stft_costheta = stft_Rch[:,0:1] 76 | stft_sintheta = stft_Rch[:,2:3] 77 | x_cos_theta = stft_costheta * stft_mag 78 | x_sin_theta = stft_sintheta * stft_mag 79 | 80 | new_stft_cmplx = torch.view_as_complex(torch.stack([x_cos_theta[:,0], x_sin_theta[:,0]], -1)) 81 | if append_dc_term: # if dc term was removed, add it back in with zero magnitude, zero phase 82 | _b, _h, _w = new_stft_cmplx.shape 83 | dct = torch.zeros(_b, 1, _w, device=new_stft_cmplx.device) 84 | new_stft_cmplx = torch.cat((dct, new_stft_cmplx), 1) 85 | wav_out = torch.istft(new_stft_cmplx, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=torch.hann_window(win_length).to(new_stft_cmplx.device)) 86 | return wav_out 87 | 88 | 89 | def _get_complex_dtype(real_dtype: torch.dtype): 90 | if real_dtype == torch.double: 91 | return torch.cdouble 92 | if real_dtype == torch.float: 93 | return torch.cfloat 94 | if real_dtype == torch.half: 95 | return torch.complex32 96 | raise ValueError(f"Unexpected dtype {real_dtype}") 97 | -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/data/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | import math 6 | import random 7 | import torch 8 | 9 | from torch import nn 10 | from typing import Tuple 11 | 12 | class PadCrop(nn.Module): 13 | def __init__(self, n_samples, randomize=True): 14 | super().__init__() 15 | self.n_samples = n_samples 16 | self.randomize = randomize 17 | 18 | def __call__(self, signal): 19 | n, s = signal.shape 20 | start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() 21 | end = start + self.n_samples 22 | output = signal.new_zeros([n, self.n_samples]) 23 | output[:, :min(s, self.n_samples)] = signal[:, start:end] 24 | return output 25 | 26 | class PadCrop_Normalized_T(nn.Module): 27 | 28 | def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): 29 | 30 | super().__init__() 31 | 32 | self.n_samples = n_samples 33 | self.sample_rate = sample_rate 34 | self.randomize = randomize 35 | 36 | def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]: 37 | 38 | n_channels, n_samples = source.shape 39 | 40 | # If the audio is shorter than the desired length, pad it 41 | upper_bound = max(0, n_samples - self.n_samples) 42 | 43 | # If randomize is False, always start at the beginning of the audio 44 | offset = 0 45 | if(self.randomize and n_samples > self.n_samples): 46 | offset = random.randint(0, upper_bound) 47 | 48 | # Calculate the start and end times of the chunk 49 | t_start = offset / (upper_bound + self.n_samples) 50 | t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) 51 | 52 | # Create the chunk 53 | chunk = source.new_zeros([n_channels, self.n_samples]) 54 | 55 | # Copy the audio into the chunk 56 | chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples] 57 | 58 | # Calculate the start and end times of the chunk in seconds 59 | seconds_start = math.floor(offset / self.sample_rate) 60 | seconds_total = math.ceil(n_samples / self.sample_rate) 61 | 62 | # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't 63 | padding_mask = torch.zeros([self.n_samples]) 64 | padding_mask[:min(n_samples, self.n_samples)] = 1 65 | 66 | 67 | return ( 68 | chunk, 69 | t_start, 70 | t_end, 71 | seconds_start, 72 | seconds_total, 73 | padding_mask 74 | ) 75 | 76 | class PhaseFlipper(nn.Module): 77 | "Randomly invert the phase of a signal" 78 | def __init__(self, p=0.5): 79 | super().__init__() 80 | self.p = p 81 | def __call__(self, signal): 82 | return -signal if (random.random() < self.p) else signal 83 | 84 | class Mono(nn.Module): 85 | def __call__(self, signal): 86 | return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal 87 | 88 | class Stereo(nn.Module): 89 | def __call__(self, signal): 90 | signal_shape = signal.shape 91 | # Check if it's mono 92 | if len(signal_shape) == 1: # s -> 2, s 93 | signal = signal.unsqueeze(0).repeat(2, 1) 94 | elif len(signal_shape) == 2: 95 | if signal_shape[0] == 1: #1, s -> 2, s 96 | signal = signal.repeat(2, 1) 97 | elif signal_shape[0] > 2: #?, s -> 2,s 98 | signal = signal[:2, :] 99 | 100 | return signal 101 | -------------------------------------------------------------------------------- /ETTA/docs/datasets.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | `stable-audio-tools` supports loading data from local file storage, as well as loading audio files and JSON files in the [WebDataset](https://github.com/webdataset/webdataset/tree/main/webdataset) format from Amazon S3 buckets. 3 | 4 | # Dataset configs 5 | To specify the dataset used for training, you must provide a dataset config JSON file to `train.py`. 6 | 7 | The dataset config consists of a `dataset_type` property specifying the type of data loader to use, a `datasets` array to provide multiple data sources, and a `random_crop` property, which decides if the cropped audio from the training samples is from a random place in the audio file, or always from the beginning. 8 | 9 | ## Local audio files 10 | To use a local directory of audio samples, set the `dataset_type` property in your dataset config to `"audio_dir"`, and provide a list of objects to the `datasets` property including the `path` property, which should be the path to your directory of audio samples. 11 | 12 | This will load all of the compatible audio files from the provided directory and all subdirectories. 13 | 14 | ### Example config 15 | ```json 16 | { 17 | "dataset_type": "audio_dir", 18 | "datasets": [ 19 | { 20 | "id": "my_audio", 21 | "path": "/path/to/audio/dataset/" 22 | } 23 | ], 24 | "random_crop": true 25 | } 26 | ``` 27 | 28 | ## S3 WebDataset 29 | To load audio files and related metadata from .tar files in the WebDataset format hosted in Amazon S3 buckets, you can set the `dataset_type` property to `s3`, and provide the `datasets` parameter with a list of objects containing the AWS S3 path to the shared S3 bucket prefix of the WebDataset .tar files. The S3 bucket will be searched recursively given the path, and assumes any .tar files found contain audio files and corresponding JSON files where the related files differ only in file extension (e.g. "000001.flac", "000001.json", "00002.flac", "00002.json", etc.) 30 | 31 | ### Example config 32 | ```json 33 | { 34 | "dataset_type": "s3", 35 | "datasets": [ 36 | { 37 | "id": "s3-test", 38 | "s3_path": "s3://my-bucket/datasets/webdataset/audio/" 39 | } 40 | ], 41 | "random_crop": true 42 | } 43 | ``` 44 | 45 | # Custom metadata 46 | To customize the metadata provided to the conditioners during model training, you can provide a separate custom metadata module to the dataset config. This metadata module should be a Python file that must contain a function called `get_custom_metadata` that takes in two parameters, `info`, and `audio`, and returns a dictionary. 47 | 48 | For local training, the `info` parameter will contain a few pieces of information about the loaded audio file, such as the path, and information about how the audio was cropped from the original training sample. For WebDataset datasets, it will also contain the metadata from the related JSON files. 49 | 50 | The `audio` parameter contains the audio sample that will be passed to the model at training time. This lets you analyze the audio for extra properties that you can then pass in as extra conditioning signals. 51 | 52 | The dictionary returned from the `get_custom_metadata` function will have its properties added to the `metadata` object used at training time. For more information on how conditioning works, please see the [Conditioning documentation](./conditioning.md) 53 | 54 | ## Example config and custom metadata module 55 | ```json 56 | { 57 | "dataset_type": "audio_dir", 58 | "datasets": [ 59 | { 60 | "id": "my_audio", 61 | "path": "/path/to/audio/dataset/", 62 | "custom_metadata_module": "/path/to/custom_metadata.py", 63 | } 64 | ], 65 | "random_crop": true 66 | } 67 | ``` 68 | 69 | `custom_metadata.py`: 70 | ```py 71 | def get_custom_metadata(info, audio): 72 | 73 | # Pass in the relative path of the audio file as the prompt 74 | return {"prompt": info["relpath"]} 75 | ``` -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/training/losses/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | import typing as tp 6 | 7 | from torch.nn import functional as F 8 | from torch import nn 9 | 10 | class LossModule(nn.Module): 11 | def __init__(self, name: str, weight: float = 1.0): 12 | super().__init__() 13 | 14 | self.name = name 15 | self.weight = weight 16 | 17 | def forward(self, info, *args, **kwargs): 18 | raise NotImplementedError 19 | 20 | class ValueLoss(LossModule): 21 | def __init__(self, key: str, name, weight: float = 1.0): 22 | super().__init__(name=name, weight=weight) 23 | 24 | self.key = key 25 | 26 | def forward(self, info): 27 | return self.weight * info[self.key] 28 | 29 | class L1Loss(LossModule): 30 | def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'l1_loss'): 31 | super().__init__(name=name, weight=weight) 32 | 33 | self.key_a = key_a 34 | self.key_b = key_b 35 | 36 | self.mask_key = mask_key 37 | 38 | def forward(self, info): 39 | mse_loss = F.l1_loss(info[self.key_a], info[self.key_b], reduction='none') 40 | 41 | if self.mask_key is not None and self.mask_key in info: 42 | mse_loss = mse_loss[info[self.mask_key]] 43 | 44 | mse_loss = mse_loss.mean() 45 | 46 | return self.weight * mse_loss 47 | 48 | class MSELoss(LossModule): 49 | def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'mse_loss'): 50 | super().__init__(name=name, weight=weight) 51 | 52 | self.key_a = key_a 53 | self.key_b = key_b 54 | 55 | self.mask_key = mask_key 56 | 57 | def forward(self, info): 58 | mse_loss = F.mse_loss(info[self.key_a], info[self.key_b], reduction='none') 59 | 60 | if self.mask_key is not None and self.mask_key in info and info[self.mask_key] is not None: 61 | mask = info[self.mask_key] 62 | 63 | if mask.ndim == 2 and mse_loss.ndim == 3: 64 | mask = mask.unsqueeze(1) 65 | 66 | if mask.shape[1] != mse_loss.shape[1]: 67 | mask = mask.repeat(1, mse_loss.shape[1], 1) 68 | 69 | mse_loss = mse_loss[mask] 70 | 71 | mse_loss = mse_loss.mean() 72 | 73 | return self.weight * mse_loss 74 | 75 | 76 | class AuralossLoss(LossModule): 77 | def __init__(self, auraloss_module, input_key: str, target_key: str, name: str, weight: float = 1): 78 | super().__init__(name, weight) 79 | 80 | self.auraloss_module = auraloss_module 81 | 82 | self.input_key = input_key 83 | self.target_key = target_key 84 | 85 | def forward(self, info): 86 | loss = self.auraloss_module(info[self.input_key], info[self.target_key]) 87 | 88 | return self.weight * loss 89 | 90 | class MultiLoss(nn.Module): 91 | def __init__(self, losses: tp.List[LossModule]): 92 | super().__init__() 93 | 94 | self.losses = nn.ModuleList(losses) 95 | 96 | def forward( 97 | self, 98 | info, 99 | dynamic_weights = {} 100 | ): 101 | total_loss = 0 102 | 103 | losses = {} 104 | 105 | for loss_module in self.losses: 106 | module_loss = loss_module(info) 107 | 108 | # new impl for dynamic weighting of loss if passed in forward(). otherwise use original weight 109 | dynamic_weight = dynamic_weights.get(loss_module.name, None) 110 | if dynamic_weight is not None: 111 | total_loss += dynamic_weight * module_loss 112 | else: 113 | total_loss += module_loss 114 | 115 | losses[loss_module.name] = module_loss 116 | 117 | return total_loss, losses 118 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .DS_Store 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | 163 | .vscode/ 164 | tmp/ 165 | *.ckpt 166 | *.wav 167 | wandb/* 168 | .gradio/ 169 | 170 | stable_audio_tools/configs/dataset_configs/*filenames_cache*/* 171 | -------------------------------------------------------------------------------- /A2SB/inference/A2SB_upsample_api.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for A2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | 10 | # # If there is Error: mkl-service + Intel(R) MKL: MKL_THREADING_LAYER=INTEL is incompatible with libgomp.so.1 library. 11 | # os.environ["MKL_THREADING_LAYER"] = "GNU" 12 | # import numpy as np 13 | # os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" 14 | 15 | import numpy as np 16 | import json 17 | import argparse 18 | import glob 19 | from subprocess import Popen, PIPE 20 | import yaml 21 | import time 22 | from datetime import datetime 23 | import shutil 24 | import csv 25 | from tqdm import tqdm 26 | 27 | import librosa 28 | import soundfile as sf 29 | 30 | 31 | def load_yaml(file_path): 32 | with open(file_path, 'r') as file: 33 | data = yaml.safe_load(file) 34 | return data 35 | 36 | 37 | def save_yaml(data, prefix="../configs/temp"): 38 | os.makedirs(os.path.dirname(prefix), exist_ok=True) 39 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 40 | rnd_num = np.random.rand() 41 | rnd_num = rnd_num - rnd_num % 0.000001 42 | file_name = f"{prefix}_{timestamp}_{rnd_num}.yaml" 43 | with open(file_name, 'w') as f: 44 | yaml.dump(data, f) 45 | return file_name 46 | 47 | 48 | def shell_run_cmd(cmd): 49 | print('running:', cmd) 50 | p = Popen(cmd, stdout=PIPE, stderr=PIPE, shell=True) 51 | stdout, stderr = p.communicate() 52 | print(stdout) 53 | print(stderr) 54 | 55 | 56 | def compute_rolloff_freq(audio_file, roll_percent=0.99): 57 | y, sr = librosa.load(audio_file, sr=None) 58 | rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr, roll_percent=roll_percent)[0] 59 | rolloff = int(np.mean(rolloff)) 60 | print('99 percent rolloff:', rolloff) 61 | return rolloff 62 | 63 | 64 | def upsample_one_sample(audio_filename, output_audio_filename, predict_n_steps=50): 65 | 66 | assert output_audio_filename != audio_filename, "output filename cannot be input filename" 67 | 68 | inference_config = load_yaml('../configs/inference_files_upsampling.yaml') 69 | inference_config['data']['predict_filelist'] = [{ 70 | 'filepath': audio_filename, 71 | 'output_subdir': '.' 72 | }] 73 | 74 | cutoff_freq = compute_rolloff_freq(audio_filename, roll_percent=0.99) 75 | inference_config['data']['transforms_aug'][0]['init_args']['upsample_mask_kwargs'] = { 76 | 'min_cutoff_freq': cutoff_freq, 77 | 'max_cutoff_freq': cutoff_freq 78 | } 79 | temporary_yaml_file = save_yaml(inference_config) 80 | 81 | cmd = "cd ../; \ 82 | python ensembled_inference_api.py predict \ 83 | -c configs/ensemble_2split_sampling.yaml \ 84 | -c {} \ 85 | --model.predict_n_steps={} \ 86 | --model.output_audio_filename={}; \ 87 | cd inference/".format(temporary_yaml_file.replace('../', ''), predict_n_steps, output_audio_filename) 88 | shell_run_cmd(cmd) 89 | 90 | os.remove(temporary_yaml_file) 91 | 92 | 93 | def main(): 94 | parser = argparse.ArgumentParser(description='Description of your program') 95 | parser.add_argument('-f','--audio_filename', type=str, help='audio filename to be upsampled', required=True) 96 | parser.add_argument('-o','--output_audio_filename', type=str, help='path to save upsampled audio', required=True) 97 | parser.add_argument('-n','--predict_n_steps', type=int, help='number of sampling steps', default=50) 98 | args = parser.parse_args() 99 | 100 | upsample_one_sample(audio_filename=args.audio_filename, output_audio_filename=args.output_audio_filename, predict_n_steps=args.predict_n_steps) 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | 106 | # python A2SB_upsample_api.py -f -o -n 107 | 108 | -------------------------------------------------------------------------------- /A2SB/LICENSE: -------------------------------------------------------------------------------- 1 | NVIDIA Source Code License for A2SB 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | 7 | “Software” means the original work of authorship made available under this License. 8 | 9 | “Work” means the Software and any additions to or derivative works of the Software that are made available under 10 | this License. 11 | 12 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under 13 | U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include 14 | works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 15 | 16 | Works, including the Software, are “made available” under this License by including in or with the Work either 17 | (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 18 | 19 | 2. License Grant 20 | 21 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, 22 | worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly 23 | display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 24 | 25 | 3. Limitations 26 | 27 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you 28 | include a complete copy of this License with your distribution, and (c) you retain without modification any 29 | copyright, patent, trademark, or attribution notices that are present in the Work. 30 | 31 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and 32 | distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use 33 | limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works 34 | that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution 35 | requirements in Section 3.1) will continue to apply to the Work itself. 36 | 37 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use 38 | non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative 39 | works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 40 | 41 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, 42 | cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then 43 | your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately. 44 | 45 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, 46 | or trademarks, except as necessary to reproduce the notices described in this License. 47 | 48 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the 49 | grant in Section 2.1) will terminate immediately. 50 | 51 | 4. Disclaimer of Warranty. 52 | 53 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING 54 | WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU 55 | BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 56 | 57 | 5. Limitation of Liability. 58 | 59 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING 60 | NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 61 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR 62 | INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR 63 | DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN 64 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/configs/model_configs/autoencoders/etta_vae.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "autoencoder", 3 | "sample_size": 65536, 4 | "sample_rate": 44100, 5 | "audio_channels": 2, 6 | "precision": "32-true", 7 | "model": { 8 | "encoder": { 9 | "type": "oobleck", 10 | "config": { 11 | "in_channels": 2, 12 | "channels": 128, 13 | "c_mults": [1, 2, 4, 8, 16], 14 | "strides": [2, 4, 4, 8, 8], 15 | "latent_dim": 128, 16 | "use_snake": true, 17 | "causal": false, 18 | "padding_mode": "zeros" 19 | } 20 | }, 21 | "decoder": { 22 | "type": "oobleck", 23 | "config": { 24 | "out_channels": 2, 25 | "channels": 128, 26 | "c_mults": [1, 2, 4, 8, 16], 27 | "strides": [2, 4, 4, 8, 8], 28 | "latent_dim": 64, 29 | "use_snake": true, 30 | "final_tanh": false, 31 | "causal": false, 32 | "padding_mode": "zeros" 33 | } 34 | }, 35 | "bottleneck": { 36 | "type": "vae" 37 | }, 38 | "latent_dim": 64, 39 | "downsampling_ratio": 2048, 40 | "io_channels": 2 41 | }, 42 | "training": { 43 | "max_steps": 2800000, 44 | "learning_rate": 1.5e-4, 45 | "gradient_clip_val": 500, 46 | "warmup_steps": 0, 47 | "use_ema": true, 48 | "optimizer_configs": { 49 | "autoencoder": { 50 | "optimizer": { 51 | "type": "AdamW", 52 | "config": { 53 | "betas": [0.8, 0.99], 54 | "lr": 1.5e-4, 55 | "weight_decay": 1e-3 56 | } 57 | }, 58 | "scheduler": { 59 | "type": "InverseLR", 60 | "config": { 61 | "inv_gamma": 200000, 62 | "power": 0.5, 63 | "warmup": 0.999 64 | } 65 | } 66 | }, 67 | "discriminator": { 68 | "optimizer": { 69 | "type": "AdamW", 70 | "config": { 71 | "betas": [0.8, 0.99], 72 | "lr": 3.0e-4, 73 | "weight_decay": 1e-3 74 | } 75 | }, 76 | "scheduler": { 77 | "type": "InverseLR", 78 | "config": { 79 | "inv_gamma": 200000, 80 | "power": 0.5, 81 | "warmup": 0.999 82 | } 83 | } 84 | } 85 | }, 86 | "loss_configs": { 87 | "discriminator": { 88 | "type": "encodec", 89 | "config": { 90 | "filters": 64, 91 | "n_ffts": [2048, 1024, 512, 256, 128], 92 | "hop_lengths": [512, 256, 128, 64, 32], 93 | "win_lengths": [2048, 1024, 512, 256, 128] 94 | }, 95 | "weights": { 96 | "adversarial": 0.1, 97 | "feature_matching": 5.0 98 | } 99 | }, 100 | "spectral": { 101 | "type": "mrstft", 102 | "config": { 103 | "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], 104 | "hop_sizes": [512, 256, 128, 64, 32, 16, 8], 105 | "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], 106 | "perceptual_weighting": true 107 | }, 108 | "weights": { 109 | "mrstft": 1.0 110 | } 111 | }, 112 | "time": { 113 | "type": "l1", 114 | "weights": { 115 | "l1": 0.0 116 | } 117 | }, 118 | "bottleneck": { 119 | "type": "kl", 120 | "weights": { 121 | "kl": 1e-4 122 | } 123 | } 124 | }, 125 | "demo": { 126 | "demo_every": 10000 127 | } 128 | } 129 | } -------------------------------------------------------------------------------- /A2SB/configs/pretrain.yaml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.0.8 2 | seed_everything: true 3 | trainer: 4 | accelerator: gpu 5 | strategy: ddp 6 | devices: 8 7 | num_nodes: 4 8 | precision: bf16-mixed 9 | logger: null 10 | callbacks: null 11 | fast_dev_run: false 12 | max_epochs: null 13 | min_epochs: null 14 | max_steps: -1 15 | min_steps: null 16 | max_time: null 17 | limit_train_batches: null 18 | limit_val_batches: null 19 | limit_test_batches: null 20 | limit_predict_batches: null 21 | overfit_batches: 0.0 22 | val_check_interval: 1000 23 | check_val_every_n_epoch: 1 24 | num_sanity_val_steps: null 25 | log_every_n_steps: null 26 | enable_checkpointing: null 27 | enable_progress_bar: null 28 | enable_model_summary: null 29 | accumulate_grad_batches: 1 30 | gradient_clip_val: 0.5 31 | gradient_clip_algorithm: norm 32 | deterministic: null 33 | benchmark: null 34 | inference_mode: true 35 | use_distributed_sampler: true 36 | profiler: null 37 | detect_anomaly: false 38 | barebones: false 39 | plugins: 40 | - class_path: lightning.fabric.plugins.environments.SLURMEnvironment 41 | sync_batchnorm: false 42 | reload_dataloaders_every_n_epochs: 0 43 | default_root_dir: null 44 | 45 | 46 | model: 47 | vf_model: 48 | class_path: networks.AttnUNetF 49 | init_args: 50 | n_updown_levels: 5 51 | in_channels: 3 52 | hidden_channels: [128, 256, 512, 768, 1024, 2048] 53 | out_channels: 3 54 | emb_channels: 128 55 | band_embedding_dim: 16 56 | n_attn_heads: 8 57 | attention_levels: 58 | - 3 59 | - 4 60 | use_attn_input_norm: true 61 | num_res_blocks: 2 62 | inv_transforms: 63 | - class_path: audio_transforms.transforms.PowerScaleSpectrogram 64 | init_args: 65 | power: 4 66 | channels: 67 | - 0 68 | - class_path: audio_transforms.transforms.SpectrogramAddDCTerm 69 | - class_path: audio_transforms.transforms.MagInstPhaseToComplex 70 | - class_path: audio_transforms.transforms.InverseComplexSpectrogram 71 | init_args: 72 | n_fft: 2048 73 | win_length: 2048 74 | hop_length: 512 75 | sampling_rate: 44100 76 | n_timestep_channels: 128 77 | use_ot_ode: false 78 | beta_max: 1.0 79 | learning_rate: 0.00008 80 | 81 | data: 82 | mix_dataset_config: 83 | TRAIN_DATASET_NAME: 84 | root_folder: PATH/TO/MANIFEST/FOLDER 85 | filename: 'TRAIN_DATASET_NAME_manifest.csv' 86 | TRAIN_DATASET_NAME_TWO: 87 | root_folder: PATH/TO/MANIFEST/FOLDER 88 | filename: 'TRAIN_DATASET_NAME_TWO_manifest.csv' 89 | 90 | 91 | segment_length: 130560 92 | sampling_rate: 44100 93 | num_workers: 23 94 | batch_size: 10 95 | val_max_samples: 256 96 | 97 | transforms_aug: 98 | - class_path: corruption.corruptions.MultinomialInpaintMaskTransform 99 | init_args: 100 | p_upsample_mask: 0.5 101 | p_extension_mask: 0.0 102 | p_inpaint_mask: 0.5 103 | fill_noise_level: 0.5 104 | sampling_rate: 44100 105 | upsample_mask_kwargs: 106 | min_cutoff_freq: 2000 107 | max_cutoff_freq: 16000 108 | inpainting_mask_kwargs: 109 | min_inpainting_frac: 0.03378 110 | max_inpainting_frac: 0.5404 111 | is_random: true 112 | 113 | eval_transforms_aug: 114 | - class_path: corruption.corruptions.MultinomialInpaintMaskTransform 115 | init_args: 116 | p_upsample_mask: 0.5 117 | p_extension_mask: 0.0 118 | p_inpaint_mask: 0.5 119 | fill_noise_level: 0.5 120 | sampling_rate: 44100 121 | upsample_mask_kwargs: 122 | min_cutoff_freq: 4000 123 | max_cutoff_freq: 4000 124 | inpainting_mask_kwargs: 125 | min_inpainting_frac: 0.1013 126 | max_inpainting_frac: 0.1013 127 | is_random: false 128 | 129 | transforms_gt: 130 | - class_path: audio_transforms.transforms.ComplexSpectrogram 131 | init_args: 132 | n_fft: 2048 133 | win_length: 2048 134 | hop_length: 512 135 | - class_path: audio_transforms.transforms.ComplexToMagInstPhase 136 | - class_path: audio_transforms.transforms.SpectrogramDropDCTerm 137 | - class_path: audio_transforms.transforms.PowerScaleSpectrogram 138 | init_args: 139 | power: 0.25 140 | channels: 141 | - 0 142 | 143 | checkpoint_callback: 144 | dirpath: exp/pretrain 145 | every_n_train_steps: 1000 146 | -------------------------------------------------------------------------------- /A2SB/configs/t_finetune_2split_0.0_0.5.yaml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.0.8 2 | seed_everything: true 3 | trainer: 4 | accelerator: gpu 5 | strategy: ddp 6 | devices: 8 7 | num_nodes: 2 8 | precision: 32-true 9 | logger: null 10 | callbacks: null 11 | fast_dev_run: false 12 | max_epochs: null 13 | min_epochs: null 14 | max_steps: -1 15 | min_steps: null 16 | max_time: null 17 | limit_train_batches: null 18 | limit_val_batches: null 19 | limit_test_batches: null 20 | limit_predict_batches: null 21 | overfit_batches: 0.0 22 | val_check_interval: 1000 23 | check_val_every_n_epoch: 1 24 | num_sanity_val_steps: null 25 | log_every_n_steps: null 26 | enable_checkpointing: null 27 | enable_progress_bar: null 28 | enable_model_summary: null 29 | accumulate_grad_batches: 1 30 | gradient_clip_val: 0.5 31 | gradient_clip_algorithm: norm 32 | deterministic: null 33 | benchmark: null 34 | inference_mode: true 35 | use_distributed_sampler: true 36 | profiler: null 37 | detect_anomaly: false 38 | barebones: false 39 | plugins: 40 | - class_path: lightning.fabric.plugins.environments.SLURMEnvironment 41 | sync_batchnorm: false 42 | reload_dataloaders_every_n_epochs: 0 43 | default_root_dir: null 44 | 45 | 46 | model: 47 | train_t_min: 0.0 48 | train_t_max: 0.5 49 | vf_model: 50 | class_path: networks.AttnUNetF 51 | init_args: 52 | n_updown_levels: 5 53 | in_channels: 3 54 | hidden_channels: [128, 256, 512, 768, 1024, 2048] 55 | out_channels: 3 56 | emb_channels: 128 57 | band_embedding_dim: 16 58 | n_attn_heads: 8 59 | attention_levels: 60 | - 3 61 | - 4 62 | use_attn_input_norm: true 63 | num_res_blocks: 2 64 | inv_transforms: 65 | - class_path: audio_transforms.transforms.PowerScaleSpectrogram 66 | init_args: 67 | power: 4 68 | channels: 69 | - 0 70 | - class_path: audio_transforms.transforms.SpectrogramAddDCTerm 71 | - class_path: audio_transforms.transforms.MagInstPhaseToComplex 72 | - class_path: audio_transforms.transforms.InverseComplexSpectrogram 73 | init_args: 74 | n_fft: 2048 75 | win_length: 2048 76 | hop_length: 512 77 | sampling_rate: 44100 78 | n_timestep_channels: 128 79 | use_ot_ode: false 80 | beta_max: 1.0 81 | learning_rate: 0.00008 82 | 83 | data: 84 | mix_dataset_config: 85 | TRAIN_DATASET_NAME: 86 | root_folder: PATH/TO/MANIFEST/FOLDER 87 | filename: 'TRAIN_DATASET_NAME_manifest.csv' 88 | TRAIN_DATASET_NAME_TWO: 89 | root_folder: PATH/TO/MANIFEST/FOLDER 90 | filename: 'TRAIN_DATASET_NAME_TWO_manifest.csv' 91 | 92 | segment_length: 130560 93 | sampling_rate: 44100 94 | num_workers: 23 95 | batch_size: 4 96 | val_max_samples: 256 97 | 98 | transforms_aug: 99 | - class_path: corruption.corruptions.MultinomialInpaintMaskTransform 100 | init_args: 101 | p_upsample_mask: 0.5 102 | p_extension_mask: 0.0 103 | p_inpaint_mask: 0.5 104 | fill_noise_level: 0.5 105 | sampling_rate: 44100 106 | upsample_mask_kwargs: 107 | min_cutoff_freq: 2000 108 | max_cutoff_freq: 16000 109 | inpainting_mask_kwargs: 110 | min_inpainting_frac: 0.03378 111 | max_inpainting_frac: 0.5404 112 | is_random: true 113 | 114 | eval_transforms_aug: 115 | - class_path: corruption.corruptions.MultinomialInpaintMaskTransform 116 | init_args: 117 | p_upsample_mask: 0.5 118 | p_extension_mask: 0.0 119 | p_inpaint_mask: 0.5 120 | fill_noise_level: 0.5 121 | sampling_rate: 44100 122 | upsample_mask_kwargs: 123 | min_cutoff_freq: 4000 124 | max_cutoff_freq: 4000 125 | inpainting_mask_kwargs: 126 | min_inpainting_frac: 0.1013 127 | max_inpainting_frac: 0.1013 128 | is_random: false 129 | 130 | transforms_gt: 131 | - class_path: audio_transforms.transforms.ComplexSpectrogram 132 | init_args: 133 | n_fft: 2048 134 | win_length: 2048 135 | hop_length: 512 136 | - class_path: audio_transforms.transforms.ComplexToMagInstPhase 137 | - class_path: audio_transforms.transforms.SpectrogramDropDCTerm 138 | - class_path: audio_transforms.transforms.PowerScaleSpectrogram 139 | init_args: 140 | power: 0.25 141 | channels: 142 | - 0 143 | 144 | checkpoint_callback: 145 | dirpath: exp/t_finetune_0.0_0.5 146 | every_n_train_steps: 1000 147 | -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | import torch 6 | from safetensors.torch import load_file 7 | 8 | from torch.nn.utils import remove_weight_norm 9 | from torch.nn.utils.parametrize import remove_parametrizations 10 | 11 | def load_ckpt_state_dict(ckpt_path): 12 | if ckpt_path.endswith(".safetensors"): 13 | state_dict = load_file(ckpt_path) 14 | else: 15 | state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] 16 | 17 | return state_dict 18 | 19 | def remove_weight_norm_from_model(model): 20 | print(f"INFO: Removing all weight norm from model") 21 | for module in model.modules(): 22 | if hasattr(module, "parametrizations"): # for new WN implementation using parameterizations 23 | # print(f"Removing weight norm (parameterizations) from {module}") 24 | try: 25 | remove_parametrizations(module, "weight") 26 | except ValueError: 27 | print(f"[WARNING] No weight norm found in {module} with parameterizations. You can ignore this if you know that this module does not apply weight norm.") 28 | elif hasattr(module, "weight"): 29 | # print(f"Removing weight norm (legacy) from {module}") 30 | try: 31 | remove_weight_norm(module) 32 | except ValueError: 33 | print(f"[WARNING] No weight norm found in {module} with legacy method. You can ignore this if you know that this module does not apply weight norm.") 34 | 35 | return model 36 | 37 | # Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license 38 | # License can be found in LICENSES/LICENSE_META.txt 39 | 40 | def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): 41 | """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. 42 | 43 | Args: 44 | input (torch.Tensor): The input tensor containing probabilities. 45 | num_samples (int): Number of samples to draw. 46 | replacement (bool): Whether to draw with replacement or not. 47 | Keywords args: 48 | generator (torch.Generator): A pseudorandom number generator for sampling. 49 | Returns: 50 | torch.Tensor: Last dimension contains num_samples indices 51 | sampled from the multinomial probability distribution 52 | located in the last dimension of tensor input. 53 | """ 54 | 55 | if num_samples == 1: 56 | q = torch.empty_like(input).exponential_(1, generator=generator) 57 | return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64) 58 | 59 | input_ = input.reshape(-1, input.shape[-1]) 60 | output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) 61 | output = output_.reshape(*list(input.shape[:-1]), -1) 62 | return output 63 | 64 | 65 | def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: 66 | """Sample next token from top K values along the last dimension of the input probs tensor. 67 | 68 | Args: 69 | probs (torch.Tensor): Input probabilities with token candidates on the last dimension. 70 | k (int): The k in “top-k”. 71 | Returns: 72 | torch.Tensor: Sampled tokens. 73 | """ 74 | top_k_value, _ = torch.topk(probs, k, dim=-1) 75 | min_value_top_k = top_k_value[..., [-1]] 76 | probs *= (probs >= min_value_top_k).float() 77 | probs.div_(probs.sum(dim=-1, keepdim=True)) 78 | next_token = multinomial(probs, num_samples=1) 79 | return next_token 80 | 81 | 82 | def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: 83 | """Sample next token from top P probabilities along the last dimension of the input probs tensor. 84 | 85 | Args: 86 | probs (torch.Tensor): Input probabilities with token candidates on the last dimension. 87 | p (int): The p in “top-p”. 88 | Returns: 89 | torch.Tensor: Sampled tokens. 90 | """ 91 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 92 | probs_sum = torch.cumsum(probs_sort, dim=-1) 93 | mask = probs_sum - probs_sort > p 94 | probs_sort *= (~mask).float() 95 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 96 | next_token = multinomial(probs_sort, num_samples=1) 97 | next_token = torch.gather(probs_idx, -1, next_token) 98 | return next_token 99 | 100 | def next_power_of_two(n): 101 | return 2 ** (n - 1).bit_length() 102 | 103 | def next_multiple_of_64(n): 104 | return ((n + 63) // 64) * 64 -------------------------------------------------------------------------------- /ETTA/README.md: -------------------------------------------------------------------------------- 1 | # ETTA: Elucidating the Design Space of Text-to-Audio Models 2 | 3 | #### **Sang-gil Lee***, **Zhifeng Kong***, Arushi Goel, Sungwon Kim, Rafael Valle, Bryan Catanzaro (*Equal contribution.) 4 | 5 |
6 | 7 | 8 | 9 | 10 |
11 | 12 | 13 |
14 |
15 |
16 | 17 |
18 | 19 | 20 | ## Overview 21 | 22 | This repository contains model, training, and inference code implementation of [ETTA: Elucidating the Design Space of Text-to-Audio Models](https://arxiv.org/abs/2412.19351) (ICML 2025): 23 | 24 | * Synthetic audio caption generation pipeline is built on top of [Audio Flamingo](https://github.com/NVIDIA/audio-flamingo) from NVIDIA. See ```AFSynthetic/README.md``` for more details. 25 | 26 | * Text-to-audio model is built on top of [`stable-audio-tools`](https://github.com/Stability-AI/stable-audio-tools) from Stability AI. 27 | 28 | ## Installation 29 | 30 | The codebase has been tested on Python `3.11` and PyTorch `>=2.6.0`. Below is an example command to create a conda environment: 31 | 32 | ```shell 33 | # create and activate a conda environment 34 | conda create -n etta python=3.11 35 | conda activate etta 36 | # install pytorch 37 | pip install torch torchvision torchaudio 38 | # install flash-attn 39 | pip install flash-attn --no-build-isolation 40 | # install this repository in editable mode 41 | pip install -e . 42 | ``` 43 | 44 | Alternatively, we also provide an example `Dockerfile` based on an official PyTorch image to run the model on a Docker container. 45 | ```shell 46 | docker build --progress plain -t etta:latest . 47 | docker run --gpus all etta:latest 48 | ``` 49 | 50 | ## Inference Examples 51 | 52 | Below is an example ETTA inference script on a single GPU: 53 | ``` 54 | CUDA_VISIBLE_DEVICES=0 python inference_tta.py \ 55 | --text_prompt "A hip-hop track using sounds from a construction site—hammering nails as the beat, drilling sounds as scratches, and metal clanks as rhythm accents." "A saxophone that sounds like meowing of cat." \ 56 | --output_dir ./tmp \ 57 | --model_ckpt_path /path/to/your/etta/model_unwrap.ckpt \ 58 | --target_sample_rate 44100 \ 59 | --sampler_type euler \ 60 | --steps 100 \ 61 | --cfg_scale 3.5 \ 62 | --seconds_start 0 \ 63 | --seconds_total 10 \ 64 | --batch_size 4 65 | ``` 66 | 67 | 68 | ## Training Examples 69 | 70 | Below is an example command to train ETTA-VAE on 8 GPUs: 71 | ``` 72 | NUM_GPUS=8 && \ 73 | torchrun --nproc_per_node=$NUM_GPUS train.py \ 74 | --name DEBUG_etta_vae \ 75 | --dataset_config stable_audio_tools/configs/dataset_configs/etta_vae_training_example.json \ 76 | --model_config stable_audio_tools/configs/model_configs/autoencoders/etta_vae.json \ 77 | --save_dir tmp --ckpt_path last \ 78 | --enable_progress_bar true \ 79 | --seed 2025 \ 80 | --num_gpus $NUM_GPUS \ 81 | --batch_size 8 \ 82 | --params \ 83 | training.max_steps=2800000 \ 84 | training.loss_configs.bottleneck.weights.kl=0.0001 85 | ``` 86 | 87 | Below is an example command to train ETTA-DiT on 8 GPUs: 88 | ``` 89 | NUM_GPUS=8 && \ 90 | torchrun --nproc_per_node=$NUM_GPUS train.py \ 91 | --name DEBUG_etta_dit \ 92 | --dataset_config stable_audio_tools/configs/dataset_configs/etta_dit_training_example.json \ 93 | --model_config stable_audio_tools/configs/model_configs/txt2audio/etta_dit.json \ 94 | --save_dir tmp --ckpt_path last \ 95 | --enable_progress_bar true \ 96 | --seed 2025 \ 97 | --num_gpus $NUM_GPUS \ 98 | --batch_size 8 \ 99 | --params \ 100 | pretransform_ckpt_path=/path/to/etta_vae/model_unwrap_step_2800000.ckpt \ 101 | model.diffusion.config.depth=24 \ 102 | training.max_steps=1000000 103 | ``` 104 | 105 | Below is an example command to unwrap a trained model into `model_unwrap.ckpt`: 106 | ``` 107 | CKPT_DIR=/path/to/your/etta && 108 | python unwrap_model.py \ 109 | --model-config $CKPT_DIR/config.json \ 110 | --ckpt-path $CKPT_DIR/epoch=x-step=xxxxxx.ckpt \ 111 | --name $CKPT_DIR/model_unwrap 112 | ``` 113 | 114 | 115 | ## Citation 116 | ```bibtex 117 | @article{lee2024etta, 118 | title={ETTA: Elucidating the Design Space of Text-to-Audio Models}, 119 | author={Lee, Sang-gil and Kong, Zhifeng and Goel, Arushi and Kim, Sungwon and Valle, Rafael and Catanzaro, Bryan}, 120 | journal={arXiv preprint arXiv:2412.19351}, 121 | year={2024} 122 | } 123 | ``` 124 | 125 | ## Reference 126 | For more detail in `stable-audio-tools` from which we build ETTA upon, see [original README.md of stable-audio-tools](https://github.com/Stability-AI/stable-audio-tools/blob/main/README.md). 127 | -------------------------------------------------------------------------------- /A2SB/modelcard.md: -------------------------------------------------------------------------------- 1 | # Model Overview 2 | 3 | ## Description: 4 | A2SB uses a UNet architecture to perform inpainting on an audio spectrogram. It can fill in missing frequency bands above 4kHz (bandwidth extension), or fill in short temporal slices (currently supporting filling in gaps of less than 1 second). This model is for non commercial use only. 5 | 6 | ### License/Terms of Use: 7 | The model is provided under the NVIDIA OneWay NonCommercial License. 8 | 9 | The code is under [NVIDIA Source Code License - Non Commercial](https://github.com/NVlabs/I2SB/blob/master/LICENSE). Some components are adapted from other sources. The training code is adapted from [I2SB](https://github.com/NVlabs/I2SB) under the [NVIDIA Source Code License - Non Commercial](https://github.com/NVlabs/I2SB/blob/master/LICENSE). The model architecture is adapted from [Improved Diffusion](https://github.com/openai/improved-diffusion/blob/main/LICENSE) under the MIT License. 10 | 11 | ### Deployment Geography: 12 | Global 13 | 14 | ### Use Case: 15 | Research purposes pertaining to audio enhancement and generative modeling, as well as for general creative use such as bandwidth extension and inpainting short segments of missing audio. 16 | 17 | ### Release Date: 18 | Github 06/27/2025 via github.com/NVIDIA/diffusion-audio-restoration 19 | 20 | ## Reference(s): 21 | - [project page](https://research.nvidia.com/labs/adlr/A2SB) 22 | - [technical report](https://arxiv.org/abs/2501.11311) 23 | - [I2SB](https://github.com/NVlabs/I2SB) 24 | - [Improved-Diffusion UNet Architecture](https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/unet.py) 25 | 26 | 27 | ## Model Architecture: 28 | **Architecture Type:** CNN with interleaved Self-Attention Layers 29 | 30 | **Network Architecture:** UNET 31 | 32 | 33 | 34 | ## Input: 35 | **Input Type(s):** Audio 36 | 37 | **Input Format(s):** WAV/MP3/FLAC 38 | 39 | **Input Parameters:** One-Dimensional (1D) 40 | 41 | **Other Properties Related to Input:** All audio assumed to be single-channeled, 44.1kHz. For editing, also provide frequency cutoff for bandwidth extension sampling (resample content above this frequency), or start/end time stamps for segment inpainting. 42 | 43 | ## Output: 44 | **Output Type(s):** Audio 45 | 46 | **Output Format(s):** WAV 47 | 48 | **Output Parameters:** One-Dimensional (1D) 49 | 50 | **Other Properties Related to Output:** Single-channeled 44.1kHz output file. Maximum audio output length is 1 hour. 51 | 52 | Our AI models are designed and/or optimized to run on NVIDIA GPU-accelerated systems. By leveraging NVIDIA’s hardware (e.g. GPU cores) and software frameworks (e.g., CUDA libraries), the model achieves faster training and inference times compared to CPU-only solutions. 53 | 54 | ## Software Integration: 55 | **Runtime Engine(s):** 56 | * [PyTorch-2.2.2+cuda12.1+cudnn8] 57 | 58 | 59 | **Supported Hardware Microarchitecture Compatibility:** 60 | * NVIDIA Ampere 61 | * NVIDIA Blackwell 62 | * NVIDIA Jetson 63 | * NVIDIA Hopper 64 | * NVIDIA Lovelace 65 | * NVIDIA Pascal 66 | * NVIDIA Turing 67 | * NVIDIA Volta 68 | 69 | 70 | **[Preferred/Supported] Operating System(s):** 71 | ['Linux'] 72 | 73 | ## Model Versions: 74 | v1 75 | 76 | # Training and Evaluation Datasets: 77 | 78 | ## Training Datasets: 79 | 80 | The property column below shows the total duration before license, quality, and sampling rate filtering. Our model training code ingests only raw audio samples -- no additional labels provided in the datasets listed below are used for training purposes. 81 | 82 | | DatasetName | Collection Method | Labeling Method | Properties | 83 | | ------ | ------ | ------ | ------ | 84 | | [FMA](https://github.com/mdeff/fma) | Human | N/A | 5257.0 hrs | 85 | | [Medleys-solos-DB](https://medleydb.weebly.com/) | Human | N/A | 17.8 hrs| 86 | | [MUSAN](https://www.openslr.org/17/) | Human | N/A | 42.6 hrs | 87 | | [Musical Instrument](https://www.kaggle.com/datasets/soumendraprasad/musical-instruments-sound-dataset) | Human| N/A | 16.2 hrs | 88 | | [MusicNet](https://zenodo.org/records/5120004) | Human | N/A | 34.5 hrs | 89 | | [Slakh](https://github.com/ethman/slakh-utils) | Hybrid | N/A | 118.3 hrs| 90 | | [FreeSound](https://freesound.org/) | Human | N/A | 4576.6 hrs| 91 | | [FSD50K](https://zenodo.org/records/4060432) | Human | N/A | 75.6 hrs| 92 | | [GTZAN](http://marsyas.info/index.html) | Human | N/A | 8.3 hrs| 93 | | [NSynth](https://magenta.tensorflow.org/datasets/nsynth) | Human | N/A | 340.0 hrs| 94 | 95 | 96 | ## Evaluation Datasets: 97 | | DatasetName | Collection Method | Labeling Method | Properties | 98 | | ------ | ------ | ------ | ------ | 99 | | [AAM: Artificial Audio Multitracks Dataset](https://zenodo.org/records/5794629) | Automated | N/A | 4 hrs | 100 | | [Maestro](https://magenta.tensorflow.org/datasets/maestro) | Human | N/A | 199.2 hrs | 101 | | [MTD](https://www.audiolabs-erlangen.de/resources/MIR/MTD) | Human | N/A | 0.9 hrs | 102 | | [CC-Mixter](https://members.loria.fr/ALiutkus/kam/) | Human | N/A | 3.2 hrs | 103 | 104 | 105 | ## Inference: 106 | **Engine:** PyTorch 107 | 108 | **Test Hardware:** 109 | * NVIDIA Ampere 110 | 111 | ## Ethical Considerations: 112 | NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. 113 | 114 | Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/). -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/configs/model_configs/txt2audio/etta_dit.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "diffusion_cond", 3 | "sample_size": 441000, 4 | "sample_rate": 44100, 5 | "audio_channels": 2, 6 | "precision": "bf16-mixed", 7 | "model": { 8 | "pretransform": { 9 | "type": "autoencoder", 10 | "iterate_batch": true, 11 | "config": { 12 | "encoder": { 13 | "type": "oobleck", 14 | "config": { 15 | "in_channels": 2, 16 | "channels": 128, 17 | "c_mults": [1, 2, 4, 8, 16], 18 | "strides": [2, 4, 4, 8, 8], 19 | "latent_dim": 128, 20 | "use_snake": true 21 | } 22 | }, 23 | "decoder": { 24 | "type": "oobleck", 25 | "config": { 26 | "out_channels": 2, 27 | "channels": 128, 28 | "c_mults": [1, 2, 4, 8, 16], 29 | "strides": [2, 4, 4, 8, 8], 30 | "latent_dim": 64, 31 | "use_snake": true, 32 | "final_tanh": false 33 | } 34 | }, 35 | "bottleneck": { 36 | "type": "vae" 37 | }, 38 | "latent_dim": 64, 39 | "downsampling_ratio": 2048, 40 | "io_channels": 2 41 | } 42 | }, 43 | "conditioning": { 44 | "configs": [ 45 | { 46 | "id": "prompt", 47 | "type": "t5", 48 | "config": { 49 | "t5_model_name": "t5-base", 50 | "max_length": 512 51 | } 52 | }, 53 | { 54 | "id": "seconds_start", 55 | "type": "number", 56 | "config": { 57 | "min_val": 0, 58 | "max_val": 512 59 | } 60 | }, 61 | { 62 | "id": "seconds_total", 63 | "type": "number", 64 | "config": { 65 | "min_val": 0, 66 | "max_val": 512 67 | } 68 | } 69 | ], 70 | "cond_dim": 768 71 | }, 72 | "diffusion": { 73 | "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"], 74 | "global_cond_ids": ["seconds_start", "seconds_total"], 75 | "type": "dit", 76 | "diffusion_objective": "rectified_flow", 77 | "config": { 78 | "io_channels": 64, 79 | "embed_dim": 1536, 80 | "depth": 24, 81 | "num_heads": 24, 82 | "cond_token_dim": 768, 83 | "global_cond_dim": 1536, 84 | "project_cond_tokens": false, 85 | "transformer_type": "etta_transformer", 86 | "global_cond_type": "adaLN", 87 | "is_causal": false, 88 | "pos_emb_name": "rope", 89 | "rope_base": 16384, 90 | "use_flash_attention": true, 91 | "deterministic": false, 92 | "p_dropout": 0.1, 93 | "apply_norm_to_cond": true, 94 | "kernel_size": 1, 95 | "context_xattn": { 96 | "n_heads": 16, 97 | "d_heads": 768, 98 | "pos_emb_name": "" 99 | } 100 | } 101 | }, 102 | "io_channels": 64 103 | }, 104 | "training": { 105 | "max_steps": 1000000, 106 | "timestep_sampler": "logit_normal", 107 | "use_ema": true, 108 | "log_loss_info": false, 109 | "optimizer_configs": { 110 | "diffusion": { 111 | "optimizer": { 112 | "type": "AdamW", 113 | "config": { 114 | "lr": 1e-4, 115 | "betas": [0.9, 0.999], 116 | "weight_decay": 1e-3 117 | } 118 | }, 119 | "scheduler": { 120 | "type": "InverseLR", 121 | "config": { 122 | "inv_gamma": 1000000, 123 | "power": 0.5, 124 | "warmup": 0.99 125 | } 126 | } 127 | } 128 | }, 129 | "demo": { 130 | "demo_every": 10000, 131 | "demo_steps": 100, 132 | "num_demos": 4, 133 | "demo_cond": [ 134 | {"prompt": "A hip-hop track using sounds from a construction site—hammering nails as the beat, drilling sounds as scratches, and metal clanks as rhythm accents.", "seconds_start": 0, "seconds_total": 10}, 135 | {"prompt": "A saxophone that sounds like meowing of cat.", "seconds_start": 0, "seconds_total": 10}, 136 | {"prompt": "Dogs barking, birds chirping, and electronic dance music.", "seconds_start": 0, "seconds_total": 10}, 137 | {"prompt": "A soundscape with a choir of alarm siren from an ambulance car but to produce a lush and calm choir composition with sustained chords.", "seconds_start": 0, "seconds_total": 10} 138 | ], 139 | "demo_cfg_scales": [1, 3.5, 6] 140 | } 141 | } 142 | } -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/training/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | import torch 6 | import os 7 | 8 | # to track gradient norm without clipping 9 | def gradient_norm(model): 10 | total_norm = 0.0 11 | for p in model.parameters(): 12 | if p.grad is not None: 13 | param_norm = p.grad.detach().data.norm(2) 14 | total_norm += param_norm.item() ** 2 15 | total_norm = total_norm ** (1. / 2) 16 | return total_norm 17 | 18 | def get_rank(): 19 | """Get rank of current process.""" 20 | 21 | print(os.environ.keys()) 22 | 23 | if "SLURM_PROCID" in os.environ: 24 | return int(os.environ["SLURM_PROCID"]) 25 | 26 | if not torch.distributed.is_available() or not torch.distributed.is_initialized(): 27 | return 0 28 | 29 | return torch.distributed.get_rank() 30 | 31 | class InverseLR(torch.optim.lr_scheduler._LRScheduler): 32 | """Implements an inverse decay learning rate schedule with an optional exponential 33 | warmup. When last_epoch=-1, sets initial lr as lr. 34 | inv_gamma is the number of steps/epochs required for the learning rate to decay to 35 | (1 / 2)**power of its original value. 36 | Args: 37 | optimizer (Optimizer): Wrapped optimizer. 38 | inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. 39 | power (float): Exponential factor of learning rate decay. Default: 1. 40 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) 41 | Default: 0. 42 | final_lr (float): The final learning rate. Default: 0. 43 | last_epoch (int): The index of last epoch. Default: -1. 44 | """ 45 | 46 | def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., 47 | last_epoch=-1): 48 | self.inv_gamma = inv_gamma 49 | self.power = power 50 | if not 0. <= warmup < 1: 51 | raise ValueError('Invalid value for warmup') 52 | self.warmup = warmup 53 | self.final_lr = final_lr 54 | super().__init__(optimizer, last_epoch) 55 | 56 | def get_lr(self): 57 | if not self._get_lr_called_within_step: 58 | import warnings 59 | warnings.warn("To get the last learning rate computed by the scheduler, " 60 | "please use `get_last_lr()`.") 61 | 62 | return self._get_closed_form_lr() 63 | 64 | def _get_closed_form_lr(self): 65 | warmup = 1 - self.warmup ** (self.last_epoch + 1) 66 | lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power 67 | return [warmup * max(self.final_lr, base_lr * lr_mult) 68 | for base_lr in self.base_lrs] 69 | 70 | def copy_state_dict(model, state_dict): 71 | """Load state_dict to model, but only for keys that match exactly. 72 | 73 | Args: 74 | model (nn.Module): model to load state_dict. 75 | state_dict (OrderedDict): state_dict to load. 76 | """ 77 | model_state_dict = model.state_dict() 78 | skipped_keys = [] 79 | 80 | for key in state_dict: 81 | if key in model_state_dict: 82 | if state_dict[key].shape == model_state_dict[key].shape: 83 | if isinstance(state_dict[key], torch.nn.Parameter): 84 | # Backwards compatibility for serialized parameters 85 | state_dict[key] = state_dict[key].data 86 | model_state_dict[key] = state_dict[key] 87 | else: 88 | skipped_keys.append((key, state_dict[key].shape, model_state_dict[key].shape)) 89 | else: 90 | skipped_keys.append((key, state_dict[key].shape, None)) 91 | 92 | model.load_state_dict(model_state_dict, strict=False) 93 | 94 | if skipped_keys: 95 | print("=====================================================================================================================") 96 | print("[WARNING(copy_state_dict)] The following keys were skipped due to shape mismatch or absence in the model's state_dict:") 97 | for key, state_shape, model_shape in skipped_keys: 98 | print(f" - {key}: loaded state shape {state_shape}, model state shape {model_shape}") 99 | print("=====================================================================================================================") 100 | 101 | def create_optimizer_from_config(optimizer_config, parameters): 102 | """Create optimizer from config. 103 | 104 | Args: 105 | parameters (iterable): parameters to optimize. 106 | optimizer_config (dict): optimizer config. 107 | 108 | Returns: 109 | torch.optim.Optimizer: optimizer. 110 | """ 111 | 112 | optimizer_type = optimizer_config["type"] 113 | 114 | if optimizer_type == "FusedAdam": 115 | from deepspeed.ops.adam import FusedAdam 116 | optimizer = FusedAdam(parameters, **optimizer_config["config"]) 117 | else: 118 | optimizer_fn = getattr(torch.optim, optimizer_type) 119 | optimizer = optimizer_fn(parameters, **optimizer_config["config"]) 120 | return optimizer 121 | 122 | def create_scheduler_from_config(scheduler_config, optimizer): 123 | """Create scheduler from config. 124 | 125 | Args: 126 | scheduler_config (dict): scheduler config. 127 | optimizer (torch.optim.Optimizer): optimizer. 128 | 129 | Returns: 130 | torch.optim.lr_scheduler._LRScheduler: scheduler. 131 | """ 132 | if scheduler_config["type"] == "InverseLR": 133 | scheduler_fn = InverseLR 134 | else: 135 | scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"]) 136 | scheduler = scheduler_fn(optimizer, **scheduler_config["config"]) 137 | return scheduler -------------------------------------------------------------------------------- /A2SB/inference/A2SB_upsample_dataset.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for A2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | 10 | # # If there is Error: mkl-service + Intel(R) MKL: MKL_THREADING_LAYER=INTEL is incompatible with libgomp.so.1 library. 11 | # os.environ["MKL_THREADING_LAYER"] = "GNU" 12 | # import numpy as np 13 | # os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" 14 | 15 | import numpy as np 16 | import json 17 | import argparse 18 | import glob 19 | from subprocess import Popen, PIPE 20 | import yaml 21 | import time 22 | from datetime import datetime 23 | import shutil 24 | import csv 25 | from tqdm import tqdm 26 | 27 | import librosa 28 | import soundfile as sf 29 | 30 | 31 | def read_standard_csv(root_folder, filename): 32 | all_files = { 33 | "train": [], 34 | "validation": [], 35 | "test": [] 36 | } 37 | 38 | with open(os.path.join(root_folder, filename)) as csvfile: 39 | reader = csv.reader(csvfile, delimiter=',', quotechar='"') 40 | next(reader) 41 | for row in reader: 42 | assert len(row) == 3 43 | split, audio_filename, duration = row 44 | split = split.strip() 45 | audio_filename = audio_filename.strip() 46 | duration = float(duration) 47 | sample_rate = None # 44100 # not used 48 | all_files[split].append((audio_filename, duration, sample_rate)) 49 | 50 | return all_files 51 | 52 | 53 | def load_yaml(file_path): 54 | with open(file_path, 'r') as file: 55 | data = yaml.safe_load(file) 56 | return data 57 | 58 | 59 | def save_yaml(data, prefix="../configs/temp"): 60 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 61 | rnd_num = np.random.rand() 62 | rnd_num = rnd_num - rnd_num % 0.000001 63 | file_name = f"{prefix}_{timestamp}_{rnd_num}.yaml" 64 | with open(file_name, 'w') as f: 65 | yaml.dump(data, f) 66 | return file_name 67 | 68 | 69 | def shell_run_cmd(cmd): 70 | print('running:', cmd) 71 | p = Popen(cmd, stdout=PIPE, stderr=PIPE, shell=True) 72 | stdout, stderr = p.communicate() 73 | print(stdout) 74 | print(stderr) 75 | 76 | 77 | def upsample_one_sample(dataset_name, audio_filename, exp_root, exp_name, cutoff_freq): 78 | # get paths ready 79 | output_subdir = '_'.join(audio_filename.split('/')[-3:]) # get reasonably short filename 80 | output_subdir = '.'.join(output_subdir.split('.')[:-1]) # remove suffix 81 | output_dir = os.path.join(exp_root, exp_name, dataset_name, 'cutoff_freq={}'.format(cutoff_freq)) 82 | 83 | # Load, modify, and store yaml file for the specific file 84 | template_yaml_file = '../configs/inference_files_upsampling.yaml' 85 | inference_config = load_yaml(template_yaml_file) 86 | inference_config['data']['predict_filelist'] = [{ 87 | 'filepath': audio_filename, 88 | 'output_subdir': output_subdir 89 | }] 90 | inference_config['data']['transforms_aug'][0]['init_args']['upsample_mask_kwargs'] = { 91 | 'min_cutoff_freq': cutoff_freq, 92 | 'max_cutoff_freq': cutoff_freq 93 | } 94 | temporary_yaml_file = save_yaml(inference_config) 95 | 96 | # copy true file 97 | os.makedirs(os.path.join(output_dir, output_subdir), exist_ok=True) 98 | shutil.copy( 99 | audio_filename, 100 | os.path.join(output_dir, output_subdir, 'original.{}'.format(audio_filename.split('.')[-1])) 101 | ) 102 | orig_audio, orig_sr = librosa.load(audio_filename, sr=None) 103 | degraded_audio = librosa.resample(orig_audio, orig_sr=orig_sr, target_sr=cutoff_freq*2) 104 | degraded_audio = librosa.resample(degraded_audio, orig_sr=cutoff_freq*2, target_sr=orig_sr) 105 | sf.write(os.path.join(output_dir, output_subdir, 'degraded.{}'.format(audio_filename.split('.')[-1])), degraded_audio, orig_sr) 106 | 107 | # run upsampling command 108 | if not os.path.exists(os.path.join(output_dir, output_subdir, 'recon.wav')): 109 | cmd = "cd ../; \ 110 | python ensembled_inference.py predict \ 111 | -c configs/{}.yaml \ 112 | -c {} \ 113 | --model.predict_output_dir={}; \ 114 | cd inference/".format(exp_name, temporary_yaml_file.replace('../', ''), output_dir) 115 | shell_run_cmd(cmd) 116 | else: 117 | print(audio_filename, ' - already upsampled') 118 | 119 | os.remove(temporary_yaml_file) 120 | 121 | 122 | def main(): 123 | parser = argparse.ArgumentParser(description='Description of your program') 124 | parser.add_argument('-dn','--dataset_name', help='dataset name', required=True) 125 | parser.add_argument('-exp','--exp_name', help='exp_name', required=True) 126 | parser.add_argument('-cf','--cutoff_freq', type=int, help='cutoff_freq', required=True) 127 | parser.add_argument('-start','--start', type=int, help='start', default=0) 128 | parser.add_argument('-end','--end', type=int, help='end', default=-1) 129 | args = parser.parse_args() 130 | 131 | manifest_root_folder = 'PATH/TO/MANIFEST/FOLDER' 132 | exp_root = './exp' 133 | 134 | dataset_name = args.dataset_name 135 | manifest_filename = '{}_manifest.csv'.format(dataset_name) 136 | all_files = read_standard_csv(manifest_root_folder, manifest_filename) 137 | 138 | exp_name = args.exp_name 139 | cutoff_freq = args.cutoff_freq 140 | 141 | start = args.start 142 | end = args.end if args.end > args.start else len(all_files['test']) 143 | 144 | for row in tqdm(all_files['test'][start:end]): 145 | (audio_filename, duration, sample_rate) = row 146 | upsample_one_sample(dataset_name, audio_filename, exp_root, exp_name, cutoff_freq) 147 | 148 | 149 | if __name__ == '__main__': 150 | main() 151 | 152 | -------------------------------------------------------------------------------- /A2SB/README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Implementation of Audio-to-Audio Schrodinger Bridges 2 | 3 |
4 | 5 | 6 | 7 | 8 |
9 |
10 | 11 | 12 |
13 | 14 | Please refer to the following GitHub link for the full release: 15 | https://github.com/NVIDIA/diffusion-audio-restoration 16 | 17 | # Overview 18 | 19 | This repo contains the PyTorch implementation of [A2SB: Audio-to-Audio Schrodinger Bridges](https://arxiv.org/abs/2501.11311). A2SB is an audio restoration model tailored for high-res music at 44.1kHz. It is capable of both bandwidth extension (predicting high-frequency components) and inpainting (re-generating missing segments). Critically, A2SB is end-to-end without need of a vocoder to predict waveform outputs, and able to restore hour-long audio inputs. A2SB is capable of achieving state-of-the-art bandwidth extension and inpainting quality on several out-of-distribution music test sets. 20 | 21 | - We propose A2SB, a state-of-the-art, end-to-end, vocoder-free, and multi-task diffusion Schrodinger Bridge model for 44.1kHz high-res music restoration, using an effective factorized audio representation. 22 | 23 | - A2SB is the first long audio restoration model that could restore hour-long audio without 24 | boundary artifacts. 25 | 26 | 27 | # Usage 28 | 29 | ## Data preparation 30 | 31 | Prepare your data into a ```DATASET_NAME_manifest.csv``` file in the following format: 32 | ``` 33 | split,file_path,duration 34 | train,PATH/TO/AUDIO.wav,10.0 35 | ... 36 | validation,PATH/TO/AUDIO.wav,10.0 37 | ... 38 | test,PATH/TO/AUDIO.wav,10.0 39 | ... 40 | ``` 41 | You could have multiple manifests, one for each dataset, and you could use different audio formats as long as ```SoundFile``` supports it. After you prepare all of them, write down their paths and names in config files under ```configs/```. 42 | 43 | We train our models on the permissively licensed subsets of the following datasets: FMA, Medley-Solos-DB, MUSAN, Musical Instrument, MusicNet, Slakh, FreeSound, FSD50K, GTZAN, and NSynth. 44 | 45 | ## Training 46 | 47 | - For pretraining, the script is 48 | 49 | ```python main.py fit --config configs/pretrain.yaml``` 50 | 51 | - For T-finetuning, first copy the pretrained checkpoint to the T-finetune experiment folder as initialization. Then, T-finetuning resumes from this checkpoint. 52 | 53 | Here's an example of running T-finetuning of 2-splits. These 2 models will be trained separately. For the first split, run 54 | 55 | ```python main.py fit --config configs/t_finetune_2split_0.0_0.5.yaml``` 56 | 57 | For the second split, copy this config and modify ```model.train_t_min -> 0.5, model.train_t_max -> 1.0```, setup a different experiment name and path, and run training in a similar way. 58 | 59 | - Misc: you may need to adjust batch size, num devices, num nodes, and gradient accumulation in the configs based on your GPU configurations. 60 | 61 | 62 | ## Inference 63 | 64 | - If you would like to run inference of the entire dataset, use 65 | ``` 66 | cd inference/ 67 | python A2SB_upsample_dataset.py -dn DATASET_NAME -exp ensemble_2split_sampling -cf 4000 68 | python A2SB_inpaint_dataset.py -dn DATASET_NAME -exp ensemble_2split_sampling -inp_len 0.3 -inp_every 5.0 69 | ``` 70 | 71 | - If you would like to run a simple bandwidth extension API for arbitrarily long audio with automatic rolloff frequency detection, use 72 | ``` 73 | cd inference/ 74 | python A2SB_upsample_api.py -f DEGRADED.wav -o RESTORED.wav -n N_STEPS 75 | ``` 76 | 77 | ## Requirements 78 | 79 | ``` 80 | numpy, scipy, matplotlib, jsonargparse, librosa, soundfile, torch, torchaudio, einops, pytorch_lightning, rotary_embedding_torch, ssr_eval 81 | ``` 82 | 83 | 84 | # Citation 85 | ``` 86 | @article{kong2025a2sb, 87 | title={A2SB: Audio-to-Audio Schrodinger Bridges}, 88 | author={Kong, Zhifeng and Shih, Kevin J and Nie, Weili and Vahdat, Arash and Lee, Sang-gil and Santos, Joao Felipe and Jukic, Ante and Valle, Rafael and Catanzaro, Bryan}, 89 | journal={arXiv preprint arXiv:2501.11311}, 90 | year={2025} 91 | } 92 | ``` 93 | 94 | # License/Terms of Use: 95 | The model is provided under the NVIDIA OneWay NonCommercial License. 96 | 97 | The code is under [NVIDIA Source Code License - Non Commercial](https://github.com/NVlabs/I2SB/blob/master/LICENSE). Some components are adapted from other sources. The training code is adapted from [I2SB](https://github.com/NVlabs/I2SB) under the [NVIDIA Source Code License - Non Commercial](https://github.com/NVlabs/I2SB/blob/master/LICENSE). The model architecture is adapted from [Improved Diffusion](https://github.com/openai/improved-diffusion/blob/main/LICENSE) under the MIT License. 98 | 99 | For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/). 100 | 101 | ## Ethical Considerations: 102 | NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. 103 | 104 | Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/). 105 | -------------------------------------------------------------------------------- /A2SB/diffusion.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for A2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | 9 | # Original source and license: 10 | # --------------------------------------------------------------- 11 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 12 | # 13 | # This work is licensed under the NVIDIA Source Code License 14 | # for I2SB. 15 | # https://github.com/NVlabs/I2SB/blob/master/i2sb/diffusion.py 16 | 17 | 18 | import numpy as np 19 | from tqdm import tqdm 20 | 21 | import torch 22 | import torch.nn as nn 23 | from math import ceil 24 | from einops import rearrange 25 | 26 | 27 | def get_multidiffusion_vf(vf_model, x_t, t_emb, win_length=256, hop_length=128, batch_size=16): 28 | """ 29 | t_emb should be b x emb_dim but all embeddiengs should be for the same time step, as this code does not 30 | support heterogenous sampling schedulings 31 | """ 32 | b_size, num_channels, win_height, seq_len = x_t.shape 33 | counts = torch.zeros_like(x_t) 34 | num_hops = (seq_len - (win_length - hop_length))//hop_length 35 | l_idx = 0 36 | vf_t = torch.zeros_like(x_t) 37 | unfolder = torch.nn.Unfold([win_height, win_length], stride=hop_length) 38 | x_t_unfolded = unfolder(x_t) 39 | num_channels = x_t.shape[1] 40 | b_size = x_t.shape[0] 41 | x_t_unfolded = rearrange(x_t_unfolded, "b (c h w) l -> (b l) c h w", c=num_channels, 42 | h=win_height, w=win_length) 43 | # compute the vector fields in batches 44 | num_chunks = ceil(x_t_unfolded.shape[0]/batch_size) 45 | x_t_unfolded_chunked = torch.chunk(x_t_unfolded, num_chunks) 46 | vfields_out = [] 47 | t_emb_rpt = t_emb.repeat(num_hops, 1) 48 | t_emb_chunked = torch.chunk(t_emb_rpt, num_chunks) 49 | for b_chunk_idx in range(num_chunks): 50 | vfields_out.append(vf_model(x_t_unfolded_chunked[b_chunk_idx], t_emb_chunked[b_chunk_idx])) 51 | vfields = torch.cat(vfields_out, 0) 52 | vfields = rearrange(vfields, "(b l) c h w -> l b c h w", b=b_size, l=num_hops) 53 | 54 | for hop_idx in range(int(num_hops)): 55 | 56 | r_idx = l_idx + win_length 57 | counts[...,l_idx:r_idx]+=1 58 | curr_x_t = x_t[...,l_idx:r_idx] 59 | vf_out = vfields[hop_idx] #vf_model(curr_x_t, t_emb) 60 | 61 | vf_t[...,l_idx:r_idx] += vf_out 62 | l_idx += hop_length 63 | 64 | return vf_t / counts 65 | 66 | 67 | def multidiffusion_pad_inputs(input, win_length, hop_length, padding_constant=None): 68 | _b, _c, _h, width = input.shape 69 | if width <= win_length: # no hops 70 | to_pad = win_length - width 71 | else: 72 | pad_to = ceil((width - win_length)/ hop_length) * hop_length + win_length 73 | to_pad = pad_to - width 74 | 75 | if to_pad > 0: 76 | padding = input[..., :to_pad] 77 | if padding_constant is not None: 78 | padding = padding*0+padding_constant 79 | 80 | input_padded = torch.cat([input, padding], dim=-1) 81 | else: 82 | input_padded = input.clone() 83 | return input_padded 84 | 85 | 86 | def multidiffusion_unpad_outputs(output, original_width: int): 87 | return output[...,:original_width] 88 | 89 | 90 | def compute_gaussian_product_coef(sigma1, sigma2): 91 | """ Given p1 = N(x_t|x_0, sigma_1**2) and p2 = N(x_t|x_1, sigma_2**2) 92 | return p1 * p2 = N(x_t| coef1 * x0 + coef2 * x1, var) """ 93 | 94 | denom = sigma1**2 + sigma2**2 95 | coef1 = sigma2**2 / denom 96 | coef2 = sigma1**2 / denom 97 | var = (sigma1**2 * sigma2**2) / denom 98 | return coef1, coef2, var 99 | 100 | class Diffusion(nn.Module): 101 | def __init__(self, beta_min=1e-4, beta_max=0.3): 102 | super().__init__() 103 | # t = 0 (clean data), t=1 (corrputed posterior) 104 | self.beta_min = beta_min 105 | self.beta_max = beta_max 106 | 107 | def get_beta_t(self, t): 108 | # beta = sqrt(t)*beta/0.5 109 | if t <= 0.5: 110 | return t**2 * self.beta_max 111 | else: 112 | return (1-t)**2 * self.beta_max 113 | 114 | def get_int_beta_0_t(self, t): 115 | """ 116 | t: torch.tensor [0,1] 117 | """ 118 | beta_int = t.clone() 119 | full_integral = 2 * self.beta_max*(0.5**3)/3 120 | half_inds = t > 0.5 121 | beta_int[half_inds] = full_integral - 1/3*self.beta_max * ((1-t[half_inds])**3) 122 | beta_int[~half_inds] = 1/3*self.beta_max * (t[~half_inds]**3) 123 | return beta_int 124 | 125 | def get_std_fwd(self, t): 126 | return torch.sqrt(self.get_int_beta_0_t(t)) 127 | 128 | def get_std_rev(self, t): 129 | return torch.sqrt(self.get_int_beta_0_t(1-t)) 130 | 131 | def get_std_t(self, t): 132 | sigma_fwd = self.get_std_fwd(t) 133 | sigma_rev = self.get_std_rev(t) 134 | coef1, coef2, var = compute_gaussian_product_coef(sigma_fwd, sigma_rev) 135 | return torch.sqrt(var) 136 | 137 | def q_sample(self, t, x_0, x_1, ot_ode=False): 138 | """ Sample q(x_t | x_0, x_1), i.e. eq 11 """ 139 | sigma_fwd = self.get_std_fwd(t) 140 | sigma_rev = self.get_std_rev(t) 141 | 142 | coef1, coef2, var = compute_gaussian_product_coef(sigma_fwd, sigma_rev) 143 | while len(coef1.shape) < len(x_0.shape): 144 | coef1 = coef1[:, None] 145 | coef2 = coef2[:, None] 146 | var = var[:, None] 147 | x_t = coef1 * x_0 + coef2 * x_1 148 | std_sb_t = torch.sqrt(var) 149 | if not ot_ode: 150 | x_t += std_sb_t * torch.randn_like(x_t) 151 | return x_t.detach() 152 | 153 | def p_posterior(self, t_prev, t, x_t, x_0, ot_ode=False): 154 | assert t_prev < t 155 | std_t = self.get_std_fwd(t) 156 | std_t_prev = self.get_std_fwd(t_prev) 157 | std_delta = (std_t**2 - std_t_prev**2).sqrt() 158 | mu_x0, mu_xt, var = compute_gaussian_product_coef(std_t_prev, std_delta) 159 | x_t_prev = mu_x0 * x_0 + mu_xt * x_t 160 | 161 | if not ot_ode and t_prev > 0: 162 | x_t_prev = x_t_prev + var.sqrt() * torch.randn_like(x_t_prev) 163 | return x_t_prev 164 | 165 | def get_pred_x0(self, t, x_t, net_out): 166 | std_fwd_t = self.get_std_fwd(t) 167 | pred_x0 = x_t - std_fwd_t * net_out 168 | return pred_x0 169 | -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/models/lm_backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | from torch import nn 6 | from x_transformers import ContinuousTransformerWrapper, Decoder 7 | 8 | from .transformer import ContinuousTransformer 9 | 10 | # Interface for backbone of a language model 11 | # Handles conditioning and cross-attention 12 | # Does not have to deal with patterns or quantizer heads 13 | class AudioLMBackbone(nn.Module): 14 | def __init__(self, embed_dim: int, use_generation_cache=False, **kwargs): 15 | super().__init__() 16 | 17 | self.embed_dim = embed_dim 18 | self.use_generation_cache = use_generation_cache 19 | 20 | def forward( 21 | self, 22 | x, 23 | cross_attn_cond=None, 24 | prepend_cond=None, 25 | prepend_cond_mask=None, 26 | global_cond=None, 27 | use_cache=False, 28 | **kwargs 29 | ): 30 | raise NotImplementedError 31 | 32 | def reset_generation_cache( 33 | self, 34 | max_seq_len, 35 | batch_size, 36 | dtype=None 37 | ): 38 | pass 39 | 40 | def update_generation_cache( 41 | self, 42 | seqlen_offset 43 | ): 44 | pass 45 | 46 | class XTransformersAudioLMBackbone(AudioLMBackbone): 47 | def __init__(self, 48 | embed_dim: int, 49 | cross_attn_cond_dim: int = 0, 50 | prepend_cond_dim: int = 0, 51 | **kwargs): 52 | super().__init__(embed_dim=embed_dim) 53 | 54 | # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer 55 | self.model = ContinuousTransformerWrapper( 56 | dim_in=embed_dim, 57 | dim_out=embed_dim, 58 | max_seq_len=0, #Not relevant without absolute positional embeds, 59 | attn_layers=Decoder( 60 | dim=embed_dim, 61 | attn_flash = True, 62 | cross_attend = cross_attn_cond_dim > 0, 63 | zero_init_branch_output=True, 64 | use_abs_pos_emb = False, 65 | rotary_pos_emb=True, 66 | ff_swish = True, 67 | ff_glu = True, 68 | **kwargs 69 | ) 70 | ) 71 | 72 | if prepend_cond_dim > 0: 73 | # Prepend conditioning 74 | self.to_prepend_embed = nn.Sequential( 75 | nn.Linear(prepend_cond_dim, embed_dim, bias=False), 76 | nn.SiLU(), 77 | nn.Linear(embed_dim, embed_dim, bias=False) 78 | ) 79 | 80 | if cross_attn_cond_dim > 0: 81 | # Cross-attention conditioning 82 | self.to_cross_attn_embed = nn.Sequential( 83 | nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), 84 | nn.SiLU(), 85 | nn.Linear(embed_dim, embed_dim, bias=False) 86 | ) 87 | 88 | def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): 89 | 90 | prepend_length = 0 91 | if prepend_cond is not None: 92 | # Project the prepend conditioning to the embedding dimension 93 | prepend_cond = self.to_prepend_embed(prepend_cond) 94 | prepend_length = prepend_cond.shape[1] 95 | 96 | if prepend_cond_mask is not None: 97 | # Cast mask to bool 98 | prepend_cond_mask = prepend_cond_mask.bool() 99 | 100 | if cross_attn_cond is not None: 101 | # Project the cross-attention conditioning to the embedding dimension 102 | cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) 103 | 104 | return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] 105 | 106 | class ContinuousTransformerAudioLMBackbone(AudioLMBackbone): 107 | def __init__(self, 108 | embed_dim: int, 109 | cross_attn_cond_dim: int = 0, 110 | prepend_cond_dim: int = 0, 111 | project_cross_attn_cond: bool = False, 112 | **kwargs): 113 | super().__init__(embed_dim=embed_dim) 114 | 115 | # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer 116 | self.model = ContinuousTransformer( 117 | dim=embed_dim, 118 | dim_in=embed_dim, 119 | dim_out=embed_dim, 120 | cross_attend = cross_attn_cond_dim > 0, 121 | cond_token_dim = embed_dim if project_cross_attn_cond else cross_attn_cond_dim, 122 | causal=True, 123 | **kwargs 124 | ) 125 | 126 | if prepend_cond_dim > 0: 127 | # Prepend conditioning 128 | self.to_prepend_embed = nn.Sequential( 129 | nn.Linear(prepend_cond_dim, embed_dim, bias=False), 130 | nn.SiLU(), 131 | nn.Linear(embed_dim, embed_dim, bias=False) 132 | ) 133 | 134 | if cross_attn_cond_dim > 0 and project_cross_attn_cond: 135 | # Cross-attention conditioning 136 | self.to_cross_attn_embed = nn.Sequential( 137 | nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), 138 | nn.SiLU(), 139 | nn.Linear(embed_dim, embed_dim, bias=False) 140 | ) 141 | else: 142 | self.to_cross_attn_embed = nn.Identity() 143 | 144 | def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): 145 | 146 | prepend_length = 0 147 | if prepend_cond is not None: 148 | # Project the prepend conditioning to the embedding dimension 149 | prepend_cond = self.to_prepend_embed(prepend_cond) 150 | prepend_length = prepend_cond.shape[1] 151 | 152 | if prepend_cond_mask is not None: 153 | # Cast mask to bool 154 | prepend_cond_mask = prepend_cond_mask.bool() 155 | 156 | if cross_attn_cond is not None: 157 | # Cast cross_attn_cond to same dtype as self.to_cross_attn_embed 158 | cross_attn_cond = cross_attn_cond.to(self.to_cross_attn_embed[0].weight.dtype) 159 | 160 | # Project the cross-attention conditioning to the embedding dimension 161 | cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) 162 | 163 | return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] -------------------------------------------------------------------------------- /ETTA/AFSynthetic/README.md: -------------------------------------------------------------------------------- 1 | # Steps to create AFSynthetic Dataset 2 | 3 | ## Overview 4 | 5 | The AFSynthetic dataset contains 1.35M high quality captions generated with Audio Flamingo. 6 | 7 | AFSynthetic Overview 8 | 9 | ## Step 1: download audio datasets from the original sources. 10 | 11 | - The audio sources include AudioCaps, AudioSet, WavCaps, Laion-630K, and VGGSound. 12 | - It is ideal to transform all samples into ```.wav``` for faster seek. 13 | - Note: it is possible that some sources have been removed and some audio files are under restricted licenses. Make sure you follow the dataset and audio licenses. 14 | 15 | ## Step 2: generate raw captions 16 | 17 | ### AF-AudioSet 18 | - Our previous experimental dataset ```AF-AudioSet``` introduced in [Tango-AF](https://arxiv.org/pdf/2406.15487) is released [here](https://github.com/NVIDIA/audio-flamingo/tree/legacy_audio_flamingo_1/labeling_machine). 19 | 20 | ### ETTA (ICML 2025) version 21 | - Our ICML 2025 version of AFSynthetic in ETTA uses [Audio Flamingo](https://arxiv.org/pdf/2402.01831) (ICML 2024) for caption generation. See example inference script in [this python script](https://github.com/NVIDIA/audio-flamingo/blob/legacy_audio_flamingo_1/inference/inference_examples.py). 22 | - Use the ```chat model``` in the [inference script](https://github.com/NVIDIA/audio-flamingo/blob/legacy_audio_flamingo_1/inference/inference_examples.py#L187) (L187), 23 | - Prepare data chunks in 10s segments -- replace L199-L218 in the inference script with the following: 24 | ``` 25 | # manifest is the list of samples 26 | # data_root is the root folder to store dataset 27 | audio_duration = 10.0 28 | items = [] 29 | for v in manifest: 30 | name = os.path.join(data_root, v["name"]) 31 | total_duration = librosa.get_duration(name) 32 | for audio_start_idx in range(int(total_duration // audio_duration)): 33 | items.append( 34 | { 35 | 'name': name, 36 | 'prefix': 'The task is dialog.', 37 | 'audio_start': audio_start_idx * audio_duration, 38 | 'audio_duration': audio_duration, 39 | 'dialogue': [{"user": "Can you briefly describe what you hear in this audio?"}] 40 | }, 41 | ) 42 | ``` 43 | - Change ```inference_kwargs``` to 44 | ``` 45 | inference_kwargs = { 46 | "do_sample": True, 47 | "top_k": 50, 48 | "top_p": 0.95, 49 | "num_return_sequences": 20 50 | } 51 | ``` 52 | - Use command ```outputs = main(config_file, data_root, checkpoint_path, items, inference_kwargs, is_dialogue=True, do_dialogue_last=True)``` 53 | 54 | ### Latest version 55 | - You are encouraged to use [Audio Flamingo 2](https://arxiv.org/pdf/2503.03983) (ICML 2025) for better captioning. 56 | - Chunk into 10s or 30s segments based on your needs. 57 | - The inference script example is [here](https://github.com/NVIDIA/audio-flamingo/tree/main/inference_HF_pretrained). 58 | 59 | ## Step 3: CLAP filtering 60 | 61 | ### ETTA (ICML 2025) version 62 | - Our ICML 2025 version of AFSynthetic in ETTA uses [Laion-CLAP](https://github.com/LAION-AI/CLAP) ```630k-audioset-fusion-best.pt``` for clap similarity computation. 63 | - The similarity is computed as cosine similarity between audio and text embeddings: 64 | ``` 65 | @torch.no_grad() 66 | def compute_laionclap_text_audio_sim(audio_file, laionclap_model, outputs, duration=None, start=None): 67 | data = load_audio(audio_file, target_sr=48000, duration=duration, start=start) 68 | # compute audio embedding 69 | audio_data = data.reshape(1, -1) 70 | audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().cuda() 71 | audio_embed = laionclap_model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True) 72 | text_embed = laionclap_model.get_text_embedding(outputs, use_tensor=True) 73 | cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) 74 | cos_similarity = cos(audio_embed.repeat(text_embed.shape[0], 1), text_embed) 75 | return cos_similarity.squeeze().cpu().numpy() 76 | ``` 77 | - Choose the captions with ```topk=3``` highest clap scores, and then further filter samples with clap score less than ```0.45```. 78 | 79 | 80 | ### Latest version 81 | - You are encouraged to use Audio Flamingo 2's AF-Clap model for clap similarity computation [example script](https://github.com/NVIDIA/audio-flamingo/blob/main/AFClap/afclap_similarity.py). 82 | - Choose a clap threshold between ```0.3 - 0.4``` depending on your quality requirements. 83 | 84 | ## Step 4: extra filtering 85 | - Keyword filtering using the following condition 86 | ``` 87 | def is_low_quality(caption): 88 | keywords = [ 89 | "noise", "noisy", "unclear", "muffled", "indistinct", "inaudible", "distorted", 90 | "garbled", "unintelligible", "static", "interference", "echo", "background noise", 91 | "low volume", "choppy", "feedback", "crackling", "hissing", "fuzzy", "murmur", 92 | "buzzing", "scrambled", "faint", "broken up", "skipped", "irrelevant", 93 | "overlapping speech", "reverberation", "clipping", "sibilance", "popping", 94 | "unspecific", "gibberish", "unknown sounds", "vague", "ambiguous", "incoherent", 95 | "misheard", "uncertain", "distant", "irregular", "glitch", "skipping", "dropout", 96 | "artifact", "undermodulated", "overmodulated", "off-mic", "misinterpretation", "unreliable", "fluctuating", 97 | "low-quality", "low quality", "compromised", "substandard", "inferior", "deficient", "poor", 98 | "suboptimal", "flawed", "unsatisfactory", "inadequate", "faulty", 99 | "second-rate", "mediocre", "insufficient", "lacking", "imprecise" 100 | ] 101 | # if you do not want to include speech data, add the following keywords 102 | # keywords += ["speech", "voice", "man", "woman", "male", "female", "baby", "crying", "cries", "speaking", "speak", "speaks", "talk"] 103 | 104 | caption_lower = caption.lower() 105 | for keyword in keywords: 106 | if keyword.lower() in caption_lower: 107 | return True 108 | return False 109 | ``` 110 | - Sub-sample non-musical samples with ```audio_start > 600```. 111 | - (Optional) Audio quality filtering using [AudioBox Aesthetics scores](https://github.com/facebookresearch/audiobox-aesthetics). We recommend removing samples with ```CE <= 3.38 and PC <= 2.89``` based on our own manual inspections on a number of samples. 112 | - (Optional) Use an LLM to rephrase captions or shorten long captions. An example prompt is 113 | ``` 114 | Condense the following description of audio or music into a short caption around 10 to 20 words. Make sure all acoustic information is accurate. For example, a caption like 'People speak and a dog pants' or 'Rapid plastic clicking as a machine blows air loudly followed by a click then blowing air softly' is good.\n 115 | 116 | ``` 117 | -------------------------------------------------------------------------------- /A2SB/corruption/corruptions.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for A2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | 9 | import torch 10 | import numpy as np 11 | 12 | 13 | 14 | def mask_with_noise(x, mask, noise_level): 15 | return x * (1-mask) + mask * torch.randn_like(x) * noise_level 16 | 17 | 18 | class UpsampleMask: 19 | def __init__(self, min_cutoff_freq: int, max_cutoff_freq: int, sampling_rate: int, dc_dropped: bool=True): 20 | super().__init__() 21 | self.min_cutoff_freq = min_cutoff_freq 22 | self.max_cutoff_freq = max_cutoff_freq 23 | self.sampling_rate = sampling_rate 24 | self.dc_dropped = dc_dropped 25 | 26 | @staticmethod 27 | def get_upsample_mask(spec: torch.Tensor, min_cutoff_freq: int, max_cutoff_freq: int, sampling_rate: int, dc_dropped=True): 28 | """ 29 | input: 30 | spec: C x H x L spectrograms batched 31 | returns: 32 | mel_masked: C x H x L mel spectrograms batched, with areas filled in with white noise 33 | """ 34 | 35 | c, h, l = spec.shape 36 | inpaint_mask = torch.zeros(c, h, l).to(spec.device) 37 | # low = int(h * min_cutoff_freq / float(sampling_rate)) 38 | # high = min(int(h * max_cutoff_freq / float(sampling_rate)), h) 39 | if dc_dropped: 40 | n_fft = h * 2 41 | else: 42 | n_fft = (h - 1) * 2 43 | inpaint_mask = torch.zeros(c, h, l).to(spec.device) 44 | low = int(n_fft * min_cutoff_freq / float(sampling_rate)) 45 | high = min(int(n_fft * max_cutoff_freq / float(sampling_rate)), h) 46 | high = max(high, low + 1) # make sure high > low 47 | 48 | cutoff = torch.randint(low=low, high=high, size=[1]) 49 | 50 | inpaint_mask[:, cutoff[0]:, :] = 1 51 | return inpaint_mask 52 | 53 | def __call__(self, spec: torch.Tensor): 54 | return self.get_upsample_mask(spec, self.min_cutoff_freq, self.max_cutoff_freq, self.sampling_rate, self.dc_dropped) 55 | 56 | 57 | class ExtensionMask: 58 | def __init__(self, min_edge_distance=32): 59 | super().__init__() 60 | self.min_edge_distance = min_edge_distance 61 | 62 | @staticmethod 63 | def get_extension_mask(spec: torch.Tensor, min_edge_distance: int): 64 | """ 65 | input: 66 | spec: C x H x L spectrograms batched 67 | returns: 68 | mel_masked: C x H x L mel spectrograms batched, with areas filled in with white noise 69 | """ 70 | 71 | c, h, l = spec.shape 72 | inpaint_mask = torch.zeros(c, h, l).to(spec.device) 73 | mask_start_ind = torch.randint(low=min_edge_distance, high=l-min_edge_distance, size=[1]) 74 | 75 | if torch.randn(1) > 0: # to the right 76 | inpaint_mask[:, :, mask_start_ind[0]:] = 1 77 | else: # to the left 78 | inpaint_mask[:, :, :mask_start_ind[0]] = 1 79 | return inpaint_mask 80 | 81 | def __call__(self, spec: torch.Tensor): 82 | return self.get_extension_mask(spec, self.min_edge_distance) 83 | 84 | 85 | class InpaintMask: 86 | def __init__(self, min_inpainting_frac: float, max_inpainting_frac: float, is_random: bool): 87 | super().__init__() 88 | assert 0.0 <= min_inpainting_frac <= max_inpainting_frac <= 1.0 89 | self.min_inpainting_frac = min_inpainting_frac 90 | self.max_inpainting_frac = max_inpainting_frac 91 | self.is_random = is_random 92 | 93 | @staticmethod 94 | def get_inpainting_mask(spec: torch.Tensor, min_inpainting_frac, max_inpainting_frac, is_random): 95 | 96 | c, h, w = spec.shape 97 | # print('spec.shape', spec.shape) # torch.Size([3, 1024, 256]) 98 | inpaint_mask = torch.zeros(c, h, w).to(spec.device) 99 | 100 | random_variable_for_length = np.random.rand() 101 | inpainting_frac = random_variable_for_length * (max_inpainting_frac - min_inpainting_frac) + min_inpainting_frac 102 | if inpainting_frac == 0: 103 | # that is, min_inpainting_frac = max_inpainting_frac = 0.0 104 | return inpaint_mask 105 | 106 | if not is_random: 107 | inpainting_start_frac = 0.5 - inpainting_frac / 2.0 108 | else: 109 | inpainting_start_frac = np.random.rand() * (1.0 - inpainting_frac) 110 | 111 | inpainting_start = int(inpainting_start_frac * w) 112 | inpainting_end = int((inpainting_start_frac + inpainting_frac) * w) 113 | inpaint_mask[:, :, inpainting_start:inpainting_end] = 1 114 | return inpaint_mask 115 | 116 | def __call__(self, spec: torch.Tensor): 117 | return self.get_inpainting_mask(spec, self.min_inpainting_frac, self.max_inpainting_frac, self.is_random) 118 | 119 | 120 | class MultinomialInpaintMaskTransform: 121 | def __init__(self, p_upsample_mask=0.5, p_extension_mask=0.5, p_inpaint_mask=0.0, fill_noise_level=0.5, sampling_rate=22050, upsample_mask_kwargs={}, inpainting_mask_kwargs={}): 122 | """ 123 | TODO: include other parameters for individual transforms 124 | """ 125 | self.mask_fns = [UpsampleMask(sampling_rate=sampling_rate, **upsample_mask_kwargs), ExtensionMask(), InpaintMask(**inpainting_mask_kwargs)] 126 | self.mask_multinomial_probs = torch.Tensor([p_upsample_mask, p_extension_mask, p_inpaint_mask]) 127 | self.fill_noise_level = fill_noise_level 128 | 129 | 130 | def __call__(self, spec: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 131 | """ 132 | input: 133 | spec: spectrogram tensor of shape C x H x W 134 | Returns: 135 | masked_input: masked version of spec with noise-filled holes 136 | mask: C x H x W binary masked used 137 | """ 138 | 139 | mask_fn = self.mask_fns[torch.multinomial(self.mask_multinomial_probs, 1)] 140 | mask = mask_fn(spec) 141 | 142 | masked_and_noised_spec = mask_with_noise(spec, mask, self.fill_noise_level) 143 | 144 | return masked_and_noised_spec, mask 145 | 146 | 147 | class TimestampedSegmentInpaintMaskTransform: 148 | def __init__(self, start_time=0.5, end_time=1.0, hop_length=512, sampling_rate=44100, fill_noise_level=0.5): 149 | self.start_idx = int(sampling_rate/hop_length*start_time) 150 | self.end_idx = int(sampling_rate/hop_length*end_time) 151 | self.fill_noise_level = fill_noise_level 152 | 153 | def __call__(self, spec: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 154 | """ 155 | spec: C x H x W 156 | """ 157 | mask = torch.zeros_like(spec) 158 | mask[:, :, self.start_idx:self.end_idx] = 1 159 | masked_and_noised_spec = mask_with_noise(spec, mask, self.fill_noise_level) 160 | return masked_and_noised_spec, mask 161 | -------------------------------------------------------------------------------- /ETTA/unwrap_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | import argparse 6 | import json 7 | import torch 8 | from torch.nn.parameter import Parameter 9 | from stable_audio_tools.models import create_model_from_config 10 | from stable_audio_tools.models.utils import remove_weight_norm_from_model 11 | from stable_audio_tools.utils.addict import Dict as AttrDict 12 | from pprint import pprint 13 | 14 | if __name__ == '__main__': 15 | args = argparse.ArgumentParser() 16 | args.add_argument('--model-config', type=str, default=None) 17 | args.add_argument('--ckpt-path', type=str, default=None) 18 | args.add_argument('--name', type=str, default='exported_model') 19 | args.add_argument('--use-safetensors', action='store_true') 20 | 21 | args = args.parse_args() 22 | 23 | with open(args.model_config) as f: 24 | model_config = json.load(f) 25 | 26 | # convert it to AttrDict (dot-accessible dictionary) 27 | model_config = AttrDict(model_config) 28 | 29 | # to load config.json from experiment with potential overridden params 30 | if "model_config" in model_config.keys(): 31 | model_config = model_config["model_config"] 32 | pprint(model_config) 33 | 34 | model = create_model_from_config(model_config) 35 | 36 | # Remove weight_norm from the pretransform if specified (during training), to load WN-less weight properly from ckpt 37 | # either "pre_load" or "post_load" will work 38 | if model_config.get("remove_pretransform_weight_norm", ''): 39 | remove_weight_norm_from_model(model.pretransform) 40 | 41 | model_type = model_config.get('model_type', None) 42 | 43 | assert model_type is not None, 'model_type must be specified in model config' 44 | 45 | training_config = model_config.get('training', None) 46 | 47 | if model_type == 'autoencoder': 48 | from stable_audio_tools.training.autoencoders import AutoencoderTrainingWrapper 49 | 50 | ema_copy = None 51 | 52 | if training_config.get("use_ema", False): 53 | from stable_audio_tools.models.factory import create_model_from_config 54 | ema_copy = create_model_from_config(model_config) 55 | # ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once 56 | 57 | # Copy each weight to the ema copy 58 | for name, param in model.state_dict().items(): 59 | if isinstance(param, Parameter): 60 | # backwards compatibility for serialized parameters 61 | param = param.data 62 | ema_copy.state_dict()[name].copy_(param) 63 | 64 | use_ema = training_config.get("use_ema", False) 65 | 66 | training_wrapper = AutoencoderTrainingWrapper.load_from_checkpoint( 67 | args.ckpt_path, 68 | autoencoder=model, 69 | strict=False, 70 | sample_rate=model.sample_rate, 71 | loss_config=training_config["loss_configs"], 72 | use_ema=training_config["use_ema"], 73 | ema_copy=ema_copy if use_ema else None 74 | ) 75 | elif model_type == 'diffusion_uncond': 76 | from stable_audio_tools.training.diffusion import DiffusionUncondTrainingWrapper 77 | training_wrapper = DiffusionUncondTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) 78 | 79 | elif model_type == 'diffusion_autoencoder': 80 | from stable_audio_tools.training.diffusion import DiffusionAutoencoderTrainingWrapper 81 | 82 | ema_copy = create_model_from_config(model_config) 83 | 84 | for name, param in model.state_dict().items(): 85 | if isinstance(param, Parameter): 86 | # backwards compatibility for serialized parameters 87 | param = param.data 88 | ema_copy.state_dict()[name].copy_(param) 89 | 90 | training_wrapper = DiffusionAutoencoderTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, ema_copy=ema_copy, strict=False) 91 | elif model_type == 'diffusion_cond': 92 | from stable_audio_tools.training.diffusion import DiffusionCondTrainingWrapper 93 | 94 | use_ema = training_config.get("use_ema", True) 95 | 96 | training_wrapper = DiffusionCondTrainingWrapper.load_from_checkpoint( 97 | args.ckpt_path, 98 | model=model, 99 | use_ema=use_ema, 100 | lr=training_config.get("learning_rate", None), 101 | optimizer_configs=training_config.get("optimizer_configs", None), 102 | strict=False 103 | ) 104 | elif model_type == 'diffusion_cond_inpaint': 105 | from stable_audio_tools.training.diffusion import DiffusionCondInpaintTrainingWrapper 106 | training_wrapper = DiffusionCondInpaintTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) 107 | elif model_type == 'diffusion_prior': 108 | from stable_audio_tools.training.diffusion import DiffusionPriorTrainingWrapper 109 | 110 | ema_copy = create_model_from_config(model_config) 111 | 112 | for name, param in model.state_dict().items(): 113 | if isinstance(param, Parameter): 114 | # backwards compatibility for serialized parameters 115 | param = param.data 116 | ema_copy.state_dict()[name].copy_(param) 117 | 118 | training_wrapper = DiffusionPriorTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False, ema_copy=ema_copy) 119 | elif model_type == 'lm': 120 | from stable_audio_tools.training.lm import AudioLanguageModelTrainingWrapper 121 | 122 | ema_copy = None 123 | 124 | if training_config.get("use_ema", False): 125 | 126 | ema_copy = create_model_from_config(model_config) 127 | 128 | for name, param in model.state_dict().items(): 129 | if isinstance(param, Parameter): 130 | # backwards compatibility for serialized parameters 131 | param = param.data 132 | ema_copy.state_dict()[name].copy_(param) 133 | 134 | training_wrapper = AudioLanguageModelTrainingWrapper.load_from_checkpoint( 135 | args.ckpt_path, 136 | model=model, 137 | strict=False, 138 | ema_copy=ema_copy, 139 | optimizer_configs=training_config.get("optimizer_configs", None) 140 | ) 141 | 142 | else: 143 | raise ValueError(f"Unknown model type {model_type}") 144 | 145 | print(f"Loaded model from {args.ckpt_path}") 146 | 147 | if args.use_safetensors: 148 | ckpt_path = f"{args.name}.safetensors" 149 | else: 150 | ckpt_path = f"{args.name}.ckpt" 151 | 152 | training_wrapper.export_model(ckpt_path, use_safetensors=args.use_safetensors) 153 | 154 | print(f"Exported model to {ckpt_path}") -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/models/factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # modified from stable-audio-tools under the MIT license 4 | 5 | import json 6 | 7 | def create_model_from_config(model_config): 8 | model_type = model_config.get('model_type', None) 9 | 10 | assert model_type is not None, 'model_type must be specified in model config' 11 | 12 | if model_type == 'autoencoder': 13 | from .autoencoders import create_autoencoder_from_config 14 | return create_autoencoder_from_config(model_config) 15 | elif model_type == 'diffusion_uncond': 16 | from .diffusion import create_diffusion_uncond_from_config 17 | return create_diffusion_uncond_from_config(model_config) 18 | elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior": 19 | from .diffusion import create_diffusion_cond_from_config 20 | return create_diffusion_cond_from_config(model_config) 21 | elif model_type == 'diffusion_autoencoder': 22 | from .autoencoders import create_diffAE_from_config 23 | return create_diffAE_from_config(model_config) 24 | elif model_type == 'lm': 25 | from .lm import create_audio_lm_from_config 26 | return create_audio_lm_from_config(model_config) 27 | else: 28 | raise NotImplementedError(f'Unknown model type: {model_type}') 29 | 30 | def create_model_from_config_path(model_config_path): 31 | with open(model_config_path) as f: 32 | model_config = json.load(f) 33 | 34 | return create_model_from_config(model_config) 35 | 36 | def create_pretransform_from_config(pretransform_config, sample_rate): 37 | pretransform_type = pretransform_config.get('type', None) 38 | 39 | assert pretransform_type is not None, 'type must be specified in pretransform config' 40 | 41 | if pretransform_type == 'autoencoder': 42 | from .autoencoders import create_autoencoder_from_config 43 | from .pretransforms import AutoencoderPretransform 44 | 45 | # Create fake top-level config to pass sample rate to autoencoder constructor 46 | # This is a bit of a hack but it keeps us from re-defining the sample rate in the config 47 | autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]} 48 | autoencoder = create_autoencoder_from_config(autoencoder_config) 49 | 50 | scale = pretransform_config.get("scale", 1.0) 51 | model_half = pretransform_config.get("model_half", False) 52 | iterate_batch = pretransform_config.get("iterate_batch", False) 53 | chunked = pretransform_config.get("chunked", False) 54 | 55 | pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) 56 | elif pretransform_type == 'wavelet': 57 | from .pretransforms import WaveletPretransform 58 | 59 | wavelet_config = pretransform_config["config"] 60 | channels = wavelet_config["channels"] 61 | levels = wavelet_config["levels"] 62 | wavelet = wavelet_config["wavelet"] 63 | 64 | pretransform = WaveletPretransform(channels, levels, wavelet) 65 | elif pretransform_type == 'pqmf': 66 | from .pretransforms import PQMFPretransform 67 | pqmf_config = pretransform_config["config"] 68 | pretransform = PQMFPretransform(**pqmf_config) 69 | elif pretransform_type == 'dac_pretrained': 70 | from .pretransforms import PretrainedDACPretransform 71 | pretrained_dac_config = pretransform_config["config"] 72 | pretransform = PretrainedDACPretransform(**pretrained_dac_config) 73 | elif pretransform_type == "audiocraft_pretrained": 74 | from .pretransforms import AudiocraftCompressionPretransform 75 | 76 | audiocraft_config = pretransform_config["config"] 77 | pretransform = AudiocraftCompressionPretransform(**audiocraft_config) 78 | else: 79 | raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') 80 | 81 | enable_grad = pretransform_config.get('enable_grad', False) 82 | pretransform.enable_grad = enable_grad 83 | 84 | pretransform.eval().requires_grad_(pretransform.enable_grad) 85 | 86 | return pretransform 87 | 88 | def create_bottleneck_from_config(bottleneck_config): 89 | bottleneck_type = bottleneck_config.get('type', None) 90 | 91 | assert bottleneck_type is not None, 'type must be specified in bottleneck config' 92 | 93 | if bottleneck_type == 'tanh': 94 | from .bottleneck import TanhBottleneck 95 | bottleneck = TanhBottleneck() 96 | elif bottleneck_type == 'vae': 97 | from .bottleneck import VAEBottleneck 98 | bottleneck = VAEBottleneck() 99 | elif bottleneck_type == 'rvq': 100 | from .bottleneck import RVQBottleneck 101 | 102 | quantizer_params = { 103 | "dim": 128, 104 | "codebook_size": 1024, 105 | "num_quantizers": 8, 106 | "decay": 0.99, 107 | "kmeans_init": True, 108 | "kmeans_iters": 50, 109 | "threshold_ema_dead_code": 2, 110 | } 111 | 112 | quantizer_params.update(bottleneck_config["config"]) 113 | 114 | bottleneck = RVQBottleneck(**quantizer_params) 115 | elif bottleneck_type == "dac_rvq": 116 | from .bottleneck import DACRVQBottleneck 117 | 118 | bottleneck = DACRVQBottleneck(**bottleneck_config["config"]) 119 | 120 | elif bottleneck_type == 'rvq_vae': 121 | from .bottleneck import RVQVAEBottleneck 122 | 123 | quantizer_params = { 124 | "dim": 128, 125 | "codebook_size": 1024, 126 | "num_quantizers": 8, 127 | "decay": 0.99, 128 | "kmeans_init": True, 129 | "kmeans_iters": 50, 130 | "threshold_ema_dead_code": 2, 131 | } 132 | 133 | quantizer_params.update(bottleneck_config["config"]) 134 | 135 | bottleneck = RVQVAEBottleneck(**quantizer_params) 136 | 137 | elif bottleneck_type == 'dac_rvq_vae': 138 | from .bottleneck import DACRVQVAEBottleneck 139 | bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"]) 140 | elif bottleneck_type == 'l2_norm': 141 | from .bottleneck import L2Bottleneck 142 | bottleneck = L2Bottleneck() 143 | elif bottleneck_type == "wasserstein": 144 | from .bottleneck import WassersteinBottleneck 145 | bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {})) 146 | elif bottleneck_type == "fsq": 147 | from .bottleneck import FSQBottleneck 148 | bottleneck = FSQBottleneck(**bottleneck_config["config"]) 149 | else: 150 | raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') 151 | 152 | requires_grad = bottleneck_config.get('requires_grad', True) 153 | if not requires_grad: 154 | for param in bottleneck.parameters(): 155 | param.requires_grad = False 156 | 157 | return bottleneck 158 | -------------------------------------------------------------------------------- /ETTA/stable_audio_tools/utils/addict.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | # copied and modified from https://github.com/mewwts/addict/blob/master/addict/addict.py under the MIT license. 4 | # with additional update_params() for convinence 5 | # LICENSE is in LICENCES directory. 6 | 7 | import copy 8 | from typing import List 9 | import ast 10 | 11 | class Dict(dict): 12 | 13 | def __init__(__self, *args, **kwargs): 14 | object.__setattr__(__self, '__parent', kwargs.pop('__parent', None)) 15 | object.__setattr__(__self, '__key', kwargs.pop('__key', None)) 16 | object.__setattr__(__self, '__frozen', False) 17 | for arg in args: 18 | if not arg: 19 | continue 20 | elif isinstance(arg, dict): 21 | for key, val in arg.items(): 22 | __self[key] = __self._hook(val) 23 | elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)): 24 | __self[arg[0]] = __self._hook(arg[1]) 25 | else: 26 | for key, val in iter(arg): 27 | __self[key] = __self._hook(val) 28 | 29 | for key, val in kwargs.items(): 30 | __self[key] = __self._hook(val) 31 | 32 | def __setattr__(self, name, value): 33 | if hasattr(self.__class__, name): 34 | raise AttributeError("'Dict' object attribute " 35 | "'{0}' is read-only".format(name)) 36 | else: 37 | self[name] = value 38 | 39 | def __setitem__(self, name, value): 40 | isFrozen = (hasattr(self, '__frozen') and 41 | object.__getattribute__(self, '__frozen')) 42 | if isFrozen and name not in super(Dict, self).keys(): 43 | raise KeyError(name) 44 | super(Dict, self).__setitem__(name, value) 45 | try: 46 | p = object.__getattribute__(self, '__parent') 47 | key = object.__getattribute__(self, '__key') 48 | except AttributeError: 49 | p = None 50 | key = None 51 | if p is not None: 52 | p[key] = self 53 | object.__delattr__(self, '__parent') 54 | object.__delattr__(self, '__key') 55 | 56 | def __add__(self, other): 57 | if not self.keys(): 58 | return other 59 | else: 60 | self_type = type(self).__name__ 61 | other_type = type(other).__name__ 62 | msg = "unsupported operand type(s) for +: '{}' and '{}'" 63 | raise TypeError(msg.format(self_type, other_type)) 64 | 65 | @classmethod 66 | def _hook(cls, item): 67 | if isinstance(item, dict): 68 | return cls(item) 69 | elif isinstance(item, (list, tuple)): 70 | return type(item)(cls._hook(elem) for elem in item) 71 | return item 72 | 73 | def __getattr__(self, item): 74 | return self.__getitem__(item) 75 | 76 | def __missing__(self, name): 77 | if object.__getattribute__(self, '__frozen'): 78 | raise KeyError(name) 79 | return self.__class__(__parent=self, __key=name) 80 | 81 | def __delattr__(self, name): 82 | del self[name] 83 | 84 | def to_dict(self): 85 | base = {} 86 | for key, value in self.items(): 87 | if isinstance(value, type(self)): 88 | base[key] = value.to_dict() 89 | elif isinstance(value, (list, tuple)): 90 | base[key] = type(value)( 91 | item.to_dict() if isinstance(item, type(self)) else 92 | item for item in value) 93 | else: 94 | base[key] = value 95 | return base 96 | 97 | def copy(self): 98 | return copy.copy(self) 99 | 100 | def deepcopy(self): 101 | return copy.deepcopy(self) 102 | 103 | def __deepcopy__(self, memo): 104 | other = self.__class__() 105 | memo[id(self)] = other 106 | for key, value in self.items(): 107 | other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo) 108 | return other 109 | 110 | def update(self, *args, **kwargs): 111 | other = {} 112 | if args: 113 | if len(args) > 1: 114 | raise TypeError() 115 | other.update(args[0]) 116 | other.update(kwargs) 117 | for k, v in other.items(): 118 | if ((k not in self) or 119 | (not isinstance(self[k], dict)) or 120 | (not isinstance(v, dict))): 121 | self[k] = v 122 | else: 123 | self[k].update(v) 124 | 125 | def __getnewargs__(self): 126 | return tuple(self.items()) 127 | 128 | def __getstate__(self): 129 | return self 130 | 131 | def __setstate__(self, state): 132 | self.update(state) 133 | 134 | def __or__(self, other): 135 | if not isinstance(other, (Dict, dict)): 136 | return NotImplemented 137 | new = Dict(self) 138 | new.update(other) 139 | return new 140 | 141 | def __ror__(self, other): 142 | if not isinstance(other, (Dict, dict)): 143 | return NotImplemented 144 | new = Dict(other) 145 | new.update(self) 146 | return new 147 | 148 | def __ior__(self, other): 149 | self.update(other) 150 | return self 151 | 152 | def setdefault(self, key, default=None): 153 | if key in self: 154 | return self[key] 155 | else: 156 | self[key] = default 157 | return default 158 | 159 | def freeze(self, shouldFreeze=True): 160 | object.__setattr__(self, '__frozen', shouldFreeze) 161 | for key, val in self.items(): 162 | if isinstance(val, Dict): 163 | val.freeze(shouldFreeze) 164 | 165 | def unfreeze(self): 166 | self.freeze(False) 167 | 168 | def str_to_bool(self, value): 169 | """Converts string to boolean if applicable.""" 170 | if value.lower() in ('true'): 171 | return True 172 | elif value.lower() in ('false'): 173 | return False 174 | return None 175 | 176 | def update_params(self, params: List[str]): 177 | """Overrides self contents with params.""" 178 | for param in params: 179 | # Split only on the first '=' to avoid issues with values containing '=' 180 | k, v = param.split("=", 1) # Split only at the first '=' 181 | boolean_value = self.str_to_bool(v) 182 | if boolean_value is None: # str_to_bool did not return a boolean, try other conversions 183 | try: 184 | v = ast.literal_eval(v) 185 | if isinstance(v, tuple): 186 | print(f"[INFO] converting {k}: {v} tuple value to list for compatibility") 187 | v = list(v) 188 | except (ValueError, SyntaxError): # Catch SyntaxError as well for malformed literals 189 | pass # v remains a string if it cannot be evaluated to a Python literal 190 | else: 191 | v = boolean_value # Use the boolean value returned by str_to_bool 192 | 193 | k_split = k.split('.') 194 | current = self 195 | for part in k_split[:-1]: # Navigate/create intermediate Dicts for nested keys 196 | if part not in current or not isinstance(current[part], Dict): 197 | current[part] = Dict() # Ensure nested dicts are Dict instances 198 | current = current[part] 199 | 200 | final_key = k_split[-1] 201 | if final_key in current: 202 | print(f"[INFO] overriding {final_key} with {v}") 203 | else: 204 | print(f"[WARNING] new param {final_key} with {v}") 205 | current[final_key] = v -------------------------------------------------------------------------------- /ETTA/docs/diffusion.md: -------------------------------------------------------------------------------- 1 | # Diffusion 2 | 3 | Diffusion models learn to denoise data 4 | 5 | # Model configs 6 | The model config file for a diffusion model should set the `model_type` to `diffusion_cond` if the model uses conditioning, or `diffusion_uncond` if it does not, and the `model` object should have the following properties: 7 | 8 | - `diffusion` 9 | - The configuration for the diffusion model itself. See below for more information on the diffusion model config 10 | - `pretransform` 11 | - The configuration of the diffusion model's [pretransform](pretransforms.md), such as an autoencoder for latent diffusion. 12 | - Optional 13 | - `conditioning` 14 | - The configuration of the various [conditioning](conditioning.md) modules for the diffusion model 15 | - Only required for `diffusion_cond` 16 | - `io_channels` 17 | - The base number of input/output channels for the diffusion model 18 | - Used by inference scripts to determine the shape of the noise to generate for the diffusion model 19 | 20 | # Diffusion configs 21 | - `type` 22 | - The underlying model type for the transformer 23 | - For conditioned diffusion models, be one of `dit` ([Diffusion Transformer](#diffusion-transformers-dit)), `DAU1d` ([Dance Diffusion U-Net](#dance-diffusion-u-net)), or `adp_cfg_1d` ([audio-diffusion-pytorch U-Net](#audio-diffusion-pytorch-u-net-adp)) 24 | - Unconditioned diffusion models can also use `adp_1d` 25 | - `cross_attention_cond_ids` 26 | - Conditioner ids for conditioning information to be used as cross-attention input 27 | - If multiple ids are specified, the conditioning tensors will be concatenated along the sequence dimension 28 | - `global_cond_ids` 29 | - Conditioner ids for conditioning information to be used as global conditioning input 30 | - If multiple ids are specified, the conditioning tensors will be concatenated along the channel dimension 31 | - `prepend_cond_ids` 32 | - Conditioner ids for conditioning information to be prepended to the model input 33 | - If multiple ids are specified, the conditioning tensors will be concatenated along the sequence dimension 34 | - Only works with diffusion transformer models 35 | - `input_concat_ids` 36 | - Conditioner ids for conditioning information to be concatenated to the model input 37 | - If multiple ids are specified, the conditioning tensors will be concatenated along the channel dimension 38 | - If the conditioning tensors are not the same length as the model input, they will be interpolated along the sequence dimension to be the same length. 39 | - The interpolation algorithm is model-dependent, but usually uses nearest-neighbor resampling. 40 | - `config` 41 | - The configuration for the model backbone itself 42 | - Model-dependent 43 | 44 | # Training configs 45 | The `training` config in the diffusion model config file should have the following properties: 46 | 47 | - `learning_rate` 48 | - The learning rate to use during training 49 | - Defaults to constant learning rate, can be overridden with `optimizer_configs` 50 | - `use_ema` 51 | - If true, a copy of the model weights is maintained during training and updated as an exponential moving average of the trained model's weights. 52 | - Optional. Default: `true` 53 | - `log_loss_info` 54 | - If true, additional diffusion loss info will be gathered across all GPUs and displayed during training 55 | - Optional. Default: `false` 56 | - `loss_configs` 57 | - Configurations for the loss function calculation 58 | - Optional 59 | - `optimizer_configs` 60 | - Configuration for optimizers and schedulers 61 | - Optional, overrides `learning_rate` 62 | - `demo` 63 | - Configuration for the demos during training, including conditioning information 64 | 65 | ## Example config 66 | ```json 67 | "training": { 68 | "use_ema": true, 69 | "log_loss_info": false, 70 | "optimizer_configs": { 71 | "diffusion": { 72 | "optimizer": { 73 | "type": "AdamW", 74 | "config": { 75 | "lr": 5e-5, 76 | "betas": [0.9, 0.999], 77 | "weight_decay": 1e-3 78 | } 79 | }, 80 | "scheduler": { 81 | "type": "InverseLR", 82 | "config": { 83 | "inv_gamma": 1000000, 84 | "power": 0.5, 85 | "warmup": 0.99 86 | } 87 | } 88 | } 89 | }, 90 | "demo": { ... } 91 | } 92 | ``` 93 | 94 | # Demo configs 95 | The `demo` config in the diffusion model training config should have the following properties: 96 | - `demo_every` 97 | - How many training steps between demos 98 | - `demo_steps` 99 | - Number of diffusion timesteps to run for the demos 100 | - `num_demos` 101 | - This is the number of examples to generate in each demo 102 | - `demo_cond` 103 | - For conditioned diffusion models, this is the conditioning metadata to provide to each example, provided as a list 104 | - NOTE: List must be the same length as `num_demos` 105 | - `demo_cfg_scales` 106 | - For conditioned diffusion models, this provides a list of classifier-free guidance (CFG) scales to render during the demos. This can be helpful to get an idea of how the model responds to different conditioning strengths as training continues. 107 | 108 | ## Example config 109 | ```json 110 | "demo": { 111 | "demo_every": 2000, 112 | "demo_steps": 250, 113 | "num_demos": 4, 114 | "demo_cond": [ 115 | {"prompt": "A beautiful piano arpeggio", "seconds_start": 0, "seconds_total": 80}, 116 | {"prompt": "A tropical house track with upbeat melodies, a driving bassline, and cheery vibes", "seconds_start": 0, "seconds_total": 250}, 117 | {"prompt": "A cool 80s glam rock song with driving drums and distorted guitars", "seconds_start": 0, "seconds_total": 180}, 118 | {"prompt": "A grand orchestral arrangement", "seconds_start": 0, "seconds_total": 190} 119 | ], 120 | "demo_cfg_scales": [3, 6, 9] 121 | } 122 | ``` 123 | 124 | # Model types 125 | 126 | A variety of different model types can be used as the underlying backbone for a diffusion model. At the moment, this includes variants of U-Net and Transformer models. 127 | 128 | ## Diffusion Transformers (DiT) 129 | 130 | Transformers tend to consistently outperform U-Nets in terms of model quality, but are much more memory- and compute-intensive and work best on shorter sequences such as latent encodings of audio. 131 | 132 | ### Continuous Transformer 133 | 134 | This is our custom implementation of a transformer model, based on the `x-transformers` implementation, but with efficiency improvements such as fused QKV layers, and Flash Attention 2 support. 135 | 136 | ### `x-transformers` 137 | 138 | This model type uses the `ContinuousTransformerWrapper` class from the https://github.com/lucidrains/x-transformers repository as the diffusion transformer backbone. 139 | 140 | `x-transformers` is a great baseline transformer implementation with lots of options for various experimental settings. 141 | It's great for testing out experimental features without implementing them yourself, but the implementations might not be fully optimized, and breaking changes may be introduced without much warning. 142 | 143 | ## Diffusion U-Net 144 | 145 | U-Nets use a hierarchical architecture to gradually downsample the input data before more heavy processing is performed, then upsample the data again, using skip connections to pass data across the downsampling "valley" (the "U" in the name) to the upsampling layer at the same resolution. 146 | 147 | ### audio-diffusion-pytorch U-Net (ADP) 148 | 149 | This model type uses a modified implementation of the `UNetCFG1D` class from version 0.0.94 of the `https://github.com/archinetai/audio-diffusion-pytorch` repo, with added Flash Attention support. 150 | 151 | ### Dance Diffusion U-Net 152 | 153 | This is a reimplementation of the U-Net used in [Dance Diffusion](https://github.com/Harmonai-org/sample-generator). It has minimal conditioning support, only really supporting global conditioning. Mostly used for unconditional diffusion models. -------------------------------------------------------------------------------- /ETTA/docs/conditioning.md: -------------------------------------------------------------------------------- 1 | # Conditioning 2 | Conditioning, in the context of `stable-audio-tools` is the use of additional signals in a model that are used to add an additional level of control over the model's behavior. For example, we can condition the outputs of a diffusion model on a text prompt, creating a text-to-audio model. 3 | 4 | # Conditioning types 5 | There are a few different kinds of conditioning depending on the conditioning signal being used. 6 | 7 | ## Cross attention 8 | Cross attention is a type of conditioning that allows us to find correlations between two sequences of potentially different lengths. For example, cross attention allows us to find correlations between a sequence of features from a text encoder and a sequence of high-level audio features. 9 | 10 | Signals used for cross-attention conditioning should be of the shape `[batch, sequence, channels]`. 11 | 12 | ## Global conditioning 13 | Global conditioning is the use of a single n-dimensional tensor to provide conditioning information that pertains to the whole sequence being conditioned. For example, this could be the single embedding output of a CLAP model, or a learned class embedding. 14 | 15 | Signals used for global conditioning should be of the shape `[batch, channels]`. 16 | 17 | ## Prepend conditioning 18 | Prepend conditioning involves prepending the conditioning tokens to the data tokens in the model, allowing for the information to be interpreted through the model's self-attention mechanism. 19 | 20 | This kind of conditioning is currently only supported by Transformer-based models such as diffusion transformers. 21 | 22 | Signals used for prepend conditioning should be of the shape `[batch, sequence, channels]`. 23 | 24 | ## Input concatenation 25 | Input concatenation applies a spatial conditioning signal to the model that correlates in the sequence dimension with the model's input, and is of the same length. The conditioning signal will be concatenated with the model's input data along the channel dimension. This can be used for things like inpainting information, melody conditioning, or for creating a diffusion autoencoder. 26 | 27 | Signals used for input concatenation conditioning should be of the shape `[batch, channels, sequence]` and must be the same length as the model's input. 28 | 29 | # Conditioners and conditioning configs 30 | `stable-audio-tools` uses Conditioner modules to translate human-readable metadata such as text prompts or a number of seconds into tensors that the model can take as input. 31 | 32 | Each conditioner has a corresponding `id` that it expects to find in the conditioning dictionary provided during training or inference. Each conditioner takes in the relevant conditioning data and returns a tuple containing the corresponding tensor and a mask. 33 | 34 | The ConditionedDiffusionModelWrapper manages the translation between the user-provided metadata dictionary (e.g. `{"prompt": "a beautiful song", "seconds_start": 22, "seconds_total": 193}`) and the dictionary of different conditioning types that the model uses (e.g. `{"cross_attn_cond": ...}`). 35 | 36 | To apply conditioning to a model, you must provide a `conditioning` configuration in the model's config. At the moment, we only support conditioning diffusion models though the `diffusion_cond` model type. 37 | 38 | The `conditioning` configuration should contain a `configs` array, which allows you to define multiple conditioning signals. 39 | 40 | Each item in `configs` array should define the `id` for the corresponding metadata, the type of conditioner to be used, and the config for that conditioner. 41 | 42 | The `cond_dim` property is used to enforce the same dimension on all conditioning inputs, however that can be overridden with an explicit `output_dim` property on any of the individual configs. 43 | 44 | ## Example config 45 | ```json 46 | "conditioning": { 47 | "configs": [ 48 | { 49 | "id": "prompt", 50 | "type": "t5", 51 | "config": { 52 | "t5_model_name": "t5-base", 53 | "max_length": 77, 54 | "project_out": true 55 | } 56 | } 57 | ], 58 | "cond_dim": 768 59 | } 60 | ``` 61 | 62 | # Conditioners 63 | 64 | ## Text encoders 65 | 66 | ### `t5` 67 | This uses a frozen [T5](https://huggingface.co/docs/transformers/model_doc/t5) text encoder from the `transformers` library to encode text prompts into a sequence of text features. 68 | 69 | The `t5_model_name` property determines which T5 model is loaded from the `transformers` library. 70 | 71 | The `max_length` property determines the maximum number of tokens that the text encoder will take in, as well as the sequence length of the output text features. 72 | 73 | If you set `enable_grad` to `true`, the T5 model will be un-frozen and saved with the model checkpoint, allowing you to fine-tune the T5 model. 74 | 75 | T5 encodings are only compatible with cross attention conditioning. 76 | 77 | #### Example config 78 | ```json 79 | { 80 | "id": "prompt", 81 | "type": "t5", 82 | "config": { 83 | "t5_model_name": "t5-base", 84 | "max_length": 77, 85 | "project_out": true 86 | } 87 | } 88 | ``` 89 | 90 | ### `clap_text` 91 | This loads the text encoder from a [CLAP](https://github.com/LAION-AI/CLAP) model, which can provide either a sequence of text features, or a single multimodal text/audio embedding. 92 | 93 | The CLAP model must be provided with a local file path, set in the `clap_ckpt_path` property,along with the correct `audio_model_type` and `enable_fusion` properties for the provided model. 94 | 95 | If the `use_text_features` property is set to `true`, the conditioner output will be a sequence of text features, instead of a single multimodal embedding. This allows for more fine-grained text information to be used by the model, at the cost of losing the ability to prompt with CLAP audio embeddings. 96 | 97 | By default, if `use_text_features` is true, the last layer of the CLAP text encoder's features are returned. You can return the text features of earlier layers by specifying the index of the layer to return in the `feature_layer_ix` property. For example, you can return the text features of the next-to-last layer of the CLAP model by setting `feature_layer_ix` to `-2`. 98 | 99 | If you set `enable_grad` to `true`, the CLAP model will be un-frozen and saved with the model checkpoint, allowing you to fine-tune the CLAP model. 100 | 101 | CLAP text embeddings are compatible with global conditioning and cross attention conditioning. If `use_text_features` is set to `true`, the features are not compatible with global conditioning. 102 | 103 | #### Example config 104 | ```json 105 | { 106 | "id": "prompt", 107 | "type": "clap_text", 108 | "config": { 109 | "clap_ckpt_path": "/path/to/clap/model.ckpt", 110 | "audio_model_type": "HTSAT-base", 111 | "enable_fusion": true, 112 | "use_text_features": true, 113 | "feature_layer_ix": -2 114 | } 115 | } 116 | ``` 117 | 118 | ## Number encoders 119 | 120 | ### `int` 121 | The IntConditioner takes in a list of integers in a given range, and returns a discrete learned embedding for each of those integers. 122 | 123 | The `min_val` and `max_val` properties set the range of the embedding values. Input integers are clamped to this range. 124 | 125 | This can be used for things like discrete timing embeddings, or learned class embeddings. 126 | 127 | Int embeddings are compatible with global conditioning and cross attention conditioning. 128 | 129 | #### Example config 130 | ```json 131 | { 132 | "id": "seconds_start", 133 | "type": "int", 134 | "config": { 135 | "min_val": 0, 136 | "max_val": 512 137 | } 138 | } 139 | ``` 140 | 141 | ### `number` 142 | The NumberConditioner takes in a a list of floats in a given range, and returns a continuous Fourier embedding of the provided floats. 143 | 144 | The `min_val` and `max_val` properties set the range of the float values. This is the range used to normalize the input float values. 145 | 146 | Number embeddings are compatible with global conditioning and cross attention conditioning. 147 | 148 | #### Example config 149 | ```json 150 | { 151 | "id": "seconds_total", 152 | "type": "number", 153 | "config": { 154 | "min_val": 0, 155 | "max_val": 512 156 | } 157 | } 158 | ``` -------------------------------------------------------------------------------- /A2SB/datasets/datamodule.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for A2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | 9 | import lightning.pytorch as pl 10 | import torch 11 | from torch.utils.data import DataLoader 12 | from typing import Optional, List 13 | from datasets.datasets import MixAudioDataset, FullSequencePredictDataset 14 | from utils import SequenceLength 15 | import numpy as np 16 | 17 | def collate_fn(batch): 18 | all_lengths = torch.tensor([mel.shape[1] for mel in batch]) 19 | 20 | sequence_lengths = SequenceLength(all_lengths) 21 | max_length = all_lengths.max() 22 | channels = batch[0].shape[0] 23 | # pad length to closest power of 2 24 | # we assume mel channels fixed at 80, which is divisible by 2 up to 3 times 25 | padded_length = int(2**np.ceil(np.log2(all_lengths.max()))) 26 | all_mels = torch.zeros((len(batch), channels, padded_length)) 27 | for ind in range(len(batch)): 28 | all_mels[ind, :, :all_lengths[ind]] = torch.tensor(batch[ind]) 29 | 30 | output_dict = {'mels': all_mels, 'seq_lens': sequence_lengths} 31 | return output_dict 32 | 33 | 34 | class STFTAudioDataModule(pl.LightningDataModule): 35 | def __init__(self, 36 | mix_dataset_config={}, 37 | segment_length=2**16, 38 | sampling_rate=22050, 39 | num_workers=0, 40 | batch_size=8, 41 | transforms_gt=[], 42 | transforms_aug=[], 43 | transforms_aug_val=[], 44 | eval_transforms_aug=[], 45 | train_max_samples=None, 46 | val_max_samples=100, 47 | predict_filelist=[], 48 | predict_start_idx=0, 49 | predict_end_idx=None 50 | ): 51 | super().__init__() 52 | # self.fft_size = fft_size 53 | # self.hop_size = hop_size 54 | # self.win_length = win_length 55 | # self.window_type = window_type 56 | self.mix_dataset_config = mix_dataset_config 57 | self.segment_length = segment_length 58 | self.sampling_rate = sampling_rate 59 | self.num_workers = num_workers 60 | self.batch_size = batch_size 61 | self.transforms_aug = transforms_aug 62 | if not transforms_aug_val: 63 | self.transforms_aug_val = transforms_aug 64 | else: 65 | self.transforms_aug_val = transforms_aug_val # validation set augmentation (randomness is fixed) 66 | self.transforms_gt = transforms_gt 67 | self.eval_transforms_aug = eval_transforms_aug 68 | self.train_max_samples = train_max_samples 69 | self.val_max_samples = val_max_samples 70 | self.predict_filelist = predict_filelist 71 | self.predict_start_idx = predict_start_idx 72 | self.predict_end_idx = predict_end_idx 73 | 74 | 75 | def prepare_data(self): 76 | pass 77 | 78 | def setup(self, stage: str): 79 | if stage == "fit": 80 | print("initializing training dataset") 81 | self.trainset = MixAudioDataset( 82 | mix_dataset_config = self.mix_dataset_config, 83 | split = 'train', 84 | segment_length=self.segment_length, 85 | sampling_rate=self.sampling_rate, 86 | transforms_aug = self.transforms_aug, 87 | transforms_gt = self.transforms_gt, 88 | eval_transforms_aug = self.eval_transforms_aug, 89 | max_samples = self.train_max_samples 90 | ) 91 | 92 | # pass a list of validation datasets 93 | self.valset = [] 94 | i = 0 95 | for valset_name in self.mix_dataset_config: 96 | single_val_dataset_config = {valset_name: self.mix_dataset_config[valset_name]} 97 | valset_i = MixAudioDataset( 98 | mix_dataset_config=single_val_dataset_config, # instead of self.mix_dataset_config 99 | split='validation', 100 | segment_length=self.segment_length, 101 | sampling_rate=self.sampling_rate, 102 | transforms_aug=self.transforms_aug, 103 | transforms_gt=self.transforms_gt, 104 | eval_transforms_aug = self.eval_transforms_aug, 105 | evaluation_mode=True, 106 | max_samples = self.val_max_samples 107 | ) 108 | if len(valset_i) > 0: 109 | self.valset.append(valset_i) 110 | print("valset_{}: {}".format(i, valset_name)) 111 | i += 1 112 | 113 | elif stage == "validation": 114 | self.valset = [] 115 | for valset_name in self.mix_dataset_config: 116 | single_val_dataset_config = {valset_name: self.mix_dataset_config[valset_name]} 117 | valset_i = MixAudioDataset( 118 | mix_dataset_config=single_val_dataset_config, # instead of self.mix_dataset_config 119 | split='validation', 120 | segment_length=self.segment_length, 121 | sampling_rate=self.sampling_rate, 122 | transforms_aug=self.transforms_aug, 123 | transforms_gt=self.transforms_gt, 124 | eval_transforms_aug = self.eval_transforms_aug, 125 | evaluation_mode=True, 126 | max_samples = self.val_max_samples 127 | ) 128 | if len(valset_i) > 0: 129 | self.valset.append(valset_i) 130 | print("valset_{}: {}".format(i, valset_name)) 131 | i += 1 132 | 133 | elif stage == "test": 134 | self.testset = [] 135 | for testset_name in self.mix_dataset_config: 136 | single_test_dataset_config = {testset_name: self.mix_dataset_config[testset_name]} 137 | testset_i = MixAudioDataset( 138 | mix_dataset_config=single_test_dataset_config, # instead of self.mix_dataset_config 139 | split='test', 140 | segment_length=self.segment_length, 141 | sampling_rate=self.sampling_rate, 142 | transforms_aug=self.transforms_aug, 143 | transforms_gt=self.transforms_gt, 144 | eval_transforms_aug = self.eval_transforms_aug, 145 | evaluation_mode=True, 146 | max_samples = self.val_max_samples 147 | ) 148 | if len(testset_i) > 0: 149 | self.testset.append(testset_i) 150 | print("testset_{}: {}".format(i, testset_name)) 151 | i += 1 152 | elif stage == "predict": 153 | self.predict_dataset = FullSequencePredictDataset( 154 | audio_file_list=self.predict_filelist, 155 | sampling_rate=self.sampling_rate, 156 | transforms_aug=self.transforms_aug, 157 | transforms_gt=self.transforms_gt, 158 | start_idx=self.predict_start_idx, 159 | end_idx=self.predict_end_idx 160 | ) 161 | else: 162 | raise ValueError("Unimplemented stage in datamodule class") 163 | 164 | 165 | 166 | def train_dataloader(self): 167 | train_loader = DataLoader( 168 | self.trainset, num_workers=self.num_workers, shuffle=True, 169 | batch_size=self.batch_size, pin_memory=False, drop_last=True) 170 | return train_loader 171 | 172 | 173 | def val_dataloader(self): 174 | val_loader = [] 175 | for valset_i in self.valset: 176 | val_loader.append( 177 | DataLoader( 178 | valset_i, # self.valset 179 | num_workers=self.num_workers, shuffle=False, 180 | batch_size=self.batch_size, pin_memory=False, drop_last=False 181 | ) 182 | ) 183 | return val_loader 184 | 185 | 186 | def test_dataloader(self): 187 | print("initializing test dataloader") 188 | test_loader = [] 189 | for testset_i in self.testset: 190 | test_loader.append( 191 | DataLoader( 192 | testset_i, 193 | num_workers=self.num_workers, shuffle=False, 194 | batch_size=self.batch_size, pin_memory=False, drop_last=False 195 | ) 196 | ) 197 | return test_loader 198 | 199 | 200 | def predict_dataloader(self): 201 | return DataLoader( 202 | self.predict_dataset, 203 | num_workers=0, shuffle=False, 204 | batch_size=1, pin_memory=False, 205 | drop_last=False) 206 | -------------------------------------------------------------------------------- /A2SB/inference/A2SB_inpaint_dataset.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for A2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | 10 | # # If there is Error: mkl-service + Intel(R) MKL: MKL_THREADING_LAYER=INTEL is incompatible with libgomp.so.1 library. 11 | # os.environ["MKL_THREADING_LAYER"] = "GNU" 12 | # import numpy as np 13 | # os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" 14 | 15 | import numpy as np 16 | import json 17 | import argparse 18 | import glob 19 | from subprocess import Popen, PIPE 20 | import yaml 21 | import time 22 | from datetime import datetime 23 | import shutil 24 | import csv 25 | from tqdm import tqdm 26 | from copy import deepcopy 27 | 28 | import librosa 29 | import soundfile as sf 30 | 31 | 32 | def read_standard_csv(root_folder, filename): 33 | all_files = { 34 | "train": [], 35 | "validation": [], 36 | "test": [] 37 | } 38 | 39 | with open(os.path.join(root_folder, filename)) as csvfile: 40 | reader = csv.reader(csvfile, delimiter=',', quotechar='"') 41 | next(reader) 42 | for row in reader: 43 | assert len(row) == 3 44 | split, audio_filename, duration = row 45 | split = split.strip() 46 | audio_filename = audio_filename.strip() 47 | duration = float(duration) 48 | sample_rate = None # 44100 # not used 49 | all_files[split].append((audio_filename, duration, sample_rate)) 50 | 51 | return all_files 52 | 53 | 54 | def load_yaml(file_path): 55 | with open(file_path, 'r') as file: 56 | data = yaml.safe_load(file) 57 | return data 58 | 59 | 60 | def save_yaml(data, prefix="../configs/temp"): 61 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 62 | rnd_num = np.random.rand() 63 | rnd_num = rnd_num - rnd_num % 0.000001 64 | file_name = f"{prefix}_{timestamp}_{rnd_num}.yaml" 65 | with open(file_name, 'w') as f: 66 | yaml.dump(data, f) 67 | return file_name 68 | 69 | 70 | def shell_run_cmd(cmd): 71 | print('running:', cmd) 72 | p = Popen(cmd, stdout=PIPE, stderr=PIPE, shell=True) 73 | stdout, stderr = p.communicate() 74 | print(stdout) 75 | print(stderr) 76 | 77 | 78 | def inpaint_one_sample(dataset_name, audio_filename, exp_root, exp_name, inpaint_length, inpaint_every, max_segment_length=-1, custom_output_subdir=None): 79 | assert 0 < inpaint_length < inpaint_every < 10 80 | 81 | # get paths ready 82 | if custom_output_subdir is not None: 83 | output_subdir = custom_output_subdir 84 | else: 85 | output_subdir = '_'.join(audio_filename.split('/')[-3:]) # get reasonably short filename 86 | output_subdir = '.'.join(output_subdir.split('.')[:-1]) # remove suffix 87 | 88 | output_dir = os.path.join(exp_root, exp_name, dataset_name, 'inpaint_{}_every_{}'.format(inpaint_length, inpaint_every)) 89 | if os.path.exists(os.path.join(output_dir, output_subdir, 'recon.wav')): 90 | print(audio_filename, ' - already inpainted') 91 | return 92 | # elif custom_output_subdir is None: 93 | # shutil.rmtree(os.path.join(output_dir, output_subdir)) 94 | 95 | # copy true file 96 | os.makedirs(os.path.join(output_dir, output_subdir), exist_ok=True) 97 | audio_suffix = audio_filename.split('.')[-1] 98 | original_target = os.path.join(output_dir, output_subdir, 'original.{}'.format(audio_suffix)) 99 | if not os.path.exists(original_target): 100 | shutil.copy(audio_filename, original_target) 101 | orig_audio, orig_sr = librosa.load(audio_filename, sr=None) 102 | duration = len(orig_audio) / orig_sr 103 | 104 | if (custom_output_subdir is None) and max_segment_length > 0 and duration > max_segment_length + 2.0: 105 | # split original audio into segments 106 | assert max_segment_length % inpaint_every == 0, 'max_segment_length={}, inpaint_every={}'.format(max_segment_length, inpaint_every) 107 | n_segments = int(np.ceil(duration / max_segment_length)) 108 | print('duration is {:.2f}; segment into {} parts'.format(duration, n_segments)) 109 | for i in range(n_segments): 110 | orig_segment = orig_audio[int(i*max_segment_length*orig_sr):int((i+1)*max_segment_length*orig_sr)] 111 | part_dir = os.path.join(output_dir, output_subdir, 'tmp_part{}'.format(i)) 112 | part_audio_filename = os.path.join(part_dir, 'original.{}'.format(audio_suffix)) 113 | os.makedirs(part_dir, exist_ok=True) 114 | sf.write(part_audio_filename, orig_segment, orig_sr) 115 | inpaint_one_sample(dataset_name, part_audio_filename, exp_root, exp_name, inpaint_length, inpaint_every, custom_output_subdir=part_dir) 116 | 117 | # concatenate all 118 | recon_audio = None 119 | for i in range(n_segments): 120 | part_dir = os.path.join(output_dir, output_subdir, 'tmp_part{}'.format(i)) 121 | recon_audio_part, recon_sr = librosa.load(os.path.join(part_dir, 'recon.wav'), sr=None) 122 | if recon_audio is None: 123 | recon_audio = recon_audio_part 124 | else: 125 | recon_audio = np.append(recon_audio, recon_audio_part) 126 | sf.write(os.path.join(output_dir, output_subdir, 'recon.wav'), recon_audio, recon_sr) 127 | 128 | else: 129 | pass 130 | 131 | # Load, modify, and store yaml file for the specific file 132 | template_yaml_file = '../configs/inference_files_inpainting.yaml' 133 | inference_config = load_yaml(template_yaml_file) 134 | inference_config['data']['predict_filelist'] = [{ 135 | 'filepath': audio_filename, 136 | 'output_subdir': output_subdir 137 | }] 138 | 139 | _each_transforms_aug = deepcopy(inference_config['data']['transforms_aug'][0]) 140 | inference_config['data']['transforms_aug'] = [] 141 | 142 | starts, ends = [], [] 143 | for i in range(int(duration // inpaint_every)): 144 | start = i * inpaint_every + (inpaint_every - inpaint_length) / 2 145 | end = i * inpaint_every + (inpaint_every + inpaint_length) / 2 146 | _each_transforms_aug['init_args']['start_time'] = start 147 | _each_transforms_aug['init_args']['end_time'] = end 148 | 149 | inference_config['data']['transforms_aug'].append(deepcopy(_each_transforms_aug)) 150 | starts.append(start) 151 | ends.append(end) 152 | 153 | temporary_yaml_file = save_yaml(inference_config) 154 | 155 | # compute degraded audio 156 | degraded_audio = deepcopy(orig_audio) 157 | for t1, t2 in zip(starts, ends): 158 | degraded_audio[int(t1*orig_sr):int(t2*orig_sr)] = 0 159 | sf.write(os.path.join(output_dir, output_subdir, 'degraded.{}'.format(audio_suffix)), degraded_audio, orig_sr) 160 | 161 | # run inpainting command 162 | cmd = "cd ../; \ 163 | python ensembled_inference.py predict \ 164 | -c configs/{}.yaml \ 165 | -c {} \ 166 | --model.fast_inpaint_mode=true \ 167 | --model.predict_n_steps=200 \ 168 | --model.predict_output_dir={}; \ 169 | cd inference/".format(exp_name, temporary_yaml_file.replace('../', ''), output_dir) 170 | 171 | shell_run_cmd(cmd) 172 | 173 | os.remove(temporary_yaml_file) 174 | 175 | 176 | def main(): 177 | parser = argparse.ArgumentParser(description='Description of your program') 178 | parser.add_argument('-dn','--dataset_name', help='dataset name', required=True) 179 | parser.add_argument('-exp','--exp_name', help='exp_name', required=True) 180 | parser.add_argument('-inp_len','--inpaint_length', type=float, help='inpaint_length', required=True) 181 | parser.add_argument('-inp_every','--inpaint_every', type=float, help='inpaint_every', required=True) 182 | parser.add_argument('-seg_len','--max_segment_length', type=float, default=-1, help='maximum segment length for inpainting') 183 | parser.add_argument('-start','--start', type=int, help='start', default=0) 184 | parser.add_argument('-end','--end', type=int, help='end', default=-1) 185 | args = parser.parse_args() 186 | 187 | manifest_root_folder = 'PATH/TO/MANIFEST/FOLDER' 188 | exp_root = './exp' 189 | 190 | dataset_name = args.dataset_name 191 | manifest_filename = '{}_manifest.csv'.format(dataset_name) 192 | all_files = read_standard_csv(manifest_root_folder, manifest_filename) 193 | 194 | exp_name = args.exp_name 195 | inpaint_length = args.inpaint_length 196 | inpaint_every = args.inpaint_every 197 | max_segment_length = args.max_segment_length 198 | 199 | start = args.start 200 | end = args.end if args.end > args.start else len(all_files['test']) 201 | 202 | for row in tqdm(all_files['test'][start:end]): 203 | (audio_filename, duration, sample_rate) = row 204 | inpaint_one_sample(dataset_name, audio_filename, exp_root, exp_name, inpaint_length, inpaint_every, max_segment_length) 205 | 206 | 207 | if __name__ == '__main__': 208 | main() 209 | 210 | --------------------------------------------------------------------------------