├── .gitignore ├── README.md ├── audiocaps_test_16000_struct.tsv ├── configs ├── audiolcm.yaml ├── autoencoder1d.yaml └── teacher.yaml ├── ldm ├── lr_scheduler.py ├── models │ ├── autoencoder.py │ ├── autoencoder1d.py │ ├── autoencoder_multi.py │ └── diffusion │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddim_solver.py │ │ ├── ddpm.py │ │ ├── ddpm_audio.py │ │ ├── ddpm_audio_inpaint.py │ │ ├── ddpm_audio_order.py │ │ ├── lcm_audio.py │ │ ├── plms.py │ │ └── scheduling_lcm.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── concatDiT.py │ │ ├── flag_large_dit.py │ │ ├── model.py │ │ └── util.py │ ├── discriminator │ │ ├── model.py │ │ └── multi_window_disc.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── CLAP │ │ │ ├── CLAPWrapper.py │ │ │ ├── __init__.py │ │ │ ├── audio.py │ │ │ ├── clap.py │ │ │ ├── config.yaml │ │ │ └── utils.py │ │ ├── __init__.py │ │ ├── modules.py │ │ └── open_clap │ │ │ ├── __init__.py │ │ │ ├── bert.py │ │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ │ ├── factory.py │ │ │ ├── feature_fusion.py │ │ │ ├── htsat.py │ │ │ ├── linear_probe.py │ │ │ ├── loss.py │ │ │ ├── model.py │ │ │ ├── model_configs │ │ │ ├── HTSAT-base.json │ │ │ ├── HTSAT-large.json │ │ │ ├── HTSAT-tiny-win-1536.json │ │ │ ├── HTSAT-tiny.json │ │ │ ├── PANN-10.json │ │ │ ├── PANN-14-fmax-18k.json │ │ │ ├── PANN-14-fmax-8k-20s.json │ │ │ ├── PANN-14-tiny-transformer.json │ │ │ ├── PANN-14-win-1536.json │ │ │ ├── PANN-14.json │ │ │ ├── PANN-6.json │ │ │ ├── RN101-quickgelu.json │ │ │ ├── RN101.json │ │ │ ├── RN50-quickgelu.json │ │ │ ├── RN50.json │ │ │ ├── RN50x16.json │ │ │ ├── RN50x4.json │ │ │ ├── ViT-B-16.json │ │ │ ├── ViT-B-32-quickgelu.json │ │ │ ├── ViT-B-32.json │ │ │ └── ViT-L-14.json │ │ │ ├── openai.py │ │ │ ├── pann_model.py │ │ │ ├── pretrained.py │ │ │ ├── timm_model.py │ │ │ ├── tokenizer.py │ │ │ ├── transform.py │ │ │ ├── utils.py │ │ │ ├── version.py │ │ │ └── wrapper.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ ├── losses_audio │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ ├── contperceptual_dis.py │ │ ├── contperceptual_mask.py │ │ ├── contperceptual_multiw.py │ │ ├── lpaps.py │ │ └── vqperceptual.py │ ├── new_attention.py │ └── x_transformer.py └── util.py ├── main.py ├── pythonscripts ├── InferAPI.py ├── __pycache__ │ ├── InferAPI.cpython-38.pyc │ └── txt2audio_for_2cap.cpython-37.pyc ├── reconstruct_audio.py ├── txt2audio_for_2cap.py └── txt2audio_for_lcm.py ├── requirements.txt ├── scripts ├── reconstruct_audio.py ├── txt2audio_for_2cap.py └── txt2audio_for_lcm.py ├── vocoder └── bigvgan │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── activations.py │ ├── alias_free_torch │ ├── __init__.py │ ├── act.py │ ├── filter.py │ └── resample.py │ ├── bigvgan_audioset16khz_80band.json │ ├── configs │ ├── bigvgan_22khz_80band.json │ ├── bigvgan_24khz_100band.json │ ├── bigvgan_base_22khz_80band.json │ └── bigvgan_base_24khz_100band.json │ ├── env.py │ ├── incl_licenses │ ├── LICENSE_1 │ ├── LICENSE_2 │ ├── LICENSE_3 │ ├── LICENSE_4 │ └── LICENSE_5 │ ├── inference.py │ ├── inference_e2e.py │ ├── meldataset.py │ ├── models.py │ ├── parse_scripts │ └── parse_libritts.py │ ├── requirements.txt │ ├── train.py │ ├── train_vocoder.py │ └── utils.py └── wav_evaluation ├── cal_clap_score.py ├── cal_fad_score.py ├── metrics └── fad.py ├── models ├── CLAPWrapper.py ├── CLAPWrapper_for_CLAP.py ├── __init__.py ├── audio.py ├── clap.py └── utils.py └── useful_ckpts └── CLAP └── config.yml /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *__pycache__ 3 | useful_ckpts/bigvgan 4 | useful_ckpts/*.ckpt 5 | useful_ckpts/CLAP/*.ckpt 6 | evaluation 7 | .idea/ 8 | logs 9 | audiocaps_gen 10 | audioldm_eval 11 | src 12 | processed 13 | run.sh 14 | infer.sh 15 | *.DS_Store 16 | data_melnone16000nfft1024 17 | data 18 | audiocaps_mels 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AudioLCM: Text-to-Audio Generation with Latent Consistency Models 2 | 3 | #### Huadai Liu, Rongjie Huang, Yang Liu, Hengyuan Cao, Jialei Wang, Xize Cheng, Siqi Zheng, Zhou Zhao 4 | 5 | PyTorch Implementation of [AudioLCM]: an efficient and high-quality text-to-audio generation with latent consistency model. 6 | 7 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2406.00356v1) 8 | [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/AIGC-Audio/AudioLCM) 9 | [![GitHub Stars](https://img.shields.io/github/stars/liuhuadai/AudioLCM?style=social)](https://github.com/liuhuadai/AudioLCM) 10 | 11 | We provide our implementation and pretrained models as open-source in this repository. 12 | 13 | Visit our [demo page](https://audiolcm.github.io/) for audio samples. 14 | 15 | [AudioLCM HuggingFace Space](https://huggingface.co/spaces/AIGC-Audio/AudioLCM) 16 | 17 | ## News 18 | 19 | - June, 2024: **[AudioLCM]** released in Github and HuggingFace. 20 | 21 | ## Quick Started 22 | We provide an example of how you can generate high-fidelity samples quickly using AudioLCM. 23 | 24 | To try on your own dataset, simply clone this repo in your local machine provided with NVIDIA GPU + CUDA cuDNN and follow the below instructions. 25 | 26 | 27 | ### Support Datasets and Pretrained Models 28 | 29 | Simply download the weights from [Huggingface](https://huggingface.co/liuhuadai/AudioLCM). 30 | 31 | 32 | ``` 33 | Download: 34 | audiolcm.ckpt and put it into ./ckpts 35 | BigVGAN vocoder and put it into ./vocoder/logs/bigvnat16k93.5w 36 | t5-v1_1-large and put it into ./ldm/modules/encoders/CLAP 37 | bert-base-uncased and put it into ./ldm/modules/encoders/CLAP 38 | CLAP_weights_2022.pth and put it into ./wav_evaluation/useful_ckpts/CLAP 39 | ``` 40 | 51 | 52 | 53 | ### Dependencies 54 | See requirements in `requirement.txt`: 55 | 56 | ## Inference with a pre-trained model 57 | ```bash 58 | python scripts/txt2audio_for_lcm.py --ddim_steps 2 -b configs/audiolcm.yaml --sample_rate 16000 --vocoder-ckpt vocoder/logs/bigvnat16k93.5w --outdir results --test-dataset audiocaps -r ckpt/audiolcm.ckpt 59 | ``` 60 | 61 | ## Dataset preparation 62 | - We can't provide the dataset download link for copyright issues. We provide the process code to generate melspec. 63 | - Before training, we need to construct the dataset information into a tsv file, which includes the name (id for each audio), dataset (which dataset the audio belongs to), audio_path (the path of .wav file),caption (the caption of the audio) ,mel_path (the processed melspec file path of each audio). 64 | - We provide a tsv file of the audiocaps test set: ./audiocaps_test_16000_struct.tsv as a sample. 65 | ### Generate the melspec file of audio 66 | Assume you have already got a tsv file to link each caption to its audio_path, which means the tsv_file has "name","audio_path","dataset" and "caption" columns in it. 67 | To get the melspec of audio, run the following command, which will save mels in ./processed 68 | ```bash 69 | python ldm/data/preprocess/mel_spec.py --tsv_path tmp.tsv 70 | ``` 71 | Add the duration into the tsv file 72 | ```bash 73 | python ldm/data/preprocess/add_duration.py 74 | ``` 75 | ## Train variational autoencoder 76 | Assume we have processed several datasets, and save the .tsv files in data/*.tsv . Replace **data.params.spec_dir_path** with the **data**(the directory that contain tsvs) in the config file. Then we can train VAE with the following command. If you don't have 8 gpus in your machine, you can replace --gpus 0,1,...,gpu_nums 77 | ```bash 78 | python main.py --base configs/train/vae.yaml -t --gpus 0,1,2,3,4,5,6,7 79 | ``` 80 | The training result will be saved in ./logs/ 81 | ## Train latent diffsuion 82 | After Training VAE, replace model.params.first_stage_config.params.ckpt_path with your trained VAE checkpoint path in the config file. 83 | Run the following command to train the Diffusion model 84 | ```bash 85 | python main.py --base configs/autoencoder1d.yaml -t --gpus 0,1,2,3,4,5,6,7 86 | ``` 87 | The training result will be saved in ./logs/ 88 | ## Evaluation 89 | Please refer to [Make-An-Audio](https://github.com/Text-to-Audio/Make-An-Audio?tab=readme-ov-file#evaluation) 90 | 91 | ## Acknowledgements 92 | This implementation uses parts of the code from the following Github repos: 93 | [Make-An-Audio](https://github.com/Text-to-Audio/Make-An-Audio) 94 | [CLAP](https://github.com/LAION-AI/CLAP), 95 | [Stable Diffusion](https://github.com/CompVis/stable-diffusion), 96 | as described in our code. 97 | 98 | ## Citations ## 99 | If you find this code useful in your research, please consider citing: 100 | ```bibtex 101 | @misc{liu2024audiolcm, 102 | title={AudioLCM: Text-to-Audio Generation with Latent Consistency Models}, 103 | author={Huadai Liu and Rongjie Huang and Yang Liu and Hengyuan Cao and Jialei Wang and Xize Cheng and Siqi Zheng and Zhou Zhao}, 104 | year={2024}, 105 | eprint={2406.00356}, 106 | archivePrefix={arXiv}, 107 | primaryClass={eess.AS} 108 | } 109 | ``` 110 | 111 | # Disclaimer ## 112 | Any organization or individual is prohibited from using any technology mentioned in this paper to generate someone's speech without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws. 113 | -------------------------------------------------------------------------------- /configs/audiolcm.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 3.0e-06 3 | target: ldm.models.diffusion.lcm_audio.LCM_audio 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | mel_dim: 20 13 | mel_length: 312 14 | channels: 0 15 | cond_stage_trainable: False 16 | conditioning_key: crossattn 17 | monitor: val/loss_simple_ema 18 | scale_by_std: true 19 | use_lcm: True 20 | num_ddim_timesteps: 50 21 | w_min: 4 22 | w_max: 12 23 | ckpt_path: ../ckpt/maa2.ckpt 24 | 25 | use_ema: false 26 | scheduler_config: 27 | target: ldm.lr_scheduler.LambdaLinearScheduler 28 | params: 29 | warm_up_steps: 30 | - 10000 31 | cycle_lengths: 32 | - 10000000000000 33 | f_start: 34 | - 1.0e-06 35 | f_max: 36 | - 1.0 37 | f_min: 38 | - 1.0 39 | unet_config: 40 | target: ldm.modules.diffusionmodules.concatDiT.ConcatDiT2MLP 41 | params: 42 | in_channels: 20 43 | context_dim: 1024 44 | hidden_size: 576 45 | num_heads: 8 46 | depth: 4 47 | max_len: 1000 48 | first_stage_config: 49 | target: ldm.models.autoencoder1d.AutoencoderKL 50 | params: 51 | embed_dim: 20 52 | monitor: val/rec_loss 53 | ckpt_path: ../logs/trainae/ckpt/epoch=000032.ckpt 54 | ddconfig: 55 | double_z: true 56 | in_channels: 80 57 | out_ch: 80 58 | z_channels: 20 59 | kernel_size: 5 60 | ch: 384 61 | ch_mult: 62 | - 1 63 | - 2 64 | - 4 65 | num_res_blocks: 2 66 | attn_layers: 67 | - 3 68 | down_layers: 69 | - 0 70 | dropout: 0.0 71 | lossconfig: 72 | target: torch.nn.Identity 73 | cond_stage_config: 74 | target: ldm.modules.encoders.modules.FrozenCLAPFLANEmbedder 75 | params: 76 | weights_path: ../useful_ckpts/CLAP/CLAP_weights_2022.pth 77 | 78 | lightning: 79 | callbacks: 80 | image_logger: 81 | target: main.AudioLogger 82 | params: 83 | sample_rate: 16000 84 | for_specs: true 85 | increase_log_steps: false 86 | batch_frequency: 5000 87 | max_images: 8 88 | melvmin: -5 89 | melvmax: 1.5 90 | vocoder_cfg: 91 | target: vocoder.bigvgan.models.VocoderBigVGAN 92 | params: 93 | ckpt_vocoder: ../vocoder/logs/bigvnat16k93.5w 94 | trainer: 95 | benchmark: True 96 | gradient_clip_val: 1.0 97 | replace_sampler_ddp: false 98 | max_epochs: 100 99 | modelcheckpoint: 100 | params: 101 | monitor: epoch 102 | mode: max 103 | # every_n_train_steps: 2000 104 | save_top_k: 100 105 | every_n_epochs: 3 106 | 107 | 108 | data: 109 | target: main.SpectrogramDataModuleFromConfig 110 | params: 111 | batch_size: 8 112 | num_workers: 32 113 | spec_dir_path: 'ldm/data/tsv_dirs/full_data/caps_struct' 114 | mel_num: 80 115 | train: 116 | target: ldm.data.joinaudiodataset_struct_anylen.JoinSpecsTrain 117 | params: 118 | specs_dataset_cfg: 119 | validation: 120 | target: ldm.data.joinaudiodataset_struct_anylen.JoinSpecsValidation 121 | params: 122 | specs_dataset_cfg: 123 | 124 | test_dataset: 125 | target: ldm.data.tsvdataset.TSVDatasetStruct 126 | params: 127 | tsv_path: audiocaps_test_16000_struct.tsv 128 | spec_crop_len: 624 129 | 130 | -------------------------------------------------------------------------------- /configs/autoencoder1d.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder1d.AutoencoderKL 4 | params: 5 | embed_dim: 20 6 | monitor: val/rec_loss 7 | ddconfig: 8 | double_z: true 9 | in_channels: 80 10 | out_ch: 80 11 | z_channels: 20 12 | kernel_size: 5 13 | ch: 384 14 | ch_mult: 15 | - 1 16 | - 2 17 | - 4 18 | num_res_blocks: 2 19 | attn_layers: 20 | - 3 21 | down_layers: 22 | - 0 23 | dropout: 0.0 24 | lossconfig: 25 | target: ldm.modules.losses_audio.contperceptual.LPAPSWithDiscriminator 26 | params: 27 | disc_start: 80001 28 | perceptual_weight: 0.0 29 | kl_weight: 1.0e-06 30 | disc_weight: 0.5 31 | disc_in_channels: 1 32 | disc_loss: mse 33 | disc_factor: 2 34 | disc_conditional: false 35 | r1_reg_weight: 3 36 | 37 | lightning: 38 | callbacks: 39 | image_logger: 40 | target: main.AudioLogger 41 | params: 42 | for_specs: true 43 | increase_log_steps: false 44 | batch_frequency: 5000 45 | max_images: 8 46 | rescale: false 47 | melvmin: -5 48 | melvmax: 1.5 49 | vocoder_cfg: 50 | target: vocoder.bigvgan.models.VocoderBigVGAN 51 | params: 52 | ckpt_vocoder: vocoder/logs/bigvnat16k93.5w 53 | trainer: 54 | sync_batchnorm: false # not working with r1_regularization 55 | strategy: ddp 56 | 57 | 58 | data: 59 | target: main.SpectrogramDataModuleFromConfig 60 | params: 61 | batch_size: 4 62 | num_workers: 16 63 | spec_dir_path: ldm/data/tsv_dirs/full_data/V1_new 64 | mel_num: 80 65 | spec_len: 624 66 | spec_crop_len: 624 67 | train: 68 | target: ldm.data.joinaudiodataset_624.JoinSpecsTrain 69 | params: 70 | specs_dataset_cfg: null 71 | validation: 72 | target: ldm.data.joinaudiodataset_624.JoinSpecsValidation 73 | params: 74 | specs_dataset_cfg: null 75 | -------------------------------------------------------------------------------- /configs/teacher.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 3.0e-06 3 | target: ldm.models.diffusion.ddpm_audio.LatentDiffusion_audio 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | mel_dim: 20 13 | mel_length: 312 14 | channels: 0 15 | cond_stage_trainable: True 16 | conditioning_key: crossattn 17 | monitor: val/loss_simple_ema 18 | scale_by_std: true 19 | use_ema: false 20 | scheduler_config: 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: 24 | - 10000 25 | cycle_lengths: 26 | - 10000000000000 27 | f_start: 28 | - 1.0e-06 29 | f_max: 30 | - 1.0 31 | f_min: 32 | - 1.0 33 | unet_config: 34 | target: ldm.modules.diffusionmodules.concatDiT.ConcatDiT2MLP 35 | params: 36 | in_channels: 20 37 | context_dim: 1024 38 | hidden_size: 576 39 | num_heads: 8 40 | depth: 4 41 | max_len: 1000 42 | first_stage_config: 43 | target: ldm.models.autoencoder1d.AutoencoderKL 44 | params: 45 | embed_dim: 20 46 | monitor: val/rec_loss 47 | ckpt_path: logs/trainae/ckpt/epoch=000032.ckpt 48 | ddconfig: 49 | double_z: true 50 | in_channels: 80 51 | out_ch: 80 52 | z_channels: 20 53 | kernel_size: 5 54 | ch: 384 55 | ch_mult: 56 | - 1 57 | - 2 58 | - 4 59 | num_res_blocks: 2 60 | attn_layers: 61 | - 3 62 | down_layers: 63 | - 0 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.FrozenCLAPFLANEmbedder 69 | params: 70 | weights_path: useful_ckpts/CLAP/CLAP_weights_2022.pth 71 | 72 | lightning: 73 | callbacks: 74 | image_logger: 75 | target: main.AudioLogger 76 | params: 77 | sample_rate: 16000 78 | for_specs: true 79 | increase_log_steps: false 80 | batch_frequency: 5000 81 | max_images: 8 82 | melvmin: -5 83 | melvmax: 1.5 84 | vocoder_cfg: 85 | target: vocoder.bigvgan.models.VocoderBigVGAN 86 | params: 87 | ckpt_vocoder: vocoder/logs/bigvnat16k93.5w 88 | trainer: 89 | benchmark: True 90 | gradient_clip_val: 1.0 91 | replace_sampler_ddp: false 92 | modelcheckpoint: 93 | params: 94 | monitor: epoch 95 | mode: max 96 | save_top_k: 10 97 | every_n_epochs: 5 98 | 99 | data: 100 | target: main.SpectrogramDataModuleFromConfig 101 | params: 102 | batch_size: 4 103 | num_workers: 32 104 | main_spec_dir_path: 'ldm/data/tsv_dirs/full_data/caps_struct' 105 | other_spec_dir_path: 'ldm/data/tsv_dirs/full_data/V2' 106 | mel_num: 80 107 | train: 108 | target: ldm.data.joinaudiodataset_struct_sample_anylen.JoinSpecsTrain 109 | params: 110 | specs_dataset_cfg: 111 | validation: 112 | target: ldm.data.joinaudiodataset_struct_sample_anylen.JoinSpecsValidation 113 | params: 114 | specs_dataset_cfg: 115 | 116 | test_dataset: 117 | target: ldm.data.tsvdataset.TSVDatasetStruct 118 | params: 119 | tsv_path: musiccap.tsv 120 | spec_crop_len: 624 121 | 122 | -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuhuadai/AudioLCM/be5a709a1020072e3ca2d66289724f15bb4c917c/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/ddim_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pytorch_lightning as pl 4 | 5 | 6 | def extract_into_tensor(a, t, x_shape): 7 | b, *_ = t.shape 8 | out = a.gather(-1, t) 9 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 10 | 11 | class DDIMSolver: 12 | def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): 13 | # DDIM sampling parameters 14 | step_ratio = timesteps // ddim_timesteps 15 | self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 16 | self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] 17 | self.ddim_alpha_cumprods_prev = np.asarray( 18 | [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() 19 | ) 20 | # convert to torch tensors 21 | self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() 22 | self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) 23 | self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) 24 | 25 | def to(self, device): 26 | self.ddim_timesteps = self.ddim_timesteps.to(device) 27 | self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) 28 | self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) 29 | return self 30 | 31 | def ddim_step(self, pred_x0, pred_noise, timestep_index): 32 | alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev.to(pred_x0.device), timestep_index, pred_x0.shape) 33 | dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise 34 | x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt 35 | return x_prev -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuhuadai/AudioLCM/be5a709a1020072e3ca2d66289724f15bb4c917c/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuhuadai/AudioLCM/be5a709a1020072e3ca2d66289724f15bb4c917c/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | sum_dim = list(range(1,len(self.mean.shape))) 44 | if other is None: 45 | 46 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 47 | + self.var - 1.0 - self.logvar, 48 | dim=sum_dim) 49 | else: 50 | return 0.5 * torch.sum( 51 | torch.pow(self.mean - other.mean, 2) / other.var 52 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 53 | dim=sum_dim) 54 | 55 | def nll(self, sample, dims=[1,2,3]): 56 | if self.deterministic: 57 | return torch.Tensor([0.]) 58 | logtwopi = np.log(2.0 * np.pi) 59 | return 0.5 * torch.sum( 60 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 61 | dim=dims) 62 | 63 | def mode(self): 64 | return self.mean 65 | 66 | 67 | def normal_kl(mean1, logvar1, mean2, logvar2): 68 | """ 69 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 70 | Compute the KL divergence between two gaussians. 71 | Shapes are automatically broadcasted, so batches can be compared to 72 | scalars, among other use cases. 73 | """ 74 | tensor = None 75 | for obj in (mean1, logvar1, mean2, logvar2): 76 | if isinstance(obj, torch.Tensor): 77 | tensor = obj 78 | break 79 | assert tensor is not None, "at least one argument must be a Tensor" 80 | 81 | # Force variances to be Tensors. Broadcasting helps convert scalars to 82 | # Tensors, but it does not work for torch.exp(). 83 | logvar1, logvar2 = [ 84 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 85 | for x in (logvar1, logvar2) 86 | ] 87 | 88 | return 0.5 * ( 89 | -1.0 90 | + logvar2 91 | - logvar1 92 | + torch.exp(logvar1 - logvar2) 93 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 94 | ) 95 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/__init__.py: -------------------------------------------------------------------------------- 1 | from . import clap 2 | from . import audio 3 | from . import utils -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/audio.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchlibrosa.stft import Spectrogram, LogmelFilterBank 5 | 6 | def get_audio_encoder(name: str): 7 | if name == "Cnn14": 8 | return Cnn14 9 | else: 10 | raise Exception('The audio encoder name {} is incorrect or not supported'.format(name)) 11 | 12 | 13 | class ConvBlock(nn.Module): 14 | def __init__(self, in_channels, out_channels): 15 | 16 | super(ConvBlock, self).__init__() 17 | 18 | self.conv1 = nn.Conv2d(in_channels=in_channels, 19 | out_channels=out_channels, 20 | kernel_size=(3, 3), stride=(1, 1), 21 | padding=(1, 1), bias=False) 22 | 23 | self.conv2 = nn.Conv2d(in_channels=out_channels, 24 | out_channels=out_channels, 25 | kernel_size=(3, 3), stride=(1, 1), 26 | padding=(1, 1), bias=False) 27 | 28 | self.bn1 = nn.BatchNorm2d(out_channels) 29 | self.bn2 = nn.BatchNorm2d(out_channels) 30 | 31 | 32 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 33 | 34 | x = input 35 | x = F.relu_(self.bn1(self.conv1(x))) 36 | x = F.relu_(self.bn2(self.conv2(x))) 37 | if pool_type == 'max': 38 | x = F.max_pool2d(x, kernel_size=pool_size) 39 | elif pool_type == 'avg': 40 | x = F.avg_pool2d(x, kernel_size=pool_size) 41 | elif pool_type == 'avg+max': 42 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 43 | x2 = F.max_pool2d(x, kernel_size=pool_size) 44 | x = x1 + x2 45 | else: 46 | raise Exception('Incorrect argument!') 47 | 48 | return x 49 | 50 | 51 | class ConvBlock5x5(nn.Module): 52 | def __init__(self, in_channels, out_channels): 53 | 54 | super(ConvBlock5x5, self).__init__() 55 | 56 | self.conv1 = nn.Conv2d(in_channels=in_channels, 57 | out_channels=out_channels, 58 | kernel_size=(5, 5), stride=(1, 1), 59 | padding=(2, 2), bias=False) 60 | 61 | self.bn1 = nn.BatchNorm2d(out_channels) 62 | 63 | 64 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 65 | 66 | x = input 67 | x = F.relu_(self.bn1(self.conv1(x))) 68 | if pool_type == 'max': 69 | x = F.max_pool2d(x, kernel_size=pool_size) 70 | elif pool_type == 'avg': 71 | x = F.avg_pool2d(x, kernel_size=pool_size) 72 | elif pool_type == 'avg+max': 73 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 74 | x2 = F.max_pool2d(x, kernel_size=pool_size) 75 | x = x1 + x2 76 | else: 77 | raise Exception('Incorrect argument!') 78 | 79 | return x 80 | 81 | 82 | class AttBlock(nn.Module): 83 | def __init__(self, n_in, n_out, activation='linear', temperature=1.): 84 | super(AttBlock, self).__init__() 85 | 86 | self.activation = activation 87 | self.temperature = temperature 88 | self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 89 | self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 90 | 91 | self.bn_att = nn.BatchNorm1d(n_out) 92 | 93 | def forward(self, x): 94 | # x: (n_samples, n_in, n_time) 95 | norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) 96 | cla = self.nonlinear_transform(self.cla(x)) 97 | x = torch.sum(norm_att * cla, dim=2) 98 | return x, norm_att, cla 99 | 100 | def nonlinear_transform(self, x): 101 | if self.activation == 'linear': 102 | return x 103 | elif self.activation == 'sigmoid': 104 | return torch.sigmoid(x) 105 | 106 | 107 | class Cnn14(nn.Module): 108 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 109 | fmax, classes_num, out_emb): 110 | 111 | super(Cnn14, self).__init__() 112 | 113 | window = 'hann' 114 | center = True 115 | pad_mode = 'reflect' 116 | ref = 1.0 117 | amin = 1e-10 118 | top_db = None 119 | 120 | # Spectrogram extractor 121 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 122 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 123 | freeze_parameters=True) 124 | 125 | # Logmel feature extractor 126 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 127 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 128 | freeze_parameters=True) 129 | 130 | self.bn0 = nn.BatchNorm2d(64) 131 | 132 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 133 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 134 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 135 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 136 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 137 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 138 | 139 | # out_emb is 2048 for best Cnn14 140 | self.fc1 = nn.Linear(2048, out_emb, bias=True) 141 | self.fc_audioset = nn.Linear(out_emb, classes_num, bias=True) 142 | 143 | def forward(self, input, mixup_lambda=None): 144 | """ 145 | Input: (batch_size, data_length) 146 | """ 147 | 148 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 149 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 150 | 151 | x = x.transpose(1, 3) 152 | x = self.bn0(x) 153 | x = x.transpose(1, 3) 154 | 155 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 156 | x = F.dropout(x, p=0.2, training=self.training) 157 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 158 | x = F.dropout(x, p=0.2, training=self.training) 159 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 160 | x = F.dropout(x, p=0.2, training=self.training) 161 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 162 | x = F.dropout(x, p=0.2, training=self.training) 163 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 164 | x = F.dropout(x, p=0.2, training=self.training) 165 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 166 | x = F.dropout(x, p=0.2, training=self.training) 167 | x = torch.mean(x, dim=3) 168 | 169 | (x1, _) = torch.max(x, dim=2) 170 | x2 = torch.mean(x, dim=2) 171 | x = x1 + x2 172 | x = F.dropout(x, p=0.5, training=self.training) 173 | x = F.relu_(self.fc1(x)) 174 | embedding = F.dropout(x, p=0.5, training=self.training) 175 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 176 | 177 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 178 | 179 | return output_dict -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/clap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from transformers import AutoModel 6 | from .audio import get_audio_encoder 7 | 8 | class Projection(nn.Module): 9 | def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None: 10 | super().__init__() 11 | self.linear1 = nn.Linear(d_in, d_out, bias=False) 12 | self.linear2 = nn.Linear(d_out, d_out, bias=False) 13 | self.layer_norm = nn.LayerNorm(d_out) 14 | self.drop = nn.Dropout(p) 15 | 16 | def forward(self, x: torch.Tensor) -> torch.Tensor: 17 | embed1 = self.linear1(x) 18 | embed2 = self.drop(self.linear2(F.gelu(embed1))) 19 | embeds = self.layer_norm(embed1 + embed2) 20 | return embeds 21 | 22 | class AudioEncoder(nn.Module): 23 | def __init__(self, audioenc_name:str, d_in: int, d_out: int, sample_rate: int, window_size: int, 24 | hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None: 25 | super().__init__() 26 | 27 | audio_encoder = get_audio_encoder(audioenc_name) 28 | 29 | self.base = audio_encoder( 30 | sample_rate, window_size, 31 | hop_size, mel_bins, fmin, fmax, 32 | classes_num, d_in) 33 | 34 | self.projection = Projection(d_in, d_out) 35 | 36 | def forward(self, x): 37 | out_dict = self.base(x) 38 | audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output'] 39 | projected_vec = self.projection(audio_features) 40 | return projected_vec, audio_classification_output 41 | 42 | class TextEncoder(nn.Module): 43 | def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None: 44 | super().__init__() 45 | self.base = AutoModel.from_pretrained(text_model) 46 | self.projection = Projection(transformer_embed_dim, d_out) 47 | 48 | def forward(self, x): 49 | out = self.base(**x)[0] 50 | out = out[:, 0, :] # get CLS token output 51 | projected_vec = self.projection(out) 52 | return projected_vec 53 | 54 | class CLAP(nn.Module): 55 | def __init__(self, 56 | # audio 57 | audioenc_name: str, 58 | sample_rate: int, 59 | window_size: int, 60 | hop_size: int, 61 | mel_bins: int, 62 | fmin: int, 63 | fmax: int, 64 | classes_num: int, 65 | out_emb: int, 66 | # text 67 | text_model: str, 68 | transformer_embed_dim: int, 69 | # common 70 | d_proj: int, 71 | ): 72 | super().__init__() 73 | 74 | 75 | self.audio_encoder = AudioEncoder( 76 | audioenc_name, out_emb, d_proj, 77 | sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num) 78 | 79 | self.caption_encoder = TextEncoder( 80 | d_proj, text_model, transformer_embed_dim 81 | ) 82 | 83 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 84 | 85 | def forward(self, audio, text): 86 | audio_embed, _ = self.audio_encoder(audio) 87 | caption_embed = self.caption_encoder(text) 88 | 89 | return caption_embed, audio_embed, self.logit_scale.exp() -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/config.yaml: -------------------------------------------------------------------------------- 1 | # TEXT ENCODER CONFIG 2 | text_model: '../ldm/modules/encoders/CLAP/bert-base-uncased' 3 | text_len: 100 4 | transformer_embed_dim: 768 5 | freeze_text_encoder_weights: True 6 | 7 | # AUDIO ENCODER CONFIG 8 | audioenc_name: 'Cnn14' 9 | out_emb: 2048 10 | sampling_rate: 44100 11 | duration: 5 12 | fmin: 50 13 | fmax: 14000 14 | n_fft: 1028 15 | hop_size: 320 16 | mel_bins: 64 17 | window_size: 1024 18 | 19 | # PROJECTION SPACE CONFIG 20 | d_proj: 1024 21 | temperature: 0.003 22 | 23 | # TRAINING AND EVALUATION CONFIG 24 | num_classes: 527 25 | batch_size: 1024 26 | demo: False 27 | -------------------------------------------------------------------------------- /ldm/modules/encoders/CLAP/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import sys 4 | 5 | def read_config_as_args(config_path,args=None,is_config_str=False): 6 | return_dict = {} 7 | 8 | if config_path is not None: 9 | if is_config_str: 10 | yml_config = yaml.load(config_path, Loader=yaml.FullLoader) 11 | else: 12 | with open(config_path, "r") as f: 13 | yml_config = yaml.load(f, Loader=yaml.FullLoader) 14 | 15 | if args != None: 16 | for k, v in yml_config.items(): 17 | if k in args.__dict__: 18 | args.__dict__[k] = v 19 | else: 20 | sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k)) 21 | else: 22 | for k, v in yml_config.items(): 23 | return_dict[k] = v 24 | 25 | args = args if args != None else return_dict 26 | return argparse.Namespace(**args) 27 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuhuadai/AudioLCM/be5a709a1020072e3ca2d66289724f15bb4c917c/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import list_models, create_model, create_model_and_transforms, add_model_config 2 | from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics 3 | from .model import CLAP, CLAPTextCfg, CLAPVisionCfg, CLAPAudioCfp, convert_weights_to_fp16, trace_model 4 | from .openai import load_openai_model, list_openai_models 5 | from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\ 6 | get_pretrained_url, download_pretrained 7 | from .tokenizer import SimpleTokenizer, tokenize 8 | from .transform import image_transform -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/bert.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer, BertModel 2 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 3 | model = BertModel.from_pretrained("bert-base-uncased") 4 | text = "Replace me by any text you'd like." 5 | 6 | def bert_embeddings(text): 7 | # text = "Replace me by any text you'd like." 8 | encoded_input = tokenizer(text, return_tensors='pt') 9 | output = model(**encoded_input) 10 | return output 11 | 12 | from transformers import RobertaTokenizer, RobertaModel 13 | 14 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 15 | model = RobertaModel.from_pretrained('roberta-base') 16 | text = "Replace me by any text you'd like." 17 | def Roberta_embeddings(text): 18 | # text = "Replace me by any text you'd like." 19 | encoded_input = tokenizer(text, return_tensors='pt') 20 | output = model(**encoded_input) 21 | return output 22 | 23 | from transformers import BartTokenizer, BartModel 24 | 25 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') 26 | model = BartModel.from_pretrained('facebook/bart-base') 27 | text = "Replace me by any text you'd like." 28 | def bart_embeddings(text): 29 | # text = "Replace me by any text you'd like." 30 | encoded_input = tokenizer(text, return_tensors='pt') 31 | output = model(**encoded_input) 32 | return output -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuhuadai/AudioLCM/be5a709a1020072e3ca2d66289724f15bb4c917c/ldm/modules/encoders/open_clap/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/feature_fusion.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Feature Fusion for Varible-Length Data Processing 3 | AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py 4 | According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021 5 | ''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class DAF(nn.Module): 12 | ''' 13 | 直接相加 DirectAddFuse 14 | ''' 15 | 16 | def __init__(self): 17 | super(DAF, self).__init__() 18 | 19 | def forward(self, x, residual): 20 | return x + residual 21 | 22 | 23 | class iAFF(nn.Module): 24 | ''' 25 | 多特征融合 iAFF 26 | ''' 27 | 28 | def __init__(self, channels=64, r=4, type='2D'): 29 | super(iAFF, self).__init__() 30 | inter_channels = int(channels // r) 31 | 32 | if type == '1D': 33 | # 本地注意力 34 | self.local_att = nn.Sequential( 35 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 36 | nn.BatchNorm1d(inter_channels), 37 | nn.ReLU(inplace=True), 38 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 39 | nn.BatchNorm1d(channels), 40 | ) 41 | 42 | # 全局注意力 43 | self.global_att = nn.Sequential( 44 | nn.AdaptiveAvgPool1d(1), 45 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 46 | nn.BatchNorm1d(inter_channels), 47 | nn.ReLU(inplace=True), 48 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 49 | nn.BatchNorm1d(channels), 50 | ) 51 | 52 | # 第二次本地注意力 53 | self.local_att2 = nn.Sequential( 54 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 55 | nn.BatchNorm1d(inter_channels), 56 | nn.ReLU(inplace=True), 57 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 58 | nn.BatchNorm1d(channels), 59 | ) 60 | # 第二次全局注意力 61 | self.global_att2 = nn.Sequential( 62 | nn.AdaptiveAvgPool1d(1), 63 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 64 | nn.BatchNorm1d(inter_channels), 65 | nn.ReLU(inplace=True), 66 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 67 | nn.BatchNorm1d(channels), 68 | ) 69 | elif type == '2D': 70 | # 本地注意力 71 | self.local_att = nn.Sequential( 72 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 73 | nn.BatchNorm2d(inter_channels), 74 | nn.ReLU(inplace=True), 75 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 76 | nn.BatchNorm2d(channels), 77 | ) 78 | 79 | # 全局注意力 80 | self.global_att = nn.Sequential( 81 | nn.AdaptiveAvgPool2d(1), 82 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 83 | nn.BatchNorm2d(inter_channels), 84 | nn.ReLU(inplace=True), 85 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 86 | nn.BatchNorm2d(channels), 87 | ) 88 | 89 | # 第二次本地注意力 90 | self.local_att2 = nn.Sequential( 91 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 92 | nn.BatchNorm2d(inter_channels), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 95 | nn.BatchNorm2d(channels), 96 | ) 97 | # 第二次全局注意力 98 | self.global_att2 = nn.Sequential( 99 | nn.AdaptiveAvgPool2d(1), 100 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 101 | nn.BatchNorm2d(inter_channels), 102 | nn.ReLU(inplace=True), 103 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 104 | nn.BatchNorm2d(channels), 105 | ) 106 | else: 107 | raise f'the type is not supported' 108 | 109 | self.sigmoid = nn.Sigmoid() 110 | 111 | def forward(self, x, residual): 112 | flag = False 113 | xa = x + residual 114 | if xa.size(0) == 1: 115 | xa = torch.cat([xa,xa],dim=0) 116 | flag = True 117 | xl = self.local_att(xa) 118 | xg = self.global_att(xa) 119 | xlg = xl + xg 120 | wei = self.sigmoid(xlg) 121 | xi = x * wei + residual * (1 - wei) 122 | 123 | xl2 = self.local_att2(xi) 124 | xg2 = self.global_att(xi) 125 | xlg2 = xl2 + xg2 126 | wei2 = self.sigmoid(xlg2) 127 | xo = x * wei2 + residual * (1 - wei2) 128 | if flag: 129 | xo = xo[0].unsqueeze(0) 130 | return xo 131 | 132 | 133 | class AFF(nn.Module): 134 | ''' 135 | 多特征融合 AFF 136 | ''' 137 | 138 | def __init__(self, channels=64, r=4, type='2D'): 139 | super(AFF, self).__init__() 140 | inter_channels = int(channels // r) 141 | 142 | if type == '1D': 143 | self.local_att = nn.Sequential( 144 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 145 | nn.BatchNorm1d(inter_channels), 146 | nn.ReLU(inplace=True), 147 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 148 | nn.BatchNorm1d(channels), 149 | ) 150 | self.global_att = nn.Sequential( 151 | nn.AdaptiveAvgPool1d(1), 152 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 153 | nn.BatchNorm1d(inter_channels), 154 | nn.ReLU(inplace=True), 155 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 156 | nn.BatchNorm1d(channels), 157 | ) 158 | elif type == '2D': 159 | self.local_att = nn.Sequential( 160 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 161 | nn.BatchNorm2d(inter_channels), 162 | nn.ReLU(inplace=True), 163 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 164 | nn.BatchNorm2d(channels), 165 | ) 166 | self.global_att = nn.Sequential( 167 | nn.AdaptiveAvgPool2d(1), 168 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 169 | nn.BatchNorm2d(inter_channels), 170 | nn.ReLU(inplace=True), 171 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 172 | nn.BatchNorm2d(channels), 173 | ) 174 | else: 175 | raise f'the type is not supported.' 176 | 177 | self.sigmoid = nn.Sigmoid() 178 | 179 | def forward(self, x, residual): 180 | flag = False 181 | xa = x + residual 182 | if xa.size(0) == 1: 183 | xa = torch.cat([xa,xa],dim=0) 184 | flag = True 185 | xl = self.local_att(xa) 186 | xg = self.global_att(xa) 187 | xlg = xl + xg 188 | wei = self.sigmoid(xlg) 189 | xo = 2 * x * wei + 2 * residual * (1 - wei) 190 | if flag: 191 | xo = xo[0].unsqueeze(0) 192 | return xo 193 | 194 | -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/linear_probe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from .model import MLPLayers 5 | 6 | 7 | class LinearProbe(nn.Module): 8 | def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): 9 | """ 10 | Args: 11 | model: nn.Module 12 | mlp: bool, if True, then use the MLP layer as the linear probe module 13 | freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe 14 | in_ch: int, the output channel from CLAP model 15 | out_ch: int, the output channel from linear probe (class_num) 16 | act: torch.nn.functional, the activation function before the loss function 17 | """ 18 | super().__init__() 19 | in_ch = 512 20 | self.clap_model = model 21 | self.clap_model.text_branch = None # to save memory 22 | self.freeze = freeze 23 | if mlp: 24 | self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) 25 | else: 26 | self.lp_layer = nn.Linear(in_ch, out_ch) 27 | 28 | if self.freeze: 29 | for param in self.clap_model.parameters(): 30 | param.requires_grad = False 31 | 32 | if act == 'None': 33 | self.act = None 34 | elif act == 'relu': 35 | self.act = nn.ReLU() 36 | elif act == 'elu': 37 | self.act = nn.ELU() 38 | elif act == 'prelu': 39 | self.act = nn.PReLU(num_parameters=in_ch) 40 | elif act == 'softmax': 41 | self.act = nn.Softmax(dim=-1) 42 | elif act == 'sigmoid': 43 | self.act = nn.Sigmoid() 44 | 45 | def forward(self, x, mix_lambda=None, device=None): 46 | """ 47 | Args: 48 | x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list 49 | mix_lambda: torch.tensor [batch], the mixup lambda 50 | Returns: 51 | class_prob: torch.tensor [batch, class_num] 52 | 53 | """ 54 | # batchnorm cancel grandient 55 | if self.freeze: 56 | self.clap_model.eval() 57 | 58 | x = self.clap_model.audio_projection( 59 | self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)["embedding"]) 60 | out = self.lp_layer(x) 61 | if self.act is not None: 62 | out = self.act(out) 63 | return out 64 | -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/HTSAT-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "base" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/HTSAT-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "large" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/HTSAT-tiny-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/HTSAT-tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/PANN-10.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn10" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/PANN-14-fmax-18k.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 18000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/PANN-14-fmax-8k-20s.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 960000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 360, 10 | "fmin": 50, 11 | "fmax": 8000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/PANN-14-tiny-transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 4 22 | } 23 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/PANN-14-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/PANN-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/PANN-6.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn6" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import Union, List 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict 13 | from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_tag_models('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | model_cfg, 26 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 27 | jit=True, 28 | cache_dir=os.path.expanduser("~/.cache/clip"), 29 | enable_fusion: bool = False, 30 | fusion_type: str = 'None' 31 | ): 32 | """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model 33 | 34 | Parameters 35 | ---------- 36 | name : str 37 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | 43 | Returns 44 | ------- 45 | model : torch.nn.Module 46 | The CLAP model 47 | preprocess : Callable[[PIL.Image], torch.Tensor] 48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 49 | """ 50 | if get_pretrained_url(name, 'openai'): 51 | model_path = download_pretrained(get_pretrained_url(name, 'openai'), root=cache_dir) 52 | elif os.path.isfile(name): 53 | model_path = name 54 | else: 55 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 56 | 57 | try: 58 | # loading JIT archive 59 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 60 | state_dict = None 61 | except RuntimeError: 62 | # loading saved state dict 63 | if jit: 64 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 65 | jit = False 66 | state_dict = torch.load(model_path, map_location="cpu") 67 | 68 | if not jit: 69 | try: 70 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type).to(device) 71 | except KeyError: 72 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 73 | model = build_model_from_openai_state_dict(sd, model_cfg, enable_fusion, fusion_type).to(device) 74 | 75 | if str(device) == "cpu": 76 | model.float() 77 | return model 78 | 79 | # patch the device names 80 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 81 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 82 | 83 | def patch_device(module): 84 | try: 85 | graphs = [module.graph] if hasattr(module, "graph") else [] 86 | except RuntimeError: 87 | graphs = [] 88 | 89 | if hasattr(module, "forward1"): 90 | graphs.append(module.forward1.graph) 91 | 92 | for graph in graphs: 93 | for node in graph.findAllNodes("prim::Constant"): 94 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 95 | node.copyAttributes(device_node) 96 | 97 | model.apply(patch_device) 98 | patch_device(model.encode_audio) 99 | patch_device(model.encode_text) 100 | 101 | # patch dtype to float32 on CPU 102 | if str(device) == "cpu": 103 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 104 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 105 | float_node = float_input.node() 106 | 107 | def patch_float(module): 108 | try: 109 | graphs = [module.graph] if hasattr(module, "graph") else [] 110 | except RuntimeError: 111 | graphs = [] 112 | 113 | if hasattr(module, "forward1"): 114 | graphs.append(module.forward1.graph) 115 | 116 | for graph in graphs: 117 | for node in graph.findAllNodes("aten::to"): 118 | inputs = list(node.inputs()) 119 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 120 | if inputs[i].node()["value"] == 5: 121 | inputs[i].node().copyAttributes(float_node) 122 | 123 | model.apply(patch_float) 124 | patch_float(model.encode_audio) 125 | patch_float(model.encode_text) 126 | model.float() 127 | 128 | model.audio_branch.audio_length = model.audio_cfg.audio_length 129 | return model 130 | -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | 6 | from tqdm import tqdm 7 | 8 | _RN50 = dict( 9 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 10 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 11 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" 12 | ) 13 | 14 | _RN50_quickgelu = dict( 15 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 16 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 17 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" 18 | ) 19 | 20 | _RN101 = dict( 21 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 22 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" 23 | ) 24 | 25 | _RN101_quickgelu = dict( 26 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 27 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" 28 | ) 29 | 30 | _RN50x4 = dict( 31 | openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 32 | ) 33 | 34 | _RN50x16 = dict( 35 | openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 36 | ) 37 | 38 | _RN50x64 = dict( 39 | openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 40 | ) 41 | 42 | _VITB32 = dict( 43 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 44 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 45 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 46 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", 47 | ) 48 | 49 | _VITB32_quickgelu = dict( 50 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 51 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 52 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 53 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", 54 | ) 55 | 56 | _VITB16 = dict( 57 | openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 58 | ) 59 | 60 | _VITL14 = dict( 61 | openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 62 | ) 63 | 64 | _PRETRAINED = { 65 | "RN50": _RN50, 66 | "RN50-quickgelu": _RN50_quickgelu, 67 | "RN101": _RN101, 68 | "RN101-quickgelu": _RN101_quickgelu, 69 | "RN50x4": _RN50x4, 70 | "RN50x16": _RN50x16, 71 | "ViT-B-32": _VITB32, 72 | "ViT-B-32-quickgelu": _VITB32_quickgelu, 73 | "ViT-B-16": _VITB16, 74 | "ViT-L-14": _VITL14, 75 | } 76 | 77 | 78 | def list_pretrained(as_str: bool = False): 79 | """ returns list of pretrained models 80 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 81 | """ 82 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] 83 | 84 | 85 | def list_pretrained_tag_models(tag: str): 86 | """ return all models having the specified pretrain tag """ 87 | models = [] 88 | for k in _PRETRAINED.keys(): 89 | if tag in _PRETRAINED[k]: 90 | models.append(k) 91 | return models 92 | 93 | 94 | def list_pretrained_model_tags(model: str): 95 | """ return all pretrain tags for the specified model architecture """ 96 | tags = [] 97 | if model in _PRETRAINED: 98 | tags.extend(_PRETRAINED[model].keys()) 99 | return tags 100 | 101 | 102 | def get_pretrained_url(model: str, tag: str): 103 | if model not in _PRETRAINED: 104 | return '' 105 | model_pretrained = _PRETRAINED[model] 106 | if tag not in model_pretrained: 107 | return '' 108 | return model_pretrained[tag] 109 | 110 | 111 | def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): 112 | os.makedirs(root, exist_ok=True) 113 | filename = os.path.basename(url) 114 | 115 | if 'openaipublic' in url: 116 | expected_sha256 = url.split("/")[-2] 117 | else: 118 | expected_sha256 = '' 119 | 120 | download_target = os.path.join(root, filename) 121 | 122 | if os.path.exists(download_target) and not os.path.isfile(download_target): 123 | raise RuntimeError(f"{download_target} exists and is not a regular file") 124 | 125 | if os.path.isfile(download_target): 126 | if expected_sha256: 127 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 128 | return download_target 129 | else: 130 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 131 | else: 132 | return download_target 133 | 134 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 135 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 136 | while True: 137 | buffer = source.read(8192) 138 | if not buffer: 139 | break 140 | 141 | output.write(buffer) 142 | loop.update(len(buffer)) 143 | 144 | if expected_sha256 and hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 145 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 146 | 147 | return download_target 148 | -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch.nn as nn 8 | 9 | try: 10 | import timm 11 | from timm.models.layers import Mlp, to_2tuple 12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 13 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 14 | except ImportError as e: 15 | timm = None 16 | 17 | from .utils import freeze_batch_norm_2d 18 | 19 | 20 | class TimmModel(nn.Module): 21 | """ timm model adapter 22 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model_name, 28 | embed_dim, 29 | image_size=224, 30 | pool='avg', 31 | proj='linear', 32 | drop=0., 33 | pretrained=False): 34 | super().__init__() 35 | if timm is None: 36 | raise RuntimeError("Please `pip install timm` to use timm models.") 37 | 38 | self.image_size = to_2tuple(image_size) 39 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 40 | feat_size = self.trunk.default_cfg.get('pool_size', None) 41 | feature_ndim = 1 if not feat_size else 2 42 | if pool in ('abs_attn', 'rot_attn'): 43 | assert feature_ndim == 2 44 | # if attn pooling used, remove both classifier and default pool 45 | self.trunk.reset_classifier(0, global_pool='') 46 | else: 47 | # reset global pool if pool config set, otherwise leave as network default 48 | reset_kwargs = dict(global_pool=pool) if pool else {} 49 | self.trunk.reset_classifier(0, **reset_kwargs) 50 | prev_chs = self.trunk.num_features 51 | 52 | head_layers = OrderedDict() 53 | if pool == 'abs_attn': 54 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 55 | prev_chs = embed_dim 56 | elif pool == 'rot_attn': 57 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 58 | prev_chs = embed_dim 59 | else: 60 | assert proj, 'projection layer needed if non-attention pooling is used.' 61 | 62 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 63 | if proj == 'linear': 64 | head_layers['drop'] = nn.Dropout(drop) 65 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim) 66 | elif proj == 'mlp': 67 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) 68 | 69 | self.head = nn.Sequential(head_layers) 70 | 71 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 72 | """ lock modules 73 | Args: 74 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 75 | """ 76 | if not unlocked_groups: 77 | # lock full model 78 | for param in self.trunk.parameters(): 79 | param.requires_grad = False 80 | if freeze_bn_stats: 81 | freeze_batch_norm_2d(self.trunk) 82 | else: 83 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 84 | try: 85 | # FIXME import here until API stable and in an official release 86 | from timm.models.helpers import group_parameters, group_modules 87 | except ImportError: 88 | raise RuntimeError( 89 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 90 | matcher = self.trunk.group_matcher() 91 | gparams = group_parameters(self.trunk, matcher) 92 | max_layer_id = max(gparams.keys()) 93 | max_layer_id = max_layer_id - unlocked_groups 94 | for group_idx in range(max_layer_id + 1): 95 | group = gparams[group_idx] 96 | for param in group: 97 | self.trunk.get_parameter(param).requires_grad = False 98 | if freeze_bn_stats: 99 | gmodules = group_modules(self.trunk, matcher, reverse=True) 100 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 101 | freeze_batch_norm_2d(self.trunk, gmodules) 102 | 103 | def forward(self, x): 104 | x = self.trunk(x) 105 | x = self.head(x) 106 | return x 107 | -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | 16 | @lru_cache() 17 | def default_bpe(): 18 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 19 | 20 | 21 | @lru_cache() 22 | def bytes_to_unicode(): 23 | """ 24 | Returns list of utf-8 byte and a corresponding list of unicode strings. 25 | The reversible bpe codes work on unicode strings. 26 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 27 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 28 | This is a signficant percentage of your normal, say, 32K bpe vocab. 29 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 30 | And avoids mapping to whitespace/control characters the bpe code barfs on. 31 | """ 32 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 33 | cs = bs[:] 34 | n = 0 35 | for b in range(2**8): 36 | if b not in bs: 37 | bs.append(b) 38 | cs.append(2**8+n) 39 | n += 1 40 | cs = [chr(n) for n in cs] 41 | return dict(zip(bs, cs)) 42 | 43 | 44 | def get_pairs(word): 45 | """Return set of symbol pairs in a word. 46 | Word is represented as tuple of symbols (symbols being variable-length strings). 47 | """ 48 | pairs = set() 49 | prev_char = word[0] 50 | for char in word[1:]: 51 | pairs.add((prev_char, char)) 52 | prev_char = char 53 | return pairs 54 | 55 | 56 | def basic_clean(text): 57 | text = ftfy.fix_text(text) 58 | text = html.unescape(html.unescape(text)) 59 | return text.strip() 60 | 61 | 62 | def whitespace_clean(text): 63 | text = re.sub(r'\s+', ' ', text) 64 | text = text.strip() 65 | return text 66 | 67 | 68 | class SimpleTokenizer(object): 69 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 70 | self.byte_encoder = bytes_to_unicode() 71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 72 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 73 | merges = merges[1:49152-256-2+1] 74 | merges = [tuple(merge.split()) for merge in merges] 75 | vocab = list(bytes_to_unicode().values()) 76 | vocab = vocab + [v+'' for v in vocab] 77 | for merge in merges: 78 | vocab.append(''.join(merge)) 79 | if not special_tokens: 80 | special_tokens = ['', ''] 81 | else: 82 | special_tokens = ['', ''] + special_tokens 83 | vocab.extend(special_tokens) 84 | self.encoder = dict(zip(vocab, range(len(vocab)))) 85 | self.decoder = {v: k for k, v in self.encoder.items()} 86 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 87 | self.cache = {t:t for t in special_tokens} 88 | special = "|".join(special_tokens) 89 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 90 | 91 | self.vocab_size = len(self.encoder) 92 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 93 | 94 | def bpe(self, token): 95 | if token in self.cache: 96 | return self.cache[token] 97 | word = tuple(token[:-1]) + ( token[-1] + '',) 98 | pairs = get_pairs(word) 99 | 100 | if not pairs: 101 | return token+'' 102 | 103 | while True: 104 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 105 | if bigram not in self.bpe_ranks: 106 | break 107 | first, second = bigram 108 | new_word = [] 109 | i = 0 110 | while i < len(word): 111 | try: 112 | j = word.index(first, i) 113 | new_word.extend(word[i:j]) 114 | i = j 115 | except: 116 | new_word.extend(word[i:]) 117 | break 118 | 119 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 120 | new_word.append(first+second) 121 | i += 2 122 | else: 123 | new_word.append(word[i]) 124 | i += 1 125 | new_word = tuple(new_word) 126 | word = new_word 127 | if len(word) == 1: 128 | break 129 | else: 130 | pairs = get_pairs(word) 131 | word = ' '.join(word) 132 | self.cache[token] = word 133 | return word 134 | 135 | def encode(self, text): 136 | bpe_tokens = [] 137 | text = whitespace_clean(basic_clean(text)).lower() 138 | for token in re.findall(self.pat, text): 139 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 140 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 141 | return bpe_tokens 142 | 143 | def decode(self, tokens): 144 | text = ''.join([self.decoder[token] for token in tokens]) 145 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 146 | return text 147 | 148 | 149 | _tokenizer = SimpleTokenizer() 150 | 151 | 152 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 153 | """ 154 | Returns the tokenized representation of given input string(s) 155 | 156 | Parameters 157 | ---------- 158 | texts : Union[str, List[str]] 159 | An input string or a list of input strings to tokenize 160 | context_length : int 161 | The context length to use; all CLIP models use 77 as the context length 162 | 163 | Returns 164 | ------- 165 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 166 | """ 167 | if isinstance(texts, str): 168 | texts = [texts] 169 | 170 | sot_token = _tokenizer.encoder[""] 171 | eot_token = _tokenizer.encoder[""] 172 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 173 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 174 | 175 | for i, tokens in enumerate(all_tokens): 176 | if len(tokens) > context_length: 177 | tokens = tokens[:context_length] # Truncate 178 | result[i, :len(tokens)] = torch.tensor(tokens) 179 | 180 | return result 181 | -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 2 | CenterCrop 3 | 4 | 5 | def _convert_to_rgb(image): 6 | return image.convert('RGB') 7 | 8 | 9 | def image_transform( 10 | image_size: int, 11 | is_train: bool, 12 | mean=(0.48145466, 0.4578275, 0.40821073), 13 | std=(0.26862954, 0.26130258, 0.27577711) 14 | ): 15 | normalize = Normalize(mean=mean, std=std) 16 | if is_train: 17 | return Compose([ 18 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 19 | _convert_to_rgb, 20 | ToTensor(), 21 | normalize, 22 | ]) 23 | else: 24 | return Compose([ 25 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 26 | CenterCrop(image_size), 27 | _convert_to_rgb, 28 | ToTensor(), 29 | normalize, 30 | ]) 31 | -------------------------------------------------------------------------------- /ldm/modules/encoders/open_clap/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.2.1' 2 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuhuadai/AudioLCM/be5a709a1020072e3ca2d66289724f15bb4c917c/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /ldm/modules/losses_audio/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses_audio.vqperceptual import DummyLoss 2 | 3 | # relative imports pain 4 | import os 5 | import sys 6 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'vggishish') 7 | sys.path.append(path) 8 | -------------------------------------------------------------------------------- /ldm/modules/losses_audio/contperceptual_dis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | 6 | sys.path.insert(0, '.') # nopep8 7 | from ldm.modules.losses_audio.vqperceptual import * 8 | from ldm.modules.discriminator.multi_window_disc import Discriminator 9 | 10 | class LPAPSWithDiscriminator(nn.Module):# 相比于contperceptual.py添加了MultiWindowDiscriminator 11 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 12 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 13 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 14 | disc_loss="hinge"): 15 | 16 | super().__init__() 17 | assert disc_loss in ["hinge", "vanilla"] 18 | self.kl_weight = kl_weight 19 | self.pixel_weight = pixelloss_weight 20 | self.perceptual_loss = LPAPS().eval() 21 | self.perceptual_weight = perceptual_weight 22 | # output log variance 23 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 24 | 25 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 26 | n_layers=disc_num_layers, 27 | use_actnorm=use_actnorm, 28 | ).apply(weights_init) 29 | self.discriminator_iter_start = disc_start 30 | if disc_loss == "hinge": 31 | self.disc_loss = hinge_d_loss 32 | elif disc_loss == "vanilla": 33 | self.disc_loss = vanilla_d_loss 34 | else: 35 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 36 | print(f"LPAPSWithDiscriminator running with {disc_loss} loss.") 37 | self.disc_factor = disc_factor 38 | self.discriminator_weight = disc_weight 39 | self.disc_conditional = disc_conditional 40 | 41 | disc_win_num = 3 42 | mel_disc_hidden_size = 128 43 | self.discriminator_multi = Discriminator(time_lengths=[32, 64, 128][:disc_win_num], 44 | freq_length=80, hidden_size=mel_disc_hidden_size, kernel=(3, 3), 45 | cond_size=0, norm_type="in", reduction="stack") 46 | 47 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 48 | if last_layer is not None: 49 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 50 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 51 | else: 52 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 53 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 54 | 55 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 56 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 57 | d_weight = d_weight * self.discriminator_weight 58 | return d_weight 59 | 60 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 61 | global_step, last_layer=None, cond=None, split="train", weights=None): 62 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 63 | if self.perceptual_weight > 0: 64 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 65 | rec_loss = rec_loss + self.perceptual_weight * p_loss 66 | else: 67 | p_loss = torch.tensor([0.0]) 68 | 69 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 70 | weighted_nll_loss = nll_loss 71 | if weights is not None: 72 | weighted_nll_loss = weights*nll_loss 73 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 74 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 75 | kl_loss = posteriors.kl() 76 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 77 | 78 | # now the GAN part 79 | if optimizer_idx == 0: 80 | # generator update 81 | if cond is None: 82 | assert not self.disc_conditional 83 | logits_fake = self.discriminator(reconstructions.contiguous()) 84 | else: 85 | assert self.disc_conditional 86 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 87 | 88 | logits_fake_multi = self.discriminator_multi(reconstructions.contiguous().squeeze(1).transpose(1, 2)) 89 | 90 | g_loss = -torch.mean(logits_fake) 91 | g_loss_multi = -torch.mean(logits_fake_multi['y']) 92 | 93 | try: 94 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 95 | d_weight_multi = self.calculate_adaptive_weight(nll_loss, g_loss_multi, last_layer=last_layer) 96 | except RuntimeError: 97 | assert not self.training 98 | d_weight = d_weight_multi = torch.tensor(0.0) 99 | 100 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 101 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + d_weight_multi * disc_factor * g_loss_multi 102 | 103 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 104 | "{}/logvar".format(split): self.logvar.detach(), 105 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 106 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 107 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 108 | "{}/d_weight".format(split): d_weight.detach(), 109 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 110 | "{}/g_loss".format(split): g_loss.detach().mean(), 111 | "{}/g_loss_multi".format(split): g_loss_multi.detach().mean(), 112 | } 113 | return loss, log 114 | 115 | if optimizer_idx == 1: 116 | # second pass for discriminator update 117 | if cond is None: 118 | logits_real = self.discriminator(inputs.contiguous().detach()) 119 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 120 | else: 121 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 123 | 124 | logits_real_multi = self.discriminator_multi(inputs.contiguous().detach().squeeze(1).transpose(1, 2)) 125 | logits_fake_multi = self.discriminator_multi(reconstructions.contiguous().detach().squeeze(1).transpose(1, 2)) 126 | 127 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 128 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 129 | d_loss_multi = disc_factor * self.disc_loss(logits_real_multi['y'], logits_fake_multi['y']) 130 | 131 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 132 | "{}/disc_loss_multi".format(split): d_loss_multi.clone().detach().mean(), 133 | "{}/logits_real".format(split): logits_real.detach().mean(), 134 | "{}/logits_fake".format(split): logits_fake.detach().mean() 135 | } 136 | return d_loss+d_loss_multi, log 137 | 138 | -------------------------------------------------------------------------------- /ldm/modules/losses_audio/contperceptual_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | 6 | sys.path.insert(0, '.') # nopep8 7 | from ldm.modules.losses_audio.vqperceptual import * 8 | 9 | def sequence_mask(length, max_length=None):# length shape (B,) 10 | if max_length is None: 11 | max_length = length.max() 12 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)# (max_length) 13 | return x.unsqueeze(0) < length.unsqueeze(1)# (B,max_length) 14 | 15 | class LPAPSWithDiscriminator(nn.Module): 16 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 17 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 18 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 19 | disc_loss="hinge",pad_value=-1): 20 | super().__init__() 21 | assert disc_loss in ["hinge", "vanilla"] 22 | self.pad_val = pad_value 23 | self.kl_weight = kl_weight 24 | self.pixel_weight = pixelloss_weight 25 | self.perceptual_weight = perceptual_weight 26 | if self.perceptual_weight > 0: 27 | self.perceptual_loss = LPAPS().eval()# LPIPS用于日常图像,而LPAPS用于梅尔谱图 28 | 29 | # output log variance 30 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 31 | 32 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 33 | n_layers=disc_num_layers, 34 | use_actnorm=use_actnorm, 35 | ).apply(weights_init) 36 | self.discriminator_iter_start = disc_start 37 | if disc_loss == "hinge": 38 | self.disc_loss = hinge_d_loss 39 | elif disc_loss == "vanilla": 40 | self.disc_loss = vanilla_d_loss 41 | else: 42 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 43 | print(f"LPAPSWithDiscriminator running with {disc_loss} loss.") 44 | self.disc_factor = disc_factor 45 | self.discriminator_weight = disc_weight 46 | self.disc_conditional = disc_conditional 47 | 48 | 49 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 50 | if last_layer is not None: 51 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 52 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 53 | else: 54 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 55 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 56 | 57 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 58 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 59 | d_weight = d_weight * self.discriminator_weight 60 | return d_weight 61 | 62 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 63 | global_step, last_layer=None, cond=None, split="train", weights=None): 64 | if len(inputs.shape) == 3: 65 | inputs,reconstructions = inputs.unsqueeze(1),reconstructions.unsqueeze(1) 66 | 67 | b,c,h,w = inputs.shape 68 | x_lengths = (inputs.mean(dim=(1,2)) > self.pad_val).long().sum(-1) 69 | x_mask = sequence_mask(x_lengths, max_length = w)[:,None,None,:].to(inputs.dtype)# (B,1,1,max_length), 0 is the padded place 70 | 71 | 72 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 73 | if self.perceptual_weight > 0: 74 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 75 | # print(f"p_loss {p_loss}") 76 | rec_loss = rec_loss + self.perceptual_weight * p_loss 77 | else: 78 | p_loss = torch.tensor([0.0]) 79 | 80 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 81 | weighted_nll_loss = nll_loss 82 | if weights is not None: 83 | weighted_nll_loss = weights*nll_loss 84 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 85 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 86 | kl_loss = posteriors.kl() 87 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 88 | 89 | # !!!!!!!!!!!! use the following line to avoid discriminator fail !!!!!!!!!!!!! 90 | reconstructions = reconstructions*x_mask + (1-x_mask)*self.pad_val 91 | # now the GAN part 92 | if optimizer_idx == 0: 93 | # generator update 94 | if cond is None: 95 | assert not self.disc_conditional 96 | logits_fake = self.discriminator(reconstructions.contiguous()) 97 | else: 98 | assert self.disc_conditional 99 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 100 | g_loss = -torch.mean(logits_fake) 101 | 102 | try: 103 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 104 | except RuntimeError: 105 | assert not self.training 106 | d_weight = torch.tensor(0.0) 107 | 108 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 109 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 110 | 111 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 112 | "{}/logvar".format(split): self.logvar.detach(), 113 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 114 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 115 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 116 | "{}/d_weight".format(split): d_weight.detach(), 117 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 118 | "{}/g_loss".format(split): g_loss.detach().mean(), 119 | } 120 | return loss, log 121 | 122 | if optimizer_idx == 1: 123 | # second pass for discriminator update 124 | if cond is None: 125 | logits_real = self.discriminator(inputs.contiguous().detach()) 126 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 127 | else: 128 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 129 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 133 | 134 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 135 | "{}/logits_real".format(split): logits_real.detach().mean(), 136 | "{}/logits_fake".format(split): logits_fake.detach().mean() 137 | } 138 | return d_loss, log 139 | 140 | 141 | -------------------------------------------------------------------------------- /ldm/modules/losses_audio/contperceptual_multiw.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | 6 | sys.path.insert(0, '.') # nopep8 7 | from ldm.modules.losses_audio.vqperceptual import * 8 | from ldm.modules.discriminator.multi_window_disc import Discriminator 9 | 10 | class LPAPSWithDiscriminator(nn.Module): 11 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 12 | time_lengths = [16,32,64], disc_factor=1.0, disc_weight=1.0, 13 | perceptual_weight=1.0, disc_conditional=False, 14 | disc_loss="hinge"): 15 | 16 | super().__init__() 17 | assert disc_loss in ["hinge", "vanilla"] 18 | self.kl_weight = kl_weight 19 | self.pixel_weight = pixelloss_weight 20 | self.perceptual_weight = perceptual_weight 21 | if self.perceptual_weight > 0: 22 | self.perceptual_loss = LPAPS().eval()# LPIPS用于日常图像,而LPAPS用于梅尔谱图 23 | 24 | # output log variance 25 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 26 | 27 | self.discriminator = Discriminator(time_lengths=time_lengths,reduction='stack').apply(weights_init) # h=8,w/(2**disc_num_layers) - 2 28 | self.discriminator_iter_start = disc_start 29 | if disc_loss == "hinge": 30 | self.disc_loss = hinge_d_loss 31 | elif disc_loss == "vanilla": 32 | self.disc_loss = vanilla_d_loss 33 | else: 34 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 35 | print(f"LPAPSWithDiscriminator running with {disc_loss} loss.") 36 | self.disc_factor = disc_factor 37 | self.discriminator_weight = disc_weight 38 | self.disc_conditional = disc_conditional 39 | 40 | 41 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 42 | if last_layer is not None: 43 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 44 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 45 | else: 46 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 47 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 48 | 49 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 50 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 51 | d_weight = d_weight * self.discriminator_weight 52 | return d_weight 53 | 54 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 55 | global_step, last_layer=None, cond=None, split="train", weights=None): 56 | if len(inputs.shape) == 3: # (B,melbins,T) 57 | inputs,reconstructions = inputs.unsqueeze(1),reconstructions.unsqueeze(1) 58 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 59 | if self.perceptual_weight > 0: 60 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 61 | rec_loss = rec_loss + self.perceptual_weight * p_loss 62 | else: 63 | p_loss = torch.tensor([0.0]) 64 | 65 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 66 | weighted_nll_loss = nll_loss 67 | if weights is not None: 68 | weighted_nll_loss = weights*nll_loss 69 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 70 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 71 | kl_loss = posteriors.kl() 72 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 73 | 74 | 75 | inputs,reconstructions = inputs.squeeze(1).transpose(1,2),reconstructions.squeeze(1).transpose(1,2) # (B,T,melbins) 76 | # now the GAN part 77 | if optimizer_idx == 0: 78 | # generator update 79 | if cond is None: 80 | assert not self.disc_conditional 81 | logits_fake = self.discriminator(reconstructions.contiguous())['y'] 82 | else: 83 | assert self.disc_conditional 84 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))['y'] 85 | g_loss = -torch.mean(logits_fake) # logits_fake the higher the better 86 | 87 | try: 88 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 89 | except RuntimeError: 90 | assert not self.training 91 | d_weight = torch.tensor(0.0) 92 | 93 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 94 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 95 | 96 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 97 | "{}/logvar".format(split): self.logvar.detach(), 98 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 99 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 100 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 101 | "{}/d_weight".format(split): d_weight.detach(), 102 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 103 | "{}/g_loss".format(split): g_loss.detach().mean(), 104 | } 105 | return loss, log 106 | 107 | if optimizer_idx == 1: 108 | # second pass for discriminator update 109 | if cond is None: 110 | logits_real = self.discriminator(inputs.contiguous().detach())['y'] 111 | logits_fake = self.discriminator(reconstructions.contiguous().detach())['y'] 112 | else: 113 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))['y'] 114 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))['y'] 115 | 116 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 117 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 118 | 119 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 120 | "{}/logits_real".format(split): logits_real.detach().mean(), 121 | "{}/logits_fake".format(split): logits_fake.detach().mean() 122 | } 123 | return d_loss, log 124 | 125 | 126 | -------------------------------------------------------------------------------- /ldm/modules/losses_audio/lpaps.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on https://github.com/CompVis/taming-transformers/blob/52720829/taming/modules/losses/lpips.py 3 | Adapted for spectrograms by Vladimir Iashin (v-iashin) 4 | """ 5 | from collections import namedtuple 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | import sys 12 | sys.path.insert(0, '.') # nopep8 13 | # from ldm.modules.losses_audio.vggishish.model import VGGishish 14 | from ldm.util import get_ckpt_path 15 | 16 | 17 | class LPAPS(nn.Module):# this model is trained on 80melbins22050hz mel 18 | # Learned perceptual metric 19 | def __init__(self, use_dropout=True): 20 | super().__init__() 21 | self.scaling_layer = ScalingLayer() 22 | self.chns = [64, 128, 256, 512, 512] # vggish16 features 23 | self.net = vggishish16(pretrained=True, requires_grad=False) 24 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 25 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 26 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 27 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 28 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 29 | self.load_from_pretrained() 30 | for param in self.parameters(): 31 | param.requires_grad = False 32 | 33 | def load_from_pretrained(self, name="vggishish_lpaps"): 34 | ckpt = get_ckpt_path(name, "ldm/modules/autoencoder/lpaps") 35 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 36 | print("loaded pretrained LPAPS loss from {}".format(ckpt)) 37 | 38 | @classmethod 39 | def from_pretrained(cls, name="vggishish_lpaps"): 40 | if name != "vggishish_lpaps": 41 | raise NotImplementedError 42 | model = cls() 43 | ckpt = get_ckpt_path(name) 44 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 45 | return model 46 | 47 | def forward(self, input, target): 48 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 49 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 50 | feats0, feats1, diffs = {}, {}, {} 51 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 52 | for kk in range(len(self.chns)): 53 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 54 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 55 | 56 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 57 | val = res[0] 58 | for l in range(1, len(self.chns)): 59 | val += res[l] 60 | return val 61 | 62 | class ScalingLayer(nn.Module): 63 | def __init__(self): 64 | super(ScalingLayer, self).__init__() 65 | # we are gonna use get_ckpt_path to donwload the stats as well 66 | stat_path = get_ckpt_path('vggishish_mean_std_melspec_10s_22050hz', 'ldm/modules/autoencoder/lpaps') 67 | # if for images we normalize on the channel dim, in spectrogram we will norm on frequency dimension 68 | means, stds = np.loadtxt(stat_path, dtype=np.float32).T 69 | # the normalization in means and stds are given for [0, 1], but specvqgan expects [-1, 1]: 70 | means = 2 * means - 1 71 | stds = 2 * stds 72 | # input is expected to be (B, 1, F, T) 73 | self.register_buffer('shift', torch.from_numpy(means)[None, None, :, None]) 74 | self.register_buffer('scale', torch.from_numpy(stds)[None, None, :, None]) 75 | 76 | def forward(self, inp): 77 | return (inp - self.shift) / self.scale 78 | 79 | 80 | class NetLinLayer(nn.Module): 81 | """ A single linear layer which does a 1x1 conv """ 82 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 83 | super(NetLinLayer, self).__init__() 84 | layers = [nn.Dropout(), ] if (use_dropout) else [] 85 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 86 | self.model = nn.Sequential(*layers) 87 | 88 | class vggishish16(torch.nn.Module): 89 | def __init__(self, requires_grad=False, pretrained=True): 90 | super().__init__() 91 | vgg_pretrained_features = self.vggishish16(pretrained=pretrained).features 92 | self.slice1 = torch.nn.Sequential() 93 | self.slice2 = torch.nn.Sequential() 94 | self.slice3 = torch.nn.Sequential() 95 | self.slice4 = torch.nn.Sequential() 96 | self.slice5 = torch.nn.Sequential() 97 | self.N_slices = 5 98 | for x in range(4): 99 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 100 | for x in range(4, 9): 101 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 102 | for x in range(9, 16): 103 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 104 | for x in range(16, 23): 105 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 106 | for x in range(23, 30): 107 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 108 | if not requires_grad: 109 | for param in self.parameters(): 110 | param.requires_grad = False 111 | 112 | def forward(self, X): 113 | h = self.slice1(X) 114 | h_relu1_2 = h 115 | h = self.slice2(h) 116 | h_relu2_2 = h 117 | h = self.slice3(h) 118 | h_relu3_3 = h 119 | h = self.slice4(h) 120 | h_relu4_3 = h 121 | h = self.slice5(h) 122 | h_relu5_3 = h 123 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 124 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 125 | return out 126 | 127 | def vggishish16(self, pretrained: bool = True) -> VGGishish: 128 | # loading vggishish pretrained on vggsound 129 | num_classes_vggsound = 309 130 | conv_layers = [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512] 131 | model = VGGishish(conv_layers, use_bn=False, num_classes=num_classes_vggsound) 132 | if pretrained: 133 | ckpt_path = get_ckpt_path('vggishish_lpaps', "ldm/modules/autoencoder/lpaps") 134 | ckpt = torch.load(ckpt_path, map_location=torch.device("cpu")) 135 | model.load_state_dict(ckpt, strict=False) 136 | return model 137 | 138 | def normalize_tensor(x, eps=1e-10): 139 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 140 | return x / (norm_factor+eps) 141 | 142 | def spatial_average(x, keepdim=True): 143 | return x.mean([2, 3], keepdim=keepdim) 144 | 145 | 146 | if __name__ == '__main__': 147 | inputs = torch.rand((16, 1, 80, 848)) 148 | reconstructions = torch.rand((16, 1, 80, 848)) 149 | lpips = LPAPS().eval() 150 | loss_p = lpips(inputs.contiguous(), reconstructions.contiguous()) 151 | # (16, 1, 1, 1) 152 | print(loss_p.shape) 153 | -------------------------------------------------------------------------------- /ldm/modules/losses_audio/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | from ldm.util import exists 6 | sys.path.insert(0, '.') # nopep8 7 | from ldm.modules.discriminator.model import (NLayerDiscriminator, NLayerDiscriminator1dFeats, 8 | NLayerDiscriminator1dSpecs, 9 | weights_init) 10 | # from ldm.modules.losses_audio.lpaps import LPAPS 11 | from ldm.modules.losses.vqperceptual import l1, l2, measure_perplexity, hinge_d_loss, vanilla_d_loss, adopt_weight 12 | 13 | 14 | 15 | class DummyLoss(nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | class VQLPAPSWithDiscriminator(nn.Module): 20 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 21 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 22 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 23 | disc_ndf=64, disc_loss="hinge", n_classes=None, pixel_loss="l1"): 24 | super().__init__() 25 | assert disc_loss in ["hinge", "vanilla"] 26 | self.codebook_weight = codebook_weight 27 | self.pixel_weight = pixelloss_weight 28 | self.perceptual_loss = None # LPAPS().eval() 29 | self.perceptual_weight = perceptual_weight 30 | 31 | if pixel_loss == "l1": 32 | self.pixel_loss = l1 33 | else: 34 | self.pixel_loss = l2 35 | 36 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 37 | n_layers=disc_num_layers, 38 | use_actnorm=use_actnorm, 39 | ndf=disc_ndf 40 | ).apply(weights_init) 41 | self.discriminator_iter_start = disc_start 42 | if disc_loss == "hinge": 43 | self.disc_loss = hinge_d_loss 44 | elif disc_loss == "vanilla": 45 | self.disc_loss = vanilla_d_loss 46 | else: 47 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 48 | print(f"VQLPAPSWithDiscriminator running with {disc_loss} loss.") 49 | self.disc_factor = disc_factor 50 | self.discriminator_weight = disc_weight 51 | self.disc_conditional = disc_conditional 52 | self.n_classes = n_classes 53 | 54 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 55 | if last_layer is not None: 56 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 57 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 58 | else: 59 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 60 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 61 | 62 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 63 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 64 | d_weight = d_weight * self.discriminator_weight 65 | return d_weight 66 | 67 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 68 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 69 | if not exists(codebook_loss): 70 | codebook_loss = torch.tensor([0.]).to(inputs.device) 71 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 72 | if self.perceptual_weight > 0: 73 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 74 | rec_loss = rec_loss + self.perceptual_weight * p_loss 75 | else: 76 | p_loss = torch.tensor([0.0]) 77 | 78 | nll_loss = rec_loss 79 | # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 80 | nll_loss = torch.mean(nll_loss) 81 | 82 | # now the GAN part 83 | if optimizer_idx == 0: 84 | # generator update 85 | if cond is None: 86 | assert not self.disc_conditional 87 | logits_fake = self.discriminator(reconstructions.contiguous()) 88 | else: 89 | assert self.disc_conditional 90 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 91 | g_loss = -torch.mean(logits_fake) 92 | 93 | try: 94 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 95 | except RuntimeError: 96 | assert not self.training 97 | d_weight = torch.tensor(0.0) 98 | 99 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 100 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 101 | 102 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 103 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 104 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 105 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 106 | "{}/p_loss".format(split): p_loss.detach().mean(), 107 | "{}/d_weight".format(split): d_weight.detach(), 108 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 109 | "{}/g_loss".format(split): g_loss.detach().mean(), 110 | } 111 | # if predicted_indices is not None: 112 | # assert self.n_classes is not None 113 | # with torch.no_grad(): 114 | # perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 115 | # log[f"{split}/perplexity"] = perplexity 116 | # log[f"{split}/cluster_usage"] = cluster_usage 117 | return loss, log 118 | 119 | if optimizer_idx == 1: 120 | # second pass for discriminator update 121 | if cond is None: 122 | logits_real = self.discriminator(inputs.contiguous().detach()) 123 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 124 | else: 125 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 126 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 127 | 128 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 129 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 130 | 131 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 132 | "{}/logits_real".format(split): logits_real.detach().mean(), 133 | "{}/logits_fake".format(split): logits_fake.detach().mean() 134 | } 135 | return d_loss, log 136 | 137 | -------------------------------------------------------------------------------- /pythonscripts/InferAPI.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | import pathlib 3 | directory = pathlib.Path(os.getcwd()) 4 | print(directory) 5 | sys.path.append(str(directory)) 6 | import torch 7 | import numpy as np 8 | from omegaconf import OmegaConf 9 | from PIL import Image 10 | from tqdm import tqdm, trange 11 | from ldm.util import instantiate_from_config 12 | from ldm.models.diffusion.scheduling_lcm import LCMSampler 13 | from ldm.models.diffusion.plms import PLMSSampler 14 | import pandas as pd 15 | from torch.utils.data import DataLoader 16 | from tqdm import tqdm 17 | from icecream import ic 18 | from pathlib import Path 19 | import soundfile as sf 20 | import yaml 21 | import datetime 22 | from vocoder.bigvgan.models import VocoderBigVGAN 23 | import soundfile 24 | # from pytorch_memlab import LineProfiler,profile 25 | 26 | def load_model_from_config(config, ckpt = None, verbose=True): 27 | model = instantiate_from_config(config.model) 28 | if ckpt: 29 | print(f"Loading model from {ckpt}") 30 | pl_sd = torch.load(ckpt, map_location="cpu") 31 | sd = pl_sd["state_dict"] 32 | 33 | m, u = model.load_state_dict(sd, strict=False) 34 | if len(m) > 0 and verbose: 35 | print("missing keys:") 36 | print(m) 37 | if len(u) > 0 and verbose: 38 | print("unexpected keys:") 39 | print(u) 40 | else: 41 | print(f"Note chat no ckpt is loaded !!!") 42 | 43 | model.cuda() 44 | model.eval() 45 | return model 46 | 47 | 48 | 49 | 50 | class GenSamples: 51 | def __init__(self,sampler,model,outpath,vocoder = None,save_mel = True,save_wav = True, original_inference_steps=None) -> None: 52 | self.sampler = sampler 53 | self.model = model 54 | self.outpath = outpath 55 | if save_wav: 56 | assert vocoder is not None 57 | self.vocoder = vocoder 58 | self.save_mel = save_mel 59 | self.save_wav = save_wav 60 | self.channel_dim = self.model.channels 61 | self.original_inference_steps = original_inference_steps 62 | 63 | def gen_test_sample(self,prompt,mel_name = None,wav_name = None):# prompt is {'ori_caption':’xxx‘,'struct_caption':'xxx'} 64 | uc = None 65 | record_dicts = [] 66 | # if os.path.exists(os.path.join(self.outpath,mel_name+f'_0.npy')): 67 | # return record_dicts 68 | emptycap = {'ori_caption':1*[""],'struct_caption':1*[""]} 69 | uc = self.model.get_learned_conditioning(emptycap) 70 | 71 | for n in range(1):# trange(self.opt.n_iter, desc="Sampling"): 72 | for k,v in prompt.items(): 73 | prompt[k] = 1 * [v] 74 | c = self.model.get_learned_conditioning(prompt)# shape:[1,77,1280],即还没有变成句子embedding,仍是每个单词的embedding 75 | if self.channel_dim>0: 76 | shape = [self.channel_dim, 20, 312] # (z_dim, 80//2^x, 848//2^x) 77 | else: 78 | shape = [20, 312] 79 | samples_ddim, _ = self.sampler.sample(S=2, 80 | conditioning=c, 81 | batch_size=1, 82 | shape=shape, 83 | verbose=False, 84 | guidance_scale=5, 85 | original_inference_steps=self.original_inference_steps 86 | ) 87 | x_samples_ddim = self.model.decode_first_stage(samples_ddim) 88 | for idx,spec in enumerate(x_samples_ddim): 89 | spec = spec.squeeze(0).cpu().numpy() 90 | record_dict = {'caption':prompt['ori_caption'][0]} 91 | if self.save_mel: 92 | mel_path = os.path.join(self.outpath,mel_name+f'_{idx}.npy') 93 | np.save(mel_path,spec) 94 | record_dict['mel_path'] = mel_path 95 | if self.save_wav: 96 | wav = self.vocoder.vocode(spec) 97 | wav_path = os.path.join(self.outpath,wav_name+f'_{idx}.wav') 98 | soundfile.write(wav_path, wav, 16000) 99 | record_dict['audio_path'] = wav_path 100 | record_dicts.append(record_dict) 101 | return record_dicts 102 | 103 | def AudioLCMInfer(ori_prompt, config_path = "configs/audiolcm.yaml", model_path = "./model/000184.ckpt", vocoder_path = "./model/vocoder"): 104 | 105 | prompt = dict(ori_caption=ori_prompt,struct_caption=f'<{ori_prompt}& all>') 106 | 107 | 108 | config = OmegaConf.load(config_path) 109 | 110 | # print("-------quick debug no load ckpt---------") 111 | # model = instantiate_from_config(config['model'])# for quick debug 112 | model = load_model_from_config(config, model_path) 113 | 114 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 115 | model = model.to(device) 116 | 117 | sampler = LCMSampler(model) 118 | 119 | os.makedirs("results/test", exist_ok=True) 120 | 121 | vocoder = VocoderBigVGAN(vocoder_path,device) 122 | 123 | 124 | generator = GenSamples(sampler,model,"results/test",vocoder,save_mel = False,save_wav = True, original_inference_steps=config.model.params.num_ddim_timesteps) 125 | csv_dicts = [] 126 | 127 | with torch.no_grad(): 128 | with model.ema_scope(): 129 | wav_name = f'{prompt["ori_caption"].strip().replace(" ", "-")}' 130 | generator.gen_test_sample(prompt,wav_name=wav_name) 131 | 132 | print(f"Your samples are ready and waiting four you here: \nresults/test \nEnjoy.") 133 | return "results/test/"+wav_name+"_0.wav" 134 | 135 | def AudioLCMBatchInfer(ori_prompts, config_path = "configs/audiolcm.yaml", model_path = "./model/000184.ckpt", vocoder_path = "./model/vocoder"): 136 | 137 | prompts = [dict(ori_caption=ori_prompt,struct_caption=f'<{ori_prompt}& all>') for ori_prompt in ori_prompts] 138 | 139 | 140 | config = OmegaConf.load(config_path) 141 | 142 | # print("-------quick debug no load ckpt---------") 143 | # model = instantiate_from_config(config['model'])# for quick debug 144 | model = load_model_from_config(config, model_path) 145 | 146 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 147 | model = model.to(device) 148 | 149 | sampler = LCMSampler(model) 150 | 151 | os.makedirs("results/test", exist_ok=True) 152 | 153 | vocoder = VocoderBigVGAN(vocoder_path,device) 154 | 155 | 156 | generator = GenSamples(sampler,model,"results/test",vocoder,save_mel = False,save_wav = True, original_inference_steps=config.model.params.num_ddim_timesteps) 157 | csv_dicts = [] 158 | 159 | for prompt in prompts: 160 | with torch.no_grad(): 161 | with model.ema_scope(): 162 | wav_name = f'{prompt["ori_caption"].strip().replace(" ", "-")}' 163 | generator.gen_test_sample(prompt,wav_name=wav_name) 164 | 165 | print(f"Your samples are ready and waiting four you here: \nresults/test \nEnjoy.") 166 | return "results/test/"+wav_name+"_0.wav" 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /pythonscripts/__pycache__/InferAPI.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuhuadai/AudioLCM/be5a709a1020072e3ca2d66289724f15bb4c917c/pythonscripts/__pycache__/InferAPI.cpython-38.pyc -------------------------------------------------------------------------------- /pythonscripts/__pycache__/txt2audio_for_2cap.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuhuadai/AudioLCM/be5a709a1020072e3ca2d66289724f15bb4c917c/pythonscripts/__pycache__/txt2audio_for_2cap.cpython-37.pyc -------------------------------------------------------------------------------- /pythonscripts/reconstruct_audio.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | import pathlib 3 | directory = pathlib.Path(os.getcwd()) 4 | print(directory) 5 | sys.path.append(str(directory)) 6 | import torch 7 | import numpy as np 8 | from omegaconf import OmegaConf 9 | from PIL import Image 10 | from tqdm import tqdm, trange 11 | from ldm.util import instantiate_from_config 12 | from ldm.models.diffusion.ddim import DDIMSampler 13 | from ldm.models.diffusion.plms import PLMSSampler 14 | import pandas as pd 15 | from torch.utils.data import DataLoader 16 | from tqdm import tqdm 17 | from icecream import ic 18 | from pathlib import Path 19 | import yaml 20 | from vocoder.bigvgan.models import VocoderBigVGAN 21 | import soundfile 22 | # from pytorch_memlab import LineProfiler,profile 23 | 24 | def load_model_from_config(config, ckpt = None, verbose=True): 25 | model = instantiate_from_config(config.model) 26 | if ckpt: 27 | print(f"Loading model from {ckpt}") 28 | pl_sd = torch.load(ckpt, map_location="cpu") 29 | sd = pl_sd["state_dict"] 30 | 31 | m, u = model.load_state_dict(sd, strict=False) 32 | if len(m) > 0 and verbose: 33 | print("missing keys:") 34 | print(m) 35 | if len(u) > 0 and verbose: 36 | print("unexpected keys:") 37 | print(u) 38 | else: 39 | print(f"Note chat no ckpt is loaded !!!") 40 | 41 | model.cuda() 42 | model.eval() 43 | return model 44 | 45 | 46 | def parse_args(): 47 | parser = argparse.ArgumentParser() 48 | 49 | parser.add_argument( 50 | "--sample_rate", 51 | type=int, 52 | default="16000", 53 | help="sample rate of wav" 54 | ) 55 | 56 | parser.add_argument( 57 | "--test-dataset", 58 | default="none", 59 | help="test which dataset: audiocaps/clotho/fsd50k" 60 | ) 61 | parser.add_argument( 62 | "--outdir", 63 | type=str, 64 | nargs="?", 65 | help="dir to write results to", 66 | default="outputs/txt2audio-samples" 67 | ) 68 | 69 | 70 | 71 | parser.add_argument( 72 | "-r", 73 | "--resume", 74 | type=str, 75 | const=True, 76 | default="", 77 | nargs="?", 78 | help="resume from logdir or checkpoint in logdir", 79 | ) 80 | parser.add_argument( 81 | "-b", 82 | "--base", 83 | type=str, 84 | help="paths to base configs. Loaded from left-to-right. " 85 | "Parameters can be overwritten or added with command-line options of the form `--key value`.", 86 | default="", 87 | ) 88 | parser.add_argument( 89 | "--vocoder-ckpt", 90 | type=str, 91 | help="paths to vocoder checkpoint", 92 | default='vocoder/logs/bigvnat16k93.5w', 93 | ) 94 | 95 | return parser.parse_args() 96 | 97 | class GenSamples: 98 | def __init__(self,opt,model,outpath,vocoder = None,save_mel = False,save_wav = True) -> None: 99 | self.opt = opt 100 | self.model = model 101 | self.outpath = outpath 102 | if save_wav: 103 | assert vocoder is not None 104 | self.vocoder = vocoder 105 | self.save_mel = save_mel 106 | self.save_wav = save_wav 107 | 108 | def gen_test_sample(self,mel,mel_name = None,wav_name = None):# prompt is {'ori_caption':’xxx‘,'struct_caption':'xxx'} 109 | uc = None 110 | record_dicts = [] 111 | # if os.path.exists(os.path.join(self.outpath,mel_name+f'_0.npy')): 112 | # return record_dicts 113 | # import ipdb 114 | # ipdb.set_trace() 115 | recon_mel,posterior = self.model(mel) 116 | spec = recon_mel.squeeze(0).cpu().numpy() 117 | 118 | 119 | if self.save_wav: 120 | wav = self.vocoder.vocode(spec) 121 | wav_path = os.path.join(self.outpath,wav_name+'.wav') 122 | soundfile.write(wav_path, wav, self.opt.sample_rate) 123 | return 124 | 125 | def main(): 126 | opt = parse_args() 127 | 128 | config = OmegaConf.load(opt.base) 129 | # print("-------quick debug no load ckpt---------") 130 | # model = instantiate_from_config(config['model'])# for quick debug 131 | model = load_model_from_config(config, opt.resume) 132 | 133 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 134 | model = model.to(device) 135 | 136 | 137 | os.makedirs(opt.outdir, exist_ok=True) 138 | if 'mel' in opt.vocoder_ckpt: 139 | vocoder = VocoderMelGan(opt.vocoder_ckpt,device) 140 | elif 'hifi' in opt.vocoder_ckpt: 141 | vocoder = VocoderHifigan(opt.vocoder_ckpt,device) 142 | elif 'bigv' in opt.vocoder_ckpt: 143 | vocoder = VocoderBigVGAN(opt.vocoder_ckpt,device) 144 | 145 | 146 | generator = GenSamples(opt,model,opt.outdir,vocoder,save_mel = False,save_wav = True) 147 | csv_dicts = [] 148 | 149 | with torch.no_grad(): 150 | if opt.test_dataset != 'none': 151 | if opt.test_dataset == 'audiocaps': 152 | test_dataset = instantiate_from_config(config['test_dataset']) 153 | elif opt.test_dataset == 'clotho': 154 | test_dataset = instantiate_from_config(config['test_dataset2']) 155 | elif opt.test_dataset == 'fsd50k': 156 | test_dataset = instantiate_from_config(config['test_dataset3']) 157 | elif opt.test_dataset == 'musiccap': 158 | test_dataset = instantiate_from_config(config['test_dataset']) 159 | print(f"Dataset: {type(test_dataset)} LEN: {len(test_dataset)}") 160 | for item in tqdm(test_dataset): 161 | mel,f_name = item['image'],item['f_name'] 162 | mel = torch.from_numpy(mel).to(device).unsqueeze(0) 163 | vname_num_split_index = f_name.rfind('_')# file_names[b]:video_name+'_'+num 164 | v_n,num = f_name[:vname_num_split_index],f_name[vname_num_split_index+1:] 165 | mel_name = f'{v_n}_sample_{num}' 166 | wav_name = f'{v_n}_sample_{num}' 167 | generator.gen_test_sample(mel,mel_name=mel_name,wav_name=wav_name) 168 | # write_gt_wav(v_n,opt.test_dataset2,opt.outdir,opt.sample_rate) 169 | # csv_dicts.extend(generator.gen_test_sample(mel,mel_name=mel_name,wav_name=wav_name)) 170 | 171 | # df = pd.DataFrame.from_dict(csv_dicts) 172 | # df.to_csv(os.path.join(opt.outdir,'result.csv'),sep='\t',index=False) 173 | else: 174 | with open(opt.prompt_txt,'r') as f: 175 | prompts = f.readlines() 176 | for prompt in prompts: 177 | wav_name = f'{prompt.strip().replace(" ", "-")}' 178 | generator.gen_test_sample(prompt,wav_name=wav_name) 179 | 180 | print(f"Your samples are ready and waiting four you here: \n{opt.outdir} \nEnjoy.") 181 | 182 | if __name__ == "__main__": 183 | main() 184 | 185 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # torch==1.12.1+cu113 2 | # torchaudio==0.12.1+cu113 3 | # torchvision==0.13.1+cu113 4 | torchlibrosa==0.1.0 5 | pytorch-lightning==1.7.0 6 | # librosa==0.8.0 7 | # soundfile==0.11.0 8 | # omegaconf==2.3.0 9 | jupyter==1.0.0 10 | icecream 11 | torchmetrics==0.11.4 12 | huggingface_hub==0.20.2 13 | calmsize 14 | librosa==0.10.1 15 | # tqdm 16 | # numpy 17 | taming-transformers-rom1504 18 | -------------------------------------------------------------------------------- /scripts/reconstruct_audio.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | import pathlib 3 | directory = pathlib.Path(os.getcwd()) 4 | print(directory) 5 | sys.path.append(str(directory)) 6 | import torch 7 | import numpy as np 8 | from omegaconf import OmegaConf 9 | from PIL import Image 10 | from tqdm import tqdm, trange 11 | from ldm.util import instantiate_from_config 12 | from ldm.models.diffusion.ddim import DDIMSampler 13 | from ldm.models.diffusion.plms import PLMSSampler 14 | import pandas as pd 15 | from torch.utils.data import DataLoader 16 | from tqdm import tqdm 17 | from icecream import ic 18 | from pathlib import Path 19 | import yaml 20 | from vocoder.bigvgan.models import VocoderBigVGAN 21 | import soundfile 22 | # from pytorch_memlab import LineProfiler,profile 23 | 24 | def load_model_from_config(config, ckpt = None, verbose=True): 25 | model = instantiate_from_config(config.model) 26 | if ckpt: 27 | print(f"Loading model from {ckpt}") 28 | pl_sd = torch.load(ckpt, map_location="cpu") 29 | sd = pl_sd["state_dict"] 30 | 31 | m, u = model.load_state_dict(sd, strict=False) 32 | if len(m) > 0 and verbose: 33 | print("missing keys:") 34 | print(m) 35 | if len(u) > 0 and verbose: 36 | print("unexpected keys:") 37 | print(u) 38 | else: 39 | print(f"Note chat no ckpt is loaded !!!") 40 | 41 | model.cuda() 42 | model.eval() 43 | return model 44 | 45 | 46 | def parse_args(): 47 | parser = argparse.ArgumentParser() 48 | 49 | parser.add_argument( 50 | "--sample_rate", 51 | type=int, 52 | default="16000", 53 | help="sample rate of wav" 54 | ) 55 | 56 | parser.add_argument( 57 | "--test-dataset", 58 | default="none", 59 | help="test which dataset: audiocaps/clotho/fsd50k" 60 | ) 61 | parser.add_argument( 62 | "--outdir", 63 | type=str, 64 | nargs="?", 65 | help="dir to write results to", 66 | default="outputs/txt2audio-samples" 67 | ) 68 | 69 | 70 | 71 | parser.add_argument( 72 | "-r", 73 | "--resume", 74 | type=str, 75 | const=True, 76 | default="", 77 | nargs="?", 78 | help="resume from logdir or checkpoint in logdir", 79 | ) 80 | parser.add_argument( 81 | "-b", 82 | "--base", 83 | type=str, 84 | help="paths to base configs. Loaded from left-to-right. " 85 | "Parameters can be overwritten or added with command-line options of the form `--key value`.", 86 | default="", 87 | ) 88 | parser.add_argument( 89 | "--vocoder-ckpt", 90 | type=str, 91 | help="paths to vocoder checkpoint", 92 | default='vocoder/logs/bigvnat16k93.5w', 93 | ) 94 | 95 | return parser.parse_args() 96 | 97 | class GenSamples: 98 | def __init__(self,opt,model,outpath,vocoder = None,save_mel = False,save_wav = True) -> None: 99 | self.opt = opt 100 | self.model = model 101 | self.outpath = outpath 102 | if save_wav: 103 | assert vocoder is not None 104 | self.vocoder = vocoder 105 | self.save_mel = save_mel 106 | self.save_wav = save_wav 107 | 108 | def gen_test_sample(self,mel,mel_name = None,wav_name = None):# prompt is {'ori_caption':’xxx‘,'struct_caption':'xxx'} 109 | uc = None 110 | record_dicts = [] 111 | # if os.path.exists(os.path.join(self.outpath,mel_name+f'_0.npy')): 112 | # return record_dicts 113 | # import ipdb 114 | # ipdb.set_trace() 115 | recon_mel,posterior = self.model(mel) 116 | spec = recon_mel.squeeze(0).cpu().numpy() 117 | 118 | 119 | if self.save_wav: 120 | wav = self.vocoder.vocode(spec) 121 | wav_path = os.path.join(self.outpath,wav_name+'.wav') 122 | soundfile.write(wav_path, wav, self.opt.sample_rate) 123 | return 124 | 125 | def main(): 126 | opt = parse_args() 127 | 128 | config = OmegaConf.load(opt.base) 129 | # print("-------quick debug no load ckpt---------") 130 | # model = instantiate_from_config(config['model'])# for quick debug 131 | model = load_model_from_config(config, opt.resume) 132 | 133 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 134 | model = model.to(device) 135 | 136 | 137 | os.makedirs(opt.outdir, exist_ok=True) 138 | if 'mel' in opt.vocoder_ckpt: 139 | vocoder = VocoderMelGan(opt.vocoder_ckpt,device) 140 | elif 'hifi' in opt.vocoder_ckpt: 141 | vocoder = VocoderHifigan(opt.vocoder_ckpt,device) 142 | elif 'bigv' in opt.vocoder_ckpt: 143 | vocoder = VocoderBigVGAN(opt.vocoder_ckpt,device) 144 | 145 | 146 | generator = GenSamples(opt,model,opt.outdir,vocoder,save_mel = False,save_wav = True) 147 | csv_dicts = [] 148 | 149 | with torch.no_grad(): 150 | if opt.test_dataset != 'none': 151 | if opt.test_dataset == 'audiocaps': 152 | test_dataset = instantiate_from_config(config['test_dataset']) 153 | elif opt.test_dataset == 'clotho': 154 | test_dataset = instantiate_from_config(config['test_dataset2']) 155 | elif opt.test_dataset == 'fsd50k': 156 | test_dataset = instantiate_from_config(config['test_dataset3']) 157 | elif opt.test_dataset == 'musiccap': 158 | test_dataset = instantiate_from_config(config['test_dataset']) 159 | print(f"Dataset: {type(test_dataset)} LEN: {len(test_dataset)}") 160 | for item in tqdm(test_dataset): 161 | mel,f_name = item['image'],item['f_name'] 162 | mel = torch.from_numpy(mel).to(device).unsqueeze(0) 163 | vname_num_split_index = f_name.rfind('_')# file_names[b]:video_name+'_'+num 164 | v_n,num = f_name[:vname_num_split_index],f_name[vname_num_split_index+1:] 165 | mel_name = f'{v_n}_sample_{num}' 166 | wav_name = f'{v_n}_sample_{num}' 167 | generator.gen_test_sample(mel,mel_name=mel_name,wav_name=wav_name) 168 | # write_gt_wav(v_n,opt.test_dataset2,opt.outdir,opt.sample_rate) 169 | # csv_dicts.extend(generator.gen_test_sample(mel,mel_name=mel_name,wav_name=wav_name)) 170 | 171 | # df = pd.DataFrame.from_dict(csv_dicts) 172 | # df.to_csv(os.path.join(opt.outdir,'result.csv'),sep='\t',index=False) 173 | else: 174 | with open(opt.prompt_txt,'r') as f: 175 | prompts = f.readlines() 176 | for prompt in prompts: 177 | wav_name = f'{prompt.strip().replace(" ", "-")}' 178 | generator.gen_test_sample(prompt,wav_name=wav_name) 179 | 180 | print(f"Your samples are ready and waiting four you here: \n{opt.outdir} \nEnjoy.") 181 | 182 | if __name__ == "__main__": 183 | main() 184 | 185 | -------------------------------------------------------------------------------- /vocoder/bigvgan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 NVIDIA CORPORATION. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /vocoder/bigvgan/README.md: -------------------------------------------------------------------------------- 1 | ## BigVGAN: A Universal Neural Vocoder with Large-Scale Training 2 | #### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon 3 | 4 |
5 | 6 | 7 | ### [Paper](https://arxiv.org/abs/2206.04658) 8 | ### [Audio demo](https://bigvgan-demo.github.io/) 9 | 10 | ## Installation 11 | Clone the repository and install dependencies. 12 | ```shell 13 | # the codebase has been tested on Python 3.8 / 3.10 with PyTorch 1.12.1 / 1.13 conda binaries 14 | git clone https://github.com/NVIDIA/BigVGAN 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset. 19 | ``` shell 20 | cd LibriTTS && \ 21 | ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \ 22 | ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \ 23 | ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \ 24 | ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \ 25 | ln -s /path/to/your/LibriTTS/dev-other dev-other && \ 26 | ln -s /path/to/your/LibriTTS/test-clean test-clean && \ 27 | ln -s /path/to/your/LibriTTS/test-other test-other && \ 28 | cd .. 29 | ``` 30 | 31 | ## Training 32 | Train BigVGAN model. Below is an example command for training BigVGAN using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input. 33 | ```shell 34 | python train.py \ 35 | --config configs/bigvgan_24khz_100band.json \ 36 | --input_wavs_dir LibriTTS \ 37 | --input_training_file LibriTTS/train-full.txt \ 38 | --input_validation_file LibriTTS/val-full.txt \ 39 | --list_input_unseen_wavs_dir LibriTTS LibriTTS \ 40 | --list_input_unseen_validation_file LibriTTS/dev-clean.txt LibriTTS/dev-other.txt \ 41 | --checkpoint_path exp/bigvgan 42 | ``` 43 | 44 | ## Synthesis 45 | Synthesize from BigVGAN model. Below is an example command for generating audio from the model. 46 | It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`. 47 | ```shell 48 | python inference.py \ 49 | --checkpoint_file exp/bigvgan/g_05000000 \ 50 | --input_wavs_dir /path/to/your/input_wav \ 51 | --output_dir /path/to/your/output_wav 52 | ``` 53 | 54 | `inference_e2e.py` supports synthesis directly from the mel spectrogram saved in `.npy` format, with shapes `[1, channel, frame]` or `[channel, frame]`. 55 | It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`. 56 | 57 | Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model. 58 | ```shell 59 | python inference_e2e.py \ 60 | --checkpoint_file exp/bigvgan/g_05000000 \ 61 | --input_mels_dir /path/to/your/input_mel \ 62 | --output_dir /path/to/your/output_wav 63 | ``` 64 | 65 | ## Pretrained Models 66 | We provide the [pretrained models](https://drive.google.com/drive/folders/1e9wdM29d-t3EHUpBb8T4dcHrkYGAXTgq). 67 | One can download the checkpoints of generator (e.g., g_05000000) and discriminator (e.g., do_05000000) within the listed folders. 68 | 69 | |Folder Name|Sampling Rate|Mel band|fmax|Params.|Dataset|Fine-Tuned| 70 | |------|---|---|---|---|------|---| 71 | |bigvgan_24khz_100band|24 kHz|100|12000|112M|LibriTTS|No| 72 | |bigvgan_base_24khz_100band|24 kHz|100|12000|14M|LibriTTS|No| 73 | |bigvgan_22khz_80band|22 kHz|80|8000|112M|LibriTTS + VCTK + LJSpeech|No| 74 | |bigvgan_base_22khz_80band|22 kHz|80|8000|14M|LibriTTS + VCTK + LJSpeech|No| 75 | 76 | The paper results are based on 24kHz BigVGAN models trained on LibriTTS dataset. 77 | We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications. 78 | Note that, the latest checkpoints use ``snakebeta`` activation with log scale parameterization, which have the best overall quality. 79 | 80 | 81 | ## TODO 82 | 83 | Current codebase only provides a plain PyTorch implementation for the filtered nonlinearity. We are working on a fast CUDA kernel implementation, which will be released in the future. 84 | 85 | 86 | ## References 87 | * [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator) 88 | 89 | * [Snake](https://github.com/EdwardDixon/snake) (for periodic activation) 90 | 91 | * [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing) 92 | 93 | * [Julius](https://github.com/adefossez/julius) (for low-pass filter) 94 | 95 | * [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator) -------------------------------------------------------------------------------- /vocoder/bigvgan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuhuadai/AudioLCM/be5a709a1020072e3ca2d66289724f15bb4c917c/vocoder/bigvgan/__init__.py -------------------------------------------------------------------------------- /vocoder/bigvgan/activations.py: -------------------------------------------------------------------------------- 1 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | from torch import nn, sin, pow 6 | from torch.nn import Parameter 7 | 8 | 9 | class Snake(nn.Module): 10 | ''' 11 | Implementation of a sine-based periodic activation function 12 | Shape: 13 | - Input: (B, C, T) 14 | - Output: (B, C, T), same shape as the input 15 | Parameters: 16 | - alpha - trainable parameter 17 | References: 18 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 19 | https://arxiv.org/abs/2006.08195 20 | Examples: 21 | >>> a1 = snake(256) 22 | >>> x = torch.randn(256) 23 | >>> x = a1(x) 24 | ''' 25 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 26 | ''' 27 | Initialization. 28 | INPUT: 29 | - in_features: shape of the input 30 | - alpha: trainable parameter 31 | alpha is initialized to 1 by default, higher values = higher-frequency. 32 | alpha will be trained along with the rest of your model. 33 | ''' 34 | super(Snake, self).__init__() 35 | self.in_features = in_features 36 | 37 | # initialize alpha 38 | self.alpha_logscale = alpha_logscale 39 | if self.alpha_logscale: # log scale alphas initialized to zeros 40 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 41 | else: # linear scale alphas initialized to ones 42 | self.alpha = Parameter(torch.ones(in_features) * alpha) 43 | 44 | self.alpha.requires_grad = alpha_trainable 45 | 46 | self.no_div_by_zero = 0.000000001 47 | 48 | def forward(self, x): 49 | ''' 50 | Forward pass of the function. 51 | Applies the function to the input elementwise. 52 | Snake ∶= x + 1/a * sin^2 (xa) 53 | ''' 54 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 55 | if self.alpha_logscale: 56 | alpha = torch.exp(alpha) 57 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 58 | 59 | return x 60 | 61 | 62 | class SnakeBeta(nn.Module): 63 | ''' 64 | A modified Snake function which uses separate parameters for the magnitude of the periodic components 65 | Shape: 66 | - Input: (B, C, T) 67 | - Output: (B, C, T), same shape as the input 68 | Parameters: 69 | - alpha - trainable parameter that controls frequency 70 | - beta - trainable parameter that controls magnitude 71 | References: 72 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 73 | https://arxiv.org/abs/2006.08195 74 | Examples: 75 | >>> a1 = snakebeta(256) 76 | >>> x = torch.randn(256) 77 | >>> x = a1(x) 78 | ''' 79 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 80 | ''' 81 | Initialization. 82 | INPUT: 83 | - in_features: shape of the input 84 | - alpha - trainable parameter that controls frequency 85 | - beta - trainable parameter that controls magnitude 86 | alpha is initialized to 1 by default, higher values = higher-frequency. 87 | beta is initialized to 1 by default, higher values = higher-magnitude. 88 | alpha will be trained along with the rest of your model. 89 | ''' 90 | super(SnakeBeta, self).__init__() 91 | self.in_features = in_features 92 | 93 | # initialize alpha 94 | self.alpha_logscale = alpha_logscale 95 | if self.alpha_logscale: # log scale alphas initialized to zeros 96 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 97 | self.beta = Parameter(torch.zeros(in_features) * alpha) 98 | else: # linear scale alphas initialized to ones 99 | self.alpha = Parameter(torch.ones(in_features) * alpha) 100 | self.beta = Parameter(torch.ones(in_features) * alpha) 101 | 102 | self.alpha.requires_grad = alpha_trainable 103 | self.beta.requires_grad = alpha_trainable 104 | 105 | self.no_div_by_zero = 0.000000001 106 | 107 | def forward(self, x): 108 | ''' 109 | Forward pass of the function. 110 | Applies the function to the input elementwise. 111 | SnakeBeta ∶= x + 1/b * sin^2 (xa) 112 | ''' 113 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 114 | beta = self.beta.unsqueeze(0).unsqueeze(-1) 115 | if self.alpha_logscale: 116 | alpha = torch.exp(alpha) 117 | beta = torch.exp(beta) 118 | x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 119 | 120 | return x -------------------------------------------------------------------------------- /vocoder/bigvgan/alias_free_torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | from .filter import * 5 | from .resample import * 6 | from .act import * -------------------------------------------------------------------------------- /vocoder/bigvgan/alias_free_torch/act.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from .resample import UpSample1d, DownSample1d 6 | 7 | 8 | class Activation1d(nn.Module): 9 | def __init__(self, 10 | activation, 11 | up_ratio: int = 2, 12 | down_ratio: int = 2, 13 | up_kernel_size: int = 12, 14 | down_kernel_size: int = 12): 15 | super().__init__() 16 | self.up_ratio = up_ratio 17 | self.down_ratio = down_ratio 18 | self.act = activation 19 | self.upsample = UpSample1d(up_ratio, up_kernel_size) 20 | self.downsample = DownSample1d(down_ratio, down_kernel_size) 21 | 22 | # x: [B,C,T] 23 | def forward(self, x): 24 | x = self.upsample(x) 25 | x = self.act(x) 26 | x = self.downsample(x) 27 | 28 | return x -------------------------------------------------------------------------------- /vocoder/bigvgan/alias_free_torch/filter.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | 9 | if 'sinc' in dir(torch): 10 | sinc = torch.sinc 11 | else: 12 | # This code is adopted from adefossez's julius.core.sinc under the MIT License 13 | # https://adefossez.github.io/julius/julius/core.html 14 | # LICENSE is in incl_licenses directory. 15 | def sinc(x: torch.Tensor): 16 | """ 17 | Implementation of sinc, i.e. sin(pi * x) / (pi * x) 18 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`! 19 | """ 20 | return torch.where(x == 0, 21 | torch.tensor(1., device=x.device, dtype=x.dtype), 22 | torch.sin(math.pi * x) / math.pi / x) 23 | 24 | 25 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License 26 | # https://adefossez.github.io/julius/julius/lowpass.html 27 | # LICENSE is in incl_licenses directory. 28 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] 29 | even = (kernel_size % 2 == 0) 30 | half_size = kernel_size // 2 31 | 32 | #For kaiser window 33 | delta_f = 4 * half_width 34 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 35 | if A > 50.: 36 | beta = 0.1102 * (A - 8.7) 37 | elif A >= 21.: 38 | beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) 39 | else: 40 | beta = 0. 41 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) 42 | 43 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio 44 | if even: 45 | time = (torch.arange(-half_size, half_size) + 0.5) 46 | else: 47 | time = torch.arange(kernel_size) - half_size 48 | if cutoff == 0: 49 | filter_ = torch.zeros_like(time) 50 | else: 51 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) 52 | # Normalize filter to have sum = 1, otherwise we will have a small leakage 53 | # of the constant component in the input signal. 54 | filter_ /= filter_.sum() 55 | filter = filter_.view(1, 1, kernel_size) 56 | 57 | return filter 58 | 59 | 60 | class LowPassFilter1d(nn.Module): 61 | def __init__(self, 62 | cutoff=0.5, 63 | half_width=0.6, 64 | stride: int = 1, 65 | padding: bool = True, 66 | padding_mode: str = 'replicate', 67 | kernel_size: int = 12): 68 | # kernel_size should be even number for stylegan3 setup, 69 | # in this implementation, odd number is also possible. 70 | super().__init__() 71 | if cutoff < -0.: 72 | raise ValueError("Minimum cutoff must be larger than zero.") 73 | if cutoff > 0.5: 74 | raise ValueError("A cutoff above 0.5 does not make sense.") 75 | self.kernel_size = kernel_size 76 | self.even = (kernel_size % 2 == 0) 77 | self.pad_left = kernel_size // 2 - int(self.even) 78 | self.pad_right = kernel_size // 2 79 | self.stride = stride 80 | self.padding = padding 81 | self.padding_mode = padding_mode 82 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) 83 | self.register_buffer("filter", filter) 84 | 85 | #input [B, C, T] 86 | def forward(self, x): 87 | _, C, _ = x.shape 88 | 89 | if self.padding: 90 | x = F.pad(x, (self.pad_left, self.pad_right), 91 | mode=self.padding_mode) 92 | out = F.conv1d(x, self.filter.expand(C, -1, -1), 93 | stride=self.stride, groups=C) 94 | 95 | return out -------------------------------------------------------------------------------- /vocoder/bigvgan/alias_free_torch/resample.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from .filter import LowPassFilter1d 7 | from .filter import kaiser_sinc_filter1d 8 | 9 | 10 | class UpSample1d(nn.Module): 11 | def __init__(self, ratio=2, kernel_size=None): 12 | super().__init__() 13 | self.ratio = ratio 14 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 15 | self.stride = ratio 16 | self.pad = self.kernel_size // ratio - 1 17 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 18 | self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 19 | filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, 20 | half_width=0.6 / ratio, 21 | kernel_size=self.kernel_size) 22 | self.register_buffer("filter", filter) 23 | 24 | # x: [B, C, T] 25 | def forward(self, x): 26 | _, C, _ = x.shape 27 | 28 | x = F.pad(x, (self.pad, self.pad), mode='replicate') 29 | x = self.ratio * F.conv_transpose1d( 30 | x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) 31 | x = x[..., self.pad_left:-self.pad_right] 32 | 33 | return x 34 | 35 | 36 | class DownSample1d(nn.Module): 37 | def __init__(self, ratio=2, kernel_size=None): 38 | super().__init__() 39 | self.ratio = ratio 40 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 41 | self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, 42 | half_width=0.6 / ratio, 43 | stride=ratio, 44 | kernel_size=self.kernel_size) 45 | 46 | def forward(self, x): 47 | xx = self.lowpass(x) 48 | 49 | return xx -------------------------------------------------------------------------------- /vocoder/bigvgan/bigvgan_audioset16khz_80band.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 32, 5 | "learning_rate": 0.0001, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [4,4,2,2,2,2], 12 | "upsample_kernel_sizes": [8,8,4,4,4,4], 13 | "upsample_initial_channel": 1536, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "activation": "snakebeta", 18 | "snake_logscale": true, 19 | 20 | "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], 21 | "mpd_reshapes": [2, 3, 5, 7, 11], 22 | "use_spectral_norm": false, 23 | "discriminator_channel_mult": 1, 24 | 25 | "segment_size": 8192, 26 | "num_mels": 80, 27 | "num_freq": 1025, 28 | "n_fft": 1024, 29 | "hop_size": 256, 30 | "win_size": 1024, 31 | 32 | "sampling_rate": 16000, 33 | 34 | "fmin": 125, 35 | "fmax": 7600, 36 | "fmax_for_loss": null, 37 | 38 | "num_workers": 4, 39 | 40 | "dist_config": { 41 | "dist_backend": "nccl", 42 | "dist_url": "tcp://localhost:44327", 43 | "world_size": 1 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /vocoder/bigvgan/configs/bigvgan_22khz_80band.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 32, 5 | "learning_rate": 0.0001, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [4,4,2,2,2,2], 12 | "upsample_kernel_sizes": [8,8,4,4,4,4], 13 | "upsample_initial_channel": 1536, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "activation": "snakebeta", 18 | "snake_logscale": true, 19 | 20 | "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], 21 | "mpd_reshapes": [2, 3, 5, 7, 11], 22 | "use_spectral_norm": false, 23 | "discriminator_channel_mult": 1, 24 | 25 | "segment_size": 8192, 26 | "num_mels": 80, 27 | "num_freq": 1025, 28 | "n_fft": 1024, 29 | "hop_size": 256, 30 | "win_size": 1024, 31 | 32 | "sampling_rate": 22050, 33 | 34 | "fmin": 0, 35 | "fmax": 8000, 36 | "fmax_for_loss": null, 37 | 38 | "num_workers": 4, 39 | 40 | "dist_config": { 41 | "dist_backend": "nccl", 42 | "dist_url": "tcp://localhost:54321", 43 | "world_size": 1 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /vocoder/bigvgan/configs/bigvgan_24khz_100band.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 32, 5 | "learning_rate": 0.0001, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [4,4,2,2,2,2], 12 | "upsample_kernel_sizes": [8,8,4,4,4,4], 13 | "upsample_initial_channel": 1536, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "activation": "snakebeta", 18 | "snake_logscale": true, 19 | 20 | "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], 21 | "mpd_reshapes": [2, 3, 5, 7, 11], 22 | "use_spectral_norm": false, 23 | "discriminator_channel_mult": 1, 24 | 25 | "segment_size": 8192, 26 | "num_mels": 100, 27 | "num_freq": 1025, 28 | "n_fft": 1024, 29 | "hop_size": 256, 30 | "win_size": 1024, 31 | 32 | "sampling_rate": 24000, 33 | 34 | "fmin": 0, 35 | "fmax": 12000, 36 | "fmax_for_loss": null, 37 | 38 | "num_workers": 4, 39 | 40 | "dist_config": { 41 | "dist_backend": "nccl", 42 | "dist_url": "tcp://localhost:54321", 43 | "world_size": 1 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /vocoder/bigvgan/configs/bigvgan_base_22khz_80band.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 32, 5 | "learning_rate": 0.0001, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "activation": "snakebeta", 18 | "snake_logscale": true, 19 | 20 | "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], 21 | "mpd_reshapes": [2, 3, 5, 7, 11], 22 | "use_spectral_norm": false, 23 | "discriminator_channel_mult": 1, 24 | 25 | "segment_size": 8192, 26 | "num_mels": 80, 27 | "num_freq": 1025, 28 | "n_fft": 1024, 29 | "hop_size": 256, 30 | "win_size": 1024, 31 | 32 | "sampling_rate": 22050, 33 | 34 | "fmin": 0, 35 | "fmax": 8000, 36 | "fmax_for_loss": null, 37 | 38 | "num_workers": 4, 39 | 40 | "dist_config": { 41 | "dist_backend": "nccl", 42 | "dist_url": "tcp://localhost:54321", 43 | "world_size": 1 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /vocoder/bigvgan/configs/bigvgan_base_24khz_100band.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 32, 5 | "learning_rate": 0.0001, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "activation": "snakebeta", 18 | "snake_logscale": true, 19 | 20 | "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], 21 | "mpd_reshapes": [2, 3, 5, 7, 11], 22 | "use_spectral_norm": false, 23 | "discriminator_channel_mult": 1, 24 | 25 | "segment_size": 8192, 26 | "num_mels": 100, 27 | "num_freq": 1025, 28 | "n_fft": 1024, 29 | "hop_size": 256, 30 | "win_size": 1024, 31 | 32 | "sampling_rate": 24000, 33 | 34 | "fmin": 0, 35 | "fmax": 12000, 36 | "fmax_for_loss": null, 37 | 38 | "num_workers": 4, 39 | 40 | "dist_config": { 41 | "dist_backend": "nccl", 42 | "dist_url": "tcp://localhost:54321", 43 | "world_size": 1 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /vocoder/bigvgan/env.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/jik876/hifi-gan under the MIT license. 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import os 5 | import shutil 6 | 7 | 8 | class AttrDict(dict): 9 | def __init__(self, *args, **kwargs): 10 | super(AttrDict, self).__init__(*args, **kwargs) 11 | self.__dict__ = self 12 | 13 | 14 | def build_env(config, config_name, path): 15 | t_path = os.path.join(path, config_name) 16 | if config != t_path: 17 | os.makedirs(path, exist_ok=True) 18 | shutil.copyfile(config, os.path.join(path, config_name)) -------------------------------------------------------------------------------- /vocoder/bigvgan/incl_licenses/LICENSE_1: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /vocoder/bigvgan/incl_licenses/LICENSE_2: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Edward Dixon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /vocoder/bigvgan/incl_licenses/LICENSE_4: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Seungwon Park 박승원 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /vocoder/bigvgan/incl_licenses/LICENSE_5: -------------------------------------------------------------------------------- 1 | Copyright 2020 Alexandre Défossez 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 4 | associated documentation files (the "Software"), to deal in the Software without restriction, 5 | including without limitation the rights to use, copy, modify, merge, publish, distribute, 6 | sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 7 | furnished to do so, subject to the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or 10 | substantial portions of the Software. 11 | 12 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 13 | NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 14 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 15 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /vocoder/bigvgan/inference.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/jik876/hifi-gan under the MIT license. 2 | # LICENSE is in incl_licenses directory. 3 | 4 | from __future__ import absolute_import, division, print_function, unicode_literals 5 | 6 | import glob 7 | import os 8 | import argparse 9 | import json 10 | import torch 11 | from scipy.io.wavfile import write 12 | from env import AttrDict 13 | from meldataset import mel_spectrogram, MAX_WAV_VALUE 14 | from models import BigVGAN as Generator 15 | import librosa 16 | 17 | h = None 18 | device = None 19 | torch.backends.cudnn.benchmark = False 20 | 21 | 22 | def load_checkpoint(filepath, device): 23 | assert os.path.isfile(filepath) 24 | print("Loading '{}'".format(filepath)) 25 | checkpoint_dict = torch.load(filepath, map_location=device) 26 | print("Complete.") 27 | return checkpoint_dict 28 | 29 | 30 | def get_mel(x): 31 | return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax) 32 | 33 | 34 | def scan_checkpoint(cp_dir, prefix): 35 | pattern = os.path.join(cp_dir, prefix + '*') 36 | cp_list = glob.glob(pattern) 37 | if len(cp_list) == 0: 38 | return '' 39 | return sorted(cp_list)[-1] 40 | 41 | 42 | def inference(a, h): 43 | generator = Generator(h).to(device) 44 | 45 | state_dict_g = load_checkpoint(a.checkpoint_file, device) 46 | generator.load_state_dict(state_dict_g['generator']) 47 | 48 | filelist = os.listdir(a.input_wavs_dir) 49 | 50 | os.makedirs(a.output_dir, exist_ok=True) 51 | 52 | generator.eval() 53 | generator.remove_weight_norm() 54 | with torch.no_grad(): 55 | for i, filname in enumerate(filelist): 56 | # load the ground truth audio and resample if necessary 57 | wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), h.sampling_rate, mono=True) 58 | wav = torch.FloatTensor(wav).to(device) 59 | # compute mel spectrogram from the ground truth audio 60 | x = get_mel(wav.unsqueeze(0)) 61 | 62 | y_g_hat = generator(x) 63 | 64 | audio = y_g_hat.squeeze() 65 | audio = audio * MAX_WAV_VALUE 66 | audio = audio.cpu().numpy().astype('int16') 67 | 68 | output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated.wav') 69 | write(output_file, h.sampling_rate, audio) 70 | print(output_file) 71 | 72 | 73 | def main(): 74 | print('Initializing Inference Process..') 75 | 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument('--input_wavs_dir', default='test_files') 78 | parser.add_argument('--output_dir', default='generated_files') 79 | parser.add_argument('--checkpoint_file', required=True) 80 | 81 | a = parser.parse_args() 82 | 83 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json') 84 | with open(config_file) as f: 85 | data = f.read() 86 | 87 | global h 88 | json_config = json.loads(data) 89 | h = AttrDict(json_config) 90 | 91 | torch.manual_seed(h.seed) 92 | global device 93 | if torch.cuda.is_available(): 94 | torch.cuda.manual_seed(h.seed) 95 | device = torch.device('cuda') 96 | else: 97 | device = torch.device('cpu') 98 | 99 | inference(a, h) 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | 105 | -------------------------------------------------------------------------------- /vocoder/bigvgan/inference_e2e.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/jik876/hifi-gan under the MIT license. 2 | # LICENSE is in incl_licenses directory. 3 | 4 | from __future__ import absolute_import, division, print_function, unicode_literals 5 | 6 | import glob 7 | import os 8 | import numpy as np 9 | import argparse 10 | import json 11 | import torch 12 | from scipy.io.wavfile import write 13 | from env import AttrDict 14 | from meldataset import MAX_WAV_VALUE 15 | from models import BigVGAN as Generator 16 | 17 | h = None 18 | device = None 19 | torch.backends.cudnn.benchmark = False 20 | 21 | 22 | def load_checkpoint(filepath, device): 23 | assert os.path.isfile(filepath) 24 | print("Loading '{}'".format(filepath)) 25 | checkpoint_dict = torch.load(filepath, map_location=device) 26 | print("Complete.") 27 | return checkpoint_dict 28 | 29 | 30 | def scan_checkpoint(cp_dir, prefix): 31 | pattern = os.path.join(cp_dir, prefix + '*') 32 | cp_list = glob.glob(pattern) 33 | if len(cp_list) == 0: 34 | return '' 35 | return sorted(cp_list)[-1] 36 | 37 | 38 | def inference(a, h): 39 | generator = Generator(h).to(device) 40 | 41 | state_dict_g = load_checkpoint(a.checkpoint_file, device) 42 | generator.load_state_dict(state_dict_g['generator']) 43 | 44 | filelist = os.listdir(a.input_mels_dir) 45 | 46 | os.makedirs(a.output_dir, exist_ok=True) 47 | 48 | generator.eval() 49 | generator.remove_weight_norm() 50 | with torch.no_grad(): 51 | for i, filname in enumerate(filelist): 52 | # load the mel spectrogram in .npy format 53 | x = np.load(os.path.join(a.input_mels_dir, filname)) 54 | x = torch.FloatTensor(x).to(device) 55 | if len(x.shape) == 2: 56 | x = x.unsqueeze(0) 57 | 58 | y_g_hat = generator(x) 59 | 60 | audio = y_g_hat.squeeze() 61 | audio = audio * MAX_WAV_VALUE 62 | audio = audio.cpu().numpy().astype('int16') 63 | 64 | output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated_e2e.wav') 65 | write(output_file, h.sampling_rate, audio) 66 | print(output_file) 67 | 68 | 69 | def main(): 70 | print('Initializing Inference Process..') 71 | 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--input_mels_dir', default='test_mel_files') 74 | parser.add_argument('--output_dir', default='generated_files_from_mel') 75 | parser.add_argument('--checkpoint_file', required=True) 76 | 77 | a = parser.parse_args() 78 | 79 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json') 80 | with open(config_file) as f: 81 | data = f.read() 82 | 83 | global h 84 | json_config = json.loads(data) 85 | h = AttrDict(json_config) 86 | 87 | torch.manual_seed(h.seed) 88 | global device 89 | if torch.cuda.is_available(): 90 | torch.cuda.manual_seed(h.seed) 91 | device = torch.device('cuda') 92 | else: 93 | device = torch.device('cpu') 94 | 95 | inference(a, h) 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | 101 | -------------------------------------------------------------------------------- /vocoder/bigvgan/parse_scripts/parse_libritts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | 4 | import os, glob 5 | 6 | def get_wav_and_text_filelist(data_root, data_type, subsample=1): 7 | wav_list = sorted([path.replace(data_root, "")[1:] for path in glob.glob(os.path.join(data_root, data_type, "**/**/*.wav"))]) 8 | wav_list = wav_list[::subsample] 9 | txt_filelist = [path.replace('.wav', '.normalized.txt') for path in wav_list] 10 | 11 | txt_list = [] 12 | for txt_file in txt_filelist: 13 | with open(os.path.join(data_root, txt_file), 'r') as f_txt: 14 | text = f_txt.readline().strip('\n') 15 | txt_list.append(text) 16 | wav_list = [path.replace('.wav', '') for path in wav_list] 17 | 18 | return wav_list, txt_list 19 | 20 | def write_filelist(output_path, wav_list, txt_list): 21 | with open(output_path, 'w') as f: 22 | for i in range(len(wav_list)): 23 | filename = wav_list[i] + '|' + txt_list[i] 24 | f.write(filename + '\n') 25 | 26 | if __name__ == "__main__": 27 | 28 | data_root = "LibriTTS" 29 | 30 | # dev and test sets. subsample each sets to get ~100 utterances 31 | data_type_list = ["dev-clean", "dev-other", "test-clean", "test-other"] 32 | subsample_list = [50, 50, 50, 50] 33 | for (data_type, subsample) in zip(data_type_list, subsample_list): 34 | print("processing {}".format(data_type)) 35 | data_path = os.path.join(data_root, data_type) 36 | assert os.path.exists(data_path),\ 37 | "path {} not found. make sure the path is accessible by creating the symbolic link using the following command: "\ 38 | "ln -s /path/to/your/{} {}".format(data_path, data_path, data_path) 39 | wav_list, txt_list = get_wav_and_text_filelist(data_root, data_type, subsample) 40 | write_filelist(os.path.join(data_root, data_type+".txt"), wav_list, txt_list) 41 | 42 | # training and seen speaker validation datasets (libritts-full): train-clean-100 + train-clean-360 + train-other-500 43 | wav_list_train, txt_list_train = [], [] 44 | for data_type in ["train-clean-100", "train-clean-360", "train-other-500"]: 45 | print("processing {}".format(data_type)) 46 | data_path = os.path.join(data_root, data_type) 47 | assert os.path.exists(data_path),\ 48 | "path {} not found. make sure the path is accessible by creating the symbolic link using the following command: "\ 49 | "ln -s /path/to/your/{} {}".format(data_path, data_path, data_path) 50 | wav_list, txt_list = get_wav_and_text_filelist(data_root, data_type) 51 | wav_list_train.extend(wav_list) 52 | txt_list_train.extend(txt_list) 53 | 54 | # split the training set so that the seen speaker validation set contains ~100 utterances 55 | subsample_val = 3000 56 | wav_list_val, txt_list_val = wav_list_train[::subsample_val], txt_list_train[::subsample_val] 57 | del wav_list_train[::subsample_val] 58 | del txt_list_train[::subsample_val] 59 | write_filelist(os.path.join(data_root, "train-full.txt"), wav_list_train, txt_list_train) 60 | write_filelist(os.path.join(data_root, "val-full.txt"), wav_list_val, txt_list_val) 61 | 62 | print("done") -------------------------------------------------------------------------------- /vocoder/bigvgan/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | librosa==0.8.1 4 | scipy 5 | tensorboard 6 | soundfile 7 | matplotlib 8 | pesq 9 | auraloss 10 | tqdm -------------------------------------------------------------------------------- /vocoder/bigvgan/utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/jik876/hifi-gan under the MIT license. 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import glob 5 | import os 6 | import matplotlib 7 | import torch 8 | from torch.nn.utils import weight_norm 9 | matplotlib.use("Agg") 10 | import matplotlib.pylab as plt 11 | from meldataset import MAX_WAV_VALUE 12 | from scipy.io.wavfile import write 13 | 14 | 15 | def plot_spectrogram(spectrogram): 16 | fig, ax = plt.subplots(figsize=(10, 2)) 17 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 18 | interpolation='none') 19 | plt.colorbar(im, ax=ax) 20 | 21 | fig.canvas.draw() 22 | plt.close() 23 | 24 | return fig 25 | 26 | 27 | def plot_spectrogram_clipped(spectrogram, clip_max=2.): 28 | fig, ax = plt.subplots(figsize=(10, 2)) 29 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 30 | interpolation='none', vmin=1e-6, vmax=clip_max) 31 | plt.colorbar(im, ax=ax) 32 | 33 | fig.canvas.draw() 34 | plt.close() 35 | 36 | return fig 37 | 38 | 39 | def init_weights(m, mean=0.0, std=0.01): 40 | classname = m.__class__.__name__ 41 | if classname.find("Conv") != -1: 42 | m.weight.data.normal_(mean, std) 43 | 44 | 45 | def apply_weight_norm(m): 46 | classname = m.__class__.__name__ 47 | if classname.find("Conv") != -1: 48 | weight_norm(m) 49 | 50 | 51 | def get_padding(kernel_size, dilation=1): 52 | return int((kernel_size*dilation - dilation)/2) 53 | 54 | 55 | def load_checkpoint(filepath, device): 56 | assert os.path.isfile(filepath) 57 | print("Loading '{}'".format(filepath)) 58 | checkpoint_dict = torch.load(filepath, map_location=device) 59 | print("Complete.") 60 | return checkpoint_dict 61 | 62 | 63 | def save_checkpoint(filepath, obj): 64 | print("Saving checkpoint to {}".format(filepath)) 65 | torch.save(obj, filepath) 66 | print("Complete.") 67 | 68 | 69 | def scan_checkpoint(cp_dir, prefix): 70 | pattern = os.path.join(cp_dir, prefix + '????????') 71 | cp_list = glob.glob(pattern) 72 | if len(cp_list) == 0: 73 | return None 74 | return sorted(cp_list)[-1] 75 | 76 | def save_audio(audio, path, sr): 77 | # wav: torch with 1d shape 78 | audio = audio * MAX_WAV_VALUE 79 | audio = audio.cpu().numpy().astype('int16') 80 | write(path, sr, audio) -------------------------------------------------------------------------------- /wav_evaluation/cal_clap_score.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import sys 3 | import os 4 | directory = pathlib.Path(os.getcwd()) 5 | sys.path.append(str(directory)) 6 | import torch 7 | import numpy as np 8 | from wav_evaluation.models.CLAPWrapper import CLAPWrapper 9 | import torch.nn.functional as F 10 | import argparse 11 | import csv 12 | from tqdm import tqdm 13 | from torch.utils.data import Dataset,DataLoader 14 | import pandas as pd 15 | import json 16 | 17 | 18 | 19 | 20 | 21 | def cal_score_by_csv(csv_path,clap_model): # audiocaps val的gt音频的clap_score计算为0.479077 22 | input_file = open(csv_path) 23 | input_lines = input_file.readlines() 24 | 25 | 26 | 27 | clap_scores = [] 28 | 29 | caption_list,audio_list = [],[] 30 | with torch.no_grad(): 31 | for idx in tqdm(range(len(input_lines))): 32 | # text_embeddings = clap_model.get_text_embeddings([getattr(t,'caption')])# 经过了norm的embedding 33 | # audio_embeddings = clap_model.get_audio_embeddings([getattr(t,'audio_path')], resample=True) 34 | # score = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False) 35 | # clap_scores.append(score.cpu().numpy()) 36 | if input_lines[idx][0] == 'S': 37 | item_name, semantic = input_lines[idx].split('\t') 38 | 39 | index = item_name[2:] 40 | # import ipdb 41 | # ipdb.set_trace() 42 | caption_list.append(semantic.strip()) 43 | audio_list.append(f'/home1/liuhuadai/projects/VoiceLM-main/encodec_16k_6kbps_multiDisc/results/text_to_audio_0912/ref/{index}.wav') 44 | # import ipdb 45 | # ipdb.set_trace() 46 | if idx % 60 == 0: 47 | text_embeddings = clap_model.get_text_embeddings(caption_list)# 经过了norm的embedding 48 | audio_embeddings = clap_model.get_audio_embeddings(audio_list, resample=True)# 这一步比较耗时,读取音频并重采样到44100 49 | score_mat = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False) 50 | score = score_mat.diagonal() 51 | clap_scores.append(score.cpu().numpy()) 52 | # print(caption_list) 53 | # print(audio_list) 54 | # print(score) 55 | audio_list = [] 56 | caption_list = [] 57 | # print("mean:",np.mean(np.array(clap_scores).flatten())) 58 | return np.mean(np.array(clap_scores).flatten()) 59 | [0.24463119, 0.24597324, 0.26050782, 0.25079757, 0.2501094, 0.2629509,0.25025588,0.25980043,0.27295044, 0.25655213, 0.2490872, 0.2598294,0.26491216,0.24698025,0.25086403,0.27533108,0.27969885,0.2596455,0.26313564,0.2658071] 60 | def add_clap_score_to_csv(csv_path,clap_model): 61 | df = pd.read_csv(csv_path,sep='\t') 62 | clap_scores_dict = {} 63 | with torch.no_grad(): 64 | for idx,t in enumerate(tqdm(df.itertuples()),start=1): 65 | text_embeddings = clap_model.get_text_embeddings([getattr(t,'caption')])# 经过了norm的embedding 66 | audio_embeddings = clap_model.get_audio_embeddings([getattr(t,'audio_path')], resample=True) 67 | score = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False) 68 | clap_scores_dict[idx] = score.cpu().numpy() 69 | df['clap_score'] = clap_scores_dict 70 | df.to_csv(csv_path[:-4]+'_clap.csv',sep='\t',index=False) 71 | 72 | 73 | if __name__ == '__main__': 74 | ckpt_path = '/home1/liuhuadai/projects/VoiceLM-main/encodec_16k_6kbps_multiDisc/useful_ckpts/CLAP' 75 | clap_model = CLAPWrapper(os.path.join(ckpt_path,'CLAP_weights_2022.pth'),os.path.join(ckpt_path,'config.yml'), use_cuda=True) 76 | 77 | clap_score = cal_score_by_csv('/home1/liuhuadai/projects/VoiceLM-main/encodec_16k_6kbps_multiDisc/Test/generate-test.txt',clap_model) 78 | out = 'text_to_audio2_0908' 79 | print(f"clap_score for {out} is:{clap_score}") 80 | print(f"clap_score for {out} is:{clap_score}") 81 | print(f"clap_score for {out} is:{clap_score}") 82 | # os.remove(csv_path) -------------------------------------------------------------------------------- /wav_evaluation/cal_fad_score.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import sys 3 | import os 4 | directory = pathlib.Path(os.getcwd()) 5 | sys.path.append(str(directory)) 6 | import argparse 7 | from wav_evaluation.metrics.fad import FrechetAudioDistance 8 | """it will resample to 16000hz automatically""" 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | # parser.add_argument('--csv_path',type=str,default='tmp.csv') 12 | parser.add_argument('--pred_wavsdir',type=str) 13 | parser.add_argument('--gt_wavsdir', default="/home/tiger/nfs/data/audiocaps/test") 14 | args = parser.parse_args() 15 | return args 16 | 17 | if __name__ == '__main__': 18 | args = parse_args() 19 | frechet = FrechetAudioDistance( 20 | use_pca=False, 21 | use_activation=False, 22 | verbose=False 23 | ) 24 | fad_score = frechet.score(background_dir=args.gt_wavsdir,eval_dir=args.pred_wavsdir) 25 | print(f"Frechet Audio Distance {fad_score}") 26 | -------------------------------------------------------------------------------- /wav_evaluation/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import clap 2 | from . import audio 3 | from . import utils -------------------------------------------------------------------------------- /wav_evaluation/models/audio.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchlibrosa.stft import Spectrogram, LogmelFilterBank 5 | 6 | def get_audio_encoder(name: str): 7 | if name == "Cnn14": 8 | return Cnn14 9 | else: 10 | raise Exception('The audio encoder name {} is incorrect or not supported'.format(name)) 11 | 12 | 13 | class ConvBlock(nn.Module): 14 | def __init__(self, in_channels, out_channels): 15 | 16 | super(ConvBlock, self).__init__() 17 | 18 | self.conv1 = nn.Conv2d(in_channels=in_channels, 19 | out_channels=out_channels, 20 | kernel_size=(3, 3), stride=(1, 1), 21 | padding=(1, 1), bias=False) 22 | 23 | self.conv2 = nn.Conv2d(in_channels=out_channels, 24 | out_channels=out_channels, 25 | kernel_size=(3, 3), stride=(1, 1), 26 | padding=(1, 1), bias=False) 27 | 28 | self.bn1 = nn.BatchNorm2d(out_channels) 29 | self.bn2 = nn.BatchNorm2d(out_channels) 30 | 31 | 32 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 33 | 34 | x = input 35 | x = F.relu_(self.bn1(self.conv1(x))) 36 | x = F.relu_(self.bn2(self.conv2(x))) 37 | if pool_type == 'max': 38 | x = F.max_pool2d(x, kernel_size=pool_size) 39 | elif pool_type == 'avg': 40 | x = F.avg_pool2d(x, kernel_size=pool_size) 41 | elif pool_type == 'avg+max': 42 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 43 | x2 = F.max_pool2d(x, kernel_size=pool_size) 44 | x = x1 + x2 45 | else: 46 | raise Exception('Incorrect argument!') 47 | 48 | return x 49 | 50 | 51 | class ConvBlock5x5(nn.Module): 52 | def __init__(self, in_channels, out_channels): 53 | 54 | super(ConvBlock5x5, self).__init__() 55 | 56 | self.conv1 = nn.Conv2d(in_channels=in_channels, 57 | out_channels=out_channels, 58 | kernel_size=(5, 5), stride=(1, 1), 59 | padding=(2, 2), bias=False) 60 | 61 | self.bn1 = nn.BatchNorm2d(out_channels) 62 | 63 | 64 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 65 | 66 | x = input 67 | x = F.relu_(self.bn1(self.conv1(x))) 68 | if pool_type == 'max': 69 | x = F.max_pool2d(x, kernel_size=pool_size) 70 | elif pool_type == 'avg': 71 | x = F.avg_pool2d(x, kernel_size=pool_size) 72 | elif pool_type == 'avg+max': 73 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 74 | x2 = F.max_pool2d(x, kernel_size=pool_size) 75 | x = x1 + x2 76 | else: 77 | raise Exception('Incorrect argument!') 78 | 79 | return x 80 | 81 | 82 | class AttBlock(nn.Module): 83 | def __init__(self, n_in, n_out, activation='linear', temperature=1.): 84 | super(AttBlock, self).__init__() 85 | 86 | self.activation = activation 87 | self.temperature = temperature 88 | self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 89 | self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 90 | 91 | self.bn_att = nn.BatchNorm1d(n_out) 92 | 93 | def forward(self, x): 94 | # x: (n_samples, n_in, n_time) 95 | norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) 96 | cla = self.nonlinear_transform(self.cla(x)) 97 | x = torch.sum(norm_att * cla, dim=2) 98 | return x, norm_att, cla 99 | 100 | def nonlinear_transform(self, x): 101 | if self.activation == 'linear': 102 | return x 103 | elif self.activation == 'sigmoid': 104 | return torch.sigmoid(x) 105 | 106 | 107 | class Cnn14(nn.Module): 108 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 109 | fmax, classes_num, out_emb): 110 | 111 | super(Cnn14, self).__init__() 112 | 113 | window = 'hann' 114 | center = True 115 | pad_mode = 'reflect' 116 | ref = 1.0 117 | amin = 1e-10 118 | top_db = None 119 | 120 | # Spectrogram extractor 121 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 122 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 123 | freeze_parameters=True) 124 | 125 | # Logmel feature extractor 126 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 127 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 128 | freeze_parameters=True) 129 | 130 | self.bn0 = nn.BatchNorm2d(64) 131 | 132 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 133 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 134 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 135 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 136 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 137 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 138 | 139 | # out_emb is 2048 for best Cnn14 140 | self.fc1 = nn.Linear(2048, out_emb, bias=True) 141 | self.fc_audioset = nn.Linear(out_emb, classes_num, bias=True) 142 | 143 | def forward(self, input, mixup_lambda=None): 144 | """ 145 | Input: (batch_size, data_length) 146 | """ 147 | 148 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 149 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 150 | 151 | x = x.transpose(1, 3) 152 | x = self.bn0(x) 153 | x = x.transpose(1, 3) 154 | 155 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 156 | x = F.dropout(x, p=0.2, training=self.training) 157 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 158 | x = F.dropout(x, p=0.2, training=self.training) 159 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 160 | x = F.dropout(x, p=0.2, training=self.training) 161 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 162 | x = F.dropout(x, p=0.2, training=self.training) 163 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 164 | x = F.dropout(x, p=0.2, training=self.training) 165 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 166 | x = F.dropout(x, p=0.2, training=self.training) 167 | x = torch.mean(x, dim=3) 168 | 169 | (x1, _) = torch.max(x, dim=2) 170 | x2 = torch.mean(x, dim=2) 171 | x = x1 + x2 172 | x = F.dropout(x, p=0.5, training=self.training) 173 | x = F.relu_(self.fc1(x)) 174 | embedding = F.dropout(x, p=0.5, training=self.training) 175 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 176 | 177 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 178 | 179 | return output_dict -------------------------------------------------------------------------------- /wav_evaluation/models/clap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from transformers import AutoModel 6 | from .audio import get_audio_encoder 7 | 8 | class Projection(nn.Module): 9 | def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None: 10 | super().__init__() 11 | self.linear1 = nn.Linear(d_in, d_out, bias=False) 12 | self.linear2 = nn.Linear(d_out, d_out, bias=False) 13 | self.layer_norm = nn.LayerNorm(d_out) 14 | self.drop = nn.Dropout(p) 15 | 16 | def forward(self, x: torch.Tensor) -> torch.Tensor: 17 | embed1 = self.linear1(x) 18 | embed2 = self.drop(self.linear2(F.gelu(embed1))) 19 | embeds = self.layer_norm(embed1 + embed2) 20 | return embeds 21 | 22 | class AudioEncoder(nn.Module): 23 | def __init__(self, audioenc_name:str, d_in: int, d_out: int, sample_rate: int, window_size: int, 24 | hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None: 25 | super().__init__() 26 | 27 | audio_encoder = get_audio_encoder(audioenc_name) 28 | 29 | self.base = audio_encoder( 30 | sample_rate, window_size, 31 | hop_size, mel_bins, fmin, fmax, 32 | classes_num, d_in) 33 | 34 | self.projection = Projection(d_in, d_out) 35 | 36 | def forward(self, x): 37 | out_dict = self.base(x) 38 | audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output'] 39 | projected_vec = self.projection(audio_features) 40 | return projected_vec, audio_classification_output 41 | 42 | class TextEncoder(nn.Module): 43 | def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None: 44 | super().__init__() 45 | self.base = AutoModel.from_pretrained(text_model) 46 | 47 | self.projection = Projection(transformer_embed_dim, d_out) 48 | 49 | def forward(self, x): 50 | out = self.base(**x)[0] 51 | out = out[:, 0, :] # get CLS token output 52 | projected_vec = self.projection(out) 53 | return projected_vec 54 | 55 | class CLAP(nn.Module): 56 | def __init__(self, 57 | # audio 58 | audioenc_name: str, 59 | sample_rate: int, 60 | window_size: int, 61 | hop_size: int, 62 | mel_bins: int, 63 | fmin: int, 64 | fmax: int, 65 | classes_num: int, 66 | out_emb: int, 67 | # text 68 | text_model: str, 69 | transformer_embed_dim: int, 70 | # common 71 | d_proj: int, 72 | ): 73 | super().__init__() 74 | 75 | 76 | self.audio_encoder = AudioEncoder( 77 | audioenc_name, out_emb, d_proj, 78 | sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num) 79 | 80 | self.caption_encoder = TextEncoder( 81 | d_proj, text_model, transformer_embed_dim 82 | ) 83 | 84 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 85 | 86 | def forward(self, audio, text): 87 | audio_embed, _ = self.audio_encoder(audio) 88 | caption_embed = self.caption_encoder(text) 89 | 90 | return caption_embed, audio_embed, self.logit_scale.exp() -------------------------------------------------------------------------------- /wav_evaluation/models/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import sys 4 | 5 | def read_config_as_args(config_path,args=None,is_config_str=False): 6 | return_dict = {} 7 | 8 | if config_path is not None: 9 | if is_config_str: 10 | yml_config = yaml.load(config_path, Loader=yaml.FullLoader) 11 | else: 12 | with open(config_path, "r") as f: 13 | yml_config = yaml.load(f, Loader=yaml.FullLoader) 14 | 15 | if args != None: 16 | for k, v in yml_config.items(): 17 | if k in args.__dict__: 18 | args.__dict__[k] = v 19 | else: 20 | sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k)) 21 | else: 22 | for k, v in yml_config.items(): 23 | return_dict[k] = v 24 | 25 | args = args if args != None else return_dict 26 | return argparse.Namespace(**args) 27 | -------------------------------------------------------------------------------- /wav_evaluation/useful_ckpts/CLAP/config.yml: -------------------------------------------------------------------------------- 1 | # TEXT ENCODER CONFIG 2 | text_model: '/root/autodl-tmp/liuhuadai/AudioLCM/ldm/modules/encoders/CLAP/bert-base-uncased' 3 | text_len: 100 4 | transformer_embed_dim: 768 5 | freeze_text_encoder_weights: True 6 | 7 | # AUDIO ENCODER CONFIG 8 | audioenc_name: 'Cnn14' 9 | out_emb: 2048 10 | sampling_rate: 44100 11 | duration: 5 12 | fmin: 50 13 | fmax: 14000 14 | n_fft: 1028 15 | hop_size: 320 16 | mel_bins: 64 17 | window_size: 1024 18 | 19 | # PROJECTION SPACE CONFIG 20 | d_proj: 1024 21 | temperature: 0.003 22 | 23 | # TRAINING AND EVALUATION CONFIG 24 | num_classes: 527 25 | batch_size: 1024 26 | demo: False 27 | --------------------------------------------------------------------------------