├── .gitignore ├── README.md ├── class_mapping.json ├── configs ├── config_embeddings.yaml └── config_ge2e.yaml ├── notebooks └── VisualizeEmbeddings.ipynb ├── requirements.txt ├── scripts ├── README.md ├── download_resources.py ├── generate_embeddings.py ├── get_classification_metrics.py ├── ood_detector.py └── preprocess_dataset.py ├── src ├── datasets │ ├── __init__.py │ ├── dataset.py │ ├── samplers.py │ └── utils.py ├── losses.py ├── models │ ├── NSD.py │ ├── __init__.py │ ├── w2v2_aasist.py │ └── w2v2_encoder.py └── utils.py ├── train_ge2e.py └── train_refd.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | data -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Baselines for Interspeech 2025 Special Session on Source Tracing 2 | 3 | The following repository contains baselines to start your work with the task of DeepFake Source Tracing as 4 | part of [Source tracing: The origins of synthetic or manipulated speech](https://www.interspeech2025.org/special-sessions) INTERSPEECH 2025 Special Session. 5 | 6 | ## Attribution 7 | 8 | Special thanks to [Resemble AI](https://www.resemble.ai) and [AI4Trust project](https://ai4trust.eu/) for their support and affiliation. 9 | 10 | ## Contributors 11 | - [Piotr Kawa](https://github.com/piotrkawa), 12 | - [Adriana Stan](https://github.com/adrianastan), 13 | - [Nicolas M. Müller](https://github.com/mueller91). 14 | 15 | 16 | ## Before you start 17 | 18 | ### Download dataset 19 | 20 | The baseline is based on the [MLAAD (Source Tracing Protocols) dataset](https://deepfake-total.com/sourcetracing). 21 | To download the required resources run: 22 | ```bash 23 | python scripts/download_resources.py 24 | ``` 25 | 26 | The default scripts' arguments assume that all the required data is put into `data` dir in the project root directory. 27 | 28 | ### Install dependencies 29 | 30 | Install all the required dependencies from the `requirements.txt` file. The baseline was created using Python 3.11. 31 | ```bash 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | 36 | ## GE2E + Wav2Vec2.0 Baseline 37 | 38 | To train the feature extractor based on Wav2Vec2.0-based encoder using [GE2E-Loss](https://arxiv.org/pdf/1710.10467) run: 39 | ```bash 40 | python train_ge2e.py --config configs/config_ge2e.yaml 41 | ``` 42 | 43 | 44 | ## REFD Baseline 45 | 46 | This baseline builds upon the work of Xie et al. ["Generalized Source Tracing: Detecting Novel Audio Deepfake Algorithm 47 | with Real Emphasis and Fake Dispersion Strategy"](https://arxiv.org/abs/2406.03240) and its associated [Github repo](https://github.com/xieyuankun/REFD/). 48 | 49 | The work uses a data augmentation technique and an OOD detection method to improve the classification of unseen 50 | deepfake algorithms. However, in this repository we implement the very basic setup, and leave potential 51 | authors the option to improve upon it. 52 | 53 | 54 |
55 | More details here 56 | 57 | 58 | ### Download data augmentation datasets 59 | 60 | 61 | For the required data augmentation step you will need the [MUSAN](https://www.openslr.org/17/) and [RIRS_NOISES](https://www.openslr.org/28/) datasets. 62 | 63 | 64 | 65 | ### Step 1. Data augmentation and feature extraction 66 | 67 | The first step of the tool reads the original MLAAD data, augments it with random noise and RIR and extracts 68 | the `wav2vec2-base` features needed to train the AASIST model. Additional parameters can be set from the script, 69 | such as max length, model, etc. 70 | 71 | ```bash 72 | python scripts/preprocess_dataset.py 73 | ``` 74 | 75 | Output will be written to `exp/preprocess_wav2vec2-base/`. You can change the path in the script. 76 | 77 | ### Step 2. Train a AASIST model on top of the wav2vec2-base features 78 | 79 | Using the augmented features, we then train an AASIST model for 30 epochs. The model is able to classify the samples 80 | with respect to the source system. The class assignment will be written to `exp/label_assignment.txt`. 81 | 82 | ```bash 83 | python train_refd.py 84 | ``` 85 | 86 | ### Step 3. Get the classification metrics for the known (in-domain) classes 87 | 88 | Given the trained model stored in `exp/trained_models/`, we can now compute its accuracy over known classes (those 89 | seen during training time). 90 | 91 | ```bash 92 | python scripts/get_classification_metrics.py 93 | ``` 94 | 95 | The script will limit the data in the `dev` and `eval` sets to the samples which are from the known systems 96 | (i.e. those also present in the training data) and compute their classification metrics. 97 | 98 | ### Step 4. Run the OOD detector and evaluate it 99 | 100 | ```bash 101 | python scripts/ood_detector.py --feature_extraction_step 102 | ``` 103 | The script builds an NSD OOD detector as described in the original paper. The OOD detector is based on the hidden states and logits of the AASIST model. It first extracts all this info from the trained model and stores it in separate dicts. It then loads the training data and determines the in-domain scores. 104 | 105 | It then computes the scores for the development set. Based on these scores for which we know the OOD class assignments 106 | it determines the EER and associated threshold. The computed threshold is then used for providing the 107 | classification into OOD and known systems metrics for the evaluation data. 108 | 109 | The baseline results at this point is a 63% EER with an F1-score of 0.31 for the eval data. 110 | 111 |
112 | 113 | 114 | ## License 115 | This repository is licensed under the [CC BY-NC 4.0 License](https://creativecommons.org/licenses/by-nc/4.0/) for original content. 116 | 117 | 118 | ### Exceptions: 119 | - Portions of this repository include code from [REFD repository](https://github.com/xieyuankun/REFD/), which does not have a license. 120 | - As per copyright law, such code is "All Rights Reserved" and is not covered by the CC BY-NC license. Users should not reuse or redistribute it without the original author's explicit permission. 121 | 122 | 123 | ## References 124 | The following repository is built using the following open-source repositories: 125 | 126 | 127 | ### REFD 128 | * [GitHub repository](https://github.com/xieyuankun/REFD), 129 | ``` 130 | @inproceedings{xie24_interspeech, 131 | title = {Generalized Source Tracing: Detecting Novel Audio Deepfake Algorithm with Real Emphasis and Fake Dispersion Strategy}, 132 | author = {Yuankun Xie and Ruibo Fu and Zhengqi Wen and Zhiyong Wang and Xiaopeng Wang and Haonnan Cheng and Long Ye and Jianhua Tao}, 133 | year = {2024}, 134 | booktitle = {Interspeech 2024}, 135 | pages = {4833--4837}, 136 | doi = {10.21437/Interspeech.2024-254}, 137 | issn = {2958-1796}, 138 | } 139 | ``` 140 | 141 | 142 | ### Coqui.ai TTS 143 | * [GitHub repository](https://github.com/coqui-ai/TTS), 144 | 145 | ``` 146 | @software{Eren_Coqui_TTS_2021, 147 | author = {Eren, Gölge and {The Coqui TTS Team}}, 148 | doi = {10.5281/zenodo.6334862}, 149 | license = {MPL-2.0}, 150 | month = jan, 151 | title = {{Coqui TTS}}, 152 | url = {https://github.com/coqui-ai/TTS}, 153 | version = {1.4}, 154 | year = {2021} 155 | } 156 | ``` 157 | -------------------------------------------------------------------------------- /class_mapping.json: -------------------------------------------------------------------------------- 1 | { 2 | "Mars5": 0, 3 | "MatchaTTS": 1, 4 | "MeloTTS": 2, 5 | "Metavoice-1B": 3, 6 | "OpenVoiceV2": 4, 7 | "WhisperSpeech": 5, 8 | "e2-tts": 6, 9 | "f5-tts": 7, 10 | "facebook/mms-tts-deu": 8, 11 | "facebook/mms-tts-eng": 9, 12 | "facebook/mms-tts-fin": 10, 13 | "facebook/mms-tts-fra": 11, 14 | "facebook/mms-tts-hun": 12, 15 | "facebook/mms-tts-nld": 13, 16 | "facebook/mms-tts-ron": 14, 17 | "facebook/mms-tts-rus": 15, 18 | "facebook/mms-tts-swe": 16, 19 | "facebook/mms-tts-ukr": 17, 20 | "griffin_lim": 18, 21 | "microsoft/speecht5_tts": 19, 22 | "optispeech": 20, 23 | "parler_tts_large_v1": 21, 24 | "parler_tts_mini_v0.1": 22, 25 | "parler_tts_mini_v1": 23, 26 | "suno/bark": 24, 27 | "suno/bark-small": 25, 28 | "tts_models/bg/cv/vits": 26, 29 | "tts_models/bn/custom/vits-female": 27, 30 | "tts_models/bn/custom/vits-male": 28, 31 | "tts_models/cs/cv/vits": 29, 32 | "tts_models/da/cv/vits": 30, 33 | "tts_models/de/css10/vits-neon": 31, 34 | "tts_models/de/thorsten/tacotron2-DCA": 32, 35 | "tts_models/de/thorsten/tacotron2-DDC": 33, 36 | "tts_models/de/thorsten/vits": 34, 37 | "tts_models/el/cv/vits": 35, 38 | "tts_models/en/blizzard2013/capacitron-t2-c50": 36, 39 | "tts_models/en/ek1/tacotron2": 37, 40 | "tts_models/en/jenny/jenny": 38, 41 | "tts_models/en/ljspeech/fast_pitch": 39, 42 | "tts_models/en/ljspeech/glow-tts": 40, 43 | "tts_models/en/ljspeech/neural_hmm": 41, 44 | "tts_models/en/ljspeech/overflow": 42, 45 | "tts_models/en/ljspeech/speedy-speech": 43, 46 | "tts_models/en/ljspeech/tacotron2-DCA": 44, 47 | "tts_models/en/ljspeech/tacotron2-DDC": 45, 48 | "tts_models/en/ljspeech/tacotron2-DDC_ph": 46, 49 | "tts_models/en/ljspeech/vits": 47, 50 | "tts_models/en/ljspeech/vits--neon": 48, 51 | "tts_models/en/multi-dataset/tortoise-v2": 49, 52 | "tts_models/en/sam/tacotron-DDC": 50, 53 | "tts_models/es/css10/vits": 51, 54 | "tts_models/es/mai/tacotron2-DDC": 52, 55 | "tts_models/et/cv/vits": 53, 56 | "tts_models/fa/custom/glow-tts": 54, 57 | "tts_models/fi/css10/vits": 55, 58 | "tts_models/fr/css10/vits": 56, 59 | "tts_models/fr/mai/tacotron2-DDC": 57, 60 | "tts_models/ga/cv/vits": 58, 61 | "tts_models/hr/cv/vits": 59, 62 | "tts_models/hu/css10/vits": 60, 63 | "tts_models/it/mai_female/glow-tts": 61, 64 | "tts_models/it/mai_female/vits": 62, 65 | "tts_models/it/mai_male/glow-tts": 63, 66 | "tts_models/it/mai_male/vits": 64, 67 | "tts_models/lt/cv/vits": 65, 68 | "tts_models/lv/cv/vits": 66, 69 | "tts_models/mt/cv/vits": 67, 70 | "tts_models/multilingual/multi-dataset/bark": 68, 71 | "tts_models/multilingual/multi-dataset/xtts_v1.1": 69, 72 | "tts_models/multilingual/multi-dataset/xtts_v2": 70, 73 | "tts_models/pl/mai_female/vits": 71, 74 | "tts_models/pt/cv/vits": 72, 75 | "tts_models/ro/cv/vits": 73, 76 | "tts_models/sk/cv/vits": 74, 77 | "tts_models/sl/cv/vits": 75, 78 | "tts_models/sv/cv/vits": 76, 79 | "tts_models/tr/common-voice/glow-tts": 77, 80 | "tts_models/uk/mai/glow-tts": 78, 81 | "tts_models/uk/mai/vits": 79, 82 | "tts_models/zh-CN/baker/tacotron2-DDC-GST": 80, 83 | "vixTTS": 81 84 | } -------------------------------------------------------------------------------- /configs/config_embeddings.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | mlaad_root_path: data/MLAADv5 3 | protocols_root_path: data/MLAADv5_for_sourcetracing 4 | sampling_rate: 16000 5 | sample_length_s: 4 6 | num_workers: 4 7 | batch_size: 32 8 | 9 | training: 10 | num_epochs: 4 11 | lr: 0.0001 12 | save_path: models/baseline_model 13 | log_interval: 10 14 | seed: 42 15 | 16 | model: 17 | model_name: wav2vec2 18 | checkpoint_path: "" 19 | -------------------------------------------------------------------------------- /configs/config_ge2e.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | mlaad_root_path: data/MLAADv5 3 | protocols_root_path: data/MLAADv5_for_sourcetracing 4 | sampling_rate: 16000 5 | sample_length_s: 4 6 | num_workers: 4 7 | 8 | training: 9 | num_epochs: 4 10 | lr: 0.0001 11 | save_path: models/baseline_model 12 | log_interval: 10 13 | seed: 42 14 | n_classes_in_batch: 10 15 | n_utter_per_class: 4 16 | 17 | model: 18 | model_name: wav2vec2 19 | checkpoint_path: "" 20 | -------------------------------------------------------------------------------- /notebooks/VisualizeEmbeddings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Visualize Embeddings\n", 8 | "\n", 9 | "The following notebook hels to visualize embeddings using UMAP algorithm.\n", 10 | "Make sure to generate them using the `generate_embeddings.py` script.\n" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "\"\"\"\n", 20 | "DISCLAIMER:\n", 21 | "This code is provided \"as-is\" without any warranty of any kind, either expressed or implied,\n", 22 | "including but not limited to the implied warranties of merchantability and fitness for a particular purpose.\n", 23 | "The author assumes no liability for any damages or consequences resulting from the use of this code.\n", 24 | "Use it at your own risk.\n", 25 | "\n", 26 | "Utility to download and extract all resources needed for the MLAADv5 project.\n", 27 | "\n", 28 | "This script handles the downloading of large files with progress bars, ensures\n", 29 | "caching of already downloaded files, and extracts `.zip` files using 7-Zip.\n", 30 | "\n", 31 | "## Author: Piotr KAWA\n", 32 | "## December 2024\n", 33 | "\"\"\"" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import json\n", 43 | "from pathlib import Path\n", 44 | "\n", 45 | "import matplotlib.pyplot as plt\n", 46 | "import numpy as np\n", 47 | "import pandas as pd\n", 48 | "import umap" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "embeddings_root_dir = \"../data/embeddings\"" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "def find_samples(embeddings_dir_path: Path) -> list[dict]:\n", 67 | " embeddings_dir_path = Path(embeddings_dir_path)\n", 68 | " samples = []\n", 69 | " for p in embeddings_dir_path.rglob(\"*.npy\"):\n", 70 | " samples.append(\n", 71 | " {\n", 72 | " \"embedding_path\": str(p),\n", 73 | " \"class_id\": p.parent.name,\n", 74 | " }\n", 75 | " )\n", 76 | " return samples" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "embeddings_dir_root = Path(embeddings_root_dir)\n", 86 | "train_subdir_root = embeddings_dir_root / \"train\"\n", 87 | "dev_subdir_root = embeddings_dir_root / \"dev\"\n", 88 | "test_subdir_root = embeddings_dir_root / \"test\"\n", 89 | "\n", 90 | "train_and_dev_samples = pd.DataFrame(\n", 91 | " find_samples(train_subdir_root) + find_samples(dev_subdir_root)\n", 92 | ")\n", 93 | "test_samples = pd.DataFrame(find_samples(test_subdir_root))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "with open(\"../class_mapping.json\") as f:\n", 103 | " class_mapping = json.load(f)\n", 104 | "\n", 105 | "inv_class_mapping = {}\n", 106 | "\n", 107 | "for k, v in class_mapping.items():\n", 108 | " inv_class_mapping[v] = k\n", 109 | "\n", 110 | "train_and_dev_samples[\"class_name\"] = train_and_dev_samples[\"class_id\"].apply(\n", 111 | " lambda x: str(inv_class_mapping[int(x)])\n", 112 | ")\n", 113 | "test_samples[\"class_name\"] = test_samples[\"class_id\"].apply(\n", 114 | " lambda x: str(inv_class_mapping[int(x)])\n", 115 | ")" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "train_and_dev_embeddings = [\n", 125 | " np.load(path) for path in train_and_dev_samples[\"embedding_path\"]\n", 126 | "]\n", 127 | "test_embeddings = [np.load(path) for path in test_samples[\"embedding_path\"]]" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "reducer = umap.UMAP()\n", 137 | "print(\"Fit + transform train and dev embeddings\")\n", 138 | "train_embedding_umap = reducer.fit_transform(train_and_dev_embeddings)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": { 145 | "vscode": { 146 | "languageId": "ruby" 147 | } 148 | }, 149 | "outputs": [], 150 | "source": [ 151 | "plt.figure(figsize=(15, 15))\n", 152 | "for class_name in train_and_dev_samples[\"class_name\"].unique():\n", 153 | " indices = train_and_dev_samples[\"class_name\"] == class_name\n", 154 | " plt.scatter(\n", 155 | " train_embedding_umap[indices, 0],\n", 156 | " train_embedding_umap[indices, 1],\n", 157 | " s=3,\n", 158 | " label=class_name,\n", 159 | " )\n", 160 | "\n", 161 | "plt.title(\"UMAP projection of the train and dev embeddings\")\n", 162 | "plt.legend(markerscale=5, bbox_to_anchor=(1.05, 1), loc=\"upper left\")\n", 163 | "plt.show()" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "print(\"Transforming test embeddings\")\n", 173 | "test_embedding_umap = reducer.transform(test_embeddings)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "plt.figure(figsize=(15, 15))\n", 183 | "for class_name in test_samples[\"class_name\"].unique():\n", 184 | " indices = test_samples[\"class_name\"] == class_name\n", 185 | " plt.scatter(\n", 186 | " test_embedding_umap[indices, 0],\n", 187 | " test_embedding_umap[indices, 1],\n", 188 | " s=3,\n", 189 | " label=class_name,\n", 190 | " )\n", 191 | "\n", 192 | "plt.title(\"UMAP projection of the test embeddings\")\n", 193 | "plt.legend(markerscale=5, bbox_to_anchor=(1.05, 1), loc=\"upper left\")\n", 194 | "plt.show()" 195 | ] 196 | } 197 | ], 198 | "metadata": { 199 | "kernelspec": { 200 | "display_name": "source-tracing", 201 | "language": "python", 202 | "name": "python3" 203 | }, 204 | "language_info": { 205 | "codemirror_mode": { 206 | "name": "ipython", 207 | "version": 3 208 | }, 209 | "file_extension": ".py", 210 | "mimetype": "text/x-python", 211 | "name": "python", 212 | "nbconvert_exporter": "python", 213 | "pygments_lexer": "ipython3", 214 | "version": "3.11.10" 215 | } 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 2 219 | } 220 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | audioread==3.0.1 2 | certifi==2024.8.30 3 | cffi==1.17.1 4 | charset-normalizer==3.4.0 5 | decorator==5.1.1 6 | filelock==3.16.1 7 | fsspec==2024.10.0 8 | huggingface-hub==0.26.3 9 | idna==3.10 10 | Jinja2==3.1.4 11 | joblib==1.4.2 12 | jupyter 13 | librosa==0.9.2 14 | llvmlite==0.43.0 15 | MarkupSafe==3.0.2 16 | mpmath==1.3.0 17 | networkx==3.4.2 18 | numba==0.60.0 19 | numpy==2.0.2 20 | nvidia-cublas-cu12==12.4.5.8 21 | nvidia-cuda-cupti-cu12==12.4.127 22 | nvidia-cuda-nvrtc-cu12==12.4.127 23 | nvidia-cuda-runtime-cu12==12.4.127 24 | nvidia-cudnn-cu12==9.1.0.70 25 | nvidia-cufft-cu12==11.2.1.3 26 | nvidia-curand-cu12==10.3.5.147 27 | nvidia-cusolver-cu12==11.6.1.9 28 | nvidia-cusparse-cu12==12.3.1.170 29 | nvidia-nccl-cu12==2.21.5 30 | nvidia-nvjitlink-cu12==12.4.127 31 | nvidia-nvtx-cu12==12.4.127 32 | packaging==24.2 33 | pandas==2.2.3 34 | pathlib==1.0.1 35 | platformdirs==4.3.6 36 | pooch==1.8.2 37 | pycparser==2.22 38 | python-dateutil==2.9.0.post0 39 | pytz==2024.2 40 | PyYAML==6.0.2 41 | regex==2024.11.6 42 | requests==2.32.3 43 | resampy==0.4.3 44 | safetensors==0.4.5 45 | scikit-learn==1.5.2 46 | scipy==1.14.1 47 | six==1.16.0 48 | soundfile==0.12.1 49 | sympy==1.13.1 50 | threadpoolctl==3.5.0 51 | tokenizers==0.20.3 52 | torch==2.5.1 53 | tqdm==4.67.1 54 | transformers==4.46.3 55 | torchaudio==2.5.0 56 | triton==3.1.0 57 | typing==3.7.4.3 58 | typing_extensions==4.12.2 59 | tzdata==2024.2 60 | umap-learn==0.5.7 61 | urllib3==2.2.3 62 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Scripts 2 | 3 | Steps required for REFD approach (more info [here](../README.md)): 4 | * `preprocess_dataset.py`, 5 | * `ood_detector.py`, 6 | * `get_classification_metrics.py`. 7 | 8 | 9 | The directory contains the following auxiliary scripts: 10 | * `download_resources.py` - required to download Special Sessio dataset, 11 | * `generate_embeddings.py` - generate embeddings for each subset, 12 | 13 | 14 | ## Generate embeddings 15 | 16 | Use your embeddings architecture to generate embeddings for train, dev and test subsets. 17 | 18 | ```bash 19 | python scripts/generate_embeddings.py --config $config_path --embeddings_root_dir $output_emb_dir 20 | ``` 21 | 22 | The generated structure looks like this: 23 | ``` 24 | $output_emb_dir 25 | ├── dev 26 | │ ├── 0 # class name mapped to an ID 27 | │ │ ├── filename_1.npy # embedding of the filename_1 audio file 28 | │ │ ├── ... 29 | │ │ └── filename_n.npy 30 | │ ├── ... 31 | │ └── 1 32 | │ ├── filename_1.npy 33 | │ ├── ... 34 | │ └── filename_n.npy 35 | ├── test 36 | │ ├── 0 37 | │ │ ├── filename_1.npy 38 | │ │ ├── ... 39 | │ │ └── filename_n.npy 40 | │ ├── ... 41 | │ └── 1 42 | │ ├── filename_1.npy 43 | │ ├── ... 44 | │ └── filename_n.npy 45 | └── train 46 | ├── 0 47 | │ ├── filename_1.npy 48 | │ ├── ... 49 | │ └── filename_n.npy 50 | ├── ... 51 | └── 1 52 | ├── filename_1.npy 53 | ├── ... 54 | └── filename_n.npy 55 | ``` 56 | 57 | ## Visualize embeddings 58 | 59 | To visualize embeddings created using `scripts/generate_embeddings.py` script use `notebooks/VisualizeEmbeddings.ipynb`. The notebook creates UMAP visualization of train and test subsets. 60 | 61 | 62 | -------------------------------------------------------------------------------- /scripts/download_resources.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER: 3 | This code is provided "as-is" without any warranty of any kind, either expressed or implied, 4 | including but not limited to the implied warranties of merchantability and fitness for a particular purpose. 5 | The author assumes no liability for any damages or consequences resulting from the use of this code. 6 | Use it at your own risk. 7 | 8 | Utility to download and extract all resources needed for the MLAADv5 project. 9 | 10 | This script handles the downloading of large files with progress bars, ensures 11 | caching of already downloaded files, and extracts `.zip` files using 7-Zip. 12 | 13 | ## Author: Nicolas MUELLER 14 | ## December 2024 15 | """ 16 | 17 | import sys 18 | from pathlib import Path 19 | 20 | # Enables running the script from root directory 21 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 22 | 23 | import os 24 | import subprocess 25 | 26 | import requests 27 | from tqdm import tqdm 28 | 29 | 30 | def download_file(session, file_url, save_path): 31 | """ 32 | Download a file with a progress bar. 33 | 34 | Parameters: 35 | session (requests.Session): The HTTP session for downloading. 36 | file_url (str): The URL of the file to download. 37 | save_path (str): The local path where the file should be saved. 38 | 39 | Returns: 40 | None 41 | """ 42 | # Check if the file exists 43 | if os.path.exists(save_path): 44 | print(f"File already exists: {save_path}") 45 | return 46 | 47 | # Get the file size from headers 48 | response = session.head(file_url, allow_redirects=True) 49 | if response.status_code != 200: 50 | print( 51 | f"Failed to fetch headers for: {file_url}, status code: {response.status_code}" 52 | ) 53 | return 54 | 55 | file_size = int(response.headers.get("content-length", 0)) 56 | 57 | # Start downloading the file with a progress bar 58 | response = session.get(file_url, stream=True) 59 | if response.status_code == 200: 60 | with open(save_path, "wb") as file, tqdm( 61 | desc=f"Downloading {os.path.basename(save_path)}", 62 | total=file_size, 63 | unit="B", 64 | unit_scale=True, 65 | unit_divisor=1024, 66 | ) as progress_bar: 67 | for chunk in response.iter_content(chunk_size=8192): 68 | file.write(chunk) 69 | progress_bar.update(len(chunk)) 70 | print(f"Download completed: {save_path}") 71 | else: 72 | print(f"Failed to download: {file_url}, status code: {response.status_code}") 73 | 74 | 75 | def extract_zip_file(zip_path, extract_dir): 76 | """ 77 | Extract a `.zip` file using 7-Zip. 78 | 79 | Parameters: 80 | zip_path (str): The path to the `.zip` file to be extracted. 81 | extract_dir (str): The directory where the files will be extracted. 82 | 83 | Returns: 84 | None 85 | """ 86 | print(f"Extracting {zip_path} to {extract_dir}...") 87 | try: 88 | subprocess.run(["7za", "x", zip_path, f"-o{extract_dir}"], check=True) 89 | print(f"Extraction completed: {zip_path}") 90 | except subprocess.CalledProcessError as e: 91 | print(f"Extraction failed: {e}") 92 | except FileNotFoundError: 93 | print("7-Zip (7za) is not installed. Please install it to enable extraction.") 94 | 95 | 96 | def download_MLAADv5(save_dir): 97 | """ 98 | Download and extract MLAADv5. 99 | Files are saved in the specified directory, and existing files are skipped. 100 | """ 101 | # MLAADv5 dataset file URLs 102 | files_mlaad = [ 103 | f"https://owncloud.fraunhofer.de/index.php/s/tL2Y1FKrWiX4ZtP/download?path=%2Fv5&files=mlaad_v5.z0{i}" 104 | for i in range(1, 10) 105 | ] + [ 106 | f"https://owncloud.fraunhofer.de/index.php/s/tL2Y1FKrWiX4ZtP/download?path=%2Fv5&files=mlaad_v5.z10", 107 | "https://owncloud.fraunhofer.de/index.php/s/tL2Y1FKrWiX4ZtP/download?path=%2Fv5&files=mlaad_v5.zip", 108 | "https://owncloud.fraunhofer.de/index.php/s/tL2Y1FKrWiX4ZtP/download?path=%2Fv5&files=mlaad_v5.zip.md5", 109 | ] 110 | 111 | # Directory to save downloaded files 112 | os.makedirs(save_dir, exist_ok=True) 113 | 114 | # Create a session 115 | session = requests.Session() 116 | 117 | # Download dataset and protocol files 118 | all_files = files_mlaad 119 | for file_url in all_files: 120 | # Extract file name from URL 121 | file_name = os.path.basename(file_url.split("&files=")[-1]) 122 | save_path = os.path.join(save_dir, file_name) 123 | # Download file if not cached 124 | download_file(session, file_url, save_path) 125 | 126 | print(f"Now, go into the directory {save_dir} and perform:") 127 | print(f"md5sum -c mlaad_v5.zip.md5") 128 | print(f"7za x mlaad_v5.zip -o./") 129 | 130 | 131 | if __name__ == "__main__": 132 | # Download MLAADv5 dataset 133 | download_MLAADv5("data/MLAADv5") 134 | 135 | # Protocol file URL 136 | save_protocols_path = "data/MLAADv5_for_sourcetracing/mlaadv5_for_sourcetracing.zip" 137 | os.makedirs(os.path.dirname(save_protocols_path), exist_ok=True) 138 | if not os.path.exists(save_protocols_path): 139 | url_protocols = "https://deepfake-total.com/data/mlaad4sourcetracing.zip" 140 | download_file(requests.Session(), url_protocols, save_protocols_path) 141 | -------------------------------------------------------------------------------- /scripts/generate_embeddings.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER: 3 | This code is provided "as-is" without any warranty of any kind, either expressed or implied, 4 | including but not limited to the implied warranties of merchantability and fitness for a particular purpose. 5 | The author assumes no liability for any damages or consequences resulting from the use of this code. 6 | Use it at your own risk. 7 | 8 | ## Author: Piotr KAWA 9 | ## December 2024 10 | """ 11 | 12 | import argparse 13 | import sys 14 | from pathlib import Path 15 | 16 | # Enables running the script from root directory 17 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 18 | 19 | import numpy as np 20 | import pandas as pd 21 | import torch 22 | import yaml 23 | from sklearn.preprocessing import LabelEncoder 24 | from torch import nn 25 | from torch.utils.data import DataLoader 26 | from tqdm import tqdm 27 | 28 | from src import models 29 | from src.datasets.dataset import BaseDataset 30 | 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser(description="Generate embeddings script") 34 | parser.add_argument( 35 | "--config", 36 | type=str, 37 | default="configs/configs_embeddings.yaml", 38 | required=True, 39 | help="Path to config file", 40 | ) 41 | parser.add_argument( 42 | "--embeddings_root_dir", 43 | type=str, 44 | required=True, 45 | help="Path to the embeddings directory", 46 | ) 47 | parser.add_argument( 48 | "--cpu", 49 | action="store_true", 50 | help="Force the use of CPU even if GPU is available", 51 | ) 52 | return parser.parse_args() 53 | 54 | 55 | def generate_embeddings( 56 | data_loader: DataLoader, model: nn.Module, embeddings_root_dir: Path 57 | ): 58 | embeddings_root_dir = Path(embeddings_root_dir) 59 | 60 | for x, y, paths in tqdm(data_loader, total=len(data_loader)): 61 | x = x.to(device) 62 | with torch.no_grad(): 63 | embeddings = model(x) 64 | 65 | for embedding, label, path in zip(embeddings, y, paths): 66 | cls_dir = embeddings_root_dir / str(label.item()) 67 | cls_dir.mkdir(parents=True, exist_ok=True) 68 | np.save(cls_dir / f"{path.stem}.npy", embedding.cpu().numpy()) 69 | 70 | 71 | if __name__ == "__main__": 72 | args = parse_args() 73 | 74 | with open(args.config, "r") as f: 75 | config = yaml.safe_load(f) 76 | 77 | device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu" 78 | 79 | model = models.get_model( 80 | model_name=config["model"]["model_name"], 81 | checkpoint_path=config["model"]["checkpoint_path"], 82 | ) 83 | model.to(device) 84 | model.eval() 85 | 86 | embeddings_root_dir = Path(args.embeddings_root_dir) 87 | embeddings_root_dir.mkdir(parents=True, exist_ok=True) 88 | 89 | # Data 90 | path_mlaad = config["data"]["mlaad_root_path"] 91 | path_protocols = Path(config["data"]["protocols_root_path"]) 92 | 93 | protocols = { 94 | "train": path_protocols / "train.csv", 95 | "dev": path_protocols / "dev.csv", 96 | "test": path_protocols / "eval.csv", 97 | } 98 | 99 | for subset, protocols_root_path in protocols.items(): 100 | assert protocols_root_path.exists(), f"{protocols_root_path} does not exist" 101 | 102 | dataframes = { 103 | subset: pd.read_csv(protocols_root_path) 104 | for subset, protocols_root_path in protocols.items() 105 | } 106 | 107 | # Concat all datasets to transform model names into model ids and as new column do each df 108 | all_df = pd.concat(dataframes.values()) 109 | le = LabelEncoder() 110 | le.fit(all_df["model_name"]) 111 | class_mapping = {name: idx for idx, name in enumerate(le.classes_)} 112 | 113 | for subset, df in dataframes.items(): 114 | df["model_id"] = le.transform(df["model_name"]) 115 | 116 | for subset, df in dataframes.items(): 117 | print(f"Generating '{subset}' subset embeddings") 118 | dataset = BaseDataset( 119 | basepath=path_mlaad, 120 | sr=config["data"]["sampling_rate"], 121 | sample_length_s=config["data"]["sample_length_s"], 122 | meta_data=df.to_dict(orient="records"), 123 | class_mapping=class_mapping, 124 | ) 125 | data_loader = DataLoader( 126 | dataset, 127 | batch_size=config["data"]["batch_size"], 128 | collate_fn=dataset.collate_fn, 129 | shuffle=True, 130 | num_workers=config["data"]["num_workers"], 131 | ) 132 | generate_embeddings( 133 | data_loader=data_loader, 134 | model=model, 135 | embeddings_root_dir=embeddings_root_dir / subset, 136 | ) 137 | -------------------------------------------------------------------------------- /scripts/get_classification_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER: 3 | This code is provided "as-is" without any warranty of any kind, either expressed or implied, 4 | including but not limited to the implied warranties of merchantability and fitness for a particular purpose. 5 | The author assumes no liability for any damages or consequences resulting from the use of this code. 6 | Use it at your own risk. 7 | 8 | ## Author: Adriana STAN 9 | ## December 2024 10 | """ 11 | 12 | import argparse 13 | import os 14 | import sys 15 | from pathlib import Path 16 | 17 | # Enables running the script from root directory 18 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 19 | 20 | import numpy as np 21 | import torch 22 | import torch.nn.functional as F 23 | from sklearn.metrics import classification_report 24 | from torch.utils.data import DataLoader 25 | from tqdm import tqdm 26 | 27 | from src.datasets.dataset import MLAADFDDataset 28 | from src.models.w2v2_aasist import W2VAASIST 29 | 30 | 31 | def parse_args(): 32 | parser = argparse.ArgumentParser(description="Get metrics") 33 | parser.add_argument( 34 | "--model_path", 35 | type=str, 36 | default="exp/trained_models/anti-spoofing_feat_model.pt", 37 | help="Path to trained model", 38 | ) 39 | parser.add_argument( 40 | "--path_to_features", 41 | type=str, 42 | default="exp/preprocess_wav2vec2-base/", 43 | help="Path to features", 44 | ) 45 | parser.add_argument( 46 | "--results_path", 47 | type=str, 48 | default="exp/results/", 49 | help="Where to write the results", 50 | ) 51 | 52 | parser.add_argument( 53 | "--batch_size", type=int, default=128, help="Batch size for inference" 54 | ) 55 | parser.add_argument( 56 | "--feat_dim", 57 | type=int, 58 | default=768, 59 | help="Feature dimension of wav2vec features", 60 | ) 61 | parser.add_argument( 62 | "--num_classes", 63 | type=int, 64 | default=24, 65 | help="Number of systems in the training dataset", 66 | ) 67 | args = parser.parse_args() 68 | if not os.path.exists((args.results_path)): 69 | os.makedirs(args.results_path) 70 | return args 71 | 72 | 73 | def main(args): 74 | device = "cuda" if torch.cuda.is_available() else "cpu" 75 | print(f"Running on {device}..") 76 | 77 | # Read the data 78 | dev_dataset = MLAADFDDataset( 79 | args.path_to_features, "dev", mode="known", max_samples=-1 80 | ) 81 | dev_loader = DataLoader(dev_dataset, batch_size=args.batch_size, num_workers=0) 82 | 83 | eval_dataset = MLAADFDDataset( 84 | args.path_to_features, "eval", mode="known", max_samples=-1 85 | ) 86 | eval_loader = DataLoader(eval_dataset, batch_size=args.batch_size, num_workers=0) 87 | 88 | if len(eval_dataset) == 0: 89 | print("No data found for evaluation! Exiting...") 90 | exit(1) 91 | 92 | print(f"Loading model from {args.model_path}") 93 | model = W2VAASIST(args.feat_dim, args.num_classes) 94 | state_dict = torch.load(args.model_path) 95 | model.load_state_dict(state_dict) 96 | model.to(device) 97 | model.eval() 98 | 99 | print("Running on dev data...") 100 | with torch.no_grad(): 101 | all_predicted = np.zeros(len(dev_dataset), dtype=int) 102 | all_labels = np.zeros(len(dev_dataset), dtype=int) 103 | 104 | dev_bar = tqdm(dev_loader, desc=f"Evaluation") 105 | for idx, batch in enumerate(dev_bar): 106 | sample_number = idx * args.batch_size 107 | feats, filename, labels = batch 108 | feats = feats.transpose(2, 3).to(device) 109 | _, logits = model(feats) 110 | logits = F.softmax(logits, dim=1) 111 | predicted = torch.argmax(logits, dim=1).detach().cpu().numpy() 112 | all_predicted[sample_number : sample_number + labels.shape[0]] = predicted 113 | all_labels[sample_number : sample_number + labels.shape[0]] = labels 114 | 115 | print("Classification report for DEV data: ") 116 | report_path = os.path.join(args.results_path, "dev_in_domain_results.txt") 117 | report = classification_report( 118 | all_labels, all_predicted, labels=np.unique(all_labels), zero_division=1.0 119 | ) 120 | with open(report_path, "w") as f: 121 | f.write(report) 122 | print(report) 123 | print(f"... also written to {report_path}") 124 | 125 | print("Running on evaluation data...") 126 | with torch.no_grad(): 127 | all_predicted = np.zeros(len(eval_dataset), dtype=int) 128 | all_labels = np.zeros(len(eval_dataset), dtype=int) 129 | 130 | eval_bar = tqdm(eval_loader, desc=f"Evaluation") 131 | for idx, batch in enumerate(eval_bar): 132 | sample_number = idx * args.batch_size 133 | feats, filename, labels = batch 134 | feats = feats.transpose(2, 3).to(device) 135 | _, logits = model(feats) 136 | logits = F.softmax(logits, dim=1) 137 | predicted = torch.argmax(logits, dim=1).detach().cpu().numpy() 138 | all_predicted[sample_number : sample_number + labels.shape[0]] = predicted 139 | all_labels[sample_number : sample_number + labels.shape[0]] = labels 140 | 141 | print("Classification report for EVAL data:") 142 | report_path = os.path.join(args.results_path, "eval_in_domain_results.txt") 143 | report = classification_report( 144 | all_labels, all_predicted, labels=np.unique(all_labels), zero_division=1.0 145 | ) 146 | with open(report_path, "w") as f: 147 | f.write(report) 148 | print(report) 149 | print(f"... also written to {report_path}") 150 | 151 | 152 | if __name__ == "__main__": 153 | args = parse_args() 154 | main(args) 155 | -------------------------------------------------------------------------------- /scripts/ood_detector.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER: 3 | This code is provided "as-is" without any warranty of any kind, either expressed or implied, 4 | including but not limited to the implied warranties of merchantability and fitness for a particular purpose. 5 | The author assumes no liability for any damages or consequences resulting from the use of this code. 6 | Use it at your own risk. 7 | 8 | ## Author: Adriana STAN 9 | ## December 2024 10 | """ 11 | 12 | import argparse 13 | import os 14 | import sys 15 | from pathlib import Path 16 | 17 | # Enables running the script from root directory 18 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 19 | import numpy as np 20 | import torch 21 | from sklearn.metrics import classification_report, roc_curve 22 | from torch.utils.data import DataLoader 23 | from tqdm import tqdm 24 | 25 | from src.models.w2v2_aasist import W2VAASIST 26 | from src.datasets.dataset import MLAADFDDataset 27 | from src.models.NSD import NSDOODDetector 28 | 29 | 30 | def parse_args(): 31 | parser = argparse.ArgumentParser("OOD Detector script") 32 | # Paths 33 | parser.add_argument( 34 | "--model_path", 35 | type=str, 36 | default="exp/trained_models/anti-spoofing_feat_model.pt", 37 | help="Path to trained model", 38 | ) 39 | parser.add_argument( 40 | "--feature_path", 41 | type=str, 42 | default="exp/preprocess_wav2vec2-base/", 43 | help="Path to features", 44 | ) 45 | parser.add_argument( 46 | "--out_folder", type=str, default="exp/ood_step/", help="Path to output results" 47 | ) 48 | parser.add_argument( 49 | "--label_assignment_path", 50 | type=str, 51 | default="exp/label_assignment.txt", 52 | help="Path to the file which lists the class assignments as written in the preprocessing step", 53 | ) 54 | # Hyperparameters 55 | parser.add_argument("--batch_size", type=int, default=128, help="Batch_size") 56 | parser.add_argument("--feat_dim", type=int, default=768, help="Feature dimension") 57 | parser.add_argument("--hidden_dim", type=int, default=160, help="Hidden size dim") 58 | parser.add_argument( 59 | "--num_classes", type=int, default=24, help="Number of known systems" 60 | ) 61 | parser.add_argument( 62 | "--feature_extraction_step", 63 | action="store_true", 64 | help="Whether to run the feature extraction step or just the OOD", 65 | ) 66 | 67 | args = parser.parse_args() 68 | if not os.path.exists(args.out_folder): 69 | os.makedirs(args.out_folder) 70 | 71 | return args 72 | 73 | 74 | def compute_eer(labels, scores): 75 | fpr, tpr, thresholds = roc_curve(labels, scores) 76 | eer_index = np.nanargmin(np.abs(fpr - (1 - tpr))) 77 | return fpr[eer_index], thresholds[eer_index] 78 | 79 | 80 | def main(args): 81 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 82 | print(f"Loading model from {args.model_path}") 83 | model = W2VAASIST(args.feat_dim, args.num_classes) 84 | state_dict = torch.load(args.model_path) 85 | model.load_state_dict(state_dict) 86 | model.to(device) 87 | model.eval() 88 | 89 | if args.feature_extraction_step: 90 | # Loading datasets 91 | train_dataset = MLAADFDDataset(args.feature_path, "train") 92 | train_loader = DataLoader( 93 | train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 94 | ) 95 | 96 | dev_dataset = MLAADFDDataset(args.feature_path, "dev") 97 | dev_loader = DataLoader( 98 | dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 99 | ) 100 | 101 | eval_dataset = MLAADFDDataset(args.feature_path, "eval") 102 | eval_loader = DataLoader( 103 | eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 104 | ) 105 | 106 | # Extract logits and hidden states from trained model 107 | for subset_, loader in zip( 108 | ["train", "dev", "eval"], [train_loader, dev_loader, eval_loader] 109 | ): 110 | print(f"Running hidden feature extraction for {subset_}") 111 | all_feats = np.zeros((len(loader) * args.batch_size, args.hidden_dim)) 112 | all_logits = np.zeros((len(loader) * args.batch_size, args.num_classes)) 113 | all_labels = np.zeros(len(loader) * args.batch_size) 114 | 115 | for idx, batch in enumerate(tqdm(loader)): 116 | sample_num = idx * args.batch_size 117 | feats, filename, labels = batch 118 | feats = feats.transpose(2, 3).to(device) 119 | with torch.no_grad(): 120 | hidden_state, logits = model(feats) 121 | # Store all info 122 | all_feats[sample_num : sample_num + feats.shape[0]] = ( 123 | hidden_state.detach().cpu().numpy() 124 | ) 125 | all_logits[sample_num : sample_num + feats.shape[0]] = ( 126 | logits.detach().cpu().numpy() 127 | ) 128 | all_labels[sample_num : sample_num + feats.shape[0]] = labels 129 | # Save the info 130 | out_path = os.path.join(args.out_folder, f"{subset_}_dict.npy") 131 | np.save( 132 | out_path, 133 | {"feats": all_feats, "logits": all_logits, "labels": all_labels}, 134 | ) 135 | print(f"Saved hidden_states to {out_path}") 136 | 137 | train_dict = np.load( 138 | os.path.join(args.out_folder, "train_dict.npy"), allow_pickle=True 139 | ).item() 140 | dev_dict = np.load( 141 | os.path.join(args.out_folder, "dev_dict.npy"), allow_pickle=True 142 | ).item() 143 | eval_dict = np.load( 144 | os.path.join(args.out_folder, "eval_dict.npy"), allow_pickle=True 145 | ).item() 146 | 147 | print("Setting up the OOD detector using the training data...") 148 | ood_detector = NSDOODDetector() 149 | ood_detector.setup(args, train_dict) 150 | 151 | # Get scores for OOD 152 | print("Getting OOD scores for the dev set...") 153 | dev_scores = ood_detector.infer(dev_dict) 154 | 155 | # Get the systems' labels assigned to OOD samples 156 | # Convert the system numbers into classes: OOD=1 and KNOWN=0 157 | with open(args.label_assignment_path) as f: 158 | OOD_classes = [ 159 | int(line.split("|")[1]) 160 | for line in f.readlines() 161 | if line.strip().split("|")[2] == "OOD" 162 | ] 163 | dev_ood_labels = [ 164 | 1 if int(dev_dict["labels"][k]) in OOD_classes else 0 165 | for k in range(len(dev_dict["labels"])) 166 | ] 167 | 168 | # Compute a EER threshold over the dev scores 169 | print("\nComputing the EER threshold over the development set...") 170 | eer, threshold = compute_eer(dev_ood_labels, dev_scores) 171 | print(f"DEV EER: {eer*100:.2f} | Threshold: {threshold:.2f}") 172 | 173 | # Set the threshold and compute the OOD accuracy over the eval set 174 | print("\nComputing the evaluation results using the dev threshold...") 175 | print("Class 1 is OOD, Class 0 is ID") 176 | eval_scores = ood_detector.infer(eval_dict) 177 | eval_ood_labels = [ 178 | 1 if int(eval_dict["labels"][k]) in OOD_classes else 0 179 | for k in range(len(eval_dict["labels"])) 180 | ] 181 | predicts = [ 182 | 1 if eval_scores[k] > threshold else 0 for k in range(len(eval_dict["labels"])) 183 | ] 184 | 185 | print("OOD classification report for eval data:") 186 | report = classification_report(eval_ood_labels, predicts) 187 | report_path = os.path.join(args.out_folder, "OOD_eval_results.txt") 188 | with open(report_path, "w") as f: 189 | f.write(report) 190 | print(report) 191 | print(f"... also written to {report_path}") 192 | 193 | 194 | if __name__ == "__main__": 195 | args = parse_args() 196 | main(args) 197 | -------------------------------------------------------------------------------- /scripts/preprocess_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER: 3 | This code is provided "as-is" without any warranty of any kind, either expressed or implied, 4 | including but not limited to the implied warranties of merchantability and fitness for a particular purpose. 5 | The author assumes no liability for any damages or consequences resulting from the use of this code. 6 | Use it at your own risk. 7 | 8 | ## Author: Adriana STAN 9 | ## December 2024 10 | """ 11 | 12 | import argparse 13 | import os 14 | import sys 15 | from pathlib import Path 16 | 17 | # Enables running the script from root directory 18 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 19 | import pandas as pd 20 | import torch 21 | from sklearn.preprocessing import LabelEncoder 22 | from torch.utils.data import DataLoader 23 | from tqdm import tqdm 24 | 25 | from src.datasets.dataset import MLAADBaseDataset 26 | from src.datasets.utils import HuggingFaceFeatureExtractor, WaveformEmphasiser 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser(description="Data augmentation script") 31 | # Datasets and protocols 32 | parser.add_argument( 33 | "--mlaad_path", 34 | type=str, 35 | default="data/MLAAD_v5/", 36 | help="Path to MLAADv5 dataset", 37 | ) 38 | parser.add_argument( 39 | "--protocol_path", 40 | type=str, 41 | default="data/MLAADv5_for_sourcetracing/", 42 | help="Path to MLAADv5 protocols", 43 | ) 44 | parser.add_argument( 45 | "--musan_path", 46 | type=str, 47 | default="data/musan/", 48 | help="Path to the MUSAN dataset", 49 | ) 50 | parser.add_argument( 51 | "--rir_path", 52 | type=str, 53 | default="data/RIRS_NOISES/", 54 | help="Path to RIRs dataset", 55 | ) 56 | 57 | # HuggingFace feature extractor 58 | parser.add_argument( 59 | "--model_name", 60 | type=str, 61 | default="wav2vec2-base", 62 | help="name of the feature extractor", 63 | ) 64 | parser.add_argument( 65 | "--model_class", 66 | type=str, 67 | default="Wav2Vec2Model", 68 | help="Class of the feature extractor", 69 | ) 70 | parser.add_argument( 71 | "--model_layer", 72 | type=int, 73 | default=5, 74 | help="Which layer to use from the feature extractor", 75 | ) 76 | parser.add_argument( 77 | "--hugging_face_path", 78 | type=str, 79 | default="facebook/wav2vec2-base", 80 | help="Path from the HF collections", 81 | ) 82 | parser.add_argument( 83 | "--sampling_rate", type=int, default=16_000, help="Audio sampling rate" 84 | ) 85 | parser.add_argument( 86 | "--max_length", type=int, default=4, help="Crop the audio to X seconds" 87 | ) 88 | parser.add_argument( 89 | "--batch_size", type=int, default=1, help="Batch size for preprocessing" 90 | ) 91 | parser.add_argument( 92 | "--num_workers", type=int, default=0, help="Workers for loaders" 93 | ) 94 | 95 | # Output folder 96 | parser.add_argument( 97 | "--out_folder", type=str, default="exp", help="Where to write the results" 98 | ) 99 | args = parser.parse_args() 100 | if not os.path.exists(args.out_folder): 101 | os.makedirs(args.out_folder) 102 | return args 103 | 104 | 105 | def main(args): 106 | 107 | # Read the MLAAD data 108 | path_mlaad = args.mlaad_path 109 | path_protocols = args.protocol_path 110 | train_protocol = os.path.join(path_protocols, "train.csv") 111 | dev_protocol = os.path.join(path_protocols, "dev.csv") 112 | test_protocol = os.path.join(path_protocols, "eval.csv") 113 | assert os.path.exists(train_protocol), f"{train_protocol} does not exist" 114 | assert os.path.exists(dev_protocol), f"{dev_protocol} does not exist" 115 | assert os.path.exists(test_protocol), f"{test_protocol} does not exist" 116 | train_df = pd.read_csv(train_protocol) 117 | dev_df = pd.read_csv(dev_protocol) 118 | test_df = pd.read_csv(test_protocol) 119 | 120 | # Encode the system names to unique int values 121 | # Use only the training data classes. The others are OOD 122 | le = LabelEncoder() 123 | le.fit(train_df["model_name"]) 124 | train_df["model_id"] = le.transform(train_df["model_name"]) 125 | class_mapping = {name: [idx, "ID"] for idx, name in enumerate(le.classes_)} 126 | 127 | # Add a OOD label for unseen systems in the training data 128 | for k in pd.concat([dev_df["model_name"], test_df["model_name"]]): 129 | if k not in class_mapping: 130 | class_mapping[k] = [len(class_mapping), "OOD"] 131 | 132 | # Save the label assignment 133 | with open(os.path.join(args.out_folder, "label_assignment.txt"), "w") as fout: 134 | for k, v in sorted( 135 | class_mapping.items(), key=lambda item: (item[1], item[0].lower()) 136 | ): 137 | fout.write(f"{k.ljust(50)}|{str(v[0]).ljust(3)}|{v[1]}\n") 138 | print( 139 | f"[INFO] Label assignment written to: {args.out_folder}/label_assignment.txt" 140 | ) 141 | 142 | # Prepare dataloaders 143 | train_data = MLAADBaseDataset( 144 | basepath=path_mlaad, 145 | sr=args.sampling_rate, 146 | sample_length_s=args.max_length, 147 | meta_data=train_df.to_dict(orient="records"), 148 | class_mapping=class_mapping, 149 | max_samples=-1, 150 | ) 151 | train_loader = DataLoader( 152 | train_data, 153 | batch_size=args.batch_size, 154 | collate_fn=train_data.collate_fn, 155 | shuffle=False, 156 | num_workers=args.num_workers, 157 | ) 158 | 159 | dev_data = MLAADBaseDataset( 160 | basepath=path_mlaad, 161 | sr=args.sampling_rate, 162 | sample_length_s=args.max_length, 163 | meta_data=dev_df.to_dict(orient="records"), 164 | class_mapping=class_mapping, 165 | max_samples=-1, 166 | ) 167 | dev_loader = DataLoader( 168 | dev_data, 169 | batch_size=args.batch_size, 170 | collate_fn=train_data.collate_fn, 171 | shuffle=False, 172 | num_workers=args.num_workers, 173 | ) 174 | 175 | test_data = MLAADBaseDataset( 176 | basepath=path_mlaad, 177 | sr=args.sampling_rate, 178 | sample_length_s=args.max_length, 179 | meta_data=test_df.to_dict(orient="records"), 180 | class_mapping=class_mapping, 181 | max_samples=-1, 182 | ) 183 | test_loader = DataLoader( 184 | test_data, 185 | batch_size=args.batch_size, 186 | collate_fn=train_data.collate_fn, 187 | shuffle=False, 188 | num_workers=args.num_workers, 189 | ) 190 | 191 | # Load the feature extractor 192 | feature_extractor = HuggingFaceFeatureExtractor( 193 | model_class_name=args.model_class, 194 | layer=args.model_layer, 195 | name=args.hugging_face_path, 196 | ) 197 | 198 | ## Run the augmentation 199 | list_of_emphases = ["original", "reverb", "speech", "music", "noise"] 200 | emphasiser = WaveformEmphasiser(args.sampling_rate, args.musan_path, args.rir_path) 201 | for subset_, loader in zip( 202 | ["train", "dev", "eval"], [train_loader, dev_loader, test_loader] 203 | ): 204 | count = 0 205 | feature_folder = os.path.join(args.out_folder, "preprocess_" + args.model_name) 206 | target_dir = os.path.join(feature_folder, subset_) 207 | if not os.path.exists(target_dir): 208 | os.makedirs(target_dir) 209 | print(f"[INFO] Processing {subset_} data...") 210 | print(f"[INFO] Writing features to {target_dir}") 211 | for waveform, label, file_name in tqdm(loader): 212 | for emphasis in list_of_emphases: 213 | waveform = emphasiser(waveform, emphasis) 214 | hidden_state = feature_extractor(waveform, args.sampling_rate) 215 | 216 | # Create a unique filename which also includes the class id 217 | # i.e. 000001_class_emphasisType_originalFileName.pt 218 | orig_file_name = os.path.splitext(os.path.split(file_name[0])[1])[0] 219 | out_file_name = ( 220 | f"{count:06d}_{label.item()}_{emphasis}_{orig_file_name}.pt" 221 | ) 222 | torch.save( 223 | hidden_state.float(), os.path.join(target_dir, out_file_name) 224 | ) 225 | count += 1 226 | print("[INFO] Augmentation step finished") 227 | 228 | 229 | if __name__ == "__main__": 230 | args = parse_args() 231 | main(args) 232 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piotrkawa/audio-deepfake-source-tracing/c8292bbef4756e723480eb319fa9ff6a4e25b94e/src/datasets/__init__.py -------------------------------------------------------------------------------- /src/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER: 3 | This code is provided "as-is" without any warranty of any kind, either expressed or implied, 4 | including but not limited to the implied warranties of merchantability and fitness for a particular purpose. 5 | The author assumes no liability for any damages or consequences resulting from the use of this code. 6 | Use it at your own risk. 7 | 8 | ## Authors: Piotr KAWA, Adriana STAN 9 | ## December 2024 10 | """ 11 | 12 | import os 13 | import random 14 | from collections import defaultdict 15 | from pathlib import Path 16 | 17 | import librosa 18 | import numpy as np 19 | import torch 20 | from torch.utils.data import Dataset 21 | 22 | 23 | class BaseDataset(Dataset): 24 | def __init__( 25 | self, 26 | meta_data: dict, 27 | basepath: str, 28 | class_mapping: dict, 29 | num_utter_per_class: int = 10, 30 | sr: int = 16_000, 31 | sample_length_s: float = 4, 32 | verbose: bool = False, 33 | ): 34 | super().__init__() 35 | self.class_mapping = class_mapping 36 | self.basepath = Path(basepath) 37 | self.sr = sr 38 | self.sample_length_s = sample_length_s 39 | self.num_utter_per_class = num_utter_per_class 40 | self.verbose = verbose 41 | self.samples, self.classes_in_subset = self._parse_samples(meta_data) 42 | 43 | if self.verbose: 44 | self._print_initialization_info() 45 | 46 | def _print_initialization_info(self): 47 | print("\n > Dataset initialization") 48 | print(f" | > Number of instances: {len(self.samples)}") 49 | print(f" | > Sequence length: {self.sample_length_s} s") 50 | print(f" | > Sampling rate: {self.sr}") 51 | print(f" | > Num Classes: {len(self.classes_in_subset)}") 52 | print(f" | > Classes: {list(self.classes_in_subset)}") 53 | 54 | def load_wav(self, file_path: str) -> np.ndarray: 55 | audio, sr = librosa.load(file_path, sr=None) 56 | if sr != self.sr: 57 | audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sr) 58 | return audio 59 | 60 | def _parse_samples(self, meta_data): 61 | class_to_utters = defaultdict(list) 62 | for sample in meta_data: 63 | path_ = self.basepath / sample["path"] 64 | assert path_.exists(), f"File does not exist: {path_}" 65 | class_name = sample["model_id"] 66 | class_to_utters[class_name].append(path_) 67 | 68 | inv_class_mapping = {v: k for k, v in self.class_mapping.items()} 69 | 70 | # skip classes with number of samples >= self.num_utter_per_class 71 | class_to_utters = { 72 | inv_class_mapping[k]: v 73 | for (k, v) in class_to_utters.items() 74 | if len(v) >= self.num_utter_per_class 75 | } 76 | 77 | classes = list(class_to_utters.keys()) 78 | classes.sort() 79 | 80 | new_items = [] 81 | for sample in meta_data: 82 | path_ = self.basepath / sample["path"] 83 | class_name = sample["model_name"] 84 | new_items.append( 85 | { 86 | "wav_file_path": path_, 87 | "class_name": class_name, 88 | "class_id": self.class_mapping[class_name], 89 | } 90 | ) 91 | 92 | return new_items, classes 93 | 94 | def __len__(self): 95 | return len(self.samples) 96 | 97 | def get_num_classes(self): 98 | return len(self.classes_in_subset) 99 | 100 | def get_class_list(self): 101 | return list(self.classes_in_subset) 102 | 103 | def __getitem__(self, idx: int) -> dict: 104 | return self.samples[idx] 105 | 106 | def collate_fn(self, batch: torch.Tensor, return_metadata: bool = True) -> tuple: 107 | labels, feats, paths = [], [], [] 108 | target_length = int(self.sample_length_s * self.sr) 109 | 110 | for item in batch: 111 | wav = self.load_wav(item["wav_file_path"]) 112 | wav = self._process_wav(wav, target_length) 113 | feats.append(torch.from_numpy(wav).unsqueeze(0).float()) 114 | labels.append(item["class_id"]) 115 | paths.append(item["wav_file_path"]) 116 | 117 | feats_tensor = torch.stack(feats) 118 | labels_tensor = torch.LongTensor(labels) 119 | return ( 120 | (feats_tensor, labels_tensor, paths) 121 | if return_metadata 122 | else (feats_tensor, labels_tensor) 123 | ) 124 | 125 | def _process_wav(self, wav: np.ndarray, target_length: int) -> np.ndarray: 126 | if wav.shape[0] >= target_length: 127 | offset = random.randint(0, wav.shape[0] - target_length) 128 | wav = wav[offset : offset + target_length] 129 | else: 130 | wav = np.pad(wav, (0, target_length - wav.shape[0]), mode="wrap") 131 | return wav 132 | 133 | 134 | class MLAADBaseDataset(Dataset): 135 | def __init__( 136 | self, 137 | meta_data: dict, 138 | basepath: str, 139 | class_mapping: dict, 140 | sr: int = 16_000, 141 | sample_length_s: float = 4, 142 | max_samples=-1, 143 | verbose: bool = True, 144 | ): 145 | super().__init__() 146 | self.class_mapping = {k: v[0] for k, v in class_mapping.items()} 147 | self.items = meta_data 148 | self.sample_length_s = sample_length_s 149 | self.basepath = basepath 150 | self.sr = sr 151 | self.verbose = verbose 152 | self.classes, self.items = self._parse_items() 153 | 154 | # [TEMP] limit the number of samples per class for testing 155 | if max_samples > 0: 156 | counts = {k: 0 for k in self.classes} 157 | new_items = [] 158 | for k in range(len(self.items)): 159 | if counts[self.items[k]["class_id"]] < max_samples: 160 | new_items.append(self.items[k]) 161 | counts[self.items[k]["class_id"]] += 1 162 | 163 | self.items = new_items 164 | 165 | if self.verbose: 166 | self._print_initialization_info() 167 | 168 | def _print_initialization_info(self): 169 | print("\n > DataLoader initialization") 170 | print(f" | > Number of instances : {len(self.items)}") 171 | print(f" | > Max sequence length: {self.sample_length_s} seconds") 172 | print(f" | > Num Classes: {len(self.classes)}") 173 | print(f" | > Classes: {self.classes}") 174 | 175 | def load_wav(self, file_path: str) -> np.ndarray: 176 | audio, sr = librosa.load(file_path, sr=None) 177 | if sr != self.sr: 178 | audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sr) 179 | return audio 180 | 181 | def _parse_items(self): 182 | class_to_utters = defaultdict(list) 183 | for item in self.items: 184 | path = Path(self.basepath) / item["path"] 185 | assert os.path.exists(path), f"File does not exist: {path}" 186 | class_id = self.class_mapping[item["model_name"]] 187 | class_to_utters[class_id].append(path) 188 | 189 | classes = sorted(class_to_utters.keys()) 190 | new_items = [ 191 | { 192 | "wav_file_path": Path(self.basepath) / item["path"], 193 | "class_id": self.class_mapping[item["model_name"]], 194 | } 195 | for item in self.items 196 | ] 197 | return classes, new_items 198 | 199 | def __len__(self): 200 | return len(self.items) 201 | 202 | def get_num_classes(self): 203 | return len(self.classes) 204 | 205 | def get_class_list(self): 206 | return self.classes 207 | 208 | def __getitem__(self, idx: int) -> dict: 209 | return self.items[idx] 210 | 211 | def collate_fn(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 212 | labels, feats, files = [], [], [] 213 | target_length = int(self.sample_length_s * self.sr) 214 | 215 | for item in batch: 216 | utter_path = item["wav_file_path"] 217 | class_id = item["class_id"] 218 | wav = self.load_wav(utter_path) 219 | wav = self._process_wav(wav, target_length) 220 | feats.append(torch.from_numpy(wav).unsqueeze(0).float()) 221 | labels.append(class_id) 222 | files.append(item["wav_file_path"]) 223 | return torch.stack(feats), torch.LongTensor(labels), files 224 | 225 | def _process_wav(self, wav: np.ndarray, target_length: int) -> np.ndarray: 226 | if wav.shape[0] >= target_length: 227 | offset = random.randint(0, wav.shape[0] - target_length) 228 | wav = wav[offset : offset + target_length] 229 | else: 230 | wav = np.pad(wav, (0, max(0, target_length - wav.shape[0])), mode="wrap") 231 | return wav 232 | 233 | 234 | class MLAADFDDataset(Dataset): 235 | def __init__(self, path_to_features, part="train", mode="train", max_samples=-1): 236 | super().__init__() 237 | self.path_to_features = path_to_features 238 | self.part = part 239 | self.ptf = os.path.join(path_to_features, self.part) 240 | self.all_files = librosa.util.find_files(self.ptf, ext="pt") 241 | if mode == "known": 242 | # keep only known classes seen during training for F1 metrics 243 | self.all_files = [ 244 | x for x in self.all_files if int(os.path.basename(x).split("_")[1]) < 24 245 | ] 246 | 247 | if max_samples > 0: 248 | self.all_files = self.all_files[:max_samples] 249 | 250 | # Determine the set of labels 251 | self.labels = sorted( 252 | set([int(os.path.split(x)[1].split("_")[1]) for x in self.all_files]) 253 | ) 254 | self._print_info() 255 | 256 | def _print_info(self): 257 | print(f"Searching for features in folder: {self.ptf}") 258 | print(f"Found {len(self.all_files)} files...") 259 | print(f"Using {len(self.labels)} classes\n") 260 | print( 261 | "Seen classes: ", 262 | set([int(os.path.basename(x).split("_")[1]) for x in self.all_files]), 263 | ) 264 | 265 | def __len__(self): 266 | return len(self.all_files) 267 | 268 | def __getitem__(self, idx): 269 | filepath = self.all_files[idx] 270 | basename = os.path.basename(filepath) 271 | all_info = basename.split("_") 272 | 273 | feature_tensor = torch.load(filepath) 274 | filename = "_".join(all_info[2:-1]) 275 | label = int(all_info[1]) 276 | 277 | return feature_tensor, filename, label 278 | -------------------------------------------------------------------------------- /src/datasets/samplers.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER: 3 | This code is provided "as-is" without any warranty of any kind, either expressed or implied, 4 | including but not limited to the implied warranties of merchantability and fitness for a particular purpose. 5 | The author assumes no liability for any damages or consequences resulting from the use of this code. 6 | Use it at your own risk. 7 | 8 | ## Authors: Piotr KAWA 9 | ## December 2024 10 | ## Code adapted from coqui.ai - https://github.com/coqui-ai/TTS/blob/dev/TTS/utils/samplers.py 11 | """ 12 | 13 | 14 | import random 15 | from collections import defaultdict 16 | 17 | from torch.utils.data.sampler import Sampler, SubsetRandomSampler 18 | 19 | 20 | class PerfectBatchSampler(Sampler): 21 | """ 22 | Samples a mini-batch of indices for a balanced class batching 23 | 24 | Args: 25 | dataset_items(list): dataset items to sample from. 26 | classes (list): list of classes of dataset_items to sample from. 27 | batch_size (int): total number of samples to be sampled in a mini-batch. 28 | num_gpus (int): number of GPU in the data parallel mode. 29 | shuffle (bool): if True, samples randomly, otherwise samples sequentially. 30 | drop_last (bool): if True, drops last incomplete batch. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | dataset_items, 36 | classes: list, 37 | batch_size: int, 38 | num_classes_in_batch: int, 39 | num_gpus: int = 1, 40 | drop_last: bool = False, 41 | label_key: str = "class_name", 42 | ): 43 | super().__init__(dataset_items) 44 | assert ( 45 | batch_size % (num_classes_in_batch * num_gpus) == 0 46 | ), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)." 47 | 48 | label_indices = defaultdict(list) 49 | for idx, item in enumerate(dataset_items): 50 | label = item[label_key] 51 | label_indices[label].append(idx) 52 | 53 | self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes] 54 | 55 | self._batch_size = batch_size 56 | self._drop_last = drop_last 57 | self._dp_devices = num_gpus 58 | self._num_classes_in_batch = num_classes_in_batch 59 | 60 | def __iter__(self): 61 | batch = [] 62 | if self._num_classes_in_batch != len(self._samplers): 63 | valid_samplers_idx = random.sample( 64 | range(len(self._samplers)), self._num_classes_in_batch 65 | ) 66 | else: 67 | valid_samplers_idx = None 68 | 69 | iters = [iter(s) for s in self._samplers] 70 | done = False 71 | 72 | while True: 73 | b = [] 74 | for i, it in enumerate(iters): 75 | if valid_samplers_idx is not None and i not in valid_samplers_idx: 76 | continue 77 | idx = next(it, None) 78 | if idx is None: 79 | done = True 80 | break 81 | b.append(idx) 82 | if done: 83 | break 84 | batch += b 85 | if len(batch) == self._batch_size: 86 | # yield batch 87 | for b in batch: 88 | yield b 89 | batch = [] 90 | if valid_samplers_idx is not None: 91 | valid_samplers_idx = random.sample( 92 | range(len(self._samplers)), self._num_classes_in_batch 93 | ) 94 | 95 | if not self._drop_last: 96 | if len(batch) > 0: 97 | groups = len(batch) // self._num_classes_in_batch 98 | if groups % self._dp_devices == 0: 99 | yield batch 100 | else: 101 | batch = batch[ 102 | : (groups // self._dp_devices) 103 | * self._dp_devices 104 | * self._num_classes_in_batch 105 | ] 106 | if len(batch) > 0: 107 | yield batch 108 | 109 | def __len__(self): 110 | class_batch_size = self._batch_size // self._num_classes_in_batch 111 | return min( 112 | ((len(s) + class_batch_size - 1) // class_batch_size) 113 | for s in self._samplers 114 | ) 115 | -------------------------------------------------------------------------------- /src/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | 5 | import librosa 6 | import numpy 7 | import torch 8 | import transformers 9 | from scipy import signal 10 | 11 | 12 | class HuggingFaceFeatureExtractor: 13 | def __init__(self, model_class_name, layer=-1, name=None): 14 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 15 | self.feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(name) 16 | model_class = getattr(transformers, model_class_name) 17 | 18 | self.model = model_class.from_pretrained(name, output_hidden_states=True) 19 | self.model.eval() 20 | self.model.to(self.device) 21 | self.layer = layer 22 | 23 | def __call__(self, audio, sr): 24 | inputs = self.feature_extractor( 25 | audio, 26 | sampling_rate=sr, 27 | return_tensors="pt", 28 | padding=True, 29 | ) 30 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 31 | with torch.no_grad(): 32 | outputs = self.model(**inputs) 33 | return outputs.hidden_states[self.layer] 34 | 35 | 36 | class WaveformEmphasiser: 37 | def __init__(self, sampling_rate, musan_path, rir_path): 38 | self.sampling_rate = sampling_rate 39 | self.noisesnr = {"noise": [0, 15], "speech": [13, 20], "music": [5, 15]} 40 | self.numnoise = {"noise": [1, 1], "speech": [3, 8], "music": [1, 1]} 41 | self.noiselist = {} 42 | self.rir_files = glob.glob(os.path.join(rir_path, "*/*/*/*.wav")) 43 | 44 | self.augment_files = glob.glob(os.path.join(musan_path, "*/*/*.wav")) 45 | ## group the noises by category 46 | for file in self.augment_files: 47 | if file.split("/")[-3] not in self.noiselist: 48 | self.noiselist[file.split("/")[-3]] = [] 49 | self.noiselist[file.split("/")[-3]].append(file) 50 | 51 | def __call__(self, waveform, emphasis="original"): 52 | waveform = self._unpack(waveform) 53 | if emphasis == "original": 54 | waveform = waveform 55 | elif emphasis == "reverb": 56 | waveform = self.add_reverb(waveform) 57 | elif emphasis in ["speech", "music", "noise"]: 58 | waveform = self.add_noise(waveform, "speech") 59 | 60 | return self._pack(waveform) 61 | 62 | def _unpack(self, waveform): 63 | return waveform.squeeze().cpu().numpy() 64 | 65 | def _pack(self, waveform): 66 | return torch.Tensor(waveform) 67 | 68 | def add_reverb(self, audio): 69 | rir_file = random.choice(self.rir_files) 70 | rir, sr = librosa.load(rir_file, sr=self.sampling_rate) 71 | rir = rir / numpy.sqrt(numpy.sum(rir**2)) 72 | # print(f"Audio shape: {audio.shape}") 73 | # print(f"RIR shape: {rir.shape}") 74 | result = signal.convolve(audio, rir, mode="full")[: audio.shape[0]] 75 | return result 76 | 77 | def add_noise(self, audio, noise_type="speech"): 78 | audio_db = 10 * numpy.log10(numpy.mean(audio**2) + 1e-4) 79 | noise_file = random.choice(self.noiselist[noise_type]) 80 | noise, sr = librosa.load(noise_file, sr=self.sampling_rate) 81 | if noise.shape[0] <= audio.shape[0]: 82 | noise = numpy.pad(noise, (0, audio.shape[0] - noise.shape[0]), "wrap") 83 | else: 84 | noise = noise[: audio.shape[0]] 85 | noise_db = 10 * numpy.log10(numpy.mean(noise**2) + 1e-4) 86 | random_noise_snr = random.uniform( 87 | self.noisesnr[noise_type][0], self.noisesnr[noise_type][1] 88 | ) 89 | noise = ( 90 | numpy.sqrt(10 ** ((audio_db - noise_db - random_noise_snr) / 10)) * noise 91 | ) 92 | result = audio + noise 93 | return result 94 | 95 | 96 | def shuffle( 97 | feat: torch.Tensor, labels: torch.Tensor 98 | ) -> tuple[torch.Tensor, torch.Tensor]: 99 | shuffle_index = torch.randperm(labels.shape[0]) 100 | feat = feat[shuffle_index] 101 | labels = labels[shuffle_index] 102 | return feat, labels 103 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER: 3 | This code is provided "as-is" without any warranty of any kind, either expressed or implied, 4 | including but not limited to the implied warranties of merchantability and fitness for a particular purpose. 5 | The author assumes no liability for any damages or consequences resulting from the use of this code. 6 | Use it at your own risk. 7 | 8 | ## Author: Piotr KAWA 9 | ## December 2024 10 | ## Code adapted from https://github.com/coqui-ai/TTS/blob/dev/TTS/encoder/losses.py 11 | """ 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn 16 | 17 | 18 | class GE2ELoss(nn.Module): 19 | def __init__( 20 | self, init_w: float = 10.0, init_b: float = -5.0, loss_method: str = "softmax" 21 | ): 22 | """ 23 | Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1] 24 | Accepts an input of size (N, M, D) 25 | where N is the number of speakers in the batch, 26 | M is the number of utterances per speaker, 27 | and D is the dimensionality of the embedding vector (e.g. d-vector) 28 | Args: 29 | - init_w (float): defines the initial value of w in Equation (5) of [1] 30 | - init_b (float): definies the initial value of b in Equation (5) of [1] 31 | """ 32 | super().__init__() 33 | # pylint: disable=E1102 34 | self.w = nn.Parameter(torch.tensor(init_w)) 35 | # pylint: disable=E1102 36 | self.b = nn.Parameter(torch.tensor(init_b)) 37 | self.loss_method = loss_method 38 | 39 | print(" > Initialized Generalized End-to-End loss") 40 | 41 | assert self.loss_method in ["softmax", "contrast"] 42 | 43 | if self.loss_method == "softmax": 44 | self.embed_loss = self.embed_loss_softmax 45 | if self.loss_method == "contrast": 46 | self.embed_loss = self.embed_loss_contrast 47 | 48 | # pylint: disable=R0201 49 | def calc_new_centroids(self, dvecs, centroids, spkr, utt): 50 | """ 51 | Calculates the new centroids excluding the reference utterance 52 | """ 53 | excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :])) 54 | excl = torch.mean(excl, 0) 55 | new_centroids = [] 56 | for i, centroid in enumerate(centroids): 57 | if i == spkr: 58 | new_centroids.append(excl) 59 | else: 60 | new_centroids.append(centroid) 61 | return torch.stack(new_centroids) 62 | 63 | def calc_cosine_sim(self, dvecs, centroids): 64 | """ 65 | Make the cosine similarity matrix with dims (N,M,N) 66 | """ 67 | cos_sim_matrix = [] 68 | for spkr_idx, speaker in enumerate(dvecs): 69 | cs_row = [] 70 | for utt_idx, utterance in enumerate(speaker): 71 | new_centroids = self.calc_new_centroids( 72 | dvecs, centroids, spkr_idx, utt_idx 73 | ) 74 | # vector based cosine similarity for speed 75 | cs_row.append( 76 | torch.clamp( 77 | torch.mm( 78 | utterance.unsqueeze(1).transpose(0, 1).contiguous(), 79 | new_centroids.transpose(0, 1).contiguous(), 80 | ) 81 | / (torch.norm(utterance) * torch.norm(new_centroids, dim=1)), 82 | 1e-6, 83 | ) 84 | ) 85 | cs_row = torch.cat(cs_row, dim=0) 86 | cos_sim_matrix.append(cs_row) 87 | return torch.stack(cos_sim_matrix) 88 | 89 | # pylint: disable=R0201 90 | def embed_loss_softmax(self, dvecs, cos_sim_matrix): 91 | """ 92 | Calculates the loss on each embedding $L(e_{ji})$ by taking softmax 93 | """ 94 | N, M, _ = dvecs.shape 95 | L = [] 96 | for j in range(N): 97 | L_row = [] 98 | for i in range(M): 99 | L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j]) 100 | L_row = torch.stack(L_row) 101 | L.append(L_row) 102 | return torch.stack(L) 103 | 104 | # pylint: disable=R0201 105 | def embed_loss_contrast(self, dvecs, cos_sim_matrix): 106 | """ 107 | Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid 108 | """ 109 | N, M, _ = dvecs.shape 110 | L = [] 111 | for j in range(N): 112 | L_row = [] 113 | for i in range(M): 114 | centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i]) 115 | excl_centroids_sigmoids = torch.cat( 116 | (centroids_sigmoids[:j], centroids_sigmoids[j + 1 :]) 117 | ) 118 | L_row.append( 119 | 1.0 120 | - torch.sigmoid(cos_sim_matrix[j, i, j]) 121 | + torch.max(excl_centroids_sigmoids) 122 | ) 123 | L_row = torch.stack(L_row) 124 | L.append(L_row) 125 | return torch.stack(L) 126 | 127 | def forward(self, x, _label=None): 128 | """ 129 | Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) 130 | """ 131 | 132 | assert x.size()[1] >= 2 133 | 134 | centroids = torch.mean(x, 1) 135 | cos_sim_matrix = self.calc_cosine_sim(x, centroids) 136 | torch.clamp(self.w, 1e-6) 137 | cos_sim_matrix = self.w * cos_sim_matrix + self.b 138 | L = self.embed_loss(x, cos_sim_matrix) 139 | return L.mean() 140 | -------------------------------------------------------------------------------- /src/models/NSD.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER: 3 | This code is provided "as-is" without any warranty of any kind, either expressed or implied, 4 | including but not limited to the implied warranties of merchantability and fitness for a particular purpose. 5 | The author assumes no liability for any damages or consequences resulting from the use of this code. 6 | Use it at your own risk. 7 | 8 | ### Adapted from: https://github.com/xieyuankun/REFD/blob/main/code/ADD2023t3_FD/ood_detectors/NSD.py 9 | ### and: https://github.com/xieyuankun/REFD/blob/main/code/ADD2023t3_FD/ood_detectors/interface.py 10 | 11 | 12 | ## Author: Adriana STAN 13 | ## December 2024 14 | ## Parts of this code are taken from https://github.com/xieyuankun/REFD/tree/main/code/ADD2023t3_FD 15 | """ 16 | 17 | from abc import ABC, abstractmethod 18 | from typing import Dict 19 | 20 | import numpy as np 21 | import torch 22 | import torch.nn.functional as F 23 | 24 | 25 | class OODDetector(ABC): 26 | @abstractmethod 27 | def setup(self, args, train_model_outputs: Dict[str, torch.Tensor]): 28 | pass 29 | 30 | @abstractmethod 31 | def infer(self, model_outputs: Dict[str, torch.Tensor]) -> torch.Tensor: 32 | pass 33 | 34 | 35 | def NSD_with_angle(feats_train, feats, min=False): 36 | feas_train = feats_train.cpu().numpy() 37 | feats = feats.cpu().numpy() 38 | cos_similarity = np.dot(feats, feats_train.T) 39 | if min: 40 | scores = np.array(cos_similarity.min(axis=1)) 41 | else: 42 | scores = np.array(cos_similarity.mean(axis=1)) 43 | return scores 44 | 45 | 46 | class NSDOODDetector(OODDetector): 47 | def setup(self, args, train_model_outputs): 48 | # Compute the training set info 49 | logits_train = torch.Tensor(train_model_outputs["logits"]) 50 | feats_train = torch.Tensor(train_model_outputs["feats"]) 51 | train_labels = train_model_outputs["labels"] 52 | feats_train = F.normalize(feats_train, p=2, dim=-1) 53 | confs_train = torch.logsumexp(logits_train, dim=1) 54 | self.scaled_feats_train = feats_train * confs_train[:, None] 55 | 56 | def infer(self, model_outputs): 57 | feats = torch.Tensor(model_outputs["feats"]) 58 | logits = torch.Tensor(model_outputs["logits"]) 59 | feats = F.normalize(feats, p=2, dim=-1) 60 | confs = torch.logsumexp(logits, dim=1) 61 | guidances = NSD_with_angle(self.scaled_feats_train, feats) 62 | scores = torch.from_numpy(guidances).to(confs.device) * confs 63 | return scores 64 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER: 3 | This code is provided "as-is" without any warranty of any kind, either expressed or implied, 4 | including but not limited to the implied warranties of merchantability and fitness for a particular purpose. 5 | The author assumes no liability for any damages or consequences resulting from the use of this code. 6 | Use it at your own risk. 7 | 8 | ## Author: Adriana STAN, Piotr KAWA, Nicolas MUELLER 9 | ## December 2024 10 | """ 11 | import torch 12 | from torch import nn 13 | 14 | from src.models.w2v2_encoder import Wav2Vec2Encoder 15 | 16 | 17 | def count_parameters(model): 18 | total = sum(p.numel() for p in model.parameters()) 19 | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 20 | non_trainable = total - trainable 21 | return total, trainable, non_trainable 22 | 23 | 24 | def get_model(model_name: str, checkpoint_path: str | None) -> nn.Module: 25 | if model_name == "wav2vec2": 26 | model = Wav2Vec2Encoder() 27 | else: 28 | raise ValueError(f"Model '{model_name}' not implemented") 29 | 30 | print(f" > Initialized model '{model_name}'") 31 | 32 | total, trainable, non_trainable = count_parameters(model) 33 | print( 34 | f" > Number of parameters: {total:,}, Trainable: {trainable:,}, Non-Trainable: {non_trainable:,}" 35 | ) 36 | 37 | if checkpoint_path: 38 | print(f" > Loading weights from '{checkpoint_path}'") 39 | checkpoint = torch.load(checkpoint_path) 40 | model.load_state_dict(checkpoint["model_state_dict"]) 41 | 42 | return model 43 | -------------------------------------------------------------------------------- /src/models/w2v2_aasist.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER: 3 | This code is provided "as-is" without any warranty of any kind, either expressed or implied, 4 | including but not limited to the implied warranties of merchantability and fitness for a particular purpose. 5 | The author assumes no liability for any damages or consequences resulting from the use of this code. 6 | Use it at your own risk. 7 | 8 | ### Source: https://github.com/xieyuankun/REFD/blob/main/code/ADD2023t3_RE/model.py 9 | 10 | ## Author: Adriana STAN 11 | ## December 2024 12 | ## Parts of this code are taken from https://github.com/xieyuankun/REFD/tree/main/code/ADD2023t3_FD 13 | """ 14 | 15 | from typing import Union 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | class GraphAttentionLayer(nn.Module): 23 | def __init__(self, in_dim, out_dim, **kwargs): 24 | super().__init__() 25 | 26 | # attention map 27 | self.att_proj = nn.Linear(in_dim, out_dim) 28 | self.att_weight = self._init_new_params(out_dim, 1) 29 | 30 | # project 31 | self.proj_with_att = nn.Linear(in_dim, out_dim) 32 | self.proj_without_att = nn.Linear(in_dim, out_dim) 33 | 34 | # batch norm 35 | self.bn = nn.BatchNorm1d(out_dim) 36 | 37 | # dropout for inputs 38 | self.input_drop = nn.Dropout(p=0.2) 39 | 40 | # activate 41 | self.act = nn.SELU(inplace=True) 42 | 43 | # temperature 44 | self.temp = 1.0 45 | if "temperature" in kwargs: 46 | self.temp = kwargs["temperature"] 47 | 48 | def forward(self, x): 49 | """ 50 | x :(#bs, #node, #dim) 51 | """ 52 | # apply input dropout 53 | x = self.input_drop(x) 54 | # print(x.shape,'GraphAttentionLayer_x') 55 | 56 | # derive attention map 57 | att_map = self._derive_att_map(x) 58 | 59 | # projection 60 | x = self._project(x, att_map) 61 | 62 | # apply batch norm 63 | x = self._apply_BN(x) 64 | x = self.act(x) 65 | return x 66 | 67 | def _pairwise_mul_nodes(self, x): 68 | """ 69 | Calculates pairwise multiplication of nodes. 70 | - for attention map 71 | x :(#bs, #node, #dim) 72 | out_shape :(#bs, #node, #node, #dim) 73 | """ 74 | 75 | nb_nodes = x.size(1) 76 | x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) 77 | x_mirror = x.transpose(1, 2) 78 | 79 | return x * x_mirror 80 | 81 | def _derive_att_map(self, x): 82 | """ 83 | x :(#bs, #node, #dim) 84 | out_shape :(#bs, #node, #node, 1) 85 | """ 86 | att_map = self._pairwise_mul_nodes(x) 87 | # size: (#bs, #node, #node, #dim_out) 88 | att_map = torch.tanh(self.att_proj(att_map)) 89 | # size: (#bs, #node, #node, 1) 90 | att_map = torch.matmul(att_map, self.att_weight) 91 | 92 | # apply temperature 93 | att_map = att_map / self.temp 94 | 95 | att_map = F.softmax(att_map, dim=-2) 96 | 97 | return att_map 98 | 99 | def _project(self, x, att_map): 100 | x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) 101 | x2 = self.proj_without_att(x) 102 | 103 | return x1 + x2 104 | 105 | def _apply_BN(self, x): 106 | org_size = x.size() 107 | x = x.view(-1, org_size[-1]) 108 | x = self.bn(x) 109 | x = x.view(org_size) 110 | 111 | return x 112 | 113 | def _init_new_params(self, *size): 114 | out = nn.Parameter(torch.FloatTensor(*size)) 115 | nn.init.xavier_normal_(out) 116 | return out 117 | 118 | 119 | class HtrgGraphAttentionLayer(nn.Module): 120 | def __init__(self, in_dim, out_dim, **kwargs): 121 | super().__init__() 122 | 123 | self.proj_type1 = nn.Linear(in_dim, in_dim) 124 | self.proj_type2 = nn.Linear(in_dim, in_dim) 125 | 126 | # attention map 127 | self.att_proj = nn.Linear(in_dim, out_dim) 128 | self.att_projM = nn.Linear(in_dim, out_dim) 129 | 130 | self.att_weight11 = self._init_new_params(out_dim, 1) 131 | self.att_weight22 = self._init_new_params(out_dim, 1) 132 | self.att_weight12 = self._init_new_params(out_dim, 1) 133 | self.att_weightM = self._init_new_params(out_dim, 1) 134 | 135 | # project 136 | self.proj_with_att = nn.Linear(in_dim, out_dim) 137 | self.proj_without_att = nn.Linear(in_dim, out_dim) 138 | 139 | self.proj_with_attM = nn.Linear(in_dim, out_dim) 140 | self.proj_without_attM = nn.Linear(in_dim, out_dim) 141 | 142 | # batch norm 143 | self.bn = nn.BatchNorm1d(out_dim) 144 | 145 | # dropout for inputs 146 | self.input_drop = nn.Dropout(p=0.2) 147 | 148 | # activate 149 | self.act = nn.SELU(inplace=True) 150 | 151 | # temperature 152 | self.temp = 1.0 153 | if "temperature" in kwargs: 154 | self.temp = kwargs["temperature"] 155 | 156 | def forward(self, x1, x2, master=None): 157 | """ 158 | x1 :(#bs, #node, #dim) 159 | x2 :(#bs, #node, #dim) 160 | """ 161 | # print('x1',x1.shape) 162 | # print('x2',x2.shape) 163 | num_type1 = x1.size(1) 164 | num_type2 = x2.size(1) 165 | # print('num_type1',num_type1) 166 | # print('num_type2',num_type2) 167 | x1 = self.proj_type1(x1) 168 | # print('proj_type1',x1.shape) 169 | x2 = self.proj_type2(x2) 170 | # print('proj_type2',x2.shape) 171 | x = torch.cat([x1, x2], dim=1) 172 | # print('Concat x1 and x2',x.shape) 173 | 174 | if master is None: 175 | master = torch.mean(x, dim=1, keepdim=True) 176 | # print('master',master.shape) 177 | # apply input dropout 178 | x = self.input_drop(x) 179 | 180 | # derive attention map 181 | att_map = self._derive_att_map(x, num_type1, num_type2) 182 | # print('master',master.shape) 183 | # directional edge for master node 184 | master = self._update_master(x, master) 185 | # print('master',master.shape) 186 | # projection 187 | x = self._project(x, att_map) 188 | # print('proj x',x.shape) 189 | # apply batch norm 190 | x = self._apply_BN(x) 191 | x = self.act(x) 192 | 193 | x1 = x.narrow(1, 0, num_type1) 194 | # print('x1',x1.shape) 195 | x2 = x.narrow(1, num_type1, num_type2) 196 | # print('x2',x2.shape) 197 | return x1, x2, master 198 | 199 | def _update_master(self, x, master): 200 | 201 | att_map = self._derive_att_map_master(x, master) 202 | master = self._project_master(x, master, att_map) 203 | 204 | return master 205 | 206 | def _pairwise_mul_nodes(self, x): 207 | """ 208 | Calculates pairwise multiplication of nodes. 209 | - for attention map 210 | x :(#bs, #node, #dim) 211 | out_shape :(#bs, #node, #node, #dim) 212 | """ 213 | 214 | nb_nodes = x.size(1) 215 | x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) 216 | x_mirror = x.transpose(1, 2) 217 | 218 | return x * x_mirror 219 | 220 | def _derive_att_map_master(self, x, master): 221 | """ 222 | x :(#bs, #node, #dim) 223 | out_shape :(#bs, #node, #node, 1) 224 | """ 225 | att_map = x * master 226 | att_map = torch.tanh(self.att_projM(att_map)) 227 | 228 | att_map = torch.matmul(att_map, self.att_weightM) 229 | 230 | # apply temperature 231 | att_map = att_map / self.temp 232 | 233 | att_map = F.softmax(att_map, dim=-2) 234 | 235 | return att_map 236 | 237 | def _derive_att_map(self, x, num_type1, num_type2): 238 | """ 239 | x :(#bs, #node, #dim) 240 | out_shape :(#bs, #node, #node, 1) 241 | """ 242 | att_map = self._pairwise_mul_nodes(x) 243 | # size: (#bs, #node, #node, #dim_out) 244 | att_map = torch.tanh(self.att_proj(att_map)) 245 | # size: (#bs, #node, #node, 1) 246 | 247 | att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1) 248 | 249 | att_board[:, :num_type1, :num_type1, :] = torch.matmul( 250 | att_map[:, :num_type1, :num_type1, :], self.att_weight11 251 | ) 252 | att_board[:, num_type1:, num_type1:, :] = torch.matmul( 253 | att_map[:, num_type1:, num_type1:, :], self.att_weight22 254 | ) 255 | att_board[:, :num_type1, num_type1:, :] = torch.matmul( 256 | att_map[:, :num_type1, num_type1:, :], self.att_weight12 257 | ) 258 | att_board[:, num_type1:, :num_type1, :] = torch.matmul( 259 | att_map[:, num_type1:, :num_type1, :], self.att_weight12 260 | ) 261 | 262 | att_map = att_board 263 | 264 | # apply temperature 265 | att_map = att_map / self.temp 266 | 267 | att_map = F.softmax(att_map, dim=-2) 268 | 269 | return att_map 270 | 271 | def _project(self, x, att_map): 272 | x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) 273 | x2 = self.proj_without_att(x) 274 | 275 | return x1 + x2 276 | 277 | def _project_master(self, x, master, att_map): 278 | 279 | x1 = self.proj_with_attM(torch.matmul(att_map.squeeze(-1).unsqueeze(1), x)) 280 | x2 = self.proj_without_attM(master) 281 | 282 | return x1 + x2 283 | 284 | def _apply_BN(self, x): 285 | org_size = x.size() 286 | x = x.view(-1, org_size[-1]) 287 | x = self.bn(x) 288 | x = x.view(org_size) 289 | 290 | return x 291 | 292 | def _init_new_params(self, *size): 293 | out = nn.Parameter(torch.FloatTensor(*size)) 294 | nn.init.xavier_normal_(out) 295 | return out 296 | 297 | 298 | class GraphPool(nn.Module): 299 | def __init__(self, k: float, in_dim: int, p: Union[float, int]): 300 | super().__init__() 301 | self.k = k 302 | self.sigmoid = nn.Sigmoid() 303 | self.proj = nn.Linear(in_dim, 1) 304 | self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity() 305 | self.in_dim = in_dim 306 | 307 | def forward(self, h): 308 | Z = self.drop(h) 309 | weights = self.proj(Z) 310 | scores = self.sigmoid(weights) 311 | new_h = self.top_k_graph(scores, h, self.k) 312 | 313 | return new_h 314 | 315 | def top_k_graph(self, scores, h, k): 316 | """ 317 | args 318 | ===== 319 | scores: attention-based weights (#bs, #node, 1) 320 | h: graph data (#bs, #node, #dim) 321 | k: ratio of remaining nodes, (float) 322 | returns 323 | ===== 324 | h: graph pool applied data (#bs, #node', #dim) 325 | """ 326 | _, n_nodes, n_feat = h.size() 327 | n_nodes = max(int(n_nodes * k), 1) 328 | _, idx = torch.topk(scores, n_nodes, dim=1) 329 | idx = idx.expand(-1, -1, n_feat) 330 | 331 | h = h * scores 332 | h = torch.gather(h, 1, idx) 333 | 334 | return h 335 | 336 | 337 | class Residual_block(nn.Module): 338 | def __init__(self, nb_filts, first=False): 339 | super().__init__() 340 | self.first = first 341 | 342 | if not self.first: 343 | self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0]) 344 | self.conv1 = nn.Conv2d( 345 | in_channels=nb_filts[0], 346 | out_channels=nb_filts[1], 347 | kernel_size=(2, 3), 348 | padding=(1, 1), 349 | stride=1, 350 | ) 351 | self.selu = nn.SELU(inplace=True) 352 | 353 | self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1]) 354 | self.conv2 = nn.Conv2d( 355 | in_channels=nb_filts[1], 356 | out_channels=nb_filts[1], 357 | kernel_size=(2, 3), 358 | padding=(0, 1), 359 | stride=1, 360 | ) 361 | 362 | if nb_filts[0] != nb_filts[1]: 363 | self.downsample = True 364 | self.conv_downsample = nn.Conv2d( 365 | in_channels=nb_filts[0], 366 | out_channels=nb_filts[1], 367 | padding=(0, 1), 368 | kernel_size=(1, 3), 369 | stride=1, 370 | ) 371 | 372 | else: 373 | self.downsample = False 374 | 375 | def forward(self, x): 376 | identity = x 377 | if not self.first: 378 | out = self.bn1(x) 379 | out = self.selu(out) 380 | else: 381 | out = x 382 | 383 | # print('out',out.shape) 384 | out = self.conv1(x) 385 | 386 | # print('aft conv1 out',out.shape) 387 | out = self.bn2(out) 388 | out = self.selu(out) 389 | # print('out',out.shape) 390 | out = self.conv2(out) 391 | # print('conv2 out',out.shape) 392 | 393 | if self.downsample: 394 | identity = self.conv_downsample(identity) 395 | 396 | out += identity 397 | # out = self.mp(out) 398 | return out 399 | 400 | 401 | class W2VAASIST(nn.Module): 402 | def __init__(self, feature_dim, num_labels): 403 | super().__init__() 404 | # AASIST parameters 405 | filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]] 406 | gat_dims = [64, 32] 407 | pool_ratios = [0.5, 0.5, 0.5, 0.5] 408 | temperatures = [2.0, 2.0, 100.0, 100.0] 409 | #### 410 | # create network wav2vec 2.0 411 | #### 412 | self.first_bn = nn.BatchNorm2d(num_features=1) 413 | self.first_bn1 = nn.BatchNorm2d(num_features=64) 414 | self.drop = nn.Dropout(0.5, inplace=True) 415 | self.drop_way = nn.Dropout(0.2, inplace=True) 416 | self.selu = nn.SELU(inplace=True) 417 | 418 | # RawNet2 encoder 419 | self.encoder = nn.Sequential( 420 | nn.Sequential(Residual_block(nb_filts=filts[1], first=True)), 421 | nn.Sequential(Residual_block(nb_filts=filts[2])), 422 | nn.Sequential(Residual_block(nb_filts=filts[3])), 423 | nn.Sequential(Residual_block(nb_filts=filts[4])), 424 | nn.Sequential(Residual_block(nb_filts=filts[4])), 425 | nn.Sequential(Residual_block(nb_filts=filts[4])), 426 | ) 427 | self.LL = nn.Linear(feature_dim, 128) 428 | 429 | self.attention = nn.Sequential( 430 | nn.Conv2d(64, 128, kernel_size=(1, 1)), 431 | nn.SELU(inplace=True), 432 | nn.BatchNorm2d(128), 433 | nn.Conv2d(128, 64, kernel_size=(1, 1)), 434 | ) 435 | # position encoding 436 | self.pos_S = nn.Parameter(torch.randn(1, 42, filts[-1][-1])) 437 | 438 | self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) 439 | self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) 440 | 441 | # Graph module 442 | self.GAT_layer_S = GraphAttentionLayer( 443 | filts[-1][-1], gat_dims[0], temperature=temperatures[0] 444 | ) 445 | self.GAT_layer_T = GraphAttentionLayer( 446 | filts[-1][-1], gat_dims[0], temperature=temperatures[1] 447 | ) 448 | # HS-GAL layer 449 | self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer( 450 | gat_dims[0], gat_dims[1], temperature=temperatures[2] 451 | ) 452 | self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer( 453 | gat_dims[1], gat_dims[1], temperature=temperatures[2] 454 | ) 455 | self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer( 456 | gat_dims[0], gat_dims[1], temperature=temperatures[2] 457 | ) 458 | self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer( 459 | gat_dims[1], gat_dims[1], temperature=temperatures[2] 460 | ) 461 | # Graph pooling layers 462 | self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3) 463 | self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3) 464 | self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 465 | self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 466 | 467 | self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 468 | self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 469 | 470 | self.out_layer = nn.Linear(5 * gat_dims[1], num_labels) 471 | 472 | def forward(self, x): 473 | 474 | # -------pre-trained Wav2vec model fine tunning ------------------------## 475 | x = x.squeeze(dim=1) 476 | x = x.transpose(1, 2) 477 | x = self.LL(x) 478 | x = x.transpose(1, 2) # (bs,feat_out_dim,frame_number) 479 | x = x.unsqueeze(dim=1) # add channel 480 | x = F.max_pool2d(x, (3, 3)) 481 | x = self.first_bn(x) 482 | x = self.selu(x) 483 | 484 | # RawNet2-based encoder 485 | x = self.encoder(x) 486 | x = self.first_bn1(x) 487 | x = self.selu(x) 488 | w = self.attention(x) 489 | 490 | # ------------SA for spectral feature-------------# 491 | w1 = F.softmax(w, dim=-1) 492 | m = torch.sum(x * w1, dim=-1) 493 | e_S = m.transpose(1, 2) + self.pos_S 494 | gat_S = self.GAT_layer_S(e_S) 495 | out_S = self.pool_S(gat_S) # (#bs, #node, #dim) 496 | 497 | # ------------SA for temporal feature-------------# 498 | w2 = F.softmax(w, dim=-2) 499 | m1 = torch.sum(x * w2, dim=-2) 500 | e_T = m1.transpose(1, 2) 501 | 502 | # graph module layer 503 | gat_T = self.GAT_layer_T(e_T) 504 | out_T = self.pool_T(gat_T) 505 | 506 | # learnable master node 507 | master1 = self.master1.expand(x.size(0), -1, -1) 508 | master2 = self.master2.expand(x.size(0), -1, -1) 509 | 510 | # inference 1 511 | out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11( 512 | out_T, out_S, master=self.master1 513 | ) 514 | 515 | out_S1 = self.pool_hS1(out_S1) 516 | out_T1 = self.pool_hT1(out_T1) 517 | 518 | out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12( 519 | out_T1, out_S1, master=master1 520 | ) 521 | out_T1 = out_T1 + out_T_aug 522 | out_S1 = out_S1 + out_S_aug 523 | master1 = master1 + master_aug 524 | 525 | # inference 2 526 | out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21( 527 | out_T, out_S, master=self.master2 528 | ) 529 | out_S2 = self.pool_hS2(out_S2) 530 | out_T2 = self.pool_hT2(out_T2) 531 | 532 | out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22( 533 | out_T2, out_S2, master=master2 534 | ) 535 | out_T2 = out_T2 + out_T_aug 536 | out_S2 = out_S2 + out_S_aug 537 | master2 = master2 + master_aug 538 | 539 | out_T1 = self.drop_way(out_T1) 540 | out_T2 = self.drop_way(out_T2) 541 | out_S1 = self.drop_way(out_S1) 542 | out_S2 = self.drop_way(out_S2) 543 | master1 = self.drop_way(master1) 544 | master2 = self.drop_way(master2) 545 | 546 | out_T = torch.max(out_T1, out_T2) 547 | out_S = torch.max(out_S1, out_S2) 548 | master = torch.max(master1, master2) 549 | 550 | # Readout operation 551 | T_max, _ = torch.max(torch.abs(out_T), dim=1) 552 | T_avg = torch.mean(out_T, dim=1) 553 | 554 | S_max, _ = torch.max(torch.abs(out_S), dim=1) 555 | S_avg = torch.mean(out_S, dim=1) 556 | last_hidden = torch.cat([T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1) 557 | last_hidden = self.drop(last_hidden) 558 | output = self.out_layer(last_hidden) 559 | return last_hidden, output 560 | -------------------------------------------------------------------------------- /src/models/w2v2_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER: 3 | This code is provided "as-is" without any warranty of any kind, either expressed or implied, 4 | including but not limited to the implied warranties of merchantability and fitness for a particular purpose. 5 | The author assumes no liability for any damages or consequences resulting from the use of this code. 6 | Use it at your own risk. 7 | 8 | ## Author: Piotr KAWA, Nicolas MUELLER 9 | ## December 2024 10 | """ 11 | 12 | import torch 13 | from transformers import AutoFeatureExtractor, AutoModelForCTC 14 | 15 | 16 | class Wav2Vec2Encoder(torch.nn.Module): 17 | def __init__(self, device: str = "cuda", sr: int = 16_000): 18 | super().__init__() 19 | self.sr = sr 20 | self.model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h").to( 21 | device 22 | ) 23 | self.feature_extractor = AutoFeatureExtractor.from_pretrained( 24 | "facebook/wav2vec2-base-960h", sampling_rate=sr 25 | ) 26 | self.LL = torch.nn.Sequential( 27 | torch.nn.Linear(768, 256), torch.nn.ReLU(), torch.nn.Linear(256, 256) 28 | ).to(device) 29 | 30 | def forward(self, x): 31 | input_values = self.feature_extractor( 32 | x, return_tensors="pt", sampling_rate=self.sr 33 | ).input_values 34 | input_values = input_values.to(self.model.device).squeeze() 35 | outputs = self.model(input_values, output_hidden_states=True) 36 | hidden_states = outputs.hidden_states[-1] 37 | hidden_states = hidden_states.mean(1) # B x T x D -> B x D 38 | return self.LL(hidden_states) 39 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER: 3 | This code is provided "as-is" without any warranty of any kind, either expressed or implied, 4 | including but not limited to the implied warranties of merchantability and fitness for a particular purpose. 5 | The author assumes no liability for any damages or consequences resulting from the use of this code. 6 | Use it at your own risk. 7 | 8 | ## Author: Adriana STAN, Piotr KAWA 9 | ## December 2024 10 | ## Parts of this code are taken from https://github.com/xieyuankun/REFD/tree/main/code/ADD2023t3_FD 11 | """ 12 | 13 | import os 14 | import random 15 | 16 | import numpy as np 17 | import torch 18 | 19 | 20 | def set_seed(seed: int): 21 | random.seed(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | if torch.cuda.is_available(): 25 | torch.cuda.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = False 29 | os.environ["PYTHONHASHSEED"] = str(seed) 30 | 31 | 32 | def adjust_learning_rate(args, lr, optimizer, epoch_num): 33 | lr = lr * (args.lr_decay ** (epoch_num // args.interval)) 34 | for param_group in optimizer.param_groups: 35 | param_group["lr"] = lr 36 | 37 | 38 | def mixup_data(x_mels, y, device, alpha=0.5): 39 | """Returns mixed inputs, pairs of targets, and lambda""" 40 | if alpha > 0: 41 | lam = np.random.beta(alpha, alpha) 42 | else: 43 | lam = 1 44 | 45 | batch_size = x_mels.size()[0] 46 | index = torch.randperm(batch_size).cuda() 47 | 48 | mixed_x_mels = lam * x_mels + (1 - lam) * x_mels[index, :] 49 | y_a, y_b = y, y[index] 50 | return mixed_x_mels, y_a, y_b, lam 51 | 52 | 53 | def regmix_criterion(criterion, pred, y_a, y_b, lam): 54 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 55 | -------------------------------------------------------------------------------- /train_ge2e.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER: 3 | This code is provided "as-is" without any warranty of any kind, either expressed or implied, 4 | including but not limited to the implied warranties of merchantability and fitness for a particular purpose. 5 | The author assumes no liability for any damages or consequences resulting from the use of this code. 6 | Use it at your own risk. 7 | 8 | ## Author: Piotr KAWA 9 | ## December 2024 10 | """ 11 | 12 | import argparse 13 | from pathlib import Path 14 | 15 | import pandas as pd 16 | import torch 17 | import torch.optim as optim 18 | import yaml 19 | from sklearn.calibration import LabelEncoder 20 | from torch.utils.data import DataLoader 21 | from tqdm import tqdm 22 | 23 | from src.datasets import samplers 24 | from src.datasets.dataset import BaseDataset 25 | from src.losses import GE2ELoss 26 | from src.models import get_model 27 | from src.utils import set_seed 28 | 29 | 30 | def reshape_to_loss_format( 31 | x: torch.Tensor, 32 | num_utter_per_class: int, 33 | num_classes_in_batch: int, 34 | ) -> torch.Tensor: 35 | return ( 36 | x.contiguous() 37 | .reshape(num_utter_per_class, num_classes_in_batch, x.shape[-1]) 38 | .transpose(0, 1) 39 | .contiguous() 40 | ) 41 | 42 | 43 | def parse_args(): 44 | parser = argparse.ArgumentParser(description="Training script") 45 | parser.add_argument( 46 | "--config", type=str, default="configs/config.yaml", help="Path to config file" 47 | ) 48 | parser.add_argument( 49 | "--cpu", 50 | action="store_true", 51 | help="Force the use of CPU even if GPU is available", 52 | ) 53 | return parser.parse_args() 54 | 55 | 56 | def initialize_datasets( 57 | config: dict, path_mlaad: Path, path_protocols: Path, batch_size: int 58 | ) -> tuple[DataLoader, DataLoader]: 59 | protocols = { 60 | "train": path_protocols / "train.csv", 61 | "dev": path_protocols / "dev.csv", 62 | "test": path_protocols / "eval.csv", 63 | } 64 | 65 | for subset, protocols_root_path in protocols.items(): 66 | if not protocols_root_path.exists(): 67 | raise FileNotFoundError(f"{protocols_root_path} does not exist") 68 | 69 | dataframes = { 70 | subset: pd.read_csv(protocols_root_path) 71 | for subset, protocols_root_path in protocols.items() 72 | } 73 | 74 | all_df = pd.concat(dataframes.values()) 75 | le = LabelEncoder() 76 | le.fit(all_df["model_name"]) 77 | class_mapping = {name: idx for idx, name in enumerate(le.classes_)} 78 | 79 | for subset, df in dataframes.items(): 80 | df["model_id"] = le.transform(df["model_name"]) 81 | dataframes[subset] = df 82 | 83 | # we train on concatenation of train and dev 84 | train_and_dev = pd.concat([dataframes["train"], dataframes["dev"]]) 85 | train_dataset = BaseDataset( 86 | meta_data=train_and_dev.to_dict(orient="records"), 87 | basepath=path_mlaad, 88 | class_mapping=class_mapping, 89 | sr=config["data"]["sampling_rate"], 90 | sample_length_s=config["data"]["sample_length_s"], 91 | verbose=True, 92 | ) 93 | 94 | test_dataset = BaseDataset( 95 | meta_data=dataframes["test"].to_dict(orient="records"), 96 | basepath=path_mlaad, 97 | class_mapping=class_mapping, 98 | sr=config["data"]["sampling_rate"], 99 | sample_length_s=config["data"]["sample_length_s"], 100 | verbose=True, 101 | ) 102 | 103 | train_sampler = samplers.PerfectBatchSampler( 104 | dataset_items=train_dataset.samples, 105 | classes=train_dataset.get_class_list(), 106 | batch_size=batch_size, 107 | num_classes_in_batch=n_classes_in_batch, 108 | num_gpus=1, 109 | drop_last=True, 110 | ) 111 | 112 | train_loader = DataLoader( 113 | train_dataset, 114 | batch_size=batch_size, 115 | num_workers=config["data"]["num_workers"], 116 | collate_fn=train_dataset.collate_fn, 117 | pin_memory=True, 118 | sampler=train_sampler, 119 | ) 120 | 121 | test_sampler = samplers.PerfectBatchSampler( 122 | dataset_items=test_dataset.samples, 123 | classes=test_dataset.get_class_list(), 124 | batch_size=batch_size, 125 | num_classes_in_batch=n_classes_in_batch, 126 | num_gpus=1, 127 | ) 128 | 129 | test_loader = DataLoader( 130 | test_dataset, 131 | batch_size=batch_size, 132 | collate_fn=test_dataset.collate_fn, 133 | num_workers=config["data"]["num_workers"], 134 | pin_memory=True, 135 | sampler=test_sampler, 136 | drop_last=True, 137 | ) 138 | 139 | return train_loader, test_loader 140 | 141 | 142 | def train_model( 143 | model: torch.nn.Module, 144 | train_loader: DataLoader, 145 | criterion: torch.nn.Module, 146 | optimizer: optim.Optimizer, 147 | device: str, 148 | num_epochs: int, 149 | log_interval: int, 150 | save_path: Path, 151 | n_utter_per_class: int, 152 | n_classes_in_batch: int, 153 | ) -> torch.nn.Module: 154 | 155 | best_loss = float("inf") 156 | for epoch in tqdm(range(num_epochs)): 157 | tqdm.write(f"Epoch {epoch+1}/{num_epochs}") 158 | model.train() 159 | running_loss = 0.0 160 | num_total = 0 161 | 162 | for batch_idx, (x, y, paths) in enumerate(train_loader): 163 | batch_size = y.size(0) 164 | x, y = x.to(device), y.to(device) 165 | optimizer.zero_grad() 166 | 167 | output = model(x) 168 | out_reshaped = reshape_to_loss_format( 169 | output, n_utter_per_class, n_classes_in_batch 170 | ) 171 | train_loss = criterion(out_reshaped) 172 | 173 | train_loss.backward() 174 | 175 | optimizer.step() 176 | running_loss += train_loss.item() * batch_size 177 | num_total += batch_size 178 | 179 | if (batch_idx + 1) % log_interval == 0: 180 | print( 181 | f"Batch [{batch_idx+1}]: Train Loss: {running_loss / num_total:.4f}" 182 | ) 183 | running_loss /= num_total 184 | 185 | print(f"Epoch [{epoch+1}/{num_epochs}]: Train Loss: {running_loss:.4f}") 186 | 187 | if running_loss < best_loss: 188 | model_save_path = save_path / "best_model.pth" 189 | print( 190 | f"Loss improved ({best_loss:.4f} -> {running_loss:.4f}). Saving model to '{model_save_path}'." 191 | ) 192 | best_loss = running_loss 193 | torch.save( 194 | { 195 | "epoch": epoch, 196 | "model_state_dict": model.state_dict(), 197 | "optimizer_state_dict": optimizer.state_dict(), 198 | "loss": running_loss, 199 | }, 200 | model_save_path, 201 | ) 202 | return model 203 | 204 | 205 | if __name__ == "__main__": 206 | args = parse_args() 207 | 208 | with open(args.config, "r") as f: 209 | config = yaml.safe_load(f) 210 | 211 | set_seed(config["training"]["seed"]) 212 | 213 | path_mlaad = config["data"]["mlaad_root_path"] 214 | path_protocols = config["data"]["protocols_root_path"] 215 | 216 | model = get_model( 217 | model_name=config["model"]["model_name"], 218 | checkpoint_path=config["model"]["checkpoint_path"], 219 | ) 220 | model = model.train() 221 | lr = config["training"]["lr"] 222 | num_epochs = config["training"]["num_epochs"] 223 | device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu" 224 | num_workers = config["data"]["num_workers"] 225 | 226 | criterion = GE2ELoss() 227 | optimizer = optim.Adam(model.parameters(), lr=lr) 228 | 229 | path_mlaad = config["data"]["mlaad_root_path"] 230 | path_protocols = Path(config["data"]["protocols_root_path"]) 231 | 232 | save_path = Path(config["training"]["save_path"]) 233 | print(f"Model-related data will be saved in '{save_path}'") 234 | 235 | save_path.mkdir(parents=True, exist_ok=True) 236 | log_interval = config["training"]["log_interval"] 237 | n_classes_in_batch = config["training"]["n_classes_in_batch"] 238 | n_utter_per_class = config["training"]["n_utter_per_class"] 239 | 240 | batch_size = n_classes_in_batch * n_utter_per_class 241 | 242 | train_loader, test_loader = initialize_datasets( 243 | config=config, 244 | path_mlaad=path_mlaad, 245 | path_protocols=path_protocols, 246 | batch_size=batch_size, 247 | ) 248 | 249 | model.to(device) 250 | model = train_model( 251 | model=model, 252 | train_loader=train_loader, 253 | criterion=criterion, 254 | optimizer=optimizer, 255 | device=device, 256 | num_epochs=num_epochs, 257 | log_interval=log_interval, 258 | save_path=save_path, 259 | n_classes_in_batch=n_classes_in_batch, 260 | n_utter_per_class=n_utter_per_class, 261 | ) 262 | 263 | print("Finished training. Started test procedure!") 264 | model.eval() 265 | test_running_loss = 0 266 | num_total = 0 267 | 268 | with torch.no_grad(): 269 | for batch_idx, (x, y, paths) in enumerate(test_loader): 270 | batch_size = y.size(0) 271 | x, y = x.to(device), y.to(device) 272 | 273 | output = model(x) 274 | out_reshaped = reshape_to_loss_format( 275 | output, n_utter_per_class, n_classes_in_batch 276 | ) 277 | loss = criterion(out_reshaped) 278 | 279 | test_running_loss += loss.item() * batch_size 280 | num_total += batch_size 281 | 282 | if (batch_idx + 1) % log_interval == 0: 283 | print( 284 | f"Batch [{batch_idx+1}]: Test Loss: {test_running_loss / num_total:.4f}" 285 | ) 286 | 287 | test_running_loss /= num_total 288 | print(f"Test Loss: {test_running_loss:.4f}") 289 | -------------------------------------------------------------------------------- /train_refd.py: -------------------------------------------------------------------------------- 1 | """ 2 | DISCLAIMER: 3 | This code is provided "as-is" without any warranty of any kind, either expressed or implied, 4 | including but not limited to the implied warranties of merchantability and fitness for a particular purpose. 5 | The author assumes no liability for any damages or consequences resulting from the use of this code. 6 | Use it at your own risk. 7 | 8 | ## Author: Adriana STAN 9 | ## December 2024 10 | """ 11 | 12 | import argparse 13 | import json 14 | import os 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torch.utils.data.sampler as torch_sampler 21 | from torch.utils.data import DataLoader 22 | from tqdm import tqdm 23 | 24 | from src import utils 25 | from src.datasets.dataset import MLAADFDDataset 26 | from src.models.w2v2_aasist import W2VAASIST 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser("Training script parameters") 31 | 32 | # Paths to features and output 33 | parser.add_argument( 34 | "-f", 35 | "--path_to_features", 36 | type=str, 37 | default="./exp/preprocess_wav2vec2-base/", 38 | help="Path to the previuosly extracted features", 39 | ) 40 | parser.add_argument( 41 | "--out_folder", type=str, default="./exp/trained_models/", help="Output folder" 42 | ) 43 | 44 | # Training hyperparameters 45 | parser.add_argument("--seed", type=int, help="random number seed", default=688) 46 | parser.add_argument( 47 | "--feat_dim", 48 | type=int, 49 | default=768, 50 | help="Feature dimension from the wav2vec model", 51 | ) 52 | parser.add_argument( 53 | "--num_classes", type=int, default=24, help="Number of in domain classes" 54 | ) 55 | parser.add_argument( 56 | "--num_epochs", type=int, default=30, help="Number of epochs for training" 57 | ) 58 | parser.add_argument( 59 | "--batch_size", type=int, default=128, help="Batch size for training" 60 | ) 61 | parser.add_argument("--lr", type=float, default=0.0005, help="learning rate") 62 | parser.add_argument( 63 | "--lr_decay", type=float, default=0.5, help="decay learning rate" 64 | ) 65 | parser.add_argument("--interval", type=int, default=10, help="interval to decay lr") 66 | parser.add_argument("--beta_1", type=float, default=0.9, help="bata_1 for Adam") 67 | parser.add_argument("--beta_2", type=float, default=0.999, help="beta_2 for Adam") 68 | parser.add_argument("--eps", type=float, default=1e-8, help="epsilon for Adam") 69 | parser.add_argument("--num_workers", type=int, default=0, help="number of workers") 70 | parser.add_argument( 71 | "--base_loss", 72 | type=str, 73 | default="ce", 74 | choices=["ce", "bce"], 75 | help="Loss for basic training", 76 | ) 77 | args = parser.parse_args() 78 | 79 | # Set seeds 80 | utils.set_seed(args.seed) 81 | 82 | # Path for output data 83 | if not os.path.exists(args.out_folder): 84 | os.makedirs(args.out_folder) 85 | 86 | # Folder for intermediate results 87 | if not os.path.exists(os.path.join(args.out_folder, "checkpoint")): 88 | os.makedirs(os.path.join(args.out_folder, "checkpoint")) 89 | 90 | # Path for input data 91 | assert os.path.exists(args.path_to_features) 92 | 93 | # Save training arguments 94 | with open(os.path.join(args.out_folder, "args.json"), "w") as file: 95 | file.write(json.dumps(vars(args), sort_keys=True, separators=("\n", ":"))) 96 | 97 | cuda = torch.cuda.is_available() 98 | print("Running on: ", "cuda" if cuda else "cpu") 99 | args.device = torch.device("cuda" if cuda else "cpu") 100 | return args 101 | 102 | 103 | def train(args): 104 | 105 | # Load the train and dev data 106 | print("Loading training data...") 107 | training_set = MLAADFDDataset(args.path_to_features, "train") 108 | print("\nLoading dev data...") 109 | dev_set = MLAADFDDataset(args.path_to_features, "dev", mode="known") 110 | 111 | train_loader = DataLoader( 112 | training_set, 113 | batch_size=args.batch_size, 114 | shuffle=False, 115 | num_workers=args.num_workers, 116 | sampler=torch_sampler.SubsetRandomSampler(range(len(training_set))), 117 | ) 118 | dev_loader = DataLoader( 119 | dev_set, 120 | batch_size=args.batch_size, 121 | shuffle=False, 122 | num_workers=args.num_workers, 123 | sampler=torch_sampler.SubsetRandomSampler(range(len(dev_set))), 124 | ) 125 | 126 | # Setup the model 127 | model = W2VAASIST(args.feat_dim, args.num_classes).to(args.device) 128 | print(f"Training a {type(model).__name__} model for {args.num_epochs} epochs") 129 | feat_optimizer = torch.optim.Adam( 130 | model.parameters(), 131 | lr=args.lr, 132 | betas=(args.beta_1, args.beta_2), 133 | eps=args.eps, 134 | weight_decay=0.0005, 135 | ) 136 | if args.base_loss == "ce": 137 | criterion = nn.CrossEntropyLoss() 138 | else: 139 | criterion = nn.BCELoss() 140 | 141 | prev_loss = 1e8 142 | # Main training loop 143 | for epoch_num in range(args.num_epochs): 144 | model.train() 145 | utils.adjust_learning_rate(args, args.lr, feat_optimizer, epoch_num) 146 | 147 | epoch_bar = tqdm(train_loader, desc=f"Epoch [{epoch_num+1}/{args.num_epochs}]") 148 | accuracy, train_loss = [], [] 149 | for iter_num, batch in enumerate(epoch_bar): 150 | feat, audio, labels = batch 151 | feat = feat.transpose(2, 3).to(args.device) 152 | labels = labels.to(args.device) 153 | 154 | mix_feat, y_a, y_b, lam = utils.mixup_data( 155 | feat, labels, args.device, alpha=0.5 156 | ) 157 | 158 | targets_a = torch.cat([labels, y_a]) 159 | targets_b = torch.cat([labels, y_b]) 160 | feat = torch.cat([feat, mix_feat], dim=0) 161 | 162 | feats, feat_outputs = model(feat) 163 | if args.base_loss == "bce": 164 | feat_loss = criterion(feat_outputs, labels.unsqueeze(1).float()) 165 | else: 166 | feat_loss = utils.regmix_criterion( 167 | criterion, feat_outputs, targets_a, targets_b, lam 168 | ) 169 | 170 | score = F.softmax(feat_outputs, dim=1) # [:, 0] 171 | predicted_classes = np.argmax(score.detach().cpu().numpy(), axis=1) 172 | correct_predictions = [ 173 | 1 for k in range(len(labels)) if predicted_classes[k] == labels[k] 174 | ] 175 | accuracy.append(sum(correct_predictions) / len(labels) * 100) 176 | train_loss.append(feat_loss.item()) 177 | epoch_bar.set_postfix( 178 | { 179 | "train_loss": f"{sum(train_loss)/(iter_num+1):.4f}", 180 | "acc": f"{sum(accuracy)/(iter_num+1):.2f}", 181 | } 182 | ) 183 | 184 | feat_optimizer.zero_grad() 185 | feat_loss.backward() 186 | feat_optimizer.step() 187 | 188 | # Epoch eval 189 | model.eval() 190 | with torch.no_grad(): 191 | val_bar = tqdm(dev_loader, desc=f"Validation for epoch {epoch_num+1}") 192 | accuracy, val_loss = [], [] 193 | for iter_num, batch in enumerate(val_bar): 194 | feat, _, labels = batch 195 | feat = feat.transpose(2, 3).to(args.device) 196 | labels = labels.to(args.device) 197 | 198 | feats, feat_outputs = model(feat) 199 | if args.base_loss == "bce": 200 | feat_loss = criterion(feat_outputs, labels.unsqueeze(1).float()) 201 | score = feat_outputs 202 | else: 203 | feat_loss = criterion(feat_outputs, labels) 204 | score = F.softmax(feat_outputs, dim=1) 205 | 206 | predicted_classes = np.argmax(score.detach().cpu().numpy(), axis=1) 207 | correct_predictions = [ 208 | 1 for k in range(len(labels)) if predicted_classes[k] == labels[k] 209 | ] 210 | accuracy.append(sum(correct_predictions) / len(labels) * 100) 211 | 212 | val_loss.append(feat_loss.item()) 213 | val_bar.set_postfix( 214 | { 215 | "val_loss": f"{sum(val_loss)/(iter_num+1):.4f}", 216 | "val_acc": f"{sum(accuracy)/(iter_num+1):.2f}", 217 | } 218 | ) 219 | 220 | epoch_val_loss = sum(val_loss) / (iter_num + 1) 221 | if epoch_val_loss < prev_loss: 222 | # Save the checkpoint with better val_loss 223 | checkpoint_path = os.path.join( 224 | args.out_folder, "anti-spoofing_feat_model.pth" 225 | ) 226 | print(f"[INFO] Saving model with better val_loss to {checkpoint_path}") 227 | torch.save(model.state_dict(), checkpoint_path) 228 | prev_loss = epoch_val_loss 229 | 230 | elif (epoch_num + 1) % 10 == 0: 231 | # Save the intermediate checkpoints just in case 232 | checkpoint_path = os.path.join( 233 | args.out_folder, 234 | "checkpoint", 235 | "anti-spoofing_feat_model_%02d.pth" % (epoch_num + 1), 236 | ) 237 | print( 238 | f"[INFO] Saving intermediate model at epoch {epoch_num+1} to {checkpoint_path}" 239 | ) 240 | torch.save(model.state_dict(), checkpoint_path) 241 | print("\n") 242 | 243 | 244 | if __name__ == "__main__": 245 | args = parse_args() 246 | train(args) 247 | --------------------------------------------------------------------------------