├── client ├── __init__.py ├── dht_discovery.py ├── api.py ├── discovery.py └── router.py ├── inference_node ├── __init__.py ├── heartbeat.py ├── dht_publisher.py ├── request_queue.py ├── metrics.py └── p2p_handler.py ├── dht ├── __init__.py ├── protocol.py └── routing_table.py ├── common ├── __init__.py ├── utils.py ├── unified_sse.py ├── port_utils.py ├── validation_utils.py ├── error_handler.py ├── p2p_transport.py ├── metrics_manager.py ├── hardware_fingerprint.py ├── service_manager.py ├── models.py ├── subnet_matcher.py └── sse_handler.py ├── static └── images │ └── screenshot.png ├── requirements-dev.txt ├── requirements.txt ├── docker ├── registry.Dockerfile ├── docker-compose.yml ├── inference.Dockerfile ├── start.sh └── gpu-detect.sh ├── .gitignore ├── setup.py ├── pyproject.toml ├── tools ├── quick_check.py ├── monitor.py └── network_status.py ├── scripts └── install-cuda-support.sh ├── examples ├── openai_client_example.py └── simple_client.py ├── start-app.sh └── LICENSE /client/__init__.py: -------------------------------------------------------------------------------- 1 | """LlamaNet client library""" 2 | -------------------------------------------------------------------------------- /inference_node/__init__.py: -------------------------------------------------------------------------------- 1 | """LlamaNet inference node""" 2 | -------------------------------------------------------------------------------- /dht/__init__.py: -------------------------------------------------------------------------------- 1 | """Kademlia DHT implementation for LlamaNet""" 2 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- 1 | """Common utilities and models for LlamaNet""" 2 | -------------------------------------------------------------------------------- /static/images/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/machaao/llama-net/HEAD/static/images/screenshot.png -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest>=6.0.0 2 | pytest-asyncio>=0.18.0 3 | black>=22.0.0 4 | flake8>=4.0.0 5 | mypy>=0.950 6 | -------------------------------------------------------------------------------- /client/dht_discovery.py: -------------------------------------------------------------------------------- 1 | # Import the event-based discovery as the default implementation 2 | from client.event_discovery import EventBasedDHTDiscovery 3 | 4 | # For backward compatibility, alias the event-based version as DHTDiscovery 5 | DHTDiscovery = EventBasedDHTDiscovery 6 | -------------------------------------------------------------------------------- /client/api.py: -------------------------------------------------------------------------------- 1 | from client.event_aware_client import EventAwareOpenAIClient 2 | from common.utils import get_logger 3 | 4 | logger = get_logger(__name__) 5 | 6 | # Simplified API with consolidated client 7 | OpenAIClient = EventAwareOpenAIClient 8 | Client = EventAwareOpenAIClient 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi>=0.68.0,<1.0.0 2 | uvicorn 3 | pydantic>=2.0.0,<3.0.0 4 | requests>=2.26.0,<3.0.0 5 | llama-cpp-python>=0.2.20 6 | psutil>=5.8.0,<6.0.0 7 | pynvml>=11.4.1 8 | kademlia>=2.2.2,<3.0.0 9 | aiohttp>=3.8.0,<4.0.0 10 | p2pd 11 | ipaddress>=1.0.23 12 | httptools 13 | -------------------------------------------------------------------------------- /docker/registry.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim 2 | 3 | WORKDIR /app 4 | 5 | # Copy requirements and install Python dependencies 6 | COPY requirements.txt . 7 | RUN pip install --no-cache-dir -r requirements.txt 8 | 9 | # Copy application code 10 | COPY . . 11 | 12 | # Environment variables 13 | ENV REGISTRY_PORT=8080 14 | 15 | # Expose port 16 | EXPOSE 8080 17 | 18 | # Run the registry service 19 | CMD ["python", "-m", "registry.server"] 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Virtual environments 24 | venv/ 25 | env/ 26 | ENV/ 27 | 28 | # IDE 29 | .vscode/ 30 | .idea/ 31 | *.swp 32 | *.swo 33 | 34 | # Models 35 | models/ 36 | *.gguf 37 | 38 | # Logs 39 | *.log 40 | 41 | # MACH-AI specific 42 | plans 43 | 44 | # Deployment 45 | start-app.sh.local 46 | -------------------------------------------------------------------------------- /client/discovery.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Optional 3 | from common.models import NodeInfo 4 | 5 | class DiscoveryInterface(ABC): 6 | """Abstract base class for node discovery mechanisms""" 7 | 8 | @abstractmethod 9 | async def start(self): 10 | """Start the discovery service""" 11 | pass 12 | 13 | @abstractmethod 14 | async def stop(self): 15 | """Stop the discovery service""" 16 | pass 17 | 18 | @abstractmethod 19 | async def get_nodes(self, model: Optional[str] = None, force_refresh: bool = False) -> List[NodeInfo]: 20 | """Get available nodes, optionally filtered by model""" 21 | pass 22 | 23 | @abstractmethod 24 | async def find_specific_node(self, node_id: str) -> Optional[NodeInfo]: 25 | """Find a specific node by ID""" 26 | pass 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="llamanet", 5 | version="0.1.0", 6 | packages=find_packages(), 7 | install_requires=[ 8 | "fastapi>=0.68.0,<1.0.0", 9 | "uvicorn>=0.15.0,<1.0.0", 10 | "pydantic>=2.0.0,<3.0.0", 11 | "requests>=2.26.0,<3.0.0", 12 | "llama-cpp-python>=0.2.20", 13 | "psutil>=5.8.0,<6.0.0", 14 | "pynvml>=11.4.1", 15 | "kademlia>=2.2.2,<3.0.0", 16 | "aiohttp>=3.8.0,<4.0.0", 17 | "p2pd", 18 | "ipaddress>=1.0.23" 19 | ], 20 | author="LlamaNet Team", 21 | author_email="example@example.com", 22 | description="Decentralized Inference Swarm for llama.cpp", 23 | keywords="llm, inference, decentralized", 24 | url="https://github.com/yourusername/llamanet", 25 | license="Apache-2.0", 26 | classifiers=[ 27 | "Development Status :: 3 - Alpha", 28 | "Intended Audience :: Developers", 29 | "Programming Language :: Python :: 3", 30 | ], 31 | python_requires=">=3.8", 32 | ) 33 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "llamanet" 7 | version = "0.1.0" 8 | description = "Decentralized Inference Swarm for llama.cpp" 9 | authors = [{name = "LlamaNet Team", email = "example@example.com"}] 10 | license = {text = "Apache-2.0"} 11 | readme = "README.md" 12 | requires-python = ">=3.8" 13 | keywords = ["llm", "inference", "decentralized"] 14 | classifiers = [ 15 | "Development Status :: 3 - Alpha", 16 | "Intended Audience :: Developers", 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: Apache Software License", 19 | ] 20 | 21 | dependencies = [ 22 | "fastapi>=0.68.0,<1.0.0", 23 | "uvicorn>=0.15.0,<1.0.0", 24 | "pydantic>=2.0.0,<3.0.0", 25 | "requests>=2.26.0,<3.0.0", 26 | "llama-cpp-python>=0.2.20", 27 | "psutil>=5.8.0,<6.0.0", 28 | "pynvml>=11.4.1", 29 | "kademlia>=2.2.2,<3.0.0", 30 | "aiohttp>=3.8.0,<4.0.0", 31 | "p2pd", 32 | "ipaddress>=1.0.23", 33 | ] 34 | 35 | [project.urls] 36 | Homepage = "https://github.com/yourusername/llamanet" 37 | Repository = "https://github.com/yourusername/llamanet" 38 | 39 | [project.scripts] 40 | llamanet-inference = "inference_node.server:start_server" 41 | llamanet-registry = "registry.server:start_server" 42 | llamanet-help = "inference_node.server:show_help" 43 | 44 | [tool.setuptools.packages.find] 45 | include = ["client*", "common*", "dht*", "inference_node*", "registry*"] 46 | -------------------------------------------------------------------------------- /docker/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | # Bootstrap node - first inference node that others can connect to 5 | bootstrap: 6 | build: 7 | context: .. 8 | dockerfile: docker/inference.Dockerfile 9 | ports: 10 | - "8000:8000" 11 | - "8001:8001" # DHT port 12 | environment: 13 | - MODEL_PATH=/models/model.gguf 14 | - PORT=8000 15 | - DHT_PORT=8001 16 | - NODE_ID=bootstrap-node 17 | - BOOTSTRAP_NODES="" 18 | volumes: 19 | - ../models:/models 20 | restart: unless-stopped 21 | 22 | # Additional inference nodes 23 | inference1: 24 | build: 25 | context: .. 26 | dockerfile: docker/inference.Dockerfile 27 | ports: 28 | - "8002:8000" 29 | - "8003:8001" # DHT port 30 | environment: 31 | - MODEL_PATH=/models/model.gguf 32 | - PORT=8000 33 | - DHT_PORT=8001 34 | - BOOTSTRAP_NODES=bootstrap:8001 35 | volumes: 36 | - ../models:/models 37 | depends_on: 38 | - bootstrap 39 | restart: unless-stopped 40 | 41 | inference2: 42 | build: 43 | context: .. 44 | dockerfile: docker/inference.Dockerfile 45 | ports: 46 | - "8004:8000" 47 | - "8005:8001" # DHT port 48 | environment: 49 | - MODEL_PATH=/models/model.gguf 50 | - PORT=8000 51 | - DHT_PORT=8001 52 | - BOOTSTRAP_NODES=bootstrap:8001 53 | volumes: 54 | - ../models:/models 55 | depends_on: 56 | - bootstrap 57 | restart: unless-stopped 58 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | from typing import Any, Dict, Optional 5 | 6 | # Configure logging 7 | logging.basicConfig( 8 | level=logging.INFO, 9 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' 10 | ) 11 | 12 | def get_logger(name: str) -> logging.Logger: 13 | """Get a logger with the given name""" 14 | return logging.getLogger(name) 15 | 16 | def load_env_var(key: str, default: Any = None) -> Any: 17 | """Load an environment variable with a default value""" 18 | return os.environ.get(key, default) 19 | 20 | def get_host_ip() -> str: 21 | """Get the host IP address""" 22 | # This is a simple implementation - in production you might want something more robust 23 | import socket 24 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 25 | try: 26 | # Doesn't need to be reachable 27 | s.connect(('10.255.255.255', 1)) 28 | ip = s.getsockname()[0] 29 | except Exception: 30 | ip = '127.0.0.1' 31 | finally: 32 | s.close() 33 | return ip 34 | 35 | def normalize_stop_tokens(stop): 36 | """Normalize stop tokens to the format expected by llama-cpp-python""" 37 | if stop is None: 38 | return None 39 | elif isinstance(stop, str): 40 | return [stop] if stop.strip() else None 41 | elif isinstance(stop, list): 42 | # Filter out empty strings and ensure all items are strings 43 | normalized = [str(token).strip() for token in stop if str(token).strip()] 44 | return normalized if normalized else None 45 | else: 46 | return None 47 | -------------------------------------------------------------------------------- /tools/quick_check.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import requests 3 | from client.api import Client 4 | 5 | async def quick_network_check(): 6 | client = Client(bootstrap_nodes="localhost:8001") 7 | 8 | try: 9 | # Get all nodes 10 | nodes = await client.dht_discovery.get_nodes() 11 | print(f"🌐 Available nodes: {len(nodes)}") 12 | 13 | if not nodes: 14 | print("❌ No nodes found in the network") 15 | return 16 | 17 | print("\n📊 Node Status:") 18 | print("-" * 60) 19 | 20 | for node in nodes: 21 | print(f"🔹 {node.node_id[:8]}... ({node.ip}:{node.port})") 22 | print(f" Model: {node.model}") 23 | print(f" Load: {node.load:.2f} | TPS: {node.tps:.2f}") 24 | 25 | # Try to get health status 26 | try: 27 | health_response = requests.get(f"http://{node.ip}:{node.port}/health", timeout=3) 28 | if health_response.status_code == 200: 29 | health = health_response.json() 30 | health_icon = "💚" if health.get('healthy', False) else "💔" 31 | print(f" Health: {health_icon} {'Healthy' if health.get('healthy', False) else 'Unhealthy'}") 32 | else: 33 | print(f" Health: ❓ Unknown") 34 | except: 35 | print(f" Health: ❌ Unreachable") 36 | 37 | print() 38 | 39 | finally: 40 | await client.close() 41 | 42 | if __name__ == "__main__": 43 | asyncio.run(quick_network_check()) 44 | -------------------------------------------------------------------------------- /tools/monitor.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | import os 4 | import requests 5 | from client.dht_discovery import DHTDiscovery 6 | 7 | async def monitor_network(bootstrap_nodes="localhost:8001", interval=10): 8 | """Monitor the network in real-time""" 9 | discovery = DHTDiscovery(bootstrap_nodes) 10 | 11 | try: 12 | while True: 13 | os.system('clear' if os.name == 'posix' else 'cls') # Clear screen 14 | 15 | print("🔄 LlamaNet Network Monitor") 16 | print(f"⏰ {time.strftime('%Y-%m-%d %H:%M:%S')}") 17 | print("=" * 60) 18 | 19 | nodes = await discovery.get_nodes(force_refresh=True) 20 | 21 | if nodes: 22 | print(f"📊 Active Nodes: {len(nodes)}") 23 | print("-" * 60) 24 | 25 | for node in nodes: 26 | status = "🟢" if time.time() - node.last_seen < 30 else "🟡" 27 | print(f"{status} {node.node_id[:12]}... | {node.ip}:{node.port} | {node.model}") 28 | print(f" Load: {node.load:.2f} | TPS: {node.tps:.2f} | Uptime: {node.uptime}s") 29 | print() 30 | else: 31 | print("❌ No nodes found") 32 | 33 | print(f"\n🔄 Refreshing in {interval} seconds... (Ctrl+C to exit)") 34 | await asyncio.sleep(interval) 35 | 36 | except KeyboardInterrupt: 37 | print("\n👋 Monitoring stopped") 38 | finally: 39 | await discovery.stop() 40 | 41 | if __name__ == "__main__": 42 | asyncio.run(monitor_network()) 43 | -------------------------------------------------------------------------------- /scripts/install-cuda-support.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # CUDA Support Installation Script for Python 3.12 3 | # This script installs llama-cpp-python with CUDA support 4 | 5 | set -e 6 | 7 | echo "Installing CUDA support for LlamaNet on Python 3.12..." 8 | 9 | # Check Python version 10 | PYTHON_VERSION=$(python3 -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") 11 | echo "Detected Python version: $PYTHON_VERSION" 12 | 13 | if [[ "$PYTHON_VERSION" < "3.8" ]]; then 14 | echo "Error: Python 3.8 or higher required" 15 | exit 1 16 | fi 17 | 18 | # Check for CUDA installation 19 | if command -v nvcc &> /dev/null; then 20 | CUDA_VERSION=$(nvcc --version | grep "release" | sed 's/.*release \([0-9]\+\.[0-9]\+\).*/\1/') 21 | echo "CUDA version detected: $CUDA_VERSION" 22 | else 23 | echo "Warning: CUDA not detected. Installing CPU-only version." 24 | pip install llama-cpp-python>=0.2.11 25 | exit 0 26 | fi 27 | 28 | # Install with CUDA support 29 | echo "Installing llama-cpp-python with CUDA support..." 30 | CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python>=0.2.11 --force-reinstall --no-cache-dir 31 | 32 | # Install NVIDIA monitoring 33 | echo "Installing NVIDIA GPU monitoring..." 34 | pip install nvidia-ml-py3>=7.352.0 35 | 36 | # Verify installation 37 | echo "Verifying CUDA installation..." 38 | python3 -c " 39 | try: 40 | import llama_cpp 41 | import pynvml 42 | pynvml.nvmlInit() 43 | device_count = pynvml.nvmlDeviceGetCount() 44 | print(f'✓ CUDA support verified: {device_count} GPU(s) detected') 45 | except Exception as e: 46 | print(f'✗ CUDA verification failed: {e}') 47 | exit(1) 48 | " 49 | 50 | echo "✓ CUDA support installation completed successfully!" -------------------------------------------------------------------------------- /common/unified_sse.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import time 4 | from typing import Dict, Any, Optional, Callable, List, AsyncGenerator 5 | import aiohttp 6 | from common.sse_handler import SSEHandler, SSENetworkMonitor, SSEParser, SSEStreamHandler, OpenAISSETransformer 7 | from common.utils import get_logger 8 | 9 | logger = get_logger(__name__) 10 | 11 | class UnifiedSSEManager: 12 | """Unified SSE management combining all SSE functionality""" 13 | 14 | def __init__(self, base_url: str = None): 15 | self.handler = SSEHandler() 16 | self.monitor = SSENetworkMonitor(base_url) if base_url else None 17 | self.parser = SSEParser() 18 | self.stream_handler = SSEStreamHandler() 19 | self.transformer = OpenAISSETransformer() 20 | 21 | if self.monitor: 22 | self.monitor.set_sse_handler(self.handler) 23 | 24 | async def start(self): 25 | """Start all SSE components""" 26 | if self.monitor: 27 | await self.monitor.start() 28 | self.handler.running = True 29 | logger.info("Unified SSE manager started") 30 | 31 | async def stop(self): 32 | """Stop all SSE components""" 33 | if self.monitor: 34 | await self.monitor.stop() 35 | self.handler.running = False 36 | logger.info("Unified SSE manager stopped") 37 | 38 | # Delegate methods to appropriate components 39 | async def add_connection(self, connection_id: str): 40 | return await self.handler.add_connection(connection_id) 41 | 42 | async def remove_connection(self, connection_id: str): 43 | return await self.handler.remove_connection(connection_id) 44 | 45 | async def broadcast_event(self, event_type: str, event_data: Dict[str, Any]): 46 | return await self.handler.broadcast_event(event_type, event_data) 47 | 48 | def get_status(self): 49 | return self.handler.get_status() 50 | 51 | # Stream handling methods 52 | async def stream_from_response(self, response: aiohttp.ClientResponse, transform_func: Optional[Callable] = None): 53 | return self.stream_handler.stream_from_response(response, transform_func) 54 | 55 | def get_chat_transformer(self): 56 | return self.transformer.chat_transform 57 | 58 | def get_completion_transformer(self): 59 | return self.transformer.completion_transform 60 | -------------------------------------------------------------------------------- /tools/network_status.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import requests 3 | import json 4 | from client.dht_discovery import DHTDiscovery 5 | 6 | async def show_network_status(bootstrap_nodes="localhost:8001"): 7 | """Show the status of the entire LlamaNet network""" 8 | 9 | print("🔍 Discovering LlamaNet network...") 10 | 11 | # Create DHT discovery client 12 | discovery = DHTDiscovery(bootstrap_nodes) 13 | 14 | try: 15 | # Get all nodes 16 | nodes = await discovery.get_nodes() 17 | 18 | print(f"\n📊 Found {len(nodes)} active nodes:") 19 | print("-" * 80) 20 | 21 | for i, node in enumerate(nodes, 1): 22 | print(f"{i}. Node ID: {node.node_id}") 23 | print(f" Address: {node.ip}:{node.port}") 24 | print(f" Model: {node.model}") 25 | print(f" Load: {node.load:.2f}") 26 | print(f" TPS: {node.tps:.2f}") 27 | print(f" Uptime: {node.uptime}s") 28 | print(f" Last seen: {node.last_seen}") 29 | 30 | # Try to get additional info from HTTP API 31 | try: 32 | # Get basic info 33 | response = requests.get(f"http://{node.ip}:{node.port}/info", timeout=5) 34 | if response.status_code == 200: 35 | info = response.json() 36 | print(f" DHT Port: {info.get('dht_port', 'unknown')}") 37 | if 'system' in info: 38 | system = info['system'] 39 | print(f" CPU: {system.get('cpu', 'unknown')}") 40 | if system.get('gpu'): 41 | print(f" GPU: {system['gpu']}") 42 | 43 | # Get health status 44 | health_response = requests.get(f"http://{node.ip}:{node.port}/health", timeout=5) 45 | if health_response.status_code == 200: 46 | health = health_response.json() 47 | health_icon = "💚" if health.get('healthy', False) else "💔" 48 | print(f" Health: {health_icon} {'Healthy' if health.get('healthy', False) else 'Unhealthy'}") 49 | 50 | except: 51 | print(f" Status: HTTP API not reachable") 52 | 53 | print() 54 | 55 | except Exception as e: 56 | print(f"❌ Error discovering network: {e}") 57 | finally: 58 | await discovery.stop() 59 | 60 | if __name__ == "__main__": 61 | import sys 62 | bootstrap = sys.argv[1] if len(sys.argv) > 1 else "localhost:8001" 63 | asyncio.run(show_network_status(bootstrap)) 64 | -------------------------------------------------------------------------------- /common/port_utils.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from typing import List, Tuple, Optional 3 | from common.utils import get_logger 4 | 5 | logger = get_logger(__name__) 6 | 7 | class PortManager: 8 | """Centralized port management utilities""" 9 | 10 | @staticmethod 11 | def is_tcp_port_available(port: int, host: str = '') -> bool: 12 | """Check if a TCP port is available""" 13 | try: 14 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 15 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 16 | s.bind((host, port)) 17 | return True 18 | except OSError: 19 | return False 20 | 21 | @staticmethod 22 | def is_udp_port_available(port: int, host: str = '') -> bool: 23 | """Check if a UDP port is available""" 24 | try: 25 | with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: 26 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 27 | s.bind((host, port)) 28 | return True 29 | except OSError: 30 | return False 31 | 32 | @staticmethod 33 | def find_available_tcp_port(start_port: int = 8000, max_attempts: int = 100) -> int: 34 | """Find an available TCP port""" 35 | for port in range(start_port, start_port + max_attempts): 36 | if PortManager.is_tcp_port_available(port): 37 | return port 38 | raise RuntimeError(f"No available TCP ports found starting from {start_port}") 39 | 40 | @staticmethod 41 | def find_available_udp_port(start_port: int = 8001, max_attempts: int = 100) -> int: 42 | """Find an available UDP port""" 43 | for port in range(start_port, start_port + max_attempts): 44 | if PortManager.is_udp_port_available(port): 45 | return port 46 | raise RuntimeError(f"No available UDP ports found starting from {start_port}") 47 | 48 | @staticmethod 49 | def get_port_with_fallback(preferred_port: int, port_type: str = 'tcp') -> int: 50 | """Get preferred port or find alternative""" 51 | check_func = PortManager.is_tcp_port_available if port_type == 'tcp' else PortManager.is_udp_port_available 52 | find_func = PortManager.find_available_tcp_port if port_type == 'tcp' else PortManager.find_available_udp_port 53 | 54 | if check_func(preferred_port): 55 | return preferred_port 56 | else: 57 | logger.warning(f"Port {preferred_port} not available, finding alternative") 58 | alternative = find_func(preferred_port) 59 | logger.info(f"Using alternative {port_type.upper()} port: {alternative}") 60 | return alternative 61 | -------------------------------------------------------------------------------- /docker/inference.Dockerfile: -------------------------------------------------------------------------------- 1 | # Multi-stage Dockerfile for LlamaNet with GPU/CPU auto-detection 2 | ARG PYTHON_VERSION=3.11 3 | ARG CUDA_VERSION=12.1.1 4 | ARG UBUNTU_VERSION=22.04 5 | 6 | # Base stage - common dependencies 7 | FROM python:${PYTHON_VERSION}-slim as base 8 | 9 | # Install system dependencies 10 | RUN apt-get update && apt-get install -y \ 11 | build-essential \ 12 | cmake \ 13 | git \ 14 | wget \ 15 | curl \ 16 | pkg-config \ 17 | libssl-dev \ 18 | libopenblas-dev \ 19 | && rm -rf /var/lib/apt/lists/* 20 | 21 | # Set working directory 22 | WORKDIR /app 23 | 24 | # Copy requirements first for better caching 25 | COPY requirements.txt requirements-dev.txt ./ 26 | 27 | # GPU detection and setup stage 28 | FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} as gpu-base 29 | 30 | # Install Python 31 | RUN apt-get update && apt-get install -y \ 32 | python${PYTHON_VERSION} \ 33 | python3-pip \ 34 | python3-dev \ 35 | build-essential \ 36 | cmake \ 37 | git \ 38 | wget \ 39 | curl \ 40 | pkg-config \ 41 | libssl-dev \ 42 | libopenblas-dev \ 43 | && rm -rf /var/lib/apt/lists/* 44 | 45 | # Create symlinks for python 46 | RUN ln -sf /usr/bin/python${PYTHON_VERSION} /usr/bin/python && \ 47 | ln -sf /usr/bin/python${PYTHON_VERSION} /usr/bin/python3 48 | 49 | WORKDIR /app 50 | 51 | # Copy requirements 52 | COPY requirements.txt requirements-dev.txt ./ 53 | 54 | # Final runtime stage 55 | FROM base as runtime 56 | 57 | # Copy detection and startup scripts 58 | COPY docker/gpu-detect.sh /usr/local/bin/gpu-detect.sh 59 | COPY docker/start.sh /usr/local/bin/start.sh 60 | RUN chmod +x /usr/local/bin/gpu-detect.sh /usr/local/bin/start.sh 61 | 62 | # Install base Python dependencies (without llama-cpp-python) 63 | RUN pip install --no-cache-dir --upgrade pip && \ 64 | grep -v "llama-cpp-python" requirements.txt > /tmp/requirements-base.txt && \ 65 | pip install --no-cache-dir -r /tmp/requirements-base.txt 66 | 67 | # Copy application code 68 | COPY . . 69 | 70 | # Install package in development mode 71 | RUN pip install -e . 72 | 73 | # Create directory for models 74 | RUN mkdir -p /models 75 | 76 | # Environment variables for configuration 77 | ENV PYTHONPATH=/app 78 | ENV PYTHONUNBUFFERED=1 79 | ENV MODEL_PATH=/models/model.gguf 80 | ENV HOST=0.0.0.0 81 | ENV PORT=8000 82 | ENV DHT_PORT=8001 83 | ENV BOOTSTRAP_NODES="" 84 | ENV HARDWARE_MODE=auto 85 | ENV N_GPU_LAYERS=0 86 | 87 | # Health check 88 | HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \ 89 | CMD curl -f http://localhost:${PORT}/health || exit 1 90 | 91 | # Expose ports 92 | EXPOSE 8000 8001 93 | 94 | # Use startup script as entrypoint 95 | ENTRYPOINT ["/usr/local/bin/start.sh"] 96 | CMD ["python", "-m", "inference_node.server"] 97 | -------------------------------------------------------------------------------- /common/validation_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import ipaddress 3 | from typing import Dict, Any, Optional, List 4 | from common.utils import get_logger 5 | 6 | logger = get_logger(__name__) 7 | 8 | class NodeValidator: 9 | """Centralized node validation utilities""" 10 | 11 | @staticmethod 12 | def validate_node_id(node_id: str) -> bool: 13 | """Validate node ID format""" 14 | try: 15 | if not node_id or not isinstance(node_id, str): 16 | return False 17 | if len(node_id) != 40: # SHA-1 hex length 18 | return False 19 | int(node_id, 16) # Test if valid hex 20 | return True 21 | except (ValueError, TypeError): 22 | return False 23 | 24 | @staticmethod 25 | def validate_contact(contact) -> bool: 26 | """Validate DHT contact""" 27 | try: 28 | if not contact.node_id or not contact.ip: 29 | return False 30 | if not NodeValidator.validate_node_id(contact.node_id): 31 | return False 32 | ipaddress.IPv4Address(contact.ip) # Validate IP 33 | if hasattr(contact, 'last_seen'): 34 | if time.time() - contact.last_seen > 300: # 5 minutes 35 | return False 36 | return True 37 | except Exception: 38 | return False 39 | 40 | @staticmethod 41 | def validate_node_info(node_info: Dict[str, Any]) -> bool: 42 | """Validate node info structure""" 43 | required_fields = ['node_id', 'ip', 'port', 'model'] 44 | 45 | for field in required_fields: 46 | if field not in node_info: 47 | return False 48 | 49 | if not NodeValidator.validate_node_id(node_info['node_id']): 50 | return False 51 | 52 | try: 53 | ipaddress.IPv4Address(node_info['ip']) 54 | port = int(node_info['port']) 55 | if not (1024 <= port <= 65535): 56 | return False 57 | except (ValueError, ipaddress.AddressValueError): 58 | return False 59 | 60 | return True 61 | 62 | class NetworkValidator: 63 | """Network-level validation utilities""" 64 | 65 | @staticmethod 66 | def check_duplicate_endpoints(nodes: List[Dict[str, Any]]) -> List[str]: 67 | """Check for duplicate IP:port combinations""" 68 | seen_endpoints = {} 69 | duplicates = [] 70 | 71 | for node in nodes: 72 | endpoint = f"{node.get('ip')}:{node.get('port')}" 73 | node_id = node.get('node_id') 74 | 75 | if endpoint in seen_endpoints: 76 | existing_node = seen_endpoints[endpoint] 77 | if existing_node != node_id: 78 | duplicates.append(f"Duplicate endpoint {endpoint}: {existing_node[:8]}... vs {node_id[:8]}...") 79 | else: 80 | seen_endpoints[endpoint] = node_id 81 | 82 | return duplicates 83 | -------------------------------------------------------------------------------- /examples/openai_client_example.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import aiohttp 3 | import json 4 | 5 | async def test_openai_compatibility(): 6 | """Test OpenAI-compatible endpoints""" 7 | base_url = "http://localhost:8000" 8 | 9 | async with aiohttp.ClientSession() as session: 10 | 11 | # Test models endpoint 12 | print("🔍 Testing /v1/models endpoint...") 13 | async with session.get(f"{base_url}/v1/models") as response: 14 | if response.status == 200: 15 | models = await response.json() 16 | print(f"✅ Available models: {[m['id'] for m in models['data']]}") 17 | else: 18 | print(f"❌ Models endpoint failed: {response.status}") 19 | return 20 | 21 | # Test completions endpoint 22 | print("\n🤖 Testing /v1/completions endpoint...") 23 | completion_request = { 24 | "model": "llamanet", 25 | "prompt": "What is artificial intelligence?", 26 | "max_tokens": 100, 27 | "temperature": 0.7 28 | } 29 | 30 | async with session.post( 31 | f"{base_url}/v1/completions", 32 | json=completion_request, 33 | headers={"Content-Type": "application/json"} 34 | ) as response: 35 | if response.status == 200: 36 | completion = await response.json() 37 | print(f"✅ Completion response:") 38 | print(f" ID: {completion['id']}") 39 | print(f" Text: {completion['choices'][0]['text'][:100]}...") 40 | print(f" Tokens: {completion['usage']['total_tokens']}") 41 | else: 42 | print(f"❌ Completions endpoint failed: {response.status}") 43 | text = await response.text() 44 | print(f" Error: {text}") 45 | 46 | # Test chat completions endpoint 47 | print("\n💬 Testing /v1/chat/completions endpoint...") 48 | chat_request = { 49 | "model": "llamanet", 50 | "messages": [ 51 | {"role": "system", "content": "You are a helpful assistant."}, 52 | {"role": "user", "content": "Explain quantum computing in simple terms."} 53 | ], 54 | "max_tokens": 150, 55 | "temperature": 0.7 56 | } 57 | 58 | async with session.post( 59 | f"{base_url}/v1/chat/completions", 60 | json=chat_request, 61 | headers={"Content-Type": "application/json"} 62 | ) as response: 63 | if response.status == 200: 64 | chat_completion = await response.json() 65 | print(f"✅ Chat completion response:") 66 | print(f" ID: {chat_completion['id']}") 67 | print(f" Message: {chat_completion['choices'][0]['message']['content'][:100]}...") 68 | print(f" Tokens: {chat_completion['usage']['total_tokens']}") 69 | else: 70 | print(f"❌ Chat completions endpoint failed: {response.status}") 71 | text = await response.text() 72 | print(f" Error: {text}") 73 | 74 | def test_with_openai_library(): 75 | """Example using the official OpenAI Python library""" 76 | print("\n📚 Example using OpenAI Python library:") 77 | print(""" 78 | # Install: pip install openai 79 | 80 | import openai 81 | 82 | # Configure to use LlamaNet 83 | openai.api_base = "http://localhost:8000/v1" 84 | openai.api_key = "dummy-key" # Not used but required 85 | 86 | # Text completion 87 | response = openai.Completion.create( 88 | model="llamanet", 89 | prompt="What is machine learning?", 90 | max_tokens=100 91 | ) 92 | print(response.choices[0].text) 93 | 94 | # Chat completion 95 | response = openai.ChatCompletion.create( 96 | model="llamanet", 97 | messages=[ 98 | {"role": "user", "content": "Hello, how are you?"} 99 | ] 100 | ) 101 | print(response.choices[0].message.content) 102 | """) 103 | 104 | if __name__ == "__main__": 105 | print("🚀 Testing OpenAI-compatible endpoints...") 106 | print("Make sure you have a LlamaNet node running on localhost:8000") 107 | print() 108 | 109 | asyncio.run(test_openai_compatibility()) 110 | test_with_openai_library() 111 | -------------------------------------------------------------------------------- /common/error_handler.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from typing import Any, Callable, Optional, Dict 4 | from functools import wraps 5 | from common.utils import get_logger 6 | 7 | logger = get_logger(__name__) 8 | 9 | class ErrorHandler: 10 | """Centralized error handling utilities""" 11 | 12 | @staticmethod 13 | def safe_async_call(func: Callable, error_message: str = "Operation failed", 14 | default_return: Any = None, log_level: int = logging.ERROR): 15 | """Decorator for safe async function calls with consistent error handling""" 16 | @wraps(func) 17 | async def wrapper(*args, **kwargs): 18 | try: 19 | return await func(*args, **kwargs) 20 | except asyncio.CancelledError: 21 | logger.info(f"{func.__name__} was cancelled") 22 | raise 23 | except Exception as e: 24 | logger.log(log_level, f"{error_message}: {e}") 25 | return default_return 26 | return wrapper 27 | 28 | @staticmethod 29 | def safe_sync_call(func: Callable, error_message: str = "Operation failed", 30 | default_return: Any = None, log_level: int = logging.ERROR): 31 | """Decorator for safe sync function calls with consistent error handling""" 32 | @wraps(func) 33 | def wrapper(*args, **kwargs): 34 | try: 35 | return func(*args, **kwargs) 36 | except Exception as e: 37 | logger.log(log_level, f"{error_message}: {e}") 38 | return default_return 39 | return wrapper 40 | 41 | @staticmethod 42 | async def safe_task_cancellation(task: Optional[asyncio.Task], task_name: str = "task"): 43 | """Safely cancel an asyncio task""" 44 | if task and not task.done(): 45 | task.cancel() 46 | try: 47 | await task 48 | except asyncio.CancelledError: 49 | logger.debug(f"✅ {task_name} cancelled successfully") 50 | except Exception as e: 51 | logger.debug(f"❌ Error cancelling {task_name}: {e}") 52 | else: 53 | logger.debug(f"✅ {task_name} already done or None") 54 | 55 | @staticmethod 56 | async def safe_component_stop(component: Any, component_name: str = "component"): 57 | """Safely stop a component with stop() method""" 58 | if component and hasattr(component, 'stop'): 59 | try: 60 | await component.stop() 61 | logger.debug(f"✅ {component_name} stopped") 62 | except Exception as e: 63 | logger.debug(f"❌ Error stopping {component_name}: {e}") 64 | else: 65 | logger.debug(f"✅ {component_name} has no stop method or is None") 66 | 67 | @staticmethod 68 | async def safe_component_close(component: Any, component_name: str = "component"): 69 | """Safely close a component with close() method""" 70 | if component and hasattr(component, 'close'): 71 | try: 72 | await component.close() 73 | logger.debug(f"✅ {component_name} closed") 74 | except Exception as e: 75 | logger.debug(f"❌ Error closing {component_name}: {e}") 76 | else: 77 | logger.debug(f"✅ {component_name} has no close method or is None") 78 | 79 | @staticmethod 80 | def create_timeout_handler(timeout: float, operation_name: str = "operation"): 81 | """Create a timeout handler for operations""" 82 | async def timeout_handler(): 83 | await asyncio.sleep(timeout) 84 | logger.warning(f"⏰ {operation_name} timed out after {timeout}s") 85 | return timeout_handler 86 | 87 | @staticmethod 88 | async def execute_with_timeout(coro, timeout: float, operation_name: str = "operation", 89 | default_return: Any = None): 90 | """Execute a coroutine with timeout and error handling""" 91 | try: 92 | return await asyncio.wait_for(coro, timeout=timeout) 93 | except asyncio.TimeoutError: 94 | logger.warning(f"⏰ {operation_name} timed out after {timeout}s") 95 | return default_return 96 | except Exception as e: 97 | logger.error(f"❌ {operation_name} failed: {e}") 98 | return default_return 99 | -------------------------------------------------------------------------------- /common/p2p_transport.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import time 4 | from typing import Optional, Dict, Any, Callable, List 5 | from p2pd import P2PNode 6 | from common.utils import get_logger 7 | 8 | logger = get_logger(__name__) 9 | 10 | class P2PTransport: 11 | """P2PD transport layer for NAT traversal""" 12 | 13 | def __init__(self, node_id: str, model_name: str = None): 14 | self.node_id = node_id 15 | self.model_name = model_name 16 | self.p2p_node: Optional[P2PNode] = None 17 | self.message_callbacks = [] 18 | self.nickname = None 19 | self.running = False 20 | 21 | async def start(self, port: int = None): 22 | """Start P2P node""" 23 | try: 24 | self.p2p_node = await P2PNode() 25 | self.p2p_node.add_msg_cb(self._handle_message) 26 | 27 | # Create unique nickname based on model and node_id 28 | if self.model_name: 29 | self.nickname = f"{self.model_name}-{self.node_id[:8]}" 30 | else: 31 | self.nickname = f"client-{self.node_id[:8]}" 32 | 33 | # Note: p2pd library doesn't support nickname registration 34 | # Nickname is used for identification in connection attempts 35 | logger.info(f"P2P node started with identifier: {self.nickname}") 36 | 37 | self.running = True 38 | logger.info(f"P2P transport started for node: {self.node_id}") 39 | 40 | except Exception as e: 41 | logger.error(f"Failed to start P2P transport: {e}") 42 | raise 43 | 44 | async def connect_to_peer(self, peer_nickname: str, timeout: float = 10.0) -> Optional[Any]: 45 | """Connect to a peer using P2PD with timeout""" 46 | if not self.p2p_node: 47 | raise RuntimeError("P2P node not started") 48 | 49 | try: 50 | # Try to connect with timeout 51 | # Note: p2pd may require peer ID instead of nickname 52 | pipe = await asyncio.wait_for( 53 | self.p2p_node.connect(peer_nickname), 54 | timeout=timeout 55 | ) 56 | logger.info(f"P2P connected to peer: {peer_nickname}") 57 | return pipe 58 | except asyncio.TimeoutError: 59 | logger.warning(f"P2P connection timeout to peer: {peer_nickname}") 60 | return None 61 | except AttributeError as e: 62 | logger.error(f"P2P method not available: {e}") 63 | return None 64 | except Exception as e: 65 | logger.error(f"Failed to connect to peer {peer_nickname}: {e}") 66 | return None 67 | 68 | async def send_message(self, pipe, message: bytes): 69 | """Send message through P2P pipe""" 70 | try: 71 | await pipe.send(message) 72 | except Exception as e: 73 | logger.error(f"Failed to send P2P message: {e}") 74 | raise 75 | 76 | def add_message_callback(self, callback: Callable): 77 | """Add callback for incoming messages""" 78 | self.message_callbacks.append(callback) 79 | 80 | async def _handle_message(self, msg: bytes, client_tup, pipe): 81 | """Handle incoming P2P messages""" 82 | for callback in self.message_callbacks: 83 | try: 84 | await callback(msg, client_tup, pipe) 85 | except Exception as e: 86 | logger.error(f"Error in P2P message callback: {e}") 87 | 88 | def get_address_info(self) -> Dict[str, Any]: 89 | """Get P2P address information""" 90 | if not self.p2p_node: 91 | return {} 92 | 93 | return { 94 | "nickname": self.nickname, 95 | "p2p_enabled": True, 96 | "supports_nat_traversal": True 97 | } 98 | 99 | async def close(self): 100 | """Close P2P transport""" 101 | self.running = False 102 | if self.p2p_node: 103 | try: 104 | await self.p2p_node.close() 105 | logger.info("P2P transport closed") 106 | except Exception as e: 107 | logger.error(f"Error closing P2P transport: {e}") 108 | -------------------------------------------------------------------------------- /common/metrics_manager.py: -------------------------------------------------------------------------------- 1 | import time 2 | import psutil 3 | from typing import Dict, Any, Optional 4 | from common.utils import get_logger 5 | 6 | logger = get_logger(__name__) 7 | 8 | class MetricsManager: 9 | """Centralized metrics collection and management""" 10 | 11 | def __init__(self): 12 | self.start_time = time.time() 13 | self.total_tokens_generated = 0 14 | self.total_generation_time = 0 15 | self.request_count = 0 16 | self.active_requests = 0 17 | 18 | def get_comprehensive_metrics(self) -> Dict[str, Any]: 19 | """Get all metrics in one call""" 20 | return { 21 | **self.get_performance_metrics(), 22 | **self.get_system_metrics(), 23 | **self.get_request_metrics() 24 | } 25 | 26 | def get_performance_metrics(self) -> Dict[str, Any]: 27 | """Get performance-related metrics""" 28 | uptime = int(time.time() - self.start_time) 29 | tps = self._calculate_tps() 30 | load = self._calculate_load() 31 | 32 | return { 33 | "uptime": uptime, 34 | "load": round(load, 2), 35 | "tps": round(tps, 2), 36 | "total_tokens": self.total_tokens_generated 37 | } 38 | 39 | def get_system_metrics(self) -> Dict[str, Any]: 40 | """Get system-related metrics""" 41 | try: 42 | cpu_percent = psutil.cpu_percent(interval=0) 43 | memory_percent = psutil.virtual_memory().percent 44 | 45 | try: 46 | disk_percent = psutil.disk_usage('/').percent 47 | except (OSError, PermissionError): 48 | try: 49 | disk_percent = psutil.disk_usage('.').percent 50 | except: 51 | disk_percent = 0.0 52 | 53 | return { 54 | "cpu_percent": cpu_percent, 55 | "memory_percent": memory_percent, 56 | "disk_percent": disk_percent 57 | } 58 | except Exception as e: 59 | logger.error(f"Error getting system metrics: {e}") 60 | return { 61 | "cpu_percent": 0.0, 62 | "memory_percent": 0.0, 63 | "disk_percent": 0.0 64 | } 65 | 66 | def get_request_metrics(self) -> Dict[str, Any]: 67 | """Get request-related metrics""" 68 | uptime = time.time() - self.start_time 69 | return { 70 | "active_requests": self.active_requests, 71 | "total_requests": self.request_count, 72 | "requests_per_second": self.request_count / uptime if uptime > 0 else 0 73 | } 74 | 75 | def _calculate_load(self) -> float: 76 | """Calculate current load""" 77 | return min(1.0, self.active_requests / 10) 78 | 79 | def _calculate_tps(self) -> float: 80 | """Calculate tokens per second""" 81 | if self.total_generation_time > 0: 82 | return self.total_tokens_generated / self.total_generation_time 83 | return 0.0 84 | 85 | def record_request_start(self): 86 | """Record request start""" 87 | self.active_requests += 1 88 | self.request_count += 1 89 | 90 | def record_request_end(self, tokens_generated: int = 0, generation_time: float = 0): 91 | """Record request completion""" 92 | self.active_requests = max(0, self.active_requests - 1) 93 | self.total_tokens_generated += tokens_generated 94 | self.total_generation_time += generation_time 95 | 96 | def is_overloaded(self, cpu_threshold: float = 80.0, 97 | memory_threshold: float = 85.0, 98 | max_requests: int = 10) -> Dict[str, Any]: 99 | """Check if node is overloaded""" 100 | system_metrics = self.get_system_metrics() 101 | 102 | overload_reasons = [] 103 | 104 | if system_metrics["cpu_percent"] > cpu_threshold: 105 | overload_reasons.append(f"CPU: {system_metrics['cpu_percent']:.1f}% > {cpu_threshold}%") 106 | 107 | if system_metrics["memory_percent"] > memory_threshold: 108 | overload_reasons.append(f"Memory: {system_metrics['memory_percent']:.1f}% > {memory_threshold}%") 109 | 110 | if self.active_requests >= max_requests: 111 | overload_reasons.append(f"Requests: {self.active_requests} >= {max_requests}") 112 | 113 | return { 114 | "is_overloaded": len(overload_reasons) > 0, 115 | "overload_reasons": overload_reasons, 116 | "load_metrics": system_metrics, 117 | "active_requests": self.active_requests, 118 | "thresholds": { 119 | "cpu": cpu_threshold, 120 | "memory": memory_threshold, 121 | "max_requests": max_requests 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /examples/simple_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from client.event_aware_client import EventAwareOpenAIClient 3 | 4 | async def node_change_callback(event): 5 | """Callback for real-time node events""" 6 | print(f"🔔 Network event: {event.event_type.value}") 7 | if event.node_info: 8 | print(f" Node: {event.node_info.node_id[:12]}... at {event.node_info.ip}:{event.node_info.port}") 9 | 10 | async def main(): 11 | # Create an event-aware OpenAI-compatible client with real-time discovery 12 | client = EventAwareOpenAIClient( 13 | bootstrap_nodes="localhost:8001", # Connect to bootstrap node 14 | model="llamanet", # Use the default model name 15 | event_callback=node_change_callback # Get real-time notifications 16 | ) 17 | 18 | try: 19 | # Start the client and wait for nodes 20 | print("🚀 Starting event-aware client...") 21 | await client.start() 22 | 23 | print("⏳ Waiting for nodes to be discovered...") 24 | nodes_found = await client.wait_for_nodes(min_nodes=1, timeout=30.0) 25 | 26 | if not nodes_found: 27 | print("❌ No nodes found within timeout") 28 | return 29 | 30 | # Show real-time network stats 31 | stats = client.get_real_time_stats() 32 | print(f"\n📊 Real-time network stats:") 33 | print(f" Total nodes: {stats['total_nodes']}") 34 | print(f" Network health: {stats['network_health']}") 35 | print(f" Average load: {stats.get('avg_load', 0):.3f}") 36 | print(f" Total capacity: {stats.get('total_capacity', 0):.2f} TPS") 37 | 38 | # Show available nodes (real-time, no polling) 39 | nodes = await client.get_available_nodes() 40 | print(f"\n🌐 Available nodes: {len(nodes)}") 41 | 42 | for node in nodes: 43 | print(f" • {node.node_id[:8]}... ({node.ip}:{node.port}) - {node.model} - Load: {node.load:.2f}") 44 | 45 | # Test chat completions (OpenAI format) 46 | print("\n🤖 Testing chat completions...") 47 | messages = [ 48 | {"role": "system", "content": "You are a helpful assistant."}, 49 | {"role": "user", "content": "What is LlamaNet and how does it work?"} 50 | ] 51 | 52 | response = await client.chat_completions( 53 | messages=messages, 54 | max_tokens=150, 55 | temperature=0.7, 56 | strategy="round_robin" 57 | ) 58 | 59 | if response: 60 | print(f"✅ Chat completion response:") 61 | print(f" ID: {response.id}") 62 | print(f" Model: {response.model}") 63 | print(f" Content: {response.choices[0].message.content}") 64 | print(f" Tokens used: {response.usage.total_tokens}") 65 | else: 66 | print("❌ No response received") 67 | 68 | # Test text completions (OpenAI format) 69 | print("\n📝 Testing text completions...") 70 | completion_response = await client.completions( 71 | prompt="LlamaNet is a decentralized inference network that", 72 | max_tokens=100, 73 | temperature=0.7, 74 | strategy="round_robin" 75 | ) 76 | 77 | if completion_response: 78 | print(f"✅ Text completion response:") 79 | print(f" ID: {completion_response.id}") 80 | print(f" Text: {completion_response.choices[0].text}") 81 | print(f" Tokens used: {completion_response.usage.total_tokens}") 82 | else: 83 | print("❌ No completion response received") 84 | 85 | # Test round robin distribution with real-time updates 86 | print("\n🔄 Testing round robin distribution...") 87 | for i in range(6): 88 | print(f"\n--- Request {i+1} ---") 89 | 90 | # Show current network state 91 | current_stats = client.get_real_time_stats() 92 | print(f"Current nodes: {current_stats['total_nodes']}") 93 | 94 | response = await client.chat_completions( 95 | messages=[{"role": "user", "content": f"Hello {i+1}"}], 96 | max_tokens=50, 97 | strategy="round_robin" 98 | ) 99 | if response: 100 | print(f"Response from: {response.id}") 101 | 102 | # Small delay to see any real-time changes 103 | await asyncio.sleep(2) 104 | 105 | # Demonstrate real-time monitoring 106 | print("\n⏱️ Monitoring network for 10 seconds...") 107 | print(" (Try starting/stopping nodes to see real-time updates)") 108 | 109 | for i in range(10): 110 | await asyncio.sleep(1) 111 | stats = client.get_real_time_stats() 112 | print(f" Nodes: {stats['total_nodes']}, Health: {stats['network_health']}") 113 | 114 | finally: 115 | await client.close() 116 | 117 | if __name__ == "__main__": 118 | asyncio.run(main()) 119 | -------------------------------------------------------------------------------- /client/router.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Optional, Dict, Any, Callable 3 | from common.models import NodeInfo 4 | from client.dht_discovery import DHTDiscovery 5 | from common.utils import get_logger 6 | 7 | logger = get_logger(__name__) 8 | 9 | class NodeSelector: 10 | """Select the best node for inference""" 11 | 12 | def __init__(self, dht_discovery: DHTDiscovery): 13 | self.dht_discovery = dht_discovery 14 | self.round_robin_index = 0 # Track round robin position 15 | 16 | async def select_node(self, 17 | model: Optional[str] = None, 18 | min_tps: float = 0.0, 19 | max_load: float = 1.0, 20 | strategy: str = "round_robin", # "load_balanced", "round_robin", "random" 21 | randomize: bool = True, 22 | target_model: Optional[str] = None) -> Optional[NodeInfo]: 23 | """Select the best node based on criteria and strategy""" 24 | 25 | # Use target_model if specified, otherwise fall back to model parameter 26 | model_filter = target_model or model 27 | 28 | # Get nodes from event-based discovery (real-time, no polling) 29 | nodes = await self.dht_discovery.get_nodes(model=model_filter) 30 | 31 | if not nodes: 32 | logger.warning(f"No nodes available for model {model_filter}") 33 | return None 34 | 35 | # If target_model is specified, ensure we only get nodes with that exact model 36 | if target_model: 37 | nodes = [node for node in nodes if node.model == target_model] 38 | if not nodes: 39 | logger.warning(f"No nodes found running the specific model: {target_model}") 40 | return None 41 | 42 | # Log available nodes for debugging 43 | logger.debug(f"Available nodes for selection (model: {model_filter}):") 44 | for node in nodes: 45 | logger.debug(f" - {node.node_id[:8]}... at {node.ip}:{node.port} (model: {node.model}, load: {node.load})") 46 | 47 | # Filter by criteria (be more lenient for DHT contacts with unknown metrics) 48 | eligible_nodes = [] 49 | for node in nodes: 50 | # For nodes with unknown metrics (DHT contacts), be more permissive 51 | if node.model == "unknown" or node.tps == 0.0: 52 | eligible_nodes.append(node) # Include DHT contacts regardless of metrics 53 | elif node.tps >= min_tps and node.load <= max_load: 54 | eligible_nodes.append(node) 55 | 56 | if not eligible_nodes: 57 | logger.warning(f"No nodes meet criteria (min_tps={min_tps}, max_load={max_load}) for model {model_filter}") 58 | # Fall back to any available node with the target model 59 | eligible_nodes = nodes 60 | 61 | # Apply selection strategy with logging 62 | selected = None 63 | if strategy == "round_robin": 64 | selected = self._round_robin_select(eligible_nodes) 65 | elif strategy == "random": 66 | selected = random.choice(eligible_nodes) 67 | elif strategy == "load_balanced": 68 | selected = self._load_balanced_select(eligible_nodes, randomize) 69 | else: 70 | logger.warning(f"Unknown strategy {strategy}, using round_robin") 71 | selected = self._round_robin_select(eligible_nodes) 72 | 73 | if selected: 74 | logger.info(f"🎯 Selected node {selected.node_id[:8]}... at {selected.ip}:{selected.port} (model: {selected.model}) via {strategy} strategy") 75 | 76 | return selected 77 | 78 | def _round_robin_select(self, nodes: List[NodeInfo]) -> NodeInfo: 79 | """Select node using true round robin with persistent state""" 80 | if not nodes: 81 | return None 82 | 83 | # Sort nodes by node_id for consistent ordering across all clients 84 | sorted_nodes = sorted(nodes, key=lambda n: n.node_id) 85 | 86 | # Use modulo arithmetic for true round-robin 87 | selected_node = sorted_nodes[self.round_robin_index % len(sorted_nodes)] 88 | 89 | # Increment counter for next selection 90 | self.round_robin_index = (self.round_robin_index + 1) % len(sorted_nodes) 91 | 92 | logger.debug(f"Round robin selected node {selected_node.node_id[:8]}... (index: {self.round_robin_index - 1})") 93 | return selected_node 94 | 95 | def _load_balanced_select(self, nodes: List[NodeInfo], randomize: bool) -> NodeInfo: 96 | """Select node using load balancing (original logic)""" 97 | # Sort by load (ascending) 98 | nodes.sort(key=lambda n: n.load) 99 | 100 | # Get the best nodes (those with similar load) 101 | best_load = nodes[0].load 102 | best_nodes = [n for n in nodes if n.load <= best_load + 0.1] 103 | 104 | # Randomize if requested 105 | if randomize and len(best_nodes) > 1: 106 | return random.choice(best_nodes) 107 | 108 | # Return the first (lowest load) 109 | return best_nodes[0] 110 | -------------------------------------------------------------------------------- /inference_node/heartbeat.py: -------------------------------------------------------------------------------- 1 | import time 2 | import threading 3 | import requests 4 | from typing import Dict, Any, Optional 5 | import json 6 | from common.utils import get_logger, get_host_ip 7 | from common.models import NodeInfo 8 | from inference_node.config import InferenceConfig 9 | 10 | logger = get_logger(__name__) 11 | 12 | class HeartbeatSender: 13 | """Send heartbeats to the registry""" 14 | 15 | def __init__(self, config: InferenceConfig, metrics_callback): 16 | self.config = config 17 | self.metrics_callback = metrics_callback 18 | self.running = False 19 | self.thread = None 20 | self.node_info = NodeInfo( 21 | node_id=config.node_id, 22 | ip=get_host_ip(), 23 | port=config.port, 24 | model=config.model_name 25 | ) 26 | 27 | def start(self): 28 | """Start sending heartbeats""" 29 | if self.running: 30 | return 31 | 32 | self.running = True 33 | self.thread = threading.Thread(target=self._heartbeat_loop) 34 | self.thread.daemon = True 35 | self.thread.start() 36 | logger.info(f"Started heartbeat to registry at {self.config.registry_url}") 37 | 38 | def stop(self): 39 | """Stop sending heartbeats""" 40 | self.running = False 41 | if self.thread: 42 | self.thread.join(timeout=1) 43 | 44 | def _heartbeat_loop(self): 45 | """Send heartbeats at regular intervals""" 46 | while self.running: 47 | try: 48 | self._send_heartbeat() 49 | except Exception as e: 50 | logger.error(f"Failed to send heartbeat: {e}") 51 | 52 | time.sleep(self.config.heartbeat_interval) 53 | 54 | def _send_heartbeat(self): 55 | """Send a single heartbeat to the registry""" 56 | # Get current metrics 57 | metrics = self.metrics_callback() 58 | 59 | # Update node info 60 | self.node_info.load = metrics["load"] 61 | self.node_info.tps = metrics["tps"] 62 | self.node_info.uptime = metrics["uptime"] 63 | self.node_info.last_seen = int(time.time()) 64 | 65 | # Send to registry 66 | response = requests.post( 67 | f"{self.config.registry_url}/register", 68 | data=self.node_info.json(), 69 | headers={"Content-Type": "application/json"} 70 | ) 71 | 72 | if response.status_code != 200: 73 | logger.warning(f"Registry returned status {response.status_code}: {response.text}") 74 | import asyncio 75 | import time 76 | from typing import Dict, Any, Callable 77 | from common.utils import get_logger 78 | 79 | logger = get_logger(__name__) 80 | 81 | class HeartbeatManager: 82 | """Manages node health monitoring and heartbeat signals""" 83 | 84 | def __init__(self, node_id: str, metrics_callback: Callable[[], Dict[str, Any]], interval: int = 10): 85 | self.node_id = node_id 86 | self.metrics_callback = metrics_callback 87 | self.interval = interval 88 | self.running = False 89 | self.heartbeat_task = None 90 | self.last_heartbeat = 0 91 | 92 | async def start(self): 93 | """Start the heartbeat system""" 94 | if self.running: 95 | return 96 | 97 | self.running = True 98 | self.heartbeat_task = asyncio.create_task(self._heartbeat_loop()) 99 | logger.info(f"Heartbeat manager started for node {self.node_id[:8]}...") 100 | 101 | async def stop(self): 102 | """Stop the heartbeat system""" 103 | self.running = False 104 | if self.heartbeat_task: 105 | self.heartbeat_task.cancel() 106 | try: 107 | await self.heartbeat_task 108 | except asyncio.CancelledError: 109 | pass 110 | logger.info("Heartbeat manager stopped") 111 | 112 | async def _heartbeat_loop(self): 113 | """Main heartbeat loop""" 114 | while self.running: 115 | try: 116 | await self._send_heartbeat() 117 | await asyncio.sleep(self.interval) 118 | except asyncio.CancelledError: 119 | break 120 | except Exception as e: 121 | logger.error(f"Error in heartbeat loop: {e}") 122 | await asyncio.sleep(5) # Wait before retrying 123 | 124 | async def _send_heartbeat(self): 125 | """Send a heartbeat signal""" 126 | current_time = time.time() 127 | metrics = self.metrics_callback() 128 | 129 | # Update last heartbeat time 130 | self.last_heartbeat = current_time 131 | 132 | # Log heartbeat (in production, this would send to monitoring system) 133 | logger.debug(f"💓 Heartbeat from {self.node_id[:8]}... - Load: {metrics.get('load', 0):.2f}, TPS: {metrics.get('tps', 0):.2f}") 134 | 135 | def is_healthy(self) -> bool: 136 | """Check if the node is healthy based on recent heartbeats""" 137 | if self.last_heartbeat == 0: 138 | return False 139 | return time.time() - self.last_heartbeat < (self.interval * 3) # Allow 3 missed heartbeats 140 | 141 | def get_health_status(self) -> Dict[str, Any]: 142 | """Get detailed health status""" 143 | current_time = time.time() 144 | time_since_last = current_time - self.last_heartbeat if self.last_heartbeat > 0 else float('inf') 145 | 146 | return { 147 | "healthy": self.is_healthy(), 148 | "running": self.running, 149 | "last_heartbeat": self.last_heartbeat, 150 | "time_since_last_heartbeat": time_since_last, 151 | "heartbeat_interval": self.interval 152 | } 153 | -------------------------------------------------------------------------------- /start-app.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # LlamaNet OpenAI-Compatible Inference Node Startup Script 4 | # This script handles deployment on MACHAAO platform and local development 5 | 6 | set -e 7 | 8 | echo "🚀 Starting LlamaNet OpenAI-Compatible Inference Node..." 9 | 10 | # Check if we're in a containerized environment 11 | if [ -d "/app" ] && [ "$(pwd)" = "/app" ]; then 12 | echo "📦 Running in containerized environment" 13 | CONTAINER_MODE=true 14 | else 15 | echo "💻 Running in local development mode" 16 | CONTAINER_MODE=false 17 | fi 18 | 19 | # Set default values 20 | DEFAULT_MODEL_PATH="${MODEL_PATH:-./models/model.gguf}" 21 | DEFAULT_HOST="${HOST:-0.0.0.0}" 22 | DEFAULT_PORT="${PORT:-8000}" 23 | DEFAULT_DHT_PORT="${DHT_PORT:-8001}" 24 | DEFAULT_NODE_ID="${NODE_ID:-}" 25 | DEFAULT_BOOTSTRAP_NODES="${BOOTSTRAP_NODES:-}" 26 | 27 | # Suppress Python semaphore warnings for cleaner output 28 | export PYTHONWARNINGS="ignore:semaphore:UserWarning:multiprocessing.resource_tracker,ignore:resource_tracker" 29 | export PYTHONDONTWRITEBYTECODE=1 30 | 31 | # Validate model file exists 32 | if [ ! -f "$DEFAULT_MODEL_PATH" ]; then 33 | echo "❌ Error: Model file not found at $DEFAULT_MODEL_PATH" 34 | echo "Please set MODEL_PATH environment variable or place model at ./models/model.gguf" 35 | exit 1 36 | fi 37 | 38 | echo "✅ Model file found: $DEFAULT_MODEL_PATH" 39 | 40 | # Check if Python dependencies are installed 41 | if ! python -c "import fastapi, uvicorn, llama_cpp" 2>/dev/null; then 42 | echo "📦 Installing Python dependencies..." 43 | if [ -f "requirements.txt" ]; then 44 | pip install -r requirements.txt 45 | else 46 | echo "❌ Error: requirements.txt not found" 47 | exit 1 48 | fi 49 | fi 50 | 51 | # Install package in development mode if not already installed 52 | if ! python -c "import inference_node" 2>/dev/null; then 53 | echo "📦 Installing LlamaNet package..." 54 | pip install -e . 55 | fi 56 | 57 | # Health check endpoint 58 | health_check() { 59 | local port=$1 60 | local max_attempts=30 61 | local attempt=1 62 | 63 | echo "🔍 Waiting for service to be ready on port $port..." 64 | 65 | while [ $attempt -le $max_attempts ]; do 66 | if curl -s "http://localhost:$port/health" >/dev/null 2>&1; then 67 | echo "✅ Service is ready!" 68 | return 0 69 | fi 70 | 71 | echo "⏳ Attempt $attempt/$max_attempts - waiting for service..." 72 | sleep 2 73 | attempt=$((attempt + 1)) 74 | done 75 | 76 | echo "❌ Service failed to start within expected time" 77 | return 1 78 | } 79 | 80 | # Signal handler for graceful shutdown 81 | cleanup() { 82 | echo "🛑 Received shutdown signal, stopping LlamaNet node..." 83 | if [ ! -z "$SERVER_PID" ]; then 84 | echo "📤 Sending SIGTERM to server process $SERVER_PID..." 85 | # Send SIGTERM and let the application handle graceful shutdown 86 | kill -TERM $SERVER_PID 2>/dev/null || true 87 | 88 | # Wait for graceful shutdown with appropriate timeout 89 | echo "⏳ Waiting for graceful shutdown (max 10 seconds)..." 90 | for i in $(seq 1 10); do 91 | if ! kill -0 $SERVER_PID 2>/dev/null; then 92 | echo "✅ Server shut down gracefully" 93 | exit 0 94 | fi 95 | sleep 1 96 | done 97 | 98 | # Send SIGINT if still running 99 | echo "⚠️ Sending SIGINT for faster shutdown..." 100 | kill -INT $SERVER_PID 2>/dev/null || true 101 | 102 | # Wait a bit more 103 | for i in $(seq 1 3); do 104 | if ! kill -0 $SERVER_PID 2>/dev/null; then 105 | echo "✅ Server shut down after SIGINT" 106 | exit 0 107 | fi 108 | sleep 1 109 | done 110 | 111 | # Force kill if still running 112 | echo "⚠️ Forcing server shutdown..." 113 | kill -KILL $SERVER_PID 2>/dev/null || true 114 | fi 115 | exit 0 116 | } 117 | 118 | # Set up signal traps - only trap in shell script, not in Python 119 | trap cleanup SIGINT SIGTERM 120 | 121 | # Build command line arguments 122 | ARGS="--model-path $DEFAULT_MODEL_PATH" 123 | ARGS="$ARGS --host $DEFAULT_HOST" 124 | ARGS="$ARGS --port $DEFAULT_PORT" 125 | ARGS="$ARGS --dht-port $DEFAULT_DHT_PORT" 126 | 127 | if [ -n "$DEFAULT_NODE_ID" ]; then 128 | ARGS="$ARGS --node-id $DEFAULT_NODE_ID" 129 | fi 130 | 131 | if [ -n "$DEFAULT_BOOTSTRAP_NODES" ]; then 132 | ARGS="$ARGS --bootstrap-nodes $DEFAULT_BOOTSTRAP_NODES" 133 | fi 134 | 135 | echo "🔧 Configuration:" 136 | echo " Model: $DEFAULT_MODEL_PATH" 137 | echo " Host: $DEFAULT_HOST" 138 | echo " HTTP Port: $DEFAULT_PORT" 139 | echo " DHT Port: $DEFAULT_DHT_PORT" 140 | echo " Node ID: ${DEFAULT_NODE_ID:-auto-generated}" 141 | echo " Bootstrap Nodes: ${DEFAULT_BOOTSTRAP_NODES:-none (bootstrap mode)}" 142 | 143 | # Start the inference node 144 | echo "🚀 Starting inference node with OpenAI-compatible API..." 145 | echo "📡 API will be available at: http://$DEFAULT_HOST:$DEFAULT_PORT" 146 | echo "🌐 Web UI will be available at: http://$DEFAULT_HOST:$DEFAULT_PORT" 147 | echo "🔗 OpenAI-compatible endpoints:" 148 | echo " - GET /v1/models" 149 | echo " - POST /v1/completions" 150 | echo " - POST /v1/chat/completions" 151 | 152 | # Start the server in background for health check 153 | python -m inference_node.server $ARGS & 154 | SERVER_PID=$! 155 | 156 | # Wait for service to be ready 157 | if health_check $DEFAULT_PORT; then 158 | echo "🎉 LlamaNet OpenAI-Compatible Inference Node is running!" 159 | echo "📊 Monitor network status: python -m tools.monitor" 160 | echo "🔍 Quick network check: python -m tools.quick_check" 161 | echo "🛑 Press Ctrl+C for graceful shutdown" 162 | 163 | # Keep the server running in foreground with proper signal handling 164 | wait $SERVER_PID 165 | exit_code=$? 166 | 167 | if [ $exit_code -eq 0 ]; then 168 | echo "✅ Server exited gracefully" 169 | else 170 | echo "❌ Server exited with code $exit_code" 171 | fi 172 | 173 | exit $exit_code 174 | else 175 | echo "❌ Failed to start service" 176 | kill $SERVER_PID 2>/dev/null || true 177 | exit 1 178 | fi 179 | -------------------------------------------------------------------------------- /inference_node/dht_publisher.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | from typing import Dict, Any 4 | from inference_node.event_publisher import EventBasedDHTPublisher 5 | from common.utils import get_logger 6 | 7 | logger = get_logger(__name__) 8 | 9 | class HardwareBasedDHTPublisher(EventBasedDHTPublisher): 10 | """Hardware-based DHT publisher with enhanced validation and consistency checks""" 11 | 12 | def __init__(self, config, metrics_callback): 13 | super().__init__(config, metrics_callback) 14 | 15 | # Additional hardware-specific tracking 16 | self.hardware_validation_interval = 300 # 5 minutes 17 | self.last_hardware_validation = 0 18 | self.hardware_change_count = 0 19 | 20 | logger.info(f"Hardware-based DHT publisher initialized for node: {config.node_id[:16]}...") 21 | 22 | async def start(self): 23 | """Start the hardware-based DHT publisher with enhanced validation""" 24 | # Perform initial hardware validation 25 | await self._perform_initial_hardware_validation() 26 | 27 | # Start the base publisher 28 | await super().start() 29 | 30 | logger.info("Hardware-based DHT publisher started with validated node ID") 31 | 32 | async def _perform_initial_hardware_validation(self): 33 | """Perform comprehensive hardware validation before starting""" 34 | if not self.hardware_fingerprint: 35 | logger.warning("Hardware fingerprint not available for validation") 36 | return 37 | 38 | try: 39 | # Check if current node ID matches hardware 40 | expected_node_id = self.hardware_fingerprint.generate_node_id(self.config.port) 41 | 42 | if self.config.node_id != expected_node_id: 43 | logger.warning("Initial hardware validation failed!") 44 | logger.info(f"Expected: {expected_node_id[:16]}...") 45 | logger.info(f"Current: {self.config.node_id[:16]}...") 46 | 47 | # Check if we have a stored node ID that matches 48 | stored_node_id = self.config._get_stored_node_id() if hasattr(self.config, '_get_stored_node_id') else None 49 | 50 | if stored_node_id == expected_node_id: 51 | logger.info("Found matching stored node ID, updating configuration") 52 | self.config.node_id = expected_node_id 53 | elif stored_node_id and stored_node_id != expected_node_id: 54 | logger.warning("Stored node ID also doesn't match current hardware") 55 | logger.info("This indicates significant hardware changes") 56 | 57 | # Update to new hardware-based ID 58 | self.config.node_id = expected_node_id 59 | if hasattr(self.config, '_store_node_id'): 60 | self.config._store_node_id(expected_node_id) 61 | 62 | self.hardware_change_count += 1 63 | logger.info(f"Updated to new hardware-based node ID: {expected_node_id[:16]}...") 64 | else: 65 | logger.info("Initial hardware validation passed") 66 | 67 | except Exception as e: 68 | logger.error(f"Error during initial hardware validation: {e}") 69 | 70 | async def _monitor_changes(self): 71 | """Enhanced monitoring with hardware validation""" 72 | while self.running: 73 | try: 74 | current_metrics = self.metrics_callback() 75 | current_time = time.time() 76 | 77 | # Check if metrics changed significantly 78 | should_update = self._should_update_metrics(current_metrics) 79 | 80 | # Periodic hardware validation 81 | if current_time - self.last_hardware_validation > self.hardware_validation_interval: 82 | await self._validate_hardware_consistency() 83 | self.last_hardware_validation = current_time 84 | 85 | if should_update: 86 | await self._publish_node_info() 87 | self.last_published_metrics = current_metrics.copy() 88 | 89 | await asyncio.sleep(5) # Check every 5 seconds 90 | 91 | except asyncio.CancelledError: 92 | break 93 | except Exception as e: 94 | logger.error(f"Error monitoring changes: {e}") 95 | await asyncio.sleep(10) 96 | 97 | async def _validate_hardware_consistency(self): 98 | """Periodic hardware consistency validation""" 99 | if not self.hardware_fingerprint: 100 | return 101 | 102 | try: 103 | # Regenerate fingerprint to check for changes 104 | from common.hardware_fingerprint import HardwareFingerprint 105 | fresh_fingerprint = HardwareFingerprint() 106 | 107 | current_node_id = fresh_fingerprint.generate_node_id(self.config.port) 108 | 109 | if current_node_id != self.config.node_id: 110 | logger.warning("Hardware consistency check failed!") 111 | logger.info("Hardware may have changed during runtime") 112 | 113 | # Handle the hardware change 114 | await self.handle_hardware_change() 115 | 116 | # Update our fingerprint reference 117 | self.hardware_fingerprint = fresh_fingerprint 118 | self.hardware_change_count += 1 119 | else: 120 | logger.debug("Periodic hardware consistency check passed") 121 | 122 | except Exception as e: 123 | logger.error(f"Error during hardware consistency validation: {e}") 124 | 125 | def get_hardware_stats(self) -> Dict[str, Any]: 126 | """Get hardware-related statistics""" 127 | stats = self.get_node_info() 128 | 129 | stats.update({ 130 | 'hardware_validation_interval': self.hardware_validation_interval, 131 | 'last_hardware_validation': self.last_hardware_validation, 132 | 'hardware_change_count': self.hardware_change_count, 133 | 'validation_enabled': self.hardware_fingerprint is not None 134 | }) 135 | 136 | return stats 137 | 138 | # For backward compatibility, alias the event-based version as DHTPublisher 139 | DHTPublisher = HardwareBasedDHTPublisher 140 | -------------------------------------------------------------------------------- /common/hardware_fingerprint.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import platform 3 | import uuid 4 | import psutil 5 | import socket 6 | import os 7 | import atexit 8 | import threading 9 | from typing import Dict, List, Optional 10 | from common.utils import get_logger 11 | 12 | logger = get_logger(__name__) 13 | 14 | # Global cleanup for hardware fingerprint 15 | _fingerprint_cache = {} 16 | _cache_lock = threading.Lock() 17 | 18 | def _cleanup_fingerprint_cache(): 19 | """Clean up fingerprint cache""" 20 | with _cache_lock: 21 | _fingerprint_cache.clear() 22 | 23 | atexit.register(_cleanup_fingerprint_cache) 24 | 25 | class HardwareFingerprint: 26 | """Generate consistent hardware-based fingerprints for node identification""" 27 | 28 | def __init__(self): 29 | self.fingerprint_data = {} 30 | self._collect_hardware_info() 31 | 32 | def _collect_hardware_info(self) -> None: 33 | """Collect hardware information for fingerprinting""" 34 | try: 35 | # CPU information 36 | self.fingerprint_data['cpu_count'] = psutil.cpu_count(logical=False) 37 | self.fingerprint_data['cpu_count_logical'] = psutil.cpu_count(logical=True) 38 | 39 | # Memory information (total RAM in GB, rounded) 40 | memory_gb = round(psutil.virtual_memory().total / (1024**3)) 41 | self.fingerprint_data['memory_gb'] = memory_gb 42 | 43 | # Platform information 44 | self.fingerprint_data['platform'] = platform.platform() 45 | self.fingerprint_data['machine'] = platform.machine() 46 | self.fingerprint_data['processor'] = platform.processor() 47 | 48 | # Network interfaces (MAC addresses) 49 | self.fingerprint_data['mac_addresses'] = self._get_mac_addresses() 50 | 51 | # Hostname (as fallback) 52 | self.fingerprint_data['hostname'] = socket.gethostname() 53 | 54 | logger.debug(f"Collected hardware fingerprint data: {self._sanitize_for_log()}") 55 | 56 | except Exception as e: 57 | logger.warning(f"Error collecting hardware info: {e}") 58 | # Fallback to minimal info 59 | self.fingerprint_data = { 60 | 'hostname': socket.gethostname(), 61 | 'platform': platform.platform(), 62 | 'fallback': True 63 | } 64 | 65 | def _get_mac_addresses(self) -> List[str]: 66 | """Get MAC addresses from network interfaces with caching""" 67 | cache_key = 'mac_addresses' 68 | 69 | with _cache_lock: 70 | if cache_key in _fingerprint_cache: 71 | return _fingerprint_cache[cache_key] 72 | 73 | mac_addresses = [] 74 | try: 75 | import psutil 76 | for interface, addrs in psutil.net_if_addrs().items(): 77 | for addr in addrs: 78 | if addr.family == psutil.AF_LINK: # MAC address 79 | mac = addr.address 80 | # Filter out virtual/temporary interfaces 81 | if (mac and mac != '00:00:00:00:00:00' and 82 | not interface.startswith(('veth', 'docker', 'br-', 'lo'))): 83 | mac_addresses.append(mac.upper()) 84 | except Exception as e: 85 | logger.debug(f"Error getting MAC addresses: {e}") 86 | 87 | # Sort for consistency 88 | result = sorted(list(set(mac_addresses))) 89 | 90 | with _cache_lock: 91 | _fingerprint_cache[cache_key] = result 92 | 93 | return result 94 | 95 | def _get_system_uuid(self) -> Optional[str]: 96 | """Get system UUID if available""" 97 | try: 98 | # Try different methods to get system UUID 99 | uuid_sources = [ 100 | '/sys/class/dmi/id/product_uuid', 101 | '/proc/sys/kernel/random/uuid' 102 | ] 103 | 104 | for source in uuid_sources: 105 | try: 106 | with open(source, 'r') as f: 107 | system_uuid = f.read().strip() 108 | if system_uuid and len(system_uuid) > 10: 109 | return system_uuid 110 | except (FileNotFoundError, PermissionError): 111 | continue 112 | 113 | # Fallback to Python's uuid if available 114 | try: 115 | return str(uuid.uuid1()) 116 | except: 117 | return None 118 | 119 | except Exception as e: 120 | logger.debug(f"Error getting system UUID: {e}") 121 | return None 122 | 123 | def generate_node_id(self, port: int = None) -> str: 124 | """Generate a consistent hardware-based node ID""" 125 | # Create a deterministic string from hardware info 126 | fingerprint_parts = [] 127 | 128 | # Primary identifiers (most stable) 129 | if self.fingerprint_data.get('mac_addresses'): 130 | fingerprint_parts.extend(self.fingerprint_data['mac_addresses']) 131 | 132 | # Secondary identifiers 133 | fingerprint_parts.extend([ 134 | str(self.fingerprint_data.get('cpu_count', 0)), 135 | str(self.fingerprint_data.get('memory_gb', 0)), 136 | self.fingerprint_data.get('machine', ''), 137 | self.fingerprint_data.get('hostname', '') 138 | ]) 139 | 140 | # Include port for uniqueness when multiple nodes on same hardware 141 | if port: 142 | fingerprint_parts.append(f"port:{port}") 143 | 144 | # Create deterministic hash 145 | fingerprint_string = '|'.join(filter(None, fingerprint_parts)) 146 | 147 | if not fingerprint_string: 148 | logger.warning("No hardware fingerprint data available, using fallback") 149 | fingerprint_string = f"fallback:{socket.gethostname()}:{uuid.uuid4().hex[:8]}" 150 | 151 | # Generate SHA-1 hash for Kademlia compatibility (160-bit) 152 | node_id = hashlib.sha1(fingerprint_string.encode('utf-8')).hexdigest() 153 | 154 | logger.info(f"using hardware-based node ID: {node_id}") 155 | return node_id 156 | 157 | def get_fingerprint_summary(self) -> Dict: 158 | """Get a summary of the hardware fingerprint for debugging""" 159 | return { 160 | 'mac_count': len(self.fingerprint_data.get('mac_addresses', [])), 161 | 'has_system_uuid': bool(self.fingerprint_data.get('system_uuid')), 162 | 'cpu_count': self.fingerprint_data.get('cpu_count'), 163 | 'memory_gb': self.fingerprint_data.get('memory_gb'), 164 | 'platform': self.fingerprint_data.get('platform', '')[:50], # Truncate for readability 165 | 'hostname': self.fingerprint_data.get('hostname'), 166 | 'is_fallback': self.fingerprint_data.get('fallback', False) 167 | } 168 | 169 | def _sanitize_for_log(self) -> Dict: 170 | """Sanitize fingerprint data for logging (hide sensitive info)""" 171 | sanitized = self.fingerprint_data.copy() 172 | 173 | # Mask MAC addresses for privacy 174 | if 'mac_addresses' in sanitized: 175 | sanitized['mac_addresses'] = [f"{mac[:8]}:XX:XX:XX" for mac in sanitized['mac_addresses']] 176 | 177 | # Mask system UUID 178 | if 'system_uuid' in sanitized and sanitized['system_uuid']: 179 | uuid_str = sanitized['system_uuid'] 180 | sanitized['system_uuid'] = f"{uuid_str[:8]}...{uuid_str[-4:]}" 181 | 182 | return sanitized 183 | 184 | def validate_consistency(self, stored_node_id: str, port: int = None) -> bool: 185 | """Validate that the stored node ID matches current hardware""" 186 | current_node_id = self.generate_node_id(port) 187 | is_consistent = stored_node_id == current_node_id 188 | 189 | if not is_consistent: 190 | logger.warning(f"Hardware fingerprint mismatch: stored={stored_node_id[:16]}..., current={current_node_id[:16]}...") 191 | logger.info("This may indicate hardware changes or first run on new hardware") 192 | 193 | return is_consistent 194 | -------------------------------------------------------------------------------- /common/service_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | from typing import Dict, Set, Callable, Optional 4 | from enum import Enum 5 | from dataclasses import dataclass 6 | from common.utils import get_logger 7 | 8 | logger = get_logger(__name__) 9 | 10 | class ServiceState(Enum): 11 | PENDING = "pending" 12 | INITIALIZING = "initializing" 13 | READY = "ready" 14 | FAILED = "failed" 15 | 16 | @dataclass 17 | class ServiceInfo: 18 | name: str 19 | state: ServiceState 20 | start_time: Optional[float] = None 21 | ready_time: Optional[float] = None 22 | error: Optional[str] = None 23 | dependencies: Set[str] = None 24 | 25 | def __post_init__(self): 26 | if self.dependencies is None: 27 | self.dependencies = set() 28 | 29 | class ServiceInitializationManager: 30 | """Manages service initialization order and tracks readiness for DHT join""" 31 | 32 | def __init__(self): 33 | self.services: Dict[str, ServiceInfo] = {} 34 | self.ready_callbacks: Dict[str, Callable] = {} 35 | self.all_ready_callbacks: list = [] 36 | self.initialization_complete = False 37 | self._lock = asyncio.Lock() 38 | 39 | # Define service dependencies and initialization order 40 | self._define_service_dependencies() 41 | 42 | def _define_service_dependencies(self): 43 | """Define the required services and their dependencies""" 44 | # Core services that must be ready before DHT join 45 | self.register_service("config", dependencies=set()) 46 | self.register_service("llm", dependencies={"config"}) 47 | self.register_service("system_info", dependencies={"config"}) 48 | self.register_service("heartbeat_manager", dependencies={"llm"}) 49 | self.register_service("dht_service", dependencies={"config"}) 50 | self.register_service("dht_publisher", dependencies={"dht_service"}) 51 | self.register_service("dht_discovery", dependencies={"dht_service"}) 52 | self.register_service("sse_handler", dependencies={"dht_discovery"}) 53 | self.register_service("sse_network_monitor", dependencies={"sse_handler"}) 54 | self.register_service("discovery_bridge", dependencies={"sse_handler", "dht_discovery"}) 55 | self.register_service("node_selector", dependencies={"dht_discovery"}) 56 | self.register_service("p2p_handler", dependencies={"llm"}, optional=True) 57 | 58 | def register_service(self, name: str, dependencies: Set[str] = None, optional: bool = False): 59 | """Register a service with its dependencies""" 60 | self.services[name] = ServiceInfo( 61 | name=name, 62 | state=ServiceState.PENDING, 63 | dependencies=dependencies or set() 64 | ) 65 | if optional: 66 | self.services[name].optional = True 67 | logger.debug(f"Registered service: {name} with dependencies: {dependencies}") 68 | 69 | async def mark_service_initializing(self, name: str): 70 | """Mark a service as starting initialization""" 71 | async with self._lock: 72 | if name in self.services: 73 | self.services[name].state = ServiceState.INITIALIZING 74 | self.services[name].start_time = time.time() 75 | logger.debug(f"Service {name} is initializing") 76 | 77 | async def mark_service_ready(self, name: str): 78 | """Mark a service as ready and check if all services are ready""" 79 | async with self._lock: 80 | if name in self.services: 81 | self.services[name].state = ServiceState.READY 82 | self.services[name].ready_time = time.time() 83 | 84 | init_time = 0 85 | if self.services[name].start_time: 86 | init_time = self.services[name].ready_time - self.services[name].start_time 87 | 88 | logger.info(f"✅ Service {name} ready (init time: {init_time:.2f}s)") 89 | 90 | # Execute ready callback if registered 91 | if name in self.ready_callbacks: 92 | try: 93 | await self.ready_callbacks[name]() 94 | except Exception as e: 95 | logger.error(f"Error in ready callback for {name}: {e}") 96 | 97 | # Check if all required services are ready 98 | await self._check_all_services_ready() 99 | 100 | async def mark_service_failed(self, name: str, error: str): 101 | """Mark a service as failed""" 102 | async with self._lock: 103 | if name in self.services: 104 | self.services[name].state = ServiceState.FAILED 105 | self.services[name].error = error 106 | logger.error(f"❌ Service {name} failed: {error}") 107 | 108 | def register_ready_callback(self, service_name: str, callback: Callable): 109 | """Register a callback to execute when a specific service is ready""" 110 | self.ready_callbacks[service_name] = callback 111 | 112 | def register_all_ready_callback(self, callback: Callable): 113 | """Register a callback to execute when all services are ready""" 114 | self.all_ready_callbacks.append(callback) 115 | 116 | async def _check_all_services_ready(self): 117 | """Check if all required services are ready and trigger callbacks""" 118 | if self.initialization_complete: 119 | return 120 | 121 | required_services = {name: service for name, service in self.services.items() 122 | if not getattr(service, 'optional', False)} 123 | 124 | all_ready = all(service.state == ServiceState.READY for service in required_services.values()) 125 | 126 | if all_ready: 127 | self.initialization_complete = True 128 | total_time = time.time() - min(s.start_time for s in required_services.values() if s.start_time) 129 | 130 | logger.info(f"🎉 All services ready! Total initialization time: {total_time:.2f}s") 131 | 132 | # DISABLED: Don't execute callbacks here - join event will be sent post-uvicorn 133 | # Execute all ready callbacks 134 | # for callback in self.all_ready_callbacks: 135 | # try: 136 | # await callback() 137 | # except Exception as e: 138 | # logger.error(f"Error in all-ready callback: {e}") 139 | 140 | logger.info("Join event will be sent after uvicorn initialization completes") 141 | 142 | def are_dependencies_ready(self, service_name: str) -> bool: 143 | """Check if all dependencies for a service are ready""" 144 | if service_name not in self.services: 145 | return False 146 | 147 | service = self.services[service_name] 148 | for dep in service.dependencies: 149 | if dep not in self.services or self.services[dep].state != ServiceState.READY: 150 | return False 151 | return True 152 | 153 | def get_initialization_status(self) -> Dict: 154 | """Get current initialization status""" 155 | status = { 156 | "initialization_complete": self.initialization_complete, 157 | "services": {}, 158 | "ready_count": 0, 159 | "total_count": len([s for s in self.services.values() if not getattr(s, 'optional', False)]) 160 | } 161 | 162 | for name, service in self.services.items(): 163 | status["services"][name] = { 164 | "state": service.state.value, 165 | "dependencies": list(service.dependencies), 166 | "optional": getattr(service, 'optional', False), 167 | "start_time": service.start_time, 168 | "ready_time": service.ready_time, 169 | "error": service.error 170 | } 171 | 172 | if service.state == ServiceState.READY and not getattr(service, 'optional', False): 173 | status["ready_count"] += 1 174 | 175 | return status 176 | 177 | async def wait_for_all_services(self, timeout: float = 30.0) -> bool: 178 | """Wait for all services to be ready with timeout""" 179 | start_time = time.time() 180 | 181 | while not self.initialization_complete and (time.time() - start_time) < timeout: 182 | await asyncio.sleep(0.1) 183 | 184 | return self.initialization_complete 185 | 186 | # Global service manager instance 187 | _service_manager: Optional[ServiceInitializationManager] = None 188 | 189 | def get_service_manager() -> ServiceInitializationManager: 190 | """Get the global service manager instance""" 191 | global _service_manager 192 | if _service_manager is None: 193 | _service_manager = ServiceInitializationManager() 194 | return _service_manager 195 | -------------------------------------------------------------------------------- /inference_node/request_queue.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | import uuid 4 | from typing import Dict, Any, Optional, Callable, Awaitable 5 | from enum import Enum 6 | from dataclasses import dataclass 7 | from common.utils import get_logger 8 | 9 | logger = get_logger(__name__) 10 | 11 | class RequestStatus(Enum): 12 | QUEUED = "queued" 13 | PROCESSING = "processing" 14 | COMPLETED = "completed" 15 | FAILED = "failed" 16 | CANCELLED = "cancelled" 17 | 18 | @dataclass 19 | class QueuedRequest: 20 | request_id: str 21 | request_type: str # "completion" or "chat_completion" 22 | request_data: Dict[str, Any] 23 | future: asyncio.Future 24 | queued_at: float 25 | started_at: Optional[float] = None 26 | completed_at: Optional[float] = None 27 | status: RequestStatus = RequestStatus.QUEUED 28 | 29 | class RequestQueueManager: 30 | """Manages request queuing for single-threaded LLM processing""" 31 | 32 | def __init__(self, max_queue_size: int = 50): 33 | self.max_queue_size = max_queue_size 34 | self.request_queue = asyncio.Queue(maxsize=max_queue_size) 35 | self.active_requests: Dict[str, QueuedRequest] = {} 36 | self.processing_request: Optional[QueuedRequest] = None 37 | self.worker_task: Optional[asyncio.Task] = None 38 | self.running = False 39 | self.stats = { 40 | "total_requests": 0, 41 | "completed_requests": 0, 42 | "failed_requests": 0, 43 | "cancelled_requests": 0, 44 | "queue_full_rejections": 0 45 | } 46 | 47 | async def start(self): 48 | """Start the request queue worker""" 49 | if self.running: 50 | return 51 | 52 | self.running = True 53 | self.worker_task = asyncio.create_task(self._worker_loop()) 54 | logger.info("Request queue manager started") 55 | 56 | async def stop(self): 57 | """Stop the request queue manager""" 58 | if not self.running: 59 | return 60 | 61 | self.running = False 62 | 63 | # Cancel all queued requests 64 | while not self.request_queue.empty(): 65 | try: 66 | request = self.request_queue.get_nowait() 67 | request.status = RequestStatus.CANCELLED 68 | request.future.cancel() 69 | self.stats["cancelled_requests"] += 1 70 | except asyncio.QueueEmpty: 71 | break 72 | 73 | # Cancel processing request if any 74 | if self.processing_request: 75 | self.processing_request.status = RequestStatus.CANCELLED 76 | self.processing_request.future.cancel() 77 | 78 | # Cancel worker task 79 | if self.worker_task: 80 | self.worker_task.cancel() 81 | try: 82 | await self.worker_task 83 | except asyncio.CancelledError: 84 | pass 85 | 86 | logger.info("Request queue manager stopped") 87 | 88 | async def submit_request(self, 89 | request_type: str, 90 | request_data: Dict[str, Any], 91 | processor: Callable[[Dict[str, Any]], Awaitable[Any]]) -> Any: 92 | """Submit a request to the queue""" 93 | if not self.running: 94 | raise RuntimeError("Request queue manager not running") 95 | 96 | # Check if queue is full 97 | if self.request_queue.qsize() >= self.max_queue_size: 98 | self.stats["queue_full_rejections"] += 1 99 | raise asyncio.QueueFull("Request queue is full") 100 | 101 | # Create request 102 | request_id = f"{request_type}_{uuid.uuid4().hex[:8]}" 103 | future = asyncio.Future() 104 | 105 | queued_request = QueuedRequest( 106 | request_id=request_id, 107 | request_type=request_type, 108 | request_data=request_data, 109 | future=future, 110 | queued_at=time.time() 111 | ) 112 | 113 | # Store processor function in request data 114 | queued_request.processor = processor 115 | 116 | # Add to queue and tracking 117 | await self.request_queue.put(queued_request) 118 | self.active_requests[request_id] = queued_request 119 | self.stats["total_requests"] += 1 120 | 121 | logger.debug(f"Request {request_id} queued (queue size: {self.request_queue.qsize()})") 122 | 123 | try: 124 | # Wait for completion 125 | result = await future 126 | return result 127 | except asyncio.CancelledError: 128 | logger.debug(f"Request {request_id} cancelled") 129 | raise 130 | finally: 131 | # Clean up 132 | self.active_requests.pop(request_id, None) 133 | 134 | async def _worker_loop(self): 135 | """Main worker loop that processes requests sequentially""" 136 | logger.info("Request queue worker started") 137 | 138 | while self.running: 139 | try: 140 | # Get next request (with timeout to allow graceful shutdown) 141 | try: 142 | request = await asyncio.wait_for( 143 | self.request_queue.get(), 144 | timeout=1.0 145 | ) 146 | except asyncio.TimeoutError: 147 | continue 148 | 149 | # Process the request 150 | await self._process_request(request) 151 | 152 | except asyncio.CancelledError: 153 | logger.info("Request queue worker cancelled") 154 | break 155 | except Exception as e: 156 | logger.error(f"Error in request queue worker: {e}") 157 | await asyncio.sleep(0.1) # Brief pause on error 158 | 159 | logger.info("Request queue worker stopped") 160 | 161 | async def _process_request(self, request: QueuedRequest): 162 | """Process a single request""" 163 | self.processing_request = request 164 | request.status = RequestStatus.PROCESSING 165 | request.started_at = time.time() 166 | 167 | wait_time = request.started_at - request.queued_at 168 | logger.debug(f"Processing request {request.request_id} (waited {wait_time:.2f}s)") 169 | 170 | try: 171 | # Call the processor function 172 | result = await request.processor(request.request_data) 173 | 174 | # Mark as completed 175 | request.status = RequestStatus.COMPLETED 176 | request.completed_at = time.time() 177 | request.future.set_result(result) 178 | self.stats["completed_requests"] += 1 179 | 180 | processing_time = request.completed_at - request.started_at 181 | logger.debug(f"Request {request.request_id} completed in {processing_time:.2f}s") 182 | 183 | except asyncio.CancelledError: 184 | request.status = RequestStatus.CANCELLED 185 | request.future.cancel() 186 | self.stats["cancelled_requests"] += 1 187 | logger.debug(f"Request {request.request_id} cancelled during processing") 188 | 189 | except Exception as e: 190 | request.status = RequestStatus.FAILED 191 | request.completed_at = time.time() 192 | request.future.set_exception(e) 193 | self.stats["failed_requests"] += 1 194 | logger.error(f"Request {request.request_id} failed: {e}") 195 | 196 | finally: 197 | self.processing_request = None 198 | 199 | def get_status(self) -> Dict[str, Any]: 200 | """Get current queue status""" 201 | current_time = time.time() 202 | 203 | # Calculate queue wait times 204 | queue_wait_times = [] 205 | for request in list(self.active_requests.values()): 206 | if request.status == RequestStatus.QUEUED: 207 | wait_time = current_time - request.queued_at 208 | queue_wait_times.append(wait_time) 209 | 210 | avg_wait_time = sum(queue_wait_times) / len(queue_wait_times) if queue_wait_times else 0 211 | 212 | status = { 213 | "running": self.running, 214 | "queue_size": self.request_queue.qsize(), 215 | "max_queue_size": self.max_queue_size, 216 | "active_requests": len(self.active_requests), 217 | "processing_request": self.processing_request.request_id if self.processing_request else None, 218 | "avg_queue_wait_time": avg_wait_time, 219 | "max_queue_wait_time": max(queue_wait_times) if queue_wait_times else 0, 220 | "stats": self.stats.copy(), 221 | "timestamp": current_time 222 | } 223 | 224 | return status 225 | 226 | def is_busy(self) -> bool: 227 | """Check if the LLM is currently processing a request""" 228 | return self.processing_request is not None 229 | 230 | def get_queue_position(self, request_id: str) -> Optional[int]: 231 | """Get the position of a request in the queue""" 232 | if request_id not in self.active_requests: 233 | return None 234 | 235 | request = self.active_requests[request_id] 236 | if request.status != RequestStatus.QUEUED: 237 | return None 238 | 239 | # This is approximate since asyncio.Queue doesn't provide position info 240 | return self.request_queue.qsize() 241 | -------------------------------------------------------------------------------- /inference_node/metrics.py: -------------------------------------------------------------------------------- 1 | import time 2 | import psutil 3 | import platform 4 | from typing import Dict, Any, Optional 5 | from common.utils import get_logger 6 | import atexit 7 | import threading 8 | 9 | logger = get_logger(__name__) 10 | 11 | # Global cleanup registry 12 | _cleanup_registry = [] 13 | _cleanup_lock = threading.Lock() 14 | 15 | def register_cleanup(cleanup_func): 16 | """Register a cleanup function to be called at exit""" 17 | with _cleanup_lock: 18 | _cleanup_registry.append(cleanup_func) 19 | 20 | def _cleanup_all(): 21 | """Clean up all registered resources""" 22 | with _cleanup_lock: 23 | for cleanup_func in _cleanup_registry: 24 | try: 25 | cleanup_func() 26 | except Exception as e: 27 | logger.debug(f"Cleanup error: {e}") 28 | _cleanup_registry.clear() 29 | 30 | # Register cleanup at module import 31 | atexit.register(_cleanup_all) 32 | 33 | # RequestTracker functionality moved to common/metrics_manager.py 34 | # Use MetricsManager for all request tracking 35 | 36 | class SystemInfo: 37 | """Collect system information with proper resource management""" 38 | 39 | # Cache system info to avoid repeated psutil calls 40 | _cpu_info_cache = None 41 | _ram_info_cache = None 42 | _gpu_info_cache = None 43 | _cache_time = 0 44 | _cache_ttl = 300 # 5 minutes 45 | 46 | @staticmethod 47 | def get_cpu_info() -> str: 48 | """Get CPU information with caching and fallback methods""" 49 | current_time = time.time() 50 | 51 | # Use cache if available and fresh 52 | if (SystemInfo._cpu_info_cache and 53 | current_time - SystemInfo._cache_time < SystemInfo._cache_ttl): 54 | return SystemInfo._cpu_info_cache 55 | 56 | try: 57 | # Try multiple methods to get CPU info 58 | cpu_info = platform.processor() 59 | if cpu_info and cpu_info.strip(): 60 | SystemInfo._cpu_info_cache = cpu_info.strip() 61 | SystemInfo._cache_time = current_time 62 | return SystemInfo._cpu_info_cache 63 | 64 | # Fallback to machine type 65 | machine = platform.machine() 66 | system = platform.system() 67 | cpu_info = f"{system} {machine}" 68 | SystemInfo._cpu_info_cache = cpu_info 69 | SystemInfo._cache_time = current_time 70 | return cpu_info 71 | except Exception: 72 | fallback = "Unknown CPU" 73 | SystemInfo._cpu_info_cache = fallback 74 | return fallback 75 | 76 | @staticmethod 77 | def get_ram_info() -> Dict[str, Any]: 78 | """Get RAM information with better error handling and caching""" 79 | current_time = time.time() 80 | 81 | # Use cache if available and fresh 82 | if (SystemInfo._ram_info_cache and 83 | current_time - SystemInfo._cache_time < SystemInfo._cache_ttl): 84 | return SystemInfo._ram_info_cache 85 | 86 | try: 87 | mem = psutil.virtual_memory() 88 | ram_info = { 89 | "total": mem.total, 90 | "available": mem.available, 91 | "total_gb": round(mem.total / (1024**3), 2), 92 | "available_gb": round(mem.available / (1024**3), 2), 93 | "used_percent": mem.percent 94 | } 95 | SystemInfo._ram_info_cache = ram_info 96 | SystemInfo._cache_time = current_time 97 | return ram_info 98 | except Exception as e: 99 | logger.error(f"Error getting RAM info: {e}") 100 | fallback = { 101 | "total": 0, 102 | "available": 0, 103 | "total_gb": 0.0, 104 | "available_gb": 0.0, 105 | "used_percent": 0.0 106 | } 107 | SystemInfo._ram_info_cache = fallback 108 | return fallback 109 | 110 | @staticmethod 111 | def get_gpu_info() -> Optional[str]: 112 | """Get GPU information with caching and improved error handling""" 113 | current_time = time.time() 114 | 115 | # Use cache if available and fresh 116 | if (SystemInfo._gpu_info_cache is not None and 117 | current_time - SystemInfo._cache_time < SystemInfo._cache_ttl): 118 | return SystemInfo._gpu_info_cache 119 | 120 | try: 121 | # First check if we have NVIDIA GPUs using nvidia-smi (most reliable) 122 | import subprocess 123 | try: 124 | # Try nvidia-smi first (most reliable way to detect NVIDIA GPUs) 125 | result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader,nounits'], 126 | capture_output=True, text=True, timeout=5) 127 | if result.returncode == 0 and result.stdout.strip(): 128 | # Parse nvidia-smi output 129 | gpu_info = [] 130 | for line in result.stdout.strip().split('\n'): 131 | if line.strip(): 132 | parts = line.split(',') 133 | if len(parts) >= 2: 134 | name = parts[0].strip() 135 | memory_mb = parts[1].strip() 136 | gpu_info.append(f"{name} ({memory_mb}MB)") 137 | 138 | gpu_result = ", ".join(gpu_info) if gpu_info else None 139 | SystemInfo._gpu_info_cache = gpu_result 140 | SystemInfo._cache_time = current_time 141 | return gpu_result 142 | except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): 143 | # nvidia-smi not available, continue to pynvml 144 | pass 145 | 146 | # Try pynvml as fallback 147 | import pynvml 148 | pynvml.nvmlInit() 149 | device_count = pynvml.nvmlDeviceGetCount() 150 | 151 | if device_count == 0: 152 | SystemInfo._gpu_info_cache = None 153 | SystemInfo._cache_time = current_time 154 | return None 155 | 156 | # Return info for all GPUs 157 | gpu_info = [] 158 | for i in range(device_count): 159 | handle = pynvml.nvmlDeviceGetHandleByIndex(i) 160 | name = pynvml.nvmlDeviceGetName(handle) 161 | memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle) 162 | memory_mb = int(memory_info.total / (1024 * 1024)) 163 | 164 | # Handle both bytes and string returns from pynvml 165 | if isinstance(name, bytes): 166 | name = name.decode('utf-8') 167 | 168 | gpu_info.append(f"{name} ({memory_mb}MB)") 169 | 170 | gpu_result = ", ".join(gpu_info) 171 | SystemInfo._gpu_info_cache = gpu_result 172 | SystemInfo._cache_time = current_time 173 | return gpu_result 174 | 175 | except ImportError: 176 | logger.debug("pynvml not available - install with: pip install pynvml") 177 | SystemInfo._gpu_info_cache = None 178 | SystemInfo._cache_time = current_time 179 | return None 180 | except Exception as e: 181 | # Only log as debug for common "no NVIDIA GPU" scenarios 182 | error_msg = str(e).lower() 183 | if any(phrase in error_msg for phrase in ['nvml shared library not found', 'nvidia driver', 'no devices', 'nvml_error_uninitialized']): 184 | logger.debug(f"No NVIDIA GPUs detected: {e}") 185 | else: 186 | logger.warning(f"Could not get GPU info: {e}") 187 | SystemInfo._gpu_info_cache = None 188 | SystemInfo._cache_time = current_time 189 | return None 190 | 191 | @staticmethod 192 | def get_current_load() -> Dict[str, float]: 193 | """Get current system load metrics without blocking""" 194 | try: 195 | # Use non-blocking CPU measurement 196 | cpu_percent = psutil.cpu_percent(interval=0) # Non-blocking 197 | memory_percent = psutil.virtual_memory().percent 198 | 199 | # Safe disk usage check 200 | try: 201 | disk_percent = psutil.disk_usage('/').percent 202 | except (OSError, PermissionError): 203 | # Fallback for Windows or permission issues 204 | try: 205 | disk_percent = psutil.disk_usage('.').percent 206 | except: 207 | disk_percent = 0.0 208 | 209 | return { 210 | "cpu_percent": cpu_percent, 211 | "memory_percent": memory_percent, 212 | "disk_percent": disk_percent 213 | } 214 | except Exception as e: 215 | logger.error(f"Error getting load metrics: {e}") 216 | return { 217 | "cpu_percent": 0.0, 218 | "memory_percent": 0.0, 219 | "disk_percent": 0.0 220 | } 221 | 222 | # Overload checking moved to common/metrics_manager.py 223 | # Use MetricsManager.is_overloaded() for consistent overload detection 224 | 225 | @staticmethod 226 | def get_all_info() -> Dict[str, Any]: 227 | """Get all system information with caching""" 228 | return { 229 | "cpu": SystemInfo.get_cpu_info(), 230 | "ram": SystemInfo.get_ram_info(), 231 | "gpu": SystemInfo.get_gpu_info(), 232 | "platform": platform.platform() 233 | } 234 | 235 | # Module cleanup 236 | def cleanup_system_info(): 237 | """Clean up SystemInfo caches""" 238 | SystemInfo._cpu_info_cache = None 239 | SystemInfo._ram_info_cache = None 240 | SystemInfo._gpu_info_cache = None 241 | 242 | register_cleanup(cleanup_system_info) 243 | -------------------------------------------------------------------------------- /docker/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Startup script for LlamaNet Docker container 3 | # Orchestrates hardware detection, configuration, and application startup 4 | 5 | set -e 6 | 7 | # Color codes for output 8 | RED='\033[0;31m' 9 | GREEN='\033[0;32m' 10 | YELLOW='\033[1;33m' 11 | BLUE='\033[0;34m' 12 | CYAN='\033[0;36m' 13 | NC='\033[0m' # No Color 14 | 15 | # Logging functions 16 | log() { 17 | echo -e "${BLUE}[STARTUP]${NC} $1" 18 | } 19 | 20 | log_success() { 21 | echo -e "${GREEN}[STARTUP]${NC} ✅ $1" 22 | } 23 | 24 | log_warning() { 25 | echo -e "${YELLOW}[STARTUP]${NC} ⚠️ $1" 26 | } 27 | 28 | log_error() { 29 | echo -e "${RED}[STARTUP]${NC} ❌ $1" 30 | } 31 | 32 | log_info() { 33 | echo -e "${CYAN}[STARTUP]${NC} ℹ️ $1" 34 | } 35 | 36 | show_banner() { 37 | echo "" 38 | echo -e "${CYAN}╔══════════════════════════════════════════════════════════════╗${NC}" 39 | echo -e "${CYAN}║ 🦙 LlamaNet Container ║${NC}" 40 | echo -e "${CYAN}║ Distributed AI Inference Network ║${NC}" 41 | echo -e "${CYAN}╚══════════════════════════════════════════════════════════════╝${NC}" 42 | echo "" 43 | } 44 | 45 | validate_environment() { 46 | log "🔍 Validating environment configuration..." 47 | 48 | local validation_failed=false 49 | 50 | # Check required environment variables 51 | if [ -z "$MODEL_PATH" ]; then 52 | log_error "MODEL_PATH environment variable is required" 53 | log_info "Please set MODEL_PATH to point to your GGUF model file" 54 | log_info "Example: -e MODEL_PATH=/models/your-model.gguf" 55 | validation_failed=true 56 | fi 57 | 58 | # Check if model file exists 59 | if [ -n "$MODEL_PATH" ] && [ ! -f "$MODEL_PATH" ]; then 60 | log_error "Model file not found at $MODEL_PATH" 61 | log_info "Available files in /models:" 62 | if [ -d "/models" ]; then 63 | ls -la /models/ 2>/dev/null || log_warning "Models directory is empty or not accessible" 64 | else 65 | log_warning "Models directory not mounted" 66 | fi 67 | validation_failed=true 68 | fi 69 | 70 | # Validate model file format 71 | if [ -n "$MODEL_PATH" ] && [ -f "$MODEL_PATH" ]; then 72 | if [[ "$MODEL_PATH" == *.gguf ]]; then 73 | log_success "Model file found: $MODEL_PATH" 74 | log_info "📁 Model size: $(du -h "$MODEL_PATH" | cut -f1)" 75 | else 76 | log_warning "Model file does not have .gguf extension: $MODEL_PATH" 77 | log_info "LlamaNet expects GGUF format models" 78 | fi 79 | fi 80 | 81 | # Validate ports 82 | local port=${PORT:-8000} 83 | local dht_port=${DHT_PORT:-8001} 84 | 85 | if [ "$port" = "$dht_port" ]; then 86 | log_error "HTTP port ($port) and DHT port ($dht_port) cannot be the same" 87 | validation_failed=true 88 | fi 89 | 90 | # Check for port conflicts (basic check) 91 | if netstat -tuln 2>/dev/null | grep -q ":$port "; then 92 | log_warning "Port $port appears to be in use" 93 | fi 94 | 95 | if netstat -tuln 2>/dev/null | grep -q ":$dht_port "; then 96 | log_warning "DHT port $dht_port appears to be in use" 97 | fi 98 | 99 | if [ "$validation_failed" = true ]; then 100 | log_error "Environment validation failed" 101 | exit 1 102 | fi 103 | 104 | log_success "Environment validation passed" 105 | } 106 | 107 | run_hardware_detection() { 108 | log "🔧 Running hardware detection and setup..." 109 | 110 | # Source and run GPU detection script 111 | if [ -f "/usr/local/bin/gpu-detect.sh" ]; then 112 | source /usr/local/bin/gpu-detect.sh 113 | 114 | # Run the main detection function 115 | if ! main; then 116 | log_error "Hardware detection failed" 117 | exit 1 118 | fi 119 | 120 | # Load the hardware configuration 121 | if [ -f "/tmp/hardware_config" ]; then 122 | source /tmp/hardware_config 123 | log_success "Hardware configuration loaded" 124 | else 125 | log_warning "Hardware configuration file not found, using defaults" 126 | export HARDWARE_MODE=${HARDWARE_MODE:-cpu} 127 | export N_GPU_LAYERS=${N_GPU_LAYERS:-0} 128 | fi 129 | else 130 | log_error "GPU detection script not found at /usr/local/bin/gpu-detect.sh" 131 | exit 1 132 | fi 133 | } 134 | 135 | setup_python_environment() { 136 | log "🐍 Setting up Python environment..." 137 | 138 | # Set Python path and ensure UTF-8 encoding 139 | export PYTHONPATH=/app:$PYTHONPATH 140 | export PYTHONUNBUFFERED=1 141 | export PYTHONIOENCODING=utf-8 142 | export LC_ALL=C.UTF-8 143 | export LANG=C.UTF-8 144 | 145 | # Verify Python installation 146 | if ! python3 --version >/dev/null 2>&1; then 147 | log_error "Python 3 is not available" 148 | exit 1 149 | fi 150 | 151 | log_success "Python environment configured" 152 | log_info "Python version: $(python3 --version)" 153 | } 154 | 155 | configure_application() { 156 | log "⚙️ Configuring LlamaNet application..." 157 | 158 | # Set default values for environment variables 159 | export HOST=${HOST:-0.0.0.0} 160 | export PORT=${PORT:-8000} 161 | export DHT_PORT=${DHT_PORT:-8001} 162 | export NODE_ID=${NODE_ID:-$(hostname)-$(date +%s)} 163 | export BOOTSTRAP_NODES=${BOOTSTRAP_NODES:-""} 164 | 165 | # Hardware-specific configuration 166 | export N_GPU_LAYERS=${N_GPU_LAYERS:-0} 167 | export HARDWARE_MODE=${HARDWARE_MODE:-cpu} 168 | 169 | # Performance tuning based on hardware 170 | if [ "$HARDWARE_MODE" = "gpu" ]; then 171 | # GPU-specific optimizations 172 | export N_BATCH=${N_BATCH:-512} 173 | export N_CTX=${N_CTX:-4096} 174 | export N_THREADS=${N_THREADS:-$(nproc)} 175 | else 176 | # CPU-specific optimizations 177 | export N_BATCH=${N_BATCH:-128} 178 | export N_CTX=${N_CTX:-2048} 179 | export N_THREADS=${N_THREADS:-$(nproc)} 180 | fi 181 | 182 | # Logging configuration 183 | export LOG_LEVEL=${LOG_LEVEL:-info} 184 | 185 | log_success "Application configuration complete" 186 | } 187 | 188 | show_configuration() { 189 | echo "" 190 | log "🚀 LlamaNet Configuration Summary" 191 | log "==================================" 192 | echo -e " ${CYAN}Hardware Mode:${NC} $HARDWARE_MODE" 193 | echo -e " ${CYAN}GPU Layers:${NC} $N_GPU_LAYERS" 194 | echo -e " ${CYAN}Model Path:${NC} $MODEL_PATH" 195 | echo -e " ${CYAN}Host:${NC} $HOST" 196 | echo -e " ${CYAN}HTTP Port:${NC} $PORT" 197 | echo -e " ${CYAN}DHT Port:${NC} $DHT_PORT" 198 | echo -e " ${CYAN}Node ID:${NC} $NODE_ID" 199 | echo -e " ${CYAN}Bootstrap Nodes:${NC} ${BOOTSTRAP_NODES:-'none (bootstrap node)'}" 200 | echo -e " ${CYAN}Context Size:${NC} $N_CTX" 201 | echo -e " ${CYAN}Batch Size:${NC} $N_BATCH" 202 | echo -e " ${CYAN}Threads:${NC} $N_THREADS" 203 | log "==================================" 204 | echo "" 205 | } 206 | 207 | setup_signal_handlers() { 208 | log "📡 Setting up signal handlers..." 209 | 210 | # Function to handle shutdown signals 211 | shutdown_handler() { 212 | log_warning "Received shutdown signal, cleaning up..." 213 | 214 | # Kill any background processes 215 | jobs -p | xargs -r kill 2>/dev/null || true 216 | 217 | log_success "Cleanup complete, exiting" 218 | exit 0 219 | } 220 | 221 | # Set up signal traps 222 | trap shutdown_handler SIGTERM SIGINT SIGQUIT 223 | 224 | log_success "Signal handlers configured" 225 | } 226 | 227 | perform_health_checks() { 228 | log "🏥 Performing pre-startup health checks..." 229 | 230 | # Check disk space 231 | local available_space=$(df /app | tail -1 | awk '{print $4}') 232 | if [ "$available_space" -lt 1048576 ]; then # Less than 1GB 233 | log_warning "Low disk space available: $(df -h /app | tail -1 | awk '{print $4}')" 234 | fi 235 | 236 | # Check memory 237 | local available_memory=$(free -m | grep '^Mem:' | awk '{print $7}') 238 | if [ "$available_memory" -lt 1024 ]; then # Less than 1GB 239 | log_warning "Low memory available: ${available_memory}MB" 240 | fi 241 | 242 | # Test model file accessibility 243 | if [ -f "$MODEL_PATH" ]; then 244 | if [ -r "$MODEL_PATH" ]; then 245 | log_success "Model file is readable" 246 | else 247 | log_error "Model file is not readable" 248 | exit 1 249 | fi 250 | fi 251 | 252 | # Test network connectivity (if bootstrap nodes specified) 253 | if [ -n "$BOOTSTRAP_NODES" ]; then 254 | log_info "Testing connectivity to bootstrap nodes..." 255 | IFS=',' read -ra NODES <<< "$BOOTSTRAP_NODES" 256 | for node in "${NODES[@]}"; do 257 | IFS=':' read -ra ADDR <<< "$node" 258 | local host=${ADDR[0]} 259 | local port=${ADDR[1]} 260 | 261 | if timeout 5 nc -z "$host" "$port" 2>/dev/null; then 262 | log_success "Bootstrap node reachable: $node" 263 | else 264 | log_warning "Bootstrap node not reachable: $node" 265 | fi 266 | done 267 | fi 268 | 269 | log_success "Health checks completed" 270 | } 271 | 272 | start_application() { 273 | log "🎬 Starting LlamaNet application..." 274 | 275 | # Change to application directory 276 | cd /app 277 | 278 | # Show final startup message 279 | log_success "Launching: $*" 280 | echo "" 281 | 282 | # Execute the application with all provided arguments 283 | exec "$@" 284 | } 285 | 286 | # Main startup sequence 287 | main() { 288 | show_banner 289 | 290 | # Core startup sequence 291 | validate_environment 292 | setup_python_environment 293 | setup_signal_handlers 294 | run_hardware_detection 295 | configure_application 296 | perform_health_checks 297 | show_configuration 298 | 299 | # Start the application 300 | start_application "$@" 301 | } 302 | 303 | # Run main function with all arguments 304 | main "$@" 305 | -------------------------------------------------------------------------------- /docker/gpu-detect.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # GPU Detection Script for LlamaNet Docker 3 | # Detects NVIDIA GPU availability and configures appropriate llama-cpp-python installation 4 | 5 | set -e 6 | 7 | # Color codes for output 8 | RED='\033[0;31m' 9 | GREEN='\033[0;32m' 10 | YELLOW='\033[1;33m' 11 | BLUE='\033[0;34m' 12 | NC='\033[0m' # No Color 13 | 14 | # Logging function 15 | log() { 16 | echo -e "${BLUE}[GPU-DETECT]${NC} $1" 17 | } 18 | 19 | log_success() { 20 | echo -e "${GREEN}[GPU-DETECT]${NC} ✅ $1" 21 | } 22 | 23 | log_warning() { 24 | echo -e "${YELLOW}[GPU-DETECT]${NC} ⚠️ $1" 25 | } 26 | 27 | log_error() { 28 | echo -e "${RED}[GPU-DETECT]${NC} ❌ $1" 29 | } 30 | 31 | detect_gpu() { 32 | log "🔍 Detecting GPU capabilities..." 33 | 34 | # Check if nvidia-smi is available and working 35 | if command -v nvidia-smi >/dev/null 2>&1; then 36 | if nvidia-smi >/dev/null 2>&1; then 37 | GPU_COUNT=$(nvidia-smi --query-gpu=count --format=csv,noheader,nounits 2>/dev/null | head -1 || echo "0") 38 | GPU_MEMORY=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits 2>/dev/null | head -1 || echo "0") 39 | GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null | head -1 || echo "Unknown") 40 | 41 | if [ "$GPU_COUNT" -gt 0 ] && [ "$GPU_MEMORY" -gt 0 ]; then 42 | log_success "GPU detected: $GPU_NAME" 43 | log "📊 GPU Memory: ${GPU_MEMORY}MB" 44 | log "🔢 GPU Count: $GPU_COUNT" 45 | 46 | # Check CUDA runtime availability 47 | if python3 -c "import torch; print('CUDA available:', torch.cuda.is_available())" 2>/dev/null | grep -q "True"; then 48 | log_success "CUDA runtime verified" 49 | return 0 50 | else 51 | log_warning "CUDA runtime not available in Python, checking basic GPU access..." 52 | # Even without torch CUDA, we can still try CUDA compilation 53 | return 0 54 | fi 55 | else 56 | log_warning "GPU detected but not accessible (Count: $GPU_COUNT, Memory: ${GPU_MEMORY}MB)" 57 | return 1 58 | fi 59 | else 60 | log_warning "nvidia-smi found but not working, using CPU mode" 61 | return 1 62 | fi 63 | else 64 | log "ℹ️ No GPU detected, using CPU mode" 65 | return 1 66 | fi 67 | } 68 | 69 | install_gpu_support() { 70 | log "🚀 Installing GPU-optimized llama-cpp-python..." 71 | 72 | # Set CUDA compilation flags 73 | export CMAKE_ARGS="-DLLAMA_CUBLAS=ON -DLLAMA_CUDA_FORCE_DMMV=ON" 74 | export FORCE_CMAKE=1 75 | export CUDACXX=/usr/local/cuda/bin/nvcc 76 | 77 | # Ensure CUDA is in PATH 78 | export PATH="/usr/local/cuda/bin:$PATH" 79 | export LD_LIBRARY_PATH="/usr/local/cuda/lib64:$LD_LIBRARY_PATH" 80 | 81 | log "🔧 CUDA compilation flags set: $CMAKE_ARGS" 82 | 83 | # Install GPU version with retry logic 84 | local max_attempts=3 85 | local attempt=1 86 | 87 | while [ $attempt -le $max_attempts ]; do 88 | log "📦 Installing llama-cpp-python with CUDA support (attempt $attempt/$max_attempts)..." 89 | 90 | if pip install --no-cache-dir --force-reinstall --no-deps llama-cpp-python; then 91 | log_success "GPU version installation completed" 92 | break 93 | else 94 | log_warning "Installation attempt $attempt failed" 95 | if [ $attempt -eq $max_attempts ]; then 96 | log_error "GPU installation failed after $max_attempts attempts" 97 | return 1 98 | fi 99 | attempt=$((attempt + 1)) 100 | sleep 2 101 | fi 102 | done 103 | 104 | # Verify GPU installation 105 | log "🔍 Verifying GPU installation..." 106 | if python3 -c " 107 | import sys 108 | try: 109 | from llama_cpp import Llama 110 | print('✅ llama-cpp-python imported successfully') 111 | 112 | # Try to check for CUDA support 113 | try: 114 | # This is a basic test - actual CUDA verification happens at model load time 115 | print('✅ GPU support appears to be available') 116 | sys.exit(0) 117 | except Exception as e: 118 | print(f'⚠️ GPU support verification inconclusive: {e}') 119 | sys.exit(0) # Still proceed, will be verified at runtime 120 | 121 | except ImportError as e: 122 | print(f'❌ Failed to import llama-cpp-python: {e}') 123 | sys.exit(1) 124 | except Exception as e: 125 | print(f'❌ Unexpected error: {e}') 126 | sys.exit(1) 127 | " 2>&1; then 128 | log_success "GPU support installed and verified" 129 | 130 | # Set optimal GPU configuration 131 | export N_GPU_LAYERS=${N_GPU_LAYERS:-32} 132 | export HARDWARE_MODE=gpu 133 | 134 | # Calculate optimal GPU layers based on available memory 135 | if [ "$GPU_MEMORY" -gt 0 ]; then 136 | if [ "$GPU_MEMORY" -gt 16000 ]; then 137 | export N_GPU_LAYERS=40 # High-end GPU 138 | elif [ "$GPU_MEMORY" -gt 8000 ]; then 139 | export N_GPU_LAYERS=32 # Mid-range GPU 140 | elif [ "$GPU_MEMORY" -gt 4000 ]; then 141 | export N_GPU_LAYERS=20 # Lower-end GPU 142 | else 143 | export N_GPU_LAYERS=10 # Very limited GPU memory 144 | fi 145 | log "🎯 Optimized GPU layers for ${GPU_MEMORY}MB VRAM: $N_GPU_LAYERS" 146 | fi 147 | 148 | return 0 149 | else 150 | log_error "GPU installation verification failed" 151 | return 1 152 | fi 153 | } 154 | 155 | install_cpu_support() { 156 | log "🖥️ Installing CPU-optimized llama-cpp-python..." 157 | 158 | # Set CPU compilation flags for optimal performance 159 | export CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS -DLLAMA_NATIVE=ON" 160 | unset FORCE_CMAKE 161 | unset CUDACXX 162 | 163 | log "🔧 CPU compilation flags set: $CMAKE_ARGS" 164 | 165 | # Install CPU version with retry logic 166 | local max_attempts=3 167 | local attempt=1 168 | 169 | while [ $attempt -le $max_attempts ]; do 170 | log "📦 Installing llama-cpp-python with CPU optimization (attempt $attempt/$max_attempts)..." 171 | 172 | if pip install --no-cache-dir --force-reinstall --no-deps llama-cpp-python; then 173 | log_success "CPU version installation completed" 174 | break 175 | else 176 | log_warning "Installation attempt $attempt failed" 177 | if [ $attempt -eq $max_attempts ]; then 178 | log_error "CPU installation failed after $max_attempts attempts" 179 | return 1 180 | fi 181 | attempt=$((attempt + 1)) 182 | sleep 2 183 | fi 184 | done 185 | 186 | # Verify CPU installation 187 | log "🔍 Verifying CPU installation..." 188 | if python3 -c " 189 | import sys 190 | try: 191 | from llama_cpp import Llama 192 | print('✅ llama-cpp-python imported successfully') 193 | print('✅ CPU support available') 194 | sys.exit(0) 195 | except ImportError as e: 196 | print(f'❌ Failed to import llama-cpp-python: {e}') 197 | sys.exit(1) 198 | except Exception as e: 199 | print(f'❌ Unexpected error: {e}') 200 | sys.exit(1) 201 | " 2>&1; then 202 | log_success "CPU support installed and verified" 203 | export N_GPU_LAYERS=0 204 | export HARDWARE_MODE=cpu 205 | return 0 206 | else 207 | log_error "CPU installation verification failed" 208 | return 1 209 | fi 210 | } 211 | 212 | get_system_info() { 213 | log "📋 System Information:" 214 | echo " OS: $(uname -s) $(uname -r)" 215 | echo " Architecture: $(uname -m)" 216 | echo " Python: $(python3 --version 2>&1)" 217 | echo " CPU Cores: $(nproc)" 218 | echo " Memory: $(free -h | grep '^Mem:' | awk '{print $2}') total" 219 | 220 | if command -v nvidia-smi >/dev/null 2>&1; then 221 | echo " NVIDIA Driver: $(nvidia-smi --query-gpu=driver_version --format=csv,noheader,nounits 2>/dev/null | head -1 || echo 'Not available')" 222 | echo " CUDA Version: $(nvcc --version 2>/dev/null | grep 'release' | awk '{print $6}' | cut -c2- || echo 'Not available')" 223 | fi 224 | } 225 | 226 | # Main detection and installation logic 227 | main() { 228 | echo "" 229 | log "🔧 LlamaNet Hardware Detection and Setup" 230 | log "========================================" 231 | 232 | # Show system information 233 | get_system_info 234 | echo "" 235 | 236 | # Check if hardware mode is forced via environment variable 237 | if [ "$HARDWARE_MODE" = "gpu" ]; then 238 | log "🎯 GPU mode forced via environment variable" 239 | if ! install_gpu_support; then 240 | log_error "Forced GPU mode failed, exiting" 241 | exit 1 242 | fi 243 | elif [ "$HARDWARE_MODE" = "cpu" ]; then 244 | log "🎯 CPU mode forced via environment variable" 245 | if ! install_cpu_support; then 246 | log_error "CPU mode installation failed" 247 | exit 1 248 | fi 249 | else 250 | log "🔍 Auto-detecting hardware capabilities..." 251 | if detect_gpu; then 252 | log "🎯 GPU detected, attempting GPU installation..." 253 | if ! install_gpu_support; then 254 | log_warning "GPU installation failed, falling back to CPU mode" 255 | if ! install_cpu_support; then 256 | log_error "Both GPU and CPU installation failed" 257 | exit 1 258 | fi 259 | fi 260 | else 261 | log "🎯 No GPU detected or GPU not accessible, using CPU mode" 262 | if ! install_cpu_support; then 263 | log_error "CPU mode installation failed" 264 | exit 1 265 | fi 266 | fi 267 | fi 268 | 269 | echo "" 270 | log "========================================" 271 | log_success "Hardware setup complete" 272 | log "🏃 Mode: $HARDWARE_MODE" 273 | log "🧠 GPU Layers: ${N_GPU_LAYERS:-0}" 274 | log "🔧 CMAKE_ARGS: ${CMAKE_ARGS:-'Not set'}" 275 | 276 | # Export variables for the startup script 277 | echo "export HARDWARE_MODE=$HARDWARE_MODE" > /tmp/hardware_config 278 | echo "export N_GPU_LAYERS=${N_GPU_LAYERS:-0}" >> /tmp/hardware_config 279 | echo "export CMAKE_ARGS='$CMAKE_ARGS'" >> /tmp/hardware_config 280 | 281 | log_success "Configuration saved to /tmp/hardware_config" 282 | echo "" 283 | } 284 | 285 | # Export function for use in other scripts 286 | if [ "${BASH_SOURCE[0]}" = "${0}" ]; then 287 | main "$@" 288 | fi 289 | -------------------------------------------------------------------------------- /dht/protocol.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import struct 4 | import time 5 | import hashlib 6 | import uuid 7 | from typing import Dict, Any, Tuple, Optional 8 | from common.utils import get_logger 9 | from common.error_handler import ErrorHandler 10 | from common.validation_utils import NodeValidator 11 | 12 | logger = get_logger(__name__) 13 | 14 | class KademliaProtocol(asyncio.DatagramProtocol): 15 | """UDP protocol for Kademlia messages""" 16 | 17 | def __init__(self, node): 18 | self.node = node 19 | self.transport = None 20 | self.pending_requests: Dict[str, asyncio.Future] = {} 21 | 22 | def connection_made(self, transport): 23 | self.transport = transport 24 | 25 | @ErrorHandler.safe_sync_call 26 | def error_received(self, exc): 27 | """Handle transport errors gracefully using consolidated error handling""" 28 | logger.error(f"UDP transport error: {exc}") 29 | 30 | # Handle specific error types without crashing 31 | error_str = str(exc).lower() 32 | if "connection refused" in error_str: 33 | logger.warning("Connection refused - remote node may be down") 34 | elif "network unreachable" in error_str: 35 | logger.warning("Network unreachable - check connectivity") 36 | elif "no buffer space available" in error_str: 37 | logger.warning("Network buffer full - system under load") 38 | elif "address already in use" in error_str: 39 | logger.error("Port already in use - check for conflicting processes") 40 | else: 41 | logger.error(f"Unhandled transport error: {exc}") 42 | 43 | def connection_lost(self, exc): 44 | """Handle connection loss""" 45 | if exc: 46 | logger.error(f"UDP connection lost: {exc}") 47 | else: 48 | logger.debug("UDP connection closed normally") 49 | 50 | def datagram_received(self, data: bytes, addr: Tuple[str, int]): 51 | """Handle incoming UDP messages with robust error handling""" 52 | try: 53 | # Validate data 54 | if not data: 55 | logger.debug(f"Received empty UDP packet from {addr}") 56 | return 57 | 58 | if len(data) > 65507: # Max UDP payload size 59 | logger.warning(f"Received oversized UDP packet ({len(data)} bytes) from {addr}") 60 | return 61 | 62 | # Decode with proper error handling 63 | try: 64 | message_str = data.decode('utf-8') 65 | except UnicodeDecodeError as e: 66 | logger.warning(f"Failed to decode message from {addr}: {e}") 67 | return 68 | 69 | # Parse JSON with validation 70 | try: 71 | message = json.loads(message_str) 72 | except json.JSONDecodeError as e: 73 | logger.warning(f"Failed to parse JSON from {addr}: {e}") 74 | return 75 | 76 | # Validate message structure 77 | if not isinstance(message, dict) or 'type' not in message: 78 | logger.warning(f"Invalid message structure from {addr}") 79 | return 80 | 81 | # Handle message asynchronously with error protection 82 | asyncio.create_task(self._handle_message_safe(message, addr)) 83 | 84 | except Exception as e: 85 | logger.error(f"Unexpected error in datagram_received from {addr}: {e}") 86 | # Don't re-raise - this would crash the event loop 87 | 88 | async def _handle_message_safe(self, message: Dict[str, Any], addr: Tuple[str, int]): 89 | """Safely handle message with error isolation""" 90 | try: 91 | await self._handle_message(message, addr) 92 | except Exception as e: 93 | logger.error(f"Error handling message from {addr}: {e}") 94 | # Log but don't crash the protocol handler 95 | 96 | async def _handle_message(self, message: Dict[str, Any], addr: Tuple[str, int]): 97 | """Simplified message handling - DHT operations only""" 98 | msg_type = message.get('type') 99 | sender_id = message.get('sender_id') 100 | 101 | logger.info(f"📡 DHT message received: {sender_id}, {msg_type}, {addr}") 102 | 103 | # Handle DHT protocol messages only 104 | # 'ping' acts like an act node 105 | if msg_type == 'ping': 106 | await self._handle_ping(message, addr) 107 | elif msg_type == 'store': 108 | await self._handle_store(message, addr) 109 | elif msg_type == 'find_node': 110 | await self._handle_find_node(message, addr) 111 | elif msg_type == 'find_value': 112 | await self._handle_find_value(message, addr) 113 | elif msg_type == 'response': 114 | await self._handle_response(message, addr) 115 | elif msg_type == 'leave_notification': 116 | await self._handle_leave_notification(message, addr) 117 | else: 118 | logger.debug(f"Unknown DHT message type: {msg_type}") 119 | 120 | async def _handle_ping(self, message: Dict[str, Any], addr: Tuple[str, int]): 121 | """Handle ping message - basic DHT operation only""" 122 | sender_id = message.get('sender_id') 123 | if sender_id: 124 | from dht.kademlia_node import Contact 125 | contact = Contact(sender_id, addr[0], addr[1]) 126 | self.node.routing_table.add_contact(contact) 127 | 128 | response = { 129 | 'type': 'response', 130 | 'id': message.get('id'), 131 | 'sender_id': self.node.node_id, 132 | 'data': { 133 | 'pong': True, 134 | 'sender_id': self.node.node_id 135 | } 136 | } 137 | await self._send_message(response, addr) 138 | 139 | async def _handle_leave_notification(self, message: Dict[str, Any], addr: Tuple[str, int]): 140 | """Handle ping message - basic DHT operation only""" 141 | sender_id = message.get('sender_id') 142 | await self.node.handle_network_leave_event(sender_id, 'interrupted') 143 | 144 | async def _handle_store(self, message: Dict[str, Any], addr: Tuple[str, int]): 145 | """Handle store message""" 146 | key = message.get('key') 147 | value = message.get('value') 148 | 149 | if key and value is not None: 150 | self.node.storage[key] = { 151 | 'value': value, 152 | 'timestamp': time.time() 153 | } 154 | 155 | response = { 156 | 'type': 'response', 157 | 'id': message.get('id'), 158 | 'sender_id': self.node.node_id, 159 | 'data': {'stored': True} 160 | } 161 | await self._send_message(response, addr) 162 | 163 | async def _handle_find_node(self, message: Dict[str, Any], addr: Tuple[str, int]): 164 | """Handle find_node message""" 165 | target_id = message.get('target_id') 166 | closest = self.node.routing_table.find_closest_contacts(target_id, self.node.k) 167 | 168 | contacts_data = [ 169 | {'node_id': c.node_id, 'ip': c.ip, 'port': c.port} 170 | for c in closest 171 | ] 172 | 173 | response = { 174 | 'type': 'response', 175 | 'id': message.get('id'), 176 | 'sender_id': self.node.node_id, 177 | 'data': {'contacts': contacts_data} 178 | } 179 | await self._send_message(response, addr) 180 | 181 | async def _handle_find_value(self, message: Dict[str, Any], addr: Tuple[str, int]): 182 | """Handle find_value message""" 183 | key = message.get('key') 184 | 185 | # Check if we have the value 186 | if key in self.node.storage: 187 | stored_item = self.node.storage[key] 188 | if time.time() - stored_item['timestamp'] < self.node.ttl: 189 | response = { 190 | 'type': 'response', 191 | 'id': message.get('id'), 192 | 'sender_id': self.node.node_id, 193 | 'data': {'value': stored_item['value']} 194 | } 195 | await self._send_message(response, addr) 196 | return 197 | 198 | # Return closest nodes instead 199 | target_hash = hashlib.sha1(key.encode()).hexdigest() 200 | closest = self.node.routing_table.find_closest_contacts(target_hash, self.node.k) 201 | 202 | contacts_data = [ 203 | {'node_id': c.node_id, 'ip': c.ip, 'port': c.port} 204 | for c in closest 205 | ] 206 | 207 | response = { 208 | 'type': 'response', 209 | 'id': message.get('id'), 210 | 'sender_id': self.node.node_id, 211 | 'data': {'contacts': contacts_data} 212 | } 213 | await self._send_message(response, addr) 214 | 215 | async def _handle_response(self, message: Dict[str, Any], addr: Tuple[str, int]): 216 | """Handle response message""" 217 | msg_id = message.get('id') 218 | sender_id = message.get('sender_id') 219 | 220 | # Update contact activity for responses too 221 | if sender_id: 222 | self.node.routing_table.update_contact_seen(sender_id) 223 | 224 | if msg_id in self.pending_requests: 225 | future = self.pending_requests.pop(msg_id) 226 | if not future.done(): 227 | future.set_result(message.get('data')) 228 | 229 | async def _send_message(self, message: Dict[str, Any], addr: Tuple[str, int]): 230 | """Send a message to an address""" 231 | try: 232 | data = json.dumps(message).encode() 233 | self.transport.sendto(data, addr) 234 | except Exception as e: 235 | logger.error(f"Failed to send message to {addr}: {e}") 236 | 237 | async def send_request(self, message: Dict[str, Any], addr: Tuple[str, int], timeout: float = 5.0) -> Optional[Dict[str, Any]]: 238 | """Send a request and wait for response""" 239 | if not self.transport: 240 | logger.error("Cannot send request: transport not available") 241 | return None 242 | 243 | msg_id = message.get('id') 244 | if not msg_id: 245 | msg_id = str(uuid.uuid4()) 246 | message['id'] = msg_id 247 | 248 | # Create future for response 249 | future = asyncio.Future() 250 | self.pending_requests[msg_id] = future 251 | 252 | try: 253 | # Send message 254 | await self._send_message(message, addr) 255 | 256 | # Wait for response with timeout 257 | response = await asyncio.wait_for(future, timeout=timeout) 258 | return response 259 | 260 | except asyncio.TimeoutError: 261 | logger.debug(f"Request timeout to {addr}: {message.get('type')}") 262 | return None 263 | except Exception as e: 264 | logger.error(f"Error sending request to {addr}: {e}") 265 | return None 266 | finally: 267 | # Clean up pending request 268 | self.pending_requests.pop(msg_id, None) 269 | -------------------------------------------------------------------------------- /common/models.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from pydantic import BaseModel, Field 4 | from typing import Optional, List, Dict, Any, Union, AsyncGenerator 5 | import time 6 | import uuid 7 | 8 | class NodeInfo(BaseModel): 9 | """Information about an inference node""" 10 | node_id: str 11 | ip: str # Primary IP for backward compatibility 12 | port: int 13 | model: str 14 | load: float = 0.0 15 | tps: float = 0.0 16 | uptime: int = 0 17 | last_seen: int = Field(default_factory=lambda: int(time.time())) 18 | 19 | # Event-driven metadata 20 | event_driven: bool = True # Mark as event-driven update 21 | last_significant_change: Optional[int] = None # When metrics last changed significantly 22 | change_reason: Optional[str] = None # Why this update was triggered 23 | 24 | # Multi-IP support for auto IP selection 25 | available_ips: Optional[List[str]] = None # All available IP addresses 26 | ip_types: Optional[Dict[str, str]] = None # IP classification (public/private/loopback) 27 | preferred_ip: Optional[str] = None # Client's preferred IP after testing 28 | 29 | # Additional metadata 30 | cpu_info: Optional[str] = None 31 | ram_total: Optional[int] = None 32 | gpu_info: Optional[str] = None 33 | context_size: Optional[int] = None 34 | 35 | def get_all_ips(self) -> List[str]: 36 | """Get all available IP addresses for this node""" 37 | if self.available_ips: 38 | return self.available_ips 39 | return [self.ip] # Fallback to primary IP 40 | 41 | def get_best_ip_for_client(self, client_ip: str = None) -> str: 42 | """Get the best IP address for a client to connect to""" 43 | if self.preferred_ip: 44 | return self.preferred_ip 45 | 46 | # If we have multiple IPs, try to select the best one 47 | if self.available_ips and len(self.available_ips) > 1: 48 | # Prefer public IPs, then private, then others 49 | public_ips = [ip for ip in self.available_ips 50 | if self.ip_types and self.ip_types.get(ip) == "public"] 51 | private_ips = [ip for ip in self.available_ips 52 | if self.ip_types and self.ip_types.get(ip) == "private"] 53 | 54 | if public_ips: 55 | return public_ips[0] 56 | elif private_ips: 57 | return private_ips[0] 58 | 59 | return self.ip # Fallback to primary IP 60 | 61 | # OpenAI-compatible models with reasoning support 62 | class OpenAIMessage(BaseModel): 63 | """OpenAI chat message format with reasoning support""" 64 | role: str # "system", "user", "assistant" 65 | content: str 66 | reasoning_content: Optional[str] = None # Add reasoning_content field for reasoning models 67 | 68 | class OpenAICompletionRequest(BaseModel): 69 | """OpenAI-compatible completion request""" 70 | model: str = "llamanet" 71 | prompt: Union[str, List[str]] 72 | max_tokens: Optional[int] = 100 73 | temperature: Optional[float] = 0.7 74 | top_p: Optional[float] = 0.9 75 | n: Optional[int] = 1 76 | stream: Optional[bool] = False 77 | stop: Optional[Union[str, List[str]]] = None 78 | presence_penalty: Optional[float] = 0.0 79 | frequency_penalty: Optional[float] = 0.0 80 | logit_bias: Optional[Dict[str, float]] = None 81 | user: Optional[str] = None 82 | suffix: Optional[str] = None 83 | echo: Optional[bool] = False 84 | strategy: Optional[str] = "round_robin" 85 | target_model: Optional[str] = None # Add target model parameter 86 | reasoning: Optional[bool] = True # Add reasoning parameter 87 | 88 | class OpenAIChatCompletionRequest(BaseModel): 89 | """OpenAI-compatible chat completion request with reasoning support""" 90 | model: str = "llamanet" 91 | messages: List[OpenAIMessage] 92 | max_tokens: Optional[int] = 100 93 | temperature: Optional[float] = 0.7 94 | top_p: Optional[float] = 0.9 95 | n: Optional[int] = 1 96 | stream: Optional[bool] = False 97 | stop: Optional[Union[str, List[str]]] = None 98 | presence_penalty: Optional[float] = 0.0 99 | frequency_penalty: Optional[float] = 0.0 100 | logit_bias: Optional[Dict[str, float]] = None 101 | user: Optional[str] = None 102 | strategy: Optional[str] = "round_robin" 103 | target_model: Optional[str] = None 104 | reasoning: Optional[bool] = True # Enable reasoning by default 105 | enable_reasoning: Optional[bool] = None # Alternative parameter name for compatibility 106 | 107 | class OpenAIChoice(BaseModel): 108 | """OpenAI choice object""" 109 | text: Optional[str] = None 110 | message: Optional[OpenAIMessage] = None 111 | index: int 112 | finish_reason: Optional[str] = "stop" 113 | logprobs: Optional[Dict] = None 114 | 115 | class OpenAIUsage(BaseModel): 116 | """OpenAI usage statistics""" 117 | prompt_tokens: int 118 | completion_tokens: int 119 | total_tokens: int 120 | 121 | class OpenAICompletionResponse(BaseModel): 122 | """OpenAI-compatible completion response""" 123 | id: str 124 | object: str = "text_completion" 125 | created: int 126 | model: str 127 | choices: List[OpenAIChoice] 128 | usage: OpenAIUsage 129 | node_info: Optional[Dict[str, Any]] = None 130 | 131 | class OpenAIChatCompletionResponse(BaseModel): 132 | """OpenAI-compatible chat completion response""" 133 | id: str 134 | object: str = "chat.completion" 135 | created: int 136 | model: str 137 | choices: List[OpenAIChoice] 138 | usage: OpenAIUsage 139 | node_info: Optional[Dict[str, Any]] = None 140 | 141 | class OpenAIModel(BaseModel): 142 | """OpenAI model object""" 143 | id: str 144 | object: str = "model" 145 | created: int 146 | owned_by: str = "llamanet" 147 | 148 | class OpenAIModelList(BaseModel): 149 | """OpenAI models list response""" 150 | object: str = "list" 151 | data: List[OpenAIModel] 152 | 153 | # Streaming OpenAI models with reasoning support 154 | class OpenAIStreamingDelta(BaseModel): 155 | """OpenAI streaming delta object with reasoning support""" 156 | content: Optional[str] = None 157 | role: Optional[str] = None 158 | reasoning_content: Optional[str] = None # Add reasoning_content field 159 | 160 | class OpenAIStreamingChoice(BaseModel): 161 | """OpenAI streaming choice object""" 162 | delta: OpenAIStreamingDelta 163 | index: int 164 | finish_reason: Optional[str] = None 165 | 166 | class OpenAIStreamingChatResponse(BaseModel): 167 | """OpenAI-compatible streaming chat response""" 168 | id: str 169 | object: str = "chat.completion.chunk" 170 | created: int 171 | model: str 172 | choices: List[OpenAIStreamingChoice] 173 | node_info: Optional[Dict[str, Any]] = None 174 | 175 | class OpenAIStreamingCompletionChoice(BaseModel): 176 | """OpenAI streaming completion choice""" 177 | text: str 178 | index: int 179 | finish_reason: Optional[str] = None 180 | logprobs: Optional[Dict] = None 181 | 182 | class OpenAIStreamingCompletionResponse(BaseModel): 183 | """OpenAI-compatible streaming completion response""" 184 | id: str 185 | object: str = "text_completion" 186 | created: int 187 | model: str 188 | choices: List[OpenAIStreamingCompletionChoice] 189 | node_info: Optional[Dict[str, Any]] = None 190 | 191 | 192 | # Streaming utilities 193 | def create_sse_data(data: Dict[str, Any]) -> str: 194 | """Create Server-Sent Events formatted data""" 195 | return f"data: {json.dumps(data)}\n\n" 196 | 197 | 198 | def create_sse_done() -> str: 199 | """Create SSE done signal""" 200 | return "data: [DONE]\n\n" 201 | 202 | 203 | async def create_streaming_chat_response( 204 | request_id: str, 205 | model: str, 206 | stream_generator: AsyncGenerator[Dict[str, Any], None], 207 | node_info: Optional[Dict[str, Any]] = None 208 | ) -> AsyncGenerator[str, None]: 209 | """Create OpenAI-compatible streaming chat completion response with reasoning support""" 210 | created = int(time.time()) 211 | 212 | # Send initial chunk with role and node info 213 | initial_chunk = OpenAIStreamingChatResponse( 214 | id=request_id, 215 | created=created, 216 | model=model, 217 | choices=[OpenAIStreamingChoice( 218 | delta=OpenAIStreamingDelta(role="assistant"), 219 | index=0 220 | )], 221 | node_info=node_info 222 | ) 223 | yield create_sse_data(initial_chunk.dict()) 224 | 225 | # Stream content chunks with reasoning support 226 | async for chunk in stream_generator: 227 | delta_content = {} 228 | 229 | # Handle reasoning content first (if available) 230 | if chunk.get("reasoning_content"): 231 | delta_content["reasoning_content"] = chunk["reasoning_content"] 232 | 233 | # Handle regular content 234 | if chunk.get("text") or chunk.get("content"): 235 | delta_content["content"] = chunk.get("text") or chunk.get("content") 236 | 237 | if delta_content: 238 | streaming_chunk = OpenAIStreamingChatResponse( 239 | id=request_id, 240 | created=created, 241 | model=model, 242 | choices=[OpenAIStreamingChoice( 243 | delta=OpenAIStreamingDelta(**delta_content), 244 | index=0, 245 | finish_reason=None if not chunk.get("finished") else "stop" 246 | )] 247 | ) 248 | yield create_sse_data(streaming_chunk.dict()) 249 | 250 | if chunk.get("finished"): 251 | # Send final chunk with finish_reason 252 | final_chunk = OpenAIStreamingChatResponse( 253 | id=request_id, 254 | created=created, 255 | model=model, 256 | choices=[OpenAIStreamingChoice( 257 | delta=OpenAIStreamingDelta(), 258 | index=0, 259 | finish_reason="stop" 260 | )] 261 | ) 262 | yield create_sse_data(final_chunk.dict()) 263 | break 264 | 265 | # Send done signal 266 | yield create_sse_done() 267 | 268 | 269 | async def create_streaming_completion_response( 270 | request_id: str, 271 | model: str, 272 | stream_generator: AsyncGenerator[Dict[str, Any], None], 273 | node_info: Optional[Dict[str, Any]] = None 274 | ) -> AsyncGenerator[str, None]: 275 | """Create OpenAI-compatible streaming completion response""" 276 | created = int(time.time()) 277 | 278 | # Stream content chunks 279 | async for chunk in stream_generator: 280 | if chunk.get("text"): 281 | streaming_chunk = OpenAIStreamingCompletionResponse( 282 | id=request_id, 283 | created=created, 284 | model=model, 285 | choices=[OpenAIStreamingCompletionChoice( 286 | text=chunk["text"], 287 | index=0, 288 | finish_reason=None if not chunk.get("finished") else "stop" 289 | )], 290 | node_info=node_info if chunk.get("text") else None # Include node_info in first content chunk 291 | ) 292 | yield create_sse_data(streaming_chunk.dict()) 293 | 294 | if chunk.get("finished"): 295 | break 296 | 297 | # Send done signal 298 | yield create_sse_done() 299 | -------------------------------------------------------------------------------- /common/subnet_matcher.py: -------------------------------------------------------------------------------- 1 | import ipaddress 2 | from typing import List, Tuple, Dict, Optional 3 | from dataclasses import dataclass 4 | from common.utils import get_logger 5 | 6 | logger = get_logger(__name__) 7 | 8 | @dataclass 9 | class SubnetMatch: 10 | """Result of subnet matching analysis""" 11 | ip: str 12 | subnet_score: float 13 | proximity_score: float 14 | total_score: float 15 | match_reason: str 16 | 17 | class SubnetMatcher: 18 | """Smart subnet matching for optimal IP selection""" 19 | 20 | def __init__(self): 21 | self.cache = {} 22 | self.cache_ttl = 300 # 5 minutes 23 | 24 | def analyze_bootstrap_context(self, bootstrap_ips: List[str]) -> Dict[str, any]: 25 | """Analyze bootstrap node IPs to understand network context""" 26 | if not bootstrap_ips: 27 | return {"subnets": [], "network_types": [], "analysis": "no_bootstrap_nodes"} 28 | 29 | context = { 30 | "subnets": [], 31 | "network_types": [], 32 | "ip_analysis": [], 33 | "dominant_type": None 34 | } 35 | 36 | for ip_str in bootstrap_ips: 37 | try: 38 | ip_obj = ipaddress.IPv4Address(ip_str) 39 | 40 | # Determine network type 41 | if ip_obj.is_private: 42 | if ip_str.startswith('192.168.'): 43 | net_type = 'private_class_c' 44 | subnet = ipaddress.IPv4Network(f"{ip_str}/24", strict=False) 45 | elif ip_str.startswith('10.'): 46 | net_type = 'private_class_a' 47 | subnet = ipaddress.IPv4Network(f"{ip_str}/8", strict=False) 48 | elif ip_str.startswith('172.'): 49 | net_type = 'private_class_b' 50 | subnet = ipaddress.IPv4Network(f"{ip_str}/12", strict=False) 51 | else: 52 | net_type = 'private_other' 53 | subnet = ipaddress.IPv4Network(f"{ip_str}/24", strict=False) 54 | else: 55 | net_type = 'public' 56 | subnet = ipaddress.IPv4Network(f"{ip_str}/24", strict=False) 57 | 58 | context["subnets"].append(subnet) 59 | context["network_types"].append(net_type) 60 | context["ip_analysis"].append({ 61 | "ip": ip_str, 62 | "type": net_type, 63 | "subnet": str(subnet), 64 | "is_private": ip_obj.is_private 65 | }) 66 | 67 | except Exception as e: 68 | logger.warning(f"Error analyzing bootstrap IP {ip_str}: {e}") 69 | 70 | # Determine dominant network type 71 | if context["network_types"]: 72 | type_counts = {} 73 | for net_type in context["network_types"]: 74 | type_counts[net_type] = type_counts.get(net_type, 0) + 1 75 | 76 | context["dominant_type"] = max(type_counts.items(), key=lambda x: x[1])[0] 77 | 78 | return context 79 | 80 | def rank_ips_for_bootstrap_context(self, 81 | available_ips: List, 82 | bootstrap_ips: List[str]) -> List[SubnetMatch]: 83 | """Rank available IPs based on compatibility with bootstrap context""" 84 | 85 | if not bootstrap_ips: 86 | # No bootstrap context, rank by general preference 87 | return self._rank_ips_general(available_ips) 88 | 89 | # Analyze bootstrap context 90 | bootstrap_context = self.analyze_bootstrap_context(bootstrap_ips) 91 | 92 | matches = [] 93 | 94 | for ip_class in available_ips: 95 | if not ip_class.is_reachable and ip_class.type != 'loopback': 96 | continue 97 | 98 | subnet_score = self._calculate_subnet_score(ip_class, bootstrap_context) 99 | proximity_score = self._calculate_proximity_score(ip_class, bootstrap_ips) 100 | 101 | # Combine scores with weights 102 | total_score = ( 103 | subnet_score * 0.6 + # Subnet compatibility (60%) 104 | proximity_score * 0.3 + # IP proximity (30%) 105 | ip_class.confidence_score * 0.1 # Base confidence (10%) 106 | ) 107 | 108 | match_reason = self._determine_match_reason(ip_class, bootstrap_context, subnet_score, proximity_score) 109 | 110 | matches.append(SubnetMatch( 111 | ip=ip_class.ip, 112 | subnet_score=subnet_score, 113 | proximity_score=proximity_score, 114 | total_score=total_score, 115 | match_reason=match_reason 116 | )) 117 | 118 | # Sort by total score (highest first) 119 | matches.sort(key=lambda x: x.total_score, reverse=True) 120 | 121 | # Log ranking results 122 | logger.info(f"🎯 Subnet matching results for bootstrap context {bootstrap_ips}:") 123 | for i, match in enumerate(matches[:3]): # Show top 3 124 | logger.info(f" {i+1}. {match.ip} (score: {match.total_score:.2f}) - {match.match_reason}") 125 | 126 | return matches 127 | 128 | def _calculate_subnet_score(self, ip_class, bootstrap_context: Dict) -> float: 129 | """Calculate subnet compatibility score""" 130 | score = 0.0 131 | 132 | try: 133 | local_ip = ipaddress.IPv4Address(ip_class.ip) 134 | 135 | # Check direct subnet matches 136 | for bootstrap_subnet in bootstrap_context.get("subnets", []): 137 | if local_ip in bootstrap_subnet: 138 | score += 100.0 # Perfect subnet match 139 | break 140 | else: 141 | # No direct match, check network type compatibility 142 | local_is_private = local_ip.is_private 143 | dominant_type = bootstrap_context.get("dominant_type", "") 144 | 145 | if local_is_private and "private" in dominant_type: 146 | score += 75.0 # Same private network class 147 | elif not local_is_private and dominant_type == "public": 148 | score += 75.0 # Both public 149 | elif local_is_private and dominant_type == "public": 150 | score += 25.0 # Private to public (lower score) 151 | elif not local_is_private and "private" in dominant_type: 152 | score += 50.0 # Public to private (medium score) 153 | 154 | # Additional scoring for IP class proximity 155 | for analysis in bootstrap_context.get("ip_analysis", []): 156 | bootstrap_ip = analysis["ip"] 157 | if self._same_ip_class(ip_class.ip, bootstrap_ip): 158 | score += 25.0 159 | break 160 | 161 | except Exception as e: 162 | logger.debug(f"Error calculating subnet score for {ip_class.ip}: {e}") 163 | 164 | return min(score, 100.0) # Cap at 100 165 | 166 | def _calculate_proximity_score(self, ip_class, bootstrap_ips: List[str]) -> float: 167 | """Calculate IP proximity score""" 168 | score = 0.0 169 | 170 | try: 171 | local_parts = ip_class.ip.split('.') 172 | 173 | for bootstrap_ip in bootstrap_ips: 174 | bootstrap_parts = bootstrap_ip.split('.') 175 | 176 | # Score based on matching octets 177 | matching_octets = 0 178 | for i in range(min(len(local_parts), len(bootstrap_parts))): 179 | if local_parts[i] == bootstrap_parts[i]: 180 | matching_octets += 1 181 | else: 182 | break 183 | 184 | # Convert to score (4 octets = 100 points) 185 | proximity = (matching_octets / 4.0) * 100.0 186 | score = max(score, proximity) # Take best match 187 | 188 | except Exception as e: 189 | logger.debug(f"Error calculating proximity score for {ip_class.ip}: {e}") 190 | 191 | return score 192 | 193 | def _same_ip_class(self, ip1: str, ip2: str) -> bool: 194 | """Check if two IPs are in the same class (A/B/C)""" 195 | try: 196 | return ip1.split('.')[0] == ip2.split('.')[0] 197 | except: 198 | return False 199 | 200 | def _determine_match_reason(self, ip_class, bootstrap_context: Dict, 201 | subnet_score: float, proximity_score: float) -> str: 202 | """Determine the reason for the match score""" 203 | 204 | if subnet_score >= 100.0: 205 | return "perfect_subnet_match" 206 | elif subnet_score >= 75.0: 207 | return "same_network_type" 208 | elif proximity_score >= 75.0: 209 | return "high_ip_proximity" 210 | elif subnet_score >= 50.0: 211 | return "compatible_network_type" 212 | elif proximity_score >= 50.0: 213 | return "medium_ip_proximity" 214 | elif ip_class.is_reachable: 215 | return "reachable_fallback" 216 | else: 217 | return "last_resort" 218 | 219 | def _rank_ips_general(self, available_ips: List) -> List[SubnetMatch]: 220 | """General IP ranking when no bootstrap context is available""" 221 | matches = [] 222 | 223 | for ip_class in available_ips: 224 | # Use confidence score as primary ranking 225 | total_score = ip_class.confidence_score 226 | 227 | match_reason = f"general_ranking_{ip_class.type}" 228 | if ip_class.is_reachable: 229 | match_reason += "_reachable" 230 | 231 | matches.append(SubnetMatch( 232 | ip=ip_class.ip, 233 | subnet_score=0.0, 234 | proximity_score=0.0, 235 | total_score=total_score, 236 | match_reason=match_reason 237 | )) 238 | 239 | matches.sort(key=lambda x: x.total_score, reverse=True) 240 | return matches 241 | 242 | def get_best_ip_for_context(self, available_ips: List, 243 | bootstrap_ips: List[str] = None) -> Optional[str]: 244 | """Get the single best IP for the given context""" 245 | 246 | matches = self.rank_ips_for_bootstrap_context(available_ips, bootstrap_ips or []) 247 | 248 | if matches: 249 | best_match = matches[0] 250 | logger.info(f"🎯 Selected best IP: {best_match.ip} (score: {best_match.total_score:.2f}, reason: {best_match.match_reason})") 251 | return best_match.ip 252 | 253 | return None 254 | 255 | # Global instance 256 | _subnet_matcher = None 257 | 258 | def get_subnet_matcher() -> SubnetMatcher: 259 | """Get the global subnet matcher instance""" 260 | global _subnet_matcher 261 | if _subnet_matcher is None: 262 | _subnet_matcher = SubnetMatcher() 263 | return _subnet_matcher 264 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /inference_node/p2p_handler.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import time 4 | import uuid 5 | from typing import Dict, Any 6 | from common.p2p_transport import P2PTransport 7 | from common.models import ( 8 | OpenAICompletionRequest, OpenAIChatCompletionRequest, 9 | OpenAICompletionResponse, OpenAIChatCompletionResponse, 10 | OpenAIChoice, OpenAIUsage, OpenAIMessage 11 | ) 12 | from inference_node.llm_wrapper import LlamaWrapper 13 | from inference_node.config import InferenceConfig 14 | from common.utils import get_logger, get_host_ip 15 | 16 | logger = get_logger(__name__) 17 | 18 | class P2PRequestHandler: 19 | """Handle inference requests via P2P transport""" 20 | 21 | def __init__(self, config: InferenceConfig, llm: LlamaWrapper): 22 | self.config = config 23 | self.llm = llm 24 | self.transport = P2PTransport(config.node_id, config.model_name) 25 | 26 | async def start(self): 27 | """Start P2P request handler with error handling""" 28 | try: 29 | await self.transport.start() 30 | self.transport.add_message_callback(self._handle_p2p_request) 31 | logger.info(f"P2P request handler started") 32 | except ImportError as e: 33 | logger.warning(f"P2P dependencies not available: {e}") 34 | raise 35 | except Exception as e: 36 | logger.error(f"Failed to start P2P transport: {e}") 37 | raise 38 | 39 | async def _handle_p2p_request(self, msg: bytes, client_tup, pipe): 40 | """Handle incoming P2P inference requests""" 41 | try: 42 | # Parse request 43 | request_data = json.loads(msg.decode('utf-8')) 44 | request_type = request_data.get('type') 45 | 46 | if request_type == 'completion_request': 47 | await self._handle_completion_request(request_data, pipe) 48 | elif request_type == 'chat_completion_request': 49 | await self._handle_chat_completion_request(request_data, pipe) 50 | elif request_type == 'status_request': 51 | await self._handle_status_request(request_data, pipe) 52 | elif request_type == 'ping': 53 | await self._handle_ping_request(request_data, pipe) 54 | else: 55 | logger.warning(f"Unknown P2P request type: {request_type}") 56 | 57 | except Exception as e: 58 | logger.error(f"Error handling P2P request: {e}") 59 | # Send error response 60 | error_response = { 61 | 'type': 'error_response', 62 | 'error': str(e), 63 | 'timestamp': time.time() 64 | } 65 | try: 66 | await self.transport.send_message(pipe, json.dumps(error_response).encode('utf-8')) 67 | except: 68 | pass 69 | 70 | async def _handle_completion_request(self, request_data: Dict[str, Any], pipe): 71 | """Handle completion request via P2P""" 72 | try: 73 | request = OpenAICompletionRequest(**request_data['data']) 74 | 75 | # Handle prompt (can be string or list) 76 | if isinstance(request.prompt, list): 77 | prompt = request.prompt[0] if request.prompt else "" 78 | else: 79 | prompt = request.prompt 80 | 81 | # Normalize stop tokens 82 | stop_tokens = None 83 | if request.stop: 84 | if isinstance(request.stop, str): 85 | stop_tokens = [request.stop] if request.stop.strip() else None 86 | elif isinstance(request.stop, list): 87 | stop_tokens = [str(token).strip() for token in request.stop if str(token).strip()] 88 | stop_tokens = stop_tokens if stop_tokens else None 89 | 90 | # Generate response 91 | result = self.llm.generate( 92 | prompt=prompt, 93 | max_tokens=request.max_tokens or 100, 94 | temperature=request.temperature or 0.7, 95 | top_p=request.top_p or 0.9, 96 | stop=stop_tokens, 97 | repeat_penalty=1.0 + (request.frequency_penalty or 0.0) 98 | ) 99 | 100 | # Calculate token counts 101 | prompt_tokens = len(prompt.split()) 102 | completion_tokens = result["tokens_generated"] 103 | 104 | # Create response 105 | choice = OpenAIChoice( 106 | text=result["text"], 107 | index=0, 108 | finish_reason="stop" 109 | ) 110 | 111 | usage = OpenAIUsage( 112 | prompt_tokens=prompt_tokens, 113 | completion_tokens=completion_tokens, 114 | total_tokens=prompt_tokens + completion_tokens 115 | ) 116 | 117 | response = OpenAICompletionResponse( 118 | id=f"cmpl-{uuid.uuid4().hex[:8]}", 119 | created=int(time.time()), 120 | model=request.model, 121 | choices=[choice], 122 | usage=usage, 123 | node_info={ 124 | "node_id": self.config.node_id, 125 | "ip": get_host_ip(), 126 | "port": self.config.port, 127 | "model": self.config.model_name, 128 | "processing_node": "p2p", 129 | "transport": "p2p" 130 | } 131 | ) 132 | 133 | # Send response 134 | response_msg = { 135 | 'type': 'completion_response', 136 | 'data': response.dict(), 137 | 'request_id': request_data.get('request_id'), 138 | 'timestamp': time.time() 139 | } 140 | 141 | await self.transport.send_message(pipe, json.dumps(response_msg).encode('utf-8')) 142 | 143 | except Exception as e: 144 | logger.error(f"Error in P2P completion request: {e}") 145 | raise 146 | 147 | async def _handle_chat_completion_request(self, request_data: Dict[str, Any], pipe): 148 | """Handle chat completion request via P2P""" 149 | try: 150 | request = OpenAIChatCompletionRequest(**request_data['data']) 151 | 152 | # Convert messages to prompt 153 | prompt_parts = [] 154 | for message in request.messages: 155 | if message.role == "system": 156 | prompt_parts.append(f"System: {message.content}") 157 | elif message.role == "user": 158 | prompt_parts.append(f"Human: {message.content}") 159 | elif message.role == "assistant": 160 | prompt_parts.append(f"Assistant: {message.content}") 161 | 162 | prompt = "\n\n".join(prompt_parts) + "\n\nAssistant:" 163 | 164 | # Default stop tokens for chat format 165 | stop_tokens = ["\n\nHuman:", "\n\nUser:", "\nHuman:", "\nUser:", "Human:", "User:"] 166 | if request.stop: 167 | if isinstance(request.stop, str): 168 | stop_tokens.append(request.stop) 169 | elif isinstance(request.stop, list): 170 | stop_tokens.extend(request.stop) 171 | 172 | # Generate response 173 | result = self.llm.generate( 174 | prompt=prompt, 175 | max_tokens=request.max_tokens or 100, 176 | temperature=request.temperature or 0.7, 177 | top_p=request.top_p or 0.9, 178 | stop=stop_tokens, 179 | repeat_penalty=1.0 + (request.frequency_penalty or 0.0) 180 | ) 181 | 182 | # Calculate token counts 183 | prompt_tokens = len(prompt.split()) 184 | completion_tokens = result["tokens_generated"] 185 | 186 | # Create response 187 | response_message = OpenAIMessage( 188 | role="assistant", 189 | content=result["text"].strip() 190 | ) 191 | 192 | choice = OpenAIChoice( 193 | message=response_message, 194 | index=0, 195 | finish_reason="stop" 196 | ) 197 | 198 | usage = OpenAIUsage( 199 | prompt_tokens=prompt_tokens, 200 | completion_tokens=completion_tokens, 201 | total_tokens=prompt_tokens + completion_tokens 202 | ) 203 | 204 | response = OpenAIChatCompletionResponse( 205 | id=f"chatcmpl-{uuid.uuid4().hex[:8]}", 206 | created=int(time.time()), 207 | model=request.model, 208 | choices=[choice], 209 | usage=usage, 210 | node_info={ 211 | "node_id": self.config.node_id, 212 | "ip": get_host_ip(), 213 | "port": self.config.port, 214 | "model": self.config.model_name, 215 | "processing_node": "p2p", 216 | "transport": "p2p" 217 | } 218 | ) 219 | 220 | # Send response 221 | response_msg = { 222 | 'type': 'chat_completion_response', 223 | 'data': response.dict(), 224 | 'request_id': request_data.get('request_id'), 225 | 'timestamp': time.time() 226 | } 227 | 228 | await self.transport.send_message(pipe, json.dumps(response_msg).encode('utf-8')) 229 | 230 | except Exception as e: 231 | logger.error(f"Error in P2P chat completion request: {e}") 232 | raise 233 | 234 | async def _handle_status_request(self, request_data: Dict[str, Any], pipe): 235 | """Handle status request via P2P""" 236 | try: 237 | metrics = self.llm.get_metrics() 238 | 239 | response_msg = { 240 | 'type': 'status_response', 241 | 'data': { 242 | **metrics, 243 | 'node_id': self.config.node_id, 244 | 'model': self.config.model_name, 245 | 'transport': 'p2p', 246 | 'p2p_info': self.transport.get_address_info() 247 | }, 248 | 'request_id': request_data.get('request_id'), 249 | 'timestamp': time.time() 250 | } 251 | 252 | await self.transport.send_message(pipe, json.dumps(response_msg).encode('utf-8')) 253 | 254 | except Exception as e: 255 | logger.error(f"Error in P2P status request: {e}") 256 | raise 257 | 258 | async def _handle_ping_request(self, request_data: Dict[str, Any], pipe): 259 | """Handle ping request via P2P""" 260 | try: 261 | response_msg = { 262 | 'type': 'pong_response', 263 | 'data': { 264 | 'node_id': self.config.node_id, 265 | 'model': self.config.model_name, 266 | 'timestamp': time.time() 267 | }, 268 | 'request_id': request_data.get('request_id'), 269 | 'timestamp': time.time() 270 | } 271 | 272 | await self.transport.send_message(pipe, json.dumps(response_msg).encode('utf-8')) 273 | 274 | except Exception as e: 275 | logger.error(f"Error in P2P ping request: {e}") 276 | raise 277 | 278 | def get_p2p_info(self) -> Dict[str, Any]: 279 | """Get P2P transport information""" 280 | return self.transport.get_address_info() 281 | 282 | async def close(self): 283 | """Close P2P request handler""" 284 | await self.transport.close() 285 | -------------------------------------------------------------------------------- /dht/routing_table.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import List, Dict, Optional, TYPE_CHECKING, Any 3 | from common.utils import get_logger 4 | from common.validation_utils import NodeValidator 5 | 6 | if TYPE_CHECKING: 7 | from dht.kademlia_node import Contact 8 | 9 | logger = get_logger(__name__) 10 | 11 | class KBucket: 12 | """K-bucket for storing contacts""" 13 | 14 | def __init__(self, k: int = 20): 15 | self.k = k 16 | self.contacts: List['Contact'] = [] 17 | self.last_updated = time.time() 18 | 19 | def add_contact(self, contact: 'Contact') -> bool: 20 | """Add a contact to the bucket""" 21 | # Remove if already exists 22 | self.contacts = [c for c in self.contacts if c.node_id != contact.node_id] 23 | 24 | # Add to front 25 | self.contacts.insert(0, contact) 26 | 27 | # Trim to k size 28 | if len(self.contacts) > self.k: 29 | self.contacts = self.contacts[:self.k] 30 | 31 | self.last_updated = time.time() 32 | return True 33 | 34 | def remove_contact(self, node_id: str): 35 | """Remove a contact from the bucket""" 36 | self.contacts = [c for c in self.contacts if c.node_id != node_id] 37 | 38 | def get_contacts(self) -> List['Contact']: 39 | """Get all contacts in the bucket""" 40 | return self.contacts.copy() 41 | 42 | class RoutingTable: 43 | """Kademlia routing table with k-buckets""" 44 | 45 | def __init__(self, node_id: str, k: int = 20): 46 | self.node_id = node_id 47 | self.k = k 48 | self.buckets: Dict[int, KBucket] = {} 49 | self.contact_timeout = 180 # Increased from 60 to 180 seconds (3 minutes) 50 | 51 | def add_contact(self, contact: 'Contact'): 52 | """Add a contact to the appropriate bucket""" 53 | if contact.node_id == self.node_id: 54 | return # Don't add ourselves 55 | 56 | bucket_index = self._get_bucket_index(contact.node_id) 57 | 58 | if bucket_index not in self.buckets: 59 | self.buckets[bucket_index] = KBucket(self.k) 60 | 61 | # Check if this is a new contact before adding 62 | existing_contact_ids = [c.node_id for c in self.buckets[bucket_index].contacts] 63 | is_new_contact = contact.node_id not in existing_contact_ids 64 | 65 | # IMPORTANT: Remove old contact with same node_id (not ip:port combo) 66 | # This ensures we update the contact info if ports change 67 | self.buckets[bucket_index].contacts = [ 68 | c for c in self.buckets[bucket_index].contacts 69 | if c.node_id != contact.node_id 70 | ] 71 | 72 | self.buckets[bucket_index].add_contact(contact) 73 | 74 | if is_new_contact: 75 | logger.info(f"🔗 New DHT contact added: {contact.node_id} ({contact.ip}:{contact.port})") 76 | else: 77 | logger.info(f"Updated contact {contact.node_id} in bucket {bucket_index} - new address: {contact.ip}:{contact.port}") 78 | 79 | def remove_contact(self, node_id: str): 80 | """Remove a contact from the routing table""" 81 | 82 | bucket_index = self._get_bucket_index(node_id) 83 | if bucket_index in self.buckets: 84 | self.buckets[bucket_index].remove_contact(node_id) 85 | 86 | logger.info(f"🔗 DHT contact removed: {node_id}") 87 | 88 | self.cleanup_stale_contacts() 89 | 90 | def cleanup_stale_contacts(self): 91 | """Remove contacts that haven't been seen recently""" 92 | current_time = time.time() 93 | removed_count = 0 94 | 95 | for bucket_index, bucket in list(self.buckets.items()): 96 | original_count = len(bucket.contacts) 97 | 98 | # Filter out stale contacts 99 | stale_contacts = [] 100 | active_contacts = [] 101 | 102 | for contact in bucket.contacts: 103 | if current_time - contact.last_seen < self.contact_timeout: 104 | active_contacts.append(contact) 105 | else: 106 | stale_contacts.append(contact) 107 | 108 | bucket.contacts = active_contacts 109 | 110 | # Log removed contacts 111 | for contact in stale_contacts: 112 | logger.info(f"🧹 Removed stale contact: {contact.node_id[:12]}... ({contact.ip}:{contact.port}) - last seen {int(current_time - contact.last_seen)}s ago") 113 | 114 | removed = original_count - len(bucket.contacts) 115 | if removed > 0: 116 | removed_count += removed 117 | 118 | # Remove empty buckets 119 | if not bucket.contacts: 120 | del self.buckets[bucket_index] 121 | 122 | if removed_count > 0: 123 | logger.info(f"🧹 Total cleanup: removed {removed_count} stale contacts") 124 | 125 | return removed_count 126 | 127 | def update_contact_seen(self, node_id: str): 128 | """Update last_seen time for a contact""" 129 | for bucket in self.buckets.values(): 130 | for contact in bucket.contacts: 131 | if contact.node_id == node_id: 132 | contact.last_seen = time.time() 133 | logger.debug(f"📡 Updated last_seen for contact {node_id[:8]}...") 134 | return True 135 | return False 136 | 137 | def find_closest_contacts(self, target_id: str, count: int) -> List['Contact']: 138 | """Find the closest contacts to a target ID""" 139 | all_contacts = [] 140 | 141 | # Collect all contacts 142 | for bucket in self.buckets.values(): 143 | all_contacts.extend(bucket.get_contacts()) 144 | 145 | # Sort by distance to target 146 | all_contacts.sort(key=lambda c: c.distance(target_id)) 147 | 148 | return all_contacts[:count] 149 | 150 | def _get_bucket_index(self, node_id: str) -> int: 151 | """Get the bucket index for a node ID""" 152 | try: 153 | # Ensure both are strings 154 | self_id_str = str(self.node_id) 155 | node_id_str = str(node_id) 156 | 157 | distance = int(self_id_str, 16) ^ int(node_id_str, 16) 158 | if distance == 0: 159 | return 0 160 | return distance.bit_length() - 1 161 | except (ValueError, TypeError) as e: 162 | logger.error(f"Invalid node ID format for bucket calculation: {self.node_id} or {node_id}: {e}") 163 | return 0 # Default bucket 164 | 165 | def get_all_contacts(self) -> List['Contact']: 166 | """Get all contacts in the routing table""" 167 | all_contacts = [] 168 | for bucket in self.buckets.values(): 169 | all_contacts.extend(bucket.get_contacts()) 170 | return all_contacts 171 | 172 | def get_unique_contacts(self) -> List['Contact']: 173 | """Get all unique contacts (deduplicated by node_id)""" 174 | seen_ids = set() 175 | unique_contacts = [] 176 | 177 | for bucket in self.buckets.values(): 178 | for contact in bucket.get_contacts(): 179 | if contact.node_id not in seen_ids: 180 | seen_ids.add(contact.node_id) 181 | unique_contacts.append(contact) 182 | 183 | return unique_contacts 184 | 185 | def handle_node_join(self, contact: 'Contact', join_source: str = 'unknown') -> bool: 186 | """Handle explicit node join event with enhanced tracking""" 187 | if contact.node_id == self.node_id: 188 | return False # Don't add ourselves 189 | 190 | # Validate contact before adding 191 | if not NodeValidator.validate_contact(contact): 192 | logger.warning(f"Invalid contact in join event: {contact.node_id}") 193 | return False 194 | 195 | # Check if this is truly a new contact 196 | existing_contact = self.get_contact_by_id(contact.node_id) 197 | is_new_contact = existing_contact is None 198 | 199 | # Add/update the contact 200 | self.add_contact(contact) 201 | 202 | if is_new_contact: 203 | logger.info(f"🆕 Node joined DHT: {contact.node_id[:12]}... ({contact.ip}:{contact.port}) via {join_source}") 204 | else: 205 | logger.info(f"🔄 Updated existing contact: {contact.node_id[:8]}... new address: {contact.ip}:{contact.port}") 206 | 207 | return is_new_contact 208 | 209 | def handle_node_leave(self, node_id: str, leave_reason: str = 'unknown') -> bool: 210 | """Handle explicit node leave event""" 211 | if node_id == self.node_id: 212 | return False # Don't remove ourselves 213 | 214 | # Validate node ID 215 | if not NodeValidator.validate_node_id(node_id): 216 | logger.warning(f"Invalid node ID in leave event: {node_id}") 217 | return False 218 | 219 | # Check if we actually have this contact 220 | existing_contact = self.get_contact_by_id(node_id) 221 | if existing_contact: 222 | self.remove_contact(node_id) 223 | logger.info(f"👋 Node left DHT: {node_id[:12]}... (reason: {leave_reason})") 224 | return True 225 | else: 226 | logger.debug(f"🤷 Leave notification for unknown node: {node_id[:8]}...") 227 | return False 228 | 229 | def get_contact_by_id(self, node_id: str) -> Optional['Contact']: 230 | """Get a contact by node ID""" 231 | for bucket in self.buckets.values(): 232 | for contact in bucket.contacts: 233 | if contact.node_id == node_id: 234 | return contact 235 | return None 236 | 237 | def update_contact_from_event(self, node_id: str, new_ip: str, new_port: int) -> bool: 238 | """Update contact information from network events""" 239 | existing_contact = self.get_contact_by_id(node_id) 240 | if existing_contact: 241 | # Update contact info if it changed 242 | if existing_contact.ip != new_ip or existing_contact.port != new_port: 243 | logger.info(f"📍 Contact address updated: {node_id[:8]}... {existing_contact.ip}:{existing_contact.port} -> {new_ip}:{new_port}") 244 | existing_contact.ip = new_ip 245 | existing_contact.port = new_port 246 | existing_contact.last_seen = time.time() 247 | return True 248 | else: 249 | # Just update last_seen 250 | existing_contact.last_seen = time.time() 251 | return False 252 | return False 253 | 254 | def get_routing_table_events(self) -> Dict[str, Any]: 255 | """Get routing table statistics for event broadcasting""" 256 | current_time = time.time() 257 | all_contacts = self.get_all_contacts() 258 | 259 | active_contacts = [c for c in all_contacts if current_time - c.last_seen < 60] 260 | recent_contacts = [c for c in all_contacts if current_time - c.last_seen < 30] 261 | 262 | return { 263 | "total_contacts": len(all_contacts), 264 | "active_contacts": len(active_contacts), 265 | "recent_contacts": len(recent_contacts), 266 | "buckets_count": len(self.buckets), 267 | "last_updated": current_time 268 | } 269 | 270 | def get_stats(self) -> Dict[str, Any]: 271 | """Get routing table statistics""" 272 | current_time = time.time() 273 | all_contacts = self.get_all_contacts() 274 | 275 | active_contacts = [c for c in all_contacts if current_time - c.last_seen < 30] 276 | stale_contacts = [c for c in all_contacts if current_time - c.last_seen >= 30] 277 | 278 | return { 279 | "total_contacts": len(all_contacts), 280 | "active_contacts": len(active_contacts), 281 | "stale_contacts": len(stale_contacts), 282 | "buckets_count": len(self.buckets), 283 | "contact_timeout": self.contact_timeout 284 | } 285 | -------------------------------------------------------------------------------- /common/sse_handler.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | import time 5 | from typing import AsyncGenerator, Dict, Any, Optional, Callable, List, Set 6 | import aiohttp 7 | from dataclasses import dataclass 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | @dataclass 12 | class SSEChunk: 13 | """Represents a single SSE chunk""" 14 | data: str 15 | event: Optional[str] = None 16 | id: Optional[str] = None 17 | retry: Optional[int] = None 18 | 19 | class SSEHandler: 20 | """Server-side SSE connection and event management""" 21 | 22 | def __init__(self): 23 | self.active_connections: Dict[str, asyncio.Queue] = {} 24 | self.event_listeners: List[Callable] = [] 25 | self.running = False 26 | 27 | async def add_connection(self, connection_id: str) -> asyncio.Queue: 28 | """Add a new SSE connection""" 29 | event_queue = asyncio.Queue(maxsize=100) 30 | self.active_connections[connection_id] = event_queue 31 | logger.info(f"SSE connection added: {connection_id} (total: {len(self.active_connections)})") 32 | return event_queue 33 | 34 | async def remove_connection(self, connection_id: str): 35 | """Remove an SSE connection""" 36 | if connection_id in self.active_connections: 37 | del self.active_connections[connection_id] 38 | logger.info(f"SSE connection removed: {connection_id} (remaining: {len(self.active_connections)})") 39 | 40 | def add_event_listener(self, listener: Callable): 41 | """Add an event listener function""" 42 | self.event_listeners.append(listener) 43 | 44 | def remove_event_listener(self, listener: Callable): 45 | """Remove an event listener function""" 46 | if listener in self.event_listeners: 47 | self.event_listeners.remove(listener) 48 | 49 | async def broadcast_event(self, event_type: str, event_data: Dict[str, Any]): 50 | """Broadcast an event to all active connections - event-driven only""" 51 | if not self.active_connections: 52 | return 53 | 54 | event_payload = { 55 | "type": event_type, 56 | "timestamp": time.time(), 57 | "event_driven": True, # Mark as event-driven 58 | "polling_disabled": True, # Confirm no polling 59 | **event_data 60 | } 61 | 62 | # Send to all active connections 63 | disconnected_connections = [] 64 | for connection_id, queue in self.active_connections.items(): 65 | try: 66 | await asyncio.wait_for(queue.put(event_payload), timeout=1.0) 67 | except asyncio.TimeoutError: 68 | logger.warning(f"SSE queue full for connection {connection_id}") 69 | except Exception as e: 70 | logger.error(f"Error broadcasting to connection {connection_id}: {e}") 71 | disconnected_connections.append(connection_id) 72 | 73 | # Clean up disconnected connections 74 | for connection_id in disconnected_connections: 75 | await self.remove_connection(connection_id) 76 | 77 | # Notify event listeners 78 | for listener in self.event_listeners: 79 | try: 80 | await listener(event_payload) 81 | except Exception as e: 82 | logger.error(f"Error in event listener: {e}") 83 | 84 | def get_status(self) -> Dict[str, Any]: 85 | """Get SSE handler status""" 86 | return { 87 | "active_connections": len(self.active_connections), 88 | "event_listeners": len(self.event_listeners), 89 | "running": self.running, 90 | "connection_ids": list(self.active_connections.keys()) 91 | } 92 | 93 | class SSENetworkMonitor: 94 | """Monitor network changes and broadcast via SSE""" 95 | 96 | def __init__(self, base_url: str): 97 | self.base_url = base_url 98 | self.running = False 99 | self.monitor_task = None 100 | self.sse_handler = None 101 | 102 | async def start(self): 103 | """Start the network monitor""" 104 | self.running = True 105 | # Monitor task can be added here if needed for periodic checks 106 | logger.info("SSE Network Monitor started") 107 | 108 | async def stop(self): 109 | """Stop the network monitor""" 110 | self.running = False 111 | if self.monitor_task: 112 | self.monitor_task.cancel() 113 | try: 114 | await self.monitor_task 115 | except asyncio.CancelledError: 116 | pass 117 | logger.info("SSE Network Monitor stopped") 118 | 119 | def set_sse_handler(self, sse_handler: SSEHandler): 120 | """Set the SSE handler for broadcasting events""" 121 | self.sse_handler = sse_handler 122 | 123 | class SSEParser: 124 | """Parse Server-Sent Events from a stream""" 125 | 126 | def __init__(self): 127 | self.buffer = "" 128 | 129 | def parse_line(self, line: str) -> Optional[SSEChunk]: 130 | """Parse a single SSE line""" 131 | line = line.strip() 132 | 133 | if not line or line.startswith(':'): 134 | return None 135 | 136 | if line.startswith('data: '): 137 | data = line[6:] 138 | if data and data != '[DONE]': 139 | return SSEChunk(data=data) 140 | 141 | return None 142 | 143 | def parse_chunk(self, chunk: str) -> List[SSEChunk]: 144 | """Parse a chunk of data and return SSE events""" 145 | self.buffer += chunk 146 | lines = self.buffer.split('\n') 147 | self.buffer = lines.pop() # Keep incomplete line 148 | 149 | events = [] 150 | for line in lines: 151 | event = self.parse_line(line) 152 | if event: 153 | events.append(event) 154 | 155 | return events 156 | 157 | class SSEStreamHandler: 158 | """Handle SSE streaming with robust error handling""" 159 | 160 | def __init__(self, timeout: int = 30): 161 | self.timeout = timeout 162 | self.parser = SSEParser() 163 | 164 | async def stream_from_response( 165 | self, 166 | response: aiohttp.ClientResponse, 167 | transform_func: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None 168 | ) -> AsyncGenerator[Dict[str, Any], None]: 169 | """Stream and parse SSE data from an aiohttp response""" 170 | try: 171 | async for chunk in response.content.iter_any(): 172 | if chunk: 173 | chunk_str = chunk.decode('utf-8', errors='ignore') 174 | events = self.parser.parse_chunk(chunk_str) 175 | 176 | for event in events: 177 | try: 178 | data = json.loads(event.data) 179 | 180 | # Apply transformation if provided 181 | if transform_func: 182 | data = transform_func(data) 183 | 184 | if data: 185 | yield data 186 | 187 | except json.JSONDecodeError: 188 | logger.warning(f"Failed to parse SSE data: {event.data}") 189 | continue 190 | except Exception as e: 191 | logger.error(f"Error processing SSE event: {e}") 192 | continue 193 | 194 | except asyncio.CancelledError: 195 | logger.info("SSE stream cancelled") 196 | raise 197 | except aiohttp.ClientConnectionError as e: 198 | logger.warning(f"SSE connection closed: {e}") 199 | # Don't re-raise, just end the stream gracefully 200 | except Exception as e: 201 | logger.error(f"Unexpected error in SSE stream: {e}") 202 | # Don't re-raise, just end the stream gracefully 203 | 204 | class OpenAISSETransformer: 205 | """Transform OpenAI SSE format to internal format""" 206 | 207 | @staticmethod 208 | def completion_transform(data: Dict[str, Any]) -> Optional[Dict[str, Any]]: 209 | """Transform OpenAI completion SSE to internal format""" 210 | if data.get('choices') and len(data['choices']) > 0: 211 | choice = data['choices'][0] 212 | if choice.get('text'): 213 | return { 214 | "text": choice['text'], 215 | "finished": choice.get('finish_reason') is not None 216 | } 217 | elif choice.get('finish_reason'): 218 | return {"text": "", "finished": True} 219 | return None 220 | 221 | @staticmethod 222 | def chat_transform(data: Dict[str, Any]) -> Optional[Dict[str, Any]]: 223 | """Transform OpenAI chat SSE to internal format""" 224 | if data.get('choices') and len(data['choices']) > 0: 225 | choice = data['choices'][0] 226 | delta = choice.get('delta', {}) 227 | 228 | if delta.get('content'): 229 | return { 230 | "text": delta['content'], 231 | "finished": choice.get('finish_reason') is not None 232 | } 233 | elif choice.get('finish_reason'): 234 | return {"text": "", "finished": True} 235 | return None 236 | 237 | class SSEForwarder: 238 | """Forward SSE streams between nodes with error handling""" 239 | 240 | def __init__(self, timeout: int = 30): 241 | self.handler = SSEStreamHandler(timeout) 242 | self.transformer = OpenAISSETransformer() 243 | 244 | async def forward_completion_stream( 245 | self, 246 | url: str, 247 | request_data: Dict[str, Any], 248 | headers: Optional[Dict[str, str]] = None 249 | ) -> AsyncGenerator[Dict[str, Any], None]: 250 | """Forward a completion stream from another node""" 251 | 252 | default_headers = {"Content-Type": "application/json"} 253 | if headers: 254 | default_headers.update(headers) 255 | 256 | timeout = aiohttp.ClientTimeout(total=self.handler.timeout, connect=5) 257 | 258 | try: 259 | async with aiohttp.ClientSession(timeout=timeout) as session: 260 | async with session.post(url, json=request_data, headers=default_headers) as response: 261 | if response.status != 200: 262 | error_text = await response.text() 263 | logger.error(f"Forwarded request failed: {response.status} {error_text}") 264 | return 265 | 266 | async for data in self.handler.stream_from_response( 267 | response, 268 | self.transformer.completion_transform 269 | ): 270 | yield data 271 | 272 | except asyncio.TimeoutError: 273 | logger.error(f"Timeout forwarding stream to {url}") 274 | except Exception as e: 275 | logger.error(f"Error forwarding stream to {url}: {e}") 276 | 277 | async def forward_chat_stream( 278 | self, 279 | url: str, 280 | request_data: Dict[str, Any], 281 | headers: Optional[Dict[str, str]] = None 282 | ) -> AsyncGenerator[Dict[str, Any], None]: 283 | """Forward a chat completion stream from another node""" 284 | 285 | default_headers = {"Content-Type": "application/json"} 286 | if headers: 287 | default_headers.update(headers) 288 | 289 | timeout = aiohttp.ClientTimeout(total=self.handler.timeout, connect=5) 290 | 291 | try: 292 | async with aiohttp.ClientSession(timeout=timeout) as session: 293 | async with session.post(url, json=request_data, headers=default_headers) as response: 294 | if response.status != 200: 295 | error_text = await response.text() 296 | logger.error(f"Forwarded chat request failed: {response.status} {error_text}") 297 | return 298 | 299 | async for data in self.handler.stream_from_response( 300 | response, 301 | self.transformer.chat_transform 302 | ): 303 | yield data 304 | 305 | except asyncio.TimeoutError: 306 | logger.error(f"Timeout forwarding chat stream to {url}") 307 | except Exception as e: 308 | logger.error(f"Error forwarding chat stream to {url}: {e}") 309 | --------------------------------------------------------------------------------