├── LICENSE ├── README.md ├── audio_flow ├── adaptors │ ├── latent.py │ ├── onehot.py │ └── vae.py ├── data_transforms │ ├── dac2stereo.py │ ├── label2music.py │ ├── mono2stereo.py │ ├── mss.py │ ├── superresolution.py │ └── text2speech.py ├── datasets │ ├── gtzan_vae.py │ ├── ljspeech_vae.py │ ├── musdb18hq_dac2stereo_vae.py │ ├── musdb18hq_lowres2highres_vae.py │ ├── musdb18hq_mono2stereo_vae.py │ └── musdb18hq_vae.py ├── encoders │ └── dac.py ├── models │ ├── attention.py │ ├── embedders.py │ ├── pad.py │ ├── rope.py │ └── transformer1d.py ├── samplers │ └── sampler.py ├── utils.py └── vae │ └── levo.py ├── compute_latents ├── gtzan.py ├── ljspeech.py ├── maestro.py └── musdb18hq.py ├── configs ├── dac2stereo.yaml ├── mono2stereo.yaml ├── mss.yaml ├── superresolution.yaml ├── text2music.yaml └── text2speech.yaml ├── env.sh ├── finetune.py ├── sample.py ├── scripts ├── download_gtzan.sh ├── download_ljspeech.sh ├── download_maestro.sh └── download_musdb18hq.sh └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2025 CUHK (Qiuqiang Kong) 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AudioFlow: Audio Generation with Flow Matching 2 | 3 | This repository contains a tutorial on audio generation using conditional flow matching implemented in PyTorch. Signals from any modality, including text, audio, MIDI, images, and video, can be converted to audio using conditional flow matching. The figure below shows the framework. 4 | 5 | The supported tasks include: 6 | 7 | | Tasks | Supported | Dataset | Config yaml | 8 | |-------------------------|--------------|------------|--------------------------------------------------------------| 9 | | Text to music | ✅ | GTZAN | [configs/text2music.yaml](configs/text2music.yaml) | 10 | | MIDI to music | ✅ | MAESTRO | [configs/midi2music.yaml](configs/midi2music.yaml) | 11 | | Codec to audio | ✅ | MUSDB18HQ | [configs/codec2audio.yaml](configs/codec2audio.yaml) | 12 | | Mono to stereo | ✅ | MUSDB18HQ | [configs/mono2stereo.yaml](configs/mono2stereo.yaml) | 13 | | Super resolution | ✅ | MUSDB18HQ | [configs/superresolution.yaml](configs/superresolution.yaml) | 14 | | Music source separation | ✅ | MUSDB18HQ | [configs/mss.yaml](configs/mss.yaml) | 15 | | Vocal to music | ✅ | MUSDB18HQ | [configs/vocal2music.yaml](configs/vocal2music.yaml) | 16 | 17 | 18 | ## 0. Install dependencies 19 | 20 | ```bash 21 | # Clone the repo 22 | git clone https://github.com/qiuqiangkong/audio_flow 23 | cd audio_flow 24 | 25 | # Install Python environment 26 | conda create --name audio_flow python=3.10 27 | 28 | # Activate environment 29 | conda activate audio_flow 30 | 31 | # Install Python packages dependencies 32 | bash env.sh 33 | ``` 34 | 35 | ## 1. Download datasets 36 | 37 | Download the dataset corresponding to the task. 38 | 39 | GTZAN (1.3 GB, 8 hours): 40 | 41 | ```bash 42 | bash ./scripts/download_gtzan.sh 43 | ``` 44 | 45 | MUSDB18HQ (30 GB, 10 hours): 46 | 47 | ```bash 48 | bash ./scripts/download_musdb18hq.sh 49 | ``` 50 | 51 | To download more datasets please see [scripts](scripts). 52 | 53 | ## 2. Train 54 | 55 | ### 2.0 Pre-extract VAE latent 56 | 57 | In training, uses can use (1) online VAE extraction, or (2) offline VAE extraction. We adopt (2) to speed up the training of flow matching and to save RAM. 58 | 59 | ```python 60 | CUDA_VISIBLE_DEVICES=0 python -m compute_latents.gtzan_vae \ 61 | --dataset_root="./datasets/gtzan" \ 62 | --out_dir="./datasets/gtzan_vae" \ 63 | --augmentation_repeats=10 64 | ``` 65 | 66 | ### 2.1 Train with single GPU 67 | 68 | Here is an example of training a text to music generation system. Users can train different tasks viewing more config yaml files at [configs](configs). 69 | 70 | ```python 71 | CUDA_VISIBLE_DEVICES=0 python train.py --config="./configs/text2music.yaml" --no_log 72 | ``` 73 | 74 | ### 2.2 Finetune 75 | 76 | Extract VAE latent: 77 | 78 | ```python 79 | CUDA_VISIBLE_DEVICES=6 python -m compute_latents.musdb18hq_vae stems \ 80 | --dataset_root="./datasets/musdb18hq" \ 81 | --out_dir="./datasets/musdb18hq_vae" \ 82 | --augmentation_repeats=10 83 | ``` 84 | 85 | Train: 86 | 87 | ```python 88 | CUDA_VISIBLE_DEVICES=0 python finetune.py \ 89 | --config="./configs/mss.yaml" \ 90 | --ckpt_path="checkpoints/train/text2music/step=300000_ema.pth" \ 91 | --no_log 92 | ``` 93 | 94 | To run more examples please see [configs](configs). 95 | 96 | ## External links 97 | 98 | [1] Conditional flow matching: https://github.com/atong01/conditional-flow-matching 99 | 100 | [2] DiT: https://github.com/facebookresearch/DiT -------------------------------------------------------------------------------- /audio_flow/adaptors/latent.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch import Tensor 3 | from einops import rearrange 4 | 5 | 6 | class LatentEncoder(nn.Module): 7 | def __init__(self, in_channels: int, dim: int): 8 | super().__init__() 9 | 10 | self.mlp = nn.Sequential( 11 | nn.Linear(in_channels, dim), 12 | nn.SiLU(), 13 | nn.Linear(dim, dim, bias=True), 14 | ) 15 | 16 | def forward(self, cond_dict: dict) -> Tensor: 17 | r"""Compute latent embedding.""" 18 | 19 | if "c" in cond_dict: 20 | c = self.mlp(cond_dict["c"]) 21 | return {"c": c} 22 | 23 | elif "ct" in cond_dict: 24 | ct = rearrange(cond_dict["ct"], 'b d t -> b t d') 25 | ct = self.mlp(ct) 26 | ct = rearrange(ct, 'b t d -> b d t') 27 | return {"ct": ct} 28 | 29 | else: 30 | raise ValueError -------------------------------------------------------------------------------- /audio_flow/adaptors/onehot.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch import Tensor 3 | from einops import rearrange 4 | 5 | 6 | class OnehotEncoder(nn.Module): 7 | def __init__(self, num_classes: int, dim: int): 8 | super().__init__() 9 | 10 | self.mlp = nn.Sequential( 11 | nn.Embedding(num_classes, dim), 12 | nn.SiLU(), 13 | nn.Linear(dim, dim, bias=True), 14 | ) 15 | 16 | def forward(self, cond_dict: dict) -> Tensor: 17 | r"""Compute latent embedding.""" 18 | 19 | if cond_dict["id"].ndim == 1: 20 | c = self.mlp(cond_dict["id"]) # (b, d) 21 | return {"c": c} 22 | 23 | elif cond_dict["id"].ndim > 1: 24 | cx = self.mlp(cond_dict["id"]) # (b, t, d) 25 | cx = rearrange(cx, 'b t d -> b d t') 26 | return {"cx": cx} 27 | 28 | else: 29 | raise ValueError -------------------------------------------------------------------------------- /audio_flow/adaptors/vae.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch import Tensor 3 | from einops import rearrange 4 | 5 | 6 | class VAEEncoder(nn.Module): 7 | def __init__(self, in_channels: int, dim: int): 8 | super().__init__() 9 | 10 | self.mlp = nn.Sequential( 11 | nn.Linear(in_channels, dim), 12 | nn.SiLU(), 13 | nn.Linear(dim, dim, bias=True), 14 | ) 15 | 16 | def forward(self, cond_dict: dict) -> Tensor: 17 | r"""Compute onehot embedding.""" 18 | 19 | ct = rearrange(cond_dict["ct"], 'b d t -> b t d') 20 | ct = self.mlp(ct) 21 | ct = rearrange(ct, 'b t d -> b d t') 22 | 23 | emb_dict = { 24 | "ct": ct, 25 | } 26 | 27 | return emb_dict -------------------------------------------------------------------------------- /audio_flow/data_transforms/dac2stereo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | from audidata.datasets import GTZAN 5 | from torch import Tensor 6 | from einops import rearrange 7 | 8 | from audio_flow.vae.levo import LevoVAE 9 | from audio_flow.encoders.dac import DAC 10 | from audio_flow.utils import align_temporal_features 11 | 12 | 13 | class Dac2StereoVAE(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | self.dac = DAC() 18 | self.vae = LevoVAE() 19 | self.sr = self.vae.sr 20 | 21 | def audio_to_latent(self, data: dict) -> tuple[Tensor, dict]: 22 | r"""Transform data into latent representations and conditions. 23 | 24 | b: batch_size 25 | c: channels_num 26 | l: audio_samples 27 | t: frames_num 28 | f: mel bins 29 | """ 30 | 31 | device = next(self.parameters()).device 32 | 33 | 34 | dac_code = data["dac_code"].to(device) 35 | vae_latent = data["vae_latent"].to(device) 36 | 37 | dac_latent = self.dac.code_to_latent(dac_code) # (b, d, t) 38 | 39 | dac_latent = align_temporal_features( 40 | input=dac_latent, 41 | target=vae_latent, 42 | input_fps=data["dac_fps"][0].item(), 43 | target_fps=data["vae_fps"][0].item() 44 | ) 45 | 46 | # Condition 47 | cond_dict = { 48 | "ct": dac_latent 49 | } 50 | 51 | return vae_latent, cond_dict 52 | 53 | def latent_to_audio(self, x: Tensor) -> Tensor: 54 | r"""Ues vocoder to convert mel spectrogram to audio. 55 | 56 | Args: 57 | x: (b, c, t, f) 58 | 59 | Outputs: 60 | y: (b, c, l) 61 | """ 62 | x = self.vae.decode(x) 63 | return x 64 | 65 | def __call__(self, data: dict) -> tuple[Tensor, dict]: 66 | return self.audio_to_latent(data) -------------------------------------------------------------------------------- /audio_flow/data_transforms/label2music.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | from audidata.datasets import GTZAN 5 | from torch import Tensor 6 | from einops import rearrange 7 | 8 | from audio_flow.vae.levo import LevoVAE 9 | 10 | 11 | class Label2MusicVAE(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | self.vae = LevoVAE() 16 | self.sr = self.vae.sr 17 | 18 | def audio_to_latent(self, data: dict) -> tuple[Tensor, dict]: 19 | r"""Transform data into latent representations and conditions. 20 | 21 | b: batch_size 22 | c: channels_num 23 | l: audio_samples 24 | t: frames_num 25 | f: mel bins 26 | """ 27 | 28 | device = next(self.parameters()).device 29 | 30 | # Mel spectrogram target 31 | latent = data["latent"].to(device) # (b, d, t) 32 | # latent = rearrange(latent, 'b d t -> b t d') 33 | 34 | ids = data["target"].to(device) # (b,) 35 | captions = data["label"] # (b,) 36 | 37 | # Condition 38 | cond_dict = { 39 | "id": ids, 40 | "caption": captions 41 | } 42 | 43 | return latent, cond_dict 44 | 45 | def latent_to_audio(self, x: Tensor) -> Tensor: 46 | r"""Ues vocoder to convert mel spectrogram to audio. 47 | 48 | Args: 49 | x: (b, c, t, f) 50 | 51 | Outputs: 52 | y: (b, c, l) 53 | """ 54 | x = self.vae.decode(x) 55 | return x 56 | 57 | def __call__(self, data: dict) -> tuple[Tensor, dict]: 58 | return self.audio_to_latent(data) -------------------------------------------------------------------------------- /audio_flow/data_transforms/mono2stereo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | from audidata.datasets import GTZAN 5 | from torch import Tensor 6 | from einops import rearrange 7 | 8 | from audio_flow.vae.levo import LevoVAE 9 | 10 | 11 | class Mono2StereoVAE(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | self.vae = LevoVAE() 16 | self.sr = self.vae.sr 17 | 18 | def audio_to_latent(self, data: dict) -> tuple[Tensor, dict]: 19 | r"""Transform data into latent representations and conditions. 20 | 21 | b: batch_size 22 | c: channels_num 23 | l: audio_samples 24 | t: frames_num 25 | f: mel bins 26 | """ 27 | 28 | device = next(self.parameters()).device 29 | 30 | # Mel spectrogram target 31 | mono_latent = data["mono_latent"].to(device) # (b, d, t) 32 | target_latent = data["stereo_latent"].to(device) # (b, d, t) 33 | 34 | # Condition 35 | cond_dict = { 36 | "ct": mono_latent 37 | } 38 | 39 | return target_latent, cond_dict 40 | 41 | def latent_to_audio(self, x: Tensor) -> Tensor: 42 | r"""Ues vocoder to convert mel spectrogram to audio. 43 | 44 | Args: 45 | x: (b, c, t, f) 46 | 47 | Outputs: 48 | y: (b, c, l) 49 | """ 50 | x = self.vae.decode(x) 51 | return x 52 | 53 | def __call__(self, data: dict) -> tuple[Tensor, dict]: 54 | return self.audio_to_latent(data) -------------------------------------------------------------------------------- /audio_flow/data_transforms/mss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | from audidata.datasets import GTZAN 5 | from torch import Tensor 6 | from einops import rearrange 7 | 8 | from audio_flow.vae.levo import LevoVAE 9 | 10 | 11 | class MSSVAE(nn.Module): 12 | def __init__(self, target_stem: str): 13 | super().__init__() 14 | 15 | self.target_stem = target_stem 16 | self.vae = LevoVAE() 17 | self.sr = self.vae.sr 18 | 19 | def audio_to_latent(self, data: dict) -> tuple[Tensor, dict]: 20 | r"""Transform data into latent representations and conditions. 21 | 22 | b: batch_size 23 | c: channels_num 24 | l: audio_samples 25 | t: frames_num 26 | f: mel bins 27 | """ 28 | 29 | device = next(self.parameters()).device 30 | 31 | # Mel spectrogram target 32 | mixture_latent = data["mixture_latent"].to(device) # (b, d, t) 33 | target_latent = data["target_latent"].to(device) # (b, d, t) 34 | 35 | # Condition 36 | cond_dict = { 37 | "ct": mixture_latent 38 | } 39 | 40 | return target_latent, cond_dict 41 | 42 | def latent_to_audio(self, x: Tensor) -> Tensor: 43 | r"""Ues vocoder to convert mel spectrogram to audio. 44 | 45 | Args: 46 | x: (b, c, t, f) 47 | 48 | Outputs: 49 | y: (b, c, l) 50 | """ 51 | x = self.vae.decode(x) 52 | return x 53 | 54 | def __call__(self, data: dict) -> tuple[Tensor, dict]: 55 | return self.audio_to_latent(data) -------------------------------------------------------------------------------- /audio_flow/data_transforms/superresolution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | from audidata.datasets import GTZAN 5 | from torch import Tensor 6 | from einops import rearrange 7 | 8 | from audio_flow.vae.levo import LevoVAE 9 | 10 | 11 | class SuperResolutionVAE(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | self.vae = LevoVAE() 16 | self.sr = self.vae.sr 17 | 18 | def audio_to_latent(self, data: dict) -> tuple[Tensor, dict]: 19 | r"""Transform data into latent representations and conditions. 20 | 21 | b: batch_size 22 | c: channels_num 23 | l: audio_samples 24 | t: frames_num 25 | f: mel bins 26 | """ 27 | 28 | device = next(self.parameters()).device 29 | from IPython import embed; embed(using=False); os._exit(0) 30 | 31 | # Mel spectrogram target 32 | mono_latent = data["mono_latent"].to(device) # (b, d, t) 33 | target_latent = data["stereo_latent"].to(device) # (b, d, t) 34 | 35 | # Condition 36 | cond_dict = { 37 | "ct": mono_latent 38 | } 39 | 40 | return target_latent, cond_dict 41 | 42 | def latent_to_audio(self, x: Tensor) -> Tensor: 43 | r"""Ues vocoder to convert mel spectrogram to audio. 44 | 45 | Args: 46 | x: (b, c, t, f) 47 | 48 | Outputs: 49 | y: (b, c, l) 50 | """ 51 | x = self.vae.decode(x) 52 | return x 53 | 54 | def __call__(self, data: dict) -> tuple[Tensor, dict]: 55 | return self.audio_to_latent(data) -------------------------------------------------------------------------------- /audio_flow/data_transforms/text2speech.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | from audidata.datasets import GTZAN 5 | from torch import Tensor, LongTensor 6 | from einops import rearrange 7 | from transformers import AutoTokenizer 8 | 9 | from audio_flow.vae.levo import LevoVAE 10 | 11 | 12 | class Text2SpeechVAE(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | self.vae = LevoVAE() 17 | self.sr = self.vae.sr 18 | 19 | self.tok = AutoTokenizer.from_pretrained("bert-base-uncased") 20 | 21 | def audio_to_latent(self, data: dict) -> tuple[Tensor, dict]: 22 | r"""Transform data into latent representations and conditions. 23 | 24 | b: batch_size 25 | c: channels_num 26 | l: audio_samples 27 | t: frames_num 28 | f: mel bins 29 | """ 30 | 31 | device = next(self.parameters()).device 32 | 33 | # Mel spectrogram target 34 | latent = data["latent"].to(device) # (b, d, t) 35 | captions = data["caption"] # (b,) 36 | 37 | ids = self.tok(captions, padding="longest")["input_ids"] 38 | ids = LongTensor(ids).to(device) 39 | 40 | # Condition 41 | cond_dict = { 42 | "id": ids, 43 | "caption": captions 44 | } 45 | 46 | return latent, cond_dict 47 | 48 | def latent_to_audio(self, x: Tensor) -> Tensor: 49 | r"""Ues vocoder to convert mel spectrogram to audio. 50 | 51 | Args: 52 | x: (b, c, t, f) 53 | 54 | Outputs: 55 | y: (b, c, l) 56 | """ 57 | x = self.vae.decode(x) 58 | return x 59 | 60 | def __call__(self, data: dict) -> tuple[Tensor, dict]: 61 | return self.audio_to_latent(data) -------------------------------------------------------------------------------- /audio_flow/datasets/gtzan_vae.py: -------------------------------------------------------------------------------- 1 | r"""Code modified from: https://github.com/AudioFans/audidata/blob/main/audidata/datasets/gtzan.py""" 2 | from __future__ import annotations 3 | 4 | import os 5 | import re 6 | from pathlib import Path 7 | 8 | import h5py 9 | import random 10 | import pickle 11 | import librosa 12 | import numpy as np 13 | from audidata.io.audio import load 14 | from audidata.io.crops import StartCrop 15 | from audidata.transforms.audio import Mono 16 | from audidata.transforms.onehot import OneHot 17 | from audidata.utils import call 18 | from torch.utils.data import Dataset 19 | from typing_extensions import Literal 20 | 21 | 22 | class GtzanVAE(Dataset): 23 | r"""GTZAN [1] is a music dataset containing 1,000 30-second audio clips. 24 | The total duration is 8.3 hours. GTZAN includes 10 genres. All audio files 25 | are mono sampled at 22,050 Hz. After decompression, the dataset size is 26 | 1.3 GB. 27 | 28 | [1] Tzanetakis, G., et al., Musical genre classification of audio signals. 2002 29 | 30 | The dataset looks like: 31 | 32 | gtzan (1.3 GB) 33 | └── genres 34 | ├── blues (100 files) 35 | ├── classical (100 files) 36 | ├── country (100 files) 37 | ├── disco (100 files) 38 | ├── hiphop (100 files) 39 | ├── jazz (100 files) 40 | ├── metal (100 files) 41 | ├── pop (100 files) 42 | ├── reggae (100 files) 43 | └── rock (100 files) 44 | """ 45 | 46 | LABELS = ["blues", "classical", "country", "disco", "hiphop", "jazz", 47 | "metal", "pop", "reggae", "rock"] 48 | 49 | CLASSES_NUM = len(LABELS) 50 | LB_TO_IX = {lb: ix for ix, lb in enumerate(LABELS)} 51 | IX_TO_LB = {ix: lb for ix, lb in enumerate(LABELS)} 52 | 53 | def __init__( 54 | self, 55 | root: str = None, 56 | split: Literal["train", "test"] = "train", 57 | test_fold: int = 0, # E.g., fold 0 is used for testing. Fold 1 - 9 are used for training. 58 | duration: float = 10. 59 | ) -> None: 60 | 61 | self.root = root 62 | self.split = split 63 | self.test_fold = test_fold 64 | self.duration = duration 65 | 66 | self.labels = GtzanVAE.LABELS 67 | self.lb_to_ix = GtzanVAE.LB_TO_IX 68 | self.ix_to_lb = GtzanVAE.IX_TO_LB 69 | 70 | self.meta_dict = self.load_meta() 71 | 72 | def __getitem__(self, index: int) -> dict: 73 | 74 | path = str(self.meta_dict["path"][index]) 75 | label = self.meta_dict["label"][index] 76 | 77 | full_data = { 78 | "dataset_name": "GtzanVAE", 79 | "path": path, 80 | } 81 | 82 | # Load audio data 83 | latent_data = self.load_latent_data(path=path) 84 | full_data.update(latent_data) 85 | 86 | # Load target data 87 | target_data = self.load_target_data(label=label) 88 | full_data.update(target_data) 89 | 90 | return full_data 91 | 92 | def __len__(self) -> int: 93 | return len(self.meta_dict["name"]) 94 | 95 | def load_meta(self) -> dict: 96 | r"""Load metadata of the GTZAN dataset. 97 | """ 98 | 99 | meta_dict = { 100 | "name": [], 101 | "path": [], 102 | "label": [], 103 | } 104 | 105 | out_dir = Path(self.root, "genres") 106 | 107 | for genre in self.labels: 108 | 109 | names = sorted(os.listdir(Path(out_dir, genre))) 110 | train_names, test_names = self.split_train_test(names) 111 | 112 | if self.split == "train": 113 | filtered_names = train_names 114 | 115 | elif self.split == "test": 116 | filtered_names = test_names 117 | 118 | else: 119 | raise ValueError(self.split) 120 | 121 | for name in filtered_names: 122 | path = str(Path(out_dir, genre, name)) 123 | meta_dict["name"].append(name) 124 | meta_dict["path"].append(path) 125 | meta_dict["label"].append(genre) 126 | 127 | return meta_dict 128 | 129 | def split_train_test(self, names: list) -> tuple[list, list]: 130 | 131 | train_names = [] 132 | test_names = [] 133 | 134 | test_ids = range(self.test_fold * 10, (self.test_fold + 1) * 10) 135 | # E.g., if test_fold = 3, then test_ids = [30, 31, 32, ..., 39] 136 | 137 | for name in names: 138 | 139 | audio_id = int(re.search(r'\d+', name).group()) 140 | # E.g., if name is "blues.00037.h5", then audio_id = 37 141 | 142 | if audio_id in test_ids: 143 | test_names.append(name) 144 | 145 | else: 146 | train_names.append(name) 147 | 148 | return train_names, test_names 149 | 150 | def load_latent_data(self, path: str) -> dict: 151 | 152 | with h5py.File(path, 'r') as hf: 153 | latent = hf["latent"][:] 154 | fps = hf.attrs["fps"] 155 | 156 | clip_frames = int(self.duration * fps) 157 | bgn_frame = random.randint(0, latent.shape[-1] - clip_frames) 158 | bgn_frame = max(0, bgn_frame) 159 | latent = latent[:, bgn_frame : bgn_frame + clip_frames] # (d, t) 160 | 161 | data = { 162 | "latent": latent, 163 | "fps": fps 164 | } 165 | 166 | return data 167 | 168 | def load_target_data(self, label: str) -> dict: 169 | 170 | target = self.lb_to_ix[label] 171 | 172 | data = { 173 | "label": label, 174 | "target": target # shape: (classes_num,) 175 | } 176 | 177 | return data -------------------------------------------------------------------------------- /audio_flow/datasets/ljspeech_vae.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | import pickle 5 | import os 6 | import h5py 7 | 8 | import librosa 9 | import numpy as np 10 | import pandas as pd 11 | from audidata.io.audio import load 12 | from audidata.io.crops import StartCrop 13 | from audidata.transforms.audio import Mono 14 | from audidata.utils import call 15 | from torch.utils.data import Dataset 16 | from typing_extensions import Literal 17 | 18 | 19 | class LJSpeechVAE(Dataset): 20 | 21 | def __init__( 22 | self, 23 | root: str = None, 24 | split: Literal["train", "valid" "test"] = "train", 25 | duration: float = 10 26 | ) -> None: 27 | 28 | self.root = root 29 | self.split = split 30 | self.duration = duration 31 | 32 | self.meta_dict = self.load_meta() 33 | 34 | def __getitem__(self, index: int) -> dict: 35 | 36 | audio_path = self.meta_dict["audio_path"][index] 37 | caption = self.meta_dict["caption"][index] 38 | 39 | full_data = { 40 | "dataset_name": "LJSpeechVAE", 41 | "audio_path": audio_path, 42 | } 43 | 44 | # Load audio data 45 | audio_data = self.load_latent_data(path=audio_path) 46 | full_data.update(audio_data) 47 | 48 | # Load target data 49 | target_data = self.load_target_data(caption=caption) 50 | full_data.update(target_data) 51 | 52 | return full_data 53 | 54 | def __len__(self) -> int: 55 | return len(self.meta_dict["audio_name"]) 56 | 57 | def load_meta(self) -> dict: 58 | r"""Load metadata of the GTZAN dataset. 59 | """ 60 | 61 | # Load split file 62 | split_path = Path(self.root, "{}.txt".format(self.split)) 63 | df = pd.read_csv(split_path, header=None) 64 | split_names = df[0].values 65 | 66 | # Load csv file 67 | csv_path = Path(self.root, "metadata.csv") 68 | df = pd.read_csv(csv_path, sep="|", header=None) 69 | names = df[0].values 70 | captions = df[1].values 71 | 72 | # Get split indexes 73 | idxes = [] 74 | for i in range(len(names)): 75 | if names[i] in split_names: 76 | idxes.append(i) 77 | 78 | audios_num = len(os.listdir(Path(self.root, "wavs"))) 79 | num_repeats = audios_num // len(captions) 80 | 81 | all_names = [] 82 | all_captions = [] 83 | all_paths = [] 84 | 85 | for i in idxes: 86 | all_names.extend([names[i]] * num_repeats) 87 | all_captions.extend([captions[i]] * num_repeats) 88 | all_paths.extend([str(Path(self.root, "wavs", f"{names[i]}_{i:03d}_vae.h5")) for i in range(num_repeats)]) 89 | 90 | meta_dict = { 91 | "audio_name": all_names, 92 | "audio_path": all_paths, 93 | "caption": all_captions 94 | } 95 | 96 | return meta_dict 97 | 98 | def load_latent_data(self, path: str) -> dict: 99 | 100 | with h5py.File(path, 'r') as hf: 101 | latent = hf["latent"][:] 102 | fps = hf.attrs["fps"] 103 | 104 | data = { 105 | "latent": latent, 106 | "fps": fps 107 | } 108 | 109 | return data 110 | 111 | def load_target_data(self, caption: str) -> dict: 112 | 113 | target = caption 114 | 115 | data = { 116 | "caption": caption, 117 | "target": target 118 | } 119 | 120 | return data -------------------------------------------------------------------------------- /audio_flow/datasets/musdb18hq_dac2stereo_vae.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from pathlib import Path 5 | from typing import Union 6 | import pickle 7 | import random 8 | import h5py 9 | 10 | import librosa 11 | import numpy as np 12 | from audidata.io.audio import load 13 | from audidata.io.crops import RandomCrop 14 | from torch.utils.data import Dataset 15 | from typing_extensions import Literal 16 | 17 | 18 | class MUSDB18HqDac2StereoVAE(Dataset): 19 | 20 | def __init__( 21 | self, 22 | root: str, 23 | split: Literal["train", "test"] = "train", 24 | duration: float = 10., 25 | ) -> None: 26 | 27 | self.root = root 28 | self.split = split 29 | self.duration = duration 30 | 31 | self.meta_dict = self.load_meta() 32 | 33 | def __getitem__( 34 | self, 35 | index: int, 36 | ) -> dict: 37 | 38 | full_data = { 39 | "dataset_name": "MUSDB18HqMono2StereoVAE", 40 | } 41 | 42 | dac_path = self.meta_dict["dac_path"][index] 43 | vae_path = self.meta_dict["vae_path"][index] 44 | 45 | data = self.load_latent_data(dac_path, vae_path) 46 | full_data.update(data) 47 | 48 | return full_data 49 | 50 | def __len__(self) -> int: 51 | return len(self.meta_dict["vae_path"]) 52 | 53 | def load_meta(self): 54 | 55 | vae_paths = sorted(list(Path(self.root, self.split).rglob('*vae.h5'))) 56 | vae_paths = [str(s) for s in vae_paths] 57 | dac_paths = [s.replace("vae.h5", "dac.h5") for s in vae_paths] 58 | 59 | meta_dict = { 60 | "dac_path": dac_paths, 61 | "vae_path": vae_paths 62 | } 63 | 64 | return meta_dict 65 | 66 | def load_latent_data(self, mono_path: str, vae_path: str) -> dict: 67 | 68 | with h5py.File(mono_path, 'r') as hf: 69 | dac_code = hf["code"][:] 70 | dac_fps = hf.attrs["fps"] 71 | 72 | with h5py.File(vae_path, 'r') as hf: 73 | vae_latent = hf["latent"][:] 74 | vae_fps = hf.attrs["fps"] 75 | 76 | vae_frames = int(self.duration * vae_fps) 77 | vae_bgn = random.randint(0, vae_latent.shape[-1] - vae_frames) 78 | vae_bgn = max(0, vae_bgn) 79 | vae_latent = vae_latent[:, vae_bgn : vae_bgn + vae_frames] # (d, t) 80 | 81 | dac_frames = round(vae_frames * (dac_fps / vae_fps)) 82 | dac_bgn = round(vae_bgn * (dac_fps / vae_fps)) 83 | dac_code = dac_code[:, dac_bgn : dac_bgn + dac_frames] # (d, t) 84 | 85 | if dac_code.shape[-1] < dac_frames: 86 | dac_code = librosa.util.fix_length(dac_code, size=dac_frames, axis=-1, mode="edge") 87 | 88 | data = { 89 | "dac_code": dac_code, 90 | "vae_latent": vae_latent, 91 | "dac_fps": dac_fps, 92 | "vae_fps": vae_fps 93 | } 94 | 95 | return data -------------------------------------------------------------------------------- /audio_flow/datasets/musdb18hq_lowres2highres_vae.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from pathlib import Path 5 | from typing import Union 6 | import pickle 7 | import random 8 | import h5py 9 | 10 | import librosa 11 | import numpy as np 12 | from audidata.io.audio import load 13 | from audidata.io.crops import RandomCrop 14 | from torch.utils.data import Dataset 15 | from typing_extensions import Literal 16 | 17 | 18 | class MUSDB18HqLowres2HighresVAE(Dataset): 19 | 20 | def __init__( 21 | self, 22 | root: str, 23 | split: Literal["train", "test"] = "train", 24 | duration: float = 10., 25 | ) -> None: 26 | 27 | self.root = root 28 | self.split = split 29 | self.duration = duration 30 | 31 | self.meta_dict = self.load_meta() 32 | 33 | def __getitem__( 34 | self, 35 | index: int, 36 | ) -> dict: 37 | 38 | full_data = { 39 | "dataset_name": "MUSDB18HqLowres2HighresVAE", 40 | } 41 | 42 | mono_path = self.meta_dict["mono_path"][index] 43 | stereo_path = self.meta_dict["stereo_path"][index] 44 | 45 | data = self.load_latent_data(mono_path, stereo_path) 46 | full_data.update(data) 47 | 48 | return full_data 49 | 50 | def __len__(self) -> int: 51 | return len(self.meta_dict["mono_path"]) 52 | 53 | def load_meta(self): 54 | 55 | from IPython import embed; embed(using=False); os._exit(0) 56 | 57 | input_paths = sorted(list(Path(self.root, self.split).rglob('*lowres_vae.h5'))) 58 | target_paths = [str(s) for s in input_paths] 59 | target_paths = [s.replace("lowres_vae.h5", "highres_vae.h5") for s in input_paths] 60 | 61 | meta_dict = { 62 | "mono_path": mono_paths, 63 | "stereo_path": stereo_paths 64 | } 65 | 66 | return meta_dict 67 | 68 | def load_latent_data(self, mono_path: str, stereo_path: str) -> dict: 69 | 70 | with h5py.File(mono_path, 'r') as hf: 71 | mono_latent = hf["latent"][:] 72 | fps = hf.attrs["fps"] 73 | 74 | with h5py.File(stereo_path, 'r') as hf: 75 | stereo_latent = hf["latent"][:] 76 | 77 | total_frames = mono_latent.shape[-1] 78 | clip_frames = int(self.duration * fps) 79 | bgn_frame = random.randint(0, total_frames - clip_frames) 80 | bgn_frame = max(0, bgn_frame) 81 | 82 | mono_latent = mono_latent[:, bgn_frame : bgn_frame + clip_frames] # (d, t) 83 | stereo_latent = stereo_latent[:, bgn_frame : bgn_frame + clip_frames] # (d, t) 84 | 85 | data = { 86 | "mono_latent": mono_latent, 87 | "stereo_latent": stereo_latent, 88 | "fps": fps 89 | } 90 | 91 | return data -------------------------------------------------------------------------------- /audio_flow/datasets/musdb18hq_mono2stereo_vae.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from pathlib import Path 5 | from typing import Union 6 | import pickle 7 | import random 8 | import h5py 9 | 10 | import librosa 11 | import numpy as np 12 | from audidata.io.audio import load 13 | from audidata.io.crops import RandomCrop 14 | from torch.utils.data import Dataset 15 | from typing_extensions import Literal 16 | 17 | 18 | class MUSDB18HqMono2StereoVAE(Dataset): 19 | 20 | def __init__( 21 | self, 22 | root: str, 23 | split: Literal["train", "test"] = "train", 24 | duration: float = 10., 25 | ) -> None: 26 | 27 | self.root = root 28 | self.split = split 29 | self.duration = duration 30 | 31 | self.meta_dict = self.load_meta() 32 | 33 | def __getitem__( 34 | self, 35 | index: int, 36 | ) -> dict: 37 | 38 | full_data = { 39 | "dataset_name": "MUSDB18HqMono2StereoVAE", 40 | } 41 | 42 | mono_path = self.meta_dict["mono_path"][index] 43 | stereo_path = self.meta_dict["stereo_path"][index] 44 | 45 | data = self.load_latent_data(mono_path, stereo_path) 46 | full_data.update(data) 47 | 48 | return full_data 49 | 50 | def __len__(self) -> int: 51 | return len(self.meta_dict["mono_path"]) 52 | 53 | def load_meta(self): 54 | 55 | mono_paths = sorted(list(Path(self.root, self.split).rglob('*mono_vae.h5'))) 56 | mono_paths = [str(s) for s in mono_paths] 57 | stereo_paths = [s.replace("mono_vae.h5", "stereo_vae.h5") for s in mono_paths] 58 | 59 | meta_dict = { 60 | "mono_path": mono_paths, 61 | "stereo_path": stereo_paths 62 | } 63 | 64 | return meta_dict 65 | 66 | def load_latent_data(self, mono_path: str, stereo_path: str) -> dict: 67 | 68 | with h5py.File(mono_path, 'r') as hf: 69 | mono_latent = hf["latent"][:] 70 | fps = hf.attrs["fps"] 71 | 72 | with h5py.File(stereo_path, 'r') as hf: 73 | stereo_latent = hf["latent"][:] 74 | 75 | total_frames = mono_latent.shape[-1] 76 | clip_frames = int(self.duration * fps) 77 | bgn_frame = random.randint(0, total_frames - clip_frames) 78 | bgn_frame = max(0, bgn_frame) 79 | 80 | mono_latent = mono_latent[:, bgn_frame : bgn_frame + clip_frames] # (d, t) 81 | stereo_latent = stereo_latent[:, bgn_frame : bgn_frame + clip_frames] # (d, t) 82 | 83 | data = { 84 | "mono_latent": mono_latent, 85 | "stereo_latent": stereo_latent, 86 | "fps": fps 87 | } 88 | 89 | return data -------------------------------------------------------------------------------- /audio_flow/datasets/musdb18hq_vae.py: -------------------------------------------------------------------------------- 1 | r"""Code from https://github.com/AudioFans/audidata/blob/main/audidata/datasets/musdb18hq.py""" 2 | from __future__ import annotations 3 | 4 | import os 5 | from pathlib import Path 6 | from typing import Union 7 | import pickle 8 | import random 9 | import h5py 10 | 11 | import librosa 12 | import numpy as np 13 | from audidata.io.audio import load 14 | from audidata.io.crops import RandomCrop 15 | from torch.utils.data import Dataset 16 | from typing_extensions import Literal 17 | 18 | 19 | class MUSDB18HqVAE(Dataset): 20 | 21 | def __init__( 22 | self, 23 | root: str, 24 | split: Literal["train", "test"] = "train", 25 | duration: float = 10., 26 | target_stem: str = "vocals", 27 | ) -> None: 28 | 29 | self.root = root 30 | self.split = split 31 | self.duration = duration 32 | self.target_stem = target_stem 33 | 34 | self.meta_dict = self.load_meta() 35 | 36 | def __getitem__( 37 | self, 38 | index: int, 39 | ) -> dict: 40 | 41 | audio_names = {} 42 | audio_paths = {} 43 | start_times = {} 44 | clip_durations = {} 45 | 46 | full_data = { 47 | "dataset_name": "MUSDB18HqVAE", 48 | } 49 | 50 | mixture_path = self.meta_dict["mixture_path"][index] 51 | target_path = self.meta_dict["target_path"][index] 52 | 53 | data = self.load_latent_data(mixture_path, target_path) 54 | full_data.update(data) 55 | 56 | return full_data 57 | 58 | def __len__(self) -> int: 59 | return len(self.meta_dict["mixture_path"]) 60 | 61 | def load_meta(self): 62 | 63 | mixture_paths = sorted(list(Path(self.root, self.split).rglob("*mixture_*"))) 64 | mixture_paths = [str(s) for s in mixture_paths] 65 | target_paths = [s.replace("mixture_", "{}_".format(self.target_stem)) for s in mixture_paths] 66 | 67 | meta_dict = { 68 | "mixture_path": mixture_paths, 69 | "target_path": target_paths 70 | } 71 | 72 | return meta_dict 73 | 74 | def load_latent_data(self, mixture_path: str, target_path) -> dict: 75 | 76 | with h5py.File(mixture_path, 'r') as hf: 77 | mixture_latent = hf["latent"][:] 78 | fps = hf.attrs["fps"] 79 | 80 | with h5py.File(target_path, 'r') as hf: 81 | target_latent = hf["latent"][:] 82 | 83 | total_frames = mixture_latent.shape[-1] 84 | clip_frames = int(self.duration * fps) 85 | bgn_frame = random.randint(0, total_frames - clip_frames) 86 | bgn_frame = max(0, bgn_frame) 87 | 88 | mixture_latent = mixture_latent[:, bgn_frame : bgn_frame + clip_frames] # (d, t) 89 | target_latent = target_latent[:, bgn_frame : bgn_frame + clip_frames] # (d, t) 90 | 91 | data = { 92 | "mixture_latent": mixture_latent, 93 | "target_latent": target_latent, 94 | "fps": fps 95 | } 96 | 97 | return data -------------------------------------------------------------------------------- /audio_flow/encoders/dac.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dac 4 | import torch 5 | import torch.nn as nn 6 | from torch import LongTensor, Tensor 7 | 8 | 9 | class DAC(nn.Module): 10 | def __init__(self, n_quantizers: int = 2) -> None: 11 | super().__init__() 12 | 13 | model_path = dac.utils.download(model_type="44khz") 14 | self.model = dac.DAC.load(model_path) 15 | self.n_quantizers = n_quantizers 16 | self.sr = 44100 17 | self.fps = self.sr / 2 / 4 / 8 / 8 18 | 19 | 20 | def encode(self, audio: Tensor) -> Tensor: 21 | r"""Encode audio to discrete code. 22 | 23 | b: batch_size 24 | c: channels_num 25 | l: audio_samples 26 | t: time_steps 27 | q: n_quantizers 28 | 29 | Args: 30 | audio: (b, c, l) 31 | 32 | Outputs: 33 | x: (b, q, t) 34 | """ 35 | 36 | assert audio.shape[1] == 1 37 | 38 | with torch.no_grad(): 39 | self.model.eval() 40 | _, codes, _, _, _ = self.model.encode( 41 | audio_data=audio, 42 | n_quantizers=self.n_quantizers 43 | ) # codes: (b, q, t), integer, codebook indices 44 | 45 | # latent, _, _ = self.model.quantizer.from_codes(codes[:, 0 : self.n_quantizers, :]) 46 | 47 | return codes 48 | 49 | def decode( 50 | self, 51 | codes: LongTensor, 52 | ) -> Tensor: 53 | r"""Decode discrete code to audio. 54 | 55 | d: latent_dim 56 | 57 | Args: 58 | codes: (b, q, t) 59 | 60 | Returns: 61 | audio: (b, c, l) 62 | """ 63 | 64 | with torch.no_grad(): 65 | self.model.eval() 66 | z, _, _ = self.model.quantizer.from_codes(codes) # (b, d, t) 67 | audio = self.model.decode(z) # (b, c, l) 68 | 69 | return audio 70 | 71 | def code_to_latent(self, codes): 72 | with torch.no_grad(): 73 | self.model.eval() 74 | latent, _, _ = self.model.quantizer.from_codes(codes) # (b, d, t) 75 | 76 | return latent 77 | 78 | def __call__(self, audio: Tensor) -> Tensor: 79 | return self.encode(audio) -------------------------------------------------------------------------------- /audio_flow/models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | from einops import rearrange 6 | 7 | from audio_flow.models.rope import RoPE 8 | 9 | 10 | def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: 11 | r"""Modulate input with scale and shift. 12 | 13 | Args: 14 | x: (b, t, d) 15 | shift: (b, t, d) 16 | scale: (b, t, d) 17 | 18 | Outputs: 19 | out: (b, t, d) 20 | """ 21 | return x * (1 + scale) + shift 22 | 23 | 24 | class Block(nn.Module): 25 | r"""Self attention block. 26 | 27 | Ref: 28 | [1] https://github.com/facebookresearch/DiT/blob/main/models.py 29 | [2] https://huggingface.co/hpcai-tech/OpenSora-STDiT-v1-HQ-16x256x256/blob/main/layers.py 30 | """ 31 | def __init__(self, dim, num_heads) -> None: 32 | super().__init__() 33 | 34 | self.norm1 = RMSNorm(dim) 35 | self.norm2 = RMSNorm(dim) 36 | self.norm3 = RMSNorm(dim) 37 | 38 | self.self_attn = SelfAttention(dim, num_heads) 39 | self.cross_attn = CrossAttention(dim, num_heads) 40 | 41 | self.ffn = nn.Sequential( 42 | nn.Linear(dim, dim * 4), 43 | nn.GELU(approximate='tanh'), 44 | nn.Linear(dim * 4, dim) 45 | ) 46 | 47 | self.modulation = nn.Sequential( 48 | nn.SiLU(), 49 | nn.Linear(dim, 6 * dim) 50 | ) 51 | 52 | def forward( 53 | self, 54 | x: Tensor, 55 | e: Tensor, 56 | cx: Tensor, 57 | rope: RoPE, 58 | ) -> torch.Tensor: 59 | r"""Self attention block. 60 | 61 | Args: 62 | x: (b, l, d) 63 | rope: (t, head_dim/2, 2) 64 | mask: None | (1, 1, l, l) 65 | emb: (b, l, d) 66 | 67 | Outputs: 68 | out: (b, l, d) 69 | """ 70 | 71 | e = self.modulation(e).chunk(6, dim=2) 72 | 73 | # Self-attention 74 | h = modulate(self.norm1(x), e[0], e[1]) 75 | x = x + e[2] * self.self_attn(h, rope) 76 | 77 | # Cross-attention 78 | if cx is not None: 79 | x = x + self.cross_attn(self.norm2(x), cx, rope) 80 | 81 | # FFN 82 | h = modulate(self.norm3(x), e[3], e[4]) 83 | x = x + e[5] * self.ffn(h) 84 | 85 | return x 86 | 87 | 88 | class RMSNorm(nn.Module): 89 | r"""Root Mean Square Layer Normalization. 90 | 91 | Ref: https://github.com/meta-llama/llama/blob/main/llama/model.py 92 | """ 93 | def __init__(self, dim: int, eps: float = 1e-6): 94 | 95 | super().__init__() 96 | self.eps = eps 97 | self.scale = nn.Parameter(torch.ones(dim)) 98 | 99 | def forward(self, x): 100 | r"""RMSNorm. 101 | 102 | Args: 103 | x: (b, t, d) 104 | 105 | Outputs: 106 | x: (b, t, d) 107 | """ 108 | norm_x = torch.mean(x ** 2, dim=-1, keepdim=True) 109 | output = x * torch.rsqrt(norm_x + self.eps) * self.scale 110 | return output 111 | 112 | 113 | class SelfAttention(nn.Module): 114 | def __init__(self, dim, num_heads) -> None: 115 | super().__init__() 116 | 117 | assert dim % num_heads == 0 118 | self.head_dim = dim // num_heads 119 | 120 | self.qkv_linear = nn.Linear(dim, 3 * dim) 121 | self.norm_q = RMSNorm(dim) 122 | self.norm_k = RMSNorm(dim) 123 | 124 | self.proj = nn.Linear(dim, dim) 125 | 126 | def forward( 127 | self, 128 | x: Tensor, 129 | rope: nn.Module, 130 | ) -> Tensor: 131 | r"""Causal self attention. 132 | 133 | b: batch_size 134 | l: seq_len 135 | d: latent_dim 136 | n: n_head 137 | h: head_dim 138 | 139 | Args: 140 | x: (b, l, d) 141 | rope: (l, head_dim/2, 2) 142 | mask: (1, 1) 143 | 144 | Outputs: 145 | x: (b, l, d) 146 | """ 147 | 148 | # Calculate query, key, values 149 | q, k, v = self.qkv_linear(x).chunk(chunks=3, dim=2) # shapes: (b, l, d) 150 | q = rearrange(self.norm_q(q), 'b l (n h) -> b l n h', h=self.head_dim) # (b, l, n, h) 151 | k = rearrange(self.norm_k(k), 'b l (n h) -> b l n h', h=self.head_dim) # (b, l, n, h) 152 | v = rearrange(v, 'b l (n h) -> b l n h', h=self.head_dim) # (b, l, n, h) 153 | 154 | # Apply RoPE 155 | q = rope(q) # (b, l, n, h) 156 | k = rope(k) # (b, l, n, h) 157 | 158 | # Efficient attention using Flash Attention CUDA kernels 159 | x = F.scaled_dot_product_attention( 160 | query=rearrange(q, 'b l n h -> b n l h'), 161 | key=rearrange(k, 'b l n h -> b n l h'), 162 | value=rearrange(v, 'b l n h -> b n l h'), 163 | attn_mask=None, 164 | dropout_p=0.0 165 | ) # (b, n, l, h) 166 | 167 | x = rearrange(x, 'b n l h -> b l (n h)') 168 | x = self.proj(x) # (b, l, d) 169 | 170 | return x 171 | 172 | 173 | class CrossAttention(nn.Module): 174 | def __init__(self, dim, num_heads): 175 | super().__init__() 176 | 177 | assert dim % num_heads == 0 178 | self.head_dim = dim // num_heads 179 | 180 | self.q_linear = nn.Linear(dim, dim) 181 | self.kv_linear = nn.Linear(dim, dim * 2) 182 | self.norm_q = RMSNorm(dim) 183 | self.norm_k = RMSNorm(dim) 184 | 185 | self.proj = nn.Linear(dim, dim) 186 | 187 | def forward( 188 | self, 189 | x: Tensor, 190 | cx: Tensor, 191 | rope: RoPE, 192 | ) -> Tensor: 193 | r"""Causal self attention. 194 | 195 | b: batch_size 196 | l: seq_len 197 | d: latent_dim 198 | n: heads_num 199 | h: head_dim 200 | k: rope_dim 201 | 202 | Args: 203 | x: (b, l, d) 204 | rope: (l, h/2, 2) 205 | pos: (l, k) 206 | mask: (1, 1, ) 207 | 208 | Outputs: 209 | x: (b, l, d) 210 | """ 211 | B, L, D = x.shape 212 | 213 | # Calculate query, key, values 214 | q = self.q_linear(x) # shapes: (b, lx, d) 215 | k, v = self.kv_linear(cx).chunk(chunks=2, dim=2) # shapes: (b, lc, d) 216 | 217 | q = rearrange(self.norm_q(q), 'b l (n h) -> b l n h', h=self.head_dim) # (b, l, n, h) 218 | k = rearrange(self.norm_k(k), 'b l (n h) -> b l n h', h=self.head_dim) # (b, l, n, h) 219 | v = rearrange(v, 'b l (n h) -> b l n h', h=self.head_dim) # (b, l, n, h) 220 | 221 | # Apply RoPE 222 | q = rope(q) # (b, l, n, h) 223 | k = rope(k) # (b, l, n, h) 224 | 225 | # Efficient attention using Flash Attention CUDA kernels 226 | x = F.scaled_dot_product_attention( 227 | query=rearrange(q, 'b l n h -> b n l h'), 228 | key=rearrange(k, 'b l n h -> b n l h'), 229 | value=rearrange(v, 'b l n h -> b n l h'), 230 | attn_mask=None, 231 | dropout_p=0.0 232 | ) # (b, n, l, h) 233 | 234 | x = rearrange(x, 'b n l h -> b l (n h)') 235 | x = self.proj(x) # (b, l, d) 236 | 237 | return x 238 | 239 | 240 | # class MLP(nn.Module): 241 | # def __init__(self, dim) -> None: 242 | # super().__init__() 243 | 244 | # # The hyper-parameters follow https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py 245 | # hidden_dim = 4 * config.n_embd 246 | 247 | # self.fc1 = nn.Linear(dim, hidden_dim, bias=False) 248 | # self.fc2 = nn.Linear(dim, hidden_dim, bias=False) 249 | # self.proj = nn.Linear(hidden_dim, dim, bias=False) 250 | 251 | # def forward(self, x: Tensor) -> Tensor: 252 | # r"""Causal self attention. 253 | 254 | # Args: 255 | # x: (b, l, d) 256 | 257 | # Outputs: 258 | # x: (b, l, d) 259 | # """ 260 | 261 | # x = F.silu(self.fc1(x)) * self.fc2(x) 262 | # x = self.proj(x) 263 | # return x -------------------------------------------------------------------------------- /audio_flow/models/embedders.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import LongTensor, Tensor 6 | 7 | 8 | class TimestepEmbedder(nn.Module): 9 | r"""Time step embedder. 10 | 11 | References: 12 | [1] https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/nn.py 13 | [2] https://huggingface.co/hpcai-tech/OpenSora-STDiT-v1-HQ-16x256x256/blob/main/layers.py 14 | """ 15 | def __init__( 16 | self, 17 | dim: int, 18 | freq_size: int = 256, 19 | scale: float = 1. # Use 100. for flow matching 20 | ): 21 | super().__init__() 22 | 23 | self.freq_size = freq_size 24 | self.scale = scale 25 | 26 | self.mlp = nn.Sequential( 27 | nn.Linear(freq_size, dim, bias=True), 28 | nn.SiLU(), 29 | nn.Linear(dim, dim, bias=True), 30 | ) 31 | 32 | def timestep_embedding(self, t: Tensor, max_period=10000) -> Tensor: 33 | r""" 34 | 35 | Args: 36 | t: (b,), between 0. and 1. 37 | 38 | Outputs: 39 | embedding: (b, d) 40 | """ 41 | 42 | half = self.freq_size // 2 43 | freqs = torch.exp(-math.log(max_period) * torch.arange(half) / half).to(t.device) # (b,) 44 | args = self.scale * t[:, None] * freqs[None, :] # (b, dim/2) 45 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # (b, dim) 46 | 47 | return embedding 48 | 49 | def forward(self, t: Tensor) -> Tensor: 50 | r"""Calculate time embedding. 51 | 52 | Args: 53 | t: (b,), between 0. and 1. 54 | 55 | Outputs: 56 | out: (b, d) 57 | """ 58 | 59 | t = self.timestep_embedding(t) 60 | t = self.mlp(t) 61 | 62 | return t 63 | 64 | 65 | class LabelEmbedder(nn.Module): 66 | 67 | def __init__(self, classes_num: int, out_channels: int): 68 | super().__init__() 69 | 70 | self.mlp = nn.Sequential( 71 | nn.Embedding(classes_num, out_channels), 72 | nn.SiLU(), 73 | nn.Linear(out_channels, out_channels, bias=True), 74 | ) 75 | 76 | def forward(self, x: LongTensor) -> Tensor: 77 | r"""Calculate label embedding. 78 | 79 | Args: 80 | x: (b,), LongTensor 81 | 82 | Outputs: 83 | out: (b, d) 84 | """ 85 | 86 | return self.mlp(x) 87 | 88 | 89 | class MlpEmbedder(nn.Module): 90 | 91 | def __init__(self, in_channels: int, out_channels: int): 92 | super().__init__() 93 | 94 | self.mlp = nn.Sequential( 95 | nn.Linear(in_channels, out_channels), 96 | nn.SiLU(), 97 | nn.Linear(out_channels, out_channels, bias=True), 98 | ) 99 | 100 | def forward(self, x: Tensor) -> Tensor: 101 | r"""Calculate MLP embedding. 102 | 103 | Args: 104 | x: (b, d, ...) 105 | 106 | Outputs: 107 | out: (b, d, ...) 108 | """ 109 | 110 | x = x.transpose(1, -1) # (b, ..., d) 111 | x = self.mlp(x) # (b, ..., d) 112 | x = x.transpose(1, -1) # (b, d, ...) 113 | return x -------------------------------------------------------------------------------- /audio_flow/models/pad.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | 9 | def pad1d(x: Tensor, patch_size: int) -> Tensor: 10 | r"""Pad a tensor along the last two dims. 11 | 12 | Args: 13 | x: (b, c, t) 14 | patch_size: tuple 15 | 16 | Outpus: 17 | out: (b, c, t) 18 | """ 19 | 20 | T = x.shape[2] 21 | t = patch_size 22 | 23 | pad_t = math.ceil(T / t) * t - T 24 | x = F.pad(x, pad=(0, pad_t)) 25 | 26 | return x 27 | -------------------------------------------------------------------------------- /audio_flow/models/rope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from torch import LongTensor, Tensor 5 | 6 | 7 | class RoPE(nn.Module): 8 | def __init__(self, head_dim: int, max_len: int = 8192, base: int = 10000): 9 | r"""Rotary position embedding. 10 | 11 | [1] Su, Jianlin, et al. "Roformer: Enhanced transformer with rotary 12 | position embedding." Neurocomputing, 2024 13 | 14 | h: head_dim 15 | l: seq_len 16 | """ 17 | super().__init__() 18 | 19 | self.head_dim = head_dim 20 | 21 | # Calculate θ = 1 / 10000**(2i/h) 22 | theta = 1.0 / (base ** (torch.arange(0, head_dim, 2) / head_dim)) # (h/2,) 23 | 24 | # Matrix pθ 25 | pos_theta = torch.outer(torch.arange(max_len), theta).float() # (l, h/2) 26 | 27 | # Rotation matrix 28 | w = torch.stack([torch.cos(pos_theta), torch.sin(pos_theta)], dim=-1) # (l, h/2, 2) 29 | self.register_buffer(name="w", tensor=w) 30 | 31 | def forward(self, x: Tensor) -> Tensor: 32 | r"""Apply RoPE. 33 | 34 | b: batch_size 35 | l: seq_len 36 | n: heads_num 37 | h: head_dim 38 | 39 | Args: 40 | x: (b, l, n, h) 41 | 42 | Outputs: 43 | out: (b, l, n, h) 44 | """ 45 | L = x.shape[1] 46 | x = rearrange(x, 'b l n (h c) -> b l n h c', c=2) # (b, l, n, h/2, 2) 47 | w = self.w[0 : L][None, :, None, :, :] # (1, l, 1, h/2, 2) 48 | x = self.rotate(x, w) # (b, l, n, h/2, 2) 49 | x = rearrange(x, 'b l n h c -> b l n (h c)') # (b, l, n, h) 50 | 51 | return x 52 | 53 | def rotate(self, x: Tensor, w: Tensor) -> Tensor: 54 | r"""Rotate x. 55 | 56 | x0 = cos(θp)·x0 - sin(θp)·x1 57 | x1 = sin(θp)·x0 + cos(θp)·x1 58 | 59 | b: batch_size 60 | l: seq_len 61 | n: heads_num 62 | h: head_dim 63 | 64 | Args: 65 | x: (b, l, n, h/2, 2) 66 | w: (1, l, 1, h/2, 2) 67 | 68 | Outputs: 69 | out: (b, l, n, h/2, 2) 70 | """ 71 | 72 | out = torch.stack([ 73 | w[..., 0] * x[..., 0] - w[..., 1] * x[..., 1], 74 | w[..., 0] * x[..., 1] + w[..., 1] * x[..., 0] 75 | ], 76 | dim=-1, 77 | ) # (b, l, n, h/2, 2) 78 | 79 | return out 80 | 81 | def apply_nd(self, x: Tensor, pos: LongTensor) -> Tensor: 82 | r"""Apply Nd RoPE with sparse positions. 83 | 84 | b: batch_size 85 | l: seq_len 86 | n: heads_num 87 | h: head_dim 88 | k: data dim 89 | 90 | Args: 91 | x: (b, l, n, h) 92 | pos: (l, k) 93 | n_dim: int 94 | 95 | Outputs: 96 | out: (b, l, n, h) 97 | """ 98 | 99 | B, L, N, H = x.shape 100 | K = pos.shape[1] # rope_dim 101 | assert H == K * self.head_dim 102 | 103 | x = rearrange(x, 'b l n (k h c) -> k b l n h c', k=K, c=2) # (k, b, l, n, h/2, 2) 104 | x = x.contiguous() 105 | 106 | for i in range(K): 107 | p = pos[:, i] # (l,) 108 | w = self.w[p][None, :, None, :, :] # (1, l, 1, h/2, 2) 109 | x[i] = self.rotate(x[i], w).clone() # x: (k, b, l, n, h/2, 2) 110 | 111 | out = rearrange(x, 'k b l n h c -> b l n (k h c)') # (b, l, n, h) 112 | 113 | return out 114 | 115 | 116 | if __name__ == '__main__': 117 | 118 | torch.manual_seed(1234) 119 | 120 | B = 4 # batch_size 121 | L = 100 # time_steps 122 | N = 8 # n_head 123 | H = 24 # head_dim 124 | n_embd = N * H 125 | 126 | print("Example 1: RoPE (1D)") 127 | rope = RoPE(head_dim=H) 128 | x = torch.rand((B, L, N, H)) # (b, l, n, h) 129 | out = rope(x) # (b, l, n, h) 130 | print(out.shape) 131 | 132 | print("Example 2: RoPE (1D) with sparse positions") 133 | rope = RoPE(head_dim=H) 134 | x = torch.rand((B, 4, N, H)) # (b, l, n, h) 135 | pos = torch.LongTensor([[0], [3], [7], [8]]) # (l, 1) 136 | out = rope.apply_nd(x, pos) # (b, l, n, h) 137 | print(out.shape) 138 | 139 | print("Example 3: RoPE (2D image) with sparse positions") 140 | data_dim = 2 141 | rope = RoPE(head_dim=H // data_dim) 142 | x = torch.rand((B, 4, N, H)) 143 | pos = torch.LongTensor([[0, 0], [0, 1], [1, 0], [1, 1]]) 144 | out = rope.apply_nd(x, pos) # (b, l, n, h) 145 | print(out.shape) 146 | 147 | print("Example 4: RoPE (3D video) with sparse positions") 148 | data_dim = 3 149 | rope = RoPE(head_dim=H // data_dim) 150 | x = torch.rand((B, 4, N, H)) 151 | pos = torch.LongTensor([[0, 0, 0], [1, 3, 4], [2, 2, 2], [5, 4, 3]]) 152 | out3 = rope.apply_nd(x, pos) # (b, l, n, h) 153 | print(out3.shape) 154 | 155 | # Visualization of RoPE weights 156 | import matplotlib.pyplot as plt 157 | fig, axs = plt.subplots(2, 1, sharex=True) 158 | axs[0].matshow(rope.w[:, :, 0].data.cpu().numpy().T, origin='lower', aspect='auto', cmap='jet') 159 | axs[1].matshow(rope.w[:, :, 1].data.cpu().numpy().T, origin='lower', aspect='auto', cmap='jet') 160 | plt.savefig("rope.pdf") 161 | print("Write out to rope.pdf") -------------------------------------------------------------------------------- /audio_flow/models/transformer1d.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from einops import rearrange 7 | from torch import Tensor 8 | 9 | from audio_flow.models.attention import Block 10 | from audio_flow.models.embedders import LabelEmbedder, MlpEmbedder, TimestepEmbedder 11 | from audio_flow.models.rope import RoPE 12 | from audio_flow.models.pad import pad1d 13 | 14 | 15 | class Transformer1D(nn.Module): 16 | def __init__( 17 | self, 18 | in_dim=16, 19 | patch_size=1, 20 | dim=384, 21 | mlp_ratio=4.0, 22 | num_layers=12, 23 | num_heads=12, 24 | rope_len=8192, 25 | **kwargs 26 | ): 27 | 28 | super().__init__() 29 | 30 | self.patch_size = patch_size 31 | 32 | self.patch_x = nn.Conv1d(in_dim, dim, kernel_size=patch_size, stride=patch_size) 33 | self.unpatch_x = nn.ConvTranspose1d(dim, in_dim, kernel_size=patch_size, stride=patch_size) 34 | 35 | # Time embedder 36 | self.t_embedder = TimestepEmbedder(dim=dim, freq_size=256, scale=100.) 37 | 38 | self.blocks = nn.ModuleList(Block(dim, num_heads) for _ in range(num_layers)) 39 | 40 | head_dim = dim // num_heads 41 | self.rope = RoPE(head_dim, max_len=rope_len) 42 | 43 | def forward( 44 | self, 45 | t: Tensor, 46 | x: Tensor, 47 | emb_dict: dict 48 | ) -> Tensor: 49 | """Model 50 | 51 | Args: 52 | t: (b,), random time steps between 0. and 1. 53 | x: (b, d, t) 54 | cond_dict: dict 55 | 56 | Outputs: 57 | output: (b, d, t) 58 | """ 59 | 60 | assert all(key in ["c", "ct", "cx"] for key in emb_dict.keys()), "Invalid key in emb_dict!" 61 | 62 | c = emb_dict.get("c", None) 63 | ct = emb_dict.get("ct", None) 64 | cx = emb_dict.get("cx", None) 65 | 66 | B, D, T = x.shape 67 | x = pad1d(x, self.patch_size) # x: (b, d, t) 68 | x = self.patch_x(x) # shape: (b, d, t, f) 69 | 70 | e = torch.zeros_like(x) 71 | 72 | # 2.1 Time embedder. Repeat B times for inference 73 | if t.dim() == 0: 74 | t = t.repeat(B) 75 | 76 | e += self.t_embedder(t)[:, :, None] 77 | 78 | if c is not None: 79 | e += c[:, :, None] 80 | 81 | if ct is not None: 82 | e += ct[:, :, :] 83 | 84 | if cx is not None: 85 | cx = rearrange(cx, 'b d t -> b t d') 86 | 87 | x = rearrange(x, 'b d t -> b t d') 88 | e = rearrange(e, 'b d t -> b t d') 89 | 90 | for block in self.blocks: 91 | x = block(x, e, cx, self.rope) 92 | 93 | x = rearrange(x, 'b t d -> b d t') 94 | 95 | x = self.unpatch_x(x) 96 | x = x[:, :, 0 : T] 97 | 98 | return x -------------------------------------------------------------------------------- /audio_flow/samplers/sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Sized 3 | 4 | 5 | class RepeatShuffleSampler: 6 | def __init__(self, dataset: Sized) -> None: 7 | r"""Randomly sample indices of a dataset without replacement. Execute 8 | this process infinitely. 9 | """ 10 | 11 | self.indices = list(range(len(dataset))) 12 | random.shuffle(self.indices) # self.indices: [3, 7, 0, ...] 13 | self.p = 0 # pointer 14 | 15 | def __iter__(self) -> int: 16 | r"""Yield an index.""" 17 | 18 | while True: 19 | 20 | if self.p == len(self.indices): 21 | random.shuffle(self.indices) 22 | self.p = 0 23 | 24 | index = self.indices[self.p] 25 | self.p += 1 26 | 27 | yield index -------------------------------------------------------------------------------- /audio_flow/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import sys 5 | from collections import OrderedDict 6 | from contextlib import contextmanager 7 | 8 | import librosa 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import yaml 14 | from torch import Tensor 15 | from einops import rearrange 16 | 17 | 18 | def parse_yaml(config_yaml: str) -> dict: 19 | r"""Parse yaml file.""" 20 | 21 | with open(config_yaml, "r") as fr: 22 | return yaml.load(fr, Loader=yaml.FullLoader) 23 | 24 | 25 | class LinearWarmUp: 26 | r"""Linear learning rate warm up scheduler.""" 27 | 28 | def __init__(self, warm_up_steps: int) -> None: 29 | self.warm_up_steps = warm_up_steps 30 | 31 | def __call__(self, step: int) -> float: 32 | if step <= self.warm_up_steps: 33 | return step / self.warm_up_steps 34 | else: 35 | return 1. 36 | 37 | 38 | @torch.no_grad() 39 | def update_ema(ema: nn.Module, model: nn.Module, decay: float = 0.999) -> None: 40 | """Update EMA model weights and buffers from model.""" 41 | 42 | # Parameters 43 | for e, m in zip(ema.parameters(), model.parameters()): 44 | e.mul_(decay).add_(m.data.float(), alpha=1 - decay) 45 | 46 | # Buffers (BN running stats, etc) 47 | for e, m in zip(ema.buffers(), model.buffers()): 48 | if m.dtype in [torch.bool, torch.long]: 49 | continue 50 | e.mul_(decay).add_(m.data.float(), alpha=1 - decay) 51 | 52 | 53 | def requires_grad(model: nn.Module, flag=True) -> None: 54 | for p in model.parameters(): 55 | p.requires_grad = flag 56 | 57 | 58 | class CombinedModel(nn.Module): 59 | def __init__(self, base: nn.Module, adaptor: nn.Module) -> None: 60 | super().__init__() 61 | self.base = base 62 | self.adaptor = adaptor 63 | 64 | 65 | def forward_in_chunks(model: nn.Module, audio: np.array, clip_samples: int) -> np.array: 66 | 67 | device = next(model.parameters()).device 68 | latents = [] 69 | i = 0 70 | skip_samples = 10000 71 | 72 | while i < audio.shape[-1]: 73 | 74 | if audio.shape[-1] - i < skip_samples: 75 | break 76 | 77 | x = Tensor(audio[None, :, i : i + clip_samples]).to(device) 78 | 79 | with torch.no_grad(): 80 | model.eval() 81 | latent = model(x)[0].data.cpu().numpy() # (d, t) 82 | 83 | latents.append(latent) 84 | i += clip_samples 85 | 86 | latents = np.concatenate(latents, axis=-1) 87 | 88 | return latents 89 | 90 | 91 | def align_temporal_features( 92 | input: Tensor, 93 | target: Tensor, 94 | input_fps: float, 95 | target_fps: float 96 | ) -> Tensor: 97 | r"""Align or stack the input with the target along the temporal axis. 98 | 99 | Args: 100 | input: (any, t1) 101 | target: (any, t2) 102 | input_fps: float 103 | target_fps: float 104 | 105 | Outputs: 106 | output: (any, t2) if t1 ≤ t2 107 | (any, t2*w) if t1 > t2 108 | """ 109 | 110 | if input_fps == target_fps: 111 | return input 112 | 113 | elif input_fps > target_fps: 114 | 115 | ratio = input_fps / target_fps 116 | width = round(ratio / 2) 117 | T = target.shape[-1] 118 | 119 | indices = torch.round(torch.arange(0, T) * ratio) # (t,) 120 | indices = indices[:, None] + torch.arange(-width, width + 1) # (t, w) 121 | indices = torch.clamp(indices, 0, T).long() # (t, w) 122 | 123 | indices = rearrange(indices, 't w -> (t w)') # (t*w) 124 | output = rearrange(input[..., indices], 'b d (t w) -> b (w d) t', t=T) # (b, w*d, t) 125 | return output 126 | 127 | else: 128 | ratio = input_fps / target_fps 129 | T = target.shape[-1] 130 | indices = (torch.arange(0, T) * ratio).long() # (t,) 131 | output = input[..., indices] # (b, d, t) 132 | return output 133 | 134 | 135 | def logmel(audio: np.ndarray, sr: float) -> np.ndarray: 136 | 137 | if audio.ndim == 2: 138 | audio = np.mean(audio, axis=0) 139 | 140 | return np.log10(librosa.feature.melspectrogram( 141 | y=audio, 142 | sr=sr, 143 | n_fft=2048, 144 | hop_length=round(sr * 0.01), 145 | n_mels=128 146 | )).T # (t, f) 147 | 148 | 149 | ''' 150 | def fix_length(x: Tensor, size: int) -> Tensor: 151 | 152 | if x.shape[-1] >= size: 153 | return x[:, :, 0 : size] 154 | else: 155 | pad_t = size - x.shape[-1] 156 | return F.pad(input=x, pad=(0, pad_t)) 157 | 158 | 159 | @contextmanager 160 | def suppress_print(): 161 | original_stdout = sys.stdout 162 | sys.stdout = open(os.devnull, 'w') 163 | try: 164 | yield 165 | finally: 166 | sys.stdout.close() 167 | sys.stdout = original_stdout 168 | 169 | 170 | class Logmel: 171 | def __init__(self, sr: float): 172 | self.sr = sr 173 | self.n_fft = 2048 174 | self.hop_length = round(sr * 0.01) 175 | self.n_mels = 128 176 | 177 | def __call__(self, audio: np.array) -> np.array: 178 | 179 | logmel = np.log10(librosa.feature.melspectrogram( 180 | y=audio, 181 | sr=self.sr, 182 | n_fft=self.n_fft, 183 | hop_length=self.hop_length, 184 | n_mels=self.n_mels 185 | )).T 186 | 187 | return logmel 188 | ''' 189 | -------------------------------------------------------------------------------- /audio_flow/vae/levo.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch.nn as nn 3 | import torch 4 | from torch import Tensor 5 | from huggingface_hub import hf_hub_download 6 | from stable_audio_tools.models.factory import create_model_from_config 7 | from stable_audio_tools.models.autoencoders import AudioAutoencoder 8 | 9 | 10 | class LevoVAE(nn.Module): 11 | 12 | def __init__(self): 13 | super().__init__() 14 | 15 | config_path = hf_hub_download( 16 | repo_id="tencent/SongGeneration", 17 | filename="ckpt/vae/stable_audio_1920_vae.json" 18 | ) 19 | 20 | model_path = hf_hub_download( 21 | repo_id="tencent/SongGeneration", 22 | filename="ckpt/vae/autoencoder_music_1320k.ckpt" 23 | ) 24 | 25 | with open(config_path, "r") as f: 26 | model_config = json.load(f) 27 | 28 | self.vae = create_model_from_config(model_config) 29 | state_dict = torch.load(model_path, map_location="cpu")["state_dict"] 30 | self.vae.load_state_dict(state_dict) 31 | 32 | self.sr = model_config["sample_rate"] 33 | self.fps = 25 34 | 35 | def encode(self, audio: Tensor) -> Tensor: 36 | r""" 37 | 38 | Args: 39 | audio: (b, 2, l) 40 | """ 41 | 42 | with torch.no_grad(): 43 | self.vae.eval() 44 | latent = self.vae.encode_audio(audio) 45 | 46 | return latent 47 | 48 | def decode(self, latent): 49 | with torch.no_grad(): 50 | self.vae.eval() 51 | audio = self.vae.decode_audio(latent) 52 | 53 | return audio 54 | 55 | def __call__(self, audio: Tensor) -> Tensor: 56 | return self.encode(audio) -------------------------------------------------------------------------------- /compute_latents/gtzan.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | from pathlib import Path 6 | 7 | import h5py 8 | import librosa 9 | import numpy as np 10 | 11 | from audio_flow.utils import forward_in_chunks 12 | from audio_flow.vae.levo import LevoVAE 13 | 14 | 15 | def compute_vae(args) -> None: 16 | 17 | root = args.dataset_root 18 | out_dir = args.out_dir 19 | aug_repeats = args.augmentation_repeats 20 | 21 | device = "cuda" 22 | clip_duration = 30. 23 | 24 | # Load VAE 25 | vae = LevoVAE().to(device) 26 | 27 | clip_samples = int(clip_duration * vae.sr) 28 | 29 | # Compuate VAE latent 30 | labels = sorted(os.listdir(Path(root, "genres"))) 31 | 32 | for k, label in enumerate(labels): 33 | 34 | print("{}/{}".format(k, len(labels))) 35 | 36 | paths = sorted(list(Path(root, "genres", label).glob("*.au"))) 37 | 38 | for path in paths: 39 | 40 | audio, fs = librosa.load(path=path, sr=vae.sr, mono=False) # (l,) 41 | audio = np.repeat(audio[None, :], repeats=2, axis=0) # (2, l) 42 | 43 | for i in range(aug_repeats): 44 | 45 | jitter = int((i / aug_repeats) * (vae.sr / vae.fps)) 46 | aug_audio = audio[:, jitter :] # (2, l) 47 | aug_audio = librosa.util.fix_length(data=aug_audio, size=clip_samples, axis=-1) 48 | 49 | t1 = time.time() 50 | latents = forward_in_chunks(vae, aug_audio, clip_samples) # (d, t) 51 | t = time.time() - t1 52 | 53 | out_path = Path(out_dir, "genres", label, "{}_{:03d}_vae.h5".format(Path(path).stem, i)) 54 | out_path.parent.mkdir(parents=True, exist_ok=True) 55 | 56 | with h5py.File(out_path, 'w') as hf: 57 | hf.create_dataset("latent", data=latents, dtype=np.float32) 58 | hf.attrs.create("fps", data=vae.fps, dtype=float) 59 | 60 | print(f"Write out to {out_path} time: {t:.2f} s") 61 | 62 | 63 | if __name__ == '__main__': 64 | 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--dataset_root", type=str, required=True, help="Path of config yaml.") 67 | parser.add_argument("--out_dir", type=str) 68 | parser.add_argument("--augmentation_repeats", type=int) 69 | args = parser.parse_args() 70 | 71 | compute_vae(args) -------------------------------------------------------------------------------- /compute_latents/ljspeech.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | from pathlib import Path 6 | 7 | import h5py 8 | import librosa 9 | import numpy as np 10 | 11 | from audio_flow.utils import forward_in_chunks 12 | from audio_flow.vae.levo import LevoVAE 13 | 14 | 15 | def compute_vae(args) -> None: 16 | 17 | root = args.dataset_root 18 | out_dir = args.out_dir 19 | aug_repeats = args.augmentation_repeats 20 | 21 | device = "cuda" 22 | clip_duration = 10. 23 | 24 | # Load VAE 25 | vae = LevoVAE().to(device) 26 | 27 | clip_samples = int(clip_duration * vae.sr) 28 | 29 | # Compuate VAE latent 30 | paths = sorted(list(Path(root, "wavs").glob("*.wav"))) 31 | 32 | for i, path in enumerate(paths): 33 | 34 | print("{}/{}".format(i, len(paths))) 35 | 36 | audio, fs = librosa.load(path=path, sr=vae.sr, mono=False) 37 | audio = np.repeat(audio[None, :], repeats=2, axis=0) # (2, l) 38 | 39 | for i in range(aug_repeats): 40 | 41 | jitter = int((i / aug_repeats) * (vae.sr / vae.fps)) 42 | aug_audio = audio[:, jitter :] # (2, l) 43 | aug_audio = librosa.util.fix_length(data=aug_audio, size=clip_samples, axis=-1) 44 | 45 | t1 = time.time() 46 | latents = forward_in_chunks(vae, aug_audio, clip_samples) 47 | t = time.time() - t1 48 | 49 | out_path = Path(out_dir, "wavs", "{}_{:03d}_vae.h5".format(Path(path).stem, i)) 50 | out_path.parent.mkdir(parents=True, exist_ok=True) 51 | 52 | with h5py.File(out_path, 'w') as hf: 53 | hf.create_dataset("latent", data=latents, dtype=np.float32) 54 | hf.attrs.create("fps", data=vae.fps, dtype=float) 55 | 56 | print(f"Write out to {out_path} time: {t:.2f} s") 57 | 58 | 59 | 60 | if __name__ == '__main__': 61 | 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument("--dataset_root", type=str, required=True, help="Path of config yaml.") 64 | parser.add_argument("--out_dir", type=str) 65 | parser.add_argument("--augmentation_repeats", type=int) 66 | args = parser.parse_args() 67 | 68 | compute_vae(args) -------------------------------------------------------------------------------- /compute_latents/maestro.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | from pathlib import Path 6 | 7 | import h5py 8 | import librosa 9 | import numpy as np 10 | 11 | from audio_flow.utils import forward_in_chunks 12 | from audio_flow.vae.levo import LevoVAE 13 | 14 | 15 | def compute_vae(args) -> None: 16 | 17 | root = args.dataset_root 18 | out_dir = args.out_dir 19 | aug_repeats = args.augmentation_repeats 20 | 21 | device = "cuda" 22 | clip_duration = 30. 23 | 24 | # Load VAE 25 | vae = LevoVAE().to(device) 26 | 27 | clip_samples = int(clip_duration * vae.sr) 28 | 29 | # Compuate VAE latent 30 | labels = sorted(os.listdir(Path(root, "genres"))) 31 | 32 | for k, label in enumerate(labels): 33 | 34 | print("{}/{}".format(k, len(labels))) 35 | 36 | paths = sorted(list(Path(root, "genres", label).glob("*.au"))) 37 | 38 | for path in paths: 39 | 40 | audio, fs = librosa.load(path=path, sr=vae.sr, mono=False) # (l,) 41 | audio = np.repeat(audio[None, :], repeats=2, axis=0) # (2, l) 42 | 43 | for i in range(aug_repeats): 44 | 45 | jitter = int((i / aug_repeats) * (vae.sr / vae.fps)) 46 | aug_audio = audio[:, jitter :] # (2, l) 47 | aug_audio = librosa.util.fix_length(data=aug_audio, size=clip_samples, axis=-1) 48 | 49 | t1 = time.time() 50 | latents = forward_in_chunks(vae, aug_audio, clip_samples) # (d, t) 51 | t = time.time() - t1 52 | 53 | out_path = Path(out_dir, "genres", label, "{}_{:03d}_vae.h5".format(Path(path).stem, i)) 54 | out_path.parent.mkdir(parents=True, exist_ok=True) 55 | 56 | with h5py.File(out_path, 'w') as hf: 57 | hf.create_dataset("latent", data=latents, dtype=np.float32) 58 | hf.attrs.create("fps", data=vae.fps, dtype=float) 59 | 60 | print(f"Write out to {out_path} time: {t:.2f} s") 61 | 62 | 63 | if __name__ == '__main__': 64 | 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--dataset_root", type=str, required=True, help="Path of config yaml.") 67 | parser.add_argument("--out_dir", type=str) 68 | parser.add_argument("--augmentation_repeats", type=int) 69 | args = parser.parse_args() 70 | 71 | compute_vae(args) -------------------------------------------------------------------------------- /compute_latents/musdb18hq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | from pathlib import Path 6 | 7 | import h5py 8 | import librosa 9 | import torch 10 | import numpy as np 11 | 12 | from audio_flow.utils import forward_in_chunks 13 | from audio_flow.vae.levo import LevoVAE 14 | from audio_flow.encoders.dac import DAC 15 | 16 | 17 | def compute_stems_vae(args) -> None: 18 | 19 | root = args.dataset_root 20 | out_dir = args.out_dir 21 | aug_repeats = args.augmentation_repeats 22 | 23 | device = "cuda" 24 | clip_duration = 60. 25 | 26 | # Load VAE 27 | vae = LevoVAE().to(device) 28 | 29 | clip_samples = int(clip_duration * vae.sr) 30 | 31 | # Compuate VAE latent 32 | stems = ["vocals", "bass", "drums", "other", "mixture"] 33 | 34 | for split in ["train", "test"]: 35 | 36 | audio_names = sorted(os.listdir(Path(root, split))) 37 | 38 | for i, name in enumerate(audio_names): 39 | 40 | print("{}/{}, {}".format(i, len(audio_names), name)) 41 | 42 | for stem in stems: 43 | 44 | path = Path(root, split, name, f"{stem}.wav") 45 | audio, fs = librosa.load(path=path, sr=vae.sr, mono=False) # (2, l) 46 | 47 | for i in range(aug_repeats): 48 | 49 | jitter = int((i / aug_repeats) * (vae.sr / vae.fps)) 50 | stereo = audio[:, jitter :] # (2, l) 51 | 52 | t1 = time.time() 53 | latents = forward_in_chunks(vae, stereo, clip_samples) # (d, t) 54 | t = time.time() - t1 55 | 56 | out_path = Path(out_dir, split, name, f"{stem}_{i:03d}_vae.h5") 57 | out_path.parent.mkdir(parents=True, exist_ok=True) 58 | 59 | with h5py.File(out_path, 'w') as hf: 60 | hf.create_dataset("latent", data=latents, dtype=np.float32) 61 | hf.attrs.create("fps", data=vae.fps, dtype=float) 62 | 63 | print(f"Write out to {out_path} time: {t:.2f} s") 64 | 65 | 66 | def compute_mono_stereo_vae(args) -> None: 67 | 68 | root = args.dataset_root 69 | out_dir = args.out_dir 70 | aug_repeats = args.augmentation_repeats 71 | 72 | device = "cuda" 73 | clip_duration = 60. 74 | 75 | # Load VAE 76 | vae = LevoVAE().to(device) 77 | 78 | clip_samples = int(clip_duration * vae.sr) 79 | 80 | # Compuate VAE latent 81 | stems = ["mixture"] 82 | 83 | for split in ["train", "test"]: 84 | 85 | audio_names = sorted(os.listdir(Path(root, split))) 86 | 87 | for i, name in enumerate(audio_names): 88 | 89 | print("{}/{}, {}".format(i, len(audio_names), name)) 90 | 91 | for stem in stems: 92 | 93 | path = Path(root, split, name, f"{stem}.wav") 94 | audio, fs = librosa.load(path=path, sr=vae.sr, mono=False) 95 | 96 | for i in range(aug_repeats): 97 | 98 | jitter = int((i / aug_repeats) * (vae.sr / vae.fps)) 99 | stereo = audio[:, jitter :] # (2, l) 100 | mono = stereo.mean(axis=0, keepdims=True).repeat(repeats=2, axis=0) 101 | 102 | t1 = time.time() 103 | stereo_latents = forward_in_chunks(vae, stereo, clip_samples) 104 | mono_latents = forward_in_chunks(vae, mono, clip_samples) 105 | t = time.time() - t1 106 | 107 | sub_dir = Path(out_dir, split, name) 108 | sub_dir.mkdir(parents=True, exist_ok=True) 109 | 110 | stereo_out_path = Path(sub_dir, f"{stem}_{i:03d}_stereo_vae.h5") 111 | mono_out_path = Path(sub_dir, f"{stem}_{i:03d}_mono_vae.h5") 112 | 113 | with h5py.File(stereo_out_path, 'w') as hf: 114 | hf.create_dataset("latent", data=stereo_latents, dtype=np.float32) 115 | hf.attrs.create("fps", data=vae.fps, dtype=float) 116 | 117 | with h5py.File(mono_out_path, 'w') as hf: 118 | hf.create_dataset("latent", data=mono_latents, dtype=np.float32) 119 | hf.attrs.create("fps", data=vae.fps, dtype=float) 120 | 121 | print(f"Write out to {stereo_out_path}") 122 | print(f"Write out to {mono_out_path}") 123 | print(f"time: {t:.2f} s") 124 | 125 | 126 | def compute_dac_stereo_vae(args) -> None: 127 | 128 | import dac 129 | 130 | root = args.dataset_root 131 | out_dir = args.out_dir 132 | aug_repeats = args.augmentation_repeats 133 | 134 | device = "cuda" 135 | clip_duration = 60. 136 | n_quantizers = 2 137 | 138 | # Load VAE 139 | vae = LevoVAE().to(device) 140 | dac = DAC(n_quantizers=n_quantizers).to(device) 141 | 142 | clip_samples = int(clip_duration * vae.sr) 143 | 144 | # Compuate VAE latent 145 | stems = ["mixture"] 146 | 147 | for split in ["train", "test"]: 148 | 149 | audio_names = sorted(os.listdir(Path(root, split))) 150 | 151 | for i, name in enumerate(audio_names): 152 | 153 | print("{}/{}, {}".format(i, len(audio_names), name)) 154 | 155 | for stem in stems: 156 | 157 | path = Path(root, split, name, f"{stem}.wav") 158 | audio, fs = librosa.load(path=path, sr=vae.sr, mono=False) 159 | 160 | for i in range(aug_repeats): 161 | 162 | jitter = int((i / aug_repeats) * (vae.sr / vae.fps)) 163 | stereo = audio[:, jitter :] # (2, l) 164 | 165 | dac_audio = librosa.resample(y=stereo, orig_sr=vae.sr, target_sr=dac.sr) 166 | dac_audio = dac_audio.mean(axis=0, keepdims=True) 167 | 168 | t1 = time.time() 169 | vae_latents = forward_in_chunks(vae, stereo, clip_samples) 170 | dac_codes = forward_in_chunks(dac, dac_audio, clip_samples) 171 | t = time.time() - t1 172 | 173 | sub_dir = Path(out_dir, split, name) 174 | sub_dir.mkdir(parents=True, exist_ok=True) 175 | 176 | vae_out_path = Path(sub_dir, f"{stem}_{i:03d}_vae.h5") 177 | dac_out_path = Path(sub_dir, f"{stem}_{i:03d}_dac.h5") 178 | 179 | with h5py.File(vae_out_path, 'w') as hf: 180 | hf.create_dataset("latent", data=vae_latents, dtype=np.float32) 181 | hf.attrs.create("fps", data=vae.fps, dtype=float) 182 | 183 | with h5py.File(dac_out_path, 'w') as hf: 184 | hf.create_dataset("code", data=dac_codes, dtype=int) 185 | hf.attrs.create("fps", data=dac.fps, dtype=float) 186 | 187 | print(f"Write out to {vae_out_path}") 188 | print(f"Write out to {dac_out_path}") 189 | print(f"time: {t:.2f} s") 190 | 191 | 192 | def compute_8khz_44khz_vae(args) -> None: 193 | 194 | root = args.dataset_root 195 | out_dir = args.out_dir 196 | aug_repeats = args.augmentation_repeats 197 | 198 | device = "cuda" 199 | clip_duration = 60. 200 | lowres_sr = 8000. 201 | 202 | # Load VAE 203 | vae = LevoVAE().to(device) 204 | 205 | clip_samples = int(clip_duration * vae.sr) 206 | 207 | # Compuate VAE latent 208 | stems = ["mixture"] 209 | 210 | for split in ["train", "test"]: 211 | 212 | audio_names = sorted(os.listdir(Path(root, split))) 213 | 214 | for i, name in enumerate(audio_names): 215 | 216 | print("{}/{}, {}".format(i, len(audio_names), name)) 217 | 218 | for stem in stems: 219 | 220 | path = Path(root, split, name, f"{stem}.wav") 221 | audio, fs = librosa.load(path=path, sr=vae.sr, mono=False) 222 | 223 | for i in range(aug_repeats): 224 | 225 | jitter = int((i / aug_repeats) * (vae.sr / vae.fps)) 226 | stereo = audio[:, jitter :] # (2, l) 227 | 228 | lowres_audio = librosa.resample(y=stereo, orig_sr=vae.sr, target_sr=lowres_sr) 229 | lowres_audio = librosa.resample(y=lowres_audio, orig_sr=lowres_sr, target_sr=vae.sr) 230 | 231 | t1 = time.time() 232 | stereo_latents = forward_in_chunks(vae, stereo, clip_samples) 233 | lowres_latents = forward_in_chunks(vae, lowres_audio, clip_samples) 234 | t = time.time() - t1 235 | 236 | sub_dir = Path(out_dir, split, name) 237 | sub_dir.mkdir(parents=True, exist_ok=True) 238 | 239 | highres_out_path = Path(sub_dir, f"{stem}_{i:03d}_highres_vae.h5") 240 | lowres_out_path = Path(sub_dir, f"{stem}_{i:03d}_lowres_vae.h5") 241 | 242 | with h5py.File(highres_out_path, 'w') as hf: 243 | hf.create_dataset("latent", data=stereo_latents, dtype=np.float32) 244 | hf.attrs.create("fps", data=vae.fps, dtype=float) 245 | 246 | with h5py.File(lowres_out_path, 'w') as hf: 247 | hf.create_dataset("latent", data=lowres_latents, dtype=np.float32) 248 | hf.attrs.create("fps", data=vae.fps, dtype=float) 249 | 250 | print(f"Write out to {highres_out_path}") 251 | print(f"Write out to {lowres_out_path}") 252 | print(f"time: {t:.2f} s") 253 | 254 | 255 | if __name__ == '__main__': 256 | 257 | parser = argparse.ArgumentParser() 258 | subparsers = parser.add_subparsers(dest="mode") 259 | 260 | parser_stems = subparsers.add_parser("stems") 261 | parser_stems.add_argument("--dataset_root", type=str, required=True, help="Path of config yaml.") 262 | parser_stems.add_argument("--out_dir", type=str) 263 | parser_stems.add_argument("--augmentation_repeats", type=int) 264 | 265 | parser_mono_stereo = subparsers.add_parser("mono_stereo") 266 | parser_mono_stereo.add_argument("--dataset_root", type=str, required=True, help="Path of config yaml.") 267 | parser_mono_stereo.add_argument("--out_dir", type=str) 268 | parser_mono_stereo.add_argument("--augmentation_repeats", type=int) 269 | 270 | parser_dac_stereo = subparsers.add_parser("dac_stereo") 271 | parser_dac_stereo.add_argument("--dataset_root", type=str, required=True, help="Path of config yaml.") 272 | parser_dac_stereo.add_argument("--out_dir", type=str) 273 | parser_dac_stereo.add_argument("--augmentation_repeats", type=int) 274 | 275 | parser_dac_stereo = subparsers.add_parser("8khz_44khz") 276 | parser_dac_stereo.add_argument("--dataset_root", type=str, required=True, help="Path of config yaml.") 277 | parser_dac_stereo.add_argument("--out_dir", type=str) 278 | parser_dac_stereo.add_argument("--augmentation_repeats", type=int) 279 | 280 | args = parser.parse_args() 281 | 282 | if args.mode == "stems": 283 | compute_stems_vae(args) 284 | 285 | elif args.mode == "mono_stereo": 286 | compute_mono_stereo_vae(args) 287 | 288 | elif args.mode == "dac_stereo": 289 | compute_dac_stereo_vae(args) 290 | 291 | elif args.mode == "8khz_44khz": 292 | compute_8khz_44khz_vae(args) 293 | 294 | else: 295 | raise ValueError -------------------------------------------------------------------------------- /configs/dac2stereo.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | clip_duration: 10. 3 | 4 | train_datasets: 5 | MUSDB18HqDac2StereoVAE: 6 | root: "./datasets/musdb18hq_dac_stereo_vae" 7 | split: "train" 8 | 9 | test_datasets: 10 | MUSDB18HqDac2StereoVAE: 11 | root: "./datasets/musdb18hq_dac_stereo_vae" 12 | split: "test" 13 | 14 | sampler: RepeatShuffleSampler 15 | 16 | data_transform: 17 | name: Dac2StereoVAE 18 | 19 | base: 20 | name: Transformer1D 21 | in_dim: 64 22 | patch_size: 1 23 | dim: 768 24 | num_layers: 12 25 | num_heads: 12 26 | rope_len: 8192 27 | 28 | adaptor: 29 | name: LatentEncoder 30 | dim: 5120 31 | 32 | train: 33 | device: cuda 34 | num_workers: 16 35 | precision: "no" # "no" (fp32) | "bf16" 36 | optimizer: AdamW 37 | lr: 1e-4 38 | warm_up_steps: 1000 # Leave blank if no warm up is used 39 | batch_size_per_device: 16 40 | test_every_n_steps: 10000 41 | save_every_n_steps: 50000 42 | training_steps: 500000 43 | resume_ckpt_path: # Leave blank if train from scratch 44 | 45 | valid_audios: 10 -------------------------------------------------------------------------------- /configs/mono2stereo.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | clip_duration: 10. 3 | 4 | train_datasets: 5 | MUSDB18HqMono2StereoVAE: 6 | root: "./datasets/musdb18hq_mono_stereo_vae" 7 | split: "train" 8 | 9 | test_datasets: 10 | MUSDB18HqMono2StereoVAE: 11 | root: "./datasets/musdb18hq_mono_stereo_vae" 12 | split: "test" 13 | 14 | sampler: RepeatShuffleSampler 15 | 16 | data_transform: 17 | name: Mono2StereoVAE 18 | 19 | base: 20 | name: Transformer1D 21 | in_dim: 64 22 | patch_size: 1 23 | dim: 768 24 | num_layers: 12 25 | num_heads: 12 26 | rope_len: 8192 27 | 28 | adaptor: 29 | name: LatentEncoder 30 | dim: 64 31 | 32 | train: 33 | device: cuda 34 | num_workers: 16 35 | precision: "no" # "no" (fp32) | "bf16" 36 | optimizer: AdamW 37 | lr: 1e-4 38 | warm_up_steps: 1000 # Leave blank if no warm up is used 39 | batch_size_per_device: 16 40 | test_every_n_steps: 10000 41 | save_every_n_steps: 50000 42 | training_steps: 500000 43 | resume_ckpt_path: # Leave blank if train from scratch 44 | 45 | valid_audios: 10 -------------------------------------------------------------------------------- /configs/mss.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | clip_duration: 10. 3 | 4 | train_datasets: 5 | MUSDB18HqVAE: 6 | root: "./datasets/musdb18hq_stems_vae" 7 | split: "train" 8 | 9 | test_datasets: 10 | MUSDB18HqVAE: 11 | root: "./datasets/musdb18hq_stems_vae" 12 | split: "test" 13 | 14 | sampler: RepeatShuffleSampler 15 | 16 | data_transform: 17 | name: MSS 18 | target_stem: vocals 19 | 20 | base: 21 | name: Transformer1D 22 | in_dim: 64 23 | patch_size: 1 24 | dim: 768 25 | num_layers: 12 26 | num_heads: 12 27 | rope_len: 8192 28 | 29 | adaptor: 30 | name: LatentEncoder 31 | dim: 64 32 | 33 | train: 34 | device: cuda 35 | num_workers: 16 36 | precision: "no" # "no" (fp32) | "bf16" 37 | optimizer: AdamW 38 | lr: 1e-4 39 | warm_up_steps: 1000 # Leave blank if no warm up is used 40 | batch_size_per_device: 16 41 | test_every_n_steps: 10000 42 | save_every_n_steps: 50000 43 | training_steps: 500000 44 | resume_ckpt_path: # Leave blank if train from scratch 45 | 46 | valid_audios: 10 -------------------------------------------------------------------------------- /configs/superresolution.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | clip_duration: 10. 3 | 4 | train_datasets: 5 | MUSDB18HqLowres2HighresVAE: 6 | root: "./datasets/musdb18hq_8k_44k_vae" 7 | split: "train" 8 | 9 | test_datasets: 10 | MUSDB18HqLowres2HighresVAE: 11 | root: "./datasets/musdb18hq_8k_44k_vae" 12 | split: "test" 13 | 14 | sampler: RepeatShuffleSampler 15 | 16 | data_transform: 17 | name: SuperResolutionVAE 18 | 19 | base: 20 | name: Transformer1D 21 | in_dim: 64 22 | patch_size: 1 23 | dim: 768 24 | num_layers: 12 25 | num_heads: 12 26 | rope_len: 8192 27 | 28 | adaptor: 29 | name: LatentEncoder 30 | dim: 64 31 | 32 | train: 33 | device: cuda 34 | num_workers: 0 35 | precision: "no" # "no" (fp32) | "bf16" 36 | optimizer: AdamW 37 | lr: 1e-4 38 | warm_up_steps: 1000 # Leave blank if no warm up is used 39 | batch_size_per_device: 16 40 | test_every_n_steps: 10000 41 | save_every_n_steps: 50000 42 | training_steps: 500000 43 | resume_ckpt_path: # Leave blank if train from scratch 44 | 45 | valid_audios: 10 -------------------------------------------------------------------------------- /configs/text2music.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | clip_duration: 10. 3 | 4 | train_datasets: 5 | GtzanVAE: 6 | root: "./datasets/gtzan_vae" 7 | split: "train" 8 | 9 | test_datasets: 10 | GtzanVAE: 11 | root: "./datasets/gtzan_vae" 12 | split: "test" 13 | 14 | sampler: RepeatShuffleSampler 15 | 16 | data_transform: 17 | name: Label2MusicVAE 18 | 19 | base: 20 | name: Transformer1D 21 | in_dim: 64 22 | patch_size: 1 23 | dim: 768 24 | num_layers: 12 25 | num_heads: 12 26 | rope_len: 8192 27 | 28 | adaptor: 29 | name: OnehotEncoder 30 | num_classes: 10 31 | 32 | train: 33 | device: cuda 34 | num_workers: 16 35 | precision: "no" # "no" (fp32) | "bf16" 36 | optimizer: AdamW 37 | lr: 1e-4 38 | warm_up_steps: 1000 # Leave blank if no warm up is used 39 | batch_size_per_device: 16 40 | test_every_n_steps: 10000 41 | save_every_n_steps: 50000 42 | training_steps: 500000 43 | resume_ckpt_path: # Leave blank if train from scratch 44 | 45 | valid_audios: 10 -------------------------------------------------------------------------------- /configs/text2speech.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | clip_duration: 10. 3 | 4 | train_datasets: 5 | LJSpeechVAE: 6 | root: "./datasets/ljspeech_vae" 7 | split: "train" 8 | 9 | test_datasets: 10 | LJSpeechVAE: 11 | root: "./datasets/ljspeech_vae" 12 | split: "test" 13 | 14 | sampler: RepeatShuffleSampler 15 | 16 | data_transform: 17 | name: Text2SpeechVAE 18 | 19 | base: 20 | name: Transformer1D 21 | in_dim: 64 22 | patch_size: 1 23 | dim: 768 24 | num_layers: 12 25 | num_heads: 12 26 | rope_len: 8192 27 | 28 | adaptor: 29 | name: OnehotEncoder 30 | num_classes: 30522 # Bert vocabulary 31 | 32 | train: 33 | device: cuda 34 | num_workers: 16 35 | precision: "no" # "no" (fp32) | "bf16" 36 | optimizer: AdamW 37 | lr: 1e-4 38 | warm_up_steps: 1000 # Leave blank if no warm up is used 39 | batch_size_per_device: 16 40 | test_every_n_steps: 10000 41 | save_every_n_steps: 50000 42 | training_steps: 500000 43 | resume_ckpt_path: # Leave blank if train from scratch 44 | 45 | valid_audios: 10 -------------------------------------------------------------------------------- /env.sh: -------------------------------------------------------------------------------- 1 | pip install librosa==0.11.0 2 | pip install torch==2.8.0 3 | pip install torchaudio==2.8.0 4 | pip install torchdiffeq==0.2.5 5 | pip install tqdm==4.67.1 6 | pip install einops==0.8.1 7 | pip install matplotlib==3.10.1 8 | pip install torchcfm==1.0.7 9 | pip install bigvgan 10 | pip install descript-audio-codec 11 | pip install audidata==0.0.5 12 | pip install wandb==0.19.10 13 | pip install accelerate==1.10.0 14 | pip install h5py==3.14.0 15 | pip install stable-audio-tools -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | from copy import deepcopy 5 | from pathlib import Path 6 | from typing import Iterable, Literal 7 | 8 | import matplotlib.pyplot as plt 9 | import soundfile 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torchdiffeq 14 | import wandb 15 | from audio_flow.utils import (CombinedModel, LinearWarmUp, parse_yaml, 16 | requires_grad, update_ema, logmel) 17 | from torch.utils.data import DataLoader, Dataset 18 | from torch.utils.data._utils.collate import default_collate 19 | from torchcfm.conditional_flow_matching import ConditionalFlowMatcher 20 | from tqdm import tqdm 21 | 22 | from train import get_dataset, get_data_transform, get_base, get_adaptor, get_sampler, get_optimizer_and_scheduler, validate 23 | 24 | 25 | def train(args) -> None: 26 | r"""Train audio generation with flow matching.""" 27 | 28 | # Arguments 29 | config_path = args.config 30 | ckpt_path = args.ckpt_path 31 | filename = Path(__file__).stem 32 | wandb_log = not args.no_log 33 | 34 | # Configs 35 | configs = parse_yaml(config_path) 36 | device = configs["train"]["device"] 37 | 38 | # Checkpoints directory 39 | config_name = Path(config_path).stem 40 | ckpts_dir = Path("./checkpoints", filename, config_name) 41 | Path(ckpts_dir).mkdir(parents=True, exist_ok=True) 42 | 43 | # Datasets 44 | train_dataset = get_dataset(configs, split="train", mode="train") 45 | 46 | # Sampler 47 | train_sampler = get_sampler(configs, train_dataset) 48 | 49 | # Dataloader 50 | train_dataloader = DataLoader( 51 | dataset=train_dataset, 52 | batch_size=configs["train"]["batch_size_per_device"], 53 | sampler=train_sampler, 54 | num_workers=configs["train"]["num_workers"], 55 | pin_memory=True, 56 | ) 57 | 58 | # Data processor 59 | data_transform = get_data_transform(configs).to(device) 60 | 61 | # Flow matching data processor 62 | fm = ConditionalFlowMatcher(sigma=0.) 63 | 64 | # Model 65 | base = get_base( 66 | configs=configs, 67 | ).to(device) 68 | 69 | adaptor = get_adaptor( 70 | configs=configs, 71 | ).to(device) 72 | 73 | model = CombinedModel(base, adaptor) 74 | 75 | if ckpt_path: 76 | ckpt = torch.load(ckpt_path) 77 | base_ckpt = {k: v for k, v in ckpt.items() if k.startswith("base")} 78 | missing, unexpected = model.load_state_dict(base_ckpt, strict=False) 79 | print(missing) 80 | 81 | # EMA (optional) 82 | ema = deepcopy(model).to(device) 83 | requires_grad(ema, False) 84 | update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights 85 | ema.eval() # EMA model should always be in eval mode 86 | 87 | # Optimizer 88 | optimizer, scheduler = get_optimizer_and_scheduler( 89 | configs=configs, 90 | params=model.parameters() 91 | ) 92 | 93 | # Logger 94 | if wandb_log: 95 | wandb.init(project="audio_flow", name=f"{filename}_{config_name}") 96 | 97 | for step, data in enumerate(tqdm(train_dataloader)): 98 | 99 | # ------ 1. Data preparation ------ 100 | # 1.1 Transform data into latent representations and conditions 101 | x_real, cond_dict = data_transform(data) 102 | 103 | # 1.2 Noise 104 | noise = torch.randn_like(x_real) 105 | 106 | # 1.3 Get input and velocity 107 | t, xt, ut = fm.sample_location_and_conditional_flow(x0=noise, x1=x_real) 108 | 109 | # ------ 2. Training ------ 110 | # 2.1 Forward 111 | model.train() 112 | emb_dict = model.adaptor(cond_dict) 113 | vt = model.base(t=t, x=xt, emb_dict=emb_dict) 114 | 115 | # 2.2 Loss 116 | loss = torch.mean((vt - ut) ** 2) 117 | 118 | # 2.3 Optimize 119 | optimizer.zero_grad() # Reset all parameter.grad to 0 120 | loss.backward() # Update all parameter.grad 121 | optimizer.step() # Update all parameters based on all parameter.grad 122 | update_ema(ema, model, decay=0.999) 123 | 124 | # 2.4 Learning rate scheduler 125 | if scheduler: 126 | scheduler.step() 127 | 128 | if step % 100 == 0: 129 | print("train loss: {:.4f}".format(loss.item())) 130 | 131 | # ------ 3. Evaluation ------ 132 | # 3.1 Evaluate 133 | if step % configs["train"]["test_every_n_steps"] == 0: 134 | 135 | for split in ["train", "test"]: 136 | validate( 137 | configs=configs, 138 | data_transform=data_transform, 139 | model=ema, 140 | split=split, 141 | out_dir=Path("./results", filename, config_name, f"steps={step}_ema") 142 | ) 143 | 144 | if wandb_log: 145 | wandb.log( 146 | data={ 147 | "train_loss": loss.item() 148 | }, 149 | step=step 150 | ) 151 | 152 | # 3.2 Save model 153 | if step % configs["train"]["save_every_n_steps"] == 0: 154 | 155 | ckpt_path = Path(ckpts_dir, f"step={step}_ema.pt") 156 | torch.save(ema.state_dict(), ckpt_path) 157 | print(f"Save model to {ckpt_path}") 158 | 159 | if step == configs["train"]["training_steps"]: 160 | break 161 | 162 | step += 1 163 | 164 | 165 | if __name__ == "__main__": 166 | 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument("--config", type=str, required=True, help="Path of config yaml.") 169 | parser.add_argument("--ckpt_path", type=str, required=True, help="Path of config yaml.") 170 | parser.add_argument("--no_log", action="store_true", default=False) 171 | args = parser.parse_args() 172 | 173 | train(args) -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | from copy import deepcopy 5 | from pathlib import Path 6 | from typing import Iterable, Literal 7 | 8 | import matplotlib.pyplot as plt 9 | import soundfile 10 | import torch 11 | from torch import LongTensor 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import torchdiffeq 15 | import wandb 16 | from audio_flow.utils import (CombinedModel, LinearWarmUp, parse_yaml, 17 | requires_grad, update_ema, logmel) 18 | from torch.utils.data import DataLoader, Dataset 19 | from torch.utils.data._utils.collate import default_collate 20 | from torchcfm.conditional_flow_matching import ConditionalFlowMatcher 21 | from tqdm import tqdm 22 | 23 | from train import get_data_transform, get_base, get_adaptor 24 | from audio_flow.datasets.gtzan_vae import GtzanVAE 25 | 26 | 27 | def sample(args): 28 | 29 | # Arguments 30 | config_path = args.config 31 | ckpt_path = args.ckpt_path 32 | device = "cuda" 33 | 34 | configs = parse_yaml(config_path) 35 | 36 | data_transform = get_data_transform(configs).to(device) 37 | 38 | # Model 39 | base = get_base(configs=configs).to(device) 40 | adaptor = get_adaptor(configs=configs).to(device) 41 | model = CombinedModel(base, adaptor) 42 | 43 | # Load checkpoint 44 | if ckpt_path: 45 | ckpt = torch.load(ckpt_path) 46 | model.load_state_dict(ckpt, strict=True) 47 | 48 | out_dir = "_tmp" 49 | Path(out_dir).mkdir(parents=True, exist_ok=True) 50 | 51 | # Prepare condition 52 | for id in range(10): 53 | 54 | cond_dict = { 55 | "id": LongTensor([id]).to(device), 56 | } 57 | 58 | noise = torch.randn(1, 64, 250).to(device) # (b, d, t) 59 | 60 | with torch.no_grad(): 61 | model.eval() 62 | emb_dict = model.adaptor(cond_dict) 63 | traj = torchdiffeq.odeint( 64 | lambda t, x: model.base(t, x, emb_dict), 65 | y0=noise, 66 | t=torch.linspace(0, 1, 2, device=device), 67 | atol=1e-4, 68 | rtol=1e-4, 69 | method="dopri5", 70 | ) 71 | 72 | x_gen = traj[-1] # (b, d, t) 73 | gen_audio = data_transform.latent_to_audio(x_gen).data.cpu().numpy()[0] # (c, l) 74 | 75 | # Visualize logmel 76 | sr = data_transform.sr 77 | gen_logmel = logmel(gen_audio, sr) 78 | fig, ax = plt.subplots(1, 1, figsize=(10, 10)) 79 | ax.matshow(gen_logmel.T, origin='lower', aspect='auto', cmap='jet', vmin=-10, vmax=5) 80 | out_path = Path(out_dir, "{}.pdf".format(GtzanVAE.IX_TO_LB[id])) 81 | plt.savefig(out_path) 82 | print(f"Write out to {out_path}") 83 | 84 | # Write to audio 85 | out_path = Path(out_dir, "{}.wav".format(GtzanVAE.IX_TO_LB[id])) 86 | soundfile.write(file=out_path, data=gen_audio.T, samplerate=sr) 87 | print(f"Write out to {out_path}") 88 | 89 | 90 | if __name__ == "__main__": 91 | 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument("--config", type=str, required=True, help="Path of config yaml.") 94 | parser.add_argument("--ckpt_path", type=str, required=True, help="Path of config yaml.") 95 | args = parser.parse_args() 96 | 97 | sample(args) -------------------------------------------------------------------------------- /scripts/download_gtzan.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Download dataset 4 | mkdir -p ./downloaded_datasets/gtzan 5 | wget -O ./downloaded_datasets/gtzan/genres.tar.gz https://huggingface.co/datasets/qiuqiangkong/gtzan/resolve/main/genres.tar.gz?download=true 6 | 7 | # Unzip dataset 8 | mkdir -p ./datasets/gtzan 9 | tar -zxvf ./downloaded_datasets/gtzan/genres.tar.gz -C ./datasets/gtzan/ -------------------------------------------------------------------------------- /scripts/download_ljspeech.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget -O LJSpeech-1.1.tar.bz2 https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 4 | mkdir -p ./datasets 5 | tar -xvf LJSpeech-1.1.tar.bz2 -C ./datasets/ 6 | 7 | wget -O ./datasets/LJSpeech-1.1/train.txt https://huggingface.co/datasets/flexthink/ljspeech/resolve/main/train.txt?download=true 8 | wget -O ./datasets/LJSpeech-1.1/valid.txt https://huggingface.co/datasets/flexthink/ljspeech/resolve/main/valid.txt?download=true 9 | wget -O ./datasets/LJSpeech-1.1/test.txt https://huggingface.co/datasets/flexthink/ljspeech/resolve/main/test.txt?download=true -------------------------------------------------------------------------------- /scripts/download_maestro.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Download dataset 4 | mkdir -p ./downloaded_datasets/maestro 5 | wget -O ./downloaded_datasets/maestro/maestro-v3.0.0.zip https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.zip 6 | 7 | # Unzip dataset 8 | sudo apt install p7zip-full 9 | 10 | mkdir -p ./datasets 11 | unzip ./downloaded_datasets/maestro/maestro-v3.0.0.zip -d ./datasets/ 12 | -------------------------------------------------------------------------------- /scripts/download_musdb18hq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Download dataset 4 | mkdir -p ./downloaded_datasets/musdb18hq 5 | wget -O ./downloaded_datasets/musdb18hq/musdb18hq.zip https://zenodo.org/records/3338373/files/musdb18hq.zip?download=1 6 | 7 | # Unzip dataset 8 | mkdir -p ./datasets/musdb18hq 9 | unzip ./downloaded_datasets/musdb18hq/musdb18hq.zip -d ./datasets/musdb18hq/ -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | from copy import deepcopy 5 | from pathlib import Path 6 | from typing import Iterable, Literal 7 | 8 | import matplotlib.pyplot as plt 9 | import soundfile 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torchdiffeq 14 | import wandb 15 | from audio_flow.utils import (CombinedModel, LinearWarmUp, parse_yaml, 16 | requires_grad, update_ema, logmel) 17 | from torch.utils.data import DataLoader, Dataset 18 | from torch.utils.data._utils.collate import default_collate 19 | from torchcfm.conditional_flow_matching import ConditionalFlowMatcher 20 | from tqdm import tqdm 21 | 22 | 23 | def train(args) -> None: 24 | r"""Train audio generation with flow matching.""" 25 | 26 | # Arguments 27 | wandb_log = not args.no_log 28 | config_path = args.config 29 | filename = Path(__file__).stem 30 | 31 | # Configs 32 | configs = parse_yaml(config_path) 33 | device = configs["train"]["device"] 34 | ckpt_path = configs["train"]["resume_ckpt_path"] 35 | 36 | # Checkpoints directory 37 | config_name = Path(config_path).stem 38 | ckpts_dir = Path("./checkpoints", filename, config_name) 39 | Path(ckpts_dir).mkdir(parents=True, exist_ok=True) 40 | 41 | # Datasets 42 | train_dataset = get_dataset(configs, split="train", mode="train") 43 | 44 | # Sampler 45 | train_sampler = get_sampler(configs, train_dataset) 46 | 47 | # Dataloader 48 | train_dataloader = DataLoader( 49 | dataset=train_dataset, 50 | batch_size=configs["train"]["batch_size_per_device"], 51 | sampler=train_sampler, 52 | num_workers=configs["train"]["num_workers"], 53 | pin_memory=True, 54 | ) 55 | 56 | # Data processor 57 | data_transform = get_data_transform(configs).to(device) 58 | 59 | # Flow matching data processor 60 | fm = ConditionalFlowMatcher(sigma=0.) 61 | 62 | # Model 63 | base = get_base( 64 | configs=configs, 65 | ).to(device) 66 | 67 | adaptor = get_adaptor( 68 | configs=configs, 69 | ).to(device) 70 | 71 | model = CombinedModel(base, adaptor) 72 | 73 | if ckpt_path: 74 | ckpt = torch.load(ckpt_path) 75 | model.load_state_dict(ckpt, strict=True) 76 | 77 | # EMA (optional) 78 | ema = deepcopy(model).to(device) 79 | requires_grad(ema, False) 80 | update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights 81 | ema.eval() # EMA model should always be in eval mode 82 | 83 | # Optimizer 84 | optimizer, scheduler = get_optimizer_and_scheduler( 85 | configs=configs, 86 | params=model.parameters() 87 | ) 88 | 89 | # Logger 90 | if wandb_log: 91 | wandb.init(project="audio_flow", name=f"{filename}_{config_name}") 92 | 93 | for step, data in enumerate(tqdm(train_dataloader)): 94 | 95 | # ------ 1. Data preparation ------ 96 | # 1.1 Transform data into latent representations and conditions 97 | x_real, cond_dict = data_transform(data) 98 | 99 | # 1.2 Noise 100 | noise = torch.randn_like(x_real) 101 | 102 | # 1.3 Get input and velocity 103 | t, xt, ut = fm.sample_location_and_conditional_flow(x0=noise, x1=x_real) 104 | 105 | # ------ 2. Training ------ 106 | # 2.1 Forward 107 | model.train() 108 | emb_dict = model.adaptor(cond_dict) 109 | vt = model.base(t=t, x=xt, emb_dict=emb_dict) 110 | 111 | # 2.2 Loss 112 | loss = torch.mean((vt - ut) ** 2) 113 | 114 | # 2.3 Optimize 115 | optimizer.zero_grad() # Reset all parameter.grad to 0 116 | loss.backward() # Update all parameter.grad 117 | optimizer.step() # Update all parameters based on all parameter.grad 118 | update_ema(ema, model, decay=0.999) 119 | 120 | # 2.4 Learning rate scheduler 121 | if scheduler: 122 | scheduler.step() 123 | 124 | if step % 100 == 0: 125 | print("train loss: {:.4f}".format(loss.item())) 126 | 127 | # ------ 3. Evaluation ------ 128 | # 3.1 Evaluate 129 | if step % configs["train"]["test_every_n_steps"] == 0: 130 | 131 | for split in ["train", "test"]: 132 | validate( 133 | configs=configs, 134 | data_transform=data_transform, 135 | model=ema, 136 | split=split, 137 | out_dir=Path("./results", filename, config_name, f"steps={step}_ema") 138 | ) 139 | 140 | if wandb_log: 141 | wandb.log( 142 | data={ 143 | "train_loss": loss.item() 144 | }, 145 | step=step 146 | ) 147 | 148 | # 3.2 Save model 149 | if step % configs["train"]["save_every_n_steps"] == 0: 150 | 151 | ckpt_path = Path(ckpts_dir, f"step={step}_ema.pt") 152 | torch.save(ema.state_dict(), ckpt_path) 153 | print(f"Save model to {ckpt_path}") 154 | 155 | if step == configs["train"]["training_steps"]: 156 | break 157 | 158 | step += 1 159 | 160 | 161 | def get_dataset( 162 | configs: dict, 163 | split: Literal["train", "test"], 164 | mode: Literal["train", "test"] 165 | ) -> Dataset: 166 | r"""Get datasets.""" 167 | 168 | ds = f"{split}_datasets" 169 | 170 | for name in configs[ds].keys(): 171 | 172 | if name == "GtzanVAE": 173 | from audio_flow.datasets.gtzan_vae import GtzanVAE 174 | return GtzanVAE( 175 | root=configs[ds][name]["root"], 176 | split=configs[ds][name]["split"], 177 | test_fold=0, 178 | duration=configs["clip_duration"] 179 | ) 180 | 181 | elif name == "MUSDB18HqVAE": 182 | from audio_flow.datasets.musdb18hq_vae import MUSDB18HqVAE 183 | return MUSDB18HqVAE( 184 | root=configs[ds][name]["root"], 185 | split=configs[ds][name]["split"], 186 | duration=configs["clip_duration"] 187 | ) 188 | 189 | elif name == "MUSDB18HqMono2StereoVAE": 190 | from audio_flow.datasets.musdb18hq_mono2stereo_vae import MUSDB18HqMono2StereoVAE 191 | return MUSDB18HqMono2StereoVAE( 192 | root=configs[ds][name]["root"], 193 | split=configs[ds][name]["split"], 194 | duration=configs["clip_duration"] 195 | ) 196 | 197 | elif name == "MUSDB18HqLowres2HighresVAE": 198 | from audio_flow.datasets.musdb18hq_lowres2highres_vae import MUSDB18HqLowres2HighresVAE 199 | return MUSDB18HqLowres2HighresVAE( 200 | root=configs[ds][name]["root"], 201 | split=configs[ds][name]["split"], 202 | duration=configs["clip_duration"] 203 | ) 204 | 205 | elif name == "MUSDB18HqDac2StereoVAE": 206 | from audio_flow.datasets.musdb18hq_dac2stereo_vae import MUSDB18HqDac2StereoVAE 207 | return MUSDB18HqDac2StereoVAE( 208 | root=configs[ds][name]["root"], 209 | split=configs[ds][name]["split"], 210 | duration=configs["clip_duration"] 211 | ) 212 | 213 | elif name == "LJSpeechVAE": 214 | from audio_flow.datasets.ljspeech_vae import LJSpeechVAE 215 | return LJSpeechVAE( 216 | root=configs[ds][name]["root"], 217 | split=configs[ds][name]["split"], 218 | duration=configs["clip_duration"] 219 | ) 220 | 221 | else: 222 | raise ValueError(name) 223 | 224 | 225 | def get_sampler(configs: dict, dataset: Dataset) -> Iterable: 226 | r"""Get sampler.""" 227 | 228 | name = configs["sampler"] 229 | 230 | if name == "RepeatShuffleSampler": 231 | from audio_flow.samplers.sampler import RepeatShuffleSampler 232 | return RepeatShuffleSampler(dataset) 233 | 234 | else: 235 | raise ValueError(name) 236 | 237 | 238 | def get_data_transform(configs: dict): 239 | r"""Transform data into latent representations and conditions.""" 240 | 241 | name = configs["data_transform"]["name"] 242 | 243 | if name == "Label2MusicVAE": 244 | from audio_flow.data_transforms.label2music import Label2MusicVAE 245 | return Label2MusicVAE() 246 | 247 | elif name == "MSS": 248 | from audio_flow.data_transforms.mss import MSSVAE 249 | return MSSVAE(target_stem=configs["data_transform"]["target_stem"]) 250 | 251 | elif name == "Mono2StereoVAE": 252 | from audio_flow.data_transforms.mono2stereo import Mono2StereoVAE 253 | return Mono2StereoVAE() 254 | 255 | elif name == "SuperResolutionVAE": 256 | from audio_flow.data_transforms.mono2stereo import SuperResolutionVAE 257 | return SuperResolutionVAE() 258 | 259 | elif name == "Dac2StereoVAE": 260 | from audio_flow.data_transforms.dac2stereo import Dac2StereoVAE 261 | return Dac2StereoVAE() 262 | 263 | elif name == "Text2SpeechVAE": 264 | from audio_flow.data_transforms.text2speech import Text2SpeechVAE 265 | return Text2SpeechVAE() 266 | 267 | else: 268 | raise ValueError(name) 269 | 270 | 271 | def get_base( 272 | configs: dict, 273 | ) -> nn.Module: 274 | r"""Initialize base model.""" 275 | 276 | name = configs["base"]["name"] 277 | 278 | if name == "Transformer1D": 279 | from audio_flow.models.transformer1d import Transformer1D 280 | return Transformer1D(**configs["base"]) 281 | 282 | else: 283 | raise ValueError(name) 284 | 285 | 286 | def get_adaptor( 287 | configs: dict, 288 | ): 289 | r"""Initialize adaptor.""" 290 | 291 | name = configs["adaptor"]["name"] 292 | 293 | if name == "OnehotEncoder": 294 | from audio_flow.adaptors.onehot import OnehotEncoder 295 | return OnehotEncoder( 296 | num_classes=configs["adaptor"]["num_classes"], 297 | dim=configs["base"]["dim"] 298 | ) 299 | 300 | elif name == "VAEEncoder": 301 | from audio_flow.adaptors.vae import VAEEncoder 302 | return VAEEncoder( 303 | in_channels=configs["adaptor"]["dim"], 304 | dim=configs["base"]["dim"] 305 | ) 306 | 307 | elif name == "LatentEncoder": 308 | from audio_flow.adaptors.latent import LatentEncoder 309 | return LatentEncoder( 310 | in_channels=configs["adaptor"]["dim"], 311 | dim=configs["base"]["dim"] 312 | ) 313 | 314 | else: 315 | raise ValueError(name) 316 | 317 | 318 | def get_optimizer_and_scheduler( 319 | configs: dict, 320 | params: list[torch.Tensor] 321 | ) -> tuple[optim.Optimizer, None | optim.lr_scheduler.LambdaLR]: 322 | r"""Get optimizer and scheduler.""" 323 | 324 | lr = float(configs["train"]["lr"]) 325 | warm_up_steps = configs["train"]["warm_up_steps"] 326 | optimizer_name = configs["train"]["optimizer"] 327 | 328 | if optimizer_name == "AdamW": 329 | optimizer = optim.AdamW(params=params, lr=lr) 330 | 331 | if warm_up_steps: 332 | lr_lambda = LinearWarmUp(warm_up_steps) 333 | scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lr_lambda) 334 | else: 335 | scheduler = None 336 | 337 | return optimizer, scheduler 338 | 339 | 340 | def validate( 341 | configs: dict, 342 | data_transform: object, 343 | model: nn.Module, 344 | split: Literal["train", "test"], 345 | out_dir: str 346 | ) -> float: 347 | r"""Validate the model on part of data.""" 348 | 349 | device = next(model.parameters()).device 350 | out_dir.mkdir(parents=True, exist_ok=True) 351 | 352 | valid_audios = configs["valid_audios"] 353 | sr = data_transform.sr 354 | 355 | dataset = get_dataset(configs, split=split, mode="test") 356 | dataset[0] 357 | 358 | # Evaluate only part of data 359 | if valid_audios: 360 | skip_n = max(1, len(dataset) // valid_audios) 361 | else: 362 | skip_n = 1 363 | 364 | for i, idx in enumerate(range(0, len(dataset), skip_n)): 365 | 366 | # ------ 1. Data preparation ------ 367 | # 1.1 Get Data 368 | data = dataset[idx] 369 | data = default_collate([data]) 370 | 371 | # 1.2 Transform data into latent representations and conditions 372 | x_real, cond_dict = data_transform(data) 373 | 374 | # 1.3 Noise 375 | noise = torch.randn_like(x_real) 376 | 377 | # ------ 2. Forward with ODE ------ 378 | # 2.1 Iteratively forward 379 | with torch.no_grad(): 380 | model.eval() 381 | emb_dict = model.adaptor(cond_dict) 382 | traj = torchdiffeq.odeint( 383 | lambda t, x: model.base(t, x, emb_dict), 384 | y0=noise, 385 | t=torch.linspace(0, 1, 2, device=device), 386 | atol=1e-4, 387 | rtol=1e-4, 388 | method="dopri5", 389 | ) 390 | 391 | x_gen = traj[-1] # (b, d, t) 392 | 393 | # 2.2 Latent to audio 394 | if "ct" in cond_dict and cond_dict["ct"].shape == x_real.shape: 395 | in_audio = data_transform.latent_to_audio(cond_dict["ct"]).data.cpu().numpy()[0] # (c, l) 396 | else: 397 | in_audio = None 398 | 399 | gen_audio = data_transform.latent_to_audio(x_gen).data.cpu().numpy()[0] # (c, l) 400 | gt_audio = data_transform.latent_to_audio(x_real).data.cpu().numpy()[0] # (c, l) 401 | 402 | # ------ 3. Plot and Visualization ------ 403 | if in_audio is not None: 404 | in_logmel = logmel(in_audio, sr) 405 | gen_logmel = logmel(gen_audio, sr) 406 | gt_logmel = logmel(gt_audio, sr) 407 | 408 | fig, axs = plt.subplots(3, 1, figsize=(10, 10)) 409 | vmin, vmax = -10, 5 410 | if in_audio is not None: 411 | axs[0].matshow(in_logmel.T, origin='lower', aspect='auto', cmap='jet', vmin=vmin, vmax=vmax) 412 | axs[1].matshow(gen_logmel.T, origin='lower', aspect='auto', cmap='jet', vmin=vmin, vmax=vmax) 413 | axs[2].matshow(gt_logmel.T, origin='lower', aspect='auto', cmap='jet', vmin=vmin, vmax=vmax) 414 | axs[0].set_title("Input (if there are)") 415 | axs[1].set_title("Generation") 416 | axs[2].set_title("Ground truth") 417 | axs[2].xaxis.tick_bottom() 418 | 419 | if "caption" in cond_dict: 420 | caption = "_{}".format(cond_dict["caption"][0]) 421 | else: 422 | caption = "" 423 | 424 | out_path = Path(out_dir, f"{split}_{i}{caption}.png") 425 | plt.savefig(out_path) 426 | print(f"Write out to {out_path}") 427 | 428 | # 3.2 Save audio 429 | if in_audio is not None: 430 | out_path = Path(out_dir, f"{split}_{i}{caption}_in.wav") 431 | soundfile.write(file=out_path, data=in_audio.T, samplerate=sr) 432 | print(f"Write out to {out_path}") 433 | 434 | out_path = Path(out_dir, f"{split}_{i}{caption}_gen.wav") 435 | soundfile.write(file=out_path, data=gen_audio.T, samplerate=sr) 436 | print(f"Write out to {out_path}") 437 | 438 | out_path = Path(out_dir, f"{split}_{i}{caption}_gt.wav") 439 | soundfile.write(file=out_path, data=gt_audio.T, samplerate=sr) 440 | print(f"Write out to {out_path}") 441 | 442 | 443 | if __name__ == "__main__": 444 | 445 | parser = argparse.ArgumentParser() 446 | parser.add_argument("--config", type=str, required=True, help="Path of config yaml.") 447 | parser.add_argument("--no_log", action="store_true", default=False) 448 | args = parser.parse_args() 449 | 450 | train(args) --------------------------------------------------------------------------------