├── audioldm2 ├── clap │ ├── __init__.py │ ├── training │ │ ├── __init__.py │ │ ├── audioset_textmap.npy │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ └── __pycache__ │ │ │ ├── data.cpython-310.pyc │ │ │ └── __init__.cpython-310.pyc │ └── open_clip │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── model_configs │ │ ├── ViT-B-16.json │ │ ├── ViT-B-32.json │ │ ├── ViT-L-14.json │ │ ├── ViT-B-32-quickgelu.json │ │ ├── RN50.json │ │ ├── RN101.json │ │ ├── RN50x16.json │ │ ├── RN50x4.json │ │ ├── RN101-quickgelu.json │ │ ├── RN50-quickgelu.json │ │ ├── PANN-6.json │ │ ├── HTSAT-base.json │ │ ├── HTSAT-tiny.json │ │ ├── PANN-10.json │ │ ├── PANN-14.json │ │ ├── HTSAT-large.json │ │ ├── PANN-14-fmax-18k.json │ │ ├── PANN-14-win-1536.json │ │ ├── HTSAT-tiny-win-1536.json │ │ ├── PANN-14-fmax-8k-20s.json │ │ └── PANN-14-tiny-transformer.json │ │ ├── __init__.py │ │ ├── transform.py │ │ ├── timm_model.py │ │ ├── openai.py │ │ ├── pretrained.py │ │ ├── tokenizer.py │ │ └── feature_fusion.py ├── latent_diffusion │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── ddim.cpython-310.pyc │ │ │ ├── ddpm.cpython-310.pyc │ │ │ ├── plms.cpython-310.pyc │ │ │ └── __init__.cpython-310.pyc │ ├── modules │ │ ├── __init__.py │ │ ├── audiomae │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── AudioMAE.cpython-310.pyc │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ ├── models_mae.cpython-310.pyc │ │ │ │ └── models_vit.cpython-310.pyc │ │ │ ├── util │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── pos_embed.cpython-310.pyc │ │ │ │ │ └── patch_embed.cpython-310.pyc │ │ │ │ ├── lr_sched.py │ │ │ │ ├── crop.py │ │ │ │ ├── datasets.py │ │ │ │ ├── lars.py │ │ │ │ ├── stat.py │ │ │ │ ├── lr_decay.py │ │ │ │ ├── patch_embed.py │ │ │ │ └── pos_embed.py │ │ │ ├── AudioMAE.py │ │ │ └── models_vit.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ └── modules.cpython-310.pyc │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── util.cpython-310.pyc │ │ │ │ └── __init__.cpython-310.pyc │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ └── distributions.cpython-310.pyc │ │ │ └── distributions.py │ │ ├── phoneme_encoder │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── commons.cpython-310.pyc │ │ │ │ ├── encoder.cpython-310.pyc │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ └── attentions.cpython-310.pyc │ │ │ ├── text │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── symbols.cpython-310.pyc │ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ │ └── cleaners.cpython-310.pyc │ │ │ │ ├── symbols.py │ │ │ │ ├── LICENSE │ │ │ │ ├── __init__.py │ │ │ │ └── cleaners.py │ │ │ ├── encoder.py │ │ │ └── commons.py │ │ ├── __pycache__ │ │ │ ├── ema.cpython-310.pyc │ │ │ └── __init__.cpython-310.pyc │ │ └── ema.py │ └── util.py ├── utilities │ ├── data │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── dataset.cpython-310.pyc │ ├── __init__.py │ ├── audio │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── stft.cpython-310.pyc │ │ │ ├── tools.cpython-310.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── audio_processing.cpython-310.pyc │ │ ├── tools.py │ │ ├── audio_processing.py │ │ └── stft.py │ └── model.py ├── audiomae_gen │ ├── __init__.py │ └── utils.py ├── __init__.py ├── hifigan │ ├── __init__.py │ ├── LICENSE │ └── models.py ├── __main__.py └── pipeline.py ├── visuals └── framework.png ├── modules ├── __pycache__ │ ├── layers.cpython-310.pyc │ ├── autoencoder.cpython-310.pyc │ └── conditioner.cpython-310.pyc └── conditioner.py ├── scripts ├── train_l.sh ├── train_s.sh ├── train_b.sh └── train_g.sh ├── utils.py ├── config ├── example.txt └── 16k_64.yaml ├── constants.py ├── model.py ├── sample.py ├── README.md └── test.py /audioldm2/clap/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audioldm2/clap/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/phoneme_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audioldm2/utilities/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | -------------------------------------------------------------------------------- /audioldm2/audiomae_gen/__init__.py: -------------------------------------------------------------------------------- 1 | from .sequence_input import Sequence2AudioMAE 2 | -------------------------------------------------------------------------------- /visuals/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/visuals/framework.png -------------------------------------------------------------------------------- /audioldm2/utilities/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import * 2 | from .data import * 3 | from .model import * 4 | -------------------------------------------------------------------------------- /audioldm2/utilities/audio/__init__.py: -------------------------------------------------------------------------------- 1 | from .audio_processing import * 2 | from .stft import * 3 | from .tools import * 4 | -------------------------------------------------------------------------------- /audioldm2/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import seed_everything, save_wave, get_time, get_duration, read_list 2 | from .pipeline import * 3 | -------------------------------------------------------------------------------- /modules/__pycache__/layers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/modules/__pycache__/layers.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/clap/training/audioset_textmap.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/clap/training/audioset_textmap.npy -------------------------------------------------------------------------------- /modules/__pycache__/autoencoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/modules/__pycache__/autoencoder.cpython-310.pyc -------------------------------------------------------------------------------- /modules/__pycache__/conditioner.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/modules/__pycache__/conditioner.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /audioldm2/clap/training/__pycache__/data.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/clap/training/__pycache__/data.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/utilities/audio/__pycache__/stft.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/utilities/audio/__pycache__/stft.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/clap/training/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/clap/training/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/utilities/audio/__pycache__/tools.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/utilities/audio/__pycache__/tools.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/utilities/data/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/utilities/data/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/utilities/data/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/utilities/data/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/utilities/audio/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/utilities/audio/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/models/__pycache__/ddim.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/models/__pycache__/ddim.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/models/__pycache__/ddpm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/models/__pycache__/ddpm.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/models/__pycache__/plms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/models/__pycache__/plms.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/__pycache__/ema.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/__pycache__/ema.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/train_l.sh: -------------------------------------------------------------------------------- 1 | torchrun --nnodes=4 --nproc_per_node=8 train.py \ 2 | --version large \ 3 | --data-path combine_dataset.json \ 4 | --global_batch_size 128 \ 5 | --global-seed 2023 -------------------------------------------------------------------------------- /scripts/train_s.sh: -------------------------------------------------------------------------------- 1 | torchrun --nnodes=1 --nproc_per_node=8 train.py \ 2 | --version small \ 3 | --data-path combine_dataset.json \ 4 | --global_batch_size 128 \ 5 | --global-seed 2023 -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/utilities/audio/__pycache__/audio_processing.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/utilities/audio/__pycache__/audio_processing.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/train_b.sh: -------------------------------------------------------------------------------- 1 | torchrun --nnodes=2 --nproc_per_node=8 train.py \ 2 | --version base \ 3 | --data-path combine_dataset.json \ 4 | --global_batch_size 128 \ 5 | --resume xxx \ 6 | --global-seed 2023 7 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/__pycache__/AudioMAE.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/audiomae/__pycache__/AudioMAE.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/audiomae/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/encoders/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/encoders/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/encoders/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/encoders/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/__pycache__/models_mae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/audiomae/__pycache__/models_mae.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/__pycache__/models_vit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/audiomae/__pycache__/models_vit.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/util/__pycache__/pos_embed.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/audiomae/util/__pycache__/pos_embed.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/diffusionmodules/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/diffusionmodules/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/distributions/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/distributions/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/commons.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/commons.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/encoder.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/util/__pycache__/patch_embed.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/audiomae/util/__pycache__/patch_embed.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/distributions/__pycache__/distributions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/distributions/__pycache__/distributions.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/attentions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/attentions.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/phoneme_encoder/text/__pycache__/symbols.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/phoneme_encoder/text/__pycache__/symbols.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/phoneme_encoder/text/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/phoneme_encoder/text/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/phoneme_encoder/text/__pycache__/cleaners.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/phoneme_encoder/text/__pycache__/cleaners.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/train_g.sh: -------------------------------------------------------------------------------- 1 | torchrun --nnodes=4 --nproc_per_node=8 train.py \ 2 | --version giant \ 3 | --data-path combine_dataset.json \ 4 | --resume results/giant/checkpoints/0050000.pt \ 5 | --global_batch_size 128 \ 6 | --global-seed 2023 \ 7 | --accum_iter 8 8 | -------------------------------------------------------------------------------- /audioldm2/hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | from .models_v2 import Generator 2 | from .models import Generator as Generator_old 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/PANN-6.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn6" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/HTSAT-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "base" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/PANN-10.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn10" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/PANN-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/phoneme_encoder/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | ''' 6 | _pad = '_' 7 | _punctuation = ';:,.!?¡¿—…"«»“” ' 8 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 9 | _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 10 | 11 | 12 | # Export all symbols: 13 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) 14 | 15 | # Special symbol ids 16 | SPACE_ID = symbols.index(" ") 17 | -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/HTSAT-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "large" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 18000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 960000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 360, 10 | "fmin": 50, 11 | "fmax": 8000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 4 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import ( 2 | list_models, 3 | create_model, 4 | create_model_and_transforms, 5 | add_model_config, 6 | ) 7 | from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics 8 | from .model import ( 9 | CLAP, 10 | CLAPTextCfg, 11 | CLAPVisionCfg, 12 | CLAPAudioCfp, 13 | convert_weights_to_fp16, 14 | trace_model, 15 | ) 16 | from .openai import load_openai_model, list_openai_models 17 | from .pretrained import ( 18 | list_pretrained, 19 | list_pretrained_tag_models, 20 | list_pretrained_model_tags, 21 | get_pretrained_url, 22 | download_pretrained, 23 | ) 24 | from .tokenizer import SimpleTokenizer, tokenize 25 | from .transform import image_transform 26 | -------------------------------------------------------------------------------- /audioldm2/audiomae_gen/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Prenet(nn.Module): 5 | def __init__(self, in_dim, sizes=[256, 128], dropout_rate=0.5): 6 | super(Prenet, self).__init__() 7 | in_sizes = [in_dim] + sizes[:-1] 8 | self.layers = nn.ModuleList( 9 | [ 10 | nn.Linear(in_size, out_size) 11 | for (in_size, out_size) in zip(in_sizes, sizes) 12 | ] 13 | ) 14 | self.relu = nn.ReLU() 15 | self.dropout = nn.Dropout(dropout_rate) 16 | 17 | def forward(self, inputs): 18 | for linear in self.layers: 19 | inputs = self.dropout(self.relu(linear(inputs))) 20 | return inputs 21 | 22 | 23 | if __name__ == "__main__": 24 | model = Prenet(in_dim=128, sizes=[256, 256, 128]) 25 | import ipdb 26 | 27 | ipdb.set_trace() 28 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | 10 | def adjust_learning_rate(optimizer, epoch, args): 11 | """Decay the learning rate with half-cycle cosine after warmup""" 12 | if epoch < args.warmup_epochs: 13 | lr = args.lr * epoch / args.warmup_epochs 14 | else: 15 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( 16 | 1.0 17 | + math.cos( 18 | math.pi 19 | * (epoch - args.warmup_epochs) 20 | / (args.epochs - args.warmup_epochs) 21 | ) 22 | ) 23 | for param_group in optimizer.param_groups: 24 | if "lr_scale" in param_group: 25 | param_group["lr"] = lr * param_group["lr_scale"] 26 | else: 27 | param_group["lr"] = lr 28 | return lr 29 | -------------------------------------------------------------------------------- /audioldm2/hifigan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/phoneme_encoder/text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import ( 2 | Normalize, 3 | Compose, 4 | RandomResizedCrop, 5 | InterpolationMode, 6 | ToTensor, 7 | Resize, 8 | CenterCrop, 9 | ) 10 | 11 | 12 | def _convert_to_rgb(image): 13 | return image.convert("RGB") 14 | 15 | 16 | def image_transform( 17 | image_size: int, 18 | is_train: bool, 19 | mean=(0.48145466, 0.4578275, 0.40821073), 20 | std=(0.26862954, 0.26130258, 0.27577711), 21 | ): 22 | normalize = Normalize(mean=mean, std=std) 23 | if is_train: 24 | return Compose( 25 | [ 26 | RandomResizedCrop( 27 | image_size, 28 | scale=(0.9, 1.0), 29 | interpolation=InterpolationMode.BICUBIC, 30 | ), 31 | _convert_to_rgb, 32 | ToTensor(), 33 | normalize, 34 | ] 35 | ) 36 | else: 37 | return Compose( 38 | [ 39 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 40 | CenterCrop(image_size), 41 | _convert_to_rgb, 42 | ToTensor(), 43 | normalize, 44 | ] 45 | ) 46 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules.autoencoder import AutoEncoder, AutoEncoderParams 3 | from modules.conditioner import HFEmbedder 4 | from safetensors.torch import load_file as load_sft 5 | 6 | 7 | def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: 8 | # max length 64, 128, 256 and 512 should work (if your sequence is short enough) 9 | return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device) 10 | 11 | 12 | def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: 13 | return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) 14 | 15 | 16 | def load_clap(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: 17 | return HFEmbedder("laion/larger_clap_music", max_length=256, torch_dtype=torch.bfloat16).to(device) 18 | 19 | def load_ae(ckpt_path, device: str | torch.device = "cuda",) -> AutoEncoder: 20 | ae_params=AutoEncoderParams( 21 | resolution=256, 22 | in_channels=3, 23 | ch=128, 24 | out_ch=3, 25 | ch_mult=[1, 2, 4, 4], 26 | num_res_blocks=2, 27 | z_channels=16, 28 | scale_factor=0.3611, 29 | shift_factor=0.1159, 30 | ) 31 | # Loading the autoencoder 32 | ae = AutoEncoder(ae_params) 33 | sd = load_sft(ckpt_path,) 34 | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) 35 | ae.to(device) 36 | return ae 37 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | 23 | @staticmethod 24 | def get_params(img, scale, ratio): 25 | width, height = F._get_image_size(img) 26 | area = height * width 27 | 28 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 29 | log_ratio = torch.log(torch.tensor(ratio)) 30 | aspect_ratio = torch.exp( 31 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 32 | ).item() 33 | 34 | w = int(round(math.sqrt(target_area * aspect_ratio))) 35 | h = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | w = min(w, width) 38 | h = min(h, height) 39 | 40 | i = torch.randint(0, height - h + 1, size=(1,)).item() 41 | j = torch.randint(0, width - w + 1, size=(1,)).item() 42 | 43 | return i, j, h, w 44 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/phoneme_encoder/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from audioldm2.latent_diffusion.modules.phoneme_encoder.text import cleaners 3 | from audioldm2.latent_diffusion.modules.phoneme_encoder.text.symbols import symbols 4 | 5 | 6 | # Mappings from symbol to numeric ID and vice versa: 7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 9 | 10 | cleaner = getattr(cleaners, "english_cleaners2") 11 | 12 | def text_to_sequence(text, cleaner_names): 13 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 14 | Args: 15 | text: string to convert to a sequence 16 | cleaner_names: names of the cleaner functions to run the text through 17 | Returns: 18 | List of integers corresponding to the symbols in the text 19 | ''' 20 | sequence = [] 21 | 22 | clean_text = _clean_text(text, cleaner_names) 23 | for symbol in clean_text: 24 | symbol_id = _symbol_to_id[symbol] 25 | sequence += [symbol_id] 26 | return sequence 27 | 28 | def cleaned_text_to_sequence(cleaned_text): 29 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 30 | Args: 31 | text: string to convert to a sequence 32 | Returns: 33 | List of integers corresponding to the symbols in the text 34 | ''' 35 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] 36 | return sequence 37 | 38 | def sequence_to_text(sequence): 39 | '''Converts a sequence of IDs back to a string''' 40 | result = '' 41 | for symbol_id in sequence: 42 | s = _id_to_symbol[symbol_id] 43 | result += s 44 | return result 45 | 46 | def _clean_text(text, cleaner_names): 47 | text = cleaner(text) 48 | return text 49 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/phoneme_encoder/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | import audioldm2.latent_diffusion.modules.phoneme_encoder.commons as commons 6 | import audioldm2.latent_diffusion.modules.phoneme_encoder.attentions as attentions 7 | 8 | 9 | class TextEncoder(nn.Module): 10 | def __init__( 11 | self, 12 | n_vocab, 13 | out_channels=192, 14 | hidden_channels=192, 15 | filter_channels=768, 16 | n_heads=2, 17 | n_layers=6, 18 | kernel_size=3, 19 | p_dropout=0.1, 20 | ): 21 | super().__init__() 22 | self.n_vocab = n_vocab 23 | self.out_channels = out_channels 24 | self.hidden_channels = hidden_channels 25 | self.filter_channels = filter_channels 26 | self.n_heads = n_heads 27 | self.n_layers = n_layers 28 | self.kernel_size = kernel_size 29 | self.p_dropout = p_dropout 30 | 31 | self.emb = nn.Embedding(n_vocab, hidden_channels) 32 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) 33 | 34 | self.encoder = attentions.Encoder( 35 | hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout 36 | ) 37 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 38 | 39 | def forward(self, x, x_lengths): 40 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] 41 | x = torch.transpose(x, 1, -1) # [b, h, t] 42 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( 43 | x.dtype 44 | ) 45 | 46 | x = self.encoder(x * x_mask, x_mask) 47 | stats = self.proj(x) * x_mask 48 | 49 | m, logs = torch.split(stats, self.out_channels, dim=1) 50 | return x, m, logs, x_mask 51 | -------------------------------------------------------------------------------- /config/example.txt: -------------------------------------------------------------------------------- 1 | An experimental and free-jazz psych-folk tune with jazz influences from Europe. 2 | A powerful and dynamic rock track that blends industrial and electronic elements for a thrilling and energetic sound. 3 | A lively and energizing rock anthem that gets audiences pumped up and singing along. 4 | This song is a fusion of pop and experimental pop that creates a unique and innovative sound. 5 | This song is a mashup of punk, progressive, rock, and no wave genres. 6 | This song is a unique mix of experimental sounds and styles that push the boundaries of conventional music. 7 | An experimental electroacoustic instrumental composition featuring ambient soundscapes and field recordings. 8 | This rock song has a punk essence to it. 9 | This song is a fusion of both heavier metal and more melodic rock genres. 10 | This song features international sounds and influences. 11 | A chaotic, yet danceable blend of punk, new wave, and post-punk elements with electronic, experimental, and industrial influences, featuring chip music and a heavy drone focus. 12 | A wild and rebellious rock anthem with psychedelic and bluesy undertones, this song is loud and energetic. 13 | A soothing and atmospheric instrumental folk song plays in the background. 14 | A high-energy punk rock track with grungy garage undertones. 15 | A hip hop and electronic infused tune with strong elements of dubstep. 16 | This song is a combination of pop and experimental pop. 17 | A diverse indie-rock song that incorporates elements of rock, electronic, goth and trip-hop. 18 | This song is a classic rock anthem with driving guitars and powerful lyrics. 19 | A folk song from a British artist with traditional string instruments such as the fiddle and guitar. 20 | The song is an epic blend of space-rock, rock, and post-rock genres. 21 | A noisy and experimental track with a chaotic sound. 22 | A wild electrified sound consisting of experimental jazz, avant-garde, glitch, and free-jazz, with traces of noise and electroacoustics out there. 23 | A song that incorporates avant-garde, electroacoustic, experimental and musique concrete with field recordings. 24 | This song combines elements of punk, rock, and no wave. -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, "train" if is_train else "val") 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation="bicubic", 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize( 60 | size, interpolation=PIL.Image.BICUBIC 61 | ), # to maintain same ratio w.r.t. 224 images 62 | ) 63 | t.append(transforms.CenterCrop(args.input_size)) 64 | 65 | t.append(transforms.ToTensor()) 66 | t.append(transforms.Normalize(mean, std)) 67 | return transforms.Compose(t) 68 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | 19 | def __init__( 20 | self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001 21 | ): 22 | defaults = dict( 23 | lr=lr, 24 | weight_decay=weight_decay, 25 | momentum=momentum, 26 | trust_coefficient=trust_coefficient, 27 | ) 28 | super().__init__(params, defaults) 29 | 30 | @torch.no_grad() 31 | def step(self): 32 | for g in self.param_groups: 33 | for p in g["params"]: 34 | dp = p.grad 35 | 36 | if dp is None: 37 | continue 38 | 39 | if p.ndim > 1: # if not normalization gamma/beta or bias 40 | dp = dp.add(p, alpha=g["weight_decay"]) 41 | param_norm = torch.norm(p) 42 | update_norm = torch.norm(dp) 43 | one = torch.ones_like(param_norm) 44 | q = torch.where( 45 | param_norm > 0.0, 46 | torch.where( 47 | update_norm > 0, 48 | (g["trust_coefficient"] * param_norm / update_norm), 49 | one, 50 | ), 51 | one, 52 | ) 53 | dp = dp.mul(q) 54 | 55 | param_state = self.state[p] 56 | if "mu" not in param_state: 57 | param_state["mu"] = torch.zeros_like(p) 58 | mu = param_state["mu"] 59 | mu.mul_(g["momentum"]).add_(dp) 60 | p.add_(mu, alpha=-g["lr"]) 61 | -------------------------------------------------------------------------------- /config/16k_64.yaml: -------------------------------------------------------------------------------- 1 | metadata_root: "./data/dataset/metadata/dataset_root.json" 2 | log_directory: "./log/latent_diffusion" 3 | project: "audioldm" 4 | precision: "high" 5 | 6 | variables: 7 | sampling_rate: &sampling_rate 16000 8 | mel_bins: &mel_bins 64 9 | latent_embed_dim: &latent_embed_dim 8 10 | latent_t_size: &latent_t_size 256 # TODO might need to change 11 | latent_f_size: &latent_f_size 16 12 | in_channels: &unet_in_channels 8 13 | optimize_ddpm_parameter: &optimize_ddpm_parameter true 14 | optimize_gpt: &optimize_gpt true 15 | warmup_steps: &warmup_steps 2000 16 | 17 | data: 18 | train: ["audiocaps"] 19 | val: "audiocaps" 20 | test: "audiocaps" 21 | class_label_indices: "audioset_eval_subset" 22 | dataloader_add_ons: ["waveform_rs_48k"] 23 | 24 | step: 25 | validation_every_n_epochs: 15 26 | save_checkpoint_every_n_steps: 5000 27 | # limit_val_batches: 2 28 | max_steps: 800000 29 | save_top_k: 1 30 | 31 | preprocessing: 32 | audio: 33 | sampling_rate: *sampling_rate 34 | max_wav_value: 32768.0 35 | duration: 10.24 36 | stft: 37 | filter_length: 1024 38 | hop_length: 160 39 | win_length: 1024 40 | mel: 41 | n_mel_channels: *mel_bins 42 | mel_fmin: 0 43 | mel_fmax: 8000 44 | 45 | augmentation: 46 | mixup: 0.0 47 | 48 | model: 49 | base_learning_rate: 8.0e-06 50 | target: audioldm_train.modules.latent_encoder.autoencoder.AutoencoderKL 51 | params: 52 | # reload_from_ckpt: "data/checkpoints/vae_mel_16k_64bins.ckpt" 53 | sampling_rate: *sampling_rate 54 | batchsize: 4 55 | monitor: val/rec_loss 56 | image_key: fbank 57 | subband: 1 58 | embed_dim: *latent_embed_dim 59 | time_shuffle: 1 60 | lossconfig: 61 | target: audioldm_train.losses.LPIPSWithDiscriminator 62 | params: 63 | disc_start: 50001 64 | kl_weight: 1000.0 65 | disc_weight: 0.5 66 | disc_in_channels: 1 67 | ddconfig: 68 | double_z: true 69 | mel_bins: *mel_bins # The frequency bins of mel spectrogram 70 | z_channels: 8 71 | resolution: 256 72 | downsample_time: false 73 | in_channels: 1 74 | out_ch: 1 75 | ch: 128 76 | ch_mult: 77 | - 1 78 | - 2 79 | - 4 80 | num_res_blocks: 2 81 | attn_resolutions: [] 82 | dropout: 0.0 83 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/util/stat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | from sklearn import metrics 4 | import torch 5 | 6 | 7 | def d_prime(auc): 8 | standard_normal = stats.norm() 9 | d_prime = standard_normal.ppf(auc) * np.sqrt(2.0) 10 | return d_prime 11 | 12 | 13 | @torch.no_grad() 14 | def concat_all_gather(tensor): 15 | """ 16 | Performs all_gather operation on the provided tensors. 17 | *** Warning ***: torch.distributed.all_gather has no gradient. 18 | """ 19 | tensors_gather = [ 20 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) 21 | ] 22 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 23 | 24 | output = torch.cat(tensors_gather, dim=0) 25 | return output 26 | 27 | 28 | def calculate_stats(output, target): 29 | """Calculate statistics including mAP, AUC, etc. 30 | 31 | Args: 32 | output: 2d array, (samples_num, classes_num) 33 | target: 2d array, (samples_num, classes_num) 34 | 35 | Returns: 36 | stats: list of statistic of each class. 37 | """ 38 | 39 | classes_num = target.shape[-1] 40 | stats = [] 41 | 42 | # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet 43 | acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1)) 44 | 45 | # Class-wise statistics 46 | for k in range(classes_num): 47 | # Average precision 48 | avg_precision = metrics.average_precision_score( 49 | target[:, k], output[:, k], average=None 50 | ) 51 | 52 | # AUC 53 | # auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None) 54 | 55 | # Precisions, recalls 56 | (precisions, recalls, thresholds) = metrics.precision_recall_curve( 57 | target[:, k], output[:, k] 58 | ) 59 | 60 | # FPR, TPR 61 | (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k]) 62 | 63 | save_every_steps = 1000 # Sample statistics to reduce size 64 | dict = { 65 | "precisions": precisions[0::save_every_steps], 66 | "recalls": recalls[0::save_every_steps], 67 | "AP": avg_precision, 68 | "fpr": fpr[0::save_every_steps], 69 | "fnr": 1.0 - tpr[0::save_every_steps], 70 | # 'auc': auc, 71 | # note acc is not class-wise, this is just to keep consistent with other metrics 72 | "acc": acc, 73 | } 74 | stats.append(dict) 75 | 76 | return stats 77 | -------------------------------------------------------------------------------- /modules/conditioner.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, 3 | T5Tokenizer, AutoTokenizer, ClapTextModel) 4 | 5 | 6 | class HFEmbedder(nn.Module): 7 | def __init__(self, version: str, max_length: int, **hf_kwargs): 8 | super().__init__() 9 | self.is_t5 = version.startswith("google") 10 | self.max_length = max_length 11 | self.output_key = "last_hidden_state" if self.is_t5 else "pooler_output" 12 | 13 | if version.startswith("openai"): 14 | local_path = '/maindata/data/shared/multimodal/public/ckpts/stable-diffusion-3-medium-diffusers/text_encoder' 15 | local_path_tokenizer = '/maindata/data/shared/multimodal/public/ckpts/stable-diffusion-3-medium-diffusers/tokenizer' 16 | self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(local_path_tokenizer, max_length=max_length) 17 | self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(local_path, **hf_kwargs).half() 18 | elif version.startswith("laion"): 19 | local_path = '/maindata/data/shared/multimodal/public/dataset_music/clap' 20 | self.tokenizer = AutoTokenizer.from_pretrained(local_path, max_length=max_length) 21 | self.hf_module: ClapTextModel = ClapTextModel.from_pretrained(local_path, **hf_kwargs).half() 22 | else: 23 | local_path = '/maindata/data/shared/multimodal/public/ckpts/stable-diffusion-3-medium-diffusers/text_encoder_3' 24 | local_path_tokenizer = '/maindata/data/shared/multimodal/public/ckpts/stable-diffusion-3-medium-diffusers/tokenizer_3' 25 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(local_path_tokenizer, max_length=max_length) 26 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(local_path, **hf_kwargs).half() 27 | 28 | self.hf_module = self.hf_module.eval().requires_grad_(False) 29 | 30 | def forward(self, text: list[str]) -> Tensor: 31 | batch_encoding = self.tokenizer( 32 | text, 33 | truncation=True, 34 | max_length=self.max_length, 35 | return_length=False, 36 | return_overflowing_tokens=False, 37 | padding="max_length", 38 | return_tensors="pt", 39 | ) 40 | 41 | outputs = self.hf_module( 42 | input_ids=batch_encoding["input_ids"].to(self.hf_module.device), 43 | attention_mask=None, 44 | output_hidden_states=False, 45 | ) 46 | return outputs[self.output_key] 47 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | 13 | def param_groups_lrd( 14 | model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75 15 | ): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0.0 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ["cls_token", "pos_embed"]: 70 | return 0 71 | elif name.startswith("patch_embed"): 72 | return 0 73 | elif name.startswith("blocks"): 74 | return int(name.split(".")[1]) + 1 75 | else: 76 | return num_layers 77 | -------------------------------------------------------------------------------- /audioldm2/utilities/audio/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.io.wavfile import write 4 | import torchaudio 5 | 6 | from audioldm2.utilities.audio.audio_processing import griffin_lim 7 | 8 | 9 | def pad_wav(waveform, segment_length): 10 | waveform_length = waveform.shape[-1] 11 | assert waveform_length > 100, "Waveform is too short, %s" % waveform_length 12 | if segment_length is None or waveform_length == segment_length: 13 | return waveform 14 | elif waveform_length > segment_length: 15 | return waveform[:segment_length] 16 | elif waveform_length < segment_length: 17 | temp_wav = np.zeros((1, segment_length)) 18 | temp_wav[:, :waveform_length] = waveform 19 | return temp_wav 20 | 21 | 22 | def normalize_wav(waveform): 23 | waveform = waveform - np.mean(waveform) 24 | waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) 25 | return waveform * 0.5 26 | 27 | 28 | def read_wav_file(filename, segment_length): 29 | # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower 30 | waveform, sr = torchaudio.load(filename) # Faster!!! 31 | waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) 32 | waveform = waveform.numpy()[0, ...] 33 | waveform = normalize_wav(waveform) 34 | waveform = waveform[None, ...] 35 | waveform = pad_wav(waveform, segment_length) 36 | 37 | waveform = waveform / np.max(np.abs(waveform)) 38 | waveform = 0.5 * waveform 39 | 40 | return waveform 41 | 42 | 43 | def get_mel_from_wav(audio, _stft): 44 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 45 | audio = torch.autograd.Variable(audio, requires_grad=False) 46 | melspec, magnitudes, phases, energy = _stft.mel_spectrogram(audio) 47 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) 48 | magnitudes = torch.squeeze(magnitudes, 0).numpy().astype(np.float32) 49 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32) 50 | return melspec, magnitudes, energy 51 | 52 | 53 | def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): 54 | mel = torch.stack([mel]) 55 | mel_decompress = _stft.spectral_de_normalize(mel) 56 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu() 57 | spec_from_mel_scaling = 1000 58 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) 59 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) 60 | spec_from_mel = spec_from_mel * spec_from_mel_scaling 61 | 62 | audio = griffin_lim( 63 | torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters 64 | ) 65 | 66 | audio = audio.squeeze() 67 | audio = audio.cpu().numpy() 68 | audio_path = out_filename 69 | write(audio_path, _stft.sampling_rate, audio) 70 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | from model import FluxParams, Flux 2 | 3 | def build_model(version='base'): 4 | if version == 'base': 5 | params=FluxParams( 6 | in_channels=32, 7 | vec_in_dim=768, 8 | context_in_dim=4096, 9 | hidden_size=768, 10 | mlp_ratio=4.0, 11 | num_heads=16, 12 | depth=12, 13 | depth_single_blocks=24, 14 | axes_dim=[16, 16, 16], 15 | theta=10_000, 16 | qkv_bias=True, 17 | guidance_embed=True, 18 | ) 19 | 20 | elif version == 'small': 21 | params=FluxParams( 22 | in_channels=32, 23 | vec_in_dim=768, 24 | context_in_dim=4096, 25 | hidden_size=512, 26 | mlp_ratio=4.0, 27 | num_heads=16, 28 | depth=8, 29 | depth_single_blocks=16, 30 | axes_dim=[8, 12, 12], 31 | theta=10_000, 32 | qkv_bias=True, 33 | guidance_embed=True, 34 | ) 35 | 36 | 37 | elif version == 'large': 38 | params=FluxParams( 39 | in_channels=32, 40 | vec_in_dim=768, 41 | context_in_dim=4096, 42 | hidden_size=1024, 43 | mlp_ratio=4.0, 44 | num_heads=16, 45 | depth=12, 46 | depth_single_blocks=24, 47 | axes_dim=[16, 24, 24], 48 | theta=10_000, 49 | qkv_bias=True, 50 | guidance_embed=True, 51 | ) 52 | 53 | elif version == 'biggiant': 54 | params=FluxParams( 55 | in_channels=32, 56 | vec_in_dim=768, 57 | context_in_dim=4096, 58 | hidden_size=2048, 59 | mlp_ratio=4.0, 60 | num_heads=16, 61 | depth=19, 62 | depth_single_blocks=38, 63 | axes_dim=[32, 48, 48], 64 | theta=10_000, 65 | qkv_bias=True, 66 | guidance_embed=True, 67 | ) 68 | 69 | elif version == 'giant_full': 70 | params=FluxParams( 71 | in_channels=32, 72 | vec_in_dim=768, 73 | context_in_dim=4096, 74 | hidden_size=1408, 75 | mlp_ratio=4.0, 76 | num_heads=16, 77 | depth=12, 78 | depth_single_blocks=24, 79 | axes_dim=[16, 36, 36], 80 | theta=10_000, 81 | qkv_bias=True, 82 | guidance_embed=True, 83 | ) 84 | 85 | else: 86 | params=FluxParams( 87 | in_channels=32, 88 | vec_in_dim=768, 89 | context_in_dim=4096, 90 | hidden_size=1408, 91 | mlp_ratio=4.0, 92 | num_heads=16, 93 | depth=16, 94 | depth_single_blocks=32, 95 | axes_dim=[16, 36, 36], 96 | theta=10_000, 97 | qkv_bias=True, 98 | guidance_embed=True, 99 | ) 100 | 101 | 102 | model = Flux(params) 103 | return model -------------------------------------------------------------------------------- /audioldm2/utilities/audio/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import librosa.util as librosa_util 4 | from scipy.signal import get_window 5 | 6 | 7 | def window_sumsquare( 8 | window, 9 | n_frames, 10 | hop_length, 11 | win_length, 12 | n_fft, 13 | dtype=np.float32, 14 | norm=None, 15 | ): 16 | """ 17 | # from librosa 0.6 18 | Compute the sum-square envelope of a window function at a given hop length. 19 | 20 | This is used to estimate modulation effects induced by windowing 21 | observations in short-time fourier transforms. 22 | 23 | Parameters 24 | ---------- 25 | window : string, tuple, number, callable, or list-like 26 | Window specification, as in `get_window` 27 | 28 | n_frames : int > 0 29 | The number of analysis frames 30 | 31 | hop_length : int > 0 32 | The number of samples to advance between frames 33 | 34 | win_length : [optional] 35 | The length of the window function. By default, this matches `n_fft`. 36 | 37 | n_fft : int > 0 38 | The length of each analysis frame. 39 | 40 | dtype : np.dtype 41 | The data type of the output 42 | 43 | Returns 44 | ------- 45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 46 | The sum-squared envelope of the window function 47 | """ 48 | if win_length is None: 49 | win_length = n_fft 50 | 51 | n = n_fft + hop_length * (n_frames - 1) 52 | x = np.zeros(n, dtype=dtype) 53 | 54 | # Compute the squared window at the desired length 55 | win_sq = get_window(window, win_length, fftbins=True) 56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 57 | win_sq = librosa_util.pad_center(win_sq, n_fft) 58 | 59 | # Fill the envelope 60 | for i in range(n_frames): 61 | sample = i * hop_length 62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] 63 | return x 64 | 65 | 66 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 67 | """ 68 | PARAMS 69 | ------ 70 | magnitudes: spectrogram magnitudes 71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 72 | """ 73 | 74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 75 | angles = angles.astype(np.float32) 76 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 78 | 79 | for i in range(n_iters): 80 | _, angles = stft_fn.transform(signal) 81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 82 | return signal 83 | 84 | 85 | def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): 86 | """ 87 | PARAMS 88 | ------ 89 | C: compression factor 90 | """ 91 | return normalize_fun(torch.clamp(x, min=clip_val) * C) 92 | 93 | 94 | def dynamic_range_decompression(x, C=1): 95 | """ 96 | PARAMS 97 | ------ 98 | C: compression factor used to compress 99 | """ 100 | return torch.exp(x) / C 101 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_( 47 | one_minus_decay * (shadow_params[sname] - m_param[key]) 48 | ) 49 | else: 50 | assert not key in self.m_name2s_name 51 | 52 | def copy_to(self, model): 53 | m_param = dict(model.named_parameters()) 54 | shadow_params = dict(self.named_buffers()) 55 | for key in m_param: 56 | if m_param[key].requires_grad: 57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 58 | else: 59 | assert not key in self.m_name2s_name 60 | 61 | def store(self, parameters): 62 | """ 63 | Save the current parameters for restoring later. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | temporarily stored. 67 | """ 68 | self.collected_params = [param.clone() for param in parameters] 69 | 70 | def restore(self, parameters): 71 | """ 72 | Restore the parameters stored with the `store` method. 73 | Useful to validate the model with EMA parameters without affecting the 74 | original optimization process. Store the parameters before the 75 | `copy_to` method. After validation (or model saving), use this to 76 | restore the former parameters. 77 | Args: 78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 79 | updated with the stored parameters. 80 | """ 81 | for c_param, param in zip(self.collected_params, parameters): 82 | param.data.copy_(c_param.data) 83 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/phoneme_encoder/text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | import re 16 | from unidecode import unidecode 17 | from phonemizer import phonemize 18 | 19 | 20 | # Regular expression matching whitespace: 21 | _whitespace_re = re.compile(r'\s+') 22 | 23 | # List of (regular expression, replacement) pairs for abbreviations: 24 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 25 | ('mrs', 'misess'), 26 | ('mr', 'mister'), 27 | ('dr', 'doctor'), 28 | ('st', 'saint'), 29 | ('co', 'company'), 30 | ('jr', 'junior'), 31 | ('maj', 'major'), 32 | ('gen', 'general'), 33 | ('drs', 'doctors'), 34 | ('rev', 'reverend'), 35 | ('lt', 'lieutenant'), 36 | ('hon', 'honorable'), 37 | ('sgt', 'sergeant'), 38 | ('capt', 'captain'), 39 | ('esq', 'esquire'), 40 | ('ltd', 'limited'), 41 | ('col', 'colonel'), 42 | ('ft', 'fort'), 43 | ]] 44 | 45 | 46 | def expand_abbreviations(text): 47 | for regex, replacement in _abbreviations: 48 | text = re.sub(regex, replacement, text) 49 | return text 50 | 51 | 52 | def expand_numbers(text): 53 | return normalize_numbers(text) 54 | 55 | 56 | def lowercase(text): 57 | return text.lower() 58 | 59 | 60 | def collapse_whitespace(text): 61 | return re.sub(_whitespace_re, ' ', text) 62 | 63 | 64 | def convert_to_ascii(text): 65 | return unidecode(text) 66 | 67 | 68 | def basic_cleaners(text): 69 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 70 | text = lowercase(text) 71 | text = collapse_whitespace(text) 72 | return text 73 | 74 | 75 | def transliteration_cleaners(text): 76 | '''Pipeline for non-English text that transliterates to ASCII.''' 77 | text = convert_to_ascii(text) 78 | text = lowercase(text) 79 | text = collapse_whitespace(text) 80 | return text 81 | 82 | 83 | def english_cleaners(text): 84 | '''Pipeline for English text, including abbreviation expansion.''' 85 | text = convert_to_ascii(text) 86 | text = lowercase(text) 87 | text = expand_abbreviations(text) 88 | phonemes = phonemize(text, language='en-us', backend='espeak', strip=True) 89 | phonemes = collapse_whitespace(phonemes) 90 | return phonemes 91 | 92 | 93 | def english_cleaners2(text): 94 | '''Pipeline for English text, including abbreviation expansion. + punctuation + stress''' 95 | text = convert_to_ascii(text) 96 | text = lowercase(text) 97 | text = expand_abbreviations(text) 98 | phonemes = phonemize(text, language='en-us', backend='espeak', strip=True, preserve_punctuation=True, with_stress=True) 99 | phonemes = collapse_whitespace(phonemes) 100 | return phonemes 101 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.mean( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.mean( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | 6 | from modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, 7 | MLPEmbedder, SingleStreamBlock, 8 | timestep_embedding) 9 | 10 | 11 | @dataclass 12 | class FluxParams: 13 | in_channels: int 14 | vec_in_dim: int 15 | context_in_dim: int 16 | hidden_size: int 17 | mlp_ratio: float 18 | num_heads: int 19 | depth: int 20 | depth_single_blocks: int 21 | axes_dim: list[int] 22 | theta: int 23 | qkv_bias: bool 24 | guidance_embed: bool 25 | 26 | 27 | class Flux(nn.Module): 28 | """ 29 | Transformer model for flow matching on sequences. 30 | """ 31 | 32 | def __init__(self, params: FluxParams): 33 | super().__init__() 34 | 35 | self.params = params 36 | self.in_channels = params.in_channels 37 | self.out_channels = self.in_channels 38 | if params.hidden_size % params.num_heads != 0: 39 | raise ValueError( 40 | f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" 41 | ) 42 | pe_dim = params.hidden_size // params.num_heads 43 | if sum(params.axes_dim) != pe_dim: 44 | raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") 45 | self.hidden_size = params.hidden_size 46 | self.num_heads = params.num_heads 47 | self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 48 | self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) 49 | self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 50 | self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) 51 | # self.guidance_in = ( 52 | # MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() 53 | # ) 54 | self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) 55 | 56 | self.double_blocks = nn.ModuleList( 57 | [ 58 | DoubleStreamBlock( 59 | self.hidden_size, 60 | self.num_heads, 61 | mlp_ratio=params.mlp_ratio, 62 | qkv_bias=params.qkv_bias, 63 | ) 64 | for _ in range(params.depth) 65 | ] 66 | ) 67 | 68 | self.single_blocks = nn.ModuleList( 69 | [ 70 | SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) 71 | for _ in range(params.depth_single_blocks) 72 | ] 73 | ) 74 | 75 | self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) 76 | 77 | def forward( 78 | self, 79 | x: Tensor, 80 | img_ids: Tensor, 81 | txt: Tensor, 82 | txt_ids: Tensor, 83 | t: Tensor, 84 | y: Tensor, 85 | guidance: Tensor | None = None, 86 | ) -> Tensor: 87 | if x.ndim != 3 or txt.ndim != 3: 88 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 89 | 90 | # running on sequences img 91 | img = self.img_in(x) 92 | vec = self.time_in(timestep_embedding(t, 256)) 93 | # if self.params.guidance_embed: 94 | # if guidance is None: 95 | # raise ValueError("Didn't get guidance strength for guidance distilled model.") 96 | # vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) 97 | vec = vec + self.vector_in(y) 98 | txt = self.txt_in(txt) 99 | 100 | ids = torch.cat((txt_ids, img_ids), dim=1) 101 | pe = self.pe_embedder(ids) 102 | 103 | for block in self.double_blocks: 104 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe) 105 | 106 | img = torch.cat((txt, img), 1) 107 | for block in self.single_blocks: 108 | img = block(img, vec=vec, pe=pe) 109 | img = img[:, txt.shape[1] :, ...] 110 | 111 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 112 | return img 113 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import math 5 | from einops import rearrange, repeat 6 | from PIL import Image 7 | from diffusers import AutoencoderKL 8 | from transformers import SpeechT5HifiGan 9 | 10 | from utils import load_t5, load_clap, load_ae 11 | from train import RF 12 | from constants import build_model 13 | 14 | 15 | def prepare(t5, clip, img, prompt): 16 | bs, c, h, w = img.shape 17 | if bs == 1 and not isinstance(prompt, str): 18 | bs = len(prompt) 19 | 20 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 21 | if img.shape[0] == 1 and bs > 1: 22 | img = repeat(img, "1 ... -> bs ...", bs=bs) 23 | 24 | img_ids = torch.zeros(h // 2, w // 2, 3) 25 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] 26 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] 27 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 28 | 29 | if isinstance(prompt, str): 30 | prompt = [prompt] 31 | txt = t5(prompt) 32 | if txt.shape[0] == 1 and bs > 1: 33 | txt = repeat(txt, "1 ... -> bs ...", bs=bs) 34 | txt_ids = torch.zeros(bs, txt.shape[1], 3) 35 | 36 | vec = clip(prompt) 37 | if vec.shape[0] == 1 and bs > 1: 38 | vec = repeat(vec, "1 ... -> bs ...", bs=bs) 39 | 40 | print(img_ids.size(), txt.size(), vec.size()) 41 | return img, { 42 | "img_ids": img_ids.to(img.device), 43 | "txt": txt.to(img.device), 44 | "txt_ids": txt_ids.to(img.device), 45 | "y": vec.to(img.device), 46 | } 47 | 48 | def main(args): 49 | print('generate with MusicFlux') 50 | torch.manual_seed(args.seed) 51 | torch.set_grad_enabled(False) 52 | device = "cuda" if torch.cuda.is_available() else "cpu" 53 | 54 | latent_size = (256, 16) 55 | 56 | model = build_model(args.version).to(device) 57 | local_path = args.ckpt_path 58 | state_dict = torch.load(local_path, map_location=lambda storage, loc: storage) 59 | model.load_state_dict(state_dict['ema']) 60 | model.eval() # important! 61 | diffusion = RF() 62 | 63 | # Setup VAE 64 | t5 = load_t5(device, max_length=256) 65 | clap = load_clap(device, max_length=256) 66 | 67 | vae = AutoencoderKL.from_pretrained(os.path.join(args.audioldm2_model_path, 'vae')).to(device) 68 | vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(args.audioldm2_model_path, 'vocoder')).to(device) 69 | 70 | with open(args.prompt_file, 'r') as f: 71 | conds_txt = f.readlines() 72 | L = len(conds_txt) 73 | unconds_txt = ["low quality, gentle"] * L 74 | print(L, conds_txt, unconds_txt) 75 | 76 | init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).cuda() 77 | 78 | STEPSIZE = 50 79 | img, conds = prepare(t5, clap, init_noise, conds_txt) 80 | _, unconds = prepare(t5, clap, init_noise, unconds_txt) 81 | with torch.autocast(device_type='cuda'): 82 | images = diffusion.sample_with_xps(model, img, conds=conds, null_cond=unconds, sample_steps = STEPSIZE, cfg = 7.0) 83 | 84 | print(images[-1].size(), ) 85 | 86 | images = rearrange( 87 | images[-1], 88 | "b (h w) (c ph pw) -> b c (h ph) (w pw)", 89 | h=128, 90 | w=8, 91 | ph=2, 92 | pw=2,) 93 | # print(images.size()) 94 | latents = 1 / vae.config.scaling_factor * images 95 | mel_spectrogram = vae.decode(latents).sample 96 | print(mel_spectrogram.size()) 97 | 98 | for i in range(L): 99 | x_i = mel_spectrogram[i] 100 | if x_i.dim() == 4: 101 | x_i = x_i.squeeze(1) 102 | waveform = vocoder(x_i) 103 | waveform = waveform[0].cpu().float().detach().numpy() 104 | print(waveform.shape) 105 | # import soundfile as sf 106 | # sf.write('reconstruct.wav', waveform, samplerate=16000) 107 | from scipy.io import wavfile 108 | wavfile.write('wav/sample_' + str(i) + '.wav', 16000, waveform) 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument("--version", type=str, default="small") 114 | parser.add_argument("--prompt_file", type=str, default='config/example.txt') 115 | parser.add_argument("--ckpt_path", type=str, default='musicflow_s.pt') 116 | parser.add_argument("--audioldm2_model_path", type=str, default='/maindata/data/shared/multimodal/public/dataset_music/audioldm2' ) 117 | parser.add_argument("--seed", type=int, default=2024) 118 | args = parser.parse_args() 119 | main(args) 120 | 121 | 122 | -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch.nn as nn 8 | 9 | try: 10 | import timm 11 | from timm.models.layers import Mlp, to_2tuple 12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 13 | from timm.models.layers.attention_pool2d import ( 14 | AttentionPool2d as AbsAttentionPool2d, 15 | ) 16 | except ImportError: 17 | timm = None 18 | 19 | from .utils import freeze_batch_norm_2d 20 | 21 | 22 | class TimmModel(nn.Module): 23 | """timm model adapter 24 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 25 | """ 26 | 27 | def __init__( 28 | self, 29 | model_name, 30 | embed_dim, 31 | image_size=224, 32 | pool="avg", 33 | proj="linear", 34 | drop=0.0, 35 | pretrained=False, 36 | ): 37 | super().__init__() 38 | if timm is None: 39 | raise RuntimeError("Please `pip install timm` to use timm models.") 40 | 41 | self.image_size = to_2tuple(image_size) 42 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 43 | feat_size = self.trunk.default_cfg.get("pool_size", None) 44 | feature_ndim = 1 if not feat_size else 2 45 | if pool in ("abs_attn", "rot_attn"): 46 | assert feature_ndim == 2 47 | # if attn pooling used, remove both classifier and default pool 48 | self.trunk.reset_classifier(0, global_pool="") 49 | else: 50 | # reset global pool if pool config set, otherwise leave as network default 51 | reset_kwargs = dict(global_pool=pool) if pool else {} 52 | self.trunk.reset_classifier(0, **reset_kwargs) 53 | prev_chs = self.trunk.num_features 54 | 55 | head_layers = OrderedDict() 56 | if pool == "abs_attn": 57 | head_layers["pool"] = AbsAttentionPool2d( 58 | prev_chs, feat_size=feat_size, out_features=embed_dim 59 | ) 60 | prev_chs = embed_dim 61 | elif pool == "rot_attn": 62 | head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 63 | prev_chs = embed_dim 64 | else: 65 | assert proj, "projection layer needed if non-attention pooling is used." 66 | 67 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 68 | if proj == "linear": 69 | head_layers["drop"] = nn.Dropout(drop) 70 | head_layers["proj"] = nn.Linear(prev_chs, embed_dim) 71 | elif proj == "mlp": 72 | head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) 73 | 74 | self.head = nn.Sequential(head_layers) 75 | 76 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 77 | """lock modules 78 | Args: 79 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 80 | """ 81 | if not unlocked_groups: 82 | # lock full model 83 | for param in self.trunk.parameters(): 84 | param.requires_grad = False 85 | if freeze_bn_stats: 86 | freeze_batch_norm_2d(self.trunk) 87 | else: 88 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 89 | try: 90 | # FIXME import here until API stable and in an official release 91 | from timm.models.helpers import group_parameters, group_modules 92 | except ImportError: 93 | raise RuntimeError( 94 | "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`" 95 | ) 96 | matcher = self.trunk.group_matcher() 97 | gparams = group_parameters(self.trunk, matcher) 98 | max_layer_id = max(gparams.keys()) 99 | max_layer_id = max_layer_id - unlocked_groups 100 | for group_idx in range(max_layer_id + 1): 101 | group = gparams[group_idx] 102 | for param in group: 103 | self.trunk.get_parameter(param).requires_grad = False 104 | if freeze_bn_stats: 105 | gmodules = group_modules(self.trunk, matcher, reverse=True) 106 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 107 | freeze_batch_norm_2d(self.trunk, gmodules) 108 | 109 | def forward(self, x): 110 | x = self.trunk(x) 111 | x = self.head(x) 112 | return x 113 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/util/patch_embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import to_2tuple 4 | 5 | 6 | class PatchEmbed_org(nn.Module): 7 | """Image to Patch Embedding""" 8 | 9 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 10 | super().__init__() 11 | img_size = to_2tuple(img_size) 12 | patch_size = to_2tuple(patch_size) 13 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 14 | self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) 15 | self.img_size = img_size 16 | self.patch_size = patch_size 17 | self.num_patches = num_patches 18 | 19 | self.proj = nn.Conv2d( 20 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size 21 | ) 22 | 23 | def forward(self, x): 24 | B, C, H, W = x.shape 25 | # FIXME look at relaxing size constraints 26 | # assert H == self.img_size[0] and W == self.img_size[1], \ 27 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 28 | x = self.proj(x) 29 | y = x.flatten(2).transpose(1, 2) 30 | return y 31 | 32 | 33 | class PatchEmbed_new(nn.Module): 34 | """Flexible Image to Patch Embedding""" 35 | 36 | def __init__( 37 | self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 38 | ): 39 | super().__init__() 40 | img_size = to_2tuple(img_size) 41 | patch_size = to_2tuple(patch_size) 42 | stride = to_2tuple(stride) 43 | 44 | self.img_size = img_size 45 | self.patch_size = patch_size 46 | 47 | self.proj = nn.Conv2d( 48 | in_chans, embed_dim, kernel_size=patch_size, stride=stride 49 | ) # with overlapped patches 50 | # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 51 | 52 | # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) 53 | # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 54 | _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w 55 | self.patch_hw = (h, w) 56 | self.num_patches = h * w 57 | 58 | def get_output_shape(self, img_size): 59 | # todo: don't be lazy.. 60 | return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape 61 | 62 | def forward(self, x): 63 | B, C, H, W = x.shape 64 | # FIXME look at relaxing size constraints 65 | # assert H == self.img_size[0] and W == self.img_size[1], \ 66 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 67 | # x = self.proj(x).flatten(2).transpose(1, 2) 68 | x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12 69 | x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212 70 | x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768 71 | return x 72 | 73 | 74 | class PatchEmbed3D_new(nn.Module): 75 | """Flexible Image to Patch Embedding""" 76 | 77 | def __init__( 78 | self, 79 | video_size=(16, 224, 224), 80 | patch_size=(2, 16, 16), 81 | in_chans=3, 82 | embed_dim=768, 83 | stride=(2, 16, 16), 84 | ): 85 | super().__init__() 86 | 87 | self.video_size = video_size 88 | self.patch_size = patch_size 89 | self.in_chans = in_chans 90 | 91 | self.proj = nn.Conv3d( 92 | in_chans, embed_dim, kernel_size=patch_size, stride=stride 93 | ) 94 | _, _, t, h, w = self.get_output_shape(video_size) # n, emb_dim, h, w 95 | self.patch_thw = (t, h, w) 96 | self.num_patches = t * h * w 97 | 98 | def get_output_shape(self, video_size): 99 | # todo: don't be lazy.. 100 | return self.proj( 101 | torch.randn(1, self.in_chans, video_size[0], video_size[1], video_size[2]) 102 | ).shape 103 | 104 | def forward(self, x): 105 | B, C, T, H, W = x.shape 106 | x = self.proj(x) # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14 107 | x = x.flatten(2) # 32, 768, 1568 108 | x = x.transpose(1, 2) # 32, 768, 1568 -> 32, 1568, 768 109 | return x 110 | 111 | 112 | if __name__ == "__main__": 113 | # patch_emb = PatchEmbed_new(img_size=224, patch_size=16, in_chans=1, embed_dim=64, stride=(16,16)) 114 | # input = torch.rand(8,1,1024,128) 115 | # output = patch_emb(input) 116 | # print(output.shape) # (8,512,64) 117 | 118 | patch_emb = PatchEmbed3D_new( 119 | video_size=(6, 224, 224), 120 | patch_size=(2, 16, 16), 121 | in_chans=3, 122 | embed_dim=768, 123 | stride=(2, 16, 16), 124 | ) 125 | input = torch.rand(8, 3, 6, 224, 224) 126 | output = patch_emb(input) 127 | print(output.shape) # (8,64) 128 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## FluxMusic: Text-to-Music Generation with Rectified Flow Transformer
Official PyTorch Implementation 2 | 3 |   4 |   5 |   6 |   7 | 8 | 9 | 10 | This repo contains PyTorch model definitions, pre-trained weights, and training/sampling code for paper *Flux that plays music*. 11 | It explores a simple extension of diffusion-based rectified flow Transformers for text-to-music generation. The model architecture can be seen as follows: 12 | 13 | 14 | 15 | 16 | ### To-do list 17 | 18 | - [x] training / inference scripts 19 | - [x] clean code 20 | - [x] all ckpts and part of dataset 21 | 22 | 23 | ### 1. Training 24 | 25 | You can refer to the [link](https://github.com/black-forest-labs/flux) to build the running environment. 26 | 27 | To launch small version in the latent space training with `N` GPUs on one node with pytorch DDP: 28 | ```bash 29 | torchrun --nnodes=1 --nproc_per_node=N train.py \ 30 | --version small \ 31 | --data-path xxx \ 32 | --global_batch_size 128 33 | ``` 34 | 35 | More scripts of different model size can reference to `scripts` file direction. 36 | 37 | 38 | ### 2. Inference 39 | 40 | We include a [`sample.py`](sample.py) script which samples music clips according to conditions from a MusicFlux model as: 41 | ```bash 42 | python sample.py \ 43 | --version small \ 44 | --ckpt_path /path/to/model \ 45 | --prompt_file config/example.txt 46 | ``` 47 | 48 | All prompts used in paper are lists in `config/example.txt`. 49 | 50 | 51 | ### 3. Download Ckpts and Data 52 | 53 | We use VAE and Vocoder in AudioLDM2, CLAP-L, and T5-XXL. You can download in the following table directly, we also provide the training scripts in our experiments. 54 | 55 | Note that in actual experiments, a restart experiment was performed due to machine malfunction, so there will be resume options in some scripts. 56 | 57 | 58 | | Model |Training steps | Url | Training scripts | 59 | |-------|--------|------------------|---------| 60 | | VAE | -| [link](https://huggingface.co/cvssp/audioldm2/tree/main/vae) | - | 61 | | Vocoder |-| [link](https://huggingface.co/cvssp/audioldm2/tree/main/vocoder) | - | 62 | | T5-XXL | - | [link](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers/tree/main/text_encoder_3) | - | 63 | | CLAP-L | -| [link](https://huggingface.co/laion/larger_clap_music/tree/main) | - | 64 | | FluxMusic-Small | 200K | [link](https://huggingface.co/feizhengcong/FluxMusic/blob/main/musicflow_s.pt) | [link](https://github.com/feizc/FluxMusic/blob/main/scripts/train_s.sh) | 65 | | FluxMusic-Base | 200K | [link](https://huggingface.co/feizhengcong/FluxMusic/blob/main/musicflow_b.pt) | [link](https://github.com/feizc/FluxMusic/blob/main/scripts/train_b.sh) | 66 | | FluxMusic-Large | 200K | [link](https://huggingface.co/feizhengcong/FluxMusic/blob/main/musicflow_l.pt) | [link](https://github.com/feizc/FluxMusic/blob/main/scripts/train_l.sh) | 67 | | FluxMusic-Giant | 200K | [link](https://huggingface.co/feizhengcong/FluxMusic/blob/main/musicflow_g.pt) | [link](https://github.com/feizc/FluxMusic/blob/main/scripts/train_g.sh) | 68 | | FluxMusic-Giant-Full | 2M | [link](https://huggingface.co/feizhengcong/FluxMusic/blob/main/musicflow_g_full.pt) | - | 69 | 70 | 71 | Note that 200K-steps ckpts are trained on a sub-training set and used for ploted the scaling experiments as well as case studies in the paper. 72 | The full version of main results will be released right way. 73 | 74 | The construction of training data can refer to the `test.py` file, showing a simple build of combing differnet datasets in json file. 75 | 76 | Considering copyright issues, the data used in the paper needs to be downloaded by oneself. 77 | 78 | We provide a clean subset in:   79 | 80 | A quick download link for other datasets can be found in [Huggingface](https://huggingface.co/datasets?search=music) : ). 81 | 82 | This is a research project, and it is recommended to try advanced products: 83 |   84 | 85 | 86 | ### Acknowledgments 87 | 88 | The codebase is based on the awesome [Flux](https://github.com/black-forest-labs/flux) and [AudioLDM2](https://github.com/haoheliu/AudioLDM2) repos. 89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /audioldm2/__main__.py: -------------------------------------------------------------------------------- 1 | #!D:\GitDownload\SupThirdParty\audioldm2\venv\Scripts\python.exe 2 | import os 3 | import torch 4 | import logging 5 | from audioldm2 import text_to_audio, build_model, save_wave, get_time, read_list 6 | import argparse 7 | 8 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 9 | matplotlib_logger = logging.getLogger('matplotlib') 10 | matplotlib_logger.setLevel(logging.WARNING) 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument( 15 | "-t", 16 | "--text", 17 | type=str, 18 | required=False, 19 | default="", 20 | help="Text prompt to the model for audio generation", 21 | ) 22 | 23 | parser.add_argument( 24 | "--transcription", 25 | type=str, 26 | required=False, 27 | default="", 28 | help="Transcription for Text-to-Speech", 29 | ) 30 | 31 | parser.add_argument( 32 | "-tl", 33 | "--text_list", 34 | type=str, 35 | required=False, 36 | default="", 37 | help="A file that contains text prompt to the model for audio generation", 38 | ) 39 | 40 | parser.add_argument( 41 | "-s", 42 | "--save_path", 43 | type=str, 44 | required=False, 45 | help="The path to save model output", 46 | default="./output", 47 | ) 48 | 49 | parser.add_argument( 50 | "--model_name", 51 | type=str, 52 | required=False, 53 | help="The checkpoint you gonna use", 54 | default="audioldm_48k", 55 | choices=["audioldm_48k", "audioldm_16k_crossattn_t5", "audioldm2-full", "audioldm2-music-665k", 56 | "audioldm2-full-large-1150k", "audioldm2-speech-ljspeech", "audioldm2-speech-gigaspeech"] 57 | ) 58 | 59 | parser.add_argument( 60 | "-d", 61 | "--device", 62 | type=str, 63 | required=False, 64 | help="The device for computation. If not specified, the script will automatically choose the device based on your environment.", 65 | default="auto", 66 | ) 67 | 68 | parser.add_argument( 69 | "-b", 70 | "--batchsize", 71 | type=int, 72 | required=False, 73 | default=1, 74 | help="Generate how many samples at the same time", 75 | ) 76 | 77 | parser.add_argument( 78 | "--ddim_steps", 79 | type=int, 80 | required=False, 81 | default=200, 82 | help="The sampling step for DDIM", 83 | ) 84 | 85 | parser.add_argument( 86 | "-gs", 87 | "--guidance_scale", 88 | type=float, 89 | required=False, 90 | default=3.5, 91 | help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)", 92 | ) 93 | 94 | parser.add_argument( 95 | "-dur", 96 | "--duration", 97 | type=float, 98 | required=False, 99 | default=10.0, 100 | help="The duration of the samples", 101 | ) 102 | 103 | parser.add_argument( 104 | "-n", 105 | "--n_candidate_gen_per_text", 106 | type=int, 107 | required=False, 108 | default=3, 109 | help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation", 110 | ) 111 | 112 | parser.add_argument( 113 | "--seed", 114 | type=int, 115 | required=False, 116 | default=0, 117 | help="Change this value (any integer number) will lead to a different generation result.", 118 | ) 119 | 120 | args = parser.parse_args() 121 | 122 | torch.set_float32_matmul_precision("high") 123 | 124 | save_path = os.path.join(args.save_path, get_time()) 125 | 126 | text = args.text 127 | random_seed = args.seed 128 | duration = args.duration 129 | sample_rate = 16000 130 | 131 | if ("audioldm2" in args.model_name): 132 | print( 133 | "Warning: For AudioLDM2 we currently only support 10s of generation. Please use audioldm_48k or audioldm_16k_crossattn_t5 if you want a different duration.") 134 | duration = 10 135 | if ("48k" in args.model_name): 136 | sample_rate = 48000 137 | 138 | guidance_scale = args.guidance_scale 139 | n_candidate_gen_per_text = args.n_candidate_gen_per_text 140 | transcription = args.transcription 141 | 142 | if (transcription): 143 | if "speech" not in args.model_name: 144 | print( 145 | "Warning: You choose to perform Text-to-Speech by providing the transcription.However you do not choose the correct model name (audioldm2-speech-gigaspeech or audioldm2-speech-ljspeech).") 146 | print("Warning: We will use audioldm2-speech-gigaspeech by default") 147 | args.model_name = "audioldm2-speech-gigaspeech" 148 | if (not text): 149 | print( 150 | "Warning: You should provide text as a input to describe the speaker. Use default (A male reporter is speaking)") 151 | text = "A female reporter is speaking full of emotion" 152 | 153 | os.makedirs(save_path, exist_ok=True) 154 | audioldm2 = build_model(model_name=args.model_name, device=args.device) 155 | 156 | if (args.text_list): 157 | print("Generate audio based on the text prompts in %s" % args.text_list) 158 | prompt_todo = read_list(args.text_list) 159 | else: 160 | prompt_todo = [text] 161 | 162 | for text in prompt_todo: 163 | if ("|" in text): 164 | text, name = text.split("|") 165 | else: 166 | name = text[:128] 167 | 168 | if (transcription): 169 | name += "-TTS-%s" % transcription 170 | 171 | waveform = text_to_audio( 172 | audioldm2, 173 | text, 174 | transcription=transcription, # To avoid the model to ignore the last vocab 175 | seed=random_seed, 176 | duration=duration, 177 | guidance_scale=guidance_scale, 178 | ddim_steps=args.ddim_steps, 179 | n_candidate_gen_per_text=n_candidate_gen_per_text, 180 | batchsize=args.batchsize, 181 | ) 182 | 183 | save_wave(waveform, save_path, name=name, samplerate=sample_rate) 184 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/phoneme_encoder/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def init_weights(m, mean=0.0, std=0.01): 7 | classname = m.__class__.__name__ 8 | if classname.find("Conv") != -1: 9 | m.weight.data.normal_(mean, std) 10 | 11 | 12 | def get_padding(kernel_size, dilation=1): 13 | return int((kernel_size * dilation - dilation) / 2) 14 | 15 | 16 | def convert_pad_shape(pad_shape): 17 | l = pad_shape[::-1] 18 | pad_shape = [item for sublist in l for item in sublist] 19 | return pad_shape 20 | 21 | 22 | def intersperse(lst, item): 23 | result = [item] * (len(lst) * 2 + 1) 24 | result[1::2] = lst 25 | return result 26 | 27 | 28 | def kl_divergence(m_p, logs_p, m_q, logs_q): 29 | """KL(P||Q)""" 30 | kl = (logs_q - logs_p) - 0.5 31 | kl += ( 32 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 33 | ) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | ret = torch.zeros_like(x[:, :, :segment_size]) 50 | for i in range(x.size(0)): 51 | idx_str = ids_str[i] 52 | idx_end = idx_str + segment_size 53 | ret[i] = x[i, :, idx_str:idx_end] 54 | return ret 55 | 56 | 57 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 58 | b, d, t = x.size() 59 | if x_lengths is None: 60 | x_lengths = t 61 | ids_str_max = x_lengths - segment_size + 1 62 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 63 | ret = slice_segments(x, ids_str, segment_size) 64 | return ret, ids_str 65 | 66 | 67 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 68 | position = torch.arange(length, dtype=torch.float) 69 | num_timescales = channels // 2 70 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( 71 | num_timescales - 1 72 | ) 73 | inv_timescales = min_timescale * torch.exp( 74 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 75 | ) 76 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 77 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 78 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 79 | signal = signal.view(1, channels, length) 80 | return signal 81 | 82 | 83 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 84 | b, channels, length = x.size() 85 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 86 | return x + signal.to(dtype=x.dtype, device=x.device) 87 | 88 | 89 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 90 | b, channels, length = x.size() 91 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 92 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 93 | 94 | 95 | def subsequent_mask(length): 96 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 97 | return mask 98 | 99 | 100 | @torch.jit.script 101 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 102 | n_channels_int = n_channels[0] 103 | in_act = input_a + input_b 104 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 105 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 106 | acts = t_act * s_act 107 | return acts 108 | 109 | 110 | def convert_pad_shape(pad_shape): 111 | l = pad_shape[::-1] 112 | pad_shape = [item for sublist in l for item in sublist] 113 | return pad_shape 114 | 115 | 116 | def shift_1d(x): 117 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 118 | return x 119 | 120 | 121 | def sequence_mask(length, max_length=None): 122 | if max_length is None: 123 | max_length = length.max() 124 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 125 | return x.unsqueeze(0) < length.unsqueeze(1) 126 | 127 | 128 | def generate_path(duration, mask): 129 | """ 130 | duration: [b, 1, t_x] 131 | mask: [b, 1, t_y, t_x] 132 | """ 133 | duration.device 134 | 135 | b, _, t_y, t_x = mask.shape 136 | cum_duration = torch.cumsum(duration, -1) 137 | 138 | cum_duration_flat = cum_duration.view(b * t_x) 139 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 140 | path = path.view(b, t_x, t_y) 141 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 142 | path = path.unsqueeze(1).transpose(2, 3) * mask 143 | return path 144 | 145 | 146 | def clip_grad_value_(parameters, clip_value, norm_type=2): 147 | if isinstance(parameters, torch.Tensor): 148 | parameters = [parameters] 149 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 150 | norm_type = float(norm_type) 151 | if clip_value is not None: 152 | clip_value = float(clip_value) 153 | 154 | total_norm = 0 155 | for p in parameters: 156 | param_norm = p.grad.data.norm(norm_type) 157 | total_norm += param_norm.item() ** norm_type 158 | if clip_value is not None: 159 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 160 | total_norm = total_norm ** (1.0 / norm_type) 161 | return total_norm 162 | -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import Union, List 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict 13 | from .pretrained import ( 14 | get_pretrained_url, 15 | list_pretrained_tag_models, 16 | download_pretrained, 17 | ) 18 | 19 | __all__ = ["list_openai_models", "load_openai_model"] 20 | 21 | 22 | def list_openai_models() -> List[str]: 23 | """Returns the names of available CLIP models""" 24 | return list_pretrained_tag_models("openai") 25 | 26 | 27 | def load_openai_model( 28 | name: str, 29 | model_cfg, 30 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 31 | jit=True, 32 | cache_dir=os.path.expanduser("~/.cache/clip"), 33 | enable_fusion: bool = False, 34 | fusion_type: str = "None", 35 | ): 36 | """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model 37 | 38 | Parameters 39 | ---------- 40 | name : str 41 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 42 | device : Union[str, torch.device] 43 | The device to put the loaded model 44 | jit : bool 45 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 46 | 47 | Returns 48 | ------- 49 | model : torch.nn.Module 50 | The CLAP model 51 | preprocess : Callable[[PIL.Image], torch.Tensor] 52 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 53 | """ 54 | if get_pretrained_url(name, "openai"): 55 | model_path = download_pretrained( 56 | get_pretrained_url(name, "openai"), root=cache_dir 57 | ) 58 | elif os.path.isfile(name): 59 | model_path = name 60 | else: 61 | raise RuntimeError( 62 | f"Model {name} not found; available models = {list_openai_models()}" 63 | ) 64 | 65 | try: 66 | # loading JIT archive 67 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 68 | state_dict = None 69 | except RuntimeError: 70 | # loading saved state dict 71 | if jit: 72 | warnings.warn( 73 | f"File {model_path} is not a JIT archive. Loading as a state dict instead" 74 | ) 75 | jit = False 76 | state_dict = torch.load(model_path, map_location="cpu") 77 | 78 | if not jit: 79 | try: 80 | model = build_model_from_openai_state_dict( 81 | state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type 82 | ).to(device) 83 | except KeyError: 84 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 85 | model = build_model_from_openai_state_dict( 86 | sd, model_cfg, enable_fusion, fusion_type 87 | ).to(device) 88 | 89 | if str(device) == "cpu": 90 | model.float() 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace( 95 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] 96 | ) 97 | device_node = [ 98 | n 99 | for n in device_holder.graph.findAllNodes("prim::Constant") 100 | if "Device" in repr(n) 101 | ][-1] 102 | 103 | def patch_device(module): 104 | try: 105 | graphs = [module.graph] if hasattr(module, "graph") else [] 106 | except RuntimeError: 107 | graphs = [] 108 | 109 | if hasattr(module, "forward1"): 110 | graphs.append(module.forward1.graph) 111 | 112 | for graph in graphs: 113 | for node in graph.findAllNodes("prim::Constant"): 114 | if "value" in node.attributeNames() and str(node["value"]).startswith( 115 | "cuda" 116 | ): 117 | node.copyAttributes(device_node) 118 | 119 | model.apply(patch_device) 120 | patch_device(model.encode_audio) 121 | patch_device(model.encode_text) 122 | 123 | # patch dtype to float32 on CPU 124 | if str(device) == "cpu": 125 | float_holder = torch.jit.trace( 126 | lambda: torch.ones([]).float(), example_inputs=[] 127 | ) 128 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 129 | float_node = float_input.node() 130 | 131 | def patch_float(module): 132 | try: 133 | graphs = [module.graph] if hasattr(module, "graph") else [] 134 | except RuntimeError: 135 | graphs = [] 136 | 137 | if hasattr(module, "forward1"): 138 | graphs.append(module.forward1.graph) 139 | 140 | for graph in graphs: 141 | for node in graph.findAllNodes("aten::to"): 142 | inputs = list(node.inputs()) 143 | for i in [ 144 | 1, 145 | 2, 146 | ]: # dtype can be the second or third argument to aten::to() 147 | if inputs[i].node()["value"] == 5: 148 | inputs[i].node().copyAttributes(float_node) 149 | 150 | model.apply(patch_float) 151 | patch_float(model.encode_audio) 152 | patch_float(model.encode_text) 153 | model.float() 154 | 155 | model.audio_branch.audio_length = model.audio_cfg.audio_length 156 | return model 157 | -------------------------------------------------------------------------------- /audioldm2/utilities/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import audioldm2.hifigan as hifigan 4 | 5 | 6 | def get_vocoder_config(): 7 | return { 8 | "resblock": "1", 9 | "num_gpus": 6, 10 | "batch_size": 16, 11 | "learning_rate": 0.0002, 12 | "adam_b1": 0.8, 13 | "adam_b2": 0.99, 14 | "lr_decay": 0.999, 15 | "seed": 1234, 16 | "upsample_rates": [5, 4, 2, 2, 2], 17 | "upsample_kernel_sizes": [16, 16, 8, 4, 4], 18 | "upsample_initial_channel": 1024, 19 | "resblock_kernel_sizes": [3, 7, 11], 20 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 21 | "segment_size": 8192, 22 | "num_mels": 64, 23 | "num_freq": 1025, 24 | "n_fft": 1024, 25 | "hop_size": 160, 26 | "win_size": 1024, 27 | "sampling_rate": 16000, 28 | "fmin": 0, 29 | "fmax": 8000, 30 | "fmax_for_loss": None, 31 | "num_workers": 4, 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1, 36 | }, 37 | } 38 | 39 | def get_vocoder_config_48k(): 40 | return { 41 | "resblock": "1", 42 | "num_gpus": 8, 43 | "batch_size": 128, 44 | "learning_rate": 0.0001, 45 | "adam_b1": 0.8, 46 | "adam_b2": 0.99, 47 | "lr_decay": 0.999, 48 | "seed": 1234, 49 | 50 | "upsample_rates": [6,5,4,2,2], 51 | "upsample_kernel_sizes": [12,10,8,4,4], 52 | "upsample_initial_channel": 1536, 53 | "resblock_kernel_sizes": [3,7,11,15], 54 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]], 55 | 56 | "segment_size": 15360, 57 | "num_mels": 256, 58 | "n_fft": 2048, 59 | "hop_size": 480, 60 | "win_size": 2048, 61 | 62 | "sampling_rate": 48000, 63 | 64 | "fmin": 20, 65 | "fmax": 24000, 66 | "fmax_for_loss": None, 67 | 68 | "num_workers": 8, 69 | 70 | "dist_config": { 71 | "dist_backend": "nccl", 72 | "dist_url": "tcp://localhost:18273", 73 | "world_size": 1 74 | } 75 | } 76 | 77 | 78 | def get_available_checkpoint_keys(model, ckpt): 79 | state_dict = torch.load(ckpt)["state_dict"] 80 | current_state_dict = model.state_dict() 81 | new_state_dict = {} 82 | for k in state_dict.keys(): 83 | if ( 84 | k in current_state_dict.keys() 85 | and current_state_dict[k].size() == state_dict[k].size() 86 | ): 87 | new_state_dict[k] = state_dict[k] 88 | else: 89 | print("==> WARNING: Skipping %s" % k) 90 | print( 91 | "%s out of %s keys are matched" 92 | % (len(new_state_dict.keys()), len(state_dict.keys())) 93 | ) 94 | return new_state_dict 95 | 96 | 97 | def get_param_num(model): 98 | num_param = sum(param.numel() for param in model.parameters()) 99 | return num_param 100 | 101 | 102 | def torch_version_orig_mod_remove(state_dict): 103 | new_state_dict = {} 104 | new_state_dict["generator"] = {} 105 | for key in state_dict["generator"].keys(): 106 | if "_orig_mod." in key: 107 | new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[ 108 | "generator" 109 | ][key] 110 | else: 111 | new_state_dict["generator"][key] = state_dict["generator"][key] 112 | return new_state_dict 113 | 114 | 115 | def get_vocoder(config, device, mel_bins): 116 | name = "HiFi-GAN" 117 | speaker = "" 118 | if name == "MelGAN": 119 | if speaker == "LJSpeech": 120 | vocoder = torch.hub.load( 121 | "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" 122 | ) 123 | elif speaker == "universal": 124 | vocoder = torch.hub.load( 125 | "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" 126 | ) 127 | vocoder.mel2wav.eval() 128 | vocoder.mel2wav.to(device) 129 | elif name == "HiFi-GAN": 130 | if(mel_bins == 64): 131 | config = get_vocoder_config() 132 | config = hifigan.AttrDict(config) 133 | vocoder = hifigan.Generator_old(config) 134 | # print("Load hifigan/g_01080000") 135 | # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) 136 | # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) 137 | # ckpt = torch_version_orig_mod_remove(ckpt) 138 | # vocoder.load_state_dict(ckpt["generator"]) 139 | vocoder.eval() 140 | vocoder.remove_weight_norm() 141 | vocoder.to(device) 142 | else: 143 | config = get_vocoder_config_48k() 144 | config = hifigan.AttrDict(config) 145 | vocoder = hifigan.Generator_old(config) 146 | # print("Load hifigan/g_01080000") 147 | # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) 148 | # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) 149 | # ckpt = torch_version_orig_mod_remove(ckpt) 150 | # vocoder.load_state_dict(ckpt["generator"]) 151 | vocoder.eval() 152 | vocoder.remove_weight_norm() 153 | vocoder.to(device) 154 | return vocoder 155 | 156 | 157 | def vocoder_infer(mels, vocoder, lengths=None): 158 | with torch.no_grad(): 159 | wavs = vocoder(mels).squeeze(1) 160 | 161 | wavs = (wavs.cpu().numpy() * 32768).astype("int16") 162 | 163 | if lengths is not None: 164 | wavs = wavs[:, :lengths] 165 | 166 | # wavs = [wav for wav in wavs] 167 | 168 | # for i in range(len(mels)): 169 | # if lengths is not None: 170 | # wavs[i] = wavs[i][: lengths[i]] 171 | 172 | return wavs 173 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/AudioMAE.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference Repo: https://github.com/facebookresearch/AudioMAE 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from timm.models.layers import to_2tuple 8 | import audioldm2.latent_diffusion.modules.audiomae.models_vit as models_vit 9 | import audioldm2.latent_diffusion.modules.audiomae.models_mae as models_mae 10 | 11 | # model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128)) 12 | 13 | 14 | class PatchEmbed_new(nn.Module): 15 | """Flexible Image to Patch Embedding""" 16 | 17 | def __init__( 18 | self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 19 | ): 20 | super().__init__() 21 | img_size = to_2tuple(img_size) 22 | patch_size = to_2tuple(patch_size) 23 | stride = to_2tuple(stride) 24 | 25 | self.img_size = img_size 26 | self.patch_size = patch_size 27 | 28 | self.proj = nn.Conv2d( 29 | in_chans, embed_dim, kernel_size=patch_size, stride=stride 30 | ) # with overlapped patches 31 | # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 32 | 33 | # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) 34 | # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 35 | _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w 36 | self.patch_hw = (h, w) 37 | self.num_patches = h * w 38 | 39 | def get_output_shape(self, img_size): 40 | # todo: don't be lazy.. 41 | return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape 42 | 43 | def forward(self, x): 44 | B, C, H, W = x.shape 45 | # FIXME look at relaxing size constraints 46 | # assert H == self.img_size[0] and W == self.img_size[1], \ 47 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 48 | x = self.proj(x) 49 | x = x.flatten(2).transpose(1, 2) 50 | return x 51 | 52 | 53 | class AudioMAE(nn.Module): 54 | """Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)""" 55 | 56 | def __init__( 57 | self, 58 | ): 59 | super().__init__() 60 | model = models_vit.__dict__["vit_base_patch16"]( 61 | num_classes=527, 62 | drop_path_rate=0.1, 63 | global_pool=True, 64 | mask_2d=True, 65 | use_custom_patch=False, 66 | ) 67 | 68 | img_size = (1024, 128) 69 | emb_dim = 768 70 | 71 | model.patch_embed = PatchEmbed_new( 72 | img_size=img_size, 73 | patch_size=(16, 16), 74 | in_chans=1, 75 | embed_dim=emb_dim, 76 | stride=16, 77 | ) 78 | num_patches = model.patch_embed.num_patches 79 | # num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8 80 | model.pos_embed = nn.Parameter( 81 | torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False 82 | ) # fixed sin-cos embedding 83 | 84 | # checkpoint_path = '/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth' 85 | # checkpoint = torch.load(checkpoint_path, map_location='cpu') 86 | # msg = model.load_state_dict(checkpoint['model'], strict=False) 87 | # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}') 88 | 89 | self.model = model 90 | 91 | def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0): 92 | """ 93 | x: mel fbank [Batch, 1, T, F] 94 | mask_t_prob: 'T masking ratio (percentage of removed patches).' 95 | mask_f_prob: 'F masking ratio (percentage of removed patches).' 96 | """ 97 | return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob) 98 | 99 | 100 | class Vanilla_AudioMAE(nn.Module): 101 | """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM2)""" 102 | 103 | def __init__( 104 | self, 105 | ): 106 | super().__init__() 107 | model = models_mae.__dict__["mae_vit_base_patch16"]( 108 | in_chans=1, audio_exp=True, img_size=(1024, 128) 109 | ) 110 | 111 | # checkpoint_path = '/mnt/bn/lqhaoheliu/exps/checkpoints/audiomae/pretrained.pth' 112 | # checkpoint = torch.load(checkpoint_path, map_location='cpu') 113 | # msg = model.load_state_dict(checkpoint['model'], strict=False) 114 | 115 | # Skip the missing keys of decoder modules (not required) 116 | # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}') 117 | 118 | self.model = model.eval() 119 | 120 | def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False): 121 | """ 122 | x: mel fbank [Batch, 1, 1024 (T), 128 (F)] 123 | mask_ratio: 'masking ratio (percentage of removed patches).' 124 | """ 125 | with torch.no_grad(): 126 | # embed: [B, 513, 768] for mask_ratio=0.0 127 | if no_mask: 128 | if no_average: 129 | raise RuntimeError("This function is deprecated") 130 | embed = self.model.forward_encoder_no_random_mask_no_average( 131 | x 132 | ) # mask_ratio 133 | else: 134 | embed = self.model.forward_encoder_no_mask(x) # mask_ratio 135 | else: 136 | raise RuntimeError("This function is deprecated") 137 | embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio) 138 | return embed 139 | 140 | 141 | if __name__ == "__main__": 142 | model = Vanilla_AudioMAE().cuda() 143 | input = torch.randn(4, 1, 1024, 128).cuda() 144 | print("The first run") 145 | embed = model(input, mask_ratio=0.0, no_mask=True) 146 | print(embed) 147 | print("The second run") 148 | embed = model(input, mask_ratio=0.0) 149 | print(embed) 150 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | def test_reconstuct(): 5 | import yaml 6 | from diffusers import AutoencoderKL 7 | from transformers import SpeechT5HifiGan 8 | from audioldm2.utilities.data.dataset import AudioDataset 9 | from utils import load_clip, load_clap, load_t5 10 | 11 | model_path = '/maindata/data/shared/multimodal/public/dataset_music/audioldm2' 12 | config = yaml.load( 13 | open( 14 | 'config/16k_64.yaml', 15 | 'r' 16 | ), 17 | Loader=yaml.FullLoader, 18 | ) 19 | print(config) 20 | t5 = load_t5('cuda', max_length=256) 21 | clap = load_clap('cuda', max_length=256) 22 | 23 | dataset = AudioDataset( 24 | config=config, split="train", waveform_only=False, dataset_json_path='mini_dataset.json', 25 | tokenizer=clap.tokenizer, 26 | uncond_pro=0.1, 27 | text_ctx_len=77, 28 | tokenizer_t5=t5.tokenizer, 29 | text_ctx_len_t5=256, 30 | uncond_pro_t5=0.1, 31 | ) 32 | print(dataset[0]['log_mel_spec'].unsqueeze(0).unsqueeze(0).size()) 33 | 34 | vae = AutoencoderKL.from_pretrained(os.path.join(model_path, 'vae')) 35 | vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(model_path, 'vocoder')) 36 | latents = vae.encode(dataset[0]['log_mel_spec'].unsqueeze(0).unsqueeze(0)).latent_dist.sample().mul_(vae.config.scaling_factor) 37 | print('laten size:', latents.size()) 38 | 39 | latents = 1 / vae.config.scaling_factor * latents 40 | mel_spectrogram = vae.decode(latents).sample 41 | print(mel_spectrogram.size()) 42 | if mel_spectrogram.dim() == 4: 43 | mel_spectrogram = mel_spectrogram.squeeze(1) 44 | waveform = vocoder(mel_spectrogram) 45 | waveform = waveform[0].cpu().float().detach().numpy() 46 | print(waveform.shape) 47 | # import soundfile as sf 48 | # sf.write('reconstruct.wav', waveform, samplerate=16000) 49 | from scipy.io import wavfile 50 | # wavfile.write('reconstruct.wav', 16000, waveform) 51 | 52 | 53 | 54 | def mini_dataset(num=32): 55 | data = [] 56 | for i in range(num): 57 | data.append( 58 | { 59 | 'wav': 'case.mp3', 60 | 'label': 'a beautiful music', 61 | } 62 | ) 63 | 64 | with open('mini_dataset.json', 'w') as f: 65 | json.dump(data, f, indent=4) 66 | 67 | 68 | def fma_dataset(): 69 | import pandas as pd 70 | 71 | annotation_prex = "/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/annotation" 72 | annotation_list = ['test-00000-of-00001.parquet', 'train-00000-of-00001.parquet', 'valid-00000-of-00001.parquet'] 73 | dataset_prex = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/fma_large' 74 | 75 | data = [] 76 | for annotation_file in annotation_list: 77 | annotation_file = os.path.join(annotation_prex, annotation_file) 78 | df=pd.read_parquet(annotation_file) 79 | print(df.shape) 80 | for id, row in df.iterrows(): 81 | #print(id, row['pseudo_caption'], row['path']) 82 | tmp_path = os.path.join(dataset_prex, row['path'] + '.mp3') 83 | # print(tmp_path) 84 | if os.path.exists(tmp_path): 85 | data.append( 86 | { 87 | 'wav': tmp_path, 88 | 'label': row['pseudo_caption'], 89 | } 90 | ) 91 | # break 92 | print(len(data)) 93 | with open('fma_dataset.json', 'w') as f: 94 | json.dump(data, f, indent=4) 95 | 96 | 97 | 98 | 99 | 100 | def audioset_dataset(): 101 | import pandas as pd 102 | dataset_prex = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioset' 103 | annotation_path = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioset/balanced_train-00000-of-00001.parquet' 104 | df=pd.read_parquet(annotation_path) 105 | print(df.shape) 106 | 107 | data = [] 108 | for id, row in df.iterrows(): 109 | #print(id, row['pseudo_caption'], row['path']) 110 | try: 111 | tmp_path = os.path.join(dataset_prex, row['path'] + '.flac') 112 | except: 113 | print(row['path']) 114 | 115 | if os.path.exists(tmp_path): 116 | # print(tmp_path) 117 | data.append( 118 | { 119 | 'wav': tmp_path, 120 | 'label': row['pseudo_caption'], 121 | } 122 | ) 123 | print(len(data)) 124 | with open('audioset_dataset.json', 'w') as f: 125 | json.dump(data, f, indent=4) 126 | 127 | 128 | 129 | def combine_dataset(): 130 | data_list = ['fma_dataset.json', 'audioset_dataset.json'] 131 | 132 | data = [] 133 | for data_file in data_list: 134 | with open(data_file, 'r') as f: 135 | data += json.load(f) 136 | print(len(data)) 137 | with open('combine_dataset.json', 'w') as f: 138 | json.dump(data, f, indent=4) 139 | 140 | 141 | 142 | def test_music_format(): 143 | import torchaudio 144 | filename = '2.flac' 145 | waveform, sr = torchaudio.load(filename,) 146 | print(waveform, sr ) 147 | 148 | 149 | def test_flops(): 150 | version = 'giant' 151 | import torch 152 | from constants import build_model 153 | from thop import profile 154 | 155 | model = build_model(version).cuda() 156 | img_ids = torch.randn((1, 1024, 3)).cuda() 157 | txt = torch.randn((1, 256, 4096)).cuda() 158 | txt_ids = torch.randn((1, 256, 3)).cuda() 159 | y = torch.randn((1, 768)).cuda() 160 | x = torch.randn((1, 1024, 32)).cuda() 161 | t = torch.tensor([1] * 1).cuda() 162 | flops, _ = profile(model, inputs=(x, img_ids, txt, txt_ids, t, y,)) 163 | print('FLOPs = ' + str(flops * 2/1000**3) + 'G') 164 | 165 | 166 | # test_music_format() 167 | # test_reconstuct() 168 | # mini_dataset() 169 | # fma_dataset() 170 | # audioset_dataset() 171 | # combine_dataset() 172 | test_flops() -------------------------------------------------------------------------------- /audioldm2/hifigan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv1d, ConvTranspose1d 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResBlock(torch.nn.Module): 21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 22 | super(ResBlock, self).__init__() 23 | self.h = h 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[0], 33 | padding=get_padding(kernel_size, dilation[0]), 34 | ) 35 | ), 36 | weight_norm( 37 | Conv1d( 38 | channels, 39 | channels, 40 | kernel_size, 41 | 1, 42 | dilation=dilation[1], 43 | padding=get_padding(kernel_size, dilation[1]), 44 | ) 45 | ), 46 | weight_norm( 47 | Conv1d( 48 | channels, 49 | channels, 50 | kernel_size, 51 | 1, 52 | dilation=dilation[2], 53 | padding=get_padding(kernel_size, dilation[2]), 54 | ) 55 | ), 56 | ] 57 | ) 58 | self.convs1.apply(init_weights) 59 | 60 | self.convs2 = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | Conv1d( 64 | channels, 65 | channels, 66 | kernel_size, 67 | 1, 68 | dilation=1, 69 | padding=get_padding(kernel_size, 1), 70 | ) 71 | ), 72 | weight_norm( 73 | Conv1d( 74 | channels, 75 | channels, 76 | kernel_size, 77 | 1, 78 | dilation=1, 79 | padding=get_padding(kernel_size, 1), 80 | ) 81 | ), 82 | weight_norm( 83 | Conv1d( 84 | channels, 85 | channels, 86 | kernel_size, 87 | 1, 88 | dilation=1, 89 | padding=get_padding(kernel_size, 1), 90 | ) 91 | ), 92 | ] 93 | ) 94 | self.convs2.apply(init_weights) 95 | 96 | def forward(self, x): 97 | for c1, c2 in zip(self.convs1, self.convs2): 98 | xt = F.leaky_relu(x, LRELU_SLOPE) 99 | xt = c1(xt) 100 | xt = F.leaky_relu(xt, LRELU_SLOPE) 101 | xt = c2(xt) 102 | x = xt + x 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs1: 107 | remove_weight_norm(l) 108 | for l in self.convs2: 109 | remove_weight_norm(l) 110 | 111 | 112 | class Generator(torch.nn.Module): 113 | def __init__(self, h): 114 | super(Generator, self).__init__() 115 | self.h = h 116 | self.num_kernels = len(h.resblock_kernel_sizes) 117 | self.num_upsamples = len(h.upsample_rates) 118 | self.conv_pre = weight_norm( 119 | Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) 120 | ) 121 | resblock = ResBlock 122 | 123 | self.ups = nn.ModuleList() 124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 125 | self.ups.append( 126 | weight_norm( 127 | ConvTranspose1d( 128 | h.upsample_initial_channel // (2**i), 129 | h.upsample_initial_channel // (2 ** (i + 1)), 130 | k, 131 | u, 132 | padding=(k - u) // 2, 133 | ) 134 | ) 135 | ) 136 | 137 | self.resblocks = nn.ModuleList() 138 | for i in range(len(self.ups)): 139 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 140 | for j, (k, d) in enumerate( 141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 142 | ): 143 | self.resblocks.append(resblock(h, ch, k, d)) 144 | 145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 146 | self.ups.apply(init_weights) 147 | self.conv_post.apply(init_weights) 148 | 149 | def forward(self, x): 150 | x = self.conv_pre(x) 151 | for i in range(self.num_upsamples): 152 | x = F.leaky_relu(x, LRELU_SLOPE) 153 | x = self.ups[i](x) 154 | xs = None 155 | for j in range(self.num_kernels): 156 | if xs is None: 157 | xs = self.resblocks[i * self.num_kernels + j](x) 158 | else: 159 | xs += self.resblocks[i * self.num_kernels + j](x) 160 | x = xs / self.num_kernels 161 | x = F.leaky_relu(x) 162 | x = self.conv_post(x) 163 | x = torch.tanh(x) 164 | 165 | return x 166 | 167 | def remove_weight_norm(self): 168 | # print("Removing weight norm...") 169 | for l in self.ups: 170 | remove_weight_norm(l) 171 | for l in self.resblocks: 172 | l.remove_weight_norm() 173 | remove_weight_norm(self.conv_pre) 174 | remove_weight_norm(self.conv_post) 175 | -------------------------------------------------------------------------------- /audioldm2/pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import yaml 5 | import torch 6 | import torchaudio 7 | 8 | import audioldm2.latent_diffusion.modules.phoneme_encoder.text as text 9 | from audioldm2.latent_diffusion.models.ddpm import LatentDiffusion 10 | from audioldm2.latent_diffusion.util import get_vits_phoneme_ids_no_padding 11 | from audioldm2.utils import default_audioldm_config, download_checkpoint 12 | import os 13 | 14 | # CACHE_DIR = os.getenv( 15 | # "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2") 16 | # ) 17 | 18 | def seed_everything(seed): 19 | import random, os 20 | import numpy as np 21 | import torch 22 | 23 | random.seed(seed) 24 | os.environ["PYTHONHASHSEED"] = str(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = True 30 | 31 | def text2phoneme(data): 32 | return text._clean_text(re.sub(r'<.*?>', '', data), ["english_cleaners2"]) 33 | 34 | def text_to_filename(text): 35 | return text.replace(" ", "_").replace("'", "_").replace('"', "_") 36 | 37 | def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec): 38 | norm_mean = -4.2677393 39 | norm_std = 4.5689974 40 | 41 | if sampling_rate != 16000: 42 | waveform_16k = torchaudio.functional.resample( 43 | waveform, orig_freq=sampling_rate, new_freq=16000 44 | ) 45 | else: 46 | waveform_16k = waveform 47 | 48 | waveform_16k = waveform_16k - waveform_16k.mean() 49 | fbank = torchaudio.compliance.kaldi.fbank( 50 | waveform_16k, 51 | htk_compat=True, 52 | sample_frequency=16000, 53 | use_energy=False, 54 | window_type="hanning", 55 | num_mel_bins=128, 56 | dither=0.0, 57 | frame_shift=10, 58 | ) 59 | 60 | TARGET_LEN = log_mel_spec.size(0) 61 | 62 | # cut and pad 63 | n_frames = fbank.shape[0] 64 | p = TARGET_LEN - n_frames 65 | if p > 0: 66 | m = torch.nn.ZeroPad2d((0, 0, 0, p)) 67 | fbank = m(fbank) 68 | elif p < 0: 69 | fbank = fbank[:TARGET_LEN, :] 70 | 71 | fbank = (fbank - norm_mean) / (norm_std * 2) 72 | 73 | return {"ta_kaldi_fbank": fbank} # [1024, 128] 74 | 75 | def make_batch_for_text_to_audio(text, transcription="", waveform=None, fbank=None, batchsize=1): 76 | text = [text] * batchsize 77 | if(transcription): 78 | transcription = text2phoneme(transcription) 79 | transcription = [transcription] * batchsize 80 | 81 | if batchsize < 1: 82 | print("Warning: Batchsize must be at least 1. Batchsize is set to .") 83 | 84 | if fbank is None: 85 | fbank = torch.zeros( 86 | (batchsize, 1024, 64) 87 | ) # Not used, here to keep the code format 88 | else: 89 | fbank = torch.FloatTensor(fbank) 90 | fbank = fbank.expand(batchsize, 1024, 64) 91 | assert fbank.size(0) == batchsize 92 | 93 | stft = torch.zeros((batchsize, 1024, 512)) # Not used 94 | phonemes = get_vits_phoneme_ids_no_padding(transcription) 95 | 96 | if waveform is None: 97 | waveform = torch.zeros((batchsize, 160000)) # Not used 98 | ta_kaldi_fbank = torch.zeros((batchsize, 1024, 128)) 99 | else: 100 | waveform = torch.FloatTensor(waveform) 101 | waveform = waveform.expand(batchsize, -1) 102 | assert waveform.size(0) == batchsize 103 | ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, 16000, fbank) 104 | 105 | batch = { 106 | "text": text, # list 107 | "fname": [text_to_filename(t) for t in text], # list 108 | "waveform": waveform, 109 | "stft": stft, 110 | "log_mel_spec": fbank, 111 | "ta_kaldi_fbank": ta_kaldi_fbank, 112 | } 113 | batch.update(phonemes) 114 | return batch 115 | 116 | 117 | def round_up_duration(duration): 118 | return int(round(duration / 2.5) + 1) * 2.5 119 | 120 | 121 | # def split_clap_weight_to_pth(checkpoint): 122 | # if os.path.exists(os.path.join(CACHE_DIR, "clap.pth")): 123 | # return 124 | # print("Constructing the weight for the CLAP model.") 125 | # include_keys = "cond_stage_models.0.cond_stage_models.0.model." 126 | # new_state_dict = {} 127 | # for each in checkpoint["state_dict"].keys(): 128 | # if include_keys in each: 129 | # new_state_dict[each.replace(include_keys, "module.")] = checkpoint[ 130 | # "state_dict" 131 | # ][each] 132 | # torch.save({"state_dict": new_state_dict}, os.path.join(CACHE_DIR, "clap.pth")) 133 | 134 | 135 | def build_model(ckpt_path=None, config=None, device=None, model_name="audioldm2-full"): 136 | 137 | if device is None or device == "auto": 138 | if torch.cuda.is_available(): 139 | device = torch.device("cuda:0") 140 | elif torch.backends.mps.is_available(): 141 | device = torch.device("mps") 142 | else: 143 | device = torch.device("cpu") 144 | 145 | print("Loading AudioLDM-2: %s" % model_name) 146 | print("Loading model on %s" % device) 147 | 148 | ckpt_path = download_checkpoint(model_name) 149 | 150 | if config is not None: 151 | assert type(config) is str 152 | config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) 153 | else: 154 | config = default_audioldm_config(model_name) 155 | 156 | # # Use text as condition instead of using waveform during training 157 | config["model"]["params"]["device"] = device 158 | # config["model"]["params"]["cond_stage_key"] = "text" 159 | 160 | # No normalization here 161 | latent_diffusion = LatentDiffusion(**config["model"]["params"]) 162 | 163 | resume_from_checkpoint = ckpt_path 164 | 165 | checkpoint = torch.load(resume_from_checkpoint, map_location=device) 166 | 167 | latent_diffusion.load_state_dict(checkpoint["state_dict"]) 168 | 169 | latent_diffusion.eval() 170 | latent_diffusion = latent_diffusion.to(device) 171 | 172 | return latent_diffusion 173 | 174 | def text_to_audio( 175 | latent_diffusion, 176 | text, 177 | transcription="", 178 | seed=42, 179 | ddim_steps=200, 180 | duration=10, 181 | batchsize=1, 182 | guidance_scale=3.5, 183 | n_candidate_gen_per_text=3, 184 | latent_t_per_second=25.6, 185 | config=None, 186 | ): 187 | 188 | seed_everything(int(seed)) 189 | waveform = None 190 | 191 | batch = make_batch_for_text_to_audio(text, transcription=transcription, waveform=waveform, batchsize=batchsize) 192 | 193 | latent_diffusion.latent_t_size = int(duration * latent_t_per_second) 194 | 195 | with torch.no_grad(): 196 | waveform = latent_diffusion.generate_batch( 197 | batch, 198 | unconditional_guidance_scale=guidance_scale, 199 | ddim_steps=ddim_steps, 200 | n_gen=n_candidate_gen_per_text, 201 | duration=duration, 202 | ) 203 | 204 | return waveform 205 | -------------------------------------------------------------------------------- /audioldm2/utilities/audio/stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy.signal import get_window 5 | from librosa.util import pad_center, tiny 6 | from librosa.filters import mel as librosa_mel_fn 7 | 8 | from audioldm2.utilities.audio.audio_processing import ( 9 | dynamic_range_compression, 10 | dynamic_range_decompression, 11 | window_sumsquare, 12 | ) 13 | 14 | 15 | class STFT(torch.nn.Module): 16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 17 | 18 | def __init__(self, filter_length, hop_length, win_length, window="hann"): 19 | super(STFT, self).__init__() 20 | self.filter_length = filter_length 21 | self.hop_length = hop_length 22 | self.win_length = win_length 23 | self.window = window 24 | self.forward_transform = None 25 | scale = self.filter_length / self.hop_length 26 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 27 | 28 | cutoff = int((self.filter_length / 2 + 1)) 29 | fourier_basis = np.vstack( 30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] 31 | ) 32 | 33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 34 | inverse_basis = torch.FloatTensor( 35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :] 36 | ) 37 | 38 | if window is not None: 39 | assert filter_length >= win_length 40 | # get window and zero center pad it to filter_length 41 | fft_window = get_window(window, win_length, fftbins=True) 42 | fft_window = pad_center(fft_window, filter_length) 43 | fft_window = torch.from_numpy(fft_window).float() 44 | 45 | # window the bases 46 | forward_basis *= fft_window 47 | inverse_basis *= fft_window 48 | 49 | self.register_buffer("forward_basis", forward_basis.float()) 50 | self.register_buffer("inverse_basis", inverse_basis.float()) 51 | 52 | def transform(self, input_data): 53 | num_batches = input_data.size(0) 54 | num_samples = input_data.size(1) 55 | 56 | self.num_samples = num_samples 57 | 58 | # similar to librosa, reflect-pad the input 59 | input_data = input_data.view(num_batches, 1, num_samples) 60 | input_data = F.pad( 61 | input_data.unsqueeze(1), 62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 63 | mode="reflect", 64 | ) 65 | input_data = input_data.squeeze(1) 66 | 67 | forward_transform = F.conv1d( 68 | input_data, 69 | torch.autograd.Variable(self.forward_basis, requires_grad=False), 70 | stride=self.hop_length, 71 | padding=0, 72 | ).cpu() 73 | 74 | cutoff = int((self.filter_length / 2) + 1) 75 | real_part = forward_transform[:, :cutoff, :] 76 | imag_part = forward_transform[:, cutoff:, :] 77 | 78 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) 80 | 81 | return magnitude, phase 82 | 83 | def inverse(self, magnitude, phase): 84 | recombine_magnitude_phase = torch.cat( 85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 86 | ) 87 | 88 | inverse_transform = F.conv_transpose1d( 89 | recombine_magnitude_phase, 90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False), 91 | stride=self.hop_length, 92 | padding=0, 93 | ) 94 | 95 | if self.window is not None: 96 | window_sum = window_sumsquare( 97 | self.window, 98 | magnitude.size(-1), 99 | hop_length=self.hop_length, 100 | win_length=self.win_length, 101 | n_fft=self.filter_length, 102 | dtype=np.float32, 103 | ) 104 | # remove modulation effects 105 | approx_nonzero_indices = torch.from_numpy( 106 | np.where(window_sum > tiny(window_sum))[0] 107 | ) 108 | window_sum = torch.autograd.Variable( 109 | torch.from_numpy(window_sum), requires_grad=False 110 | ) 111 | window_sum = window_sum 112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ 113 | approx_nonzero_indices 114 | ] 115 | 116 | # scale by hop ratio 117 | inverse_transform *= float(self.filter_length) / self.hop_length 118 | 119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] 120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] 121 | 122 | return inverse_transform 123 | 124 | def forward(self, input_data): 125 | self.magnitude, self.phase = self.transform(input_data) 126 | reconstruction = self.inverse(self.magnitude, self.phase) 127 | return reconstruction 128 | 129 | 130 | class TacotronSTFT(torch.nn.Module): 131 | def __init__( 132 | self, 133 | filter_length, 134 | hop_length, 135 | win_length, 136 | n_mel_channels, 137 | sampling_rate, 138 | mel_fmin, 139 | mel_fmax, 140 | ): 141 | super(TacotronSTFT, self).__init__() 142 | self.n_mel_channels = n_mel_channels 143 | self.sampling_rate = sampling_rate 144 | self.stft_fn = STFT(filter_length, hop_length, win_length) 145 | mel_basis = librosa_mel_fn( 146 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax 147 | ) 148 | mel_basis = torch.from_numpy(mel_basis).float() 149 | self.register_buffer("mel_basis", mel_basis) 150 | 151 | def spectral_normalize(self, magnitudes, normalize_fun): 152 | output = dynamic_range_compression(magnitudes, normalize_fun) 153 | return output 154 | 155 | def spectral_de_normalize(self, magnitudes): 156 | output = dynamic_range_decompression(magnitudes) 157 | return output 158 | 159 | def mel_spectrogram(self, y, normalize_fun=torch.log): 160 | """Computes mel-spectrograms from a batch of waves 161 | PARAMS 162 | ------ 163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 164 | 165 | RETURNS 166 | ------- 167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 168 | """ 169 | assert torch.min(y.data) >= -1, torch.min(y.data) 170 | assert torch.max(y.data) <= 1, torch.max(y.data) 171 | 172 | magnitudes, phases = self.stft_fn.transform(y) 173 | magnitudes = magnitudes.data 174 | mel_output = torch.matmul(self.mel_basis, magnitudes) 175 | mel_output = self.spectral_normalize(mel_output, normalize_fun) 176 | energy = torch.norm(magnitudes, dim=1) 177 | 178 | return mel_output, magnitudes, phases, energy 179 | -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | 6 | from tqdm import tqdm 7 | 8 | _RN50 = dict( 9 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 10 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 11 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", 12 | ) 13 | 14 | _RN50_quickgelu = dict( 15 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 16 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 17 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", 18 | ) 19 | 20 | _RN101 = dict( 21 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 22 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", 23 | ) 24 | 25 | _RN101_quickgelu = dict( 26 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 27 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", 28 | ) 29 | 30 | _RN50x4 = dict( 31 | openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 32 | ) 33 | 34 | _RN50x16 = dict( 35 | openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 36 | ) 37 | 38 | _RN50x64 = dict( 39 | openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 40 | ) 41 | 42 | _VITB32 = dict( 43 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 44 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 45 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 46 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", 47 | ) 48 | 49 | _VITB32_quickgelu = dict( 50 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 51 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 52 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 53 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", 54 | ) 55 | 56 | _VITB16 = dict( 57 | openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 58 | ) 59 | 60 | _VITL14 = dict( 61 | openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 62 | ) 63 | 64 | _PRETRAINED = { 65 | "RN50": _RN50, 66 | "RN50-quickgelu": _RN50_quickgelu, 67 | "RN101": _RN101, 68 | "RN101-quickgelu": _RN101_quickgelu, 69 | "RN50x4": _RN50x4, 70 | "RN50x16": _RN50x16, 71 | "ViT-B-32": _VITB32, 72 | "ViT-B-32-quickgelu": _VITB32_quickgelu, 73 | "ViT-B-16": _VITB16, 74 | "ViT-L-14": _VITL14, 75 | } 76 | 77 | 78 | def list_pretrained(as_str: bool = False): 79 | """returns list of pretrained models 80 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 81 | """ 82 | return [ 83 | ":".join([k, t]) if as_str else (k, t) 84 | for k in _PRETRAINED.keys() 85 | for t in _PRETRAINED[k].keys() 86 | ] 87 | 88 | 89 | def list_pretrained_tag_models(tag: str): 90 | """return all models having the specified pretrain tag""" 91 | models = [] 92 | for k in _PRETRAINED.keys(): 93 | if tag in _PRETRAINED[k]: 94 | models.append(k) 95 | return models 96 | 97 | 98 | def list_pretrained_model_tags(model: str): 99 | """return all pretrain tags for the specified model architecture""" 100 | tags = [] 101 | if model in _PRETRAINED: 102 | tags.extend(_PRETRAINED[model].keys()) 103 | return tags 104 | 105 | 106 | def get_pretrained_url(model: str, tag: str): 107 | if model not in _PRETRAINED: 108 | return "" 109 | model_pretrained = _PRETRAINED[model] 110 | if tag not in model_pretrained: 111 | return "" 112 | return model_pretrained[tag] 113 | 114 | 115 | def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): 116 | os.makedirs(root, exist_ok=True) 117 | filename = os.path.basename(url) 118 | 119 | if "openaipublic" in url: 120 | expected_sha256 = url.split("/")[-2] 121 | else: 122 | expected_sha256 = "" 123 | 124 | download_target = os.path.join(root, filename) 125 | 126 | if os.path.exists(download_target) and not os.path.isfile(download_target): 127 | raise RuntimeError(f"{download_target} exists and is not a regular file") 128 | 129 | if os.path.isfile(download_target): 130 | if expected_sha256: 131 | if ( 132 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 133 | == expected_sha256 134 | ): 135 | return download_target 136 | else: 137 | warnings.warn( 138 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 139 | ) 140 | else: 141 | return download_target 142 | 143 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 144 | with tqdm( 145 | total=int(source.info().get("Content-Length")), 146 | ncols=80, 147 | unit="iB", 148 | unit_scale=True, 149 | ) as loop: 150 | while True: 151 | buffer = source.read(8192) 152 | if not buffer: 153 | break 154 | 155 | output.write(buffer) 156 | loop.update(len(buffer)) 157 | 158 | if ( 159 | expected_sha256 160 | and hashlib.sha256(open(download_target, "rb").read()).hexdigest() 161 | != expected_sha256 162 | ): 163 | raise RuntimeError( 164 | f"Model has been downloaded but the SHA256 checksum does not not match" 165 | ) 166 | 167 | return download_target 168 | -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | 16 | @lru_cache() 17 | def default_bpe(): 18 | return os.path.join( 19 | os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" 20 | ) 21 | 22 | 23 | @lru_cache() 24 | def bytes_to_unicode(): 25 | """ 26 | Returns list of utf-8 byte and a corresponding list of unicode strings. 27 | The reversible bpe codes work on unicode strings. 28 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 29 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 30 | This is a signficant percentage of your normal, say, 32K bpe vocab. 31 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 32 | And avoids mapping to whitespace/control characters the bpe code barfs on. 33 | """ 34 | bs = ( 35 | list(range(ord("!"), ord("~") + 1)) 36 | + list(range(ord("¡"), ord("¬") + 1)) 37 | + list(range(ord("®"), ord("ÿ") + 1)) 38 | ) 39 | cs = bs[:] 40 | n = 0 41 | for b in range(2**8): 42 | if b not in bs: 43 | bs.append(b) 44 | cs.append(2**8 + n) 45 | n += 1 46 | cs = [chr(n) for n in cs] 47 | return dict(zip(bs, cs)) 48 | 49 | 50 | def get_pairs(word): 51 | """Return set of symbol pairs in a word. 52 | Word is represented as tuple of symbols (symbols being variable-length strings). 53 | """ 54 | pairs = set() 55 | prev_char = word[0] 56 | for char in word[1:]: 57 | pairs.add((prev_char, char)) 58 | prev_char = char 59 | return pairs 60 | 61 | 62 | def basic_clean(text): 63 | text = ftfy.fix_text(text) 64 | text = html.unescape(html.unescape(text)) 65 | return text.strip() 66 | 67 | 68 | def whitespace_clean(text): 69 | text = re.sub(r"\s+", " ", text) 70 | text = text.strip() 71 | return text 72 | 73 | 74 | class SimpleTokenizer(object): 75 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 76 | self.byte_encoder = bytes_to_unicode() 77 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 78 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") 79 | merges = merges[1 : 49152 - 256 - 2 + 1] 80 | merges = [tuple(merge.split()) for merge in merges] 81 | vocab = list(bytes_to_unicode().values()) 82 | vocab = vocab + [v + "" for v in vocab] 83 | for merge in merges: 84 | vocab.append("".join(merge)) 85 | if not special_tokens: 86 | special_tokens = ["", ""] 87 | else: 88 | special_tokens = ["", ""] + special_tokens 89 | vocab.extend(special_tokens) 90 | self.encoder = dict(zip(vocab, range(len(vocab)))) 91 | self.decoder = {v: k for k, v in self.encoder.items()} 92 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 93 | self.cache = {t: t for t in special_tokens} 94 | special = "|".join(special_tokens) 95 | self.pat = re.compile( 96 | special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 97 | re.IGNORECASE, 98 | ) 99 | 100 | self.vocab_size = len(self.encoder) 101 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 102 | 103 | def bpe(self, token): 104 | if token in self.cache: 105 | return self.cache[token] 106 | word = tuple(token[:-1]) + (token[-1] + "",) 107 | pairs = get_pairs(word) 108 | 109 | if not pairs: 110 | return token + "" 111 | 112 | while True: 113 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 114 | if bigram not in self.bpe_ranks: 115 | break 116 | first, second = bigram 117 | new_word = [] 118 | i = 0 119 | while i < len(word): 120 | try: 121 | j = word.index(first, i) 122 | new_word.extend(word[i:j]) 123 | i = j 124 | except: 125 | new_word.extend(word[i:]) 126 | break 127 | 128 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 129 | new_word.append(first + second) 130 | i += 2 131 | else: 132 | new_word.append(word[i]) 133 | i += 1 134 | new_word = tuple(new_word) 135 | word = new_word 136 | if len(word) == 1: 137 | break 138 | else: 139 | pairs = get_pairs(word) 140 | word = " ".join(word) 141 | self.cache[token] = word 142 | return word 143 | 144 | def encode(self, text): 145 | bpe_tokens = [] 146 | text = whitespace_clean(basic_clean(text)).lower() 147 | for token in re.findall(self.pat, text): 148 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 149 | bpe_tokens.extend( 150 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") 151 | ) 152 | return bpe_tokens 153 | 154 | def decode(self, tokens): 155 | text = "".join([self.decoder[token] for token in tokens]) 156 | text = ( 157 | bytearray([self.byte_decoder[c] for c in text]) 158 | .decode("utf-8", errors="replace") 159 | .replace("", " ") 160 | ) 161 | return text 162 | 163 | 164 | _tokenizer = SimpleTokenizer() 165 | 166 | 167 | def tokenize( 168 | texts: Union[str, List[str]], context_length: int = 77 169 | ) -> torch.LongTensor: 170 | """ 171 | Returns the tokenized representation of given input string(s) 172 | 173 | Parameters 174 | ---------- 175 | texts : Union[str, List[str]] 176 | An input string or a list of input strings to tokenize 177 | context_length : int 178 | The context length to use; all CLIP models use 77 as the context length 179 | 180 | Returns 181 | ------- 182 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 183 | """ 184 | if isinstance(texts, str): 185 | texts = [texts] 186 | 187 | sot_token = _tokenizer.encoder[""] 188 | eot_token = _tokenizer.encoder[""] 189 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 190 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 191 | 192 | for i, tokens in enumerate(all_tokens): 193 | if len(tokens) > context_length: 194 | tokens = tokens[:context_length] # Truncate 195 | result[i, : len(tokens)] = torch.tensor(tokens) 196 | 197 | return result 198 | -------------------------------------------------------------------------------- /audioldm2/clap/open_clip/feature_fusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Feature Fusion for Varible-Length Data Processing 3 | AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py 4 | According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class DAF(nn.Module): 12 | """ 13 | 直接相加 DirectAddFuse 14 | """ 15 | 16 | def __init__(self): 17 | super(DAF, self).__init__() 18 | 19 | def forward(self, x, residual): 20 | return x + residual 21 | 22 | 23 | class iAFF(nn.Module): 24 | """ 25 | 多特征融合 iAFF 26 | """ 27 | 28 | def __init__(self, channels=64, r=4, type="2D"): 29 | super(iAFF, self).__init__() 30 | inter_channels = int(channels // r) 31 | 32 | if type == "1D": 33 | # 本地注意力 34 | self.local_att = nn.Sequential( 35 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 36 | nn.BatchNorm1d(inter_channels), 37 | nn.ReLU(inplace=True), 38 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 39 | nn.BatchNorm1d(channels), 40 | ) 41 | 42 | # 全局注意力 43 | self.global_att = nn.Sequential( 44 | nn.AdaptiveAvgPool1d(1), 45 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 46 | nn.BatchNorm1d(inter_channels), 47 | nn.ReLU(inplace=True), 48 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 49 | nn.BatchNorm1d(channels), 50 | ) 51 | 52 | # 第二次本地注意力 53 | self.local_att2 = nn.Sequential( 54 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 55 | nn.BatchNorm1d(inter_channels), 56 | nn.ReLU(inplace=True), 57 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 58 | nn.BatchNorm1d(channels), 59 | ) 60 | # 第二次全局注意力 61 | self.global_att2 = nn.Sequential( 62 | nn.AdaptiveAvgPool1d(1), 63 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 64 | nn.BatchNorm1d(inter_channels), 65 | nn.ReLU(inplace=True), 66 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 67 | nn.BatchNorm1d(channels), 68 | ) 69 | elif type == "2D": 70 | # 本地注意力 71 | self.local_att = nn.Sequential( 72 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 73 | nn.BatchNorm2d(inter_channels), 74 | nn.ReLU(inplace=True), 75 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 76 | nn.BatchNorm2d(channels), 77 | ) 78 | 79 | # 全局注意力 80 | self.global_att = nn.Sequential( 81 | nn.AdaptiveAvgPool2d(1), 82 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 83 | nn.BatchNorm2d(inter_channels), 84 | nn.ReLU(inplace=True), 85 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 86 | nn.BatchNorm2d(channels), 87 | ) 88 | 89 | # 第二次本地注意力 90 | self.local_att2 = nn.Sequential( 91 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 92 | nn.BatchNorm2d(inter_channels), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 95 | nn.BatchNorm2d(channels), 96 | ) 97 | # 第二次全局注意力 98 | self.global_att2 = nn.Sequential( 99 | nn.AdaptiveAvgPool2d(1), 100 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 101 | nn.BatchNorm2d(inter_channels), 102 | nn.ReLU(inplace=True), 103 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 104 | nn.BatchNorm2d(channels), 105 | ) 106 | else: 107 | raise f"the type is not supported" 108 | 109 | self.sigmoid = nn.Sigmoid() 110 | 111 | def forward(self, x, residual): 112 | flag = False 113 | xa = x + residual 114 | if xa.size(0) == 1: 115 | xa = torch.cat([xa, xa], dim=0) 116 | flag = True 117 | xl = self.local_att(xa) 118 | xg = self.global_att(xa) 119 | xlg = xl + xg 120 | wei = self.sigmoid(xlg) 121 | xi = x * wei + residual * (1 - wei) 122 | 123 | xl2 = self.local_att2(xi) 124 | xg2 = self.global_att(xi) 125 | xlg2 = xl2 + xg2 126 | wei2 = self.sigmoid(xlg2) 127 | xo = x * wei2 + residual * (1 - wei2) 128 | if flag: 129 | xo = xo[0].unsqueeze(0) 130 | return xo 131 | 132 | 133 | class AFF(nn.Module): 134 | """ 135 | 多特征融合 AFF 136 | """ 137 | 138 | def __init__(self, channels=64, r=4, type="2D"): 139 | super(AFF, self).__init__() 140 | inter_channels = int(channels // r) 141 | 142 | if type == "1D": 143 | self.local_att = nn.Sequential( 144 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 145 | nn.BatchNorm1d(inter_channels), 146 | nn.ReLU(inplace=True), 147 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 148 | nn.BatchNorm1d(channels), 149 | ) 150 | self.global_att = nn.Sequential( 151 | nn.AdaptiveAvgPool1d(1), 152 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 153 | nn.BatchNorm1d(inter_channels), 154 | nn.ReLU(inplace=True), 155 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 156 | nn.BatchNorm1d(channels), 157 | ) 158 | elif type == "2D": 159 | self.local_att = nn.Sequential( 160 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 161 | nn.BatchNorm2d(inter_channels), 162 | nn.ReLU(inplace=True), 163 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 164 | nn.BatchNorm2d(channels), 165 | ) 166 | self.global_att = nn.Sequential( 167 | nn.AdaptiveAvgPool2d(1), 168 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 169 | nn.BatchNorm2d(inter_channels), 170 | nn.ReLU(inplace=True), 171 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 172 | nn.BatchNorm2d(channels), 173 | ) 174 | else: 175 | raise f"the type is not supported." 176 | 177 | self.sigmoid = nn.Sigmoid() 178 | 179 | def forward(self, x, residual): 180 | flag = False 181 | xa = x + residual 182 | if xa.size(0) == 1: 183 | xa = torch.cat([xa, xa], dim=0) 184 | flag = True 185 | xl = self.local_att(xa) 186 | xg = self.global_att(xa) 187 | xlg = xl + xg 188 | wei = self.sigmoid(xlg) 189 | xo = 2 * x * wei + 2 * residual * (1 - wei) 190 | if flag: 191 | xo = xo[0].unsqueeze(0) 192 | return xo 193 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | 7 | import multiprocessing as mp 8 | from threading import Thread 9 | from queue import Queue 10 | 11 | from inspect import isfunction 12 | from PIL import Image, ImageDraw, ImageFont 13 | 14 | CACHE = { 15 | "get_vits_phoneme_ids":{ 16 | "PAD_LENGTH": 310, 17 | "_pad": '_', 18 | "_punctuation": ';:,.!?¡¿—…"«»“” ', 19 | "_letters": 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz', 20 | "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ", 21 | "_special": "♪☎☒☝⚠" 22 | } 23 | } 24 | 25 | CACHE["get_vits_phoneme_ids"]["symbols"] = [CACHE["get_vits_phoneme_ids"]["_pad"]] + list(CACHE["get_vits_phoneme_ids"]["_punctuation"]) + list(CACHE["get_vits_phoneme_ids"]["_letters"]) + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"]) + list(CACHE["get_vits_phoneme_ids"]["_special"]) 26 | CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = {s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"])} 27 | 28 | def get_vits_phoneme_ids_no_padding(phonemes): 29 | pad_token_id = 0 30 | pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"] 31 | _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] 32 | batchsize = len(phonemes) 33 | 34 | clean_text = phonemes[0] + "⚠" 35 | sequence = [] 36 | 37 | for symbol in clean_text: 38 | if(symbol not in _symbol_to_id.keys()): 39 | print("%s is not in the vocabulary. %s" % (symbol, clean_text)) 40 | symbol = "_" 41 | symbol_id = _symbol_to_id[symbol] 42 | sequence += [symbol_id] 43 | 44 | def _pad_phonemes(phonemes_list): 45 | return phonemes_list + [pad_token_id] * (pad_length-len(phonemes_list)) 46 | 47 | sequence = sequence[:pad_length] 48 | 49 | return {"phoneme_idx": torch.LongTensor(_pad_phonemes(sequence)).unsqueeze(0).expand(batchsize, -1)} 50 | 51 | def log_txt_as_img(wh, xc, size=10): 52 | # wh a tuple of (width, height) 53 | # xc a list of captions to plot 54 | b = len(xc) 55 | txts = list() 56 | for bi in range(b): 57 | txt = Image.new("RGB", wh, color="white") 58 | draw = ImageDraw.Draw(txt) 59 | font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) 60 | nc = int(40 * (wh[0] / 256)) 61 | lines = "\n".join( 62 | xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) 63 | ) 64 | 65 | try: 66 | draw.text((0, 0), lines, fill="black", font=font) 67 | except UnicodeEncodeError: 68 | print("Cant encode string for logging. Skipping.") 69 | 70 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 71 | txts.append(txt) 72 | txts = np.stack(txts) 73 | txts = torch.tensor(txts) 74 | return txts 75 | 76 | 77 | def ismap(x): 78 | if not isinstance(x, torch.Tensor): 79 | return False 80 | return (len(x.shape) == 4) and (x.shape[1] > 3) 81 | 82 | 83 | def isimage(x): 84 | if not isinstance(x, torch.Tensor): 85 | return False 86 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 87 | 88 | 89 | def int16_to_float32(x): 90 | return (x / 32767.0).astype(np.float32) 91 | 92 | 93 | def float32_to_int16(x): 94 | x = np.clip(x, a_min=-1.0, a_max=1.0) 95 | return (x * 32767.0).astype(np.int16) 96 | 97 | 98 | def exists(x): 99 | return x is not None 100 | 101 | 102 | def default(val, d): 103 | if exists(val): 104 | return val 105 | return d() if isfunction(d) else d 106 | 107 | 108 | def mean_flat(tensor): 109 | """ 110 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 111 | Take the mean over all non-batch dimensions. 112 | """ 113 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 114 | 115 | 116 | def count_params(model, verbose=False): 117 | total_params = sum(p.numel() for p in model.parameters()) 118 | if verbose: 119 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 120 | return total_params 121 | 122 | 123 | def instantiate_from_config(config): 124 | if not "target" in config: 125 | if config == "__is_first_stage__": 126 | return None 127 | elif config == "__is_unconditional__": 128 | return None 129 | raise KeyError("Expected key `target` to instantiate.") 130 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 131 | 132 | 133 | def get_obj_from_str(string, reload=False): 134 | module, cls = string.rsplit(".", 1) 135 | if reload: 136 | module_imp = importlib.import_module(module) 137 | importlib.reload(module_imp) 138 | return getattr(importlib.import_module(module, package=None), cls) 139 | 140 | 141 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 142 | # create dummy dataset instance 143 | 144 | # run prefetching 145 | if idx_to_fn: 146 | res = func(data, worker_id=idx) 147 | else: 148 | res = func(data) 149 | Q.put([idx, res]) 150 | Q.put("Done") 151 | 152 | 153 | def parallel_data_prefetch( 154 | func: callable, 155 | data, 156 | n_proc, 157 | target_data_type="ndarray", 158 | cpu_intensive=True, 159 | use_worker_id=False, 160 | ): 161 | # if target_data_type not in ["ndarray", "list"]: 162 | # raise ValueError( 163 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 164 | # ) 165 | if isinstance(data, np.ndarray) and target_data_type == "list": 166 | raise ValueError("list expected but function got ndarray.") 167 | elif isinstance(data, abc.Iterable): 168 | if isinstance(data, dict): 169 | print( 170 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 171 | ) 172 | data = list(data.values()) 173 | if target_data_type == "ndarray": 174 | data = np.asarray(data) 175 | else: 176 | data = list(data) 177 | else: 178 | raise TypeError( 179 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 180 | ) 181 | 182 | if cpu_intensive: 183 | Q = mp.Queue(1000) 184 | proc = mp.Process 185 | else: 186 | Q = Queue(1000) 187 | proc = Thread 188 | # spawn processes 189 | if target_data_type == "ndarray": 190 | arguments = [ 191 | [func, Q, part, i, use_worker_id] 192 | for i, part in enumerate(np.array_split(data, n_proc)) 193 | ] 194 | else: 195 | step = ( 196 | int(len(data) / n_proc + 1) 197 | if len(data) % n_proc != 0 198 | else int(len(data) / n_proc) 199 | ) 200 | arguments = [ 201 | [func, Q, part, i, use_worker_id] 202 | for i, part in enumerate( 203 | [data[i : i + step] for i in range(0, len(data), step)] 204 | ) 205 | ] 206 | processes = [] 207 | for i in range(n_proc): 208 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 209 | processes += [p] 210 | 211 | # start processes 212 | print(f"Start prefetching...") 213 | import time 214 | 215 | start = time.time() 216 | gather_res = [[] for _ in range(n_proc)] 217 | try: 218 | for p in processes: 219 | p.start() 220 | 221 | k = 0 222 | while k < n_proc: 223 | # get result 224 | res = Q.get() 225 | if res == "Done": 226 | k += 1 227 | else: 228 | gather_res[res[0]] = res[1] 229 | 230 | except Exception as e: 231 | print("Exception: ", e) 232 | for p in processes: 233 | p.terminate() 234 | 235 | raise e 236 | finally: 237 | for p in processes: 238 | p.join() 239 | print(f"Prefetching complete. [{time.time() - start} sec.]") 240 | 241 | if target_data_type == "ndarray": 242 | if not isinstance(gather_res[0], np.ndarray): 243 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 244 | 245 | # order outputs 246 | return np.concatenate(gather_res, axis=0) 247 | elif target_data_type == "list": 248 | out = [] 249 | for r in gather_res: 250 | out.extend(r) 251 | return out 252 | else: 253 | return gather_res 254 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | import timm.models.vision_transformer 17 | 18 | 19 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 20 | """Vision Transformer with support for global average pooling""" 21 | 22 | def __init__( 23 | self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs 24 | ): 25 | super(VisionTransformer, self).__init__(**kwargs) 26 | 27 | self.global_pool = global_pool 28 | if self.global_pool: 29 | norm_layer = kwargs["norm_layer"] 30 | embed_dim = kwargs["embed_dim"] 31 | self.fc_norm = norm_layer(embed_dim) 32 | del self.norm # remove the original norm 33 | self.mask_2d = mask_2d 34 | self.use_custom_patch = use_custom_patch 35 | 36 | def forward_features(self, x): 37 | B = x.shape[0] 38 | x = self.patch_embed(x) 39 | x = x + self.pos_embed[:, 1:, :] 40 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 41 | cls_tokens = cls_token.expand( 42 | B, -1, -1 43 | ) # stole cls_tokens impl from Phil Wang, thanks 44 | x = torch.cat((cls_tokens, x), dim=1) 45 | x = self.pos_drop(x) 46 | 47 | for blk in self.blocks: 48 | x = blk(x) 49 | 50 | if self.global_pool: 51 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 52 | outcome = self.fc_norm(x) 53 | else: 54 | x = self.norm(x) 55 | outcome = x[:, 0] 56 | 57 | return outcome 58 | 59 | def random_masking(self, x, mask_ratio): 60 | """ 61 | Perform per-sample random masking by per-sample shuffling. 62 | Per-sample shuffling is done by argsort random noise. 63 | x: [N, L, D], sequence 64 | """ 65 | N, L, D = x.shape # batch, length, dim 66 | len_keep = int(L * (1 - mask_ratio)) 67 | 68 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 69 | 70 | # sort noise for each sample 71 | ids_shuffle = torch.argsort( 72 | noise, dim=1 73 | ) # ascend: small is keep, large is remove 74 | ids_restore = torch.argsort(ids_shuffle, dim=1) 75 | 76 | # keep the first subset 77 | ids_keep = ids_shuffle[:, :len_keep] 78 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 79 | 80 | # generate the binary mask: 0 is keep, 1 is remove 81 | mask = torch.ones([N, L], device=x.device) 82 | mask[:, :len_keep] = 0 83 | # unshuffle to get the binary mask 84 | mask = torch.gather(mask, dim=1, index=ids_restore) 85 | 86 | return x_masked, mask, ids_restore 87 | 88 | def random_masking_2d(self, x, mask_t_prob, mask_f_prob): 89 | """ 90 | 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob) 91 | Perform per-sample random masking by per-sample shuffling. 92 | Per-sample shuffling is done by argsort random noise. 93 | x: [N, L, D], sequence 94 | """ 95 | 96 | N, L, D = x.shape # batch, length, dim 97 | if self.use_custom_patch: 98 | # # for AS 99 | T = 101 # 64,101 100 | F = 12 # 8,12 101 | # # for ESC 102 | # T=50 103 | # F=12 104 | # for SPC 105 | # T=12 106 | # F=12 107 | else: 108 | # ## for AS 109 | T = 64 110 | F = 8 111 | # ## for ESC 112 | # T=32 113 | # F=8 114 | ## for SPC 115 | # T=8 116 | # F=8 117 | 118 | # mask T 119 | x = x.reshape(N, T, F, D) 120 | len_keep_T = int(T * (1 - mask_t_prob)) 121 | noise = torch.rand(N, T, device=x.device) # noise in [0, 1] 122 | # sort noise for each sample 123 | ids_shuffle = torch.argsort( 124 | noise, dim=1 125 | ) # ascend: small is keep, large is remove 126 | ids_keep = ids_shuffle[:, :len_keep_T] 127 | index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D) 128 | # x_masked = torch.gather(x, dim=1, index=index) 129 | # x_masked = x_masked.reshape(N,len_keep_T*F,D) 130 | x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D 131 | 132 | # mask F 133 | # x = x.reshape(N, T, F, D) 134 | x = x.permute(0, 2, 1, 3) # N T' F D => N F T' D 135 | len_keep_F = int(F * (1 - mask_f_prob)) 136 | noise = torch.rand(N, F, device=x.device) # noise in [0, 1] 137 | # sort noise for each sample 138 | ids_shuffle = torch.argsort( 139 | noise, dim=1 140 | ) # ascend: small is keep, large is remove 141 | ids_keep = ids_shuffle[:, :len_keep_F] 142 | # index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D) 143 | index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D) 144 | x_masked = torch.gather(x, dim=1, index=index) 145 | x_masked = x_masked.permute(0, 2, 1, 3) # N F' T' D => N T' F' D 146 | # x_masked = x_masked.reshape(N,len_keep*T,D) 147 | x_masked = x_masked.reshape(N, len_keep_F * len_keep_T, D) 148 | 149 | return x_masked, None, None 150 | 151 | def forward_features_mask(self, x, mask_t_prob, mask_f_prob): 152 | B = x.shape[0] # 4,1,1024,128 153 | x = self.patch_embed(x) # 4, 512, 768 154 | 155 | x = x + self.pos_embed[:, 1:, :] 156 | if self.random_masking_2d: 157 | x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob) 158 | else: 159 | x, mask, ids_restore = self.random_masking(x, mask_t_prob) 160 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 161 | cls_tokens = cls_token.expand(B, -1, -1) 162 | x = torch.cat((cls_tokens, x), dim=1) 163 | x = self.pos_drop(x) 164 | 165 | # apply Transformer blocks 166 | for blk in self.blocks: 167 | x = blk(x) 168 | 169 | if self.global_pool: 170 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 171 | outcome = self.fc_norm(x) 172 | else: 173 | x = self.norm(x) 174 | outcome = x[:, 0] 175 | 176 | return outcome 177 | 178 | # overwrite original timm 179 | def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0): 180 | if mask_t_prob > 0.0 or mask_f_prob > 0.0: 181 | x = self.forward_features_mask( 182 | x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob 183 | ) 184 | else: 185 | x = self.forward_features(x) 186 | x = self.head(x) 187 | return x 188 | 189 | 190 | def vit_small_patch16(**kwargs): 191 | model = VisionTransformer( 192 | patch_size=16, 193 | embed_dim=384, 194 | depth=12, 195 | num_heads=6, 196 | mlp_ratio=4, 197 | qkv_bias=True, 198 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 199 | **kwargs 200 | ) 201 | return model 202 | 203 | 204 | def vit_base_patch16(**kwargs): 205 | model = VisionTransformer( 206 | patch_size=16, 207 | embed_dim=768, 208 | depth=12, 209 | num_heads=12, 210 | mlp_ratio=4, 211 | qkv_bias=True, 212 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 213 | **kwargs 214 | ) 215 | return model 216 | 217 | 218 | def vit_large_patch16(**kwargs): 219 | model = VisionTransformer( 220 | patch_size=16, 221 | embed_dim=1024, 222 | depth=24, 223 | num_heads=16, 224 | mlp_ratio=4, 225 | qkv_bias=True, 226 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 227 | **kwargs 228 | ) 229 | return model 230 | 231 | 232 | def vit_huge_patch14(**kwargs): 233 | model = VisionTransformer( 234 | patch_size=14, 235 | embed_dim=1280, 236 | depth=32, 237 | num_heads=16, 238 | mlp_ratio=4, 239 | qkv_bias=True, 240 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 241 | **kwargs 242 | ) 243 | return model 244 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/audiomae/util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | 15 | # -------------------------------------------------------- 16 | # 2D sine-cosine position embedding 17 | # References: 18 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 19 | # MoCo v3: https://github.com/facebookresearch/moco-v3 20 | # -------------------------------------------------------- 21 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 22 | """ 23 | grid_size: int of the grid height and width 24 | return: 25 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 26 | """ 27 | grid_h = np.arange(grid_size, dtype=np.float32) 28 | grid_w = np.arange(grid_size, dtype=np.float32) 29 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 30 | grid = np.stack(grid, axis=0) 31 | 32 | grid = grid.reshape([2, 1, grid_size, grid_size]) 33 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 34 | if cls_token: 35 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 36 | return pos_embed 37 | 38 | 39 | def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False): 40 | """ 41 | grid_size: int of the grid height and width 42 | return: 43 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 44 | """ 45 | grid_h = np.arange(grid_size[0], dtype=np.float32) 46 | grid_w = np.arange(grid_size[1], dtype=np.float32) 47 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 48 | grid = np.stack(grid, axis=0) 49 | 50 | grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) 51 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 52 | if cls_token: 53 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 54 | return pos_embed 55 | 56 | 57 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 58 | assert embed_dim % 2 == 0 59 | 60 | # use half of dimensions to encode grid_h 61 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 62 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 63 | 64 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 65 | return emb 66 | 67 | 68 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 69 | """ 70 | embed_dim: output dimension for each position 71 | pos: a list of positions to be encoded: size (M,) 72 | out: (M, D) 73 | """ 74 | assert embed_dim % 2 == 0 75 | # omega = np.arange(embed_dim // 2, dtype=np.float) 76 | omega = np.arange(embed_dim // 2, dtype=float) 77 | omega /= embed_dim / 2.0 78 | omega = 1.0 / 10000**omega # (D/2,) 79 | 80 | pos = pos.reshape(-1) # (M,) 81 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 82 | 83 | emb_sin = np.sin(out) # (M, D/2) 84 | emb_cos = np.cos(out) # (M, D/2) 85 | 86 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 87 | return emb 88 | 89 | 90 | # -------------------------------------------------------- 91 | # Interpolate position embeddings for high-resolution 92 | # References: 93 | # DeiT: https://github.com/facebookresearch/deit 94 | # -------------------------------------------------------- 95 | def interpolate_pos_embed(model, checkpoint_model): 96 | if "pos_embed" in checkpoint_model: 97 | pos_embed_checkpoint = checkpoint_model["pos_embed"] 98 | embedding_size = pos_embed_checkpoint.shape[-1] 99 | num_patches = model.patch_embed.num_patches 100 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 101 | # height (== width) for the checkpoint position embedding 102 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 103 | # height (== width) for the new position embedding 104 | new_size = int(num_patches**0.5) 105 | # class_token and dist_token are kept unchanged 106 | if orig_size != new_size: 107 | print( 108 | "Position interpolate from %dx%d to %dx%d" 109 | % (orig_size, orig_size, new_size, new_size) 110 | ) 111 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 112 | # only the position tokens are interpolated 113 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 114 | pos_tokens = pos_tokens.reshape( 115 | -1, orig_size, orig_size, embedding_size 116 | ).permute(0, 3, 1, 2) 117 | pos_tokens = torch.nn.functional.interpolate( 118 | pos_tokens, 119 | size=(new_size, new_size), 120 | mode="bicubic", 121 | align_corners=False, 122 | ) 123 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 124 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 125 | checkpoint_model["pos_embed"] = new_pos_embed 126 | 127 | 128 | def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size): 129 | if "pos_embed" in checkpoint_model: 130 | pos_embed_checkpoint = checkpoint_model["pos_embed"] 131 | embedding_size = pos_embed_checkpoint.shape[-1] 132 | num_patches = model.patch_embed.num_patches 133 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 134 | # height (== width) for the checkpoint position embedding 135 | # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 136 | # height (== width) for the new position embedding 137 | # new_size = int(num_patches ** 0.5) 138 | # class_token and dist_token are kept unchanged 139 | if orig_size != new_size: 140 | print( 141 | "Position interpolate from %dx%d to %dx%d" 142 | % (orig_size[0], orig_size[1], new_size[0], new_size[1]) 143 | ) 144 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 145 | # only the position tokens are interpolated 146 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 147 | pos_tokens = pos_tokens.reshape( 148 | -1, orig_size[0], orig_size[1], embedding_size 149 | ).permute(0, 3, 1, 2) 150 | pos_tokens = torch.nn.functional.interpolate( 151 | pos_tokens, 152 | size=(new_size[0], new_size[1]), 153 | mode="bicubic", 154 | align_corners=False, 155 | ) 156 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 157 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 158 | checkpoint_model["pos_embed"] = new_pos_embed 159 | 160 | 161 | def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size): 162 | if "pos_embed" in checkpoint_model: 163 | pos_embed_checkpoint = checkpoint_model["pos_embed"] 164 | embedding_size = pos_embed_checkpoint.shape[-1] 165 | num_patches = model.patch_embed.num_patches 166 | model.pos_embed.shape[-2] - num_patches 167 | if orig_size != new_size: 168 | print( 169 | "Position interpolate from %dx%d to %dx%d" 170 | % (orig_size[0], orig_size[1], new_size[0], new_size[1]) 171 | ) 172 | # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 173 | # only the position tokens are interpolated 174 | cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1) 175 | pos_tokens = pos_embed_checkpoint[:, 1:, :] # remove 176 | pos_tokens = pos_tokens.reshape( 177 | -1, orig_size[0], orig_size[1], embedding_size 178 | ) # .permute(0, 3, 1, 2) 179 | # pos_tokens = torch.nn.functional.interpolate( 180 | # pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False) 181 | 182 | # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 183 | pos_tokens = pos_tokens[:, :, : new_size[1], :] # assume only time diff 184 | pos_tokens = pos_tokens.flatten(1, 2) 185 | new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1) 186 | checkpoint_model["pos_embed"] = new_pos_embed 187 | 188 | 189 | def interpolate_patch_embed_audio( 190 | model, 191 | checkpoint_model, 192 | orig_channel, 193 | new_channel=1, 194 | kernel_size=(16, 16), 195 | stride=(16, 16), 196 | padding=(0, 0), 197 | ): 198 | if orig_channel != new_channel: 199 | if "patch_embed.proj.weight" in checkpoint_model: 200 | # aggregate 3 channels in rgb ckpt to 1 channel for audio 201 | new_proj_weight = torch.nn.Parameter( 202 | torch.sum(checkpoint_model["patch_embed.proj.weight"], dim=1).unsqueeze( 203 | 1 204 | ) 205 | ) 206 | checkpoint_model["patch_embed.proj.weight"] = new_proj_weight 207 | -------------------------------------------------------------------------------- /audioldm2/latent_diffusion/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | from einops import repeat 16 | 17 | from audioldm2.latent_diffusion.util import instantiate_from_config 18 | 19 | 20 | def make_beta_schedule( 21 | schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 22 | ): 23 | if schedule == "linear": 24 | betas = ( 25 | torch.linspace( 26 | linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 27 | ) 28 | ** 2 29 | ) 30 | 31 | elif schedule == "cosine": 32 | timesteps = ( 33 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 34 | ) 35 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 36 | alphas = torch.cos(alphas).pow(2) 37 | alphas = alphas / alphas[0] 38 | betas = 1 - alphas[1:] / alphas[:-1] 39 | betas = np.clip(betas, a_min=0, a_max=0.999) 40 | 41 | elif schedule == "sqrt_linear": 42 | betas = torch.linspace( 43 | linear_start, linear_end, n_timestep, dtype=torch.float64 44 | ) 45 | elif schedule == "sqrt": 46 | betas = ( 47 | torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 48 | ** 0.5 49 | ) 50 | else: 51 | raise ValueError(f"schedule '{schedule}' unknown.") 52 | return betas.numpy() 53 | 54 | 55 | def make_ddim_timesteps( 56 | ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True 57 | ): 58 | if ddim_discr_method == "uniform": 59 | c = num_ddpm_timesteps // num_ddim_timesteps 60 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 61 | elif ddim_discr_method == "quad": 62 | ddim_timesteps = ( 63 | (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 64 | ).astype(int) 65 | else: 66 | raise NotImplementedError( 67 | f'There is no ddim discretization method called "{ddim_discr_method}"' 68 | ) 69 | 70 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 71 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 72 | steps_out = ddim_timesteps + 1 73 | if verbose: 74 | print(f"Selected timesteps for ddim sampler: {steps_out}") 75 | return steps_out 76 | 77 | 78 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 79 | # select alphas for computing the variance schedule 80 | alphas = alphacums[ddim_timesteps] 81 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 82 | 83 | # according the the formula provided in https://arxiv.org/abs/2010.02502 84 | sigmas = eta * np.sqrt( 85 | (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) 86 | ) 87 | if verbose: 88 | print( 89 | f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" 90 | ) 91 | print( 92 | f"For the chosen value of eta, which is {eta}, " 93 | f"this results in the following sigma_t schedule for ddim sampler {sigmas}" 94 | ) 95 | return sigmas, alphas, alphas_prev 96 | 97 | 98 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 99 | """ 100 | Create a beta schedule that discretizes the given alpha_t_bar function, 101 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 102 | :param num_diffusion_timesteps: the number of betas to produce. 103 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 104 | produces the cumulative product of (1-beta) up to that 105 | part of the diffusion process. 106 | :param max_beta: the maximum beta to use; use values lower than 1 to 107 | prevent singularities. 108 | """ 109 | betas = [] 110 | for i in range(num_diffusion_timesteps): 111 | t1 = i / num_diffusion_timesteps 112 | t2 = (i + 1) / num_diffusion_timesteps 113 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 114 | return np.array(betas) 115 | 116 | 117 | def extract_into_tensor(a, t, x_shape): 118 | b, *_ = t.shape 119 | out = a.gather(-1, t).contiguous() 120 | return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous() 121 | 122 | 123 | def checkpoint(func, inputs, params, flag): 124 | """ 125 | Evaluate a function without caching intermediate activations, allowing for 126 | reduced memory at the expense of extra compute in the backward pass. 127 | :param func: the function to evaluate. 128 | :param inputs: the argument sequence to pass to `func`. 129 | :param params: a sequence of parameters `func` depends on but does not 130 | explicitly take as arguments. 131 | :param flag: if False, disable gradient checkpointing. 132 | """ 133 | if flag: 134 | args = tuple(inputs) + tuple(params) 135 | return CheckpointFunction.apply(func, len(inputs), *args) 136 | else: 137 | return func(*inputs) 138 | 139 | 140 | class CheckpointFunction(torch.autograd.Function): 141 | @staticmethod 142 | def forward(ctx, run_function, length, *args): 143 | ctx.run_function = run_function 144 | ctx.input_tensors = list(args[:length]) 145 | ctx.input_params = list(args[length:]) 146 | 147 | with torch.no_grad(): 148 | output_tensors = ctx.run_function(*ctx.input_tensors) 149 | return output_tensors 150 | 151 | @staticmethod 152 | def backward(ctx, *output_grads): 153 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 154 | with torch.enable_grad(): 155 | # Fixes a bug where the first op in run_function modifies the 156 | # Tensor storage in place, which is not allowed for detach()'d 157 | # Tensors. 158 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 159 | output_tensors = ctx.run_function(*shallow_copies) 160 | input_grads = torch.autograd.grad( 161 | output_tensors, 162 | ctx.input_tensors + ctx.input_params, 163 | output_grads, 164 | allow_unused=True, 165 | ) 166 | del ctx.input_tensors 167 | del ctx.input_params 168 | del output_tensors 169 | return (None, None) + input_grads 170 | 171 | 172 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 173 | """ 174 | Create sinusoidal timestep embeddings. 175 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 176 | These may be fractional. 177 | :param dim: the dimension of the output. 178 | :param max_period: controls the minimum frequency of the embeddings. 179 | :return: an [N x dim] Tensor of positional embeddings. 180 | """ 181 | if not repeat_only: 182 | half = dim // 2 183 | freqs = torch.exp( 184 | -math.log(max_period) 185 | * torch.arange(start=0, end=half, dtype=torch.float32) 186 | / half 187 | ).to(device=timesteps.device) 188 | args = timesteps[:, None].float() * freqs[None] 189 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 190 | if dim % 2: 191 | embedding = torch.cat( 192 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 193 | ) 194 | else: 195 | embedding = repeat(timesteps, "b -> b d", d=dim) 196 | return embedding 197 | 198 | 199 | def zero_module(module): 200 | """ 201 | Zero out the parameters of a module and return it. 202 | """ 203 | for p in module.parameters(): 204 | p.detach().zero_() 205 | return module 206 | 207 | 208 | def scale_module(module, scale): 209 | """ 210 | Scale the parameters of a module and return it. 211 | """ 212 | for p in module.parameters(): 213 | p.detach().mul_(scale) 214 | return module 215 | 216 | 217 | def mean_flat(tensor): 218 | """ 219 | Take the mean over all non-batch dimensions. 220 | """ 221 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 222 | 223 | 224 | def normalization(channels): 225 | """ 226 | Make a standard normalization layer. 227 | :param channels: number of input channels. 228 | :return: an nn.Module for normalization. 229 | """ 230 | return GroupNorm32(32, channels) 231 | 232 | 233 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 234 | class SiLU(nn.Module): 235 | def forward(self, x): 236 | return x * torch.sigmoid(x) 237 | 238 | 239 | class GroupNorm32(nn.GroupNorm): 240 | def forward(self, x): 241 | return super().forward(x.float()).type(x.dtype) 242 | 243 | 244 | def conv_nd(dims, *args, **kwargs): 245 | """ 246 | Create a 1D, 2D, or 3D convolution module. 247 | """ 248 | if dims == 1: 249 | return nn.Conv1d(*args, **kwargs) 250 | elif dims == 2: 251 | return nn.Conv2d(*args, **kwargs) 252 | elif dims == 3: 253 | return nn.Conv3d(*args, **kwargs) 254 | raise ValueError(f"unsupported dimensions: {dims}") 255 | 256 | 257 | def linear(*args, **kwargs): 258 | """ 259 | Create a linear module. 260 | """ 261 | return nn.Linear(*args, **kwargs) 262 | 263 | 264 | def avg_pool_nd(dims, *args, **kwargs): 265 | """ 266 | Create a 1D, 2D, or 3D average pooling module. 267 | """ 268 | if dims == 1: 269 | return nn.AvgPool1d(*args, **kwargs) 270 | elif dims == 2: 271 | return nn.AvgPool2d(*args, **kwargs) 272 | elif dims == 3: 273 | return nn.AvgPool3d(*args, **kwargs) 274 | raise ValueError(f"unsupported dimensions: {dims}") 275 | 276 | 277 | class HybridConditioner(nn.Module): 278 | def __init__(self, c_concat_config, c_crossattn_config): 279 | super().__init__() 280 | self.concat_conditioner = instantiate_from_config(c_concat_config) 281 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 282 | 283 | def forward(self, c_concat, c_crossattn): 284 | c_concat = self.concat_conditioner(c_concat) 285 | c_crossattn = self.crossattn_conditioner(c_crossattn) 286 | return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} 287 | 288 | 289 | def noise_like(shape, device, repeat=False): 290 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( 291 | shape[0], *((1,) * (len(shape) - 1)) 292 | ) 293 | noise = lambda: torch.randn(shape, device=device) 294 | return repeat_noise() if repeat else noise() 295 | --------------------------------------------------------------------------------