├── templates ├── index.html ├── crud.html ├── setup.html └── chat.html ├── static ├── test.mp3 ├── crud.js ├── app.js └── chat.js ├── finetuned_model └── config.json ├── .gitignore ├── requirements.txt ├── .github └── FUNDING.yml ├── loadandmergecheckpoint.py ├── test.py ├── config.py ├── run_csm.py ├── setup.py ├── llm_interface.py ├── models.py ├── vad.py ├── README.md ├── LICENSE ├── rag_system.py └── generator.py /templates/index.html: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /static/test.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidbrowne17/csm-streaming/HEAD/static/test.mp3 -------------------------------------------------------------------------------- /finetuned_model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "audio_num_codebooks": 32, 3 | "audio_vocab_size": 2051, 4 | "backbone_flavor": "llama-1B", 5 | "decoder_flavor": "llama-100M", 6 | "text_vocab_size": 128256 7 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Virtual Environment 24 | .env 25 | .venv 26 | env/ 27 | venv/ 28 | ENV/ 29 | 30 | # IDE 31 | .idea/ 32 | .vscode/ 33 | *.swp 34 | *.swo 35 | 36 | # Project specific 37 | .python-version 38 | *.wav 39 | output_*/ 40 | basic_audio.wav 41 | full_conversation.wav 42 | context_audio.wav 43 | 44 | # Model files 45 | *.pt 46 | *.ckpt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url=https://download.pytorch.org/whl/cu128 2 | vllm==0.8.0 3 | torch==2.6.0 4 | torchaudio==2.6.0 5 | torchvision==0.21.0 6 | tokenizers==0.21.0 7 | transformers==4.49.0 8 | huggingface_hub==0.28.1 9 | moshi==0.2.2 10 | sounddevice 11 | torchtune==0.4.0 12 | torchao==0.9.0 13 | bitsandbytes 14 | peft 15 | wandb 16 | silero_vad 17 | python-multipart>=0.0.6 18 | aiofiles>=23.1.0 19 | sentence-transformers>=2.2.2 20 | ctransformers>=0.2.24 21 | python-multipart>=0.0.6 22 | sqlalchemy>=2.0.0 23 | pydantic>=2.0.0 24 | fastapi>=0.95.0 25 | uvicorn>=0.22.0 26 | websockets>=11.0.3 27 | jinja2>=3.0.0 28 | speechbrain>=0.5.15 29 | matplotlib 30 | whisper-openai 31 | silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master 32 | numpy==1.26.0 -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: davidbrowne17 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | polar: # Replace with a single Polar username 13 | buy_me_a_coffee: # Replace with a single Buy Me a Coffee username 14 | thanks_dev: # Replace with a single thanks.dev username 15 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 16 | -------------------------------------------------------------------------------- /templates/crud.html: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 |No conversations found.
'; 34 | return; 35 | } 36 | 37 | list.forEach(conv => { 38 | const div = document.createElement('div'); 39 | div.className = "bg-gray-800 p-4 rounded shadow"; 40 | div.innerHTML = ` 41 |Redirecting to AI Companion...
111 | 112 | 113 | """) 114 | logger.info("Created index template for redirection") 115 | 116 | def setup_database(): 117 | """Initialize the SQLite database""" 118 | logger.info("Setting up database...") 119 | 120 | try: 121 | from sqlalchemy import create_engine, Column, Integer, String, Text 122 | from sqlalchemy.ext.declarative import declarative_base 123 | from sqlalchemy.orm import sessionmaker 124 | 125 | Base = declarative_base() 126 | engine = create_engine("sqlite:///companion.db") 127 | 128 | class Conversation(Base): 129 | __tablename__ = "conversations" 130 | id = Column(Integer, primary_key=True, index=True) 131 | session_id = Column(String, index=True) 132 | timestamp = Column(String) 133 | user_message = Column(Text) 134 | ai_message = Column(Text) 135 | audio_path = Column(String) 136 | 137 | # Create tables 138 | Base.metadata.create_all(bind=engine) 139 | logger.info("Database initialized successfully") 140 | 141 | except Exception as e: 142 | logger.error(f"Failed to set up database: {e}") 143 | 144 | def check_cuda(): 145 | """Check if CUDA is available for PyTorch""" 146 | if torch.cuda.is_available(): 147 | device_name = torch.cuda.get_device_name(0) 148 | logger.info(f"CUDA is available: {device_name}") 149 | logger.info(f"CUDA version: {torch.version.cuda}") 150 | else: 151 | logger.warning("CUDA is not available. The application will run on CPU, which may be very slow") 152 | logger.warning("For optimal performance, a CUDA-capable GPU is recommended") 153 | 154 | def main(): 155 | """Main setup function""" 156 | logger.info("Starting AI Companion setup...") 157 | 158 | # Check for CUDA availability 159 | check_cuda() 160 | 161 | # Check and install requirements 162 | #check_requirements() 163 | 164 | # Create directories 165 | setup_directories() 166 | 167 | # Set up database 168 | setup_database() 169 | 170 | # Download models 171 | download_vad_model() 172 | download_embedding_models() 173 | 174 | logger.info("Setup completed successfully!") 175 | logger.info("You can now start the application with:") 176 | logger.info(" python main.py") 177 | 178 | if __name__ == "__main__": 179 | main() -------------------------------------------------------------------------------- /llm_interface.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Dict, Any, Optional 3 | import torch 4 | from vllm import LLM, SamplingParams 5 | 6 | class LLMInterface: 7 | def __init__(self, model_path: str, max_tokens: int = 8192, n_threads: int = 8, gpu_layers: int = -1): 8 | """Initialize the LLM interface using VLLM with a given model. 9 | 10 | Args: 11 | model_path (str): Path to the model or HuggingFace model name 12 | max_tokens (int, optional): Maximum context length. Defaults to 8192. 13 | n_threads (int, optional): Number of CPU threads. Defaults to 8. 14 | gpu_layers (int, optional): Not used in VLLM, maintained for API compatibility. 15 | """ 16 | # VLLM configuration 17 | self.llm = LLM( 18 | model=model_path, 19 | tensor_parallel_size=1, # Adjust based on number of GPUs available 20 | gpu_memory_utilization=0.6, 21 | max_model_len=max_tokens, 22 | swap_space=0, 23 | trust_remote_code=True, 24 | dtype=torch.float16, 25 | ) 26 | 27 | # Store configuration for reference 28 | self.config = { 29 | "model_path": model_path, 30 | "max_tokens": max_tokens, 31 | } 32 | 33 | def trim_to_last_sentence(self, text: str) -> str: 34 | """ 35 | Return *text* truncated at the final full sentence boundary. 36 | A boundary is considered to be any '.', '!' or '?' followed by 37 | optional quotes/brackets, optional whitespace, and then end-of-string. 38 | 39 | If no sentence terminator exists, the original text is returned. 40 | """ 41 | # Regex explanation: 42 | # (.*?[.!?]["')\]]?) any text lazily until a terminator 43 | # \s*$ followed only by whitespace till end-of-string 44 | m = re.match(r"^(.*?[.!?][\"')\]]?)\s*$", text, re.DOTALL) 45 | if m: 46 | return m.group(1).strip() 47 | # Fall back to manual search (handles cases with additional text) 48 | for i in range(len(text) - 1, -1, -1): 49 | if text[i] in ".!?": 50 | return text[: i + 1].strip() 51 | return text.strip() 52 | 53 | def generate_response(self, system_prompt: str, user_message: str, conversation_history: str = "") -> str: 54 | """Generate a response from the LLM using chat-style prompt formatting. 55 | 56 | Args: 57 | system_prompt (str): The system prompt/instructions 58 | user_message (str): The user's input message 59 | conversation_history (str, optional): Any prior conversation context. Defaults to "". 60 | 61 | Returns: 62 | str: The generated response 63 | """ 64 | # Format prompt following chat template structure 65 | prompt = f"""<|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot_id|> 66 | {conversation_history} 67 | <|start_header_id|>user<|end_header_id|>\n{user_message}<|eot_id|> 68 | <|start_header_id|>assistant<|end_header_id|>\n""" 69 | 70 | # Define sampling parameters (equivalent to the previous implementation) 71 | sampling_params = SamplingParams( 72 | temperature=1.0, 73 | top_p=0.95, 74 | max_tokens=100, 75 | repetition_penalty=1.2, 76 | top_k=200, 77 | stop=["", "<|endoftext|>", "<$1')
96 | .replace(/`([^`]+)`/g,'$1')
97 | .replace(/\n/g,'