├── conf ├── .gitkeep ├── sample_demo_config.yml ├── sample_sampling_config.yml ├── sample_audio2audio_config.yml ├── sample_outpainting_config.yml ├── sample_interpolation_config.yml ├── sample_inpainting_config.yml └── sample_training_config.yml ├── msanii ├── __init__.py ├── modules │ ├── __init__.py │ └── modules.py ├── utils │ ├── __init__.py │ └── utils.py ├── demo │ ├── __init__.py │ ├── utils.py │ ├── helpers.py │ └── demo.py ├── diffusion │ ├── __init__.py │ ├── utils.py │ └── modules.py ├── losses │ ├── __init__.py │ └── spectrogram_losses.py ├── pipeline │ ├── __init__.py │ └── pipeline.py ├── transforms │ ├── __init__.py │ ├── utils.py │ ├── inverse_mel_scale.py │ ├── transforms.py │ ├── feature_scaling.py │ └── griffin_lim.py ├── data │ ├── __init__.py │ ├── audio_datamodule.py │ └── audio_dataset.py ├── models │ ├── __init__.py │ ├── vocoder.py │ └── unet.py ├── lit_models │ ├── __init__.py │ ├── utils.py │ ├── lit_vocoder.py │ └── lit_diffusion.py ├── scripts │ ├── __init__.py │ ├── utils.py │ ├── training.py │ └── inference.py └── config │ ├── demo_config.py │ ├── __init__.py │ ├── utils.py │ ├── inference_config.py │ └── training_config.py ├── setup.cfg ├── .github └── FUNDING.yml ├── requirements.txt ├── .pre-commit-config.yaml ├── LICENSE ├── setup.py ├── .gitignore ├── notebooks ├── msanii_demo.ipynb ├── msanii_training.ipynb └── msanii_inference.ipynb └── README.md /conf/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /msanii/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | profile = black -------------------------------------------------------------------------------- /msanii/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | -------------------------------------------------------------------------------- /msanii/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | -------------------------------------------------------------------------------- /msanii/demo/__init__.py: -------------------------------------------------------------------------------- 1 | from .demo import run_demo 2 | -------------------------------------------------------------------------------- /msanii/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | -------------------------------------------------------------------------------- /msanii/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .spectrogram_losses import * 2 | -------------------------------------------------------------------------------- /msanii/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline import Pipeline 2 | -------------------------------------------------------------------------------- /msanii/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import Transforms 2 | -------------------------------------------------------------------------------- /msanii/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .audio_datamodule import AudioDataModule 2 | -------------------------------------------------------------------------------- /msanii/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import UNet 2 | from .vocoder import Vocoder 3 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [Kinyugo] 4 | -------------------------------------------------------------------------------- /msanii/lit_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .lit_diffusion import LitDiffusion 2 | from .lit_vocoder import LitVocoder 3 | -------------------------------------------------------------------------------- /conf/sample_demo_config.yml: -------------------------------------------------------------------------------- 1 | ckpt_path: checkpoints/pipeline/msanii.pt 2 | device: cpu 3 | dtype: float32 4 | launch: true 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | einops>=0.4 3 | gradio>=3.15.0 4 | lightning 5 | matplotlib>=3.6.2 6 | numpy 7 | omegaconf 8 | torch>=1.6 9 | torchaudio 10 | tqdm 11 | typing_extensions 12 | -------------------------------------------------------------------------------- /msanii/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference import ( 2 | run_audio2audio, 3 | run_inpainting, 4 | run_interpolation, 5 | run_outpainting, 6 | run_sampling, 7 | ) 8 | from .training import run_training 9 | -------------------------------------------------------------------------------- /msanii/config/demo_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from omegaconf import MISSING 4 | 5 | 6 | @dataclass 7 | class DemoConfig: 8 | ckpt_path: str = MISSING 9 | device: str = "cpu" 10 | dtype: str = "float32" 11 | launch: bool = False 12 | -------------------------------------------------------------------------------- /conf/sample_sampling_config.yml: -------------------------------------------------------------------------------- 1 | ckpt_path: msanii.pt 2 | output_dir: samples/sampling 3 | batch_size: 4 4 | num_frames: 8387584 5 | duration: null 6 | output_audio_format: wav 7 | seed: 0 8 | device: cuda 9 | dtype: float32 10 | num_inference_steps: 200 11 | verbose: false 12 | use_neural_vocoder: true 13 | num_griffin_lim_iters: 200 14 | _kwargs_: {} 15 | channels: 2 16 | num_samples: 16 17 | -------------------------------------------------------------------------------- /msanii/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .demo_config import DemoConfig 2 | from .inference_config import ( 3 | Audio2AudioConfig, 4 | InpaintingConfig, 5 | InterpolationConfig, 6 | OutpaintingConfig, 7 | SamplingConfig, 8 | ) 9 | from .training_config import ( 10 | DiffusionTrainingConfig, 11 | TrainingConfig, 12 | VocoderTrainingConfig, 13 | ) 14 | from .utils import from_config 15 | -------------------------------------------------------------------------------- /msanii/transforms/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_complex_dtype(real_dtype: torch.dtype) -> torch.dtype: 5 | if "double" in repr(real_dtype): 6 | return torch.cdouble 7 | if "float16" in repr(real_dtype): 8 | return torch.complex32 9 | if "float" in repr(real_dtype): 10 | return torch.cfloat 11 | 12 | raise ValueError(f"Unexpected dtype {real_dtype}") 13 | -------------------------------------------------------------------------------- /conf/sample_audio2audio_config.yml: -------------------------------------------------------------------------------- 1 | ckpt_path: msanii.pt 2 | output_dir: samples/audio2audio 3 | batch_size: 4 4 | num_frames: 8387584 5 | duration: null 6 | output_audio_format: wav 7 | seed: 0 8 | device: cuda 9 | dtype: float32 10 | num_inference_steps: 200 11 | verbose: false 12 | use_neural_vocoder: true 13 | num_griffin_lim_iters: 200 14 | _kwargs_: {} 15 | data_dir: audio_data_dir 16 | num_workers: 4 17 | pin_memory: false 18 | strength: 1.0 19 | -------------------------------------------------------------------------------- /conf/sample_outpainting_config.yml: -------------------------------------------------------------------------------- 1 | ckpt_path: msanii.pt 2 | output_dir: samples/outpainting 3 | batch_size: 4 4 | num_frames: 8387584 5 | duration: null 6 | output_audio_format: wav 7 | seed: 0 8 | device: cuda 9 | dtype: float32 10 | num_inference_steps: 200 11 | verbose: false 12 | use_neural_vocoder: true 13 | num_griffin_lim_iters: 200 14 | _kwargs_: {} 15 | data_dir: audio_data_dir 16 | num_workers: 4 17 | pin_memory: false 18 | num_spans: 2 19 | eta: 0.0 20 | jump_length: 10 21 | jump_n_sample: 10 22 | -------------------------------------------------------------------------------- /conf/sample_interpolation_config.yml: -------------------------------------------------------------------------------- 1 | ckpt_path: msanii.pt 2 | output_dir: samples/interpolation 3 | batch_size: 4 4 | num_frames: 8387584 5 | duration: null 6 | output_audio_format: wav 7 | seed: 0 8 | device: cuda 9 | dtype: float32 10 | num_inference_steps: 200 11 | verbose: false 12 | use_neural_vocoder: true 13 | num_griffin_lim_iters: 200 14 | _kwargs_: {} 15 | first_data_dir: audio_data_dir_1 16 | second_data_dir: audio_data_dir_2 17 | num_workers: 4 18 | pin_memory: false 19 | ratio: 0.5 20 | strength: 1.0 21 | -------------------------------------------------------------------------------- /conf/sample_inpainting_config.yml: -------------------------------------------------------------------------------- 1 | ckpt_path: msanii.pt 2 | output_dir: samples/inpainting 3 | batch_size: 4 4 | num_frames: 8387584 5 | duration: null 6 | output_audio_format: wav 7 | seed: 0 8 | device: cuda 9 | dtype: float32 10 | num_inference_steps: 200 11 | verbose: false 12 | use_neural_vocoder: true 13 | num_griffin_lim_iters: 200 14 | _kwargs_: {} 15 | data_dir: audio_data_dir 16 | num_workers: 4 17 | pin_memory: false 18 | masks: 19 | - 3-5,6-7,20-100 20 | eta: 0.0 21 | jump_length: 10 22 | jump_n_sample: 10 23 | -------------------------------------------------------------------------------- /msanii/diffusion/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | def noise_like(x: Tensor, generator: Optional[torch.Generator] = None) -> Tensor: 8 | return torch.randn(x.shape, dtype=x.dtype, device=x.device, generator=generator) 9 | 10 | 11 | def sequential_mask(like: Tensor, start: int) -> Tensor: 12 | length, device = like.shape[-1], like.device 13 | mask = torch.ones_like(like, dtype=torch.bool) 14 | mask[..., start:] = torch.zeros((length - start,), device=device) 15 | 16 | return mask 17 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v3.2.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-added-large-files 11 | args: ["--maxkb=5000"] 12 | - repo: https://github.com/psf/black 13 | rev: 22.12.0 14 | hooks: 15 | - id: black-jupyter 16 | - repo: https://github.com/nbQA-dev/nbQA 17 | rev: 1.6.1 18 | hooks: 19 | - id: nbqa-isort 20 | args: ["--float-to-top"] 21 | - repo: https://github.com/roy-ht/pre-commit-jupyter 22 | rev: v1.2.1 23 | hooks: 24 | - id: jupyter-notebook-cleanup 25 | args: 26 | # - --remove-kernel-metadata 27 | - --pin-patterns 28 | - "[pin];[donotremove]" 29 | -------------------------------------------------------------------------------- /msanii/models/vocoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.configuration_utils import ConfigMixin, register_to_config 3 | from torch import Tensor, nn 4 | 5 | from ..modules import InputFn, OutputFn, ResidualBlock 6 | 7 | 8 | class Vocoder(ConfigMixin, nn.Module): 9 | config_name = "vocoder_config.json" 10 | 11 | @register_to_config 12 | def __init__( 13 | self, 14 | n_fft: int = 2048, 15 | n_mels: int = 128, 16 | d_model: int = 256, 17 | d_hidden_factor: int = 4, 18 | ) -> None: 19 | super().__init__() 20 | 21 | self.n_fft = n_fft 22 | self.n_mels = n_mels 23 | self.d_model = d_model 24 | self.d_hidden_factor = d_hidden_factor 25 | 26 | self.fn = nn.Sequential( 27 | InputFn(self.n_mels, self.d_model), 28 | ResidualBlock(self.d_model, self.d_model, self.d_hidden_factor), 29 | OutputFn((self.n_fft // 2 + 1), self.d_model), 30 | ) 31 | 32 | def forward(self, x: Tensor) -> Tensor: 33 | return torch.exp(self.fn(x)) 34 | -------------------------------------------------------------------------------- /msanii/scripts/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Tuple 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | 8 | def generate_batch_audio_mask( 9 | mask_strs: List[str], audio: Tensor, sample_rate: int 10 | ) -> Tensor: 11 | """Generate audio masks for a batch of audio.""" 12 | batch_intervals = list(map(mask_intervals_from_str, mask_strs)) 13 | 14 | audio_mask = torch.ones_like(audio) 15 | for i, intervals in enumerate(batch_intervals): 16 | for start, end in intervals: 17 | start_idx = int(start * sample_rate) 18 | end_idx = int(end * sample_rate) 19 | audio_mask[i, ..., start_idx:end_idx] = 0 20 | 21 | return audio_mask 22 | 23 | 24 | def mask_intervals_from_str(mask_str: str) -> List[Tuple[int, int]]: 25 | """Generates a list of start and end for mask intervals, 26 | e.g: '3-5, 6-7' gives [(3,5), (6,7)] 27 | """ 28 | intervals = re.findall(r"(\d+)-(\d+)", mask_str) 29 | intervals = [tuple(map(int, x)) for x in intervals] 30 | intervals.sort() 31 | 32 | return intervals 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Denis Kinyugo Maina 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /msanii/config/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from inspect import signature 3 | from typing import Any, Callable, Mapping, Optional 4 | 5 | 6 | def from_config( 7 | config: Mapping, target: Optional[Callable] = None, **kwargs: Any 8 | ) -> Any: 9 | # Convert config to a dictionary so we can add more keys 10 | config = dict(config) 11 | 12 | # Add target import path if given 13 | if target is not None: 14 | target_path = get_import_path_from_instance(target) 15 | config.update({"_target_": target_path}) 16 | 17 | # Get parameters of the target 18 | target = import_from_path(config.get("_target_")) 19 | target_signature = signature(target) 20 | target_params = [param for param in target_signature.parameters] 21 | 22 | # Select only arguments that are in the parameters of the target 23 | filtered_config = {k: v for k, v in config.items() if k in target_params} 24 | 25 | # Add placeholder kwargs 26 | if config.get("_kwargs_", None) is None: 27 | config.update({"_kwargs_": {}}) 28 | 29 | # Merge the configs 30 | merged_config = {**filtered_config, **config.get("_kwargs_"), **kwargs} 31 | 32 | return target(**merged_config) 33 | 34 | 35 | def get_import_path_from_instance(instance: Any) -> str: 36 | return f"{instance.__module__}.{instance.__name__}" 37 | 38 | 39 | def import_from_path(import_path: str) -> Callable: 40 | module_name, obj_name = import_path.rsplit(".", maxsplit=1) 41 | module = importlib.import_module(module_name) 42 | 43 | return getattr(module, obj_name) 44 | -------------------------------------------------------------------------------- /msanii/data/audio_datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import lightning as L 4 | from torch.utils.data import DataLoader 5 | 6 | from .audio_dataset import AudioDataset 7 | 8 | 9 | class AudioDataModule(L.LightningDataModule): 10 | def __init__( 11 | self, 12 | data_dir: str, 13 | sample_rate: int = 44_100, 14 | num_frames: Optional[int] = None, 15 | load_random_slice: bool = False, 16 | normalize_amplitude: bool = True, 17 | batch_size: int = 32, 18 | num_workers: int = 4, 19 | pin_memory: bool = False, 20 | shuffle: bool = True, 21 | ) -> None: 22 | 23 | super().__init__() 24 | 25 | self.data_dir = data_dir 26 | self.sample_rate = sample_rate 27 | self.num_frames = num_frames 28 | self.load_random_slice = load_random_slice 29 | self.normalize_amplitude = normalize_amplitude 30 | self.batch_size = batch_size 31 | self.num_workers = num_workers 32 | self.pin_memory = pin_memory 33 | self.shuffle = shuffle 34 | 35 | def setup(self, stage: str = None) -> None: 36 | self.dataset = AudioDataset( 37 | self.data_dir, 38 | self.sample_rate, 39 | self.num_frames, 40 | self.load_random_slice, 41 | self.normalize_amplitude, 42 | ) 43 | 44 | def train_dataloader(self) -> DataLoader: 45 | return DataLoader( 46 | self.dataset, 47 | batch_size=self.batch_size, 48 | num_workers=self.num_workers, 49 | pin_memory=self.pin_memory, 50 | shuffle=self.shuffle, 51 | ) 52 | -------------------------------------------------------------------------------- /msanii/transforms/inverse_mel_scale.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from einops import rearrange, repeat 5 | from torch import Tensor, nn 6 | from torchaudio import functional as F 7 | 8 | 9 | class InverseMelScale(nn.Module): 10 | def __init__( 11 | self, 12 | sample_rate: int, 13 | n_fft: int, 14 | n_mels: int, 15 | f_min: float = 0.0, 16 | f_max: Optional[float] = None, 17 | norm: Optional[str] = None, 18 | mel_scale: str = "htk", 19 | ) -> None: 20 | super().__init__() 21 | 22 | # Compute the inverse filter banks using the pseudo inverse 23 | f_max = f_max or float(sample_rate // 2) 24 | fb = F.melscale_fbanks( 25 | (n_fft // 2 + 1), f_min, f_max, n_mels, sample_rate, norm, mel_scale 26 | ) 27 | # Using pseudo-inverse is faster than calculating the least-squares in each 28 | # forward pass and experiments show that they converge to the same solution 29 | self.register_buffer("fb", torch.linalg.pinv(fb)) 30 | 31 | def forward(self, melspec: Tensor) -> Tensor: 32 | # Flatten the melspec except for the frequency and time dimension 33 | shape = melspec.shape 34 | melspec = rearrange(melspec, "... f t -> (...) f t") 35 | 36 | # Expand the filter banks to match the melspec 37 | fb = repeat(self.fb, "f m -> n m f", n=melspec.shape[0]) 38 | 39 | # Sythesize the stft specgram using the filter banks 40 | specgram = fb @ melspec 41 | # Ensure non-negative solution 42 | specgram = torch.clamp(specgram, min=0.0) 43 | 44 | # Unflatten the specgram (*, freq, time) 45 | specgram = specgram.view(shape[:-2] + (fb.shape[-2], shape[-1])) 46 | 47 | return specgram 48 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | # Metadata 4 | NAME = "msanii" 5 | DESCRIPTION = "Msanii: High Fidelity Music Synthesis on a Shoestring Budget" 6 | URL = "https://github.com/Kinyugo/msanii" 7 | EMAIL = "kinyugomaina@gmail.com" 8 | AUTHOR = "Kinyugo Maina" 9 | REQUIRES_PYTHON = ">=3.6.0" 10 | VERSION = "0.0.1" 11 | 12 | # Required packages 13 | REQUIRED = [ 14 | "torch>=1.6", 15 | "torchaudio", 16 | "lightning", 17 | "diffusers", 18 | "tqdm", 19 | "numpy", 20 | "einops>=0.4", 21 | "gradio>=3.15.0", 22 | "matplotlib>=3.6.2", 23 | "omegaconf", 24 | "typing_extensions", 25 | "tornado", 26 | ] 27 | 28 | # Extra packages 29 | EXTRAS = {} 30 | 31 | # Load long description or fallback to short description 32 | try: 33 | with open("README.md", "r", encoding="Utf-8") as f: 34 | long_description = "\n" + f.read() 35 | except FileNotFoundError: 36 | long_description = DESCRIPTION 37 | 38 | # Setup package 39 | setup( 40 | name=NAME, 41 | version=VERSION, 42 | long_description=long_description, 43 | long_description_content_type="text/markdown", 44 | author=AUTHOR, 45 | author_email=EMAIL, 46 | python_requires=REQUIRES_PYTHON, 47 | url=URL, 48 | packages=find_packages(), 49 | install_requires=REQUIRED, 50 | extras_require=EXTRAS, 51 | include_package_data=True, 52 | license="MIT", 53 | classifiers=[ 54 | "Development Status :: 4 - Beta", 55 | "Intended Audience :: Science/Research", 56 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 57 | "License :: OSI Approved :: MIT License", 58 | "Programming Language :: Python :: 3.6", 59 | ], 60 | keywords=[ 61 | "artificial intelligence", 62 | "deep learning", 63 | "audio synthesis", 64 | "music synthesis", 65 | ], 66 | ) 67 | -------------------------------------------------------------------------------- /msanii/config/inference_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, Dict, List, Optional 3 | 4 | from omegaconf import MISSING 5 | 6 | 7 | @dataclass 8 | class SharedConfig: 9 | ckpt_path: str = MISSING 10 | output_dir: str = MISSING 11 | 12 | batch_size: int = 4 13 | num_frames: int = 8_387_584 14 | duration: Optional[int] = None 15 | output_audio_format: str = "wav" 16 | 17 | seed: int = 0 18 | device: str = "cpu" 19 | dtype: str = "float" 20 | 21 | num_inference_steps: Optional[int] = None 22 | verbose: bool = False 23 | use_neural_vocoder: bool = True 24 | num_griffin_lim_iters: Optional[int] = None 25 | 26 | _kwargs_: Dict[str, Any] = field(default_factory=lambda: {}) 27 | 28 | 29 | @dataclass 30 | class SamplingConfig(SharedConfig): 31 | channels: int = 2 32 | num_samples: int = 16 33 | 34 | 35 | @dataclass 36 | class Audio2AudioConfig(SharedConfig): 37 | data_dir: str = MISSING 38 | num_workers: int = 4 39 | pin_memory: bool = False 40 | 41 | strength: float = 1.0 42 | 43 | 44 | @dataclass 45 | class InterpolationConfig(SharedConfig): 46 | first_data_dir: str = MISSING 47 | second_data_dir: str = MISSING 48 | num_workers: int = 4 49 | pin_memory: bool = False 50 | 51 | ratio: float = 0.5 52 | strength: float = 1.0 53 | 54 | 55 | @dataclass 56 | class InpaintingConfig(SharedConfig): 57 | data_dir: str = MISSING 58 | num_workers: int = 4 59 | pin_memory: bool = False 60 | 61 | masks: List[str] = MISSING 62 | eta: float = 0.0 63 | jump_length: int = 10 64 | jump_n_sample: int = 10 65 | 66 | 67 | @dataclass 68 | class OutpaintingConfig(SharedConfig): 69 | data_dir: str = MISSING 70 | num_workers: int = 4 71 | pin_memory: bool = False 72 | 73 | num_spans: int = 2 74 | eta: float = 0.0 75 | jump_length: int = 10 76 | jump_n_sample: int = 10 77 | -------------------------------------------------------------------------------- /msanii/losses/spectrogram_losses.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, TypedDict 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class SCLoss(nn.Module): 9 | def __init__(self, eps: float = 1e-6) -> None: 10 | super().__init__() 11 | 12 | self.eps = eps 13 | 14 | def forward(self, input_specgram: Tensor, target_specgram: Tensor) -> Tensor: 15 | return torch.norm((target_specgram - input_specgram), p="fro") / ( 16 | torch.norm(target_specgram, p="fro") + self.eps 17 | ) 18 | 19 | 20 | class LogMagnitudeLoss(nn.Module): 21 | def __init__(self, distance: str = "l1", eps: float = 1e-6) -> None: 22 | super().__init__() 23 | 24 | self.distance = distance 25 | self.eps = eps 26 | 27 | def forward(self, input_specgram: Tensor, target_specgram: Tensor) -> Tensor: 28 | input_specgram = torch.log(input_specgram + self.eps) 29 | target_specgram = torch.log(target_specgram + self.eps) 30 | 31 | if self.distance == "l1": 32 | return F.l1_loss(input_specgram, target_specgram) 33 | return F.mse_loss(input_specgram, target_specgram) 34 | 35 | 36 | class SpectrogramLossDict(TypedDict): 37 | sc_loss: Tensor 38 | lm_loss: Tensor 39 | 40 | 41 | class SpectrogramLoss(nn.Module): 42 | def __init__( 43 | self, 44 | w_sc_loss: float = 1.0, 45 | w_lm_loss: float = 1.0, 46 | eps: float = 1e-6, 47 | distance: str = "l1", 48 | ) -> None: 49 | super().__init__() 50 | 51 | self.w_sc_loss = w_sc_loss 52 | self.w_lm_loss = w_lm_loss 53 | 54 | self.sc_loss = SCLoss(eps) 55 | self.lm_loss = LogMagnitudeLoss(distance, eps) 56 | 57 | def forward( 58 | self, input_specgram: Tensor, target_specgram: Tensor 59 | ) -> Tuple[Tensor, SpectrogramLossDict]: 60 | sc_loss = self.w_sc_loss * self.sc_loss(input_specgram, target_specgram) 61 | lm_loss = self.w_lm_loss * self.lm_loss(input_specgram, target_specgram) 62 | 63 | total_loss = sc_loss + lm_loss 64 | losses = SpectrogramLossDict(sc_loss=sc_loss, lm_loss=lm_loss) 65 | 66 | return total_loss, losses 67 | -------------------------------------------------------------------------------- /msanii/lit_models/utils.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | from lightning.pytorch.loggers import WandbLogger 3 | from torch import Tensor, nn 4 | 5 | import wandb 6 | 7 | from ..utils import plot_distribution, plot_spectrogram, plot_waveform 8 | 9 | 10 | def log_waveform( 11 | logger: WandbLogger, waveform: Tensor, sample_rate: int, id: str, caption: str = "" 12 | ) -> None: 13 | logger.experiment.log( 14 | { 15 | f"{id}_{idx}": plot_waveform(waveform[idx], sample_rate, caption) 16 | for idx in range(waveform.shape[0]) 17 | } 18 | ) 19 | 20 | 21 | def log_spectrogram( 22 | logger: WandbLogger, spectrogram: Tensor, id: str, caption: str = "" 23 | ) -> None: 24 | logger.experiment.log( 25 | { 26 | f"{id}_{idx}": plot_spectrogram(spectrogram[idx], caption) 27 | for idx in range(spectrogram.shape[0]) 28 | } 29 | ) 30 | 31 | 32 | def log_distribution( 33 | logger: WandbLogger, x: Tensor, id: str, caption: str = "" 34 | ) -> None: 35 | logger.experiment.log( 36 | {f"{id}_{idx}": plot_distribution(x[idx], caption) for idx in range(x.shape[0])} 37 | ) 38 | 39 | 40 | def log_audio( 41 | logger: WandbLogger, audio: Tensor, sample_rate: int, id: str, caption: str = "" 42 | ) -> None: 43 | audio = rearrange(audio, "b c t -> b t c").detach().cpu().numpy() 44 | logger.experiment.log( 45 | { 46 | f"{id}_{idx}": wandb.Audio(audio[idx], sample_rate, caption) 47 | for idx in range(audio.shape[0]) 48 | } 49 | ) 50 | 51 | 52 | def log_samples( 53 | logger: WandbLogger, 54 | waveform: Tensor, 55 | spectrogram: Tensor, 56 | sample_rate: int, 57 | tag: str, 58 | ) -> None: 59 | log_spectrogram(logger, spectrogram, f"spectrogram/{tag}", tag) 60 | log_distribution(logger, spectrogram, f"distribution/{tag}", tag) 61 | log_waveform(logger, waveform, sample_rate, f"waveform/{tag}", tag) 62 | log_audio(logger, waveform, sample_rate, f"audio/{tag}", tag) 63 | 64 | 65 | def update_ema_model(src_model: nn.Module, ema_model: nn.Module, decay: float) -> None: 66 | for ema_param, src_param in zip(ema_model.parameters(), src_model.parameters()): 67 | if ema_param.data is None: 68 | ema_param.data = src_param.data 69 | else: 70 | ema_param.data = decay * ema_param.data + (1 - decay) * src_param.data 71 | -------------------------------------------------------------------------------- /msanii/data/audio_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from typing import Optional, Tuple 4 | 5 | import torchaudio 6 | from torch import Tensor 7 | from torch.nn import functional as F 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class AudioDataset(Dataset): 12 | def __init__( 13 | self, 14 | data_dir: str, 15 | sample_rate: int, 16 | num_frames: Optional[int] = None, 17 | load_random_slice: bool = False, 18 | normalize_amplitude: bool = True, 19 | ) -> None: 20 | super().__init__() 21 | 22 | self.data_dir = os.path.expanduser(data_dir) 23 | self.sample_rate = sample_rate 24 | self.num_frames = num_frames 25 | self.load_random_slice = load_random_slice 26 | self.normalize_amplitude = normalize_amplitude 27 | 28 | self.filenames = os.listdir(self.data_dir) 29 | 30 | def __len__(self) -> int: 31 | return len(self.filenames) 32 | 33 | def __getitem__(self, index: int) -> Tensor: 34 | filepath = os.path.join(self.data_dir, self.filenames[index]) 35 | 36 | if self.load_random_slice: 37 | waveform, sample_rate = self.__load_random_audio_slice(filepath) 38 | else: 39 | waveform, sample_rate = torchaudio.load(filepath) 40 | 41 | waveform = self.__resample(waveform, sample_rate) 42 | waveform = self.__pad(waveform) 43 | waveform = self.__normalize_amplitude(waveform) 44 | 45 | return waveform.clamp(min=-1.0, max=1.0) 46 | 47 | def __load_random_audio_slice(self, filepath: str) -> Tuple[Tensor, int]: 48 | metadata = torchaudio.info(filepath) 49 | frames_to_load = int( 50 | (metadata.sample_rate / self.sample_rate) * self.num_frames 51 | ) 52 | frame_offset = random.randint(0, max(0, metadata.num_frames - frames_to_load)) 53 | 54 | waveform, sample_rate = torchaudio.load( 55 | filepath, num_frames=frames_to_load, frame_offset=frame_offset 56 | ) 57 | 58 | return waveform, sample_rate 59 | 60 | def __resample(self, x: Tensor, sample_rate: int) -> Tensor: 61 | if sample_rate != self.sample_rate: 62 | return torchaudio.functional.resample(x, sample_rate, self.sample_rate) 63 | return x 64 | 65 | def __pad(self, x: Tensor) -> Tensor: 66 | if self.num_frames: 67 | return F.pad(x, (0, self.num_frames - x.shape[-1]), value=0.0) 68 | return x 69 | 70 | def __normalize_amplitude(self, x: Tensor) -> Tensor: 71 | if self.normalize_amplitude: 72 | return x / x.abs().max() 73 | return x 74 | -------------------------------------------------------------------------------- /msanii/demo/utils.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | 3 | import numpy as np 4 | import torch 5 | from einops import rearrange 6 | from torch import Tensor 7 | from torch.nn import functional as F 8 | from torchaudio import functional as AF 9 | 10 | from ..utils import pad_to_divisible_length 11 | 12 | 13 | def gradio_audio_preprocessing( 14 | audio: np.ndarray, 15 | src_sample_rate: int, 16 | target_sample_rate: int, 17 | target_length: int, 18 | hop_length: int, 19 | num_downsamples: int, 20 | dtype: torch.dtype, 21 | device: torch.device, 22 | pad_end: bool = True, 23 | ) -> Tensor: 24 | # Ensure audio is a float tensor between [-1, 1] 25 | if issubclass(audio.dtype.type, numbers.Integral): 26 | audio = audio / np.iinfo(audio.dtype).max 27 | 28 | # Load audio into tensor and resample 29 | audio = torch.from_numpy(audio) 30 | if audio.ndim == 1: 31 | audio = rearrange(audio, "l -> () () l") # to batched and mono-channels 32 | else: 33 | audio = rearrange(audio, "l c -> () c l") # to batched channel first 34 | audio = AF.resample(audio, src_sample_rate, target_sample_rate) 35 | 36 | # Rescale target length by the sample rate 37 | target_length = int((target_length * target_sample_rate) / src_sample_rate) 38 | 39 | # Pad audio to the target length 40 | if pad_end: 41 | audio = F.pad(audio, (0, target_length - audio.shape[-1])) 42 | else: 43 | audio = F.pad(audio, (target_length - audio.shape[-1], 0)) 44 | 45 | # Pad audio to a length divisible by the number of downsampling layers 46 | audio = pad_to_divisible_length(audio, hop_length, num_downsamples, pad_end) 47 | 48 | # Switch target dtype and device 49 | audio = audio.to(dtype).to(device) 50 | 51 | return audio 52 | 53 | 54 | def gradio_audio_postprocessing( 55 | audio: Tensor, target_length: int, pad_end: bool = True 56 | ) -> np.ndarray: 57 | # Ensure audio is the correct length 58 | if pad_end: 59 | audio = F.pad(audio, (0, target_length - audio.shape[-1])) 60 | else: 61 | audio = F.pad(audio, (target_length - audio.shape[-1], 0)) 62 | 63 | # Remove batch dimension & switch to channels last 64 | audio = rearrange(audio, "b c l -> (l b) c") 65 | 66 | return audio.detach().cpu().numpy() 67 | 68 | 69 | def generate_gradio_audio_mask( 70 | audio: np.ndarray, sample_rate: int, spec: str 71 | ) -> np.ndarray: 72 | # Convert mask string to a list of tuples of time intevals 73 | mask_intervals = [] 74 | for mask in spec.split(","): 75 | start, end = map(int, mask.split("-")) 76 | mask_intervals.append((start, end)) 77 | 78 | # Create a numpy array of zeros with the same shape as input 79 | mask = np.ones_like(audio) 80 | 81 | # Set the values at the specified time intervals to 1 82 | for start, end in mask_intervals: 83 | start_sample = int(start * sample_rate) 84 | end_sample = int(end * sample_rate) 85 | mask[start_sample:end_sample, ...] = 0 86 | 87 | return mask 88 | 89 | 90 | def max_abs_scaling(x: Tensor, max_abs_value: float = 0.05) -> Tensor: 91 | x = x / x.abs().max() 92 | x = x * max_abs_value 93 | 94 | return x 95 | -------------------------------------------------------------------------------- /msanii/utils/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from einops import reduce 5 | from matplotlib import pyplot as plt 6 | from torch import Tensor, nn 7 | from torch.nn import functional as F 8 | 9 | 10 | def plot_waveform(waveform: Tensor, sample_rate: int, title: str = "") -> plt.Figure: 11 | waveform = reduce(waveform, "... l -> l", reduction="mean") 12 | waveform = waveform.detach().cpu() 13 | 14 | n_frames = waveform.shape[-1] 15 | skip = int(n_frames / (0.01 * n_frames)) 16 | waveform = waveform[..., 0:-1:skip] 17 | 18 | n_frames = waveform.shape[-1] 19 | time_axis = torch.linspace(0, n_frames / (sample_rate / skip), steps=n_frames) 20 | 21 | fig = plt.figure(dpi=300) 22 | plt.plot(time_axis, waveform, linewidth=1) 23 | plt.grid(True) 24 | plt.title(title) 25 | plt.xlabel("Time (s)") 26 | plt.ylabel("Amplitude") 27 | 28 | return fig 29 | 30 | 31 | def plot_spectrogram(spectrogram: Tensor, title: str = "") -> plt.Figure: 32 | spectrogram = reduce(spectrogram, "... f t-> f t", reduction="mean") 33 | spectrogram = spectrogram.detach().cpu() 34 | 35 | fig = plt.figure(dpi=300) 36 | plt.imshow(spectrogram, origin="lower", aspect="auto", cmap="magma") 37 | plt.colorbar() 38 | plt.title(title) 39 | plt.xlabel("Time") 40 | plt.ylabel("Frequency") 41 | 42 | return fig 43 | 44 | 45 | def plot_distribution(x: Tensor, title: str = "") -> plt.Figure: 46 | x = x.detach().cpu() 47 | mean, std = x.mean(), x.std() 48 | 49 | hist, edges = torch.histogram(x, density=True) 50 | 51 | fig = plt.figure(dpi=300) 52 | plt.plot(edges[:-1], hist) 53 | plt.title(f"{title} | Mean: {mean:.4f} Std: {std:.4f}") 54 | plt.xlabel("X") 55 | plt.ylabel("Density") 56 | 57 | return fig 58 | 59 | 60 | def freeze_model(model: nn.Module) -> nn.Module: 61 | model = model.eval() 62 | for param in model.parameters(): 63 | param.requires_grad = False 64 | 65 | return model 66 | 67 | 68 | def clone_model_parameters(src_model: nn.Module, target_model: nn.Module) -> nn.Module: 69 | for src_param, target_param in zip( 70 | src_model.parameters(), target_model.parameters() 71 | ): 72 | target_param.data = src_param.data 73 | 74 | return target_model 75 | 76 | 77 | def compute_divisible_length( 78 | curr_length: int, hop_length: int, num_downsamples: int 79 | ) -> int: 80 | # Current time frame size 81 | num_time_frames = int((curr_length / hop_length) + 1) 82 | # Divisible time frames 83 | divisible_time_frames = math.ceil(num_time_frames / 2**num_downsamples) * ( 84 | 2**num_downsamples 85 | ) 86 | divisible_length = (divisible_time_frames - 1) * hop_length 87 | 88 | return divisible_length 89 | 90 | 91 | def pad_to_divisible_length( 92 | x: Tensor, hop_length: int, num_downsamples: int, pad_end: bool = True 93 | ) -> Tensor: 94 | divisible_length = compute_divisible_length( 95 | x.shape[-1], hop_length, num_downsamples 96 | ) 97 | # Pad to appropriate length 98 | if pad_end: 99 | x = F.pad(x, (0, divisible_length - x.shape[-1])) 100 | else: 101 | x = F.pad(x, (divisible_length - x.shape[-1], 0)) 102 | 103 | return x 104 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # VsCode 163 | .vscode 164 | 165 | # Test config 166 | conf/*test* 167 | 168 | # Checkpoints saved during runs 169 | checkpoints/ 170 | 171 | # Logs created during runs 172 | logs/ 173 | 174 | # Random collection of scripts and notebooks for hacking 175 | playground/ 176 | 177 | # Generated samples 178 | samples/ 179 | 180 | # Dummy data for testing 181 | dummy_data/ 182 | -------------------------------------------------------------------------------- /msanii/lit_models/lit_vocoder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | import torch 4 | from lightning.pytorch import LightningModule 5 | from torch import Tensor, optim 6 | 7 | from ..losses import SpectrogramLoss 8 | from ..models import Vocoder 9 | from ..transforms import Transforms 10 | from .utils import log_samples 11 | 12 | 13 | class LitVocoder(LightningModule): 14 | def __init__( 15 | self, 16 | transforms: Transforms, 17 | vocoder: Vocoder, 18 | sample_rate: int = 44_100, 19 | transforms_decay: float = 0.999, 20 | lr: float = 2e-4, 21 | betas: Tuple[float, float] = (0.5, 0.999), 22 | lr_scheduler_start_factor: float = 1 / 3, 23 | lr_scheduler_iters: int = 500, 24 | sample_every_n_epochs: int = 10, 25 | num_samples: int = 4, 26 | ) -> None: 27 | super().__init__() 28 | 29 | self.save_hyperparameters(ignore=["transforms", "vocoder"]) 30 | 31 | self.transforms = transforms 32 | self.vocoder = vocoder 33 | self.sample_rate = sample_rate 34 | self.transforms_decay = transforms_decay 35 | self.lr = lr 36 | self.betas = betas 37 | self.lr_scheduler_start_factor = lr_scheduler_start_factor 38 | self.lr_scheduler_iters = lr_scheduler_iters 39 | self.sample_every_n_epochs = sample_every_n_epochs 40 | self.num_samples = num_samples 41 | 42 | self.loss_fn = SpectrogramLoss() 43 | 44 | def training_step( 45 | self, batch: Union[Tensor, List[Tensor]], batch_idx: int 46 | ) -> Tensor: 47 | # Drop labels if present 48 | if isinstance(batch, list): 49 | batch = batch[0] 50 | 51 | # Reconstruct magnitude stft spectrogram from mel spectrograms 52 | mel_spectrograms = self.transforms(batch) 53 | pred_mag_spectrograms = self.vocoder(mel_spectrograms) 54 | 55 | # Compute & log spectral losses 56 | target_mag_spectrograms = self.transforms.spectrogram(batch) 57 | loss, loss_dict = self.loss_fn(pred_mag_spectrograms, target_mag_spectrograms) 58 | self.log_dict({"total_loss": loss}) 59 | self.log_dict(loss_dict) 60 | 61 | # Update & log transforms parameters 62 | self.transforms.step() 63 | self.log_dict(self.transforms.params_dict) 64 | 65 | # Sample & log waveforms, spectrograms, distributions & audio samples 66 | if (self.current_epoch % self.sample_every_n_epochs == 0) and batch_idx == 0: 67 | self.__sample_and_log_samples( 68 | batch, mel_spectrograms, pred_mag_spectrograms 69 | ) 70 | 71 | return loss 72 | 73 | def configure_optimizers(self): 74 | opt = optim.Adam(self.vocoder.parameters(), lr=self.lr, betas=self.betas) 75 | sched = optim.lr_scheduler.LinearLR( 76 | opt, 77 | start_factor=self.lr_scheduler_start_factor, 78 | total_iters=self.lr_scheduler_iters, 79 | ) 80 | sched = {"scheduler": sched, "interval": "step"} 81 | 82 | return [opt], [sched] 83 | 84 | @torch.no_grad() 85 | def __sample_and_log_samples( 86 | self, waveforms: Tensor, mel_spectrograms: Tensor, pred_mag_spectrograms: Tensor 87 | ) -> None: 88 | # Ensure the number of samples does not exceed the batch size 89 | num_samples = min(self.num_samples, waveforms.shape[0]) 90 | 91 | # Ground truth samples 92 | log_samples( 93 | self.logger, 94 | waveforms[:num_samples], 95 | mel_spectrograms[:num_samples], 96 | self.sample_rate, 97 | "ground_truth", 98 | ) 99 | 100 | # Neural vocoder samples 101 | pred_batch = self.transforms.griffin_lim( 102 | pred_mag_spectrograms[:num_samples], length=waveforms.shape[-1] 103 | ) 104 | pred_mel_spectrograms = self.transforms(pred_batch) 105 | log_samples( 106 | self.logger, 107 | pred_batch, 108 | pred_mel_spectrograms, 109 | self.sample_rate, 110 | "neural_vocoder", 111 | ) 112 | 113 | # Direct reconstruction samples 114 | pred_batch = self.transforms.inverse_transform( 115 | mel_spectrograms[:num_samples], length=waveforms.shape[-1] 116 | ) 117 | pred_mel_spectrograms = self.transforms(pred_batch) 118 | log_samples( 119 | self.logger, 120 | pred_batch, 121 | pred_mel_spectrograms, 122 | self.sample_rate, 123 | "direct_reconstruction", 124 | ) 125 | -------------------------------------------------------------------------------- /msanii/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple 2 | 3 | import torch 4 | from diffusers.configuration_utils import ConfigMixin, register_to_config 5 | from torch import Tensor, nn 6 | from torchaudio import transforms as T 7 | 8 | from .feature_scaling import MinMaxScaler, StandardScaler 9 | from .griffin_lim import GriffinLim 10 | from .inverse_mel_scale import InverseMelScale 11 | 12 | 13 | class Transforms(ConfigMixin, nn.Module): 14 | config_name = "transforms_config.json" 15 | 16 | @register_to_config 17 | def __init__( 18 | self, 19 | sample_rate: int = 44_100, 20 | n_fft: int = 2048, 21 | win_length: Optional[int] = None, 22 | hop_length: Optional[int] = None, 23 | n_mels: int = 128, 24 | feature_range: Tuple[float, float] = (-1.0, 1.0), 25 | momentum: float = 1e-3, 26 | momentum_decay: float = 0.99, 27 | eps: float = 1e-5, 28 | clip: bool = True, 29 | num_griffin_lim_iters: int = 100, 30 | griffin_lim_momentum: float = 0.99, 31 | ) -> None: 32 | super().__init__() 33 | 34 | self.sample_rate = sample_rate 35 | self.n_fft = n_fft 36 | self.win_length = win_length or self.n_fft 37 | self.hop_length = hop_length or (self.win_length // 2) 38 | self.n_mels = n_mels 39 | self.num_griffin_lim_iters = num_griffin_lim_iters 40 | self.griffin_lim_momentum = griffin_lim_momentum 41 | 42 | self.spectrogram = T.Spectrogram( 43 | n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=1.0 44 | ) 45 | self.complex_spectrogram = T.Spectrogram( 46 | n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=None 47 | ) 48 | self.inverse_spectrogram = T.InverseSpectrogram( 49 | n_fft=n_fft, win_length=win_length, hop_length=hop_length 50 | ) 51 | self.mel_scale = T.MelScale( 52 | sample_rate=sample_rate, n_mels=n_mels, n_stft=(n_fft // 2 + 1) 53 | ) 54 | self.inverse_mel_scale = InverseMelScale( 55 | sample_rate=sample_rate, n_fft=n_fft, n_mels=n_mels 56 | ) 57 | self.griffin_lim = GriffinLim( 58 | num_iters=num_griffin_lim_iters, 59 | momentum=griffin_lim_momentum, 60 | n_fft=n_fft, 61 | win_length=win_length, 62 | hop_length=hop_length, 63 | power=1.0, 64 | ) 65 | self.standard_scaler = StandardScaler(momentum, momentum_decay, eps) 66 | self.minmax_scaler = MinMaxScaler(feature_range, momentum, momentum_decay, clip) 67 | 68 | @property 69 | def params_dict(self) -> Dict[str, Any]: 70 | return { 71 | "standard_scaler/momentum": self.standard_scaler.momentum, 72 | "standard_scaler/mean": self.standard_scaler.running_mean, 73 | "standard_scaler/var": self.standard_scaler.running_var, 74 | "minmax_scaler/momentum": self.minmax_scaler.momentum, 75 | "minmax_scaler/min": self.minmax_scaler.running_min, 76 | "minmax_scaler/max": self.minmax_scaler.running_max, 77 | } 78 | 79 | def forward( 80 | self, 81 | x: Tensor, 82 | inverse: bool = False, 83 | length: Optional[int] = None, 84 | num_griffin_lim_iters: Optional[int] = None, 85 | ) -> Tensor: 86 | if inverse: 87 | return self.inverse_transform(x, length, num_griffin_lim_iters) 88 | return self.transform(x) 89 | 90 | def transform(self, x: Tensor) -> Tensor: 91 | x_transformed = self.spectrogram(x) 92 | x_transformed = self.mel_scale(x_transformed) 93 | x_transformed = torch.log(x_transformed + 1e-5) 94 | x_transformed = self.standard_scaler(x_transformed) 95 | x_transformed = self.minmax_scaler(x_transformed) 96 | 97 | return x_transformed 98 | 99 | def inverse_transform( 100 | self, 101 | x: Tensor, 102 | length: Optional[int] = None, 103 | num_griffin_lim_iters: Optional[int] = None, 104 | ) -> Tensor: 105 | x_transformed = self.minmax_scaler(x, inverse=True) 106 | x_transformed = self.standard_scaler(x_transformed, inverse=True) 107 | x_transformed = torch.exp(x_transformed) 108 | x_transformed = self.inverse_mel_scale(x_transformed) 109 | x_transformed = self.griffin_lim( 110 | x_transformed, length=length, num_iters=num_griffin_lim_iters 111 | ) 112 | 113 | return x_transformed 114 | 115 | @torch.no_grad() 116 | def step(self) -> None: 117 | self.standard_scaler.step() 118 | self.minmax_scaler.step() 119 | -------------------------------------------------------------------------------- /notebooks/msanii_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Msanii Demo" 9 | ] 10 | }, 11 | { 12 | "attachments": {}, 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## GPU Check" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "!nvidia-smi" 26 | ] 27 | }, 28 | { 29 | "attachments": {}, 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Setup\n", 34 | "\n", 35 | "Run one of the the below install options. \n", 36 | "> **WARNING:** Restart the runtime or some packages will not be updated!" 37 | ] 38 | }, 39 | { 40 | "attachments": {}, 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "### Install package from git" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "%pip install -q git+https://github.com/Kinyugo/msanii.git" 54 | ] 55 | }, 56 | { 57 | "attachments": {}, 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "### Install package in edit mode" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "!git clone https://github.com/Kinyugo/msanii.git\n", 71 | "!cd msanii\n", 72 | "%pip install -q -r requirements.txt\n", 73 | "%pip install -e ." 74 | ] 75 | }, 76 | { 77 | "attachments": {}, 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "### Fetch model checkpoint" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "%pip install -q gdown --upgrade --no-cache" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "!gdown 1G9kF0r5vxYXPSdSuv4t3GR-sBO8xGFCe" 100 | ] 101 | }, 102 | { 103 | "attachments": {}, 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "### Imports" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "from omegaconf import OmegaConf\n", 117 | "\n", 118 | "from msanii.config import DemoConfig\n", 119 | "from msanii.demo import run_demo" 120 | ] 121 | }, 122 | { 123 | "attachments": {}, 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "## Run Demo" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "dict_config = {\n", 137 | " \"ckpt_path\": \"\",\n", 138 | " \"device\": \"cuda\", # cpu or cuda\n", 139 | " \"dtype\": \"float32\", # torch.dtype\n", 140 | " \"launch\": False, # launch demo?\n", 141 | "}" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "default_config = OmegaConf.structured(DemoConfig)\n", 151 | "custom_config = OmegaConf.create(dict_config)\n", 152 | "config = OmegaConf.merge(default_config, custom_config)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "demo = run_demo(config)\n", 162 | "demo.launch(debug=True)" 163 | ] 164 | } 165 | ], 166 | "metadata": { 167 | "kernelspec": { 168 | "display_name": "torch", 169 | "language": "python", 170 | "name": "python3" 171 | }, 172 | "language_info": { 173 | "name": "python", 174 | "version": "3.10.5" 175 | }, 176 | "orig_nbformat": 4, 177 | "vscode": { 178 | "interpreter": { 179 | "hash": "c339664639c3e5019e3803d0baff2aab4fdaac0204aae143f6ed0f1a6cb76161" 180 | } 181 | } 182 | }, 183 | "nbformat": 4, 184 | "nbformat_minor": 2 185 | } 186 | -------------------------------------------------------------------------------- /conf/sample_training_config.yml: -------------------------------------------------------------------------------- 1 | vocoder: 2 | datamodule: 3 | _target_: msanii.data.AudioDataModule 4 | data_dir: audio_data_dir 5 | sample_rate: 44100 6 | num_frames: 523264 7 | load_random_slice: true 8 | normalize_amplitude: false 9 | batch_size: 8 10 | num_workers: 8 11 | pin_memory: true 12 | transforms: 13 | _target_: msanii.transforms.Transforms 14 | sample_rate: 44100 15 | n_fft: 2048 16 | win_length: null 17 | hop_length: null 18 | n_mels: 128 19 | feature_range: 20 | - -1.0 21 | - 1.0 22 | momentum: 0.001 23 | eps: 1.0e-05 24 | clip: true 25 | num_griffin_lim_iters: 100 26 | griffin_lim_momentum: 0.99 27 | vocoder: 28 | _target_: msanii.models.Vocoder 29 | n_fft: 2048 30 | n_mels: 128 31 | d_model: 256 32 | d_hidden_factor: 4 33 | lit_vocoder: 34 | _target_: msanii.lit_models.LitVocoder 35 | sample_rate: 44100 36 | transforms_decay: 0.999 37 | lr: 0.0002 38 | betas: 39 | - 0.5 40 | - 0.999 41 | lr_scheduler_start_factor: 0.3333333333333333 42 | lr_scheduler_iters: 500 43 | sample_every_n_epochs: 20 44 | num_samples: 4 45 | wandb_logger: 46 | _target_: lightning.pytorch.loggers.WandbLogger 47 | save_dir: logs 48 | project: msanii 49 | name: null 50 | job_type: train 51 | log_model: true 52 | tags: null 53 | notes: null 54 | save_code: true 55 | offline: false 56 | _kwargs_: {} 57 | model_checkpoint: 58 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 59 | dirpath: null 60 | save_last: true 61 | verbose: false 62 | mode: min 63 | _kwargs_: {} 64 | trainer: 65 | _target_: lightning.Trainer 66 | accelerator: auto 67 | accumulate_grad_batches: 1 68 | devices: null 69 | default_root_dir: null 70 | detect_anomaly: false 71 | gradient_clip_val: 1.0 72 | gradient_clip_algorithm: norm 73 | limit_train_batches: 1.0 74 | log_every_n_steps: 10 75 | precision: 16 76 | max_epochs: 201 77 | max_steps: -1 78 | weights_save_path: null 79 | fast_dev_run: false 80 | _kwargs_: {} 81 | skip_training: false 82 | resume_ckpt_path: null 83 | diffusion: 84 | datamodule: 85 | _target_: msanii.data.AudioDataModule 86 | data_dir: audio_data_dir 87 | sample_rate: 44100 88 | num_frames: 8387584 89 | load_random_slice: true 90 | normalize_amplitude: false 91 | batch_size: 4 92 | num_workers: 8 93 | pin_memory: true 94 | unet: 95 | _target_: msanii.models.UNet 96 | d_freq: 128 97 | d_base: 256 98 | d_hidden_factor: 4 99 | d_multipliers: 100 | - 1 101 | - 1 102 | - 1 103 | - 1 104 | - 1 105 | - 1 106 | - 1 107 | d_timestep: 128 108 | dilations: 109 | - 1 110 | - 1 111 | - 1 112 | - 1 113 | - 1 114 | - 1 115 | - 1 116 | n_heads: 8 117 | has_attention: 118 | - false 119 | - false 120 | - false 121 | - false 122 | - false 123 | - true 124 | - true 125 | has_resampling: 126 | - true 127 | - true 128 | - true 129 | - true 130 | - true 131 | - true 132 | - false 133 | n_block_layers: 134 | - 2 135 | - 2 136 | - 2 137 | - 2 138 | - 2 139 | - 2 140 | - 2 141 | scheduler: 142 | _target_: diffusers.DDIMScheduler 143 | num_train_timesteps: 1000 144 | beta_schedule: squaredcos_cap_v2 145 | _kwargs_: {} 146 | lit_diffusion: 147 | _target_: msanii.lit_models.LitDiffusion 148 | sample_rate: 44100 149 | transforms_decay: 0.999 150 | ema_decay: 0.995 151 | ema_start_step: 2000 152 | ema_update_every: 10 153 | lr: 0.0002 154 | betas: 155 | - 0.5 156 | - 0.999 157 | lr_scheduler_start_factor: 0.3333333333333333 158 | lr_scheduler_iters: 500 159 | sample_every_n_epochs: 20 160 | num_samples: 4 161 | num_inference_steps: 20 162 | wandb_logger: 163 | _target_: lightning.pytorch.loggers.WandbLogger 164 | save_dir: logs 165 | project: msanii 166 | name: null 167 | job_type: train 168 | log_model: true 169 | tags: null 170 | notes: null 171 | save_code: true 172 | offline: false 173 | _kwargs_: {} 174 | model_checkpoint: 175 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 176 | dirpath: null 177 | save_last: true 178 | verbose: false 179 | mode: min 180 | _kwargs_: {} 181 | trainer: 182 | _target_: lightning.Trainer 183 | accelerator: auto 184 | accumulate_grad_batches: 1 185 | devices: null 186 | default_root_dir: null 187 | detect_anomaly: false 188 | gradient_clip_val: 1.0 189 | gradient_clip_algorithm: norm 190 | limit_train_batches: 1.0 191 | log_every_n_steps: 10 192 | precision: 16 193 | max_epochs: 501 194 | max_steps: -1 195 | weights_save_path: null 196 | fast_dev_run: false 197 | _kwargs_: {} 198 | skip_training: false 199 | resume_ckpt_path: null 200 | seed: 0 201 | pipeline_wandb_name: msanii_pipeline 202 | pipeline_ckpt_path: checkpoints/msanii.pt 203 | -------------------------------------------------------------------------------- /msanii/transforms/feature_scaling.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | 6 | 7 | class StandardScaler(nn.Module): 8 | def __init__( 9 | self, momentum: float = 1e-3, momentum_decay: float = 0.99, eps: float = 1e-5 10 | ) -> None: 11 | super().__init__() 12 | 13 | self.eps = eps 14 | self.momentum_decay = momentum_decay 15 | 16 | self.register_buffer("momentum", torch.tensor(momentum)) 17 | self.register_buffer("running_mean", torch.tensor(0.0)) 18 | self.register_buffer("running_var", torch.tensor(1.0)) 19 | self.register_buffer("fitted", torch.tensor(False)) 20 | 21 | def forward(self, x: Tensor, inverse: bool = False) -> Tensor: 22 | if inverse: 23 | return self.inverse_transform(x) 24 | return self.transform(x) 25 | 26 | def transform(self, x: Tensor) -> Tensor: 27 | # Compute & update statistics over the current batch in training mode 28 | if self.training: 29 | batch_mean = x.mean() 30 | batch_var = x.var() 31 | 32 | self.__update_stats(batch_mean, batch_var) 33 | 34 | # Use running statistics in other modes 35 | else: 36 | batch_mean = self.running_mean 37 | batch_var = self.running_var 38 | 39 | return (x - batch_mean) / torch.sqrt(batch_var + self.eps) 40 | 41 | def inverse_transform(self, x: Tensor) -> Tensor: 42 | # Use running statistics to undo the standardization 43 | return (x * torch.sqrt(self.running_var + self.eps)) + self.running_mean 44 | 45 | @torch.no_grad() 46 | def step(self) -> None: 47 | self.momentum.data = self.momentum_decay * self.momentum.data 48 | 49 | @torch.no_grad() 50 | def __update_stats(self, batch_mean: Tensor, batch_var: Tensor) -> Tensor: 51 | # Copy batch statistics for the initial fitting 52 | if not self.fitted: 53 | self.running_mean.data = batch_mean.data 54 | self.running_var.data = batch_var.data 55 | self.fitted.data = torch.tensor(True, device=batch_mean.device) 56 | 57 | # Update a moving average of the statistics 58 | else: 59 | self.running_mean.data = ( 60 | 1.0 - self.momentum 61 | ) * self.running_mean + self.momentum * batch_mean 62 | self.running_var.data = ( 63 | 1.0 - self.momentum 64 | ) * self.running_var + self.momentum * batch_var 65 | 66 | 67 | class MinMaxScaler(nn.Module): 68 | def __init__( 69 | self, 70 | feature_range: Tuple[float, float] = (-1.0, 1.0), 71 | momentum: float = 1e-3, 72 | momentum_decay=0.99, 73 | clip: bool = True, 74 | ) -> None: 75 | super().__init__() 76 | 77 | self.feature_range = feature_range 78 | self.momentum_decay = momentum_decay 79 | self.clip = clip 80 | 81 | self.register_buffer("momentum", torch.tensor(momentum)) 82 | self.register_buffer("running_min", torch.tensor(0.0)) 83 | self.register_buffer("running_max", torch.tensor(1.0)) 84 | self.register_buffer("fitted", torch.tensor(False)) 85 | 86 | def forward(self, x: Tensor, inverse: bool = False) -> Tensor: 87 | if inverse: 88 | return self.inverse_transform(x) 89 | return self.transform(x) 90 | 91 | def transform(self, x: Tensor) -> Tensor: 92 | f_min, f_max = self.feature_range 93 | 94 | # Compute & update statistics over the current batch in training mode 95 | if self.training: 96 | batch_min = x.min() 97 | batch_max = x.max() 98 | 99 | self.__update_stats(batch_min, batch_max) 100 | 101 | # Use running statistics in other modes 102 | else: 103 | batch_min = self.running_min 104 | batch_max = self.running_max 105 | 106 | x_transformed = (x - batch_min) / (batch_max - batch_min) 107 | x_transformed = x_transformed * (f_max - f_min) + f_min 108 | 109 | # Ensure values are in the appropriate range 110 | if self.clip: 111 | return torch.clip(x_transformed, min=f_min, max=f_max) 112 | return x_transformed 113 | 114 | def inverse_transform(self, x: Tensor) -> Tensor: 115 | f_min, f_max = self.feature_range 116 | 117 | # Ensure values are in the appropriate range 118 | if self.clip: 119 | x = torch.clip(x, min=f_min, max=f_max) 120 | 121 | # Use running statistics to undo the min-max scaling 122 | x_transformed = (x - f_min) / (f_max - f_min) 123 | x_transformed = ( 124 | x_transformed * (self.running_max - self.running_min) + self.running_min 125 | ) 126 | 127 | return x_transformed 128 | 129 | @torch.no_grad() 130 | def step(self) -> None: 131 | self.momentum.data = self.momentum_decay * self.momentum.data 132 | 133 | @torch.no_grad() 134 | def __update_stats(self, batch_min: Tensor, batch_max: Tensor) -> Tensor: 135 | # Copy batch statistics for the initial fitting 136 | if not self.fitted: 137 | self.running_min.data = batch_min.data 138 | self.running_max.data = batch_max.data 139 | self.fitted.data = torch.tensor(True, device=batch_min.device) 140 | 141 | # Update a moving average of the statistics 142 | else: 143 | self.running_min.data = ( 144 | 1.0 - self.momentum 145 | ) * self.running_min + self.momentum * batch_min 146 | self.running_max.data = ( 147 | 1.0 - self.momentum 148 | ) * self.running_max + self.momentum * batch_max 149 | -------------------------------------------------------------------------------- /notebooks/msanii_training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Msanii Training\n" 9 | ] 10 | }, 11 | { 12 | "attachments": {}, 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## GPU Check" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "!nvidia-smi" 26 | ] 27 | }, 28 | { 29 | "attachments": {}, 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Setup\n", 34 | "\n", 35 | "Run one of the the below install options. \n", 36 | "> **WARNING:** Restart the runtime or some packages will not be updated!" 37 | ] 38 | }, 39 | { 40 | "attachments": {}, 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "### Install package from git" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "%pip install -q git+https://github.com/Kinyugo/msanii.git" 54 | ] 55 | }, 56 | { 57 | "attachments": {}, 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "### Install package in edit mode" 62 | ] 63 | }, 64 | { 65 | "attachments": {}, 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "!git clone https://github.com/Kinyugo/msanii.git\n", 70 | "!cd msanii\n", 71 | "%pip install -q -r requirements.txt\n", 72 | "%pip install -e ." 73 | ] 74 | }, 75 | { 76 | "attachments": {}, 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "### Imports" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "import wandb\n", 90 | "from omegaconf import OmegaConf\n", 91 | "\n", 92 | "from msanii.config import TrainingConfig\n", 93 | "from msanii.scripts import run_training" 94 | ] 95 | }, 96 | { 97 | "attachments": {}, 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "## Training\n" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "dict_config = {\n", 111 | " \"vocoder\": {\n", 112 | " \"datamodule\": {\n", 113 | " \"data_dir\": \"\",\n", 114 | " \"batch_size\": 8,\n", 115 | " \"num_frames\": 523_264,\n", 116 | " \"sample_rate\": 44_100,\n", 117 | " \"load_random_slice\": True,\n", 118 | " \"num_workers\": 8,\n", 119 | " \"pin_memory\": True,\n", 120 | " },\n", 121 | " \"transforms\": {\n", 122 | " \"sample_rate\": 44_100,\n", 123 | " },\n", 124 | " \"lit_vocoder\": {\"sample_every_n_epochs\": 20},\n", 125 | " \"trainer\": {\n", 126 | " \"accumulate_grad_batches\": 1,\n", 127 | " \"log_every_n_steps\": 20,\n", 128 | " \"max_epochs\": 201,\n", 129 | " \"precision\": 16,\n", 130 | " },\n", 131 | " },\n", 132 | " \"diffusion\": {\n", 133 | " \"datamodule\": {\n", 134 | " \"data_dir\": \"\",\n", 135 | " \"batch_size\": 4,\n", 136 | " \"num_frames\": 8_387_584,\n", 137 | " \"sample_rate\": 44_100,\n", 138 | " \"load_random_slice\": True,\n", 139 | " \"num_workers\": 8,\n", 140 | " \"pin_memory\": True,\n", 141 | " },\n", 142 | " \"lit_diffusion\": {\"sample_every_n_epochs\": 20},\n", 143 | " \"trainer\": {\n", 144 | " \"accumulate_grad_batches\": 1,\n", 145 | " \"log_every_n_steps\": 20,\n", 146 | " \"max_epochs\": 501,\n", 147 | " \"precision\": 16,\n", 148 | " },\n", 149 | " },\n", 150 | "}" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "default_config = OmegaConf.structured(TrainingConfig)\n", 160 | "custom_config = OmegaConf.create(dict_config)\n", 161 | "config = OmegaConf.merge(default_config, custom_config)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "# Run this if not logged into wandb\n", 171 | "# wandb.login()" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "run_training(config)" 181 | ] 182 | } 183 | ], 184 | "metadata": { 185 | "kernelspec": { 186 | "display_name": "torch", 187 | "language": "python", 188 | "name": "python3" 189 | }, 190 | "language_info": { 191 | "codemirror_mode": { 192 | "name": "ipython", 193 | "version": 3 194 | }, 195 | "file_extension": ".py", 196 | "mimetype": "text/x-python", 197 | "name": "python", 198 | "nbconvert_exporter": "python", 199 | "pygments_lexer": "ipython3", 200 | "version": "3.10.8" 201 | }, 202 | "orig_nbformat": 4, 203 | "vscode": { 204 | "interpreter": { 205 | "hash": "c339664639c3e5019e3803d0baff2aab4fdaac0204aae143f6ed0f1a6cb76161" 206 | } 207 | } 208 | }, 209 | "nbformat": 4, 210 | "nbformat_minor": 2 211 | } 212 | -------------------------------------------------------------------------------- /msanii/models/unet.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from diffusers.configuration_utils import ConfigMixin, register_to_config 6 | from torch import Tensor, nn 7 | 8 | from ..modules import ( 9 | Block, 10 | Downsample, 11 | InputFn, 12 | MiddleBlock, 13 | OutputFn, 14 | TimestepEmbedding, 15 | Upsample, 16 | ) 17 | 18 | 19 | class UNet(ConfigMixin, nn.Module): 20 | config_name = "unet_config.json" 21 | 22 | @register_to_config 23 | def __init__( 24 | self, 25 | d_freq: int = 128, 26 | d_base: int = 256, 27 | d_hidden_factor: int = 4, 28 | d_multipliers: List[int] = [1, 1, 1, 1, 1, 1, 1], 29 | d_timestep: int = 128, 30 | dilations: List[int] = [1, 1, 1, 1, 1, 1, 1], 31 | n_heads: int = 8, 32 | has_attention: List[bool] = [False, False, False, False, False, True, True], 33 | has_resampling: List[bool] = [True, True, True, True, True, True, False], 34 | n_block_layers: List[int] = [2, 2, 2, 2, 2, 2, 2], 35 | ) -> None: 36 | super().__init__() 37 | 38 | self.d_freq = d_freq 39 | self.d_base = d_base 40 | self.d_hidden_factor = d_hidden_factor 41 | self.d_multipliers = d_multipliers 42 | self.d_timestep = d_timestep 43 | self.dilations = dilations 44 | self.n_heads = n_heads 45 | self.has_attention = has_attention 46 | self.has_resampling = has_resampling 47 | self.n_block_layers = n_block_layers 48 | 49 | self.input_fn = InputFn(self.d_freq, self.d_base) 50 | self.output_fn = OutputFn(self.d_freq, self.d_base) 51 | self.timestep_embedding = TimestepEmbedding( 52 | self.d_timestep, self.d_hidden_factor 53 | ) 54 | 55 | self.encoder_blocks = self.__make_encoder_blocks() 56 | self.middle_block = MiddleBlock( 57 | self.d_base * self.d_multipliers[-1], 58 | self.d_base * self.d_multipliers[-1], 59 | self.d_hidden_factor, 60 | self.n_heads, 61 | dilation=1, 62 | d_timestep=self.d_timestep, 63 | ) 64 | self.decoder_blocks = self.__make_decoder_blocks() 65 | 66 | def forward(self, x: Tensor, timestep: Tensor) -> Tensor: 67 | x_embed = self.input_fn(x) 68 | timestep_embed = self.timestep_embedding(timestep) 69 | 70 | x_hidden, enc_hiddens = self.__encode(x_embed, timestep_embed) 71 | x_hidden = self.middle_block(x_hidden, timestep_embed) 72 | x_hidden = self.__decode(x_hidden, enc_hiddens, timestep_embed) 73 | 74 | return self.output_fn(x_hidden) 75 | 76 | def __encode( 77 | self, x_hidden: Tensor, timestep_embed: Tensor 78 | ) -> Tuple[Tensor, List[Tensor]]: 79 | enc_hiddens = [] 80 | for block in self.encoder_blocks: 81 | if isinstance(block, Block): 82 | x_hidden = block(x_hidden, timestep_embed) 83 | enc_hiddens.append(x_hidden) 84 | else: 85 | x_hidden = block(x_hidden) 86 | 87 | return x_hidden, enc_hiddens 88 | 89 | def __decode( 90 | self, x_hidden: Tensor, enc_hiddens: List[Tensor], timestep_embed: Tensor 91 | ) -> Tensor: 92 | for block in self.decoder_blocks: 93 | if isinstance(block, Block): 94 | x_hidden = torch.cat((x_hidden, enc_hiddens.pop()), dim=1) 95 | x_hidden = block(x_hidden, timestep_embed) 96 | else: 97 | x_hidden = block(x_hidden) 98 | 99 | return x_hidden 100 | 101 | def __make_encoder_blocks(self) -> nn.ModuleList: 102 | blocks = nn.ModuleList() 103 | 104 | for idx, (d_in, d_out) in enumerate(self.__make_d_pairs()): 105 | # Append layers for the current blocks 106 | for _ in range(self.n_block_layers[idx]): 107 | blocks.append( 108 | self.__make_block( 109 | d_in, d_out, self.dilations[idx], self.has_attention[idx] 110 | ) 111 | ) 112 | d_in = d_out 113 | 114 | # Append downsampling block 115 | if self.has_resampling[idx]: 116 | blocks.append(Downsample(d_out)) 117 | 118 | return blocks 119 | 120 | def __make_decoder_blocks(self) -> nn.ModuleList: 121 | blocks = nn.ModuleList() 122 | 123 | # Append blocks in reverse order 124 | for idx, (d_out, d_in) in enumerate(self.__make_d_pairs()[::-1]): 125 | # Append upsampling blocks 126 | if self.has_resampling[::-1][idx]: 127 | blocks.append(Upsample(d_in)) 128 | 129 | # Append layers for the current block 130 | inner_blocks = nn.ModuleList() 131 | for _ in range(self.n_block_layers[::-1][idx]): 132 | inner_blocks.append( 133 | self.__make_block( 134 | d_in * 2, 135 | d_out, 136 | self.dilations[::-1][idx], 137 | self.has_attention[::-1][idx], 138 | ) 139 | ) 140 | d_out = d_in 141 | 142 | # Append layers in reversed order 143 | blocks.extend(inner_blocks[::-1]) 144 | 145 | return blocks 146 | 147 | def __make_d_pairs(self) -> nn.ModuleList: 148 | dims = np.multiply(self.d_multipliers + self.d_multipliers[-1:], self.d_base) 149 | return list(zip(dims[:-1], dims[1:])) 150 | 151 | def __make_block( 152 | self, d_in: int, d_out: int, dilation: int, has_attention: bool 153 | ) -> Block: 154 | return Block( 155 | d_in, 156 | d_out, 157 | self.d_hidden_factor, 158 | dilation, 159 | self.d_timestep, 160 | self.n_heads, 161 | has_attention, 162 | ) 163 | -------------------------------------------------------------------------------- /msanii/transforms/griffin_lim.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Callable, Optional, Tuple, Union 3 | 4 | import torch 5 | from einops import rearrange 6 | from torch import Tensor, nn 7 | 8 | from .utils import get_complex_dtype 9 | 10 | 11 | class BaseIterativeVocoder(ABC, nn.Module): 12 | def __init__( 13 | self, 14 | num_iters: int = 100, 15 | n_fft: int = 2048, 16 | win_length: Optional[int] = None, 17 | hop_length: Optional[int] = None, 18 | window_fn: Callable[..., Tensor] = torch.hann_window, 19 | wkwargs: Optional[dict] = None, 20 | power: float = 1.0, 21 | eps: float = 1e-16, 22 | ) -> None: 23 | super().__init__() 24 | 25 | self.num_iters = num_iters 26 | self.n_fft = n_fft 27 | self.win_length = n_fft if win_length is None else win_length 28 | self.hop_length = self.win_length // 2 if hop_length is None else hop_length 29 | window = ( 30 | window_fn(self.win_length) 31 | if wkwargs is None 32 | else window_fn(self.win_length, **wkwargs) 33 | ) 34 | self.register_buffer("window", window) 35 | self.power = power 36 | self.eps = eps 37 | 38 | def forward( 39 | self, 40 | specgram: Tensor, 41 | *, 42 | init_phase: Optional[Tensor] = None, 43 | length: Optional[int] = None, 44 | num_iters: Optional[int] = None, 45 | return_phase: bool = False, 46 | ) -> Union[Tensor, Tuple[Tensor, Tensor]]: 47 | # Flatten the specgram except the frequency and time dimension 48 | shape = specgram.shape 49 | specgram = rearrange(specgram, "... f t -> (...) f t") 50 | 51 | # Project arbitrary specgram into a magnitude specgram 52 | specgram = specgram.pow(1 / self.power) 53 | 54 | # Initialize the phase 55 | if init_phase is None: 56 | phase = torch.rand( 57 | specgram.shape, 58 | dtype=get_complex_dtype(specgram.dtype), 59 | device=specgram.device, 60 | ) 61 | else: 62 | phase = init_phase / (torch.abs(init_phase) + self.eps) 63 | phase = phase.reshape(specgram.shape) 64 | 65 | # Reconstruct the phase 66 | phase = self.reconstruct_phase( 67 | specgram, phase=phase, length=length, num_iters=num_iters 68 | ) 69 | 70 | # Synthesize the waveform using computed phase 71 | waveform = torch.istft( 72 | specgram * phase, 73 | n_fft=self.n_fft, 74 | win_length=self.win_length, 75 | hop_length=self.hop_length, 76 | window=self.window, 77 | length=length, 78 | ) 79 | 80 | # Unflatten the waveform & phase 81 | waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:]) 82 | phase = phase.reshape(shape) 83 | 84 | if return_phase: 85 | return waveform, phase 86 | return waveform 87 | 88 | @abstractmethod 89 | def reconstruct_phase( 90 | self, 91 | specgram: Tensor, 92 | *, 93 | phase: Optional[Tensor] = None, 94 | length: Optional[int] = None, 95 | num_iters: Optional[int] = None, 96 | ) -> Tensor: 97 | raise NotImplementedError 98 | 99 | def project_onto_magspec(self, magspec: Tensor, stftspec: Tensor) -> Tensor: 100 | return magspec * stftspec / (torch.abs(stftspec) + self.eps) 101 | 102 | def project_complex_spec( 103 | self, stftspec: Tensor, length: Optional[int] = None 104 | ) -> Tensor: 105 | # Invert with our current phase estimates 106 | inverse = torch.istft( 107 | stftspec, 108 | n_fft=self.n_fft, 109 | win_length=self.win_length, 110 | hop_length=self.hop_length, 111 | window=self.window, 112 | length=length, 113 | ) 114 | # inverse = inverse.clamp(min=-1.0, max=1.0) 115 | 116 | # Rebuild the complex spectrogram 117 | rebuilt = torch.stft( 118 | inverse, 119 | n_fft=self.n_fft, 120 | win_length=self.win_length, 121 | hop_length=self.hop_length, 122 | window=self.window, 123 | center=True, 124 | pad_mode="reflect", 125 | normalized=False, 126 | onesided=True, 127 | return_complex=True, 128 | ) 129 | 130 | return rebuilt 131 | 132 | 133 | class GriffinLim(BaseIterativeVocoder): 134 | def __init__( 135 | self, 136 | num_iters: int = 100, 137 | momentum: float = 0.99, 138 | n_fft: int = 2048, 139 | win_length: Optional[int] = None, 140 | hop_length: Optional[int] = None, 141 | window_fn: Callable[..., Tensor] = torch.hann_window, 142 | wkwargs: Optional[dict] = None, 143 | power: float = 1.0, 144 | eps: float = 1e-16, 145 | ) -> None: 146 | super().__init__( 147 | num_iters, n_fft, win_length, hop_length, window_fn, wkwargs, power, eps 148 | ) 149 | 150 | self.momentum = momentum 151 | 152 | def reconstruct_phase( 153 | self, 154 | specgram: Tensor, 155 | *, 156 | phase: Optional[Tensor] = None, 157 | length: Optional[int] = None, 158 | num_iters: Optional[int] = None, 159 | ) -> Tensor: 160 | if num_iters is None: 161 | num_iters = self.num_iters 162 | 163 | momentum = self.momentum / (1 + self.momentum) 164 | 165 | # Initialize our previous iterate 166 | prev_stftspec = torch.tensor(0.0, dtype=specgram.dtype, device=specgram.device) 167 | 168 | for _ in range(num_iters): 169 | # Invert and rebuild with the current phase estimate 170 | next_stftspec = self.project_complex_spec(specgram * phase, length) 171 | 172 | # Update our phase estimates 173 | phase = next_stftspec - (prev_stftspec * momentum) 174 | phase = phase / (torch.abs(phase) + self.eps) 175 | 176 | # Update our previous iterate 177 | prev_stftspec = next_stftspec 178 | 179 | return phase 180 | -------------------------------------------------------------------------------- /msanii/lit_models/lit_diffusion.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import List, Tuple, Union 3 | 4 | import torch 5 | from diffusers import DDIMScheduler, DPMSolverMultistepScheduler 6 | from lightning.pytorch import LightningModule 7 | from torch import Tensor, nn, optim 8 | 9 | from ..diffusion import Sampler 10 | from ..models import UNet, Vocoder 11 | from ..transforms import Transforms 12 | from ..utils import freeze_model 13 | from .utils import log_samples, update_ema_model 14 | 15 | 16 | class LitDiffusion(LightningModule): 17 | def __init__( 18 | self, 19 | transforms: Transforms, 20 | vocoder: Vocoder, 21 | unet: UNet, 22 | ema_unet: UNet, 23 | scheduler: Union[DDIMScheduler, DPMSolverMultistepScheduler], 24 | sample_rate: int = 44_100, 25 | transforms_decay: float = 0.999, 26 | ema_decay: float = 0.995, 27 | ema_start_step: int = 2000, 28 | ema_update_every: int = 10, 29 | lr: float = 2e-4, 30 | betas: Tuple[float, float] = (0.5, 0.999), 31 | lr_scheduler_start_factor: float = 1 / 3, 32 | lr_scheduler_iters: int = 500, 33 | sample_every_n_epochs: int = 10, 34 | num_samples: int = 4, 35 | num_inference_steps: int = 20, 36 | ) -> None: 37 | super().__init__() 38 | 39 | self.save_hyperparameters( 40 | ignore=["transforms", "vocoder", "unet", "ema_unet", "scheduler"] 41 | ) 42 | 43 | self.transforms = transforms 44 | self.vocoder = freeze_model(vocoder) 45 | self.unet = unet 46 | self.ema_unet = ema_unet 47 | self.scheduler = scheduler 48 | self.sample_rate = sample_rate 49 | self.transforms_decay = transforms_decay 50 | self.ema_decay = ema_decay 51 | self.ema_start_step = ema_start_step 52 | self.ema_update_every = ema_update_every 53 | self.lr = lr 54 | self.betas = betas 55 | self.lr_scheduler_start_factor = lr_scheduler_start_factor 56 | self.lr_scheduler_iters = lr_scheduler_iters 57 | self.sample_every_n_epochs = sample_every_n_epochs 58 | self.num_samples = num_samples 59 | self.num_inference_steps = num_inference_steps 60 | 61 | self.loss_fn = nn.L1Loss() 62 | 63 | def training_step( 64 | self, batch: Union[Tensor, List[Tensor]], batch_idx: int 65 | ) -> Tensor: 66 | # Drop labels if present 67 | if isinstance(batch, list): 68 | batch = batch[0] 69 | 70 | # Transform batch to mel spectrograms & add noise 71 | mel_spectrograms = self.transforms(batch) 72 | noise = torch.randn_like(mel_spectrograms) 73 | timesteps = torch.randint( 74 | self.scheduler.num_train_timesteps, 75 | size=(batch.shape[0],), 76 | dtype=torch.long, 77 | device=self.device, 78 | ) 79 | noisy_mel_spectrograms = self.scheduler.add_noise( 80 | mel_spectrograms, noise, timesteps 81 | ) 82 | # Predict added noise 83 | pred_noise = self.unet(noisy_mel_spectrograms, timesteps) 84 | 85 | # Compute & log reconstruction loss 86 | loss = self.loss_fn(pred_noise, noise) 87 | self.log_dict({"total_loss": loss}) 88 | 89 | # Update & log transforms parameters 90 | self.transforms.step() 91 | self.log_dict(self.transforms.params_dict) 92 | 93 | # Update moving average model parameters 94 | if self.global_step % self.ema_update_every == 0: 95 | self.__update_ema_model() 96 | 97 | # Sample & log waveforms, spectrograms, distributions & audio samples 98 | if (self.current_epoch % self.sample_every_n_epochs == 0) and batch_idx == 0: 99 | self.__sample_and_log_samples(batch, mel_spectrograms) 100 | 101 | return loss 102 | 103 | def configure_optimizers(self): 104 | opt = optim.Adam(self.unet.parameters(), lr=self.lr, betas=self.betas) 105 | sched = optim.lr_scheduler.LinearLR( 106 | opt, 107 | start_factor=self.lr_scheduler_start_factor, 108 | total_iters=self.lr_scheduler_iters, 109 | ) 110 | sched = {"scheduler": sched, "interval": "step"} 111 | 112 | return [opt], [sched] 113 | 114 | @torch.no_grad() 115 | def __update_ema_model(self) -> None: 116 | # Copy source model parameters 117 | if self.global_step < self.ema_start_step: 118 | self.ema_unet.load_state_dict(self.unet.state_dict()) 119 | 120 | # Update ema model parameters with the moving average 121 | else: 122 | update_ema_model(self.unet, self.ema_unet, self.ema_decay) 123 | 124 | @torch.no_grad() 125 | def __sample_and_log_samples( 126 | self, waveforms: Tensor, mel_spectrograms: Tensor 127 | ) -> None: 128 | # Ensure the number of samples does not exceed the batch size 129 | num_samples = min(self.num_samples, waveforms.shape[0]) 130 | 131 | # Generate samples 132 | noise = torch.randn_like(mel_spectrograms)[:num_samples] 133 | sampler = Sampler(self.ema_unet, copy.deepcopy(self.scheduler)).to(self.device) 134 | sample_mel_spectrograms = sampler( 135 | noise, num_inference_steps=self.num_inference_steps, verbose=True 136 | ) 137 | 138 | # Ground truth samples 139 | log_samples( 140 | self.logger, 141 | waveforms[:num_samples], 142 | mel_spectrograms[:num_samples], 143 | self.sample_rate, 144 | "ground_truth", 145 | ) 146 | 147 | # Neural vocoder samples 148 | sample_waveforms = self.transforms.griffin_lim( 149 | self.vocoder(sample_mel_spectrograms), length=waveforms.shape[-1] 150 | ) 151 | log_samples( 152 | self.logger, 153 | sample_waveforms, 154 | sample_mel_spectrograms, 155 | self.sample_rate, 156 | "neural_vocoder", 157 | ) 158 | 159 | # Direct reconstruction samples 160 | sample_waveforms = self.transforms.inverse_transform( 161 | sample_mel_spectrograms, length=waveforms.shape[-1] 162 | ) 163 | log_samples( 164 | self.logger, 165 | sample_waveforms, 166 | sample_mel_spectrograms, 167 | self.sample_rate, 168 | "direct_reconstruction", 169 | ) 170 | -------------------------------------------------------------------------------- /msanii/modules/modules.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | import numpy as np 4 | import torch 5 | from einops import rearrange 6 | from einops.layers.torch import Rearrange 7 | from torch import Tensor, nn 8 | 9 | 10 | class Lambda(nn.Module): 11 | def __init__(self, fn: Callable[..., Any]) -> None: 12 | super().__init__() 13 | 14 | self.fn = fn 15 | 16 | def forward(self, *args: Any) -> Any: 17 | return self.fn(*args) 18 | 19 | 20 | def InputFn(d_in: int, d_model: int) -> nn.Sequential: 21 | return nn.Sequential( 22 | Rearrange("b c f t -> b f c t"), 23 | nn.Conv2d(d_in, d_model, kernel_size=1), 24 | ) 25 | 26 | 27 | def OutputFn(d_out: int, d_model: int) -> nn.Sequential: 28 | return nn.Sequential( 29 | nn.Conv2d(d_model, d_out, kernel_size=1), 30 | Rearrange("b f c t -> b c f t"), 31 | ) 32 | 33 | 34 | class TimestepEmbedding(nn.Module): 35 | def __init__(self, d_timestep: int, d_hidden_factor: int) -> None: 36 | super().__init__() 37 | 38 | self.d_timestep = d_timestep 39 | 40 | self.fn = nn.Sequential( 41 | nn.Linear(d_timestep, d_timestep * d_hidden_factor), 42 | nn.GELU(), 43 | nn.Linear(d_timestep * d_hidden_factor, d_timestep), 44 | Rearrange("b c -> b c () ()"), 45 | ) 46 | 47 | @property 48 | def dtype(self) -> torch.dtype: 49 | return next(self.parameters()).dtype 50 | 51 | def forward(self, x: Tensor) -> Tensor: 52 | half_d_timestep = self.d_timestep // 2 53 | emb = np.log(10000) / (half_d_timestep - 1) 54 | emb = torch.exp(torch.arange(half_d_timestep, device=x.device) * -emb) 55 | emb = x[:, None] * emb[None, :] 56 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1).to(self.dtype) 57 | emb = self.fn(emb) 58 | 59 | return emb 60 | 61 | 62 | class ResidualBlock(nn.Module): 63 | def __init__( 64 | self, 65 | d_in: int, 66 | d_out: int, 67 | d_hidden_factor: int, 68 | dilation: int = 1, 69 | d_timestep: Optional[int] = None, 70 | ) -> None: 71 | super().__init__() 72 | 73 | self.in_fn = nn.Sequential( 74 | nn.InstanceNorm2d(d_in), 75 | nn.Conv2d(d_in, d_out, kernel_size=3, dilation=dilation, padding="same"), 76 | ) 77 | self.out_fn = nn.Sequential( 78 | nn.GELU(), 79 | nn.Conv2d(d_out, d_out * d_hidden_factor, kernel_size=1), 80 | nn.GELU(), 81 | nn.Conv2d(d_out * d_hidden_factor, d_out, kernel_size=1), 82 | ) 83 | self.timestep_fn = ( 84 | nn.Conv2d(d_timestep, d_out, kernel_size=1) if d_timestep else None 85 | ) 86 | self.residual_fn = ( 87 | nn.Conv2d(d_in, d_out, kernel_size=1) if d_in != d_out else nn.Identity() 88 | ) 89 | 90 | def forward(self, x: Tensor, timestep_embed: Optional[Tensor] = None) -> Tensor: 91 | x_hidden = self.in_fn(x) 92 | if self.timestep_fn: 93 | x_hidden = x_hidden + self.timestep_fn(timestep_embed) 94 | 95 | return self.out_fn(x_hidden) + self.residual_fn(x) 96 | 97 | 98 | class LinearAttention(nn.Module): 99 | def __init__(self, d_model: int, n_heads: int) -> None: 100 | super().__init__() 101 | 102 | self.scale = 1 / np.sqrt(d_model // n_heads) 103 | 104 | self.in_fn = nn.Sequential( 105 | nn.InstanceNorm2d(d_model), 106 | nn.Conv2d(d_model, d_model * 3, kernel_size=1), 107 | Lambda(lambda x: torch.chunk(x, chunks=3, dim=1)), 108 | Lambda( 109 | lambda x: [ 110 | rearrange(t, "b (h f) c t -> (b c) h f t", h=n_heads) for t in x 111 | ] 112 | ), 113 | ) 114 | self.out_fn = nn.Conv2d(d_model, d_model, kernel_size=1) 115 | 116 | def forward(self, x: Tensor) -> Tensor: 117 | q, k, v = self.in_fn(x) 118 | 119 | q = q.softmax(dim=-2) 120 | k = k.softmax(dim=-1) 121 | 122 | q = q * self.scale 123 | v = v / v.shape[-1] 124 | 125 | context = torch.einsum("b h k l, b h v l -> b h k v", k, v) 126 | 127 | out = torch.einsum("b h d v, b h d l -> b h v l", context, q) 128 | out = rearrange(out, "(b c) h f t -> b (h f) c t", c=x.shape[-2]) 129 | 130 | return self.out_fn(out) + x 131 | 132 | 133 | class Block(nn.Module): 134 | def __init__( 135 | self, 136 | d_in: int, 137 | d_out: int, 138 | d_hidden_factor: int, 139 | dilation: int = 1, 140 | d_timestep: Optional[int] = None, 141 | n_heads: int = 8, 142 | has_attention: bool = True, 143 | ) -> None: 144 | super().__init__() 145 | 146 | self.residual_fn = ResidualBlock( 147 | d_in, d_out, d_hidden_factor, dilation, d_timestep 148 | ) 149 | self.attention_fn = ( 150 | LinearAttention(d_out, n_heads) if has_attention else nn.Identity() 151 | ) 152 | 153 | def forward(self, x: Tensor, timestep_embed: Optional[Tensor] = None) -> Tensor: 154 | return self.attention_fn(self.residual_fn(x, timestep_embed)) 155 | 156 | 157 | class MiddleBlock(nn.Module): 158 | def __init__( 159 | self, 160 | d_in: int, 161 | d_out: int, 162 | d_hidden_factor: int, 163 | n_heads: int = 8, 164 | dilation: int = 1, 165 | d_timestep: int = None, 166 | ) -> None: 167 | super().__init__() 168 | 169 | self.in_residual_fn = ResidualBlock( 170 | d_in, d_out, d_hidden_factor, dilation, d_timestep 171 | ) 172 | self.attention_fn = LinearAttention(d_out, n_heads) 173 | self.out_residual_fn = ResidualBlock( 174 | d_in, d_out, d_hidden_factor, dilation, d_timestep 175 | ) 176 | 177 | def forward(self, x: Tensor, timestep_embed: Optional[Tensor] = None) -> Tensor: 178 | x_hidden = self.in_residual_fn(x, timestep_embed) 179 | x_hidden = self.attention_fn(x_hidden) 180 | 181 | return self.out_residual_fn(x_hidden, timestep_embed) 182 | 183 | 184 | def Downsample(d_model: int) -> nn.Conv2d: 185 | return nn.Conv2d( 186 | d_model, d_model, kernel_size=(1, 3), stride=(1, 2), padding=(0, 1) 187 | ) 188 | 189 | 190 | def Upsample(d_model: int) -> nn.ConvTranspose2d: 191 | return nn.ConvTranspose2d( 192 | d_model, d_model, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) 193 | ) 194 | -------------------------------------------------------------------------------- /msanii/config/training_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | 4 | from omegaconf import MISSING 5 | 6 | 7 | @dataclass 8 | class AudioDataModuleConfig: 9 | _target_: str = "msanii.data.AudioDataModule" 10 | data_dir: str = MISSING 11 | sample_rate: int = 44_100 12 | num_frames: Optional[int] = None 13 | load_random_slice: bool = False 14 | normalize_amplitude: bool = True 15 | batch_size: int = 32 16 | num_workers: int = 4 17 | pin_memory: bool = False 18 | 19 | 20 | @dataclass 21 | class TransformsConfig: 22 | _target_: str = "msanii.transforms.Transforms" 23 | sample_rate: int = 44_100 24 | n_fft: int = 2048 25 | win_length: Optional[int] = None 26 | hop_length: Optional[int] = None 27 | n_mels: int = 128 28 | feature_range: Tuple[float, float] = (-1.0, 1.0) 29 | momentum: float = 1e-3 30 | eps: float = 1e-5 31 | clip: bool = True 32 | num_griffin_lim_iters: int = 100 33 | griffin_lim_momentum: float = 0.99 34 | 35 | 36 | @dataclass 37 | class VocoderConfig: 38 | _target_: str = "msanii.models.Vocoder" 39 | n_fft: int = 2048 40 | n_mels: int = 128 41 | d_model: int = 256 42 | d_hidden_factor: int = 4 43 | 44 | 45 | @dataclass 46 | class UNetConfig: 47 | _target_: str = "msanii.models.UNet" 48 | d_freq: int = 128 49 | d_base: int = 256 50 | d_hidden_factor: int = 4 51 | d_multipliers: List[int] = field(default_factory=lambda: [1, 1, 1, 1, 1, 1, 1]) 52 | d_timestep: int = 128 53 | dilations: List[int] = field(default_factory=lambda: [1, 1, 1, 1, 1, 1, 1]) 54 | n_heads: int = 8 55 | has_attention: List[bool] = field( 56 | default_factory=lambda: [False, False, False, False, False, True, True] 57 | ) 58 | has_resampling: List[bool] = field( 59 | default_factory=lambda: [True, True, True, True, True, True, False] 60 | ) 61 | n_block_layers: List[int] = field(default_factory=lambda: [2, 2, 2, 2, 2, 2, 2]) 62 | 63 | 64 | @dataclass 65 | class SchedulerConfig: 66 | _target_: str = "diffusers.DDIMScheduler" 67 | num_train_timesteps: int = 1000 68 | beta_schedule: str = "squaredcos_cap_v2" 69 | _kwargs_: Dict[str, Any] = field(default_factory=lambda: {}) 70 | 71 | 72 | @dataclass 73 | class LitVocoderConfig: 74 | _target_: str = "msanii.lit_models.LitVocoder" 75 | sample_rate: int = 44_100 76 | transforms_decay: float = 0.999 77 | lr: float = 2e-4 78 | betas: Tuple[float, float] = (0.5, 0.999) 79 | lr_scheduler_start_factor: float = 1 / 3 80 | lr_scheduler_iters: int = 500 81 | sample_every_n_epochs: int = 10 82 | num_samples: int = 4 83 | 84 | 85 | @dataclass 86 | class LitDiffusionConfig: 87 | _target_: str = "msanii.lit_models.LitDiffusion" 88 | sample_rate: int = 44_100 89 | transforms_decay: float = 0.999 90 | ema_decay: float = 0.995 91 | ema_start_step: int = 2000 92 | ema_update_every: int = 10 93 | lr: float = 2e-4 94 | betas: Tuple[float, float] = (0.5, 0.999) 95 | lr_scheduler_start_factor: float = 1 / 3 96 | lr_scheduler_iters: int = 500 97 | sample_every_n_epochs: int = 10 98 | num_samples: int = 4 99 | num_inference_steps: int = 20 100 | 101 | 102 | @dataclass 103 | class WandbLoggerConfig: 104 | _target_: str = "lightning.pytorch.loggers.WandbLogger" 105 | save_dir: str = "logs" 106 | project: str = "msanii" 107 | name: Optional[str] = None 108 | job_type: Optional[str] = "train" 109 | log_model: Union[str, bool] = True 110 | tags: Optional[List[str]] = None 111 | notes: Optional[str] = None 112 | save_code: Optional[bool] = True 113 | offline: bool = False 114 | _kwargs_: Dict[str, Any] = field(default_factory=lambda: {}) 115 | 116 | 117 | @dataclass 118 | class ModelCheckpointConfig: 119 | _target_: str = "lightning.pytorch.callbacks.ModelCheckpoint" 120 | dirpath: Optional[str] = None 121 | save_last: Optional[bool] = True 122 | verbose: bool = False 123 | mode: str = "min" 124 | _kwargs_: Dict[str, Any] = field(default_factory=lambda: {}) 125 | 126 | 127 | @dataclass 128 | class TrainerConfig: 129 | _target_: str = "lightning.Trainer" 130 | accelerator: Optional[str] = "auto" 131 | accumulate_grad_batches: int = 1 132 | devices: Optional[Union[int, str]] = None 133 | default_root_dir: Optional[str] = None 134 | detect_anomaly: bool = False 135 | gradient_clip_val: float = 1.0 136 | gradient_clip_algorithm: str = "norm" 137 | limit_train_batches: Optional[Union[int, float]] = 1.0 138 | log_every_n_steps: int = 10 139 | precision: Union[int, str] = 32 140 | max_epochs: Optional[int] = 6 141 | max_steps: int = -1 142 | weights_save_path: Optional[str] = None 143 | fast_dev_run: bool = False 144 | _kwargs_: Dict[str, Any] = field(default_factory=lambda: {}) 145 | 146 | 147 | @dataclass 148 | class VocoderTrainingConfig: 149 | datamodule: AudioDataModuleConfig = field(default_factory=AudioDataModuleConfig) 150 | transforms: TransformsConfig = field(default_factory=TransformsConfig) 151 | vocoder: VocoderConfig = field(default_factory=VocoderConfig) 152 | lit_vocoder: LitVocoderConfig = field(default_factory=LitVocoderConfig) 153 | wandb_logger: WandbLoggerConfig = field(default_factory=WandbLoggerConfig) 154 | model_checkpoint: ModelCheckpointConfig = field( 155 | default_factory=ModelCheckpointConfig 156 | ) 157 | trainer: TrainerConfig = field(default_factory=TrainerConfig) 158 | 159 | skip_training: bool = False 160 | resume_ckpt_path: Optional[str] = None 161 | 162 | 163 | @dataclass 164 | class DiffusionTrainingConfig: 165 | datamodule: AudioDataModuleConfig = field(default_factory=AudioDataModuleConfig) 166 | unet: UNetConfig = field(default_factory=UNetConfig) 167 | scheduler: SchedulerConfig = field(default_factory=SchedulerConfig) 168 | lit_diffusion: LitDiffusionConfig = field(default_factory=LitDiffusionConfig) 169 | wandb_logger: WandbLoggerConfig = field(default_factory=WandbLoggerConfig) 170 | model_checkpoint: ModelCheckpointConfig = field( 171 | default_factory=ModelCheckpointConfig 172 | ) 173 | trainer: TrainerConfig = field(default_factory=TrainerConfig) 174 | 175 | skip_training: bool = False 176 | resume_ckpt_path: Optional[str] = None 177 | 178 | 179 | @dataclass 180 | class TrainingConfig: 181 | vocoder: VocoderTrainingConfig = field(default_factory=VocoderTrainingConfig) 182 | diffusion: DiffusionTrainingConfig = field(default_factory=DiffusionTrainingConfig) 183 | 184 | seed: int = 0 185 | pipeline_wandb_name: str = "msanii_pipeline" 186 | pipeline_ckpt_path: str = "checkpoints/msanii.pt" 187 | -------------------------------------------------------------------------------- /msanii/scripts/training.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import Tuple, Union 3 | 4 | import lightning as L 5 | import matplotlib 6 | from diffusers import DDIMPipeline, DPMSolverMultistepScheduler 7 | from lightning.pytorch.callbacks import RichModelSummary, TQDMProgressBar 8 | from lightning.pytorch.loggers import WandbLogger 9 | from omegaconf import OmegaConf 10 | 11 | from ..config import ( 12 | DiffusionTrainingConfig, 13 | TrainingConfig, 14 | VocoderTrainingConfig, 15 | from_config, 16 | ) 17 | from ..models import UNet, Vocoder 18 | from ..pipeline import Pipeline 19 | from ..transforms import Transforms 20 | from ..utils import clone_model_parameters 21 | 22 | 23 | def run_vocoder_training(config: VocoderTrainingConfig) -> Tuple[Transforms, Vocoder]: 24 | # ------------------------------------------- 25 | # Data & Transforms 26 | # ------------------------------------------- 27 | datamodule = from_config(config.datamodule) 28 | transforms = from_config(config.transforms) 29 | 30 | # ----------------------------------------- 31 | # Model 32 | # ------------------------------------------ 33 | vocoder = from_config(config.vocoder) 34 | 35 | # ----------------------------------------- 36 | # Lit Model 37 | # ------------------------------------------ 38 | lit_vocoder = from_config( 39 | config.lit_vocoder, transforms=transforms, vocoder=vocoder 40 | ) 41 | 42 | # ----------------------------------------- 43 | # Logger & Callbacks 44 | # ------------------------------------------ 45 | wandb_logger = from_config(config.wandb_logger) 46 | model_checkpoint = from_config(config.model_checkpoint) 47 | callbacks = [model_checkpoint, RichModelSummary(), TQDMProgressBar()] 48 | 49 | # ----------------------------------------- 50 | # Trainer 51 | # ------------------------------------------ 52 | trainer = from_config(config.trainer, logger=wandb_logger, callbacks=callbacks) 53 | 54 | # ----------------------------------------- 55 | # Run Training 56 | # ------------------------------------------ 57 | # Save config to wandb 58 | wandb_logger.experiment.config.update(dict(config)) 59 | 60 | # Optionally run training 61 | if not config.skip_training: 62 | trainer.fit( 63 | lit_vocoder, datamodule=datamodule, ckpt_path=config.resume_ckpt_path 64 | ) 65 | 66 | # Terminate wandb run 67 | wandb_logger.experiment.finish() 68 | 69 | return lit_vocoder.transforms, lit_vocoder.vocoder 70 | 71 | 72 | def run_diffusion_training( 73 | config: DiffusionTrainingConfig, transforms: Transforms, vocoder: Vocoder 74 | ) -> Tuple[ 75 | Transforms, Vocoder, Union[DDIMPipeline, DPMSolverMultistepScheduler], WandbLogger 76 | ]: 77 | # ------------------------------------------- 78 | # Data 79 | # ------------------------------------------- 80 | datamodule = from_config(config.datamodule) 81 | 82 | # ----------------------------------------- 83 | # Models 84 | # ------------------------------------------ 85 | unet = from_config(config.unet) 86 | ema_unet = from_config(config.unet) 87 | ema_unet = clone_model_parameters(unet, ema_unet) 88 | 89 | # ----------------------------------------- 90 | # Scheduler 91 | # ------------------------------------------ 92 | scheduler = from_config(config.scheduler) 93 | 94 | # ----------------------------------------- 95 | # Lit Model 96 | # ------------------------------------------ 97 | lit_diffusion = from_config( 98 | config.lit_diffusion, 99 | transforms=transforms, 100 | vocoder=vocoder, 101 | unet=unet, 102 | ema_unet=ema_unet, 103 | scheduler=scheduler, 104 | ) 105 | 106 | # ----------------------------------------- 107 | # Logger & Callbacks 108 | # ------------------------------------------ 109 | wandb_logger = from_config(config.wandb_logger, reinit=True) 110 | model_checkpoint = from_config(config.model_checkpoint) 111 | callbacks = [model_checkpoint, RichModelSummary(), TQDMProgressBar()] 112 | 113 | # ----------------------------------------- 114 | # Trainer 115 | # ------------------------------------------ 116 | trainer = from_config(config.trainer, logger=wandb_logger, callbacks=callbacks) 117 | 118 | # ----------------------------------------- 119 | # Run Training 120 | # ------------------------------------------ 121 | # Save config to wandb 122 | wandb_logger.experiment.config.update(dict(config)) 123 | 124 | # Optionally run training 125 | if not config.skip_training: 126 | trainer.fit( 127 | lit_diffusion, datamodule=datamodule, ckpt_path=config.resume_ckpt_path 128 | ) 129 | 130 | return ( 131 | lit_diffusion.transforms, 132 | lit_diffusion.ema_unet, 133 | lit_diffusion.scheduler, 134 | wandb_logger, 135 | ) 136 | 137 | 138 | def run_training(config: TrainingConfig) -> None: 139 | # ------------------------------------------- 140 | # Reproducibility 141 | # ------------------------------------------- 142 | L.seed_everything(config.seed) 143 | 144 | # ------------------------------------------- 145 | # Configure Matplotlib 146 | # ------------------------------------------- 147 | # Prevents pixelated fonts on figures 148 | matplotlib.use("webagg") 149 | matplotlib.style.use(["seaborn", "fast"]) 150 | 151 | # ------------------------------------------- 152 | # Train Vocoder 153 | # ------------------------------------------- 154 | transforms, vocoder = run_vocoder_training(config.vocoder) 155 | 156 | # ------------------------------------------- 157 | # Train Diffusion 158 | # ------------------------------------------- 159 | transforms, unet, scheduler, wandb_logger = run_diffusion_training( 160 | config.diffusion, transforms, vocoder 161 | ) 162 | 163 | # ------------------------------------------- 164 | # Save Pipeline Checkpoint 165 | # ------------------------------------------- 166 | pipeline = Pipeline(transforms, vocoder, unet, scheduler) 167 | pipeline.save_pretrained(config.pipeline_ckpt_path) 168 | 169 | # ------------------------------------------- 170 | # Log checkpoint to Wandb 171 | # ------------------------------------------- 172 | artifact = wandb_logger.experiment.Artifact( 173 | config.pipeline_wandb_name, type="model" 174 | ) 175 | artifact.add_file(config.pipeline_ckpt_path) 176 | wandb_logger.experiment.log_artifact(artifact) 177 | 178 | 179 | if __name__ == "__main__": 180 | parser = ArgumentParser() 181 | parser.add_argument("config_path", help="path to config file", type=str) 182 | args = parser.parse_args() 183 | 184 | default_training_config = OmegaConf.structured(TrainingConfig) 185 | file_training_config = OmegaConf.load(args.config_path) 186 | training_config = OmegaConf.merge(default_training_config, file_training_config) 187 | 188 | run_training(training_config) 189 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Msanii: High Fidelity Music Synthesis on a Shoestring Budget 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2301.06468-.svg)](https://arxiv.org/abs/2301.06468) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/kinyugo/msanii) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Kinyugo/msanii/blob/main/notebooks/msanii_demo.ipynb) [![GitHub Repo stars](https://img.shields.io/github/stars/Kinyugo/msanii?style=social) ](https://github.com/Kinyugo/msanii) 4 | 5 | A novel diffusion-based model for synthesizing long-context, high-fidelity music efficiently. 6 | 7 | ## Abstract 8 | 9 | > In this paper, we present Msanii, a novel diffusion-based model for synthesizing long-context, high-fidelity music efficiently. Our model combines the expressiveness of mel spectrograms, the generative capabilities of diffusion models, and the vocoding capabilities of neural vocoders. We demonstrate the effectiveness of Msanii by synthesizing tens of seconds (_190 seconds_) of _stereo_ music at high sample rates (_44.1 kHz_) without the use of concatenative synthesis, cascading architectures, or compression techniques. To the best of our knowledge, this is the first work to successfully employ a diffusion-based model for synthesizing such long music samples at high sample rates. Our demo can be found [here](https://kinyugo.github.io/msanii-demo) and our code [here](https://github.com/Kinyugo/msanii). 10 | 11 | ## Disclaimer 12 | 13 | This is a work in progress and has not been finalized. The results and approach presented are subject to change and should not be considered final. 14 | 15 | ## Samples 16 | 17 | See more [here](https://kinyugo.github.io/msanii-demo/). 18 | 19 | | **Midnight Melodies** | **Echoes of Yesterday** | 20 | | :---------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------: | 21 | | [ ![ Midnight Melodies ](http://img.youtube.com/vi/cFrpR0wc_A4/0.jpg) ](http://www.youtube.com/watch?v=cFrpR0wc_A4 "Midnight Melodies") | [ ![ Echoes of Yesterday ](http://img.youtube.com/vi/tWlEqkRxZSU/0.jpg) ](http://www.youtube.com/watch?v=tWlEqkRxZSU "Echoes of Yesterday") | 22 | | **Rainy Day Reflections** | **Starlight Sonatas** | 23 | | [ ![ Rainy Day Reflections ](http://img.youtube.com/vi/-ZikAJxNomM/0.jpg) ](http://www.youtube.com/watch?v=-ZikAJxNomM "Rainy Day Reflections") | [ ![ Starlight Sonatas ](http://img.youtube.com/vi/3adYlNVZSxA/0.jpg) ](http://www.youtube.com/watch?v=3adYlNVZSxA "Starlight Sonatas") | 24 | 25 | ## Setup 26 | 27 | Setup your virtual environment using conda or venv. 28 | 29 | ### Install package from git 30 | 31 | ```bash 32 | pip install -q git+https://github.com/Kinyugo/msanii.git 33 | ``` 34 | 35 | ### Install package in edit mode 36 | 37 | ```bash 38 | git clone https://github.com/Kinyugo/msanii.git 39 | cd msanii 40 | pip install -q -r requirements.txt 41 | pip install -e . 42 | ``` 43 | 44 | ## Training 45 | 46 | ### Notebook 47 | 48 | 49 | Open In Colab 50 | 51 | 52 | ### CLI 53 | 54 | To train via CLI you need to define a config file. Check for sample config files within the `conf` directory. 55 | 56 | ```bash 57 | wandb login 58 | python -m msanii.scripts.training 59 | ``` 60 | 61 | ## Inference 62 | 63 | ### Notebook 64 | 65 | 66 | Open In Colab 67 | 68 | 69 | ### CLI 70 | 71 | Msanii supports the following inference tasks: 72 | 73 | - sampling 74 | - audio2audio 75 | - interpolation 76 | - inpainting 77 | - outpainting 78 | 79 | Each task requires a different config file. Check `conf` directory for samples. 80 | 81 | ```bash 82 | gdown 1G9kF0r5vxYXPSdSuv4t3GR-sBO8xGFCe # model checkpoint 83 | python -m msanii.scripts.inference 84 | ``` 85 | 86 | ## Demo 87 | 88 | ### HF Spaces & Notebook 89 | 90 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/kinyugo/msanii) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Kinyugo/msanii/blob/main/notebooks/msanii_demo.ipynb) 91 | 92 | ### CLI 93 | 94 | To run the demo via CLI you need to define a config file. Check for sample config files within the `conf` directory. 95 | 96 | ```bash 97 | gdown 1G9kF0r5vxYXPSdSuv4t3GR-sBO8xGFCe # model checkpoint 98 | python -m msanii.demo.demo 99 | ``` 100 | 101 | ## Contribute to the Project 102 | 103 | We are always looking for ways to improve and expand our project, and we welcome contributions from the community. Here are a few ways you can get involved: 104 | 105 | - **Bug Fixes and Feature Requests:** If you find any issues with the project, please open a GitHub issue or submit a pull request with a fix. 106 | - **Data Collection:** We are always in need of more data to improve the performance of our models. If you have any relevant data that you would like to share, please let us know. 107 | - **Feedback:** We value feedback from our users and would love to hear your thoughts on the project. Please feel free to reach out to us with any suggestions or comments. 108 | - **Funding:** If you find our project helpful, consider supporting us through GitHub Sponsors. Your support will help us continue to maintain and improve the project. 109 | - **Computational Resources:** If you have access to computational resources such as GPU clusters, you can help us by providing access to these resources to run experiments and improve the project. 110 | - **Code Contributions:** If you are a developer and want to contribute to the codebase, feel free to open a pull request. 111 | - **Documentation:** If you have experience with documentation and want to help improve the project's documentation please let us know. 112 | - **Promotion:** Help increase the visibility and attract more contributors by sharing the project with your friends, colleagues, and on social media. 113 | - **Educational Material:** If you are an educator or content creator you can help by creating tutorials, guides or educational material that can help others understand the project better. 114 | - **Discussing and Sharing Ideas:** Even if you don't have the time or technical skills to contribute directly to the code or documentation, you can still help by sharing and discussing ideas with the community. This can help identify new features or use cases, or find ways to improve existing ones. 115 | - **Ethical Review:** Help us ensure that the project follows ethical standards by reviewing data and models for potential infringements. Additionally, please do not use the project or its models to train or generate copyrighted works without proper authorization. 116 | -------------------------------------------------------------------------------- /msanii/diffusion/modules.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Tuple, Union 2 | 3 | import torch 4 | from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, RePaintScheduler 5 | from torch import Tensor, nn 6 | from tqdm.autonotebook import tqdm 7 | 8 | from .utils import noise_like, sequential_mask 9 | 10 | 11 | class DiffusionModule(nn.Module): 12 | def __init__( 13 | self, 14 | eps_model: nn.Module, 15 | scheduler: Union[DDIMScheduler, DPMSolverMultistepScheduler], 16 | ) -> None: 17 | super().__init__() 18 | 19 | self.eps_model = eps_model 20 | self.scheduler = scheduler 21 | 22 | self.scheduler.set_timesteps(self.scheduler.num_train_timesteps) 23 | 24 | @property 25 | def device(self) -> torch.device: 26 | next(self.eps_model.parameters()).device 27 | 28 | def compute_timesteps( 29 | self, num_inference_steps: int, strength: float, device: torch.device 30 | ) -> Tuple[Tensor, int]: 31 | if strength > 1.0 or strength <= 0.0: 32 | raise ValueError(f"`strength`: {strength} must be between 0 and 1") 33 | 34 | # Ensure number of inference steps is not more than the start step 35 | num_inference_steps = min( 36 | num_inference_steps, round(strength * self.scheduler.num_train_timesteps) 37 | ) 38 | 39 | # Compute the total number of inference steps for the whole sampling schedule 40 | total_inference_steps = round(num_inference_steps / strength) 41 | total_inference_steps = min( 42 | total_inference_steps, self.scheduler.num_train_timesteps 43 | ) 44 | 45 | # Compute the step ratio using the total inference steps 46 | step_ratio = self.scheduler.num_train_timesteps // total_inference_steps 47 | 48 | # Generate timesteps 49 | timesteps = ( 50 | (torch.arange(start=0, end=num_inference_steps, device=device) * step_ratio) 51 | .round() 52 | .flipud() 53 | ) 54 | 55 | return timesteps, total_inference_steps 56 | 57 | 58 | class Sampler(DiffusionModule): 59 | def forward( 60 | self, 61 | x: Tensor, 62 | num_inference_steps: Optional[int] = None, 63 | strength: float = 1.0, 64 | generator: Optional[torch.Generator] = None, 65 | verbose: bool = False, 66 | **kwargs: Any, 67 | ) -> Tensor: 68 | # Set number of timesteps for denoising 69 | timesteps, total_inference_steps = self.compute_timesteps( 70 | num_inference_steps, strength, x.device 71 | ) 72 | self.scheduler.set_timesteps(total_inference_steps, x.device) 73 | 74 | for t in tqdm(timesteps, disable=(not verbose)): 75 | # Prepare a batch of timesteps 76 | batch_t = torch.full( 77 | (x.shape[0],), t.item(), dtype=torch.long, device=x.device 78 | ) 79 | # Get model estimate 80 | eps = self.eps_model(x, batch_t) 81 | # Compute the denoised sample 82 | if isinstance(self.scheduler, DDIMScheduler): 83 | x = self.scheduler.step( 84 | eps, t, x, generator=generator, **kwargs 85 | ).prev_sample 86 | else: 87 | x = self.scheduler.step(eps, t, x, **kwargs).prev_sample 88 | 89 | return x 90 | 91 | 92 | class Interpolater(DiffusionModule): 93 | def __init__( 94 | self, 95 | eps_model: nn.Module, 96 | scheduler: Union[DDIMScheduler, DPMSolverMultistepScheduler], 97 | ) -> None: 98 | super().__init__(eps_model, scheduler) 99 | 100 | self.sampler = Sampler(self.eps_model, self.scheduler) 101 | 102 | def forward( 103 | self, 104 | x1: Tensor, 105 | x2: Tensor, 106 | ratio: float = 0.5, 107 | num_inference_steps: Optional[int] = None, 108 | strength: float = 1.0, 109 | generator: Optional[torch.Generator] = None, 110 | verbose: bool = False, 111 | **kwargs: Any, 112 | ) -> Tensor: 113 | # Get timestep schedule for the denoising process 114 | timesteps, _ = self.compute_timesteps(num_inference_steps, strength, x1.device) 115 | 116 | # Add noise up to the starting timestep 117 | batch_t = torch.full( 118 | (x1.shape[0],), timesteps[0], dtype=torch.long, device=x1.device 119 | ) 120 | x1_noisy = self.scheduler.add_noise(x1, noise_like(x1, generator), batch_t) 121 | x2_noisy = self.scheduler.add_noise(x2, noise_like(x2, generator), batch_t) 122 | 123 | # Interpolate between the two samples in latent/noisy space 124 | x = ratio * x1_noisy + (1 - ratio) * x2_noisy 125 | x = self.sampler(x, num_inference_steps, strength, generator, verbose, **kwargs) 126 | 127 | return x 128 | 129 | 130 | class Inpainter(DiffusionModule): 131 | def __init__( 132 | self, 133 | eps_model: nn.Module, 134 | scheduler: Union[DDIMScheduler, DPMSolverMultistepScheduler], 135 | ) -> None: 136 | super().__init__(eps_model, scheduler) 137 | 138 | self.repaint_scheduler = RePaintScheduler.from_config(self.scheduler.config) 139 | 140 | def forward( 141 | self, 142 | x: Tensor, 143 | mask: Tensor, 144 | num_inference_steps: Optional[int] = None, 145 | eta: float = 0.0, 146 | jump_length=10, 147 | jump_n_sample=10, 148 | generator: Optional[torch.Generator] = None, 149 | verbose: bool = False, 150 | **kwargs: Any, 151 | ) -> Tensor: 152 | # Iterate over all timesteps 153 | if num_inference_steps is None: 154 | num_inference_steps = self.scheduler.num_train_timesteps 155 | 156 | # Adjust scheduler inference steps 157 | self.repaint_scheduler.set_timesteps( 158 | num_inference_steps, jump_length, jump_n_sample, x.device 159 | ) 160 | self.repaint_scheduler.eta = eta 161 | 162 | # Start inpainting from complete noise 163 | x_inpainted = noise_like(x) 164 | 165 | t_last = self.repaint_scheduler.timesteps[0] + 1 166 | for t in tqdm(self.repaint_scheduler.timesteps, disable=(not verbose)): 167 | if t < t_last: 168 | # Prepare a batch of timesteps 169 | batch_t = torch.full( 170 | (x.shape[0],), t, dtype=torch.long, device=x.device 171 | ) 172 | # Get model estimate 173 | eps = self.eps_model(x_inpainted, batch_t) 174 | # Compute the denoised sample 175 | x_inpainted = self.repaint_scheduler.step( 176 | eps, t, x_inpainted, x, mask, generator, **kwargs 177 | ).prev_sample 178 | else: 179 | x_inpainted = self.repaint_scheduler.undo_step( 180 | x_inpainted, t_last, generator 181 | ) 182 | t_last = t 183 | 184 | return x_inpainted 185 | 186 | 187 | class Outpainter(DiffusionModule): 188 | def __init__( 189 | self, 190 | eps_model: nn.Module, 191 | scheduler: Union[DDIMScheduler, DPMSolverMultistepScheduler], 192 | ) -> None: 193 | super().__init__(eps_model, scheduler) 194 | 195 | self.inpainter = Inpainter(eps_model, scheduler) 196 | 197 | def forward( 198 | self, 199 | x: Tensor, 200 | num_spans: int = 1, 201 | num_inference_steps: Optional[int] = None, 202 | eta: float = 0.0, 203 | jump_length=10, 204 | jump_n_sample=10, 205 | generator: Optional[torch.Generator] = None, 206 | verbose: bool = False, 207 | **kwargs: Any, 208 | ) -> Tensor: 209 | half_length = x.shape[-1] // 2 210 | 211 | spans = list(x.chunk(chunks=2, dim=-1)) 212 | # Inpaint second half from first half 213 | inpaint = torch.zeros_like(x) 214 | inpaint[..., :half_length] = x[..., half_length:] 215 | inpaint_mask = sequential_mask(like=x, start=half_length).to(x.dtype) 216 | 217 | for _ in range(num_spans): 218 | # Inpaint second half 219 | span = self.inpainter( 220 | inpaint, 221 | inpaint_mask, 222 | num_inference_steps, 223 | eta, 224 | jump_length, 225 | jump_n_sample, 226 | generator, 227 | verbose, 228 | **kwargs, 229 | ) 230 | # Replace first half with generated second half 231 | second_half = span[..., half_length:] 232 | inpaint[..., :half_length] = second_half 233 | # Save generated span 234 | spans.append(second_half) 235 | 236 | return torch.cat(spans, dim=-1) 237 | -------------------------------------------------------------------------------- /msanii/pipeline/pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional, Union 2 | 3 | import torch 4 | from diffusers import DDIMScheduler, DPMSolverMultistepScheduler 5 | from einops import repeat 6 | from torch import Tensor, nn 7 | from torch.nn import functional as F 8 | from typing_extensions import Self 9 | 10 | from ..config import from_config 11 | from ..diffusion import Inpainter, Interpolater, Outpainter, Sampler 12 | from ..diffusion.utils import noise_like 13 | from ..models import UNet, Vocoder 14 | from ..transforms import Transforms 15 | 16 | 17 | class Pipeline(nn.Module): 18 | def __init__( 19 | self, 20 | transforms: Transforms, 21 | vocoder: Vocoder, 22 | unet: UNet, 23 | scheduler: Union[DDIMScheduler, DPMSolverMultistepScheduler], 24 | ) -> None: 25 | super().__init__() 26 | 27 | self.transforms = transforms 28 | self.vocoder = vocoder 29 | self.unet = unet 30 | self.scheduler = scheduler 31 | 32 | @property 33 | def dtype(self) -> torch.dtype: 34 | return next(self.unet.parameters()).dtype 35 | 36 | @property 37 | def device(self) -> torch.device: 38 | return next(self.unet.parameters()).device 39 | 40 | @torch.no_grad() 41 | def forward(self, *args: Any, **kwds: Any) -> Any: 42 | raise NotImplementedError 43 | 44 | @torch.no_grad() 45 | def sample( 46 | self, 47 | x: Tensor, 48 | num_inference_steps: Optional[int] = None, 49 | strength: float = 1.0, 50 | generator: Optional[torch.Generator] = None, 51 | verbose: bool = False, 52 | use_input_as_seed: bool = False, 53 | use_neural_vocoder: bool = True, 54 | num_griffin_lim_iters: Optional[int] = None, 55 | **kwargs: Any, 56 | ) -> Tensor: 57 | # Initialize the sampler 58 | sampler = Sampler(self.unet, self.scheduler).to(self.device) 59 | 60 | num_frames = x.shape[-1] 61 | 62 | # Start from an initial sample 63 | if use_input_as_seed: 64 | # Add noise to the initial sample up to the last timestep 65 | timesteps, _ = sampler.compute_timesteps( 66 | num_inference_steps, strength, x.device 67 | ) 68 | 69 | # Convert waveform to mel and add noise to the starting timestep 70 | batch_t = torch.full( 71 | (x.shape[0],), timesteps[0], dtype=torch.long, device=self.device 72 | ) 73 | x = self.transforms(x) 74 | x = sampler.scheduler.add_noise(x, noise_like(x, generator), batch_t) 75 | 76 | # Start from random noise 77 | else: 78 | x = noise_like(self.transforms(x), generator) 79 | 80 | # Denoise the samples 81 | x = sampler(x, num_inference_steps, strength, generator, verbose, **kwargs) 82 | 83 | return self.__vocode(x, use_neural_vocoder, num_frames, num_griffin_lim_iters) 84 | 85 | @torch.no_grad() 86 | def interpolate( 87 | self, 88 | x1: Tensor, 89 | x2: Tensor, 90 | ratio: float = 0.5, 91 | num_inference_steps: Optional[int] = None, 92 | strength: float = 1.0, 93 | generator: Optional[torch.Generator] = None, 94 | verbose: bool = False, 95 | use_neural_vocoder: bool = True, 96 | num_griffin_lim_iters: Optional[int] = None, 97 | **kwargs: Any, 98 | ) -> Tensor: 99 | # Initialize the interpolater 100 | interpolater = Interpolater(self.unet, self.scheduler).to(self.device) 101 | 102 | # Transform inputs into mel spectrograms 103 | num_frames = x1.shape[-1] 104 | x1 = self.transforms(x1) 105 | x2 = self.transforms(x2) 106 | 107 | # Interpolate in the melspace 108 | x = interpolater( 109 | x1, x2, ratio, num_inference_steps, strength, generator, verbose, **kwargs 110 | ) 111 | 112 | return self.__vocode(x, use_neural_vocoder, num_frames, num_griffin_lim_iters) 113 | 114 | @torch.no_grad() 115 | def inpaint( 116 | self, 117 | x: Tensor, 118 | mask: Tensor, 119 | num_inference_steps: Optional[int] = None, 120 | eta: float = 0.0, 121 | jump_length: int = 10, 122 | jump_n_sample: int = 10, 123 | generator: Optional[torch.Generator] = None, 124 | verbose: bool = False, 125 | use_neural_vocoder: bool = True, 126 | num_griffin_lim_iters: Optional[int] = None, 127 | **kwargs: Any, 128 | ) -> Tensor: 129 | # Initialize the inpainter 130 | inpainter = Inpainter(self.unet, self.scheduler).to(self.device) 131 | 132 | # Transform input into mel spectrogram 133 | num_frames = x.shape[-1] 134 | x = self.transforms(x) 135 | 136 | # Rescale the mask to match shape of the spectrogram 137 | mask = F.interpolate(mask, size=x.shape[-1]) 138 | mask = repeat(mask, "b c t -> b c f t", f=x.shape[-2]) 139 | 140 | x = inpainter( 141 | x, 142 | mask, 143 | num_inference_steps, 144 | eta, 145 | jump_length, 146 | jump_n_sample, 147 | generator, 148 | verbose, 149 | **kwargs, 150 | ) 151 | 152 | return self.__vocode(x, use_neural_vocoder, num_frames, num_griffin_lim_iters) 153 | 154 | @torch.no_grad() 155 | def outpaint( 156 | self, 157 | x: Tensor, 158 | num_spans: int = 2, 159 | num_inference_steps: Optional[int] = None, 160 | eta: float = 0.0, 161 | jump_length=10, 162 | jump_n_sample=10, 163 | generator: Optional[torch.Generator] = None, 164 | verbose: bool = False, 165 | use_neural_vocoder: bool = True, 166 | num_griffin_lim_iters: Optional[int] = None, 167 | **kwargs: Any, 168 | ) -> Tensor: 169 | # Initialize the inpainter 170 | outpainter = Outpainter(self.unet, self.scheduler).to(self.device) 171 | 172 | # Transform input into mel spectrogram 173 | num_frames = x.shape[-1] 174 | x = self.transforms(x) 175 | 176 | x = outpainter( 177 | x, 178 | num_spans, 179 | num_inference_steps, 180 | eta, 181 | jump_length, 182 | jump_n_sample, 183 | generator, 184 | verbose, 185 | **kwargs, 186 | ) 187 | 188 | # Compute the new number of frames 189 | num_frames = num_frames + ((num_frames // 2) * num_spans) 190 | x = self.__vocode(x, use_neural_vocoder, None, num_griffin_lim_iters) 191 | 192 | return x[..., :num_frames] 193 | 194 | @torch.no_grad() 195 | def save_pretrained(self, ckpt_path: str) -> None: 196 | checkpoint = { 197 | "transforms_config": self.transforms.config, 198 | "vocoder_config": self.vocoder.config, 199 | "unet_config": self.unet.config, 200 | "scheduler_config": self.scheduler.config, 201 | "transforms_state_dict": self.transforms.state_dict(), 202 | "vocoder_state_dict": self.vocoder.state_dict(), 203 | "unet_state_dict": self.unet.state_dict(), 204 | } 205 | torch.save(checkpoint, ckpt_path) 206 | 207 | @classmethod 208 | def from_pretrained( 209 | cls, 210 | ckpt_path: str, 211 | scheduler_class: Union[ 212 | DDIMScheduler, DPMSolverMultistepScheduler 213 | ] = DDIMScheduler, 214 | device: Optional[torch.device] = None, 215 | ) -> Self: 216 | checkpoint = torch.load(ckpt_path) 217 | 218 | transforms = Pipeline._load_from_checkpoint( 219 | checkpoint, "transforms", Transforms 220 | ) 221 | vocoder = Pipeline._load_from_checkpoint(checkpoint, "vocoder", Vocoder) 222 | unet = Pipeline._load_from_checkpoint(checkpoint, "unet", UNet) 223 | scheduler = from_config(checkpoint["scheduler_config"], scheduler_class) 224 | 225 | return cls(transforms, vocoder, unet, scheduler).to(device) 226 | 227 | def __vocode( 228 | self, 229 | x: Tensor, 230 | use_neural_vocoder: bool = True, 231 | num_frames: Optional[int] = None, 232 | num_griffin_lim_iters: Optional[int] = None, 233 | ) -> Tensor: 234 | if use_neural_vocoder: 235 | return self.transforms.griffin_lim( 236 | self.vocoder(x), length=num_frames, num_iters=num_griffin_lim_iters 237 | ) 238 | 239 | return self.transforms( 240 | x, 241 | inverse=True, 242 | length=num_frames, 243 | num_griffin_lim_iters=num_griffin_lim_iters, 244 | ) 245 | 246 | @staticmethod 247 | def _load_from_checkpoint(checkpoint, prefix: str, target: Callable) -> Any: 248 | target_instance = from_config(checkpoint[f"{prefix}_config"], target) 249 | target_instance.load_state_dict(checkpoint[f"{prefix}_state_dict"]) 250 | 251 | return target_instance 252 | 253 | 254 | if __name__ == "__main__": 255 | pipeline = Pipeline(Transforms(), UNet(), Vocoder(), DDIMScheduler()) 256 | x = torch.randn((1, 2, 523_264)) 257 | x = pipeline.outpaint(x, num_spans=2, num_inference_steps=2, verbose=True) 258 | print("samples", x.shape) 259 | -------------------------------------------------------------------------------- /msanii/demo/helpers.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | import torch 6 | from matplotlib.figure import Figure 7 | from torch import Tensor 8 | from torch.nn import functional as F 9 | from torchaudio import functional as AF 10 | 11 | from ..pipeline import Pipeline 12 | from ..utils import compute_divisible_length, plot_spectrogram, plot_waveform 13 | from .utils import ( 14 | generate_gradio_audio_mask, 15 | gradio_audio_postprocessing, 16 | gradio_audio_preprocessing, 17 | max_abs_scaling, 18 | ) 19 | 20 | 21 | def run_sampling( 22 | pipeline: Pipeline, 23 | duration: int, 24 | channels: int, 25 | num_inference_steps: int, 26 | eta: float, 27 | use_neural_vocoder: bool, 28 | num_griffin_lim_iters: int, 29 | seed: float, 30 | verbose: bool, 31 | ) -> Tuple[Figure, Figure, Tuple[int, np.ndarray]]: 32 | if verbose: 33 | print(f"Pipeline device: {pipeline.device}, dtype: {pipeline.dtype}") 34 | # Prepare sample audio that will guide the sampling process 35 | audio_length = duration * pipeline.transforms.sample_rate 36 | divisible_length = compute_divisible_length( 37 | audio_length, pipeline.transforms.hop_length, sum(pipeline.unet.has_resampling) 38 | ) 39 | audio = torch.randn( 40 | (1, channels, divisible_length), dtype=pipeline.dtype, device=pipeline.device 41 | ) 42 | 43 | # Generate sample from the pipeline 44 | generator = torch.Generator(pipeline.device).manual_seed(int(seed)) 45 | audio = pipeline.sample( 46 | audio, 47 | num_inference_steps, 48 | generator=generator, 49 | use_neural_vocoder=use_neural_vocoder, 50 | num_griffin_lim_iters=num_griffin_lim_iters, 51 | eta=eta, 52 | verbose=verbose, 53 | ) 54 | 55 | # Compute waveform and spectrogram representation 56 | spectrogram = plot_spectrogram(pipeline.transforms(audio)) 57 | waveform = plot_waveform(audio, pipeline.transforms.sample_rate) 58 | audio = gradio_audio_postprocessing(audio, audio_length) 59 | 60 | return spectrogram, waveform, (pipeline.transforms.sample_rate, audio) 61 | 62 | 63 | def run_audio2audio( 64 | pipeline: Pipeline, 65 | audio: Tuple[int, np.ndarray], 66 | num_inference_steps: int, 67 | strength: float, 68 | use_neural_vocoder: bool, 69 | num_griffin_lim_iters: int, 70 | seed: float, 71 | eta: float, 72 | max_abs_value: float, 73 | verbose: bool, 74 | ) -> Tuple[Figure, Figure, Tuple[int, np.ndarray]]: 75 | if verbose: 76 | print(f"Pipeline device: {pipeline.device}, dtype: {pipeline.dtype}") 77 | # Convert audio to tensor & resample 78 | sample_rate, audio = audio 79 | 80 | # Apply some preprocessing 81 | audio_len = audio.shape[0] 82 | audio = gradio_audio_preprocessing( 83 | audio, 84 | src_sample_rate=sample_rate, 85 | target_sample_rate=pipeline.transforms.sample_rate, 86 | target_length=audio_len, 87 | hop_length=pipeline.transforms.hop_length, 88 | num_downsamples=sum(pipeline.unet.has_resampling), 89 | dtype=pipeline.dtype, 90 | device=pipeline.device, 91 | ) 92 | audio = max_abs_scaling(audio, max_abs_value) 93 | 94 | # Generate sample from pipeline 95 | generator = torch.Generator(pipeline.device).manual_seed(int(seed)) 96 | audio = pipeline.sample( 97 | audio, 98 | num_inference_steps, 99 | strength=strength, 100 | generator=generator, 101 | use_neural_vocoder=use_neural_vocoder, 102 | use_input_as_seed=True, 103 | num_griffin_lim_iters=num_griffin_lim_iters, 104 | eta=eta, 105 | verbose=verbose, 106 | ) 107 | 108 | # Compute waveform and spectrogram representation 109 | spectrogram = plot_spectrogram(pipeline.transforms(audio)) 110 | waveform = plot_waveform(audio, pipeline.transforms.sample_rate) 111 | audio = max_abs_scaling(audio, max_abs_value=1.0) 112 | audio_len = int((audio_len * pipeline.transforms.sample_rate) / sample_rate) 113 | audio = gradio_audio_postprocessing(audio, audio_len) 114 | 115 | return spectrogram, waveform, (pipeline.transforms.sample_rate, audio) 116 | 117 | 118 | def run_interpolation( 119 | pipeline: Pipeline, 120 | first_audio: Tuple[int, np.ndarray], 121 | second_audio: Tuple[int, np.ndarray], 122 | num_inference_steps: int, 123 | ratio: float, 124 | strength: float, 125 | use_neural_vocoder: bool, 126 | num_griffin_lim_iters: int, 127 | seed: float, 128 | eta: float, 129 | max_abs_value: float, 130 | verbose: bool, 131 | ) -> Tuple[Figure, Figure, Tuple[int, np.ndarray]]: 132 | if verbose: 133 | print(f"Pipeline device: {pipeline.device}, dtype: {pipeline.dtype}") 134 | # Convert audio to tensor & resample 135 | first_sample_rate, first_audio = first_audio 136 | second_sample_rate, second_audio = second_audio 137 | 138 | # Apply some preprocessing 139 | audio_len = first_audio.shape[0] 140 | first_audio = gradio_audio_preprocessing( 141 | first_audio, 142 | src_sample_rate=first_sample_rate, 143 | target_sample_rate=pipeline.transforms.sample_rate, 144 | target_length=audio_len, 145 | hop_length=pipeline.transforms.hop_length, 146 | num_downsamples=sum(pipeline.unet.has_resampling), 147 | dtype=pipeline.dtype, 148 | device=pipeline.device, 149 | ) 150 | target_length = int((audio_len * second_sample_rate) / first_sample_rate) 151 | second_audio = gradio_audio_preprocessing( 152 | second_audio, 153 | src_sample_rate=second_sample_rate, 154 | target_sample_rate=pipeline.transforms.sample_rate, 155 | target_length=target_length, 156 | hop_length=pipeline.transforms.hop_length, 157 | num_downsamples=sum(pipeline.unet.has_resampling), 158 | dtype=pipeline.dtype, 159 | device=pipeline.device, 160 | ) 161 | 162 | # Scale the amplitude to a range close to what the model outputs 163 | first_audio = max_abs_scaling(first_audio, max_abs_value) 164 | second_audio = max_abs_scaling(second_audio, max_abs_value) 165 | 166 | # Generate sample from pipeline 167 | generator = torch.Generator(pipeline.device).manual_seed(int(seed)) 168 | audio = pipeline.interpolate( 169 | first_audio, 170 | second_audio, 171 | ratio=ratio, 172 | num_inference_steps=num_inference_steps, 173 | strength=strength, 174 | generator=generator, 175 | use_neural_vocoder=use_neural_vocoder, 176 | num_griffin_lim_iters=num_griffin_lim_iters, 177 | eta=eta, 178 | verbose=verbose, 179 | ) 180 | 181 | # Compute waveform and spectrogram representation 182 | spectrogram = plot_spectrogram(pipeline.transforms(audio)) 183 | waveform = plot_waveform(audio, pipeline.transforms.sample_rate) 184 | audio = max_abs_scaling(audio, max_abs_value=1.0) 185 | audio_len = int((audio_len * pipeline.transforms.sample_rate) / first_sample_rate) 186 | audio = gradio_audio_postprocessing(audio, audio_len) 187 | 188 | return spectrogram, waveform, (pipeline.transforms.sample_rate, audio) 189 | 190 | 191 | def run_inpainting( 192 | pipeline: Pipeline, 193 | audio: Tuple[int, np.ndarray], 194 | mask_spec: str, 195 | jump_length: int, 196 | jump_n_samples: int, 197 | num_inference_steps: int, 198 | use_neural_vocoder: bool, 199 | num_griffin_lim_iters: int, 200 | seed: float, 201 | eta: float, 202 | max_abs_value: float, 203 | verbose: bool, 204 | ) -> Tuple[Figure, Figure, Tuple[int, np.ndarray]]: 205 | if verbose: 206 | print(f"Pipeline device: {pipeline.device}, dtype: {pipeline.dtype}") 207 | 208 | sample_rate, audio = audio 209 | 210 | # Generate mask from the mask-spec 211 | audio_mask = generate_gradio_audio_mask(audio, sample_rate, mask_spec) 212 | 213 | # Apply some preprocessing 214 | audio_len = audio.shape[0] 215 | audio = gradio_audio_preprocessing( 216 | audio, 217 | src_sample_rate=sample_rate, 218 | target_sample_rate=pipeline.transforms.sample_rate, 219 | target_length=audio_len, 220 | hop_length=pipeline.transforms.hop_length, 221 | num_downsamples=sum(pipeline.unet.has_resampling), 222 | dtype=pipeline.dtype, 223 | device=pipeline.device, 224 | ) 225 | audio_mask = gradio_audio_preprocessing( 226 | audio_mask.astype(float), 227 | src_sample_rate=sample_rate, 228 | target_sample_rate=pipeline.transforms.sample_rate, 229 | target_length=audio_len, 230 | hop_length=pipeline.transforms.hop_length, 231 | num_downsamples=sum(pipeline.unet.has_resampling), 232 | dtype=pipeline.dtype, 233 | device=pipeline.device, 234 | ) 235 | 236 | # Scale the amplitude to a range close to what the model outputs 237 | audio = max_abs_scaling(audio, max_abs_value) 238 | 239 | # Generate sample from pipeline 240 | generator = torch.Generator(pipeline.device).manual_seed(int(seed)) 241 | audio = pipeline.inpaint( 242 | audio, 243 | audio_mask, 244 | num_inference_steps=num_inference_steps, 245 | jump_length=jump_length, 246 | jump_n_sample=jump_n_samples, 247 | generator=generator, 248 | use_neural_vocoder=use_neural_vocoder, 249 | num_griffin_lim_iters=num_griffin_lim_iters, 250 | eta=eta, 251 | verbose=verbose, 252 | ) 253 | 254 | # Compute waveform and spectrogram representation 255 | spectrogram = plot_spectrogram(pipeline.transforms(audio)) 256 | waveform = plot_waveform(audio, pipeline.transforms.sample_rate) 257 | audio = max_abs_scaling(audio, max_abs_value=1.0) 258 | audio_len = int((audio_len * pipeline.transforms.sample_rate) / sample_rate) 259 | audio = gradio_audio_postprocessing(audio, audio_len) 260 | 261 | return spectrogram, waveform, (pipeline.transforms.sample_rate, audio) 262 | 263 | 264 | def run_outpainting( 265 | pipeline: Pipeline, 266 | audio: Tuple[int, np.ndarray], 267 | num_spans: int, 268 | num_inference_steps: int, 269 | use_neural_vocoder: bool, 270 | num_griffin_lim_iters: int, 271 | seed: float, 272 | eta: float, 273 | max_abs_value, 274 | verbose: bool, 275 | ) -> Tuple[Figure, Figure, Tuple[int, np.ndarray]]: 276 | if verbose: 277 | print(f"Pipeline device: {pipeline.device}, dtype: {pipeline.dtype}") 278 | # Convert audio to tensor & resample 279 | sample_rate, audio = audio 280 | 281 | # Apply some preprocessing 282 | seed_length = audio.shape[0] 283 | audio = gradio_audio_preprocessing( 284 | audio, 285 | src_sample_rate=sample_rate, 286 | target_sample_rate=pipeline.transforms.sample_rate, 287 | target_length=seed_length, 288 | hop_length=pipeline.transforms.hop_length, 289 | num_downsamples=sum(pipeline.unet.has_resampling), 290 | dtype=pipeline.dtype, 291 | device=pipeline.device, 292 | pad_end=False, 293 | ) 294 | padded_length = audio.shape[-1] 295 | 296 | # Scale the amplitude to a range close to what the model outputs 297 | audio = max_abs_scaling(audio, max_abs_value) 298 | 299 | # Generate sample from pipeline 300 | generator = torch.Generator(pipeline.device).manual_seed(int(seed)) 301 | audio = pipeline.outpaint( 302 | audio, 303 | num_spans=num_spans, 304 | num_inference_steps=num_inference_steps, 305 | generator=generator, 306 | use_neural_vocoder=use_neural_vocoder, 307 | num_griffin_lim_iters=num_griffin_lim_iters, 308 | eta=eta, 309 | verbose=verbose, 310 | ) 311 | 312 | # Compute waveform and spectrogram representation 313 | seed_length = int((seed_length * pipeline.transforms.sample_rate) / sample_rate) 314 | target_length = int(seed_length + (((padded_length / 2) * num_spans))) 315 | spectrogram = plot_spectrogram(pipeline.transforms(audio)) 316 | waveform = plot_waveform(audio, pipeline.transforms.sample_rate) 317 | audio = max_abs_scaling(audio, max_abs_value=1.0) 318 | audio = gradio_audio_postprocessing(audio, target_length, pad_end=False) 319 | 320 | return spectrogram, waveform, (pipeline.transforms.sample_rate, audio) 321 | -------------------------------------------------------------------------------- /msanii/demo/demo.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import Optional 3 | 4 | import gradio as gr 5 | import matplotlib 6 | import numpy as np 7 | import torch 8 | from omegaconf import OmegaConf 9 | 10 | from ..config import DemoConfig 11 | from ..pipeline import Pipeline 12 | from .helpers import ( 13 | run_audio2audio, 14 | run_inpainting, 15 | run_interpolation, 16 | run_outpainting, 17 | run_sampling, 18 | ) 19 | 20 | HEADER = """ 21 | 35 | 36 | # Msanii: High Fidelity Music Synthesis on a Shoestring Budget 37 | """ 38 | 39 | 40 | def run_demo(config: DemoConfig) -> Optional[gr.Blocks]: 41 | # ------------------------------------------- 42 | # Configure Matplotlib 43 | # ------------------------------------------- 44 | # Prevents pixelated fonts on figures 45 | matplotlib.use("webagg") 46 | matplotlib.style.use(["seaborn", "fast"]) 47 | 48 | # ------------------------------------------- 49 | # Load Pipeline checkpoint 50 | # ------------------------------------------- 51 | pipeline = Pipeline.from_pretrained(config.ckpt_path) 52 | pipeline = pipeline.to(torch.device(config.device)) 53 | pipeline = pipeline.to(getattr(torch, config.dtype)) 54 | 55 | # ------------------------------------------- 56 | # Define gradio interface 57 | # ------------------------------------------- 58 | with gr.Blocks(title="Msanii") as demo: 59 | gr.Markdown(HEADER) 60 | 61 | with gr.Row(): 62 | # Main Section 63 | with gr.Column(scale=2): 64 | with gr.Tab("Sampling"): 65 | sampling_spectrogram_output = gr.Plot(label="Spectrogram") 66 | sampling_waveform_output = gr.Plot(label="Waveform") 67 | sampling_audio_output = gr.Audio(label="Sample Audio") 68 | sampling_button = gr.Button(value="Run", variant="primary") 69 | 70 | with gr.Tab("Audio2Audio"): 71 | a2a_audio_input = gr.Audio(label="Source Audio") 72 | a2a_spectrogram_output = gr.Plot(label="Spectrogram") 73 | a2a_waveform_output = gr.Plot(label="Waveform") 74 | a2a_audio_output = gr.Audio(label="Sample Audio") 75 | a2a_button = gr.Button(value="Run", variant="primary") 76 | 77 | with gr.Tab("Interpolation"): 78 | interpolation_first_audio_input = gr.Audio( 79 | label="First Source Audio" 80 | ) 81 | interpolation_second_audio_input = gr.Audio( 82 | label="Second Source Audio" 83 | ) 84 | interpolation_spectrogram_output = gr.Plot(label="Spectrogram") 85 | interpolation_waveform_output = gr.Plot(label="Waveform") 86 | interpolation_audio_output = gr.Audio(label="Sample Audio") 87 | interpolation_button = gr.Button(value="Run", variant="primary") 88 | 89 | with gr.Tab("Inpainting"): 90 | inpainting_audio_input = gr.Audio(label="Source Audio") 91 | inpainting_mask_input = gr.Text( 92 | label="Mask Intervals (seconds) e.g: 20-30,50-60" 93 | ) 94 | inpainting_spectrogram_output = gr.Plot(label="Spectrogram") 95 | inpainting_waveform_output = gr.Plot(label="Waveform") 96 | inpainting_audio_output = gr.Audio(label="Sample Audio") 97 | inpainting_button = gr.Button(value="Run", variant="primary") 98 | with gr.Tab("Outpainting"): 99 | outpainting_audio_input = gr.Audio(label="Source Audio") 100 | outpainting_spectrogram_output = gr.Plot(label="Spectrogram") 101 | outpainting_waveform_output = gr.Plot(label="Waveform") 102 | outpainting_audio_output = gr.Audio(label="Sample Audio") 103 | outpainting_button = gr.Button(value="Run", variant="primary") 104 | 105 | # Options 106 | with gr.Column(scale=1): 107 | with gr.Accordion(label="General Options", open=True): 108 | with gr.Group(): 109 | inference_steps_slider = gr.Slider( 110 | minimum=1, 111 | maximum=pipeline.scheduler.num_train_timesteps, 112 | value=20, 113 | step=1, 114 | label="Number of Inference Steps", 115 | ) 116 | griffin_lim_iters_slider = gr.Slider( 117 | minimum=1, 118 | maximum=1000, 119 | value=100, 120 | step=1, 121 | label="Number of GriffinLim Iterations", 122 | ) 123 | seed = gr.Number( 124 | value=lambda: np.random.randint(0, 1_000_000), label="Seed" 125 | ) 126 | 127 | with gr.Accordion(label="Task Specific Options", open=False): 128 | with gr.Group(): 129 | gr.Markdown("Sampling Options") 130 | duration_slider = gr.Slider( 131 | minimum=10, maximum=190, label="Audio Duration (seconds)" 132 | ) 133 | channels_slider = gr.Slider( 134 | minimum=1, 135 | maximum=2, 136 | step=1, 137 | value=2, 138 | label="Audio Channels", 139 | ) 140 | 141 | with gr.Group(): 142 | gr.Markdown("Audio2Audio & Interpolation Options") 143 | ratio_slider = gr.Slider( 144 | minimum=0, maximum=1, value=0.5, label="Interpolation Ratio" 145 | ) 146 | strength_slider = gr.Slider( 147 | minimum=0, maximum=1, value=0.1, label="Noise Strength" 148 | ) 149 | 150 | with gr.Group(): 151 | gr.Markdown("Inpainting & Outpainting") 152 | jump_length_slider = gr.Slider( 153 | minimum=1, 154 | maximum=pipeline.scheduler.num_train_timesteps, 155 | value=10, 156 | label="Number of Forward Steps", 157 | ) 158 | jump_n_samples_slider = gr.Slider( 159 | minimum=1, 160 | maximum=pipeline.scheduler.num_train_timesteps, 161 | value=10, 162 | label="Number of Forward Jumps", 163 | ) 164 | 165 | with gr.Group(): 166 | gr.Markdown("Outpainting") 167 | num_spans_slider = gr.Slider( 168 | minimum=1, 169 | maximum=100, 170 | step=1, 171 | value=5, 172 | label="Number of Outpaint Spans (1/2 duration)", 173 | ) 174 | with gr.Accordion(label="Advanced Options", open=False): 175 | use_nv_checkbox = gr.Checkbox( 176 | value=True, label="Use Neural Vocoder" 177 | ) 178 | eta_slider = gr.Slider(minimum=0, maximum=1, label="eta") 179 | max_abs_value_slider = gr.Slider( 180 | minimum=0, maximum=1, value=0.05, label="Maximum Absolute Value" 181 | ) 182 | verbose_checkbox = gr.Checkbox(value=False, label="Verbose") 183 | 184 | sampling_button.click( 185 | lambda *args: run_sampling(pipeline, *args), 186 | inputs=[ 187 | duration_slider, 188 | channels_slider, 189 | inference_steps_slider, 190 | eta_slider, 191 | use_nv_checkbox, 192 | griffin_lim_iters_slider, 193 | seed, 194 | verbose_checkbox, 195 | ], 196 | outputs=[ 197 | sampling_spectrogram_output, 198 | sampling_waveform_output, 199 | sampling_audio_output, 200 | ], 201 | ) 202 | 203 | a2a_button.click( 204 | lambda *args: run_audio2audio(pipeline, *args), 205 | inputs=[ 206 | a2a_audio_input, 207 | inference_steps_slider, 208 | strength_slider, 209 | use_nv_checkbox, 210 | griffin_lim_iters_slider, 211 | seed, 212 | eta_slider, 213 | max_abs_value_slider, 214 | verbose_checkbox, 215 | ], 216 | outputs=[a2a_spectrogram_output, a2a_waveform_output, a2a_audio_output], 217 | ) 218 | 219 | interpolation_button.click( 220 | lambda *args: run_interpolation(pipeline, *args), 221 | inputs=[ 222 | interpolation_first_audio_input, 223 | interpolation_second_audio_input, 224 | inference_steps_slider, 225 | ratio_slider, 226 | strength_slider, 227 | use_nv_checkbox, 228 | griffin_lim_iters_slider, 229 | seed, 230 | eta_slider, 231 | max_abs_value_slider, 232 | verbose_checkbox, 233 | ], 234 | outputs=[ 235 | interpolation_spectrogram_output, 236 | interpolation_waveform_output, 237 | interpolation_audio_output, 238 | ], 239 | ) 240 | 241 | inpainting_button.click( 242 | lambda *args: run_inpainting(pipeline, *args), 243 | inputs=[ 244 | inpainting_audio_input, 245 | inpainting_mask_input, 246 | jump_length_slider, 247 | jump_n_samples_slider, 248 | inference_steps_slider, 249 | use_nv_checkbox, 250 | griffin_lim_iters_slider, 251 | seed, 252 | eta_slider, 253 | max_abs_value_slider, 254 | verbose_checkbox, 255 | ], 256 | outputs=[ 257 | inpainting_spectrogram_output, 258 | inpainting_waveform_output, 259 | inpainting_audio_output, 260 | ], 261 | ) 262 | 263 | outpainting_button.click( 264 | lambda *args: run_outpainting(pipeline, *args), 265 | inputs=[ 266 | outpainting_audio_input, 267 | num_spans_slider, 268 | inference_steps_slider, 269 | use_nv_checkbox, 270 | griffin_lim_iters_slider, 271 | seed, 272 | eta_slider, 273 | max_abs_value_slider, 274 | verbose_checkbox, 275 | ], 276 | outputs=[ 277 | outpainting_spectrogram_output, 278 | outpainting_waveform_output, 279 | outpainting_audio_output, 280 | ], 281 | ) 282 | 283 | if config.launch: 284 | demo.launch(debug=True) 285 | else: 286 | return demo 287 | 288 | 289 | if __name__ == "__main__": 290 | parser = ArgumentParser() 291 | parser.add_argument("config_path", help="path to config file", type=str) 292 | args = parser.parse_args() 293 | 294 | default_demo_config = OmegaConf.structured(DemoConfig) 295 | file_demo_config = OmegaConf.load(args.config_path) 296 | demo_config = OmegaConf.merge(default_demo_config, file_demo_config) 297 | 298 | run_demo(demo_config) 299 | -------------------------------------------------------------------------------- /notebooks/msanii_inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Msanii Inference\n" 9 | ] 10 | }, 11 | { 12 | "attachments": {}, 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## GPU Check" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "!nvidia-smi" 26 | ] 27 | }, 28 | { 29 | "attachments": {}, 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Setup\n", 34 | "\n", 35 | "Run one of the the below install options. \n", 36 | "> **WARNING:** Restart the runtime or some packages will not be updated!" 37 | ] 38 | }, 39 | { 40 | "attachments": {}, 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "### Install package from git" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "%pip install -q git+https://github.com/Kinyugo/msanii.git" 54 | ] 55 | }, 56 | { 57 | "attachments": {}, 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "### Install package in edit mode" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "!git clone https://github.com/Kinyugo/msanii.git\n", 71 | "!cd msanii\n", 72 | "%pip install -q -r requirements.txt\n", 73 | "%pip install -e ." 74 | ] 75 | }, 76 | { 77 | "attachments": {}, 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "### Fetch model checkpoint" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "%pip install -q gdown --upgrade --no-cache" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "!gdown 1G9kF0r5vxYXPSdSuv4t3GR-sBO8xGFCe" 100 | ] 101 | }, 102 | { 103 | "attachments": {}, 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "### Imports" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "from omegaconf import OmegaConf\n", 117 | "\n", 118 | "from msanii.config import (\n", 119 | " Audio2AudioConfig,\n", 120 | " InpaintingConfig,\n", 121 | " InterpolationConfig,\n", 122 | " OutpaintingConfig,\n", 123 | " SamplingConfig,\n", 124 | ")\n", 125 | "from msanii.scripts import (\n", 126 | " run_audio2audio,\n", 127 | " run_inpainting,\n", 128 | " run_interpolation,\n", 129 | " run_outpainting,\n", 130 | " run_sampling,\n", 131 | ")" 132 | ] 133 | }, 134 | { 135 | "attachments": {}, 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "## Sampling\n" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "sampling_dict_config = {\n", 149 | " \"ckpt_path\": \"\",\n", 150 | " \"output_dir\": \"\",\n", 151 | " \"batch_size\": 4,\n", 152 | " \"num_frames\": 8_387_584, # should divisible by the downsampling factor of the U-Net\n", 153 | " \"output_audio_format\": \"wav\", # ogg, mp3 ...\n", 154 | " \"seed\": 0,\n", 155 | " \"device\": \"cuda\", # cpu or cuda\n", 156 | " \"dtype\": \"float32\", # torch.dtype\n", 157 | " \"num_inference_steps\": 20,\n", 158 | " \"verbose\": True,\n", 159 | " \"use_neural_vocoder\": True,\n", 160 | " \"channels\": 2, # mono or stereo\n", 161 | " \"num_samples\": 16,\n", 162 | "}" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "sampling_default_config = OmegaConf.structured(SamplingConfig)\n", 172 | "sampling_custom_config = OmegaConf.create(sampling_dict_config)\n", 173 | "sampling_config = OmegaConf.merge(sampling_default_config, sampling_custom_config)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "run_sampling(sampling_config)" 183 | ] 184 | }, 185 | { 186 | "attachments": {}, 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "## Audio2Audio\n" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "audio2audio_dict_config = {\n", 200 | " \"ckpt_path\": \"\",\n", 201 | " \"output_dir\": \"\",\n", 202 | " \"batch_size\": 4,\n", 203 | " \"num_frames\": 8_387_584, # should divisible by the downsampling factor of the U-Net\n", 204 | " \"output_audio_format\": \"wav\", # ogg, mp3 ...\n", 205 | " \"seed\": 0,\n", 206 | " \"device\": \"cuda\", # cpu or cuda\n", 207 | " \"dtype\": \"float32\", # torch.dtype\n", 208 | " \"num_inference_steps\": 20,\n", 209 | " \"verbose\": True,\n", 210 | " \"use_neural_vocoder\": True,\n", 211 | " \"data_dir\": \"\",\n", 212 | " \"num_workers\": 4,\n", 213 | " \"pin_memory\": True,\n", 214 | " \"strength\": 0.1, # controls how much noise is added; [0, 1]\n", 215 | "}" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "audio2audio_default_config = OmegaConf.structured(Audio2AudioConfig)\n", 225 | "audio2audio_custom_config = OmegaConf.create(audio2audio_dict_config)\n", 226 | "audio2audio_config = OmegaConf.merge(\n", 227 | " audio2audio_default_config, audio2audio_custom_config\n", 228 | ")" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "run_audio2audio(audio2audio_config)" 238 | ] 239 | }, 240 | { 241 | "attachments": {}, 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "## Interpolation\n" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "interpolation_dict_config = {\n", 255 | " \"ckpt_path\": \"\",\n", 256 | " \"output_dir\": \"\",\n", 257 | " \"batch_size\": 4,\n", 258 | " \"num_frames\": 8_387_584, # should divisible by the downsampling factor of the U-Net\n", 259 | " \"output_audio_format\": \"wav\", # ogg, mp3 ...\n", 260 | " \"seed\": 0,\n", 261 | " \"device\": \"cuda\", # cpu or cuda\n", 262 | " \"dtype\": \"float32\", # torch.dtype\n", 263 | " \"num_inference_steps\": 20,\n", 264 | " \"verbose\": True,\n", 265 | " \"use_neural_vocoder\": True,\n", 266 | " \"first_data_dir\": \"\",\n", 267 | " \"second_data_dir\": \"\",\n", 268 | " \"num_workers\": 4,\n", 269 | " \"pin_memory\": True,\n", 270 | " \"ratio\": 0.5, # controls how much of the first sample is in the interpolation\n", 271 | " \"strength\": 0.1, # controls how much noise is added; [0, 1]\n", 272 | "}" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "interpolation_default_config = OmegaConf.structured(InterpolationConfig)\n", 282 | "interpolation_custom_config = OmegaConf.create(interpolation_dict_config)\n", 283 | "interpolation_config = OmegaConf.merge(\n", 284 | " interpolation_default_config, interpolation_custom_config\n", 285 | ")" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "run_interpolation(interpolation_config)" 295 | ] 296 | }, 297 | { 298 | "attachments": {}, 299 | "cell_type": "markdown", 300 | "metadata": {}, 301 | "source": [ 302 | "## Inpainting\n" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "inpainting_dict_config = {\n", 312 | " \"ckpt_path\": \"\",\n", 313 | " \"output_dir\": \"\",\n", 314 | " \"batch_size\": 4,\n", 315 | " \"num_frames\": 8_387_584, # should divisible by the downsampling factor of the U-Net\n", 316 | " \"output_audio_format\": \"wav\", # ogg, mp3 ...\n", 317 | " \"seed\": 0,\n", 318 | " \"device\": \"cuda\", # cpu or cuda\n", 319 | " \"dtype\": \"float32\", # torch.dtype\n", 320 | " \"num_inference_steps\": 20,\n", 321 | " \"verbose\": True,\n", 322 | " \"use_neural_vocoder\": True,\n", 323 | " \"data_dir\": \"\",\n", 324 | " \"num_workers\": 4,\n", 325 | " \"pin_memory\": True,\n", 326 | " \"masks\": [], # e.g [\"3-5,10-50\",\"4-10\", ...] for each sample if the folder,\n", 327 | " \"eta\": 0.0,\n", 328 | " \"jump_length\": 10,\n", 329 | " \"jump_n_sample\": 10,\n", 330 | "}" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [ 339 | "inpainting_default_config = OmegaConf.structured(InpaintingConfig)\n", 340 | "inpainting_custom_config = OmegaConf.create(inpainting_dict_config)\n", 341 | "inpainting_config = OmegaConf.merge(inpainting_default_config, sampling_custom_config)" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "run_inpainting(inpainting_config)" 351 | ] 352 | }, 353 | { 354 | "attachments": {}, 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "## Outpainting\n" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "outpainting_dict_config = {\n", 368 | " \"ckpt_path\": \"\",\n", 369 | " \"output_dir\": \"\",\n", 370 | " \"batch_size\": 4,\n", 371 | " \"num_frames\": 8_387_584, # should divisible by the downsampling factor of the U-Net\n", 372 | " \"output_audio_format\": \"wav\", # ogg, mp3 ...\n", 373 | " \"seed\": 0,\n", 374 | " \"device\": \"cuda\", # cpu or cuda\n", 375 | " \"dtype\": \"float32\", # torch.dtype\n", 376 | " \"num_inference_steps\": 20,\n", 377 | " \"verbose\": True,\n", 378 | " \"use_neural_vocoder\": True,\n", 379 | " \"data_dir\": \"\",\n", 380 | " \"num_workers\": 4,\n", 381 | " \"pin_memory\": True,\n", 382 | " \"num_spans\": 2, # number of half the num_frames outpaints\n", 383 | " \"eta\": 0.0,\n", 384 | " \"jump_length\": 10,\n", 385 | " \"jump_n_sample\": 10,\n", 386 | "}" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [ 395 | "outpainting_default_config = OmegaConf.structured(OutpaintingConfig)\n", 396 | "outpainting_custom_config = OmegaConf.create(outpainting_dict_config)\n", 397 | "outpainting_config = OmegaConf.merge(\n", 398 | " outpainting_default_config, outpainting_custom_config\n", 399 | ")" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": null, 405 | "metadata": {}, 406 | "outputs": [], 407 | "source": [ 408 | "run_outpainting(outpainting_config)" 409 | ] 410 | } 411 | ], 412 | "metadata": { 413 | "kernelspec": { 414 | "display_name": "torch", 415 | "language": "python", 416 | "name": "python3" 417 | }, 418 | "language_info": { 419 | "codemirror_mode": { 420 | "name": "ipython", 421 | "version": 3 422 | }, 423 | "file_extension": ".py", 424 | "mimetype": "text/x-python", 425 | "name": "python", 426 | "nbconvert_exporter": "python", 427 | "pygments_lexer": "ipython3", 428 | "version": "3.10.8" 429 | }, 430 | "orig_nbformat": 4, 431 | "vscode": { 432 | "interpreter": { 433 | "hash": "c339664639c3e5019e3803d0baff2aab4fdaac0204aae143f6ed0f1a6cb76161" 434 | } 435 | } 436 | }, 437 | "nbformat": 4, 438 | "nbformat_minor": 2 439 | } 440 | -------------------------------------------------------------------------------- /msanii/scripts/inference.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from argparse import ArgumentParser 4 | 5 | import numpy as np 6 | import torch 7 | import torchaudio 8 | from omegaconf import OmegaConf 9 | from torch import Tensor 10 | from tqdm.autonotebook import tqdm 11 | 12 | from ..config import ( 13 | Audio2AudioConfig, 14 | InpaintingConfig, 15 | InterpolationConfig, 16 | OutpaintingConfig, 17 | SamplingConfig, 18 | ) 19 | from ..data import AudioDataModule 20 | from ..pipeline import Pipeline 21 | from ..utils import compute_divisible_length 22 | from .utils import generate_batch_audio_mask 23 | 24 | 25 | def save_batch_samples( 26 | samples: Tensor, offset: int, output_dir: str, sample_rate: int, audio_format: str 27 | ) -> None: 28 | # torchaudio can only save cpu samples 29 | samples = samples.detach().cpu() 30 | samples = samples.to(torch.float32) 31 | 32 | os.makedirs(output_dir, exist_ok=True) 33 | for i, sample in enumerate(samples): 34 | filename = f"{offset + i}.{audio_format}" 35 | filepath = os.path.join(output_dir, filename) 36 | torchaudio.save(filepath, sample, sample_rate=sample_rate, format=audio_format) 37 | 38 | 39 | def run_sampling(config: SamplingConfig) -> None: 40 | # ------------------------------------------- 41 | # Setup 42 | # ------------------------------------------- 43 | device = torch.device(config.device) 44 | dtype = getattr(torch, config.dtype) 45 | 46 | # ------------------------------------------- 47 | # Load pipeline 48 | # ------------------------------------------- 49 | pipeline = Pipeline.from_pretrained(config.ckpt_path).to(device).to(dtype) 50 | 51 | # ------------------------------------------- 52 | # Run sampling 53 | # ------------------------------------------- 54 | n_batches = config.num_samples // config.batch_size 55 | batches = np.array_split(range(config.num_samples), n_batches) 56 | 57 | # Optionally compute the number of frames from duration 58 | num_frames = config.num_frames 59 | if config.duration is not None: 60 | num_frames = config.duration * pipeline.transforms.sample_rate 61 | num_frames = compute_divisible_length( 62 | num_frames, 63 | pipeline.transforms.hop_length, 64 | sum(pipeline.unet.has_resampling), 65 | ) 66 | 67 | for batch_idx, batch in enumerate( 68 | tqdm(batches, desc="Sampling", disable=(not config.verbose)) 69 | ): 70 | samples = torch.randn( 71 | (len(batch), config.channels, num_frames), 72 | device=pipeline.device, 73 | dtype=dtype, 74 | ) 75 | samples = pipeline.sample( 76 | samples, 77 | num_inference_steps=config.num_inference_steps, 78 | generator=torch.Generator(device).manual_seed(config.seed), 79 | verbose=config.verbose, 80 | use_neural_vocoder=config.use_neural_vocoder, 81 | num_griffin_lim_iters=config.num_griffin_lim_iters, 82 | ) 83 | save_batch_samples( 84 | samples, 85 | offset=(batch_idx * config.batch_size), 86 | output_dir=config.output_dir, 87 | sample_rate=pipeline.transforms.sample_rate, 88 | audio_format=config.output_audio_format, 89 | ) 90 | 91 | 92 | def run_audio2audio(config: Audio2AudioConfig) -> None: 93 | # ------------------------------------------- 94 | # Setup 95 | # ------------------------------------------- 96 | device = torch.device(config.device) 97 | dtype = getattr(torch, config.dtype) 98 | 99 | # ------------------------------------------- 100 | # Load pipeline 101 | # ------------------------------------------- 102 | pipeline = Pipeline.from_pretrained(config.ckpt_path).to(device).to(dtype) 103 | 104 | # ------------------------------------------- 105 | # Prepare datamodule and dataloader 106 | # ------------------------------------------- 107 | # Optionally compute the number of frames from duration 108 | num_frames = config.num_frames 109 | if config.duration is not None: 110 | num_frames = config.duration * pipeline.transforms.sample_rate 111 | num_frames = compute_divisible_length( 112 | num_frames, 113 | pipeline.transforms.hop_length, 114 | sum(pipeline.unet.has_resampling), 115 | ) 116 | datamodule = AudioDataModule( 117 | config.data_dir, 118 | sample_rate=pipeline.transforms.sample_rate, 119 | num_frames=num_frames, 120 | load_random_slice=False, 121 | normalize_amplitude=False, 122 | batch_size=config.batch_size, 123 | num_workers=config.num_workers, 124 | pin_memory=config.pin_memory, 125 | shuffle=False, 126 | ) 127 | datamodule.prepare_data() 128 | datamodule.setup() 129 | dataloader = datamodule.train_dataloader() 130 | 131 | # ------------------------------------------- 132 | # Run sampling 133 | # ------------------------------------------- 134 | for batch_idx, batch in enumerate( 135 | tqdm(dataloader, desc="Audio2Audio", disable=(not config.verbose)) 136 | ): 137 | samples = pipeline.sample( 138 | batch.to(device).to(dtype), 139 | num_inference_steps=config.num_inference_steps, 140 | strength=config.strength, 141 | generator=torch.Generator(device).manual_seed(config.seed), 142 | verbose=config.verbose, 143 | use_input_as_seed=True, 144 | use_neural_vocoder=config.use_neural_vocoder, 145 | num_griffin_lim_iters=config.num_griffin_lim_iters, 146 | ) 147 | save_batch_samples( 148 | samples, 149 | offset=(batch_idx * config.batch_size), 150 | output_dir=config.output_dir, 151 | sample_rate=pipeline.transforms.sample_rate, 152 | audio_format=config.output_audio_format, 153 | ) 154 | 155 | 156 | def run_interpolation(config: InterpolationConfig) -> None: 157 | # ------------------------------------------- 158 | # Setup 159 | # ------------------------------------------- 160 | device = torch.device(config.device) 161 | dtype = getattr(torch, config.dtype) 162 | 163 | # ------------------------------------------- 164 | # Load pipeline 165 | # ------------------------------------------- 166 | pipeline = Pipeline.from_pretrained(config.ckpt_path).to(device).to(dtype) 167 | 168 | # ------------------------------------------- 169 | # Prepare datamodule and dataloader 170 | # ------------------------------------------- 171 | # Optionally compute the number of frames from duration 172 | num_frames = config.num_frames 173 | if config.duration is not None: 174 | num_frames = config.duration * pipeline.transforms.sample_rate 175 | num_frames = compute_divisible_length( 176 | num_frames, 177 | pipeline.transforms.hop_length, 178 | sum(pipeline.unet.has_resampling), 179 | ) 180 | first_datamodule = AudioDataModule( 181 | config.first_data_dir, 182 | sample_rate=pipeline.transforms.sample_rate, 183 | num_frames=num_frames, 184 | load_random_slice=False, 185 | normalize_amplitude=False, 186 | batch_size=config.batch_size, 187 | num_workers=config.num_workers, 188 | pin_memory=config.pin_memory, 189 | shuffle=False, 190 | ) 191 | second_datamodule = AudioDataModule( 192 | config.second_data_dir, 193 | sample_rate=pipeline.transforms.sample_rate, 194 | num_frames=num_frames, 195 | load_random_slice=False, 196 | normalize_amplitude=False, 197 | batch_size=config.batch_size, 198 | num_workers=config.num_workers, 199 | pin_memory=config.pin_memory, 200 | shuffle=False, 201 | ) 202 | first_datamodule.prepare_data() 203 | first_datamodule.setup() 204 | second_datamodule.prepare_data() 205 | second_datamodule.setup() 206 | 207 | first_dataloader = first_datamodule.train_dataloader() 208 | second_dataloader = second_datamodule.train_dataloader() 209 | 210 | assert len(first_dataloader) == len( 211 | second_dataloader 212 | ), "Samples should be equal in both directories" 213 | 214 | # ------------------------------------------- 215 | # Run sampling 216 | # ------------------------------------------- 217 | for batch_idx, (first_batch, second_batch) in enumerate( 218 | tqdm( 219 | zip(first_dataloader, second_dataloader), 220 | desc="Interpolation", 221 | disable=(not config.verbose), 222 | ) 223 | ): 224 | samples = pipeline.interpolate( 225 | first_batch, 226 | second_batch, 227 | ratio=config.ratio, 228 | num_inference_steps=config.num_inference_steps, 229 | strength=config.strength, 230 | generator=torch.Generator(device).manual_seed(config.seed), 231 | verbose=config.verbose, 232 | use_neural_vocoder=config.use_neural_vocoder, 233 | num_griffin_lim_iters=config.num_griffin_lim_iters, 234 | ) 235 | save_batch_samples( 236 | samples, 237 | offset=(batch_idx * config.batch_size), 238 | output_dir=config.output_dir, 239 | sample_rate=pipeline.transforms.sample_rate, 240 | audio_format=config.output_audio_format, 241 | ) 242 | 243 | 244 | def run_inpainting(config: InpaintingConfig) -> None: 245 | # ------------------------------------------- 246 | # Setup 247 | # ------------------------------------------- 248 | device = torch.device(config.device) 249 | dtype = getattr(torch, config.dtype) 250 | 251 | # ------------------------------------------- 252 | # Load pipeline 253 | # ------------------------------------------- 254 | pipeline = Pipeline.from_pretrained(config.ckpt_path).to(device).to(dtype) 255 | 256 | # ------------------------------------------- 257 | # Prepare datamodule and dataloader 258 | # ------------------------------------------- 259 | # Optionally compute the number of frames from duration 260 | num_frames = config.num_frames 261 | if config.duration is not None: 262 | num_frames = config.duration * pipeline.transforms.sample_rate 263 | num_frames = compute_divisible_length( 264 | num_frames, 265 | pipeline.transforms.hop_length, 266 | sum(pipeline.unet.has_resampling), 267 | ) 268 | datamodule = AudioDataModule( 269 | config.data_dir, 270 | sample_rate=pipeline.transforms.sample_rate, 271 | num_frames=num_frames, 272 | load_random_slice=False, 273 | normalize_amplitude=False, 274 | batch_size=config.batch_size, 275 | num_workers=config.num_workers, 276 | pin_memory=config.pin_memory, 277 | shuffle=False, 278 | ) 279 | datamodule.prepare_data() 280 | datamodule.setup() 281 | dataloader = datamodule.train_dataloader() 282 | 283 | # ------------------------------------------- 284 | # Run sampling 285 | # ------------------------------------------- 286 | assert len(dataloader) == math.ceil( 287 | len(config.masks) / config.batch_size 288 | ), "Number of masks should match the samples" 289 | 290 | for batch_idx, batch in enumerate( 291 | tqdm(dataloader, desc="Inpainting", disable=(not config.verbose)) 292 | ): 293 | batch = batch.to(device).to(dtype) 294 | masks = generate_batch_audio_mask( 295 | config.masks[batch_idx : batch_idx + len(batch)], 296 | batch, 297 | pipeline.transforms.sample_rate, 298 | ) 299 | 300 | samples = pipeline.inpaint( 301 | batch, 302 | masks, 303 | num_inference_steps=config.num_inference_steps, 304 | eta=config.eta, 305 | jump_length=config.jump_length, 306 | jump_n_sample=config.jump_n_sample, 307 | generator=torch.Generator(device).manual_seed(config.seed), 308 | verbose=config.verbose, 309 | use_neural_vocoder=config.use_neural_vocoder, 310 | num_griffin_lim_iters=config.num_griffin_lim_iters, 311 | ) 312 | save_batch_samples( 313 | samples, 314 | offset=(batch_idx * config.batch_size), 315 | output_dir=config.output_dir, 316 | sample_rate=pipeline.transforms.sample_rate, 317 | audio_format=config.output_audio_format, 318 | ) 319 | 320 | 321 | def run_outpainting(config: OutpaintingConfig) -> None: 322 | # ------------------------------------------- 323 | # Setup 324 | # ------------------------------------------- 325 | device = torch.device(config.device) 326 | dtype = getattr(torch, config.dtype) 327 | 328 | # ------------------------------------------- 329 | # Load pipeline 330 | # ------------------------------------------- 331 | pipeline = Pipeline.from_pretrained(config.ckpt_path).to(device).to(dtype) 332 | 333 | # ------------------------------------------- 334 | # Prepare datamodule and dataloader 335 | # ------------------------------------------- 336 | # Optionally compute the number of frames from duration 337 | num_frames = config.num_frames 338 | if config.duration is not None: 339 | num_frames = config.duration * pipeline.transforms.sample_rate 340 | num_frames = compute_divisible_length( 341 | num_frames, 342 | pipeline.transforms.hop_length, 343 | sum(pipeline.unet.has_resampling), 344 | ) 345 | datamodule = AudioDataModule( 346 | config.data_dir, 347 | sample_rate=pipeline.transforms.sample_rate, 348 | num_frames=num_frames, 349 | load_random_slice=False, 350 | normalize_amplitude=False, 351 | batch_size=config.batch_size, 352 | num_workers=config.num_workers, 353 | pin_memory=config.pin_memory, 354 | shuffle=False, 355 | ) 356 | datamodule.prepare_data() 357 | datamodule.setup() 358 | dataloader = datamodule.train_dataloader() 359 | 360 | # ------------------------------------------- 361 | # Run sampling 362 | # ------------------------------------------- 363 | for batch_idx, batch in enumerate( 364 | tqdm(dataloader, desc="Outpainting", disable=(not config.verbose)) 365 | ): 366 | samples = pipeline.outpaint( 367 | batch.to(device).to(dtype), 368 | num_spans=config.num_spans, 369 | num_inference_steps=config.num_inference_steps, 370 | eta=config.eta, 371 | jump_length=config.jump_length, 372 | jump_n_sample=config.jump_n_sample, 373 | generator=torch.Generator(device).manual_seed(config.seed), 374 | verbose=config.verbose, 375 | use_neural_vocoder=config.use_neural_vocoder, 376 | num_griffin_lim_iters=config.num_griffin_lim_iters, 377 | ) 378 | save_batch_samples( 379 | samples, 380 | offset=(batch_idx * config.batch_size), 381 | output_dir=config.output_dir, 382 | sample_rate=pipeline.transforms.sample_rate, 383 | audio_format=config.output_audio_format, 384 | ) 385 | 386 | 387 | if __name__ == "__main__": 388 | parser = ArgumentParser() 389 | parser.add_argument( 390 | "task", 391 | help="task to run", 392 | choices=[ 393 | "sampling", 394 | "audio2audio", 395 | "interpolation", 396 | "inpainting", 397 | "outpainting", 398 | ], 399 | type=str, 400 | ) 401 | parser.add_argument("config_path", help="path to config file") 402 | args = parser.parse_args() 403 | 404 | file_config = OmegaConf.load(args.config_path) 405 | if args.task.lower() == "sampling": 406 | sampling_config = OmegaConf.structured(SamplingConfig) 407 | sampling_config = OmegaConf.merge(sampling_config, file_config) 408 | run_sampling(sampling_config) 409 | 410 | elif args.task.lower() == "audio2audio": 411 | audio2audio_config = OmegaConf.structured(Audio2AudioConfig) 412 | audio2audio_config = OmegaConf.merge(audio2audio_config, file_config) 413 | run_audio2audio(audio2audio_config) 414 | 415 | elif args.task.lower() == "interpolation": 416 | interpolation_config = OmegaConf.structured(InterpolationConfig) 417 | interpolation_config = OmegaConf.merge(InterpolationConfig, file_config) 418 | run_interpolation(interpolation_config) 419 | 420 | elif args.task.lower() == "inpainting": 421 | inpainting_config = OmegaConf.structured(InpaintingConfig) 422 | inpainting_config = OmegaConf.merge(InpaintingConfig, file_config) 423 | run_inpainting(inpainting_config) 424 | 425 | elif args.task.lower() == "outpainting": 426 | outpainting_config = OmegaConf.structured(OutpaintingConfig) 427 | outpainting_config = OmegaConf.merge(OutpaintingConfig, file_config) 428 | run_outpainting(outpainting_config) 429 | --------------------------------------------------------------------------------