├── templates ├── index.html ├── crud.html ├── setup.html └── chat.html ├── static ├── test.mp3 ├── crud.js ├── app.js └── chat.js ├── finetuned_model └── config.json ├── .gitignore ├── requirements.txt ├── .github └── FUNDING.yml ├── loadandmergecheckpoint.py ├── test.py ├── config.py ├── run_csm.py ├── setup.py ├── llm_interface.py ├── models.py ├── vad.py ├── README.md ├── LICENSE ├── rag_system.py └── generator.py /templates/index.html: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /static/test.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbrowne17/csm-streaming/HEAD/static/test.mp3 -------------------------------------------------------------------------------- /finetuned_model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "audio_num_codebooks": 32, 3 | "audio_vocab_size": 2051, 4 | "backbone_flavor": "llama-1B", 5 | "decoder_flavor": "llama-100M", 6 | "text_vocab_size": 128256 7 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Virtual Environment 24 | .env 25 | .venv 26 | env/ 27 | venv/ 28 | ENV/ 29 | 30 | # IDE 31 | .idea/ 32 | .vscode/ 33 | *.swp 34 | *.swo 35 | 36 | # Project specific 37 | .python-version 38 | *.wav 39 | output_*/ 40 | basic_audio.wav 41 | full_conversation.wav 42 | context_audio.wav 43 | 44 | # Model files 45 | *.pt 46 | *.ckpt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url=https://download.pytorch.org/whl/cu128 2 | vllm==0.8.0 3 | torch==2.6.0 4 | torchaudio==2.6.0 5 | torchvision==0.21.0 6 | tokenizers==0.21.0 7 | transformers==4.49.0 8 | huggingface_hub==0.28.1 9 | moshi==0.2.2 10 | sounddevice 11 | torchtune==0.4.0 12 | torchao==0.9.0 13 | bitsandbytes 14 | peft 15 | wandb 16 | silero_vad 17 | python-multipart>=0.0.6 18 | aiofiles>=23.1.0 19 | sentence-transformers>=2.2.2 20 | ctransformers>=0.2.24 21 | python-multipart>=0.0.6 22 | sqlalchemy>=2.0.0 23 | pydantic>=2.0.0 24 | fastapi>=0.95.0 25 | uvicorn>=0.22.0 26 | websockets>=11.0.3 27 | jinja2>=3.0.0 28 | speechbrain>=0.5.15 29 | matplotlib 30 | whisper-openai 31 | silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master 32 | numpy==1.26.0 -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: davidbrowne17 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | polar: # Replace with a single Polar username 13 | buy_me_a_coffee: # Replace with a single Buy Me a Coffee username 14 | thanks_dev: # Replace with a single thanks.dev username 15 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 16 | -------------------------------------------------------------------------------- /templates/crud.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Conversation Manager 6 | 7 | 8 | 9 | 10 |
11 | 12 | Return to Setup 13 | 14 |

Memory Manager

15 |
16 | 22 | 23 |
24 | 25 |
26 |
27 | 28 | 29 | -------------------------------------------------------------------------------- /loadandmergecheckpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import torch 4 | from models import Model 5 | from safetensors.torch import save_file, load_file 6 | 7 | from lora import ( 8 | remove_lora_modules, 9 | merge_lora_weights, 10 | strip_bias_keys, 11 | DEVICE, 12 | OUTPUT_DIR, 13 | replace_linear_with_lora, 14 | ) 15 | MODEL_NAME = "sesame/csm-1b" 16 | R=32 17 | APLHA=32 18 | 19 | def find_latest_checkpoint(dir_path): 20 | checkpoints = [ 21 | (int(re.search(r"checkpoint-epoch-(\d+)", d).group(1)), os.path.join(dir_path, d)) 22 | for d in os.listdir(dir_path) 23 | if os.path.isdir(os.path.join(dir_path, d)) and "checkpoint-epoch" in d 24 | ] 25 | if not checkpoints: 26 | raise FileNotFoundError("No checkpoints found.") 27 | latest_epoch, latest_path = max(checkpoints, key=lambda x: x[0]) 28 | print(f"Latest checkpoint: epoch {latest_epoch} -> {latest_path}") 29 | return latest_path 30 | 31 | def load_checkpoint_and_merge(): 32 | print("Loading base model...") 33 | model = Model.from_pretrained(MODEL_NAME).to(DEVICE) 34 | 35 | print("Applying LoRA structure to the model...") 36 | target_layers = ['q_proj', 'k_proj', 'v_proj', 'o_proj'] 37 | 38 | model = replace_linear_with_lora(model, r=R, alpha=APLHA, dropout=0.0, target_linear_names = target_layers) 39 | checkpoint_path = find_latest_checkpoint(OUTPUT_DIR) 40 | 41 | print(f"Loading state dictionary from safetensors file...") 42 | state_dict = load_file(os.path.join(checkpoint_path, "model.safetensors"), device=DEVICE) 43 | 44 | print("Loading weights into the model...") 45 | model.load_state_dict(state_dict, strict=False) 46 | 47 | print("Merging LoRA weights into base model...") 48 | merge_lora_weights(model) 49 | 50 | print("Replacing LoRALinear modules with standard nn.Linear...") 51 | model = remove_lora_modules(model) 52 | 53 | print("Stripping bias keys for final clean model...") 54 | merged_state = strip_bias_keys(model.state_dict()) 55 | 56 | final_path = os.path.join(OUTPUT_DIR, "model.safetensors") 57 | save_file(merged_state, final_path) 58 | print(f"Merged and cleaned model saved to: {final_path}") 59 | 60 | if __name__ == "__main__": 61 | load_checkpoint_and_merge() 62 | -------------------------------------------------------------------------------- /static/crud.js: -------------------------------------------------------------------------------- 1 | let allConversations = []; 2 | 3 | document.addEventListener('DOMContentLoaded', async () => { 4 | await loadConversations(); 5 | 6 | document.getElementById('searchInput').addEventListener('input', () => { 7 | const query = document.getElementById('searchInput').value.toLowerCase(); 8 | const filtered = allConversations.filter(c => 9 | c.user_message.toLowerCase().includes(query) || 10 | c.ai_message.toLowerCase().includes(query) 11 | ); 12 | renderConversations(filtered); 13 | }); 14 | 15 | document.getElementById('deleteAllBtn').addEventListener('click', async () => { 16 | if (!confirm("Are you sure you want to delete ALL conversations?")) return; 17 | await fetch('/api/conversations', { method: 'DELETE' }); 18 | await loadConversations(); 19 | }); 20 | }); 21 | 22 | async function loadConversations() { 23 | const res = await fetch('/api/conversations'); 24 | allConversations = await res.json(); 25 | renderConversations(allConversations); 26 | } 27 | 28 | function renderConversations(list) { 29 | const container = document.getElementById('conversationList'); 30 | container.innerHTML = ''; 31 | 32 | if (list.length === 0) { 33 | container.innerHTML = '

No conversations found.

'; 34 | return; 35 | } 36 | 37 | list.forEach(conv => { 38 | const div = document.createElement('div'); 39 | div.className = "bg-gray-800 p-4 rounded shadow"; 40 | div.innerHTML = ` 41 |
User:
42 | 43 |
AI:
44 | 45 | 46 | 47 | `; 48 | container.appendChild(div); 49 | 50 | div.querySelector('.saveBtn').addEventListener('click', async () => { 51 | const id = conv.id; 52 | const user = div.querySelector('textarea[data-field="user"]').value; 53 | const ai = div.querySelector('textarea[data-field="ai"]').value; 54 | await fetch(`/api/conversations/${id}`, { 55 | method: 'PUT', 56 | headers: { 'Content-Type': 'application/json' }, 57 | body: JSON.stringify({ user_message: user, ai_message: ai }) 58 | }); 59 | alert("Saved."); 60 | }); 61 | 62 | div.querySelector('.deleteBtn').addEventListener('click', async () => { 63 | const id = conv.id; 64 | if (!confirm("Delete this conversation?")) return; 65 | await fetch(`/api/conversations/${id}`, { method: 'DELETE' }); 66 | await loadConversations(); 67 | }); 68 | }); 69 | } 70 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import time 2 | from generator import Segment, load_csm_1b, generate_streaming_audio 3 | import torchaudio 4 | 5 | print(f"Starting script at: {time.strftime('%H:%M:%S')}") 6 | start_time = time.time() 7 | 8 | print("Downloading model...") 9 | model_start = time.time() 10 | print(f"Model download completed in {time.time() - model_start:.2f} seconds") 11 | 12 | print("Loading model to CUDA...") 13 | load_start = time.time() 14 | generator = load_csm_1b("cuda") 15 | print(f"Model loaded in {time.time() - load_start:.2f} seconds") 16 | 17 | speakers = [0, 1, 0, 0] 18 | transcripts = [ 19 | "Hey how are you doing.", 20 | "Pretty good, pretty good.", 21 | "I'm great.", 22 | "So happy to be speaking to you.", 23 | ] 24 | audio_paths = [ 25 | "utterance_0.wav", 26 | "utterance_1.wav", 27 | "utterance_2.wav", 28 | "utterance_3.wav", 29 | ] 30 | 31 | def load_audio(audio_path): 32 | print(f"Loading reference audio: {audio_path}") 33 | audio_load_start = time.time() 34 | audio_tensor, sample_rate = torchaudio.load(audio_path) 35 | audio_tensor = torchaudio.functional.resample( 36 | audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate 37 | ) 38 | print(f"Audio loaded and resampled in {time.time() - audio_load_start:.2f} seconds") 39 | return audio_tensor 40 | 41 | print("Creating segments with reference audio...") 42 | segments_start = time.time() 43 | segments = [ 44 | Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path)) 45 | for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths) 46 | ] 47 | print(f"Segments created in {time.time() - segments_start:.2f} seconds") 48 | 49 | # Option 1: Regular generation with streaming internally enabled 50 | print("Generating audio (with internal streaming)...") 51 | gen_start = time.time() 52 | audio = generator.generate( 53 | text="Me too, this is some cool stuff huh?", 54 | speaker=0, 55 | context=segments, 56 | max_audio_length_ms=10_000, 57 | stream=True # Enable internal streaming 58 | ) 59 | print(f"Audio generation completed in {time.time() - gen_start:.2f} seconds") 60 | 61 | print("Saving audio file...") 62 | save_start = time.time() 63 | torchaudio.save("audio_regular.wav", audio.unsqueeze(0).cpu(), generator.sample_rate) 64 | print(f"Audio saved in {time.time() - save_start:.2f} seconds") 65 | 66 | # Option 2: Use the streaming helper function that saves as it goes 67 | print("Generating audio using streaming API...") 68 | generate_streaming_audio( 69 | generator=generator, 70 | text="Me too, this is some cool stuff huh?", 71 | speaker=0, 72 | context=segments, 73 | output_file="audio_streamed.wav", 74 | max_audio_length_ms=10_000, 75 | play_audio=True # Set to True to play audio in real-time (requires sounddevice package) 76 | ) 77 | 78 | total_time = time.time() - start_time 79 | print(f"Total execution time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)") 80 | print(f"Script completed at: {time.strftime('%H:%M:%S')}") -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | from typing import Dict, Any, Optional 5 | from pydantic import BaseModel 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | class ConfigManager: 10 | """ 11 | Manages configuration persistence for the AI Companion app. 12 | Saves and loads configuration to avoid re-entering model paths. 13 | """ 14 | def __init__(self, config_path: str = "config/app_config.json"): 15 | """ 16 | Initialize the configuration manager. 17 | 18 | Args: 19 | config_path: Path to store the configuration file 20 | """ 21 | self.config_path = config_path 22 | self.config_dir = os.path.dirname(config_path) 23 | 24 | # Create config directory if it doesn't exist 25 | if not os.path.exists(self.config_dir): 26 | os.makedirs(self.config_dir, exist_ok=True) 27 | logger.info(f"Created configuration directory: {self.config_dir}") 28 | 29 | def save_config(self, config_data: Dict[str, Any]) -> bool: 30 | """ 31 | Save configuration data to the config file. 32 | 33 | Args: 34 | config_data: Configuration data to save 35 | 36 | Returns: 37 | bool: True if successful, False otherwise 38 | """ 39 | try: 40 | # Ensure directory exists 41 | os.makedirs(self.config_dir, exist_ok=True) 42 | print(config_data) 43 | # Verify all reference paths are included 44 | ref_paths = [ 45 | "reference_audio_path", 46 | "reference_audio_path2", 47 | "reference_audio_path3" 48 | ] 49 | 50 | # Log which references are being saved 51 | for path_key in ref_paths: 52 | if path_key in config_data and config_data[path_key]: 53 | logger.info(f"Saving reference path: {path_key}={config_data[path_key]}") 54 | else: 55 | logger.info(f"No {path_key} provided in configuration") 56 | 57 | # Save configuration 58 | with open(self.config_path, 'w') as f: 59 | json.dump(config_data, f, indent=2) 60 | 61 | logger.info(f"Configuration saved to {self.config_path}") 62 | return True 63 | 64 | except Exception as e: 65 | logger.error(f"Failed to save configuration: {e}") 66 | return False 67 | 68 | def load_config(self) -> Optional[Dict[str, Any]]: 69 | """ 70 | Load configuration data from the config file. 71 | 72 | Returns: 73 | Dict or None: Configuration data if successful, None otherwise 74 | """ 75 | if not os.path.exists(self.config_path): 76 | logger.info(f"Configuration file does not exist at {self.config_path}") 77 | return None 78 | 79 | try: 80 | with open(self.config_path, 'r') as f: 81 | config_data = json.load(f) 82 | 83 | # Log which references are being loaded 84 | ref_paths = [ 85 | "reference_audio_path", 86 | "reference_audio_path2", 87 | "reference_audio_path3" 88 | ] 89 | 90 | for path_key in ref_paths: 91 | if path_key in config_data and config_data[path_key]: 92 | logger.info(f"Loaded reference path: {path_key}={config_data[path_key]}") 93 | 94 | logger.info(f"Configuration loaded from {self.config_path}") 95 | return config_data 96 | 97 | except Exception as e: 98 | logger.error(f"Failed to load configuration: {e}") 99 | return None 100 | 101 | # Helper function to convert Pydantic model to dict 102 | def model_to_dict(model: BaseModel) -> Dict[str, Any]: 103 | """Convert a Pydantic model to a dictionary suitable for JSON serialization""" 104 | return json.loads(model.json()) -------------------------------------------------------------------------------- /run_csm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchaudio 4 | from huggingface_hub import hf_hub_download 5 | from generator import load_csm_1b, Segment 6 | from dataclasses import dataclass 7 | 8 | 9 | # Default prompts are available at https://hf.co/sesame/csm-1b 10 | prompt_filepath_conversational_a = hf_hub_download( 11 | repo_id="sesame/csm-1b", 12 | filename="prompts/conversational_a.wav" 13 | ) 14 | prompt_filepath_conversational_b = hf_hub_download( 15 | repo_id="sesame/csm-1b", 16 | filename="prompts/conversational_b.wav" 17 | ) 18 | 19 | SPEAKER_PROMPTS = { 20 | "conversational_a": { 21 | "text": ( 22 | "like revising for an exam I'd have to try and like keep up the momentum because I'd " 23 | "start really early I'd be like okay I'm gonna start revising now and then like " 24 | "you're revising for ages and then I just like start losing steam I didn't do that " 25 | "for the exam we had recently to be fair that was a more of a last minute scenario " 26 | "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I " 27 | "sort of start the day with this not like a panic but like a" 28 | ), 29 | "audio": prompt_filepath_conversational_a 30 | }, 31 | "conversational_b": { 32 | "text": ( 33 | "like a super Mario level. Like it's very like high detail. And like, once you get " 34 | "into the park, it just like, everything looks like a computer game and they have all " 35 | "these, like, you know, if, if there's like a, you know, like in a Mario game, they " 36 | "will have like a question block. And if you like, you know, punch it, a coin will " 37 | "come out. So like everyone, when they come into the park, they get like this little " 38 | "bracelet and then you can go punching question blocks around." 39 | ), 40 | "audio": prompt_filepath_conversational_b 41 | } 42 | } 43 | 44 | def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor: 45 | audio_tensor, sample_rate = torchaudio.load(audio_path) 46 | audio_tensor = audio_tensor.squeeze(0) 47 | # Resample is lazy so we can always call it 48 | audio_tensor = torchaudio.functional.resample( 49 | audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate 50 | ) 51 | return audio_tensor 52 | 53 | def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment: 54 | audio_tensor = load_prompt_audio(audio_path, sample_rate) 55 | return Segment(text=text, speaker=speaker, audio=audio_tensor) 56 | 57 | def main(): 58 | # Select the best available device, skipping MPS due to float64 limitations 59 | if torch.cuda.is_available(): 60 | device = "cuda" 61 | else: 62 | device = "cpu" 63 | print(f"Using device: {device}") 64 | 65 | # Load model 66 | generator = load_csm_1b(device) 67 | 68 | # Prepare prompts 69 | prompt_a = prepare_prompt( 70 | SPEAKER_PROMPTS["conversational_a"]["text"], 71 | 0, 72 | SPEAKER_PROMPTS["conversational_a"]["audio"], 73 | generator.sample_rate 74 | ) 75 | 76 | prompt_b = prepare_prompt( 77 | SPEAKER_PROMPTS["conversational_b"]["text"], 78 | 1, 79 | SPEAKER_PROMPTS["conversational_b"]["audio"], 80 | generator.sample_rate 81 | ) 82 | 83 | # Generate conversation 84 | conversation = [ 85 | {"text": "Hey how are you doing?", "speaker_id": 0}, 86 | {"text": "Pretty good, pretty good. How about you?", "speaker_id": 1}, 87 | {"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0}, 88 | {"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1} 89 | ] 90 | 91 | # Generate each utterance 92 | generated_segments = [] 93 | prompt_segments = [prompt_a, prompt_b] 94 | 95 | for utterance in conversation: 96 | print(f"Generating: {utterance['text']}") 97 | audio_tensor = generator.generate( 98 | text=utterance['text'], 99 | speaker=utterance['speaker_id'], 100 | context=prompt_segments + generated_segments, 101 | max_audio_length_ms=10_000, 102 | ) 103 | generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor)) 104 | 105 | # Concatenate all generations 106 | all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0) 107 | torchaudio.save( 108 | "full_conversation.wav", 109 | all_audio.unsqueeze(0).cpu(), 110 | generator.sample_rate 111 | ) 112 | print("Successfully generated full_conversation.wav") 113 | 114 | if __name__ == "__main__": 115 | main() -------------------------------------------------------------------------------- /templates/setup.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | AI Companion Setup 7 | 8 | 9 | 44 | 45 | 46 |
47 |
48 |

AI Companion Setup

49 | 50 | Open Conversation DB 51 | 52 |
53 | 56 | 59 | 62 | 63 | 64 |
65 |
66 |

Primary Reference Audio (Required)

67 |
68 | 71 | 74 |
75 | 76 | 77 |
78 |
79 |

Secondary Reference (Optional)

80 | For better voice quality 81 |
82 | 85 | 88 |
89 | 90 | 91 |
92 |
93 |

Tertiary Reference (Optional)

94 | For even better voice quality 95 |
96 | 99 | 102 |
103 | 104 | 107 |
108 | 109 | 110 |
111 | 114 |
115 | 116 | 117 |
118 |
119 | 120 | 124 |
125 |
126 |
127 |
128 | 129 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | import logging 5 | import urllib.request 6 | import torch 7 | import time 8 | import shutil 9 | from pathlib import Path 10 | 11 | # Configure logging 12 | logging.basicConfig(level=logging.INFO, 13 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 14 | logger = logging.getLogger(__name__) 15 | 16 | def check_requirements(): 17 | """Check if all required Python packages are installed""" 18 | logger.info("Checking requirements...") 19 | 20 | requirements = [ 21 | "torch", "torchaudio", "fastapi", "uvicorn", "websockets", "numpy", 22 | "scikit-learn", "sqlalchemy", "pydantic", "jinja2", "whisper", 23 | "sounddevice", "soundfile", "sentence_transformers", "ctransformers" 24 | ] 25 | 26 | missing = [] 27 | for req in requirements: 28 | try: 29 | __import__(req) 30 | except ImportError: 31 | missing.append(req) 32 | 33 | if missing: 34 | logger.warning(f"Missing required packages: {', '.join(missing)}") 35 | logger.info("Installing missing requirements...") 36 | subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"]) 37 | logger.info("Requirements installed successfully") 38 | else: 39 | logger.info("All requirements are satisfied") 40 | 41 | def download_vad_model(): 42 | """Download the Silero VAD model using PyTorch Hub instead of direct URL""" 43 | model_path = "silero_vad.jit" 44 | 45 | if os.path.exists(model_path): 46 | logger.info(f"Silero VAD model already exists at {model_path}") 47 | return 48 | 49 | logger.info("Downloading Silero VAD model using PyTorch Hub...") 50 | try: 51 | # Use torch.hub to download the model instead of direct URL 52 | torch.hub.set_dir("./models") 53 | model, utils = torch.hub.load(repo_or_dir="snakers4/silero-vad", 54 | model="silero_vad", 55 | force_reload=True, 56 | onnx=False) 57 | 58 | # Save the model 59 | torch.jit.save(model, model_path) 60 | logger.info(f"Model downloaded and saved to {model_path}") 61 | 62 | except Exception as e: 63 | logger.error(f"Failed to download Silero VAD model using PyTorch Hub: {e}") 64 | logger.info("Falling back to energy-based VAD - the system will still work but with simpler voice detection") 65 | 66 | def download_embedding_models(): 67 | """Download the sentence transformer models for RAG""" 68 | logger.info("Setting up sentence transformer models...") 69 | 70 | try: 71 | from sentence_transformers import SentenceTransformer 72 | 73 | # Download lightweight model for embeddings 74 | logger.info("Downloading embedding models (this may take a few minutes)...") 75 | models = [ 76 | "all-MiniLM-L6-v2", # Fast 77 | "all-mpnet-base-v2", # Balanced 78 | "multi-qa-mpnet-base-dot-v1" # Best for Q&A 79 | ] 80 | 81 | for model_name in models: 82 | logger.info(f"Setting up model: {model_name}") 83 | _ = SentenceTransformer(model_name) 84 | logger.info(f"Model {model_name} is ready") 85 | 86 | except Exception as e: 87 | logger.error(f"Failed to download embedding models: {e}") 88 | logger.error("Please try running the script again or download models manually") 89 | 90 | def setup_directories(): 91 | """Create necessary directories for the application""" 92 | directories = ["static", "responses", "embeddings_cache", "templates"] 93 | 94 | for directory in directories: 95 | os.makedirs(directory, exist_ok=True) 96 | logger.info(f"Directory {directory} is ready") 97 | 98 | # Create template redirect file 99 | template_dir = Path("templates") 100 | index_html = template_dir / "index.html" 101 | 102 | with open(index_html, "w") as f: 103 | f.write(""" 104 | 105 | 106 | 107 | 108 | 109 | 110 |

Redirecting to AI Companion...

111 | 112 | 113 | """) 114 | logger.info("Created index template for redirection") 115 | 116 | def setup_database(): 117 | """Initialize the SQLite database""" 118 | logger.info("Setting up database...") 119 | 120 | try: 121 | from sqlalchemy import create_engine, Column, Integer, String, Text 122 | from sqlalchemy.ext.declarative import declarative_base 123 | from sqlalchemy.orm import sessionmaker 124 | 125 | Base = declarative_base() 126 | engine = create_engine("sqlite:///companion.db") 127 | 128 | class Conversation(Base): 129 | __tablename__ = "conversations" 130 | id = Column(Integer, primary_key=True, index=True) 131 | session_id = Column(String, index=True) 132 | timestamp = Column(String) 133 | user_message = Column(Text) 134 | ai_message = Column(Text) 135 | audio_path = Column(String) 136 | 137 | # Create tables 138 | Base.metadata.create_all(bind=engine) 139 | logger.info("Database initialized successfully") 140 | 141 | except Exception as e: 142 | logger.error(f"Failed to set up database: {e}") 143 | 144 | def check_cuda(): 145 | """Check if CUDA is available for PyTorch""" 146 | if torch.cuda.is_available(): 147 | device_name = torch.cuda.get_device_name(0) 148 | logger.info(f"CUDA is available: {device_name}") 149 | logger.info(f"CUDA version: {torch.version.cuda}") 150 | else: 151 | logger.warning("CUDA is not available. The application will run on CPU, which may be very slow") 152 | logger.warning("For optimal performance, a CUDA-capable GPU is recommended") 153 | 154 | def main(): 155 | """Main setup function""" 156 | logger.info("Starting AI Companion setup...") 157 | 158 | # Check for CUDA availability 159 | check_cuda() 160 | 161 | # Check and install requirements 162 | #check_requirements() 163 | 164 | # Create directories 165 | setup_directories() 166 | 167 | # Set up database 168 | setup_database() 169 | 170 | # Download models 171 | download_vad_model() 172 | download_embedding_models() 173 | 174 | logger.info("Setup completed successfully!") 175 | logger.info("You can now start the application with:") 176 | logger.info(" python main.py") 177 | 178 | if __name__ == "__main__": 179 | main() -------------------------------------------------------------------------------- /llm_interface.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Dict, Any, Optional 3 | import torch 4 | from vllm import LLM, SamplingParams 5 | 6 | class LLMInterface: 7 | def __init__(self, model_path: str, max_tokens: int = 8192, n_threads: int = 8, gpu_layers: int = -1): 8 | """Initialize the LLM interface using VLLM with a given model. 9 | 10 | Args: 11 | model_path (str): Path to the model or HuggingFace model name 12 | max_tokens (int, optional): Maximum context length. Defaults to 8192. 13 | n_threads (int, optional): Number of CPU threads. Defaults to 8. 14 | gpu_layers (int, optional): Not used in VLLM, maintained for API compatibility. 15 | """ 16 | # VLLM configuration 17 | self.llm = LLM( 18 | model=model_path, 19 | tensor_parallel_size=1, # Adjust based on number of GPUs available 20 | gpu_memory_utilization=0.6, 21 | max_model_len=max_tokens, 22 | swap_space=0, 23 | trust_remote_code=True, 24 | dtype=torch.float16, 25 | ) 26 | 27 | # Store configuration for reference 28 | self.config = { 29 | "model_path": model_path, 30 | "max_tokens": max_tokens, 31 | } 32 | 33 | def trim_to_last_sentence(self, text: str) -> str: 34 | """ 35 | Return *text* truncated at the final full sentence boundary. 36 | A boundary is considered to be any '.', '!' or '?' followed by 37 | optional quotes/brackets, optional whitespace, and then end-of-string. 38 | 39 | If no sentence terminator exists, the original text is returned. 40 | """ 41 | # Regex explanation: 42 | # (.*?[.!?]["')\]]?) any text lazily until a terminator 43 | # \s*$ followed only by whitespace till end-of-string 44 | m = re.match(r"^(.*?[.!?][\"')\]]?)\s*$", text, re.DOTALL) 45 | if m: 46 | return m.group(1).strip() 47 | # Fall back to manual search (handles cases with additional text) 48 | for i in range(len(text) - 1, -1, -1): 49 | if text[i] in ".!?": 50 | return text[: i + 1].strip() 51 | return text.strip() 52 | 53 | def generate_response(self, system_prompt: str, user_message: str, conversation_history: str = "") -> str: 54 | """Generate a response from the LLM using chat-style prompt formatting. 55 | 56 | Args: 57 | system_prompt (str): The system prompt/instructions 58 | user_message (str): The user's input message 59 | conversation_history (str, optional): Any prior conversation context. Defaults to "". 60 | 61 | Returns: 62 | str: The generated response 63 | """ 64 | # Format prompt following chat template structure 65 | prompt = f"""<|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot_id|> 66 | {conversation_history} 67 | <|start_header_id|>user<|end_header_id|>\n{user_message}<|eot_id|> 68 | <|start_header_id|>assistant<|end_header_id|>\n""" 69 | 70 | # Define sampling parameters (equivalent to the previous implementation) 71 | sampling_params = SamplingParams( 72 | temperature=1.0, 73 | top_p=0.95, 74 | max_tokens=100, 75 | repetition_penalty=1.2, 76 | top_k=200, 77 | stop=["", "<|endoftext|>", "<>", "<>", "<>", 78 | "<>", "<>", "<|end_header_id|>", "<>", 79 | "<|eot_id|>", "<|im_end|>", "user:", "User:", "user :", "User :"] 80 | ) 81 | 82 | # Generate response using VLLM 83 | outputs = self.llm.generate(prompt, sampling_params) 84 | 85 | # Extract and return the generated text 86 | if outputs and len(outputs) > 0: 87 | text = outputs[0].outputs[0].text 88 | return self.trim_to_last_sentence(text) 89 | return "" 90 | 91 | def tokenize(self, text: str) -> List[int]: 92 | """Tokenize text using VLLM's tokenizer. 93 | 94 | Args: 95 | text (str): Text to tokenize 96 | 97 | Returns: 98 | List[int]: List of token IDs 99 | """ 100 | # VLLM doesn't expose tokenizer directly in the same way 101 | # We can access the tokenizer through the LLM instance 102 | tokenizer = self.llm.get_tokenizer() 103 | return tokenizer.encode(text) 104 | 105 | def get_token_count(self, text: str) -> int: 106 | """Return token count of the input text. 107 | 108 | Args: 109 | text (str): Text to count tokens for 110 | 111 | Returns: 112 | int: Number of tokens 113 | """ 114 | return len(self.tokenize(text)) 115 | 116 | def batch_generate(self, prompts: List[Dict[str, str]], 117 | max_tokens: int = 512, 118 | temperature: float = 0.7) -> List[str]: 119 | """Generate responses for multiple prompts in a batch. 120 | Args: 121 | prompts (List[Dict[str, str]]): List of prompt dictionaries, each with 122 | 'system', 'user' and optional 'history' keys 123 | max_tokens (int, optional): Maximum tokens to generate per response 124 | temperature (float, optional): Temperature for sampling 125 | 126 | Returns: 127 | List[str]: Generated responses 128 | """ 129 | formatted_prompts = [] 130 | 131 | # Format each prompt according to the chat template 132 | for p in prompts: 133 | system = p.get("system", "") 134 | user = p.get("user", "") 135 | history = p.get("history", "") 136 | 137 | formatted_prompt = f"""<|start_header_id|>system<|end_header_id|>\n{system}<|eot_id|> 138 | {history} 139 | <|start_header_id|>user<|end_header_id|>\n{user}<|eot_id|> 140 | <|start_header_id|>assistant<|end_header_id|>\n""" 141 | 142 | formatted_prompts.append(formatted_prompt) 143 | 144 | # Set up batch sampling parameters 145 | sampling_params = SamplingParams( 146 | temperature=temperature, 147 | top_p=0.95, 148 | max_tokens=max_tokens, 149 | repetition_penalty=1.2, 150 | top_k=400, 151 | stop=["", "<|endoftext|>", "<>", "<>", "<>", 152 | "<>", "<>", "<|end_header_id|>", "<>", 153 | "<|eot_id|>", "<|im_end|>", "user:", "User:", "user :", "User :"] 154 | ) 155 | 156 | # Generate responses for all prompts in a batch 157 | outputs = self.llm.generate(formatted_prompts, sampling_params) 158 | 159 | # Extract and return the generated texts 160 | results = [] 161 | for output in outputs: 162 | if output.outputs: 163 | results.append(output.outputs[0].text.strip()) 164 | else: 165 | results.append("") 166 | 167 | return results -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torchtune 7 | from huggingface_hub import PyTorchModelHubMixin 8 | from torchtune.models import llama3_2 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder: 13 | return llama3_2.llama3_2( 14 | vocab_size=128_256, 15 | num_layers=16, 16 | num_heads=32, 17 | num_kv_heads=8, 18 | embed_dim=2048, 19 | max_seq_len=2048, 20 | intermediate_dim=8192, 21 | attn_dropout=0.0, 22 | norm_eps=1e-5, 23 | rope_base=500_000, 24 | scale_factor=32, 25 | ) 26 | 27 | def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder: 28 | return llama3_2.llama3_2( 29 | vocab_size=128_256, 30 | num_layers=4, 31 | num_heads=8, 32 | num_kv_heads=2, 33 | embed_dim=1024, 34 | max_seq_len=2048, 35 | intermediate_dim=8192, 36 | attn_dropout=0.0, 37 | norm_eps=1e-5, 38 | rope_base=500_000, 39 | scale_factor=32, 40 | ) 41 | 42 | FLAVORS = { 43 | "llama-1B": llama3_2_1B, 44 | "llama-100M": llama3_2_100M, 45 | } 46 | 47 | def _prepare_transformer(model): 48 | embed_dim = model.tok_embeddings.embedding_dim 49 | model.tok_embeddings = nn.Identity() 50 | model.output = nn.Identity() 51 | return model, embed_dim 52 | 53 | def _create_causal_mask(seq_len: int, device: torch.device): 54 | return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)) 55 | 56 | def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor): 57 | """ 58 | Args: 59 | mask: (max_seq_len, max_seq_len) 60 | input_pos: (batch_size, seq_len) 61 | 62 | Returns: 63 | (batch_size, seq_len, max_seq_len) 64 | """ 65 | r = mask[input_pos, :] 66 | return r 67 | 68 | def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization 69 | q = torch.empty_like(probs).exponential_(1) 70 | return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int) 71 | 72 | def sample_topk(logits: torch.Tensor, topk: int, temperature: float): 73 | logits = logits / temperature 74 | 75 | filter_value: float = -float("Inf") 76 | indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None] 77 | scores_processed = logits.masked_fill(indices_to_remove, filter_value) 78 | scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1) 79 | probs = torch.nn.functional.softmax(scores_processed, dim=-1) 80 | 81 | sample_token = _multinomial_sample_one_no_sync(probs) 82 | return sample_token 83 | 84 | @dataclass 85 | class ModelArgs: 86 | backbone_flavor: str 87 | decoder_flavor: str 88 | text_vocab_size: int 89 | audio_vocab_size: int 90 | audio_num_codebooks: int 91 | 92 | 93 | class Model( 94 | nn.Module, 95 | PyTorchModelHubMixin, 96 | repo_url="https://github.com/SesameAILabs/csm", 97 | pipeline_tag="text-to-speech", 98 | license="apache-2.0", 99 | ): 100 | def __init__(self, config: ModelArgs): 101 | super().__init__() 102 | self.config = config 103 | 104 | self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]()) 105 | self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]()) 106 | 107 | self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim) 108 | self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim) 109 | 110 | self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False) 111 | self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False) 112 | self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size)) 113 | 114 | def setup_caches(self, max_batch_size: int) -> torch.Tensor: 115 | """Setup KV caches and return a causal mask.""" 116 | dtype = next(self.parameters()).dtype 117 | device = next(self.parameters()).device 118 | 119 | with device: 120 | self.backbone.setup_caches(max_batch_size, dtype) 121 | self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks) 122 | 123 | self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device)) 124 | self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device)) 125 | 126 | def generate_frame( 127 | self, 128 | tokens: torch.Tensor, 129 | tokens_mask: torch.Tensor, 130 | input_pos: torch.Tensor, 131 | temperature: float, 132 | topk: int, 133 | ) -> torch.Tensor: 134 | """ 135 | Args: 136 | tokens: (batch_size, seq_len, audio_num_codebooks+1) 137 | tokens_mask: (batch_size, seq_len, audio_num_codebooks+1) 138 | input_pos: (batch_size, seq_len) positions for each token 139 | mask: (batch_size, seq_len, max_seq_len 140 | 141 | Returns: 142 | (batch_size, audio_num_codebooks) sampled tokens 143 | """ 144 | dtype = next(self.parameters()).dtype 145 | b, s, _ = tokens.size() 146 | 147 | assert self.backbone.caches_are_enabled(), "backbone caches are not enabled" 148 | curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos) 149 | embeds = self._embed_tokens(tokens) 150 | masked_embeds = embeds * tokens_mask.unsqueeze(-1) 151 | h = masked_embeds.sum(dim=2) 152 | h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype) 153 | 154 | last_h = h[:, -1, :] 155 | c0_logits = self.codebook0_head(last_h) 156 | c0_sample = sample_topk(c0_logits, topk, temperature) 157 | c0_embed = self._embed_audio(0, c0_sample) 158 | 159 | curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1) 160 | curr_sample = c0_sample.clone() 161 | curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1) 162 | 163 | # Decoder caches must be reset every frame. 164 | self.decoder.reset_caches() 165 | for i in range(1, self.config.audio_num_codebooks): 166 | curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos) 167 | decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to( 168 | dtype=dtype 169 | ) 170 | ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1]) 171 | ci_sample = sample_topk(ci_logits, topk, temperature) 172 | ci_embed = self._embed_audio(i, ci_sample) 173 | 174 | curr_h = ci_embed 175 | curr_sample = torch.cat([curr_sample, ci_sample], dim=1) 176 | curr_pos = curr_pos[:, -1:] + 1 177 | 178 | return curr_sample 179 | 180 | def reset_caches(self): 181 | self.backbone.reset_caches() 182 | self.decoder.reset_caches() 183 | 184 | def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor: 185 | return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size) 186 | 187 | def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor: 188 | text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2) 189 | 190 | audio_tokens = tokens[:, :, :-1] + ( 191 | self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device) 192 | ) 193 | audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape( 194 | tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1 195 | ) 196 | 197 | return torch.cat([audio_embeds, text_embeds], dim=-2) 198 | 199 | -------------------------------------------------------------------------------- /vad.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from typing import Callable, Dict, List 4 | from collections import deque 5 | class VoiceActivityDetector: 6 | def __init__( 7 | self, 8 | model, 9 | utils, 10 | sample_rate: int = 16000, 11 | threshold: float = 0.3, 12 | silence_duration: int = 45 13 | ): 14 | self.model = model 15 | self.sample_rate = sample_rate 16 | self.threshold = threshold 17 | self.silence_duration = silence_duration 18 | 19 | # Get functions from utils 20 | self.get_speech_timestamps = utils[0] 21 | 22 | self.is_speaking = False 23 | self.silent_frames = 0 24 | self.frame_size = 512 if sample_rate == 16000 else 256 # Required by Silero VAD 25 | 26 | print(f"VAD initialized with threshold {threshold}, frame size {self.frame_size}, silence duration {silence_duration}") 27 | 28 | def reset(self) -> None: 29 | self.is_speaking = False 30 | self.silent_frames = 0 31 | 32 | if hasattr(self.model, "reset_states"): 33 | self.model.reset_states() 34 | elif hasattr(self.model, "reset_state"): 35 | self.model.reset_state() 36 | else: 37 | for buf in ("h", "c"): 38 | if hasattr(self.model, buf): 39 | getattr(self.model, buf).zero_() 40 | 41 | def process_audio_chunk(self, audio_chunk: np.ndarray) -> bool: 42 | # Prepare audio chunk 43 | if audio_chunk.ndim > 1: 44 | audio_chunk = np.mean(audio_chunk, axis=1) 45 | if audio_chunk.dtype != np.float32: 46 | audio_chunk = audio_chunk.astype(np.float32) 47 | 48 | # Process in chunks of the correct size 49 | speech_detected = False 50 | turn_ended = False 51 | 52 | speech_probs = [] 53 | 54 | # Process audio in correct sized chunks for Silero VAD 55 | for i in range(0, len(audio_chunk), self.frame_size): 56 | # Get chunk of correct size 57 | chunk = audio_chunk[i:i+self.frame_size] 58 | 59 | # If we don't have enough samples, pad with zeros 60 | if len(chunk) < self.frame_size: 61 | chunk = np.pad(chunk, (0, self.frame_size - len(chunk))) 62 | 63 | # Convert to tensor 64 | audio_tensor = torch.tensor(chunk).to('cpu') 65 | 66 | # Get speech probability 67 | 68 | speech_prob = self.model(audio_tensor, self.sample_rate).item() 69 | 70 | speech_probs.append(speech_prob) 71 | 72 | # Update speaking state 73 | if speech_prob >= self.threshold: 74 | speech_detected = True 75 | self.silent_frames = 0 76 | else: 77 | if self.is_speaking: 78 | self.silent_frames += 1 79 | 80 | # Print detailed speech detection information 81 | # print(f"Speech probabilities: {speech_probs}") 82 | # print(f"Speech detected: {speech_detected}, Current state: {self.is_speaking}") 83 | # print(f"Silent frames: {self.silent_frames}, Threshold: {self.silence_duration}") 84 | 85 | # Update speaking state based on all chunks 86 | if speech_detected: 87 | self.is_speaking = True 88 | self.silent_frames = 0 89 | elif self.is_speaking and self.silent_frames >= self.silence_duration: 90 | # Transition to not speaking if we've had enough silent frames 91 | self.is_speaking = False 92 | turn_ended = True 93 | print(f"Turn ended after {self.silent_frames} silent frames") 94 | self.silent_frames = 0 95 | 96 | return turn_ended 97 | 98 | 99 | class AudioStreamProcessor: 100 | def __init__( 101 | self, 102 | model, 103 | utils, 104 | sample_rate: int = 16000, 105 | chunk_size: int = 512, 106 | vad_threshold: float = 0.3, 107 | callbacks: Dict[str, Callable] = None, 108 | pre_speech_buffer_size: int = 10 109 | ): 110 | self.sample_rate = sample_rate 111 | self.chunk_size = chunk_size 112 | self.pre_speech_buffer = deque(maxlen=pre_speech_buffer_size) 113 | # Ensure model is on CPU 114 | if hasattr(model, 'to'): 115 | model = model.to('cpu') 116 | 117 | self.vad = VoiceActivityDetector( 118 | model=model, 119 | utils=utils, 120 | sample_rate=sample_rate, 121 | threshold=vad_threshold, 122 | silence_duration=45 # Increased for better end detection 123 | ) 124 | 125 | self.audio_buffer = [] 126 | self.is_collecting = False 127 | self.callbacks = callbacks or {} 128 | self.silent_chunk_count = 0 129 | self.max_silent_chunks = 30 # Force end after this many silent chunks 130 | 131 | print(f"AudioStreamProcessor initialized with threshold: {vad_threshold}") 132 | 133 | def process_audio(self, audio_chunk: np.ndarray): 134 | # Always add to pre-speech buffer 135 | self.pre_speech_buffer.append(audio_chunk) 136 | 137 | if self.is_collecting: 138 | self.audio_buffer.append(audio_chunk) 139 | 140 | # Process with VAD 141 | is_turn_end = self.vad.process_audio_chunk(audio_chunk) 142 | 143 | # Start collecting on speech detection 144 | if self.vad.is_speaking and not self.is_collecting: 145 | self.is_collecting = True 146 | self.silent_chunk_count = 0 147 | # Include pre-speech buffer in the audio buffer 148 | self.audio_buffer = list(self.pre_speech_buffer) 149 | print(f"Speech started, beginning collection with {len(self.pre_speech_buffer)} pre-speech chunks") 150 | if "on_speech_start" in self.callbacks: 151 | self.callbacks["on_speech_start"]() 152 | 153 | # Count silent chunks when collecting but not speaking 154 | if self.is_collecting and not self.vad.is_speaking: 155 | self.silent_chunk_count += 1 156 | print(f"Silent chunk count: {self.silent_chunk_count}, max: {self.max_silent_chunks}") 157 | # Force end after too many silent chunks 158 | if self.silent_chunk_count >= self.max_silent_chunks: 159 | is_turn_end = True 160 | print(f"Forcing speech end after {self.silent_chunk_count} silent chunks") 161 | else: 162 | self.silent_chunk_count = 0 163 | 164 | # End collection on turn end 165 | if is_turn_end and self.is_collecting: 166 | print("Turn end detected, processing collected audio") 167 | self.is_collecting = False 168 | if self.audio_buffer: 169 | print(f"Audio buffer length: {len(self.audio_buffer)} chunks") 170 | print("Speech ended, processing collected audio") 171 | complete_audio = np.concatenate(self.audio_buffer) 172 | print(f"Complete audio length: {len(complete_audio)}") 173 | 174 | if "on_speech_end" in self.callbacks: 175 | try: 176 | print("Calling on_speech_end callback") 177 | self.callbacks["on_speech_end"](complete_audio, self.sample_rate) 178 | print("on_speech_end callback completed successfully") 179 | except Exception as e: 180 | print(f"Error in on_speech_end callback: {e}") 181 | 182 | # Clear buffer after processing 183 | self.audio_buffer = [] 184 | self.silent_chunk_count = 0 185 | 186 | def reset(self): 187 | self.vad.reset() 188 | self.audio_buffer = [] 189 | self.is_collecting = False 190 | self.silent_chunk_count = 0 191 | print("AudioStreamProcessor reset") -------------------------------------------------------------------------------- /static/app.js: -------------------------------------------------------------------------------- 1 | let ws; 2 | let micAnalyser, micContext, micSource, micStream; 3 | let outputAnalyser, outputAudioCtx; 4 | let lastConfig = null; 5 | let isLoading = false; 6 | 7 | document.addEventListener('DOMContentLoaded', async () => { 8 | await populateAudioDevices(); 9 | 10 | ws = new WebSocket(`ws://${window.location.host}/ws`); 11 | 12 | ws.onopen = () => { 13 | console.log("WebSocket connected, requesting saved config..."); 14 | ws.send(JSON.stringify({ type: "request_saved_config" })); 15 | }; 16 | 17 | ws.onmessage = async (event) => { 18 | const data = JSON.parse(event.data); 19 | 20 | if (data.type === "saved_config" && data.config) { 21 | document.getElementById('systemPrompt').value = data.config.system_prompt || ""; 22 | document.getElementById('modelPath').value = data.config.model_path || ""; 23 | document.getElementById('llmPath').value = data.config.llm_path || ""; 24 | document.getElementById('referenceAudio').value = data.config.reference_audio_path || ""; 25 | document.getElementById('referenceText').value = data.config.reference_text || ""; 26 | document.getElementById('referenceAudio2').value = data.config.reference_audio_path2 || ""; 27 | document.getElementById('referenceText2').value = data.config.reference_text2 || ""; 28 | document.getElementById('referenceAudio3').value = data.config.reference_audio_path3 || ""; 29 | document.getElementById('referenceText3').value = data.config.reference_text3 || ""; 30 | 31 | setTimeout(() => { 32 | if (data.config.mic_id) document.getElementById('micSelect').value = data.config.mic_id; 33 | if (data.config.output_id) document.getElementById('outputSelect').value = data.config.output_id; 34 | }, 500); 35 | } 36 | 37 | if (data.type === "status") { 38 | if (data.message.includes("Models initialized")) { 39 | console.log("Model initialization confirmed. Redirecting..."); 40 | 41 | // Save config again just to be safe 42 | localStorage.setItem('ai_config', JSON.stringify(lastConfig)); 43 | 44 | // Close WebSocket before navigating 45 | if (ws && ws.readyState === WebSocket.OPEN) { 46 | ws.close(); 47 | } 48 | 49 | // Wait briefly to let server clean up, then redirect 50 | setTimeout(() => { 51 | window.location.href = "/chat"; 52 | }, 100); 53 | } else if (data.message.includes("Initializing") || data.message.includes("Loading")) { 54 | // Show that models are being loaded 55 | showLoading(true, data.message); 56 | } 57 | } 58 | }; 59 | 60 | document.getElementById('testMicBtn').addEventListener('click', async () => { 61 | const micId = getSelectedMic(); 62 | micStream = await navigator.mediaDevices.getUserMedia({ audio: { deviceId: micId } }); 63 | 64 | micContext = new AudioContext(); 65 | micSource = micContext.createMediaStreamSource(micStream); 66 | micAnalyser = micContext.createAnalyser(); 67 | micSource.connect(micAnalyser); 68 | visualizeMic(micAnalyser, 'micCanvas'); 69 | 70 | const recorder = new MediaRecorder(micStream); 71 | const chunks = []; 72 | 73 | recorder.ondataavailable = e => { 74 | if (e.data.size > 0) chunks.push(e.data); 75 | }; 76 | 77 | recorder.onstop = () => { 78 | const blob = new Blob(chunks, { type: 'audio/webm' }); 79 | const url = URL.createObjectURL(blob); 80 | const audio = new Audio(url); 81 | audio.play(); 82 | 83 | micStream.getTracks().forEach(track => track.stop()); 84 | micContext.close(); 85 | }; 86 | 87 | recorder.start(); 88 | setTimeout(() => recorder.stop(), 3000); 89 | }); 90 | 91 | document.getElementById('testOutputBtn').addEventListener('click', () => { 92 | const audio = new Audio('/static/test.mp3'); 93 | audio.setSinkId(getSelectedOutput()).then(() => { 94 | outputAudioCtx = new AudioContext(); 95 | const outputSource = outputAudioCtx.createMediaElementSource(audio); 96 | outputAnalyser = outputAudioCtx.createAnalyser(); 97 | outputSource.connect(outputAnalyser); 98 | outputAnalyser.connect(outputAudioCtx.destination); 99 | visualizeMic(outputAnalyser, 'outputCanvas'); 100 | audio.play(); 101 | }).catch(err => { 102 | console.warn("Failed to route output:", err); 103 | }); 104 | }); 105 | 106 | document.getElementById('saveAndStart').addEventListener('click', () => { 107 | lastConfig = { 108 | system_prompt: document.getElementById('systemPrompt').value, 109 | model_path: document.getElementById('modelPath').value, 110 | llm_path: document.getElementById('llmPath').value, 111 | reference_audio_path: document.getElementById('referenceAudio').value, 112 | reference_text: document.getElementById('referenceText').value, 113 | reference_audio_path2: document.getElementById('referenceAudio2').value, 114 | reference_text2: document.getElementById('referenceText2').value, 115 | reference_audio_path3: document.getElementById('referenceAudio3').value, 116 | reference_text3: document.getElementById('referenceText3').value, 117 | mic_id: getSelectedMic(), 118 | output_id: getSelectedOutput(), 119 | }; 120 | console.log("Sending config to backend..."); 121 | console.log(lastConfig) 122 | showLoading(true, "Initializing models, please wait..."); 123 | ws.send(JSON.stringify({ type: "config", config: lastConfig })); 124 | // we wait for the backend to reply with model status before navigating 125 | }); 126 | }); 127 | 128 | function showLoading(show, message) { 129 | const saveButton = document.getElementById('saveAndStart'); 130 | const loadingContainer = document.getElementById('loadingContainer'); 131 | const loadingSpinner = document.getElementById('loadingSpinner'); 132 | const loadingText = document.getElementById('loadingText'); 133 | 134 | isLoading = show; 135 | 136 | if (show) { 137 | saveButton.disabled = true; 138 | saveButton.classList.add('opacity-50', 'cursor-not-allowed'); 139 | saveButton.classList.remove('hover:bg-green-500'); 140 | loadingContainer.classList.remove('hidden'); 141 | loadingSpinner.style.display = 'block'; 142 | if (message) { 143 | loadingText.textContent = message; 144 | } 145 | } else { 146 | saveButton.disabled = false; 147 | saveButton.classList.remove('opacity-50', 'cursor-not-allowed'); 148 | saveButton.classList.add('hover:bg-green-500'); 149 | loadingContainer.classList.add('hidden'); 150 | loadingSpinner.style.display = 'none'; 151 | } 152 | } 153 | 154 | function getSelectedMic() { 155 | return document.getElementById('micSelect').value; 156 | } 157 | 158 | function getSelectedOutput() { 159 | return document.getElementById('outputSelect').value; 160 | } 161 | 162 | async function populateAudioDevices() { 163 | try { 164 | await navigator.mediaDevices.getUserMedia({ audio: true }); 165 | } catch (err) { 166 | console.warn("Microphone permission denied or not granted."); 167 | return; 168 | } 169 | 170 | const devices = await navigator.mediaDevices.enumerateDevices(); 171 | const micSelect = document.getElementById('micSelect'); 172 | const outputSelect = document.getElementById('outputSelect'); 173 | 174 | micSelect.innerHTML = ''; 175 | outputSelect.innerHTML = ''; 176 | 177 | devices.forEach(device => { 178 | const option = new Option(device.label || `${device.kind}`, device.deviceId); 179 | if (device.kind === 'audioinput') micSelect.add(option.cloneNode(true)); 180 | if (device.kind === 'audiooutput') { 181 | outputSelect.add(option.cloneNode(true)); 182 | } 183 | }); 184 | 185 | if (micSelect.options.length === 0) { 186 | micSelect.add(new Option("No mic devices found", "")); 187 | } 188 | if (outputSelect.options.length === 0) { 189 | outputSelect.add(new Option("Default Output", "default")); 190 | } 191 | } 192 | 193 | function visualizeMic(analyser, canvasId) { 194 | const canvas = document.getElementById(canvasId); 195 | const ctx = canvas.getContext("2d"); 196 | analyser.fftSize = 256; 197 | const bufferLength = analyser.frequencyBinCount; 198 | const dataArray = new Uint8Array(bufferLength); 199 | 200 | function draw() { 201 | requestAnimationFrame(draw); 202 | analyser.getByteFrequencyData(dataArray); 203 | ctx.fillStyle = "#1f2937"; 204 | ctx.fillRect(0, 0, canvas.width, canvas.height); 205 | const barWidth = canvas.width / bufferLength; 206 | for (let i = 0; i < bufferLength; i++) { 207 | const barHeight = dataArray[i]; 208 | ctx.fillStyle = "#4ade80"; 209 | ctx.fillRect(i * barWidth, canvas.height - barHeight / 2, barWidth - 1, barHeight / 2); 210 | } 211 | } 212 | draw(); 213 | } -------------------------------------------------------------------------------- /templates/chat.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | AI Companion - Chat 7 | 8 | 9 | 105 | 106 | 107 | 108 |
109 | 110 |
111 |

AI Companion

112 | 115 |
116 | 117 |
118 | 119 | 164 | 165 |
166 | 167 |
168 |
169 |
170 | 171 |
172 |
173 | 179 | 187 |
188 |
189 |
190 |
191 | 192 | 244 | 245 | 246 | 247 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSM - Optimized Streaming/Finetuning Edition 2 | 3 | --- 4 | 5 | CSM (Conversational Speech Model) is a speech generation model from [Sesame](https://www.sesame.com) that generates RVQ audio codes from text and audio inputs. The model architecture employs a [Llama](https://www.llama.com/) backbone and a smaller audio decoder that produces [Mimi](https://huggingface.co/kyutai/mimi) audio codes. 6 | 7 | Our fork adds **streaming audio generation**, **real-time playback**, and **performance optimizations** to the original implementation. 8 | 9 | ## Requirements 10 | 11 | * A CUDA-compatible GPU 12 | * The code has been tested on CUDA 12.4 and 12.6, but it may also work on other versions 13 | * Similarly, Python 3.10 is recommended, but newer versions may be fine 14 | * For some audio operations, `ffmpeg` may be required 15 | * For real-time audio playback: `pip install sounddevice` 16 | * Access to the following Hugging Face models: 17 | * [Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) 18 | * [CSM-1B](https://huggingface.co/sesame/csm-1b) 19 | 20 | ### Setup 21 | 22 | ```bash 23 | sudo apt-get update && sudo apt-get install -y libportaudio2 libportaudio-dev 24 | git clone git@github.com:davidbrowne17/csm-streaming.git 25 | cd csm-streaming 26 | python3.10 -m venv .venv 27 | source .venv/bin/activate 28 | pip install -r requirements.txt 29 | 30 | # Optional speedup 31 | pip install flash-attn 32 | # You will need access to CSM-1B and Llama-3.2-1B 33 | huggingface-cli login 34 | ``` 35 | 36 | ### Windows Setup 37 | 38 | The `triton` package cannot be installed in Windows. Instead use `pip install triton-windows`. 39 | The realtime demo uses VLLM for inference speed. This is currently not supported for windows but you can try with https://github.com/SystemPanic/vllm-windows until support is added. 40 | 41 | ## Quickstart 42 | 43 | Generate a sentence with streaming (chunks are processed and output as they're generated): 44 | 45 | ```python 46 | import time 47 | from huggingface_hub import hf_hub_download 48 | from generator import Generator, Segment, load_csm_1b, generate_streaming_audio 49 | import torchaudio 50 | 51 | # Load the model 52 | generator = load_csm_1b("cuda") 53 | 54 | # Generate audio with streaming and real-time playback 55 | generate_streaming_audio( 56 | generator=generator, 57 | text="Hello, this is streaming audio generation in action!", 58 | speaker=0, 59 | context=[], # No context needed for basic generation 60 | output_file="streaming_audio.wav", 61 | play_audio=True # Enable real-time playback 62 | ) 63 | ``` 64 | ## Finetuning 65 | To finetune CSM all you need are some wav audio files with the speaker voice you want to train, just the raw wavs. Place them in a folder called audio_data and run lora.py. 66 | You can configure the exact training params such as batch size, number of epochs and learning rate by modifying the values at the top of lora.py. 67 | You will need a CUDA gpu with at least 12gb of vram depending on your dataset size and training params. You can monitor the training metrics via the dynamic png created in /finetuned_model/ folder. This contains various graphs to help you track the training progress. If you want to try a checkpoint you can use the loadandmergecheckpoint.py (make sure to set the same R and Alpha values as you used in the training) 68 | 69 | ## Realtime chat demo 70 | To use the realtime demo run the setup.py to download the required models, and then run main.py. This will open up a setup page at http://localhost:8000 in which you can set the paths for your chosen LLM and setup the CSM paths and reference audio as well as select your headset and mic. When loaded you will be able to chat in realtime with the AI just like the CSM demo. Our demo includes a dynamic RAG system so the AI can remember previous conversations. The demo by default uses whisper-large-v3-turbo for STT and includes Automatic Voice Detection using Silero VAD. 71 | 72 | ## Usage 73 | 74 | Our optimized version offers several ways to use CSM with streaming capabilities: 75 | 76 | ### 1. Basic Streaming Generation 77 | 78 | Generate audio with streaming and save to a file: 79 | 80 | ```python 81 | from generator import load_csm_1b, generate_streaming_audio 82 | 83 | generator = load_csm_1b("cuda") 84 | 85 | # Generate with streaming (writes to file as it generates) 86 | generate_streaming_audio( 87 | generator=generator, 88 | text="This audio will be generated in chunks for faster response times.", 89 | speaker=0, 90 | context=[], 91 | output_file="streaming_output.wav" 92 | ) 93 | ``` 94 | 95 | ### 2. Real-time Audio Playback 96 | 97 | Generate and play audio in real-time as it's being generated: 98 | 99 | ```python 100 | from generator import load_csm_1b, generate_streaming_audio 101 | 102 | generator = load_csm_1b("cuda") 103 | 104 | # Generate with streaming and play in real-time 105 | generate_streaming_audio( 106 | generator=generator, 107 | text="You'll hear me speaking as I'm being generated!", 108 | speaker=0, 109 | context=[], 110 | output_file="streaming_output.wav", 111 | play_audio=True # Enable real-time playback 112 | ) 113 | ``` 114 | 115 | ### 3. Low-level Streaming API 116 | 117 | For more control, use the low-level streaming API: 118 | 119 | ```python 120 | from generator import load_csm_1b, Segment 121 | import torchaudio 122 | 123 | generator = load_csm_1b("cuda") 124 | 125 | # Process audio chunks as they're generated 126 | for audio_chunk in generator.generate_stream( 127 | text="This is generated chunk by chunk.", 128 | speaker=0, 129 | context=[] 130 | ): 131 | # Do something with each chunk as it's generated 132 | print(f"Received chunk of size: {audio_chunk.shape}") 133 | 134 | # You could process or play each chunk here 135 | # For example, write to a file incrementally 136 | # Or send over a network connection 137 | ``` 138 | 139 | ### 4. Generate with Context 140 | 141 | For best results, provide reference audio context: 142 | 143 | ```python 144 | from generator import load_csm_1b, Segment, generate_streaming_audio 145 | import torchaudio 146 | 147 | generator = load_csm_1b("cuda") 148 | 149 | # Load reference audio 150 | def load_audio(audio_path): 151 | audio_tensor, sample_rate = torchaudio.load(audio_path) 152 | audio_tensor = torchaudio.functional.resample( 153 | audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate 154 | ) 155 | return audio_tensor 156 | 157 | # Create context segments 158 | segments = [ 159 | Segment( 160 | text="I knew I could trust you.", 161 | speaker=0, 162 | audio=load_audio("reference.wav") 163 | ) 164 | ] 165 | 166 | # Generate with streaming using the context 167 | generate_streaming_audio( 168 | generator=generator, 169 | text="Me too, this is some cool stuff huh?", 170 | speaker=0, 171 | context=segments, 172 | output_file="contextual_streaming.wav", 173 | play_audio=True 174 | ) 175 | ``` 176 | 177 | ### 5. Regular Generation with Internal Streaming 178 | 179 | Use the original API with streaming enabled internally: 180 | 181 | ```python 182 | from generator import load_csm_1b, Segment 183 | import torchaudio 184 | 185 | generator = load_csm_1b("cuda") 186 | 187 | # Regular generation but with internal streaming optimization 188 | audio = generator.generate( 189 | text="This uses internal streaming for faster processing.", 190 | speaker=0, 191 | context=[], 192 | max_audio_length_ms=10_000, 193 | stream=True # Enable internal streaming optimization 194 | ) 195 | 196 | torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate) 197 | ``` 198 | ## Performance Optimizations 199 | 200 | Our optimized version includes several performance enhancements: 201 | 202 | - **Streaming Generation**: Processes and outputs audio in chunks instead of waiting for the entire generation achieving a Real-time factor (RTF): 0.28x (target: <1.0) on a 4090 (10 seconds of audio takes 2.8 seconds to generate) 203 | - **Frame Batching**: Processes multiple frames at once for better GPU utilization 204 | - **Half-precision Inference**: Uses bfloat16/float16 for faster processing 205 | - **CUDA Optimizations**: Enables cuDNN benchmarking and Flash Attention where available 206 | - **Memory Management**: Clears GPU cache before generation to reduce memory pressure 207 | 208 | ## FAQ 209 | 210 | **How much faster is the streaming version?** 211 | 212 | The perceived response time is significantly faster since you get the first audio chunks in milliseconds instead of waiting for the entire generation to complete. The actual total generation time is also improved by 40-60% depending on your hardware. 213 | 214 | **Does this model come with any voices?** 215 | 216 | The model is a base generation model capable of producing a variety of voices but hasn't been fine-tuned on any specific voice. Provide reference audio for best results. 217 | 218 | **Can I converse with the model?** 219 | 220 | CSM is trained to be an audio generation model and not a general-purpose multimodal LLM. It cannot generate text. Using a seperate LLM you can converse with the realtime demo via the web ui. 221 | 222 | **Does it support other languages?** 223 | 224 | The model has some capacity for non-English languages due to data contamination in the training data, but it likely won't do well. 225 | 226 | ## Misuse and abuse ⚠️ 227 | 228 | This project provides a high-quality speech generation model for research and educational purposes. While we encourage responsible and ethical use, we **explicitly prohibit** the following: 229 | 230 | - **Impersonation or Fraud**: Do not use this model to generate speech that mimics real individuals without their explicit consent. 231 | - **Misinformation or Deception**: Do not use this model to create deceptive or misleading content, such as fake news or fraudulent calls. 232 | - **Illegal or Harmful Activities**: Do not use this model for any illegal, harmful, or malicious purposes. 233 | 234 | By using this model, you agree to comply with all applicable laws and ethical guidelines. We are **not responsible** for any misuse, and we strongly condemn unethical applications of this technology. 235 | 236 | --- 237 | 238 | ## Original Authors 239 | Johan Schalkwyk, Ankit Kumar, Dan Lyth, Sefik Emre Eskimez, Zack Hodari, Cinjon Resnick, Ramon Sanabria, Raven Jiang, and the Sesame team. 240 | 241 | ## Streaming, Realtime Demo and Finetuning Implementation 242 | David Browne 243 | 244 | ## Support me 245 | Support this project on Ko-fi: https://ko-fi.com/davidbrowne17 246 | 247 | ## Transformers streaming 248 | If you want to use streaming with the Transformers implementation you can find it here: https://github.com/davidbrowne17/csm-streaming-tf 249 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /rag_system.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import numpy as np 3 | import json 4 | from pathlib import Path 5 | import time 6 | from typing import List, Dict, Any, Tuple, Optional 7 | from sentence_transformers import SentenceTransformer 8 | from sklearn.metrics.pairwise import cosine_similarity 9 | import torch 10 | 11 | class RAGSystem: 12 | def __init__(self, db_path: str, model_name: str = "all-MiniLM-L6-v2", cache_dir: str = "./embeddings_cache"): 13 | """ 14 | Initialize the enhanced RAG system with embeddings. 15 | 16 | Args: 17 | db_path: Path to the SQLite database 18 | model_name: Name of the sentence-transformer model to use 19 | cache_dir: Directory to cache embeddings 20 | """ 21 | self.db_path = db_path 22 | self.cache_dir = Path(cache_dir) 23 | self.cache_dir.mkdir(exist_ok=True) 24 | 25 | # Load embedding model 26 | print(f"Loading embedding model: {model_name}") 27 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 28 | self.model = SentenceTransformer(model_name, device=self.device) 29 | print(f"Embedding model loaded on {self.device}") 30 | 31 | # Cache for embeddings 32 | self.embedding_cache = self._load_embedding_cache() 33 | 34 | # Initialize database tables if needed 35 | self._initialize_db() 36 | 37 | # Load existing conversations and cache embeddings 38 | self._load_conversations() 39 | 40 | def _initialize_db(self): 41 | """Create necessary tables if they don't exist.""" 42 | conn = sqlite3.connect(self.db_path) 43 | cursor = conn.cursor() 44 | 45 | # Create conversations table if it doesn't exist 46 | cursor.execute(""" 47 | CREATE TABLE IF NOT EXISTS conversations ( 48 | id INTEGER PRIMARY KEY, 49 | user_message TEXT, 50 | ai_message TEXT, 51 | timestamp DATETIME DEFAULT CURRENT_TIMESTAMP 52 | ) 53 | """) 54 | 55 | # Create embeddings table if it doesn't exist 56 | cursor.execute(""" 57 | CREATE TABLE IF NOT EXISTS embeddings ( 58 | id INTEGER PRIMARY KEY, 59 | conversation_id INTEGER, 60 | text TEXT, 61 | embedding_file TEXT, 62 | chunk_id TEXT, 63 | FOREIGN KEY (conversation_id) REFERENCES conversations(id) 64 | ) 65 | """) 66 | 67 | conn.commit() 68 | conn.close() 69 | 70 | def _load_embedding_cache(self) -> Dict[str, np.ndarray]: 71 | """Load cached embeddings from disk.""" 72 | cache = {} 73 | 74 | for cache_file in self.cache_dir.glob("*.json"): 75 | try: 76 | with open(cache_file, "r") as f: 77 | cache_data = json.load(f) 78 | for chunk_id, embedding_data in cache_data.items(): 79 | cache[chunk_id] = np.array(embedding_data) 80 | except Exception as e: 81 | print(f"Error loading cache file {cache_file}: {e}") 82 | 83 | print(f"Loaded {len(cache)} cached embeddings") 84 | return cache 85 | 86 | def _save_embedding_to_cache(self, chunk_id: str, embedding: np.ndarray): 87 | """Save an embedding to the cache.""" 88 | cache_file = self.cache_dir / f"{chunk_id[:2]}.json" 89 | 90 | # Load existing cache file or create new one 91 | if cache_file.exists(): 92 | try: 93 | with open(cache_file, "r") as f: 94 | cache_data = json.load(f) 95 | except: 96 | cache_data = {} 97 | else: 98 | cache_data = {} 99 | 100 | # Add new embedding 101 | cache_data[chunk_id] = embedding.tolist() 102 | 103 | # Save cache file 104 | with open(cache_file, "w") as f: 105 | json.dump(cache_data, f) 106 | 107 | def _load_conversations(self): 108 | """Load existing conversations from the database and cache their embeddings.""" 109 | try: 110 | conn = sqlite3.connect(self.db_path) 111 | cursor = conn.cursor() 112 | 113 | # First check if the conversations table exists 114 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='conversations'") 115 | if not cursor.fetchone(): 116 | print("Conversations table does not exist yet") 117 | conn.close() 118 | return 119 | 120 | # Get all conversations not yet in the embeddings table 121 | cursor.execute(""" 122 | SELECT c.id, c.user_message, c.ai_message 123 | FROM conversations c 124 | LEFT JOIN embeddings e ON c.id = e.conversation_id 125 | WHERE e.id IS NULL 126 | """) 127 | 128 | conversations = cursor.fetchall() 129 | if not conversations: 130 | conn.close() 131 | return 132 | 133 | print(f"Processing embeddings for {len(conversations)} new conversations") 134 | 135 | for conv_id, user_message, ai_message in conversations: 136 | # Create chunks for indexing 137 | if user_message is not None and ai_message is not None: # Ensure neither is None 138 | self._process_conversation(conv_id, user_message, ai_message, conn) 139 | 140 | conn.close() 141 | print("Finished processing conversation embeddings") 142 | except Exception as e: 143 | print(f"Error loading conversations: {e}") 144 | 145 | def _process_conversation(self, conv_id: int, user_message: str, ai_message: str, conn: sqlite3.Connection): 146 | """Process a conversation and store its embeddings.""" 147 | try: 148 | cursor = conn.cursor() 149 | 150 | # Combine user and AI messages 151 | full_text = f"User: {user_message}\nAI: {ai_message}" 152 | 153 | # For simplicity, we're using the entire message as a chunk 154 | # In a more sophisticated system, you might split long messages into smaller chunks 155 | chunk_id = f"conv_{conv_id}" 156 | 157 | # Check if we already have this embedding cached 158 | if chunk_id not in self.embedding_cache: 159 | # Generate embedding 160 | embedding = self.model.encode(full_text) 161 | self.embedding_cache[chunk_id] = embedding 162 | 163 | # Save to cache 164 | self._save_embedding_to_cache(chunk_id, embedding) 165 | else: 166 | embedding = self.embedding_cache[chunk_id] 167 | 168 | # Store reference in database 169 | embedding_file = f"{chunk_id[:2]}.json" 170 | cursor.execute( 171 | "INSERT INTO embeddings (conversation_id, text, embedding_file, chunk_id) VALUES (?, ?, ?, ?)", 172 | (conv_id, full_text, embedding_file, chunk_id) 173 | ) 174 | 175 | conn.commit() 176 | except Exception as e: 177 | print(f"Error processing conversation {conv_id}: {e}") 178 | 179 | def add_conversation(self, user_message: str, ai_message: str) -> int: 180 | """ 181 | Add a new conversation to the RAG system. 182 | 183 | Returns: 184 | The id of the newly added conversation 185 | """ 186 | try: 187 | conn = sqlite3.connect(self.db_path) 188 | cursor = conn.cursor() 189 | 190 | # Insert the conversation first 191 | cursor.execute( 192 | "INSERT INTO conversations (user_message, ai_message) VALUES (?, ?)", 193 | (user_message, ai_message) 194 | ) 195 | 196 | # Get the ID of the new conversation 197 | conv_id = cursor.lastrowid 198 | 199 | # Process the conversation for embeddings 200 | self._process_conversation(conv_id, user_message, ai_message, conn) 201 | 202 | conn.commit() 203 | conn.close() 204 | 205 | return conv_id 206 | except Exception as e: 207 | print(f"Error adding conversation: {e}") 208 | return -1 209 | 210 | def query(self, query_text: str, top_k: int = 3) -> List[Tuple[str, float]]: 211 | """ 212 | Query the RAG system for relevant context. 213 | 214 | Args: 215 | query_text: The query text 216 | top_k: Number of top results to return 217 | 218 | Returns: 219 | List of tuples with (text, similarity_score) 220 | """ 221 | if query_text is None or query_text.strip() == "": 222 | print("Error: Empty query text") 223 | return [] 224 | 225 | try: 226 | # Generate query embedding 227 | query_embedding = self.model.encode(query_text) 228 | 229 | # Find most similar conversations 230 | results = self._find_similar(query_embedding, top_k) 231 | 232 | return results 233 | except Exception as e: 234 | print(f"Error during query: {e}") 235 | return [] 236 | 237 | def get_context(self, query_text: str, top_k: int = 3, threshold: float = 0.6) -> str: 238 | """ 239 | Get formatted context from the RAG system. 240 | 241 | Args: 242 | query_text: The query text 243 | top_k: Number of top results to return 244 | threshold: Minimum similarity score to include 245 | 246 | Returns: 247 | String with relevant context 248 | """ 249 | results = self.query(query_text, top_k) 250 | 251 | if not results: 252 | return "" 253 | 254 | # Format results 255 | context_parts = [] 256 | for text, score in results: 257 | # Only include really relevant results 258 | if score < threshold: # Threshold for relevance 259 | continue 260 | context_parts.append(f"Relevance: {score:.2f}\n{text}") 261 | 262 | return "\n---\n".join(context_parts) 263 | 264 | def _find_similar(self, query_embedding: np.ndarray, top_k: int) -> List[Tuple[str, float]]: 265 | """Find the most similar conversations to the query.""" 266 | try: 267 | conn = sqlite3.connect(self.db_path) 268 | cursor = conn.cursor() 269 | 270 | # Check if the embeddings table exists 271 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='embeddings'") 272 | if not cursor.fetchone(): 273 | print("Embeddings table does not exist yet") 274 | conn.close() 275 | return [] 276 | 277 | # Get all embeddings from the database 278 | cursor.execute("SELECT id, text, embedding_file, chunk_id FROM embeddings") 279 | results = cursor.fetchall() 280 | 281 | if not results: 282 | conn.close() 283 | return [] 284 | 285 | # Calculate similarities 286 | similarities = [] 287 | for db_id, text, embedding_file, chunk_id in results: 288 | # Get embedding from cache 289 | if chunk_id in self.embedding_cache: 290 | embedding = self.embedding_cache[chunk_id] 291 | else: 292 | # This should not happen, but just in case 293 | # We'll reload from the cache file 294 | cache_file = self.cache_dir / embedding_file 295 | if cache_file.exists(): 296 | with open(cache_file, "r") as f: 297 | cache_data = json.load(f) 298 | if chunk_id in cache_data: 299 | embedding = np.array(cache_data[chunk_id]) 300 | self.embedding_cache[chunk_id] = embedding 301 | else: 302 | continue 303 | else: 304 | continue 305 | 306 | # Calculate similarity 307 | similarity = cosine_similarity( 308 | query_embedding.reshape(1, -1), 309 | embedding.reshape(1, -1) 310 | )[0][0] 311 | 312 | similarities.append((text, similarity)) 313 | 314 | conn.close() 315 | 316 | # Sort by similarity and return top_k 317 | similarities.sort(key=lambda x: x[1], reverse=True) 318 | return similarities[:top_k] 319 | except Exception as e: 320 | print(f"Error finding similar documents: {e}") 321 | return [] 322 | 323 | def refresh(self): 324 | """Refresh embeddings from the database.""" 325 | self._load_conversations() 326 | 327 | # Example usage 328 | if __name__ == "__main__": 329 | # Initialize the RAG system 330 | rag = RAGSystem("conversations.db") -------------------------------------------------------------------------------- /static/chat.js: -------------------------------------------------------------------------------- 1 | let ws; 2 | let sessionStartTime = null; 3 | let messageCount = 0; 4 | let audioLevelsChart = null; 5 | let isRecording = false; 6 | let isAudioCurrentlyPlaying = false; 7 | let configSaved = false; 8 | let currentAudioSource = null; 9 | let interruptRequested = false; 10 | let interruptInProgress = false; 11 | let audioContext = null; 12 | let lastSeenGenId = 0; 13 | let reconnecting = false; 14 | let reconnectAttempts = 0; 15 | let maxReconnectAttempts = 10; 16 | 17 | const SESSION_ID = "default"; 18 | console.log("chat.js loaded"); 19 | 20 | let micStream; 21 | let selectedMicId = null; 22 | let selectedOutputId = null; 23 | 24 | let audioPlaybackQueue = []; 25 | let audioDataHistory = []; 26 | let micAnalyser, micContext; 27 | let activeGenId = 0; 28 | 29 | function createPermanentVoiceCircle() { 30 | if (document.getElementById('voice-circle')) return; 31 | const style = document.createElement('style'); 32 | style.textContent = ` 33 | #voice-circle{ 34 | position:fixed;top:50%;left:50%; 35 | width:180px;height:180px;border-radius:50%; 36 | background:rgba(99,102,241,.20); 37 | transform:translate(-50%,-50%) scale(var(--dynamic-scale,1)); 38 | pointer-events:none;z-index:50; 39 | transition:background-color .35s ease; 40 | } 41 | #voice-circle.active{ 42 | animation:pulse-circle 2s infinite alternate ease-in-out; 43 | } 44 | @keyframes pulse-circle{ 45 | 0%{background:rgba(99,102,241,.55)} 46 | 100%{background:rgba(99,102,241,.20)} 47 | }`; 48 | document.head.appendChild(style); 49 | 50 | const c = document.createElement('div'); 51 | c.id='voice-circle'; 52 | document.body.appendChild(c); 53 | console.log("Created permanent voice circle"); 54 | } 55 | 56 | function showVoiceCircle() { 57 | const c=document.getElementById('voice-circle')||createPermanentVoiceCircle(); 58 | c.classList.add('active'); 59 | } 60 | 61 | function hideVoiceCircle() { 62 | const c=document.getElementById('voice-circle'); 63 | if (c){ c.classList.remove('active'); c.style.setProperty('--dynamic-scale',1); } 64 | } 65 | 66 | function showNotification(msg, type='info'){ 67 | const n=document.createElement('div'); 68 | n.className=`fixed bottom-4 right-4 px-4 py-3 rounded-lg shadow-lg z-50 69 | ${type==='success'?'bg-green-600': 70 | type==='error' ?'bg-red-600':'bg-indigo-600'}`; 71 | n.textContent=msg; 72 | document.body.appendChild(n); 73 | setTimeout(()=>{n.classList.add('opacity-0'); 74 | setTimeout(()=>n.remove(),500)},3000); 75 | } 76 | 77 | function addMessageToConversation(sender,text){ 78 | const pane=document.getElementById('conversationHistory'); 79 | if(!pane) return; 80 | const box=document.createElement('div'); 81 | box.className=`p-3 mb-3 rounded-lg text-sm ${ 82 | sender==='user'?'bg-gray-800 ml-2':'bg-indigo-900 mr-2'}`; 83 | box.innerHTML=` 84 |
85 |
87 | ${sender==='user'?'U':'AI'} 88 |
89 | ${new Date().toLocaleTimeString()} 90 |
91 |
${text 92 | .replace(/&/g,'&').replace(/$1') 94 | .replace(/\*(.*?)\*/g,'$1') 95 | .replace(/```([^`]+)```/g,'
$1
') 96 | .replace(/`([^`]+)`/g,'$1') 97 | .replace(/\n/g,'
')}
`; 98 | pane.appendChild(box); 99 | pane.scrollTop=pane.scrollHeight; 100 | } 101 | 102 | function connectWebSocket() { 103 | if (reconnecting && reconnectAttempts >= maxReconnectAttempts) { 104 | console.error("Maximum reconnect attempts reached. Please refresh the page."); 105 | showNotification("Connection lost. Please refresh the page.", "error"); 106 | return; 107 | } 108 | 109 | if (ws && ws.readyState !== WebSocket.CLOSED && ws.readyState !== WebSocket.CLOSING) { 110 | try { 111 | ws.close(); 112 | } catch (e) { 113 | console.warn("Error closing existing WebSocket:", e); 114 | } 115 | } 116 | 117 | const proto = location.protocol === 'https:' ? 'wss:' : 'ws:'; 118 | ws = new WebSocket(`${proto}//${location.host}/ws`); 119 | window.ws = ws; 120 | 121 | const connLbl = document.getElementById('connectionStatus'); 122 | if (connLbl) { 123 | connLbl.textContent = reconnecting ? 'Reconnecting…' : 'Connecting…'; 124 | connLbl.className = 'text-yellow-500'; 125 | } 126 | 127 | ws.onopen = () => { 128 | if (connLbl) { 129 | connLbl.textContent = 'Connected'; 130 | connLbl.className = 'text-green-500'; 131 | } 132 | 133 | reconnecting = false; 134 | reconnectAttempts = 0; 135 | 136 | ws.send(JSON.stringify({type: 'request_saved_config'})); 137 | 138 | if (!reconnecting) { 139 | addMessageToConversation('ai', 'WebSocket connected. Ready for voice or text.'); 140 | } else { 141 | showNotification("Reconnected successfully", "success"); 142 | } 143 | }; 144 | 145 | ws.onclose = (event) => { 146 | console.log("WebSocket closed with code:", event.code); 147 | if (connLbl) { 148 | connLbl.textContent = 'Disconnected'; 149 | connLbl.className = 'text-red-500'; 150 | } 151 | 152 | // Clear audio state on disconnection 153 | clearAudioPlayback(); 154 | 155 | // Don't auto-reconnect if this was a normal closure 156 | if (event.code !== 1000 && event.code !== 1001) { 157 | reconnecting = true; 158 | reconnectAttempts++; 159 | 160 | const delay = Math.min(1000 * Math.pow(1.5, reconnectAttempts), 1000); 161 | console.log(`Reconnecting in ${delay}ms (attempt ${reconnectAttempts})`); 162 | 163 | setTimeout(connectWebSocket, delay); 164 | } 165 | }; 166 | 167 | ws.onerror = (error) => { 168 | console.error("WebSocket error:", error); 169 | if (connLbl) { 170 | connLbl.textContent = 'Error'; 171 | connLbl.className = 'text-red-500'; 172 | } 173 | }; 174 | 175 | ws.onmessage = (e) => { 176 | try { 177 | const data = JSON.parse(e.data); 178 | handleWebSocketMessage(data); 179 | } catch (err) { 180 | console.error("Error handling WebSocket message:", err); 181 | } 182 | }; 183 | } 184 | 185 | function sendTextMessage(txt) { 186 | if (!txt.trim()) return; 187 | 188 | if (!ws || ws.readyState !== WebSocket.OPEN) { 189 | showNotification("Not connected", "error"); 190 | return; 191 | } 192 | 193 | console.log("Force clearing all audio state before sending text message"); 194 | 195 | // Stop any playing audio 196 | if (isAudioCurrentlyPlaying) { 197 | if (currentAudioSource) { 198 | try { 199 | if (currentAudioSource.disconnect) currentAudioSource.disconnect(); 200 | if (currentAudioSource.stop) currentAudioSource.stop(0); 201 | } catch (e) { 202 | console.warn("Error stopping audio:", e); 203 | } 204 | currentAudioSource = null; 205 | } 206 | isAudioCurrentlyPlaying = false; 207 | } 208 | 209 | // Clear all flags and queues 210 | interruptRequested = false; 211 | interruptInProgress = false; 212 | activeGenId = 0; 213 | audioPlaybackQueue = []; 214 | 215 | // Always force interruption to be absolutely sure 216 | if (ws && ws.readyState === WebSocket.OPEN) { 217 | try { 218 | ws.send(JSON.stringify({type: 'interrupt', immediate: true})); 219 | } catch (e) { 220 | console.warn("Error sending interrupt:", e); 221 | } 222 | } 223 | 224 | // Wait a bit before sending the actual message 225 | setTimeout(() => { 226 | try { 227 | // Show visual feedback 228 | showVoiceCircle(); 229 | 230 | // Send the message 231 | ws.send(JSON.stringify({ 232 | type: 'text_message', 233 | text: txt, 234 | session_id: SESSION_ID 235 | })); 236 | 237 | const cnt = document.getElementById('messageCount'); 238 | if (cnt) cnt.textContent = ++messageCount; 239 | 240 | document.getElementById('textInput').value = ''; 241 | 242 | console.log("Text message sent successfully"); 243 | } catch (error) { 244 | console.error("Error sending message:", error); 245 | showNotification("Error sending message", "error"); 246 | } 247 | }, 300); 248 | } 249 | 250 | // Reset all audio state to ensure clean state for new interactions 251 | function resetAudioState() { 252 | console.log("Resetting audio state"); 253 | 254 | // Clear any stale generation information 255 | activeGenId = 0; 256 | lastSeenGenId = 0; 257 | 258 | // Clear any remaining flags 259 | interruptRequested = false; 260 | interruptInProgress = false; 261 | 262 | // Make sure we don't have any playing audio 263 | if (isAudioCurrentlyPlaying) { 264 | clearAudioPlayback(); 265 | } 266 | 267 | // Clear any queued audio 268 | audioPlaybackQueue = []; 269 | } 270 | 271 | function clearAudioPlayback() { 272 | console.log("FORCEFULLY CLEARING AUDIO PLAYBACK"); 273 | 274 | interruptRequested = true; 275 | interruptInProgress = true; 276 | 277 | try { 278 | // Empty the queue first - do this before stopping current source 279 | console.log(`Clearing queue with ${audioPlaybackQueue.length} items`); 280 | audioPlaybackQueue = []; 281 | 282 | activeGenId = 0; 283 | 284 | // Stop any currently playing audio 285 | if (currentAudioSource) { 286 | console.log("Stopping active audio source"); 287 | 288 | try { 289 | if (currentAudioSource.disconnect) { 290 | currentAudioSource.disconnect(); 291 | } 292 | } catch (e) { 293 | console.warn("Error disconnecting audio source:", e); 294 | } 295 | 296 | try { 297 | if (currentAudioSource.stop) { 298 | currentAudioSource.stop(0); 299 | } 300 | } catch (e) { 301 | console.warn("Error stopping audio source:", e); 302 | } 303 | 304 | currentAudioSource = null; 305 | } 306 | 307 | try { 308 | if (audioContext) { 309 | const oldContext = audioContext; 310 | audioContext = new (window.AudioContext || window.webkitAudioContext)(); 311 | window.audioContext = audioContext; 312 | 313 | try { 314 | oldContext.close(); 315 | } catch (closeError) { 316 | console.warn("Error closing old audio context:", closeError); 317 | } 318 | } else { 319 | audioContext = new (window.AudioContext || window.webkitAudioContext)(); 320 | window.audioContext = audioContext; 321 | } 322 | } catch (contextError) { 323 | console.error("Error recreating audio context:", contextError); 324 | } 325 | } catch (err) { 326 | console.error("Error clearing audio:", err); 327 | } 328 | 329 | // Reset state 330 | isAudioCurrentlyPlaying = false; 331 | hideVoiceCircle(); 332 | 333 | console.log("Audio playback cleared successfully"); 334 | 335 | // After a short delay, reset the interrupt flags to accept new audio 336 | setTimeout(() => { 337 | interruptInProgress = false; 338 | // Keep interruptRequested true until we get a new generation 339 | }, 300); 340 | } 341 | 342 | 343 | // Handle interruption request from user 344 | function requestInterrupt() { 345 | console.log("User requested interruption"); 346 | 347 | if (interruptInProgress) { 348 | console.log("Interrupt already in progress - force clearing again"); 349 | clearAudioPlayback(); 350 | return false; 351 | } 352 | 353 | // Set the flags immediately 354 | interruptRequested = true; 355 | interruptInProgress = true; 356 | 357 | // Show visual feedback 358 | showNotification("Interrupting...", "info"); 359 | 360 | // Force clear all audio immediately on client side 361 | clearAudioPlayback(); 362 | 363 | // Show visual feedback for the button 364 | const interruptBtn = document.getElementById('interruptBtn'); 365 | if (interruptBtn) { 366 | interruptBtn.classList.add('bg-red-800'); 367 | setTimeout(() => { 368 | interruptBtn.classList.remove('bg-red-800'); 369 | }, 300); 370 | } 371 | 372 | // Then notify the server 373 | if (ws && ws.readyState === WebSocket.OPEN) { 374 | console.log("Sending interrupt request to server"); 375 | try { 376 | ws.send(JSON.stringify({ 377 | type: 'interrupt', 378 | immediate: true 379 | })); 380 | } catch (error) { 381 | console.error("Error sending interrupt request:", error); 382 | } 383 | 384 | // Set a timeout to reset interrupt flags if we don't get server confirmation 385 | setTimeout(() => { 386 | if (interruptInProgress) { 387 | console.log("No interrupt confirmation received from server, resetting state"); 388 | interruptInProgress = false; 389 | } 390 | }, 2000); 391 | 392 | return true; 393 | } else { 394 | console.warn("WebSocket not available for interrupt request"); 395 | // Reset flag after brief delay if we couldn't send to server 396 | setTimeout(() => { 397 | interruptInProgress = false; 398 | }, 500); 399 | return false; 400 | } 401 | } 402 | 403 | function handleWebSocketMessage(d) { 404 | console.log("Received message:", d.type, d); 405 | 406 | switch(d.type) { 407 | case 'transcription': 408 | addMessageToConversation('user', d.text); 409 | showVoiceCircle(); 410 | break; 411 | 412 | case 'response': 413 | addMessageToConversation('ai', d.text); 414 | showVoiceCircle(); 415 | 416 | console.log("NEW RESPONSE RECEIVED - FORCE RESETTING ALL AUDIO STATE"); 417 | 418 | if (isAudioCurrentlyPlaying) { 419 | if (currentAudioSource) { 420 | try { 421 | if (currentAudioSource.disconnect) currentAudioSource.disconnect(); 422 | if (currentAudioSource.stop) currentAudioSource.stop(0); 423 | } catch (e) { 424 | console.warn("Error stopping current audio:", e); 425 | } 426 | currentAudioSource = null; 427 | } 428 | isAudioCurrentlyPlaying = false; 429 | } 430 | 431 | interruptRequested = false; 432 | interruptInProgress = false; 433 | 434 | activeGenId = 0; 435 | 436 | audioPlaybackQueue = []; 437 | 438 | try { 439 | if (audioContext) { 440 | if (audioContext.state === 'suspended') { 441 | audioContext.resume().catch(e => console.warn("Error resuming audio context:", e)); 442 | } 443 | } else { 444 | audioContext = new (window.AudioContext || window.webkitAudioContext)(); 445 | window.audioContext = audioContext; 446 | } 447 | } catch (e) { 448 | console.warn("Error with audio context:", e); 449 | audioContext = new (window.AudioContext || window.webkitAudioContext)(); 450 | window.audioContext = audioContext; 451 | } 452 | 453 | console.log("Audio state fully reset and ready for new audio"); 454 | break; 455 | 456 | case 'audio_chunk': 457 | console.log("Audio chunk received, flags:", 458 | "interruptRequested:", interruptRequested, 459 | "interruptInProgress:", interruptInProgress, 460 | "genId:", d.gen_id, 461 | "activeGenId:", activeGenId); 462 | 463 | if (!isAudioCurrentlyPlaying && activeGenId === 0) { 464 | console.log("FIRST AUDIO CHUNK - FORCING FLAGS RESET"); 465 | interruptRequested = false; 466 | interruptInProgress = false; 467 | } 468 | 469 | // Don't queue new audio if an interrupt was requested 470 | if (interruptRequested || interruptInProgress) { 471 | console.log("Interrupt active - ignoring new audio chunk"); 472 | return; 473 | } 474 | 475 | // Set active generation ID on first chunk 476 | if (activeGenId === 0) { 477 | activeGenId = d.gen_id || 1; 478 | console.log("!!! Setting activeGenId to:", activeGenId); 479 | } 480 | 481 | // Only accept chunks that match our active generation 482 | if ((d.gen_id === activeGenId) || (activeGenId === 0)) { 483 | queueAudioForPlayback(d.audio, d.sample_rate, d.gen_id || 0); 484 | showVoiceCircle(); 485 | } else { 486 | console.log(`Ignored stale chunk - current gen: ${activeGenId}, received: ${d.gen_id}`); 487 | } 488 | break; 489 | 490 | case 'audio_status': 491 | console.log("Audio status update:", d.status); 492 | 493 | if (d.status === 'generating') { 494 | console.log("GOT GENERATING STATUS - IMMEDIATELY CLEARING ALL INTERRUPT FLAGS"); 495 | interruptRequested = false; 496 | interruptInProgress = false; 497 | 498 | // Capture the generation ID for new generations 499 | if (d.gen_id) { 500 | console.log(`New generation starting with ID: ${d.gen_id}`); 501 | activeGenId = d.gen_id; 502 | } 503 | 504 | showVoiceCircle(); 505 | } 506 | else if (d.status === 'complete') { 507 | console.log("Audio generation complete"); 508 | if (!d.gen_id || d.gen_id === activeGenId) { 509 | activeGenId = 0; // Reset for next generation 510 | } 511 | if (!isAudioCurrentlyPlaying) { 512 | hideVoiceCircle(); 513 | } 514 | } 515 | else if (d.status === 'interrupted' || d.status === 'interrupt_acknowledged') { 516 | console.log("Server confirmed interrupt - clearing audio"); 517 | clearAudioPlayback(); 518 | 519 | setTimeout(() => { 520 | console.log("Resetting interrupt flags after server confirmation"); 521 | interruptRequested = false; 522 | interruptInProgress = false; 523 | }, 300); 524 | } 525 | break; 526 | 527 | case 'status': 528 | if (d.message === 'Thinking...') { 529 | showVoiceCircle(); 530 | 531 | interruptRequested = false; 532 | interruptInProgress = false; 533 | activeGenId = 0; 534 | } 535 | break; 536 | 537 | case 'error': 538 | showNotification(d.message, 'error'); 539 | hideVoiceCircle(); 540 | break; 541 | 542 | case 'vad_status': 543 | if (d.status === 'speech_started') { 544 | console.log(`[VAD] speech_started | should_interrupt=${d.should_interrupt}`); 545 | 546 | if (d.should_interrupt && isAudioCurrentlyPlaying) { 547 | console.log('[VAD] confirmed – sending interrupt'); 548 | requestInterrupt(); 549 | } else { 550 | console.log('[VAD] ignored (echo / early AI audio)'); 551 | } 552 | } 553 | break; 554 | } 555 | } 556 | 557 | function queueAudioForPlayback(arr, sr, genId = 0) { 558 | if (activeGenId !== 0 && genId !== activeGenId) { 559 | console.log(`Stale chunk ignored (genId mismatch): ${genId} vs ${activeGenId}`); 560 | return; 561 | } 562 | 563 | // Don't queue if interrupting 564 | if (interruptRequested || interruptInProgress) { 565 | console.log("Interrupt active - skipping audio chunk"); 566 | return; 567 | } 568 | 569 | console.log("Queueing audio chunk for playback"); 570 | audioPlaybackQueue.push({arr, sr, genId}); 571 | 572 | if (!isAudioCurrentlyPlaying) { 573 | console.log("▶Starting audio playback"); 574 | processAudioPlaybackQueue(); 575 | } 576 | } 577 | 578 | function queueAudioForPlayback(arr, sr, genId = 0) { 579 | // Extra logging for the first audio chunk 580 | if (!isAudioCurrentlyPlaying) { 581 | console.log("Queueing first audio chunk", 582 | "interruptRequested:", interruptRequested, 583 | "interruptInProgress:", interruptInProgress); 584 | } 585 | 586 | if (!isAudioCurrentlyPlaying && audioPlaybackQueue.length === 0) { 587 | console.log("First audio chunk - forcing clear of interrupt flags"); 588 | interruptRequested = false; 589 | interruptInProgress = false; 590 | } 591 | 592 | // Don't queue audio from a different generation than our active one 593 | if (activeGenId !== 0 && genId !== activeGenId) { 594 | console.log(`Stale chunk ignored (genId mismatch): ${genId} vs ${activeGenId}`); 595 | return; 596 | } 597 | 598 | // Don't queue if interrupting - BUT CHECK AGAIN THAT FLAGS ARE VALID 599 | if (interruptRequested || interruptInProgress) { 600 | console.log("Interrupt active - skipping audio chunk"); 601 | return; 602 | } 603 | 604 | console.log("Queueing audio chunk for playback"); 605 | audioPlaybackQueue.push({arr, sr, genId}); 606 | 607 | if (!isAudioCurrentlyPlaying) { 608 | console.log("STARTING AUDIO PLAYBACK - FIRST CHUNK"); 609 | processAudioPlaybackQueue(); 610 | } 611 | } 612 | 613 | 614 | // Modified to ensure first audio actually plays 615 | function processAudioPlaybackQueue() { 616 | if (!isAudioCurrentlyPlaying && audioPlaybackQueue.length > 0) { 617 | console.log("Starting first audio chunk - force clearing interrupt flags"); 618 | interruptRequested = false; 619 | interruptInProgress = false; 620 | } 621 | 622 | // Double-check interrupt flags AFTER clearling them 623 | if (interruptRequested || interruptInProgress) { 624 | console.log("Interrupt active - not processing audio queue"); 625 | isAudioCurrentlyPlaying = false; 626 | hideVoiceCircle(); 627 | return; 628 | } 629 | 630 | // Check if queue is empty 631 | if (!audioPlaybackQueue.length) { 632 | console.log("📭 Audio queue empty, stopping playback"); 633 | isAudioCurrentlyPlaying = false; 634 | hideVoiceCircle(); 635 | currentAudioSource = null; 636 | return; 637 | } 638 | 639 | // Enable the interrupt button when audio is playing 640 | const interruptBtn = document.getElementById('interruptBtn'); 641 | if (interruptBtn) { 642 | interruptBtn.disabled = false; 643 | interruptBtn.classList.remove('opacity-50'); 644 | } 645 | 646 | console.log("Processing next audio chunk"); 647 | isAudioCurrentlyPlaying = true; 648 | 649 | // Get the genId from the chunk 650 | const {arr, sr, genId} = audioPlaybackQueue.shift(); 651 | 652 | // Skip if it's a stale chunk 653 | if (activeGenId !== 0 && genId !== activeGenId) { 654 | console.log(`Skipping stale chunk playback (gen ${genId} vs active ${activeGenId})`); 655 | processAudioPlaybackQueue(); // Continue with next chunk 656 | return; 657 | } 658 | 659 | playAudioChunk(arr, sr) 660 | .then(() => { 661 | // Check interrupt status again after playback 662 | if (!interruptRequested && !interruptInProgress) { 663 | processAudioPlaybackQueue(); 664 | } else { 665 | console.log("interrupt active - stopping queue processing"); 666 | isAudioCurrentlyPlaying = false; 667 | hideVoiceCircle(); 668 | } 669 | }) 670 | .catch(err => { 671 | console.error("Error in audio playback:", err); 672 | isAudioCurrentlyPlaying = false; 673 | hideVoiceCircle(); 674 | 675 | // Try to continue with next chunk despite errors 676 | setTimeout(() => { 677 | if (audioPlaybackQueue.length > 0 && !interruptRequested) { 678 | processAudioPlaybackQueue(); 679 | } 680 | }, 200); 681 | }); 682 | } 683 | 684 | async function playAudioChunk(audioArr, sampleRate) { 685 | // Skip playback if interrupt was requested 686 | if (interruptRequested || interruptInProgress) { 687 | console.log("Interrupt active - not playing audio chunk"); 688 | return Promise.resolve(); 689 | } 690 | 691 | try { 692 | // Ensure we have a valid audio context 693 | if (!audioContext) { 694 | audioContext = new (window.AudioContext || window.webkitAudioContext)(); 695 | window.audioContext = audioContext; 696 | } 697 | 698 | // Make sure context is resumed 699 | if (audioContext.state === 'suspended') { 700 | await audioContext.resume(); 701 | } 702 | 703 | const buf = audioContext.createBuffer(1, audioArr.length, sampleRate); 704 | buf.copyToChannel(new Float32Array(audioArr), 0); 705 | 706 | const src = audioContext.createBufferSource(); 707 | src.buffer = buf; 708 | 709 | // Store reference to current source for potential interruption 710 | currentAudioSource = src; 711 | 712 | const an = audioContext.createAnalyser(); 713 | an.fftSize = 256; 714 | src.connect(an); 715 | an.connect(audioContext.destination); 716 | src.start(); 717 | 718 | console.log("🎵 Started playing audio chunk"); 719 | 720 | const arr = new Uint8Array(an.frequencyBinCount); 721 | const circle = document.getElementById('voice-circle'); 722 | 723 | // Animation function that respects interruption 724 | function pump() { 725 | // Stop animation if source is no longer current or interrupt requested 726 | if (src !== currentAudioSource || interruptRequested || interruptInProgress) { 727 | return; 728 | } 729 | 730 | try { 731 | an.getByteFrequencyData(arr); 732 | const avg = arr.reduce((a,b) => a+b, 0) / arr.length; 733 | if (circle) { 734 | circle.style.setProperty('--dynamic-scale', (1+avg/255*1.5).toFixed(3)); 735 | } 736 | } catch (e) { 737 | console.warn("Error in animation pump:", e); 738 | return; 739 | } 740 | 741 | if (src.playbackState !== src.FINISHED_STATE) { 742 | requestAnimationFrame(pump); 743 | } 744 | } 745 | pump(); 746 | 747 | return new Promise(resolve => { 748 | src.onended = () => { 749 | // Only resolve if this is still the current source and no interrupt 750 | if (src === currentAudioSource && !interruptRequested && !interruptInProgress) { 751 | resolve(); 752 | } else { 753 | resolve(); // Still resolve to maintain chain 754 | } 755 | }; 756 | }); 757 | } catch (error) { 758 | console.error("Error playing audio chunk:", error); 759 | return Promise.resolve(); // Resolve anyway to keep chain going 760 | } 761 | } 762 | 763 | async function startRecording() { 764 | if (isRecording) return; 765 | try { 766 | const constraints = { 767 | audio: selectedMicId ? {deviceId:{exact:selectedMicId}} : true 768 | }; 769 | micStream = await navigator.mediaDevices.getUserMedia(constraints); 770 | 771 | if (!audioContext) audioContext = new (AudioContext||webkitAudioContext)(); 772 | const src = audioContext.createMediaStreamSource(micStream); 773 | const proc = audioContext.createScriptProcessor(4096,1,1); 774 | src.connect(proc); proc.connect(audioContext.destination); 775 | 776 | proc.onaudioprocess = e => { 777 | const samples = Array.from(e.inputBuffer.getChannelData(0)); 778 | if (ws && ws.readyState === WebSocket.OPEN) { 779 | try { 780 | ws.send(JSON.stringify({ 781 | type:'audio', 782 | audio:samples, 783 | sample_rate:audioContext.sampleRate, 784 | session_id:SESSION_ID 785 | })); 786 | } catch (error) { 787 | console.error("Error sending audio data:", error); 788 | stopRecording(); 789 | } 790 | } 791 | }; 792 | 793 | window._micProcessor = proc; 794 | isRecording = true; 795 | document.getElementById('micStatus').textContent = 'Listening…'; 796 | showVoiceCircle(); 797 | } catch (err) { 798 | console.error("Microphone access error:", err); 799 | showNotification('Microphone access denied','error'); 800 | } 801 | } 802 | 803 | function stopRecording() { 804 | if (!isRecording) return; 805 | try { 806 | if (window._micProcessor) { 807 | window._micProcessor.disconnect(); 808 | window._micProcessor = null; 809 | } 810 | if (micStream) { 811 | micStream.getTracks().forEach(t => t.stop()); 812 | micStream = null; 813 | } 814 | } catch(e) { 815 | console.warn("Error stopping recording:", e); 816 | } 817 | isRecording = false; 818 | 819 | const micStatus = document.getElementById('micStatus'); 820 | if (micStatus) { 821 | micStatus.textContent = 'Click to speak'; 822 | } 823 | hideVoiceCircle(); 824 | } 825 | 826 | async function setupChatUI() { 827 | document.documentElement.classList.add('bg-gray-950'); 828 | document.documentElement.style.backgroundColor = '#030712'; 829 | 830 | createPermanentVoiceCircle(); 831 | connectWebSocket(); 832 | 833 | initAudioLevelsChart(); 834 | 835 | const txt = document.getElementById('textInput'); 836 | const btn = document.getElementById('sendTextBtn'); 837 | 838 | // Setup enhanced interrupt button 839 | const interruptBtn = document.createElement('button'); 840 | interruptBtn.id = 'interruptBtn'; 841 | interruptBtn.className = 'px-3 py-2 ml-2 bg-red-600 text-white rounded hover:bg-red-700 flex items-center transition duration-150'; 842 | interruptBtn.innerHTML = ' Stop'; 843 | interruptBtn.onclick = (e) => { 844 | e.preventDefault(); 845 | try { 846 | requestInterrupt(); 847 | interruptBtn.classList.add('bg-red-800', 'scale-95'); 848 | setTimeout(() => interruptBtn.classList.remove('bg-red-800', 'scale-95'), 150); 849 | } catch (error) { 850 | console.error("Error in interrupt button handler:", error); 851 | } 852 | }; 853 | interruptBtn.title = "Stop AI speech (Space or Esc)"; 854 | interruptBtn.disabled = true; // Disabled by default 855 | interruptBtn.classList.add('opacity-50', 'cursor-not-allowed'); 856 | 857 | if (btn && btn.parentElement) { 858 | btn.parentElement.appendChild(interruptBtn); 859 | } 860 | 861 | // Add debug button for easier debugging of interrupt issues 862 | const debugBtn = document.createElement('button'); 863 | debugBtn.innerText = "Debug Audio"; 864 | debugBtn.className = "px-3 py-2 ml-2 bg-blue-600 text-white rounded text-xs"; 865 | debugBtn.onclick = () => { 866 | console.log("- Debug info:"); 867 | console.log("- Audio playing:", isAudioCurrentlyPlaying); 868 | console.log("- Interrupt requested:", interruptRequested); 869 | console.log("- Interrupt in progress:", interruptInProgress); 870 | console.log("- Current source:", currentAudioSource); 871 | console.log("- Queue length:", audioPlaybackQueue.length); 872 | console.log("- Audio context state:", audioContext?.state); 873 | console.log("- Active generation ID:", activeGenId); 874 | console.log("- Last seen generation ID:", lastSeenGenId); 875 | console.log("- WebSocket state:", ws ? ws.readyState : "no websocket"); 876 | showNotification("Debug info in console", "info"); 877 | }; 878 | 879 | if (btn && btn.parentElement) { 880 | btn.parentElement.appendChild(debugBtn); 881 | } 882 | 883 | // Run the update function periodically 884 | setInterval(() => { 885 | const interruptBtn = document.getElementById('interruptBtn'); 886 | if (interruptBtn) { 887 | if (isAudioCurrentlyPlaying && !interruptRequested && !interruptInProgress) { 888 | interruptBtn.disabled = false; 889 | interruptBtn.classList.remove('opacity-50', 'cursor-not-allowed'); 890 | } else { 891 | interruptBtn.disabled = true; 892 | interruptBtn.classList.add('opacity-50', 'cursor-not-allowed'); 893 | } 894 | } 895 | }, 300); 896 | 897 | if (btn) { 898 | btn.onclick = () => { 899 | try { 900 | sendTextMessage(txt.value); 901 | } catch (error) { 902 | console.error("Error in send button handler:", error); 903 | } 904 | }; 905 | } 906 | 907 | if (txt) { 908 | txt.addEventListener('keydown', e => { 909 | if (e.key === 'Enter' && !e.shiftKey) { 910 | e.preventDefault(); 911 | try { 912 | sendTextMessage(txt.value); 913 | } catch (error) { 914 | console.error("Error in text input handler:", error); 915 | } 916 | } 917 | }); 918 | } 919 | 920 | const micBtn = document.getElementById('micToggleBtn'); 921 | if (micBtn) { 922 | micBtn.addEventListener('click', () => { 923 | try { 924 | if (isRecording) stopRecording(); 925 | else startRecording(); 926 | } catch (error) { 927 | console.error("Error in mic button handler:", error); 928 | } 929 | }); 930 | } 931 | 932 | // Add event listeners to detect keyboard interruptions 933 | document.addEventListener('keydown', e => { 934 | // Allow space or escape to interrupt 935 | if ((e.code === 'Space' || e.code === 'Escape') && isAudioCurrentlyPlaying) { 936 | e.preventDefault(); 937 | try { 938 | requestInterrupt(); 939 | 940 | // Add visual feedback 941 | const interruptBtn = document.getElementById('interruptBtn'); 942 | if (interruptBtn) { 943 | interruptBtn.classList.add('bg-red-800'); 944 | setTimeout(() => { 945 | interruptBtn.classList.remove('bg-red-800'); 946 | }, 200); 947 | } 948 | } catch (error) { 949 | console.error("Error in keyboard interrupt handler:", error); 950 | } 951 | } 952 | }); 953 | 954 | // Initialize audio context 955 | if (!audioContext) { 956 | try { 957 | audioContext = new (window.AudioContext || window.webkitAudioContext)(); 958 | window.audioContext = audioContext; 959 | } catch (error) { 960 | console.error("Error creating audio context:", error); 961 | showNotification("Audio initialization failed. Please refresh the page.", "error"); 962 | } 963 | } 964 | 965 | // Try to unlock audio context on user interaction 966 | ['click', 'touchstart', 'keydown'].forEach(ev => 967 | document.addEventListener(ev, function unlock() { 968 | if (audioContext && audioContext.state === 'suspended') { 969 | try { 970 | audioContext.resume(); 971 | } catch (error) { 972 | console.warn("Error resuming audio context:", error); 973 | } 974 | } 975 | document.removeEventListener(ev, unlock); 976 | }) 977 | ); 978 | 979 | console.log("Chat UI ready with enhanced interruption support"); 980 | } 981 | 982 | if (document.readyState === 'loading') { 983 | document.addEventListener('DOMContentLoaded', setupChatUI); 984 | } else { 985 | setupChatUI(); 986 | } 987 | 988 | function initAudioLevelsChart() { 989 | const ctx = document.getElementById('audioLevels'); 990 | if (!ctx) return; 991 | 992 | try { 993 | if (audioLevelsChart) audioLevelsChart.destroy(); 994 | 995 | const grad = ctx.getContext('2d').createLinearGradient(0, 0, 0, 100); 996 | grad.addColorStop(0, 'rgba(79,70,229,.6)'); 997 | grad.addColorStop(1, 'rgba(79,70,229,.1)'); 998 | 999 | audioLevelsChart = new Chart(ctx, { 1000 | type: 'line', 1001 | data: { 1002 | labels: Array(30).fill(''), 1003 | datasets: [{ 1004 | data: Array(30).fill(0), 1005 | backgroundColor: grad, 1006 | borderColor: 'rgba(99,102,241,1)', 1007 | borderWidth: 2, 1008 | tension: .4, 1009 | fill: true, 1010 | pointRadius: 0 1011 | }] 1012 | }, 1013 | options: { 1014 | animation: false, 1015 | responsive: true, 1016 | scales: { 1017 | y: { 1018 | beginAtZero: true, 1019 | max: 100, 1020 | ticks: {display: false}, 1021 | grid: {color: 'rgba(255,255,255,.1)'} 1022 | }, 1023 | x: {display: false, grid: {display: false}} 1024 | }, 1025 | plugins: { 1026 | legend: {display: false}, 1027 | tooltip: {enabled: false} 1028 | }, 1029 | elements: {point: {radius: 0}} 1030 | } 1031 | }); 1032 | } catch (error) { 1033 | console.error("Error initializing audio chart:", error); 1034 | } 1035 | } -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import math 3 | import os 4 | from typing import List, Tuple, Generator as PyGenerator, Optional, Callable 5 | import time 6 | import queue 7 | import threading 8 | import platform 9 | from typing_extensions import OrderedDict 10 | import wave 11 | import numpy as np 12 | import torch 13 | import torchaudio 14 | from huggingface_hub import hf_hub_download 15 | from models import Model, ModelArgs 16 | from moshi.models import loaders 17 | from tokenizers.processors import TemplateProcessing 18 | from transformers import AutoTokenizer 19 | import logging 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | @dataclass 24 | class Segment: 25 | speaker: int 26 | text: str 27 | sample_rate = 24_000 28 | audio: torch.Tensor 29 | 30 | 31 | def load_llama3_tokenizer(): 32 | """ 33 | https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992 34 | """ 35 | tokenizer_name = "unsloth/Llama-3.2-1B" 36 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 37 | bos = tokenizer.bos_token 38 | eos = tokenizer.eos_token 39 | tokenizer._tokenizer.post_processor = TemplateProcessing( 40 | single=f"{bos}:0 $A:0 {eos}:0", 41 | pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1", 42 | special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)], 43 | ) 44 | 45 | return tokenizer 46 | 47 | 48 | class Generator: 49 | def __init__(self, model: Model): 50 | self._model = model 51 | self._model.setup_caches(1) 52 | 53 | self._text_tokenizer = load_llama3_tokenizer() 54 | device = next(model.parameters()).device 55 | 56 | mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME) 57 | mimi = loaders.get_mimi(mimi_weight, device=device) 58 | 59 | num_codebooks = model.config.audio_num_codebooks 60 | mimi.set_num_codebooks(num_codebooks) 61 | self._num_codebooks = num_codebooks 62 | self._audio_tokenizer = mimi 63 | 64 | self.sample_rate = mimi.sample_rate 65 | self.device = device 66 | 67 | self._stream_buffer_size = 20 68 | self.max_seq_len = 2048 69 | self._cache = OrderedDict() 70 | self._text_token_cache = {} 71 | torch.set_num_threads(16) 72 | torch.cuda.set_per_process_memory_fraction(0.95) 73 | 74 | def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]: 75 | """ 76 | Tokenize text segment with caching optimization for reduced latency. 77 | """ 78 | # Check cache first 79 | cache_key = f"{speaker}:{text}" 80 | if not hasattr(self, '_text_token_cache'): 81 | self._text_token_cache = {} 82 | 83 | if cache_key in self._text_token_cache: 84 | return self._text_token_cache[cache_key] 85 | 86 | text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}") 87 | text_frame = torch.zeros(len(text_tokens), self._num_codebooks+1, dtype=torch.long, device=self.device) 88 | text_frame_mask = torch.zeros(len(text_tokens), self._num_codebooks+1, dtype=torch.bool, device=self.device) 89 | text_frame[:, -1] = torch.tensor(text_tokens, device=self.device) 90 | text_frame_mask[:, -1] = True 91 | 92 | frame_tokens = [text_frame] 93 | frame_masks = [text_frame_mask] 94 | 95 | result = (torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)) 96 | 97 | self._text_token_cache[cache_key] = result 98 | 99 | return result 100 | 101 | 102 | def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 103 | 104 | frame_tokens = [] 105 | frame_masks = [] 106 | 107 | # (K, T) 108 | audio = audio.to(self.device) 109 | audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0] 110 | 111 | # Limit to the number of codebooks set in MIMI 112 | audio_tokens = audio_tokens[:self._num_codebooks, :] 113 | 114 | # add EOS frame 115 | eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device) 116 | audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1) 117 | 118 | audio_frame = torch.zeros(audio_tokens.size(1), self._num_codebooks+1).long().to(self.device) 119 | audio_frame_mask = torch.zeros(audio_tokens.size(1), self._num_codebooks+1).bool().to(self.device) 120 | audio_frame[:, :self._num_codebooks] = audio_tokens.transpose(0, 1) 121 | audio_frame_mask[:, :self._num_codebooks] = True 122 | 123 | frame_tokens.append(audio_frame) 124 | frame_masks.append(audio_frame_mask) 125 | 126 | return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) 127 | 128 | def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]: 129 | text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker) 130 | audio_tokens, audio_masks = self._tokenize_audio(segment.audio) 131 | 132 | total_len = text_tokens.size(0) + audio_tokens.size(0) 133 | 134 | if total_len > self.max_seq_len: 135 | overflow = total_len - self.max_seq_len 136 | 137 | if text_tokens.size(0) > overflow: 138 | text_tokens = text_tokens[overflow:] 139 | text_masks = text_masks[overflow:] 140 | else: 141 | audio_overflow = overflow - text_tokens.size(0) 142 | text_tokens = text_tokens[0:0] 143 | text_masks = text_masks[0:0] 144 | audio_tokens = audio_tokens[audio_overflow:] 145 | audio_masks = audio_masks[audio_overflow:] 146 | 147 | return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0) 148 | 149 | @torch.inference_mode() 150 | def _decode_frames(self, frames): 151 | if not frames: 152 | return torch.tensor([]) 153 | 154 | # Only use first N codebooks for faster decoding 155 | frames_reduced = [frame[:, :self._num_codebooks//2] for frame in frames] 156 | audio = self._audio_tokenizer.decode(torch.stack(frames_reduced).permute(1, 2, 0)).squeeze(0).squeeze(0) 157 | return audio 158 | 159 | @torch.inference_mode() 160 | def generate_stream( 161 | self, 162 | text: str, 163 | speaker: int, 164 | context: List[Segment], 165 | max_audio_length_ms: float = 90_000, 166 | temperature: float = 0.7, 167 | topk: int = 30, 168 | on_chunk_generated: Optional[Callable[[torch.Tensor], None]] = None, 169 | ): 170 | """ 171 | Generate audio in a streaming fashion, optimized for lower latency to first chunk. 172 | """ 173 | if torch.cuda.is_available(): 174 | torch.backends.cuda.matmul.allow_tf32 = True 175 | torch.backends.cudnn.benchmark = True 176 | torch.cuda.empty_cache() 177 | torch.cuda.synchronize() 178 | 179 | self._model.reset_caches() 180 | 181 | max_generation_len = int(max_audio_length_ms / 80) 182 | 183 | tokens, tokens_mask = [], [] 184 | 185 | initial_batch_size = 20 186 | normal_batch_size = 20 187 | initial_buffer_size = 20 188 | normal_buffer_size = 20 189 | 190 | batch_size = initial_batch_size 191 | buffer_size = initial_buffer_size 192 | first_chunk_delivered = False 193 | 194 | context_start = time.time() 195 | if context: 196 | for segment in context: 197 | segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) 198 | tokens.append(segment_tokens) 199 | tokens_mask.append(segment_tokens_mask) 200 | 201 | gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker) 202 | tokens.append(gen_segment_tokens) 203 | tokens_mask.append(gen_segment_tokens_mask) 204 | 205 | prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) 206 | prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) 207 | 208 | max_seq_len = 2048 209 | if prompt_tokens.size(0) > max_seq_len: 210 | prompt_tokens = prompt_tokens[-max_seq_len:] 211 | prompt_tokens_mask = prompt_tokens_mask[-max_seq_len:] 212 | 213 | curr_tokens = prompt_tokens.unsqueeze(0) 214 | curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) 215 | curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) 216 | 217 | expected_frame_count = buffer_size 218 | frame_buffer = [] 219 | 220 | zeros_1_1 = torch.zeros(1, 1).long().to(self.device) 221 | zeros_mask_1_1 = torch.zeros(1, 1).bool().to(self.device) 222 | 223 | def update_tokens(sample): 224 | nonlocal curr_tokens, curr_tokens_mask, curr_pos 225 | ones = torch.ones_like(sample).bool() 226 | curr_tokens = torch.cat([sample, zeros_1_1], dim=1).unsqueeze(1) 227 | curr_tokens_mask = torch.cat([ones, zeros_mask_1_1], dim=1).unsqueeze(1) 228 | curr_pos = curr_pos[:, -1:] + 1 229 | 230 | with self._audio_tokenizer.streaming(1): 231 | i = 0 232 | generation_start = time.time() 233 | 234 | while i < max_generation_len: 235 | batch_end = min(i + batch_size, max_generation_len) 236 | batch_size_actual = batch_end - i 237 | 238 | batch_samples = [] 239 | 240 | for _ in range(batch_size_actual): 241 | with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): 242 | sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) 243 | if torch.cuda.is_available() and hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available"): 244 | try: 245 | torch.cuda.synchronize() # Force sync before checking 246 | if sample.numel() == 0 or torch.isnan(sample).any(): 247 | print("Warning: Generated empty or NaN sample, stopping generation") 248 | break 249 | except: 250 | print("Error checking tensor, stopping generation") 251 | break 252 | if torch.all(sample == 0): 253 | break 254 | 255 | batch_samples.append(sample) 256 | update_tokens(sample) 257 | 258 | if not batch_samples: 259 | break 260 | 261 | frame_buffer.extend(batch_samples) 262 | i += len(batch_samples) 263 | 264 | if len(frame_buffer) >= buffer_size: 265 | frames_to_process = frame_buffer[:expected_frame_count] 266 | 267 | # If we don't have enough frames, pad with zeros to match expected shape 268 | if len(frames_to_process) < expected_frame_count: 269 | # Create padding frames (zeros) 270 | padding_frames = [ 271 | torch.zeros_like(frames_to_process[0]) 272 | for _ in range(expected_frame_count - len(frames_to_process)) 273 | ] 274 | 275 | # Combine actual frames with padding 276 | frames_to_process = frames_to_process + padding_frames 277 | 278 | frames_stacked = torch.stack(frames_to_process).permute(1, 2, 0) 279 | audio_chunk = self._audio_tokenizer.decode(frames_stacked).squeeze(0).squeeze(0) 280 | 281 | # Keep remaining frames for next iteration 282 | frame_buffer = frame_buffer[expected_frame_count:] 283 | 284 | # Process and yield the chunk 285 | cpu_chunk = audio_chunk.cpu() 286 | if on_chunk_generated: 287 | on_chunk_generated(cpu_chunk) 288 | 289 | # After first chunk is delivered, switch to normal batch and buffer sizes 290 | if not first_chunk_delivered: 291 | batch_size = normal_batch_size 292 | buffer_size = normal_buffer_size 293 | expected_frame_count = buffer_size 294 | first_chunk_delivered = True 295 | 296 | yield cpu_chunk 297 | 298 | # Occasionally print progress and sync GPU 299 | if i >= 100 and (i % 100 == 0): 300 | if torch.cuda.is_available(): 301 | torch.cuda.synchronize() 302 | print(f"Generated {i} frames ({i * 0.08:.2f}s of audio)") 303 | 304 | # Process any remaining frames 305 | if frame_buffer: 306 | # Pad frame buffer if necessary 307 | if len(frame_buffer) < expected_frame_count: 308 | padding_frames = [ 309 | torch.zeros_like(frame_buffer[0]) 310 | for _ in range(expected_frame_count - len(frame_buffer)) 311 | ] 312 | frames_to_process = frame_buffer + padding_frames 313 | else: 314 | # Otherwise take as many frames as possible that are a multiple of expected_frame_count 315 | frames_multiple = (len(frame_buffer) // expected_frame_count) * expected_frame_count 316 | frames_to_process = frame_buffer[:frames_multiple] 317 | 318 | frames_stacked = torch.stack(frames_to_process).permute(1, 2, 0) 319 | audio_chunk = self._audio_tokenizer.decode(frames_stacked).squeeze(0).squeeze(0) 320 | 321 | # Determine actual audio length (before padding) 322 | actual_frames_percentage = min(len(frame_buffer), expected_frame_count) / expected_frame_count 323 | actual_samples = int(audio_chunk.shape[0] * actual_frames_percentage) 324 | 325 | # Return only the non-padded portion of audio if we added padding 326 | if len(frame_buffer) < expected_frame_count: 327 | audio_chunk = audio_chunk[:actual_samples] 328 | 329 | cpu_chunk = audio_chunk.cpu() 330 | if on_chunk_generated: 331 | on_chunk_generated(cpu_chunk) 332 | yield cpu_chunk 333 | 334 | # Print final performance metrics 335 | if torch.cuda.is_available(): 336 | torch.cuda.synchronize() 337 | total_time = time.time() - generation_start 338 | frames_generated = i 339 | audio_seconds = frames_generated * 0.08 340 | rtf = total_time / audio_seconds if audio_seconds > 0 else float('inf') 341 | print(f"Total time: {total_time:.2f}s") 342 | print(f"Generated {frames_generated} frames ({audio_seconds:.2f}s of audio)") 343 | print(f"Real-time factor: {rtf:.3f}x (target: <1.0)") 344 | 345 | @torch.inference_mode() 346 | def generate( 347 | self, 348 | text: str, 349 | speaker: int, 350 | context: List[Segment], 351 | max_audio_length_ms: float = 90_000, 352 | temperature: float = 0.8, 353 | topk: int = 40, 354 | stream: bool = False, 355 | output_file: Optional[str] = None, 356 | ): 357 | """ 358 | Generate audio with optional streaming and file output. 359 | 360 | Args: 361 | text: Text to generate audio for 362 | speaker: Speaker ID 363 | context: List of context segments 364 | max_audio_length_ms: Maximum audio length in milliseconds 365 | temperature: Sampling temperature 366 | topk: Top-k sampling parameter 367 | stream: Whether to use streaming generation 368 | output_file: If provided and stream=True, output will be saved to this file 369 | 370 | Returns: 371 | torch.Tensor: Generated audio tensor 372 | """ 373 | if stream: 374 | if output_file: 375 | # Setup streaming to file 376 | write_chunk, close_wav = stream_audio_to_wav(output_file, self.sample_rate) 377 | 378 | # Collect chunks while streaming to file 379 | audio_chunks = [] 380 | t1 = time.time() 381 | 382 | for i, chunk in enumerate(self.generate_stream( 383 | text, speaker, context, max_audio_length_ms, temperature, topk 384 | )): 385 | # Write to file 386 | write_chunk(chunk) 387 | # Store for return value 388 | audio_chunks.append(chunk) 389 | 390 | # Occasionally print progress 391 | if i % 5 == 0: 392 | print(f"Part {i+1} available after {time.time() - t1:.4f}s") 393 | t1 = time.time() 394 | 395 | # Close file 396 | close_wav() 397 | print(f"Streaming complete, WAV file saved to {output_file}") 398 | else: 399 | # Just collect chunks without file output 400 | audio_chunks = [] 401 | for chunk in self.generate_stream(text, speaker, context, max_audio_length_ms, temperature, topk): 402 | audio_chunks.append(chunk) 403 | 404 | if not audio_chunks: 405 | return torch.tensor([]) 406 | return torch.cat(audio_chunks) 407 | 408 | # Non-streaming generation remains unchanged 409 | if torch.cuda.is_available(): 410 | torch.cuda.empty_cache() 411 | 412 | self._model.reset_caches() 413 | 414 | max_generation_len = int(max_audio_length_ms / 80) 415 | tokens, tokens_mask = [], [] 416 | 417 | for segment in context: 418 | segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) 419 | tokens.append(segment_tokens) 420 | tokens_mask.append(segment_tokens_mask) 421 | 422 | gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker) 423 | tokens.append(gen_segment_tokens) 424 | tokens_mask.append(gen_segment_tokens_mask) 425 | 426 | prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) 427 | prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) 428 | 429 | max_seq_len = 2048 430 | if prompt_tokens.size(0) > max_seq_len: 431 | prompt_tokens = prompt_tokens[-max_seq_len:] 432 | prompt_tokens_mask = prompt_tokens_mask[-max_seq_len:] 433 | 434 | curr_tokens = prompt_tokens.unsqueeze(0) 435 | curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) 436 | curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) 437 | 438 | samples = [] 439 | with self._audio_tokenizer.streaming(1): 440 | for _ in range(max_generation_len): 441 | sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) 442 | if torch.all(sample == 0): 443 | break 444 | samples.append(sample) 445 | 446 | curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) 447 | curr_tokens_mask = torch.cat( 448 | [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 449 | ).unsqueeze(1) 450 | curr_pos = curr_pos[:, -1:] + 1 451 | 452 | if not samples: 453 | return torch.tensor([]) 454 | 455 | return self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0) 456 | 457 | class AudioStreamWriter: 458 | """ 459 | Helper class for writing streaming audio to a file. 460 | """ 461 | def __init__(self, filename, sample_rate): 462 | self.filename = filename 463 | self.sample_rate = sample_rate 464 | self.audio_chunks = [] 465 | self.lock = threading.Lock() 466 | self.queue = queue.Queue() 467 | self.running = True 468 | 469 | # Start background writer thread 470 | self.writer_thread = threading.Thread(target=self._writer_worker, daemon=True) 471 | self.writer_thread.start() 472 | 473 | def _writer_worker(self): 474 | """Background thread that handles audio chunk processing""" 475 | buffer_chunks = [] 476 | last_flush_time = time.time() 477 | 478 | while self.running or not self.queue.empty(): 479 | try: 480 | # Get chunk with timeout to allow for regular checks 481 | chunk = self.queue.get(timeout=0.2) 482 | buffer_chunks.append(chunk) 483 | 484 | # Periodically flush the buffer to the main list 485 | current_time = time.time() 486 | if len(buffer_chunks) >= 10 or (current_time - last_flush_time > 2.0 and buffer_chunks): 487 | with self.lock: 488 | self.audio_chunks.extend(buffer_chunks) 489 | buffer_chunks = [] 490 | last_flush_time = current_time 491 | 492 | except queue.Empty: 493 | # If queue is empty but we have pending chunks, add them 494 | if buffer_chunks: 495 | with self.lock: 496 | self.audio_chunks.extend(buffer_chunks) 497 | buffer_chunks = [] 498 | last_flush_time = time.time() 499 | 500 | # Final flush of any remaining chunks 501 | if buffer_chunks: 502 | with self.lock: 503 | self.audio_chunks.extend(buffer_chunks) 504 | 505 | def add_chunk(self, chunk): 506 | """Add an audio chunk to the buffer queue without blocking""" 507 | try: 508 | self.queue.put(chunk, timeout=0.1) 509 | except queue.Full: 510 | # If queue is full, add directly to avoid losing data 511 | with self.lock: 512 | self.audio_chunks.append(chunk) 513 | 514 | def write_file(self): 515 | """Write all collected audio chunks to file and clean up""" 516 | # Signal the background thread to stop 517 | self.running = False 518 | # Wait for the thread to finish with a timeout 519 | self.writer_thread.join(timeout=3.0) 520 | 521 | with self.lock: 522 | if not self.audio_chunks: 523 | return 524 | 525 | # Concatenate all chunks 526 | audio = torch.cat(self.audio_chunks) 527 | # Save to file 528 | torchaudio.save(self.filename, audio.unsqueeze(0).cpu(), self.sample_rate) 529 | 530 | from safetensors.torch import load_file 531 | import os 532 | import torch 533 | from models import Model, ModelArgs 534 | from generator import Generator 535 | 536 | def load_csm_1b_local(model_path: str, device: str = "cuda", audio_num_codebooks: int = 32): 537 | """ 538 | Load the CSM-1B model from a local checkpoint with extreme optimizations and warmup. 539 | """ 540 | import torch 541 | import platform 542 | from functools import lru_cache 543 | from generator import Generator, Model, ModelArgs 544 | 545 | # Enable all CUDA optimizations 546 | torch.backends.cuda.matmul.allow_tf32 = True 547 | if hasattr(torch.backends.cuda, 'enable_flash_sdp'): 548 | torch.backends.cuda.enable_flash_sdp(True) 549 | torch.backends.cudnn.benchmark = True 550 | torch.backends.cudnn.enabled = True 551 | 552 | print(f"Loading CSM-1B model from local checkpoint '{model_path}' with extreme optimizations...") 553 | 554 | if torch.cuda.is_available(): 555 | torch.cuda.empty_cache() 556 | torch.cuda.synchronize() 557 | 558 | config = ModelArgs( 559 | backbone_flavor="llama-1B", 560 | decoder_flavor="llama-100M", 561 | text_vocab_size=128256, 562 | audio_vocab_size=2051, 563 | audio_num_codebooks=audio_num_codebooks, 564 | ) 565 | 566 | model = Model.from_pretrained(model_path) 567 | model.eval() 568 | 569 | dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 570 | 571 | model.backbone = torch.compile(model.backbone,mode='reduce-overhead', fullgraph=True, backend='inductor') 572 | model.decoder = torch.compile(model.decoder,mode='reduce-overhead', fullgraph=True, backend='inductor') 573 | 574 | model.to(device=device, dtype=dtype) 575 | 576 | print("Model compilation complete. Creating generator...") 577 | 578 | generator = Generator(model) 579 | generator._stream_buffer_size = 20 580 | 581 | # Setup tokenization caching 582 | generator._tokenization_cache = {} 583 | 584 | original_tokenize_text = generator._tokenize_text_segment 585 | 586 | @lru_cache(maxsize=2048) 587 | def cached_tokenize_text_segment(text_str, speaker_int): 588 | return original_tokenize_text(text_str, speaker_int) 589 | 590 | generator._tokenize_text_segment = lambda text, speaker: cached_tokenize_text_segment(text, speaker) 591 | 592 | # Perform warmup 593 | warmup_generator(generator) 594 | 595 | return generator 596 | 597 | def warmup_generator(gen: Generator, warmup_text: str = "Hello, this is a comprehensive warmup text that will exercise the model's generation capabilities.", speaker_id: int = 0): 598 | """ 599 | Perform an extremely aggressive warmup to drastically reduce first-generation latency. 600 | """ 601 | print("Starting maximum-intensity warmup sequence...") 602 | 603 | # Directly access and optimize the model's internal state 604 | if hasattr(gen._model, 'backbone') and hasattr(gen._model.backbone, 'positional_embedding'): 605 | # Force calculation of position embeddings to ensure they're cached 606 | with torch.inference_mode(): 607 | positions = torch.arange(0, 2048).to(gen.device) 608 | _ = gen._model.backbone.positional_embedding(positions) 609 | 610 | # Pre-allocate CUDA memory to prevent fragmentation during generation 611 | if torch.cuda.is_available(): 612 | print("Optimizing GPU memory allocation...") 613 | # Try to reserve a large chunk of memory 614 | try: 615 | import math 616 | reserved_memory = [] 617 | # Reserve multiple blocks of different sizes 618 | for size_mb in [128, 256, 512, 256, 128, 64]: 619 | size = int(size_mb * 1024 * 1024 / 4) # Convert MB to float32 elements 620 | tensor_size = int(math.sqrt(size)) 621 | tensor = torch.ones((tensor_size, tensor_size), device=gen.device, dtype=torch.float32) 622 | tensor = tensor * 1.0 # Force allocation 623 | reserved_memory.append(tensor) 624 | torch.cuda.synchronize() 625 | 626 | # Now free the memory 627 | for tensor in reserved_memory: 628 | del tensor 629 | reserved_memory = [] 630 | torch.cuda.empty_cache() 631 | torch.cuda.synchronize() 632 | except Exception as e: 633 | print(f"Memory pre-allocation: {e}") 634 | 635 | # Create multiple dummy audio segments with varying characteristics 636 | print("Creating diverse audio contexts...") 637 | audio_segments = [] 638 | 639 | # Create 3 different audio patterns 640 | for i in range(3): 641 | length = 24000 * (i + 1) # 1s, 2s, 3s 642 | audio = torch.zeros(length).to(gen.device) 643 | 644 | # Add different patterns to each segment 645 | if i == 0: 646 | # Sine wave pattern 647 | import math 648 | t = torch.linspace(0, 8 * math.pi, length).to(gen.device) 649 | audio = torch.sin(t) * 0.1 650 | elif i == 1: 651 | # Random noise pattern 652 | audio = torch.randn(length).to(gen.device) * 0.05 653 | else: 654 | # Pulse pattern 655 | audio[::800] = 0.2 656 | audio[::801] = -0.2 657 | 658 | segment = Segment( 659 | speaker=speaker_id, 660 | text=f"Warmup segment {i+1} with {length/24000:.1f}s of audio.", 661 | audio=audio 662 | ) 663 | audio_segments.append(segment) 664 | 665 | # Force compilation of critical model components 666 | print("Forcing compilation of critical components...") 667 | 668 | # Directly exercise the audio tokenizer with real data 669 | with torch.inference_mode(): 670 | for segment in audio_segments: 671 | # Force tokenization of both text and audio 672 | gen._tokenize_segment(segment) 673 | 674 | # Exercise the model's generation capabilities directly 675 | with torch.inference_mode(): 676 | 677 | # Generate some sample frames to ensure model is compiled 678 | dummy_tokens = torch.ones(1, 10, gen._num_codebooks+1).long().to(gen.device) 679 | dummy_mask = torch.ones(1, 10, gen._num_codebooks+1).bool().to(gen.device) 680 | dummy_pos = torch.arange(0, 10).unsqueeze(0).to(gen.device) 681 | 682 | # Generate multiple frames with different parameters 683 | for temp in [0.6, 0.7, 0.8]: 684 | for topk in [20, 30, 40]: 685 | _ = gen._model.generate_frame(dummy_tokens, dummy_mask, dummy_pos, temp, topk) 686 | 687 | gen._text_token_cache.clear() 688 | 689 | print("Running final generation with exact same setup as a real request...") 690 | 691 | final_text = "This is the final warmup that exactly matches a real generation request." 692 | 693 | # First tokenize the text - to fill the cache 694 | gen._tokenize_text_segment(final_text, speaker_id) 695 | 696 | try: 697 | # Now run a complete generation with a single context segment 698 | generate_streaming_audio( 699 | generator=gen, 700 | text=final_text, 701 | speaker=speaker_id, 702 | context=[audio_segments[0]], # Just one context segment 703 | output_file="warmup_final.wav", 704 | max_audio_length_ms=6000, 705 | temperature=0.7, 706 | topk=30, 707 | play_audio=False 708 | ) 709 | except Exception as e: 710 | print(f"Final warmup run exception (ignorable): {e}") 711 | 712 | # Force final synchronization and memory optimization 713 | if torch.cuda.is_available(): 714 | print("Final GPU optimization...") 715 | torch.cuda.synchronize() 716 | torch.cuda.empty_cache() 717 | 718 | try: 719 | # Allocate a large tensor to force compaction 720 | large_tensor = torch.empty(int(1e9//4), dtype=torch.float, device=gen.device) 721 | # Immediately delete it 722 | del large_tensor 723 | except RuntimeError: 724 | # Expected if there's not enough memory 725 | pass 726 | 727 | # Final cleanup 728 | torch.cuda.empty_cache() 729 | torch.cuda.synchronize() 730 | 731 | print("Maximum-intensity warmup complete. First generation should now be MUCH faster.") 732 | 733 | def load_csm_1b(device: str = "cuda") -> Generator: 734 | """ 735 | Load the CSM-1B model with extreme optimizations for real-time performance. 736 | """ 737 | # Enable all CUDA optimizations 738 | torch.backends.cuda.matmul.allow_tf32 = True 739 | torch.backends.cuda.enable_flash_sdp(True) 740 | torch.backends.cudnn.benchmark = True 741 | torch.backends.cudnn.enabled = True 742 | 743 | print("Loading CSM-1B model with extreme optimizations for real-time performance...") 744 | 745 | if torch.cuda.is_available(): 746 | torch.cuda.empty_cache() 747 | torch.cuda.synchronize() 748 | 749 | model = Model.from_pretrained("sesame/csm-1b") 750 | 751 | dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 752 | model.backbone = torch.compile(model.backbone,mode='reduce-overhead', fullgraph=True, backend='inductor') 753 | model.decoder = torch.compile(model.decoder,mode='reduce-overhead', fullgraph=True, backend='inductor') 754 | 755 | model.to(device=device, dtype=dtype) 756 | 757 | print("Model compilation complete. Creating generator...") 758 | 759 | generator = Generator(model) 760 | 761 | generator._stream_buffer_size = 20 762 | 763 | 764 | generator._tokenization_cache = {} 765 | 766 | from functools import lru_cache 767 | 768 | # Patch the tokenize method with caching 769 | original_tokenize_text = generator._tokenize_text_segment 770 | 771 | @lru_cache(maxsize=2048) 772 | def cached_tokenize_text_segment(text_str, speaker_int): 773 | return original_tokenize_text(text_str, speaker_int) 774 | 775 | generator._tokenize_text_segment = lambda text, speaker: cached_tokenize_text_segment(text, speaker) 776 | 777 | warmup_generator(generator) 778 | 779 | return generator 780 | 781 | def stream_audio_to_wav(filename, sample_rate): 782 | """ 783 | Initialize a WAV writer for streaming audio chunks. 784 | 785 | Args: 786 | filename: Output WAV file path 787 | sample_rate: Audio sample rate in Hz 788 | 789 | Returns: 790 | tuple: (write_chunk, close) functions for writing audio data and closing the file 791 | """ 792 | # Create a WAV file with the proper header 793 | wav_file = wave.open(filename, 'wb') 794 | wav_file.setnchannels(1) # Mono 795 | wav_file.setsampwidth(2) # 16-bit 796 | wav_file.setframerate(sample_rate) 797 | 798 | def write_chunk(audio_chunk): 799 | # Convert tensor to numpy and then to int16 PCM format 800 | if isinstance(audio_chunk, torch.Tensor): 801 | # Ensure it's on CPU and detached before converting to numpy 802 | audio_np = audio_chunk.detach().cpu().numpy() 803 | else: 804 | audio_np = audio_chunk 805 | 806 | # Normalize if needed (assuming audio is in [-1, 1] range) 807 | if audio_np.max() <= 1.0 and audio_np.min() >= -1.0: 808 | audio_int = (audio_np * 32767).astype(np.int16) 809 | else: 810 | audio_int = audio_np.astype(np.int16) 811 | 812 | # Write to WAV file 813 | wav_file.writeframes(audio_int.tobytes()) 814 | 815 | def close(): 816 | wav_file.close() 817 | 818 | return write_chunk, close 819 | 820 | def generate_streaming_audio( 821 | generator: Generator, 822 | text: str, 823 | speaker: int, 824 | context: List[Segment], 825 | output_file: str, 826 | max_audio_length_ms: float = 90_000, 827 | temperature: float = 1.0, 828 | topk: int = 50, 829 | play_audio: bool = False, 830 | ): 831 | """ 832 | Generate audio with streaming output and comprehensive timing metrics. 833 | Optimized for reduced first-chunk latency. 834 | """ 835 | # Initialize the streaming WAV writer 836 | write_chunk, close_wav = stream_audio_to_wav(output_file, generator.sample_rate) 837 | 838 | # Set up audio playback if requested 839 | audio_queue = queue.Queue(maxsize=100) if play_audio else None 840 | stop_event = threading.Event() 841 | 842 | if play_audio: 843 | try: 844 | import sounddevice as sd 845 | 846 | # Get available sample rates for default output device to check compatibility 847 | device_info = sd.query_devices(kind='output') 848 | supported_rate = device_info.get('default_samplerate', 44100) 849 | need_resampling = abs(supported_rate - generator.sample_rate) > 100 850 | 851 | if need_resampling: 852 | try: 853 | # Use resampling if sample rate doesn't match 854 | import librosa 855 | print(f"Resampling from {generator.sample_rate}Hz to {int(supported_rate)}Hz for playback") 856 | 857 | def audio_playback_worker(): 858 | while not stop_event.is_set() or not audio_queue.empty(): 859 | try: 860 | chunk = audio_queue.get(timeout=0.5) 861 | if isinstance(chunk, torch.Tensor) and chunk.numel() == 0: 862 | audio_queue.task_done() 863 | continue 864 | 865 | audio_np = chunk.numpy() if isinstance(chunk, torch.Tensor) else chunk 866 | 867 | # Skip very short chunks (likely noise) 868 | if len(audio_np) < 100: 869 | audio_queue.task_done() 870 | continue 871 | 872 | # Resample to device's supported rate 873 | resampled = librosa.resample( 874 | audio_np, 875 | orig_sr=generator.sample_rate, 876 | target_sr=int(supported_rate) 877 | ) 878 | sd.play(resampled, supported_rate, blocking=True) 879 | # Add a small delay to ensure audio finishes playing 880 | time.sleep(0.05) 881 | audio_queue.task_done() 882 | except queue.Empty: 883 | # If queue empty but not stopping, keep trying 884 | if not stop_event.is_set(): 885 | continue 886 | else: 887 | break 888 | except Exception as e: 889 | print(f"Playback error: {e}") 890 | audio_queue.task_done() 891 | except ImportError: 892 | print("Librosa not found. Using direct playback which may cause sample rate warnings.") 893 | need_resampling = False 894 | 895 | if not need_resampling: 896 | def audio_playback_worker(): 897 | while not stop_event.is_set() or not audio_queue.empty(): 898 | try: 899 | chunk = audio_queue.get(timeout=0.5) 900 | if isinstance(chunk, torch.Tensor) and chunk.numel() == 0: 901 | audio_queue.task_done() 902 | continue 903 | 904 | audio_np = chunk.numpy() if isinstance(chunk, torch.Tensor) else chunk 905 | 906 | # Skip very short chunks (likely noise) 907 | if len(audio_np) < 100: 908 | audio_queue.task_done() 909 | continue 910 | 911 | sd.play(audio_np, generator.sample_rate, blocking=True) 912 | # Add a small delay to ensure audio finishes playing 913 | time.sleep(0.05) 914 | audio_queue.task_done() 915 | except queue.Empty: 916 | # If queue empty but not stopping, keep trying 917 | if not stop_event.is_set(): 918 | continue 919 | else: 920 | break 921 | except Exception as e: 922 | print(f"Playback error: {e}") 923 | audio_queue.task_done() 924 | 925 | # Start playback thread 926 | playback_thread = threading.Thread(target=audio_playback_worker, daemon=False) 927 | playback_thread.start() 928 | 929 | except ImportError: 930 | print("sounddevice library not found. Install with 'pip install sounddevice' for real-time playback.") 931 | play_audio = False 932 | 933 | # Timing metrics 934 | chunk_times = [] 935 | latency_to_first_chunk = None 936 | total_audio_duration = 0 937 | chunk_count = 0 938 | 939 | # Function to handle each generated chunk 940 | def on_chunk_generated(chunk): 941 | nonlocal chunk_count, latency_to_first_chunk, total_audio_duration 942 | 943 | current_time = time.time() 944 | if chunk_count == 0: 945 | latency_to_first_chunk = current_time - start_time 946 | print(f"First chunk latency: {latency_to_first_chunk*1000:.1f}ms") 947 | 948 | # Save chunk to WAV file 949 | write_chunk(chunk) 950 | 951 | # Update metrics 952 | chunk_count += 1 953 | chunk_duration = len(chunk) / generator.sample_rate 954 | total_audio_duration += chunk_duration 955 | chunk_times.append(current_time) 956 | 957 | # Send to audio player if enabled 958 | if play_audio and audio_queue is not None: 959 | try: 960 | audio_queue.put(chunk, timeout=1.0) 961 | except queue.Full: 962 | pass # Skip if queue is full to avoid blocking 963 | 964 | if torch.cuda.is_available(): 965 | print("Preparing GPU for low-latency generation...") 966 | torch.cuda.empty_cache() 967 | torch.cuda.synchronize() 968 | 969 | # Pre-allocate some GPU memory to avoid allocation during generation 970 | dummy_tensors = [] 971 | for i in range(5): 972 | dummy = torch.ones((100, 100), device=generator.device) 973 | dummy = dummy + 1.0 # Force computation 974 | dummy_tensors.append(dummy) # Keep reference to prevent deallocation 975 | 976 | torch.cuda.synchronize() 977 | 978 | # Set process priority to improve performance - use higher priority 979 | try: 980 | import psutil 981 | process = psutil.Process() 982 | if platform.system() == 'Windows': 983 | process.nice(psutil.HIGH_PRIORITY_CLASS) 984 | else: 985 | process.nice(-1) 986 | except (ImportError, PermissionError, psutil.AccessDenied): 987 | pass 988 | 989 | print(f"Starting audio generation for: '{text[:50]}{'...' if len(text) > 50 else ''}'") 990 | start_time = time.time() 991 | 992 | # Generate audio in chunks, catching possible errors 993 | frame_count = 0 994 | audio_chunks = [] # Store all chunks for possible use at the end 995 | 996 | try: 997 | for audio_chunk in generator.generate_stream( 998 | text=text, 999 | speaker=speaker, 1000 | context=context, 1001 | max_audio_length_ms=max_audio_length_ms, 1002 | temperature=temperature, 1003 | topk=topk, 1004 | on_chunk_generated=on_chunk_generated 1005 | ): 1006 | frame_count += 1 1007 | audio_chunks.append(audio_chunk) # Store the chunk 1008 | 1009 | # Print timing info less frequently to reduce overhead 1010 | if frame_count % 10 == 0: 1011 | current_time = time.time() 1012 | elapsed = current_time - start_time 1013 | if total_audio_duration > 0: 1014 | rtf = elapsed / total_audio_duration 1015 | remaining_time = (max_audio_length_ms/1000 - total_audio_duration) * rtf 1016 | print(f"Chunk {chunk_count}: {total_audio_duration:.1f}s audio in {elapsed:.1f}s " 1017 | f"(RTF: {rtf:.2f}x, Est. remaining: {remaining_time:.1f}s)") 1018 | except Exception as e: 1019 | print(f"Error during audio generation: {e}") 1020 | import traceback 1021 | traceback.print_exc() 1022 | 1023 | # Release dummy tensors to free memory 1024 | if 'dummy_tensors' in locals(): 1025 | del dummy_tensors 1026 | 1027 | # Ensure all chunks are properly processed 1028 | if play_audio and audio_queue is not None: 1029 | print("Waiting for playback queue to finish...") 1030 | try: 1031 | timeout_start = time.time() 1032 | while not audio_queue.empty() and time.time() - timeout_start < 5.0: 1033 | time.sleep(0.1) 1034 | except: 1035 | pass 1036 | 1037 | # Add a small delay to ensure everything is processed 1038 | time.sleep(0.5) 1039 | 1040 | # Signal audio worker that generation is complete 1041 | stop_event.set() 1042 | 1043 | # Close WAV file 1044 | close_wav() 1045 | 1046 | # Wait for audio playback to complete if enabled 1047 | if play_audio and 'playback_thread' in locals(): 1048 | print("Waiting for audio playback to complete...") 1049 | 1050 | # First, ensure the queue is empty 1051 | try: 1052 | timeout_start = time.time() 1053 | while not audio_queue.empty() and time.time() - timeout_start < 5.0: 1054 | time.sleep(0.1) 1055 | except: 1056 | pass 1057 | 1058 | # Set a flag to indicate complete audio playback is needed 1059 | if hasattr(sd, 'wait'): 1060 | try: 1061 | sd.wait() 1062 | except: 1063 | pass 1064 | 1065 | # Join the playback thread with timeout 1066 | playback_thread.join(timeout=5.0) 1067 | 1068 | # Force sounddevice to stop if it's still playing 1069 | try: 1070 | sd.stop() 1071 | except: 1072 | pass 1073 | 1074 | # Calculate and print detailed performance metrics 1075 | end_time = time.time() 1076 | total_elapsed = end_time - start_time 1077 | 1078 | # Calculate inter-chunk latency 1079 | if len(chunk_times) > 1: 1080 | inter_chunk_latencies = [chunk_times[i] - chunk_times[i-1] for i in range(1, len(chunk_times))] 1081 | avg_inter_chunk_latency = sum(inter_chunk_latencies) / len(inter_chunk_latencies) 1082 | max_inter_chunk_latency = max(inter_chunk_latencies) if inter_chunk_latencies else 0 1083 | min_inter_chunk_latency = min(inter_chunk_latencies) if inter_chunk_latencies else 0 1084 | else: 1085 | avg_inter_chunk_latency = max_inter_chunk_latency = min_inter_chunk_latency = 0 1086 | 1087 | rtf = total_elapsed / total_audio_duration if total_audio_duration > 0 else float('inf') 1088 | 1089 | print("\n" + "="*50) 1090 | print("AUDIO GENERATION PERFORMANCE METRICS") 1091 | print("="*50) 1092 | print(f"First chunk latency: {latency_to_first_chunk*1000:.1f}ms") 1093 | print(f"Total generation time: {total_elapsed:.2f}s") 1094 | print(f"Audio duration: {total_audio_duration:.2f}s") 1095 | print(f"Real-time factor (RTF): {rtf:.3f}x (target: <1.0)") 1096 | print(f"Number of chunks: {chunk_count}") 1097 | print(f"Average chunk size: {(total_audio_duration/chunk_count)*1000:.1f}ms") if chunk_count > 0 else None 1098 | print(f"Average inter-chunk latency: {avg_inter_chunk_latency*1000:.1f}ms") 1099 | print(f"Min/Max inter-chunk latency: {min_inter_chunk_latency*1000:.1f}ms / {max_inter_chunk_latency*1000:.1f}ms") 1100 | print(f"Chunks per second: {chunk_count/total_elapsed:.2f}") 1101 | print(f"Output file: {output_file}") 1102 | print("="*50) --------------------------------------------------------------------------------