├── preprocess ├── __init__.py ├── NAT_mel.py └── mel_spec.py ├── vocoder └── bigvgan │ ├── __init__.py │ ├── alias_free_torch │ ├── __init__.py │ ├── act.py │ ├── resample.py │ └── filter.py │ └── activations.py ├── ldm ├── models │ ├── diffusion │ │ ├── __init__.py │ │ └── classifier.py │ └── autoencoder_multi.py ├── modules │ ├── encoders │ │ ├── __init__.py │ │ ├── CLAP │ │ │ ├── __init__.py │ │ │ ├── config.yml │ │ │ ├── utils.py │ │ │ ├── clap.py │ │ │ ├── audio.py │ │ │ └── CLAPWrapper.py │ │ └── modules.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ └── util.py │ ├── losses_audio │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ ├── ema.py │ ├── discriminator │ │ ├── multi_window_disc.py │ │ └── model.py │ └── attention.py ├── data │ └── joinaudiodataset_624.py ├── lr_scheduler.py └── util.py ├── wav_evaluation ├── models │ ├── __init__.py │ ├── utils.py │ ├── clap.py │ └── audio.py └── cal_clap_score.py ├── useful_ckpts ├── .DS_Store └── CLAP │ └── config.yml ├── .gitignore ├── requirements.txt ├── configs ├── text_to_audio │ ├── clap_args.yaml │ ├── bigvgan_args.yaml │ └── txt2audio_args.yaml └── train │ ├── vae.yaml │ └── diffusion.yaml ├── scripts ├── test.py └── audio2audio.py ├── LICENSE ├── .gitattributes ├── gen_wav.py ├── gen_wavs_by_tsv.py └── README.md /preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vocoder/bigvgan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wav_evaluation/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import clap 2 | from . import audio 3 | from . import utils -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/__init__.py: -------------------------------------------------------------------------------- 1 | from . import clap 2 | from . import audio 3 | from . import utils -------------------------------------------------------------------------------- /useful_ckpts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Text-to-Audio/Make-An-Audio/HEAD/useful_ckpts/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *__pycache__ 3 | useful_ckpts/bigvgan 4 | useful_ckpts/*.ckpt 5 | useful_ckpts/CLAP/*.ckpt 6 | evaluation 7 | .idea/ 8 | logs 9 | audiocaps_gen 10 | audioldm_eval 11 | src 12 | processed 13 | run.sh 14 | *.DS_Store -------------------------------------------------------------------------------- /vocoder/bigvgan/alias_free_torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | from .filter import * 5 | from .resample import * 6 | from .act import * -------------------------------------------------------------------------------- /ldm/modules/losses_audio/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses_audio.vqperceptual import DummyLoss 2 | 3 | # relative imports pain 4 | import os 5 | import sys 6 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'vggishish') 7 | sys.path.append(path) 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu113 2 | torch 3 | torch-fidelity==0.3.0 4 | scipy 5 | importlib_resources 6 | torchaudio>=0.13.0 7 | torchvision>=0.14.0 8 | tqdm 9 | omegaconf 10 | einops 11 | numpy<=1.23.5 12 | soundfile 13 | librosa==0.9.2 14 | pandas 15 | torchlibrosa 16 | transformers==4.18.0 17 | ftfy 18 | pytorch-lightning==1.7.0 19 | torchmetrics==0.11.1 20 | -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 21 | -------------------------------------------------------------------------------- /useful_ckpts/CLAP/config.yml: -------------------------------------------------------------------------------- 1 | # TEXT ENCODER CONFIG 2 | text_model: 'bert-base-uncased' 3 | text_len: 100 4 | transformer_embed_dim: 768 5 | freeze_text_encoder_weights: True 6 | 7 | # AUDIO ENCODER CONFIG 8 | audioenc_name: 'Cnn14' 9 | out_emb: 2048 10 | sampling_rate: 44100 11 | duration: 9 12 | fmin: 50 13 | fmax: 14000 14 | n_fft: 1028 15 | hop_size: 320 16 | mel_bins: 64 17 | window_size: 1024 18 | 19 | # PROJECTION SPACE CONFIG 20 | d_proj: 1024 21 | temperature: 0.003 22 | 23 | # TRAINING AND EVALUATION CONFIG 24 | num_classes: 527 25 | batch_size: 1024 26 | demo: False 27 | -------------------------------------------------------------------------------- /configs/text_to_audio/clap_args.yaml: -------------------------------------------------------------------------------- 1 | # TEXT ENCODER CONFIG 2 | text_model: 'bert-base-uncased' 3 | text_len: 100 4 | transformer_embed_dim: 768 5 | freeze_text_encoder_weights: True 6 | 7 | # AUDIO ENCODER CONFIG 8 | audioenc_name: 'Cnn14' 9 | out_emb: 2048 10 | sampling_rate: 44100 11 | duration: 9 12 | fmin: 50 13 | fmax: 14000 14 | n_fft: 1028 15 | hop_size: 320 16 | mel_bins: 64 17 | window_size: 1024 18 | 19 | # PROJECTION SPACE CONFIG 20 | d_proj: 1024 21 | temperature: 0.003 22 | 23 | # TRAINING AND EVALUATION CONFIG 24 | num_classes: 527 25 | batch_size: 1024 26 | demo: False 27 | -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/config.yml: -------------------------------------------------------------------------------- 1 | # TEXT ENCODER CONFIG 2 | text_model: 'bert-base-uncased' 3 | text_len: 100 4 | transformer_embed_dim: 768 5 | freeze_text_encoder_weights: True 6 | 7 | # AUDIO ENCODER CONFIG 8 | audioenc_name: 'Cnn14' 9 | out_emb: 2048 10 | sampling_rate: 44100 11 | duration: 5 12 | fmin: 50 13 | fmax: 14000 14 | n_fft: 1028 15 | hop_size: 320 16 | mel_bins: 64 17 | window_size: 1024 18 | 19 | # PROJECTION SPACE CONFIG 20 | d_proj: 1024 21 | temperature: 0.003 22 | 23 | # TRAINING AND EVALUATION CONFIG 24 | num_classes: 527 25 | batch_size: 1024 26 | demo: False 27 | -------------------------------------------------------------------------------- /scripts/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from audioldm_eval import EvaluationHelper 3 | import argparse 4 | 5 | device = torch.device(f"cuda:{0}") 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--pred_wavsdir',type=str) 10 | parser.add_argument('--gt_wavsdir',type=str) 11 | args = parser.parse_args() 12 | return args 13 | 14 | if __name__ == '__main__': 15 | args = parse_args() 16 | generation_result_path = args.pred_wavsdir 17 | target_audio_path = args.gt_wavsdir 18 | 19 | evaluator = EvaluationHelper(16000, device) 20 | 21 | # Perform evaluation, result will be print out and saved as json 22 | metrics = evaluator.main( 23 | generation_result_path, 24 | target_audio_path, 25 | ) 26 | -------------------------------------------------------------------------------- /wav_evaluation/models/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import sys 4 | 5 | def read_config_as_args(config_path,args=None,is_config_str=False): 6 | return_dict = {} 7 | 8 | if config_path is not None: 9 | if is_config_str: 10 | yml_config = yaml.load(config_path, Loader=yaml.FullLoader) 11 | else: 12 | with open(config_path, "r") as f: 13 | yml_config = yaml.load(f, Loader=yaml.FullLoader) 14 | 15 | if args != None: 16 | for k, v in yml_config.items(): 17 | if k in args.__dict__: 18 | args.__dict__[k] = v 19 | else: 20 | sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k)) 21 | else: 22 | for k, v in yml_config.items(): 23 | return_dict[k] = v 24 | 25 | args = args if args != None else return_dict 26 | return argparse.Namespace(**args) 27 | -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import sys 4 | 5 | def read_config_as_args(config_path,args=None,is_config_str=False): 6 | return_dict = {} 7 | 8 | if config_path is not None: 9 | if is_config_str: 10 | yml_config = yaml.load(config_path, Loader=yaml.FullLoader) 11 | else: 12 | with open(config_path, "r") as f: 13 | yml_config = yaml.load(f, Loader=yaml.FullLoader) 14 | 15 | if args != None: 16 | for k, v in yml_config.items(): 17 | if k in args.__dict__: 18 | args.__dict__[k] = v 19 | else: 20 | sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k)) 21 | else: 22 | for k, v in yml_config.items(): 23 | return_dict[k] = v 24 | 25 | args = args if args != None else return_dict 26 | return argparse.Namespace(**args) 27 | -------------------------------------------------------------------------------- /vocoder/bigvgan/alias_free_torch/act.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from .resample import UpSample1d, DownSample1d 6 | 7 | 8 | class Activation1d(nn.Module): 9 | def __init__(self, 10 | activation, 11 | up_ratio: int = 2, 12 | down_ratio: int = 2, 13 | up_kernel_size: int = 12, 14 | down_kernel_size: int = 12): 15 | super().__init__() 16 | self.up_ratio = up_ratio 17 | self.down_ratio = down_ratio 18 | self.act = activation 19 | self.upsample = UpSample1d(up_ratio, up_kernel_size) 20 | self.downsample = DownSample1d(down_ratio, down_kernel_size) 21 | 22 | # x: [B,C,T] 23 | def forward(self, x): 24 | x = self.upsample(x) 25 | x = self.act(x) 26 | x = self.downsample(x) 27 | 28 | return x -------------------------------------------------------------------------------- /configs/text_to_audio/bigvgan_args.yaml: -------------------------------------------------------------------------------- 1 | resblock: '1' 2 | num_gpus: 0 3 | batch_size: 64 4 | num_mels: 80 5 | learning_rate: 0.0001 6 | adam_b1: 0.8 7 | adam_b2: 0.99 8 | lr_decay: 0.999 9 | seed: 1234 10 | upsample_rates: 11 | - 4 12 | - 4 13 | - 2 14 | - 2 15 | - 2 16 | - 2 17 | upsample_kernel_sizes: 18 | - 8 19 | - 8 20 | - 4 21 | - 4 22 | - 4 23 | - 4 24 | upsample_initial_channel: 1536 25 | resblock_kernel_sizes: 26 | - 3 27 | - 7 28 | - 11 29 | resblock_dilation_sizes: 30 | - - 1 31 | - 3 32 | - 5 33 | - - 1 34 | - 3 35 | - 5 36 | - - 1 37 | - 3 38 | - 5 39 | activation: snakebeta 40 | snake_logscale: true 41 | resolutions: 42 | - - 1024 43 | - 120 44 | - 600 45 | - - 2048 46 | - 240 47 | - 1200 48 | - - 512 49 | - 50 50 | - 240 51 | mpd_reshapes: 52 | - 2 53 | - 3 54 | - 5 55 | - 7 56 | - 11 57 | use_spectral_norm: false 58 | discriminator_channel_mult: 1 59 | num_workers: 4 60 | dist_config: 61 | dist_backend: nccl 62 | dist_url: tcp://localhost:54341 63 | world_size: 1 64 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Text-to-Audio 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tflite filter=lfs diff=lfs merge=lfs -text 29 | *.tgz filter=lfs diff=lfs merge=lfs -text 30 | *.wasm filter=lfs diff=lfs merge=lfs -text 31 | *.xz filter=lfs diff=lfs merge=lfs -text 32 | *.zip filter=lfs diff=lfs merge=lfs -text 33 | *.zst filter=lfs diff=lfs merge=lfs -text 34 | *tfevents* filter=lfs diff=lfs merge=lfs -text 35 | -------------------------------------------------------------------------------- /configs/train/vae.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 4 7 | ddconfig: 8 | double_z: true 9 | z_channels: 4 10 | resolution: 624 11 | in_channels: 1 12 | out_ch: 1 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 2 17 | - 2 18 | - 4 19 | num_res_blocks: 2 20 | attn_resolutions: 21 | - 78 22 | - 156 23 | dropout: 0.0 24 | lossconfig: 25 | target: ldm.modules.losses_audio.contperceptual.LPAPSWithDiscriminator 26 | params: 27 | disc_start: 50001 28 | kl_weight: 1.0e-06 29 | perceptual_weight: 0.0 30 | disc_weight: 0.5 31 | disc_in_channels: 1 32 | disc_conditional: false 33 | 34 | lightning: 35 | callbacks: 36 | image_logger: 37 | target: main.AudioLogger 38 | params: 39 | sample_rate: 16000 40 | for_specs: true 41 | increase_log_steps: false 42 | batch_frequency: 5000 43 | max_images: 8 44 | melvmin: -5 45 | melvmax: 1.5 46 | vocoder_cfg: 47 | target: vocoder.bigvgan.models.VocoderBigVGAN 48 | params: 49 | ckpt_vocoder: useful_ckpts/bigvnat 50 | trainer: 51 | strategy: ddp 52 | gpus: 0,1,2,3,4,5,6,7 53 | 54 | 55 | data: 56 | target: main.SpectrogramDataModuleFromConfig 57 | params: 58 | batch_size: 4 59 | num_workers: 16 60 | spec_dir_path: data 61 | spec_crop_len: 624 62 | drop: 0.1 63 | train: 64 | target: ldm.data.joinaudiodataset_624.JoinSpecsTrain 65 | params: 66 | specs_dataset_cfg: null 67 | validation: 68 | target: ldm.data.joinaudiodataset_624.JoinSpecsValidation 69 | params: 70 | specs_dataset_cfg: null 71 | 72 | -------------------------------------------------------------------------------- /vocoder/bigvgan/alias_free_torch/resample.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from .filter import LowPassFilter1d 7 | from .filter import kaiser_sinc_filter1d 8 | 9 | 10 | class UpSample1d(nn.Module): 11 | def __init__(self, ratio=2, kernel_size=None): 12 | super().__init__() 13 | self.ratio = ratio 14 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 15 | self.stride = ratio 16 | self.pad = self.kernel_size // ratio - 1 17 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 18 | self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 19 | filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, 20 | half_width=0.6 / ratio, 21 | kernel_size=self.kernel_size) 22 | self.register_buffer("filter", filter) 23 | 24 | # x: [B, C, T] 25 | def forward(self, x): 26 | _, C, _ = x.shape 27 | 28 | x = F.pad(x, (self.pad, self.pad), mode='replicate') 29 | x = self.ratio * F.conv_transpose1d( 30 | x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) 31 | x = x[..., self.pad_left:-self.pad_right] 32 | 33 | return x 34 | 35 | 36 | class DownSample1d(nn.Module): 37 | def __init__(self, ratio=2, kernel_size=None): 38 | super().__init__() 39 | self.ratio = ratio 40 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 41 | self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, 42 | half_width=0.6 / ratio, 43 | stride=ratio, 44 | kernel_size=self.kernel_size) 45 | 46 | def forward(self, x): 47 | xx = self.lowpass(x) 48 | 49 | return xx -------------------------------------------------------------------------------- /configs/text_to_audio/txt2audio_args.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-05 3 | target: ldm.models.diffusion.ddpm_audio.LatentDiffusion_audio 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 # unused 13 | mel_dim: 10 # 80 // 2^3 14 | mel_length: 78 # 624 // 2^3 15 | channels: 4 16 | cond_stage_trainable: false 17 | conditioning_key: crossattn 18 | monitor: val/loss_simple_ema 19 | scale_by_std: True 20 | use_ema: False 21 | 22 | scheduler_config: # 10000 warmup steps 23 | target: ldm.lr_scheduler.LambdaLinearScheduler 24 | params: 25 | warm_up_steps: [10000] 26 | cycle_lengths: [10000000000000] 27 | f_start: [1.e-6] 28 | f_max: [1.] 29 | f_min: [ 1.] 30 | 31 | unet_config: 32 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 33 | params: 34 | image_size: 32 # ununsed 35 | in_channels: 4 36 | out_channels: 4 37 | model_channels: 320 38 | attention_resolutions: 39 | - 1 40 | - 2 41 | num_res_blocks: 2 42 | channel_mult: # num_down = len(ch_mult)-1 43 | - 1 44 | - 2 45 | num_heads: 8 46 | use_spatial_transformer: true 47 | transformer_depth: 1 48 | context_dim: 1024 49 | use_checkpoint: true 50 | legacy: False 51 | 52 | first_stage_config: 53 | target: ldm.models.autoencoder.AutoencoderKL 54 | params: 55 | embed_dim: 4 56 | monitor: val/rec_loss 57 | ckpt_path: 58 | ddconfig: 59 | double_z: true 60 | z_channels: 4 61 | resolution: 624 62 | in_channels: 1 63 | out_ch: 1 64 | ch: 128 65 | ch_mult: [ 1, 2, 2, 4 ] # num_down = len(ch_mult)-1 66 | num_res_blocks: 2 67 | attn_resolutions: [78, 156] 68 | dropout: 0.0 69 | lossconfig: 70 | target: torch.nn.Identity 71 | 72 | cond_stage_config: 73 | target: ldm.modules.encoders.modules.FrozenCLAPEmbedder 74 | params: 75 | weights_path: # useful_ckpts/CLAP/CLAP_weights_2022.pth 76 | 77 | ckpt_path: useful_ckpts/maa1_caps.ckpt 78 | 79 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /wav_evaluation/models/clap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from transformers import AutoModel 6 | from .audio import get_audio_encoder 7 | 8 | class Projection(nn.Module): 9 | def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None: 10 | super().__init__() 11 | self.linear1 = nn.Linear(d_in, d_out, bias=False) 12 | self.linear2 = nn.Linear(d_out, d_out, bias=False) 13 | self.layer_norm = nn.LayerNorm(d_out) 14 | self.drop = nn.Dropout(p) 15 | 16 | def forward(self, x: torch.Tensor) -> torch.Tensor: 17 | embed1 = self.linear1(x) 18 | embed2 = self.drop(self.linear2(F.gelu(embed1))) 19 | embeds = self.layer_norm(embed1 + embed2) 20 | return embeds 21 | 22 | class AudioEncoder(nn.Module): 23 | def __init__(self, audioenc_name:str, d_in: int, d_out: int, sample_rate: int, window_size: int, 24 | hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None: 25 | super().__init__() 26 | 27 | audio_encoder = get_audio_encoder(audioenc_name) 28 | 29 | self.base = audio_encoder( 30 | sample_rate, window_size, 31 | hop_size, mel_bins, fmin, fmax, 32 | classes_num, d_in) 33 | 34 | self.projection = Projection(d_in, d_out) 35 | 36 | def forward(self, x): 37 | out_dict = self.base(x) 38 | audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output'] 39 | projected_vec = self.projection(audio_features) 40 | return projected_vec, audio_classification_output 41 | 42 | class TextEncoder(nn.Module): 43 | def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None: 44 | super().__init__() 45 | self.base = AutoModel.from_pretrained(text_model) 46 | 47 | self.projection = Projection(transformer_embed_dim, d_out) 48 | 49 | def forward(self, x): 50 | out = self.base(**x)[0] 51 | out = out[:, 0, :] # get CLS token output 52 | projected_vec = self.projection(out) 53 | return projected_vec 54 | 55 | class CLAP(nn.Module): 56 | def __init__(self, 57 | # audio 58 | audioenc_name: str, 59 | sample_rate: int, 60 | window_size: int, 61 | hop_size: int, 62 | mel_bins: int, 63 | fmin: int, 64 | fmax: int, 65 | classes_num: int, 66 | out_emb: int, 67 | # text 68 | text_model: str, 69 | transformer_embed_dim: int, 70 | # common 71 | d_proj: int, 72 | ): 73 | super().__init__() 74 | 75 | 76 | self.audio_encoder = AudioEncoder( 77 | audioenc_name, out_emb, d_proj, 78 | sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num) 79 | 80 | self.caption_encoder = TextEncoder( 81 | d_proj, text_model, transformer_embed_dim 82 | ) 83 | 84 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 85 | 86 | def forward(self, audio, text): 87 | audio_embed, _ = self.audio_encoder(audio) 88 | caption_embed = self.caption_encoder(text) 89 | 90 | return caption_embed, audio_embed, self.logit_scale.exp() -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/data/joinaudiodataset_624.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import torch 4 | import logging 5 | import pandas as pd 6 | import glob 7 | logger = logging.getLogger(f'main.{__name__}') 8 | 9 | sys.path.insert(0, '.') # nopep8 10 | 11 | class JoinManifestSpecs(torch.utils.data.Dataset): 12 | def __init__(self, split, spec_dir_path, mel_num=None, spec_crop_len=None,drop=0,**kwargs): 13 | super().__init__() 14 | self.split = split 15 | self.batch_max_length = spec_crop_len 16 | self.batch_min_length = 50 17 | self.mel_num = mel_num 18 | self.drop = drop 19 | manifest_files = [] 20 | for dir_path in spec_dir_path.split(','): 21 | manifest_files += glob.glob(f'{dir_path}/**/*.tsv',recursive=True) 22 | df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files] 23 | df = pd.concat(df_list,ignore_index=True) 24 | 25 | if split == 'train': 26 | self.dataset = df.iloc[100:] 27 | elif split == 'valid' or split == 'val': 28 | self.dataset = df.iloc[:100] 29 | elif split == 'test': 30 | df = self.add_name_num(df) 31 | self.dataset = df 32 | else: 33 | raise ValueError(f'Unknown split {split}') 34 | self.dataset.reset_index(inplace=True) 35 | print('dataset len:', len(self.dataset)) 36 | 37 | def add_name_num(self,df): 38 | """each file may have different caption, we add num to filename to identify each audio-caption pair""" 39 | name_count_dict = {} 40 | change = [] 41 | for t in df.itertuples(): 42 | name = getattr(t,'name') 43 | if name in name_count_dict: 44 | name_count_dict[name] += 1 45 | else: 46 | name_count_dict[name] = 0 47 | change.append((t[0],name_count_dict[name])) 48 | for t in change: 49 | df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}' 50 | return df 51 | 52 | def __getitem__(self, idx): 53 | data = self.dataset.iloc[idx] 54 | item = {} 55 | try: 56 | spec = np.load(data['mel_path']) # mel spec [80, 624] 57 | except: 58 | mel_path = data['mel_path'] 59 | print(f'corrupted:{mel_path}') 60 | spec = np.zeros((self.mel_num,self.batch_max_length)).astype(np.float32) 61 | 62 | if spec.shape[1] < self.batch_max_length: 63 | spec = np.tile(spec,reps=(self.batch_max_length//spec.shape[1])+1) 64 | 65 | item['image'] = spec[:,:self.batch_max_length] 66 | p = np.random.uniform(0,1) 67 | if p > self.drop: 68 | item["caption"] = data['caption'] 69 | else: 70 | item["caption"] = "" 71 | if self.split == 'test': 72 | item['f_name'] = data['name'] 73 | return item 74 | 75 | def __len__(self): 76 | return len(self.dataset) 77 | 78 | 79 | class JoinSpecsTrain(JoinManifestSpecs): 80 | def __init__(self, specs_dataset_cfg): 81 | super().__init__('train', **specs_dataset_cfg) 82 | 83 | class JoinSpecsValidation(JoinManifestSpecs): 84 | def __init__(self, specs_dataset_cfg): 85 | super().__init__('valid', **specs_dataset_cfg) 86 | 87 | class JoinSpecsTest(JoinManifestSpecs): 88 | def __init__(self, specs_dataset_cfg): 89 | super().__init__('test', **specs_dataset_cfg) 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/clap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from transformers import AutoModel 6 | from .audio import get_audio_encoder 7 | 8 | class Projection(nn.Module): 9 | def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None: 10 | super().__init__() 11 | self.linear1 = nn.Linear(d_in, d_out, bias=False) 12 | self.linear2 = nn.Linear(d_out, d_out, bias=False) 13 | self.layer_norm = nn.LayerNorm(d_out) 14 | self.drop = nn.Dropout(p) 15 | 16 | def forward(self, x: torch.Tensor) -> torch.Tensor: 17 | embed1 = self.linear1(x) 18 | embed2 = self.drop(self.linear2(F.gelu(embed1))) 19 | embeds = self.layer_norm(embed1 + embed2) 20 | return embeds 21 | 22 | class AudioEncoder(nn.Module): 23 | def __init__(self, audioenc_name:str, d_in: int, d_out: int, sample_rate: int, window_size: int, 24 | hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None: 25 | super().__init__() 26 | 27 | audio_encoder = get_audio_encoder(audioenc_name) 28 | 29 | self.base = audio_encoder( 30 | sample_rate, window_size, 31 | hop_size, mel_bins, fmin, fmax, 32 | classes_num, d_in) 33 | 34 | self.projection = Projection(d_in, d_out) 35 | 36 | def forward(self, x): 37 | out_dict = self.base(x) 38 | audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output'] 39 | projected_vec = self.projection(audio_features) 40 | return projected_vec, audio_classification_output 41 | 42 | class TextEncoder(nn.Module): 43 | def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None: 44 | super().__init__() 45 | self.base = AutoModel.from_pretrained(text_model) 46 | self.projection = Projection(transformer_embed_dim, d_out) 47 | 48 | def forward(self, x): 49 | out = self.base(**x)[0] 50 | out = out[:, 0, :] # get CLS token output 51 | projected_vec = self.projection(out) 52 | return projected_vec 53 | 54 | class CLAP(nn.Module): 55 | def __init__(self, 56 | # audio 57 | audioenc_name: str, 58 | sample_rate: int, 59 | window_size: int, 60 | hop_size: int, 61 | mel_bins: int, 62 | fmin: int, 63 | fmax: int, 64 | classes_num: int, 65 | out_emb: int, 66 | # text 67 | text_model: str, 68 | transformer_embed_dim: int, 69 | # common 70 | d_proj: int, 71 | ): 72 | super().__init__() 73 | 74 | 75 | self.audio_encoder = AudioEncoder( 76 | audioenc_name, out_emb, d_proj, 77 | sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num) 78 | 79 | self.caption_encoder = TextEncoder( 80 | d_proj, text_model, transformer_embed_dim 81 | ) 82 | 83 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 84 | 85 | def forward(self, audio, text): 86 | audio_embed, _ = self.audio_encoder(audio) 87 | caption_embed = self.caption_encoder(text) 88 | 89 | return caption_embed, audio_embed, self.logit_scale.exp() -------------------------------------------------------------------------------- /vocoder/bigvgan/alias_free_torch/filter.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | 9 | if 'sinc' in dir(torch): 10 | sinc = torch.sinc 11 | else: 12 | # This code is adopted from adefossez's julius.core.sinc under the MIT License 13 | # https://adefossez.github.io/julius/julius/core.html 14 | # LICENSE is in incl_licenses directory. 15 | def sinc(x: torch.Tensor): 16 | """ 17 | Implementation of sinc, i.e. sin(pi * x) / (pi * x) 18 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`! 19 | """ 20 | return torch.where(x == 0, 21 | torch.tensor(1., device=x.device, dtype=x.dtype), 22 | torch.sin(math.pi * x) / math.pi / x) 23 | 24 | 25 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License 26 | # https://adefossez.github.io/julius/julius/lowpass.html 27 | # LICENSE is in incl_licenses directory. 28 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] 29 | even = (kernel_size % 2 == 0) 30 | half_size = kernel_size // 2 31 | 32 | #For kaiser window 33 | delta_f = 4 * half_width 34 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 35 | if A > 50.: 36 | beta = 0.1102 * (A - 8.7) 37 | elif A >= 21.: 38 | beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) 39 | else: 40 | beta = 0. 41 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) 42 | 43 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio 44 | if even: 45 | time = (torch.arange(-half_size, half_size) + 0.5) 46 | else: 47 | time = torch.arange(kernel_size) - half_size 48 | if cutoff == 0: 49 | filter_ = torch.zeros_like(time) 50 | else: 51 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) 52 | # Normalize filter to have sum = 1, otherwise we will have a small leakage 53 | # of the constant component in the input signal. 54 | filter_ /= filter_.sum() 55 | filter = filter_.view(1, 1, kernel_size) 56 | 57 | return filter 58 | 59 | 60 | class LowPassFilter1d(nn.Module): 61 | def __init__(self, 62 | cutoff=0.5, 63 | half_width=0.6, 64 | stride: int = 1, 65 | padding: bool = True, 66 | padding_mode: str = 'replicate', 67 | kernel_size: int = 12): 68 | # kernel_size should be even number for stylegan3 setup, 69 | # in this implementation, odd number is also possible. 70 | super().__init__() 71 | if cutoff < -0.: 72 | raise ValueError("Minimum cutoff must be larger than zero.") 73 | if cutoff > 0.5: 74 | raise ValueError("A cutoff above 0.5 does not make sense.") 75 | self.kernel_size = kernel_size 76 | self.even = (kernel_size % 2 == 0) 77 | self.pad_left = kernel_size // 2 - int(self.even) 78 | self.pad_right = kernel_size // 2 79 | self.stride = stride 80 | self.padding = padding 81 | self.padding_mode = padding_mode 82 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) 83 | self.register_buffer("filter", filter) 84 | 85 | #input [B, C, T] 86 | def forward(self, x): 87 | _, C, _ = x.shape 88 | 89 | if self.padding: 90 | x = F.pad(x, (self.pad_left, self.pad_right), 91 | mode=self.padding_mode) 92 | out = F.conv1d(x, self.filter.expand(C, -1, -1), 93 | stride=self.stride, groups=C) 94 | 95 | return out -------------------------------------------------------------------------------- /configs/train/diffusion.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-05 3 | target: ldm.models.diffusion.ddpm_audio.LatentDiffusion_audio 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 # unused 13 | mel_dim: 10 # 80 // 2^3 14 | mel_length: 78 # 624 // 2^3 15 | channels: 4 16 | cond_stage_trainable: false 17 | conditioning_key: crossattn 18 | monitor: val/loss_simple_ema 19 | scale_by_std: True 20 | use_ema: False 21 | 22 | scheduler_config: # 10000 warmup steps 23 | target: ldm.lr_scheduler.LambdaLinearScheduler 24 | params: 25 | warm_up_steps: [10000] 26 | cycle_lengths: [10000000000000] 27 | f_start: [1.e-6] 28 | f_max: [1.] 29 | f_min: [ 1.] 30 | 31 | unet_config: 32 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 33 | params: 34 | image_size: 32 # ununsed 35 | in_channels: 4 36 | out_channels: 4 37 | model_channels: 320 38 | attention_resolutions: 39 | - 1 40 | - 2 41 | num_res_blocks: 2 42 | channel_mult: # num_down = len(ch_mult)-1 43 | - 1 44 | - 2 45 | num_heads: 8 46 | use_spatial_transformer: true 47 | transformer_depth: 1 48 | context_dim: 1024 49 | use_checkpoint: true 50 | legacy: False 51 | 52 | first_stage_config: 53 | target: ldm.models.autoencoder.AutoencoderKL 54 | params: 55 | embed_dim: 4 56 | monitor: val/rec_loss 57 | ckpt_path: save your pretrained vae path here 58 | ddconfig: 59 | double_z: true 60 | z_channels: 4 61 | resolution: 624 62 | in_channels: 1 63 | out_ch: 1 64 | ch: 128 65 | ch_mult: [ 1, 2, 2, 4 ] # num_down = len(ch_mult)-1 66 | num_res_blocks: 2 67 | attn_resolutions: [78, 156] 68 | dropout: 0.0 69 | lossconfig: 70 | target: torch.nn.Identity 71 | 72 | cond_stage_config: 73 | target: ldm.modules.encoders.modules.FrozenCLAPEmbedder 74 | params: 75 | weights_path: useful_ckpts/CLAP/CLAP_weights_2022.pth 76 | 77 | 78 | data: 79 | target: main.SpectrogramDataModuleFromConfig 80 | params: 81 | batch_size: 4 82 | num_workers: 16 83 | spec_dir_path: data 84 | spec_crop_len: 624 85 | drop: 0.1 86 | train: 87 | target: ldm.data.joinaudiodataset_624.JoinSpecsTrain 88 | params: 89 | specs_dataset_cfg: null 90 | validation: 91 | target: ldm.data.joinaudiodataset_624.JoinSpecsValidation 92 | params: 93 | specs_dataset_cfg: null 94 | 95 | lightning: 96 | callbacks: 97 | image_logger: 98 | target: main.AudioLogger 99 | params: 100 | sample_rate: 16000 101 | for_specs: true 102 | increase_log_steps: false 103 | batch_frequency: 5000 104 | max_images: 8 105 | melvmin: -5 106 | melvmax: 1.5 107 | vocoder_cfg: 108 | target: vocoder.bigvgan.models.VocoderBigVGAN 109 | params: 110 | ckpt_vocoder: useful_ckpts/bigvnat 111 | trainer: 112 | benchmark: true 113 | gradient_clip_val: 1.0 114 | strategy: ddp 115 | gpus: 0,1,2,3 116 | modelcheckpoint: 117 | params: 118 | monitor: epoch 119 | mode: max 120 | save_top_k: 8 121 | every_n_epochs: 8 122 | 123 | test_dataset: 124 | target: ldm.data.tsvdataset.TSVDataset 125 | params: 126 | tsv_path: data/audiocaps_test.tsv 127 | spec_crop_len: 624 128 | 129 | -------------------------------------------------------------------------------- /wav_evaluation/cal_clap_score.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import sys 3 | import os 4 | directory = pathlib.Path(os.getcwd()) 5 | sys.path.append(str(directory)) 6 | import torch 7 | import numpy as np 8 | from wav_evaluation.models.CLAPWrapper import CLAPWrapper 9 | import argparse 10 | from tqdm import tqdm 11 | import pandas as pd 12 | import json 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--tsv_path',type=str,default='') 18 | parser.add_argument('--wavsdir',type=str) 19 | parser.add_argument('--mean',type=bool,default=True) 20 | parser.add_argument('--ckpt_path', default="useful_ckpts/CLAP") 21 | args = parser.parse_args() 22 | return args 23 | 24 | def add_audio_path(df): 25 | df['audio_path'] = df.apply(lambda x:x['mel_path'].replace('.npy','.wav'),axis=1) 26 | return df 27 | 28 | def build_tsv_from_wavs(root_dir): 29 | with open('ldm/data/audiocaps_fn2cap.json','r') as f: 30 | fn2cap = json.load(f) 31 | if os.path.exists(os.path.join(root_dir,'fake_class')): 32 | wavs_root = os.path.join(root_dir,'fake_class') 33 | else: 34 | wavs_root = root_dir 35 | wavfiles = os.listdir(wavs_root) 36 | wavfiles = list(filter(lambda x:x.endswith('.wav') and x[-6:-4]!='gt',wavfiles)) 37 | print(len(wavfiles)) 38 | dict_list = [] 39 | for wavfile in wavfiles: 40 | tmpd = {'audio_path':os.path.join(wavs_root,wavfile)} 41 | key = wavfile.rsplit('_sample')[0] + wavfile.rsplit('_sample')[1][:2] 42 | tmpd['caption'] = fn2cap[key] 43 | dict_list.append(tmpd) 44 | df = pd.DataFrame.from_dict(dict_list) 45 | tsv_path = f'{os.path.basename(root_dir)}.tsv' 46 | tsv_path = os.path.join(wavs_root,tsv_path) 47 | df.to_csv(tsv_path,sep='\t',index=False) 48 | return tsv_path 49 | 50 | def cal_score_by_tsv(tsv_path,clap_model): # audiocaps val的gt音频的clap_score计算为0.479077 51 | df = pd.read_csv(tsv_path,sep='\t') 52 | clap_scores = [] 53 | if not ('audio_path' in df.columns): 54 | df = add_audio_path(df) 55 | caption_list,audio_list = [],[] 56 | with torch.no_grad(): 57 | for idx,t in enumerate(tqdm(df.itertuples()),start=1): 58 | caption_list.append(getattr(t,'caption')) 59 | audio_list.append(getattr(t,'audio_path')) 60 | if idx % 20 == 0: 61 | text_embeddings = clap_model.get_text_embeddings(caption_list)# 经过了norm的embedding 62 | audio_embeddings = clap_model.get_audio_embeddings(audio_list, resample=True)# 这一步比较耗时,读取音频并重采样到44100 63 | score_mat = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False) 64 | score = score_mat.diagonal() 65 | clap_scores.append(score.cpu().numpy()) 66 | audio_list = [] 67 | caption_list = [] 68 | return np.mean(np.array(clap_scores).flatten()) 69 | 70 | def add_clap_score_to_tsv(tsv_path,clap_model): 71 | df = pd.read_csv(tsv_path,sep='\t') 72 | clap_scores_dict = {} 73 | with torch.no_grad(): 74 | for idx,t in enumerate(tqdm(df.itertuples()),start=1): 75 | text_embeddings = clap_model.get_text_embeddings([getattr(t,'caption')])# 经过了norm的embedding 76 | audio_embeddings = clap_model.get_audio_embeddings([getattr(t,'audio_path')], resample=True) 77 | score = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False) 78 | clap_scores_dict[idx] = score.cpu().numpy() 79 | df['clap_score'] = clap_scores_dict 80 | df.to_csv(tsv_path[:-4]+'_clap.tsv',sep='\t',index=False) 81 | 82 | 83 | if __name__ == '__main__': 84 | args = parse_args() 85 | if args.tsv_path: 86 | tsv_path = args.tsv_path 87 | else: 88 | tsv_path = os.path.join(args.wavsdir,'result.tsv') 89 | if not os.path.exists(tsv_path): 90 | print("result tsv not exist,build for it") 91 | tsv_path = build_tsv_from_wavs(args.wavsdir) 92 | clap_model = CLAPWrapper(os.path.join(args.ckpt_path,'CLAP_weights_2022.pth'),os.path.join(args.ckpt_path,'config.yml'), use_cuda=True) 93 | clap_score = cal_score_by_tsv(tsv_path,clap_model) 94 | out = args.wavsdir if args.wavsdir else args.tsv_path 95 | print(f"clap_score for {out} is:{clap_score}") 96 | -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /gen_wav.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from vocoder.bigvgan.models import VocoderBigVGAN 4 | from ldm.models.diffusion.ddim import DDIMSampler 5 | from ldm.util import instantiate_from_config 6 | from omegaconf import OmegaConf 7 | import argparse 8 | import soundfile 9 | device = 'cuda' # change to 'cpu‘ if you do not have gpu. generating with cpu is very slow. 10 | SAMPLE_RATE = 16000 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument( 16 | "--prompt", 17 | type=str, 18 | default="a bird chirps", 19 | help="the prompt to generate audio" 20 | ) 21 | 22 | parser.add_argument( 23 | "--ddim_steps", 24 | type=int, 25 | default=100, 26 | help="number of ddim sampling steps", 27 | ) 28 | 29 | parser.add_argument( 30 | "--duration", 31 | type=int, 32 | default=10, 33 | help="audio duration, seconds", 34 | ) 35 | 36 | parser.add_argument( 37 | "--n_samples", 38 | type=int, 39 | default=1, 40 | help="how many samples to produce for the given prompt", 41 | ) 42 | 43 | parser.add_argument( 44 | "--scale", 45 | type=float, 46 | default=3.0, # if it's 1, only condition is taken into consideration 47 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", 48 | ) 49 | 50 | parser.add_argument( 51 | "--save_name", 52 | type=str, 53 | default='test', 54 | help="audio path name for saving", 55 | ) 56 | return parser.parse_args() 57 | 58 | def initialize_model(config, ckpt,device=device): 59 | config = OmegaConf.load(config) 60 | model = instantiate_from_config(config.model) 61 | model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False) 62 | 63 | model = model.to(device) 64 | model.cond_stage_model.to(model.device) 65 | model.cond_stage_model.device = model.device 66 | print(model.device,device,model.cond_stage_model.device) 67 | sampler = DDIMSampler(model) 68 | 69 | return sampler 70 | 71 | def dur_to_size(duration): 72 | latent_width = int(duration * 7.8) 73 | if latent_width % 4 != 0: 74 | latent_width = (latent_width // 4 + 1) * 4 75 | return latent_width 76 | 77 | def gen_wav(sampler,vocoder,prompt,ddim_steps,scale,duration,n_samples): 78 | latent_width = dur_to_size(duration) 79 | start_code = torch.randn(n_samples, sampler.model.first_stage_model.embed_dim, 10, latent_width).to(device=device, dtype=torch.float32) 80 | 81 | uc = None 82 | if scale != 1.0: 83 | uc = sampler.model.get_learned_conditioning(n_samples * [""]) 84 | c = sampler.model.get_learned_conditioning(n_samples * [prompt]) 85 | shape = [sampler.model.first_stage_model.embed_dim, 10, latent_width] # 10 is latent height 86 | samples_ddim, _ = sampler.sample(S=ddim_steps, 87 | conditioning=c, 88 | batch_size=n_samples, 89 | shape=shape, 90 | verbose=False, 91 | unconditional_guidance_scale=scale, 92 | unconditional_conditioning=uc, 93 | x_T=start_code) 94 | 95 | x_samples_ddim = sampler.model.decode_first_stage(samples_ddim) 96 | 97 | wav_list = [] 98 | for idx,spec in enumerate(x_samples_ddim): 99 | wav = vocoder.vocode(spec) 100 | if len(wav) < SAMPLE_RATE * duration: 101 | wav = np.pad(wav,SAMPLE_RATE*duration-len(wav),mode='constant',constant_values=0) 102 | wav_list.append(wav) 103 | return wav_list 104 | 105 | if __name__ == '__main__': 106 | args = parse_args() 107 | sampler = initialize_model('configs/text_to_audio/txt2audio_args.yaml', 'useful_ckpts/maa1_full.ckpt') 108 | vocoder = VocoderBigVGAN('useful_ckpts/bigvgan',device=device) 109 | print("Generating audios, it may takes a long time depending on your gpu performance") 110 | wav_list = gen_wav(sampler,vocoder,prompt=args.prompt,ddim_steps=args.ddim_steps,scale=args.scale,duration=args.duration,n_samples=args.n_samples) 111 | for idx,wav in enumerate(wav_list): 112 | soundfile.write(f'{args.save_name}_{idx}.wav',wav,samplerate=SAMPLE_RATE) 113 | print(f"audios are saved in {args.save_name}_i.wav") 114 | -------------------------------------------------------------------------------- /vocoder/bigvgan/activations.py: -------------------------------------------------------------------------------- 1 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | from torch import nn, sin, pow 6 | from torch.nn import Parameter 7 | 8 | 9 | class Snake(nn.Module): 10 | ''' 11 | Implementation of a sine-based periodic activation function 12 | Shape: 13 | - Input: (B, C, T) 14 | - Output: (B, C, T), same shape as the input 15 | Parameters: 16 | - alpha - trainable parameter 17 | References: 18 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 19 | https://arxiv.org/abs/2006.08195 20 | Examples: 21 | >>> a1 = snake(256) 22 | >>> x = torch.randn(256) 23 | >>> x = a1(x) 24 | ''' 25 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 26 | ''' 27 | Initialization. 28 | INPUT: 29 | - in_features: shape of the input 30 | - alpha: trainable parameter 31 | alpha is initialized to 1 by default, higher values = higher-frequency. 32 | alpha will be trained along with the rest of your model. 33 | ''' 34 | super(Snake, self).__init__() 35 | self.in_features = in_features 36 | 37 | # initialize alpha 38 | self.alpha_logscale = alpha_logscale 39 | if self.alpha_logscale: # log scale alphas initialized to zeros 40 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 41 | else: # linear scale alphas initialized to ones 42 | self.alpha = Parameter(torch.ones(in_features) * alpha) 43 | 44 | self.alpha.requires_grad = alpha_trainable 45 | 46 | self.no_div_by_zero = 0.000000001 47 | 48 | def forward(self, x): 49 | ''' 50 | Forward pass of the function. 51 | Applies the function to the input elementwise. 52 | Snake ∶= x + 1/a * sin^2 (xa) 53 | ''' 54 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 55 | if self.alpha_logscale: 56 | alpha = torch.exp(alpha) 57 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 58 | 59 | return x 60 | 61 | 62 | class SnakeBeta(nn.Module): 63 | ''' 64 | A modified Snake function which uses separate parameters for the magnitude of the periodic components 65 | Shape: 66 | - Input: (B, C, T) 67 | - Output: (B, C, T), same shape as the input 68 | Parameters: 69 | - alpha - trainable parameter that controls frequency 70 | - beta - trainable parameter that controls magnitude 71 | References: 72 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 73 | https://arxiv.org/abs/2006.08195 74 | Examples: 75 | >>> a1 = snakebeta(256) 76 | >>> x = torch.randn(256) 77 | >>> x = a1(x) 78 | ''' 79 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 80 | ''' 81 | Initialization. 82 | INPUT: 83 | - in_features: shape of the input 84 | - alpha - trainable parameter that controls frequency 85 | - beta - trainable parameter that controls magnitude 86 | alpha is initialized to 1 by default, higher values = higher-frequency. 87 | beta is initialized to 1 by default, higher values = higher-magnitude. 88 | alpha will be trained along with the rest of your model. 89 | ''' 90 | super(SnakeBeta, self).__init__() 91 | self.in_features = in_features 92 | 93 | # initialize alpha 94 | self.alpha_logscale = alpha_logscale 95 | if self.alpha_logscale: # log scale alphas initialized to zeros 96 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 97 | self.beta = Parameter(torch.zeros(in_features) * alpha) 98 | else: # linear scale alphas initialized to ones 99 | self.alpha = Parameter(torch.ones(in_features) * alpha) 100 | self.beta = Parameter(torch.ones(in_features) * alpha) 101 | 102 | self.alpha.requires_grad = alpha_trainable 103 | self.beta.requires_grad = alpha_trainable 104 | 105 | self.no_div_by_zero = 0.000000001 106 | 107 | def forward(self, x): 108 | ''' 109 | Forward pass of the function. 110 | Applies the function to the input elementwise. 111 | SnakeBeta ∶= x + 1/b * sin^2 (xa) 112 | ''' 113 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 114 | beta = self.beta.unsqueeze(0).unsqueeze(-1) 115 | if self.alpha_logscale: 116 | alpha = torch.exp(alpha) 117 | beta = torch.exp(beta) 118 | x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 119 | 120 | return x -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from inspect import isfunction 7 | from PIL import Image, ImageDraw, ImageFont 8 | import hashlib 9 | import requests 10 | import os 11 | 12 | URL_MAP = { 13 | 'vggishish_lpaps': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt', 14 | 'vggishish_mean_std_melspec_10s_22050hz': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt', 15 | 'melception': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt', 16 | } 17 | 18 | CKPT_MAP = { 19 | 'vggishish_lpaps': 'vggishish16.pt', 20 | 'vggishish_mean_std_melspec_10s_22050hz': 'train_means_stds_melspec_10s_22050hz.txt', 21 | 'melception': 'melception-21-05-10T09-28-40.pt', 22 | } 23 | 24 | MD5_MAP = { 25 | 'vggishish_lpaps': '197040c524a07ccacf7715d7080a80bd', 26 | 'vggishish_mean_std_melspec_10s_22050hz': 'f449c6fd0e248936c16f6d22492bb625', 27 | 'melception': 'a71a41041e945b457c7d3d814bbcf72d', 28 | } 29 | 30 | 31 | def download(url, local_path, chunk_size=1024): 32 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 33 | with requests.get(url, stream=True) as r: 34 | total_size = int(r.headers.get("content-length", 0)) 35 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 36 | with open(local_path, "wb") as f: 37 | for data in r.iter_content(chunk_size=chunk_size): 38 | if data: 39 | f.write(data) 40 | pbar.update(chunk_size) 41 | 42 | 43 | def md5_hash(path): 44 | with open(path, "rb") as f: 45 | content = f.read() 46 | return hashlib.md5(content).hexdigest() 47 | 48 | 49 | 50 | def log_txt_as_img(wh, xc, size=10): 51 | # wh a tuple of (width, height) 52 | # xc a list of captions to plot 53 | b = len(xc) 54 | txts = list() 55 | for bi in range(b): 56 | txt = Image.new("RGB", wh, color="white") 57 | draw = ImageDraw.Draw(txt) 58 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 59 | nc = int(40 * (wh[0] / 256)) 60 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 61 | 62 | try: 63 | draw.text((0, 0), lines, fill="black", font=font) 64 | except UnicodeEncodeError: 65 | print("Cant encode string for logging. Skipping.") 66 | 67 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 68 | txts.append(txt) 69 | txts = np.stack(txts) 70 | txts = torch.tensor(txts) 71 | return txts 72 | 73 | 74 | def ismap(x): 75 | if not isinstance(x, torch.Tensor): 76 | return False 77 | return (len(x.shape) == 4) and (x.shape[1] > 3) 78 | 79 | 80 | def isimage(x): 81 | if not isinstance(x,torch.Tensor): 82 | return False 83 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 84 | 85 | 86 | def exists(x): 87 | return x is not None 88 | 89 | 90 | def default(val, d): 91 | if exists(val): 92 | return val 93 | return d() if isfunction(d) else d 94 | 95 | 96 | def mean_flat(tensor): 97 | """ 98 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 99 | Take the mean over all non-batch dimensions. 100 | """ 101 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 102 | 103 | 104 | def count_params(model, verbose=False): 105 | total_params = sum(p.numel() for p in model.parameters()) 106 | if verbose: 107 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 108 | return total_params 109 | 110 | 111 | def instantiate_from_config(config,reload=False): 112 | if not "target" in config: 113 | if config == '__is_first_stage__': 114 | return None 115 | elif config == "__is_unconditional__": 116 | return None 117 | raise KeyError("Expected key `target` to instantiate.") 118 | return get_obj_from_str(config["target"],reload=reload)(**config.get("params", dict())) 119 | 120 | 121 | def get_obj_from_str(string, reload=False): 122 | module, cls = string.rsplit(".", 1) 123 | if reload: 124 | module_imp = importlib.import_module(module) 125 | importlib.reload(module_imp) 126 | return getattr(importlib.import_module(module, package=None), cls) 127 | 128 | def get_ckpt_path(name, root, check=False): 129 | assert name in URL_MAP 130 | path = os.path.join(root, CKPT_MAP[name]) 131 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 132 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 133 | download(URL_MAP[name], path) 134 | md5 = md5_hash(path) 135 | assert md5 == MD5_MAP[name], md5 136 | return path 137 | -------------------------------------------------------------------------------- /preprocess/NAT_mel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from librosa.filters import mel as librosa_mel_fn 5 | from scipy.io.wavfile import read 6 | import torch 7 | import torch.nn as nn 8 | 9 | MAX_WAV_VALUE = 32768.0 10 | 11 | 12 | def load_wav(full_path): 13 | sampling_rate, data = read(full_path) 14 | return data, sampling_rate 15 | 16 | 17 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 18 | return np.log10(np.clip(x, a_min=clip_val, a_max=None) * C) 19 | 20 | 21 | def dynamic_range_decompression(x, C=1): 22 | return np.exp(x) / C 23 | 24 | 25 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 26 | return torch.log10(torch.clamp(x, min=clip_val) * C) 27 | 28 | 29 | def dynamic_range_decompression_torch(x, C=1): 30 | return torch.exp(x) / C 31 | 32 | 33 | def spectral_normalize_torch(magnitudes): 34 | output = dynamic_range_compression_torch(magnitudes) 35 | return output 36 | 37 | 38 | def spectral_de_normalize_torch(magnitudes): 39 | output = dynamic_range_decompression_torch(magnitudes) 40 | return output 41 | 42 | class MelNet(nn.Module): 43 | def __init__(self,hparams,device='cpu') -> None: 44 | super().__init__() 45 | self.n_fft = hparams['fft_size'] 46 | self.num_mels = hparams['audio_num_mel_bins'] 47 | self.sampling_rate = hparams['audio_sample_rate'] 48 | self.hop_size = hparams['hop_size'] 49 | self.win_size = hparams['win_size'] 50 | self.fmin = hparams['fmin'] 51 | self.fmax = hparams['fmax'] 52 | self.device = device 53 | 54 | mel = librosa_mel_fn(self.sampling_rate, self.n_fft, self.num_mels, self.fmin, self.fmax) 55 | self.mel_basis = torch.from_numpy(mel).float().to(self.device) 56 | self.hann_window = torch.hann_window(self.win_size).to(self.device) 57 | 58 | def to(self,device,**kwagrs): 59 | super().to(device=device,**kwagrs) 60 | self.mel_basis = self.mel_basis.to(device) 61 | self.hann_window = self.hann_window.to(device) 62 | self.device = device 63 | 64 | def forward(self,y,center=False, complex=False): 65 | if isinstance(y,np.ndarray): 66 | y = torch.FloatTensor(y) 67 | if len(y.shape) == 1: 68 | y = y.unsqueeze(0) 69 | y = y.clamp(min=-1., max=1.).to(self.device) 70 | 71 | y = torch.nn.functional.pad(y.unsqueeze(1), [int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2)], 72 | mode='reflect') 73 | y = y.squeeze(1) 74 | 75 | spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window, 76 | center=center, pad_mode='reflect', normalized=False, onesided=True,return_complex=complex) 77 | 78 | if not complex: 79 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 80 | spec = torch.matmul(self.mel_basis, spec) 81 | spec = spectral_normalize_torch(spec) 82 | else: 83 | B, C, T, _ = spec.shape 84 | spec = spec.transpose(1, 2) # [B, T, n_fft, 2] 85 | return spec 86 | 87 | ## below can be used in one gpu, but not ddp 88 | mel_basis = {} 89 | hann_window = {} 90 | 91 | 92 | def mel_spectrogram(y, hparams, center=False, complex=False): # y should be a tensor with shape (b,wav_len) 93 | # hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate) 94 | # win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate) 95 | # fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 96 | # fmax: 10000 # To be increased/reduced depending on data. 97 | # fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter 98 | # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, 99 | n_fft = hparams['fft_size'] 100 | num_mels = hparams['audio_num_mel_bins'] 101 | sampling_rate = hparams['audio_sample_rate'] 102 | hop_size = hparams['hop_size'] 103 | win_size = hparams['win_size'] 104 | fmin = hparams['fmin'] 105 | fmax = hparams['fmax'] 106 | if isinstance(y,np.ndarray): 107 | y = torch.FloatTensor(y) 108 | if len(y.shape) == 1: 109 | y = y.unsqueeze(0) 110 | y = y.clamp(min=-1., max=1.) 111 | global mel_basis, hann_window 112 | if fmax not in mel_basis: 113 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 114 | mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 115 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 116 | 117 | y = torch.nn.functional.pad(y.unsqueeze(1), [int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)], 118 | mode='reflect') 119 | y = y.squeeze(1) 120 | 121 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], 122 | center=center, pad_mode='reflect', normalized=False, onesided=True,return_complex=complex) 123 | 124 | if not complex: 125 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 126 | spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec) 127 | spec = spectral_normalize_torch(spec) 128 | else: 129 | B, C, T, _ = spec.shape 130 | spec = spec.transpose(1, 2) # [B, T, n_fft, 2] 131 | return spec 132 | -------------------------------------------------------------------------------- /gen_wavs_by_tsv.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import torch 3 | from tqdm import tqdm 4 | import pandas as pd 5 | import numpy as np 6 | from vocoder.bigvgan.models import VocoderBigVGAN 7 | from ldm.models.diffusion.ddim import DDIMSampler 8 | from ldm.util import instantiate_from_config 9 | from omegaconf import OmegaConf 10 | import argparse 11 | import soundfile 12 | device = 'cuda' # change to 'cpu‘ if you do not have gpu. generating with cpu is very slow. 13 | SAMPLE_RATE = 16000 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument( 19 | "--tsv_path", 20 | type=str, 21 | help="the tsv contains name and caption" 22 | ) 23 | 24 | parser.add_argument( 25 | "--save_dir", 26 | type=str, 27 | help="the directory contains wavs" 28 | ) 29 | 30 | parser.add_argument( 31 | "--ddim_steps", 32 | type=int, 33 | default=100, 34 | help="number of ddim sampling steps", 35 | ) 36 | 37 | parser.add_argument( 38 | "--duration", 39 | type=int, 40 | default=10, 41 | help="audio duration, seconds", 42 | ) 43 | 44 | parser.add_argument( 45 | "--n_samples", 46 | type=int, 47 | default=1, 48 | help="how many samples to produce for the given prompt", 49 | ) 50 | 51 | parser.add_argument( 52 | "--scale", 53 | type=float, 54 | default=3.0, # if it's 1, only condition is taken into consideration 55 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", 56 | ) 57 | 58 | parser.add_argument( 59 | "--save_name", 60 | type=str, 61 | default='test', 62 | help="audio path name for saving", 63 | ) 64 | return parser.parse_args() 65 | 66 | def initialize_model(config, ckpt,device=device): 67 | config = OmegaConf.load(config) 68 | model = instantiate_from_config(config.model) 69 | model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False) 70 | 71 | model = model.to(device) 72 | model.cond_stage_model.to(model.device) 73 | model.cond_stage_model.device = model.device 74 | print(model.device,device,model.cond_stage_model.device) 75 | sampler = DDIMSampler(model) 76 | 77 | return sampler 78 | 79 | def dur_to_size(duration): 80 | latent_width = int(duration * 7.8) 81 | if latent_width % 4 != 0: 82 | latent_width = (latent_width // 4) * 4 83 | return latent_width 84 | 85 | def build_name2caption(tsv_path): 86 | df = pd.read_csv(tsv_path,sep='\t') 87 | name2cap = {} 88 | name_count = {} 89 | for t in df.itertuples(): 90 | name = getattr(t,'name') 91 | caption = getattr(t,'caption') 92 | if name not in name_count: 93 | name_count[name] = 0 94 | else: 95 | name_count[name]+=1 96 | name2cap[name+f'_{name_count[name]}'] = caption 97 | 98 | return name2cap 99 | 100 | def gen_wav(sampler,vocoder,prompt,ddim_steps,scale,duration,n_samples): 101 | latent_width = dur_to_size(duration) 102 | start_code = torch.randn(n_samples, sampler.model.first_stage_model.embed_dim, 10, latent_width).to(device=device, dtype=torch.float32) 103 | 104 | uc = None 105 | if scale != 1.0: 106 | uc = sampler.model.get_learned_conditioning(n_samples * [""]) 107 | c = sampler.model.get_learned_conditioning(n_samples * [prompt]) 108 | shape = [sampler.model.first_stage_model.embed_dim, 10, latent_width] # 10 is latent height 109 | samples_ddim, _ = sampler.sample(S=ddim_steps, 110 | conditioning=c, 111 | batch_size=n_samples, 112 | shape=shape, 113 | verbose=False, 114 | unconditional_guidance_scale=scale, 115 | unconditional_conditioning=uc, 116 | x_T=start_code) 117 | 118 | x_samples_ddim = sampler.model.decode_first_stage(samples_ddim) 119 | 120 | wav_list = [] 121 | for idx,spec in enumerate(x_samples_ddim): 122 | wav = vocoder.vocode(spec) 123 | if len(wav) < SAMPLE_RATE * duration: 124 | wav = np.pad(wav,SAMPLE_RATE*duration-len(wav),mode='constant',constant_values=0) 125 | wav_list.append(wav) 126 | return wav_list 127 | 128 | if __name__ == '__main__': 129 | args = parse_args() 130 | save_dir = args.save_dir 131 | os.makedirs(save_dir,exist_ok=True) 132 | 133 | sampler = initialize_model('configs/text_to_audio/txt2audio_args.yaml', 'useful_ckpts/maa1_full.ckpt') 134 | vocoder = VocoderBigVGAN('useful_ckpts/bigvnat',device=device) 135 | print("Generating audios, it may takes a long time depending on your gpu performance") 136 | name2cap = build_name2caption(args.tsv_path) 137 | result = {'caption':[],'audio_path':[]} 138 | for name,caption in tqdm(name2cap.items()): 139 | wav_list = gen_wav(sampler,vocoder,prompt=caption,ddim_steps=args.ddim_steps,scale=args.scale,duration=args.duration,n_samples=1) 140 | for idx,wav in enumerate(wav_list): 141 | audio_path = f'{save_dir}/{name}.wav' 142 | soundfile.write(audio_path,wav,samplerate=SAMPLE_RATE) 143 | result['caption'].append(caption) 144 | result['audio_path'].append(audio_path) 145 | result = pd.DataFrame(result) 146 | result.to_csv(f'{save_dir}/result.tsv',sep='\t',index=False) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Make-An-Audio: Text-To-Audio Generation with Prompt-Enhanced Diffusion Models 2 | 3 | #### Rongjie Huang, Jiawei Huang, Dongchao Yang, Yi Ren, Luping Liu, Mingze Li, Zhenhui Ye, Jinglin Liu, Xiang Yin, Zhou Zhao 4 | 5 | PyTorch Implementation of [Make-An-Audio (ICML'23)](https://arxiv.org/abs/2301.12661): a conditional diffusion probabilistic model capable of generating high fidelity audio efficiently from X modality. 6 | 7 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2301.12661) 8 | [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/AIGC-Audio/Make_An_Audio) 9 | [![GitHub Stars](https://img.shields.io/github/stars/Text-to-Audio/Make-An-Audio?style=social)](https://github.com/Text-to-Audio/Make-An-Audio) 10 | 11 | We provide our implementation and pretrained models as open source in this repository. 12 | 13 | Visit our [demo page](https://text-to-audio.github.io/) for audio samples. 14 | 15 | [Text-to-Audio HuggingFace Space](https://huggingface.co/spaces/AIGC-Audio/Make_An_Audio) | [Audio Inpainting HuggingFace Space](https://huggingface.co/spaces/AIGC-Audio/Make_An_Audio_inpaint) 16 | 17 | ## News 18 | - Jan, 2023: **[Make-An-Audio](https://arxiv.org/abs/2207.06389)** submitted to arxiv. 19 | - August, 2023: **[Make-An-Audio](https://arxiv.org/abs/2301.12661) (ICML 2022)** released in Github. 20 | 21 | ## Quick Started 22 | We provide an example of how you can generate high-fidelity samples using Make-An-Audio. 23 | 24 | To try on your own dataset, simply clone this repo in your local machine provided with NVIDIA GPU + CUDA cuDNN and follow the below instructions. 25 | 26 | 27 | ### Support Datasets and Pretrained Models 28 | 29 | Simply run following command to download the weights from [Google drive](https://drive.google.com/drive/folders/1zZTI3-nHrUIywKFqwxlFO6PjB66JA8jI?usp=drive_link). 30 | Download CLAP weights from [Hugging Face](https://huggingface.co/microsoft/msclap/blob/main/CLAP_weights_2022.pth). 31 | 32 | ``` 33 | Download: 34 | maa1_full.ckpt and put it into ./useful_ckpts 35 | BigVGAN vocoder and put it into ./useful_ckpts 36 | CLAP_weights_2022.pth and put it into ./useful_ckpts/CLAP 37 | ``` 38 | The directory structure should be: 39 | ``` 40 | useful_ckpts/ 41 | ├── bigvgan 42 | │ ├── args.yml 43 | │ └── best_netG.pt 44 | ├── CLAP 45 | │ ├── config.yml 46 | │ └── CLAP_weights_2022.pth 47 | └── maa1_full.ckpt 48 | ``` 49 | 50 | 51 | ### Dependencies 52 | See requirements in `requirement.txt`: 53 | 54 | ## Inference with pretrained model 55 | ```bash 56 | python gen_wav.py --prompt "a bird chirps" --ddim_steps 100 --duration 10 --scale 3 --n_samples 1 --save_name "results" 57 | ``` 58 | # Train 59 | ## dataset preparation 60 | We can't provide the dataset download link for copyright issues. We provide the process code to generate melspec. 61 | Before training, we need to construct the dataset information into a tsv file, which includes name (id for each audio), dataset (which dataset the audio belongs to), audio_path (the path of .wav file),caption (the caption of the audio) ,mel_path (the processed melspec file path of each audio). We provide a tsv file of audiocaps test set: ./data/audiocaps_test.tsv as a sample. 62 | ### generate the melspec file of audio 63 | Assume you have already got a tsv file to link each caption to its audio_path, which mean the tsv_file have "name","audio_path","dataset" and "caption" columns in it. 64 | To get the melspec of audio, run the following command, which will save mels in ./processed 65 | ```bash 66 | python preprocess/mel_spec.py --tsv_path tmp.tsv --num_gpus 1 --max_duration 10 67 | ``` 68 | ## Train variational autoencoder 69 | Assume we have processed several datasets, and save the .tsv files in data/*.tsv . Replace **data.params.spec_dir_path** with the **data**(the directory that contain tsvs) in the config file. Then we can train VAE with the following command. If you don't have 8 gpus in your machine, you can replace --gpus 0,1,...,gpu_nums 70 | ```bash 71 | python main.py --base configs/train/vae.yaml -t --gpus 0,1,2,3,4,5,6,7 72 | ``` 73 | The training result will be save in ./logs/ 74 | ## train latent diffsuion 75 | After Trainning VAE, replace model.params.first_stage_config.params.ckpt_path with your trained VAE checkpoint path in the config file. 76 | Run the following command to train Diffusion model 77 | ```bash 78 | python main.py --base configs/train/diffusion.yaml -t --gpus 0,1,2,3,4,5,6,7 79 | ``` 80 | The training result will be save in ./logs/ 81 | # Evaluation 82 | ## generate audiocaps samples 83 | ```bash 84 | python gen_wavs_by_tsv.py --tsv_path data/audiocaps_test.tsv --save_dir audiocaps_gen 85 | ``` 86 | 87 | ## calculate FD,FAD,IS,KL 88 | install [audioldm_eval](https://github.com/haoheliu/audioldm_eval) by 89 | ```bash 90 | git clone git@github.com:haoheliu/audioldm_eval.git 91 | ``` 92 | Then test with: 93 | ```bash 94 | python scripts/test.py --pred_wavsdir {the directory that saves the audios you generated} --gt_wavsdir {the directory that saves audiocaps test set waves} 95 | ``` 96 | ## calculate Clap_score 97 | ```bash 98 | python wav_evaluation/cal_clap_score.py --tsv_path {the directory that saves the audios you generated}/result.tsv 99 | ``` 100 | # X-To-Audio 101 | ## Audio2Audio 102 | ```bash 103 | python scripts/audio2audio.py --prompt "a bird chirping" --strength 0.3 --init-audio sample.wav --ckpt useful_ckpts/maa1_full.ckpt --vocoder_ckpt useful_ckpts/bigvgan --config configs/text_to_audio/txt2audio_args.yaml --outdir audio2audio_samples 104 | ``` 105 | 106 | ## Acknowledgements 107 | This implementation uses parts of the code from the following Github repos: 108 | [CLAP](https://github.com/LAION-AI/CLAP), 109 | [Stable Diffusion](https://github.com/CompVis/stable-diffusion), 110 | as described in our code. 111 | 112 | ## Citations ## 113 | If you find this code useful in your research, please consider citing: 114 | ```bibtex 115 | @article{huang2023make, 116 | title={Make-an-audio: Text-to-audio generation with prompt-enhanced diffusion models}, 117 | author={Huang, Rongjie and Huang, Jiawei and Yang, Dongchao and Ren, Yi and Liu, Luping and Li, Mingze and Ye, Zhenhui and Liu, Jinglin and Yin, Xiang and Zhao, Zhou}, 118 | journal={arXiv preprint arXiv:2301.12661}, 119 | year={2023} 120 | } 121 | ``` 122 | 123 | # Disclaimer ## 124 | Any organization or individual is prohibited from using any technology mentioned in this paper to generate someone's speech without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws. 125 | -------------------------------------------------------------------------------- /wav_evaluation/models/audio.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchlibrosa.stft import Spectrogram, LogmelFilterBank 5 | 6 | def get_audio_encoder(name: str): 7 | if name == "Cnn14": 8 | return Cnn14 9 | else: 10 | raise Exception('The audio encoder name {} is incorrect or not supported'.format(name)) 11 | 12 | 13 | class ConvBlock(nn.Module): 14 | def __init__(self, in_channels, out_channels): 15 | 16 | super(ConvBlock, self).__init__() 17 | 18 | self.conv1 = nn.Conv2d(in_channels=in_channels, 19 | out_channels=out_channels, 20 | kernel_size=(3, 3), stride=(1, 1), 21 | padding=(1, 1), bias=False) 22 | 23 | self.conv2 = nn.Conv2d(in_channels=out_channels, 24 | out_channels=out_channels, 25 | kernel_size=(3, 3), stride=(1, 1), 26 | padding=(1, 1), bias=False) 27 | 28 | self.bn1 = nn.BatchNorm2d(out_channels) 29 | self.bn2 = nn.BatchNorm2d(out_channels) 30 | 31 | 32 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 33 | 34 | x = input 35 | x = F.relu_(self.bn1(self.conv1(x))) 36 | x = F.relu_(self.bn2(self.conv2(x))) 37 | if pool_type == 'max': 38 | x = F.max_pool2d(x, kernel_size=pool_size) 39 | elif pool_type == 'avg': 40 | x = F.avg_pool2d(x, kernel_size=pool_size) 41 | elif pool_type == 'avg+max': 42 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 43 | x2 = F.max_pool2d(x, kernel_size=pool_size) 44 | x = x1 + x2 45 | else: 46 | raise Exception('Incorrect argument!') 47 | 48 | return x 49 | 50 | 51 | class ConvBlock5x5(nn.Module): 52 | def __init__(self, in_channels, out_channels): 53 | 54 | super(ConvBlock5x5, self).__init__() 55 | 56 | self.conv1 = nn.Conv2d(in_channels=in_channels, 57 | out_channels=out_channels, 58 | kernel_size=(5, 5), stride=(1, 1), 59 | padding=(2, 2), bias=False) 60 | 61 | self.bn1 = nn.BatchNorm2d(out_channels) 62 | 63 | 64 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 65 | 66 | x = input 67 | x = F.relu_(self.bn1(self.conv1(x))) 68 | if pool_type == 'max': 69 | x = F.max_pool2d(x, kernel_size=pool_size) 70 | elif pool_type == 'avg': 71 | x = F.avg_pool2d(x, kernel_size=pool_size) 72 | elif pool_type == 'avg+max': 73 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 74 | x2 = F.max_pool2d(x, kernel_size=pool_size) 75 | x = x1 + x2 76 | else: 77 | raise Exception('Incorrect argument!') 78 | 79 | return x 80 | 81 | 82 | class AttBlock(nn.Module): 83 | def __init__(self, n_in, n_out, activation='linear', temperature=1.): 84 | super(AttBlock, self).__init__() 85 | 86 | self.activation = activation 87 | self.temperature = temperature 88 | self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 89 | self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 90 | 91 | self.bn_att = nn.BatchNorm1d(n_out) 92 | 93 | def forward(self, x): 94 | # x: (n_samples, n_in, n_time) 95 | norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) 96 | cla = self.nonlinear_transform(self.cla(x)) 97 | x = torch.sum(norm_att * cla, dim=2) 98 | return x, norm_att, cla 99 | 100 | def nonlinear_transform(self, x): 101 | if self.activation == 'linear': 102 | return x 103 | elif self.activation == 'sigmoid': 104 | return torch.sigmoid(x) 105 | 106 | 107 | class Cnn14(nn.Module): 108 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 109 | fmax, classes_num, out_emb): 110 | 111 | super(Cnn14, self).__init__() 112 | 113 | window = 'hann' 114 | center = True 115 | pad_mode = 'reflect' 116 | ref = 1.0 117 | amin = 1e-10 118 | top_db = None 119 | 120 | # Spectrogram extractor 121 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 122 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 123 | freeze_parameters=True) 124 | 125 | # Logmel feature extractor 126 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 127 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 128 | freeze_parameters=True) 129 | 130 | self.bn0 = nn.BatchNorm2d(64) 131 | 132 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 133 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 134 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 135 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 136 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 137 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 138 | 139 | # out_emb is 2048 for best Cnn14 140 | self.fc1 = nn.Linear(2048, out_emb, bias=True) 141 | self.fc_audioset = nn.Linear(out_emb, classes_num, bias=True) 142 | 143 | def forward(self, input, mixup_lambda=None): 144 | """ 145 | Input: (batch_size, data_length) 146 | """ 147 | 148 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 149 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 150 | 151 | x = x.transpose(1, 3) 152 | x = self.bn0(x) 153 | x = x.transpose(1, 3) 154 | 155 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 156 | x = F.dropout(x, p=0.2, training=self.training) 157 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 158 | x = F.dropout(x, p=0.2, training=self.training) 159 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 160 | x = F.dropout(x, p=0.2, training=self.training) 161 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 162 | x = F.dropout(x, p=0.2, training=self.training) 163 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 164 | x = F.dropout(x, p=0.2, training=self.training) 165 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 166 | x = F.dropout(x, p=0.2, training=self.training) 167 | x = torch.mean(x, dim=3) 168 | 169 | (x1, _) = torch.max(x, dim=2) 170 | x2 = torch.mean(x, dim=2) 171 | x = x1 + x2 172 | x = F.dropout(x, p=0.5, training=self.training) 173 | x = F.relu_(self.fc1(x)) 174 | embedding = F.dropout(x, p=0.5, training=self.training) 175 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 176 | 177 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 178 | 179 | return output_dict -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/audio.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchlibrosa.stft import Spectrogram, LogmelFilterBank 5 | 6 | def get_audio_encoder(name: str): 7 | if name == "Cnn14": 8 | return Cnn14 9 | else: 10 | raise Exception('The audio encoder name {} is incorrect or not supported'.format(name)) 11 | 12 | 13 | class ConvBlock(nn.Module): 14 | def __init__(self, in_channels, out_channels): 15 | 16 | super(ConvBlock, self).__init__() 17 | 18 | self.conv1 = nn.Conv2d(in_channels=in_channels, 19 | out_channels=out_channels, 20 | kernel_size=(3, 3), stride=(1, 1), 21 | padding=(1, 1), bias=False) 22 | 23 | self.conv2 = nn.Conv2d(in_channels=out_channels, 24 | out_channels=out_channels, 25 | kernel_size=(3, 3), stride=(1, 1), 26 | padding=(1, 1), bias=False) 27 | 28 | self.bn1 = nn.BatchNorm2d(out_channels) 29 | self.bn2 = nn.BatchNorm2d(out_channels) 30 | 31 | 32 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 33 | 34 | x = input 35 | x = F.relu_(self.bn1(self.conv1(x))) 36 | x = F.relu_(self.bn2(self.conv2(x))) 37 | if pool_type == 'max': 38 | x = F.max_pool2d(x, kernel_size=pool_size) 39 | elif pool_type == 'avg': 40 | x = F.avg_pool2d(x, kernel_size=pool_size) 41 | elif pool_type == 'avg+max': 42 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 43 | x2 = F.max_pool2d(x, kernel_size=pool_size) 44 | x = x1 + x2 45 | else: 46 | raise Exception('Incorrect argument!') 47 | 48 | return x 49 | 50 | 51 | class ConvBlock5x5(nn.Module): 52 | def __init__(self, in_channels, out_channels): 53 | 54 | super(ConvBlock5x5, self).__init__() 55 | 56 | self.conv1 = nn.Conv2d(in_channels=in_channels, 57 | out_channels=out_channels, 58 | kernel_size=(5, 5), stride=(1, 1), 59 | padding=(2, 2), bias=False) 60 | 61 | self.bn1 = nn.BatchNorm2d(out_channels) 62 | 63 | 64 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 65 | 66 | x = input 67 | x = F.relu_(self.bn1(self.conv1(x))) 68 | if pool_type == 'max': 69 | x = F.max_pool2d(x, kernel_size=pool_size) 70 | elif pool_type == 'avg': 71 | x = F.avg_pool2d(x, kernel_size=pool_size) 72 | elif pool_type == 'avg+max': 73 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 74 | x2 = F.max_pool2d(x, kernel_size=pool_size) 75 | x = x1 + x2 76 | else: 77 | raise Exception('Incorrect argument!') 78 | 79 | return x 80 | 81 | 82 | class AttBlock(nn.Module): 83 | def __init__(self, n_in, n_out, activation='linear', temperature=1.): 84 | super(AttBlock, self).__init__() 85 | 86 | self.activation = activation 87 | self.temperature = temperature 88 | self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 89 | self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 90 | 91 | self.bn_att = nn.BatchNorm1d(n_out) 92 | 93 | def forward(self, x): 94 | # x: (n_samples, n_in, n_time) 95 | norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) 96 | cla = self.nonlinear_transform(self.cla(x)) 97 | x = torch.sum(norm_att * cla, dim=2) 98 | return x, norm_att, cla 99 | 100 | def nonlinear_transform(self, x): 101 | if self.activation == 'linear': 102 | return x 103 | elif self.activation == 'sigmoid': 104 | return torch.sigmoid(x) 105 | 106 | 107 | class Cnn14(nn.Module): 108 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 109 | fmax, classes_num, out_emb): 110 | 111 | super(Cnn14, self).__init__() 112 | 113 | window = 'hann' 114 | center = True 115 | pad_mode = 'reflect' 116 | ref = 1.0 117 | amin = 1e-10 118 | top_db = None 119 | 120 | # Spectrogram extractor 121 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 122 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 123 | freeze_parameters=True) 124 | 125 | # Logmel feature extractor 126 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 127 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 128 | freeze_parameters=True) 129 | 130 | self.bn0 = nn.BatchNorm2d(64) 131 | 132 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 133 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 134 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 135 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 136 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 137 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 138 | 139 | # out_emb is 2048 for best Cnn14 140 | self.fc1 = nn.Linear(2048, out_emb, bias=True) 141 | self.fc_audioset = nn.Linear(out_emb, classes_num, bias=True) 142 | 143 | def forward(self, input, mixup_lambda=None): 144 | """ 145 | Input: (batch_size, data_length) 146 | """ 147 | 148 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 149 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 150 | 151 | x = x.transpose(1, 3) 152 | x = self.bn0(x) 153 | x = x.transpose(1, 3) 154 | 155 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 156 | x = F.dropout(x, p=0.2, training=self.training) 157 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 158 | x = F.dropout(x, p=0.2, training=self.training) 159 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 160 | x = F.dropout(x, p=0.2, training=self.training) 161 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 162 | x = F.dropout(x, p=0.2, training=self.training) 163 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 164 | x = F.dropout(x, p=0.2, training=self.training) 165 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 166 | x = F.dropout(x, p=0.2, training=self.training) 167 | x = torch.mean(x, dim=3) 168 | 169 | (x1, _) = torch.max(x, dim=2) 170 | x2 = torch.mean(x, dim=2) 171 | x = x1 + x2 172 | x = F.dropout(x, p=0.5, training=self.training) 173 | x = F.relu_(self.fc1(x)) 174 | embedding = F.dropout(x, p=0.5, training=self.training) 175 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 176 | 177 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 178 | 179 | return output_dict -------------------------------------------------------------------------------- /ldm/modules/losses_audio/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | 6 | sys.path.insert(0, '.') # nopep8 7 | from ldm.modules.losses_audio.vqperceptual import * 8 | 9 | def discriminator_loss_mse(disc_real_outputs, disc_generated_outputs): 10 | r_losses = 0 11 | g_losses = 0 12 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 13 | r_loss = torch.mean((1 - dr) ** 2) 14 | g_loss = torch.mean(dg ** 2) 15 | r_losses += r_loss 16 | g_losses += g_loss 17 | r_losses = r_losses / len(disc_real_outputs) 18 | g_losses = g_losses / len(disc_real_outputs) 19 | total = 0.5 * (r_losses + g_losses) 20 | return total 21 | 22 | class LPAPSWithDiscriminator(nn.Module): 23 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 24 | disc_num_layers=3, disc_in_channels=3,disc_hidden_size=64, disc_factor=1.0, disc_weight=1.0, 25 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 26 | disc_loss="hinge",r1_reg_weight=5): 27 | 28 | super().__init__() 29 | assert disc_loss in ["hinge", "vanilla","mse"] 30 | self.kl_weight = kl_weight 31 | self.pixel_weight = pixelloss_weight 32 | self.perceptual_weight = perceptual_weight 33 | if self.perceptual_weight > 0: 34 | raise RuntimeError("don't use perceptual loss") 35 | 36 | # output log variance 37 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 38 | 39 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 40 | ndf = disc_hidden_size, 41 | n_layers=disc_num_layers, 42 | use_actnorm=use_actnorm, 43 | ).apply(weights_init) # h=8,w/(2**disc_num_layers) - 2 44 | self.discriminator_iter_start = disc_start 45 | if disc_loss == "hinge": 46 | self.disc_loss = hinge_d_loss 47 | elif disc_loss == "vanilla": 48 | self.disc_loss = vanilla_d_loss 49 | elif disc_loss == 'mse': 50 | self.disc_loss = discriminator_loss_mse 51 | else: 52 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 53 | print(f"LPAPSWithDiscriminator running with {disc_loss} loss.") 54 | self.disc_factor = disc_factor 55 | self.discriminator_weight = disc_weight 56 | self.disc_conditional = disc_conditional 57 | self.r1_reg_weight = r1_reg_weight 58 | 59 | 60 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 61 | if last_layer is not None: 62 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 63 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 64 | else: 65 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 66 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 67 | 68 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 69 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 70 | d_weight = d_weight * self.discriminator_weight 71 | return d_weight 72 | 73 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 74 | global_step, last_layer=None, cond=None, split="train", weights=None): 75 | if len(inputs.shape) == 3: 76 | inputs,reconstructions = inputs.unsqueeze(1),reconstructions.unsqueeze(1) 77 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 78 | if self.perceptual_weight > 0: 79 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 80 | # print(f"p_loss {p_loss}") 81 | rec_loss = rec_loss + self.perceptual_weight * p_loss 82 | else: 83 | p_loss = torch.tensor([0.0]) 84 | 85 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 86 | weighted_nll_loss = nll_loss 87 | if weights is not None: 88 | weighted_nll_loss = weights*nll_loss 89 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 90 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 91 | kl_loss = posteriors.kl() 92 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 93 | 94 | # now the GAN part 95 | if optimizer_idx == 0: 96 | # generator update 97 | if cond is None: 98 | assert not self.disc_conditional 99 | logits_fake = self.discriminator(reconstructions.contiguous()) 100 | else: 101 | assert self.disc_conditional 102 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 103 | g_loss = -torch.mean(logits_fake) 104 | 105 | try: 106 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 107 | except RuntimeError: 108 | assert not self.training 109 | d_weight = torch.tensor(0.0) 110 | 111 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 112 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 113 | 114 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 115 | "{}/logvar".format(split): self.logvar.detach(), 116 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 117 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 118 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 119 | "{}/d_weight".format(split): d_weight.detach(), 120 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 121 | "{}/g_loss".format(split): g_loss.detach().mean(), 122 | } 123 | return loss, log 124 | 125 | if optimizer_idx == 1: 126 | # second pass for discriminator update 127 | if cond is None: 128 | d_real_in = inputs.contiguous().detach() 129 | d_real_in.requires_grad = True 130 | logits_real = self.discriminator(d_real_in) 131 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 132 | else: 133 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 134 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 135 | 136 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 137 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) # logits_real越大,logits_fake越小说明discriminator越强 138 | if self.r1_reg_weight > 0 and split=='train': 139 | r1_grads = torch.autograd.grad(outputs=[logits_real.sum()], inputs=[d_real_in], create_graph=True, only_inputs=True) 140 | r1_grads = r1_grads[0] 141 | r1_penalty = r1_grads.square().mean() 142 | d_loss += self.r1_reg_weight * r1_penalty 143 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 144 | "{}/logits_real".format(split): logits_real.detach().mean(), 145 | "{}/logits_fake".format(split): logits_fake.detach().mean() 146 | } 147 | if self.r1_reg_weight and split=='train': 148 | log["{}/r1_prnalty".format(split)] = r1_penalty 149 | return d_loss, log 150 | 151 | 152 | -------------------------------------------------------------------------------- /ldm/modules/losses_audio/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | from ldm.util import exists 6 | sys.path.insert(0, '.') # nopep8 7 | # from ldm.modules.losses_audio.lpaps import LPAPS 8 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 9 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 10 | 11 | 12 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 13 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 14 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 15 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 16 | loss_real = (weights * loss_real).sum() / weights.sum() 17 | loss_fake = (weights * loss_fake).sum() / weights.sum() 18 | d_loss = 0.5 * (loss_real + loss_fake) 19 | return d_loss 20 | 21 | def adopt_weight(weight, global_step, threshold=0, value=0.): 22 | if global_step < threshold: 23 | weight = value 24 | return weight 25 | 26 | 27 | def measure_perplexity(predicted_indices, n_embed): 28 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 29 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 30 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 31 | avg_probs = encodings.mean(0) 32 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 33 | cluster_use = torch.sum(avg_probs > 0) 34 | return perplexity, cluster_use 35 | 36 | def l1(x, y): 37 | return torch.abs(x-y) 38 | 39 | 40 | def l2(x, y): 41 | return torch.pow((x-y), 2) 42 | 43 | 44 | 45 | class DummyLoss(nn.Module): 46 | def __init__(self): 47 | super().__init__() 48 | 49 | class VQLPAPSWithDiscriminator(nn.Module): 50 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 51 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 52 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 53 | disc_ndf=64, disc_loss="hinge", n_classes=None, pixel_loss="l1"): 54 | super().__init__() 55 | assert disc_loss in ["hinge", "vanilla"] 56 | self.codebook_weight = codebook_weight 57 | self.pixel_weight = pixelloss_weight 58 | self.perceptual_loss = None # LPAPS().eval() 59 | self.perceptual_weight = perceptual_weight 60 | 61 | if pixel_loss == "l1": 62 | self.pixel_loss = l1 63 | else: 64 | self.pixel_loss = l2 65 | 66 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 67 | n_layers=disc_num_layers, 68 | use_actnorm=use_actnorm, 69 | ndf=disc_ndf 70 | ).apply(weights_init) 71 | self.discriminator_iter_start = disc_start 72 | if disc_loss == "hinge": 73 | self.disc_loss = hinge_d_loss 74 | elif disc_loss == "vanilla": 75 | self.disc_loss = vanilla_d_loss 76 | else: 77 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 78 | print(f"VQLPAPSWithDiscriminator running with {disc_loss} loss.") 79 | self.disc_factor = disc_factor 80 | self.discriminator_weight = disc_weight 81 | self.disc_conditional = disc_conditional 82 | self.n_classes = n_classes 83 | 84 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 85 | if last_layer is not None: 86 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 87 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 88 | else: 89 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 90 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 91 | 92 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 93 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 94 | d_weight = d_weight * self.discriminator_weight 95 | return d_weight 96 | 97 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 98 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 99 | if not exists(codebook_loss): 100 | codebook_loss = torch.tensor([0.]).to(inputs.device) 101 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 102 | if self.perceptual_weight > 0: 103 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | rec_loss = rec_loss + self.perceptual_weight * p_loss 105 | else: 106 | p_loss = torch.tensor([0.0]) 107 | 108 | nll_loss = rec_loss 109 | # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 110 | nll_loss = torch.mean(nll_loss) 111 | 112 | # now the GAN part 113 | if optimizer_idx == 0: 114 | # generator update 115 | if cond is None: 116 | assert not self.disc_conditional 117 | logits_fake = self.discriminator(reconstructions.contiguous()) 118 | else: 119 | assert self.disc_conditional 120 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 121 | g_loss = -torch.mean(logits_fake) 122 | 123 | try: 124 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 125 | except RuntimeError: 126 | assert not self.training 127 | d_weight = torch.tensor(0.0) 128 | 129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 130 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 131 | 132 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 133 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 134 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 135 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 136 | "{}/p_loss".format(split): p_loss.detach().mean(), 137 | "{}/d_weight".format(split): d_weight.detach(), 138 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 139 | "{}/g_loss".format(split): g_loss.detach().mean(), 140 | } 141 | # if predicted_indices is not None: 142 | # assert self.n_classes is not None 143 | # with torch.no_grad(): 144 | # perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 145 | # log[f"{split}/perplexity"] = perplexity 146 | # log[f"{split}/cluster_usage"] = cluster_usage 147 | return loss, log 148 | 149 | if optimizer_idx == 1: 150 | # second pass for discriminator update 151 | if cond is None: 152 | logits_real = self.discriminator(inputs.contiguous().detach()) 153 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 154 | else: 155 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 156 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 157 | 158 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 159 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 160 | 161 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 162 | "{}/logits_real".format(split): logits_real.detach().mean(), 163 | "{}/logits_fake".format(split): logits_fake.detach().mean() 164 | } 165 | return d_loss, log 166 | 167 | -------------------------------------------------------------------------------- /ldm/modules/discriminator/multi_window_disc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class Discriminator2DFactory(nn.Module): 7 | def __init__(self, time_length, freq_length=80, kernel=(3, 3), c_in=1, hidden_size=128, 8 | norm_type='bn', reduction='sum'): 9 | super(Discriminator2DFactory, self).__init__() 10 | padding = (kernel[0] // 2, kernel[1] // 2) 11 | 12 | def discriminator_block(in_filters, out_filters, first=False): 13 | """ 14 | Input: (B, in, 2H, 2W) 15 | Output:(B, out, H, W) 16 | """ 17 | conv = nn.Conv2d(in_filters, out_filters, kernel, (2, 2), padding) 18 | if norm_type == 'sn': 19 | conv = nn.utils.spectral_norm(conv) 20 | block = [ 21 | conv, # padding = kernel//2 22 | nn.LeakyReLU(0.2, inplace=True), 23 | nn.Dropout2d(0.25) 24 | ] 25 | if norm_type == 'bn' and not first: 26 | block.append(nn.BatchNorm2d(out_filters, 0.8)) 27 | if norm_type == 'in' and not first: 28 | block.append(nn.InstanceNorm2d(out_filters, affine=True)) 29 | block = nn.Sequential(*block) 30 | return block 31 | 32 | self.model = nn.ModuleList([ 33 | discriminator_block(c_in, hidden_size, first=True), 34 | discriminator_block(hidden_size, hidden_size), 35 | discriminator_block(hidden_size, hidden_size), 36 | ]) 37 | 38 | self.reduction = reduction 39 | ds_size = (time_length // 2 ** 3, (freq_length + 7) // 2 ** 3) 40 | if reduction != 'none': 41 | # The height and width of downsampled image 42 | self.adv_layer = nn.Linear(hidden_size * ds_size[0] * ds_size[1], 1) 43 | else: 44 | self.adv_layer = nn.Linear(hidden_size * ds_size[1], 1) 45 | 46 | def forward(self, x): 47 | """ 48 | 49 | :param x: [B, C, T, n_bins] 50 | :return: validity: [B, 1], h: List of hiddens 51 | """ 52 | h = [] 53 | for l in self.model: 54 | x = l(x) 55 | h.append(x) 56 | if self.reduction != 'none': 57 | x = x.view(x.shape[0], -1) 58 | validity = self.adv_layer(x) # [B, 1] 59 | else: 60 | B, _, T_, _ = x.shape 61 | x = x.transpose(1, 2).reshape(B, T_, -1) 62 | validity = self.adv_layer(x)[:, :, 0] # [B, T] 63 | return validity, h 64 | 65 | 66 | class MultiWindowDiscriminator(nn.Module): 67 | def __init__(self, time_lengths, cond_size=0, freq_length=80, kernel=(3, 3), 68 | c_in=1, hidden_size=128, norm_type='bn', reduction='sum'): 69 | super(MultiWindowDiscriminator, self).__init__() 70 | self.win_lengths = time_lengths 71 | self.reduction = reduction 72 | 73 | self.conv_layers = nn.ModuleList() 74 | if cond_size > 0: 75 | self.cond_proj_layers = nn.ModuleList() 76 | self.mel_proj_layers = nn.ModuleList() 77 | for time_length in time_lengths: 78 | conv_layer = [ 79 | Discriminator2DFactory( 80 | time_length, freq_length, kernel, c_in=c_in, hidden_size=hidden_size, 81 | norm_type=norm_type, reduction=reduction) 82 | ] 83 | self.conv_layers += conv_layer 84 | if cond_size > 0: 85 | self.cond_proj_layers.append(nn.Linear(cond_size, freq_length)) 86 | self.mel_proj_layers.append(nn.Linear(freq_length, freq_length)) 87 | 88 | def forward(self, x, x_len, cond=None, start_frames_wins=None): 89 | ''' 90 | Args: 91 | x (tensor): input mel, (B, c_in, T, n_bins). 92 | x_length (tensor): len of per mel. (B,). 93 | 94 | Returns: 95 | tensor : (B). 96 | ''' 97 | validity = [] 98 | if start_frames_wins is None: 99 | start_frames_wins = [None] * len(self.conv_layers) 100 | h = [] 101 | for i, start_frames in zip(range(len(self.conv_layers)), start_frames_wins): 102 | x_clip, c_clip, start_frames = self.clip( 103 | x, cond, x_len, self.win_lengths[i], start_frames) # (B, win_length, C) 104 | start_frames_wins[i] = start_frames 105 | if x_clip is None: 106 | continue 107 | if cond is not None: 108 | x_clip = self.mel_proj_layers[i](x_clip) # (B, 1, win_length, C) 109 | c_clip = self.cond_proj_layers[i](c_clip)[:, None] # (B, 1, win_length, C) 110 | x_clip = x_clip + c_clip 111 | x_clip, h_ = self.conv_layers[i](x_clip) 112 | h += h_ 113 | validity.append(x_clip) 114 | if len(validity) != len(self.conv_layers): 115 | return None, start_frames_wins, h 116 | if self.reduction == 'sum': 117 | validity = sum(validity) # [B] 118 | elif self.reduction == 'stack': 119 | validity = torch.stack(validity, -1) # [B, W_L] 120 | elif self.reduction == 'none': 121 | validity = torch.cat(validity, -1) # [B, W_sum] 122 | return validity, start_frames_wins, h 123 | 124 | def clip(self, x, cond, x_len, win_length, start_frames=None): 125 | '''Ramdom clip x to win_length. 126 | Args: 127 | x (tensor) : (B, c_in, T, n_bins). 128 | cond (tensor) : (B, T, H). 129 | x_len (tensor) : (B,). 130 | win_length (int): target clip length 131 | 132 | Returns: 133 | (tensor) : (B, c_in, win_length, n_bins). 134 | 135 | ''' 136 | T_start = 0 137 | T_end = x_len.max() - win_length 138 | if T_end < 0: 139 | return None, None, start_frames 140 | T_end = T_end.item() 141 | if start_frames is None: 142 | start_frame = np.random.randint(low=T_start, high=T_end + 1) 143 | start_frames = [start_frame] * x.size(0) 144 | else: 145 | start_frame = start_frames[0] 146 | x_batch = x[:, :, start_frame: start_frame + win_length] 147 | c_batch = cond[:, start_frame: start_frame + win_length] if cond is not None else None 148 | return x_batch, c_batch, start_frames 149 | 150 | 151 | class Discriminator(nn.Module): 152 | def __init__(self, time_lengths=[32, 64, 128], freq_length=80, cond_size=0, kernel=(3, 3), c_in=1, 153 | hidden_size=128, norm_type='bn', reduction='sum', uncond_disc=True): 154 | super(Discriminator, self).__init__() 155 | self.time_lengths = time_lengths 156 | self.cond_size = cond_size 157 | self.reduction = reduction 158 | self.uncond_disc = uncond_disc 159 | if uncond_disc: 160 | self.discriminator = MultiWindowDiscriminator( 161 | freq_length=freq_length, 162 | time_lengths=time_lengths, 163 | kernel=kernel, 164 | c_in=c_in, hidden_size=hidden_size, norm_type=norm_type, 165 | reduction=reduction 166 | ) 167 | if cond_size > 0: 168 | self.cond_disc = MultiWindowDiscriminator( 169 | freq_length=freq_length, 170 | time_lengths=time_lengths, 171 | cond_size=cond_size, 172 | kernel=kernel, 173 | c_in=c_in, hidden_size=hidden_size, norm_type=norm_type, 174 | reduction=reduction 175 | ) 176 | 177 | def forward(self, x, cond=None, start_frames_wins=None): 178 | """ 179 | 180 | :param x: [B, T, 80] 181 | :param cond: [B, T, cond_size] 182 | :param return_y_only: 183 | :return: 184 | """ 185 | if len(x.shape) == 3: 186 | x = x[:, None, :, :] 187 | x_len = x.sum([1, -1]).ne(0).int().sum([-1]) 188 | ret = {'y_c': None, 'y': None} 189 | if self.uncond_disc: 190 | ret['y'], start_frames_wins, ret['h'] = self.discriminator( 191 | x, x_len, start_frames_wins=start_frames_wins) 192 | if self.cond_size > 0 and cond is not None: 193 | ret['y_c'], start_frames_wins, ret['h_c'] = self.cond_disc( 194 | x, x_len, cond, start_frames_wins=start_frames_wins) 195 | ret['start_frames_wins'] = start_frames_wins 196 | return ret -------------------------------------------------------------------------------- /ldm/models/autoencoder_multi.py: -------------------------------------------------------------------------------- 1 | """ 2 | 与autoencoder.py的区别在于,autoencoder.py计算loss时只有一个discriminator,而此处又多了个multiwindowDiscriminator,所以优化器 3 | 优化的参数改为: 4 | opt_disc = torch.optim.Adam(list(self.loss.discriminator.parameters()) + list(self.loss.discriminator_multi.parameters()), 5 | lr=lr, betas=(0.5, 0.9)) 6 | """ 7 | 8 | import os 9 | import torch 10 | import pytorch_lightning as pl 11 | import torch.nn.functional as F 12 | from contextlib import contextmanager 13 | 14 | from packaging import version 15 | import numpy as np 16 | from ldm.modules.diffusionmodules.model import Encoder, Decoder 17 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 18 | from torch.optim.lr_scheduler import LambdaLR 19 | from ldm.util import instantiate_from_config 20 | 21 | 22 | 23 | class AutoencoderKL(pl.LightningModule): 24 | def __init__(self, 25 | ddconfig, 26 | lossconfig, 27 | embed_dim, 28 | ckpt_path=None, 29 | ignore_keys=[], 30 | image_key="image", 31 | colorize_nlabels=None, 32 | monitor=None, 33 | ): 34 | super().__init__() 35 | self.image_key = image_key 36 | self.encoder = Encoder(**ddconfig) 37 | self.decoder = Decoder(**ddconfig) 38 | self.loss = instantiate_from_config(lossconfig) 39 | assert ddconfig["double_z"] 40 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 41 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 42 | self.embed_dim = embed_dim 43 | if colorize_nlabels is not None: 44 | assert type(colorize_nlabels)==int 45 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 46 | if monitor is not None: 47 | self.monitor = monitor 48 | if ckpt_path is not None: 49 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 50 | 51 | def init_from_ckpt(self, path, ignore_keys=list()): 52 | sd = torch.load(path, map_location="cpu")["state_dict"] 53 | keys = list(sd.keys()) 54 | for k in keys: 55 | for ik in ignore_keys: 56 | if k.startswith(ik): 57 | print("Deleting key {} from state_dict.".format(k)) 58 | del sd[k] 59 | self.load_state_dict(sd, strict=False) 60 | print(f"Restored from {path}") 61 | 62 | def encode(self, x): 63 | h = self.encoder(x) 64 | moments = self.quant_conv(h) 65 | posterior = DiagonalGaussianDistribution(moments) 66 | return posterior 67 | 68 | def decode(self, z): 69 | z = self.post_quant_conv(z) 70 | dec = self.decoder(z) 71 | return dec 72 | 73 | def forward(self, input, sample_posterior=True): 74 | posterior = self.encode(input) 75 | if sample_posterior: 76 | z = posterior.sample() 77 | else: 78 | z = posterior.mode() 79 | dec = self.decode(z) 80 | return dec, posterior 81 | 82 | def get_input(self, batch, k): 83 | x = batch[k] 84 | if len(x.shape) == 3: 85 | x = x[..., None] 86 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 87 | return x 88 | 89 | def training_step(self, batch, batch_idx, optimizer_idx): 90 | inputs = self.get_input(batch, self.image_key) 91 | reconstructions, posterior = self(inputs) 92 | 93 | if optimizer_idx == 0: 94 | # train encoder+decoder+logvar 95 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 96 | last_layer=self.get_last_layer(), split="train") 97 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 98 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 99 | return aeloss 100 | 101 | if optimizer_idx == 1: 102 | # train the discriminator 103 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 104 | last_layer=self.get_last_layer(), split="train") 105 | 106 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 107 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 108 | return discloss 109 | 110 | def validation_step(self, batch, batch_idx): 111 | inputs = self.get_input(batch, self.image_key) 112 | reconstructions, posterior = self(inputs) 113 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 114 | last_layer=self.get_last_layer(), split="val") 115 | 116 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 117 | last_layer=self.get_last_layer(), split="val") 118 | 119 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 120 | self.log_dict(log_dict_ae) 121 | self.log_dict(log_dict_disc) 122 | return self.log_dict 123 | 124 | def test_step(self, batch, batch_idx): 125 | inputs = self.get_input(batch, self.image_key)# inputs shape:(b,c,mel_len,T) or (b,c,h,w) 126 | reconstructions, posterior = self(inputs)# reconstructions:(b,c,mel_len,T) or (b,c,h,w) 127 | reconstructions = (reconstructions + 1)/2 # to mel scale 128 | test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path) 129 | savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class') 130 | if not os.path.exists(savedir): 131 | os.makedirs(savedir) 132 | 133 | file_names = batch['f_name'] 134 | # print(f"reconstructions.shape:{reconstructions.shape}",file_names) 135 | reconstructions = reconstructions.cpu().numpy().squeeze(1) # squuze channel dim 136 | for b in range(reconstructions.shape[0]): 137 | vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num 138 | v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:] 139 | save_img_path = os.path.join(savedir,f'{v_n}_sample_{num}.npy') 140 | np.save(save_img_path,reconstructions[b]) 141 | 142 | return None 143 | 144 | def configure_optimizers(self): 145 | lr = self.learning_rate 146 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 147 | list(self.decoder.parameters())+ 148 | list(self.quant_conv.parameters())+ 149 | list(self.post_quant_conv.parameters()), 150 | lr=lr, betas=(0.5, 0.9)) 151 | opt_disc = torch.optim.Adam(list(self.loss.discriminator.parameters()) + list(self.loss.discriminator_multi.parameters()), 152 | lr=lr, betas=(0.5, 0.9)) 153 | return [opt_ae, opt_disc], [] 154 | 155 | def get_last_layer(self): 156 | return self.decoder.conv_out.weight 157 | 158 | @torch.no_grad() 159 | def log_images(self, batch, only_inputs=False, **kwargs): 160 | log = dict() 161 | x = self.get_input(batch, self.image_key) 162 | x = x.to(self.device) 163 | if not only_inputs: 164 | xrec, posterior = self(x) 165 | if x.shape[1] > 3: 166 | # colorize with random projection 167 | assert xrec.shape[1] > 3 168 | x = self.to_rgb(x) 169 | xrec = self.to_rgb(xrec) 170 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 171 | log["reconstructions"] = xrec 172 | log["inputs"] = x 173 | return log 174 | 175 | def to_rgb(self, x): 176 | assert self.image_key == "segmentation" 177 | if not hasattr(self, "colorize"): 178 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 179 | x = F.conv2d(x, weight=self.colorize) 180 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 181 | return x 182 | 183 | 184 | class IdentityFirstStage(torch.nn.Module): 185 | def __init__(self, *args, vq_interface=False, **kwargs): 186 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff 187 | super().__init__() 188 | 189 | def encode(self, x, *args, **kwargs): 190 | return x 191 | 192 | def decode(self, x, *args, **kwargs): 193 | return x 194 | 195 | def quantize(self, x, *args, **kwargs): 196 | if self.vq_interface: 197 | return x, None, [None, None, None] 198 | return x 199 | 200 | def forward(self, x, *args, **kwargs): 201 | return x -------------------------------------------------------------------------------- /scripts/audio2audio.py: -------------------------------------------------------------------------------- 1 | """make variations of input image""" 2 | 3 | import argparse, os, sys, glob 4 | import PIL 5 | import torch 6 | import numpy as np 7 | from omegaconf import OmegaConf 8 | from PIL import Image 9 | from tqdm import tqdm, trange 10 | from itertools import islice 11 | from einops import rearrange, repeat 12 | from torchvision.utils import make_grid 13 | from torch import autocast 14 | import librosa 15 | # from contextlib import nullcontext 16 | import time 17 | from pytorch_lightning import seed_everything 18 | import math 19 | from ldm.util import instantiate_from_config 20 | from ldm.models.diffusion.ddim import DDIMSampler 21 | from vocoder.bigvgan.models import VocoderBigVGAN 22 | # from ldm.data.extract_mel_spectrogram import TRANSFORMS_22050,TRANSFORMS_16000 23 | from preprocess.NAT_mel import MelNet 24 | import soundfile 25 | 26 | batch_max_length = 624 27 | SAMPLE_RATE= 16000 28 | 29 | def chunk(it, size): 30 | it = iter(it) 31 | return iter(lambda: tuple(islice(it, size)), ()) 32 | 33 | 34 | def load_model_from_config(config, ckpt, verbose=True): 35 | print(f"Loading model from {ckpt}") 36 | pl_sd = torch.load(ckpt, map_location="cpu") 37 | if "global_step" in pl_sd: 38 | print(f"Global Step: {pl_sd['global_step']}") 39 | sd = pl_sd["state_dict"] 40 | model = instantiate_from_config(config.model) 41 | m, u = model.load_state_dict(sd, strict=False) 42 | if len(m) > 0 and verbose: 43 | print("missing keys:") 44 | print(m) 45 | if len(u) > 0 and verbose: 46 | print("unexpected keys:") 47 | print(u) 48 | 49 | model.cuda() 50 | model.eval() 51 | return model 52 | 53 | def load_audio(path,transform,sr=16000,batch_max_length=624):# load wav and return mel 54 | wav,_ = librosa.load(path,sr=sr) 55 | 56 | audio = transform(wav) # (1,melbins,T) 57 | if audio.shape[2] <= batch_max_length: 58 | n_repeat = math.ceil((batch_max_length + 1) / audio.shape[1]) 59 | audio = audio.repeat(1,1, n_repeat) 60 | 61 | audio = audio[..., :batch_max_length].unsqueeze(0) # shape [1,1,80,batch_max_length] 62 | return audio 63 | 64 | def load_img(path):# load mel 65 | audio = np.load(path) 66 | if audio.shape[1] <= batch_max_length: 67 | n_repeat = math.ceil((batch_max_length + 1) / audio.shape[1]) 68 | audio = np.tile(audio, reps=(1, n_repeat)) 69 | 70 | audio = audio[:, :batch_max_length] 71 | audio = torch.FloatTensor(audio)[None, None, :, :] # [1,1,80,batch_max_length] 72 | return audio 73 | 74 | def parse_args(): 75 | parser = argparse.ArgumentParser() 76 | 77 | parser.add_argument( 78 | "--prompt", 79 | type=str, 80 | nargs="?", 81 | default="a bird chirping", 82 | help="the prompt to render" 83 | ) 84 | 85 | parser.add_argument( 86 | "--init-audio", 87 | type=str, 88 | nargs="?", 89 | help="path to the input image" 90 | ) 91 | 92 | parser.add_argument( 93 | "--outdir", 94 | type=str, 95 | nargs="?", 96 | help="dir to write results to", 97 | default="outputs/audio2audio-samples" 98 | ) 99 | 100 | 101 | parser.add_argument( 102 | "--ddim_steps", 103 | type=int, 104 | default=100, 105 | help="number of ddim sampling steps", 106 | ) 107 | 108 | parser.add_argument( 109 | "--ddim_eta", 110 | type=float, 111 | default=0.0, 112 | help="ddim eta (eta=0.0 corresponds to deterministic sampling", 113 | ) 114 | parser.add_argument( 115 | "--n_iter", 116 | type=int, 117 | default=1, 118 | help="sample this often", 119 | ) 120 | parser.add_argument( 121 | "--n_samples", 122 | type=int, 123 | default=2, 124 | help="how many samples to produce for each given prompt. A.k.a batch size", 125 | ) 126 | 127 | parser.add_argument( 128 | "--scale", 129 | type=float, 130 | default=3.0, 131 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", 132 | ) 133 | 134 | parser.add_argument( 135 | "--strength", 136 | type=float, 137 | default=0.3, 138 | help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image", 139 | ) 140 | parser.add_argument( 141 | "--from-file", 142 | type=str, 143 | help="if specified, load prompts from this file", 144 | ) 145 | parser.add_argument( 146 | "--config", 147 | type=str, 148 | default="configs/stable-diffusion/v1-inference.yaml", 149 | help="path to config which constructs model", 150 | ) 151 | parser.add_argument( 152 | "--ckpt", 153 | type=str, 154 | default="models/ldm/stable-diffusion-v1/model.ckpt", 155 | help="path to checkpoint of model", 156 | ) 157 | parser.add_argument( 158 | "--seed", 159 | type=int, 160 | default=42, 161 | help="the seed (for reproducible sampling)", 162 | ) 163 | parser.add_argument( 164 | "-v", 165 | "--vocoder_ckpt", 166 | type=str, 167 | help="resume from vocoder checkpoint", 168 | default="", 169 | ) 170 | return parser.parse_args() 171 | 172 | def main(): 173 | opt = parse_args() 174 | seed_everything(opt.seed) 175 | 176 | config = OmegaConf.load(f"{opt.config}") 177 | model = load_model_from_config(config, f"{opt.ckpt}") 178 | 179 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 180 | model = model.to(device) 181 | 182 | hparams = { 183 | 'audio_sample_rate': SAMPLE_RATE, 184 | 'audio_num_mel_bins':80, 185 | 'fft_size': 1024, 186 | 'win_size': 1024, 187 | 'hop_size': 256, 188 | 'fmin': 0, 189 | 'fmax': 8000, 190 | 'batch_max_length': 1248, 191 | 'mode': 'pad', # pad,none, 192 | } 193 | melnet = MelNet(hparams) 194 | sampler = DDIMSampler(model) 195 | vocoder = VocoderBigVGAN(opt.vocoder_ckpt,device) 196 | 197 | os.makedirs(opt.outdir, exist_ok=True) 198 | outpath = opt.outdir 199 | 200 | batch_size = opt.n_samples # 一个prompt产生n_samples个结果 201 | if not opt.from_file: # load prompts from this file 202 | prompt = opt.prompt 203 | assert prompt is not None 204 | data = [batch_size * [prompt]] 205 | else: 206 | print(f"reading prompts from {opt.from_file}") 207 | with open(opt.from_file, "r") as f: 208 | data = f.read().splitlines() 209 | data = list(chunk(data, batch_size)) 210 | 211 | sample_path = os.path.join(outpath, "samples") 212 | os.makedirs(sample_path, exist_ok=True) 213 | base_count = len(os.listdir(sample_path)) 214 | 215 | assert os.path.isfile(opt.init_audio) 216 | init_image = load_audio(opt.init_audio,transform=melnet).to(device) 217 | init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) 218 | init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space 219 | sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False) 220 | 221 | assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]' 222 | t_enc = int(opt.strength * opt.ddim_steps) 223 | print(f"target t_enc is {t_enc} steps") 224 | 225 | with torch.no_grad(): 226 | with model.ema_scope(): 227 | tic = time.time() 228 | all_samples = list() 229 | for n in trange(opt.n_iter, desc="Sampling"): 230 | for prompts in tqdm(data, desc="data"): 231 | uc = None 232 | if opt.scale != 1.0: # default=5.0 233 | uc = model.get_learned_conditioning(batch_size * [""]) 234 | if isinstance(prompts, tuple): 235 | prompts = list(prompts) 236 | c = model.get_learned_conditioning(prompts) 237 | z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) # [B, channel, c, h] 238 | # decode it 239 | samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, 240 | unconditional_conditioning=uc,) 241 | 242 | x_samples = model.decode_first_stage(samples) 243 | print(x_samples.shape) 244 | for x_sample in x_samples: 245 | spec = x_sample[0].cpu().numpy() 246 | spec_ori = init_image[0][0].cpu().numpy() 247 | print(x_sample.shape,spec.shape,init_image.shape) 248 | wav = vocoder.vocode(spec) 249 | wav_ori = vocoder.vocode(spec_ori) 250 | soundfile.write(os.path.join(outpath, f'{prompt.replace(" ", "-")}.wav'), wav, SAMPLE_RATE, 'FLOAT') 251 | soundfile.write(os.path.join(outpath, f'{prompt.replace(" ", "-")}_ori.wav'), wav_ori, SAMPLE_RATE, 'FLOAT') 252 | base_count += 1 253 | all_samples.append(x_samples) 254 | 255 | 256 | print(f"Your samples are ready and waiting for you here: \n{outpath} \n" 257 | f" \nEnjoy.") 258 | 259 | 260 | if __name__ == "__main__": 261 | main() 262 | -------------------------------------------------------------------------------- /preprocess/mel_spec.py: -------------------------------------------------------------------------------- 1 | import sys,os 2 | sys.path.append(os.getcwd()) 3 | from preprocess.NAT_mel import MelNet 4 | from tqdm import tqdm 5 | from glob import glob 6 | import math 7 | import pandas as pd 8 | import argparse 9 | from argparse import Namespace 10 | import math 11 | import audioread 12 | from tqdm.contrib.concurrent import process_map 13 | import torch 14 | import torchaudio 15 | import numpy as np 16 | from torch.distributed import init_process_group 17 | from torch.utils.data import Dataset,DataLoader,DistributedSampler 18 | import torch.multiprocessing as mp 19 | import json 20 | 21 | 22 | class tsv_dataset(Dataset): 23 | def __init__(self,tsv_path,sr,mode='none',hop_size = None,target_mel_length = None) -> None: 24 | super().__init__() 25 | if os.path.isdir(tsv_path): 26 | files = glob(os.path.join(tsv_path,'*.tsv')) 27 | df = pd.concat([pd.read_csv(file,sep='\t') for file in files]) 28 | else: 29 | df = pd.read_csv(tsv_path,sep='\t') 30 | self.audio_paths = [] 31 | self.sr = sr 32 | self.mode = mode 33 | self.target_mel_length = target_mel_length 34 | self.hop_size = hop_size 35 | for t in tqdm(df.itertuples()): 36 | self.audio_paths.append(getattr(t,'audio_path')) 37 | 38 | def __len__(self): 39 | return len(self.audio_paths) 40 | 41 | def pad_wav(self,wav): 42 | # wav should be in shape(1,wav_len) 43 | wav_length = wav.shape[-1] 44 | assert wav_length > 100, "wav is too short, %s" % wav_length 45 | segment_length = (self.target_mel_length + 1) * self.hop_size # final mel will crop the last mel, mel = mel[:,:-1] 46 | if segment_length is None or wav_length == segment_length: 47 | return wav 48 | elif wav_length > segment_length: 49 | return wav[:,:segment_length] 50 | elif wav_length < segment_length: 51 | temp_wav = torch.zeros((1, segment_length),dtype=torch.float32) 52 | temp_wav[:, :wav_length] = wav 53 | return temp_wav 54 | 55 | def __getitem__(self, index): 56 | audio_path = self.audio_paths[index] 57 | wav, orisr = torchaudio.load(audio_path) 58 | if wav.shape[0] != 1: # stereo to mono (2,wav_len) -> (1,wav_len) 59 | wav = wav.mean(0,keepdim=True) 60 | wav = torchaudio.functional.resample(wav, orig_freq=orisr, new_freq=self.sr) 61 | if self.mode == 'pad': 62 | assert self.target_mel_length is not None 63 | wav = self.pad_wav(wav) 64 | return audio_path,wav 65 | 66 | def process_audio_by_tsv(rank,args): 67 | if args.num_gpus > 1: 68 | init_process_group(backend=args.dist_config['dist_backend'], init_method=args.dist_config['dist_url'], 69 | world_size=args.dist_config['world_size'] * args.num_gpus, rank=rank) 70 | 71 | sr = args.audio_sample_rate 72 | dataset = tsv_dataset(args.tsv_path,sr = sr,mode=args.mode,hop_size=args.hop_size,target_mel_length=args.batch_max_length) 73 | sampler = DistributedSampler(dataset,shuffle=False) if args.num_gpus > 1 else None 74 | # batch_size must == 1,since wav_len is not equal 75 | loader = DataLoader(dataset, sampler=sampler,batch_size=1, num_workers=16,drop_last=False) 76 | 77 | device = torch.device('cuda:{:d}'.format(rank)) 78 | mel_net = MelNet(args.__dict__) 79 | mel_net.to(device) 80 | # if args.num_gpus > 1: # RuntimeError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient. 81 | # mel_net = DistributedDataParallel(mel_net, device_ids=[rank]).to(device) 82 | root = args.save_path 83 | loader = tqdm(loader) if rank == 0 else loader 84 | for batch in loader: 85 | audio_paths,wavs = batch 86 | wavs = wavs.to(device) 87 | if args.save_resample: 88 | for audio_path,wav in zip(audio_paths,wavs): 89 | psplits = audio_path.split('/') 90 | wav_name = psplits[-1] 91 | # save resample 92 | resample_root,resample_name = root+f'_{sr}',wav_name[:-4]+'_audio.npy' 93 | resample_dir_name = os.path.join(resample_root,*psplits[1:-1]) 94 | resample_path = os.path.join(resample_dir_name,resample_name) 95 | os.makedirs(resample_dir_name,exist_ok=True) 96 | np.save(resample_path,wav.cpu().numpy().squeeze(0)) 97 | 98 | if args.save_mel: 99 | mode = args.mode 100 | batch_max_length = args.batch_max_length 101 | 102 | for audio_path,wav in zip(audio_paths,wavs): 103 | psplits = audio_path.split('/') 104 | wav_name = psplits[-1] 105 | mel_root,mel_name = root,wav_name[:-4]+'_mel.npy' 106 | mel_dir_name = os.path.join(mel_root,f'mel{mode}{sr}',*psplits[1:-1]) 107 | mel_path = os.path.join(mel_dir_name,mel_name) 108 | if not os.path.exists(mel_path): 109 | mel_spec = mel_net(wav).cpu().numpy().squeeze(0) # (mel_bins,mel_len) 110 | if mel_spec.shape[1] <= batch_max_length: 111 | if mode == 'tile': # pad is done in dataset as pad wav 112 | n_repeat = math.ceil((batch_max_length + 1) / mel_spec.shape[1]) 113 | mel_spec = np.tile(mel_spec,reps=(1,n_repeat)) 114 | elif mode == 'none' or mode == 'pad': 115 | pass 116 | else: 117 | raise ValueError(f'mode:{mode} is not supported') 118 | mel_spec = mel_spec[:,:batch_max_length] 119 | os.makedirs(mel_dir_name,exist_ok=True) 120 | np.save(mel_path,mel_spec) 121 | 122 | 123 | def split_list(i_list,num): 124 | each_num = math.ceil(i_list / num) 125 | result = [] 126 | for i in range(num): 127 | s = each_num * i 128 | e = (each_num * (i+1)) 129 | result.append(i_list[s:e]) 130 | return result 131 | 132 | 133 | def drop_bad_wav(item): 134 | index,path = item 135 | try: 136 | with audioread.audio_open(path) as f: 137 | totalsec = f.duration 138 | if totalsec < 0.1: 139 | return index # index 140 | except: 141 | print(f"corrupted wav:{path}") 142 | return index 143 | return False 144 | 145 | def drop_bad_wavs(tsv_path):# 'audioset.csv' 146 | df = pd.read_csv(tsv_path,sep='\t') 147 | item_list = [] 148 | for item in tqdm(df.itertuples()): 149 | item_list.append((item[0],getattr(item,'audio_path'))) 150 | 151 | r = process_map(drop_bad_wav,item_list,max_workers=16,chunksize=16) 152 | bad_indices = list(filter(lambda x:x!= False,r)) 153 | 154 | print(bad_indices) 155 | with open('bad_wavs.json','w') as f: 156 | x = [item_list[i] for i in bad_indices] 157 | json.dump(x,f) 158 | df = df.drop(bad_indices,axis=0) 159 | df.to_csv(tsv_path,sep='\t',index=False) 160 | 161 | def addmel2tsv(save_dir,tsv_path): 162 | df = pd.read_csv(tsv_path,sep='\t') 163 | mels = glob(f'{save_dir}/mel{args.mode}{args.audio_sample_rate}/**/*_mel.npy',recursive=True) 164 | name2mel,idx2name,idx2mel = {},{},{} 165 | for mel in mels: 166 | bn = os.path.basename(mel)[:-8]# remove _mel.npy 167 | name2mel[bn] = mel 168 | for t in df.itertuples(): 169 | idx = int(t[0]) 170 | bn = os.path.basename(getattr(t,'audio_path'))[:-4] 171 | idx2name[idx] = bn 172 | for k,v in idx2name.items(): 173 | idx2mel[k] = name2mel[v] 174 | df['mel_path'] = df.index.map(idx2mel) 175 | df.to_csv(tsv_path,sep='\t',index=False) 176 | 177 | def parse_args(): 178 | parser = argparse.ArgumentParser() 179 | parser.add_argument( "--tsv_path",type=str) 180 | parser.add_argument( "--num_gpus",type=int,default=1) 181 | parser.add_argument( "--max_duration",type=int,default=30) 182 | return parser.parse_args() 183 | 184 | if __name__ == '__main__': 185 | pargs = parse_args() 186 | tsv_path = pargs.tsv_path 187 | if os.path.isdir(tsv_path): 188 | files = glob(os.path.join(tsv_path,'*.tsv')) 189 | for file in files: 190 | drop_bad_wavs(file) 191 | else: 192 | drop_bad_wavs(tsv_path) 193 | num_gpus = pargs.num_gpus 194 | batch_max_length = int(pargs.max_duration * 62.5)# 62.5 is the mel length for 1 second 195 | save_path = 'processed' 196 | args = { 197 | 'audio_sample_rate': 16000, 198 | 'audio_num_mel_bins':80, 199 | 'fft_size': 1024, 200 | 'win_size': 1024, 201 | 'hop_size': 256, 202 | 'fmin': 0, 203 | 'fmax': 8000, 204 | 'batch_max_length': batch_max_length, 205 | 'tsv_path': tsv_path, 206 | 'num_gpus': num_gpus, 207 | 'mode': 'pad', # pad,none, 208 | 'save_resample':False, 209 | 'save_mel' :True, 210 | 'save_path': save_path, 211 | } 212 | os.makedirs(save_path,exist_ok=True) 213 | args = Namespace(**args) 214 | args.dist_config = { 215 | "dist_backend": "nccl", 216 | "dist_url": "tcp://localhost:54189", 217 | "world_size": 1 218 | } 219 | if args.num_gpus>1: 220 | mp.spawn(process_audio_by_tsv,nprocs=args.num_gpus,args=(args,)) 221 | else: 222 | process_audio_by_tsv(0,args=args) 223 | print("proceoss mel done") 224 | addmel2tsv(save_path,tsv_path) 225 | print("done") 226 | 227 | -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):# 如果设置了context_dim就不是自注意力了 154 | super().__init__() 155 | inner_dim = dim_head * heads # inner_dim == SpatialTransformer.model_channels 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None):# x:(b,h*w,c), context:(b,seq_len,context_dim) 171 | h = self.heads 172 | 173 | q = self.to_q(x)# q:(b,h*w,inner_dim) 174 | context = default(context, x) 175 | k = self.to_k(context)# (b,seq_len,inner_dim) 176 | v = self.to_v(context)# (b,seq_len,inner_dim) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))# n is seq_len for k and v 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # (b*head,h*w,seq_len) 181 | 182 | if exists(mask):# false 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v)# (b*head,h*w,inner_dim/head) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)# (b,h*w,inner_dim) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape # such as [2,320,10,106] 253 | x_in = x 254 | x = self.norm(x)# group norm 255 | x = self.proj_in(x)# no shape change 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context)# context shape [b,seq_len=77,context_dim] 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | -------------------------------------------------------------------------------- /ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | 5 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 6 | from torch.utils.checkpoint import checkpoint 7 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel, AutoTokenizer 8 | from importlib_resources import files 9 | from ldm.modules.encoders.CLAP.utils import read_config_as_args 10 | from ldm.modules.encoders.CLAP.clap import TextEncoder 11 | from ldm.util import default, count_params 12 | 13 | 14 | class AbstractEncoder(nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def encode(self, *args, **kwargs): 19 | raise NotImplementedError 20 | 21 | 22 | class ClassEmbedder(nn.Module): 23 | def __init__(self, embed_dim, n_classes=1000, key='class'): 24 | super().__init__() 25 | self.key = key 26 | self.embedding = nn.Embedding(n_classes, embed_dim) 27 | 28 | def forward(self, batch, key=None): 29 | if key is None: 30 | key = self.key 31 | # this is for use in crossattn 32 | c = batch[key][:, None]# (bsz,1) 33 | c = self.embedding(c) 34 | return c 35 | 36 | 37 | class TransformerEmbedder(AbstractEncoder): 38 | """Some transformer encoder layers""" 39 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 40 | super().__init__() 41 | self.device = device 42 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 43 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 44 | 45 | def forward(self, tokens): 46 | tokens = tokens.to(self.device) # meh 47 | z = self.transformer(tokens, return_embeddings=True) 48 | return z 49 | 50 | def encode(self, x): 51 | return self(x) 52 | 53 | 54 | class BERTTokenizer(AbstractEncoder): 55 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 56 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 57 | super().__init__() 58 | from transformers import BertTokenizerFast # TODO: add to reuquirements 59 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 60 | self.device = device 61 | self.vq_interface = vq_interface 62 | self.max_length = max_length 63 | 64 | def forward(self, text): 65 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 66 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 67 | tokens = batch_encoding["input_ids"].to(self.device) 68 | return tokens 69 | 70 | @torch.no_grad() 71 | def encode(self, text): 72 | tokens = self(text) 73 | if not self.vq_interface: 74 | return tokens 75 | return None, None, [None, None, tokens] 76 | 77 | def decode(self, text): 78 | return text 79 | 80 | 81 | class BERTEmbedder(AbstractEncoder):# 这里不是用的pretrained bert,是用的transformers的BertTokenizer加自定义的TransformerWrapper 82 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 83 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 84 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 85 | super().__init__() 86 | self.use_tknz_fn = use_tokenizer 87 | if self.use_tknz_fn: 88 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 89 | self.device = device 90 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 91 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 92 | emb_dropout=embedding_dropout) 93 | 94 | def forward(self, text): 95 | if self.use_tknz_fn: 96 | tokens = self.tknz_fn(text)#.to(self.device) 97 | else: 98 | tokens = text 99 | z = self.transformer(tokens, return_embeddings=True) 100 | return z 101 | 102 | def encode(self, text): 103 | # output of length 77 104 | return self(text) 105 | 106 | 107 | class SpatialRescaler(nn.Module): 108 | def __init__(self, 109 | n_stages=1, 110 | method='bilinear', 111 | multiplier=0.5, 112 | in_channels=3, 113 | out_channels=None, 114 | bias=False): 115 | super().__init__() 116 | self.n_stages = n_stages 117 | assert self.n_stages >= 0 118 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 119 | self.multiplier = multiplier 120 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 121 | self.remap_output = out_channels is not None 122 | if self.remap_output: 123 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 124 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 125 | 126 | def forward(self,x): 127 | for stage in range(self.n_stages): 128 | x = self.interpolator(x, scale_factor=self.multiplier) 129 | 130 | 131 | if self.remap_output: 132 | x = self.channel_mapper(x) 133 | return x 134 | 135 | def encode(self, x): 136 | return self(x) 137 | 138 | def disabled_train(self, mode=True): 139 | """Overwrite model.train with this function to make sure train/eval mode 140 | does not change anymore.""" 141 | return self 142 | 143 | class FrozenT5Embedder(AbstractEncoder): 144 | """Uses the T5 transformer encoder for text""" 145 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 146 | super().__init__() 147 | self.tokenizer = T5Tokenizer.from_pretrained(version) 148 | self.transformer = T5EncoderModel.from_pretrained(version) 149 | self.device = device 150 | self.max_length = max_length # TODO: typical value? 151 | if freeze: 152 | self.freeze() 153 | 154 | def freeze(self): 155 | self.transformer = self.transformer.eval() 156 | #self.train = disabled_train 157 | for param in self.parameters(): 158 | param.requires_grad = False 159 | 160 | def forward(self, text): 161 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 162 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 163 | tokens = batch_encoding["input_ids"].to(self.device) 164 | outputs = self.transformer(input_ids=tokens) 165 | 166 | z = outputs.last_hidden_state 167 | return z 168 | 169 | def encode(self, text): 170 | return self(text) 171 | 172 | 173 | class FrozenCLAPEmbedder(AbstractEncoder): 174 | """Uses the CLAP transformer encoder for text (from huggingface)""" 175 | def __init__(self, weights_path=None, freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32 176 | super().__init__() 177 | if weights_path: 178 | model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model'] 179 | match_params = dict() 180 | for key in list(model_state_dict.keys()): 181 | if 'caption_encoder' in key: 182 | match_params[key.replace('caption_encoder.', '')] = model_state_dict[key] 183 | 184 | config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text() 185 | args = read_config_as_args(config_as_str, is_config_str=True) 186 | 187 | # To device 188 | self.tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model 189 | self.caption_encoder = TextEncoder( 190 | args.d_proj, args.text_model, args.transformer_embed_dim 191 | ) 192 | 193 | self.max_length = max_length 194 | self.device = device 195 | if freeze: self.freeze() 196 | 197 | print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.") 198 | 199 | def freeze(self): 200 | self.caption_encoder.base = self.caption_encoder.base.eval() 201 | for param in self.caption_encoder.base.parameters(): 202 | param.requires_grad = False 203 | 204 | 205 | def encode(self, text): 206 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 207 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 208 | tokens = batch_encoding["input_ids"].to(self.device) 209 | 210 | outputs = self.caption_encoder.base(input_ids=tokens) 211 | z = self.caption_encoder.projection(outputs.last_hidden_state) 212 | return z 213 | 214 | class FrozenCLAPEmbedderNoLoad(AbstractEncoder): 215 | def __init__(self, config, freeze=True, device="cpu", max_length=77): 216 | super().__init__() 217 | args = config 218 | 219 | # To device 220 | self.tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model 221 | self.caption_encoder = TextEncoder( 222 | args.d_proj, args.text_model, args.transformer_embed_dim 223 | ) 224 | 225 | self.max_length = max_length 226 | self.device = device 227 | if freeze: self.freeze() 228 | 229 | print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.") 230 | 231 | def freeze(self): 232 | self.caption_encoder.base = self.caption_encoder.base.eval() 233 | for param in self.caption_encoder.base.parameters(): 234 | param.requires_grad = False 235 | 236 | 237 | def encode(self, text): 238 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 239 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 240 | tokens = batch_encoding["input_ids"].to(self.device) 241 | 242 | outputs = self.caption_encoder.base(input_ids=tokens) 243 | z = self.caption_encoder.projection(outputs.last_hidden_state) 244 | return z 245 | 246 | 247 | class FrozenFLANEmbedder(AbstractEncoder): 248 | """Uses the T5 transformer encoder for text""" 249 | def __init__(self, version="google/flan-t5-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 250 | super().__init__() 251 | self.tokenizer = T5Tokenizer.from_pretrained(version) 252 | self.transformer = T5EncoderModel.from_pretrained(version) 253 | self.device = device 254 | self.max_length = max_length # TODO: typical value? 255 | if freeze: 256 | self.freeze() 257 | 258 | def freeze(self): 259 | self.transformer = self.transformer.eval() 260 | #self.train = disabled_train 261 | for param in self.parameters(): 262 | param.requires_grad = False 263 | 264 | def forward(self, text): 265 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 266 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 267 | tokens = batch_encoding["input_ids"].to(self.device) 268 | outputs = self.transformer(input_ids=tokens) 269 | 270 | z = outputs.last_hidden_state 271 | return z 272 | 273 | def encode(self, text): 274 | return self(text) -------------------------------------------------------------------------------- /ldm/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | class ActNorm(nn.Module): 6 | def __init__(self, num_features, logdet=False, affine=True, 7 | allow_reverse_init=False): 8 | assert affine 9 | super().__init__() 10 | self.logdet = logdet 11 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 12 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 13 | self.allow_reverse_init = allow_reverse_init 14 | 15 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 16 | 17 | def initialize(self, input): 18 | with torch.no_grad(): 19 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 20 | mean = ( 21 | flatten.mean(1) 22 | .unsqueeze(1) 23 | .unsqueeze(2) 24 | .unsqueeze(3) 25 | .permute(1, 0, 2, 3) 26 | ) 27 | std = ( 28 | flatten.std(1) 29 | .unsqueeze(1) 30 | .unsqueeze(2) 31 | .unsqueeze(3) 32 | .permute(1, 0, 2, 3) 33 | ) 34 | 35 | self.loc.data.copy_(-mean) 36 | self.scale.data.copy_(1 / (std + 1e-6)) 37 | 38 | def forward(self, input, reverse=False): 39 | if reverse: 40 | return self.reverse(input) 41 | if len(input.shape) == 2: 42 | input = input[:, :, None, None] 43 | squeeze = True 44 | else: 45 | squeeze = False 46 | 47 | _, _, height, width = input.shape 48 | 49 | if self.training and self.initialized.item() == 0: 50 | self.initialize(input) 51 | self.initialized.fill_(1) 52 | 53 | h = self.scale * (input + self.loc) 54 | 55 | if squeeze: 56 | h = h.squeeze(-1).squeeze(-1) 57 | 58 | if self.logdet: 59 | log_abs = torch.log(torch.abs(self.scale)) 60 | logdet = height * width * torch.sum(log_abs) 61 | logdet = logdet * torch.ones(input.shape[0]).to(input) 62 | return h, logdet 63 | 64 | return h 65 | 66 | def reverse(self, output): 67 | if self.training and self.initialized.item() == 0: 68 | if not self.allow_reverse_init: 69 | raise RuntimeError( 70 | "Initializing ActNorm in reverse direction is " 71 | "disabled by default. Use allow_reverse_init=True to enable." 72 | ) 73 | else: 74 | self.initialize(output) 75 | self.initialized.fill_(1) 76 | 77 | if len(output.shape) == 2: 78 | output = output[:, :, None, None] 79 | squeeze = True 80 | else: 81 | squeeze = False 82 | 83 | h = output / self.scale - self.loc 84 | 85 | if squeeze: 86 | h = h.squeeze(-1).squeeze(-1) 87 | return h 88 | 89 | def weights_init(m): 90 | classname = m.__class__.__name__ 91 | if classname.find('Conv') != -1: 92 | nn.init.normal_(m.weight.data, 0.0, 0.02) 93 | elif classname.find('BatchNorm') != -1: 94 | nn.init.normal_(m.weight.data, 1.0, 0.02) 95 | nn.init.constant_(m.bias.data, 0) 96 | 97 | 98 | class NLayerDiscriminator(nn.Module): 99 | """Defines a PatchGAN discriminator as in Pix2Pix 100 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 101 | """ 102 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 103 | """Construct a PatchGAN discriminator 104 | Parameters: 105 | input_nc (int) -- the number of channels in input images 106 | ndf (int) -- the number of filters in the last conv layer 107 | n_layers (int) -- the number of conv layers in the discriminator 108 | norm_layer -- normalization layer 109 | """ 110 | super(NLayerDiscriminator, self).__init__() 111 | if not use_actnorm: 112 | norm_layer = nn.BatchNorm2d 113 | else: 114 | norm_layer = ActNorm 115 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 116 | use_bias = norm_layer.func != nn.BatchNorm2d 117 | else: 118 | use_bias = norm_layer != nn.BatchNorm2d 119 | 120 | kw = 4 121 | padw = 1 122 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 123 | nf_mult = 1 124 | nf_mult_prev = 1 125 | for n in range(1, n_layers): # gradually increase the number of filters 126 | nf_mult_prev = nf_mult 127 | nf_mult = min(2 ** n, 8) 128 | sequence += [ 129 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 130 | norm_layer(ndf * nf_mult), 131 | nn.LeakyReLU(0.2, True) 132 | ] 133 | 134 | nf_mult_prev = nf_mult 135 | nf_mult = min(2 ** n_layers, 8) 136 | sequence += [ 137 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 138 | norm_layer(ndf * nf_mult), 139 | nn.LeakyReLU(0.2, True) 140 | ] 141 | # output 1 channel prediction map 142 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 143 | self.main = nn.Sequential(*sequence) 144 | 145 | def forward(self, input): 146 | """Standard forward.""" 147 | return self.main(input) 148 | 149 | class NLayerDiscriminator1dFeats(NLayerDiscriminator): 150 | """Defines a PatchGAN discriminator as in Pix2Pix 151 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 152 | """ 153 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 154 | """Construct a PatchGAN discriminator 155 | Parameters: 156 | input_nc (int) -- the number of channels in input feats 157 | ndf (int) -- the number of filters in the last conv layer 158 | n_layers (int) -- the number of conv layers in the discriminator 159 | norm_layer -- normalization layer 160 | """ 161 | super().__init__(input_nc=input_nc, ndf=64, n_layers=n_layers, use_actnorm=use_actnorm) 162 | 163 | if not use_actnorm: 164 | norm_layer = nn.BatchNorm1d 165 | else: 166 | norm_layer = ActNorm 167 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm has affine parameters 168 | use_bias = norm_layer.func != nn.BatchNorm1d 169 | else: 170 | use_bias = norm_layer != nn.BatchNorm1d 171 | 172 | kw = 4 173 | padw = 1 174 | sequence = [nn.Conv1d(input_nc, input_nc//2, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 175 | nf_mult = input_nc//2 176 | nf_mult_prev = 1 177 | for n in range(1, n_layers): # gradually decrease the number of filters 178 | nf_mult_prev = nf_mult 179 | nf_mult = max(nf_mult_prev // (2 ** n), 8) 180 | sequence += [ 181 | nn.Conv1d(nf_mult_prev, nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 182 | norm_layer(nf_mult), 183 | nn.LeakyReLU(0.2, True) 184 | ] 185 | 186 | nf_mult_prev = nf_mult 187 | nf_mult = max(nf_mult_prev // (2 ** n), 8) 188 | sequence += [ 189 | nn.Conv1d(nf_mult_prev, nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 190 | norm_layer(nf_mult), 191 | nn.LeakyReLU(0.2, True) 192 | ] 193 | nf_mult_prev = nf_mult 194 | nf_mult = max(nf_mult_prev // (2 ** n), 8) 195 | sequence += [ 196 | nn.Conv1d(nf_mult_prev, nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 197 | norm_layer(nf_mult), 198 | nn.LeakyReLU(0.2, True) 199 | ] 200 | # output 1 channel prediction map 201 | sequence += [nn.Conv1d(nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 202 | self.main = nn.Sequential(*sequence) 203 | 204 | 205 | class NLayerDiscriminator1dSpecs(NLayerDiscriminator): 206 | """Defines a PatchGAN discriminator as in Pix2Pix 207 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 208 | """ 209 | def __init__(self, input_nc=80, ndf=64, n_layers=3, use_actnorm=False): 210 | """Construct a PatchGAN discriminator 211 | Parameters: 212 | input_nc (int) -- the number of channels in input specs 213 | ndf (int) -- the number of filters in the last conv layer 214 | n_layers (int) -- the number of conv layers in the discriminator 215 | norm_layer -- normalization layer 216 | """ 217 | super().__init__(input_nc=input_nc, ndf=64, n_layers=n_layers, use_actnorm=use_actnorm) 218 | 219 | if not use_actnorm: 220 | norm_layer = nn.BatchNorm1d 221 | else: 222 | norm_layer = ActNorm 223 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm has affine parameters 224 | use_bias = norm_layer.func != nn.BatchNorm1d 225 | else: 226 | use_bias = norm_layer != nn.BatchNorm1d 227 | 228 | kw = 4 229 | padw = 1 230 | sequence = [nn.Conv1d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 231 | nf_mult = 1 232 | nf_mult_prev = 1 233 | for n in range(1, n_layers): # gradually decrease the number of filters 234 | nf_mult_prev = nf_mult 235 | nf_mult = min(2 ** n, 8) 236 | sequence += [ 237 | nn.Conv1d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 238 | norm_layer(ndf * nf_mult), 239 | nn.LeakyReLU(0.2, True) 240 | ] 241 | 242 | nf_mult_prev = nf_mult 243 | nf_mult = min(2 ** n_layers, 8) 244 | sequence += [ 245 | nn.Conv1d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 246 | norm_layer(ndf * nf_mult), 247 | nn.LeakyReLU(0.2, True) 248 | ] 249 | # output 1 channel prediction map 250 | sequence += [nn.Conv1d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 251 | self.main = nn.Sequential(*sequence) 252 | 253 | def forward(self, input): 254 | """Standard forward.""" 255 | # (B, C, L) 256 | input = input.squeeze(1) 257 | input = self.main(input) 258 | return input 259 | 260 | 261 | if __name__ == '__main__': 262 | import torch 263 | 264 | ## FEATURES 265 | disc_in_channels = 2048 266 | disc_num_layers = 2 267 | use_actnorm = False 268 | disc_ndf = 64 269 | discriminator = NLayerDiscriminator1dFeats(input_nc=disc_in_channels, n_layers=disc_num_layers, 270 | use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init) 271 | inputs = torch.rand((6, 2048, 212)) 272 | outputs = discriminator(inputs) 273 | print(outputs.shape) 274 | 275 | ## AUDIO 276 | disc_in_channels = 1 277 | disc_num_layers = 3 278 | use_actnorm = False 279 | disc_ndf = 64 280 | discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers, 281 | use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init) 282 | inputs = torch.rand((6, 1, 80, 848)) 283 | outputs = discriminator(inputs) 284 | print(outputs.shape) 285 | 286 | ## IMAGE 287 | disc_in_channels = 3 288 | disc_num_layers = 3 289 | use_actnorm = False 290 | disc_ndf = 64 291 | discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers, 292 | use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init) 293 | inputs = torch.rand((6, 3, 256, 256)) 294 | outputs = discriminator(inputs) 295 | print(outputs.shape) 296 | -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/CLAPWrapper.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchaudio 3 | from torch._six import string_classes 4 | import collections 5 | import re 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from transformers import AutoTokenizer 9 | from ldm.modules.encoders.CLAP.utils import read_config_as_args 10 | from ldm.modules.encoders.CLAP.clap import CLAP 11 | import math 12 | import torchaudio.transforms as T 13 | import os 14 | import torch 15 | from importlib_resources import files 16 | 17 | 18 | class CLAPWrapper(): 19 | """ 20 | A class for interfacing CLAP model. 21 | """ 22 | 23 | def __init__(self, model_fp, device): 24 | self.np_str_obj_array_pattern = re.compile(r'[SaUO]') 25 | self.file_path = os.path.realpath(__file__) 26 | self.default_collate_err_msg_format = ( 27 | "default_collate: batch must contain tensors, numpy arrays, numbers, " 28 | "dicts or lists; found {}") 29 | self.config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text() 30 | self.model_fp = model_fp 31 | self.device = device 32 | self.clap, self.tokenizer, self.args = self.load_clap() 33 | 34 | def load_clap(self): 35 | r"""Load CLAP model with args from config file""" 36 | 37 | args = read_config_as_args(self.config_as_str, is_config_str=True) 38 | 39 | if 'bert' in args.text_model: 40 | self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask'] 41 | else: 42 | self.token_keys = ['input_ids', 'attention_mask'] 43 | 44 | clap = CLAP( 45 | audioenc_name=args.audioenc_name, 46 | sample_rate=args.sampling_rate, 47 | window_size=args.window_size, 48 | hop_size=args.hop_size, 49 | mel_bins=args.mel_bins, 50 | fmin=args.fmin, 51 | fmax=args.fmax, 52 | classes_num=args.num_classes, 53 | out_emb=args.out_emb, 54 | text_model=args.text_model, 55 | transformer_embed_dim=args.transformer_embed_dim, 56 | d_proj=args.d_proj 57 | ) 58 | 59 | # Load pretrained weights for model 60 | model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model'] 61 | clap.load_state_dict(model_state_dict) 62 | 63 | clap.eval() # set clap in eval mode 64 | tokenizer = AutoTokenizer.from_pretrained(args.text_model) 65 | 66 | clap = clap.to(self.device) 67 | tokenizer = tokenizer.to(self.device) 68 | 69 | return clap, tokenizer, args 70 | 71 | def default_collate(self, batch): 72 | r"""Puts each data field into a tensor with outer dimension batch size""" 73 | elem = batch[0] 74 | elem_type = type(elem) 75 | if isinstance(elem, torch.Tensor): 76 | out = None 77 | if torch.utils.data.get_worker_info() is not None: 78 | # If we're in a background process, concatenate directly into a 79 | # shared memory tensor to avoid an extra copy 80 | numel = sum([x.numel() for x in batch]) 81 | storage = elem.storage()._new_shared(numel) 82 | out = elem.new(storage) 83 | return torch.stack(batch, 0, out=out) 84 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 85 | and elem_type.__name__ != 'string_': 86 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 87 | # array of string classes and object 88 | if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None: 89 | raise TypeError( 90 | self.default_collate_err_msg_format.format(elem.dtype)) 91 | 92 | return self.default_collate([torch.as_tensor(b) for b in batch]) 93 | elif elem.shape == (): # scalars 94 | return torch.as_tensor(batch) 95 | elif isinstance(elem, float): 96 | return torch.tensor(batch, dtype=torch.float64) 97 | elif isinstance(elem, int): 98 | return torch.tensor(batch) 99 | elif isinstance(elem, string_classes): 100 | return batch 101 | elif isinstance(elem, collections.abc.Mapping): 102 | return {key: self.default_collate([d[key] for d in batch]) for key in elem} 103 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 104 | return elem_type(*(self.default_collate(samples) for samples in zip(*batch))) 105 | elif isinstance(elem, collections.abc.Sequence): 106 | # check to make sure that the elements in batch have consistent size 107 | it = iter(batch) 108 | elem_size = len(next(it)) 109 | if not all(len(elem) == elem_size for elem in it): 110 | raise RuntimeError( 111 | 'each element in list of batch should be of equal size') 112 | transposed = zip(*batch) 113 | return [self.default_collate(samples) for samples in transposed] 114 | 115 | raise TypeError(self.default_collate_err_msg_format.format(elem_type)) 116 | 117 | def load_audio_into_tensor(self, audio_path, audio_duration, resample=False): 118 | r"""Loads audio file and returns raw audio.""" 119 | # Randomly sample a segment of audio_duration from the clip or pad to match duration 120 | audio_time_series, sample_rate = torchaudio.load(audio_path) 121 | resample_rate = self.args.sampling_rate 122 | if resample: 123 | resampler = T.Resample(sample_rate, resample_rate) 124 | audio_time_series = resampler(audio_time_series) 125 | audio_time_series = audio_time_series.reshape(-1) 126 | 127 | # audio_time_series is shorter than predefined audio duration, 128 | # so audio_time_series is extended 129 | if audio_duration*sample_rate >= audio_time_series.shape[0]: 130 | repeat_factor = int(np.ceil((audio_duration*sample_rate) / 131 | audio_time_series.shape[0])) 132 | # Repeat audio_time_series by repeat_factor to match audio_duration 133 | audio_time_series = audio_time_series.repeat(repeat_factor) 134 | # remove excess part of audio_time_series 135 | audio_time_series = audio_time_series[0:audio_duration*sample_rate] 136 | else: 137 | # audio_time_series is longer than predefined audio duration, 138 | # so audio_time_series is trimmed 139 | start_index = random.randrange( 140 | audio_time_series.shape[0] - audio_duration*sample_rate) 141 | audio_time_series = audio_time_series[start_index:start_index + 142 | audio_duration*sample_rate] 143 | return torch.FloatTensor(audio_time_series) 144 | 145 | def preprocess_audio(self, audio_files, resample): 146 | r"""Load list of audio files and return raw audio""" 147 | audio_tensors = [] 148 | for audio_file in audio_files: 149 | audio_tensor = self.load_audio_into_tensor( 150 | audio_file, self.args.duration, resample) 151 | audio_tensor = audio_tensor.reshape(1, -1).to(self.device) 152 | audio_tensors.append(audio_tensor) 153 | return self.default_collate(audio_tensors) 154 | 155 | def preprocess_text(self, text_queries, text_len=100): 156 | r"""Load list of class labels and return tokenized text""" 157 | device = next(self.clap.parameters()).device 158 | tokenized_texts = [] 159 | for ttext in text_queries: 160 | tok = self.tokenizer.encode_plus( 161 | text=ttext, add_special_tokens=True, max_length=text_len, pad_to_max_length=True, return_tensors="pt") 162 | for key in self.token_keys: 163 | tok[key] = tok[key].reshape(-1).to(device) 164 | tokenized_texts.append(tok) 165 | return self.default_collate(tokenized_texts) 166 | 167 | def get_text_embeddings(self, class_labels): 168 | r"""Load list of class labels and return text embeddings""" 169 | preprocessed_text = self.preprocess_text(class_labels) 170 | text_embeddings = self._get_text_embeddings(preprocessed_text) 171 | text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True) 172 | return text_embeddings 173 | 174 | def get_audio_embeddings(self, audio_files, resample): 175 | r"""Load list of audio files and return a audio embeddings""" 176 | preprocessed_audio = self.preprocess_audio(audio_files, resample) 177 | audio_embeddings = self._get_audio_embeddings(preprocessed_audio) 178 | audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True) 179 | return audio_embeddings 180 | 181 | def _get_text_embeddings(self, preprocessed_text): 182 | r"""Load preprocessed text and return text embeddings""" 183 | with torch.no_grad(): 184 | text_embeddings = self.clap.caption_encoder(preprocessed_text) 185 | text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True) 186 | return text_embeddings 187 | 188 | def _get_audio_embeddings(self, preprocessed_audio): 189 | r"""Load preprocessed audio and return a audio embeddings""" 190 | with torch.no_grad(): 191 | preprocessed_audio = preprocessed_audio.reshape( 192 | preprocessed_audio.shape[0], preprocessed_audio.shape[2]) 193 | #Append [0] the audio emebdding, [1] has output class probabilities 194 | audio_embeddings = self.clap.audio_encoder(preprocessed_audio)[0] 195 | audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True) 196 | return audio_embeddings 197 | 198 | def compute_similarity(self, audio_embeddings, text_embeddings): 199 | r"""Compute similarity between text and audio embeddings""" 200 | logit_scale = self.clap.logit_scale.exp() 201 | similarity = logit_scale*text_embeddings @ audio_embeddings.T 202 | return similarity.T 203 | 204 | def _generic_batch_inference(self, func, *args): 205 | r"""Process audio and/or text per batch""" 206 | input_tmp = args[0] 207 | batch_size = args[-1] 208 | # args[0] has audio_files, args[1] has class_labels 209 | inputs = [args[0], args[1]] if len(args) == 3 else [args[0]] 210 | args0_len = len(args[0]) 211 | # compute text_embeddings once for all the audio_files batches 212 | if len(inputs) == 2: 213 | text_embeddings = self.get_text_embeddings(args[1]) 214 | inputs = [args[0], args[1], text_embeddings] 215 | dataset_idx = 0 216 | for _ in range(math.ceil(args0_len/batch_size)): 217 | next_batch_idx = dataset_idx + batch_size 218 | # batch size is bigger than available audio/text items 219 | if next_batch_idx >= args0_len: 220 | inputs[0] = input_tmp[dataset_idx:] 221 | return func(*tuple(inputs)) 222 | else: 223 | inputs[0] = input_tmp[dataset_idx:next_batch_idx] 224 | yield func(*tuple(inputs)) 225 | dataset_idx = next_batch_idx 226 | 227 | def get_audio_embeddings_per_batch(self, audio_files, batch_size): 228 | r"""Load preprocessed audio and return a audio embeddings per batch""" 229 | return self._generic_batch_inference(self.get_audio_embeddings, audio_files, batch_size) 230 | 231 | def get_text_embeddings_per_batch(self, class_labels, batch_size): 232 | r"""Load preprocessed text and return text embeddings per batch""" 233 | return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size) 234 | 235 | def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size): 236 | r"""Compute classification probabilities for each audio recording in a batch and each class label""" 237 | return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size) 238 | 239 | if __name__ == '__main__': 240 | 241 | # Load and initialize CLAP 242 | weights_path = "/home1/huangrongjie/Project/Diffusion/LatentDiffusion/CLAP/CLAP_weights_2022.pth" 243 | clap_model = CLAPWrapper(weights_path, use_cuda=False) 244 | 245 | y = ["A woman talks nearby as water pours", "Multiple clanging and clanking sounds"] 246 | x = ['/home2/huangjiawei/data/audiocaps/train/Yr1nicOVtvkQ.wav', '/home2/huangjiawei/data/audiocaps/train/YUDGBjjwyaqE.wav'] 247 | 248 | # Computing text embeddings 249 | text_embeddings = clap_model.get_text_embeddings(y) 250 | 251 | import ipdb 252 | ipdb.set_trace() 253 | 254 | # Computing audio embeddings 255 | audio_embeddings = clap_model.get_audio_embeddings(x, resample=True) 256 | similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings) 257 | 258 | --------------------------------------------------------------------------------