├── upload.py ├── .gitignore ├── README.md ├── client.py ├── generate_timbre_embeddings.py ├── server.py ├── emolia-explorer.html └── annotate_audio.py /upload.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from datasets import Dataset, Audio, Features, Value 4 | from tqdm import tqdm 5 | import pandas as pd 6 | from huggingface_hub import HfApi 7 | from joblib import Parallel, delayed 8 | 9 | # --- Configuration --- 10 | # The root directory containing the folders with audio/json pairs. 11 | root_dir = os.path.expanduser("~/emilia-yodas/EN") 12 | # The Hugging Face Hub repository to push the dataset to. 13 | dataset_repo = "laion/Emilia-Annotated-WIP" 14 | # Number of parallel jobs to run. -1 uses all available CPU cores. 15 | N_JOBS = -1 16 | 17 | # --- Helper Functions for Parallel Processing --- 18 | 19 | def scan_folder(foldername, root_dir): 20 | """ 21 | Scans a single folder for valid .mp3 and .json file pairs. 22 | 23 | Args: 24 | foldername (str): The name of the folder to scan. 25 | root_dir (str): The root directory containing the folder. 26 | 27 | Returns: 28 | list: A list of tuples, where each tuple contains the path to an .mp3 file 29 | and its corresponding .json file. 30 | """ 31 | folder_path = os.path.join(root_dir, foldername) 32 | if not os.path.isdir(folder_path): 33 | return [] 34 | 35 | pairs = [] 36 | for filename in os.listdir(folder_path): 37 | if filename.endswith(".mp3"): 38 | base_name = filename[:-4] 39 | mp3_path = os.path.join(folder_path, filename) 40 | json_path = os.path.join(folder_path, f"{base_name}.json") 41 | if os.path.exists(json_path): 42 | pairs.append((mp3_path, json_path)) 43 | return pairs 44 | 45 | def process_file_pair(mp3_path, json_path): 46 | """ 47 | Processes a single mp3/json pair, validates it, and extracts data. 48 | 49 | Args: 50 | mp3_path (str): The file path to the MP3 audio file. 51 | json_path (str): The file path to the corresponding JSON metadata file. 52 | 53 | Returns: 54 | dict: A dictionary containing the extracted data ('audio', 'caption', 55 | 'emotions', 'raw_json') if the pair is valid. 56 | None: Returns None if the file pair is invalid or cannot be read. 57 | """ 58 | try: 59 | # Read the entire JSON file content first 60 | with open(json_path, "r", encoding="utf-8") as f: 61 | raw_json_content = f.read() 62 | metadata = json.loads(raw_json_content) 63 | except (json.JSONDecodeError, IOError): 64 | # Return None if JSON is malformed or the file can't be read 65 | return None 66 | 67 | # Validate that the metadata is a dictionary and has the required fields 68 | if ( 69 | not isinstance(metadata, dict) 70 | or "caption" not in metadata 71 | or "emotions" not in metadata 72 | ): 73 | return None 74 | 75 | # Normalize the caption: if it starts with "AA", reduce it to a single "A" 76 | caption = metadata["caption"] 77 | if isinstance(caption, str) and caption.startswith("AA"): 78 | caption = "A" + caption[2:] 79 | 80 | # Return the processed data as a dictionary 81 | return { 82 | "audio": mp3_path, 83 | "caption": caption, 84 | "emotions": json.dumps(metadata["emotions"]), 85 | "raw_json": raw_json_content, 86 | } 87 | 88 | # --- Main Script --- 89 | 90 | # Step 1: Collect valid .mp3 and .json file pairs using parallel folder scanning 91 | print("Starting parallel scan of folders...") 92 | folder_list = os.listdir(root_dir) 93 | parallel_results = Parallel(n_jobs=N_JOBS)( 94 | delayed(scan_folder)(foldername, root_dir) for foldername in tqdm(folder_list, desc="Scanning folders") 95 | ) 96 | # Flatten the list of lists into a single list of file pairs 97 | file_pairs = [pair for sublist in parallel_results for pair in sublist] 98 | print(f"Found {len(file_pairs)} potential file pairs.") 99 | 100 | 101 | # Step 2: Process files in parallel to extract metadata 102 | print("\nStarting parallel processing of files...") 103 | processed_results = Parallel(n_jobs=N_JOBS)( 104 | delayed(process_file_pair)(mp3_path, json_path) for mp3_path, json_path in tqdm(file_pairs, desc="Processing files") 105 | ) 106 | 107 | # Filter out the None values from invalid/skipped files 108 | valid_data = [item for item in processed_results if item is not None] 109 | skipped_files = len(file_pairs) - len(valid_data) 110 | 111 | print(f"\nSuccessfully processed {len(valid_data)} files.") 112 | print(f"Skipped {skipped_files} invalid or unreadable files.") 113 | 114 | 115 | # Step 3: Create the Hugging Face Dataset object with the new schema 116 | features = Features({ 117 | "audio": Audio(sampling_rate=16000), 118 | "caption": Value("string"), 119 | "emotions": Value("string"), 120 | "raw_json": Value("string") 121 | }) 122 | 123 | # Convert the list of dictionaries directly to a pandas DataFrame 124 | df = pd.DataFrame(valid_data) 125 | dataset = Dataset.from_pandas(df, features=features) 126 | 127 | print("\nDataset object created successfully.") 128 | print(dataset) 129 | 130 | # Step 4: Push the dataset to the Hugging Face Hub 131 | print(f"\nPushing dataset to Hugging Face Hub repository: {dataset_repo}") 132 | dataset.push_to_hub(dataset_repo, max_shard_size="500MB") 133 | 134 | print("\nScript finished.") 135 | # Note: The push_to_hub command is commented out. 136 | # Uncomment it when you are ready to upload the data. 137 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[codz] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | #poetry.toml 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. 114 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control 115 | #pdm.lock 116 | #pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # pixi 121 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. 122 | #pixi.lock 123 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one 124 | # in the .venv directory. It is recommended not to include this directory in version control. 125 | .pixi 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .envrc 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | 147 | # Spyder project settings 148 | .spyderproject 149 | .spyproject 150 | 151 | # Rope project settings 152 | .ropeproject 153 | 154 | # mkdocs documentation 155 | /site 156 | 157 | # mypy 158 | .mypy_cache/ 159 | .dmypy.json 160 | dmypy.json 161 | 162 | # Pyre type checker 163 | .pyre/ 164 | 165 | # pytype static type analyzer 166 | .pytype/ 167 | 168 | # Cython debug symbols 169 | cython_debug/ 170 | 171 | # PyCharm 172 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 173 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 174 | # and can be added to the global gitignore or merged into this file. For a more nuclear 175 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 176 | #.idea/ 177 | 178 | # Abstra 179 | # Abstra is an AI-powered process automation framework. 180 | # Ignore directories containing user credentials, local state, and settings. 181 | # Learn more at https://abstra.io/docs 182 | .abstra/ 183 | 184 | # Visual Studio Code 185 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 186 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 187 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 188 | # you could uncomment the following to ignore the entire vscode folder 189 | # .vscode/ 190 | 191 | # Ruff stuff: 192 | .ruff_cache/ 193 | 194 | # PyPI configuration file 195 | .pypirc 196 | 197 | # Cursor 198 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 199 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 200 | # refer to https://docs.cursor.com/context/ignore-files 201 | .cursorignore 202 | .cursorindexingignore 203 | 204 | # Marimo 205 | marimo/_static/ 206 | marimo/_lsp/ 207 | __marimo__/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EmoNet - Voice Annotation Toolkit 2 | 3 | This repository contains a suite of tools for analyzing audio files. The core functionalities are generating transcriptions, detailed emotional content scores, and speaker-specific embeddings. The toolkit is optimized for GPU acceleration and designed for efficient batch processing of large datasets. 4 | 5 | The primary goal is to generate rich, multi-faceted annotations for use in downstream tasks, such as training expressive Text-to-Speech (TTS) models. The underlying concepts are explored in more detail in LAION's research blog post, ["Do They See What We See?"](https://laion.ai/blog/do-they-see-what-we-see/). 6 | 7 | 8 | 9 | ## Features 10 | 11 | - **Dual-Mode Annotation**: Provides both transcription and a rich set of 55 emotion/attribute scores from a single, efficient pass over the audio using the **Empathic Insight Voice Small** models, which are based on the **EmoNet** architecture. 12 | - **Batch Processing Script**: A standalone script (`annotate_audio.py`) to process entire folders of audio files recursively, leveraging GPU batching for throughput. 13 | - **Intelligent File Handling**: The script automatically skips already processed files and can merge new annotations into existing JSON files without overwriting other data. 14 | - **Server/Client Architecture**: Includes a robust, asynchronous FastAPI server and a versatile client for real-time inference applications. 15 | - **Speaker Timbre Embeddings**: A conceptualized script to generate unique speaker embeddings using `Orange/Speaker-wavLM-tbr`, allowing for clustering of speakers based on their unique vocal characteristics (timbre). 16 | - **Optimizations**: Leverages FP16 (half-precision), Flash Attention 2 (with a stable fallback), and targeted `torch.compile` for performance on modern NVIDIA GPUs. 17 | 18 | ## The Annotation Workflow 19 | 20 | The tools in this repository can be used to create a powerful data pipeline for training next-generation TTS models: 21 | 22 | 1. **Speaker Clustering**: Use the speaker embedding script on your dataset. This generates timbre-based embeddings that can cluster speech snippets from the same (or very similar) speakers together, even if they are speaking with different emotions. This allows you to assign a consistent pseudo-identity (e.g., `speaker_001`, `speaker_002`) to each voice in a large, unlabeled dataset. 23 | 2. **Emotion & Transcription Annotation**: Run the `annotate_audio.py` script on the same dataset. This will generate a JSON file for each audio clip containing the transcription and all 55 emotion/attribute scores from the **Empathic Insight Voice** models. 24 | 3. **Training Data Assembly**: You can now assemble a rich training dataset. Each data point can contain: 25 | - The raw audio waveform. 26 | - The text transcription (the target for the TTS model to speak). 27 | - The speaker identity (from the clustering step). 28 | - The emotion scores (which become controllable conditioning signals). 29 | 4. **Training a Controllable TTS Model**: With this data, one can train a TTS model to take a speaker identity, a text prompt, and a desired set of emotion scores as input. This allows the final model to generate speech for the same speaker but with different emotional expressions. 30 | 31 | --- 32 | 33 | ## 1. Standalone Folder Annotation Script (`annotate_audio.py`) 34 | 35 | This is the primary tool for offline batch processing. It's a single, self-contained script that scans an input folder, finds all audio files, and generates detailed JSON annotations for each one. 36 | 37 | ### How it Works 38 | 39 | - **Input**: A folder path. 40 | - **Processing**: 41 | - Recursively finds all supported audio files (`.wav`, `.mp3`, `.m4a`, etc.). 42 | - Intelligently checks if a corresponding `.json` file already exists and contains complete annotations. If so, it skips the audio file. 43 | - Groups the remaining files into batches to process efficiently on the GPU. 44 | - For each audio file, it generates a transcription and the 55 emotion/attribute scores. 45 | - Saves the output to a `.json` file with the same name as the audio file. If the JSON file already exists but is missing data, this script will add the new annotations to it without overwriting existing, unrelated data. 46 | - **Output**: A `.json` file for each processed audio file, saved in the same directory. 47 | 48 | ### Usage 49 | 50 | 1. **Prerequisites**: Ensure you have the required libraries installed: `pip install torch transformers huggingface-hub librosa soundfile tqdm` 51 | 2. **Run the script**: 52 | ```bash 53 | python annotate_audio.py /path/to/your/audio_dataset 54 | ``` 55 | 56 | ### Example Output (`your_audio.json`) 57 | ```json 58 | { 59 | "source_audio_file": "your_audio.wav", 60 | "caption": "This is the transcribed text from the audio file.", 61 | "emotions": { 62 | "Amusement": 0.038, 63 | "Interest": 2.822, 64 | "Contentment": 1.776, 65 | "Age": 3.010, 66 | "Valence": 1.950, 67 | "...": "..." 68 | } 69 | } 70 | ``` 71 | 72 | --- 73 | 74 | ## 2. High-Performance Server & Client 75 | 76 | For real-time applications, a robust server/client architecture is provided. 77 | 78 | ### Server (`server.py`) 79 | 80 | A high-performance FastAPI server that pre-loads all models into GPU memory and uses a dynamic batching system to handle concurrent requests with high throughput. 81 | 82 | #### Features 83 | - **Pre-loading**: All models are loaded once on startup. 84 | - **Dynamic Batching**: Groups incoming requests into optimal batches to maximize GPU utilization. 85 | - **Robust Optimizations**: Attempts to use Flash Attention 2 and falls back gracefully. 86 | 87 | #### Usage 88 | 1. **Launch the server**: 89 | ```bash 90 | uvicorn server:app --host 0.0.0.0 --port 8022 91 | ``` 92 | 93 | ### Client (`client.py`) 94 | 95 | An asynchronous client for interacting with the server. It can be used for single-file analysis, batch processing of folders, or running performance benchmarks. 96 | 97 | #### Usage 98 | - **Run a demo on a sample file**: 99 | ```bash 100 | python client.py --demo 101 | ``` 102 | - **Analyze a single local file**: 103 | ```bash 104 | python client.py --file /path/to/your/audio.wav 105 | ``` 106 | - **Run a high-throughput benchmark**: 107 | ```bash 108 | python client.py --benchmark 109 | ``` 110 | 111 | --- 112 | 113 | ## 3. Speaker Timbre Embedding Script (Advanced) 114 | 115 | This script is a key component for enabling advanced voice cloning and speaker identification capabilities. While not developed as part of the core Empathic Insight project, we provide and recommend this tool for its powerful and complementary functionality. It uses the **[Orange/Speaker-wavLM-tbr](https://huggingface.co/Orange/Speaker-wavLM-tbr)** model. 116 | 117 | ### The Utility of Timbre Embeddings 118 | 119 | The core concept is to separate the *identity* of a speaker from the *emotion* of their speech. 120 | - **Emotion** is conveyed through prosody, pitch, and energy, which change from moment to moment. 121 | - **Timbre** is the unique, underlying "fingerprint" or "color" of a voice that remains constant. It's what makes a specific person's voice recognizable, regardless of whether they are whispering happily or shouting angrily. 122 | 123 | The `Speaker-wavLM-tbr` model is specifically trained to listen to an audio clip and generate an embedding vector that represents only this timbre, effectively ignoring the emotional content. 124 | 125 | ### From Embeddings to Controllable Voice Cloning 126 | 127 | This emotion-invariant property is particularly useful for processing large-scale, unlabeled datasets: 128 | 129 | 1. **The Goal**: You have thousands of audio clips from many different, unknown speakers expressing various emotions. You want to group all clips from "Speaker A" together, all clips from "Speaker B" together, and so on. 130 | 131 | 2. **The Process**: 132 | - Run the speaker embedding script on your entire dataset. Each audio file now has a corresponding timbre vector. 133 | - Use a clustering algorithm (like K-Means, HDBSCAN, etc.) on these vectors. 134 | - The algorithm will automatically group the vectors into distinct clusters. Because the embeddings are emotion-invariant, a happy clip and a sad clip from the same person will have very similar vectors and will be placed in the same cluster. 135 | 136 | 3. **The Result: Pseudo Speaker Identities**: 137 | Each resulting cluster represents a unique speaker. You can now assign a **"pseudo speaker identity"** (e.g., `speaker_001`, `speaker_002`) to every audio file in your dataset based on which cluster it belongs to. 138 | 139 | 4. **The Application: Controllable TTS**: 140 | With this final piece of data, you can train a highly sophisticated TTS model. The model can be conditioned on multiple inputs: a reference audio for the speaker's voice, a target text, and a target emotion. The training data would be structured as follows: 141 | - **Reference Input**: Reference audio tokens, reference transcription, reference emotion scores. 142 | - **Target Input**: Target transcription, target emotion conditioning. 143 | - **Prediction**: The model predicts the target audio tokens. 144 | 145 | This allows the model to learn the separation of identity and emotion. At inference time, you can provide a single audio file of any speaker to define the voice timbre, and then ask the model to generate **any new text with any new emotion** in that person's voice. This enables true zero-shot voice cloning with emotional control. 146 | 147 | ### Usage 148 | *(Note: A conceptual `speaker_embedding.py` script would be provided here. It would be structured similarly to `annotate_audio.py`, loading the `Orange/Speaker-wavLM-tbr` model and saving `.pt` or `.npy` embedding files for each audio file.)* 149 | 150 | ```bash 151 | # Example usage (conceptual) 152 | python speaker_embedding.py /path/to/your/audio_dataset --output-folder /path/to/embeddings 153 | ``` 154 | -------------------------------------------------------------------------------- /client.py: -------------------------------------------------------------------------------- 1 | # Final client with caption display. 2 | 3 | import asyncio 4 | import aiohttp 5 | import argparse 6 | from pathlib import Path 7 | import time 8 | import os 9 | import shutil 10 | import tempfile 11 | import requests 12 | from typing import List, Dict 13 | 14 | # --- Configuration --- 15 | DEFAULT_SERVER_URL = "http://127.0.0.1:8022/predict" 16 | DEMO_FILE_URL = "https://huggingface.co/datasets/laion/School_BUD-E/resolve/main/juniper-long-en.wav" 17 | SUPPORTED_EXTENSIONS = (".wav", ".mp3", ".flac") 18 | CACHE_DIR = Path("./demo_cache") 19 | 20 | CORE_EMOTION_KEYS: List[str] = [ 21 | "Amusement", 22 | "Elation", 23 | "Pleasure_Ecstasy", 24 | "Contentment", 25 | "Thankfulness_Gratitude", 26 | "Affection", 27 | "Infatuation", 28 | "Hope_Enthusiasm_Optimism", 29 | "Triumph", 30 | "Pride", 31 | "Interest", 32 | "Awe", 33 | "Astonishment_Surprise", 34 | "Concentration", 35 | "Contemplation", 36 | "Relief", 37 | "Longing", 38 | "Teasing", 39 | "Impatience_and_Irritability", 40 | "Sexual_Lust", 41 | "Doubt", 42 | "Fear", 43 | "Distress", 44 | "Confusion", 45 | "Embarrassment", 46 | "Shame", 47 | "Disappointment", 48 | "Sadness", 49 | "Bitterness", 50 | "Contempt", 51 | "Disgust", 52 | "Anger", 53 | "Malevolence_Malice", 54 | "Sourness", 55 | "Pain", 56 | "Helplessness", 57 | "Fatigue_Exhaustion", 58 | "Emotional_Numbness", 59 | "Intoxication_Altered_States_of_Consciousness", 60 | "Jealousy_&_Envy", 61 | ] 62 | ATTRIBUTE_KEYS: List[str] = [ 63 | "Age", 64 | "Arousal", 65 | "Authenticity", 66 | "Background_Noise", 67 | "Confident_vs._Hesitant", 68 | "Gender", 69 | "High-Pitched_vs._Low-Pitched", 70 | "Monotone_vs._Expressive", 71 | "Recording_Quality", 72 | "Serious_vs._Humorous", 73 | "Soft_vs._Harsh", 74 | "Submissive_vs._Dominant", 75 | "Valence", 76 | "Vulnerable_vs._Emotionally_Detached", 77 | "Warm_vs._Cold", 78 | ] 79 | 80 | 81 | def get_demo_file() -> Path: 82 | CACHE_DIR.mkdir(exist_ok=True) 83 | dest_path = CACHE_DIR / DEMO_FILE_URL.split("/")[-1] 84 | if dest_path.exists() and dest_path.stat().st_size > 0: 85 | print(f"Demo file found in cache: {dest_path}") 86 | return dest_path 87 | print(f"Downloading demo file to cache: {dest_path}...") 88 | try: 89 | with requests.get(DEMO_FILE_URL, stream=True) as r: 90 | r.raise_for_status() 91 | with open(dest_path, "wb") as f: 92 | for chunk in r.iter_content(chunk_size=8192): 93 | f.write(chunk) 94 | print("Download complete.") 95 | return dest_path 96 | except requests.exceptions.RequestException as e: 97 | print(f"Error downloading file: {e}") 98 | raise 99 | 100 | 101 | def format_and_print_results(result: dict): 102 | print("\n" + "=" * 50) 103 | print(f"RESULTS FOR: {result.get('file', 'N/A')}") 104 | print("=" * 50) 105 | if "error" in result: 106 | print(f"An error occurred: {result['error']}") 107 | print("=" * 50 + "\n") 108 | return 109 | 110 | # --- ADDED: Display Transcription --- 111 | caption = result.get("caption", "No caption returned.") 112 | print("\n--- Transcription ---\n") 113 | print(f'"{caption}"\n') 114 | # --- END OF ADDITION --- 115 | 116 | if "emotions" not in result or not result["emotions"]: 117 | print("No emotion data returned from server.") 118 | print("=" * 50 + "\n") 119 | return 120 | 121 | all_scores = result["emotions"] 122 | emotion_scores = {k: v for k, v in all_scores.items() if k in CORE_EMOTION_KEYS} 123 | attribute_scores = {k: v for k, v in all_scores.items() if k in ATTRIBUTE_KEYS} 124 | sorted_emotions = sorted( 125 | emotion_scores.items(), key=lambda item: item[1], reverse=True 126 | ) 127 | sorted_attributes = sorted( 128 | attribute_scores.items(), key=lambda item: item[1], reverse=True 129 | ) 130 | 131 | print("--- Top 5 Core Emotions ---") 132 | for emotion, score in sorted_emotions[:5]: 133 | print(f"- {emotion.replace('_', ' '):<35} | Score: {score:.3f}") 134 | print("\n--- Top 5 Attributes / Dimensions ---") 135 | for attr, score in sorted_attributes[:5]: 136 | print(f"- {attr.replace('_', ' '):<35} | Score: {score:.3f}") 137 | print(f"\n--- All {len(sorted_emotions)} Emotion Scores (descending) ---") 138 | for emotion, score in sorted_emotions: 139 | print(f"- {emotion.replace('_', ' '):<40} {score:7.3f}") 140 | print(f"\n--- All {len(sorted_attributes)} Attribute Scores (descending) ---") 141 | for attr, score in sorted_attributes: 142 | print(f"- {attr.replace('_', ' '):<40} {score:7.3f}") 143 | print("=" * 50 + "\n") 144 | 145 | 146 | async def query_single_file( 147 | session: aiohttp.ClientSession, file_path: Path, url: str 148 | ) -> dict: 149 | try: 150 | with open(file_path, "rb") as f: 151 | data = aiohttp.FormData() 152 | data.add_field( 153 | "file", f, filename=file_path.name, content_type="audio/mpeg" 154 | ) 155 | async with session.post(url, data=data) as response: 156 | if response.status == 200: 157 | json_res = await response.json() 158 | json_res["file"] = file_path.name 159 | return json_res 160 | else: 161 | return { 162 | "file": file_path.name, 163 | "error": f"HTTP {response.status}: {await response.text()}", 164 | } 165 | except Exception as e: 166 | return {"file": file_path.name, "error": str(e)} 167 | 168 | 169 | async def main(): 170 | parser = argparse.ArgumentParser( 171 | description="Enhanced client for the Emotion Annotation API.", 172 | formatter_class=argparse.RawTextHelpFormatter, 173 | ) 174 | group = parser.add_mutually_exclusive_group(required=True) 175 | group.add_argument( 176 | "--demo", 177 | action="store_true", 178 | help="Run inference on the cached demo file and print results.", 179 | ) 180 | group.add_argument( 181 | "--benchmark", 182 | action="store_true", 183 | help="Run a high-throughput benchmark with 128 concurrent requests.", 184 | ) 185 | group.add_argument( 186 | "--file", type=str, help="Path to a single audio file to process." 187 | ) 188 | group.add_argument( 189 | "--folder", type=str, help="Path to a folder of audio files to process." 190 | ) 191 | parser.add_argument( 192 | "--url", 193 | type=str, 194 | default=DEFAULT_SERVER_URL, 195 | help="URL of the prediction endpoint.", 196 | ) 197 | parser.add_argument( 198 | "--concurrency", 199 | type=int, 200 | default=32, 201 | help="Number of concurrent requests for folder/benchmark modes.", 202 | ) 203 | args = parser.parse_args() 204 | 205 | if args.demo: 206 | print("--- Running in DEMO mode ---") 207 | demo_file_path = get_demo_file() 208 | async with aiohttp.ClientSession() as session: 209 | start_time = time.time() 210 | result = await query_single_file(session, demo_file_path, args.url) 211 | end_time = time.time() 212 | format_and_print_results(result) 213 | print(f"Time for single inference: {end_time - start_time:.2f} seconds.") 214 | elif args.benchmark: 215 | print("--- Running in BENCHMARK mode ---") 216 | num_copies = 128 217 | demo_file_path = get_demo_file() 218 | with tempfile.TemporaryDirectory() as temp_dir: 219 | temp_path = Path(temp_dir) 220 | print(f"Creating {num_copies-1} temporary copies for benchmark...") 221 | files_to_process = [demo_file_path] + [ 222 | shutil.copy( 223 | demo_file_path, temp_path / f"copy_{i}_{demo_file_path.name}" 224 | ) 225 | for i in range(num_copies - 1) 226 | ] 227 | print("Starting benchmark...") 228 | start_time = time.time() 229 | connector = aiohttp.TCPConnector(limit=args.concurrency, ssl=False) 230 | async with aiohttp.ClientSession(connector=connector) as session: 231 | tasks = [ 232 | query_single_file(session, file, args.url) 233 | for file in files_to_process 234 | ] 235 | results = await asyncio.gather(*tasks) 236 | end_time = time.time() 237 | total_time, success_count = end_time - start_time, sum( 238 | 1 for r in results if "error" not in r 239 | ) 240 | print("\n--- Benchmark Complete ---") 241 | print(f"Total time to process {num_copies} files: {total_time:.2f} seconds") 242 | print( 243 | f"Throughput: {success_count / total_time:.2f} files/second" 244 | if total_time > 0 245 | else "N/A" 246 | ) 247 | print( 248 | f"Effective time per file: {total_time / success_count:.3f} seconds" 249 | if success_count > 0 250 | else "N/A" 251 | ) 252 | else: 253 | files_to_process = ( 254 | [Path(args.file)] 255 | if args.file 256 | else [ 257 | p 258 | for p in Path(args.folder).rglob("*") 259 | if p.suffix.lower() in SUPPORTED_EXTENSIONS 260 | ] 261 | ) 262 | if not files_to_process: 263 | print("No supported audio files found.") 264 | return 265 | print(f"Found {len(files_to_process)} file(s) to process...") 266 | start_time = time.time() 267 | connector = aiohttp.TCPConnector(limit=args.concurrency, ssl=False) 268 | async with aiohttp.ClientSession(connector=connector) as session: 269 | results = await asyncio.gather( 270 | *[ 271 | query_single_file(session, file, args.url) 272 | for file in files_to_process 273 | ] 274 | ) 275 | end_time = time.time() 276 | print(f"\n--- Processing of {len(files_to_process)} file(s) complete ---") 277 | for res in results[:3]: 278 | format_and_print_results(res) 279 | if len(results) > 3: 280 | print(f"...and {len(results) - 3} more results not shown.") 281 | print(f"Total time taken: {end_time - start_time:.2f} seconds") 282 | 283 | 284 | if __name__ == "__main__": 285 | if os.name == "nt": 286 | asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) 287 | asyncio.run(main()) 288 | -------------------------------------------------------------------------------- /generate_timbre_embeddings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torchaudio 5 | import numpy as np 6 | from pathlib import Path 7 | import multiprocessing 8 | from tqdm import tqdm 9 | 10 | # --- Configuration --- 11 | # !!! SET YOUR INPUT FOLDER HERE !!! 12 | INPUT_FOLDER = "/mnt/raid/spirit/Speaker-wavLM-tbr/laion_audio_output_data/" 13 | # The Hugging Face model name for timbre embeddings 14 | MODEL_NAME = "Orange/Speaker-wavLM-tbr" 15 | # Audio file extensions to search for 16 | AUDIO_EXTENSIONS = (".flac", ".mp3", ".wav", ".m4a", ".ogg", ".opus") 17 | # Batch size for GPU inference (how many audio files are processed by the model in one go on a single GPU) 18 | BATCH_SIZE = 16 19 | # Target sample rate for the model 20 | TARGET_SAMPLE_RATE = 16000 21 | # Maximum audio duration in seconds. Files longer than this will be truncated. 22 | # Set to None for no truncation. The model authors used 30s for some evaluations. 23 | MAX_AUDIO_DURATION_SEC = 30 24 | MAX_AUDIO_SAMPLES = ( 25 | int(TARGET_SAMPLE_RATE * MAX_AUDIO_DURATION_SEC) if MAX_AUDIO_DURATION_SEC else None 26 | ) 27 | 28 | # --- Ensure spk_embeddings.py can be imported --- 29 | # This adds the script's directory to Python's path, helping to find spk_embeddings.py 30 | # and ensuring that spk_embeddings.py can import its own dependencies (like transformers). 31 | script_dir = os.path.dirname(os.path.abspath(__file__)) 32 | if script_dir not in sys.path: 33 | sys.path.insert(0, script_dir) 34 | 35 | try: 36 | from spk_embeddings import EmbeddingsModel 37 | except ImportError as e: 38 | print( 39 | f"Detailed ImportError when trying to import 'EmbeddingsModel' from 'spk_embeddings.py': {e}" 40 | ) 41 | print("\nThis could be because:") 42 | print( 43 | "1. spk_embeddings.py is not in the current path (less likely if it's in the same directory)." 44 | ) 45 | print( 46 | "2. A module that spk_embeddings.py itself tries to import is missing or cannot be imported." 47 | ) 48 | print( 49 | " Common dependencies for spk_embeddings.py are: torch, torchaudio, transformers, huggingface_hub." 50 | ) 51 | print(" Please ensure these are installed in your Python environment.") 52 | print(f"\nAttempted to load spk_embeddings.py from directory: {script_dir}") 53 | print("Current sys.path includes:") 54 | for p_item in sys.path: 55 | print(f" - {p_item}") 56 | print( 57 | "\nIf you haven't, try: pip install torch torchaudio transformers huggingface_hub" 58 | ) 59 | sys.exit(1) 60 | except Exception as e: 61 | print(f"An unexpected error occurred during import: {e}") 62 | sys.exit(1) 63 | 64 | # --- Helper Functions --- 65 | 66 | 67 | def find_audio_files(folder_path: Path, extensions: tuple) -> list: 68 | """Recursively finds all audio files with given extensions in a folder.""" 69 | audio_files = [] 70 | print( 71 | f"Scanning for audio files with extensions {extensions} in '{folder_path}'..." 72 | ) 73 | for ext in extensions: 74 | audio_files.extend(list(folder_path.rglob(f"*{ext.lower()}"))) 75 | audio_files.extend( 76 | list(folder_path.rglob(f"*{ext.upper()}")) 77 | ) # Case-insensitive 78 | 79 | # Sort and remove duplicates 80 | unique_audio_files = sorted(list(set(audio_files))) 81 | print(f"Found {len(unique_audio_files)} unique audio file paths.") 82 | return unique_audio_files 83 | 84 | 85 | def load_and_resample_audio(file_path: Path, target_sr: int): 86 | """ 87 | Loads an audio file, resamples it to target_sr, converts to mono, 88 | and returns it as a Tensor of shape (1, num_samples). 89 | """ 90 | try: 91 | waveform, sample_rate = torchaudio.load(file_path) # (num_channels, num_frames) 92 | 93 | # Resample if necessary 94 | if sample_rate != target_sr: 95 | resampler = torchaudio.transforms.Resample( 96 | orig_freq=sample_rate, new_freq=target_sr 97 | ) 98 | waveform = resampler(waveform) 99 | 100 | # Convert to mono if stereo, ensure shape is (1, num_samples) 101 | if waveform.shape[0] > 1: # If more than 1 channel 102 | waveform = torch.mean(waveform, dim=0, keepdim=True) # Average channels 103 | elif waveform.shape[0] == 0: # Should not happen with valid audio 104 | print( 105 | f"Warning: Audio file {file_path} has 0 channels after loading. Skipping." 106 | ) 107 | return None, None 108 | # If waveform.shape[0] == 1, it's already mono and correctly shaped (1, num_samples) 109 | 110 | return waveform, target_sr 111 | except Exception as e: 112 | print(f"Error loading or resampling {file_path}: {e}") 113 | return None, None 114 | 115 | 116 | def worker_process_files( 117 | process_id: int, 118 | gpu_id: int, 119 | assigned_file_paths: list, # List of all file paths assigned to this worker process 120 | model_name: str, 121 | progress_queue: multiprocessing.Queue, 122 | gpu_batch_size: int, # Actual batch size for GPU inference 123 | max_samples: int, # Max samples for truncation 124 | ): 125 | """ 126 | Worker function to process a list of audio files on a specific GPU. 127 | It loads the model once and processes its assigned files in batches. 128 | """ 129 | # Set CUDA device for this specific process 130 | # This ensures each process uses its designated GPU 131 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 132 | device = torch.device( 133 | f"cuda:{0}" 134 | ) # After CUDA_VISIBLE_DEVICES, cuda:0 is the target GPU 135 | 136 | try: 137 | # Each process loads its own instance of the model 138 | model = EmbeddingsModel.from_pretrained(model_name) 139 | model.to(device) 140 | model.eval() 141 | except Exception as e: 142 | print( 143 | f"[Process {process_id} on GPU {gpu_id}] Error loading model '{model_name}': {e}" 144 | ) 145 | # If model loading fails, signal failure for all assigned files for progress tracking 146 | for _ in assigned_file_paths: 147 | progress_queue.put(1) 148 | return 149 | 150 | files_processed_by_worker = 0 151 | # Process assigned files in chunks of gpu_batch_size 152 | for i in range(0, len(assigned_file_paths), gpu_batch_size): 153 | current_chunk_paths = assigned_file_paths[i : i + gpu_batch_size] 154 | if not current_chunk_paths: 155 | continue 156 | 157 | waveforms_for_batch = [] 158 | valid_paths_in_chunk = ( 159 | [] 160 | ) # To keep track of which files loaded successfully for this chunk 161 | max_len_in_chunk = 0 162 | 163 | for audio_file_path in current_chunk_paths: 164 | output_npy_path = audio_file_path.with_suffix(".npy") 165 | if output_npy_path.exists(): 166 | # print(f"[Process {process_id}] Skipping {output_npy_path}, already exists.") 167 | progress_queue.put(1) # Signal progress for this skipped file 168 | continue # Skip this file for batch processing 169 | 170 | waveform, sr = load_and_resample_audio(audio_file_path, TARGET_SAMPLE_RATE) 171 | if waveform is None: # Loading or resampling failed 172 | progress_queue.put(1) # Signal failure for this file 173 | continue # Skip this file for batch processing 174 | 175 | # Truncate if MAX_AUDIO_SAMPLES is set and waveform is longer 176 | if max_samples is not None and waveform.shape[1] > max_samples: 177 | waveform = waveform[:, :max_samples] # waveform is (1, num_samples) 178 | 179 | waveforms_for_batch.append( 180 | waveform.squeeze(0) 181 | ) # Squeeze to (num_samples) for padding 182 | valid_paths_in_chunk.append(audio_file_path) 183 | if waveform.shape[1] > max_len_in_chunk: 184 | max_len_in_chunk = waveform.shape[1] 185 | 186 | if ( 187 | not waveforms_for_batch 188 | ): # All files in this chunk were skipped or failed loading 189 | continue 190 | 191 | # Pad waveforms in the current batch to the max_len_in_chunk 192 | padded_waveforms_list = [] 193 | for wf in waveforms_for_batch: 194 | padding_needed = max_len_in_chunk - wf.shape[0] 195 | # Pad at the end of the audio signal (dim 0 after squeeze) 196 | padded_wf = torch.nn.functional.pad(wf, (0, padding_needed)) 197 | padded_waveforms_list.append(padded_wf) 198 | 199 | # Stack into a batch tensor: (current_chunk_actual_size, max_len_in_chunk) 200 | batch_input_tensor = torch.stack(padded_waveforms_list).to(device) 201 | 202 | try: 203 | with torch.no_grad(): 204 | # The EmbeddingsModel.forward expects (batch_size, num_samples) 205 | embeddings_batch_output = model(batch_input_tensor) 206 | 207 | # Save embeddings for each valid file in the processed chunk 208 | for idx, original_path in enumerate(valid_paths_in_chunk): 209 | output_npy_path = original_path.with_suffix(".npy") 210 | # embeddings_batch_output[idx] is a 1D tensor for the embedding 211 | np.save(output_npy_path, embeddings_batch_output[idx].cpu().numpy()) 212 | files_processed_by_worker += 1 213 | except Exception as e: 214 | print( 215 | f"[Process {process_id} on GPU {gpu_id}] Error processing batch (starts with {valid_paths_in_chunk[0].name if valid_paths_in_chunk else 'N/A'}): {e}" 216 | ) 217 | # If batch processing fails, signal progress for all files intended for this batch as failures 218 | for ( 219 | _ 220 | ) in ( 221 | valid_paths_in_chunk 222 | ): # These were files that loaded but failed during model inference/saving 223 | progress_queue.put(1) 224 | continue # to the next chunk of files for this worker 225 | 226 | # Signal progress for all successfully processed files in this batch 227 | for _ in valid_paths_in_chunk: 228 | progress_queue.put(1) 229 | 230 | # print(f"[Process {process_id} on GPU {gpu_id}] Finished. Processed {files_processed_by_worker} new files.") 231 | 232 | 233 | # --- Main Logic --- 234 | if __name__ == "__main__": 235 | # 'spawn' is recommended for CUDA with multiprocessing for safety 236 | multiprocessing.set_start_method("spawn", force=True) 237 | 238 | input_path = Path(INPUT_FOLDER) 239 | if not input_path.is_dir(): 240 | print( 241 | f"Error: Input folder '{INPUT_FOLDER}' does not exist or is not a directory." 242 | ) 243 | sys.exit(1) 244 | 245 | all_audio_files = find_audio_files(input_path, AUDIO_EXTENSIONS) 246 | 247 | if not all_audio_files: 248 | print("No audio files found in the specified folder and subfolders. Exiting.") 249 | sys.exit(0) 250 | 251 | # Filter out files for which embeddings (.npy) already exist 252 | files_to_process = [] 253 | for f_path in all_audio_files: 254 | if not f_path.with_suffix(".npy").exists(): 255 | files_to_process.append(f_path) 256 | 257 | if not files_to_process: 258 | print( 259 | "All audio files seem to have corresponding .npy embeddings already. Exiting." 260 | ) 261 | sys.exit(0) 262 | 263 | print(f"Found {len(files_to_process)} audio files that need new embeddings.") 264 | 265 | # Determine number of GPUs to use 266 | if torch.cuda.is_available(): 267 | num_gpus = torch.cuda.device_count() 268 | print(f"Found {num_gpus} CUDA device(s). Using all available GPUs.") 269 | else: 270 | print("CUDA not available. This script is optimized for GPU acceleration.") 271 | print("It will attempt to run on CPU using 1 process, which will be very slow.") 272 | num_gpus = 1 # Fallback to 1 CPU process (GPU ID will be 0, but device will be CPU in worker) 273 | # The worker will still try to use 'cuda:0'. This needs a CPU fallback in worker. 274 | # For simplicity, let's assume CUDA is the primary target. 275 | # A robust CPU version would explicitly set device to 'cpu'. 276 | # Let's enforce CUDA for this version as per typical high-performance audio processing. 277 | if not torch.cuda.is_available(): # Re-check for clarity 278 | print( 279 | "ERROR: No CUDA GPUs found. This script requires CUDA for efficient operation." 280 | ) 281 | print( 282 | "Please install PyTorch with CUDA support or adapt the script for CPU use." 283 | ) 284 | sys.exit(1) 285 | 286 | # Distribute files_to_process among worker processes (one worker per GPU) 287 | worker_assignments = [[] for _ in range(num_gpus)] 288 | for idx, file_path in enumerate(files_to_process): 289 | # Simple round-robin assignment to GPUs 290 | worker_assignments[idx % num_gpus].append(file_path) 291 | 292 | processes = [] 293 | progress_queue = multiprocessing.Queue() # For TQDM progress updates 294 | 295 | print( 296 | f"\nStarting {num_gpus} worker process(es). Each will handle its assigned files." 297 | ) 298 | print(f"GPU batch size for model inference within each worker: {BATCH_SIZE}") 299 | print( 300 | f"Max audio samples per file (after {TARGET_SAMPLE_RATE}Hz resample): {MAX_AUDIO_SAMPLES or 'No limit'}\n" 301 | ) 302 | 303 | for i in range(num_gpus): 304 | if not worker_assignments[ 305 | i 306 | ]: # If a GPU has no files assigned (e.g., fewer files than GPUs) 307 | continue 308 | 309 | # process_id is 'i', gpu_id is also 'i' (0-indexed) 310 | # worker_assignments[i] is the list of file paths for this worker 311 | p = multiprocessing.Process( 312 | target=worker_process_files, 313 | args=( 314 | i, 315 | i, 316 | worker_assignments[i], 317 | MODEL_NAME, 318 | progress_queue, 319 | BATCH_SIZE, 320 | MAX_AUDIO_SAMPLES, 321 | ), 322 | ) 323 | processes.append(p) 324 | p.start() 325 | 326 | # Progress bar handling using TQDM 327 | # total=len(files_to_process) because progress_queue.put(1) is called for each file attempt 328 | with tqdm( 329 | total=len(files_to_process), desc="Generating Embeddings", unit="file" 330 | ) as pbar: 331 | for _ in range(len(files_to_process)): 332 | progress_queue.get() # Wait for a signal that one file attempt is complete 333 | pbar.update(1) 334 | 335 | # Wait for all processes to complete 336 | for p in processes: 337 | p.join() 338 | 339 | print("\nAll embeddings generated successfully.") 340 | print( 341 | f"Embeddings (.npy files) are saved alongside their original audio files in '{INPUT_FOLDER}'." 342 | ) 343 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | # FINAL VERSION: Generates both transcription and emotion scores in one efficient pass. 2 | # Includes Flash Attention 2 enabled by default with a robust fallback. 3 | 4 | import os 5 | import logging 6 | import asyncio 7 | import threading 8 | import time 9 | from typing import List, Dict, Any 10 | from pathlib import Path 11 | import shutil 12 | import tempfile 13 | import uuid 14 | from collections import OrderedDict 15 | 16 | # --- Core ML/AI Libraries --- 17 | import torch 18 | import torch.nn as nn 19 | import librosa 20 | from transformers import AutoProcessor, WhisperForConditionalGeneration 21 | from huggingface_hub import snapshot_download 22 | 23 | # --- Web Framework --- 24 | from fastapi import FastAPI, UploadFile, File, HTTPException 25 | from fastapi.responses import JSONResponse 26 | from contextlib import asynccontextmanager 27 | from dataclasses import dataclass, field 28 | 29 | # --- Configuration Section --- 30 | logging.basicConfig( 31 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 32 | ) 33 | 34 | if torch.cuda.is_available() and torch.cuda.device_count() > 1: 35 | DEVICE = "cuda:1" 36 | elif torch.cuda.is_available(): 37 | DEVICE = "cuda:0" 38 | else: 39 | DEVICE = "cpu" 40 | logging.info(f"Using device: {DEVICE}") 41 | 42 | WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1" 43 | HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" 44 | LOCAL_MLP_MODELS_DIR = Path("./empathic_insight_models") 45 | 46 | WHISPER_SEQ_LEN = 1500 47 | WHISPER_EMBED_DIM = 768 48 | PROJECTION_DIM = 64 49 | MLP_HIDDEN_DIMS = [64, 32, 16] 50 | MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] 51 | SAMPLING_RATE = 16000 52 | MAX_AUDIO_SECONDS = 30 53 | 54 | # --- PERFORMANCE OPTIMIZATION FLAGS --- 55 | # Flash Attention 2 provides a significant speed-up on compatible GPUs. 56 | # Set to True by default. The server will automatically fall back if it fails. 57 | USE_FLASH_ATTENTION_2 = False # True 58 | USE_TORCH_COMPILE_FOR_WHISPER = False # Kept False for stability 59 | USE_TORCH_COMPILE_FOR_MLPS = True 60 | USE_FP16 = True 61 | 62 | 63 | class FullEmbeddingMLP(nn.Module): 64 | def __init__( 65 | self, 66 | seq_len: int, 67 | embed_dim: int, 68 | projection_dim: int, 69 | mlp_hidden_dims: List[int], 70 | mlp_dropout_rates: List[float], 71 | ): 72 | super().__init__() 73 | if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1: 74 | raise ValueError(f"Dropout rates length error.") 75 | self.flatten = nn.Flatten() 76 | self.proj = nn.Linear(seq_len * embed_dim, projection_dim) 77 | layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])] 78 | current_dim = projection_dim 79 | for i, h_dim in enumerate(mlp_hidden_dims): 80 | layers.extend( 81 | [ 82 | nn.Linear(current_dim, h_dim), 83 | nn.ReLU(), 84 | nn.Dropout(mlp_dropout_rates[i + 1]), 85 | ] 86 | ) 87 | current_dim = h_dim 88 | layers.append(nn.Linear(current_dim, 1)) 89 | self.mlp = nn.Sequential(*layers) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | if x.ndim == 4 and x.shape[1] == 1: 93 | x = x.squeeze(1) 94 | projected = self.proj(self.flatten(x)) 95 | return self.mlp(projected) 96 | 97 | 98 | def load_models_and_processor(): 99 | logging.info("Starting model loading process...") 100 | logging.info(f"Loading Whisper model: {WHISPER_MODEL_ID}") 101 | whisper_dtype = torch.float16 if USE_FP16 and DEVICE != "cpu" else torch.float32 102 | whisper_model = None 103 | if USE_FLASH_ATTENTION_2 and DEVICE != "cpu": 104 | try: 105 | logging.info("Attempting to load Whisper model with Flash Attention 2...") 106 | whisper_model = WhisperForConditionalGeneration.from_pretrained( 107 | WHISPER_MODEL_ID, 108 | torch_dtype=whisper_dtype, 109 | low_cpu_mem_usage=True, 110 | use_safetensors=True, 111 | attn_implementation="flash_attention_2", 112 | ).to(DEVICE) 113 | logging.info("Successfully loaded Whisper model with Flash Attention 2.") 114 | except Exception as e: 115 | logging.warning( 116 | f"Flash Attention 2 is enabled but failed to load. Error: {e}" 117 | ) 118 | logging.warning("FALLING BACK to the default 'sdpa' attention mechanism.") 119 | if whisper_model is None: 120 | whisper_model = WhisperForConditionalGeneration.from_pretrained( 121 | WHISPER_MODEL_ID, 122 | torch_dtype=whisper_dtype, 123 | low_cpu_mem_usage=True, 124 | use_safetensors=True, 125 | attn_implementation="sdpa", 126 | ).to(DEVICE) 127 | logging.info("Whisper model loaded successfully with default attention.") 128 | whisper_model.eval() 129 | 130 | whisper_processor = AutoProcessor.from_pretrained(WHISPER_MODEL_ID) 131 | logging.info( 132 | f"Downloading MLP models from {HF_MLP_REPO_ID} to {LOCAL_MLP_MODELS_DIR}..." 133 | ) 134 | snapshot_download( 135 | repo_id=HF_MLP_REPO_ID, 136 | local_dir=LOCAL_MLP_MODELS_DIR, 137 | local_dir_use_symlinks=False, 138 | ignore_patterns=["*.mp3", "*.md", ".gitattributes"], 139 | ) 140 | mlp_models = {} 141 | mlp_files = list(LOCAL_MLP_MODELS_DIR.glob("*.pth")) 142 | logging.info(f"Found {len(mlp_files)} MLP model files.") 143 | for model_path in mlp_files: 144 | filename = model_path.stem 145 | parts = filename.split("_") 146 | dimension_name = ( 147 | "_".join(parts[1:-1]) if "best" in parts[-1] else "_".join(parts[1:]) 148 | ) 149 | mlp_model = FullEmbeddingMLP( 150 | seq_len=WHISPER_SEQ_LEN, 151 | embed_dim=WHISPER_EMBED_DIM, 152 | projection_dim=PROJECTION_DIM, 153 | mlp_hidden_dims=MLP_HIDDEN_DIMS, 154 | mlp_dropout_rates=MLP_DROPOUTS, 155 | ).to(DEVICE) 156 | state_dict = torch.load(model_path, map_location=DEVICE, weights_only=True) 157 | if any(k.startswith("_orig_mod.") for k in state_dict.keys()): 158 | state_dict = OrderedDict( 159 | (k.replace("_orig_mod.", ""), v) for k, v in state_dict.items() 160 | ) 161 | mlp_model.load_state_dict(state_dict) 162 | mlp_model.eval() 163 | if USE_FP16 and DEVICE != "cpu": 164 | mlp_model.half() 165 | if USE_TORCH_COMPILE_FOR_MLPS: 166 | logging.info(f"Compiling MLP for: {dimension_name}...") 167 | mlp_model = torch.compile(mlp_model, mode="reduce-overhead", fullgraph=True) 168 | mlp_models[dimension_name] = mlp_model 169 | logging.info("All models loaded and optimized successfully.") 170 | return whisper_processor, whisper_model, mlp_models 171 | 172 | 173 | class BatchInferenceManager: 174 | def __init__(self, whisper_processor, whisper_model, mlp_models): 175 | self.whisper_processor = whisper_processor 176 | self.whisper_model = whisper_model 177 | self.mlp_models = mlp_models 178 | 179 | def process_batch(self, audio_paths: List[Path]) -> List[Dict[str, Any]]: 180 | if not audio_paths: 181 | return [] 182 | processed_audios = [self._load_and_process_audio(p) for p in audio_paths] 183 | valid_audios = [a for a in processed_audios if a is not None] 184 | if not valid_audios: 185 | return [{"error": "Invalid audio"} for _ in audio_paths] 186 | 187 | with torch.no_grad(): 188 | inputs = self.whisper_processor( 189 | [a["waveform"] for a in valid_audios], 190 | sampling_rate=SAMPLING_RATE, 191 | return_tensors="pt", 192 | padding="max_length", 193 | truncation=True, 194 | ).to(DEVICE, non_blocking=True) 195 | if USE_FP16 and DEVICE != "cpu": 196 | inputs["input_features"] = inputs["input_features"].to( 197 | dtype=self.whisper_model.dtype 198 | ) 199 | 200 | # --- EFFICIENT DUAL INFERENCE --- 201 | # 1. Run the encoder ONLY ONCE 202 | encoder_outputs = self.whisper_model.get_encoder()( 203 | inputs.input_features, return_dict=True 204 | ) 205 | 206 | # 2. Use the encoder output for transcription 207 | predicted_ids = self.whisper_model.generate(encoder_outputs=encoder_outputs) 208 | captions = self.whisper_processor.batch_decode( 209 | predicted_ids, skip_special_tokens=True 210 | ) 211 | 212 | # 3. Use the SAME encoder output for emotion analysis 213 | embeddings = encoder_outputs.last_hidden_state 214 | 215 | batch_results = [] 216 | for i in range(embeddings.size(0)): 217 | # MLP processing 218 | single_embedding = embeddings[i : i + 1] 219 | emotion_predictions = { 220 | dim: mlp(single_embedding).item() 221 | for dim, mlp in self.mlp_models.items() 222 | } 223 | 224 | # Assemble the final result with the caption 225 | batch_results.append( 226 | { 227 | "file": valid_audios[i]["path"].name, 228 | "caption": captions[i].strip(), 229 | "emotions": emotion_predictions, 230 | } 231 | ) 232 | 233 | final_results = [] 234 | result_map = {res["file"]: res for res in batch_results} 235 | for path in audio_paths: 236 | final_results.append( 237 | result_map.get( 238 | path.name, {"file": path.name, "error": "Processing failed"} 239 | ) 240 | ) 241 | return final_results 242 | 243 | def _load_and_process_audio(self, audio_path: Path): 244 | try: 245 | waveform, _ = librosa.load( 246 | audio_path, sr=SAMPLING_RATE, mono=True, duration=MAX_AUDIO_SECONDS 247 | ) 248 | return {"waveform": waveform, "path": audio_path} 249 | except Exception as e: 250 | logging.error(f"Failed to load audio file {audio_path}: {e}") 251 | return None 252 | 253 | 254 | # --- DynamicBatcher and FastAPI App (no changes needed) --- 255 | @dataclass 256 | class PendingRequest: 257 | future: asyncio.Future 258 | temp_path: Path 259 | enqueue_time: float = field(default_factory=time.time) 260 | 261 | 262 | class DynamicBatcher: 263 | def __init__( 264 | self, 265 | inference_manager: BatchInferenceManager, 266 | batch_size: int, 267 | max_wait_time: float, 268 | ): 269 | self.inference_manager, self.batch_size, self.max_wait_time = ( 270 | inference_manager, 271 | batch_size, 272 | max_wait_time, 273 | ) 274 | self.queue: List[PendingRequest] = [] 275 | self.lock = threading.Lock() 276 | self.shutdown_event = threading.Event() 277 | self.processing_thread = threading.Thread( 278 | target=self._batching_loop, daemon=True 279 | ) 280 | self.loop = None 281 | 282 | def start(self): 283 | self.loop = asyncio.get_running_loop() 284 | self.processing_thread.start() 285 | logging.info( 286 | f"Dynamic batcher started with batch_size={self.batch_size}, max_wait_time={self.max_wait_time}s" 287 | ) 288 | 289 | def stop(self): 290 | self.shutdown_event.set() 291 | self.processing_thread.join() 292 | 293 | async def add_request(self, temp_path: Path) -> Dict: 294 | future = self.loop.create_future() 295 | request = PendingRequest(future=future, temp_path=temp_path) 296 | with self.lock: 297 | self.queue.append(request) 298 | try: 299 | return await asyncio.wait_for(future, timeout=MAX_AUDIO_SECONDS + 15) 300 | except asyncio.TimeoutError: 301 | raise HTTPException(status_code=504, detail="Request timed out in queue.") 302 | 303 | def _batching_loop(self): 304 | while not self.shutdown_event.is_set(): 305 | batch_to_process = None 306 | with self.lock: 307 | if self.queue and ( 308 | len(self.queue) >= self.batch_size 309 | or time.time() - self.queue[0].enqueue_time > self.max_wait_time 310 | ): 311 | batch_to_process = self.queue[:] 312 | self.queue.clear() 313 | if batch_to_process: 314 | self._process_and_respond(batch_to_process) 315 | else: 316 | time.sleep(0.005) 317 | 318 | def _process_and_respond(self, batch: List[PendingRequest]): 319 | batch_paths = [req.temp_path for req in batch] 320 | logging.info(f"Processing batch of size {len(batch_paths)}...") 321 | try: 322 | results = self.inference_manager.process_batch(batch_paths) 323 | result_map = {Path(res["file"]).name: res for res in results} 324 | for request in batch: 325 | result = result_map.get( 326 | request.temp_path.name, {"error": "Processing failed."} 327 | ) 328 | self.loop.call_soon_threadsafe(request.future.set_result, result) 329 | except Exception as e: 330 | logging.error(f"Error processing batch: {e}", exc_info=True) 331 | error_result = {"error": f"An internal server error occurred: {e}"} 332 | for request in batch: 333 | if not request.future.done(): 334 | self.loop.call_soon_threadsafe( 335 | request.future.set_result, error_result 336 | ) 337 | finally: 338 | for path in batch_paths: 339 | try: 340 | path.unlink() 341 | except OSError as e: 342 | logging.warning(f"Could not delete temp file {path}: {e}") 343 | 344 | 345 | @asynccontextmanager 346 | async def lifespan(app: FastAPI): 347 | global batcher 348 | processor, whisper_model, mlps = load_models_and_processor() 349 | inference_manager = BatchInferenceManager(processor, whisper_model, mlps) 350 | batcher = DynamicBatcher(inference_manager, batch_size=16, max_wait_time=0.05) 351 | batcher.start() 352 | yield 353 | batcher.stop() 354 | 355 | 356 | app = FastAPI(title="High-Performance Emotion Annotation API", lifespan=lifespan) 357 | 358 | 359 | @app.post( 360 | "/predict", summary="Analyze a single audio file for emotion and transcription" 361 | ) 362 | async def predict_emotion(file: UploadFile = File(...)): 363 | try: 364 | suffix = Path(file.filename).suffix 365 | with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: 366 | shutil.copyfileobj(file.file, temp_file) 367 | temp_path = Path(temp_file.name) 368 | except Exception as e: 369 | raise HTTPException( 370 | status_code=500, detail=f"Failed to save uploaded file: {e}" 371 | ) 372 | result = await batcher.add_request(temp_path) 373 | return JSONResponse(content=result) 374 | 375 | 376 | @app.get("/health", summary="Health check endpoint") 377 | async def health_check(): 378 | return {"status": "ok", "device": DEVICE, "models_loaded": batcher is not None} 379 | 380 | 381 | if __name__ == "__main__": 382 | import uvicorn 383 | 384 | uvicorn.run(app, host="0.0.0.0", port=8022) 385 | -------------------------------------------------------------------------------- /emolia-explorer.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Emolia — Ranked Search (Light) 6 | 7 | 44 | 45 | 46 | 47 |

🎧 Emolia — Ranked Search

48 |

Choose a numeric column to rank by, add conditions (range or expression), pick languages, set K, and play MP3s.

49 | 50 | 51 |
52 |
53 |
54 | 55 | 56 |
Your fastapi_parquet_rank.py server.
57 |
58 |
59 | 60 | 61 |
Root of /mnt/sdc/emolia/audiodata.
62 |
63 |
64 | 65 |
66 |
67 |
68 |
69 | 70 | 71 |
72 |

Query builder

73 | 74 |
75 |
76 | 77 | 78 |
Pick any numeric dimension (e.g., Pain_best).
79 |
80 | 81 |
82 | 83 | 87 |
88 | 89 |
90 | 91 | 92 |
93 | 94 |
95 |
96 | 97 | 101 |
102 | 123 |
124 | If “ALL languages” is checked, the request sends no language filter → all languages included. 125 | Otherwise, the server matches language or the 2-letter prefix of 126 | speaker/id. 127 |
128 |
129 |
130 | 131 |
132 | 133 |
134 |
135 |

Add a range condition

136 |
137 |
138 | 139 | 140 |
141 |
142 | 143 | 144 |
145 |
146 | 147 | 148 |
149 |
150 |
151 |
Examples: 1 < Age_best < 2, Arousal_best >= 1.5.
152 |
153 | 154 |
155 |

Add a free expression (NumExpr)

156 |
157 |
158 | 159 | 160 |
161 |
162 |
163 |
164 |
Allowed: + − × ÷, comparisons, parentheses, boolean &/|. Column names must exist.
165 |
166 |
167 | 168 |
169 |

Conditions

170 |
171 |
No conditions yet.
172 |
173 | 174 |
175 | 176 |
177 |
178 | 179 |
180 |
181 | 182 | 367 | 368 |
369 |

Serve MP3s

370 |

371 | Run a tiny HTTP server rooted at /mnt/sdc/emolia/audiodata, then set Audio base above: 372 |

373 |
374 | PORT=8040
375 | python3 -m http.server "$PORT" --directory /mnt/sdc/emolia/audiodata
376 | 
377 | 378 | 379 | 380 | -------------------------------------------------------------------------------- /annotate_audio.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import time 5 | import os 6 | from pathlib import Path 7 | from typing import List, Dict, Any 8 | from collections import OrderedDict 9 | import multiprocessing as mp 10 | 11 | # --- Core ML/AI Libraries --- 12 | import torch 13 | import torch.distributed as dist 14 | import torch.nn as nn 15 | import librosa 16 | from transformers import AutoProcessor, WhisperForConditionalGeneration 17 | from huggingface_hub import snapshot_download 18 | from tqdm import tqdm 19 | 20 | # --- Configuration Section --- 21 | # Set up logging to be clean and informative 22 | logging.basicConfig( 23 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 24 | ) 25 | 26 | # --- Performance & Model Configuration --- 27 | # Multi-GPU configuration 28 | NUM_GPUS = torch.cuda.device_count() 29 | WORLD_SIZE = NUM_GPUS if NUM_GPUS > 0 else 1 30 | USE_DDP = NUM_GPUS > 1 31 | 32 | # Enable optimizations 33 | USE_FLASH_ATTENTION_2 = True # Enable Flash Attention 2 34 | USE_TORCH_COMPILE = False # Enable model compilation 35 | USE_FP16 = True # Use half-precision 36 | BASE_BATCH_SIZE = 16 # Base batch size per GPU 37 | MAX_AUDIO_LENGTH = 30 # Max audio length in seconds 38 | 39 | # Model and Audio Parameters 40 | WHISPER_MODEL_ID = "mkrausio/EmoWhisper-AnS-Small-v0.1" 41 | HF_MLP_REPO_ID = "laion/Empathic-Insight-Voice-Small" 42 | LOCAL_MODELS_DIR = Path("./models_cache") # Dedicated directory for downloaded models 43 | SAMPLING_RATE = 16000 44 | 45 | # Parameters required for the MLP architecture 46 | WHISPER_SEQ_LEN = 1500 47 | WHISPER_EMBED_DIM = 768 48 | PROJECTION_DIM = 64 49 | MLP_HIDDEN_DIMS = [64, 32, 16] 50 | MLP_DROPOUTS = [0.0, 0.1, 0.1, 0.1] 51 | 52 | # List of supported audio file extensions 53 | SUPPORTED_AUDIO_EXTENSIONS = {".wav", ".mp3", ".flac", ".m4a", ".ogg"} 54 | 55 | 56 | class FullEmbeddingMLP(nn.Module): 57 | """The defined architecture for the emotion/attribute MLP classifiers.""" 58 | 59 | def __init__( 60 | self, 61 | seq_len: int, 62 | embed_dim: int, 63 | projection_dim: int, 64 | mlp_hidden_dims: List[int], 65 | mlp_dropout_rates: List[float], 66 | ): 67 | super().__init__() 68 | self.flatten = nn.Flatten() 69 | self.proj = nn.Linear(seq_len * embed_dim, projection_dim) 70 | layers = [nn.ReLU(), nn.Dropout(mlp_dropout_rates[0])] 71 | current_dim = projection_dim 72 | for i, h_dim in enumerate(mlp_hidden_dims): 73 | layers.extend( 74 | [ 75 | nn.Linear(current_dim, h_dim), 76 | nn.ReLU(), 77 | nn.Dropout(mlp_dropout_rates[i + 1]), 78 | ] 79 | ) 80 | current_dim = h_dim 81 | layers.append(nn.Linear(current_dim, 1)) 82 | self.mlp = nn.Sequential(*layers) 83 | 84 | def forward(self, x: torch.Tensor) -> torch.Tensor: 85 | if x.ndim == 4 and x.shape[1] == 1: 86 | x = x.squeeze(1) 87 | projected = self.proj(self.flatten(x)) 88 | return self.mlp(projected) 89 | 90 | 91 | def load_models_and_processor(device: torch.device, use_fp16: bool = True): 92 | """ 93 | Loads all required models (Whisper and MLPs) from Hugging Face Hub into memory. 94 | This function is called once per process in DDP mode. 95 | """ 96 | rank = dist.get_rank() if USE_DDP else 0 97 | if rank == 0: 98 | logging.info("Starting model loading process...") 99 | 100 | LOCAL_MODELS_DIR.mkdir(exist_ok=True, parents=True) 101 | 102 | # Synchronize processes before model loading 103 | if USE_DDP: 104 | dist.barrier() 105 | 106 | # --- Robust Whisper Model Loading with Flash Attention Fallback --- 107 | if rank == 0: 108 | logging.info(f"Loading Whisper model: {WHISPER_MODEL_ID}") 109 | 110 | whisper_dtype = ( 111 | torch.float16 if use_fp16 and device.type == "cuda" else torch.float32 112 | ) 113 | whisper_model = None 114 | 115 | if USE_FLASH_ATTENTION_2 and device.type == "cuda": 116 | try: 117 | if rank == 0: 118 | logging.info( 119 | "Attempting to load Whisper model with Flash Attention 2..." 120 | ) 121 | whisper_model = WhisperForConditionalGeneration.from_pretrained( 122 | WHISPER_MODEL_ID, 123 | torch_dtype=whisper_dtype, 124 | cache_dir=LOCAL_MODELS_DIR, 125 | use_safetensors=True, 126 | attn_implementation="flash_attention_2", 127 | ) 128 | if rank == 0: 129 | logging.info( 130 | "Successfully loaded Whisper model with Flash Attention 2." 131 | ) 132 | except (ValueError, ImportError) as e: 133 | if rank == 0: 134 | logging.warning(f"Flash Attention 2 failed: {e}") 135 | logging.warning("Falling back to default attention mechanism.") 136 | 137 | if whisper_model is None: 138 | whisper_model = WhisperForConditionalGeneration.from_pretrained( 139 | WHISPER_MODEL_ID, 140 | torch_dtype=whisper_dtype, 141 | cache_dir=LOCAL_MODELS_DIR, 142 | use_safetensors=True, 143 | attn_implementation="sdpa", 144 | ) 145 | if rank == 0: 146 | logging.info("Whisper model loaded with default attention.") 147 | 148 | whisper_model.to(device) 149 | whisper_model.eval() 150 | 151 | # Compile model for better performance 152 | if USE_TORCH_COMPILE and device.type == "cuda": 153 | try: 154 | whisper_model = torch.compile(whisper_model, mode="reduce-overhead") 155 | if rank == 0: 156 | logging.info("Whisper model compiled with torch.compile") 157 | except Exception as e: 158 | if rank == 0: 159 | logging.warning(f"Model compilation failed: {e}") 160 | 161 | # Load processor 162 | whisper_processor = AutoProcessor.from_pretrained( 163 | WHISPER_MODEL_ID, cache_dir=LOCAL_MODELS_DIR 164 | ) 165 | 166 | # --- Load MLP Models --- 167 | mlp_models_dir = LOCAL_MODELS_DIR / "empathic_insight_models" 168 | if rank == 0: 169 | logging.info( 170 | f"Downloading MLP models from {HF_MLP_REPO_ID} to {mlp_models_dir}..." 171 | ) 172 | snapshot_download( 173 | repo_id=HF_MLP_REPO_ID, 174 | local_dir=mlp_models_dir, 175 | local_dir_use_symlinks=False, 176 | ignore_patterns=["*.mp3", "*.md", ".gitattributes"], 177 | ) 178 | 179 | # Synchronize after download 180 | if USE_DDP: 181 | dist.barrier() 182 | 183 | mlp_models = {} 184 | mlp_files = list(mlp_models_dir.glob("*.pth")) 185 | 186 | if rank == 0: 187 | logging.info(f"Found {len(mlp_files)} MLP model files.") 188 | 189 | for model_path in mlp_files: 190 | filename = model_path.stem 191 | parts = filename.split("_") 192 | dimension_name = ( 193 | "_".join(parts[1:-1]) if "best" in parts[-1] else "_".join(parts[1:]) 194 | ) 195 | 196 | mlp_model = FullEmbeddingMLP( 197 | seq_len=WHISPER_SEQ_LEN, 198 | embed_dim=WHISPER_EMBED_DIM, 199 | projection_dim=PROJECTION_DIM, 200 | mlp_hidden_dims=MLP_HIDDEN_DIMS, 201 | mlp_dropout_rates=MLP_DROPOUTS, 202 | ).to(device) 203 | 204 | state_dict = torch.load(model_path, map_location=device) 205 | if any(k.startswith("_orig_mod.") for k in state_dict.keys()): 206 | state_dict = OrderedDict( 207 | (k.replace("_orig_mod.", ""), v) for k, v in state_dict.items() 208 | ) 209 | 210 | mlp_model.load_state_dict(state_dict) 211 | mlp_model.eval() 212 | if use_fp16 and device.type == "cuda": 213 | mlp_model.half() 214 | 215 | # Compile MLP model 216 | if USE_TORCH_COMPILE and device.type == "cuda": 217 | try: 218 | mlp_model = torch.compile(mlp_model) 219 | except Exception as e: 220 | logging.warning(f"MLP compilation failed: {e}") 221 | 222 | mlp_models[dimension_name] = mlp_model 223 | 224 | if rank == 0: 225 | logging.info("All models loaded successfully.") 226 | 227 | return whisper_processor, whisper_model, mlp_models 228 | 229 | 230 | class InferenceProcessor: 231 | """A class to manage the loaded models and perform efficient batch inference.""" 232 | 233 | def __init__(self, whisper_processor, whisper_model, mlp_models, device): 234 | self.whisper_processor = whisper_processor 235 | self.whisper_model = whisper_model 236 | self.mlp_models = mlp_models 237 | self.device = device 238 | 239 | def process_batch(self, audio_paths: List[Path]) -> List[Dict[str, Any]]: 240 | """Processes a batch of audio files, generating transcription and emotion scores.""" 241 | if not audio_paths: 242 | return [] 243 | 244 | # Load audio waveforms for the current batch 245 | processed_audios = [self._load_audio(p) for p in audio_paths] 246 | valid_audios = [a for a in processed_audios if a is not None] 247 | if not valid_audios: 248 | return [{"error": "Invalid audio"} for _ in audio_paths] 249 | 250 | with torch.no_grad(): 251 | # Pre-process the entire batch of audio 252 | inputs = self.whisper_processor( 253 | [a["waveform"] for a in valid_audios], 254 | sampling_rate=SAMPLING_RATE, 255 | return_tensors="pt", 256 | padding="max_length", 257 | truncation=True, 258 | ).to(self.device, non_blocking=True) 259 | 260 | if USE_FP16 and self.device.type == "cuda": 261 | inputs["input_features"] = inputs["input_features"].to( 262 | dtype=self.whisper_model.dtype 263 | ) 264 | 265 | # --- EFFICIENT DUAL INFERENCE --- 266 | # 1. Run the encoder ONLY ONCE 267 | encoder_outputs = self.whisper_model.get_encoder()( 268 | inputs.input_features, return_dict=True 269 | ) 270 | 271 | # 2. Use the encoder output for transcription (decoding) 272 | predicted_ids = self.whisper_model.generate(encoder_outputs=encoder_outputs) 273 | captions = self.whisper_processor.batch_decode( 274 | predicted_ids, skip_special_tokens=True 275 | ) 276 | 277 | # 3. Use the SAME encoder output's hidden state for emotion analysis 278 | embeddings = encoder_outputs.last_hidden_state 279 | 280 | batch_results = [] 281 | for i in range(embeddings.size(0)): 282 | single_embedding = embeddings[i : i + 1] 283 | emotion_predictions = {} 284 | for dim, mlp in self.mlp_models.items(): 285 | # Ensure consistent precision 286 | if USE_FP16 and self.device.type == "cuda": 287 | single_embedding = single_embedding.half() 288 | emotion_predictions[dim] = mlp(single_embedding).item() 289 | 290 | # Assemble the final result with the caption and emotion scores 291 | batch_results.append( 292 | { 293 | "audio_file": str( 294 | valid_audios[i]["path"] 295 | ), # Store full path for reference 296 | "caption": captions[i].strip(), 297 | "emotions": emotion_predictions, 298 | } 299 | ) 300 | 301 | # Map results back to the original input paths to handle any loading errors 302 | final_results_map = {res["audio_file"]: res for res in batch_results} 303 | return [ 304 | final_results_map.get( 305 | str(p), {"audio_file": str(p), "error": "Processing failed"} 306 | ) 307 | for p in audio_paths 308 | ] 309 | 310 | def _load_audio(self, audio_path: Path): 311 | try: 312 | waveform, _ = librosa.load( 313 | audio_path, sr=SAMPLING_RATE, mono=True, duration=MAX_AUDIO_LENGTH 314 | ) 315 | return {"waveform": waveform, "path": audio_path} 316 | except Exception as e: 317 | logging.error(f"Failed to load audio file {audio_path}: {e}") 318 | return None 319 | 320 | 321 | def setup_ddp(rank, world_size): 322 | """Initialize distributed processing with proper device assignment""" 323 | os.environ["MASTER_ADDR"] = "localhost" 324 | os.environ["MASTER_PORT"] = "12355" 325 | 326 | # Explicitly set the CUDA device before initializing DDP 327 | torch.cuda.set_device(rank) 328 | 329 | # Initialize process group 330 | dist.init_process_group( 331 | backend="nccl", rank=rank, world_size=world_size, init_method="env://" 332 | ) 333 | logging.info(f"Rank {rank} initialized on device cuda:{rank}") 334 | 335 | 336 | def cleanup_ddp(): 337 | dist.destroy_process_group() 338 | 339 | 340 | def process_worker(rank, world_size, input_folder, file_chunk): 341 | """Worker process for DDP execution""" 342 | # Set device explicitly at the start 343 | device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") 344 | if device.type == "cuda": 345 | torch.cuda.set_device(rank) 346 | 347 | # Initialize DDP if using multi-GPU 348 | if USE_DDP and world_size > 1: 349 | try: 350 | setup_ddp(rank, world_size) 351 | except RuntimeError as e: 352 | logging.error(f"Rank {rank} failed to initialize DDP: {e}") 353 | return 354 | 355 | try: 356 | # Calculate dynamic batch size based on GPU memory 357 | if device.type == "cuda": 358 | total_mem = torch.cuda.get_device_properties(device).total_memory 359 | available_mem = total_mem * 0.8 # Use 80% of available memory 360 | batch_size = max(1, int(available_mem / (3 * 1024**3))) # ~3GB per batch 361 | else: 362 | batch_size = BASE_BATCH_SIZE 363 | 364 | if rank == 0: 365 | logging.info(f"Using batch size: {batch_size} per GPU") 366 | 367 | # Load models 368 | whisper_processor, whisper_model, mlp_models = load_models_and_processor( 369 | device, USE_FP16 370 | ) 371 | inference_processor = InferenceProcessor( 372 | whisper_processor, whisper_model, mlp_models, device 373 | ) 374 | 375 | # Process file chunk 376 | with tqdm(total=len(file_chunk), desc=f"GPU {rank}", position=rank) as pbar: 377 | for i in range(0, len(file_chunk), batch_size): 378 | batch_paths = file_chunk[i : i + batch_size] 379 | batch_results = inference_processor.process_batch(batch_paths) 380 | 381 | # Save results 382 | for idx, result in enumerate(batch_results): 383 | audio_path = batch_paths[idx] # Original audio file path 384 | json_path = audio_path.with_suffix( 385 | ".json" 386 | ) # JSON in same directory 387 | 388 | # Skip saving if error occurred 389 | if "error" in result: 390 | continue 391 | 392 | # Load existing data for merging 393 | existing_data = {} 394 | if json_path.exists(): 395 | try: 396 | with open(json_path, "r", encoding="utf-8") as f: 397 | existing_data = json.load(f) 398 | except Exception as e: 399 | logging.warning(f"Error reading {json_path}: {e}") 400 | 401 | # Update with new results 402 | existing_data.update( 403 | { 404 | "caption": result.get("caption"), 405 | "emotions": result.get("emotions"), 406 | "source_audio_file": os.path.basename( 407 | result.get("audio_file") 408 | ), 409 | } 410 | ) 411 | 412 | # Ensure directory exists 413 | json_path.parent.mkdir(parents=True, exist_ok=True) 414 | 415 | # Write with JSON 416 | with open(json_path, "w", encoding="utf-8") as f: 417 | json.dump(existing_data, f, indent=4) 418 | 419 | pbar.update(len(batch_paths)) 420 | except Exception as e: 421 | logging.error(f"Error in worker process {rank}: {e}") 422 | finally: 423 | # Cleanup DDP 424 | if USE_DDP and world_size > 1: 425 | cleanup_ddp() 426 | 427 | 428 | def process_folder(input_folder: Path): 429 | """ 430 | Main orchestration function. Scans a folder, splits files across GPUs, 431 | and spawns worker processes for parallel processing. 432 | """ 433 | start_time = time.time() 434 | logging.info(f"Available GPUs: {NUM_GPUS}") 435 | logging.info(f"Using DDP: {'Yes' if USE_DDP else 'No'}") 436 | logging.info( 437 | f"Using optimizations: FlashAttention2={USE_FLASH_ATTENTION_2}, TorchCompile={USE_TORCH_COMPILE}" 438 | ) 439 | 440 | # 1. Scan for all supported audio files recursively. 441 | logging.info(f"Scanning for audio files in '{input_folder}'...") 442 | all_audio_files = [ 443 | p 444 | for p in input_folder.rglob("*") 445 | if p.suffix.lower() in SUPPORTED_AUDIO_EXTENSIONS 446 | ] 447 | 448 | if not all_audio_files: 449 | logging.warning("No supported audio files found.") 450 | return 451 | 452 | logging.info(f"Found {len(all_audio_files)} audio files. Checking which need processing...") 453 | 454 | # 2. Filter files that need processing, with a progress bar 455 | files_to_process = [] 456 | for audio_path in tqdm(all_audio_files, desc="Checking existing files"): 457 | json_path = audio_path.with_suffix(".json") 458 | if not json_path.exists(): 459 | files_to_process.append(audio_path) 460 | continue 461 | 462 | try: 463 | with open(json_path, "r", encoding="utf-8") as f: 464 | data = json.load(f) 465 | if ( 466 | "caption" not in data 467 | or "emotions" not in data 468 | or len(data.get("emotions", {})) < 55 469 | ): 470 | files_to_process.append(audio_path) 471 | except Exception: 472 | files_to_process.append(audio_path) 473 | 474 | if not files_to_process: 475 | logging.info("All files already processed.") 476 | return 477 | 478 | logging.info(f"Processing {len(files_to_process)} files...") 479 | 480 | # 3. Split files across GPUS 481 | chunk_size = len(files_to_process) // WORLD_SIZE 482 | file_chunks = [ 483 | files_to_process[i : i + chunk_size] 484 | for i in range(0, len(files_to_process), chunk_size) 485 | ] 486 | # Ensure we have exactly WORLD_SIZE chunks 487 | while len(file_chunks) < WORLD_SIZE: 488 | file_chunks.append([]) 489 | 490 | # 4. Spawn worker processes 491 | processes = [] 492 | for rank in range(WORLD_SIZE): 493 | if file_chunks[rank]: 494 | p = mp.Process( 495 | target=process_worker, 496 | args=(rank, WORLD_SIZE, input_folder, file_chunks[rank]), 497 | ) 498 | p.start() 499 | processes.append(p) 500 | 501 | for p in processes: 502 | p.join() 503 | 504 | duration = time.time() - start_time 505 | processed_count = len(files_to_process) 506 | logging.info(f"Processing complete. Total time: {duration:.2f} seconds") 507 | logging.info(f"Throughput: {processed_count/duration:.2f} files/sec") 508 | logging.info(f"Total files processed: {processed_count}") 509 | 510 | 511 | if __name__ == "__main__": 512 | # Set multiprocessing start method 513 | mp.set_start_method("spawn", force=True) 514 | 515 | parser = argparse.ArgumentParser( 516 | description="Multi-GPU audio processing for transcription and emotion analysis", 517 | formatter_class=argparse.RawTextHelpFormatter, 518 | ) 519 | parser.add_argument( 520 | "input_folder", type=str, help="Path to folder containing audio files" 521 | ) 522 | args = parser.parse_args() 523 | 524 | input_path = Path(args.input_folder) 525 | if not input_path.is_dir(): 526 | logging.error(f"Invalid directory: {input_path}") 527 | else: 528 | process_folder(input_path) 529 | --------------------------------------------------------------------------------