├── audioldm2
├── clap
│ ├── __init__.py
│ ├── training
│ │ ├── __init__.py
│ │ ├── audioset_textmap.npy
│ │ ├── bpe_simple_vocab_16e6.txt.gz
│ │ └── __pycache__
│ │ │ ├── data.cpython-310.pyc
│ │ │ └── __init__.cpython-310.pyc
│ └── open_clip
│ │ ├── bpe_simple_vocab_16e6.txt.gz
│ │ ├── model_configs
│ │ ├── ViT-B-16.json
│ │ ├── ViT-B-32.json
│ │ ├── ViT-L-14.json
│ │ ├── ViT-B-32-quickgelu.json
│ │ ├── RN50.json
│ │ ├── RN101.json
│ │ ├── RN50x16.json
│ │ ├── RN50x4.json
│ │ ├── RN101-quickgelu.json
│ │ ├── RN50-quickgelu.json
│ │ ├── PANN-6.json
│ │ ├── HTSAT-base.json
│ │ ├── HTSAT-tiny.json
│ │ ├── PANN-10.json
│ │ ├── PANN-14.json
│ │ ├── HTSAT-large.json
│ │ ├── PANN-14-fmax-18k.json
│ │ ├── PANN-14-win-1536.json
│ │ ├── HTSAT-tiny-win-1536.json
│ │ ├── PANN-14-fmax-8k-20s.json
│ │ └── PANN-14-tiny-transformer.json
│ │ ├── __init__.py
│ │ ├── transform.py
│ │ ├── timm_model.py
│ │ ├── openai.py
│ │ ├── pretrained.py
│ │ ├── tokenizer.py
│ │ └── feature_fusion.py
├── latent_diffusion
│ ├── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ └── __pycache__
│ │ │ ├── ddim.cpython-310.pyc
│ │ │ ├── ddpm.cpython-310.pyc
│ │ │ ├── plms.cpython-310.pyc
│ │ │ └── __init__.cpython-310.pyc
│ ├── modules
│ │ ├── __init__.py
│ │ ├── audiomae
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── AudioMAE.cpython-310.pyc
│ │ │ │ ├── __init__.cpython-310.pyc
│ │ │ │ ├── models_mae.cpython-310.pyc
│ │ │ │ └── models_vit.cpython-310.pyc
│ │ │ ├── util
│ │ │ │ ├── __pycache__
│ │ │ │ │ ├── pos_embed.cpython-310.pyc
│ │ │ │ │ └── patch_embed.cpython-310.pyc
│ │ │ │ ├── lr_sched.py
│ │ │ │ ├── crop.py
│ │ │ │ ├── datasets.py
│ │ │ │ ├── lars.py
│ │ │ │ ├── stat.py
│ │ │ │ ├── lr_decay.py
│ │ │ │ ├── patch_embed.py
│ │ │ │ └── pos_embed.py
│ │ │ ├── AudioMAE.py
│ │ │ └── models_vit.py
│ │ ├── encoders
│ │ │ ├── __init__.py
│ │ │ └── __pycache__
│ │ │ │ ├── __init__.cpython-310.pyc
│ │ │ │ └── modules.cpython-310.pyc
│ │ ├── diffusionmodules
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── util.cpython-310.pyc
│ │ │ │ └── __init__.cpython-310.pyc
│ │ │ └── util.py
│ │ ├── distributions
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-310.pyc
│ │ │ │ └── distributions.cpython-310.pyc
│ │ │ └── distributions.py
│ │ ├── phoneme_encoder
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── commons.cpython-310.pyc
│ │ │ │ ├── encoder.cpython-310.pyc
│ │ │ │ ├── __init__.cpython-310.pyc
│ │ │ │ └── attentions.cpython-310.pyc
│ │ │ ├── text
│ │ │ │ ├── __pycache__
│ │ │ │ │ ├── symbols.cpython-310.pyc
│ │ │ │ │ ├── __init__.cpython-310.pyc
│ │ │ │ │ └── cleaners.cpython-310.pyc
│ │ │ │ ├── symbols.py
│ │ │ │ ├── LICENSE
│ │ │ │ ├── __init__.py
│ │ │ │ └── cleaners.py
│ │ │ ├── encoder.py
│ │ │ └── commons.py
│ │ ├── __pycache__
│ │ │ ├── ema.cpython-310.pyc
│ │ │ └── __init__.cpython-310.pyc
│ │ └── ema.py
│ └── util.py
├── utilities
│ ├── data
│ │ ├── __init__.py
│ │ └── __pycache__
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ └── dataset.cpython-310.pyc
│ ├── __init__.py
│ ├── audio
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── stft.cpython-310.pyc
│ │ │ ├── tools.cpython-310.pyc
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ └── audio_processing.cpython-310.pyc
│ │ ├── tools.py
│ │ ├── audio_processing.py
│ │ └── stft.py
│ └── model.py
├── audiomae_gen
│ ├── __init__.py
│ └── utils.py
├── __init__.py
├── hifigan
│ ├── __init__.py
│ ├── LICENSE
│ └── models.py
├── __main__.py
└── pipeline.py
├── visuals
└── framework.png
├── modules
├── __pycache__
│ ├── layers.cpython-310.pyc
│ ├── autoencoder.cpython-310.pyc
│ └── conditioner.cpython-310.pyc
└── conditioner.py
├── scripts
├── train_l.sh
├── train_s.sh
├── train_b.sh
└── train_g.sh
├── utils.py
├── config
├── example.txt
└── 16k_64.yaml
├── constants.py
├── model.py
├── sample.py
├── README.md
└── test.py
/audioldm2/clap/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/audioldm2/clap/training/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/phoneme_encoder/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/audioldm2/utilities/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import Dataset
2 |
--------------------------------------------------------------------------------
/audioldm2/audiomae_gen/__init__.py:
--------------------------------------------------------------------------------
1 | from .sequence_input import Sequence2AudioMAE
2 |
--------------------------------------------------------------------------------
/visuals/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/visuals/framework.png
--------------------------------------------------------------------------------
/audioldm2/utilities/__init__.py:
--------------------------------------------------------------------------------
1 | from .tools import *
2 | from .data import *
3 | from .model import *
4 |
--------------------------------------------------------------------------------
/audioldm2/utilities/audio/__init__.py:
--------------------------------------------------------------------------------
1 | from .audio_processing import *
2 | from .stft import *
3 | from .tools import *
4 |
--------------------------------------------------------------------------------
/audioldm2/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import seed_everything, save_wave, get_time, get_duration, read_list
2 | from .pipeline import *
3 |
--------------------------------------------------------------------------------
/modules/__pycache__/layers.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/modules/__pycache__/layers.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/clap/training/audioset_textmap.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/clap/training/audioset_textmap.npy
--------------------------------------------------------------------------------
/modules/__pycache__/autoencoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/modules/__pycache__/autoencoder.cpython-310.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/conditioner.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/modules/__pycache__/conditioner.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/audioldm2/clap/training/__pycache__/data.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/clap/training/__pycache__/data.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/utilities/audio/__pycache__/stft.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/utilities/audio/__pycache__/stft.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/clap/training/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/clap/training/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/utilities/audio/__pycache__/tools.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/utilities/audio/__pycache__/tools.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/utilities/data/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/utilities/data/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/utilities/data/__pycache__/dataset.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/utilities/data/__pycache__/dataset.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/utilities/audio/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/utilities/audio/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/models/__pycache__/ddim.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/models/__pycache__/ddim.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/models/__pycache__/ddpm.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/models/__pycache__/ddpm.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/models/__pycache__/plms.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/models/__pycache__/plms.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/__pycache__/ema.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/__pycache__/ema.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/train_l.sh:
--------------------------------------------------------------------------------
1 | torchrun --nnodes=4 --nproc_per_node=8 train.py \
2 | --version large \
3 | --data-path combine_dataset.json \
4 | --global_batch_size 128 \
5 | --global-seed 2023
--------------------------------------------------------------------------------
/scripts/train_s.sh:
--------------------------------------------------------------------------------
1 | torchrun --nnodes=1 --nproc_per_node=8 train.py \
2 | --version small \
3 | --data-path combine_dataset.json \
4 | --global_batch_size 128 \
5 | --global-seed 2023
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/models/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/models/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/utilities/audio/__pycache__/audio_processing.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/utilities/audio/__pycache__/audio_processing.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/train_b.sh:
--------------------------------------------------------------------------------
1 | torchrun --nnodes=2 --nproc_per_node=8 train.py \
2 | --version base \
3 | --data-path combine_dataset.json \
4 | --global_batch_size 128 \
5 | --resume xxx \
6 | --global-seed 2023
7 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/__pycache__/AudioMAE.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/audiomae/__pycache__/AudioMAE.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/audiomae/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/encoders/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/encoders/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/encoders/__pycache__/modules.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/encoders/__pycache__/modules.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/__pycache__/models_mae.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/audiomae/__pycache__/models_mae.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/__pycache__/models_vit.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/audiomae/__pycache__/models_vit.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/util/__pycache__/pos_embed.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/audiomae/util/__pycache__/pos_embed.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/diffusionmodules/__pycache__/util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/diffusionmodules/__pycache__/util.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/distributions/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/distributions/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/commons.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/commons.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/encoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/encoder.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/util/__pycache__/patch_embed.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/audiomae/util/__pycache__/patch_embed.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/distributions/__pycache__/distributions.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/distributions/__pycache__/distributions.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/attentions.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/phoneme_encoder/__pycache__/attentions.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/phoneme_encoder/text/__pycache__/symbols.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/phoneme_encoder/text/__pycache__/symbols.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/phoneme_encoder/text/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/phoneme_encoder/text/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/phoneme_encoder/text/__pycache__/cleaners.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/FluxMusic/HEAD/audioldm2/latent_diffusion/modules/phoneme_encoder/text/__pycache__/cleaners.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/train_g.sh:
--------------------------------------------------------------------------------
1 | torchrun --nnodes=4 --nproc_per_node=8 train.py \
2 | --version giant \
3 | --data-path combine_dataset.json \
4 | --resume results/giant/checkpoints/0050000.pt \
5 | --global_batch_size 128 \
6 | --global-seed 2023 \
7 | --accum_iter 8
8 |
--------------------------------------------------------------------------------
/audioldm2/hifigan/__init__.py:
--------------------------------------------------------------------------------
1 | from .models_v2 import Generator
2 | from .models import Generator as Generator_old
3 |
4 |
5 | class AttrDict(dict):
6 | def __init__(self, *args, **kwargs):
7 | super(AttrDict, self).__init__(*args, **kwargs)
8 | self.__dict__ = self
9 |
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/ViT-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/ViT-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": 12,
7 | "width": 768,
8 | "patch_size": 32
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/RN50.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 6,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/RN101.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 23,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/RN50x16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 384,
5 | "layers": [
6 | 6,
7 | 8,
8 | 18,
9 | 8
10 | ],
11 | "width": 96,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 768,
18 | "heads": 12,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/RN50x4.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 288,
5 | "layers": [
6 | 4,
7 | 6,
8 | 10,
9 | 6
10 | ],
11 | "width": 80,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 640,
18 | "heads": 10,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 23,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 6,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/PANN-6.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn6"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/HTSAT-base.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "HTSAT",
14 | "model_name": "base"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "HTSAT",
14 | "model_name": "tiny"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/PANN-10.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn10"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/PANN-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn14"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/phoneme_encoder/text/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Defines the set of symbols used in text input to the model.
5 | '''
6 | _pad = '_'
7 | _punctuation = ';:,.!?¡¿—…"«»“” '
8 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
9 | _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
10 |
11 |
12 | # Export all symbols:
13 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
14 |
15 | # Special symbol ids
16 | SPACE_ID = symbols.index(" ")
17 |
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/HTSAT-large.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "HTSAT",
14 | "model_name": "large"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 18000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn14"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1536,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn14"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1536,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "HTSAT",
14 | "model_name": "tiny"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 960000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 360,
10 | "fmin": 50,
11 | "fmax": 8000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn14"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn14"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 4
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .factory import (
2 | list_models,
3 | create_model,
4 | create_model_and_transforms,
5 | add_model_config,
6 | )
7 | from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
8 | from .model import (
9 | CLAP,
10 | CLAPTextCfg,
11 | CLAPVisionCfg,
12 | CLAPAudioCfp,
13 | convert_weights_to_fp16,
14 | trace_model,
15 | )
16 | from .openai import load_openai_model, list_openai_models
17 | from .pretrained import (
18 | list_pretrained,
19 | list_pretrained_tag_models,
20 | list_pretrained_model_tags,
21 | get_pretrained_url,
22 | download_pretrained,
23 | )
24 | from .tokenizer import SimpleTokenizer, tokenize
25 | from .transform import image_transform
26 |
--------------------------------------------------------------------------------
/audioldm2/audiomae_gen/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Prenet(nn.Module):
5 | def __init__(self, in_dim, sizes=[256, 128], dropout_rate=0.5):
6 | super(Prenet, self).__init__()
7 | in_sizes = [in_dim] + sizes[:-1]
8 | self.layers = nn.ModuleList(
9 | [
10 | nn.Linear(in_size, out_size)
11 | for (in_size, out_size) in zip(in_sizes, sizes)
12 | ]
13 | )
14 | self.relu = nn.ReLU()
15 | self.dropout = nn.Dropout(dropout_rate)
16 |
17 | def forward(self, inputs):
18 | for linear in self.layers:
19 | inputs = self.dropout(self.relu(linear(inputs)))
20 | return inputs
21 |
22 |
23 | if __name__ == "__main__":
24 | model = Prenet(in_dim=128, sizes=[256, 256, 128])
25 | import ipdb
26 |
27 | ipdb.set_trace()
28 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/util/lr_sched.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 |
10 | def adjust_learning_rate(optimizer, epoch, args):
11 | """Decay the learning rate with half-cycle cosine after warmup"""
12 | if epoch < args.warmup_epochs:
13 | lr = args.lr * epoch / args.warmup_epochs
14 | else:
15 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (
16 | 1.0
17 | + math.cos(
18 | math.pi
19 | * (epoch - args.warmup_epochs)
20 | / (args.epochs - args.warmup_epochs)
21 | )
22 | )
23 | for param_group in optimizer.param_groups:
24 | if "lr_scale" in param_group:
25 | param_group["lr"] = lr * param_group["lr_scale"]
26 | else:
27 | param_group["lr"] = lr
28 | return lr
29 |
--------------------------------------------------------------------------------
/audioldm2/hifigan/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Jungil Kong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/phoneme_encoder/text/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 Keith Ito
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/transform.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import (
2 | Normalize,
3 | Compose,
4 | RandomResizedCrop,
5 | InterpolationMode,
6 | ToTensor,
7 | Resize,
8 | CenterCrop,
9 | )
10 |
11 |
12 | def _convert_to_rgb(image):
13 | return image.convert("RGB")
14 |
15 |
16 | def image_transform(
17 | image_size: int,
18 | is_train: bool,
19 | mean=(0.48145466, 0.4578275, 0.40821073),
20 | std=(0.26862954, 0.26130258, 0.27577711),
21 | ):
22 | normalize = Normalize(mean=mean, std=std)
23 | if is_train:
24 | return Compose(
25 | [
26 | RandomResizedCrop(
27 | image_size,
28 | scale=(0.9, 1.0),
29 | interpolation=InterpolationMode.BICUBIC,
30 | ),
31 | _convert_to_rgb,
32 | ToTensor(),
33 | normalize,
34 | ]
35 | )
36 | else:
37 | return Compose(
38 | [
39 | Resize(image_size, interpolation=InterpolationMode.BICUBIC),
40 | CenterCrop(image_size),
41 | _convert_to_rgb,
42 | ToTensor(),
43 | normalize,
44 | ]
45 | )
46 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from modules.autoencoder import AutoEncoder, AutoEncoderParams
3 | from modules.conditioner import HFEmbedder
4 | from safetensors.torch import load_file as load_sft
5 |
6 |
7 | def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
8 | # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
9 | return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
10 |
11 |
12 | def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
13 | return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
14 |
15 |
16 | def load_clap(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
17 | return HFEmbedder("laion/larger_clap_music", max_length=256, torch_dtype=torch.bfloat16).to(device)
18 |
19 | def load_ae(ckpt_path, device: str | torch.device = "cuda",) -> AutoEncoder:
20 | ae_params=AutoEncoderParams(
21 | resolution=256,
22 | in_channels=3,
23 | ch=128,
24 | out_ch=3,
25 | ch_mult=[1, 2, 4, 4],
26 | num_res_blocks=2,
27 | z_channels=16,
28 | scale_factor=0.3611,
29 | shift_factor=0.1159,
30 | )
31 | # Loading the autoencoder
32 | ae = AutoEncoder(ae_params)
33 | sd = load_sft(ckpt_path,)
34 | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
35 | ae.to(device)
36 | return ae
37 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/util/crop.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 | import torch
10 |
11 | from torchvision import transforms
12 | from torchvision.transforms import functional as F
13 |
14 |
15 | class RandomResizedCrop(transforms.RandomResizedCrop):
16 | """
17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
18 | This may lead to results different with torchvision's version.
19 | Following BYOL's TF code:
20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
21 | """
22 |
23 | @staticmethod
24 | def get_params(img, scale, ratio):
25 | width, height = F._get_image_size(img)
26 | area = height * width
27 |
28 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
29 | log_ratio = torch.log(torch.tensor(ratio))
30 | aspect_ratio = torch.exp(
31 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
32 | ).item()
33 |
34 | w = int(round(math.sqrt(target_area * aspect_ratio)))
35 | h = int(round(math.sqrt(target_area / aspect_ratio)))
36 |
37 | w = min(w, width)
38 | h = min(h, height)
39 |
40 | i = torch.randint(0, height - h + 1, size=(1,)).item()
41 | j = torch.randint(0, width - w + 1, size=(1,)).item()
42 |
43 | return i, j, h, w
44 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/phoneme_encoder/text/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | from audioldm2.latent_diffusion.modules.phoneme_encoder.text import cleaners
3 | from audioldm2.latent_diffusion.modules.phoneme_encoder.text.symbols import symbols
4 |
5 |
6 | # Mappings from symbol to numeric ID and vice versa:
7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9 |
10 | cleaner = getattr(cleaners, "english_cleaners2")
11 |
12 | def text_to_sequence(text, cleaner_names):
13 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
14 | Args:
15 | text: string to convert to a sequence
16 | cleaner_names: names of the cleaner functions to run the text through
17 | Returns:
18 | List of integers corresponding to the symbols in the text
19 | '''
20 | sequence = []
21 |
22 | clean_text = _clean_text(text, cleaner_names)
23 | for symbol in clean_text:
24 | symbol_id = _symbol_to_id[symbol]
25 | sequence += [symbol_id]
26 | return sequence
27 |
28 | def cleaned_text_to_sequence(cleaned_text):
29 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
30 | Args:
31 | text: string to convert to a sequence
32 | Returns:
33 | List of integers corresponding to the symbols in the text
34 | '''
35 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text]
36 | return sequence
37 |
38 | def sequence_to_text(sequence):
39 | '''Converts a sequence of IDs back to a string'''
40 | result = ''
41 | for symbol_id in sequence:
42 | s = _id_to_symbol[symbol_id]
43 | result += s
44 | return result
45 |
46 | def _clean_text(text, cleaner_names):
47 | text = cleaner(text)
48 | return text
49 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/phoneme_encoder/encoder.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 |
5 | import audioldm2.latent_diffusion.modules.phoneme_encoder.commons as commons
6 | import audioldm2.latent_diffusion.modules.phoneme_encoder.attentions as attentions
7 |
8 |
9 | class TextEncoder(nn.Module):
10 | def __init__(
11 | self,
12 | n_vocab,
13 | out_channels=192,
14 | hidden_channels=192,
15 | filter_channels=768,
16 | n_heads=2,
17 | n_layers=6,
18 | kernel_size=3,
19 | p_dropout=0.1,
20 | ):
21 | super().__init__()
22 | self.n_vocab = n_vocab
23 | self.out_channels = out_channels
24 | self.hidden_channels = hidden_channels
25 | self.filter_channels = filter_channels
26 | self.n_heads = n_heads
27 | self.n_layers = n_layers
28 | self.kernel_size = kernel_size
29 | self.p_dropout = p_dropout
30 |
31 | self.emb = nn.Embedding(n_vocab, hidden_channels)
32 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
33 |
34 | self.encoder = attentions.Encoder(
35 | hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
36 | )
37 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
38 |
39 | def forward(self, x, x_lengths):
40 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
41 | x = torch.transpose(x, 1, -1) # [b, h, t]
42 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
43 | x.dtype
44 | )
45 |
46 | x = self.encoder(x * x_mask, x_mask)
47 | stats = self.proj(x) * x_mask
48 |
49 | m, logs = torch.split(stats, self.out_channels, dim=1)
50 | return x, m, logs, x_mask
51 |
--------------------------------------------------------------------------------
/config/example.txt:
--------------------------------------------------------------------------------
1 | An experimental and free-jazz psych-folk tune with jazz influences from Europe.
2 | A powerful and dynamic rock track that blends industrial and electronic elements for a thrilling and energetic sound.
3 | A lively and energizing rock anthem that gets audiences pumped up and singing along.
4 | This song is a fusion of pop and experimental pop that creates a unique and innovative sound.
5 | This song is a mashup of punk, progressive, rock, and no wave genres.
6 | This song is a unique mix of experimental sounds and styles that push the boundaries of conventional music.
7 | An experimental electroacoustic instrumental composition featuring ambient soundscapes and field recordings.
8 | This rock song has a punk essence to it.
9 | This song is a fusion of both heavier metal and more melodic rock genres.
10 | This song features international sounds and influences.
11 | A chaotic, yet danceable blend of punk, new wave, and post-punk elements with electronic, experimental, and industrial influences, featuring chip music and a heavy drone focus.
12 | A wild and rebellious rock anthem with psychedelic and bluesy undertones, this song is loud and energetic.
13 | A soothing and atmospheric instrumental folk song plays in the background.
14 | A high-energy punk rock track with grungy garage undertones.
15 | A hip hop and electronic infused tune with strong elements of dubstep.
16 | This song is a combination of pop and experimental pop.
17 | A diverse indie-rock song that incorporates elements of rock, electronic, goth and trip-hop.
18 | This song is a classic rock anthem with driving guitars and powerful lyrics.
19 | A folk song from a British artist with traditional string instruments such as the fiddle and guitar.
20 | The song is an epic blend of space-rock, rock, and post-rock genres.
21 | A noisy and experimental track with a chaotic sound.
22 | A wild electrified sound consisting of experimental jazz, avant-garde, glitch, and free-jazz, with traces of noise and electroacoustics out there.
23 | A song that incorporates avant-garde, electroacoustic, experimental and musique concrete with field recordings.
24 | This song combines elements of punk, rock, and no wave.
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/util/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # --------------------------------------------------------
10 |
11 | import os
12 | import PIL
13 |
14 | from torchvision import datasets, transforms
15 |
16 | from timm.data import create_transform
17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18 |
19 |
20 | def build_dataset(is_train, args):
21 | transform = build_transform(is_train, args)
22 |
23 | root = os.path.join(args.data_path, "train" if is_train else "val")
24 | dataset = datasets.ImageFolder(root, transform=transform)
25 |
26 | print(dataset)
27 |
28 | return dataset
29 |
30 |
31 | def build_transform(is_train, args):
32 | mean = IMAGENET_DEFAULT_MEAN
33 | std = IMAGENET_DEFAULT_STD
34 | # train transform
35 | if is_train:
36 | # this should always dispatch to transforms_imagenet_train
37 | transform = create_transform(
38 | input_size=args.input_size,
39 | is_training=True,
40 | color_jitter=args.color_jitter,
41 | auto_augment=args.aa,
42 | interpolation="bicubic",
43 | re_prob=args.reprob,
44 | re_mode=args.remode,
45 | re_count=args.recount,
46 | mean=mean,
47 | std=std,
48 | )
49 | return transform
50 |
51 | # eval transform
52 | t = []
53 | if args.input_size <= 224:
54 | crop_pct = 224 / 256
55 | else:
56 | crop_pct = 1.0
57 | size = int(args.input_size / crop_pct)
58 | t.append(
59 | transforms.Resize(
60 | size, interpolation=PIL.Image.BICUBIC
61 | ), # to maintain same ratio w.r.t. 224 images
62 | )
63 | t.append(transforms.CenterCrop(args.input_size))
64 |
65 | t.append(transforms.ToTensor())
66 | t.append(transforms.Normalize(mean, std))
67 | return transforms.Compose(t)
68 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/util/lars.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # LARS optimizer, implementation from MoCo v3:
8 | # https://github.com/facebookresearch/moco-v3
9 | # --------------------------------------------------------
10 |
11 | import torch
12 |
13 |
14 | class LARS(torch.optim.Optimizer):
15 | """
16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
17 | """
18 |
19 | def __init__(
20 | self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001
21 | ):
22 | defaults = dict(
23 | lr=lr,
24 | weight_decay=weight_decay,
25 | momentum=momentum,
26 | trust_coefficient=trust_coefficient,
27 | )
28 | super().__init__(params, defaults)
29 |
30 | @torch.no_grad()
31 | def step(self):
32 | for g in self.param_groups:
33 | for p in g["params"]:
34 | dp = p.grad
35 |
36 | if dp is None:
37 | continue
38 |
39 | if p.ndim > 1: # if not normalization gamma/beta or bias
40 | dp = dp.add(p, alpha=g["weight_decay"])
41 | param_norm = torch.norm(p)
42 | update_norm = torch.norm(dp)
43 | one = torch.ones_like(param_norm)
44 | q = torch.where(
45 | param_norm > 0.0,
46 | torch.where(
47 | update_norm > 0,
48 | (g["trust_coefficient"] * param_norm / update_norm),
49 | one,
50 | ),
51 | one,
52 | )
53 | dp = dp.mul(q)
54 |
55 | param_state = self.state[p]
56 | if "mu" not in param_state:
57 | param_state["mu"] = torch.zeros_like(p)
58 | mu = param_state["mu"]
59 | mu.mul_(g["momentum"]).add_(dp)
60 | p.add_(mu, alpha=-g["lr"])
61 |
--------------------------------------------------------------------------------
/config/16k_64.yaml:
--------------------------------------------------------------------------------
1 | metadata_root: "./data/dataset/metadata/dataset_root.json"
2 | log_directory: "./log/latent_diffusion"
3 | project: "audioldm"
4 | precision: "high"
5 |
6 | variables:
7 | sampling_rate: &sampling_rate 16000
8 | mel_bins: &mel_bins 64
9 | latent_embed_dim: &latent_embed_dim 8
10 | latent_t_size: &latent_t_size 256 # TODO might need to change
11 | latent_f_size: &latent_f_size 16
12 | in_channels: &unet_in_channels 8
13 | optimize_ddpm_parameter: &optimize_ddpm_parameter true
14 | optimize_gpt: &optimize_gpt true
15 | warmup_steps: &warmup_steps 2000
16 |
17 | data:
18 | train: ["audiocaps"]
19 | val: "audiocaps"
20 | test: "audiocaps"
21 | class_label_indices: "audioset_eval_subset"
22 | dataloader_add_ons: ["waveform_rs_48k"]
23 |
24 | step:
25 | validation_every_n_epochs: 15
26 | save_checkpoint_every_n_steps: 5000
27 | # limit_val_batches: 2
28 | max_steps: 800000
29 | save_top_k: 1
30 |
31 | preprocessing:
32 | audio:
33 | sampling_rate: *sampling_rate
34 | max_wav_value: 32768.0
35 | duration: 10.24
36 | stft:
37 | filter_length: 1024
38 | hop_length: 160
39 | win_length: 1024
40 | mel:
41 | n_mel_channels: *mel_bins
42 | mel_fmin: 0
43 | mel_fmax: 8000
44 |
45 | augmentation:
46 | mixup: 0.0
47 |
48 | model:
49 | base_learning_rate: 8.0e-06
50 | target: audioldm_train.modules.latent_encoder.autoencoder.AutoencoderKL
51 | params:
52 | # reload_from_ckpt: "data/checkpoints/vae_mel_16k_64bins.ckpt"
53 | sampling_rate: *sampling_rate
54 | batchsize: 4
55 | monitor: val/rec_loss
56 | image_key: fbank
57 | subband: 1
58 | embed_dim: *latent_embed_dim
59 | time_shuffle: 1
60 | lossconfig:
61 | target: audioldm_train.losses.LPIPSWithDiscriminator
62 | params:
63 | disc_start: 50001
64 | kl_weight: 1000.0
65 | disc_weight: 0.5
66 | disc_in_channels: 1
67 | ddconfig:
68 | double_z: true
69 | mel_bins: *mel_bins # The frequency bins of mel spectrogram
70 | z_channels: 8
71 | resolution: 256
72 | downsample_time: false
73 | in_channels: 1
74 | out_ch: 1
75 | ch: 128
76 | ch_mult:
77 | - 1
78 | - 2
79 | - 4
80 | num_res_blocks: 2
81 | attn_resolutions: []
82 | dropout: 0.0
83 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/util/stat.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import stats
3 | from sklearn import metrics
4 | import torch
5 |
6 |
7 | def d_prime(auc):
8 | standard_normal = stats.norm()
9 | d_prime = standard_normal.ppf(auc) * np.sqrt(2.0)
10 | return d_prime
11 |
12 |
13 | @torch.no_grad()
14 | def concat_all_gather(tensor):
15 | """
16 | Performs all_gather operation on the provided tensors.
17 | *** Warning ***: torch.distributed.all_gather has no gradient.
18 | """
19 | tensors_gather = [
20 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
21 | ]
22 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
23 |
24 | output = torch.cat(tensors_gather, dim=0)
25 | return output
26 |
27 |
28 | def calculate_stats(output, target):
29 | """Calculate statistics including mAP, AUC, etc.
30 |
31 | Args:
32 | output: 2d array, (samples_num, classes_num)
33 | target: 2d array, (samples_num, classes_num)
34 |
35 | Returns:
36 | stats: list of statistic of each class.
37 | """
38 |
39 | classes_num = target.shape[-1]
40 | stats = []
41 |
42 | # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet
43 | acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1))
44 |
45 | # Class-wise statistics
46 | for k in range(classes_num):
47 | # Average precision
48 | avg_precision = metrics.average_precision_score(
49 | target[:, k], output[:, k], average=None
50 | )
51 |
52 | # AUC
53 | # auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None)
54 |
55 | # Precisions, recalls
56 | (precisions, recalls, thresholds) = metrics.precision_recall_curve(
57 | target[:, k], output[:, k]
58 | )
59 |
60 | # FPR, TPR
61 | (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k])
62 |
63 | save_every_steps = 1000 # Sample statistics to reduce size
64 | dict = {
65 | "precisions": precisions[0::save_every_steps],
66 | "recalls": recalls[0::save_every_steps],
67 | "AP": avg_precision,
68 | "fpr": fpr[0::save_every_steps],
69 | "fnr": 1.0 - tpr[0::save_every_steps],
70 | # 'auc': auc,
71 | # note acc is not class-wise, this is just to keep consistent with other metrics
72 | "acc": acc,
73 | }
74 | stats.append(dict)
75 |
76 | return stats
77 |
--------------------------------------------------------------------------------
/modules/conditioner.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor, nn
2 | from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
3 | T5Tokenizer, AutoTokenizer, ClapTextModel)
4 |
5 |
6 | class HFEmbedder(nn.Module):
7 | def __init__(self, version: str, max_length: int, **hf_kwargs):
8 | super().__init__()
9 | self.is_t5 = version.startswith("google")
10 | self.max_length = max_length
11 | self.output_key = "last_hidden_state" if self.is_t5 else "pooler_output"
12 |
13 | if version.startswith("openai"):
14 | local_path = '/maindata/data/shared/multimodal/public/ckpts/stable-diffusion-3-medium-diffusers/text_encoder'
15 | local_path_tokenizer = '/maindata/data/shared/multimodal/public/ckpts/stable-diffusion-3-medium-diffusers/tokenizer'
16 | self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(local_path_tokenizer, max_length=max_length)
17 | self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(local_path, **hf_kwargs).half()
18 | elif version.startswith("laion"):
19 | local_path = '/maindata/data/shared/multimodal/public/dataset_music/clap'
20 | self.tokenizer = AutoTokenizer.from_pretrained(local_path, max_length=max_length)
21 | self.hf_module: ClapTextModel = ClapTextModel.from_pretrained(local_path, **hf_kwargs).half()
22 | else:
23 | local_path = '/maindata/data/shared/multimodal/public/ckpts/stable-diffusion-3-medium-diffusers/text_encoder_3'
24 | local_path_tokenizer = '/maindata/data/shared/multimodal/public/ckpts/stable-diffusion-3-medium-diffusers/tokenizer_3'
25 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(local_path_tokenizer, max_length=max_length)
26 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(local_path, **hf_kwargs).half()
27 |
28 | self.hf_module = self.hf_module.eval().requires_grad_(False)
29 |
30 | def forward(self, text: list[str]) -> Tensor:
31 | batch_encoding = self.tokenizer(
32 | text,
33 | truncation=True,
34 | max_length=self.max_length,
35 | return_length=False,
36 | return_overflowing_tokens=False,
37 | padding="max_length",
38 | return_tensors="pt",
39 | )
40 |
41 | outputs = self.hf_module(
42 | input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
43 | attention_mask=None,
44 | output_hidden_states=False,
45 | )
46 | return outputs[self.output_key]
47 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/util/lr_decay.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ELECTRA https://github.com/google-research/electra
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 |
13 | def param_groups_lrd(
14 | model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75
15 | ):
16 | """
17 | Parameter groups for layer-wise lr decay
18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
19 | """
20 | param_group_names = {}
21 | param_groups = {}
22 |
23 | num_layers = len(model.blocks) + 1
24 |
25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
26 |
27 | for n, p in model.named_parameters():
28 | if not p.requires_grad:
29 | continue
30 |
31 | # no decay: all 1D parameters and model specific ones
32 | if p.ndim == 1 or n in no_weight_decay_list:
33 | g_decay = "no_decay"
34 | this_decay = 0.0
35 | else:
36 | g_decay = "decay"
37 | this_decay = weight_decay
38 |
39 | layer_id = get_layer_id_for_vit(n, num_layers)
40 | group_name = "layer_%d_%s" % (layer_id, g_decay)
41 |
42 | if group_name not in param_group_names:
43 | this_scale = layer_scales[layer_id]
44 |
45 | param_group_names[group_name] = {
46 | "lr_scale": this_scale,
47 | "weight_decay": this_decay,
48 | "params": [],
49 | }
50 | param_groups[group_name] = {
51 | "lr_scale": this_scale,
52 | "weight_decay": this_decay,
53 | "params": [],
54 | }
55 |
56 | param_group_names[group_name]["params"].append(n)
57 | param_groups[group_name]["params"].append(p)
58 |
59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
60 |
61 | return list(param_groups.values())
62 |
63 |
64 | def get_layer_id_for_vit(name, num_layers):
65 | """
66 | Assign a parameter with its layer id
67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
68 | """
69 | if name in ["cls_token", "pos_embed"]:
70 | return 0
71 | elif name.startswith("patch_embed"):
72 | return 0
73 | elif name.startswith("blocks"):
74 | return int(name.split(".")[1]) + 1
75 | else:
76 | return num_layers
77 |
--------------------------------------------------------------------------------
/audioldm2/utilities/audio/tools.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.io.wavfile import write
4 | import torchaudio
5 |
6 | from audioldm2.utilities.audio.audio_processing import griffin_lim
7 |
8 |
9 | def pad_wav(waveform, segment_length):
10 | waveform_length = waveform.shape[-1]
11 | assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
12 | if segment_length is None or waveform_length == segment_length:
13 | return waveform
14 | elif waveform_length > segment_length:
15 | return waveform[:segment_length]
16 | elif waveform_length < segment_length:
17 | temp_wav = np.zeros((1, segment_length))
18 | temp_wav[:, :waveform_length] = waveform
19 | return temp_wav
20 |
21 |
22 | def normalize_wav(waveform):
23 | waveform = waveform - np.mean(waveform)
24 | waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
25 | return waveform * 0.5
26 |
27 |
28 | def read_wav_file(filename, segment_length):
29 | # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
30 | waveform, sr = torchaudio.load(filename) # Faster!!!
31 | waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
32 | waveform = waveform.numpy()[0, ...]
33 | waveform = normalize_wav(waveform)
34 | waveform = waveform[None, ...]
35 | waveform = pad_wav(waveform, segment_length)
36 |
37 | waveform = waveform / np.max(np.abs(waveform))
38 | waveform = 0.5 * waveform
39 |
40 | return waveform
41 |
42 |
43 | def get_mel_from_wav(audio, _stft):
44 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
45 | audio = torch.autograd.Variable(audio, requires_grad=False)
46 | melspec, magnitudes, phases, energy = _stft.mel_spectrogram(audio)
47 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
48 | magnitudes = torch.squeeze(magnitudes, 0).numpy().astype(np.float32)
49 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
50 | return melspec, magnitudes, energy
51 |
52 |
53 | def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
54 | mel = torch.stack([mel])
55 | mel_decompress = _stft.spectral_de_normalize(mel)
56 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
57 | spec_from_mel_scaling = 1000
58 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
59 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
60 | spec_from_mel = spec_from_mel * spec_from_mel_scaling
61 |
62 | audio = griffin_lim(
63 | torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
64 | )
65 |
66 | audio = audio.squeeze()
67 | audio = audio.cpu().numpy()
68 | audio_path = out_filename
69 | write(audio_path, _stft.sampling_rate, audio)
70 |
--------------------------------------------------------------------------------
/constants.py:
--------------------------------------------------------------------------------
1 | from model import FluxParams, Flux
2 |
3 | def build_model(version='base'):
4 | if version == 'base':
5 | params=FluxParams(
6 | in_channels=32,
7 | vec_in_dim=768,
8 | context_in_dim=4096,
9 | hidden_size=768,
10 | mlp_ratio=4.0,
11 | num_heads=16,
12 | depth=12,
13 | depth_single_blocks=24,
14 | axes_dim=[16, 16, 16],
15 | theta=10_000,
16 | qkv_bias=True,
17 | guidance_embed=True,
18 | )
19 |
20 | elif version == 'small':
21 | params=FluxParams(
22 | in_channels=32,
23 | vec_in_dim=768,
24 | context_in_dim=4096,
25 | hidden_size=512,
26 | mlp_ratio=4.0,
27 | num_heads=16,
28 | depth=8,
29 | depth_single_blocks=16,
30 | axes_dim=[8, 12, 12],
31 | theta=10_000,
32 | qkv_bias=True,
33 | guidance_embed=True,
34 | )
35 |
36 |
37 | elif version == 'large':
38 | params=FluxParams(
39 | in_channels=32,
40 | vec_in_dim=768,
41 | context_in_dim=4096,
42 | hidden_size=1024,
43 | mlp_ratio=4.0,
44 | num_heads=16,
45 | depth=12,
46 | depth_single_blocks=24,
47 | axes_dim=[16, 24, 24],
48 | theta=10_000,
49 | qkv_bias=True,
50 | guidance_embed=True,
51 | )
52 |
53 | elif version == 'biggiant':
54 | params=FluxParams(
55 | in_channels=32,
56 | vec_in_dim=768,
57 | context_in_dim=4096,
58 | hidden_size=2048,
59 | mlp_ratio=4.0,
60 | num_heads=16,
61 | depth=19,
62 | depth_single_blocks=38,
63 | axes_dim=[32, 48, 48],
64 | theta=10_000,
65 | qkv_bias=True,
66 | guidance_embed=True,
67 | )
68 |
69 | elif version == 'giant_full':
70 | params=FluxParams(
71 | in_channels=32,
72 | vec_in_dim=768,
73 | context_in_dim=4096,
74 | hidden_size=1408,
75 | mlp_ratio=4.0,
76 | num_heads=16,
77 | depth=12,
78 | depth_single_blocks=24,
79 | axes_dim=[16, 36, 36],
80 | theta=10_000,
81 | qkv_bias=True,
82 | guidance_embed=True,
83 | )
84 |
85 | else:
86 | params=FluxParams(
87 | in_channels=32,
88 | vec_in_dim=768,
89 | context_in_dim=4096,
90 | hidden_size=1408,
91 | mlp_ratio=4.0,
92 | num_heads=16,
93 | depth=16,
94 | depth_single_blocks=32,
95 | axes_dim=[16, 36, 36],
96 | theta=10_000,
97 | qkv_bias=True,
98 | guidance_embed=True,
99 | )
100 |
101 |
102 | model = Flux(params)
103 | return model
--------------------------------------------------------------------------------
/audioldm2/utilities/audio/audio_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import librosa.util as librosa_util
4 | from scipy.signal import get_window
5 |
6 |
7 | def window_sumsquare(
8 | window,
9 | n_frames,
10 | hop_length,
11 | win_length,
12 | n_fft,
13 | dtype=np.float32,
14 | norm=None,
15 | ):
16 | """
17 | # from librosa 0.6
18 | Compute the sum-square envelope of a window function at a given hop length.
19 |
20 | This is used to estimate modulation effects induced by windowing
21 | observations in short-time fourier transforms.
22 |
23 | Parameters
24 | ----------
25 | window : string, tuple, number, callable, or list-like
26 | Window specification, as in `get_window`
27 |
28 | n_frames : int > 0
29 | The number of analysis frames
30 |
31 | hop_length : int > 0
32 | The number of samples to advance between frames
33 |
34 | win_length : [optional]
35 | The length of the window function. By default, this matches `n_fft`.
36 |
37 | n_fft : int > 0
38 | The length of each analysis frame.
39 |
40 | dtype : np.dtype
41 | The data type of the output
42 |
43 | Returns
44 | -------
45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
46 | The sum-squared envelope of the window function
47 | """
48 | if win_length is None:
49 | win_length = n_fft
50 |
51 | n = n_fft + hop_length * (n_frames - 1)
52 | x = np.zeros(n, dtype=dtype)
53 |
54 | # Compute the squared window at the desired length
55 | win_sq = get_window(window, win_length, fftbins=True)
56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
57 | win_sq = librosa_util.pad_center(win_sq, n_fft)
58 |
59 | # Fill the envelope
60 | for i in range(n_frames):
61 | sample = i * hop_length
62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
63 | return x
64 |
65 |
66 | def griffin_lim(magnitudes, stft_fn, n_iters=30):
67 | """
68 | PARAMS
69 | ------
70 | magnitudes: spectrogram magnitudes
71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
72 | """
73 |
74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
75 | angles = angles.astype(np.float32)
76 | angles = torch.autograd.Variable(torch.from_numpy(angles))
77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
78 |
79 | for i in range(n_iters):
80 | _, angles = stft_fn.transform(signal)
81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
82 | return signal
83 |
84 |
85 | def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
86 | """
87 | PARAMS
88 | ------
89 | C: compression factor
90 | """
91 | return normalize_fun(torch.clamp(x, min=clip_val) * C)
92 |
93 |
94 | def dynamic_range_decompression(x, C=1):
95 | """
96 | PARAMS
97 | ------
98 | C: compression factor used to compress
99 | """
100 | return torch.exp(x) / C
101 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError("Decay must be between 0 and 1")
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer(
14 | "num_updates",
15 | torch.tensor(0, dtype=torch.int)
16 | if use_num_upates
17 | else torch.tensor(-1, dtype=torch.int),
18 | )
19 |
20 | for name, p in model.named_parameters():
21 | if p.requires_grad:
22 | # remove as '.'-character is not allowed in buffers
23 | s_name = name.replace(".", "")
24 | self.m_name2s_name.update({name: s_name})
25 | self.register_buffer(s_name, p.clone().detach().data)
26 |
27 | self.collected_params = []
28 |
29 | def forward(self, model):
30 | decay = self.decay
31 |
32 | if self.num_updates >= 0:
33 | self.num_updates += 1
34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35 |
36 | one_minus_decay = 1.0 - decay
37 |
38 | with torch.no_grad():
39 | m_param = dict(model.named_parameters())
40 | shadow_params = dict(self.named_buffers())
41 |
42 | for key in m_param:
43 | if m_param[key].requires_grad:
44 | sname = self.m_name2s_name[key]
45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46 | shadow_params[sname].sub_(
47 | one_minus_decay * (shadow_params[sname] - m_param[key])
48 | )
49 | else:
50 | assert not key in self.m_name2s_name
51 |
52 | def copy_to(self, model):
53 | m_param = dict(model.named_parameters())
54 | shadow_params = dict(self.named_buffers())
55 | for key in m_param:
56 | if m_param[key].requires_grad:
57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
58 | else:
59 | assert not key in self.m_name2s_name
60 |
61 | def store(self, parameters):
62 | """
63 | Save the current parameters for restoring later.
64 | Args:
65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
66 | temporarily stored.
67 | """
68 | self.collected_params = [param.clone() for param in parameters]
69 |
70 | def restore(self, parameters):
71 | """
72 | Restore the parameters stored with the `store` method.
73 | Useful to validate the model with EMA parameters without affecting the
74 | original optimization process. Store the parameters before the
75 | `copy_to` method. After validation (or model saving), use this to
76 | restore the former parameters.
77 | Args:
78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
79 | updated with the stored parameters.
80 | """
81 | for c_param, param in zip(self.collected_params, parameters):
82 | param.data.copy_(c_param.data)
83 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/phoneme_encoder/text/cleaners.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Cleaners are transformations that run over the input text at both training and eval time.
5 |
6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8 | 1. "english_cleaners" for English text
9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12 | the symbols in symbols.py to match your data).
13 | '''
14 |
15 | import re
16 | from unidecode import unidecode
17 | from phonemizer import phonemize
18 |
19 |
20 | # Regular expression matching whitespace:
21 | _whitespace_re = re.compile(r'\s+')
22 |
23 | # List of (regular expression, replacement) pairs for abbreviations:
24 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
25 | ('mrs', 'misess'),
26 | ('mr', 'mister'),
27 | ('dr', 'doctor'),
28 | ('st', 'saint'),
29 | ('co', 'company'),
30 | ('jr', 'junior'),
31 | ('maj', 'major'),
32 | ('gen', 'general'),
33 | ('drs', 'doctors'),
34 | ('rev', 'reverend'),
35 | ('lt', 'lieutenant'),
36 | ('hon', 'honorable'),
37 | ('sgt', 'sergeant'),
38 | ('capt', 'captain'),
39 | ('esq', 'esquire'),
40 | ('ltd', 'limited'),
41 | ('col', 'colonel'),
42 | ('ft', 'fort'),
43 | ]]
44 |
45 |
46 | def expand_abbreviations(text):
47 | for regex, replacement in _abbreviations:
48 | text = re.sub(regex, replacement, text)
49 | return text
50 |
51 |
52 | def expand_numbers(text):
53 | return normalize_numbers(text)
54 |
55 |
56 | def lowercase(text):
57 | return text.lower()
58 |
59 |
60 | def collapse_whitespace(text):
61 | return re.sub(_whitespace_re, ' ', text)
62 |
63 |
64 | def convert_to_ascii(text):
65 | return unidecode(text)
66 |
67 |
68 | def basic_cleaners(text):
69 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
70 | text = lowercase(text)
71 | text = collapse_whitespace(text)
72 | return text
73 |
74 |
75 | def transliteration_cleaners(text):
76 | '''Pipeline for non-English text that transliterates to ASCII.'''
77 | text = convert_to_ascii(text)
78 | text = lowercase(text)
79 | text = collapse_whitespace(text)
80 | return text
81 |
82 |
83 | def english_cleaners(text):
84 | '''Pipeline for English text, including abbreviation expansion.'''
85 | text = convert_to_ascii(text)
86 | text = lowercase(text)
87 | text = expand_abbreviations(text)
88 | phonemes = phonemize(text, language='en-us', backend='espeak', strip=True)
89 | phonemes = collapse_whitespace(phonemes)
90 | return phonemes
91 |
92 |
93 | def english_cleaners2(text):
94 | '''Pipeline for English text, including abbreviation expansion. + punctuation + stress'''
95 | text = convert_to_ascii(text)
96 | text = lowercase(text)
97 | text = expand_abbreviations(text)
98 | phonemes = phonemize(text, language='en-us', backend='espeak', strip=True, preserve_punctuation=True, with_stress=True)
99 | phonemes = collapse_whitespace(phonemes)
100 | return phonemes
101 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(
34 | device=self.parameters.device
35 | )
36 |
37 | def sample(self):
38 | x = self.mean + self.std * torch.randn(self.mean.shape).to(
39 | device=self.parameters.device
40 | )
41 | return x
42 |
43 | def kl(self, other=None):
44 | if self.deterministic:
45 | return torch.Tensor([0.0])
46 | else:
47 | if other is None:
48 | return 0.5 * torch.mean(
49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50 | dim=[1, 2, 3],
51 | )
52 | else:
53 | return 0.5 * torch.mean(
54 | torch.pow(self.mean - other.mean, 2) / other.var
55 | + self.var / other.var
56 | - 1.0
57 | - self.logvar
58 | + other.logvar,
59 | dim=[1, 2, 3],
60 | )
61 |
62 | def nll(self, sample, dims=[1, 2, 3]):
63 | if self.deterministic:
64 | return torch.Tensor([0.0])
65 | logtwopi = np.log(2.0 * np.pi)
66 | return 0.5 * torch.sum(
67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68 | dim=dims,
69 | )
70 |
71 | def mode(self):
72 | return self.mean
73 |
74 |
75 | def normal_kl(mean1, logvar1, mean2, logvar2):
76 | """
77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78 | Compute the KL divergence between two gaussians.
79 | Shapes are automatically broadcasted, so batches can be compared to
80 | scalars, among other use cases.
81 | """
82 | tensor = None
83 | for obj in (mean1, logvar1, mean2, logvar2):
84 | if isinstance(obj, torch.Tensor):
85 | tensor = obj
86 | break
87 | assert tensor is not None, "at least one argument must be a Tensor"
88 |
89 | # Force variances to be Tensors. Broadcasting helps convert scalars to
90 | # Tensors, but it does not work for torch.exp().
91 | logvar1, logvar2 = [
92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93 | for x in (logvar1, logvar2)
94 | ]
95 |
96 | return 0.5 * (
97 | -1.0
98 | + logvar2
99 | - logvar1
100 | + torch.exp(logvar1 - logvar2)
101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102 | )
103 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 | from torch import Tensor, nn
5 |
6 | from modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
7 | MLPEmbedder, SingleStreamBlock,
8 | timestep_embedding)
9 |
10 |
11 | @dataclass
12 | class FluxParams:
13 | in_channels: int
14 | vec_in_dim: int
15 | context_in_dim: int
16 | hidden_size: int
17 | mlp_ratio: float
18 | num_heads: int
19 | depth: int
20 | depth_single_blocks: int
21 | axes_dim: list[int]
22 | theta: int
23 | qkv_bias: bool
24 | guidance_embed: bool
25 |
26 |
27 | class Flux(nn.Module):
28 | """
29 | Transformer model for flow matching on sequences.
30 | """
31 |
32 | def __init__(self, params: FluxParams):
33 | super().__init__()
34 |
35 | self.params = params
36 | self.in_channels = params.in_channels
37 | self.out_channels = self.in_channels
38 | if params.hidden_size % params.num_heads != 0:
39 | raise ValueError(
40 | f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
41 | )
42 | pe_dim = params.hidden_size // params.num_heads
43 | if sum(params.axes_dim) != pe_dim:
44 | raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
45 | self.hidden_size = params.hidden_size
46 | self.num_heads = params.num_heads
47 | self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
48 | self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
49 | self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
50 | self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
51 | # self.guidance_in = (
52 | # MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
53 | # )
54 | self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
55 |
56 | self.double_blocks = nn.ModuleList(
57 | [
58 | DoubleStreamBlock(
59 | self.hidden_size,
60 | self.num_heads,
61 | mlp_ratio=params.mlp_ratio,
62 | qkv_bias=params.qkv_bias,
63 | )
64 | for _ in range(params.depth)
65 | ]
66 | )
67 |
68 | self.single_blocks = nn.ModuleList(
69 | [
70 | SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
71 | for _ in range(params.depth_single_blocks)
72 | ]
73 | )
74 |
75 | self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
76 |
77 | def forward(
78 | self,
79 | x: Tensor,
80 | img_ids: Tensor,
81 | txt: Tensor,
82 | txt_ids: Tensor,
83 | t: Tensor,
84 | y: Tensor,
85 | guidance: Tensor | None = None,
86 | ) -> Tensor:
87 | if x.ndim != 3 or txt.ndim != 3:
88 | raise ValueError("Input img and txt tensors must have 3 dimensions.")
89 |
90 | # running on sequences img
91 | img = self.img_in(x)
92 | vec = self.time_in(timestep_embedding(t, 256))
93 | # if self.params.guidance_embed:
94 | # if guidance is None:
95 | # raise ValueError("Didn't get guidance strength for guidance distilled model.")
96 | # vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
97 | vec = vec + self.vector_in(y)
98 | txt = self.txt_in(txt)
99 |
100 | ids = torch.cat((txt_ids, img_ids), dim=1)
101 | pe = self.pe_embedder(ids)
102 |
103 | for block in self.double_blocks:
104 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
105 |
106 | img = torch.cat((txt, img), 1)
107 | for block in self.single_blocks:
108 | img = block(img, vec=vec, pe=pe)
109 | img = img[:, txt.shape[1] :, ...]
110 |
111 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
112 | return img
113 |
--------------------------------------------------------------------------------
/sample.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import math
5 | from einops import rearrange, repeat
6 | from PIL import Image
7 | from diffusers import AutoencoderKL
8 | from transformers import SpeechT5HifiGan
9 |
10 | from utils import load_t5, load_clap, load_ae
11 | from train import RF
12 | from constants import build_model
13 |
14 |
15 | def prepare(t5, clip, img, prompt):
16 | bs, c, h, w = img.shape
17 | if bs == 1 and not isinstance(prompt, str):
18 | bs = len(prompt)
19 |
20 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
21 | if img.shape[0] == 1 and bs > 1:
22 | img = repeat(img, "1 ... -> bs ...", bs=bs)
23 |
24 | img_ids = torch.zeros(h // 2, w // 2, 3)
25 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
26 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
27 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
28 |
29 | if isinstance(prompt, str):
30 | prompt = [prompt]
31 | txt = t5(prompt)
32 | if txt.shape[0] == 1 and bs > 1:
33 | txt = repeat(txt, "1 ... -> bs ...", bs=bs)
34 | txt_ids = torch.zeros(bs, txt.shape[1], 3)
35 |
36 | vec = clip(prompt)
37 | if vec.shape[0] == 1 and bs > 1:
38 | vec = repeat(vec, "1 ... -> bs ...", bs=bs)
39 |
40 | print(img_ids.size(), txt.size(), vec.size())
41 | return img, {
42 | "img_ids": img_ids.to(img.device),
43 | "txt": txt.to(img.device),
44 | "txt_ids": txt_ids.to(img.device),
45 | "y": vec.to(img.device),
46 | }
47 |
48 | def main(args):
49 | print('generate with MusicFlux')
50 | torch.manual_seed(args.seed)
51 | torch.set_grad_enabled(False)
52 | device = "cuda" if torch.cuda.is_available() else "cpu"
53 |
54 | latent_size = (256, 16)
55 |
56 | model = build_model(args.version).to(device)
57 | local_path = args.ckpt_path
58 | state_dict = torch.load(local_path, map_location=lambda storage, loc: storage)
59 | model.load_state_dict(state_dict['ema'])
60 | model.eval() # important!
61 | diffusion = RF()
62 |
63 | # Setup VAE
64 | t5 = load_t5(device, max_length=256)
65 | clap = load_clap(device, max_length=256)
66 |
67 | vae = AutoencoderKL.from_pretrained(os.path.join(args.audioldm2_model_path, 'vae')).to(device)
68 | vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(args.audioldm2_model_path, 'vocoder')).to(device)
69 |
70 | with open(args.prompt_file, 'r') as f:
71 | conds_txt = f.readlines()
72 | L = len(conds_txt)
73 | unconds_txt = ["low quality, gentle"] * L
74 | print(L, conds_txt, unconds_txt)
75 |
76 | init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).cuda()
77 |
78 | STEPSIZE = 50
79 | img, conds = prepare(t5, clap, init_noise, conds_txt)
80 | _, unconds = prepare(t5, clap, init_noise, unconds_txt)
81 | with torch.autocast(device_type='cuda'):
82 | images = diffusion.sample_with_xps(model, img, conds=conds, null_cond=unconds, sample_steps = STEPSIZE, cfg = 7.0)
83 |
84 | print(images[-1].size(), )
85 |
86 | images = rearrange(
87 | images[-1],
88 | "b (h w) (c ph pw) -> b c (h ph) (w pw)",
89 | h=128,
90 | w=8,
91 | ph=2,
92 | pw=2,)
93 | # print(images.size())
94 | latents = 1 / vae.config.scaling_factor * images
95 | mel_spectrogram = vae.decode(latents).sample
96 | print(mel_spectrogram.size())
97 |
98 | for i in range(L):
99 | x_i = mel_spectrogram[i]
100 | if x_i.dim() == 4:
101 | x_i = x_i.squeeze(1)
102 | waveform = vocoder(x_i)
103 | waveform = waveform[0].cpu().float().detach().numpy()
104 | print(waveform.shape)
105 | # import soundfile as sf
106 | # sf.write('reconstruct.wav', waveform, samplerate=16000)
107 | from scipy.io import wavfile
108 | wavfile.write('wav/sample_' + str(i) + '.wav', 16000, waveform)
109 |
110 |
111 | if __name__ == "__main__":
112 | parser = argparse.ArgumentParser()
113 | parser.add_argument("--version", type=str, default="small")
114 | parser.add_argument("--prompt_file", type=str, default='config/example.txt')
115 | parser.add_argument("--ckpt_path", type=str, default='musicflow_s.pt')
116 | parser.add_argument("--audioldm2_model_path", type=str, default='/maindata/data/shared/multimodal/public/dataset_music/audioldm2' )
117 | parser.add_argument("--seed", type=int, default=2024)
118 | args = parser.parse_args()
119 | main(args)
120 |
121 |
122 |
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/timm_model.py:
--------------------------------------------------------------------------------
1 | """ timm model adapter
2 |
3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4 | """
5 | from collections import OrderedDict
6 |
7 | import torch.nn as nn
8 |
9 | try:
10 | import timm
11 | from timm.models.layers import Mlp, to_2tuple
12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d
13 | from timm.models.layers.attention_pool2d import (
14 | AttentionPool2d as AbsAttentionPool2d,
15 | )
16 | except ImportError:
17 | timm = None
18 |
19 | from .utils import freeze_batch_norm_2d
20 |
21 |
22 | class TimmModel(nn.Module):
23 | """timm model adapter
24 | # FIXME this adapter is a work in progress, may change in ways that break weight compat
25 | """
26 |
27 | def __init__(
28 | self,
29 | model_name,
30 | embed_dim,
31 | image_size=224,
32 | pool="avg",
33 | proj="linear",
34 | drop=0.0,
35 | pretrained=False,
36 | ):
37 | super().__init__()
38 | if timm is None:
39 | raise RuntimeError("Please `pip install timm` to use timm models.")
40 |
41 | self.image_size = to_2tuple(image_size)
42 | self.trunk = timm.create_model(model_name, pretrained=pretrained)
43 | feat_size = self.trunk.default_cfg.get("pool_size", None)
44 | feature_ndim = 1 if not feat_size else 2
45 | if pool in ("abs_attn", "rot_attn"):
46 | assert feature_ndim == 2
47 | # if attn pooling used, remove both classifier and default pool
48 | self.trunk.reset_classifier(0, global_pool="")
49 | else:
50 | # reset global pool if pool config set, otherwise leave as network default
51 | reset_kwargs = dict(global_pool=pool) if pool else {}
52 | self.trunk.reset_classifier(0, **reset_kwargs)
53 | prev_chs = self.trunk.num_features
54 |
55 | head_layers = OrderedDict()
56 | if pool == "abs_attn":
57 | head_layers["pool"] = AbsAttentionPool2d(
58 | prev_chs, feat_size=feat_size, out_features=embed_dim
59 | )
60 | prev_chs = embed_dim
61 | elif pool == "rot_attn":
62 | head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
63 | prev_chs = embed_dim
64 | else:
65 | assert proj, "projection layer needed if non-attention pooling is used."
66 |
67 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
68 | if proj == "linear":
69 | head_layers["drop"] = nn.Dropout(drop)
70 | head_layers["proj"] = nn.Linear(prev_chs, embed_dim)
71 | elif proj == "mlp":
72 | head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
73 |
74 | self.head = nn.Sequential(head_layers)
75 |
76 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
77 | """lock modules
78 | Args:
79 | unlocked_groups (int): leave last n layer groups unlocked (default: 0)
80 | """
81 | if not unlocked_groups:
82 | # lock full model
83 | for param in self.trunk.parameters():
84 | param.requires_grad = False
85 | if freeze_bn_stats:
86 | freeze_batch_norm_2d(self.trunk)
87 | else:
88 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change
89 | try:
90 | # FIXME import here until API stable and in an official release
91 | from timm.models.helpers import group_parameters, group_modules
92 | except ImportError:
93 | raise RuntimeError(
94 | "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`"
95 | )
96 | matcher = self.trunk.group_matcher()
97 | gparams = group_parameters(self.trunk, matcher)
98 | max_layer_id = max(gparams.keys())
99 | max_layer_id = max_layer_id - unlocked_groups
100 | for group_idx in range(max_layer_id + 1):
101 | group = gparams[group_idx]
102 | for param in group:
103 | self.trunk.get_parameter(param).requires_grad = False
104 | if freeze_bn_stats:
105 | gmodules = group_modules(self.trunk, matcher, reverse=True)
106 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
107 | freeze_batch_norm_2d(self.trunk, gmodules)
108 |
109 | def forward(self, x):
110 | x = self.trunk(x)
111 | x = self.head(x)
112 | return x
113 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/util/patch_embed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from timm.models.layers import to_2tuple
4 |
5 |
6 | class PatchEmbed_org(nn.Module):
7 | """Image to Patch Embedding"""
8 |
9 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
10 | super().__init__()
11 | img_size = to_2tuple(img_size)
12 | patch_size = to_2tuple(patch_size)
13 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
14 | self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
15 | self.img_size = img_size
16 | self.patch_size = patch_size
17 | self.num_patches = num_patches
18 |
19 | self.proj = nn.Conv2d(
20 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
21 | )
22 |
23 | def forward(self, x):
24 | B, C, H, W = x.shape
25 | # FIXME look at relaxing size constraints
26 | # assert H == self.img_size[0] and W == self.img_size[1], \
27 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
28 | x = self.proj(x)
29 | y = x.flatten(2).transpose(1, 2)
30 | return y
31 |
32 |
33 | class PatchEmbed_new(nn.Module):
34 | """Flexible Image to Patch Embedding"""
35 |
36 | def __init__(
37 | self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
38 | ):
39 | super().__init__()
40 | img_size = to_2tuple(img_size)
41 | patch_size = to_2tuple(patch_size)
42 | stride = to_2tuple(stride)
43 |
44 | self.img_size = img_size
45 | self.patch_size = patch_size
46 |
47 | self.proj = nn.Conv2d(
48 | in_chans, embed_dim, kernel_size=patch_size, stride=stride
49 | ) # with overlapped patches
50 | # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
51 |
52 | # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
53 | # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
54 | _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
55 | self.patch_hw = (h, w)
56 | self.num_patches = h * w
57 |
58 | def get_output_shape(self, img_size):
59 | # todo: don't be lazy..
60 | return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
61 |
62 | def forward(self, x):
63 | B, C, H, W = x.shape
64 | # FIXME look at relaxing size constraints
65 | # assert H == self.img_size[0] and W == self.img_size[1], \
66 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
67 | # x = self.proj(x).flatten(2).transpose(1, 2)
68 | x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12
69 | x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212
70 | x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768
71 | return x
72 |
73 |
74 | class PatchEmbed3D_new(nn.Module):
75 | """Flexible Image to Patch Embedding"""
76 |
77 | def __init__(
78 | self,
79 | video_size=(16, 224, 224),
80 | patch_size=(2, 16, 16),
81 | in_chans=3,
82 | embed_dim=768,
83 | stride=(2, 16, 16),
84 | ):
85 | super().__init__()
86 |
87 | self.video_size = video_size
88 | self.patch_size = patch_size
89 | self.in_chans = in_chans
90 |
91 | self.proj = nn.Conv3d(
92 | in_chans, embed_dim, kernel_size=patch_size, stride=stride
93 | )
94 | _, _, t, h, w = self.get_output_shape(video_size) # n, emb_dim, h, w
95 | self.patch_thw = (t, h, w)
96 | self.num_patches = t * h * w
97 |
98 | def get_output_shape(self, video_size):
99 | # todo: don't be lazy..
100 | return self.proj(
101 | torch.randn(1, self.in_chans, video_size[0], video_size[1], video_size[2])
102 | ).shape
103 |
104 | def forward(self, x):
105 | B, C, T, H, W = x.shape
106 | x = self.proj(x) # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14
107 | x = x.flatten(2) # 32, 768, 1568
108 | x = x.transpose(1, 2) # 32, 768, 1568 -> 32, 1568, 768
109 | return x
110 |
111 |
112 | if __name__ == "__main__":
113 | # patch_emb = PatchEmbed_new(img_size=224, patch_size=16, in_chans=1, embed_dim=64, stride=(16,16))
114 | # input = torch.rand(8,1,1024,128)
115 | # output = patch_emb(input)
116 | # print(output.shape) # (8,512,64)
117 |
118 | patch_emb = PatchEmbed3D_new(
119 | video_size=(6, 224, 224),
120 | patch_size=(2, 16, 16),
121 | in_chans=3,
122 | embed_dim=768,
123 | stride=(2, 16, 16),
124 | )
125 | input = torch.rand(8, 3, 6, 224, 224)
126 | output = patch_emb(input)
127 | print(output.shape) # (8,64)
128 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## FluxMusic: Text-to-Music Generation with Rectified Flow Transformer
Official PyTorch Implementation
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | This repo contains PyTorch model definitions, pre-trained weights, and training/sampling code for paper *Flux that plays music*.
11 | It explores a simple extension of diffusion-based rectified flow Transformers for text-to-music generation. The model architecture can be seen as follows:
12 |
13 |
14 |
15 |
16 | ### To-do list
17 |
18 | - [x] training / inference scripts
19 | - [x] clean code
20 | - [x] all ckpts and part of dataset
21 |
22 |
23 | ### 1. Training
24 |
25 | You can refer to the [link](https://github.com/black-forest-labs/flux) to build the running environment.
26 |
27 | To launch small version in the latent space training with `N` GPUs on one node with pytorch DDP:
28 | ```bash
29 | torchrun --nnodes=1 --nproc_per_node=N train.py \
30 | --version small \
31 | --data-path xxx \
32 | --global_batch_size 128
33 | ```
34 |
35 | More scripts of different model size can reference to `scripts` file direction.
36 |
37 |
38 | ### 2. Inference
39 |
40 | We include a [`sample.py`](sample.py) script which samples music clips according to conditions from a MusicFlux model as:
41 | ```bash
42 | python sample.py \
43 | --version small \
44 | --ckpt_path /path/to/model \
45 | --prompt_file config/example.txt
46 | ```
47 |
48 | All prompts used in paper are lists in `config/example.txt`.
49 |
50 |
51 | ### 3. Download Ckpts and Data
52 |
53 | We use VAE and Vocoder in AudioLDM2, CLAP-L, and T5-XXL. You can download in the following table directly, we also provide the training scripts in our experiments.
54 |
55 | Note that in actual experiments, a restart experiment was performed due to machine malfunction, so there will be resume options in some scripts.
56 |
57 |
58 | | Model |Training steps | Url | Training scripts |
59 | |-------|--------|------------------|---------|
60 | | VAE | -| [link](https://huggingface.co/cvssp/audioldm2/tree/main/vae) | - |
61 | | Vocoder |-| [link](https://huggingface.co/cvssp/audioldm2/tree/main/vocoder) | - |
62 | | T5-XXL | - | [link](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers/tree/main/text_encoder_3) | - |
63 | | CLAP-L | -| [link](https://huggingface.co/laion/larger_clap_music/tree/main) | - |
64 | | FluxMusic-Small | 200K | [link](https://huggingface.co/feizhengcong/FluxMusic/blob/main/musicflow_s.pt) | [link](https://github.com/feizc/FluxMusic/blob/main/scripts/train_s.sh) |
65 | | FluxMusic-Base | 200K | [link](https://huggingface.co/feizhengcong/FluxMusic/blob/main/musicflow_b.pt) | [link](https://github.com/feizc/FluxMusic/blob/main/scripts/train_b.sh) |
66 | | FluxMusic-Large | 200K | [link](https://huggingface.co/feizhengcong/FluxMusic/blob/main/musicflow_l.pt) | [link](https://github.com/feizc/FluxMusic/blob/main/scripts/train_l.sh) |
67 | | FluxMusic-Giant | 200K | [link](https://huggingface.co/feizhengcong/FluxMusic/blob/main/musicflow_g.pt) | [link](https://github.com/feizc/FluxMusic/blob/main/scripts/train_g.sh) |
68 | | FluxMusic-Giant-Full | 2M | [link](https://huggingface.co/feizhengcong/FluxMusic/blob/main/musicflow_g_full.pt) | - |
69 |
70 |
71 | Note that 200K-steps ckpts are trained on a sub-training set and used for ploted the scaling experiments as well as case studies in the paper.
72 | The full version of main results will be released right way.
73 |
74 | The construction of training data can refer to the `test.py` file, showing a simple build of combing differnet datasets in json file.
75 |
76 | Considering copyright issues, the data used in the paper needs to be downloaded by oneself.
77 |
78 | We provide a clean subset in:
79 |
80 | A quick download link for other datasets can be found in [Huggingface](https://huggingface.co/datasets?search=music) : ).
81 |
82 | This is a research project, and it is recommended to try advanced products:
83 |
84 |
85 |
86 | ### Acknowledgments
87 |
88 | The codebase is based on the awesome [Flux](https://github.com/black-forest-labs/flux) and [AudioLDM2](https://github.com/haoheliu/AudioLDM2) repos.
89 |
90 |
91 |
92 |
93 |
--------------------------------------------------------------------------------
/audioldm2/__main__.py:
--------------------------------------------------------------------------------
1 | #!D:\GitDownload\SupThirdParty\audioldm2\venv\Scripts\python.exe
2 | import os
3 | import torch
4 | import logging
5 | from audioldm2 import text_to_audio, build_model, save_wave, get_time, read_list
6 | import argparse
7 |
8 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
9 | matplotlib_logger = logging.getLogger('matplotlib')
10 | matplotlib_logger.setLevel(logging.WARNING)
11 |
12 | parser = argparse.ArgumentParser()
13 |
14 | parser.add_argument(
15 | "-t",
16 | "--text",
17 | type=str,
18 | required=False,
19 | default="",
20 | help="Text prompt to the model for audio generation",
21 | )
22 |
23 | parser.add_argument(
24 | "--transcription",
25 | type=str,
26 | required=False,
27 | default="",
28 | help="Transcription for Text-to-Speech",
29 | )
30 |
31 | parser.add_argument(
32 | "-tl",
33 | "--text_list",
34 | type=str,
35 | required=False,
36 | default="",
37 | help="A file that contains text prompt to the model for audio generation",
38 | )
39 |
40 | parser.add_argument(
41 | "-s",
42 | "--save_path",
43 | type=str,
44 | required=False,
45 | help="The path to save model output",
46 | default="./output",
47 | )
48 |
49 | parser.add_argument(
50 | "--model_name",
51 | type=str,
52 | required=False,
53 | help="The checkpoint you gonna use",
54 | default="audioldm_48k",
55 | choices=["audioldm_48k", "audioldm_16k_crossattn_t5", "audioldm2-full", "audioldm2-music-665k",
56 | "audioldm2-full-large-1150k", "audioldm2-speech-ljspeech", "audioldm2-speech-gigaspeech"]
57 | )
58 |
59 | parser.add_argument(
60 | "-d",
61 | "--device",
62 | type=str,
63 | required=False,
64 | help="The device for computation. If not specified, the script will automatically choose the device based on your environment.",
65 | default="auto",
66 | )
67 |
68 | parser.add_argument(
69 | "-b",
70 | "--batchsize",
71 | type=int,
72 | required=False,
73 | default=1,
74 | help="Generate how many samples at the same time",
75 | )
76 |
77 | parser.add_argument(
78 | "--ddim_steps",
79 | type=int,
80 | required=False,
81 | default=200,
82 | help="The sampling step for DDIM",
83 | )
84 |
85 | parser.add_argument(
86 | "-gs",
87 | "--guidance_scale",
88 | type=float,
89 | required=False,
90 | default=3.5,
91 | help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
92 | )
93 |
94 | parser.add_argument(
95 | "-dur",
96 | "--duration",
97 | type=float,
98 | required=False,
99 | default=10.0,
100 | help="The duration of the samples",
101 | )
102 |
103 | parser.add_argument(
104 | "-n",
105 | "--n_candidate_gen_per_text",
106 | type=int,
107 | required=False,
108 | default=3,
109 | help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
110 | )
111 |
112 | parser.add_argument(
113 | "--seed",
114 | type=int,
115 | required=False,
116 | default=0,
117 | help="Change this value (any integer number) will lead to a different generation result.",
118 | )
119 |
120 | args = parser.parse_args()
121 |
122 | torch.set_float32_matmul_precision("high")
123 |
124 | save_path = os.path.join(args.save_path, get_time())
125 |
126 | text = args.text
127 | random_seed = args.seed
128 | duration = args.duration
129 | sample_rate = 16000
130 |
131 | if ("audioldm2" in args.model_name):
132 | print(
133 | "Warning: For AudioLDM2 we currently only support 10s of generation. Please use audioldm_48k or audioldm_16k_crossattn_t5 if you want a different duration.")
134 | duration = 10
135 | if ("48k" in args.model_name):
136 | sample_rate = 48000
137 |
138 | guidance_scale = args.guidance_scale
139 | n_candidate_gen_per_text = args.n_candidate_gen_per_text
140 | transcription = args.transcription
141 |
142 | if (transcription):
143 | if "speech" not in args.model_name:
144 | print(
145 | "Warning: You choose to perform Text-to-Speech by providing the transcription.However you do not choose the correct model name (audioldm2-speech-gigaspeech or audioldm2-speech-ljspeech).")
146 | print("Warning: We will use audioldm2-speech-gigaspeech by default")
147 | args.model_name = "audioldm2-speech-gigaspeech"
148 | if (not text):
149 | print(
150 | "Warning: You should provide text as a input to describe the speaker. Use default (A male reporter is speaking)")
151 | text = "A female reporter is speaking full of emotion"
152 |
153 | os.makedirs(save_path, exist_ok=True)
154 | audioldm2 = build_model(model_name=args.model_name, device=args.device)
155 |
156 | if (args.text_list):
157 | print("Generate audio based on the text prompts in %s" % args.text_list)
158 | prompt_todo = read_list(args.text_list)
159 | else:
160 | prompt_todo = [text]
161 |
162 | for text in prompt_todo:
163 | if ("|" in text):
164 | text, name = text.split("|")
165 | else:
166 | name = text[:128]
167 |
168 | if (transcription):
169 | name += "-TTS-%s" % transcription
170 |
171 | waveform = text_to_audio(
172 | audioldm2,
173 | text,
174 | transcription=transcription, # To avoid the model to ignore the last vocab
175 | seed=random_seed,
176 | duration=duration,
177 | guidance_scale=guidance_scale,
178 | ddim_steps=args.ddim_steps,
179 | n_candidate_gen_per_text=n_candidate_gen_per_text,
180 | batchsize=args.batchsize,
181 | )
182 |
183 | save_wave(waveform, save_path, name=name, samplerate=sample_rate)
184 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/phoneme_encoder/commons.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.nn import functional as F
4 |
5 |
6 | def init_weights(m, mean=0.0, std=0.01):
7 | classname = m.__class__.__name__
8 | if classname.find("Conv") != -1:
9 | m.weight.data.normal_(mean, std)
10 |
11 |
12 | def get_padding(kernel_size, dilation=1):
13 | return int((kernel_size * dilation - dilation) / 2)
14 |
15 |
16 | def convert_pad_shape(pad_shape):
17 | l = pad_shape[::-1]
18 | pad_shape = [item for sublist in l for item in sublist]
19 | return pad_shape
20 |
21 |
22 | def intersperse(lst, item):
23 | result = [item] * (len(lst) * 2 + 1)
24 | result[1::2] = lst
25 | return result
26 |
27 |
28 | def kl_divergence(m_p, logs_p, m_q, logs_q):
29 | """KL(P||Q)"""
30 | kl = (logs_q - logs_p) - 0.5
31 | kl += (
32 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33 | )
34 | return kl
35 |
36 |
37 | def rand_gumbel(shape):
38 | """Sample from the Gumbel distribution, protect from overflows."""
39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40 | return -torch.log(-torch.log(uniform_samples))
41 |
42 |
43 | def rand_gumbel_like(x):
44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45 | return g
46 |
47 |
48 | def slice_segments(x, ids_str, segment_size=4):
49 | ret = torch.zeros_like(x[:, :, :segment_size])
50 | for i in range(x.size(0)):
51 | idx_str = ids_str[i]
52 | idx_end = idx_str + segment_size
53 | ret[i] = x[i, :, idx_str:idx_end]
54 | return ret
55 |
56 |
57 | def rand_slice_segments(x, x_lengths=None, segment_size=4):
58 | b, d, t = x.size()
59 | if x_lengths is None:
60 | x_lengths = t
61 | ids_str_max = x_lengths - segment_size + 1
62 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63 | ret = slice_segments(x, ids_str, segment_size)
64 | return ret, ids_str
65 |
66 |
67 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
68 | position = torch.arange(length, dtype=torch.float)
69 | num_timescales = channels // 2
70 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
71 | num_timescales - 1
72 | )
73 | inv_timescales = min_timescale * torch.exp(
74 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
75 | )
76 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78 | signal = F.pad(signal, [0, 0, 0, channels % 2])
79 | signal = signal.view(1, channels, length)
80 | return signal
81 |
82 |
83 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84 | b, channels, length = x.size()
85 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86 | return x + signal.to(dtype=x.dtype, device=x.device)
87 |
88 |
89 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90 | b, channels, length = x.size()
91 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93 |
94 |
95 | def subsequent_mask(length):
96 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97 | return mask
98 |
99 |
100 | @torch.jit.script
101 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102 | n_channels_int = n_channels[0]
103 | in_act = input_a + input_b
104 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
105 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106 | acts = t_act * s_act
107 | return acts
108 |
109 |
110 | def convert_pad_shape(pad_shape):
111 | l = pad_shape[::-1]
112 | pad_shape = [item for sublist in l for item in sublist]
113 | return pad_shape
114 |
115 |
116 | def shift_1d(x):
117 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118 | return x
119 |
120 |
121 | def sequence_mask(length, max_length=None):
122 | if max_length is None:
123 | max_length = length.max()
124 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125 | return x.unsqueeze(0) < length.unsqueeze(1)
126 |
127 |
128 | def generate_path(duration, mask):
129 | """
130 | duration: [b, 1, t_x]
131 | mask: [b, 1, t_y, t_x]
132 | """
133 | duration.device
134 |
135 | b, _, t_y, t_x = mask.shape
136 | cum_duration = torch.cumsum(duration, -1)
137 |
138 | cum_duration_flat = cum_duration.view(b * t_x)
139 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
140 | path = path.view(b, t_x, t_y)
141 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
142 | path = path.unsqueeze(1).transpose(2, 3) * mask
143 | return path
144 |
145 |
146 | def clip_grad_value_(parameters, clip_value, norm_type=2):
147 | if isinstance(parameters, torch.Tensor):
148 | parameters = [parameters]
149 | parameters = list(filter(lambda p: p.grad is not None, parameters))
150 | norm_type = float(norm_type)
151 | if clip_value is not None:
152 | clip_value = float(clip_value)
153 |
154 | total_norm = 0
155 | for p in parameters:
156 | param_norm = p.grad.data.norm(norm_type)
157 | total_norm += param_norm.item() ** norm_type
158 | if clip_value is not None:
159 | p.grad.data.clamp_(min=-clip_value, max=clip_value)
160 | total_norm = total_norm ** (1.0 / norm_type)
161 | return total_norm
162 |
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/openai.py:
--------------------------------------------------------------------------------
1 | """ OpenAI pretrained model functions
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | import os
7 | import warnings
8 | from typing import Union, List
9 |
10 | import torch
11 |
12 | from .model import build_model_from_openai_state_dict
13 | from .pretrained import (
14 | get_pretrained_url,
15 | list_pretrained_tag_models,
16 | download_pretrained,
17 | )
18 |
19 | __all__ = ["list_openai_models", "load_openai_model"]
20 |
21 |
22 | def list_openai_models() -> List[str]:
23 | """Returns the names of available CLIP models"""
24 | return list_pretrained_tag_models("openai")
25 |
26 |
27 | def load_openai_model(
28 | name: str,
29 | model_cfg,
30 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
31 | jit=True,
32 | cache_dir=os.path.expanduser("~/.cache/clip"),
33 | enable_fusion: bool = False,
34 | fusion_type: str = "None",
35 | ):
36 | """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
37 |
38 | Parameters
39 | ----------
40 | name : str
41 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
42 | device : Union[str, torch.device]
43 | The device to put the loaded model
44 | jit : bool
45 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
46 |
47 | Returns
48 | -------
49 | model : torch.nn.Module
50 | The CLAP model
51 | preprocess : Callable[[PIL.Image], torch.Tensor]
52 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
53 | """
54 | if get_pretrained_url(name, "openai"):
55 | model_path = download_pretrained(
56 | get_pretrained_url(name, "openai"), root=cache_dir
57 | )
58 | elif os.path.isfile(name):
59 | model_path = name
60 | else:
61 | raise RuntimeError(
62 | f"Model {name} not found; available models = {list_openai_models()}"
63 | )
64 |
65 | try:
66 | # loading JIT archive
67 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
68 | state_dict = None
69 | except RuntimeError:
70 | # loading saved state dict
71 | if jit:
72 | warnings.warn(
73 | f"File {model_path} is not a JIT archive. Loading as a state dict instead"
74 | )
75 | jit = False
76 | state_dict = torch.load(model_path, map_location="cpu")
77 |
78 | if not jit:
79 | try:
80 | model = build_model_from_openai_state_dict(
81 | state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type
82 | ).to(device)
83 | except KeyError:
84 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
85 | model = build_model_from_openai_state_dict(
86 | sd, model_cfg, enable_fusion, fusion_type
87 | ).to(device)
88 |
89 | if str(device) == "cpu":
90 | model.float()
91 | return model
92 |
93 | # patch the device names
94 | device_holder = torch.jit.trace(
95 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
96 | )
97 | device_node = [
98 | n
99 | for n in device_holder.graph.findAllNodes("prim::Constant")
100 | if "Device" in repr(n)
101 | ][-1]
102 |
103 | def patch_device(module):
104 | try:
105 | graphs = [module.graph] if hasattr(module, "graph") else []
106 | except RuntimeError:
107 | graphs = []
108 |
109 | if hasattr(module, "forward1"):
110 | graphs.append(module.forward1.graph)
111 |
112 | for graph in graphs:
113 | for node in graph.findAllNodes("prim::Constant"):
114 | if "value" in node.attributeNames() and str(node["value"]).startswith(
115 | "cuda"
116 | ):
117 | node.copyAttributes(device_node)
118 |
119 | model.apply(patch_device)
120 | patch_device(model.encode_audio)
121 | patch_device(model.encode_text)
122 |
123 | # patch dtype to float32 on CPU
124 | if str(device) == "cpu":
125 | float_holder = torch.jit.trace(
126 | lambda: torch.ones([]).float(), example_inputs=[]
127 | )
128 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
129 | float_node = float_input.node()
130 |
131 | def patch_float(module):
132 | try:
133 | graphs = [module.graph] if hasattr(module, "graph") else []
134 | except RuntimeError:
135 | graphs = []
136 |
137 | if hasattr(module, "forward1"):
138 | graphs.append(module.forward1.graph)
139 |
140 | for graph in graphs:
141 | for node in graph.findAllNodes("aten::to"):
142 | inputs = list(node.inputs())
143 | for i in [
144 | 1,
145 | 2,
146 | ]: # dtype can be the second or third argument to aten::to()
147 | if inputs[i].node()["value"] == 5:
148 | inputs[i].node().copyAttributes(float_node)
149 |
150 | model.apply(patch_float)
151 | patch_float(model.encode_audio)
152 | patch_float(model.encode_text)
153 | model.float()
154 |
155 | model.audio_branch.audio_length = model.audio_cfg.audio_length
156 | return model
157 |
--------------------------------------------------------------------------------
/audioldm2/utilities/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import audioldm2.hifigan as hifigan
4 |
5 |
6 | def get_vocoder_config():
7 | return {
8 | "resblock": "1",
9 | "num_gpus": 6,
10 | "batch_size": 16,
11 | "learning_rate": 0.0002,
12 | "adam_b1": 0.8,
13 | "adam_b2": 0.99,
14 | "lr_decay": 0.999,
15 | "seed": 1234,
16 | "upsample_rates": [5, 4, 2, 2, 2],
17 | "upsample_kernel_sizes": [16, 16, 8, 4, 4],
18 | "upsample_initial_channel": 1024,
19 | "resblock_kernel_sizes": [3, 7, 11],
20 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
21 | "segment_size": 8192,
22 | "num_mels": 64,
23 | "num_freq": 1025,
24 | "n_fft": 1024,
25 | "hop_size": 160,
26 | "win_size": 1024,
27 | "sampling_rate": 16000,
28 | "fmin": 0,
29 | "fmax": 8000,
30 | "fmax_for_loss": None,
31 | "num_workers": 4,
32 | "dist_config": {
33 | "dist_backend": "nccl",
34 | "dist_url": "tcp://localhost:54321",
35 | "world_size": 1,
36 | },
37 | }
38 |
39 | def get_vocoder_config_48k():
40 | return {
41 | "resblock": "1",
42 | "num_gpus": 8,
43 | "batch_size": 128,
44 | "learning_rate": 0.0001,
45 | "adam_b1": 0.8,
46 | "adam_b2": 0.99,
47 | "lr_decay": 0.999,
48 | "seed": 1234,
49 |
50 | "upsample_rates": [6,5,4,2,2],
51 | "upsample_kernel_sizes": [12,10,8,4,4],
52 | "upsample_initial_channel": 1536,
53 | "resblock_kernel_sizes": [3,7,11,15],
54 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]],
55 |
56 | "segment_size": 15360,
57 | "num_mels": 256,
58 | "n_fft": 2048,
59 | "hop_size": 480,
60 | "win_size": 2048,
61 |
62 | "sampling_rate": 48000,
63 |
64 | "fmin": 20,
65 | "fmax": 24000,
66 | "fmax_for_loss": None,
67 |
68 | "num_workers": 8,
69 |
70 | "dist_config": {
71 | "dist_backend": "nccl",
72 | "dist_url": "tcp://localhost:18273",
73 | "world_size": 1
74 | }
75 | }
76 |
77 |
78 | def get_available_checkpoint_keys(model, ckpt):
79 | state_dict = torch.load(ckpt)["state_dict"]
80 | current_state_dict = model.state_dict()
81 | new_state_dict = {}
82 | for k in state_dict.keys():
83 | if (
84 | k in current_state_dict.keys()
85 | and current_state_dict[k].size() == state_dict[k].size()
86 | ):
87 | new_state_dict[k] = state_dict[k]
88 | else:
89 | print("==> WARNING: Skipping %s" % k)
90 | print(
91 | "%s out of %s keys are matched"
92 | % (len(new_state_dict.keys()), len(state_dict.keys()))
93 | )
94 | return new_state_dict
95 |
96 |
97 | def get_param_num(model):
98 | num_param = sum(param.numel() for param in model.parameters())
99 | return num_param
100 |
101 |
102 | def torch_version_orig_mod_remove(state_dict):
103 | new_state_dict = {}
104 | new_state_dict["generator"] = {}
105 | for key in state_dict["generator"].keys():
106 | if "_orig_mod." in key:
107 | new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[
108 | "generator"
109 | ][key]
110 | else:
111 | new_state_dict["generator"][key] = state_dict["generator"][key]
112 | return new_state_dict
113 |
114 |
115 | def get_vocoder(config, device, mel_bins):
116 | name = "HiFi-GAN"
117 | speaker = ""
118 | if name == "MelGAN":
119 | if speaker == "LJSpeech":
120 | vocoder = torch.hub.load(
121 | "descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
122 | )
123 | elif speaker == "universal":
124 | vocoder = torch.hub.load(
125 | "descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
126 | )
127 | vocoder.mel2wav.eval()
128 | vocoder.mel2wav.to(device)
129 | elif name == "HiFi-GAN":
130 | if(mel_bins == 64):
131 | config = get_vocoder_config()
132 | config = hifigan.AttrDict(config)
133 | vocoder = hifigan.Generator_old(config)
134 | # print("Load hifigan/g_01080000")
135 | # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
136 | # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
137 | # ckpt = torch_version_orig_mod_remove(ckpt)
138 | # vocoder.load_state_dict(ckpt["generator"])
139 | vocoder.eval()
140 | vocoder.remove_weight_norm()
141 | vocoder.to(device)
142 | else:
143 | config = get_vocoder_config_48k()
144 | config = hifigan.AttrDict(config)
145 | vocoder = hifigan.Generator_old(config)
146 | # print("Load hifigan/g_01080000")
147 | # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
148 | # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
149 | # ckpt = torch_version_orig_mod_remove(ckpt)
150 | # vocoder.load_state_dict(ckpt["generator"])
151 | vocoder.eval()
152 | vocoder.remove_weight_norm()
153 | vocoder.to(device)
154 | return vocoder
155 |
156 |
157 | def vocoder_infer(mels, vocoder, lengths=None):
158 | with torch.no_grad():
159 | wavs = vocoder(mels).squeeze(1)
160 |
161 | wavs = (wavs.cpu().numpy() * 32768).astype("int16")
162 |
163 | if lengths is not None:
164 | wavs = wavs[:, :lengths]
165 |
166 | # wavs = [wav for wav in wavs]
167 |
168 | # for i in range(len(mels)):
169 | # if lengths is not None:
170 | # wavs[i] = wavs[i][: lengths[i]]
171 |
172 | return wavs
173 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/AudioMAE.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference Repo: https://github.com/facebookresearch/AudioMAE
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 | from timm.models.layers import to_2tuple
8 | import audioldm2.latent_diffusion.modules.audiomae.models_vit as models_vit
9 | import audioldm2.latent_diffusion.modules.audiomae.models_mae as models_mae
10 |
11 | # model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
12 |
13 |
14 | class PatchEmbed_new(nn.Module):
15 | """Flexible Image to Patch Embedding"""
16 |
17 | def __init__(
18 | self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
19 | ):
20 | super().__init__()
21 | img_size = to_2tuple(img_size)
22 | patch_size = to_2tuple(patch_size)
23 | stride = to_2tuple(stride)
24 |
25 | self.img_size = img_size
26 | self.patch_size = patch_size
27 |
28 | self.proj = nn.Conv2d(
29 | in_chans, embed_dim, kernel_size=patch_size, stride=stride
30 | ) # with overlapped patches
31 | # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
32 |
33 | # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
34 | # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
35 | _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
36 | self.patch_hw = (h, w)
37 | self.num_patches = h * w
38 |
39 | def get_output_shape(self, img_size):
40 | # todo: don't be lazy..
41 | return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
42 |
43 | def forward(self, x):
44 | B, C, H, W = x.shape
45 | # FIXME look at relaxing size constraints
46 | # assert H == self.img_size[0] and W == self.img_size[1], \
47 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
48 | x = self.proj(x)
49 | x = x.flatten(2).transpose(1, 2)
50 | return x
51 |
52 |
53 | class AudioMAE(nn.Module):
54 | """Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)"""
55 |
56 | def __init__(
57 | self,
58 | ):
59 | super().__init__()
60 | model = models_vit.__dict__["vit_base_patch16"](
61 | num_classes=527,
62 | drop_path_rate=0.1,
63 | global_pool=True,
64 | mask_2d=True,
65 | use_custom_patch=False,
66 | )
67 |
68 | img_size = (1024, 128)
69 | emb_dim = 768
70 |
71 | model.patch_embed = PatchEmbed_new(
72 | img_size=img_size,
73 | patch_size=(16, 16),
74 | in_chans=1,
75 | embed_dim=emb_dim,
76 | stride=16,
77 | )
78 | num_patches = model.patch_embed.num_patches
79 | # num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8
80 | model.pos_embed = nn.Parameter(
81 | torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False
82 | ) # fixed sin-cos embedding
83 |
84 | # checkpoint_path = '/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth'
85 | # checkpoint = torch.load(checkpoint_path, map_location='cpu')
86 | # msg = model.load_state_dict(checkpoint['model'], strict=False)
87 | # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
88 |
89 | self.model = model
90 |
91 | def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0):
92 | """
93 | x: mel fbank [Batch, 1, T, F]
94 | mask_t_prob: 'T masking ratio (percentage of removed patches).'
95 | mask_f_prob: 'F masking ratio (percentage of removed patches).'
96 | """
97 | return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)
98 |
99 |
100 | class Vanilla_AudioMAE(nn.Module):
101 | """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM2)"""
102 |
103 | def __init__(
104 | self,
105 | ):
106 | super().__init__()
107 | model = models_mae.__dict__["mae_vit_base_patch16"](
108 | in_chans=1, audio_exp=True, img_size=(1024, 128)
109 | )
110 |
111 | # checkpoint_path = '/mnt/bn/lqhaoheliu/exps/checkpoints/audiomae/pretrained.pth'
112 | # checkpoint = torch.load(checkpoint_path, map_location='cpu')
113 | # msg = model.load_state_dict(checkpoint['model'], strict=False)
114 |
115 | # Skip the missing keys of decoder modules (not required)
116 | # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
117 |
118 | self.model = model.eval()
119 |
120 | def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False):
121 | """
122 | x: mel fbank [Batch, 1, 1024 (T), 128 (F)]
123 | mask_ratio: 'masking ratio (percentage of removed patches).'
124 | """
125 | with torch.no_grad():
126 | # embed: [B, 513, 768] for mask_ratio=0.0
127 | if no_mask:
128 | if no_average:
129 | raise RuntimeError("This function is deprecated")
130 | embed = self.model.forward_encoder_no_random_mask_no_average(
131 | x
132 | ) # mask_ratio
133 | else:
134 | embed = self.model.forward_encoder_no_mask(x) # mask_ratio
135 | else:
136 | raise RuntimeError("This function is deprecated")
137 | embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio)
138 | return embed
139 |
140 |
141 | if __name__ == "__main__":
142 | model = Vanilla_AudioMAE().cuda()
143 | input = torch.randn(4, 1, 1024, 128).cuda()
144 | print("The first run")
145 | embed = model(input, mask_ratio=0.0, no_mask=True)
146 | print(embed)
147 | print("The second run")
148 | embed = model(input, mask_ratio=0.0)
149 | print(embed)
150 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | def test_reconstuct():
5 | import yaml
6 | from diffusers import AutoencoderKL
7 | from transformers import SpeechT5HifiGan
8 | from audioldm2.utilities.data.dataset import AudioDataset
9 | from utils import load_clip, load_clap, load_t5
10 |
11 | model_path = '/maindata/data/shared/multimodal/public/dataset_music/audioldm2'
12 | config = yaml.load(
13 | open(
14 | 'config/16k_64.yaml',
15 | 'r'
16 | ),
17 | Loader=yaml.FullLoader,
18 | )
19 | print(config)
20 | t5 = load_t5('cuda', max_length=256)
21 | clap = load_clap('cuda', max_length=256)
22 |
23 | dataset = AudioDataset(
24 | config=config, split="train", waveform_only=False, dataset_json_path='mini_dataset.json',
25 | tokenizer=clap.tokenizer,
26 | uncond_pro=0.1,
27 | text_ctx_len=77,
28 | tokenizer_t5=t5.tokenizer,
29 | text_ctx_len_t5=256,
30 | uncond_pro_t5=0.1,
31 | )
32 | print(dataset[0]['log_mel_spec'].unsqueeze(0).unsqueeze(0).size())
33 |
34 | vae = AutoencoderKL.from_pretrained(os.path.join(model_path, 'vae'))
35 | vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(model_path, 'vocoder'))
36 | latents = vae.encode(dataset[0]['log_mel_spec'].unsqueeze(0).unsqueeze(0)).latent_dist.sample().mul_(vae.config.scaling_factor)
37 | print('laten size:', latents.size())
38 |
39 | latents = 1 / vae.config.scaling_factor * latents
40 | mel_spectrogram = vae.decode(latents).sample
41 | print(mel_spectrogram.size())
42 | if mel_spectrogram.dim() == 4:
43 | mel_spectrogram = mel_spectrogram.squeeze(1)
44 | waveform = vocoder(mel_spectrogram)
45 | waveform = waveform[0].cpu().float().detach().numpy()
46 | print(waveform.shape)
47 | # import soundfile as sf
48 | # sf.write('reconstruct.wav', waveform, samplerate=16000)
49 | from scipy.io import wavfile
50 | # wavfile.write('reconstruct.wav', 16000, waveform)
51 |
52 |
53 |
54 | def mini_dataset(num=32):
55 | data = []
56 | for i in range(num):
57 | data.append(
58 | {
59 | 'wav': 'case.mp3',
60 | 'label': 'a beautiful music',
61 | }
62 | )
63 |
64 | with open('mini_dataset.json', 'w') as f:
65 | json.dump(data, f, indent=4)
66 |
67 |
68 | def fma_dataset():
69 | import pandas as pd
70 |
71 | annotation_prex = "/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/annotation"
72 | annotation_list = ['test-00000-of-00001.parquet', 'train-00000-of-00001.parquet', 'valid-00000-of-00001.parquet']
73 | dataset_prex = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/fma_large'
74 |
75 | data = []
76 | for annotation_file in annotation_list:
77 | annotation_file = os.path.join(annotation_prex, annotation_file)
78 | df=pd.read_parquet(annotation_file)
79 | print(df.shape)
80 | for id, row in df.iterrows():
81 | #print(id, row['pseudo_caption'], row['path'])
82 | tmp_path = os.path.join(dataset_prex, row['path'] + '.mp3')
83 | # print(tmp_path)
84 | if os.path.exists(tmp_path):
85 | data.append(
86 | {
87 | 'wav': tmp_path,
88 | 'label': row['pseudo_caption'],
89 | }
90 | )
91 | # break
92 | print(len(data))
93 | with open('fma_dataset.json', 'w') as f:
94 | json.dump(data, f, indent=4)
95 |
96 |
97 |
98 |
99 |
100 | def audioset_dataset():
101 | import pandas as pd
102 | dataset_prex = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioset'
103 | annotation_path = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioset/balanced_train-00000-of-00001.parquet'
104 | df=pd.read_parquet(annotation_path)
105 | print(df.shape)
106 |
107 | data = []
108 | for id, row in df.iterrows():
109 | #print(id, row['pseudo_caption'], row['path'])
110 | try:
111 | tmp_path = os.path.join(dataset_prex, row['path'] + '.flac')
112 | except:
113 | print(row['path'])
114 |
115 | if os.path.exists(tmp_path):
116 | # print(tmp_path)
117 | data.append(
118 | {
119 | 'wav': tmp_path,
120 | 'label': row['pseudo_caption'],
121 | }
122 | )
123 | print(len(data))
124 | with open('audioset_dataset.json', 'w') as f:
125 | json.dump(data, f, indent=4)
126 |
127 |
128 |
129 | def combine_dataset():
130 | data_list = ['fma_dataset.json', 'audioset_dataset.json']
131 |
132 | data = []
133 | for data_file in data_list:
134 | with open(data_file, 'r') as f:
135 | data += json.load(f)
136 | print(len(data))
137 | with open('combine_dataset.json', 'w') as f:
138 | json.dump(data, f, indent=4)
139 |
140 |
141 |
142 | def test_music_format():
143 | import torchaudio
144 | filename = '2.flac'
145 | waveform, sr = torchaudio.load(filename,)
146 | print(waveform, sr )
147 |
148 |
149 | def test_flops():
150 | version = 'giant'
151 | import torch
152 | from constants import build_model
153 | from thop import profile
154 |
155 | model = build_model(version).cuda()
156 | img_ids = torch.randn((1, 1024, 3)).cuda()
157 | txt = torch.randn((1, 256, 4096)).cuda()
158 | txt_ids = torch.randn((1, 256, 3)).cuda()
159 | y = torch.randn((1, 768)).cuda()
160 | x = torch.randn((1, 1024, 32)).cuda()
161 | t = torch.tensor([1] * 1).cuda()
162 | flops, _ = profile(model, inputs=(x, img_ids, txt, txt_ids, t, y,))
163 | print('FLOPs = ' + str(flops * 2/1000**3) + 'G')
164 |
165 |
166 | # test_music_format()
167 | # test_reconstuct()
168 | # mini_dataset()
169 | # fma_dataset()
170 | # audioset_dataset()
171 | # combine_dataset()
172 | test_flops()
--------------------------------------------------------------------------------
/audioldm2/hifigan/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import Conv1d, ConvTranspose1d
5 | from torch.nn.utils import weight_norm, remove_weight_norm
6 |
7 | LRELU_SLOPE = 0.1
8 |
9 |
10 | def init_weights(m, mean=0.0, std=0.01):
11 | classname = m.__class__.__name__
12 | if classname.find("Conv") != -1:
13 | m.weight.data.normal_(mean, std)
14 |
15 |
16 | def get_padding(kernel_size, dilation=1):
17 | return int((kernel_size * dilation - dilation) / 2)
18 |
19 |
20 | class ResBlock(torch.nn.Module):
21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22 | super(ResBlock, self).__init__()
23 | self.h = h
24 | self.convs1 = nn.ModuleList(
25 | [
26 | weight_norm(
27 | Conv1d(
28 | channels,
29 | channels,
30 | kernel_size,
31 | 1,
32 | dilation=dilation[0],
33 | padding=get_padding(kernel_size, dilation[0]),
34 | )
35 | ),
36 | weight_norm(
37 | Conv1d(
38 | channels,
39 | channels,
40 | kernel_size,
41 | 1,
42 | dilation=dilation[1],
43 | padding=get_padding(kernel_size, dilation[1]),
44 | )
45 | ),
46 | weight_norm(
47 | Conv1d(
48 | channels,
49 | channels,
50 | kernel_size,
51 | 1,
52 | dilation=dilation[2],
53 | padding=get_padding(kernel_size, dilation[2]),
54 | )
55 | ),
56 | ]
57 | )
58 | self.convs1.apply(init_weights)
59 |
60 | self.convs2 = nn.ModuleList(
61 | [
62 | weight_norm(
63 | Conv1d(
64 | channels,
65 | channels,
66 | kernel_size,
67 | 1,
68 | dilation=1,
69 | padding=get_padding(kernel_size, 1),
70 | )
71 | ),
72 | weight_norm(
73 | Conv1d(
74 | channels,
75 | channels,
76 | kernel_size,
77 | 1,
78 | dilation=1,
79 | padding=get_padding(kernel_size, 1),
80 | )
81 | ),
82 | weight_norm(
83 | Conv1d(
84 | channels,
85 | channels,
86 | kernel_size,
87 | 1,
88 | dilation=1,
89 | padding=get_padding(kernel_size, 1),
90 | )
91 | ),
92 | ]
93 | )
94 | self.convs2.apply(init_weights)
95 |
96 | def forward(self, x):
97 | for c1, c2 in zip(self.convs1, self.convs2):
98 | xt = F.leaky_relu(x, LRELU_SLOPE)
99 | xt = c1(xt)
100 | xt = F.leaky_relu(xt, LRELU_SLOPE)
101 | xt = c2(xt)
102 | x = xt + x
103 | return x
104 |
105 | def remove_weight_norm(self):
106 | for l in self.convs1:
107 | remove_weight_norm(l)
108 | for l in self.convs2:
109 | remove_weight_norm(l)
110 |
111 |
112 | class Generator(torch.nn.Module):
113 | def __init__(self, h):
114 | super(Generator, self).__init__()
115 | self.h = h
116 | self.num_kernels = len(h.resblock_kernel_sizes)
117 | self.num_upsamples = len(h.upsample_rates)
118 | self.conv_pre = weight_norm(
119 | Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
120 | )
121 | resblock = ResBlock
122 |
123 | self.ups = nn.ModuleList()
124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
125 | self.ups.append(
126 | weight_norm(
127 | ConvTranspose1d(
128 | h.upsample_initial_channel // (2**i),
129 | h.upsample_initial_channel // (2 ** (i + 1)),
130 | k,
131 | u,
132 | padding=(k - u) // 2,
133 | )
134 | )
135 | )
136 |
137 | self.resblocks = nn.ModuleList()
138 | for i in range(len(self.ups)):
139 | ch = h.upsample_initial_channel // (2 ** (i + 1))
140 | for j, (k, d) in enumerate(
141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
142 | ):
143 | self.resblocks.append(resblock(h, ch, k, d))
144 |
145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
146 | self.ups.apply(init_weights)
147 | self.conv_post.apply(init_weights)
148 |
149 | def forward(self, x):
150 | x = self.conv_pre(x)
151 | for i in range(self.num_upsamples):
152 | x = F.leaky_relu(x, LRELU_SLOPE)
153 | x = self.ups[i](x)
154 | xs = None
155 | for j in range(self.num_kernels):
156 | if xs is None:
157 | xs = self.resblocks[i * self.num_kernels + j](x)
158 | else:
159 | xs += self.resblocks[i * self.num_kernels + j](x)
160 | x = xs / self.num_kernels
161 | x = F.leaky_relu(x)
162 | x = self.conv_post(x)
163 | x = torch.tanh(x)
164 |
165 | return x
166 |
167 | def remove_weight_norm(self):
168 | # print("Removing weight norm...")
169 | for l in self.ups:
170 | remove_weight_norm(l)
171 | for l in self.resblocks:
172 | l.remove_weight_norm()
173 | remove_weight_norm(self.conv_pre)
174 | remove_weight_norm(self.conv_post)
175 |
--------------------------------------------------------------------------------
/audioldm2/pipeline.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 |
4 | import yaml
5 | import torch
6 | import torchaudio
7 |
8 | import audioldm2.latent_diffusion.modules.phoneme_encoder.text as text
9 | from audioldm2.latent_diffusion.models.ddpm import LatentDiffusion
10 | from audioldm2.latent_diffusion.util import get_vits_phoneme_ids_no_padding
11 | from audioldm2.utils import default_audioldm_config, download_checkpoint
12 | import os
13 |
14 | # CACHE_DIR = os.getenv(
15 | # "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2")
16 | # )
17 |
18 | def seed_everything(seed):
19 | import random, os
20 | import numpy as np
21 | import torch
22 |
23 | random.seed(seed)
24 | os.environ["PYTHONHASHSEED"] = str(seed)
25 | np.random.seed(seed)
26 | torch.manual_seed(seed)
27 | torch.cuda.manual_seed(seed)
28 | torch.backends.cudnn.deterministic = True
29 | torch.backends.cudnn.benchmark = True
30 |
31 | def text2phoneme(data):
32 | return text._clean_text(re.sub(r'<.*?>', '', data), ["english_cleaners2"])
33 |
34 | def text_to_filename(text):
35 | return text.replace(" ", "_").replace("'", "_").replace('"', "_")
36 |
37 | def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec):
38 | norm_mean = -4.2677393
39 | norm_std = 4.5689974
40 |
41 | if sampling_rate != 16000:
42 | waveform_16k = torchaudio.functional.resample(
43 | waveform, orig_freq=sampling_rate, new_freq=16000
44 | )
45 | else:
46 | waveform_16k = waveform
47 |
48 | waveform_16k = waveform_16k - waveform_16k.mean()
49 | fbank = torchaudio.compliance.kaldi.fbank(
50 | waveform_16k,
51 | htk_compat=True,
52 | sample_frequency=16000,
53 | use_energy=False,
54 | window_type="hanning",
55 | num_mel_bins=128,
56 | dither=0.0,
57 | frame_shift=10,
58 | )
59 |
60 | TARGET_LEN = log_mel_spec.size(0)
61 |
62 | # cut and pad
63 | n_frames = fbank.shape[0]
64 | p = TARGET_LEN - n_frames
65 | if p > 0:
66 | m = torch.nn.ZeroPad2d((0, 0, 0, p))
67 | fbank = m(fbank)
68 | elif p < 0:
69 | fbank = fbank[:TARGET_LEN, :]
70 |
71 | fbank = (fbank - norm_mean) / (norm_std * 2)
72 |
73 | return {"ta_kaldi_fbank": fbank} # [1024, 128]
74 |
75 | def make_batch_for_text_to_audio(text, transcription="", waveform=None, fbank=None, batchsize=1):
76 | text = [text] * batchsize
77 | if(transcription):
78 | transcription = text2phoneme(transcription)
79 | transcription = [transcription] * batchsize
80 |
81 | if batchsize < 1:
82 | print("Warning: Batchsize must be at least 1. Batchsize is set to .")
83 |
84 | if fbank is None:
85 | fbank = torch.zeros(
86 | (batchsize, 1024, 64)
87 | ) # Not used, here to keep the code format
88 | else:
89 | fbank = torch.FloatTensor(fbank)
90 | fbank = fbank.expand(batchsize, 1024, 64)
91 | assert fbank.size(0) == batchsize
92 |
93 | stft = torch.zeros((batchsize, 1024, 512)) # Not used
94 | phonemes = get_vits_phoneme_ids_no_padding(transcription)
95 |
96 | if waveform is None:
97 | waveform = torch.zeros((batchsize, 160000)) # Not used
98 | ta_kaldi_fbank = torch.zeros((batchsize, 1024, 128))
99 | else:
100 | waveform = torch.FloatTensor(waveform)
101 | waveform = waveform.expand(batchsize, -1)
102 | assert waveform.size(0) == batchsize
103 | ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, 16000, fbank)
104 |
105 | batch = {
106 | "text": text, # list
107 | "fname": [text_to_filename(t) for t in text], # list
108 | "waveform": waveform,
109 | "stft": stft,
110 | "log_mel_spec": fbank,
111 | "ta_kaldi_fbank": ta_kaldi_fbank,
112 | }
113 | batch.update(phonemes)
114 | return batch
115 |
116 |
117 | def round_up_duration(duration):
118 | return int(round(duration / 2.5) + 1) * 2.5
119 |
120 |
121 | # def split_clap_weight_to_pth(checkpoint):
122 | # if os.path.exists(os.path.join(CACHE_DIR, "clap.pth")):
123 | # return
124 | # print("Constructing the weight for the CLAP model.")
125 | # include_keys = "cond_stage_models.0.cond_stage_models.0.model."
126 | # new_state_dict = {}
127 | # for each in checkpoint["state_dict"].keys():
128 | # if include_keys in each:
129 | # new_state_dict[each.replace(include_keys, "module.")] = checkpoint[
130 | # "state_dict"
131 | # ][each]
132 | # torch.save({"state_dict": new_state_dict}, os.path.join(CACHE_DIR, "clap.pth"))
133 |
134 |
135 | def build_model(ckpt_path=None, config=None, device=None, model_name="audioldm2-full"):
136 |
137 | if device is None or device == "auto":
138 | if torch.cuda.is_available():
139 | device = torch.device("cuda:0")
140 | elif torch.backends.mps.is_available():
141 | device = torch.device("mps")
142 | else:
143 | device = torch.device("cpu")
144 |
145 | print("Loading AudioLDM-2: %s" % model_name)
146 | print("Loading model on %s" % device)
147 |
148 | ckpt_path = download_checkpoint(model_name)
149 |
150 | if config is not None:
151 | assert type(config) is str
152 | config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
153 | else:
154 | config = default_audioldm_config(model_name)
155 |
156 | # # Use text as condition instead of using waveform during training
157 | config["model"]["params"]["device"] = device
158 | # config["model"]["params"]["cond_stage_key"] = "text"
159 |
160 | # No normalization here
161 | latent_diffusion = LatentDiffusion(**config["model"]["params"])
162 |
163 | resume_from_checkpoint = ckpt_path
164 |
165 | checkpoint = torch.load(resume_from_checkpoint, map_location=device)
166 |
167 | latent_diffusion.load_state_dict(checkpoint["state_dict"])
168 |
169 | latent_diffusion.eval()
170 | latent_diffusion = latent_diffusion.to(device)
171 |
172 | return latent_diffusion
173 |
174 | def text_to_audio(
175 | latent_diffusion,
176 | text,
177 | transcription="",
178 | seed=42,
179 | ddim_steps=200,
180 | duration=10,
181 | batchsize=1,
182 | guidance_scale=3.5,
183 | n_candidate_gen_per_text=3,
184 | latent_t_per_second=25.6,
185 | config=None,
186 | ):
187 |
188 | seed_everything(int(seed))
189 | waveform = None
190 |
191 | batch = make_batch_for_text_to_audio(text, transcription=transcription, waveform=waveform, batchsize=batchsize)
192 |
193 | latent_diffusion.latent_t_size = int(duration * latent_t_per_second)
194 |
195 | with torch.no_grad():
196 | waveform = latent_diffusion.generate_batch(
197 | batch,
198 | unconditional_guidance_scale=guidance_scale,
199 | ddim_steps=ddim_steps,
200 | n_gen=n_candidate_gen_per_text,
201 | duration=duration,
202 | )
203 |
204 | return waveform
205 |
--------------------------------------------------------------------------------
/audioldm2/utilities/audio/stft.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy.signal import get_window
5 | from librosa.util import pad_center, tiny
6 | from librosa.filters import mel as librosa_mel_fn
7 |
8 | from audioldm2.utilities.audio.audio_processing import (
9 | dynamic_range_compression,
10 | dynamic_range_decompression,
11 | window_sumsquare,
12 | )
13 |
14 |
15 | class STFT(torch.nn.Module):
16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
17 |
18 | def __init__(self, filter_length, hop_length, win_length, window="hann"):
19 | super(STFT, self).__init__()
20 | self.filter_length = filter_length
21 | self.hop_length = hop_length
22 | self.win_length = win_length
23 | self.window = window
24 | self.forward_transform = None
25 | scale = self.filter_length / self.hop_length
26 | fourier_basis = np.fft.fft(np.eye(self.filter_length))
27 |
28 | cutoff = int((self.filter_length / 2 + 1))
29 | fourier_basis = np.vstack(
30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
31 | )
32 |
33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
34 | inverse_basis = torch.FloatTensor(
35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]
36 | )
37 |
38 | if window is not None:
39 | assert filter_length >= win_length
40 | # get window and zero center pad it to filter_length
41 | fft_window = get_window(window, win_length, fftbins=True)
42 | fft_window = pad_center(fft_window, filter_length)
43 | fft_window = torch.from_numpy(fft_window).float()
44 |
45 | # window the bases
46 | forward_basis *= fft_window
47 | inverse_basis *= fft_window
48 |
49 | self.register_buffer("forward_basis", forward_basis.float())
50 | self.register_buffer("inverse_basis", inverse_basis.float())
51 |
52 | def transform(self, input_data):
53 | num_batches = input_data.size(0)
54 | num_samples = input_data.size(1)
55 |
56 | self.num_samples = num_samples
57 |
58 | # similar to librosa, reflect-pad the input
59 | input_data = input_data.view(num_batches, 1, num_samples)
60 | input_data = F.pad(
61 | input_data.unsqueeze(1),
62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
63 | mode="reflect",
64 | )
65 | input_data = input_data.squeeze(1)
66 |
67 | forward_transform = F.conv1d(
68 | input_data,
69 | torch.autograd.Variable(self.forward_basis, requires_grad=False),
70 | stride=self.hop_length,
71 | padding=0,
72 | ).cpu()
73 |
74 | cutoff = int((self.filter_length / 2) + 1)
75 | real_part = forward_transform[:, :cutoff, :]
76 | imag_part = forward_transform[:, cutoff:, :]
77 |
78 | magnitude = torch.sqrt(real_part**2 + imag_part**2)
79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
80 |
81 | return magnitude, phase
82 |
83 | def inverse(self, magnitude, phase):
84 | recombine_magnitude_phase = torch.cat(
85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
86 | )
87 |
88 | inverse_transform = F.conv_transpose1d(
89 | recombine_magnitude_phase,
90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False),
91 | stride=self.hop_length,
92 | padding=0,
93 | )
94 |
95 | if self.window is not None:
96 | window_sum = window_sumsquare(
97 | self.window,
98 | magnitude.size(-1),
99 | hop_length=self.hop_length,
100 | win_length=self.win_length,
101 | n_fft=self.filter_length,
102 | dtype=np.float32,
103 | )
104 | # remove modulation effects
105 | approx_nonzero_indices = torch.from_numpy(
106 | np.where(window_sum > tiny(window_sum))[0]
107 | )
108 | window_sum = torch.autograd.Variable(
109 | torch.from_numpy(window_sum), requires_grad=False
110 | )
111 | window_sum = window_sum
112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
113 | approx_nonzero_indices
114 | ]
115 |
116 | # scale by hop ratio
117 | inverse_transform *= float(self.filter_length) / self.hop_length
118 |
119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
121 |
122 | return inverse_transform
123 |
124 | def forward(self, input_data):
125 | self.magnitude, self.phase = self.transform(input_data)
126 | reconstruction = self.inverse(self.magnitude, self.phase)
127 | return reconstruction
128 |
129 |
130 | class TacotronSTFT(torch.nn.Module):
131 | def __init__(
132 | self,
133 | filter_length,
134 | hop_length,
135 | win_length,
136 | n_mel_channels,
137 | sampling_rate,
138 | mel_fmin,
139 | mel_fmax,
140 | ):
141 | super(TacotronSTFT, self).__init__()
142 | self.n_mel_channels = n_mel_channels
143 | self.sampling_rate = sampling_rate
144 | self.stft_fn = STFT(filter_length, hop_length, win_length)
145 | mel_basis = librosa_mel_fn(
146 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
147 | )
148 | mel_basis = torch.from_numpy(mel_basis).float()
149 | self.register_buffer("mel_basis", mel_basis)
150 |
151 | def spectral_normalize(self, magnitudes, normalize_fun):
152 | output = dynamic_range_compression(magnitudes, normalize_fun)
153 | return output
154 |
155 | def spectral_de_normalize(self, magnitudes):
156 | output = dynamic_range_decompression(magnitudes)
157 | return output
158 |
159 | def mel_spectrogram(self, y, normalize_fun=torch.log):
160 | """Computes mel-spectrograms from a batch of waves
161 | PARAMS
162 | ------
163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
164 |
165 | RETURNS
166 | -------
167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
168 | """
169 | assert torch.min(y.data) >= -1, torch.min(y.data)
170 | assert torch.max(y.data) <= 1, torch.max(y.data)
171 |
172 | magnitudes, phases = self.stft_fn.transform(y)
173 | magnitudes = magnitudes.data
174 | mel_output = torch.matmul(self.mel_basis, magnitudes)
175 | mel_output = self.spectral_normalize(mel_output, normalize_fun)
176 | energy = torch.norm(magnitudes, dim=1)
177 |
178 | return mel_output, magnitudes, phases, energy
179 |
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/pretrained.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 |
6 | from tqdm import tqdm
7 |
8 | _RN50 = dict(
9 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
10 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
11 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
12 | )
13 |
14 | _RN50_quickgelu = dict(
15 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
16 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
17 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
18 | )
19 |
20 | _RN101 = dict(
21 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
22 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
23 | )
24 |
25 | _RN101_quickgelu = dict(
26 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
27 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
28 | )
29 |
30 | _RN50x4 = dict(
31 | openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
32 | )
33 |
34 | _RN50x16 = dict(
35 | openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
36 | )
37 |
38 | _RN50x64 = dict(
39 | openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
40 | )
41 |
42 | _VITB32 = dict(
43 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
44 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
45 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
46 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
47 | )
48 |
49 | _VITB32_quickgelu = dict(
50 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
51 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
52 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
53 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
54 | )
55 |
56 | _VITB16 = dict(
57 | openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
58 | )
59 |
60 | _VITL14 = dict(
61 | openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
62 | )
63 |
64 | _PRETRAINED = {
65 | "RN50": _RN50,
66 | "RN50-quickgelu": _RN50_quickgelu,
67 | "RN101": _RN101,
68 | "RN101-quickgelu": _RN101_quickgelu,
69 | "RN50x4": _RN50x4,
70 | "RN50x16": _RN50x16,
71 | "ViT-B-32": _VITB32,
72 | "ViT-B-32-quickgelu": _VITB32_quickgelu,
73 | "ViT-B-16": _VITB16,
74 | "ViT-L-14": _VITL14,
75 | }
76 |
77 |
78 | def list_pretrained(as_str: bool = False):
79 | """returns list of pretrained models
80 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
81 | """
82 | return [
83 | ":".join([k, t]) if as_str else (k, t)
84 | for k in _PRETRAINED.keys()
85 | for t in _PRETRAINED[k].keys()
86 | ]
87 |
88 |
89 | def list_pretrained_tag_models(tag: str):
90 | """return all models having the specified pretrain tag"""
91 | models = []
92 | for k in _PRETRAINED.keys():
93 | if tag in _PRETRAINED[k]:
94 | models.append(k)
95 | return models
96 |
97 |
98 | def list_pretrained_model_tags(model: str):
99 | """return all pretrain tags for the specified model architecture"""
100 | tags = []
101 | if model in _PRETRAINED:
102 | tags.extend(_PRETRAINED[model].keys())
103 | return tags
104 |
105 |
106 | def get_pretrained_url(model: str, tag: str):
107 | if model not in _PRETRAINED:
108 | return ""
109 | model_pretrained = _PRETRAINED[model]
110 | if tag not in model_pretrained:
111 | return ""
112 | return model_pretrained[tag]
113 |
114 |
115 | def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")):
116 | os.makedirs(root, exist_ok=True)
117 | filename = os.path.basename(url)
118 |
119 | if "openaipublic" in url:
120 | expected_sha256 = url.split("/")[-2]
121 | else:
122 | expected_sha256 = ""
123 |
124 | download_target = os.path.join(root, filename)
125 |
126 | if os.path.exists(download_target) and not os.path.isfile(download_target):
127 | raise RuntimeError(f"{download_target} exists and is not a regular file")
128 |
129 | if os.path.isfile(download_target):
130 | if expected_sha256:
131 | if (
132 | hashlib.sha256(open(download_target, "rb").read()).hexdigest()
133 | == expected_sha256
134 | ):
135 | return download_target
136 | else:
137 | warnings.warn(
138 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
139 | )
140 | else:
141 | return download_target
142 |
143 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
144 | with tqdm(
145 | total=int(source.info().get("Content-Length")),
146 | ncols=80,
147 | unit="iB",
148 | unit_scale=True,
149 | ) as loop:
150 | while True:
151 | buffer = source.read(8192)
152 | if not buffer:
153 | break
154 |
155 | output.write(buffer)
156 | loop.update(len(buffer))
157 |
158 | if (
159 | expected_sha256
160 | and hashlib.sha256(open(download_target, "rb").read()).hexdigest()
161 | != expected_sha256
162 | ):
163 | raise RuntimeError(
164 | f"Model has been downloaded but the SHA256 checksum does not not match"
165 | )
166 |
167 | return download_target
168 |
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/tokenizer.py:
--------------------------------------------------------------------------------
1 | """ CLIP tokenizer
2 |
3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 | import gzip
6 | import html
7 | import os
8 | from functools import lru_cache
9 | from typing import Union, List
10 |
11 | import ftfy
12 | import regex as re
13 | import torch
14 |
15 |
16 | @lru_cache()
17 | def default_bpe():
18 | return os.path.join(
19 | os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
20 | )
21 |
22 |
23 | @lru_cache()
24 | def bytes_to_unicode():
25 | """
26 | Returns list of utf-8 byte and a corresponding list of unicode strings.
27 | The reversible bpe codes work on unicode strings.
28 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
29 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
30 | This is a signficant percentage of your normal, say, 32K bpe vocab.
31 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
32 | And avoids mapping to whitespace/control characters the bpe code barfs on.
33 | """
34 | bs = (
35 | list(range(ord("!"), ord("~") + 1))
36 | + list(range(ord("¡"), ord("¬") + 1))
37 | + list(range(ord("®"), ord("ÿ") + 1))
38 | )
39 | cs = bs[:]
40 | n = 0
41 | for b in range(2**8):
42 | if b not in bs:
43 | bs.append(b)
44 | cs.append(2**8 + n)
45 | n += 1
46 | cs = [chr(n) for n in cs]
47 | return dict(zip(bs, cs))
48 |
49 |
50 | def get_pairs(word):
51 | """Return set of symbol pairs in a word.
52 | Word is represented as tuple of symbols (symbols being variable-length strings).
53 | """
54 | pairs = set()
55 | prev_char = word[0]
56 | for char in word[1:]:
57 | pairs.add((prev_char, char))
58 | prev_char = char
59 | return pairs
60 |
61 |
62 | def basic_clean(text):
63 | text = ftfy.fix_text(text)
64 | text = html.unescape(html.unescape(text))
65 | return text.strip()
66 |
67 |
68 | def whitespace_clean(text):
69 | text = re.sub(r"\s+", " ", text)
70 | text = text.strip()
71 | return text
72 |
73 |
74 | class SimpleTokenizer(object):
75 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
76 | self.byte_encoder = bytes_to_unicode()
77 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
78 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
79 | merges = merges[1 : 49152 - 256 - 2 + 1]
80 | merges = [tuple(merge.split()) for merge in merges]
81 | vocab = list(bytes_to_unicode().values())
82 | vocab = vocab + [v + "" for v in vocab]
83 | for merge in merges:
84 | vocab.append("".join(merge))
85 | if not special_tokens:
86 | special_tokens = ["", ""]
87 | else:
88 | special_tokens = ["", ""] + special_tokens
89 | vocab.extend(special_tokens)
90 | self.encoder = dict(zip(vocab, range(len(vocab))))
91 | self.decoder = {v: k for k, v in self.encoder.items()}
92 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
93 | self.cache = {t: t for t in special_tokens}
94 | special = "|".join(special_tokens)
95 | self.pat = re.compile(
96 | special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
97 | re.IGNORECASE,
98 | )
99 |
100 | self.vocab_size = len(self.encoder)
101 | self.all_special_ids = [self.encoder[t] for t in special_tokens]
102 |
103 | def bpe(self, token):
104 | if token in self.cache:
105 | return self.cache[token]
106 | word = tuple(token[:-1]) + (token[-1] + "",)
107 | pairs = get_pairs(word)
108 |
109 | if not pairs:
110 | return token + ""
111 |
112 | while True:
113 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
114 | if bigram not in self.bpe_ranks:
115 | break
116 | first, second = bigram
117 | new_word = []
118 | i = 0
119 | while i < len(word):
120 | try:
121 | j = word.index(first, i)
122 | new_word.extend(word[i:j])
123 | i = j
124 | except:
125 | new_word.extend(word[i:])
126 | break
127 |
128 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
129 | new_word.append(first + second)
130 | i += 2
131 | else:
132 | new_word.append(word[i])
133 | i += 1
134 | new_word = tuple(new_word)
135 | word = new_word
136 | if len(word) == 1:
137 | break
138 | else:
139 | pairs = get_pairs(word)
140 | word = " ".join(word)
141 | self.cache[token] = word
142 | return word
143 |
144 | def encode(self, text):
145 | bpe_tokens = []
146 | text = whitespace_clean(basic_clean(text)).lower()
147 | for token in re.findall(self.pat, text):
148 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
149 | bpe_tokens.extend(
150 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
151 | )
152 | return bpe_tokens
153 |
154 | def decode(self, tokens):
155 | text = "".join([self.decoder[token] for token in tokens])
156 | text = (
157 | bytearray([self.byte_decoder[c] for c in text])
158 | .decode("utf-8", errors="replace")
159 | .replace("", " ")
160 | )
161 | return text
162 |
163 |
164 | _tokenizer = SimpleTokenizer()
165 |
166 |
167 | def tokenize(
168 | texts: Union[str, List[str]], context_length: int = 77
169 | ) -> torch.LongTensor:
170 | """
171 | Returns the tokenized representation of given input string(s)
172 |
173 | Parameters
174 | ----------
175 | texts : Union[str, List[str]]
176 | An input string or a list of input strings to tokenize
177 | context_length : int
178 | The context length to use; all CLIP models use 77 as the context length
179 |
180 | Returns
181 | -------
182 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
183 | """
184 | if isinstance(texts, str):
185 | texts = [texts]
186 |
187 | sot_token = _tokenizer.encoder[""]
188 | eot_token = _tokenizer.encoder[""]
189 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
190 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
191 |
192 | for i, tokens in enumerate(all_tokens):
193 | if len(tokens) > context_length:
194 | tokens = tokens[:context_length] # Truncate
195 | result[i, : len(tokens)] = torch.tensor(tokens)
196 |
197 | return result
198 |
--------------------------------------------------------------------------------
/audioldm2/clap/open_clip/feature_fusion.py:
--------------------------------------------------------------------------------
1 | """
2 | Feature Fusion for Varible-Length Data Processing
3 | AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4 | According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5 | """
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | class DAF(nn.Module):
12 | """
13 | 直接相加 DirectAddFuse
14 | """
15 |
16 | def __init__(self):
17 | super(DAF, self).__init__()
18 |
19 | def forward(self, x, residual):
20 | return x + residual
21 |
22 |
23 | class iAFF(nn.Module):
24 | """
25 | 多特征融合 iAFF
26 | """
27 |
28 | def __init__(self, channels=64, r=4, type="2D"):
29 | super(iAFF, self).__init__()
30 | inter_channels = int(channels // r)
31 |
32 | if type == "1D":
33 | # 本地注意力
34 | self.local_att = nn.Sequential(
35 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36 | nn.BatchNorm1d(inter_channels),
37 | nn.ReLU(inplace=True),
38 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39 | nn.BatchNorm1d(channels),
40 | )
41 |
42 | # 全局注意力
43 | self.global_att = nn.Sequential(
44 | nn.AdaptiveAvgPool1d(1),
45 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46 | nn.BatchNorm1d(inter_channels),
47 | nn.ReLU(inplace=True),
48 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49 | nn.BatchNorm1d(channels),
50 | )
51 |
52 | # 第二次本地注意力
53 | self.local_att2 = nn.Sequential(
54 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55 | nn.BatchNorm1d(inter_channels),
56 | nn.ReLU(inplace=True),
57 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58 | nn.BatchNorm1d(channels),
59 | )
60 | # 第二次全局注意力
61 | self.global_att2 = nn.Sequential(
62 | nn.AdaptiveAvgPool1d(1),
63 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64 | nn.BatchNorm1d(inter_channels),
65 | nn.ReLU(inplace=True),
66 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67 | nn.BatchNorm1d(channels),
68 | )
69 | elif type == "2D":
70 | # 本地注意力
71 | self.local_att = nn.Sequential(
72 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73 | nn.BatchNorm2d(inter_channels),
74 | nn.ReLU(inplace=True),
75 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76 | nn.BatchNorm2d(channels),
77 | )
78 |
79 | # 全局注意力
80 | self.global_att = nn.Sequential(
81 | nn.AdaptiveAvgPool2d(1),
82 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83 | nn.BatchNorm2d(inter_channels),
84 | nn.ReLU(inplace=True),
85 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86 | nn.BatchNorm2d(channels),
87 | )
88 |
89 | # 第二次本地注意力
90 | self.local_att2 = nn.Sequential(
91 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92 | nn.BatchNorm2d(inter_channels),
93 | nn.ReLU(inplace=True),
94 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95 | nn.BatchNorm2d(channels),
96 | )
97 | # 第二次全局注意力
98 | self.global_att2 = nn.Sequential(
99 | nn.AdaptiveAvgPool2d(1),
100 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101 | nn.BatchNorm2d(inter_channels),
102 | nn.ReLU(inplace=True),
103 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104 | nn.BatchNorm2d(channels),
105 | )
106 | else:
107 | raise f"the type is not supported"
108 |
109 | self.sigmoid = nn.Sigmoid()
110 |
111 | def forward(self, x, residual):
112 | flag = False
113 | xa = x + residual
114 | if xa.size(0) == 1:
115 | xa = torch.cat([xa, xa], dim=0)
116 | flag = True
117 | xl = self.local_att(xa)
118 | xg = self.global_att(xa)
119 | xlg = xl + xg
120 | wei = self.sigmoid(xlg)
121 | xi = x * wei + residual * (1 - wei)
122 |
123 | xl2 = self.local_att2(xi)
124 | xg2 = self.global_att(xi)
125 | xlg2 = xl2 + xg2
126 | wei2 = self.sigmoid(xlg2)
127 | xo = x * wei2 + residual * (1 - wei2)
128 | if flag:
129 | xo = xo[0].unsqueeze(0)
130 | return xo
131 |
132 |
133 | class AFF(nn.Module):
134 | """
135 | 多特征融合 AFF
136 | """
137 |
138 | def __init__(self, channels=64, r=4, type="2D"):
139 | super(AFF, self).__init__()
140 | inter_channels = int(channels // r)
141 |
142 | if type == "1D":
143 | self.local_att = nn.Sequential(
144 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145 | nn.BatchNorm1d(inter_channels),
146 | nn.ReLU(inplace=True),
147 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148 | nn.BatchNorm1d(channels),
149 | )
150 | self.global_att = nn.Sequential(
151 | nn.AdaptiveAvgPool1d(1),
152 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153 | nn.BatchNorm1d(inter_channels),
154 | nn.ReLU(inplace=True),
155 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156 | nn.BatchNorm1d(channels),
157 | )
158 | elif type == "2D":
159 | self.local_att = nn.Sequential(
160 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161 | nn.BatchNorm2d(inter_channels),
162 | nn.ReLU(inplace=True),
163 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164 | nn.BatchNorm2d(channels),
165 | )
166 | self.global_att = nn.Sequential(
167 | nn.AdaptiveAvgPool2d(1),
168 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169 | nn.BatchNorm2d(inter_channels),
170 | nn.ReLU(inplace=True),
171 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172 | nn.BatchNorm2d(channels),
173 | )
174 | else:
175 | raise f"the type is not supported."
176 |
177 | self.sigmoid = nn.Sigmoid()
178 |
179 | def forward(self, x, residual):
180 | flag = False
181 | xa = x + residual
182 | if xa.size(0) == 1:
183 | xa = torch.cat([xa, xa], dim=0)
184 | flag = True
185 | xl = self.local_att(xa)
186 | xg = self.global_att(xa)
187 | xlg = xl + xg
188 | wei = self.sigmoid(xlg)
189 | xo = 2 * x * wei + 2 * residual * (1 - wei)
190 | if flag:
191 | xo = xo[0].unsqueeze(0)
192 | return xo
193 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 |
7 | import multiprocessing as mp
8 | from threading import Thread
9 | from queue import Queue
10 |
11 | from inspect import isfunction
12 | from PIL import Image, ImageDraw, ImageFont
13 |
14 | CACHE = {
15 | "get_vits_phoneme_ids":{
16 | "PAD_LENGTH": 310,
17 | "_pad": '_',
18 | "_punctuation": ';:,.!?¡¿—…"«»“” ',
19 | "_letters": 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz',
20 | "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ",
21 | "_special": "♪☎☒☝⚠"
22 | }
23 | }
24 |
25 | CACHE["get_vits_phoneme_ids"]["symbols"] = [CACHE["get_vits_phoneme_ids"]["_pad"]] + list(CACHE["get_vits_phoneme_ids"]["_punctuation"]) + list(CACHE["get_vits_phoneme_ids"]["_letters"]) + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"]) + list(CACHE["get_vits_phoneme_ids"]["_special"])
26 | CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = {s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"])}
27 |
28 | def get_vits_phoneme_ids_no_padding(phonemes):
29 | pad_token_id = 0
30 | pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
31 | _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
32 | batchsize = len(phonemes)
33 |
34 | clean_text = phonemes[0] + "⚠"
35 | sequence = []
36 |
37 | for symbol in clean_text:
38 | if(symbol not in _symbol_to_id.keys()):
39 | print("%s is not in the vocabulary. %s" % (symbol, clean_text))
40 | symbol = "_"
41 | symbol_id = _symbol_to_id[symbol]
42 | sequence += [symbol_id]
43 |
44 | def _pad_phonemes(phonemes_list):
45 | return phonemes_list + [pad_token_id] * (pad_length-len(phonemes_list))
46 |
47 | sequence = sequence[:pad_length]
48 |
49 | return {"phoneme_idx": torch.LongTensor(_pad_phonemes(sequence)).unsqueeze(0).expand(batchsize, -1)}
50 |
51 | def log_txt_as_img(wh, xc, size=10):
52 | # wh a tuple of (width, height)
53 | # xc a list of captions to plot
54 | b = len(xc)
55 | txts = list()
56 | for bi in range(b):
57 | txt = Image.new("RGB", wh, color="white")
58 | draw = ImageDraw.Draw(txt)
59 | font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
60 | nc = int(40 * (wh[0] / 256))
61 | lines = "\n".join(
62 | xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
63 | )
64 |
65 | try:
66 | draw.text((0, 0), lines, fill="black", font=font)
67 | except UnicodeEncodeError:
68 | print("Cant encode string for logging. Skipping.")
69 |
70 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
71 | txts.append(txt)
72 | txts = np.stack(txts)
73 | txts = torch.tensor(txts)
74 | return txts
75 |
76 |
77 | def ismap(x):
78 | if not isinstance(x, torch.Tensor):
79 | return False
80 | return (len(x.shape) == 4) and (x.shape[1] > 3)
81 |
82 |
83 | def isimage(x):
84 | if not isinstance(x, torch.Tensor):
85 | return False
86 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
87 |
88 |
89 | def int16_to_float32(x):
90 | return (x / 32767.0).astype(np.float32)
91 |
92 |
93 | def float32_to_int16(x):
94 | x = np.clip(x, a_min=-1.0, a_max=1.0)
95 | return (x * 32767.0).astype(np.int16)
96 |
97 |
98 | def exists(x):
99 | return x is not None
100 |
101 |
102 | def default(val, d):
103 | if exists(val):
104 | return val
105 | return d() if isfunction(d) else d
106 |
107 |
108 | def mean_flat(tensor):
109 | """
110 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
111 | Take the mean over all non-batch dimensions.
112 | """
113 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
114 |
115 |
116 | def count_params(model, verbose=False):
117 | total_params = sum(p.numel() for p in model.parameters())
118 | if verbose:
119 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
120 | return total_params
121 |
122 |
123 | def instantiate_from_config(config):
124 | if not "target" in config:
125 | if config == "__is_first_stage__":
126 | return None
127 | elif config == "__is_unconditional__":
128 | return None
129 | raise KeyError("Expected key `target` to instantiate.")
130 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
131 |
132 |
133 | def get_obj_from_str(string, reload=False):
134 | module, cls = string.rsplit(".", 1)
135 | if reload:
136 | module_imp = importlib.import_module(module)
137 | importlib.reload(module_imp)
138 | return getattr(importlib.import_module(module, package=None), cls)
139 |
140 |
141 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
142 | # create dummy dataset instance
143 |
144 | # run prefetching
145 | if idx_to_fn:
146 | res = func(data, worker_id=idx)
147 | else:
148 | res = func(data)
149 | Q.put([idx, res])
150 | Q.put("Done")
151 |
152 |
153 | def parallel_data_prefetch(
154 | func: callable,
155 | data,
156 | n_proc,
157 | target_data_type="ndarray",
158 | cpu_intensive=True,
159 | use_worker_id=False,
160 | ):
161 | # if target_data_type not in ["ndarray", "list"]:
162 | # raise ValueError(
163 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
164 | # )
165 | if isinstance(data, np.ndarray) and target_data_type == "list":
166 | raise ValueError("list expected but function got ndarray.")
167 | elif isinstance(data, abc.Iterable):
168 | if isinstance(data, dict):
169 | print(
170 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
171 | )
172 | data = list(data.values())
173 | if target_data_type == "ndarray":
174 | data = np.asarray(data)
175 | else:
176 | data = list(data)
177 | else:
178 | raise TypeError(
179 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
180 | )
181 |
182 | if cpu_intensive:
183 | Q = mp.Queue(1000)
184 | proc = mp.Process
185 | else:
186 | Q = Queue(1000)
187 | proc = Thread
188 | # spawn processes
189 | if target_data_type == "ndarray":
190 | arguments = [
191 | [func, Q, part, i, use_worker_id]
192 | for i, part in enumerate(np.array_split(data, n_proc))
193 | ]
194 | else:
195 | step = (
196 | int(len(data) / n_proc + 1)
197 | if len(data) % n_proc != 0
198 | else int(len(data) / n_proc)
199 | )
200 | arguments = [
201 | [func, Q, part, i, use_worker_id]
202 | for i, part in enumerate(
203 | [data[i : i + step] for i in range(0, len(data), step)]
204 | )
205 | ]
206 | processes = []
207 | for i in range(n_proc):
208 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
209 | processes += [p]
210 |
211 | # start processes
212 | print(f"Start prefetching...")
213 | import time
214 |
215 | start = time.time()
216 | gather_res = [[] for _ in range(n_proc)]
217 | try:
218 | for p in processes:
219 | p.start()
220 |
221 | k = 0
222 | while k < n_proc:
223 | # get result
224 | res = Q.get()
225 | if res == "Done":
226 | k += 1
227 | else:
228 | gather_res[res[0]] = res[1]
229 |
230 | except Exception as e:
231 | print("Exception: ", e)
232 | for p in processes:
233 | p.terminate()
234 |
235 | raise e
236 | finally:
237 | for p in processes:
238 | p.join()
239 | print(f"Prefetching complete. [{time.time() - start} sec.]")
240 |
241 | if target_data_type == "ndarray":
242 | if not isinstance(gather_res[0], np.ndarray):
243 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
244 |
245 | # order outputs
246 | return np.concatenate(gather_res, axis=0)
247 | elif target_data_type == "list":
248 | out = []
249 | for r in gather_res:
250 | out.extend(r)
251 | return out
252 | else:
253 | return gather_res
254 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/models_vit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | # DeiT: https://github.com/facebookresearch/deit
10 | # --------------------------------------------------------
11 |
12 | from functools import partial
13 |
14 | import torch
15 | import torch.nn as nn
16 | import timm.models.vision_transformer
17 |
18 |
19 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
20 | """Vision Transformer with support for global average pooling"""
21 |
22 | def __init__(
23 | self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs
24 | ):
25 | super(VisionTransformer, self).__init__(**kwargs)
26 |
27 | self.global_pool = global_pool
28 | if self.global_pool:
29 | norm_layer = kwargs["norm_layer"]
30 | embed_dim = kwargs["embed_dim"]
31 | self.fc_norm = norm_layer(embed_dim)
32 | del self.norm # remove the original norm
33 | self.mask_2d = mask_2d
34 | self.use_custom_patch = use_custom_patch
35 |
36 | def forward_features(self, x):
37 | B = x.shape[0]
38 | x = self.patch_embed(x)
39 | x = x + self.pos_embed[:, 1:, :]
40 | cls_token = self.cls_token + self.pos_embed[:, :1, :]
41 | cls_tokens = cls_token.expand(
42 | B, -1, -1
43 | ) # stole cls_tokens impl from Phil Wang, thanks
44 | x = torch.cat((cls_tokens, x), dim=1)
45 | x = self.pos_drop(x)
46 |
47 | for blk in self.blocks:
48 | x = blk(x)
49 |
50 | if self.global_pool:
51 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token
52 | outcome = self.fc_norm(x)
53 | else:
54 | x = self.norm(x)
55 | outcome = x[:, 0]
56 |
57 | return outcome
58 |
59 | def random_masking(self, x, mask_ratio):
60 | """
61 | Perform per-sample random masking by per-sample shuffling.
62 | Per-sample shuffling is done by argsort random noise.
63 | x: [N, L, D], sequence
64 | """
65 | N, L, D = x.shape # batch, length, dim
66 | len_keep = int(L * (1 - mask_ratio))
67 |
68 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
69 |
70 | # sort noise for each sample
71 | ids_shuffle = torch.argsort(
72 | noise, dim=1
73 | ) # ascend: small is keep, large is remove
74 | ids_restore = torch.argsort(ids_shuffle, dim=1)
75 |
76 | # keep the first subset
77 | ids_keep = ids_shuffle[:, :len_keep]
78 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
79 |
80 | # generate the binary mask: 0 is keep, 1 is remove
81 | mask = torch.ones([N, L], device=x.device)
82 | mask[:, :len_keep] = 0
83 | # unshuffle to get the binary mask
84 | mask = torch.gather(mask, dim=1, index=ids_restore)
85 |
86 | return x_masked, mask, ids_restore
87 |
88 | def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
89 | """
90 | 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
91 | Perform per-sample random masking by per-sample shuffling.
92 | Per-sample shuffling is done by argsort random noise.
93 | x: [N, L, D], sequence
94 | """
95 |
96 | N, L, D = x.shape # batch, length, dim
97 | if self.use_custom_patch:
98 | # # for AS
99 | T = 101 # 64,101
100 | F = 12 # 8,12
101 | # # for ESC
102 | # T=50
103 | # F=12
104 | # for SPC
105 | # T=12
106 | # F=12
107 | else:
108 | # ## for AS
109 | T = 64
110 | F = 8
111 | # ## for ESC
112 | # T=32
113 | # F=8
114 | ## for SPC
115 | # T=8
116 | # F=8
117 |
118 | # mask T
119 | x = x.reshape(N, T, F, D)
120 | len_keep_T = int(T * (1 - mask_t_prob))
121 | noise = torch.rand(N, T, device=x.device) # noise in [0, 1]
122 | # sort noise for each sample
123 | ids_shuffle = torch.argsort(
124 | noise, dim=1
125 | ) # ascend: small is keep, large is remove
126 | ids_keep = ids_shuffle[:, :len_keep_T]
127 | index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D)
128 | # x_masked = torch.gather(x, dim=1, index=index)
129 | # x_masked = x_masked.reshape(N,len_keep_T*F,D)
130 | x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D
131 |
132 | # mask F
133 | # x = x.reshape(N, T, F, D)
134 | x = x.permute(0, 2, 1, 3) # N T' F D => N F T' D
135 | len_keep_F = int(F * (1 - mask_f_prob))
136 | noise = torch.rand(N, F, device=x.device) # noise in [0, 1]
137 | # sort noise for each sample
138 | ids_shuffle = torch.argsort(
139 | noise, dim=1
140 | ) # ascend: small is keep, large is remove
141 | ids_keep = ids_shuffle[:, :len_keep_F]
142 | # index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D)
143 | index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D)
144 | x_masked = torch.gather(x, dim=1, index=index)
145 | x_masked = x_masked.permute(0, 2, 1, 3) # N F' T' D => N T' F' D
146 | # x_masked = x_masked.reshape(N,len_keep*T,D)
147 | x_masked = x_masked.reshape(N, len_keep_F * len_keep_T, D)
148 |
149 | return x_masked, None, None
150 |
151 | def forward_features_mask(self, x, mask_t_prob, mask_f_prob):
152 | B = x.shape[0] # 4,1,1024,128
153 | x = self.patch_embed(x) # 4, 512, 768
154 |
155 | x = x + self.pos_embed[:, 1:, :]
156 | if self.random_masking_2d:
157 | x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob)
158 | else:
159 | x, mask, ids_restore = self.random_masking(x, mask_t_prob)
160 | cls_token = self.cls_token + self.pos_embed[:, :1, :]
161 | cls_tokens = cls_token.expand(B, -1, -1)
162 | x = torch.cat((cls_tokens, x), dim=1)
163 | x = self.pos_drop(x)
164 |
165 | # apply Transformer blocks
166 | for blk in self.blocks:
167 | x = blk(x)
168 |
169 | if self.global_pool:
170 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token
171 | outcome = self.fc_norm(x)
172 | else:
173 | x = self.norm(x)
174 | outcome = x[:, 0]
175 |
176 | return outcome
177 |
178 | # overwrite original timm
179 | def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0):
180 | if mask_t_prob > 0.0 or mask_f_prob > 0.0:
181 | x = self.forward_features_mask(
182 | x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob
183 | )
184 | else:
185 | x = self.forward_features(x)
186 | x = self.head(x)
187 | return x
188 |
189 |
190 | def vit_small_patch16(**kwargs):
191 | model = VisionTransformer(
192 | patch_size=16,
193 | embed_dim=384,
194 | depth=12,
195 | num_heads=6,
196 | mlp_ratio=4,
197 | qkv_bias=True,
198 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
199 | **kwargs
200 | )
201 | return model
202 |
203 |
204 | def vit_base_patch16(**kwargs):
205 | model = VisionTransformer(
206 | patch_size=16,
207 | embed_dim=768,
208 | depth=12,
209 | num_heads=12,
210 | mlp_ratio=4,
211 | qkv_bias=True,
212 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
213 | **kwargs
214 | )
215 | return model
216 |
217 |
218 | def vit_large_patch16(**kwargs):
219 | model = VisionTransformer(
220 | patch_size=16,
221 | embed_dim=1024,
222 | depth=24,
223 | num_heads=16,
224 | mlp_ratio=4,
225 | qkv_bias=True,
226 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
227 | **kwargs
228 | )
229 | return model
230 |
231 |
232 | def vit_huge_patch14(**kwargs):
233 | model = VisionTransformer(
234 | patch_size=14,
235 | embed_dim=1280,
236 | depth=32,
237 | num_heads=16,
238 | mlp_ratio=4,
239 | qkv_bias=True,
240 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
241 | **kwargs
242 | )
243 | return model
244 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/audiomae/util/pos_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # Position embedding utils
8 | # --------------------------------------------------------
9 |
10 | import numpy as np
11 |
12 | import torch
13 |
14 |
15 | # --------------------------------------------------------
16 | # 2D sine-cosine position embedding
17 | # References:
18 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
19 | # MoCo v3: https://github.com/facebookresearch/moco-v3
20 | # --------------------------------------------------------
21 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
22 | """
23 | grid_size: int of the grid height and width
24 | return:
25 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
26 | """
27 | grid_h = np.arange(grid_size, dtype=np.float32)
28 | grid_w = np.arange(grid_size, dtype=np.float32)
29 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
30 | grid = np.stack(grid, axis=0)
31 |
32 | grid = grid.reshape([2, 1, grid_size, grid_size])
33 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
34 | if cls_token:
35 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
36 | return pos_embed
37 |
38 |
39 | def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
40 | """
41 | grid_size: int of the grid height and width
42 | return:
43 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
44 | """
45 | grid_h = np.arange(grid_size[0], dtype=np.float32)
46 | grid_w = np.arange(grid_size[1], dtype=np.float32)
47 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
48 | grid = np.stack(grid, axis=0)
49 |
50 | grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
51 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
52 | if cls_token:
53 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
54 | return pos_embed
55 |
56 |
57 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
58 | assert embed_dim % 2 == 0
59 |
60 | # use half of dimensions to encode grid_h
61 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
62 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
63 |
64 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
65 | return emb
66 |
67 |
68 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
69 | """
70 | embed_dim: output dimension for each position
71 | pos: a list of positions to be encoded: size (M,)
72 | out: (M, D)
73 | """
74 | assert embed_dim % 2 == 0
75 | # omega = np.arange(embed_dim // 2, dtype=np.float)
76 | omega = np.arange(embed_dim // 2, dtype=float)
77 | omega /= embed_dim / 2.0
78 | omega = 1.0 / 10000**omega # (D/2,)
79 |
80 | pos = pos.reshape(-1) # (M,)
81 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
82 |
83 | emb_sin = np.sin(out) # (M, D/2)
84 | emb_cos = np.cos(out) # (M, D/2)
85 |
86 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
87 | return emb
88 |
89 |
90 | # --------------------------------------------------------
91 | # Interpolate position embeddings for high-resolution
92 | # References:
93 | # DeiT: https://github.com/facebookresearch/deit
94 | # --------------------------------------------------------
95 | def interpolate_pos_embed(model, checkpoint_model):
96 | if "pos_embed" in checkpoint_model:
97 | pos_embed_checkpoint = checkpoint_model["pos_embed"]
98 | embedding_size = pos_embed_checkpoint.shape[-1]
99 | num_patches = model.patch_embed.num_patches
100 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
101 | # height (== width) for the checkpoint position embedding
102 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
103 | # height (== width) for the new position embedding
104 | new_size = int(num_patches**0.5)
105 | # class_token and dist_token are kept unchanged
106 | if orig_size != new_size:
107 | print(
108 | "Position interpolate from %dx%d to %dx%d"
109 | % (orig_size, orig_size, new_size, new_size)
110 | )
111 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
112 | # only the position tokens are interpolated
113 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
114 | pos_tokens = pos_tokens.reshape(
115 | -1, orig_size, orig_size, embedding_size
116 | ).permute(0, 3, 1, 2)
117 | pos_tokens = torch.nn.functional.interpolate(
118 | pos_tokens,
119 | size=(new_size, new_size),
120 | mode="bicubic",
121 | align_corners=False,
122 | )
123 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
124 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
125 | checkpoint_model["pos_embed"] = new_pos_embed
126 |
127 |
128 | def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size):
129 | if "pos_embed" in checkpoint_model:
130 | pos_embed_checkpoint = checkpoint_model["pos_embed"]
131 | embedding_size = pos_embed_checkpoint.shape[-1]
132 | num_patches = model.patch_embed.num_patches
133 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
134 | # height (== width) for the checkpoint position embedding
135 | # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
136 | # height (== width) for the new position embedding
137 | # new_size = int(num_patches ** 0.5)
138 | # class_token and dist_token are kept unchanged
139 | if orig_size != new_size:
140 | print(
141 | "Position interpolate from %dx%d to %dx%d"
142 | % (orig_size[0], orig_size[1], new_size[0], new_size[1])
143 | )
144 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
145 | # only the position tokens are interpolated
146 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
147 | pos_tokens = pos_tokens.reshape(
148 | -1, orig_size[0], orig_size[1], embedding_size
149 | ).permute(0, 3, 1, 2)
150 | pos_tokens = torch.nn.functional.interpolate(
151 | pos_tokens,
152 | size=(new_size[0], new_size[1]),
153 | mode="bicubic",
154 | align_corners=False,
155 | )
156 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
157 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
158 | checkpoint_model["pos_embed"] = new_pos_embed
159 |
160 |
161 | def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size):
162 | if "pos_embed" in checkpoint_model:
163 | pos_embed_checkpoint = checkpoint_model["pos_embed"]
164 | embedding_size = pos_embed_checkpoint.shape[-1]
165 | num_patches = model.patch_embed.num_patches
166 | model.pos_embed.shape[-2] - num_patches
167 | if orig_size != new_size:
168 | print(
169 | "Position interpolate from %dx%d to %dx%d"
170 | % (orig_size[0], orig_size[1], new_size[0], new_size[1])
171 | )
172 | # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
173 | # only the position tokens are interpolated
174 | cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1)
175 | pos_tokens = pos_embed_checkpoint[:, 1:, :] # remove
176 | pos_tokens = pos_tokens.reshape(
177 | -1, orig_size[0], orig_size[1], embedding_size
178 | ) # .permute(0, 3, 1, 2)
179 | # pos_tokens = torch.nn.functional.interpolate(
180 | # pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False)
181 |
182 | # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
183 | pos_tokens = pos_tokens[:, :, : new_size[1], :] # assume only time diff
184 | pos_tokens = pos_tokens.flatten(1, 2)
185 | new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1)
186 | checkpoint_model["pos_embed"] = new_pos_embed
187 |
188 |
189 | def interpolate_patch_embed_audio(
190 | model,
191 | checkpoint_model,
192 | orig_channel,
193 | new_channel=1,
194 | kernel_size=(16, 16),
195 | stride=(16, 16),
196 | padding=(0, 0),
197 | ):
198 | if orig_channel != new_channel:
199 | if "patch_embed.proj.weight" in checkpoint_model:
200 | # aggregate 3 channels in rgb ckpt to 1 channel for audio
201 | new_proj_weight = torch.nn.Parameter(
202 | torch.sum(checkpoint_model["patch_embed.proj.weight"], dim=1).unsqueeze(
203 | 1
204 | )
205 | )
206 | checkpoint_model["patch_embed.proj.weight"] = new_proj_weight
207 |
--------------------------------------------------------------------------------
/audioldm2/latent_diffusion/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import math
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 | from einops import repeat
16 |
17 | from audioldm2.latent_diffusion.util import instantiate_from_config
18 |
19 |
20 | def make_beta_schedule(
21 | schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
22 | ):
23 | if schedule == "linear":
24 | betas = (
25 | torch.linspace(
26 | linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
27 | )
28 | ** 2
29 | )
30 |
31 | elif schedule == "cosine":
32 | timesteps = (
33 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
34 | )
35 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
36 | alphas = torch.cos(alphas).pow(2)
37 | alphas = alphas / alphas[0]
38 | betas = 1 - alphas[1:] / alphas[:-1]
39 | betas = np.clip(betas, a_min=0, a_max=0.999)
40 |
41 | elif schedule == "sqrt_linear":
42 | betas = torch.linspace(
43 | linear_start, linear_end, n_timestep, dtype=torch.float64
44 | )
45 | elif schedule == "sqrt":
46 | betas = (
47 | torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
48 | ** 0.5
49 | )
50 | else:
51 | raise ValueError(f"schedule '{schedule}' unknown.")
52 | return betas.numpy()
53 |
54 |
55 | def make_ddim_timesteps(
56 | ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
57 | ):
58 | if ddim_discr_method == "uniform":
59 | c = num_ddpm_timesteps // num_ddim_timesteps
60 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
61 | elif ddim_discr_method == "quad":
62 | ddim_timesteps = (
63 | (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
64 | ).astype(int)
65 | else:
66 | raise NotImplementedError(
67 | f'There is no ddim discretization method called "{ddim_discr_method}"'
68 | )
69 |
70 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
71 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
72 | steps_out = ddim_timesteps + 1
73 | if verbose:
74 | print(f"Selected timesteps for ddim sampler: {steps_out}")
75 | return steps_out
76 |
77 |
78 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
79 | # select alphas for computing the variance schedule
80 | alphas = alphacums[ddim_timesteps]
81 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
82 |
83 | # according the the formula provided in https://arxiv.org/abs/2010.02502
84 | sigmas = eta * np.sqrt(
85 | (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
86 | )
87 | if verbose:
88 | print(
89 | f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
90 | )
91 | print(
92 | f"For the chosen value of eta, which is {eta}, "
93 | f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
94 | )
95 | return sigmas, alphas, alphas_prev
96 |
97 |
98 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
99 | """
100 | Create a beta schedule that discretizes the given alpha_t_bar function,
101 | which defines the cumulative product of (1-beta) over time from t = [0,1].
102 | :param num_diffusion_timesteps: the number of betas to produce.
103 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
104 | produces the cumulative product of (1-beta) up to that
105 | part of the diffusion process.
106 | :param max_beta: the maximum beta to use; use values lower than 1 to
107 | prevent singularities.
108 | """
109 | betas = []
110 | for i in range(num_diffusion_timesteps):
111 | t1 = i / num_diffusion_timesteps
112 | t2 = (i + 1) / num_diffusion_timesteps
113 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
114 | return np.array(betas)
115 |
116 |
117 | def extract_into_tensor(a, t, x_shape):
118 | b, *_ = t.shape
119 | out = a.gather(-1, t).contiguous()
120 | return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous()
121 |
122 |
123 | def checkpoint(func, inputs, params, flag):
124 | """
125 | Evaluate a function without caching intermediate activations, allowing for
126 | reduced memory at the expense of extra compute in the backward pass.
127 | :param func: the function to evaluate.
128 | :param inputs: the argument sequence to pass to `func`.
129 | :param params: a sequence of parameters `func` depends on but does not
130 | explicitly take as arguments.
131 | :param flag: if False, disable gradient checkpointing.
132 | """
133 | if flag:
134 | args = tuple(inputs) + tuple(params)
135 | return CheckpointFunction.apply(func, len(inputs), *args)
136 | else:
137 | return func(*inputs)
138 |
139 |
140 | class CheckpointFunction(torch.autograd.Function):
141 | @staticmethod
142 | def forward(ctx, run_function, length, *args):
143 | ctx.run_function = run_function
144 | ctx.input_tensors = list(args[:length])
145 | ctx.input_params = list(args[length:])
146 |
147 | with torch.no_grad():
148 | output_tensors = ctx.run_function(*ctx.input_tensors)
149 | return output_tensors
150 |
151 | @staticmethod
152 | def backward(ctx, *output_grads):
153 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
154 | with torch.enable_grad():
155 | # Fixes a bug where the first op in run_function modifies the
156 | # Tensor storage in place, which is not allowed for detach()'d
157 | # Tensors.
158 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
159 | output_tensors = ctx.run_function(*shallow_copies)
160 | input_grads = torch.autograd.grad(
161 | output_tensors,
162 | ctx.input_tensors + ctx.input_params,
163 | output_grads,
164 | allow_unused=True,
165 | )
166 | del ctx.input_tensors
167 | del ctx.input_params
168 | del output_tensors
169 | return (None, None) + input_grads
170 |
171 |
172 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
173 | """
174 | Create sinusoidal timestep embeddings.
175 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
176 | These may be fractional.
177 | :param dim: the dimension of the output.
178 | :param max_period: controls the minimum frequency of the embeddings.
179 | :return: an [N x dim] Tensor of positional embeddings.
180 | """
181 | if not repeat_only:
182 | half = dim // 2
183 | freqs = torch.exp(
184 | -math.log(max_period)
185 | * torch.arange(start=0, end=half, dtype=torch.float32)
186 | / half
187 | ).to(device=timesteps.device)
188 | args = timesteps[:, None].float() * freqs[None]
189 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
190 | if dim % 2:
191 | embedding = torch.cat(
192 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
193 | )
194 | else:
195 | embedding = repeat(timesteps, "b -> b d", d=dim)
196 | return embedding
197 |
198 |
199 | def zero_module(module):
200 | """
201 | Zero out the parameters of a module and return it.
202 | """
203 | for p in module.parameters():
204 | p.detach().zero_()
205 | return module
206 |
207 |
208 | def scale_module(module, scale):
209 | """
210 | Scale the parameters of a module and return it.
211 | """
212 | for p in module.parameters():
213 | p.detach().mul_(scale)
214 | return module
215 |
216 |
217 | def mean_flat(tensor):
218 | """
219 | Take the mean over all non-batch dimensions.
220 | """
221 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
222 |
223 |
224 | def normalization(channels):
225 | """
226 | Make a standard normalization layer.
227 | :param channels: number of input channels.
228 | :return: an nn.Module for normalization.
229 | """
230 | return GroupNorm32(32, channels)
231 |
232 |
233 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
234 | class SiLU(nn.Module):
235 | def forward(self, x):
236 | return x * torch.sigmoid(x)
237 |
238 |
239 | class GroupNorm32(nn.GroupNorm):
240 | def forward(self, x):
241 | return super().forward(x.float()).type(x.dtype)
242 |
243 |
244 | def conv_nd(dims, *args, **kwargs):
245 | """
246 | Create a 1D, 2D, or 3D convolution module.
247 | """
248 | if dims == 1:
249 | return nn.Conv1d(*args, **kwargs)
250 | elif dims == 2:
251 | return nn.Conv2d(*args, **kwargs)
252 | elif dims == 3:
253 | return nn.Conv3d(*args, **kwargs)
254 | raise ValueError(f"unsupported dimensions: {dims}")
255 |
256 |
257 | def linear(*args, **kwargs):
258 | """
259 | Create a linear module.
260 | """
261 | return nn.Linear(*args, **kwargs)
262 |
263 |
264 | def avg_pool_nd(dims, *args, **kwargs):
265 | """
266 | Create a 1D, 2D, or 3D average pooling module.
267 | """
268 | if dims == 1:
269 | return nn.AvgPool1d(*args, **kwargs)
270 | elif dims == 2:
271 | return nn.AvgPool2d(*args, **kwargs)
272 | elif dims == 3:
273 | return nn.AvgPool3d(*args, **kwargs)
274 | raise ValueError(f"unsupported dimensions: {dims}")
275 |
276 |
277 | class HybridConditioner(nn.Module):
278 | def __init__(self, c_concat_config, c_crossattn_config):
279 | super().__init__()
280 | self.concat_conditioner = instantiate_from_config(c_concat_config)
281 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
282 |
283 | def forward(self, c_concat, c_crossattn):
284 | c_concat = self.concat_conditioner(c_concat)
285 | c_crossattn = self.crossattn_conditioner(c_crossattn)
286 | return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
287 |
288 |
289 | def noise_like(shape, device, repeat=False):
290 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
291 | shape[0], *((1,) * (len(shape) - 1))
292 | )
293 | noise = lambda: torch.randn(shape, device=device)
294 | return repeat_noise() if repeat else noise()
295 |
--------------------------------------------------------------------------------