├── .github └── workflows │ └── docker-build-push.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── config.py ├── dia ├── __init__.py ├── audio.py ├── config.py ├── layers.py ├── model.py └── state.py ├── docker-compose.yml ├── documentation.md ├── download_model.py ├── engine.py ├── env.example.txt ├── models.py ├── reference_audio ├── Gianna.txt ├── Gianna.wav ├── Oliver_Luna.txt ├── Oliver_Luna.wav ├── Robert.txt └── Robert.wav ├── requirements.txt ├── server.py ├── static ├── screenshot-d.png └── screenshot-l.png ├── ui ├── index.html ├── presets.yaml └── script.js ├── utils.py └── voices ├── Abigail.txt ├── Abigail.wav ├── Abigail_Taylor.txt ├── Abigail_Taylor.wav ├── Adrian.txt ├── Adrian.wav ├── Adrian_Jade.txt ├── Adrian_Jade.wav ├── Alexander.txt ├── Alexander.wav ├── Alexander_Emily.txt ├── Alexander_Emily.wav ├── Alice.txt ├── Alice.wav ├── Austin.txt ├── Austin.wav ├── Austin_Jeremiah.txt ├── Austin_Jeremiah.wav ├── Axel.txt ├── Axel.wav ├── Axel_Miles.txt ├── Axel_Miles.wav ├── Connor.txt ├── Connor.wav ├── Connor_Ryan.txt ├── Connor_Ryan.wav ├── Cora.txt ├── Cora.wav ├── Cora_Gianna.txt ├── Cora_Gianna.wav ├── Elena.txt ├── Elena.wav ├── Elena_Emily.txt ├── Elena_Emily.wav ├── Eli.txt ├── Eli.wav ├── Emily.txt ├── Emily.wav ├── Everett.txt ├── Everett.wav ├── Everett_Jordan.txt ├── Everett_Jordan.wav ├── Gabriel.txt ├── Gabriel.wav ├── Gabriel_Ian.txt ├── Gabriel_Ian.wav ├── Gianna.txt ├── Gianna.wav ├── Henry.txt ├── Henry.wav ├── Ian.txt ├── Ian.wav ├── Jade.txt ├── Jade.wav ├── Jade_Layla.txt ├── Jade_Layla.wav ├── Jeremiah.txt ├── Jeremiah.wav ├── Jordan.txt ├── Jordan.wav ├── Julian.txt ├── Julian.wav ├── Julian_Thomas.txt ├── Julian_Thomas.wav ├── Layla.txt ├── Layla.wav ├── Leonardo.txt ├── Leonardo.wav ├── Leonardo_Olivia.txt ├── Leonardo_Olivia.wav ├── Michael.txt ├── Michael.wav ├── Michael_Emily.txt ├── Michael_Emily.wav ├── Miles.txt ├── Miles.wav ├── Oliver_Luna.txt ├── Oliver_Luna.wav ├── Olivia.txt ├── Olivia.wav ├── Ryan.txt ├── Ryan.wav ├── Taylor.txt ├── Taylor.wav ├── Thomas.txt └── Thomas.wav /.github/workflows/docker-build-push.yml: -------------------------------------------------------------------------------- 1 | name: Build and Push Docker Image to GHCR 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | workflow_dispatch: 9 | 10 | jobs: 11 | build-and-push: 12 | runs-on: ubuntu-latest 13 | permissions: 14 | contents: read 15 | packages: write 16 | 17 | steps: 18 | - name: Checkout repository 19 | uses: actions/checkout@v4 20 | 21 | - name: Set up Docker Buildx 22 | uses: docker/setup-buildx-action@v3 23 | 24 | - name: Log in to GitHub Container Registry 25 | uses: docker/login-action@v3 26 | with: 27 | registry: ghcr.io 28 | username: ${{ github.repository_owner }} 29 | password: ${{ secrets.GITHUB_TOKEN }} 30 | 31 | - name: Extract metadata for Docker 32 | id: meta 33 | uses: docker/metadata-action@v5 34 | with: 35 | images: ghcr.io/${{ github.repository_owner }}/dia-tts-server 36 | tags: | 37 | type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', 'main') || github.ref == format('refs/heads/{0}', 'master') }} 38 | type=sha,format=short 39 | - name: Build and push Docker image 40 | uses: docker/build-push-action@v5 41 | with: 42 | context: . 43 | push: ${{ github.event_name != 'pull_request' }} 44 | tags: ${{ steps.meta.outputs.tags }} 45 | labels: ${{ steps.meta.outputs.labels }} 46 | cache-from: type=gha 47 | cache-to: type=gha,mode=max 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore Python cache and compiled files 2 | *.pyc 3 | *.pyo 4 | *.pyd 5 | __pycache__/ 6 | .ipynb_checkpoints/ 7 | 8 | # Ignore distribution build artifacts 9 | .dist/ 10 | dist/ 11 | build/ 12 | *.egg-info/ 13 | 14 | # Ignore IDE/Editor specific files 15 | .idea/ 16 | .vscode/ 17 | *.swp 18 | *.swo 19 | 20 | # Ignore OS generated files 21 | .DS_Store 22 | Thumbs.db 23 | 24 | # Ignore virtual environment directories 25 | .venv/ 26 | venv/ 27 | */.venv/ 28 | */venv/ 29 | env/ 30 | */env/ 31 | 32 | # Ignore sensitive environment file 33 | .env 34 | 35 | # Ignore generated output and user data directories 36 | outputs/ 37 | reference_audio/ 38 | model_cache/ # Also good practice to ignore the model cache 39 | 40 | # Ignore test reports/coverage 41 | .coverage 42 | htmlcov/ 43 | .pytest_cache/ 44 | 45 | # Ignore log files 46 | *.log 47 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.8.1-runtime-ubuntu22.04 2 | 3 | # Set environment variables 4 | ENV PYTHONDONTWRITEBYTECODE=1 5 | ENV PYTHONUNBUFFERED=1 6 | ENV DEBIAN_FRONTEND=noninteractive 7 | 8 | # Install system dependencies 9 | RUN apt-get update && apt-get install -y --no-install-recommends \ 10 | build-essential \ 11 | libsndfile1 \ 12 | ffmpeg \ 13 | python3 \ 14 | python3-pip \ 15 | python3-dev \ 16 | && apt-get clean \ 17 | && rm -rf /var/lib/apt/lists/* 18 | 19 | # Set up working directory 20 | WORKDIR /app 21 | 22 | # Install Python dependencies 23 | COPY requirements.txt . 24 | RUN pip3 install --no-cache-dir -r requirements.txt 25 | 26 | # Copy application code 27 | COPY . . 28 | 29 | # Create required directories 30 | RUN mkdir -p model_cache reference_audio outputs voices 31 | 32 | # Expose the port the application will run on (default to 8003 as per config) 33 | EXPOSE 8003 34 | 35 | # Command to run the application 36 | CMD ["python3", "server.py"] 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 devnen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dia TTS Server: OpenAI-Compatible API with Web UI, Large Text Handling & Built-in Voices 2 | 3 | **Self-host the powerful [Nari Labs Dia TTS model](https://github.com/nari-labs/dia) with this enhanced FastAPI server! Features an intuitive Web UI, flexible API endpoints (including OpenAI-compatible `/v1/audio/speech`), support for realistic dialogue (`[S1]`/`[S2]`), improved voice cloning, large text processing via intelligent chunking, and consistent, reproducible voices using 43 built-in ready-to-use voices and generation seeds feature.** 4 | 5 | The latest Dia-TTS-Server version now has improved speed and reduced VRAM usage. Defaults to efficient BF16 SafeTensors for reduced VRAM and faster inference, with support for original `.pth` weights. Runs accelerated on NVIDIA GPUs (CUDA) with CPU fallback. 6 | 7 | ➡️ **Announcing our new TTS project:** Explore the Chatterbox TTS Server and its features: [https://github.com/devnen/Chatterbox-TTS-Server](https://github.com/devnen/Chatterbox-TTS-Server) 8 | 9 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg?style=for-the-badge)](LICENSE) 10 | [![Python Version](https://img.shields.io/badge/Python-3.10+-blue.svg?style=for-the-badge)](https://www.python.org/downloads/) 11 | [![Framework](https://img.shields.io/badge/Framework-FastAPI-green.svg?style=for-the-badge)](https://fastapi.tiangolo.com/) 12 | [![Model Format](https://img.shields.io/badge/Weights-SafeTensors%20/%20pth-orange.svg?style=for-the-badge)](https://github.com/huggingface/safetensors) 13 | [![Docker](https://img.shields.io/badge/Docker-Supported-blue.svg?style=for-the-badge)](https://www.docker.com/) 14 | [![Web UI](https://img.shields.io/badge/Web_UI-Included-4285F4?style=for-the-badge&logo=googlechrome&logoColor=white)](#) 15 | [![CUDA Compatible](https://img.shields.io/badge/CUDA-Compatible-76B900?style=for-the-badge&logo=nvidia&logoColor=white)](https://developer.nvidia.com/cuda-zone) 16 | [![API](https://img.shields.io/badge/OpenAI_Compatible_API-Ready-000000?style=for-the-badge&logo=openai&logoColor=white)](https://platform.openai.com/docs/api-reference) 17 | 18 |
19 | Dia TTS Server Web UI - Dark Mode 20 | Dia TTS Server Web UI - Light Mode 21 |
22 | 23 | --- 24 | 25 | ## 🗣️ Overview: Enhanced Dia TTS Access 26 | 27 | The original [Dia 1.6B TTS model by Nari Labs](https://github.com/nari-labs/dia) provides incredible capabilities for generating realistic dialogue, complete with speaker turns and non-verbal sounds like `(laughs)` or `(sighs)`. This project builds upon that foundation by providing a robust **[FastAPI](https://fastapi.tiangolo.com/) server** that makes Dia significantly easier to use and integrate. 28 | 29 | We solve the complexity of setting up and running the model by offering: 30 | 31 | * An **OpenAI-compatible API endpoint**, allowing you to use Dia TTS with tools expecting OpenAI's API structure. 32 | * A **modern Web UI** for easy experimentation, preset loading, reference audio management, and generation parameter tuning. The interface design draws inspiration from **[Lex-au's Orpheus-FastAPI project](https://github.com/Lex-au/Orpheus-FastAPI)**, adapting its intuitive layout and user experience for Dia TTS. 33 | * **Large Text Handling:** Intelligently splits long text inputs into manageable chunks based on sentence structure and speaker tags, processes them sequentially, and seamlessly concatenates the audio. 34 | * **Predefined Voices:** Select from 43 curated, ready-to-use synthetic voices for consistent and reliable output without cloning setup. 35 | * **Improved Voice Cloning:** Enhanced pipeline with automatic audio processing and transcript handling (local `.txt` file or experimental Whisper fallback). 36 | * **Consistent Generation:** Achieve consistent voice output across multiple generations or text chunks by using the "Predefined Voices" or "Voice Cloning" modes, optionally combined with a fixed integer **Seed**. 37 | * Support for both original `.pth` weights and modern, secure **[SafeTensors](https://github.com/huggingface/safetensors)**, defaulting to a **BF16 SafeTensors** version which uses roughly half the VRAM and offers improved speed. 38 | * Automatic **GPU (CUDA) acceleration** detection with fallback to CPU. 39 | * Configuration primarily via `config.yaml`, with `.env` used for initial setup/reset. 40 | * **Docker support** for easy containerized deployment with [Docker](https://www.docker.com/). 41 | 42 | This server is your gateway to leveraging Dia's advanced TTS capabilities seamlessly, now with enhanced stability, voice consistency, and large text support. 43 | 44 | ## ✨ What's New (v1.4.0 vs v1.0.0) 45 | 46 | This version introduces significant improvements and new features: 47 | 48 | **🚀 New Features:** 49 | 50 | * **Large Text Processing (Chunking):** 51 | * Automatically handles long text inputs by intelligently splitting them into smaller chunks based on sentence boundaries and speaker tags (`[S1]`/`[S2]`). 52 | * Processes each chunk individually and seamlessly concatenates the resulting audio, overcoming previous generation limits. 53 | * Configurable via UI toggle ("Split text into chunks") and chunk size slider. 54 | * **Predefined Voices:** 55 | * Added support for using 43 curated, ready-to-use synthetic voices stored in the `./voices` directory. 56 | * Selectable via UI dropdown ("Predefined Voices" mode). Server automatically uses required transcripts. 57 | * Provides reliable voice output without manual cloning setup and avoids potential licensing issues. 58 | * **Enhanced Voice Cloning:** 59 | * Improved backend pipeline for robustness. 60 | * Automatic reference audio processing: mono conversion, resampling to 44.1kHz, truncation (~20s). 61 | * Automatic transcript handling: Prioritizes local `.txt` file (recommended for accuracy) -> **experimental Whisper generation** if `.txt` is missing. Backend handles transcript prepending automatically. 62 | * Robust reference file finding handles case-insensitivity and extensions. 63 | * **Whisper Integration:** Added `openai-whisper` for automatic transcript generation as an experimental fallback during cloning. Configurable model (`WHISPER_MODEL_NAME` in `config.yaml`). 64 | * **API Enhancements:** 65 | * `/tts` endpoint now supports `transcript` (for explicit clone transcript), `split_text`, `chunk_size`, and `seed`. 66 | * `/v1/audio/speech` endpoint now supports `seed`. 67 | * **Generation Seed:** Added `seed` parameter to UI and API for influencing generation results. Using a fixed integer seed *in combination with* Predefined Voices or Voice Cloning helps maintain consistency across chunks or separate generations. Use -1 for random variation. 68 | * **Terminal Progress:** Generation of long text (using chunking) now displays a `tqdm` progress bar in the server's terminal window. 69 | * **UI Configuration Management:** Added UI section to view/edit `config.yaml` settings and save generation defaults. 70 | * **Configuration System:** Migrated to `config.yaml` for primary runtime configuration, managed via `config.py`. `.env` is now used mainly for initial seeding or resetting defaults. 71 | 72 | **🔧 Fixes & Enhancements:** 73 | 74 | * **VRAM Usage Fixed & Optimized:** Resolved memory leaks during inference and significantly reduced VRAM usage (approx. 14GB+ down to ~7GB) through code optimizations, fixing memory leaks, and BF16 default. 75 | * **Performance:** Significant speed improvements reported (approaching 95% real-time on tested hardware: AMD Ryzen 9 9950X3D + NVIDIA RTX 3090). 76 | * **Audio Post-Processing:** Automatically applies silence trimming (leading/trailing), internal silence reduction, and unvoiced segment removal (using Parselmouth) to improve audio quality and remove artifacts. 77 | * **UI State Persistence:** Web UI now saves/restores text input, voice mode selection, file selections, and generation parameters (seed, chunking, sliders) in `config.yaml`. 78 | * **UI Improvements:** Better loading indicators (shows chunk processing), refined chunking controls, seed input field, theme toggle, dynamic preset loading from `ui/presets.yaml`, warning modals for chunking/generation quality. 79 | * **Cloning Workflow:** Backend now handles transcript prepending automatically. UI workflow simplified (user selects file, enters target text). 80 | * **Dependency Management:** Added `tqdm`, `PyYAML`, `openai-whisper`, `parselmouth` to `requirements.txt`. 81 | * **Code Refactoring:** Aligned internal engine code with refactored `dia` library structure. Updated `config.py` to use `YamlConfigManager`. 82 | 83 | ## ✅ Features 84 | 85 | * **Core Dia Capabilities (via [Nari Labs Dia](https://github.com/nari-labs/dia)):** 86 | * 🗣️ Generate multi-speaker dialogue using `[S1]` / `[S2]` tags. 87 | * 😂 Include non-verbal sounds like `(laughs)`, `(sighs)`, `(clears throat)`. 88 | * 🎭 Perform voice cloning using reference audio prompts. 89 | * **Enhanced Server & API:** 90 | * ⚡ Built with the high-performance **[FastAPI](https://fastapi.tiangolo.com/)** framework. 91 | * 🤖 **OpenAI-Compatible API Endpoint** (`/v1/audio/speech`) for easy integration (now includes `seed`). 92 | * ⚙️ **Custom API Endpoint** (`/tts`) exposing all Dia generation parameters (now includes `seed`, `split_text`, `chunk_size`, `transcript`). 93 | * 📄 Interactive API documentation via Swagger UI (`/docs`). 94 | * 🩺 Health check endpoint (`/health`). 95 | * **Advanced Generation Features:** 96 | * 📚 **Large Text Handling:** Intelligently splits long inputs into chunks based on sentences and speaker tags, generates audio for each, and concatenates the results seamlessly. Configurable via `split_text` and `chunk_size`. 97 | * 🎤 **Predefined Voices:** Select from 43 curated, ready-to-use synthetic voices in the `./voices` directory for consistent output without cloning setup. 98 | * ✨ **Improved Voice Cloning:** Robust pipeline with automatic audio processing and transcript handling (local `.txt` or Whisper fallback). Backend handles transcript prepending. 99 | * 🌱 **Consistent Generation:** Use Predefined Voices or Voice Cloning modes, optionally with a fixed integer **Seed**, for consistent voice output across chunks or multiple requests. 100 | * 🔇 **Audio Post-Processing:** Automatic steps to trim silence, fix internal pauses, and remove long unvoiced segments/artifacts. 101 | * **Intuitive Web User Interface:** 102 | * 🖱️ Modern, easy-to-use interface inspired by **[Lex-au's Orpheus-FastAPI project](https://github.com/Lex-au/Orpheus-FastAPI)**. 103 | * 💡 **Presets:** Load example text and settings dynamically from `ui/presets.yaml`. Customize by editing the file. 104 | * 🎤 **Reference Audio Upload:** Easily upload `.wav`/`.mp3` files for voice cloning. 105 | * 🗣️ **Voice Mode Selection:** Choose between Predefined Voices, Voice Cloning, or Random/Dialogue modes. 106 | * 🎛️ **Parameter Control:** Adjust generation settings (CFG Scale, Temperature, Speed, Seed, etc.) via sliders and inputs. 107 | * 💾 **Configuration Management:** View and save server settings (`config.yaml`) and default generation parameters directly in the UI. 108 | * 💾 **Session Persistence:** Remembers your last used settings via `config.yaml`. 109 | * ✂️ **Chunking Controls:** Enable/disable text splitting and adjust approximate chunk size. 110 | * ⚠️ **Warning Modals:** Optional warnings for chunking voice consistency and general generation quality. 111 | * 🌓 **Light/Dark Mode:** Toggle between themes with preference saved locally. 112 | * 🔊 **Audio Player:** Integrated waveform player ([WaveSurfer.js](https://wavesurfer.xyz/)) for generated audio with download option. 113 | * ⏳ **Loading Indicator:** Shows status, including chunk processing information. 114 | * **Flexible & Efficient Model Handling:** 115 | * ☁️ Downloads models automatically from [Hugging Face Hub](https://huggingface.co/). 116 | * 🔒 Supports loading secure **`.safetensors`** weights (default). 117 | * 💾 Supports loading original **`.pth`** weights. 118 | * 🚀 Defaults to **BF16 SafeTensors** for reduced memory footprint (~half size) and potentially faster inference. (Credit: [ttj/dia-1.6b-safetensors](https://huggingface.co/ttj/dia-1.6b-safetensors)) 119 | * 🔄 Easily switch between model formats/versions via `config.yaml`. 120 | * **Performance & Configuration:** 121 | * 💻 **GPU Acceleration:** Automatically uses NVIDIA CUDA if available, falls back to CPU. Optimized VRAM usage (~7GB typical). 122 | * 📊 **Terminal Progress:** Displays `tqdm` progress bar when processing text chunks. 123 | * ⚙️ Primary configuration via `config.yaml`, initial seeding via `.env`. 124 | * 📦 Uses standard Python virtual environments. 125 | * **Docker Support:** 126 | * 🐳 Containerized deployment via [Docker](https://www.docker.com/) and Docker Compose. 127 | * 🔌 NVIDIA GPU acceleration with Container Toolkit integration. 128 | * 💾 Persistent volumes for models, reference audio, predefined voices, outputs, and config. 129 | * 🚀 One-command setup and deployment (`docker compose up -d`). 130 | 131 | ## 🔩 System Prerequisites 132 | 133 | * **Operating System:** Windows 10/11 (64-bit) or Linux (Debian/Ubuntu recommended). 134 | * **Python:** Version 3.10 or later ([Download](https://www.python.org/downloads/)). 135 | * **Git:** For cloning the repository ([Download](https://git-scm.com/downloads)). 136 | * **Internet:** For downloading dependencies and models. 137 | * **(Optional but HIGHLY Recommended for Performance):** 138 | * **NVIDIA GPU:** CUDA-compatible (Maxwell architecture or newer). Check [NVIDIA CUDA GPUs](https://developer.nvidia.com/cuda-gpus). Optimized VRAM usage (~7GB typical), but more helps. 139 | * **NVIDIA Drivers:** Latest version for your GPU/OS ([Download](https://www.nvidia.com/Download/index.aspx)). 140 | * **CUDA Toolkit:** Compatible version (e.g., 11.8, 12.1) matching the PyTorch build you install. 141 | * **(Linux Only):** 142 | * `libsndfile1`: Audio library needed by `soundfile`. Install via package manager (e.g., `sudo apt install libsndfile1`). 143 | * `ffmpeg`: Required by `openai-whisper`. Install via package manager (e.g., `sudo apt install ffmpeg`). 144 | 145 | ## 💻 Installation and Setup 146 | 147 | Follow these steps carefully to get the server running. 148 | 149 | **1. Clone the Repository** 150 | ```bash 151 | git clone https://github.com/devnen/dia-tts-server.git 152 | cd dia-tts-server 153 | ``` 154 | 155 | **2. Set up Python Virtual Environment** 156 | 157 | Using a virtual environment is crucial! 158 | 159 | * **Windows (PowerShell):** 160 | ```powershell 161 | # In the dia-tts-server directory 162 | python -m venv venv 163 | .\venv\Scripts\activate 164 | # Your prompt should now start with (venv) 165 | ``` 166 | 167 | * **Linux (Bash - Debian/Ubuntu Example):** 168 | ```bash 169 | # Ensure prerequisites are installed 170 | sudo apt update && sudo apt install python3 python3-venv python3-pip libsndfile1 ffmpeg -y 171 | 172 | # In the dia-tts-server directory 173 | python3 -m venv venv 174 | source venv/bin/activate 175 | # Your prompt should now start with (venv) 176 | ``` 177 | 178 | **3. Install Dependencies** 179 | 180 | Make sure your virtual environment is activated (`(venv)` prefix visible). 181 | 182 | ```bash 183 | # Upgrade pip (recommended) 184 | pip install --upgrade pip 185 | 186 | # Install project requirements (includes tqdm, yaml, parselmouth etc.) 187 | pip install -r requirements.txt 188 | ``` 189 | ⭐ **Note:** This installation includes large libraries like PyTorch. The download and installation process may take some time depending on your internet speed and system performance. 190 | 191 | ⭐ **Important:** This installs the *CPU-only* version of PyTorch by default. If you have an NVIDIA GPU, proceed to Step 4 **before** running the server for GPU acceleration. 192 | 193 | **4. NVIDIA Driver and CUDA Setup (for GPU Acceleration)** 194 | 195 | Skip this step if you only have a CPU. 196 | 197 | * **Step 4a: Check/Install NVIDIA Drivers** 198 | * Run `nvidia-smi` in your terminal/command prompt. 199 | * If it works, note the **CUDA Version** listed (e.g., 12.1, 11.8). This is the *maximum* your driver supports. 200 | * If it fails, download and install the latest drivers from [NVIDIA Driver Downloads](https://www.nvidia.com/Download/index.aspx) and **reboot**. Verify with `nvidia-smi` again. 201 | 202 | * **Step 4b: Install PyTorch with CUDA Support** 203 | * Go to the [Official PyTorch Website](https://pytorch.org/get-started/locally/). 204 | * Use the configuration tool: Select **Stable**, **Windows/Linux**, **Pip**, **Python**, and the **CUDA version** that is **equal to or lower** than the one shown by `nvidia-smi` (e.g., if `nvidia-smi` shows 12.4, choose CUDA 12.1). 205 | * Copy the generated command (it will include `--index-url https://download.pytorch.org/whl/cuXXX`). 206 | * **In your activated `(venv)`:** 207 | ```bash 208 | # Uninstall the CPU version first! 209 | pip uninstall torch torchvision torchaudio -y 210 | 211 | # Paste and run the command copied from the PyTorch website 212 | # Example (replace with your actual command): 213 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 214 | ``` 215 | 216 | * **Step 4c: Verify PyTorch CUDA Installation** 217 | * In your activated `(venv)`, run `python` and execute the following single line: 218 | ```python 219 | import torch; print(f"PyTorch version: {torch.__version__}"); print(f"CUDA available: {torch.cuda.is_available()}"); print(f"Device name: {torch.cuda.get_device_name(0)}") if torch.cuda.is_available() else None; exit() 220 | ``` 221 | * If `CUDA available:` shows `True`, the setup was successful. If `False`, double-check driver installation and the PyTorch install command. 222 | 223 | ## ⚙️ Configuration 224 | 225 | The server now primarily uses `config.yaml` for runtime configuration. 226 | 227 | * **`config.yaml`:** Located in the project root. This file stores all server settings, model paths, generation defaults, and UI state. It is created automatically on the first run if it doesn't exist. **This is the main file to edit for persistent configuration changes.** 228 | * **`.env` File:** Used **only** for the *initial creation* of `config.yaml` if it's missing, or when using the "Reset All Settings" button in the UI. Values in `.env` override hardcoded defaults during this initial seeding/reset process. It is **not** read during normal server operation once `config.yaml` exists. 229 | * **UI Configuration:** The "Server Configuration" and "Generation Parameters" sections in the Web UI allow direct editing and saving of values *into* `config.yaml`. 230 | 231 | **Key Configuration Areas (in `config.yaml` or UI):** 232 | 233 | * `server`: `host`, `port` 234 | * `model`: `repo_id`, `config_filename`, `weights_filename`, `whisper_model_name` 235 | * `paths`: `model_cache`, `reference_audio`, `output`, `voices` (for predefined) 236 | * `generation_defaults`: Default values for sliders/seed in the UI (`speed_factor`, `cfg_scale`, `temperature`, `top_p`, `cfg_filter_top_k`, `seed`, `split_text`, `chunk_size`). 237 | * `ui_state`: Stores the last used text, voice mode, file selections, etc., for UI persistence. 238 | 239 | ⭐ **Remember:** Changes made to `server`, `model`, or `paths` sections in `config.yaml` (or via the UI) **require a server restart** to take effect. Changes to `generation_defaults` or `ui_state` are applied dynamically or on the next page load. 240 | 241 | ## ▶️ Running the Server 242 | 243 | **Note on Model Downloads:** 244 | The first time you run the server (or after changing model settings in `config.yaml`), it will download the required Dia and Whisper model files (~3-7GB depending on selection). Monitor the terminal logs for progress. The server starts fully *after* downloads complete. 245 | 246 | 1. **Activate the virtual environment (if not activated):** 247 | * Linux/macOS: `source venv/bin/activate` 248 | * Windows: `.\venv\Scripts\activate` 249 | 2. **Run the server:** 250 | ```bash 251 | python server.py 252 | ``` 253 | 3. **Access the UI:** The server should automatically attempt to open the Web UI in your default browser after startup. If it doesn't for any reason, manually navigate to `http://localhost:PORT` (e.g., `http://localhost:8003`). 254 | 4. **Access API Docs:** Open `http://localhost:PORT/docs`. 255 | 5. **Stop the server:** Press `CTRL+C` in the terminal. 256 | 257 | --- 258 | 259 | ## 🐳 Docker Installation 260 | 261 | Run Dia TTS Server easily using Docker. The recommended method uses Docker Compose with pre-built images from GitHub Container Registry (GHCR). 262 | 263 | ### Prerequisites 264 | 265 | * [Docker](https://docs.docker.com/get-docker/) installed. 266 | * [Docker Compose](https://docs.docker.com/compose/install/) installed (usually included with Docker Desktop). 267 | * (Optional but Recommended for GPU) NVIDIA GPU with up-to-date drivers and the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) installed. 268 | 269 | ### Option 1: Using Docker Compose (Recommended) 270 | 271 | This method uses `docker-compose.yml` to manage the container, volumes, and configuration easily. It leverages pre-built images hosted on GHCR. 272 | 273 | 1. **Clone the repository:** (You only need the `docker-compose.yml` and `env.example.txt` files from it) 274 | ```bash 275 | git clone https://github.com/devnen/dia-tts-server.git 276 | cd dia-tts-server 277 | ``` 278 | 279 | 2. **(Optional) Initial Configuration via `.env`:** 280 | * If this is your very first time running the container and you want to override the default settings *before* `config.yaml` is created inside the container, copy the example environment file: 281 | ```bash 282 | cp env.example.txt .env 283 | ``` 284 | * Edit the `.env` file with your desired initial settings (e.g., `PORT`, model filenames). 285 | * **Note:** This `.env` file is *only* used to seed the *initial* `config.yaml` on the very first container start if `/app/config.yaml` doesn't already exist inside the container's volume (which it won't initially). Subsequent configuration changes should be made via the UI or by editing `config.yaml` directly (see Configuration Note below). 286 | 287 | 3. **Review `docker-compose.yml`:** 288 | * The repository includes a `docker-compose.yml` file configured to use the pre-built image and recommended settings. Ensure it looks similar to this: 289 | 290 | ```yaml 291 | # docker-compose.yml 292 | version: '3.8' 293 | 294 | services: 295 | dia-tts-server: 296 | # Use the pre-built image from GitHub Container Registry 297 | image: ghcr.io/devnen/dia-tts-server:latest 298 | # Alternatively, to build locally (e.g., for development): 299 | # build: 300 | # context: . 301 | # dockerfile: Dockerfile 302 | ports: 303 | # Map host port (default 8003) to container port 8003 304 | # You can change the host port via .env (e.g., PORT=8004) 305 | - "${PORT:-8003}:8003" 306 | volumes: 307 | # Mount local directories into the container for persistent data 308 | - ./model_cache:/app/model_cache 309 | - ./reference_audio:/app/reference_audio 310 | - ./outputs:/app/outputs 311 | - ./voices:/app/voices 312 | # DO NOT mount config.yaml - let the app create it inside 313 | 314 | # --- GPU Access --- 315 | # Modern method (Recommended for newer Docker/NVIDIA setups) 316 | devices: 317 | - nvidia.com/gpu=all 318 | device_cgroup_rules: 319 | - "c 195:* rmw" # Needed for some NVIDIA container toolkit versions 320 | - "c 236:* rmw" # Needed for some NVIDIA container toolkit versions 321 | 322 | # Legacy method (Alternative for older Docker/NVIDIA setups) 323 | # If the 'devices' block above doesn't work, comment it out and uncomment 324 | # the 'deploy' block below. Do not use both simultaneously. 325 | # deploy: 326 | # resources: 327 | # reservations: 328 | # devices: 329 | # - driver: nvidia 330 | # count: 1 # Or specify specific GPUs e.g., "device=0,1" 331 | # capabilities: [gpu] 332 | # --- End GPU Access --- 333 | 334 | restart: unless-stopped 335 | env_file: 336 | # Load environment variables from .env file for initial config seeding 337 | - .env 338 | environment: 339 | # Enable faster Hugging Face downloads inside the container 340 | - HF_HUB_ENABLE_HF_TRANSFER=1 341 | # Pass GPU capabilities (may be needed for legacy method if uncommented) 342 | - NVIDIA_VISIBLE_DEVICES=all 343 | - NVIDIA_DRIVER_CAPABILITIES=compute,utility 344 | 345 | # Optional: Define named volumes if you prefer them over host mounts 346 | # volumes: 347 | # model_cache: 348 | # reference_audio: 349 | # outputs: 350 | # voices: 351 | ``` 352 | 353 | 4. **Start the container:** 354 | ```bash 355 | docker compose up -d 356 | ``` 357 | * This command will: 358 | * Pull the latest `ghcr.io/devnen/dia-tts-server:latest` image. 359 | * Create the local directories (`model_cache`, `reference_audio`, `outputs`, `voices`) if they don't exist. 360 | * Start the container in detached mode (`-d`). 361 | * The first time you run this, it will download the TTS models into `./model_cache`, which may take some time depending on your internet speed. 362 | 363 | 5. **Access the UI:** 364 | Open your web browser to `http://localhost:8003` (or the host port you configured in `.env`). 365 | 366 | 6. **View logs:** 367 | ```bash 368 | docker compose logs -f 369 | ``` 370 | 371 | 7. **Stop the container:** 372 | ```bash 373 | docker compose down 374 | ``` 375 | 376 | ### Option 2: Using `docker run` (Alternative) 377 | 378 | This method runs the container directly without Docker Compose, requiring manual specification of ports, volumes, and GPU flags. 379 | 380 | ```bash 381 | # Ensure local directories exist first: 382 | # mkdir -p model_cache reference_audio outputs voices 383 | 384 | docker run -d \ 385 | --name dia-tts-server \ 386 | -p 8003:8003 \ 387 | -v ./model_cache:/app/model_cache \ 388 | -v ./reference_audio:/app/reference_audio \ 389 | -v ./outputs:/app/outputs \ 390 | -v ./voices:/app/voices \ 391 | --env HF_HUB_ENABLE_HF_TRANSFER=1 \ 392 | --gpus all \ 393 | ghcr.io/devnen/dia-tts-server:latest 394 | ``` 395 | 396 | * Replace `8003:8003` with `:8003` if needed. 397 | * `--gpus all` enables GPU access; consult NVIDIA Container Toolkit documentation for alternatives if needed. 398 | * Initial configuration relies on model defaults unless you pass environment variables using multiple `-e VAR=VALUE` flags (more complex than using `.env` with Compose). 399 | 400 | ### Configuration Note 401 | 402 | * The server uses `config.yaml` inside the container (`/app/config.yaml`) for its settings. 403 | * On the *very first start*, if `/app/config.yaml` doesn't exist, the server creates it using defaults from the code, potentially overridden by variables in the `.env` file (if using Docker Compose and `.env` exists). 404 | * **After the first start,** changes should be made by: 405 | * Using the Web UI's settings page (if available). 406 | * Editing the `config.yaml` file *inside* the container (e.g., `docker compose exec dia-tts-server nano /app/config.yaml`). Changes require a container restart (`docker compose restart dia-tts-server`) to take effect for server/model/path settings. UI state changes are saved live. 407 | 408 | ### Performance Optimizations 409 | 410 | * **Faster Model Downloads**: `hf-transfer` is enabled by default in the provided `docker-compose.yml` and image, significantly speeding up initial model downloads from Hugging Face. 411 | * **GPU Acceleration**: The `docker-compose.yml` and `docker run` examples include flags (`devices` or `--gpus`) to enable NVIDIA GPU acceleration if available. The Docker image uses a CUDA runtime base for efficiency. 412 | 413 | ### Docker Volumes 414 | 415 | Persistent data is stored on your host machine via volume mounts: 416 | 417 | * `./model_cache:/app/model_cache` (Downloaded TTS and Whisper models) 418 | * `./reference_audio:/app/reference_audio` (Your uploaded reference audio files for cloning) 419 | * `./outputs:/app/outputs` (Generated audio files) 420 | * `./voices:/app/voices` (Predefined voice audio files) 421 | 422 | ### Available Images 423 | 424 | * **GitHub Container Registry**: `ghcr.io/devnen/dia-tts-server:latest` (Automatically built from the `main` branch) 425 | 426 | --- 427 | 428 | ## 💡 Usage 429 | 430 | ### Web UI (`http://localhost:PORT`) 431 | 432 | The most intuitive way to use the server: 433 | 434 | * **Text Input:** Enter your script. Use `[S1]`/`[S2]` for dialogue and non-verbals like `(laughs)`. Content is saved automatically. 435 | * **Generate Button & Chunking:** Click "Generate Speech". Below the text box: 436 | * **Split text into chunks:** Toggle checkbox (enabled by default). Enables splitting for long text (> ~2x chunk size). 437 | * **Chunk Size:** Adjust the slider (visible when splitting is possible) for approximate chunk character length (default 120). 438 | * **Voice Mode:** Choose: 439 | * `Predefined Voices`: Select a curated, ready-to-use synthetic voice from the `./voices` directory. 440 | * `Voice Cloning`: Select an uploaded reference file from `./reference_audio`. Requires a corresponding `.txt` transcript (recommended) or relies on experimental Whisper fallback. Backend handles transcript automatically. 441 | * `Random Single / Dialogue`: Uses `[S1]`/`[S2]` tags or generates a random voice if no tags. Use a fixed Seed for consistency. 442 | * **Presets:** Click buttons (loaded from `ui/presets.yaml`) to populate text and parameters. Customize by editing the YAML file. 443 | * **Reference Audio (Clone Mode):** Select an existing `.wav`/`.mp3` or click "Import" to upload new files to `./reference_audio`. 444 | * **Generation Parameters:** Adjust sliders/inputs for Speed, CFG, Temperature, Top P, Top K, and **Seed**. Settings are saved automatically. Click "Save Generation Parameters" to update the defaults in `config.yaml`. Use -1 seed for random, integer for specific results. 445 | * **Server Configuration:** View/edit `config.yaml` settings (requires server restart for some changes). 446 | * **Loading Overlay:** Appears during generation, showing chunk progress if applicable. 447 | * **Audio Player:** Appears on success with waveform, playback controls, download link, and generation info. 448 | * **Theme Toggle:** Switch between light/dark modes. 449 | 450 | ### API Endpoints (`/docs` for details) 451 | 452 | * **`/v1/audio/speech` (POST):** OpenAI-compatible. 453 | * `input`: Text. 454 | * `voice`: 'S1', 'S2', 'dialogue', 'predefined_voice_filename.wav', or 'reference_filename.wav'. 455 | * `response_format`: 'opus' or 'wav'. 456 | * `speed`: Playback speed factor (0.5-2.0). 457 | * `seed`: (Optional) Integer seed, -1 for random. 458 | * **`/tts` (POST):** Custom endpoint with full control. 459 | * `text`: Target text. 460 | * `voice_mode`: 'dialogue', 'single_s1', 'single_s2', 'clone', 'predefined'. 461 | * `clone_reference_filename`: Filename in `./reference_audio` (for clone) or `./voices` (for predefined). 462 | * `transcript`: (Optional, Clone Mode Only) Explicit transcript text to override file/Whisper lookup. 463 | * `output_format`: 'opus' or 'wav'. 464 | * `max_tokens`: (Optional) Max tokens *per chunk*. 465 | * `cfg_scale`, `temperature`, `top_p`, `cfg_filter_top_k`: Generation parameters. 466 | * `speed_factor`: Playback speed factor (0.5-2.0). 467 | * `seed`: (Optional) Integer seed, -1 for random. 468 | * `split_text`: (Optional) Boolean, enable/disable chunking (default: True). 469 | * `chunk_size`: (Optional) Integer, target chunk size (default: 120). 470 | 471 | ## 🔍 Troubleshooting 472 | 473 | * **CUDA Not Available / Slow:** Check NVIDIA drivers (`nvidia-smi`), ensure correct CUDA-enabled PyTorch is installed (Installation Step 4). 474 | * **VRAM Out of Memory (OOM):** 475 | * Ensure you are using the BF16 model (`dia-v0_1_bf16.safetensors` in `config.yaml`) if VRAM is limited (~7GB needed). 476 | * Close other GPU-intensive applications. VRAM optimizations and leak fixes have significantly reduced requirements. 477 | * If processing very long text even with chunking, try reducing `chunk_size` (e.g., 100). 478 | * **CUDA Out of Memory (OOM) During Startup:** This can happen due to temporary overhead. The server loads weights to CPU first to mitigate this. If it persists, check VRAM usage (`nvidia-smi`), ensure BF16 model is used, or try setting `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` environment variable before starting. 479 | * **Import Errors (`dac`, `tqdm`, `yaml`, `whisper`, `parselmouth`):** Activate venv, run `pip install -r requirements.txt`. Ensure `descript-audio-codec` installed correctly. 480 | * **`libsndfile` / `ffmpeg` Error (Linux):** Run `sudo apt install libsndfile1 ffmpeg`. 481 | * **Model Download Fails (Dia or Whisper):** Check internet, `config.yaml` settings (`model.repo_id`, `model.weights_filename`, `model.whisper_model_name`), Hugging Face status, cache path permissions (`paths.model_cache`). 482 | * **Voice Cloning Fails / Poor Quality:** 483 | * **Ensure accurate `.txt` transcript exists** alongside the reference audio in `./reference_audio`. Format: `[S1] text...` or `[S1] text... [S2] text...`. This is the most reliable method. 484 | * Whisper fallback is experimental and may be inaccurate. 485 | * Use clean, clear reference audio (5-20s). 486 | * Check server logs for specific errors during `_prepare_cloning_inputs`. 487 | * **Permission Errors (Saving Files/Config):** Check write permissions for `paths.output`, `paths.reference_audio`, `paths.voices`, `paths.model_cache` (for Whisper transcript saves), and `config.yaml`. 488 | * **UI Issues / Settings Not Saving:** Clear browser cache/local storage. Check developer console (F12) for JS errors. Ensure `config.yaml` is writable by the server process. 489 | * **Inconsistent Voice with Chunking:** Use "Predefined Voices" or "Voice Cloning" mode. If using "Random/Dialogue" mode with splitting, use a fixed integer `seed` (not -1) for consistency across chunks. The UI provides a warning otherwise. 490 | * **Port Conflict (`Address already in use` / `Errno 98`):** Another process is using the port (default 8003). Stop the other process or change the `server.port` in `config.yaml` (requires restart). 491 | * **Explanation:** This usually happens if a previous server instance didn't shut down cleanly or another application is bound to the same port. 492 | * **Linux:** Find/kill process: `sudo lsof -i:PORT | grep LISTEN | awk '{print $2}' | xargs kill -9` (Replace PORT, e.g., 8003). 493 | * **Windows:** Find/kill process: `for /f "tokens=5" %i in ('netstat -ano ^| findstr :PORT') do taskkill /F /PID %i` (Replace PORT, e.g., 8003). Use with caution. 494 | * **Generation Cancel Button:** This is a "UI Cancel" - it stops the *frontend* from waiting but doesn't instantly halt ongoing backend model inference. Clicking Generate again cancels the previous UI wait. 495 | 496 | ### Selecting GPUs on Multi-GPU Systems 497 | 498 | Set the `CUDA_VISIBLE_DEVICES` environment variable **before** running `python server.py` to specify which GPU(s) PyTorch should see. The server uses the first visible one (`cuda:0`). 499 | 500 | * **Example (Use only physical GPU 1):** 501 | * Linux/macOS: `CUDA_VISIBLE_DEVICES="1" python server.py` 502 | * Windows CMD: `set CUDA_VISIBLE_DEVICES=1 && python server.py` 503 | * Windows PowerShell: `$env:CUDA_VISIBLE_DEVICES="1"; python server.py` 504 | 505 | * **Example (Use physical GPUs 6 and 7 - server uses GPU 6):** 506 | * Linux/macOS: `CUDA_VISIBLE_DEVICES="6,7" python server.py` 507 | * Windows CMD: `set CUDA_VISIBLE_DEVICES=6,7 && python server.py` 508 | * Windows PowerShell: `$env:CUDA_VISIBLE_DEVICES="6,7"; python server.py` 509 | 510 | **Note:** `CUDA_VISIBLE_DEVICES` selects GPUs; it does **not** fix OOM errors if the chosen GPU lacks sufficient memory. 511 | 512 | ## 🤝 Contributing 513 | 514 | Contributions are welcome! Please feel free to open an issue to report bugs or suggest features, or submit a Pull Request for improvements. 515 | 516 | ## 📜 License 517 | 518 | This project is licensed under the **MIT License**. 519 | 520 | You can find it here: [https://opensource.org/licenses/MIT](https://opensource.org/licenses/MIT) 521 | 522 | ## 🙏 Acknowledgements 523 | 524 | * **Core Model:** This project heavily relies on the excellent **[Dia TTS model](https://github.com/nari-labs/dia)** developed by **[Nari Labs](https://github.com/nari-labs)**. Their work in creating and open-sourcing the model is greatly appreciated. 525 | * **UI Inspiration:** Special thanks to **[Lex-au](https://github.com/Lex-au)** whose **[Orpheus-FastAPI](https://github.com/Lex-au/Orpheus-FastAPI)** project served as inspiration for the web interface design of this project. 526 | * **SafeTensors Conversion:** Thank you to user **[ttj on Hugging Face](https://huggingface.co/ttj)** for providing the converted **[SafeTensors weights](https://huggingface.co/ttj/dia-1.6b-safetensors)** used as the default in this server. 527 | * **Containerization Technologies:** [Docker](https://www.docker.com/) and [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-docker) for enabling consistent deployment environments. 528 | * **Core Libraries:** 529 | * [FastAPI](https://fastapi.tiangolo.com/) 530 | * [Uvicorn](https://www.uvicorn.org/) 531 | * [PyTorch](https://pytorch.org/) 532 | * [Hugging Face Hub](https://huggingface.co/docs/huggingface_hub/index) & [SafeTensors](https://github.com/huggingface/safetensors) 533 | * [Descript Audio Codec (DAC)](https://github.com/descriptinc/descript-audio-codec) 534 | * [SoundFile](https://python-soundfile.readthedocs.io/) & [libsndfile](http://www.mega-nerd.com/libsndfile/) 535 | * [Jinja2](https://jinja.palletsprojects.com/) 536 | * [WaveSurfer.js](https://wavesurfer.xyz/) 537 | * [Tailwind CSS](https://tailwindcss.com/) (via CDN) 538 | 539 | --- 540 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # config.py 2 | # Configuration management for Dia TTS server using YAML file. 3 | 4 | import os 5 | import logging 6 | import yaml 7 | import shutil 8 | from copy import deepcopy 9 | from threading import Lock 10 | from typing import Dict, Any, Optional, List, Tuple 11 | 12 | # Use dotenv for initial seeding ONLY if config.yaml is missing 13 | from dotenv import load_dotenv, find_dotenv 14 | 15 | # Configure logging 16 | logger = logging.getLogger(__name__) 17 | 18 | # --- Constants --- 19 | CONFIG_FILE_PATH = "config.yaml" 20 | ENV_FILE_PATH = find_dotenv() # Find .env file location 21 | 22 | # --- Default Configuration Structure --- 23 | # This defines the expected structure and default values for config.yaml 24 | # It's crucial to keep this up-to-date with all expected keys. 25 | DEFAULT_CONFIG: Dict[str, Any] = { 26 | "server": { 27 | "host": "0.0.0.0", 28 | "port": 8003, 29 | }, 30 | "model": { 31 | "repo_id": "ttj/dia-1.6b-safetensors", 32 | "config_filename": "config.json", 33 | "weights_filename": "dia-v0_1_bf16.safetensors", 34 | "whisper_model_name": "small.en", # Added Whisper model here 35 | }, 36 | "paths": { 37 | "model_cache": "./model_cache", 38 | "reference_audio": "./reference_audio", 39 | "output": "./outputs", 40 | "voices": "./voices", # Added predefined voices path 41 | }, 42 | "generation_defaults": { 43 | "speed_factor": 1.0, # Changed default to 1.0 44 | "cfg_scale": 3.0, 45 | "temperature": 1.3, 46 | "top_p": 0.95, 47 | "cfg_filter_top_k": 35, 48 | "seed": 42, 49 | "split_text": True, # Default split enabled 50 | "chunk_size": 120, # Default chunk size 51 | }, 52 | "ui_state": { 53 | "last_text": "", 54 | "last_voice_mode": "predefined", # Default to predefined 55 | "last_predefined_voice": None, # Store original filename (e.g., "Michael_Emily.wav") 56 | "last_reference_file": None, 57 | "last_seed": 42, 58 | "last_chunk_size": 120, 59 | "last_split_text_enabled": True, 60 | "hide_chunk_warning": False, 61 | "hide_generation_warning": False, 62 | }, 63 | } 64 | 65 | # Mapping from .env variable names to config.yaml nested keys 66 | # Used ONLY during initial seeding if config.yaml is missing. 67 | ENV_TO_YAML_MAP: Dict[str, Tuple[List[str], type]] = { 68 | # Server 69 | "HOST": (["server", "host"], str), 70 | "PORT": (["server", "port"], int), 71 | # Model 72 | "DIA_MODEL_REPO_ID": (["model", "repo_id"], str), 73 | "DIA_MODEL_CONFIG_FILENAME": (["model", "config_filename"], str), 74 | "DIA_MODEL_WEIGHTS_FILENAME": (["model", "weights_filename"], str), 75 | "WHISPER_MODEL_NAME": (["model", "whisper_model_name"], str), 76 | # Paths 77 | "DIA_MODEL_CACHE_PATH": (["paths", "model_cache"], str), 78 | "REFERENCE_AUDIO_PATH": (["paths", "reference_audio"], str), 79 | "OUTPUT_PATH": (["paths", "output"], str), 80 | # Generation Defaults 81 | "GEN_DEFAULT_SPEED_FACTOR": (["generation_defaults", "speed_factor"], float), 82 | "GEN_DEFAULT_CFG_SCALE": (["generation_defaults", "cfg_scale"], float), 83 | "GEN_DEFAULT_TEMPERATURE": (["generation_defaults", "temperature"], float), 84 | "GEN_DEFAULT_TOP_P": (["generation_defaults", "top_p"], float), 85 | "GEN_DEFAULT_CFG_FILTER_TOP_K": (["generation_defaults", "cfg_filter_top_k"], int), 86 | "GEN_DEFAULT_SEED": (["generation_defaults", "seed"], int), 87 | # Note: split_text and chunk_size defaults are not typically in .env 88 | # Note: ui_state is never loaded from .env 89 | } 90 | 91 | 92 | def _deep_merge_dicts(source: Dict, destination: Dict) -> Dict: 93 | """ 94 | Recursively merges source dict into destination dict. 95 | Modifies destination in place. 96 | """ 97 | for key, value in source.items(): 98 | if isinstance(value, dict): 99 | # get node or create one 100 | node = destination.setdefault(key, {}) 101 | _deep_merge_dicts(value, node) 102 | else: 103 | destination[key] = value 104 | return destination 105 | 106 | 107 | def _set_nested_value(d: Dict, keys: List[str], value: Any): 108 | """Sets a value in a nested dictionary using a list of keys.""" 109 | for key in keys[:-1]: 110 | d = d.setdefault(key, {}) 111 | d[keys[-1]] = value 112 | 113 | 114 | def _get_nested_value(d: Dict, keys: List[str], default: Any = None) -> Any: 115 | """Gets a value from a nested dictionary using a list of keys.""" 116 | for key in keys: 117 | if isinstance(d, dict) and key in d: 118 | d = d[key] 119 | else: 120 | return default 121 | return d 122 | 123 | 124 | class YamlConfigManager: 125 | """Manages configuration for the TTS server using a YAML file.""" 126 | 127 | def __init__(self): 128 | """Initialize the configuration manager by loading the config.""" 129 | self.config: Dict[str, Any] = {} 130 | self._lock = Lock() # Lock for thread-safe file writing 131 | self.load_config() 132 | 133 | def _load_defaults(self) -> Dict[str, Any]: 134 | """Returns a deep copy of the hardcoded default configuration.""" 135 | return deepcopy(DEFAULT_CONFIG) 136 | 137 | def _load_env_overrides(self, config_dict: Dict[str, Any]) -> Dict[str, Any]: 138 | """ 139 | Loads .env file (if found) and overrides values in the provided config_dict. 140 | Used ONLY during initial seeding or reset. 141 | """ 142 | if not ENV_FILE_PATH: 143 | logger.info("No .env file found, skipping environment variable overrides.") 144 | return config_dict 145 | 146 | logger.info(f"Loading environment variables from: {ENV_FILE_PATH}") 147 | # Load .env variables into os.environ temporarily 148 | load_dotenv(dotenv_path=ENV_FILE_PATH, override=True) 149 | 150 | env_values_applied = 0 151 | for env_var, (yaml_path, target_type) in ENV_TO_YAML_MAP.items(): 152 | env_value_str = os.environ.get(env_var) 153 | if env_value_str is not None: 154 | try: 155 | # Attempt type conversion 156 | if target_type is bool: 157 | converted_value = env_value_str.lower() in ( 158 | "true", 159 | "1", 160 | "t", 161 | "yes", 162 | "y", 163 | ) 164 | else: 165 | converted_value = target_type(env_value_str) 166 | 167 | _set_nested_value(config_dict, yaml_path, converted_value) 168 | logger.debug( 169 | f"Applied .env override: {'->'.join(yaml_path)} = {converted_value}" 170 | ) 171 | env_values_applied += 1 172 | except (ValueError, TypeError) as e: 173 | logger.warning( 174 | f"Could not apply .env override for '{env_var}'. Invalid value '{env_value_str}' for type {target_type.__name__}. Using default. Error: {e}" 175 | ) 176 | # Clean up the loaded env var from os.environ if desired, though it's usually harmless 177 | # if env_var in os.environ: 178 | # del os.environ[env_var] 179 | 180 | if env_values_applied > 0: 181 | logger.info(f"Applied {env_values_applied} overrides from .env file.") 182 | else: 183 | logger.info("No applicable overrides found in .env file.") 184 | 185 | return config_dict 186 | 187 | def load_config(self): 188 | """ 189 | Loads configuration from config.yaml. 190 | If config.yaml doesn't exist, creates it by seeding from .env (if found) 191 | and hardcoded defaults. 192 | Ensures all default keys are present in the loaded config. 193 | """ 194 | with self._lock: # Ensure loading process is atomic if called concurrently (unlikely at startup) 195 | loaded_config = self._load_defaults() # Start with defaults 196 | 197 | if os.path.exists(CONFIG_FILE_PATH): 198 | logger.info(f"Loading configuration from: {CONFIG_FILE_PATH}") 199 | try: 200 | with open(CONFIG_FILE_PATH, "r", encoding="utf-8") as f: 201 | yaml_data = yaml.safe_load(f) 202 | if isinstance(yaml_data, dict): 203 | # Merge loaded data onto defaults to ensure all keys exist 204 | # and to add new default keys if the file is old. 205 | loaded_config = _deep_merge_dicts(yaml_data, loaded_config) 206 | logger.info("Successfully loaded and merged config.yaml.") 207 | else: 208 | logger.error( 209 | f"Invalid format in {CONFIG_FILE_PATH}. Expected a dictionary (key-value pairs). Using defaults and attempting to overwrite." 210 | ) 211 | # Proceed using defaults, will attempt to save later 212 | if not self._save_config_yaml_internal(loaded_config): 213 | logger.error( 214 | f"Failed to overwrite invalid {CONFIG_FILE_PATH} with defaults." 215 | ) 216 | 217 | except yaml.YAMLError as e: 218 | logger.error( 219 | f"Error parsing {CONFIG_FILE_PATH}: {e}. Using defaults and attempting to overwrite." 220 | ) 221 | if not self._save_config_yaml_internal(loaded_config): 222 | logger.error( 223 | f"Failed to overwrite corrupted {CONFIG_FILE_PATH} with defaults." 224 | ) 225 | except Exception as e: 226 | logger.error( 227 | f"Unexpected error loading {CONFIG_FILE_PATH}: {e}. Using defaults.", 228 | exc_info=True, 229 | ) 230 | # Don't try to save if it was an unexpected error 231 | 232 | else: 233 | logger.info( 234 | f"{CONFIG_FILE_PATH} not found. Creating initial configuration..." 235 | ) 236 | # Seed from .env overrides onto the defaults 237 | loaded_config = self._load_env_overrides(loaded_config) 238 | # Save the newly created config 239 | if self._save_config_yaml_internal(loaded_config): 240 | logger.info( 241 | f"Successfully created and saved initial configuration to {CONFIG_FILE_PATH}." 242 | ) 243 | else: 244 | logger.error( 245 | f"Failed to save initial configuration to {CONFIG_FILE_PATH}. Using in-memory defaults." 246 | ) 247 | 248 | self.config = loaded_config 249 | logger.debug(f"Current config loaded: {self.config}") 250 | return self.config 251 | 252 | def _save_config_yaml_internal(self, config_dict: Dict[str, Any]) -> bool: 253 | """ 254 | Internal method to save the configuration dictionary to config.yaml. 255 | Includes backup and restore mechanism. Assumes lock is already held. 256 | """ 257 | temp_file_path = CONFIG_FILE_PATH + ".tmp" 258 | backup_file_path = CONFIG_FILE_PATH + ".bak" 259 | 260 | try: 261 | # Write to temporary file first 262 | with open(temp_file_path, "w", encoding="utf-8") as f: 263 | yaml.dump( 264 | config_dict, f, default_flow_style=False, sort_keys=False, indent=2 265 | ) 266 | 267 | # Backup existing file if it exists 268 | if os.path.exists(CONFIG_FILE_PATH): 269 | try: 270 | shutil.move(CONFIG_FILE_PATH, backup_file_path) 271 | logger.debug(f"Backed up existing config to {backup_file_path}") 272 | except Exception as backup_err: 273 | logger.warning( 274 | f"Could not create backup of {CONFIG_FILE_PATH}: {backup_err}" 275 | ) 276 | # Decide whether to proceed without backup or fail 277 | # Proceeding for now, but log the warning. 278 | 279 | # Rename temporary file to actual config file 280 | shutil.move(temp_file_path, CONFIG_FILE_PATH) 281 | logger.info(f"Configuration successfully saved to {CONFIG_FILE_PATH}") 282 | return True 283 | 284 | except yaml.YAMLError as e: 285 | logger.error( 286 | f"Error formatting data for {CONFIG_FILE_PATH}: {e}", exc_info=True 287 | ) 288 | return False 289 | except Exception as e: 290 | logger.error( 291 | f"Failed to save configuration to {CONFIG_FILE_PATH}: {e}", 292 | exc_info=True, 293 | ) 294 | # Attempt to restore backup if save failed mid-way 295 | if os.path.exists(backup_file_path) and not os.path.exists( 296 | CONFIG_FILE_PATH 297 | ): 298 | try: 299 | shutil.move(backup_file_path, CONFIG_FILE_PATH) 300 | logger.info( 301 | f"Restored configuration from backup {backup_file_path}" 302 | ) 303 | except Exception as restore_err: 304 | logger.error( 305 | f"Failed to restore configuration from backup: {restore_err}" 306 | ) 307 | # Clean up temp file if it still exists 308 | if os.path.exists(temp_file_path): 309 | try: 310 | os.remove(temp_file_path) 311 | except Exception as remove_err: 312 | logger.warning( 313 | f"Could not remove temporary config file {temp_file_path}: {remove_err}" 314 | ) 315 | return False 316 | finally: 317 | # Clean up backup file if main file exists and save was successful (or if backup failed) 318 | if os.path.exists(CONFIG_FILE_PATH) and os.path.exists(backup_file_path): 319 | try: 320 | # Only remove backup if the main file write seems okay 321 | if os.path.getsize(CONFIG_FILE_PATH) > 0: # Basic check 322 | os.remove(backup_file_path) 323 | logger.debug(f"Removed backup file {backup_file_path}") 324 | except Exception as remove_bak_err: 325 | logger.warning( 326 | f"Could not remove backup config file {backup_file_path}: {remove_bak_err}" 327 | ) 328 | 329 | def save_config_yaml(self, config_dict: Dict[str, Any]) -> bool: 330 | """Public method to save the configuration dictionary with locking.""" 331 | with self._lock: 332 | return self._save_config_yaml_internal(config_dict) 333 | 334 | def get(self, key_path: str, default: Any = None) -> Any: 335 | """ 336 | Get a configuration value using a dot-separated key path. 337 | e.g., get('server.port', 8000) 338 | """ 339 | keys = key_path.split(".") 340 | value = _get_nested_value(self.config, keys, default) 341 | # Ensure we return a copy for mutable types like dicts/lists 342 | return deepcopy(value) if isinstance(value, (dict, list)) else value 343 | 344 | def get_all(self) -> Dict[str, Any]: 345 | """Get a deep copy of all current configuration values.""" 346 | with self._lock: # Ensure consistency while copying 347 | return deepcopy(self.config) 348 | 349 | def update_and_save(self, partial_update_dict: Dict[str, Any]) -> bool: 350 | """ 351 | Deep merges a partial update dictionary into the current config 352 | and saves the entire configuration back to the YAML file. 353 | """ 354 | if not isinstance(partial_update_dict, dict): 355 | logger.error("Invalid partial update data: must be a dictionary.") 356 | return False 357 | 358 | with self._lock: 359 | try: 360 | # Create a deep copy to avoid modifying self.config directly before successful save 361 | updated_config = deepcopy(self.config) 362 | # Merge the partial update into the copy 363 | _deep_merge_dicts(partial_update_dict, updated_config) 364 | 365 | # Save the fully merged configuration 366 | if self._save_config_yaml_internal(updated_config): 367 | # If save was successful, update the in-memory config 368 | self.config = updated_config 369 | logger.info("Configuration updated and saved successfully.") 370 | return True 371 | else: 372 | logger.error("Failed to save updated configuration.") 373 | return False 374 | except Exception as e: 375 | logger.error( 376 | f"Error during configuration update and save: {e}", exc_info=True 377 | ) 378 | return False 379 | 380 | def reset_and_save(self) -> bool: 381 | """ 382 | Resets the configuration to hardcoded defaults, potentially overridden 383 | by values in the .env file, and saves it to config.yaml. 384 | """ 385 | with self._lock: 386 | logger.warning( 387 | "Resetting configuration to defaults (with .env overrides)..." 388 | ) 389 | reset_config = self._load_defaults() 390 | reset_config = self._load_env_overrides( 391 | reset_config 392 | ) # Apply .env overrides to defaults 393 | 394 | if self._save_config_yaml_internal(reset_config): 395 | self.config = reset_config # Update in-memory config 396 | logger.info("Configuration successfully reset and saved.") 397 | return True 398 | else: 399 | logger.error("Failed to save reset configuration.") 400 | # Keep the old config in memory if save failed 401 | return False 402 | 403 | # --- Type-specific Getters with Error Handling --- 404 | def get_int(self, key_path: str, default: Optional[int] = None) -> int: 405 | """Get a configuration value as an integer.""" 406 | value = self.get(key_path) 407 | if value is None: 408 | if default is not None: 409 | logger.debug( 410 | f"Config '{key_path}' not found, using provided default: {default}" 411 | ) 412 | return default 413 | else: 414 | logger.error( 415 | f"Mandatory config '{key_path}' not found and no default. Returning 0." 416 | ) 417 | return 0 418 | try: 419 | return int(value) 420 | except (ValueError, TypeError): 421 | logger.warning( 422 | f"Invalid integer value '{value}' for '{key_path}'. Using default: {default}" 423 | ) 424 | if isinstance(default, int): 425 | return default 426 | else: 427 | logger.error( 428 | f"Cannot parse '{value}' as int for '{key_path}' and no valid default. Returning 0." 429 | ) 430 | return 0 431 | 432 | def get_float(self, key_path: str, default: Optional[float] = None) -> float: 433 | """Get a configuration value as a float.""" 434 | value = self.get(key_path) 435 | if value is None: 436 | if default is not None: 437 | logger.debug( 438 | f"Config '{key_path}' not found, using provided default: {default}" 439 | ) 440 | return default 441 | else: 442 | logger.error( 443 | f"Mandatory config '{key_path}' not found and no default. Returning 0.0." 444 | ) 445 | return 0.0 446 | try: 447 | return float(value) 448 | except (ValueError, TypeError): 449 | logger.warning( 450 | f"Invalid float value '{value}' for '{key_path}'. Using default: {default}" 451 | ) 452 | if isinstance(default, float): 453 | return default 454 | else: 455 | logger.error( 456 | f"Cannot parse '{value}' as float for '{key_path}' and no valid default. Returning 0.0." 457 | ) 458 | return 0.0 459 | 460 | def get_bool(self, key_path: str, default: Optional[bool] = None) -> bool: 461 | """Get a configuration value as a boolean.""" 462 | value = self.get(key_path) 463 | if value is None: 464 | if default is not None: 465 | logger.debug( 466 | f"Config '{key_path}' not found, using provided default: {default}" 467 | ) 468 | return default 469 | else: 470 | logger.error( 471 | f"Mandatory config '{key_path}' not found and no default. Returning False." 472 | ) 473 | return False 474 | if isinstance(value, bool): 475 | return value 476 | if isinstance(value, str): 477 | return value.lower() in ("true", "1", "t", "yes", "y") 478 | try: 479 | # Handle numeric representations (e.g., 1 for True, 0 for False) 480 | return bool(int(value)) 481 | except (ValueError, TypeError): 482 | logger.warning( 483 | f"Invalid boolean value '{value}' for '{key_path}'. Using default: {default}" 484 | ) 485 | if isinstance(default, bool): 486 | return default 487 | else: 488 | logger.error( 489 | f"Cannot parse '{value}' as bool for '{key_path}' and no valid default. Returning False." 490 | ) 491 | return False 492 | 493 | 494 | # --- Create a singleton instance for global access --- 495 | config_manager = YamlConfigManager() 496 | 497 | # --- Export common getters for easy access --- 498 | 499 | 500 | # Helper to get default value from the DEFAULT_CONFIG structure 501 | def _get_default(key_path: str) -> Any: 502 | keys = key_path.split(".") 503 | return _get_nested_value(DEFAULT_CONFIG, keys) 504 | 505 | 506 | # Server Settings 507 | def get_host() -> str: 508 | return config_manager.get("server.host", _get_default("server.host")) 509 | 510 | 511 | def get_port() -> int: 512 | return config_manager.get_int("server.port", _get_default("server.port")) 513 | 514 | 515 | # Model Source Settings 516 | def get_model_repo_id() -> str: 517 | return config_manager.get("model.repo_id", _get_default("model.repo_id")) 518 | 519 | 520 | def get_model_config_filename() -> str: 521 | return config_manager.get( 522 | "model.config_filename", _get_default("model.config_filename") 523 | ) 524 | 525 | 526 | def get_model_weights_filename() -> str: 527 | return config_manager.get( 528 | "model.weights_filename", _get_default("model.weights_filename") 529 | ) 530 | 531 | 532 | def get_whisper_model_name() -> str: 533 | return config_manager.get( 534 | "model.whisper_model_name", _get_default("model.whisper_model_name") 535 | ) 536 | 537 | 538 | # Path Settings 539 | def get_model_cache_path() -> str: 540 | return os.path.abspath( 541 | config_manager.get("paths.model_cache", _get_default("paths.model_cache")) 542 | ) 543 | 544 | 545 | def get_reference_audio_path() -> str: 546 | return os.path.abspath( 547 | config_manager.get( 548 | "paths.reference_audio", _get_default("paths.reference_audio") 549 | ) 550 | ) 551 | 552 | 553 | def get_output_path() -> str: 554 | return os.path.abspath( 555 | config_manager.get("paths.output", _get_default("paths.output")) 556 | ) 557 | 558 | 559 | def get_predefined_voices_path() -> str: 560 | return os.path.abspath( 561 | config_manager.get("paths.voices", _get_default("paths.voices")) 562 | ) 563 | 564 | 565 | # Default Generation Parameter Getters 566 | def get_gen_default_speed_factor() -> float: 567 | return config_manager.get_float( 568 | "generation_defaults.speed_factor", 569 | _get_default("generation_defaults.speed_factor"), 570 | ) 571 | 572 | 573 | def get_gen_default_cfg_scale() -> float: 574 | return config_manager.get_float( 575 | "generation_defaults.cfg_scale", _get_default("generation_defaults.cfg_scale") 576 | ) 577 | 578 | 579 | def get_gen_default_temperature() -> float: 580 | return config_manager.get_float( 581 | "generation_defaults.temperature", 582 | _get_default("generation_defaults.temperature"), 583 | ) 584 | 585 | 586 | def get_gen_default_top_p() -> float: 587 | return config_manager.get_float( 588 | "generation_defaults.top_p", _get_default("generation_defaults.top_p") 589 | ) 590 | 591 | 592 | def get_gen_default_cfg_filter_top_k() -> int: 593 | return config_manager.get_int( 594 | "generation_defaults.cfg_filter_top_k", 595 | _get_default("generation_defaults.cfg_filter_top_k"), 596 | ) 597 | 598 | 599 | def get_gen_default_seed() -> int: 600 | return config_manager.get_int( 601 | "generation_defaults.seed", _get_default("generation_defaults.seed") 602 | ) 603 | 604 | 605 | def get_gen_default_split_text() -> bool: 606 | return config_manager.get_bool( 607 | "generation_defaults.split_text", _get_default("generation_defaults.split_text") 608 | ) 609 | 610 | 611 | def get_gen_default_chunk_size() -> int: 612 | return config_manager.get_int( 613 | "generation_defaults.chunk_size", _get_default("generation_defaults.chunk_size") 614 | ) 615 | 616 | 617 | # UI State Getters (might be less frequently needed directly in backend) 618 | def get_ui_state() -> Dict[str, Any]: 619 | """Gets the entire UI state dictionary.""" 620 | return config_manager.get("ui_state", _get_default("ui_state")) 621 | 622 | 623 | def get_hide_chunk_warning() -> bool: 624 | """Gets the flag for hiding the chunk warning dialog.""" 625 | return config_manager.get_bool( 626 | "ui_state.hide_chunk_warning", _get_default("ui_state.hide_chunk_warning") 627 | ) 628 | 629 | 630 | def get_hide_generation_warning() -> bool: 631 | """Gets the flag for hiding the general generation warning dialog.""" 632 | return config_manager.get_bool( 633 | "ui_state.hide_generation_warning", 634 | _get_default("ui_state.hide_generation_warning"), 635 | ) 636 | -------------------------------------------------------------------------------- /dia/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/dia/__init__.py -------------------------------------------------------------------------------- /dia/audio.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import torch 4 | 5 | 6 | def build_delay_indices( 7 | B: int, T: int, C: int, delay_pattern: tp.List[int] 8 | ) -> tp.Tuple[torch.Tensor, torch.Tensor]: 9 | """ 10 | Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c]. 11 | Negative t_idx => BOS; t_idx >= T => PAD. 12 | """ 13 | delay_arr = torch.tensor(delay_pattern, dtype=torch.int32) 14 | 15 | t_idx_BxT = torch.broadcast_to( 16 | torch.arange(T, dtype=torch.int32)[None, :], 17 | [B, T], 18 | ) 19 | t_idx_BxTx1 = t_idx_BxT[..., None] 20 | t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C) 21 | 22 | b_idx_BxTxC = torch.broadcast_to( 23 | torch.arange(B, dtype=torch.int32).view(B, 1, 1), 24 | [B, T, C], 25 | ) 26 | c_idx_BxTxC = torch.broadcast_to( 27 | torch.arange(C, dtype=torch.int32).view(1, 1, C), 28 | [B, T, C], 29 | ) 30 | 31 | # We must clamp time indices to [0..T-1] so gather_nd equivalent won't fail 32 | t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1) 33 | 34 | indices_BTCx3 = torch.stack( 35 | [ 36 | b_idx_BxTxC.reshape(-1), 37 | t_clamped_BxTxC.reshape(-1), 38 | c_idx_BxTxC.reshape(-1), 39 | ], 40 | dim=1, 41 | ).long() # Ensure indices are long type for indexing 42 | 43 | return t_idx_BxTxC, indices_BTCx3 44 | 45 | 46 | def apply_audio_delay( 47 | audio_BxTxC: torch.Tensor, 48 | pad_value: int, 49 | bos_value: int, 50 | precomp: tp.Tuple[torch.Tensor, torch.Tensor], 51 | ) -> torch.Tensor: 52 | """ 53 | Applies the delay pattern to batched audio tokens using precomputed indices, 54 | inserting BOS where t_idx < 0 and PAD where t_idx >= T. 55 | 56 | Args: 57 | audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float) 58 | pad_value: the padding token 59 | bos_value: the BOS token 60 | precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices 61 | 62 | Returns: 63 | result_BxTxC: [B, T, C] delayed audio tokens 64 | """ 65 | device = audio_BxTxC.device # Get device from input tensor 66 | t_idx_BxTxC, indices_BTCx3 = precomp 67 | t_idx_BxTxC = t_idx_BxTxC.to(device) # Move precomputed indices to device 68 | indices_BTCx3 = indices_BTCx3.to(device) 69 | 70 | # Equivalent of tf.gather_nd using advanced indexing 71 | # Ensure indices are long type if not already (build_delay_indices should handle this) 72 | gathered_flat = audio_BxTxC[ 73 | indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2] 74 | ] 75 | gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape) 76 | 77 | # Create masks on the correct device 78 | mask_bos = t_idx_BxTxC < 0 # => place bos_value 79 | mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] # => place pad_value 80 | 81 | # Create scalar tensors on the correct device 82 | bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device) 83 | pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device) 84 | 85 | # If mask_bos, BOS; else if mask_pad, PAD; else original gather 86 | # All tensors should now be on the same device 87 | result_BxTxC = torch.where( 88 | mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC) 89 | ) 90 | 91 | return result_BxTxC 92 | 93 | 94 | def build_revert_indices( 95 | B: int, T: int, C: int, delay_pattern: tp.List[int] 96 | ) -> tp.Tuple[torch.Tensor, torch.Tensor]: 97 | """ 98 | Precompute indices for the revert operation using PyTorch. 99 | 100 | Returns: 101 | A tuple (t_idx_BxTxC, indices_BTCx3) where: 102 | - t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay. 103 | - indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from: 104 | batch indices, clamped time indices, and channel indices. 105 | """ 106 | # Use default device unless specified otherwise; assumes inputs might define device later 107 | device = None # Or determine dynamically if needed, e.g., from a model parameter 108 | 109 | delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device) 110 | 111 | t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T]) 112 | t_idx_BT1 = t_idx_BT1.unsqueeze(-1) 113 | 114 | t_idx_BxTxC = torch.minimum( 115 | t_idx_BT1 + delay_arr.view(1, 1, C), 116 | torch.tensor(T - 1, device=device), 117 | ) 118 | b_idx_BxTxC = torch.broadcast_to( 119 | torch.arange(B, device=device).view(B, 1, 1), [B, T, C] 120 | ) 121 | c_idx_BxTxC = torch.broadcast_to( 122 | torch.arange(C, device=device).view(1, 1, C), [B, T, C] 123 | ) 124 | 125 | indices_BTCx3 = torch.stack( 126 | [ 127 | b_idx_BxTxC.reshape(-1), 128 | t_idx_BxTxC.reshape(-1), 129 | c_idx_BxTxC.reshape(-1), 130 | ], 131 | axis=1, 132 | ).long() # Ensure indices are long type 133 | 134 | return t_idx_BxTxC, indices_BTCx3 135 | 136 | 137 | def revert_audio_delay( 138 | audio_BxTxC: torch.Tensor, 139 | pad_value: int, 140 | precomp: tp.Tuple[torch.Tensor, torch.Tensor], 141 | T: int, 142 | ) -> torch.Tensor: 143 | """ 144 | Reverts a delay pattern from batched audio tokens using precomputed indices (PyTorch version). 145 | 146 | Args: 147 | audio_BxTxC: Input delayed audio tensor 148 | pad_value: Padding value for out-of-bounds indices 149 | precomp: Precomputed revert indices tuple containing: 150 | - t_idx_BxTxC: Time offset indices tensor 151 | - indices_BTCx3: Gather indices tensor for original audio 152 | T: Original sequence length before padding 153 | 154 | Returns: 155 | Reverted audio tensor with same shape as input 156 | """ 157 | t_idx_BxTxC, indices_BTCx3 = precomp 158 | device = audio_BxTxC.device # Get device from input tensor 159 | 160 | # Move precomputed indices to the same device as audio_BxTxC if they aren't already 161 | t_idx_BxTxC = t_idx_BxTxC.to(device) 162 | indices_BTCx3 = indices_BTCx3.to(device) 163 | 164 | # Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent) 165 | gathered_flat = audio_BxTxC[ 166 | indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2] 167 | ] 168 | gathered_BxTxC = gathered_flat.view( 169 | audio_BxTxC.size() 170 | ) # Use .size() for robust reshaping 171 | 172 | # Create pad_tensor on the correct device 173 | pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device) 174 | # Create T tensor on the correct device for comparison 175 | T_tensor = torch.tensor(T, device=device) 176 | 177 | result_BxTxC = torch.where( 178 | t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC 179 | ) # Changed np.where to torch.where 180 | 181 | return result_BxTxC 182 | 183 | 184 | @torch.no_grad() 185 | @torch.inference_mode() 186 | def decode( 187 | model, 188 | audio_codes, 189 | ): 190 | """ 191 | Decodes the given frames into an output audio waveform 192 | """ 193 | if len(audio_codes) != 1: 194 | raise ValueError(f"Expected one frame, got {len(audio_codes)}") 195 | 196 | try: 197 | audio_values = model.quantizer.from_codes(audio_codes) 198 | audio_values = model.decode(audio_values[0]) 199 | 200 | return audio_values 201 | except Exception as e: 202 | print(f"Error in decode method: {str(e)}") 203 | raise 204 | -------------------------------------------------------------------------------- /dia/config.py: -------------------------------------------------------------------------------- 1 | """Configuration management module for the Dia model. 2 | 3 | This module provides comprehensive configuration management for the Dia model, 4 | utilizing Pydantic for validation. It defines configurations for data processing, 5 | model architecture (encoder and decoder), and training settings. 6 | 7 | Key components: 8 | - DataConfig: Parameters for data loading and preprocessing. 9 | - EncoderConfig: Architecture details for the encoder module. 10 | - DecoderConfig: Architecture details for the decoder module. 11 | - ModelConfig: Combined model architecture settings. 12 | - TrainingConfig: Training hyperparameters and settings. 13 | - DiaConfig: Master configuration combining all components. 14 | """ 15 | 16 | import os 17 | from typing import Annotated 18 | 19 | from pydantic import BaseModel, BeforeValidator, Field 20 | 21 | 22 | class DataConfig(BaseModel, frozen=True): 23 | """Configuration for data loading and preprocessing. 24 | 25 | Attributes: 26 | text_length: Maximum length of text sequences (must be multiple of 128). 27 | audio_length: Maximum length of audio sequences (must be multiple of 128). 28 | channels: Number of audio channels. 29 | text_pad_value: Value used for padding text sequences. 30 | audio_eos_value: Value representing the end of audio sequences. 31 | audio_bos_value: Value representing the beginning of audio sequences. 32 | audio_pad_value: Value used for padding audio sequences. 33 | delay_pattern: List of delay values for each audio channel. 34 | """ 35 | 36 | text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = ( 37 | Field(gt=0, multiple_of=128) 38 | ) 39 | audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = ( 40 | Field(gt=0, multiple_of=128) 41 | ) 42 | channels: int = Field(default=9, gt=0, multiple_of=1) 43 | text_pad_value: int = Field(default=0) 44 | audio_eos_value: int = Field(default=1024) 45 | audio_pad_value: int = Field(default=1025) 46 | audio_bos_value: int = Field(default=1026) 47 | delay_pattern: list[Annotated[int, Field(ge=0)]] = Field( 48 | default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15] 49 | ) 50 | 51 | def __hash__(self) -> int: 52 | """Generate a hash based on all fields of the config.""" 53 | return hash( 54 | ( 55 | self.text_length, 56 | self.audio_length, 57 | self.channels, 58 | self.text_pad_value, 59 | self.audio_pad_value, 60 | self.audio_bos_value, 61 | self.audio_eos_value, 62 | tuple(self.delay_pattern), 63 | ) 64 | ) 65 | 66 | 67 | class EncoderConfig(BaseModel, frozen=True): 68 | """Configuration for the encoder component of the Dia model. 69 | 70 | Attributes: 71 | n_layer: Number of transformer layers. 72 | n_embd: Embedding dimension. 73 | n_hidden: Hidden dimension size in the MLP layers. 74 | n_head: Number of attention heads. 75 | head_dim: Dimension per attention head. 76 | """ 77 | 78 | n_layer: int = Field(gt=0) 79 | n_embd: int = Field(gt=0) 80 | n_hidden: int = Field(gt=0) 81 | n_head: int = Field(gt=0) 82 | head_dim: int = Field(gt=0) 83 | 84 | 85 | class DecoderConfig(BaseModel, frozen=True): 86 | """Configuration for the decoder component of the Dia model. 87 | 88 | Attributes: 89 | n_layer: Number of transformer layers. 90 | n_embd: Embedding dimension. 91 | n_hidden: Hidden dimension size in the MLP layers. 92 | gqa_query_heads: Number of query heads for grouped-query self-attention. 93 | kv_heads: Number of key/value heads for grouped-query self-attention. 94 | gqa_head_dim: Dimension per query head for grouped-query self-attention. 95 | cross_query_heads: Number of query heads for cross-attention. 96 | cross_head_dim: Dimension per cross-attention head. 97 | """ 98 | 99 | n_layer: int = Field(gt=0) 100 | n_embd: int = Field(gt=0) 101 | n_hidden: int = Field(gt=0) 102 | gqa_query_heads: int = Field(gt=0) 103 | kv_heads: int = Field(gt=0) 104 | gqa_head_dim: int = Field(gt=0) 105 | cross_query_heads: int = Field(gt=0) 106 | cross_head_dim: int = Field(gt=0) 107 | 108 | 109 | class ModelConfig(BaseModel, frozen=True): 110 | """Main configuration container for the Dia model architecture. 111 | 112 | Attributes: 113 | encoder: Configuration for the encoder component. 114 | decoder: Configuration for the decoder component. 115 | src_vocab_size: Size of the source (text) vocabulary. 116 | tgt_vocab_size: Size of the target (audio code) vocabulary. 117 | dropout: Dropout probability applied within the model. 118 | normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm). 119 | weight_dtype: Data type for model weights (e.g., "float32", "bfloat16"). 120 | rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE). 121 | rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE). 122 | """ 123 | 124 | encoder: EncoderConfig 125 | decoder: DecoderConfig 126 | src_vocab_size: int = Field(default=128, gt=0) 127 | tgt_vocab_size: int = Field(default=1028, gt=0) 128 | dropout: float = Field(default=0.0, ge=0.0, lt=1.0) 129 | normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0) 130 | weight_dtype: str = Field(default="float32", description="Weight precision") 131 | rope_min_timescale: int = Field( 132 | default=1, description="Timescale For global Attention" 133 | ) 134 | rope_max_timescale: int = Field( 135 | default=10_000, description="Timescale For global Attention" 136 | ) 137 | 138 | 139 | class TrainingConfig(BaseModel, frozen=True): 140 | pass 141 | 142 | 143 | class DiaConfig(BaseModel, frozen=True): 144 | """Master configuration for the Dia model. 145 | 146 | Combines all sub-configurations into a single validated object. 147 | 148 | Attributes: 149 | version: Configuration version string. 150 | model: Model architecture configuration. 151 | training: Training process configuration (precision settings). 152 | data: Data loading and processing configuration. 153 | """ 154 | 155 | version: str = Field(default="1.0") 156 | model: ModelConfig 157 | # TODO: remove training. this is just for backward compatibility 158 | training: TrainingConfig | None = Field(default=None) 159 | data: DataConfig 160 | 161 | def save(self, path: str) -> None: 162 | """Save the current configuration instance to a JSON file. 163 | 164 | Ensures the parent directory exists and the file has a .json extension. 165 | 166 | Args: 167 | path: The target file path to save the configuration. 168 | 169 | Raises: 170 | ValueError: If the path is not a file with a .json extension. 171 | """ 172 | os.makedirs(os.path.dirname(path), exist_ok=True) 173 | config_json = self.model_dump_json(indent=2) 174 | with open(path, "w") as f: 175 | f.write(config_json) 176 | 177 | @classmethod 178 | def load(cls, path: str) -> "DiaConfig | None": 179 | """Load and validate a Dia configuration from a JSON file. 180 | 181 | Args: 182 | path: The path to the configuration file. 183 | 184 | Returns: 185 | A validated DiaConfig instance if the file exists and is valid, 186 | otherwise None if the file is not found. 187 | 188 | Raises: 189 | ValueError: If the path does not point to an existing .json file. 190 | pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema. 191 | """ 192 | try: 193 | with open(path, "r") as f: 194 | content = f.read() 195 | return cls.model_validate_json(content) 196 | except FileNotFoundError: 197 | return None 198 | -------------------------------------------------------------------------------- /dia/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from huggingface_hub import PyTorchModelHubMixin 5 | from torch import Tensor 6 | from torch.nn import RMSNorm 7 | 8 | from .config import DiaConfig 9 | from .state import DecoderInferenceState, EncoderInferenceState, KVCache 10 | 11 | 12 | def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]: 13 | return tuple(ax if ax >= 0 else ndim + ax for ax in axes) 14 | 15 | 16 | class DenseGeneral(nn.Module): 17 | """ 18 | PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init. 19 | 20 | Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot 21 | for the generalized matrix multiplication. Weight/bias shapes are calculated 22 | and parameters created during initialization based on config. 23 | `load_weights` validates shapes and copies data. 24 | 25 | Attributes: 26 | axis (Tuple[int, ...]): Input axis or axes to contract. 27 | in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`. 28 | out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims). 29 | use_bias (bool): Whether to add a bias term. 30 | weight (nn.Parameter): The kernel parameter. 31 | bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True). 32 | """ 33 | 34 | def __init__( 35 | self, 36 | in_shapes: tuple[int, ...], 37 | out_features: tuple[int, ...], 38 | axis: tuple[int, ...] = (-1,), 39 | weight_dtype: torch.dtype | None = None, 40 | device: torch.device | None = None, 41 | ): 42 | super().__init__() 43 | self.in_shapes = in_shapes 44 | self.out_features = out_features 45 | self.axis = axis 46 | self.kernel_shape = self.in_shapes + self.out_features 47 | 48 | factory_kwargs = {"device": device, "dtype": weight_dtype} 49 | self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs)) 50 | 51 | def forward(self, inputs: Tensor) -> Tensor: 52 | norm_axis = _normalize_axes(self.axis, inputs.ndim) 53 | kernel_contract_axes = tuple(range(len(norm_axis))) 54 | 55 | output = torch.tensordot( 56 | inputs.to(self.weight.dtype), 57 | self.weight, 58 | dims=(norm_axis, kernel_contract_axes), 59 | ).to(inputs.dtype) 60 | return output 61 | 62 | 63 | class MlpBlock(nn.Module): 64 | """MLP block using DenseGeneral.""" 65 | 66 | def __init__( 67 | self, embed_dim: int, intermediate_dim: int, compute_dtype: torch.dtype 68 | ): 69 | super().__init__() 70 | self.dtype = compute_dtype 71 | 72 | self.wi_fused = DenseGeneral( 73 | in_shapes=(embed_dim,), 74 | out_features=(2, intermediate_dim), 75 | axis=(-1,), 76 | weight_dtype=compute_dtype, 77 | ) 78 | 79 | self.wo = DenseGeneral( 80 | in_shapes=(intermediate_dim,), 81 | out_features=(embed_dim,), 82 | axis=(-1,), 83 | weight_dtype=compute_dtype, 84 | ) 85 | 86 | def forward(self, x: torch.Tensor) -> torch.Tensor: 87 | """Forward pass.""" 88 | fused_x = self.wi_fused(x) 89 | 90 | gate = fused_x[..., 0, :] 91 | up = fused_x[..., 1, :] 92 | 93 | hidden = torch.mul(F.silu(gate), up).to(self.dtype) 94 | 95 | output = self.wo(hidden) 96 | return output 97 | 98 | 99 | class RotaryEmbedding(nn.Module): 100 | """Rotary Position Embedding (RoPE) implementation in PyTorch.""" 101 | 102 | def __init__( 103 | self, 104 | embedding_dims: int, 105 | min_timescale: int = 1, 106 | max_timescale: int = 10000, 107 | dtype: torch.dtype = torch.float32, 108 | ): 109 | super().__init__() 110 | if embedding_dims % 2 != 0: 111 | raise ValueError("Embedding dim must be even for RoPE.") 112 | self.embedding_dims = embedding_dims 113 | self.min_timescale = min_timescale 114 | self.max_timescale = max_timescale 115 | self.compute_dtype = dtype 116 | 117 | half_embedding_dim = embedding_dims // 2 118 | fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims 119 | timescale = ( 120 | self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction 121 | ).to(torch.float32) 122 | self.register_buffer("timescale", timescale, persistent=False) 123 | 124 | def forward(self, inputs: torch.Tensor, position: torch.Tensor): 125 | """Applies RoPE.""" 126 | position = position.unsqueeze(-1).unsqueeze(-1) 127 | sinusoid_inp = position / self.timescale 128 | sin = torch.sin(sinusoid_inp) 129 | cos = torch.cos(sinusoid_inp) 130 | first_half, second_half = torch.chunk(inputs.to(torch.float32), 2, dim=-1) 131 | first_part = first_half * cos - second_half * sin 132 | second_part = second_half * cos + first_half * sin 133 | return torch.cat( 134 | (first_part.to(self.compute_dtype), second_part.to(self.compute_dtype)), 135 | dim=-1, 136 | ) 137 | 138 | 139 | class Attention(nn.Module): 140 | """Attention using DenseGeneral.""" 141 | 142 | def __init__( 143 | self, 144 | config: DiaConfig, 145 | q_embed_dim: int, 146 | kv_embed_dim: int, 147 | num_query_heads: int, 148 | num_kv_heads: int, 149 | head_dim: int, 150 | compute_dtype: torch.dtype, 151 | is_cross_attn: bool = False, 152 | out_embed_dim: int | None = None, 153 | ): 154 | super().__init__() 155 | self.num_query_heads = num_query_heads 156 | self.num_kv_heads = num_kv_heads 157 | self.head_dim = head_dim 158 | self.is_cross_attn = is_cross_attn 159 | self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim 160 | self.projected_query_dim = num_query_heads * head_dim 161 | if num_query_heads % num_kv_heads != 0: 162 | raise ValueError( 163 | f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})" 164 | ) 165 | self.num_gqa_groups = num_query_heads // num_kv_heads 166 | 167 | # --- Projection Layers using DenseGeneral --- 168 | self.q_proj = DenseGeneral( 169 | in_shapes=(q_embed_dim,), 170 | out_features=(num_query_heads, head_dim), 171 | axis=(-1,), 172 | weight_dtype=compute_dtype, 173 | ) 174 | self.k_proj = DenseGeneral( 175 | in_shapes=(kv_embed_dim,), 176 | out_features=(num_kv_heads, head_dim), 177 | axis=(-1,), 178 | weight_dtype=compute_dtype, 179 | ) 180 | self.v_proj = DenseGeneral( 181 | in_shapes=(kv_embed_dim,), 182 | out_features=(num_kv_heads, head_dim), 183 | axis=(-1,), 184 | weight_dtype=compute_dtype, 185 | ) 186 | self.o_proj = DenseGeneral( 187 | in_shapes=(num_query_heads, head_dim), 188 | out_features=(self.output_dim,), 189 | axis=(-2, -1), 190 | weight_dtype=compute_dtype, 191 | ) 192 | 193 | # --- Rotary Embedding --- 194 | self.rotary_emb = RotaryEmbedding( 195 | embedding_dims=self.head_dim, 196 | min_timescale=config.model.rope_min_timescale, 197 | max_timescale=config.model.rope_max_timescale, 198 | dtype=compute_dtype, 199 | ) 200 | 201 | def forward( 202 | self, 203 | Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation 204 | Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation 205 | q_positions: torch.Tensor, # (B, T) 206 | kv_positions: torch.Tensor | None = None, # (B, S) 207 | attn_mask: ( 208 | torch.Tensor | None 209 | ) = None, # None in Decoder Self Attention, Valid mask in Others 210 | cache: KVCache | None = None, # None in Encoder, KVCache in Decoder 211 | prefill: bool = False, 212 | is_causal: bool = False, 213 | ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: 214 | """ 215 | Performs attention calculation with optional KV caching. 216 | 217 | Args: 218 | Xq: Query tensor (B, T, D). T=1 during single-step decoding. 219 | Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn. 220 | q_positions: Positions for queries (B, T). 221 | kv_positions: Positions for keys/values (B, S). If None, uses q_positions. 222 | attn_mask: Attention mask. 223 | cache: KVCache. 224 | prefill: If True, use prefill mode. 225 | 226 | Returns: 227 | A tuple containing: 228 | - output: The attention output tensor (B, T, output_dim). 229 | - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv. 230 | """ 231 | if kv_positions is None: 232 | kv_positions = q_positions 233 | original_dtype = Xq.dtype 234 | 235 | Xq_BxTxNxH = self.q_proj(Xq) 236 | Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions) 237 | Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2) 238 | 239 | attn_k: torch.Tensor | None = None 240 | attn_v: torch.Tensor | None = None 241 | 242 | if self.is_cross_attn: 243 | attn_k, attn_v = cache.k, cache.v 244 | else: 245 | Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H) 246 | Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H) 247 | Xk_BxSxKxH = self.rotary_emb( 248 | Xk_BxSxKxH, position=kv_positions 249 | ) # (B, S, K, H) 250 | 251 | Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H) 252 | Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H) 253 | 254 | if cache is None: 255 | attn_k = Xk_BxKxSxH 256 | attn_v = Xv_BxKxSxH 257 | else: 258 | if prefill: 259 | attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH 260 | cache.prefill(attn_k, attn_v) 261 | else: 262 | attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH) 263 | 264 | if attn_k is not None and attn_v is not None: 265 | attn_k = attn_k.to(Xq_BxNxTxH.dtype) 266 | attn_v = attn_v.to(Xq_BxNxTxH.dtype) 267 | 268 | attn_output = F.scaled_dot_product_attention( 269 | Xq_BxNxTxH, 270 | attn_k, 271 | attn_v, 272 | attn_mask=attn_mask, 273 | scale=1.0, 274 | enable_gqa=self.num_gqa_groups > 1, 275 | is_causal=is_causal, 276 | ) 277 | 278 | attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H) 279 | output = self.o_proj(attn_output) 280 | 281 | return output.to(original_dtype) 282 | 283 | 284 | class EncoderLayer(nn.Module): 285 | """Transformer Encoder Layer using DenseGeneral.""" 286 | 287 | def __init__(self, config: DiaConfig, compute_dtype: torch.dtype): 288 | super().__init__() 289 | self.config = config 290 | model_config = config.model 291 | enc_config = config.model.encoder 292 | embed_dim = enc_config.n_embd 293 | self.compute_dtype = compute_dtype 294 | 295 | self.pre_sa_norm = RMSNorm( 296 | embed_dim, 297 | eps=model_config.normalization_layer_epsilon, 298 | dtype=torch.float32, 299 | ) 300 | self.self_attention = Attention( 301 | config, 302 | q_embed_dim=embed_dim, 303 | kv_embed_dim=embed_dim, 304 | num_query_heads=enc_config.n_head, 305 | num_kv_heads=enc_config.n_head, 306 | head_dim=enc_config.head_dim, 307 | compute_dtype=compute_dtype, 308 | is_cross_attn=False, 309 | out_embed_dim=embed_dim, 310 | ) 311 | self.post_sa_norm = RMSNorm( 312 | embed_dim, 313 | eps=model_config.normalization_layer_epsilon, 314 | dtype=torch.float32, 315 | ) 316 | self.mlp = MlpBlock( 317 | embed_dim=embed_dim, 318 | intermediate_dim=enc_config.n_hidden, 319 | compute_dtype=compute_dtype, 320 | ) 321 | 322 | def forward( 323 | self, 324 | x: torch.Tensor, 325 | state: EncoderInferenceState, 326 | ) -> torch.Tensor: 327 | residual = x 328 | x_norm = self.pre_sa_norm(x).to(self.compute_dtype) 329 | 330 | sa_out = self.self_attention( 331 | Xq=x_norm, 332 | Xkv=x_norm, 333 | q_positions=state.positions, 334 | kv_positions=state.positions, 335 | attn_mask=state.attn_mask, 336 | ) 337 | x = residual + sa_out 338 | 339 | residual = x 340 | x_norm = self.post_sa_norm(x).to(self.compute_dtype) 341 | mlp_out = self.mlp(x_norm) 342 | x = residual + mlp_out 343 | 344 | return x 345 | 346 | 347 | class Encoder(nn.Module): 348 | """Transformer Encoder Stack using DenseGeneral.""" 349 | 350 | def __init__(self, config: DiaConfig, compute_dtype: torch.dtype): 351 | super().__init__() 352 | self.config = config 353 | model_config = config.model 354 | enc_config = config.model.encoder 355 | self.compute_dtype = compute_dtype 356 | 357 | self.embedding = nn.Embedding( 358 | model_config.src_vocab_size, 359 | enc_config.n_embd, 360 | dtype=compute_dtype, 361 | ) 362 | self.layers = nn.ModuleList( 363 | [EncoderLayer(config, compute_dtype) for _ in range(enc_config.n_layer)] 364 | ) 365 | self.norm = RMSNorm( 366 | enc_config.n_embd, 367 | eps=model_config.normalization_layer_epsilon, 368 | dtype=torch.float32, 369 | ) 370 | 371 | def forward( 372 | self, 373 | x_ids: torch.Tensor, 374 | state: EncoderInferenceState, 375 | ) -> torch.Tensor: 376 | x = self.embedding(x_ids) 377 | 378 | for layer in self.layers: 379 | x = layer(x, state) 380 | 381 | x = self.norm(x).to(self.compute_dtype) 382 | return x 383 | 384 | 385 | class DecoderLayer(nn.Module): 386 | """Transformer Decoder Layer using DenseGeneral.""" 387 | 388 | def __init__(self, config: DiaConfig, compute_dtype: torch.dtype): 389 | super().__init__() 390 | self.config = config 391 | model_config = config.model 392 | dec_config = config.model.decoder 393 | enc_config = config.model.encoder 394 | dec_embed_dim = dec_config.n_embd 395 | enc_embed_dim = enc_config.n_embd 396 | self.compute_dtype = compute_dtype 397 | 398 | # Norms 399 | self.pre_sa_norm = RMSNorm( 400 | dec_embed_dim, 401 | eps=model_config.normalization_layer_epsilon, 402 | dtype=torch.float32, 403 | ) 404 | self.pre_ca_norm = RMSNorm( 405 | dec_embed_dim, 406 | eps=model_config.normalization_layer_epsilon, 407 | dtype=torch.float32, 408 | ) 409 | self.pre_mlp_norm = RMSNorm( 410 | dec_embed_dim, 411 | eps=model_config.normalization_layer_epsilon, 412 | dtype=torch.float32, 413 | ) 414 | 415 | # Self-Attention (GQA) with Causal Masking 416 | self.self_attention = Attention( 417 | config, 418 | q_embed_dim=dec_embed_dim, 419 | kv_embed_dim=dec_embed_dim, 420 | num_query_heads=dec_config.gqa_query_heads, 421 | num_kv_heads=dec_config.kv_heads, 422 | head_dim=dec_config.gqa_head_dim, 423 | compute_dtype=compute_dtype, 424 | is_cross_attn=False, 425 | out_embed_dim=dec_embed_dim, 426 | ) 427 | # Cross-Attention (MHA) 428 | self.cross_attention = Attention( 429 | config=config, 430 | q_embed_dim=dec_embed_dim, 431 | kv_embed_dim=enc_embed_dim, # Note kv_embed_dim 432 | num_query_heads=dec_config.cross_query_heads, 433 | num_kv_heads=dec_config.cross_query_heads, 434 | head_dim=dec_config.cross_head_dim, 435 | compute_dtype=compute_dtype, 436 | is_cross_attn=True, 437 | out_embed_dim=dec_embed_dim, 438 | ) 439 | # MLP 440 | self.mlp = MlpBlock( 441 | embed_dim=dec_embed_dim, 442 | intermediate_dim=dec_config.n_hidden, 443 | compute_dtype=compute_dtype, 444 | ) 445 | 446 | def forward( 447 | self, 448 | x: torch.Tensor, 449 | state: DecoderInferenceState, 450 | self_attn_cache: KVCache | None = None, 451 | cross_attn_cache: KVCache | None = None, 452 | prefill: bool = False, 453 | ) -> torch.Tensor: 454 | residual = x 455 | x_norm = self.pre_sa_norm(x).to(self.compute_dtype) 456 | 457 | sa_out = self.self_attention( 458 | Xq=x_norm, # (2, 1, D) 459 | Xkv=x_norm, # (2, 1, D) 460 | q_positions=state.dec_positions, # (2, 1) 461 | kv_positions=state.dec_positions, # (2, 1) 462 | attn_mask=None, 463 | cache=self_attn_cache, 464 | prefill=prefill, 465 | is_causal=prefill, 466 | ) 467 | 468 | x = residual + sa_out 469 | 470 | residual = x 471 | x_norm = self.pre_ca_norm(x).to(self.compute_dtype) 472 | ca_out = self.cross_attention( 473 | Xq=x_norm, 474 | Xkv=state.enc_out, 475 | q_positions=state.dec_positions, 476 | kv_positions=state.enc_positions, 477 | attn_mask=state.dec_cross_attn_mask, 478 | cache=cross_attn_cache, 479 | ) 480 | x = residual + ca_out 481 | 482 | residual = x 483 | x_norm = self.pre_mlp_norm(x).to(self.compute_dtype) 484 | mlp_out = self.mlp(x_norm) 485 | x = residual + mlp_out 486 | 487 | return x 488 | 489 | 490 | class Decoder(nn.Module): 491 | """Transformer Decoder Stack using DenseGeneral.""" 492 | 493 | def __init__(self, config: DiaConfig, compute_dtype: torch.dtype): 494 | super().__init__() 495 | self.config = config 496 | model_config = config.model 497 | dec_config = config.model.decoder 498 | data_config = config.data 499 | self.num_channels = data_config.channels 500 | self.num_layers = dec_config.n_layer 501 | 502 | self.embeddings = nn.ModuleList( 503 | [ 504 | nn.Embedding( 505 | model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype 506 | ) 507 | for _ in range(self.num_channels) 508 | ] 509 | ) 510 | self.layers = nn.ModuleList( 511 | [ 512 | DecoderLayer(config=config, compute_dtype=compute_dtype) 513 | for _ in range(self.num_layers) 514 | ] 515 | ) 516 | 517 | self.norm = RMSNorm( 518 | dec_config.n_embd, 519 | eps=model_config.normalization_layer_epsilon, 520 | dtype=torch.float32, 521 | ) 522 | 523 | self.logits_dense = DenseGeneral( 524 | in_shapes=(dec_config.n_embd,), 525 | out_features=(self.num_channels, model_config.tgt_vocab_size), 526 | axis=(-1,), 527 | weight_dtype=compute_dtype, 528 | ) 529 | 530 | def precompute_cross_attn_cache( 531 | self, 532 | enc_out: torch.Tensor, # (B, S, E) 533 | enc_positions: torch.Tensor, # (B, S) 534 | ) -> list[KVCache]: 535 | """ 536 | Computes the Key and Value tensors for cross-attention for each layer from the encoder output. 537 | """ 538 | per_layer_kv_cache: list[KVCache] = [] 539 | 540 | for layer in self.layers: 541 | cross_attn_module = layer.cross_attention 542 | k_proj = cross_attn_module.k_proj(enc_out) 543 | v_proj = cross_attn_module.v_proj(enc_out) 544 | 545 | k_proj = cross_attn_module.rotary_emb(k_proj, position=enc_positions) 546 | k = k_proj.transpose(1, 2) 547 | v = v_proj.transpose(1, 2) 548 | 549 | per_layer_kv_cache.append(KVCache.from_kv(k, v)) 550 | 551 | return per_layer_kv_cache 552 | 553 | def decode_step( 554 | self, 555 | tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C] 556 | state: DecoderInferenceState, 557 | ) -> torch.Tensor: 558 | """ 559 | Performs a single decoding step, managing KV caches layer by layer. 560 | 561 | Returns: 562 | A tuple containing: 563 | - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32. 564 | """ 565 | 566 | x = None 567 | for i in range(self.num_channels): 568 | channel_tokens = tgt_ids_Bx1xC[..., i] 569 | channel_embed = self.embeddings[i](channel_tokens) 570 | x = channel_embed if x is None else x + channel_embed 571 | 572 | for i, layer in enumerate(self.layers): 573 | self_cache = state.self_attn_cache[i] 574 | cross_cache = state.cross_attn_cache[i] 575 | x = layer( 576 | x, # (2, 1, D) 577 | state, 578 | self_attn_cache=self_cache, 579 | cross_attn_cache=cross_cache, 580 | ) 581 | 582 | x = self.norm(x) 583 | logits_Bx1xCxV = self.logits_dense(x) 584 | 585 | return logits_Bx1xCxV.to(torch.float32) 586 | 587 | def forward( 588 | self, tgt_ids_BxTxC: torch.Tensor, state: DecoderInferenceState 589 | ) -> torch.Tensor: 590 | """ 591 | Forward pass for the Decoder stack, managing KV caches. 592 | 593 | Args: 594 | tgt_ids_BxTxC: Target token IDs (B, T, C). 595 | encoder_out: Output from the encoder (B, S, E). 596 | tgt_positions: Positions for target sequence (B, T). 597 | src_positions: Positions for source sequence (B, S). 598 | self_attn_mask: Mask for self-attention. 599 | cross_attn_mask: Mask for cross-attention. 600 | past_key_values: List containing the self-attention KV cache for each layer 601 | from the previous decoding step. `len(past_key_values)` should 602 | equal `num_layers`. 603 | precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache 604 | derived from `encoder_out`. This is passed identically 605 | to all layers. 606 | 607 | Returns: 608 | A tuple containing: 609 | - logits: The final output logits (B, T, C * V), cast to float32. 610 | - present_key_values: A list containing the updated self-attention KV cache 611 | for each layer for the *current* decoding step. 612 | """ 613 | _, _, num_channels_in = tgt_ids_BxTxC.shape 614 | assert num_channels_in == self.num_channels, "Input channels mismatch" 615 | 616 | # Embeddings 617 | x = None 618 | for i in range(self.num_channels): 619 | channel_tokens = tgt_ids_BxTxC[..., i] 620 | channel_embed = self.embeddings[i](channel_tokens) 621 | x = channel_embed if x is None else x + channel_embed 622 | 623 | for i, layer in enumerate(self.layers): 624 | self_cache = state.self_attn_cache[i] 625 | cross_cache = state.cross_attn_cache[i] 626 | x = layer( 627 | x, 628 | state, 629 | self_attn_cache=self_cache, 630 | cross_attn_cache=cross_cache, 631 | prefill=True, 632 | ) 633 | 634 | # Final Norm 635 | x = self.norm(x) 636 | logits_BxTxCxV = self.logits_dense(x) 637 | 638 | return logits_BxTxCxV.to(torch.float32) 639 | 640 | 641 | class DiaModel( 642 | nn.Module, 643 | PyTorchModelHubMixin, 644 | repo_url="https://github.com/nari-labs/dia", 645 | pipeline_tag="text-to-speech", 646 | license="apache-2.0", 647 | coders={ 648 | DiaConfig: ( 649 | lambda x: x.model_dump(), 650 | lambda data: DiaConfig.model_validate(data), 651 | ), 652 | }, 653 | ): 654 | """PyTorch Dia Model using DenseGeneral.""" 655 | 656 | def __init__(self, config: DiaConfig, compute_dtype: torch.dtype): 657 | super().__init__() 658 | self.config = config 659 | self.encoder = Encoder(config, compute_dtype) 660 | self.decoder = Decoder(config, compute_dtype) 661 | -------------------------------------------------------------------------------- /dia/model.py: -------------------------------------------------------------------------------- 1 | import time 2 | from enum import Enum 3 | 4 | import dac 5 | import numpy as np 6 | import torch 7 | import torchaudio 8 | 9 | from .audio import ( 10 | apply_audio_delay, 11 | build_delay_indices, 12 | build_revert_indices, 13 | decode, 14 | revert_audio_delay, 15 | ) 16 | from .config import DiaConfig 17 | from .layers import DiaModel 18 | from .state import DecoderInferenceState, DecoderOutput, EncoderInferenceState 19 | 20 | 21 | DEFAULT_SAMPLE_RATE = 44100 22 | 23 | 24 | def _get_default_device(): 25 | if torch.cuda.is_available(): 26 | return torch.device("cuda") 27 | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): 28 | return torch.device("mps") 29 | return torch.device("cpu") 30 | 31 | 32 | def _sample_next_token( 33 | logits_BCxV: torch.Tensor, 34 | temperature: float, 35 | top_p: float, 36 | cfg_filter_top_k: int | None = None, 37 | generator: torch.Generator = None, # Added generator parameter 38 | ) -> torch.Tensor: 39 | if temperature == 0.0: 40 | return torch.argmax(logits_BCxV, dim=-1) 41 | 42 | logits_BCxV = logits_BCxV / temperature 43 | if cfg_filter_top_k is not None: 44 | _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1) 45 | mask = torch.ones_like(logits_BCxV, dtype=torch.bool) 46 | mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False) 47 | logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf) 48 | 49 | if top_p < 1.0: 50 | probs_BCxV = torch.softmax(logits_BCxV, dim=-1) 51 | sorted_probs_BCxV, sorted_indices_BCxV = torch.sort( 52 | probs_BCxV, dim=-1, descending=True 53 | ) 54 | cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1) 55 | 56 | sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p 57 | sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[ 58 | ..., :-1 59 | ].clone() 60 | sorted_indices_to_remove_BCxV[..., 0] = 0 61 | 62 | indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV) 63 | indices_to_remove_BCxV.scatter_( 64 | dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV 65 | ) 66 | logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf) 67 | 68 | final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1) 69 | 70 | sampled_indices_BC = torch.multinomial( 71 | final_probs_BCxV, 72 | num_samples=1, 73 | generator=generator, # Pass generator to multinomial 74 | ) 75 | sampled_indices_C = sampled_indices_BC.squeeze(-1) 76 | return sampled_indices_C 77 | 78 | 79 | class ComputeDtype(str, Enum): 80 | FLOAT32 = "float32" 81 | FLOAT16 = "float16" 82 | BFLOAT16 = "bfloat16" 83 | 84 | def to_dtype(self) -> torch.dtype: 85 | if self == ComputeDtype.FLOAT32: 86 | return torch.float32 87 | elif self == ComputeDtype.FLOAT16: 88 | return torch.float16 89 | elif self == ComputeDtype.BFLOAT16: 90 | return torch.bfloat16 91 | else: 92 | raise ValueError(f"Unsupported compute dtype: {self}") 93 | 94 | 95 | class Dia: 96 | def __init__( 97 | self, 98 | config: DiaConfig, 99 | compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32, 100 | device: torch.device | None = None, 101 | ): 102 | """Initializes the Dia model. 103 | 104 | Args: 105 | config: The configuration object for the model. 106 | device: The device to load the model onto. If None, will automatically select the best available device. 107 | 108 | Raises: 109 | RuntimeError: If there is an error loading the DAC model. 110 | """ 111 | super().__init__() 112 | self.config = config 113 | self.device = device if device is not None else _get_default_device() 114 | if isinstance(compute_dtype, str): 115 | compute_dtype = ComputeDtype(compute_dtype) 116 | self.compute_dtype = compute_dtype.to_dtype() 117 | self.model = DiaModel(config, self.compute_dtype) 118 | self.dac_model = None 119 | 120 | @classmethod 121 | def from_local( 122 | cls, 123 | config_path: str, 124 | checkpoint_path: str, 125 | compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32, 126 | device: torch.device | None = None, 127 | ) -> "Dia": 128 | """Loads the Dia model from local configuration and checkpoint files. 129 | 130 | Args: 131 | config_path: Path to the configuration JSON file. 132 | checkpoint_path: Path to the model checkpoint (.pth) file. 133 | device: The device to load the model onto. If None, will automatically select the best available device. 134 | 135 | Returns: 136 | An instance of the Dia model loaded with weights and set to eval mode. 137 | 138 | Raises: 139 | FileNotFoundError: If the config or checkpoint file is not found. 140 | RuntimeError: If there is an error loading the checkpoint. 141 | """ 142 | config = DiaConfig.load(config_path) 143 | if config is None: 144 | raise FileNotFoundError(f"Config file not found at {config_path}") 145 | 146 | dia = cls(config, compute_dtype, device) 147 | 148 | try: 149 | state_dict = torch.load(checkpoint_path, map_location=dia.device) 150 | dia.model.load_state_dict(state_dict) 151 | except FileNotFoundError: 152 | raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}") 153 | except Exception as e: 154 | raise RuntimeError( 155 | f"Error loading checkpoint from {checkpoint_path}" 156 | ) from e 157 | 158 | dia.model.to(dia.device) 159 | dia.model.eval() 160 | dia._load_dac_model() 161 | return dia 162 | 163 | @classmethod 164 | def from_pretrained( 165 | cls, 166 | model_name: str = "nari-labs/Dia-1.6B", 167 | compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32, 168 | device: torch.device | None = None, 169 | ) -> "Dia": 170 | """Loads the Dia model from a Hugging Face Hub repository. 171 | 172 | Downloads the configuration and checkpoint files from the specified 173 | repository ID and then loads the model. 174 | 175 | Args: 176 | model_name: The Hugging Face Hub repository ID (e.g., "nari-labs/Dia-1.6B"). 177 | compute_dtype: The computation dtype to use. 178 | device: The device to load the model onto. If None, will automatically select the best available device. 179 | 180 | Returns: 181 | An instance of the Dia model loaded with weights and set to eval mode. 182 | 183 | Raises: 184 | FileNotFoundError: If config or checkpoint download/loading fails. 185 | RuntimeError: If there is an error loading the checkpoint. 186 | """ 187 | if isinstance(compute_dtype, str): 188 | compute_dtype = ComputeDtype(compute_dtype) 189 | loaded_model = DiaModel.from_pretrained( 190 | model_name, compute_dtype=compute_dtype.to_dtype() 191 | ) 192 | config = loaded_model.config 193 | dia = cls(config, compute_dtype, device) 194 | 195 | dia.model = loaded_model 196 | dia.model.to(dia.device) 197 | dia.model.eval() 198 | dia._load_dac_model() 199 | return dia 200 | 201 | def _load_dac_model(self): 202 | try: 203 | dac_model_path = dac.utils.download() 204 | dac_model = dac.DAC.load(dac_model_path).to(self.device) 205 | except Exception as e: 206 | raise RuntimeError("Failed to load DAC model") from e 207 | self.dac_model = dac_model 208 | 209 | def _prepare_text_input(self, text: str) -> torch.Tensor: 210 | """Encodes text prompt, pads, and creates attention mask and positions.""" 211 | text_pad_value = self.config.data.text_pad_value 212 | max_len = self.config.data.text_length 213 | 214 | byte_text = text.encode("utf-8") 215 | replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02") 216 | text_tokens = list(replaced_bytes) 217 | 218 | current_len = len(text_tokens) 219 | padding_needed = max_len - current_len 220 | if padding_needed <= 0: 221 | text_tokens = text_tokens[:max_len] 222 | padded_text_np = np.array(text_tokens, dtype=np.uint8) 223 | else: 224 | padded_text_np = np.pad( 225 | text_tokens, 226 | (0, padding_needed), 227 | mode="constant", 228 | constant_values=text_pad_value, 229 | ).astype(np.uint8) 230 | 231 | src_tokens = ( 232 | torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0) 233 | ) # [1, S] 234 | return src_tokens 235 | 236 | def _prepare_audio_prompt( 237 | self, audio_prompt: torch.Tensor | None 238 | ) -> tuple[torch.Tensor, int]: 239 | num_channels = self.config.data.channels 240 | audio_bos_value = self.config.data.audio_bos_value 241 | audio_pad_value = self.config.data.audio_pad_value 242 | delay_pattern = self.config.data.delay_pattern 243 | max_delay_pattern = max(delay_pattern) 244 | 245 | prefill = torch.full( 246 | (1, num_channels), 247 | fill_value=audio_bos_value, 248 | dtype=torch.int, 249 | device=self.device, 250 | ) 251 | 252 | prefill_step = 1 253 | 254 | if audio_prompt is not None: 255 | prefill_step += audio_prompt.shape[0] 256 | prefill = torch.cat([prefill, audio_prompt], dim=0) 257 | 258 | delay_pad_tensor = torch.full( 259 | (max_delay_pattern, num_channels), 260 | fill_value=-1, 261 | dtype=torch.int, 262 | device=self.device, 263 | ) 264 | prefill = torch.cat([prefill, delay_pad_tensor], dim=0) 265 | 266 | delay_precomp = build_delay_indices( 267 | B=1, 268 | T=prefill.shape[0], 269 | C=num_channels, 270 | delay_pattern=delay_pattern, 271 | ) 272 | 273 | prefill = apply_audio_delay( 274 | audio_BxTxC=prefill.unsqueeze(0), 275 | pad_value=audio_pad_value, 276 | bos_value=audio_bos_value, 277 | precomp=delay_precomp, 278 | ).squeeze(0) 279 | 280 | return prefill, prefill_step 281 | 282 | def _prepare_generation( 283 | self, text: str, audio_prompt: str | torch.Tensor | None, verbose: bool 284 | ): 285 | enc_input_cond = self._prepare_text_input(text) 286 | enc_input_uncond = torch.zeros_like(enc_input_cond) 287 | enc_input = torch.cat([enc_input_uncond, enc_input_cond], dim=0) 288 | 289 | if isinstance(audio_prompt, str): 290 | audio_prompt = self.load_audio(audio_prompt) 291 | prefill, prefill_step = self._prepare_audio_prompt(audio_prompt) 292 | 293 | if verbose: 294 | print("generate: data loaded") 295 | 296 | enc_state = EncoderInferenceState.new(self.config, enc_input_cond) 297 | encoder_out = self.model.encoder(enc_input, enc_state) 298 | 299 | # Clean up inputs after encoding 300 | del enc_input_uncond, enc_input 301 | 302 | dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache( 303 | encoder_out, enc_state.positions 304 | ) 305 | dec_state = DecoderInferenceState.new( 306 | self.config, 307 | enc_state, 308 | encoder_out, 309 | dec_cross_attn_cache, 310 | self.compute_dtype, 311 | ) 312 | 313 | # Can delete encoder output after it's used by decoder state 314 | del dec_cross_attn_cache 315 | 316 | dec_output = DecoderOutput.new(self.config, self.device) 317 | dec_output.prefill(prefill, prefill_step) 318 | 319 | dec_step = prefill_step - 1 320 | if dec_step > 0: 321 | dec_state.prepare_step(0, dec_step) 322 | tokens_BxTxC = ( 323 | dec_output.get_tokens_at(0, dec_step).unsqueeze(0).expand(2, -1, -1) 324 | ) 325 | self.model.decoder.forward(tokens_BxTxC, dec_state) 326 | 327 | return dec_state, dec_output 328 | 329 | def _decoder_step( 330 | self, 331 | tokens_Bx1xC: torch.Tensor, 332 | dec_state: DecoderInferenceState, 333 | cfg_scale: float, 334 | temperature: float, 335 | top_p: float, 336 | cfg_filter_top_k: int, 337 | generator: torch.Generator = None, # Added generator parameter 338 | ) -> torch.Tensor: 339 | audio_eos_value = self.config.data.audio_eos_value 340 | logits_Bx1xCxV = self.model.decoder.decode_step(tokens_Bx1xC, dec_state) 341 | 342 | logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :] 343 | # ADD: Remove the full logits tensor 344 | del logits_Bx1xCxV 345 | 346 | uncond_logits_CxV = logits_last_BxCxV[0, :, :] 347 | cond_logits_CxV = logits_last_BxCxV[1, :, :] 348 | # ADD: Remove the combined logits 349 | del logits_last_BxCxV 350 | 351 | logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV) 352 | # ADD: Remove component logits 353 | del uncond_logits_CxV, cond_logits_CxV 354 | 355 | logits_CxV[:, audio_eos_value + 1 :] = -torch.inf 356 | logits_CxV[1:, audio_eos_value:] = -torch.inf 357 | 358 | pred_C = _sample_next_token( 359 | logits_CxV.float(), 360 | temperature=temperature, 361 | top_p=top_p, 362 | cfg_filter_top_k=cfg_filter_top_k, 363 | generator=generator, # Pass generator to _sample_next_token 364 | ) 365 | # ADD: Remove final logits 366 | del logits_CxV 367 | 368 | return pred_C 369 | 370 | def _generate_output(self, generated_codes: torch.Tensor) -> np.ndarray: 371 | num_channels = self.config.data.channels 372 | seq_length = generated_codes.shape[0] 373 | delay_pattern = self.config.data.delay_pattern 374 | audio_pad_value = self.config.data.audio_pad_value 375 | max_delay_pattern = max(delay_pattern) 376 | 377 | revert_precomp = build_revert_indices( 378 | B=1, 379 | T=seq_length, 380 | C=num_channels, 381 | delay_pattern=delay_pattern, 382 | ) 383 | 384 | codebook = revert_audio_delay( 385 | audio_BxTxC=generated_codes.unsqueeze(0), 386 | pad_value=audio_pad_value, 387 | precomp=revert_precomp, 388 | T=seq_length, 389 | )[:, :-max_delay_pattern, :] 390 | 391 | # ADD: Clean up intermediate tensors 392 | del revert_precomp, generated_codes 393 | 394 | min_valid_index = 0 395 | max_valid_index = 1023 396 | invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index) 397 | codebook[invalid_mask] = 0 398 | 399 | # Process final audio 400 | audio = decode(self.dac_model, codebook.transpose(1, 2)) 401 | 402 | # ADD: Clean up codebook 403 | del codebook, invalid_mask 404 | 405 | result = audio.squeeze().cpu().numpy() 406 | 407 | # ADD: Clean up audio tensor 408 | del audio 409 | torch.cuda.empty_cache() 410 | 411 | return result 412 | 413 | def load_audio(self, audio_path: str) -> torch.Tensor: 414 | audio, sr = torchaudio.load(audio_path, channels_first=True) # C, T 415 | if sr != DEFAULT_SAMPLE_RATE: 416 | audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE) 417 | audio = audio.to(self.device).unsqueeze(0) # 1, C, T 418 | audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE) 419 | _, encoded_frame, _, _, _ = self.dac_model.encode(audio_data) # 1, C, T 420 | 421 | # Clean up intermediate tensors 422 | del audio_data 423 | if sr != DEFAULT_SAMPLE_RATE: 424 | del audio 425 | encoded_result = encoded_frame.squeeze(0).transpose(0, 1) 426 | del encoded_frame 427 | 428 | return encoded_result 429 | 430 | def save_audio(self, path: str, audio: np.ndarray): 431 | import soundfile as sf 432 | 433 | sf.write(path, audio, DEFAULT_SAMPLE_RATE) 434 | 435 | @torch.inference_mode() 436 | def generate( 437 | self, 438 | text: str, 439 | max_tokens: int | None = None, 440 | cfg_scale: float = 3.0, 441 | temperature: float = 1.3, 442 | top_p: float = 0.95, 443 | use_torch_compile: bool = False, 444 | cfg_filter_top_k: int = 35, 445 | audio_prompt: str | torch.Tensor | None = None, 446 | audio_prompt_path: str | None = None, 447 | use_cfg_filter: bool | None = None, 448 | seed: int = 42, # Added seed parameter 449 | verbose: bool = True, 450 | text_to_generate_size: int | None = None, 451 | ) -> np.ndarray: 452 | # Import tqdm here to avoid requiring it as a dependency if progress bar isn't used 453 | from tqdm import tqdm 454 | 455 | audio_eos_value = self.config.data.audio_eos_value 456 | audio_pad_value = self.config.data.audio_pad_value 457 | delay_pattern = self.config.data.delay_pattern 458 | 459 | # Estimate tokens based on text length (using the 6.19 tokens per char ratio) 460 | if audio_prompt: 461 | estimated_tokens = int(text_to_generate_size * 3.4) 462 | else: 463 | estimated_tokens = int(text_to_generate_size * 6.7) 464 | 465 | current_step = 0 466 | 467 | # Cap at model maximum 468 | max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens 469 | 470 | max_delay_pattern = max(delay_pattern) if delay_pattern else 0 471 | self.model.eval() 472 | 473 | if audio_prompt_path: 474 | print("Warning: audio_prompt_path is deprecated. Use audio_prompt instead.") 475 | audio_prompt = audio_prompt_path 476 | if use_cfg_filter is not None: 477 | print("Warning: use_cfg_filter is deprecated.") 478 | 479 | # Create generator from seed if provided 480 | generator = None 481 | if seed >= 0: 482 | if verbose: 483 | print(f"Using seed: {seed} for generation") 484 | generator = torch.Generator(device=self.device) 485 | generator.manual_seed(seed) 486 | else: 487 | if verbose: 488 | print("Using random seed for generation") 489 | 490 | if verbose: 491 | total_start_time = time.time() 492 | 493 | # Prepare generation states 494 | dec_state, dec_output = self._prepare_generation(text, audio_prompt, verbose) 495 | dec_step = dec_output.prefill_step - 1 496 | 497 | bos_countdown = max_delay_pattern 498 | eos_detected = False 499 | eos_countdown = -1 500 | 501 | if use_torch_compile: 502 | step_fn = torch.compile(self._decoder_step, mode="default") 503 | else: 504 | step_fn = self._decoder_step 505 | 506 | if verbose: 507 | print("generate: starting generation loop") 508 | if use_torch_compile: 509 | print( 510 | "generate: by using use_torch_compile=True, the first step would take long" 511 | ) 512 | start_time = time.time() 513 | 514 | token_history = [] # Track generated tokens 515 | 516 | try: 517 | while dec_step < max_tokens: 518 | dec_state.prepare_step(dec_step) 519 | tokens_Bx1xC = ( 520 | dec_output.get_tokens_at(dec_step).unsqueeze(0).expand(2, -1, -1) 521 | ) 522 | pred_C = step_fn( 523 | tokens_Bx1xC, 524 | dec_state, 525 | cfg_scale, 526 | temperature, 527 | top_p, 528 | cfg_filter_top_k, 529 | generator=generator, # Pass generator to step_fn 530 | ) 531 | # Clean up tokens after use 532 | del tokens_Bx1xC 533 | 534 | if ( 535 | not eos_detected and pred_C[0] == audio_eos_value 536 | ) or dec_step == max_tokens - max_delay_pattern - 1: 537 | eos_detected = True 538 | eos_countdown = max_delay_pattern 539 | 540 | if eos_countdown > 0: 541 | step_after_eos = max_delay_pattern - eos_countdown 542 | for i, d in enumerate(delay_pattern): 543 | if step_after_eos == d: 544 | pred_C[i] = audio_eos_value 545 | elif step_after_eos > d: 546 | pred_C[i] = audio_pad_value 547 | eos_countdown -= 1 548 | 549 | bos_countdown = max(0, bos_countdown - 1) 550 | dec_output.update_one(pred_C, dec_step + 1, bos_countdown > 0) 551 | 552 | # Add to token history 553 | token_history.append(pred_C.detach().clone()) 554 | 555 | if eos_countdown == 0: 556 | break 557 | 558 | dec_step += 1 559 | 560 | if verbose and dec_step % 86 == 0: 561 | duration = time.time() - start_time 562 | print( 563 | f"generate step {dec_step}: speed={86 / duration:.3f} tokens/s, realtime factor={1 / duration:.3f}x" 564 | ) 565 | start_time = time.time() 566 | 567 | # Periodic memory cleanup during long generations 568 | if torch.cuda.is_available() and dec_step > 500: 569 | torch.cuda.empty_cache() 570 | 571 | except Exception as e: 572 | # Ensure cleanup even on error 573 | if "dec_state" in locals(): 574 | del dec_state 575 | if "dec_output" in locals(): 576 | del dec_output 577 | if "token_history" in locals(): 578 | del token_history 579 | self.reset_state() 580 | raise e 581 | 582 | finally: 583 | # Close progress bar in finally block to ensure it's closed even on error 584 | # pbar.close() 585 | 586 | # Check if dec_output was created 587 | if "dec_output" in locals() and dec_output is not None: 588 | if dec_output.prefill_step >= dec_step + 1: 589 | print("Warning: Nothing generated") 590 | # Cleanup on early return 591 | if "dec_state" in locals(): 592 | del dec_state 593 | if "dec_output" in locals(): 594 | del dec_output 595 | if "token_history" in locals(): 596 | del token_history 597 | self.reset_state() 598 | return None 599 | 600 | generated_codes = dec_output.generated_tokens[ 601 | dec_output.prefill_step : dec_step + 1, : 602 | ] 603 | 604 | # Clean up state objects 605 | if "dec_state" in locals(): 606 | del dec_state 607 | if "dec_output" in locals(): 608 | del dec_output 609 | if "token_history" in locals(): 610 | del token_history 611 | 612 | if verbose: 613 | if "dec_output" in locals(): 614 | total_step = dec_step + 1 - dec_output.prefill_step 615 | total_duration = time.time() - total_start_time 616 | print( 617 | f"generate: total step={total_step}, total duration={total_duration:.3f}s" 618 | ) 619 | 620 | # Process output 621 | output = self._generate_output(generated_codes) 622 | 623 | # Final cleanup 624 | del generated_codes 625 | self.reset_state() 626 | 627 | return output 628 | else: 629 | # Handle case where dec_output wasn't created 630 | self.reset_state() 631 | return None 632 | 633 | def reset_state(self): 634 | """Reset internal model state and clear CUDA cache to prevent memory leaks.""" 635 | # 1. Clear any cached states in the model 636 | if hasattr(self, "dac_model") and hasattr(self.dac_model, "reset"): 637 | self.dac_model.reset() 638 | 639 | # 2. Clear any encoder/decoder buffers 640 | if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "buffers"): 641 | for buffer in self.model.encoder.buffers(): 642 | if buffer.device.type == "cuda": 643 | buffer.data = buffer.data.clone() # Force copy to break references 644 | 645 | if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "buffers"): 646 | for buffer in self.model.decoder.buffers(): 647 | if buffer.device.type == "cuda": 648 | buffer.data = buffer.data.clone() # Force copy to break references 649 | 650 | # 3. Force garbage collection first 651 | import gc 652 | 653 | gc.collect() 654 | 655 | # 4. Then clear CUDA cache 656 | if torch.cuda.is_available(): 657 | torch.cuda.empty_cache() 658 | 659 | return True 660 | -------------------------------------------------------------------------------- /dia/state.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | 5 | from .config import DiaConfig 6 | 7 | 8 | def create_attn_mask( 9 | q_padding_mask_1d: torch.Tensor, 10 | k_padding_mask_1d: torch.Tensor, 11 | device: torch.device, 12 | is_causal: bool = False, 13 | ) -> torch.Tensor: 14 | """ 15 | Creates the attention mask (self or cross) mimicking JAX segment ID logic. 16 | """ 17 | B1, Tq = q_padding_mask_1d.shape 18 | B2, Tk = k_padding_mask_1d.shape 19 | assert B1 == B2, "Query and key batch dimensions must match" 20 | 21 | p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1] 22 | p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk] 23 | 24 | # Condition A: Non-padding query attends to non-padding key 25 | non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk] 26 | 27 | # Condition B: Padding query attends to padding key 28 | pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk] 29 | 30 | # Combine: True if padding status is compatible (both non-pad OR both pad) 31 | mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk] 32 | 33 | if is_causal: 34 | assert ( 35 | Tq == Tk 36 | ), "Causal mask requires query and key sequence lengths to be equal" 37 | causal_mask_2d = torch.tril( 38 | torch.ones((Tq, Tk), dtype=torch.bool, device=device) 39 | ) # Shape [Tq, Tk] 40 | causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk] 41 | return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] 42 | else: 43 | return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] 44 | 45 | 46 | @dataclass 47 | class EncoderInferenceState: 48 | """Parameters specifically for encoder inference.""" 49 | 50 | max_seq_len: int 51 | device: torch.device 52 | positions: torch.Tensor 53 | padding_mask: torch.Tensor 54 | attn_mask: torch.Tensor 55 | 56 | @classmethod 57 | def new(cls, config: DiaConfig, cond_src: torch.Tensor) -> "EncoderInferenceState": 58 | """Creates EtorchrInferenceParams from DiaConfig and a device.""" 59 | device = cond_src.device 60 | 61 | positions = ( 62 | torch.arange(config.data.text_length, dtype=torch.float32, device=device) 63 | .unsqueeze(0) 64 | .expand(2, -1) 65 | ) 66 | padding_mask = (cond_src != config.data.text_pad_value).to(device).expand(2, -1) 67 | attn_mask = create_attn_mask( 68 | padding_mask, padding_mask, device, is_causal=False 69 | ) 70 | 71 | return cls( 72 | max_seq_len=config.data.text_length, 73 | device=device, 74 | positions=positions, 75 | padding_mask=padding_mask, 76 | attn_mask=attn_mask, 77 | ) 78 | 79 | 80 | class KVCache: 81 | def __init__( 82 | self, 83 | num_heads: int, 84 | max_len: int, 85 | head_dim: int, 86 | dtype: torch.dtype, 87 | device: torch.device, 88 | k: torch.Tensor | None = None, 89 | v: torch.Tensor | None = None, 90 | ): 91 | self.k = ( 92 | torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device) 93 | if k is None 94 | else k 95 | ) 96 | self.v = ( 97 | torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device) 98 | if v is None 99 | else v 100 | ) 101 | self.current_idx = torch.tensor(0) 102 | 103 | @classmethod 104 | def from_kv(cls, k: torch.Tensor, v: torch.Tensor) -> "KVCache": 105 | return cls( 106 | num_heads=k.shape[1], 107 | max_len=k.shape[2], 108 | head_dim=k.shape[3], 109 | dtype=k.dtype, 110 | device=k.device, 111 | k=k, 112 | v=v, 113 | ) 114 | 115 | def update( 116 | self, k: torch.Tensor, v: torch.Tensor 117 | ) -> tuple[torch.Tensor, torch.Tensor]: 118 | self.k[:, :, self.current_idx : self.current_idx + 1, :] = k 119 | self.v[:, :, self.current_idx : self.current_idx + 1, :] = v 120 | self.current_idx += 1 121 | return self.k[:, :, : self.current_idx, :], self.v[:, :, : self.current_idx, :] 122 | 123 | def prefill( 124 | self, k: torch.Tensor, v: torch.Tensor 125 | ) -> tuple[torch.Tensor, torch.Tensor]: 126 | prefill_len = k.shape[2] 127 | self.k[:, :, :prefill_len, :] = k 128 | self.v[:, :, :prefill_len, :] = v 129 | self.current_idx = prefill_len - 1 130 | 131 | 132 | @dataclass 133 | class DecoderInferenceState: 134 | """Parameters specifically for decoder inference.""" 135 | 136 | device: torch.device 137 | dtype: torch.dtype 138 | enc_out: torch.Tensor 139 | enc_positions: torch.Tensor 140 | dec_positions: torch.Tensor 141 | dec_cross_attn_mask: torch.Tensor 142 | self_attn_cache: list[KVCache] 143 | cross_attn_cache: list[KVCache] 144 | 145 | @classmethod 146 | def new( 147 | cls, 148 | config: DiaConfig, 149 | enc_state: EncoderInferenceState, 150 | enc_out: torch.Tensor, 151 | dec_cross_attn_cache: list[KVCache], 152 | compute_dtype: torch.dtype, 153 | ) -> "DecoderInferenceState": 154 | """Creates DecoderInferenceParams from DiaConfig and a device.""" 155 | device = enc_out.device 156 | max_audio_len = config.data.audio_length 157 | 158 | dec_positions = torch.full( 159 | (2, 1), fill_value=0, dtype=torch.long, device=device 160 | ) 161 | tgt_padding_mask = torch.ones((2, 1), dtype=torch.bool, device=device) 162 | dec_cross_attn_mask = create_attn_mask( 163 | tgt_padding_mask, enc_state.padding_mask, device, is_causal=False 164 | ) 165 | 166 | self_attn_cache = [ 167 | KVCache( 168 | config.model.decoder.kv_heads, 169 | max_audio_len, 170 | config.model.decoder.gqa_head_dim, 171 | compute_dtype, 172 | device, 173 | ) 174 | for _ in range(config.model.decoder.n_layer) 175 | ] 176 | 177 | return cls( 178 | device=device, 179 | dtype=compute_dtype, 180 | enc_out=enc_out, 181 | enc_positions=enc_state.positions, 182 | dec_positions=dec_positions, 183 | dec_cross_attn_mask=dec_cross_attn_mask, 184 | self_attn_cache=self_attn_cache, 185 | cross_attn_cache=dec_cross_attn_cache, 186 | ) 187 | 188 | def prepare_step(self, step_from: int, step_to: int | None = None) -> None: 189 | if step_to is None: 190 | step_to = step_from + 1 191 | self.dec_positions = ( 192 | torch.arange(step_from, step_to, dtype=torch.float32, device=self.device) 193 | .unsqueeze(0) 194 | .expand(2, -1) 195 | ) 196 | 197 | 198 | @dataclass 199 | class DecoderOutput: 200 | generated_tokens: torch.Tensor 201 | prefill_step: int 202 | 203 | @classmethod 204 | def new(cls, config: DiaConfig, device: torch.device) -> "DecoderOutput": 205 | max_audio_len = config.data.audio_length 206 | return cls( 207 | generated_tokens=torch.full( 208 | (max_audio_len, config.data.channels), 209 | fill_value=-1, 210 | dtype=torch.int, 211 | device=device, 212 | ), 213 | prefill_step=0, 214 | ) 215 | 216 | def get_tokens_at(self, step_from: int, step_to: int | None = None) -> torch.Tensor: 217 | if step_to is None: 218 | step_to = step_from + 1 219 | return self.generated_tokens[step_from:step_to, :] 220 | 221 | def update_one(self, dec_out: torch.Tensor, step: int, apply_mask: bool = False): 222 | if apply_mask: 223 | mask = self.generated_tokens[step : step + 1, :] == -1 224 | self.generated_tokens[step : step + 1, :] = torch.where( 225 | mask, dec_out, self.generated_tokens[step : step + 1, :] 226 | ) 227 | else: 228 | self.generated_tokens[step : step + 1, :] = dec_out 229 | 230 | def prefill(self, dec_out: torch.Tensor, prefill_step: int): 231 | length = dec_out.shape[0] 232 | self.generated_tokens[0:length, :] = dec_out 233 | self.prefill_step = prefill_step 234 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | dia-tts-server: 5 | build: 6 | context: . 7 | dockerfile: Dockerfile 8 | ports: 9 | - "${PORT:-8003}:${PORT:-8003}" 10 | volumes: 11 | # Mount local directories into the container for persistent data 12 | - ./model_cache:/app/model_cache 13 | - ./reference_audio:/app/reference_audio 14 | - ./outputs:/app/outputs 15 | - ./voices:/app/voices 16 | # DO NOT mount config.yaml - let the app create it inside 17 | 18 | # --- GPU Access --- 19 | # Modern method (Recommended for newer Docker/NVIDIA setups) 20 | devices: 21 | - nvidia.com/gpu=all 22 | device_cgroup_rules: 23 | - "c 195:* rmw" # Needed for some NVIDIA container toolkit versions 24 | - "c 236:* rmw" # Needed for some NVIDIA container toolkit versions 25 | 26 | # Legacy method (Alternative for older Docker/NVIDIA setups) 27 | # If the 'devices' block above doesn't work, comment it out and uncomment 28 | # the 'deploy' block below. Do not use both simultaneously. 29 | # deploy: 30 | # resources: 31 | # reservations: 32 | # devices: 33 | # - driver: nvidia 34 | # count: 1 # Or specify specific GPUs e.g., "device=0,1" 35 | # capabilities: [gpu] 36 | # --- End GPU Access --- 37 | 38 | restart: unless-stopped 39 | env_file: 40 | # Load environment variables from .env file for initial config seeding 41 | - .env 42 | environment: 43 | # Enable faster Hugging Face downloads inside the container 44 | - HF_HUB_ENABLE_HF_TRANSFER=1 45 | # Pass GPU capabilities (may be needed for legacy method if uncommented) 46 | - NVIDIA_VISIBLE_DEVICES=all 47 | - NVIDIA_DRIVER_CAPABILITIES=compute,utility 48 | -------------------------------------------------------------------------------- /download_model.py: -------------------------------------------------------------------------------- 1 | # download_model.py 2 | # Utility script to download the Dia model and dependencies without starting the server. 3 | 4 | import logging 5 | import os 6 | import engine # Import the engine module to trigger its loading logic 7 | 8 | # Import Whisper for model download check 9 | try: 10 | import whisper 11 | 12 | WHISPER_AVAILABLE = True 13 | except ImportError: 14 | WHISPER_AVAILABLE = False 15 | logging.warning("Whisper library not found. Cannot download Whisper model.") 16 | 17 | # Configure basic logging for the script 18 | logging.basicConfig( 19 | level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" 20 | ) 21 | logger = logging.getLogger("ModelDownloader") 22 | 23 | if __name__ == "__main__": 24 | logger.info("--- Starting Dia & Whisper Model Download ---") 25 | 26 | # Ensure cache directory exists (redundant if engine.load_model does it, but safe) 27 | try: 28 | from config import get_model_cache_path, get_whisper_model_name 29 | 30 | cache_path = get_model_cache_path() 31 | os.makedirs(cache_path, exist_ok=True) 32 | logger.info( 33 | f"Ensured model cache directory exists: {os.path.abspath(cache_path)}" 34 | ) 35 | except Exception as e: 36 | logger.warning(f"Could not ensure cache directory exists: {e}") 37 | cache_path = None # Ensure cache_path is defined or None 38 | 39 | # Trigger the Dia model loading function from the engine 40 | logger.info("Calling engine.load_model() to initiate Dia download if necessary...") 41 | dia_success = engine.load_model() 42 | 43 | if dia_success: 44 | logger.info("--- Dia model download/load process completed successfully ---") 45 | else: 46 | logger.error( 47 | "--- Dia model download/load process failed. Check logs for details. ---" 48 | ) 49 | # Optionally exit if Dia fails, or continue to try Whisper 50 | # exit(1) 51 | 52 | # --- Download Whisper Model --- 53 | whisper_success = False 54 | if WHISPER_AVAILABLE and cache_path: 55 | whisper_model_name = get_whisper_model_name() 56 | logger.info(f"Attempting to download Whisper model '{whisper_model_name}'...") 57 | try: 58 | # Use download_root to specify our cache directory 59 | whisper.load_model(whisper_model_name, download_root=cache_path) 60 | logger.info( 61 | f"Whisper model '{whisper_model_name}' downloaded/found successfully in {cache_path}." 62 | ) 63 | whisper_success = True 64 | except Exception as e: 65 | logger.error( 66 | f"Failed to download/load Whisper model '{whisper_model_name}': {e}", 67 | exc_info=True, 68 | ) 69 | elif not WHISPER_AVAILABLE: 70 | logger.warning( 71 | "Skipping Whisper model download: Whisper library not installed." 72 | ) 73 | elif not cache_path: 74 | logger.warning( 75 | "Skipping Whisper model download: Cache path could not be determined." 76 | ) 77 | 78 | # --- Final Status --- 79 | if dia_success and whisper_success: 80 | logger.info("--- All required models downloaded/verified successfully ---") 81 | elif dia_success: 82 | logger.warning( 83 | "--- Dia model OK, but Whisper model download failed or was skipped ---" 84 | ) 85 | elif whisper_success: 86 | logger.warning("--- Whisper model OK, but Dia model download failed ---") 87 | else: 88 | logger.error( 89 | "--- Both Dia and Whisper model downloads failed or were skipped ---" 90 | ) 91 | exit(1) # Exit with error code if essential models failed 92 | 93 | logger.info("You can now start the server using 'python server.py'") 94 | -------------------------------------------------------------------------------- /env.example.txt: -------------------------------------------------------------------------------- 1 | # .env - Initial Configuration Seed for Dia TTS Server 2 | # IMPORTANT: This file is ONLY used on the very first server start 3 | # IF config.yaml does NOT exist. 4 | # After config.yaml is created, settings are managed there. 5 | # Changes here will NOT affect a running server with an existing config.yaml. 6 | # Use the Web UI or directly edit config.yaml for subsequent changes. 7 | 8 | # --- Server Settings --- 9 | HOST='0.0.0.0' 10 | PORT='8003' 11 | 12 | # --- Path Settings --- 13 | # These paths will be written into the initial config.yaml 14 | DIA_MODEL_CACHE_PATH='./model_cache' 15 | REFERENCE_AUDIO_PATH='./reference_audio' 16 | OUTPUT_PATH='./outputs' 17 | # Path for predefined voices is now set in config.py defaults 18 | # PREDEFINED_VOICES_PATH='./voices' 19 | 20 | # --- Model Source Settings --- 21 | # Defaulting to BF16 safetensors. Uncomment/modify to seed config.yaml differently. 22 | DIA_MODEL_REPO_ID='ttj/dia-1.6b-safetensors' 23 | DIA_MODEL_CONFIG_FILENAME='config.json' 24 | DIA_MODEL_WEIGHTS_FILENAME='dia-v0_1_bf16.safetensors' 25 | 26 | # Example: Seed with full precision safetensors 27 | # DIA_MODEL_REPO_ID=ttj/dia-1.6b-safetensors 28 | # DIA_MODEL_WEIGHTS_FILENAME=dia-v0_1.safetensors 29 | 30 | # Example: Seed with original Nari Labs .pth model 31 | # DIA_MODEL_REPO_ID=nari-labs/Dia-1.6B 32 | # DIA_MODEL_WEIGHTS_FILENAME=dia-v0_1.pth 33 | 34 | # --- Whisper Transcription Settings --- 35 | # Model name for automatic transcription if reference .txt is missing. 36 | WHISPER_MODEL_NAME='small.en' 37 | 38 | # --- Default Generation Parameters --- 39 | # These set the initial values in the 'generation_defaults' section of config.yaml. 40 | # They can be changed later via the UI ('Save Generation Defaults' button) or by editing config.yaml. 41 | GEN_DEFAULT_SPEED_FACTOR='1.0' 42 | GEN_DEFAULT_CFG_SCALE='3.0' 43 | GEN_DEFAULT_TEMPERATURE='1.3' 44 | GEN_DEFAULT_TOP_P='0.95' 45 | GEN_DEFAULT_CFG_FILTER_TOP_K='35' 46 | # -1 for random seed 47 | GEN_DEFAULT_SEED='42' 48 | # Now managed in config.py defaults 49 | # GEN_DEFAULT_SPLIT_TEXT='True' 50 | # GEN_DEFAULT_CHUNK_SIZE='120' 51 | 52 | # --- UI State --- 53 | # UI state (last text, selections, etc.) is NOT seeded from .env. 54 | # It starts with defaults defined in config.py and is saved in config.yaml. 55 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # models.py 2 | # Pydantic models for API requests and potentially responses 3 | 4 | from pydantic import BaseModel, Field 5 | from typing import Optional, Literal 6 | 7 | # --- Request Models --- 8 | 9 | 10 | class OpenAITTSRequest(BaseModel): 11 | """Request model compatible with the OpenAI TTS API.""" 12 | 13 | model: str = Field( 14 | default="dia-1.6b", 15 | description="Model identifier (ignored by this server, always uses Dia). Included for compatibility.", 16 | ) 17 | input: str = Field(..., description="The text to synthesize.") 18 | voice: str = Field( 19 | default="S1", 20 | description="Voice mode or reference audio filename. Examples: 'S1', 'S2', 'dialogue', 'my_reference.wav'.", 21 | ) 22 | response_format: Literal["opus", "wav"] = Field( 23 | default="opus", description="The desired audio output format." 24 | ) 25 | speed: float = Field( 26 | default=1.0, 27 | ge=0.5, # Allow wider range for speed factor post-processing 28 | le=2.0, 29 | description="Adjusts the speed of the generated audio (0.5 to 2.0).", 30 | ) 31 | # Add seed parameter, defaulting to random (-1) 32 | seed: Optional[int] = Field( 33 | default=-1, 34 | description="Generation seed. Use -1 for random (default), or a specific integer for deterministic output.", 35 | ) 36 | 37 | 38 | class CustomTTSRequest(BaseModel): 39 | """Request model for the custom /tts endpoint.""" 40 | 41 | text: str = Field( 42 | ..., 43 | description="The text to synthesize. For 'dialogue' mode, include [S1]/[S2] tags.", 44 | ) 45 | voice_mode: Literal["dialogue", "single_s1", "single_s2", "clone"] = Field( 46 | default="single_s1", description="Specifies the generation mode." 47 | ) 48 | clone_reference_filename: Optional[str] = Field( 49 | default=None, 50 | description="Filename of the reference audio within the configured reference path (required if voice_mode is 'clone').", 51 | ) 52 | # New: Optional transcript for cloning 53 | transcript: Optional[str] = Field( 54 | default=None, 55 | description="Optional transcript of the reference audio for cloning. If provided, overrides local .txt file lookup and Whisper generation.", 56 | ) 57 | output_format: Literal["opus", "wav"] = Field( 58 | default="opus", description="The desired audio output format." 59 | ) 60 | # Dia-specific generation parameters 61 | max_tokens: Optional[int] = Field( 62 | default=None, 63 | gt=0, 64 | description="Maximum number of audio tokens to generate per chunk (defaults to model's internal config value).", 65 | ) 66 | cfg_scale: float = Field( 67 | default=3.0, 68 | ge=1.0, 69 | le=5.0, 70 | description="Classifier-Free Guidance scale (1.0-5.0).", 71 | ) 72 | temperature: float = Field( 73 | default=1.3, 74 | ge=0.1, 75 | le=1.5, 76 | description="Sampling temperature (0.1-1.5).", # Allow lower temp for greedy-like 77 | ) 78 | top_p: float = Field( 79 | default=0.95, 80 | ge=0.1, # Allow lower top_p 81 | le=1.0, 82 | description="Nucleus sampling probability (0.1-1.0).", 83 | ) 84 | speed_factor: float = Field( 85 | default=0.94, 86 | ge=0.5, # Allow wider range for speed factor post-processing 87 | le=2.0, 88 | description="Adjusts the speed of the generated audio (0.5 to 2.0).", 89 | ) 90 | cfg_filter_top_k: int = Field( 91 | default=35, 92 | ge=1, 93 | le=100, 94 | description="Top k filter for CFG guidance (1-100).", # Allow wider range 95 | ) 96 | # Add seed parameter, defaulting to random (-1) 97 | seed: Optional[int] = Field( 98 | default=-1, 99 | description="Generation seed. Use -1 for random (default), or a specific integer for deterministic output.", 100 | ) 101 | # Add text splitting parameters 102 | split_text: Optional[bool] = Field( 103 | default=True, # Default to splitting enabled 104 | description="Whether to automatically split long text into chunks for processing.", 105 | ) 106 | chunk_size: Optional[int] = Field( 107 | default=300, # Default target chunk size 108 | ge=100, # Minimum reasonable chunk size 109 | le=1000, # Maximum reasonable chunk size 110 | description="Approximate target character length for text chunks when splitting is enabled (100-1000).", 111 | ) 112 | 113 | 114 | # --- Response Models (Optional, can be simple dicts too) --- 115 | 116 | 117 | class TTSResponse(BaseModel): 118 | """Basic response model for successful generation (if returning JSON).""" 119 | 120 | request_id: str 121 | status: str = "completed" 122 | generation_time_sec: float 123 | output_url: Optional[str] = None # If saving file and returning URL 124 | 125 | 126 | class ErrorResponse(BaseModel): 127 | """Error response model.""" 128 | 129 | detail: str 130 | -------------------------------------------------------------------------------- /reference_audio/Gianna.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /reference_audio/Gianna.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/reference_audio/Gianna.wav -------------------------------------------------------------------------------- /reference_audio/Oliver_Luna.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /reference_audio/Oliver_Luna.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/reference_audio/Oliver_Luna.wav -------------------------------------------------------------------------------- /reference_audio/Robert.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /reference_audio/Robert.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/reference_audio/Robert.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # requirements.txt 2 | 3 | # Core Web Framework 4 | fastapi 5 | uvicorn[standard] 6 | 7 | # Machine Learning & Audio 8 | torch 9 | torchaudio 10 | numpy 11 | soundfile # Requires libsndfile system library (e.g., sudo apt-get install libsndfile1 on Debian/Ubuntu) 12 | huggingface_hub 13 | descript-audio-codec 14 | safetensors 15 | openai-whisper 16 | 17 | # Configuration & Utilities 18 | pydantic 19 | python-dotenv # Used ONLY for initial config seeding if config.yaml missing 20 | Jinja2 21 | python-multipart # For file uploads 22 | requests # For health checks or other potential uses 23 | PyYAML # For parsing presets.yaml AND primary config.yaml 24 | tqdm 25 | 26 | # Audio Post-processing 27 | pydub 28 | praat-parselmouth # For unvoiced segment removal 29 | librosa # for changes to sampling 30 | hf-transfer 31 | -------------------------------------------------------------------------------- /static/screenshot-d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/static/screenshot-d.png -------------------------------------------------------------------------------- /static/screenshot-l.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/static/screenshot-l.png -------------------------------------------------------------------------------- /ui/presets.yaml: -------------------------------------------------------------------------------- 1 | # ui/presets.yaml 2 | # Predefined examples for the Dia TTS UI 3 | 4 | - name: "Short Dialogue" # Renamed 5 | # voice_mode ignored by JS loader 6 | text: | 7 | [S1] Hey, how's it going? 8 | [S2] Pretty good! Just grabbing some coffee. You? 9 | [S1] Same here. Need the fuel! (laughs) 10 | params: 11 | cfg_scale: 3.0 12 | temperature: 1.3 13 | top_p: 0.95 14 | cfg_filter_top_k: 35 15 | # speed_factor uses the saved default 16 | # seed uses the saved default 17 | 18 | - name: "Long Dialogue" # Added 19 | text: | 20 | [S1] Hey, is that my coffee you're drinking? 21 | [S2] Pretty sure it's mine. Wait, let me check. (squints) It says "Bark"? That can't be right. 22 | [S1] That's Mark with an M! (sighs) My handwriting isn't that bad. 23 | [S2] Oh! (laughs) Well, in my defense, this is my third coffee today. 24 | [S1] And yet you thought you needed another one? 25 | [S2] Sleep is for the weak! (yawns) Though apparently so am I. 26 | [S1] Here, just give me my coffee before you (gasps) 27 | [S2] What? What's wrong? 28 | [S1] That was a triple espresso with extra shots! How are you still standing? 29 | [S2] (laughs) I can see through time now. Is it normal to taste colors? 30 | [S1] I'll order you some water. 31 | [S2] No need! I've never felt better! (laughs) 32 | [S1] Your left eye is twitching. 33 | [S2] Is it? I can't feel my face. (laughs) 34 | [S1] Maybe we should sit down for a minute. 35 | [S2] Sitting is for people who can't vibrate in place! (sighs) Fine, maybe for a minute. 36 | [S1] I've never seen someone drink the wrong coffee so enthusiastically. 37 | [S2] I've never met a coffee I didn't like. (laughs) Though this one might be trying to kill me. 38 | [S1] Next time, maybe read the name before you drink? 39 | [S2] Reading requires focus. Focus requires sleep. It's a whole problem. (yawns) 40 | [S1] How about I buy you a decaf? 41 | [S2] Decaf? (gasps) What did I ever do to you? 42 | [S1] Just trying to keep you alive until tomorrow. (laughs) 43 | [S2] Tomorrow is overrated. Today has coffee! 44 | params: # Using standard defaults 45 | cfg_scale: 3.0 46 | temperature: 1.3 47 | top_p: 0.95 48 | cfg_filter_top_k: 35 49 | 50 | - name: "Expressive Narration" 51 | text: | 52 | [S1] The old house stood on a windswept hill, its windows like empty eyes staring out at the stormy sea. (sighs) It felt... lonely. 53 | params: 54 | cfg_scale: 3.0 55 | temperature: 1.2 # Slightly lower temp for clarity 56 | top_p: 0.95 57 | cfg_filter_top_k: 35 58 | 59 | - name: "Quick Announcement" 60 | text: | 61 | [S1] Attention shoppers! The store will be closing in 15 minutes. Please bring your final purchases to the checkout. 62 | params: 63 | cfg_scale: 2.8 # Slightly lower CFG for potentially more natural tone 64 | temperature: 1.3 65 | top_p: 0.95 66 | cfg_filter_top_k: 35 67 | 68 | - name: "Funny Exchange" 69 | text: | 70 | [S1] Did you remember to buy the alien repellent? 71 | [S2] The what now? (laughs) I thought you were joking! 72 | [S1] Joking? They're landing tonight! (clears throat) Probably. 73 | params: 74 | cfg_scale: 3.2 # Slightly higher CFG 75 | temperature: 1.35 # Slightly higher temp 76 | top_p: 0.95 77 | cfg_filter_top_k: 35 78 | 79 | - name: "Simple Sentence" 80 | text: | 81 | [S1] This is a test of the text to speech system. 82 | params: 83 | cfg_scale: 3.0 84 | temperature: 1.3 85 | top_p: 0.95 86 | cfg_filter_top_k: 35 87 | -------------------------------------------------------------------------------- /voices/Abigail.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Abigail.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Abigail.wav -------------------------------------------------------------------------------- /voices/Abigail_Taylor.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Abigail_Taylor.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Abigail_Taylor.wav -------------------------------------------------------------------------------- /voices/Adrian.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Adrian.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Adrian.wav -------------------------------------------------------------------------------- /voices/Adrian_Jade.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Adrian_Jade.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Adrian_Jade.wav -------------------------------------------------------------------------------- /voices/Alexander.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Alexander.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Alexander.wav -------------------------------------------------------------------------------- /voices/Alexander_Emily.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Alexander_Emily.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Alexander_Emily.wav -------------------------------------------------------------------------------- /voices/Alice.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Alice.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Alice.wav -------------------------------------------------------------------------------- /voices/Austin.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Austin.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Austin.wav -------------------------------------------------------------------------------- /voices/Austin_Jeremiah.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Austin_Jeremiah.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Austin_Jeremiah.wav -------------------------------------------------------------------------------- /voices/Axel.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Axel.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Axel.wav -------------------------------------------------------------------------------- /voices/Axel_Miles.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Axel_Miles.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Axel_Miles.wav -------------------------------------------------------------------------------- /voices/Connor.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Connor.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Connor.wav -------------------------------------------------------------------------------- /voices/Connor_Ryan.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Connor_Ryan.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Connor_Ryan.wav -------------------------------------------------------------------------------- /voices/Cora.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Cora.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Cora.wav -------------------------------------------------------------------------------- /voices/Cora_Gianna.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Cora_Gianna.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Cora_Gianna.wav -------------------------------------------------------------------------------- /voices/Elena.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Elena.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Elena.wav -------------------------------------------------------------------------------- /voices/Elena_Emily.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Elena_Emily.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Elena_Emily.wav -------------------------------------------------------------------------------- /voices/Eli.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Eli.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Eli.wav -------------------------------------------------------------------------------- /voices/Emily.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Emily.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Emily.wav -------------------------------------------------------------------------------- /voices/Everett.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Everett.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Everett.wav -------------------------------------------------------------------------------- /voices/Everett_Jordan.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Everett_Jordan.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Everett_Jordan.wav -------------------------------------------------------------------------------- /voices/Gabriel.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Gabriel.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Gabriel.wav -------------------------------------------------------------------------------- /voices/Gabriel_Ian.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Gabriel_Ian.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Gabriel_Ian.wav -------------------------------------------------------------------------------- /voices/Gianna.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Gianna.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Gianna.wav -------------------------------------------------------------------------------- /voices/Henry.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Henry.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Henry.wav -------------------------------------------------------------------------------- /voices/Ian.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Ian.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Ian.wav -------------------------------------------------------------------------------- /voices/Jade.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Jade.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Jade.wav -------------------------------------------------------------------------------- /voices/Jade_Layla.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Jade_Layla.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Jade_Layla.wav -------------------------------------------------------------------------------- /voices/Jeremiah.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Jeremiah.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Jeremiah.wav -------------------------------------------------------------------------------- /voices/Jordan.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Jordan.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Jordan.wav -------------------------------------------------------------------------------- /voices/Julian.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Julian.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Julian.wav -------------------------------------------------------------------------------- /voices/Julian_Thomas.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Julian_Thomas.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Julian_Thomas.wav -------------------------------------------------------------------------------- /voices/Layla.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Layla.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Layla.wav -------------------------------------------------------------------------------- /voices/Leonardo.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Leonardo.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Leonardo.wav -------------------------------------------------------------------------------- /voices/Leonardo_Olivia.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Leonardo_Olivia.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Leonardo_Olivia.wav -------------------------------------------------------------------------------- /voices/Michael.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Michael.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Michael.wav -------------------------------------------------------------------------------- /voices/Michael_Emily.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Michael_Emily.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Michael_Emily.wav -------------------------------------------------------------------------------- /voices/Miles.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Miles.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Miles.wav -------------------------------------------------------------------------------- /voices/Oliver_Luna.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Oliver_Luna.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Oliver_Luna.wav -------------------------------------------------------------------------------- /voices/Olivia.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Olivia.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Olivia.wav -------------------------------------------------------------------------------- /voices/Ryan.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Ryan.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Ryan.wav -------------------------------------------------------------------------------- /voices/Taylor.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Taylor.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Taylor.wav -------------------------------------------------------------------------------- /voices/Thomas.txt: -------------------------------------------------------------------------------- 1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. -------------------------------------------------------------------------------- /voices/Thomas.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Thomas.wav --------------------------------------------------------------------------------