├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── app.py
├── audioldm
├── __init__.py
├── __main__.py
├── audio
│ ├── __init__.py
│ ├── audio_processing.py
│ ├── stft.py
│ └── tools.py
├── clap
│ ├── __init__.py
│ ├── encoders.py
│ ├── open_clip
│ │ ├── __init__.py
│ │ ├── bert.py
│ │ ├── bpe_simple_vocab_16e6.txt.gz
│ │ ├── factory.py
│ │ ├── feature_fusion.py
│ │ ├── htsat.py
│ │ ├── linear_probe.py
│ │ ├── loss.py
│ │ ├── model.py
│ │ ├── model_configs
│ │ │ ├── HTSAT-base.json
│ │ │ ├── HTSAT-large.json
│ │ │ ├── HTSAT-tiny-win-1536.json
│ │ │ ├── HTSAT-tiny.json
│ │ │ ├── PANN-10.json
│ │ │ ├── PANN-14-fmax-18k.json
│ │ │ ├── PANN-14-fmax-8k-20s.json
│ │ │ ├── PANN-14-tiny-transformer.json
│ │ │ ├── PANN-14-win-1536.json
│ │ │ ├── PANN-14.json
│ │ │ ├── PANN-6.json
│ │ │ ├── RN101-quickgelu.json
│ │ │ ├── RN101.json
│ │ │ ├── RN50-quickgelu.json
│ │ │ ├── RN50.json
│ │ │ ├── RN50x16.json
│ │ │ ├── RN50x4.json
│ │ │ ├── ViT-B-16.json
│ │ │ ├── ViT-B-32-quickgelu.json
│ │ │ ├── ViT-B-32.json
│ │ │ └── ViT-L-14.json
│ │ ├── openai.py
│ │ ├── pann_model.py
│ │ ├── pretrained.py
│ │ ├── timm_model.py
│ │ ├── tokenizer.py
│ │ ├── transform.py
│ │ ├── utils.py
│ │ └── version.py
│ └── training
│ │ ├── __init__.py
│ │ ├── audioset_textmap.npy
│ │ ├── data.py
│ │ ├── distributed.py
│ │ ├── imagenet_zeroshot_data.py
│ │ ├── infer_demo.py
│ │ ├── logger.py
│ │ ├── lp_main.py
│ │ ├── lp_train.py
│ │ ├── main.py
│ │ ├── params.py
│ │ ├── scheduler.py
│ │ ├── train.py
│ │ └── zero_shot.py
├── hifigan
│ ├── __init__.py
│ ├── models.py
│ └── utilities.py
├── latent_diffusion
│ ├── __init__.py
│ ├── attention.py
│ ├── ddim.py
│ ├── ddpm.py
│ ├── ema.py
│ ├── openaimodel.py
│ └── util.py
├── ldm.py
├── pipeline.py
├── utils.py
└── variational_autoencoder
│ ├── __init__.py
│ ├── autoencoder.py
│ ├── distributions.py
│ └── modules.py
├── bg.png
├── bin
├── audioldm
└── audioldm.cmd
├── ckpt
└── .gitkeep
├── scripts
├── test.sh
└── text2sound.py
├── setup.py
└── trumpet.wav
/.gitignore:
--------------------------------------------------------------------------------
1 | *.egg-info
2 | __pycache__
3 | build
4 | dist
5 | output
6 | *temp.py
7 | *.wav
8 | gradio_cached_examples
9 | run.sh
10 | run2.sh
11 | output/
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include *.py LICENSE README.md
2 | recursive-include audioldm *.txt *.py *.gz *.npy *.json
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # :sound: Audio Generation with AudioLDM (ICML 2023)
2 |
3 | [](https://arxiv.org/abs/2301.12503) [](https://audioldm.github.io/) [](https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation) [](https://colab.research.google.com/github/olaviinha/NeuralTextToAudio/blob/main/AudioLDM_pub.ipynb?force_theme=dark) [](https://replicate.com/jagilley/audio-ldm)
4 |
5 |
6 |
7 | **Generate speech, sound effects, music and beyond.**
8 |
9 | This repo currently support:
10 |
11 | - **Text-to-Audio Generation**: Generate audio given text input.
12 | - **Audio-to-Audio Generation**: Given an audio, generate another audio that contain the same type of sound.
13 | - **Text-guided Audio-to-Audio Style Transfer**: Transfer the sound of an audio into another one using the text description.
14 |
15 |
16 |
17 | ## Important tricks to make your generated audio sound better
18 | 1. Try to provide more hints to AudioLDM, such as using more adjectives to describe your sound (e.g., clearly, high quality) or make your target more specific (e.g., "water stream in a forest" instead of "stream"). This can make sure AudioLDM understand what you want.
19 | 2. Try to use different random seeds, which can affect the generation quality significantly sometimes.
20 | 3. It's best to use general terms like 'man' or 'woman' instead of specific names for individuals or abstract objects that humans may not be familiar with.
21 |
22 | # Change Log
23 |
24 | **2023-04-10**: Try to finetune AudioLDM with MusicCaps and AudioCaps datasets. Add three more checkpoints, including audioldm-m-text-ft, audioldm-s-text-ft, and audioldm-m-full.
25 |
26 | **2023-03-04**: Add two more checkpoints, one is small model with more training steps, another is a large model. Add model selection in the Gradio APP.
27 |
28 | **2023-02-24**: Add audio-to-audio generation. Add test cases. Add a pipeline (python function) for audio super-resolution and inpainting.
29 |
30 | **2023-02-15**: Add audio style transfer. Add more options on generation.
31 |
32 | ## Web APP
33 |
34 | The web APP currently only support Text-to-Audio generation. For full functionality please refer to the [Commandline Usage](https://github.com/haoheliu/AudioLDM#commandline-usage).
35 |
36 | 1. Prepare running environment
37 | ```shell
38 | conda create -n audioldm python=3.8; conda activate audioldm
39 | pip3 install git+https://github.com/haoheliu/AudioLDM.git
40 | git clone https://github.com/haoheliu/AudioLDM; cd AudioLDM
41 | ```
42 | 2. Start the web application (powered by Gradio)
43 | ```shell
44 | python3 app.py
45 | ```
46 | 3. A link will be printed out. Click the link to open the browser and play.
47 |
48 | ## Commandline Usage
49 | Prepare running environment
50 | ```shell
51 | # Optional
52 | conda create -n audioldm python=3.8; conda activate audioldm
53 | # Install AudioLDM
54 | pip3 install git+https://github.com/haoheliu/AudioLDM.git
55 | ```
56 |
57 | :star2: **Text-to-Audio Generation**: generate an audio guided by a text
58 | ```shell
59 | # The default --mode is "generation"
60 | audioldm -t "A hammer is hitting a wooden surface"
61 | # Result will be saved in "./output/generation"
62 | ```
63 |
64 | :star2: **Audio-to-Audio Generation**: generate an audio guided by an audio (output will have similar audio events as the input audio file).
65 | ```shell
66 | audioldm --file_path trumpet.wav
67 | # Result will be saved in "./output/generation_audio_to_audio/trumpet"
68 | ```
69 |
70 | :star2: **Text-guided Audio-to-Audio Style Transfer**
71 | ```shell
72 | # Test run
73 | # --file_path is the original audio file for transfer
74 | # -t is the text AudioLDM uses for transfer.
75 | # Please make sure that --file_path exist
76 | audioldm --mode "transfer" --file_path trumpet.wav -t "Children Singing"
77 | # Result will be saved in "./output/transfer/trumpet"
78 |
79 | # Tune the value of --transfer_strength is important!
80 | # --transfer_strength: A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text
81 | audioldm --mode "transfer" --file_path trumpet.wav -t "Children Singing" --transfer_strength 0.25
82 | ```
83 |
84 | :gear: How to choose between different model checkpoints?
85 | ```
86 | # Add the --model_name parameter, choice={audioldm-m-text-ft, audioldm-s-text-ft, audioldm-m-full, audioldm-s-full,audioldm-l-full,audioldm-s-full-v2}
87 | audioldm --model_name audioldm-s-full
88 | ```
89 |
90 | - :star: audioldm-m-full (default, **recommend**): the medium AudioLDM without finetuning and trained with audio embeddings as condition *(added 2023-04-10)*.
91 | - :star: audioldm-s-full (**recommend**): the original open-sourced version *(added 2023-02-01)*.
92 | - :star: audioldm-s-full-v2 (**recommend**): more training steps comparing with audioldm-s-full *(added 2023-03-04)*.
93 | - audioldm-s-text-ft: the small AudioLDM finetuned with AudioCaps and MusicCaps audio-text pairs *(added 2023-04-10)*.
94 | - audioldm-m-text-ft: the medium large AudioLDM finetuned with AudioCaps and MusicCaps audio-text pairs *(added 2023-04-10)*.
95 | - audioldm-l-full: larger model comparing with audioldm-s-full *(added 2023-03-04)*.
96 |
97 | > @haoheliu personally did a evaluation regarding the overall quality of the checkpoint, which gives audioldm-m-full (6.85/10), audioldm-s-full (6.62/10), audioldm-s-text-ft (6/10), audioldm-m-text-ft (5.46/10). These score are only for reference and may not reflect the true performance of the checkpoint. Checkpoint performance also varying with different text input as well.
98 |
99 | :grey_question: For more options on guidance scale, batchsize, seed, ddim steps, etc., please run
100 | ```shell
101 | audioldm -h
102 | ```
103 | ```console
104 | usage: audioldm [-h] [--mode {generation,transfer}] [-t TEXT] [-f FILE_PATH] [--transfer_strength TRANSFER_STRENGTH] [-s SAVE_PATH] [--model_name {audioldm-s-full,audioldm-l-full,audioldm-s-full-v2}] [-ckpt CKPT_PATH]
105 | [-b BATCHSIZE] [--ddim_steps DDIM_STEPS] [-gs GUIDANCE_SCALE] [-dur DURATION] [-n N_CANDIDATE_GEN_PER_TEXT] [--seed SEED]
106 |
107 | optional arguments:
108 | -h, --help show this help message and exit
109 | --mode {generation,transfer}
110 | generation: text-to-audio generation; transfer: style transfer
111 | -t TEXT, --text TEXT Text prompt to the model for audio generation, DEFAULT ""
112 | -f FILE_PATH, --file_path FILE_PATH
113 | (--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio, DEFAULT None
114 | --transfer_strength TRANSFER_STRENGTH
115 | A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text, DEFAULT 0.5
116 | -s SAVE_PATH, --save_path SAVE_PATH
117 | The path to save model output, DEFAULT "./output"
118 | --model_name {audioldm-s-full,audioldm-l-full,audioldm-s-full-v2}
119 | The checkpoint you gonna use, DEFAULT "audioldm-s-full"
120 | -ckpt CKPT_PATH, --ckpt_path CKPT_PATH
121 | (deprecated) The path to the pretrained .ckpt model, DEFAULT None
122 | -b BATCHSIZE, --batchsize BATCHSIZE
123 | Generate how many samples at the same time, DEFAULT 1
124 | --ddim_steps DDIM_STEPS
125 | The sampling step for DDIM, DEFAULT 200
126 | -gs GUIDANCE_SCALE, --guidance_scale GUIDANCE_SCALE
127 | Guidance scale (Large => better quality and relavancy to text; Small => better diversity), DEFAULT 2.5
128 | -dur DURATION, --duration DURATION
129 | The duration of the samples, DEFAULT 10
130 | -n N_CANDIDATE_GEN_PER_TEXT, --n_candidate_gen_per_text N_CANDIDATE_GEN_PER_TEXT
131 | Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation, DEFAULT 3
132 | --seed SEED Change this value (any integer number) will lead to a different generation result. DEFAULT 42
133 | ```
134 |
135 | For the evaluation of audio generative model, please refer to [audioldm_eval](https://github.com/haoheliu/audioldm_eval).
136 |
137 | # Hugging Face 🧨 Diffusers
138 |
139 | AudioLDM is available in the Hugging Face [🧨 Diffusers](https://github.com/huggingface/diffusers) library from v0.15.0 onwards. The official checkpoints can be found on the [Hugging Face Hub](https://huggingface.co/cvssp), alongside [documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm) and [examples scripts](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm).
140 |
141 | To install Diffusers and Transformers, run:
142 | ```bash
143 | pip install --upgrade diffusers transformers
144 | ```
145 |
146 | You can then load pre-trained weights into the [AudioLDM pipeline](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm) and generate text-conditional audio outputs:
147 | ```python
148 | from diffusers import AudioLDMPipeline
149 | import torch
150 |
151 | repo_id = "cvssp/audioldm-s-full-v2"
152 | pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
153 | pipe = pipe.to("cuda")
154 |
155 | prompt = "Techno music with a strong, upbeat tempo and high melodic riffs"
156 | audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0]
157 | ```
158 |
159 | # Web Demo
160 |
161 | Integrated into [Hugging Face Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo [](https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation)
162 |
163 | # TuneFlow Demo
164 |
165 | Try out AudioLDM as a [TuneFlow](https://tuneflow.com) plugin [](https://github.com/tuneflow/AudioLDM). See how it can work in a real DAW (Digital Audio Workstation).
166 |
167 | # TODO
168 |
169 | [](https://www.buymeacoffee.com/haoheliuP)
170 |
171 | - [x] Update the checkpoint with more training steps.
172 | - [x] Update the checkpoint with more parameters (audioldm-l).
173 | - [ ] Add AudioCaps finetuned AudioLDM-S model
174 | - [x] Build pip installable package for commandline use
175 | - [x] Build Gradio web application
176 | - [ ] Add super-resolution, inpainting into Gradio web application
177 | - [ ] Add style-transfer into Gradio web application
178 | - [x] Add text-guided style transfer
179 | - [x] Add audio-to-audio generation
180 | - [x] Add audio super-resolution
181 | - [x] Add audio inpainting
182 |
183 | ## Cite this work
184 |
185 | If you found this tool useful, please consider citing
186 | ```bibtex
187 | @article{liu2023audioldm,
188 | title={{AudioLDM}: Text-to-Audio Generation with Latent Diffusion Models},
189 | author={Liu, Haohe and Chen, Zehua and Yuan, Yi and Mei, Xinhao and Liu, Xubo and Mandic, Danilo and Wang, Wenwu and Plumbley, Mark D},
190 | journal={Proceedings of the International Conference on Machine Learning},
191 | year={2023},
192 | pages={21450-21474}
193 | }
194 | ```
195 |
196 | # Hardware requirement
197 | - GPU with 8GB of dedicated VRAM
198 | - A system with a 64-bit operating system (Windows 7, 8.1 or 10, Ubuntu 16.04 or later, or macOS 10.13 or later) 16GB or more of system RAM
199 |
200 | ## Reference
201 | Part of the code is borrowed from the following repos. We would like to thank the authors of these repos for their contribution.
202 |
203 | > https://github.com/LAION-AI/CLAP
204 |
205 | > https://github.com/CompVis/stable-diffusion
206 |
207 | > https://github.com/v-iashin/SpecVQGAN
208 |
209 | > https://github.com/toshas/torch-fidelity
210 |
211 |
212 | We build the model with data from AudioSet, Freesound and BBC Sound Effect library. We share this demo based on the UK copyright exception of data for academic research.
213 |
214 |
215 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import numpy as np
3 | from audioldm import text_to_audio, build_model
4 |
5 | # from share_btn import community_icon_html, loading_icon_html, share_js
6 |
7 | model_id = "haoheliu/AudioLDM-S-Full"
8 |
9 | audioldm = None
10 | current_model_name = None
11 | # audioldm=None
12 |
13 | # def predict(input, history=[]):
14 | # # tokenize the new input sentence
15 | # new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
16 |
17 | # # append the new user input tokens to the chat history
18 | # bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
19 |
20 | # # generate a response
21 | # history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
22 |
23 | # # convert the tokens to text, and then split the responses into lines
24 | # response = tokenizer.decode(history[0]).split("<|endoftext|>")
25 | # response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list
26 | # return response, history
27 |
28 | def text2audio(text, duration, guidance_scale, random_seed, n_candidates, model_name):
29 | global audioldm, current_model_name
30 |
31 | if audioldm is None or model_name != current_model_name:
32 | audioldm=build_model(model_name=model_name)
33 | current_model_name = model_name
34 |
35 | # print(text, length, guidance_scale)
36 | waveform = text_to_audio(
37 | latent_diffusion=audioldm,
38 | text=text,
39 | seed=random_seed,
40 | duration=duration,
41 | guidance_scale=guidance_scale,
42 | n_candidate_gen_per_text=int(n_candidates),
43 | ) # [bs, 1, samples]
44 | waveform = [
45 | gr.make_waveform((16000, wave[0]), bg_image="bg.png") for wave in waveform
46 | ]
47 | # waveform = [(16000, np.random.randn(16000)), (16000, np.random.randn(16000))]
48 | if len(waveform) == 1:
49 | waveform = waveform[0]
50 | return waveform
51 |
52 |
53 | # iface = gr.Interface(fn=text2audio, inputs=[
54 | # gr.Textbox(value="A man is speaking in a huge room", max_lines=1),
55 | # gr.Slider(2.5, 10, value=5, step=2.5),
56 | # gr.Slider(0, 5, value=2.5, step=0.5),
57 | # gr.Number(value=42)
58 | # ], outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")],
59 | # allow_flagging="never"
60 | # )
61 | # iface.launch(share=True)
62 |
63 |
64 | css = """
65 | a {
66 | color: inherit;
67 | text-decoration: underline;
68 | }
69 | .gradio-container {
70 | font-family: 'IBM Plex Sans', sans-serif;
71 | }
72 | .gr-button {
73 | color: white;
74 | border-color: #000000;
75 | background: #000000;
76 | }
77 | input[type='range'] {
78 | accent-color: #000000;
79 | }
80 | .dark input[type='range'] {
81 | accent-color: #dfdfdf;
82 | }
83 | .container {
84 | max-width: 730px;
85 | margin: auto;
86 | padding-top: 1.5rem;
87 | }
88 | #gallery {
89 | min-height: 22rem;
90 | margin-bottom: 15px;
91 | margin-left: auto;
92 | margin-right: auto;
93 | border-bottom-right-radius: .5rem !important;
94 | border-bottom-left-radius: .5rem !important;
95 | }
96 | #gallery>div>.h-full {
97 | min-height: 20rem;
98 | }
99 | .details:hover {
100 | text-decoration: underline;
101 | }
102 | .gr-button {
103 | white-space: nowrap;
104 | }
105 | .gr-button:focus {
106 | border-color: rgb(147 197 253 / var(--tw-border-opacity));
107 | outline: none;
108 | box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
109 | --tw-border-opacity: 1;
110 | --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
111 | --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
112 | --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
113 | --tw-ring-opacity: .5;
114 | }
115 | #advanced-btn {
116 | font-size: .7rem !important;
117 | line-height: 19px;
118 | margin-top: 12px;
119 | margin-bottom: 12px;
120 | padding: 2px 8px;
121 | border-radius: 14px !important;
122 | }
123 | #advanced-options {
124 | margin-bottom: 20px;
125 | }
126 | .footer {
127 | margin-bottom: 45px;
128 | margin-top: 35px;
129 | text-align: center;
130 | border-bottom: 1px solid #e5e5e5;
131 | }
132 | .footer>p {
133 | font-size: .8rem;
134 | display: inline-block;
135 | padding: 0 10px;
136 | transform: translateY(10px);
137 | background: white;
138 | }
139 | .dark .footer {
140 | border-color: #303030;
141 | }
142 | .dark .footer>p {
143 | background: #0b0f19;
144 | }
145 | .acknowledgments h4{
146 | margin: 1.25em 0 .25em 0;
147 | font-weight: bold;
148 | font-size: 115%;
149 | }
150 | #container-advanced-btns{
151 | display: flex;
152 | flex-wrap: wrap;
153 | justify-content: space-between;
154 | align-items: center;
155 | }
156 | .animate-spin {
157 | animation: spin 1s linear infinite;
158 | }
159 | @keyframes spin {
160 | from {
161 | transform: rotate(0deg);
162 | }
163 | to {
164 | transform: rotate(360deg);
165 | }
166 | }
167 | #share-btn-container {
168 | display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
169 | margin-top: 10px;
170 | margin-left: auto;
171 | }
172 | #share-btn {
173 | all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
174 | }
175 | #share-btn * {
176 | all: unset;
177 | }
178 | #share-btn-container div:nth-child(-n+2){
179 | width: auto !important;
180 | min-height: 0px !important;
181 | }
182 | #share-btn-container .wrap {
183 | display: none !important;
184 | }
185 | .gr-form{
186 | flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
187 | }
188 | #prompt-container{
189 | gap: 0;
190 | }
191 | #generated_id{
192 | min-height: 700px
193 | }
194 | #setting_id{
195 | margin-bottom: 12px;
196 | text-align: center;
197 | font-weight: 900;
198 | }
199 | """
200 | iface = gr.Blocks(css=css)
201 |
202 | with iface:
203 | gr.HTML(
204 | """
205 |
206 |
214 |
215 | AudioLDM: Text-to-Audio Generation with Latent Diffusion Models
216 |
217 |
218 |
219 | [Paper] [Project page]
220 |
221 |
222 | """
223 | )
224 | with gr.Group():
225 | with gr.Box():
226 | ############# Input
227 | textbox = gr.Textbox(
228 | value="A hammer is hitting a wooden surface",
229 | max_lines=1,
230 | label="Input your text here. Please ensure it is descriptive and of moderate length.",
231 | elem_id="prompt-in",
232 | )
233 |
234 | with gr.Accordion("Click to modify detailed configurations", open=False):
235 | seed = gr.Number(
236 | value=42,
237 | label="Change this value (any integer number) will lead to a different generation result.",
238 | )
239 | duration = gr.Slider(
240 | 2.5, 10, value=5, step=2.5, label="Duration (seconds)"
241 | )
242 | guidance_scale = gr.Slider(
243 | 0,
244 | 5,
245 | value=2.5,
246 | step=0.5,
247 | label="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
248 | )
249 | n_candidates = gr.Slider(
250 | 1,
251 | 5,
252 | value=3,
253 | step=1,
254 | label="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
255 | )
256 | model_name = gr.Dropdown(
257 | ["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2","audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full"], value="audioldm-m-full", label="Choose the model to use. audioldm-m-text-ft and audioldm-s-text-ft are recommanded. -s- means small, -m- means medium and -l- means large",
258 | )
259 | ############# Output
260 | # outputs=gr.Audio(label="Output", type="numpy")
261 | outputs = gr.Video(label="Output", elem_id="output-video")
262 |
263 | # with gr.Group(elem_id="container-advanced-btns"):
264 | # # advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
265 | # with gr.Group(elem_id="share-btn-container"):
266 | # community_icon = gr.HTML(community_icon_html, visible=False)
267 | # loading_icon = gr.HTML(loading_icon_html, visible=False)
268 | # share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
269 | # outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")]
270 | btn = gr.Button("Submit").style(full_width=True)
271 |
272 | # with gr.Group(elem_id="share-btn-container", visible=False):
273 | # community_icon = gr.HTML(community_icon_html)
274 | # loading_icon = gr.HTML(loading_icon_html)
275 | # share_button = gr.Button("Share to community", elem_id="share-btn")
276 |
277 | btn.click(
278 | text2audio,
279 | inputs=[textbox, duration, guidance_scale, seed, n_candidates, model_name],
280 | outputs=[outputs],
281 | )
282 |
283 | # share_button.click(None, [], [], _js=share_js)
284 | gr.HTML(
285 | """
286 |
293 | """
294 | )
295 | # gr.Examples(
296 | # [
297 | # ["A hammer is hitting a wooden surface", 5, 2.5, 45, 3, "audioldm-s-full"],
298 | # [
299 | # "Peaceful and calming ambient music with singing bowl and other instruments.",
300 | # 5,
301 | # 2.5,
302 | # 45,
303 | # 3,
304 | # "audioldm-s-full"
305 | # ],
306 | # ["A man is speaking in a small room.", 5, 2.5, 45, 3, "audioldm-s-full"],
307 | # ["A female is speaking followed by footstep sound", 5, 2.5, 45, 3, "audioldm-s-full"],
308 | # [
309 | # "Wooden table tapping sound followed by water pouring sound.",
310 | # 5,
311 | # 2.5,
312 | # 45,
313 | # 3,
314 | # "audioldm-s-full"
315 | # ],
316 | # ],
317 | # fn=text2audio,
318 | # inputs=[textbox, duration, guidance_scale, seed, n_candidates, model_name],
319 | # outputs=[outputs],
320 | # cache_examples=True,
321 | # )
322 | with gr.Accordion("Additional information", open=False):
323 | gr.HTML(
324 | """
325 |
328 | """
329 | )
330 | # This demo is strictly for research demo purpose only. For commercial use please contact us.
331 |
332 | iface.queue(concurrency_count=3)
333 | # iface.launch(debug=True)
334 | iface.launch(debug=True, share=False)
335 |
--------------------------------------------------------------------------------
/audioldm/__init__.py:
--------------------------------------------------------------------------------
1 | from .ldm import LatentDiffusion
2 | from .utils import seed_everything, save_wave, get_time, get_duration
3 | from .pipeline import *
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/audioldm/__main__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | import os
3 | from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration
4 | import argparse
5 |
6 | CACHE_DIR = os.getenv(
7 | "AUDIOLDM_CACHE_DIR",
8 | os.path.join(os.path.expanduser("~"), ".cache/audioldm"))
9 |
10 | parser = argparse.ArgumentParser()
11 |
12 | parser.add_argument(
13 | "--mode",
14 | type=str,
15 | required=False,
16 | default="generation",
17 | help="generation: text-to-audio generation; transfer: style transfer",
18 | choices=["generation", "transfer"]
19 | )
20 |
21 | parser.add_argument(
22 | "-t",
23 | "--text",
24 | type=str,
25 | required=False,
26 | default="",
27 | help="Text prompt to the model for audio generation",
28 | )
29 |
30 | parser.add_argument(
31 | "-f",
32 | "--file_path",
33 | type=str,
34 | required=False,
35 | default=None,
36 | help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio",
37 | )
38 |
39 | parser.add_argument(
40 | "--transfer_strength",
41 | type=float,
42 | required=False,
43 | default=0.5,
44 | help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text",
45 | )
46 |
47 | parser.add_argument(
48 | "-s",
49 | "--save_path",
50 | type=str,
51 | required=False,
52 | help="The path to save model output",
53 | default="./output",
54 | )
55 |
56 | parser.add_argument(
57 | "--model_name",
58 | type=str,
59 | required=False,
60 | help="The checkpoint you gonna use",
61 | default="audioldm-m-full",
62 | choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2","audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full"]
63 | )
64 |
65 | parser.add_argument(
66 | "-ckpt",
67 | "--ckpt_path",
68 | type=str,
69 | required=False,
70 | help="The path to the pretrained .ckpt model",
71 | default=None,
72 | )
73 |
74 | parser.add_argument(
75 | "-b",
76 | "--batchsize",
77 | type=int,
78 | required=False,
79 | default=1,
80 | help="Generate how many samples at the same time",
81 | )
82 |
83 | parser.add_argument(
84 | "--ddim_steps",
85 | type=int,
86 | required=False,
87 | default=200,
88 | help="The sampling step for DDIM",
89 | )
90 |
91 | parser.add_argument(
92 | "-gs",
93 | "--guidance_scale",
94 | type=float,
95 | required=False,
96 | default=2.5,
97 | help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
98 | )
99 |
100 | parser.add_argument(
101 | "-dur",
102 | "--duration",
103 | type=float,
104 | required=False,
105 | default=10.0,
106 | help="The duration of the samples",
107 | )
108 |
109 | parser.add_argument(
110 | "-n",
111 | "--n_candidate_gen_per_text",
112 | type=int,
113 | required=False,
114 | default=3,
115 | help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
116 | )
117 |
118 | parser.add_argument(
119 | "--seed",
120 | type=int,
121 | required=False,
122 | default=42,
123 | help="Change this value (any integer number) will lead to a different generation result.",
124 | )
125 |
126 | args = parser.parse_args()
127 |
128 | if(args.ckpt_path is not None):
129 | print("Warning: ckpt_path has no effect after version 0.0.20.")
130 |
131 | assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5"
132 |
133 | mode = args.mode
134 | if(mode == "generation" and args.file_path is not None):
135 | mode = "generation_audio_to_audio"
136 | if(len(args.text) > 0):
137 | print("Warning: You have specified the --file_path. --text will be ignored")
138 | args.text = ""
139 |
140 | save_path = os.path.join(args.save_path, mode)
141 |
142 | if(args.file_path is not None):
143 | save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0]))
144 |
145 | text = args.text
146 | random_seed = args.seed
147 | duration = args.duration
148 | guidance_scale = args.guidance_scale
149 | n_candidate_gen_per_text = args.n_candidate_gen_per_text
150 |
151 | os.makedirs(save_path, exist_ok=True)
152 | audioldm = build_model(model_name=args.model_name)
153 |
154 | if(args.mode == "generation"):
155 | waveform = text_to_audio(
156 | audioldm,
157 | text,
158 | args.file_path,
159 | random_seed,
160 | duration=duration,
161 | guidance_scale=guidance_scale,
162 | ddim_steps=args.ddim_steps,
163 | n_candidate_gen_per_text=n_candidate_gen_per_text,
164 | batchsize=args.batchsize,
165 | )
166 |
167 | elif(args.mode == "transfer"):
168 | assert args.file_path is not None
169 | assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path
170 | waveform = style_transfer(
171 | audioldm,
172 | text,
173 | args.file_path,
174 | args.transfer_strength,
175 | random_seed,
176 | duration=duration,
177 | guidance_scale=guidance_scale,
178 | ddim_steps=args.ddim_steps,
179 | batchsize=args.batchsize,
180 | )
181 | waveform = waveform[:,None,:]
182 |
183 | save_wave(waveform, save_path, name="%s_%s" % (get_time(), text))
184 |
--------------------------------------------------------------------------------
/audioldm/audio/__init__.py:
--------------------------------------------------------------------------------
1 | from .tools import wav_to_fbank, read_wav_file
2 | from .stft import TacotronSTFT
3 |
--------------------------------------------------------------------------------
/audioldm/audio/audio_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import librosa.util as librosa_util
4 | from scipy.signal import get_window
5 |
6 |
7 | def window_sumsquare(
8 | window,
9 | n_frames,
10 | hop_length,
11 | win_length,
12 | n_fft,
13 | dtype=np.float32,
14 | norm=None,
15 | ):
16 | """
17 | # from librosa 0.6
18 | Compute the sum-square envelope of a window function at a given hop length.
19 |
20 | This is used to estimate modulation effects induced by windowing
21 | observations in short-time fourier transforms.
22 |
23 | Parameters
24 | ----------
25 | window : string, tuple, number, callable, or list-like
26 | Window specification, as in `get_window`
27 |
28 | n_frames : int > 0
29 | The number of analysis frames
30 |
31 | hop_length : int > 0
32 | The number of samples to advance between frames
33 |
34 | win_length : [optional]
35 | The length of the window function. By default, this matches `n_fft`.
36 |
37 | n_fft : int > 0
38 | The length of each analysis frame.
39 |
40 | dtype : np.dtype
41 | The data type of the output
42 |
43 | Returns
44 | -------
45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
46 | The sum-squared envelope of the window function
47 | """
48 | if win_length is None:
49 | win_length = n_fft
50 |
51 | n = n_fft + hop_length * (n_frames - 1)
52 | x = np.zeros(n, dtype=dtype)
53 |
54 | # Compute the squared window at the desired length
55 | win_sq = get_window(window, win_length, fftbins=True)
56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
57 | win_sq = librosa_util.pad_center(win_sq, n_fft)
58 |
59 | # Fill the envelope
60 | for i in range(n_frames):
61 | sample = i * hop_length
62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
63 | return x
64 |
65 |
66 | def griffin_lim(magnitudes, stft_fn, n_iters=30):
67 | """
68 | PARAMS
69 | ------
70 | magnitudes: spectrogram magnitudes
71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
72 | """
73 |
74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
75 | angles = angles.astype(np.float32)
76 | angles = torch.autograd.Variable(torch.from_numpy(angles))
77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
78 |
79 | for i in range(n_iters):
80 | _, angles = stft_fn.transform(signal)
81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
82 | return signal
83 |
84 |
85 | def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
86 | """
87 | PARAMS
88 | ------
89 | C: compression factor
90 | """
91 | return normalize_fun(torch.clamp(x, min=clip_val) * C)
92 |
93 |
94 | def dynamic_range_decompression(x, C=1):
95 | """
96 | PARAMS
97 | ------
98 | C: compression factor used to compress
99 | """
100 | return torch.exp(x) / C
101 |
--------------------------------------------------------------------------------
/audioldm/audio/stft.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy.signal import get_window
5 | from librosa.util import pad_center, tiny
6 | from librosa.filters import mel as librosa_mel_fn
7 |
8 | from audioldm.audio.audio_processing import (
9 | dynamic_range_compression,
10 | dynamic_range_decompression,
11 | window_sumsquare,
12 | )
13 |
14 |
15 | class STFT(torch.nn.Module):
16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
17 |
18 | def __init__(self, filter_length, hop_length, win_length, window="hann"):
19 | super(STFT, self).__init__()
20 | self.filter_length = filter_length
21 | self.hop_length = hop_length
22 | self.win_length = win_length
23 | self.window = window
24 | self.forward_transform = None
25 | scale = self.filter_length / self.hop_length
26 | fourier_basis = np.fft.fft(np.eye(self.filter_length))
27 |
28 | cutoff = int((self.filter_length / 2 + 1))
29 | fourier_basis = np.vstack(
30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
31 | )
32 |
33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
34 | inverse_basis = torch.FloatTensor(
35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]
36 | )
37 |
38 | if window is not None:
39 | assert filter_length >= win_length
40 | # get window and zero center pad it to filter_length
41 | fft_window = get_window(window, win_length, fftbins=True)
42 | fft_window = pad_center(fft_window, filter_length)
43 | fft_window = torch.from_numpy(fft_window).float()
44 |
45 | # window the bases
46 | forward_basis *= fft_window
47 | inverse_basis *= fft_window
48 |
49 | self.register_buffer("forward_basis", forward_basis.float())
50 | self.register_buffer("inverse_basis", inverse_basis.float())
51 |
52 | def transform(self, input_data):
53 | num_batches = input_data.size(0)
54 | num_samples = input_data.size(1)
55 |
56 | self.num_samples = num_samples
57 |
58 | # similar to librosa, reflect-pad the input
59 | input_data = input_data.view(num_batches, 1, num_samples)
60 | input_data = F.pad(
61 | input_data.unsqueeze(1),
62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
63 | mode="reflect",
64 | )
65 | input_data = input_data.squeeze(1)
66 |
67 | forward_transform = F.conv1d(
68 | input_data,
69 | torch.autograd.Variable(self.forward_basis, requires_grad=False),
70 | stride=self.hop_length,
71 | padding=0,
72 | ).cpu()
73 |
74 | cutoff = int((self.filter_length / 2) + 1)
75 | real_part = forward_transform[:, :cutoff, :]
76 | imag_part = forward_transform[:, cutoff:, :]
77 |
78 | magnitude = torch.sqrt(real_part**2 + imag_part**2)
79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
80 |
81 | return magnitude, phase
82 |
83 | def inverse(self, magnitude, phase):
84 | recombine_magnitude_phase = torch.cat(
85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
86 | )
87 |
88 | inverse_transform = F.conv_transpose1d(
89 | recombine_magnitude_phase,
90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False),
91 | stride=self.hop_length,
92 | padding=0,
93 | )
94 |
95 | if self.window is not None:
96 | window_sum = window_sumsquare(
97 | self.window,
98 | magnitude.size(-1),
99 | hop_length=self.hop_length,
100 | win_length=self.win_length,
101 | n_fft=self.filter_length,
102 | dtype=np.float32,
103 | )
104 | # remove modulation effects
105 | approx_nonzero_indices = torch.from_numpy(
106 | np.where(window_sum > tiny(window_sum))[0]
107 | )
108 | window_sum = torch.autograd.Variable(
109 | torch.from_numpy(window_sum), requires_grad=False
110 | )
111 | window_sum = window_sum
112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
113 | approx_nonzero_indices
114 | ]
115 |
116 | # scale by hop ratio
117 | inverse_transform *= float(self.filter_length) / self.hop_length
118 |
119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
121 |
122 | return inverse_transform
123 |
124 | def forward(self, input_data):
125 | self.magnitude, self.phase = self.transform(input_data)
126 | reconstruction = self.inverse(self.magnitude, self.phase)
127 | return reconstruction
128 |
129 |
130 | class TacotronSTFT(torch.nn.Module):
131 | def __init__(
132 | self,
133 | filter_length,
134 | hop_length,
135 | win_length,
136 | n_mel_channels,
137 | sampling_rate,
138 | mel_fmin,
139 | mel_fmax,
140 | ):
141 | super(TacotronSTFT, self).__init__()
142 | self.n_mel_channels = n_mel_channels
143 | self.sampling_rate = sampling_rate
144 | self.stft_fn = STFT(filter_length, hop_length, win_length)
145 | mel_basis = librosa_mel_fn(
146 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
147 | )
148 | mel_basis = torch.from_numpy(mel_basis).float()
149 | self.register_buffer("mel_basis", mel_basis)
150 |
151 | def spectral_normalize(self, magnitudes, normalize_fun):
152 | output = dynamic_range_compression(magnitudes, normalize_fun)
153 | return output
154 |
155 | def spectral_de_normalize(self, magnitudes):
156 | output = dynamic_range_decompression(magnitudes)
157 | return output
158 |
159 | def mel_spectrogram(self, y, normalize_fun=torch.log):
160 | """Computes mel-spectrograms from a batch of waves
161 | PARAMS
162 | ------
163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
164 |
165 | RETURNS
166 | -------
167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
168 | """
169 | assert torch.min(y.data) >= -1, torch.min(y.data)
170 | assert torch.max(y.data) <= 1, torch.max(y.data)
171 |
172 | magnitudes, phases = self.stft_fn.transform(y)
173 | magnitudes = magnitudes.data
174 | mel_output = torch.matmul(self.mel_basis, magnitudes)
175 | mel_output = self.spectral_normalize(mel_output, normalize_fun)
176 | energy = torch.norm(magnitudes, dim=1)
177 |
178 | log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
179 |
180 | return mel_output, log_magnitudes, energy
181 |
--------------------------------------------------------------------------------
/audioldm/audio/tools.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torchaudio
4 |
5 |
6 | def get_mel_from_wav(audio, _stft):
7 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
8 | audio = torch.autograd.Variable(audio, requires_grad=False)
9 | melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
10 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
11 | log_magnitudes_stft = (
12 | torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)
13 | )
14 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
15 | return melspec, log_magnitudes_stft, energy
16 |
17 |
18 | def _pad_spec(fbank, target_length=1024):
19 | n_frames = fbank.shape[0]
20 | p = target_length - n_frames
21 | # cut and pad
22 | if p > 0:
23 | m = torch.nn.ZeroPad2d((0, 0, 0, p))
24 | fbank = m(fbank)
25 | elif p < 0:
26 | fbank = fbank[0:target_length, :]
27 |
28 | if fbank.size(-1) % 2 != 0:
29 | fbank = fbank[..., :-1]
30 |
31 | return fbank
32 |
33 |
34 | def pad_wav(waveform, segment_length):
35 | waveform_length = waveform.shape[-1]
36 | assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
37 | if segment_length is None or waveform_length == segment_length:
38 | return waveform
39 | elif waveform_length > segment_length:
40 | return waveform[:segment_length]
41 | elif waveform_length < segment_length:
42 | temp_wav = np.zeros((1, segment_length))
43 | temp_wav[:, :waveform_length] = waveform
44 | return temp_wav
45 |
46 | def normalize_wav(waveform):
47 | waveform = waveform - np.mean(waveform)
48 | waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
49 | return waveform * 0.5
50 |
51 |
52 | def read_wav_file(filename, segment_length):
53 | # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
54 | waveform, sr = torchaudio.load(filename) # Faster!!!
55 | waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
56 | waveform = waveform.numpy()[0, ...]
57 | waveform = normalize_wav(waveform)
58 | waveform = waveform[None, ...]
59 | waveform = pad_wav(waveform, segment_length)
60 |
61 | waveform = waveform / np.max(np.abs(waveform))
62 | waveform = 0.5 * waveform
63 |
64 | return waveform
65 |
66 |
67 | def wav_to_fbank(filename, target_length=1024, fn_STFT=None):
68 | assert fn_STFT is not None
69 |
70 | # mixup
71 | waveform = read_wav_file(filename, target_length * 160) # hop size is 160
72 |
73 | waveform = waveform[0, ...]
74 | waveform = torch.FloatTensor(waveform)
75 |
76 | fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
77 |
78 | fbank = torch.FloatTensor(fbank.T)
79 | log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T)
80 |
81 | fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
82 | log_magnitudes_stft, target_length
83 | )
84 |
85 | return fbank, log_magnitudes_stft, waveform
86 |
--------------------------------------------------------------------------------
/audioldm/clap/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/audioldm/clap/__init__.py
--------------------------------------------------------------------------------
/audioldm/clap/encoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from audioldm.clap.open_clip import create_model
4 | from audioldm.clap.training.data import get_audio_features
5 | import torchaudio
6 | from transformers import RobertaTokenizer
7 | import torch.nn.functional as F
8 |
9 |
10 | class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
11 | def __init__(
12 | self,
13 | pretrained_path="",
14 | key="class",
15 | sampling_rate=16000,
16 | embed_mode="audio",
17 | amodel = "HTSAT-tiny",
18 | unconditional_prob=0.1,
19 | random_mute=False,
20 | max_random_mute_portion=0.5,
21 | training_mode=True,
22 | ):
23 | super().__init__()
24 |
25 | self.key = key
26 | self.device = "cpu"
27 | self.precision = "fp32"
28 | self.amodel = amodel # or 'PANN-14'
29 | self.tmodel = "roberta" # the best text encoder in our training
30 | self.enable_fusion = False # False if you do not want to use the fusion model
31 | self.fusion_type = "aff_2d"
32 | self.pretrained = pretrained_path
33 | self.embed_mode = embed_mode
34 | self.embed_mode_orig = embed_mode
35 | self.sampling_rate = sampling_rate
36 | self.unconditional_prob = unconditional_prob
37 | self.random_mute = random_mute
38 | self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
39 | self.max_random_mute_portion = max_random_mute_portion
40 | self.training_mode = training_mode
41 | self.model, self.model_cfg = create_model(
42 | self.amodel,
43 | self.tmodel,
44 | self.pretrained,
45 | precision=self.precision,
46 | device=self.device,
47 | enable_fusion=self.enable_fusion,
48 | fusion_type=self.fusion_type,
49 | )
50 | for p in self.model.parameters():
51 | p.requires_grad = False
52 |
53 | self.model.eval()
54 |
55 | def get_unconditional_condition(self, batchsize):
56 | self.unconditional_token = self.model.get_text_embedding(
57 | self.tokenizer(["", ""])
58 | )[0:1]
59 | return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
60 |
61 | def batch_to_list(self, batch):
62 | ret = []
63 | for i in range(batch.size(0)):
64 | ret.append(batch[i])
65 | return ret
66 |
67 | def make_decision(self, probability):
68 | if float(torch.rand(1)) < probability:
69 | return True
70 | else:
71 | return False
72 |
73 | def random_uniform(self, start, end):
74 | val = torch.rand(1).item()
75 | return start + (end - start) * val
76 |
77 | def _random_mute(self, waveform):
78 | # waveform: [bs, t-steps]
79 | t_steps = waveform.size(-1)
80 | for i in range(waveform.size(0)):
81 | mute_size = int(
82 | self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
83 | )
84 | mute_start = int(self.random_uniform(0, t_steps - mute_size))
85 | waveform[i, mute_start : mute_start + mute_size] = 0
86 | return waveform
87 |
88 | def cos_similarity(self, waveform, text):
89 | # waveform: [bs, t_steps]
90 | with torch.no_grad():
91 | self.embed_mode = "audio"
92 | audio_emb = self(waveform.cuda())
93 | self.embed_mode = "text"
94 | text_emb = self(text)
95 | similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
96 | return similarity.squeeze()
97 |
98 | def forward(self, batch, key=None):
99 | # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0
100 | # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0
101 | if self.model.training == True and not self.training_mode:
102 | print(
103 | "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters."
104 | )
105 | self.model, self.model_cfg = create_model(
106 | self.amodel,
107 | self.tmodel,
108 | self.pretrained,
109 | precision=self.precision,
110 | device="cuda",
111 | enable_fusion=self.enable_fusion,
112 | fusion_type=self.fusion_type,
113 | )
114 | for p in self.model.parameters():
115 | p.requires_grad = False
116 | self.model.eval()
117 |
118 | # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
119 | if self.embed_mode == "audio":
120 | with torch.no_grad():
121 | audio_dict_list = []
122 | assert (
123 | self.sampling_rate == 16000
124 | ), "We only support 16000 sampling rate"
125 | if self.random_mute:
126 | batch = self._random_mute(batch)
127 | # batch: [bs, 1, t-samples]
128 | batch = torchaudio.functional.resample(
129 | batch, orig_freq=self.sampling_rate, new_freq=48000
130 | )
131 | for waveform in self.batch_to_list(batch):
132 | audio_dict = {}
133 | audio_dict = get_audio_features(
134 | audio_dict,
135 | waveform,
136 | 480000,
137 | data_truncating="fusion",
138 | data_filling="repeatpad",
139 | audio_cfg=self.model_cfg["audio_cfg"],
140 | )
141 | audio_dict_list.append(audio_dict)
142 | # [bs, 512]
143 | embed = self.model.get_audio_embedding(audio_dict_list)
144 | elif self.embed_mode == "text":
145 | with torch.no_grad():
146 | # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
147 | text_data = self.tokenizer(batch)
148 | embed = self.model.get_text_embedding(text_data)
149 |
150 | embed = embed.unsqueeze(1)
151 | self.unconditional_token = self.model.get_text_embedding(
152 | self.tokenizer(["", ""])
153 | )[0:1]
154 |
155 | for i in range(embed.size(0)):
156 | if self.make_decision(self.unconditional_prob):
157 | embed[i] = self.unconditional_token
158 |
159 | # [bs, 1, 512]
160 | return embed.detach()
161 |
162 | def tokenizer(self, text):
163 | result = self.tokenize(
164 | text,
165 | padding="max_length",
166 | truncation=True,
167 | max_length=512,
168 | return_tensors="pt",
169 | )
170 | return {k: v.squeeze(0) for k, v in result.items()}
171 |
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .factory import (
2 | list_models,
3 | create_model,
4 | create_model_and_transforms,
5 | add_model_config,
6 | )
7 | from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
8 | from .model import (
9 | CLAP,
10 | CLAPTextCfg,
11 | CLAPVisionCfg,
12 | CLAPAudioCfp,
13 | convert_weights_to_fp16,
14 | trace_model,
15 | )
16 | from .openai import load_openai_model, list_openai_models
17 | from .pretrained import (
18 | list_pretrained,
19 | list_pretrained_tag_models,
20 | list_pretrained_model_tags,
21 | get_pretrained_url,
22 | download_pretrained,
23 | )
24 | from .tokenizer import SimpleTokenizer, tokenize
25 | from .transform import image_transform
26 |
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/bert.py:
--------------------------------------------------------------------------------
1 | from transformers import BertTokenizer, BertModel
2 |
3 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
4 | model = BertModel.from_pretrained("bert-base-uncased")
5 | text = "Replace me by any text you'd like."
6 |
7 |
8 | def bert_embeddings(text):
9 | # text = "Replace me by any text you'd like."
10 | encoded_input = tokenizer(text, return_tensors="pt")
11 | output = model(**encoded_input)
12 | return output
13 |
14 |
15 | from transformers import RobertaTokenizer, RobertaModel
16 |
17 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
18 | model = RobertaModel.from_pretrained("roberta-base")
19 | text = "Replace me by any text you'd like."
20 |
21 |
22 | def Roberta_embeddings(text):
23 | # text = "Replace me by any text you'd like."
24 | encoded_input = tokenizer(text, return_tensors="pt")
25 | output = model(**encoded_input)
26 | return output
27 |
28 |
29 | from transformers import BartTokenizer, BartModel
30 |
31 | tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
32 | model = BartModel.from_pretrained("facebook/bart-base")
33 | text = "Replace me by any text you'd like."
34 |
35 |
36 | def bart_embeddings(text):
37 | # text = "Replace me by any text you'd like."
38 | encoded_input = tokenizer(text, return_tensors="pt")
39 | output = model(**encoded_input)
40 | return output
41 |
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/factory.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import pathlib
5 | import re
6 | from copy import deepcopy
7 | from pathlib import Path
8 |
9 | import torch
10 |
11 | from .model import CLAP, convert_weights_to_fp16
12 | from .openai import load_openai_model
13 | from .pretrained import get_pretrained_url, download_pretrained
14 | from .transform import image_transform
15 |
16 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
17 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
18 | CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache/audioldm")
19 |
20 |
21 |
22 | def _natural_key(string_):
23 | return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
24 |
25 |
26 | def _rescan_model_configs():
27 | global _MODEL_CONFIGS
28 |
29 | config_ext = (".json",)
30 | config_files = []
31 | for config_path in _MODEL_CONFIG_PATHS:
32 | if config_path.is_file() and config_path.suffix in config_ext:
33 | config_files.append(config_path)
34 | elif config_path.is_dir():
35 | for ext in config_ext:
36 | config_files.extend(config_path.glob(f"*{ext}"))
37 |
38 | for cf in config_files:
39 | if os.path.basename(cf)[0] == ".":
40 | continue # Ignore hidden files
41 |
42 | with open(cf, "r") as f:
43 | model_cfg = json.load(f)
44 | if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
45 | _MODEL_CONFIGS[cf.stem] = model_cfg
46 |
47 | _MODEL_CONFIGS = {
48 | k: v
49 | for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
50 | }
51 |
52 |
53 | _rescan_model_configs() # initial populate of model config registry
54 |
55 |
56 | def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
57 | checkpoint = torch.load(checkpoint_path, map_location=map_location)
58 | if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
59 | state_dict = checkpoint["state_dict"]
60 | else:
61 | state_dict = checkpoint
62 | if skip_params:
63 | if next(iter(state_dict.items()))[0].startswith("module"):
64 | state_dict = {k[7:]: v for k, v in state_dict.items()}
65 | # for k in state_dict:
66 | # if k.startswith('transformer'):
67 | # v = state_dict.pop(k)
68 | # state_dict['text_branch.' + k[12:]] = v
69 | return state_dict
70 |
71 |
72 | def create_model(
73 | amodel_name: str,
74 | tmodel_name: str,
75 | pretrained: str = "",
76 | precision: str = "fp32",
77 | device: torch.device = torch.device("cpu"),
78 | jit: bool = False,
79 | force_quick_gelu: bool = False,
80 | openai_model_cache_dir: str = os.path.expanduser(f"{CACHE_DIR}/clip"),
81 | skip_params=True,
82 | pretrained_audio: str = "",
83 | pretrained_text: str = "",
84 | enable_fusion: bool = False,
85 | fusion_type: str = "None"
86 | # pretrained_image: bool = False,
87 | ):
88 | amodel_name = amodel_name.replace(
89 | "/", "-"
90 | ) # for callers using old naming with / in ViT names
91 | pretrained_orig = pretrained
92 | pretrained = pretrained.lower()
93 | if pretrained == "openai":
94 | if amodel_name in _MODEL_CONFIGS:
95 | logging.info(f"Loading {amodel_name} model config.")
96 | model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
97 | else:
98 | logging.error(
99 | f"Model config for {amodel_name} not found; available models {list_models()}."
100 | )
101 | raise RuntimeError(f"Model config for {amodel_name} not found.")
102 |
103 | logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
104 | # Hard Code in model name
105 | model_cfg["text_cfg"]["model_type"] = tmodel_name
106 | model = load_openai_model(
107 | "ViT-B-16",
108 | model_cfg,
109 | device=device,
110 | jit=jit,
111 | cache_dir=openai_model_cache_dir,
112 | enable_fusion=enable_fusion,
113 | fusion_type=fusion_type,
114 | )
115 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
116 | if precision == "amp" or precision == "fp32":
117 | model = model.float()
118 | else:
119 | if amodel_name in _MODEL_CONFIGS:
120 | logging.info(f"Loading {amodel_name} model config.")
121 | model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
122 | else:
123 | logging.error(
124 | f"Model config for {amodel_name} not found; available models {list_models()}."
125 | )
126 | raise RuntimeError(f"Model config for {amodel_name} not found.")
127 |
128 | if force_quick_gelu:
129 | # override for use of QuickGELU on non-OpenAI transformer models
130 | model_cfg["quick_gelu"] = True
131 |
132 | # if pretrained_image:
133 | # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
134 | # # pretrained weight loading for timm models set via vision_cfg
135 | # model_cfg['vision_cfg']['timm_model_pretrained'] = True
136 | # else:
137 | # assert False, 'pretrained image towers currently only supported for timm models'
138 | model_cfg["text_cfg"]["model_type"] = tmodel_name
139 | model_cfg["enable_fusion"] = enable_fusion
140 | model_cfg["fusion_type"] = fusion_type
141 | model = CLAP(**model_cfg)
142 |
143 | if pretrained:
144 | checkpoint_path = ""
145 | url = get_pretrained_url(amodel_name, pretrained)
146 | if url:
147 | checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
148 | elif os.path.exists(pretrained_orig):
149 | checkpoint_path = pretrained_orig
150 | if checkpoint_path:
151 | logging.info(
152 | f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
153 | )
154 | ckpt = load_state_dict(checkpoint_path, skip_params=True)
155 | model.load_state_dict(ckpt)
156 | param_names = [n for n, p in model.named_parameters()]
157 | # for n in param_names:
158 | # print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
159 | else:
160 | logging.warning(
161 | f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
162 | )
163 | raise RuntimeError(
164 | f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
165 | )
166 |
167 | if pretrained_audio:
168 | if amodel_name.startswith("PANN"):
169 | if "Cnn14_mAP" in pretrained_audio: # official checkpoint
170 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
171 | audio_ckpt = audio_ckpt["model"]
172 | keys = list(audio_ckpt.keys())
173 | for key in keys:
174 | if (
175 | "spectrogram_extractor" not in key
176 | and "logmel_extractor" not in key
177 | ):
178 | v = audio_ckpt.pop(key)
179 | audio_ckpt["audio_branch." + key] = v
180 | elif os.path.basename(pretrained_audio).startswith(
181 | "PANN"
182 | ): # checkpoint trained via HTSAT codebase
183 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
184 | audio_ckpt = audio_ckpt["state_dict"]
185 | keys = list(audio_ckpt.keys())
186 | for key in keys:
187 | if key.startswith("sed_model"):
188 | v = audio_ckpt.pop(key)
189 | audio_ckpt["audio_branch." + key[10:]] = v
190 | elif os.path.basename(pretrained_audio).startswith(
191 | "finetuned"
192 | ): # checkpoint trained via linear probe codebase
193 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
194 | else:
195 | raise ValueError("Unknown audio checkpoint")
196 | elif amodel_name.startswith("HTSAT"):
197 | if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
198 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
199 | audio_ckpt = audio_ckpt["state_dict"]
200 | keys = list(audio_ckpt.keys())
201 | for key in keys:
202 | if key.startswith("sed_model") and (
203 | "spectrogram_extractor" not in key
204 | and "logmel_extractor" not in key
205 | ):
206 | v = audio_ckpt.pop(key)
207 | audio_ckpt["audio_branch." + key[10:]] = v
208 | elif os.path.basename(pretrained_audio).startswith(
209 | "HTSAT"
210 | ): # checkpoint trained via HTSAT codebase
211 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
212 | audio_ckpt = audio_ckpt["state_dict"]
213 | keys = list(audio_ckpt.keys())
214 | for key in keys:
215 | if key.startswith("sed_model"):
216 | v = audio_ckpt.pop(key)
217 | audio_ckpt["audio_branch." + key[10:]] = v
218 | elif os.path.basename(pretrained_audio).startswith(
219 | "finetuned"
220 | ): # checkpoint trained via linear probe codebase
221 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
222 | else:
223 | raise ValueError("Unknown audio checkpoint")
224 | else:
225 | raise f"this audio encoder pretrained checkpoint is not support"
226 |
227 | model.load_state_dict(audio_ckpt, strict=False)
228 | logging.info(
229 | f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
230 | )
231 | param_names = [n for n, p in model.named_parameters()]
232 | for n in param_names:
233 | print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
234 |
235 | model.to(device=device)
236 | if precision == "fp16":
237 | assert device.type != "cpu"
238 | convert_weights_to_fp16(model)
239 |
240 | if jit:
241 | model = torch.jit.script(model)
242 |
243 | return model, model_cfg
244 |
245 |
246 | def create_model_and_transforms(
247 | model_name: str,
248 | pretrained: str = "",
249 | precision: str = "fp32",
250 | device: torch.device = torch.device("cpu"),
251 | jit: bool = False,
252 | force_quick_gelu: bool = False,
253 | # pretrained_image: bool = False,
254 | ):
255 | model = create_model(
256 | model_name,
257 | pretrained,
258 | precision,
259 | device,
260 | jit,
261 | force_quick_gelu=force_quick_gelu,
262 | # pretrained_image=pretrained_image
263 | )
264 | preprocess_train = image_transform(model.visual.image_size, is_train=True)
265 | preprocess_val = image_transform(model.visual.image_size, is_train=False)
266 | return model, preprocess_train, preprocess_val
267 |
268 |
269 | def list_models():
270 | """enumerate available model architectures based on config files"""
271 | return list(_MODEL_CONFIGS.keys())
272 |
273 |
274 | def add_model_config(path):
275 | """add model config path or file and update registry"""
276 | if not isinstance(path, Path):
277 | path = Path(path)
278 | _MODEL_CONFIG_PATHS.append(path)
279 | _rescan_model_configs()
280 |
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/feature_fusion.py:
--------------------------------------------------------------------------------
1 | """
2 | Feature Fusion for Varible-Length Data Processing
3 | AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4 | According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5 | """
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | class DAF(nn.Module):
12 | """
13 | 直接相加 DirectAddFuse
14 | """
15 |
16 | def __init__(self):
17 | super(DAF, self).__init__()
18 |
19 | def forward(self, x, residual):
20 | return x + residual
21 |
22 |
23 | class iAFF(nn.Module):
24 | """
25 | 多特征融合 iAFF
26 | """
27 |
28 | def __init__(self, channels=64, r=4, type="2D"):
29 | super(iAFF, self).__init__()
30 | inter_channels = int(channels // r)
31 |
32 | if type == "1D":
33 | # 本地注意力
34 | self.local_att = nn.Sequential(
35 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36 | nn.BatchNorm1d(inter_channels),
37 | nn.ReLU(inplace=True),
38 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39 | nn.BatchNorm1d(channels),
40 | )
41 |
42 | # 全局注意力
43 | self.global_att = nn.Sequential(
44 | nn.AdaptiveAvgPool1d(1),
45 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46 | nn.BatchNorm1d(inter_channels),
47 | nn.ReLU(inplace=True),
48 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49 | nn.BatchNorm1d(channels),
50 | )
51 |
52 | # 第二次本地注意力
53 | self.local_att2 = nn.Sequential(
54 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55 | nn.BatchNorm1d(inter_channels),
56 | nn.ReLU(inplace=True),
57 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58 | nn.BatchNorm1d(channels),
59 | )
60 | # 第二次全局注意力
61 | self.global_att2 = nn.Sequential(
62 | nn.AdaptiveAvgPool1d(1),
63 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64 | nn.BatchNorm1d(inter_channels),
65 | nn.ReLU(inplace=True),
66 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67 | nn.BatchNorm1d(channels),
68 | )
69 | elif type == "2D":
70 | # 本地注意力
71 | self.local_att = nn.Sequential(
72 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73 | nn.BatchNorm2d(inter_channels),
74 | nn.ReLU(inplace=True),
75 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76 | nn.BatchNorm2d(channels),
77 | )
78 |
79 | # 全局注意力
80 | self.global_att = nn.Sequential(
81 | nn.AdaptiveAvgPool2d(1),
82 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83 | nn.BatchNorm2d(inter_channels),
84 | nn.ReLU(inplace=True),
85 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86 | nn.BatchNorm2d(channels),
87 | )
88 |
89 | # 第二次本地注意力
90 | self.local_att2 = nn.Sequential(
91 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92 | nn.BatchNorm2d(inter_channels),
93 | nn.ReLU(inplace=True),
94 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95 | nn.BatchNorm2d(channels),
96 | )
97 | # 第二次全局注意力
98 | self.global_att2 = nn.Sequential(
99 | nn.AdaptiveAvgPool2d(1),
100 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101 | nn.BatchNorm2d(inter_channels),
102 | nn.ReLU(inplace=True),
103 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104 | nn.BatchNorm2d(channels),
105 | )
106 | else:
107 | raise f"the type is not supported"
108 |
109 | self.sigmoid = nn.Sigmoid()
110 |
111 | def forward(self, x, residual):
112 | flag = False
113 | xa = x + residual
114 | if xa.size(0) == 1:
115 | xa = torch.cat([xa, xa], dim=0)
116 | flag = True
117 | xl = self.local_att(xa)
118 | xg = self.global_att(xa)
119 | xlg = xl + xg
120 | wei = self.sigmoid(xlg)
121 | xi = x * wei + residual * (1 - wei)
122 |
123 | xl2 = self.local_att2(xi)
124 | xg2 = self.global_att(xi)
125 | xlg2 = xl2 + xg2
126 | wei2 = self.sigmoid(xlg2)
127 | xo = x * wei2 + residual * (1 - wei2)
128 | if flag:
129 | xo = xo[0].unsqueeze(0)
130 | return xo
131 |
132 |
133 | class AFF(nn.Module):
134 | """
135 | 多特征融合 AFF
136 | """
137 |
138 | def __init__(self, channels=64, r=4, type="2D"):
139 | super(AFF, self).__init__()
140 | inter_channels = int(channels // r)
141 |
142 | if type == "1D":
143 | self.local_att = nn.Sequential(
144 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145 | nn.BatchNorm1d(inter_channels),
146 | nn.ReLU(inplace=True),
147 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148 | nn.BatchNorm1d(channels),
149 | )
150 | self.global_att = nn.Sequential(
151 | nn.AdaptiveAvgPool1d(1),
152 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153 | nn.BatchNorm1d(inter_channels),
154 | nn.ReLU(inplace=True),
155 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156 | nn.BatchNorm1d(channels),
157 | )
158 | elif type == "2D":
159 | self.local_att = nn.Sequential(
160 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161 | nn.BatchNorm2d(inter_channels),
162 | nn.ReLU(inplace=True),
163 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164 | nn.BatchNorm2d(channels),
165 | )
166 | self.global_att = nn.Sequential(
167 | nn.AdaptiveAvgPool2d(1),
168 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169 | nn.BatchNorm2d(inter_channels),
170 | nn.ReLU(inplace=True),
171 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172 | nn.BatchNorm2d(channels),
173 | )
174 | else:
175 | raise f"the type is not supported."
176 |
177 | self.sigmoid = nn.Sigmoid()
178 |
179 | def forward(self, x, residual):
180 | flag = False
181 | xa = x + residual
182 | if xa.size(0) == 1:
183 | xa = torch.cat([xa, xa], dim=0)
184 | flag = True
185 | xl = self.local_att(xa)
186 | xg = self.global_att(xa)
187 | xlg = xl + xg
188 | wei = self.sigmoid(xlg)
189 | xo = 2 * x * wei + 2 * residual * (1 - wei)
190 | if flag:
191 | xo = xo[0].unsqueeze(0)
192 | return xo
193 |
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/linear_probe.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from .model import MLPLayers
5 |
6 |
7 | class LinearProbe(nn.Module):
8 | def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None):
9 | """
10 | Args:
11 | model: nn.Module
12 | mlp: bool, if True, then use the MLP layer as the linear probe module
13 | freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe
14 | in_ch: int, the output channel from CLAP model
15 | out_ch: int, the output channel from linear probe (class_num)
16 | act: torch.nn.functional, the activation function before the loss function
17 | """
18 | super().__init__()
19 | in_ch = 512
20 | self.clap_model = model
21 | self.clap_model.text_branch = None # to save memory
22 | self.freeze = freeze
23 | if mlp:
24 | self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch])
25 | else:
26 | self.lp_layer = nn.Linear(in_ch, out_ch)
27 |
28 | if self.freeze:
29 | for param in self.clap_model.parameters():
30 | param.requires_grad = False
31 |
32 | if act == "None":
33 | self.act = None
34 | elif act == "relu":
35 | self.act = nn.ReLU()
36 | elif act == "elu":
37 | self.act = nn.ELU()
38 | elif act == "prelu":
39 | self.act = nn.PReLU(num_parameters=in_ch)
40 | elif act == "softmax":
41 | self.act = nn.Softmax(dim=-1)
42 | elif act == "sigmoid":
43 | self.act = nn.Sigmoid()
44 |
45 | def forward(self, x, mix_lambda=None, device=None):
46 | """
47 | Args:
48 | x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list
49 | mix_lambda: torch.tensor [batch], the mixup lambda
50 | Returns:
51 | class_prob: torch.tensor [batch, class_num]
52 |
53 | """
54 | # batchnorm cancel grandient
55 | if self.freeze:
56 | self.clap_model.eval()
57 |
58 | x = self.clap_model.audio_projection(
59 | self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[
60 | "embedding"
61 | ]
62 | )
63 | out = self.lp_layer(x)
64 | if self.act is not None:
65 | out = self.act(out)
66 | return out
67 |
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/HTSAT-base.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "HTSAT",
14 | "model_name": "base"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/HTSAT-large.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "HTSAT",
14 | "model_name": "large"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1536,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "HTSAT",
14 | "model_name": "tiny"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "HTSAT",
14 | "model_name": "tiny"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/PANN-10.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn10"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 18000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn14"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 960000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 360,
10 | "fmin": 50,
11 | "fmax": 8000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn14"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn14"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 4
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1536,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn14"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/PANN-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn14"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/PANN-6.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn6"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 23,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/RN101.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 23,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 6,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/RN50.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 6,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/RN50x16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 384,
5 | "layers": [
6 | 6,
7 | 8,
8 | 18,
9 | 8
10 | ],
11 | "width": 96,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 768,
18 | "heads": 12,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/RN50x4.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 288,
5 | "layers": [
6 | 4,
7 | 6,
8 | 10,
9 | 6
10 | ],
11 | "width": 80,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 640,
18 | "heads": 10,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/ViT-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": 12,
7 | "width": 768,
8 | "patch_size": 32
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/model_configs/ViT-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/openai.py:
--------------------------------------------------------------------------------
1 | """ OpenAI pretrained model functions
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | import os
7 | import warnings
8 | from typing import Union, List
9 |
10 | import torch
11 |
12 | from .model import build_model_from_openai_state_dict
13 | from .pretrained import (
14 | get_pretrained_url,
15 | list_pretrained_tag_models,
16 | download_pretrained,
17 | )
18 |
19 | __all__ = ["list_openai_models", "load_openai_model"]
20 |
21 | CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache")
22 |
23 |
24 |
25 | def list_openai_models() -> List[str]:
26 | """Returns the names of available CLIP models"""
27 | return list_pretrained_tag_models("openai")
28 |
29 |
30 | def load_openai_model(
31 | name: str,
32 | model_cfg,
33 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
34 | jit=True,
35 | cache_dir=os.path.expanduser(f"{CACHE_DIR}/clip"),
36 | enable_fusion: bool = False,
37 | fusion_type: str = "None",
38 | ):
39 | """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
40 |
41 | Parameters
42 | ----------
43 | name : str
44 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
45 | device : Union[str, torch.device]
46 | The device to put the loaded model
47 | jit : bool
48 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
49 |
50 | Returns
51 | -------
52 | model : torch.nn.Module
53 | The CLAP model
54 | preprocess : Callable[[PIL.Image], torch.Tensor]
55 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
56 | """
57 | if get_pretrained_url(name, "openai"):
58 | model_path = download_pretrained(
59 | get_pretrained_url(name, "openai"), root=cache_dir
60 | )
61 | elif os.path.isfile(name):
62 | model_path = name
63 | else:
64 | raise RuntimeError(
65 | f"Model {name} not found; available models = {list_openai_models()}"
66 | )
67 |
68 | try:
69 | # loading JIT archive
70 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
71 | state_dict = None
72 | except RuntimeError:
73 | # loading saved state dict
74 | if jit:
75 | warnings.warn(
76 | f"File {model_path} is not a JIT archive. Loading as a state dict instead"
77 | )
78 | jit = False
79 | state_dict = torch.load(model_path, map_location="cpu")
80 |
81 | if not jit:
82 | try:
83 | model = build_model_from_openai_state_dict(
84 | state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type
85 | ).to(device)
86 | except KeyError:
87 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
88 | model = build_model_from_openai_state_dict(
89 | sd, model_cfg, enable_fusion, fusion_type
90 | ).to(device)
91 |
92 | if str(device) == "cpu":
93 | model.float()
94 | return model
95 |
96 | # patch the device names
97 | device_holder = torch.jit.trace(
98 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
99 | )
100 | device_node = [
101 | n
102 | for n in device_holder.graph.findAllNodes("prim::Constant")
103 | if "Device" in repr(n)
104 | ][-1]
105 |
106 | def patch_device(module):
107 | try:
108 | graphs = [module.graph] if hasattr(module, "graph") else []
109 | except RuntimeError:
110 | graphs = []
111 |
112 | if hasattr(module, "forward1"):
113 | graphs.append(module.forward1.graph)
114 |
115 | for graph in graphs:
116 | for node in graph.findAllNodes("prim::Constant"):
117 | if "value" in node.attributeNames() and str(node["value"]).startswith(
118 | "cuda"
119 | ):
120 | node.copyAttributes(device_node)
121 |
122 | model.apply(patch_device)
123 | patch_device(model.encode_audio)
124 | patch_device(model.encode_text)
125 |
126 | # patch dtype to float32 on CPU
127 | if str(device) == "cpu":
128 | float_holder = torch.jit.trace(
129 | lambda: torch.ones([]).float(), example_inputs=[]
130 | )
131 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
132 | float_node = float_input.node()
133 |
134 | def patch_float(module):
135 | try:
136 | graphs = [module.graph] if hasattr(module, "graph") else []
137 | except RuntimeError:
138 | graphs = []
139 |
140 | if hasattr(module, "forward1"):
141 | graphs.append(module.forward1.graph)
142 |
143 | for graph in graphs:
144 | for node in graph.findAllNodes("aten::to"):
145 | inputs = list(node.inputs())
146 | for i in [
147 | 1,
148 | 2,
149 | ]: # dtype can be the second or third argument to aten::to()
150 | if inputs[i].node()["value"] == 5:
151 | inputs[i].node().copyAttributes(float_node)
152 |
153 | model.apply(patch_float)
154 | patch_float(model.encode_audio)
155 | patch_float(model.encode_text)
156 | model.float()
157 |
158 | model.audio_branch.audio_length = model.audio_cfg.audio_length
159 | return model
160 |
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/pretrained.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 |
6 | from tqdm import tqdm
7 |
8 | CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache")
9 |
10 | _RN50 = dict(
11 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
12 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
13 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
14 | )
15 |
16 | _RN50_quickgelu = dict(
17 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
18 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
19 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
20 | )
21 |
22 | _RN101 = dict(
23 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
24 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
25 | )
26 |
27 | _RN101_quickgelu = dict(
28 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
29 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
30 | )
31 |
32 | _RN50x4 = dict(
33 | openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34 | )
35 |
36 | _RN50x16 = dict(
37 | openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
38 | )
39 |
40 | _RN50x64 = dict(
41 | openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
42 | )
43 |
44 | _VITB32 = dict(
45 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
46 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
47 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
48 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
49 | )
50 |
51 | _VITB32_quickgelu = dict(
52 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
53 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
54 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
55 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
56 | )
57 |
58 | _VITB16 = dict(
59 | openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
60 | )
61 |
62 | _VITL14 = dict(
63 | openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
64 | )
65 |
66 | _PRETRAINED = {
67 | "RN50": _RN50,
68 | "RN50-quickgelu": _RN50_quickgelu,
69 | "RN101": _RN101,
70 | "RN101-quickgelu": _RN101_quickgelu,
71 | "RN50x4": _RN50x4,
72 | "RN50x16": _RN50x16,
73 | "ViT-B-32": _VITB32,
74 | "ViT-B-32-quickgelu": _VITB32_quickgelu,
75 | "ViT-B-16": _VITB16,
76 | "ViT-L-14": _VITL14,
77 | }
78 |
79 |
80 | def list_pretrained(as_str: bool = False):
81 | """returns list of pretrained models
82 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
83 | """
84 | return [
85 | ":".join([k, t]) if as_str else (k, t)
86 | for k in _PRETRAINED.keys()
87 | for t in _PRETRAINED[k].keys()
88 | ]
89 |
90 |
91 | def list_pretrained_tag_models(tag: str):
92 | """return all models having the specified pretrain tag"""
93 | models = []
94 | for k in _PRETRAINED.keys():
95 | if tag in _PRETRAINED[k]:
96 | models.append(k)
97 | return models
98 |
99 |
100 | def list_pretrained_model_tags(model: str):
101 | """return all pretrain tags for the specified model architecture"""
102 | tags = []
103 | if model in _PRETRAINED:
104 | tags.extend(_PRETRAINED[model].keys())
105 | return tags
106 |
107 |
108 | def get_pretrained_url(model: str, tag: str):
109 | if model not in _PRETRAINED:
110 | return ""
111 | model_pretrained = _PRETRAINED[model]
112 | if tag not in model_pretrained:
113 | return ""
114 | return model_pretrained[tag]
115 |
116 |
117 | def download_pretrained(url: str, root: str = os.path.expanduser(f"{CACHE_DIR}/clip")):
118 | os.makedirs(root, exist_ok=True)
119 | filename = os.path.basename(url)
120 |
121 | if "openaipublic" in url:
122 | expected_sha256 = url.split("/")[-2]
123 | else:
124 | expected_sha256 = ""
125 |
126 | download_target = os.path.join(root, filename)
127 |
128 | if os.path.exists(download_target) and not os.path.isfile(download_target):
129 | raise RuntimeError(f"{download_target} exists and is not a regular file")
130 |
131 | if os.path.isfile(download_target):
132 | if expected_sha256:
133 | if (
134 | hashlib.sha256(open(download_target, "rb").read()).hexdigest()
135 | == expected_sha256
136 | ):
137 | return download_target
138 | else:
139 | warnings.warn(
140 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
141 | )
142 | else:
143 | return download_target
144 |
145 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
146 | with tqdm(
147 | total=int(source.info().get("Content-Length")),
148 | ncols=80,
149 | unit="iB",
150 | unit_scale=True,
151 | ) as loop:
152 | while True:
153 | buffer = source.read(8192)
154 | if not buffer:
155 | break
156 |
157 | output.write(buffer)
158 | loop.update(len(buffer))
159 |
160 | if (
161 | expected_sha256
162 | and hashlib.sha256(open(download_target, "rb").read()).hexdigest()
163 | != expected_sha256
164 | ):
165 | raise RuntimeError(
166 | f"Model has been downloaded but the SHA256 checksum does not not match"
167 | )
168 |
169 | return download_target
170 |
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/timm_model.py:
--------------------------------------------------------------------------------
1 | """ timm model adapter
2 |
3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4 | """
5 | from collections import OrderedDict
6 |
7 | import torch.nn as nn
8 |
9 | try:
10 | import timm
11 | from timm.models.layers import Mlp, to_2tuple
12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d
13 | from timm.models.layers.attention_pool2d import (
14 | AttentionPool2d as AbsAttentionPool2d,
15 | )
16 | except ImportError as e:
17 | timm = None
18 |
19 | from .utils import freeze_batch_norm_2d
20 |
21 |
22 | class TimmModel(nn.Module):
23 | """timm model adapter
24 | # FIXME this adapter is a work in progress, may change in ways that break weight compat
25 | """
26 |
27 | def __init__(
28 | self,
29 | model_name,
30 | embed_dim,
31 | image_size=224,
32 | pool="avg",
33 | proj="linear",
34 | drop=0.0,
35 | pretrained=False,
36 | ):
37 | super().__init__()
38 | if timm is None:
39 | raise RuntimeError("Please `pip install timm` to use timm models.")
40 |
41 | self.image_size = to_2tuple(image_size)
42 | self.trunk = timm.create_model(model_name, pretrained=pretrained)
43 | feat_size = self.trunk.default_cfg.get("pool_size", None)
44 | feature_ndim = 1 if not feat_size else 2
45 | if pool in ("abs_attn", "rot_attn"):
46 | assert feature_ndim == 2
47 | # if attn pooling used, remove both classifier and default pool
48 | self.trunk.reset_classifier(0, global_pool="")
49 | else:
50 | # reset global pool if pool config set, otherwise leave as network default
51 | reset_kwargs = dict(global_pool=pool) if pool else {}
52 | self.trunk.reset_classifier(0, **reset_kwargs)
53 | prev_chs = self.trunk.num_features
54 |
55 | head_layers = OrderedDict()
56 | if pool == "abs_attn":
57 | head_layers["pool"] = AbsAttentionPool2d(
58 | prev_chs, feat_size=feat_size, out_features=embed_dim
59 | )
60 | prev_chs = embed_dim
61 | elif pool == "rot_attn":
62 | head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
63 | prev_chs = embed_dim
64 | else:
65 | assert proj, "projection layer needed if non-attention pooling is used."
66 |
67 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
68 | if proj == "linear":
69 | head_layers["drop"] = nn.Dropout(drop)
70 | head_layers["proj"] = nn.Linear(prev_chs, embed_dim)
71 | elif proj == "mlp":
72 | head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
73 |
74 | self.head = nn.Sequential(head_layers)
75 |
76 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
77 | """lock modules
78 | Args:
79 | unlocked_groups (int): leave last n layer groups unlocked (default: 0)
80 | """
81 | if not unlocked_groups:
82 | # lock full model
83 | for param in self.trunk.parameters():
84 | param.requires_grad = False
85 | if freeze_bn_stats:
86 | freeze_batch_norm_2d(self.trunk)
87 | else:
88 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change
89 | try:
90 | # FIXME import here until API stable and in an official release
91 | from timm.models.helpers import group_parameters, group_modules
92 | except ImportError:
93 | raise RuntimeError(
94 | "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`"
95 | )
96 | matcher = self.trunk.group_matcher()
97 | gparams = group_parameters(self.trunk, matcher)
98 | max_layer_id = max(gparams.keys())
99 | max_layer_id = max_layer_id - unlocked_groups
100 | for group_idx in range(max_layer_id + 1):
101 | group = gparams[group_idx]
102 | for param in group:
103 | self.trunk.get_parameter(param).requires_grad = False
104 | if freeze_bn_stats:
105 | gmodules = group_modules(self.trunk, matcher, reverse=True)
106 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
107 | freeze_batch_norm_2d(self.trunk, gmodules)
108 |
109 | def forward(self, x):
110 | x = self.trunk(x)
111 | x = self.head(x)
112 | return x
113 |
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/tokenizer.py:
--------------------------------------------------------------------------------
1 | """ CLIP tokenizer
2 |
3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 | import gzip
6 | import html
7 | import os
8 | from functools import lru_cache
9 | from typing import Union, List
10 |
11 | import ftfy
12 | import regex as re
13 | import torch
14 |
15 |
16 | @lru_cache()
17 | def default_bpe():
18 | return os.path.join(
19 | os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
20 | )
21 |
22 |
23 | @lru_cache()
24 | def bytes_to_unicode():
25 | """
26 | Returns list of utf-8 byte and a corresponding list of unicode strings.
27 | The reversible bpe codes work on unicode strings.
28 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
29 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
30 | This is a signficant percentage of your normal, say, 32K bpe vocab.
31 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
32 | And avoids mapping to whitespace/control characters the bpe code barfs on.
33 | """
34 | bs = (
35 | list(range(ord("!"), ord("~") + 1))
36 | + list(range(ord("¡"), ord("¬") + 1))
37 | + list(range(ord("®"), ord("ÿ") + 1))
38 | )
39 | cs = bs[:]
40 | n = 0
41 | for b in range(2**8):
42 | if b not in bs:
43 | bs.append(b)
44 | cs.append(2**8 + n)
45 | n += 1
46 | cs = [chr(n) for n in cs]
47 | return dict(zip(bs, cs))
48 |
49 |
50 | def get_pairs(word):
51 | """Return set of symbol pairs in a word.
52 | Word is represented as tuple of symbols (symbols being variable-length strings).
53 | """
54 | pairs = set()
55 | prev_char = word[0]
56 | for char in word[1:]:
57 | pairs.add((prev_char, char))
58 | prev_char = char
59 | return pairs
60 |
61 |
62 | def basic_clean(text):
63 | text = ftfy.fix_text(text)
64 | text = html.unescape(html.unescape(text))
65 | return text.strip()
66 |
67 |
68 | def whitespace_clean(text):
69 | text = re.sub(r"\s+", " ", text)
70 | text = text.strip()
71 | return text
72 |
73 |
74 | class SimpleTokenizer(object):
75 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
76 | self.byte_encoder = bytes_to_unicode()
77 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
78 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
79 | merges = merges[1 : 49152 - 256 - 2 + 1]
80 | merges = [tuple(merge.split()) for merge in merges]
81 | vocab = list(bytes_to_unicode().values())
82 | vocab = vocab + [v + "" for v in vocab]
83 | for merge in merges:
84 | vocab.append("".join(merge))
85 | if not special_tokens:
86 | special_tokens = ["", ""]
87 | else:
88 | special_tokens = ["", ""] + special_tokens
89 | vocab.extend(special_tokens)
90 | self.encoder = dict(zip(vocab, range(len(vocab))))
91 | self.decoder = {v: k for k, v in self.encoder.items()}
92 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
93 | self.cache = {t: t for t in special_tokens}
94 | special = "|".join(special_tokens)
95 | self.pat = re.compile(
96 | special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
97 | re.IGNORECASE,
98 | )
99 |
100 | self.vocab_size = len(self.encoder)
101 | self.all_special_ids = [self.encoder[t] for t in special_tokens]
102 |
103 | def bpe(self, token):
104 | if token in self.cache:
105 | return self.cache[token]
106 | word = tuple(token[:-1]) + (token[-1] + "",)
107 | pairs = get_pairs(word)
108 |
109 | if not pairs:
110 | return token + ""
111 |
112 | while True:
113 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
114 | if bigram not in self.bpe_ranks:
115 | break
116 | first, second = bigram
117 | new_word = []
118 | i = 0
119 | while i < len(word):
120 | try:
121 | j = word.index(first, i)
122 | new_word.extend(word[i:j])
123 | i = j
124 | except:
125 | new_word.extend(word[i:])
126 | break
127 |
128 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
129 | new_word.append(first + second)
130 | i += 2
131 | else:
132 | new_word.append(word[i])
133 | i += 1
134 | new_word = tuple(new_word)
135 | word = new_word
136 | if len(word) == 1:
137 | break
138 | else:
139 | pairs = get_pairs(word)
140 | word = " ".join(word)
141 | self.cache[token] = word
142 | return word
143 |
144 | def encode(self, text):
145 | bpe_tokens = []
146 | text = whitespace_clean(basic_clean(text)).lower()
147 | for token in re.findall(self.pat, text):
148 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
149 | bpe_tokens.extend(
150 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
151 | )
152 | return bpe_tokens
153 |
154 | def decode(self, tokens):
155 | text = "".join([self.decoder[token] for token in tokens])
156 | text = (
157 | bytearray([self.byte_decoder[c] for c in text])
158 | .decode("utf-8", errors="replace")
159 | .replace("", " ")
160 | )
161 | return text
162 |
163 |
164 | _tokenizer = SimpleTokenizer()
165 |
166 |
167 | def tokenize(
168 | texts: Union[str, List[str]], context_length: int = 77
169 | ) -> torch.LongTensor:
170 | """
171 | Returns the tokenized representation of given input string(s)
172 |
173 | Parameters
174 | ----------
175 | texts : Union[str, List[str]]
176 | An input string or a list of input strings to tokenize
177 | context_length : int
178 | The context length to use; all CLIP models use 77 as the context length
179 |
180 | Returns
181 | -------
182 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
183 | """
184 | if isinstance(texts, str):
185 | texts = [texts]
186 |
187 | sot_token = _tokenizer.encoder[""]
188 | eot_token = _tokenizer.encoder[""]
189 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
190 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
191 |
192 | for i, tokens in enumerate(all_tokens):
193 | if len(tokens) > context_length:
194 | tokens = tokens[:context_length] # Truncate
195 | result[i, : len(tokens)] = torch.tensor(tokens)
196 |
197 | return result
198 |
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/transform.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import (
2 | Normalize,
3 | Compose,
4 | RandomResizedCrop,
5 | InterpolationMode,
6 | ToTensor,
7 | Resize,
8 | CenterCrop,
9 | )
10 |
11 |
12 | def _convert_to_rgb(image):
13 | return image.convert("RGB")
14 |
15 |
16 | def image_transform(
17 | image_size: int,
18 | is_train: bool,
19 | mean=(0.48145466, 0.4578275, 0.40821073),
20 | std=(0.26862954, 0.26130258, 0.27577711),
21 | ):
22 | normalize = Normalize(mean=mean, std=std)
23 | if is_train:
24 | return Compose(
25 | [
26 | RandomResizedCrop(
27 | image_size,
28 | scale=(0.9, 1.0),
29 | interpolation=InterpolationMode.BICUBIC,
30 | ),
31 | _convert_to_rgb,
32 | ToTensor(),
33 | normalize,
34 | ]
35 | )
36 | else:
37 | return Compose(
38 | [
39 | Resize(image_size, interpolation=InterpolationMode.BICUBIC),
40 | CenterCrop(image_size),
41 | _convert_to_rgb,
42 | ToTensor(),
43 | normalize,
44 | ]
45 | )
46 |
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn as nn
4 | from torchvision.ops.misc import FrozenBatchNorm2d
5 | import logging
6 |
7 | # import h5py
8 | from tqdm import tqdm
9 | import random
10 | import json
11 | import os
12 | import pathlib
13 |
14 | # TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later.
15 | dataset_split = {
16 | "audiocaps": ["train", "valid", "test"],
17 | "audioset": ["balanced_train", "unbalanced_train", "eval"],
18 | "BBCSoundEffects": ["train", "test"],
19 | "Clotho": ["train", "test", "valid"],
20 | "free_to_use_sounds": ["train", "test"],
21 | "paramount_motion": ["train", "test"],
22 | "sonniss_game_effects": ["train", "test"],
23 | "wesoundeffects": ["train", "test"],
24 | "MACS": ["train", "test"],
25 | "freesound": ["train", "test"],
26 | "FSD50K": ["train", "test", "valid"],
27 | "fsd50k_class_label": ["train", "test", "valid"],
28 | "esc50": ["train", "test"],
29 | "audiostock": ["train", "test"],
30 | "freesound_no_overlap_noesc50": ["train", "test"],
31 | "epidemic_sound_effects": ["train", "test"],
32 | "VGGSound": ["train", "test"],
33 | "urbansound8k_class_label": ["train", "test"],
34 | "audioset_t5": ["balanced_train", "unbalanced_train", "eval"],
35 | "epidemic_sound_effects_t5": ["train", "test"],
36 | "WavText5K": ["train", "test"],
37 | "esc50_no_overlap": ["train", "test"],
38 | "usd8k_no_overlap": ["train", "test"],
39 | "fsd50k_200_class_label": ["train", "test", "valid"],
40 | }
41 |
42 |
43 | def freeze_batch_norm_2d(module, module_match={}, name=""):
44 | """
45 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
46 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
47 | returned. Otherwise, the module is walked recursively and submodules are converted in place.
48 |
49 | Args:
50 | module (torch.nn.Module): Any PyTorch module.
51 | module_match (dict): Dictionary of full module names to freeze (all if empty)
52 | name (str): Full module name (prefix)
53 |
54 | Returns:
55 | torch.nn.Module: Resulting module
56 |
57 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
58 | """
59 | res = module
60 | is_match = True
61 | if module_match:
62 | is_match = name in module_match
63 | if is_match and isinstance(
64 | module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)
65 | ):
66 | res = FrozenBatchNorm2d(module.num_features)
67 | res.num_features = module.num_features
68 | res.affine = module.affine
69 | if module.affine:
70 | res.weight.data = module.weight.data.clone().detach()
71 | res.bias.data = module.bias.data.clone().detach()
72 | res.running_mean.data = module.running_mean.data
73 | res.running_var.data = module.running_var.data
74 | res.eps = module.eps
75 | else:
76 | for child_name, child in module.named_children():
77 | full_child_name = ".".join([name, child_name]) if name else child_name
78 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
79 | if new_child is not child:
80 | res.add_module(child_name, new_child)
81 | return res
82 |
83 |
84 | def exist(dataset_name, dataset_type):
85 | """
86 | Check if dataset exists
87 | """
88 | if dataset_type in dataset_split[dataset_name]:
89 | return True
90 | else:
91 | return False
92 |
93 |
94 | def get_tar_path_from_dataset_name(
95 | dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None
96 | ):
97 | """
98 | Get tar path from dataset name and type
99 | """
100 | output = []
101 | for n in dataset_names:
102 | if full_dataset is not None and n in full_dataset:
103 | current_dataset_types = dataset_split[n]
104 | else:
105 | current_dataset_types = dataset_types
106 | for s in current_dataset_types:
107 | tmp = []
108 | if islocal:
109 | sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json"
110 | if not os.path.exists(sizefilepath_):
111 | sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
112 | else:
113 | sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
114 | if not os.path.exists(sizefilepath_):
115 | continue
116 | sizes = json.load(open(sizefilepath_, "r"))
117 | for k in sizes.keys():
118 | if islocal:
119 | tmp.append(f"{dataset_path}/{n}/{s}/{k}")
120 | else:
121 | tmp.append(
122 | f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -"
123 | )
124 | if proportion != 1:
125 | tmp = random.sample(tmp, int(proportion * len(tmp)))
126 | output.append(tmp)
127 | return sum(output, [])
128 |
129 |
130 | def get_tar_path_from_txts(txt_path, islocal, proportion=1):
131 | """
132 | Get tar path from txt path
133 | """
134 | if isinstance(txt_path, (list, tuple)):
135 | return sum(
136 | [
137 | get_tar_path_from_txts(
138 | txt_path[i], islocal=islocal, proportion=proportion
139 | )
140 | for i in range(len(txt_path))
141 | ],
142 | [],
143 | )
144 | if isinstance(txt_path, str):
145 | with open(txt_path) as f:
146 | lines = f.readlines()
147 | if islocal:
148 | lines = [
149 | lines[i]
150 | .split("\n")[0]
151 | .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/")
152 | for i in range(len(lines))
153 | ]
154 | else:
155 | lines = [
156 | lines[i].split("\n")[0].replace(".tar", ".tar -")
157 | for i in range(len(lines))
158 | ]
159 | if proportion != 1:
160 | print("Sampling tars with proportion of {}".format(proportion))
161 | lines = random.sample(lines, int(proportion * len(lines)))
162 | return lines
163 |
164 |
165 | def get_mix_lambda(mixup_alpha, batch_size):
166 | mixup_lambdas = [
167 | np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)
168 | ]
169 | return np.array(mixup_lambdas).astype(np.float32)
170 |
171 |
172 | def do_mixup(x, mixup_lambda):
173 | """
174 | Args:
175 | x: (batch_size , ...)
176 | mixup_lambda: (batch_size,)
177 | Returns:
178 | out: (batch_size, ...)
179 | """
180 | out = (
181 | x.transpose(0, -1) * mixup_lambda
182 | + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
183 | ).transpose(0, -1)
184 | return out
185 |
186 |
187 | def interpolate(x, ratio):
188 | """Interpolate data in time domain. This is used to compensate the
189 | resolution reduction in downsampling of a CNN.
190 |
191 | Args:
192 | x: (batch_size, time_steps, classes_num)
193 | ratio: int, ratio to interpolate
194 | Returns:
195 | upsampled: (batch_size, time_steps * ratio, classes_num)
196 | """
197 | (batch_size, time_steps, classes_num) = x.shape
198 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
199 | upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
200 | return upsampled
201 |
202 |
203 | def pad_framewise_output(framewise_output, frames_num):
204 | """Pad framewise_output to the same length as input frames. The pad value
205 | is the same as the value of the last frame.
206 | Args:
207 | framewise_output: (batch_size, frames_num, classes_num)
208 | frames_num: int, number of frames to pad
209 | Outputs:
210 | output: (batch_size, frames_num, classes_num)
211 | """
212 | pad = framewise_output[:, -1:, :].repeat(
213 | 1, frames_num - framewise_output.shape[1], 1
214 | )
215 | """tensor for padding"""
216 |
217 | output = torch.cat((framewise_output, pad), dim=1)
218 | """(batch_size, frames_num, classes_num)"""
219 |
220 |
221 | # def process_ipc(index_path, classes_num, filename):
222 | # # load data
223 | # logging.info("Load Data...............")
224 | # ipc = [[] for _ in range(classes_num)]
225 | # with h5py.File(index_path, "r") as f:
226 | # for i in tqdm(range(len(f["target"]))):
227 | # t_class = np.where(f["target"][i])[0]
228 | # for t in t_class:
229 | # ipc[t].append(i)
230 | # print(ipc)
231 | # np.save(filename, ipc)
232 | # logging.info("Load Data Succeed...............")
233 |
234 |
235 | def save_to_dict(s, o_={}):
236 | sp = s.split(": ")
237 | o_.update({sp[0]: float(sp[1])})
238 | return o_
239 |
240 |
241 | def get_data_from_log(txt_path):
242 | """
243 | Output dictionary from out.txt log file
244 | """
245 | with open(txt_path) as f:
246 | lines = f.readlines()
247 | val_data = {}
248 | train_data = {}
249 | train_losses = []
250 | train_losses_epoch = []
251 | for i in range(len(lines)):
252 | if "| INFO |" in lines[i]:
253 | if "Eval Epoch" in lines[i]:
254 | if "val_loss" in lines[i]:
255 | # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", ""))
256 | line = lines[i].split("Eval Epoch: ")[-1]
257 | num_epoch = int(line.split(" ")[0].split(" ")[0])
258 | d = {
259 | line.split(" ")[0]
260 | .split(" ")[1]
261 | .replace(":", ""): float(line.split(" ")[0].split(" ")[-1])
262 | }
263 | for i in range(1, len(line.split(" "))):
264 | d = save_to_dict(line.split(" ")[i], d)
265 | val_data[num_epoch] = d
266 | elif "Train Epoch" in lines[i]:
267 | num_epoch = int(lines[i].split("Train Epoch: ")[1][0])
268 | loss = float(lines[i].split("Loss: ")[-1].split(" (")[0])
269 | train_losses.append(loss)
270 | train_losses_epoch.append(num_epoch)
271 | for i in range(len(train_losses)):
272 | train_data[i] = {
273 | "num_epoch": train_losses_epoch[i],
274 | "train_loss": train_losses[i],
275 | }
276 | return train_data, val_data
277 |
278 |
279 | def save_p(obj, filename):
280 | import pickle
281 |
282 | try:
283 | from deepdiff import DeepDiff
284 | except:
285 | os.system("pip install deepdiff")
286 | from deepdiff import DeepDiff
287 | with open(filename, "wb") as file:
288 | pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol
289 | with open(filename, "rb") as file:
290 | z = pickle.load(file)
291 | assert (
292 | DeepDiff(obj, z, ignore_string_case=True) == {}
293 | ), "there is something wrong with the saving process"
294 | return
295 |
296 |
297 | def load_p(filename):
298 | import pickle
299 |
300 | with open(filename, "rb") as file:
301 | z = pickle.load(file)
302 | return z
303 |
304 |
305 | def save_json(data, name="data.json"):
306 | import json
307 |
308 | with open(name, "w") as fp:
309 | json.dump(data, fp)
310 | return
311 |
312 |
313 | def load_json(name):
314 | import json
315 |
316 | with open(name, "r") as fp:
317 | data = json.load(fp)
318 | return data
319 |
320 |
321 | from multiprocessing import Process, Manager
322 | from multiprocessing import Process, Value, Array
323 | from ctypes import c_wchar
324 |
325 |
326 | def load_class_label(path):
327 | # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
328 | # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
329 | out = None
330 | if path is not None:
331 | if pathlib.Path(path).suffix in [".pkl", ".pickle"]:
332 | out = load_p(path)
333 | elif pathlib.Path(path).suffix in [".json", ".txt"]:
334 | out = load_json(path)
335 | elif pathlib.Path(path).suffix in [".npy", ".npz"]:
336 | out = np.load(path)
337 | elif pathlib.Path(path).suffix in [".csv"]:
338 | import pandas as pd
339 |
340 | out = pd.read_csv(path)
341 | return out
342 | # if out is None:
343 | # return None
344 | # else:
345 | # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False)
346 | # val = Array('i', out.values(), lock=False)
347 | # return (key, val)
348 |
349 |
350 | from torch import optim
351 |
352 |
353 | def get_optimizer(params, lr, betas, eps, momentum, optimizer_name):
354 | if optimizer_name.lower() == "adamw":
355 | optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps)
356 | elif optimizer_name.lower() == "sgd":
357 | optimizer = optim.SGD(params, lr=lr, momentum=momentum)
358 | elif optimizer_name.lower() == "adam":
359 | optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps)
360 | else:
361 | raise ValueError("optimizer name is not correct")
362 | return optimizer
363 |
--------------------------------------------------------------------------------
/audioldm/clap/open_clip/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.2.1"
2 |
--------------------------------------------------------------------------------
/audioldm/clap/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/audioldm/clap/training/__init__.py
--------------------------------------------------------------------------------
/audioldm/clap/training/audioset_textmap.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/audioldm/clap/training/audioset_textmap.npy
--------------------------------------------------------------------------------
/audioldm/clap/training/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import socket
5 |
6 | try:
7 | import horovod.torch as hvd
8 | except ImportError:
9 | hvd = None
10 |
11 |
12 | def is_global_master(args):
13 | return args.rank == 0
14 |
15 |
16 | def is_local_master(args):
17 | return args.local_rank == 0
18 |
19 |
20 | def is_master(args, local=False):
21 | return is_local_master(args) if local else is_global_master(args)
22 |
23 |
24 | def is_using_horovod():
25 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
26 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
27 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
28 | pmi_vars = ["PMI_RANK", "PMI_SIZE"]
29 | if all([var in os.environ for var in ompi_vars]) or all(
30 | [var in os.environ for var in pmi_vars]
31 | ):
32 | return True
33 | else:
34 | return False
35 |
36 |
37 | def is_using_distributed():
38 | if "WORLD_SIZE" in os.environ:
39 | return int(os.environ["WORLD_SIZE"]) > 1
40 | if "SLURM_NTASKS" in os.environ:
41 | return int(os.environ["SLURM_NTASKS"]) > 1
42 | return False
43 |
44 |
45 | def world_info_from_env():
46 | local_rank = 0
47 | for v in (
48 | "SLURM_LOCALID",
49 | "MPI_LOCALRANKID",
50 | "OMPI_COMM_WORLD_LOCAL_RANK",
51 | "LOCAL_RANK",
52 | ):
53 | if v in os.environ:
54 | local_rank = int(os.environ[v])
55 | break
56 | global_rank = 0
57 | for v in ("SLURM_PROCID", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "RANK"):
58 | if v in os.environ:
59 | global_rank = int(os.environ[v])
60 | break
61 | world_size = 1
62 | for v in ("SLURM_NTASKS", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "WORLD_SIZE"):
63 | if v in os.environ:
64 | world_size = int(os.environ[v])
65 | break
66 |
67 | return local_rank, global_rank, world_size
68 |
69 |
70 | def init_distributed_device(args):
71 | # Distributed training = training on more than one GPU.
72 | # Works in both single and multi-node scenarios.
73 | args.distributed = False
74 | args.world_size = 1
75 | args.rank = 0 # global rank
76 | args.local_rank = 0
77 | if args.horovod:
78 | assert hvd is not None, "Horovod is not installed"
79 | hvd.init()
80 | world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
81 | world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
82 | local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
83 | args.local_rank = local_rank
84 | args.rank = world_rank
85 | args.world_size = world_size
86 | # args.local_rank = int(hvd.local_rank())
87 | # args.rank = hvd.rank()
88 | # args.world_size = hvd.size()
89 | args.distributed = True
90 | os.environ["LOCAL_RANK"] = str(args.local_rank)
91 | os.environ["RANK"] = str(args.rank)
92 | os.environ["WORLD_SIZE"] = str(args.world_size)
93 | print(
94 | f"Distributed training: local_rank={args.local_rank}, "
95 | f"rank={args.rank}, world_size={args.world_size}, "
96 | f"hostname={socket.gethostname()}, pid={os.getpid()}"
97 | )
98 | elif is_using_distributed():
99 | if "SLURM_PROCID" in os.environ:
100 | # DDP via SLURM
101 | args.local_rank, args.rank, args.world_size = world_info_from_env()
102 | # SLURM var -> torch.distributed vars in case needed
103 | os.environ["LOCAL_RANK"] = str(args.local_rank)
104 | os.environ["RANK"] = str(args.rank)
105 | os.environ["WORLD_SIZE"] = str(args.world_size)
106 | torch.distributed.init_process_group(
107 | backend=args.dist_backend,
108 | init_method=args.dist_url,
109 | world_size=args.world_size,
110 | rank=args.rank,
111 | )
112 | elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster
113 | world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
114 | world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
115 | local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
116 | args.local_rank = local_rank
117 | args.rank = world_rank
118 | args.world_size = world_size
119 | torch.distributed.init_process_group(
120 | backend=args.dist_backend,
121 | init_method=args.dist_url,
122 | world_size=args.world_size,
123 | rank=args.rank,
124 | )
125 | else:
126 | # DDP via torchrun, torch.distributed.launch
127 | args.local_rank, _, _ = world_info_from_env()
128 | torch.distributed.init_process_group(
129 | backend=args.dist_backend, init_method=args.dist_url
130 | )
131 | args.world_size = torch.distributed.get_world_size()
132 | args.rank = torch.distributed.get_rank()
133 | args.distributed = True
134 | print(
135 | f"Distributed training: local_rank={args.local_rank}, "
136 | f"rank={args.rank}, world_size={args.world_size}, "
137 | f"hostname={socket.gethostname()}, pid={os.getpid()}"
138 | )
139 |
140 | if torch.cuda.is_available():
141 | if args.distributed and not args.no_set_device_rank:
142 | device = "cuda:%d" % args.local_rank
143 | else:
144 | device = "cuda:0"
145 | torch.cuda.set_device(device)
146 | else:
147 | device = "cpu"
148 | args.device = device
149 | device = torch.device(device)
150 | return device
151 |
--------------------------------------------------------------------------------
/audioldm/clap/training/infer_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import os
4 | import torch
5 | import librosa
6 | from open_clip import create_model
7 | from training.data import get_audio_features
8 | from training.data import int16_to_float32, float32_to_int16
9 | from transformers import RobertaTokenizer
10 |
11 | tokenize = RobertaTokenizer.from_pretrained("roberta-base")
12 |
13 |
14 | def tokenizer(text):
15 | result = tokenize(
16 | text,
17 | padding="max_length",
18 | truncation=True,
19 | max_length=77,
20 | return_tensors="pt",
21 | )
22 | return {k: v.squeeze(0) for k, v in result.items()}
23 |
24 |
25 | PRETRAINED_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/checkpoints/epoch_top_0_audioset_no_fusion.pt"
26 | WAVE_48k_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/audio/machine.wav"
27 |
28 |
29 | def infer_text():
30 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
31 | precision = "fp32"
32 | amodel = "HTSAT-tiny" # or 'PANN-14'
33 | tmodel = "roberta" # the best text encoder in our training
34 | enable_fusion = False # False if you do not want to use the fusion model
35 | fusion_type = "aff_2d"
36 | pretrained = PRETRAINED_PATH
37 |
38 | model, model_cfg = create_model(
39 | amodel,
40 | tmodel,
41 | pretrained,
42 | precision=precision,
43 | device=device,
44 | enable_fusion=enable_fusion,
45 | fusion_type=fusion_type,
46 | )
47 | # load the text, can be a list (i.e. batch size)
48 | text_data = ["I love the contrastive learning", "I love the pretrain model"]
49 | # tokenize for roberta, if you want to tokenize for another text encoder, please refer to data.py#L43-90
50 | text_data = tokenizer(text_data)
51 |
52 | text_embed = model.get_text_embedding(text_data)
53 | print(text_embed.size())
54 |
55 |
56 | def infer_audio():
57 |
58 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
59 | precision = "fp32"
60 | amodel = "HTSAT-tiny" # or 'PANN-14'
61 | tmodel = "roberta" # the best text encoder in our training
62 | enable_fusion = False # False if you do not want to use the fusion model
63 | fusion_type = "aff_2d"
64 | pretrained = PRETRAINED_PATH
65 |
66 | model, model_cfg = create_model(
67 | amodel,
68 | tmodel,
69 | pretrained,
70 | precision=precision,
71 | device=device,
72 | enable_fusion=enable_fusion,
73 | fusion_type=fusion_type,
74 | )
75 |
76 | # load the waveform of the shape (T,), should resample to 48000
77 | audio_waveform, sr = librosa.load(WAVE_48k_PATH, sr=48000)
78 | # quantize
79 | audio_waveform = int16_to_float32(float32_to_int16(audio_waveform))
80 | audio_waveform = torch.from_numpy(audio_waveform).float()
81 | audio_dict = {}
82 |
83 | # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
84 | import ipdb
85 |
86 | ipdb.set_trace()
87 | audio_dict = get_audio_features(
88 | audio_dict,
89 | audio_waveform,
90 | 480000,
91 | data_truncating="fusion",
92 | data_filling="repeatpad",
93 | audio_cfg=model_cfg["audio_cfg"],
94 | )
95 | # can send a list to the model, to process many audio tracks in one time (i.e. batch size)
96 | audio_embed = model.get_audio_embedding([audio_dict])
97 | print(audio_embed.size())
98 | import ipdb
99 |
100 | ipdb.set_trace()
101 |
102 |
103 | if __name__ == "__main__":
104 | infer_text()
105 | infer_audio()
106 |
--------------------------------------------------------------------------------
/audioldm/clap/training/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def setup_logging(log_file, level, include_host=False):
5 | if include_host:
6 | import socket
7 |
8 | hostname = socket.gethostname()
9 | formatter = logging.Formatter(
10 | f"%(asctime)s | {hostname} | %(levelname)s | %(message)s",
11 | datefmt="%Y-%m-%d,%H:%M:%S",
12 | )
13 | else:
14 | formatter = logging.Formatter(
15 | "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d,%H:%M:%S"
16 | )
17 |
18 | logging.root.setLevel(level)
19 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
20 | for logger in loggers:
21 | logger.setLevel(level)
22 |
23 | stream_handler = logging.StreamHandler()
24 | stream_handler.setFormatter(formatter)
25 | logging.root.addHandler(stream_handler)
26 |
27 | if log_file:
28 | file_handler = logging.FileHandler(filename=log_file)
29 | file_handler.setFormatter(formatter)
30 | logging.root.addHandler(file_handler)
31 |
--------------------------------------------------------------------------------
/audioldm/clap/training/lp_train.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import math
4 | import os
5 | import time
6 | from contextlib import suppress
7 |
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 |
12 | try:
13 | import wandb
14 | except ImportError:
15 | wandb = None
16 |
17 | from open_clip import LPLoss, LPMetrics, lp_gather_features
18 | from open_clip.utils import do_mixup, get_mix_lambda
19 | from .distributed import is_master
20 | from .zero_shot import zero_shot_eval
21 |
22 |
23 | class AverageMeter(object):
24 | """Computes and stores the average and current value"""
25 |
26 | def __init__(self):
27 | self.reset()
28 |
29 | def reset(self):
30 | self.val = 0
31 | self.avg = 0
32 | self.sum = 0
33 | self.count = 0
34 |
35 | def update(self, val, n=1):
36 | self.val = val
37 | self.sum += val * n
38 | self.count += n
39 | self.avg = self.sum / self.count
40 |
41 |
42 | def unwrap_model(model):
43 | if hasattr(model, "module"):
44 | return model.module
45 | else:
46 | return model
47 |
48 |
49 | def train_one_epoch(
50 | model,
51 | data,
52 | epoch,
53 | optimizer,
54 | scaler,
55 | scheduler,
56 | args,
57 | tb_writer=None,
58 | extra_suffix="",
59 | ):
60 | device = torch.device(args.device)
61 | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
62 | model.train()
63 | loss = LPLoss(args.lp_loss)
64 |
65 | dataloader, sampler = data["train"].dataloader, data["train"].sampler
66 | if args.distributed and sampler is not None:
67 | sampler.set_epoch(epoch)
68 | num_batches_per_epoch = dataloader.num_batches
69 | sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
70 |
71 | # for toy dataset
72 | if args.dataset_type == "toy":
73 | dataloader.dataset.generate_queue()
74 |
75 | loss_m = AverageMeter()
76 | batch_time_m = AverageMeter()
77 | data_time_m = AverageMeter()
78 | end = time.time()
79 |
80 | for i, batch in enumerate(dataloader):
81 | step = num_batches_per_epoch * epoch + i
82 |
83 | if isinstance(scheduler, dict):
84 | for s in scheduler.values():
85 | s(step)
86 | else:
87 | scheduler(step)
88 |
89 | audio = batch # contains mel_spec, wavform, and longer list
90 | class_label = batch["class_label"]
91 | # audio = audio.to(device=device, non_blocking=True)
92 | class_label = class_label.to(device=device, non_blocking=True)
93 |
94 | if args.mixup:
95 | # https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146
96 | mix_lambda = torch.from_numpy(
97 | get_mix_lambda(0.5, len(audio["waveform"]))
98 | ).to(device)
99 | class_label = do_mixup(class_label, mix_lambda)
100 | else:
101 | mix_lambda = None
102 |
103 | data_time_m.update(time.time() - end)
104 | if isinstance(optimizer, dict):
105 | for o_ in optimizer.values():
106 | o_.zero_grad()
107 | else:
108 | optimizer.zero_grad()
109 |
110 | with autocast():
111 | pred = model(audio, mix_lambda=mix_lambda, device=device)
112 | total_loss = loss(pred, class_label)
113 |
114 | if isinstance(optimizer, dict):
115 | if scaler is not None:
116 | scaler.scale(total_loss).backward()
117 | for o_ in optimizer.values():
118 | if args.horovod:
119 | o_.synchronize()
120 | scaler.unscale_(o_)
121 | with o_.skip_synchronize():
122 | scaler.step(o_)
123 | else:
124 | scaler.step(o_)
125 | scaler.update()
126 | else:
127 | total_loss.backward()
128 | for o_ in optimizer.values():
129 | o_.step()
130 | else:
131 | if scaler is not None:
132 | scaler.scale(total_loss).backward()
133 | if args.horovod:
134 | optimizer.synchronize()
135 | scaler.unscale_(optimizer)
136 | with optimizer.skip_synchronize():
137 | scaler.step(optimizer)
138 | else:
139 | scaler.step(optimizer)
140 | scaler.update()
141 | else:
142 | total_loss.backward()
143 | optimizer.step()
144 |
145 | # Note: we clamp to 4.6052 = ln(100), as in the original paper.
146 | with torch.no_grad():
147 | unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100))
148 | unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100))
149 |
150 | batch_time_m.update(time.time() - end)
151 | end = time.time()
152 | batch_count = i + 1
153 |
154 | if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch):
155 | if isinstance(audio, dict):
156 | batch_size = len(audio["waveform"])
157 | else:
158 | batch_size = len(audio)
159 | num_samples = batch_count * batch_size * args.world_size
160 | samples_per_epoch = dataloader.num_samples
161 | percent_complete = 100.0 * batch_count / num_batches_per_epoch
162 |
163 | # NOTE loss is coarsely sampled, just master node and per log update
164 | loss_m.update(total_loss.item(), batch_size)
165 | if isinstance(optimizer, dict):
166 | logging.info(
167 | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
168 | f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
169 | f"Data (t): {data_time_m.avg:.3f} "
170 | f"Batch (t): {batch_time_m.avg:.3f} "
171 | f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}"
172 | )
173 | log_data = {
174 | "loss": loss_m.val,
175 | "data_time": data_time_m.val,
176 | "batch_time": batch_time_m.val,
177 | "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
178 | }
179 | else:
180 | logging.info(
181 | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
182 | f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
183 | f"Data (t): {data_time_m.avg:.3f} "
184 | f"Batch (t): {batch_time_m.avg:.3f} "
185 | f"LR: {optimizer.param_groups[0]['lr']:5f} "
186 | )
187 |
188 | # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
189 | log_data = {
190 | "loss": loss_m.val,
191 | "data_time": data_time_m.val,
192 | "batch_time": batch_time_m.val,
193 | "lr": optimizer.param_groups[0]["lr"],
194 | }
195 | for name, val in log_data.items():
196 | name = f"train{extra_suffix}/{name}"
197 | if tb_writer is not None:
198 | tb_writer.add_scalar(name, val, step)
199 | if args.wandb:
200 | assert wandb is not None, "Please install wandb."
201 | wandb.log({name: val, "step": step})
202 |
203 | # resetting batch / data time meters per log window
204 | batch_time_m.reset()
205 | data_time_m.reset()
206 | # end for
207 |
208 |
209 | def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""):
210 | metrics = {}
211 | if not args.parallel_eval:
212 | if not is_master(args):
213 | return metrics
214 | device = torch.device(args.device)
215 | model.eval()
216 |
217 | # CHANGE
218 | # zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
219 | # metrics.update(zero_shot_metrics)
220 | if is_master(args):
221 | print("Evaluating...")
222 | metric_names = args.lp_metrics.split(",")
223 | eval_tool = LPMetrics(metric_names=metric_names)
224 |
225 | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
226 | if "val" in data and (
227 | args.val_frequency
228 | and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)
229 | ):
230 | if args.parallel_eval:
231 | dataloader, sampler = data["val"].dataloader, data["val"].sampler
232 | if args.distributed and sampler is not None:
233 | sampler.set_epoch(epoch)
234 | samples_per_val = dataloader.num_samples
235 | else:
236 | dataloader = data["val"].dataloader
237 | num_samples = 0
238 | samples_per_val = dataloader.num_samples
239 |
240 | eval_info = {"pred": [], "target": []}
241 | with torch.no_grad():
242 | for i, batch in enumerate(dataloader):
243 | audio = batch # contains mel_spec, wavform, and longer list
244 | class_label = batch["class_label"]
245 |
246 | # audio = audio.to(device=device, non_blocking=True)
247 | class_label = class_label.to(device=device, non_blocking=True)
248 |
249 | with autocast():
250 | pred = model(audio, device=device)
251 | if args.parallel_eval:
252 | pred, class_label = lp_gather_features(
253 | pred, class_label, args.world_size, args.horovod
254 | )
255 | eval_info["pred"].append(pred)
256 | eval_info["target"].append(class_label)
257 |
258 | num_samples += class_label.shape[0]
259 |
260 | if (i % 100) == 0: # and i != 0:
261 | logging.info(
262 | f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]"
263 | )
264 |
265 | if is_master(args):
266 | eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu()
267 | eval_info["target"] = torch.cat(eval_info["target"], 0).cpu()
268 | metric_dict = eval_tool.evaluate_mertics(
269 | eval_info["pred"], eval_info["target"]
270 | )
271 | metrics.update(metric_dict)
272 | if "epoch" not in metrics.keys():
273 | metrics.update({"epoch": epoch})
274 |
275 | if is_master(args):
276 | if not metrics:
277 | return metrics
278 |
279 | logging.info(
280 | f"Eval Epoch: {epoch} "
281 | + "\n".join(
282 | ["\t".join([f"{m}: {round(metrics[m], 4):.4f}"]) for m in metrics]
283 | )
284 | )
285 | if args.save_logs:
286 | for name, val in metrics.items():
287 | if tb_writer is not None:
288 | tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch)
289 |
290 | with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
291 | f.write(json.dumps(metrics))
292 | f.write("\n")
293 |
294 | if args.wandb:
295 | assert wandb is not None, "Please install wandb."
296 | for name, val in metrics.items():
297 | wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch})
298 |
299 | return metrics
300 | else:
301 | return metrics
302 |
--------------------------------------------------------------------------------
/audioldm/clap/training/scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def assign_learning_rate(optimizer, new_lr):
5 | for param_group in optimizer.param_groups:
6 | param_group["lr"] = new_lr
7 |
8 |
9 | def _warmup_lr(base_lr, warmup_length, step):
10 | return base_lr * (step + 1) / warmup_length
11 |
12 |
13 | def cosine_lr(optimizer, base_lr, warmup_length, steps):
14 | def _lr_adjuster(step):
15 | if step < warmup_length:
16 | lr = _warmup_lr(base_lr, warmup_length, step)
17 | else:
18 | e = step - warmup_length
19 | es = steps - warmup_length
20 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
21 | assign_learning_rate(optimizer, lr)
22 | return lr
23 |
24 | return _lr_adjuster
25 |
--------------------------------------------------------------------------------
/audioldm/clap/training/zero_shot.py:
--------------------------------------------------------------------------------
1 | # NOTE: This script is currently not supported for CLAP.
2 | import logging
3 | from contextlib import suppress
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from tqdm import tqdm
8 |
9 | from open_clip import tokenize
10 | from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template
11 |
12 |
13 | def zero_shot_classifier(model, classnames, templates, args):
14 | with torch.no_grad():
15 | zeroshot_weights = []
16 | for classname in tqdm(classnames):
17 | texts = [template(classname) for template in templates] # format with class
18 | texts = tokenize(texts).to(args.device) # tokenize
19 | if args.distributed and not args.horovod:
20 | class_embeddings = model.module.encode_text(texts)
21 | else:
22 | class_embeddings = model.encode_text(texts)
23 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
24 | class_embedding /= class_embedding.norm()
25 | zeroshot_weights.append(class_embedding)
26 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device)
27 | return zeroshot_weights
28 |
29 |
30 | def accuracy(output, target, topk=(1,)):
31 | pred = output.topk(max(topk), 1, True, True)[1].t()
32 | correct = pred.eq(target.view(1, -1).expand_as(pred))
33 | return [
34 | float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
35 | for k in topk
36 | ]
37 |
38 |
39 | def run(model, classifier, dataloader, args):
40 | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
41 | with torch.no_grad():
42 | top1, top5, n = 0.0, 0.0, 0.0
43 | for images, target in tqdm(dataloader, unit_scale=args.batch_size):
44 | images = images.to(args.device)
45 | target = target.to(args.device)
46 |
47 | with autocast():
48 | # predict
49 | if args.distributed and not args.horovod:
50 | image_features = model.module.encode_image(images)
51 | else:
52 | image_features = model.encode_image(images)
53 | image_features = F.normalize(image_features, dim=-1)
54 | logits = 100.0 * image_features @ classifier
55 |
56 | # measure accuracy
57 | acc1, acc5 = accuracy(logits, target, topk=(1, 5))
58 | top1 += acc1
59 | top5 += acc5
60 | n += images.size(0)
61 |
62 | top1 = top1 / n
63 | top5 = top5 / n
64 | return top1, top5
65 |
66 |
67 | def zero_shot_eval(model, data, epoch, args):
68 | if "imagenet-val" not in data and "imagenet-v2" not in data:
69 | return {}
70 | if args.zeroshot_frequency == 0:
71 | return {}
72 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
73 | return {}
74 |
75 | logging.info("Starting zero-shot imagenet.")
76 |
77 | logging.info("Building zero-shot classifier")
78 | classifier = zero_shot_classifier(
79 | model, imagenet_classnames, openai_imagenet_template, args
80 | )
81 |
82 | logging.info("Using classifier")
83 | results = {}
84 | if "imagenet-val" in data:
85 | top1, top5 = run(model, classifier, data["imagenet-val"].dataloader, args)
86 | results["imagenet-zeroshot-val-top1"] = top1
87 | results["imagenet-zeroshot-val-top5"] = top5
88 | if "imagenet-v2" in data:
89 | top1, top5 = run(model, classifier, data["imagenet-v2"].dataloader, args)
90 | results["imagenetv2-zeroshot-val-top1"] = top1
91 | results["imagenetv2-zeroshot-val-top5"] = top5
92 |
93 | logging.info("Finished zero-shot imagenet.")
94 |
95 | return results
96 |
--------------------------------------------------------------------------------
/audioldm/hifigan/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import Generator
2 |
3 |
4 | class AttrDict(dict):
5 | def __init__(self, *args, **kwargs):
6 | super(AttrDict, self).__init__(*args, **kwargs)
7 | self.__dict__ = self
8 |
--------------------------------------------------------------------------------
/audioldm/hifigan/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import Conv1d, ConvTranspose1d
5 | from torch.nn.utils import weight_norm, remove_weight_norm
6 |
7 | LRELU_SLOPE = 0.1
8 |
9 |
10 | def init_weights(m, mean=0.0, std=0.01):
11 | classname = m.__class__.__name__
12 | if classname.find("Conv") != -1:
13 | m.weight.data.normal_(mean, std)
14 |
15 |
16 | def get_padding(kernel_size, dilation=1):
17 | return int((kernel_size * dilation - dilation) / 2)
18 |
19 |
20 | class ResBlock(torch.nn.Module):
21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22 | super(ResBlock, self).__init__()
23 | self.h = h
24 | self.convs1 = nn.ModuleList(
25 | [
26 | weight_norm(
27 | Conv1d(
28 | channels,
29 | channels,
30 | kernel_size,
31 | 1,
32 | dilation=dilation[0],
33 | padding=get_padding(kernel_size, dilation[0]),
34 | )
35 | ),
36 | weight_norm(
37 | Conv1d(
38 | channels,
39 | channels,
40 | kernel_size,
41 | 1,
42 | dilation=dilation[1],
43 | padding=get_padding(kernel_size, dilation[1]),
44 | )
45 | ),
46 | weight_norm(
47 | Conv1d(
48 | channels,
49 | channels,
50 | kernel_size,
51 | 1,
52 | dilation=dilation[2],
53 | padding=get_padding(kernel_size, dilation[2]),
54 | )
55 | ),
56 | ]
57 | )
58 | self.convs1.apply(init_weights)
59 |
60 | self.convs2 = nn.ModuleList(
61 | [
62 | weight_norm(
63 | Conv1d(
64 | channels,
65 | channels,
66 | kernel_size,
67 | 1,
68 | dilation=1,
69 | padding=get_padding(kernel_size, 1),
70 | )
71 | ),
72 | weight_norm(
73 | Conv1d(
74 | channels,
75 | channels,
76 | kernel_size,
77 | 1,
78 | dilation=1,
79 | padding=get_padding(kernel_size, 1),
80 | )
81 | ),
82 | weight_norm(
83 | Conv1d(
84 | channels,
85 | channels,
86 | kernel_size,
87 | 1,
88 | dilation=1,
89 | padding=get_padding(kernel_size, 1),
90 | )
91 | ),
92 | ]
93 | )
94 | self.convs2.apply(init_weights)
95 |
96 | def forward(self, x):
97 | for c1, c2 in zip(self.convs1, self.convs2):
98 | xt = F.leaky_relu(x, LRELU_SLOPE)
99 | xt = c1(xt)
100 | xt = F.leaky_relu(xt, LRELU_SLOPE)
101 | xt = c2(xt)
102 | x = xt + x
103 | return x
104 |
105 | def remove_weight_norm(self):
106 | for l in self.convs1:
107 | remove_weight_norm(l)
108 | for l in self.convs2:
109 | remove_weight_norm(l)
110 |
111 |
112 | class Generator(torch.nn.Module):
113 | def __init__(self, h):
114 | super(Generator, self).__init__()
115 | self.h = h
116 | self.num_kernels = len(h.resblock_kernel_sizes)
117 | self.num_upsamples = len(h.upsample_rates)
118 | self.conv_pre = weight_norm(
119 | Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
120 | )
121 | resblock = ResBlock
122 |
123 | self.ups = nn.ModuleList()
124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
125 | self.ups.append(
126 | weight_norm(
127 | ConvTranspose1d(
128 | h.upsample_initial_channel // (2**i),
129 | h.upsample_initial_channel // (2 ** (i + 1)),
130 | k,
131 | u,
132 | padding=(k - u) // 2,
133 | )
134 | )
135 | )
136 |
137 | self.resblocks = nn.ModuleList()
138 | for i in range(len(self.ups)):
139 | ch = h.upsample_initial_channel // (2 ** (i + 1))
140 | for j, (k, d) in enumerate(
141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
142 | ):
143 | self.resblocks.append(resblock(h, ch, k, d))
144 |
145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
146 | self.ups.apply(init_weights)
147 | self.conv_post.apply(init_weights)
148 |
149 | def forward(self, x):
150 | x = self.conv_pre(x)
151 | for i in range(self.num_upsamples):
152 | x = F.leaky_relu(x, LRELU_SLOPE)
153 | x = self.ups[i](x)
154 | xs = None
155 | for j in range(self.num_kernels):
156 | if xs is None:
157 | xs = self.resblocks[i * self.num_kernels + j](x)
158 | else:
159 | xs += self.resblocks[i * self.num_kernels + j](x)
160 | x = xs / self.num_kernels
161 | x = F.leaky_relu(x)
162 | x = self.conv_post(x)
163 | x = torch.tanh(x)
164 |
165 | return x
166 |
167 | def remove_weight_norm(self):
168 | # print("Removing weight norm...")
169 | for l in self.ups:
170 | remove_weight_norm(l)
171 | for l in self.resblocks:
172 | l.remove_weight_norm()
173 | remove_weight_norm(self.conv_pre)
174 | remove_weight_norm(self.conv_post)
175 |
--------------------------------------------------------------------------------
/audioldm/hifigan/utilities.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | import torch
5 | import numpy as np
6 |
7 | import audioldm.hifigan as hifigan
8 |
9 | HIFIGAN_16K_64 = {
10 | "resblock": "1",
11 | "num_gpus": 6,
12 | "batch_size": 16,
13 | "learning_rate": 0.0002,
14 | "adam_b1": 0.8,
15 | "adam_b2": 0.99,
16 | "lr_decay": 0.999,
17 | "seed": 1234,
18 | "upsample_rates": [5, 4, 2, 2, 2],
19 | "upsample_kernel_sizes": [16, 16, 8, 4, 4],
20 | "upsample_initial_channel": 1024,
21 | "resblock_kernel_sizes": [3, 7, 11],
22 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
23 | "segment_size": 8192,
24 | "num_mels": 64,
25 | "num_freq": 1025,
26 | "n_fft": 1024,
27 | "hop_size": 160,
28 | "win_size": 1024,
29 | "sampling_rate": 16000,
30 | "fmin": 0,
31 | "fmax": 8000,
32 | "fmax_for_loss": None,
33 | "num_workers": 4,
34 | "dist_config": {
35 | "dist_backend": "nccl",
36 | "dist_url": "tcp://localhost:54321",
37 | "world_size": 1,
38 | },
39 | }
40 |
41 |
42 | def get_available_checkpoint_keys(model, ckpt):
43 | print("==> Attemp to reload from %s" % ckpt)
44 | state_dict = torch.load(ckpt)["state_dict"]
45 | current_state_dict = model.state_dict()
46 | new_state_dict = {}
47 | for k in state_dict.keys():
48 | if (
49 | k in current_state_dict.keys()
50 | and current_state_dict[k].size() == state_dict[k].size()
51 | ):
52 | new_state_dict[k] = state_dict[k]
53 | else:
54 | print("==> WARNING: Skipping %s" % k)
55 | print(
56 | "%s out of %s keys are matched"
57 | % (len(new_state_dict.keys()), len(state_dict.keys()))
58 | )
59 | return new_state_dict
60 |
61 |
62 | def get_param_num(model):
63 | num_param = sum(param.numel() for param in model.parameters())
64 | return num_param
65 |
66 |
67 | def get_vocoder(config, device):
68 | config = hifigan.AttrDict(HIFIGAN_16K_64)
69 | vocoder = hifigan.Generator(config)
70 | vocoder.eval()
71 | vocoder.remove_weight_norm()
72 | vocoder.to(device)
73 | return vocoder
74 |
75 |
76 | def vocoder_infer(mels, vocoder, lengths=None):
77 | with torch.no_grad():
78 | wavs = vocoder(mels).squeeze(1)
79 |
80 | wavs = (wavs.cpu().numpy() * 32768).astype("int16")
81 |
82 | if lengths is not None:
83 | wavs = wavs[:, :lengths]
84 |
85 | return wavs
86 |
--------------------------------------------------------------------------------
/audioldm/latent_diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/audioldm/latent_diffusion/__init__.py
--------------------------------------------------------------------------------
/audioldm/latent_diffusion/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError("Decay must be between 0 and 1")
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer(
14 | "num_updates",
15 | torch.tensor(0, dtype=torch.int)
16 | if use_num_upates
17 | else torch.tensor(-1, dtype=torch.int),
18 | )
19 |
20 | for name, p in model.named_parameters():
21 | if p.requires_grad:
22 | # remove as '.'-character is not allowed in buffers
23 | s_name = name.replace(".", "")
24 | self.m_name2s_name.update({name: s_name})
25 | self.register_buffer(s_name, p.clone().detach().data)
26 |
27 | self.collected_params = []
28 |
29 | def forward(self, model):
30 | decay = self.decay
31 |
32 | if self.num_updates >= 0:
33 | self.num_updates += 1
34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35 |
36 | one_minus_decay = 1.0 - decay
37 |
38 | with torch.no_grad():
39 | m_param = dict(model.named_parameters())
40 | shadow_params = dict(self.named_buffers())
41 |
42 | for key in m_param:
43 | if m_param[key].requires_grad:
44 | sname = self.m_name2s_name[key]
45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46 | shadow_params[sname].sub_(
47 | one_minus_decay * (shadow_params[sname] - m_param[key])
48 | )
49 | else:
50 | assert not key in self.m_name2s_name
51 |
52 | def copy_to(self, model):
53 | m_param = dict(model.named_parameters())
54 | shadow_params = dict(self.named_buffers())
55 | for key in m_param:
56 | if m_param[key].requires_grad:
57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
58 | else:
59 | assert not key in self.m_name2s_name
60 |
61 | def store(self, parameters):
62 | """
63 | Save the current parameters for restoring later.
64 | Args:
65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
66 | temporarily stored.
67 | """
68 | self.collected_params = [param.clone() for param in parameters]
69 |
70 | def restore(self, parameters):
71 | """
72 | Restore the parameters stored with the `store` method.
73 | Useful to validate the model with EMA parameters without affecting the
74 | original optimization process. Store the parameters before the
75 | `copy_to` method. After validation (or model saving), use this to
76 | restore the former parameters.
77 | Args:
78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
79 | updated with the stored parameters.
80 | """
81 | for c_param, param in zip(self.collected_params, parameters):
82 | param.data.copy_(c_param.data)
83 |
--------------------------------------------------------------------------------
/audioldm/latent_diffusion/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from audioldm.utils import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(
22 | schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
23 | ):
24 | if schedule == "linear":
25 | betas = (
26 | torch.linspace(
27 | linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
28 | )
29 | ** 2
30 | )
31 |
32 | elif schedule == "cosine":
33 | timesteps = (
34 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
35 | )
36 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
37 | alphas = torch.cos(alphas).pow(2)
38 | alphas = alphas / alphas[0]
39 | betas = 1 - alphas[1:] / alphas[:-1]
40 | betas = np.clip(betas, a_min=0, a_max=0.999)
41 |
42 | elif schedule == "sqrt_linear":
43 | betas = torch.linspace(
44 | linear_start, linear_end, n_timestep, dtype=torch.float64
45 | )
46 | elif schedule == "sqrt":
47 | betas = (
48 | torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
49 | ** 0.5
50 | )
51 | else:
52 | raise ValueError(f"schedule '{schedule}' unknown.")
53 | return betas.numpy()
54 |
55 |
56 | def make_ddim_timesteps(
57 | ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
58 | ):
59 | if ddim_discr_method == "uniform":
60 | c = num_ddpm_timesteps // num_ddim_timesteps
61 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
62 | elif ddim_discr_method == "quad":
63 | ddim_timesteps = (
64 | (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
65 | ).astype(int)
66 | else:
67 | raise NotImplementedError(
68 | f'There is no ddim discretization method called "{ddim_discr_method}"'
69 | )
70 |
71 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
72 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
73 | steps_out = ddim_timesteps + 1
74 | if verbose:
75 | print(f"Selected timesteps for ddim sampler: {steps_out}")
76 | return steps_out
77 |
78 |
79 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
80 | # select alphas for computing the variance schedule
81 | alphas = alphacums[ddim_timesteps]
82 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
83 |
84 | # according the the formula provided in https://arxiv.org/abs/2010.02502
85 | sigmas = eta * np.sqrt(
86 | (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
87 | )
88 | if verbose:
89 | print(
90 | f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
91 | )
92 | print(
93 | f"For the chosen value of eta, which is {eta}, "
94 | f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
95 | )
96 | return sigmas, alphas, alphas_prev
97 |
98 |
99 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
100 | """
101 | Create a beta schedule that discretizes the given alpha_t_bar function,
102 | which defines the cumulative product of (1-beta) over time from t = [0,1].
103 | :param num_diffusion_timesteps: the number of betas to produce.
104 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
105 | produces the cumulative product of (1-beta) up to that
106 | part of the diffusion process.
107 | :param max_beta: the maximum beta to use; use values lower than 1 to
108 | prevent singularities.
109 | """
110 | betas = []
111 | for i in range(num_diffusion_timesteps):
112 | t1 = i / num_diffusion_timesteps
113 | t2 = (i + 1) / num_diffusion_timesteps
114 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
115 | return np.array(betas)
116 |
117 |
118 | def extract_into_tensor(a, t, x_shape):
119 | b, *_ = t.shape
120 | out = a.gather(-1, t).contiguous()
121 | return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous()
122 |
123 |
124 | def checkpoint(func, inputs, params, flag):
125 | """
126 | Evaluate a function without caching intermediate activations, allowing for
127 | reduced memory at the expense of extra compute in the backward pass.
128 | :param func: the function to evaluate.
129 | :param inputs: the argument sequence to pass to `func`.
130 | :param params: a sequence of parameters `func` depends on but does not
131 | explicitly take as arguments.
132 | :param flag: if False, disable gradient checkpointing.
133 | """
134 | if flag:
135 | args = tuple(inputs) + tuple(params)
136 | return CheckpointFunction.apply(func, len(inputs), *args)
137 | else:
138 | return func(*inputs)
139 |
140 |
141 | class CheckpointFunction(torch.autograd.Function):
142 | @staticmethod
143 | def forward(ctx, run_function, length, *args):
144 | ctx.run_function = run_function
145 | ctx.input_tensors = list(args[:length])
146 | ctx.input_params = list(args[length:])
147 |
148 | with torch.no_grad():
149 | output_tensors = ctx.run_function(*ctx.input_tensors)
150 | return output_tensors
151 |
152 | @staticmethod
153 | def backward(ctx, *output_grads):
154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155 | with torch.enable_grad():
156 | # Fixes a bug where the first op in run_function modifies the
157 | # Tensor storage in place, which is not allowed for detach()'d
158 | # Tensors.
159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160 | output_tensors = ctx.run_function(*shallow_copies)
161 | input_grads = torch.autograd.grad(
162 | output_tensors,
163 | ctx.input_tensors + ctx.input_params,
164 | output_grads,
165 | allow_unused=True,
166 | )
167 | del ctx.input_tensors
168 | del ctx.input_params
169 | del output_tensors
170 | return (None, None) + input_grads
171 |
172 |
173 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
174 | """
175 | Create sinusoidal timestep embeddings.
176 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
177 | These may be fractional.
178 | :param dim: the dimension of the output.
179 | :param max_period: controls the minimum frequency of the embeddings.
180 | :return: an [N x dim] Tensor of positional embeddings.
181 | """
182 | if not repeat_only:
183 | half = dim // 2
184 | freqs = torch.exp(
185 | -math.log(max_period)
186 | * torch.arange(start=0, end=half, dtype=torch.float32)
187 | / half
188 | ).to(device=timesteps.device)
189 | args = timesteps[:, None].float() * freqs[None]
190 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
191 | if dim % 2:
192 | embedding = torch.cat(
193 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
194 | )
195 | else:
196 | embedding = repeat(timesteps, "b -> b d", d=dim)
197 | return embedding
198 |
199 |
200 | def zero_module(module):
201 | """
202 | Zero out the parameters of a module and return it.
203 | """
204 | for p in module.parameters():
205 | p.detach().zero_()
206 | return module
207 |
208 |
209 | def scale_module(module, scale):
210 | """
211 | Scale the parameters of a module and return it.
212 | """
213 | for p in module.parameters():
214 | p.detach().mul_(scale)
215 | return module
216 |
217 |
218 | def mean_flat(tensor):
219 | """
220 | Take the mean over all non-batch dimensions.
221 | """
222 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
223 |
224 |
225 | def normalization(channels):
226 | """
227 | Make a standard normalization layer.
228 | :param channels: number of input channels.
229 | :return: an nn.Module for normalization.
230 | """
231 | return GroupNorm32(32, channels)
232 |
233 |
234 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
235 | class SiLU(nn.Module):
236 | def forward(self, x):
237 | return x * torch.sigmoid(x)
238 |
239 |
240 | class GroupNorm32(nn.GroupNorm):
241 | def forward(self, x):
242 | return super().forward(x.float()).type(x.dtype)
243 |
244 |
245 | def conv_nd(dims, *args, **kwargs):
246 | """
247 | Create a 1D, 2D, or 3D convolution module.
248 | """
249 | if dims == 1:
250 | return nn.Conv1d(*args, **kwargs)
251 | elif dims == 2:
252 | return nn.Conv2d(*args, **kwargs)
253 | elif dims == 3:
254 | return nn.Conv3d(*args, **kwargs)
255 | raise ValueError(f"unsupported dimensions: {dims}")
256 |
257 |
258 | def linear(*args, **kwargs):
259 | """
260 | Create a linear module.
261 | """
262 | return nn.Linear(*args, **kwargs)
263 |
264 |
265 | def avg_pool_nd(dims, *args, **kwargs):
266 | """
267 | Create a 1D, 2D, or 3D average pooling module.
268 | """
269 | if dims == 1:
270 | return nn.AvgPool1d(*args, **kwargs)
271 | elif dims == 2:
272 | return nn.AvgPool2d(*args, **kwargs)
273 | elif dims == 3:
274 | return nn.AvgPool3d(*args, **kwargs)
275 | raise ValueError(f"unsupported dimensions: {dims}")
276 |
277 |
278 | class HybridConditioner(nn.Module):
279 | def __init__(self, c_concat_config, c_crossattn_config):
280 | super().__init__()
281 | self.concat_conditioner = instantiate_from_config(c_concat_config)
282 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
283 |
284 | def forward(self, c_concat, c_crossattn):
285 | c_concat = self.concat_conditioner(c_concat)
286 | c_crossattn = self.crossattn_conditioner(c_crossattn)
287 | return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
288 |
289 |
290 | def noise_like(shape, device, repeat=False):
291 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
292 | shape[0], *((1,) * (len(shape) - 1))
293 | )
294 | noise = lambda: torch.randn(shape, device=device)
295 | return repeat_noise() if repeat else noise()
296 |
--------------------------------------------------------------------------------
/audioldm/pipeline.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import argparse
4 | import yaml
5 | import torch
6 | from torch import autocast
7 | from tqdm import tqdm, trange
8 |
9 | from audioldm import LatentDiffusion, seed_everything
10 | from audioldm.utils import default_audioldm_config, get_duration, get_bit_depth, get_metadata, download_checkpoint
11 | from audioldm.audio import wav_to_fbank, TacotronSTFT, read_wav_file
12 | from audioldm.latent_diffusion.ddim import DDIMSampler
13 | from einops import repeat
14 | import os
15 |
16 | def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1):
17 | text = [text] * batchsize
18 | if batchsize < 1:
19 | print("Warning: Batchsize must be at least 1. Batchsize is set to .")
20 |
21 | if(fbank is None):
22 | fbank = torch.zeros((batchsize, 1024, 64)) # Not used, here to keep the code format
23 | else:
24 | fbank = torch.FloatTensor(fbank)
25 | fbank = fbank.expand(batchsize, 1024, 64)
26 | assert fbank.size(0) == batchsize
27 |
28 | stft = torch.zeros((batchsize, 1024, 512)) # Not used
29 |
30 | if(waveform is None):
31 | waveform = torch.zeros((batchsize, 160000)) # Not used
32 | else:
33 | waveform = torch.FloatTensor(waveform)
34 | waveform = waveform.expand(batchsize, -1)
35 | assert waveform.size(0) == batchsize
36 |
37 | fname = [""] * batchsize # Not used
38 |
39 | batch = (
40 | fbank,
41 | stft,
42 | None,
43 | fname,
44 | waveform,
45 | text,
46 | )
47 | return batch
48 |
49 | def round_up_duration(duration):
50 | return int(round(duration/2.5) + 1) * 2.5
51 |
52 | def build_model(
53 | ckpt_path=None,
54 | config=None,
55 | model_name="audioldm-s-full"
56 | ):
57 | print("Load AudioLDM: %s", model_name)
58 |
59 | if(ckpt_path is None):
60 | ckpt_path = get_metadata()[model_name]["path"]
61 |
62 | if(not os.path.exists(ckpt_path)):
63 | download_checkpoint(model_name)
64 |
65 | if torch.cuda.is_available():
66 | device = torch.device("cuda:0")
67 | else:
68 | device = torch.device("cpu")
69 |
70 | if config is not None:
71 | assert type(config) is str
72 | config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
73 | else:
74 | config = default_audioldm_config(model_name)
75 |
76 | # Use text as condition instead of using waveform during training
77 | config["model"]["params"]["device"] = device
78 | config["model"]["params"]["cond_stage_key"] = "text"
79 |
80 | # No normalization here
81 | latent_diffusion = LatentDiffusion(**config["model"]["params"])
82 |
83 | resume_from_checkpoint = ckpt_path
84 |
85 | checkpoint = torch.load(resume_from_checkpoint, map_location=device)
86 | '''Original. Here is a bug that, an unexpected key "cond_stage_model.model.text_branch.embeddings.position_ids" exists in the checkpoint file. '''
87 | # latent_diffusion.load_state_dict(checkpoint["state_dict"])
88 | '''2023.10.17 Fix the bug by setting the paramer "strict" as "False" to ignore the unexpected key. '''
89 | latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=False)
90 |
91 | latent_diffusion.eval()
92 | latent_diffusion = latent_diffusion.to(device)
93 |
94 | latent_diffusion.cond_stage_model.embed_mode = "text"
95 | return latent_diffusion
96 |
97 | def duration_to_latent_t_size(duration):
98 | return int(duration * 25.6)
99 |
100 | def set_cond_audio(latent_diffusion):
101 | latent_diffusion.cond_stage_key = "waveform"
102 | latent_diffusion.cond_stage_model.embed_mode="audio"
103 | return latent_diffusion
104 |
105 | def set_cond_text(latent_diffusion):
106 | latent_diffusion.cond_stage_key = "text"
107 | latent_diffusion.cond_stage_model.embed_mode="text"
108 | return latent_diffusion
109 |
110 | def text_to_audio(
111 | latent_diffusion,
112 | text,
113 | original_audio_file_path = None,
114 | seed=42,
115 | ddim_steps=200,
116 | duration=10,
117 | batchsize=1,
118 | guidance_scale=2.5,
119 | n_candidate_gen_per_text=3,
120 | config=None,
121 | ):
122 | seed_everything(int(seed))
123 | waveform = None
124 | if(original_audio_file_path is not None):
125 | waveform = read_wav_file(original_audio_file_path, int(duration * 102.4) * 160)
126 |
127 | batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize)
128 |
129 | latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
130 |
131 | if(waveform is not None):
132 | print("Generate audio that has similar content as %s" % original_audio_file_path)
133 | latent_diffusion = set_cond_audio(latent_diffusion)
134 | else:
135 | print("Generate audio using text %s" % text)
136 | latent_diffusion = set_cond_text(latent_diffusion)
137 |
138 | with torch.no_grad():
139 | waveform = latent_diffusion.generate_sample(
140 | [batch],
141 | unconditional_guidance_scale=guidance_scale,
142 | ddim_steps=ddim_steps,
143 | n_candidate_gen_per_text=n_candidate_gen_per_text,
144 | duration=duration,
145 | )
146 | return waveform
147 |
148 | def style_transfer(
149 | latent_diffusion,
150 | text,
151 | original_audio_file_path,
152 | transfer_strength,
153 | seed=42,
154 | duration=10,
155 | batchsize=1,
156 | guidance_scale=2.5,
157 | ddim_steps=200,
158 | config=None,
159 | ):
160 | if torch.cuda.is_available():
161 | device = torch.device("cuda:0")
162 | else:
163 | device = torch.device("cpu")
164 |
165 | assert original_audio_file_path is not None, "You need to provide the original audio file path"
166 |
167 | audio_file_duration = get_duration(original_audio_file_path)
168 |
169 | assert get_bit_depth(original_audio_file_path) == 16, "The bit depth of the original audio file %s must be 16" % original_audio_file_path
170 |
171 | # if(duration > 20):
172 | # print("Warning: The duration of the audio file %s must be less than 20 seconds. Longer duration will result in Nan in model output (we are still debugging that); Automatically set duration to 20 seconds")
173 | # duration = 20
174 |
175 | if(duration > audio_file_duration):
176 | print("Warning: Duration you specified %s-seconds must equal or smaller than the audio file duration %ss" % (duration, audio_file_duration))
177 | duration = round_up_duration(audio_file_duration)
178 | print("Set new duration as %s-seconds" % duration)
179 |
180 | # duration = round_up_duration(duration)
181 |
182 | latent_diffusion = set_cond_text(latent_diffusion)
183 |
184 | if config is not None:
185 | assert type(config) is str
186 | config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
187 | else:
188 | config = default_audioldm_config()
189 |
190 | seed_everything(int(seed))
191 | # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
192 | latent_diffusion.cond_stage_model.embed_mode = "text"
193 |
194 | fn_STFT = TacotronSTFT(
195 | config["preprocessing"]["stft"]["filter_length"],
196 | config["preprocessing"]["stft"]["hop_length"],
197 | config["preprocessing"]["stft"]["win_length"],
198 | config["preprocessing"]["mel"]["n_mel_channels"],
199 | config["preprocessing"]["audio"]["sampling_rate"],
200 | config["preprocessing"]["mel"]["mel_fmin"],
201 | config["preprocessing"]["mel"]["mel_fmax"],
202 | )
203 |
204 | mel, _, _ = wav_to_fbank(
205 | original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT
206 | )
207 | mel = mel.unsqueeze(0).unsqueeze(0).to(device)
208 | mel = repeat(mel, "1 ... -> b ...", b=batchsize)
209 | init_latent = latent_diffusion.get_first_stage_encoding(
210 | latent_diffusion.encode_first_stage(mel)
211 | ) # move to latent space, encode and sample
212 | if(torch.max(torch.abs(init_latent)) > 1e2):
213 | init_latent = torch.clip(init_latent, min=-10, max=10)
214 | sampler = DDIMSampler(latent_diffusion)
215 | sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=1.0, verbose=False)
216 |
217 | t_enc = int(transfer_strength * ddim_steps)
218 | prompts = text
219 |
220 | with torch.no_grad():
221 | with autocast("cuda"):
222 | with latent_diffusion.ema_scope():
223 | uc = None
224 | if guidance_scale != 1.0:
225 | uc = latent_diffusion.cond_stage_model.get_unconditional_condition(
226 | batchsize
227 | )
228 |
229 | c = latent_diffusion.get_learned_conditioning([prompts] * batchsize)
230 | z_enc = sampler.stochastic_encode(
231 | init_latent, torch.tensor([t_enc] * batchsize).to(device)
232 | )
233 | samples = sampler.decode(
234 | z_enc,
235 | c,
236 | t_enc,
237 | unconditional_guidance_scale=guidance_scale,
238 | unconditional_conditioning=uc,
239 | )
240 | # x_samples = latent_diffusion.decode_first_stage(samples) # Will result in Nan in output
241 | # print(torch.sum(torch.isnan(samples)))
242 | x_samples = latent_diffusion.decode_first_stage(samples)
243 | # print(x_samples)
244 | x_samples = latent_diffusion.decode_first_stage(samples[:,:,:-3,:])
245 | # print(x_samples)
246 | waveform = latent_diffusion.first_stage_model.decode_to_waveform(
247 | x_samples
248 | )
249 |
250 | return waveform
251 |
252 | def super_resolution_and_inpainting(
253 | latent_diffusion,
254 | text,
255 | original_audio_file_path = None,
256 | seed=42,
257 | ddim_steps=200,
258 | duration=None,
259 | batchsize=1,
260 | guidance_scale=2.5,
261 | n_candidate_gen_per_text=3,
262 | time_mask_ratio_start_and_end=(0.10, 0.15), # regenerate the 10% to 15% of the time steps in the spectrogram
263 | # time_mask_ratio_start_and_end=(1.0, 1.0), # no inpainting
264 | # freq_mask_ratio_start_and_end=(0.75, 1.0), # regenerate the higher 75% to 100% mel bins
265 | freq_mask_ratio_start_and_end=(1.0, 1.0), # no super-resolution
266 | config=None,
267 | ):
268 | seed_everything(int(seed))
269 | if config is not None:
270 | assert type(config) is str
271 | config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
272 | else:
273 | config = default_audioldm_config()
274 | fn_STFT = TacotronSTFT(
275 | config["preprocessing"]["stft"]["filter_length"],
276 | config["preprocessing"]["stft"]["hop_length"],
277 | config["preprocessing"]["stft"]["win_length"],
278 | config["preprocessing"]["mel"]["n_mel_channels"],
279 | config["preprocessing"]["audio"]["sampling_rate"],
280 | config["preprocessing"]["mel"]["mel_fmin"],
281 | config["preprocessing"]["mel"]["mel_fmax"],
282 | )
283 |
284 | # waveform = read_wav_file(original_audio_file_path, None)
285 | mel, _, _ = wav_to_fbank(
286 | original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT
287 | )
288 |
289 | batch = make_batch_for_text_to_audio(text, fbank=mel[None,...], batchsize=batchsize)
290 |
291 | # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
292 | latent_diffusion = set_cond_text(latent_diffusion)
293 |
294 | with torch.no_grad():
295 | waveform = latent_diffusion.generate_sample_masked(
296 | [batch],
297 | unconditional_guidance_scale=guidance_scale,
298 | ddim_steps=ddim_steps,
299 | n_candidate_gen_per_text=n_candidate_gen_per_text,
300 | duration=duration,
301 | time_mask_ratio_start_and_end=time_mask_ratio_start_and_end,
302 | freq_mask_ratio_start_and_end=freq_mask_ratio_start_and_end
303 | )
304 | return waveform
305 |
--------------------------------------------------------------------------------
/audioldm/utils.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import importlib
3 |
4 | from inspect import isfunction
5 | import os
6 | import soundfile as sf
7 | import time
8 | import wave
9 |
10 | import urllib.request
11 | import progressbar
12 |
13 | CACHE_DIR = os.getenv(
14 | "AUDIOLDM_CACHE_DIR",
15 | os.path.join(os.path.expanduser("~"), ".cache/audioldm"))
16 |
17 | def get_duration(fname):
18 | with contextlib.closing(wave.open(fname, 'r')) as f:
19 | frames = f.getnframes()
20 | rate = f.getframerate()
21 | return frames / float(rate)
22 |
23 | def get_bit_depth(fname):
24 | with contextlib.closing(wave.open(fname, 'r')) as f:
25 | bit_depth = f.getsampwidth() * 8
26 | return bit_depth
27 |
28 | def get_time():
29 | t = time.localtime()
30 | return time.strftime("%d_%m_%Y_%H_%M_%S", t)
31 |
32 | def seed_everything(seed):
33 | import random, os
34 | import numpy as np
35 | import torch
36 |
37 | random.seed(seed)
38 | os.environ["PYTHONHASHSEED"] = str(seed)
39 | np.random.seed(seed)
40 | torch.manual_seed(seed)
41 | torch.cuda.manual_seed(seed)
42 | torch.backends.cudnn.deterministic = True
43 | torch.backends.cudnn.benchmark = True
44 |
45 |
46 | def save_wave(waveform, savepath, name="outwav"):
47 | if type(name) is not list:
48 | name = [name] * waveform.shape[0]
49 |
50 | for i in range(waveform.shape[0]):
51 | path = os.path.join(
52 | savepath,
53 | "%s_%s.wav"
54 | % (
55 | os.path.basename(name[i])
56 | if (not ".wav" in name[i])
57 | else os.path.basename(name[i]).split(".")[0],
58 | i,
59 | ),
60 | )
61 | print("Save audio to %s" % path)
62 | sf.write(path, waveform[i, 0], samplerate=16000)
63 |
64 |
65 | def exists(x):
66 | return x is not None
67 |
68 |
69 | def default(val, d):
70 | if exists(val):
71 | return val
72 | return d() if isfunction(d) else d
73 |
74 |
75 | def count_params(model, verbose=False):
76 | total_params = sum(p.numel() for p in model.parameters())
77 | if verbose:
78 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
79 | return total_params
80 |
81 |
82 | def get_obj_from_str(string, reload=False):
83 | module, cls = string.rsplit(".", 1)
84 | if reload:
85 | module_imp = importlib.import_module(module)
86 | importlib.reload(module_imp)
87 | return getattr(importlib.import_module(module, package=None), cls)
88 |
89 |
90 | def instantiate_from_config(config):
91 | if not "target" in config:
92 | if config == "__is_first_stage__":
93 | return None
94 | elif config == "__is_unconditional__":
95 | return None
96 | raise KeyError("Expected key `target` to instantiate.")
97 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
98 |
99 |
100 | def default_audioldm_config(model_name="audioldm-s-full"):
101 | basic_config = {
102 | "wave_file_save_path": "./output",
103 | "id": {
104 | "version": "v1",
105 | "name": "default",
106 | "root": "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/AudioLDM-python/config/default/latent_diffusion.yaml",
107 | },
108 | "preprocessing": {
109 | "audio": {"sampling_rate": 16000, "max_wav_value": 32768},
110 | "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024},
111 | "mel": {
112 | "n_mel_channels": 64,
113 | "mel_fmin": 0,
114 | "mel_fmax": 8000,
115 | "freqm": 0,
116 | "timem": 0,
117 | "blur": False,
118 | "mean": -4.63,
119 | "std": 2.74,
120 | "target_length": 1024,
121 | },
122 | },
123 | "model": {
124 | "device": "cuda",
125 | "target": "audioldm.pipline.LatentDiffusion",
126 | "params": {
127 | "base_learning_rate": 5e-06,
128 | "linear_start": 0.0015,
129 | "linear_end": 0.0195,
130 | "num_timesteps_cond": 1,
131 | "log_every_t": 200,
132 | "timesteps": 1000,
133 | "first_stage_key": "fbank",
134 | "cond_stage_key": "waveform",
135 | "latent_t_size": 256,
136 | "latent_f_size": 16,
137 | "channels": 8,
138 | "cond_stage_trainable": True,
139 | "conditioning_key": "film",
140 | "monitor": "val/loss_simple_ema",
141 | "scale_by_std": True,
142 | "unet_config": {
143 | "target": "audioldm.latent_diffusion.openaimodel.UNetModel",
144 | "params": {
145 | "image_size": 64,
146 | "extra_film_condition_dim": 512,
147 | "extra_film_use_concat": True,
148 | "in_channels": 8,
149 | "out_channels": 8,
150 | "model_channels": 128,
151 | "attention_resolutions": [8, 4, 2],
152 | "num_res_blocks": 2,
153 | "channel_mult": [1, 2, 3, 5],
154 | "num_head_channels": 32,
155 | "use_spatial_transformer": True,
156 | },
157 | },
158 | "first_stage_config": {
159 | "base_learning_rate": 4.5e-05,
160 | "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL",
161 | "params": {
162 | "monitor": "val/rec_loss",
163 | "image_key": "fbank",
164 | "subband": 1,
165 | "embed_dim": 8,
166 | "time_shuffle": 1,
167 | "ddconfig": {
168 | "double_z": True,
169 | "z_channels": 8,
170 | "resolution": 256,
171 | "downsample_time": False,
172 | "in_channels": 1,
173 | "out_ch": 1,
174 | "ch": 128,
175 | "ch_mult": [1, 2, 4],
176 | "num_res_blocks": 2,
177 | "attn_resolutions": [],
178 | "dropout": 0.0,
179 | },
180 | },
181 | },
182 | "cond_stage_config": {
183 | "target": "audioldm.clap.encoders.CLAPAudioEmbeddingClassifierFreev2",
184 | "params": {
185 | "key": "waveform",
186 | "sampling_rate": 16000,
187 | "embed_mode": "audio",
188 | "unconditional_prob": 0.1,
189 | },
190 | },
191 | },
192 | },
193 | }
194 |
195 | if("-l-" in model_name):
196 | basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 256
197 | basic_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = 64
198 | elif("-m-" in model_name):
199 | basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 192
200 | basic_config["model"]["params"]["cond_stage_config"]["params"]["amodel"] = "HTSAT-base" # This model use a larger HTAST
201 |
202 | return basic_config
203 |
204 | def get_metadata():
205 | return {
206 | "audioldm-s-full": {
207 | "path": os.path.join(
208 | CACHE_DIR,
209 | "audioldm-s-full.ckpt",
210 | ),
211 | "url": "https://zenodo.org/record/7600541/files/audioldm-s-full?download=1",
212 | },
213 | "audioldm-l-full": {
214 | "path": os.path.join(
215 | CACHE_DIR,
216 | "audioldm-l-full.ckpt",
217 | ),
218 | "url": "https://zenodo.org/record/7698295/files/audioldm-full-l.ckpt?download=1",
219 | },
220 | "audioldm-s-full-v2": {
221 | "path": os.path.join(
222 | CACHE_DIR,
223 | "audioldm-s-full-v2.ckpt",
224 | ),
225 | "url": "https://zenodo.org/record/7698295/files/audioldm-full-s-v2.ckpt?download=1",
226 | },
227 | "audioldm-m-text-ft": {
228 | "path": os.path.join(
229 | CACHE_DIR,
230 | "audioldm-m-text-ft.ckpt",
231 | ),
232 | "url": "https://zenodo.org/record/7813012/files/audioldm-m-text-ft.ckpt?download=1",
233 | },
234 | "audioldm-s-text-ft": {
235 | "path": os.path.join(
236 | CACHE_DIR,
237 | "audioldm-s-text-ft.ckpt",
238 | ),
239 | "url": "https://zenodo.org/record/7813012/files/audioldm-s-text-ft.ckpt?download=1",
240 | },
241 | "audioldm-m-full": {
242 | "path": os.path.join(
243 | CACHE_DIR,
244 | "audioldm-m-full.ckpt",
245 | ),
246 | "url": "https://zenodo.org/record/7813012/files/audioldm-m-full.ckpt?download=1",
247 | },
248 | }
249 |
250 | class MyProgressBar():
251 | def __init__(self):
252 | self.pbar = None
253 |
254 | def __call__(self, block_num, block_size, total_size):
255 | if not self.pbar:
256 | self.pbar=progressbar.ProgressBar(maxval=total_size)
257 | self.pbar.start()
258 |
259 | downloaded = block_num * block_size
260 | if downloaded < total_size:
261 | self.pbar.update(downloaded)
262 | else:
263 | self.pbar.finish()
264 |
265 | def download_checkpoint(checkpoint_name="audioldm-s-full"):
266 | meta = get_metadata()
267 | if(checkpoint_name not in meta.keys()):
268 | print("The model name you provided is not supported. Please use one of the following: ", meta.keys())
269 |
270 | if not os.path.exists(meta[checkpoint_name]["path"]) or os.path.getsize(meta[checkpoint_name]["path"]) < 2*10**9:
271 | os.makedirs(os.path.dirname(meta[checkpoint_name]["path"]), exist_ok=True)
272 | print(f"Downloading the main structure of {checkpoint_name} into {os.path.dirname(meta[checkpoint_name]['path'])}")
273 |
274 | urllib.request.urlretrieve(meta[checkpoint_name]["url"], meta[checkpoint_name]["path"], MyProgressBar())
275 | print(
276 | "Weights downloaded in: {} Size: {}".format(
277 | meta[checkpoint_name]["path"],
278 | os.path.getsize(meta[checkpoint_name]["path"]),
279 | )
280 | )
281 |
--------------------------------------------------------------------------------
/audioldm/variational_autoencoder/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/audioldm/variational_autoencoder/__init__.py
--------------------------------------------------------------------------------
/audioldm/variational_autoencoder/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from audioldm.latent_diffusion.ema import *
3 | from audioldm.variational_autoencoder.modules import Encoder, Decoder
4 | from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution
5 |
6 | from audioldm.hifigan.utilities import get_vocoder, vocoder_infer
7 |
8 |
9 | class AutoencoderKL(nn.Module):
10 | def __init__(
11 | self,
12 | ddconfig=None,
13 | lossconfig=None,
14 | image_key="fbank",
15 | embed_dim=None,
16 | time_shuffle=1,
17 | subband=1,
18 | ckpt_path=None,
19 | reload_from_ckpt=None,
20 | ignore_keys=[],
21 | colorize_nlabels=None,
22 | monitor=None,
23 | base_learning_rate=1e-5,
24 | ):
25 | super().__init__()
26 |
27 | self.encoder = Encoder(**ddconfig)
28 | self.decoder = Decoder(**ddconfig)
29 |
30 | self.subband = int(subband)
31 |
32 | if self.subband > 1:
33 | print("Use subband decomposition %s" % self.subband)
34 |
35 | self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
36 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
37 |
38 | self.vocoder = get_vocoder(None, "cpu")
39 | self.embed_dim = embed_dim
40 |
41 | if monitor is not None:
42 | self.monitor = monitor
43 |
44 | self.time_shuffle = time_shuffle
45 | self.reload_from_ckpt = reload_from_ckpt
46 | self.reloaded = False
47 | self.mean, self.std = None, None
48 |
49 | def encode(self, x):
50 | # x = self.time_shuffle_operation(x)
51 | x = self.freq_split_subband(x)
52 | h = self.encoder(x)
53 | moments = self.quant_conv(h)
54 | posterior = DiagonalGaussianDistribution(moments)
55 | return posterior
56 |
57 | def decode(self, z):
58 | z = self.post_quant_conv(z)
59 | dec = self.decoder(z)
60 | dec = self.freq_merge_subband(dec)
61 | return dec
62 |
63 | def decode_to_waveform(self, dec):
64 | dec = dec.squeeze(1).permute(0, 2, 1)
65 | wav_reconstruction = vocoder_infer(dec, self.vocoder)
66 | return wav_reconstruction
67 |
68 | def forward(self, input, sample_posterior=True):
69 | posterior = self.encode(input)
70 | if sample_posterior:
71 | z = posterior.sample()
72 | else:
73 | z = posterior.mode()
74 |
75 | if self.flag_first_run:
76 | print("Latent size: ", z.size())
77 | self.flag_first_run = False
78 |
79 | dec = self.decode(z)
80 |
81 | return dec, posterior
82 |
83 | def freq_split_subband(self, fbank):
84 | if self.subband == 1 or self.image_key != "stft":
85 | return fbank
86 |
87 | bs, ch, tstep, fbins = fbank.size()
88 |
89 | assert fbank.size(-1) % self.subband == 0
90 | assert ch == 1
91 |
92 | return (
93 | fbank.squeeze(1)
94 | .reshape(bs, tstep, self.subband, fbins // self.subband)
95 | .permute(0, 2, 1, 3)
96 | )
97 |
98 | def freq_merge_subband(self, subband_fbank):
99 | if self.subband == 1 or self.image_key != "stft":
100 | return subband_fbank
101 | assert subband_fbank.size(1) == self.subband # Channel dimension
102 | bs, sub_ch, tstep, fbins = subband_fbank.size()
103 | return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1)
104 |
--------------------------------------------------------------------------------
/audioldm/variational_autoencoder/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(
34 | device=self.parameters.device
35 | )
36 |
37 | def sample(self):
38 | x = self.mean + self.std * torch.randn(self.mean.shape).to(
39 | device=self.parameters.device
40 | )
41 | return x
42 |
43 | def kl(self, other=None):
44 | if self.deterministic:
45 | return torch.Tensor([0.0])
46 | else:
47 | if other is None:
48 | return 0.5 * torch.mean(
49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50 | dim=[1, 2, 3],
51 | )
52 | else:
53 | return 0.5 * torch.mean(
54 | torch.pow(self.mean - other.mean, 2) / other.var
55 | + self.var / other.var
56 | - 1.0
57 | - self.logvar
58 | + other.logvar,
59 | dim=[1, 2, 3],
60 | )
61 |
62 | def nll(self, sample, dims=[1, 2, 3]):
63 | if self.deterministic:
64 | return torch.Tensor([0.0])
65 | logtwopi = np.log(2.0 * np.pi)
66 | return 0.5 * torch.sum(
67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68 | dim=dims,
69 | )
70 |
71 | def mode(self):
72 | return self.mean
73 |
74 |
75 | def normal_kl(mean1, logvar1, mean2, logvar2):
76 | """
77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78 | Compute the KL divergence between two gaussians.
79 | Shapes are automatically broadcasted, so batches can be compared to
80 | scalars, among other use cases.
81 | """
82 | tensor = None
83 | for obj in (mean1, logvar1, mean2, logvar2):
84 | if isinstance(obj, torch.Tensor):
85 | tensor = obj
86 | break
87 | assert tensor is not None, "at least one argument must be a Tensor"
88 |
89 | # Force variances to be Tensors. Broadcasting helps convert scalars to
90 | # Tensors, but it does not work for torch.exp().
91 | logvar1, logvar2 = [
92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93 | for x in (logvar1, logvar2)
94 | ]
95 |
96 | return 0.5 * (
97 | -1.0
98 | + logvar2
99 | - logvar1
100 | + torch.exp(logvar1 - logvar2)
101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102 | )
103 |
--------------------------------------------------------------------------------
/bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/bg.png
--------------------------------------------------------------------------------
/bin/audioldm:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | import os
3 | from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration
4 | import argparse
5 |
6 | CACHE_DIR = os.getenv(
7 | "AUDIOLDM_CACHE_DIR",
8 | os.path.join(os.path.expanduser("~"), ".cache/audioldm"))
9 |
10 | parser = argparse.ArgumentParser()
11 |
12 | parser.add_argument(
13 | "--mode",
14 | type=str,
15 | required=False,
16 | default="generation",
17 | help="generation: text-to-audio generation; transfer: style transfer",
18 | choices=["generation", "transfer"]
19 | )
20 |
21 | parser.add_argument(
22 | "-t",
23 | "--text",
24 | type=str,
25 | required=False,
26 | default="",
27 | help="Text prompt to the model for audio generation",
28 | )
29 |
30 | parser.add_argument(
31 | "-f",
32 | "--file_path",
33 | type=str,
34 | required=False,
35 | default=None,
36 | help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio",
37 | )
38 |
39 | parser.add_argument(
40 | "--transfer_strength",
41 | type=float,
42 | required=False,
43 | default=0.5,
44 | help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text",
45 | )
46 |
47 | parser.add_argument(
48 | "-s",
49 | "--save_path",
50 | type=str,
51 | required=False,
52 | help="The path to save model output",
53 | default="./output",
54 | )
55 |
56 | parser.add_argument(
57 | "--model_name",
58 | type=str,
59 | required=False,
60 | help="The checkpoint you gonna use",
61 | default="audioldm-m-full",
62 | choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2","audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full"]
63 | )
64 |
65 | parser.add_argument(
66 | "-ckpt",
67 | "--ckpt_path",
68 | type=str,
69 | required=False,
70 | help="The path to the pretrained .ckpt model",
71 | default=None,
72 | )
73 |
74 | parser.add_argument(
75 | "-b",
76 | "--batchsize",
77 | type=int,
78 | required=False,
79 | default=1,
80 | help="Generate how many samples at the same time",
81 | )
82 |
83 | parser.add_argument(
84 | "--ddim_steps",
85 | type=int,
86 | required=False,
87 | default=200,
88 | help="The sampling step for DDIM",
89 | )
90 |
91 | parser.add_argument(
92 | "-gs",
93 | "--guidance_scale",
94 | type=float,
95 | required=False,
96 | default=2.5,
97 | help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
98 | )
99 |
100 | parser.add_argument(
101 | "-dur",
102 | "--duration",
103 | type=float,
104 | required=False,
105 | default=10.0,
106 | help="The duration of the samples",
107 | )
108 |
109 | parser.add_argument(
110 | "-n",
111 | "--n_candidate_gen_per_text",
112 | type=int,
113 | required=False,
114 | default=3,
115 | help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
116 | )
117 |
118 | parser.add_argument(
119 | "--seed",
120 | type=int,
121 | required=False,
122 | default=42,
123 | help="Change this value (any integer number) will lead to a different generation result.",
124 | )
125 |
126 | args = parser.parse_args()
127 |
128 | if(args.ckpt_path is not None):
129 | print("Warning: ckpt_path has no effect after version 0.0.20.")
130 |
131 | assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5"
132 |
133 | mode = args.mode
134 | if(mode == "generation" and args.file_path is not None):
135 | mode = "generation_audio_to_audio"
136 | if(len(args.text) > 0):
137 | print("Warning: You have specified the --file_path. --text will be ignored")
138 | args.text = ""
139 |
140 | save_path = os.path.join(args.save_path, mode)
141 |
142 | if(args.file_path is not None):
143 | save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0]))
144 |
145 | text = args.text
146 | random_seed = args.seed
147 | duration = args.duration
148 | guidance_scale = args.guidance_scale
149 | n_candidate_gen_per_text = args.n_candidate_gen_per_text
150 |
151 | os.makedirs(save_path, exist_ok=True)
152 | audioldm = build_model(model_name=args.model_name)
153 |
154 | if(args.mode == "generation"):
155 | waveform = text_to_audio(
156 | audioldm,
157 | text,
158 | args.file_path,
159 | random_seed,
160 | duration=duration,
161 | guidance_scale=guidance_scale,
162 | ddim_steps=args.ddim_steps,
163 | n_candidate_gen_per_text=n_candidate_gen_per_text,
164 | batchsize=args.batchsize,
165 | )
166 |
167 | elif(args.mode == "transfer"):
168 | assert args.file_path is not None
169 | assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path
170 | waveform = style_transfer(
171 | audioldm,
172 | text,
173 | args.file_path,
174 | args.transfer_strength,
175 | random_seed,
176 | duration=duration,
177 | guidance_scale=guidance_scale,
178 | ddim_steps=args.ddim_steps,
179 | batchsize=args.batchsize,
180 | )
181 | waveform = waveform[:,None,:]
182 |
183 | save_wave(waveform, save_path, name="%s_%s" % (get_time(), text))
184 |
--------------------------------------------------------------------------------
/bin/audioldm.cmd:
--------------------------------------------------------------------------------
1 | @echo OFF
2 | python -m audioldm %*
--------------------------------------------------------------------------------
/ckpt/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/ckpt/.gitkeep
--------------------------------------------------------------------------------
/scripts/test.sh:
--------------------------------------------------------------------------------
1 | # Generation
2 | audioldm --file_path trumpet.wav
3 | audioldm --file_path trumpet.wav -dur 25
4 | audioldm --file_path trumpet.wav -dur 2.5
5 | audioldm --text "A hammer is hitting a wooden surface"
6 | audioldm
7 |
8 | # False use cases
9 | audioldm --text "A hammer is hitting a wooden surface" --file_path trumpet.wav # Same as audioldm --file_path trumpet.wav
10 |
11 |
12 | # Transfer
13 | audioldm --mode "transfer" --file_path trumpet.wav -t "Children Singing"
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/scripts/text2sound.py:
--------------------------------------------------------------------------------
1 | import os
2 | from audioldm import text_to_audio, build_model, save_wave
3 |
4 | import argparse
5 |
6 | parser = argparse.ArgumentParser()
7 |
8 | parser.add_argument(
9 | "-t",
10 | "--text",
11 | type=str,
12 | required=False,
13 | default="A hammer is hitting a wooden surface",
14 | help="Text prompt to the model for audio generation",
15 | )
16 |
17 | parser.add_argument(
18 | "-s",
19 | "--save_path",
20 | type=str,
21 | required=False,
22 | help="The path to save model output",
23 | default="./output",
24 | )
25 |
26 | parser.add_argument(
27 | "-ckpt",
28 | "--ckpt_path",
29 | type=str,
30 | required=False,
31 | help="The path to the pretrained .ckpt model",
32 | default="./ckpt/audioldm-s-full.ckpt",
33 | )
34 |
35 | parser.add_argument(
36 | "-b",
37 | "--batchsize",
38 | type=int,
39 | required=False,
40 | default=1,
41 | help="Generate how many samples at the same time",
42 | )
43 |
44 | parser.add_argument(
45 | "-gs",
46 | "--guidance_scale",
47 | type=float,
48 | required=False,
49 | default=2.5,
50 | help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
51 | )
52 |
53 | parser.add_argument(
54 | "-dur",
55 | "--duration",
56 | type=float,
57 | required=False,
58 | default=10.0,
59 | help="The duration of the samples",
60 | )
61 |
62 | parser.add_argument(
63 | "-n",
64 | "--n_candidate_gen_per_text",
65 | type=int,
66 | required=False,
67 | default=3,
68 | help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
69 | )
70 |
71 | parser.add_argument(
72 | "--seed",
73 | type=int,
74 | required=False,
75 | default=42,
76 | help="Change this value (any integer number) will lead to a different generation result.",
77 | )
78 |
79 | args = parser.parse_args()
80 |
81 | assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5"
82 |
83 | save_path = args.save_path
84 | text = args.text
85 | random_seed = args.seed
86 | duration = args.duration
87 | guidance_scale = args.guidance_scale
88 | n_candidate_gen_per_text = args.n_candidate_gen_per_text
89 |
90 | os.makedirs(save_path, exist_ok=True)
91 | audioldm = build_model(ckpt_path=args.ckpt_path)
92 | waveform = text_to_audio(
93 | audioldm,
94 | text,
95 | seed=random_seed,
96 | duration=duration,
97 | guidance_scale=guidance_scale,
98 | n_candidate_gen_per_text=n_candidate_gen_per_text,
99 | batchsize=args.batchsize,
100 | )
101 |
102 | save_wave(waveform, save_path, name=text)
103 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- encoding: utf-8 -*-
3 | # python3 setup.py sdist bdist_wheel
4 | """
5 | @File : setup.py.py
6 | @Contact : haoheliu@gmail.com
7 | @License : (C)Copyright 2020-2100
8 |
9 | @Modify Time @Author @Version @Desciption
10 | ------------ ------- -------- -----------
11 | 9/6/21 5:16 PM Haohe Liu 1.0 None
12 | """
13 |
14 | # !/usr/bin/env python
15 | # -*- coding: utf-8 -*-
16 |
17 | # Note: To use the 'upload' functionality of this file, you must:
18 | # $ pipenv install twine --dev
19 |
20 | import io
21 | import os
22 | import sys
23 | from shutil import rmtree
24 |
25 | from setuptools import find_packages, setup, Command
26 |
27 | # Package meta-data.
28 | NAME = "audioldm"
29 | DESCRIPTION = "This package is written for text-to-audio generation."
30 | URL = "https://github.com/haoheliu/audioldm"
31 | EMAIL = "haoheliu@gmail.com"
32 | AUTHOR = "Haohe Liu"
33 | REQUIRES_PYTHON = ">=3.7.0"
34 | VERSION = "0.1.1"
35 |
36 | # What packages are required for this module to be executed?
37 | REQUIRED = [
38 | "torch>=1.13.0",
39 | "torchaudio>=0.13.0",
40 | "torchvision>=0.14.0",
41 | "tqdm",
42 | "gradio",
43 | "pyyaml",
44 | "einops",
45 | "chardet",
46 | "numpy<=1.23.5",
47 | "soundfile",
48 | "librosa==0.9.2",
49 | "scipy",
50 | "pandas",
51 | "torchlibrosa==0.0.9",
52 | "transformers==4.29.0",
53 | "progressbar",
54 | "ftfy",
55 | ]
56 |
57 | # What packages are optional?
58 | EXTRAS = {}
59 |
60 | # The rest you shouldn't have to touch too much :)
61 | # ------------------------------------------------
62 | # Except, perhaps the License and Trove Classifiers!
63 | # If you do change the License, remember to change the Trove Classifier for that!
64 |
65 | here = os.path.abspath(os.path.dirname(__file__))
66 |
67 | # Import the README and use it as the long-description.
68 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file!
69 | try:
70 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f:
71 | long_description = "\n" + f.read()
72 | except FileNotFoundError:
73 | long_description = DESCRIPTION
74 |
75 | # Load the package's __version__.py module as a dictionary.
76 | about = {}
77 | if not VERSION:
78 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_")
79 | with open(os.path.join(here, project_slug, "__version__.py")) as f:
80 | exec(f.read(), about)
81 | else:
82 | about["__version__"] = VERSION
83 |
84 |
85 | class UploadCommand(Command):
86 | """Support setup.py upload."""
87 |
88 | description = "Build and publish the package."
89 | user_options = []
90 |
91 | @staticmethod
92 | def status(s):
93 | """Prints things in bold."""
94 | print("\033[1m{0}\033[0m".format(s))
95 |
96 | def initialize_options(self):
97 | pass
98 |
99 | def finalize_options(self):
100 | pass
101 |
102 | def run(self):
103 | try:
104 | self.status("Removing previous builds…")
105 | rmtree(os.path.join(here, "dist"))
106 | except OSError:
107 | pass
108 |
109 | self.status("Building Source and Wheel (universal) distribution…")
110 | os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable))
111 |
112 | self.status("Uploading the package to PyPI via Twine…")
113 | os.system("twine upload dist/*")
114 |
115 | self.status("Pushing git tags…")
116 | os.system("git tag v{0}".format(about["__version__"]))
117 | os.system("git push --tags")
118 |
119 | sys.exit()
120 |
121 |
122 | # Where the magic happens:
123 | setup(
124 | name=NAME,
125 | version=about["__version__"],
126 | description=DESCRIPTION,
127 | long_description=long_description,
128 | long_description_content_type="text/markdown",
129 | author=AUTHOR,
130 | author_email=EMAIL,
131 | python_requires=REQUIRES_PYTHON,
132 | url=URL,
133 | # packages=find_packages(exclude=[]),
134 | # If your package is a single module, use this instead of 'packages':
135 | # py_modules=["audioldm"],
136 | # entry_points={
137 | # 'console_scripts': ['mycli=mymodule:cli'],
138 | # },
139 | install_requires=REQUIRED,
140 | extras_require=EXTRAS,
141 | packages=find_packages(),
142 | # package_data={'bpe': ['audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz']},
143 | include_package_data=True,
144 | license="MIT",
145 | classifiers=[
146 | # Trove classifiers
147 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
148 | "License :: OSI Approved :: MIT License",
149 | "Programming Language :: Python",
150 | "Programming Language :: Python :: 3",
151 | "Programming Language :: Python :: 3.7",
152 | "Programming Language :: Python :: Implementation :: CPython",
153 | "Programming Language :: Python :: Implementation :: PyPy",
154 | ],
155 | # $ setup.py publish support.
156 | cmdclass={
157 | "upload": UploadCommand,
158 | },
159 | scripts=["bin/audioldm.cmd", "bin/audioldm"],
160 | )
161 |
--------------------------------------------------------------------------------
/trumpet.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haoheliu/AudioLDM/c0059c2414f52cdb1720954b09e6d113e79bfdfe/trumpet.wav
--------------------------------------------------------------------------------