├── .gitignore
├── src
├── __init__.py
├── main.py
├── models.py
├── inference.py
└── app.py
├── run.bat
├── requirements_native_hf.txt
├── images
├── intro.png
└── podcast.png
├── requirements_base.txt
├── run.sh
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | venv
2 | __pycache__
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | # Intentionally left empty
--------------------------------------------------------------------------------
/run.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | python -m src.main
3 | pause
--------------------------------------------------------------------------------
/requirements_native_hf.txt:
--------------------------------------------------------------------------------
1 | gradio==5.14.0
2 | xcodec2==0.1.3
3 | bitsandbytes>=0.39.0
--------------------------------------------------------------------------------
/images/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nivibilla/local-llasa-tts/HEAD/images/intro.png
--------------------------------------------------------------------------------
/images/podcast.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nivibilla/local-llasa-tts/HEAD/images/podcast.png
--------------------------------------------------------------------------------
/requirements_base.txt:
--------------------------------------------------------------------------------
1 | torch==2.5.1
2 | torchaudio==2.5.1
3 | torchvision==0.20.1
4 | numpy==1.26.4
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # run.sh: Launch the Llasa TTS application using the src package
3 |
4 | # Ensure the script is run from the project root directory
5 | python -m src.main
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import sys
3 | import argparse
4 | import torch
5 |
6 | # Check Python version
7 | if sys.version_info < (3, 10):
8 | print("ERROR: Python 3.10 or higher is required.")
9 | sys.exit(1)
10 |
11 | if not torch.cuda.is_available():
12 | print("ERROR: CUDA is not available. Please use a CUDA-capable GPU.")
13 | sys.exit(1)
14 |
15 | import os
16 | from .inference import initialize_models
17 | from .models import get_llasa_model
18 | from .app import build_dashboard
19 |
20 | def main():
21 | parser = argparse.ArgumentParser(description="Run the modular Llasa TTS Dashboard.")
22 | parser.add_argument("--share", help="Enable gradio share", action="store_true")
23 | args = parser.parse_args()
24 |
25 | print("Initializing CUDA backend...", flush=True)
26 | torch.cuda.init()
27 | _ = torch.zeros(1).cuda()
28 | print(f"Using device: {torch.cuda.get_device_name()}", flush=True)
29 |
30 | # Initialize local models
31 | print("\nStep 1: Loading XCodec2 and Whisper models...", flush=True)
32 | initialize_models()
33 |
34 | print("\nStep 2: Preloading Llasa 3B model (faster startup for standard usage)...", flush=True)
35 | get_llasa_model("3B")
36 | print("Preload done. Models are ready!")
37 |
38 | # Launch Gradio
39 | print("\nLaunching Gradio interface...", flush=True)
40 | app = build_dashboard()
41 | app.launch(share=args.share, server_name="0.0.0.0", server_port=7860)
42 |
43 | if __name__ == "__main__":
44 | main()
45 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5 | from transformers.utils import move_cache
6 |
7 | # Global caches
8 | loaded_models = {}
9 | loaded_tokenizers = {}
10 |
11 | # Quantization configuration
12 | quantization_config = BitsAndBytesConfig(
13 | load_in_4bit=True,
14 | bnb_4bit_compute_dtype=torch.float16,
15 | bnb_4bit_use_double_quant=True
16 | )
17 |
18 | def get_gpu_memory():
19 | """Return current GPU memory usage in GB."""
20 | if torch.cuda.is_available():
21 | return torch.cuda.memory_allocated() / 1024**3
22 | return 0.0
23 |
24 | def unload_model(model_choice: str):
25 | """Unload a model from GPU and clear from cache."""
26 | from .models import loaded_models, loaded_tokenizers
27 | if model_choice in loaded_models:
28 | print(f"Unloading {model_choice} model from GPU...", flush=True)
29 | if hasattr(loaded_models[model_choice], 'cpu'):
30 | loaded_models[model_choice].cpu()
31 | del loaded_models[model_choice]
32 | if model_choice in loaded_tokenizers:
33 | del loaded_tokenizers[model_choice]
34 | torch.cuda.empty_cache()
35 | print(f"{model_choice} model unloaded successfully!", flush=True)
36 |
37 |
38 | def get_llasa_model(model_choice: str, hf_api_key: str = None):
39 | """
40 | Load and cache the specified model (1B, 3B, or 8B).
41 | If an API key is provided, it is used to authenticate with Hugging Face.
42 | """
43 | from .models import loaded_models, loaded_tokenizers, quantization_config
44 |
45 | # Determine repo name
46 | if model_choice == "1B":
47 | repo = "HKUSTAudio/Llasa-1B"
48 | elif model_choice == "3B":
49 | repo = "srinivasbilla/llasa-3b"
50 | else:
51 | repo = "HKUSTAudio/Llasa-8B"
52 |
53 | # Unload any other loaded model
54 | for existing_model in list(loaded_models.keys()):
55 | if existing_model != model_choice:
56 | unload_model(existing_model)
57 |
58 | if model_choice not in loaded_models:
59 | print(f"Preparing to load {repo}...", flush=True)
60 | print(f"Current GPU memory usage: {get_gpu_memory():.2f}GB", flush=True)
61 |
62 | hub_path = os.path.join(
63 | os.path.expanduser("~"),
64 | ".cache",
65 | "huggingface",
66 | "hub",
67 | "models--" + repo.replace("/", "--")
68 | )
69 |
70 | if os.path.exists(hub_path):
71 | print(f"Loading {repo} from local cache...", flush=True)
72 | else:
73 | print(f"Model {repo} not found in cache. Starting download...", flush=True)
74 |
75 | print("Loading tokenizer...", flush=True)
76 | tokenizer = AutoTokenizer.from_pretrained(repo, use_auth_token=hf_api_key)
77 | print("Tokenizer loaded successfully!", flush=True)
78 |
79 | print(f"Loading {model_choice} model into memory...", flush=True)
80 | model = AutoModelForCausalLM.from_pretrained(
81 | repo,
82 | trust_remote_code=True,
83 | device_map='cuda',
84 | quantization_config=quantization_config,
85 | low_cpu_mem_usage=True,
86 | use_auth_token=hf_api_key,
87 | torch_dtype=torch.float16
88 | )
89 | torch.cuda.empty_cache()
90 | print(f"{model_choice} model loaded successfully! (GPU memory: {get_gpu_memory():.2f}GB)", flush=True)
91 | loaded_tokenizers[model_choice] = tokenizer
92 | loaded_models[model_choice] = model
93 |
94 | return loaded_tokenizers[model_choice], loaded_models[model_choice]
95 |
96 |
97 | def check_model_in_cache(model_choice: str) -> bool:
98 | """
99 | Check if the given model repo is already present in the local Hugging Face cache.
100 | """
101 | if model_choice == "1B":
102 | repo = "HKUSTAudio/Llasa-1B"
103 | elif model_choice == "3B":
104 | repo = "srinivasbilla/llasa-3b"
105 | else:
106 | repo = "HKUSTAudio/Llasa-8B"
107 |
108 | hub_path = os.path.join(
109 | os.path.expanduser("~"),
110 | ".cache",
111 | "huggingface",
112 | "hub",
113 | "models--" + repo.replace("/", "--")
114 | )
115 | return os.path.exists(hub_path)
116 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Llasa TTS Dashboard
2 | A powerful, local text-to-speech system powered by **Llasa TTS** models. This project offers a modern, interactive dashboard that supports multiple model sizes (1B, 3B, and 8B) and introduces a new **Podcast Mode** for multi-speaker conversation synthesis.
3 |
4 |
5 |
6 | ---
7 |
8 | ## Overview
9 | The **Llasa TTS Dashboard** transforms traditional text-to-speech pipelines into a robust, user-friendly application. With efficient GPU utilization, and flexible generation controls, the dashboard is designed for both developers and end users who demand high-quality speech synthesis.
10 |
11 | ### Key Features
12 | - **Multi-Model Support**
13 |
14 | Switch easily between **1B**, **3B**, and **8B** models.
15 |
16 | - **Standard TTS Mode**
17 |
18 | Generate natural-sounding speech either from plain text or with a reference audio prompt.
19 |
20 | - **Podcast Mode**
21 |
22 | Create multi-speaker podcasts from transcripts. Configure reference audio and seeds for each speaker to produce consistent character voices.
23 |
24 | - **Advanced Generation Controls**
25 |
26 | Fine-tune parameters such as max length, temperature, and top-p. Use random or fixed seeds for reproducibility.
27 |
28 | - **Clean & Modern UI**
29 |
30 | A sleek, two-panel interface built with Gradio. Enjoy a dark theme that enhances readability.
31 |
32 | ---
33 | ## System Requirements
34 |
35 | - **Python 3.10+**
36 | - **CUDA-Capable NVIDIA GPU**
37 |
38 | - **VRAM Requirements:**
39 | - **8.5 GB+ VRAM:** When running with Whisper Large Turbo in 4-bit mode.
40 | - **6.5 GB+ VRAM:** When running without Whisper and using the LLM in 4-bit mode.
41 | ---
42 |
43 | ## Installation
44 |
45 | ### Clone the Repository
46 |
47 | ```bash
48 |
49 | git clone https://github.com/nivibilla/local_llasa_tts.git
50 |
51 | cd local_llasa_tts
52 |
53 | ```
54 | ### Setup the Environment
55 |
56 | **If you're on Windows, this works best when using WSL2.**
57 | Install the necessary dependencies:
58 |
59 | ```bash
60 |
61 | pip install -r requirements_base.txt
62 |
63 | pip install -r requirements_native_hf.txt
64 |
65 | ```
66 | ---
67 |
68 | ## Usage
69 |
70 | You can start the application in several ways:
71 |
72 | ### Run via Module
73 |
74 | From the project root directory, execute:
75 |
76 | ```bash
77 |
78 | python -m src.main
79 |
80 | ```
81 |
82 | ### Using Provided Scripts
83 |
84 | - **Unix/Linux/Mac:**
85 |
86 | Make sure `run.sh` is executable and run it:
87 |
88 | ```bash
89 |
90 | chmod +x run.sh
91 |
92 | ./run.sh
93 |
94 | ```
95 |
96 | - **Windows:**
97 |
98 | Double-click `run.bat` or run it from the command prompt:
99 |
100 | ```batch
101 |
102 | run.bat
103 |
104 | ```
105 |
106 | ### Dashboard Modes
107 |
108 | #### Standard TTS Mode
109 |
110 | - **Model Selection:** Choose between 1B, 3B, or 8B.
111 |
112 | - **Generation Mode:** Select "Text only" or "Reference audio."
113 |
114 | - **Advanced Settings:** Adjust max length, temperature, and top-p.
115 |
116 | - **Output:** Listen to the synthesized speech and review previous generations.
117 |
118 | #### Podcast Mode
119 |
120 | - **Transcript Input:** Enter a conversation transcript with each line formatted as `Speaker Name: message`.
121 |
122 | - **Speaker Configuration:** Optionally provide reference audio and seeds for each speaker.
123 |
124 | - **Advanced Settings:** Configure generation parameters similar to Standard TTS.
125 |
126 | - **Output:** Generate a complete podcast audio file with seamless transitions.
127 |
128 | *Screenshot:*
129 |
130 |
131 |
132 | ---
133 | ## Additional Resources
134 |
135 |
136 |
137 | - **Long Text Inference:**
138 |
139 | Refer to [llasa_vllm_longtext_inference.ipynb](./llasa_vllm_longtext_inference.ipynb) for handling long text inputs using VLLM and chunking.
140 |
141 | - **Google Colab:**
142 |
143 | If you do not have a suitable local GPU, try our [Colab Notebook](https://colab.research.google.com/github/YourUser/local_llasa_tts/blob/main/colab_notebook_4bit.ipynb).
144 |
145 | ---
146 |
147 | ## Acknowledgements
148 |
149 | - **Original LLaSA Training Repository:** Inspired by [zhenye234/LLaSA_training](https://github.com/zhenye234/LLaSA_training).
150 |
151 | - **Gradio Demo Inspiration:** UI concepts adapted from [mrfakename/E2-F5-TTS](https://huggingface.co/spaces/mrfakename/E2-F5-TTS).
152 |
--------------------------------------------------------------------------------
/src/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import io
3 | import base64
4 | import json
5 | import random
6 | import numpy as np
7 | import torch
8 | import torchaudio
9 | import soundfile as sf
10 | import gradio as gr
11 |
12 | from .models import get_llasa_model, get_gpu_memory
13 | from xcodec2.modeling_xcodec2 import XCodec2Model
14 | from transformers import pipeline
15 |
16 | # Global constants / settings
17 | HF_KEY_ENV_VAR = "LLASA_API_KEY"
18 | MAX_HISTORY = 5 # How many previous generations to keep
19 | history_data = [] # In-memory history list
20 |
21 | # Will hold references to XCodec2Model and Whisper pipeline
22 | Codec_model = None
23 | whisper_turbo_pipe = None
24 |
25 | def initialize_models():
26 | """Initialize XCodec2 and Whisper models at startup."""
27 | global Codec_model, whisper_turbo_pipe
28 | print("Step 1/3: Preparing XCodec2 model...", flush=True)
29 | model_path = "srinivasbilla/xcodec2"
30 | import os
31 | hub_path = os.path.join(
32 | os.path.expanduser("~"),
33 | ".cache", "huggingface", "hub",
34 | "models--" + model_path.replace("/", "--")
35 | )
36 | if os.path.exists(hub_path):
37 | print(f"Loading XCodec2 model from local cache...", flush=True)
38 | else:
39 | print(f"Model {model_path} not found in cache. Starting download...", flush=True)
40 | print("Loading XCodec2 model into memory...", flush=True)
41 | Codec_model = XCodec2Model.from_pretrained(model_path)
42 | Codec_model.eval().cuda()
43 | torch.cuda.empty_cache()
44 | print(f"XCodec2 model loaded successfully! (GPU memory: {get_gpu_memory():.2f}GB)")
45 |
46 | print("\nStep 2/3: Preparing Whisper model...", flush=True)
47 | whisper_model = "openai/whisper-large-v3-turbo"
48 | hub_path = os.path.join(
49 | os.path.expanduser("~"),
50 | ".cache", "huggingface", "hub",
51 | "models--" + whisper_model.replace("/", "--")
52 | )
53 | if os.path.exists(hub_path):
54 | print(f"Loading Whisper model from local cache...", flush=True)
55 | else:
56 | print(f"Model {whisper_model} not found in cache. Starting download...", flush=True)
57 | print("Loading Whisper model and preparing pipeline...", flush=True)
58 | whisper_turbo_pipe = pipeline(
59 | "automatic-speech-recognition",
60 | model=whisper_model,
61 | torch_dtype=torch.float16,
62 | device='cuda'
63 | )
64 | torch.cuda.empty_cache()
65 | print(f"Whisper model loaded successfully! (GPU memory: {get_gpu_memory():.2f}GB)\n")
66 |
67 |
68 | ###############################################################################
69 | # Utility / Rendering Functions #
70 | ###############################################################################
71 |
72 | def ids_to_speech_tokens(speech_ids):
73 | """Convert list of integers to token strings."""
74 | return [f"<|s_{speech_id}|>" for speech_id in speech_ids]
75 |
76 | def extract_speech_ids(speech_tokens_str):
77 | """Extract integer IDs from tokens like <|s_123|>."""
78 | speech_ids = []
79 | for token_str in speech_tokens_str:
80 | if token_str.startswith('<|s_') and token_str.endswith('|>'):
81 | try:
82 | num = int(token_str[4:-2])
83 | speech_ids.append(num)
84 | except ValueError:
85 | print(f"Failed to convert token: {token_str}")
86 | else:
87 | print(f"Unexpected token: {token_str}")
88 | return speech_ids
89 |
90 | def generate_audio_data_url(audio_np, sample_rate=16000, format='WAV'):
91 | """Encode NumPy audio array into a base64 data URL for HTML audio tags."""
92 | if audio_np.dtype != np.float32:
93 | audio_np = audio_np.astype(np.float32)
94 | if np.abs(audio_np).max() > 1.0:
95 | audio_np = audio_np / np.abs(audio_np).max()
96 | audio_int16 = (audio_np * 32767).astype(np.int16)
97 | with io.BytesIO() as buf:
98 | sf.write(buf, audio_int16, sample_rate, format=format, subtype='PCM_16')
99 | audio_data = base64.b64encode(buf.getvalue()).decode('utf-8')
100 | return f"data:audio/wav;base64,{audio_data}"
101 |
102 | def render_previous_generations(history_list, is_generating=False):
103 | """Render history entries as HTML."""
104 | if not history_list and not is_generating:
105 | return "
Text: {entry['text']}
141 |Params: max_len={entry['max_length']}, temp={entry['temperature']}, top_p={entry['top_p']}{', seed=' + str(entry.get('seed')) if entry.get('seed') is not None else ''}
142 |{speaker}: {text}
\n" 197 | return html 198 | 199 | 200 | ############################################################################### 201 | # Core Inference # 202 | ############################################################################### 203 | 204 | def set_seed(seed): 205 | """Set seeds for reproducible generation.""" 206 | if seed is not None: 207 | torch.manual_seed(seed) 208 | torch.cuda.manual_seed_all(seed) 209 | np.random.seed(seed) 210 | random.seed(seed) 211 | torch.backends.cudnn.deterministic = True 212 | 213 | 214 | def infer( 215 | generation_mode, # "Text only" or "Reference audio" 216 | ref_audio_path, # path to ref audio (if any) 217 | target_text, # text to synthesize 218 | model_version, # "1B", "3B", or "8B" 219 | hf_api_key, # HF API key 220 | trim_audio, # trim ref audio to 15s? 221 | max_length, # generation param 222 | temperature, # generation param 223 | top_p, # generation param 224 | whisper_language, # whisper language 225 | user_seed, # user-provided seed 226 | random_seed_each_gen, # random seed if True 227 | beam_search_enabled, # beam search flag 228 | auto_optimize_length, # auto-optimize length 229 | prev_history, # prior generation history 230 | progress=gr.Progress() 231 | ): 232 | from .models import get_llasa_model 233 | 234 | # Handle seeds 235 | if random_seed_each_gen: 236 | chosen_seed = random.randint(0, 2**31 - 1) 237 | else: 238 | chosen_seed = user_seed 239 | set_seed(chosen_seed) 240 | 241 | # If there's an env var for HF token, fallback if no API key given 242 | if (not hf_api_key or not hf_api_key.strip()): 243 | env_key = os.environ.get(HF_KEY_ENV_VAR, "").strip() 244 | if env_key: 245 | hf_api_key = env_key 246 | 247 | # Acquire model and tokenizer 248 | tokenizer, model = get_llasa_model(model_version, hf_api_key=hf_api_key) 249 | 250 | if len(target_text) == 0: 251 | return None, render_previous_generations(prev_history), prev_history 252 | elif len(target_text) > 1000: 253 | gr.warning("Text is too long. Truncating to 1000 characters.") 254 | target_text = target_text[:1000] 255 | 256 | # Possibly auto-optimize max_length BEFORE we build final input 257 | # (We also do a check after building input_ids, below.) 258 | # -- We'll do the final check after the input is built to be safe. 259 | 260 | from .inference import Codec_model, whisper_turbo_pipe 261 | 262 | # Handle reference audio if needed 263 | speech_ids_prefix = [] 264 | prompt_text = "" 265 | if generation_mode == "Reference audio" and ref_audio_path: 266 | progress(0, "Loading & trimming reference audio...") 267 | waveform, sample_rate = torchaudio.load(ref_audio_path) 268 | if trim_audio and (waveform.shape[1] / sample_rate) > 15: 269 | waveform = waveform[:, :sample_rate * 15] 270 | 271 | # Resample to 16k 272 | if waveform.size(0) > 1: 273 | waveform_mono = torch.mean(waveform, dim=0, keepdim=True) 274 | else: 275 | waveform_mono = waveform 276 | prompt_wav = torchaudio.transforms.Resample( 277 | orig_freq=sample_rate, new_freq=16000 278 | )(waveform_mono) 279 | 280 | # Transcribe with Whisper 281 | whisper_args = {} 282 | if whisper_language != "auto": 283 | whisper_args["language"] = whisper_language 284 | prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy(), generate_kwargs=whisper_args)['text'].strip() 285 | 286 | # Encode reference audio with XCodec2 287 | with torch.no_grad(): 288 | vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav) 289 | vq_code_prompt = vq_code_prompt[0, 0, :] # shape: [T] 290 | speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt) 291 | elif generation_mode == "Reference audio" and not ref_audio_path: 292 | gr.warning("No reference audio provided. Proceeding in text-only mode.") 293 | 294 | progress(0.5, "Generating speech...") 295 | 296 | # Combine any reference text + user text 297 | combined_input_text = prompt_text + " " + target_text 298 | prefix_str = "".join(speech_ids_prefix) if speech_ids_prefix else "" 299 | formatted_text = f"<|TEXT_UNDERSTANDING_START|>{combined_input_text}<|TEXT_UNDERSTANDING_END|>" 300 | chat = [ 301 | {"role": "user", "content": "Convert the text to speech:" + formatted_text}, 302 | {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + prefix_str}, 303 | ] 304 | num_beams = 2 if beam_search_enabled else 1 305 | early_stopping_val = (num_beams > 1) 306 | 307 | model_inputs = tokenizer.apply_chat_template( 308 | chat, 309 | tokenize=True, 310 | return_tensors="pt", 311 | continue_final_message=True 312 | ) 313 | input_ids = model_inputs.to("cuda") 314 | attention_mask = torch.ones_like(input_ids).to("cuda") 315 | 316 | # Final auto-optimize check 317 | if auto_optimize_length: 318 | input_len = input_ids.shape[1] 319 | margin = 100 if generation_mode == "Reference audio" else 50 320 | if input_len + margin > max_length: 321 | old_val = max_length 322 | max_length = input_len + margin 323 | print(f"Auto optimizing: input length is {input_len}, raising max_length from {old_val} to {max_length}.") 324 | 325 | # Generate tokens 326 | speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") 327 | with torch.no_grad(): 328 | outputs = model.generate( 329 | input_ids, 330 | attention_mask=attention_mask, 331 | pad_token_id=(tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id), 332 | max_length=int(max_length), 333 | min_length=int(max_length * 0.5), 334 | eos_token_id=speech_end_id, 335 | do_sample=True, 336 | num_beams=num_beams, 337 | length_penalty=1.5, 338 | temperature=float(temperature), 339 | top_p=float(top_p), 340 | repetition_penalty=1.2, 341 | early_stopping=early_stopping_val, 342 | no_repeat_ngram_size=3, 343 | ) 344 | 345 | prefix_len = len(speech_ids_prefix) 346 | # cutting off prefix from the final output 347 | generated_ids = outputs[0][(input_ids.shape[1] - prefix_len) : -1] 348 | speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) 349 | speech_tokens = extract_speech_ids(speech_tokens) 350 | speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0) 351 | gen_wav = Codec_model.decode_code(speech_tokens) 352 | 353 | # If we had a reference prompt, remove that segment from the final 354 | if speech_ids_prefix: 355 | gen_wav = gen_wav[:, :, prompt_wav.shape[1] :] 356 | 357 | sr = 16000 358 | out_audio_np = gen_wav[0, 0, :].cpu().numpy() 359 | progress(0.9, "Finalizing audio...") 360 | 361 | audio_data_url = generate_audio_data_url(out_audio_np, sample_rate=sr) 362 | new_entry = { 363 | "mode": generation_mode, 364 | "text": target_text, 365 | "audio_url": audio_data_url, 366 | "temperature": temperature, 367 | "top_p": top_p, 368 | "max_length": max_length, 369 | "seed": chosen_seed, 370 | } 371 | 372 | if len(prev_history) >= MAX_HISTORY: 373 | prev_history.pop(0) 374 | prev_history.append(new_entry) 375 | updated_dashboard_html = render_previous_generations(prev_history, is_generating=False) 376 | 377 | return (sr, out_audio_np), updated_dashboard_html, prev_history 378 | 379 | 380 | def infer_podcast( 381 | conversation_text, 382 | generation_mode, # "Podcast" 383 | model_choice, 384 | hf_api_key, 385 | trim_audio, 386 | max_length, 387 | temperature, 388 | top_p, 389 | whisper_language, 390 | user_seed, 391 | random_seed_each_gen, 392 | beam_search_enabled, 393 | auto_optimize_length, 394 | prev_history, 395 | speaker_config=None, 396 | progress=gr.Progress() 397 | ): 398 | """ 399 | Generate podcast audio line by line, taking speaker-specific configurations. 400 | """ 401 | if speaker_config is None: 402 | speaker_config = {} 403 | 404 | from .inference import parse_conversation, generate_audio_data_url, render_previous_generations 405 | from .inference import join_audio_segments, infer 406 | 407 | conversation, speakers = parse_conversation(conversation_text) 408 | audio_segments = [] 409 | 410 | for speaker, line_text in conversation: 411 | # Retrieve speaker-specific config 412 | config = speaker_config.get(speaker.lower(), {"ref_audio": "", "seed": None}) 413 | ref_audio = config.get("ref_audio", "") 414 | seed = config.get("seed", None) 415 | 416 | # Decide generation mode 417 | line_mode = "Reference audio" if ref_audio else "Text only" 418 | result = infer( 419 | line_mode, 420 | ref_audio, 421 | line_text, 422 | model_choice, 423 | hf_api_key, 424 | trim_audio, 425 | max_length, 426 | temperature, 427 | top_p, 428 | whisper_language, 429 | seed, 430 | random_seed_each_gen, 431 | beam_search_enabled, 432 | auto_optimize_length, 433 | prev_history=[], 434 | progress=progress 435 | ) 436 | _, line_audio = result[0] 437 | audio_segments.append(line_audio) 438 | 439 | final_audio = join_audio_segments(audio_segments, sample_rate=16000, crossfade_duration=0.05) 440 | audio_url = generate_audio_data_url(final_audio, sample_rate=16000) 441 | 442 | new_entry = { 443 | "mode": "Podcast", 444 | "text": conversation_text, 445 | "audio_url": audio_url, 446 | "temperature": temperature, 447 | "top_p": top_p, 448 | "max_length": max_length, 449 | "seed": "N/A", 450 | } 451 | if len(prev_history) >= MAX_HISTORY: 452 | prev_history.pop(0) 453 | prev_history.append(new_entry) 454 | updated_dashboard_html = render_previous_generations(prev_history, is_generating=False) 455 | 456 | return (16000, final_audio), updated_dashboard_html, prev_history 457 | -------------------------------------------------------------------------------- /src/app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import torch 3 | import os 4 | 5 | from .inference import ( 6 | initialize_models, 7 | infer, 8 | infer_podcast, 9 | render_previous_generations 10 | ) 11 | from .models import check_model_in_cache 12 | 13 | # Inline CSS for the dark theme (unchanged) 14 | NEW_CSS = """ 15 | /* Remove Gradio branding/footer */ 16 | #footer, .gradio-container a[target="_blank"] { display: none; } 17 | /* Simple dark background */ 18 | body, .gradio-container { margin: 0; padding: 0; background-color: #1E1E2A; color: #EAEAEA; font-family: 'Segoe UI', sans-serif; } 19 | /* Header styling */ 20 | #header { background-color: #2E2F46; padding: 1rem 2rem; text-align: center; } 21 | #header h1 { margin: 0; font-size: 2rem; } 22 | /* Main content row styling */ 23 | #content-row { display: flex; flex-direction: row; gap: 1rem; padding: 1rem 2rem; } 24 | /* Synthesis panel */ 25 | #synthesis-panel { flex: 2; background-color: #222233; border-radius: 8px; padding: 1.5rem; } 26 | /* History panel */ 27 | #history-panel { flex: 1; background-color: #222233; border-radius: 8px; padding: 1.5rem; } 28 | /* Form elements styling */ 29 | .gr-textbox input, .gr-textbox textarea, .gr-dropdown select { background-color: #38395A; border: 1px solid #4A4B6F; color: #F1F1F1; border-radius: 4px; padding: 0.5rem; } 30 | /* Audio components */ 31 | .audio-input, .audio-output { background-color: #2E2F46 !important; border-radius: 8px !important; padding: 12px !important; margin: 8px 0 !important; } 32 | """ 33 | 34 | def build_dashboard(): 35 | """ 36 | Build the Gradio interface with separate tabs for Standard TTS and Podcast Mode. 37 | Adds a dynamic check for whether the chosen model is in local cache. 38 | If not, the HF API Key field is shown. If in cache, it remains hidden. 39 | """ 40 | theme = gr.themes.Default( 41 | primary_hue="blue", 42 | secondary_hue="slate", 43 | neutral_hue="slate", 44 | font=[gr.themes.GoogleFont("Inter")], 45 | font_mono=gr.themes.GoogleFont("IBM Plex Mono"), 46 | ).set( 47 | background_fill_primary="#1E1E2A", 48 | background_fill_secondary="#222233", 49 | border_color_primary="#4A4B6F", 50 | body_text_color="#EAEAEA", 51 | block_title_text_color="#EAEAEA", 52 | block_label_text_color="#EAEAEA", 53 | input_background_fill="#38395A", 54 | ) 55 | 56 | with gr.Blocks(theme=theme, css=NEW_CSS) as demo: 57 | gr.Markdown("