├── .gitignore ├── src ├── __init__.py ├── main.py ├── models.py ├── inference.py └── app.py ├── run.bat ├── requirements_native_hf.txt ├── images ├── intro.png └── podcast.png ├── requirements_base.txt ├── run.sh └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | __pycache__ -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Intentionally left empty -------------------------------------------------------------------------------- /run.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | python -m src.main 3 | pause -------------------------------------------------------------------------------- /requirements_native_hf.txt: -------------------------------------------------------------------------------- 1 | gradio==5.14.0 2 | xcodec2==0.1.3 3 | bitsandbytes>=0.39.0 -------------------------------------------------------------------------------- /images/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nivibilla/local-llasa-tts/HEAD/images/intro.png -------------------------------------------------------------------------------- /images/podcast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nivibilla/local-llasa-tts/HEAD/images/podcast.png -------------------------------------------------------------------------------- /requirements_base.txt: -------------------------------------------------------------------------------- 1 | torch==2.5.1 2 | torchaudio==2.5.1 3 | torchvision==0.20.1 4 | numpy==1.26.4 -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # run.sh: Launch the Llasa TTS application using the src package 3 | 4 | # Ensure the script is run from the project root directory 5 | python -m src.main -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import sys 3 | import argparse 4 | import torch 5 | 6 | # Check Python version 7 | if sys.version_info < (3, 10): 8 | print("ERROR: Python 3.10 or higher is required.") 9 | sys.exit(1) 10 | 11 | if not torch.cuda.is_available(): 12 | print("ERROR: CUDA is not available. Please use a CUDA-capable GPU.") 13 | sys.exit(1) 14 | 15 | import os 16 | from .inference import initialize_models 17 | from .models import get_llasa_model 18 | from .app import build_dashboard 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser(description="Run the modular Llasa TTS Dashboard.") 22 | parser.add_argument("--share", help="Enable gradio share", action="store_true") 23 | args = parser.parse_args() 24 | 25 | print("Initializing CUDA backend...", flush=True) 26 | torch.cuda.init() 27 | _ = torch.zeros(1).cuda() 28 | print(f"Using device: {torch.cuda.get_device_name()}", flush=True) 29 | 30 | # Initialize local models 31 | print("\nStep 1: Loading XCodec2 and Whisper models...", flush=True) 32 | initialize_models() 33 | 34 | print("\nStep 2: Preloading Llasa 3B model (faster startup for standard usage)...", flush=True) 35 | get_llasa_model("3B") 36 | print("Preload done. Models are ready!") 37 | 38 | # Launch Gradio 39 | print("\nLaunching Gradio interface...", flush=True) 40 | app = build_dashboard() 41 | app.launch(share=args.share, server_name="0.0.0.0", server_port=7860) 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig 5 | from transformers.utils import move_cache 6 | 7 | # Global caches 8 | loaded_models = {} 9 | loaded_tokenizers = {} 10 | 11 | # Quantization configuration 12 | quantization_config = BitsAndBytesConfig( 13 | load_in_4bit=True, 14 | bnb_4bit_compute_dtype=torch.float16, 15 | bnb_4bit_use_double_quant=True 16 | ) 17 | 18 | def get_gpu_memory(): 19 | """Return current GPU memory usage in GB.""" 20 | if torch.cuda.is_available(): 21 | return torch.cuda.memory_allocated() / 1024**3 22 | return 0.0 23 | 24 | def unload_model(model_choice: str): 25 | """Unload a model from GPU and clear from cache.""" 26 | from .models import loaded_models, loaded_tokenizers 27 | if model_choice in loaded_models: 28 | print(f"Unloading {model_choice} model from GPU...", flush=True) 29 | if hasattr(loaded_models[model_choice], 'cpu'): 30 | loaded_models[model_choice].cpu() 31 | del loaded_models[model_choice] 32 | if model_choice in loaded_tokenizers: 33 | del loaded_tokenizers[model_choice] 34 | torch.cuda.empty_cache() 35 | print(f"{model_choice} model unloaded successfully!", flush=True) 36 | 37 | 38 | def get_llasa_model(model_choice: str, hf_api_key: str = None): 39 | """ 40 | Load and cache the specified model (1B, 3B, or 8B). 41 | If an API key is provided, it is used to authenticate with Hugging Face. 42 | """ 43 | from .models import loaded_models, loaded_tokenizers, quantization_config 44 | 45 | # Determine repo name 46 | if model_choice == "1B": 47 | repo = "HKUSTAudio/Llasa-1B" 48 | elif model_choice == "3B": 49 | repo = "srinivasbilla/llasa-3b" 50 | else: 51 | repo = "HKUSTAudio/Llasa-8B" 52 | 53 | # Unload any other loaded model 54 | for existing_model in list(loaded_models.keys()): 55 | if existing_model != model_choice: 56 | unload_model(existing_model) 57 | 58 | if model_choice not in loaded_models: 59 | print(f"Preparing to load {repo}...", flush=True) 60 | print(f"Current GPU memory usage: {get_gpu_memory():.2f}GB", flush=True) 61 | 62 | hub_path = os.path.join( 63 | os.path.expanduser("~"), 64 | ".cache", 65 | "huggingface", 66 | "hub", 67 | "models--" + repo.replace("/", "--") 68 | ) 69 | 70 | if os.path.exists(hub_path): 71 | print(f"Loading {repo} from local cache...", flush=True) 72 | else: 73 | print(f"Model {repo} not found in cache. Starting download...", flush=True) 74 | 75 | print("Loading tokenizer...", flush=True) 76 | tokenizer = AutoTokenizer.from_pretrained(repo, use_auth_token=hf_api_key) 77 | print("Tokenizer loaded successfully!", flush=True) 78 | 79 | print(f"Loading {model_choice} model into memory...", flush=True) 80 | model = AutoModelForCausalLM.from_pretrained( 81 | repo, 82 | trust_remote_code=True, 83 | device_map='cuda', 84 | quantization_config=quantization_config, 85 | low_cpu_mem_usage=True, 86 | use_auth_token=hf_api_key, 87 | torch_dtype=torch.float16 88 | ) 89 | torch.cuda.empty_cache() 90 | print(f"{model_choice} model loaded successfully! (GPU memory: {get_gpu_memory():.2f}GB)", flush=True) 91 | loaded_tokenizers[model_choice] = tokenizer 92 | loaded_models[model_choice] = model 93 | 94 | return loaded_tokenizers[model_choice], loaded_models[model_choice] 95 | 96 | 97 | def check_model_in_cache(model_choice: str) -> bool: 98 | """ 99 | Check if the given model repo is already present in the local Hugging Face cache. 100 | """ 101 | if model_choice == "1B": 102 | repo = "HKUSTAudio/Llasa-1B" 103 | elif model_choice == "3B": 104 | repo = "srinivasbilla/llasa-3b" 105 | else: 106 | repo = "HKUSTAudio/Llasa-8B" 107 | 108 | hub_path = os.path.join( 109 | os.path.expanduser("~"), 110 | ".cache", 111 | "huggingface", 112 | "hub", 113 | "models--" + repo.replace("/", "--") 114 | ) 115 | return os.path.exists(hub_path) 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Llasa TTS Dashboard 2 | A powerful, local text-to-speech system powered by **Llasa TTS** models. This project offers a modern, interactive dashboard that supports multiple model sizes (1B, 3B, and 8B) and introduces a new **Podcast Mode** for multi-speaker conversation synthesis. 3 | 4 | Standard TTS Tab 5 | 6 | --- 7 | 8 | ## Overview 9 | The **Llasa TTS Dashboard** transforms traditional text-to-speech pipelines into a robust, user-friendly application. With efficient GPU utilization, and flexible generation controls, the dashboard is designed for both developers and end users who demand high-quality speech synthesis. 10 | 11 | ### Key Features 12 | - **Multi-Model Support** 13 | 14 | Switch easily between **1B**, **3B**, and **8B** models. 15 | 16 | - **Standard TTS Mode** 17 | 18 | Generate natural-sounding speech either from plain text or with a reference audio prompt. 19 | 20 | - **Podcast Mode** 21 | 22 | Create multi-speaker podcasts from transcripts. Configure reference audio and seeds for each speaker to produce consistent character voices. 23 | 24 | - **Advanced Generation Controls** 25 | 26 | Fine-tune parameters such as max length, temperature, and top-p. Use random or fixed seeds for reproducibility. 27 | 28 | - **Clean & Modern UI** 29 | 30 | A sleek, two-panel interface built with Gradio. Enjoy a dark theme that enhances readability. 31 | 32 | --- 33 | ## System Requirements 34 | 35 | - **Python 3.10+** 36 | - **CUDA-Capable NVIDIA GPU** 37 | 38 | - **VRAM Requirements:** 39 | - **8.5 GB+ VRAM:** When running with Whisper Large Turbo in 4-bit mode. 40 | - **6.5 GB+ VRAM:** When running without Whisper and using the LLM in 4-bit mode. 41 | --- 42 | 43 | ## Installation 44 | 45 | ### Clone the Repository 46 | 47 | ```bash 48 | 49 | git clone https://github.com/nivibilla/local_llasa_tts.git 50 | 51 | cd local_llasa_tts 52 | 53 | ``` 54 | ### Setup the Environment 55 | 56 | **If you're on Windows, this works best when using WSL2.** 57 | Install the necessary dependencies: 58 | 59 | ```bash 60 | 61 | pip install -r requirements_base.txt 62 | 63 | pip install -r requirements_native_hf.txt 64 | 65 | ``` 66 | --- 67 | 68 | ## Usage 69 | 70 | You can start the application in several ways: 71 | 72 | ### Run via Module 73 | 74 | From the project root directory, execute: 75 | 76 | ```bash 77 | 78 | python -m src.main 79 | 80 | ``` 81 | 82 | ### Using Provided Scripts 83 | 84 | - **Unix/Linux/Mac:** 85 | 86 | Make sure `run.sh` is executable and run it: 87 | 88 | ```bash 89 | 90 | chmod +x run.sh 91 | 92 | ./run.sh 93 | 94 | ``` 95 | 96 | - **Windows:** 97 | 98 | Double-click `run.bat` or run it from the command prompt: 99 | 100 | ```batch 101 | 102 | run.bat 103 | 104 | ``` 105 | 106 | ### Dashboard Modes 107 | 108 | #### Standard TTS Mode 109 | 110 | - **Model Selection:** Choose between 1B, 3B, or 8B. 111 | 112 | - **Generation Mode:** Select "Text only" or "Reference audio." 113 | 114 | - **Advanced Settings:** Adjust max length, temperature, and top-p. 115 | 116 | - **Output:** Listen to the synthesized speech and review previous generations. 117 | 118 | #### Podcast Mode 119 | 120 | - **Transcript Input:** Enter a conversation transcript with each line formatted as `Speaker Name: message`. 121 | 122 | - **Speaker Configuration:** Optionally provide reference audio and seeds for each speaker. 123 | 124 | - **Advanced Settings:** Configure generation parameters similar to Standard TTS. 125 | 126 | - **Output:** Generate a complete podcast audio file with seamless transitions. 127 | 128 | *Screenshot:* 129 | 130 | Podcast TTS Tab 131 | 132 | --- 133 | ## Additional Resources 134 | 135 | 136 | 137 | - **Long Text Inference:** 138 | 139 | Refer to [llasa_vllm_longtext_inference.ipynb](./llasa_vllm_longtext_inference.ipynb) for handling long text inputs using VLLM and chunking. 140 | 141 | - **Google Colab:** 142 | 143 | If you do not have a suitable local GPU, try our [Colab Notebook](https://colab.research.google.com/github/YourUser/local_llasa_tts/blob/main/colab_notebook_4bit.ipynb). 144 | 145 | --- 146 | 147 | ## Acknowledgements 148 | 149 | - **Original LLaSA Training Repository:** Inspired by [zhenye234/LLaSA_training](https://github.com/zhenye234/LLaSA_training). 150 | 151 | - **Gradio Demo Inspiration:** UI concepts adapted from [mrfakename/E2-F5-TTS](https://huggingface.co/spaces/mrfakename/E2-F5-TTS). 152 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import base64 4 | import json 5 | import random 6 | import numpy as np 7 | import torch 8 | import torchaudio 9 | import soundfile as sf 10 | import gradio as gr 11 | 12 | from .models import get_llasa_model, get_gpu_memory 13 | from xcodec2.modeling_xcodec2 import XCodec2Model 14 | from transformers import pipeline 15 | 16 | # Global constants / settings 17 | HF_KEY_ENV_VAR = "LLASA_API_KEY" 18 | MAX_HISTORY = 5 # How many previous generations to keep 19 | history_data = [] # In-memory history list 20 | 21 | # Will hold references to XCodec2Model and Whisper pipeline 22 | Codec_model = None 23 | whisper_turbo_pipe = None 24 | 25 | def initialize_models(): 26 | """Initialize XCodec2 and Whisper models at startup.""" 27 | global Codec_model, whisper_turbo_pipe 28 | print("Step 1/3: Preparing XCodec2 model...", flush=True) 29 | model_path = "srinivasbilla/xcodec2" 30 | import os 31 | hub_path = os.path.join( 32 | os.path.expanduser("~"), 33 | ".cache", "huggingface", "hub", 34 | "models--" + model_path.replace("/", "--") 35 | ) 36 | if os.path.exists(hub_path): 37 | print(f"Loading XCodec2 model from local cache...", flush=True) 38 | else: 39 | print(f"Model {model_path} not found in cache. Starting download...", flush=True) 40 | print("Loading XCodec2 model into memory...", flush=True) 41 | Codec_model = XCodec2Model.from_pretrained(model_path) 42 | Codec_model.eval().cuda() 43 | torch.cuda.empty_cache() 44 | print(f"XCodec2 model loaded successfully! (GPU memory: {get_gpu_memory():.2f}GB)") 45 | 46 | print("\nStep 2/3: Preparing Whisper model...", flush=True) 47 | whisper_model = "openai/whisper-large-v3-turbo" 48 | hub_path = os.path.join( 49 | os.path.expanduser("~"), 50 | ".cache", "huggingface", "hub", 51 | "models--" + whisper_model.replace("/", "--") 52 | ) 53 | if os.path.exists(hub_path): 54 | print(f"Loading Whisper model from local cache...", flush=True) 55 | else: 56 | print(f"Model {whisper_model} not found in cache. Starting download...", flush=True) 57 | print("Loading Whisper model and preparing pipeline...", flush=True) 58 | whisper_turbo_pipe = pipeline( 59 | "automatic-speech-recognition", 60 | model=whisper_model, 61 | torch_dtype=torch.float16, 62 | device='cuda' 63 | ) 64 | torch.cuda.empty_cache() 65 | print(f"Whisper model loaded successfully! (GPU memory: {get_gpu_memory():.2f}GB)\n") 66 | 67 | 68 | ############################################################################### 69 | # Utility / Rendering Functions # 70 | ############################################################################### 71 | 72 | def ids_to_speech_tokens(speech_ids): 73 | """Convert list of integers to token strings.""" 74 | return [f"<|s_{speech_id}|>" for speech_id in speech_ids] 75 | 76 | def extract_speech_ids(speech_tokens_str): 77 | """Extract integer IDs from tokens like <|s_123|>.""" 78 | speech_ids = [] 79 | for token_str in speech_tokens_str: 80 | if token_str.startswith('<|s_') and token_str.endswith('|>'): 81 | try: 82 | num = int(token_str[4:-2]) 83 | speech_ids.append(num) 84 | except ValueError: 85 | print(f"Failed to convert token: {token_str}") 86 | else: 87 | print(f"Unexpected token: {token_str}") 88 | return speech_ids 89 | 90 | def generate_audio_data_url(audio_np, sample_rate=16000, format='WAV'): 91 | """Encode NumPy audio array into a base64 data URL for HTML audio tags.""" 92 | if audio_np.dtype != np.float32: 93 | audio_np = audio_np.astype(np.float32) 94 | if np.abs(audio_np).max() > 1.0: 95 | audio_np = audio_np / np.abs(audio_np).max() 96 | audio_int16 = (audio_np * 32767).astype(np.int16) 97 | with io.BytesIO() as buf: 98 | sf.write(buf, audio_int16, sample_rate, format=format, subtype='PCM_16') 99 | audio_data = base64.b64encode(buf.getvalue()).decode('utf-8') 100 | return f"data:audio/wav;base64,{audio_data}" 101 | 102 | def render_previous_generations(history_list, is_generating=False): 103 | """Render history entries as HTML.""" 104 | if not history_list and not is_generating: 105 | return "
No previous generations yet.
" 106 | html = """ 107 | 123 | """ 124 | # Show skeleton if is_generating 125 | if is_generating: 126 | html += """ 127 |
128 |
129 |
130 |
131 |
132 |
133 | """ 134 | if history_list: 135 | html += "
" 136 | for entry in reversed(history_list): 137 | card_html = f""" 138 |
139 |

Mode: {entry['mode']}

140 |

Text: {entry['text']}

141 |

Params: max_len={entry['max_length']}, temp={entry['temperature']}, top_p={entry['top_p']}{', seed=' + str(entry.get('seed')) if entry.get('seed') is not None else ''}

142 |
143 | 144 |
145 |
146 | """ 147 | html += card_html 148 | html += "
" 149 | return html 150 | 151 | 152 | ############################################################################### 153 | # Podcast Utility Functions # 154 | ############################################################################### 155 | 156 | def parse_conversation(transcript: str): 157 | """ 158 | Parse the transcript into a list of (speaker, message) tuples. 159 | Expected per line: "Speaker Name: message" 160 | """ 161 | lines = transcript.splitlines() 162 | conversation = [] 163 | speakers = set() 164 | for line in lines: 165 | if ':' not in line: 166 | continue 167 | speaker, text = line.split(":", 1) 168 | speaker = speaker.strip() 169 | text = text.strip() 170 | conversation.append((speaker, text)) 171 | speakers.add(speaker) 172 | return conversation, list(speakers) 173 | 174 | def join_audio_segments(segments, sample_rate=16000, crossfade_duration=0.05): 175 | """ 176 | Concatenate a list of 1D NumPy audio arrays with a brief crossfade. 177 | """ 178 | if not segments: 179 | return np.array([], dtype=np.float32) 180 | crossfade_samples = int(sample_rate * crossfade_duration) 181 | joined_audio = segments[0] 182 | for seg in segments[1:]: 183 | if crossfade_samples > 0 and len(joined_audio) >= crossfade_samples and len(seg) >= crossfade_samples: 184 | fade_out = np.linspace(1, 0, crossfade_samples) 185 | fade_in = np.linspace(0, 1, crossfade_samples) 186 | joined_audio[-crossfade_samples:] = joined_audio[-crossfade_samples:] * fade_out + seg[:crossfade_samples] * fade_in 187 | joined_audio = np.concatenate([joined_audio, seg[crossfade_samples:]]) 188 | else: 189 | joined_audio = np.concatenate([joined_audio, seg]) 190 | return joined_audio 191 | 192 | def build_transcript_html(conversation): 193 | """Build an HTML transcript with speaker labels.""" 194 | html = "" 195 | for speaker, text in conversation: 196 | html += f"

{speaker}: {text}

\n" 197 | return html 198 | 199 | 200 | ############################################################################### 201 | # Core Inference # 202 | ############################################################################### 203 | 204 | def set_seed(seed): 205 | """Set seeds for reproducible generation.""" 206 | if seed is not None: 207 | torch.manual_seed(seed) 208 | torch.cuda.manual_seed_all(seed) 209 | np.random.seed(seed) 210 | random.seed(seed) 211 | torch.backends.cudnn.deterministic = True 212 | 213 | 214 | def infer( 215 | generation_mode, # "Text only" or "Reference audio" 216 | ref_audio_path, # path to ref audio (if any) 217 | target_text, # text to synthesize 218 | model_version, # "1B", "3B", or "8B" 219 | hf_api_key, # HF API key 220 | trim_audio, # trim ref audio to 15s? 221 | max_length, # generation param 222 | temperature, # generation param 223 | top_p, # generation param 224 | whisper_language, # whisper language 225 | user_seed, # user-provided seed 226 | random_seed_each_gen, # random seed if True 227 | beam_search_enabled, # beam search flag 228 | auto_optimize_length, # auto-optimize length 229 | prev_history, # prior generation history 230 | progress=gr.Progress() 231 | ): 232 | from .models import get_llasa_model 233 | 234 | # Handle seeds 235 | if random_seed_each_gen: 236 | chosen_seed = random.randint(0, 2**31 - 1) 237 | else: 238 | chosen_seed = user_seed 239 | set_seed(chosen_seed) 240 | 241 | # If there's an env var for HF token, fallback if no API key given 242 | if (not hf_api_key or not hf_api_key.strip()): 243 | env_key = os.environ.get(HF_KEY_ENV_VAR, "").strip() 244 | if env_key: 245 | hf_api_key = env_key 246 | 247 | # Acquire model and tokenizer 248 | tokenizer, model = get_llasa_model(model_version, hf_api_key=hf_api_key) 249 | 250 | if len(target_text) == 0: 251 | return None, render_previous_generations(prev_history), prev_history 252 | elif len(target_text) > 1000: 253 | gr.warning("Text is too long. Truncating to 1000 characters.") 254 | target_text = target_text[:1000] 255 | 256 | # Possibly auto-optimize max_length BEFORE we build final input 257 | # (We also do a check after building input_ids, below.) 258 | # -- We'll do the final check after the input is built to be safe. 259 | 260 | from .inference import Codec_model, whisper_turbo_pipe 261 | 262 | # Handle reference audio if needed 263 | speech_ids_prefix = [] 264 | prompt_text = "" 265 | if generation_mode == "Reference audio" and ref_audio_path: 266 | progress(0, "Loading & trimming reference audio...") 267 | waveform, sample_rate = torchaudio.load(ref_audio_path) 268 | if trim_audio and (waveform.shape[1] / sample_rate) > 15: 269 | waveform = waveform[:, :sample_rate * 15] 270 | 271 | # Resample to 16k 272 | if waveform.size(0) > 1: 273 | waveform_mono = torch.mean(waveform, dim=0, keepdim=True) 274 | else: 275 | waveform_mono = waveform 276 | prompt_wav = torchaudio.transforms.Resample( 277 | orig_freq=sample_rate, new_freq=16000 278 | )(waveform_mono) 279 | 280 | # Transcribe with Whisper 281 | whisper_args = {} 282 | if whisper_language != "auto": 283 | whisper_args["language"] = whisper_language 284 | prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy(), generate_kwargs=whisper_args)['text'].strip() 285 | 286 | # Encode reference audio with XCodec2 287 | with torch.no_grad(): 288 | vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav) 289 | vq_code_prompt = vq_code_prompt[0, 0, :] # shape: [T] 290 | speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt) 291 | elif generation_mode == "Reference audio" and not ref_audio_path: 292 | gr.warning("No reference audio provided. Proceeding in text-only mode.") 293 | 294 | progress(0.5, "Generating speech...") 295 | 296 | # Combine any reference text + user text 297 | combined_input_text = prompt_text + " " + target_text 298 | prefix_str = "".join(speech_ids_prefix) if speech_ids_prefix else "" 299 | formatted_text = f"<|TEXT_UNDERSTANDING_START|>{combined_input_text}<|TEXT_UNDERSTANDING_END|>" 300 | chat = [ 301 | {"role": "user", "content": "Convert the text to speech:" + formatted_text}, 302 | {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + prefix_str}, 303 | ] 304 | num_beams = 2 if beam_search_enabled else 1 305 | early_stopping_val = (num_beams > 1) 306 | 307 | model_inputs = tokenizer.apply_chat_template( 308 | chat, 309 | tokenize=True, 310 | return_tensors="pt", 311 | continue_final_message=True 312 | ) 313 | input_ids = model_inputs.to("cuda") 314 | attention_mask = torch.ones_like(input_ids).to("cuda") 315 | 316 | # Final auto-optimize check 317 | if auto_optimize_length: 318 | input_len = input_ids.shape[1] 319 | margin = 100 if generation_mode == "Reference audio" else 50 320 | if input_len + margin > max_length: 321 | old_val = max_length 322 | max_length = input_len + margin 323 | print(f"Auto optimizing: input length is {input_len}, raising max_length from {old_val} to {max_length}.") 324 | 325 | # Generate tokens 326 | speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") 327 | with torch.no_grad(): 328 | outputs = model.generate( 329 | input_ids, 330 | attention_mask=attention_mask, 331 | pad_token_id=(tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id), 332 | max_length=int(max_length), 333 | min_length=int(max_length * 0.5), 334 | eos_token_id=speech_end_id, 335 | do_sample=True, 336 | num_beams=num_beams, 337 | length_penalty=1.5, 338 | temperature=float(temperature), 339 | top_p=float(top_p), 340 | repetition_penalty=1.2, 341 | early_stopping=early_stopping_val, 342 | no_repeat_ngram_size=3, 343 | ) 344 | 345 | prefix_len = len(speech_ids_prefix) 346 | # cutting off prefix from the final output 347 | generated_ids = outputs[0][(input_ids.shape[1] - prefix_len) : -1] 348 | speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) 349 | speech_tokens = extract_speech_ids(speech_tokens) 350 | speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0) 351 | gen_wav = Codec_model.decode_code(speech_tokens) 352 | 353 | # If we had a reference prompt, remove that segment from the final 354 | if speech_ids_prefix: 355 | gen_wav = gen_wav[:, :, prompt_wav.shape[1] :] 356 | 357 | sr = 16000 358 | out_audio_np = gen_wav[0, 0, :].cpu().numpy() 359 | progress(0.9, "Finalizing audio...") 360 | 361 | audio_data_url = generate_audio_data_url(out_audio_np, sample_rate=sr) 362 | new_entry = { 363 | "mode": generation_mode, 364 | "text": target_text, 365 | "audio_url": audio_data_url, 366 | "temperature": temperature, 367 | "top_p": top_p, 368 | "max_length": max_length, 369 | "seed": chosen_seed, 370 | } 371 | 372 | if len(prev_history) >= MAX_HISTORY: 373 | prev_history.pop(0) 374 | prev_history.append(new_entry) 375 | updated_dashboard_html = render_previous_generations(prev_history, is_generating=False) 376 | 377 | return (sr, out_audio_np), updated_dashboard_html, prev_history 378 | 379 | 380 | def infer_podcast( 381 | conversation_text, 382 | generation_mode, # "Podcast" 383 | model_choice, 384 | hf_api_key, 385 | trim_audio, 386 | max_length, 387 | temperature, 388 | top_p, 389 | whisper_language, 390 | user_seed, 391 | random_seed_each_gen, 392 | beam_search_enabled, 393 | auto_optimize_length, 394 | prev_history, 395 | speaker_config=None, 396 | progress=gr.Progress() 397 | ): 398 | """ 399 | Generate podcast audio line by line, taking speaker-specific configurations. 400 | """ 401 | if speaker_config is None: 402 | speaker_config = {} 403 | 404 | from .inference import parse_conversation, generate_audio_data_url, render_previous_generations 405 | from .inference import join_audio_segments, infer 406 | 407 | conversation, speakers = parse_conversation(conversation_text) 408 | audio_segments = [] 409 | 410 | for speaker, line_text in conversation: 411 | # Retrieve speaker-specific config 412 | config = speaker_config.get(speaker.lower(), {"ref_audio": "", "seed": None}) 413 | ref_audio = config.get("ref_audio", "") 414 | seed = config.get("seed", None) 415 | 416 | # Decide generation mode 417 | line_mode = "Reference audio" if ref_audio else "Text only" 418 | result = infer( 419 | line_mode, 420 | ref_audio, 421 | line_text, 422 | model_choice, 423 | hf_api_key, 424 | trim_audio, 425 | max_length, 426 | temperature, 427 | top_p, 428 | whisper_language, 429 | seed, 430 | random_seed_each_gen, 431 | beam_search_enabled, 432 | auto_optimize_length, 433 | prev_history=[], 434 | progress=progress 435 | ) 436 | _, line_audio = result[0] 437 | audio_segments.append(line_audio) 438 | 439 | final_audio = join_audio_segments(audio_segments, sample_rate=16000, crossfade_duration=0.05) 440 | audio_url = generate_audio_data_url(final_audio, sample_rate=16000) 441 | 442 | new_entry = { 443 | "mode": "Podcast", 444 | "text": conversation_text, 445 | "audio_url": audio_url, 446 | "temperature": temperature, 447 | "top_p": top_p, 448 | "max_length": max_length, 449 | "seed": "N/A", 450 | } 451 | if len(prev_history) >= MAX_HISTORY: 452 | prev_history.pop(0) 453 | prev_history.append(new_entry) 454 | updated_dashboard_html = render_previous_generations(prev_history, is_generating=False) 455 | 456 | return (16000, final_audio), updated_dashboard_html, prev_history 457 | -------------------------------------------------------------------------------- /src/app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import torch 3 | import os 4 | 5 | from .inference import ( 6 | initialize_models, 7 | infer, 8 | infer_podcast, 9 | render_previous_generations 10 | ) 11 | from .models import check_model_in_cache 12 | 13 | # Inline CSS for the dark theme (unchanged) 14 | NEW_CSS = """ 15 | /* Remove Gradio branding/footer */ 16 | #footer, .gradio-container a[target="_blank"] { display: none; } 17 | /* Simple dark background */ 18 | body, .gradio-container { margin: 0; padding: 0; background-color: #1E1E2A; color: #EAEAEA; font-family: 'Segoe UI', sans-serif; } 19 | /* Header styling */ 20 | #header { background-color: #2E2F46; padding: 1rem 2rem; text-align: center; } 21 | #header h1 { margin: 0; font-size: 2rem; } 22 | /* Main content row styling */ 23 | #content-row { display: flex; flex-direction: row; gap: 1rem; padding: 1rem 2rem; } 24 | /* Synthesis panel */ 25 | #synthesis-panel { flex: 2; background-color: #222233; border-radius: 8px; padding: 1.5rem; } 26 | /* History panel */ 27 | #history-panel { flex: 1; background-color: #222233; border-radius: 8px; padding: 1.5rem; } 28 | /* Form elements styling */ 29 | .gr-textbox input, .gr-textbox textarea, .gr-dropdown select { background-color: #38395A; border: 1px solid #4A4B6F; color: #F1F1F1; border-radius: 4px; padding: 0.5rem; } 30 | /* Audio components */ 31 | .audio-input, .audio-output { background-color: #2E2F46 !important; border-radius: 8px !important; padding: 12px !important; margin: 8px 0 !important; } 32 | """ 33 | 34 | def build_dashboard(): 35 | """ 36 | Build the Gradio interface with separate tabs for Standard TTS and Podcast Mode. 37 | Adds a dynamic check for whether the chosen model is in local cache. 38 | If not, the HF API Key field is shown. If in cache, it remains hidden. 39 | """ 40 | theme = gr.themes.Default( 41 | primary_hue="blue", 42 | secondary_hue="slate", 43 | neutral_hue="slate", 44 | font=[gr.themes.GoogleFont("Inter")], 45 | font_mono=gr.themes.GoogleFont("IBM Plex Mono"), 46 | ).set( 47 | background_fill_primary="#1E1E2A", 48 | background_fill_secondary="#222233", 49 | border_color_primary="#4A4B6F", 50 | body_text_color="#EAEAEA", 51 | block_title_text_color="#EAEAEA", 52 | block_label_text_color="#EAEAEA", 53 | input_background_fill="#38395A", 54 | ) 55 | 56 | with gr.Blocks(theme=theme, css=NEW_CSS) as demo: 57 | gr.Markdown("", elem_id="header") 58 | # Shared state for previous generations 59 | prev_history_state = gr.State([]) 60 | 61 | ######################################## 62 | # DYNAMIC VISIBILITY: HF API Key Input 63 | ######################################## 64 | def toggle_api_key_visibility(model_choice): 65 | in_cache = check_model_in_cache(model_choice) 66 | # If the model is in local cache, hide the text field 67 | return gr.update(visible=not in_cache) 68 | 69 | with gr.Tabs(): 70 | # --- Standard TTS Tab --- 71 | with gr.TabItem("Standard TTS"): 72 | with gr.Row(elem_id="content-row"): 73 | with gr.Column(elem_id="synthesis-panel"): 74 | gr.Markdown("## Standard TTS") 75 | model_choice_std = gr.Dropdown( 76 | label="Select llasa Model", 77 | choices=["1B", "3B", "8B"], 78 | value="3B" 79 | ) 80 | generation_mode_std = gr.Radio( 81 | label="Generation Mode", 82 | choices=["Text only", "Reference audio"], 83 | value="Text only", 84 | type="value" 85 | ) 86 | with gr.Group(): 87 | ref_audio_input = gr.Audio( 88 | label="Reference Audio (Optional)", 89 | sources=["upload", "microphone"], 90 | type="filepath" 91 | ) 92 | trim_audio_checkbox_std = gr.Checkbox( 93 | label="Trim Reference Audio to 15s?", 94 | value=False 95 | ) 96 | gen_text_input = gr.Textbox( 97 | label="Text to Generate", 98 | lines=4, 99 | placeholder="Enter text here..." 100 | ) 101 | 102 | with gr.Accordion("Advanced Generation Settings", open=False): 103 | max_length_slider_std = gr.Slider( 104 | minimum=64, 105 | maximum=4096, 106 | value=1024, 107 | step=64, 108 | label="Max Length (tokens)" 109 | ) 110 | temperature_slider_std = gr.Slider( 111 | minimum=0.1, 112 | maximum=2.0, 113 | value=1.0, 114 | step=0.1, 115 | label="Temperature" 116 | ) 117 | top_p_slider_std = gr.Slider( 118 | minimum=0.1, 119 | maximum=1.0, 120 | value=1.0, 121 | step=0.05, 122 | label="Top-p" 123 | ) 124 | whisper_language_std = gr.Dropdown( 125 | label="Whisper Language (for reference audio)", 126 | choices=["en", "auto", "ja", "zh", "de", "es", "ru", "ko", "fr"], 127 | value="en", 128 | type="value" 129 | ) 130 | random_seed_checkbox_std = gr.Checkbox( 131 | label="Random seed each generation", 132 | value=True 133 | ) 134 | beam_search_checkbox_std = gr.Checkbox( 135 | label="Enable beam search", 136 | value=False 137 | ) 138 | auto_optimize_checkbox_std = gr.Checkbox( 139 | label="[Text Only] Auto Optimize Length", 140 | value=True 141 | ) 142 | seed_number_std = gr.Number( 143 | label="Seed (if not random)", 144 | value=None, 145 | precision=0, 146 | minimum=0, 147 | maximum=2**32-1, 148 | step=1 149 | ) 150 | api_key_input_std = gr.Textbox( 151 | label="Hugging Face API Key (Required only if model not in cache)", 152 | type="password", 153 | placeholder="Enter your HF token or leave blank", 154 | visible=False 155 | ) 156 | 157 | synthesize_btn_std = gr.Button("Synthesize") 158 | 159 | with gr.Group(): 160 | audio_output_std = gr.Audio( 161 | label="Synthesized Audio", 162 | type="numpy", 163 | interactive=False, 164 | show_label=True, 165 | autoplay=False 166 | ) 167 | 168 | with gr.Column(elem_id="history-panel"): 169 | gr.Markdown("## Previous Generations") 170 | dashboard_html_std = gr.HTML( 171 | value="
No previous generations yet.
", 172 | show_label=False 173 | ) 174 | 175 | # --- Podcast Mode Tab --- 176 | with gr.TabItem("Podcast Mode"): 177 | with gr.Row(elem_id="content-row"): 178 | with gr.Column(elem_id="synthesis-panel"): 179 | gr.Markdown("## Podcast Mode") 180 | gr.Markdown("⚠️ **Experimental Feature** ⚠️\nWorks best with reference audio for each speaker.") 181 | 182 | model_choice_pod = gr.Dropdown( 183 | label="Select llasa Model", 184 | choices=["1B", "3B", "8B"], 185 | value="3B" 186 | ) 187 | podcast_transcript = gr.Textbox( 188 | label="Podcast Transcript", 189 | lines=6, 190 | placeholder="Each line -> 'Speaker Name: message'" 191 | ) 192 | with gr.Accordion("Speaker Configuration (Add as many as needed)", open=True): 193 | gr.Markdown("Fill out details for each speaker present in the transcript.") 194 | speaker1_name = gr.Textbox( 195 | label="Speaker 1 Name", 196 | placeholder="e.g., Alex" 197 | ) 198 | ref_audio_speaker1 = gr.Audio( 199 | label="Reference Audio for Speaker 1 (Optional)", 200 | sources=["upload", "microphone"], 201 | type="filepath" 202 | ) 203 | seed_speaker1 = gr.Number( 204 | label="Seed for Speaker 1 (Optional)", 205 | value=None, 206 | precision=0 207 | ) 208 | 209 | speaker2_name = gr.Textbox( 210 | label="Speaker 2 Name", 211 | placeholder="e.g., Jamie" 212 | ) 213 | ref_audio_speaker2 = gr.Audio( 214 | label="Reference Audio for Speaker 2 (Optional)", 215 | sources=["upload", "microphone"], 216 | type="filepath" 217 | ) 218 | seed_speaker2 = gr.Number( 219 | label="Seed for Speaker 2 (Optional)", 220 | value=None, 221 | precision=0 222 | ) 223 | 224 | speaker3_name = gr.Textbox( 225 | label="Speaker 3 Name (Optional)", 226 | placeholder="e.g., Casey" 227 | ) 228 | ref_audio_speaker3 = gr.Audio( 229 | label="Reference Audio for Speaker 3 (Optional)", 230 | sources=["upload", "microphone"], 231 | type="filepath" 232 | ) 233 | seed_speaker3 = gr.Number( 234 | label="Seed for Speaker 3 (Optional)", 235 | value=None, 236 | precision=0 237 | ) 238 | 239 | with gr.Accordion("Advanced Generation Settings", open=False): 240 | max_length_slider_pod = gr.Slider( 241 | minimum=64, 242 | maximum=4096, 243 | value=1024, 244 | step=64, 245 | label="Max Length (tokens)" 246 | ) 247 | temperature_slider_pod = gr.Slider( 248 | minimum=0.1, 249 | maximum=2.0, 250 | value=1.0, 251 | step=0.1, 252 | label="Temperature" 253 | ) 254 | top_p_slider_pod = gr.Slider( 255 | minimum=0.1, 256 | maximum=1.0, 257 | value=1.0, 258 | step=0.05, 259 | label="Top-p" 260 | ) 261 | whisper_language_pod = gr.Dropdown( 262 | label="Whisper Language (for reference audio)", 263 | choices=["en", "auto", "ja", "zh", "de", "es", "ru", "ko", "fr"], 264 | value="en", 265 | type="value" 266 | ) 267 | random_seed_checkbox_pod = gr.Checkbox( 268 | label="Random seed each generation", 269 | value=True 270 | ) 271 | beam_search_checkbox_pod = gr.Checkbox( 272 | label="Enable beam search", 273 | value=False 274 | ) 275 | auto_optimize_checkbox_pod = gr.Checkbox( 276 | label="[Text Only] Auto Optimize Length", 277 | value=True 278 | ) 279 | seed_number_pod = gr.Number( 280 | label="Seed (if not random)", 281 | value=None, 282 | precision=0, 283 | minimum=0, 284 | maximum=2**32-1, 285 | step=1 286 | ) 287 | 288 | api_key_input_pod = gr.Textbox( 289 | label="Hugging Face API Key (Required only if model not in cache)", 290 | type="password", 291 | placeholder="Enter your HF token or leave blank", 292 | visible=False 293 | ) 294 | 295 | synthesize_btn_pod = gr.Button("Synthesize Podcast") 296 | 297 | with gr.Group(): 298 | audio_output_pod = gr.Audio( 299 | label="Synthesized Podcast Audio", 300 | type="numpy", 301 | interactive=False, 302 | show_label=True, 303 | autoplay=False 304 | ) 305 | 306 | with gr.Column(elem_id="history-panel"): 307 | gr.Markdown("## Previous Generations") 308 | dashboard_html_pod = gr.HTML( 309 | value="
No previous generations yet.
", 310 | show_label=False 311 | ) 312 | 313 | # Define helper callback for Standard TTS 314 | def synthesize_standard( 315 | generation_mode, ref_audio_input, gen_text_input, model_choice, api_key_input, 316 | max_length_slider, temperature_slider, top_p_slider, whisper_language, 317 | seed_number, random_seed_checkbox, beam_search_checkbox, auto_optimize_checkbox, 318 | trim_audio, prev_history 319 | ): 320 | return infer( 321 | generation_mode, 322 | ref_audio_input, 323 | gen_text_input, 324 | model_choice, 325 | api_key_input, 326 | trim_audio, 327 | max_length_slider, 328 | temperature_slider, 329 | top_p_slider, 330 | whisper_language, 331 | seed_number, 332 | random_seed_checkbox, 333 | beam_search_checkbox, 334 | auto_optimize_checkbox, 335 | prev_history 336 | ) 337 | 338 | # Define helper callback for Podcast 339 | def synthesize_podcast_fn( 340 | podcast_transcript, model_choice, api_key_input, 341 | max_length_slider, temperature_slider, top_p_slider, whisper_language, 342 | seed_number, random_seed_checkbox, beam_search_checkbox, auto_optimize_checkbox, 343 | prev_history, 344 | speaker1_name, ref_audio_speaker1, seed_speaker1, 345 | speaker2_name, ref_audio_speaker2, seed_speaker2, 346 | speaker3_name, ref_audio_speaker3, seed_speaker3 347 | ): 348 | # Build speaker_config dictionary 349 | speaker_config = {} 350 | for name, ref, seed in [ 351 | (speaker1_name, ref_audio_speaker1, seed_speaker1), 352 | (speaker2_name, ref_audio_speaker2, seed_speaker2), 353 | (speaker3_name, ref_audio_speaker3, seed_speaker3), 354 | ]: 355 | if name and name.strip(): 356 | speaker_config[name.strip().lower()] = { 357 | "ref_audio": ref if ref else "", 358 | "seed": seed 359 | } 360 | 361 | return infer_podcast( 362 | podcast_transcript, 363 | "Podcast", 364 | model_choice, 365 | api_key_input, 366 | False, # trim_audio 367 | max_length_slider, 368 | temperature_slider, 369 | top_p_slider, 370 | whisper_language, 371 | seed_number, 372 | random_seed_checkbox, 373 | beam_search_checkbox, 374 | auto_optimize_checkbox, 375 | prev_history, 376 | speaker_config=speaker_config 377 | ) 378 | 379 | # --- Wire up Standard TTS Tab --- 380 | synthesize_btn_std.click( 381 | lambda history: render_previous_generations(history, is_generating=True), 382 | inputs=[prev_history_state], 383 | outputs=[dashboard_html_std] 384 | ).then( 385 | synthesize_standard, 386 | inputs=[ 387 | generation_mode_std, 388 | ref_audio_input, 389 | gen_text_input, 390 | model_choice_std, 391 | api_key_input_std, 392 | max_length_slider_std, 393 | temperature_slider_std, 394 | top_p_slider_std, 395 | whisper_language_std, 396 | seed_number_std, 397 | random_seed_checkbox_std, 398 | beam_search_checkbox_std, 399 | auto_optimize_checkbox_std, 400 | trim_audio_checkbox_std, 401 | prev_history_state 402 | ], 403 | outputs=[audio_output_std, dashboard_html_std, prev_history_state] 404 | ) 405 | 406 | # --- Wire up Podcast Mode Tab --- 407 | synthesize_btn_pod.click( 408 | lambda history: render_previous_generations(history, is_generating=True), 409 | inputs=[prev_history_state], 410 | outputs=[dashboard_html_pod] 411 | ).then( 412 | synthesize_podcast_fn, 413 | inputs=[ 414 | podcast_transcript, 415 | model_choice_pod, 416 | api_key_input_pod, 417 | max_length_slider_pod, 418 | temperature_slider_pod, 419 | top_p_slider_pod, 420 | whisper_language_pod, 421 | seed_number_pod, 422 | random_seed_checkbox_pod, 423 | beam_search_checkbox_pod, 424 | auto_optimize_checkbox_pod, 425 | prev_history_state, 426 | speaker1_name, ref_audio_speaker1, seed_speaker1, 427 | speaker2_name, ref_audio_speaker2, seed_speaker2, 428 | speaker3_name, ref_audio_speaker3, seed_speaker3 429 | ], 430 | outputs=[audio_output_pod, dashboard_html_pod, prev_history_state] 431 | ) 432 | 433 | # Show/hide API key input if model not cached 434 | model_choice_std.change( 435 | toggle_api_key_visibility, 436 | inputs=[model_choice_std], 437 | outputs=[api_key_input_std] 438 | ) 439 | model_choice_pod.change( 440 | toggle_api_key_visibility, 441 | inputs=[model_choice_pod], 442 | outputs=[api_key_input_pod] 443 | ) 444 | 445 | # On load, also run the toggle once to set correct visibility 446 | demo.load( 447 | fn=toggle_api_key_visibility, 448 | inputs=[model_choice_std], 449 | outputs=[api_key_input_std] 450 | ) 451 | demo.load( 452 | fn=toggle_api_key_visibility, 453 | inputs=[model_choice_pod], 454 | outputs=[api_key_input_pod] 455 | ) 456 | 457 | return demo 458 | --------------------------------------------------------------------------------