├── AudioSep_Colab.ipynb ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets └── results.png ├── benchmark.py ├── callbacks └── base.py ├── cog.yaml ├── config └── audiosep_base.yaml ├── data ├── audiotext_dataset.py ├── datamodules.py └── waveform_mixers.py ├── datafiles └── template.json ├── environment.yml ├── environment_win64.yaml ├── evaluation ├── evaluate_audiocaps.py ├── evaluate_audioset.py ├── evaluate_clotho.py ├── evaluate_esc50.py ├── evaluate_music.py ├── evaluate_vggsound.py └── metadata │ ├── audiocaps_eval.csv │ ├── audioset_eval.csv │ ├── class_labels_indices.csv │ ├── clotho_eval.csv │ ├── esc50_eval.csv │ ├── music_eval.csv │ └── vggsound_eval.csv ├── losses.py ├── models ├── CLAP │ ├── __init__.py │ ├── open_clip │ │ ├── __init__.py │ │ ├── bert.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── factory.py │ │ ├── feature_fusion.py │ │ ├── htsat.py │ │ ├── linear_probe.py │ │ ├── loss.py │ │ ├── model.py │ │ ├── model_configs │ │ │ ├── HTSAT-base.json │ │ │ ├── HTSAT-large.json │ │ │ ├── HTSAT-tiny-win-1536.json │ │ │ ├── HTSAT-tiny.json │ │ │ ├── PANN-10.json │ │ │ ├── PANN-14-fmax-18k.json │ │ │ ├── PANN-14-fmax-8k-20s.json │ │ │ ├── PANN-14-tiny-transformer.json │ │ │ ├── PANN-14-win-1536.json │ │ │ ├── PANN-14.json │ │ │ ├── PANN-6.json │ │ │ ├── RN101-quickgelu.json │ │ │ ├── RN101.json │ │ │ ├── RN50-quickgelu.json │ │ │ ├── RN50.json │ │ │ ├── RN50x16.json │ │ │ ├── RN50x4.json │ │ │ ├── ViT-B-16.json │ │ │ ├── ViT-B-32-quickgelu.json │ │ │ ├── ViT-B-32.json │ │ │ └── ViT-L-14.json │ │ ├── openai.py │ │ ├── pann_model.py │ │ ├── pretrained.py │ │ ├── timm_model.py │ │ ├── tokenizer.py │ │ ├── transform.py │ │ ├── utils.py │ │ └── version.py │ └── training │ │ ├── __init__.py │ │ ├── audioset_textmap.npy │ │ ├── data.py │ │ ├── distributed.py │ │ ├── imagenet_zeroshot_data.py │ │ ├── infer_demo.py │ │ ├── logger.py │ │ ├── lp_main.py │ │ ├── lp_train.py │ │ ├── main.py │ │ ├── params.py │ │ ├── scheduler.py │ │ ├── train.py │ │ └── zero_shot.py ├── audiosep.py ├── base.py ├── clap_encoder.py └── resunet.py ├── optimizers └── lr_schedulers.py ├── pipeline.py ├── predict.py ├── train.py └── utils.py /AudioSep_Colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pathlib import Path\n", 10 | "\n", 11 | "repo_path = Path(\"/content/AudioSep\")\n", 12 | "if not repo_path.exists():\n", 13 | " !git clone https://github.com/Audio-AGI/AudioSep.git\n", 14 | "\n", 15 | "%cd /content/AudioSep" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": { 22 | "id": "pjIhw5ECS_3_" 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "!pip install torchlibrosa==0.1.0 gradio==3.47.1 gdown lightning transformers==4.28.1 ftfy braceexpand webdataset soundfile wget h5py" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": { 33 | "id": "t6h9KB3CcjBd" 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "checkpoints_dir = Path(\"checkpoint\")\n", 38 | "checkpoints_dir.mkdir(exist_ok=True)\n", 39 | "\n", 40 | "models = (\n", 41 | " (\n", 42 | " \"https://huggingface.co/spaces/badayvedat/AudioSep/resolve/main/checkpoint/audiosep_base_4M_steps.ckpt\",\n", 43 | " checkpoints_dir / \"audiosep_base_4M_steps.ckpt\"\n", 44 | " ),\n", 45 | " (\n", 46 | " \"https://huggingface.co/spaces/badayvedat/AudioSep/resolve/main/checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt\",\n", 47 | " checkpoints_dir / \"music_speech_audioset_epoch_15_esc_89.98.pt\"\n", 48 | " )\n", 49 | ")\n", 50 | "\n", 51 | "for model_url, model_path in models:\n", 52 | " if not model_path.exists():\n", 53 | " !wget {model_url} -O {model_path}" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": { 60 | "id": "3uDrzCQyY58h" 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "!wget \"https://audio-agi.github.io/Separate-Anything-You-Describe/demos/exp31_water drops_mixture.wav\"" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": { 71 | "id": "0nr77CGXTwO1" 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "import torch\n", 76 | "from pipeline import build_audiosep, separate_audio\n", 77 | "\n", 78 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 79 | "\n", 80 | "model = build_audiosep(\n", 81 | " config_yaml='config/audiosep_base.yaml',\n", 82 | " checkpoint_path=str(models[0][1]),\n", 83 | " device=device)\n", 84 | "\n", 85 | "audio_file = 'exp31_water drops_mixture.wav'\n", 86 | "text = 'water drops'\n", 87 | "output_file='separated_audio.wav'\n", 88 | "\n", 89 | "# AudioSep processes the audio at 32 kHz sampling rate\n", 90 | "separate_audio(model, audio_file, text, output_file, device)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": { 97 | "id": "kssOe0pbPSWp" 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "print(f\"The separated audio is saved to: '{output_file}' file.\")" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": { 108 | "id": "sl35U3dAR6KN" 109 | }, 110 | "outputs": [], 111 | "source": [] 112 | } 113 | ], 114 | "metadata": { 115 | "colab": { 116 | "provenance": [] 117 | }, 118 | "kernelspec": { 119 | "display_name": "Python 3", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "name": "python" 124 | } 125 | }, 126 | "nbformat": 4, 127 | "nbformat_minor": 0 128 | } 129 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # 🎵 Contributing to AudioSep 2 | 3 | Welcome to the AudioSep repository, where your contributions can harmonize the world of audio separation. To ensure a harmonious and organized collaboration, please follow the contribution guidelines outlined below. 4 | 5 | ## **Submitting Contributions** 6 | 7 | To contribute to this project, please adhere to the following steps: 8 | 9 | ### **1. Choose or Create an Issue** 10 | 11 | - Start by reviewing the existing issues to identify areas where your contributions can make a significant impact. 12 | - If you have ideas for new features, enhancements, or bug fixes, feel free to create a new issue to propose your contributions. Provide comprehensive details for clarity. 13 | 14 | ### **2. Fork the Repository** 15 | 16 | - To initiate your contribution, fork the primary repository by clicking the "Fork" button. This will create a copy of the repository in your personal GitHub account. 17 | 18 | ### **3. Clone Your Forked Repository** 19 | 20 | - Clone your forked repository to your local development environment using the following command: 21 | 22 | ```bash 23 | git clone https://github.com/your-username/AudioSep.git 24 | ``` 25 | 26 | ### **4. Set Up the Upstream Remote** 27 | 28 | - Maintain a reference to the primary project by adding it as the upstream remote: 29 | 30 | ```bash 31 | cd AudioSep 32 | git remote add upstream https://github.com/Audio-AGI/AudioSep 33 | git remote -v 34 | ``` 35 | 36 | ### **5. Create a New Branch** 37 | 38 | - Before starting your contribution, establish a new branch dedicated to your specific task: 39 | 40 | ```bash 41 | git checkout -b my-contribution 42 | ``` 43 | 44 | ## **Working on Your Contribution** 45 | 46 | Now that your development environment is ready and a new branch is established, you can start working on your contribution. Please ensure you adhere to the following guidelines: 47 | 48 | ### **6. Make Changes** 49 | 50 | - Implement the necessary changes, including code additions, enhancements, or bug fixes. Ensure your contributions are well-structured, documented, and aligned with the project's objectives. 51 | 52 | ### **7. Commit Your Changes** 53 | 54 | - Commit your changes using informative commit messages that clearly convey the purpose of your contributions: 55 | 56 | ```bash 57 | git commit -m "Add a descriptive message here" 58 | ``` 59 | 60 | ### **8. Push Your Changes** 61 | 62 | - Push the committed changes to your remote repository on GitHub: 63 | 64 | ```bash 65 | git push origin my-contribution 66 | ``` 67 | 68 | ### **9. Create a Pull Request** 69 | 70 | - Visit your repository on GitHub and click the "New Pull Request" button to initiate a pull request from your branch to the primary repository. 71 | 72 | ### **10. Await Review** 73 | 74 | - Your pull request will undergo review, and feedback will be provided by the project maintainers or fellow contributors. Be prepared to address any suggested changes or refinements. 75 | 76 | ## **Community Engagement** 77 | 78 | While contributing, please consider engaging with the community in the following ways: 79 | 80 | ### **11. Join Discussions** 81 | 82 | - Participate in discussions related to audio separation techniques and their applications. Share your insights, experiences, and expertise in the audio field. 83 | 84 | ### **12. Share Ideas** 85 | 86 | - If you have innovative ideas for advancing the project or optimizing audio separation, such as new algorithms or research findings, feel free to open issues to initiate productive discussions. 87 | 88 | ## **Acknowledgment** 89 | 90 | We appreciate your dedication to the world of audio separation. Your contributions play a crucial role in harmonizing audio and improving the listening experience for all. If you have questions or require assistance, please don't hesitate to contact the project maintainers. 91 | 92 | Thank you for your valuable contributions, and we eagerly anticipate collaborating with you on AudioSep! 🎶🙌 93 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Xubo Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Separate Anything You Describe 2 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2308.05037) [![GitHub Stars](https://img.shields.io/github/stars/Audio-AGI/AudioSep?style=social)](https://github.com/Audio-AGI/AudioSep/) [![githubio](https://img.shields.io/badge/GitHub.io-Demo_Page-blue?logo=Github&style=flat-square)](https://audio-agi.github.io/Separate-Anything-You-Describe) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Audio-AGI/AudioSep/blob/main/AudioSep_Colab.ipynb) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/Audio-AGI/AudioSep) [![Replicate](https://replicate.com/cjwbw/audiosep/badge)](https://replicate.com/cjwbw/audiosep) 3 | 4 | 5 | This repository contains the official implementation of ["Separate Anything You Describe"](https://audio-agi.github.io/Separate-Anything-You-Describe/AudioSep_arXiv.pdf). 6 | 7 | We introduce AudioSep, a foundation model for open-domain sound separation with natural language queries. AudioSep demonstrates strong separation performance and impressive zero-shot generalization ability on numerous tasks, such as audio event separation, musical instrument separation, and speech enhancement. Check out the separated audio examples on the [Demo Page](https://audio-agi.github.io/Separate-Anything-You-Describe/)! 8 | 9 |

10 | 11 |

12 | 13 |
14 | 15 | ## Setup 16 | Clone the repository and setup the conda environment: 17 | 18 | ```shell 19 | git clone https://github.com/Audio-AGI/AudioSep.git && \ 20 | cd AudioSep && \ 21 | conda env create -f environment.yml && \ 22 | conda activate AudioSep 23 | ``` 24 | Download [model weights](https://huggingface.co/spaces/Audio-AGI/AudioSep/tree/main/checkpoint) at `checkpoint/`. 25 | 26 | 27 | If you're using this checkpoint for the DCASE 2024 Task 9 challenge participation, please note that this checkpoint was trained using audio at 32k Hz, with a window size of 2048 points and a hop size of 320 points in the STFT operation, which is different with the challenge baseline system provided (16k Hz, window size 1024, hop size 160). 28 |
29 | 30 | ## Inference 31 | 32 | 33 | ```python 34 | from pipeline import build_audiosep, inference 35 | import torch 36 | 37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 38 | 39 | model = build_audiosep( 40 | config_yaml='config/audiosep_base.yaml', 41 | checkpoint_path='checkpoint/audiosep_base_4M_steps.ckpt', 42 | device=device) 43 | 44 | audio_file = 'path_to_audio_file' 45 | text = 'textual_description' 46 | output_file='separated_audio.wav' 47 | 48 | # AudioSep processes the audio at 32 kHz sampling rate 49 | inference(model, audio_file, text, output_file, device) 50 | ``` 51 | 52 |
53 | 54 | To load directly from Hugging Face, you can do the following: 55 | 56 | ```python 57 | from models.audiosep import AudioSep 58 | from utils import get_ss_model 59 | import torch 60 | 61 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 62 | 63 | ss_model = get_ss_model('config/audiosep_base.yaml') 64 | 65 | model = AudioSep.from_pretrained("nielsr/audiosep-demo", ss_model=ss_model) 66 | 67 | audio_file = 'path_to_audio_file' 68 | text = 'textual_description' 69 | output_file='separated_audio.wav' 70 | 71 | # AudioSep processes the audio at 32 kHz sampling rate 72 | inference(model, audio_file, text, output_file, device) 73 | ``` 74 |
75 | 76 | Use chunk-based inference to save memory: 77 | ```python 78 | inference(model, audio_file, text, output_file, device, use_chunk=True) 79 | ``` 80 | 81 | ## Training 82 | 83 | To utilize your audio-text paired dataset: 84 | 85 | 1. Format your dataset to match our JSON structure. Refer to the provided template at `datafiles/template.json`. 86 | 87 | 2. Update the `config/audiosep_base.yaml` file by listing your formatted JSON data files under `datafiles`. For example: 88 | 89 | ```yaml 90 | data: 91 | datafiles: 92 | - 'datafiles/your_datafile_1.json' 93 | - 'datafiles/your_datafile_2.json' 94 | ... 95 | ``` 96 | 97 | Train AudioSep from scratch: 98 | ```python 99 | python train.py --workspace workspace/AudioSep --config_yaml config/audiosep_base.yaml --resume_checkpoint_path checkpoint/ '' 100 | ``` 101 | 102 | Finetune AudioSep from pretrained checkpoint: 103 | ```python 104 | python train.py --workspace workspace/AudioSep --config_yaml config/audiosep_base.yaml --resume_checkpoint_path path_to_checkpoint 105 | ``` 106 | 107 |
108 | 109 | ## Benchmark Evaluation 110 | Download the [evaluation data](https://drive.google.com/drive/folders/1PbCsuvdrzwAZZ_fwIzF0PeVGZkTk0-kL?usp=sharing) under the `evaluation/data` folder. The data should be organized as follows: 111 | 112 | ```yaml 113 | evaluation: 114 | data: 115 | - audioset/ 116 | - audiocaps/ 117 | - vggsound/ 118 | - music/ 119 | - clotho/ 120 | - esc50/ 121 | ``` 122 | Run benchmark inference script, the results will be saved at `eval_logs/` 123 | ```python 124 | python benchmark.py --checkpoint_path audiosep_base_4M_steps.ckpt 125 | 126 | """ 127 | Evaluation Results: 128 | 129 | VGGSound Avg SDRi: 9.144, SISDR: 9.043 130 | MUSIC Avg SDRi: 10.508, SISDR: 9.425 131 | ESC-50 Avg SDRi: 10.040, SISDR: 8.810 132 | AudioSet Avg SDRi: 7.739, SISDR: 6.903 133 | AudioCaps Avg SDRi: 8.220, SISDR: 7.189 134 | Clotho Avg SDRi: 6.850, SISDR: 5.242 135 | """ 136 | ``` 137 | 138 | ## Cite this work 139 | 140 | If you found this tool useful, please consider citing 141 | ```bibtex 142 | @article{liu2023separate, 143 | title={Separate Anything You Describe}, 144 | author={Liu, Xubo and Kong, Qiuqiang and Zhao, Yan and Liu, Haohe and Yuan, Yi, and Liu, Yuzhuo, and Xia, Rui and Wang, Yuxuan, and Plumbley, Mark D and Wang, Wenwu}, 145 | journal={arXiv preprint arXiv:2308.05037}, 146 | year={2023} 147 | } 148 | ``` 149 | ```bibtex 150 | @inproceedings{liu22w_interspeech, 151 | title={Separate What You Describe: Language-Queried Audio Source Separation}, 152 | author={Liu, Xubo and Liu, Haohe and Kong, Qiuqiang and Mei, Xinhao and Zhao, Jinzheng and Huang, Qiushi, and Plumbley, Mark D and Wang, Wenwu}, 153 | year=2022, 154 | booktitle={Proc. Interspeech}, 155 | pages={1801--1805}, 156 | } 157 | ``` 158 | 159 | ## Contributors : 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-AGI/AudioSep/944583f18b84589dc965de3ad77525c945334252/assets/results.png -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | from evaluation.evaluate_audioset import AudioSetEvaluator 5 | from evaluation.evaluate_audiocaps import AudioCapsEvaluator 6 | from evaluation.evaluate_vggsound import VGGSoundEvaluator 7 | from evaluation.evaluate_music import MUSICEvaluator 8 | from evaluation.evaluate_esc50 import ESC50Evaluator 9 | from evaluation.evaluate_clotho import ClothoEvaluator 10 | from models.clap_encoder import CLAP_Encoder 11 | 12 | from utils import ( 13 | load_ss_model, 14 | calculate_sdr, 15 | calculate_sisdr, 16 | parse_yaml, 17 | get_mean_sdr_from_dict, 18 | ) 19 | 20 | def eval(checkpoint_path, config_yaml='config/audiosep_base.yaml'): 21 | 22 | log_dir = 'eval_logs' 23 | os.makedirs(log_dir, exist_ok=True) 24 | 25 | device = "cuda" 26 | 27 | configs = parse_yaml(config_yaml) 28 | 29 | # AudioSet Evaluators 30 | audioset_evaluator = AudioSetEvaluator() 31 | # AudioCaps Evaluator 32 | audiocaps_evaluator = AudioCapsEvaluator() 33 | # VGGSound+ Evaluator 34 | vggsound_evaluator = VGGSoundEvaluator() 35 | # Clotho Evaluator 36 | clotho_evaluator = ClothoEvaluator() 37 | # MUSIC Evaluator 38 | music_evaluator = MUSICEvaluator() 39 | # ESC-50 Evaluator 40 | esc50_evaluator = ESC50Evaluator() 41 | 42 | # Load model 43 | query_encoder = CLAP_Encoder().eval() 44 | 45 | pl_model = load_ss_model( 46 | configs=configs, 47 | checkpoint_path=checkpoint_path, 48 | query_encoder=query_encoder 49 | ).to(device) 50 | 51 | print(f'------- Start Evaluation -------') 52 | 53 | # evaluation on Clotho 54 | SISDR, SDRi = clotho_evaluator(pl_model) 55 | msg_clotho = "Clotho Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR) 56 | print(msg_clotho) 57 | 58 | # evaluation on VGGSound+ (YAN) 59 | SISDR, SDRi = vggsound_evaluator(pl_model) 60 | msg_vgg = "VGGSound Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR) 61 | print(msg_vgg) 62 | 63 | # evaluation on MUSIC 64 | SISDR, SDRi = music_evaluator(pl_model) 65 | msg_music = "MUSIC Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR) 66 | print(msg_music) 67 | 68 | # evaluation on ESC-50 69 | SISDR, SDRi = esc50_evaluator(pl_model) 70 | msg_esc50 = "ESC-50 Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR) 71 | print(msg_esc50) 72 | 73 | # evaluation on AudioSet 74 | stats_dict = audioset_evaluator(pl_model=pl_model) 75 | median_sdris = {} 76 | median_sisdrs = {} 77 | 78 | for class_id in range(527): 79 | median_sdris[class_id] = np.nanmedian(stats_dict["sdris_dict"][class_id]) 80 | median_sisdrs[class_id] = np.nanmedian(stats_dict["sisdrs_dict"][class_id]) 81 | 82 | SDRi = get_mean_sdr_from_dict(median_sdris) 83 | SISDR = get_mean_sdr_from_dict(median_sisdrs) 84 | msg_audioset = "AudioSet Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR) 85 | print(msg_audioset) 86 | 87 | # evaluation on AudioCaps 88 | SISDR, SDRi = audiocaps_evaluator(pl_model) 89 | msg_audiocaps = "AudioCaps Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR) 90 | print(msg_audiocaps) 91 | 92 | # evaluation on Clotho 93 | SISDR, SDRi = clotho_evaluator(pl_model) 94 | msg_clotho = "Clotho Avg SDRi: {:.3f}, SISDR: {:.3f}".format(SDRi, SISDR) 95 | print(msg_clotho) 96 | 97 | msgs = [msg_audioset, msg_vgg, msg_audiocaps, msg_clotho, msg_music, msg_esc50] 98 | 99 | # open file in write mode 100 | log_path = os.path.join(log_dir, 'eval_results.txt') 101 | with open(log_path, 'w') as fp: 102 | for msg in msgs: 103 | fp.write(msg + '\n') 104 | print(f'Eval log is written to {log_path} ...') 105 | print('------------------------- Done ---------------------------') 106 | 107 | 108 | if __name__ == '__main__': 109 | eval(checkpoint_path='checkpoint/audiosep_base.ckpt') 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /callbacks/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lightning.pytorch as pl 3 | from lightning.pytorch.utilities import rank_zero_only 4 | 5 | 6 | class CheckpointEveryNSteps(pl.Callback): 7 | def __init__( 8 | self, 9 | checkpoints_dir, 10 | save_step_frequency, 11 | ) -> None: 12 | r"""Save a checkpoint every N steps. 13 | 14 | Args: 15 | checkpoints_dir (str): directory to save checkpoints 16 | save_step_frequency (int): save checkpoint every N step 17 | """ 18 | 19 | self.checkpoints_dir = checkpoints_dir 20 | self.save_step_frequency = save_step_frequency 21 | 22 | @rank_zero_only 23 | def on_train_batch_end(self, *args, **kwargs) -> None: 24 | r"""Save a checkpoint every N steps.""" 25 | 26 | trainer = args[0] 27 | global_step = trainer.global_step 28 | 29 | if global_step == 1 or global_step % self.save_step_frequency == 0: 30 | 31 | ckpt_path = os.path.join( 32 | self.checkpoints_dir, 33 | "step={}.ckpt".format(global_step)) 34 | trainer.save_checkpoint(ckpt_path) 35 | print("Save checkpoint to {}".format(ckpt_path)) 36 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | python_version: "3.11" 7 | python_packages: 8 | - "torchlibrosa==0.1.0" 9 | - "lightning==2.1.0" 10 | - "torch==2.0.1" 11 | - "transformers==4.28.1" 12 | - "braceexpand==0.1.7" 13 | - "webdataset==0.2.60" 14 | - "soundfile==0.12.1" 15 | - "torchaudio==2.0.2" 16 | - "torchvision==0.15.2" 17 | - "h5py==3.10.0" 18 | - "ftfy==6.1.1" 19 | - "pandas==2.1.1" 20 | - "wget==3.2" 21 | predict: "predict.py:Predictor" 22 | -------------------------------------------------------------------------------- /config/audiosep_base.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | task_name: AudioSep 3 | 4 | data: 5 | datafiles: 6 | - 'datafiles/template.json' 7 | 8 | sampling_rate: 32000 9 | segment_seconds: 5 10 | loudness_norm: 11 | lower_db: -10 12 | higher_db: 10 13 | max_mix_num: 2 14 | 15 | model: 16 | query_net: CLAP 17 | condition_size: 512 18 | model_type: ResUNet30 19 | input_channels: 1 20 | output_channels: 1 21 | resume_checkpoint: "" 22 | use_text_ratio: 1.0 23 | 24 | train: 25 | optimizer: 26 | optimizer_type: AdamW 27 | learning_rate: 1e-3 28 | warm_up_steps: 10000 29 | reduce_lr_steps: 1000000 30 | lr_lambda_type: constant_warm_up 31 | num_nodes: 1 32 | num_workers: 6 33 | loss_type: l1_wav 34 | sync_batchnorm: True 35 | batch_size_per_device: 12 36 | steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`. 37 | evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps. 38 | save_step_frequency: 20000 # Save every #save_step_frequency steps. 39 | early_stop_steps: 10000001 40 | random_seed: 1234 41 | 42 | -------------------------------------------------------------------------------- /data/audiotext_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import torch 4 | import torchaudio 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class AudioTextDataset(Dataset): 9 | """Can sample data from audio-text databases 10 | Params: 11 | sampling_rate: audio sampling rate 12 | max_clip_len: max length (seconds) of audio clip to be sampled 13 | """ 14 | def __init__( 15 | self, 16 | datafiles=[''], 17 | sampling_rate=32000, 18 | max_clip_len=5, 19 | ): 20 | all_data_json = [] 21 | for datafile in datafiles: 22 | with open(datafile, 'r') as fp: 23 | data_json = json.load(fp)['data'] 24 | all_data_json.extend(data_json) 25 | self.all_data_json = all_data_json 26 | 27 | self.sampling_rate = sampling_rate 28 | self.max_length = max_clip_len * sampling_rate 29 | 30 | def __len__(self): 31 | return len(self.all_data_json) 32 | 33 | def _cut_or_randomcrop(self, waveform): 34 | # waveform: [1, samples] 35 | # random crop 36 | if waveform.size(1) > self.max_length: 37 | random_idx = random.randint(0, waveform.size(1)-self.max_length) 38 | waveform = waveform[:, random_idx:random_idx+self.max_length] 39 | else: 40 | temp_wav = torch.zeros(1, self.max_length) 41 | temp_wav[:, 0:waveform.size(1)] = waveform 42 | waveform = temp_wav 43 | 44 | assert waveform.size(1) == self.max_length, \ 45 | f"number of audio samples is {waveform.size(1)}" 46 | 47 | return waveform 48 | 49 | def _read_audio(self, index): 50 | try: 51 | audio_path = self.all_data_json[index]['wav'] 52 | audio_data, audio_rate = torchaudio.load(audio_path, channels_first=True) 53 | text = self.all_data_json[index]['caption'] 54 | 55 | # drop short utterance 56 | if audio_data.size(1) < self.sampling_rate * 1: 57 | raise Exception(f'{audio_path} is too short, drop it ...') 58 | 59 | return text, audio_data, audio_rate 60 | 61 | except Exception as e: 62 | print(f'error: {e} occurs, when loading {audio_path}') 63 | random_index = random.randint(0, len(self.all_data_json)-1) 64 | return self._read_audio(index=random_index) 65 | 66 | def __getitem__(self, index): 67 | # create a audio tensor 68 | text, audio_data, audio_rate = self._read_audio(index) 69 | audio_len = audio_data.shape[1] / audio_rate 70 | # convert stero to single channel 71 | if audio_data.shape[0] > 1: 72 | # audio_data: [samples] 73 | audio_data = (audio_data[0] + audio_data[1]) / 2 74 | else: 75 | audio_data = audio_data.squeeze(0) 76 | 77 | # resample audio clip 78 | if audio_rate != self.sampling_rate: 79 | audio_data = torchaudio.functional.resample(audio_data, orig_freq=audio_rate, new_freq=self.sampling_rate) 80 | 81 | audio_data = audio_data.unsqueeze(0) 82 | 83 | audio_data = self._cut_or_randomcrop(audio_data) 84 | 85 | data_dict = { 86 | 'text': text, 87 | 'waveform': audio_data, 88 | 'modality': 'audio_text' 89 | } 90 | 91 | return data_dict 92 | -------------------------------------------------------------------------------- /data/datamodules.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, NoReturn 2 | import torch 3 | import lightning.pytorch as pl 4 | from torch.utils.data import DataLoader 5 | from data.audiotext_dataset import AudioTextDataset 6 | 7 | 8 | class DataModule(pl.LightningDataModule): 9 | def __init__( 10 | self, 11 | train_dataset: object, 12 | batch_size: int, 13 | num_workers: int 14 | ): 15 | r"""Data module. To get one batch of data: 16 | 17 | code-block:: python 18 | 19 | data_module.setup() 20 | 21 | for batch_data_dict in data_module.train_dataloader(): 22 | print(batch_data_dict.keys()) 23 | break 24 | 25 | Args: 26 | train_sampler: Sampler object 27 | train_dataset: Dataset object 28 | num_workers: int 29 | distributed: bool 30 | """ 31 | super().__init__() 32 | self._train_dataset = train_dataset 33 | self.num_workers = num_workers 34 | self.batch_size = batch_size 35 | self.collate_fn = collate_fn 36 | 37 | 38 | def prepare_data(self): 39 | # download, split, etc... 40 | # only called on 1 GPU/TPU in distributed 41 | pass 42 | 43 | def setup(self, stage: Optional[str] = None) -> NoReturn: 44 | r"""called on every device.""" 45 | 46 | # make assignments here (val/train/test split) 47 | # called on every process in DDP 48 | 49 | # SegmentSampler is used for selecting segments for training. 50 | # On multiple devices, each SegmentSampler samples a part of mini-batch 51 | # data. 52 | self.train_dataset = self._train_dataset 53 | 54 | 55 | def train_dataloader(self) -> torch.utils.data.DataLoader: 56 | r"""Get train loader.""" 57 | train_loader = DataLoader( 58 | dataset=self.train_dataset, 59 | batch_size=self.batch_size, 60 | collate_fn=self.collate_fn, 61 | num_workers=self.num_workers, 62 | pin_memory=True, 63 | persistent_workers=False, 64 | shuffle=True 65 | ) 66 | 67 | return train_loader 68 | 69 | def val_dataloader(self): 70 | # val_split = Dataset(...) 71 | # return DataLoader(val_split) 72 | pass 73 | 74 | def test_dataloader(self): 75 | # test_split = Dataset(...) 76 | # return DataLoader(test_split) 77 | pass 78 | 79 | def teardown(self): 80 | # clean up after fit or test 81 | # called on every process in DDP 82 | pass 83 | 84 | 85 | def collate_fn(list_data_dict): 86 | r"""Collate mini-batch data to inputs and targets for training. 87 | 88 | Args: 89 | list_data_dict: e.g., [ 90 | { 91 | 'text': 'a sound of dog', 92 | 'waveform': (1, samples), 93 | 'modality': 'audio_text' 94 | } 95 | ... 96 | ] 97 | Returns: 98 | data_dict: e.g. 99 | 'audio_text': { 100 | 'text': ['a sound of dog', ...] 101 | 'waveform': (batch_size, 1, samples) 102 | } 103 | """ 104 | 105 | at_list_data_dict = [data_dict for data_dict in list_data_dict if data_dict['modality']=='audio_text'] 106 | 107 | at_data_dict = {} 108 | 109 | if len(at_list_data_dict) > 0: 110 | for key in at_list_data_dict[0].keys(): 111 | at_data_dict[key] = [at_data_dict[key] for at_data_dict in at_list_data_dict] 112 | if key == 'waveform': 113 | at_data_dict[key] = torch.stack(at_data_dict[key]) 114 | elif key == 'text': 115 | at_data_dict[key] = [text for text in at_data_dict[key]] 116 | 117 | 118 | data_dict = { 119 | 'audio_text': at_data_dict 120 | } 121 | 122 | return data_dict -------------------------------------------------------------------------------- /data/waveform_mixers.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sre_compile 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import pyloudnorm as pyln 7 | 8 | 9 | class SegmentMixer(nn.Module): 10 | def __init__(self, max_mix_num, lower_db, higher_db): 11 | super(SegmentMixer, self).__init__() 12 | 13 | self.max_mix_num = max_mix_num 14 | self.loudness_param = { 15 | 'lower_db': lower_db, 16 | 'higher_db': higher_db, 17 | } 18 | 19 | def __call__(self, waveforms): 20 | 21 | batch_size = waveforms.shape[0] 22 | 23 | data_dict = { 24 | 'segment': [], 25 | 'mixture': [], 26 | } 27 | 28 | for n in range(0, batch_size): 29 | 30 | segment = waveforms[n].clone() 31 | 32 | # create zero tensors as the background template 33 | noise = torch.zeros_like(segment) 34 | 35 | mix_num = random.randint(2, self.max_mix_num) 36 | assert mix_num >= 2 37 | 38 | for i in range(1, mix_num): 39 | next_segment = waveforms[(n + i) % batch_size] 40 | rescaled_next_segment = dynamic_loudnorm(audio=next_segment, reference=segment, **self.loudness_param) 41 | noise += rescaled_next_segment 42 | 43 | # randomly normalize background noise 44 | noise = dynamic_loudnorm(audio=noise, reference=segment, **self.loudness_param) 45 | 46 | # create audio mixyure 47 | mixture = segment + noise 48 | 49 | # declipping if need be 50 | max_value = torch.max(torch.abs(mixture)) 51 | if max_value > 1: 52 | segment *= 0.9 / max_value 53 | mixture *= 0.9 / max_value 54 | 55 | data_dict['segment'].append(segment) 56 | data_dict['mixture'].append(mixture) 57 | 58 | for key in data_dict.keys(): 59 | data_dict[key] = torch.stack(data_dict[key], dim=0) 60 | 61 | # return data_dict 62 | return data_dict['mixture'], data_dict['segment'] 63 | 64 | 65 | def rescale_to_match_energy(segment1, segment2): 66 | 67 | ratio = get_energy_ratio(segment1, segment2) 68 | rescaled_segment1 = segment1 / ratio 69 | return rescaled_segment1 70 | 71 | 72 | def get_energy(x): 73 | return torch.mean(x ** 2) 74 | 75 | 76 | def get_energy_ratio(segment1, segment2): 77 | 78 | energy1 = get_energy(segment1) 79 | energy2 = max(get_energy(segment2), 1e-10) 80 | ratio = (energy1 / energy2) ** 0.5 81 | ratio = torch.clamp(ratio, 0.02, 50) 82 | return ratio 83 | 84 | 85 | def dynamic_loudnorm(audio, reference, lower_db=-10, higher_db=10): 86 | rescaled_audio = rescale_to_match_energy(audio, reference) 87 | 88 | delta_loudness = random.randint(lower_db, higher_db) 89 | 90 | gain = np.power(10.0, delta_loudness / 20.0) 91 | 92 | return gain * rescaled_audio 93 | 94 | 95 | def torch_to_numpy(tensor): 96 | """Convert a PyTorch tensor to a NumPy array.""" 97 | if isinstance(tensor, torch.Tensor): 98 | return tensor.detach().cpu().numpy() 99 | else: 100 | raise ValueError("Input must be a PyTorch tensor.") 101 | 102 | 103 | def numpy_to_torch(array): 104 | """Convert a NumPy array to a PyTorch tensor.""" 105 | if isinstance(array, np.ndarray): 106 | return torch.from_numpy(array) 107 | else: 108 | raise ValueError("Input must be a NumPy array.") 109 | 110 | 111 | # decayed 112 | def random_loudness_norm(audio, lower_db=-35, higher_db=-15, sr=32000): 113 | device = audio.device 114 | audio = torch_to_numpy(audio.squeeze(0)) 115 | # randomly select a norm volume 116 | norm_vol = random.randint(lower_db, higher_db) 117 | 118 | # measure the loudness first 119 | meter = pyln.Meter(sr) # create BS.1770 meter 120 | loudness = meter.integrated_loudness(audio) 121 | # loudness normalize audio 122 | normalized_audio = pyln.normalize.loudness(audio, loudness, norm_vol) 123 | 124 | normalized_audio = numpy_to_torch(normalized_audio).unsqueeze(0) 125 | 126 | return normalized_audio.to(device) 127 | 128 | -------------------------------------------------------------------------------- /datafiles/template.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "wav": "path_to_audio_file", 5 | "caption": "textual_desciptions" 6 | } 7 | ] 8 | } -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: AudioSep 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - backcall=0.2.0=pyhd3eb1b0_0 10 | - blas=1.0=mkl 11 | - boltons=23.0.0=py310h06a4308_0 12 | - brotlipy=0.7.0=py310h7f8727e_1002 13 | - bzip2=1.0.8=h7b6447c_0 14 | - ca-certificates=2023.01.10=h06a4308_0 15 | - certifi=2022.12.7=py310h06a4308_0 16 | - cffi=1.15.1=py310h5eee18b_3 17 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 18 | - comm=0.1.2=py310h06a4308_0 19 | - conda=23.3.1=py310h06a4308_0 20 | - conda-content-trust=0.1.3=py310h06a4308_0 21 | - conda-package-handling=2.0.2=py310h06a4308_0 22 | - conda-package-streaming=0.7.0=py310h06a4308_0 23 | - cryptography=38.0.4=py310h9ce1e76_0 24 | - cuda=11.6.1=0 25 | - cuda-cccl=11.6.55=hf6102b2_0 26 | - cuda-command-line-tools=11.6.2=0 27 | - cuda-compiler=11.6.2=0 28 | - cuda-cudart=11.6.55=he381448_0 29 | - cuda-cudart-dev=11.6.55=h42ad0f4_0 30 | - cuda-cuobjdump=11.6.124=h2eeebcb_0 31 | - cuda-cupti=11.6.124=h86345e5_0 32 | - cuda-cuxxfilt=11.6.124=hecbf4f6_0 33 | - cuda-driver-dev=11.6.55=0 34 | - cuda-gdb=12.1.55=0 35 | - cuda-libraries=11.6.1=0 36 | - cuda-libraries-dev=11.6.1=0 37 | - cuda-memcheck=11.8.86=0 38 | - cuda-nsight=12.1.55=0 39 | - cuda-nsight-compute=12.1.0=0 40 | - cuda-nvcc=11.6.124=hbba6d2d_0 41 | - cuda-nvdisasm=12.1.55=0 42 | - cuda-nvml-dev=11.6.55=haa9ef22_0 43 | - cuda-nvprof=12.1.55=0 44 | - cuda-nvprune=11.6.124=he22ec0a_0 45 | - cuda-nvrtc=11.6.124=h020bade_0 46 | - cuda-nvrtc-dev=11.6.124=h249d397_0 47 | - cuda-nvtx=11.6.124=h0630a44_0 48 | - cuda-nvvp=12.1.55=0 49 | - cuda-runtime=11.6.1=0 50 | - cuda-samples=11.6.101=h8efea70_0 51 | - cuda-sanitizer-api=12.1.55=0 52 | - cuda-toolkit=11.6.1=0 53 | - cuda-tools=11.6.1=0 54 | - cuda-visual-tools=11.6.1=0 55 | - debugpy=1.5.1=py310h295c915_0 56 | - decorator=5.1.1=pyhd3eb1b0_0 57 | - flit-core=3.8.0=py310h06a4308_0 58 | - freetype=2.12.1=h4a9f257_0 59 | - gds-tools=1.6.0.25=0 60 | - giflib=5.2.1=h5eee18b_3 61 | - gmp=6.2.1=h295c915_3 62 | - gnutls=3.6.15=he1e5248_0 63 | - idna=3.4=py310h06a4308_0 64 | - intel-openmp=2021.4.0=h06a4308_3561 65 | - ipykernel=6.19.2=py310h2f386ee_0 66 | - ipython=8.12.0=py310h06a4308_0 67 | - jpeg=9e=h5eee18b_1 68 | - jsonpatch=1.32=pyhd3eb1b0_0 69 | - jsonpointer=2.1=pyhd3eb1b0_0 70 | - jupyter_client=8.1.0=py310h06a4308_0 71 | - jupyter_core=5.3.0=py310h06a4308_0 72 | - lame=3.100=h7b6447c_0 73 | - lcms2=2.12=h3be6417_0 74 | - ld_impl_linux-64=2.38=h1181459_1 75 | - lerc=3.0=h295c915_0 76 | - libcublas=11.9.2.110=h5e84587_0 77 | - libcublas-dev=11.9.2.110=h5c901ab_0 78 | - libcufft=10.7.1.112=hf425ae0_0 79 | - libcufft-dev=10.7.1.112=ha5ce4c0_0 80 | - libcufile=1.6.0.25=0 81 | - libcufile-dev=1.6.0.25=0 82 | - libcurand=10.3.2.56=0 83 | - libcurand-dev=10.3.2.56=0 84 | - libcusolver=11.3.4.124=h33c3c4e_0 85 | - libcusparse=11.7.2.124=h7538f96_0 86 | - libcusparse-dev=11.7.2.124=hbbe9722_0 87 | - libdeflate=1.17=h5eee18b_0 88 | - libffi=3.4.2=h6a678d5_6 89 | - libgcc-ng=11.2.0=h1234567_1 90 | - libgomp=11.2.0=h1234567_1 91 | - libiconv=1.16=h7f8727e_2 92 | - libidn2=2.3.2=h7f8727e_0 93 | - libnpp=11.6.3.124=hd2722f0_0 94 | - libnpp-dev=11.6.3.124=h3c42840_0 95 | - libnvjpeg=11.6.2.124=hd473ad6_0 96 | - libnvjpeg-dev=11.6.2.124=hb5906b9_0 97 | - libpng=1.6.39=h5eee18b_0 98 | - libsodium=1.0.18=h7b6447c_0 99 | - libstdcxx-ng=11.2.0=h1234567_1 100 | - libtasn1=4.19.0=h5eee18b_0 101 | - libtiff=4.5.0=h6a678d5_2 102 | - libunistring=0.9.10=h27cfd23_0 103 | - libuuid=1.41.5=h5eee18b_0 104 | - libwebp=1.2.4=h11a3e52_1 105 | - libwebp-base=1.2.4=h5eee18b_1 106 | - lz4-c=1.9.4=h6a678d5_0 107 | - matplotlib-inline=0.1.6=py310h06a4308_0 108 | - mkl=2021.4.0=h06a4308_640 109 | - mkl-service=2.4.0=py310h7f8727e_0 110 | - mkl_fft=1.3.1=py310hd6ae3a3_0 111 | - mkl_random=1.2.2=py310h00e6091_0 112 | - ncurses=6.4=h6a678d5_0 113 | - nest-asyncio=1.5.6=py310h06a4308_0 114 | - nettle=3.7.3=hbbd107a_1 115 | - nsight-compute=2023.1.0.15=0 116 | - numpy=1.23.5=py310hd5efca6_0 117 | - numpy-base=1.23.5=py310h8e6c178_0 118 | - openh264=2.1.1=h4ff587b_0 119 | - openssl=1.1.1t=h7f8727e_0 120 | - packaging=23.0=py310h06a4308_0 121 | - parso=0.8.3=pyhd3eb1b0_0 122 | - pexpect=4.8.0=pyhd3eb1b0_3 123 | - pickleshare=0.7.5=pyhd3eb1b0_1003 124 | - pip=22.3.1=py310h06a4308_0 125 | - platformdirs=2.5.2=py310h06a4308_0 126 | - pluggy=1.0.0=py310h06a4308_1 127 | - psutil=5.9.0=py310h5eee18b_0 128 | - ptyprocess=0.7.0=pyhd3eb1b0_2 129 | - pure_eval=0.2.2=pyhd3eb1b0_0 130 | - pycosat=0.6.4=py310h5eee18b_0 131 | - pycparser=2.21=pyhd3eb1b0_0 132 | - pyopenssl=22.0.0=pyhd3eb1b0_0 133 | - pysocks=1.7.1=py310h06a4308_0 134 | - python=3.10.9=h7a1cb2a_0 135 | - python-dateutil=2.8.2=pyhd3eb1b0_0 136 | - pytorch=1.13.1=py3.10_cuda11.6_cudnn8.3.2_0 137 | - pytorch-cuda=11.6=h867d48c_1 138 | - pytorch-mutex=1.0=cuda 139 | - pyzmq=23.2.0=py310h6a678d5_0 140 | - readline=8.2=h5eee18b_0 141 | - requests=2.28.1=py310h06a4308_0 142 | - ruamel.yaml=0.17.21=py310h5eee18b_0 143 | - ruamel.yaml.clib=0.2.6=py310h5eee18b_1 144 | - setuptools=65.6.3=py310h06a4308_0 145 | - six=1.16.0=pyhd3eb1b0_1 146 | - sqlite=3.40.1=h5082296_0 147 | - stack_data=0.2.0=pyhd3eb1b0_0 148 | - tk=8.6.12=h1ccaba5_0 149 | - toolz=0.12.0=py310h06a4308_0 150 | - torchaudio=0.13.1=py310_cu116 151 | - torchvision=0.14.1=py310_cu116 152 | - tornado=6.2=py310h5eee18b_0 153 | - tqdm=4.64.1=py310h06a4308_0 154 | - typing_extensions=4.4.0=py310h06a4308_0 155 | - tzdata=2022g=h04d1e81_0 156 | - urllib3=1.26.14=py310h06a4308_0 157 | - wheel=0.37.1=pyhd3eb1b0_0 158 | - xz=5.2.10=h5eee18b_1 159 | - zeromq=4.3.4=h2531618_0 160 | - zlib=1.2.13=h5eee18b_0 161 | - zstandard=0.18.0=py310h5eee18b_0 162 | - zstd=1.5.4=hc292b87_0 163 | - pip: 164 | - absl-py==1.4.0 165 | - aiohttp==3.8.4 166 | - aiosignal==1.3.1 167 | - anyio==3.6.2 168 | - appdirs==1.4.4 169 | - arrow==1.2.3 170 | - asttokens==2.2.1 171 | - async-generator==1.10 172 | - async-timeout==4.0.2 173 | - attrs==22.2.0 174 | - audioread==3.0.0 175 | - av==10.0.0 176 | - beartype==0.12.0 177 | - beautifulsoup4==4.12.2 178 | - blessed==1.20.0 179 | - braceexpand==0.1.7 180 | - cachetools==5.3.0 181 | - click==8.1.3 182 | - contourpy==1.0.7 183 | - croniter==1.3.10 184 | - cycler==0.11.0 185 | - dataclasses-json==0.5.8 186 | - dateutils==0.6.12 187 | - decord==0.6.0 188 | - deepdiff==6.3.0 189 | - dtk==0.2 190 | - exceptiongroup==1.1.1 191 | - executing==1.2.0 192 | - fastapi==0.88.0 193 | - ffmpeg==1.4 194 | - ffmpeg-python==0.2.0 195 | - filelock==3.12.0 196 | - fonttools==4.39.3 197 | - frozenlist==1.3.3 198 | - fsspec==2023.4.0 199 | - ftfy==6.1.1 200 | - future==0.18.3 201 | - gammatone==1.0 202 | - google-auth==2.17.3 203 | - google-auth-oauthlib==1.0.0 204 | - greenlet==2.0.2 205 | - grpcio==1.54.0 206 | - h11==0.14.0 207 | - h5py==3.8.0 208 | - hickle==5.0.2 209 | - huggingface-hub==0.14.1 210 | - humanize==4.6.0 211 | - imageio==2.27.0 212 | - inquirer==3.1.3 213 | - ipdb==0.13.13 214 | - itsdangerous==2.1.2 215 | - jedi==0.18.2 216 | - jinja2==3.1.2 217 | - joblib==1.2.0 218 | - kiwisolver==1.4.4 219 | - langchain==0.0.216 220 | - langchainplus-sdk==0.0.17 221 | - lazy-loader==0.2 222 | - librosa==0.10.0.post2 223 | - lightning==2.0.0 224 | - lightning-cloud==0.5.33 225 | - lightning-utilities==0.8.0 226 | - llvmlite==0.39.1 227 | - markdown==3.4.3 228 | - markdown-it-py==2.2.0 229 | - markupsafe==2.1.2 230 | - marshmallow==3.19.0 231 | - marshmallow-enum==1.5.1 232 | - matplotlib==3.7.1 233 | - mdurl==0.1.2 234 | - mergedeep==1.3.4 235 | - mock==5.0.2 236 | - msgpack==1.0.5 237 | - msgpack-numpy==0.4.8 238 | - multidict==6.0.4 239 | - musdb==0.4.0 240 | - mypy-extensions==1.0.0 241 | - networkx==3.1 242 | - nose==1.3.7 243 | - numba==0.56.4 244 | - numexpr==2.8.4 245 | - oauthlib==3.2.2 246 | - openai==0.27.8 247 | - openapi-schema-pydantic==1.2.4 248 | - opencv-python==4.7.0.72 249 | - ordered-set==4.1.0 250 | - outcome==1.2.0 251 | - pandas==1.5.3 252 | - panns-inference==0.1.0 253 | - pesq==0.0.4 254 | - pillow==9.5.0 255 | - pooch==1.6.0 256 | - prompt-toolkit==3.0.38 257 | - protobuf==4.22.3 258 | - pyaml==23.5.9 259 | - pyasn1==0.5.0 260 | - pyasn1-modules==0.3.0 261 | - pydantic==1.10.7 262 | - pygments==2.14.0 263 | - pyjwt==2.6.0 264 | - pyloudnorm==0.1.1 265 | - pyparsing==3.0.9 266 | - pystoi==0.3.3 267 | - python-editor==1.0.4 268 | - python-multipart==0.0.6 269 | - pytorch-ignite==0.3.0 270 | - pytorch-lightning==2.0.1.post0 271 | - pytz==2023.3 272 | - pywavelets==1.4.1 273 | - pyyaml==6.0 274 | - readchar==4.0.5 275 | - regex==2023.3.23 276 | - requests-oauthlib==1.3.1 277 | - resampy==0.4.2 278 | - rich==13.3.3 279 | - rsa==4.9 280 | - scikit-image==0.20.0 281 | - scikit-learn==1.2.2 282 | - scipy==1.10.1 283 | - selenium==4.8.3 284 | - simplejpeg==1.6.6 285 | - sniffio==1.3.0 286 | - sortedcontainers==2.4.0 287 | - soundfile==0.12.1 288 | - soupsieve==2.4 289 | - soxr==0.3.5 290 | - sqlalchemy==2.0.17 291 | - stack-data==0.6.2 292 | - starlette==0.22.0 293 | - starsessions==1.3.0 294 | - stempeg==0.2.3 295 | - tenacity==8.2.2 296 | - tensorboard==2.12.2 297 | - tensorboard-data-server==0.7.0 298 | - tensorboard-plugin-wit==1.8.1 299 | - termcolor==1.1.0 300 | - threadpoolctl==3.1.0 301 | - tifffile==2023.3.21 302 | - timm==0.3.2 303 | - tokenizers==0.13.3 304 | - tomli==2.0.1 305 | - torchfile==0.1.0 306 | - torchlibrosa==0.1.0 307 | - torchmetrics==0.11.4 308 | - traitlets==5.9.0 309 | - transformers==4.28.1 310 | - trio==0.22.0 311 | - trio-websocket==0.10.2 312 | - typeguard==3.0.2 313 | - typing-extensions==4.5.0 314 | - typing-inspect==0.9.0 315 | - uvicorn==0.21.1 316 | - visdom==0.1.8.9 317 | - wcwidth==0.2.6 318 | - webdataset==0.2.48 319 | - websocket-client==1.5.1 320 | - websockets==11.0.1 321 | - werkzeug==2.2.3 322 | - wget==3.2 323 | - wsproto==1.2.0 324 | - yarl==1.8.2 325 | - zenodo-get==1.3.4 326 | - zsvision==0.7.8 -------------------------------------------------------------------------------- /environment_win64.yaml: -------------------------------------------------------------------------------- 1 | name: AudioSep 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - pytorch 6 | - nvidia 7 | dependencies: 8 | - audioread=3.0.0=py310h5588dad_1 9 | - blas=1.0=mkl 10 | - brotli-python=1.0.9=py310hd77b12b_7 11 | - bzip2=1.0.8=h2bbff1b_5 12 | - ca-certificates=2023.12.12=haa95532_0 13 | - cchardet=2.1.7=py310hd77b12b_0 14 | - certifi=2024.2.2=py310haa95532_0 15 | - chardet=4.0.0=py310haa95532_1003 16 | - cuda-cccl=12.4.99=h57928b3_0 17 | - cuda-cccl-impl=2.0.1=h57928b3_1 18 | - cuda-cccl_win-64=12.4.99=h57928b3_0 19 | - cuda-cudart=11.7.99=0 20 | - cuda-cudart-dev=11.7.99=0 21 | - cuda-cupti=11.7.101=0 22 | - cuda-libraries=11.7.1=0 23 | - cuda-libraries-dev=11.7.1=0 24 | - cuda-nvrtc=11.7.99=0 25 | - cuda-nvrtc-dev=11.7.99=0 26 | - cuda-nvtx=11.7.91=0 27 | - cuda-runtime=11.7.1=0 28 | - cuda-version=12.4=h3060b56_3 29 | - freetype=2.12.1=ha860e81_0 30 | - intel-openmp=2023.1.0=h59b6b97_46320 31 | - jpeg=9e=h2bbff1b_1 32 | - lerc=3.0=hd77b12b_0 33 | - libcublas=11.10.3.66=0 34 | - libcublas-dev=11.10.3.66=0 35 | - libcufft=10.7.2.124=0 36 | - libcufft-dev=10.7.2.124=0 37 | - libcurand=10.3.5.119=0 38 | - libcurand-dev=10.3.5.119=0 39 | - libcusolver=11.4.0.1=0 40 | - libcusolver-dev=11.4.0.1=0 41 | - libcusparse=11.7.4.91=0 42 | - libcusparse-dev=11.7.4.91=0 43 | - libdeflate=1.17=h2bbff1b_1 44 | - libffi=3.4.4=hd77b12b_0 45 | - libnpp=11.7.4.75=0 46 | - libnpp-dev=11.7.4.75=0 47 | - libnvjpeg=11.8.0.2=0 48 | - libnvjpeg-dev=11.8.0.2=0 49 | - libpng=1.6.39=h8cc25b3_0 50 | - libtiff=4.5.1=hd77b12b_0 51 | - libuv=1.44.2=h2bbff1b_0 52 | - libwebp-base=1.3.2=h2bbff1b_0 53 | - lz4-c=1.9.4=h2bbff1b_0 54 | - mkl=2023.1.0=h6b88ed4_46358 55 | - mkl-service=2.4.0=py310h2bbff1b_1 56 | - mkl_fft=1.3.8=py310h2bbff1b_0 57 | - mkl_random=1.2.4=py310h59b6b97_0 58 | - numpy=1.23.5=py310h85e1a82_1 59 | - numpy-base=1.23.5=py310hb5c95e7_1 60 | - openjpeg=2.4.0=h4fc8c34_0 61 | - openssl=3.0.13=h2bbff1b_0 62 | - pysocks=1.7.1=py310haa95532_0 63 | - python=3.10.13=he1021f5_0 64 | - python_abi=3.10=2_cp310 65 | - pytorch=1.13.1=py3.10_cuda11.7_cudnn8_0 66 | - pytorch-cuda=11.7=h16d0643_5 67 | - pytorch-mutex=1.0=cuda 68 | - requests=2.31.0=py310haa95532_1 69 | - setuptools=68.2.2=py310haa95532_0 70 | - sqlite=3.41.2=h2bbff1b_0 71 | - tbb=2021.8.0=h59b6b97_0 72 | - tk=8.6.12=h2bbff1b_0 73 | - typing_extensions=4.9.0=py310haa95532_1 74 | - tzdata=2024a=h04d1e81_0 75 | - vc=14.2=h21ff451_1 76 | - vs2015_runtime=14.27.29016=h5e58377_2 77 | - wheel=0.41.2=py310haa95532_0 78 | - win_inet_pton=1.1.0=py310haa95532_0 79 | - xz=5.4.6=h8cc25b3_0 80 | - zlib=1.2.13=h8cc25b3_0 81 | - zstd=1.5.5=hd43e919_0 82 | - pip: 83 | - absl-py==1.4.0 84 | - aiohttp==3.8.4 85 | - aiosignal==1.3.1 86 | - ansicon==1.89.0 87 | - anyio==3.6.2 88 | - appdirs==1.4.4 89 | - argon2-cffi==23.1.0 90 | - argon2-cffi-bindings==21.2.0 91 | - arrow==1.2.3 92 | - asttokens==2.2.1 93 | - async-generator==1.10 94 | - async-lru==2.0.4 95 | - async-timeout==4.0.2 96 | - attrs==22.2.0 97 | - av==10.0.0 98 | - babel==2.14.0 99 | - beartype==0.12.0 100 | - beautifulsoup4==4.12.2 101 | - bleach==6.1.0 102 | - blessed==1.20.0 103 | - braceexpand==0.1.7 104 | - cachetools==5.3.0 105 | - cffi==1.16.0 106 | - charset-normalizer==3.3.2 107 | - click==8.1.3 108 | - cog==0.9.5 109 | - colorama==0.4.6 110 | - comm==0.2.2 111 | - contourpy==1.0.7 112 | - croniter==1.3.10 113 | - cycler==0.11.0 114 | - dataclasses-json==0.5.8 115 | - dateutils==0.6.12 116 | - debugpy==1.8.1 117 | - decorator==5.1.1 118 | - decord==0.6.0 119 | - deepdiff==6.3.0 120 | - defusedxml==0.7.1 121 | - dtk==0.2 122 | - exceptiongroup==1.1.1 123 | - executing==1.2.0 124 | - fastapi==0.88.0 125 | - fastjsonschema==2.19.1 126 | - ffmpeg==1.4 127 | - ffmpeg-python==0.2.0 128 | - filelock==3.12.0 129 | - fonttools==4.39.3 130 | - fqdn==1.5.1 131 | - frozenlist==1.3.3 132 | - fsspec==2023.4.0 133 | - ftfy==6.1.1 134 | - future==0.18.3 135 | - gammatone==1.0.0 136 | - google-auth==2.17.3 137 | - google-auth-oauthlib==1.0.0 138 | - greenlet==2.0.2 139 | - grpcio==1.54.0 140 | - h11==0.14.0 141 | - h5py==3.8.0 142 | - hickle==5.0.2 143 | - httpcore==1.0.4 144 | - httptools==0.6.1 145 | - httpx==0.27.0 146 | - huggingface-hub==0.14.1 147 | - humanize==4.6.0 148 | - idna==3.6 149 | - imageio==2.27.0 150 | - inquirer==3.1.3 151 | - ipdb==0.13.13 152 | - ipykernel==6.29.3 153 | - ipython==8.18.0 154 | - ipywidgets==8.1.2 155 | - isoduration==20.11.0 156 | - itsdangerous==2.1.2 157 | - jedi==0.18.2 158 | - jinja2==3.1.2 159 | - jinxed==1.2.1 160 | - joblib==1.2.0 161 | - json5==0.9.22 162 | - jsonpatch==1.33 163 | - jsonpointer==2.4 164 | - jsonschema==4.21.1 165 | - jsonschema-specifications==2023.12.1 166 | - jupyter==1.0.0 167 | - jupyter-client==8.6.1 168 | - jupyter-console==6.6.3 169 | - jupyter-core==5.7.2 170 | - jupyter-events==0.9.1 171 | - jupyter-lsp==2.2.4 172 | - jupyter-server==2.13.0 173 | - jupyter-server-terminals==0.5.3 174 | - jupyterlab==4.1.5 175 | - jupyterlab-pygments==0.3.0 176 | - jupyterlab-server==2.25.4 177 | - jupyterlab-widgets==3.0.10 178 | - kiwisolver==1.4.4 179 | - langchain==0.0.216 180 | - langchainplus-sdk==0.0.17 181 | - lazy-loader==0.2 182 | - librosa==0.10.0.post2 183 | - lightning==2.0.0 184 | - lightning-cloud==0.5.33 185 | - lightning-utilities==0.8.0 186 | - llvmlite==0.39.1 187 | - markdown==3.4.3 188 | - markdown-it-py==2.2.0 189 | - markupsafe==2.1.2 190 | - marshmallow==3.19.0 191 | - marshmallow-enum==1.5.1 192 | - matplotlib==3.7.1 193 | - matplotlib-inline==0.1.6 194 | - mdurl==0.1.2 195 | - mergedeep==1.3.4 196 | - mistune==3.0.2 197 | - mock==5.0.2 198 | - mpmath==1.3.0 199 | - msgpack==1.0.5 200 | - msgpack-numpy==0.4.8 201 | - multidict==6.0.4 202 | - musdb==0.4.0 203 | - mypy-extensions==1.0.0 204 | - nbclient==0.10.0 205 | - nbconvert==7.16.2 206 | - nbformat==5.10.3 207 | - nest-asyncio==1.6.0 208 | - networkx==3.1 209 | - nose==1.3.7 210 | - notebook==7.1.2 211 | - notebook-shim==0.2.4 212 | - numba==0.56.4 213 | - numexpr==2.8.4 214 | - oauthlib==3.2.2 215 | - openai==0.27.8 216 | - openapi-schema-pydantic==1.2.4 217 | - opencv-python==4.7.0.72 218 | - ordered-set==4.1.0 219 | - outcome==1.2.0 220 | - overrides==7.7.0 221 | - packaging==24.0 222 | - pandas==1.5.3 223 | - pandocfilters==1.5.1 224 | - panns-inference==0.1.0 225 | - parso==0.8.3 226 | - pesq==0.0.4 227 | - pillow==9.5.0 228 | - pip==24.0 229 | - platformdirs==4.2.0 230 | - pooch==1.6.0 231 | - prometheus-client==0.20.0 232 | - prompt-toolkit==3.0.38 233 | - protobuf==4.22.3 234 | - psutil==5.9.8 235 | - pure-eval==0.2.2 236 | - pyaml==23.5.9 237 | - pyasn1==0.5.0 238 | - pyasn1-modules==0.3.0 239 | - pycparser==2.21 240 | - pydantic==1.10.7 241 | - pygments==2.14.0 242 | - pyjwt==2.6.0 243 | - pyloudnorm==0.1.1 244 | - pyparsing==3.0.9 245 | - pystoi==0.3.3 246 | - python-dateutil==2.9.0.post0 247 | - python-dotenv==1.0.1 248 | - python-editor==1.0.4 249 | - python-json-logger==2.0.7 250 | - python-multipart==0.0.6 251 | - pytorch-ignite==0.3.0 252 | - pytorch-lightning==2.0.1.post0 253 | - pytz==2023.3 254 | - pywavelets==1.4.1 255 | - pywin32==306 256 | - pywinpty==2.0.13 257 | - pyyaml==6.0 258 | - pyzmq==25.1.2 259 | - qtconsole==5.5.1 260 | - qtpy==2.4.1 261 | - readchar==4.0.5 262 | - referencing==0.33.0 263 | - regex==2023.3.23 264 | - requests-oauthlib==1.3.1 265 | - resampy==0.4.2 266 | - rfc3339-validator==0.1.4 267 | - rfc3986-validator==0.1.1 268 | - rich==13.3.3 269 | - rpds-py==0.18.0 270 | - rsa==4.9 271 | - scikit-image==0.20.0 272 | - scikit-learn==1.2.2 273 | - scipy==1.10.1 274 | - selenium==4.8.3 275 | - send2trash==1.8.2 276 | - simplejpeg==1.6.6 277 | - six==1.16.0 278 | - sniffio==1.3.0 279 | - sortedcontainers==2.4.0 280 | - soundfile==0.12.1 281 | - soupsieve==2.4 282 | - soxr==0.3.5 283 | - sqlalchemy==2.0.17 284 | - stack-data==0.6.2 285 | - starlette==0.22.0 286 | - starsessions==1.3.0 287 | - stempeg==0.2.3 288 | - structlog==24.1.0 289 | - sympy==1.12 290 | - tenacity==8.2.2 291 | - tensorboard==2.12.2 292 | - tensorboard-data-server==0.7.0 293 | - tensorboard-plugin-wit==1.8.1 294 | - termcolor==1.1.0 295 | - terminado==0.18.1 296 | - threadpoolctl==3.1.0 297 | - tifffile==2023.3.21 298 | - timm==0.3.2 299 | - tinycss2==1.2.1 300 | - tokenizers==0.13.3 301 | - tomli==2.0.1 302 | - torch==2.1.2 303 | - torchaudio==0.13.1 304 | - torchfile==0.1.0 305 | - torchlibrosa==0.1.0 306 | - torchmetrics==0.11.4 307 | - torchvision==0.16.2 308 | - tornado==6.4 309 | - tqdm==4.66.2 310 | - traitlets==5.9.0 311 | - transformers==4.28.1 312 | - trio==0.22.0 313 | - trio-websocket==0.10.2 314 | - typeguard==3.0.2 315 | - typing-extensions==4.5.0 316 | - typing-inspect==0.9.0 317 | - uri-template==1.3.0 318 | - urllib3==1.26.18 319 | - uvicorn==0.21.1 320 | - visdom==0.1.8.9 321 | - watchfiles==0.21.0 322 | - wcwidth==0.2.6 323 | - webcolors==1.13 324 | - webdataset==0.2.48 325 | - webencodings==0.5.1 326 | - websocket-client==1.5.1 327 | - websockets==11.0.1 328 | - werkzeug==2.2.3 329 | - wget==3.2 330 | - widgetsnbextension==4.0.10 331 | - wsproto==1.2.0 332 | - yarl==1.8.2 333 | - zenodo-get==1.3.4 334 | - zsvision==0.7.8 335 | -------------------------------------------------------------------------------- /evaluation/evaluate_audiocaps.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | from typing import Dict, List 5 | 6 | import csv 7 | import pandas as pd 8 | import numpy as np 9 | import torch 10 | from tqdm import tqdm 11 | import pathlib 12 | import librosa 13 | import lightning.pytorch as pl 14 | from models.clap_encoder import CLAP_Encoder 15 | 16 | sys.path.append('../AudioSep/') 17 | from utils import ( 18 | load_ss_model, 19 | calculate_sdr, 20 | calculate_sisdr, 21 | parse_yaml, 22 | get_mean_sdr_from_dict, 23 | ) 24 | 25 | 26 | class AudioCapsEvaluator: 27 | def __init__( 28 | self, 29 | query='caption', 30 | sampling_rate=32000, 31 | ) -> None: 32 | r"""AudioCaps evaluator. 33 | 34 | Args: 35 | query (str): type of query, 'caption' or 'labels' 36 | Returns: 37 | None 38 | """ 39 | 40 | self.query = query 41 | self.sampling_rate = sampling_rate 42 | 43 | with open(f'evaluation/metadata/audiocaps_eval.csv') as csv_file: 44 | csv_reader = csv.reader(csv_file, delimiter=',') 45 | eval_list = [row for row in csv_reader][1:] 46 | 47 | self.eval_list = eval_list 48 | self.audio_dir = f'evaluation/data/audiocaps' 49 | 50 | def __call__( 51 | self, 52 | pl_model: pl.LightningModule 53 | ) -> Dict: 54 | r"""Evalute.""" 55 | 56 | print(f'Evaluation on AudioCaps with [{self.query}] queries.') 57 | 58 | pl_model.eval() 59 | device = pl_model.device 60 | 61 | sisdrs_list = [] 62 | sdris_list = [] 63 | 64 | with torch.no_grad(): 65 | for eval_data in tqdm(self.eval_list): 66 | 67 | idx, caption, labels, _, _ = eval_data 68 | 69 | source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav') 70 | mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav') 71 | 72 | source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True) 73 | mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True) 74 | 75 | sdr_no_sep = calculate_sdr(ref=source, est=mixture) 76 | 77 | if self.query == 'caption': 78 | text = [caption] 79 | elif self.query == 'labels': 80 | text = [labels] 81 | 82 | conditions = pl_model.query_encoder.get_query_embed( 83 | modality='text', 84 | text=text, 85 | device=device 86 | ) 87 | 88 | input_dict = { 89 | "mixture": torch.Tensor(mixture)[None, None, :].to(device), 90 | "condition": conditions, 91 | } 92 | 93 | 94 | sep_segment = pl_model.ss_model(input_dict)["waveform"] 95 | # sep_segment: (batch_size=1, channels_num=1, segment_samples) 96 | 97 | sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() 98 | # sep_segment: (segment_samples,) 99 | 100 | sdr = calculate_sdr(ref=source, est=sep_segment) 101 | sdri = sdr - sdr_no_sep 102 | sisdr = calculate_sisdr(ref=source, est=sep_segment) 103 | 104 | sisdrs_list.append(sisdr) 105 | sdris_list.append(sdri) 106 | 107 | mean_sisdr = np.mean(sisdrs_list) 108 | mean_sdri = np.mean(sdris_list) 109 | 110 | return mean_sisdr, mean_sdri -------------------------------------------------------------------------------- /evaluation/evaluate_audioset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | from typing import Dict, List 5 | 6 | import pandas as pd 7 | import numpy as np 8 | import torch 9 | from tqdm import tqdm 10 | import pathlib 11 | import librosa 12 | import lightning.pytorch as pl 13 | from models.clap_encoder import CLAP_Encoder 14 | 15 | sys.path.append('../AudioSep/') 16 | from utils import ( 17 | load_ss_model, 18 | calculate_sdr, 19 | calculate_sisdr, 20 | parse_yaml, 21 | get_mean_sdr_from_dict, 22 | ) 23 | 24 | 25 | meta_csv_file = "evaluation/metadata/class_labels_indices.csv" 26 | df = pd.read_csv(meta_csv_file, sep=',') 27 | 28 | IDS = df['mid'].tolist() 29 | LABELS = df['display_name'].tolist() 30 | 31 | CLASSES_NUM = len(LABELS) 32 | 33 | IX_TO_LB = {i : label for i, label in enumerate(LABELS)} 34 | 35 | 36 | class AudioSetEvaluator: 37 | def __init__( 38 | self, 39 | audios_dir='evaluation/data/audioset', 40 | classes_num=527, 41 | sampling_rate=32000, 42 | number_per_class=10, 43 | ) -> None: 44 | r"""AudioSet evaluator. 45 | 46 | Args: 47 | audios_dir (str): directory of evaluation segments 48 | classes_num (int): the number of sound classes 49 | number_per_class (int), the number of samples to evaluate for each sound class 50 | 51 | Returns: 52 | None 53 | """ 54 | 55 | self.audios_dir = audios_dir 56 | self.classes_num = classes_num 57 | self.number_per_class = number_per_class 58 | self.sampling_rate = sampling_rate 59 | 60 | @torch.no_grad() 61 | def __call__( 62 | self, 63 | pl_model: pl.LightningModule 64 | ) -> Dict: 65 | r"""Evalute.""" 66 | 67 | pl_model.eval() 68 | 69 | sisdrs_dict = {class_id: [] for class_id in range(self.classes_num)} 70 | sdris_dict = {class_id: [] for class_id in range(self.classes_num)} 71 | 72 | print('Evaluation on AudioSet with [text label] queries.') 73 | 74 | for class_id in tqdm(range(self.classes_num)): 75 | 76 | sub_dir = os.path.join( 77 | self.audios_dir, 78 | "class_id={}".format(class_id)) 79 | 80 | audio_names = self._get_audio_names(audios_dir=sub_dir) 81 | 82 | for audio_index, audio_name in enumerate(audio_names): 83 | 84 | if audio_index == self.number_per_class: 85 | break 86 | 87 | source_path = os.path.join( 88 | sub_dir, "{},source.wav".format(audio_name)) 89 | mixture_path = os.path.join( 90 | sub_dir, "{},mixture.wav".format(audio_name)) 91 | 92 | source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True) 93 | mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True) 94 | 95 | sdr_no_sep = calculate_sdr(ref=source, est=mixture) 96 | 97 | device = pl_model.device 98 | 99 | text = [IX_TO_LB[class_id]] 100 | 101 | conditions = pl_model.query_encoder.get_query_embed( 102 | modality='text', 103 | text=text, 104 | device=device 105 | ) 106 | 107 | input_dict = { 108 | "mixture": torch.Tensor(mixture)[None, None, :].to(device), 109 | "condition": conditions, 110 | } 111 | 112 | sep_segment = pl_model.ss_model(input_dict)["waveform"] 113 | # sep_segment: (batch_size=1, channels_num=1, segment_samples) 114 | 115 | sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() 116 | # sep_segment: (segment_samples,) 117 | 118 | sdr = calculate_sdr(ref=source, est=sep_segment) 119 | sdri = sdr - sdr_no_sep 120 | sisdr = calculate_sisdr(ref=source, est=sep_segment) 121 | 122 | 123 | sisdrs_dict[class_id].append(sisdr) 124 | sdris_dict[class_id].append(sdri) 125 | 126 | 127 | stats_dict = { 128 | "sisdrs_dict": sisdrs_dict, 129 | "sdris_dict": sdris_dict, 130 | } 131 | 132 | return stats_dict 133 | 134 | def _get_audio_names(self, audios_dir: str) -> List[str]: 135 | r"""Get evaluation audio names.""" 136 | audio_names = sorted(os.listdir(audios_dir)) 137 | 138 | audio_names = [audio_name for audio_name in audio_names if '.wav' in audio_name] 139 | 140 | audio_names = [ 141 | re.search( 142 | "(.*),(mixture|source).wav", 143 | audio_name).group(1) for audio_name in audio_names] 144 | 145 | audio_names = sorted(list(set(audio_names))) 146 | 147 | return audio_names 148 | 149 | @staticmethod 150 | def get_median_metrics(stats_dict, metric_type): 151 | class_ids = stats_dict[metric_type].keys() 152 | median_stats_dict = { 153 | class_id: np.nanmedian( 154 | stats_dict[metric_type][class_id]) for class_id in class_ids} 155 | return median_stats_dict 156 | -------------------------------------------------------------------------------- /evaluation/evaluate_clotho.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | from typing import Dict, List 5 | 6 | import csv 7 | import pandas as pd 8 | import numpy as np 9 | import torch 10 | from tqdm import tqdm 11 | import pathlib 12 | import librosa 13 | import lightning.pytorch as pl 14 | from models.clap_encoder import CLAP_Encoder 15 | 16 | sys.path.append('../AudioSep/') 17 | from utils import ( 18 | load_ss_model, 19 | calculate_sdr, 20 | calculate_sisdr, 21 | parse_yaml, 22 | get_mean_sdr_from_dict, 23 | ) 24 | 25 | 26 | class ClothoEvaluator: 27 | def __init__( 28 | self, 29 | sampling_rate=32000, 30 | ) -> None: 31 | r"""Clotho evaluator. 32 | Returns: 33 | None 34 | """ 35 | 36 | self.sampling_rate = sampling_rate 37 | 38 | with open('evaluation/metadata/clotho_eval.csv') as csv_file: 39 | csv_reader = csv.reader(csv_file, delimiter=',') 40 | eval_list = [row for row in csv_reader][1:] 41 | 42 | self.eval_list = eval_list 43 | self.audio_dir = 'evaluation/data/clotho' 44 | 45 | def __call__( 46 | self, 47 | pl_model: pl.LightningModule 48 | ) -> Dict: 49 | r"""Evalute.""" 50 | 51 | print(f'Evaluation on Clotho Evaluation with [caption] queries.') 52 | 53 | pl_model.eval() 54 | device = pl_model.device 55 | 56 | sisdrs_list = [] 57 | sdris_list = [] 58 | 59 | with torch.no_grad(): 60 | for eval_data in tqdm(self.eval_list): 61 | 62 | idx, caption, _, _, _ = eval_data 63 | 64 | source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav') 65 | mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav') 66 | 67 | source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True) 68 | mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True) 69 | 70 | sdr_no_sep = calculate_sdr(ref=source, est=mixture) 71 | 72 | text = [caption] 73 | 74 | conditions = pl_model.query_encoder.get_query_embed( 75 | modality='text', 76 | text=text, 77 | device=device 78 | ) 79 | 80 | input_dict = { 81 | "mixture": torch.Tensor(mixture)[None, None, :].to(device), 82 | "condition": conditions, 83 | } 84 | 85 | sep_segment = pl_model.ss_model(input_dict)["waveform"] 86 | # sep_segment: (batch_size=1, channels_num=1, segment_samples) 87 | 88 | sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() 89 | # sep_segment: (segment_samples,) 90 | 91 | sdr = calculate_sdr(ref=source, est=sep_segment) 92 | sdri = sdr - sdr_no_sep 93 | sisdr = calculate_sisdr(ref=source, est=sep_segment) 94 | 95 | 96 | sisdrs_list.append(sisdr) 97 | sdris_list.append(sdri) 98 | 99 | mean_sisdr = np.mean(sisdrs_list) 100 | mean_sdri = np.mean(sdris_list) 101 | 102 | return mean_sisdr, mean_sdri -------------------------------------------------------------------------------- /evaluation/evaluate_esc50.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | from typing import Dict, List 5 | 6 | import csv 7 | import pandas as pd 8 | import numpy as np 9 | import torch 10 | from tqdm import tqdm 11 | import pathlib 12 | import librosa 13 | import lightning.pytorch as pl 14 | from models.clap_encoder import CLAP_Encoder 15 | 16 | sys.path.append('../AudioSep/') 17 | from utils import ( 18 | load_ss_model, 19 | calculate_sdr, 20 | calculate_sisdr, 21 | parse_yaml, 22 | get_mean_sdr_from_dict, 23 | ) 24 | 25 | 26 | class ESC50Evaluator: 27 | def __init__( 28 | self, 29 | sampling_rate=32000 30 | ) -> None: 31 | r"""ESC-50 evaluator. 32 | 33 | Returns: 34 | None 35 | """ 36 | 37 | self.sampling_rate = sampling_rate 38 | 39 | with open('evaluation/metadata/esc50_eval.csv') as csv_file: 40 | csv_reader = csv.reader(csv_file, delimiter=',') 41 | eval_list = [row for row in csv_reader][1:] 42 | 43 | self.eval_list = eval_list 44 | self.audio_dir = 'evaluation/data/esc50' 45 | 46 | def __call__( 47 | self, 48 | pl_model: pl.LightningModule 49 | ) -> Dict: 50 | r"""Evalute.""" 51 | 52 | print(f'Evaluation on ESC-50 with [text label] queries.') 53 | 54 | pl_model.eval() 55 | device = pl_model.device 56 | 57 | sisdrs_list = [] 58 | sdris_list = [] 59 | 60 | with torch.no_grad(): 61 | for eval_data in tqdm(self.eval_list): 62 | 63 | idx, caption, _, _, = eval_data 64 | 65 | source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav') 66 | mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav') 67 | 68 | source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True) 69 | mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True) 70 | 71 | sdr_no_sep = calculate_sdr(ref=source, est=mixture) 72 | 73 | text = [caption] 74 | 75 | conditions = pl_model.query_encoder.get_query_embed( 76 | modality='text', 77 | text=text, 78 | device=device 79 | ) 80 | 81 | input_dict = { 82 | "mixture": torch.Tensor(mixture)[None, None, :].to(device), 83 | "condition": conditions, 84 | } 85 | 86 | sep_segment = pl_model.ss_model(input_dict)["waveform"] 87 | # sep_segment: (batch_size=1, channels_num=1, segment_samples) 88 | 89 | sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() 90 | # sep_segment: (segment_samples,) 91 | 92 | sdr = calculate_sdr(ref=source, est=sep_segment) 93 | sdri = sdr - sdr_no_sep 94 | sisdr = calculate_sisdr(ref=source, est=sep_segment) 95 | 96 | sisdrs_list.append(sisdr) 97 | sdris_list.append(sdri) 98 | 99 | mean_sdri = np.mean(sdris_list) 100 | mean_sisdr = np.mean(sisdrs_list) 101 | 102 | return mean_sisdr, mean_sdri 103 | -------------------------------------------------------------------------------- /evaluation/evaluate_music.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | from typing import Dict, List 5 | 6 | import csv 7 | import pandas as pd 8 | import numpy as np 9 | import torch 10 | from tqdm import tqdm 11 | import pathlib 12 | import librosa 13 | import lightning.pytorch as pl 14 | from models.clap_encoder import CLAP_Encoder 15 | 16 | sys.path.append('../AudioSep/') 17 | from utils import ( 18 | load_ss_model, 19 | calculate_sdr, 20 | calculate_sisdr, 21 | parse_yaml, 22 | get_mean_sdr_from_dict, 23 | ) 24 | 25 | 26 | class MUSICEvaluator: 27 | def __init__( 28 | self, 29 | sampling_rate=32000 30 | ) -> None: 31 | 32 | self.sampling_rate = sampling_rate 33 | 34 | with open('evaluation/metadata/music_eval.csv') as csv_file: 35 | csv_reader = csv.reader(csv_file, delimiter=',') 36 | eval_list = [row for row in csv_reader][1:] 37 | 38 | self.eval_list = eval_list 39 | self.audio_dir = 'evaluation/data/music' 40 | 41 | self.source_types = [ 42 | "acoustic guitar", 43 | "violin", 44 | "accordion", 45 | "xylophone", 46 | "erhu", 47 | "trumpet", 48 | "tuba", 49 | "cello", 50 | "flute", 51 | "saxophone"] 52 | 53 | def __call__( 54 | self, 55 | pl_model: pl.LightningModule 56 | ) -> Dict: 57 | r"""Evalute.""" 58 | 59 | print(f'Evaluation on MUSIC Test with [text label] queries.') 60 | 61 | pl_model.eval() 62 | device = pl_model.device 63 | 64 | sisdrs_list = {source_type: [] for source_type in self.source_types} 65 | sdris_list = {source_type: [] for source_type in self.source_types} 66 | 67 | with torch.no_grad(): 68 | for eval_data in tqdm(self.eval_list): 69 | 70 | idx, caption, _, _, = eval_data 71 | 72 | source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav') 73 | mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav') 74 | 75 | source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True) 76 | mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True) 77 | 78 | sdr_no_sep = calculate_sdr(ref=source, est=mixture) 79 | 80 | text = [caption] 81 | 82 | conditions = pl_model.query_encoder.get_query_embed( 83 | modality='text', 84 | text=text, 85 | device=device 86 | ) 87 | 88 | input_dict = { 89 | "mixture": torch.Tensor(mixture)[None, None, :].to(device), 90 | "condition": conditions, 91 | } 92 | 93 | sep_segment = pl_model.ss_model(input_dict)["waveform"] 94 | # sep_segment: (batch_size=1, channels_num=1, segment_samples) 95 | 96 | sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() 97 | # sep_segment: (segment_samples,) 98 | 99 | sdr = calculate_sdr(ref=source, est=sep_segment) 100 | sdri = sdr - sdr_no_sep 101 | sisdr = calculate_sisdr(ref=source, est=sep_segment) 102 | 103 | sisdrs_list[caption].append(sisdr) 104 | sdris_list[caption].append(sdri) 105 | 106 | mean_sisdr_list = [] 107 | mean_sdri_list = [] 108 | 109 | for source_class in self.source_types: 110 | sisdr = np.mean(sisdrs_list[source_class]) 111 | sdri = np.mean(sdris_list[source_class]) 112 | mean_sisdr_list.append(sisdr) 113 | mean_sdri_list.append(sdri) 114 | 115 | mean_sdri = np.mean(mean_sdri_list) 116 | mean_sisdr = np.mean(mean_sisdr_list) 117 | 118 | return mean_sisdr, mean_sdri -------------------------------------------------------------------------------- /evaluation/evaluate_vggsound.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | from typing import Dict, List 5 | 6 | import csv 7 | import pandas as pd 8 | import numpy as np 9 | import torch 10 | from tqdm import tqdm 11 | import pathlib 12 | import librosa 13 | import lightning.pytorch as pl 14 | from models.clap_encoder import CLAP_Encoder 15 | 16 | sys.path.append('../AudioSep/') 17 | from utils import ( 18 | load_ss_model, 19 | calculate_sdr, 20 | calculate_sisdr, 21 | parse_yaml, 22 | get_mean_sdr_from_dict, 23 | ) 24 | 25 | 26 | class VGGSoundEvaluator: 27 | def __init__( 28 | self, 29 | sampling_rate=32000 30 | ) -> None: 31 | r"""VGGSound evaluator. 32 | 33 | Args: 34 | data_recipe (str): dataset split, 'yan' 35 | Returns: 36 | None 37 | """ 38 | 39 | self.sampling_rate = sampling_rate 40 | 41 | with open('evaluation/metadata/vggsound_eval.csv') as csv_file: 42 | csv_reader = csv.reader(csv_file, delimiter=',') 43 | eval_list = [row for row in csv_reader][1:] 44 | 45 | self.eval_list = eval_list 46 | self.audio_dir = 'evaluation/data/vggsound' 47 | 48 | def __call__( 49 | self, 50 | pl_model: pl.LightningModule 51 | ) -> Dict: 52 | r"""Evalute.""" 53 | 54 | print(f'Evaluation on VGGSound+ with [text label] queries.') 55 | 56 | pl_model.eval() 57 | device = pl_model.device 58 | 59 | sisdrs_list = [] 60 | sdris_list = [] 61 | sisdris_list = [] 62 | 63 | 64 | with torch.no_grad(): 65 | for eval_data in tqdm(self.eval_list): 66 | 67 | # labels, source_path, mixture_path = eval_data 68 | file_id, mix_wav, s0_wav, s0_text, s1_wav, s1_text = eval_data 69 | 70 | labels = s0_text 71 | 72 | mixture_path = os.path.join(self.audio_dir, mix_wav) 73 | source_path = os.path.join(self.audio_dir, s0_wav) 74 | 75 | 76 | source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True) 77 | mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True) 78 | 79 | sdr_no_sep = calculate_sdr(ref=source, est=mixture) 80 | 81 | text = [labels] 82 | conditions = pl_model.query_encoder.get_query_embed( 83 | modality='text', 84 | text=text, 85 | device=device 86 | ) 87 | 88 | input_dict = { 89 | "mixture": torch.Tensor(mixture)[None, None, :].to(device), 90 | "condition": conditions, 91 | } 92 | 93 | sep_segment = pl_model.ss_model(input_dict)["waveform"] 94 | # sep_segment: (batch_size=1, channels_num=1, segment_samples) 95 | 96 | sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() 97 | # sep_segment: (segment_samples,) 98 | 99 | sdr = calculate_sdr(ref=source, est=sep_segment) 100 | sdri = sdr - sdr_no_sep 101 | 102 | sisdr_no_sep = calculate_sisdr(ref=source, est=mixture) 103 | sisdr = calculate_sisdr(ref=source, est=sep_segment) 104 | sisdri = sisdr - sisdr_no_sep 105 | 106 | sisdrs_list.append(sisdr) 107 | sdris_list.append(sdri) 108 | sisdris_list.append(sisdri) 109 | 110 | 111 | mean_sisdr = np.mean(sisdrs_list) 112 | mean_sdri = np.mean(sdris_list) 113 | 114 | return mean_sisdr, mean_sdri -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def l1(output, target): 5 | return torch.mean(torch.abs(output - target)) 6 | 7 | 8 | def l1_wav(output_dict, target_dict): 9 | return l1(output_dict['segment'], target_dict['segment']) 10 | 11 | 12 | def get_loss_function(loss_type): 13 | if loss_type == "l1_wav": 14 | return l1_wav 15 | 16 | else: 17 | raise NotImplementedError("Error!") 18 | -------------------------------------------------------------------------------- /models/CLAP/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-AGI/AudioSep/944583f18b84589dc965de3ad77525c945334252/models/CLAP/__init__.py -------------------------------------------------------------------------------- /models/CLAP/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import ( 2 | list_models, 3 | create_model, 4 | create_model_and_transforms, 5 | add_model_config, 6 | ) 7 | from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics 8 | from .model import ( 9 | CLAP, 10 | CLAPTextCfg, 11 | CLAPVisionCfg, 12 | CLAPAudioCfp, 13 | convert_weights_to_fp16, 14 | trace_model, 15 | ) 16 | from .openai import load_openai_model, list_openai_models 17 | from .pretrained import ( 18 | list_pretrained, 19 | list_pretrained_tag_models, 20 | list_pretrained_model_tags, 21 | get_pretrained_url, 22 | download_pretrained, 23 | ) 24 | from .tokenizer import SimpleTokenizer, tokenize 25 | from .transform import image_transform 26 | -------------------------------------------------------------------------------- /models/CLAP/open_clip/bert.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer, BertModel 2 | 3 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 4 | model = BertModel.from_pretrained("bert-base-uncased") 5 | text = "Replace me by any text you'd like." 6 | 7 | 8 | def bert_embeddings(text): 9 | # text = "Replace me by any text you'd like." 10 | encoded_input = tokenizer(text, return_tensors="pt") 11 | output = model(**encoded_input) 12 | return output 13 | 14 | 15 | from transformers import RobertaTokenizer, RobertaModel 16 | 17 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 18 | model = RobertaModel.from_pretrained("roberta-base") 19 | text = "Replace me by any text you'd like." 20 | 21 | 22 | def Roberta_embeddings(text): 23 | # text = "Replace me by any text you'd like." 24 | encoded_input = tokenizer(text, return_tensors="pt") 25 | output = model(**encoded_input) 26 | return output 27 | 28 | 29 | from transformers import BartTokenizer, BartModel 30 | 31 | tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") 32 | model = BartModel.from_pretrained("facebook/bart-base") 33 | text = "Replace me by any text you'd like." 34 | 35 | 36 | def bart_embeddings(text): 37 | # text = "Replace me by any text you'd like." 38 | encoded_input = tokenizer(text, return_tensors="pt") 39 | output = model(**encoded_input) 40 | return output 41 | -------------------------------------------------------------------------------- /models/CLAP/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-AGI/AudioSep/944583f18b84589dc965de3ad77525c945334252/models/CLAP/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /models/CLAP/open_clip/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | 9 | import torch 10 | 11 | from .model import CLAP, convert_weights_to_fp16 12 | from .openai import load_openai_model 13 | from .pretrained import get_pretrained_url, download_pretrained 14 | from .transform import image_transform 15 | 16 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 17 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 18 | 19 | 20 | def _natural_key(string_): 21 | return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] 22 | 23 | 24 | def _rescan_model_configs(): 25 | global _MODEL_CONFIGS 26 | 27 | config_ext = (".json",) 28 | config_files = [] 29 | for config_path in _MODEL_CONFIG_PATHS: 30 | if config_path.is_file() and config_path.suffix in config_ext: 31 | config_files.append(config_path) 32 | elif config_path.is_dir(): 33 | for ext in config_ext: 34 | config_files.extend(config_path.glob(f"*{ext}")) 35 | 36 | for cf in config_files: 37 | if os.path.basename(cf)[0] == ".": 38 | continue # Ignore hidden files 39 | 40 | with open(cf, "r") as f: 41 | model_cfg = json.load(f) 42 | if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")): 43 | _MODEL_CONFIGS[cf.stem] = model_cfg 44 | 45 | _MODEL_CONFIGS = { 46 | k: v 47 | for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) 48 | } 49 | 50 | 51 | _rescan_model_configs() # initial populate of model config registry 52 | 53 | 54 | def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True): 55 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 56 | if isinstance(checkpoint, dict) and "state_dict" in checkpoint: 57 | state_dict = checkpoint["state_dict"] 58 | else: 59 | state_dict = checkpoint 60 | if skip_params: 61 | if next(iter(state_dict.items()))[0].startswith("module"): 62 | state_dict = {k[7:]: v for k, v in state_dict.items()} 63 | # for k in state_dict: 64 | # if k.startswith('transformer'): 65 | # v = state_dict.pop(k) 66 | # state_dict['text_branch.' + k[12:]] = v 67 | return state_dict 68 | 69 | 70 | def create_model( 71 | amodel_name: str, 72 | tmodel_name: str, 73 | pretrained: str = "", 74 | precision: str = "fp32", 75 | device: torch.device = torch.device("cpu"), 76 | jit: bool = False, 77 | force_quick_gelu: bool = False, 78 | openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"), 79 | skip_params=True, 80 | pretrained_audio: str = "", 81 | pretrained_text: str = "", 82 | enable_fusion: bool = False, 83 | fusion_type: str = "None" 84 | # pretrained_image: bool = False, 85 | ): 86 | amodel_name = amodel_name.replace( 87 | "/", "-" 88 | ) # for callers using old naming with / in ViT names 89 | pretrained_orig = pretrained 90 | pretrained = pretrained.lower() 91 | if pretrained == "openai": 92 | if amodel_name in _MODEL_CONFIGS: 93 | logging.info(f"Loading {amodel_name} model config.") 94 | model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) 95 | else: 96 | logging.error( 97 | f"Model config for {amodel_name} not found; available models {list_models()}." 98 | ) 99 | raise RuntimeError(f"Model config for {amodel_name} not found.") 100 | 101 | logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.") 102 | # Hard Code in model name 103 | model_cfg["text_cfg"]["model_type"] = tmodel_name 104 | model = load_openai_model( 105 | "ViT-B-16", 106 | model_cfg, 107 | device=device, 108 | jit=jit, 109 | cache_dir=openai_model_cache_dir, 110 | enable_fusion=enable_fusion, 111 | fusion_type=fusion_type, 112 | ) 113 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 114 | if precision == "amp" or precision == "fp32": 115 | model = model.float() 116 | else: 117 | if amodel_name in _MODEL_CONFIGS: 118 | logging.info(f"Loading {amodel_name} model config.") 119 | model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) 120 | else: 121 | logging.error( 122 | f"Model config for {amodel_name} not found; available models {list_models()}." 123 | ) 124 | raise RuntimeError(f"Model config for {amodel_name} not found.") 125 | 126 | if force_quick_gelu: 127 | # override for use of QuickGELU on non-OpenAI transformer models 128 | model_cfg["quick_gelu"] = True 129 | 130 | # if pretrained_image: 131 | # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}): 132 | # # pretrained weight loading for timm models set via vision_cfg 133 | # model_cfg['vision_cfg']['timm_model_pretrained'] = True 134 | # else: 135 | # assert False, 'pretrained image towers currently only supported for timm models' 136 | model_cfg["text_cfg"]["model_type"] = tmodel_name 137 | model_cfg["enable_fusion"] = enable_fusion 138 | model_cfg["fusion_type"] = fusion_type 139 | model = CLAP(**model_cfg) 140 | 141 | if pretrained: 142 | checkpoint_path = "" 143 | url = get_pretrained_url(amodel_name, pretrained) 144 | if url: 145 | checkpoint_path = download_pretrained(url, root=openai_model_cache_dir) 146 | elif os.path.exists(pretrained_orig): 147 | checkpoint_path = pretrained_orig 148 | if checkpoint_path: 149 | logging.info( 150 | f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})." 151 | ) 152 | ckpt = load_state_dict(checkpoint_path, skip_params=True) 153 | model.load_state_dict(ckpt) 154 | param_names = [n for n, p in model.named_parameters()] 155 | # for n in param_names: 156 | # print(n, "\t", "Loaded" if n in ckpt else "Unloaded") 157 | else: 158 | logging.warning( 159 | f"Pretrained weights ({pretrained}) not found for model {amodel_name}." 160 | ) 161 | raise RuntimeError( 162 | f"Pretrained weights ({pretrained}) not found for model {amodel_name}." 163 | ) 164 | 165 | if pretrained_audio: 166 | if amodel_name.startswith("PANN"): 167 | if "Cnn14_mAP" in pretrained_audio: # official checkpoint 168 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 169 | audio_ckpt = audio_ckpt["model"] 170 | keys = list(audio_ckpt.keys()) 171 | for key in keys: 172 | if ( 173 | "spectrogram_extractor" not in key 174 | and "logmel_extractor" not in key 175 | ): 176 | v = audio_ckpt.pop(key) 177 | audio_ckpt["audio_branch." + key] = v 178 | elif os.path.basename(pretrained_audio).startswith( 179 | "PANN" 180 | ): # checkpoint trained via HTSAT codebase 181 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 182 | audio_ckpt = audio_ckpt["state_dict"] 183 | keys = list(audio_ckpt.keys()) 184 | for key in keys: 185 | if key.startswith("sed_model"): 186 | v = audio_ckpt.pop(key) 187 | audio_ckpt["audio_branch." + key[10:]] = v 188 | elif os.path.basename(pretrained_audio).startswith( 189 | "finetuned" 190 | ): # checkpoint trained via linear probe codebase 191 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 192 | else: 193 | raise ValueError("Unknown audio checkpoint") 194 | elif amodel_name.startswith("HTSAT"): 195 | if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint 196 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 197 | audio_ckpt = audio_ckpt["state_dict"] 198 | keys = list(audio_ckpt.keys()) 199 | for key in keys: 200 | if key.startswith("sed_model") and ( 201 | "spectrogram_extractor" not in key 202 | and "logmel_extractor" not in key 203 | ): 204 | v = audio_ckpt.pop(key) 205 | audio_ckpt["audio_branch." + key[10:]] = v 206 | elif os.path.basename(pretrained_audio).startswith( 207 | "HTSAT" 208 | ): # checkpoint trained via HTSAT codebase 209 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 210 | audio_ckpt = audio_ckpt["state_dict"] 211 | keys = list(audio_ckpt.keys()) 212 | for key in keys: 213 | if key.startswith("sed_model"): 214 | v = audio_ckpt.pop(key) 215 | audio_ckpt["audio_branch." + key[10:]] = v 216 | elif os.path.basename(pretrained_audio).startswith( 217 | "finetuned" 218 | ): # checkpoint trained via linear probe codebase 219 | audio_ckpt = torch.load(pretrained_audio, map_location="cpu") 220 | else: 221 | raise ValueError("Unknown audio checkpoint") 222 | else: 223 | raise f"this audio encoder pretrained checkpoint is not support" 224 | 225 | model.load_state_dict(audio_ckpt, strict=False) 226 | logging.info( 227 | f"Loading pretrained {amodel_name} weights ({pretrained_audio})." 228 | ) 229 | param_names = [n for n, p in model.named_parameters()] 230 | for n in param_names: 231 | print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded") 232 | 233 | model.to(device=device) 234 | if precision == "fp16": 235 | assert device.type != "cpu" 236 | convert_weights_to_fp16(model) 237 | 238 | if jit: 239 | model = torch.jit.script(model) 240 | 241 | return model, model_cfg 242 | 243 | 244 | def create_model_and_transforms( 245 | model_name: str, 246 | pretrained: str = "", 247 | precision: str = "fp32", 248 | device: torch.device = torch.device("cpu"), 249 | jit: bool = False, 250 | force_quick_gelu: bool = False, 251 | # pretrained_image: bool = False, 252 | ): 253 | model = create_model( 254 | model_name, 255 | pretrained, 256 | precision, 257 | device, 258 | jit, 259 | force_quick_gelu=force_quick_gelu, 260 | # pretrained_image=pretrained_image 261 | ) 262 | preprocess_train = image_transform(model.visual.image_size, is_train=True) 263 | preprocess_val = image_transform(model.visual.image_size, is_train=False) 264 | return model, preprocess_train, preprocess_val 265 | 266 | 267 | def list_models(): 268 | """enumerate available model architectures based on config files""" 269 | return list(_MODEL_CONFIGS.keys()) 270 | 271 | 272 | def add_model_config(path): 273 | """add model config path or file and update registry""" 274 | if not isinstance(path, Path): 275 | path = Path(path) 276 | _MODEL_CONFIG_PATHS.append(path) 277 | _rescan_model_configs() 278 | -------------------------------------------------------------------------------- /models/CLAP/open_clip/feature_fusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Feature Fusion for Variable-Length Data Processing 3 | AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py 4 | According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class DAF(nn.Module): 12 | """ 13 | 直接相加 DirectAddFuse 14 | """ 15 | 16 | def __init__(self): 17 | super(DAF, self).__init__() 18 | 19 | def forward(self, x, residual): 20 | return x + residual 21 | 22 | 23 | class iAFF(nn.Module): 24 | """ 25 | 多特征融合 iAFF 26 | """ 27 | 28 | def __init__(self, channels=64, r=4, type="2D"): 29 | super(iAFF, self).__init__() 30 | inter_channels = int(channels // r) 31 | 32 | if type == "1D": 33 | # 本地注意力 34 | self.local_att = nn.Sequential( 35 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 36 | nn.BatchNorm1d(inter_channels), 37 | nn.ReLU(inplace=True), 38 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 39 | nn.BatchNorm1d(channels), 40 | ) 41 | 42 | # 全局注意力 43 | self.global_att = nn.Sequential( 44 | nn.AdaptiveAvgPool1d(1), 45 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 46 | nn.BatchNorm1d(inter_channels), 47 | nn.ReLU(inplace=True), 48 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 49 | nn.BatchNorm1d(channels), 50 | ) 51 | 52 | # 第二次本地注意力 53 | self.local_att2 = nn.Sequential( 54 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 55 | nn.BatchNorm1d(inter_channels), 56 | nn.ReLU(inplace=True), 57 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 58 | nn.BatchNorm1d(channels), 59 | ) 60 | # 第二次全局注意力 61 | self.global_att2 = nn.Sequential( 62 | nn.AdaptiveAvgPool1d(1), 63 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 64 | nn.BatchNorm1d(inter_channels), 65 | nn.ReLU(inplace=True), 66 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 67 | nn.BatchNorm1d(channels), 68 | ) 69 | elif type == "2D": 70 | # 本地注意力 71 | self.local_att = nn.Sequential( 72 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 73 | nn.BatchNorm2d(inter_channels), 74 | nn.ReLU(inplace=True), 75 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 76 | nn.BatchNorm2d(channels), 77 | ) 78 | 79 | # 全局注意力 80 | self.global_att = nn.Sequential( 81 | nn.AdaptiveAvgPool2d(1), 82 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 83 | nn.BatchNorm2d(inter_channels), 84 | nn.ReLU(inplace=True), 85 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 86 | nn.BatchNorm2d(channels), 87 | ) 88 | 89 | # 第二次本地注意力 90 | self.local_att2 = nn.Sequential( 91 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 92 | nn.BatchNorm2d(inter_channels), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 95 | nn.BatchNorm2d(channels), 96 | ) 97 | # 第二次全局注意力 98 | self.global_att2 = nn.Sequential( 99 | nn.AdaptiveAvgPool2d(1), 100 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 101 | nn.BatchNorm2d(inter_channels), 102 | nn.ReLU(inplace=True), 103 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 104 | nn.BatchNorm2d(channels), 105 | ) 106 | else: 107 | raise f"the type is not supported" 108 | 109 | self.sigmoid = nn.Sigmoid() 110 | 111 | def forward(self, x, residual): 112 | flag = False 113 | xa = x + residual 114 | if xa.size(0) == 1: 115 | xa = torch.cat([xa, xa], dim=0) 116 | flag = True 117 | xl = self.local_att(xa) 118 | xg = self.global_att(xa) 119 | xlg = xl + xg 120 | wei = self.sigmoid(xlg) 121 | xi = x * wei + residual * (1 - wei) 122 | 123 | xl2 = self.local_att2(xi) 124 | xg2 = self.global_att(xi) 125 | xlg2 = xl2 + xg2 126 | wei2 = self.sigmoid(xlg2) 127 | xo = x * wei2 + residual * (1 - wei2) 128 | if flag: 129 | xo = xo[0].unsqueeze(0) 130 | return xo 131 | 132 | 133 | class AFF(nn.Module): 134 | """ 135 | 多特征融合 AFF 136 | """ 137 | 138 | def __init__(self, channels=64, r=4, type="2D"): 139 | super(AFF, self).__init__() 140 | inter_channels = int(channels // r) 141 | 142 | if type == "1D": 143 | self.local_att = nn.Sequential( 144 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 145 | nn.BatchNorm1d(inter_channels), 146 | nn.ReLU(inplace=True), 147 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 148 | nn.BatchNorm1d(channels), 149 | ) 150 | self.global_att = nn.Sequential( 151 | nn.AdaptiveAvgPool1d(1), 152 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 153 | nn.BatchNorm1d(inter_channels), 154 | nn.ReLU(inplace=True), 155 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 156 | nn.BatchNorm1d(channels), 157 | ) 158 | elif type == "2D": 159 | self.local_att = nn.Sequential( 160 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 161 | nn.BatchNorm2d(inter_channels), 162 | nn.ReLU(inplace=True), 163 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 164 | nn.BatchNorm2d(channels), 165 | ) 166 | self.global_att = nn.Sequential( 167 | nn.AdaptiveAvgPool2d(1), 168 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 169 | nn.BatchNorm2d(inter_channels), 170 | nn.ReLU(inplace=True), 171 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 172 | nn.BatchNorm2d(channels), 173 | ) 174 | else: 175 | raise f"the type is not supported." 176 | 177 | self.sigmoid = nn.Sigmoid() 178 | 179 | def forward(self, x, residual): 180 | flag = False 181 | xa = x + residual 182 | if xa.size(0) == 1: 183 | xa = torch.cat([xa, xa], dim=0) 184 | flag = True 185 | xl = self.local_att(xa) 186 | xg = self.global_att(xa) 187 | xlg = xl + xg 188 | wei = self.sigmoid(xlg) 189 | xo = 2 * x * wei + 2 * residual * (1 - wei) 190 | if flag: 191 | xo = xo[0].unsqueeze(0) 192 | return xo 193 | -------------------------------------------------------------------------------- /models/CLAP/open_clip/linear_probe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from .model import MLPLayers 5 | 6 | 7 | class LinearProbe(nn.Module): 8 | def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): 9 | """ 10 | Args: 11 | model: nn.Module 12 | mlp: bool, if True, then use the MLP layer as the linear probe module 13 | freeze: bool, if True, then freeze all the CLAP model's layers when training the linear probe 14 | in_ch: int, the output channel from CLAP model 15 | out_ch: int, the output channel from linear probe (class_num) 16 | act: torch.nn.functional, the activation function before the loss function 17 | """ 18 | super().__init__() 19 | in_ch = 512 20 | self.clap_model = model 21 | self.clap_model.text_branch = None # to save memory 22 | self.freeze = freeze 23 | if mlp: 24 | self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) 25 | else: 26 | self.lp_layer = nn.Linear(in_ch, out_ch) 27 | 28 | if self.freeze: 29 | for param in self.clap_model.parameters(): 30 | param.requires_grad = False 31 | 32 | if act == "None": 33 | self.act = None 34 | elif act == "relu": 35 | self.act = nn.ReLU() 36 | elif act == "elu": 37 | self.act = nn.ELU() 38 | elif act == "prelu": 39 | self.act = nn.PReLU(num_parameters=in_ch) 40 | elif act == "softmax": 41 | self.act = nn.Softmax(dim=-1) 42 | elif act == "sigmoid": 43 | self.act = nn.Sigmoid() 44 | 45 | def forward(self, x, mix_lambda=None, device=None): 46 | """ 47 | Args: 48 | x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list 49 | mix_lambda: torch.tensor [batch], the mixup lambda 50 | Returns: 51 | class_prob: torch.tensor [batch, class_num] 52 | 53 | """ 54 | # batchnorm cancel grandient 55 | if self.freeze: 56 | self.clap_model.eval() 57 | 58 | x = self.clap_model.audio_projection( 59 | self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[ 60 | "embedding" 61 | ] 62 | ) 63 | out = self.lp_layer(x) 64 | if self.act is not None: 65 | out = self.act(out) 66 | return out 67 | -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/HTSAT-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "base" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/HTSAT-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "large" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/HTSAT-tiny-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/HTSAT-tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/PANN-10.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn10" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/PANN-14-fmax-18k.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 18000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/PANN-14-fmax-8k-20s.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 960000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 360, 10 | "fmin": 50, 11 | "fmax": 8000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/PANN-14-tiny-transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 4 22 | } 23 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/PANN-14-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/PANN-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/PANN-6.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn6" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /models/CLAP/open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import Union, List 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict 13 | from .pretrained import ( 14 | get_pretrained_url, 15 | list_pretrained_tag_models, 16 | download_pretrained, 17 | ) 18 | 19 | __all__ = ["list_openai_models", "load_openai_model"] 20 | 21 | 22 | def list_openai_models() -> List[str]: 23 | """Returns the names of available CLIP models""" 24 | return list_pretrained_tag_models("openai") 25 | 26 | 27 | def load_openai_model( 28 | name: str, 29 | model_cfg, 30 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 31 | jit=True, 32 | cache_dir=os.path.expanduser("~/.cache/clip"), 33 | enable_fusion: bool = False, 34 | fusion_type: str = "None", 35 | ): 36 | """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model 37 | 38 | Parameters 39 | ---------- 40 | name : str 41 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 42 | device : Union[str, torch.device] 43 | The device to put the loaded model 44 | jit : bool 45 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 46 | 47 | Returns 48 | ------- 49 | model : torch.nn.Module 50 | The CLAP model 51 | preprocess : Callable[[PIL.Image], torch.Tensor] 52 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 53 | """ 54 | if get_pretrained_url(name, "openai"): 55 | model_path = download_pretrained( 56 | get_pretrained_url(name, "openai"), root=cache_dir 57 | ) 58 | elif os.path.isfile(name): 59 | model_path = name 60 | else: 61 | raise RuntimeError( 62 | f"Model {name} not found; available models = {list_openai_models()}" 63 | ) 64 | 65 | try: 66 | # loading JIT archive 67 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 68 | state_dict = None 69 | except RuntimeError: 70 | # loading saved state dict 71 | if jit: 72 | warnings.warn( 73 | f"File {model_path} is not a JIT archive. Loading as a state dict instead" 74 | ) 75 | jit = False 76 | state_dict = torch.load(model_path, map_location="cpu") 77 | 78 | if not jit: 79 | try: 80 | model = build_model_from_openai_state_dict( 81 | state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type 82 | ).to(device) 83 | except KeyError: 84 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 85 | model = build_model_from_openai_state_dict( 86 | sd, model_cfg, enable_fusion, fusion_type 87 | ).to(device) 88 | 89 | if str(device) == "cpu": 90 | model.float() 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace( 95 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] 96 | ) 97 | device_node = [ 98 | n 99 | for n in device_holder.graph.findAllNodes("prim::Constant") 100 | if "Device" in repr(n) 101 | ][-1] 102 | 103 | def patch_device(module): 104 | try: 105 | graphs = [module.graph] if hasattr(module, "graph") else [] 106 | except RuntimeError: 107 | graphs = [] 108 | 109 | if hasattr(module, "forward1"): 110 | graphs.append(module.forward1.graph) 111 | 112 | for graph in graphs: 113 | for node in graph.findAllNodes("prim::Constant"): 114 | if "value" in node.attributeNames() and str(node["value"]).startswith( 115 | "cuda" 116 | ): 117 | node.copyAttributes(device_node) 118 | 119 | model.apply(patch_device) 120 | patch_device(model.encode_audio) 121 | patch_device(model.encode_text) 122 | 123 | # patch dtype to float32 on CPU 124 | if str(device) == "cpu": 125 | float_holder = torch.jit.trace( 126 | lambda: torch.ones([]).float(), example_inputs=[] 127 | ) 128 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 129 | float_node = float_input.node() 130 | 131 | def patch_float(module): 132 | try: 133 | graphs = [module.graph] if hasattr(module, "graph") else [] 134 | except RuntimeError: 135 | graphs = [] 136 | 137 | if hasattr(module, "forward1"): 138 | graphs.append(module.forward1.graph) 139 | 140 | for graph in graphs: 141 | for node in graph.findAllNodes("aten::to"): 142 | inputs = list(node.inputs()) 143 | for i in [ 144 | 1, 145 | 2, 146 | ]: # dtype can be the second or third argument to aten::to() 147 | if inputs[i].node()["value"] == 5: 148 | inputs[i].node().copyAttributes(float_node) 149 | 150 | model.apply(patch_float) 151 | patch_float(model.encode_audio) 152 | patch_float(model.encode_text) 153 | model.float() 154 | 155 | model.audio_branch.audio_length = model.audio_cfg.audio_length 156 | return model 157 | -------------------------------------------------------------------------------- /models/CLAP/open_clip/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | 6 | from tqdm import tqdm 7 | 8 | _RN50 = dict( 9 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 10 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 11 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", 12 | ) 13 | 14 | _RN50_quickgelu = dict( 15 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 16 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 17 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", 18 | ) 19 | 20 | _RN101 = dict( 21 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 22 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", 23 | ) 24 | 25 | _RN101_quickgelu = dict( 26 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 27 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", 28 | ) 29 | 30 | _RN50x4 = dict( 31 | openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 32 | ) 33 | 34 | _RN50x16 = dict( 35 | openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 36 | ) 37 | 38 | _RN50x64 = dict( 39 | openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 40 | ) 41 | 42 | _VITB32 = dict( 43 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 44 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 45 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 46 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", 47 | ) 48 | 49 | _VITB32_quickgelu = dict( 50 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 51 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 52 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 53 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", 54 | ) 55 | 56 | _VITB16 = dict( 57 | openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 58 | ) 59 | 60 | _VITL14 = dict( 61 | openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 62 | ) 63 | 64 | _PRETRAINED = { 65 | "RN50": _RN50, 66 | "RN50-quickgelu": _RN50_quickgelu, 67 | "RN101": _RN101, 68 | "RN101-quickgelu": _RN101_quickgelu, 69 | "RN50x4": _RN50x4, 70 | "RN50x16": _RN50x16, 71 | "ViT-B-32": _VITB32, 72 | "ViT-B-32-quickgelu": _VITB32_quickgelu, 73 | "ViT-B-16": _VITB16, 74 | "ViT-L-14": _VITL14, 75 | } 76 | 77 | 78 | def list_pretrained(as_str: bool = False): 79 | """returns list of pretrained models 80 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 81 | """ 82 | return [ 83 | ":".join([k, t]) if as_str else (k, t) 84 | for k in _PRETRAINED.keys() 85 | for t in _PRETRAINED[k].keys() 86 | ] 87 | 88 | 89 | def list_pretrained_tag_models(tag: str): 90 | """return all models having the specified pretrain tag""" 91 | models = [] 92 | for k in _PRETRAINED.keys(): 93 | if tag in _PRETRAINED[k]: 94 | models.append(k) 95 | return models 96 | 97 | 98 | def list_pretrained_model_tags(model: str): 99 | """return all pretrain tags for the specified model architecture""" 100 | tags = [] 101 | if model in _PRETRAINED: 102 | tags.extend(_PRETRAINED[model].keys()) 103 | return tags 104 | 105 | 106 | def get_pretrained_url(model: str, tag: str): 107 | if model not in _PRETRAINED: 108 | return "" 109 | model_pretrained = _PRETRAINED[model] 110 | if tag not in model_pretrained: 111 | return "" 112 | return model_pretrained[tag] 113 | 114 | 115 | def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): 116 | os.makedirs(root, exist_ok=True) 117 | filename = os.path.basename(url) 118 | 119 | if "openaipublic" in url: 120 | expected_sha256 = url.split("/")[-2] 121 | else: 122 | expected_sha256 = "" 123 | 124 | download_target = os.path.join(root, filename) 125 | 126 | if os.path.exists(download_target) and not os.path.isfile(download_target): 127 | raise RuntimeError(f"{download_target} exists and is not a regular file") 128 | 129 | if os.path.isfile(download_target): 130 | if expected_sha256: 131 | if ( 132 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 133 | == expected_sha256 134 | ): 135 | return download_target 136 | else: 137 | warnings.warn( 138 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 139 | ) 140 | else: 141 | return download_target 142 | 143 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 144 | with tqdm( 145 | total=int(source.info().get("Content-Length")), 146 | ncols=80, 147 | unit="iB", 148 | unit_scale=True, 149 | ) as loop: 150 | while True: 151 | buffer = source.read(8192) 152 | if not buffer: 153 | break 154 | 155 | output.write(buffer) 156 | loop.update(len(buffer)) 157 | 158 | if ( 159 | expected_sha256 160 | and hashlib.sha256(open(download_target, "rb").read()).hexdigest() 161 | != expected_sha256 162 | ): 163 | raise RuntimeError( 164 | f"Model has been downloaded but the SHA256 checksum does not not match" 165 | ) 166 | 167 | return download_target 168 | -------------------------------------------------------------------------------- /models/CLAP/open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch.nn as nn 8 | 9 | try: 10 | import timm 11 | from timm.models.layers import Mlp, to_2tuple 12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 13 | from timm.models.layers.attention_pool2d import ( 14 | AttentionPool2d as AbsAttentionPool2d, 15 | ) 16 | except ImportError as e: 17 | timm = None 18 | 19 | from .utils import freeze_batch_norm_2d 20 | 21 | 22 | class TimmModel(nn.Module): 23 | """timm model adapter 24 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 25 | """ 26 | 27 | def __init__( 28 | self, 29 | model_name, 30 | embed_dim, 31 | image_size=224, 32 | pool="avg", 33 | proj="linear", 34 | drop=0.0, 35 | pretrained=False, 36 | ): 37 | super().__init__() 38 | if timm is None: 39 | raise RuntimeError("Please `pip install timm` to use timm models.") 40 | 41 | self.image_size = to_2tuple(image_size) 42 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 43 | feat_size = self.trunk.default_cfg.get("pool_size", None) 44 | feature_ndim = 1 if not feat_size else 2 45 | if pool in ("abs_attn", "rot_attn"): 46 | assert feature_ndim == 2 47 | # if attn pooling used, remove both classifier and default pool 48 | self.trunk.reset_classifier(0, global_pool="") 49 | else: 50 | # reset global pool if pool config set, otherwise leave as network default 51 | reset_kwargs = dict(global_pool=pool) if pool else {} 52 | self.trunk.reset_classifier(0, **reset_kwargs) 53 | prev_chs = self.trunk.num_features 54 | 55 | head_layers = OrderedDict() 56 | if pool == "abs_attn": 57 | head_layers["pool"] = AbsAttentionPool2d( 58 | prev_chs, feat_size=feat_size, out_features=embed_dim 59 | ) 60 | prev_chs = embed_dim 61 | elif pool == "rot_attn": 62 | head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 63 | prev_chs = embed_dim 64 | else: 65 | assert proj, "projection layer needed if non-attention pooling is used." 66 | 67 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 68 | if proj == "linear": 69 | head_layers["drop"] = nn.Dropout(drop) 70 | head_layers["proj"] = nn.Linear(prev_chs, embed_dim) 71 | elif proj == "mlp": 72 | head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) 73 | 74 | self.head = nn.Sequential(head_layers) 75 | 76 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 77 | """lock modules 78 | Args: 79 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 80 | """ 81 | if not unlocked_groups: 82 | # lock full model 83 | for param in self.trunk.parameters(): 84 | param.requires_grad = False 85 | if freeze_bn_stats: 86 | freeze_batch_norm_2d(self.trunk) 87 | else: 88 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 89 | try: 90 | # FIXME import here until API stable and in an official release 91 | from timm.models.helpers import group_parameters, group_modules 92 | except ImportError: 93 | raise RuntimeError( 94 | "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`" 95 | ) 96 | matcher = self.trunk.group_matcher() 97 | gparams = group_parameters(self.trunk, matcher) 98 | max_layer_id = max(gparams.keys()) 99 | max_layer_id = max_layer_id - unlocked_groups 100 | for group_idx in range(max_layer_id + 1): 101 | group = gparams[group_idx] 102 | for param in group: 103 | self.trunk.get_parameter(param).requires_grad = False 104 | if freeze_bn_stats: 105 | gmodules = group_modules(self.trunk, matcher, reverse=True) 106 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 107 | freeze_batch_norm_2d(self.trunk, gmodules) 108 | 109 | def forward(self, x): 110 | x = self.trunk(x) 111 | x = self.head(x) 112 | return x 113 | -------------------------------------------------------------------------------- /models/CLAP/open_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | 16 | @lru_cache() 17 | def default_bpe(): 18 | return os.path.join( 19 | os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" 20 | ) 21 | 22 | 23 | @lru_cache() 24 | def bytes_to_unicode(): 25 | """ 26 | Returns list of utf-8 byte and a corresponding list of unicode strings. 27 | The reversible bpe codes work on unicode strings. 28 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 29 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 30 | This is a significant percentage of your normal, say, 32K bpe vocab. 31 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 32 | And avoids mapping to whitespace/control characters the bpe code barfs on. 33 | """ 34 | bs = ( 35 | list(range(ord("!"), ord("~") + 1)) 36 | + list(range(ord("¡"), ord("¬") + 1)) 37 | + list(range(ord("®"), ord("ÿ") + 1)) 38 | ) 39 | cs = bs[:] 40 | n = 0 41 | for b in range(2**8): 42 | if b not in bs: 43 | bs.append(b) 44 | cs.append(2**8 + n) 45 | n += 1 46 | cs = [chr(n) for n in cs] 47 | return dict(zip(bs, cs)) 48 | 49 | 50 | def get_pairs(word): 51 | """Return set of symbol pairs in a word. 52 | Word is represented as tuple of symbols (symbols being variable-length strings). 53 | """ 54 | pairs = set() 55 | prev_char = word[0] 56 | for char in word[1:]: 57 | pairs.add((prev_char, char)) 58 | prev_char = char 59 | return pairs 60 | 61 | 62 | def basic_clean(text): 63 | text = ftfy.fix_text(text) 64 | text = html.unescape(html.unescape(text)) 65 | return text.strip() 66 | 67 | 68 | def whitespace_clean(text): 69 | text = re.sub(r"\s+", " ", text) 70 | text = text.strip() 71 | return text 72 | 73 | 74 | class SimpleTokenizer(object): 75 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 76 | self.byte_encoder = bytes_to_unicode() 77 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 78 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") 79 | merges = merges[1 : 49152 - 256 - 2 + 1] 80 | merges = [tuple(merge.split()) for merge in merges] 81 | vocab = list(bytes_to_unicode().values()) 82 | vocab = vocab + [v + "" for v in vocab] 83 | for merge in merges: 84 | vocab.append("".join(merge)) 85 | if not special_tokens: 86 | special_tokens = ["", ""] 87 | else: 88 | special_tokens = ["", ""] + special_tokens 89 | vocab.extend(special_tokens) 90 | self.encoder = dict(zip(vocab, range(len(vocab)))) 91 | self.decoder = {v: k for k, v in self.encoder.items()} 92 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 93 | self.cache = {t: t for t in special_tokens} 94 | special = "|".join(special_tokens) 95 | self.pat = re.compile( 96 | special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 97 | re.IGNORECASE, 98 | ) 99 | 100 | self.vocab_size = len(self.encoder) 101 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 102 | 103 | def bpe(self, token): 104 | if token in self.cache: 105 | return self.cache[token] 106 | word = tuple(token[:-1]) + (token[-1] + "",) 107 | pairs = get_pairs(word) 108 | 109 | if not pairs: 110 | return token + "" 111 | 112 | while True: 113 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 114 | if bigram not in self.bpe_ranks: 115 | break 116 | first, second = bigram 117 | new_word = [] 118 | i = 0 119 | while i < len(word): 120 | try: 121 | j = word.index(first, i) 122 | new_word.extend(word[i:j]) 123 | i = j 124 | except: 125 | new_word.extend(word[i:]) 126 | break 127 | 128 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 129 | new_word.append(first + second) 130 | i += 2 131 | else: 132 | new_word.append(word[i]) 133 | i += 1 134 | new_word = tuple(new_word) 135 | word = new_word 136 | if len(word) == 1: 137 | break 138 | else: 139 | pairs = get_pairs(word) 140 | word = " ".join(word) 141 | self.cache[token] = word 142 | return word 143 | 144 | def encode(self, text): 145 | bpe_tokens = [] 146 | text = whitespace_clean(basic_clean(text)).lower() 147 | for token in re.findall(self.pat, text): 148 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 149 | bpe_tokens.extend( 150 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") 151 | ) 152 | return bpe_tokens 153 | 154 | def decode(self, tokens): 155 | text = "".join([self.decoder[token] for token in tokens]) 156 | text = ( 157 | bytearray([self.byte_decoder[c] for c in text]) 158 | .decode("utf-8", errors="replace") 159 | .replace("", " ") 160 | ) 161 | return text 162 | 163 | 164 | _tokenizer = SimpleTokenizer() 165 | 166 | 167 | def tokenize( 168 | texts: Union[str, List[str]], context_length: int = 77 169 | ) -> torch.LongTensor: 170 | """ 171 | Returns the tokenized representation of given input string(s) 172 | 173 | Parameters 174 | ---------- 175 | texts : Union[str, List[str]] 176 | An input string or a list of input strings to tokenize 177 | context_length : int 178 | The context length to use; all CLIP models use 77 as the context length 179 | 180 | Returns 181 | ------- 182 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 183 | """ 184 | if isinstance(texts, str): 185 | texts = [texts] 186 | 187 | sot_token = _tokenizer.encoder[""] 188 | eot_token = _tokenizer.encoder[""] 189 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 190 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 191 | 192 | for i, tokens in enumerate(all_tokens): 193 | if len(tokens) > context_length: 194 | tokens = tokens[:context_length] # Truncate 195 | result[i, : len(tokens)] = torch.tensor(tokens) 196 | 197 | return result 198 | -------------------------------------------------------------------------------- /models/CLAP/open_clip/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import ( 2 | Normalize, 3 | Compose, 4 | RandomResizedCrop, 5 | InterpolationMode, 6 | ToTensor, 7 | Resize, 8 | CenterCrop, 9 | ) 10 | 11 | 12 | def _convert_to_rgb(image): 13 | return image.convert("RGB") 14 | 15 | 16 | def image_transform( 17 | image_size: int, 18 | is_train: bool, 19 | mean=(0.48145466, 0.4578275, 0.40821073), 20 | std=(0.26862954, 0.26130258, 0.27577711), 21 | ): 22 | normalize = Normalize(mean=mean, std=std) 23 | if is_train: 24 | return Compose( 25 | [ 26 | RandomResizedCrop( 27 | image_size, 28 | scale=(0.9, 1.0), 29 | interpolation=InterpolationMode.BICUBIC, 30 | ), 31 | _convert_to_rgb, 32 | ToTensor(), 33 | normalize, 34 | ] 35 | ) 36 | else: 37 | return Compose( 38 | [ 39 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 40 | CenterCrop(image_size), 41 | _convert_to_rgb, 42 | ToTensor(), 43 | normalize, 44 | ] 45 | ) 46 | -------------------------------------------------------------------------------- /models/CLAP/open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.1" 2 | -------------------------------------------------------------------------------- /models/CLAP/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-AGI/AudioSep/944583f18b84589dc965de3ad77525c945334252/models/CLAP/training/__init__.py -------------------------------------------------------------------------------- /models/CLAP/training/audioset_textmap.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-AGI/AudioSep/944583f18b84589dc965de3ad77525c945334252/models/CLAP/training/audioset_textmap.npy -------------------------------------------------------------------------------- /models/CLAP/training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import socket 5 | 6 | try: 7 | import horovod.torch as hvd 8 | except ImportError: 9 | hvd = None 10 | 11 | 12 | def is_global_master(args): 13 | return args.rank == 0 14 | 15 | 16 | def is_local_master(args): 17 | return args.local_rank == 0 18 | 19 | 20 | def is_master(args, local=False): 21 | return is_local_master(args) if local else is_global_master(args) 22 | 23 | 24 | def is_using_horovod(): 25 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 26 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 27 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 28 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 29 | if all([var in os.environ for var in ompi_vars]) or all( 30 | [var in os.environ for var in pmi_vars] 31 | ): 32 | return True 33 | else: 34 | return False 35 | 36 | 37 | def is_using_distributed(): 38 | if "WORLD_SIZE" in os.environ: 39 | return int(os.environ["WORLD_SIZE"]) > 1 40 | if "SLURM_NTASKS" in os.environ: 41 | return int(os.environ["SLURM_NTASKS"]) > 1 42 | return False 43 | 44 | 45 | def world_info_from_env(): 46 | local_rank = 0 47 | for v in ( 48 | "SLURM_LOCALID", 49 | "MPI_LOCALRANKID", 50 | "OMPI_COMM_WORLD_LOCAL_RANK", 51 | "LOCAL_RANK", 52 | ): 53 | if v in os.environ: 54 | local_rank = int(os.environ[v]) 55 | break 56 | global_rank = 0 57 | for v in ("SLURM_PROCID", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "RANK"): 58 | if v in os.environ: 59 | global_rank = int(os.environ[v]) 60 | break 61 | world_size = 1 62 | for v in ("SLURM_NTASKS", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "WORLD_SIZE"): 63 | if v in os.environ: 64 | world_size = int(os.environ[v]) 65 | break 66 | 67 | return local_rank, global_rank, world_size 68 | 69 | 70 | def init_distributed_device(args): 71 | # Distributed training = training on more than one GPU. 72 | # Works in both single and multi-node scenarios. 73 | args.distributed = False 74 | args.world_size = 1 75 | args.rank = 0 # global rank 76 | args.local_rank = 0 77 | if args.horovod: 78 | assert hvd is not None, "Horovod is not installed" 79 | hvd.init() 80 | world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 81 | world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 82 | local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 83 | args.local_rank = local_rank 84 | args.rank = world_rank 85 | args.world_size = world_size 86 | # args.local_rank = int(hvd.local_rank()) 87 | # args.rank = hvd.rank() 88 | # args.world_size = hvd.size() 89 | args.distributed = True 90 | os.environ["LOCAL_RANK"] = str(args.local_rank) 91 | os.environ["RANK"] = str(args.rank) 92 | os.environ["WORLD_SIZE"] = str(args.world_size) 93 | print( 94 | f"Distributed training: local_rank={args.local_rank}, " 95 | f"rank={args.rank}, world_size={args.world_size}, " 96 | f"hostname={socket.gethostname()}, pid={os.getpid()}" 97 | ) 98 | elif is_using_distributed(): 99 | if "SLURM_PROCID" in os.environ: 100 | # DDP via SLURM 101 | args.local_rank, args.rank, args.world_size = world_info_from_env() 102 | # SLURM var -> torch.distributed vars in case needed 103 | os.environ["LOCAL_RANK"] = str(args.local_rank) 104 | os.environ["RANK"] = str(args.rank) 105 | os.environ["WORLD_SIZE"] = str(args.world_size) 106 | torch.distributed.init_process_group( 107 | backend=args.dist_backend, 108 | init_method=args.dist_url, 109 | world_size=args.world_size, 110 | rank=args.rank, 111 | ) 112 | elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster 113 | world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 114 | world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 115 | local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 116 | args.local_rank = local_rank 117 | args.rank = world_rank 118 | args.world_size = world_size 119 | torch.distributed.init_process_group( 120 | backend=args.dist_backend, 121 | init_method=args.dist_url, 122 | world_size=args.world_size, 123 | rank=args.rank, 124 | ) 125 | else: 126 | # DDP via torchrun, torch.distributed.launch 127 | args.local_rank, _, _ = world_info_from_env() 128 | torch.distributed.init_process_group( 129 | backend=args.dist_backend, init_method=args.dist_url 130 | ) 131 | args.world_size = torch.distributed.get_world_size() 132 | args.rank = torch.distributed.get_rank() 133 | args.distributed = True 134 | print( 135 | f"Distributed training: local_rank={args.local_rank}, " 136 | f"rank={args.rank}, world_size={args.world_size}, " 137 | f"hostname={socket.gethostname()}, pid={os.getpid()}" 138 | ) 139 | 140 | if torch.cuda.is_available(): 141 | if args.distributed and not args.no_set_device_rank: 142 | device = "cuda:%d" % args.local_rank 143 | else: 144 | device = "cuda:0" 145 | torch.cuda.set_device(device) 146 | else: 147 | device = "cpu" 148 | args.device = device 149 | device = torch.device(device) 150 | return device 151 | -------------------------------------------------------------------------------- /models/CLAP/training/infer_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append( 4 | "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/src" 5 | ) 6 | 7 | import os 8 | import torch 9 | import librosa 10 | from open_clip import create_model 11 | from training.data import get_audio_features 12 | from training.data import int16_to_float32, float32_to_int16 13 | from transformers import RobertaTokenizer 14 | 15 | tokenize = RobertaTokenizer.from_pretrained("roberta-base") 16 | 17 | 18 | def tokenizer(text): 19 | result = tokenize( 20 | text, 21 | padding="max_length", 22 | truncation=True, 23 | max_length=77, 24 | return_tensors="pt", 25 | ) 26 | return {k: v.squeeze(0) for k, v in result.items()} 27 | 28 | 29 | PRETRAINED_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/checkpoints/epoch_top_0_audioset_no_fusion.pt" 30 | WAVE_48k_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/audio/machine.wav" 31 | 32 | 33 | def infer_text(): 34 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 35 | precision = "fp32" 36 | amodel = "HTSAT-tiny" # or 'PANN-14' 37 | tmodel = "roberta" # the best text encoder in our training 38 | enable_fusion = False # False if you do not want to use the fusion model 39 | fusion_type = "aff_2d" 40 | pretrained = PRETRAINED_PATH 41 | 42 | model, model_cfg = create_model( 43 | amodel, 44 | tmodel, 45 | pretrained, 46 | precision=precision, 47 | device=device, 48 | enable_fusion=enable_fusion, 49 | fusion_type=fusion_type, 50 | ) 51 | # load the text, can be a list (i.e. batch size) 52 | text_data = ["I love the contrastive learning", "I love the pretrain model"] 53 | # tokenize for roberta, if you want to tokenize for another text encoder, please refer to data.py#L43-90 54 | text_data = tokenizer(text_data) 55 | 56 | text_embed = model.get_text_embedding(text_data) 57 | print(text_embed.size()) 58 | 59 | 60 | def infer_audio(): 61 | 62 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 63 | precision = "fp32" 64 | amodel = "HTSAT-tiny" # or 'PANN-14' 65 | tmodel = "roberta" # the best text encoder in our training 66 | enable_fusion = False # False if you do not want to use the fusion model 67 | fusion_type = "aff_2d" 68 | pretrained = PRETRAINED_PATH 69 | 70 | model, model_cfg = create_model( 71 | amodel, 72 | tmodel, 73 | pretrained, 74 | precision=precision, 75 | device=device, 76 | enable_fusion=enable_fusion, 77 | fusion_type=fusion_type, 78 | ) 79 | 80 | # load the waveform of the shape (T,), should resample to 48000 81 | audio_waveform, sr = librosa.load(WAVE_48k_PATH, sr=48000) 82 | # quantize 83 | audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) 84 | audio_waveform = torch.from_numpy(audio_waveform).float() 85 | audio_dict = {} 86 | 87 | # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode 88 | import ipdb 89 | 90 | ipdb.set_trace() 91 | audio_dict = get_audio_features( 92 | audio_dict, 93 | audio_waveform, 94 | 480000, 95 | data_truncating="fusion", 96 | data_filling="repeatpad", 97 | audio_cfg=model_cfg["audio_cfg"], 98 | ) 99 | # can send a list to the model, to process many audio tracks in one time (i.e. batch size) 100 | audio_embed = model.get_audio_embedding([audio_dict]) 101 | print(audio_embed.size()) 102 | import ipdb 103 | 104 | ipdb.set_trace() 105 | 106 | 107 | if __name__ == "__main__": 108 | infer_text() 109 | infer_audio() 110 | -------------------------------------------------------------------------------- /models/CLAP/training/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | 8 | hostname = socket.gethostname() 9 | formatter = logging.Formatter( 10 | f"%(asctime)s | {hostname} | %(levelname)s | %(message)s", 11 | datefmt="%Y-%m-%d,%H:%M:%S", 12 | ) 13 | else: 14 | formatter = logging.Formatter( 15 | "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d,%H:%M:%S" 16 | ) 17 | 18 | logging.root.setLevel(level) 19 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 20 | for logger in loggers: 21 | logger.setLevel(level) 22 | 23 | stream_handler = logging.StreamHandler() 24 | stream_handler.setFormatter(formatter) 25 | logging.root.addHandler(stream_handler) 26 | 27 | if log_file: 28 | file_handler = logging.FileHandler(filename=log_file) 29 | file_handler.setFormatter(formatter) 30 | logging.root.addHandler(file_handler) 31 | -------------------------------------------------------------------------------- /models/CLAP/training/lp_train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import math 4 | import os 5 | import time 6 | from contextlib import suppress 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | try: 13 | import wandb 14 | except ImportError: 15 | wandb = None 16 | 17 | from open_clip import LPLoss, LPMetrics, lp_gather_features 18 | from open_clip.utils import do_mixup, get_mix_lambda 19 | from .distributed import is_master 20 | from .zero_shot import zero_shot_eval 21 | 22 | 23 | class AverageMeter(object): 24 | """Computes and stores the average and current value""" 25 | 26 | def __init__(self): 27 | self.reset() 28 | 29 | def reset(self): 30 | self.val = 0 31 | self.avg = 0 32 | self.sum = 0 33 | self.count = 0 34 | 35 | def update(self, val, n=1): 36 | self.val = val 37 | self.sum += val * n 38 | self.count += n 39 | self.avg = self.sum / self.count 40 | 41 | 42 | def unwrap_model(model): 43 | if hasattr(model, "module"): 44 | return model.module 45 | else: 46 | return model 47 | 48 | 49 | def train_one_epoch( 50 | model, 51 | data, 52 | epoch, 53 | optimizer, 54 | scaler, 55 | scheduler, 56 | args, 57 | tb_writer=None, 58 | extra_suffix="", 59 | ): 60 | device = torch.device(args.device) 61 | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress 62 | model.train() 63 | loss = LPLoss(args.lp_loss) 64 | 65 | dataloader, sampler = data["train"].dataloader, data["train"].sampler 66 | if args.distributed and sampler is not None: 67 | sampler.set_epoch(epoch) 68 | num_batches_per_epoch = dataloader.num_batches 69 | sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) 70 | 71 | # for toy dataset 72 | if args.dataset_type == "toy": 73 | dataloader.dataset.generate_queue() 74 | 75 | loss_m = AverageMeter() 76 | batch_time_m = AverageMeter() 77 | data_time_m = AverageMeter() 78 | end = time.time() 79 | 80 | for i, batch in enumerate(dataloader): 81 | step = num_batches_per_epoch * epoch + i 82 | 83 | if isinstance(scheduler, dict): 84 | for s in scheduler.values(): 85 | s(step) 86 | else: 87 | scheduler(step) 88 | 89 | audio = batch # contains mel_spec, wavform, and longer list 90 | class_label = batch["class_label"] 91 | # audio = audio.to(device=device, non_blocking=True) 92 | class_label = class_label.to(device=device, non_blocking=True) 93 | 94 | if args.mixup: 95 | # https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146 96 | mix_lambda = torch.from_numpy( 97 | get_mix_lambda(0.5, len(audio["waveform"])) 98 | ).to(device) 99 | class_label = do_mixup(class_label, mix_lambda) 100 | else: 101 | mix_lambda = None 102 | 103 | data_time_m.update(time.time() - end) 104 | if isinstance(optimizer, dict): 105 | for o_ in optimizer.values(): 106 | o_.zero_grad() 107 | else: 108 | optimizer.zero_grad() 109 | 110 | with autocast(): 111 | pred = model(audio, mix_lambda=mix_lambda, device=device) 112 | total_loss = loss(pred, class_label) 113 | 114 | if isinstance(optimizer, dict): 115 | if scaler is not None: 116 | scaler.scale(total_loss).backward() 117 | for o_ in optimizer.values(): 118 | if args.horovod: 119 | o_.synchronize() 120 | scaler.unscale_(o_) 121 | with o_.skip_synchronize(): 122 | scaler.step(o_) 123 | else: 124 | scaler.step(o_) 125 | scaler.update() 126 | else: 127 | total_loss.backward() 128 | for o_ in optimizer.values(): 129 | o_.step() 130 | else: 131 | if scaler is not None: 132 | scaler.scale(total_loss).backward() 133 | if args.horovod: 134 | optimizer.synchronize() 135 | scaler.unscale_(optimizer) 136 | with optimizer.skip_synchronize(): 137 | scaler.step(optimizer) 138 | else: 139 | scaler.step(optimizer) 140 | scaler.update() 141 | else: 142 | total_loss.backward() 143 | optimizer.step() 144 | 145 | # Note: we clamp to 4.6052 = ln(100), as in the original paper. 146 | with torch.no_grad(): 147 | unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100)) 148 | unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100)) 149 | 150 | batch_time_m.update(time.time() - end) 151 | end = time.time() 152 | batch_count = i + 1 153 | 154 | if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): 155 | if isinstance(audio, dict): 156 | batch_size = len(audio["waveform"]) 157 | else: 158 | batch_size = len(audio) 159 | num_samples = batch_count * batch_size * args.world_size 160 | samples_per_epoch = dataloader.num_samples 161 | percent_complete = 100.0 * batch_count / num_batches_per_epoch 162 | 163 | # NOTE loss is coarsely sampled, just master node and per log update 164 | loss_m.update(total_loss.item(), batch_size) 165 | if isinstance(optimizer, dict): 166 | logging.info( 167 | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " 168 | f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " 169 | f"Data (t): {data_time_m.avg:.3f} " 170 | f"Batch (t): {batch_time_m.avg:.3f} " 171 | f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}" 172 | ) 173 | log_data = { 174 | "loss": loss_m.val, 175 | "data_time": data_time_m.val, 176 | "batch_time": batch_time_m.val, 177 | "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], 178 | } 179 | else: 180 | logging.info( 181 | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " 182 | f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " 183 | f"Data (t): {data_time_m.avg:.3f} " 184 | f"Batch (t): {batch_time_m.avg:.3f} " 185 | f"LR: {optimizer.param_groups[0]['lr']:5f} " 186 | ) 187 | 188 | # Save train loss / etc. Using non avg meter values as loggers have their own smoothing 189 | log_data = { 190 | "loss": loss_m.val, 191 | "data_time": data_time_m.val, 192 | "batch_time": batch_time_m.val, 193 | "lr": optimizer.param_groups[0]["lr"], 194 | } 195 | for name, val in log_data.items(): 196 | name = f"train{extra_suffix}/{name}" 197 | if tb_writer is not None: 198 | tb_writer.add_scalar(name, val, step) 199 | if args.wandb: 200 | assert wandb is not None, "Please install wandb." 201 | wandb.log({name: val, "step": step}) 202 | 203 | # resetting batch / data time meters per log window 204 | batch_time_m.reset() 205 | data_time_m.reset() 206 | # end for 207 | 208 | 209 | def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""): 210 | metrics = {} 211 | if not args.parallel_eval: 212 | if not is_master(args): 213 | return metrics 214 | device = torch.device(args.device) 215 | model.eval() 216 | 217 | # CHANGE 218 | # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) 219 | # metrics.update(zero_shot_metrics) 220 | if is_master(args): 221 | print("Evaluating...") 222 | metric_names = args.lp_metrics.split(",") 223 | eval_tool = LPMetrics(metric_names=metric_names) 224 | 225 | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress 226 | if "val" in data and ( 227 | args.val_frequency 228 | and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) 229 | ): 230 | if args.parallel_eval: 231 | dataloader, sampler = data["val"].dataloader, data["val"].sampler 232 | if args.distributed and sampler is not None: 233 | sampler.set_epoch(epoch) 234 | samples_per_val = dataloader.num_samples 235 | else: 236 | dataloader = data["val"].dataloader 237 | num_samples = 0 238 | samples_per_val = dataloader.num_samples 239 | 240 | eval_info = {"pred": [], "target": []} 241 | with torch.no_grad(): 242 | for i, batch in enumerate(dataloader): 243 | audio = batch # contains mel_spec, wavform, and longer list 244 | class_label = batch["class_label"] 245 | 246 | # audio = audio.to(device=device, non_blocking=True) 247 | class_label = class_label.to(device=device, non_blocking=True) 248 | 249 | with autocast(): 250 | pred = model(audio, device=device) 251 | if args.parallel_eval: 252 | pred, class_label = lp_gather_features( 253 | pred, class_label, args.world_size, args.horovod 254 | ) 255 | eval_info["pred"].append(pred) 256 | eval_info["target"].append(class_label) 257 | 258 | num_samples += class_label.shape[0] 259 | 260 | if (i % 100) == 0: # and i != 0: 261 | logging.info( 262 | f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" 263 | ) 264 | 265 | if is_master(args): 266 | eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu() 267 | eval_info["target"] = torch.cat(eval_info["target"], 0).cpu() 268 | metric_dict = eval_tool.evaluate_mertics( 269 | eval_info["pred"], eval_info["target"] 270 | ) 271 | metrics.update(metric_dict) 272 | if "epoch" not in metrics.keys(): 273 | metrics.update({"epoch": epoch}) 274 | 275 | if is_master(args): 276 | if not metrics: 277 | return metrics 278 | 279 | logging.info( 280 | f"Eval Epoch: {epoch} " 281 | + "\n".join( 282 | ["\t".join([f"{m}: {round(metrics[m], 4):.4f}"]) for m in metrics] 283 | ) 284 | ) 285 | if args.save_logs: 286 | for name, val in metrics.items(): 287 | if tb_writer is not None: 288 | tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch) 289 | 290 | with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: 291 | f.write(json.dumps(metrics)) 292 | f.write("\n") 293 | 294 | if args.wandb: 295 | assert wandb is not None, "Please install wandb." 296 | for name, val in metrics.items(): 297 | wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch}) 298 | 299 | return metrics 300 | else: 301 | return metrics 302 | -------------------------------------------------------------------------------- /models/CLAP/training/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | e = step - warmup_length 19 | es = steps - warmup_length 20 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 21 | assign_learning_rate(optimizer, lr) 22 | return lr 23 | 24 | return _lr_adjuster 25 | -------------------------------------------------------------------------------- /models/CLAP/training/zero_shot.py: -------------------------------------------------------------------------------- 1 | # NOTE: This script is currently not supported for CLAP. 2 | import logging 3 | from contextlib import suppress 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | 9 | from open_clip import tokenize 10 | from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template 11 | 12 | 13 | def zero_shot_classifier(model, classnames, templates, args): 14 | with torch.no_grad(): 15 | zeroshot_weights = [] 16 | for classname in tqdm(classnames): 17 | texts = [template(classname) for template in templates] # format with class 18 | texts = tokenize(texts).to(args.device) # tokenize 19 | if args.distributed and not args.horovod: 20 | class_embeddings = model.module.encode_text(texts) 21 | else: 22 | class_embeddings = model.encode_text(texts) 23 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 24 | class_embedding /= class_embedding.norm() 25 | zeroshot_weights.append(class_embedding) 26 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) 27 | return zeroshot_weights 28 | 29 | 30 | def accuracy(output, target, topk=(1,)): 31 | pred = output.topk(max(topk), 1, True, True)[1].t() 32 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 33 | return [ 34 | float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) 35 | for k in topk 36 | ] 37 | 38 | 39 | def run(model, classifier, dataloader, args): 40 | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress 41 | with torch.no_grad(): 42 | top1, top5, n = 0.0, 0.0, 0.0 43 | for images, target in tqdm(dataloader, unit_scale=args.batch_size): 44 | images = images.to(args.device) 45 | target = target.to(args.device) 46 | 47 | with autocast(): 48 | # predict 49 | if args.distributed and not args.horovod: 50 | image_features = model.module.encode_image(images) 51 | else: 52 | image_features = model.encode_image(images) 53 | image_features = F.normalize(image_features, dim=-1) 54 | logits = 100.0 * image_features @ classifier 55 | 56 | # measure accuracy 57 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 58 | top1 += acc1 59 | top5 += acc5 60 | n += images.size(0) 61 | 62 | top1 = top1 / n 63 | top5 = top5 / n 64 | return top1, top5 65 | 66 | 67 | def zero_shot_eval(model, data, epoch, args): 68 | if "imagenet-val" not in data and "imagenet-v2" not in data: 69 | return {} 70 | if args.zeroshot_frequency == 0: 71 | return {} 72 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 73 | return {} 74 | 75 | logging.info("Starting zero-shot imagenet.") 76 | 77 | logging.info("Building zero-shot classifier") 78 | classifier = zero_shot_classifier( 79 | model, imagenet_classnames, openai_imagenet_template, args 80 | ) 81 | 82 | logging.info("Using classifier") 83 | results = {} 84 | if "imagenet-val" in data: 85 | top1, top5 = run(model, classifier, data["imagenet-val"].dataloader, args) 86 | results["imagenet-zeroshot-val-top1"] = top1 87 | results["imagenet-zeroshot-val-top5"] = top5 88 | if "imagenet-v2" in data: 89 | top1, top5 = run(model, classifier, data["imagenet-v2"].dataloader, args) 90 | results["imagenetv2-zeroshot-val-top1"] = top1 91 | results["imagenetv2-zeroshot-val-top5"] = top5 92 | 93 | logging.info("Finished zero-shot imagenet.") 94 | 95 | return results 96 | -------------------------------------------------------------------------------- /models/audiosep.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict 2 | import random 3 | import lightning.pytorch as pl 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.optim.lr_scheduler import LambdaLR 8 | 9 | from models.clap_encoder import CLAP_Encoder 10 | 11 | from huggingface_hub import PyTorchModelHubMixin 12 | 13 | 14 | class AudioSep(pl.LightningModule, PyTorchModelHubMixin): 15 | def __init__( 16 | self, 17 | ss_model: nn.Module = None, 18 | waveform_mixer = None, 19 | query_encoder: nn.Module = CLAP_Encoder().eval(), 20 | loss_function = None, 21 | optimizer_type: str = None, 22 | learning_rate: float = None, 23 | lr_lambda_func = None, 24 | use_text_ratio: float =1.0, 25 | ): 26 | r"""Pytorch Lightning wrapper of PyTorch model, including forward, 27 | optimization of model, etc. 28 | 29 | Args: 30 | ss_model: nn.Module 31 | anchor_segment_detector: nn.Module 32 | loss_function: function or object 33 | learning_rate: float 34 | lr_lambda: function 35 | """ 36 | 37 | super().__init__() 38 | self.ss_model = ss_model 39 | self.waveform_mixer = waveform_mixer 40 | self.query_encoder = query_encoder 41 | self.query_encoder_type = self.query_encoder.encoder_type 42 | self.use_text_ratio = use_text_ratio 43 | self.loss_function = loss_function 44 | self.optimizer_type = optimizer_type 45 | self.learning_rate = learning_rate 46 | self.lr_lambda_func = lr_lambda_func 47 | 48 | 49 | def forward(self, x): 50 | pass 51 | 52 | def training_step(self, batch_data_dict, batch_idx): 53 | r"""Forward a mini-batch data to model, calculate loss function, and 54 | train for one step. A mini-batch data is evenly distributed to multiple 55 | devices (if there are) for parallel training. 56 | 57 | Args: 58 | batch_data_dict: e.g. 59 | 'audio_text': { 60 | 'text': ['a sound of dog', ...] 61 | 'waveform': (batch_size, 1, samples) 62 | } 63 | batch_idx: int 64 | 65 | Returns: 66 | loss: float, loss function of this mini-batch 67 | """ 68 | # [important] fix random seeds across devices 69 | random.seed(batch_idx) 70 | 71 | batch_audio_text_dict = batch_data_dict['audio_text'] 72 | 73 | batch_text = batch_audio_text_dict['text'] 74 | batch_audio = batch_audio_text_dict['waveform'] 75 | device = batch_audio.device 76 | 77 | mixtures, segments = self.waveform_mixer( 78 | waveforms=batch_audio 79 | ) 80 | 81 | # calculate text embed for audio-text data 82 | if self.query_encoder_type == 'CLAP': 83 | conditions = self.query_encoder.get_query_embed( 84 | modality='hybird', 85 | text=batch_text, 86 | audio=segments.squeeze(1), 87 | use_text_ratio=self.use_text_ratio, 88 | ) 89 | 90 | input_dict = { 91 | 'mixture': mixtures[:, None, :].squeeze(1), 92 | 'condition': conditions, 93 | } 94 | 95 | target_dict = { 96 | 'segment': segments.squeeze(1), 97 | } 98 | 99 | self.ss_model.train() 100 | sep_segment = self.ss_model(input_dict)['waveform'] 101 | sep_segment = sep_segment.squeeze() 102 | # (batch_size, 1, segment_samples) 103 | 104 | output_dict = { 105 | 'segment': sep_segment, 106 | } 107 | 108 | # Calculate loss. 109 | loss = self.loss_function(output_dict, target_dict) 110 | 111 | self.log_dict({"train_loss": loss}) 112 | 113 | return loss 114 | 115 | def test_step(self, batch, batch_idx): 116 | pass 117 | 118 | def configure_optimizers(self): 119 | r"""Configure optimizer. 120 | """ 121 | 122 | if self.optimizer_type == "AdamW": 123 | optimizer = optim.AdamW( 124 | params=self.ss_model.parameters(), 125 | lr=self.learning_rate, 126 | betas=(0.9, 0.999), 127 | eps=1e-08, 128 | weight_decay=0.0, 129 | amsgrad=True, 130 | ) 131 | else: 132 | raise NotImplementedError 133 | 134 | scheduler = LambdaLR(optimizer, self.lr_lambda_func) 135 | 136 | output_dict = { 137 | "optimizer": optimizer, 138 | "lr_scheduler": { 139 | 'scheduler': scheduler, 140 | 'interval': 'step', 141 | 'frequency': 1, 142 | } 143 | } 144 | 145 | return output_dict 146 | 147 | 148 | def get_model_class(model_type): 149 | if model_type == 'ResUNet30': 150 | from models.resunet import ResUNet30 151 | return ResUNet30 152 | 153 | else: 154 | raise NotImplementedError 155 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import math 6 | from torchlibrosa.stft import magphase 7 | 8 | 9 | def init_layer(layer): 10 | """Initialize a Linear or Convolutional layer. """ 11 | nn.init.xavier_uniform_(layer.weight) 12 | 13 | if hasattr(layer, "bias"): 14 | if layer.bias is not None: 15 | layer.bias.data.fill_(0.0) 16 | 17 | 18 | def init_bn(bn): 19 | """Initialize a Batchnorm layer. """ 20 | bn.bias.data.fill_(0.0) 21 | bn.weight.data.fill_(1.0) 22 | 23 | 24 | def init_embedding(layer): 25 | """Initialize a Linear or Convolutional layer. """ 26 | nn.init.uniform_(layer.weight, -1., 1.) 27 | 28 | if hasattr(layer, 'bias'): 29 | if layer.bias is not None: 30 | layer.bias.data.fill_(0.) 31 | 32 | 33 | def init_gru(rnn): 34 | """Initialize a GRU layer. """ 35 | 36 | def _concat_init(tensor, init_funcs): 37 | (length, fan_out) = tensor.shape 38 | fan_in = length // len(init_funcs) 39 | 40 | for (i, init_func) in enumerate(init_funcs): 41 | init_func(tensor[i * fan_in : (i + 1) * fan_in, :]) 42 | 43 | def _inner_uniform(tensor): 44 | fan_in = nn.init._calculate_correct_fan(tensor, "fan_in") 45 | nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) 46 | 47 | for i in range(rnn.num_layers): 48 | _concat_init( 49 | getattr(rnn, "weight_ih_l{}".format(i)), 50 | [_inner_uniform, _inner_uniform, _inner_uniform], 51 | ) 52 | torch.nn.init.constant_(getattr(rnn, "bias_ih_l{}".format(i)), 0) 53 | 54 | _concat_init( 55 | getattr(rnn, "weight_hh_l{}".format(i)), 56 | [_inner_uniform, _inner_uniform, nn.init.orthogonal_], 57 | ) 58 | torch.nn.init.constant_(getattr(rnn, "bias_hh_l{}".format(i)), 0) 59 | 60 | 61 | def act(x, activation): 62 | if activation == "relu": 63 | return F.relu_(x) 64 | 65 | elif activation == "leaky_relu": 66 | return F.leaky_relu_(x, negative_slope=0.01) 67 | 68 | elif activation == "swish": 69 | return x * torch.sigmoid(x) 70 | 71 | else: 72 | raise Exception("Incorrect activation!") 73 | 74 | 75 | class Base: 76 | def __init__(self): 77 | pass 78 | 79 | def spectrogram(self, input, eps=0.): 80 | (real, imag) = self.stft(input) 81 | return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 82 | 83 | def spectrogram_phase(self, input, eps=0.): 84 | (real, imag) = self.stft(input) 85 | mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 86 | cos = real / mag 87 | sin = imag / mag 88 | return mag, cos, sin 89 | 90 | 91 | def wav_to_spectrogram_phase(self, input, eps=1e-10): 92 | """Waveform to spectrogram. 93 | 94 | Args: 95 | input: (batch_size, segment_samples, channels_num) 96 | 97 | Outputs: 98 | output: (batch_size, channels_num, time_steps, freq_bins) 99 | """ 100 | sp_list = [] 101 | cos_list = [] 102 | sin_list = [] 103 | channels_num = input.shape[1] 104 | for channel in range(channels_num): 105 | mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps) 106 | sp_list.append(mag) 107 | cos_list.append(cos) 108 | sin_list.append(sin) 109 | 110 | sps = torch.cat(sp_list, dim=1) 111 | coss = torch.cat(cos_list, dim=1) 112 | sins = torch.cat(sin_list, dim=1) 113 | return sps, coss, sins 114 | 115 | def wav_to_spectrogram(self, input, eps=0.): 116 | """Waveform to spectrogram. 117 | 118 | Args: 119 | input: (batch_size, segment_samples, channels_num) 120 | 121 | Outputs: 122 | output: (batch_size, channels_num, time_steps, freq_bins) 123 | """ 124 | sp_list = [] 125 | channels_num = input.shape[1] 126 | for channel in range(channels_num): 127 | sp_list.append(self.spectrogram(input[:, channel, :], eps=eps)) 128 | 129 | output = torch.cat(sp_list, dim=1) 130 | return output 131 | 132 | 133 | def spectrogram_to_wav(self, input, spectrogram, length=None): 134 | """Spectrogram to waveform. 135 | 136 | Args: 137 | input: (batch_size, segment_samples, channels_num) 138 | spectrogram: (batch_size, channels_num, time_steps, freq_bins) 139 | 140 | Outputs: 141 | output: (batch_size, segment_samples, channels_num) 142 | """ 143 | channels_num = input.shape[1] 144 | wav_list = [] 145 | for channel in range(channels_num): 146 | (real, imag) = self.stft(input[:, channel, :]) 147 | (_, cos, sin) = magphase(real, imag) 148 | wav_list.append(self.istft(spectrogram[:, channel : channel + 1, :, :] * cos, 149 | spectrogram[:, channel : channel + 1, :, :] * sin, length)) 150 | 151 | output = torch.stack(wav_list, dim=1) 152 | return output 153 | -------------------------------------------------------------------------------- /models/clap_encoder.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | import torchaudio 5 | from models.CLAP.open_clip import create_model 6 | from models.CLAP.training.data import get_audio_features 7 | from transformers import RobertaTokenizer 8 | 9 | 10 | class CLAP_Encoder(nn.Module): 11 | def __init__( 12 | self, 13 | pretrained_path='checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt', 14 | sampling_rate=32000, 15 | amodel = "HTSAT-base", 16 | ): 17 | super().__init__() 18 | self.device = "cpu" 19 | self.precision = "fp32" 20 | self.amodel = amodel # or 'PANN-14' 21 | self.tmodel = "roberta" # the best text encoder in our training 22 | self.enable_fusion = False # False if you do not want to use the fusion model 23 | self.fusion_type = "aff_2d" 24 | self.pretrained = pretrained_path 25 | self.sampling_rate = sampling_rate 26 | self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") 27 | 28 | self.model, self.model_cfg = create_model( 29 | self.amodel, 30 | self.tmodel, 31 | self.pretrained, 32 | precision=self.precision, 33 | device=self.device, 34 | enable_fusion=self.enable_fusion, 35 | fusion_type=self.fusion_type, 36 | ) 37 | 38 | for p in self.model.parameters(): 39 | p.requires_grad = False 40 | 41 | self.model.eval() 42 | self.encoder_type = 'CLAP' 43 | 44 | def batch_to_list(self, batch): 45 | ret = [] 46 | for i in range(batch.size(0)): 47 | ret.append(batch[i]) 48 | return ret 49 | 50 | def _get_audio_embed(self, batch): 51 | # batch: [B, samples] 52 | with torch.no_grad(): 53 | audio_dict_list = [] 54 | assert ( 55 | self.sampling_rate == 32000 56 | ), "We only support 32000 sampling rate" 57 | 58 | # batch: [bs, 1, t-samples] 59 | batch = torchaudio.functional.resample( 60 | batch, orig_freq=self.sampling_rate, new_freq=48000 61 | ) 62 | for waveform in self.batch_to_list(batch): 63 | audio_dict = {} 64 | audio_dict = get_audio_features( 65 | audio_dict, 66 | waveform, 67 | 480000, 68 | data_truncating="fusion", 69 | data_filling="repeatpad", 70 | audio_cfg=self.model_cfg["audio_cfg"], 71 | ) 72 | audio_dict_list.append(audio_dict) 73 | # [bs, 512] 74 | embed = self.model.get_audio_embedding(audio_dict_list) 75 | 76 | return embed.detach() 77 | 78 | def _get_text_embed(self, batch): 79 | double_batch = False 80 | if len(batch) == 1: 81 | batch = batch * 2 82 | double_batch = True 83 | with torch.no_grad(): 84 | # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode 85 | text_data = self.tokenizer(batch) 86 | embed = self.model.get_text_embedding(text_data) 87 | if double_batch: 88 | embed = embed[0].unsqueeze(0) 89 | 90 | return embed.detach() 91 | 92 | 93 | def get_query_embed(self, modality, audio=None, text=None, use_text_ratio=0.5, device=None): 94 | if modality == 'audio': 95 | embed = self._get_audio_embed(audio) 96 | elif modality == 'text': 97 | embed = self._get_text_embed(text) 98 | elif modality == 'hybird': 99 | if random.random() > use_text_ratio: 100 | embed = self._get_audio_embed(audio) 101 | else: 102 | embed = self._get_text_embed(text) 103 | else: 104 | raise NotImplementedError("Please check flag 'training_modality'.") 105 | 106 | return embed.float() 107 | 108 | def tokenizer(self, text): 109 | result = self.tokenize( 110 | text, 111 | padding="max_length", 112 | truncation=True, 113 | max_length=512, 114 | return_tensors="pt", 115 | ) 116 | return {k: v.squeeze(0) for k, v in result.items()} 117 | -------------------------------------------------------------------------------- /optimizers/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Callable 3 | 4 | 5 | def linear_warm_up( 6 | step: int, 7 | warm_up_steps: int, 8 | reduce_lr_steps: int 9 | ) -> float: 10 | r"""Get linear warm up scheduler for LambdaLR. 11 | 12 | Args: 13 | step (int): global step 14 | warm_up_steps (int): steps for warm up 15 | reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step 16 | 17 | .. code-block: python 18 | >>> lr_lambda = partial(linear_warm_up, warm_up_steps=1000, reduce_lr_steps=10000) 19 | >>> from torch.optim.lr_scheduler import LambdaLR 20 | >>> LambdaLR(optimizer, lr_lambda) 21 | 22 | Returns: 23 | lr_scale (float): learning rate scaler 24 | """ 25 | 26 | if step <= warm_up_steps: 27 | lr_scale = step / warm_up_steps 28 | else: 29 | lr_scale = 0.9 ** (step // reduce_lr_steps) 30 | 31 | return lr_scale 32 | 33 | 34 | def constant_warm_up( 35 | step: int, 36 | warm_up_steps: int, 37 | reduce_lr_steps: int 38 | ) -> float: 39 | r"""Get constant warm up scheduler for LambdaLR. 40 | 41 | Args: 42 | step (int): global step 43 | warm_up_steps (int): steps for warm up 44 | reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step 45 | 46 | .. code-block: python 47 | >>> lr_lambda = partial(constant_warm_up, warm_up_steps=1000, reduce_lr_steps=10000) 48 | >>> from torch.optim.lr_scheduler import LambdaLR 49 | >>> LambdaLR(optimizer, lr_lambda) 50 | 51 | Returns: 52 | lr_scale (float): learning rate scaler 53 | """ 54 | 55 | if 0 <= step < warm_up_steps: 56 | lr_scale = 0.001 57 | 58 | elif warm_up_steps <= step < 2 * warm_up_steps: 59 | lr_scale = 0.01 60 | 61 | elif 2 * warm_up_steps <= step < 3 * warm_up_steps: 62 | lr_scale = 0.1 63 | 64 | else: 65 | lr_scale = 1 66 | 67 | return lr_scale 68 | 69 | 70 | def get_lr_lambda( 71 | lr_lambda_type: str, 72 | **kwargs 73 | ) -> Callable: 74 | r"""Get learning scheduler. 75 | 76 | Args: 77 | lr_lambda_type (str), e.g., "constant_warm_up" | "linear_warm_up" 78 | 79 | Returns: 80 | lr_lambda_func (Callable) 81 | """ 82 | if lr_lambda_type == "constant_warm_up": 83 | 84 | lr_lambda_func = partial( 85 | constant_warm_up, 86 | warm_up_steps=kwargs["warm_up_steps"], 87 | reduce_lr_steps=kwargs["reduce_lr_steps"], 88 | ) 89 | 90 | elif lr_lambda_type == "linear_warm_up": 91 | 92 | lr_lambda_func = partial( 93 | linear_warm_up, 94 | warm_up_steps=kwargs["warm_up_steps"], 95 | reduce_lr_steps=kwargs["reduce_lr_steps"], 96 | ) 97 | 98 | else: 99 | raise NotImplementedError 100 | 101 | return lr_lambda_func 102 | -------------------------------------------------------------------------------- /pipeline.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from typing import List 3 | import torch 4 | import numpy as np 5 | import librosa 6 | from scipy.io.wavfile import write 7 | from utils import ignore_warnings, parse_yaml, load_ss_model 8 | from models.clap_encoder import CLAP_Encoder 9 | 10 | def build_audiosep(config_yaml, checkpoint_path, device): 11 | ignore_warnings() 12 | configs = parse_yaml(config_yaml) 13 | 14 | query_encoder = CLAP_Encoder().eval() 15 | model = load_ss_model(configs=configs, checkpoint_path=checkpoint_path, query_encoder=query_encoder).eval().to(device) 16 | 17 | print(f'Loaded AudioSep model from [{checkpoint_path}]') 18 | return model 19 | 20 | def separate_audio(model, audio_file, text, output_file, device='cuda', use_chunk=False): 21 | print(f'Separating audio from [{audio_file}] with textual query: [{text}]') 22 | mixture, fs = librosa.load(audio_file, sr=32000, mono=True) 23 | with torch.no_grad(): 24 | text = [text] 25 | 26 | conditions = model.query_encoder.get_query_embed( 27 | modality='text', 28 | text=text, 29 | device=device 30 | ) 31 | 32 | input_dict = { 33 | "mixture": torch.Tensor(mixture)[None, None, :].to(device), 34 | "condition": conditions, 35 | } 36 | 37 | if use_chunk: 38 | sep_segment = model.ss_model.chunk_inference(input_dict) 39 | sep_segment = np.squeeze(sep_segment) 40 | else: 41 | sep_segment = model.ss_model(input_dict)["waveform"] 42 | sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() 43 | 44 | write(output_file, 32000, np.round(sep_segment * 32767).astype(np.int16)) 45 | print(f'Separated audio written to [{output_file}]') 46 | 47 | if __name__ == '__main__': 48 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 49 | model = build_audiosep( 50 | config_yaml='config/audiosep_base.yaml', 51 | checkpoint_path='checkpoint/step=3920000.ckpt', 52 | device=device) 53 | 54 | audio_file = '/mnt/bn/data-xubo/project/AudioShop/YT_audios/Y3VHpLxtd498.wav' 55 | text = 'pigeons are cooing in the background' 56 | output_file = 'separated_audio.wav' 57 | 58 | separate_audio(model, audio_file, text, output_file, device) 59 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | 4 | import torch 5 | from cog import BasePredictor, Input, Path 6 | 7 | from pipeline import build_audiosep, inference 8 | 9 | 10 | class Predictor(BasePredictor): 11 | def setup(self) -> None: 12 | """Load the model into memory to make running multiple predictions efficient""" 13 | 14 | self.model = build_audiosep( 15 | config_yaml="config/audiosep_base.yaml", 16 | checkpoint_path="checkpoint/audiosep_base_4M_steps.ckpt", 17 | device="cuda", 18 | ) 19 | 20 | def predict( 21 | self, 22 | audio_file: Path = Input(description="Input audio file."), 23 | text: str = Input(description="Input text.", default="water drops"), 24 | ) -> Path: 25 | """Run a single prediction on the model""" 26 | 27 | output_file = "/tmp/separated_audio.wav" 28 | 29 | # AudioSep processes the audio at 32 kHz sampling rate 30 | inference(self.model, str(audio_file), text, output_file, "cuda") 31 | return Path(output_file) 32 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pathlib 5 | from typing import List, NoReturn 6 | import lightning.pytorch as pl 7 | from lightning.pytorch.strategies import DDPStrategy 8 | from torch.utils.tensorboard import SummaryWriter 9 | from data.datamodules import * 10 | from utils import create_logging, parse_yaml 11 | from models.resunet import * 12 | from losses import get_loss_function 13 | from models.audiosep import AudioSep, get_model_class 14 | from data.waveform_mixers import SegmentMixer 15 | from models.clap_encoder import CLAP_Encoder 16 | from callbacks.base import CheckpointEveryNSteps 17 | from optimizers.lr_schedulers import get_lr_lambda 18 | 19 | 20 | def get_dirs( 21 | workspace: str, 22 | filename: str, 23 | config_yaml: str, 24 | devices_num: int 25 | ) -> List[str]: 26 | r"""Get directories and paths. 27 | 28 | Args: 29 | workspace (str): directory of workspace 30 | filename (str): filename of current .py file. 31 | config_yaml (str): config yaml path 32 | devices_num (int): 0 for cpu and 8 for training with 8 GPUs 33 | 34 | Returns: 35 | checkpoints_dir (str): directory to save checkpoints 36 | logs_dir (str), directory to save logs 37 | tf_logs_dir (str), directory to save TensorBoard logs 38 | statistics_path (str), directory to save statistics 39 | """ 40 | 41 | os.makedirs(workspace, exist_ok=True) 42 | 43 | yaml_name = pathlib.Path(config_yaml).stem 44 | 45 | # Directory to save checkpoints 46 | checkpoints_dir = os.path.join( 47 | workspace, 48 | "checkpoints", 49 | filename, 50 | "{},devices={}".format(yaml_name, devices_num), 51 | ) 52 | os.makedirs(checkpoints_dir, exist_ok=True) 53 | 54 | # Directory to save logs 55 | logs_dir = os.path.join( 56 | workspace, 57 | "logs", 58 | filename, 59 | "{},devices={}".format(yaml_name, devices_num), 60 | ) 61 | os.makedirs(logs_dir, exist_ok=True) 62 | 63 | # Directory to save TensorBoard logs 64 | create_logging(logs_dir, filemode="w") 65 | logging.info(args) 66 | 67 | tf_logs_dir = os.path.join( 68 | workspace, 69 | "tf_logs", 70 | filename, 71 | "{},devices={}".format(yaml_name, devices_num), 72 | ) 73 | 74 | # Directory to save statistics 75 | statistics_path = os.path.join( 76 | workspace, 77 | "statistics", 78 | filename, 79 | "{},devices={}".format(yaml_name, devices_num), 80 | "statistics.pkl", 81 | ) 82 | os.makedirs(os.path.dirname(statistics_path), exist_ok=True) 83 | 84 | return checkpoints_dir, logs_dir, tf_logs_dir, statistics_path 85 | 86 | 87 | def get_data_module( 88 | config_yaml: str, 89 | num_workers: int, 90 | batch_size: int, 91 | ) -> DataModule: 92 | r"""Create data_module. Mini-batch data can be obtained by: 93 | 94 | code-block:: python 95 | 96 | data_module.setup() 97 | 98 | for batch_data_dict in data_module.train_dataloader(): 99 | print(batch_data_dict.keys()) 100 | break 101 | 102 | Args: 103 | workspace: str 104 | config_yaml: str 105 | num_workers: int, e.g., 0 for non-parallel and 8 for using cpu cores 106 | for preparing data in parallel 107 | distributed: bool 108 | 109 | Returns: 110 | data_module: DataModule 111 | """ 112 | 113 | # read configurations 114 | configs = parse_yaml(config_yaml) 115 | sampling_rate = configs['data']['sampling_rate'] 116 | segment_seconds = configs['data']['segment_seconds'] 117 | 118 | # audio-text datasets 119 | datafiles = configs['data']['datafiles'] 120 | 121 | # dataset 122 | dataset = AudioTextDataset( 123 | datafiles=datafiles, 124 | sampling_rate=sampling_rate, 125 | max_clip_len=segment_seconds, 126 | ) 127 | 128 | 129 | # data module 130 | data_module = DataModule( 131 | train_dataset=dataset, 132 | num_workers=num_workers, 133 | batch_size=batch_size 134 | ) 135 | 136 | return data_module 137 | 138 | 139 | def train(args) -> NoReturn: 140 | r"""Train, evaluate, and save checkpoints. 141 | 142 | Args: 143 | workspace: str, directory of workspace 144 | gpus: int, number of GPUs to train 145 | config_yaml: str 146 | """ 147 | 148 | # arguments & parameters 149 | workspace = args.workspace 150 | config_yaml = args.config_yaml 151 | filename = args.filename 152 | 153 | devices_num = torch.cuda.device_count() 154 | # Read config file. 155 | configs = parse_yaml(config_yaml) 156 | 157 | # Configuration of data 158 | max_mix_num = configs['data']['max_mix_num'] 159 | sampling_rate = configs['data']['sampling_rate'] 160 | lower_db = configs['data']['loudness_norm']['lower_db'] 161 | higher_db = configs['data']['loudness_norm']['higher_db'] 162 | 163 | # Configuration of the separation model 164 | query_net = configs['model']['query_net'] 165 | model_type = configs['model']['model_type'] 166 | input_channels = configs['model']['input_channels'] 167 | output_channels = configs['model']['output_channels'] 168 | condition_size = configs['model']['condition_size'] 169 | use_text_ratio = configs['model']['use_text_ratio'] 170 | 171 | # Configuration of the trainer 172 | num_nodes = configs['train']['num_nodes'] 173 | batch_size = configs['train']['batch_size_per_device'] 174 | sync_batchnorm = configs['train']['sync_batchnorm'] 175 | num_workers = configs['train']['num_workers'] 176 | loss_type = configs['train']['loss_type'] 177 | optimizer_type = configs["train"]["optimizer"]["optimizer_type"] 178 | learning_rate = float(configs['train']["optimizer"]['learning_rate']) 179 | lr_lambda_type = configs['train']["optimizer"]['lr_lambda_type'] 180 | warm_up_steps = configs['train']["optimizer"]['warm_up_steps'] 181 | reduce_lr_steps = configs['train']["optimizer"]['reduce_lr_steps'] 182 | save_step_frequency = configs['train']['save_step_frequency'] 183 | resume_checkpoint_path = args.resume_checkpoint_path 184 | if resume_checkpoint_path == "": 185 | resume_checkpoint_path = None 186 | else: 187 | logging.info(f'Finetuning AudioSep with checkpoint [{resume_checkpoint_path}]') 188 | 189 | # Get directories and paths 190 | checkpoints_dir, logs_dir, tf_logs_dir, statistics_path = get_dirs( 191 | workspace, filename, config_yaml, devices_num, 192 | ) 193 | 194 | logging.info(configs) 195 | 196 | # data module 197 | data_module = get_data_module( 198 | config_yaml=config_yaml, 199 | batch_size=batch_size, 200 | num_workers=num_workers, 201 | ) 202 | 203 | # model 204 | Model = get_model_class(model_type=model_type) 205 | 206 | ss_model = Model( 207 | input_channels=input_channels, 208 | output_channels=output_channels, 209 | condition_size=condition_size, 210 | ) 211 | 212 | # loss function 213 | loss_function = get_loss_function(loss_type) 214 | 215 | segment_mixer = SegmentMixer( 216 | max_mix_num=max_mix_num, 217 | lower_db=lower_db, 218 | higher_db=higher_db 219 | ) 220 | 221 | 222 | if query_net == 'CLAP': 223 | query_encoder = CLAP_Encoder() 224 | else: 225 | raise NotImplementedError 226 | 227 | lr_lambda_func = get_lr_lambda( 228 | lr_lambda_type=lr_lambda_type, 229 | warm_up_steps=warm_up_steps, 230 | reduce_lr_steps=reduce_lr_steps, 231 | ) 232 | 233 | # pytorch-lightning model 234 | pl_model = AudioSep( 235 | ss_model=ss_model, 236 | waveform_mixer=segment_mixer, 237 | query_encoder=query_encoder, 238 | loss_function=loss_function, 239 | optimizer_type=optimizer_type, 240 | learning_rate=learning_rate, 241 | lr_lambda_func=lr_lambda_func, 242 | use_text_ratio=use_text_ratio 243 | ) 244 | 245 | checkpoint_every_n_steps = CheckpointEveryNSteps( 246 | checkpoints_dir=checkpoints_dir, 247 | save_step_frequency=save_step_frequency, 248 | ) 249 | 250 | summary_writer = SummaryWriter(log_dir=tf_logs_dir) 251 | 252 | callbacks = [checkpoint_every_n_steps] 253 | 254 | trainer = pl.Trainer( 255 | accelerator='auto', 256 | devices='auto', 257 | strategy='ddp_find_unused_parameters_true', 258 | num_nodes=num_nodes, 259 | precision="32-true", 260 | logger=None, 261 | callbacks=callbacks, 262 | fast_dev_run=False, 263 | max_epochs=-1, 264 | log_every_n_steps=50, 265 | use_distributed_sampler=True, 266 | sync_batchnorm=sync_batchnorm, 267 | num_sanity_val_steps=2, 268 | enable_checkpointing=False, 269 | enable_progress_bar=True, 270 | enable_model_summary=True, 271 | ) 272 | 273 | # Fit, evaluate, and save checkpoints. 274 | trainer.fit( 275 | model=pl_model, 276 | train_dataloaders=None, 277 | val_dataloaders=None, 278 | datamodule=data_module, 279 | ckpt_path=resume_checkpoint_path, 280 | ) 281 | 282 | 283 | if __name__ == "__main__": 284 | 285 | parser = argparse.ArgumentParser() 286 | parser.add_argument( 287 | "--workspace", type=str, required=True, help="Directory of workspace." 288 | ) 289 | parser.add_argument( 290 | "--config_yaml", 291 | type=str, 292 | required=True, 293 | help="Path of config file for training.", 294 | ) 295 | 296 | parser.add_argument( 297 | "--resume_checkpoint_path", 298 | type=str, 299 | required=True, 300 | default='', 301 | help="Path of pretrained checkpoint for finetuning.", 302 | ) 303 | 304 | args = parser.parse_args() 305 | args.filename = pathlib.Path(__file__).stem 306 | 307 | train(args) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import json 4 | import logging 5 | import librosa 6 | import pickle 7 | from typing import Dict 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import yaml 12 | from models.audiosep import AudioSep, get_model_class 13 | 14 | 15 | def ignore_warnings(): 16 | import warnings 17 | # Ignore UserWarning from torch.meshgrid 18 | warnings.filterwarnings('ignore', category=UserWarning, module='torch.functional') 19 | 20 | # Refined regex pattern to capture variations in the warning message 21 | pattern = r"Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: \['lm_head\..*'\].*" 22 | warnings.filterwarnings('ignore', message=pattern) 23 | 24 | 25 | 26 | def create_logging(log_dir, filemode): 27 | os.makedirs(log_dir, exist_ok=True) 28 | i1 = 0 29 | 30 | while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))): 31 | i1 += 1 32 | 33 | log_path = os.path.join(log_dir, "{:04d}.log".format(i1)) 34 | logging.basicConfig( 35 | level=logging.DEBUG, 36 | format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s", 37 | datefmt="%a, %d %b %Y %H:%M:%S", 38 | filename=log_path, 39 | filemode=filemode, 40 | ) 41 | 42 | # Print to console 43 | console = logging.StreamHandler() 44 | console.setLevel(logging.INFO) 45 | formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") 46 | console.setFormatter(formatter) 47 | logging.getLogger("").addHandler(console) 48 | 49 | return logging 50 | 51 | 52 | def float32_to_int16(x: float) -> int: 53 | x = np.clip(x, a_min=-1, a_max=1) 54 | return (x * 32767.0).astype(np.int16) 55 | 56 | 57 | def int16_to_float32(x: int) -> float: 58 | return (x / 32767.0).astype(np.float32) 59 | 60 | 61 | def parse_yaml(config_yaml: str) -> Dict: 62 | r"""Parse yaml file. 63 | 64 | Args: 65 | config_yaml (str): config yaml path 66 | 67 | Returns: 68 | yaml_dict (Dict): parsed yaml file 69 | """ 70 | 71 | with open(config_yaml, "r") as fr: 72 | return yaml.load(fr, Loader=yaml.FullLoader) 73 | 74 | 75 | def get_audioset632_id_to_lb(ontology_path: str) -> Dict: 76 | r"""Get AudioSet 632 classes ID to label mapping.""" 77 | 78 | audioset632_id_to_lb = {} 79 | 80 | with open(ontology_path) as f: 81 | data_list = json.load(f) 82 | 83 | for e in data_list: 84 | audioset632_id_to_lb[e["id"]] = e["name"] 85 | 86 | return audioset632_id_to_lb 87 | 88 | 89 | def load_pretrained_panns( 90 | model_type: str, 91 | checkpoint_path: str, 92 | freeze: bool 93 | ) -> nn.Module: 94 | r"""Load pretrained pretrained audio neural networks (PANNs). 95 | 96 | Args: 97 | model_type: str, e.g., "Cnn14" 98 | checkpoint_path, str, e.g., "Cnn14_mAP=0.431.pth" 99 | freeze: bool 100 | 101 | Returns: 102 | model: nn.Module 103 | """ 104 | 105 | if model_type == "Cnn14": 106 | Model = Cnn14 107 | 108 | elif model_type == "Cnn14_DecisionLevelMax": 109 | Model = Cnn14_DecisionLevelMax 110 | 111 | else: 112 | raise NotImplementedError 113 | 114 | model = Model(sample_rate=32000, window_size=1024, hop_size=320, 115 | mel_bins=64, fmin=50, fmax=14000, classes_num=527) 116 | 117 | if checkpoint_path: 118 | checkpoint = torch.load(checkpoint_path, map_location="cpu") 119 | model.load_state_dict(checkpoint["model"]) 120 | 121 | if freeze: 122 | for param in model.parameters(): 123 | param.requires_grad = False 124 | 125 | return model 126 | 127 | 128 | def energy(x): 129 | return torch.mean(x ** 2) 130 | 131 | 132 | def magnitude_to_db(x): 133 | eps = 1e-10 134 | return 20. * np.log10(max(x, eps)) 135 | 136 | 137 | def db_to_magnitude(x): 138 | return 10. ** (x / 20) 139 | 140 | 141 | def ids_to_hots(ids, classes_num, device): 142 | hots = torch.zeros(classes_num).to(device) 143 | for id in ids: 144 | hots[id] = 1 145 | return hots 146 | 147 | 148 | def calculate_sdr( 149 | ref: np.ndarray, 150 | est: np.ndarray, 151 | eps=1e-10 152 | ) -> float: 153 | r"""Calculate SDR between reference and estimation. 154 | 155 | Args: 156 | ref (np.ndarray), reference signal 157 | est (np.ndarray), estimated signal 158 | """ 159 | reference = ref 160 | noise = est - reference 161 | 162 | 163 | numerator = np.clip(a=np.mean(reference ** 2), a_min=eps, a_max=None) 164 | 165 | denominator = np.clip(a=np.mean(noise ** 2), a_min=eps, a_max=None) 166 | 167 | sdr = 10. * np.log10(numerator / denominator) 168 | 169 | return sdr 170 | 171 | 172 | def calculate_sisdr(ref, est): 173 | r"""Calculate SDR between reference and estimation. 174 | 175 | Args: 176 | ref (np.ndarray), reference signal 177 | est (np.ndarray), estimated signal 178 | """ 179 | 180 | eps = np.finfo(ref.dtype).eps 181 | 182 | reference = ref.copy() 183 | estimate = est.copy() 184 | 185 | reference = reference.reshape(reference.size, 1) 186 | estimate = estimate.reshape(estimate.size, 1) 187 | 188 | Rss = np.dot(reference.T, reference) 189 | # get the scaling factor for clean sources 190 | a = (eps + np.dot(reference.T, estimate)) / (Rss + eps) 191 | 192 | e_true = a * reference 193 | e_res = estimate - e_true 194 | 195 | Sss = (e_true**2).sum() 196 | Snn = (e_res**2).sum() 197 | 198 | sisdr = 10 * np.log10((eps+ Sss)/(eps + Snn)) 199 | 200 | return sisdr 201 | 202 | 203 | class StatisticsContainer(object): 204 | def __init__(self, statistics_path): 205 | self.statistics_path = statistics_path 206 | 207 | self.backup_statistics_path = "{}_{}.pkl".format( 208 | os.path.splitext(self.statistics_path)[0], 209 | datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 210 | ) 211 | 212 | self.statistics_dict = {"balanced_train": [], "test": []} 213 | 214 | def append(self, steps, statistics, split, flush=True): 215 | statistics["steps"] = steps 216 | self.statistics_dict[split].append(statistics) 217 | 218 | if flush: 219 | self.flush() 220 | 221 | def flush(self): 222 | pickle.dump(self.statistics_dict, open(self.statistics_path, "wb")) 223 | pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb")) 224 | logging.info(" Dump statistics to {}".format(self.statistics_path)) 225 | logging.info(" Dump statistics to {}".format(self.backup_statistics_path)) 226 | 227 | 228 | def get_mean_sdr_from_dict(sdris_dict): 229 | mean_sdr = np.nanmean(list(sdris_dict.values())) 230 | return mean_sdr 231 | 232 | 233 | def remove_silence(audio: np.ndarray, sample_rate: int) -> np.ndarray: 234 | r"""Remove silent frames.""" 235 | window_size = int(sample_rate * 0.1) 236 | threshold = 0.02 237 | 238 | frames = librosa.util.frame(x=audio, frame_length=window_size, hop_length=window_size).T 239 | # shape: (frames_num, window_size) 240 | 241 | new_frames = get_active_frames(frames, threshold) 242 | # shape: (new_frames_num, window_size) 243 | 244 | new_audio = new_frames.flatten() 245 | # shape: (new_audio_samples,) 246 | 247 | return new_audio 248 | 249 | 250 | def get_active_frames(frames: np.ndarray, threshold: float) -> np.ndarray: 251 | r"""Get active frames.""" 252 | 253 | energy = np.max(np.abs(frames), axis=-1) 254 | # shape: (frames_num,) 255 | 256 | active_indexes = np.where(energy > threshold)[0] 257 | # shape: (new_frames_num,) 258 | 259 | new_frames = frames[active_indexes] 260 | # shape: (new_frames_num,) 261 | 262 | return new_frames 263 | 264 | 265 | def repeat_to_length(audio: np.ndarray, segment_samples: int) -> np.ndarray: 266 | r"""Repeat audio to length.""" 267 | 268 | repeats_num = (segment_samples // audio.shape[-1]) + 1 269 | audio = np.tile(audio, repeats_num)[0 : segment_samples] 270 | 271 | return audio 272 | 273 | def calculate_segmentwise_sdr(ref, est, hop_samples, return_sdr_list=False): 274 | min_len = min(ref.shape[-1], est.shape[-1]) 275 | pointer = 0 276 | sdrs = [] 277 | while pointer + hop_samples < min_len: 278 | sdr = calculate_sdr( 279 | ref=ref[:, pointer : pointer + hop_samples], 280 | est=est[:, pointer : pointer + hop_samples], 281 | ) 282 | sdrs.append(sdr) 283 | pointer += hop_samples 284 | 285 | sdr = np.nanmedian(sdrs) 286 | 287 | if return_sdr_list: 288 | return sdr, sdrs 289 | else: 290 | return sdr 291 | 292 | 293 | def loudness(data, input_loudness, target_loudness): 294 | """ Loudness normalize a signal. 295 | 296 | Normalize an input signal to a user loudness in dB LKFS. 297 | 298 | Params 299 | ------- 300 | data : torch.Tensor 301 | Input multichannel audio data. 302 | input_loudness : float 303 | Loudness of the input in dB LUFS. 304 | target_loudness : float 305 | Target loudness of the output in dB LUFS. 306 | 307 | Returns 308 | ------- 309 | output : torch.Tensor 310 | Loudness normalized output data. 311 | """ 312 | 313 | # calculate the gain needed to scale to the desired loudness level 314 | delta_loudness = target_loudness - input_loudness 315 | gain = torch.pow(10.0, delta_loudness / 20.0) 316 | 317 | output = gain * data 318 | 319 | # check for potentially clipped samples 320 | # if torch.max(torch.abs(output)) >= 1.0: 321 | # warnings.warn("Possible clipped samples in output.") 322 | 323 | return output 324 | 325 | 326 | def get_ss_model(config_yaml) -> nn.Module: 327 | r"""Load trained universal source separation model. 328 | 329 | Args: 330 | configs (Dict) 331 | checkpoint_path (str): path of the checkpoint to load 332 | device (str): e.g., "cpu" | "cuda" 333 | 334 | Returns: 335 | pl_model: pl.LightningModule 336 | """ 337 | configs = parse_yaml(config_yaml) 338 | 339 | ss_model_type = configs["model"]["model_type"] 340 | input_channels = configs["model"]["input_channels"] 341 | output_channels = configs["model"]["output_channels"] 342 | condition_size = configs["model"]["condition_size"] 343 | 344 | # Initialize separation model 345 | SsModel = get_model_class(model_type=ss_model_type) 346 | 347 | ss_model = SsModel( 348 | input_channels=input_channels, 349 | output_channels=output_channels, 350 | condition_size=condition_size, 351 | ) 352 | 353 | return ss_model 354 | 355 | 356 | def load_ss_model( 357 | configs: Dict, 358 | checkpoint_path: str, 359 | query_encoder: nn.Module 360 | ) -> nn.Module: 361 | r"""Load trained universal source separation model. 362 | 363 | Args: 364 | configs (Dict) 365 | checkpoint_path (str): path of the checkpoint to load 366 | device (str): e.g., "cpu" | "cuda" 367 | 368 | Returns: 369 | pl_model: pl.LightningModule 370 | """ 371 | 372 | ss_model_type = configs["model"]["model_type"] 373 | input_channels = configs["model"]["input_channels"] 374 | output_channels = configs["model"]["output_channels"] 375 | condition_size = configs["model"]["condition_size"] 376 | 377 | # Initialize separation model 378 | SsModel = get_model_class(model_type=ss_model_type) 379 | 380 | ss_model = SsModel( 381 | input_channels=input_channels, 382 | output_channels=output_channels, 383 | condition_size=condition_size, 384 | ) 385 | 386 | # Load PyTorch Lightning model 387 | pl_model = AudioSep.load_from_checkpoint( 388 | checkpoint_path=checkpoint_path, 389 | strict=False, 390 | ss_model=ss_model, 391 | waveform_mixer=None, 392 | query_encoder=query_encoder, 393 | loss_function=None, 394 | optimizer_type=None, 395 | learning_rate=None, 396 | lr_lambda_func=None, 397 | map_location=torch.device('cpu'), 398 | ) 399 | 400 | return pl_model 401 | 402 | 403 | def parse_yaml(config_yaml: str) -> Dict: 404 | r"""Parse yaml file. 405 | 406 | Args: 407 | config_yaml (str): config yaml path 408 | 409 | Returns: 410 | yaml_dict (Dict): parsed yaml file 411 | """ 412 | 413 | with open(config_yaml, "r") as fr: 414 | return yaml.load(fr, Loader=yaml.FullLoader) --------------------------------------------------------------------------------