├── audiosr ├── clap │ ├── __init__.py │ ├── training │ │ ├── __init__.py │ │ ├── audioset_textmap.npy │ │ └── bpe_simple_vocab_16e6.txt.gz │ └── open_clip │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── model_configs │ │ ├── ViT-B-16.json │ │ ├── ViT-B-32.json │ │ ├── ViT-L-14.json │ │ ├── ViT-B-32-quickgelu.json │ │ ├── RN101.json │ │ ├── RN50.json │ │ ├── RN50x16.json │ │ ├── RN50x4.json │ │ ├── RN101-quickgelu.json │ │ ├── RN50-quickgelu.json │ │ ├── PANN-10.json │ │ ├── PANN-14.json │ │ ├── PANN-6.json │ │ ├── HTSAT-base.json │ │ ├── HTSAT-large.json │ │ ├── HTSAT-tiny.json │ │ ├── HTSAT-tiny-win-1536.json │ │ ├── PANN-14-fmax-18k.json │ │ ├── PANN-14-fmax-8k-20s.json │ │ ├── PANN-14-win-1536.json │ │ └── PANN-14-tiny-transformer.json │ │ ├── __init__.py │ │ ├── transform.py │ │ ├── timm_model.py │ │ ├── openai.py │ │ ├── pretrained.py │ │ ├── tokenizer.py │ │ ├── feature_fusion.py │ │ ├── factory.py │ │ └── utils.py ├── latent_diffusion │ ├── __init__.py │ ├── models │ │ └── __init__.py │ ├── modules │ │ ├── __init__.py │ │ ├── audiomae │ │ │ ├── __init__.py │ │ │ ├── util │ │ │ │ ├── lr_sched.py │ │ │ │ ├── crop.py │ │ │ │ ├── datasets.py │ │ │ │ ├── lars.py │ │ │ │ ├── stat.py │ │ │ │ ├── lr_decay.py │ │ │ │ ├── patch_embed.py │ │ │ │ └── pos_embed.py │ │ │ ├── AudioMAE.py │ │ │ └── models_vit.py │ │ ├── encoders │ │ │ └── __init__.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ └── util.py │ │ ├── phoneme_encoder │ │ │ ├── __init__.py │ │ │ ├── text │ │ │ │ ├── symbols.py │ │ │ │ ├── LICENSE │ │ │ │ ├── __init__.py │ │ │ │ └── cleaners.py │ │ │ ├── encoder.py │ │ │ └── commons.py │ │ └── ema.py │ └── util.py ├── latent_encoder │ ├── __init__.py │ └── autoencoder.py ├── utilities │ ├── data │ │ └── __init__.py │ ├── __init__.py │ ├── audio │ │ ├── __init__.py │ │ ├── tools.py │ │ ├── audio_processing.py │ │ └── stft.py │ └── model.py ├── __init__.py ├── hifigan │ ├── __init__.py │ ├── LICENSE │ └── models.py ├── __main__.py ├── pipeline.py └── lowpass.py ├── .gitignore ├── datasets ├── 8khz_sample │ ├── 1.mp3 │ ├── 2.mp3 │ └── 3.mp3 ├── 8khz_sample2 │ ├── 1.wav │ ├── 2.wav │ └── 3.wav └── 8khz_sample3 │ └── 1.wav ├── requirements.txt ├── Readme.md └── main.py /audiosr/clap/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audiosr/clap/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audiosr/latent_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/audiomae/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/phoneme_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audiosr/utilities/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | flagged 4 | output 5 | gradio_cached* 6 | *egg-info 7 | build* -------------------------------------------------------------------------------- /audiosr/utilities/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import * 2 | from .data import * 3 | from .model import * 4 | -------------------------------------------------------------------------------- /datasets/8khz_sample/1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ORI-Muchim/AudioSR-Upsampling/HEAD/datasets/8khz_sample/1.mp3 -------------------------------------------------------------------------------- /datasets/8khz_sample/2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ORI-Muchim/AudioSR-Upsampling/HEAD/datasets/8khz_sample/2.mp3 -------------------------------------------------------------------------------- /datasets/8khz_sample/3.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ORI-Muchim/AudioSR-Upsampling/HEAD/datasets/8khz_sample/3.mp3 -------------------------------------------------------------------------------- /audiosr/utilities/audio/__init__.py: -------------------------------------------------------------------------------- 1 | from .audio_processing import * 2 | from .stft import * 3 | from .tools import * 4 | -------------------------------------------------------------------------------- /datasets/8khz_sample2/1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ORI-Muchim/AudioSR-Upsampling/HEAD/datasets/8khz_sample2/1.wav -------------------------------------------------------------------------------- /datasets/8khz_sample2/2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ORI-Muchim/AudioSR-Upsampling/HEAD/datasets/8khz_sample2/2.wav -------------------------------------------------------------------------------- /datasets/8khz_sample2/3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ORI-Muchim/AudioSR-Upsampling/HEAD/datasets/8khz_sample2/3.wav -------------------------------------------------------------------------------- /datasets/8khz_sample3/1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ORI-Muchim/AudioSR-Upsampling/HEAD/datasets/8khz_sample3/1.wav -------------------------------------------------------------------------------- /audiosr/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import seed_everything, save_wave, get_time, get_duration, read_list 2 | from .pipeline import * 3 | -------------------------------------------------------------------------------- /audiosr/clap/training/audioset_textmap.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ORI-Muchim/AudioSR-Upsampling/HEAD/audiosr/clap/training/audioset_textmap.npy -------------------------------------------------------------------------------- /audiosr/clap/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ORI-Muchim/AudioSR-Upsampling/HEAD/audiosr/clap/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /audiosr/clap/training/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ORI-Muchim/AudioSR-Upsampling/HEAD/audiosr/clap/training/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /audiosr/hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | from .models_v2 import Generator 2 | from .models import Generator as Generator_old 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/huggingface/diffusers.git 2 | git+https://github.com/huggingface/transformers.git 3 | --extra-index-url https://download.pytorch.org/whl/cu117 4 | torch >= 2.0 5 | huggingface_hub 6 | transformers 7 | soundfile 8 | librosa==0.9.2 9 | unidecode 10 | progressbar 11 | phonemizer 12 | ftfy 13 | timm 14 | einops 15 | torchlibrosa>=0.0.9 16 | chardet 17 | gradio 18 | numpy<=1.23.5 19 | pandas 20 | -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/PANN-10.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn10" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/PANN-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/PANN-6.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn6" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/phoneme_encoder/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Defines the set of symbols used in text input to the model. 5 | """ 6 | _pad = "_" 7 | _punctuation = ';:,.!?¡¿—…"«»“” ' 8 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 9 | _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 10 | 11 | 12 | # Export all symbols: 13 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) 14 | 15 | # Special symbol ids 16 | SPACE_ID = symbols.index(" ") 17 | -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/HTSAT-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "base" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/HTSAT-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "large" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/HTSAT-tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/PANN-14-fmax-18k.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 18000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 960000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 360, 10 | "fmin": 50, 11 | "fmax": 8000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/PANN-14-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/model_configs/PANN-14-tiny-transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 4 22 | } 23 | } -------------------------------------------------------------------------------- /audiosr/clap/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import ( 2 | list_models, 3 | create_model, 4 | create_model_and_transforms, 5 | add_model_config, 6 | ) 7 | from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics 8 | from .model import ( 9 | CLAP, 10 | CLAPTextCfg, 11 | CLAPVisionCfg, 12 | CLAPAudioCfp, 13 | convert_weights_to_fp16, 14 | trace_model, 15 | ) 16 | from .openai import load_openai_model, list_openai_models 17 | from .pretrained import ( 18 | list_pretrained, 19 | list_pretrained_tag_models, 20 | list_pretrained_model_tags, 21 | get_pretrained_url, 22 | download_pretrained, 23 | ) 24 | from .tokenizer import SimpleTokenizer, tokenize 25 | from .transform import image_transform 26 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/audiomae/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | 10 | def adjust_learning_rate(optimizer, epoch, args): 11 | """Decay the learning rate with half-cycle cosine after warmup""" 12 | if epoch < args.warmup_epochs: 13 | lr = args.lr * epoch / args.warmup_epochs 14 | else: 15 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( 16 | 1.0 17 | + math.cos( 18 | math.pi 19 | * (epoch - args.warmup_epochs) 20 | / (args.epochs - args.warmup_epochs) 21 | ) 22 | ) 23 | for param_group in optimizer.param_groups: 24 | if "lr_scale" in param_group: 25 | param_group["lr"] = lr * param_group["lr_scale"] 26 | else: 27 | param_group["lr"] = lr 28 | return lr 29 | -------------------------------------------------------------------------------- /audiosr/hifigan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/phoneme_encoder/text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /audiosr/clap/open_clip/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import ( 2 | Normalize, 3 | Compose, 4 | RandomResizedCrop, 5 | InterpolationMode, 6 | ToTensor, 7 | Resize, 8 | CenterCrop, 9 | ) 10 | 11 | 12 | def _convert_to_rgb(image): 13 | return image.convert("RGB") 14 | 15 | 16 | def image_transform( 17 | image_size: int, 18 | is_train: bool, 19 | mean=(0.48145466, 0.4578275, 0.40821073), 20 | std=(0.26862954, 0.26130258, 0.27577711), 21 | ): 22 | normalize = Normalize(mean=mean, std=std) 23 | if is_train: 24 | return Compose( 25 | [ 26 | RandomResizedCrop( 27 | image_size, 28 | scale=(0.9, 1.0), 29 | interpolation=InterpolationMode.BICUBIC, 30 | ), 31 | _convert_to_rgb, 32 | ToTensor(), 33 | normalize, 34 | ] 35 | ) 36 | else: 37 | return Compose( 38 | [ 39 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 40 | CenterCrop(image_size), 41 | _convert_to_rgb, 42 | ToTensor(), 43 | normalize, 44 | ] 45 | ) 46 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/audiomae/util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | 23 | @staticmethod 24 | def get_params(img, scale, ratio): 25 | width, height = F._get_image_size(img) 26 | area = height * width 27 | 28 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 29 | log_ratio = torch.log(torch.tensor(ratio)) 30 | aspect_ratio = torch.exp( 31 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 32 | ).item() 33 | 34 | w = int(round(math.sqrt(target_area * aspect_ratio))) 35 | h = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | w = min(w, width) 38 | h = min(h, height) 39 | 40 | i = torch.randint(0, height - h + 1, size=(1,)).item() 41 | j = torch.randint(0, width - w + 1, size=(1,)).item() 42 | 43 | return i, j, h, w 44 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/phoneme_encoder/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | import audiosr.latent_diffusion.modules.phoneme_encoder.commons as commons 6 | import audiosr.latent_diffusion.modules.phoneme_encoder.attentions as attentions 7 | 8 | 9 | class TextEncoder(nn.Module): 10 | def __init__( 11 | self, 12 | n_vocab, 13 | out_channels=192, 14 | hidden_channels=192, 15 | filter_channels=768, 16 | n_heads=2, 17 | n_layers=6, 18 | kernel_size=3, 19 | p_dropout=0.1, 20 | ): 21 | super().__init__() 22 | self.n_vocab = n_vocab 23 | self.out_channels = out_channels 24 | self.hidden_channels = hidden_channels 25 | self.filter_channels = filter_channels 26 | self.n_heads = n_heads 27 | self.n_layers = n_layers 28 | self.kernel_size = kernel_size 29 | self.p_dropout = p_dropout 30 | 31 | self.emb = nn.Embedding(n_vocab, hidden_channels) 32 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) 33 | 34 | self.encoder = attentions.Encoder( 35 | hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout 36 | ) 37 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 38 | 39 | def forward(self, x, x_lengths): 40 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] 41 | x = torch.transpose(x, 1, -1) # [b, h, t] 42 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( 43 | x.dtype 44 | ) 45 | 46 | x = self.encoder(x * x_mask, x_mask) 47 | stats = self.proj(x) * x_mask 48 | 49 | m, logs = torch.split(stats, self.out_channels, dim=1) 50 | return x, m, logs, x_mask 51 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/phoneme_encoder/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from audiosr.latent_diffusion.modules.phoneme_encoder.text import cleaners 3 | from audiosr.latent_diffusion.modules.phoneme_encoder.text.symbols import symbols 4 | 5 | 6 | # Mappings from symbol to numeric ID and vice versa: 7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 9 | 10 | cleaner = getattr(cleaners, "english_cleaners2") 11 | 12 | 13 | def text_to_sequence(text, cleaner_names): 14 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 15 | Args: 16 | text: string to convert to a sequence 17 | cleaner_names: names of the cleaner functions to run the text through 18 | Returns: 19 | List of integers corresponding to the symbols in the text 20 | """ 21 | sequence = [] 22 | 23 | clean_text = _clean_text(text, cleaner_names) 24 | for symbol in clean_text: 25 | symbol_id = _symbol_to_id[symbol] 26 | sequence += [symbol_id] 27 | return sequence 28 | 29 | 30 | def cleaned_text_to_sequence(cleaned_text): 31 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 32 | Args: 33 | text: string to convert to a sequence 34 | Returns: 35 | List of integers corresponding to the symbols in the text 36 | """ 37 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] 38 | return sequence 39 | 40 | 41 | def sequence_to_text(sequence): 42 | """Converts a sequence of IDs back to a string""" 43 | result = "" 44 | for symbol_id in sequence: 45 | s = _id_to_symbol[symbol_id] 46 | result += s 47 | return result 48 | 49 | 50 | def _clean_text(text, cleaner_names): 51 | text = cleaner(text) 52 | return text 53 | -------------------------------------------------------------------------------- /audiosr/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import argparse 5 | from audiosr import super_resolution, build_model, save_wave 6 | 7 | logging.basicConfig(level=logging.INFO) 8 | logger = logging.getLogger(__name__) 9 | 10 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 11 | torch.set_float32_matmul_precision("high") 12 | 13 | 14 | def main(args): 15 | audiosr = build_model(model_name=args.model_name, device="auto") 16 | 17 | waveform = super_resolution( 18 | audiosr, 19 | args.input_path, 20 | seed=42, 21 | guidance_scale=3.5, 22 | ddim_steps=50, 23 | latent_t_per_second=12.8 24 | ) 25 | 26 | save_wave(waveform, args.save_path, name="output", samplerate=48000) 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser(description='Perform Upsampling an audio files using audiosr package.') 31 | 32 | parser.add_argument('-i', '--input_path', required=True, help='Path to the input waveform file.') 33 | parser.add_argument('-s', '--save_path', required=True, help='Path to save the output waveform file.') 34 | parser.add_argument('--model_name', choices=['basic', 'speech'], default='speech', help='Name of the model to be used.') 35 | parser.add_argument('-d', '--device', default="auto", help='The device for computation. If not specified, the script will automatically choose the device based on your environment.') 36 | parser.add_argument('--ddim_steps', type=int, default=50, help='The sampling step for DDIM.') 37 | parser.add_argument('-gs', '--guidance_scale', type=float, default=3.5, help='Guidance scale (Large => better quality and relavancy to text; Small => better diversity).') 38 | parser.add_argument('--seed', type=int, default=42, help='Change this value (any integer number) will lead to a different generation result.') 39 | parser.add_argument('-il', '--input_file_list', help='A file that contains all audio files that need to perform audio super resolution.') 40 | 41 | args = parser.parse_args() 42 | main(args) 43 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/audiomae/util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, "train" if is_train else "val") 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation="bicubic", 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize( 60 | size, interpolation=PIL.Image.BICUBIC 61 | ), # to maintain same ratio w.r.t. 224 images 62 | ) 63 | t.append(transforms.CenterCrop(args.input_size)) 64 | 65 | t.append(transforms.ToTensor()) 66 | t.append(transforms.Normalize(mean, std)) 67 | return transforms.Compose(t) 68 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/audiomae/util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | 19 | def __init__( 20 | self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001 21 | ): 22 | defaults = dict( 23 | lr=lr, 24 | weight_decay=weight_decay, 25 | momentum=momentum, 26 | trust_coefficient=trust_coefficient, 27 | ) 28 | super().__init__(params, defaults) 29 | 30 | @torch.no_grad() 31 | def step(self): 32 | for g in self.param_groups: 33 | for p in g["params"]: 34 | dp = p.grad 35 | 36 | if dp is None: 37 | continue 38 | 39 | if p.ndim > 1: # if not normalization gamma/beta or bias 40 | dp = dp.add(p, alpha=g["weight_decay"]) 41 | param_norm = torch.norm(p) 42 | update_norm = torch.norm(dp) 43 | one = torch.ones_like(param_norm) 44 | q = torch.where( 45 | param_norm > 0.0, 46 | torch.where( 47 | update_norm > 0, 48 | (g["trust_coefficient"] * param_norm / update_norm), 49 | one, 50 | ), 51 | one, 52 | ) 53 | dp = dp.mul(q) 54 | 55 | param_state = self.state[p] 56 | if "mu" not in param_state: 57 | param_state["mu"] = torch.zeros_like(p) 58 | mu = param_state["mu"] 59 | mu.mul_(g["momentum"]).add_(dp) 60 | p.add_(mu, alpha=-g["lr"]) 61 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/audiomae/util/stat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | from sklearn import metrics 4 | import torch 5 | 6 | 7 | def d_prime(auc): 8 | standard_normal = stats.norm() 9 | d_prime = standard_normal.ppf(auc) * np.sqrt(2.0) 10 | return d_prime 11 | 12 | 13 | @torch.no_grad() 14 | def concat_all_gather(tensor): 15 | """ 16 | Performs all_gather operation on the provided tensors. 17 | *** Warning ***: torch.distributed.all_gather has no gradient. 18 | """ 19 | tensors_gather = [ 20 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) 21 | ] 22 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 23 | 24 | output = torch.cat(tensors_gather, dim=0) 25 | return output 26 | 27 | 28 | def calculate_stats(output, target): 29 | """Calculate statistics including mAP, AUC, etc. 30 | 31 | Args: 32 | output: 2d array, (samples_num, classes_num) 33 | target: 2d array, (samples_num, classes_num) 34 | 35 | Returns: 36 | stats: list of statistic of each class. 37 | """ 38 | 39 | classes_num = target.shape[-1] 40 | stats = [] 41 | 42 | # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet 43 | acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1)) 44 | 45 | # Class-wise statistics 46 | for k in range(classes_num): 47 | # Average precision 48 | avg_precision = metrics.average_precision_score( 49 | target[:, k], output[:, k], average=None 50 | ) 51 | 52 | # AUC 53 | # auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None) 54 | 55 | # Precisions, recalls 56 | (precisions, recalls, thresholds) = metrics.precision_recall_curve( 57 | target[:, k], output[:, k] 58 | ) 59 | 60 | # FPR, TPR 61 | (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k]) 62 | 63 | save_every_steps = 1000 # Sample statistics to reduce size 64 | dict = { 65 | "precisions": precisions[0::save_every_steps], 66 | "recalls": recalls[0::save_every_steps], 67 | "AP": avg_precision, 68 | "fpr": fpr[0::save_every_steps], 69 | "fnr": 1.0 - tpr[0::save_every_steps], 70 | # 'auc': auc, 71 | # note acc is not class-wise, this is just to keep consistent with other metrics 72 | "acc": acc, 73 | } 74 | stats.append(dict) 75 | 76 | return stats 77 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # AudioSR: Versatile Audio Super-resolution at Scale 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2309.07314-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2309.07314) [![githubio](https://img.shields.io/badge/GitHub.io-Audio_Samples-blue?logo=Github&style=flat-square)](https://audioldm.github.io/audiosr) 4 | 5 | Pass your audio in, AudioSR will make it high fidelity! 6 | 7 | Work on all types of audio (e.g., music, speech, dog, raining, ...) & all sampling rates. 8 | 9 | [Original Repo](https://github.com/haoheliu/versatile_audio_super_resolution) 10 | 11 | 12 | ## Table of Contents 13 | - [Installation](#installation) 14 | - [Prepare_Datasets](#prepare_datasets) 15 | - [Usage](#usage) 16 | - [Reference](#reference) 17 | 18 | 19 | ## Installation 20 | 1. **Create an Anaconda environment:** 21 | 22 | ```sh 23 | conda create -n audiosr python=3.9 24 | ``` 25 | 26 | 2. **Activate the environment:** 27 | 28 | ```sh 29 | conda activate audiosr 30 | ``` 31 | 32 | 3. **Clone this repository to your local machine:** 33 | 34 | ```sh 35 | git clone https://github.com/ORI-Muchim/AudioSR-Upsampling.git 36 | ``` 37 | 38 | 4. **Navigate to the cloned directory:** 39 | 40 | ```sh 41 | cd AudioSR-Upsampling 42 | ``` 43 | 44 | 5. **Install the necessary dependencies:** 45 | 46 | ```sh 47 | pip install -r requirements.txt 48 | ``` 49 | 50 | 51 | ## Prepare_Datasets 52 | 53 | Place the audio files as follows. 54 | 55 | .mp3 or .wav files are okay. 56 | 57 | ``` 58 | AudioSR-Upsampling 59 | ├────datasets 60 | │ ├───speaker0 61 | │ │ ├────1.mp3 62 | │ │ └────2.mp3 63 | │ ├───speaker1 64 | │ │ ├───1.wav 65 | │ │ └───1.wav 66 | │ ├───speaker2 67 | │ │ ├────1.wav 68 | │ └───└────1.wav 69 | ├────audiosr 70 | ├────.gitignore 71 | ├────main.py 72 | ├────Readme.md 73 | └────requirements.txt 74 | ``` 75 | 76 | This is just an example, and it's okay to add more speakers. 77 | 78 | ### When you put audio datasets in one folder, please unify all the extensions into one. 79 | 80 | 81 | ## Usage 82 | 83 | ```sh 84 | python main.py 85 | ``` 86 | 87 | ## Reference 88 | 89 | Thank you for [falsewinnet](https://github.com/falseywinchnet) for helping me create `./audiosr/__main__.py`. 90 | 91 | 92 | 93 | If you find this repo useful, please consider citing: 94 | ```bibtex 95 | @article{liu2023audiosr, 96 | title={{AudioSR}: Versatile Audio Super-resolution at Scale}, 97 | author={Liu, Haohe and Chen, Ke and Tian, Qiao and Wang, Wenwu and Plumbley, Mark D}, 98 | journal={arXiv preprint arXiv:2309.07314}, 99 | year={2023} 100 | } 101 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/audiomae/util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | 13 | def param_groups_lrd( 14 | model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75 15 | ): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0.0 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ["cls_token", "pos_embed"]: 70 | return 0 71 | elif name.startswith("patch_embed"): 72 | return 0 73 | elif name.startswith("blocks"): 74 | return int(name.split(".")[1]) + 1 75 | else: 76 | return num_layers 77 | -------------------------------------------------------------------------------- /audiosr/utilities/audio/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.io.wavfile import write 4 | import torchaudio 5 | 6 | from audiosr.utilities.audio.audio_processing import griffin_lim 7 | 8 | 9 | def pad_wav(waveform, segment_length): 10 | waveform_length = waveform.shape[-1] 11 | assert waveform_length > 100, "Waveform is too short, %s" % waveform_length 12 | if segment_length is None or waveform_length == segment_length: 13 | return waveform 14 | elif waveform_length > segment_length: 15 | return waveform[:segment_length] 16 | elif waveform_length < segment_length: 17 | temp_wav = np.zeros((1, segment_length)) 18 | temp_wav[:, :waveform_length] = waveform 19 | return temp_wav 20 | 21 | 22 | def normalize_wav(waveform): 23 | waveform = waveform - np.mean(waveform) 24 | waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) 25 | return waveform * 0.5 26 | 27 | 28 | def read_wav_file(filename, segment_length): 29 | # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower 30 | waveform, sr = torchaudio.load(filename) # Faster!!! 31 | waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) 32 | waveform = waveform.numpy()[0, ...] 33 | waveform = normalize_wav(waveform) 34 | waveform = waveform[None, ...] 35 | waveform = pad_wav(waveform, segment_length) 36 | 37 | waveform = waveform / np.max(np.abs(waveform)) 38 | waveform = 0.5 * waveform 39 | 40 | return waveform 41 | 42 | 43 | def get_mel_from_wav(audio, _stft): 44 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 45 | audio = torch.autograd.Variable(audio, requires_grad=False) 46 | melspec, magnitudes, phases, energy = _stft.mel_spectrogram(audio) 47 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) 48 | magnitudes = torch.squeeze(magnitudes, 0).numpy().astype(np.float32) 49 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32) 50 | return melspec, magnitudes, energy 51 | 52 | 53 | def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): 54 | mel = torch.stack([mel]) 55 | mel_decompress = _stft.spectral_de_normalize(mel) 56 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu() 57 | spec_from_mel_scaling = 1000 58 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) 59 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) 60 | spec_from_mel = spec_from_mel * spec_from_mel_scaling 61 | 62 | audio = griffin_lim( 63 | torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters 64 | ) 65 | 66 | audio = audio.squeeze() 67 | audio = audio.cpu().numpy() 68 | audio_path = out_filename 69 | write(audio_path, _stft.sampling_rate, audio) 70 | -------------------------------------------------------------------------------- /audiosr/utilities/audio/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import librosa.util as librosa_util 4 | from scipy.signal import get_window 5 | 6 | 7 | def window_sumsquare( 8 | window, 9 | n_frames, 10 | hop_length, 11 | win_length, 12 | n_fft, 13 | dtype=np.float32, 14 | norm=None, 15 | ): 16 | """ 17 | # from librosa 0.6 18 | Compute the sum-square envelope of a window function at a given hop length. 19 | 20 | This is used to estimate modulation effects induced by windowing 21 | observations in short-time fourier transforms. 22 | 23 | Parameters 24 | ---------- 25 | window : string, tuple, number, callable, or list-like 26 | Window specification, as in `get_window` 27 | 28 | n_frames : int > 0 29 | The number of analysis frames 30 | 31 | hop_length : int > 0 32 | The number of samples to advance between frames 33 | 34 | win_length : [optional] 35 | The length of the window function. By default, this matches `n_fft`. 36 | 37 | n_fft : int > 0 38 | The length of each analysis frame. 39 | 40 | dtype : np.dtype 41 | The data type of the output 42 | 43 | Returns 44 | ------- 45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 46 | The sum-squared envelope of the window function 47 | """ 48 | if win_length is None: 49 | win_length = n_fft 50 | 51 | n = n_fft + hop_length * (n_frames - 1) 52 | x = np.zeros(n, dtype=dtype) 53 | 54 | # Compute the squared window at the desired length 55 | win_sq = get_window(window, win_length, fftbins=True) 56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 57 | win_sq = librosa_util.pad_center(data=win_sq, size=n_fft) 58 | 59 | # Fill the envelope 60 | for i in range(n_frames): 61 | sample = i * hop_length 62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] 63 | return x 64 | 65 | 66 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 67 | """ 68 | PARAMS 69 | ------ 70 | magnitudes: spectrogram magnitudes 71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 72 | """ 73 | 74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 75 | angles = angles.astype(np.float32) 76 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 78 | 79 | for i in range(n_iters): 80 | _, angles = stft_fn.transform(signal) 81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 82 | return signal 83 | 84 | 85 | def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): 86 | """ 87 | PARAMS 88 | ------ 89 | C: compression factor 90 | """ 91 | return normalize_fun(torch.clamp(x, min=clip_val) * C) 92 | 93 | 94 | def dynamic_range_decompression(x, C=1): 95 | """ 96 | PARAMS 97 | ------ 98 | C: compression factor used to compress 99 | """ 100 | return torch.exp(x) / C 101 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_( 47 | one_minus_decay * (shadow_params[sname] - m_param[key]) 48 | ) 49 | else: 50 | assert not key in self.m_name2s_name 51 | 52 | def copy_to(self, model): 53 | m_param = dict(model.named_parameters()) 54 | shadow_params = dict(self.named_buffers()) 55 | for key in m_param: 56 | if m_param[key].requires_grad: 57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 58 | else: 59 | assert not key in self.m_name2s_name 60 | 61 | def store(self, parameters): 62 | """ 63 | Save the current parameters for restoring later. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | temporarily stored. 67 | """ 68 | self.collected_params = [param.clone() for param in parameters] 69 | 70 | def restore(self, parameters): 71 | """ 72 | Restore the parameters stored with the `store` method. 73 | Useful to validate the model with EMA parameters without affecting the 74 | original optimization process. Store the parameters before the 75 | `copy_to` method. After validation (or model saving), use this to 76 | restore the former parameters. 77 | Args: 78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 79 | updated with the stored parameters. 80 | """ 81 | for c_param, param in zip(self.collected_params, parameters): 82 | param.data.copy_(c_param.data) 83 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.mean( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.mean( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/phoneme_encoder/text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | """ 14 | 15 | import re 16 | from unidecode import unidecode 17 | from phonemizer import phonemize 18 | 19 | 20 | # Regular expression matching whitespace: 21 | _whitespace_re = re.compile(r"\s+") 22 | 23 | # List of (regular expression, replacement) pairs for abbreviations: 24 | _abbreviations = [ 25 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) 26 | for x in [ 27 | ("mrs", "misess"), 28 | ("mr", "mister"), 29 | ("dr", "doctor"), 30 | ("st", "saint"), 31 | ("co", "company"), 32 | ("jr", "junior"), 33 | ("maj", "major"), 34 | ("gen", "general"), 35 | ("drs", "doctors"), 36 | ("rev", "reverend"), 37 | ("lt", "lieutenant"), 38 | ("hon", "honorable"), 39 | ("sgt", "sergeant"), 40 | ("capt", "captain"), 41 | ("esq", "esquire"), 42 | ("ltd", "limited"), 43 | ("col", "colonel"), 44 | ("ft", "fort"), 45 | ] 46 | ] 47 | 48 | 49 | def expand_abbreviations(text): 50 | for regex, replacement in _abbreviations: 51 | text = re.sub(regex, replacement, text) 52 | return text 53 | 54 | 55 | def expand_numbers(text): 56 | return normalize_numbers(text) 57 | 58 | 59 | def lowercase(text): 60 | return text.lower() 61 | 62 | 63 | def collapse_whitespace(text): 64 | return re.sub(_whitespace_re, " ", text) 65 | 66 | 67 | def convert_to_ascii(text): 68 | return unidecode(text) 69 | 70 | 71 | def basic_cleaners(text): 72 | """Basic pipeline that lowercases and collapses whitespace without transliteration.""" 73 | text = lowercase(text) 74 | text = collapse_whitespace(text) 75 | return text 76 | 77 | 78 | def transliteration_cleaners(text): 79 | """Pipeline for non-English text that transliterates to ASCII.""" 80 | text = convert_to_ascii(text) 81 | text = lowercase(text) 82 | text = collapse_whitespace(text) 83 | return text 84 | 85 | 86 | def english_cleaners(text): 87 | """Pipeline for English text, including abbreviation expansion.""" 88 | text = convert_to_ascii(text) 89 | text = lowercase(text) 90 | text = expand_abbreviations(text) 91 | phonemes = phonemize(text, language="en-us", backend="espeak", strip=True) 92 | phonemes = collapse_whitespace(phonemes) 93 | return phonemes 94 | 95 | 96 | def english_cleaners2(text): 97 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 98 | text = convert_to_ascii(text) 99 | text = lowercase(text) 100 | text = expand_abbreviations(text) 101 | phonemes = phonemize( 102 | text, 103 | language="en-us", 104 | backend="espeak", 105 | strip=True, 106 | preserve_punctuation=True, 107 | with_stress=True, 108 | ) 109 | phonemes = collapse_whitespace(phonemes) 110 | return phonemes 111 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | import json 5 | import tempfile 6 | from audiosr import build_model, super_resolution, save_wave 7 | 8 | 9 | def get_sample_rate(file_path): 10 | command = [ 11 | 'ffprobe', 12 | '-v', 'error', 13 | '-select_streams', 'a:0', 14 | '-show_entries', 'stream=sample_rate', 15 | '-of', 'json', 16 | file_path 17 | ] 18 | result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 19 | output = json.loads(result.stdout) 20 | return int(output['streams'][0]['sample_rate']) 21 | 22 | 23 | 24 | def get_duration(file_location): 25 | command = [ 26 | 'ffprobe', 27 | '-v', 'error', 28 | '-select_streams', 'a:0', 29 | '-show_entries', 'format=duration', 30 | '-sexagesimal', 31 | '-of', 'json', 32 | file_location 33 | ] 34 | result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 35 | output = json.loads(result.stdout) 36 | return output['format']['duration'] 37 | 38 | 39 | def remove_silence(duration, input_path, output_path): 40 | command = [ 41 | 'ffmpeg', 42 | '-ss', '00:00:00', 43 | '-i', input_path, 44 | '-t', duration, 45 | '-c', 'copy', 46 | output_path 47 | ] 48 | subprocess.run(command) 49 | os.remove(input_path) 50 | 51 | 52 | tmp_dir = tempfile.gettempdir() 53 | datasets_path = './datasets/' 54 | speakers = [d for d in os.listdir(datasets_path) if os.path.isdir(os.path.join(datasets_path, d))] 55 | audiosr = build_model(model_name='speech', device="auto") 56 | 57 | for speaker in speakers: 58 | folder_path = os.path.join(datasets_path, speaker) 59 | source_wavs_folder = os.path.join(folder_path, 'source_wavs') 60 | if not os.path.exists(source_wavs_folder): 61 | os.makedirs(source_wavs_folder) 62 | 63 | files = os.listdir(folder_path) 64 | 65 | for file in files: 66 | input_path = os.path.join(folder_path, file) 67 | base_name, _ = os.path.splitext(file) 68 | wav_file_name = f"{base_name}.wav" 69 | wav_path = os.path.join(folder_path, wav_file_name) 70 | 71 | if file.endswith('.mp3'): 72 | try: 73 | sample_rate = get_sample_rate(input_path) 74 | subprocess.run(["ffmpeg", "-i", input_path, "-ar", str(sample_rate), "-ac", "1", wav_path], check=True) 75 | os.remove(input_path) 76 | except subprocess.CalledProcessError as e: 77 | print(f"Error occurred while converting {input_path} to WAV: {e}") 78 | except Exception as e: 79 | print(f"An error occurred: {e}") 80 | 81 | if file.endswith('.wav') or file.endswith('.mp3'): 82 | shutil.move(wav_path, os.path.join(source_wavs_folder, wav_file_name)) 83 | 84 | for speaker in speakers: 85 | folder_path = os.path.join(datasets_path, speaker) 86 | wavs_folder = os.path.join(folder_path, 'wavs') 87 | source_wavs_folder = os.path.join(folder_path, 'source_wavs') 88 | if not os.path.exists(wavs_folder): 89 | os.makedirs(wavs_folder) 90 | 91 | source_wavs_files = os.listdir(source_wavs_folder) 92 | 93 | for file in source_wavs_files: 94 | if file.endswith('.wav'): 95 | input_path = os.path.join(source_wavs_folder, file) 96 | input_filename = os.path.basename(input_path) 97 | duration = get_duration(input_path) 98 | save_path = os.path.join(wavs_folder) 99 | try: 100 | waveform = super_resolution( 101 | audiosr, 102 | input_path, 103 | seed=42, 104 | guidance_scale=3.5, 105 | ddim_steps=50, 106 | latent_t_per_second=12.8 107 | ) 108 | base_name, _ = os.path.splitext(file) 109 | tmp_file_path = os.path.join(tmp_dir, f"{base_name}.wav") 110 | save_wave(waveform, tmp_dir, name=base_name, samplerate=48000) 111 | remove_silence(duration, tmp_file_path, os.path.join(save_path, input_filename)) 112 | 113 | except Exception as e: 114 | print(f"An error occurred: {e}") 115 | -------------------------------------------------------------------------------- /audiosr/clap/open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch.nn as nn 8 | 9 | try: 10 | import timm 11 | from timm.models.layers import Mlp, to_2tuple 12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 13 | from timm.models.layers.attention_pool2d import ( 14 | AttentionPool2d as AbsAttentionPool2d, 15 | ) 16 | except ImportError: 17 | timm = None 18 | 19 | from .utils import freeze_batch_norm_2d 20 | 21 | 22 | class TimmModel(nn.Module): 23 | """timm model adapter 24 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 25 | """ 26 | 27 | def __init__( 28 | self, 29 | model_name, 30 | embed_dim, 31 | image_size=224, 32 | pool="avg", 33 | proj="linear", 34 | drop=0.0, 35 | pretrained=False, 36 | ): 37 | super().__init__() 38 | if timm is None: 39 | raise RuntimeError("Please `pip install timm` to use timm models.") 40 | 41 | self.image_size = to_2tuple(image_size) 42 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 43 | feat_size = self.trunk.default_cfg.get("pool_size", None) 44 | feature_ndim = 1 if not feat_size else 2 45 | if pool in ("abs_attn", "rot_attn"): 46 | assert feature_ndim == 2 47 | # if attn pooling used, remove both classifier and default pool 48 | self.trunk.reset_classifier(0, global_pool="") 49 | else: 50 | # reset global pool if pool config set, otherwise leave as network default 51 | reset_kwargs = dict(global_pool=pool) if pool else {} 52 | self.trunk.reset_classifier(0, **reset_kwargs) 53 | prev_chs = self.trunk.num_features 54 | 55 | head_layers = OrderedDict() 56 | if pool == "abs_attn": 57 | head_layers["pool"] = AbsAttentionPool2d( 58 | prev_chs, feat_size=feat_size, out_features=embed_dim 59 | ) 60 | prev_chs = embed_dim 61 | elif pool == "rot_attn": 62 | head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 63 | prev_chs = embed_dim 64 | else: 65 | assert proj, "projection layer needed if non-attention pooling is used." 66 | 67 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 68 | if proj == "linear": 69 | head_layers["drop"] = nn.Dropout(drop) 70 | head_layers["proj"] = nn.Linear(prev_chs, embed_dim) 71 | elif proj == "mlp": 72 | head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) 73 | 74 | self.head = nn.Sequential(head_layers) 75 | 76 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 77 | """lock modules 78 | Args: 79 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 80 | """ 81 | if not unlocked_groups: 82 | # lock full model 83 | for param in self.trunk.parameters(): 84 | param.requires_grad = False 85 | if freeze_bn_stats: 86 | freeze_batch_norm_2d(self.trunk) 87 | else: 88 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 89 | try: 90 | # FIXME import here until API stable and in an official release 91 | from timm.models.helpers import group_parameters, group_modules 92 | except ImportError: 93 | raise RuntimeError( 94 | "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`" 95 | ) 96 | matcher = self.trunk.group_matcher() 97 | gparams = group_parameters(self.trunk, matcher) 98 | max_layer_id = max(gparams.keys()) 99 | max_layer_id = max_layer_id - unlocked_groups 100 | for group_idx in range(max_layer_id + 1): 101 | group = gparams[group_idx] 102 | for param in group: 103 | self.trunk.get_parameter(param).requires_grad = False 104 | if freeze_bn_stats: 105 | gmodules = group_modules(self.trunk, matcher, reverse=True) 106 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 107 | freeze_batch_norm_2d(self.trunk, gmodules) 108 | 109 | def forward(self, x): 110 | x = self.trunk(x) 111 | x = self.head(x) 112 | return x 113 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/audiomae/util/patch_embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import to_2tuple 4 | 5 | 6 | class PatchEmbed_org(nn.Module): 7 | """Image to Patch Embedding""" 8 | 9 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 10 | super().__init__() 11 | img_size = to_2tuple(img_size) 12 | patch_size = to_2tuple(patch_size) 13 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 14 | self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) 15 | self.img_size = img_size 16 | self.patch_size = patch_size 17 | self.num_patches = num_patches 18 | 19 | self.proj = nn.Conv2d( 20 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size 21 | ) 22 | 23 | def forward(self, x): 24 | B, C, H, W = x.shape 25 | # FIXME look at relaxing size constraints 26 | # assert H == self.img_size[0] and W == self.img_size[1], \ 27 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 28 | x = self.proj(x) 29 | y = x.flatten(2).transpose(1, 2) 30 | return y 31 | 32 | 33 | class PatchEmbed_new(nn.Module): 34 | """Flexible Image to Patch Embedding""" 35 | 36 | def __init__( 37 | self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 38 | ): 39 | super().__init__() 40 | img_size = to_2tuple(img_size) 41 | patch_size = to_2tuple(patch_size) 42 | stride = to_2tuple(stride) 43 | 44 | self.img_size = img_size 45 | self.patch_size = patch_size 46 | 47 | self.proj = nn.Conv2d( 48 | in_chans, embed_dim, kernel_size=patch_size, stride=stride 49 | ) # with overlapped patches 50 | # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 51 | 52 | # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) 53 | # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 54 | _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w 55 | self.patch_hw = (h, w) 56 | self.num_patches = h * w 57 | 58 | def get_output_shape(self, img_size): 59 | # todo: don't be lazy.. 60 | return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape 61 | 62 | def forward(self, x): 63 | B, C, H, W = x.shape 64 | # FIXME look at relaxing size constraints 65 | # assert H == self.img_size[0] and W == self.img_size[1], \ 66 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 67 | # x = self.proj(x).flatten(2).transpose(1, 2) 68 | x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12 69 | x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212 70 | x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768 71 | return x 72 | 73 | 74 | class PatchEmbed3D_new(nn.Module): 75 | """Flexible Image to Patch Embedding""" 76 | 77 | def __init__( 78 | self, 79 | video_size=(16, 224, 224), 80 | patch_size=(2, 16, 16), 81 | in_chans=3, 82 | embed_dim=768, 83 | stride=(2, 16, 16), 84 | ): 85 | super().__init__() 86 | 87 | self.video_size = video_size 88 | self.patch_size = patch_size 89 | self.in_chans = in_chans 90 | 91 | self.proj = nn.Conv3d( 92 | in_chans, embed_dim, kernel_size=patch_size, stride=stride 93 | ) 94 | _, _, t, h, w = self.get_output_shape(video_size) # n, emb_dim, h, w 95 | self.patch_thw = (t, h, w) 96 | self.num_patches = t * h * w 97 | 98 | def get_output_shape(self, video_size): 99 | # todo: don't be lazy.. 100 | return self.proj( 101 | torch.randn(1, self.in_chans, video_size[0], video_size[1], video_size[2]) 102 | ).shape 103 | 104 | def forward(self, x): 105 | B, C, T, H, W = x.shape 106 | x = self.proj(x) # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14 107 | x = x.flatten(2) # 32, 768, 1568 108 | x = x.transpose(1, 2) # 32, 768, 1568 -> 32, 1568, 768 109 | return x 110 | 111 | 112 | if __name__ == "__main__": 113 | # patch_emb = PatchEmbed_new(img_size=224, patch_size=16, in_chans=1, embed_dim=64, stride=(16,16)) 114 | # input = torch.rand(8,1,1024,128) 115 | # output = patch_emb(input) 116 | # print(output.shape) # (8,512,64) 117 | 118 | patch_emb = PatchEmbed3D_new( 119 | video_size=(6, 224, 224), 120 | patch_size=(2, 16, 16), 121 | in_chans=3, 122 | embed_dim=768, 123 | stride=(2, 16, 16), 124 | ) 125 | input = torch.rand(8, 3, 6, 224, 224) 126 | output = patch_emb(input) 127 | print(output.shape) # (8,64) 128 | -------------------------------------------------------------------------------- /audiosr/pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import yaml 5 | import torch 6 | import torchaudio 7 | import numpy as np 8 | 9 | import audiosr.latent_diffusion.modules.phoneme_encoder.text as text 10 | from audiosr.latent_diffusion.models.ddpm import LatentDiffusion 11 | from audiosr.latent_diffusion.util import get_vits_phoneme_ids_no_padding 12 | from audiosr.utils import ( 13 | default_audioldm_config, 14 | download_checkpoint, 15 | read_audio_file, 16 | lowpass_filtering_prepare_inference, 17 | wav_feature_extraction, 18 | ) 19 | import os 20 | 21 | 22 | def seed_everything(seed): 23 | import random, os 24 | import numpy as np 25 | import torch 26 | 27 | random.seed(seed) 28 | os.environ["PYTHONHASHSEED"] = str(seed) 29 | np.random.seed(seed) 30 | torch.manual_seed(seed) 31 | torch.cuda.manual_seed(seed) 32 | torch.backends.cudnn.deterministic = True 33 | torch.backends.cudnn.benchmark = True 34 | 35 | 36 | def text2phoneme(data): 37 | return text._clean_text(re.sub(r"<.*?>", "", data), ["english_cleaners2"]) 38 | 39 | 40 | def text_to_filename(text): 41 | return text.replace(" ", "_").replace("'", "_").replace('"', "_") 42 | 43 | 44 | def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec): 45 | norm_mean = -4.2677393 46 | norm_std = 4.5689974 47 | 48 | if sampling_rate != 16000: 49 | waveform_16k = torchaudio.functional.resample( 50 | waveform, orig_freq=sampling_rate, new_freq=16000 51 | ) 52 | else: 53 | waveform_16k = waveform 54 | 55 | waveform_16k = waveform_16k - waveform_16k.mean() 56 | fbank = torchaudio.compliance.kaldi.fbank( 57 | waveform_16k, 58 | htk_compat=True, 59 | sample_frequency=16000, 60 | use_energy=False, 61 | window_type="hanning", 62 | num_mel_bins=128, 63 | dither=0.0, 64 | frame_shift=10, 65 | ) 66 | 67 | TARGET_LEN = log_mel_spec.size(0) 68 | 69 | # cut and pad 70 | n_frames = fbank.shape[0] 71 | p = TARGET_LEN - n_frames 72 | if p > 0: 73 | m = torch.nn.ZeroPad2d((0, 0, 0, p)) 74 | fbank = m(fbank) 75 | elif p < 0: 76 | fbank = fbank[:TARGET_LEN, :] 77 | 78 | fbank = (fbank - norm_mean) / (norm_std * 2) 79 | 80 | return {"ta_kaldi_fbank": fbank} # [1024, 128] 81 | 82 | 83 | def make_batch_for_super_resolution(input_file, waveform=None, fbank=None): 84 | log_mel_spec, stft, waveform, duration, target_frame = read_audio_file(input_file) 85 | 86 | batch = { 87 | "waveform": torch.FloatTensor(waveform), 88 | "stft": torch.FloatTensor(stft), 89 | "log_mel_spec": torch.FloatTensor(log_mel_spec), 90 | "sampling_rate": 48000, 91 | } 92 | 93 | # print(batch["waveform"].size(), batch["stft"].size(), batch["log_mel_spec"].size()) 94 | 95 | batch.update(lowpass_filtering_prepare_inference(batch)) 96 | 97 | assert "waveform_lowpass" in batch.keys() 98 | lowpass_mel, lowpass_stft = wav_feature_extraction( 99 | batch["waveform_lowpass"], target_frame 100 | ) 101 | batch["lowpass_mel"] = lowpass_mel 102 | 103 | for k in batch.keys(): 104 | if type(batch[k]) == torch.Tensor: 105 | batch[k] = torch.FloatTensor(batch[k]).unsqueeze(0) 106 | 107 | return batch, duration 108 | 109 | 110 | def round_up_duration(duration): 111 | return int(round(duration / 2.5) + 1) * 2.5 112 | 113 | 114 | def build_model(ckpt_path=None, config=None, device=None, model_name="basic"): 115 | if device is None or device == "auto": 116 | if torch.cuda.is_available(): 117 | device = torch.device("cuda:0") 118 | elif torch.backends.mps.is_available(): 119 | device = torch.device("mps") 120 | else: 121 | device = torch.device("cpu") 122 | 123 | print("Loading AudioSR: %s" % model_name) 124 | print("Loading model on %s" % device) 125 | 126 | ckpt_path = download_checkpoint(model_name) 127 | 128 | if config is not None: 129 | assert type(config) is str 130 | config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) 131 | else: 132 | config = default_audioldm_config(model_name) 133 | 134 | # # Use text as condition instead of using waveform during training 135 | config["model"]["params"]["device"] = device 136 | # config["model"]["params"]["cond_stage_key"] = "text" 137 | 138 | # No normalization here 139 | latent_diffusion = LatentDiffusion(**config["model"]["params"]) 140 | 141 | resume_from_checkpoint = ckpt_path 142 | 143 | checkpoint = torch.load(resume_from_checkpoint, map_location=device) 144 | 145 | latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=False) 146 | 147 | latent_diffusion.eval() 148 | latent_diffusion = latent_diffusion.to(device) 149 | 150 | return latent_diffusion 151 | 152 | 153 | def super_resolution( 154 | latent_diffusion, 155 | input_file, 156 | seed=42, 157 | ddim_steps=200, 158 | guidance_scale=3.5, 159 | latent_t_per_second=12.8, 160 | config=None, 161 | ): 162 | seed_everything(int(seed)) 163 | waveform = None 164 | 165 | batch, duration = make_batch_for_super_resolution(input_file, waveform=waveform) 166 | 167 | with torch.no_grad(): 168 | waveform = latent_diffusion.generate_batch( 169 | batch, 170 | unconditional_guidance_scale=guidance_scale, 171 | ddim_steps=ddim_steps, 172 | duration=duration, 173 | ) 174 | 175 | return waveform 176 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/phoneme_encoder/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def init_weights(m, mean=0.0, std=0.01): 7 | classname = m.__class__.__name__ 8 | if classname.find("Conv") != -1: 9 | m.weight.data.normal_(mean, std) 10 | 11 | 12 | def get_padding(kernel_size, dilation=1): 13 | return int((kernel_size * dilation - dilation) / 2) 14 | 15 | 16 | def convert_pad_shape(pad_shape): 17 | l = pad_shape[::-1] 18 | pad_shape = [item for sublist in l for item in sublist] 19 | return pad_shape 20 | 21 | 22 | def intersperse(lst, item): 23 | result = [item] * (len(lst) * 2 + 1) 24 | result[1::2] = lst 25 | return result 26 | 27 | 28 | def kl_divergence(m_p, logs_p, m_q, logs_q): 29 | """KL(P||Q)""" 30 | kl = (logs_q - logs_p) - 0.5 31 | kl += ( 32 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 33 | ) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | ret = torch.zeros_like(x[:, :, :segment_size]) 50 | for i in range(x.size(0)): 51 | idx_str = ids_str[i] 52 | idx_end = idx_str + segment_size 53 | ret[i] = x[i, :, idx_str:idx_end] 54 | return ret 55 | 56 | 57 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 58 | b, d, t = x.size() 59 | if x_lengths is None: 60 | x_lengths = t 61 | ids_str_max = x_lengths - segment_size + 1 62 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 63 | ret = slice_segments(x, ids_str, segment_size) 64 | return ret, ids_str 65 | 66 | 67 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 68 | position = torch.arange(length, dtype=torch.float) 69 | num_timescales = channels // 2 70 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( 71 | num_timescales - 1 72 | ) 73 | inv_timescales = min_timescale * torch.exp( 74 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 75 | ) 76 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 77 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 78 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 79 | signal = signal.view(1, channels, length) 80 | return signal 81 | 82 | 83 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 84 | b, channels, length = x.size() 85 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 86 | return x + signal.to(dtype=x.dtype, device=x.device) 87 | 88 | 89 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 90 | b, channels, length = x.size() 91 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 92 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 93 | 94 | 95 | def subsequent_mask(length): 96 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 97 | return mask 98 | 99 | 100 | @torch.jit.script 101 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 102 | n_channels_int = n_channels[0] 103 | in_act = input_a + input_b 104 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 105 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 106 | acts = t_act * s_act 107 | return acts 108 | 109 | 110 | def convert_pad_shape(pad_shape): 111 | l = pad_shape[::-1] 112 | pad_shape = [item for sublist in l for item in sublist] 113 | return pad_shape 114 | 115 | 116 | def shift_1d(x): 117 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 118 | return x 119 | 120 | 121 | def sequence_mask(length, max_length=None): 122 | if max_length is None: 123 | max_length = length.max() 124 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 125 | return x.unsqueeze(0) < length.unsqueeze(1) 126 | 127 | 128 | def generate_path(duration, mask): 129 | """ 130 | duration: [b, 1, t_x] 131 | mask: [b, 1, t_y, t_x] 132 | """ 133 | duration.device 134 | 135 | b, _, t_y, t_x = mask.shape 136 | cum_duration = torch.cumsum(duration, -1) 137 | 138 | cum_duration_flat = cum_duration.view(b * t_x) 139 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 140 | path = path.view(b, t_x, t_y) 141 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 142 | path = path.unsqueeze(1).transpose(2, 3) * mask 143 | return path 144 | 145 | 146 | def clip_grad_value_(parameters, clip_value, norm_type=2): 147 | if isinstance(parameters, torch.Tensor): 148 | parameters = [parameters] 149 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 150 | norm_type = float(norm_type) 151 | if clip_value is not None: 152 | clip_value = float(clip_value) 153 | 154 | total_norm = 0 155 | for p in parameters: 156 | param_norm = p.grad.data.norm(norm_type) 157 | total_norm += param_norm.item() ** norm_type 158 | if clip_value is not None: 159 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 160 | total_norm = total_norm ** (1.0 / norm_type) 161 | return total_norm 162 | -------------------------------------------------------------------------------- /audiosr/clap/open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import Union, List 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict 13 | from .pretrained import ( 14 | get_pretrained_url, 15 | list_pretrained_tag_models, 16 | download_pretrained, 17 | ) 18 | 19 | __all__ = ["list_openai_models", "load_openai_model"] 20 | 21 | 22 | def list_openai_models() -> List[str]: 23 | """Returns the names of available CLIP models""" 24 | return list_pretrained_tag_models("openai") 25 | 26 | 27 | def load_openai_model( 28 | name: str, 29 | model_cfg, 30 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 31 | jit=True, 32 | cache_dir=os.path.expanduser("~/.cache/clip"), 33 | enable_fusion: bool = False, 34 | fusion_type: str = "None", 35 | ): 36 | """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model 37 | 38 | Parameters 39 | ---------- 40 | name : str 41 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 42 | device : Union[str, torch.device] 43 | The device to put the loaded model 44 | jit : bool 45 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 46 | 47 | Returns 48 | ------- 49 | model : torch.nn.Module 50 | The CLAP model 51 | preprocess : Callable[[PIL.Image], torch.Tensor] 52 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 53 | """ 54 | if get_pretrained_url(name, "openai"): 55 | model_path = download_pretrained( 56 | get_pretrained_url(name, "openai"), root=cache_dir 57 | ) 58 | elif os.path.isfile(name): 59 | model_path = name 60 | else: 61 | raise RuntimeError( 62 | f"Model {name} not found; available models = {list_openai_models()}" 63 | ) 64 | 65 | try: 66 | # loading JIT archive 67 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 68 | state_dict = None 69 | except RuntimeError: 70 | # loading saved state dict 71 | if jit: 72 | warnings.warn( 73 | f"File {model_path} is not a JIT archive. Loading as a state dict instead" 74 | ) 75 | jit = False 76 | state_dict = torch.load(model_path, map_location="cpu") 77 | 78 | if not jit: 79 | try: 80 | model = build_model_from_openai_state_dict( 81 | state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type 82 | ).to(device) 83 | except KeyError: 84 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 85 | model = build_model_from_openai_state_dict( 86 | sd, model_cfg, enable_fusion, fusion_type 87 | ).to(device) 88 | 89 | if str(device) == "cpu": 90 | model.float() 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace( 95 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] 96 | ) 97 | device_node = [ 98 | n 99 | for n in device_holder.graph.findAllNodes("prim::Constant") 100 | if "Device" in repr(n) 101 | ][-1] 102 | 103 | def patch_device(module): 104 | try: 105 | graphs = [module.graph] if hasattr(module, "graph") else [] 106 | except RuntimeError: 107 | graphs = [] 108 | 109 | if hasattr(module, "forward1"): 110 | graphs.append(module.forward1.graph) 111 | 112 | for graph in graphs: 113 | for node in graph.findAllNodes("prim::Constant"): 114 | if "value" in node.attributeNames() and str(node["value"]).startswith( 115 | "cuda" 116 | ): 117 | node.copyAttributes(device_node) 118 | 119 | model.apply(patch_device) 120 | patch_device(model.encode_audio) 121 | patch_device(model.encode_text) 122 | 123 | # patch dtype to float32 on CPU 124 | if str(device) == "cpu": 125 | float_holder = torch.jit.trace( 126 | lambda: torch.ones([]).float(), example_inputs=[] 127 | ) 128 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 129 | float_node = float_input.node() 130 | 131 | def patch_float(module): 132 | try: 133 | graphs = [module.graph] if hasattr(module, "graph") else [] 134 | except RuntimeError: 135 | graphs = [] 136 | 137 | if hasattr(module, "forward1"): 138 | graphs.append(module.forward1.graph) 139 | 140 | for graph in graphs: 141 | for node in graph.findAllNodes("aten::to"): 142 | inputs = list(node.inputs()) 143 | for i in [ 144 | 1, 145 | 2, 146 | ]: # dtype can be the second or third argument to aten::to() 147 | if inputs[i].node()["value"] == 5: 148 | inputs[i].node().copyAttributes(float_node) 149 | 150 | model.apply(patch_float) 151 | patch_float(model.encode_audio) 152 | patch_float(model.encode_text) 153 | model.float() 154 | 155 | model.audio_branch.audio_length = model.audio_cfg.audio_length 156 | return model 157 | -------------------------------------------------------------------------------- /audiosr/utilities/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import audiosr.hifigan as hifigan 4 | 5 | 6 | def get_vocoder_config(): 7 | return { 8 | "resblock": "1", 9 | "num_gpus": 6, 10 | "batch_size": 16, 11 | "learning_rate": 0.0002, 12 | "adam_b1": 0.8, 13 | "adam_b2": 0.99, 14 | "lr_decay": 0.999, 15 | "seed": 1234, 16 | "upsample_rates": [5, 4, 2, 2, 2], 17 | "upsample_kernel_sizes": [16, 16, 8, 4, 4], 18 | "upsample_initial_channel": 1024, 19 | "resblock_kernel_sizes": [3, 7, 11], 20 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 21 | "segment_size": 8192, 22 | "num_mels": 64, 23 | "num_freq": 1025, 24 | "n_fft": 1024, 25 | "hop_size": 160, 26 | "win_size": 1024, 27 | "sampling_rate": 16000, 28 | "fmin": 0, 29 | "fmax": 8000, 30 | "fmax_for_loss": None, 31 | "num_workers": 4, 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1, 36 | }, 37 | } 38 | 39 | 40 | def get_vocoder_config_48k(): 41 | return { 42 | "resblock": "1", 43 | "num_gpus": 8, 44 | "batch_size": 128, 45 | "learning_rate": 0.0001, 46 | "adam_b1": 0.8, 47 | "adam_b2": 0.99, 48 | "lr_decay": 0.999, 49 | "seed": 1234, 50 | "upsample_rates": [6, 5, 4, 2, 2], 51 | "upsample_kernel_sizes": [12, 10, 8, 4, 4], 52 | "upsample_initial_channel": 1536, 53 | "resblock_kernel_sizes": [3, 7, 11, 15], 54 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]], 55 | "segment_size": 15360, 56 | "num_mels": 256, 57 | "n_fft": 2048, 58 | "hop_size": 480, 59 | "win_size": 2048, 60 | "sampling_rate": 48000, 61 | "fmin": 20, 62 | "fmax": 24000, 63 | "fmax_for_loss": None, 64 | "num_workers": 8, 65 | "dist_config": { 66 | "dist_backend": "nccl", 67 | "dist_url": "tcp://localhost:18273", 68 | "world_size": 1, 69 | }, 70 | } 71 | 72 | 73 | def get_available_checkpoint_keys(model, ckpt): 74 | state_dict = torch.load(ckpt)["state_dict"] 75 | current_state_dict = model.state_dict() 76 | new_state_dict = {} 77 | for k in state_dict.keys(): 78 | if ( 79 | k in current_state_dict.keys() 80 | and current_state_dict[k].size() == state_dict[k].size() 81 | ): 82 | new_state_dict[k] = state_dict[k] 83 | else: 84 | print("==> WARNING: Skipping %s" % k) 85 | print( 86 | "%s out of %s keys are matched" 87 | % (len(new_state_dict.keys()), len(state_dict.keys())) 88 | ) 89 | return new_state_dict 90 | 91 | 92 | def get_param_num(model): 93 | num_param = sum(param.numel() for param in model.parameters()) 94 | return num_param 95 | 96 | 97 | def torch_version_orig_mod_remove(state_dict): 98 | new_state_dict = {} 99 | new_state_dict["generator"] = {} 100 | for key in state_dict["generator"].keys(): 101 | if "_orig_mod." in key: 102 | new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[ 103 | "generator" 104 | ][key] 105 | else: 106 | new_state_dict["generator"][key] = state_dict["generator"][key] 107 | return new_state_dict 108 | 109 | 110 | def get_vocoder(config, device, mel_bins): 111 | name = "HiFi-GAN" 112 | speaker = "" 113 | if name == "MelGAN": 114 | if speaker == "LJSpeech": 115 | vocoder = torch.hub.load( 116 | "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" 117 | ) 118 | elif speaker == "universal": 119 | vocoder = torch.hub.load( 120 | "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" 121 | ) 122 | vocoder.mel2wav.eval() 123 | vocoder.mel2wav.to(device) 124 | elif name == "HiFi-GAN": 125 | if mel_bins == 64: 126 | config = get_vocoder_config() 127 | config = hifigan.AttrDict(config) 128 | vocoder = hifigan.Generator_old(config) 129 | # print("Load hifigan/g_01080000") 130 | # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) 131 | # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) 132 | # ckpt = torch_version_orig_mod_remove(ckpt) 133 | # vocoder.load_state_dict(ckpt["generator"]) 134 | vocoder.eval() 135 | vocoder.remove_weight_norm() 136 | vocoder.to(device) 137 | else: 138 | config = get_vocoder_config_48k() 139 | config = hifigan.AttrDict(config) 140 | vocoder = hifigan.Generator_old(config) 141 | # print("Load hifigan/g_01080000") 142 | # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) 143 | # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) 144 | # ckpt = torch_version_orig_mod_remove(ckpt) 145 | # vocoder.load_state_dict(ckpt["generator"]) 146 | vocoder.eval() 147 | vocoder.remove_weight_norm() 148 | vocoder.to(device) 149 | return vocoder 150 | 151 | 152 | def vocoder_infer(mels, vocoder, lengths=None): 153 | with torch.no_grad(): 154 | wavs = vocoder(mels).squeeze(1) 155 | 156 | wavs = (wavs.cpu().numpy() * 32768).astype("int16") 157 | 158 | if lengths is not None: 159 | wavs = wavs[:, :lengths] 160 | 161 | # wavs = [wav for wav in wavs] 162 | 163 | # for i in range(len(mels)): 164 | # if lengths is not None: 165 | # wavs[i] = wavs[i][: lengths[i]] 166 | 167 | return wavs 168 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/audiomae/AudioMAE.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference Repo: https://github.com/facebookresearch/AudioMAE 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from timm.models.layers import to_2tuple 8 | import audiosr.latent_diffusion.modules.audiomae.models_vit as models_vit 9 | import audiosr.latent_diffusion.modules.audiomae.models_mae as models_mae 10 | 11 | # model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128)) 12 | 13 | 14 | class PatchEmbed_new(nn.Module): 15 | """Flexible Image to Patch Embedding""" 16 | 17 | def __init__( 18 | self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 19 | ): 20 | super().__init__() 21 | img_size = to_2tuple(img_size) 22 | patch_size = to_2tuple(patch_size) 23 | stride = to_2tuple(stride) 24 | 25 | self.img_size = img_size 26 | self.patch_size = patch_size 27 | 28 | self.proj = nn.Conv2d( 29 | in_chans, embed_dim, kernel_size=patch_size, stride=stride 30 | ) # with overlapped patches 31 | # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 32 | 33 | # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) 34 | # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 35 | _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w 36 | self.patch_hw = (h, w) 37 | self.num_patches = h * w 38 | 39 | def get_output_shape(self, img_size): 40 | # todo: don't be lazy.. 41 | return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape 42 | 43 | def forward(self, x): 44 | B, C, H, W = x.shape 45 | # FIXME look at relaxing size constraints 46 | # assert H == self.img_size[0] and W == self.img_size[1], \ 47 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 48 | x = self.proj(x) 49 | x = x.flatten(2).transpose(1, 2) 50 | return x 51 | 52 | 53 | class AudioMAE(nn.Module): 54 | """Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)""" 55 | 56 | def __init__( 57 | self, 58 | ): 59 | super().__init__() 60 | model = models_vit.__dict__["vit_base_patch16"]( 61 | num_classes=527, 62 | drop_path_rate=0.1, 63 | global_pool=True, 64 | mask_2d=True, 65 | use_custom_patch=False, 66 | ) 67 | 68 | img_size = (1024, 128) 69 | emb_dim = 768 70 | 71 | model.patch_embed = PatchEmbed_new( 72 | img_size=img_size, 73 | patch_size=(16, 16), 74 | in_chans=1, 75 | embed_dim=emb_dim, 76 | stride=16, 77 | ) 78 | num_patches = model.patch_embed.num_patches 79 | # num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8 80 | model.pos_embed = nn.Parameter( 81 | torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False 82 | ) # fixed sin-cos embedding 83 | 84 | # checkpoint_path = '/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth' 85 | # checkpoint = torch.load(checkpoint_path, map_location='cpu') 86 | # msg = model.load_state_dict(checkpoint['model'], strict=False) 87 | # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}') 88 | 89 | self.model = model 90 | 91 | def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0): 92 | """ 93 | x: mel fbank [Batch, 1, T, F] 94 | mask_t_prob: 'T masking ratio (percentage of removed patches).' 95 | mask_f_prob: 'F masking ratio (percentage of removed patches).' 96 | """ 97 | return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob) 98 | 99 | 100 | class Vanilla_AudioMAE(nn.Module): 101 | """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM2)""" 102 | 103 | def __init__( 104 | self, 105 | ): 106 | super().__init__() 107 | model = models_mae.__dict__["mae_vit_base_patch16"]( 108 | in_chans=1, audio_exp=True, img_size=(1024, 128) 109 | ) 110 | 111 | # checkpoint_path = '/mnt/bn/lqhaoheliu/exps/checkpoints/audiomae/pretrained.pth' 112 | # checkpoint = torch.load(checkpoint_path, map_location='cpu') 113 | # msg = model.load_state_dict(checkpoint['model'], strict=False) 114 | 115 | # Skip the missing keys of decoder modules (not required) 116 | # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}') 117 | 118 | self.model = model.eval() 119 | 120 | def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False): 121 | """ 122 | x: mel fbank [Batch, 1, 1024 (T), 128 (F)] 123 | mask_ratio: 'masking ratio (percentage of removed patches).' 124 | """ 125 | with torch.no_grad(): 126 | # embed: [B, 513, 768] for mask_ratio=0.0 127 | if no_mask: 128 | if no_average: 129 | raise RuntimeError("This function is deprecated") 130 | embed = self.model.forward_encoder_no_random_mask_no_average( 131 | x 132 | ) # mask_ratio 133 | else: 134 | embed = self.model.forward_encoder_no_mask(x) # mask_ratio 135 | else: 136 | raise RuntimeError("This function is deprecated") 137 | embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio) 138 | return embed 139 | 140 | 141 | if __name__ == "__main__": 142 | model = Vanilla_AudioMAE().cuda() 143 | input = torch.randn(4, 1, 1024, 128).cuda() 144 | print("The first run") 145 | embed = model(input, mask_ratio=0.0, no_mask=True) 146 | print(embed) 147 | print("The second run") 148 | embed = model(input, mask_ratio=0.0) 149 | print(embed) 150 | -------------------------------------------------------------------------------- /audiosr/hifigan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv1d, ConvTranspose1d 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResBlock(torch.nn.Module): 21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 22 | super(ResBlock, self).__init__() 23 | self.h = h 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[0], 33 | padding=get_padding(kernel_size, dilation[0]), 34 | ) 35 | ), 36 | weight_norm( 37 | Conv1d( 38 | channels, 39 | channels, 40 | kernel_size, 41 | 1, 42 | dilation=dilation[1], 43 | padding=get_padding(kernel_size, dilation[1]), 44 | ) 45 | ), 46 | weight_norm( 47 | Conv1d( 48 | channels, 49 | channels, 50 | kernel_size, 51 | 1, 52 | dilation=dilation[2], 53 | padding=get_padding(kernel_size, dilation[2]), 54 | ) 55 | ), 56 | ] 57 | ) 58 | self.convs1.apply(init_weights) 59 | 60 | self.convs2 = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | Conv1d( 64 | channels, 65 | channels, 66 | kernel_size, 67 | 1, 68 | dilation=1, 69 | padding=get_padding(kernel_size, 1), 70 | ) 71 | ), 72 | weight_norm( 73 | Conv1d( 74 | channels, 75 | channels, 76 | kernel_size, 77 | 1, 78 | dilation=1, 79 | padding=get_padding(kernel_size, 1), 80 | ) 81 | ), 82 | weight_norm( 83 | Conv1d( 84 | channels, 85 | channels, 86 | kernel_size, 87 | 1, 88 | dilation=1, 89 | padding=get_padding(kernel_size, 1), 90 | ) 91 | ), 92 | ] 93 | ) 94 | self.convs2.apply(init_weights) 95 | 96 | def forward(self, x): 97 | for c1, c2 in zip(self.convs1, self.convs2): 98 | xt = F.leaky_relu(x, LRELU_SLOPE) 99 | xt = c1(xt) 100 | xt = F.leaky_relu(xt, LRELU_SLOPE) 101 | xt = c2(xt) 102 | x = xt + x 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs1: 107 | remove_weight_norm(l) 108 | for l in self.convs2: 109 | remove_weight_norm(l) 110 | 111 | 112 | class Generator(torch.nn.Module): 113 | def __init__(self, h): 114 | super(Generator, self).__init__() 115 | self.h = h 116 | self.num_kernels = len(h.resblock_kernel_sizes) 117 | self.num_upsamples = len(h.upsample_rates) 118 | self.conv_pre = weight_norm( 119 | Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) 120 | ) 121 | resblock = ResBlock 122 | 123 | self.ups = nn.ModuleList() 124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 125 | self.ups.append( 126 | weight_norm( 127 | ConvTranspose1d( 128 | h.upsample_initial_channel // (2**i), 129 | h.upsample_initial_channel // (2 ** (i + 1)), 130 | k, 131 | u, 132 | padding=(k - u) // 2, 133 | ) 134 | ) 135 | ) 136 | 137 | self.resblocks = nn.ModuleList() 138 | for i in range(len(self.ups)): 139 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 140 | for j, (k, d) in enumerate( 141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 142 | ): 143 | self.resblocks.append(resblock(h, ch, k, d)) 144 | 145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 146 | self.ups.apply(init_weights) 147 | self.conv_post.apply(init_weights) 148 | 149 | def forward(self, x): 150 | x = self.conv_pre(x) 151 | for i in range(self.num_upsamples): 152 | x = F.leaky_relu(x, LRELU_SLOPE) 153 | x = self.ups[i](x) 154 | xs = None 155 | for j in range(self.num_kernels): 156 | if xs is None: 157 | xs = self.resblocks[i * self.num_kernels + j](x) 158 | else: 159 | xs += self.resblocks[i * self.num_kernels + j](x) 160 | x = xs / self.num_kernels 161 | x = F.leaky_relu(x) 162 | x = self.conv_post(x) 163 | x = torch.tanh(x) 164 | 165 | return x 166 | 167 | def remove_weight_norm(self): 168 | # print("Removing weight norm...") 169 | for l in self.ups: 170 | remove_weight_norm(l) 171 | for l in self.resblocks: 172 | l.remove_weight_norm() 173 | remove_weight_norm(self.conv_pre) 174 | remove_weight_norm(self.conv_post) 175 | -------------------------------------------------------------------------------- /audiosr/utilities/audio/stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy.signal import get_window 5 | from librosa.util import pad_center, tiny 6 | from librosa.filters import mel as librosa_mel_fn 7 | 8 | from audiosr.utilities.audio.audio_processing import ( 9 | dynamic_range_compression, 10 | dynamic_range_decompression, 11 | window_sumsquare, 12 | ) 13 | 14 | 15 | class STFT(torch.nn.Module): 16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 17 | 18 | def __init__(self, filter_length, hop_length, win_length, window="hann"): 19 | super(STFT, self).__init__() 20 | self.filter_length = filter_length 21 | self.hop_length = hop_length 22 | self.win_length = win_length 23 | self.window = window 24 | self.forward_transform = None 25 | scale = self.filter_length / self.hop_length 26 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 27 | 28 | cutoff = int((self.filter_length / 2 + 1)) 29 | fourier_basis = np.vstack( 30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] 31 | ) 32 | 33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 34 | inverse_basis = torch.FloatTensor( 35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :] 36 | ) 37 | 38 | if window is not None: 39 | assert filter_length >= win_length 40 | # get window and zero center pad it to filter_length 41 | fft_window = get_window(window, win_length, fftbins=True) 42 | fft_window = pad_center(fft_window, filter_length) 43 | fft_window = torch.from_numpy(fft_window).float() 44 | 45 | # window the bases 46 | forward_basis *= fft_window 47 | inverse_basis *= fft_window 48 | 49 | self.register_buffer("forward_basis", forward_basis.float()) 50 | self.register_buffer("inverse_basis", inverse_basis.float()) 51 | 52 | def transform(self, input_data): 53 | num_batches = input_data.size(0) 54 | num_samples = input_data.size(1) 55 | 56 | self.num_samples = num_samples 57 | 58 | # similar to librosa, reflect-pad the input 59 | input_data = input_data.view(num_batches, 1, num_samples) 60 | input_data = F.pad( 61 | input_data.unsqueeze(1), 62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 63 | mode="reflect", 64 | ) 65 | input_data = input_data.squeeze(1) 66 | 67 | forward_transform = F.conv1d( 68 | input_data, 69 | torch.autograd.Variable(self.forward_basis, requires_grad=False), 70 | stride=self.hop_length, 71 | padding=0, 72 | ).cpu() 73 | 74 | cutoff = int((self.filter_length / 2) + 1) 75 | real_part = forward_transform[:, :cutoff, :] 76 | imag_part = forward_transform[:, cutoff:, :] 77 | 78 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) 80 | 81 | return magnitude, phase 82 | 83 | def inverse(self, magnitude, phase): 84 | recombine_magnitude_phase = torch.cat( 85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 86 | ) 87 | 88 | inverse_transform = F.conv_transpose1d( 89 | recombine_magnitude_phase, 90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False), 91 | stride=self.hop_length, 92 | padding=0, 93 | ) 94 | 95 | if self.window is not None: 96 | window_sum = window_sumsquare( 97 | self.window, 98 | magnitude.size(-1), 99 | hop_length=self.hop_length, 100 | win_length=self.win_length, 101 | n_fft=self.filter_length, 102 | dtype=np.float32, 103 | ) 104 | # remove modulation effects 105 | approx_nonzero_indices = torch.from_numpy( 106 | np.where(window_sum > tiny(window_sum))[0] 107 | ) 108 | window_sum = torch.autograd.Variable( 109 | torch.from_numpy(window_sum), requires_grad=False 110 | ) 111 | window_sum = window_sum 112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ 113 | approx_nonzero_indices 114 | ] 115 | 116 | # scale by hop ratio 117 | inverse_transform *= float(self.filter_length) / self.hop_length 118 | 119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] 120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] 121 | 122 | return inverse_transform 123 | 124 | def forward(self, input_data): 125 | self.magnitude, self.phase = self.transform(input_data) 126 | reconstruction = self.inverse(self.magnitude, self.phase) 127 | return reconstruction 128 | 129 | 130 | class TacotronSTFT(torch.nn.Module): 131 | def __init__( 132 | self, 133 | filter_length, 134 | hop_length, 135 | win_length, 136 | n_mel_channels, 137 | sampling_rate, 138 | mel_fmin, 139 | mel_fmax, 140 | ): 141 | super(TacotronSTFT, self).__init__() 142 | self.n_mel_channels = n_mel_channels 143 | self.sampling_rate = sampling_rate 144 | self.stft_fn = STFT(filter_length, hop_length, win_length) 145 | mel_basis = librosa_mel_fn( 146 | sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax 147 | ) 148 | mel_basis = torch.from_numpy(mel_basis).float() 149 | self.register_buffer("mel_basis", mel_basis) 150 | 151 | def spectral_normalize(self, magnitudes, normalize_fun): 152 | output = dynamic_range_compression(magnitudes, normalize_fun) 153 | return output 154 | 155 | def spectral_de_normalize(self, magnitudes): 156 | output = dynamic_range_decompression(magnitudes) 157 | return output 158 | 159 | def mel_spectrogram(self, y, normalize_fun=torch.log): 160 | """Computes mel-spectrograms from a batch of waves 161 | PARAMS 162 | ------ 163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 164 | 165 | RETURNS 166 | ------- 167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 168 | """ 169 | assert torch.min(y.data) >= -1, torch.min(y.data) 170 | assert torch.max(y.data) <= 1, torch.max(y.data) 171 | 172 | magnitudes, phases = self.stft_fn.transform(y) 173 | magnitudes = magnitudes.data 174 | mel_output = torch.matmul(self.mel_basis, magnitudes) 175 | mel_output = self.spectral_normalize(mel_output, normalize_fun) 176 | energy = torch.norm(magnitudes, dim=1) 177 | 178 | return mel_output, magnitudes, phases, energy 179 | -------------------------------------------------------------------------------- /audiosr/clap/open_clip/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | 6 | from tqdm import tqdm 7 | 8 | _RN50 = dict( 9 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 10 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 11 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", 12 | ) 13 | 14 | _RN50_quickgelu = dict( 15 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 16 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 17 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", 18 | ) 19 | 20 | _RN101 = dict( 21 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 22 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", 23 | ) 24 | 25 | _RN101_quickgelu = dict( 26 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 27 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", 28 | ) 29 | 30 | _RN50x4 = dict( 31 | openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 32 | ) 33 | 34 | _RN50x16 = dict( 35 | openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 36 | ) 37 | 38 | _RN50x64 = dict( 39 | openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 40 | ) 41 | 42 | _VITB32 = dict( 43 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 44 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 45 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 46 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", 47 | ) 48 | 49 | _VITB32_quickgelu = dict( 50 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 51 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 52 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 53 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", 54 | ) 55 | 56 | _VITB16 = dict( 57 | openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 58 | ) 59 | 60 | _VITL14 = dict( 61 | openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 62 | ) 63 | 64 | _PRETRAINED = { 65 | "RN50": _RN50, 66 | "RN50-quickgelu": _RN50_quickgelu, 67 | "RN101": _RN101, 68 | "RN101-quickgelu": _RN101_quickgelu, 69 | "RN50x4": _RN50x4, 70 | "RN50x16": _RN50x16, 71 | "ViT-B-32": _VITB32, 72 | "ViT-B-32-quickgelu": _VITB32_quickgelu, 73 | "ViT-B-16": _VITB16, 74 | "ViT-L-14": _VITL14, 75 | } 76 | 77 | 78 | def list_pretrained(as_str: bool = False): 79 | """returns list of pretrained models 80 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 81 | """ 82 | return [ 83 | ":".join([k, t]) if as_str else (k, t) 84 | for k in _PRETRAINED.keys() 85 | for t in _PRETRAINED[k].keys() 86 | ] 87 | 88 | 89 | def list_pretrained_tag_models(tag: str): 90 | """return all models having the specified pretrain tag""" 91 | models = [] 92 | for k in _PRETRAINED.keys(): 93 | if tag in _PRETRAINED[k]: 94 | models.append(k) 95 | return models 96 | 97 | 98 | def list_pretrained_model_tags(model: str): 99 | """return all pretrain tags for the specified model architecture""" 100 | tags = [] 101 | if model in _PRETRAINED: 102 | tags.extend(_PRETRAINED[model].keys()) 103 | return tags 104 | 105 | 106 | def get_pretrained_url(model: str, tag: str): 107 | if model not in _PRETRAINED: 108 | return "" 109 | model_pretrained = _PRETRAINED[model] 110 | if tag not in model_pretrained: 111 | return "" 112 | return model_pretrained[tag] 113 | 114 | 115 | def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): 116 | os.makedirs(root, exist_ok=True) 117 | filename = os.path.basename(url) 118 | 119 | if "openaipublic" in url: 120 | expected_sha256 = url.split("/")[-2] 121 | else: 122 | expected_sha256 = "" 123 | 124 | download_target = os.path.join(root, filename) 125 | 126 | if os.path.exists(download_target) and not os.path.isfile(download_target): 127 | raise RuntimeError(f"{download_target} exists and is not a regular file") 128 | 129 | if os.path.isfile(download_target): 130 | if expected_sha256: 131 | if ( 132 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 133 | == expected_sha256 134 | ): 135 | return download_target 136 | else: 137 | warnings.warn( 138 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 139 | ) 140 | else: 141 | return download_target 142 | 143 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 144 | with tqdm( 145 | total=int(source.info().get("Content-Length")), 146 | ncols=80, 147 | unit="iB", 148 | unit_scale=True, 149 | ) as loop: 150 | while True: 151 | buffer = source.read(8192) 152 | if not buffer: 153 | break 154 | 155 | output.write(buffer) 156 | loop.update(len(buffer)) 157 | 158 | if ( 159 | expected_sha256 160 | and hashlib.sha256(open(download_target, "rb").read()).hexdigest() 161 | != expected_sha256 162 | ): 163 | raise RuntimeError( 164 | f"Model has been downloaded but the SHA256 checksum does not not match" 165 | ) 166 | 167 | return download_target 168 | -------------------------------------------------------------------------------- /audiosr/clap/open_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | 16 | @lru_cache() 17 | def default_bpe(): 18 | return os.path.join( 19 | os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" 20 | ) 21 | 22 | 23 | @lru_cache() 24 | def bytes_to_unicode(): 25 | """ 26 | Returns list of utf-8 byte and a corresponding list of unicode strings. 27 | The reversible bpe codes work on unicode strings. 28 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 29 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 30 | This is a signficant percentage of your normal, say, 32K bpe vocab. 31 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 32 | And avoids mapping to whitespace/control characters the bpe code barfs on. 33 | """ 34 | bs = ( 35 | list(range(ord("!"), ord("~") + 1)) 36 | + list(range(ord("¡"), ord("¬") + 1)) 37 | + list(range(ord("®"), ord("ÿ") + 1)) 38 | ) 39 | cs = bs[:] 40 | n = 0 41 | for b in range(2**8): 42 | if b not in bs: 43 | bs.append(b) 44 | cs.append(2**8 + n) 45 | n += 1 46 | cs = [chr(n) for n in cs] 47 | return dict(zip(bs, cs)) 48 | 49 | 50 | def get_pairs(word): 51 | """Return set of symbol pairs in a word. 52 | Word is represented as tuple of symbols (symbols being variable-length strings). 53 | """ 54 | pairs = set() 55 | prev_char = word[0] 56 | for char in word[1:]: 57 | pairs.add((prev_char, char)) 58 | prev_char = char 59 | return pairs 60 | 61 | 62 | def basic_clean(text): 63 | text = ftfy.fix_text(text) 64 | text = html.unescape(html.unescape(text)) 65 | return text.strip() 66 | 67 | 68 | def whitespace_clean(text): 69 | text = re.sub(r"\s+", " ", text) 70 | text = text.strip() 71 | return text 72 | 73 | 74 | class SimpleTokenizer(object): 75 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 76 | self.byte_encoder = bytes_to_unicode() 77 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 78 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") 79 | merges = merges[1 : 49152 - 256 - 2 + 1] 80 | merges = [tuple(merge.split()) for merge in merges] 81 | vocab = list(bytes_to_unicode().values()) 82 | vocab = vocab + [v + "" for v in vocab] 83 | for merge in merges: 84 | vocab.append("".join(merge)) 85 | if not special_tokens: 86 | special_tokens = ["", ""] 87 | else: 88 | special_tokens = ["", ""] + special_tokens 89 | vocab.extend(special_tokens) 90 | self.encoder = dict(zip(vocab, range(len(vocab)))) 91 | self.decoder = {v: k for k, v in self.encoder.items()} 92 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 93 | self.cache = {t: t for t in special_tokens} 94 | special = "|".join(special_tokens) 95 | self.pat = re.compile( 96 | special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 97 | re.IGNORECASE, 98 | ) 99 | 100 | self.vocab_size = len(self.encoder) 101 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 102 | 103 | def bpe(self, token): 104 | if token in self.cache: 105 | return self.cache[token] 106 | word = tuple(token[:-1]) + (token[-1] + "",) 107 | pairs = get_pairs(word) 108 | 109 | if not pairs: 110 | return token + "" 111 | 112 | while True: 113 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 114 | if bigram not in self.bpe_ranks: 115 | break 116 | first, second = bigram 117 | new_word = [] 118 | i = 0 119 | while i < len(word): 120 | try: 121 | j = word.index(first, i) 122 | new_word.extend(word[i:j]) 123 | i = j 124 | except: 125 | new_word.extend(word[i:]) 126 | break 127 | 128 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 129 | new_word.append(first + second) 130 | i += 2 131 | else: 132 | new_word.append(word[i]) 133 | i += 1 134 | new_word = tuple(new_word) 135 | word = new_word 136 | if len(word) == 1: 137 | break 138 | else: 139 | pairs = get_pairs(word) 140 | word = " ".join(word) 141 | self.cache[token] = word 142 | return word 143 | 144 | def encode(self, text): 145 | bpe_tokens = [] 146 | text = whitespace_clean(basic_clean(text)).lower() 147 | for token in re.findall(self.pat, text): 148 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 149 | bpe_tokens.extend( 150 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") 151 | ) 152 | return bpe_tokens 153 | 154 | def decode(self, tokens): 155 | text = "".join([self.decoder[token] for token in tokens]) 156 | text = ( 157 | bytearray([self.byte_decoder[c] for c in text]) 158 | .decode("utf-8", errors="replace") 159 | .replace("", " ") 160 | ) 161 | return text 162 | 163 | 164 | _tokenizer = SimpleTokenizer() 165 | 166 | 167 | def tokenize( 168 | texts: Union[str, List[str]], context_length: int = 77 169 | ) -> torch.LongTensor: 170 | """ 171 | Returns the tokenized representation of given input string(s) 172 | 173 | Parameters 174 | ---------- 175 | texts : Union[str, List[str]] 176 | An input string or a list of input strings to tokenize 177 | context_length : int 178 | The context length to use; all CLIP models use 77 as the context length 179 | 180 | Returns 181 | ------- 182 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 183 | """ 184 | if isinstance(texts, str): 185 | texts = [texts] 186 | 187 | sot_token = _tokenizer.encoder[""] 188 | eot_token = _tokenizer.encoder[""] 189 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 190 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 191 | 192 | for i, tokens in enumerate(all_tokens): 193 | if len(tokens) > context_length: 194 | tokens = tokens[:context_length] # Truncate 195 | result[i, : len(tokens)] = torch.tensor(tokens) 196 | 197 | return result 198 | -------------------------------------------------------------------------------- /audiosr/clap/open_clip/feature_fusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Feature Fusion for Varible-Length Data Processing 3 | AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py 4 | According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class DAF(nn.Module): 12 | """ 13 | 直接相加 DirectAddFuse 14 | """ 15 | 16 | def __init__(self): 17 | super(DAF, self).__init__() 18 | 19 | def forward(self, x, residual): 20 | return x + residual 21 | 22 | 23 | class iAFF(nn.Module): 24 | """ 25 | 多特征融合 iAFF 26 | """ 27 | 28 | def __init__(self, channels=64, r=4, type="2D"): 29 | super(iAFF, self).__init__() 30 | inter_channels = int(channels // r) 31 | 32 | if type == "1D": 33 | # 本地注意力 34 | self.local_att = nn.Sequential( 35 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 36 | nn.BatchNorm1d(inter_channels), 37 | nn.ReLU(inplace=True), 38 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 39 | nn.BatchNorm1d(channels), 40 | ) 41 | 42 | # 全局注意力 43 | self.global_att = nn.Sequential( 44 | nn.AdaptiveAvgPool1d(1), 45 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 46 | nn.BatchNorm1d(inter_channels), 47 | nn.ReLU(inplace=True), 48 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 49 | nn.BatchNorm1d(channels), 50 | ) 51 | 52 | # 第二次本地注意力 53 | self.local_att2 = nn.Sequential( 54 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 55 | nn.BatchNorm1d(inter_channels), 56 | nn.ReLU(inplace=True), 57 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 58 | nn.BatchNorm1d(channels), 59 | ) 60 | # 第二次全局注意力 61 | self.global_att2 = nn.Sequential( 62 | nn.AdaptiveAvgPool1d(1), 63 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 64 | nn.BatchNorm1d(inter_channels), 65 | nn.ReLU(inplace=True), 66 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 67 | nn.BatchNorm1d(channels), 68 | ) 69 | elif type == "2D": 70 | # 本地注意力 71 | self.local_att = nn.Sequential( 72 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 73 | nn.BatchNorm2d(inter_channels), 74 | nn.ReLU(inplace=True), 75 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 76 | nn.BatchNorm2d(channels), 77 | ) 78 | 79 | # 全局注意力 80 | self.global_att = nn.Sequential( 81 | nn.AdaptiveAvgPool2d(1), 82 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 83 | nn.BatchNorm2d(inter_channels), 84 | nn.ReLU(inplace=True), 85 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 86 | nn.BatchNorm2d(channels), 87 | ) 88 | 89 | # 第二次本地注意力 90 | self.local_att2 = nn.Sequential( 91 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 92 | nn.BatchNorm2d(inter_channels), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 95 | nn.BatchNorm2d(channels), 96 | ) 97 | # 第二次全局注意力 98 | self.global_att2 = nn.Sequential( 99 | nn.AdaptiveAvgPool2d(1), 100 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 101 | nn.BatchNorm2d(inter_channels), 102 | nn.ReLU(inplace=True), 103 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 104 | nn.BatchNorm2d(channels), 105 | ) 106 | else: 107 | raise f"the type is not supported" 108 | 109 | self.sigmoid = nn.Sigmoid() 110 | 111 | def forward(self, x, residual): 112 | flag = False 113 | xa = x + residual 114 | if xa.size(0) == 1: 115 | xa = torch.cat([xa, xa], dim=0) 116 | flag = True 117 | xl = self.local_att(xa) 118 | xg = self.global_att(xa) 119 | xlg = xl + xg 120 | wei = self.sigmoid(xlg) 121 | xi = x * wei + residual * (1 - wei) 122 | 123 | xl2 = self.local_att2(xi) 124 | xg2 = self.global_att(xi) 125 | xlg2 = xl2 + xg2 126 | wei2 = self.sigmoid(xlg2) 127 | xo = x * wei2 + residual * (1 - wei2) 128 | if flag: 129 | xo = xo[0].unsqueeze(0) 130 | return xo 131 | 132 | 133 | class AFF(nn.Module): 134 | """ 135 | 多特征融合 AFF 136 | """ 137 | 138 | def __init__(self, channels=64, r=4, type="2D"): 139 | super(AFF, self).__init__() 140 | inter_channels = int(channels // r) 141 | 142 | if type == "1D": 143 | self.local_att = nn.Sequential( 144 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 145 | nn.BatchNorm1d(inter_channels), 146 | nn.ReLU(inplace=True), 147 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 148 | nn.BatchNorm1d(channels), 149 | ) 150 | self.global_att = nn.Sequential( 151 | nn.AdaptiveAvgPool1d(1), 152 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 153 | nn.BatchNorm1d(inter_channels), 154 | nn.ReLU(inplace=True), 155 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 156 | nn.BatchNorm1d(channels), 157 | ) 158 | elif type == "2D": 159 | self.local_att = nn.Sequential( 160 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 161 | nn.BatchNorm2d(inter_channels), 162 | nn.ReLU(inplace=True), 163 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 164 | nn.BatchNorm2d(channels), 165 | ) 166 | self.global_att = nn.Sequential( 167 | nn.AdaptiveAvgPool2d(1), 168 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 169 | nn.BatchNorm2d(inter_channels), 170 | nn.ReLU(inplace=True), 171 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 172 | nn.BatchNorm2d(channels), 173 | ) 174 | else: 175 | raise f"the type is not supported." 176 | 177 | self.sigmoid = nn.Sigmoid() 178 | 179 | def forward(self, x, residual): 180 | flag = False 181 | xa = x + residual 182 | if xa.size(0) == 1: 183 | xa = torch.cat([xa, xa], dim=0) 184 | flag = True 185 | xl = self.local_att(xa) 186 | xg = self.global_att(xa) 187 | xlg = xl + xg 188 | wei = self.sigmoid(xlg) 189 | xo = 2 * x * wei + 2 * residual * (1 - wei) 190 | if flag: 191 | xo = xo[0].unsqueeze(0) 192 | return xo 193 | -------------------------------------------------------------------------------- /audiosr/lowpass.py: -------------------------------------------------------------------------------- 1 | from scipy.signal import butter, lfilter 2 | import torch 3 | from scipy import signal 4 | import librosa 5 | import numpy as np 6 | 7 | from scipy.signal import sosfiltfilt 8 | from scipy.signal import butter, cheby1, cheby2, ellip, bessel 9 | from scipy.signal import resample_poly 10 | 11 | 12 | def align_length(x=None, y=None, Lx=None): 13 | """align the length of y to that of x 14 | 15 | Args: 16 | x (np.array): reference signal 17 | y (np.array): the signal needs to be length aligned 18 | 19 | Return: 20 | yy (np.array): signal with the same length as x 21 | """ 22 | assert y is not None 23 | 24 | if Lx is None: 25 | Lx = len(x) 26 | Ly = len(y) 27 | 28 | if Lx == Ly: 29 | return y 30 | elif Lx > Ly: 31 | # pad y with zeros 32 | return np.pad(y, (0, Lx - Ly), mode="constant") 33 | else: 34 | # cut y 35 | return y[:Lx] 36 | 37 | 38 | def bandpass_filter(x, lowcut, highcut, fs, order, ftype): 39 | """process input signal x using bandpass filter 40 | 41 | Args: 42 | x (np.array): input signal 43 | lowcut (float): low cutoff frequency 44 | highcut (float): high cutoff frequency 45 | order (int): the order of filter 46 | ftype (string): type of filter 47 | ['butter', 'cheby1', 'cheby2', 'ellip', 'bessel'] 48 | 49 | Return: 50 | y (np.array): filtered signal 51 | """ 52 | nyq = 0.5 * fs 53 | lo = lowcut / nyq 54 | hi = highcut / nyq 55 | 56 | if ftype == "butter": 57 | # b, a = butter(order, [lo, hi], btype='band') 58 | sos = butter(order, [lo, hi], btype="band", output="sos") 59 | elif ftype == "cheby1": 60 | sos = cheby1(order, 0.1, [lo, hi], btype="band", output="sos") 61 | elif ftype == "cheby2": 62 | sos = cheby2(order, 60, [lo, hi], btype="band", output="sos") 63 | elif ftype == "ellip": 64 | sos = ellip(order, 0.1, 60, [lo, hi], btype="band", output="sos") 65 | elif ftype == "bessel": 66 | sos = bessel(order, [lo, hi], btype="band", output="sos") 67 | else: 68 | raise Exception(f"The bandpass filter {ftype} is not supported!") 69 | 70 | # y = lfilter(b, a, x) 71 | y = sosfiltfilt(sos, x) 72 | 73 | if len(y) != len(x): 74 | y = align_length(x, y) 75 | return y 76 | 77 | 78 | def lowpass_filter(x, highcut, fs, order, ftype): 79 | """process input signal x using lowpass filter 80 | 81 | Args: 82 | x (np.array): input signal 83 | highcut (float): high cutoff frequency 84 | order (int): the order of filter 85 | ftype (string): type of filter 86 | ['butter', 'cheby1', 'cheby2', 'ellip', 'bessel'] 87 | 88 | Return: 89 | y (np.array): filtered signal 90 | """ 91 | nyq = 0.5 * fs 92 | hi = highcut / nyq 93 | 94 | if ftype == "butter": 95 | sos = butter(order, hi, btype="low", output="sos") 96 | elif ftype == "cheby1": 97 | sos = cheby1(order, 0.1, hi, btype="low", output="sos") 98 | elif ftype == "cheby2": 99 | sos = cheby2(order, 60, hi, btype="low", output="sos") 100 | elif ftype == "ellip": 101 | sos = ellip(order, 0.1, 60, hi, btype="low", output="sos") 102 | elif ftype == "bessel": 103 | sos = bessel(order, hi, btype="low", output="sos") 104 | else: 105 | raise Exception(f"The lowpass filter {ftype} is not supported!") 106 | 107 | y = sosfiltfilt(sos, x) 108 | 109 | if len(y) != len(x): 110 | y = align_length(x, y) 111 | 112 | y_len = len(y) 113 | 114 | y = stft_hard_lowpass(y, hi, fs_ori=fs) 115 | 116 | y = sosfiltfilt(sos, y) 117 | 118 | if len(y) != y_len: 119 | y = align_length(y=y, Lx=y_len) 120 | 121 | return y 122 | 123 | 124 | def stft_hard_lowpass(data, lowpass_ratio, fs_ori=44100): 125 | fs_down = int(lowpass_ratio * fs_ori) 126 | # downsample to the low sampling rate 127 | y = resample_poly(data, fs_down, fs_ori) 128 | 129 | # upsample to the original sampling rate 130 | y = resample_poly(y, fs_ori, fs_down) 131 | 132 | if len(y) != len(data): 133 | y = align_length(data, y) 134 | return y 135 | 136 | 137 | def limit(integer, high, low): 138 | if integer > high: 139 | return high 140 | elif integer < low: 141 | return low 142 | else: 143 | return int(integer) 144 | 145 | 146 | def lowpass(data, highcut, fs, order=5, _type="butter"): 147 | """ 148 | :param data: np.float32 type 1d time numpy array, (samples,) , can not be (samples, 1) !!!!!!!!!!!! 149 | :param highcut: cutoff frequency 150 | :param fs: sample rate of the original data 151 | :param order: order of the filter 152 | :return: filtered data, (samples,) 153 | """ 154 | 155 | if len(list(data.shape)) != 1: 156 | raise ValueError( 157 | "Error (chebyshev_lowpass_filter): Data " 158 | + str(data.shape) 159 | + " should be type 1d time array, (samples,) , can not be (samples, 1)" 160 | ) 161 | 162 | if _type in "butter": 163 | order = limit(order, high=10, low=2) 164 | return lowpass_filter( 165 | x=data, highcut=int(highcut), fs=fs, order=order, ftype="butter" 166 | ) 167 | elif _type in "cheby1": 168 | order = limit(order, high=10, low=2) 169 | return lowpass_filter( 170 | x=data, highcut=int(highcut), fs=fs, order=order, ftype="cheby1" 171 | ) 172 | elif _type in "ellip": 173 | order = limit(order, high=10, low=2) 174 | return lowpass_filter( 175 | x=data, highcut=int(highcut), fs=fs, order=order, ftype="ellip" 176 | ) 177 | elif _type in "bessel": 178 | order = limit(order, high=10, low=2) 179 | return lowpass_filter( 180 | x=data, highcut=int(highcut), fs=fs, order=order, ftype="bessel" 181 | ) 182 | # elif(_type in "stft"): 183 | # return stft_hard_lowpass(data, lowpass_ratio=highcut / int(fs / 2)) 184 | # elif(_type in "stft_hard"): 185 | # return stft_hard_lowpass_v0(data, lowpass_ratio=highcut / int(fs / 2)) 186 | else: 187 | raise ValueError("Error: Unexpected filter type " + _type) 188 | 189 | 190 | def bandpass(data, lowcut, highcut, fs, order=5, _type="butter"): 191 | """ 192 | :param data: np.float32 type 1d time numpy array, (samples,) , can not be (samples, 1) !!!!!!!!!!!! 193 | :param lowcut: low cutoff frequency 194 | :param highcut: high cutoff frequency 195 | :param fs: sample rate of the original data 196 | :param order: order of the filter 197 | :param _type: type of filter 198 | :return: filtered data, (samples,) 199 | """ 200 | if len(list(data.shape)) != 1: 201 | raise ValueError( 202 | "Error (chebyshev_lowpass_filter): Data " 203 | + str(data.shape) 204 | + " should be type 1d time array, (samples,) , can not be (samples, 1)" 205 | ) 206 | if _type in "butter": 207 | order = limit(order, high=10, low=2) 208 | return bandpass_filter( 209 | x=data, 210 | lowcut=int(lowcut), 211 | highcut=int(highcut), 212 | fs=fs, 213 | order=order, 214 | ftype="butter", 215 | ) 216 | elif _type in "cheby1": 217 | order = limit(order, high=10, low=2) 218 | return bandpass_filter( 219 | x=data, 220 | lowcut=int(lowcut), 221 | highcut=int(highcut), 222 | fs=fs, 223 | order=order, 224 | ftype="cheby1", 225 | ) 226 | # elif(_type in "cheby2"): 227 | # return bandpass_filter(x=data,lowcut=int(lowcut),highcut=int(highcut), fs=fs, order=order,ftype="cheby2") 228 | elif _type in "ellip": 229 | order = limit(order, high=10, low=2) 230 | return bandpass_filter( 231 | x=data, 232 | lowcut=int(lowcut), 233 | highcut=int(highcut), 234 | fs=fs, 235 | order=order, 236 | ftype="ellip", 237 | ) 238 | elif _type in "bessel": 239 | order = limit(order, high=10, low=2) 240 | return bandpass_filter( 241 | x=data, 242 | lowcut=int(lowcut), 243 | highcut=int(highcut), 244 | fs=fs, 245 | order=order, 246 | ftype="bessel", 247 | ) 248 | else: 249 | raise ValueError("Error: Unexpected filter type " + _type) 250 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/audiomae/models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | import timm.models.vision_transformer 17 | 18 | 19 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 20 | """Vision Transformer with support for global average pooling""" 21 | 22 | def __init__( 23 | self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs 24 | ): 25 | super(VisionTransformer, self).__init__(**kwargs) 26 | 27 | self.global_pool = global_pool 28 | if self.global_pool: 29 | norm_layer = kwargs["norm_layer"] 30 | embed_dim = kwargs["embed_dim"] 31 | self.fc_norm = norm_layer(embed_dim) 32 | del self.norm # remove the original norm 33 | self.mask_2d = mask_2d 34 | self.use_custom_patch = use_custom_patch 35 | 36 | def forward_features(self, x): 37 | B = x.shape[0] 38 | x = self.patch_embed(x) 39 | x = x + self.pos_embed[:, 1:, :] 40 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 41 | cls_tokens = cls_token.expand( 42 | B, -1, -1 43 | ) # stole cls_tokens impl from Phil Wang, thanks 44 | x = torch.cat((cls_tokens, x), dim=1) 45 | x = self.pos_drop(x) 46 | 47 | for blk in self.blocks: 48 | x = blk(x) 49 | 50 | if self.global_pool: 51 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 52 | outcome = self.fc_norm(x) 53 | else: 54 | x = self.norm(x) 55 | outcome = x[:, 0] 56 | 57 | return outcome 58 | 59 | def random_masking(self, x, mask_ratio): 60 | """ 61 | Perform per-sample random masking by per-sample shuffling. 62 | Per-sample shuffling is done by argsort random noise. 63 | x: [N, L, D], sequence 64 | """ 65 | N, L, D = x.shape # batch, length, dim 66 | len_keep = int(L * (1 - mask_ratio)) 67 | 68 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 69 | 70 | # sort noise for each sample 71 | ids_shuffle = torch.argsort( 72 | noise, dim=1 73 | ) # ascend: small is keep, large is remove 74 | ids_restore = torch.argsort(ids_shuffle, dim=1) 75 | 76 | # keep the first subset 77 | ids_keep = ids_shuffle[:, :len_keep] 78 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 79 | 80 | # generate the binary mask: 0 is keep, 1 is remove 81 | mask = torch.ones([N, L], device=x.device) 82 | mask[:, :len_keep] = 0 83 | # unshuffle to get the binary mask 84 | mask = torch.gather(mask, dim=1, index=ids_restore) 85 | 86 | return x_masked, mask, ids_restore 87 | 88 | def random_masking_2d(self, x, mask_t_prob, mask_f_prob): 89 | """ 90 | 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob) 91 | Perform per-sample random masking by per-sample shuffling. 92 | Per-sample shuffling is done by argsort random noise. 93 | x: [N, L, D], sequence 94 | """ 95 | 96 | N, L, D = x.shape # batch, length, dim 97 | if self.use_custom_patch: 98 | # # for AS 99 | T = 101 # 64,101 100 | F = 12 # 8,12 101 | # # for ESC 102 | # T=50 103 | # F=12 104 | # for SPC 105 | # T=12 106 | # F=12 107 | else: 108 | # ## for AS 109 | T = 64 110 | F = 8 111 | # ## for ESC 112 | # T=32 113 | # F=8 114 | ## for SPC 115 | # T=8 116 | # F=8 117 | 118 | # mask T 119 | x = x.reshape(N, T, F, D) 120 | len_keep_T = int(T * (1 - mask_t_prob)) 121 | noise = torch.rand(N, T, device=x.device) # noise in [0, 1] 122 | # sort noise for each sample 123 | ids_shuffle = torch.argsort( 124 | noise, dim=1 125 | ) # ascend: small is keep, large is remove 126 | ids_keep = ids_shuffle[:, :len_keep_T] 127 | index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D) 128 | # x_masked = torch.gather(x, dim=1, index=index) 129 | # x_masked = x_masked.reshape(N,len_keep_T*F,D) 130 | x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D 131 | 132 | # mask F 133 | # x = x.reshape(N, T, F, D) 134 | x = x.permute(0, 2, 1, 3) # N T' F D => N F T' D 135 | len_keep_F = int(F * (1 - mask_f_prob)) 136 | noise = torch.rand(N, F, device=x.device) # noise in [0, 1] 137 | # sort noise for each sample 138 | ids_shuffle = torch.argsort( 139 | noise, dim=1 140 | ) # ascend: small is keep, large is remove 141 | ids_keep = ids_shuffle[:, :len_keep_F] 142 | # index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D) 143 | index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D) 144 | x_masked = torch.gather(x, dim=1, index=index) 145 | x_masked = x_masked.permute(0, 2, 1, 3) # N F' T' D => N T' F' D 146 | # x_masked = x_masked.reshape(N,len_keep*T,D) 147 | x_masked = x_masked.reshape(N, len_keep_F * len_keep_T, D) 148 | 149 | return x_masked, None, None 150 | 151 | def forward_features_mask(self, x, mask_t_prob, mask_f_prob): 152 | B = x.shape[0] # 4,1,1024,128 153 | x = self.patch_embed(x) # 4, 512, 768 154 | 155 | x = x + self.pos_embed[:, 1:, :] 156 | if self.random_masking_2d: 157 | x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob) 158 | else: 159 | x, mask, ids_restore = self.random_masking(x, mask_t_prob) 160 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 161 | cls_tokens = cls_token.expand(B, -1, -1) 162 | x = torch.cat((cls_tokens, x), dim=1) 163 | x = self.pos_drop(x) 164 | 165 | # apply Transformer blocks 166 | for blk in self.blocks: 167 | x = blk(x) 168 | 169 | if self.global_pool: 170 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 171 | outcome = self.fc_norm(x) 172 | else: 173 | x = self.norm(x) 174 | outcome = x[:, 0] 175 | 176 | return outcome 177 | 178 | # overwrite original timm 179 | def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0): 180 | if mask_t_prob > 0.0 or mask_f_prob > 0.0: 181 | x = self.forward_features_mask( 182 | x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob 183 | ) 184 | else: 185 | x = self.forward_features(x) 186 | x = self.head(x) 187 | return x 188 | 189 | 190 | def vit_small_patch16(**kwargs): 191 | model = VisionTransformer( 192 | patch_size=16, 193 | embed_dim=384, 194 | depth=12, 195 | num_heads=6, 196 | mlp_ratio=4, 197 | qkv_bias=True, 198 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 199 | **kwargs 200 | ) 201 | return model 202 | 203 | 204 | def vit_base_patch16(**kwargs): 205 | model = VisionTransformer( 206 | patch_size=16, 207 | embed_dim=768, 208 | depth=12, 209 | num_heads=12, 210 | mlp_ratio=4, 211 | qkv_bias=True, 212 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 213 | **kwargs 214 | ) 215 | return model 216 | 217 | 218 | def vit_large_patch16(**kwargs): 219 | model = VisionTransformer( 220 | patch_size=16, 221 | embed_dim=1024, 222 | depth=24, 223 | num_heads=16, 224 | mlp_ratio=4, 225 | qkv_bias=True, 226 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 227 | **kwargs 228 | ) 229 | return model 230 | 231 | 232 | def vit_huge_patch14(**kwargs): 233 | model = VisionTransformer( 234 | patch_size=14, 235 | embed_dim=1280, 236 | depth=32, 237 | num_heads=16, 238 | mlp_ratio=4, 239 | qkv_bias=True, 240 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 241 | **kwargs 242 | ) 243 | return model 244 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | 7 | import multiprocessing as mp 8 | from threading import Thread 9 | from queue import Queue 10 | 11 | from inspect import isfunction 12 | from PIL import Image, ImageDraw, ImageFont 13 | 14 | CACHE = { 15 | "get_vits_phoneme_ids": { 16 | "PAD_LENGTH": 310, 17 | "_pad": "_", 18 | "_punctuation": ';:,.!?¡¿—…"«»“” ', 19 | "_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", 20 | "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ", 21 | "_special": "♪☎☒☝⚠", 22 | } 23 | } 24 | 25 | CACHE["get_vits_phoneme_ids"]["symbols"] = ( 26 | [CACHE["get_vits_phoneme_ids"]["_pad"]] 27 | + list(CACHE["get_vits_phoneme_ids"]["_punctuation"]) 28 | + list(CACHE["get_vits_phoneme_ids"]["_letters"]) 29 | + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"]) 30 | + list(CACHE["get_vits_phoneme_ids"]["_special"]) 31 | ) 32 | CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = { 33 | s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"]) 34 | } 35 | 36 | 37 | def get_vits_phoneme_ids_no_padding(phonemes): 38 | pad_token_id = 0 39 | pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"] 40 | _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] 41 | batchsize = len(phonemes) 42 | 43 | clean_text = phonemes[0] + "⚠" 44 | sequence = [] 45 | 46 | for symbol in clean_text: 47 | if symbol not in _symbol_to_id.keys(): 48 | print("%s is not in the vocabulary. %s" % (symbol, clean_text)) 49 | symbol = "_" 50 | symbol_id = _symbol_to_id[symbol] 51 | sequence += [symbol_id] 52 | 53 | def _pad_phonemes(phonemes_list): 54 | return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list)) 55 | 56 | sequence = sequence[:pad_length] 57 | 58 | return { 59 | "phoneme_idx": torch.LongTensor(_pad_phonemes(sequence)) 60 | .unsqueeze(0) 61 | .expand(batchsize, -1) 62 | } 63 | 64 | 65 | def log_txt_as_img(wh, xc, size=10): 66 | # wh a tuple of (width, height) 67 | # xc a list of captions to plot 68 | b = len(xc) 69 | txts = list() 70 | for bi in range(b): 71 | txt = Image.new("RGB", wh, color="white") 72 | draw = ImageDraw.Draw(txt) 73 | font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) 74 | nc = int(40 * (wh[0] / 256)) 75 | lines = "\n".join( 76 | xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) 77 | ) 78 | 79 | try: 80 | draw.text((0, 0), lines, fill="black", font=font) 81 | except UnicodeEncodeError: 82 | print("Cant encode string for logging. Skipping.") 83 | 84 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 85 | txts.append(txt) 86 | txts = np.stack(txts) 87 | txts = torch.tensor(txts) 88 | return txts 89 | 90 | 91 | def ismap(x): 92 | if not isinstance(x, torch.Tensor): 93 | return False 94 | return (len(x.shape) == 4) and (x.shape[1] > 3) 95 | 96 | 97 | def isimage(x): 98 | if not isinstance(x, torch.Tensor): 99 | return False 100 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 101 | 102 | 103 | def int16_to_float32(x): 104 | return (x / 32767.0).astype(np.float32) 105 | 106 | 107 | def float32_to_int16(x): 108 | x = np.clip(x, a_min=-1.0, a_max=1.0) 109 | return (x * 32767.0).astype(np.int16) 110 | 111 | 112 | def exists(x): 113 | return x is not None 114 | 115 | 116 | def default(val, d): 117 | if exists(val): 118 | return val 119 | return d() if isfunction(d) else d 120 | 121 | 122 | def mean_flat(tensor): 123 | """ 124 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 125 | Take the mean over all non-batch dimensions. 126 | """ 127 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 128 | 129 | 130 | def count_params(model, verbose=False): 131 | total_params = sum(p.numel() for p in model.parameters()) 132 | if verbose: 133 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 134 | return total_params 135 | 136 | 137 | def instantiate_from_config(config): 138 | if not "target" in config: 139 | if config == "__is_first_stage__": 140 | return None 141 | elif config == "__is_unconditional__": 142 | return None 143 | raise KeyError("Expected key `target` to instantiate.") 144 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 145 | 146 | 147 | def get_obj_from_str(string, reload=False): 148 | module, cls = string.rsplit(".", 1) 149 | if reload: 150 | module_imp = importlib.import_module(module) 151 | importlib.reload(module_imp) 152 | return getattr(importlib.import_module(module, package=None), cls) 153 | 154 | 155 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 156 | # create dummy dataset instance 157 | 158 | # run prefetching 159 | if idx_to_fn: 160 | res = func(data, worker_id=idx) 161 | else: 162 | res = func(data) 163 | Q.put([idx, res]) 164 | Q.put("Done") 165 | 166 | 167 | def parallel_data_prefetch( 168 | func: callable, 169 | data, 170 | n_proc, 171 | target_data_type="ndarray", 172 | cpu_intensive=True, 173 | use_worker_id=False, 174 | ): 175 | # if target_data_type not in ["ndarray", "list"]: 176 | # raise ValueError( 177 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 178 | # ) 179 | if isinstance(data, np.ndarray) and target_data_type == "list": 180 | raise ValueError("list expected but function got ndarray.") 181 | elif isinstance(data, abc.Iterable): 182 | if isinstance(data, dict): 183 | print( 184 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 185 | ) 186 | data = list(data.values()) 187 | if target_data_type == "ndarray": 188 | data = np.asarray(data) 189 | else: 190 | data = list(data) 191 | else: 192 | raise TypeError( 193 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 194 | ) 195 | 196 | if cpu_intensive: 197 | Q = mp.Queue(1000) 198 | proc = mp.Process 199 | else: 200 | Q = Queue(1000) 201 | proc = Thread 202 | # spawn processes 203 | if target_data_type == "ndarray": 204 | arguments = [ 205 | [func, Q, part, i, use_worker_id] 206 | for i, part in enumerate(np.array_split(data, n_proc)) 207 | ] 208 | else: 209 | step = ( 210 | int(len(data) / n_proc + 1) 211 | if len(data) % n_proc != 0 212 | else int(len(data) / n_proc) 213 | ) 214 | arguments = [ 215 | [func, Q, part, i, use_worker_id] 216 | for i, part in enumerate( 217 | [data[i : i + step] for i in range(0, len(data), step)] 218 | ) 219 | ] 220 | processes = [] 221 | for i in range(n_proc): 222 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 223 | processes += [p] 224 | 225 | # start processes 226 | print(f"Start prefetching...") 227 | import time 228 | 229 | start = time.time() 230 | gather_res = [[] for _ in range(n_proc)] 231 | try: 232 | for p in processes: 233 | p.start() 234 | 235 | k = 0 236 | while k < n_proc: 237 | # get result 238 | res = Q.get() 239 | if res == "Done": 240 | k += 1 241 | else: 242 | gather_res[res[0]] = res[1] 243 | 244 | except Exception as e: 245 | print("Exception: ", e) 246 | for p in processes: 247 | p.terminate() 248 | 249 | raise e 250 | finally: 251 | for p in processes: 252 | p.join() 253 | print(f"Prefetching complete. [{time.time() - start} sec.]") 254 | 255 | if target_data_type == "ndarray": 256 | if not isinstance(gather_res[0], np.ndarray): 257 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 258 | 259 | # order outputs 260 | return np.concatenate(gather_res, axis=0) 261 | elif target_data_type == "list": 262 | out = [] 263 | for r in gather_res: 264 | out.extend(r) 265 | return out 266 | else: 267 | return gather_res 268 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/audiomae/util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | 15 | # -------------------------------------------------------- 16 | # 2D sine-cosine position embedding 17 | # References: 18 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 19 | # MoCo v3: https://github.com/facebookresearch/moco-v3 20 | # -------------------------------------------------------- 21 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 22 | """ 23 | grid_size: int of the grid height and width 24 | return: 25 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 26 | """ 27 | grid_h = np.arange(grid_size, dtype=np.float32) 28 | grid_w = np.arange(grid_size, dtype=np.float32) 29 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 30 | grid = np.stack(grid, axis=0) 31 | 32 | grid = grid.reshape([2, 1, grid_size, grid_size]) 33 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 34 | if cls_token: 35 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 36 | return pos_embed 37 | 38 | 39 | def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False): 40 | """ 41 | grid_size: int of the grid height and width 42 | return: 43 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 44 | """ 45 | grid_h = np.arange(grid_size[0], dtype=np.float32) 46 | grid_w = np.arange(grid_size[1], dtype=np.float32) 47 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 48 | grid = np.stack(grid, axis=0) 49 | 50 | grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) 51 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 52 | if cls_token: 53 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 54 | return pos_embed 55 | 56 | 57 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 58 | assert embed_dim % 2 == 0 59 | 60 | # use half of dimensions to encode grid_h 61 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 62 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 63 | 64 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 65 | return emb 66 | 67 | 68 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 69 | """ 70 | embed_dim: output dimension for each position 71 | pos: a list of positions to be encoded: size (M,) 72 | out: (M, D) 73 | """ 74 | assert embed_dim % 2 == 0 75 | # omega = np.arange(embed_dim // 2, dtype=np.float) 76 | omega = np.arange(embed_dim // 2, dtype=float) 77 | omega /= embed_dim / 2.0 78 | omega = 1.0 / 10000**omega # (D/2,) 79 | 80 | pos = pos.reshape(-1) # (M,) 81 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 82 | 83 | emb_sin = np.sin(out) # (M, D/2) 84 | emb_cos = np.cos(out) # (M, D/2) 85 | 86 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 87 | return emb 88 | 89 | 90 | # -------------------------------------------------------- 91 | # Interpolate position embeddings for high-resolution 92 | # References: 93 | # DeiT: https://github.com/facebookresearch/deit 94 | # -------------------------------------------------------- 95 | def interpolate_pos_embed(model, checkpoint_model): 96 | if "pos_embed" in checkpoint_model: 97 | pos_embed_checkpoint = checkpoint_model["pos_embed"] 98 | embedding_size = pos_embed_checkpoint.shape[-1] 99 | num_patches = model.patch_embed.num_patches 100 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 101 | # height (== width) for the checkpoint position embedding 102 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 103 | # height (== width) for the new position embedding 104 | new_size = int(num_patches**0.5) 105 | # class_token and dist_token are kept unchanged 106 | if orig_size != new_size: 107 | print( 108 | "Position interpolate from %dx%d to %dx%d" 109 | % (orig_size, orig_size, new_size, new_size) 110 | ) 111 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 112 | # only the position tokens are interpolated 113 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 114 | pos_tokens = pos_tokens.reshape( 115 | -1, orig_size, orig_size, embedding_size 116 | ).permute(0, 3, 1, 2) 117 | pos_tokens = torch.nn.functional.interpolate( 118 | pos_tokens, 119 | size=(new_size, new_size), 120 | mode="bicubic", 121 | align_corners=False, 122 | ) 123 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 124 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 125 | checkpoint_model["pos_embed"] = new_pos_embed 126 | 127 | 128 | def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size): 129 | if "pos_embed" in checkpoint_model: 130 | pos_embed_checkpoint = checkpoint_model["pos_embed"] 131 | embedding_size = pos_embed_checkpoint.shape[-1] 132 | num_patches = model.patch_embed.num_patches 133 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 134 | # height (== width) for the checkpoint position embedding 135 | # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 136 | # height (== width) for the new position embedding 137 | # new_size = int(num_patches ** 0.5) 138 | # class_token and dist_token are kept unchanged 139 | if orig_size != new_size: 140 | print( 141 | "Position interpolate from %dx%d to %dx%d" 142 | % (orig_size[0], orig_size[1], new_size[0], new_size[1]) 143 | ) 144 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 145 | # only the position tokens are interpolated 146 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 147 | pos_tokens = pos_tokens.reshape( 148 | -1, orig_size[0], orig_size[1], embedding_size 149 | ).permute(0, 3, 1, 2) 150 | pos_tokens = torch.nn.functional.interpolate( 151 | pos_tokens, 152 | size=(new_size[0], new_size[1]), 153 | mode="bicubic", 154 | align_corners=False, 155 | ) 156 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 157 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 158 | checkpoint_model["pos_embed"] = new_pos_embed 159 | 160 | 161 | def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size): 162 | if "pos_embed" in checkpoint_model: 163 | pos_embed_checkpoint = checkpoint_model["pos_embed"] 164 | embedding_size = pos_embed_checkpoint.shape[-1] 165 | num_patches = model.patch_embed.num_patches 166 | model.pos_embed.shape[-2] - num_patches 167 | if orig_size != new_size: 168 | print( 169 | "Position interpolate from %dx%d to %dx%d" 170 | % (orig_size[0], orig_size[1], new_size[0], new_size[1]) 171 | ) 172 | # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 173 | # only the position tokens are interpolated 174 | cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1) 175 | pos_tokens = pos_embed_checkpoint[:, 1:, :] # remove 176 | pos_tokens = pos_tokens.reshape( 177 | -1, orig_size[0], orig_size[1], embedding_size 178 | ) # .permute(0, 3, 1, 2) 179 | # pos_tokens = torch.nn.functional.interpolate( 180 | # pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False) 181 | 182 | # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 183 | pos_tokens = pos_tokens[:, :, : new_size[1], :] # assume only time diff 184 | pos_tokens = pos_tokens.flatten(1, 2) 185 | new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1) 186 | checkpoint_model["pos_embed"] = new_pos_embed 187 | 188 | 189 | def interpolate_patch_embed_audio( 190 | model, 191 | checkpoint_model, 192 | orig_channel, 193 | new_channel=1, 194 | kernel_size=(16, 16), 195 | stride=(16, 16), 196 | padding=(0, 0), 197 | ): 198 | if orig_channel != new_channel: 199 | if "patch_embed.proj.weight" in checkpoint_model: 200 | # aggregate 3 channels in rgb ckpt to 1 channel for audio 201 | new_proj_weight = torch.nn.Parameter( 202 | torch.sum(checkpoint_model["patch_embed.proj.weight"], dim=1).unsqueeze( 203 | 1 204 | ) 205 | ) 206 | checkpoint_model["patch_embed.proj.weight"] = new_proj_weight 207 | -------------------------------------------------------------------------------- /audiosr/latent_diffusion/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | from einops import repeat 16 | 17 | from audiosr.latent_diffusion.util import instantiate_from_config 18 | 19 | 20 | def make_beta_schedule( 21 | schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 22 | ): 23 | if schedule == "linear": 24 | betas = ( 25 | torch.linspace( 26 | linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 27 | ) 28 | ** 2 29 | ) 30 | 31 | elif schedule == "cosine": 32 | timesteps = ( 33 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 34 | ) 35 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 36 | alphas = torch.cos(alphas).pow(2) 37 | alphas = alphas / alphas[0] 38 | betas = 1 - alphas[1:] / alphas[:-1] 39 | # betas = np.clip(betas, a_min=0, a_max=0.999) 40 | 41 | elif schedule == "sqrt_linear": 42 | betas = torch.linspace( 43 | linear_start, linear_end, n_timestep, dtype=torch.float64 44 | ) 45 | elif schedule == "sqrt": 46 | betas = ( 47 | torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 48 | ** 0.5 49 | ) 50 | else: 51 | raise ValueError(f"schedule '{schedule}' unknown.") 52 | return betas.numpy() 53 | 54 | 55 | def make_ddim_timesteps( 56 | ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True 57 | ): 58 | if ddim_discr_method == "uniform": 59 | c = num_ddpm_timesteps // num_ddim_timesteps 60 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 61 | elif ddim_discr_method == "quad": 62 | ddim_timesteps = ( 63 | (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 64 | ).astype(int) 65 | else: 66 | raise NotImplementedError( 67 | f'There is no ddim discretization method called "{ddim_discr_method}"' 68 | ) 69 | 70 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 71 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 72 | steps_out = ddim_timesteps + 1 73 | if verbose: 74 | print(f"Selected timesteps for ddim sampler: {steps_out}") 75 | return steps_out 76 | 77 | 78 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 79 | # select alphas for computing the variance schedule 80 | alphas = alphacums[ddim_timesteps] 81 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 82 | 83 | # according the the formula provided in https://arxiv.org/abs/2010.02502 84 | sigmas = eta * np.sqrt( 85 | (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) 86 | ) 87 | if verbose: 88 | print( 89 | f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" 90 | ) 91 | print( 92 | f"For the chosen value of eta, which is {eta}, " 93 | f"this results in the following sigma_t schedule for ddim sampler {sigmas}" 94 | ) 95 | return sigmas, alphas, alphas_prev 96 | 97 | 98 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 99 | """ 100 | Create a beta schedule that discretizes the given alpha_t_bar function, 101 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 102 | :param num_diffusion_timesteps: the number of betas to produce. 103 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 104 | produces the cumulative product of (1-beta) up to that 105 | part of the diffusion process. 106 | :param max_beta: the maximum beta to use; use values lower than 1 to 107 | prevent singularities. 108 | """ 109 | betas = [] 110 | for i in range(num_diffusion_timesteps): 111 | t1 = i / num_diffusion_timesteps 112 | t2 = (i + 1) / num_diffusion_timesteps 113 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 114 | return np.array(betas) 115 | 116 | 117 | def extract_into_tensor(a, t, x_shape): 118 | b, *_ = t.shape 119 | out = a.gather(-1, t).contiguous() 120 | return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous() 121 | 122 | 123 | def checkpoint(func, inputs, params, flag): 124 | """ 125 | Evaluate a function without caching intermediate activations, allowing for 126 | reduced memory at the expense of extra compute in the backward pass. 127 | :param func: the function to evaluate. 128 | :param inputs: the argument sequence to pass to `func`. 129 | :param params: a sequence of parameters `func` depends on but does not 130 | explicitly take as arguments. 131 | :param flag: if False, disable gradient checkpointing. 132 | """ 133 | if flag: 134 | args = tuple(inputs) + tuple(params) 135 | return CheckpointFunction.apply(func, len(inputs), *args) 136 | else: 137 | return func(*inputs) 138 | 139 | 140 | class CheckpointFunction(torch.autograd.Function): 141 | @staticmethod 142 | def forward(ctx, run_function, length, *args): 143 | ctx.run_function = run_function 144 | ctx.input_tensors = list(args[:length]) 145 | ctx.input_params = list(args[length:]) 146 | 147 | with torch.no_grad(): 148 | output_tensors = ctx.run_function(*ctx.input_tensors) 149 | return output_tensors 150 | 151 | @staticmethod 152 | def backward(ctx, *output_grads): 153 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 154 | with torch.enable_grad(): 155 | # Fixes a bug where the first op in run_function modifies the 156 | # Tensor storage in place, which is not allowed for detach()'d 157 | # Tensors. 158 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 159 | output_tensors = ctx.run_function(*shallow_copies) 160 | input_grads = torch.autograd.grad( 161 | output_tensors, 162 | ctx.input_tensors + ctx.input_params, 163 | output_grads, 164 | allow_unused=True, 165 | ) 166 | del ctx.input_tensors 167 | del ctx.input_params 168 | del output_tensors 169 | return (None, None) + input_grads 170 | 171 | 172 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 173 | """ 174 | Create sinusoidal timestep embeddings. 175 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 176 | These may be fractional. 177 | :param dim: the dimension of the output. 178 | :param max_period: controls the minimum frequency of the embeddings. 179 | :return: an [N x dim] Tensor of positional embeddings. 180 | """ 181 | if not repeat_only: 182 | half = dim // 2 183 | freqs = torch.exp( 184 | -math.log(max_period) 185 | * torch.arange(start=0, end=half, dtype=torch.float32) 186 | / half 187 | ).to(device=timesteps.device) 188 | args = timesteps[:, None].float() * freqs[None] 189 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 190 | if dim % 2: 191 | embedding = torch.cat( 192 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 193 | ) 194 | else: 195 | embedding = repeat(timesteps, "b -> b d", d=dim) 196 | return embedding 197 | 198 | 199 | def zero_module(module): 200 | """ 201 | Zero out the parameters of a module and return it. 202 | """ 203 | for p in module.parameters(): 204 | p.detach().zero_() 205 | return module 206 | 207 | 208 | def scale_module(module, scale): 209 | """ 210 | Scale the parameters of a module and return it. 211 | """ 212 | for p in module.parameters(): 213 | p.detach().mul_(scale) 214 | return module 215 | 216 | 217 | def mean_flat(tensor): 218 | """ 219 | Take the mean over all non-batch dimensions. 220 | """ 221 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 222 | 223 | 224 | def normalization(channels): 225 | """ 226 | Make a standard normalization layer. 227 | :param channels: number of input channels. 228 | :return: an nn.Module for normalization. 229 | """ 230 | return GroupNorm32(32, channels) 231 | 232 | 233 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 234 | class SiLU(nn.Module): 235 | def forward(self, x): 236 | return x * torch.sigmoid(x) 237 | 238 | 239 | class GroupNorm32(nn.GroupNorm): 240 | def forward(self, x): 241 | return super().forward(x.float()).type(x.dtype) 242 | 243 | 244 | def conv_nd(dims, *args, **kwargs): 245 | """ 246 | Create a 1D, 2D, or 3D convolution module. 247 | """ 248 | if dims == 1: 249 | return nn.Conv1d(*args, **kwargs) 250 | elif dims == 2: 251 | return nn.Conv2d(*args, **kwargs) 252 | elif dims == 3: 253 | return nn.Conv3d(*args, **kwargs) 254 | raise ValueError(f"unsupported dimensions: {dims}") 255 | 256 | 257 | def linear(*args, **kwargs): 258 | """ 259 | Create a linear module. 260 | """ 261 | return nn.Linear(*args, **kwargs) 262 | 263 | 264 | def avg_pool_nd(dims, *args, **kwargs): 265 | """ 266 | Create a 1D, 2D, or 3D average pooling module. 267 | """ 268 | if dims == 1: 269 | return nn.AvgPool1d(*args, **kwargs) 270 | elif dims == 2: 271 | return nn.AvgPool2d(*args, **kwargs) 272 | elif dims == 3: 273 | return nn.AvgPool3d(*args, **kwargs) 274 | raise ValueError(f"unsupported dimensions: {dims}") 275 | 276 | 277 | class HybridConditioner(nn.Module): 278 | def __init__(self, c_concat_config, c_crossattn_config): 279 | super().__init__() 280 | self.concat_conditioner = instantiate_from_config(c_concat_config) 281 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 282 | 283 | def forward(self, c_concat, c_crossattn): 284 | c_concat = self.concat_conditioner(c_concat) 285 | c_crossattn = self.crossattn_conditioner(c_crossattn) 286 | return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} 287 | 288 | 289 | def noise_like(shape, device, repeat=False): 290 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( 291 | shape[0], *((1,) * (len(shape) - 1)) 292 | ) 293 | noise = lambda: torch.randn(shape, device=device) 294 | return repeat_noise() if repeat else noise() 295 | -------------------------------------------------------------------------------- /audiosr/clap/open_clip/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import re 5 | from copy import deepcopy 6 | from pathlib import Path 7 | 8 | import torch 9 | 10 | from .model import CLAP, convert_weights_to_fp16 11 | from .openai import load_openai_model 12 | from .pretrained import get_pretrained_url, download_pretrained 13 | from .transform import image_transform 14 | 15 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 16 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 17 | 18 | 19 | def _natural_key(string_): 20 | return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] 21 | 22 | 23 | def _rescan_model_configs(): 24 | global _MODEL_CONFIGS 25 | 26 | config_ext = (".json",) 27 | config_files = [] 28 | for config_path in _MODEL_CONFIG_PATHS: 29 | if config_path.is_file() and config_path.suffix in config_ext: 30 | config_files.append(config_path) 31 | elif config_path.is_dir(): 32 | for ext in config_ext: 33 | config_files.extend(config_path.glob(f"*{ext}")) 34 | 35 | for cf in config_files: 36 | if os.path.basename(cf)[0] == ".": 37 | continue # Ignore hidden files 38 | 39 | with open(cf, "r") as f: 40 | model_cfg = json.load(f) 41 | if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")): 42 | _MODEL_CONFIGS[cf.stem] = model_cfg 43 | 44 | _MODEL_CONFIGS = { 45 | k: v 46 | for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) 47 | } 48 | 49 | 50 | _rescan_model_configs() # initial populate of model config registry 51 | 52 | 53 | def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True): 54 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 55 | if isinstance(checkpoint, dict) and "state_dict" in checkpoint: 56 | state_dict = checkpoint["state_dict"] 57 | else: 58 | state_dict = checkpoint 59 | if skip_params: 60 | if next(iter(state_dict.items()))[0].startswith("module"): 61 | state_dict = {k[7:]: v for k, v in state_dict.items()} 62 | # for k in state_dict: 63 | # if k.startswith('transformer'): 64 | # v = state_dict.pop(k) 65 | # state_dict['text_branch.' + k[12:]] = v 66 | return state_dict 67 | 68 | 69 | def create_model( 70 | amodel_name: str, 71 | tmodel_name: str, 72 | pretrained: str = "", 73 | precision: str = "fp32", 74 | device: torch.device = torch.device("cpu"), 75 | jit: bool = False, 76 | force_quick_gelu: bool = False, 77 | openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"), 78 | skip_params=True, 79 | pretrained_audio: str = "", 80 | pretrained_text: str = "", 81 | enable_fusion: bool = False, 82 | fusion_type: str = "None" 83 | # pretrained_image: bool = False, 84 | ): 85 | amodel_name = amodel_name.replace( 86 | "/", "-" 87 | ) # for callers using old naming with / in ViT names 88 | pretrained_orig = pretrained 89 | pretrained = pretrained.lower() 90 | if pretrained == "openai": 91 | if amodel_name in _MODEL_CONFIGS: 92 | logging.info(f"Loading {amodel_name} model config.") 93 | model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) 94 | else: 95 | logging.error( 96 | f"Model config for {amodel_name} not found; available models {list_models()}." 97 | ) 98 | raise RuntimeError(f"Model config for {amodel_name} not found.") 99 | 100 | logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.") 101 | # Hard Code in model name 102 | model_cfg["text_cfg"]["model_type"] = tmodel_name 103 | model = load_openai_model( 104 | "ViT-B-16", 105 | model_cfg, 106 | device=device, 107 | jit=jit, 108 | cache_dir=openai_model_cache_dir, 109 | enable_fusion=enable_fusion, 110 | fusion_type=fusion_type, 111 | ) 112 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 113 | if precision == "amp" or precision == "fp32": 114 | model = model.float() 115 | else: 116 | if amodel_name in _MODEL_CONFIGS: 117 | logging.info(f"Loading {amodel_name} model config.") 118 | model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) 119 | else: 120 | logging.error( 121 | f"Model config for {amodel_name} not found; available models {list_models()}." 122 | ) 123 | raise RuntimeError(f"Model config for {amodel_name} not found.") 124 | 125 | if force_quick_gelu: 126 | # override for use of QuickGELU on non-OpenAI transformer models 127 | model_cfg["quick_gelu"] = True 128 | 129 | # if pretrained_image: 130 | # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}): 131 | # # pretrained weight loading for timm models set via vision_cfg 132 | # model_cfg['vision_cfg']['timm_model_pretrained'] = True 133 | # else: 134 | # assert False, 'pretrained image towers currently only supported for timm models' 135 | model_cfg["text_cfg"]["model_type"] = tmodel_name 136 | model_cfg["enable_fusion"] = enable_fusion 137 | model_cfg["fusion_type"] = fusion_type 138 | model = CLAP(**model_cfg) 139 | 140 | if pretrained: 141 | checkpoint_path = "" 142 | url = get_pretrained_url(amodel_name, pretrained) 143 | if url: 144 | checkpoint_path = download_pretrained(url, root=openai_model_cache_dir) 145 | elif os.path.exists(pretrained_orig): 146 | checkpoint_path = pretrained_orig 147 | if checkpoint_path: 148 | logging.info( 149 | f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})." 150 | ) 151 | ckpt = load_state_dict(checkpoint_path, skip_params=True) 152 | model.load_state_dict(ckpt) 153 | param_names = [n for n, p in model.named_parameters()] 154 | # for n in param_names: 155 | # print(n, "\t", "Loaded" if n in ckpt else "Unloaded") 156 | else: 157 | logging.warning( 158 | f"Pretrained weights ({pretrained}) not found for model {amodel_name}." 159 | ) 160 | raise RuntimeError( 161 | f"Pretrained weights ({pretrained}) not found for model {amodel_name}." 162 | ) 163 | 164 | if pretrained_audio: 165 | if amodel_name.startswith("PANN"): 166 | if "Cnn14_mAP" in pretrained_audio: # official checkpoint 167 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 168 | audio_ckpt = audio_ckpt["model"] 169 | keys = list(audio_ckpt.keys()) 170 | for key in keys: 171 | if ( 172 | "spectrogram_extractor" not in key 173 | and "logmel_extractor" not in key 174 | ): 175 | v = audio_ckpt.pop(key) 176 | audio_ckpt["audio_branch." + key] = v 177 | elif os.path.basename(pretrained_audio).startswith( 178 | "PANN" 179 | ): # checkpoint trained via HTSAT codebase 180 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 181 | audio_ckpt = audio_ckpt["state_dict"] 182 | keys = list(audio_ckpt.keys()) 183 | for key in keys: 184 | if key.startswith("sed_model"): 185 | v = audio_ckpt.pop(key) 186 | audio_ckpt["audio_branch." + key[10:]] = v 187 | elif os.path.basename(pretrained_audio).startswith( 188 | "finetuned" 189 | ): # checkpoint trained via linear probe codebase 190 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 191 | else: 192 | raise ValueError("Unknown audio checkpoint") 193 | elif amodel_name.startswith("HTSAT"): 194 | if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint 195 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 196 | audio_ckpt = audio_ckpt["state_dict"] 197 | keys = list(audio_ckpt.keys()) 198 | for key in keys: 199 | if key.startswith("sed_model") and ( 200 | "spectrogram_extractor" not in key 201 | and "logmel_extractor" not in key 202 | ): 203 | v = audio_ckpt.pop(key) 204 | audio_ckpt["audio_branch." + key[10:]] = v 205 | elif os.path.basename(pretrained_audio).startswith( 206 | "HTSAT" 207 | ): # checkpoint trained via HTSAT codebase 208 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 209 | audio_ckpt = audio_ckpt["state_dict"] 210 | keys = list(audio_ckpt.keys()) 211 | for key in keys: 212 | if key.startswith("sed_model"): 213 | v = audio_ckpt.pop(key) 214 | audio_ckpt["audio_branch." + key[10:]] = v 215 | elif os.path.basename(pretrained_audio).startswith( 216 | "finetuned" 217 | ): # checkpoint trained via linear probe codebase 218 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 219 | else: 220 | raise ValueError("Unknown audio checkpoint") 221 | else: 222 | raise f"this audio encoder pretrained checkpoint is not support" 223 | 224 | model.load_state_dict(audio_ckpt, strict=False) 225 | logging.info( 226 | f"Loading pretrained {amodel_name} weights ({pretrained_audio})." 227 | ) 228 | param_names = [n for n, p in model.named_parameters()] 229 | for n in param_names: 230 | print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded") 231 | 232 | model.to(device=device) 233 | if precision == "fp16": 234 | assert device.type != "cpu" 235 | convert_weights_to_fp16(model) 236 | 237 | if jit: 238 | model = torch.jit.script(model) 239 | 240 | return model, model_cfg 241 | 242 | 243 | def create_model_and_transforms( 244 | model_name: str, 245 | pretrained: str = "", 246 | precision: str = "fp32", 247 | device: torch.device = torch.device("cpu"), 248 | jit: bool = False, 249 | force_quick_gelu: bool = False, 250 | # pretrained_image: bool = False, 251 | ): 252 | model = create_model( 253 | model_name, 254 | pretrained, 255 | precision, 256 | device, 257 | jit, 258 | force_quick_gelu=force_quick_gelu, 259 | # pretrained_image=pretrained_image 260 | ) 261 | preprocess_train = image_transform(model.visual.image_size, is_train=True) 262 | preprocess_val = image_transform(model.visual.image_size, is_train=False) 263 | return model, preprocess_train, preprocess_val 264 | 265 | 266 | def list_models(): 267 | """enumerate available model architectures based on config files""" 268 | return list(_MODEL_CONFIGS.keys()) 269 | 270 | 271 | def add_model_config(path): 272 | """add model config path or file and update registry""" 273 | if not isinstance(path, Path): 274 | path = Path(path) 275 | _MODEL_CONFIG_PATHS.append(path) 276 | _rescan_model_configs() 277 | -------------------------------------------------------------------------------- /audiosr/latent_encoder/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from audiosr.latent_diffusion.modules.ema import * 7 | 8 | from audiosr.latent_diffusion.modules.diffusionmodules.model import Encoder, Decoder 9 | from audiosr.latent_diffusion.modules.distributions.distributions import ( 10 | DiagonalGaussianDistribution, 11 | ) 12 | import soundfile as sf 13 | 14 | from audiosr.utilities.model import get_vocoder 15 | from audiosr.utilities.tools import synth_one_sample 16 | 17 | 18 | class AutoencoderKL(nn.Module): 19 | def __init__( 20 | self, 21 | ddconfig=None, 22 | lossconfig=None, 23 | batchsize=None, 24 | embed_dim=None, 25 | time_shuffle=1, 26 | subband=1, 27 | sampling_rate=16000, 28 | ckpt_path=None, 29 | reload_from_ckpt=None, 30 | ignore_keys=[], 31 | image_key="fbank", 32 | colorize_nlabels=None, 33 | monitor=None, 34 | base_learning_rate=1e-5, 35 | ): 36 | super().__init__() 37 | self.automatic_optimization = False 38 | assert ( 39 | "mel_bins" in ddconfig.keys() 40 | ), "mel_bins is not specified in the Autoencoder config" 41 | num_mel = ddconfig["mel_bins"] 42 | self.image_key = image_key 43 | self.sampling_rate = sampling_rate 44 | self.encoder = Encoder(**ddconfig) 45 | self.decoder = Decoder(**ddconfig) 46 | 47 | self.loss = None 48 | self.subband = int(subband) 49 | 50 | if self.subband > 1: 51 | print("Use subband decomposition %s" % self.subband) 52 | 53 | assert ddconfig["double_z"] 54 | self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) 55 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 56 | 57 | if self.image_key == "fbank": 58 | self.vocoder = get_vocoder(None, "cpu", num_mel) 59 | self.embed_dim = embed_dim 60 | if colorize_nlabels is not None: 61 | assert type(colorize_nlabels) == int 62 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 63 | if monitor is not None: 64 | self.monitor = monitor 65 | if ckpt_path is not None: 66 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 67 | self.learning_rate = float(base_learning_rate) 68 | # print("Initial learning rate %s" % self.learning_rate) 69 | 70 | self.time_shuffle = time_shuffle 71 | self.reload_from_ckpt = reload_from_ckpt 72 | self.reloaded = False 73 | self.mean, self.std = None, None 74 | 75 | self.feature_cache = None 76 | self.flag_first_run = True 77 | self.train_step = 0 78 | 79 | self.logger_save_dir = None 80 | self.logger_exp_name = None 81 | 82 | def get_log_dir(self): 83 | if self.logger_save_dir is None and self.logger_exp_name is None: 84 | return os.path.join(self.logger.save_dir, self.logger._project) 85 | else: 86 | return os.path.join(self.logger_save_dir, self.logger_exp_name) 87 | 88 | def set_log_dir(self, save_dir, exp_name): 89 | self.logger_save_dir = save_dir 90 | self.logger_exp_name = exp_name 91 | 92 | def init_from_ckpt(self, path, ignore_keys=list()): 93 | sd = torch.load(path, map_location="cpu")["state_dict"] 94 | keys = list(sd.keys()) 95 | for k in keys: 96 | for ik in ignore_keys: 97 | if k.startswith(ik): 98 | print("Deleting key {} from state_dict.".format(k)) 99 | del sd[k] 100 | self.load_state_dict(sd, strict=False) 101 | print(f"Restored from {path}") 102 | 103 | def encode(self, x): 104 | # x = self.time_shuffle_operation(x) 105 | # x = self.freq_split_subband(x) 106 | h = self.encoder(x) 107 | moments = self.quant_conv(h) 108 | posterior = DiagonalGaussianDistribution(moments) 109 | return posterior 110 | 111 | def decode(self, z): 112 | z = self.post_quant_conv(z) 113 | dec = self.decoder(z) 114 | # bs, ch, shuffled_timesteps, fbins = dec.size() 115 | # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins) 116 | # dec = self.freq_merge_subband(dec) 117 | return dec 118 | 119 | def decode_to_waveform(self, dec): 120 | from audiosr.utilities.model import vocoder_infer 121 | 122 | if self.image_key == "fbank": 123 | dec = dec.squeeze(1).permute(0, 2, 1) 124 | wav_reconstruction = vocoder_infer(dec, self.vocoder) 125 | elif self.image_key == "stft": 126 | dec = dec.squeeze(1).permute(0, 2, 1) 127 | wav_reconstruction = self.wave_decoder(dec) 128 | return wav_reconstruction 129 | 130 | def visualize_latent(self, input): 131 | import matplotlib.pyplot as plt 132 | 133 | # for i in range(10): 134 | # zero_input = torch.zeros_like(input) - 11.59 135 | # zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59 136 | 137 | # posterior = self.encode(zero_input) 138 | # latent = posterior.sample() 139 | # avg_latent = torch.mean(latent, dim=1)[0] 140 | # plt.imshow(avg_latent.cpu().detach().numpy().T) 141 | # plt.savefig("%s.png" % i) 142 | # plt.close() 143 | 144 | np.save("input.npy", input.cpu().detach().numpy()) 145 | # zero_input = torch.zeros_like(input) - 11.59 146 | time_input = input.clone() 147 | time_input[:, :, :, :32] *= 0 148 | time_input[:, :, :, :32] -= 11.59 149 | 150 | np.save("time_input.npy", time_input.cpu().detach().numpy()) 151 | 152 | posterior = self.encode(time_input) 153 | latent = posterior.sample() 154 | np.save("time_latent.npy", latent.cpu().detach().numpy()) 155 | avg_latent = torch.mean(latent, dim=1) 156 | for i in range(avg_latent.size(0)): 157 | plt.imshow(avg_latent[i].cpu().detach().numpy().T) 158 | plt.savefig("freq_%s.png" % i) 159 | plt.close() 160 | 161 | freq_input = input.clone() 162 | freq_input[:, :, :512, :] *= 0 163 | freq_input[:, :, :512, :] -= 11.59 164 | 165 | np.save("freq_input.npy", freq_input.cpu().detach().numpy()) 166 | 167 | posterior = self.encode(freq_input) 168 | latent = posterior.sample() 169 | np.save("freq_latent.npy", latent.cpu().detach().numpy()) 170 | avg_latent = torch.mean(latent, dim=1) 171 | for i in range(avg_latent.size(0)): 172 | plt.imshow(avg_latent[i].cpu().detach().numpy().T) 173 | plt.savefig("time_%s.png" % i) 174 | plt.close() 175 | 176 | def get_input(self, batch): 177 | fname, text, label_indices, waveform, stft, fbank = ( 178 | batch["fname"], 179 | batch["text"], 180 | batch["label_vector"], 181 | batch["waveform"], 182 | batch["stft"], 183 | batch["log_mel_spec"], 184 | ) 185 | # if(self.time_shuffle != 1): 186 | # if(fbank.size(1) % self.time_shuffle != 0): 187 | # pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle) 188 | # fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len)) 189 | 190 | ret = {} 191 | 192 | ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = ( 193 | fbank.unsqueeze(1), 194 | stft.unsqueeze(1), 195 | fname, 196 | waveform.unsqueeze(1), 197 | ) 198 | 199 | return ret 200 | 201 | def save_wave(self, batch_wav, fname, save_dir): 202 | os.makedirs(save_dir, exist_ok=True) 203 | 204 | for wav, name in zip(batch_wav, fname): 205 | name = os.path.basename(name) 206 | 207 | sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate) 208 | 209 | def get_last_layer(self): 210 | return self.decoder.conv_out.weight 211 | 212 | @torch.no_grad() 213 | def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs): 214 | log = dict() 215 | x = batch.to(self.device) 216 | if not only_inputs: 217 | xrec, posterior = self(x) 218 | log["samples"] = self.decode(posterior.sample()) 219 | log["reconstructions"] = xrec 220 | 221 | log["inputs"] = x 222 | wavs = self._log_img(log, train=train, index=0, waveform=waveform) 223 | return wavs 224 | 225 | def _log_img(self, log, train=True, index=0, waveform=None): 226 | images_input = self.tensor2numpy(log["inputs"][index, 0]).T 227 | images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T 228 | images_samples = self.tensor2numpy(log["samples"][index, 0]).T 229 | 230 | if train: 231 | name = "train" 232 | else: 233 | name = "val" 234 | 235 | if self.logger is not None: 236 | self.logger.log_image( 237 | "img_%s" % name, 238 | [images_input, images_reconstruct, images_samples], 239 | caption=["input", "reconstruct", "samples"], 240 | ) 241 | 242 | inputs, reconstructions, samples = ( 243 | log["inputs"], 244 | log["reconstructions"], 245 | log["samples"], 246 | ) 247 | 248 | if self.image_key == "fbank": 249 | wav_original, wav_prediction = synth_one_sample( 250 | inputs[index], 251 | reconstructions[index], 252 | labels="validation", 253 | vocoder=self.vocoder, 254 | ) 255 | wav_original, wav_samples = synth_one_sample( 256 | inputs[index], samples[index], labels="validation", vocoder=self.vocoder 257 | ) 258 | wav_original, wav_samples, wav_prediction = ( 259 | wav_original[0], 260 | wav_samples[0], 261 | wav_prediction[0], 262 | ) 263 | elif self.image_key == "stft": 264 | wav_prediction = ( 265 | self.decode_to_waveform(reconstructions)[index, 0] 266 | .cpu() 267 | .detach() 268 | .numpy() 269 | ) 270 | wav_samples = ( 271 | self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy() 272 | ) 273 | wav_original = waveform[index, 0].cpu().detach().numpy() 274 | 275 | if self.logger is not None: 276 | self.logger.experiment.log( 277 | { 278 | "original_%s" 279 | % name: wandb.Audio( 280 | wav_original, caption="original", sample_rate=self.sampling_rate 281 | ), 282 | "reconstruct_%s" 283 | % name: wandb.Audio( 284 | wav_prediction, 285 | caption="reconstruct", 286 | sample_rate=self.sampling_rate, 287 | ), 288 | "samples_%s" 289 | % name: wandb.Audio( 290 | wav_samples, caption="samples", sample_rate=self.sampling_rate 291 | ), 292 | } 293 | ) 294 | 295 | return wav_original, wav_prediction, wav_samples 296 | 297 | def tensor2numpy(self, tensor): 298 | return tensor.cpu().detach().numpy() 299 | 300 | def to_rgb(self, x): 301 | assert self.image_key == "segmentation" 302 | if not hasattr(self, "colorize"): 303 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 304 | x = F.conv2d(x, weight=self.colorize) 305 | x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 306 | return x 307 | 308 | 309 | class IdentityFirstStage(torch.nn.Module): 310 | def __init__(self, *args, vq_interface=False, **kwargs): 311 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff 312 | super().__init__() 313 | 314 | def encode(self, x, *args, **kwargs): 315 | return x 316 | 317 | def decode(self, x, *args, **kwargs): 318 | return x 319 | 320 | def quantize(self, x, *args, **kwargs): 321 | if self.vq_interface: 322 | return x, None, [None, None, None] 323 | return x 324 | 325 | def forward(self, x, *args, **kwargs): 326 | return x 327 | -------------------------------------------------------------------------------- /audiosr/clap/open_clip/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn as nn 4 | from torchvision.ops.misc import FrozenBatchNorm2d 5 | import logging 6 | from tqdm import tqdm 7 | import random 8 | import json 9 | import os 10 | import pathlib 11 | 12 | # TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later. 13 | dataset_split = { 14 | "audiocaps": ["train", "valid", "test"], 15 | "audioset": ["balanced_train", "unbalanced_train", "eval"], 16 | "BBCSoundEffects": ["train", "test"], 17 | "Clotho": ["train", "test", "valid"], 18 | "free_to_use_sounds": ["train", "test"], 19 | "paramount_motion": ["train", "test"], 20 | "sonniss_game_effects": ["train", "test"], 21 | "wesoundeffects": ["train", "test"], 22 | "MACS": ["train", "test"], 23 | "freesound": ["train", "test"], 24 | "FSD50K": ["train", "test", "valid"], 25 | "fsd50k_class_label": ["train", "test", "valid"], 26 | "esc50": ["train", "test"], 27 | "audiostock": ["train", "test"], 28 | "freesound_no_overlap_noesc50": ["train", "test"], 29 | "epidemic_sound_effects": ["train", "test"], 30 | "VGGSound": ["train", "test"], 31 | "urbansound8k_class_label": ["train", "test"], 32 | "audioset_t5": ["balanced_train", "unbalanced_train", "eval"], 33 | "epidemic_sound_effects_t5": ["train", "test"], 34 | "WavText5K": ["train", "test"], 35 | "esc50_no_overlap": ["train", "test"], 36 | "usd8k_no_overlap": ["train", "test"], 37 | "fsd50k_200_class_label": ["train", "test", "valid"], 38 | } 39 | 40 | 41 | def freeze_batch_norm_2d(module, module_match={}, name=""): 42 | """ 43 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 44 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 45 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 46 | 47 | Args: 48 | module (torch.nn.Module): Any PyTorch module. 49 | module_match (dict): Dictionary of full module names to freeze (all if empty) 50 | name (str): Full module name (prefix) 51 | 52 | Returns: 53 | torch.nn.Module: Resulting module 54 | 55 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 56 | """ 57 | res = module 58 | is_match = True 59 | if module_match: 60 | is_match = name in module_match 61 | if is_match and isinstance( 62 | module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) 63 | ): 64 | res = FrozenBatchNorm2d(module.num_features) 65 | res.num_features = module.num_features 66 | res.affine = module.affine 67 | if module.affine: 68 | res.weight.data = module.weight.data.clone().detach() 69 | res.bias.data = module.bias.data.clone().detach() 70 | res.running_mean.data = module.running_mean.data 71 | res.running_var.data = module.running_var.data 72 | res.eps = module.eps 73 | else: 74 | for child_name, child in module.named_children(): 75 | full_child_name = ".".join([name, child_name]) if name else child_name 76 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 77 | if new_child is not child: 78 | res.add_module(child_name, new_child) 79 | return res 80 | 81 | 82 | def exist(dataset_name, dataset_type): 83 | """ 84 | Check if dataset exists 85 | """ 86 | if dataset_type in dataset_split[dataset_name]: 87 | return True 88 | else: 89 | return False 90 | 91 | 92 | def get_tar_path_from_dataset_name( 93 | dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None 94 | ): 95 | """ 96 | Get tar path from dataset name and type 97 | """ 98 | output = [] 99 | for n in dataset_names: 100 | if full_dataset is not None and n in full_dataset: 101 | current_dataset_types = dataset_split[n] 102 | else: 103 | current_dataset_types = dataset_types 104 | for s in current_dataset_types: 105 | tmp = [] 106 | if islocal: 107 | sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json" 108 | if not os.path.exists(sizefilepath_): 109 | sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" 110 | else: 111 | sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" 112 | if not os.path.exists(sizefilepath_): 113 | continue 114 | sizes = json.load(open(sizefilepath_, "r")) 115 | for k in sizes.keys(): 116 | if islocal: 117 | tmp.append(f"{dataset_path}/{n}/{s}/{k}") 118 | else: 119 | tmp.append( 120 | f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -" 121 | ) 122 | if proportion != 1: 123 | tmp = random.sample(tmp, int(proportion * len(tmp))) 124 | output.append(tmp) 125 | return sum(output, []) 126 | 127 | 128 | def get_tar_path_from_txts(txt_path, islocal, proportion=1): 129 | """ 130 | Get tar path from txt path 131 | """ 132 | if isinstance(txt_path, (list, tuple)): 133 | return sum( 134 | [ 135 | get_tar_path_from_txts( 136 | txt_path[i], islocal=islocal, proportion=proportion 137 | ) 138 | for i in range(len(txt_path)) 139 | ], 140 | [], 141 | ) 142 | if isinstance(txt_path, str): 143 | with open(txt_path) as f: 144 | lines = f.readlines() 145 | if islocal: 146 | lines = [ 147 | lines[i] 148 | .split("\n")[0] 149 | .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/") 150 | for i in range(len(lines)) 151 | ] 152 | else: 153 | lines = [ 154 | lines[i].split("\n")[0].replace(".tar", ".tar -") 155 | for i in range(len(lines)) 156 | ] 157 | if proportion != 1: 158 | print("Sampling tars with proportion of {}".format(proportion)) 159 | lines = random.sample(lines, int(proportion * len(lines))) 160 | return lines 161 | 162 | 163 | def get_mix_lambda(mixup_alpha, batch_size): 164 | mixup_lambdas = [ 165 | np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size) 166 | ] 167 | return np.array(mixup_lambdas).astype(np.float32) 168 | 169 | 170 | def do_mixup(x, mixup_lambda): 171 | """ 172 | Args: 173 | x: (batch_size , ...) 174 | mixup_lambda: (batch_size,) 175 | Returns: 176 | out: (batch_size, ...) 177 | """ 178 | out = ( 179 | x.transpose(0, -1) * mixup_lambda 180 | + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda) 181 | ).transpose(0, -1) 182 | return out 183 | 184 | 185 | def interpolate(x, ratio): 186 | """Interpolate data in time domain. This is used to compensate the 187 | resolution reduction in downsampling of a CNN. 188 | 189 | Args: 190 | x: (batch_size, time_steps, classes_num) 191 | ratio: int, ratio to interpolate 192 | Returns: 193 | upsampled: (batch_size, time_steps * ratio, classes_num) 194 | """ 195 | (batch_size, time_steps, classes_num) = x.shape 196 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) 197 | upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) 198 | return upsampled 199 | 200 | 201 | def pad_framewise_output(framewise_output, frames_num): 202 | """Pad framewise_output to the same length as input frames. The pad value 203 | is the same as the value of the last frame. 204 | Args: 205 | framewise_output: (batch_size, frames_num, classes_num) 206 | frames_num: int, number of frames to pad 207 | Outputs: 208 | output: (batch_size, frames_num, classes_num) 209 | """ 210 | pad = framewise_output[:, -1:, :].repeat( 211 | 1, frames_num - framewise_output.shape[1], 1 212 | ) 213 | """tensor for padding""" 214 | 215 | output = torch.cat((framewise_output, pad), dim=1) 216 | """(batch_size, frames_num, classes_num)""" 217 | 218 | 219 | # def process_ipc(index_path, classes_num, filename): 220 | # # load data 221 | # logging.info("Load Data...............") 222 | # ipc = [[] for _ in range(classes_num)] 223 | # with h5py.File(index_path, "r") as f: 224 | # for i in tqdm(range(len(f["target"]))): 225 | # t_class = np.where(f["target"][i])[0] 226 | # for t in t_class: 227 | # ipc[t].append(i) 228 | # print(ipc) 229 | # np.save(filename, ipc) 230 | # logging.info("Load Data Succeed...............") 231 | 232 | 233 | def save_to_dict(s, o_={}): 234 | sp = s.split(": ") 235 | o_.update({sp[0]: float(sp[1])}) 236 | return o_ 237 | 238 | 239 | def get_data_from_log(txt_path): 240 | """ 241 | Output dictionary from out.txt log file 242 | """ 243 | with open(txt_path) as f: 244 | lines = f.readlines() 245 | val_data = {} 246 | train_data = {} 247 | train_losses = [] 248 | train_losses_epoch = [] 249 | for i in range(len(lines)): 250 | if "| INFO |" in lines[i]: 251 | if "Eval Epoch" in lines[i]: 252 | if "val_loss" in lines[i]: 253 | # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", "")) 254 | line = lines[i].split("Eval Epoch: ")[-1] 255 | num_epoch = int(line.split(" ")[0].split(" ")[0]) 256 | d = { 257 | line.split(" ")[0] 258 | .split(" ")[1] 259 | .replace(":", ""): float(line.split(" ")[0].split(" ")[-1]) 260 | } 261 | for i in range(1, len(line.split(" "))): 262 | d = save_to_dict(line.split(" ")[i], d) 263 | val_data[num_epoch] = d 264 | elif "Train Epoch" in lines[i]: 265 | num_epoch = int(lines[i].split("Train Epoch: ")[1][0]) 266 | loss = float(lines[i].split("Loss: ")[-1].split(" (")[0]) 267 | train_losses.append(loss) 268 | train_losses_epoch.append(num_epoch) 269 | for i in range(len(train_losses)): 270 | train_data[i] = { 271 | "num_epoch": train_losses_epoch[i], 272 | "train_loss": train_losses[i], 273 | } 274 | return train_data, val_data 275 | 276 | 277 | def save_p(obj, filename): 278 | import pickle 279 | 280 | try: 281 | from deepdiff import DeepDiff 282 | except: 283 | os.system("pip install deepdiff") 284 | from deepdiff import DeepDiff 285 | with open(filename, "wb") as file: 286 | pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol 287 | with open(filename, "rb") as file: 288 | z = pickle.load(file) 289 | assert ( 290 | DeepDiff(obj, z, ignore_string_case=True) == {} 291 | ), "there is something wrong with the saving process" 292 | return 293 | 294 | 295 | def load_p(filename): 296 | import pickle 297 | 298 | with open(filename, "rb") as file: 299 | z = pickle.load(file) 300 | return z 301 | 302 | 303 | def save_json(data, name="data.json"): 304 | import json 305 | 306 | with open(name, "w") as fp: 307 | json.dump(data, fp) 308 | return 309 | 310 | 311 | def load_json(name): 312 | import json 313 | 314 | with open(name, "r") as fp: 315 | data = json.load(fp) 316 | return data 317 | 318 | 319 | def load_class_label(path): 320 | # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing 321 | # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array 322 | out = None 323 | if path is not None: 324 | if pathlib.Path(path).suffix in [".pkl", ".pickle"]: 325 | out = load_p(path) 326 | elif pathlib.Path(path).suffix in [".json", ".txt"]: 327 | out = load_json(path) 328 | elif pathlib.Path(path).suffix in [".npy", ".npz"]: 329 | out = np.load(path) 330 | elif pathlib.Path(path).suffix in [".csv"]: 331 | import pandas as pd 332 | 333 | out = pd.read_csv(path) 334 | return out 335 | # if out is None: 336 | # return None 337 | # else: 338 | # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False) 339 | # val = Array('i', out.values(), lock=False) 340 | # return (key, val) 341 | 342 | 343 | from torch import optim 344 | 345 | 346 | def get_optimizer(params, lr, betas, eps, momentum, optimizer_name): 347 | if optimizer_name.lower() == "adamw": 348 | optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps) 349 | elif optimizer_name.lower() == "sgd": 350 | optimizer = optim.SGD(params, lr=lr, momentum=momentum) 351 | elif optimizer_name.lower() == "adam": 352 | optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps) 353 | else: 354 | raise ValueError("optimizer name is not correct") 355 | return optimizer 356 | --------------------------------------------------------------------------------