├── .github
└── workflows
│ └── docker-build-push.yml
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── config.py
├── dia
├── __init__.py
├── audio.py
├── config.py
├── layers.py
├── model.py
└── state.py
├── docker-compose.yml
├── documentation.md
├── download_model.py
├── engine.py
├── env.example.txt
├── models.py
├── reference_audio
├── Gianna.txt
├── Gianna.wav
├── Oliver_Luna.txt
├── Oliver_Luna.wav
├── Robert.txt
└── Robert.wav
├── requirements.txt
├── server.py
├── static
├── screenshot-d.png
└── screenshot-l.png
├── ui
├── index.html
├── presets.yaml
└── script.js
├── utils.py
└── voices
├── Abigail.txt
├── Abigail.wav
├── Abigail_Taylor.txt
├── Abigail_Taylor.wav
├── Adrian.txt
├── Adrian.wav
├── Adrian_Jade.txt
├── Adrian_Jade.wav
├── Alexander.txt
├── Alexander.wav
├── Alexander_Emily.txt
├── Alexander_Emily.wav
├── Alice.txt
├── Alice.wav
├── Austin.txt
├── Austin.wav
├── Austin_Jeremiah.txt
├── Austin_Jeremiah.wav
├── Axel.txt
├── Axel.wav
├── Axel_Miles.txt
├── Axel_Miles.wav
├── Connor.txt
├── Connor.wav
├── Connor_Ryan.txt
├── Connor_Ryan.wav
├── Cora.txt
├── Cora.wav
├── Cora_Gianna.txt
├── Cora_Gianna.wav
├── Elena.txt
├── Elena.wav
├── Elena_Emily.txt
├── Elena_Emily.wav
├── Eli.txt
├── Eli.wav
├── Emily.txt
├── Emily.wav
├── Everett.txt
├── Everett.wav
├── Everett_Jordan.txt
├── Everett_Jordan.wav
├── Gabriel.txt
├── Gabriel.wav
├── Gabriel_Ian.txt
├── Gabriel_Ian.wav
├── Gianna.txt
├── Gianna.wav
├── Henry.txt
├── Henry.wav
├── Ian.txt
├── Ian.wav
├── Jade.txt
├── Jade.wav
├── Jade_Layla.txt
├── Jade_Layla.wav
├── Jeremiah.txt
├── Jeremiah.wav
├── Jordan.txt
├── Jordan.wav
├── Julian.txt
├── Julian.wav
├── Julian_Thomas.txt
├── Julian_Thomas.wav
├── Layla.txt
├── Layla.wav
├── Leonardo.txt
├── Leonardo.wav
├── Leonardo_Olivia.txt
├── Leonardo_Olivia.wav
├── Michael.txt
├── Michael.wav
├── Michael_Emily.txt
├── Michael_Emily.wav
├── Miles.txt
├── Miles.wav
├── Oliver_Luna.txt
├── Oliver_Luna.wav
├── Olivia.txt
├── Olivia.wav
├── Ryan.txt
├── Ryan.wav
├── Taylor.txt
├── Taylor.wav
├── Thomas.txt
└── Thomas.wav
/.github/workflows/docker-build-push.yml:
--------------------------------------------------------------------------------
1 | name: Build and Push Docker Image to GHCR
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 | workflow_dispatch:
9 |
10 | jobs:
11 | build-and-push:
12 | runs-on: ubuntu-latest
13 | permissions:
14 | contents: read
15 | packages: write
16 |
17 | steps:
18 | - name: Checkout repository
19 | uses: actions/checkout@v4
20 |
21 | - name: Set up Docker Buildx
22 | uses: docker/setup-buildx-action@v3
23 |
24 | - name: Log in to GitHub Container Registry
25 | uses: docker/login-action@v3
26 | with:
27 | registry: ghcr.io
28 | username: ${{ github.repository_owner }}
29 | password: ${{ secrets.GITHUB_TOKEN }}
30 |
31 | - name: Extract metadata for Docker
32 | id: meta
33 | uses: docker/metadata-action@v5
34 | with:
35 | images: ghcr.io/${{ github.repository_owner }}/dia-tts-server
36 | tags: |
37 | type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', 'main') || github.ref == format('refs/heads/{0}', 'master') }}
38 | type=sha,format=short
39 | - name: Build and push Docker image
40 | uses: docker/build-push-action@v5
41 | with:
42 | context: .
43 | push: ${{ github.event_name != 'pull_request' }}
44 | tags: ${{ steps.meta.outputs.tags }}
45 | labels: ${{ steps.meta.outputs.labels }}
46 | cache-from: type=gha
47 | cache-to: type=gha,mode=max
48 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore Python cache and compiled files
2 | *.pyc
3 | *.pyo
4 | *.pyd
5 | __pycache__/
6 | .ipynb_checkpoints/
7 |
8 | # Ignore distribution build artifacts
9 | .dist/
10 | dist/
11 | build/
12 | *.egg-info/
13 |
14 | # Ignore IDE/Editor specific files
15 | .idea/
16 | .vscode/
17 | *.swp
18 | *.swo
19 |
20 | # Ignore OS generated files
21 | .DS_Store
22 | Thumbs.db
23 |
24 | # Ignore virtual environment directories
25 | .venv/
26 | venv/
27 | */.venv/
28 | */venv/
29 | env/
30 | */env/
31 |
32 | # Ignore sensitive environment file
33 | .env
34 |
35 | # Ignore generated output and user data directories
36 | outputs/
37 | reference_audio/
38 | model_cache/ # Also good practice to ignore the model cache
39 |
40 | # Ignore test reports/coverage
41 | .coverage
42 | htmlcov/
43 | .pytest_cache/
44 |
45 | # Ignore log files
46 | *.log
47 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:12.8.1-runtime-ubuntu22.04
2 |
3 | # Set environment variables
4 | ENV PYTHONDONTWRITEBYTECODE=1
5 | ENV PYTHONUNBUFFERED=1
6 | ENV DEBIAN_FRONTEND=noninteractive
7 |
8 | # Install system dependencies
9 | RUN apt-get update && apt-get install -y --no-install-recommends \
10 | build-essential \
11 | libsndfile1 \
12 | ffmpeg \
13 | python3 \
14 | python3-pip \
15 | python3-dev \
16 | && apt-get clean \
17 | && rm -rf /var/lib/apt/lists/*
18 |
19 | # Set up working directory
20 | WORKDIR /app
21 |
22 | # Install Python dependencies
23 | COPY requirements.txt .
24 | RUN pip3 install --no-cache-dir -r requirements.txt
25 |
26 | # Copy application code
27 | COPY . .
28 |
29 | # Create required directories
30 | RUN mkdir -p model_cache reference_audio outputs voices
31 |
32 | # Expose the port the application will run on (default to 8003 as per config)
33 | EXPOSE 8003
34 |
35 | # Command to run the application
36 | CMD ["python3", "server.py"]
37 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 devnen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dia TTS Server: OpenAI-Compatible API with Web UI, Large Text Handling & Built-in Voices
2 |
3 | **Self-host the powerful [Nari Labs Dia TTS model](https://github.com/nari-labs/dia) with this enhanced FastAPI server! Features an intuitive Web UI, flexible API endpoints (including OpenAI-compatible `/v1/audio/speech`), support for realistic dialogue (`[S1]`/`[S2]`), improved voice cloning, large text processing via intelligent chunking, and consistent, reproducible voices using 43 built-in ready-to-use voices and generation seeds feature.**
4 |
5 | The latest Dia-TTS-Server version now has improved speed and reduced VRAM usage. Defaults to efficient BF16 SafeTensors for reduced VRAM and faster inference, with support for original `.pth` weights. Runs accelerated on NVIDIA GPUs (CUDA) with CPU fallback.
6 |
7 | ➡️ **Announcing our new TTS project:** Explore the Chatterbox TTS Server and its features: [https://github.com/devnen/Chatterbox-TTS-Server](https://github.com/devnen/Chatterbox-TTS-Server)
8 |
9 | [](LICENSE)
10 | [](https://www.python.org/downloads/)
11 | [](https://fastapi.tiangolo.com/)
12 | [](https://github.com/huggingface/safetensors)
13 | [](https://www.docker.com/)
14 | [](#)
15 | [](https://developer.nvidia.com/cuda-zone)
16 | [](https://platform.openai.com/docs/api-reference)
17 |
18 |
19 |

20 |

21 |
22 |
23 | ---
24 |
25 | ## 🗣️ Overview: Enhanced Dia TTS Access
26 |
27 | The original [Dia 1.6B TTS model by Nari Labs](https://github.com/nari-labs/dia) provides incredible capabilities for generating realistic dialogue, complete with speaker turns and non-verbal sounds like `(laughs)` or `(sighs)`. This project builds upon that foundation by providing a robust **[FastAPI](https://fastapi.tiangolo.com/) server** that makes Dia significantly easier to use and integrate.
28 |
29 | We solve the complexity of setting up and running the model by offering:
30 |
31 | * An **OpenAI-compatible API endpoint**, allowing you to use Dia TTS with tools expecting OpenAI's API structure.
32 | * A **modern Web UI** for easy experimentation, preset loading, reference audio management, and generation parameter tuning. The interface design draws inspiration from **[Lex-au's Orpheus-FastAPI project](https://github.com/Lex-au/Orpheus-FastAPI)**, adapting its intuitive layout and user experience for Dia TTS.
33 | * **Large Text Handling:** Intelligently splits long text inputs into manageable chunks based on sentence structure and speaker tags, processes them sequentially, and seamlessly concatenates the audio.
34 | * **Predefined Voices:** Select from 43 curated, ready-to-use synthetic voices for consistent and reliable output without cloning setup.
35 | * **Improved Voice Cloning:** Enhanced pipeline with automatic audio processing and transcript handling (local `.txt` file or experimental Whisper fallback).
36 | * **Consistent Generation:** Achieve consistent voice output across multiple generations or text chunks by using the "Predefined Voices" or "Voice Cloning" modes, optionally combined with a fixed integer **Seed**.
37 | * Support for both original `.pth` weights and modern, secure **[SafeTensors](https://github.com/huggingface/safetensors)**, defaulting to a **BF16 SafeTensors** version which uses roughly half the VRAM and offers improved speed.
38 | * Automatic **GPU (CUDA) acceleration** detection with fallback to CPU.
39 | * Configuration primarily via `config.yaml`, with `.env` used for initial setup/reset.
40 | * **Docker support** for easy containerized deployment with [Docker](https://www.docker.com/).
41 |
42 | This server is your gateway to leveraging Dia's advanced TTS capabilities seamlessly, now with enhanced stability, voice consistency, and large text support.
43 |
44 | ## ✨ What's New (v1.4.0 vs v1.0.0)
45 |
46 | This version introduces significant improvements and new features:
47 |
48 | **🚀 New Features:**
49 |
50 | * **Large Text Processing (Chunking):**
51 | * Automatically handles long text inputs by intelligently splitting them into smaller chunks based on sentence boundaries and speaker tags (`[S1]`/`[S2]`).
52 | * Processes each chunk individually and seamlessly concatenates the resulting audio, overcoming previous generation limits.
53 | * Configurable via UI toggle ("Split text into chunks") and chunk size slider.
54 | * **Predefined Voices:**
55 | * Added support for using 43 curated, ready-to-use synthetic voices stored in the `./voices` directory.
56 | * Selectable via UI dropdown ("Predefined Voices" mode). Server automatically uses required transcripts.
57 | * Provides reliable voice output without manual cloning setup and avoids potential licensing issues.
58 | * **Enhanced Voice Cloning:**
59 | * Improved backend pipeline for robustness.
60 | * Automatic reference audio processing: mono conversion, resampling to 44.1kHz, truncation (~20s).
61 | * Automatic transcript handling: Prioritizes local `.txt` file (recommended for accuracy) -> **experimental Whisper generation** if `.txt` is missing. Backend handles transcript prepending automatically.
62 | * Robust reference file finding handles case-insensitivity and extensions.
63 | * **Whisper Integration:** Added `openai-whisper` for automatic transcript generation as an experimental fallback during cloning. Configurable model (`WHISPER_MODEL_NAME` in `config.yaml`).
64 | * **API Enhancements:**
65 | * `/tts` endpoint now supports `transcript` (for explicit clone transcript), `split_text`, `chunk_size`, and `seed`.
66 | * `/v1/audio/speech` endpoint now supports `seed`.
67 | * **Generation Seed:** Added `seed` parameter to UI and API for influencing generation results. Using a fixed integer seed *in combination with* Predefined Voices or Voice Cloning helps maintain consistency across chunks or separate generations. Use -1 for random variation.
68 | * **Terminal Progress:** Generation of long text (using chunking) now displays a `tqdm` progress bar in the server's terminal window.
69 | * **UI Configuration Management:** Added UI section to view/edit `config.yaml` settings and save generation defaults.
70 | * **Configuration System:** Migrated to `config.yaml` for primary runtime configuration, managed via `config.py`. `.env` is now used mainly for initial seeding or resetting defaults.
71 |
72 | **🔧 Fixes & Enhancements:**
73 |
74 | * **VRAM Usage Fixed & Optimized:** Resolved memory leaks during inference and significantly reduced VRAM usage (approx. 14GB+ down to ~7GB) through code optimizations, fixing memory leaks, and BF16 default.
75 | * **Performance:** Significant speed improvements reported (approaching 95% real-time on tested hardware: AMD Ryzen 9 9950X3D + NVIDIA RTX 3090).
76 | * **Audio Post-Processing:** Automatically applies silence trimming (leading/trailing), internal silence reduction, and unvoiced segment removal (using Parselmouth) to improve audio quality and remove artifacts.
77 | * **UI State Persistence:** Web UI now saves/restores text input, voice mode selection, file selections, and generation parameters (seed, chunking, sliders) in `config.yaml`.
78 | * **UI Improvements:** Better loading indicators (shows chunk processing), refined chunking controls, seed input field, theme toggle, dynamic preset loading from `ui/presets.yaml`, warning modals for chunking/generation quality.
79 | * **Cloning Workflow:** Backend now handles transcript prepending automatically. UI workflow simplified (user selects file, enters target text).
80 | * **Dependency Management:** Added `tqdm`, `PyYAML`, `openai-whisper`, `parselmouth` to `requirements.txt`.
81 | * **Code Refactoring:** Aligned internal engine code with refactored `dia` library structure. Updated `config.py` to use `YamlConfigManager`.
82 |
83 | ## ✅ Features
84 |
85 | * **Core Dia Capabilities (via [Nari Labs Dia](https://github.com/nari-labs/dia)):**
86 | * 🗣️ Generate multi-speaker dialogue using `[S1]` / `[S2]` tags.
87 | * 😂 Include non-verbal sounds like `(laughs)`, `(sighs)`, `(clears throat)`.
88 | * 🎭 Perform voice cloning using reference audio prompts.
89 | * **Enhanced Server & API:**
90 | * ⚡ Built with the high-performance **[FastAPI](https://fastapi.tiangolo.com/)** framework.
91 | * 🤖 **OpenAI-Compatible API Endpoint** (`/v1/audio/speech`) for easy integration (now includes `seed`).
92 | * ⚙️ **Custom API Endpoint** (`/tts`) exposing all Dia generation parameters (now includes `seed`, `split_text`, `chunk_size`, `transcript`).
93 | * 📄 Interactive API documentation via Swagger UI (`/docs`).
94 | * 🩺 Health check endpoint (`/health`).
95 | * **Advanced Generation Features:**
96 | * 📚 **Large Text Handling:** Intelligently splits long inputs into chunks based on sentences and speaker tags, generates audio for each, and concatenates the results seamlessly. Configurable via `split_text` and `chunk_size`.
97 | * 🎤 **Predefined Voices:** Select from 43 curated, ready-to-use synthetic voices in the `./voices` directory for consistent output without cloning setup.
98 | * ✨ **Improved Voice Cloning:** Robust pipeline with automatic audio processing and transcript handling (local `.txt` or Whisper fallback). Backend handles transcript prepending.
99 | * 🌱 **Consistent Generation:** Use Predefined Voices or Voice Cloning modes, optionally with a fixed integer **Seed**, for consistent voice output across chunks or multiple requests.
100 | * 🔇 **Audio Post-Processing:** Automatic steps to trim silence, fix internal pauses, and remove long unvoiced segments/artifacts.
101 | * **Intuitive Web User Interface:**
102 | * 🖱️ Modern, easy-to-use interface inspired by **[Lex-au's Orpheus-FastAPI project](https://github.com/Lex-au/Orpheus-FastAPI)**.
103 | * 💡 **Presets:** Load example text and settings dynamically from `ui/presets.yaml`. Customize by editing the file.
104 | * 🎤 **Reference Audio Upload:** Easily upload `.wav`/`.mp3` files for voice cloning.
105 | * 🗣️ **Voice Mode Selection:** Choose between Predefined Voices, Voice Cloning, or Random/Dialogue modes.
106 | * 🎛️ **Parameter Control:** Adjust generation settings (CFG Scale, Temperature, Speed, Seed, etc.) via sliders and inputs.
107 | * 💾 **Configuration Management:** View and save server settings (`config.yaml`) and default generation parameters directly in the UI.
108 | * 💾 **Session Persistence:** Remembers your last used settings via `config.yaml`.
109 | * ✂️ **Chunking Controls:** Enable/disable text splitting and adjust approximate chunk size.
110 | * ⚠️ **Warning Modals:** Optional warnings for chunking voice consistency and general generation quality.
111 | * 🌓 **Light/Dark Mode:** Toggle between themes with preference saved locally.
112 | * 🔊 **Audio Player:** Integrated waveform player ([WaveSurfer.js](https://wavesurfer.xyz/)) for generated audio with download option.
113 | * ⏳ **Loading Indicator:** Shows status, including chunk processing information.
114 | * **Flexible & Efficient Model Handling:**
115 | * ☁️ Downloads models automatically from [Hugging Face Hub](https://huggingface.co/).
116 | * 🔒 Supports loading secure **`.safetensors`** weights (default).
117 | * 💾 Supports loading original **`.pth`** weights.
118 | * 🚀 Defaults to **BF16 SafeTensors** for reduced memory footprint (~half size) and potentially faster inference. (Credit: [ttj/dia-1.6b-safetensors](https://huggingface.co/ttj/dia-1.6b-safetensors))
119 | * 🔄 Easily switch between model formats/versions via `config.yaml`.
120 | * **Performance & Configuration:**
121 | * 💻 **GPU Acceleration:** Automatically uses NVIDIA CUDA if available, falls back to CPU. Optimized VRAM usage (~7GB typical).
122 | * 📊 **Terminal Progress:** Displays `tqdm` progress bar when processing text chunks.
123 | * ⚙️ Primary configuration via `config.yaml`, initial seeding via `.env`.
124 | * 📦 Uses standard Python virtual environments.
125 | * **Docker Support:**
126 | * 🐳 Containerized deployment via [Docker](https://www.docker.com/) and Docker Compose.
127 | * 🔌 NVIDIA GPU acceleration with Container Toolkit integration.
128 | * 💾 Persistent volumes for models, reference audio, predefined voices, outputs, and config.
129 | * 🚀 One-command setup and deployment (`docker compose up -d`).
130 |
131 | ## 🔩 System Prerequisites
132 |
133 | * **Operating System:** Windows 10/11 (64-bit) or Linux (Debian/Ubuntu recommended).
134 | * **Python:** Version 3.10 or later ([Download](https://www.python.org/downloads/)).
135 | * **Git:** For cloning the repository ([Download](https://git-scm.com/downloads)).
136 | * **Internet:** For downloading dependencies and models.
137 | * **(Optional but HIGHLY Recommended for Performance):**
138 | * **NVIDIA GPU:** CUDA-compatible (Maxwell architecture or newer). Check [NVIDIA CUDA GPUs](https://developer.nvidia.com/cuda-gpus). Optimized VRAM usage (~7GB typical), but more helps.
139 | * **NVIDIA Drivers:** Latest version for your GPU/OS ([Download](https://www.nvidia.com/Download/index.aspx)).
140 | * **CUDA Toolkit:** Compatible version (e.g., 11.8, 12.1) matching the PyTorch build you install.
141 | * **(Linux Only):**
142 | * `libsndfile1`: Audio library needed by `soundfile`. Install via package manager (e.g., `sudo apt install libsndfile1`).
143 | * `ffmpeg`: Required by `openai-whisper`. Install via package manager (e.g., `sudo apt install ffmpeg`).
144 |
145 | ## 💻 Installation and Setup
146 |
147 | Follow these steps carefully to get the server running.
148 |
149 | **1. Clone the Repository**
150 | ```bash
151 | git clone https://github.com/devnen/dia-tts-server.git
152 | cd dia-tts-server
153 | ```
154 |
155 | **2. Set up Python Virtual Environment**
156 |
157 | Using a virtual environment is crucial!
158 |
159 | * **Windows (PowerShell):**
160 | ```powershell
161 | # In the dia-tts-server directory
162 | python -m venv venv
163 | .\venv\Scripts\activate
164 | # Your prompt should now start with (venv)
165 | ```
166 |
167 | * **Linux (Bash - Debian/Ubuntu Example):**
168 | ```bash
169 | # Ensure prerequisites are installed
170 | sudo apt update && sudo apt install python3 python3-venv python3-pip libsndfile1 ffmpeg -y
171 |
172 | # In the dia-tts-server directory
173 | python3 -m venv venv
174 | source venv/bin/activate
175 | # Your prompt should now start with (venv)
176 | ```
177 |
178 | **3. Install Dependencies**
179 |
180 | Make sure your virtual environment is activated (`(venv)` prefix visible).
181 |
182 | ```bash
183 | # Upgrade pip (recommended)
184 | pip install --upgrade pip
185 |
186 | # Install project requirements (includes tqdm, yaml, parselmouth etc.)
187 | pip install -r requirements.txt
188 | ```
189 | ⭐ **Note:** This installation includes large libraries like PyTorch. The download and installation process may take some time depending on your internet speed and system performance.
190 |
191 | ⭐ **Important:** This installs the *CPU-only* version of PyTorch by default. If you have an NVIDIA GPU, proceed to Step 4 **before** running the server for GPU acceleration.
192 |
193 | **4. NVIDIA Driver and CUDA Setup (for GPU Acceleration)**
194 |
195 | Skip this step if you only have a CPU.
196 |
197 | * **Step 4a: Check/Install NVIDIA Drivers**
198 | * Run `nvidia-smi` in your terminal/command prompt.
199 | * If it works, note the **CUDA Version** listed (e.g., 12.1, 11.8). This is the *maximum* your driver supports.
200 | * If it fails, download and install the latest drivers from [NVIDIA Driver Downloads](https://www.nvidia.com/Download/index.aspx) and **reboot**. Verify with `nvidia-smi` again.
201 |
202 | * **Step 4b: Install PyTorch with CUDA Support**
203 | * Go to the [Official PyTorch Website](https://pytorch.org/get-started/locally/).
204 | * Use the configuration tool: Select **Stable**, **Windows/Linux**, **Pip**, **Python**, and the **CUDA version** that is **equal to or lower** than the one shown by `nvidia-smi` (e.g., if `nvidia-smi` shows 12.4, choose CUDA 12.1).
205 | * Copy the generated command (it will include `--index-url https://download.pytorch.org/whl/cuXXX`).
206 | * **In your activated `(venv)`:**
207 | ```bash
208 | # Uninstall the CPU version first!
209 | pip uninstall torch torchvision torchaudio -y
210 |
211 | # Paste and run the command copied from the PyTorch website
212 | # Example (replace with your actual command):
213 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
214 | ```
215 |
216 | * **Step 4c: Verify PyTorch CUDA Installation**
217 | * In your activated `(venv)`, run `python` and execute the following single line:
218 | ```python
219 | import torch; print(f"PyTorch version: {torch.__version__}"); print(f"CUDA available: {torch.cuda.is_available()}"); print(f"Device name: {torch.cuda.get_device_name(0)}") if torch.cuda.is_available() else None; exit()
220 | ```
221 | * If `CUDA available:` shows `True`, the setup was successful. If `False`, double-check driver installation and the PyTorch install command.
222 |
223 | ## ⚙️ Configuration
224 |
225 | The server now primarily uses `config.yaml` for runtime configuration.
226 |
227 | * **`config.yaml`:** Located in the project root. This file stores all server settings, model paths, generation defaults, and UI state. It is created automatically on the first run if it doesn't exist. **This is the main file to edit for persistent configuration changes.**
228 | * **`.env` File:** Used **only** for the *initial creation* of `config.yaml` if it's missing, or when using the "Reset All Settings" button in the UI. Values in `.env` override hardcoded defaults during this initial seeding/reset process. It is **not** read during normal server operation once `config.yaml` exists.
229 | * **UI Configuration:** The "Server Configuration" and "Generation Parameters" sections in the Web UI allow direct editing and saving of values *into* `config.yaml`.
230 |
231 | **Key Configuration Areas (in `config.yaml` or UI):**
232 |
233 | * `server`: `host`, `port`
234 | * `model`: `repo_id`, `config_filename`, `weights_filename`, `whisper_model_name`
235 | * `paths`: `model_cache`, `reference_audio`, `output`, `voices` (for predefined)
236 | * `generation_defaults`: Default values for sliders/seed in the UI (`speed_factor`, `cfg_scale`, `temperature`, `top_p`, `cfg_filter_top_k`, `seed`, `split_text`, `chunk_size`).
237 | * `ui_state`: Stores the last used text, voice mode, file selections, etc., for UI persistence.
238 |
239 | ⭐ **Remember:** Changes made to `server`, `model`, or `paths` sections in `config.yaml` (or via the UI) **require a server restart** to take effect. Changes to `generation_defaults` or `ui_state` are applied dynamically or on the next page load.
240 |
241 | ## ▶️ Running the Server
242 |
243 | **Note on Model Downloads:**
244 | The first time you run the server (or after changing model settings in `config.yaml`), it will download the required Dia and Whisper model files (~3-7GB depending on selection). Monitor the terminal logs for progress. The server starts fully *after* downloads complete.
245 |
246 | 1. **Activate the virtual environment (if not activated):**
247 | * Linux/macOS: `source venv/bin/activate`
248 | * Windows: `.\venv\Scripts\activate`
249 | 2. **Run the server:**
250 | ```bash
251 | python server.py
252 | ```
253 | 3. **Access the UI:** The server should automatically attempt to open the Web UI in your default browser after startup. If it doesn't for any reason, manually navigate to `http://localhost:PORT` (e.g., `http://localhost:8003`).
254 | 4. **Access API Docs:** Open `http://localhost:PORT/docs`.
255 | 5. **Stop the server:** Press `CTRL+C` in the terminal.
256 |
257 | ---
258 |
259 | ## 🐳 Docker Installation
260 |
261 | Run Dia TTS Server easily using Docker. The recommended method uses Docker Compose with pre-built images from GitHub Container Registry (GHCR).
262 |
263 | ### Prerequisites
264 |
265 | * [Docker](https://docs.docker.com/get-docker/) installed.
266 | * [Docker Compose](https://docs.docker.com/compose/install/) installed (usually included with Docker Desktop).
267 | * (Optional but Recommended for GPU) NVIDIA GPU with up-to-date drivers and the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) installed.
268 |
269 | ### Option 1: Using Docker Compose (Recommended)
270 |
271 | This method uses `docker-compose.yml` to manage the container, volumes, and configuration easily. It leverages pre-built images hosted on GHCR.
272 |
273 | 1. **Clone the repository:** (You only need the `docker-compose.yml` and `env.example.txt` files from it)
274 | ```bash
275 | git clone https://github.com/devnen/dia-tts-server.git
276 | cd dia-tts-server
277 | ```
278 |
279 | 2. **(Optional) Initial Configuration via `.env`:**
280 | * If this is your very first time running the container and you want to override the default settings *before* `config.yaml` is created inside the container, copy the example environment file:
281 | ```bash
282 | cp env.example.txt .env
283 | ```
284 | * Edit the `.env` file with your desired initial settings (e.g., `PORT`, model filenames).
285 | * **Note:** This `.env` file is *only* used to seed the *initial* `config.yaml` on the very first container start if `/app/config.yaml` doesn't already exist inside the container's volume (which it won't initially). Subsequent configuration changes should be made via the UI or by editing `config.yaml` directly (see Configuration Note below).
286 |
287 | 3. **Review `docker-compose.yml`:**
288 | * The repository includes a `docker-compose.yml` file configured to use the pre-built image and recommended settings. Ensure it looks similar to this:
289 |
290 | ```yaml
291 | # docker-compose.yml
292 | version: '3.8'
293 |
294 | services:
295 | dia-tts-server:
296 | # Use the pre-built image from GitHub Container Registry
297 | image: ghcr.io/devnen/dia-tts-server:latest
298 | # Alternatively, to build locally (e.g., for development):
299 | # build:
300 | # context: .
301 | # dockerfile: Dockerfile
302 | ports:
303 | # Map host port (default 8003) to container port 8003
304 | # You can change the host port via .env (e.g., PORT=8004)
305 | - "${PORT:-8003}:8003"
306 | volumes:
307 | # Mount local directories into the container for persistent data
308 | - ./model_cache:/app/model_cache
309 | - ./reference_audio:/app/reference_audio
310 | - ./outputs:/app/outputs
311 | - ./voices:/app/voices
312 | # DO NOT mount config.yaml - let the app create it inside
313 |
314 | # --- GPU Access ---
315 | # Modern method (Recommended for newer Docker/NVIDIA setups)
316 | devices:
317 | - nvidia.com/gpu=all
318 | device_cgroup_rules:
319 | - "c 195:* rmw" # Needed for some NVIDIA container toolkit versions
320 | - "c 236:* rmw" # Needed for some NVIDIA container toolkit versions
321 |
322 | # Legacy method (Alternative for older Docker/NVIDIA setups)
323 | # If the 'devices' block above doesn't work, comment it out and uncomment
324 | # the 'deploy' block below. Do not use both simultaneously.
325 | # deploy:
326 | # resources:
327 | # reservations:
328 | # devices:
329 | # - driver: nvidia
330 | # count: 1 # Or specify specific GPUs e.g., "device=0,1"
331 | # capabilities: [gpu]
332 | # --- End GPU Access ---
333 |
334 | restart: unless-stopped
335 | env_file:
336 | # Load environment variables from .env file for initial config seeding
337 | - .env
338 | environment:
339 | # Enable faster Hugging Face downloads inside the container
340 | - HF_HUB_ENABLE_HF_TRANSFER=1
341 | # Pass GPU capabilities (may be needed for legacy method if uncommented)
342 | - NVIDIA_VISIBLE_DEVICES=all
343 | - NVIDIA_DRIVER_CAPABILITIES=compute,utility
344 |
345 | # Optional: Define named volumes if you prefer them over host mounts
346 | # volumes:
347 | # model_cache:
348 | # reference_audio:
349 | # outputs:
350 | # voices:
351 | ```
352 |
353 | 4. **Start the container:**
354 | ```bash
355 | docker compose up -d
356 | ```
357 | * This command will:
358 | * Pull the latest `ghcr.io/devnen/dia-tts-server:latest` image.
359 | * Create the local directories (`model_cache`, `reference_audio`, `outputs`, `voices`) if they don't exist.
360 | * Start the container in detached mode (`-d`).
361 | * The first time you run this, it will download the TTS models into `./model_cache`, which may take some time depending on your internet speed.
362 |
363 | 5. **Access the UI:**
364 | Open your web browser to `http://localhost:8003` (or the host port you configured in `.env`).
365 |
366 | 6. **View logs:**
367 | ```bash
368 | docker compose logs -f
369 | ```
370 |
371 | 7. **Stop the container:**
372 | ```bash
373 | docker compose down
374 | ```
375 |
376 | ### Option 2: Using `docker run` (Alternative)
377 |
378 | This method runs the container directly without Docker Compose, requiring manual specification of ports, volumes, and GPU flags.
379 |
380 | ```bash
381 | # Ensure local directories exist first:
382 | # mkdir -p model_cache reference_audio outputs voices
383 |
384 | docker run -d \
385 | --name dia-tts-server \
386 | -p 8003:8003 \
387 | -v ./model_cache:/app/model_cache \
388 | -v ./reference_audio:/app/reference_audio \
389 | -v ./outputs:/app/outputs \
390 | -v ./voices:/app/voices \
391 | --env HF_HUB_ENABLE_HF_TRANSFER=1 \
392 | --gpus all \
393 | ghcr.io/devnen/dia-tts-server:latest
394 | ```
395 |
396 | * Replace `8003:8003` with `:8003` if needed.
397 | * `--gpus all` enables GPU access; consult NVIDIA Container Toolkit documentation for alternatives if needed.
398 | * Initial configuration relies on model defaults unless you pass environment variables using multiple `-e VAR=VALUE` flags (more complex than using `.env` with Compose).
399 |
400 | ### Configuration Note
401 |
402 | * The server uses `config.yaml` inside the container (`/app/config.yaml`) for its settings.
403 | * On the *very first start*, if `/app/config.yaml` doesn't exist, the server creates it using defaults from the code, potentially overridden by variables in the `.env` file (if using Docker Compose and `.env` exists).
404 | * **After the first start,** changes should be made by:
405 | * Using the Web UI's settings page (if available).
406 | * Editing the `config.yaml` file *inside* the container (e.g., `docker compose exec dia-tts-server nano /app/config.yaml`). Changes require a container restart (`docker compose restart dia-tts-server`) to take effect for server/model/path settings. UI state changes are saved live.
407 |
408 | ### Performance Optimizations
409 |
410 | * **Faster Model Downloads**: `hf-transfer` is enabled by default in the provided `docker-compose.yml` and image, significantly speeding up initial model downloads from Hugging Face.
411 | * **GPU Acceleration**: The `docker-compose.yml` and `docker run` examples include flags (`devices` or `--gpus`) to enable NVIDIA GPU acceleration if available. The Docker image uses a CUDA runtime base for efficiency.
412 |
413 | ### Docker Volumes
414 |
415 | Persistent data is stored on your host machine via volume mounts:
416 |
417 | * `./model_cache:/app/model_cache` (Downloaded TTS and Whisper models)
418 | * `./reference_audio:/app/reference_audio` (Your uploaded reference audio files for cloning)
419 | * `./outputs:/app/outputs` (Generated audio files)
420 | * `./voices:/app/voices` (Predefined voice audio files)
421 |
422 | ### Available Images
423 |
424 | * **GitHub Container Registry**: `ghcr.io/devnen/dia-tts-server:latest` (Automatically built from the `main` branch)
425 |
426 | ---
427 |
428 | ## 💡 Usage
429 |
430 | ### Web UI (`http://localhost:PORT`)
431 |
432 | The most intuitive way to use the server:
433 |
434 | * **Text Input:** Enter your script. Use `[S1]`/`[S2]` for dialogue and non-verbals like `(laughs)`. Content is saved automatically.
435 | * **Generate Button & Chunking:** Click "Generate Speech". Below the text box:
436 | * **Split text into chunks:** Toggle checkbox (enabled by default). Enables splitting for long text (> ~2x chunk size).
437 | * **Chunk Size:** Adjust the slider (visible when splitting is possible) for approximate chunk character length (default 120).
438 | * **Voice Mode:** Choose:
439 | * `Predefined Voices`: Select a curated, ready-to-use synthetic voice from the `./voices` directory.
440 | * `Voice Cloning`: Select an uploaded reference file from `./reference_audio`. Requires a corresponding `.txt` transcript (recommended) or relies on experimental Whisper fallback. Backend handles transcript automatically.
441 | * `Random Single / Dialogue`: Uses `[S1]`/`[S2]` tags or generates a random voice if no tags. Use a fixed Seed for consistency.
442 | * **Presets:** Click buttons (loaded from `ui/presets.yaml`) to populate text and parameters. Customize by editing the YAML file.
443 | * **Reference Audio (Clone Mode):** Select an existing `.wav`/`.mp3` or click "Import" to upload new files to `./reference_audio`.
444 | * **Generation Parameters:** Adjust sliders/inputs for Speed, CFG, Temperature, Top P, Top K, and **Seed**. Settings are saved automatically. Click "Save Generation Parameters" to update the defaults in `config.yaml`. Use -1 seed for random, integer for specific results.
445 | * **Server Configuration:** View/edit `config.yaml` settings (requires server restart for some changes).
446 | * **Loading Overlay:** Appears during generation, showing chunk progress if applicable.
447 | * **Audio Player:** Appears on success with waveform, playback controls, download link, and generation info.
448 | * **Theme Toggle:** Switch between light/dark modes.
449 |
450 | ### API Endpoints (`/docs` for details)
451 |
452 | * **`/v1/audio/speech` (POST):** OpenAI-compatible.
453 | * `input`: Text.
454 | * `voice`: 'S1', 'S2', 'dialogue', 'predefined_voice_filename.wav', or 'reference_filename.wav'.
455 | * `response_format`: 'opus' or 'wav'.
456 | * `speed`: Playback speed factor (0.5-2.0).
457 | * `seed`: (Optional) Integer seed, -1 for random.
458 | * **`/tts` (POST):** Custom endpoint with full control.
459 | * `text`: Target text.
460 | * `voice_mode`: 'dialogue', 'single_s1', 'single_s2', 'clone', 'predefined'.
461 | * `clone_reference_filename`: Filename in `./reference_audio` (for clone) or `./voices` (for predefined).
462 | * `transcript`: (Optional, Clone Mode Only) Explicit transcript text to override file/Whisper lookup.
463 | * `output_format`: 'opus' or 'wav'.
464 | * `max_tokens`: (Optional) Max tokens *per chunk*.
465 | * `cfg_scale`, `temperature`, `top_p`, `cfg_filter_top_k`: Generation parameters.
466 | * `speed_factor`: Playback speed factor (0.5-2.0).
467 | * `seed`: (Optional) Integer seed, -1 for random.
468 | * `split_text`: (Optional) Boolean, enable/disable chunking (default: True).
469 | * `chunk_size`: (Optional) Integer, target chunk size (default: 120).
470 |
471 | ## 🔍 Troubleshooting
472 |
473 | * **CUDA Not Available / Slow:** Check NVIDIA drivers (`nvidia-smi`), ensure correct CUDA-enabled PyTorch is installed (Installation Step 4).
474 | * **VRAM Out of Memory (OOM):**
475 | * Ensure you are using the BF16 model (`dia-v0_1_bf16.safetensors` in `config.yaml`) if VRAM is limited (~7GB needed).
476 | * Close other GPU-intensive applications. VRAM optimizations and leak fixes have significantly reduced requirements.
477 | * If processing very long text even with chunking, try reducing `chunk_size` (e.g., 100).
478 | * **CUDA Out of Memory (OOM) During Startup:** This can happen due to temporary overhead. The server loads weights to CPU first to mitigate this. If it persists, check VRAM usage (`nvidia-smi`), ensure BF16 model is used, or try setting `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` environment variable before starting.
479 | * **Import Errors (`dac`, `tqdm`, `yaml`, `whisper`, `parselmouth`):** Activate venv, run `pip install -r requirements.txt`. Ensure `descript-audio-codec` installed correctly.
480 | * **`libsndfile` / `ffmpeg` Error (Linux):** Run `sudo apt install libsndfile1 ffmpeg`.
481 | * **Model Download Fails (Dia or Whisper):** Check internet, `config.yaml` settings (`model.repo_id`, `model.weights_filename`, `model.whisper_model_name`), Hugging Face status, cache path permissions (`paths.model_cache`).
482 | * **Voice Cloning Fails / Poor Quality:**
483 | * **Ensure accurate `.txt` transcript exists** alongside the reference audio in `./reference_audio`. Format: `[S1] text...` or `[S1] text... [S2] text...`. This is the most reliable method.
484 | * Whisper fallback is experimental and may be inaccurate.
485 | * Use clean, clear reference audio (5-20s).
486 | * Check server logs for specific errors during `_prepare_cloning_inputs`.
487 | * **Permission Errors (Saving Files/Config):** Check write permissions for `paths.output`, `paths.reference_audio`, `paths.voices`, `paths.model_cache` (for Whisper transcript saves), and `config.yaml`.
488 | * **UI Issues / Settings Not Saving:** Clear browser cache/local storage. Check developer console (F12) for JS errors. Ensure `config.yaml` is writable by the server process.
489 | * **Inconsistent Voice with Chunking:** Use "Predefined Voices" or "Voice Cloning" mode. If using "Random/Dialogue" mode with splitting, use a fixed integer `seed` (not -1) for consistency across chunks. The UI provides a warning otherwise.
490 | * **Port Conflict (`Address already in use` / `Errno 98`):** Another process is using the port (default 8003). Stop the other process or change the `server.port` in `config.yaml` (requires restart).
491 | * **Explanation:** This usually happens if a previous server instance didn't shut down cleanly or another application is bound to the same port.
492 | * **Linux:** Find/kill process: `sudo lsof -i:PORT | grep LISTEN | awk '{print $2}' | xargs kill -9` (Replace PORT, e.g., 8003).
493 | * **Windows:** Find/kill process: `for /f "tokens=5" %i in ('netstat -ano ^| findstr :PORT') do taskkill /F /PID %i` (Replace PORT, e.g., 8003). Use with caution.
494 | * **Generation Cancel Button:** This is a "UI Cancel" - it stops the *frontend* from waiting but doesn't instantly halt ongoing backend model inference. Clicking Generate again cancels the previous UI wait.
495 |
496 | ### Selecting GPUs on Multi-GPU Systems
497 |
498 | Set the `CUDA_VISIBLE_DEVICES` environment variable **before** running `python server.py` to specify which GPU(s) PyTorch should see. The server uses the first visible one (`cuda:0`).
499 |
500 | * **Example (Use only physical GPU 1):**
501 | * Linux/macOS: `CUDA_VISIBLE_DEVICES="1" python server.py`
502 | * Windows CMD: `set CUDA_VISIBLE_DEVICES=1 && python server.py`
503 | * Windows PowerShell: `$env:CUDA_VISIBLE_DEVICES="1"; python server.py`
504 |
505 | * **Example (Use physical GPUs 6 and 7 - server uses GPU 6):**
506 | * Linux/macOS: `CUDA_VISIBLE_DEVICES="6,7" python server.py`
507 | * Windows CMD: `set CUDA_VISIBLE_DEVICES=6,7 && python server.py`
508 | * Windows PowerShell: `$env:CUDA_VISIBLE_DEVICES="6,7"; python server.py`
509 |
510 | **Note:** `CUDA_VISIBLE_DEVICES` selects GPUs; it does **not** fix OOM errors if the chosen GPU lacks sufficient memory.
511 |
512 | ## 🤝 Contributing
513 |
514 | Contributions are welcome! Please feel free to open an issue to report bugs or suggest features, or submit a Pull Request for improvements.
515 |
516 | ## 📜 License
517 |
518 | This project is licensed under the **MIT License**.
519 |
520 | You can find it here: [https://opensource.org/licenses/MIT](https://opensource.org/licenses/MIT)
521 |
522 | ## 🙏 Acknowledgements
523 |
524 | * **Core Model:** This project heavily relies on the excellent **[Dia TTS model](https://github.com/nari-labs/dia)** developed by **[Nari Labs](https://github.com/nari-labs)**. Their work in creating and open-sourcing the model is greatly appreciated.
525 | * **UI Inspiration:** Special thanks to **[Lex-au](https://github.com/Lex-au)** whose **[Orpheus-FastAPI](https://github.com/Lex-au/Orpheus-FastAPI)** project served as inspiration for the web interface design of this project.
526 | * **SafeTensors Conversion:** Thank you to user **[ttj on Hugging Face](https://huggingface.co/ttj)** for providing the converted **[SafeTensors weights](https://huggingface.co/ttj/dia-1.6b-safetensors)** used as the default in this server.
527 | * **Containerization Technologies:** [Docker](https://www.docker.com/) and [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-docker) for enabling consistent deployment environments.
528 | * **Core Libraries:**
529 | * [FastAPI](https://fastapi.tiangolo.com/)
530 | * [Uvicorn](https://www.uvicorn.org/)
531 | * [PyTorch](https://pytorch.org/)
532 | * [Hugging Face Hub](https://huggingface.co/docs/huggingface_hub/index) & [SafeTensors](https://github.com/huggingface/safetensors)
533 | * [Descript Audio Codec (DAC)](https://github.com/descriptinc/descript-audio-codec)
534 | * [SoundFile](https://python-soundfile.readthedocs.io/) & [libsndfile](http://www.mega-nerd.com/libsndfile/)
535 | * [Jinja2](https://jinja.palletsprojects.com/)
536 | * [WaveSurfer.js](https://wavesurfer.xyz/)
537 | * [Tailwind CSS](https://tailwindcss.com/) (via CDN)
538 |
539 | ---
540 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # config.py
2 | # Configuration management for Dia TTS server using YAML file.
3 |
4 | import os
5 | import logging
6 | import yaml
7 | import shutil
8 | from copy import deepcopy
9 | from threading import Lock
10 | from typing import Dict, Any, Optional, List, Tuple
11 |
12 | # Use dotenv for initial seeding ONLY if config.yaml is missing
13 | from dotenv import load_dotenv, find_dotenv
14 |
15 | # Configure logging
16 | logger = logging.getLogger(__name__)
17 |
18 | # --- Constants ---
19 | CONFIG_FILE_PATH = "config.yaml"
20 | ENV_FILE_PATH = find_dotenv() # Find .env file location
21 |
22 | # --- Default Configuration Structure ---
23 | # This defines the expected structure and default values for config.yaml
24 | # It's crucial to keep this up-to-date with all expected keys.
25 | DEFAULT_CONFIG: Dict[str, Any] = {
26 | "server": {
27 | "host": "0.0.0.0",
28 | "port": 8003,
29 | },
30 | "model": {
31 | "repo_id": "ttj/dia-1.6b-safetensors",
32 | "config_filename": "config.json",
33 | "weights_filename": "dia-v0_1_bf16.safetensors",
34 | "whisper_model_name": "small.en", # Added Whisper model here
35 | },
36 | "paths": {
37 | "model_cache": "./model_cache",
38 | "reference_audio": "./reference_audio",
39 | "output": "./outputs",
40 | "voices": "./voices", # Added predefined voices path
41 | },
42 | "generation_defaults": {
43 | "speed_factor": 1.0, # Changed default to 1.0
44 | "cfg_scale": 3.0,
45 | "temperature": 1.3,
46 | "top_p": 0.95,
47 | "cfg_filter_top_k": 35,
48 | "seed": 42,
49 | "split_text": True, # Default split enabled
50 | "chunk_size": 120, # Default chunk size
51 | },
52 | "ui_state": {
53 | "last_text": "",
54 | "last_voice_mode": "predefined", # Default to predefined
55 | "last_predefined_voice": None, # Store original filename (e.g., "Michael_Emily.wav")
56 | "last_reference_file": None,
57 | "last_seed": 42,
58 | "last_chunk_size": 120,
59 | "last_split_text_enabled": True,
60 | "hide_chunk_warning": False,
61 | "hide_generation_warning": False,
62 | },
63 | }
64 |
65 | # Mapping from .env variable names to config.yaml nested keys
66 | # Used ONLY during initial seeding if config.yaml is missing.
67 | ENV_TO_YAML_MAP: Dict[str, Tuple[List[str], type]] = {
68 | # Server
69 | "HOST": (["server", "host"], str),
70 | "PORT": (["server", "port"], int),
71 | # Model
72 | "DIA_MODEL_REPO_ID": (["model", "repo_id"], str),
73 | "DIA_MODEL_CONFIG_FILENAME": (["model", "config_filename"], str),
74 | "DIA_MODEL_WEIGHTS_FILENAME": (["model", "weights_filename"], str),
75 | "WHISPER_MODEL_NAME": (["model", "whisper_model_name"], str),
76 | # Paths
77 | "DIA_MODEL_CACHE_PATH": (["paths", "model_cache"], str),
78 | "REFERENCE_AUDIO_PATH": (["paths", "reference_audio"], str),
79 | "OUTPUT_PATH": (["paths", "output"], str),
80 | # Generation Defaults
81 | "GEN_DEFAULT_SPEED_FACTOR": (["generation_defaults", "speed_factor"], float),
82 | "GEN_DEFAULT_CFG_SCALE": (["generation_defaults", "cfg_scale"], float),
83 | "GEN_DEFAULT_TEMPERATURE": (["generation_defaults", "temperature"], float),
84 | "GEN_DEFAULT_TOP_P": (["generation_defaults", "top_p"], float),
85 | "GEN_DEFAULT_CFG_FILTER_TOP_K": (["generation_defaults", "cfg_filter_top_k"], int),
86 | "GEN_DEFAULT_SEED": (["generation_defaults", "seed"], int),
87 | # Note: split_text and chunk_size defaults are not typically in .env
88 | # Note: ui_state is never loaded from .env
89 | }
90 |
91 |
92 | def _deep_merge_dicts(source: Dict, destination: Dict) -> Dict:
93 | """
94 | Recursively merges source dict into destination dict.
95 | Modifies destination in place.
96 | """
97 | for key, value in source.items():
98 | if isinstance(value, dict):
99 | # get node or create one
100 | node = destination.setdefault(key, {})
101 | _deep_merge_dicts(value, node)
102 | else:
103 | destination[key] = value
104 | return destination
105 |
106 |
107 | def _set_nested_value(d: Dict, keys: List[str], value: Any):
108 | """Sets a value in a nested dictionary using a list of keys."""
109 | for key in keys[:-1]:
110 | d = d.setdefault(key, {})
111 | d[keys[-1]] = value
112 |
113 |
114 | def _get_nested_value(d: Dict, keys: List[str], default: Any = None) -> Any:
115 | """Gets a value from a nested dictionary using a list of keys."""
116 | for key in keys:
117 | if isinstance(d, dict) and key in d:
118 | d = d[key]
119 | else:
120 | return default
121 | return d
122 |
123 |
124 | class YamlConfigManager:
125 | """Manages configuration for the TTS server using a YAML file."""
126 |
127 | def __init__(self):
128 | """Initialize the configuration manager by loading the config."""
129 | self.config: Dict[str, Any] = {}
130 | self._lock = Lock() # Lock for thread-safe file writing
131 | self.load_config()
132 |
133 | def _load_defaults(self) -> Dict[str, Any]:
134 | """Returns a deep copy of the hardcoded default configuration."""
135 | return deepcopy(DEFAULT_CONFIG)
136 |
137 | def _load_env_overrides(self, config_dict: Dict[str, Any]) -> Dict[str, Any]:
138 | """
139 | Loads .env file (if found) and overrides values in the provided config_dict.
140 | Used ONLY during initial seeding or reset.
141 | """
142 | if not ENV_FILE_PATH:
143 | logger.info("No .env file found, skipping environment variable overrides.")
144 | return config_dict
145 |
146 | logger.info(f"Loading environment variables from: {ENV_FILE_PATH}")
147 | # Load .env variables into os.environ temporarily
148 | load_dotenv(dotenv_path=ENV_FILE_PATH, override=True)
149 |
150 | env_values_applied = 0
151 | for env_var, (yaml_path, target_type) in ENV_TO_YAML_MAP.items():
152 | env_value_str = os.environ.get(env_var)
153 | if env_value_str is not None:
154 | try:
155 | # Attempt type conversion
156 | if target_type is bool:
157 | converted_value = env_value_str.lower() in (
158 | "true",
159 | "1",
160 | "t",
161 | "yes",
162 | "y",
163 | )
164 | else:
165 | converted_value = target_type(env_value_str)
166 |
167 | _set_nested_value(config_dict, yaml_path, converted_value)
168 | logger.debug(
169 | f"Applied .env override: {'->'.join(yaml_path)} = {converted_value}"
170 | )
171 | env_values_applied += 1
172 | except (ValueError, TypeError) as e:
173 | logger.warning(
174 | f"Could not apply .env override for '{env_var}'. Invalid value '{env_value_str}' for type {target_type.__name__}. Using default. Error: {e}"
175 | )
176 | # Clean up the loaded env var from os.environ if desired, though it's usually harmless
177 | # if env_var in os.environ:
178 | # del os.environ[env_var]
179 |
180 | if env_values_applied > 0:
181 | logger.info(f"Applied {env_values_applied} overrides from .env file.")
182 | else:
183 | logger.info("No applicable overrides found in .env file.")
184 |
185 | return config_dict
186 |
187 | def load_config(self):
188 | """
189 | Loads configuration from config.yaml.
190 | If config.yaml doesn't exist, creates it by seeding from .env (if found)
191 | and hardcoded defaults.
192 | Ensures all default keys are present in the loaded config.
193 | """
194 | with self._lock: # Ensure loading process is atomic if called concurrently (unlikely at startup)
195 | loaded_config = self._load_defaults() # Start with defaults
196 |
197 | if os.path.exists(CONFIG_FILE_PATH):
198 | logger.info(f"Loading configuration from: {CONFIG_FILE_PATH}")
199 | try:
200 | with open(CONFIG_FILE_PATH, "r", encoding="utf-8") as f:
201 | yaml_data = yaml.safe_load(f)
202 | if isinstance(yaml_data, dict):
203 | # Merge loaded data onto defaults to ensure all keys exist
204 | # and to add new default keys if the file is old.
205 | loaded_config = _deep_merge_dicts(yaml_data, loaded_config)
206 | logger.info("Successfully loaded and merged config.yaml.")
207 | else:
208 | logger.error(
209 | f"Invalid format in {CONFIG_FILE_PATH}. Expected a dictionary (key-value pairs). Using defaults and attempting to overwrite."
210 | )
211 | # Proceed using defaults, will attempt to save later
212 | if not self._save_config_yaml_internal(loaded_config):
213 | logger.error(
214 | f"Failed to overwrite invalid {CONFIG_FILE_PATH} with defaults."
215 | )
216 |
217 | except yaml.YAMLError as e:
218 | logger.error(
219 | f"Error parsing {CONFIG_FILE_PATH}: {e}. Using defaults and attempting to overwrite."
220 | )
221 | if not self._save_config_yaml_internal(loaded_config):
222 | logger.error(
223 | f"Failed to overwrite corrupted {CONFIG_FILE_PATH} with defaults."
224 | )
225 | except Exception as e:
226 | logger.error(
227 | f"Unexpected error loading {CONFIG_FILE_PATH}: {e}. Using defaults.",
228 | exc_info=True,
229 | )
230 | # Don't try to save if it was an unexpected error
231 |
232 | else:
233 | logger.info(
234 | f"{CONFIG_FILE_PATH} not found. Creating initial configuration..."
235 | )
236 | # Seed from .env overrides onto the defaults
237 | loaded_config = self._load_env_overrides(loaded_config)
238 | # Save the newly created config
239 | if self._save_config_yaml_internal(loaded_config):
240 | logger.info(
241 | f"Successfully created and saved initial configuration to {CONFIG_FILE_PATH}."
242 | )
243 | else:
244 | logger.error(
245 | f"Failed to save initial configuration to {CONFIG_FILE_PATH}. Using in-memory defaults."
246 | )
247 |
248 | self.config = loaded_config
249 | logger.debug(f"Current config loaded: {self.config}")
250 | return self.config
251 |
252 | def _save_config_yaml_internal(self, config_dict: Dict[str, Any]) -> bool:
253 | """
254 | Internal method to save the configuration dictionary to config.yaml.
255 | Includes backup and restore mechanism. Assumes lock is already held.
256 | """
257 | temp_file_path = CONFIG_FILE_PATH + ".tmp"
258 | backup_file_path = CONFIG_FILE_PATH + ".bak"
259 |
260 | try:
261 | # Write to temporary file first
262 | with open(temp_file_path, "w", encoding="utf-8") as f:
263 | yaml.dump(
264 | config_dict, f, default_flow_style=False, sort_keys=False, indent=2
265 | )
266 |
267 | # Backup existing file if it exists
268 | if os.path.exists(CONFIG_FILE_PATH):
269 | try:
270 | shutil.move(CONFIG_FILE_PATH, backup_file_path)
271 | logger.debug(f"Backed up existing config to {backup_file_path}")
272 | except Exception as backup_err:
273 | logger.warning(
274 | f"Could not create backup of {CONFIG_FILE_PATH}: {backup_err}"
275 | )
276 | # Decide whether to proceed without backup or fail
277 | # Proceeding for now, but log the warning.
278 |
279 | # Rename temporary file to actual config file
280 | shutil.move(temp_file_path, CONFIG_FILE_PATH)
281 | logger.info(f"Configuration successfully saved to {CONFIG_FILE_PATH}")
282 | return True
283 |
284 | except yaml.YAMLError as e:
285 | logger.error(
286 | f"Error formatting data for {CONFIG_FILE_PATH}: {e}", exc_info=True
287 | )
288 | return False
289 | except Exception as e:
290 | logger.error(
291 | f"Failed to save configuration to {CONFIG_FILE_PATH}: {e}",
292 | exc_info=True,
293 | )
294 | # Attempt to restore backup if save failed mid-way
295 | if os.path.exists(backup_file_path) and not os.path.exists(
296 | CONFIG_FILE_PATH
297 | ):
298 | try:
299 | shutil.move(backup_file_path, CONFIG_FILE_PATH)
300 | logger.info(
301 | f"Restored configuration from backup {backup_file_path}"
302 | )
303 | except Exception as restore_err:
304 | logger.error(
305 | f"Failed to restore configuration from backup: {restore_err}"
306 | )
307 | # Clean up temp file if it still exists
308 | if os.path.exists(temp_file_path):
309 | try:
310 | os.remove(temp_file_path)
311 | except Exception as remove_err:
312 | logger.warning(
313 | f"Could not remove temporary config file {temp_file_path}: {remove_err}"
314 | )
315 | return False
316 | finally:
317 | # Clean up backup file if main file exists and save was successful (or if backup failed)
318 | if os.path.exists(CONFIG_FILE_PATH) and os.path.exists(backup_file_path):
319 | try:
320 | # Only remove backup if the main file write seems okay
321 | if os.path.getsize(CONFIG_FILE_PATH) > 0: # Basic check
322 | os.remove(backup_file_path)
323 | logger.debug(f"Removed backup file {backup_file_path}")
324 | except Exception as remove_bak_err:
325 | logger.warning(
326 | f"Could not remove backup config file {backup_file_path}: {remove_bak_err}"
327 | )
328 |
329 | def save_config_yaml(self, config_dict: Dict[str, Any]) -> bool:
330 | """Public method to save the configuration dictionary with locking."""
331 | with self._lock:
332 | return self._save_config_yaml_internal(config_dict)
333 |
334 | def get(self, key_path: str, default: Any = None) -> Any:
335 | """
336 | Get a configuration value using a dot-separated key path.
337 | e.g., get('server.port', 8000)
338 | """
339 | keys = key_path.split(".")
340 | value = _get_nested_value(self.config, keys, default)
341 | # Ensure we return a copy for mutable types like dicts/lists
342 | return deepcopy(value) if isinstance(value, (dict, list)) else value
343 |
344 | def get_all(self) -> Dict[str, Any]:
345 | """Get a deep copy of all current configuration values."""
346 | with self._lock: # Ensure consistency while copying
347 | return deepcopy(self.config)
348 |
349 | def update_and_save(self, partial_update_dict: Dict[str, Any]) -> bool:
350 | """
351 | Deep merges a partial update dictionary into the current config
352 | and saves the entire configuration back to the YAML file.
353 | """
354 | if not isinstance(partial_update_dict, dict):
355 | logger.error("Invalid partial update data: must be a dictionary.")
356 | return False
357 |
358 | with self._lock:
359 | try:
360 | # Create a deep copy to avoid modifying self.config directly before successful save
361 | updated_config = deepcopy(self.config)
362 | # Merge the partial update into the copy
363 | _deep_merge_dicts(partial_update_dict, updated_config)
364 |
365 | # Save the fully merged configuration
366 | if self._save_config_yaml_internal(updated_config):
367 | # If save was successful, update the in-memory config
368 | self.config = updated_config
369 | logger.info("Configuration updated and saved successfully.")
370 | return True
371 | else:
372 | logger.error("Failed to save updated configuration.")
373 | return False
374 | except Exception as e:
375 | logger.error(
376 | f"Error during configuration update and save: {e}", exc_info=True
377 | )
378 | return False
379 |
380 | def reset_and_save(self) -> bool:
381 | """
382 | Resets the configuration to hardcoded defaults, potentially overridden
383 | by values in the .env file, and saves it to config.yaml.
384 | """
385 | with self._lock:
386 | logger.warning(
387 | "Resetting configuration to defaults (with .env overrides)..."
388 | )
389 | reset_config = self._load_defaults()
390 | reset_config = self._load_env_overrides(
391 | reset_config
392 | ) # Apply .env overrides to defaults
393 |
394 | if self._save_config_yaml_internal(reset_config):
395 | self.config = reset_config # Update in-memory config
396 | logger.info("Configuration successfully reset and saved.")
397 | return True
398 | else:
399 | logger.error("Failed to save reset configuration.")
400 | # Keep the old config in memory if save failed
401 | return False
402 |
403 | # --- Type-specific Getters with Error Handling ---
404 | def get_int(self, key_path: str, default: Optional[int] = None) -> int:
405 | """Get a configuration value as an integer."""
406 | value = self.get(key_path)
407 | if value is None:
408 | if default is not None:
409 | logger.debug(
410 | f"Config '{key_path}' not found, using provided default: {default}"
411 | )
412 | return default
413 | else:
414 | logger.error(
415 | f"Mandatory config '{key_path}' not found and no default. Returning 0."
416 | )
417 | return 0
418 | try:
419 | return int(value)
420 | except (ValueError, TypeError):
421 | logger.warning(
422 | f"Invalid integer value '{value}' for '{key_path}'. Using default: {default}"
423 | )
424 | if isinstance(default, int):
425 | return default
426 | else:
427 | logger.error(
428 | f"Cannot parse '{value}' as int for '{key_path}' and no valid default. Returning 0."
429 | )
430 | return 0
431 |
432 | def get_float(self, key_path: str, default: Optional[float] = None) -> float:
433 | """Get a configuration value as a float."""
434 | value = self.get(key_path)
435 | if value is None:
436 | if default is not None:
437 | logger.debug(
438 | f"Config '{key_path}' not found, using provided default: {default}"
439 | )
440 | return default
441 | else:
442 | logger.error(
443 | f"Mandatory config '{key_path}' not found and no default. Returning 0.0."
444 | )
445 | return 0.0
446 | try:
447 | return float(value)
448 | except (ValueError, TypeError):
449 | logger.warning(
450 | f"Invalid float value '{value}' for '{key_path}'. Using default: {default}"
451 | )
452 | if isinstance(default, float):
453 | return default
454 | else:
455 | logger.error(
456 | f"Cannot parse '{value}' as float for '{key_path}' and no valid default. Returning 0.0."
457 | )
458 | return 0.0
459 |
460 | def get_bool(self, key_path: str, default: Optional[bool] = None) -> bool:
461 | """Get a configuration value as a boolean."""
462 | value = self.get(key_path)
463 | if value is None:
464 | if default is not None:
465 | logger.debug(
466 | f"Config '{key_path}' not found, using provided default: {default}"
467 | )
468 | return default
469 | else:
470 | logger.error(
471 | f"Mandatory config '{key_path}' not found and no default. Returning False."
472 | )
473 | return False
474 | if isinstance(value, bool):
475 | return value
476 | if isinstance(value, str):
477 | return value.lower() in ("true", "1", "t", "yes", "y")
478 | try:
479 | # Handle numeric representations (e.g., 1 for True, 0 for False)
480 | return bool(int(value))
481 | except (ValueError, TypeError):
482 | logger.warning(
483 | f"Invalid boolean value '{value}' for '{key_path}'. Using default: {default}"
484 | )
485 | if isinstance(default, bool):
486 | return default
487 | else:
488 | logger.error(
489 | f"Cannot parse '{value}' as bool for '{key_path}' and no valid default. Returning False."
490 | )
491 | return False
492 |
493 |
494 | # --- Create a singleton instance for global access ---
495 | config_manager = YamlConfigManager()
496 |
497 | # --- Export common getters for easy access ---
498 |
499 |
500 | # Helper to get default value from the DEFAULT_CONFIG structure
501 | def _get_default(key_path: str) -> Any:
502 | keys = key_path.split(".")
503 | return _get_nested_value(DEFAULT_CONFIG, keys)
504 |
505 |
506 | # Server Settings
507 | def get_host() -> str:
508 | return config_manager.get("server.host", _get_default("server.host"))
509 |
510 |
511 | def get_port() -> int:
512 | return config_manager.get_int("server.port", _get_default("server.port"))
513 |
514 |
515 | # Model Source Settings
516 | def get_model_repo_id() -> str:
517 | return config_manager.get("model.repo_id", _get_default("model.repo_id"))
518 |
519 |
520 | def get_model_config_filename() -> str:
521 | return config_manager.get(
522 | "model.config_filename", _get_default("model.config_filename")
523 | )
524 |
525 |
526 | def get_model_weights_filename() -> str:
527 | return config_manager.get(
528 | "model.weights_filename", _get_default("model.weights_filename")
529 | )
530 |
531 |
532 | def get_whisper_model_name() -> str:
533 | return config_manager.get(
534 | "model.whisper_model_name", _get_default("model.whisper_model_name")
535 | )
536 |
537 |
538 | # Path Settings
539 | def get_model_cache_path() -> str:
540 | return os.path.abspath(
541 | config_manager.get("paths.model_cache", _get_default("paths.model_cache"))
542 | )
543 |
544 |
545 | def get_reference_audio_path() -> str:
546 | return os.path.abspath(
547 | config_manager.get(
548 | "paths.reference_audio", _get_default("paths.reference_audio")
549 | )
550 | )
551 |
552 |
553 | def get_output_path() -> str:
554 | return os.path.abspath(
555 | config_manager.get("paths.output", _get_default("paths.output"))
556 | )
557 |
558 |
559 | def get_predefined_voices_path() -> str:
560 | return os.path.abspath(
561 | config_manager.get("paths.voices", _get_default("paths.voices"))
562 | )
563 |
564 |
565 | # Default Generation Parameter Getters
566 | def get_gen_default_speed_factor() -> float:
567 | return config_manager.get_float(
568 | "generation_defaults.speed_factor",
569 | _get_default("generation_defaults.speed_factor"),
570 | )
571 |
572 |
573 | def get_gen_default_cfg_scale() -> float:
574 | return config_manager.get_float(
575 | "generation_defaults.cfg_scale", _get_default("generation_defaults.cfg_scale")
576 | )
577 |
578 |
579 | def get_gen_default_temperature() -> float:
580 | return config_manager.get_float(
581 | "generation_defaults.temperature",
582 | _get_default("generation_defaults.temperature"),
583 | )
584 |
585 |
586 | def get_gen_default_top_p() -> float:
587 | return config_manager.get_float(
588 | "generation_defaults.top_p", _get_default("generation_defaults.top_p")
589 | )
590 |
591 |
592 | def get_gen_default_cfg_filter_top_k() -> int:
593 | return config_manager.get_int(
594 | "generation_defaults.cfg_filter_top_k",
595 | _get_default("generation_defaults.cfg_filter_top_k"),
596 | )
597 |
598 |
599 | def get_gen_default_seed() -> int:
600 | return config_manager.get_int(
601 | "generation_defaults.seed", _get_default("generation_defaults.seed")
602 | )
603 |
604 |
605 | def get_gen_default_split_text() -> bool:
606 | return config_manager.get_bool(
607 | "generation_defaults.split_text", _get_default("generation_defaults.split_text")
608 | )
609 |
610 |
611 | def get_gen_default_chunk_size() -> int:
612 | return config_manager.get_int(
613 | "generation_defaults.chunk_size", _get_default("generation_defaults.chunk_size")
614 | )
615 |
616 |
617 | # UI State Getters (might be less frequently needed directly in backend)
618 | def get_ui_state() -> Dict[str, Any]:
619 | """Gets the entire UI state dictionary."""
620 | return config_manager.get("ui_state", _get_default("ui_state"))
621 |
622 |
623 | def get_hide_chunk_warning() -> bool:
624 | """Gets the flag for hiding the chunk warning dialog."""
625 | return config_manager.get_bool(
626 | "ui_state.hide_chunk_warning", _get_default("ui_state.hide_chunk_warning")
627 | )
628 |
629 |
630 | def get_hide_generation_warning() -> bool:
631 | """Gets the flag for hiding the general generation warning dialog."""
632 | return config_manager.get_bool(
633 | "ui_state.hide_generation_warning",
634 | _get_default("ui_state.hide_generation_warning"),
635 | )
636 |
--------------------------------------------------------------------------------
/dia/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/dia/__init__.py
--------------------------------------------------------------------------------
/dia/audio.py:
--------------------------------------------------------------------------------
1 | import typing as tp
2 |
3 | import torch
4 |
5 |
6 | def build_delay_indices(
7 | B: int, T: int, C: int, delay_pattern: tp.List[int]
8 | ) -> tp.Tuple[torch.Tensor, torch.Tensor]:
9 | """
10 | Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
11 | Negative t_idx => BOS; t_idx >= T => PAD.
12 | """
13 | delay_arr = torch.tensor(delay_pattern, dtype=torch.int32)
14 |
15 | t_idx_BxT = torch.broadcast_to(
16 | torch.arange(T, dtype=torch.int32)[None, :],
17 | [B, T],
18 | )
19 | t_idx_BxTx1 = t_idx_BxT[..., None]
20 | t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C)
21 |
22 | b_idx_BxTxC = torch.broadcast_to(
23 | torch.arange(B, dtype=torch.int32).view(B, 1, 1),
24 | [B, T, C],
25 | )
26 | c_idx_BxTxC = torch.broadcast_to(
27 | torch.arange(C, dtype=torch.int32).view(1, 1, C),
28 | [B, T, C],
29 | )
30 |
31 | # We must clamp time indices to [0..T-1] so gather_nd equivalent won't fail
32 | t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1)
33 |
34 | indices_BTCx3 = torch.stack(
35 | [
36 | b_idx_BxTxC.reshape(-1),
37 | t_clamped_BxTxC.reshape(-1),
38 | c_idx_BxTxC.reshape(-1),
39 | ],
40 | dim=1,
41 | ).long() # Ensure indices are long type for indexing
42 |
43 | return t_idx_BxTxC, indices_BTCx3
44 |
45 |
46 | def apply_audio_delay(
47 | audio_BxTxC: torch.Tensor,
48 | pad_value: int,
49 | bos_value: int,
50 | precomp: tp.Tuple[torch.Tensor, torch.Tensor],
51 | ) -> torch.Tensor:
52 | """
53 | Applies the delay pattern to batched audio tokens using precomputed indices,
54 | inserting BOS where t_idx < 0 and PAD where t_idx >= T.
55 |
56 | Args:
57 | audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float)
58 | pad_value: the padding token
59 | bos_value: the BOS token
60 | precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices
61 |
62 | Returns:
63 | result_BxTxC: [B, T, C] delayed audio tokens
64 | """
65 | device = audio_BxTxC.device # Get device from input tensor
66 | t_idx_BxTxC, indices_BTCx3 = precomp
67 | t_idx_BxTxC = t_idx_BxTxC.to(device) # Move precomputed indices to device
68 | indices_BTCx3 = indices_BTCx3.to(device)
69 |
70 | # Equivalent of tf.gather_nd using advanced indexing
71 | # Ensure indices are long type if not already (build_delay_indices should handle this)
72 | gathered_flat = audio_BxTxC[
73 | indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]
74 | ]
75 | gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
76 |
77 | # Create masks on the correct device
78 | mask_bos = t_idx_BxTxC < 0 # => place bos_value
79 | mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] # => place pad_value
80 |
81 | # Create scalar tensors on the correct device
82 | bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device)
83 | pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
84 |
85 | # If mask_bos, BOS; else if mask_pad, PAD; else original gather
86 | # All tensors should now be on the same device
87 | result_BxTxC = torch.where(
88 | mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC)
89 | )
90 |
91 | return result_BxTxC
92 |
93 |
94 | def build_revert_indices(
95 | B: int, T: int, C: int, delay_pattern: tp.List[int]
96 | ) -> tp.Tuple[torch.Tensor, torch.Tensor]:
97 | """
98 | Precompute indices for the revert operation using PyTorch.
99 |
100 | Returns:
101 | A tuple (t_idx_BxTxC, indices_BTCx3) where:
102 | - t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay.
103 | - indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from:
104 | batch indices, clamped time indices, and channel indices.
105 | """
106 | # Use default device unless specified otherwise; assumes inputs might define device later
107 | device = None # Or determine dynamically if needed, e.g., from a model parameter
108 |
109 | delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device)
110 |
111 | t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T])
112 | t_idx_BT1 = t_idx_BT1.unsqueeze(-1)
113 |
114 | t_idx_BxTxC = torch.minimum(
115 | t_idx_BT1 + delay_arr.view(1, 1, C),
116 | torch.tensor(T - 1, device=device),
117 | )
118 | b_idx_BxTxC = torch.broadcast_to(
119 | torch.arange(B, device=device).view(B, 1, 1), [B, T, C]
120 | )
121 | c_idx_BxTxC = torch.broadcast_to(
122 | torch.arange(C, device=device).view(1, 1, C), [B, T, C]
123 | )
124 |
125 | indices_BTCx3 = torch.stack(
126 | [
127 | b_idx_BxTxC.reshape(-1),
128 | t_idx_BxTxC.reshape(-1),
129 | c_idx_BxTxC.reshape(-1),
130 | ],
131 | axis=1,
132 | ).long() # Ensure indices are long type
133 |
134 | return t_idx_BxTxC, indices_BTCx3
135 |
136 |
137 | def revert_audio_delay(
138 | audio_BxTxC: torch.Tensor,
139 | pad_value: int,
140 | precomp: tp.Tuple[torch.Tensor, torch.Tensor],
141 | T: int,
142 | ) -> torch.Tensor:
143 | """
144 | Reverts a delay pattern from batched audio tokens using precomputed indices (PyTorch version).
145 |
146 | Args:
147 | audio_BxTxC: Input delayed audio tensor
148 | pad_value: Padding value for out-of-bounds indices
149 | precomp: Precomputed revert indices tuple containing:
150 | - t_idx_BxTxC: Time offset indices tensor
151 | - indices_BTCx3: Gather indices tensor for original audio
152 | T: Original sequence length before padding
153 |
154 | Returns:
155 | Reverted audio tensor with same shape as input
156 | """
157 | t_idx_BxTxC, indices_BTCx3 = precomp
158 | device = audio_BxTxC.device # Get device from input tensor
159 |
160 | # Move precomputed indices to the same device as audio_BxTxC if they aren't already
161 | t_idx_BxTxC = t_idx_BxTxC.to(device)
162 | indices_BTCx3 = indices_BTCx3.to(device)
163 |
164 | # Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
165 | gathered_flat = audio_BxTxC[
166 | indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]
167 | ]
168 | gathered_BxTxC = gathered_flat.view(
169 | audio_BxTxC.size()
170 | ) # Use .size() for robust reshaping
171 |
172 | # Create pad_tensor on the correct device
173 | pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
174 | # Create T tensor on the correct device for comparison
175 | T_tensor = torch.tensor(T, device=device)
176 |
177 | result_BxTxC = torch.where(
178 | t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC
179 | ) # Changed np.where to torch.where
180 |
181 | return result_BxTxC
182 |
183 |
184 | @torch.no_grad()
185 | @torch.inference_mode()
186 | def decode(
187 | model,
188 | audio_codes,
189 | ):
190 | """
191 | Decodes the given frames into an output audio waveform
192 | """
193 | if len(audio_codes) != 1:
194 | raise ValueError(f"Expected one frame, got {len(audio_codes)}")
195 |
196 | try:
197 | audio_values = model.quantizer.from_codes(audio_codes)
198 | audio_values = model.decode(audio_values[0])
199 |
200 | return audio_values
201 | except Exception as e:
202 | print(f"Error in decode method: {str(e)}")
203 | raise
204 |
--------------------------------------------------------------------------------
/dia/config.py:
--------------------------------------------------------------------------------
1 | """Configuration management module for the Dia model.
2 |
3 | This module provides comprehensive configuration management for the Dia model,
4 | utilizing Pydantic for validation. It defines configurations for data processing,
5 | model architecture (encoder and decoder), and training settings.
6 |
7 | Key components:
8 | - DataConfig: Parameters for data loading and preprocessing.
9 | - EncoderConfig: Architecture details for the encoder module.
10 | - DecoderConfig: Architecture details for the decoder module.
11 | - ModelConfig: Combined model architecture settings.
12 | - TrainingConfig: Training hyperparameters and settings.
13 | - DiaConfig: Master configuration combining all components.
14 | """
15 |
16 | import os
17 | from typing import Annotated
18 |
19 | from pydantic import BaseModel, BeforeValidator, Field
20 |
21 |
22 | class DataConfig(BaseModel, frozen=True):
23 | """Configuration for data loading and preprocessing.
24 |
25 | Attributes:
26 | text_length: Maximum length of text sequences (must be multiple of 128).
27 | audio_length: Maximum length of audio sequences (must be multiple of 128).
28 | channels: Number of audio channels.
29 | text_pad_value: Value used for padding text sequences.
30 | audio_eos_value: Value representing the end of audio sequences.
31 | audio_bos_value: Value representing the beginning of audio sequences.
32 | audio_pad_value: Value used for padding audio sequences.
33 | delay_pattern: List of delay values for each audio channel.
34 | """
35 |
36 | text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
37 | Field(gt=0, multiple_of=128)
38 | )
39 | audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
40 | Field(gt=0, multiple_of=128)
41 | )
42 | channels: int = Field(default=9, gt=0, multiple_of=1)
43 | text_pad_value: int = Field(default=0)
44 | audio_eos_value: int = Field(default=1024)
45 | audio_pad_value: int = Field(default=1025)
46 | audio_bos_value: int = Field(default=1026)
47 | delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(
48 | default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15]
49 | )
50 |
51 | def __hash__(self) -> int:
52 | """Generate a hash based on all fields of the config."""
53 | return hash(
54 | (
55 | self.text_length,
56 | self.audio_length,
57 | self.channels,
58 | self.text_pad_value,
59 | self.audio_pad_value,
60 | self.audio_bos_value,
61 | self.audio_eos_value,
62 | tuple(self.delay_pattern),
63 | )
64 | )
65 |
66 |
67 | class EncoderConfig(BaseModel, frozen=True):
68 | """Configuration for the encoder component of the Dia model.
69 |
70 | Attributes:
71 | n_layer: Number of transformer layers.
72 | n_embd: Embedding dimension.
73 | n_hidden: Hidden dimension size in the MLP layers.
74 | n_head: Number of attention heads.
75 | head_dim: Dimension per attention head.
76 | """
77 |
78 | n_layer: int = Field(gt=0)
79 | n_embd: int = Field(gt=0)
80 | n_hidden: int = Field(gt=0)
81 | n_head: int = Field(gt=0)
82 | head_dim: int = Field(gt=0)
83 |
84 |
85 | class DecoderConfig(BaseModel, frozen=True):
86 | """Configuration for the decoder component of the Dia model.
87 |
88 | Attributes:
89 | n_layer: Number of transformer layers.
90 | n_embd: Embedding dimension.
91 | n_hidden: Hidden dimension size in the MLP layers.
92 | gqa_query_heads: Number of query heads for grouped-query self-attention.
93 | kv_heads: Number of key/value heads for grouped-query self-attention.
94 | gqa_head_dim: Dimension per query head for grouped-query self-attention.
95 | cross_query_heads: Number of query heads for cross-attention.
96 | cross_head_dim: Dimension per cross-attention head.
97 | """
98 |
99 | n_layer: int = Field(gt=0)
100 | n_embd: int = Field(gt=0)
101 | n_hidden: int = Field(gt=0)
102 | gqa_query_heads: int = Field(gt=0)
103 | kv_heads: int = Field(gt=0)
104 | gqa_head_dim: int = Field(gt=0)
105 | cross_query_heads: int = Field(gt=0)
106 | cross_head_dim: int = Field(gt=0)
107 |
108 |
109 | class ModelConfig(BaseModel, frozen=True):
110 | """Main configuration container for the Dia model architecture.
111 |
112 | Attributes:
113 | encoder: Configuration for the encoder component.
114 | decoder: Configuration for the decoder component.
115 | src_vocab_size: Size of the source (text) vocabulary.
116 | tgt_vocab_size: Size of the target (audio code) vocabulary.
117 | dropout: Dropout probability applied within the model.
118 | normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm).
119 | weight_dtype: Data type for model weights (e.g., "float32", "bfloat16").
120 | rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE).
121 | rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE).
122 | """
123 |
124 | encoder: EncoderConfig
125 | decoder: DecoderConfig
126 | src_vocab_size: int = Field(default=128, gt=0)
127 | tgt_vocab_size: int = Field(default=1028, gt=0)
128 | dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
129 | normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
130 | weight_dtype: str = Field(default="float32", description="Weight precision")
131 | rope_min_timescale: int = Field(
132 | default=1, description="Timescale For global Attention"
133 | )
134 | rope_max_timescale: int = Field(
135 | default=10_000, description="Timescale For global Attention"
136 | )
137 |
138 |
139 | class TrainingConfig(BaseModel, frozen=True):
140 | pass
141 |
142 |
143 | class DiaConfig(BaseModel, frozen=True):
144 | """Master configuration for the Dia model.
145 |
146 | Combines all sub-configurations into a single validated object.
147 |
148 | Attributes:
149 | version: Configuration version string.
150 | model: Model architecture configuration.
151 | training: Training process configuration (precision settings).
152 | data: Data loading and processing configuration.
153 | """
154 |
155 | version: str = Field(default="1.0")
156 | model: ModelConfig
157 | # TODO: remove training. this is just for backward compatibility
158 | training: TrainingConfig | None = Field(default=None)
159 | data: DataConfig
160 |
161 | def save(self, path: str) -> None:
162 | """Save the current configuration instance to a JSON file.
163 |
164 | Ensures the parent directory exists and the file has a .json extension.
165 |
166 | Args:
167 | path: The target file path to save the configuration.
168 |
169 | Raises:
170 | ValueError: If the path is not a file with a .json extension.
171 | """
172 | os.makedirs(os.path.dirname(path), exist_ok=True)
173 | config_json = self.model_dump_json(indent=2)
174 | with open(path, "w") as f:
175 | f.write(config_json)
176 |
177 | @classmethod
178 | def load(cls, path: str) -> "DiaConfig | None":
179 | """Load and validate a Dia configuration from a JSON file.
180 |
181 | Args:
182 | path: The path to the configuration file.
183 |
184 | Returns:
185 | A validated DiaConfig instance if the file exists and is valid,
186 | otherwise None if the file is not found.
187 |
188 | Raises:
189 | ValueError: If the path does not point to an existing .json file.
190 | pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema.
191 | """
192 | try:
193 | with open(path, "r") as f:
194 | content = f.read()
195 | return cls.model_validate_json(content)
196 | except FileNotFoundError:
197 | return None
198 |
--------------------------------------------------------------------------------
/dia/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from huggingface_hub import PyTorchModelHubMixin
5 | from torch import Tensor
6 | from torch.nn import RMSNorm
7 |
8 | from .config import DiaConfig
9 | from .state import DecoderInferenceState, EncoderInferenceState, KVCache
10 |
11 |
12 | def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
13 | return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
14 |
15 |
16 | class DenseGeneral(nn.Module):
17 | """
18 | PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
19 |
20 | Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot
21 | for the generalized matrix multiplication. Weight/bias shapes are calculated
22 | and parameters created during initialization based on config.
23 | `load_weights` validates shapes and copies data.
24 |
25 | Attributes:
26 | axis (Tuple[int, ...]): Input axis or axes to contract.
27 | in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`.
28 | out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims).
29 | use_bias (bool): Whether to add a bias term.
30 | weight (nn.Parameter): The kernel parameter.
31 | bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True).
32 | """
33 |
34 | def __init__(
35 | self,
36 | in_shapes: tuple[int, ...],
37 | out_features: tuple[int, ...],
38 | axis: tuple[int, ...] = (-1,),
39 | weight_dtype: torch.dtype | None = None,
40 | device: torch.device | None = None,
41 | ):
42 | super().__init__()
43 | self.in_shapes = in_shapes
44 | self.out_features = out_features
45 | self.axis = axis
46 | self.kernel_shape = self.in_shapes + self.out_features
47 |
48 | factory_kwargs = {"device": device, "dtype": weight_dtype}
49 | self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
50 |
51 | def forward(self, inputs: Tensor) -> Tensor:
52 | norm_axis = _normalize_axes(self.axis, inputs.ndim)
53 | kernel_contract_axes = tuple(range(len(norm_axis)))
54 |
55 | output = torch.tensordot(
56 | inputs.to(self.weight.dtype),
57 | self.weight,
58 | dims=(norm_axis, kernel_contract_axes),
59 | ).to(inputs.dtype)
60 | return output
61 |
62 |
63 | class MlpBlock(nn.Module):
64 | """MLP block using DenseGeneral."""
65 |
66 | def __init__(
67 | self, embed_dim: int, intermediate_dim: int, compute_dtype: torch.dtype
68 | ):
69 | super().__init__()
70 | self.dtype = compute_dtype
71 |
72 | self.wi_fused = DenseGeneral(
73 | in_shapes=(embed_dim,),
74 | out_features=(2, intermediate_dim),
75 | axis=(-1,),
76 | weight_dtype=compute_dtype,
77 | )
78 |
79 | self.wo = DenseGeneral(
80 | in_shapes=(intermediate_dim,),
81 | out_features=(embed_dim,),
82 | axis=(-1,),
83 | weight_dtype=compute_dtype,
84 | )
85 |
86 | def forward(self, x: torch.Tensor) -> torch.Tensor:
87 | """Forward pass."""
88 | fused_x = self.wi_fused(x)
89 |
90 | gate = fused_x[..., 0, :]
91 | up = fused_x[..., 1, :]
92 |
93 | hidden = torch.mul(F.silu(gate), up).to(self.dtype)
94 |
95 | output = self.wo(hidden)
96 | return output
97 |
98 |
99 | class RotaryEmbedding(nn.Module):
100 | """Rotary Position Embedding (RoPE) implementation in PyTorch."""
101 |
102 | def __init__(
103 | self,
104 | embedding_dims: int,
105 | min_timescale: int = 1,
106 | max_timescale: int = 10000,
107 | dtype: torch.dtype = torch.float32,
108 | ):
109 | super().__init__()
110 | if embedding_dims % 2 != 0:
111 | raise ValueError("Embedding dim must be even for RoPE.")
112 | self.embedding_dims = embedding_dims
113 | self.min_timescale = min_timescale
114 | self.max_timescale = max_timescale
115 | self.compute_dtype = dtype
116 |
117 | half_embedding_dim = embedding_dims // 2
118 | fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
119 | timescale = (
120 | self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
121 | ).to(torch.float32)
122 | self.register_buffer("timescale", timescale, persistent=False)
123 |
124 | def forward(self, inputs: torch.Tensor, position: torch.Tensor):
125 | """Applies RoPE."""
126 | position = position.unsqueeze(-1).unsqueeze(-1)
127 | sinusoid_inp = position / self.timescale
128 | sin = torch.sin(sinusoid_inp)
129 | cos = torch.cos(sinusoid_inp)
130 | first_half, second_half = torch.chunk(inputs.to(torch.float32), 2, dim=-1)
131 | first_part = first_half * cos - second_half * sin
132 | second_part = second_half * cos + first_half * sin
133 | return torch.cat(
134 | (first_part.to(self.compute_dtype), second_part.to(self.compute_dtype)),
135 | dim=-1,
136 | )
137 |
138 |
139 | class Attention(nn.Module):
140 | """Attention using DenseGeneral."""
141 |
142 | def __init__(
143 | self,
144 | config: DiaConfig,
145 | q_embed_dim: int,
146 | kv_embed_dim: int,
147 | num_query_heads: int,
148 | num_kv_heads: int,
149 | head_dim: int,
150 | compute_dtype: torch.dtype,
151 | is_cross_attn: bool = False,
152 | out_embed_dim: int | None = None,
153 | ):
154 | super().__init__()
155 | self.num_query_heads = num_query_heads
156 | self.num_kv_heads = num_kv_heads
157 | self.head_dim = head_dim
158 | self.is_cross_attn = is_cross_attn
159 | self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
160 | self.projected_query_dim = num_query_heads * head_dim
161 | if num_query_heads % num_kv_heads != 0:
162 | raise ValueError(
163 | f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
164 | )
165 | self.num_gqa_groups = num_query_heads // num_kv_heads
166 |
167 | # --- Projection Layers using DenseGeneral ---
168 | self.q_proj = DenseGeneral(
169 | in_shapes=(q_embed_dim,),
170 | out_features=(num_query_heads, head_dim),
171 | axis=(-1,),
172 | weight_dtype=compute_dtype,
173 | )
174 | self.k_proj = DenseGeneral(
175 | in_shapes=(kv_embed_dim,),
176 | out_features=(num_kv_heads, head_dim),
177 | axis=(-1,),
178 | weight_dtype=compute_dtype,
179 | )
180 | self.v_proj = DenseGeneral(
181 | in_shapes=(kv_embed_dim,),
182 | out_features=(num_kv_heads, head_dim),
183 | axis=(-1,),
184 | weight_dtype=compute_dtype,
185 | )
186 | self.o_proj = DenseGeneral(
187 | in_shapes=(num_query_heads, head_dim),
188 | out_features=(self.output_dim,),
189 | axis=(-2, -1),
190 | weight_dtype=compute_dtype,
191 | )
192 |
193 | # --- Rotary Embedding ---
194 | self.rotary_emb = RotaryEmbedding(
195 | embedding_dims=self.head_dim,
196 | min_timescale=config.model.rope_min_timescale,
197 | max_timescale=config.model.rope_max_timescale,
198 | dtype=compute_dtype,
199 | )
200 |
201 | def forward(
202 | self,
203 | Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
204 | Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
205 | q_positions: torch.Tensor, # (B, T)
206 | kv_positions: torch.Tensor | None = None, # (B, S)
207 | attn_mask: (
208 | torch.Tensor | None
209 | ) = None, # None in Decoder Self Attention, Valid mask in Others
210 | cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
211 | prefill: bool = False,
212 | is_causal: bool = False,
213 | ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
214 | """
215 | Performs attention calculation with optional KV caching.
216 |
217 | Args:
218 | Xq: Query tensor (B, T, D). T=1 during single-step decoding.
219 | Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
220 | q_positions: Positions for queries (B, T).
221 | kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
222 | attn_mask: Attention mask.
223 | cache: KVCache.
224 | prefill: If True, use prefill mode.
225 |
226 | Returns:
227 | A tuple containing:
228 | - output: The attention output tensor (B, T, output_dim).
229 | - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
230 | """
231 | if kv_positions is None:
232 | kv_positions = q_positions
233 | original_dtype = Xq.dtype
234 |
235 | Xq_BxTxNxH = self.q_proj(Xq)
236 | Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
237 | Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
238 |
239 | attn_k: torch.Tensor | None = None
240 | attn_v: torch.Tensor | None = None
241 |
242 | if self.is_cross_attn:
243 | attn_k, attn_v = cache.k, cache.v
244 | else:
245 | Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
246 | Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
247 | Xk_BxSxKxH = self.rotary_emb(
248 | Xk_BxSxKxH, position=kv_positions
249 | ) # (B, S, K, H)
250 |
251 | Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
252 | Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
253 |
254 | if cache is None:
255 | attn_k = Xk_BxKxSxH
256 | attn_v = Xv_BxKxSxH
257 | else:
258 | if prefill:
259 | attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH
260 | cache.prefill(attn_k, attn_v)
261 | else:
262 | attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH)
263 |
264 | if attn_k is not None and attn_v is not None:
265 | attn_k = attn_k.to(Xq_BxNxTxH.dtype)
266 | attn_v = attn_v.to(Xq_BxNxTxH.dtype)
267 |
268 | attn_output = F.scaled_dot_product_attention(
269 | Xq_BxNxTxH,
270 | attn_k,
271 | attn_v,
272 | attn_mask=attn_mask,
273 | scale=1.0,
274 | enable_gqa=self.num_gqa_groups > 1,
275 | is_causal=is_causal,
276 | )
277 |
278 | attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
279 | output = self.o_proj(attn_output)
280 |
281 | return output.to(original_dtype)
282 |
283 |
284 | class EncoderLayer(nn.Module):
285 | """Transformer Encoder Layer using DenseGeneral."""
286 |
287 | def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
288 | super().__init__()
289 | self.config = config
290 | model_config = config.model
291 | enc_config = config.model.encoder
292 | embed_dim = enc_config.n_embd
293 | self.compute_dtype = compute_dtype
294 |
295 | self.pre_sa_norm = RMSNorm(
296 | embed_dim,
297 | eps=model_config.normalization_layer_epsilon,
298 | dtype=torch.float32,
299 | )
300 | self.self_attention = Attention(
301 | config,
302 | q_embed_dim=embed_dim,
303 | kv_embed_dim=embed_dim,
304 | num_query_heads=enc_config.n_head,
305 | num_kv_heads=enc_config.n_head,
306 | head_dim=enc_config.head_dim,
307 | compute_dtype=compute_dtype,
308 | is_cross_attn=False,
309 | out_embed_dim=embed_dim,
310 | )
311 | self.post_sa_norm = RMSNorm(
312 | embed_dim,
313 | eps=model_config.normalization_layer_epsilon,
314 | dtype=torch.float32,
315 | )
316 | self.mlp = MlpBlock(
317 | embed_dim=embed_dim,
318 | intermediate_dim=enc_config.n_hidden,
319 | compute_dtype=compute_dtype,
320 | )
321 |
322 | def forward(
323 | self,
324 | x: torch.Tensor,
325 | state: EncoderInferenceState,
326 | ) -> torch.Tensor:
327 | residual = x
328 | x_norm = self.pre_sa_norm(x).to(self.compute_dtype)
329 |
330 | sa_out = self.self_attention(
331 | Xq=x_norm,
332 | Xkv=x_norm,
333 | q_positions=state.positions,
334 | kv_positions=state.positions,
335 | attn_mask=state.attn_mask,
336 | )
337 | x = residual + sa_out
338 |
339 | residual = x
340 | x_norm = self.post_sa_norm(x).to(self.compute_dtype)
341 | mlp_out = self.mlp(x_norm)
342 | x = residual + mlp_out
343 |
344 | return x
345 |
346 |
347 | class Encoder(nn.Module):
348 | """Transformer Encoder Stack using DenseGeneral."""
349 |
350 | def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
351 | super().__init__()
352 | self.config = config
353 | model_config = config.model
354 | enc_config = config.model.encoder
355 | self.compute_dtype = compute_dtype
356 |
357 | self.embedding = nn.Embedding(
358 | model_config.src_vocab_size,
359 | enc_config.n_embd,
360 | dtype=compute_dtype,
361 | )
362 | self.layers = nn.ModuleList(
363 | [EncoderLayer(config, compute_dtype) for _ in range(enc_config.n_layer)]
364 | )
365 | self.norm = RMSNorm(
366 | enc_config.n_embd,
367 | eps=model_config.normalization_layer_epsilon,
368 | dtype=torch.float32,
369 | )
370 |
371 | def forward(
372 | self,
373 | x_ids: torch.Tensor,
374 | state: EncoderInferenceState,
375 | ) -> torch.Tensor:
376 | x = self.embedding(x_ids)
377 |
378 | for layer in self.layers:
379 | x = layer(x, state)
380 |
381 | x = self.norm(x).to(self.compute_dtype)
382 | return x
383 |
384 |
385 | class DecoderLayer(nn.Module):
386 | """Transformer Decoder Layer using DenseGeneral."""
387 |
388 | def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
389 | super().__init__()
390 | self.config = config
391 | model_config = config.model
392 | dec_config = config.model.decoder
393 | enc_config = config.model.encoder
394 | dec_embed_dim = dec_config.n_embd
395 | enc_embed_dim = enc_config.n_embd
396 | self.compute_dtype = compute_dtype
397 |
398 | # Norms
399 | self.pre_sa_norm = RMSNorm(
400 | dec_embed_dim,
401 | eps=model_config.normalization_layer_epsilon,
402 | dtype=torch.float32,
403 | )
404 | self.pre_ca_norm = RMSNorm(
405 | dec_embed_dim,
406 | eps=model_config.normalization_layer_epsilon,
407 | dtype=torch.float32,
408 | )
409 | self.pre_mlp_norm = RMSNorm(
410 | dec_embed_dim,
411 | eps=model_config.normalization_layer_epsilon,
412 | dtype=torch.float32,
413 | )
414 |
415 | # Self-Attention (GQA) with Causal Masking
416 | self.self_attention = Attention(
417 | config,
418 | q_embed_dim=dec_embed_dim,
419 | kv_embed_dim=dec_embed_dim,
420 | num_query_heads=dec_config.gqa_query_heads,
421 | num_kv_heads=dec_config.kv_heads,
422 | head_dim=dec_config.gqa_head_dim,
423 | compute_dtype=compute_dtype,
424 | is_cross_attn=False,
425 | out_embed_dim=dec_embed_dim,
426 | )
427 | # Cross-Attention (MHA)
428 | self.cross_attention = Attention(
429 | config=config,
430 | q_embed_dim=dec_embed_dim,
431 | kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
432 | num_query_heads=dec_config.cross_query_heads,
433 | num_kv_heads=dec_config.cross_query_heads,
434 | head_dim=dec_config.cross_head_dim,
435 | compute_dtype=compute_dtype,
436 | is_cross_attn=True,
437 | out_embed_dim=dec_embed_dim,
438 | )
439 | # MLP
440 | self.mlp = MlpBlock(
441 | embed_dim=dec_embed_dim,
442 | intermediate_dim=dec_config.n_hidden,
443 | compute_dtype=compute_dtype,
444 | )
445 |
446 | def forward(
447 | self,
448 | x: torch.Tensor,
449 | state: DecoderInferenceState,
450 | self_attn_cache: KVCache | None = None,
451 | cross_attn_cache: KVCache | None = None,
452 | prefill: bool = False,
453 | ) -> torch.Tensor:
454 | residual = x
455 | x_norm = self.pre_sa_norm(x).to(self.compute_dtype)
456 |
457 | sa_out = self.self_attention(
458 | Xq=x_norm, # (2, 1, D)
459 | Xkv=x_norm, # (2, 1, D)
460 | q_positions=state.dec_positions, # (2, 1)
461 | kv_positions=state.dec_positions, # (2, 1)
462 | attn_mask=None,
463 | cache=self_attn_cache,
464 | prefill=prefill,
465 | is_causal=prefill,
466 | )
467 |
468 | x = residual + sa_out
469 |
470 | residual = x
471 | x_norm = self.pre_ca_norm(x).to(self.compute_dtype)
472 | ca_out = self.cross_attention(
473 | Xq=x_norm,
474 | Xkv=state.enc_out,
475 | q_positions=state.dec_positions,
476 | kv_positions=state.enc_positions,
477 | attn_mask=state.dec_cross_attn_mask,
478 | cache=cross_attn_cache,
479 | )
480 | x = residual + ca_out
481 |
482 | residual = x
483 | x_norm = self.pre_mlp_norm(x).to(self.compute_dtype)
484 | mlp_out = self.mlp(x_norm)
485 | x = residual + mlp_out
486 |
487 | return x
488 |
489 |
490 | class Decoder(nn.Module):
491 | """Transformer Decoder Stack using DenseGeneral."""
492 |
493 | def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
494 | super().__init__()
495 | self.config = config
496 | model_config = config.model
497 | dec_config = config.model.decoder
498 | data_config = config.data
499 | self.num_channels = data_config.channels
500 | self.num_layers = dec_config.n_layer
501 |
502 | self.embeddings = nn.ModuleList(
503 | [
504 | nn.Embedding(
505 | model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype
506 | )
507 | for _ in range(self.num_channels)
508 | ]
509 | )
510 | self.layers = nn.ModuleList(
511 | [
512 | DecoderLayer(config=config, compute_dtype=compute_dtype)
513 | for _ in range(self.num_layers)
514 | ]
515 | )
516 |
517 | self.norm = RMSNorm(
518 | dec_config.n_embd,
519 | eps=model_config.normalization_layer_epsilon,
520 | dtype=torch.float32,
521 | )
522 |
523 | self.logits_dense = DenseGeneral(
524 | in_shapes=(dec_config.n_embd,),
525 | out_features=(self.num_channels, model_config.tgt_vocab_size),
526 | axis=(-1,),
527 | weight_dtype=compute_dtype,
528 | )
529 |
530 | def precompute_cross_attn_cache(
531 | self,
532 | enc_out: torch.Tensor, # (B, S, E)
533 | enc_positions: torch.Tensor, # (B, S)
534 | ) -> list[KVCache]:
535 | """
536 | Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
537 | """
538 | per_layer_kv_cache: list[KVCache] = []
539 |
540 | for layer in self.layers:
541 | cross_attn_module = layer.cross_attention
542 | k_proj = cross_attn_module.k_proj(enc_out)
543 | v_proj = cross_attn_module.v_proj(enc_out)
544 |
545 | k_proj = cross_attn_module.rotary_emb(k_proj, position=enc_positions)
546 | k = k_proj.transpose(1, 2)
547 | v = v_proj.transpose(1, 2)
548 |
549 | per_layer_kv_cache.append(KVCache.from_kv(k, v))
550 |
551 | return per_layer_kv_cache
552 |
553 | def decode_step(
554 | self,
555 | tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
556 | state: DecoderInferenceState,
557 | ) -> torch.Tensor:
558 | """
559 | Performs a single decoding step, managing KV caches layer by layer.
560 |
561 | Returns:
562 | A tuple containing:
563 | - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
564 | """
565 |
566 | x = None
567 | for i in range(self.num_channels):
568 | channel_tokens = tgt_ids_Bx1xC[..., i]
569 | channel_embed = self.embeddings[i](channel_tokens)
570 | x = channel_embed if x is None else x + channel_embed
571 |
572 | for i, layer in enumerate(self.layers):
573 | self_cache = state.self_attn_cache[i]
574 | cross_cache = state.cross_attn_cache[i]
575 | x = layer(
576 | x, # (2, 1, D)
577 | state,
578 | self_attn_cache=self_cache,
579 | cross_attn_cache=cross_cache,
580 | )
581 |
582 | x = self.norm(x)
583 | logits_Bx1xCxV = self.logits_dense(x)
584 |
585 | return logits_Bx1xCxV.to(torch.float32)
586 |
587 | def forward(
588 | self, tgt_ids_BxTxC: torch.Tensor, state: DecoderInferenceState
589 | ) -> torch.Tensor:
590 | """
591 | Forward pass for the Decoder stack, managing KV caches.
592 |
593 | Args:
594 | tgt_ids_BxTxC: Target token IDs (B, T, C).
595 | encoder_out: Output from the encoder (B, S, E).
596 | tgt_positions: Positions for target sequence (B, T).
597 | src_positions: Positions for source sequence (B, S).
598 | self_attn_mask: Mask for self-attention.
599 | cross_attn_mask: Mask for cross-attention.
600 | past_key_values: List containing the self-attention KV cache for each layer
601 | from the previous decoding step. `len(past_key_values)` should
602 | equal `num_layers`.
603 | precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache
604 | derived from `encoder_out`. This is passed identically
605 | to all layers.
606 |
607 | Returns:
608 | A tuple containing:
609 | - logits: The final output logits (B, T, C * V), cast to float32.
610 | - present_key_values: A list containing the updated self-attention KV cache
611 | for each layer for the *current* decoding step.
612 | """
613 | _, _, num_channels_in = tgt_ids_BxTxC.shape
614 | assert num_channels_in == self.num_channels, "Input channels mismatch"
615 |
616 | # Embeddings
617 | x = None
618 | for i in range(self.num_channels):
619 | channel_tokens = tgt_ids_BxTxC[..., i]
620 | channel_embed = self.embeddings[i](channel_tokens)
621 | x = channel_embed if x is None else x + channel_embed
622 |
623 | for i, layer in enumerate(self.layers):
624 | self_cache = state.self_attn_cache[i]
625 | cross_cache = state.cross_attn_cache[i]
626 | x = layer(
627 | x,
628 | state,
629 | self_attn_cache=self_cache,
630 | cross_attn_cache=cross_cache,
631 | prefill=True,
632 | )
633 |
634 | # Final Norm
635 | x = self.norm(x)
636 | logits_BxTxCxV = self.logits_dense(x)
637 |
638 | return logits_BxTxCxV.to(torch.float32)
639 |
640 |
641 | class DiaModel(
642 | nn.Module,
643 | PyTorchModelHubMixin,
644 | repo_url="https://github.com/nari-labs/dia",
645 | pipeline_tag="text-to-speech",
646 | license="apache-2.0",
647 | coders={
648 | DiaConfig: (
649 | lambda x: x.model_dump(),
650 | lambda data: DiaConfig.model_validate(data),
651 | ),
652 | },
653 | ):
654 | """PyTorch Dia Model using DenseGeneral."""
655 |
656 | def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
657 | super().__init__()
658 | self.config = config
659 | self.encoder = Encoder(config, compute_dtype)
660 | self.decoder = Decoder(config, compute_dtype)
661 |
--------------------------------------------------------------------------------
/dia/model.py:
--------------------------------------------------------------------------------
1 | import time
2 | from enum import Enum
3 |
4 | import dac
5 | import numpy as np
6 | import torch
7 | import torchaudio
8 |
9 | from .audio import (
10 | apply_audio_delay,
11 | build_delay_indices,
12 | build_revert_indices,
13 | decode,
14 | revert_audio_delay,
15 | )
16 | from .config import DiaConfig
17 | from .layers import DiaModel
18 | from .state import DecoderInferenceState, DecoderOutput, EncoderInferenceState
19 |
20 |
21 | DEFAULT_SAMPLE_RATE = 44100
22 |
23 |
24 | def _get_default_device():
25 | if torch.cuda.is_available():
26 | return torch.device("cuda")
27 | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
28 | return torch.device("mps")
29 | return torch.device("cpu")
30 |
31 |
32 | def _sample_next_token(
33 | logits_BCxV: torch.Tensor,
34 | temperature: float,
35 | top_p: float,
36 | cfg_filter_top_k: int | None = None,
37 | generator: torch.Generator = None, # Added generator parameter
38 | ) -> torch.Tensor:
39 | if temperature == 0.0:
40 | return torch.argmax(logits_BCxV, dim=-1)
41 |
42 | logits_BCxV = logits_BCxV / temperature
43 | if cfg_filter_top_k is not None:
44 | _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
45 | mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
46 | mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
47 | logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
48 |
49 | if top_p < 1.0:
50 | probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
51 | sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(
52 | probs_BCxV, dim=-1, descending=True
53 | )
54 | cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
55 |
56 | sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
57 | sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[
58 | ..., :-1
59 | ].clone()
60 | sorted_indices_to_remove_BCxV[..., 0] = 0
61 |
62 | indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
63 | indices_to_remove_BCxV.scatter_(
64 | dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV
65 | )
66 | logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
67 |
68 | final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
69 |
70 | sampled_indices_BC = torch.multinomial(
71 | final_probs_BCxV,
72 | num_samples=1,
73 | generator=generator, # Pass generator to multinomial
74 | )
75 | sampled_indices_C = sampled_indices_BC.squeeze(-1)
76 | return sampled_indices_C
77 |
78 |
79 | class ComputeDtype(str, Enum):
80 | FLOAT32 = "float32"
81 | FLOAT16 = "float16"
82 | BFLOAT16 = "bfloat16"
83 |
84 | def to_dtype(self) -> torch.dtype:
85 | if self == ComputeDtype.FLOAT32:
86 | return torch.float32
87 | elif self == ComputeDtype.FLOAT16:
88 | return torch.float16
89 | elif self == ComputeDtype.BFLOAT16:
90 | return torch.bfloat16
91 | else:
92 | raise ValueError(f"Unsupported compute dtype: {self}")
93 |
94 |
95 | class Dia:
96 | def __init__(
97 | self,
98 | config: DiaConfig,
99 | compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
100 | device: torch.device | None = None,
101 | ):
102 | """Initializes the Dia model.
103 |
104 | Args:
105 | config: The configuration object for the model.
106 | device: The device to load the model onto. If None, will automatically select the best available device.
107 |
108 | Raises:
109 | RuntimeError: If there is an error loading the DAC model.
110 | """
111 | super().__init__()
112 | self.config = config
113 | self.device = device if device is not None else _get_default_device()
114 | if isinstance(compute_dtype, str):
115 | compute_dtype = ComputeDtype(compute_dtype)
116 | self.compute_dtype = compute_dtype.to_dtype()
117 | self.model = DiaModel(config, self.compute_dtype)
118 | self.dac_model = None
119 |
120 | @classmethod
121 | def from_local(
122 | cls,
123 | config_path: str,
124 | checkpoint_path: str,
125 | compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
126 | device: torch.device | None = None,
127 | ) -> "Dia":
128 | """Loads the Dia model from local configuration and checkpoint files.
129 |
130 | Args:
131 | config_path: Path to the configuration JSON file.
132 | checkpoint_path: Path to the model checkpoint (.pth) file.
133 | device: The device to load the model onto. If None, will automatically select the best available device.
134 |
135 | Returns:
136 | An instance of the Dia model loaded with weights and set to eval mode.
137 |
138 | Raises:
139 | FileNotFoundError: If the config or checkpoint file is not found.
140 | RuntimeError: If there is an error loading the checkpoint.
141 | """
142 | config = DiaConfig.load(config_path)
143 | if config is None:
144 | raise FileNotFoundError(f"Config file not found at {config_path}")
145 |
146 | dia = cls(config, compute_dtype, device)
147 |
148 | try:
149 | state_dict = torch.load(checkpoint_path, map_location=dia.device)
150 | dia.model.load_state_dict(state_dict)
151 | except FileNotFoundError:
152 | raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
153 | except Exception as e:
154 | raise RuntimeError(
155 | f"Error loading checkpoint from {checkpoint_path}"
156 | ) from e
157 |
158 | dia.model.to(dia.device)
159 | dia.model.eval()
160 | dia._load_dac_model()
161 | return dia
162 |
163 | @classmethod
164 | def from_pretrained(
165 | cls,
166 | model_name: str = "nari-labs/Dia-1.6B",
167 | compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
168 | device: torch.device | None = None,
169 | ) -> "Dia":
170 | """Loads the Dia model from a Hugging Face Hub repository.
171 |
172 | Downloads the configuration and checkpoint files from the specified
173 | repository ID and then loads the model.
174 |
175 | Args:
176 | model_name: The Hugging Face Hub repository ID (e.g., "nari-labs/Dia-1.6B").
177 | compute_dtype: The computation dtype to use.
178 | device: The device to load the model onto. If None, will automatically select the best available device.
179 |
180 | Returns:
181 | An instance of the Dia model loaded with weights and set to eval mode.
182 |
183 | Raises:
184 | FileNotFoundError: If config or checkpoint download/loading fails.
185 | RuntimeError: If there is an error loading the checkpoint.
186 | """
187 | if isinstance(compute_dtype, str):
188 | compute_dtype = ComputeDtype(compute_dtype)
189 | loaded_model = DiaModel.from_pretrained(
190 | model_name, compute_dtype=compute_dtype.to_dtype()
191 | )
192 | config = loaded_model.config
193 | dia = cls(config, compute_dtype, device)
194 |
195 | dia.model = loaded_model
196 | dia.model.to(dia.device)
197 | dia.model.eval()
198 | dia._load_dac_model()
199 | return dia
200 |
201 | def _load_dac_model(self):
202 | try:
203 | dac_model_path = dac.utils.download()
204 | dac_model = dac.DAC.load(dac_model_path).to(self.device)
205 | except Exception as e:
206 | raise RuntimeError("Failed to load DAC model") from e
207 | self.dac_model = dac_model
208 |
209 | def _prepare_text_input(self, text: str) -> torch.Tensor:
210 | """Encodes text prompt, pads, and creates attention mask and positions."""
211 | text_pad_value = self.config.data.text_pad_value
212 | max_len = self.config.data.text_length
213 |
214 | byte_text = text.encode("utf-8")
215 | replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02")
216 | text_tokens = list(replaced_bytes)
217 |
218 | current_len = len(text_tokens)
219 | padding_needed = max_len - current_len
220 | if padding_needed <= 0:
221 | text_tokens = text_tokens[:max_len]
222 | padded_text_np = np.array(text_tokens, dtype=np.uint8)
223 | else:
224 | padded_text_np = np.pad(
225 | text_tokens,
226 | (0, padding_needed),
227 | mode="constant",
228 | constant_values=text_pad_value,
229 | ).astype(np.uint8)
230 |
231 | src_tokens = (
232 | torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0)
233 | ) # [1, S]
234 | return src_tokens
235 |
236 | def _prepare_audio_prompt(
237 | self, audio_prompt: torch.Tensor | None
238 | ) -> tuple[torch.Tensor, int]:
239 | num_channels = self.config.data.channels
240 | audio_bos_value = self.config.data.audio_bos_value
241 | audio_pad_value = self.config.data.audio_pad_value
242 | delay_pattern = self.config.data.delay_pattern
243 | max_delay_pattern = max(delay_pattern)
244 |
245 | prefill = torch.full(
246 | (1, num_channels),
247 | fill_value=audio_bos_value,
248 | dtype=torch.int,
249 | device=self.device,
250 | )
251 |
252 | prefill_step = 1
253 |
254 | if audio_prompt is not None:
255 | prefill_step += audio_prompt.shape[0]
256 | prefill = torch.cat([prefill, audio_prompt], dim=0)
257 |
258 | delay_pad_tensor = torch.full(
259 | (max_delay_pattern, num_channels),
260 | fill_value=-1,
261 | dtype=torch.int,
262 | device=self.device,
263 | )
264 | prefill = torch.cat([prefill, delay_pad_tensor], dim=0)
265 |
266 | delay_precomp = build_delay_indices(
267 | B=1,
268 | T=prefill.shape[0],
269 | C=num_channels,
270 | delay_pattern=delay_pattern,
271 | )
272 |
273 | prefill = apply_audio_delay(
274 | audio_BxTxC=prefill.unsqueeze(0),
275 | pad_value=audio_pad_value,
276 | bos_value=audio_bos_value,
277 | precomp=delay_precomp,
278 | ).squeeze(0)
279 |
280 | return prefill, prefill_step
281 |
282 | def _prepare_generation(
283 | self, text: str, audio_prompt: str | torch.Tensor | None, verbose: bool
284 | ):
285 | enc_input_cond = self._prepare_text_input(text)
286 | enc_input_uncond = torch.zeros_like(enc_input_cond)
287 | enc_input = torch.cat([enc_input_uncond, enc_input_cond], dim=0)
288 |
289 | if isinstance(audio_prompt, str):
290 | audio_prompt = self.load_audio(audio_prompt)
291 | prefill, prefill_step = self._prepare_audio_prompt(audio_prompt)
292 |
293 | if verbose:
294 | print("generate: data loaded")
295 |
296 | enc_state = EncoderInferenceState.new(self.config, enc_input_cond)
297 | encoder_out = self.model.encoder(enc_input, enc_state)
298 |
299 | # Clean up inputs after encoding
300 | del enc_input_uncond, enc_input
301 |
302 | dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache(
303 | encoder_out, enc_state.positions
304 | )
305 | dec_state = DecoderInferenceState.new(
306 | self.config,
307 | enc_state,
308 | encoder_out,
309 | dec_cross_attn_cache,
310 | self.compute_dtype,
311 | )
312 |
313 | # Can delete encoder output after it's used by decoder state
314 | del dec_cross_attn_cache
315 |
316 | dec_output = DecoderOutput.new(self.config, self.device)
317 | dec_output.prefill(prefill, prefill_step)
318 |
319 | dec_step = prefill_step - 1
320 | if dec_step > 0:
321 | dec_state.prepare_step(0, dec_step)
322 | tokens_BxTxC = (
323 | dec_output.get_tokens_at(0, dec_step).unsqueeze(0).expand(2, -1, -1)
324 | )
325 | self.model.decoder.forward(tokens_BxTxC, dec_state)
326 |
327 | return dec_state, dec_output
328 |
329 | def _decoder_step(
330 | self,
331 | tokens_Bx1xC: torch.Tensor,
332 | dec_state: DecoderInferenceState,
333 | cfg_scale: float,
334 | temperature: float,
335 | top_p: float,
336 | cfg_filter_top_k: int,
337 | generator: torch.Generator = None, # Added generator parameter
338 | ) -> torch.Tensor:
339 | audio_eos_value = self.config.data.audio_eos_value
340 | logits_Bx1xCxV = self.model.decoder.decode_step(tokens_Bx1xC, dec_state)
341 |
342 | logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :]
343 | # ADD: Remove the full logits tensor
344 | del logits_Bx1xCxV
345 |
346 | uncond_logits_CxV = logits_last_BxCxV[0, :, :]
347 | cond_logits_CxV = logits_last_BxCxV[1, :, :]
348 | # ADD: Remove the combined logits
349 | del logits_last_BxCxV
350 |
351 | logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
352 | # ADD: Remove component logits
353 | del uncond_logits_CxV, cond_logits_CxV
354 |
355 | logits_CxV[:, audio_eos_value + 1 :] = -torch.inf
356 | logits_CxV[1:, audio_eos_value:] = -torch.inf
357 |
358 | pred_C = _sample_next_token(
359 | logits_CxV.float(),
360 | temperature=temperature,
361 | top_p=top_p,
362 | cfg_filter_top_k=cfg_filter_top_k,
363 | generator=generator, # Pass generator to _sample_next_token
364 | )
365 | # ADD: Remove final logits
366 | del logits_CxV
367 |
368 | return pred_C
369 |
370 | def _generate_output(self, generated_codes: torch.Tensor) -> np.ndarray:
371 | num_channels = self.config.data.channels
372 | seq_length = generated_codes.shape[0]
373 | delay_pattern = self.config.data.delay_pattern
374 | audio_pad_value = self.config.data.audio_pad_value
375 | max_delay_pattern = max(delay_pattern)
376 |
377 | revert_precomp = build_revert_indices(
378 | B=1,
379 | T=seq_length,
380 | C=num_channels,
381 | delay_pattern=delay_pattern,
382 | )
383 |
384 | codebook = revert_audio_delay(
385 | audio_BxTxC=generated_codes.unsqueeze(0),
386 | pad_value=audio_pad_value,
387 | precomp=revert_precomp,
388 | T=seq_length,
389 | )[:, :-max_delay_pattern, :]
390 |
391 | # ADD: Clean up intermediate tensors
392 | del revert_precomp, generated_codes
393 |
394 | min_valid_index = 0
395 | max_valid_index = 1023
396 | invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
397 | codebook[invalid_mask] = 0
398 |
399 | # Process final audio
400 | audio = decode(self.dac_model, codebook.transpose(1, 2))
401 |
402 | # ADD: Clean up codebook
403 | del codebook, invalid_mask
404 |
405 | result = audio.squeeze().cpu().numpy()
406 |
407 | # ADD: Clean up audio tensor
408 | del audio
409 | torch.cuda.empty_cache()
410 |
411 | return result
412 |
413 | def load_audio(self, audio_path: str) -> torch.Tensor:
414 | audio, sr = torchaudio.load(audio_path, channels_first=True) # C, T
415 | if sr != DEFAULT_SAMPLE_RATE:
416 | audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE)
417 | audio = audio.to(self.device).unsqueeze(0) # 1, C, T
418 | audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE)
419 | _, encoded_frame, _, _, _ = self.dac_model.encode(audio_data) # 1, C, T
420 |
421 | # Clean up intermediate tensors
422 | del audio_data
423 | if sr != DEFAULT_SAMPLE_RATE:
424 | del audio
425 | encoded_result = encoded_frame.squeeze(0).transpose(0, 1)
426 | del encoded_frame
427 |
428 | return encoded_result
429 |
430 | def save_audio(self, path: str, audio: np.ndarray):
431 | import soundfile as sf
432 |
433 | sf.write(path, audio, DEFAULT_SAMPLE_RATE)
434 |
435 | @torch.inference_mode()
436 | def generate(
437 | self,
438 | text: str,
439 | max_tokens: int | None = None,
440 | cfg_scale: float = 3.0,
441 | temperature: float = 1.3,
442 | top_p: float = 0.95,
443 | use_torch_compile: bool = False,
444 | cfg_filter_top_k: int = 35,
445 | audio_prompt: str | torch.Tensor | None = None,
446 | audio_prompt_path: str | None = None,
447 | use_cfg_filter: bool | None = None,
448 | seed: int = 42, # Added seed parameter
449 | verbose: bool = True,
450 | text_to_generate_size: int | None = None,
451 | ) -> np.ndarray:
452 | # Import tqdm here to avoid requiring it as a dependency if progress bar isn't used
453 | from tqdm import tqdm
454 |
455 | audio_eos_value = self.config.data.audio_eos_value
456 | audio_pad_value = self.config.data.audio_pad_value
457 | delay_pattern = self.config.data.delay_pattern
458 |
459 | # Estimate tokens based on text length (using the 6.19 tokens per char ratio)
460 | if audio_prompt:
461 | estimated_tokens = int(text_to_generate_size * 3.4)
462 | else:
463 | estimated_tokens = int(text_to_generate_size * 6.7)
464 |
465 | current_step = 0
466 |
467 | # Cap at model maximum
468 | max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
469 |
470 | max_delay_pattern = max(delay_pattern) if delay_pattern else 0
471 | self.model.eval()
472 |
473 | if audio_prompt_path:
474 | print("Warning: audio_prompt_path is deprecated. Use audio_prompt instead.")
475 | audio_prompt = audio_prompt_path
476 | if use_cfg_filter is not None:
477 | print("Warning: use_cfg_filter is deprecated.")
478 |
479 | # Create generator from seed if provided
480 | generator = None
481 | if seed >= 0:
482 | if verbose:
483 | print(f"Using seed: {seed} for generation")
484 | generator = torch.Generator(device=self.device)
485 | generator.manual_seed(seed)
486 | else:
487 | if verbose:
488 | print("Using random seed for generation")
489 |
490 | if verbose:
491 | total_start_time = time.time()
492 |
493 | # Prepare generation states
494 | dec_state, dec_output = self._prepare_generation(text, audio_prompt, verbose)
495 | dec_step = dec_output.prefill_step - 1
496 |
497 | bos_countdown = max_delay_pattern
498 | eos_detected = False
499 | eos_countdown = -1
500 |
501 | if use_torch_compile:
502 | step_fn = torch.compile(self._decoder_step, mode="default")
503 | else:
504 | step_fn = self._decoder_step
505 |
506 | if verbose:
507 | print("generate: starting generation loop")
508 | if use_torch_compile:
509 | print(
510 | "generate: by using use_torch_compile=True, the first step would take long"
511 | )
512 | start_time = time.time()
513 |
514 | token_history = [] # Track generated tokens
515 |
516 | try:
517 | while dec_step < max_tokens:
518 | dec_state.prepare_step(dec_step)
519 | tokens_Bx1xC = (
520 | dec_output.get_tokens_at(dec_step).unsqueeze(0).expand(2, -1, -1)
521 | )
522 | pred_C = step_fn(
523 | tokens_Bx1xC,
524 | dec_state,
525 | cfg_scale,
526 | temperature,
527 | top_p,
528 | cfg_filter_top_k,
529 | generator=generator, # Pass generator to step_fn
530 | )
531 | # Clean up tokens after use
532 | del tokens_Bx1xC
533 |
534 | if (
535 | not eos_detected and pred_C[0] == audio_eos_value
536 | ) or dec_step == max_tokens - max_delay_pattern - 1:
537 | eos_detected = True
538 | eos_countdown = max_delay_pattern
539 |
540 | if eos_countdown > 0:
541 | step_after_eos = max_delay_pattern - eos_countdown
542 | for i, d in enumerate(delay_pattern):
543 | if step_after_eos == d:
544 | pred_C[i] = audio_eos_value
545 | elif step_after_eos > d:
546 | pred_C[i] = audio_pad_value
547 | eos_countdown -= 1
548 |
549 | bos_countdown = max(0, bos_countdown - 1)
550 | dec_output.update_one(pred_C, dec_step + 1, bos_countdown > 0)
551 |
552 | # Add to token history
553 | token_history.append(pred_C.detach().clone())
554 |
555 | if eos_countdown == 0:
556 | break
557 |
558 | dec_step += 1
559 |
560 | if verbose and dec_step % 86 == 0:
561 | duration = time.time() - start_time
562 | print(
563 | f"generate step {dec_step}: speed={86 / duration:.3f} tokens/s, realtime factor={1 / duration:.3f}x"
564 | )
565 | start_time = time.time()
566 |
567 | # Periodic memory cleanup during long generations
568 | if torch.cuda.is_available() and dec_step > 500:
569 | torch.cuda.empty_cache()
570 |
571 | except Exception as e:
572 | # Ensure cleanup even on error
573 | if "dec_state" in locals():
574 | del dec_state
575 | if "dec_output" in locals():
576 | del dec_output
577 | if "token_history" in locals():
578 | del token_history
579 | self.reset_state()
580 | raise e
581 |
582 | finally:
583 | # Close progress bar in finally block to ensure it's closed even on error
584 | # pbar.close()
585 |
586 | # Check if dec_output was created
587 | if "dec_output" in locals() and dec_output is not None:
588 | if dec_output.prefill_step >= dec_step + 1:
589 | print("Warning: Nothing generated")
590 | # Cleanup on early return
591 | if "dec_state" in locals():
592 | del dec_state
593 | if "dec_output" in locals():
594 | del dec_output
595 | if "token_history" in locals():
596 | del token_history
597 | self.reset_state()
598 | return None
599 |
600 | generated_codes = dec_output.generated_tokens[
601 | dec_output.prefill_step : dec_step + 1, :
602 | ]
603 |
604 | # Clean up state objects
605 | if "dec_state" in locals():
606 | del dec_state
607 | if "dec_output" in locals():
608 | del dec_output
609 | if "token_history" in locals():
610 | del token_history
611 |
612 | if verbose:
613 | if "dec_output" in locals():
614 | total_step = dec_step + 1 - dec_output.prefill_step
615 | total_duration = time.time() - total_start_time
616 | print(
617 | f"generate: total step={total_step}, total duration={total_duration:.3f}s"
618 | )
619 |
620 | # Process output
621 | output = self._generate_output(generated_codes)
622 |
623 | # Final cleanup
624 | del generated_codes
625 | self.reset_state()
626 |
627 | return output
628 | else:
629 | # Handle case where dec_output wasn't created
630 | self.reset_state()
631 | return None
632 |
633 | def reset_state(self):
634 | """Reset internal model state and clear CUDA cache to prevent memory leaks."""
635 | # 1. Clear any cached states in the model
636 | if hasattr(self, "dac_model") and hasattr(self.dac_model, "reset"):
637 | self.dac_model.reset()
638 |
639 | # 2. Clear any encoder/decoder buffers
640 | if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "buffers"):
641 | for buffer in self.model.encoder.buffers():
642 | if buffer.device.type == "cuda":
643 | buffer.data = buffer.data.clone() # Force copy to break references
644 |
645 | if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "buffers"):
646 | for buffer in self.model.decoder.buffers():
647 | if buffer.device.type == "cuda":
648 | buffer.data = buffer.data.clone() # Force copy to break references
649 |
650 | # 3. Force garbage collection first
651 | import gc
652 |
653 | gc.collect()
654 |
655 | # 4. Then clear CUDA cache
656 | if torch.cuda.is_available():
657 | torch.cuda.empty_cache()
658 |
659 | return True
660 |
--------------------------------------------------------------------------------
/dia/state.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 |
5 | from .config import DiaConfig
6 |
7 |
8 | def create_attn_mask(
9 | q_padding_mask_1d: torch.Tensor,
10 | k_padding_mask_1d: torch.Tensor,
11 | device: torch.device,
12 | is_causal: bool = False,
13 | ) -> torch.Tensor:
14 | """
15 | Creates the attention mask (self or cross) mimicking JAX segment ID logic.
16 | """
17 | B1, Tq = q_padding_mask_1d.shape
18 | B2, Tk = k_padding_mask_1d.shape
19 | assert B1 == B2, "Query and key batch dimensions must match"
20 |
21 | p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
22 | p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
23 |
24 | # Condition A: Non-padding query attends to non-padding key
25 | non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
26 |
27 | # Condition B: Padding query attends to padding key
28 | pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
29 |
30 | # Combine: True if padding status is compatible (both non-pad OR both pad)
31 | mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
32 |
33 | if is_causal:
34 | assert (
35 | Tq == Tk
36 | ), "Causal mask requires query and key sequence lengths to be equal"
37 | causal_mask_2d = torch.tril(
38 | torch.ones((Tq, Tk), dtype=torch.bool, device=device)
39 | ) # Shape [Tq, Tk]
40 | causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
41 | return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
42 | else:
43 | return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
44 |
45 |
46 | @dataclass
47 | class EncoderInferenceState:
48 | """Parameters specifically for encoder inference."""
49 |
50 | max_seq_len: int
51 | device: torch.device
52 | positions: torch.Tensor
53 | padding_mask: torch.Tensor
54 | attn_mask: torch.Tensor
55 |
56 | @classmethod
57 | def new(cls, config: DiaConfig, cond_src: torch.Tensor) -> "EncoderInferenceState":
58 | """Creates EtorchrInferenceParams from DiaConfig and a device."""
59 | device = cond_src.device
60 |
61 | positions = (
62 | torch.arange(config.data.text_length, dtype=torch.float32, device=device)
63 | .unsqueeze(0)
64 | .expand(2, -1)
65 | )
66 | padding_mask = (cond_src != config.data.text_pad_value).to(device).expand(2, -1)
67 | attn_mask = create_attn_mask(
68 | padding_mask, padding_mask, device, is_causal=False
69 | )
70 |
71 | return cls(
72 | max_seq_len=config.data.text_length,
73 | device=device,
74 | positions=positions,
75 | padding_mask=padding_mask,
76 | attn_mask=attn_mask,
77 | )
78 |
79 |
80 | class KVCache:
81 | def __init__(
82 | self,
83 | num_heads: int,
84 | max_len: int,
85 | head_dim: int,
86 | dtype: torch.dtype,
87 | device: torch.device,
88 | k: torch.Tensor | None = None,
89 | v: torch.Tensor | None = None,
90 | ):
91 | self.k = (
92 | torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device)
93 | if k is None
94 | else k
95 | )
96 | self.v = (
97 | torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device)
98 | if v is None
99 | else v
100 | )
101 | self.current_idx = torch.tensor(0)
102 |
103 | @classmethod
104 | def from_kv(cls, k: torch.Tensor, v: torch.Tensor) -> "KVCache":
105 | return cls(
106 | num_heads=k.shape[1],
107 | max_len=k.shape[2],
108 | head_dim=k.shape[3],
109 | dtype=k.dtype,
110 | device=k.device,
111 | k=k,
112 | v=v,
113 | )
114 |
115 | def update(
116 | self, k: torch.Tensor, v: torch.Tensor
117 | ) -> tuple[torch.Tensor, torch.Tensor]:
118 | self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
119 | self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
120 | self.current_idx += 1
121 | return self.k[:, :, : self.current_idx, :], self.v[:, :, : self.current_idx, :]
122 |
123 | def prefill(
124 | self, k: torch.Tensor, v: torch.Tensor
125 | ) -> tuple[torch.Tensor, torch.Tensor]:
126 | prefill_len = k.shape[2]
127 | self.k[:, :, :prefill_len, :] = k
128 | self.v[:, :, :prefill_len, :] = v
129 | self.current_idx = prefill_len - 1
130 |
131 |
132 | @dataclass
133 | class DecoderInferenceState:
134 | """Parameters specifically for decoder inference."""
135 |
136 | device: torch.device
137 | dtype: torch.dtype
138 | enc_out: torch.Tensor
139 | enc_positions: torch.Tensor
140 | dec_positions: torch.Tensor
141 | dec_cross_attn_mask: torch.Tensor
142 | self_attn_cache: list[KVCache]
143 | cross_attn_cache: list[KVCache]
144 |
145 | @classmethod
146 | def new(
147 | cls,
148 | config: DiaConfig,
149 | enc_state: EncoderInferenceState,
150 | enc_out: torch.Tensor,
151 | dec_cross_attn_cache: list[KVCache],
152 | compute_dtype: torch.dtype,
153 | ) -> "DecoderInferenceState":
154 | """Creates DecoderInferenceParams from DiaConfig and a device."""
155 | device = enc_out.device
156 | max_audio_len = config.data.audio_length
157 |
158 | dec_positions = torch.full(
159 | (2, 1), fill_value=0, dtype=torch.long, device=device
160 | )
161 | tgt_padding_mask = torch.ones((2, 1), dtype=torch.bool, device=device)
162 | dec_cross_attn_mask = create_attn_mask(
163 | tgt_padding_mask, enc_state.padding_mask, device, is_causal=False
164 | )
165 |
166 | self_attn_cache = [
167 | KVCache(
168 | config.model.decoder.kv_heads,
169 | max_audio_len,
170 | config.model.decoder.gqa_head_dim,
171 | compute_dtype,
172 | device,
173 | )
174 | for _ in range(config.model.decoder.n_layer)
175 | ]
176 |
177 | return cls(
178 | device=device,
179 | dtype=compute_dtype,
180 | enc_out=enc_out,
181 | enc_positions=enc_state.positions,
182 | dec_positions=dec_positions,
183 | dec_cross_attn_mask=dec_cross_attn_mask,
184 | self_attn_cache=self_attn_cache,
185 | cross_attn_cache=dec_cross_attn_cache,
186 | )
187 |
188 | def prepare_step(self, step_from: int, step_to: int | None = None) -> None:
189 | if step_to is None:
190 | step_to = step_from + 1
191 | self.dec_positions = (
192 | torch.arange(step_from, step_to, dtype=torch.float32, device=self.device)
193 | .unsqueeze(0)
194 | .expand(2, -1)
195 | )
196 |
197 |
198 | @dataclass
199 | class DecoderOutput:
200 | generated_tokens: torch.Tensor
201 | prefill_step: int
202 |
203 | @classmethod
204 | def new(cls, config: DiaConfig, device: torch.device) -> "DecoderOutput":
205 | max_audio_len = config.data.audio_length
206 | return cls(
207 | generated_tokens=torch.full(
208 | (max_audio_len, config.data.channels),
209 | fill_value=-1,
210 | dtype=torch.int,
211 | device=device,
212 | ),
213 | prefill_step=0,
214 | )
215 |
216 | def get_tokens_at(self, step_from: int, step_to: int | None = None) -> torch.Tensor:
217 | if step_to is None:
218 | step_to = step_from + 1
219 | return self.generated_tokens[step_from:step_to, :]
220 |
221 | def update_one(self, dec_out: torch.Tensor, step: int, apply_mask: bool = False):
222 | if apply_mask:
223 | mask = self.generated_tokens[step : step + 1, :] == -1
224 | self.generated_tokens[step : step + 1, :] = torch.where(
225 | mask, dec_out, self.generated_tokens[step : step + 1, :]
226 | )
227 | else:
228 | self.generated_tokens[step : step + 1, :] = dec_out
229 |
230 | def prefill(self, dec_out: torch.Tensor, prefill_step: int):
231 | length = dec_out.shape[0]
232 | self.generated_tokens[0:length, :] = dec_out
233 | self.prefill_step = prefill_step
234 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3.8'
2 |
3 | services:
4 | dia-tts-server:
5 | build:
6 | context: .
7 | dockerfile: Dockerfile
8 | ports:
9 | - "${PORT:-8003}:${PORT:-8003}"
10 | volumes:
11 | # Mount local directories into the container for persistent data
12 | - ./model_cache:/app/model_cache
13 | - ./reference_audio:/app/reference_audio
14 | - ./outputs:/app/outputs
15 | - ./voices:/app/voices
16 | # DO NOT mount config.yaml - let the app create it inside
17 |
18 | # --- GPU Access ---
19 | # Modern method (Recommended for newer Docker/NVIDIA setups)
20 | devices:
21 | - nvidia.com/gpu=all
22 | device_cgroup_rules:
23 | - "c 195:* rmw" # Needed for some NVIDIA container toolkit versions
24 | - "c 236:* rmw" # Needed for some NVIDIA container toolkit versions
25 |
26 | # Legacy method (Alternative for older Docker/NVIDIA setups)
27 | # If the 'devices' block above doesn't work, comment it out and uncomment
28 | # the 'deploy' block below. Do not use both simultaneously.
29 | # deploy:
30 | # resources:
31 | # reservations:
32 | # devices:
33 | # - driver: nvidia
34 | # count: 1 # Or specify specific GPUs e.g., "device=0,1"
35 | # capabilities: [gpu]
36 | # --- End GPU Access ---
37 |
38 | restart: unless-stopped
39 | env_file:
40 | # Load environment variables from .env file for initial config seeding
41 | - .env
42 | environment:
43 | # Enable faster Hugging Face downloads inside the container
44 | - HF_HUB_ENABLE_HF_TRANSFER=1
45 | # Pass GPU capabilities (may be needed for legacy method if uncommented)
46 | - NVIDIA_VISIBLE_DEVICES=all
47 | - NVIDIA_DRIVER_CAPABILITIES=compute,utility
48 |
--------------------------------------------------------------------------------
/download_model.py:
--------------------------------------------------------------------------------
1 | # download_model.py
2 | # Utility script to download the Dia model and dependencies without starting the server.
3 |
4 | import logging
5 | import os
6 | import engine # Import the engine module to trigger its loading logic
7 |
8 | # Import Whisper for model download check
9 | try:
10 | import whisper
11 |
12 | WHISPER_AVAILABLE = True
13 | except ImportError:
14 | WHISPER_AVAILABLE = False
15 | logging.warning("Whisper library not found. Cannot download Whisper model.")
16 |
17 | # Configure basic logging for the script
18 | logging.basicConfig(
19 | level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
20 | )
21 | logger = logging.getLogger("ModelDownloader")
22 |
23 | if __name__ == "__main__":
24 | logger.info("--- Starting Dia & Whisper Model Download ---")
25 |
26 | # Ensure cache directory exists (redundant if engine.load_model does it, but safe)
27 | try:
28 | from config import get_model_cache_path, get_whisper_model_name
29 |
30 | cache_path = get_model_cache_path()
31 | os.makedirs(cache_path, exist_ok=True)
32 | logger.info(
33 | f"Ensured model cache directory exists: {os.path.abspath(cache_path)}"
34 | )
35 | except Exception as e:
36 | logger.warning(f"Could not ensure cache directory exists: {e}")
37 | cache_path = None # Ensure cache_path is defined or None
38 |
39 | # Trigger the Dia model loading function from the engine
40 | logger.info("Calling engine.load_model() to initiate Dia download if necessary...")
41 | dia_success = engine.load_model()
42 |
43 | if dia_success:
44 | logger.info("--- Dia model download/load process completed successfully ---")
45 | else:
46 | logger.error(
47 | "--- Dia model download/load process failed. Check logs for details. ---"
48 | )
49 | # Optionally exit if Dia fails, or continue to try Whisper
50 | # exit(1)
51 |
52 | # --- Download Whisper Model ---
53 | whisper_success = False
54 | if WHISPER_AVAILABLE and cache_path:
55 | whisper_model_name = get_whisper_model_name()
56 | logger.info(f"Attempting to download Whisper model '{whisper_model_name}'...")
57 | try:
58 | # Use download_root to specify our cache directory
59 | whisper.load_model(whisper_model_name, download_root=cache_path)
60 | logger.info(
61 | f"Whisper model '{whisper_model_name}' downloaded/found successfully in {cache_path}."
62 | )
63 | whisper_success = True
64 | except Exception as e:
65 | logger.error(
66 | f"Failed to download/load Whisper model '{whisper_model_name}': {e}",
67 | exc_info=True,
68 | )
69 | elif not WHISPER_AVAILABLE:
70 | logger.warning(
71 | "Skipping Whisper model download: Whisper library not installed."
72 | )
73 | elif not cache_path:
74 | logger.warning(
75 | "Skipping Whisper model download: Cache path could not be determined."
76 | )
77 |
78 | # --- Final Status ---
79 | if dia_success and whisper_success:
80 | logger.info("--- All required models downloaded/verified successfully ---")
81 | elif dia_success:
82 | logger.warning(
83 | "--- Dia model OK, but Whisper model download failed or was skipped ---"
84 | )
85 | elif whisper_success:
86 | logger.warning("--- Whisper model OK, but Dia model download failed ---")
87 | else:
88 | logger.error(
89 | "--- Both Dia and Whisper model downloads failed or were skipped ---"
90 | )
91 | exit(1) # Exit with error code if essential models failed
92 |
93 | logger.info("You can now start the server using 'python server.py'")
94 |
--------------------------------------------------------------------------------
/env.example.txt:
--------------------------------------------------------------------------------
1 | # .env - Initial Configuration Seed for Dia TTS Server
2 | # IMPORTANT: This file is ONLY used on the very first server start
3 | # IF config.yaml does NOT exist.
4 | # After config.yaml is created, settings are managed there.
5 | # Changes here will NOT affect a running server with an existing config.yaml.
6 | # Use the Web UI or directly edit config.yaml for subsequent changes.
7 |
8 | # --- Server Settings ---
9 | HOST='0.0.0.0'
10 | PORT='8003'
11 |
12 | # --- Path Settings ---
13 | # These paths will be written into the initial config.yaml
14 | DIA_MODEL_CACHE_PATH='./model_cache'
15 | REFERENCE_AUDIO_PATH='./reference_audio'
16 | OUTPUT_PATH='./outputs'
17 | # Path for predefined voices is now set in config.py defaults
18 | # PREDEFINED_VOICES_PATH='./voices'
19 |
20 | # --- Model Source Settings ---
21 | # Defaulting to BF16 safetensors. Uncomment/modify to seed config.yaml differently.
22 | DIA_MODEL_REPO_ID='ttj/dia-1.6b-safetensors'
23 | DIA_MODEL_CONFIG_FILENAME='config.json'
24 | DIA_MODEL_WEIGHTS_FILENAME='dia-v0_1_bf16.safetensors'
25 |
26 | # Example: Seed with full precision safetensors
27 | # DIA_MODEL_REPO_ID=ttj/dia-1.6b-safetensors
28 | # DIA_MODEL_WEIGHTS_FILENAME=dia-v0_1.safetensors
29 |
30 | # Example: Seed with original Nari Labs .pth model
31 | # DIA_MODEL_REPO_ID=nari-labs/Dia-1.6B
32 | # DIA_MODEL_WEIGHTS_FILENAME=dia-v0_1.pth
33 |
34 | # --- Whisper Transcription Settings ---
35 | # Model name for automatic transcription if reference .txt is missing.
36 | WHISPER_MODEL_NAME='small.en'
37 |
38 | # --- Default Generation Parameters ---
39 | # These set the initial values in the 'generation_defaults' section of config.yaml.
40 | # They can be changed later via the UI ('Save Generation Defaults' button) or by editing config.yaml.
41 | GEN_DEFAULT_SPEED_FACTOR='1.0'
42 | GEN_DEFAULT_CFG_SCALE='3.0'
43 | GEN_DEFAULT_TEMPERATURE='1.3'
44 | GEN_DEFAULT_TOP_P='0.95'
45 | GEN_DEFAULT_CFG_FILTER_TOP_K='35'
46 | # -1 for random seed
47 | GEN_DEFAULT_SEED='42'
48 | # Now managed in config.py defaults
49 | # GEN_DEFAULT_SPLIT_TEXT='True'
50 | # GEN_DEFAULT_CHUNK_SIZE='120'
51 |
52 | # --- UI State ---
53 | # UI state (last text, selections, etc.) is NOT seeded from .env.
54 | # It starts with defaults defined in config.py and is saved in config.yaml.
55 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | # models.py
2 | # Pydantic models for API requests and potentially responses
3 |
4 | from pydantic import BaseModel, Field
5 | from typing import Optional, Literal
6 |
7 | # --- Request Models ---
8 |
9 |
10 | class OpenAITTSRequest(BaseModel):
11 | """Request model compatible with the OpenAI TTS API."""
12 |
13 | model: str = Field(
14 | default="dia-1.6b",
15 | description="Model identifier (ignored by this server, always uses Dia). Included for compatibility.",
16 | )
17 | input: str = Field(..., description="The text to synthesize.")
18 | voice: str = Field(
19 | default="S1",
20 | description="Voice mode or reference audio filename. Examples: 'S1', 'S2', 'dialogue', 'my_reference.wav'.",
21 | )
22 | response_format: Literal["opus", "wav"] = Field(
23 | default="opus", description="The desired audio output format."
24 | )
25 | speed: float = Field(
26 | default=1.0,
27 | ge=0.5, # Allow wider range for speed factor post-processing
28 | le=2.0,
29 | description="Adjusts the speed of the generated audio (0.5 to 2.0).",
30 | )
31 | # Add seed parameter, defaulting to random (-1)
32 | seed: Optional[int] = Field(
33 | default=-1,
34 | description="Generation seed. Use -1 for random (default), or a specific integer for deterministic output.",
35 | )
36 |
37 |
38 | class CustomTTSRequest(BaseModel):
39 | """Request model for the custom /tts endpoint."""
40 |
41 | text: str = Field(
42 | ...,
43 | description="The text to synthesize. For 'dialogue' mode, include [S1]/[S2] tags.",
44 | )
45 | voice_mode: Literal["dialogue", "single_s1", "single_s2", "clone"] = Field(
46 | default="single_s1", description="Specifies the generation mode."
47 | )
48 | clone_reference_filename: Optional[str] = Field(
49 | default=None,
50 | description="Filename of the reference audio within the configured reference path (required if voice_mode is 'clone').",
51 | )
52 | # New: Optional transcript for cloning
53 | transcript: Optional[str] = Field(
54 | default=None,
55 | description="Optional transcript of the reference audio for cloning. If provided, overrides local .txt file lookup and Whisper generation.",
56 | )
57 | output_format: Literal["opus", "wav"] = Field(
58 | default="opus", description="The desired audio output format."
59 | )
60 | # Dia-specific generation parameters
61 | max_tokens: Optional[int] = Field(
62 | default=None,
63 | gt=0,
64 | description="Maximum number of audio tokens to generate per chunk (defaults to model's internal config value).",
65 | )
66 | cfg_scale: float = Field(
67 | default=3.0,
68 | ge=1.0,
69 | le=5.0,
70 | description="Classifier-Free Guidance scale (1.0-5.0).",
71 | )
72 | temperature: float = Field(
73 | default=1.3,
74 | ge=0.1,
75 | le=1.5,
76 | description="Sampling temperature (0.1-1.5).", # Allow lower temp for greedy-like
77 | )
78 | top_p: float = Field(
79 | default=0.95,
80 | ge=0.1, # Allow lower top_p
81 | le=1.0,
82 | description="Nucleus sampling probability (0.1-1.0).",
83 | )
84 | speed_factor: float = Field(
85 | default=0.94,
86 | ge=0.5, # Allow wider range for speed factor post-processing
87 | le=2.0,
88 | description="Adjusts the speed of the generated audio (0.5 to 2.0).",
89 | )
90 | cfg_filter_top_k: int = Field(
91 | default=35,
92 | ge=1,
93 | le=100,
94 | description="Top k filter for CFG guidance (1-100).", # Allow wider range
95 | )
96 | # Add seed parameter, defaulting to random (-1)
97 | seed: Optional[int] = Field(
98 | default=-1,
99 | description="Generation seed. Use -1 for random (default), or a specific integer for deterministic output.",
100 | )
101 | # Add text splitting parameters
102 | split_text: Optional[bool] = Field(
103 | default=True, # Default to splitting enabled
104 | description="Whether to automatically split long text into chunks for processing.",
105 | )
106 | chunk_size: Optional[int] = Field(
107 | default=300, # Default target chunk size
108 | ge=100, # Minimum reasonable chunk size
109 | le=1000, # Maximum reasonable chunk size
110 | description="Approximate target character length for text chunks when splitting is enabled (100-1000).",
111 | )
112 |
113 |
114 | # --- Response Models (Optional, can be simple dicts too) ---
115 |
116 |
117 | class TTSResponse(BaseModel):
118 | """Basic response model for successful generation (if returning JSON)."""
119 |
120 | request_id: str
121 | status: str = "completed"
122 | generation_time_sec: float
123 | output_url: Optional[str] = None # If saving file and returning URL
124 |
125 |
126 | class ErrorResponse(BaseModel):
127 | """Error response model."""
128 |
129 | detail: str
130 |
--------------------------------------------------------------------------------
/reference_audio/Gianna.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/reference_audio/Gianna.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/reference_audio/Gianna.wav
--------------------------------------------------------------------------------
/reference_audio/Oliver_Luna.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/reference_audio/Oliver_Luna.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/reference_audio/Oliver_Luna.wav
--------------------------------------------------------------------------------
/reference_audio/Robert.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/reference_audio/Robert.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/reference_audio/Robert.wav
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # requirements.txt
2 |
3 | # Core Web Framework
4 | fastapi
5 | uvicorn[standard]
6 |
7 | # Machine Learning & Audio
8 | torch
9 | torchaudio
10 | numpy
11 | soundfile # Requires libsndfile system library (e.g., sudo apt-get install libsndfile1 on Debian/Ubuntu)
12 | huggingface_hub
13 | descript-audio-codec
14 | safetensors
15 | openai-whisper
16 |
17 | # Configuration & Utilities
18 | pydantic
19 | python-dotenv # Used ONLY for initial config seeding if config.yaml missing
20 | Jinja2
21 | python-multipart # For file uploads
22 | requests # For health checks or other potential uses
23 | PyYAML # For parsing presets.yaml AND primary config.yaml
24 | tqdm
25 |
26 | # Audio Post-processing
27 | pydub
28 | praat-parselmouth # For unvoiced segment removal
29 | librosa # for changes to sampling
30 | hf-transfer
31 |
--------------------------------------------------------------------------------
/static/screenshot-d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/static/screenshot-d.png
--------------------------------------------------------------------------------
/static/screenshot-l.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/static/screenshot-l.png
--------------------------------------------------------------------------------
/ui/presets.yaml:
--------------------------------------------------------------------------------
1 | # ui/presets.yaml
2 | # Predefined examples for the Dia TTS UI
3 |
4 | - name: "Short Dialogue" # Renamed
5 | # voice_mode ignored by JS loader
6 | text: |
7 | [S1] Hey, how's it going?
8 | [S2] Pretty good! Just grabbing some coffee. You?
9 | [S1] Same here. Need the fuel! (laughs)
10 | params:
11 | cfg_scale: 3.0
12 | temperature: 1.3
13 | top_p: 0.95
14 | cfg_filter_top_k: 35
15 | # speed_factor uses the saved default
16 | # seed uses the saved default
17 |
18 | - name: "Long Dialogue" # Added
19 | text: |
20 | [S1] Hey, is that my coffee you're drinking?
21 | [S2] Pretty sure it's mine. Wait, let me check. (squints) It says "Bark"? That can't be right.
22 | [S1] That's Mark with an M! (sighs) My handwriting isn't that bad.
23 | [S2] Oh! (laughs) Well, in my defense, this is my third coffee today.
24 | [S1] And yet you thought you needed another one?
25 | [S2] Sleep is for the weak! (yawns) Though apparently so am I.
26 | [S1] Here, just give me my coffee before you (gasps)
27 | [S2] What? What's wrong?
28 | [S1] That was a triple espresso with extra shots! How are you still standing?
29 | [S2] (laughs) I can see through time now. Is it normal to taste colors?
30 | [S1] I'll order you some water.
31 | [S2] No need! I've never felt better! (laughs)
32 | [S1] Your left eye is twitching.
33 | [S2] Is it? I can't feel my face. (laughs)
34 | [S1] Maybe we should sit down for a minute.
35 | [S2] Sitting is for people who can't vibrate in place! (sighs) Fine, maybe for a minute.
36 | [S1] I've never seen someone drink the wrong coffee so enthusiastically.
37 | [S2] I've never met a coffee I didn't like. (laughs) Though this one might be trying to kill me.
38 | [S1] Next time, maybe read the name before you drink?
39 | [S2] Reading requires focus. Focus requires sleep. It's a whole problem. (yawns)
40 | [S1] How about I buy you a decaf?
41 | [S2] Decaf? (gasps) What did I ever do to you?
42 | [S1] Just trying to keep you alive until tomorrow. (laughs)
43 | [S2] Tomorrow is overrated. Today has coffee!
44 | params: # Using standard defaults
45 | cfg_scale: 3.0
46 | temperature: 1.3
47 | top_p: 0.95
48 | cfg_filter_top_k: 35
49 |
50 | - name: "Expressive Narration"
51 | text: |
52 | [S1] The old house stood on a windswept hill, its windows like empty eyes staring out at the stormy sea. (sighs) It felt... lonely.
53 | params:
54 | cfg_scale: 3.0
55 | temperature: 1.2 # Slightly lower temp for clarity
56 | top_p: 0.95
57 | cfg_filter_top_k: 35
58 |
59 | - name: "Quick Announcement"
60 | text: |
61 | [S1] Attention shoppers! The store will be closing in 15 minutes. Please bring your final purchases to the checkout.
62 | params:
63 | cfg_scale: 2.8 # Slightly lower CFG for potentially more natural tone
64 | temperature: 1.3
65 | top_p: 0.95
66 | cfg_filter_top_k: 35
67 |
68 | - name: "Funny Exchange"
69 | text: |
70 | [S1] Did you remember to buy the alien repellent?
71 | [S2] The what now? (laughs) I thought you were joking!
72 | [S1] Joking? They're landing tonight! (clears throat) Probably.
73 | params:
74 | cfg_scale: 3.2 # Slightly higher CFG
75 | temperature: 1.35 # Slightly higher temp
76 | top_p: 0.95
77 | cfg_filter_top_k: 35
78 |
79 | - name: "Simple Sentence"
80 | text: |
81 | [S1] This is a test of the text to speech system.
82 | params:
83 | cfg_scale: 3.0
84 | temperature: 1.3
85 | top_p: 0.95
86 | cfg_filter_top_k: 35
87 |
--------------------------------------------------------------------------------
/voices/Abigail.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Abigail.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Abigail.wav
--------------------------------------------------------------------------------
/voices/Abigail_Taylor.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Abigail_Taylor.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Abigail_Taylor.wav
--------------------------------------------------------------------------------
/voices/Adrian.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Adrian.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Adrian.wav
--------------------------------------------------------------------------------
/voices/Adrian_Jade.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Adrian_Jade.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Adrian_Jade.wav
--------------------------------------------------------------------------------
/voices/Alexander.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Alexander.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Alexander.wav
--------------------------------------------------------------------------------
/voices/Alexander_Emily.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Alexander_Emily.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Alexander_Emily.wav
--------------------------------------------------------------------------------
/voices/Alice.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Alice.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Alice.wav
--------------------------------------------------------------------------------
/voices/Austin.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Austin.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Austin.wav
--------------------------------------------------------------------------------
/voices/Austin_Jeremiah.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Austin_Jeremiah.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Austin_Jeremiah.wav
--------------------------------------------------------------------------------
/voices/Axel.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Axel.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Axel.wav
--------------------------------------------------------------------------------
/voices/Axel_Miles.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Axel_Miles.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Axel_Miles.wav
--------------------------------------------------------------------------------
/voices/Connor.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Connor.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Connor.wav
--------------------------------------------------------------------------------
/voices/Connor_Ryan.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Connor_Ryan.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Connor_Ryan.wav
--------------------------------------------------------------------------------
/voices/Cora.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Cora.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Cora.wav
--------------------------------------------------------------------------------
/voices/Cora_Gianna.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Cora_Gianna.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Cora_Gianna.wav
--------------------------------------------------------------------------------
/voices/Elena.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Elena.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Elena.wav
--------------------------------------------------------------------------------
/voices/Elena_Emily.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Elena_Emily.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Elena_Emily.wav
--------------------------------------------------------------------------------
/voices/Eli.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Eli.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Eli.wav
--------------------------------------------------------------------------------
/voices/Emily.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Emily.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Emily.wav
--------------------------------------------------------------------------------
/voices/Everett.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Everett.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Everett.wav
--------------------------------------------------------------------------------
/voices/Everett_Jordan.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Everett_Jordan.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Everett_Jordan.wav
--------------------------------------------------------------------------------
/voices/Gabriel.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Gabriel.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Gabriel.wav
--------------------------------------------------------------------------------
/voices/Gabriel_Ian.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Gabriel_Ian.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Gabriel_Ian.wav
--------------------------------------------------------------------------------
/voices/Gianna.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Gianna.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Gianna.wav
--------------------------------------------------------------------------------
/voices/Henry.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Henry.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Henry.wav
--------------------------------------------------------------------------------
/voices/Ian.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Ian.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Ian.wav
--------------------------------------------------------------------------------
/voices/Jade.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Jade.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Jade.wav
--------------------------------------------------------------------------------
/voices/Jade_Layla.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Jade_Layla.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Jade_Layla.wav
--------------------------------------------------------------------------------
/voices/Jeremiah.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Jeremiah.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Jeremiah.wav
--------------------------------------------------------------------------------
/voices/Jordan.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Jordan.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Jordan.wav
--------------------------------------------------------------------------------
/voices/Julian.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Julian.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Julian.wav
--------------------------------------------------------------------------------
/voices/Julian_Thomas.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Julian_Thomas.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Julian_Thomas.wav
--------------------------------------------------------------------------------
/voices/Layla.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Layla.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Layla.wav
--------------------------------------------------------------------------------
/voices/Leonardo.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Leonardo.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Leonardo.wav
--------------------------------------------------------------------------------
/voices/Leonardo_Olivia.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Leonardo_Olivia.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Leonardo_Olivia.wav
--------------------------------------------------------------------------------
/voices/Michael.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Michael.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Michael.wav
--------------------------------------------------------------------------------
/voices/Michael_Emily.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Michael_Emily.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Michael_Emily.wav
--------------------------------------------------------------------------------
/voices/Miles.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Miles.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Miles.wav
--------------------------------------------------------------------------------
/voices/Oliver_Luna.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding. [S2] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Oliver_Luna.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Oliver_Luna.wav
--------------------------------------------------------------------------------
/voices/Olivia.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Olivia.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Olivia.wav
--------------------------------------------------------------------------------
/voices/Ryan.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Ryan.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Ryan.wav
--------------------------------------------------------------------------------
/voices/Taylor.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Taylor.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Taylor.wav
--------------------------------------------------------------------------------
/voices/Thomas.txt:
--------------------------------------------------------------------------------
1 | [S1] We believe that exploring new ideas and sharing knowledge helps make the world a brighter place for everyone. Continuous learning and open communication are essential for progress and mutual understanding.
--------------------------------------------------------------------------------
/voices/Thomas.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devnen/Dia-TTS-Server/3e428f864b2fb71a5a4e98712ed99594f0cc02c2/voices/Thomas.wav
--------------------------------------------------------------------------------