├── .gitignore ├── LICENSE ├── README.md ├── SECURITY.md ├── datasets ├── fma_pop_tracks.csv └── musiccaps-public-openai.csv ├── example ├── LICENSE_gudgud96 └── prompts │ ├── gpt4_quality.py │ └── gpt4_refine.py ├── fadtk ├── __init__.py ├── __main__.py ├── embeds.py ├── fad.py ├── fad_batch.py ├── model_loader.py ├── package.py ├── stats │ └── fma_pop.npz ├── test │ ├── __init__.py │ ├── __main__.py │ ├── samples │ │ ├── mg-1634.opus │ │ ├── mg-1648.opus │ │ ├── mg-1741.opus │ │ ├── mg-2344.opus │ │ ├── mg-2551.opus │ │ ├── mg-2759.opus │ │ ├── mg-284.opus │ │ ├── mg-483.opus │ │ ├── mg-66.opus │ │ ├── mg-911.opus │ │ ├── mg-974.opus │ │ ├── mlm-1619.opus │ │ ├── mlm-1698.opus │ │ ├── mlm-2940.opus │ │ ├── mlm-483.opus │ │ ├── mlm-974.opus │ │ ├── mubert-130.opus │ │ ├── mubert-1348.opus │ │ ├── mubert-1619.opus │ │ ├── mubert-2120.opus │ │ ├── mubert-2545.opus │ │ ├── mubert-2701.opus │ │ └── mubert-806.opus │ ├── samples_FAD_scores.csv │ └── test_cleanup.sh └── utils.py ├── pyproject.toml └── uv.lock /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # PyCharm 148 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 150 | # and can be added to the global gitignore or merged into this file. For a more nuclear 151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 152 | .idea/ 153 | node_modules 154 | yarn_error.log 155 | crawler.iml 156 | tokens.toml 157 | /data/ 158 | .syncthing* 159 | *.pt 160 | .model-checkpoints 161 | fadtk/test/fad_scores 162 | fadtk/test/samples/embeddings 163 | fadtk/test/samples/convert 164 | fadtk/test/samples/stats 165 | fadtk/test/comparison.csv 166 | 167 | .DS_Store 168 | ._* 169 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Microsoft Corporation 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Frechet Audio Distance Toolkit 2 | 3 | A simple and standardized library for Frechet Audio Distance (FAD) calculation. This library is published along with the paper [_Adapting Frechet Audio Distance for Generative Music Evaluation_](https://ieeexplore.ieee.org/document/10446663) ([arXiv](https://arxiv.org/abs/2311.01616)). The datasets associated with this paper and sample code tools used in the paper are also available under this repository. 4 | 5 | You can listen to audio samples of per-song FAD outliers on the online demo here: https://fadtk.hydev.org/ 6 | 7 | ## 0x00. Features 8 | 9 | * Easily and efficiently compute audio embeddings with various models. 10 | * Compute FAD∞ scores between two datasets for evaluation. 11 | * Use pre-computed statistics ("weights") to compute FAD∞ scores from existing baselines. 12 | * Compute per-song FAD to find outliers in the dataset 13 | 14 | ### Supported Models 15 | 16 | | Model | Name in FADtk | Description | Creator | 17 | | --- | --- | --- | --- | 18 | | [CLAP](https://github.com/microsoft/CLAP) | `clap-2023` | Learning audio concepts from natural language supervision | Microsoft | 19 | | [CLAP](https://github.com/LAION-AI/CLAP) | `clap-laion-{audio/music}` | Contrastive Language-Audio Pretraining | LAION | 20 | | [Encodec](https://github.com/facebookresearch/encodec) | `encodec-emb` | State-of-the-art deep learning based audio codec | Facebook/Meta Research | 21 | | [MERT](https://huggingface.co/m-a-p/MERT-v1-95M) | `MERT-v1-95M-{layer}` | Acoustic Music Understanding Model with Large-Scale Self-supervised Training | m-a-p | 22 | | [VGGish](https://github.com/tensorflow/models/blob/master/research/audioset/vggish/README.md) | `vggish` | Audio feature classification embedding | Google | 23 | | [DAC](https://github.com/descriptinc/descript-audio-codec)* | `dac-44kHz` | High-Fidelity Audio Compression with Improved RVQGAN | Descript | 24 | | [CDPAM](https://github.com/pranaymanocha/PerceptualAudio)* | `cdpam-{acoustic/content}` | Contrastive learning-based Deep Perceptual Audio Metric | Pranay Manocha et al. | 25 | | [Wav2vec 2.0](https://github.com/facebookresearch/fairseq/blob/main/examples/wav2vec/README.md) | `w2v2-{base/large}` | Wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations | Facebook/Meta Research | 26 | | [HuBERT](https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/README.md) | `hubert-{base/large}` | HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units | Facebook/Meta Research | 27 | | [WavLM](https://github.com/microsoft/unilm/tree/master/wavlm) | `wavlm-{base/base-plus/large}` | WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing | Microsoft | 28 | | [Whisper](https://github.com/openai/whisper) | `whisper-{tiny/base/small/medium/large}` | Robust Speech Recognition via Large-Scale Weak Supervision | OpenAI | 29 | 30 | > [!NOTE] 31 | > * The models marked with an asterisk (*) are not included in the default installation of FADtk. You need to install them separately. Check the [Installation](#0x01-installation) section for more details. 32 | 33 | ## 0x01. Installation 34 | 35 | To use the FAD toolkit, you must first install it. This library is created and tested on Python 3.12 on Linux but should work on Python >3.10 and on Windows and macOS as well. 36 | 37 | 1. Install torch https://pytorch.org/ 38 | 2. `pip install fadtk` 39 | 40 | To ensure that the environment is setup correctly and everything work as intended, it is recommended to run our tests using the command `python -m fadtk.test` after installing. 41 | 42 | ### Optional Dependencies 43 | 44 | Optionally, you can install dependencies that add additional embedding support. They are: 45 | 46 | * CDPAM: `pip install cdpam` 47 | * DAC: `pip install descript-audio-codec==1.0.0` 48 | 49 | ## 0x02. Command Line Usage 50 | 51 | ```sh 52 | # Evaluation 53 | fadtk [--inf/--indiv] 54 | 55 | # Compute embeddings 56 | fadtk.embeds -m -d 57 | ``` 58 | #### Example 1: Computing FAD_inf scores on FMA_Pop baseline 59 | 60 | ```sh 61 | # Compute FAD-inf between the baseline and evaluation datasets on two different models 62 | fadtk clap-laion-audio fma_pop /path/to/evaluation/audio --inf 63 | fadtk encodec-emb fma_pop /path/to/evaluation/audio --inf 64 | ``` 65 | 66 | #### Example 2: Compute individual FAD scores for each song 67 | 68 | ```sh 69 | fadtk encodec-emb fma_pop /path/to/evaluation/audio scores.csv --indiv 70 | ``` 71 | 72 | #### Example 3: Compute FAD scores with your own baseline 73 | 74 | First, create two directories, one for the baseline and one for the evaluation, and place *only* the audio files in them. Then, run the following commands: 75 | 76 | ```sh 77 | # Compute FAD between the baseline and evaluation datasets 78 | fadtk clap-laion-audio /path/to/baseline/audio /path/to/evaluation/audio 79 | ``` 80 | 81 | #### Example 4: Just compute embeddings 82 | 83 | If you only want to compute embeddings with a list of specific models for a list of dataset, you can do that using the command line. 84 | 85 | ```sh 86 | fadtk.embeds -m Model1 Model2 -d /dataset1 /dataset2 87 | ``` 88 | 89 | ## 0x03. Best Practices 90 | 91 | When using the FAD toolkit to compute FAD scores, it's essential to consider the following best practices to ensure accuracy and relevancy in the reported findings. 92 | 93 | 1. **Choose a Meaningful Reference Set**: Do not default to commonly used reference sets like Musiccaps without consideration. A reference set that aligns with the specific goal of the research should be chosen. For generative music, we recommend using the FMA-Pop subset as proposed in our paper. 94 | 2. **Select an Appropriate Embedding**: The choice of embedding can heavily influence the scoring. For instance, VGGish is optimized for classification, and it might not be the most suitable if the research objective is to measure aspects like quality. 95 | 3. **Provide Comprehensive Reporting**: Ensure that all test statistics are included in the report: 96 | * The chosen reference set. 97 | * The selected embedding. 98 | * The number of samples and their duration in both the reference and test set. 99 | 100 | This level of transparency ensures that the FAD scores' context and potential variability are understood by readers or users. 101 | 4. **Benchmark Against the State-of-the-Art**: When making comparisons, researchers should ideally use the same setup to assess the state-of-the-art models for comparison. Without a consistent setup, the FAD comparison might lose its significance. 102 | 5. **Interpret FAD Scores Contextually**: Per-sample FAD scores should be calculated. Listening to the per-sample outliers will provide a hands-on understanding of what the current setup is capturing, and what "low" and "high" FAD scores signify in the context of the study. 103 | 104 | By adhering to these best practices, the use of our FAD toolkit can be ensured to be both methodologically sound and contextually relevant. 105 | 106 | 107 | ## 0x04. Programmatic Usage 108 | 109 | ### Doing the above in python 110 | 111 | If you want to know how to do the above command-line processes in python, you can check out how our launchers are implemented ([\_\_main\_\_.py](fadtk/__main__.py) and [embeds.py](fadtk/embeds.py)) 112 | 113 | ### Adding New Embeddings 114 | 115 | To add a new embedding, the only file you would need to modify is [model_loader.py](fadtk/model_loader.py). You must create a new class that inherits the ModelLoader class. You need to implement the constructor, the `load_model` and the `_get_embedding` function. You can start with the below template: 116 | 117 | ```python 118 | class YourModel(ModelLoader): 119 |     """ 120 |     Add a short description of your model here. 121 |     """ 122 |     def __init__(self): 123 |         # Define your sample rate and number of features here. Audio will automatically be resampled to this sample rate. 124 |         super().__init__("Model name including variant", num_features=128, sr=16000) 125 |         # Add any other variables you need here 126 | 127 |     def load_model(self): 128 |         # Load your model here 129 |         pass 130 | 131 |     def _get_embedding(self, audio: np.ndarray) -> np.ndarray: 132 |         # Calculate the embeddings using your model 133 |         return np.zeros((1, self.num_features)) 134 | 135 |     def load_wav(self, wav_file: Path): 136 |         # Optionally, you can override this method to load wav file in a different way. The input wav_file is already in the correct sample rate specified in the constructor. 137 |         return super().load_wav(wav_file) 138 | ``` 139 | 140 | ## 0x05. Published Data and Code 141 | 142 | We also include some sample code and data from the paper in this repo. 143 | 144 | ### Refined Datasets 145 | 146 | [musiccaps-public-openai.csv](datasets/musiccaps-public-openai.csv): This file contains the original MusicCaps song IDs and captions along with GPT4 labels for their quality and the GPT4-refined prompts used for music generation. 147 | 148 | [fma_pop_tracks.csv](datasets/fma_pop_tracks.csv): This file contains the subset of 4839 song IDs and metadata information for the FMA-Pop subset we proposed in our paper. After downloading the Free Music Archive dataset from the [original source](https://github.com/mdeff/fma), you can easily locate the audio files for this FMA-Pop subset using song IDs. 149 | 150 | ### Sample Code 151 | 152 | The method we used to create GPT4 one-shot prompts for generating the refined MusicCaps prompts and for classifying quality from the MusicCaps captions can be found in [example/prompts](example/prompts). 153 | 154 | ## 0x06. Citation, Acknowledgments and Licenses 155 | 156 | The code in this toolkit is licensed under the [MIT License](./LICENSE). Please cite our work if this repository helped you in your project. 157 | 158 | ```latex 159 | @inproceedings{fadtk, 160 | title = {Adapting Frechet Audio Distance for Generative Music Evaluation}, 161 | author = {Azalea Gui, Hannes Gamper, Sebastian Braun, Dimitra Emmanouilidou}, 162 | booktitle = {Proc. IEEE ICASSP 2024}, 163 | year = {2024}, 164 | url = {https://arxiv.org/abs/2311.01616}, 165 | } 166 | ``` 167 | 168 | Please also cite the FMA (Free Music Archive) dataset if you used FMA-Pop as your FAD scoring baseline. 169 | 170 | ```latex 171 | @inproceedings{fma_dataset, 172 | title = {{FMA}: A Dataset for Music Analysis}, 173 | author = {Defferrard, Micha\"el and Benzi, Kirell and Vandergheynst, Pierre and Bresson, Xavier}, 174 | booktitle = {18th International Society for Music Information Retrieval Conference (ISMIR)}, 175 | year = {2017}, 176 | archiveprefix = {arXiv}, 177 | eprint = {1612.01840}, 178 | url = {https://arxiv.org/abs/1612.01840}, 179 | } 180 | ``` 181 | 182 | You may also refer to our work on measuring music emotion and mitigating emotion bias using this toolkit. 183 | 184 | ```latex 185 | @article{emotionbias_fad, 186 | title = {Rethinking Emotion Bias in Music via Frechet Audio Distance}, 187 | author = {Li, Yuanchao and Gui, Azalea and Emmanouilidou, Dimitra and Gamper, Hannes}, 188 | journal={arXiv preprint arXiv:2409.15545}, 189 | year={2024} 190 | } 191 | ``` 192 | 193 | ### Special Thanks 194 | 195 | **Immense gratitude to the foundational repository [gudgud96/frechet-audio-distance](https://github.com/gudgud96/frechet-audio-distance) - "A lightweight library for Frechet Audio Distance calculation"**. Much of our project has been adapted and enhanced from gudgud96's contributions. In honor of this work, we've retained the [original MIT license](example/LICENSE_gudgud96). 196 | 197 | * Encodec from Facebook: [facebookresearch/encodec](https://github.com/facebookresearch/encodec/) 198 | * CLAP: [microsoft/CLAP](https://github.com/microsoft/CLAP) 199 | * CLAP from LAION: [LAION-AI/CLAP](https://github.com/LAION-AI/CLAP) 200 | * MERT from M-A-P: [m-a-p/MERT](https://huggingface.co/m-a-p/MERT-v1-95M) 201 | * Wav2vec 2.0: [facebookresearch/wav2vec 2.0](https://github.com/facebookresearch/fairseq/blob/main/examples/wav2vec/README.md) 202 | * HuBERT: [facebookresearch/HuBERT](https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/README.md) 203 | * WavLM: [microsoft/WavLM](https://github.com/microsoft/unilm/tree/master/wavlm) 204 | * Whisper: [OpenAI/Whisper](https://github.com/openai/whisper) 205 | * VGGish in PyTorch: [harritaylor/torchvggish](https://github.com/harritaylor/torchvggish) 206 | * Free Music Archive: [mdeff/fma](https://github.com/mdeff/fma) 207 | * Frechet Inception Distance implementation: [mseitzer/pytorch-fid](https://github.com/mseitzer/pytorch-fid) 208 | * Frechet Audio Distance paper: [arxiv/1812.08466](https://arxiv.org/abs/1812.08466) 209 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /example/LICENSE_gudgud96: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Hao Hao Tan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /example/prompts/gpt4_quality.py: -------------------------------------------------------------------------------- 1 | # Define prompts 2 | prompt_music = """ 3 | Musicality or "musical quality" is defined as the artistic and aesthetic content of the music (e.g. whether the melody and dynamics are expressed/performed well, or whether the voice is clear), while "acoustic quality" defines the quality of the sound recording (e.g. whether the sound is clean from recording noise). 4 | 5 | Does the following music comment describe the music as high or low in musicality or acoustic quality? Please answer each with "High", "Medium", "Low", or "Not mentioned". No explanation is needed. 6 | 7 | {s} 8 | """.strip() 9 | 10 | example_comment = "A female Arabic singer sings this beautiful melody with backup singers in vocal harmony. The song is medium tempo with a string section, Arabic percussion instruments, tambourine percussion, steady drum rhythm, groovy bass line and keyboard accompaniment. The song is romantic and celebratory in nature. The audio quality is very poor." 11 | 12 | example_response = """ 13 | Musicality: High 14 | Acoustic: Low 15 | """.strip() 16 | 17 | def create_prompt(comment: str): 18 | return [ 19 | {'role': 'system', 'content': 'You are a professional musician asked to review music-related comment.'}, 20 | # Give example 21 | {'role': 'user', 'content': prompt_music.replace('{s}', example_comment)}, 22 | {'role': 'assistant', 'content': example_response}, 23 | # Ask for response 24 | {'role': 'user', 'content': comment}, 25 | ] -------------------------------------------------------------------------------- /example/prompts/gpt4_refine.py: -------------------------------------------------------------------------------- 1 | # Define prompts 2 | prompt_music = """ 3 | Please write an short and objective one-sentence description based on the below musical commentary. Please use simple words and do not mention the quality aspects of the music. 4 | 5 | {s} 6 | """.strip() 7 | 8 | def create_prompt(comment: str): 9 | return [ 10 | {'role': 'system', 'content': 'You are a professional musician asked to review music-related comment.'}, 11 | # Give example 12 | {'role': 'user', 'content': prompt_music.replace('{s}', "This song features a rubber instrument being played. The strumming is fast. The melody is played on one fretted string and other open strings. The melody is played on the lower octave and is later repeated on the higher octave. This song can be played at a folk party. This song has low recording quality, as if recorded from a mobile phone.")}, 13 | {'role': 'assistant', 'content': "Folk tune played on a rubber instrument with quick strumming and a two-octave melody"}, 14 | # Ask for response 15 | {'role': 'user', 'content': prompt_music.replace('{s}', comment)} 16 | ] 17 | -------------------------------------------------------------------------------- /fadtk/__init__.py: -------------------------------------------------------------------------------- 1 | from .fad import * 2 | from .fad_batch import * 3 | from .model_loader import * 4 | from .utils import * -------------------------------------------------------------------------------- /fadtk/__main__.py: -------------------------------------------------------------------------------- 1 | import time 2 | from argparse import ArgumentParser 3 | 4 | from .fad import FrechetAudioDistance, log 5 | from .model_loader import * 6 | from .fad_batch import cache_embedding_files 7 | 8 | 9 | def main(): 10 | """ 11 | Launcher for running FAD on two directories using a model. 12 | """ 13 | models = {m.name: m for m in get_all_models()} 14 | 15 | agupa = ArgumentParser() 16 | # Two positional arguments: model and two directories 17 | agupa.add_argument('model', type=str, choices=list(models.keys()), help="The embedding model to use") 18 | agupa.add_argument('baseline', type=str, help="The baseline dataset") 19 | agupa.add_argument('eval', type=str, help="The directory to evaluate against") 20 | agupa.add_argument('csv', type=str, nargs='?', 21 | help="The CSV file to append results to. " 22 | "If this argument is not supplied, single-value results will be printed to stdout, " 23 | "and for --indiv, the results will be saved to 'fad-individual-results.csv'") 24 | 25 | # Add optional arguments 26 | agupa.add_argument('-w', '--workers', type=int, default=8) 27 | agupa.add_argument('-s', '--sox-path', type=str, default='/usr/bin/sox') 28 | agupa.add_argument('--inf', action='store_true', help="Use FAD-inf extrapolation") 29 | agupa.add_argument('--indiv', action='store_true', 30 | help="Calculate FAD for individual songs and store the results in the given file") 31 | 32 | args = agupa.parse_args() 33 | model = models[args.model] 34 | 35 | baseline = args.baseline 36 | eval = args.eval 37 | 38 | # 1. Calculate embedding files for each dataset 39 | for d in [baseline, eval]: 40 | if Path(d).is_dir(): 41 | cache_embedding_files(d, model, workers=args.workers) 42 | 43 | # 2. Calculate FAD 44 | fad = FrechetAudioDistance(model, audio_load_worker=args.workers, load_model=False) 45 | if args.inf: 46 | assert Path(eval).is_dir(), "FAD-inf requires a directory as the evaluation dataset" 47 | score = fad.score_inf(baseline, list(Path(eval).glob('*.*'))) 48 | print("FAD-inf Information:", score) 49 | score, inf_r2 = score.score, score.r2 50 | elif args.indiv: 51 | assert Path(eval).is_dir(), "Individual FAD requires a directory as the evaluation dataset" 52 | csv_path = Path(args.csv or 'fad-individual-results.csv') 53 | fad.score_individual(baseline, eval, csv_path) 54 | log.info(f"Individual FAD scores saved to {csv_path}") 55 | exit(0) 56 | else: 57 | score = fad.score(baseline, eval) 58 | inf_r2 = None 59 | 60 | # 3. Print results 61 | log.info("FAD computed.") 62 | if args.csv: 63 | Path(args.csv).parent.mkdir(parents=True, exist_ok=True) 64 | if not Path(args.csv).is_file(): 65 | Path(args.csv).write_text('model,baseline,eval,score,inf_r2,time\n') 66 | with open(args.csv, 'a') as f: 67 | f.write(f'{model.name},{baseline},{eval},{score},{inf_r2},{time.time()}\n') 68 | log.info(f"FAD score appended to {args.csv}") 69 | 70 | log.info(f"The FAD {model.name} score between {baseline} and {eval} is: {score}") 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /fadtk/embeds.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from .model_loader import * 3 | from .fad_batch import cache_embedding_files 4 | 5 | def main(): 6 | """ 7 | Launcher for caching embeddings of directories using multiple models. 8 | """ 9 | models = {m.name: m for m in get_all_models()} 10 | 11 | agupa = ArgumentParser() 12 | 13 | # Accept multiple models and directories with distinct prefixes 14 | agupa.add_argument('-m', '--models', type=str, choices=list(models.keys()), nargs='+', required=True) 15 | agupa.add_argument('-d', '--dirs', type=str, nargs='+', required=True) 16 | 17 | # Add optional arguments 18 | agupa.add_argument('-w', '--workers', type=int, default=8) 19 | agupa.add_argument('-s', '--sox-path', type=str, default='/usr/bin/sox') 20 | 21 | args = agupa.parse_args() 22 | 23 | for model_name in args.models: 24 | model = models[model_name] 25 | for d in args.dirs: 26 | log.info(f"Caching embeddings for {d} using {model.name}") 27 | cache_embedding_files(d, model, workers=args.workers) 28 | 29 | 30 | if __name__ == "__main__": 31 | main() -------------------------------------------------------------------------------- /fadtk/fad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | import tempfile 5 | import traceback 6 | from typing import NamedTuple, Union 7 | import numpy as np 8 | import torch 9 | import torchaudio 10 | from scipy import linalg 11 | from numpy.lib.scimath import sqrt as scisqrt 12 | from pathlib import Path 13 | from hypy_utils import write 14 | from hypy_utils.tqdm_utils import tq, tmap 15 | from hypy_utils.logging_utils import setup_logger 16 | 17 | from .model_loader import ModelLoader 18 | from .utils import * 19 | 20 | log = setup_logger() 21 | sox_path = os.environ.get('SOX_PATH', 'sox') 22 | ffmpeg_path = os.environ.get('FFMPEG_PATH', 'ffmpeg') 23 | torchaudio_backend = os.environ.get('TORCHAUDIO_BACKEND', 'soundfile') 24 | TORCHAUDIO_RESAMPLING = True 25 | 26 | if not(TORCHAUDIO_RESAMPLING): 27 | if not shutil.which(sox_path): 28 | log.error(f"Could not find SoX executable at {sox_path}, please install SoX and set the SOX_PATH environment variable.") 29 | exit(3) 30 | if not shutil.which(ffmpeg_path): 31 | log.error(f"Could not find ffmpeg executable at {ffmpeg_path}, please install ffmpeg and set the FFMPEG_PATH environment variable.") 32 | exit(3) 33 | 34 | 35 | class FADInfResults(NamedTuple): 36 | score: float 37 | slope: float 38 | r2: float 39 | points: list[tuple[int, float]] 40 | 41 | 42 | def calc_embd_statistics(embd_lst: np.ndarray) -> tuple[np.ndarray, np.ndarray]: 43 | """ 44 | Calculate the mean and covariance matrix of a list of embeddings. 45 | """ 46 | assert embd_lst.shape[0] >= 2, (f"FAD requires at least two embedding window frames, you have {embd_lst.shape}." 47 | " (This probably means that your audio is too short)") 48 | return np.mean(embd_lst, axis=0), np.cov(embd_lst, rowvar=False) 49 | 50 | 51 | def calc_frechet_distance(mu1, cov1, mu2, cov2, eps=1e-6): 52 | """ 53 | Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py 54 | 55 | Numpy implementation of the Frechet Distance. 56 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 57 | and X_2 ~ N(mu_2, C_2) is 58 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 59 | Stable version by Dougal J. Sutherland. 60 | Params: 61 | -- mu1 : Numpy array containing the activations of a layer of the 62 | inception net (like returned by the function 'get_predictions') 63 | for generated samples. 64 | -- mu2 : The sample mean over activations, precalculated on an 65 | representative data set. 66 | -- cov1: The covariance matrix over activations for generated samples. 67 | -- cov2: The covariance matrix over activations, precalculated on an 68 | representative data set. 69 | Returns: 70 | -- : The Frechet Distance. 71 | """ 72 | mu1 = np.atleast_1d(mu1) 73 | mu2 = np.atleast_1d(mu2) 74 | 75 | cov1 = np.atleast_2d(cov1) 76 | cov2 = np.atleast_2d(cov2) 77 | 78 | assert mu1.shape == mu2.shape, \ 79 | f'Training and test mean vectors have different lengths ({mu1.shape} vs {mu2.shape})' 80 | assert cov1.shape == cov2.shape, \ 81 | f'Training and test covariances have different dimensions ({cov1.shape} vs {cov2.shape})' 82 | 83 | diff = mu1 - mu2 84 | 85 | # Product might be almost singular 86 | # NOTE: issues with sqrtm for newer scipy versions 87 | # using eigenvalue method as workaround 88 | covmean_sqrtm, _ = linalg.sqrtm(cov1.dot(cov2), disp=False) 89 | 90 | # eigenvalue method 91 | D, V = linalg.eig(cov1.dot(cov2)) 92 | covmean = (V * scisqrt(D)) @ linalg.inv(V) 93 | 94 | if not np.isfinite(covmean).all(): 95 | msg = ('fid calculation produces singular product; ' 96 | 'adding %s to diagonal of cov estimates') % eps 97 | log.info(msg) 98 | offset = np.eye(cov1.shape[0]) * eps 99 | covmean = linalg.sqrtm((cov1 + offset).dot(cov2 + offset)) 100 | 101 | # Numerical error might give slight imaginary component 102 | if np.iscomplexobj(covmean): 103 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 104 | m = np.max(np.abs(covmean.imag)) 105 | raise ValueError('Imaginary component {}'.format(m)) 106 | covmean = covmean.real 107 | 108 | tr_covmean = np.trace(covmean) 109 | tr_covmean_sqrtm = np.trace(covmean_sqrtm) 110 | if np.iscomplexobj(tr_covmean_sqrtm): 111 | if np.abs(tr_covmean_sqrtm.imag) < 1e-3: 112 | tr_covmean_sqrtm = tr_covmean_sqrtm.real 113 | 114 | if not(np.iscomplexobj(tr_covmean_sqrtm)): 115 | delt = np.abs(tr_covmean - tr_covmean_sqrtm) 116 | if delt > 1e-3: 117 | log.warning(f'Detected high error in sqrtm calculation: {delt}') 118 | 119 | return (diff.dot(diff) + np.trace(cov1) 120 | + np.trace(cov2) - 2 * tr_covmean) 121 | 122 | 123 | class FrechetAudioDistance: 124 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 125 | loaded = False 126 | 127 | def __init__(self, ml: ModelLoader, audio_load_worker=8, load_model=True): 128 | self.ml = ml 129 | self.audio_load_worker = audio_load_worker 130 | self.sox_formats = find_sox_formats(sox_path) 131 | 132 | if load_model: 133 | self.ml.load_model() 134 | self.loaded = True 135 | 136 | # Disable gradient calculation because we're not training 137 | torch.autograd.set_grad_enabled(False) 138 | 139 | def load_audio(self, f: Union[str, Path]): 140 | f = Path(f) 141 | 142 | # Create a directory for storing normalized audio files 143 | cache_dir = f.parent / "convert" / str(self.ml.sr) 144 | new = (cache_dir / f.name).with_suffix(".wav") 145 | 146 | if not new.exists(): 147 | cache_dir.mkdir(parents=True, exist_ok=True) 148 | if TORCHAUDIO_RESAMPLING: 149 | x, fsorig = torchaudio.load(str(f), backend=torchaudio_backend) 150 | x = torch.mean(x,0).unsqueeze(0) # convert to mono 151 | resampler = torchaudio.transforms.Resample( 152 | fsorig, 153 | self.ml.sr, 154 | lowpass_filter_width=64, 155 | rolloff=0.9475937167399596, 156 | resampling_method="sinc_interp_kaiser", 157 | beta=14.769656459379492, 158 | ) 159 | y = resampler(x) 160 | torchaudio.save(str(new), y, self.ml.sr, encoding="PCM_S", bits_per_sample=16) 161 | else: 162 | sox_args = ['-r', str(self.ml.sr), '-c', '1', '-b', '16'] 163 | 164 | # ffmpeg has bad resampling compared to SoX 165 | # SoX has bad format support compared to ffmpeg 166 | # If the file format is not supported by SoX, use ffmpeg to convert it to wav 167 | 168 | if f.suffix[1:] not in self.sox_formats: 169 | # Use ffmpeg for format conversion and then pipe to sox for resampling 170 | with tempfile.TemporaryDirectory() as tmp: 171 | tmp = Path(tmp) / 'temp.wav' 172 | 173 | # Open ffmpeg process for format conversion 174 | subprocess.run([ 175 | ffmpeg_path, 176 | "-hide_banner", "-loglevel", "error", 177 | "-i", f, tmp]) 178 | 179 | # Open sox process for resampling, taking input from ffmpeg's output 180 | subprocess.run([sox_path, tmp, *sox_args, new]) 181 | 182 | else: 183 | # Use sox for resampling 184 | subprocess.run([sox_path, f, *sox_args, new]) 185 | 186 | return self.ml.load_wav(new) 187 | 188 | def cache_embedding_file(self, audio_dir: Union[str, Path]): 189 | """ 190 | Compute embedding for an audio file and cache it to a file. 191 | """ 192 | cache = get_cache_embedding_path(self.ml.name, audio_dir) 193 | 194 | if cache.exists(): 195 | return 196 | 197 | # Load file, get embedding, save embedding 198 | wav_data = self.load_audio(audio_dir) 199 | embd = self.ml.get_embedding(wav_data) 200 | cache.parent.mkdir(parents=True, exist_ok=True) 201 | np.save(cache, embd) 202 | 203 | def read_embedding_file(self, audio_dir: Union[str, Path]): 204 | """ 205 | Read embedding from a cached file. 206 | """ 207 | cache = get_cache_embedding_path(self.ml.name, audio_dir) 208 | assert cache.exists(), f"Embedding file {cache} does not exist, please run cache_embedding_file first." 209 | return np.load(cache) 210 | 211 | def load_embeddings(self, dir: Union[str, Path], max_count: int = -1, concat: bool = True): 212 | """ 213 | Load embeddings for all audio files in a directory. 214 | """ 215 | files = list(Path(dir).glob("*.*")) 216 | log.info(f"Loading {len(files)} audio files from {dir}...") 217 | 218 | return self._load_embeddings(files, max_count=max_count, concat=concat) 219 | 220 | def _load_embeddings(self, files: list[Path], max_count: int = -1, concat: bool = True): 221 | """ 222 | Load embeddings for a list of audio files. 223 | """ 224 | if len(files) == 0: 225 | raise ValueError("No files provided") 226 | 227 | # Load embeddings 228 | if max_count == -1: 229 | embd_lst = tmap(self.read_embedding_file, files, desc="Loading audio files...", max_workers=self.audio_load_worker) 230 | else: 231 | total_len = 0 232 | embd_lst = [] 233 | for f in tq(files, "Loading files"): 234 | embd_lst.append(self.read_embedding_file(f)) 235 | total_len += embd_lst[-1].shape[0] 236 | if total_len > max_count: 237 | break 238 | 239 | # Concatenate embeddings if needed 240 | if concat: 241 | return np.concatenate(embd_lst, axis=0) 242 | else: 243 | return embd_lst, files 244 | 245 | def load_stats(self, path: PathLike): 246 | """ 247 | Load embedding statistics from a directory. 248 | """ 249 | if isinstance(path, str): 250 | # Check if it's a pre-computed statistic file 251 | bp = Path(__file__).parent / "stats" 252 | stats = bp / (path.lower() + ".npz") 253 | print(stats) 254 | if stats.exists(): 255 | path = stats 256 | 257 | path = Path(path) 258 | 259 | # Check if path is a file 260 | if path.is_file(): 261 | # Load it as a npz 262 | log.info(f"Loading embedding statistics from {path}...") 263 | with np.load(path) as data: 264 | if f'{self.ml.name}.mu' not in data or f'{self.ml.name}.cov' not in data: 265 | raise ValueError(f"FAD statistics file {path} doesn't contain data for model {self.ml.name}") 266 | return data[f'{self.ml.name}.mu'], data[f'{self.ml.name}.cov'] 267 | 268 | cache_dir = path / "stats" / self.ml.name 269 | emb_dir = path / "embeddings" / self.ml.name 270 | if cache_dir.exists(): 271 | log.info(f"Embedding statistics is already cached for {path}, loading...") 272 | mu = np.load(cache_dir / "mu.npy") 273 | cov = np.load(cache_dir / "cov.npy") 274 | return mu, cov 275 | 276 | if not path.is_dir(): 277 | log.error(f"The dataset you want to use ({path}) is not a directory nor a file.") 278 | exit(1) 279 | 280 | log.info(f"Loading embedding files from {path}...") 281 | 282 | mu, cov = calculate_embd_statistics_online(list(emb_dir.glob("*.npy"))) 283 | log.info("> Embeddings statistics calculated.") 284 | 285 | # Save statistics 286 | cache_dir.mkdir(parents=True, exist_ok=True) 287 | np.save(cache_dir / "mu.npy", mu) 288 | np.save(cache_dir / "cov.npy", cov) 289 | 290 | return mu, cov 291 | 292 | def score(self, baseline: PathLike, eval: PathLike): 293 | """ 294 | Calculate a single FAD score between a background and an eval set. 295 | 296 | :param baseline: Baseline matrix or directory containing baseline audio files 297 | :param eval: Eval matrix or directory containing eval audio files 298 | """ 299 | mu_bg, cov_bg = self.load_stats(baseline) 300 | mu_eval, cov_eval = self.load_stats(eval) 301 | 302 | return calc_frechet_distance(mu_bg, cov_bg, mu_eval, cov_eval) 303 | 304 | def score_inf(self, baseline: PathLike, eval_files: list[Path], steps: int = 25, min_n = 500, raw: bool = False): 305 | """ 306 | Calculate FAD for different n (number of samples) and compute FAD-inf. 307 | 308 | :param baseline: Baseline matrix or directory containing baseline audio files 309 | :param eval_files: list of eval audio files 310 | :param steps: number of steps to use 311 | :param min_n: minimum n to use 312 | :param raw: return raw results in addition to FAD-inf 313 | """ 314 | log.info(f"Calculating FAD-inf for {self.ml.name}...") 315 | # 1. Load background embeddings 316 | mu_base, cov_base = self.load_stats(baseline) 317 | # If all of the embedding files end in .npy, we can load them directly 318 | if all([f.suffix == '.npy' for f in eval_files]): 319 | embeds = [np.load(f) for f in eval_files] 320 | embeds = np.concatenate(embeds, axis=0) 321 | else: 322 | embeds = self._load_embeddings(eval_files, concat=True) 323 | 324 | # Calculate maximum n 325 | max_n = len(embeds) 326 | 327 | # Generate list of ns to use 328 | ns = [int(n) for n in np.linspace(min_n, max_n, steps)] 329 | 330 | results = [] 331 | for n in tq(ns, desc="Calculating FAD-inf"): 332 | # Select n feature frames randomly (with replacement) 333 | indices = np.random.choice(embeds.shape[0], size=n, replace=True) 334 | embds_eval = embeds[indices] 335 | 336 | mu_eval, cov_eval = calc_embd_statistics(embds_eval) 337 | fad_score = calc_frechet_distance(mu_base, cov_base, mu_eval, cov_eval) 338 | 339 | # Add to results 340 | results.append([n, fad_score]) 341 | 342 | # Compute FAD-inf based on linear regression of 1/n 343 | ys = np.array(results) 344 | xs = 1 / np.array(ns) 345 | slope, intercept = np.polyfit(xs, ys[:, 1], 1) 346 | 347 | # Compute R^2 348 | r2 = 1 - np.sum((ys[:, 1] - (slope * xs + intercept)) ** 2) / np.sum((ys[:, 1] - np.mean(ys[:, 1])) ** 2) 349 | 350 | # Since intercept is the FAD-inf, we can just return it 351 | return FADInfResults(score=intercept, slope=slope, r2=r2, points=results) 352 | 353 | def score_individual(self, baseline: PathLike, eval_dir: PathLike, csv_name: Union[Path, str]) -> Path: 354 | """ 355 | Calculate the FAD score for each individual file in eval_dir and write the results to a csv file. 356 | 357 | :param baseline: Baseline matrix or directory containing baseline audio files 358 | :param eval_dir: Directory containing eval audio files 359 | :param csv_name: Name of the csv file to write the results to 360 | :return: Path to the csv file 361 | """ 362 | csv = Path(csv_name) 363 | if isinstance(csv_name, str): 364 | csv = Path('data') / f'fad-individual' / self.ml.name / csv_name 365 | if csv.exists(): 366 | log.info(f"CSV file {csv} already exists, exiting...") 367 | return csv 368 | 369 | # 1. Load background embeddings 370 | mu, cov = self.load_stats(baseline) 371 | 372 | # 2. Define helper function for calculating z score 373 | def _find_z_helper(f): 374 | try: 375 | # Calculate FAD for individual songs 376 | embd = self.read_embedding_file(f) 377 | mu_eval, cov_eval = calc_embd_statistics(embd) 378 | return calc_frechet_distance(mu, cov, mu_eval, cov_eval) 379 | 380 | except Exception as e: 381 | traceback.print_exc() 382 | log.error(f"An error occurred calculating individual FAD using model {self.ml.name} on file {f}") 383 | log.error(e) 384 | 385 | # 3. Calculate z score for each eval file 386 | _files = list(Path(eval_dir).glob("*.*")) 387 | scores = tmap(_find_z_helper, _files, desc=f"Calculating scores", max_workers=self.audio_load_worker) 388 | 389 | # 4. Write the sorted z scores to csv 390 | pairs = list(zip(_files, scores)) 391 | pairs = [p for p in pairs if p[1] is not None] 392 | pairs = sorted(pairs, key=lambda x: np.abs(x[1])) 393 | write(csv, "\n".join([",".join([str(x).replace(',', '_') for x in row]) for row in pairs])) 394 | 395 | return csv 396 | -------------------------------------------------------------------------------- /fadtk/fad_batch.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | from pathlib import Path 4 | from typing import Callable, Union 5 | import numpy as np 6 | 7 | import torch 8 | 9 | from .fad import log, FrechetAudioDistance 10 | from .model_loader import ModelLoader 11 | from .utils import get_cache_embedding_path 12 | 13 | 14 | 15 | def _cache_embedding_batch(args): 16 | fs: list[Path] 17 | ml: ModelLoader 18 | fs, ml, kwargs = args 19 | fad = FrechetAudioDistance(ml, **kwargs) 20 | for f in fs: 21 | log.info(f"Loading {f} using {ml.name}") 22 | fad.cache_embedding_file(f) 23 | 24 | 25 | def cache_embedding_files(files: Union[list[Path], str, Path], ml: ModelLoader, workers: int = 8, **kwargs): 26 | """ 27 | Get embeddings for all audio files in a directory. 28 | 29 | :param ml_fn: A function that returns a ModelLoader instance. 30 | """ 31 | if isinstance(files, (str, Path)): 32 | files = list(Path(files).glob('*.*')) 33 | 34 | # Filter out files that already have embeddings 35 | files = [f for f in files if not get_cache_embedding_path(ml.name, f).exists()] 36 | if len(files) == 0: 37 | log.info("All files already have embeddings, skipping.") 38 | return 39 | 40 | log.info(f"[Frechet Audio Distance] Loading {len(files)} audio files...") 41 | 42 | # Split files into batches 43 | batches = list(np.array_split(files, workers)) 44 | 45 | # Cache embeddings in parallel 46 | multiprocessing.set_start_method('spawn', force=True) 47 | with torch.multiprocessing.Pool(workers) as pool: 48 | pool.map(_cache_embedding_batch, [(b, ml, kwargs) for b in batches]) -------------------------------------------------------------------------------- /fadtk/model_loader.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import logging 3 | import math 4 | from typing import Literal 5 | import numpy as np 6 | import soundfile 7 | 8 | import torch 9 | import librosa 10 | from torch import nn 11 | from pathlib import Path 12 | from hypy_utils.downloader import download_file 13 | import torch.nn.functional as F 14 | import importlib.util 15 | import importlib.metadata 16 | 17 | 18 | log = logging.getLogger(__name__) 19 | 20 | 21 | class ModelLoader(ABC): 22 | """ 23 | Abstract class for loading a model and getting embeddings from it. The model should be loaded in the `load_model` method. 24 | """ 25 | def __init__(self, name: str, num_features: int, sr: int, min_len: int = -1): 26 | """ 27 | Args: 28 | name (str): A unique identifier for the model. 29 | num_features (int): Number of features in the output embedding (dimensionality). 30 | sr (int): Sample rate of the audio. 31 | min_len (int, optional): Enforce a minimal length for the audio in seconds. Defaults to -1 (no minimum). 32 | """ 33 | self.model = None 34 | self.sr = sr 35 | self.num_features = num_features 36 | self.name = name 37 | self.min_len = min_len 38 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 39 | 40 | def get_embedding(self, audio: np.ndarray): 41 | embd = self._get_embedding(audio) 42 | if self.device == torch.device('cuda'): 43 | embd = embd.cpu() 44 | embd = embd.detach().numpy() 45 | 46 | # If embedding is float32, convert to float16 to be space-efficient 47 | if embd.dtype == np.float32: 48 | embd = embd.astype(np.float16) 49 | 50 | return embd 51 | 52 | @abstractmethod 53 | def load_model(self): 54 | pass 55 | 56 | @abstractmethod 57 | def _get_embedding(self, audio: np.ndarray): 58 | """ 59 | Returns the embedding of the audio file. The resulting vector should be of shape (n_frames, n_features). 60 | """ 61 | pass 62 | 63 | def load_wav(self, wav_file: Path): 64 | wav_data, _ = soundfile.read(wav_file, dtype='int16') 65 | wav_data = wav_data / 32768.0 # Convert to [-1.0, +1.0] 66 | 67 | # Enforce minimum length 68 | wav_data = self.enforce_min_len(wav_data) 69 | 70 | return wav_data 71 | 72 | def enforce_min_len(self, audio: np.ndarray) -> np.ndarray: 73 | """ 74 | Enforce a minimum length for the audio. If the audio is too short, output a warning and pad it with zeros. 75 | """ 76 | if self.min_len < 0: 77 | return audio 78 | if audio.shape[0] < self.min_len * self.sr: 79 | log.warning( 80 | f"Audio is too short for {self.name}.\n" 81 | f"The model requires a minimum length of {self.min_len}s, audio is {audio.shape[0] / self.sr:.2f}s.\n" 82 | f"Padding with zeros." 83 | ) 84 | audio = np.pad(audio, (0, int(np.ceil(self.min_len * self.sr - audio.shape[0])))) 85 | print() 86 | return audio 87 | 88 | 89 | class VGGishModel(ModelLoader): 90 | """ 91 | S. Hershey et al., "CNN Architectures for Large-Scale Audio Classification", ICASSP 2017 92 | """ 93 | def __init__(self, use_pca=False, use_activation=False): 94 | super().__init__("vggish", 128, 16000, min_len=1) 95 | self.use_pca = use_pca 96 | self.use_activation = use_activation 97 | 98 | def load_model(self): 99 | self.model = torch.hub.load('harritaylor/torchvggish', 'vggish') 100 | if not self.use_pca: 101 | self.model.postprocess = False 102 | if not self.use_activation: 103 | self.model.embeddings = nn.Sequential(*list(self.model.embeddings.children())[:-1]) 104 | self.model.eval() 105 | self.model.to(self.device) 106 | 107 | def _get_embedding(self, audio: np.ndarray) -> np.ndarray: 108 | return self.model.forward(audio, self.sr) 109 | 110 | 111 | class EncodecEmbModel(ModelLoader): 112 | """ 113 | Encodec model from https://github.com/facebookresearch/encodec 114 | 115 | Thiss version uses the embedding outputs (continuous values of 128 features). 116 | """ 117 | def __init__(self, variant: Literal['48k', '24k'] = '24k'): 118 | super().__init__('encodec-emb' if variant == '24k' else f"encodec-emb-{variant}", 128, 119 | sr=24000 if variant == '24k' else 48000) 120 | self.variant = variant 121 | 122 | def load_model(self): 123 | from encodec import EncodecModel 124 | if self.variant == '48k': 125 | self.model = EncodecModel.encodec_model_48khz() 126 | self.model.set_target_bandwidth(24) 127 | else: 128 | self.model = EncodecModel.encodec_model_24khz() 129 | self.model.set_target_bandwidth(12) 130 | self.model.to(self.device) 131 | 132 | def _get_embedding(self, audio: np.ndarray) -> np.ndarray: 133 | segment_length = self.model.segment_length 134 | 135 | # The 24k model doesn't use segmenting 136 | if segment_length is None: 137 | return self._get_frame(audio) 138 | 139 | # The 48k model uses segmenting 140 | assert audio.dim() == 3 141 | _, channels, length = audio.shape 142 | assert channels > 0 and channels <= 2 143 | stride = segment_length 144 | 145 | encoded_frames: list[torch.Tensor] = [] 146 | for offset in range(0, length, stride): 147 | frame = audio[:, :, offset:offset + segment_length] 148 | encoded_frames.append(self._get_frame(frame)) 149 | 150 | # Concatenate 151 | encoded_frames = torch.cat(encoded_frames, dim=0) # [timeframes, 128] 152 | return encoded_frames 153 | 154 | def _get_frame(self, audio: np.ndarray) -> np.ndarray: 155 | with torch.no_grad(): 156 | length = audio.shape[-1] 157 | duration = length / self.sr 158 | assert self.model.segment is None or duration <= 1e-5 + self.model.segment, f"Audio is too long ({duration} > {self.model.segment})" 159 | 160 | emb = self.model.encoder(audio.to(self.device)) # [1, 128, timeframes] 161 | emb = emb[0] # [128, timeframes] 162 | emb = emb.transpose(0, 1) # [timeframes, 128] 163 | return emb 164 | 165 | def load_wav(self, wav_file: Path): 166 | import torchaudio 167 | from encodec.utils import convert_audio 168 | 169 | wav, sr = torchaudio.load(str(wav_file)) 170 | wav = convert_audio(wav, sr, self.sr, self.model.channels) 171 | 172 | # If it's longer than 3 minutes, cut it 173 | if wav.shape[1] > 3 * 60 * self.sr: 174 | wav = wav[:, :3 * 60 * self.sr] 175 | 176 | return wav.unsqueeze(0) 177 | 178 | def _decode_frame(self, emb: np.ndarray) -> np.ndarray: 179 | with torch.no_grad(): 180 | emb = torch.from_numpy(emb).float().to(self.device) # [timeframes, 128] 181 | emb = emb.transpose(0, 1) # [128, timeframes] 182 | emb = emb.unsqueeze(0) # [1, 128, timeframes] 183 | audio = self.model.decoder(emb) # [1, 1, timeframes] 184 | audio = audio[0, 0] # [timeframes] 185 | 186 | return audio.cpu().numpy() 187 | 188 | 189 | class DACModel(ModelLoader): 190 | """ 191 | DAC model from https://github.com/descriptinc/descript-audio-codec 192 | 193 | pip install descript-audio-codec 194 | """ 195 | def __init__(self): 196 | super().__init__("dac-44kHz", 1024, 44100) 197 | 198 | def load_model(self): 199 | from dac.utils import load_model 200 | self.model = load_model(tag='latest', model_type='44khz') 201 | self.model.eval() 202 | self.model.to(self.device) 203 | 204 | def _get_embedding(self, audio) -> np.ndarray: 205 | from audiotools import AudioSignal 206 | import time 207 | 208 | audio: AudioSignal 209 | 210 | # Set variables 211 | win_len = 5.0 212 | overlap_hop_ratio = 0.5 213 | 214 | # Fix overlap window so that it's divisible by 4 in # of samples 215 | win_len = ((win_len * self.sr) // 4) * 4 216 | win_len = win_len / self.sr 217 | hop_len = win_len * overlap_hop_ratio 218 | 219 | stime = time.time() 220 | 221 | # Sanitize input 222 | audio.normalize(-16) 223 | audio.ensure_max_of_audio() 224 | 225 | nb, nac, nt = audio.audio_data.shape 226 | audio.audio_data = audio.audio_data.reshape(nb * nac, 1, nt) 227 | 228 | pad_length = math.ceil(audio.signal_duration / win_len) * win_len 229 | audio.zero_pad_to(int(pad_length * self.sr)) 230 | audio = audio.collect_windows(win_len, hop_len) 231 | 232 | print(win_len, hop_len, audio.batch_size, f"(processed in {(time.time() - stime) * 1000:.0f}ms)") 233 | stime = time.time() 234 | 235 | emb = [] 236 | for i in range(audio.batch_size): 237 | signal_from_batch = AudioSignal(audio.audio_data[i, ...], self.sr) 238 | signal_from_batch.to(self.device) 239 | e1 = self.model.encoder(signal_from_batch.audio_data).cpu() # [1, 1024, timeframes] 240 | e1 = e1[0] # [1024, timeframes] 241 | e1 = e1.transpose(0, 1) # [timeframes, 1024] 242 | emb.append(e1) 243 | 244 | emb = torch.cat(emb, dim=0) 245 | print(emb.shape, f'(computing finished in {(time.time() - stime) * 1000:.0f}ms)') 246 | 247 | return emb 248 | 249 | def load_wav(self, wav_file: Path): 250 | from audiotools import AudioSignal 251 | return AudioSignal(wav_file) 252 | 253 | 254 | class MERTModel(ModelLoader): 255 | """ 256 | MERT model from https://huggingface.co/m-a-p/MERT-v1-330M 257 | 258 | Please specify the layer to use (1-12). 259 | """ 260 | def __init__(self, size='v1-95M', layer=12, limit_minutes=6): 261 | super().__init__(f"MERT-{size}" + ("" if layer == 12 else f"-{layer}"), 768, 24000) 262 | self.huggingface_id = f"m-a-p/MERT-{size}" 263 | self.layer = layer 264 | self.limit = limit_minutes * 60 * self.sr 265 | 266 | def load_model(self): 267 | from transformers import Wav2Vec2FeatureExtractor, AutoModel, AutoConfig 268 | 269 | cfg = AutoConfig.from_pretrained(self.huggingface_id, trust_remote_code=True) 270 | cfg.conv_pos_batch_norm = False 271 | self.model = AutoModel.from_pretrained(self.huggingface_id, trust_remote_code=True, config=cfg) 272 | self.processor = Wav2Vec2FeatureExtractor.from_pretrained(self.huggingface_id, trust_remote_code=True) 273 | # self.sr = self.processor.sampling_rate 274 | self.model.to(self.device) 275 | 276 | def _get_embedding(self, audio: np.ndarray) -> np.ndarray: 277 | # Limit to 9 minutes 278 | if audio.shape[0] > self.limit: 279 | log.warning(f"Audio is too long ({audio.shape[0] / self.sr / 60:.2f} minutes > {self.limit / self.sr / 60:.2f} minutes). Truncating.") 280 | audio = audio[:self.limit] 281 | 282 | inputs = self.processor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device) 283 | with torch.no_grad(): 284 | out = self.model(**inputs, output_hidden_states=True) 285 | out = torch.stack(out.hidden_states).squeeze() # [13 layers, timeframes, 768] 286 | out = out[self.layer] # [timeframes, 768] 287 | 288 | return out 289 | 290 | 291 | class CLAPLaionModel(ModelLoader): 292 | """ 293 | CLAP model from https://github.com/LAION-AI/CLAP 294 | """ 295 | 296 | def __init__(self, type: Literal['audio', 'music']): 297 | super().__init__(f"clap-laion-{type}", 512, 48000) 298 | self.type = type 299 | 300 | if type == 'audio': 301 | url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/630k-audioset-best.pt' 302 | elif type == 'music': 303 | url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_audioset_epoch_15_esc_90.14.pt' 304 | 305 | self.model_file = Path(__file__).parent / ".model-checkpoints" / url.split('/')[-1] 306 | 307 | # Download file if it doesn't exist 308 | if not self.model_file.exists(): 309 | self.model_file.parent.mkdir(parents=True, exist_ok=True) 310 | download_file(url, self.model_file) 311 | 312 | # Patch the model file to remove position_ids (will raise an error otherwise) 313 | # This key must be removed for CLAP version <= 1.1.5 314 | # But it must be kept for CLAP version >= 1.1.6 315 | package_name = "laion_clap" 316 | from packaging import version 317 | ver = version.parse(importlib.metadata.version(package_name)) 318 | if ver < version.parse("1.1.6"): 319 | self.patch_model_430(self.model_file) 320 | else: 321 | self.unpatch_model_430(self.model_file) 322 | 323 | 324 | def patch_model_430(self, file: Path): 325 | """ 326 | Patch the model file to remove position_ids (will raise an error otherwise) 327 | This is a new issue after the transformers 4.30.0 update 328 | Please refer to https://github.com/LAION-AI/CLAP/issues/127 329 | """ 330 | # Create a "patched" file when patching is done 331 | patched = file.parent / f"{file.name}.patched.430" 332 | if patched.exists(): 333 | return 334 | 335 | log.warning("Patching LAION-CLAP's model checkpoints") 336 | 337 | # Load the checkpoint from the given path 338 | ck = torch.load(file, map_location="cpu") 339 | 340 | # Extract the state_dict from the checkpoint 341 | unwrap = isinstance(ck, dict) and "state_dict" in ck 342 | sd = ck["state_dict"] if unwrap else ck 343 | 344 | # Delete the specific key from the state_dict 345 | sd.pop("module.text_branch.embeddings.position_ids", None) 346 | 347 | # Save the modified state_dict back to the checkpoint 348 | if isinstance(ck, dict) and "state_dict" in ck: 349 | ck["state_dict"] = sd 350 | 351 | # Save the modified checkpoint 352 | torch.save(ck, file) 353 | log.warning(f"Saved patched checkpoint to {file}") 354 | 355 | # Create a "patched" file when patching is done 356 | patched.touch() 357 | 358 | 359 | def unpatch_model_430(self, file: Path): 360 | """ 361 | Since CLAP 1.1.6, its codebase provided its own workarounds that isn't compatible 362 | with our patch. This function will revert the patch to make it compatible with the new 363 | CLAP version. 364 | """ 365 | patched = file.parent / f"{file.name}.patched.430" 366 | if not patched.exists(): 367 | return 368 | 369 | # The below is an inverse operation of the patch_model_430 function, so comments are omitted 370 | log.warning("Unpatching LAION-CLAP's model checkpoints") 371 | ck = torch.load(file, map_location="cpu") 372 | unwrap = isinstance(ck, dict) and "state_dict" in ck 373 | sd = ck["state_dict"] if unwrap else ck 374 | sd["module.text_branch.embeddings.position_ids"] = 0 375 | if isinstance(ck, dict) and "state_dict" in ck: 376 | ck["state_dict"] = sd 377 | torch.save(ck, file) 378 | log.warning(f"Saved unpatched checkpoint to {file}") 379 | patched.unlink() 380 | 381 | 382 | def load_model(self): 383 | import laion_clap 384 | 385 | self.model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-tiny' if self.type == 'audio' else 'HTSAT-base') 386 | self.model.load_ckpt(self.model_file) 387 | self.model.to(self.device) 388 | 389 | def _get_embedding(self, audio: np.ndarray) -> np.ndarray: 390 | audio = audio.reshape(1, -1) 391 | 392 | # The int16-float32 conversion is used for quantization 393 | audio = self.int16_to_float32(self.float32_to_int16(audio)) 394 | 395 | # Split the audio into 10s chunks with 1s hop 396 | chunk_size = 10 * self.sr # 10 seconds 397 | hop_size = self.sr # 1 second 398 | chunks = [audio[:, i:i+chunk_size] for i in range(0, audio.shape[1], hop_size)] 399 | 400 | # Calculate embeddings for each chunk 401 | embeddings = [] 402 | for chunk in chunks: 403 | with torch.no_grad(): 404 | chunk = chunk if chunk.shape[1] == chunk_size else np.pad(chunk, ((0,0), (0, chunk_size-chunk.shape[1]))) 405 | chunk = torch.from_numpy(chunk).float().to(self.device) 406 | emb = self.model.get_audio_embedding_from_data(x = chunk, use_tensor=True) 407 | embeddings.append(emb) 408 | 409 | # Concatenate the embeddings 410 | emb = torch.cat(embeddings, dim=0) # [timeframes, 512] 411 | return emb 412 | 413 | def int16_to_float32(self, x): 414 | return (x / 32767.0).astype(np.float32) 415 | 416 | def float32_to_int16(self, x): 417 | x = np.clip(x, a_min=-1., a_max=1.) 418 | return (x * 32767.).astype(np.int16) 419 | 420 | 421 | class CdpamModel(ModelLoader): 422 | """ 423 | CDPAM model from https://github.com/pranaymanocha/PerceptualAudio/tree/master/cdpam 424 | """ 425 | def __init__(self, mode: Literal['acoustic', 'content']) -> None: 426 | super().__init__(f"cdpam-{mode}", 512, 22050) 427 | self.mode = mode 428 | assert mode in ['acoustic', 'content'], "Mode must be 'acoustic' or 'content'" 429 | 430 | def load_model(self): 431 | from cdpam import CDPAM 432 | self.model = CDPAM(dev=self.device) 433 | 434 | def _get_embedding(self, audio: np.ndarray) -> np.ndarray: 435 | audio = torch.from_numpy(audio).float().to(self.device) 436 | 437 | # Take 1s chunks 438 | chunk_size = self.sr 439 | frames = [] 440 | for i in range(0, audio.shape[1], chunk_size): 441 | chunk = audio[:, i:i+chunk_size] 442 | _, acoustic, content = self.model.model.base_encoder.forward(chunk.unsqueeze(1)) 443 | v = acoustic if self.mode == 'acoustic' else content 444 | v = F.normalize(v, dim=1) 445 | frames.append(v) 446 | 447 | # Concatenate the embeddings 448 | emb = torch.cat(frames, dim=0) # [timeframes, 512] 449 | return emb 450 | 451 | def load_wav(self, wav_file: Path): 452 | x, _ = librosa.load(wav_file, sr=self.sr) 453 | 454 | # Convert to 16 bit floating point 455 | x = np.round(x.astype(np.float) * 32768) 456 | x = np.reshape(x, [-1, 1]) 457 | x = np.reshape(x, [1, x.shape[0]]) 458 | x = np.float32(x) 459 | 460 | return x 461 | 462 | 463 | class CLAPModel(ModelLoader): 464 | """ 465 | CLAP model from https://github.com/microsoft/CLAP 466 | """ 467 | def __init__(self, type: Literal['2023']): 468 | super().__init__(f"clap-{type}", 1024, 44100) 469 | self.type = type 470 | 471 | if type == '2023': 472 | url = 'https://huggingface.co/microsoft/msclap/resolve/main/CLAP_weights_2023.pth' 473 | 474 | self.model_file = Path(__file__).parent / ".model-checkpoints" / url.split('/')[-1] 475 | 476 | # Download file if it doesn't exist 477 | if not self.model_file.exists(): 478 | self.model_file.parent.mkdir(parents=True, exist_ok=True) 479 | download_file(url, self.model_file) 480 | 481 | def load_model(self): 482 | from msclap import CLAP 483 | 484 | self.model = CLAP(self.model_file, version = self.type, use_cuda=self.device == torch.device('cuda')) 485 | #self.model.to(self.device) 486 | 487 | def _get_embedding(self, audio: np.ndarray) -> np.ndarray: 488 | audio = audio.reshape(1, -1) 489 | 490 | # The int16-float32 conversion is used for quantization 491 | #audio = self.int16_to_float32(self.float32_to_int16(audio)) 492 | 493 | # Split the audio into 7s chunks with 1s hop 494 | chunk_size = 7 * self.sr # 10 seconds 495 | hop_size = self.sr # 1 second 496 | chunks = [audio[:, i:i+chunk_size] for i in range(0, audio.shape[1], hop_size)] 497 | 498 | # zero-pad chunks to make equal length 499 | clen = [x.shape[1] for x in chunks] 500 | chunks = [np.pad(ch, ((0,0), (0,np.max(clen) - ch.shape[1]))) for ch in chunks] 501 | 502 | self.model.default_collate(chunks) 503 | 504 | # Calculate embeddings for each chunk 505 | embeddings = [] 506 | for chunk in chunks: 507 | with torch.no_grad(): 508 | chunk = chunk if chunk.shape[1] == chunk_size else np.pad(chunk, ((0,0), (0, chunk_size-chunk.shape[1]))) 509 | chunk = torch.from_numpy(chunk).float().to(self.device) 510 | emb = self.model.clap.audio_encoder(chunk)[0] 511 | embeddings.append(emb) 512 | 513 | # Concatenate the embeddings 514 | emb = torch.cat(embeddings, dim=0) # [timeframes, 1024] 515 | return emb 516 | 517 | def int16_to_float32(self, x): 518 | return (x / 32767.0).astype(np.float32) 519 | 520 | def float32_to_int16(self, x): 521 | x = np.clip(x, a_min=-1., a_max=1.) 522 | return (x * 32767.).astype(np.int16) 523 | 524 | 525 | class W2V2Model(ModelLoader): 526 | """ 527 | W2V2 model from https://huggingface.co/facebook/wav2vec2-base-960h, https://huggingface.co/facebook/wav2vec2-large-960h 528 | 529 | Please specify the size ('base' or 'large') and the layer to use (1-12 for 'base' or 1-24 for 'large'). 530 | """ 531 | def __init__(self, size: Literal['base', 'large'], layer: Literal['12', '24'], limit_minutes=6): 532 | model_dim = 768 if size == 'base' else 1024 533 | model_identifier = f"w2v2-{size}" + ("" if (layer == 12 and size == 'base') or (layer == 24 and size == 'large') else f"-{layer}") 534 | 535 | super().__init__(model_identifier, model_dim, 16000) 536 | self.huggingface_id = f"facebook/wav2vec2-{size}-960h" 537 | self.layer = layer 538 | self.limit = limit_minutes * 60 * self.sr 539 | 540 | def load_model(self): 541 | from transformers import AutoProcessor, Wav2Vec2Model 542 | 543 | self.model = Wav2Vec2Model.from_pretrained(self.huggingface_id) 544 | self.processor = AutoProcessor.from_pretrained(self.huggingface_id) 545 | self.model.to(self.device) 546 | 547 | def _get_embedding(self, audio: np.ndarray) -> np.ndarray: 548 | # Limit to specified minutes 549 | if audio.shape[0] > self.limit: 550 | log.warning(f"Audio is too long ({audio.shape[0] / self.sr / 60:.2f} minutes > {self.limit / self.sr / 60:.2f} minutes). Truncating.") 551 | audio = audio[:self.limit] 552 | 553 | inputs = self.processor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device) 554 | with torch.no_grad(): 555 | out = self.model(**inputs, output_hidden_states=True) 556 | out = torch.stack(out.hidden_states).squeeze() # [13 or 25 layers, timeframes, 768 or 1024] 557 | out = out[self.layer] # [timeframes, 768 or 1024] 558 | 559 | return out 560 | 561 | 562 | class HuBERTModel(ModelLoader): 563 | """ 564 | HuBERT model from https://huggingface.co/facebook/hubert-base-ls960, https://huggingface.co/facebook/hubert-large-ls960 565 | 566 | Please specify the size ('base' or 'large') and the layer to use (1-12 for 'base' or 1-24 for 'large'). 567 | """ 568 | def __init__(self, size: Literal['base', 'large'], layer: Literal['12', '24'], limit_minutes=6): 569 | model_dim = 768 if size == 'base' else 1024 570 | model_identifier = f"hubert-{size}" + ("" if (layer == 12 and size == 'base') or (layer == 24 and size == 'large') else f"-{layer}") 571 | 572 | super().__init__(model_identifier, model_dim, 16000) 573 | self.huggingface_id = f"facebook/hubert-{size}-ls960" 574 | self.layer = layer 575 | self.limit = limit_minutes * 60 * self.sr 576 | 577 | def load_model(self): 578 | from transformers import AutoProcessor, HubertModel 579 | 580 | self.model = HubertModel.from_pretrained(self.huggingface_id) 581 | self.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft") 582 | self.model.to(self.device) 583 | 584 | def _get_embedding(self, audio: np.ndarray) -> np.ndarray: 585 | # Limit to specified minutes 586 | if audio.shape[0] > self.limit: 587 | log.warning(f"Audio is too long ({audio.shape[0] / self.sr / 60:.2f} minutes > {self.limit / self.sr / 60:.2f} minutes). Truncating.") 588 | audio = audio[:self.limit] 589 | 590 | inputs = self.processor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device) 591 | with torch.no_grad(): 592 | out = self.model(**inputs, output_hidden_states=True) 593 | out = torch.stack(out.hidden_states).squeeze() # [13 or 25 layers, timeframes, 768 or 1024] 594 | out = out[self.layer] # [timeframes, 768 or 1024] 595 | 596 | return out 597 | 598 | 599 | class WavLMModel(ModelLoader): 600 | """ 601 | WavLM model from https://huggingface.co/microsoft/wavlm-base, https://huggingface.co/microsoft/wavlm-base-plus, https://huggingface.co/microsoft/wavlm-large 602 | 603 | Please specify the model size ('base', 'base-plus', or 'large') and the layer to use (1-12 for 'base' or 'base-plus' and 1-24 for 'large'). 604 | """ 605 | def __init__(self, size: Literal['base', 'base-plus', 'large'], layer: Literal['12', '24'], limit_minutes=6): 606 | model_dim = 768 if size in ['base', 'base-plus'] else 1024 607 | model_identifier = f"wavlm-{size}" + ("" if (layer == 12 and size in ['base', 'base-plus']) or (layer == 24 and size == 'large') else f"-{layer}") 608 | 609 | super().__init__(model_identifier, model_dim, 16000) 610 | self.huggingface_id = f"patrickvonplaten/wavlm-libri-clean-100h-{size}" 611 | self.layer = layer 612 | self.limit = limit_minutes * 60 * self.sr 613 | 614 | def load_model(self): 615 | from transformers import AutoProcessor, WavLMModel 616 | 617 | self.model = WavLMModel.from_pretrained(self.huggingface_id) 618 | self.processor = AutoProcessor.from_pretrained(self.huggingface_id) 619 | self.model.to(self.device) 620 | 621 | def _get_embedding(self, audio: np.ndarray) -> np.ndarray: 622 | # Limit to specified minutes 623 | if audio.shape[0] > self.limit: 624 | log.warning(f"Audio is too long ({audio.shape[0] / self.sr / 60:.2f} minutes > {self.limit / self.sr / 60:.2f} minutes). Truncating.") 625 | audio = audio[:self.limit] 626 | 627 | inputs = self.processor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device) 628 | with torch.no_grad(): 629 | out = self.model(**inputs, output_hidden_states=True) 630 | out = torch.stack(out.hidden_states).squeeze() # [13 or 25 layers, timeframes, 768 or 1024] 631 | out = out[self.layer] # [timeframes, 768 or 1024] 632 | 633 | return out 634 | 635 | 636 | class WhisperModel(ModelLoader): 637 | """ 638 | Whisper model from https://huggingface.co/openai/whisper-base 639 | 640 | Please specify the model size ('tiny', 'base', 'small', 'medium', or 'large'). 641 | """ 642 | def __init__(self, size: Literal['tiny', 'base', 'small', 'medium', 'large']): 643 | dimensions = { 644 | 'tiny': 384, 645 | 'base': 512, 646 | 'small': 768, 647 | 'medium': 1024, 648 | 'large': 1280 649 | } 650 | model_dim = dimensions.get(size) 651 | model_identifier = f"whisper-{size}" 652 | 653 | super().__init__(model_identifier, model_dim, 16000) 654 | self.huggingface_id = f"openai/whisper-{size}" 655 | 656 | def load_model(self): 657 | from transformers import AutoFeatureExtractor 658 | from transformers import WhisperModel 659 | 660 | self.model = WhisperModel.from_pretrained(self.huggingface_id) 661 | self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.huggingface_id) 662 | self.decoder_input_ids = (torch.tensor([[1, 1]]) * self.model.config.decoder_start_token_id).to(self.device) 663 | self.model.to(self.device) 664 | 665 | def _get_embedding(self, audio: np.ndarray) -> np.ndarray: 666 | inputs = self.feature_extractor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device) 667 | input_features = inputs.input_features.to(self.device) 668 | with torch.no_grad(): 669 | out = self.model(input_features, decoder_input_ids=self.decoder_input_ids).last_hidden_state # [1, timeframes, 512] 670 | out = out.squeeze() # [timeframes, (384 or 512 or 768 or 1024 or 1280)] 671 | 672 | return out 673 | 674 | 675 | 676 | def get_all_models() -> list[ModelLoader]: 677 | ms = [ 678 | CLAPModel('2023'), 679 | CLAPLaionModel('audio'), CLAPLaionModel('music'), 680 | VGGishModel(), 681 | *(MERTModel(layer=v) for v in range(1, 13)), 682 | EncodecEmbModel('24k'), EncodecEmbModel('48k'), 683 | # DACModel(), 684 | # CdpamModel('acoustic'), CdpamModel('content'), 685 | *(W2V2Model('base', layer=v) for v in range(1, 13)), 686 | *(W2V2Model('large', layer=v) for v in range(1, 25)), 687 | *(HuBERTModel('base', layer=v) for v in range(1, 13)), 688 | *(HuBERTModel('large', layer=v) for v in range(1, 25)), 689 | *(WavLMModel('base', layer=v) for v in range(1, 13)), 690 | *(WavLMModel('base-plus', layer=v) for v in range(1, 13)), 691 | *(WavLMModel('large', layer=v) for v in range(1, 25)), 692 | WhisperModel('tiny'), WhisperModel('small'), 693 | WhisperModel('base'), WhisperModel('medium'), 694 | WhisperModel('large'), 695 | ] 696 | if importlib.util.find_spec("dac") is not None: 697 | ms.append(DACModel()) 698 | if importlib.util.find_spec("cdpam") is not None: 699 | ms += [CdpamModel('acoustic'), CdpamModel('content')] 700 | 701 | return ms 702 | -------------------------------------------------------------------------------- /fadtk/package.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from .fad import FrechetAudioDistance 4 | from .model_loader import * 5 | from .fad_batch import cache_embedding_files 6 | 7 | if __name__ == "__main__": 8 | """ 9 | Launcher for packaging statistics of a directory using a model. 10 | """ 11 | models = {m.name: m for m in get_all_models()} 12 | 13 | agupa = ArgumentParser() 14 | agupa.add_argument('directory', type=str) 15 | agupa.add_argument('out', type=str) 16 | 17 | # Add optional arguments 18 | agupa.add_argument('-w', '--workers', type=int, default=8) 19 | agupa.add_argument('-s', '--sox-path', type=str, default='/usr/bin/sox') 20 | 21 | args = agupa.parse_args() 22 | 23 | out = Path(args.out) 24 | if out.suffix != '.npz': 25 | print('The output file you specified is not a npz file, are you sure? (y/N)') 26 | if input().lower() != 'y': 27 | exit(1) 28 | 29 | # 1. Calculate embedding files for each model 30 | for model in models.values(): 31 | cache_embedding_files(args.directory, model, workers=args.workers) 32 | 33 | # 2. Calculate statistics for each model 34 | data = {} 35 | for model in models.values(): 36 | fad = FrechetAudioDistance(model, load_model=False) 37 | mu, cov = fad.load_stats(args.directory) 38 | data[f'{model.name}.mu'] = mu 39 | data[f'{model.name}.cov'] = cov 40 | 41 | # 3. Save statistics 42 | np.savez(out, **data) -------------------------------------------------------------------------------- /fadtk/stats/fma_pop.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/stats/fma_pop.npz -------------------------------------------------------------------------------- /fadtk/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/__init__.py -------------------------------------------------------------------------------- /fadtk/test/__main__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import traceback 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from fadtk.fad import FrechetAudioDistance 7 | from fadtk.model_loader import get_all_models 8 | from hypy_utils.logging_utils import setup_logger 9 | 10 | log = setup_logger() 11 | 12 | if __name__ == '__main__': 13 | # Read samples csv 14 | fp = Path(__file__).parent 15 | reference = pd.read_csv(fp / 'samples_FAD_scores.csv') 16 | 17 | # Get reference models in column names 18 | reference_models = [c.split('_', 1)[1].replace('_fma_pop', '') for c in reference.columns if c.startswith('FAD_')] 19 | print("Models with reference data:", reference_models) 20 | 21 | # Compute FAD score 22 | for model in get_all_models(): 23 | if model.name.replace('-', '_') not in reference_models: 24 | print(f'No reference data for {model.name}, skipping') 25 | continue 26 | 27 | # Because of the heavy computation required to run each test, we limit the MERT models to only a few layers 28 | if model.name.startswith('MERT') and model.name[-1] not in ['1', '4', '8', 'M']: 29 | continue 30 | 31 | log.info(f'Computing FAD score for {model.name}') 32 | csv = fp / 'fad_scores' / f'{model.name}.csv' 33 | if csv.is_file(): 34 | continue 35 | 36 | fad = FrechetAudioDistance(model, audio_load_worker=1, load_model=True) 37 | 38 | # Cache embedding files 39 | try: 40 | for f in (fp / 'samples').glob('*.*'): 41 | fad.cache_embedding_file(f) 42 | except Exception as e: 43 | traceback.print_exc() 44 | log.error(f'Error when caching embedding files for {model.name}: {e}') 45 | exit(1) 46 | 47 | try: 48 | # Compute FAD score 49 | fad.score_individual('fma_pop', fp / 'samples', csv) 50 | except Exception as e: 51 | traceback.print_exc() 52 | log.error(f'Error when computing FAD score for {model.name}: {e}') 53 | exit(1) 54 | 55 | # Compute FAD for entire set 56 | all_score = fad.score('fma_pop', fp / 'samples') 57 | 58 | # Add all_score to csv with file name '/samples/all' 59 | data = pd.read_csv(csv, names=['file', 'score']) 60 | data = pd.concat([data, pd.DataFrame([['/samples/all', all_score]], columns=['file', 'score'])]) 61 | data.to_csv(csv, index=False, header=False) 62 | 63 | # Read from csvs 64 | table = [] 65 | for f in (fp / 'fad_scores').glob('*.csv'): 66 | model_name = f.stem.replace('-', '_') 67 | data = pd.read_csv(f, names=['file', 'score']) 68 | data['file'] = data['file'].replace(r'\\', '/', regex=True) # convert Windows paths 69 | data['file'] = data['file'].apply(lambda x: '/'.join(x.split('/')[-2:]).split('.')[0]) 70 | 71 | # Get the scores of the same model from the reference csv as an array 72 | # They should be in FAD_{model_name}_fma_pop column 73 | test = reference.loc[:, ['song_id', f'FAD_{model_name}_fma_pop']].copy() 74 | test.columns = ['file', 'score'] 75 | 76 | # Transform test to a dictionary of file: score 77 | test = test.set_index('file').to_dict()['score'] 78 | 79 | test = np.array([test[f] for f in data['file']]) 80 | data = np.array(data['score']) 81 | 82 | # Compare mean sqaurred error 83 | mse = ((data - test) ** 2).mean() 84 | max_abs_diff = np.abs(data - test).max() 85 | mean = np.mean(data) 86 | madp = max_abs_diff / mean * 100 87 | table.append({ 88 | 'model': model_name, 89 | 'mse': mse, 90 | 'max_abs_diff': max_abs_diff, 91 | 'mean': mean, 92 | 'mad%': madp, 93 | 'pass': madp < 5 # 5% threshold 94 | }) 95 | 96 | # Print table 97 | table = pd.DataFrame(table) 98 | log.info(table) 99 | table.to_csv(fp / 'comparison.csv') 100 | 101 | # If anything failed, exit with error code 2 102 | if not table['pass'].all(): 103 | log.error('Some models failed the test') 104 | exit(2) 105 | -------------------------------------------------------------------------------- /fadtk/test/samples/mg-1634.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mg-1634.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mg-1648.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mg-1648.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mg-1741.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mg-1741.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mg-2344.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mg-2344.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mg-2551.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mg-2551.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mg-2759.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mg-2759.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mg-284.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mg-284.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mg-483.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mg-483.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mg-66.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mg-66.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mg-911.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mg-911.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mg-974.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mg-974.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mlm-1619.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mlm-1619.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mlm-1698.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mlm-1698.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mlm-2940.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mlm-2940.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mlm-483.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mlm-483.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mlm-974.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mlm-974.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mubert-130.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mubert-130.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mubert-1348.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mubert-1348.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mubert-1619.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mubert-1619.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mubert-2120.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mubert-2120.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mubert-2545.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mubert-2545.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mubert-2701.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mubert-2701.opus -------------------------------------------------------------------------------- /fadtk/test/samples/mubert-806.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/fadtk/736bea03ce4646939aed88d6344378359e86a03a/fadtk/test/samples/mubert-806.opus -------------------------------------------------------------------------------- /fadtk/test/samples_FAD_scores.csv: -------------------------------------------------------------------------------- 1 | song_id,dataset,FAD_MERT_v1_95M_fma_pop,FAD_MERT_v1_95M_1_fma_pop,FAD_MERT_v1_95M_10_fma_pop,FAD_MERT_v1_95M_11_fma_pop,FAD_MERT_v1_95M_2_fma_pop,FAD_MERT_v1_95M_3_fma_pop,FAD_MERT_v1_95M_4_fma_pop,FAD_MERT_v1_95M_5_fma_pop,FAD_MERT_v1_95M_6_fma_pop,FAD_MERT_v1_95M_7_fma_pop,FAD_MERT_v1_95M_8_fma_pop,FAD_MERT_v1_95M_9_fma_pop,FAD_cdpam_acoustic_fma_pop,FAD_cdpam_content_fma_pop,FAD_clap_2023_fma_pop,FAD_clap_laion_audio_fma_pop,FAD_clap_laion_music_fma_pop,FAD_dac_44kHz_fma_pop,FAD_encodec_emb_fma_pop,FAD_encodec_emb_48k_fma_pop,FAD_vggish_fma_pop 2 | samples/all,all,17.507623637608,5.99507060988391,9.94651184353292,9.37752873733533,6.73767164102728,6.70681690772166,6.72373361429806,7.12919250152479,7.15825634627564,7.51992102891785,7.77216661926487,8.43765497705942,0.377092359959228,0.47654955691766,393.814494087621,0.463373894226911,0.538692108488325,522.632204202623,203.768676678184,76.0831788663788,5.12190021819319 3 | samples/mg-1634,mg,71.3362036099082,39.8437203020927,43.0415715642971,37.7596380812923,43.6583277552788,40.4821531917241,39.2641409017981,39.370442521202,38.3304482556851,39.5591055617214,40.0412283882476,39.2260397890029,0.778897442658879,1.13394646892678,1294.84208192179,1.61757086823932,1.65739017160921,1957.05497330936,379.363303084985,85.5908273752556,21.1936975843131 4 | samples/mg-1648,mg,48.8742065861995,19.1292334750774,29.2418803601123,27.2339161030925,22.6311159717578,23.1818971440938,22.3887182311243,24.0421732697707,23.9437197910516,25.3930147933895,25.8151493144358,25.7397508931274,0.6453184096024,0.916682708889174,842.0711302362,1.0430790021216,1.18698420181586,629.849342310313,129.107436357212,63.6964833896851,13.2565169640479 5 | samples/mg-1741,mg,77.0597172166065,38.6971619098325,44.6550587271052,38.4117758526756,41.6240556044805,42.6492858701077,43.7701520996705,45.8179236869159,45.4652780605834,46.0753878784217,45.4668447607005,42.9057410327667,0.611267179545179,0.618816291338183,980.717389933703,1.3614043099841,1.33289644400882,1893.10907311245,295.461173766258,105.799267868335,14.7468648485356 6 | samples/mg-2344,mg,54.9708276589295,21.6168903310035,33.268342265003,30.3646887094772,24.5554528856044,25.3512628251146,24.8953403419635,27.2002940840582,26.4766838273595,27.9599042508857,28.6521443333935,29.0028565341121,0.427708009742122,0.630103220536839,744.439501233073,1.07706552883832,1.25782860559843,407.64058467579,147.235847586915,70.8576667376201,13.1076233339954 7 | samples/mg-2551,mg,56.2368839807781,24.1691851965527,35.5530242173428,32.9998409081344,27.3750554494129,28.1386650023749,27.7384169449046,29.1202861667472,28.8348492237499,30.4831309685653,31.320417355044,31.160535338523,0.413011708946023,0.564915638509509,791.962140875087,1.07834391443475,1.24739895464233,609.33935152996,139.481695843542,68.6070534525258,15.1091090242704 8 | samples/mg-2759,mg,62.2369916193419,27.2624965504285,37.6459748319574,33.2103431029877,30.8377798351998,31.4479762119181,31.5118384118286,33.081790286326,33.1889715283694,35.5768225432313,35.8112717978789,34.8501358528227,0.504052113146365,0.509351823170957,822.848493978304,1.26486211518186,1.3422612611368,1193.58372779068,201.103896552495,105.684368282768,12.2754878191094 9 | samples/mg-284,mg,55.399311415545,26.6973849159013,33.5758131438032,29.0805823012182,30.5947423422831,30.6350474797407,29.3666209515128,30.260204217221,30.1069530013485,30.5936895150823,30.6610273353169,29.986522094177,0.801224109839,1.20238913845932,777.062695848219,1.11450308181136,1.28464376224534,824.170720722038,182.692049199627,80.6855174635208,13.5742702203118 10 | samples/mg-483,mg,104.432234302718,46.7548364962776,58.0706948043744,53.930141353241,50.4743473660986,49.6249358259456,49.0336657018291,50.2519652492477,50.7287217938592,50.4662123267369,52.8844325131702,52.5690248050584,0.931939685224798,1.10594010905462,1202.98475591413,1.31708189944967,1.50362239903169,2193.37186729926,453.948527633986,126.046448475786,17.5861877480217 11 | samples/mg-66,mg,45.8694014843326,19.5577146857098,28.3062207363163,25.3655137046982,23.0231252938428,23.3079717206791,22.392078920764,23.7988017891072,23.206145653285,24.158695201948,24.8618135752129,24.7809293583104,0.406231599328502,0.601530851126908,808.559234309483,1.10365601184615,1.24564064620103,493.343902619674,134.094986692187,58.7062067381099,14.8002430508453 12 | samples/mg-911,mg,55.9829090603676,23.0503115568987,30.3712459369207,26.45543015949,26.0256094477589,26.4649513028545,25.9864960701806,27.8524915965922,28.4163233210563,28.5705315369895,28.2677392499928,27.6171239725696,0.825487225935979,1.13433964275771,689.603806487139,1.01235747833845,1.17477355584989,830.233194656376,210.175363353941,98.1807173463038,12.5859739896452 13 | samples/mg-974,mg,66.7221525802332,27.0102029263874,36.6700901902076,34.4613933413184,31.8416844021753,33.153409238938,32.8431400657846,34.5015504147122,34.2633879553798,34.9320539198528,34.4355693433223,33.2521991026848,0.451192977186941,0.563557726206954,1179.59124005687,1.300793240673,1.40602544058679,766.486395072358,161.23115306535,69.9412731626551,18.6844645734393 14 | samples/mlm-1619,mlm,85.5273025758343,40.5591641198165,46.8461822769414,40.683711019671,43.9978642390336,43.7805407622785,43.6008395938013,44.0286183693192,43.4697429008807,44.0331696796011,44.9007642864178,44.1337270678425,0.699871214430989,0.849968532037253,1071.29382294647,1.45565446674086,1.5499924760128,2461.8749791105,425.376212459412,106.351814237451,15.0694401383733 15 | samples/mlm-1698,mlm,68.4294008985223,29.6346459988953,40.0606200020435,35.7108474127443,32.6866755118191,33.7118265833757,33.7633349760594,34.6593057572188,34.6704178295455,36.4454354502372,36.766483053307,37.1137464967243,0.762597036549152,0.609984186048036,805.336403050304,1.18999870916053,1.31291739698564,2420.53432634619,431.532937811971,165.026333191364,12.9596215188161 16 | samples/mlm-2940,mlm,55.0506749919466,28.4506804543523,36.7270771559485,33.6859943695681,33.0044855542651,33.5330450300275,31.7578770210049,33.2700848839499,32.2589829981452,33.1153762939929,33.3343716341666,33.0833920406232,0.464994273597244,0.80710328791953,757.519264405368,1.06415353187918,1.29784461088253,1318.55950031994,144.768956959405,69.622762187482,17.9456511436777 17 | samples/mlm-483,mlm,97.7937646998296,41.2079217129662,52.8920568092558,45.9620854351014,44.1255160136679,42.8787267791423,42.3912159871317,43.1612731121052,43.4402704937094,43.7959955667822,45.5758118744137,46.5547951484056,0.919072390857642,1.11006049202089,885.443521133834,1.241367847222,1.39273989781635,2981.88978126039,500.504292833461,135.109786913593,21.8693849024174 18 | samples/mlm-974,mlm,111.889331897077,56.9792102433184,62.6900350692379,53.5463649352641,61.3395590392334,58.9282019366559,58.9073850968861,57.8035877167931,56.938719439565,57.6862073478158,59.4618540055342,59.1655528451923,1.00347289379112,0.915323603462277,999.273424859983,1.44415982969382,1.43029937425656,4348.28446307349,732.653663068936,123.012612967667,13.9441760969872 19 | samples/mubert-130,mubert,50.8235283315129,22.378165942894,33.7484475464131,30.9482257113891,26.9890662414308,27.9837793347469,27.7481991361942,29.9285845215756,29.7841793622775,31.0659447732828,30.8607799515758,30.440360250389,0.555648514118777,0.866070977625159,796.228004407963,1.03005028114499,1.25765071373695,615.189890226723,219.559389390064,110.310227568308,11.803243997576 20 | samples/mubert-1348,mubert,62.6121252062435,32.7925985944015,41.3297185990143,39.7789498942616,36.1143550221214,36.3989510421795,36.031717004017,38.0792211944131,37.2018128841214,37.3474747603751,37.1449058725771,36.5637016323658,0.575725539139856,0.867632393370907,906.528426381916,1.01274777947414,1.22193754212061,874.420618272201,264.676394024166,116.588752110439,16.4978571303347 21 | samples/mubert-1619,mubert,57.7727778637722,24.4647346297904,38.6344912004916,35.5765291939075,28.5859277306325,30.1221738234506,30.064416616189,32.4384402237868,32.9438130409224,34.7622298542342,35.1364841931532,34.1400886134817,0.515306099671927,0.571470513482893,818.623103449384,1.07822559523702,1.23283852877291,733.082151640439,228.123376978644,104.702140253941,11.7600279825437 22 | samples/mubert-2120,mubert,75.3419219844084,29.6577976597442,43.648951392431,40.3027239372224,38.7885958786418,38.5902761102932,36.3963680531586,38.0565296749038,37.1837180470804,38.0270398025062,38.1776356925927,38.8835931355453,0.561327271238918,0.690617981068813,818.406244635316,1.09597624912158,1.24455354485224,924.160201926797,271.349773329117,113.71921956066,14.2967682647498 23 | samples/mubert-2545,mubert,76.5445250623636,34.0802206642497,46.6516969323719,39.7696583683051,39.2688173875896,40.9649098986066,40.8391181015553,41.9852400048466,41.3646566047989,42.5189934205931,44.2233943555682,44.2900771880233,0.896823858258401,1.20478117157653,851.924257730908,1.27246867580541,1.3544066150614,2930.11512929638,520.958262021986,177.414736262219,17.5662757763885 24 | samples/mubert-2701,mubert,80.5515054820278,38.3306151174236,49.4688787798489,42.7001415091607,44.6608673654546,46.9413223251688,45.2684402873795,46.6601943712715,45.5826274742136,46.6026806977783,47.9051688234711,46.3651649687786,0.98384531308024,1.22557193849572,844.263735870806,1.24125398619656,1.30180078809523,2396.76442727082,539.946750203467,176.750314361437,23.5384476473007 25 | samples/mubert-806,mubert,84.3694845057076,48.1099386432757,55.737843252305,46.3145193220992,50.758929065534,52.6613893460833,53.832603465228,56.8606118412191,56.5052420055989,57.3205701156203,56.3484574494681,53.9732795813797,0.686687126997863,0.978479105341783,893.627802738387,1.41071587586771,1.39823628902474,2049.15705349703,505.580991939341,123.812629732119,15.7087936836904 26 | -------------------------------------------------------------------------------- /fadtk/test/test_cleanup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Get script path's directory 4 | src=$(dirname "$0") 5 | echo "Cleaning up $src" 6 | 7 | rm -rfv "$src/fad_scores" 8 | rm -rfv "$src/comparison.csv" 9 | rm -rfv "$src/samples/convert" 10 | rm -rfv "$src/samples/embeddings" 11 | -------------------------------------------------------------------------------- /fadtk/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import subprocess 3 | import numpy as np 4 | from typing import Union 5 | 6 | from hypy_utils.nlp_utils import substr_between 7 | from hypy_utils.tqdm_utils import pmap 8 | 9 | 10 | PathLike = Union[str, Path] 11 | 12 | 13 | def _process_file(file: PathLike): 14 | embd = np.load(file) 15 | n = embd.shape[0] 16 | return np.mean(embd, axis=0), np.cov(embd, rowvar=False) * (n - 1), n 17 | 18 | 19 | def calculate_embd_statistics_online(files: list[PathLike]) -> tuple[np.ndarray, np.ndarray]: 20 | """ 21 | Calculate the mean and covariance matrix of a list of embeddings in an online manner. 22 | 23 | :param files: A list of npy files containing ndarrays with shape (n_frames, n_features) 24 | """ 25 | assert len(files) > 0, "No files provided" 26 | 27 | # Load the first file to get the embedding dimension 28 | embd_dim = np.load(files[0]).shape[-1] 29 | 30 | # Initialize the mean and covariance matrix 31 | mu = np.zeros(embd_dim) 32 | S = np.zeros((embd_dim, embd_dim)) # Sum of squares for online covariance computation 33 | n = 0 # Counter for total number of frames 34 | 35 | results = pmap(_process_file, files, desc='Calculating statistics') 36 | for _mu, _S, _n in results: 37 | delta = _mu - mu 38 | mu += _n / (n + _n) * delta 39 | S += _S + delta[:, None] * delta[None, :] * n * _n / (n + _n) 40 | n += _n 41 | 42 | if n < 2: 43 | return mu, np.zeros_like(S) 44 | else: 45 | cov = S / (n - 1) # compute the covariance matrix 46 | return mu, cov 47 | 48 | 49 | def find_sox_formats(sox_path: str) -> list[str]: 50 | """ 51 | Find a list of file formats supported by SoX 52 | """ 53 | try: 54 | out = subprocess.check_output((sox_path, "-h")).decode() 55 | return substr_between(out, "AUDIO FILE FORMATS: ", "\n").split() 56 | except: 57 | return [] 58 | 59 | 60 | def get_cache_embedding_path(model: str, audio_dir: PathLike) -> Path: 61 | """ 62 | Get the path to the cached embedding npy file for an audio file. 63 | 64 | :param model: The name of the model 65 | :param audio_dir: The path to the audio file 66 | """ 67 | audio_dir = Path(audio_dir) 68 | return audio_dir.parent / "embeddings" / model / audio_dir.with_suffix(".npy").name 69 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "fadtk" 3 | version = "1.1.0" 4 | description = "A simple and standardized library for Frechet Audio Distance calculation." 5 | authors = [{ name = "Azalea", email = "me@hydev.org" }] 6 | requires-python = ">=3.10,<=3.13" 7 | readme = "README.md" 8 | dependencies = [ 9 | "numpy<2", 10 | "pandas>=2.0.3,<3", 11 | "scipy>=1.11.2,<2", 12 | "hypy-utils>=1.0.19,<2", 13 | "soundfile>=0.12.1,<1", 14 | "wheel>=0.41.1,<1", 15 | "librosa>=0.10.1,<1", 16 | "encodec>=0.1.1,<1", 17 | "transformers>=4.30.0,<5", 18 | "laion-clap>=1.1.6,<2", 19 | "nnaudio>=0.3.2,<1", 20 | "torch>=2.3.0,<3", 21 | "resampy", 22 | "llvmlite>=0.44", 23 | "torchvision>=0.22.0", 24 | "msclap>=1.3.4", 25 | "audioread>=3.0.1", 26 | ] 27 | 28 | [project.urls] 29 | Homepage = "https://github.com/Microsoft/fadtk" 30 | Repository = "https://github.com/Microsoft/fadtk" 31 | 32 | [project.scripts] 33 | fadtk = "fadtk.__main__:main" 34 | fadtk-embeds = "fadtk.embeds.__main__:main" 35 | 36 | [build-system] 37 | requires = ["hatchling"] 38 | build-backend = "hatchling.build" 39 | --------------------------------------------------------------------------------