├── ldm ├── models │ └── diffusion │ │ ├── __init__.py │ │ ├── cfm1_audio_sampler.py │ │ ├── cfm1_audio.py │ │ └── classifier.py ├── modules │ ├── encoders │ │ ├── __init__.py │ │ └── CLAP │ │ │ ├── __init__.py │ │ │ ├── config.yml │ │ │ ├── utils.py │ │ │ ├── clap.py │ │ │ └── audio.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── diffusionmodules │ │ └── __init__.py │ ├── losses_audio │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ ├── ema.py │ ├── discriminator │ │ └── multi_window_disc.py │ ├── new_attention.py │ └── attention.py ├── data │ ├── joinaudiodataset_624.py │ └── joinaudiodataset_struct_sample_anylen.py ├── lr_scheduler.py └── util.py ├── vocoder ├── parallel_wavegan │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ └── utils.py │ ├── losses │ │ ├── __init__.py │ │ └── stft_loss.py │ ├── optimizers │ │ ├── __init__.py │ │ └── radam.py │ ├── models │ │ └── __init__.py │ ├── layers │ │ ├── __init__.py │ │ ├── causal_conv.py │ │ ├── residual_stack.py │ │ ├── tf_layers.py │ │ ├── pqmf.py │ │ ├── residual_block.py │ │ └── upsample.py │ └── stft_loss.py ├── bigvgan │ ├── __init__.py │ ├── alias_free_torch │ │ ├── __init__.py │ │ ├── act.py │ │ ├── resample.py │ │ └── filter.py │ └── activations.py └── hifigan │ ├── __init__.py │ ├── hifigan_utils.py │ ├── hifigan.py │ └── hifigan_nsf.py ├── requirements.txt ├── utils ├── os_utils.py └── commons │ ├── data_utils.py │ ├── ckpt_utils.py │ └── hparams.py ├── configs ├── ae_accomp.yaml └── vocal2music.yaml ├── preprocess ├── preprocess.py └── NAT_mel.py └── README.md /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * # NOQA 2 | -------------------------------------------------------------------------------- /vocoder/bigvgan/__init__.py: -------------------------------------------------------------------------------- 1 | from vocoder.bigvgan.models import VocoderBigVGAN 2 | -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .stft_loss import * # NOQA 2 | -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/__init__.py: -------------------------------------------------------------------------------- 1 | from . import clap 2 | from . import audio 3 | from . import utils -------------------------------------------------------------------------------- /vocoder/hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | from .hifigan import HifiGAN, CodeUpsampleHifiGan 2 | from .hifigan_nsf import HifiGAN_NSF -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.optim import * # NOQA 2 | from .radam import * # NOQA 3 | -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .melgan import * # NOQA 2 | from .parallel_wavegan import * # NOQA 3 | -------------------------------------------------------------------------------- /vocoder/bigvgan/alias_free_torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | from .filter import * 5 | from .resample import * 6 | from .act import * -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .causal_conv import * # NOQA 2 | from .pqmf import * # NOQA 3 | from .residual_block import * # NOQA 4 | from vocoder.parallel_wavegan.layers.residual_stack import * # NOQA 5 | from .upsample import * # NOQA 6 | -------------------------------------------------------------------------------- /ldm/modules/losses_audio/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses_audio.vqperceptual import DummyLoss 2 | 3 | # relative imports pain 4 | import os 5 | import sys 6 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'vggishish') 7 | sys.path.append(path) 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.8.0 2 | h5py==3.11.0 3 | librosa==0.10.1 4 | matplotlib==3.8.2 5 | numba==0.60.0 6 | numpy==1.26.4 7 | packaging==24.2 8 | pandas==2.2.3 9 | pyloudnorm==0.1.1 10 | scipy 11 | tensorflow 12 | tensorboard 13 | torch==2.1.0+cu121 14 | torchaudio==2.1.0+cu121 15 | tqdm==4.66.5 16 | webrtcvad==2.0.10 17 | torch-fidelity==0.3.0 18 | importlib_resources 19 | omegaconf 20 | soundfile 21 | torchlibrosa 22 | ftfy 23 | pytorch-lightning==1.9.0 24 | torchmetrics==1.2.1 25 | -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/config.yml: -------------------------------------------------------------------------------- 1 | # TEXT ENCODER CONFIG 2 | text_model: 'bert-base-uncased' 3 | text_len: 100 4 | transformer_embed_dim: 768 5 | freeze_text_encoder_weights: True 6 | 7 | # AUDIO ENCODER CONFIG 8 | audioenc_name: 'Cnn14' 9 | out_emb: 2048 10 | sampling_rate: 44100 11 | duration: 5 12 | fmin: 50 13 | fmax: 14000 14 | n_fft: 1028 15 | hop_size: 320 16 | mel_bins: 64 17 | window_size: 1024 18 | 19 | # PROJECTION SPACE CONFIG 20 | d_proj: 1024 21 | temperature: 0.003 22 | 23 | # TRAINING AND EVALUATION CONFIG 24 | num_classes: 527 25 | batch_size: 1024 26 | demo: False 27 | -------------------------------------------------------------------------------- /utils/os_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | 5 | def link_file(from_file, to_file): 6 | subprocess.check_call( 7 | f'ln -s "`realpath --relative-to="{os.path.dirname(to_file)}" "{from_file}"`" "{to_file}"', shell=True) 8 | 9 | 10 | def move_file(from_file, to_file): 11 | subprocess.check_call(f'mv "{from_file}" "{to_file}"', shell=True) 12 | 13 | 14 | def copy_file(from_file, to_file): 15 | subprocess.check_call(f'cp -r "{from_file}" "{to_file}"', shell=True) 16 | 17 | 18 | def remove_file(*fns): 19 | for f in fns: 20 | subprocess.check_call(f'rm -rf "{f}"', shell=True) 21 | -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import sys 4 | 5 | def read_config_as_args(config_path,args=None,is_config_str=False): 6 | return_dict = {} 7 | 8 | if config_path is not None: 9 | if is_config_str: 10 | yml_config = yaml.load(config_path, Loader=yaml.FullLoader) 11 | else: 12 | with open(config_path, "r") as f: 13 | yml_config = yaml.load(f, Loader=yaml.FullLoader) 14 | 15 | if args != None: 16 | for k, v in yml_config.items(): 17 | if k in args.__dict__: 18 | args.__dict__[k] = v 19 | else: 20 | sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k)) 21 | else: 22 | for k, v in yml_config.items(): 23 | return_dict[k] = v 24 | 25 | args = args if args != None else return_dict 26 | return argparse.Namespace(**args) 27 | -------------------------------------------------------------------------------- /vocoder/bigvgan/alias_free_torch/act.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from .resample import UpSample1d, DownSample1d 6 | 7 | 8 | class Activation1d(nn.Module): 9 | def __init__(self, 10 | activation, 11 | up_ratio: int = 2, 12 | down_ratio: int = 2, 13 | up_kernel_size: int = 12, 14 | down_kernel_size: int = 12): 15 | super().__init__() 16 | self.up_ratio = up_ratio 17 | self.down_ratio = down_ratio 18 | self.act = activation 19 | self.upsample = UpSample1d(up_ratio, up_kernel_size) 20 | self.downsample = DownSample1d(down_ratio, down_kernel_size) 21 | 22 | # x: [B,C,T] 23 | def forward(self, x): 24 | x = self.upsample(x) 25 | x = self.act(x) 26 | x = self.downsample(x) 27 | 28 | return x -------------------------------------------------------------------------------- /utils/commons/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import csv 4 | 5 | def safe_path(path): 6 | os.makedirs(Path(path).parent, exist_ok=True) 7 | return path 8 | 9 | def load_samples_from_tsv(tsv_path): 10 | tsv_path = Path(tsv_path) 11 | if not tsv_path.is_file(): 12 | raise FileNotFoundError(f"Dataset not found: {tsv_path}") 13 | with open(tsv_path) as f: 14 | reader = csv.DictReader( 15 | f, 16 | delimiter="\t", 17 | quotechar=None, 18 | doublequote=False, 19 | lineterminator="\n", 20 | quoting=csv.QUOTE_NONE, 21 | ) 22 | samples = [dict(e) for e in reader] 23 | if len(samples) == 0: 24 | print(f"warning: empty manifest: {tsv_path}") 25 | return [] 26 | return samples 27 | 28 | def load_dict_from_tsv(tsv_path, key): 29 | samples = load_samples_from_tsv(tsv_path) 30 | samples = {sample[key]: sample for sample in samples} 31 | return samples 32 | 33 | -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/layers/causal_conv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Causal convolusion layer modules.""" 7 | 8 | 9 | import torch 10 | 11 | 12 | class CausalConv1d(torch.nn.Module): 13 | """CausalConv1d module with customized initialization.""" 14 | 15 | def __init__(self, in_channels, out_channels, kernel_size, 16 | dilation=1, bias=True, pad="ConstantPad1d", pad_params={"value": 0.0}): 17 | """Initialize CausalConv1d module.""" 18 | super(CausalConv1d, self).__init__() 19 | self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params) 20 | self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, 21 | dilation=dilation, bias=bias) 22 | 23 | def forward(self, x): 24 | """Calculate forward propagation. 25 | 26 | Args: 27 | x (Tensor): Input tensor (B, in_channels, T). 28 | 29 | Returns: 30 | Tensor: Output tensor (B, out_channels, T). 31 | 32 | """ 33 | return self.conv(self.pad(x))[:, :, :x.size(2)] 34 | 35 | 36 | class CausalConvTranspose1d(torch.nn.Module): 37 | """CausalConvTranspose1d module with customized initialization.""" 38 | 39 | def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True): 40 | """Initialize CausalConvTranspose1d module.""" 41 | super(CausalConvTranspose1d, self).__init__() 42 | self.deconv = torch.nn.ConvTranspose1d( 43 | in_channels, out_channels, kernel_size, stride, bias=bias) 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | """Calculate forward propagation. 48 | 49 | Args: 50 | x (Tensor): Input tensor (B, in_channels, T_in). 51 | 52 | Returns: 53 | Tensor: Output tensor (B, out_channels, T_out). 54 | 55 | """ 56 | return self.deconv(x)[:, :, :-self.stride] 57 | -------------------------------------------------------------------------------- /vocoder/bigvgan/alias_free_torch/resample.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from .filter import LowPassFilter1d 7 | from .filter import kaiser_sinc_filter1d 8 | 9 | 10 | class UpSample1d(nn.Module): 11 | def __init__(self, ratio=2, kernel_size=None): 12 | super().__init__() 13 | self.ratio = ratio 14 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 15 | self.stride = ratio 16 | self.pad = self.kernel_size // ratio - 1 17 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 18 | self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 19 | filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, 20 | half_width=0.6 / ratio, 21 | kernel_size=self.kernel_size) 22 | self.register_buffer("filter", filter) 23 | 24 | # x: [B, C, T] 25 | def forward(self, x): 26 | _, C, _ = x.shape 27 | 28 | x = F.pad(x, (self.pad, self.pad), mode='replicate') 29 | x = self.ratio * F.conv_transpose1d( 30 | x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) 31 | x = x[..., self.pad_left:-self.pad_right] 32 | 33 | return x 34 | 35 | 36 | class DownSample1d(nn.Module): 37 | def __init__(self, ratio=2, kernel_size=None): 38 | super().__init__() 39 | self.ratio = ratio 40 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 41 | self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, 42 | half_width=0.6 / ratio, 43 | stride=ratio, 44 | kernel_size=self.kernel_size) 45 | 46 | def forward(self, x): 47 | xx = self.lowpass(x) 48 | 49 | return xx -------------------------------------------------------------------------------- /configs/ae_accomp.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder1d.AutoencoderKL 4 | params: 5 | embed_dim: 20 6 | monitor: val/rec_loss 7 | ddconfig: 8 | double_z: true 9 | in_channels: 80 10 | out_ch: 80 11 | z_channels: 20 12 | kernel_size: 5 13 | ch: 384 14 | ch_mult: 15 | - 1 16 | - 2 17 | - 4 18 | num_res_blocks: 2 19 | attn_layers: 20 | - 3 21 | down_layers: 22 | - 0 23 | dropout: 0.0 24 | lossconfig: 25 | target: ldm.modules.losses_audio.contperceptual.LPAPSWithDiscriminator 26 | params: 27 | disc_start: 80001 28 | perceptual_weight: 0.0 29 | kl_weight: 1.0e-06 30 | disc_weight: 0.5 31 | disc_in_channels: 1 32 | disc_loss: mse 33 | disc_factor: 2 34 | disc_conditional: false 35 | r1_reg_weight: 3 36 | 37 | lightning: 38 | callbacks: 39 | image_logger: 40 | target: main.AudioLogger 41 | params: 42 | for_specs: true 43 | increase_log_steps: false 44 | batch_frequency: 5000 45 | max_images: 8 46 | rescale: false 47 | melvmin: -5 48 | melvmax: 1.5 49 | sample_rate: 24000 50 | vocoder_cfg: 51 | target: vocoder.bigvgan.models.VocoderBigVGAN 52 | params: 53 | ckpt_vocoder: useful_ckpts/hifigan 54 | trainer: 55 | sync_batchnorm: false 56 | strategy: ddp 57 | 58 | 59 | data: 60 | target: main.SpectrogramDataModuleFromConfig 61 | params: 62 | batch_size: 20 63 | num_workers: 16 64 | spec_dir_path: /root/autodl-tmp/data/manifests/vocal_to_accomp/train/v2c_0905 65 | mel_num: 80 66 | spec_len: 624 67 | spec_crop_len: 624 68 | train: 69 | target: ldm.data.joinaudiodataset_624.JoinSpecsTrain 70 | params: 71 | specs_dataset_cfg: null 72 | validation: 73 | target: ldm.data.joinaudiodataset_624.JoinSpecsValidation 74 | params: 75 | specs_dataset_cfg: null 76 | 77 | -------------------------------------------------------------------------------- /vocoder/hifigan/hifigan_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True): 6 | if os.path.isfile(ckpt_base_dir): 7 | base_dir = os.path.dirname(ckpt_base_dir) 8 | ckpt_path = ckpt_base_dir 9 | checkpoint = torch.load(ckpt_base_dir, map_location='cpu') 10 | else: 11 | base_dir = ckpt_base_dir 12 | checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir) 13 | if checkpoint is not None: 14 | state_dict = checkpoint["state_dict"] 15 | if len([k for k in state_dict.keys() if '.' in k]) > 0: 16 | state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items() 17 | if k.startswith(f'{model_name}.')} 18 | else: 19 | if '.' not in model_name: 20 | state_dict = state_dict[model_name] 21 | else: 22 | base_model_name = model_name.split('.')[0] 23 | rest_model_name = model_name[len(base_model_name) + 1:] 24 | state_dict = { 25 | k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items() 26 | if k.startswith(f'{rest_model_name}.')} 27 | if not strict: 28 | cur_model_state_dict = cur_model.state_dict() 29 | unmatched_keys = [] 30 | for key, param in state_dict.items(): 31 | if key in cur_model_state_dict: 32 | new_param = cur_model_state_dict[key] 33 | if new_param.shape != param.shape: 34 | unmatched_keys.append(key) 35 | print("| Unmatched keys: ", key, new_param.shape, param.shape) 36 | for key in unmatched_keys: 37 | del state_dict[key] 38 | # print(state_dict) 39 | cur_model.load_state_dict(state_dict, strict=strict) 40 | print(f"| load '{model_name}' from '{ckpt_path}'.") 41 | else: 42 | e_msg = f"| ckpt not found in {base_dir}." 43 | if force: 44 | assert False, e_msg 45 | else: 46 | print(e_msg) -------------------------------------------------------------------------------- /utils/commons/ckpt_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import re 4 | import torch 5 | 6 | 7 | def get_last_checkpoint(work_dir, steps=None): 8 | checkpoint = None 9 | last_ckpt_path = None 10 | ckpt_paths = get_all_ckpts(work_dir, steps) 11 | if len(ckpt_paths) > 0: 12 | last_ckpt_path = ckpt_paths[0] 13 | checkpoint = torch.load(last_ckpt_path, map_location='cpu') 14 | return checkpoint, last_ckpt_path 15 | 16 | 17 | def get_all_ckpts(work_dir, steps=None): 18 | if steps is None: 19 | ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' 20 | else: 21 | ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' 22 | return sorted(glob.glob(ckpt_path_pattern), 23 | key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) 24 | 25 | 26 | def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True): 27 | if os.path.isfile(ckpt_base_dir): 28 | base_dir = os.path.dirname(ckpt_base_dir) 29 | ckpt_path = ckpt_base_dir 30 | checkpoint = torch.load(ckpt_base_dir, map_location='cpu') 31 | else: 32 | base_dir = ckpt_base_dir 33 | checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir) 34 | if checkpoint is not None: 35 | state_dict = checkpoint["state_dict"] 36 | if len([k for k in state_dict.keys() if '.' in k]) > 0: 37 | state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items() 38 | if k.startswith(f'{model_name}.')} 39 | else: 40 | if '.' not in model_name: 41 | state_dict = state_dict[model_name] 42 | else: 43 | base_model_name = model_name.split('.')[0] 44 | rest_model_name = model_name[len(base_model_name) + 1:] 45 | state_dict = { 46 | k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items() 47 | if k.startswith(f'{rest_model_name}.')} 48 | if not strict: 49 | cur_model_state_dict = cur_model.state_dict() 50 | unmatched_keys = [] 51 | for key, param in state_dict.items(): 52 | if key in cur_model_state_dict: 53 | new_param = cur_model_state_dict[key] 54 | if new_param.shape != param.shape: 55 | unmatched_keys.append(key) 56 | print("| Unmatched keys: ", key, new_param.shape, param.shape) 57 | for key in unmatched_keys: 58 | del state_dict[key] 59 | # print(state_dict) 60 | cur_model.load_state_dict(state_dict, strict=strict) 61 | print(f"| load '{model_name}' from '{ckpt_path}'.") 62 | else: 63 | e_msg = f"| ckpt not found in {base_dir}." 64 | if force: 65 | assert False, e_msg 66 | else: 67 | print(e_msg) 68 | -------------------------------------------------------------------------------- /vocoder/hifigan/hifigan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from vocoder.hifigan.modules.hifigan import HifiGanGenerator, CodeUpsampleHifiGanGenerator 3 | from utils.commons.ckpt_utils import load_ckpt 4 | from utils.commons.hparams import set_hparams, hparams 5 | 6 | class HifiGAN(torch.nn.Module): 7 | def __init__(self, vocoder_ckpt, device=None): 8 | super().__init__() 9 | # base_dir = hparams['vocoder_ckpt'] 10 | base_dir = vocoder_ckpt # ckpt dir 11 | config_path = f'{base_dir}/config.yaml' 12 | self.config = config = set_hparams(config_path, global_hparams=False, print_hparams=False) 13 | self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | self.model = HifiGanGenerator(config) 15 | load_ckpt(self.model, base_dir, 'model_gen') 16 | self.model.to(self.device) 17 | self.model.eval() 18 | 19 | def spec2wav(self, mel, **kwargs): 20 | device = self.device 21 | with torch.no_grad(): 22 | c = torch.FloatTensor(mel).unsqueeze(0).to(device) # [1, T, C] 23 | c = c.transpose(2, 1) 24 | y = self.model(c).view(-1) 25 | wav_out = y.cpu().numpy() 26 | return wav_out 27 | 28 | def __call__(self, mel): 29 | return self.spec2wav(mel) 30 | 31 | def vocode(self, mel): 32 | assert len(mel.shape) == 2 33 | device = self.device 34 | with torch.no_grad(): 35 | c = torch.FloatTensor(mel).unsqueeze(0).to(device) 36 | # print('mel.shape', c.shape) 37 | if c.shape[1] != 80: 38 | c = c.transpose(2, 1) 39 | # print('c.shape', c.shape) 40 | y = self.model(c).view(-1) 41 | wav_out = y.cpu().numpy() 42 | return wav_out 43 | 44 | class CodeUpsampleHifiGan(torch.nn.Module): 45 | def __init__(self, vocoder_ckpt, device=None): 46 | super(CodeUpsampleHifiGan, self).__init__() 47 | # base_dir = hparams['vocoder_ckpt'] 48 | base_dir = vocoder_ckpt # ckpt dir 49 | config_path = f'{base_dir}/config.yaml' 50 | self.config = config = set_hparams(config_path, global_hparams=False, print_hparams=False) 51 | self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") 52 | self.model = CodeUpsampleHifiGanGenerator(config) 53 | load_ckpt(self.model, base_dir, 'model_gen') 54 | self.model.to(self.device) 55 | self.model.eval() 56 | 57 | def spec2wav(self, mel, **kwargs): 58 | # mel (T, C) 59 | device = self.device 60 | with torch.no_grad(): 61 | if not isinstance(mel, torch.Tensor): 62 | mel = torch.LongTensor(mel) 63 | c = mel.unsqueeze(0) 64 | if device != mel.device: 65 | c = c.to(device) # [1, T, C] 66 | c = c.transpose(2, 1) # [1, C, T] 67 | y = self.model(c).view(-1) 68 | wav_out = y.cpu().numpy() 69 | return wav_out 70 | 71 | def __call__(self, mel): 72 | return self.spec2wav(mel) 73 | -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/layers/residual_stack.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Residual stack module in MelGAN.""" 7 | 8 | import torch 9 | 10 | from . import CausalConv1d 11 | 12 | 13 | class ResidualStack(torch.nn.Module): 14 | """Residual stack module introduced in MelGAN.""" 15 | 16 | def __init__(self, 17 | kernel_size=3, 18 | channels=32, 19 | dilation=1, 20 | bias=True, 21 | nonlinear_activation="LeakyReLU", 22 | nonlinear_activation_params={"negative_slope": 0.2}, 23 | pad="ReflectionPad1d", 24 | pad_params={}, 25 | use_causal_conv=False, 26 | ): 27 | """Initialize ResidualStack module. 28 | 29 | Args: 30 | kernel_size (int): Kernel size of dilation convolution layer. 31 | channels (int): Number of channels of convolution layers. 32 | dilation (int): Dilation factor. 33 | bias (bool): Whether to add bias parameter in convolution layers. 34 | nonlinear_activation (str): Activation function module name. 35 | nonlinear_activation_params (dict): Hyperparameters for activation function. 36 | pad (str): Padding function module name before dilated convolution layer. 37 | pad_params (dict): Hyperparameters for padding function. 38 | use_causal_conv (bool): Whether to use causal convolution. 39 | 40 | """ 41 | super(ResidualStack, self).__init__() 42 | 43 | # defile residual stack part 44 | if not use_causal_conv: 45 | assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." 46 | self.stack = torch.nn.Sequential( 47 | getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), 48 | getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params), 49 | torch.nn.Conv1d(channels, channels, kernel_size, dilation=dilation, bias=bias), 50 | getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), 51 | torch.nn.Conv1d(channels, channels, 1, bias=bias), 52 | ) 53 | else: 54 | self.stack = torch.nn.Sequential( 55 | getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), 56 | CausalConv1d(channels, channels, kernel_size, dilation=dilation, 57 | bias=bias, pad=pad, pad_params=pad_params), 58 | getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), 59 | torch.nn.Conv1d(channels, channels, 1, bias=bias), 60 | ) 61 | 62 | # defile extra layer for skip connection 63 | self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias) 64 | 65 | def forward(self, c): 66 | """Calculate forward propagation. 67 | 68 | Args: 69 | c (Tensor): Input tensor (B, channels, T). 70 | 71 | Returns: 72 | Tensor: Output tensor (B, chennels, T). 73 | 74 | """ 75 | return self.stack(c) + self.skip_layer(c) 76 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import kl_divergence, register_kl 3 | 4 | class DiagonalGaussianDistribution(object): 5 | def __init__(self, parameters, deterministic=False): 6 | self.parameters = parameters 7 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 8 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 9 | self.deterministic = deterministic 10 | self.std = torch.exp(0.5 * self.logvar) 11 | self.var = torch.exp(self.logvar) 12 | if self.deterministic: 13 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 14 | 15 | def sample(self): 16 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 17 | return x 18 | 19 | def kl(self, other=None): 20 | if self.deterministic: 21 | return torch.Tensor([0.]) 22 | else: 23 | sum_dim = list(range(1, len(self.mean.shape))) 24 | if other is None: 25 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 26 | + self.var - 1.0 - self.logvar, 27 | dim=sum_dim) 28 | else: 29 | return 0.5 * torch.sum( 30 | torch.pow(self.mean - other.mean, 2) / other.var 31 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 32 | dim=sum_dim) 33 | 34 | def nll(self, sample, dims=[1, 2, 3]): 35 | if self.deterministic: 36 | return torch.Tensor([0.]) 37 | logtwopi = torch.log(torch.tensor(2.0 * torch.pi)) 38 | return 0.5 * torch.sum( 39 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 40 | dim=dims) 41 | 42 | def mode(self): 43 | return self.mean 44 | 45 | # 定义 KL 散度的计算函数 46 | def kl_divergence_diag_gaussian(p, q): 47 | return p.kl(q) 48 | 49 | # 注册自定义的 KL 散度计算函数 50 | register_kl(DiagonalGaussianDistribution, DiagonalGaussianDistribution)(kl_divergence_diag_gaussian) 51 | 52 | 53 | 54 | def normal_kl(mean1, logvar1, mean2, logvar2): 55 | """ 56 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 57 | Compute the KL divergence between two gaussians. 58 | Shapes are automatically broadcasted, so batches can be compared to 59 | scalars, among other use cases. 60 | """ 61 | tensor = None 62 | for obj in (mean1, logvar1, mean2, logvar2): 63 | if isinstance(obj, torch.Tensor): 64 | tensor = obj 65 | break 66 | assert tensor is not None, "at least one argument must be a Tensor" 67 | 68 | # Force variances to be Tensors. Broadcasting helps convert scalars to 69 | # Tensors, but it does not work for torch.exp(). 70 | logvar1, logvar2 = [ 71 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 72 | for x in (logvar1, logvar2) 73 | ] 74 | 75 | return 0.5 * ( 76 | -1.0 77 | + logvar2 78 | - logvar1 79 | + torch.exp(logvar1 - logvar2) 80 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 81 | ) 82 | -------------------------------------------------------------------------------- /preprocess/preprocess.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | from tqdm import tqdm 4 | import torchaudio 5 | from typing import Any, Dict, List, Optional, Union 6 | from pathlib import Path 7 | import pandas as pd 8 | import random 9 | import os 10 | import csv 11 | import ast 12 | import librosa 13 | 14 | def save_df_to_tsv(dataframe, path: Union[str, Path]): 15 | _path = path if isinstance(path, str) else path.as_posix() 16 | dataframe.to_csv( 17 | _path, 18 | sep="\t", 19 | header=True, 20 | index=False, 21 | encoding="utf-8", 22 | escapechar="\\", 23 | quoting=csv.QUOTE_NONE, 24 | ) 25 | 26 | def generate(): 27 | inputs = ['/root/autodl-tmp/data/manifests/msd_prompts/crawl_2_concat_msd_prompts_refined.tsv', 28 | '/root/autodl-tmp/data/manifests/msd_prompts/crawl_new_concat_msd_prompts_refined.tsv', 29 | '/root/autodl-tmp/data/manifests/msd_prompts/yt_crawl_msd_prompts_refined.tsv'] 30 | MANIFEST_COLUMNS = ["name", "dataset", "audio_path", "mel_path"] 31 | 32 | def items_generator(input_file): 33 | with open(input_file, encoding='utf-8') as f: 34 | reader = csv.DictReader( 35 | f, 36 | delimiter="\t", 37 | quotechar=None, 38 | doublequote=False, 39 | lineterminator="\n", 40 | quoting=csv.QUOTE_NONE, 41 | ) 42 | for item in tqdm(reader): 43 | yield dict(item) 44 | 45 | skip = 0 46 | manifest = {c: [] for c in MANIFEST_COLUMNS} 47 | 48 | count=0 49 | for input_file in inputs: 50 | for i, item in enumerate(items_generator(input_file)): 51 | parts = item['item_name'].split("") 52 | if parts[0] == 'yt_crawl': 53 | parts[0] = 'yt_song_crawler' 54 | wav_path = f"/root/autodl-tmp/data/{parts[0]}_sp_demix_24k/{parts[1]}/[{parts[3]}]{parts[2]}.accomp.wav" 55 | if not os.path.exists(wav_path): 56 | # print(wav_path) 57 | skip += 1 58 | continue 59 | if not os.path.exists(wav_path.replace('accomp','vocal')): 60 | # print(wav_path) 61 | skip += 1 62 | continue 63 | count+=1 64 | mel_path = f"/root/autodl-tmp/data/{parts[0]}_sp_demix_24k/{parts[1]}/[{parts[3]}]{parts[2]}.accomp_mel.npy" 65 | 66 | dur= librosa.get_duration(filename=wav_path) 67 | caption= ast.literal_eval(item['caption']) 68 | caption=''.join(caption) 69 | # for t,cap in enumerate(caption): 70 | manifest["name"].append(str(item['item_name'])) 71 | manifest["dataset"].append(parts[0]) 72 | manifest["audio_path"].append(wav_path) 73 | manifest["mel_path"].append(mel_path) 74 | manifest["name"].append(str(item['item_name']+'vocal')) 75 | manifest["dataset"].append(parts[0]) 76 | manifest["audio_path"].append(wav_path.replace('accomp','vocal')) 77 | manifest["mel_path"].append(mel_path.replace('accomp','vocal')) 78 | print(count) 79 | 80 | print(f"skip: {skip}") 81 | save_df_to_tsv(pd.DataFrame.from_dict(manifest), f'/root/autodl-tmp/vocal2music/data/music24k/music.tsv') 82 | 83 | if __name__ == '__main__': 84 | generate() 85 | -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/clap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from transformers import AutoModel 6 | from .audio import get_audio_encoder 7 | 8 | class Projection(nn.Module): 9 | def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None: 10 | super().__init__() 11 | self.linear1 = nn.Linear(d_in, d_out, bias=False) 12 | self.linear2 = nn.Linear(d_out, d_out, bias=False) 13 | self.layer_norm = nn.LayerNorm(d_out) 14 | self.drop = nn.Dropout(p) 15 | 16 | def forward(self, x: torch.Tensor) -> torch.Tensor: 17 | embed1 = self.linear1(x) 18 | embed2 = self.drop(self.linear2(F.gelu(embed1))) 19 | embeds = self.layer_norm(embed1 + embed2) 20 | return embeds 21 | 22 | class AudioEncoder(nn.Module): 23 | def __init__(self, audioenc_name:str, d_in: int, d_out: int, sample_rate: int, window_size: int, 24 | hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None: 25 | super().__init__() 26 | 27 | audio_encoder = get_audio_encoder(audioenc_name) 28 | 29 | self.base = audio_encoder( 30 | sample_rate, window_size, 31 | hop_size, mel_bins, fmin, fmax, 32 | classes_num, d_in) 33 | 34 | self.projection = Projection(d_in, d_out) 35 | 36 | def forward(self, x): 37 | out_dict = self.base(x) 38 | audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output'] 39 | projected_vec = self.projection(audio_features) 40 | return projected_vec, audio_classification_output 41 | 42 | class TextEncoder(nn.Module): 43 | def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None: 44 | super().__init__() 45 | self.base = AutoModel.from_pretrained(text_model) 46 | self.projection = Projection(transformer_embed_dim, d_out) 47 | 48 | def forward(self, x): 49 | out = self.base(**x)[0] 50 | out = out[:, 0, :] # get CLS token output 51 | projected_vec = self.projection(out) 52 | return projected_vec 53 | 54 | class CLAP(nn.Module): 55 | def __init__(self, 56 | # audio 57 | audioenc_name: str, 58 | sample_rate: int, 59 | window_size: int, 60 | hop_size: int, 61 | mel_bins: int, 62 | fmin: int, 63 | fmax: int, 64 | classes_num: int, 65 | out_emb: int, 66 | # text 67 | text_model: str, 68 | transformer_embed_dim: int, 69 | # common 70 | d_proj: int, 71 | ): 72 | super().__init__() 73 | 74 | 75 | self.audio_encoder = AudioEncoder( 76 | audioenc_name, out_emb, d_proj, 77 | sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num) 78 | 79 | self.caption_encoder = TextEncoder( 80 | d_proj, text_model, transformer_embed_dim 81 | ) 82 | 83 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 84 | 85 | def forward(self, audio, text): 86 | audio_embed, _ = self.audio_encoder(audio) 87 | caption_embed = self.caption_encoder(text) 88 | 89 | return caption_embed, audio_embed, self.logit_scale.exp() -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/optimizers/radam.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """RAdam optimizer. 4 | 5 | This code is drived from https://github.com/LiyuanLucasLiu/RAdam. 6 | """ 7 | 8 | import math 9 | import torch 10 | 11 | from torch.optim.optimizer import Optimizer 12 | 13 | 14 | class RAdam(Optimizer): 15 | """Rectified Adam optimizer.""" 16 | 17 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 18 | """Initilize RAdam optimizer.""" 19 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 20 | self.buffer = [[None, None, None] for ind in range(10)] 21 | super(RAdam, self).__init__(params, defaults) 22 | 23 | def __setstate__(self, state): 24 | """Set state.""" 25 | super(RAdam, self).__setstate__(state) 26 | 27 | def step(self, closure=None): 28 | """Run one step.""" 29 | loss = None 30 | if closure is not None: 31 | loss = closure() 32 | 33 | for group in self.param_groups: 34 | 35 | for p in group['params']: 36 | if p.grad is None: 37 | continue 38 | grad = p.grad.data.float() 39 | if grad.is_sparse: 40 | raise RuntimeError('RAdam does not support sparse gradients') 41 | 42 | p_data_fp32 = p.data.float() 43 | 44 | state = self.state[p] 45 | 46 | if len(state) == 0: 47 | state['step'] = 0 48 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 49 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 50 | else: 51 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 52 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 53 | 54 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 55 | beta1, beta2 = group['betas'] 56 | 57 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 58 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 59 | 60 | state['step'] += 1 61 | buffered = self.buffer[int(state['step'] % 10)] 62 | if state['step'] == buffered[0]: 63 | N_sma, step_size = buffered[1], buffered[2] 64 | else: 65 | buffered[0] = state['step'] 66 | beta2_t = beta2 ** state['step'] 67 | N_sma_max = 2 / (1 - beta2) - 1 68 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 69 | buffered[1] = N_sma 70 | 71 | # more conservative since it's an approximated value 72 | if N_sma >= 5: 73 | step_size = math.sqrt( 74 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) # NOQA 75 | else: 76 | step_size = 1.0 / (1 - beta1 ** state['step']) 77 | buffered[2] = step_size 78 | 79 | if group['weight_decay'] != 0: 80 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 81 | 82 | # more conservative since it's an approximated value 83 | if N_sma >= 5: 84 | denom = exp_avg_sq.sqrt().add_(group['eps']) 85 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 86 | else: 87 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 88 | 89 | p.data.copy_(p_data_fp32) 90 | 91 | return loss 92 | -------------------------------------------------------------------------------- /ldm/data/joinaudiodataset_624.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import torch 4 | import logging 5 | import pandas as pd 6 | import glob 7 | logger = logging.getLogger(f'main.{__name__}') 8 | 9 | sys.path.insert(0, '.') # nopep8 10 | 11 | class JoinManifestSpecs(torch.utils.data.Dataset): 12 | def __init__(self, split, spec_dir_path, mel_num=None, spec_crop_len=None,drop=0,**kwargs): 13 | super().__init__() 14 | self.split = split 15 | self.batch_max_length = spec_crop_len 16 | self.batch_min_length = 50 17 | self.mel_num = mel_num 18 | self.drop = drop 19 | manifest_files = [] 20 | for dir_path in spec_dir_path.split(','): 21 | manifest_files += glob.glob(f'{dir_path}/**/*.tsv',recursive=True) 22 | df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files] 23 | df = pd.concat(df_list,ignore_index=True) 24 | 25 | if split == 'train': 26 | self.dataset = df.iloc[100:] 27 | elif split == 'valid' or split == 'val': 28 | self.dataset = df.iloc[:100] 29 | elif split == 'test': 30 | df = self.add_name_num(df) 31 | self.dataset = df 32 | else: 33 | raise ValueError(f'Unknown split {split}') 34 | self.dataset.reset_index(inplace=True) 35 | print('dataset len:', len(self.dataset)) 36 | 37 | def add_name_num(self,df): 38 | """each file may have different caption, we add num to filename to identify each audio-caption pair""" 39 | name_count_dict = {} 40 | change = [] 41 | for t in df.itertuples(): 42 | name = getattr(t,'name') 43 | if name in name_count_dict: 44 | name_count_dict[name] += 1 45 | else: 46 | name_count_dict[name] = 0 47 | change.append((t[0],name_count_dict[name])) 48 | for t in change: 49 | df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}' 50 | return df 51 | 52 | def __getitem__(self, idx): 53 | data = self.dataset.iloc[idx] 54 | item = {} 55 | try: 56 | spec = np.load(data['mel_path']) # mel spec [80, 624] 57 | except: 58 | mel_path = data['mel_path'] 59 | print(f'corrupted:{mel_path}') 60 | spec = np.zeros((self.mel_num,self.batch_max_length)).astype(np.float32) 61 | 62 | if spec.shape[1] < self.batch_max_length: 63 | # spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1]))) # [80, 624] 64 | spec = np.tile(spec, reps=(self.batch_max_length//spec.shape[1])+1) 65 | 66 | # 随机选择 67 | if spec.shape[1] > self.batch_max_length: 68 | start = np.random.randint(spec.shape[1] - self.batch_max_length) 69 | spec = spec[:, start: start + self.batch_max_length] 70 | 71 | item['image'] = spec[:,:self.batch_max_length] 72 | # p = np.random.uniform(0,1) 73 | # if p > self.drop: 74 | # item["caption"] = data['caption'] 75 | # else: 76 | # item["caption"] = "" 77 | if self.split == 'test': 78 | item['f_name'] = data['name'] 79 | return item 80 | 81 | def __len__(self): 82 | return len(self.dataset) 83 | 84 | 85 | class JoinSpecsTrain(JoinManifestSpecs): 86 | def __init__(self, specs_dataset_cfg): 87 | super().__init__('train', **specs_dataset_cfg) 88 | 89 | class JoinSpecsValidation(JoinManifestSpecs): 90 | def __init__(self, specs_dataset_cfg): 91 | super().__init__('valid', **specs_dataset_cfg) 92 | 93 | class JoinSpecsTest(JoinManifestSpecs): 94 | def __init__(self, specs_dataset_cfg): 95 | super().__init__('test', **specs_dataset_cfg) 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /vocoder/bigvgan/alias_free_torch/filter.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | 9 | if 'sinc' in dir(torch): 10 | sinc = torch.sinc 11 | else: 12 | # This code is adopted from adefossez's julius.core.sinc under the MIT License 13 | # https://adefossez.github.io/julius/julius/core.html 14 | # LICENSE is in incl_licenses directory. 15 | def sinc(x: torch.Tensor): 16 | """ 17 | Implementation of sinc, i.e. sin(pi * x) / (pi * x) 18 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`! 19 | """ 20 | return torch.where(x == 0, 21 | torch.tensor(1., device=x.device, dtype=x.dtype), 22 | torch.sin(math.pi * x) / math.pi / x) 23 | 24 | 25 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License 26 | # https://adefossez.github.io/julius/julius/lowpass.html 27 | # LICENSE is in incl_licenses directory. 28 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] 29 | even = (kernel_size % 2 == 0) 30 | half_size = kernel_size // 2 31 | 32 | #For kaiser window 33 | delta_f = 4 * half_width 34 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 35 | if A > 50.: 36 | beta = 0.1102 * (A - 8.7) 37 | elif A >= 21.: 38 | beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) 39 | else: 40 | beta = 0. 41 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) 42 | 43 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio 44 | if even: 45 | time = (torch.arange(-half_size, half_size) + 0.5) 46 | else: 47 | time = torch.arange(kernel_size) - half_size 48 | if cutoff == 0: 49 | filter_ = torch.zeros_like(time) 50 | else: 51 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) 52 | # Normalize filter to have sum = 1, otherwise we will have a small leakage 53 | # of the constant component in the input signal. 54 | filter_ /= filter_.sum() 55 | filter = filter_.view(1, 1, kernel_size) 56 | 57 | return filter 58 | 59 | 60 | class LowPassFilter1d(nn.Module): 61 | def __init__(self, 62 | cutoff=0.5, 63 | half_width=0.6, 64 | stride: int = 1, 65 | padding: bool = True, 66 | padding_mode: str = 'replicate', 67 | kernel_size: int = 12): 68 | # kernel_size should be even number for stylegan3 setup, 69 | # in this implementation, odd number is also possible. 70 | super().__init__() 71 | if cutoff < -0.: 72 | raise ValueError("Minimum cutoff must be larger than zero.") 73 | if cutoff > 0.5: 74 | raise ValueError("A cutoff above 0.5 does not make sense.") 75 | self.kernel_size = kernel_size 76 | self.even = (kernel_size % 2 == 0) 77 | self.pad_left = kernel_size // 2 - int(self.even) 78 | self.pad_right = kernel_size // 2 79 | self.stride = stride 80 | self.padding = padding 81 | self.padding_mode = padding_mode 82 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) 83 | self.register_buffer("filter", filter) 84 | 85 | #input [B, C, T] 86 | def forward(self, x): 87 | _, C, _ = x.shape 88 | 89 | if self.padding: 90 | x = F.pad(x, (self.pad_left, self.pad_right), 91 | mode=self.padding_mode) 92 | out = F.conv1d(x, self.filter.expand(C, -1, -1), 93 | stride=self.stride, groups=C) 94 | 95 | return out -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/stft_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """STFT-based Loss modules.""" 7 | import librosa 8 | import torch 9 | 10 | from vocoder.parallel_wavegan.losses import LogSTFTMagnitudeLoss, SpectralConvergengeLoss, stft 11 | 12 | 13 | class STFTLoss(torch.nn.Module): 14 | """STFT loss module.""" 15 | 16 | def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", 17 | use_mel_loss=False): 18 | """Initialize STFT loss module.""" 19 | super(STFTLoss, self).__init__() 20 | self.fft_size = fft_size 21 | self.shift_size = shift_size 22 | self.win_length = win_length 23 | self.window = getattr(torch, window)(win_length) 24 | self.spectral_convergenge_loss = SpectralConvergengeLoss() 25 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 26 | self.use_mel_loss = use_mel_loss 27 | self.mel_basis = None 28 | 29 | def forward(self, x, y): 30 | """Calculate forward propagation. 31 | 32 | Args: 33 | x (Tensor): Predicted signal (B, T). 34 | y (Tensor): Groundtruth signal (B, T). 35 | 36 | Returns: 37 | Tensor: Spectral convergence loss value. 38 | Tensor: Log STFT magnitude loss value. 39 | 40 | """ 41 | x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) 42 | y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) 43 | if self.use_mel_loss: 44 | if self.mel_basis is None: 45 | self.mel_basis = torch.from_numpy(librosa.filters.mel(22050, self.fft_size, 80)).cuda().T 46 | x_mag = x_mag @ self.mel_basis 47 | y_mag = y_mag @ self.mel_basis 48 | 49 | sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) 50 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 51 | 52 | return sc_loss, mag_loss 53 | 54 | 55 | class MultiResolutionSTFTLoss(torch.nn.Module): 56 | """Multi resolution STFT loss module.""" 57 | 58 | def __init__(self, 59 | fft_sizes=[1024, 2048, 512], 60 | hop_sizes=[120, 240, 50], 61 | win_lengths=[600, 1200, 240], 62 | window="hann_window", 63 | use_mel_loss=False): 64 | """Initialize Multi resolution STFT loss module. 65 | 66 | Args: 67 | fft_sizes (list): List of FFT sizes. 68 | hop_sizes (list): List of hop sizes. 69 | win_lengths (list): List of window lengths. 70 | window (str): Window function type. 71 | 72 | """ 73 | super(MultiResolutionSTFTLoss, self).__init__() 74 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 75 | self.stft_losses = torch.nn.ModuleList() 76 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 77 | self.stft_losses += [STFTLoss(fs, ss, wl, window, use_mel_loss)] 78 | 79 | def forward(self, x, y): 80 | """Calculate forward propagation. 81 | 82 | Args: 83 | x (Tensor): Predicted signal (B, T). 84 | y (Tensor): Groundtruth signal (B, T). 85 | 86 | Returns: 87 | Tensor: Multi resolution spectral convergence loss value. 88 | Tensor: Multi resolution log STFT magnitude loss value. 89 | 90 | """ 91 | sc_loss = 0.0 92 | mag_loss = 0.0 93 | for f in self.stft_losses: 94 | sc_l, mag_l = f(x, y) 95 | sc_loss += sc_l 96 | mag_loss += mag_l 97 | sc_loss /= len(self.stft_losses) 98 | mag_loss /= len(self.stft_losses) 99 | 100 | return sc_loss, mag_loss 101 | -------------------------------------------------------------------------------- /configs/vocal2music.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 3.0e-06 3 | target: ldm.models.diffusion.cfm1_audio.CFM 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | mel_dim: 20 13 | mel_length: 750 14 | channels: 0 15 | cond_stage_trainable: True 16 | conditioning_key: hybrid 17 | monitor: val/loss_simple_ema 18 | scale_by_std: true 19 | use_ema: false 20 | scheduler_config: 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: 24 | - 10000 25 | cycle_lengths: 26 | - 10000000000000 27 | f_start: 28 | - 1.0e-06 29 | f_max: 30 | - 1.0 31 | f_min: 32 | - 1.0 33 | unet_config: 34 | target: ldm.modules.diffusionmodules.vocal2music_moe.TxtFlagLargeImprovedDiTV2 35 | params: 36 | in_channels: 20 37 | ori_dim: 1024 38 | context_dim: 768 39 | hidden_size: 768 40 | num_heads: 8 41 | depth: 4 42 | max_len: 1500 43 | num_experts: 4 44 | 45 | first_stage_config: 46 | target: ldm.models.autoencoder1d.AutoencoderKL 47 | params: 48 | embed_dim: 20 49 | monitor: val/rec_loss 50 | ckpt_path: logs/2024-04-21T17-06-16_ae_accomp/bkup_ckpts/last.ckpt 51 | ddconfig: 52 | double_z: true 53 | in_channels: 80 54 | out_ch: 80 55 | z_channels: 20 56 | kernel_size: 5 57 | ch: 384 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | num_res_blocks: 2 63 | attn_layers: 64 | - 3 65 | down_layers: 66 | - 0 67 | dropout: 0.0 68 | lossconfig: 69 | target: torch.nn.Identity 70 | cond_stage_config: 71 | target: ldm.modules.encoders.modules.FrozenTextVocalEmbedder 72 | params: 73 | version: useful_ckpts/flan-t5-large 74 | max_length: 80 75 | 76 | lightning: 77 | callbacks: 78 | image_logger: 79 | target: main.AudioLogger 80 | params: 81 | sample_rate: 24000 82 | for_specs: true 83 | increase_log_steps: false 84 | batch_frequency: 5000 85 | max_images: 8 86 | melvmin: -5 87 | melvmax: 1.5 88 | vocoder_cfg: 89 | target: vocoder.hifigan.hifigan.HifiGAN 90 | params: 91 | vocoder_ckpt: useful_ckpts/hifigan 92 | trainer: 93 | benchmark: True 94 | gradient_clip_val: 1.0 95 | replace_sampler_ddp: false 96 | modelcheckpoint: 97 | params: 98 | monitor: epoch 99 | mode: max 100 | save_top_k: 10 101 | every_n_epochs: 10 102 | 103 | data: 104 | target: main.SpectrogramDataModuleFromConfig 105 | params: 106 | batch_size: 8 107 | num_workers: 16 108 | main_spec_dir_path: '/root/autodl-tmp/data/manifests/vocal_to_accomp/train/v2c_0905' 109 | other_spec_dir_path: '' 110 | mel_num: 80 111 | drop: 0.1 112 | spec_crop_len: 1500 113 | other_condition: '/root/autodl-tmp/data/manifests/vocal_to_accomp/train/v2c_0905/midi.npy' # codec npy path!!! 114 | train: 115 | target: ldm.data.vocal2accomp_musical_dataset.JoinSpecsTrain 116 | params: 117 | specs_dataset_cfg: 118 | validation: 119 | target: ldm.data.vocal2accomp_musical_dataset.JoinSpecsValidation 120 | params: 121 | specs_dataset_cfg: 122 | 123 | test_dataset: 124 | target: ldm.data.tsvdataset.TSVDatasetStruct 125 | params: 126 | tsv_path: audiocaps_test_16000_struct2.tsv 127 | spec_crop_len: 1500 128 | 129 | -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/layers/tf_layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 MINH ANH (@dathudeptrai) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Tensorflow Layer modules complatible with pytorch.""" 7 | 8 | import tensorflow as tf 9 | 10 | 11 | class TFReflectionPad1d(tf.keras.layers.Layer): 12 | """Tensorflow ReflectionPad1d module.""" 13 | 14 | def __init__(self, padding_size): 15 | """Initialize TFReflectionPad1d module. 16 | 17 | Args: 18 | padding_size (int): Padding size. 19 | 20 | """ 21 | super(TFReflectionPad1d, self).__init__() 22 | self.padding_size = padding_size 23 | 24 | @tf.function 25 | def call(self, x): 26 | """Calculate forward propagation. 27 | 28 | Args: 29 | x (Tensor): Input tensor (B, T, 1, C). 30 | 31 | Returns: 32 | Tensor: Padded tensor (B, T + 2 * padding_size, 1, C). 33 | 34 | """ 35 | return tf.pad(x, [[0, 0], [self.padding_size, self.padding_size], [0, 0], [0, 0]], "REFLECT") 36 | 37 | 38 | class TFConvTranspose1d(tf.keras.layers.Layer): 39 | """Tensorflow ConvTranspose1d module.""" 40 | 41 | def __init__(self, channels, kernel_size, stride, padding): 42 | """Initialize TFConvTranspose1d( module. 43 | 44 | Args: 45 | channels (int): Number of channels. 46 | kernel_size (int): kernel size. 47 | strides (int): Stride width. 48 | padding (str): Padding type ("same" or "valid"). 49 | 50 | """ 51 | super(TFConvTranspose1d, self).__init__() 52 | self.conv1d_transpose = tf.keras.layers.Conv2DTranspose( 53 | filters=channels, 54 | kernel_size=(kernel_size, 1), 55 | strides=(stride, 1), 56 | padding=padding, 57 | ) 58 | 59 | @tf.function 60 | def call(self, x): 61 | """Calculate forward propagation. 62 | 63 | Args: 64 | x (Tensor): Input tensor (B, T, 1, C). 65 | 66 | Returns: 67 | Tensors: Output tensor (B, T', 1, C'). 68 | 69 | """ 70 | x = self.conv1d_transpose(x) 71 | return x 72 | 73 | 74 | class TFResidualStack(tf.keras.layers.Layer): 75 | """Tensorflow ResidualStack module.""" 76 | 77 | def __init__(self, 78 | kernel_size, 79 | channels, 80 | dilation, 81 | bias, 82 | nonlinear_activation, 83 | nonlinear_activation_params, 84 | padding, 85 | ): 86 | """Initialize TFResidualStack module. 87 | 88 | Args: 89 | kernel_size (int): Kernel size. 90 | channles (int): Number of channels. 91 | dilation (int): Dilation ine. 92 | bias (bool): Whether to add bias parameter in convolution layers. 93 | nonlinear_activation (str): Activation function module name. 94 | nonlinear_activation_params (dict): Hyperparameters for activation function. 95 | padding (str): Padding type ("same" or "valid"). 96 | 97 | """ 98 | super(TFResidualStack, self).__init__() 99 | self.block = [ 100 | getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params), 101 | TFReflectionPad1d(dilation), 102 | tf.keras.layers.Conv2D( 103 | filters=channels, 104 | kernel_size=(kernel_size, 1), 105 | dilation_rate=(dilation, 1), 106 | use_bias=bias, 107 | padding="valid", 108 | ), 109 | getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params), 110 | tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias) 111 | ] 112 | self.shortcut = tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias) 113 | 114 | @tf.function 115 | def call(self, x): 116 | """Calculate forward propagation. 117 | 118 | Args: 119 | x (Tensor): Input tensor (B, T, 1, C). 120 | 121 | Returns: 122 | Tensor: Output tensor (B, T, 1, C). 123 | 124 | """ 125 | _x = tf.identity(x) 126 | for i, layer in enumerate(self.block): 127 | _x = layer(_x) 128 | shortcut = self.shortcut(x) 129 | return shortcut + _x 130 | -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/layers/pqmf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Pseudo QMF modules.""" 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from scipy.signal import kaiser 13 | 14 | 15 | def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0): 16 | """Design prototype filter for PQMF. 17 | 18 | This method is based on `A Kaiser window approach for the design of prototype 19 | filters of cosine modulated filterbanks`_. 20 | 21 | Args: 22 | taps (int): The number of filter taps. 23 | cutoff_ratio (float): Cut-off frequency ratio. 24 | beta (float): Beta coefficient for kaiser window. 25 | 26 | Returns: 27 | ndarray: Impluse response of prototype filter (taps + 1,). 28 | 29 | .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`: 30 | https://ieeexplore.ieee.org/abstract/document/681427 31 | 32 | """ 33 | # check the arguments are valid 34 | assert taps % 2 == 0, "The number of taps mush be even number." 35 | assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0." 36 | 37 | # make initial filter 38 | omega_c = np.pi * cutoff_ratio 39 | with np.errstate(invalid='ignore'): 40 | h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \ 41 | / (np.pi * (np.arange(taps + 1) - 0.5 * taps)) 42 | h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form 43 | 44 | # apply kaiser window 45 | w = kaiser(taps + 1, beta) 46 | h = h_i * w 47 | 48 | return h 49 | 50 | 51 | class PQMF(torch.nn.Module): 52 | """PQMF module. 53 | 54 | This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_. 55 | 56 | .. _`Near-perfect-reconstruction pseudo-QMF banks`: 57 | https://ieeexplore.ieee.org/document/258122 58 | 59 | """ 60 | 61 | def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0): 62 | """Initilize PQMF module. 63 | 64 | Args: 65 | subbands (int): The number of subbands. 66 | taps (int): The number of filter taps. 67 | cutoff_ratio (float): Cut-off frequency ratio. 68 | beta (float): Beta coefficient for kaiser window. 69 | 70 | """ 71 | super(PQMF, self).__init__() 72 | 73 | # define filter coefficient 74 | h_proto = design_prototype_filter(taps, cutoff_ratio, beta) 75 | h_analysis = np.zeros((subbands, len(h_proto))) 76 | h_synthesis = np.zeros((subbands, len(h_proto))) 77 | for k in range(subbands): 78 | h_analysis[k] = 2 * h_proto * np.cos( 79 | (2 * k + 1) * (np.pi / (2 * subbands)) * 80 | (np.arange(taps + 1) - ((taps - 1) / 2)) + 81 | (-1) ** k * np.pi / 4) 82 | h_synthesis[k] = 2 * h_proto * np.cos( 83 | (2 * k + 1) * (np.pi / (2 * subbands)) * 84 | (np.arange(taps + 1) - ((taps - 1) / 2)) - 85 | (-1) ** k * np.pi / 4) 86 | 87 | # convert to tensor 88 | analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1) 89 | synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0) 90 | 91 | # register coefficients as beffer 92 | self.register_buffer("analysis_filter", analysis_filter) 93 | self.register_buffer("synthesis_filter", synthesis_filter) 94 | 95 | # filter for downsampling & upsampling 96 | updown_filter = torch.zeros((subbands, subbands, subbands)).float() 97 | for k in range(subbands): 98 | updown_filter[k, k, 0] = 1.0 99 | self.register_buffer("updown_filter", updown_filter) 100 | self.subbands = subbands 101 | 102 | # keep padding info 103 | self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) 104 | 105 | def analysis(self, x): 106 | """Analysis with PQMF. 107 | 108 | Args: 109 | x (Tensor): Input tensor (B, 1, T). 110 | 111 | Returns: 112 | Tensor: Output tensor (B, subbands, T // subbands). 113 | 114 | """ 115 | x = F.conv1d(self.pad_fn(x), self.analysis_filter) 116 | return F.conv1d(x, self.updown_filter, stride=self.subbands) 117 | 118 | def synthesis(self, x): 119 | """Synthesis with PQMF. 120 | 121 | Args: 122 | x (Tensor): Input tensor (B, subbands, T // subbands). 123 | 124 | Returns: 125 | Tensor: Output tensor (B, 1, T). 126 | 127 | """ 128 | x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands) 129 | return F.conv1d(self.pad_fn(x), self.synthesis_filter) 130 | -------------------------------------------------------------------------------- /vocoder/bigvgan/activations.py: -------------------------------------------------------------------------------- 1 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | from torch import nn, sin, pow 6 | from torch.nn import Parameter 7 | 8 | 9 | class Snake(nn.Module): 10 | ''' 11 | Implementation of a sine-based periodic activation function 12 | Shape: 13 | - Input: (B, C, T) 14 | - Output: (B, C, T), same shape as the input 15 | Parameters: 16 | - alpha - trainable parameter 17 | References: 18 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 19 | https://arxiv.org/abs/2006.08195 20 | Examples: 21 | >>> a1 = snake(256) 22 | >>> x = torch.randn(256) 23 | >>> x = a1(x) 24 | ''' 25 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 26 | ''' 27 | Initialization. 28 | INPUT: 29 | - in_features: shape of the input 30 | - alpha: trainable parameter 31 | alpha is initialized to 1 by default, higher values = higher-frequency. 32 | alpha will be trained along with the rest of your model. 33 | ''' 34 | super(Snake, self).__init__() 35 | self.in_features = in_features 36 | 37 | # initialize alpha 38 | self.alpha_logscale = alpha_logscale 39 | if self.alpha_logscale: # log scale alphas initialized to zeros 40 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 41 | else: # linear scale alphas initialized to ones 42 | self.alpha = Parameter(torch.ones(in_features) * alpha) 43 | 44 | self.alpha.requires_grad = alpha_trainable 45 | 46 | self.no_div_by_zero = 0.000000001 47 | 48 | def forward(self, x): 49 | ''' 50 | Forward pass of the function. 51 | Applies the function to the input elementwise. 52 | Snake ∶= x + 1/a * sin^2 (xa) 53 | ''' 54 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 55 | if self.alpha_logscale: 56 | alpha = torch.exp(alpha) 57 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 58 | 59 | return x 60 | 61 | 62 | class SnakeBeta(nn.Module): 63 | ''' 64 | A modified Snake function which uses separate parameters for the magnitude of the periodic components 65 | Shape: 66 | - Input: (B, C, T) 67 | - Output: (B, C, T), same shape as the input 68 | Parameters: 69 | - alpha - trainable parameter that controls frequency 70 | - beta - trainable parameter that controls magnitude 71 | References: 72 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 73 | https://arxiv.org/abs/2006.08195 74 | Examples: 75 | >>> a1 = snakebeta(256) 76 | >>> x = torch.randn(256) 77 | >>> x = a1(x) 78 | ''' 79 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 80 | ''' 81 | Initialization. 82 | INPUT: 83 | - in_features: shape of the input 84 | - alpha - trainable parameter that controls frequency 85 | - beta - trainable parameter that controls magnitude 86 | alpha is initialized to 1 by default, higher values = higher-frequency. 87 | beta is initialized to 1 by default, higher values = higher-magnitude. 88 | alpha will be trained along with the rest of your model. 89 | ''' 90 | super(SnakeBeta, self).__init__() 91 | self.in_features = in_features 92 | 93 | # initialize alpha 94 | self.alpha_logscale = alpha_logscale 95 | if self.alpha_logscale: # log scale alphas initialized to zeros 96 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 97 | self.beta = Parameter(torch.zeros(in_features) * alpha) 98 | else: # linear scale alphas initialized to ones 99 | self.alpha = Parameter(torch.ones(in_features) * alpha) 100 | self.beta = Parameter(torch.ones(in_features) * alpha) 101 | 102 | self.alpha.requires_grad = alpha_trainable 103 | self.beta.requires_grad = alpha_trainable 104 | 105 | self.no_div_by_zero = 0.000000001 106 | 107 | def forward(self, x): 108 | ''' 109 | Forward pass of the function. 110 | Applies the function to the input elementwise. 111 | SnakeBeta ∶= x + 1/b * sin^2 (xa) 112 | ''' 113 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 114 | beta = self.beta.unsqueeze(0).unsqueeze(-1) 115 | if self.alpha_logscale: 116 | alpha = torch.exp(alpha) 117 | beta = torch.exp(beta) 118 | x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 119 | 120 | return x -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/layers/residual_block.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Residual block module in WaveNet. 4 | 5 | This code is modified from https://github.com/r9y9/wavenet_vocoder. 6 | 7 | """ 8 | 9 | import math 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | 15 | class Conv1d(torch.nn.Conv1d): 16 | """Conv1d module with customized initialization.""" 17 | 18 | def __init__(self, *args, **kwargs): 19 | """Initialize Conv1d module.""" 20 | super(Conv1d, self).__init__(*args, **kwargs) 21 | 22 | def reset_parameters(self): 23 | """Reset parameters.""" 24 | torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") 25 | if self.bias is not None: 26 | torch.nn.init.constant_(self.bias, 0.0) 27 | 28 | 29 | class Conv1d1x1(Conv1d): 30 | """1x1 Conv1d with customized initialization.""" 31 | 32 | def __init__(self, in_channels, out_channels, bias): 33 | """Initialize 1x1 Conv1d module.""" 34 | super(Conv1d1x1, self).__init__(in_channels, out_channels, 35 | kernel_size=1, padding=0, 36 | dilation=1, bias=bias) 37 | 38 | 39 | class ResidualBlock(torch.nn.Module): 40 | """Residual block module in WaveNet.""" 41 | 42 | def __init__(self, 43 | kernel_size=3, 44 | residual_channels=64, 45 | gate_channels=128, 46 | skip_channels=64, 47 | aux_channels=80, 48 | dropout=0.0, 49 | dilation=1, 50 | bias=True, 51 | use_causal_conv=False 52 | ): 53 | """Initialize ResidualBlock module. 54 | 55 | Args: 56 | kernel_size (int): Kernel size of dilation convolution layer. 57 | residual_channels (int): Number of channels for residual connection. 58 | skip_channels (int): Number of channels for skip connection. 59 | aux_channels (int): Local conditioning channels i.e. auxiliary input dimension. 60 | dropout (float): Dropout probability. 61 | dilation (int): Dilation factor. 62 | bias (bool): Whether to add bias parameter in convolution layers. 63 | use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution. 64 | 65 | """ 66 | super(ResidualBlock, self).__init__() 67 | self.dropout = dropout 68 | # no future time stamps available 69 | if use_causal_conv: 70 | padding = (kernel_size - 1) * dilation 71 | else: 72 | assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." 73 | padding = (kernel_size - 1) // 2 * dilation 74 | self.use_causal_conv = use_causal_conv 75 | 76 | # dilation conv 77 | self.conv = Conv1d(residual_channels, gate_channels, kernel_size, 78 | padding=padding, dilation=dilation, bias=bias) 79 | 80 | # local conditioning 81 | if aux_channels > 0: 82 | self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) 83 | else: 84 | self.conv1x1_aux = None 85 | 86 | # conv output is split into two groups 87 | gate_out_channels = gate_channels // 2 88 | self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias) 89 | self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias) 90 | 91 | def forward(self, x, c): 92 | """Calculate forward propagation. 93 | 94 | Args: 95 | x (Tensor): Input tensor (B, residual_channels, T). 96 | c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T). 97 | 98 | Returns: 99 | Tensor: Output tensor for residual connection (B, residual_channels, T). 100 | Tensor: Output tensor for skip connection (B, skip_channels, T). 101 | 102 | """ 103 | residual = x 104 | x = F.dropout(x, p=self.dropout, training=self.training) 105 | x = self.conv(x) 106 | 107 | # remove future time steps if use_causal_conv conv 108 | x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x 109 | 110 | # split into two part for gated activation 111 | splitdim = 1 112 | xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) 113 | 114 | # local conditioning 115 | if c is not None: 116 | assert self.conv1x1_aux is not None 117 | c = self.conv1x1_aux(c) 118 | ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) 119 | xa, xb = xa + ca, xb + cb 120 | 121 | x = torch.tanh(xa) * torch.sigmoid(xb) 122 | 123 | # for skip connection 124 | s = self.conv1x1_skip(x) 125 | 126 | # for residual connection 127 | x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5) 128 | 129 | return x, s 130 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from inspect import isfunction 7 | from PIL import Image, ImageDraw, ImageFont 8 | import hashlib 9 | import requests 10 | import os 11 | 12 | URL_MAP = { 13 | 'vggishish_lpaps': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt', 14 | 'vggishish_mean_std_melspec_10s_22050hz': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt', 15 | 'melception': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt', 16 | } 17 | 18 | CKPT_MAP = { 19 | 'vggishish_lpaps': 'vggishish16.pt', 20 | 'vggishish_mean_std_melspec_10s_22050hz': 'train_means_stds_melspec_10s_22050hz.txt', 21 | 'melception': 'melception-21-05-10T09-28-40.pt', 22 | } 23 | 24 | MD5_MAP = { 25 | 'vggishish_lpaps': '197040c524a07ccacf7715d7080a80bd', 26 | 'vggishish_mean_std_melspec_10s_22050hz': 'f449c6fd0e248936c16f6d22492bb625', 27 | 'melception': 'a71a41041e945b457c7d3d814bbcf72d', 28 | } 29 | 30 | 31 | def download(url, local_path, chunk_size=1024): 32 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 33 | with requests.get(url, stream=True) as r: 34 | total_size = int(r.headers.get("content-length", 0)) 35 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 36 | with open(local_path, "wb") as f: 37 | for data in r.iter_content(chunk_size=chunk_size): 38 | if data: 39 | f.write(data) 40 | pbar.update(chunk_size) 41 | 42 | 43 | def md5_hash(path): 44 | with open(path, "rb") as f: 45 | content = f.read() 46 | return hashlib.md5(content).hexdigest() 47 | 48 | 49 | 50 | def log_txt_as_img(wh, xc, size=10): 51 | # wh a tuple of (width, height),xc a list of captions to plot 52 | b = len(xc) 53 | txts = list() 54 | for bi in range(b): 55 | txt = Image.new("RGB", wh, color="white") 56 | draw = ImageDraw.Draw(txt) 57 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 58 | nc = int(40 * (wh[0] / 256)) 59 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 60 | 61 | try: 62 | draw.text((0, 0), lines, fill="black", font=font) 63 | except UnicodeEncodeError: 64 | print("Cant encode string for logging. Skipping.") 65 | 66 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 67 | txts.append(txt) 68 | txts = np.stack(txts) 69 | txts = torch.tensor(txts) 70 | return txts 71 | 72 | 73 | def ismap(x): 74 | if not isinstance(x, torch.Tensor): 75 | return False 76 | return (len(x.shape) == 4) and (x.shape[1] > 3) 77 | 78 | 79 | def isimage(x): 80 | if not isinstance(x,torch.Tensor): 81 | return False 82 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 83 | 84 | 85 | def exists(x): 86 | return x is not None 87 | 88 | 89 | def default(val, d): 90 | if exists(val): 91 | return val 92 | return d() if isfunction(d) else d 93 | 94 | 95 | def mean_flat(tensor): 96 | """ 97 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 98 | Take the mean over all non-batch dimensions. 99 | """ 100 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 101 | 102 | 103 | def count_params(model, verbose=False): 104 | total_params = sum(p.numel() for p in model.parameters()) 105 | if verbose: 106 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 107 | return total_params 108 | 109 | 110 | def instantiate_from_config(config,reload=False): 111 | if not "target" in config: 112 | if config == '__is_first_stage__': 113 | return None 114 | elif config == "__is_unconditional__": 115 | return None 116 | raise KeyError("Expected key `target` to instantiate.") 117 | return get_obj_from_str(config["target"],reload=reload)(**config.get("params", dict())) 118 | 119 | 120 | def get_obj_from_str(string, reload=False): 121 | module, cls = string.rsplit(".", 1) 122 | if reload: 123 | module_imp = importlib.import_module(module) 124 | importlib.reload(module_imp) 125 | return getattr(importlib.import_module(module, package=None), cls) 126 | 127 | def get_ckpt_path(name, root, check=False): 128 | assert name in URL_MAP 129 | path = os.path.join(root, CKPT_MAP[name]) 130 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 131 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 132 | download(URL_MAP[name], path) 133 | md5 = md5_hash(path) 134 | assert md5 == MD5_MAP[name], md5 135 | return path 136 | -------------------------------------------------------------------------------- /vocoder/hifigan/hifigan_nsf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import librosa 3 | import json 4 | import glob 5 | import re 6 | import os 7 | 8 | import torch 9 | from vocoder.hifigan.modules.hifigan_nsf import HifiGanGenerator 10 | from utils.commons.ckpt_utils import load_ckpt 11 | from utils.commons.hparams import set_hparams, hparams 12 | 13 | def denoise(wav, v=0.1): 14 | spec = librosa.stft(y=wav, n_fft=1024, hop_length=256, 15 | win_length=1024, pad_mode='constant') 16 | spec_m = np.abs(spec) 17 | spec_m = np.clip(spec_m - v, a_min=0, a_max=None) 18 | spec_a = np.angle(spec) 19 | 20 | return librosa.istft(spec_m * np.exp(1j * spec_a), hop_length=256, 21 | win_length=1024) 22 | 23 | def load_model(config_path, checkpoint_path): 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | ckpt_dict = torch.load(checkpoint_path, map_location="cpu") 26 | if '.yaml' in config_path: 27 | config = set_hparams(config_path, global_hparams=False) 28 | state = ckpt_dict["state_dict"]["model_gen"] 29 | elif '.json' in config_path: 30 | config = json.load(open(config_path, 'r')) 31 | state = ckpt_dict["generator"] 32 | 33 | model = HifiGanGenerator(config) 34 | model.load_state_dict(state, strict=True) 35 | model.remove_weight_norm() 36 | model = model.eval().to(device) 37 | print(f"| Loaded model parameters from {checkpoint_path}.") 38 | print(f"| HifiGAN device: {device}.") 39 | return model, config, device 40 | 41 | total_time = 0 42 | 43 | class HifiGAN_NSF(torch.nn.Module): 44 | def __init__(self, vocoder_ckpt, device=None, use_nsf=True): 45 | super().__init__() 46 | self.use_nsf = use_nsf 47 | base_dir = vocoder_ckpt 48 | config_path = f'{base_dir}/config.yaml' 49 | if os.path.exists(config_path): 50 | ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key= 51 | lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1] 52 | print('| load HifiGAN: ', ckpt) 53 | self.model, self.config, self.device = load_model(config_path=config_path, checkpoint_path=ckpt) 54 | else: 55 | config_path = f'{base_dir}/config.json' 56 | ckpt = f'{base_dir}/generator_v1' 57 | if os.path.exists(config_path): 58 | self.model, self.config, self.device = load_model(config_path=config_path, checkpoint_path=ckpt) 59 | 60 | def extract_f0_from_mel(self, mel): 61 | # 提取 f0 的方法 62 | import librosa 63 | import numpy as np 64 | 65 | # 假设已知采样率和其他参数 66 | sr = 48000 67 | n_fft = 1024 68 | hop_length = 256 69 | win_length = 1024 70 | n_mels = mel.shape[0] 71 | 72 | # 由于直接从 Mel 频谱图中提取 f0 精度不高,这里使用一个简单的方法进行估计 73 | # 将 Mel 频谱图逆变换回线性频谱 74 | mel_basis = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels) 75 | mel_basis_inv = np.linalg.pinv(mel_basis) 76 | linear_spec = np.dot(mel_basis_inv, mel) 77 | 78 | # 获取幅度谱 79 | S_abs = np.abs(linear_spec) 80 | 81 | # 使用 librosa 的 piptrack 方法估计频率 82 | frequencies, magnitudes = librosa.piptrack(S=S_abs, sr=sr, n_fft=n_fft, hop_length=hop_length) 83 | 84 | # 初始化 f0 数组 85 | f0 = np.zeros(frequencies.shape[1]) 86 | 87 | # 遍历每一帧,找到最大幅度对应的频率 88 | for i in range(frequencies.shape[1]): 89 | index = magnitudes[:, i].argmax() 90 | f0[i] = frequencies[index, i] 91 | 92 | return f0 93 | 94 | def spec2wav(self, mel, **kwargs): 95 | device = self.device 96 | with torch.no_grad(): 97 | c = torch.FloatTensor(mel).unsqueeze(0).transpose(2, 1).to(device) 98 | f0 = kwargs.get('f0') 99 | if self.use_nsf: 100 | if f0 is None: 101 | # 从 mel 提取 f0 102 | f0 = self.extract_f0_from_mel(mel) 103 | f0 = torch.FloatTensor(f0[None, :]).to(device) 104 | y = self.model(c, f0).view(-1) 105 | else: 106 | y = self.model(c).view(-1) 107 | wav_out = y.cpu().numpy() 108 | if hparams.get('vocoder_denoise_c', 0.0) > 0: 109 | wav_out = denoise(wav_out, v=hparams['vocoder_denoise_c']) 110 | return wav_out 111 | 112 | def vocode(self, mel, **kwargs): 113 | assert len(mel.shape) == 2 114 | device = self.device 115 | with torch.no_grad(): 116 | c = torch.FloatTensor(mel).unsqueeze(0).to(device) 117 | f0 = kwargs.get('f0') 118 | if c.shape[1] != 80: 119 | c = c.transpose(2, 1) 120 | if self.use_nsf: 121 | if f0 is None: 122 | # 从 mel 提取 f0 123 | f0 = self.extract_f0_from_mel(mel) 124 | 125 | f0 = torch.FloatTensor(f0[None, :]).to(device) 126 | y = self.model(c, f0).view(-1) 127 | else: 128 | y = self.model(c).view(-1) 129 | wav_out = y.cpu().numpy() 130 | if hparams.get('vocoder_denoise_c', 0.0) > 0: 131 | wav_out = denoise(wav_out, v=hparams['vocoder_denoise_c']) 132 | return wav_out 133 | -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Utility functions.""" 7 | 8 | import fnmatch 9 | import logging 10 | import os 11 | import sys 12 | 13 | import h5py 14 | import numpy as np 15 | 16 | 17 | def find_files(root_dir, query="*.wav", include_root_dir=True): 18 | """Find files recursively. 19 | 20 | Args: 21 | root_dir (str): Root root_dir to find. 22 | query (str): Query to find. 23 | include_root_dir (bool): If False, root_dir name is not included. 24 | 25 | Returns: 26 | list: List of found filenames. 27 | 28 | """ 29 | files = [] 30 | for root, dirnames, filenames in os.walk(root_dir, followlinks=True): 31 | for filename in fnmatch.filter(filenames, query): 32 | files.append(os.path.join(root, filename)) 33 | if not include_root_dir: 34 | files = [file_.replace(root_dir + "/", "") for file_ in files] 35 | 36 | return files 37 | 38 | 39 | def read_hdf5(hdf5_name, hdf5_path): 40 | """Read hdf5 dataset. 41 | 42 | Args: 43 | hdf5_name (str): Filename of hdf5 file. 44 | hdf5_path (str): Dataset name in hdf5 file. 45 | 46 | Return: 47 | any: Dataset values. 48 | 49 | """ 50 | if not os.path.exists(hdf5_name): 51 | logging.error(f"There is no such a hdf5 file ({hdf5_name}).") 52 | sys.exit(1) 53 | 54 | hdf5_file = h5py.File(hdf5_name, "r") 55 | 56 | if hdf5_path not in hdf5_file: 57 | logging.error(f"There is no such a data in hdf5 file. ({hdf5_path})") 58 | sys.exit(1) 59 | 60 | hdf5_data = hdf5_file[hdf5_path][()] 61 | hdf5_file.close() 62 | 63 | return hdf5_data 64 | 65 | 66 | def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True): 67 | """Write dataset to hdf5. 68 | 69 | Args: 70 | hdf5_name (str): Hdf5 dataset filename. 71 | hdf5_path (str): Dataset path in hdf5. 72 | write_data (ndarray): Data to write. 73 | is_overwrite (bool): Whether to overwrite dataset. 74 | 75 | """ 76 | # convert to numpy array 77 | write_data = np.array(write_data) 78 | 79 | # check folder existence 80 | folder_name, _ = os.path.split(hdf5_name) 81 | if not os.path.exists(folder_name) and len(folder_name) != 0: 82 | os.makedirs(folder_name) 83 | 84 | # check hdf5 existence 85 | if os.path.exists(hdf5_name): 86 | # if already exists, open with r+ mode 87 | hdf5_file = h5py.File(hdf5_name, "r+") 88 | # check dataset existence 89 | if hdf5_path in hdf5_file: 90 | if is_overwrite: 91 | logging.warning("Dataset in hdf5 file already exists. " 92 | "recreate dataset in hdf5.") 93 | hdf5_file.__delitem__(hdf5_path) 94 | else: 95 | logging.error("Dataset in hdf5 file already exists. " 96 | "if you want to overwrite, please set is_overwrite = True.") 97 | hdf5_file.close() 98 | sys.exit(1) 99 | else: 100 | # if not exists, open with w mode 101 | hdf5_file = h5py.File(hdf5_name, "w") 102 | 103 | # write data to hdf5 104 | hdf5_file.create_dataset(hdf5_path, data=write_data) 105 | hdf5_file.flush() 106 | hdf5_file.close() 107 | 108 | 109 | class HDF5ScpLoader(object): 110 | """Loader class for a fests.scp file of hdf5 file. 111 | 112 | Examples: 113 | key1 /some/path/a.h5:feats 114 | key2 /some/path/b.h5:feats 115 | key3 /some/path/c.h5:feats 116 | key4 /some/path/d.h5:feats 117 | ... 118 | >>> loader = HDF5ScpLoader("hdf5.scp") 119 | >>> array = loader["key1"] 120 | 121 | key1 /some/path/a.h5 122 | key2 /some/path/b.h5 123 | key3 /some/path/c.h5 124 | key4 /some/path/d.h5 125 | ... 126 | >>> loader = HDF5ScpLoader("hdf5.scp", "feats") 127 | >>> array = loader["key1"] 128 | 129 | """ 130 | 131 | def __init__(self, feats_scp, default_hdf5_path="feats"): 132 | """Initialize HDF5 scp loader. 133 | 134 | Args: 135 | feats_scp (str): Kaldi-style feats.scp file with hdf5 format. 136 | default_hdf5_path (str): Path in hdf5 file. If the scp contain the info, not used. 137 | 138 | """ 139 | self.default_hdf5_path = default_hdf5_path 140 | with open(feats_scp) as f: 141 | lines = [line.replace("\n", "") for line in f.readlines()] 142 | self.data = {} 143 | for line in lines: 144 | key, value = line.split() 145 | self.data[key] = value 146 | 147 | def get_path(self, key): 148 | """Get hdf5 file path for a given key.""" 149 | return self.data[key] 150 | 151 | def __getitem__(self, key): 152 | """Get ndarray for a given key.""" 153 | p = self.data[key] 154 | if ":" in p: 155 | return read_hdf5(*p.split(":")) 156 | else: 157 | return read_hdf5(p, self.default_hdf5_path) 158 | 159 | def __len__(self): 160 | """Return the length of the scp file.""" 161 | return len(self.data) 162 | 163 | def __iter__(self): 164 | """Return the iterator of the scp file.""" 165 | return iter(self.data) 166 | 167 | def keys(self): 168 | """Return the keys of the scp file.""" 169 | return self.data.keys() 170 | -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/losses/stft_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """STFT-based Loss modules.""" 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | def stft(x, fft_size, hop_size, win_length, window): 13 | """Perform STFT and convert to magnitude spectrogram. 14 | 15 | Args: 16 | x (Tensor): Input signal tensor (B, T). 17 | fft_size (int): FFT size. 18 | hop_size (int): Hop size. 19 | win_length (int): Window length. 20 | window (str): Window function type. 21 | 22 | Returns: 23 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 24 | 25 | """ 26 | x_stft = torch.stft(x, fft_size, hop_size, win_length, window) 27 | real = x_stft[..., 0] 28 | imag = x_stft[..., 1] 29 | 30 | # NOTE(kan-bayashi): clamp is needed to avoid nan or inf 31 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) 32 | 33 | 34 | class SpectralConvergengeLoss(torch.nn.Module): 35 | """Spectral convergence loss module.""" 36 | 37 | def __init__(self): 38 | """Initilize spectral convergence loss module.""" 39 | super(SpectralConvergengeLoss, self).__init__() 40 | 41 | def forward(self, x_mag, y_mag): 42 | """Calculate forward propagation. 43 | 44 | Args: 45 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 46 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 47 | 48 | Returns: 49 | Tensor: Spectral convergence loss value. 50 | 51 | """ 52 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 53 | 54 | 55 | class LogSTFTMagnitudeLoss(torch.nn.Module): 56 | """Log STFT magnitude loss module.""" 57 | 58 | def __init__(self): 59 | """Initilize los STFT magnitude loss module.""" 60 | super(LogSTFTMagnitudeLoss, self).__init__() 61 | 62 | def forward(self, x_mag, y_mag): 63 | """Calculate forward propagation. 64 | 65 | Args: 66 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 67 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 68 | 69 | Returns: 70 | Tensor: Log STFT magnitude loss value. 71 | 72 | """ 73 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) 74 | 75 | 76 | class STFTLoss(torch.nn.Module): 77 | """STFT loss module.""" 78 | 79 | def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"): 80 | """Initialize STFT loss module.""" 81 | super(STFTLoss, self).__init__() 82 | self.fft_size = fft_size 83 | self.shift_size = shift_size 84 | self.win_length = win_length 85 | self.window = getattr(torch, window)(win_length) 86 | self.spectral_convergenge_loss = SpectralConvergengeLoss() 87 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 88 | 89 | def forward(self, x, y): 90 | """Calculate forward propagation. 91 | 92 | Args: 93 | x (Tensor): Predicted signal (B, T). 94 | y (Tensor): Groundtruth signal (B, T). 95 | 96 | Returns: 97 | Tensor: Spectral convergence loss value. 98 | Tensor: Log STFT magnitude loss value. 99 | 100 | """ 101 | x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) 102 | y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) 103 | sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) 104 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 105 | 106 | return sc_loss, mag_loss 107 | 108 | 109 | class MultiResolutionSTFTLoss(torch.nn.Module): 110 | """Multi resolution STFT loss module.""" 111 | 112 | def __init__(self, 113 | fft_sizes=[1024, 2048, 512], 114 | hop_sizes=[120, 240, 50], 115 | win_lengths=[600, 1200, 240], 116 | window="hann_window"): 117 | """Initialize Multi resolution STFT loss module. 118 | 119 | Args: 120 | fft_sizes (list): List of FFT sizes. 121 | hop_sizes (list): List of hop sizes. 122 | win_lengths (list): List of window lengths. 123 | window (str): Window function type. 124 | 125 | """ 126 | super(MultiResolutionSTFTLoss, self).__init__() 127 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 128 | self.stft_losses = torch.nn.ModuleList() 129 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 130 | self.stft_losses += [STFTLoss(fs, ss, wl, window)] 131 | 132 | def forward(self, x, y): 133 | """Calculate forward propagation. 134 | 135 | Args: 136 | x (Tensor): Predicted signal (B, T). 137 | y (Tensor): Groundtruth signal (B, T). 138 | 139 | Returns: 140 | Tensor: Multi resolution spectral convergence loss value. 141 | Tensor: Multi resolution log STFT magnitude loss value. 142 | 143 | """ 144 | sc_loss = 0.0 145 | mag_loss = 0.0 146 | for f in self.stft_losses: 147 | sc_l, mag_l = f(x, y) 148 | sc_loss += sc_l 149 | mag_loss += mag_l 150 | sc_loss /= len(self.stft_losses) 151 | mag_loss /= len(self.stft_losses) 152 | 153 | return sc_loss, mag_loss 154 | -------------------------------------------------------------------------------- /preprocess/NAT_mel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from librosa.filters import mel as librosa_mel_fn 5 | from scipy.io.wavfile import read 6 | import torch 7 | import torch.nn as nn 8 | 9 | MAX_WAV_VALUE = 32768.0 10 | 11 | 12 | def load_wav(full_path): 13 | sampling_rate, data = read(full_path) 14 | return data, sampling_rate 15 | 16 | 17 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 18 | return np.log10(np.clip(x, a_min=clip_val, a_max=None) * C) 19 | 20 | 21 | def dynamic_range_decompression(x, C=1): 22 | return np.exp(x) / C 23 | 24 | 25 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 26 | return torch.log10(torch.clamp(x, min=clip_val) * C) 27 | 28 | 29 | def dynamic_range_decompression_torch(x, C=1): 30 | return torch.exp(x) / C 31 | 32 | 33 | def spectral_normalize_torch(magnitudes): 34 | output = dynamic_range_compression_torch(magnitudes) 35 | return output 36 | 37 | 38 | def spectral_de_normalize_torch(magnitudes): 39 | output = dynamic_range_decompression_torch(magnitudes) 40 | return output 41 | 42 | class MelNet(nn.Module): 43 | def __init__(self,hparams,device='cpu') -> None: 44 | super().__init__() 45 | self.n_fft = hparams['fft_size'] 46 | self.num_mels = hparams['audio_num_mel_bins'] 47 | self.sampling_rate = hparams['audio_sample_rate'] 48 | self.hop_size = hparams['hop_size'] 49 | self.win_size = hparams['win_size'] 50 | self.fmin = hparams['fmin'] 51 | self.fmax = hparams['fmax'] 52 | self.device = device 53 | mel = librosa_mel_fn(sr=self.sampling_rate, n_fft=self.n_fft, n_mels=self.num_mels, fmin=self.fmin, fmax=self.fmax) 54 | self.mel_basis = torch.from_numpy(mel).float().to(self.device) 55 | self.hann_window = torch.hann_window(self.win_size).to(self.device) 56 | 57 | def to(self,device,**kwagrs): 58 | super().to(device=device,**kwagrs) 59 | self.mel_basis = self.mel_basis.to(device) 60 | self.hann_window = self.hann_window.to(device) 61 | self.device = device 62 | 63 | def forward(self,y,center=False, complex=False): 64 | if isinstance(y,np.ndarray): 65 | y = torch.FloatTensor(y) 66 | if len(y.shape) == 1: 67 | y = y.unsqueeze(0) 68 | y = y.clamp(min=-1., max=1.).to(self.device) 69 | 70 | y = torch.nn.functional.pad(y.unsqueeze(1), [int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2)], 71 | mode='reflect') 72 | y = y.squeeze(1) 73 | 74 | spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window, 75 | center=center, pad_mode='reflect', normalized=False, onesided=True,return_complex=True) 76 | if not complex: 77 | spec = torch.view_as_real(spec) 78 | 79 | if not complex: 80 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 81 | spec = torch.matmul(self.mel_basis, spec) 82 | spec = spectral_normalize_torch(spec) 83 | else: 84 | B, C, T, _ = spec.shape 85 | spec = spec.transpose(1, 2) # [B, T, n_fft, 2] 86 | return spec 87 | 88 | ## below can be used in one gpu, but not ddp 89 | mel_basis = {} 90 | hann_window = {} 91 | 92 | 93 | def mel_spectrogram(y, hparams, center=False, complex=False): # y should be a tensor with shape (b,wav_len) 94 | # hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate) 95 | # win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate) 96 | # fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 97 | # fmax: 10000 # To be increased/reduced depending on data. 98 | # fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter 99 | # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, 100 | n_fft = hparams['fft_size'] 101 | num_mels = hparams['audio_num_mel_bins'] 102 | sampling_rate = hparams['audio_sample_rate'] 103 | hop_size = hparams['hop_size'] 104 | win_size = hparams['win_size'] 105 | fmin = hparams['fmin'] 106 | fmax = hparams['fmax'] 107 | if isinstance(y,np.ndarray): 108 | y = torch.FloatTensor(y) 109 | if len(y.shape) == 1: 110 | y = y.unsqueeze(0) 111 | y = y.clamp(min=-1., max=1.) 112 | global mel_basis, hann_window 113 | if fmax not in mel_basis: 114 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 115 | mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 116 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 117 | 118 | y = torch.nn.functional.pad(y.unsqueeze(1), [int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)], 119 | mode='reflect') 120 | y = y.squeeze(1) 121 | 122 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], 123 | center=center, pad_mode='reflect', normalized=False, onesided=True,return_complex=complex) 124 | 125 | if not complex: 126 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 127 | spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec) 128 | spec = spectral_normalize_torch(spec) 129 | else: 130 | B, C, T, _ = spec.shape 131 | spec = spec.transpose(1, 2) # [B, T, n_fft, 2] 132 | return spec 133 | -------------------------------------------------------------------------------- /ldm/models/diffusion/cfm1_audio_sampler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pytorch_memlab import LineProfiler,profile 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from einops import rearrange, repeat 9 | from contextlib import contextmanager 10 | from functools import partial 11 | from tqdm import tqdm 12 | 13 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps 14 | from torchvision.utils import make_grid 15 | try: 16 | from pytorch_lightning.utilities.distributed import rank_zero_only 17 | except: 18 | from pytorch_lightning.utilities import rank_zero_only # torch2 19 | from torchdyn.core import NeuralODE 20 | from ldm.models.diffusion.cfm1_audio import Wrapper, Wrapper_cfg 21 | from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like 22 | from omegaconf import ListConfig 23 | 24 | from ldm.util import log_txt_as_img, exists, default 25 | 26 | class CFMSampler(object): 27 | 28 | def __init__(self, model, num_timesteps, schedule="linear", **kwargs): 29 | super().__init__() 30 | self.model = model 31 | self.ddpm_num_timesteps = model.num_timesteps 32 | self.num_timesteps = num_timesteps 33 | self.schedule = schedule 34 | 35 | def register_buffer(self, name, attr): 36 | if type(attr) == torch.Tensor: 37 | if attr.device != torch.device("cuda"): 38 | attr = attr.to(torch.device("cuda")) 39 | setattr(self, name, attr) 40 | 41 | def stochastic_encode(self, x_start, t, noise=None): 42 | x1 = x_start 43 | x0 = default(noise, lambda: torch.randn_like(x_start)) 44 | t_unsqueeze = 1 - t.unsqueeze(1).unsqueeze(1).float() / self.num_timesteps 45 | x_noisy = t_unsqueeze * x1 + (1. - (1 - self.model.sigma_min) * t_unsqueeze) * x0 46 | return x_noisy 47 | 48 | @torch.no_grad() 49 | def sample(self, cond, batch_size=16, timesteps=None, shape=None, x_latent=None, t_start=None, **kwargs): 50 | 51 | # print(shape) 52 | 53 | if shape is None: 54 | if self.model.channels > 0: 55 | shape = (batch_size, self.model.channels, self.model.mel_dim, self.model.mel_length) 56 | else: 57 | shape = (batch_size, self.model.mel_dim, self.model.mel_length) 58 | # if cond is not None: 59 | # if isinstance(cond, dict): 60 | # cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else 61 | # list(map(lambda x: x[:batch_size], cond[key])) for key in cond} 62 | # else: 63 | # cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] 64 | 65 | if len(shape)==3: 66 | C, H, W = shape 67 | shape = (batch_size, C, H, W) 68 | else: 69 | C, T = shape 70 | shape = (batch_size, C, T) 71 | 72 | neural_ode = NeuralODE(self.ode_wrapper(cond), solver='euler', sensitivity="adjoint", atol=1e-4, rtol=1e-4) 73 | t_span = torch.linspace(0, 1, 25 if timesteps is None else timesteps) 74 | if t_start is not None: 75 | t_span = t_span[t_start:] 76 | 77 | x0 = torch.randn(shape, device=self.model.device) if x_latent is None else x_latent 78 | eval_points, traj = neural_ode(x0, t_span) 79 | 80 | return traj[-1], traj 81 | 82 | def ode_wrapper(self, cond): 83 | # self.estimator receives x, mask, mu, t, spk as arguments 84 | return Wrapper(self.model, cond) 85 | 86 | @torch.no_grad() 87 | def sample_cfg(self, cond, unconditional_guidance_scale, unconditional_conditioning, batch_size=16, timesteps=None, shape=None, x_latent=None, t_start=None, **kwargs): 88 | if shape is None: 89 | if self.model.channels > 0: 90 | shape = (batch_size, self.model.channels, self.model.mel_dim, self.model.mel_length) 91 | else: 92 | shape = (batch_size, self.model.mel_dim, self.model.mel_length) 93 | # if cond is not None: 94 | # if isinstance(cond, dict): 95 | # cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else 96 | # list(map(lambda x: x[:batch_size], cond[key])) for key in cond} 97 | # else: 98 | # cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] 99 | 100 | if len(shape)==3: 101 | C, H, W = shape 102 | shape = (batch_size, C, H, W) 103 | else: 104 | C, T = shape 105 | shape = (batch_size, C, T) 106 | 107 | neural_ode = NeuralODE(self.ode_wrapper_cfg(cond, unconditional_guidance_scale, unconditional_conditioning), solver='euler', sensitivity="adjoint", atol=1e-4, rtol=1e-4) 108 | t_span = torch.linspace(0, 1, 25 if timesteps is None else timesteps) 109 | 110 | if t_start is not None: 111 | t_span = t_span[t_start:] 112 | 113 | x0 = torch.randn(shape, device=self.model.device) if x_latent is None else x_latent 114 | eval_points, traj = neural_ode(x0, t_span) 115 | 116 | return traj[-1], traj 117 | 118 | def ode_wrapper_cfg(self, cond, unconditional_guidance_scale, unconditional_conditioning): 119 | # self.estimator receives x, mask, mu, t, spk as arguments 120 | return Wrapper_cfg(self.model, cond, unconditional_guidance_scale, unconditional_conditioning) 121 | 122 | -------------------------------------------------------------------------------- /utils/commons/hparams.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import yaml 4 | 5 | from utils.os_utils import remove_file 6 | 7 | global_print_hparams = True 8 | hparams = {} 9 | 10 | 11 | class Args: 12 | def __init__(self, **kwargs): 13 | for k, v in kwargs.items(): 14 | self.__setattr__(k, v) 15 | 16 | 17 | def override_config(old_config: dict, new_config: dict): 18 | for k, v in new_config.items(): 19 | if isinstance(v, dict) and k in old_config: 20 | override_config(old_config[k], new_config[k]) 21 | else: 22 | old_config[k] = v 23 | 24 | 25 | def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True, root_dir=''): 26 | if config == '' and exp_name == '': 27 | parser = argparse.ArgumentParser(description='') 28 | parser.add_argument('--config', type=str, default='', 29 | help='location of the data corpus') 30 | parser.add_argument('--exp_name', type=str, default='', help='exp_name') 31 | parser.add_argument('-hp', '--hparams', type=str, default='', 32 | help='location of the data corpus') 33 | parser.add_argument('--infer', action='store_true', help='infer') 34 | parser.add_argument('--validate', action='store_true', help='validate') 35 | parser.add_argument('--reset', action='store_true', help='reset hparams') 36 | parser.add_argument('--remove', action='store_true', help='remove old ckpt') 37 | parser.add_argument('--debug', action='store_true', help='debug') 38 | parser.add_argument('--root_dir', type=str, default='', help='root directory of the project.') 39 | args, unknown = parser.parse_known_args() 40 | print("| Unknow hparams: ", unknown) 41 | else: 42 | args = Args(config=config, exp_name=exp_name, hparams=hparams_str, 43 | infer=False, validate=False, reset=False, debug=False, remove=False, root_dir=root_dir) 44 | global hparams 45 | assert args.config != '' or args.exp_name != '' 46 | root_dir = args.root_dir 47 | if args.config != '': 48 | assert os.path.exists(os.path.join(root_dir, args.config)), f'| Wrong config path! root_dir: {root_dir}, config_path: {args.config}' 49 | 50 | config_chains = [] 51 | loaded_config = set() 52 | 53 | def load_config(config_fn): 54 | # deep first inheritance and avoid the second visit of one node 55 | if not os.path.exists(os.path.join(root_dir, config_fn)): 56 | return {} 57 | with open(os.path.join(root_dir, config_fn)) as f: 58 | hparams_ = yaml.safe_load(f) 59 | loaded_config.add(config_fn) 60 | if 'base_config' in hparams_: 61 | ret_hparams = {} 62 | if not isinstance(hparams_['base_config'], list): 63 | hparams_['base_config'] = [hparams_['base_config']] 64 | for c in hparams_['base_config']: 65 | if c.startswith('.'): 66 | c = f'{os.path.dirname(config_fn)}/{c}' 67 | c = os.path.normpath(c) 68 | if c not in loaded_config: 69 | override_config(ret_hparams, load_config(c)) 70 | override_config(ret_hparams, hparams_) 71 | else: 72 | ret_hparams = hparams_ 73 | config_chains.append(config_fn) 74 | return ret_hparams 75 | 76 | saved_hparams = {} 77 | args_work_dir = '' 78 | if args.exp_name != '': 79 | args_work_dir = os.path.join(root_dir, f'checkpoints/{args.exp_name}') 80 | ckpt_config_path = f'{args_work_dir}/config.yaml' 81 | if os.path.exists(ckpt_config_path): 82 | with open(ckpt_config_path) as f: 83 | saved_hparams_ = yaml.safe_load(f) 84 | if saved_hparams_ is not None: 85 | saved_hparams.update(saved_hparams_) 86 | hparams_ = {} 87 | if args.config != '': 88 | hparams_.update(load_config(args.config)) 89 | if not args.reset: 90 | hparams_.update(saved_hparams) 91 | hparams_['work_dir'] = args_work_dir 92 | 93 | # Support config overriding in command line. Support list type config overriding. 94 | # Examples: --hparams="a=1,b.c=2,d=[1 1 1]" 95 | if args.hparams != "": 96 | for new_hparam in args.hparams.split(","): 97 | k, v = new_hparam.split("=") 98 | v = v.strip("\'\" ") 99 | config_node = hparams_ 100 | for k_ in k.split(".")[:-1]: 101 | config_node = config_node[k_] 102 | k = k.split(".")[-1] 103 | if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]: 104 | if type(config_node[k]) == list: 105 | v = v.replace(" ", ",") 106 | config_node[k] = eval(v) 107 | else: 108 | config_node[k] = type(config_node[k])(v) 109 | if args_work_dir != '' and args.remove: 110 | answer = input("REMOVE old checkpoint? Y/N [Default: N]: ") 111 | if answer.lower() == "y": 112 | remove_file(args_work_dir) 113 | if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer: 114 | os.makedirs(hparams_['work_dir'], exist_ok=True) 115 | with open(ckpt_config_path, 'w') as f: 116 | yaml.safe_dump(hparams_, f) 117 | 118 | hparams_['infer'] = args.infer 119 | hparams_['debug'] = args.debug 120 | hparams_['validate'] = args.validate 121 | hparams_['exp_name'] = args.exp_name 122 | global global_print_hparams 123 | if global_hparams: 124 | hparams.clear() 125 | hparams.update(hparams_) 126 | if print_hparams and global_print_hparams and global_hparams: 127 | print('| Hparams chains: ', config_chains) 128 | print('| Hparams: ') 129 | for i, (k, v) in enumerate(sorted(hparams_.items())): 130 | print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "") 131 | print("") 132 | global_print_hparams = False 133 | return hparams_ 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Versatile Framework for Song Generation with Prompt-based Control 2 | 3 | #### Yu Zhang, Wenxiang Guo, Changhao Pan, Zhiyuan Zhu, Ruiqi Li, Jingyu Lu, Rongjie Huang, Ruiyuan Zhang, Zhiqing Hong, Ziyue Jiang, Zhou Zhao | Zhejiang University 4 | 5 | PyTorch implementation of AccompBand of **[VersBand (EMNLP 2025)](https://arxiv.org/abs/2504.19062): Versatile Framework for Song Generation with Prompt-based Control**. 6 | 7 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2504.19062) 8 | [![Demo](https://img.shields.io/badge/🚀%20Demo%20Page-blue)](https://aaronz345.github.io/VersBandDemo/) 9 | [![weixin](https://img.shields.io/badge/-WeChat@语音之家-000000?logo=wechat&logoColor=07C160)](https://mp.weixin.qq.com/s/fsbNdPyfEFi-_QHOCy85RQ) 10 | [![weixin](https://img.shields.io/badge/-WeChat@PaperWeekly-000000?logo=wechat&logoColor=07C160)](https://mp.weixin.qq.com/s/tIwnbBXqVEwlUKRwxPceSA) 11 | [![zhihu](https://img.shields.io/badge/-知乎-000000?logo=zhihu&logoColor=0084FF)](https://zhuanlan.zhihu.com/p/1943351555119097448) 12 | [![GitHub Stars](https://img.shields.io/github/stars/AaronZ345/VersBand?style=social&label=GitHub+Stars)](https://github.com/AaronZ345/VersBand) 13 | 14 | Visit our [demo page](https://aaronz345.github.io/VersBandDemo/) for song samples. 15 | 16 | ## News 17 | 18 | - 2025.08: We released the code of AcccompBand! 19 | - 2025.08: VersBand is accepted by EMNLP 2025! 20 | 21 | ## Key Features 22 | - We propose **VersBand**, a multi-task song generation approach for generating high-quality, aligned songs with prompt-based control. 23 | - We design a decoupled model **VocalBand**, which leverages the flow-matching method to generate singing styles, pitches, and melspectrograms, enabling fast and high-quality vocal synthesis with high-level style control. 24 | - We introduce a flow-based transformer model **AccompBand** to generate high-quality, controllable, aligned accompaniments, with the Band-MOE, selecting suitable experts for enhanced quality, alignment, and control. 25 | - Experimental results demonstrate that VersBand achieves superior objective and subjective evaluations compared to baseline models across multiple **song generation** tasks. 26 | 27 | ## Quick Start 28 | Since VocalBand is similar to our other SVS models (like [TCSinger](https://github.com/AaronZ345/TCSinger), [TechSinger](https://github.com/gwx314/TechSinger)), we only provide the code of **AccompBand** in this repo. We give an example of how you can train your own model and infer with AccompBand. 29 | 30 | To try on your own song dataset, clone this repo on your local machine with NVIDIA GPU + CUDA cuDNN and follow the instructions below. 31 | 32 | ### Dependencies 33 | 34 | A suitable [conda](https://conda.io/) environment named `versband` can be created 35 | and activated with: 36 | 37 | ``` 38 | conda create -n versband python=3.10 39 | conda install --yes --file requirements.txt 40 | conda activate versband 41 | ``` 42 | 43 | ### Multi-GPU 44 | 45 | By default, this implementation uses as many GPUs in parallel as returned by `torch.cuda.device_count()`. 46 | You can specify which GPUs to use by setting the `CUDA_DEVICES_AVAILABLE` environment variable before running the training module. 47 | 48 | ### Data Preparation 49 | 50 | 1. Crawl websites to build your own song datasets, then annotate them with automatic tools, like [source–accompaniment separation](https://github.com/Anjok07/ultimatevocalremovergui), [MIDI extraction](https://github.com/RickyL-2000/ROSVOT), [beat tracking](https://github.com/mjhydri/BeatNet), and [music caption annotation](https://github.com/seungheondoh/lp-music-caps). 51 | 2. Prepare TSV files that include at least an item_name column, and adapt preprocess/preprocess.py to parse your custom file format accordingly. 52 | 3. Preprocess the dataset: 53 | ```bash 54 | export PYTHONPATH=. 55 | python preprocess/preprocess.py 56 | ``` 57 | 58 | 4. Compute mel-spectrograms: 59 | 60 | ```bash 61 | python preprocess/mel_spec_24k.py --tsv_path ./data/music24k/music.tsv --num_gpus 4 --max_duration 20 62 | ``` 63 | 64 | 5. Post-process: 65 | 66 | ```bash 67 | python preprocess/postprocess_data.py 68 | ``` 69 | 70 | 6. Download [HIFI-GAN](https://drive.google.com/drive/folders/19DHgcdDHl0WOLulTtpSHPg9h7B7m-b_B?usp=drive_link) as the vocoder in `useful_ckpts/hifigan` and [FLAN-T5](https://huggingface.co/google/flan-t5-large) in `useful_ckpts/flan-t5-large`. 71 | 72 | ### Training AccompBand 73 | 74 | 1. Train the VAE module and duration predictor 75 | ```bash 76 | python main.py --base configs/ae_accomp.yaml -t --gpus 0,1,2,3,4,5,6,7 77 | ``` 78 | 79 | 2. Train the main VersBand model 80 | 81 | ```bash 82 | python main.py --base configs/vocal2music.yaml -t --gpus 0,1,2,3,4,5,6,7 83 | ``` 84 | 85 | *Notes* 86 | - Adjust the compression ratio in the config files (and related scripts). 87 | - Change the padding length in the dataloader as needed. 88 | 89 | ### Inference with AccompBand 90 | 91 | ```bash 92 | python scripts/test_final.py 93 | ``` 94 | 95 | *Replace the checkpoint path and CFG coefficient as required.* 96 | 97 | 98 | ## Acknowledgements 99 | 100 | This implementation uses parts of the code from the following Github repos: 101 | [Make-An-Audio-3](https://github.com/Text-to-Audio/Make-An-Audio-3), 102 | [TCSinger2](https://github.com/AaronZ345/TCSinger2) 103 | [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X) 104 | as described in our code. 105 | 106 | ## Citations ## 107 | 108 | If you find this code useful in your research, please cite our work: 109 | ```bib 110 | @article{zhang2025versatile, 111 | title={Versatile framework for song generation with prompt-based control}, 112 | author={Zhang, Yu and Guo, Wenxiang and Pan, Changhao and Zhu, Zhiyuan and Li, Ruiqi and Lu, Jingyu and Huang, Rongjie and Zhang, Ruiyuan and Hong, Zhiqing and Jiang, Ziyue and others}, 113 | journal={arXiv preprint arXiv:2504.19062}, 114 | year={2025} 115 | } 116 | ``` 117 | 118 | ## Disclaimer ## 119 | 120 | Any organization or individual is prohibited from using any technology mentioned in this paper to generate someone's songs without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws. 121 | 122 | ![visitors](https://visitor-badge.laobi.icu/badge?page_id=AaronZ345/VersBand) 123 | 124 | -------------------------------------------------------------------------------- /vocoder/parallel_wavegan/layers/upsample.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Upsampling module. 4 | 5 | This code is modified from https://github.com/r9y9/wavenet_vocoder. 6 | 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from . import Conv1d 14 | 15 | 16 | class Stretch2d(torch.nn.Module): 17 | """Stretch2d module.""" 18 | 19 | def __init__(self, x_scale, y_scale, mode="nearest"): 20 | """Initialize Stretch2d module. 21 | 22 | Args: 23 | x_scale (int): X scaling factor (Time axis in spectrogram). 24 | y_scale (int): Y scaling factor (Frequency axis in spectrogram). 25 | mode (str): Interpolation mode. 26 | 27 | """ 28 | super(Stretch2d, self).__init__() 29 | self.x_scale = x_scale 30 | self.y_scale = y_scale 31 | self.mode = mode 32 | 33 | def forward(self, x): 34 | """Calculate forward propagation. 35 | 36 | Args: 37 | x (Tensor): Input tensor (B, C, F, T). 38 | 39 | Returns: 40 | Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), 41 | 42 | """ 43 | return F.interpolate( 44 | x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) 45 | 46 | 47 | class Conv2d(torch.nn.Conv2d): 48 | """Conv2d module with customized initialization.""" 49 | 50 | def __init__(self, *args, **kwargs): 51 | """Initialize Conv2d module.""" 52 | super(Conv2d, self).__init__(*args, **kwargs) 53 | 54 | def reset_parameters(self): 55 | """Reset parameters.""" 56 | self.weight.data.fill_(1. / np.prod(self.kernel_size)) 57 | if self.bias is not None: 58 | torch.nn.init.constant_(self.bias, 0.0) 59 | 60 | 61 | class UpsampleNetwork(torch.nn.Module): 62 | """Upsampling network module.""" 63 | 64 | def __init__(self, 65 | upsample_scales, 66 | nonlinear_activation=None, 67 | nonlinear_activation_params={}, 68 | interpolate_mode="nearest", 69 | freq_axis_kernel_size=1, 70 | use_causal_conv=False, 71 | ): 72 | """Initialize upsampling network module. 73 | 74 | Args: 75 | upsample_scales (list): List of upsampling scales. 76 | nonlinear_activation (str): Activation function name. 77 | nonlinear_activation_params (dict): Arguments for specified activation function. 78 | interpolate_mode (str): Interpolation mode. 79 | freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. 80 | 81 | """ 82 | super(UpsampleNetwork, self).__init__() 83 | self.use_causal_conv = use_causal_conv 84 | self.up_layers = torch.nn.ModuleList() 85 | for scale in upsample_scales: 86 | # interpolation layer 87 | stretch = Stretch2d(scale, 1, interpolate_mode) 88 | self.up_layers += [stretch] 89 | 90 | # conv layer 91 | assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size." 92 | freq_axis_padding = (freq_axis_kernel_size - 1) // 2 93 | kernel_size = (freq_axis_kernel_size, scale * 2 + 1) 94 | if use_causal_conv: 95 | padding = (freq_axis_padding, scale * 2) 96 | else: 97 | padding = (freq_axis_padding, scale) 98 | conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) 99 | self.up_layers += [conv] 100 | 101 | # nonlinear 102 | if nonlinear_activation is not None: 103 | nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params) 104 | self.up_layers += [nonlinear] 105 | 106 | def forward(self, c): 107 | """Calculate forward propagation. 108 | 109 | Args: 110 | c : Input tensor (B, C, T). 111 | 112 | Returns: 113 | Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales). 114 | 115 | """ 116 | c = c.unsqueeze(1) # (B, 1, C, T) 117 | for f in self.up_layers: 118 | if self.use_causal_conv and isinstance(f, Conv2d): 119 | c = f(c)[..., :c.size(-1)] 120 | else: 121 | c = f(c) 122 | return c.squeeze(1) # (B, C, T') 123 | 124 | 125 | class ConvInUpsampleNetwork(torch.nn.Module): 126 | """Convolution + upsampling network module.""" 127 | 128 | def __init__(self, 129 | upsample_scales, 130 | nonlinear_activation=None, 131 | nonlinear_activation_params={}, 132 | interpolate_mode="nearest", 133 | freq_axis_kernel_size=1, 134 | aux_channels=80, 135 | aux_context_window=0, 136 | use_causal_conv=False 137 | ): 138 | """Initialize convolution + upsampling network module. 139 | 140 | Args: 141 | upsample_scales (list): List of upsampling scales. 142 | nonlinear_activation (str): Activation function name. 143 | nonlinear_activation_params (dict): Arguments for specified activation function. 144 | mode (str): Interpolation mode. 145 | freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. 146 | aux_channels (int): Number of channels of pre-convolutional layer. 147 | aux_context_window (int): Context window size of the pre-convolutional layer. 148 | use_causal_conv (bool): Whether to use causal structure. 149 | 150 | """ 151 | super(ConvInUpsampleNetwork, self).__init__() 152 | self.aux_context_window = aux_context_window 153 | self.use_causal_conv = use_causal_conv and aux_context_window > 0 154 | # To capture wide-context information in conditional features 155 | kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 156 | # NOTE(kan-bayashi): Here do not use padding because the input is already padded 157 | self.conv_in = Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False) 158 | self.upsample = UpsampleNetwork( 159 | upsample_scales=upsample_scales, 160 | nonlinear_activation=nonlinear_activation, 161 | nonlinear_activation_params=nonlinear_activation_params, 162 | interpolate_mode=interpolate_mode, 163 | freq_axis_kernel_size=freq_axis_kernel_size, 164 | use_causal_conv=use_causal_conv, 165 | ) 166 | 167 | def forward(self, c): 168 | """Calculate forward propagation. 169 | 170 | Args: 171 | c : Input tensor (B, C, T'). 172 | 173 | Returns: 174 | Tensor: Upsampled tensor (B, C, T), 175 | where T = (T' - aux_context_window * 2) * prod(upsample_scales). 176 | 177 | Note: 178 | The length of inputs considers the context window size. 179 | 180 | """ 181 | c_ = self.conv_in(c) 182 | c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_ 183 | return self.upsample(c) 184 | -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/audio.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchlibrosa.stft import Spectrogram, LogmelFilterBank 5 | 6 | def get_audio_encoder(name: str): 7 | if name == "Cnn14": 8 | return Cnn14 9 | else: 10 | raise Exception('The audio encoder name {} is incorrect or not supported'.format(name)) 11 | 12 | 13 | class ConvBlock(nn.Module): 14 | def __init__(self, in_channels, out_channels): 15 | 16 | super(ConvBlock, self).__init__() 17 | 18 | self.conv1 = nn.Conv2d(in_channels=in_channels, 19 | out_channels=out_channels, 20 | kernel_size=(3, 3), stride=(1, 1), 21 | padding=(1, 1), bias=False) 22 | 23 | self.conv2 = nn.Conv2d(in_channels=out_channels, 24 | out_channels=out_channels, 25 | kernel_size=(3, 3), stride=(1, 1), 26 | padding=(1, 1), bias=False) 27 | 28 | self.bn1 = nn.BatchNorm2d(out_channels) 29 | self.bn2 = nn.BatchNorm2d(out_channels) 30 | 31 | 32 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 33 | 34 | x = input 35 | x = F.relu_(self.bn1(self.conv1(x))) 36 | x = F.relu_(self.bn2(self.conv2(x))) 37 | if pool_type == 'max': 38 | x = F.max_pool2d(x, kernel_size=pool_size) 39 | elif pool_type == 'avg': 40 | x = F.avg_pool2d(x, kernel_size=pool_size) 41 | elif pool_type == 'avg+max': 42 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 43 | x2 = F.max_pool2d(x, kernel_size=pool_size) 44 | x = x1 + x2 45 | else: 46 | raise Exception('Incorrect argument!') 47 | 48 | return x 49 | 50 | 51 | class ConvBlock5x5(nn.Module): 52 | def __init__(self, in_channels, out_channels): 53 | 54 | super(ConvBlock5x5, self).__init__() 55 | 56 | self.conv1 = nn.Conv2d(in_channels=in_channels, 57 | out_channels=out_channels, 58 | kernel_size=(5, 5), stride=(1, 1), 59 | padding=(2, 2), bias=False) 60 | 61 | self.bn1 = nn.BatchNorm2d(out_channels) 62 | 63 | 64 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 65 | 66 | x = input 67 | x = F.relu_(self.bn1(self.conv1(x))) 68 | if pool_type == 'max': 69 | x = F.max_pool2d(x, kernel_size=pool_size) 70 | elif pool_type == 'avg': 71 | x = F.avg_pool2d(x, kernel_size=pool_size) 72 | elif pool_type == 'avg+max': 73 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 74 | x2 = F.max_pool2d(x, kernel_size=pool_size) 75 | x = x1 + x2 76 | else: 77 | raise Exception('Incorrect argument!') 78 | 79 | return x 80 | 81 | 82 | class AttBlock(nn.Module): 83 | def __init__(self, n_in, n_out, activation='linear', temperature=1.): 84 | super(AttBlock, self).__init__() 85 | 86 | self.activation = activation 87 | self.temperature = temperature 88 | self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 89 | self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 90 | 91 | self.bn_att = nn.BatchNorm1d(n_out) 92 | 93 | def forward(self, x): 94 | # x: (n_samples, n_in, n_time) 95 | norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) 96 | cla = self.nonlinear_transform(self.cla(x)) 97 | x = torch.sum(norm_att * cla, dim=2) 98 | return x, norm_att, cla 99 | 100 | def nonlinear_transform(self, x): 101 | if self.activation == 'linear': 102 | return x 103 | elif self.activation == 'sigmoid': 104 | return torch.sigmoid(x) 105 | 106 | 107 | class Cnn14(nn.Module): 108 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 109 | fmax, classes_num, out_emb): 110 | 111 | super(Cnn14, self).__init__() 112 | 113 | window = 'hann' 114 | center = True 115 | pad_mode = 'reflect' 116 | ref = 1.0 117 | amin = 1e-10 118 | top_db = None 119 | 120 | # Spectrogram extractor 121 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 122 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 123 | freeze_parameters=True) 124 | 125 | # Logmel feature extractor 126 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 127 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 128 | freeze_parameters=True) 129 | 130 | self.bn0 = nn.BatchNorm2d(64) 131 | 132 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 133 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 134 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 135 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 136 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 137 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 138 | 139 | # out_emb is 2048 for best Cnn14 140 | self.fc1 = nn.Linear(2048, out_emb, bias=True) 141 | self.fc_audioset = nn.Linear(out_emb, classes_num, bias=True) 142 | 143 | def forward(self, input, mixup_lambda=None): 144 | """ 145 | Input: (batch_size, data_length) 146 | """ 147 | 148 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 149 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 150 | 151 | x = x.transpose(1, 3) 152 | x = self.bn0(x) 153 | x = x.transpose(1, 3) 154 | 155 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 156 | x = F.dropout(x, p=0.2, training=self.training) 157 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 158 | x = F.dropout(x, p=0.2, training=self.training) 159 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 160 | x = F.dropout(x, p=0.2, training=self.training) 161 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 162 | x = F.dropout(x, p=0.2, training=self.training) 163 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 164 | x = F.dropout(x, p=0.2, training=self.training) 165 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 166 | x = F.dropout(x, p=0.2, training=self.training) 167 | x = torch.mean(x, dim=3) 168 | 169 | (x1, _) = torch.max(x, dim=2) 170 | x2 = torch.mean(x, dim=2) 171 | x = x1 + x2 172 | x = F.dropout(x, p=0.5, training=self.training) 173 | x = F.relu_(self.fc1(x)) 174 | embedding = F.dropout(x, p=0.5, training=self.training) 175 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 176 | 177 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 178 | 179 | return output_dict -------------------------------------------------------------------------------- /ldm/models/diffusion/cfm1_audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pytorch_memlab import LineProfiler,profile 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from einops import rearrange, repeat 9 | from contextlib import contextmanager 10 | from functools import partial 11 | from tqdm import tqdm 12 | 13 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps 14 | from torchvision.utils import make_grid 15 | try: 16 | from pytorch_lightning.utilities.distributed import rank_zero_only 17 | except: 18 | from pytorch_lightning.utilities import rank_zero_only # torch2 19 | from torchdyn.core import NeuralODE 20 | from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config 21 | from ldm.models.diffusion.ddpm_audio import LatentDiffusion_audio, disabled_train 22 | from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like 23 | from omegaconf import ListConfig 24 | import math 25 | 26 | __conditioning_keys__ = {'concat': 'c_concat', 27 | 'crossattn': 'c_crossattn', 28 | 'adm': 'y'} 29 | 30 | 31 | class CFM(LatentDiffusion_audio): 32 | 33 | def __init__(self, **kwargs): 34 | 35 | super(CFM, self).__init__(**kwargs) 36 | self.sigma_min = 1e-4 37 | 38 | def p_losses(self, x_start, cond, t, noise=None): 39 | x1 = x_start 40 | x0 = default(noise, lambda: torch.randn_like(x_start)) 41 | ut = x1 - (1 - self.sigma_min) * x0 # 和ut的梯度没关系 42 | t_unsqueeze = t.unsqueeze(1).unsqueeze(1).float() / self.num_timesteps 43 | x_noisy = t_unsqueeze * x1 + (1. - (1 - self.sigma_min) * t_unsqueeze) * x0 44 | 45 | model_output,lb_loss = self.apply_model(x_noisy, t, cond) 46 | 47 | loss_dict = {} 48 | prefix = 'train' if self.training else 'val' 49 | target = ut 50 | 51 | mean_dims = list(range(1,len(target.shape))) 52 | loss_simple = self.get_loss(model_output, target, mean=False).mean(dim=mean_dims) 53 | 54 | 55 | loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) 56 | loss_dict.update({f'{prefix}/lb_loss': lb_loss}) 57 | 58 | 59 | loss = loss_simple 60 | loss = self.l_simple_weight * loss.mean()+lb_loss 61 | loss_dict.update({f'{prefix}/loss': loss}) 62 | 63 | return loss, loss_dict 64 | 65 | @torch.no_grad() 66 | def sample(self, cond, batch_size=16, timesteps=None, shape=None, x_latent=None, t_start=None, **kwargs): 67 | if shape is None: 68 | mel_length = math.ceil(cond['acoustic']['acousitc'].shape[2] * 1 / 2) 69 | shape = (self.channels, self.mel_dim, mel_length) if self.channels > 0 else (self.mel_dim, mel_length) 70 | if cond is not None: 71 | if isinstance(cond, dict): 72 | cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else 73 | list(map(lambda x: x[:batch_size], cond[key])) for key in cond} 74 | else: 75 | cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] 76 | 77 | neural_ode = NeuralODE(self.ode_wrapper(cond), solver='euler', sensitivity="adjoint", atol=1e-4, rtol=1e-4) 78 | t_span = torch.linspace(0, 1, 25 if timesteps is None else timesteps) 79 | if t_start is not None: 80 | t_span = t_span[t_start:] 81 | 82 | x0 = torch.randn(shape, device=self.device) if x_latent is None else x_latent 83 | eval_points, traj = neural_ode(x0, t_span) 84 | 85 | return traj[-1], traj 86 | 87 | def ode_wrapper(self, cond): 88 | # self.estimator receives x, mask, mu, t, spk as arguments 89 | return Wrapper(self, cond) 90 | 91 | @torch.no_grad() 92 | def sample_cfg(self, cond, unconditional_guidance_scale, unconditional_conditioning, batch_size=16, timesteps=None, shape=None, x_latent=None, t_start=None, **kwargs): 93 | if shape is None: 94 | # if self.channels > 0: 95 | # shape = (batch_size, self.channels, self.mel_dim, self.mel_length) 96 | # else: 97 | # shape = (batch_size, self.mel_dim, self.mel_length) 98 | mel_length = math.ceil(cond['acoustic']['acoustic'].shape[2] * 1 / 2) 99 | shape = (self.channels, self.mel_dim, mel_length) if self.channels > 0 else (self.mel_dim, mel_length) 100 | if cond is not None: 101 | if isinstance(cond, dict): 102 | cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else 103 | list(map(lambda x: x[:batch_size], cond[key])) for key in cond} 104 | else: 105 | cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] 106 | 107 | neural_ode = NeuralODE(self.ode_wrapper_cfg(cond, unconditional_guidance_scale, unconditional_conditioning), solver='euler', sensitivity="adjoint", atol=1e-4, rtol=1e-4) 108 | t_span = torch.linspace(0, 1, 25 if timesteps is None else timesteps) 109 | 110 | if t_start is not None: 111 | t_span = t_span[t_start:] 112 | 113 | x0 = torch.randn(shape, device=self.device) if x_latent is None else x_latent 114 | eval_points, traj = neural_ode(x0, t_span) 115 | 116 | return traj[-1], traj 117 | 118 | def ode_wrapper_cfg(self, cond, unconditional_guidance_scale, unconditional_conditioning): 119 | # self.estimator receives x, mask, mu, t, spk as arguments 120 | return Wrapper_cfg(self, cond, unconditional_guidance_scale, unconditional_conditioning) 121 | 122 | 123 | @torch.no_grad() 124 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): 125 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) 126 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas 127 | if noise is None: 128 | noise = torch.randn_like(x0) 129 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + 130 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) 131 | 132 | 133 | class Wrapper(nn.Module): 134 | def __init__(self, net, cond): 135 | super(Wrapper, self).__init__() 136 | self.net = net 137 | self.cond = cond 138 | 139 | def forward(self, t, x, args): 140 | t = torch.tensor([t * 1000] * x.shape[0], device=t.device).long() 141 | results,loss= self.net.apply_model(x, t, self.cond) 142 | return results 143 | 144 | 145 | class Wrapper_cfg(nn.Module): 146 | 147 | def __init__(self, net, cond, unconditional_guidance_scale, unconditional_conditioning): 148 | super(Wrapper_cfg, self).__init__() 149 | self.net = net 150 | self.cond = cond 151 | self.unconditional_conditioning = unconditional_conditioning 152 | self.unconditional_guidance_scale = unconditional_guidance_scale 153 | 154 | def forward(self, t, x, args): 155 | # x_in = torch.cat([x] * 2) 156 | t = torch.tensor([t * 1000] * x.shape[0], device=t.device).long() 157 | # t_in = torch.cat([t] * 2) 158 | e_t,loss= self.net.apply_model(x, t, self.cond) 159 | e_t_uncond,loss= self.net.apply_model(x, t, self.unconditional_conditioning) 160 | e_t = e_t_uncond + self.unconditional_guidance_scale * (e_t - e_t_uncond) 161 | 162 | return e_t 163 | -------------------------------------------------------------------------------- /ldm/modules/losses_audio/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | 6 | sys.path.insert(0, '.') # nopep8 7 | from ldm.modules.losses_audio.vqperceptual import * 8 | 9 | def discriminator_loss_mse(disc_real_outputs, disc_generated_outputs): 10 | r_losses = 0 11 | g_losses = 0 12 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 13 | r_loss = torch.mean((1 - dr) ** 2) 14 | g_loss = torch.mean(dg ** 2) 15 | r_losses += r_loss 16 | g_losses += g_loss 17 | r_losses = r_losses / len(disc_real_outputs) 18 | g_losses = g_losses / len(disc_real_outputs) 19 | total = 0.5 * (r_losses + g_losses) 20 | return total 21 | 22 | class LPAPSWithDiscriminator(nn.Module): 23 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 24 | disc_num_layers=3, disc_in_channels=3,disc_hidden_size=64, disc_factor=1.0, disc_weight=1.0, 25 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 26 | disc_loss="hinge",r1_reg_weight=5): 27 | 28 | super().__init__() 29 | assert disc_loss in ["hinge", "vanilla","mse"] 30 | self.kl_weight = kl_weight 31 | self.pixel_weight = pixelloss_weight 32 | self.perceptual_weight = perceptual_weight 33 | if self.perceptual_weight > 0: 34 | raise RuntimeError("don't use perceptual loss") 35 | 36 | # output log variance 37 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 38 | 39 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 40 | ndf = disc_hidden_size, 41 | n_layers=disc_num_layers, 42 | use_actnorm=use_actnorm, 43 | ).apply(weights_init) # h=8,w/(2**disc_num_layers) - 2 44 | self.discriminator_iter_start = disc_start 45 | if disc_loss == "hinge": 46 | self.disc_loss = hinge_d_loss 47 | elif disc_loss == "vanilla": 48 | self.disc_loss = vanilla_d_loss 49 | elif disc_loss == 'mse': 50 | self.disc_loss = discriminator_loss_mse 51 | else: 52 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 53 | print(f"LPAPSWithDiscriminator running with {disc_loss} loss.") 54 | self.disc_factor = disc_factor 55 | self.discriminator_weight = disc_weight 56 | self.disc_conditional = disc_conditional 57 | self.r1_reg_weight = r1_reg_weight 58 | 59 | 60 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 61 | if last_layer is not None: 62 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 63 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 64 | else: 65 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 66 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 67 | 68 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 69 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 70 | d_weight = d_weight * self.discriminator_weight 71 | return d_weight 72 | 73 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 74 | global_step, last_layer=None, cond=None, split="train", weights=None): 75 | if len(inputs.shape) == 3: 76 | inputs,reconstructions = inputs.unsqueeze(1),reconstructions.unsqueeze(1) 77 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 78 | if self.perceptual_weight > 0: 79 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 80 | # print(f"p_loss {p_loss}") 81 | rec_loss = rec_loss + self.perceptual_weight * p_loss 82 | else: 83 | p_loss = torch.tensor([0.0]) 84 | 85 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 86 | weighted_nll_loss = nll_loss 87 | if weights is not None: 88 | weighted_nll_loss = weights*nll_loss 89 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 90 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 91 | kl_loss = posteriors.kl() 92 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 93 | 94 | # now the GAN part 95 | if optimizer_idx == 0: 96 | # generator update 97 | if cond is None: 98 | assert not self.disc_conditional 99 | logits_fake = self.discriminator(reconstructions.contiguous()) 100 | else: 101 | assert self.disc_conditional 102 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 103 | g_loss = -torch.mean(logits_fake) 104 | 105 | try: 106 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 107 | except RuntimeError: 108 | assert not self.training 109 | d_weight = torch.tensor(0.0) 110 | 111 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 112 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 113 | 114 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 115 | "{}/logvar".format(split): self.logvar.detach(), 116 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 117 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 118 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 119 | "{}/d_weight".format(split): d_weight.detach(), 120 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 121 | "{}/g_loss".format(split): g_loss.detach().mean(), 122 | } 123 | return loss, log 124 | 125 | if optimizer_idx == 1: 126 | # second pass for discriminator update 127 | if cond is None: 128 | d_real_in = inputs.contiguous().detach() 129 | d_real_in.requires_grad = True 130 | logits_real = self.discriminator(d_real_in) 131 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 132 | else: 133 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 134 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 135 | 136 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 137 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) # logits_real越大,logits_fake越小说明discriminator越强 138 | if self.r1_reg_weight > 0 and split=='train': 139 | r1_grads = torch.autograd.grad(outputs=[logits_real.sum()], inputs=[d_real_in], create_graph=True, only_inputs=True) 140 | r1_grads = r1_grads[0] 141 | r1_penalty = r1_grads.square().mean() 142 | d_loss += self.r1_reg_weight * r1_penalty 143 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 144 | "{}/logits_real".format(split): logits_real.detach().mean(), 145 | "{}/logits_fake".format(split): logits_fake.detach().mean() 146 | } 147 | if self.r1_reg_weight and split=='train': 148 | log["{}/r1_prnalty".format(split)] = r1_penalty 149 | return d_loss, log 150 | 151 | 152 | -------------------------------------------------------------------------------- /ldm/modules/losses_audio/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | from ldm.util import exists 6 | sys.path.insert(0, '.') # nopep8 7 | # from ldm.modules.losses_audio.lpaps import LPAPS 8 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 9 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 10 | 11 | 12 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 13 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 14 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 15 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 16 | loss_real = (weights * loss_real).sum() / weights.sum() 17 | loss_fake = (weights * loss_fake).sum() / weights.sum() 18 | d_loss = 0.5 * (loss_real + loss_fake) 19 | return d_loss 20 | 21 | def adopt_weight(weight, global_step, threshold=0, value=0.): 22 | if global_step < threshold: 23 | weight = value 24 | return weight 25 | 26 | 27 | def measure_perplexity(predicted_indices, n_embed): 28 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 29 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 30 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 31 | avg_probs = encodings.mean(0) 32 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 33 | cluster_use = torch.sum(avg_probs > 0) 34 | return perplexity, cluster_use 35 | 36 | def l1(x, y): 37 | return torch.abs(x-y) 38 | 39 | 40 | def l2(x, y): 41 | return torch.pow((x-y), 2) 42 | 43 | 44 | 45 | class DummyLoss(nn.Module): 46 | def __init__(self): 47 | super().__init__() 48 | 49 | class VQLPAPSWithDiscriminator(nn.Module): 50 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 51 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 52 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 53 | disc_ndf=64, disc_loss="hinge", n_classes=None, pixel_loss="l1"): 54 | super().__init__() 55 | assert disc_loss in ["hinge", "vanilla"] 56 | self.codebook_weight = codebook_weight 57 | self.pixel_weight = pixelloss_weight 58 | self.perceptual_loss = None # LPAPS().eval() 59 | self.perceptual_weight = perceptual_weight 60 | 61 | if pixel_loss == "l1": 62 | self.pixel_loss = l1 63 | else: 64 | self.pixel_loss = l2 65 | 66 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 67 | n_layers=disc_num_layers, 68 | use_actnorm=use_actnorm, 69 | ndf=disc_ndf 70 | ).apply(weights_init) 71 | self.discriminator_iter_start = disc_start 72 | if disc_loss == "hinge": 73 | self.disc_loss = hinge_d_loss 74 | elif disc_loss == "vanilla": 75 | self.disc_loss = vanilla_d_loss 76 | else: 77 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 78 | print(f"VQLPAPSWithDiscriminator running with {disc_loss} loss.") 79 | self.disc_factor = disc_factor 80 | self.discriminator_weight = disc_weight 81 | self.disc_conditional = disc_conditional 82 | self.n_classes = n_classes 83 | 84 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 85 | if last_layer is not None: 86 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 87 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 88 | else: 89 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 90 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 91 | 92 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 93 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 94 | d_weight = d_weight * self.discriminator_weight 95 | return d_weight 96 | 97 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 98 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 99 | if not exists(codebook_loss): 100 | codebook_loss = torch.tensor([0.]).to(inputs.device) 101 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 102 | if self.perceptual_weight > 0: 103 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | rec_loss = rec_loss + self.perceptual_weight * p_loss 105 | else: 106 | p_loss = torch.tensor([0.0]) 107 | 108 | nll_loss = rec_loss 109 | # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 110 | nll_loss = torch.mean(nll_loss) 111 | 112 | # now the GAN part 113 | if optimizer_idx == 0: 114 | # generator update 115 | if cond is None: 116 | assert not self.disc_conditional 117 | logits_fake = self.discriminator(reconstructions.contiguous()) 118 | else: 119 | assert self.disc_conditional 120 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 121 | g_loss = -torch.mean(logits_fake) 122 | 123 | try: 124 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 125 | except RuntimeError: 126 | assert not self.training 127 | d_weight = torch.tensor(0.0) 128 | 129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 130 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 131 | 132 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 133 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 134 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 135 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 136 | "{}/p_loss".format(split): p_loss.detach().mean(), 137 | "{}/d_weight".format(split): d_weight.detach(), 138 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 139 | "{}/g_loss".format(split): g_loss.detach().mean(), 140 | } 141 | # if predicted_indices is not None: 142 | # assert self.n_classes is not None 143 | # with torch.no_grad(): 144 | # perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 145 | # log[f"{split}/perplexity"] = perplexity 146 | # log[f"{split}/cluster_usage"] = cluster_usage 147 | return loss, log 148 | 149 | if optimizer_idx == 1: 150 | # second pass for discriminator update 151 | if cond is None: 152 | logits_real = self.discriminator(inputs.contiguous().detach()) 153 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 154 | else: 155 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 156 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 157 | 158 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 159 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 160 | 161 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 162 | "{}/logits_real".format(split): logits_real.detach().mean(), 163 | "{}/logits_fake".format(split): logits_fake.detach().mean() 164 | } 165 | return d_loss, log 166 | 167 | -------------------------------------------------------------------------------- /ldm/modules/discriminator/multi_window_disc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class Discriminator2DFactory(nn.Module): 7 | def __init__(self, time_length, freq_length=80, kernel=(3, 3), c_in=1, hidden_size=128, 8 | norm_type='bn', reduction='sum'):# if reduction = 'sum', return shape (B,1),else reduction shape(B,T) 9 | super(Discriminator2DFactory, self).__init__() 10 | padding = (kernel[0] // 2, kernel[1] // 2) 11 | 12 | def discriminator_block(in_filters, out_filters, first=False): 13 | """ 14 | Input: (B, in, 2H, 2W) 15 | Output:(B, out, H, W) 16 | """ 17 | conv = nn.Conv2d(in_filters, out_filters, kernel, (2, 2), padding) 18 | if norm_type == 'sn': 19 | conv = nn.utils.spectral_norm(conv) 20 | block = [ 21 | conv, # padding = kernel//2 22 | nn.LeakyReLU(0.2, inplace=True), 23 | nn.Dropout2d(0.25) 24 | ] 25 | if norm_type == 'bn' and not first: 26 | block.append(nn.BatchNorm2d(out_filters, 0.8)) 27 | if norm_type == 'in' and not first: 28 | block.append(nn.InstanceNorm2d(out_filters, affine=True)) 29 | block = nn.Sequential(*block) 30 | return block 31 | 32 | self.model = nn.ModuleList([ 33 | discriminator_block(c_in, hidden_size, first=True), 34 | discriminator_block(hidden_size, hidden_size), 35 | discriminator_block(hidden_size, hidden_size), 36 | ]) 37 | 38 | self.reduction = reduction 39 | ds_size = (time_length // 2 ** 3, (freq_length + 7) // 2 ** 3) 40 | if reduction != 'none': 41 | # The height and width of downsampled image 42 | self.adv_layer = nn.Linear(hidden_size * ds_size[0] * ds_size[1], 1) 43 | else: 44 | self.adv_layer = nn.Linear(hidden_size * ds_size[1], 1) 45 | 46 | def forward(self, x): 47 | """ 48 | 49 | :param x: [B, C, T, n_bins] 50 | :return: validity: [B, 1], h: List of hiddens 51 | """ 52 | h = [] 53 | for l in self.model: 54 | x = l(x) 55 | h.append(x) 56 | if self.reduction != 'none': 57 | x = x.view(x.shape[0], -1) 58 | validity = self.adv_layer(x) # [B, 1] 59 | else: 60 | B, _, T_, _ = x.shape 61 | x = x.transpose(1, 2).reshape(B, T_, -1) 62 | validity = self.adv_layer(x)[:, :, 0] # [B, T] 63 | return validity, h 64 | 65 | 66 | class MultiWindowDiscriminator(nn.Module): 67 | def __init__(self, time_lengths, cond_size=0, freq_length=80, kernel=(3, 3), 68 | c_in=1, hidden_size=128, norm_type='bn', reduction='sum'): 69 | super(MultiWindowDiscriminator, self).__init__() 70 | self.win_lengths = time_lengths 71 | self.reduction = reduction 72 | 73 | self.conv_layers = nn.ModuleList() 74 | if cond_size > 0: 75 | self.cond_proj_layers = nn.ModuleList() 76 | self.mel_proj_layers = nn.ModuleList() 77 | for time_length in time_lengths: 78 | conv_layer = [ 79 | Discriminator2DFactory( 80 | time_length, freq_length, kernel, c_in=c_in, hidden_size=hidden_size, 81 | norm_type=norm_type, reduction=reduction) 82 | ] 83 | self.conv_layers += conv_layer 84 | if cond_size > 0: 85 | self.cond_proj_layers.append(nn.Linear(cond_size, freq_length)) 86 | self.mel_proj_layers.append(nn.Linear(freq_length, freq_length)) 87 | 88 | def forward(self, x, x_len, cond=None, start_frames_wins=None): 89 | ''' 90 | Args: 91 | x (tensor): input mel, (B, c_in, T, n_bins). 92 | x_length (tensor): len of per mel. (B,). 93 | 94 | Returns: 95 | tensor : (B). 96 | ''' 97 | validity = [] 98 | if start_frames_wins is None: 99 | start_frames_wins = [None] * len(self.conv_layers) 100 | h = [] 101 | for i, start_frames in zip(range(len(self.conv_layers)), start_frames_wins): 102 | x_clip, c_clip, start_frames = self.clip( 103 | x, cond, x_len, self.win_lengths[i], start_frames) # x_clip:(B, 1, win_length, C) 104 | start_frames_wins[i] = start_frames 105 | if x_clip is None: 106 | continue 107 | if cond is not None: 108 | x_clip = self.mel_proj_layers[i](x_clip) # (B, 1, win_length, C) 109 | c_clip = self.cond_proj_layers[i](c_clip)[:, None] # (B, 1, win_length, C) 110 | x_clip = x_clip + c_clip 111 | x_clip, h_ = self.conv_layers[i](x_clip) 112 | h += h_ 113 | validity.append(x_clip) 114 | if len(validity) != len(self.conv_layers): 115 | return None, start_frames_wins, h 116 | if self.reduction == 'sum': 117 | validity = sum(validity) # [B] 118 | elif self.reduction == 'stack': 119 | validity = torch.stack(validity, -1) # [B, W_L] 120 | elif self.reduction == 'none': 121 | validity = torch.cat(validity, -1) # [B, W_sum] 122 | return validity, start_frames_wins, h 123 | 124 | def clip(self, x, cond, x_len, win_length, start_frames=None): 125 | '''Ramdom clip x to win_length. 126 | Args: 127 | x (tensor) : (B, c_in, T, n_bins). 128 | cond (tensor) : (B, T, H). 129 | x_len (tensor) : (B,). 130 | win_length (int): target clip length 131 | 132 | Returns: 133 | (tensor) : (B, c_in, win_length, n_bins). 134 | 135 | ''' 136 | T_start = 0 137 | T_end = x_len.max() - win_length # if x_len < win_length. None will be returned 138 | if T_end < 0: 139 | return None, None, start_frames 140 | T_end = T_end.item() 141 | if start_frames is None: 142 | start_frame = np.random.randint(low=T_start, high=T_end + 1) 143 | start_frames = [start_frame] * x.size(0) 144 | else: 145 | start_frame = start_frames[0] 146 | x_batch = x[:, :, start_frame: start_frame + win_length] 147 | c_batch = cond[:, start_frame: start_frame + win_length] if cond is not None else None 148 | return x_batch, c_batch, start_frames 149 | 150 | 151 | class Discriminator(nn.Module): 152 | def __init__(self, time_lengths=[32, 64, 128], freq_length=80, cond_size=0, kernel=(3, 3), c_in=1, 153 | hidden_size=128, norm_type='bn', reduction='sum', uncond_disc=True): 154 | super(Discriminator, self).__init__() 155 | self.time_lengths = time_lengths 156 | self.cond_size = cond_size 157 | self.reduction = reduction 158 | self.uncond_disc = uncond_disc 159 | if uncond_disc: 160 | self.discriminator = MultiWindowDiscriminator( 161 | freq_length=freq_length, 162 | time_lengths=time_lengths, 163 | kernel=kernel, 164 | c_in=c_in, hidden_size=hidden_size, norm_type=norm_type, 165 | reduction=reduction 166 | ) 167 | if cond_size > 0: 168 | self.cond_disc = MultiWindowDiscriminator( 169 | freq_length=freq_length, 170 | time_lengths=time_lengths, 171 | cond_size=cond_size, 172 | kernel=kernel, 173 | c_in=c_in, hidden_size=hidden_size, norm_type=norm_type, 174 | reduction=reduction 175 | ) 176 | 177 | def forward(self, x, cond=None,x_len=None, start_frames_wins=None): 178 | """ 179 | 180 | :param x: [B, T, 80] 181 | :param cond: [B, T, cond_size] 182 | :param return_y_only: 183 | :return: 184 | """ 185 | if len(x.shape) == 3: 186 | x = x[:, None, :, :] 187 | if x_len == None: 188 | # print("注意这里x_len的统计方式有问题这里假设padvalue是0,此外reconstruction注意传入之前就要处理好mask") 189 | x_len = x.sum([1, -1]).ne(0).int().sum([-1]) # shape(B,) 190 | ret = {'y_c': None, 'y': None} 191 | if self.uncond_disc: 192 | ret['y'], start_frames_wins, ret['h'] = self.discriminator( 193 | x, x_len, start_frames_wins=start_frames_wins) 194 | if self.cond_size > 0 and cond is not None: 195 | ret['y_c'], start_frames_wins, ret['h_c'] = self.cond_disc( 196 | x, x_len, cond, start_frames_wins=start_frames_wins) 197 | ret['start_frames_wins'] = start_frames_wins 198 | return ret -------------------------------------------------------------------------------- /ldm/modules/new_attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | 48 | class Conv1dGEGLU(nn.Module): 49 | def __init__(self, dim_in, dim_out,kernel_size = 9): 50 | super().__init__() 51 | self.proj = nn.Conv1d(dim_in, dim_out * 2,kernel_size=kernel_size,padding=kernel_size//2) 52 | 53 | def forward(self, x): 54 | x, gate = self.proj(x).chunk(2, dim=1) 55 | return x * F.gelu(gate) 56 | 57 | class Conv1dFeedForward(nn.Module): 58 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.,kernel_size = 9): 59 | super().__init__() 60 | inner_dim = int(dim * mult) 61 | dim_out = default(dim_out, dim) 62 | project_in = nn.Sequential( 63 | nn.Conv1d(dim, inner_dim,kernel_size=kernel_size,padding=kernel_size//2), 64 | nn.GELU() 65 | ) if not glu else Conv1dGEGLU(dim, inner_dim) 66 | 67 | self.net = nn.Sequential( 68 | project_in, 69 | nn.Dropout(dropout), 70 | nn.Conv1d(inner_dim, dim_out,kernel_size=kernel_size,padding=kernel_size//2) 71 | ) 72 | 73 | def forward(self, x): # x shape (B,C,T) 74 | return self.net(x) 75 | 76 | def zero_module(module): 77 | """ 78 | Zero out the parameters of a module and return it.zero-initializing the final convolutional layer in each block prior to any residual connections can accelerate training. 79 | """ 80 | for p in module.parameters(): 81 | p.detach().zero_() 82 | return module 83 | 84 | 85 | def Normalize(in_channels): 86 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 87 | 88 | 89 | class CrossAttention(nn.Module): 90 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):# 如果设置了context_dim就不是自注意力了 91 | super().__init__() 92 | inner_dim = dim_head * heads # inner_dim == SpatialTransformer.model_channels 93 | context_dim = default(context_dim, query_dim) 94 | 95 | self.scale = dim_head ** -0.5 96 | self.heads = heads 97 | 98 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 99 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 100 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 101 | 102 | self.to_out = nn.Sequential( 103 | nn.Linear(inner_dim, query_dim), 104 | nn.Dropout(dropout) 105 | ) 106 | 107 | def forward(self, x, context=None, mask=None):# x:(b,T,C), context:(b,seq_len,context_dim) 108 | h = self.heads 109 | 110 | q = self.to_q(x)# q:(b,T,inner_dim) 111 | context = default(context, x) 112 | k = self.to_k(context)# (b,seq_len,inner_dim) 113 | v = self.to_v(context)# (b,seq_len,inner_dim) 114 | 115 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))# n is seq_len for k and v 116 | 117 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # (b*head,T,seq_len) 118 | 119 | if exists(mask):# false 120 | mask = rearrange(mask, 'b ... -> b (...)') 121 | max_neg_value = -torch.finfo(sim.dtype).max 122 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 123 | sim.masked_fill_(~mask, max_neg_value) 124 | 125 | # attention, what we cannot get enough of 126 | attn = sim.softmax(dim=-1) 127 | 128 | out = einsum('b i j, b j d -> b i d', attn, v)# (b*head,T,inner_dim/head) 129 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)# (b,T,inner_dim) 130 | return self.to_out(out) 131 | 132 | class BasicTransformerBlock(nn.Module): 133 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): # 1 self 1 cross or 2 self 134 | super().__init__() 135 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention,if context is none 136 | self.ff = Conv1dFeedForward(dim, dropout=dropout, glu=gated_ff) 137 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 138 | heads=n_heads, dim_head=d_head, dropout=dropout) # use as cross attention 139 | self.norm1 = nn.LayerNorm(dim) 140 | self.norm2 = nn.LayerNorm(dim) 141 | self.norm3 = nn.LayerNorm(dim) 142 | self.checkpoint = checkpoint 143 | 144 | def forward(self, x, context=None): 145 | if context is None: 146 | return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint) 147 | else: 148 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 149 | 150 | def _forward(self, x, context=None):# x shape:(B,T,C) 151 | x = self.attn1(self.norm1(x)) + x 152 | x = self.attn2(self.norm2(x), context=context) + x 153 | 154 | x = self.ff(self.norm3(x).permute(0,2,1)).permute(0,2,1) + x 155 | return x 156 | 157 | class TemporalTransformer(nn.Module): 158 | """ 159 | Transformer block for image-like data. 160 | First, project the input (aka embedding) 161 | and reshape to b, t, d. 162 | Then apply standard transformer action. 163 | Finally, reshape to image 164 | """ 165 | def __init__(self, in_channels, n_heads, d_head, 166 | depth=1, dropout=0., context_dim=None): 167 | super().__init__() 168 | self.in_channels = in_channels 169 | inner_dim = n_heads * d_head 170 | self.norm = Normalize(in_channels) 171 | 172 | self.proj_in = nn.Conv1d(in_channels, 173 | inner_dim, 174 | kernel_size=1, 175 | stride=1, 176 | padding=0) 177 | 178 | self.transformer_blocks = nn.ModuleList( 179 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 180 | for d in range(depth)] 181 | ) 182 | 183 | self.proj_out = zero_module(nn.Conv1d(inner_dim, 184 | in_channels, 185 | kernel_size=1, 186 | stride=1, 187 | padding=0))# initialize with zero 188 | 189 | def forward(self, x, context=None):# x shape (b,c,t) 190 | # note: if no context is given, cross-attention defaults to self-attention 191 | x_in = x 192 | x = self.norm(x)# group norm 193 | x = self.proj_in(x)# no shape change 194 | x = rearrange(x,'b c t -> b t c') 195 | for block in self.transformer_blocks: 196 | x = block(x, context=context)# context shape [b,seq_len=77,context_dim] 197 | x = rearrange(x,'b t c -> b c t') 198 | 199 | x = self.proj_out(x) 200 | return x + x_in 201 | 202 | 203 | class PositionEmbedding(nn.Module): 204 | MODE_EXPAND = 'MODE_EXPAND' 205 | MODE_ADD = 'MODE_ADD' 206 | MODE_CONCAT = 'MODE_CONCAT' 207 | def __init__(self, 208 | num_embeddings, 209 | embedding_dim, 210 | mode=MODE_ADD): 211 | super(PositionEmbedding, self).__init__() 212 | self.num_embeddings = num_embeddings 213 | self.embedding_dim = embedding_dim 214 | self.mode = mode 215 | if self.mode == self.MODE_EXPAND: 216 | self.weight = nn.Parameter(torch.Tensor(num_embeddings * 2 + 1, embedding_dim)) 217 | else: 218 | self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) 219 | self.reset_parameters() 220 | 221 | def reset_parameters(self): 222 | # use xavier_normal_ to initialize 223 | torch.nn.init.xavier_normal_(self.weight) 224 | # use sin cons to initialize 225 | # X = torch.arange(self.num_embeddings, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, 226 | # torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) / self.embedding_dim) 227 | # init = torch.Tensor(self.num_embeddings,self.embedding_dim) 228 | # init[:, 0::2] = torch.sin(X) 229 | # init[:, 1::2] = torch.cos(X) 230 | # self.weight.data.copy_(init) 231 | 232 | def forward(self, x): 233 | if self.mode == self.MODE_EXPAND: 234 | indices = torch.clamp(x, -self.num_embeddings, self.num_embeddings) + self.num_embeddings 235 | return F.embedding(indices.type(torch.LongTensor), self.weight) 236 | batch_size, seq_len = x.size()[:2] 237 | embeddings = self.weight[:seq_len, :].view(1, seq_len, self.embedding_dim) 238 | if self.mode == self.MODE_ADD: 239 | return x + embeddings 240 | if self.mode == self.MODE_CONCAT: 241 | return torch.cat((x, embeddings.repeat(batch_size, 1, 1)), dim=-1) 242 | raise NotImplementedError('Unknown mode: %s' % self.mode) 243 | 244 | def extra_repr(self): 245 | return 'num_embeddings={}, embedding_dim={}, mode={}'.format( 246 | self.num_embeddings, self.embedding_dim, self.mode, 247 | ) 248 | -------------------------------------------------------------------------------- /ldm/data/joinaudiodataset_struct_sample_anylen.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Adapted from https://github.com/Text-to-Audio/Make-An-Audio/blob/main/ldm/data/joinaudiodataset_624.py 8 | import sys 9 | import numpy as np 10 | import torch 11 | from typing import TypeVar, Optional, Iterator 12 | import logging 13 | import pandas as pd 14 | from ldm.data.joinaudiodataset_anylen import * 15 | import glob 16 | logger = logging.getLogger(f'main.{__name__}') 17 | 18 | sys.path.insert(0, '.') # nopep8 19 | 20 | class JoinManifestSpecs(torch.utils.data.Dataset): 21 | def __init__(self, split, main_spec_dir_path,other_spec_dir_path, mel_num=80,mode='pad', spec_crop_len=1248,pad_value=-5,drop=0,**kwargs): 22 | super().__init__() 23 | self.split = split 24 | self.max_batch_len = spec_crop_len 25 | self.min_batch_len = 64 26 | self.min_factor = 4 27 | self.mel_num = mel_num 28 | self.drop = drop 29 | self.pad_value = pad_value 30 | assert mode in ['pad','tile'] 31 | self.collate_mode = mode 32 | manifest_files = [] 33 | for dir_path in main_spec_dir_path.split(','): 34 | manifest_files += glob.glob(f'{dir_path}/*.tsv') 35 | df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files] 36 | self.df_main = pd.concat(df_list,ignore_index=True) 37 | 38 | manifest_files = [] 39 | for dir_path in other_spec_dir_path.split(','): 40 | manifest_files += glob.glob(f'{dir_path}/*.tsv') 41 | df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files] 42 | self.df_other = pd.concat(df_list,ignore_index=True) 43 | self.df_other.reset_index(inplace=True) 44 | 45 | if split == 'train': 46 | self.dataset = self.df_main.iloc[100:] 47 | elif split == 'valid' or split == 'val': 48 | self.dataset = self.df_main.iloc[:100] 49 | elif split == 'test': 50 | self.df_main = self.add_name_num(self.df_main) 51 | self.dataset = self.df_main 52 | else: 53 | raise ValueError(f'Unknown split {split}') 54 | self.dataset.reset_index(inplace=True) 55 | print('dataset len:', len(self.dataset),"drop_rate",self.drop) 56 | 57 | def add_name_num(self,df): 58 | """each file may have different caption, we add num to filename to identify each audio-caption pair""" 59 | name_count_dict = {} 60 | change = [] 61 | for t in df.itertuples(): 62 | name = getattr(t,'name') 63 | if name in name_count_dict: 64 | name_count_dict[name] += 1 65 | else: 66 | name_count_dict[name] = 0 67 | change.append((t[0],name_count_dict[name])) 68 | for t in change: 69 | df.loc[t[0],'name'] = str(df.loc[t[0],'name']) + f'_{t[1]}' 70 | return df 71 | 72 | def ordered_indices(self): 73 | index2dur = self.dataset[['duration']].sort_values(by='duration') 74 | index2dur_other = self.df_other[['duration']].sort_values(by='duration') 75 | other_indices = list(index2dur_other.index) 76 | offset = len(self.dataset) 77 | other_indices = [x + offset for x in other_indices] 78 | return list(index2dur.index),other_indices 79 | 80 | def collater(self,inputs): 81 | to_dict = {} 82 | for l in inputs: 83 | for k,v in l.items(): 84 | if k in to_dict: 85 | to_dict[k].append(v) 86 | else: 87 | to_dict[k] = [v] 88 | 89 | if self.collate_mode == 'pad': 90 | to_dict['image'] = collate_1d_or_2d(to_dict['image'],pad_idx=self.pad_value,min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor) 91 | elif self.collate_mode == 'tile': 92 | to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'],min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor) 93 | else: 94 | raise NotImplementedError 95 | to_dict['caption'] = {'ori_caption':[c['ori_caption'] for c in to_dict['caption']], 96 | 'struct_caption':[c['struct_caption'] for c in to_dict['caption']]} 97 | 98 | return to_dict 99 | 100 | def __getitem__(self, idx): 101 | if idx < len(self.dataset): 102 | data = self.dataset.iloc[idx] 103 | p = np.random.uniform(0,1) 104 | if p > self.drop: 105 | ori_caption = data['ori_cap'] 106 | struct_caption = data['caption'] 107 | else: 108 | ori_caption = "" 109 | struct_caption = "" 110 | else: 111 | data = self.df_other.iloc[idx-len(self.dataset)] 112 | p = np.random.uniform(0,1) 113 | if p > self.drop: 114 | ori_caption = data['caption'] 115 | struct_caption = f'<{ori_caption}& all>' 116 | else: 117 | ori_caption = "" 118 | struct_caption = "" 119 | item = {} 120 | try: 121 | spec = np.load(data['mel_path']) # mel spec [80, T] 122 | if spec.shape[1] > self.max_batch_len: 123 | spec = spec[:,:self.max_batch_len] 124 | except: 125 | mel_path = data['mel_path'] 126 | print(f'corrupted:{mel_path}') 127 | spec = np.ones((self.mel_num,self.min_batch_len)).astype(np.float32)*self.pad_value 128 | 129 | item['image'] = spec 130 | item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption} 131 | if self.split == 'test': 132 | item['f_name'] = data['name'] 133 | return item 134 | 135 | def __len__(self): 136 | return len(self.dataset) + len(self.df_other) 137 | 138 | 139 | class JoinSpecsTrain(JoinManifestSpecs): 140 | def __init__(self, specs_dataset_cfg): 141 | super().__init__('train', **specs_dataset_cfg) 142 | 143 | class JoinSpecsValidation(JoinManifestSpecs): 144 | def __init__(self, specs_dataset_cfg): 145 | super().__init__('valid', **specs_dataset_cfg) 146 | 147 | class JoinSpecsTest(JoinManifestSpecs): 148 | def __init__(self, specs_dataset_cfg): 149 | super().__init__('test', **specs_dataset_cfg) 150 | 151 | 152 | 153 | class DDPIndexBatchSampler(Sampler):# 让长度相似的音频的indices合到一个batch中以避免过长的pad 154 | def __init__(self, main_indices,other_indices,batch_size, num_replicas: Optional[int] = None, 155 | rank: Optional[int] = None, shuffle: bool = True, 156 | seed: int = 0, drop_last: bool = False) -> None: 157 | if num_replicas is None: 158 | if not dist.is_initialized(): 159 | # raise RuntimeError("Requires distributed package to be available") 160 | print("Not in distributed mode") 161 | num_replicas = 1 162 | else: 163 | num_replicas = dist.get_world_size() 164 | if rank is None: 165 | if not dist.is_initialized(): 166 | # raise RuntimeError("Requires distributed package to be available") 167 | rank = 0 168 | else: 169 | rank = dist.get_rank() 170 | if rank >= num_replicas or rank < 0: 171 | raise ValueError( 172 | "Invalid rank {}, rank should be in the interval" 173 | " [0, {}]".format(rank, num_replicas - 1)) 174 | self.main_indices = main_indices 175 | self.other_indices = other_indices 176 | self.max_index = max(self.other_indices) 177 | self.num_replicas = num_replicas 178 | self.rank = rank 179 | self.epoch = 0 180 | self.drop_last = drop_last 181 | self.batch_size = batch_size 182 | self.shuffle = shuffle 183 | self.batches = self.build_batches() 184 | self.seed = seed 185 | 186 | def set_epoch(self,epoch): 187 | # print("!!!!!!!!!!!set epoch is called!!!!!!!!!!!!!!") 188 | self.epoch = epoch 189 | if self.shuffle: 190 | np.random.seed(self.seed+self.epoch) 191 | self.batches = self.build_batches() 192 | 193 | def build_batches(self): 194 | batches,batch = [],[] 195 | for index in self.main_indices: 196 | batch.append(index) 197 | if len(batch) == self.batch_size: 198 | batches.append(batch) 199 | batch = [] 200 | if not self.drop_last and len(batch) > 0: 201 | batches.append(batch) 202 | selected_others = np.random.choice(len(self.other_indices),len(batches),replace=False) 203 | for index in selected_others: 204 | if index + self.batch_size > len(self.other_indices): 205 | index = len(self.other_indices) - self.batch_size 206 | batch = [self.other_indices[index + i] for i in range(self.batch_size)] 207 | batches.append(batch) 208 | self.batches = batches 209 | if self.shuffle: 210 | self.batches = np.random.permutation(self.batches) 211 | if self.rank == 0: 212 | print(f"rank: {self.rank}, batches_num {len(self.batches)}") 213 | 214 | if self.drop_last and len(self.batches) % self.num_replicas != 0: 215 | self.batches = self.batches[:len(self.batches)//self.num_replicas*self.num_replicas] 216 | if len(self.batches) >= self.num_replicas: 217 | self.batches = self.batches[self.rank::self.num_replicas] 218 | else: # may happen in sanity checking 219 | self.batches = [self.batches[0]] 220 | if self.rank == 0: 221 | print(f"after split batches_num {len(self.batches)}") 222 | 223 | return self.batches 224 | 225 | def __iter__(self) -> Iterator[List[int]]: 226 | print(f"len(self.batches):{len(self.batches)}") 227 | for batch in self.batches: 228 | yield batch 229 | 230 | def __len__(self) -> int: 231 | return len(self.batches) 232 | -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):# 如果设置了context_dim就不是自注意力了 154 | super().__init__() 155 | inner_dim = dim_head * heads # inner_dim == SpatialTransformer.model_channels 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None):# x:(b,t,c), context:(b,seq_len,context_dim) mask shape (b,1,t,seq_len) 171 | h = self.heads 172 | 173 | q = self.to_q(x)# q:(b,t,inner_dim) 174 | context = default(context, x).contiguous() 175 | k = self.to_k(context)# (b,seq_len,inner_dim) 176 | v = self.to_v(context)# (b,seq_len,inner_dim) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))# n is seq_len for k and v 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # (b*head,t,seq_len) 181 | 182 | if exists(mask):# false 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v)# (b*head,t,inner_dim/head) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)# (b,t,inner_dim) 193 | return self.to_out(out).contiguous() 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | if context is None: 210 | return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint) 211 | else: 212 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 213 | 214 | def _forward(self, x, context=None): 215 | x = self.attn1(self.norm1(x)) + x 216 | x = self.attn2(self.norm2(x), context=context) + x 217 | x = self.ff(self.norm3(x)) + x 218 | return x 219 | 220 | 221 | class SpatialTransformer(nn.Module): 222 | """ 223 | Transformer block for image-like data. 224 | First, project the input (aka embedding) 225 | and reshape to b, t, d. 226 | Then apply standard transformer action. 227 | Finally, reshape to image 228 | """ 229 | def __init__(self, in_channels, n_heads, d_head, 230 | depth=1, dropout=0., context_dim=None): 231 | super().__init__() 232 | self.in_channels = in_channels 233 | inner_dim = n_heads * d_head 234 | self.norm = Normalize(in_channels) 235 | 236 | self.proj_in = nn.Conv2d(in_channels, 237 | inner_dim, 238 | kernel_size=1, 239 | stride=1, 240 | padding=0) 241 | 242 | self.transformer_blocks = nn.ModuleList( 243 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 244 | for d in range(depth)] 245 | ) 246 | 247 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 248 | in_channels, 249 | kernel_size=1, 250 | stride=1, 251 | padding=0)) 252 | 253 | def forward(self, x, context=None): 254 | # note: if no context is given, cross-attention defaults to self-attention 255 | b, c, h, w = x.shape # such as [2,320,10,106] 256 | x_in = x 257 | x = self.norm(x)# group norm 258 | x = self.proj_in(x)# no shape change 259 | x = rearrange(x, 'b c h w -> b (h w) c') 260 | for block in self.transformer_blocks: 261 | x = block(x, context=context)# context shape [b,seq_len=77,context_dim] 262 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 263 | x = self.proj_out(x) 264 | x_out = x + x_in 265 | return x_out -------------------------------------------------------------------------------- /ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', torch.tensor(self.global_step, dtype=torch.float32), logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | --------------------------------------------------------------------------------