├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── app.py ├── audioldm ├── __init__.py ├── __main__.py ├── audio │ ├── __init__.py │ ├── audio_processing.py │ ├── stft.py │ └── tools.py ├── clap │ ├── __init__.py │ ├── encoders.py │ ├── open_clip │ │ ├── __init__.py │ │ ├── bert.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── factory.py │ │ ├── feature_fusion.py │ │ ├── htsat.py │ │ ├── linear_probe.py │ │ ├── loss.py │ │ ├── model.py │ │ ├── model_configs │ │ │ ├── HTSAT-base.json │ │ │ ├── HTSAT-large.json │ │ │ ├── HTSAT-tiny-win-1536.json │ │ │ ├── HTSAT-tiny.json │ │ │ ├── PANN-10.json │ │ │ ├── PANN-14-fmax-18k.json │ │ │ ├── PANN-14-fmax-8k-20s.json │ │ │ ├── PANN-14-tiny-transformer.json │ │ │ ├── PANN-14-win-1536.json │ │ │ ├── PANN-14.json │ │ │ ├── PANN-6.json │ │ │ ├── RN101-quickgelu.json │ │ │ ├── RN101.json │ │ │ ├── RN50-quickgelu.json │ │ │ ├── RN50.json │ │ │ ├── RN50x16.json │ │ │ ├── RN50x4.json │ │ │ ├── ViT-B-16.json │ │ │ ├── ViT-B-32-quickgelu.json │ │ │ ├── ViT-B-32.json │ │ │ └── ViT-L-14.json │ │ ├── openai.py │ │ ├── pann_model.py │ │ ├── pretrained.py │ │ ├── timm_model.py │ │ ├── tokenizer.py │ │ ├── transform.py │ │ ├── utils.py │ │ └── version.py │ └── training │ │ ├── __init__.py │ │ ├── audioset_textmap.npy │ │ ├── data.py │ │ ├── distributed.py │ │ ├── imagenet_zeroshot_data.py │ │ ├── infer_demo.py │ │ ├── logger.py │ │ ├── lp_main.py │ │ ├── lp_train.py │ │ ├── main.py │ │ ├── params.py │ │ ├── scheduler.py │ │ ├── train.py │ │ └── zero_shot.py ├── hifigan │ ├── __init__.py │ ├── models.py │ └── utilities.py ├── latent_diffusion │ ├── __init__.py │ ├── attention.py │ ├── ddim.py │ ├── ddpm.py │ ├── ema.py │ ├── openaimodel.py │ └── util.py ├── ldm.py ├── pipeline.py ├── utils.py └── variational_autoencoder │ ├── __init__.py │ ├── autoencoder.py │ ├── distributions.py │ └── modules.py ├── bg.png ├── bin ├── audioldm └── audioldm.cmd ├── ckpt └── .gitkeep ├── scripts ├── test.sh └── text2sound.py ├── setup.py └── trumpet.wav /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | __pycache__ 3 | build 4 | dist 5 | output 6 | *temp.py 7 | *.wav 8 | gradio_cached_examples 9 | run.sh 10 | run2.sh 11 | output/ -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.py LICENSE README.md 2 | recursive-include audioldm *.txt *.py *.gz *.npy *.json -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # :sound: Audio Generation with AudioLDM (ICML 2023) 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2301.12503-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2301.12503) [![githubio](https://img.shields.io/badge/GitHub.io-Audio_Samples-blue?logo=Github&style=flat-square)](https://audioldm.github.io/) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/olaviinha/NeuralTextToAudio/blob/main/AudioLDM_pub.ipynb?force_theme=dark) [![Replicate](https://replicate.com/jagilley/audio-ldm/badge)](https://replicate.com/jagilley/audio-ldm) 4 | 5 | 6 | 7 | **Generate speech, sound effects, music and beyond.** 8 | 9 | This repo currently support: 10 | 11 | - **Text-to-Audio Generation**: Generate audio given text input. 12 | - **Audio-to-Audio Generation**: Given an audio, generate another audio that contain the same type of sound. 13 | - **Text-guided Audio-to-Audio Style Transfer**: Transfer the sound of an audio into another one using the text description. 14 | 15 |
16 | 17 | ## Important tricks to make your generated audio sound better 18 | 1. Try to provide more hints to AudioLDM, such as using more adjectives to describe your sound (e.g., clearly, high quality) or make your target more specific (e.g., "water stream in a forest" instead of "stream"). This can make sure AudioLDM understand what you want. 19 | 2. Try to use different random seeds, which can affect the generation quality significantly sometimes. 20 | 3. It's best to use general terms like 'man' or 'woman' instead of specific names for individuals or abstract objects that humans may not be familiar with. 21 | 22 | # Change Log 23 | 24 | **2023-04-10**: Try to finetune AudioLDM with MusicCaps and AudioCaps datasets. Add three more checkpoints, including audioldm-m-text-ft, audioldm-s-text-ft, and audioldm-m-full. 25 | 26 | **2023-03-04**: Add two more checkpoints, one is small model with more training steps, another is a large model. Add model selection in the Gradio APP. 27 | 28 | **2023-02-24**: Add audio-to-audio generation. Add test cases. Add a pipeline (python function) for audio super-resolution and inpainting. 29 | 30 | **2023-02-15**: Add audio style transfer. Add more options on generation. 31 | 32 | ## Web APP 33 | 34 | The web APP currently only support Text-to-Audio generation. For full functionality please refer to the [Commandline Usage](https://github.com/haoheliu/AudioLDM#commandline-usage). 35 | 36 | 1. Prepare running environment 37 | ```shell 38 | conda create -n audioldm python=3.8; conda activate audioldm 39 | pip3 install git+https://github.com/haoheliu/AudioLDM.git 40 | git clone https://github.com/haoheliu/AudioLDM; cd AudioLDM 41 | ``` 42 | 2. Start the web application (powered by Gradio) 43 | ```shell 44 | python3 app.py 45 | ``` 46 | 3. A link will be printed out. Click the link to open the browser and play. 47 | 48 | ## Commandline Usage 49 | Prepare running environment 50 | ```shell 51 | # Optional 52 | conda create -n audioldm python=3.8; conda activate audioldm 53 | # Install AudioLDM 54 | pip3 install git+https://github.com/haoheliu/AudioLDM.git 55 | ``` 56 | 57 | :star2: **Text-to-Audio Generation**: generate an audio guided by a text 58 | ```shell 59 | # The default --mode is "generation" 60 | audioldm -t "A hammer is hitting a wooden surface" 61 | # Result will be saved in "./output/generation" 62 | ``` 63 | 64 | :star2: **Audio-to-Audio Generation**: generate an audio guided by an audio (output will have similar audio events as the input audio file). 65 | ```shell 66 | audioldm --file_path trumpet.wav 67 | # Result will be saved in "./output/generation_audio_to_audio/trumpet" 68 | ``` 69 | 70 | :star2: **Text-guided Audio-to-Audio Style Transfer** 71 | ```shell 72 | # Test run 73 | # --file_path is the original audio file for transfer 74 | # -t is the text AudioLDM uses for transfer. 75 | # Please make sure that --file_path exist 76 | audioldm --mode "transfer" --file_path trumpet.wav -t "Children Singing" 77 | # Result will be saved in "./output/transfer/trumpet" 78 | 79 | # Tune the value of --transfer_strength is important! 80 | # --transfer_strength: A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text 81 | audioldm --mode "transfer" --file_path trumpet.wav -t "Children Singing" --transfer_strength 0.25 82 | ``` 83 | 84 | :gear: How to choose between different model checkpoints? 85 | ``` 86 | # Add the --model_name parameter, choice={audioldm-m-text-ft, audioldm-s-text-ft, audioldm-m-full, audioldm-s-full,audioldm-l-full,audioldm-s-full-v2} 87 | audioldm --model_name audioldm-s-full 88 | ``` 89 | 90 | - :star: audioldm-m-full (default, **recommend**): the medium AudioLDM without finetuning and trained with audio embeddings as condition *(added 2023-04-10)*. 91 | - :star: audioldm-s-full (**recommend**): the original open-sourced version *(added 2023-02-01)*. 92 | - :star: audioldm-s-full-v2 (**recommend**): more training steps comparing with audioldm-s-full *(added 2023-03-04)*. 93 | - audioldm-s-text-ft: the small AudioLDM finetuned with AudioCaps and MusicCaps audio-text pairs *(added 2023-04-10)*. 94 | - audioldm-m-text-ft: the medium large AudioLDM finetuned with AudioCaps and MusicCaps audio-text pairs *(added 2023-04-10)*. 95 | - audioldm-l-full: larger model comparing with audioldm-s-full *(added 2023-03-04)*. 96 | 97 | > @haoheliu personally did a evaluation regarding the overall quality of the checkpoint, which gives audioldm-m-full (6.85/10), audioldm-s-full (6.62/10), audioldm-s-text-ft (6/10), audioldm-m-text-ft (5.46/10). These score are only for reference and may not reflect the true performance of the checkpoint. Checkpoint performance also varying with different text input as well. 98 | 99 | :grey_question: For more options on guidance scale, batchsize, seed, ddim steps, etc., please run 100 | ```shell 101 | audioldm -h 102 | ``` 103 | ```console 104 | usage: audioldm [-h] [--mode {generation,transfer}] [-t TEXT] [-f FILE_PATH] [--transfer_strength TRANSFER_STRENGTH] [-s SAVE_PATH] [--model_name {audioldm-s-full,audioldm-l-full,audioldm-s-full-v2}] [-ckpt CKPT_PATH] 105 | [-b BATCHSIZE] [--ddim_steps DDIM_STEPS] [-gs GUIDANCE_SCALE] [-dur DURATION] [-n N_CANDIDATE_GEN_PER_TEXT] [--seed SEED] 106 | 107 | optional arguments: 108 | -h, --help show this help message and exit 109 | --mode {generation,transfer} 110 | generation: text-to-audio generation; transfer: style transfer 111 | -t TEXT, --text TEXT Text prompt to the model for audio generation, DEFAULT "" 112 | -f FILE_PATH, --file_path FILE_PATH 113 | (--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio, DEFAULT None 114 | --transfer_strength TRANSFER_STRENGTH 115 | A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text, DEFAULT 0.5 116 | -s SAVE_PATH, --save_path SAVE_PATH 117 | The path to save model output, DEFAULT "./output" 118 | --model_name {audioldm-s-full,audioldm-l-full,audioldm-s-full-v2} 119 | The checkpoint you gonna use, DEFAULT "audioldm-s-full" 120 | -ckpt CKPT_PATH, --ckpt_path CKPT_PATH 121 | (deprecated) The path to the pretrained .ckpt model, DEFAULT None 122 | -b BATCHSIZE, --batchsize BATCHSIZE 123 | Generate how many samples at the same time, DEFAULT 1 124 | --ddim_steps DDIM_STEPS 125 | The sampling step for DDIM, DEFAULT 200 126 | -gs GUIDANCE_SCALE, --guidance_scale GUIDANCE_SCALE 127 | Guidance scale (Large => better quality and relavancy to text; Small => better diversity), DEFAULT 2.5 128 | -dur DURATION, --duration DURATION 129 | The duration of the samples, DEFAULT 10 130 | -n N_CANDIDATE_GEN_PER_TEXT, --n_candidate_gen_per_text N_CANDIDATE_GEN_PER_TEXT 131 | Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation, DEFAULT 3 132 | --seed SEED Change this value (any integer number) will lead to a different generation result. DEFAULT 42 133 | ``` 134 | 135 | For the evaluation of audio generative model, please refer to [audioldm_eval](https://github.com/haoheliu/audioldm_eval). 136 | 137 | # Hugging Face 🧨 Diffusers 138 | 139 | AudioLDM is available in the Hugging Face [🧨 Diffusers](https://github.com/huggingface/diffusers) library from v0.15.0 onwards. The official checkpoints can be found on the [Hugging Face Hub](https://huggingface.co/cvssp), alongside [documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm) and [examples scripts](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm). 140 | 141 | To install Diffusers and Transformers, run: 142 | ```bash 143 | pip install --upgrade diffusers transformers 144 | ``` 145 | 146 | You can then load pre-trained weights into the [AudioLDM pipeline](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm) and generate text-conditional audio outputs: 147 | ```python 148 | from diffusers import AudioLDMPipeline 149 | import torch 150 | 151 | repo_id = "cvssp/audioldm-s-full-v2" 152 | pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) 153 | pipe = pipe.to("cuda") 154 | 155 | prompt = "Techno music with a strong, upbeat tempo and high melodic riffs" 156 | audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0] 157 | ``` 158 | 159 | # Web Demo 160 | 161 | Integrated into [Hugging Face Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation) 162 | 163 | # TuneFlow Demo 164 | 165 | Try out AudioLDM as a [TuneFlow](https://tuneflow.com) plugin [![TuneFlow x AudioLDM](https://img.shields.io/badge/TuneFlow-AudioLDM-%23C563E6%20)](https://github.com/tuneflow/AudioLDM). See how it can work in a real DAW (Digital Audio Workstation). 166 | 167 | # TODO 168 | 169 | [!["Buy Me A Coffee"](https://www.buymeacoffee.com/assets/img/custom_images/orange_img.png)](https://www.buymeacoffee.com/haoheliuP) 170 | 171 | - [x] Update the checkpoint with more training steps. 172 | - [x] Update the checkpoint with more parameters (audioldm-l). 173 | - [ ] Add AudioCaps finetuned AudioLDM-S model 174 | - [x] Build pip installable package for commandline use 175 | - [x] Build Gradio web application 176 | - [ ] Add super-resolution, inpainting into Gradio web application 177 | - [ ] Add style-transfer into Gradio web application 178 | - [x] Add text-guided style transfer 179 | - [x] Add audio-to-audio generation 180 | - [x] Add audio super-resolution 181 | - [x] Add audio inpainting 182 | 183 | ## Cite this work 184 | 185 | If you found this tool useful, please consider citing 186 | ```bibtex 187 | @article{liu2023audioldm, 188 | title={{AudioLDM}: Text-to-Audio Generation with Latent Diffusion Models}, 189 | author={Liu, Haohe and Chen, Zehua and Yuan, Yi and Mei, Xinhao and Liu, Xubo and Mandic, Danilo and Wang, Wenwu and Plumbley, Mark D}, 190 | journal={Proceedings of the International Conference on Machine Learning}, 191 | year={2023}, 192 | pages={21450-21474} 193 | } 194 | ``` 195 | 196 | # Hardware requirement 197 | - GPU with 8GB of dedicated VRAM 198 | - A system with a 64-bit operating system (Windows 7, 8.1 or 10, Ubuntu 16.04 or later, or macOS 10.13 or later) 16GB or more of system RAM 199 | 200 | ## Reference 201 | Part of the code is borrowed from the following repos. We would like to thank the authors of these repos for their contribution. 202 | 203 | > https://github.com/LAION-AI/CLAP 204 | 205 | > https://github.com/CompVis/stable-diffusion 206 | 207 | > https://github.com/v-iashin/SpecVQGAN 208 | 209 | > https://github.com/toshas/torch-fidelity 210 | 211 | 212 | We build the model with data from AudioSet, Freesound and BBC Sound Effect library. We share this demo based on the UK copyright exception of data for academic research. 213 | 214 | 215 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import numpy as np 3 | from audioldm import text_to_audio, build_model 4 | 5 | # from share_btn import community_icon_html, loading_icon_html, share_js 6 | 7 | model_id = "haoheliu/AudioLDM-S-Full" 8 | 9 | audioldm = None 10 | current_model_name = None 11 | # audioldm=None 12 | 13 | # def predict(input, history=[]): 14 | # # tokenize the new input sentence 15 | # new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt') 16 | 17 | # # append the new user input tokens to the chat history 18 | # bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) 19 | 20 | # # generate a response 21 | # history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist() 22 | 23 | # # convert the tokens to text, and then split the responses into lines 24 | # response = tokenizer.decode(history[0]).split("<|endoftext|>") 25 | # response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list 26 | # return response, history 27 | 28 | def text2audio(text, duration, guidance_scale, random_seed, n_candidates, model_name): 29 | global audioldm, current_model_name 30 | 31 | if audioldm is None or model_name != current_model_name: 32 | audioldm=build_model(model_name=model_name) 33 | current_model_name = model_name 34 | 35 | # print(text, length, guidance_scale) 36 | waveform = text_to_audio( 37 | latent_diffusion=audioldm, 38 | text=text, 39 | seed=random_seed, 40 | duration=duration, 41 | guidance_scale=guidance_scale, 42 | n_candidate_gen_per_text=int(n_candidates), 43 | ) # [bs, 1, samples] 44 | waveform = [ 45 | gr.make_waveform((16000, wave[0]), bg_image="bg.png") for wave in waveform 46 | ] 47 | # waveform = [(16000, np.random.randn(16000)), (16000, np.random.randn(16000))] 48 | if len(waveform) == 1: 49 | waveform = waveform[0] 50 | return waveform 51 | 52 | 53 | # iface = gr.Interface(fn=text2audio, inputs=[ 54 | # gr.Textbox(value="A man is speaking in a huge room", max_lines=1), 55 | # gr.Slider(2.5, 10, value=5, step=2.5), 56 | # gr.Slider(0, 5, value=2.5, step=0.5), 57 | # gr.Number(value=42) 58 | # ], outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")], 59 | # allow_flagging="never" 60 | # ) 61 | # iface.launch(share=True) 62 | 63 | 64 | css = """ 65 | a { 66 | color: inherit; 67 | text-decoration: underline; 68 | } 69 | .gradio-container { 70 | font-family: 'IBM Plex Sans', sans-serif; 71 | } 72 | .gr-button { 73 | color: white; 74 | border-color: #000000; 75 | background: #000000; 76 | } 77 | input[type='range'] { 78 | accent-color: #000000; 79 | } 80 | .dark input[type='range'] { 81 | accent-color: #dfdfdf; 82 | } 83 | .container { 84 | max-width: 730px; 85 | margin: auto; 86 | padding-top: 1.5rem; 87 | } 88 | #gallery { 89 | min-height: 22rem; 90 | margin-bottom: 15px; 91 | margin-left: auto; 92 | margin-right: auto; 93 | border-bottom-right-radius: .5rem !important; 94 | border-bottom-left-radius: .5rem !important; 95 | } 96 | #gallery>div>.h-full { 97 | min-height: 20rem; 98 | } 99 | .details:hover { 100 | text-decoration: underline; 101 | } 102 | .gr-button { 103 | white-space: nowrap; 104 | } 105 | .gr-button:focus { 106 | border-color: rgb(147 197 253 / var(--tw-border-opacity)); 107 | outline: none; 108 | box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); 109 | --tw-border-opacity: 1; 110 | --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); 111 | --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color); 112 | --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity)); 113 | --tw-ring-opacity: .5; 114 | } 115 | #advanced-btn { 116 | font-size: .7rem !important; 117 | line-height: 19px; 118 | margin-top: 12px; 119 | margin-bottom: 12px; 120 | padding: 2px 8px; 121 | border-radius: 14px !important; 122 | } 123 | #advanced-options { 124 | margin-bottom: 20px; 125 | } 126 | .footer { 127 | margin-bottom: 45px; 128 | margin-top: 35px; 129 | text-align: center; 130 | border-bottom: 1px solid #e5e5e5; 131 | } 132 | .footer>p { 133 | font-size: .8rem; 134 | display: inline-block; 135 | padding: 0 10px; 136 | transform: translateY(10px); 137 | background: white; 138 | } 139 | .dark .footer { 140 | border-color: #303030; 141 | } 142 | .dark .footer>p { 143 | background: #0b0f19; 144 | } 145 | .acknowledgments h4{ 146 | margin: 1.25em 0 .25em 0; 147 | font-weight: bold; 148 | font-size: 115%; 149 | } 150 | #container-advanced-btns{ 151 | display: flex; 152 | flex-wrap: wrap; 153 | justify-content: space-between; 154 | align-items: center; 155 | } 156 | .animate-spin { 157 | animation: spin 1s linear infinite; 158 | } 159 | @keyframes spin { 160 | from { 161 | transform: rotate(0deg); 162 | } 163 | to { 164 | transform: rotate(360deg); 165 | } 166 | } 167 | #share-btn-container { 168 | display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; 169 | margin-top: 10px; 170 | margin-left: auto; 171 | } 172 | #share-btn { 173 | all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; 174 | } 175 | #share-btn * { 176 | all: unset; 177 | } 178 | #share-btn-container div:nth-child(-n+2){ 179 | width: auto !important; 180 | min-height: 0px !important; 181 | } 182 | #share-btn-container .wrap { 183 | display: none !important; 184 | } 185 | .gr-form{ 186 | flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0; 187 | } 188 | #prompt-container{ 189 | gap: 0; 190 | } 191 | #generated_id{ 192 | min-height: 700px 193 | } 194 | #setting_id{ 195 | margin-bottom: 12px; 196 | text-align: center; 197 | font-weight: 900; 198 | } 199 | """ 200 | iface = gr.Blocks(css=css) 201 | 202 | with iface: 203 | gr.HTML( 204 | """ 205 |
206 |
214 |

215 | AudioLDM: Text-to-Audio Generation with Latent Diffusion Models 216 |

217 |
218 |

219 | [Paper] [Project page] 220 |

221 |
222 | """ 223 | ) 224 | with gr.Group(): 225 | with gr.Box(): 226 | ############# Input 227 | textbox = gr.Textbox( 228 | value="A hammer is hitting a wooden surface", 229 | max_lines=1, 230 | label="Input your text here. Please ensure it is descriptive and of moderate length.", 231 | elem_id="prompt-in", 232 | ) 233 | 234 | with gr.Accordion("Click to modify detailed configurations", open=False): 235 | seed = gr.Number( 236 | value=42, 237 | label="Change this value (any integer number) will lead to a different generation result.", 238 | ) 239 | duration = gr.Slider( 240 | 2.5, 10, value=5, step=2.5, label="Duration (seconds)" 241 | ) 242 | guidance_scale = gr.Slider( 243 | 0, 244 | 5, 245 | value=2.5, 246 | step=0.5, 247 | label="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)", 248 | ) 249 | n_candidates = gr.Slider( 250 | 1, 251 | 5, 252 | value=3, 253 | step=1, 254 | label="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation", 255 | ) 256 | model_name = gr.Dropdown( 257 | ["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2","audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full"], value="audioldm-m-full", label="Choose the model to use. audioldm-m-text-ft and audioldm-s-text-ft are recommanded. -s- means small, -m- means medium and -l- means large", 258 | ) 259 | ############# Output 260 | # outputs=gr.Audio(label="Output", type="numpy") 261 | outputs = gr.Video(label="Output", elem_id="output-video") 262 | 263 | # with gr.Group(elem_id="container-advanced-btns"): 264 | # # advanced_button = gr.Button("Advanced options", elem_id="advanced-btn") 265 | # with gr.Group(elem_id="share-btn-container"): 266 | # community_icon = gr.HTML(community_icon_html, visible=False) 267 | # loading_icon = gr.HTML(loading_icon_html, visible=False) 268 | # share_button = gr.Button("Share to community", elem_id="share-btn", visible=False) 269 | # outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")] 270 | btn = gr.Button("Submit").style(full_width=True) 271 | 272 | # with gr.Group(elem_id="share-btn-container", visible=False): 273 | # community_icon = gr.HTML(community_icon_html) 274 | # loading_icon = gr.HTML(loading_icon_html) 275 | # share_button = gr.Button("Share to community", elem_id="share-btn") 276 | 277 | btn.click( 278 | text2audio, 279 | inputs=[textbox, duration, guidance_scale, seed, n_candidates, model_name], 280 | outputs=[outputs], 281 | ) 282 | 283 | # share_button.click(None, [], [], _js=share_js) 284 | gr.HTML( 285 | """ 286 | 293 | """ 294 | ) 295 | # gr.Examples( 296 | # [ 297 | # ["A hammer is hitting a wooden surface", 5, 2.5, 45, 3, "audioldm-s-full"], 298 | # [ 299 | # "Peaceful and calming ambient music with singing bowl and other instruments.", 300 | # 5, 301 | # 2.5, 302 | # 45, 303 | # 3, 304 | # "audioldm-s-full" 305 | # ], 306 | # ["A man is speaking in a small room.", 5, 2.5, 45, 3, "audioldm-s-full"], 307 | # ["A female is speaking followed by footstep sound", 5, 2.5, 45, 3, "audioldm-s-full"], 308 | # [ 309 | # "Wooden table tapping sound followed by water pouring sound.", 310 | # 5, 311 | # 2.5, 312 | # 45, 313 | # 3, 314 | # "audioldm-s-full" 315 | # ], 316 | # ], 317 | # fn=text2audio, 318 | # inputs=[textbox, duration, guidance_scale, seed, n_candidates, model_name], 319 | # outputs=[outputs], 320 | # cache_examples=True, 321 | # ) 322 | with gr.Accordion("Additional information", open=False): 323 | gr.HTML( 324 | """ 325 |
326 |

We build the model with data from AudioSet, Freesound and BBC Sound Effect library. We share this demo based on the UK copyright exception of data for academic research.

327 |
328 | """ 329 | ) 330 | #

This demo is strictly for research demo purpose only. For commercial use please contact us.

331 | 332 | iface.queue(concurrency_count=3) 333 | # iface.launch(debug=True) 334 | iface.launch(debug=True, share=False) 335 | -------------------------------------------------------------------------------- /audioldm/__init__.py: -------------------------------------------------------------------------------- 1 | from .ldm import LatentDiffusion 2 | from .utils import seed_everything, save_wave, get_time, get_duration 3 | from .pipeline import * 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /audioldm/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import os 3 | from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration 4 | import argparse 5 | 6 | CACHE_DIR = os.getenv( 7 | "AUDIOLDM_CACHE_DIR", 8 | os.path.join(os.path.expanduser("~"), ".cache/audioldm")) 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument( 13 | "--mode", 14 | type=str, 15 | required=False, 16 | default="generation", 17 | help="generation: text-to-audio generation; transfer: style transfer", 18 | choices=["generation", "transfer"] 19 | ) 20 | 21 | parser.add_argument( 22 | "-t", 23 | "--text", 24 | type=str, 25 | required=False, 26 | default="", 27 | help="Text prompt to the model for audio generation", 28 | ) 29 | 30 | parser.add_argument( 31 | "-f", 32 | "--file_path", 33 | type=str, 34 | required=False, 35 | default=None, 36 | help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio", 37 | ) 38 | 39 | parser.add_argument( 40 | "--transfer_strength", 41 | type=float, 42 | required=False, 43 | default=0.5, 44 | help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text", 45 | ) 46 | 47 | parser.add_argument( 48 | "-s", 49 | "--save_path", 50 | type=str, 51 | required=False, 52 | help="The path to save model output", 53 | default="./output", 54 | ) 55 | 56 | parser.add_argument( 57 | "--model_name", 58 | type=str, 59 | required=False, 60 | help="The checkpoint you gonna use", 61 | default="audioldm-m-full", 62 | choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2","audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full"] 63 | ) 64 | 65 | parser.add_argument( 66 | "-ckpt", 67 | "--ckpt_path", 68 | type=str, 69 | required=False, 70 | help="The path to the pretrained .ckpt model", 71 | default=None, 72 | ) 73 | 74 | parser.add_argument( 75 | "-b", 76 | "--batchsize", 77 | type=int, 78 | required=False, 79 | default=1, 80 | help="Generate how many samples at the same time", 81 | ) 82 | 83 | parser.add_argument( 84 | "--ddim_steps", 85 | type=int, 86 | required=False, 87 | default=200, 88 | help="The sampling step for DDIM", 89 | ) 90 | 91 | parser.add_argument( 92 | "-gs", 93 | "--guidance_scale", 94 | type=float, 95 | required=False, 96 | default=2.5, 97 | help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)", 98 | ) 99 | 100 | parser.add_argument( 101 | "-dur", 102 | "--duration", 103 | type=float, 104 | required=False, 105 | default=10.0, 106 | help="The duration of the samples", 107 | ) 108 | 109 | parser.add_argument( 110 | "-n", 111 | "--n_candidate_gen_per_text", 112 | type=int, 113 | required=False, 114 | default=3, 115 | help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation", 116 | ) 117 | 118 | parser.add_argument( 119 | "--seed", 120 | type=int, 121 | required=False, 122 | default=42, 123 | help="Change this value (any integer number) will lead to a different generation result.", 124 | ) 125 | 126 | args = parser.parse_args() 127 | 128 | if(args.ckpt_path is not None): 129 | print("Warning: ckpt_path has no effect after version 0.0.20.") 130 | 131 | assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5" 132 | 133 | mode = args.mode 134 | if(mode == "generation" and args.file_path is not None): 135 | mode = "generation_audio_to_audio" 136 | if(len(args.text) > 0): 137 | print("Warning: You have specified the --file_path. --text will be ignored") 138 | args.text = "" 139 | 140 | save_path = os.path.join(args.save_path, mode) 141 | 142 | if(args.file_path is not None): 143 | save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0])) 144 | 145 | text = args.text 146 | random_seed = args.seed 147 | duration = args.duration 148 | guidance_scale = args.guidance_scale 149 | n_candidate_gen_per_text = args.n_candidate_gen_per_text 150 | 151 | os.makedirs(save_path, exist_ok=True) 152 | audioldm = build_model(model_name=args.model_name) 153 | 154 | if(args.mode == "generation"): 155 | waveform = text_to_audio( 156 | audioldm, 157 | text, 158 | args.file_path, 159 | random_seed, 160 | duration=duration, 161 | guidance_scale=guidance_scale, 162 | ddim_steps=args.ddim_steps, 163 | n_candidate_gen_per_text=n_candidate_gen_per_text, 164 | batchsize=args.batchsize, 165 | ) 166 | 167 | elif(args.mode == "transfer"): 168 | assert args.file_path is not None 169 | assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path 170 | waveform = style_transfer( 171 | audioldm, 172 | text, 173 | args.file_path, 174 | args.transfer_strength, 175 | random_seed, 176 | duration=duration, 177 | guidance_scale=guidance_scale, 178 | ddim_steps=args.ddim_steps, 179 | batchsize=args.batchsize, 180 | ) 181 | waveform = waveform[:,None,:] 182 | 183 | save_wave(waveform, save_path, name="%s_%s" % (get_time(), text)) 184 | -------------------------------------------------------------------------------- /audioldm/audio/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import wav_to_fbank, read_wav_file 2 | from .stft import TacotronSTFT 3 | -------------------------------------------------------------------------------- /audioldm/audio/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import librosa.util as librosa_util 4 | from scipy.signal import get_window 5 | 6 | 7 | def window_sumsquare( 8 | window, 9 | n_frames, 10 | hop_length, 11 | win_length, 12 | n_fft, 13 | dtype=np.float32, 14 | norm=None, 15 | ): 16 | """ 17 | # from librosa 0.6 18 | Compute the sum-square envelope of a window function at a given hop length. 19 | 20 | This is used to estimate modulation effects induced by windowing 21 | observations in short-time fourier transforms. 22 | 23 | Parameters 24 | ---------- 25 | window : string, tuple, number, callable, or list-like 26 | Window specification, as in `get_window` 27 | 28 | n_frames : int > 0 29 | The number of analysis frames 30 | 31 | hop_length : int > 0 32 | The number of samples to advance between frames 33 | 34 | win_length : [optional] 35 | The length of the window function. By default, this matches `n_fft`. 36 | 37 | n_fft : int > 0 38 | The length of each analysis frame. 39 | 40 | dtype : np.dtype 41 | The data type of the output 42 | 43 | Returns 44 | ------- 45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 46 | The sum-squared envelope of the window function 47 | """ 48 | if win_length is None: 49 | win_length = n_fft 50 | 51 | n = n_fft + hop_length * (n_frames - 1) 52 | x = np.zeros(n, dtype=dtype) 53 | 54 | # Compute the squared window at the desired length 55 | win_sq = get_window(window, win_length, fftbins=True) 56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 57 | win_sq = librosa_util.pad_center(win_sq, n_fft) 58 | 59 | # Fill the envelope 60 | for i in range(n_frames): 61 | sample = i * hop_length 62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] 63 | return x 64 | 65 | 66 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 67 | """ 68 | PARAMS 69 | ------ 70 | magnitudes: spectrogram magnitudes 71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 72 | """ 73 | 74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 75 | angles = angles.astype(np.float32) 76 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 78 | 79 | for i in range(n_iters): 80 | _, angles = stft_fn.transform(signal) 81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 82 | return signal 83 | 84 | 85 | def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): 86 | """ 87 | PARAMS 88 | ------ 89 | C: compression factor 90 | """ 91 | return normalize_fun(torch.clamp(x, min=clip_val) * C) 92 | 93 | 94 | def dynamic_range_decompression(x, C=1): 95 | """ 96 | PARAMS 97 | ------ 98 | C: compression factor used to compress 99 | """ 100 | return torch.exp(x) / C 101 | -------------------------------------------------------------------------------- /audioldm/audio/stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy.signal import get_window 5 | from librosa.util import pad_center, tiny 6 | from librosa.filters import mel as librosa_mel_fn 7 | 8 | from audioldm.audio.audio_processing import ( 9 | dynamic_range_compression, 10 | dynamic_range_decompression, 11 | window_sumsquare, 12 | ) 13 | 14 | 15 | class STFT(torch.nn.Module): 16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 17 | 18 | def __init__(self, filter_length, hop_length, win_length, window="hann"): 19 | super(STFT, self).__init__() 20 | self.filter_length = filter_length 21 | self.hop_length = hop_length 22 | self.win_length = win_length 23 | self.window = window 24 | self.forward_transform = None 25 | scale = self.filter_length / self.hop_length 26 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 27 | 28 | cutoff = int((self.filter_length / 2 + 1)) 29 | fourier_basis = np.vstack( 30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] 31 | ) 32 | 33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 34 | inverse_basis = torch.FloatTensor( 35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :] 36 | ) 37 | 38 | if window is not None: 39 | assert filter_length >= win_length 40 | # get window and zero center pad it to filter_length 41 | fft_window = get_window(window, win_length, fftbins=True) 42 | fft_window = pad_center(fft_window, filter_length) 43 | fft_window = torch.from_numpy(fft_window).float() 44 | 45 | # window the bases 46 | forward_basis *= fft_window 47 | inverse_basis *= fft_window 48 | 49 | self.register_buffer("forward_basis", forward_basis.float()) 50 | self.register_buffer("inverse_basis", inverse_basis.float()) 51 | 52 | def transform(self, input_data): 53 | num_batches = input_data.size(0) 54 | num_samples = input_data.size(1) 55 | 56 | self.num_samples = num_samples 57 | 58 | # similar to librosa, reflect-pad the input 59 | input_data = input_data.view(num_batches, 1, num_samples) 60 | input_data = F.pad( 61 | input_data.unsqueeze(1), 62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 63 | mode="reflect", 64 | ) 65 | input_data = input_data.squeeze(1) 66 | 67 | forward_transform = F.conv1d( 68 | input_data, 69 | torch.autograd.Variable(self.forward_basis, requires_grad=False), 70 | stride=self.hop_length, 71 | padding=0, 72 | ).cpu() 73 | 74 | cutoff = int((self.filter_length / 2) + 1) 75 | real_part = forward_transform[:, :cutoff, :] 76 | imag_part = forward_transform[:, cutoff:, :] 77 | 78 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) 80 | 81 | return magnitude, phase 82 | 83 | def inverse(self, magnitude, phase): 84 | recombine_magnitude_phase = torch.cat( 85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 86 | ) 87 | 88 | inverse_transform = F.conv_transpose1d( 89 | recombine_magnitude_phase, 90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False), 91 | stride=self.hop_length, 92 | padding=0, 93 | ) 94 | 95 | if self.window is not None: 96 | window_sum = window_sumsquare( 97 | self.window, 98 | magnitude.size(-1), 99 | hop_length=self.hop_length, 100 | win_length=self.win_length, 101 | n_fft=self.filter_length, 102 | dtype=np.float32, 103 | ) 104 | # remove modulation effects 105 | approx_nonzero_indices = torch.from_numpy( 106 | np.where(window_sum > tiny(window_sum))[0] 107 | ) 108 | window_sum = torch.autograd.Variable( 109 | torch.from_numpy(window_sum), requires_grad=False 110 | ) 111 | window_sum = window_sum 112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ 113 | approx_nonzero_indices 114 | ] 115 | 116 | # scale by hop ratio 117 | inverse_transform *= float(self.filter_length) / self.hop_length 118 | 119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] 120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] 121 | 122 | return inverse_transform 123 | 124 | def forward(self, input_data): 125 | self.magnitude, self.phase = self.transform(input_data) 126 | reconstruction = self.inverse(self.magnitude, self.phase) 127 | return reconstruction 128 | 129 | 130 | class TacotronSTFT(torch.nn.Module): 131 | def __init__( 132 | self, 133 | filter_length, 134 | hop_length, 135 | win_length, 136 | n_mel_channels, 137 | sampling_rate, 138 | mel_fmin, 139 | mel_fmax, 140 | ): 141 | super(TacotronSTFT, self).__init__() 142 | self.n_mel_channels = n_mel_channels 143 | self.sampling_rate = sampling_rate 144 | self.stft_fn = STFT(filter_length, hop_length, win_length) 145 | mel_basis = librosa_mel_fn( 146 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax 147 | ) 148 | mel_basis = torch.from_numpy(mel_basis).float() 149 | self.register_buffer("mel_basis", mel_basis) 150 | 151 | def spectral_normalize(self, magnitudes, normalize_fun): 152 | output = dynamic_range_compression(magnitudes, normalize_fun) 153 | return output 154 | 155 | def spectral_de_normalize(self, magnitudes): 156 | output = dynamic_range_decompression(magnitudes) 157 | return output 158 | 159 | def mel_spectrogram(self, y, normalize_fun=torch.log): 160 | """Computes mel-spectrograms from a batch of waves 161 | PARAMS 162 | ------ 163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 164 | 165 | RETURNS 166 | ------- 167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 168 | """ 169 | assert torch.min(y.data) >= -1, torch.min(y.data) 170 | assert torch.max(y.data) <= 1, torch.max(y.data) 171 | 172 | magnitudes, phases = self.stft_fn.transform(y) 173 | magnitudes = magnitudes.data 174 | mel_output = torch.matmul(self.mel_basis, magnitudes) 175 | mel_output = self.spectral_normalize(mel_output, normalize_fun) 176 | energy = torch.norm(magnitudes, dim=1) 177 | 178 | log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) 179 | 180 | return mel_output, log_magnitudes, energy 181 | -------------------------------------------------------------------------------- /audioldm/audio/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchaudio 4 | 5 | 6 | def get_mel_from_wav(audio, _stft): 7 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 8 | audio = torch.autograd.Variable(audio, requires_grad=False) 9 | melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio) 10 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) 11 | log_magnitudes_stft = ( 12 | torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32) 13 | ) 14 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32) 15 | return melspec, log_magnitudes_stft, energy 16 | 17 | 18 | def _pad_spec(fbank, target_length=1024): 19 | n_frames = fbank.shape[0] 20 | p = target_length - n_frames 21 | # cut and pad 22 | if p > 0: 23 | m = torch.nn.ZeroPad2d((0, 0, 0, p)) 24 | fbank = m(fbank) 25 | elif p < 0: 26 | fbank = fbank[0:target_length, :] 27 | 28 | if fbank.size(-1) % 2 != 0: 29 | fbank = fbank[..., :-1] 30 | 31 | return fbank 32 | 33 | 34 | def pad_wav(waveform, segment_length): 35 | waveform_length = waveform.shape[-1] 36 | assert waveform_length > 100, "Waveform is too short, %s" % waveform_length 37 | if segment_length is None or waveform_length == segment_length: 38 | return waveform 39 | elif waveform_length > segment_length: 40 | return waveform[:segment_length] 41 | elif waveform_length < segment_length: 42 | temp_wav = np.zeros((1, segment_length)) 43 | temp_wav[:, :waveform_length] = waveform 44 | return temp_wav 45 | 46 | def normalize_wav(waveform): 47 | waveform = waveform - np.mean(waveform) 48 | waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) 49 | return waveform * 0.5 50 | 51 | 52 | def read_wav_file(filename, segment_length): 53 | # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower 54 | waveform, sr = torchaudio.load(filename) # Faster!!! 55 | waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) 56 | waveform = waveform.numpy()[0, ...] 57 | waveform = normalize_wav(waveform) 58 | waveform = waveform[None, ...] 59 | waveform = pad_wav(waveform, segment_length) 60 | 61 | waveform = waveform / np.max(np.abs(waveform)) 62 | waveform = 0.5 * waveform 63 | 64 | return waveform 65 | 66 | 67 | def wav_to_fbank(filename, target_length=1024, fn_STFT=None): 68 | assert fn_STFT is not None 69 | 70 | # mixup 71 | waveform = read_wav_file(filename, target_length * 160) # hop size is 160 72 | 73 | waveform = waveform[0, ...] 74 | waveform = torch.FloatTensor(waveform) 75 | 76 | fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) 77 | 78 | fbank = torch.FloatTensor(fbank.T) 79 | log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T) 80 | 81 | fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( 82 | log_magnitudes_stft, target_length 83 | ) 84 | 85 | return fbank, log_magnitudes_stft, waveform 86 | -------------------------------------------------------------------------------- /audioldm/clap/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/audioldm/clap/__init__.py -------------------------------------------------------------------------------- /audioldm/clap/encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from audioldm.clap.open_clip import create_model 4 | from audioldm.clap.training.data import get_audio_features 5 | import torchaudio 6 | from transformers import RobertaTokenizer 7 | import torch.nn.functional as F 8 | 9 | 10 | class CLAPAudioEmbeddingClassifierFreev2(nn.Module): 11 | def __init__( 12 | self, 13 | pretrained_path="", 14 | key="class", 15 | sampling_rate=16000, 16 | embed_mode="audio", 17 | amodel = "HTSAT-tiny", 18 | unconditional_prob=0.1, 19 | random_mute=False, 20 | max_random_mute_portion=0.5, 21 | training_mode=True, 22 | ): 23 | super().__init__() 24 | 25 | self.key = key 26 | self.device = "cpu" 27 | self.precision = "fp32" 28 | self.amodel = amodel # or 'PANN-14' 29 | self.tmodel = "roberta" # the best text encoder in our training 30 | self.enable_fusion = False # False if you do not want to use the fusion model 31 | self.fusion_type = "aff_2d" 32 | self.pretrained = pretrained_path 33 | self.embed_mode = embed_mode 34 | self.embed_mode_orig = embed_mode 35 | self.sampling_rate = sampling_rate 36 | self.unconditional_prob = unconditional_prob 37 | self.random_mute = random_mute 38 | self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") 39 | self.max_random_mute_portion = max_random_mute_portion 40 | self.training_mode = training_mode 41 | self.model, self.model_cfg = create_model( 42 | self.amodel, 43 | self.tmodel, 44 | self.pretrained, 45 | precision=self.precision, 46 | device=self.device, 47 | enable_fusion=self.enable_fusion, 48 | fusion_type=self.fusion_type, 49 | ) 50 | for p in self.model.parameters(): 51 | p.requires_grad = False 52 | 53 | self.model.eval() 54 | 55 | def get_unconditional_condition(self, batchsize): 56 | self.unconditional_token = self.model.get_text_embedding( 57 | self.tokenizer(["", ""]) 58 | )[0:1] 59 | return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0) 60 | 61 | def batch_to_list(self, batch): 62 | ret = [] 63 | for i in range(batch.size(0)): 64 | ret.append(batch[i]) 65 | return ret 66 | 67 | def make_decision(self, probability): 68 | if float(torch.rand(1)) < probability: 69 | return True 70 | else: 71 | return False 72 | 73 | def random_uniform(self, start, end): 74 | val = torch.rand(1).item() 75 | return start + (end - start) * val 76 | 77 | def _random_mute(self, waveform): 78 | # waveform: [bs, t-steps] 79 | t_steps = waveform.size(-1) 80 | for i in range(waveform.size(0)): 81 | mute_size = int( 82 | self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion)) 83 | ) 84 | mute_start = int(self.random_uniform(0, t_steps - mute_size)) 85 | waveform[i, mute_start : mute_start + mute_size] = 0 86 | return waveform 87 | 88 | def cos_similarity(self, waveform, text): 89 | # waveform: [bs, t_steps] 90 | with torch.no_grad(): 91 | self.embed_mode = "audio" 92 | audio_emb = self(waveform.cuda()) 93 | self.embed_mode = "text" 94 | text_emb = self(text) 95 | similarity = F.cosine_similarity(audio_emb, text_emb, dim=2) 96 | return similarity.squeeze() 97 | 98 | def forward(self, batch, key=None): 99 | # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0 100 | # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0 101 | if self.model.training == True and not self.training_mode: 102 | print( 103 | "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters." 104 | ) 105 | self.model, self.model_cfg = create_model( 106 | self.amodel, 107 | self.tmodel, 108 | self.pretrained, 109 | precision=self.precision, 110 | device="cuda", 111 | enable_fusion=self.enable_fusion, 112 | fusion_type=self.fusion_type, 113 | ) 114 | for p in self.model.parameters(): 115 | p.requires_grad = False 116 | self.model.eval() 117 | 118 | # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode 119 | if self.embed_mode == "audio": 120 | with torch.no_grad(): 121 | audio_dict_list = [] 122 | assert ( 123 | self.sampling_rate == 16000 124 | ), "We only support 16000 sampling rate" 125 | if self.random_mute: 126 | batch = self._random_mute(batch) 127 | # batch: [bs, 1, t-samples] 128 | batch = torchaudio.functional.resample( 129 | batch, orig_freq=self.sampling_rate, new_freq=48000 130 | ) 131 | for waveform in self.batch_to_list(batch): 132 | audio_dict = {} 133 | audio_dict = get_audio_features( 134 | audio_dict, 135 | waveform, 136 | 480000, 137 | data_truncating="fusion", 138 | data_filling="repeatpad", 139 | audio_cfg=self.model_cfg["audio_cfg"], 140 | ) 141 | audio_dict_list.append(audio_dict) 142 | # [bs, 512] 143 | embed = self.model.get_audio_embedding(audio_dict_list) 144 | elif self.embed_mode == "text": 145 | with torch.no_grad(): 146 | # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode 147 | text_data = self.tokenizer(batch) 148 | embed = self.model.get_text_embedding(text_data) 149 | 150 | embed = embed.unsqueeze(1) 151 | self.unconditional_token = self.model.get_text_embedding( 152 | self.tokenizer(["", ""]) 153 | )[0:1] 154 | 155 | for i in range(embed.size(0)): 156 | if self.make_decision(self.unconditional_prob): 157 | embed[i] = self.unconditional_token 158 | 159 | # [bs, 1, 512] 160 | return embed.detach() 161 | 162 | def tokenizer(self, text): 163 | result = self.tokenize( 164 | text, 165 | padding="max_length", 166 | truncation=True, 167 | max_length=512, 168 | return_tensors="pt", 169 | ) 170 | return {k: v.squeeze(0) for k, v in result.items()} 171 | -------------------------------------------------------------------------------- /audioldm/clap/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import ( 2 | list_models, 3 | create_model, 4 | create_model_and_transforms, 5 | add_model_config, 6 | ) 7 | from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics 8 | from .model import ( 9 | CLAP, 10 | CLAPTextCfg, 11 | CLAPVisionCfg, 12 | CLAPAudioCfp, 13 | convert_weights_to_fp16, 14 | trace_model, 15 | ) 16 | from .openai import load_openai_model, list_openai_models 17 | from .pretrained import ( 18 | list_pretrained, 19 | list_pretrained_tag_models, 20 | list_pretrained_model_tags, 21 | get_pretrained_url, 22 | download_pretrained, 23 | ) 24 | from .tokenizer import SimpleTokenizer, tokenize 25 | from .transform import image_transform 26 | -------------------------------------------------------------------------------- /audioldm/clap/open_clip/bert.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer, BertModel 2 | 3 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 4 | model = BertModel.from_pretrained("bert-base-uncased") 5 | text = "Replace me by any text you'd like." 6 | 7 | 8 | def bert_embeddings(text): 9 | # text = "Replace me by any text you'd like." 10 | encoded_input = tokenizer(text, return_tensors="pt") 11 | output = model(**encoded_input) 12 | return output 13 | 14 | 15 | from transformers import RobertaTokenizer, RobertaModel 16 | 17 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 18 | model = RobertaModel.from_pretrained("roberta-base") 19 | text = "Replace me by any text you'd like." 20 | 21 | 22 | def Roberta_embeddings(text): 23 | # text = "Replace me by any text you'd like." 24 | encoded_input = tokenizer(text, return_tensors="pt") 25 | output = model(**encoded_input) 26 | return output 27 | 28 | 29 | from transformers import BartTokenizer, BartModel 30 | 31 | tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") 32 | model = BartModel.from_pretrained("facebook/bart-base") 33 | text = "Replace me by any text you'd like." 34 | 35 | 36 | def bart_embeddings(text): 37 | # text = "Replace me by any text you'd like." 38 | encoded_input = tokenizer(text, return_tensors="pt") 39 | output = model(**encoded_input) 40 | return output 41 | -------------------------------------------------------------------------------- /audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /audioldm/clap/open_clip/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | 9 | import torch 10 | 11 | from .model import CLAP, convert_weights_to_fp16 12 | from .openai import load_openai_model 13 | from .pretrained import get_pretrained_url, download_pretrained 14 | from .transform import image_transform 15 | 16 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 17 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 18 | CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache/audioldm") 19 | 20 | 21 | 22 | def _natural_key(string_): 23 | return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] 24 | 25 | 26 | def _rescan_model_configs(): 27 | global _MODEL_CONFIGS 28 | 29 | config_ext = (".json",) 30 | config_files = [] 31 | for config_path in _MODEL_CONFIG_PATHS: 32 | if config_path.is_file() and config_path.suffix in config_ext: 33 | config_files.append(config_path) 34 | elif config_path.is_dir(): 35 | for ext in config_ext: 36 | config_files.extend(config_path.glob(f"*{ext}")) 37 | 38 | for cf in config_files: 39 | if os.path.basename(cf)[0] == ".": 40 | continue # Ignore hidden files 41 | 42 | with open(cf, "r") as f: 43 | model_cfg = json.load(f) 44 | if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")): 45 | _MODEL_CONFIGS[cf.stem] = model_cfg 46 | 47 | _MODEL_CONFIGS = { 48 | k: v 49 | for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) 50 | } 51 | 52 | 53 | _rescan_model_configs() # initial populate of model config registry 54 | 55 | 56 | def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True): 57 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 58 | if isinstance(checkpoint, dict) and "state_dict" in checkpoint: 59 | state_dict = checkpoint["state_dict"] 60 | else: 61 | state_dict = checkpoint 62 | if skip_params: 63 | if next(iter(state_dict.items()))[0].startswith("module"): 64 | state_dict = {k[7:]: v for k, v in state_dict.items()} 65 | # for k in state_dict: 66 | # if k.startswith('transformer'): 67 | # v = state_dict.pop(k) 68 | # state_dict['text_branch.' + k[12:]] = v 69 | return state_dict 70 | 71 | 72 | def create_model( 73 | amodel_name: str, 74 | tmodel_name: str, 75 | pretrained: str = "", 76 | precision: str = "fp32", 77 | device: torch.device = torch.device("cpu"), 78 | jit: bool = False, 79 | force_quick_gelu: bool = False, 80 | openai_model_cache_dir: str = os.path.expanduser(f"{CACHE_DIR}/clip"), 81 | skip_params=True, 82 | pretrained_audio: str = "", 83 | pretrained_text: str = "", 84 | enable_fusion: bool = False, 85 | fusion_type: str = "None" 86 | # pretrained_image: bool = False, 87 | ): 88 | amodel_name = amodel_name.replace( 89 | "/", "-" 90 | ) # for callers using old naming with / in ViT names 91 | pretrained_orig = pretrained 92 | pretrained = pretrained.lower() 93 | if pretrained == "openai": 94 | if amodel_name in _MODEL_CONFIGS: 95 | logging.info(f"Loading {amodel_name} model config.") 96 | model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) 97 | else: 98 | logging.error( 99 | f"Model config for {amodel_name} not found; available models {list_models()}." 100 | ) 101 | raise RuntimeError(f"Model config for {amodel_name} not found.") 102 | 103 | logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.") 104 | # Hard Code in model name 105 | model_cfg["text_cfg"]["model_type"] = tmodel_name 106 | model = load_openai_model( 107 | "ViT-B-16", 108 | model_cfg, 109 | device=device, 110 | jit=jit, 111 | cache_dir=openai_model_cache_dir, 112 | enable_fusion=enable_fusion, 113 | fusion_type=fusion_type, 114 | ) 115 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 116 | if precision == "amp" or precision == "fp32": 117 | model = model.float() 118 | else: 119 | if amodel_name in _MODEL_CONFIGS: 120 | logging.info(f"Loading {amodel_name} model config.") 121 | model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) 122 | else: 123 | logging.error( 124 | f"Model config for {amodel_name} not found; available models {list_models()}." 125 | ) 126 | raise RuntimeError(f"Model config for {amodel_name} not found.") 127 | 128 | if force_quick_gelu: 129 | # override for use of QuickGELU on non-OpenAI transformer models 130 | model_cfg["quick_gelu"] = True 131 | 132 | # if pretrained_image: 133 | # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}): 134 | # # pretrained weight loading for timm models set via vision_cfg 135 | # model_cfg['vision_cfg']['timm_model_pretrained'] = True 136 | # else: 137 | # assert False, 'pretrained image towers currently only supported for timm models' 138 | model_cfg["text_cfg"]["model_type"] = tmodel_name 139 | model_cfg["enable_fusion"] = enable_fusion 140 | model_cfg["fusion_type"] = fusion_type 141 | model = CLAP(**model_cfg) 142 | 143 | if pretrained: 144 | checkpoint_path = "" 145 | url = get_pretrained_url(amodel_name, pretrained) 146 | if url: 147 | checkpoint_path = download_pretrained(url, root=openai_model_cache_dir) 148 | elif os.path.exists(pretrained_orig): 149 | checkpoint_path = pretrained_orig 150 | if checkpoint_path: 151 | logging.info( 152 | f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})." 153 | ) 154 | ckpt = load_state_dict(checkpoint_path, skip_params=True) 155 | model.load_state_dict(ckpt) 156 | param_names = [n for n, p in model.named_parameters()] 157 | # for n in param_names: 158 | # print(n, "\t", "Loaded" if n in ckpt else "Unloaded") 159 | else: 160 | logging.warning( 161 | f"Pretrained weights ({pretrained}) not found for model {amodel_name}." 162 | ) 163 | raise RuntimeError( 164 | f"Pretrained weights ({pretrained}) not found for model {amodel_name}." 165 | ) 166 | 167 | if pretrained_audio: 168 | if amodel_name.startswith("PANN"): 169 | if "Cnn14_mAP" in pretrained_audio: # official checkpoint 170 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 171 | audio_ckpt = audio_ckpt["model"] 172 | keys = list(audio_ckpt.keys()) 173 | for key in keys: 174 | if ( 175 | "spectrogram_extractor" not in key 176 | and "logmel_extractor" not in key 177 | ): 178 | v = audio_ckpt.pop(key) 179 | audio_ckpt["audio_branch." + key] = v 180 | elif os.path.basename(pretrained_audio).startswith( 181 | "PANN" 182 | ): # checkpoint trained via HTSAT codebase 183 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 184 | audio_ckpt = audio_ckpt["state_dict"] 185 | keys = list(audio_ckpt.keys()) 186 | for key in keys: 187 | if key.startswith("sed_model"): 188 | v = audio_ckpt.pop(key) 189 | audio_ckpt["audio_branch." + key[10:]] = v 190 | elif os.path.basename(pretrained_audio).startswith( 191 | "finetuned" 192 | ): # checkpoint trained via linear probe codebase 193 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 194 | else: 195 | raise ValueError("Unknown audio checkpoint") 196 | elif amodel_name.startswith("HTSAT"): 197 | if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint 198 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 199 | audio_ckpt = audio_ckpt["state_dict"] 200 | keys = list(audio_ckpt.keys()) 201 | for key in keys: 202 | if key.startswith("sed_model") and ( 203 | "spectrogram_extractor" not in key 204 | and "logmel_extractor" not in key 205 | ): 206 | v = audio_ckpt.pop(key) 207 | audio_ckpt["audio_branch." + key[10:]] = v 208 | elif os.path.basename(pretrained_audio).startswith( 209 | "HTSAT" 210 | ): # checkpoint trained via HTSAT codebase 211 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 212 | audio_ckpt = audio_ckpt["state_dict"] 213 | keys = list(audio_ckpt.keys()) 214 | for key in keys: 215 | if key.startswith("sed_model"): 216 | v = audio_ckpt.pop(key) 217 | audio_ckpt["audio_branch." + key[10:]] = v 218 | elif os.path.basename(pretrained_audio).startswith( 219 | "finetuned" 220 | ): # checkpoint trained via linear probe codebase 221 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 222 | else: 223 | raise ValueError("Unknown audio checkpoint") 224 | else: 225 | raise f"this audio encoder pretrained checkpoint is not support" 226 | 227 | model.load_state_dict(audio_ckpt, strict=False) 228 | logging.info( 229 | f"Loading pretrained {amodel_name} weights ({pretrained_audio})." 230 | ) 231 | param_names = [n for n, p in model.named_parameters()] 232 | for n in param_names: 233 | print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded") 234 | 235 | model.to(device=device) 236 | if precision == "fp16": 237 | assert device.type != "cpu" 238 | convert_weights_to_fp16(model) 239 | 240 | if jit: 241 | model = torch.jit.script(model) 242 | 243 | return model, model_cfg 244 | 245 | 246 | def create_model_and_transforms( 247 | model_name: str, 248 | pretrained: str = "", 249 | precision: str = "fp32", 250 | device: torch.device = torch.device("cpu"), 251 | jit: bool = False, 252 | force_quick_gelu: bool = False, 253 | # pretrained_image: bool = False, 254 | ): 255 | model = create_model( 256 | model_name, 257 | pretrained, 258 | precision, 259 | device, 260 | jit, 261 | force_quick_gelu=force_quick_gelu, 262 | # pretrained_image=pretrained_image 263 | ) 264 | preprocess_train = image_transform(model.visual.image_size, is_train=True) 265 | preprocess_val = image_transform(model.visual.image_size, is_train=False) 266 | return model, preprocess_train, preprocess_val 267 | 268 | 269 | def list_models(): 270 | """enumerate available model architectures based on config files""" 271 | return list(_MODEL_CONFIGS.keys()) 272 | 273 | 274 | def add_model_config(path): 275 | """add model config path or file and update registry""" 276 | if not isinstance(path, Path): 277 | path = Path(path) 278 | _MODEL_CONFIG_PATHS.append(path) 279 | _rescan_model_configs() 280 | -------------------------------------------------------------------------------- /audioldm/clap/open_clip/feature_fusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Feature Fusion for Varible-Length Data Processing 3 | AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py 4 | According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class DAF(nn.Module): 12 | """ 13 | 直接相加 DirectAddFuse 14 | """ 15 | 16 | def __init__(self): 17 | super(DAF, self).__init__() 18 | 19 | def forward(self, x, residual): 20 | return x + residual 21 | 22 | 23 | class iAFF(nn.Module): 24 | """ 25 | 多特征融合 iAFF 26 | """ 27 | 28 | def __init__(self, channels=64, r=4, type="2D"): 29 | super(iAFF, self).__init__() 30 | inter_channels = int(channels // r) 31 | 32 | if type == "1D": 33 | # 本地注意力 34 | self.local_att = nn.Sequential( 35 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 36 | nn.BatchNorm1d(inter_channels), 37 | nn.ReLU(inplace=True), 38 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 39 | nn.BatchNorm1d(channels), 40 | ) 41 | 42 | # 全局注意力 43 | self.global_att = nn.Sequential( 44 | nn.AdaptiveAvgPool1d(1), 45 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 46 | nn.BatchNorm1d(inter_channels), 47 | nn.ReLU(inplace=True), 48 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 49 | nn.BatchNorm1d(channels), 50 | ) 51 | 52 | # 第二次本地注意力 53 | self.local_att2 = nn.Sequential( 54 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 55 | nn.BatchNorm1d(inter_channels), 56 | nn.ReLU(inplace=True), 57 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 58 | nn.BatchNorm1d(channels), 59 | ) 60 | # 第二次全局注意力 61 | self.global_att2 = nn.Sequential( 62 | nn.AdaptiveAvgPool1d(1), 63 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 64 | nn.BatchNorm1d(inter_channels), 65 | nn.ReLU(inplace=True), 66 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 67 | nn.BatchNorm1d(channels), 68 | ) 69 | elif type == "2D": 70 | # 本地注意力 71 | self.local_att = nn.Sequential( 72 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 73 | nn.BatchNorm2d(inter_channels), 74 | nn.ReLU(inplace=True), 75 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 76 | nn.BatchNorm2d(channels), 77 | ) 78 | 79 | # 全局注意力 80 | self.global_att = nn.Sequential( 81 | nn.AdaptiveAvgPool2d(1), 82 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 83 | nn.BatchNorm2d(inter_channels), 84 | nn.ReLU(inplace=True), 85 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 86 | nn.BatchNorm2d(channels), 87 | ) 88 | 89 | # 第二次本地注意力 90 | self.local_att2 = nn.Sequential( 91 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 92 | nn.BatchNorm2d(inter_channels), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 95 | nn.BatchNorm2d(channels), 96 | ) 97 | # 第二次全局注意力 98 | self.global_att2 = nn.Sequential( 99 | nn.AdaptiveAvgPool2d(1), 100 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 101 | nn.BatchNorm2d(inter_channels), 102 | nn.ReLU(inplace=True), 103 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 104 | nn.BatchNorm2d(channels), 105 | ) 106 | else: 107 | raise f"the type is not supported" 108 | 109 | self.sigmoid = nn.Sigmoid() 110 | 111 | def forward(self, x, residual): 112 | flag = False 113 | xa = x + residual 114 | if xa.size(0) == 1: 115 | xa = torch.cat([xa, xa], dim=0) 116 | flag = True 117 | xl = self.local_att(xa) 118 | xg = self.global_att(xa) 119 | xlg = xl + xg 120 | wei = self.sigmoid(xlg) 121 | xi = x * wei + residual * (1 - wei) 122 | 123 | xl2 = self.local_att2(xi) 124 | xg2 = self.global_att(xi) 125 | xlg2 = xl2 + xg2 126 | wei2 = self.sigmoid(xlg2) 127 | xo = x * wei2 + residual * (1 - wei2) 128 | if flag: 129 | xo = xo[0].unsqueeze(0) 130 | return xo 131 | 132 | 133 | class AFF(nn.Module): 134 | """ 135 | 多特征融合 AFF 136 | """ 137 | 138 | def __init__(self, channels=64, r=4, type="2D"): 139 | super(AFF, self).__init__() 140 | inter_channels = int(channels // r) 141 | 142 | if type == "1D": 143 | self.local_att = nn.Sequential( 144 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 145 | nn.BatchNorm1d(inter_channels), 146 | nn.ReLU(inplace=True), 147 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 148 | nn.BatchNorm1d(channels), 149 | ) 150 | self.global_att = nn.Sequential( 151 | nn.AdaptiveAvgPool1d(1), 152 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 153 | nn.BatchNorm1d(inter_channels), 154 | nn.ReLU(inplace=True), 155 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 156 | nn.BatchNorm1d(channels), 157 | ) 158 | elif type == "2D": 159 | self.local_att = nn.Sequential( 160 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 161 | nn.BatchNorm2d(inter_channels), 162 | nn.ReLU(inplace=True), 163 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 164 | nn.BatchNorm2d(channels), 165 | ) 166 | self.global_att = nn.Sequential( 167 | nn.AdaptiveAvgPool2d(1), 168 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 169 | nn.BatchNorm2d(inter_channels), 170 | nn.ReLU(inplace=True), 171 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 172 | nn.BatchNorm2d(channels), 173 | ) 174 | else: 175 | raise f"the type is not supported." 176 | 177 | self.sigmoid = nn.Sigmoid() 178 | 179 | def forward(self, x, residual): 180 | flag = False 181 | xa = x + residual 182 | if xa.size(0) == 1: 183 | xa = torch.cat([xa, xa], dim=0) 184 | flag = True 185 | xl = self.local_att(xa) 186 | xg = self.global_att(xa) 187 | xlg = xl + xg 188 | wei = self.sigmoid(xlg) 189 | xo = 2 * x * wei + 2 * residual * (1 - wei) 190 | if flag: 191 | xo = xo[0].unsqueeze(0) 192 | return xo 193 | -------------------------------------------------------------------------------- /audioldm/clap/open_clip/linear_probe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from .model import MLPLayers 5 | 6 | 7 | class LinearProbe(nn.Module): 8 | def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): 9 | """ 10 | Args: 11 | model: nn.Module 12 | mlp: bool, if True, then use the MLP layer as the linear probe module 13 | freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe 14 | in_ch: int, the output channel from CLAP model 15 | out_ch: int, the output channel from linear probe (class_num) 16 | act: torch.nn.functional, the activation function before the loss function 17 | """ 18 | super().__init__() 19 | in_ch = 512 20 | self.clap_model = model 21 | self.clap_model.text_branch = None # to save memory 22 | self.freeze = freeze 23 | if mlp: 24 | self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) 25 | else: 26 | self.lp_layer = nn.Linear(in_ch, out_ch) 27 | 28 | if self.freeze: 29 | for param in self.clap_model.parameters(): 30 | param.requires_grad = False 31 | 32 | if act == "None": 33 | self.act = None 34 | elif act == "relu": 35 | self.act = nn.ReLU() 36 | elif act == "elu": 37 | self.act = nn.ELU() 38 | elif act == "prelu": 39 | self.act = nn.PReLU(num_parameters=in_ch) 40 | elif act == "softmax": 41 | self.act = nn.Softmax(dim=-1) 42 | elif act == "sigmoid": 43 | self.act = nn.Sigmoid() 44 | 45 | def forward(self, x, mix_lambda=None, device=None): 46 | """ 47 | Args: 48 | x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list 49 | mix_lambda: torch.tensor [batch], the mixup lambda 50 | Returns: 51 | class_prob: torch.tensor [batch, class_num] 52 | 53 | """ 54 | # batchnorm cancel grandient 55 | if self.freeze: 56 | self.clap_model.eval() 57 | 58 | x = self.clap_model.audio_projection( 59 | self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[ 60 | "embedding" 61 | ] 62 | ) 63 | out = self.lp_layer(x) 64 | if self.act is not None: 65 | out = self.act(out) 66 | return out 67 | -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/HTSAT-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "base" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/HTSAT-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "large" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/HTSAT-tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/PANN-10.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn10" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 18000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 960000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 360, 10 | "fmin": 50, 11 | "fmax": 8000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 4 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/PANN-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/PANN-6.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn6" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /audioldm/clap/open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import Union, List 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict 13 | from .pretrained import ( 14 | get_pretrained_url, 15 | list_pretrained_tag_models, 16 | download_pretrained, 17 | ) 18 | 19 | __all__ = ["list_openai_models", "load_openai_model"] 20 | 21 | CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache") 22 | 23 | 24 | 25 | def list_openai_models() -> List[str]: 26 | """Returns the names of available CLIP models""" 27 | return list_pretrained_tag_models("openai") 28 | 29 | 30 | def load_openai_model( 31 | name: str, 32 | model_cfg, 33 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 34 | jit=True, 35 | cache_dir=os.path.expanduser(f"{CACHE_DIR}/clip"), 36 | enable_fusion: bool = False, 37 | fusion_type: str = "None", 38 | ): 39 | """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model 40 | 41 | Parameters 42 | ---------- 43 | name : str 44 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 45 | device : Union[str, torch.device] 46 | The device to put the loaded model 47 | jit : bool 48 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 49 | 50 | Returns 51 | ------- 52 | model : torch.nn.Module 53 | The CLAP model 54 | preprocess : Callable[[PIL.Image], torch.Tensor] 55 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 56 | """ 57 | if get_pretrained_url(name, "openai"): 58 | model_path = download_pretrained( 59 | get_pretrained_url(name, "openai"), root=cache_dir 60 | ) 61 | elif os.path.isfile(name): 62 | model_path = name 63 | else: 64 | raise RuntimeError( 65 | f"Model {name} not found; available models = {list_openai_models()}" 66 | ) 67 | 68 | try: 69 | # loading JIT archive 70 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 71 | state_dict = None 72 | except RuntimeError: 73 | # loading saved state dict 74 | if jit: 75 | warnings.warn( 76 | f"File {model_path} is not a JIT archive. Loading as a state dict instead" 77 | ) 78 | jit = False 79 | state_dict = torch.load(model_path, map_location="cpu") 80 | 81 | if not jit: 82 | try: 83 | model = build_model_from_openai_state_dict( 84 | state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type 85 | ).to(device) 86 | except KeyError: 87 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 88 | model = build_model_from_openai_state_dict( 89 | sd, model_cfg, enable_fusion, fusion_type 90 | ).to(device) 91 | 92 | if str(device) == "cpu": 93 | model.float() 94 | return model 95 | 96 | # patch the device names 97 | device_holder = torch.jit.trace( 98 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] 99 | ) 100 | device_node = [ 101 | n 102 | for n in device_holder.graph.findAllNodes("prim::Constant") 103 | if "Device" in repr(n) 104 | ][-1] 105 | 106 | def patch_device(module): 107 | try: 108 | graphs = [module.graph] if hasattr(module, "graph") else [] 109 | except RuntimeError: 110 | graphs = [] 111 | 112 | if hasattr(module, "forward1"): 113 | graphs.append(module.forward1.graph) 114 | 115 | for graph in graphs: 116 | for node in graph.findAllNodes("prim::Constant"): 117 | if "value" in node.attributeNames() and str(node["value"]).startswith( 118 | "cuda" 119 | ): 120 | node.copyAttributes(device_node) 121 | 122 | model.apply(patch_device) 123 | patch_device(model.encode_audio) 124 | patch_device(model.encode_text) 125 | 126 | # patch dtype to float32 on CPU 127 | if str(device) == "cpu": 128 | float_holder = torch.jit.trace( 129 | lambda: torch.ones([]).float(), example_inputs=[] 130 | ) 131 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 132 | float_node = float_input.node() 133 | 134 | def patch_float(module): 135 | try: 136 | graphs = [module.graph] if hasattr(module, "graph") else [] 137 | except RuntimeError: 138 | graphs = [] 139 | 140 | if hasattr(module, "forward1"): 141 | graphs.append(module.forward1.graph) 142 | 143 | for graph in graphs: 144 | for node in graph.findAllNodes("aten::to"): 145 | inputs = list(node.inputs()) 146 | for i in [ 147 | 1, 148 | 2, 149 | ]: # dtype can be the second or third argument to aten::to() 150 | if inputs[i].node()["value"] == 5: 151 | inputs[i].node().copyAttributes(float_node) 152 | 153 | model.apply(patch_float) 154 | patch_float(model.encode_audio) 155 | patch_float(model.encode_text) 156 | model.float() 157 | 158 | model.audio_branch.audio_length = model.audio_cfg.audio_length 159 | return model 160 | -------------------------------------------------------------------------------- /audioldm/clap/open_clip/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | 6 | from tqdm import tqdm 7 | 8 | CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache") 9 | 10 | _RN50 = dict( 11 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 12 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 13 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", 14 | ) 15 | 16 | _RN50_quickgelu = dict( 17 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 18 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 19 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", 20 | ) 21 | 22 | _RN101 = dict( 23 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 24 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", 25 | ) 26 | 27 | _RN101_quickgelu = dict( 28 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 29 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", 30 | ) 31 | 32 | _RN50x4 = dict( 33 | openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | ) 35 | 36 | _RN50x16 = dict( 37 | openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 38 | ) 39 | 40 | _RN50x64 = dict( 41 | openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 42 | ) 43 | 44 | _VITB32 = dict( 45 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 46 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 47 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 48 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", 49 | ) 50 | 51 | _VITB32_quickgelu = dict( 52 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 53 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 54 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 55 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", 56 | ) 57 | 58 | _VITB16 = dict( 59 | openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 60 | ) 61 | 62 | _VITL14 = dict( 63 | openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 64 | ) 65 | 66 | _PRETRAINED = { 67 | "RN50": _RN50, 68 | "RN50-quickgelu": _RN50_quickgelu, 69 | "RN101": _RN101, 70 | "RN101-quickgelu": _RN101_quickgelu, 71 | "RN50x4": _RN50x4, 72 | "RN50x16": _RN50x16, 73 | "ViT-B-32": _VITB32, 74 | "ViT-B-32-quickgelu": _VITB32_quickgelu, 75 | "ViT-B-16": _VITB16, 76 | "ViT-L-14": _VITL14, 77 | } 78 | 79 | 80 | def list_pretrained(as_str: bool = False): 81 | """returns list of pretrained models 82 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 83 | """ 84 | return [ 85 | ":".join([k, t]) if as_str else (k, t) 86 | for k in _PRETRAINED.keys() 87 | for t in _PRETRAINED[k].keys() 88 | ] 89 | 90 | 91 | def list_pretrained_tag_models(tag: str): 92 | """return all models having the specified pretrain tag""" 93 | models = [] 94 | for k in _PRETRAINED.keys(): 95 | if tag in _PRETRAINED[k]: 96 | models.append(k) 97 | return models 98 | 99 | 100 | def list_pretrained_model_tags(model: str): 101 | """return all pretrain tags for the specified model architecture""" 102 | tags = [] 103 | if model in _PRETRAINED: 104 | tags.extend(_PRETRAINED[model].keys()) 105 | return tags 106 | 107 | 108 | def get_pretrained_url(model: str, tag: str): 109 | if model not in _PRETRAINED: 110 | return "" 111 | model_pretrained = _PRETRAINED[model] 112 | if tag not in model_pretrained: 113 | return "" 114 | return model_pretrained[tag] 115 | 116 | 117 | def download_pretrained(url: str, root: str = os.path.expanduser(f"{CACHE_DIR}/clip")): 118 | os.makedirs(root, exist_ok=True) 119 | filename = os.path.basename(url) 120 | 121 | if "openaipublic" in url: 122 | expected_sha256 = url.split("/")[-2] 123 | else: 124 | expected_sha256 = "" 125 | 126 | download_target = os.path.join(root, filename) 127 | 128 | if os.path.exists(download_target) and not os.path.isfile(download_target): 129 | raise RuntimeError(f"{download_target} exists and is not a regular file") 130 | 131 | if os.path.isfile(download_target): 132 | if expected_sha256: 133 | if ( 134 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 135 | == expected_sha256 136 | ): 137 | return download_target 138 | else: 139 | warnings.warn( 140 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 141 | ) 142 | else: 143 | return download_target 144 | 145 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 146 | with tqdm( 147 | total=int(source.info().get("Content-Length")), 148 | ncols=80, 149 | unit="iB", 150 | unit_scale=True, 151 | ) as loop: 152 | while True: 153 | buffer = source.read(8192) 154 | if not buffer: 155 | break 156 | 157 | output.write(buffer) 158 | loop.update(len(buffer)) 159 | 160 | if ( 161 | expected_sha256 162 | and hashlib.sha256(open(download_target, "rb").read()).hexdigest() 163 | != expected_sha256 164 | ): 165 | raise RuntimeError( 166 | f"Model has been downloaded but the SHA256 checksum does not not match" 167 | ) 168 | 169 | return download_target 170 | -------------------------------------------------------------------------------- /audioldm/clap/open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch.nn as nn 8 | 9 | try: 10 | import timm 11 | from timm.models.layers import Mlp, to_2tuple 12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 13 | from timm.models.layers.attention_pool2d import ( 14 | AttentionPool2d as AbsAttentionPool2d, 15 | ) 16 | except ImportError as e: 17 | timm = None 18 | 19 | from .utils import freeze_batch_norm_2d 20 | 21 | 22 | class TimmModel(nn.Module): 23 | """timm model adapter 24 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 25 | """ 26 | 27 | def __init__( 28 | self, 29 | model_name, 30 | embed_dim, 31 | image_size=224, 32 | pool="avg", 33 | proj="linear", 34 | drop=0.0, 35 | pretrained=False, 36 | ): 37 | super().__init__() 38 | if timm is None: 39 | raise RuntimeError("Please `pip install timm` to use timm models.") 40 | 41 | self.image_size = to_2tuple(image_size) 42 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 43 | feat_size = self.trunk.default_cfg.get("pool_size", None) 44 | feature_ndim = 1 if not feat_size else 2 45 | if pool in ("abs_attn", "rot_attn"): 46 | assert feature_ndim == 2 47 | # if attn pooling used, remove both classifier and default pool 48 | self.trunk.reset_classifier(0, global_pool="") 49 | else: 50 | # reset global pool if pool config set, otherwise leave as network default 51 | reset_kwargs = dict(global_pool=pool) if pool else {} 52 | self.trunk.reset_classifier(0, **reset_kwargs) 53 | prev_chs = self.trunk.num_features 54 | 55 | head_layers = OrderedDict() 56 | if pool == "abs_attn": 57 | head_layers["pool"] = AbsAttentionPool2d( 58 | prev_chs, feat_size=feat_size, out_features=embed_dim 59 | ) 60 | prev_chs = embed_dim 61 | elif pool == "rot_attn": 62 | head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 63 | prev_chs = embed_dim 64 | else: 65 | assert proj, "projection layer needed if non-attention pooling is used." 66 | 67 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 68 | if proj == "linear": 69 | head_layers["drop"] = nn.Dropout(drop) 70 | head_layers["proj"] = nn.Linear(prev_chs, embed_dim) 71 | elif proj == "mlp": 72 | head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) 73 | 74 | self.head = nn.Sequential(head_layers) 75 | 76 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 77 | """lock modules 78 | Args: 79 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 80 | """ 81 | if not unlocked_groups: 82 | # lock full model 83 | for param in self.trunk.parameters(): 84 | param.requires_grad = False 85 | if freeze_bn_stats: 86 | freeze_batch_norm_2d(self.trunk) 87 | else: 88 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 89 | try: 90 | # FIXME import here until API stable and in an official release 91 | from timm.models.helpers import group_parameters, group_modules 92 | except ImportError: 93 | raise RuntimeError( 94 | "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`" 95 | ) 96 | matcher = self.trunk.group_matcher() 97 | gparams = group_parameters(self.trunk, matcher) 98 | max_layer_id = max(gparams.keys()) 99 | max_layer_id = max_layer_id - unlocked_groups 100 | for group_idx in range(max_layer_id + 1): 101 | group = gparams[group_idx] 102 | for param in group: 103 | self.trunk.get_parameter(param).requires_grad = False 104 | if freeze_bn_stats: 105 | gmodules = group_modules(self.trunk, matcher, reverse=True) 106 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 107 | freeze_batch_norm_2d(self.trunk, gmodules) 108 | 109 | def forward(self, x): 110 | x = self.trunk(x) 111 | x = self.head(x) 112 | return x 113 | -------------------------------------------------------------------------------- /audioldm/clap/open_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | 16 | @lru_cache() 17 | def default_bpe(): 18 | return os.path.join( 19 | os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" 20 | ) 21 | 22 | 23 | @lru_cache() 24 | def bytes_to_unicode(): 25 | """ 26 | Returns list of utf-8 byte and a corresponding list of unicode strings. 27 | The reversible bpe codes work on unicode strings. 28 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 29 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 30 | This is a signficant percentage of your normal, say, 32K bpe vocab. 31 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 32 | And avoids mapping to whitespace/control characters the bpe code barfs on. 33 | """ 34 | bs = ( 35 | list(range(ord("!"), ord("~") + 1)) 36 | + list(range(ord("¡"), ord("¬") + 1)) 37 | + list(range(ord("®"), ord("ÿ") + 1)) 38 | ) 39 | cs = bs[:] 40 | n = 0 41 | for b in range(2**8): 42 | if b not in bs: 43 | bs.append(b) 44 | cs.append(2**8 + n) 45 | n += 1 46 | cs = [chr(n) for n in cs] 47 | return dict(zip(bs, cs)) 48 | 49 | 50 | def get_pairs(word): 51 | """Return set of symbol pairs in a word. 52 | Word is represented as tuple of symbols (symbols being variable-length strings). 53 | """ 54 | pairs = set() 55 | prev_char = word[0] 56 | for char in word[1:]: 57 | pairs.add((prev_char, char)) 58 | prev_char = char 59 | return pairs 60 | 61 | 62 | def basic_clean(text): 63 | text = ftfy.fix_text(text) 64 | text = html.unescape(html.unescape(text)) 65 | return text.strip() 66 | 67 | 68 | def whitespace_clean(text): 69 | text = re.sub(r"\s+", " ", text) 70 | text = text.strip() 71 | return text 72 | 73 | 74 | class SimpleTokenizer(object): 75 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 76 | self.byte_encoder = bytes_to_unicode() 77 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 78 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") 79 | merges = merges[1 : 49152 - 256 - 2 + 1] 80 | merges = [tuple(merge.split()) for merge in merges] 81 | vocab = list(bytes_to_unicode().values()) 82 | vocab = vocab + [v + "" for v in vocab] 83 | for merge in merges: 84 | vocab.append("".join(merge)) 85 | if not special_tokens: 86 | special_tokens = ["", ""] 87 | else: 88 | special_tokens = ["", ""] + special_tokens 89 | vocab.extend(special_tokens) 90 | self.encoder = dict(zip(vocab, range(len(vocab)))) 91 | self.decoder = {v: k for k, v in self.encoder.items()} 92 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 93 | self.cache = {t: t for t in special_tokens} 94 | special = "|".join(special_tokens) 95 | self.pat = re.compile( 96 | special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 97 | re.IGNORECASE, 98 | ) 99 | 100 | self.vocab_size = len(self.encoder) 101 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 102 | 103 | def bpe(self, token): 104 | if token in self.cache: 105 | return self.cache[token] 106 | word = tuple(token[:-1]) + (token[-1] + "",) 107 | pairs = get_pairs(word) 108 | 109 | if not pairs: 110 | return token + "" 111 | 112 | while True: 113 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 114 | if bigram not in self.bpe_ranks: 115 | break 116 | first, second = bigram 117 | new_word = [] 118 | i = 0 119 | while i < len(word): 120 | try: 121 | j = word.index(first, i) 122 | new_word.extend(word[i:j]) 123 | i = j 124 | except: 125 | new_word.extend(word[i:]) 126 | break 127 | 128 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 129 | new_word.append(first + second) 130 | i += 2 131 | else: 132 | new_word.append(word[i]) 133 | i += 1 134 | new_word = tuple(new_word) 135 | word = new_word 136 | if len(word) == 1: 137 | break 138 | else: 139 | pairs = get_pairs(word) 140 | word = " ".join(word) 141 | self.cache[token] = word 142 | return word 143 | 144 | def encode(self, text): 145 | bpe_tokens = [] 146 | text = whitespace_clean(basic_clean(text)).lower() 147 | for token in re.findall(self.pat, text): 148 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 149 | bpe_tokens.extend( 150 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") 151 | ) 152 | return bpe_tokens 153 | 154 | def decode(self, tokens): 155 | text = "".join([self.decoder[token] for token in tokens]) 156 | text = ( 157 | bytearray([self.byte_decoder[c] for c in text]) 158 | .decode("utf-8", errors="replace") 159 | .replace("", " ") 160 | ) 161 | return text 162 | 163 | 164 | _tokenizer = SimpleTokenizer() 165 | 166 | 167 | def tokenize( 168 | texts: Union[str, List[str]], context_length: int = 77 169 | ) -> torch.LongTensor: 170 | """ 171 | Returns the tokenized representation of given input string(s) 172 | 173 | Parameters 174 | ---------- 175 | texts : Union[str, List[str]] 176 | An input string or a list of input strings to tokenize 177 | context_length : int 178 | The context length to use; all CLIP models use 77 as the context length 179 | 180 | Returns 181 | ------- 182 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 183 | """ 184 | if isinstance(texts, str): 185 | texts = [texts] 186 | 187 | sot_token = _tokenizer.encoder[""] 188 | eot_token = _tokenizer.encoder[""] 189 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 190 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 191 | 192 | for i, tokens in enumerate(all_tokens): 193 | if len(tokens) > context_length: 194 | tokens = tokens[:context_length] # Truncate 195 | result[i, : len(tokens)] = torch.tensor(tokens) 196 | 197 | return result 198 | -------------------------------------------------------------------------------- /audioldm/clap/open_clip/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import ( 2 | Normalize, 3 | Compose, 4 | RandomResizedCrop, 5 | InterpolationMode, 6 | ToTensor, 7 | Resize, 8 | CenterCrop, 9 | ) 10 | 11 | 12 | def _convert_to_rgb(image): 13 | return image.convert("RGB") 14 | 15 | 16 | def image_transform( 17 | image_size: int, 18 | is_train: bool, 19 | mean=(0.48145466, 0.4578275, 0.40821073), 20 | std=(0.26862954, 0.26130258, 0.27577711), 21 | ): 22 | normalize = Normalize(mean=mean, std=std) 23 | if is_train: 24 | return Compose( 25 | [ 26 | RandomResizedCrop( 27 | image_size, 28 | scale=(0.9, 1.0), 29 | interpolation=InterpolationMode.BICUBIC, 30 | ), 31 | _convert_to_rgb, 32 | ToTensor(), 33 | normalize, 34 | ] 35 | ) 36 | else: 37 | return Compose( 38 | [ 39 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 40 | CenterCrop(image_size), 41 | _convert_to_rgb, 42 | ToTensor(), 43 | normalize, 44 | ] 45 | ) 46 | -------------------------------------------------------------------------------- /audioldm/clap/open_clip/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn as nn 4 | from torchvision.ops.misc import FrozenBatchNorm2d 5 | import logging 6 | 7 | # import h5py 8 | from tqdm import tqdm 9 | import random 10 | import json 11 | import os 12 | import pathlib 13 | 14 | # TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later. 15 | dataset_split = { 16 | "audiocaps": ["train", "valid", "test"], 17 | "audioset": ["balanced_train", "unbalanced_train", "eval"], 18 | "BBCSoundEffects": ["train", "test"], 19 | "Clotho": ["train", "test", "valid"], 20 | "free_to_use_sounds": ["train", "test"], 21 | "paramount_motion": ["train", "test"], 22 | "sonniss_game_effects": ["train", "test"], 23 | "wesoundeffects": ["train", "test"], 24 | "MACS": ["train", "test"], 25 | "freesound": ["train", "test"], 26 | "FSD50K": ["train", "test", "valid"], 27 | "fsd50k_class_label": ["train", "test", "valid"], 28 | "esc50": ["train", "test"], 29 | "audiostock": ["train", "test"], 30 | "freesound_no_overlap_noesc50": ["train", "test"], 31 | "epidemic_sound_effects": ["train", "test"], 32 | "VGGSound": ["train", "test"], 33 | "urbansound8k_class_label": ["train", "test"], 34 | "audioset_t5": ["balanced_train", "unbalanced_train", "eval"], 35 | "epidemic_sound_effects_t5": ["train", "test"], 36 | "WavText5K": ["train", "test"], 37 | "esc50_no_overlap": ["train", "test"], 38 | "usd8k_no_overlap": ["train", "test"], 39 | "fsd50k_200_class_label": ["train", "test", "valid"], 40 | } 41 | 42 | 43 | def freeze_batch_norm_2d(module, module_match={}, name=""): 44 | """ 45 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 46 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 47 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 48 | 49 | Args: 50 | module (torch.nn.Module): Any PyTorch module. 51 | module_match (dict): Dictionary of full module names to freeze (all if empty) 52 | name (str): Full module name (prefix) 53 | 54 | Returns: 55 | torch.nn.Module: Resulting module 56 | 57 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 58 | """ 59 | res = module 60 | is_match = True 61 | if module_match: 62 | is_match = name in module_match 63 | if is_match and isinstance( 64 | module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) 65 | ): 66 | res = FrozenBatchNorm2d(module.num_features) 67 | res.num_features = module.num_features 68 | res.affine = module.affine 69 | if module.affine: 70 | res.weight.data = module.weight.data.clone().detach() 71 | res.bias.data = module.bias.data.clone().detach() 72 | res.running_mean.data = module.running_mean.data 73 | res.running_var.data = module.running_var.data 74 | res.eps = module.eps 75 | else: 76 | for child_name, child in module.named_children(): 77 | full_child_name = ".".join([name, child_name]) if name else child_name 78 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 79 | if new_child is not child: 80 | res.add_module(child_name, new_child) 81 | return res 82 | 83 | 84 | def exist(dataset_name, dataset_type): 85 | """ 86 | Check if dataset exists 87 | """ 88 | if dataset_type in dataset_split[dataset_name]: 89 | return True 90 | else: 91 | return False 92 | 93 | 94 | def get_tar_path_from_dataset_name( 95 | dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None 96 | ): 97 | """ 98 | Get tar path from dataset name and type 99 | """ 100 | output = [] 101 | for n in dataset_names: 102 | if full_dataset is not None and n in full_dataset: 103 | current_dataset_types = dataset_split[n] 104 | else: 105 | current_dataset_types = dataset_types 106 | for s in current_dataset_types: 107 | tmp = [] 108 | if islocal: 109 | sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json" 110 | if not os.path.exists(sizefilepath_): 111 | sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" 112 | else: 113 | sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" 114 | if not os.path.exists(sizefilepath_): 115 | continue 116 | sizes = json.load(open(sizefilepath_, "r")) 117 | for k in sizes.keys(): 118 | if islocal: 119 | tmp.append(f"{dataset_path}/{n}/{s}/{k}") 120 | else: 121 | tmp.append( 122 | f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -" 123 | ) 124 | if proportion != 1: 125 | tmp = random.sample(tmp, int(proportion * len(tmp))) 126 | output.append(tmp) 127 | return sum(output, []) 128 | 129 | 130 | def get_tar_path_from_txts(txt_path, islocal, proportion=1): 131 | """ 132 | Get tar path from txt path 133 | """ 134 | if isinstance(txt_path, (list, tuple)): 135 | return sum( 136 | [ 137 | get_tar_path_from_txts( 138 | txt_path[i], islocal=islocal, proportion=proportion 139 | ) 140 | for i in range(len(txt_path)) 141 | ], 142 | [], 143 | ) 144 | if isinstance(txt_path, str): 145 | with open(txt_path) as f: 146 | lines = f.readlines() 147 | if islocal: 148 | lines = [ 149 | lines[i] 150 | .split("\n")[0] 151 | .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/") 152 | for i in range(len(lines)) 153 | ] 154 | else: 155 | lines = [ 156 | lines[i].split("\n")[0].replace(".tar", ".tar -") 157 | for i in range(len(lines)) 158 | ] 159 | if proportion != 1: 160 | print("Sampling tars with proportion of {}".format(proportion)) 161 | lines = random.sample(lines, int(proportion * len(lines))) 162 | return lines 163 | 164 | 165 | def get_mix_lambda(mixup_alpha, batch_size): 166 | mixup_lambdas = [ 167 | np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size) 168 | ] 169 | return np.array(mixup_lambdas).astype(np.float32) 170 | 171 | 172 | def do_mixup(x, mixup_lambda): 173 | """ 174 | Args: 175 | x: (batch_size , ...) 176 | mixup_lambda: (batch_size,) 177 | Returns: 178 | out: (batch_size, ...) 179 | """ 180 | out = ( 181 | x.transpose(0, -1) * mixup_lambda 182 | + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda) 183 | ).transpose(0, -1) 184 | return out 185 | 186 | 187 | def interpolate(x, ratio): 188 | """Interpolate data in time domain. This is used to compensate the 189 | resolution reduction in downsampling of a CNN. 190 | 191 | Args: 192 | x: (batch_size, time_steps, classes_num) 193 | ratio: int, ratio to interpolate 194 | Returns: 195 | upsampled: (batch_size, time_steps * ratio, classes_num) 196 | """ 197 | (batch_size, time_steps, classes_num) = x.shape 198 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) 199 | upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) 200 | return upsampled 201 | 202 | 203 | def pad_framewise_output(framewise_output, frames_num): 204 | """Pad framewise_output to the same length as input frames. The pad value 205 | is the same as the value of the last frame. 206 | Args: 207 | framewise_output: (batch_size, frames_num, classes_num) 208 | frames_num: int, number of frames to pad 209 | Outputs: 210 | output: (batch_size, frames_num, classes_num) 211 | """ 212 | pad = framewise_output[:, -1:, :].repeat( 213 | 1, frames_num - framewise_output.shape[1], 1 214 | ) 215 | """tensor for padding""" 216 | 217 | output = torch.cat((framewise_output, pad), dim=1) 218 | """(batch_size, frames_num, classes_num)""" 219 | 220 | 221 | # def process_ipc(index_path, classes_num, filename): 222 | # # load data 223 | # logging.info("Load Data...............") 224 | # ipc = [[] for _ in range(classes_num)] 225 | # with h5py.File(index_path, "r") as f: 226 | # for i in tqdm(range(len(f["target"]))): 227 | # t_class = np.where(f["target"][i])[0] 228 | # for t in t_class: 229 | # ipc[t].append(i) 230 | # print(ipc) 231 | # np.save(filename, ipc) 232 | # logging.info("Load Data Succeed...............") 233 | 234 | 235 | def save_to_dict(s, o_={}): 236 | sp = s.split(": ") 237 | o_.update({sp[0]: float(sp[1])}) 238 | return o_ 239 | 240 | 241 | def get_data_from_log(txt_path): 242 | """ 243 | Output dictionary from out.txt log file 244 | """ 245 | with open(txt_path) as f: 246 | lines = f.readlines() 247 | val_data = {} 248 | train_data = {} 249 | train_losses = [] 250 | train_losses_epoch = [] 251 | for i in range(len(lines)): 252 | if "| INFO |" in lines[i]: 253 | if "Eval Epoch" in lines[i]: 254 | if "val_loss" in lines[i]: 255 | # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", "")) 256 | line = lines[i].split("Eval Epoch: ")[-1] 257 | num_epoch = int(line.split(" ")[0].split(" ")[0]) 258 | d = { 259 | line.split(" ")[0] 260 | .split(" ")[1] 261 | .replace(":", ""): float(line.split(" ")[0].split(" ")[-1]) 262 | } 263 | for i in range(1, len(line.split(" "))): 264 | d = save_to_dict(line.split(" ")[i], d) 265 | val_data[num_epoch] = d 266 | elif "Train Epoch" in lines[i]: 267 | num_epoch = int(lines[i].split("Train Epoch: ")[1][0]) 268 | loss = float(lines[i].split("Loss: ")[-1].split(" (")[0]) 269 | train_losses.append(loss) 270 | train_losses_epoch.append(num_epoch) 271 | for i in range(len(train_losses)): 272 | train_data[i] = { 273 | "num_epoch": train_losses_epoch[i], 274 | "train_loss": train_losses[i], 275 | } 276 | return train_data, val_data 277 | 278 | 279 | def save_p(obj, filename): 280 | import pickle 281 | 282 | try: 283 | from deepdiff import DeepDiff 284 | except: 285 | os.system("pip install deepdiff") 286 | from deepdiff import DeepDiff 287 | with open(filename, "wb") as file: 288 | pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol 289 | with open(filename, "rb") as file: 290 | z = pickle.load(file) 291 | assert ( 292 | DeepDiff(obj, z, ignore_string_case=True) == {} 293 | ), "there is something wrong with the saving process" 294 | return 295 | 296 | 297 | def load_p(filename): 298 | import pickle 299 | 300 | with open(filename, "rb") as file: 301 | z = pickle.load(file) 302 | return z 303 | 304 | 305 | def save_json(data, name="data.json"): 306 | import json 307 | 308 | with open(name, "w") as fp: 309 | json.dump(data, fp) 310 | return 311 | 312 | 313 | def load_json(name): 314 | import json 315 | 316 | with open(name, "r") as fp: 317 | data = json.load(fp) 318 | return data 319 | 320 | 321 | from multiprocessing import Process, Manager 322 | from multiprocessing import Process, Value, Array 323 | from ctypes import c_wchar 324 | 325 | 326 | def load_class_label(path): 327 | # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing 328 | # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array 329 | out = None 330 | if path is not None: 331 | if pathlib.Path(path).suffix in [".pkl", ".pickle"]: 332 | out = load_p(path) 333 | elif pathlib.Path(path).suffix in [".json", ".txt"]: 334 | out = load_json(path) 335 | elif pathlib.Path(path).suffix in [".npy", ".npz"]: 336 | out = np.load(path) 337 | elif pathlib.Path(path).suffix in [".csv"]: 338 | import pandas as pd 339 | 340 | out = pd.read_csv(path) 341 | return out 342 | # if out is None: 343 | # return None 344 | # else: 345 | # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False) 346 | # val = Array('i', out.values(), lock=False) 347 | # return (key, val) 348 | 349 | 350 | from torch import optim 351 | 352 | 353 | def get_optimizer(params, lr, betas, eps, momentum, optimizer_name): 354 | if optimizer_name.lower() == "adamw": 355 | optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps) 356 | elif optimizer_name.lower() == "sgd": 357 | optimizer = optim.SGD(params, lr=lr, momentum=momentum) 358 | elif optimizer_name.lower() == "adam": 359 | optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps) 360 | else: 361 | raise ValueError("optimizer name is not correct") 362 | return optimizer 363 | -------------------------------------------------------------------------------- /audioldm/clap/open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.1" 2 | -------------------------------------------------------------------------------- /audioldm/clap/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/audioldm/clap/training/__init__.py -------------------------------------------------------------------------------- /audioldm/clap/training/audioset_textmap.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/audioldm/clap/training/audioset_textmap.npy -------------------------------------------------------------------------------- /audioldm/clap/training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import socket 5 | 6 | try: 7 | import horovod.torch as hvd 8 | except ImportError: 9 | hvd = None 10 | 11 | 12 | def is_global_master(args): 13 | return args.rank == 0 14 | 15 | 16 | def is_local_master(args): 17 | return args.local_rank == 0 18 | 19 | 20 | def is_master(args, local=False): 21 | return is_local_master(args) if local else is_global_master(args) 22 | 23 | 24 | def is_using_horovod(): 25 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 26 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 27 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 28 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 29 | if all([var in os.environ for var in ompi_vars]) or all( 30 | [var in os.environ for var in pmi_vars] 31 | ): 32 | return True 33 | else: 34 | return False 35 | 36 | 37 | def is_using_distributed(): 38 | if "WORLD_SIZE" in os.environ: 39 | return int(os.environ["WORLD_SIZE"]) > 1 40 | if "SLURM_NTASKS" in os.environ: 41 | return int(os.environ["SLURM_NTASKS"]) > 1 42 | return False 43 | 44 | 45 | def world_info_from_env(): 46 | local_rank = 0 47 | for v in ( 48 | "SLURM_LOCALID", 49 | "MPI_LOCALRANKID", 50 | "OMPI_COMM_WORLD_LOCAL_RANK", 51 | "LOCAL_RANK", 52 | ): 53 | if v in os.environ: 54 | local_rank = int(os.environ[v]) 55 | break 56 | global_rank = 0 57 | for v in ("SLURM_PROCID", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "RANK"): 58 | if v in os.environ: 59 | global_rank = int(os.environ[v]) 60 | break 61 | world_size = 1 62 | for v in ("SLURM_NTASKS", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "WORLD_SIZE"): 63 | if v in os.environ: 64 | world_size = int(os.environ[v]) 65 | break 66 | 67 | return local_rank, global_rank, world_size 68 | 69 | 70 | def init_distributed_device(args): 71 | # Distributed training = training on more than one GPU. 72 | # Works in both single and multi-node scenarios. 73 | args.distributed = False 74 | args.world_size = 1 75 | args.rank = 0 # global rank 76 | args.local_rank = 0 77 | if args.horovod: 78 | assert hvd is not None, "Horovod is not installed" 79 | hvd.init() 80 | world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 81 | world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 82 | local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 83 | args.local_rank = local_rank 84 | args.rank = world_rank 85 | args.world_size = world_size 86 | # args.local_rank = int(hvd.local_rank()) 87 | # args.rank = hvd.rank() 88 | # args.world_size = hvd.size() 89 | args.distributed = True 90 | os.environ["LOCAL_RANK"] = str(args.local_rank) 91 | os.environ["RANK"] = str(args.rank) 92 | os.environ["WORLD_SIZE"] = str(args.world_size) 93 | print( 94 | f"Distributed training: local_rank={args.local_rank}, " 95 | f"rank={args.rank}, world_size={args.world_size}, " 96 | f"hostname={socket.gethostname()}, pid={os.getpid()}" 97 | ) 98 | elif is_using_distributed(): 99 | if "SLURM_PROCID" in os.environ: 100 | # DDP via SLURM 101 | args.local_rank, args.rank, args.world_size = world_info_from_env() 102 | # SLURM var -> torch.distributed vars in case needed 103 | os.environ["LOCAL_RANK"] = str(args.local_rank) 104 | os.environ["RANK"] = str(args.rank) 105 | os.environ["WORLD_SIZE"] = str(args.world_size) 106 | torch.distributed.init_process_group( 107 | backend=args.dist_backend, 108 | init_method=args.dist_url, 109 | world_size=args.world_size, 110 | rank=args.rank, 111 | ) 112 | elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster 113 | world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 114 | world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 115 | local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 116 | args.local_rank = local_rank 117 | args.rank = world_rank 118 | args.world_size = world_size 119 | torch.distributed.init_process_group( 120 | backend=args.dist_backend, 121 | init_method=args.dist_url, 122 | world_size=args.world_size, 123 | rank=args.rank, 124 | ) 125 | else: 126 | # DDP via torchrun, torch.distributed.launch 127 | args.local_rank, _, _ = world_info_from_env() 128 | torch.distributed.init_process_group( 129 | backend=args.dist_backend, init_method=args.dist_url 130 | ) 131 | args.world_size = torch.distributed.get_world_size() 132 | args.rank = torch.distributed.get_rank() 133 | args.distributed = True 134 | print( 135 | f"Distributed training: local_rank={args.local_rank}, " 136 | f"rank={args.rank}, world_size={args.world_size}, " 137 | f"hostname={socket.gethostname()}, pid={os.getpid()}" 138 | ) 139 | 140 | if torch.cuda.is_available(): 141 | if args.distributed and not args.no_set_device_rank: 142 | device = "cuda:%d" % args.local_rank 143 | else: 144 | device = "cuda:0" 145 | torch.cuda.set_device(device) 146 | else: 147 | device = "cpu" 148 | args.device = device 149 | device = torch.device(device) 150 | return device 151 | -------------------------------------------------------------------------------- /audioldm/clap/training/infer_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import os 4 | import torch 5 | import librosa 6 | from open_clip import create_model 7 | from training.data import get_audio_features 8 | from training.data import int16_to_float32, float32_to_int16 9 | from transformers import RobertaTokenizer 10 | 11 | tokenize = RobertaTokenizer.from_pretrained("roberta-base") 12 | 13 | 14 | def tokenizer(text): 15 | result = tokenize( 16 | text, 17 | padding="max_length", 18 | truncation=True, 19 | max_length=77, 20 | return_tensors="pt", 21 | ) 22 | return {k: v.squeeze(0) for k, v in result.items()} 23 | 24 | 25 | PRETRAINED_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/checkpoints/epoch_top_0_audioset_no_fusion.pt" 26 | WAVE_48k_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/audio/machine.wav" 27 | 28 | 29 | def infer_text(): 30 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 31 | precision = "fp32" 32 | amodel = "HTSAT-tiny" # or 'PANN-14' 33 | tmodel = "roberta" # the best text encoder in our training 34 | enable_fusion = False # False if you do not want to use the fusion model 35 | fusion_type = "aff_2d" 36 | pretrained = PRETRAINED_PATH 37 | 38 | model, model_cfg = create_model( 39 | amodel, 40 | tmodel, 41 | pretrained, 42 | precision=precision, 43 | device=device, 44 | enable_fusion=enable_fusion, 45 | fusion_type=fusion_type, 46 | ) 47 | # load the text, can be a list (i.e. batch size) 48 | text_data = ["I love the contrastive learning", "I love the pretrain model"] 49 | # tokenize for roberta, if you want to tokenize for another text encoder, please refer to data.py#L43-90 50 | text_data = tokenizer(text_data) 51 | 52 | text_embed = model.get_text_embedding(text_data) 53 | print(text_embed.size()) 54 | 55 | 56 | def infer_audio(): 57 | 58 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 59 | precision = "fp32" 60 | amodel = "HTSAT-tiny" # or 'PANN-14' 61 | tmodel = "roberta" # the best text encoder in our training 62 | enable_fusion = False # False if you do not want to use the fusion model 63 | fusion_type = "aff_2d" 64 | pretrained = PRETRAINED_PATH 65 | 66 | model, model_cfg = create_model( 67 | amodel, 68 | tmodel, 69 | pretrained, 70 | precision=precision, 71 | device=device, 72 | enable_fusion=enable_fusion, 73 | fusion_type=fusion_type, 74 | ) 75 | 76 | # load the waveform of the shape (T,), should resample to 48000 77 | audio_waveform, sr = librosa.load(WAVE_48k_PATH, sr=48000) 78 | # quantize 79 | audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) 80 | audio_waveform = torch.from_numpy(audio_waveform).float() 81 | audio_dict = {} 82 | 83 | # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode 84 | import ipdb 85 | 86 | ipdb.set_trace() 87 | audio_dict = get_audio_features( 88 | audio_dict, 89 | audio_waveform, 90 | 480000, 91 | data_truncating="fusion", 92 | data_filling="repeatpad", 93 | audio_cfg=model_cfg["audio_cfg"], 94 | ) 95 | # can send a list to the model, to process many audio tracks in one time (i.e. batch size) 96 | audio_embed = model.get_audio_embedding([audio_dict]) 97 | print(audio_embed.size()) 98 | import ipdb 99 | 100 | ipdb.set_trace() 101 | 102 | 103 | if __name__ == "__main__": 104 | infer_text() 105 | infer_audio() 106 | -------------------------------------------------------------------------------- /audioldm/clap/training/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | 8 | hostname = socket.gethostname() 9 | formatter = logging.Formatter( 10 | f"%(asctime)s | {hostname} | %(levelname)s | %(message)s", 11 | datefmt="%Y-%m-%d,%H:%M:%S", 12 | ) 13 | else: 14 | formatter = logging.Formatter( 15 | "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d,%H:%M:%S" 16 | ) 17 | 18 | logging.root.setLevel(level) 19 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 20 | for logger in loggers: 21 | logger.setLevel(level) 22 | 23 | stream_handler = logging.StreamHandler() 24 | stream_handler.setFormatter(formatter) 25 | logging.root.addHandler(stream_handler) 26 | 27 | if log_file: 28 | file_handler = logging.FileHandler(filename=log_file) 29 | file_handler.setFormatter(formatter) 30 | logging.root.addHandler(file_handler) 31 | -------------------------------------------------------------------------------- /audioldm/clap/training/lp_train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import math 4 | import os 5 | import time 6 | from contextlib import suppress 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | try: 13 | import wandb 14 | except ImportError: 15 | wandb = None 16 | 17 | from open_clip import LPLoss, LPMetrics, lp_gather_features 18 | from open_clip.utils import do_mixup, get_mix_lambda 19 | from .distributed import is_master 20 | from .zero_shot import zero_shot_eval 21 | 22 | 23 | class AverageMeter(object): 24 | """Computes and stores the average and current value""" 25 | 26 | def __init__(self): 27 | self.reset() 28 | 29 | def reset(self): 30 | self.val = 0 31 | self.avg = 0 32 | self.sum = 0 33 | self.count = 0 34 | 35 | def update(self, val, n=1): 36 | self.val = val 37 | self.sum += val * n 38 | self.count += n 39 | self.avg = self.sum / self.count 40 | 41 | 42 | def unwrap_model(model): 43 | if hasattr(model, "module"): 44 | return model.module 45 | else: 46 | return model 47 | 48 | 49 | def train_one_epoch( 50 | model, 51 | data, 52 | epoch, 53 | optimizer, 54 | scaler, 55 | scheduler, 56 | args, 57 | tb_writer=None, 58 | extra_suffix="", 59 | ): 60 | device = torch.device(args.device) 61 | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress 62 | model.train() 63 | loss = LPLoss(args.lp_loss) 64 | 65 | dataloader, sampler = data["train"].dataloader, data["train"].sampler 66 | if args.distributed and sampler is not None: 67 | sampler.set_epoch(epoch) 68 | num_batches_per_epoch = dataloader.num_batches 69 | sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) 70 | 71 | # for toy dataset 72 | if args.dataset_type == "toy": 73 | dataloader.dataset.generate_queue() 74 | 75 | loss_m = AverageMeter() 76 | batch_time_m = AverageMeter() 77 | data_time_m = AverageMeter() 78 | end = time.time() 79 | 80 | for i, batch in enumerate(dataloader): 81 | step = num_batches_per_epoch * epoch + i 82 | 83 | if isinstance(scheduler, dict): 84 | for s in scheduler.values(): 85 | s(step) 86 | else: 87 | scheduler(step) 88 | 89 | audio = batch # contains mel_spec, wavform, and longer list 90 | class_label = batch["class_label"] 91 | # audio = audio.to(device=device, non_blocking=True) 92 | class_label = class_label.to(device=device, non_blocking=True) 93 | 94 | if args.mixup: 95 | # https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146 96 | mix_lambda = torch.from_numpy( 97 | get_mix_lambda(0.5, len(audio["waveform"])) 98 | ).to(device) 99 | class_label = do_mixup(class_label, mix_lambda) 100 | else: 101 | mix_lambda = None 102 | 103 | data_time_m.update(time.time() - end) 104 | if isinstance(optimizer, dict): 105 | for o_ in optimizer.values(): 106 | o_.zero_grad() 107 | else: 108 | optimizer.zero_grad() 109 | 110 | with autocast(): 111 | pred = model(audio, mix_lambda=mix_lambda, device=device) 112 | total_loss = loss(pred, class_label) 113 | 114 | if isinstance(optimizer, dict): 115 | if scaler is not None: 116 | scaler.scale(total_loss).backward() 117 | for o_ in optimizer.values(): 118 | if args.horovod: 119 | o_.synchronize() 120 | scaler.unscale_(o_) 121 | with o_.skip_synchronize(): 122 | scaler.step(o_) 123 | else: 124 | scaler.step(o_) 125 | scaler.update() 126 | else: 127 | total_loss.backward() 128 | for o_ in optimizer.values(): 129 | o_.step() 130 | else: 131 | if scaler is not None: 132 | scaler.scale(total_loss).backward() 133 | if args.horovod: 134 | optimizer.synchronize() 135 | scaler.unscale_(optimizer) 136 | with optimizer.skip_synchronize(): 137 | scaler.step(optimizer) 138 | else: 139 | scaler.step(optimizer) 140 | scaler.update() 141 | else: 142 | total_loss.backward() 143 | optimizer.step() 144 | 145 | # Note: we clamp to 4.6052 = ln(100), as in the original paper. 146 | with torch.no_grad(): 147 | unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100)) 148 | unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100)) 149 | 150 | batch_time_m.update(time.time() - end) 151 | end = time.time() 152 | batch_count = i + 1 153 | 154 | if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): 155 | if isinstance(audio, dict): 156 | batch_size = len(audio["waveform"]) 157 | else: 158 | batch_size = len(audio) 159 | num_samples = batch_count * batch_size * args.world_size 160 | samples_per_epoch = dataloader.num_samples 161 | percent_complete = 100.0 * batch_count / num_batches_per_epoch 162 | 163 | # NOTE loss is coarsely sampled, just master node and per log update 164 | loss_m.update(total_loss.item(), batch_size) 165 | if isinstance(optimizer, dict): 166 | logging.info( 167 | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " 168 | f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " 169 | f"Data (t): {data_time_m.avg:.3f} " 170 | f"Batch (t): {batch_time_m.avg:.3f} " 171 | f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}" 172 | ) 173 | log_data = { 174 | "loss": loss_m.val, 175 | "data_time": data_time_m.val, 176 | "batch_time": batch_time_m.val, 177 | "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], 178 | } 179 | else: 180 | logging.info( 181 | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " 182 | f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " 183 | f"Data (t): {data_time_m.avg:.3f} " 184 | f"Batch (t): {batch_time_m.avg:.3f} " 185 | f"LR: {optimizer.param_groups[0]['lr']:5f} " 186 | ) 187 | 188 | # Save train loss / etc. Using non avg meter values as loggers have their own smoothing 189 | log_data = { 190 | "loss": loss_m.val, 191 | "data_time": data_time_m.val, 192 | "batch_time": batch_time_m.val, 193 | "lr": optimizer.param_groups[0]["lr"], 194 | } 195 | for name, val in log_data.items(): 196 | name = f"train{extra_suffix}/{name}" 197 | if tb_writer is not None: 198 | tb_writer.add_scalar(name, val, step) 199 | if args.wandb: 200 | assert wandb is not None, "Please install wandb." 201 | wandb.log({name: val, "step": step}) 202 | 203 | # resetting batch / data time meters per log window 204 | batch_time_m.reset() 205 | data_time_m.reset() 206 | # end for 207 | 208 | 209 | def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""): 210 | metrics = {} 211 | if not args.parallel_eval: 212 | if not is_master(args): 213 | return metrics 214 | device = torch.device(args.device) 215 | model.eval() 216 | 217 | # CHANGE 218 | # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) 219 | # metrics.update(zero_shot_metrics) 220 | if is_master(args): 221 | print("Evaluating...") 222 | metric_names = args.lp_metrics.split(",") 223 | eval_tool = LPMetrics(metric_names=metric_names) 224 | 225 | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress 226 | if "val" in data and ( 227 | args.val_frequency 228 | and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) 229 | ): 230 | if args.parallel_eval: 231 | dataloader, sampler = data["val"].dataloader, data["val"].sampler 232 | if args.distributed and sampler is not None: 233 | sampler.set_epoch(epoch) 234 | samples_per_val = dataloader.num_samples 235 | else: 236 | dataloader = data["val"].dataloader 237 | num_samples = 0 238 | samples_per_val = dataloader.num_samples 239 | 240 | eval_info = {"pred": [], "target": []} 241 | with torch.no_grad(): 242 | for i, batch in enumerate(dataloader): 243 | audio = batch # contains mel_spec, wavform, and longer list 244 | class_label = batch["class_label"] 245 | 246 | # audio = audio.to(device=device, non_blocking=True) 247 | class_label = class_label.to(device=device, non_blocking=True) 248 | 249 | with autocast(): 250 | pred = model(audio, device=device) 251 | if args.parallel_eval: 252 | pred, class_label = lp_gather_features( 253 | pred, class_label, args.world_size, args.horovod 254 | ) 255 | eval_info["pred"].append(pred) 256 | eval_info["target"].append(class_label) 257 | 258 | num_samples += class_label.shape[0] 259 | 260 | if (i % 100) == 0: # and i != 0: 261 | logging.info( 262 | f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" 263 | ) 264 | 265 | if is_master(args): 266 | eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu() 267 | eval_info["target"] = torch.cat(eval_info["target"], 0).cpu() 268 | metric_dict = eval_tool.evaluate_mertics( 269 | eval_info["pred"], eval_info["target"] 270 | ) 271 | metrics.update(metric_dict) 272 | if "epoch" not in metrics.keys(): 273 | metrics.update({"epoch": epoch}) 274 | 275 | if is_master(args): 276 | if not metrics: 277 | return metrics 278 | 279 | logging.info( 280 | f"Eval Epoch: {epoch} " 281 | + "\n".join( 282 | ["\t".join([f"{m}: {round(metrics[m], 4):.4f}"]) for m in metrics] 283 | ) 284 | ) 285 | if args.save_logs: 286 | for name, val in metrics.items(): 287 | if tb_writer is not None: 288 | tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch) 289 | 290 | with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: 291 | f.write(json.dumps(metrics)) 292 | f.write("\n") 293 | 294 | if args.wandb: 295 | assert wandb is not None, "Please install wandb." 296 | for name, val in metrics.items(): 297 | wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch}) 298 | 299 | return metrics 300 | else: 301 | return metrics 302 | -------------------------------------------------------------------------------- /audioldm/clap/training/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | e = step - warmup_length 19 | es = steps - warmup_length 20 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 21 | assign_learning_rate(optimizer, lr) 22 | return lr 23 | 24 | return _lr_adjuster 25 | -------------------------------------------------------------------------------- /audioldm/clap/training/zero_shot.py: -------------------------------------------------------------------------------- 1 | # NOTE: This script is currently not supported for CLAP. 2 | import logging 3 | from contextlib import suppress 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | 9 | from open_clip import tokenize 10 | from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template 11 | 12 | 13 | def zero_shot_classifier(model, classnames, templates, args): 14 | with torch.no_grad(): 15 | zeroshot_weights = [] 16 | for classname in tqdm(classnames): 17 | texts = [template(classname) for template in templates] # format with class 18 | texts = tokenize(texts).to(args.device) # tokenize 19 | if args.distributed and not args.horovod: 20 | class_embeddings = model.module.encode_text(texts) 21 | else: 22 | class_embeddings = model.encode_text(texts) 23 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 24 | class_embedding /= class_embedding.norm() 25 | zeroshot_weights.append(class_embedding) 26 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) 27 | return zeroshot_weights 28 | 29 | 30 | def accuracy(output, target, topk=(1,)): 31 | pred = output.topk(max(topk), 1, True, True)[1].t() 32 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 33 | return [ 34 | float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) 35 | for k in topk 36 | ] 37 | 38 | 39 | def run(model, classifier, dataloader, args): 40 | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress 41 | with torch.no_grad(): 42 | top1, top5, n = 0.0, 0.0, 0.0 43 | for images, target in tqdm(dataloader, unit_scale=args.batch_size): 44 | images = images.to(args.device) 45 | target = target.to(args.device) 46 | 47 | with autocast(): 48 | # predict 49 | if args.distributed and not args.horovod: 50 | image_features = model.module.encode_image(images) 51 | else: 52 | image_features = model.encode_image(images) 53 | image_features = F.normalize(image_features, dim=-1) 54 | logits = 100.0 * image_features @ classifier 55 | 56 | # measure accuracy 57 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 58 | top1 += acc1 59 | top5 += acc5 60 | n += images.size(0) 61 | 62 | top1 = top1 / n 63 | top5 = top5 / n 64 | return top1, top5 65 | 66 | 67 | def zero_shot_eval(model, data, epoch, args): 68 | if "imagenet-val" not in data and "imagenet-v2" not in data: 69 | return {} 70 | if args.zeroshot_frequency == 0: 71 | return {} 72 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 73 | return {} 74 | 75 | logging.info("Starting zero-shot imagenet.") 76 | 77 | logging.info("Building zero-shot classifier") 78 | classifier = zero_shot_classifier( 79 | model, imagenet_classnames, openai_imagenet_template, args 80 | ) 81 | 82 | logging.info("Using classifier") 83 | results = {} 84 | if "imagenet-val" in data: 85 | top1, top5 = run(model, classifier, data["imagenet-val"].dataloader, args) 86 | results["imagenet-zeroshot-val-top1"] = top1 87 | results["imagenet-zeroshot-val-top5"] = top5 88 | if "imagenet-v2" in data: 89 | top1, top5 = run(model, classifier, data["imagenet-v2"].dataloader, args) 90 | results["imagenetv2-zeroshot-val-top1"] = top1 91 | results["imagenetv2-zeroshot-val-top5"] = top5 92 | 93 | logging.info("Finished zero-shot imagenet.") 94 | 95 | return results 96 | -------------------------------------------------------------------------------- /audioldm/hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import Generator 2 | 3 | 4 | class AttrDict(dict): 5 | def __init__(self, *args, **kwargs): 6 | super(AttrDict, self).__init__(*args, **kwargs) 7 | self.__dict__ = self 8 | -------------------------------------------------------------------------------- /audioldm/hifigan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv1d, ConvTranspose1d 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResBlock(torch.nn.Module): 21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 22 | super(ResBlock, self).__init__() 23 | self.h = h 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[0], 33 | padding=get_padding(kernel_size, dilation[0]), 34 | ) 35 | ), 36 | weight_norm( 37 | Conv1d( 38 | channels, 39 | channels, 40 | kernel_size, 41 | 1, 42 | dilation=dilation[1], 43 | padding=get_padding(kernel_size, dilation[1]), 44 | ) 45 | ), 46 | weight_norm( 47 | Conv1d( 48 | channels, 49 | channels, 50 | kernel_size, 51 | 1, 52 | dilation=dilation[2], 53 | padding=get_padding(kernel_size, dilation[2]), 54 | ) 55 | ), 56 | ] 57 | ) 58 | self.convs1.apply(init_weights) 59 | 60 | self.convs2 = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | Conv1d( 64 | channels, 65 | channels, 66 | kernel_size, 67 | 1, 68 | dilation=1, 69 | padding=get_padding(kernel_size, 1), 70 | ) 71 | ), 72 | weight_norm( 73 | Conv1d( 74 | channels, 75 | channels, 76 | kernel_size, 77 | 1, 78 | dilation=1, 79 | padding=get_padding(kernel_size, 1), 80 | ) 81 | ), 82 | weight_norm( 83 | Conv1d( 84 | channels, 85 | channels, 86 | kernel_size, 87 | 1, 88 | dilation=1, 89 | padding=get_padding(kernel_size, 1), 90 | ) 91 | ), 92 | ] 93 | ) 94 | self.convs2.apply(init_weights) 95 | 96 | def forward(self, x): 97 | for c1, c2 in zip(self.convs1, self.convs2): 98 | xt = F.leaky_relu(x, LRELU_SLOPE) 99 | xt = c1(xt) 100 | xt = F.leaky_relu(xt, LRELU_SLOPE) 101 | xt = c2(xt) 102 | x = xt + x 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs1: 107 | remove_weight_norm(l) 108 | for l in self.convs2: 109 | remove_weight_norm(l) 110 | 111 | 112 | class Generator(torch.nn.Module): 113 | def __init__(self, h): 114 | super(Generator, self).__init__() 115 | self.h = h 116 | self.num_kernels = len(h.resblock_kernel_sizes) 117 | self.num_upsamples = len(h.upsample_rates) 118 | self.conv_pre = weight_norm( 119 | Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) 120 | ) 121 | resblock = ResBlock 122 | 123 | self.ups = nn.ModuleList() 124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 125 | self.ups.append( 126 | weight_norm( 127 | ConvTranspose1d( 128 | h.upsample_initial_channel // (2**i), 129 | h.upsample_initial_channel // (2 ** (i + 1)), 130 | k, 131 | u, 132 | padding=(k - u) // 2, 133 | ) 134 | ) 135 | ) 136 | 137 | self.resblocks = nn.ModuleList() 138 | for i in range(len(self.ups)): 139 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 140 | for j, (k, d) in enumerate( 141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 142 | ): 143 | self.resblocks.append(resblock(h, ch, k, d)) 144 | 145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 146 | self.ups.apply(init_weights) 147 | self.conv_post.apply(init_weights) 148 | 149 | def forward(self, x): 150 | x = self.conv_pre(x) 151 | for i in range(self.num_upsamples): 152 | x = F.leaky_relu(x, LRELU_SLOPE) 153 | x = self.ups[i](x) 154 | xs = None 155 | for j in range(self.num_kernels): 156 | if xs is None: 157 | xs = self.resblocks[i * self.num_kernels + j](x) 158 | else: 159 | xs += self.resblocks[i * self.num_kernels + j](x) 160 | x = xs / self.num_kernels 161 | x = F.leaky_relu(x) 162 | x = self.conv_post(x) 163 | x = torch.tanh(x) 164 | 165 | return x 166 | 167 | def remove_weight_norm(self): 168 | # print("Removing weight norm...") 169 | for l in self.ups: 170 | remove_weight_norm(l) 171 | for l in self.resblocks: 172 | l.remove_weight_norm() 173 | remove_weight_norm(self.conv_pre) 174 | remove_weight_norm(self.conv_post) 175 | -------------------------------------------------------------------------------- /audioldm/hifigan/utilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import numpy as np 6 | 7 | import audioldm.hifigan as hifigan 8 | 9 | HIFIGAN_16K_64 = { 10 | "resblock": "1", 11 | "num_gpus": 6, 12 | "batch_size": 16, 13 | "learning_rate": 0.0002, 14 | "adam_b1": 0.8, 15 | "adam_b2": 0.99, 16 | "lr_decay": 0.999, 17 | "seed": 1234, 18 | "upsample_rates": [5, 4, 2, 2, 2], 19 | "upsample_kernel_sizes": [16, 16, 8, 4, 4], 20 | "upsample_initial_channel": 1024, 21 | "resblock_kernel_sizes": [3, 7, 11], 22 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 23 | "segment_size": 8192, 24 | "num_mels": 64, 25 | "num_freq": 1025, 26 | "n_fft": 1024, 27 | "hop_size": 160, 28 | "win_size": 1024, 29 | "sampling_rate": 16000, 30 | "fmin": 0, 31 | "fmax": 8000, 32 | "fmax_for_loss": None, 33 | "num_workers": 4, 34 | "dist_config": { 35 | "dist_backend": "nccl", 36 | "dist_url": "tcp://localhost:54321", 37 | "world_size": 1, 38 | }, 39 | } 40 | 41 | 42 | def get_available_checkpoint_keys(model, ckpt): 43 | print("==> Attemp to reload from %s" % ckpt) 44 | state_dict = torch.load(ckpt)["state_dict"] 45 | current_state_dict = model.state_dict() 46 | new_state_dict = {} 47 | for k in state_dict.keys(): 48 | if ( 49 | k in current_state_dict.keys() 50 | and current_state_dict[k].size() == state_dict[k].size() 51 | ): 52 | new_state_dict[k] = state_dict[k] 53 | else: 54 | print("==> WARNING: Skipping %s" % k) 55 | print( 56 | "%s out of %s keys are matched" 57 | % (len(new_state_dict.keys()), len(state_dict.keys())) 58 | ) 59 | return new_state_dict 60 | 61 | 62 | def get_param_num(model): 63 | num_param = sum(param.numel() for param in model.parameters()) 64 | return num_param 65 | 66 | 67 | def get_vocoder(config, device): 68 | config = hifigan.AttrDict(HIFIGAN_16K_64) 69 | vocoder = hifigan.Generator(config) 70 | vocoder.eval() 71 | vocoder.remove_weight_norm() 72 | vocoder.to(device) 73 | return vocoder 74 | 75 | 76 | def vocoder_infer(mels, vocoder, lengths=None): 77 | with torch.no_grad(): 78 | wavs = vocoder(mels).squeeze(1) 79 | 80 | wavs = (wavs.cpu().numpy() * 32768).astype("int16") 81 | 82 | if lengths is not None: 83 | wavs = wavs[:, :lengths] 84 | 85 | return wavs 86 | -------------------------------------------------------------------------------- /audioldm/latent_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/audioldm/latent_diffusion/__init__.py -------------------------------------------------------------------------------- /audioldm/latent_diffusion/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_( 47 | one_minus_decay * (shadow_params[sname] - m_param[key]) 48 | ) 49 | else: 50 | assert not key in self.m_name2s_name 51 | 52 | def copy_to(self, model): 53 | m_param = dict(model.named_parameters()) 54 | shadow_params = dict(self.named_buffers()) 55 | for key in m_param: 56 | if m_param[key].requires_grad: 57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 58 | else: 59 | assert not key in self.m_name2s_name 60 | 61 | def store(self, parameters): 62 | """ 63 | Save the current parameters for restoring later. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | temporarily stored. 67 | """ 68 | self.collected_params = [param.clone() for param in parameters] 69 | 70 | def restore(self, parameters): 71 | """ 72 | Restore the parameters stored with the `store` method. 73 | Useful to validate the model with EMA parameters without affecting the 74 | original optimization process. Store the parameters before the 75 | `copy_to` method. After validation (or model saving), use this to 76 | restore the former parameters. 77 | Args: 78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 79 | updated with the stored parameters. 80 | """ 81 | for c_param, param in zip(self.collected_params, parameters): 82 | param.data.copy_(c_param.data) 83 | -------------------------------------------------------------------------------- /audioldm/latent_diffusion/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from audioldm.utils import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule( 22 | schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 23 | ): 24 | if schedule == "linear": 25 | betas = ( 26 | torch.linspace( 27 | linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 28 | ) 29 | ** 2 30 | ) 31 | 32 | elif schedule == "cosine": 33 | timesteps = ( 34 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 35 | ) 36 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 37 | alphas = torch.cos(alphas).pow(2) 38 | alphas = alphas / alphas[0] 39 | betas = 1 - alphas[1:] / alphas[:-1] 40 | betas = np.clip(betas, a_min=0, a_max=0.999) 41 | 42 | elif schedule == "sqrt_linear": 43 | betas = torch.linspace( 44 | linear_start, linear_end, n_timestep, dtype=torch.float64 45 | ) 46 | elif schedule == "sqrt": 47 | betas = ( 48 | torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 49 | ** 0.5 50 | ) 51 | else: 52 | raise ValueError(f"schedule '{schedule}' unknown.") 53 | return betas.numpy() 54 | 55 | 56 | def make_ddim_timesteps( 57 | ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True 58 | ): 59 | if ddim_discr_method == "uniform": 60 | c = num_ddpm_timesteps // num_ddim_timesteps 61 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 62 | elif ddim_discr_method == "quad": 63 | ddim_timesteps = ( 64 | (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 65 | ).astype(int) 66 | else: 67 | raise NotImplementedError( 68 | f'There is no ddim discretization method called "{ddim_discr_method}"' 69 | ) 70 | 71 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 72 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 73 | steps_out = ddim_timesteps + 1 74 | if verbose: 75 | print(f"Selected timesteps for ddim sampler: {steps_out}") 76 | return steps_out 77 | 78 | 79 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 80 | # select alphas for computing the variance schedule 81 | alphas = alphacums[ddim_timesteps] 82 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 83 | 84 | # according the the formula provided in https://arxiv.org/abs/2010.02502 85 | sigmas = eta * np.sqrt( 86 | (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) 87 | ) 88 | if verbose: 89 | print( 90 | f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" 91 | ) 92 | print( 93 | f"For the chosen value of eta, which is {eta}, " 94 | f"this results in the following sigma_t schedule for ddim sampler {sigmas}" 95 | ) 96 | return sigmas, alphas, alphas_prev 97 | 98 | 99 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 100 | """ 101 | Create a beta schedule that discretizes the given alpha_t_bar function, 102 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 103 | :param num_diffusion_timesteps: the number of betas to produce. 104 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 105 | produces the cumulative product of (1-beta) up to that 106 | part of the diffusion process. 107 | :param max_beta: the maximum beta to use; use values lower than 1 to 108 | prevent singularities. 109 | """ 110 | betas = [] 111 | for i in range(num_diffusion_timesteps): 112 | t1 = i / num_diffusion_timesteps 113 | t2 = (i + 1) / num_diffusion_timesteps 114 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 115 | return np.array(betas) 116 | 117 | 118 | def extract_into_tensor(a, t, x_shape): 119 | b, *_ = t.shape 120 | out = a.gather(-1, t).contiguous() 121 | return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous() 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | :param func: the function to evaluate. 129 | :param inputs: the argument sequence to pass to `func`. 130 | :param params: a sequence of parameters `func` depends on but does not 131 | explicitly take as arguments. 132 | :param flag: if False, disable gradient checkpointing. 133 | """ 134 | if flag: 135 | args = tuple(inputs) + tuple(params) 136 | return CheckpointFunction.apply(func, len(inputs), *args) 137 | else: 138 | return func(*inputs) 139 | 140 | 141 | class CheckpointFunction(torch.autograd.Function): 142 | @staticmethod 143 | def forward(ctx, run_function, length, *args): 144 | ctx.run_function = run_function 145 | ctx.input_tensors = list(args[:length]) 146 | ctx.input_params = list(args[length:]) 147 | 148 | with torch.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with torch.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = torch.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | 172 | 173 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 174 | """ 175 | Create sinusoidal timestep embeddings. 176 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 177 | These may be fractional. 178 | :param dim: the dimension of the output. 179 | :param max_period: controls the minimum frequency of the embeddings. 180 | :return: an [N x dim] Tensor of positional embeddings. 181 | """ 182 | if not repeat_only: 183 | half = dim // 2 184 | freqs = torch.exp( 185 | -math.log(max_period) 186 | * torch.arange(start=0, end=half, dtype=torch.float32) 187 | / half 188 | ).to(device=timesteps.device) 189 | args = timesteps[:, None].float() * freqs[None] 190 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 191 | if dim % 2: 192 | embedding = torch.cat( 193 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 194 | ) 195 | else: 196 | embedding = repeat(timesteps, "b -> b d", d=dim) 197 | return embedding 198 | 199 | 200 | def zero_module(module): 201 | """ 202 | Zero out the parameters of a module and return it. 203 | """ 204 | for p in module.parameters(): 205 | p.detach().zero_() 206 | return module 207 | 208 | 209 | def scale_module(module, scale): 210 | """ 211 | Scale the parameters of a module and return it. 212 | """ 213 | for p in module.parameters(): 214 | p.detach().mul_(scale) 215 | return module 216 | 217 | 218 | def mean_flat(tensor): 219 | """ 220 | Take the mean over all non-batch dimensions. 221 | """ 222 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 223 | 224 | 225 | def normalization(channels): 226 | """ 227 | Make a standard normalization layer. 228 | :param channels: number of input channels. 229 | :return: an nn.Module for normalization. 230 | """ 231 | return GroupNorm32(32, channels) 232 | 233 | 234 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 235 | class SiLU(nn.Module): 236 | def forward(self, x): 237 | return x * torch.sigmoid(x) 238 | 239 | 240 | class GroupNorm32(nn.GroupNorm): 241 | def forward(self, x): 242 | return super().forward(x.float()).type(x.dtype) 243 | 244 | 245 | def conv_nd(dims, *args, **kwargs): 246 | """ 247 | Create a 1D, 2D, or 3D convolution module. 248 | """ 249 | if dims == 1: 250 | return nn.Conv1d(*args, **kwargs) 251 | elif dims == 2: 252 | return nn.Conv2d(*args, **kwargs) 253 | elif dims == 3: 254 | return nn.Conv3d(*args, **kwargs) 255 | raise ValueError(f"unsupported dimensions: {dims}") 256 | 257 | 258 | def linear(*args, **kwargs): 259 | """ 260 | Create a linear module. 261 | """ 262 | return nn.Linear(*args, **kwargs) 263 | 264 | 265 | def avg_pool_nd(dims, *args, **kwargs): 266 | """ 267 | Create a 1D, 2D, or 3D average pooling module. 268 | """ 269 | if dims == 1: 270 | return nn.AvgPool1d(*args, **kwargs) 271 | elif dims == 2: 272 | return nn.AvgPool2d(*args, **kwargs) 273 | elif dims == 3: 274 | return nn.AvgPool3d(*args, **kwargs) 275 | raise ValueError(f"unsupported dimensions: {dims}") 276 | 277 | 278 | class HybridConditioner(nn.Module): 279 | def __init__(self, c_concat_config, c_crossattn_config): 280 | super().__init__() 281 | self.concat_conditioner = instantiate_from_config(c_concat_config) 282 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 283 | 284 | def forward(self, c_concat, c_crossattn): 285 | c_concat = self.concat_conditioner(c_concat) 286 | c_crossattn = self.crossattn_conditioner(c_crossattn) 287 | return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} 288 | 289 | 290 | def noise_like(shape, device, repeat=False): 291 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( 292 | shape[0], *((1,) * (len(shape) - 1)) 293 | ) 294 | noise = lambda: torch.randn(shape, device=device) 295 | return repeat_noise() if repeat else noise() 296 | -------------------------------------------------------------------------------- /audioldm/pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import argparse 4 | import yaml 5 | import torch 6 | from torch import autocast 7 | from tqdm import tqdm, trange 8 | 9 | from audioldm import LatentDiffusion, seed_everything 10 | from audioldm.utils import default_audioldm_config, get_duration, get_bit_depth, get_metadata, download_checkpoint 11 | from audioldm.audio import wav_to_fbank, TacotronSTFT, read_wav_file 12 | from audioldm.latent_diffusion.ddim import DDIMSampler 13 | from einops import repeat 14 | import os 15 | 16 | def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1): 17 | text = [text] * batchsize 18 | if batchsize < 1: 19 | print("Warning: Batchsize must be at least 1. Batchsize is set to .") 20 | 21 | if(fbank is None): 22 | fbank = torch.zeros((batchsize, 1024, 64)) # Not used, here to keep the code format 23 | else: 24 | fbank = torch.FloatTensor(fbank) 25 | fbank = fbank.expand(batchsize, 1024, 64) 26 | assert fbank.size(0) == batchsize 27 | 28 | stft = torch.zeros((batchsize, 1024, 512)) # Not used 29 | 30 | if(waveform is None): 31 | waveform = torch.zeros((batchsize, 160000)) # Not used 32 | else: 33 | waveform = torch.FloatTensor(waveform) 34 | waveform = waveform.expand(batchsize, -1) 35 | assert waveform.size(0) == batchsize 36 | 37 | fname = [""] * batchsize # Not used 38 | 39 | batch = ( 40 | fbank, 41 | stft, 42 | None, 43 | fname, 44 | waveform, 45 | text, 46 | ) 47 | return batch 48 | 49 | def round_up_duration(duration): 50 | return int(round(duration/2.5) + 1) * 2.5 51 | 52 | def build_model( 53 | ckpt_path=None, 54 | config=None, 55 | model_name="audioldm-s-full" 56 | ): 57 | print("Load AudioLDM: %s", model_name) 58 | 59 | if(ckpt_path is None): 60 | ckpt_path = get_metadata()[model_name]["path"] 61 | 62 | if(not os.path.exists(ckpt_path)): 63 | download_checkpoint(model_name) 64 | 65 | if torch.cuda.is_available(): 66 | device = torch.device("cuda:0") 67 | else: 68 | device = torch.device("cpu") 69 | 70 | if config is not None: 71 | assert type(config) is str 72 | config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) 73 | else: 74 | config = default_audioldm_config(model_name) 75 | 76 | # Use text as condition instead of using waveform during training 77 | config["model"]["params"]["device"] = device 78 | config["model"]["params"]["cond_stage_key"] = "text" 79 | 80 | # No normalization here 81 | latent_diffusion = LatentDiffusion(**config["model"]["params"]) 82 | 83 | resume_from_checkpoint = ckpt_path 84 | 85 | checkpoint = torch.load(resume_from_checkpoint, map_location=device) 86 | '''Original. Here is a bug that, an unexpected key "cond_stage_model.model.text_branch.embeddings.position_ids" exists in the checkpoint file. ''' 87 | # latent_diffusion.load_state_dict(checkpoint["state_dict"]) 88 | '''2023.10.17 Fix the bug by setting the paramer "strict" as "False" to ignore the unexpected key. ''' 89 | latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=False) 90 | 91 | latent_diffusion.eval() 92 | latent_diffusion = latent_diffusion.to(device) 93 | 94 | latent_diffusion.cond_stage_model.embed_mode = "text" 95 | return latent_diffusion 96 | 97 | def duration_to_latent_t_size(duration): 98 | return int(duration * 25.6) 99 | 100 | def set_cond_audio(latent_diffusion): 101 | latent_diffusion.cond_stage_key = "waveform" 102 | latent_diffusion.cond_stage_model.embed_mode="audio" 103 | return latent_diffusion 104 | 105 | def set_cond_text(latent_diffusion): 106 | latent_diffusion.cond_stage_key = "text" 107 | latent_diffusion.cond_stage_model.embed_mode="text" 108 | return latent_diffusion 109 | 110 | def text_to_audio( 111 | latent_diffusion, 112 | text, 113 | original_audio_file_path = None, 114 | seed=42, 115 | ddim_steps=200, 116 | duration=10, 117 | batchsize=1, 118 | guidance_scale=2.5, 119 | n_candidate_gen_per_text=3, 120 | config=None, 121 | ): 122 | seed_everything(int(seed)) 123 | waveform = None 124 | if(original_audio_file_path is not None): 125 | waveform = read_wav_file(original_audio_file_path, int(duration * 102.4) * 160) 126 | 127 | batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize) 128 | 129 | latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) 130 | 131 | if(waveform is not None): 132 | print("Generate audio that has similar content as %s" % original_audio_file_path) 133 | latent_diffusion = set_cond_audio(latent_diffusion) 134 | else: 135 | print("Generate audio using text %s" % text) 136 | latent_diffusion = set_cond_text(latent_diffusion) 137 | 138 | with torch.no_grad(): 139 | waveform = latent_diffusion.generate_sample( 140 | [batch], 141 | unconditional_guidance_scale=guidance_scale, 142 | ddim_steps=ddim_steps, 143 | n_candidate_gen_per_text=n_candidate_gen_per_text, 144 | duration=duration, 145 | ) 146 | return waveform 147 | 148 | def style_transfer( 149 | latent_diffusion, 150 | text, 151 | original_audio_file_path, 152 | transfer_strength, 153 | seed=42, 154 | duration=10, 155 | batchsize=1, 156 | guidance_scale=2.5, 157 | ddim_steps=200, 158 | config=None, 159 | ): 160 | if torch.cuda.is_available(): 161 | device = torch.device("cuda:0") 162 | else: 163 | device = torch.device("cpu") 164 | 165 | assert original_audio_file_path is not None, "You need to provide the original audio file path" 166 | 167 | audio_file_duration = get_duration(original_audio_file_path) 168 | 169 | assert get_bit_depth(original_audio_file_path) == 16, "The bit depth of the original audio file %s must be 16" % original_audio_file_path 170 | 171 | # if(duration > 20): 172 | # print("Warning: The duration of the audio file %s must be less than 20 seconds. Longer duration will result in Nan in model output (we are still debugging that); Automatically set duration to 20 seconds") 173 | # duration = 20 174 | 175 | if(duration > audio_file_duration): 176 | print("Warning: Duration you specified %s-seconds must equal or smaller than the audio file duration %ss" % (duration, audio_file_duration)) 177 | duration = round_up_duration(audio_file_duration) 178 | print("Set new duration as %s-seconds" % duration) 179 | 180 | # duration = round_up_duration(duration) 181 | 182 | latent_diffusion = set_cond_text(latent_diffusion) 183 | 184 | if config is not None: 185 | assert type(config) is str 186 | config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) 187 | else: 188 | config = default_audioldm_config() 189 | 190 | seed_everything(int(seed)) 191 | # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) 192 | latent_diffusion.cond_stage_model.embed_mode = "text" 193 | 194 | fn_STFT = TacotronSTFT( 195 | config["preprocessing"]["stft"]["filter_length"], 196 | config["preprocessing"]["stft"]["hop_length"], 197 | config["preprocessing"]["stft"]["win_length"], 198 | config["preprocessing"]["mel"]["n_mel_channels"], 199 | config["preprocessing"]["audio"]["sampling_rate"], 200 | config["preprocessing"]["mel"]["mel_fmin"], 201 | config["preprocessing"]["mel"]["mel_fmax"], 202 | ) 203 | 204 | mel, _, _ = wav_to_fbank( 205 | original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT 206 | ) 207 | mel = mel.unsqueeze(0).unsqueeze(0).to(device) 208 | mel = repeat(mel, "1 ... -> b ...", b=batchsize) 209 | init_latent = latent_diffusion.get_first_stage_encoding( 210 | latent_diffusion.encode_first_stage(mel) 211 | ) # move to latent space, encode and sample 212 | if(torch.max(torch.abs(init_latent)) > 1e2): 213 | init_latent = torch.clip(init_latent, min=-10, max=10) 214 | sampler = DDIMSampler(latent_diffusion) 215 | sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=1.0, verbose=False) 216 | 217 | t_enc = int(transfer_strength * ddim_steps) 218 | prompts = text 219 | 220 | with torch.no_grad(): 221 | with autocast("cuda"): 222 | with latent_diffusion.ema_scope(): 223 | uc = None 224 | if guidance_scale != 1.0: 225 | uc = latent_diffusion.cond_stage_model.get_unconditional_condition( 226 | batchsize 227 | ) 228 | 229 | c = latent_diffusion.get_learned_conditioning([prompts] * batchsize) 230 | z_enc = sampler.stochastic_encode( 231 | init_latent, torch.tensor([t_enc] * batchsize).to(device) 232 | ) 233 | samples = sampler.decode( 234 | z_enc, 235 | c, 236 | t_enc, 237 | unconditional_guidance_scale=guidance_scale, 238 | unconditional_conditioning=uc, 239 | ) 240 | # x_samples = latent_diffusion.decode_first_stage(samples) # Will result in Nan in output 241 | # print(torch.sum(torch.isnan(samples))) 242 | x_samples = latent_diffusion.decode_first_stage(samples) 243 | # print(x_samples) 244 | x_samples = latent_diffusion.decode_first_stage(samples[:,:,:-3,:]) 245 | # print(x_samples) 246 | waveform = latent_diffusion.first_stage_model.decode_to_waveform( 247 | x_samples 248 | ) 249 | 250 | return waveform 251 | 252 | def super_resolution_and_inpainting( 253 | latent_diffusion, 254 | text, 255 | original_audio_file_path = None, 256 | seed=42, 257 | ddim_steps=200, 258 | duration=None, 259 | batchsize=1, 260 | guidance_scale=2.5, 261 | n_candidate_gen_per_text=3, 262 | time_mask_ratio_start_and_end=(0.10, 0.15), # regenerate the 10% to 15% of the time steps in the spectrogram 263 | # time_mask_ratio_start_and_end=(1.0, 1.0), # no inpainting 264 | # freq_mask_ratio_start_and_end=(0.75, 1.0), # regenerate the higher 75% to 100% mel bins 265 | freq_mask_ratio_start_and_end=(1.0, 1.0), # no super-resolution 266 | config=None, 267 | ): 268 | seed_everything(int(seed)) 269 | if config is not None: 270 | assert type(config) is str 271 | config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) 272 | else: 273 | config = default_audioldm_config() 274 | fn_STFT = TacotronSTFT( 275 | config["preprocessing"]["stft"]["filter_length"], 276 | config["preprocessing"]["stft"]["hop_length"], 277 | config["preprocessing"]["stft"]["win_length"], 278 | config["preprocessing"]["mel"]["n_mel_channels"], 279 | config["preprocessing"]["audio"]["sampling_rate"], 280 | config["preprocessing"]["mel"]["mel_fmin"], 281 | config["preprocessing"]["mel"]["mel_fmax"], 282 | ) 283 | 284 | # waveform = read_wav_file(original_audio_file_path, None) 285 | mel, _, _ = wav_to_fbank( 286 | original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT 287 | ) 288 | 289 | batch = make_batch_for_text_to_audio(text, fbank=mel[None,...], batchsize=batchsize) 290 | 291 | # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) 292 | latent_diffusion = set_cond_text(latent_diffusion) 293 | 294 | with torch.no_grad(): 295 | waveform = latent_diffusion.generate_sample_masked( 296 | [batch], 297 | unconditional_guidance_scale=guidance_scale, 298 | ddim_steps=ddim_steps, 299 | n_candidate_gen_per_text=n_candidate_gen_per_text, 300 | duration=duration, 301 | time_mask_ratio_start_and_end=time_mask_ratio_start_and_end, 302 | freq_mask_ratio_start_and_end=freq_mask_ratio_start_and_end 303 | ) 304 | return waveform 305 | -------------------------------------------------------------------------------- /audioldm/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import importlib 3 | 4 | from inspect import isfunction 5 | import os 6 | import soundfile as sf 7 | import time 8 | import wave 9 | 10 | import urllib.request 11 | import progressbar 12 | 13 | CACHE_DIR = os.getenv( 14 | "AUDIOLDM_CACHE_DIR", 15 | os.path.join(os.path.expanduser("~"), ".cache/audioldm")) 16 | 17 | def get_duration(fname): 18 | with contextlib.closing(wave.open(fname, 'r')) as f: 19 | frames = f.getnframes() 20 | rate = f.getframerate() 21 | return frames / float(rate) 22 | 23 | def get_bit_depth(fname): 24 | with contextlib.closing(wave.open(fname, 'r')) as f: 25 | bit_depth = f.getsampwidth() * 8 26 | return bit_depth 27 | 28 | def get_time(): 29 | t = time.localtime() 30 | return time.strftime("%d_%m_%Y_%H_%M_%S", t) 31 | 32 | def seed_everything(seed): 33 | import random, os 34 | import numpy as np 35 | import torch 36 | 37 | random.seed(seed) 38 | os.environ["PYTHONHASHSEED"] = str(seed) 39 | np.random.seed(seed) 40 | torch.manual_seed(seed) 41 | torch.cuda.manual_seed(seed) 42 | torch.backends.cudnn.deterministic = True 43 | torch.backends.cudnn.benchmark = True 44 | 45 | 46 | def save_wave(waveform, savepath, name="outwav"): 47 | if type(name) is not list: 48 | name = [name] * waveform.shape[0] 49 | 50 | for i in range(waveform.shape[0]): 51 | path = os.path.join( 52 | savepath, 53 | "%s_%s.wav" 54 | % ( 55 | os.path.basename(name[i]) 56 | if (not ".wav" in name[i]) 57 | else os.path.basename(name[i]).split(".")[0], 58 | i, 59 | ), 60 | ) 61 | print("Save audio to %s" % path) 62 | sf.write(path, waveform[i, 0], samplerate=16000) 63 | 64 | 65 | def exists(x): 66 | return x is not None 67 | 68 | 69 | def default(val, d): 70 | if exists(val): 71 | return val 72 | return d() if isfunction(d) else d 73 | 74 | 75 | def count_params(model, verbose=False): 76 | total_params = sum(p.numel() for p in model.parameters()) 77 | if verbose: 78 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 79 | return total_params 80 | 81 | 82 | def get_obj_from_str(string, reload=False): 83 | module, cls = string.rsplit(".", 1) 84 | if reload: 85 | module_imp = importlib.import_module(module) 86 | importlib.reload(module_imp) 87 | return getattr(importlib.import_module(module, package=None), cls) 88 | 89 | 90 | def instantiate_from_config(config): 91 | if not "target" in config: 92 | if config == "__is_first_stage__": 93 | return None 94 | elif config == "__is_unconditional__": 95 | return None 96 | raise KeyError("Expected key `target` to instantiate.") 97 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 98 | 99 | 100 | def default_audioldm_config(model_name="audioldm-s-full"): 101 | basic_config = { 102 | "wave_file_save_path": "./output", 103 | "id": { 104 | "version": "v1", 105 | "name": "default", 106 | "root": "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/AudioLDM-python/config/default/latent_diffusion.yaml", 107 | }, 108 | "preprocessing": { 109 | "audio": {"sampling_rate": 16000, "max_wav_value": 32768}, 110 | "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024}, 111 | "mel": { 112 | "n_mel_channels": 64, 113 | "mel_fmin": 0, 114 | "mel_fmax": 8000, 115 | "freqm": 0, 116 | "timem": 0, 117 | "blur": False, 118 | "mean": -4.63, 119 | "std": 2.74, 120 | "target_length": 1024, 121 | }, 122 | }, 123 | "model": { 124 | "device": "cuda", 125 | "target": "audioldm.pipline.LatentDiffusion", 126 | "params": { 127 | "base_learning_rate": 5e-06, 128 | "linear_start": 0.0015, 129 | "linear_end": 0.0195, 130 | "num_timesteps_cond": 1, 131 | "log_every_t": 200, 132 | "timesteps": 1000, 133 | "first_stage_key": "fbank", 134 | "cond_stage_key": "waveform", 135 | "latent_t_size": 256, 136 | "latent_f_size": 16, 137 | "channels": 8, 138 | "cond_stage_trainable": True, 139 | "conditioning_key": "film", 140 | "monitor": "val/loss_simple_ema", 141 | "scale_by_std": True, 142 | "unet_config": { 143 | "target": "audioldm.latent_diffusion.openaimodel.UNetModel", 144 | "params": { 145 | "image_size": 64, 146 | "extra_film_condition_dim": 512, 147 | "extra_film_use_concat": True, 148 | "in_channels": 8, 149 | "out_channels": 8, 150 | "model_channels": 128, 151 | "attention_resolutions": [8, 4, 2], 152 | "num_res_blocks": 2, 153 | "channel_mult": [1, 2, 3, 5], 154 | "num_head_channels": 32, 155 | "use_spatial_transformer": True, 156 | }, 157 | }, 158 | "first_stage_config": { 159 | "base_learning_rate": 4.5e-05, 160 | "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL", 161 | "params": { 162 | "monitor": "val/rec_loss", 163 | "image_key": "fbank", 164 | "subband": 1, 165 | "embed_dim": 8, 166 | "time_shuffle": 1, 167 | "ddconfig": { 168 | "double_z": True, 169 | "z_channels": 8, 170 | "resolution": 256, 171 | "downsample_time": False, 172 | "in_channels": 1, 173 | "out_ch": 1, 174 | "ch": 128, 175 | "ch_mult": [1, 2, 4], 176 | "num_res_blocks": 2, 177 | "attn_resolutions": [], 178 | "dropout": 0.0, 179 | }, 180 | }, 181 | }, 182 | "cond_stage_config": { 183 | "target": "audioldm.clap.encoders.CLAPAudioEmbeddingClassifierFreev2", 184 | "params": { 185 | "key": "waveform", 186 | "sampling_rate": 16000, 187 | "embed_mode": "audio", 188 | "unconditional_prob": 0.1, 189 | }, 190 | }, 191 | }, 192 | }, 193 | } 194 | 195 | if("-l-" in model_name): 196 | basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 256 197 | basic_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = 64 198 | elif("-m-" in model_name): 199 | basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 192 200 | basic_config["model"]["params"]["cond_stage_config"]["params"]["amodel"] = "HTSAT-base" # This model use a larger HTAST 201 | 202 | return basic_config 203 | 204 | def get_metadata(): 205 | return { 206 | "audioldm-s-full": { 207 | "path": os.path.join( 208 | CACHE_DIR, 209 | "audioldm-s-full.ckpt", 210 | ), 211 | "url": "https://zenodo.org/record/7600541/files/audioldm-s-full?download=1", 212 | }, 213 | "audioldm-l-full": { 214 | "path": os.path.join( 215 | CACHE_DIR, 216 | "audioldm-l-full.ckpt", 217 | ), 218 | "url": "https://zenodo.org/record/7698295/files/audioldm-full-l.ckpt?download=1", 219 | }, 220 | "audioldm-s-full-v2": { 221 | "path": os.path.join( 222 | CACHE_DIR, 223 | "audioldm-s-full-v2.ckpt", 224 | ), 225 | "url": "https://zenodo.org/record/7698295/files/audioldm-full-s-v2.ckpt?download=1", 226 | }, 227 | "audioldm-m-text-ft": { 228 | "path": os.path.join( 229 | CACHE_DIR, 230 | "audioldm-m-text-ft.ckpt", 231 | ), 232 | "url": "https://zenodo.org/record/7813012/files/audioldm-m-text-ft.ckpt?download=1", 233 | }, 234 | "audioldm-s-text-ft": { 235 | "path": os.path.join( 236 | CACHE_DIR, 237 | "audioldm-s-text-ft.ckpt", 238 | ), 239 | "url": "https://zenodo.org/record/7813012/files/audioldm-s-text-ft.ckpt?download=1", 240 | }, 241 | "audioldm-m-full": { 242 | "path": os.path.join( 243 | CACHE_DIR, 244 | "audioldm-m-full.ckpt", 245 | ), 246 | "url": "https://zenodo.org/record/7813012/files/audioldm-m-full.ckpt?download=1", 247 | }, 248 | } 249 | 250 | class MyProgressBar(): 251 | def __init__(self): 252 | self.pbar = None 253 | 254 | def __call__(self, block_num, block_size, total_size): 255 | if not self.pbar: 256 | self.pbar=progressbar.ProgressBar(maxval=total_size) 257 | self.pbar.start() 258 | 259 | downloaded = block_num * block_size 260 | if downloaded < total_size: 261 | self.pbar.update(downloaded) 262 | else: 263 | self.pbar.finish() 264 | 265 | def download_checkpoint(checkpoint_name="audioldm-s-full"): 266 | meta = get_metadata() 267 | if(checkpoint_name not in meta.keys()): 268 | print("The model name you provided is not supported. Please use one of the following: ", meta.keys()) 269 | 270 | if not os.path.exists(meta[checkpoint_name]["path"]) or os.path.getsize(meta[checkpoint_name]["path"]) < 2*10**9: 271 | os.makedirs(os.path.dirname(meta[checkpoint_name]["path"]), exist_ok=True) 272 | print(f"Downloading the main structure of {checkpoint_name} into {os.path.dirname(meta[checkpoint_name]['path'])}") 273 | 274 | urllib.request.urlretrieve(meta[checkpoint_name]["url"], meta[checkpoint_name]["path"], MyProgressBar()) 275 | print( 276 | "Weights downloaded in: {} Size: {}".format( 277 | meta[checkpoint_name]["path"], 278 | os.path.getsize(meta[checkpoint_name]["path"]), 279 | ) 280 | ) 281 | -------------------------------------------------------------------------------- /audioldm/variational_autoencoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/audioldm/variational_autoencoder/__init__.py -------------------------------------------------------------------------------- /audioldm/variational_autoencoder/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from audioldm.latent_diffusion.ema import * 3 | from audioldm.variational_autoencoder.modules import Encoder, Decoder 4 | from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution 5 | 6 | from audioldm.hifigan.utilities import get_vocoder, vocoder_infer 7 | 8 | 9 | class AutoencoderKL(nn.Module): 10 | def __init__( 11 | self, 12 | ddconfig=None, 13 | lossconfig=None, 14 | image_key="fbank", 15 | embed_dim=None, 16 | time_shuffle=1, 17 | subband=1, 18 | ckpt_path=None, 19 | reload_from_ckpt=None, 20 | ignore_keys=[], 21 | colorize_nlabels=None, 22 | monitor=None, 23 | base_learning_rate=1e-5, 24 | ): 25 | super().__init__() 26 | 27 | self.encoder = Encoder(**ddconfig) 28 | self.decoder = Decoder(**ddconfig) 29 | 30 | self.subband = int(subband) 31 | 32 | if self.subband > 1: 33 | print("Use subband decomposition %s" % self.subband) 34 | 35 | self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) 36 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 37 | 38 | self.vocoder = get_vocoder(None, "cpu") 39 | self.embed_dim = embed_dim 40 | 41 | if monitor is not None: 42 | self.monitor = monitor 43 | 44 | self.time_shuffle = time_shuffle 45 | self.reload_from_ckpt = reload_from_ckpt 46 | self.reloaded = False 47 | self.mean, self.std = None, None 48 | 49 | def encode(self, x): 50 | # x = self.time_shuffle_operation(x) 51 | x = self.freq_split_subband(x) 52 | h = self.encoder(x) 53 | moments = self.quant_conv(h) 54 | posterior = DiagonalGaussianDistribution(moments) 55 | return posterior 56 | 57 | def decode(self, z): 58 | z = self.post_quant_conv(z) 59 | dec = self.decoder(z) 60 | dec = self.freq_merge_subband(dec) 61 | return dec 62 | 63 | def decode_to_waveform(self, dec): 64 | dec = dec.squeeze(1).permute(0, 2, 1) 65 | wav_reconstruction = vocoder_infer(dec, self.vocoder) 66 | return wav_reconstruction 67 | 68 | def forward(self, input, sample_posterior=True): 69 | posterior = self.encode(input) 70 | if sample_posterior: 71 | z = posterior.sample() 72 | else: 73 | z = posterior.mode() 74 | 75 | if self.flag_first_run: 76 | print("Latent size: ", z.size()) 77 | self.flag_first_run = False 78 | 79 | dec = self.decode(z) 80 | 81 | return dec, posterior 82 | 83 | def freq_split_subband(self, fbank): 84 | if self.subband == 1 or self.image_key != "stft": 85 | return fbank 86 | 87 | bs, ch, tstep, fbins = fbank.size() 88 | 89 | assert fbank.size(-1) % self.subband == 0 90 | assert ch == 1 91 | 92 | return ( 93 | fbank.squeeze(1) 94 | .reshape(bs, tstep, self.subband, fbins // self.subband) 95 | .permute(0, 2, 1, 3) 96 | ) 97 | 98 | def freq_merge_subband(self, subband_fbank): 99 | if self.subband == 1 or self.image_key != "stft": 100 | return subband_fbank 101 | assert subband_fbank.size(1) == self.subband # Channel dimension 102 | bs, sub_ch, tstep, fbins = subband_fbank.size() 103 | return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1) 104 | -------------------------------------------------------------------------------- /audioldm/variational_autoencoder/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.mean( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.mean( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/bg.png -------------------------------------------------------------------------------- /bin/audioldm: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import os 3 | from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration 4 | import argparse 5 | 6 | CACHE_DIR = os.getenv( 7 | "AUDIOLDM_CACHE_DIR", 8 | os.path.join(os.path.expanduser("~"), ".cache/audioldm")) 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument( 13 | "--mode", 14 | type=str, 15 | required=False, 16 | default="generation", 17 | help="generation: text-to-audio generation; transfer: style transfer", 18 | choices=["generation", "transfer"] 19 | ) 20 | 21 | parser.add_argument( 22 | "-t", 23 | "--text", 24 | type=str, 25 | required=False, 26 | default="", 27 | help="Text prompt to the model for audio generation", 28 | ) 29 | 30 | parser.add_argument( 31 | "-f", 32 | "--file_path", 33 | type=str, 34 | required=False, 35 | default=None, 36 | help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio", 37 | ) 38 | 39 | parser.add_argument( 40 | "--transfer_strength", 41 | type=float, 42 | required=False, 43 | default=0.5, 44 | help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text", 45 | ) 46 | 47 | parser.add_argument( 48 | "-s", 49 | "--save_path", 50 | type=str, 51 | required=False, 52 | help="The path to save model output", 53 | default="./output", 54 | ) 55 | 56 | parser.add_argument( 57 | "--model_name", 58 | type=str, 59 | required=False, 60 | help="The checkpoint you gonna use", 61 | default="audioldm-m-full", 62 | choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2","audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full"] 63 | ) 64 | 65 | parser.add_argument( 66 | "-ckpt", 67 | "--ckpt_path", 68 | type=str, 69 | required=False, 70 | help="The path to the pretrained .ckpt model", 71 | default=None, 72 | ) 73 | 74 | parser.add_argument( 75 | "-b", 76 | "--batchsize", 77 | type=int, 78 | required=False, 79 | default=1, 80 | help="Generate how many samples at the same time", 81 | ) 82 | 83 | parser.add_argument( 84 | "--ddim_steps", 85 | type=int, 86 | required=False, 87 | default=200, 88 | help="The sampling step for DDIM", 89 | ) 90 | 91 | parser.add_argument( 92 | "-gs", 93 | "--guidance_scale", 94 | type=float, 95 | required=False, 96 | default=2.5, 97 | help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)", 98 | ) 99 | 100 | parser.add_argument( 101 | "-dur", 102 | "--duration", 103 | type=float, 104 | required=False, 105 | default=10.0, 106 | help="The duration of the samples", 107 | ) 108 | 109 | parser.add_argument( 110 | "-n", 111 | "--n_candidate_gen_per_text", 112 | type=int, 113 | required=False, 114 | default=3, 115 | help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation", 116 | ) 117 | 118 | parser.add_argument( 119 | "--seed", 120 | type=int, 121 | required=False, 122 | default=42, 123 | help="Change this value (any integer number) will lead to a different generation result.", 124 | ) 125 | 126 | args = parser.parse_args() 127 | 128 | if(args.ckpt_path is not None): 129 | print("Warning: ckpt_path has no effect after version 0.0.20.") 130 | 131 | assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5" 132 | 133 | mode = args.mode 134 | if(mode == "generation" and args.file_path is not None): 135 | mode = "generation_audio_to_audio" 136 | if(len(args.text) > 0): 137 | print("Warning: You have specified the --file_path. --text will be ignored") 138 | args.text = "" 139 | 140 | save_path = os.path.join(args.save_path, mode) 141 | 142 | if(args.file_path is not None): 143 | save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0])) 144 | 145 | text = args.text 146 | random_seed = args.seed 147 | duration = args.duration 148 | guidance_scale = args.guidance_scale 149 | n_candidate_gen_per_text = args.n_candidate_gen_per_text 150 | 151 | os.makedirs(save_path, exist_ok=True) 152 | audioldm = build_model(model_name=args.model_name) 153 | 154 | if(args.mode == "generation"): 155 | waveform = text_to_audio( 156 | audioldm, 157 | text, 158 | args.file_path, 159 | random_seed, 160 | duration=duration, 161 | guidance_scale=guidance_scale, 162 | ddim_steps=args.ddim_steps, 163 | n_candidate_gen_per_text=n_candidate_gen_per_text, 164 | batchsize=args.batchsize, 165 | ) 166 | 167 | elif(args.mode == "transfer"): 168 | assert args.file_path is not None 169 | assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path 170 | waveform = style_transfer( 171 | audioldm, 172 | text, 173 | args.file_path, 174 | args.transfer_strength, 175 | random_seed, 176 | duration=duration, 177 | guidance_scale=guidance_scale, 178 | ddim_steps=args.ddim_steps, 179 | batchsize=args.batchsize, 180 | ) 181 | waveform = waveform[:,None,:] 182 | 183 | save_wave(waveform, save_path, name="%s_%s" % (get_time(), text)) 184 | -------------------------------------------------------------------------------- /bin/audioldm.cmd: -------------------------------------------------------------------------------- 1 | @echo OFF 2 | python -m audioldm %* -------------------------------------------------------------------------------- /ckpt/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/ckpt/.gitkeep -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | # Generation 2 | audioldm --file_path trumpet.wav 3 | audioldm --file_path trumpet.wav -dur 25 4 | audioldm --file_path trumpet.wav -dur 2.5 5 | audioldm --text "A hammer is hitting a wooden surface" 6 | audioldm 7 | 8 | # False use cases 9 | audioldm --text "A hammer is hitting a wooden surface" --file_path trumpet.wav # Same as audioldm --file_path trumpet.wav 10 | 11 | 12 | # Transfer 13 | audioldm --mode "transfer" --file_path trumpet.wav -t "Children Singing" 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /scripts/text2sound.py: -------------------------------------------------------------------------------- 1 | import os 2 | from audioldm import text_to_audio, build_model, save_wave 3 | 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument( 9 | "-t", 10 | "--text", 11 | type=str, 12 | required=False, 13 | default="A hammer is hitting a wooden surface", 14 | help="Text prompt to the model for audio generation", 15 | ) 16 | 17 | parser.add_argument( 18 | "-s", 19 | "--save_path", 20 | type=str, 21 | required=False, 22 | help="The path to save model output", 23 | default="./output", 24 | ) 25 | 26 | parser.add_argument( 27 | "-ckpt", 28 | "--ckpt_path", 29 | type=str, 30 | required=False, 31 | help="The path to the pretrained .ckpt model", 32 | default="./ckpt/audioldm-s-full.ckpt", 33 | ) 34 | 35 | parser.add_argument( 36 | "-b", 37 | "--batchsize", 38 | type=int, 39 | required=False, 40 | default=1, 41 | help="Generate how many samples at the same time", 42 | ) 43 | 44 | parser.add_argument( 45 | "-gs", 46 | "--guidance_scale", 47 | type=float, 48 | required=False, 49 | default=2.5, 50 | help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)", 51 | ) 52 | 53 | parser.add_argument( 54 | "-dur", 55 | "--duration", 56 | type=float, 57 | required=False, 58 | default=10.0, 59 | help="The duration of the samples", 60 | ) 61 | 62 | parser.add_argument( 63 | "-n", 64 | "--n_candidate_gen_per_text", 65 | type=int, 66 | required=False, 67 | default=3, 68 | help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation", 69 | ) 70 | 71 | parser.add_argument( 72 | "--seed", 73 | type=int, 74 | required=False, 75 | default=42, 76 | help="Change this value (any integer number) will lead to a different generation result.", 77 | ) 78 | 79 | args = parser.parse_args() 80 | 81 | assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5" 82 | 83 | save_path = args.save_path 84 | text = args.text 85 | random_seed = args.seed 86 | duration = args.duration 87 | guidance_scale = args.guidance_scale 88 | n_candidate_gen_per_text = args.n_candidate_gen_per_text 89 | 90 | os.makedirs(save_path, exist_ok=True) 91 | audioldm = build_model(ckpt_path=args.ckpt_path) 92 | waveform = text_to_audio( 93 | audioldm, 94 | text, 95 | seed=random_seed, 96 | duration=duration, 97 | guidance_scale=guidance_scale, 98 | n_candidate_gen_per_text=n_candidate_gen_per_text, 99 | batchsize=args.batchsize, 100 | ) 101 | 102 | save_wave(waveform, save_path, name=text) 103 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | # python3 setup.py sdist bdist_wheel 4 | """ 5 | @File : setup.py.py 6 | @Contact : haoheliu@gmail.com 7 | @License : (C)Copyright 2020-2100 8 | 9 | @Modify Time @Author @Version @Desciption 10 | ------------ ------- -------- ----------- 11 | 9/6/21 5:16 PM Haohe Liu 1.0 None 12 | """ 13 | 14 | # !/usr/bin/env python 15 | # -*- coding: utf-8 -*- 16 | 17 | # Note: To use the 'upload' functionality of this file, you must: 18 | # $ pipenv install twine --dev 19 | 20 | import io 21 | import os 22 | import sys 23 | from shutil import rmtree 24 | 25 | from setuptools import find_packages, setup, Command 26 | 27 | # Package meta-data. 28 | NAME = "audioldm" 29 | DESCRIPTION = "This package is written for text-to-audio generation." 30 | URL = "https://github.com/haoheliu/audioldm" 31 | EMAIL = "haoheliu@gmail.com" 32 | AUTHOR = "Haohe Liu" 33 | REQUIRES_PYTHON = ">=3.7.0" 34 | VERSION = "0.1.1" 35 | 36 | # What packages are required for this module to be executed? 37 | REQUIRED = [ 38 | "torch>=1.13.0", 39 | "torchaudio>=0.13.0", 40 | "torchvision>=0.14.0", 41 | "tqdm", 42 | "gradio", 43 | "pyyaml", 44 | "einops", 45 | "chardet", 46 | "numpy<=1.23.5", 47 | "soundfile", 48 | "librosa==0.9.2", 49 | "scipy", 50 | "pandas", 51 | "torchlibrosa==0.0.9", 52 | "transformers==4.29.0", 53 | "progressbar", 54 | "ftfy", 55 | ] 56 | 57 | # What packages are optional? 58 | EXTRAS = {} 59 | 60 | # The rest you shouldn't have to touch too much :) 61 | # ------------------------------------------------ 62 | # Except, perhaps the License and Trove Classifiers! 63 | # If you do change the License, remember to change the Trove Classifier for that! 64 | 65 | here = os.path.abspath(os.path.dirname(__file__)) 66 | 67 | # Import the README and use it as the long-description. 68 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 69 | try: 70 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: 71 | long_description = "\n" + f.read() 72 | except FileNotFoundError: 73 | long_description = DESCRIPTION 74 | 75 | # Load the package's __version__.py module as a dictionary. 76 | about = {} 77 | if not VERSION: 78 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 79 | with open(os.path.join(here, project_slug, "__version__.py")) as f: 80 | exec(f.read(), about) 81 | else: 82 | about["__version__"] = VERSION 83 | 84 | 85 | class UploadCommand(Command): 86 | """Support setup.py upload.""" 87 | 88 | description = "Build and publish the package." 89 | user_options = [] 90 | 91 | @staticmethod 92 | def status(s): 93 | """Prints things in bold.""" 94 | print("\033[1m{0}\033[0m".format(s)) 95 | 96 | def initialize_options(self): 97 | pass 98 | 99 | def finalize_options(self): 100 | pass 101 | 102 | def run(self): 103 | try: 104 | self.status("Removing previous builds…") 105 | rmtree(os.path.join(here, "dist")) 106 | except OSError: 107 | pass 108 | 109 | self.status("Building Source and Wheel (universal) distribution…") 110 | os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) 111 | 112 | self.status("Uploading the package to PyPI via Twine…") 113 | os.system("twine upload dist/*") 114 | 115 | self.status("Pushing git tags…") 116 | os.system("git tag v{0}".format(about["__version__"])) 117 | os.system("git push --tags") 118 | 119 | sys.exit() 120 | 121 | 122 | # Where the magic happens: 123 | setup( 124 | name=NAME, 125 | version=about["__version__"], 126 | description=DESCRIPTION, 127 | long_description=long_description, 128 | long_description_content_type="text/markdown", 129 | author=AUTHOR, 130 | author_email=EMAIL, 131 | python_requires=REQUIRES_PYTHON, 132 | url=URL, 133 | # packages=find_packages(exclude=[]), 134 | # If your package is a single module, use this instead of 'packages': 135 | # py_modules=["audioldm"], 136 | # entry_points={ 137 | # 'console_scripts': ['mycli=mymodule:cli'], 138 | # }, 139 | install_requires=REQUIRED, 140 | extras_require=EXTRAS, 141 | packages=find_packages(), 142 | # package_data={'bpe': ['audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz']}, 143 | include_package_data=True, 144 | license="MIT", 145 | classifiers=[ 146 | # Trove classifiers 147 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 148 | "License :: OSI Approved :: MIT License", 149 | "Programming Language :: Python", 150 | "Programming Language :: Python :: 3", 151 | "Programming Language :: Python :: 3.7", 152 | "Programming Language :: Python :: Implementation :: CPython", 153 | "Programming Language :: Python :: Implementation :: PyPy", 154 | ], 155 | # $ setup.py publish support. 156 | cmdclass={ 157 | "upload": UploadCommand, 158 | }, 159 | scripts=["bin/audioldm.cmd", "bin/audioldm"], 160 | ) 161 | -------------------------------------------------------------------------------- /trumpet.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/trumpet.wav --------------------------------------------------------------------------------