├── .github
└── workflows
│ └── ci.yaml
├── .gitignore
├── .python-version
├── LICENSE
├── README.md
├── app.py
├── app_local.py
├── cli.py
├── dia
├── __init__.py
├── audio.py
├── config.json
├── config.py
├── config_inference.json
├── convert_ckpt.py
├── dataset.py
├── finetune.py
├── interleaved_datasets.py
├── layers.py
├── model.py
└── static
│ └── images
│ └── banner.png
├── example
├── simple.py
└── voice_clone.py
├── example_prompt.mp3
├── pyproject.toml
└── uv.lock
/.github/workflows/ci.yaml:
--------------------------------------------------------------------------------
1 | name: Continuous Integration
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | lint_and_format:
10 | runs-on: ubuntu-latest
11 | name: Lint and Format
12 | steps:
13 | - uses: actions/checkout@v4
14 | - uses: astral-sh/ruff-action@v3
15 | with:
16 | version: latest
17 |
18 | - name: Check Lint using Ruff
19 | run: ruff check
20 |
21 | - name: Check Format using Ruff
22 | run: ruff format --check --diff
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python-generated files
2 | __pycache__/
3 | *.py[oc]
4 | build/
5 | dist/
6 | wheels/
7 |
8 | *.egg-info
9 |
10 | # Virtual environments
11 | .venv
12 |
13 | .gradio
14 |
15 | **/*.pth
16 | **/*.mp3
17 | !example_prompt.mp3
18 | **/*.txt
19 | **/*.ipynb
20 |
21 | .ruff_cache
22 | .ipynb_checkpoints
23 | runs/
24 | results/
25 | ckpts/
26 | dia_finetune_mml/
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.10
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2025 Nari Labs
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dia TTS Model Fine-Tuning
2 |
3 | A training pipeline for fine-tuning the **Dia** TTS model using Hugging Face datasets or local audio–text pairs. Supports mixed-precision, model compilation, 8-bit optimizers, streaming datasets, and evaluation via TensorBoard.
4 | For multilingual training, the pipeline supports language-tags ```[iso_code]```.For training a multilingual model, you have to provide a dataset with a language column containing the iso_code.
5 |
6 |
7 | ---
8 |
9 |
10 | ## Installation
11 |
12 | ```bash
13 | git clone https://github.com/stlohrey/dia-finetuning.git
14 | cd dia-finetuning
15 | python -m venv .venv
16 | source .venv/bin/activate
17 | pip install -e .
18 | ```
19 |
20 | ---
21 |
22 | ## Usage Example
23 |
24 | ```bash
25 | python -m dia.finetune \
26 | --config path/to/dia/config.json \
27 | --dataset Paradoxia/opendata-iisys-hui \
28 | --hub_model nari-labs/Dia-1.6B \
29 | --run_name my_experiment \
30 | --output_dir ./checkpoints \
31 | ```
32 |
33 | ---
34 |
35 | ## Configuration
36 |
37 | * **JSON Config**: `dia/config.json` defines model sizes, token lengths, delay patterns, and audio PAD/BOS/EOS values.
38 | * **TrainConfig**: Default hyperparameters (epochs, batch size, learning rate, warmup, logging & saving steps, etc.) are set in the finetuning script in `TrainConfig`.
39 | * **CLI Config**: train settings can be passed via `train.py` flags (see below).
40 |
41 | ---
42 |
43 | ## Major CLI Arguments
44 |
45 | | Argument | Type | Default | Description | |
46 | | ----------------------- | ------ | ------------------------------ | ---------------------------------------------------------------- | ---------------------------------- |
47 | | `--config` | `Path` | `dia/config.json` | Path to the Dia JSON config. | |
48 | | `--dataset` | `str` | `Paradoxia/opendata-iisys-hui` | HF dataset name (train split). | |
49 | | `--dataset2` | `str` | `None` | (Optional) Second HF dataset to interleave. | |
50 | | `--streaming` | `bool` | `True` | Use HF streaming API. | |
51 | | `--hub_model` | `str` | `nari-labs/Dia-1.6B` | HF Hub repo for base checkpoint. | |
52 | | `--local_ckpt` | `str` | `None` | Path to local model checkpoint (`.pth`). | |
53 | | `--csv_path` | `Path` | `None` | CSV file with \`audio | example.wav\|transcript format. |
54 | | `--audio_root` | `Path` | `None` | Base directory for local audio files (required if `--csv_path`). | |
55 | | `--run_name` | `str` | | TensorBoard run directory name. | |
56 | | `--output_dir` | `Path` | | Directory for saving checkpoints. | |
57 | | `--shuffle_buffer_size` | `int` | `None` | Buffer size for streaming shuffle. | |
58 | | `--seed` | `int` | `42` | Random seed for reproducibility. | |
59 | | `--half` | `bool` | `False` | Load model in FP16. | |
60 | | `--compile` | `bool` | `False` | Enable `torch.compile` (Inductor backend). | |
61 |
62 | ---
63 |
64 | ## Monitoring & Evaluation
65 |
66 | * **TensorBoard**:
67 |
68 | ```bash
69 | tensorboard --logdir runs
70 | ```
71 |
72 | * `Loss/train`, `Loss/eval`, learning rate, grad‐norm.
73 | * Audio samples for each test sentence in multiple languages, can be specified inside finetune.py.
74 |
75 | * **Checkpoints**: Saved in `output_dir` as `ckpt_step{N}.pth` and `ckpt_epoch{E}.pth`.
76 |
77 | ---
78 |
79 | ## Inference (Gradio UI)
80 |
81 | **Convert Checkpoint to fp32**
82 |
83 | If you used --half and --compile during training, you have to unwrap and convert the checkpoint back to fp32:
84 | ```bash
85 | ./python -m dia.convert_ckpt \
86 | --input-ckpt /path/to/ckpt_epoch1.pth \
87 | --output-ckpt /path/to/ckpt_epoch1_fp32.pth \
88 | --config /path/to/config.json
89 | ```
90 |
91 | A Gradio-based web app for interactive text-to-speech synthesis. It provides sliders for generation parameters and accepts optional audio prompts.
92 |
93 | ```bash
94 | python app_local.py \
95 | --local_ckpt path/to/ckpt_epoch1_fp32.pth \
96 | --config path/to/inference/config.json
97 | ```
98 |
99 | Open the displayed URL in your browser to interact with the model.
100 |
101 | ---
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 | Dia is a 1.6B parameter text to speech model created by Nari Labs.
124 |
125 | Dia **directly generates highly realistic dialogue from a transcript**. You can condition the output on audio, enabling emotion and tone control. The model can also produce nonverbal communications like laughter, coughing, clearing throat, etc.
126 |
127 | To accelerate research, we are providing access to pretrained model checkpoints and inference code. The model weights are hosted on [Hugging Face](https://huggingface.co/nari-labs/Dia-1.6B). The model only supports English generation at the moment.
128 |
129 | We also provide a [demo page](https://yummy-fir-7a4.notion.site/dia) comparing our model to [ElevenLabs Studio](https://elevenlabs.io/studio) and [Sesame CSM-1B](https://github.com/SesameAILabs/csm).
130 |
131 | - (Update) We have a ZeroGPU Space running! Try it now [here](https://huggingface.co/spaces/nari-labs/Dia-1.6B). Thanks to the HF team for the support :)
132 | - Join our [discord server](https://discord.gg/pgdB5YRe) for community support and access to new features.
133 | - Play with a larger version of Dia: generate fun conversations, remix content, and share with friends. 🔮 Join the [waitlist](https://tally.so/r/meokbo) for early access.
134 |
135 | ## ⚡️ Quickstart
136 |
137 | ### Install via pip
138 |
139 | ```bash
140 | # Install directly from GitHub
141 | pip install git+https://github.com/nari-labs/dia.git
142 | ```
143 |
144 | ### Run the Gradio UI
145 |
146 | This will open a Gradio UI that you can work on.
147 |
148 | ```bash
149 | git clone https://github.com/nari-labs/dia.git
150 | cd dia && uv run app.py
151 | ```
152 |
153 | or if you do not have `uv` pre-installed:
154 |
155 | ```bash
156 | git clone https://github.com/nari-labs/dia.git
157 | cd dia
158 | python -m venv .venv
159 | source .venv/bin/activate
160 | pip install -e .
161 | python app.py
162 | ```
163 |
164 | Note that the model was not fine-tuned on a specific voice. Hence, you will get different voices every time you run the model.
165 | You can keep speaker consistency by either adding an audio prompt (a guide coming VERY soon - try it with the second example on Gradio for now), or fixing the seed.
166 |
167 | ## Features
168 |
169 | - Generate dialogue via `[S1]` and `[S2]` tag
170 | - Generate non-verbal like `(laughs)`, `(coughs)`, etc.
171 | - Below verbal tags will be recognized, but might result in unexpected output.
172 | - `(laughs), (clears throat), (sighs), (gasps), (coughs), (singing), (sings), (mumbles), (beep), (groans), (sniffs), (claps), (screams), (inhales), (exhales), (applause), (burps), (humming), (sneezes), (chuckle), (whistles)`
173 | - Voice cloning. See [`example/voice_clone.py`](example/voice_clone.py) for more information.
174 | - In the Hugging Face space, you can upload the audio you want to clone and place its transcript before your script. Make sure the transcript follows the required format. The model will then output only the content of your script.
175 |
176 | ## ⚙️ Usage
177 |
178 | ### As a Python Library
179 |
180 | ```python
181 | import soundfile as sf
182 |
183 | from dia.model import Dia
184 |
185 |
186 | model = Dia.from_pretrained("nari-labs/Dia-1.6B")
187 |
188 | text = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."
189 |
190 | output = model.generate(text)
191 |
192 | sf.write("simple.mp3", output, 44100)
193 | ```
194 |
195 | A pypi package and a working CLI tool will be available soon.
196 |
197 | ## 💻 Hardware and Inference Speed
198 |
199 | Dia has been tested on only GPUs (pytorch 2.0+, CUDA 12.6). CPU support is to be added soon.
200 | The initial run will take longer as the Descript Audio Codec also needs to be downloaded.
201 |
202 | On enterprise GPUs, Dia can generate audio in real-time. On older GPUs, inference time will be slower.
203 | For reference, on a A4000 GPU, Dia roughly generates 40 tokens/s (86 tokens equals 1 second of audio).
204 | `torch.compile` will increase speeds for supported GPUs.
205 |
206 | The full version of Dia requires around 12-13GB of VRAM to run. We will be adding a quantized version in the future.
207 |
208 | If you don't have hardware available or if you want to play with bigger versions of our models, join the waitlist [here](https://tally.so/r/meokbo).
209 |
210 | ## 🪪 License
211 |
212 | This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
213 |
214 | ## ⚠️ Disclaimer
215 |
216 | This project offers a high-fidelity speech generation model intended for research and educational use. The following uses are **strictly forbidden**:
217 |
218 | - **Identity Misuse**: Do not produce audio resembling real individuals without permission.
219 | - **Deceptive Content**: Do not use this model to generate misleading content (e.g. fake news)
220 | - **Illegal or Malicious Use**: Do not use this model for activities that are illegal or intended to cause harm.
221 |
222 | By using this model, you agree to uphold relevant legal standards and ethical responsibilities. We **are not responsible** for any misuse and firmly oppose any unethical usage of this technology.
223 |
224 | ## 🔭 TODO / Future Work
225 |
226 | - Docker support.
227 | - Optimize inference speed.
228 | - Add quantization for memory efficiency.
229 |
230 | ## 🤝 Contributing
231 |
232 | We are a tiny team of 1 full-time and 1 part-time research-engineers. We are extra-welcome to any contributions!
233 | Join our [Discord Server](https://discord.gg/pgdB5YRe) for discussions.
234 |
235 | ## 🤗 Acknowledgements
236 |
237 | - We thank the [Google TPU Research Cloud program](https://sites.research.google/trc/about/) for providing computation resources.
238 | - Our work was heavily inspired by [SoundStorm](https://arxiv.org/abs/2305.09636), [Parakeet](https://jordandarefsky.com/blog/2024/parakeet/), and [Descript Audio Codec](https://github.com/descriptinc/descript-audio-codec).
239 | - Hugging Face for providing the ZeroGPU Grant.
240 | - "Nari" is a pure Korean word for lily.
241 | - We thank Jason Y. for providing help with data filtering.
242 |
243 |
244 | ## ⭐ Star History
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import tempfile
3 | import time
4 | from pathlib import Path
5 | from typing import Optional, Tuple
6 |
7 | import gradio as gr
8 | import numpy as np
9 | import soundfile as sf
10 | import torch
11 |
12 | from dia.model import Dia
13 |
14 |
15 | # --- Global Setup ---
16 | parser = argparse.ArgumentParser(description="Gradio interface for Nari TTS")
17 | parser.add_argument("--device", type=str, default=None, help="Force device (e.g., 'cuda', 'mps', 'cpu')")
18 | parser.add_argument("--share", action="store_true", help="Enable Gradio sharing")
19 |
20 | args = parser.parse_args()
21 |
22 |
23 | # Determine device
24 | if args.device:
25 | device = torch.device(args.device)
26 | elif torch.cuda.is_available():
27 | device = torch.device("cuda")
28 | # Simplified MPS check for broader compatibility
29 | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
30 | # Basic check is usually sufficient, detailed check can be problematic
31 | device = torch.device("mps")
32 | else:
33 | device = torch.device("cpu")
34 |
35 | print(f"Using device: {device}")
36 |
37 | # Load Nari model and config
38 | print("Loading Nari model...")
39 | try:
40 | # Use the function from inference.py
41 | model = Dia.from_pretrained("nari-labs/Dia-1.6B", device=device)
42 | except Exception as e:
43 | print(f"Error loading Nari model: {e}")
44 | raise
45 |
46 |
47 | def run_inference(
48 | text_input: str,
49 | audio_prompt_input: Optional[Tuple[int, np.ndarray]],
50 | max_new_tokens: int,
51 | cfg_scale: float,
52 | temperature: float,
53 | top_p: float,
54 | cfg_filter_top_k: int,
55 | speed_factor: float,
56 | ):
57 | """
58 | Runs Nari inference using the globally loaded model and provided inputs.
59 | Uses temporary files for text and audio prompt compatibility with inference.generate.
60 | """
61 | global model, device # Access global model, config, device
62 |
63 | if not text_input or text_input.isspace():
64 | raise gr.Error("Text input cannot be empty.")
65 |
66 | temp_txt_file_path = None
67 | temp_audio_prompt_path = None
68 | output_audio = (44100, np.zeros(1, dtype=np.float32))
69 |
70 | try:
71 | prompt_path_for_generate = None
72 | if audio_prompt_input is not None:
73 | sr, audio_data = audio_prompt_input
74 | # Check if audio_data is valid
75 | if audio_data is None or audio_data.size == 0 or audio_data.max() == 0: # Check for silence/empty
76 | gr.Warning("Audio prompt seems empty or silent, ignoring prompt.")
77 | else:
78 | # Save prompt audio to a temporary WAV file
79 | with tempfile.NamedTemporaryFile(mode="wb", suffix=".wav", delete=False) as f_audio:
80 | temp_audio_prompt_path = f_audio.name # Store path for cleanup
81 |
82 | # Basic audio preprocessing for consistency
83 | # Convert to float32 in [-1, 1] range if integer type
84 | if np.issubdtype(audio_data.dtype, np.integer):
85 | max_val = np.iinfo(audio_data.dtype).max
86 | audio_data = audio_data.astype(np.float32) / max_val
87 | elif not np.issubdtype(audio_data.dtype, np.floating):
88 | gr.Warning(f"Unsupported audio prompt dtype {audio_data.dtype}, attempting conversion.")
89 | # Attempt conversion, might fail for complex types
90 | try:
91 | audio_data = audio_data.astype(np.float32)
92 | except Exception as conv_e:
93 | raise gr.Error(f"Failed to convert audio prompt to float32: {conv_e}")
94 |
95 | # Ensure mono (average channels if stereo)
96 | if audio_data.ndim > 1:
97 | if audio_data.shape[0] == 2: # Assume (2, N)
98 | audio_data = np.mean(audio_data, axis=0)
99 | elif audio_data.shape[1] == 2: # Assume (N, 2)
100 | audio_data = np.mean(audio_data, axis=1)
101 | else:
102 | gr.Warning(
103 | f"Audio prompt has unexpected shape {audio_data.shape}, taking first channel/axis."
104 | )
105 | audio_data = (
106 | audio_data[0] if audio_data.shape[0] < audio_data.shape[1] else audio_data[:, 0]
107 | )
108 | audio_data = np.ascontiguousarray(audio_data) # Ensure contiguous after slicing/mean
109 |
110 | # Write using soundfile
111 | try:
112 | sf.write(
113 | temp_audio_prompt_path, audio_data, sr, subtype="FLOAT"
114 | ) # Explicitly use FLOAT subtype
115 | prompt_path_for_generate = temp_audio_prompt_path
116 | print(f"Created temporary audio prompt file: {temp_audio_prompt_path} (orig sr: {sr})")
117 | except Exception as write_e:
118 | print(f"Error writing temporary audio file: {write_e}")
119 | raise gr.Error(f"Failed to save audio prompt: {write_e}")
120 |
121 | # 3. Run Generation
122 |
123 | start_time = time.time()
124 |
125 | # Use torch.inference_mode() context manager for the generation call
126 | with torch.inference_mode():
127 | output_audio_np = model.generate(
128 | text_input,
129 | max_tokens=max_new_tokens,
130 | cfg_scale=cfg_scale,
131 | temperature=temperature,
132 | top_p=top_p,
133 | use_cfg_filter=True,
134 | cfg_filter_top_k=cfg_filter_top_k, # Pass the value here
135 | use_torch_compile=False, # Keep False for Gradio stability
136 | audio_prompt_path=prompt_path_for_generate,
137 | )
138 |
139 | end_time = time.time()
140 | print(f"Generation finished in {end_time - start_time:.2f} seconds.")
141 |
142 | # 4. Convert Codes to Audio
143 | if output_audio_np is not None:
144 | # Get sample rate from the loaded DAC model
145 | output_sr = 44100
146 |
147 | # --- Slow down audio ---
148 | original_len = len(output_audio_np)
149 | # Ensure speed_factor is positive and not excessively small/large to avoid issues
150 | speed_factor = max(0.1, min(speed_factor, 5.0))
151 | target_len = int(original_len / speed_factor) # Target length based on speed_factor
152 | if target_len != original_len and target_len > 0: # Only interpolate if length changes and is valid
153 | x_original = np.arange(original_len)
154 | x_resampled = np.linspace(0, original_len - 1, target_len)
155 | resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np)
156 | output_audio = (
157 | output_sr,
158 | resampled_audio_np.astype(np.float32),
159 | ) # Use resampled audio
160 | print(f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed.")
161 | else:
162 | output_audio = (
163 | output_sr,
164 | output_audio_np,
165 | ) # Keep original if calculation fails or no change
166 | print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).")
167 | # --- End slowdown ---
168 |
169 | print(f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}")
170 |
171 | else:
172 | print("\nGeneration finished, but no valid tokens were produced.")
173 | # Return default silence
174 | gr.Warning("Generation produced no output.")
175 |
176 | except Exception as e:
177 | print(f"Error during inference: {e}")
178 | import traceback
179 |
180 | traceback.print_exc()
181 | # Re-raise as Gradio error to display nicely in the UI
182 | raise gr.Error(f"Inference failed: {e}")
183 |
184 | finally:
185 | # 5. Cleanup Temporary Files defensively
186 | if temp_txt_file_path and Path(temp_txt_file_path).exists():
187 | try:
188 | Path(temp_txt_file_path).unlink()
189 | print(f"Deleted temporary text file: {temp_txt_file_path}")
190 | except OSError as e:
191 | print(f"Warning: Error deleting temporary text file {temp_txt_file_path}: {e}")
192 | if temp_audio_prompt_path and Path(temp_audio_prompt_path).exists():
193 | try:
194 | Path(temp_audio_prompt_path).unlink()
195 | print(f"Deleted temporary audio prompt file: {temp_audio_prompt_path}")
196 | except OSError as e:
197 | print(f"Warning: Error deleting temporary audio prompt file {temp_audio_prompt_path}: {e}")
198 |
199 | return output_audio
200 |
201 |
202 | # --- Create Gradio Interface ---
203 | css = """
204 | #col-container {max-width: 90%; margin-left: auto; margin-right: auto;}
205 | """
206 | # Attempt to load default text from example.txt
207 | default_text = "[S1] Dia is an open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] Wow. Amazing. (laughs) \n[S2] Try it now on Git hub or Hugging Face."
208 | example_txt_path = Path("./example.txt")
209 | if example_txt_path.exists():
210 | try:
211 | default_text = example_txt_path.read_text(encoding="utf-8").strip()
212 | if not default_text: # Handle empty example file
213 | default_text = "Example text file was empty."
214 | except Exception as e:
215 | print(f"Warning: Could not read example.txt: {e}")
216 |
217 |
218 | # Build Gradio UI
219 | with gr.Blocks(css=css) as demo:
220 | gr.Markdown("# Nari Text-to-Speech Synthesis")
221 |
222 | with gr.Row(equal_height=False):
223 | with gr.Column(scale=1):
224 | text_input = gr.Textbox(
225 | label="Input Text",
226 | placeholder="Enter text here...",
227 | value=default_text,
228 | lines=5, # Increased lines
229 | )
230 | audio_prompt_input = gr.Audio(
231 | label="Audio Prompt (Optional)",
232 | show_label=True,
233 | sources=["upload", "microphone"],
234 | type="numpy",
235 | )
236 | with gr.Accordion("Generation Parameters", open=False):
237 | max_new_tokens = gr.Slider(
238 | label="Max New Tokens (Audio Length)",
239 | minimum=860,
240 | maximum=3072,
241 | value=model.config.data.audio_length, # Use config default if available, else fallback
242 | step=50,
243 | info="Controls the maximum length of the generated audio (more tokens = longer audio).",
244 | )
245 | cfg_scale = gr.Slider(
246 | label="CFG Scale (Guidance Strength)",
247 | minimum=1.0,
248 | maximum=5.0,
249 | value=3.0, # Default from inference.py
250 | step=0.1,
251 | info="Higher values increase adherence to the text prompt.",
252 | )
253 | temperature = gr.Slider(
254 | label="Temperature (Randomness)",
255 | minimum=1.0,
256 | maximum=1.5,
257 | value=1.3, # Default from inference.py
258 | step=0.05,
259 | info="Lower values make the output more deterministic, higher values increase randomness.",
260 | )
261 | top_p = gr.Slider(
262 | label="Top P (Nucleus Sampling)",
263 | minimum=0.80,
264 | maximum=1.0,
265 | value=0.95, # Default from inference.py
266 | step=0.01,
267 | info="Filters vocabulary to the most likely tokens cumulatively reaching probability P.",
268 | )
269 | cfg_filter_top_k = gr.Slider(
270 | label="CFG Filter Top K",
271 | minimum=15,
272 | maximum=50,
273 | value=30,
274 | step=1,
275 | info="Top k filter for CFG guidance.",
276 | )
277 | speed_factor_slider = gr.Slider(
278 | label="Speed Factor",
279 | minimum=0.8,
280 | maximum=1.0,
281 | value=0.94,
282 | step=0.02,
283 | info="Adjusts the speed of the generated audio (1.0 = original speed).",
284 | )
285 |
286 | run_button = gr.Button("Generate Audio", variant="primary")
287 |
288 | with gr.Column(scale=1):
289 | audio_output = gr.Audio(
290 | label="Generated Audio",
291 | type="numpy",
292 | autoplay=False,
293 | )
294 |
295 | # Link button click to function
296 | run_button.click(
297 | fn=run_inference,
298 | inputs=[
299 | text_input,
300 | audio_prompt_input,
301 | max_new_tokens,
302 | cfg_scale,
303 | temperature,
304 | top_p,
305 | cfg_filter_top_k,
306 | speed_factor_slider,
307 | ],
308 | outputs=[audio_output], # Add status_output here if using it
309 | api_name="generate_audio",
310 | )
311 |
312 | # Add examples (ensure the prompt path is correct or remove it if example file doesn't exist)
313 | example_prompt_path = "./example_prompt.mp3" # Adjust if needed
314 | examples_list = [
315 | [
316 | "[S1] Oh fire! Oh my goodness! What's the procedure? What to we do people? The smoke could be coming through an air duct! \n[S2] Oh my god! Okay.. it's happening. Everybody stay calm! \n[S1] What's the procedure... \n[S2] Everybody stay fucking calm!!!... Everybody fucking calm down!!!!! \n[S1] No! No! If you touch the handle, if its hot there might be a fire down the hallway! ",
317 | None,
318 | 3072,
319 | 3.0,
320 | 1.3,
321 | 0.95,
322 | 35,
323 | 0.94,
324 | ],
325 | [
326 | "[S1] Open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] I'm biased, but I think we clearly won. \n[S2] Hard to disagree. (laughs) \n[S1] Thanks for listening to this demo. \n[S2] Try it now on Git hub and Hugging Face. \n[S1] If you liked our model, please give us a star and share to your friends. \n[S2] This was Nari Labs.",
327 | example_prompt_path if Path(example_prompt_path).exists() else None,
328 | 3072,
329 | 3.0,
330 | 1.3,
331 | 0.95,
332 | 35,
333 | 0.94,
334 | ],
335 | ]
336 |
337 | if examples_list:
338 | gr.Examples(
339 | examples=examples_list,
340 | inputs=[
341 | text_input,
342 | audio_prompt_input,
343 | max_new_tokens,
344 | cfg_scale,
345 | temperature,
346 | top_p,
347 | cfg_filter_top_k,
348 | speed_factor_slider,
349 | ],
350 | outputs=[audio_output],
351 | fn=run_inference,
352 | cache_examples=False,
353 | label="Examples (Click to Run)",
354 | )
355 | else:
356 | gr.Markdown("_(No examples configured or example prompt file missing)_")
357 |
358 |
359 | # --- Launch the App ---
360 | if __name__ == "__main__":
361 | print("Launching Gradio interface...")
362 |
363 | # set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values
364 | # use `GRADIO_SERVER_NAME=0.0.0.0` for Docker
365 | demo.launch(share=args.share)
366 |
--------------------------------------------------------------------------------
/app_local.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import tempfile
3 | import time
4 | from pathlib import Path
5 | from typing import Optional, Tuple
6 |
7 | import gradio as gr
8 | import numpy as np
9 | import soundfile as sf
10 | import torch
11 |
12 | from dia.model import Dia
13 | from dia.config import DiaConfig
14 | from dia.layers import DiaModel
15 | import dac
16 |
17 |
18 | # --- Global Setup ---
19 | parser = argparse.ArgumentParser(description="Gradio interface for Nari TTS")
20 | parser.add_argument(
21 | "--device", type=str, default=None, help="Force device (e.g., 'cuda', 'mps', 'cpu')"
22 | )
23 | parser.add_argument("--share", action="store_true", help="Enable Gradio sharing")
24 | parser.add_argument("--local_ckpt", type=str, default="ckpt_epoch1_fp32.pth", help="path to your local checkpoint")
25 | parser.add_argument("--config", type=str, default="dia/config_inference.json", help="path to your inference")
26 | parser.add_argument("--half", type=bool, default=False, help="load model in fp16")
27 | parser.add_argument("--compile", type=bool, default=False, help="torch compile model")
28 |
29 | args = parser.parse_args()
30 |
31 |
32 | # Determine device
33 | if args.device:
34 | device = torch.device(args.device)
35 | elif torch.cuda.is_available():
36 | device = torch.device("cuda")
37 | # Simplified MPS check for broader compatibility
38 | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
39 | # Basic check is usually sufficient, detailed check can be problematic
40 | device = torch.device("mps")
41 | else:
42 | device = torch.device("cpu")
43 |
44 | print(f"Using device: {device}")
45 |
46 | # Load Nari model and config
47 | print("Loading Nari model...")
48 | try:
49 | # Use the function from inference.py
50 | """cfg = DiaConfig.load("dia/config.json")
51 |
52 | ptmodel = DiaModel(cfg)
53 | if args.half:
54 | ptmodel=ptmodel.half()
55 | if args.compile:
56 | ptmodel = torch.compile(ptmodel, backend="inductor")
57 |
58 | state = torch.load(args.local_ckpt, map_location="cpu")
59 | ptmodel.load_state_dict(state)
60 | ptmodel = ptmodel.to(device).eval()
61 | #ptmodel.float()
62 | model = Dia(cfg, device)
63 | model.model = ptmodel
64 | dac_model = dac.DAC.load(dac.utils.download())
65 | dac_model = dac_model.to(device)
66 | model.dac_model=dac_model"""
67 |
68 |
69 | model = Dia.from_local(
70 | config_path=args.config,
71 | checkpoint_path=args.local_ckpt,
72 | device=device,)
73 |
74 | except Exception as e:
75 | print(f"Error loading Nari model: {e}")
76 | raise
77 |
78 |
79 | def run_inference(
80 | text_input: str,
81 | audio_prompt_input: Optional[Tuple[int, np.ndarray]],
82 | max_new_tokens: int,
83 | cfg_scale: float,
84 | temperature: float,
85 | top_p: float,
86 | cfg_filter_top_k: int,
87 | speed_factor: float,
88 | ):
89 | """
90 | Runs Nari inference using the globally loaded model and provided inputs.
91 | Uses temporary files for text and audio prompt compatibility with inference.generate.
92 | """
93 | global model, device # Access global model, config, device
94 |
95 | if not text_input or text_input.isspace():
96 | raise gr.Error("Text input cannot be empty.")
97 |
98 | temp_txt_file_path = None
99 | temp_audio_prompt_path = None
100 | output_audio = (44100, np.zeros(1, dtype=np.float32))
101 |
102 | try:
103 | prompt_path_for_generate = None
104 | if audio_prompt_input is not None:
105 | sr, audio_data = audio_prompt_input
106 | # Check if audio_data is valid
107 | if (
108 | audio_data is None or audio_data.size == 0 or audio_data.max() == 0
109 | ): # Check for silence/empty
110 | gr.Warning("Audio prompt seems empty or silent, ignoring prompt.")
111 | else:
112 | # Save prompt audio to a temporary WAV file
113 | with tempfile.NamedTemporaryFile(
114 | mode="wb", suffix=".wav", delete=False
115 | ) as f_audio:
116 | temp_audio_prompt_path = f_audio.name # Store path for cleanup
117 |
118 | # Basic audio preprocessing for consistency
119 | # Convert to float32 in [-1, 1] range if integer type
120 | if np.issubdtype(audio_data.dtype, np.integer):
121 | max_val = np.iinfo(audio_data.dtype).max
122 | audio_data = audio_data.astype(np.float32) / max_val
123 | elif not np.issubdtype(audio_data.dtype, np.floating):
124 | gr.Warning(
125 | f"Unsupported audio prompt dtype {audio_data.dtype}, attempting conversion."
126 | )
127 | # Attempt conversion, might fail for complex types
128 | try:
129 | audio_data = audio_data.astype(np.float32)
130 | except Exception as conv_e:
131 | raise gr.Error(
132 | f"Failed to convert audio prompt to float32: {conv_e}"
133 | )
134 |
135 | # Ensure mono (average channels if stereo)
136 | if audio_data.ndim > 1:
137 | if audio_data.shape[0] == 2: # Assume (2, N)
138 | audio_data = np.mean(audio_data, axis=0)
139 | elif audio_data.shape[1] == 2: # Assume (N, 2)
140 | audio_data = np.mean(audio_data, axis=1)
141 | else:
142 | gr.Warning(
143 | f"Audio prompt has unexpected shape {audio_data.shape}, taking first channel/axis."
144 | )
145 | audio_data = (
146 | audio_data[0]
147 | if audio_data.shape[0] < audio_data.shape[1]
148 | else audio_data[:, 0]
149 | )
150 | audio_data = np.ascontiguousarray(
151 | audio_data
152 | ) # Ensure contiguous after slicing/mean
153 |
154 | # Write using soundfile
155 | try:
156 | sf.write(
157 | temp_audio_prompt_path, audio_data, sr, subtype="FLOAT"
158 | ) # Explicitly use FLOAT subtype
159 | prompt_path_for_generate = temp_audio_prompt_path
160 | print(
161 | f"Created temporary audio prompt file: {temp_audio_prompt_path} (orig sr: {sr})"
162 | )
163 | except Exception as write_e:
164 | print(f"Error writing temporary audio file: {write_e}")
165 | raise gr.Error(f"Failed to save audio prompt: {write_e}")
166 |
167 | # 3. Run Generation
168 |
169 | start_time = time.time()
170 |
171 | # Use torch.inference_mode() context manager for the generation call
172 | with torch.inference_mode():
173 | output_audio_np = model.generate(
174 | text_input,
175 | max_tokens=max_new_tokens,
176 | cfg_scale=cfg_scale,
177 | temperature=temperature,
178 | top_p=top_p,
179 | use_cfg_filter=True,
180 | cfg_filter_top_k=cfg_filter_top_k, # Pass the value here
181 | use_torch_compile=False, # Keep False for Gradio stability
182 | audio_prompt_path=prompt_path_for_generate,
183 | )
184 |
185 | end_time = time.time()
186 | print(f"Generation finished in {end_time - start_time:.2f} seconds.")
187 |
188 | # 4. Convert Codes to Audio
189 | if output_audio_np is not None:
190 | # Get sample rate from the loaded DAC model
191 | output_sr = 44100
192 |
193 | # --- Slow down audio ---
194 | original_len = len(output_audio_np)
195 | # Ensure speed_factor is positive and not excessively small/large to avoid issues
196 | speed_factor = max(0.1, min(speed_factor, 5.0))
197 | target_len = int(
198 | original_len / speed_factor
199 | ) # Target length based on speed_factor
200 | if (
201 | target_len != original_len and target_len > 0
202 | ): # Only interpolate if length changes and is valid
203 | x_original = np.arange(original_len)
204 | x_resampled = np.linspace(0, original_len - 1, target_len)
205 | resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np)
206 | output_audio = (
207 | output_sr,
208 | resampled_audio_np.astype(np.float32),
209 | ) # Use resampled audio
210 | print(
211 | f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed."
212 | )
213 | else:
214 | output_audio = (
215 | output_sr,
216 | output_audio_np,
217 | ) # Keep original if calculation fails or no change
218 | print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).")
219 | # --- End slowdown ---
220 |
221 | print(
222 | f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
223 | )
224 |
225 | else:
226 | print("\nGeneration finished, but no valid tokens were produced.")
227 | # Return default silence
228 | gr.Warning("Generation produced no output.")
229 |
230 | except Exception as e:
231 | print(f"Error during inference: {e}")
232 | import traceback
233 |
234 | traceback.print_exc()
235 | # Re-raise as Gradio error to display nicely in the UI
236 | raise gr.Error(f"Inference failed: {e}")
237 |
238 | finally:
239 | # 5. Cleanup Temporary Files defensively
240 | if temp_txt_file_path and Path(temp_txt_file_path).exists():
241 | try:
242 | Path(temp_txt_file_path).unlink()
243 | print(f"Deleted temporary text file: {temp_txt_file_path}")
244 | except OSError as e:
245 | print(
246 | f"Warning: Error deleting temporary text file {temp_txt_file_path}: {e}"
247 | )
248 | if temp_audio_prompt_path and Path(temp_audio_prompt_path).exists():
249 | try:
250 | Path(temp_audio_prompt_path).unlink()
251 | print(f"Deleted temporary audio prompt file: {temp_audio_prompt_path}")
252 | except OSError as e:
253 | print(
254 | f"Warning: Error deleting temporary audio prompt file {temp_audio_prompt_path}: {e}"
255 | )
256 |
257 | return output_audio
258 |
259 |
260 | # --- Create Gradio Interface ---
261 | css = """
262 | #col-container {max-width: 90%; margin-left: auto; margin-right: auto;}
263 | """
264 | # Attempt to load default text from example.txt
265 | default_text = ""
266 | example_txt_path = Path("./example.txt")
267 | if example_txt_path.exists():
268 | try:
269 | default_text = example_txt_path.read_text(encoding="utf-8").strip()
270 | if not default_text: # Handle empty example file
271 | default_text = "Example text file was empty."
272 | except Exception as e:
273 | print(f"Warning: Could not read example.txt: {e}")
274 |
275 |
276 | # Build Gradio UI
277 | with gr.Blocks(css=css) as demo:
278 | gr.Markdown("# Nari Text-to-Speech Synthesis")
279 |
280 | with gr.Row(equal_height=False):
281 | with gr.Column(scale=1):
282 | text_input = gr.Textbox(
283 | label="Input Text",
284 | placeholder="Enter text here...",
285 | value=default_text,
286 | lines=5, # Increased lines
287 | )
288 | audio_prompt_input = gr.Audio(
289 | label="Audio Prompt (Optional)",
290 | show_label=True,
291 | sources=["upload", "microphone"],
292 | type="numpy",
293 | )
294 | with gr.Accordion("Generation Parameters", open=False):
295 | max_new_tokens = gr.Slider(
296 | label="Max New Tokens (Audio Length)",
297 | minimum=860,
298 | maximum=3072,
299 | value=model.config.data.audio_length, # Use config default if available, else fallback
300 | step=50,
301 | info="Controls the maximum length of the generated audio (more tokens = longer audio).",
302 | )
303 | cfg_scale = gr.Slider(
304 | label="CFG Scale (Guidance Strength)",
305 | minimum=1.0,
306 | maximum=5.0,
307 | value=3.0, # Default from inference.py
308 | step=0.1,
309 | info="Higher values increase adherence to the text prompt.",
310 | )
311 | temperature = gr.Slider(
312 | label="Temperature (Randomness)",
313 | minimum=1.0,
314 | maximum=1.5,
315 | value=1.3, # Default from inference.py
316 | step=0.05,
317 | info="Lower values make the output more deterministic, higher values increase randomness.",
318 | )
319 | top_p = gr.Slider(
320 | label="Top P (Nucleus Sampling)",
321 | minimum=0.80,
322 | maximum=1.0,
323 | value=0.95, # Default from inference.py
324 | step=0.01,
325 | info="Filters vocabulary to the most likely tokens cumulatively reaching probability P.",
326 | )
327 | cfg_filter_top_k = gr.Slider(
328 | label="CFG Filter Top K",
329 | minimum=15,
330 | maximum=50,
331 | value=35,
332 | step=1,
333 | info="Top k filter for CFG guidance.",
334 | )
335 | speed_factor_slider = gr.Slider(
336 | label="Speed Factor",
337 | minimum=0.8,
338 | maximum=1.0,
339 | value=1.0,
340 | step=0.02,
341 | info="Adjusts the speed of the generated audio (1.0 = original speed).",
342 | )
343 |
344 | run_button = gr.Button("Generate Audio", variant="primary")
345 |
346 | with gr.Column(scale=1):
347 | audio_output = gr.Audio(
348 | label="Generated Audio",
349 | type="numpy",
350 | autoplay=False,
351 | )
352 |
353 | # Link button click to function
354 | run_button.click(
355 | fn=run_inference,
356 | inputs=[
357 | text_input,
358 | audio_prompt_input,
359 | max_new_tokens,
360 | cfg_scale,
361 | temperature,
362 | top_p,
363 | cfg_filter_top_k,
364 | speed_factor_slider,
365 | ],
366 | outputs=[audio_output], # Add status_output here if using it
367 | api_name="generate_audio",
368 | )
369 |
370 |
371 | # --- Launch the App ---
372 | if __name__ == "__main__":
373 | print("Launching Gradio interface...")
374 | demo.launch(share=args.share, server_name="0.0.0.0")
375 |
--------------------------------------------------------------------------------
/cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import soundfile as sf
7 | import torch
8 |
9 | from dia.model import Dia
10 |
11 |
12 | def set_seed(seed: int):
13 | """Sets the random seed for reproducibility."""
14 | random.seed(seed)
15 | np.random.seed(seed)
16 | torch.manual_seed(seed)
17 | if torch.cuda.is_available():
18 | torch.cuda.manual_seed(seed)
19 | torch.cuda.manual_seed_all(seed)
20 | # Ensure deterministic behavior for cuDNN (if used)
21 | torch.backends.cudnn.deterministic = True
22 | torch.backends.cudnn.benchmark = False
23 |
24 |
25 | def main():
26 | parser = argparse.ArgumentParser(description="Generate audio using the Dia model.")
27 |
28 | parser.add_argument("text", type=str, help="Input text for speech generation.")
29 | parser.add_argument(
30 | "--output", type=str, required=True, help="Path to save the generated audio file (e.g., output.wav)."
31 | )
32 |
33 | parser.add_argument(
34 | "--repo-id",
35 | type=str,
36 | default="nari-labs/Dia-1.6B",
37 | help="Hugging Face repository ID (e.g., nari-labs/Dia-1.6B).",
38 | )
39 | parser.add_argument(
40 | "--local-paths", action="store_true", help="Load model from local config and checkpoint files."
41 | )
42 |
43 | parser.add_argument(
44 | "--config", type=str, help="Path to local config.json file (required if --local-paths is set)."
45 | )
46 | parser.add_argument(
47 | "--checkpoint", type=str, help="Path to local model checkpoint .pth file (required if --local-paths is set)."
48 | )
49 | parser.add_argument(
50 | "--audio-prompt", type=str, default=None, help="Path to an optional audio prompt WAV file for voice cloning."
51 | )
52 |
53 | gen_group = parser.add_argument_group("Generation Parameters")
54 | gen_group.add_argument(
55 | "--max-tokens",
56 | type=int,
57 | default=None,
58 | help="Maximum number of audio tokens to generate (defaults to config value).",
59 | )
60 | gen_group.add_argument(
61 | "--cfg-scale", type=float, default=3.0, help="Classifier-Free Guidance scale (default: 3.0)."
62 | )
63 | gen_group.add_argument(
64 | "--temperature", type=float, default=1.3, help="Sampling temperature (higher is more random, default: 0.7)."
65 | )
66 | gen_group.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling probability (default: 0.95).")
67 |
68 | infra_group = parser.add_argument_group("Infrastructure")
69 | infra_group.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.")
70 | infra_group.add_argument(
71 | "--device",
72 | type=str,
73 | default="cuda" if torch.cuda.is_available() else "cpu",
74 | help="Device to run inference on (e.g., 'cuda', 'cpu', default: auto).",
75 | )
76 |
77 | args = parser.parse_args()
78 |
79 | # Validation for local paths
80 | if args.local_paths:
81 | if not args.config:
82 | parser.error("--config is required when --local-paths is set.")
83 | if not args.checkpoint:
84 | parser.error("--checkpoint is required when --local-paths is set.")
85 | if not os.path.exists(args.config):
86 | parser.error(f"Config file not found: {args.config}")
87 | if not os.path.exists(args.checkpoint):
88 | parser.error(f"Checkpoint file not found: {args.checkpoint}")
89 |
90 | # Set seed if provided
91 | if args.seed is not None:
92 | set_seed(args.seed)
93 | print(f"Using random seed: {args.seed}")
94 |
95 | # Determine device
96 | device = torch.device(args.device)
97 | print(f"Using device: {device}")
98 |
99 | # Load model
100 | print("Loading model...")
101 | if args.local_paths:
102 | print(f"Loading from local paths: config='{args.config}', checkpoint='{args.checkpoint}'")
103 | try:
104 | model = Dia.from_local(args.config, args.checkpoint, device=device)
105 | except Exception as e:
106 | print(f"Error loading local model: {e}")
107 | exit(1)
108 | else:
109 | print(f"Loading from Hugging Face Hub: repo_id='{args.repo_id}'")
110 | try:
111 | model = Dia.from_pretrained(args.repo_id, device=device)
112 | except Exception as e:
113 | print(f"Error loading model from Hub: {e}")
114 | exit(1)
115 | print("Model loaded.")
116 |
117 | # Generate audio
118 | print("Generating audio...")
119 | try:
120 | sample_rate = 44100 # Default assumption
121 |
122 | output_audio = model.generate(
123 | text=args.text,
124 | audio_prompt_path=args.audio_prompt,
125 | max_tokens=args.max_tokens,
126 | cfg_scale=args.cfg_scale,
127 | temperature=args.temperature,
128 | top_p=args.top_p,
129 | )
130 | print("Audio generation complete.")
131 |
132 | print(f"Saving audio to {args.output}...")
133 | os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
134 |
135 | sf.write(args.output, output_audio, sample_rate)
136 | print(f"Audio successfully saved to {args.output}")
137 |
138 | except Exception as e:
139 | print(f"Error during audio generation or saving: {e}")
140 | exit(1)
141 |
142 |
143 | if __name__ == "__main__":
144 | main()
145 |
--------------------------------------------------------------------------------
/dia/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stlohrey/dia-finetuning/25501f2911a20b7211a42640f391a07d562fd2f6/dia/__init__.py
--------------------------------------------------------------------------------
/dia/audio.py:
--------------------------------------------------------------------------------
1 | import typing as tp
2 |
3 | import torch
4 |
5 | from .config import DataConfig
6 |
7 |
8 | def build_delay_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
9 | """
10 | Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
11 | Negative t_idx => BOS; t_idx >= T => PAD.
12 | """
13 | delay_arr = torch.tensor(delay_pattern, dtype=torch.int32)
14 |
15 | t_idx_BxT = torch.broadcast_to(
16 | torch.arange(T, dtype=torch.int32)[None, :],
17 | [B, T],
18 | )
19 | t_idx_BxTx1 = t_idx_BxT[..., None]
20 | t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C)
21 |
22 | b_idx_BxTxC = torch.broadcast_to(
23 | torch.arange(B, dtype=torch.int32).view(B, 1, 1),
24 | [B, T, C],
25 | )
26 | c_idx_BxTxC = torch.broadcast_to(
27 | torch.arange(C, dtype=torch.int32).view(1, 1, C),
28 | [B, T, C],
29 | )
30 |
31 | # We must clamp time indices to [0..T-1] so gather_nd equivalent won't fail
32 | t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1)
33 |
34 | indices_BTCx3 = torch.stack(
35 | [
36 | b_idx_BxTxC.reshape(-1),
37 | t_clamped_BxTxC.reshape(-1),
38 | c_idx_BxTxC.reshape(-1),
39 | ],
40 | dim=1,
41 | ).long() # Ensure indices are long type for indexing
42 |
43 | return t_idx_BxTxC, indices_BTCx3
44 |
45 |
46 | def apply_audio_delay(
47 | audio_BxTxC: torch.Tensor,
48 | pad_value: int,
49 | bos_value: int,
50 | precomp: tp.Tuple[torch.Tensor, torch.Tensor],
51 | ) -> torch.Tensor:
52 | """
53 | Applies the delay pattern to batched audio tokens using precomputed indices,
54 | inserting BOS where t_idx < 0 and PAD where t_idx >= T.
55 |
56 | Args:
57 | audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float)
58 | pad_value: the padding token
59 | bos_value: the BOS token
60 | precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices
61 |
62 | Returns:
63 | result_BxTxC: [B, T, C] delayed audio tokens
64 | """
65 | device = audio_BxTxC.device # Get device from input tensor
66 | t_idx_BxTxC, indices_BTCx3 = precomp
67 | t_idx_BxTxC = t_idx_BxTxC.to(device) # Move precomputed indices to device
68 | indices_BTCx3 = indices_BTCx3.to(device)
69 |
70 | # Equivalent of tf.gather_nd using advanced indexing
71 | # Ensure indices are long type if not already (build_delay_indices should handle this)
72 | gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
73 | gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
74 |
75 | # Create masks on the correct device
76 | mask_bos = t_idx_BxTxC < 0 # => place bos_value
77 | mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] # => place pad_value
78 |
79 | # Create scalar tensors on the correct device
80 | bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device)
81 | pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
82 |
83 | # If mask_bos, BOS; else if mask_pad, PAD; else original gather
84 | # All tensors should now be on the same device
85 | result_BxTxC = torch.where(mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC))
86 |
87 | return result_BxTxC
88 |
89 |
90 | @torch.no_grad()
91 | @torch.inference_mode()
92 | def audio_to_codebook(
93 | model,
94 | input_values,
95 | data_config: DataConfig,
96 | padding_mask=None,
97 | sample_rate=44100,
98 | ):
99 | """
100 | Encodes the input audio waveform into discrete codes.
101 |
102 | Args:
103 | model: The model to use for encoding.
104 | input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
105 | Float values of the input audio waveform.
106 | padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
107 | Padding mask used to pad the `input_values`.
108 | sample_rate (`int`, *optional*) :
109 | Signal sampling_rate
110 |
111 | Returns:
112 | A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
113 | factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
114 | `codebook` of shape `[batch_size, num_codebooks, frames]`.
115 | Scale is not used here.
116 |
117 | """
118 | audio_data = model.preprocess(input_values, sample_rate)
119 |
120 | if padding_mask is None:
121 | padding_mask = torch.ones_like(input_values).bool()
122 |
123 | _, encoded_frame, _, _, _ = model.encode(audio_data, n_quantizers=None) # 1, C, T
124 | seq_length = encoded_frame.shape[2]
125 |
126 | t_idx_BxTxC, indices_BTCx3 = build_delay_indices(
127 | B=1,
128 | T=seq_length,
129 | C=data_config.channels,
130 | delay_pattern=data_config.delay_pattern,
131 | )
132 |
133 | encoded_frame = apply_audio_delay(
134 | audio_BxTxC=encoded_frame.transpose(1, 2), # 1, T, C
135 | pad_value=data_config.audio_pad_value,
136 | bos_value=data_config.audio_bos_value,
137 | precomp=(t_idx_BxTxC, indices_BTCx3),
138 | )
139 |
140 | return encoded_frame
141 |
142 |
143 | def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
144 | """
145 | Precompute indices for the revert operation using PyTorch.
146 |
147 | Returns:
148 | A tuple (t_idx_BxTxC, indices_BTCx3) where:
149 | - t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay.
150 | - indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from:
151 | batch indices, clamped time indices, and channel indices.
152 | """
153 | # Use default device unless specified otherwise; assumes inputs might define device later
154 | device = None # Or determine dynamically if needed, e.g., from a model parameter
155 |
156 | delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device)
157 |
158 | t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T])
159 | t_idx_BT1 = t_idx_BT1.unsqueeze(-1)
160 |
161 | t_idx_BxTxC = torch.minimum(
162 | t_idx_BT1 + delay_arr.view(1, 1, C),
163 | torch.tensor(T - 1, device=device),
164 | )
165 | b_idx_BxTxC = torch.broadcast_to(torch.arange(B, device=device).view(B, 1, 1), [B, T, C])
166 | c_idx_BxTxC = torch.broadcast_to(torch.arange(C, device=device).view(1, 1, C), [B, T, C])
167 |
168 | indices_BTCx3 = torch.stack(
169 | [
170 | b_idx_BxTxC.reshape(-1),
171 | t_idx_BxTxC.reshape(-1),
172 | c_idx_BxTxC.reshape(-1),
173 | ],
174 | axis=1,
175 | ).long() # Ensure indices are long type
176 |
177 | return t_idx_BxTxC, indices_BTCx3
178 |
179 |
180 | def revert_audio_delay(
181 | audio_BxTxC: torch.Tensor,
182 | pad_value: int,
183 | precomp: tp.Tuple[torch.Tensor, torch.Tensor],
184 | T: int,
185 | ) -> torch.Tensor:
186 | """
187 | Reverts a delay pattern from batched audio tokens using precomputed indices (PyTorch version).
188 |
189 | Args:
190 | audio_BxTxC: Input delayed audio tensor
191 | pad_value: Padding value for out-of-bounds indices
192 | precomp: Precomputed revert indices tuple containing:
193 | - t_idx_BxTxC: Time offset indices tensor
194 | - indices_BTCx3: Gather indices tensor for original audio
195 | T: Original sequence length before padding
196 |
197 | Returns:
198 | Reverted audio tensor with same shape as input
199 | """
200 | t_idx_BxTxC, indices_BTCx3 = precomp
201 | device = audio_BxTxC.device # Get device from input tensor
202 |
203 | # Move precomputed indices to the same device as audio_BxTxC if they aren't already
204 | t_idx_BxTxC = t_idx_BxTxC.to(device)
205 | indices_BTCx3 = indices_BTCx3.to(device)
206 |
207 | # Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
208 | gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
209 | gathered_BxTxC = gathered_flat.view(audio_BxTxC.size()) # Use .size() for robust reshaping
210 |
211 | # Create pad_tensor on the correct device
212 | pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
213 | # Create T tensor on the correct device for comparison
214 | T_tensor = torch.tensor(T, device=device)
215 |
216 | result_BxTxC = torch.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC) # Changed np.where to torch.where
217 |
218 | return result_BxTxC
219 |
220 |
221 | @torch.no_grad()
222 | @torch.inference_mode()
223 | def decode(
224 | model,
225 | audio_codes,
226 | ):
227 | """
228 | Decodes the given frames into an output audio waveform
229 | """
230 | if len(audio_codes) != 1:
231 | raise ValueError(f"Expected one frame, got {len(audio_codes)}")
232 |
233 | try:
234 | audio_values = model.quantizer.from_codes(audio_codes)
235 | audio_values = model.decode(audio_values[0])
236 |
237 | return audio_values
238 | except Exception as e:
239 | print(f"Error in decode method: {str(e)}")
240 | raise
241 |
242 |
243 | def codebook_to_audio(generated_codes: torch.Tensor, model, delay_pattern, B=1, T=2600, C=9):
244 | """Process a single codebook file to generate audio"""
245 | # Remove BOS token
246 | generated_codes = generated_codes[:, 1:]
247 |
248 | if generated_codes.shape[1] > T:
249 | generated_codes = generated_codes[:, :T]
250 |
251 | seq_length = generated_codes.shape[1]
252 |
253 | # Build revert indices
254 | t_idx_BxTxC, indices_BTCx3 = build_revert_indices(B=B, T=seq_length, C=C, delay_pattern=delay_pattern)
255 |
256 | # Transpose and add batch dimension
257 | audio_BxTxC = generated_codes.transpose(1, 0).unsqueeze(0)
258 | reverted_codebook = revert_audio_delay(
259 | audio_BxTxC=audio_BxTxC,
260 | pad_value=0,
261 | precomp=(t_idx_BxTxC, indices_BTCx3),
262 | T=seq_length,
263 | )
264 | reverted_codebook = reverted_codebook[:, :-30, :]
265 |
266 | codebook = reverted_codebook.transpose(1, 2)
267 |
268 | min_valid_index = 0
269 | max_valid_index = 1023
270 | invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
271 |
272 | num_invalid = torch.sum(invalid_mask).item()
273 | if num_invalid > 0:
274 | print(f"Warning: Clamping {num_invalid} indices outside range [{min_valid_index}, {max_valid_index}] to 0.")
275 |
276 | # Set invalid values to 0 (modify the tensor in-place)
277 | codebook[invalid_mask] = 0
278 | audio_array = decode(model, codebook)
279 |
280 | return audio_array
281 |
--------------------------------------------------------------------------------
/dia/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "0.1",
3 | "model": {
4 | "encoder": {
5 | "n_layer": 12,
6 | "n_embd": 1024,
7 | "n_hidden": 4096,
8 | "n_head": 16,
9 | "head_dim": 128
10 | },
11 | "decoder": {
12 | "n_layer": 18,
13 | "n_embd": 2048,
14 | "n_hidden": 8192,
15 | "gqa_query_heads": 16,
16 | "cross_query_heads": 16,
17 | "kv_heads": 4,
18 | "gqa_head_dim": 128,
19 | "cross_head_dim": 128,
20 | "d_model" : 256
21 | },
22 | "src_vocab_size": 256,
23 | "tgt_vocab_size": 1028,
24 | "dropout": 0.0
25 | },
26 | "training": {
27 | "dtype": "bfloat16"
28 | },
29 | "data": {
30 | "text_length": 512,
31 | "audio_length": 1536,
32 | "channels": 9,
33 | "text_pad_value": 0,
34 | "audio_eos_value": 1024,
35 | "audio_pad_value": 1025,
36 | "audio_bos_value": 1026,
37 | "delay_pattern": [
38 | 0,
39 | 8,
40 | 9,
41 | 10,
42 | 11,
43 | 12,
44 | 13,
45 | 14,
46 | 15
47 | ]
48 | }
49 | }
--------------------------------------------------------------------------------
/dia/config.py:
--------------------------------------------------------------------------------
1 | """Configuration management module for the Dia model.
2 |
3 | This module provides comprehensive configuration management for the Dia model,
4 | utilizing Pydantic for validation. It defines configurations for data processing,
5 | model architecture (encoder and decoder), and training settings.
6 |
7 | Key components:
8 | - DataConfig: Parameters for data loading and preprocessing.
9 | - EncoderConfig: Architecture details for the encoder module.
10 | - DecoderConfig: Architecture details for the decoder module.
11 | - ModelConfig: Combined model architecture settings.
12 | - TrainingConfig: Training hyperparameters and settings.
13 | - DiaConfig: Master configuration combining all components.
14 | """
15 |
16 | import os
17 | from typing import Annotated
18 |
19 | from pydantic import BaseModel, BeforeValidator, Field
20 |
21 |
22 | class DataConfig(BaseModel, frozen=True):
23 | """Configuration for data loading and preprocessing.
24 |
25 | Attributes:
26 | text_length: Maximum length of text sequences (must be multiple of 128).
27 | audio_length: Maximum length of audio sequences (must be multiple of 128).
28 | channels: Number of audio channels.
29 | text_pad_value: Value used for padding text sequences.
30 | audio_eos_value: Value representing the end of audio sequences.
31 | audio_bos_value: Value representing the beginning of audio sequences.
32 | audio_pad_value: Value used for padding audio sequences.
33 | delay_pattern: List of delay values for each audio channel.
34 | """
35 |
36 | text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
37 | audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
38 | channels: int = Field(default=9, gt=0, multiple_of=1)
39 | text_pad_value: int = Field(default=0)
40 | audio_eos_value: int = Field(default=1024)
41 | audio_pad_value: int = Field(default=1025)
42 | audio_bos_value: int = Field(default=1026)
43 | delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15])
44 |
45 | def __hash__(self) -> int:
46 | """Generate a hash based on all fields of the config."""
47 | return hash(
48 | (
49 | self.text_length,
50 | self.audio_length,
51 | self.channels,
52 | self.text_pad_value,
53 | self.audio_pad_value,
54 | self.audio_bos_value,
55 | self.audio_eos_value,
56 | tuple(self.delay_pattern),
57 | )
58 | )
59 |
60 |
61 | class EncoderConfig(BaseModel, frozen=True):
62 | """Configuration for the encoder component of the Dia model.
63 |
64 | Attributes:
65 | n_layer: Number of transformer layers.
66 | n_embd: Embedding dimension.
67 | n_hidden: Hidden dimension size in the MLP layers.
68 | n_head: Number of attention heads.
69 | head_dim: Dimension per attention head.
70 | mlp_activations: List of activation functions for the MLP layers.
71 | use_pre_norm: Whether to use pre-normalization (LayerNorm before attention/MLP).
72 | """
73 |
74 | n_layer: int = Field(gt=0)
75 | n_embd: int = Field(gt=0)
76 | n_hidden: int = Field(gt=0)
77 | n_head: int = Field(gt=0)
78 | head_dim: int = Field(gt=0)
79 | mlp_activations: list[str] = Field(default=["silu", "linear"])
80 | use_pre_norm: bool = Field(default=False)
81 |
82 |
83 | class DecoderConfig(BaseModel, frozen=True):
84 | """Configuration for the decoder component of the Dia model.
85 |
86 | Attributes:
87 | n_layer: Number of transformer layers.
88 | n_embd: Embedding dimension.
89 | n_hidden: Hidden dimension size in the MLP layers.
90 | gqa_query_heads: Number of query heads for grouped-query self-attention.
91 | kv_heads: Number of key/value heads for grouped-query self-attention.
92 | gqa_head_dim: Dimension per query head for grouped-query self-attention.
93 | cross_query_heads: Number of query heads for cross-attention.
94 | cross_head_dim: Dimension per cross-attention head.
95 | mlp_activations: List of activation functions for the MLP layers.
96 | use_pre_norm: Whether to use pre-normalization.
97 | """
98 |
99 | n_layer: int = Field(gt=0)
100 | n_embd: int = Field(gt=0)
101 | n_hidden: int = Field(gt=0)
102 | gqa_query_heads: int = Field(gt=0)
103 | kv_heads: int = Field(gt=0)
104 | gqa_head_dim: int = Field(gt=0)
105 | cross_query_heads: int = Field(gt=0)
106 | cross_head_dim: int = Field(gt=0)
107 | mlp_activations: list[str] = Field(default=["silu", "linear"])
108 | use_pre_norm: bool = Field(default=False)
109 |
110 |
111 | class ModelConfig(BaseModel, frozen=True):
112 | """Main configuration container for the Dia model architecture.
113 |
114 | Attributes:
115 | encoder: Configuration for the encoder component.
116 | decoder: Configuration for the decoder component.
117 | src_vocab_size: Size of the source (text) vocabulary.
118 | tgt_vocab_size: Size of the target (audio code) vocabulary.
119 | dropout: Dropout probability applied within the model.
120 | normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm).
121 | weight_dtype: Data type for model weights (e.g., "float32", "bfloat16").
122 | rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE).
123 | rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE).
124 | """
125 |
126 | encoder: EncoderConfig
127 | decoder: DecoderConfig
128 | src_vocab_size: int = Field(default=128, gt=0)
129 | tgt_vocab_size: int = Field(default=1028, gt=0)
130 | dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
131 | normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
132 | weight_dtype: str = Field(default="float32", description="Weight precision")
133 | rope_min_timescale: int = Field(default=1, description="Timescale For global Attention")
134 | rope_max_timescale: int = Field(default=10_000, description="Timescale For global Attention")
135 |
136 |
137 | class TrainingConfig(BaseModel, frozen=True):
138 | """Training process configuration and hyperparameters.
139 |
140 | Note: This configuration currently only includes precision settings.
141 | Other training parameters (like batch size, learning rate, optimizer settings)
142 | are assumed to be handled externally.
143 |
144 | Attributes:
145 | dtype: Data type for activations during training (e.g., "bfloat16", "float32").
146 | logits_dot_in_fp32: Whether to compute the final logits dot product in fp32 for stability.
147 | """
148 |
149 | dtype: str = Field(default="bfloat16", description="Activation precision")
150 | logits_dot_in_fp32: bool = Field(default=False)
151 |
152 |
153 | class DiaConfig(BaseModel, frozen=True):
154 | """Master configuration for the Dia model.
155 |
156 | Combines all sub-configurations into a single validated object.
157 |
158 | Attributes:
159 | version: Configuration version string.
160 | model: Model architecture configuration.
161 | training: Training process configuration (precision settings).
162 | data: Data loading and processing configuration.
163 | """
164 |
165 | version: str = Field(default="1.0")
166 | model: ModelConfig
167 | training: TrainingConfig
168 | data: DataConfig
169 |
170 | def save(self, path: str) -> None:
171 | """Save the current configuration instance to a JSON file.
172 |
173 | Ensures the parent directory exists and the file has a .json extension.
174 |
175 | Args:
176 | path: The target file path to save the configuration.
177 |
178 | Raises:
179 | ValueError: If the path is not a file with a .json extension.
180 | """
181 | os.makedirs(os.path.dirname(path), exist_ok=True)
182 | config_json = self.model_dump_json(indent=2)
183 | with open(path, "w") as f:
184 | f.write(config_json)
185 |
186 | @classmethod
187 | def load(cls, path: str) -> "DiaConfig | None":
188 | """Load and validate a Dia configuration from a JSON file.
189 |
190 | Args:
191 | path: The path to the configuration file.
192 |
193 | Returns:
194 | A validated DiaConfig instance if the file exists and is valid,
195 | otherwise None if the file is not found.
196 |
197 | Raises:
198 | ValueError: If the path does not point to an existing .json file.
199 | pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema.
200 | """
201 | try:
202 | with open(path, "r") as f:
203 | content = f.read()
204 | return cls.model_validate_json(content)
205 | except FileNotFoundError:
206 | return None
207 |
--------------------------------------------------------------------------------
/dia/config_inference.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "0.1",
3 | "model": {
4 | "encoder": {
5 | "n_layer": 12,
6 | "n_embd": 1024,
7 | "n_hidden": 4096,
8 | "n_head": 16,
9 | "head_dim": 128
10 | },
11 | "decoder": {
12 | "n_layer": 18,
13 | "n_embd": 2048,
14 | "n_hidden": 8192,
15 | "gqa_query_heads": 16,
16 | "cross_query_heads": 16,
17 | "kv_heads": 4,
18 | "gqa_head_dim": 128,
19 | "cross_head_dim": 128,
20 | "d_model" : 256
21 | },
22 | "src_vocab_size": 256,
23 | "tgt_vocab_size": 1028,
24 | "dropout": 0.0
25 | },
26 | "training": {
27 | "dtype": "float32"
28 | },
29 | "data": {
30 | "text_length": 512,
31 | "audio_length": 1536,
32 | "channels": 9,
33 | "text_pad_value": 0,
34 | "audio_eos_value": 1024,
35 | "audio_pad_value": 1025,
36 | "audio_bos_value": 1026,
37 | "delay_pattern": [
38 | 0,
39 | 8,
40 | 9,
41 | 10,
42 | 11,
43 | 12,
44 | 13,
45 | 14,
46 | 15
47 | ]
48 | }
49 | }
--------------------------------------------------------------------------------
/dia/convert_ckpt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from dia.layers import DiaModel # adjust your import if needed
4 | from dia.config import DiaConfig
5 |
6 | def convert_checkpoint(input_ckpt: str, output_ckpt: str, config_path: str):
7 | # select device
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 |
10 | # 1) Reconstruct exactly the same compiled model you saved
11 | dia_cfg = DiaConfig.load(config_path)
12 | model = DiaModel(dia_cfg).to(device)
13 | model = model.half()
14 | model = torch.compile(model, backend="inductor")
15 |
16 | # 2) Load your compiled/half checkpoint
17 | state = torch.load(input_ckpt, map_location=device)
18 | model.load_state_dict(state)
19 |
20 | # 3) Un-wrap to the original nn.Module
21 | orig = getattr(model, "_orig_mod", None) or getattr(model, "__wrapped__", None) or model
22 |
23 | # 4) Cast all params & buffers back to float32
24 | orig.float()
25 |
26 | # 5) Save its clean, float32 state_dict
27 | torch.save(orig.state_dict(), output_ckpt)
28 | print(f"Saved normal FP32 checkpoint to {output_ckpt}")
29 |
30 | def main():
31 | parser = argparse.ArgumentParser(
32 | description="Convert a compiled/half-precision checkpoint back to a standard FP32 state_dict."
33 | )
34 | parser.add_argument(
35 | "--input-ckpt", "-i",
36 | required=True,
37 | help="Path to the half-precision compiled checkpoint (.pth) to load"
38 | )
39 | parser.add_argument(
40 | "--output-ckpt", "-o",
41 | required=True,
42 | help="Path where the FP32 state_dict will be saved"
43 | )
44 | parser.add_argument(
45 | "--config", "-c",
46 | required=True,
47 | help="Path to your DiaConfig JSON file"
48 | )
49 |
50 | args = parser.parse_args()
51 | convert_checkpoint(args.input_ckpt, args.output_ckpt, args.config)
52 |
53 | if __name__ == "__main__":
54 | main()
--------------------------------------------------------------------------------
/dia/dataset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | import torchaudio
5 | import pandas as pd
6 | from torch.utils.data import Dataset
7 |
8 | import dac
9 | from .config import DiaConfig
10 |
11 |
12 |
13 |
14 | class LocalDiaDataset(Dataset):
15 | """Load from a local CSV (sep='|') + an audio folder."""
16 | def __init__(self, csv_path: Path, audio_root: Path, config: DiaConfig, dac_model: dac.DAC):
17 | self.df = pd.read_csv(csv_path, sep=r"\s*\|\s*", engine="python",
18 | names=["audio","text"] )
19 | self.audio_root = audio_root
20 | self.config = config
21 | self.dac_model = dac_model
22 |
23 | def __len__(self) -> int:
24 | return len(self.df)
25 |
26 | def __getitem__(self, idx: int):
27 | row = self.df.iloc[idx]
28 | lang = row.get("language", None)
29 | text = f"[{lang}]" + row["text"] if lang else row["text"]
30 | audio_path = self.audio_root / row["audio"]
31 | waveform, sr = torchaudio.load(audio_path)
32 | if sr != 44100:
33 | waveform = torchaudio.functional.resample(waveform, sr, 44100)
34 | waveform = waveform.unsqueeze(0)
35 | with torch.no_grad():
36 | audio_tensor = self.dac_model.preprocess(
37 | waveform, 44100
38 | ).to(next(self.dac_model.parameters()).device)
39 | _, encoded, *_ = self.dac_model.encode(audio_tensor, n_quantizers=None)
40 | encoded = encoded.squeeze(0).transpose(0, 1)
41 | return text, encoded, waveform
42 |
43 |
44 | class HFDiaDataset(Dataset):
45 | def __init__(self, hf_dataset, config: DiaConfig, dac_model: dac.DAC):
46 | self.dataset = hf_dataset
47 | self.config = config
48 | self.dac_model = dac_model
49 |
50 | def __len__(self) -> int:
51 | return len(self.dataset)
52 |
53 | def __getitem__(self, idx: int):
54 | sample = self.dataset[idx]
55 | lang = sample.get("language", None)
56 | text = f"[{lang}]" + sample["text"] if lang else sample["text"]
57 | audio_info = sample["audio"]
58 | waveform = torch.tensor(audio_info["array"], dtype=torch.float32)
59 | if waveform.ndim == 1:
60 | waveform = waveform.unsqueeze(0).unsqueeze(0)
61 | elif waveform.ndim == 2:
62 | waveform = waveform.unsqueeze(0)
63 | sr = audio_info.get("sampling_rate", 44100)
64 | if sr != 44100:
65 | waveform = torchaudio.functional.resample(waveform, sr, 44100)
66 | with torch.no_grad():
67 | audio_tensor = (
68 | self.dac_model.preprocess(waveform, 44100)
69 | .to(next(self.dac_model.parameters()).device)
70 | )
71 | _, encoded, *_ = self.dac_model.encode(audio_tensor, n_quantizers=None)
72 | encoded = encoded.squeeze(0).transpose(0, 1)
73 | return text, encoded, waveform
74 |
75 |
76 |
77 | class HFDiaIterDataset(torch.utils.data.IterableDataset):
78 | """Iterable wrapper for a HF streaming Dataset that has `audio.array` & `text`."""
79 | def __init__(self, hf_iterable, config: DiaConfig, dac_model: dac.DAC):
80 | super().__init__()
81 | self.dataset = hf_iterable
82 | self.config = config
83 | self.dac_model = dac_model
84 |
85 | def __iter__(self):
86 | for sample in self.dataset:
87 | lang = sample.get("language", None)
88 | text = f"[{lang}]" + sample["text"] if lang else sample["text"]
89 | audio_info = sample['audio']
90 | waveform = torch.tensor(audio_info['array'], dtype=torch.float32)
91 | if waveform.ndim == 1:
92 | waveform = waveform.unsqueeze(0).unsqueeze(0)
93 | elif waveform.ndim == 2:
94 | waveform = waveform.unsqueeze(0)
95 | sr = audio_info.get('sampling_rate', 44100)
96 | if sr != 44100:
97 | waveform = torchaudio.functional.resample(waveform, sr, 44100)
98 | with torch.no_grad():
99 | audio_tensor = (
100 | self.dac_model.preprocess(waveform, 44100)
101 | .to(next(self.dac_model.parameters()).device)
102 | )
103 | _, encoded, *_ = self.dac_model.encode(audio_tensor, n_quantizers=None)
104 | encoded = encoded.squeeze(0).transpose(0, 1)
105 | yield text, encoded, waveform
106 |
--------------------------------------------------------------------------------
/dia/finetune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import random
5 | import tempfile
6 | from dataclasses import dataclass
7 | from pathlib import Path
8 |
9 | import torch
10 | import torchaudio
11 | import pandas as pd
12 | from torch.utils.data import Dataset, DataLoader, random_split
13 | from torch.cuda.amp import autocast
14 | from torch.utils.tensorboard import SummaryWriter
15 | from torch.nn.utils import clip_grad_norm_
16 | from transformers import get_scheduler
17 | import torch.nn.functional as F
18 | import bitsandbytes as bnb
19 | from tqdm import tqdm
20 | from datasets import load_dataset, interleave_datasets, get_dataset_config_names
21 | from huggingface_hub import hf_hub_download
22 | import math
23 | import gc
24 |
25 | import dac
26 | from .config import DiaConfig
27 | from .layers import DiaModel
28 | from .model import Dia
29 | from .audio import build_delay_indices, apply_audio_delay
30 | from .dataset import *
31 | from .interleaved_datasets import load_cml_tts_streamed, load_common_voice17_streamed
32 |
33 |
34 | # Configure logging
35 | logging.basicConfig(
36 | level=logging.INFO,
37 | format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
38 | )
39 | logger = logging.getLogger(__name__)
40 |
41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42 | torch.backends.cudnn.benchmark = True
43 |
44 | #bytes for language tag replacement
45 | LANG2BYTE = {
46 | "en": 3,
47 | "de": 4,
48 | "fr": 5,
49 | "es": 6,
50 | "it": 7,
51 | "nl": 14,
52 | "pl": 15,
53 | "pt": 16,
54 | "tr": 17,
55 | "hu": 18,
56 |
57 | }
58 |
59 | test_sentences = {
60 | "en": "In order to fully assess performance and the accuracy of language tags, this test sentence contains multiple subordinate clauses, varied punctuation, and a sufficient word count.",
61 | "de": "Um Leistung und die Korrektheit der Sprach-Tags umfassend zu prüfen, enthält dieser Testsatz mehrere Nebensätze, unterschiedliche Zeichensetzung und eine ausreichende Wortzahl.",
62 | #"fr": "Pour évaluer pleinement les performances et la précision des balises de langue, cette phrase de test comporte plusieurs propositions subordonnées, une ponctuation variée et un nombre de mots suffisant.",
63 | #"es": "Para evaluar completamente el rendimiento y la precisión de las etiquetas de idioma, esta frase de prueba incluye varias oraciones subordinadas, puntuación diversa y la cantidad de palabras necesaria.",
64 | #"it": "Per valutare appieno le prestazioni e la precisione dei tag di lingua, questa frase di prova contiene più proposizioni subordinate, punteggiatura varia e un numero adeguato di parole.",
65 | #"nl": "Om de prestaties en de nauwkeurigheid van de taaltags volledig te beoordelen, bevat deze testzin meerdere ondergeschikte zinnen, gevarieerde interpunctie en een voldoende woordenaantal.",
66 | #"pl": "Aby w pełni ocenić wydajność i poprawność tagów językowych, to zdanie testowe zawiera kilka zdań podrzędnych, zróżnicowaną interpunkcję i wystarczającą liczbę słów.",
67 | #"pt": "Para avaliar completamente o desempenho e a precisão das marcas de idioma, esta frase de teste contém várias orações subordinadas, pontuação diversa e um número adequado de palavras.",
68 | #"tr": "Akışı elemeden performansı ve dil etiketlerinin doğruluğunu tam olarak değerlendirmek için bu test cümlesi birden fazla yan cümle, çeşitli noktalama işaretleri ve yeterli kelime sayısı içerir.",
69 | #"hu": "A teljesítmény és a nyelvcímkék pontosságának átfogó értékeléséhez ez a tesztmondat több mellékmondatot, változatos írásjeleket és elegendő szószámot tartalmazza."
70 | }
71 |
72 | @dataclass
73 | class TrainConfig:
74 | epochs: int = 1
75 | batch_size: int = 2
76 | grad_accum_steps: int = 2
77 | learning_rate: float = 1e-5
78 | warmup_steps: int = 500
79 | unconditional_frac: float = 0.15
80 | eval_step: int = 200
81 | save_step: int = 2000
82 | split_ratio: float = 0.997
83 | shuffle_buffer_size: int = None # for streaming shuffle
84 | seed: int = 42 # seed for reproducibility
85 | runs_dir: Path = Path("runs")
86 | run_name: str = "dia_finetune_cv"
87 | output_dir: Path = Path(".cpkts/dia_finetune_cv ")
88 |
89 |
90 | def get_args() -> argparse.Namespace:
91 | parser = argparse.ArgumentParser(description="Train the Dia audio model")
92 | parser.add_argument("--config", type=Path, default=Path("dia/config.json"))
93 | parser.add_argument("--dataset", type=str, default="Paradoxia/opendata-iisys-hui",
94 | help="HuggingFace dataset name (if not using --csv_path).")
95 | parser.add_argument("--dataset2", type=str, default=None,
96 | help="(Optional) second HF dataset to interleave (streaming)")
97 | parser.add_argument("--streaming",action="store_true",
98 | help="Enable HuggingFace dataset streaming")
99 | parser.add_argument("--hub_model", type=str, default="nari-labs/Dia-1.6B")
100 | parser.add_argument("--local_ckpt", type=str, default=None)
101 | parser.add_argument("--csv_path", type=Path, default=None,
102 | help="Path to local CSV/TSV file with `audio|text` (if you want to train locally).")
103 | parser.add_argument("--audio_root",type=Path, default=None,
104 | help="Root directory for local audio files (required if --csv_path is set).")
105 | parser.add_argument("--run_name", type=str, default=None)
106 | parser.add_argument("--output_dir",type=Path, default=None)
107 | parser.add_argument("--shuffle_buffer_size", type=int, default=None,
108 | help="Buffer size for streaming dataset shuffle.")
109 | parser.add_argument("--seed", type=int, default=42,
110 | help="Random seed for reproducibility.")
111 | parser.add_argument("--half", action="store_true", help="load model in fp16")
112 | parser.add_argument("--compile", action="store_true", help="torch compile model")
113 | return parser.parse_args()
114 |
115 |
116 |
117 | def collate_fn(batch, config: DiaConfig, device: torch.device):
118 | from torch.nn.functional import pad
119 |
120 | texts, encodings, waveforms = zip(*batch)
121 |
122 | # -- Text inputs ---------------------------------------------------------
123 |
124 | max_text = config.data.text_length
125 | pad_tok = config.data.text_pad_value
126 | text_ids = []
127 | for txt in texts:
128 | b_full = txt.encode('utf-8')
129 | # replace leading "[lang]" prefix
130 | for code, val in LANG2BYTE.items():
131 | prefix = f"[{code}]".encode('utf-8')
132 | if b_full.startswith(prefix):
133 | b_full = bytes([val]) + b_full[len(prefix):]
134 | break
135 | bts = b_full[:max_text]
136 | arr = list(bts) + [pad_tok] * (max_text - len(bts))
137 | text_ids.append(torch.tensor(arr, dtype=torch.long))
138 | src = torch.stack(text_ids).to(device)
139 | src_pos = torch.arange(max_text, device=device).unsqueeze(0).expand(src.size(0), -1)
140 | src_pad = src.ne(pad_tok)
141 | enc_self_attn_mask = (src_pad.unsqueeze(2) & src_pad.unsqueeze(1)).unsqueeze(1)
142 |
143 | # -- Audio codes --------------------------------------------------------
144 |
145 | max_audio = config.data.audio_length
146 | # per-sample lengths (clipped to max_audio)
147 | seq_lens = [min(e.size(0), max_audio) for e in encodings]
148 | batch_max = max(seq_lens)
149 |
150 | # pad or trim each encoding to the batch max length
151 | padded = [pad(e, (0, 0, 0, batch_max - e.size(0))) if e.size(0) < batch_max else e[:batch_max]
152 | for e in encodings]
153 | codes = torch.stack(padded).to(device) # (B, T=batch_max, C)
154 |
155 | B, T, C = codes.shape
156 | t_idx, idxs = build_delay_indices(B, T, C, config.data.delay_pattern)
157 | delayed = apply_audio_delay(
158 | codes,
159 | config.data.audio_pad_value,
160 | config.data.audio_bos_value,
161 | (t_idx, idxs)
162 | )
163 | # ensure no longer than max_audio
164 | delayed = delayed[:, :max_audio, :]
165 |
166 | # -- Targets with per-sample EOS ----------------------------------------
167 |
168 | max_tgt_len = max_audio + 2
169 | pad_val = config.data.audio_pad_value
170 | bos_val = config.data.audio_bos_value
171 | eos_val = config.data.audio_eos_value
172 |
173 | tgt = torch.full((B, max_tgt_len, C), pad_val, dtype=torch.long, device=device)
174 | tgt[:, 0, :] = bos_val
175 | tgt_lens = []
176 | for i, L in enumerate(seq_lens):
177 | tgt[i, 1:1 + L, :] = delayed[i, :L, :]
178 | tgt[i, 1 + L, :] = eos_val
179 | tgt_lens.append(1 + L + 1)
180 |
181 | tgt_pos = torch.arange(max_tgt_len, device=device).unsqueeze(0).expand(B, -1)
182 | tgt_pad = tgt.ne(pad_val).any(-1)
183 |
184 | causal = torch.tril(torch.ones((max_tgt_len, max_tgt_len),
185 | dtype=torch.bool,
186 | device=device))
187 | dec_self_attn_mask = (tgt_pad.unsqueeze(2) & tgt_pad.unsqueeze(1) & causal).unsqueeze(1)
188 | dec_cross_attn_mask = (tgt_pad.unsqueeze(2) & src_pad.unsqueeze(1)).unsqueeze(1)
189 |
190 | return {
191 | 'src_tokens': src,
192 | 'src_positions': src_pos,
193 | 'enc_self_attn_mask': enc_self_attn_mask,
194 | 'tgt_tokens': tgt,
195 | 'tgt_positions': tgt_pos,
196 | 'dec_self_attn_mask': dec_self_attn_mask,
197 | 'dec_cross_attn_mask': dec_cross_attn_mask,
198 | 'waveforms': waveforms,
199 | 'raw_text': texts[0],
200 | 'tgt_lens': torch.tensor(tgt_lens, dtype=torch.long, device=device),
201 | }
202 |
203 | def setup_loaders(dataset, dia_cfg: DiaConfig, train_cfg: TrainConfig, device):
204 | collate = lambda b: collate_fn(b, dia_cfg, device)
205 | if isinstance(dataset, HFDiaIterDataset):
206 | total = getattr(dataset, "total_examples", None)
207 | if total is None:
208 | total = dataset.dataset.info.splits["train"].num_examples
209 | n_train = int(train_cfg.split_ratio * total)
210 | n_val = total - n_train
211 | if n_val <= 0:
212 | raise RuntimeError(f"No validation samples: total={total}, split_ratio={train_cfg.split_ratio}")
213 | base = dataset.dataset.shuffle(buffer_size=train_cfg.shuffle_buffer_size, seed=train_cfg.seed) if train_cfg.shuffle_buffer_size else dataset.dataset
214 | val_stream = base.take(n_val)
215 | train_stream = base.skip(n_val)
216 | train_ds = HFDiaIterDataset(train_stream, dia_cfg, dataset.dac_model)
217 | val_ds = HFDiaIterDataset(val_stream, dia_cfg, dataset.dac_model)
218 | train_loader = DataLoader(train_ds, batch_size=train_cfg.batch_size, shuffle=False, collate_fn=collate)
219 | train_loader.steps_per_epoch = math.ceil(n_train / train_cfg.batch_size)
220 | val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate)
221 | return train_loader, val_loader
222 | ds_len = len(dataset)
223 | n_train = int(train_cfg.split_ratio * ds_len)
224 | train_ds, val_ds = random_split(dataset, [n_train, ds_len - n_train])
225 | train_loader = DataLoader(train_ds, batch_size=train_cfg.batch_size, shuffle=True, collate_fn=collate)
226 | val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate)
227 | return train_loader, val_loader
228 |
229 |
230 |
231 | def setup_optimizer_and_scheduler(model, train_loader, train_cfg):
232 | opt = bnb.optim.AdamW8bit(model.parameters(), lr=train_cfg.learning_rate)
233 | # Determine steps per epoch: prefer len(), else use attached attribute
234 | try:
235 | steps_per_epoch = len(train_loader)
236 | except TypeError:
237 | if hasattr(train_loader, 'steps_per_epoch'):
238 | steps_per_epoch = train_loader.steps_per_epoch
239 | else:
240 | raise RuntimeError("Cannot determine steps_per_epoch for streaming loader")
241 | total_training_steps = steps_per_epoch * train_cfg.epochs
242 | sched = get_scheduler(
243 | 'cosine', opt,
244 | num_warmup_steps=train_cfg.warmup_steps / train_cfg.grad_accum_steps,
245 | num_training_steps=total_training_steps / train_cfg.grad_accum_steps
246 | )
247 | return opt, sched
248 |
249 |
250 |
251 | def train_step(model, batch, dia_cfg, train_cfg, opt, sched, writer, step, global_step):
252 | """
253 | Perform a single training step: forward, loss, backward, update, log.
254 | Now uses per‑sample tgt_lens to mask out padding after each EOS,
255 | and applies 4× loss weight on the first channel.
256 | """
257 | # (optional) unconditional conditioning
258 | if random.random() < train_cfg.unconditional_frac:
259 | pad_tok = dia_cfg.data.text_pad_value
260 | batch['src_tokens'] = torch.zeros_like(batch['src_tokens'])
261 | batch['enc_self_attn_mask'] = torch.zeros_like(batch['enc_self_attn_mask'])
262 | batch['dec_cross_attn_mask'] = torch.zeros_like(batch['dec_cross_attn_mask'])
263 |
264 | with autocast():
265 | # forward pass
266 | logits = model(
267 | src_BxS=batch['src_tokens'],
268 | tgt_BxTxC=batch['tgt_tokens'],
269 | src_positions=batch['src_positions'],
270 | tgt_positions=batch['tgt_positions'],
271 | enc_self_attn_mask=batch['enc_self_attn_mask'],
272 | dec_self_attn_mask=batch['dec_self_attn_mask'],
273 | dec_cross_attn_mask=batch['dec_cross_attn_mask'],
274 | enable_dropout=True,
275 | )
276 | # fetch per-sample target‑lengths (including BOS+frames+EOS)
277 | lens = batch['tgt_lens'] # shape: (B,)
278 | max_L = int(lens.max().item()) # maximum over batch
279 |
280 | # keep only up through the last possible EOS slot
281 | # logits: (B, T, C, V) -> (B, max_L-1, C, V)
282 | logits = logits[:, : max_L - 1]
283 |
284 | # targets: shift off the BOS so 0.. align with logits
285 | # target: (B, T, C) -> (B, max_L-1, C)
286 | target = batch['tgt_tokens'][:, 1:max_L, :]
287 |
288 | B, Tm1, C = target.shape
289 | pad_val = dia_cfg.data.audio_pad_value
290 |
291 | # build a mask [B x (max_L-1)] that is True for t < (lens[i]-1)
292 | time_idx = torch.arange(Tm1, device=lens.device).unsqueeze(0) # (1, Tm1)
293 | valid_time = time_idx < (lens.unsqueeze(1) - 1) # (B, Tm1)
294 | mask = valid_time.unsqueeze(-1).expand(-1, -1, C) # (B, Tm1, C)
295 |
296 | # apply 4× weight on first channel, 1× on others
297 | channel_weights = [4.0] + [1.0] * (C - 1)
298 | loss_c = 0.0
299 | _, _, _, V = logits.size()
300 |
301 | for c, w in enumerate(channel_weights):
302 | # flatten this channel
303 | lc = logits[:, :, c, :].reshape(-1, V) # (B*Tm1, V)
304 | tc = target[:, :, c].reshape(-1) # (B*Tm1,)
305 | mc = mask[:, :, c].reshape(-1) # (B*Tm1,)
306 |
307 | # mask out padding and compute cross-entropy
308 | lc_valid = lc[mc]
309 | tc_valid = tc[mc]
310 | loss_c += w * F.cross_entropy(
311 | lc_valid, tc_valid,
312 | ignore_index=pad_val
313 | )
314 |
315 | # normalize by sum of weights
316 | loss = loss_c / sum(channel_weights)
317 |
318 | # scale + backward
319 | loss = loss / train_cfg.grad_accum_steps
320 | loss.backward()
321 |
322 | # step & log
323 |
324 | grad_norm = clip_grad_norm_(model.parameters(), max_norm=1e9)
325 | writer.add_scalar('GradNorm/global', grad_norm, global_step)
326 | if (step + 1) % train_cfg.grad_accum_steps == 0:
327 | opt.step()
328 | sched.step()
329 | opt.zero_grad()
330 | true_loss = loss.item() * train_cfg.grad_accum_steps
331 | current_lr = sched.get_last_lr()[0]
332 | writer.add_scalar('LR', current_lr, global_step)
333 | writer.add_scalar('Loss/train', true_loss, global_step)
334 |
335 | return loss.item() * train_cfg.grad_accum_steps
336 |
337 |
338 |
339 | def eval_step(model, val_loader, dia_cfg, dac_model, writer, global_step):
340 | """
341 | Run evaluation: compute average loss on validation set and log audio samples.
342 | """
343 | eval_losses = []
344 | last_batch = None
345 | with torch.inference_mode():
346 | for eb in tqdm(val_loader, desc="eval"):
347 | last_batch = eb
348 |
349 | # 1) do your forward in mixed precision
350 | with autocast():
351 | logits16 = model(
352 | src_BxS=eb['src_tokens'],
353 | tgt_BxTxC=eb['tgt_tokens'],
354 | src_positions=eb['src_positions'],
355 | tgt_positions=eb['tgt_positions'],
356 | enc_self_attn_mask=eb['enc_self_attn_mask'],
357 | dec_self_attn_mask=eb['dec_self_attn_mask'],
358 | dec_cross_attn_mask=eb['dec_cross_attn_mask'],
359 | enable_dropout=False,
360 | )[:, :-1]
361 |
362 | logits = logits16.float()
363 | target = eb['tgt_tokens'][:, 1:]
364 | B_e, T_e, C_e = target.shape
365 | V_e = logits.size(-1)
366 |
367 | loss_e = 0.0
368 | weights_e = [4.0] + [1.0] * (C_e - 1)
369 | for c, w in enumerate(weights_e):
370 | lc = logits[:, :, c, :].reshape(-1, V_e)
371 | tc = target[:, :, c].reshape(-1)
372 | loss_e += w * F.cross_entropy(
373 | lc, tc, ignore_index=dia_cfg.data.audio_pad_value
374 | )
375 | loss_e = loss_e / sum(weights_e)
376 |
377 | eval_losses.append(loss_e)
378 |
379 | avg_eval_loss = sum(eval_losses) / len(eval_losses)
380 | writer.add_scalar('Loss/eval', avg_eval_loss.item(), global_step)
381 |
382 | try:
383 | orig_dtype = next(model.parameters()).dtype
384 | model = model.float()
385 | dia_gen = Dia(dia_cfg, device)
386 | dia_gen.model, dia_gen.dac_model = model, dac_model
387 | with torch.inference_mode():
388 | for lang_code, sentence in test_sentences.items():
389 | text = f"[{lang_code}]{sentence}"
390 | try:
391 | audio = dia_gen.generate(text=text)
392 | writer.add_audio(f"Eval/{lang_code}", audio, global_step, 44100)
393 | except:
394 | logger.exception(f"Error synthesizing test sentence in {lang_code}.")
395 | del audio
396 | gc.collect()
397 | torch.cuda.empty_cache()
398 |
399 | except Exception:
400 | logger.exception("Eval error")
401 |
402 | finally:
403 | if orig_dtype == torch.float16:
404 | model = model.half()
405 |
406 |
407 | def train(model, dia_cfg: DiaConfig, dac_model: dac.DAC, dataset, train_cfg: TrainConfig):
408 | """
409 | Run the full training loop over epochs.
410 | """
411 | # prepare directories
412 | train_cfg.output_dir.mkdir(parents=True, exist_ok=True)
413 | (train_cfg.runs_dir / train_cfg.run_name).mkdir(parents=True, exist_ok=True)
414 | model = model.to(device)
415 |
416 | train_loader, val_loader = setup_loaders(dataset, dia_cfg, train_cfg, device)
417 | opt, sched = setup_optimizer_and_scheduler(model, train_loader, train_cfg)
418 |
419 | writer = SummaryWriter(train_cfg.runs_dir / train_cfg.run_name)
420 | model.train()
421 |
422 | steps_per_epoch = getattr(train_loader, 'steps_per_epoch', None)
423 | if steps_per_epoch is None:
424 | try:
425 | steps_per_epoch = len(train_loader)
426 | except Exception:
427 | steps_per_epoch = None
428 |
429 | for epoch in range(train_cfg.epochs):
430 | # iterate with progress bar, using total if known
431 | loader_iter = tqdm(
432 | train_loader,
433 | desc=f"E{epoch+1}",
434 | total=steps_per_epoch
435 | )
436 | for step, batch in enumerate(loader_iter):
437 | global_step = epoch * (steps_per_epoch or 0) + step
438 | # training step
439 | loss=train_step(model, batch, dia_cfg, train_cfg, opt, sched, writer,step, global_step)
440 |
441 | cur_alloc = torch.cuda.memory_allocated() # bytes currently allocated by tensors
442 | peak_alloc = torch.cuda.max_memory_allocated() # bytes peak during program
443 | # optionally convert to GB
444 | cur_gb = cur_alloc / 1024**3
445 | peak_gb = peak_alloc / 1024**3
446 |
447 | # update the tqdm postfix
448 | loader_iter.set_postfix({
449 | 'loss': f"{loss:.4f}",
450 | 'VRAM (GB)': f"{cur_gb:.2f}/{peak_gb:.2f}"
451 | })
452 |
453 | # remember to zero the peak if you want rolling peaks per step
454 | torch.cuda.reset_peak_memory_stats()
455 |
456 |
457 | # evaluation
458 | if step % train_cfg.eval_step == 0:
459 | model.eval()
460 | with torch.no_grad():
461 | eval_step(model, val_loader, dia_cfg, dac_model, writer, global_step)
462 | model.train()
463 |
464 | # checkpoint
465 | if step and step % train_cfg.save_step == 0:
466 | ckpt = train_cfg.output_dir / f"ckpt_step{global_step}.pth"
467 | torch.save(model.state_dict(), ckpt)
468 | logger.info(f"Saved checkpoint: {ckpt}")
469 |
470 | # end of epoch checkpoint
471 | ckpt_e = train_cfg.output_dir / f"ckpt_epoch{epoch+1}.pth"
472 | torch.save(model.state_dict(), ckpt_e)
473 | logger.info(f"Saved end-of-epoch checkpoint: {ckpt_e}")
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 | def main():
484 | args = get_args()
485 | dia_cfg = DiaConfig.load(args.config)
486 | dac_model = dac.DAC.load(dac.utils.download()).to(device)
487 |
488 |
489 | dataset=None
490 |
491 |
492 | #dataset = load_cml_tts_streamed(dia_cfg, dac_model)
493 | #dataset = load_common_voice17_streamed(dia_cfg, dac_model)
494 |
495 | # choose dataset
496 | if not dataset:
497 | if args.csv_path:
498 | if not args.audio_root:
499 | raise ValueError("`--audio_root` must be set when using `--csv_path`")
500 | dataset = LocalDiaDataset(args.csv_path, args.audio_root, dia_cfg, dac_model)
501 | else:
502 | # load one or two streaming HF datasets
503 | ds1 = load_dataset(args.dataset, split="train", streaming=args.streaming)
504 |
505 | if args.streaming:
506 | if args.dataset2:
507 | ds2 = load_dataset(args.dataset2, split="train", streaming=True)
508 | # sum their lengths
509 | total1 = ds1.info.splits['train'].num_examples
510 | total2 = ds2.info.splits['train'].num_examples
511 | total = total1 + total2
512 | hf_ds = interleave_datasets([ds1, ds2])
513 | dataset = HFDiaIterDataset(hf_ds, dia_cfg, dac_model)
514 | # attach total examples for loader
515 | dataset.total_examples = total
516 | else:
517 | hf_ds = ds1
518 | dataset = HFDiaIterDataset(hf_ds, dia_cfg, dac_model)
519 | else:
520 | dataset = HFDiaDataset(ds1, dia_cfg, dac_model)
521 |
522 |
523 |
524 | train_cfg = TrainConfig(
525 | run_name = args.run_name or TrainConfig.run_name,
526 | output_dir = args.output_dir or TrainConfig.output_dir,
527 | shuffle_buffer_size = args.shuffle_buffer_size,
528 | seed = args.seed,
529 | )
530 |
531 | # load model checkpoint
532 | if args.local_ckpt:
533 | ckpt_file = args.local_ckpt
534 | else:
535 | ckpt_file = hf_hub_download(args.hub_model, filename="dia-v0_1.pth")
536 | model = DiaModel(dia_cfg)
537 | if args.half:
538 | model=model.half()
539 | if args.compile:
540 | model = torch.compile(model, backend="inductor")
541 | model.load_state_dict(torch.load(ckpt_file, map_location="cpu"))
542 |
543 |
544 | # start training
545 | train(model, dia_cfg, dac_model, dataset, train_cfg)
546 |
547 |
548 | if __name__ == "__main__":
549 | main()
550 |
551 |
--------------------------------------------------------------------------------
/dia/interleaved_datasets.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset, get_dataset_config_names, interleave_datasets, load_dataset_builder
2 | from .dataset import HFDiaIterDataset
3 | import pandas as pd
4 | from huggingface_hub import hf_hub_download
5 |
6 |
7 | LANG_NAME_TO_CODE = {
8 | "dutch": "nl",
9 | "french": "fr",
10 | "german": "de",
11 | "italian": "it",
12 | "polish": "pl",
13 | "portuguese": "pt",
14 | "spanish": "es",
15 | # add more if other configs appear...
16 | }
17 |
18 |
19 |
20 |
21 |
22 |
23 | def load_cml_tts_streamed(dia_cfg, dac_model):
24 | """
25 | Stream all language subsets of the CML-TTS dataset in train split,
26 | add a `language` field, drop all except `text`, `audio`, `language`,
27 | and interleave them into one streaming Dataset.
28 |
29 | Returns:
30 | datasets.IterableDataset: interleaved streaming dataset
31 | """
32 | # 1) Discover all language subsets
33 | lang_configs = get_dataset_config_names("ylacombe/cml-tts")
34 |
35 | # 2) Build one streaming subset per language, with only desired columns
36 | streams = []
37 | num_ex=0
38 | for lang in lang_configs:
39 |
40 | iso_code = LANG_NAME_TO_CODE.get(lang, lang)
41 | ds_stream = load_dataset(
42 | "ylacombe/cml-tts",
43 | name=lang,
44 | split="train",
45 | streaming=True
46 | )
47 |
48 | num_ex += ds_stream.info.splits['train'].num_examples
49 | # keep only text, audio, and add language
50 | def _add_lang(ex, iso=iso_code):
51 | return {
52 | "text": ex["text"],
53 | "audio": ex["audio"],
54 | "language": iso
55 | }
56 | ds_stream = ds_stream.map(
57 | _add_lang,
58 | remove_columns=[c for c in ds_stream.column_names if c not in ["text", "audio", "language"]]
59 | )
60 | streams.append(ds_stream)
61 |
62 | # 3) Interleave all streams into one unified stream
63 | interleaved = interleave_datasets(streams, stopping_strategy="all_exhausted")
64 | ds = HFDiaIterDataset(interleaved, dia_cfg, dac_model)
65 | ds.total_examples = num_ex
66 | return ds
67 |
68 |
69 |
70 |
71 |
72 |
73 | def count_tsv_rows(
74 | repo_id: str,
75 | subset: str,
76 | split: str = "train",
77 | revision: str = "main"
78 | ) -> int:
79 | """Download the TSV for a given subset/split and return its number of rows."""
80 | file_path = f"transcript/{subset}/{split}.tsv"
81 | try:
82 | local_file = hf_hub_download(
83 | repo_id=repo_id,
84 | filename=file_path,
85 | repo_type="dataset",
86 | revision=revision
87 | )
88 | except:
89 | print("error fetching tsv metadata")
90 |
91 | df = pd.read_csv(local_file, sep="\t", low_memory=False)
92 | return len(df)
93 |
94 | def load_common_voice17_streamed(dia_cfg, dac_model, revision="main"):
95 | """
96 | Stream the train split of Common Voice 17 for the given language codes,
97 | rename `sentence`→`text`, keep only `text`, `audio`, and `language`,
98 | then interleave into a single streaming Dataset.
99 |
100 | Languages loaded: en, de, fr, es, it, nl, pl, pt, tr, hu
101 | """
102 | repo_id = "mozilla-foundation/common_voice_17_0"
103 | langs = ["en", "de", "fr", "es", "it", "nl", "pl", "pt", "tr", "hu"]
104 |
105 | streams = []
106 | row_counts = []
107 |
108 | for lang in langs:
109 | # 1) figure out how many rows in the TSV
110 | n_rows = count_tsv_rows(repo_id, lang, split="train", revision=revision)
111 | row_counts.append(n_rows)
112 |
113 | # 2) load in streaming mode
114 | ds_stream = load_dataset(
115 | repo_id,
116 | name=lang,
117 | split="train",
118 | streaming=True,
119 | revision=revision
120 | )
121 |
122 | # 3) map to desired schema
123 | def _prep(ex, iso=lang):
124 | return {
125 | "text": ex["sentence"],
126 | "audio": ex["audio"],
127 | "language": iso
128 | }
129 |
130 | ds_stream = ds_stream.map(
131 | _prep,
132 | remove_columns=[c for c in ds_stream.column_names if c not in ("sentence", "audio")]
133 | )
134 | streams.append(ds_stream)
135 |
136 | # 4) interleave: all_exhausted ⇒ max_length * num_streams
137 | interleaved = interleave_datasets(streams, stopping_strategy="all_exhausted")
138 |
139 | # 5) wrap and attach total_examples
140 | ds = HFDiaIterDataset(interleaved, dia_cfg, dac_model)
141 | ds.total_examples = max(row_counts) * len(langs)
142 |
143 | return ds
144 |
145 |
--------------------------------------------------------------------------------
/dia/layers.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 | from torch.nn import RMSNorm
8 |
9 | from .config import DiaConfig
10 |
11 |
12 | def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
13 | return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
14 |
15 |
16 | def _str_to_dtype(dtype_str: str) -> torch.dtype | None:
17 | # Allow None for default behavior
18 | if dtype_str is None or dtype_str.lower() == "none":
19 | return None
20 | if dtype_str == "float32":
21 | return torch.float32
22 | elif dtype_str == "float16":
23 | return torch.float16
24 | elif dtype_str == "bfloat16":
25 | return torch.bfloat16
26 | else:
27 | raise ValueError(f"Unsupported dtype string: {dtype_str}")
28 |
29 |
30 | class DenseGeneral(nn.Module):
31 | """
32 | PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
33 |
34 | Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot
35 | for the generalized matrix multiplication. Weight/bias shapes are calculated
36 | and parameters created during initialization based on config.
37 | `load_weights` validates shapes and copies data.
38 |
39 | Attributes:
40 | axis (Tuple[int, ...]): Input axis or axes to contract.
41 | in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`.
42 | out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims).
43 | use_bias (bool): Whether to add a bias term.
44 | weight (nn.Parameter): The kernel parameter.
45 | bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True).
46 | """
47 |
48 | def __init__(
49 | self,
50 | in_shapes: tuple[int, ...],
51 | out_features: tuple[int, ...],
52 | axis: tuple[int, ...] = (-1,),
53 | dtype: torch.dtype | None = None,
54 | weight_dtype: torch.dtype | None = None,
55 | device: torch.device | None = None,
56 | ):
57 | super().__init__()
58 | self.in_shapes = in_shapes
59 | self.out_features = out_features
60 | self.axis = axis
61 | self.dtype = dtype
62 | self.kernel_shape = self.in_shapes + self.out_features
63 |
64 | factory_kwargs = {"device": device, "dtype": weight_dtype}
65 | self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
66 | self.register_parameter("bias", None)
67 |
68 | def forward(self, inputs: Tensor) -> Tensor:
69 | norm_axis = _normalize_axes(self.axis, inputs.ndim)
70 | kernel_contract_axes = tuple(range(len(norm_axis)))
71 |
72 | output = torch.tensordot(
73 | inputs.float(),
74 | self.weight.float(),
75 | dims=(norm_axis, kernel_contract_axes),
76 | ).to(inputs.dtype)
77 | return output
78 |
79 |
80 | def get_activation_fn(activation_string: str) -> nn.Module: # Return Module instance
81 | """Maps activation string to PyTorch activation function module."""
82 | if activation_string == "gelu":
83 | return nn.GELU()
84 | elif activation_string == "relu":
85 | return nn.ReLU()
86 | elif activation_string == "silu" or activation_string == "swish":
87 | return nn.SiLU()
88 | elif activation_string == "linear":
89 | return nn.Identity()
90 | else:
91 | raise ValueError(f"Unsupported activation function: {activation_string}")
92 |
93 |
94 | class MlpBlock(nn.Module):
95 | """MLP block using DenseGeneral."""
96 |
97 | def __init__(
98 | self,
99 | config: DiaConfig,
100 | embed_dim: int,
101 | intermediate_dim: int,
102 | dropout_rate: float,
103 | activations: list[str] = ["silu", "linear"],
104 | use_pre_norm: bool = False,
105 | ):
106 | super().__init__()
107 | self.use_pre_norm = use_pre_norm
108 | num_activations = len(activations)
109 | compute_dtype = _str_to_dtype(config.training.dtype)
110 | weight_dtype = _str_to_dtype(config.model.weight_dtype)
111 | self.dtype = compute_dtype
112 | # Assume default device for now, could be passed in config
113 |
114 | if use_pre_norm:
115 | self.pre_norm = RMSNorm(
116 | embed_dim,
117 | eps=config.model.normalization_layer_epsilon,
118 | dtype=torch.float32,
119 | )
120 |
121 | self.wi_fused = DenseGeneral(
122 | in_shapes=(embed_dim,),
123 | out_features=(
124 | num_activations,
125 | intermediate_dim,
126 | ),
127 | axis=(-1,),
128 | dtype=compute_dtype,
129 | weight_dtype=weight_dtype,
130 | )
131 |
132 | self.activation_fn_0 = get_activation_fn(activations[0]) # silu
133 | self.activation_fn_1 = get_activation_fn(activations[1]) # linear
134 |
135 | self.dropout = nn.Dropout(dropout_rate)
136 |
137 | # Output layer using DenseGeneral
138 | self.wo = DenseGeneral(
139 | in_shapes=(intermediate_dim,),
140 | out_features=(embed_dim,),
141 | axis=(-1,),
142 | dtype=compute_dtype,
143 | weight_dtype=weight_dtype,
144 | )
145 |
146 | def forward(self, x: torch.Tensor, deterministic: bool) -> torch.Tensor:
147 | """Forward pass."""
148 | if self.use_pre_norm and hasattr(self, "pre_norm"):
149 | x = self.pre_norm(x)
150 |
151 | fused_x = self.wi_fused(x)
152 |
153 | gate_input = fused_x[..., 0, :]
154 | up_input = fused_x[..., 1, :]
155 |
156 | gate = self.activation_fn_0(gate_input)
157 | up = self.activation_fn_1(up_input)
158 | hidden = torch.mul(gate, up).to(self.dtype)
159 |
160 | if not deterministic:
161 | hidden = self.dropout(hidden)
162 |
163 | output = self.wo(hidden)
164 | return output
165 |
166 |
167 | class RotaryEmbedding(nn.Module):
168 | """Rotary Position Embedding (RoPE) implementation in PyTorch."""
169 |
170 | def __init__(
171 | self,
172 | embedding_dims: int,
173 | min_timescale: int = 1,
174 | max_timescale: int = 10000,
175 | dtype: torch.dtype = torch.float32,
176 | ):
177 | super().__init__()
178 | if embedding_dims % 2 != 0:
179 | raise ValueError("Embedding dim must be even for RoPE.")
180 | self.embedding_dims = embedding_dims
181 | self.min_timescale = min_timescale
182 | self.max_timescale = max_timescale
183 | self.dtype = dtype
184 |
185 | half_embedding_dim = embedding_dims // 2
186 | fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
187 | self.register_buffer(
188 | "timescale",
189 | self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction,
190 | persistent=False,
191 | )
192 |
193 | def extra_repr(self) -> str:
194 | s = f"{self.timescale.shape}"
195 | return s
196 |
197 | def forward(self, inputs: torch.Tensor, position: torch.Tensor):
198 | """Applies RoPE."""
199 | position = position.unsqueeze(-1).unsqueeze(-1)
200 | timescale = self.timescale.to(inputs.device)
201 | sinusoid_inp = position / timescale
202 | sin = torch.sin(sinusoid_inp).to(inputs.dtype)
203 | cos = torch.cos(sinusoid_inp).to(inputs.dtype)
204 | first_half, second_half = torch.chunk(inputs, 2, dim=-1)
205 | first_part = first_half * cos - second_half * sin
206 | second_part = second_half * cos + first_half * sin
207 | return torch.cat((first_part, second_part), dim=-1)
208 |
209 |
210 | class KVCache:
211 | def __init__(self, num_heads, max_len, head_dim, device, k=None, v=None):
212 | self.k = torch.zeros((2, num_heads, max_len, head_dim), device=device) if k is None else k
213 | self.v = torch.zeros((2, num_heads, max_len, head_dim), device=device) if v is None else v
214 | self.current_idx = 0
215 | self.max_len = max_len
216 |
217 | def get_kv_for_attention(self, current_k, current_v):
218 | if self.current_idx == 0:
219 | return current_k, current_v
220 | else:
221 | past_k = self.k[:, :, : self.current_idx, :]
222 | past_v = self.v[:, :, : self.current_idx, :]
223 | attn_k = torch.cat((past_k, current_k), dim=2)
224 | attn_v = torch.cat((past_v, current_v), dim=2)
225 | return attn_k, attn_v
226 |
227 | def update_cache(self, k, v):
228 | assert self.current_idx < self.max_len
229 | self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
230 | self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
231 | self.current_idx += 1
232 |
233 | def prefill_kv(self, k, v):
234 | prefill_len = k.shape[2]
235 | assert prefill_len <= self.max_len
236 | self.k[:, :, :prefill_len, :] = k
237 | self.v[:, :, :prefill_len, :] = v
238 | self.current_idx = prefill_len
239 |
240 |
241 | class Attention(nn.Module):
242 | """Attention using DenseGeneral."""
243 |
244 | def __init__(
245 | self,
246 | config: DiaConfig,
247 | q_embed_dim: int,
248 | kv_embed_dim: int,
249 | num_query_heads: int,
250 | num_kv_heads: int,
251 | head_dim: int,
252 | dropout_rate: float,
253 | is_cross_attn: bool = False,
254 | out_embed_dim: int | None = None,
255 | ):
256 | super().__init__()
257 | self.num_query_heads = num_query_heads
258 | self.num_kv_heads = num_kv_heads
259 | self.head_dim = head_dim
260 | self.is_cross_attn = is_cross_attn
261 | self.dropout_rate = dropout_rate
262 | compute_dtype = _str_to_dtype(config.training.dtype)
263 | weight_dtype = _str_to_dtype(config.model.weight_dtype)
264 | self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
265 | self.projected_query_dim = num_query_heads * head_dim
266 | if num_query_heads % num_kv_heads != 0:
267 | raise ValueError(f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})")
268 | self.num_gqa_groups = num_query_heads // num_kv_heads
269 |
270 | # --- Projection Layers using DenseGeneral ---
271 | self.q_proj = DenseGeneral(
272 | in_shapes=(q_embed_dim,),
273 | out_features=(num_query_heads, head_dim),
274 | axis=(-1,),
275 | dtype=compute_dtype,
276 | weight_dtype=weight_dtype,
277 | )
278 | self.k_proj = DenseGeneral(
279 | in_shapes=(kv_embed_dim,),
280 | out_features=(num_kv_heads, head_dim),
281 | axis=(-1,),
282 | dtype=compute_dtype,
283 | weight_dtype=weight_dtype,
284 | )
285 | self.v_proj = DenseGeneral(
286 | in_shapes=(kv_embed_dim,),
287 | out_features=(num_kv_heads, head_dim),
288 | axis=(-1,),
289 | dtype=compute_dtype,
290 | weight_dtype=weight_dtype,
291 | )
292 | self.o_proj = DenseGeneral(
293 | in_shapes=(num_query_heads, head_dim),
294 | out_features=(self.output_dim,),
295 | axis=(-2, -1),
296 | dtype=compute_dtype,
297 | weight_dtype=weight_dtype,
298 | )
299 |
300 | # --- Rotary Embedding ---
301 | self.rotary_emb = RotaryEmbedding(
302 | embedding_dims=self.head_dim,
303 | min_timescale=config.model.rope_min_timescale,
304 | max_timescale=config.model.rope_max_timescale,
305 | dtype=compute_dtype,
306 | )
307 |
308 | def forward(
309 | self,
310 | Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
311 | Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
312 | q_positions: torch.Tensor, # (B, T)
313 | kv_positions: torch.Tensor | None = None, # (B, S)
314 | deterministic: bool = True,
315 | attn_mask: torch.Tensor | None = None, # None in Decoder Self Attention, Valid mask in Others
316 | cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
317 | prefill: bool = False, # True only when prefilling KV Cache
318 | ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
319 | """
320 | Performs attention calculation with optional KV caching.
321 |
322 | Args:
323 | Xq: Query tensor (B, T, D). T=1 during single-step decoding.
324 | Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
325 | q_positions: Positions for queries (B, T).
326 | kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
327 | deterministic: If True, disable dropout.
328 | attn_mask: Attention mask.
329 | cache: KVCache.
330 | prefill: If True, use prefill mode.
331 |
332 | Returns:
333 | A tuple containing:
334 | - output: The attention output tensor (B, T, output_dim).
335 | - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
336 | """
337 | if kv_positions is None:
338 | kv_positions = q_positions
339 | original_dtype = Xq.dtype
340 |
341 | Xq_BxTxNxH = self.q_proj(Xq)
342 | Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
343 | Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
344 |
345 | # Input values into attention calculation
346 | attn_k: torch.Tensor | None = None
347 | attn_v: torch.Tensor | None = None
348 | new_kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None
349 |
350 | # Decoder Cross Attention
351 | if self.is_cross_attn:
352 | # Directly use cache (no need to check index)
353 | attn_k, attn_v = cache.k, cache.v
354 | if attn_k.shape[1] != self.num_query_heads or attn_v.shape[1] != self.num_query_heads:
355 | raise ValueError(
356 | f"Cross-attention cache head dimension ({attn_k.shape[1]}) "
357 | f"does not match num_query_heads ({self.num_query_heads}). "
358 | "Cache should be pre-repeated for GQA."
359 | )
360 | # Self Attention
361 | else:
362 | Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
363 | Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
364 | Xk_BxSxKxH = self.rotary_emb(Xk_BxSxKxH, position=kv_positions) # (B, S, K, H)
365 |
366 | Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
367 | Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
368 | # S=1 for Decode Step
369 |
370 | if self.num_gqa_groups > 1:
371 | Xk_BxNxSxH = Xk_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
372 | Xv_BxNxSxH = Xv_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
373 | else:
374 | Xk_BxNxSxH = Xk_BxKxSxH
375 | Xv_BxNxSxH = Xv_BxKxSxH
376 |
377 | # Encoder Self Attention
378 | if cache is None:
379 | attn_k = Xk_BxNxSxH
380 | attn_v = Xv_BxNxSxH
381 | # Decoder Self Attention
382 | else:
383 | # In prefill mode, we fill in cache until prefill length
384 | if prefill:
385 | attn_k, attn_v = Xk_BxNxSxH, Xv_BxNxSxH
386 | cache.prefill_kv(attn_k, attn_v)
387 | # In decode step, we add current K/V to cache step by step
388 | else:
389 | new_kv_cache = Xk_BxNxSxH, Xv_BxNxSxH
390 | attn_k, attn_v = cache.get_kv_for_attention(Xk_BxNxSxH, Xv_BxNxSxH)
391 |
392 | attn_output = F.scaled_dot_product_attention(
393 | Xq_BxNxTxH,
394 | attn_k,
395 | attn_v,
396 | attn_mask=attn_mask,
397 | dropout_p=self.dropout_rate if not deterministic else 0.0,
398 | scale=1.0,
399 | )
400 |
401 | attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
402 | output = self.o_proj(attn_output)
403 |
404 | return output.to(original_dtype), new_kv_cache
405 |
406 |
407 | class EncoderLayer(nn.Module):
408 | """Transformer Encoder Layer using DenseGeneral."""
409 |
410 | def __init__(self, config: DiaConfig):
411 | super().__init__()
412 | self.config = config
413 | model_config = config.model
414 | enc_config = config.model.encoder
415 | embed_dim = enc_config.n_embd
416 |
417 | self.pre_sa_norm = RMSNorm(
418 | embed_dim,
419 | eps=model_config.normalization_layer_epsilon,
420 | dtype=torch.float32,
421 | )
422 | self.self_attention = Attention(
423 | config=config,
424 | q_embed_dim=embed_dim,
425 | kv_embed_dim=embed_dim,
426 | num_query_heads=enc_config.n_head,
427 | num_kv_heads=enc_config.n_head,
428 | head_dim=enc_config.head_dim,
429 | dropout_rate=model_config.dropout,
430 | is_cross_attn=False,
431 | out_embed_dim=embed_dim,
432 | )
433 | self.post_sa_norm = RMSNorm(
434 | embed_dim,
435 | eps=model_config.normalization_layer_epsilon,
436 | dtype=torch.float32,
437 | )
438 | self.mlp = MlpBlock(
439 | config=config,
440 | embed_dim=embed_dim,
441 | intermediate_dim=enc_config.n_hidden,
442 | activations=enc_config.mlp_activations,
443 | dropout_rate=model_config.dropout,
444 | use_pre_norm=enc_config.use_pre_norm,
445 | )
446 | self.dropout = nn.Dropout(model_config.dropout)
447 |
448 | def forward(
449 | self,
450 | x: torch.Tensor,
451 | src_positions: torch.Tensor | None = None,
452 | deterministic: bool = True,
453 | attn_mask: torch.Tensor | None = None,
454 | ) -> torch.Tensor:
455 | residual = x
456 | x_norm = self.pre_sa_norm(x)
457 |
458 | sa_out, _ = self.self_attention(
459 | Xq=x_norm,
460 | Xkv=x_norm,
461 | q_positions=src_positions,
462 | kv_positions=src_positions,
463 | deterministic=deterministic,
464 | attn_mask=attn_mask,
465 | )
466 | x = residual + sa_out
467 |
468 | residual = x
469 | x_norm = self.post_sa_norm(x)
470 | mlp_out = self.mlp(x_norm, deterministic=deterministic)
471 | x = residual + mlp_out
472 |
473 | if not deterministic:
474 | x = self.dropout(x)
475 | return x
476 |
477 |
478 | class Encoder(nn.Module):
479 | """Transformer Encoder Stack using DenseGeneral."""
480 |
481 | def __init__(self, config: DiaConfig):
482 | super().__init__()
483 | self.config = config
484 | model_config = config.model
485 | enc_config = config.model.encoder
486 | compute_dtype = _str_to_dtype(config.training.dtype)
487 |
488 | self.embedding = nn.Embedding(
489 | model_config.src_vocab_size,
490 | enc_config.n_embd,
491 | dtype=compute_dtype,
492 | )
493 | self.dropout = nn.Dropout(model_config.dropout)
494 | self.layers = nn.ModuleList([EncoderLayer(config=config) for _ in range(enc_config.n_layer)])
495 | self.norm = RMSNorm(
496 | enc_config.n_embd,
497 | eps=model_config.normalization_layer_epsilon,
498 | dtype=torch.float32,
499 | )
500 |
501 | def forward(
502 | self,
503 | x_ids: torch.Tensor,
504 | src_positions: torch.Tensor | None = None,
505 | deterministic: bool = True,
506 | attn_mask: torch.Tensor | None = None,
507 | ) -> torch.Tensor:
508 | x = self.embedding(x_ids)
509 |
510 | if not deterministic:
511 | x = self.dropout(x)
512 |
513 | for layer in self.layers:
514 | x = layer(
515 | x,
516 | src_positions=src_positions,
517 | deterministic=deterministic,
518 | attn_mask=attn_mask,
519 | )
520 | x = self.norm(x)
521 | if not deterministic:
522 | x = self.dropout(x)
523 | return x
524 |
525 |
526 | class DecoderLayer(nn.Module):
527 | """Transformer Decoder Layer using DenseGeneral."""
528 |
529 | def __init__(self, config: DiaConfig):
530 | super().__init__()
531 | self.config = config
532 | model_config = config.model
533 | dec_config = config.model.decoder
534 | enc_config = config.model.encoder
535 | dec_embed_dim = dec_config.n_embd
536 | enc_embed_dim = enc_config.n_embd
537 |
538 | # Norms
539 | self.pre_sa_norm = RMSNorm(
540 | dec_embed_dim,
541 | eps=model_config.normalization_layer_epsilon,
542 | dtype=torch.float32,
543 | )
544 | self.pre_ca_norm = RMSNorm(
545 | dec_embed_dim,
546 | eps=model_config.normalization_layer_epsilon,
547 | dtype=torch.float32,
548 | )
549 | self.pre_mlp_norm = RMSNorm(
550 | dec_embed_dim,
551 | eps=model_config.normalization_layer_epsilon,
552 | dtype=torch.float32,
553 | )
554 |
555 | # Self-Attention (GQA) with Causal Masking
556 | self.self_attention = Attention(
557 | config=config,
558 | q_embed_dim=dec_embed_dim,
559 | kv_embed_dim=dec_embed_dim,
560 | num_query_heads=dec_config.gqa_query_heads,
561 | num_kv_heads=dec_config.kv_heads,
562 | head_dim=dec_config.gqa_head_dim,
563 | dropout_rate=model_config.dropout,
564 | is_cross_attn=False,
565 | out_embed_dim=dec_embed_dim,
566 | )
567 | # Cross-Attention (MHA)
568 | self.cross_attention = Attention(
569 | config=config,
570 | q_embed_dim=dec_embed_dim,
571 | kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
572 | num_query_heads=dec_config.cross_query_heads,
573 | num_kv_heads=dec_config.cross_query_heads,
574 | head_dim=dec_config.cross_head_dim,
575 | dropout_rate=model_config.dropout,
576 | is_cross_attn=True,
577 | out_embed_dim=dec_embed_dim,
578 | )
579 | # MLP
580 | self.mlp = MlpBlock(
581 | config=config,
582 | embed_dim=dec_embed_dim,
583 | intermediate_dim=dec_config.n_hidden,
584 | activations=dec_config.mlp_activations,
585 | dropout_rate=model_config.dropout,
586 | use_pre_norm=dec_config.use_pre_norm,
587 | )
588 |
589 | def forward(
590 | self,
591 | x: torch.Tensor,
592 | encoder_out: torch.Tensor,
593 | tgt_positions: torch.Tensor,
594 | src_positions: torch.Tensor | None,
595 | deterministic: bool,
596 | self_attn_mask: torch.Tensor,
597 | cross_attn_mask: torch.Tensor,
598 | self_attn_cache: KVCache,
599 | cross_attn_cache: KVCache,
600 | prefill: bool = False,
601 | ) -> torch.Tensor:
602 | residual = x
603 | x_norm = self.pre_sa_norm(x)
604 |
605 | sa_out, new_kv_cache = self.self_attention(
606 | Xq=x_norm, # (2, 1, D)
607 | Xkv=x_norm, # (2, 1, D)
608 | q_positions=tgt_positions, # (2, 1)
609 | kv_positions=tgt_positions, # (2, 1)
610 | deterministic=deterministic,
611 | attn_mask=self_attn_mask, # (2, 1, 1, S_max)
612 | cache=self_attn_cache,
613 | prefill=prefill,
614 | )
615 |
616 | x = residual + sa_out
617 |
618 | # 2. Cross-Attention
619 | residual = x
620 | x_norm = self.pre_ca_norm(x)
621 | ca_out, _ = self.cross_attention(
622 | Xq=x_norm,
623 | Xkv=encoder_out,
624 | q_positions=tgt_positions,
625 | kv_positions=src_positions,
626 | deterministic=deterministic,
627 | attn_mask=cross_attn_mask,
628 | cache=cross_attn_cache,
629 | )
630 | x = residual + ca_out
631 |
632 | # 3. MLP
633 | residual = x
634 | x_norm = self.pre_mlp_norm(x)
635 | mlp_out = self.mlp(x_norm, deterministic=deterministic)
636 | x = residual + mlp_out
637 |
638 | return x, new_kv_cache
639 |
640 |
641 | class Decoder(nn.Module):
642 | """Transformer Decoder Stack using DenseGeneral."""
643 |
644 | def __init__(self, config: DiaConfig):
645 | super().__init__()
646 | self.config = config
647 | model_config = config.model
648 | dec_config = config.model.decoder
649 | train_config = config.training
650 | data_config = config.data
651 | compute_dtype = _str_to_dtype(config.training.dtype)
652 | weight_dtype = _str_to_dtype(config.model.weight_dtype)
653 | self.num_channels = data_config.channels
654 | self.num_layers = dec_config.n_layer
655 |
656 | self.embeddings = nn.ModuleList(
657 | [
658 | nn.Embedding(model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype)
659 | for _ in range(self.num_channels)
660 | ]
661 | )
662 | self.dropout = nn.Dropout(model_config.dropout)
663 | self.layers = nn.ModuleList([DecoderLayer(config=config) for _ in range(self.num_layers)])
664 | self.norm = RMSNorm(
665 | dec_config.n_embd,
666 | eps=model_config.normalization_layer_epsilon,
667 | dtype=torch.float32,
668 | )
669 |
670 | # Final Logits Projection using DenseGeneral
671 | self.logits_dense = DenseGeneral(
672 | in_shapes=(dec_config.n_embd,),
673 | out_features=(self.num_channels, model_config.tgt_vocab_size),
674 | axis=(-1,),
675 | dtype=(torch.float32 if train_config.logits_dot_in_fp32 else compute_dtype),
676 | weight_dtype=weight_dtype,
677 | )
678 | self.logits_in_fp32 = train_config.logits_dot_in_fp32
679 |
680 | def precompute_cross_attention_kv(
681 | self,
682 | max_len: int,
683 | encoder_out: torch.Tensor, # (B, S, E)
684 | src_positions: torch.Tensor | None, # (B, S)
685 | ) -> list[KVCache]:
686 | """
687 | Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
688 | """
689 | per_layer_kv_cache: list[KVCache] = []
690 |
691 | for layer in self.layers:
692 | cross_attn_module = layer.cross_attention
693 | k_proj = cross_attn_module.k_proj(encoder_out)
694 | v_proj = cross_attn_module.v_proj(encoder_out)
695 |
696 | k_proj = cross_attn_module.rotary_emb(k_proj, position=src_positions)
697 | k = k_proj.transpose(1, 2)
698 | v = v_proj.transpose(1, 2)
699 |
700 | per_layer_kv_cache.append(
701 | KVCache(
702 | cross_attn_module.num_kv_heads,
703 | max_len,
704 | cross_attn_module.head_dim,
705 | k.device,
706 | k=k,
707 | v=v,
708 | )
709 | )
710 |
711 | return per_layer_kv_cache
712 |
713 | def decode_step(
714 | self,
715 | tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
716 | tgt_pos_Bx1: torch.Tensor, # [B, 1]
717 | encoder_out: torch.Tensor, # [B, S, E]
718 | self_attn_mask: Any, # None
719 | cross_attn_mask: torch.Tensor, # [B, 1, 1, S]
720 | self_attention_cache: list[KVCache],
721 | cross_attention_cache: list[KVCache],
722 | ) -> torch.Tensor:
723 | """
724 | Performs a single decoding step, managing KV caches layer by layer.
725 |
726 | Returns:
727 | A tuple containing:
728 | - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
729 | """
730 | assert self_attn_mask is None, "Self-attention mask should be None, kept for pattern"
731 |
732 | x = None
733 | for i in range(self.num_channels):
734 | channel_tokens = tgt_ids_Bx1xC[..., i]
735 | channel_embed = self.embeddings[i](channel_tokens)
736 | x = channel_embed if x is None else x + channel_embed
737 |
738 | new_cache = []
739 |
740 | for i, layer in enumerate(self.layers):
741 | self_cache = self_attention_cache[i]
742 | cross_cache = cross_attention_cache[i]
743 | x, new_kv_cache = layer(
744 | x, # (2, 1, D)
745 | encoder_out, # (2, S, E)
746 | src_positions=None, # CA KV is already computed
747 | tgt_positions=tgt_pos_Bx1, # (2, 1)
748 | deterministic=True,
749 | self_attn_mask=None,
750 | cross_attn_mask=cross_attn_mask,
751 | self_attn_cache=self_cache,
752 | cross_attn_cache=cross_cache,
753 | )
754 | new_cache.append(new_kv_cache)
755 |
756 | x = self.norm(x)
757 | logits_Bx1xCxV = self.logits_dense(x)
758 |
759 | return logits_Bx1xCxV.to(torch.float32), new_cache
760 |
761 | def forward(
762 | self,
763 | tgt_ids_BxTxC: torch.Tensor,
764 | encoder_out: torch.Tensor,
765 | tgt_positions: torch.Tensor,
766 | src_positions: torch.Tensor,
767 | deterministic: bool,
768 | self_attn_mask: torch.Tensor,
769 | cross_attn_mask: torch.Tensor,
770 | self_attention_cache: list[KVCache],
771 | cross_attention_cache: list[KVCache],
772 | ) -> torch.Tensor:
773 | """
774 | Forward pass for the Decoder stack, managing KV caches.
775 |
776 | Args:
777 | tgt_ids_BxTxC: Target token IDs (B, T, C).
778 | encoder_out: Output from the encoder (B, S, E).
779 | tgt_positions: Positions for target sequence (B, T).
780 | src_positions: Positions for source sequence (B, S).
781 | deterministic: Disable dropout if True.
782 | self_attn_mask: Mask for self-attention.
783 | cross_attn_mask: Mask for cross-attention.
784 | past_key_values: List containing the self-attention KV cache for each layer
785 | from the previous decoding step. `len(past_key_values)` should
786 | equal `num_layers`.
787 | precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache
788 | derived from `encoder_out`. This is passed identically
789 | to all layers.
790 |
791 | Returns:
792 | A tuple containing:
793 | - logits: The final output logits (B, T, C * V), cast to float32.
794 | - present_key_values: A list containing the updated self-attention KV cache
795 | for each layer for the *current* decoding step.
796 | """
797 | _, _, num_channels_in = tgt_ids_BxTxC.shape
798 | assert num_channels_in == self.num_channels, "Input channels mismatch"
799 |
800 | # Embeddings
801 | x = None
802 | for i in range(self.num_channels):
803 | channel_tokens = tgt_ids_BxTxC[..., i]
804 | channel_embed = self.embeddings[i](channel_tokens)
805 | x = channel_embed if x is None else x + channel_embed
806 |
807 | if not deterministic:
808 | x = self.dropout(x)
809 |
810 | for i, layer in enumerate(self.layers):
811 | x, _ = layer(
812 | x,
813 | encoder_out,
814 | tgt_positions=tgt_positions,
815 | src_positions=src_positions,
816 | deterministic=deterministic,
817 | self_attn_mask=self_attn_mask,
818 | cross_attn_mask=cross_attn_mask,
819 | self_attn_cache=self_attention_cache[i],
820 | cross_attn_cache=cross_attention_cache[i],
821 | prefill=True,
822 | )
823 |
824 | # Final Norm
825 | x = self.norm(x)
826 | logits_BxTxCxV = self.logits_dense(x)
827 |
828 | return logits_BxTxCxV.to(torch.float32)
829 |
830 |
831 | class DiaModel(nn.Module):
832 | """PyTorch Dia Model using DenseGeneral."""
833 |
834 | def __init__(self, config: DiaConfig):
835 | super().__init__()
836 | self.config = config
837 | self.encoder = Encoder(config)
838 | self.decoder = Decoder(config)
839 | #self._init_weights()
840 |
841 |
842 | def _init_weights(self):
843 | for module in self.modules():
844 | if isinstance(module, (torch.nn.Linear, torch.nn.Conv1d)):
845 | torch.nn.init.xavier_uniform_(module.weight)
846 | if module.bias is not None:
847 | torch.nn.init.zeros_(module.bias)
848 | elif isinstance(module, torch.nn.Embedding):
849 | torch.nn.init.xavier_uniform_(module.weight)
850 | elif isinstance(module, torch.nn.LayerNorm) or isinstance(module, torch.nn.modules.normalization.RMSNorm):
851 | if hasattr(module, 'weight') and module.weight is not None:
852 | torch.nn.init.ones_(module.weight)
853 | if hasattr(module, 'bias') and module.bias is not None:
854 | torch.nn.init.zeros_(module.bias)
855 |
856 | def forward(
857 | self,
858 | src_BxS: torch.Tensor,
859 | tgt_BxTxC: torch.Tensor,
860 | src_positions: torch.Tensor | None = None,
861 | tgt_positions: torch.Tensor | None = None,
862 | enc_self_attn_mask: torch.Tensor | None = None,
863 | dec_self_attn_mask: torch.Tensor | None = None,
864 | dec_cross_attn_mask: torch.Tensor | None = None,
865 | enable_dropout: bool = True,
866 | ):
867 | deterministic = not enable_dropout
868 |
869 | # --- Encoder Pass ---
870 | encoder_out = self.encoder(
871 | x_ids=src_BxS,
872 | src_positions=src_positions,
873 | deterministic=deterministic,
874 | attn_mask=enc_self_attn_mask,
875 | )
876 |
877 | B, T, C = tgt_BxTxC.shape # Batch size, target sequence length, channels
878 | device = tgt_BxTxC.device
879 |
880 | self_attention_cache = [
881 | KVCache(
882 | num_heads=self.decoder.layers[i].self_attention.num_query_heads, # ✅ FIXED: use query heads!
883 | max_len=T,
884 | head_dim=self.decoder.layers[i].self_attention.head_dim,
885 | device=device,
886 | )
887 | for i in range(self.decoder.num_layers)
888 | ]
889 |
890 | cross_attention_cache = self.decoder.precompute_cross_attention_kv(
891 | max_len=encoder_out.shape[1],
892 | encoder_out=encoder_out,
893 | src_positions=src_positions,
894 | )
895 |
896 | # --- Decoder Pass ---
897 | logits = self.decoder(
898 | tgt_ids_BxTxC=tgt_BxTxC,
899 | encoder_out=encoder_out,
900 | tgt_positions=tgt_positions,
901 | src_positions=src_positions,
902 | deterministic=deterministic,
903 | self_attn_mask=dec_self_attn_mask,
904 | cross_attn_mask=dec_cross_attn_mask,
905 | self_attention_cache=self_attention_cache,
906 | cross_attention_cache=cross_attention_cache
907 | )
908 |
909 | return logits
910 |
--------------------------------------------------------------------------------
/dia/model.py:
--------------------------------------------------------------------------------
1 | import dac
2 | import numpy as np
3 | import torch
4 | import torchaudio
5 | from huggingface_hub import hf_hub_download
6 |
7 | from .audio import audio_to_codebook, codebook_to_audio
8 | from .config import DiaConfig
9 | from .layers import DiaModel, KVCache
10 |
11 |
12 | def get_default_device():
13 | if torch.cuda.is_available():
14 | return torch.device("cuda")
15 | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
16 | return torch.device("mps")
17 | return torch.device("cpu")
18 |
19 |
20 | def _sample_next_token(
21 | logits_BCxV: torch.Tensor,
22 | temperature: float,
23 | top_p: float,
24 | use_cfg_filter: bool,
25 | cfg_filter_top_k: int | None = None,
26 | ) -> torch.Tensor:
27 | if temperature == 0.0:
28 | return torch.argmax(logits_BCxV, dim=-1)
29 |
30 | logits_BCxV = logits_BCxV / temperature
31 | if use_cfg_filter and cfg_filter_top_k is not None:
32 | _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
33 | mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
34 | mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
35 | logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
36 |
37 | if top_p < 1.0:
38 | probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
39 | sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True)
40 | cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
41 |
42 | # Calculate indices to remove based on top_p
43 | sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
44 | # Shift the mask to the right to keep the first token above the threshold
45 | sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[..., :-1].clone()
46 | sorted_indices_to_remove_BCxV[..., 0] = 0 # Always keep the most probable token
47 |
48 | indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
49 | indices_to_remove_BCxV.scatter_(dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV)
50 | logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
51 |
52 | final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
53 |
54 | sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1)
55 | sampled_indices_C = sampled_indices_BC.squeeze(-1)
56 | return sampled_indices_C
57 |
58 |
59 | class Dia:
60 | def __init__(self, config: DiaConfig, device: torch.device | None = None):
61 | """Initializes the Dia model.
62 |
63 | Args:
64 | config: The configuration object for the model.
65 | device: The device to load the model onto. If None, will automatically select the best available device.
66 |
67 | Raises:
68 | RuntimeError: If there is an error loading the DAC model.
69 | """
70 | super().__init__()
71 | self.config = config
72 | self.device = device if device is not None else get_default_device()
73 | self.model = DiaModel(config)
74 | self.dac_model = None
75 |
76 | @classmethod
77 | def from_local(cls, config_path: str, checkpoint_path: str, device: torch.device | None = None) -> "Dia":
78 | """Loads the Dia model from local configuration and checkpoint files.
79 |
80 | Args:
81 | config_path: Path to the configuration JSON file.
82 | checkpoint_path: Path to the model checkpoint (.pth) file.
83 | device: The device to load the model onto. If None, will automatically select the best available device.
84 |
85 | Returns:
86 | An instance of the Dia model loaded with weights and set to eval mode.
87 |
88 | Raises:
89 | FileNotFoundError: If the config or checkpoint file is not found.
90 | RuntimeError: If there is an error loading the checkpoint.
91 | """
92 | config = DiaConfig.load(config_path)
93 | if config is None:
94 | raise FileNotFoundError(f"Config file not found at {config_path}")
95 |
96 | dia = cls(config, device)
97 |
98 | try:
99 | state_dict = torch.load(checkpoint_path, map_location=dia.device)
100 | dia.model.load_state_dict(state_dict)
101 | except FileNotFoundError:
102 | raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
103 | except Exception as e:
104 | raise RuntimeError(f"Error loading checkpoint from {checkpoint_path}") from e
105 |
106 | dia.model.to(dia.device)
107 | dia.model.eval()
108 | dia._load_dac_model()
109 | return dia
110 |
111 | @classmethod
112 | def from_pretrained(
113 | cls, model_name: str = "nari-labs/Dia-1.6B", device: torch.device | None = None
114 | ) -> "Dia":
115 | """Loads the Dia model from a Hugging Face Hub repository.
116 |
117 | Downloads the configuration and checkpoint files from the specified
118 | repository ID and then loads the model.
119 |
120 | Args:
121 | model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
122 | device: The device to load the model onto. If None, will automatically select the best available device.
123 |
124 | Returns:
125 | An instance of the Dia model loaded with weights and set to eval mode.
126 |
127 | Raises:
128 | FileNotFoundError: If config or checkpoint download/loading fails.
129 | RuntimeError: If there is an error loading the checkpoint.
130 | """
131 | config_path = hf_hub_download(repo_id=model_name, filename="config.json")
132 | checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
133 | return cls.from_local(config_path, checkpoint_path, device)
134 |
135 | def _load_dac_model(self):
136 | try:
137 | dac_model_path = dac.utils.download()
138 | dac_model = dac.DAC.load(dac_model_path).to(self.device)
139 | except Exception as e:
140 | raise RuntimeError("Failed to load DAC model") from e
141 | self.dac_model = dac_model
142 |
143 | def _create_attn_mask(
144 | self,
145 | q_padding_mask_1d: torch.Tensor,
146 | k_padding_mask_1d: torch.Tensor,
147 | is_causal: bool = False,
148 | ) -> torch.Tensor:
149 | """
150 | Creates the attention mask (self or cross) mimicking JAX segment ID logic.
151 | """
152 | B1, Tq = q_padding_mask_1d.shape
153 | B2, Tk = k_padding_mask_1d.shape
154 | assert B1 == B2, "Query and key batch dimensions must match"
155 |
156 | p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
157 | p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
158 |
159 | # Condition A: Non-padding query attends to non-padding key
160 | non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
161 |
162 | # Condition B: Padding query attends to padding key
163 | pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
164 |
165 | # Combine: True if padding status is compatible (both non-pad OR both pad)
166 | # This implementation follows Jax TPU splash attention kernel
167 | mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
168 |
169 | if is_causal:
170 | # Ensure causality for self-attention (Tq == Tk)
171 | assert Tq == Tk, "Causal mask requires query and key sequence lengths to be equal"
172 | # Standard lower-triangular causal mask (True means allow)
173 | causal_mask_2d = torch.tril(torch.ones((Tq, Tk), dtype=torch.bool, device=self.device)) # Shape [Tq, Tk]
174 | causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
175 | return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads
176 | else:
177 | # For cross-attention or non-causal self-attention
178 | return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads
179 |
180 | def _prepare_text_input(self, text: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
181 | """Encodes text prompt, pads, and creates attention mask and positions."""
182 | text_pad_value = self.config.data.text_pad_value
183 | max_len = self.config.data.text_length
184 |
185 | byte_text = text.encode("utf-8")
186 |
187 |
188 | replaced_bytes = byte_text
189 |
190 | LANG2BYTE = {
191 | "en": 3,
192 | "de": 4,
193 | "fr": 5,
194 | "es": 6,
195 | "it": 7,
196 | "nl": 14,
197 | "pl": 15,
198 | "pt": 16,
199 | "tr": 17,
200 | "hu": 18,
201 | }
202 |
203 | for lang, byte_val in LANG2BYTE.items():
204 | tag = f"[{lang}]".encode("ascii") # e.g. b"[de]"
205 | code = bytes([byte_val]) # e.g. b"\x04"
206 | replaced_bytes = replaced_bytes.replace(tag, code)
207 | text_tokens = list(replaced_bytes)
208 |
209 | current_len = len(text_tokens)
210 | padding_needed = max_len - current_len
211 | if padding_needed <= 0:
212 | text_tokens = text_tokens[:max_len]
213 | padded_text_np = np.array(text_tokens, dtype=np.uint8)
214 | else:
215 | padded_text_np = np.pad(
216 | text_tokens,
217 | (0, padding_needed),
218 | mode="constant",
219 | constant_values=text_pad_value,
220 | ).astype(np.uint8)
221 |
222 | src_tokens = torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0) # [1, S]
223 | src_positions = torch.arange(max_len, device=self.device).to(torch.long).unsqueeze(0) # [1, S]
224 |
225 | src_padding_mask = (src_tokens != text_pad_value).to(self.device) # [1, S]
226 |
227 | enc_self_attn_mask = self._create_attn_mask(src_padding_mask, src_padding_mask, is_causal=False) # [1, S, S]
228 |
229 | return src_tokens, src_positions, src_padding_mask, enc_self_attn_mask
230 |
231 | @torch.inference_mode()
232 | def generate(
233 | self,
234 | text: str,
235 | max_tokens: int | None = None,
236 | cfg_scale: float = 3.0,
237 | temperature: float = 1.3,
238 | top_p: float = 0.95,
239 | use_cfg_filter: bool = True,
240 | use_torch_compile: bool = False,
241 | cfg_filter_top_k: int = 35,
242 | audio_prompt_path: str | None = None,
243 | ) -> np.ndarray:
244 | """
245 | Generates audio from a text prompt (and optional audio prompt) using the Nari model.
246 |
247 | Returns:
248 | A tensor of generated audio codes (shape: [max_tokens, num_channels]).
249 | """
250 | num_channels = self.config.data.channels
251 | audio_bos_value = self.config.data.audio_bos_value
252 | audio_eos_value = self.config.data.audio_eos_value
253 | audio_pad_value = self.config.data.audio_pad_value
254 | delay_pattern = self.config.data.delay_pattern
255 | max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
256 | delay_tensor = torch.tensor(delay_pattern, dtype=torch.long, device=self.device)
257 | max_delay_pattern = max(delay_pattern)
258 | self.model.eval()
259 |
260 | (
261 | cond_src_BxS,
262 | cond_src_positions_BxS,
263 | cond_src_padding_mask_BxS,
264 | cond_enc_self_attn_mask_Bx1xSxS,
265 | ) = self._prepare_text_input(text)
266 |
267 | unc_src_BxS = torch.zeros_like(cond_src_BxS)
268 | src_BxS = torch.cat([unc_src_BxS, cond_src_BxS], dim=0)
269 | src_positions_BxS = cond_src_positions_BxS.expand(2, -1)
270 | src_padding_mask_BxS = cond_src_padding_mask_BxS.expand(2, -1)
271 | enc_self_attn_mask_Bx1xSxS = cond_enc_self_attn_mask_Bx1xSxS.expand(2, -1, -1, -1)
272 |
273 | # 2. Encoder Pass
274 | # with torch.autocast(device_type="cuda", dtype=forward_dtype):
275 | encoder_out = self.model.encoder(
276 | x_ids=src_BxS,
277 | src_positions=src_positions_BxS,
278 | deterministic=True,
279 | attn_mask=enc_self_attn_mask_Bx1xSxS,
280 | ) # Shape: (B, S, E)
281 |
282 | # 3. Prepare Decoder Inputs
283 | # 3-1. Allocate KV Cache (Static)
284 | decoder_cross_attention_cache: list[KVCache] = self.model.decoder.precompute_cross_attention_kv(
285 | max_tokens, encoder_out, src_positions_BxS
286 | )
287 |
288 | decoder_self_attention_cache: list[KVCache] = []
289 | for _ in range(self.model.decoder.num_layers):
290 | decoder_self_attention_cache.append(
291 | KVCache(
292 | self.config.model.decoder.gqa_query_heads,
293 | max_tokens,
294 | self.config.model.decoder.gqa_head_dim,
295 | self.device,
296 | )
297 | )
298 |
299 | # 3-2. Initialize Decoder Inputs
300 | generated_BxTxC = torch.full(
301 | (2, 1, num_channels),
302 | fill_value=audio_bos_value,
303 | dtype=torch.long,
304 | device=self.device,
305 | )
306 |
307 | current_step = 0
308 | prompt_len_inc_bos = 1 # Start with BOS length
309 |
310 | # 3-3. Load Audio Prompt (if provided)
311 | if audio_prompt_path is not None:
312 | audio_prompt, sr = torchaudio.load(audio_prompt_path, channels_first=True) # C, T
313 | if sr != 44100: # Resample to 44.1kHz
314 | audio_prompt = torchaudio.functional.resample(audio_prompt, sr, 44100)
315 | audio_prompt = audio_prompt.to(self.device).unsqueeze(0) # 1, C, T
316 | audio_prompt = audio_to_codebook(self.dac_model, audio_prompt, data_config=self.config.data)
317 | generated_BxTxC = torch.cat([generated_BxTxC, audio_prompt.expand(2, -1, -1)], dim=1)
318 |
319 | prefill_len = generated_BxTxC.shape[1]
320 | prompt_len_inc_bos = prefill_len
321 | prefill_tgt_pos = torch.arange(prefill_len, device=self.device).unsqueeze(0).expand(2, -1)
322 | prefill_tgt_padding_mask = (generated_BxTxC != audio_pad_value).any(dim=2)
323 |
324 | prefill_self_attn_mask = self._create_attn_mask(
325 | prefill_tgt_padding_mask,
326 | prefill_tgt_padding_mask,
327 | is_causal=True,
328 | )
329 | prefill_cross_attn_mask = self._create_attn_mask(
330 | prefill_tgt_padding_mask,
331 | src_padding_mask_BxS,
332 | is_causal=False,
333 | )
334 |
335 | _ = self.model.decoder.forward(
336 | tgt_ids_BxTxC=generated_BxTxC,
337 | encoder_out=encoder_out,
338 | tgt_positions=prefill_tgt_pos,
339 | src_positions=src_positions_BxS,
340 | deterministic=True,
341 | self_attn_mask=prefill_self_attn_mask,
342 | cross_attn_mask=prefill_cross_attn_mask,
343 | self_attention_cache=decoder_self_attention_cache,
344 | cross_attention_cache=decoder_cross_attention_cache,
345 | )
346 |
347 | current_step = prefill_len - 1
348 |
349 | # 4. Autoregressive Generation Loop
350 | eos_detected_channel_0 = False
351 | eos_countdown = -1
352 | extra_steps_after_eos = 30
353 | # Make generated_BxTxC a fixed size tensor
354 | # Length is either 1 + max tokens or 1 + prompt len + max tokens
355 | generated_BxTxC = torch.cat(
356 | [
357 | generated_BxTxC,
358 | torch.full(
359 | (2, max_tokens, num_channels),
360 | fill_value=-1,
361 | dtype=torch.long,
362 | device=self.device,
363 | ),
364 | ],
365 | dim=1,
366 | )
367 |
368 | decode_step = self.model.decoder.decode_step
369 | if use_torch_compile:
370 | decode_step = torch.compile(
371 | self.model.decoder.decode_step,
372 | mode="default",
373 | )
374 |
375 | tgt_padding_mask = (
376 | (generated_BxTxC[:, -1, :].unsqueeze(1) != audio_pad_value).any(dim=2).to(self.device)
377 | ) # [B, 1]
378 | # Generated tokens are never PAD, so we use fixed mask
379 | decoder_cross_attn_mask = self._create_attn_mask(
380 | tgt_padding_mask, # Query mask [B, 1]
381 | src_padding_mask_BxS, # Key mask [B, S]
382 | is_causal=False,
383 | ) # [B, 1, 1, S]
384 |
385 | for step in range(current_step, current_step + max_tokens):
386 | tgt_ids_Bx1xC = generated_BxTxC[:, step, :].unsqueeze(1)
387 | tgt_pos_Bx1 = torch.full(
388 | (2, 1),
389 | fill_value=step,
390 | dtype=torch.long,
391 | device=self.device,
392 | )
393 |
394 | logits_Bx1xCxV, new_cache = decode_step(
395 | tgt_ids_Bx1xC=tgt_ids_Bx1xC,
396 | tgt_pos_Bx1=tgt_pos_Bx1,
397 | encoder_out=encoder_out,
398 | self_attn_mask=None,
399 | cross_attn_mask=decoder_cross_attn_mask,
400 | self_attention_cache=decoder_self_attention_cache,
401 | cross_attention_cache=decoder_cross_attention_cache,
402 | )
403 |
404 | for i, layer_cache in enumerate(decoder_self_attention_cache):
405 | layer_cache.update_cache(new_cache[i][0], new_cache[i][1])
406 |
407 | V = self.config.model.tgt_vocab_size
408 | logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :] # B, C, V
409 | uncond_logits_CxV = logits_last_BxCxV[0, :, :]
410 | cond_logits_CxV = logits_last_BxCxV[1, :, :]
411 |
412 | cfg_logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
413 |
414 | logits_CxV = cfg_logits_CxV.reshape((-1, V)) # C, V
415 | logits_CxV[:, 1025:] = -torch.inf
416 |
417 | # Sample next token
418 | pred_C = _sample_next_token(
419 | logits_CxV.float(),
420 | temperature=temperature,
421 | top_p=top_p,
422 | use_cfg_filter=use_cfg_filter,
423 | cfg_filter_top_k=cfg_filter_top_k,
424 | )
425 |
426 | generation_step_index = step - current_step
427 | if audio_prompt_path is None:
428 | pred_C = torch.where(
429 | generation_step_index >= delay_tensor,
430 | pred_C,
431 | audio_bos_value,
432 | )
433 |
434 | generated_BxTxC[:, step + 1, :] = pred_C.unsqueeze(0).expand(2, -1)
435 |
436 | if not eos_detected_channel_0 and pred_C[0] == audio_eos_value:
437 | eos_detected_channel_0 = True
438 | eos_countdown = extra_steps_after_eos
439 |
440 | if eos_countdown > 0:
441 | step_after_eos = max_delay_pattern - eos_countdown
442 | for i, d in enumerate(delay_pattern):
443 | if step_after_eos == d:
444 | generated_BxTxC[:, step + 1, i] = audio_eos_value
445 | elif step_after_eos > d:
446 | generated_BxTxC[:, step + 1, i] = audio_pad_value
447 | eos_countdown -= 1
448 | if eos_countdown == 0:
449 | break
450 |
451 | generation_step_index = step - current_step + 1
452 |
453 | output_codes = generated_BxTxC[:, prompt_len_inc_bos : step + 1, :]
454 |
455 | generated_codes = output_codes[0]
456 |
457 | audio = codebook_to_audio(
458 | generated_codes.transpose(1, 0), self.dac_model, delay_pattern, B=1, T=max_tokens, C=num_channels
459 | )
460 | return audio.squeeze().cpu().numpy()
461 |
--------------------------------------------------------------------------------
/dia/static/images/banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stlohrey/dia-finetuning/25501f2911a20b7211a42640f391a07d562fd2f6/dia/static/images/banner.png
--------------------------------------------------------------------------------
/example/simple.py:
--------------------------------------------------------------------------------
1 | import soundfile as sf
2 |
3 | from dia.model import Dia
4 |
5 |
6 | model = Dia.from_pretrained("nari-labs/Dia-1.6B")
7 |
8 | text = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."
9 |
10 | output = model.generate(text)
11 |
12 | sf.write("simple.mp3", output, 44100)
13 |
--------------------------------------------------------------------------------
/example/voice_clone.py:
--------------------------------------------------------------------------------
1 | import soundfile as sf
2 |
3 | from dia.model import Dia
4 |
5 |
6 | model = Dia.from_pretrained("nari-labs/Dia-1.6B")
7 |
8 | # You should put the transcript of the voice you want to clone
9 | # We will use the audio created by running simple.py as an example.
10 | # Note that you will be REQUIRED TO RUN simple.py for the script to work as-is.
11 | clone_from_text = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."
12 | clone_from_audio = "simple.mp3"
13 |
14 | # For your custom needs, replace above with below and add your audio file to this directory:
15 | # clone_from_text = "[S1] ... [S2] ... [S1] ... corresponding to your_audio_name.mp3"
16 | # clone_from_audio = "your_audio_name.mp3"
17 |
18 | # Text to generate
19 | text_to_generate = "[S1] Hello, how are you? [S2] I'm good, thank you. [S1] What's your name? [S2] My name is Dia. [S1] Nice to meet you. [S2] Nice to meet you too."
20 |
21 | # It will only return the audio from the text_to_generate
22 | output = model.generate(clone_from_text + text_to_generate, audio_prompt_path=clone_from_audio)
23 |
24 | sf.write("voice_clone.mp3", output, 44100)
25 |
--------------------------------------------------------------------------------
/example_prompt.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stlohrey/dia-finetuning/25501f2911a20b7211a42640f391a07d562fd2f6/example_prompt.mp3
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "nari-tts"
3 | version = "0.1.0"
4 | description = "Dia - A text-to-speech model for dialogue generation"
5 | readme = "README.md"
6 | requires-python = ">=3.10"
7 | license = {file = "LICENSE"}
8 | authors = [
9 | {name = "Nari Labs", email = "contact@narilabs.ai"},
10 | {name = "Steffen Lohrey", email = "st.lohrey@gmail.com"}
11 | ]
12 | dependencies = [
13 |
14 | "descript-audio-codec>=1.0.0",
15 | "gradio>=5.25.2",
16 | "huggingface-hub>=0.30.2",
17 | "numpy>=2.2.4",
18 | "pydantic>=2.11.3",
19 | "soundfile>=0.13.1",
20 | "torch>=2.6.0",
21 | "torchaudio>=2.6.0",
22 | "triton>=3.2.0 ; sys_platform == 'linux'",
23 | "triton-windows>=3.2.0.post18 ; sys_platform == 'win32'",
24 | "transformers>=4.35.0",
25 | "bitsandbytes>=0.39.0",
26 | "tqdm>=4.65.0",
27 | "datasets>=2.13.0",
28 | "tensorboard>=2.12.0",
29 | ]
30 |
31 | [build-system]
32 | requires = ["hatchling"]
33 | build-backend = "hatchling.build"
34 |
35 | [project.urls]
36 | "Homepage" = "https://github.com/stlohrey/dia-finetuning"
37 | "Bug Tracker" = "https://github.com/stlohrey/dia-finetuning/issues"
38 |
39 | [tool.hatch.build.targets.wheel]
40 | packages = ["dia"]
41 |
42 | [tool.ruff]
43 | # Never enforce `E501` (line length violations).
44 | lint.ignore = ["C901", "E501", "E741", "W605"]
45 | lint.select = ["C", "E", "F", "I", "W"]
46 | line-length = 119
47 |
48 | # Ignore import violations in all `__init__.py` files.
49 | [tool.ruff.lint.per-file-ignores]
50 | "__init__.py" = ["E402", "F401", "F403", "F811"]
51 |
52 | [tool.ruff.lint.isort]
53 | lines-after-imports = 2
54 |
55 | [tool.uv.sources]
56 | torch = [
57 | { index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
58 | ]
59 | torchaudio = [
60 | { index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
61 | ]
62 |
63 | [[tool.uv.index]]
64 | name = "pytorch-cu126"
65 | url = "https://download.pytorch.org/whl/cu126"
66 | explicit = true
--------------------------------------------------------------------------------