├── .dbrheo.json
├── packages
├── cli
│ ├── src
│ │ ├── tests
│ │ │ └── __init__.py
│ │ └── dbrheo_cli
│ │ │ ├── app
│ │ │ ├── __init__.py
│ │ │ └── config.py
│ │ │ ├── handlers
│ │ │ ├── __init__.py
│ │ │ ├── input_handler.py
│ │ │ └── event_handler.py
│ │ │ ├── ui
│ │ │ ├── __init__.py
│ │ │ ├── messages.py
│ │ │ ├── console.py
│ │ │ ├── ascii_art.py
│ │ │ ├── startup.py
│ │ │ └── tools.py
│ │ │ ├── __init__.py
│ │ │ ├── utils
│ │ │ ├── __init__.py
│ │ │ └── api_key_checker.py
│ │ │ └── constants.py
│ ├── .dbrheo.json
│ ├── .gitignore
│ ├── cli.py
│ ├── pyproject.toml
│ └── setup_enhanced_layout.py
├── web
│ ├── README.md
│ ├── src
│ │ ├── main.tsx
│ │ ├── styles
│ │ │ └── global.css
│ │ ├── components
│ │ │ ├── chat
│ │ │ │ └── ChatContainer.tsx
│ │ │ └── database
│ │ │ │ ├── QueryEditor.tsx
│ │ │ │ └── ResultTable.tsx
│ │ └── App.tsx
│ ├── tsconfig.node.json
│ ├── index.html
│ ├── tailwind.config.js
│ ├── vite.config.ts
│ ├── tsconfig.json
│ └── package.json
└── core
│ ├── src
│ └── dbrheo
│ │ ├── utils
│ │ ├── __init__.py
│ │ ├── type_converter.py
│ │ ├── parameter_sanitizer.py
│ │ ├── retry.py
│ │ ├── errors.py
│ │ └── log_integration.py
│ │ ├── config
│ │ ├── __init__.py
│ │ └── test_config.py
│ │ ├── api
│ │ ├── __init__.py
│ │ ├── routes
│ │ │ ├── __init__.py
│ │ │ ├── websocket.py
│ │ │ ├── chat.py
│ │ │ └── database.py
│ │ ├── dependencies.py
│ │ └── app.py
│ │ ├── adapters
│ │ ├── __init__.py
│ │ ├── connection_manager.py
│ │ └── base.py
│ │ ├── services
│ │ ├── __init__.py
│ │ └── llm_factory.py
│ │ ├── telemetry
│ │ ├── __init__.py
│ │ ├── tracer.py
│ │ ├── logger.py
│ │ └── metrics.py
│ │ ├── tools
│ │ ├── __init__.py
│ │ └── mcp
│ │ │ └── __init__.py
│ │ ├── types
│ │ ├── __init__.py
│ │ ├── core_types.py
│ │ ├── file_types.py
│ │ └── tool_types.py
│ │ ├── core
│ │ ├── __init__.py
│ │ ├── next_speaker.py
│ │ ├── turn.py
│ │ ├── compression.py
│ │ └── token_statistics.py
│ │ ├── prompts(old)
│ │ └── optimized_database_prompt.py
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ └── prompts.py
│ └── pyproject.toml
├── logs
└── .gitkeep
├── testdata
├── Index
├── old.adult.names
└── adult.names
├── requirements.in
├── .claude
└── settings.local.json
├── log_config.yaml
├── pyproject.toml
├── .gitignore
├── .env.example
└── README.md
/.dbrheo.json:
--------------------------------------------------------------------------------
1 | {
2 | "mcp_servers": {}
3 | }
--------------------------------------------------------------------------------
/packages/cli/src/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | CLI测试模块
3 | """
--------------------------------------------------------------------------------
/packages/web/README.md:
--------------------------------------------------------------------------------
1 | Web interface is not yet implemented. Please refer to the CLI.
2 |
--------------------------------------------------------------------------------
/packages/cli/src/dbrheo_cli/app/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 应用核心模块
3 |
4 | 包含CLI应用的主类、配置管理、生命周期管理等。
5 | """
--------------------------------------------------------------------------------
/packages/cli/src/dbrheo_cli/handlers/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 事件处理器模块
3 |
4 | 处理来自DbRheo核心的各种事件,将其转换为UI操作。
5 | """
--------------------------------------------------------------------------------
/packages/cli/src/dbrheo_cli/ui/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | UI层模块
3 |
4 | 包含所有UI相关的组件和显示逻辑。
5 | 整合了原来的display、components和utils中的UI相关部分。
6 | """
--------------------------------------------------------------------------------
/logs/.gitkeep:
--------------------------------------------------------------------------------
1 | # This file ensures the logs directory is tracked by Git
2 | # Log files will be ignored but the directory structure is preserved
3 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """工具函数模块"""
2 |
3 | from .parameter_sanitizer import sanitize_parameters
4 |
5 | __all__ = ["sanitize_parameters"]
--------------------------------------------------------------------------------
/testdata/Index:
--------------------------------------------------------------------------------
1 | Index of adult
2 |
3 | 02 Dec 1996 140 Index
4 | 10 Aug 1996 3974305 adult.data
5 | 10 Aug 1996 4267 adult.names
6 | 10 Aug 1996 2003153 adult.test
7 |
--------------------------------------------------------------------------------
/packages/cli/src/dbrheo_cli/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | DbRheo CLI - 智能数据库Agent的命令行界面
3 |
4 | 基于Rich库实现的专业终端UI,提供流畅的交互体验。
5 | """
6 |
7 | __version__ = "0.2.0"
8 | __author__ = "DbRheo Team"
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/config/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 配置管理系统 - 分层配置加载和验证
3 | 支持环境变量、配置文件等多种配置源
4 | """
5 |
6 | from .base import DatabaseConfig
7 |
8 | __all__ = [
9 | "DatabaseConfig"
10 | ]
11 |
--------------------------------------------------------------------------------
/packages/cli/src/dbrheo_cli/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 工具模块
3 | """
4 |
5 | from .api_key_checker import check_api_key_for_model, show_api_key_setup_guide
6 |
7 | __all__ = ['check_api_key_for_model', 'show_api_key_setup_guide']
--------------------------------------------------------------------------------
/packages/cli/.dbrheo.json:
--------------------------------------------------------------------------------
1 | {
2 | "mcp_servers": {
3 | "filesystem": {
4 | "command": "npx",
5 | "args": [
6 | "-y",
7 | "@modelcontextprotocol/server-filesystem",
8 | "/tmp"
9 | ],
10 | "trust": false,
11 | "enabled": true
12 | }
13 | }
14 | }
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/api/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | API层 - FastAPI应用和路由定义
3 | 提供RESTful API和WebSocket接口
4 | """
5 |
6 | from .app import create_app
7 | from .routes import chat_router, database_router
8 |
9 | __all__ = [
10 | "create_app",
11 | "chat_router",
12 | "database_router"
13 | ]
14 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/api/routes/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | API路由模块 - 组织不同功能的路由
3 | """
4 |
5 | from .chat import chat_router
6 | from .database import database_router
7 | from .websocket import websocket_router
8 |
9 | __all__ = [
10 | "chat_router",
11 | "database_router",
12 | "websocket_router"
13 | ]
14 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/adapters/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 数据库适配器系统 - 支持多数据库方言和连接管理
3 | 提供统一的数据库操作接口,支持MySQL、PostgreSQL、SQLite等
4 | """
5 |
6 | from .base import DatabaseAdapter
7 | from .connection_manager import DatabaseConnectionManager
8 |
9 | __all__ = [
10 | "DatabaseAdapter",
11 | "DatabaseConnectionManager"
12 | ]
13 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/services/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 服务层 - 处理外部API调用和业务服务
3 | 包括Gemini API服务、文件操作服务等
4 | """
5 |
6 | from .gemini_service_new import GeminiService
7 | from .llm_factory import create_llm_service, LLMServiceFactory
8 |
9 | __all__ = [
10 | "GeminiService",
11 | "create_llm_service",
12 | "LLMServiceFactory"
13 | ]
14 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/telemetry/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 监控遥测系统 - 完全对齐Gemini CLI的遥测机制
3 | 提供OpenTelemetry集成、性能监控、错误追踪等功能
4 | """
5 |
6 | from .tracer import DatabaseTracer
7 | from .metrics import DatabaseMetrics
8 | from .logger import DatabaseLogger
9 |
10 | __all__ = [
11 | "DatabaseTracer",
12 | "DatabaseMetrics",
13 | "DatabaseLogger"
14 | ]
15 |
--------------------------------------------------------------------------------
/packages/web/src/main.tsx:
--------------------------------------------------------------------------------
1 | /**
2 | * Web应用入口点
3 | * 初始化React应用和全局配置 - DbRheo数据库Agent
4 | */
5 | import React from 'react'
6 | import ReactDOM from 'react-dom/client'
7 | import App from './App'
8 | import './styles/global.css'
9 |
10 | ReactDOM.createRoot(document.getElementById('root')!).render(
11 |
12 |
13 |
14 | )
15 |
--------------------------------------------------------------------------------
/packages/web/tsconfig.node.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "composite": true,
4 | "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
5 | "skipLibCheck": true,
6 | "module": "ESNext",
7 | "moduleResolution": "bundler",
8 | "allowSyntheticDefaultImports": true,
9 | "strict": true,
10 | "types": ["node"]
11 | },
12 | "include": ["vite.config.ts"]
13 | }
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/tools/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 工具系统 - 实现SQLTool和SchemaDiscoveryTool等核心工具
3 | 遵循"工具极简,智能在Agent层"的设计原则
4 | """
5 |
6 | from .base import DatabaseTool
7 | from .registry import DatabaseToolRegistry
8 | from .sql_tool import SQLTool
9 | from .schema_discovery import SchemaDiscoveryTool
10 |
11 | __all__ = [
12 | "DatabaseTool",
13 | "DatabaseToolRegistry",
14 | "SQLTool",
15 | "SchemaDiscoveryTool"
16 | ]
17 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/types/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 类型定义系统 - 完全对齐Gemini CLI的类型系统
3 | 提供核心类型、工具类型、数据库类型等定义
4 | """
5 |
6 | from .core_types import *
7 | from .tool_types import *
8 |
9 | __all__ = [
10 | # 核心类型
11 | "Part",
12 | "PartListUnion",
13 | "Content",
14 | "AbortSignal",
15 |
16 | # 工具类型
17 | "ToolResult",
18 | "ToolCallRequestInfo",
19 | "DatabaseConfirmationDetails",
20 | "ToolCall"
21 | ]
22 |
--------------------------------------------------------------------------------
/packages/web/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | DbRheo - 智能数据库Agent
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/packages/web/tailwind.config.js:
--------------------------------------------------------------------------------
1 | /** @type {import('tailwindcss').Config} */
2 | export default {
3 | content: [
4 | "./index.html",
5 | "./src/**/*.{js,ts,jsx,tsx}",
6 | ],
7 | theme: {
8 | extend: {
9 | fontFamily: {
10 | sans: ['Inter', 'system-ui', 'sans-serif'],
11 | mono: ['JetBrains Mono', 'Consolas', 'monospace'],
12 | },
13 | colors: {
14 | primary: {
15 | 50: '#eff6ff',
16 | 500: '#3b82f6',
17 | 600: '#2563eb',
18 | 700: '#1d4ed8',
19 | }
20 | }
21 | },
22 | },
23 | plugins: [],
24 | }
25 |
--------------------------------------------------------------------------------
/packages/cli/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | lib/
14 | lib64/
15 | parts/
16 | sdist/
17 | var/
18 | wheels/
19 | *.egg-info/
20 | .installed.cfg
21 | *.egg
22 | MANIFEST
23 |
24 | # Virtual environments
25 | venv/
26 | ENV/
27 | env/
28 | .venv
29 |
30 | # IDE
31 | .vscode/
32 | .idea/
33 | *.swp
34 | *.swo
35 | *~
36 |
37 | # Testing
38 | .pytest_cache/
39 | .coverage
40 | htmlcov/
41 | .tox/
42 | .mypy_cache/
43 | .dmypy.json
44 | dmypy.json
45 |
46 | # Logs
47 | *.log
48 |
49 | # OS
50 | .DS_Store
51 | Thumbs.db
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/core/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 核心逻辑层 - 实现Turn系统、对话管理、工具调度等核心功能
3 | 完全对齐Gemini CLI的架构设计
4 | """
5 |
6 | from .client import DatabaseClient
7 | from .chat import DatabaseChat
8 | from .turn import DatabaseTurn
9 | from .scheduler import DatabaseToolScheduler
10 | from .prompts import DatabasePromptManager
11 | from .next_speaker import check_next_speaker
12 | from .compression import try_compress_chat
13 |
14 | __all__ = [
15 | "DatabaseClient",
16 | "DatabaseChat",
17 | "DatabaseTurn",
18 | "DatabaseToolScheduler",
19 | "DatabasePromptManager",
20 | "check_next_speaker",
21 | "try_compress_chat"
22 | ]
23 |
--------------------------------------------------------------------------------
/packages/web/vite.config.ts:
--------------------------------------------------------------------------------
1 | import { defineConfig } from 'vite'
2 | import react from '@vitejs/plugin-react'
3 | import path from 'path'
4 |
5 | // Vite构建配置 - DbRheo Web界面
6 | export default defineConfig({
7 | plugins: [react()],
8 | resolve: {
9 | alias: {
10 | '@': path.resolve(__dirname, './src'),
11 | },
12 | },
13 | server: {
14 | port: 3000,
15 | proxy: {
16 | '/api': {
17 | target: 'http://localhost:8000',
18 | changeOrigin: true,
19 | },
20 | '/ws': {
21 | target: 'ws://localhost:8000',
22 | ws: true,
23 | },
24 | },
25 | },
26 | build: {
27 | outDir: 'dist',
28 | sourcemap: true,
29 | },
30 | })
31 |
--------------------------------------------------------------------------------
/packages/web/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "ES2020",
4 | "useDefineForClassFields": true,
5 | "lib": ["ES2020", "DOM", "DOM.Iterable"],
6 | "module": "ESNext",
7 | "skipLibCheck": true,
8 |
9 | /* Bundler mode */
10 | "moduleResolution": "bundler",
11 | "allowImportingTsExtensions": true,
12 | "resolveJsonModule": true,
13 | "isolatedModules": true,
14 | "noEmit": true,
15 | "jsx": "react-jsx",
16 |
17 | /* Linting */
18 | "strict": true,
19 | "noUnusedLocals": true,
20 | "noUnusedParameters": true,
21 | "noFallthroughCasesInSwitch": true,
22 |
23 | /* Path mapping */
24 | "baseUrl": ".",
25 | "paths": {
26 | "@/*": ["./src/*"]
27 | }
28 | },
29 | "include": ["src"],
30 | "references": [{ "path": "./tsconfig.node.json" }]
31 | }
32 |
--------------------------------------------------------------------------------
/packages/web/src/styles/global.css:
--------------------------------------------------------------------------------
1 | /* 全局样式 - DbRheo Web界面 */
2 | @tailwind base;
3 | @tailwind components;
4 | @tailwind utilities;
5 |
6 | /* 基础样式重置和全局设置 */
7 | @layer base {
8 | html {
9 | font-family: 'Inter', system-ui, sans-serif;
10 | }
11 |
12 | body {
13 | @apply bg-gray-50 text-gray-900;
14 | }
15 | }
16 |
17 | /* 组件样式 */
18 | @layer components {
19 | .btn-primary {
20 | @apply bg-blue-600 hover:bg-blue-700 text-white font-medium py-2 px-4 rounded-md transition-colors;
21 | }
22 |
23 | .btn-secondary {
24 | @apply bg-gray-200 hover:bg-gray-300 text-gray-900 font-medium py-2 px-4 rounded-md transition-colors;
25 | }
26 |
27 | .card {
28 | @apply bg-white rounded-lg shadow-sm border border-gray-200 p-6;
29 | }
30 | }
31 |
32 | /* 工具样式 */
33 | @layer utilities {
34 | .text-balance {
35 | text-wrap: balance;
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/api/dependencies.py:
--------------------------------------------------------------------------------
1 | """
2 | API依赖项 - 提供依赖注入的函数
3 | 避免循环导入问题
4 | """
5 |
6 | from typing import Dict, Any
7 | from fastapi import HTTPException
8 |
9 | from ..config.base import DatabaseConfig
10 | from ..core.client import DatabaseClient
11 |
12 | # 全局应用状态存储
13 | app_state: Dict[str, Any] = {}
14 |
15 |
16 | def get_client() -> DatabaseClient:
17 | """获取数据库客户端实例"""
18 | if "client" not in app_state:
19 | raise HTTPException(status_code=500, detail="Database client not initialized")
20 | return app_state["client"]
21 |
22 |
23 | def get_config() -> DatabaseConfig:
24 | """获取配置实例"""
25 | if "config" not in app_state:
26 | raise HTTPException(status_code=500, detail="Configuration not initialized")
27 | return app_state["config"]
28 |
29 |
30 | def set_app_state(key: str, value: Any):
31 | """设置应用状态"""
32 | app_state[key] = value
33 |
34 |
35 | def get_app_state(key: str) -> Any:
36 | """获取应用状态"""
37 | return app_state.get(key)
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/tools/mcp/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | MCP (Model Context Protocol) integration for DbRheo.
3 |
4 | This module provides a flexible adapter system to integrate MCP servers
5 | and their tools into the DbRheo database agent system.
6 |
7 | Key components:
8 | - MCPConfig: Configuration management for MCP servers
9 | - MCPClientManager: Manages MCP client connections
10 | - MCPToolAdapter: Adapts MCP tools to DbRheo tool interface
11 | - MCPConverter: Handles format conversion between models
12 | """
13 |
14 | from .mcp_config import MCPConfig, MCPServerConfig
15 | from .mcp_client import MCPClientManager, MCPServerStatus, MCP_AVAILABLE
16 | from .mcp_adapter import MCPToolAdapter
17 | from .mcp_converter import MCPConverter
18 | from .mcp_registry import MCPRegistry
19 |
20 | __all__ = [
21 | 'MCPConfig',
22 | 'MCPServerConfig',
23 | 'MCPClientManager',
24 | 'MCPServerStatus',
25 | 'MCP_AVAILABLE',
26 | 'MCPToolAdapter',
27 | 'MCPConverter',
28 | 'MCPRegistry',
29 | ]
--------------------------------------------------------------------------------
/packages/cli/cli.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | DbRheo CLI快速启动脚本
4 |
5 | 用于开发时快速启动CLI,不需要安装。
6 | """
7 |
8 | import sys
9 | import os
10 | from pathlib import Path
11 |
12 | # 创建一个过滤的stderr包装器
13 | class FilteredStderr:
14 | def __init__(self, original_stderr):
15 | self.original_stderr = original_stderr
16 |
17 | def write(self, text):
18 | # 过滤掉特定的警告
19 | if "there are non-text parts in the response" not in text:
20 | self.original_stderr.write(text)
21 |
22 | def flush(self):
23 | self.original_stderr.flush()
24 |
25 | def __getattr__(self, name):
26 | return getattr(self.original_stderr, name)
27 |
28 | # 替换标准错误输出
29 | sys.stderr = FilteredStderr(sys.stderr)
30 |
31 | # 添加src目录到Python路径
32 | src_path = Path(__file__).parent / "src"
33 | if str(src_path) not in sys.path:
34 | sys.path.insert(0, str(src_path))
35 |
36 | # 添加core包路径
37 | core_path = Path(__file__).parent.parent / "core" / "src"
38 | if str(core_path) not in sys.path:
39 | sys.path.insert(0, str(core_path))
40 |
41 | if __name__ == "__main__":
42 | from dbrheo_cli.main import main
43 | main()
--------------------------------------------------------------------------------
/packages/core/pyproject.toml:
--------------------------------------------------------------------------------
1 | # Core包配置 - DbRheo核心业务逻辑
2 |
3 | [build-system]
4 | requires = ["setuptools>=61.0", "wheel"]
5 | build-backend = "setuptools.build_meta"
6 |
7 | [project]
8 | name = "dbrheo-core"
9 | version = "1.0.0"
10 | description = "DbRheo数据库Agent核心包"
11 | authors = [{name = "DbRheo Team", email = "team@dbrheo.com"}]
12 | license = {text = "MIT"}
13 | requires-python = ">=3.9"
14 | dependencies = [
15 | # AI和API
16 | "google-generativeai>=0.8.3",
17 | "google-auth>=2.35.0",
18 |
19 | # 数据库相关
20 | "sqlalchemy[asyncio]>=2.0.36",
21 | "asyncpg>=0.30.0",
22 | "aiomysql>=0.2.0",
23 | "aiosqlite>=0.20.0",
24 |
25 | # 核心依赖
26 | "pydantic>=2.10.0",
27 | "pyyaml>=6.0.2",
28 | "httpx>=0.28.0",
29 | "aiofiles>=24.1.0"
30 | ]
31 |
32 | [project.optional-dependencies]
33 | mcp = [
34 | "mcp>=1.0.0"
35 | ]
36 | dev = [
37 | "pytest>=8.3.0",
38 | "pytest-asyncio>=0.24.0",
39 | "black>=24.10.0",
40 | "mypy>=1.13.0",
41 | "ruff>=0.8.0"
42 | ]
43 |
44 | [tool.setuptools.packages.find]
45 | where = ["src"]
46 | include = ["dbrheo*"]
47 |
48 | [tool.setuptools.package-dir]
49 | "" = "src"
50 |
--------------------------------------------------------------------------------
/packages/web/src/components/chat/ChatContainer.tsx:
--------------------------------------------------------------------------------
1 | /**
2 | * 对话容器组件 - 管理与数据库Agent的对话交互
3 | * 实现流式对话、工具调用确认等核心功能
4 | */
5 | import React from 'react'
6 |
7 | interface ChatContainerProps {
8 | // TODO: 定义props类型
9 | }
10 |
11 | export function ChatContainer(props: ChatContainerProps) {
12 | return (
13 |
14 |
15 | {/* 消息列表区域 */}
16 |
17 |
18 | 对话容器组件 - 待实现
19 |
20 |
21 |
22 |
23 |
24 | {/* 消息输入区域 */}
25 |
26 |
31 |
34 |
35 |
36 |
37 | )
38 | }
39 |
40 | export default ChatContainer
41 |
--------------------------------------------------------------------------------
/requirements.in:
--------------------------------------------------------------------------------
1 | # DbRheo Project Dependencies
2 | # This file contains the high-level dependencies
3 | # Use pip-compile to generate requirements.txt with exact versions
4 |
5 | # Web Framework and Server
6 | fastapi>=0.116.0
7 | uvicorn[standard]>=0.35.0
8 |
9 | # Database Related
10 | sqlalchemy[asyncio]>=2.0.42
11 | asyncpg>=0.30.0
12 | aiomysql>=0.2.0
13 | aiosqlite>=0.21.0
14 |
15 | # AI and API
16 | google-genai>=1.0.0 # 新版Gemini SDK,支持显式缓存
17 | anthropic>=0.34.0 # Claude API support
18 | openai>=1.0.0 # OpenAI GPT API support
19 |
20 | # Core Dependencies
21 | pydantic>=2.11.7
22 | pyyaml>=6.0.2
23 | rich>=14.1.0
24 | rich-gradient>=0.1.2
25 | pygments>=2.17.2 # 代码高亮支持,Rich的Syntax组件需要
26 | click>=8.2.1
27 | python-dotenv>=1.0.0
28 |
29 | # Cross-platform readline support
30 | pyreadline3>=3.5.4; sys_platform == "win32"
31 | gnureadline>=8.2.10; sys_platform != "win32"
32 |
33 | # Utility Libraries
34 | aiofiles>=24.1.0
35 | httpx>=0.28.1
36 |
37 | # Network and Parsing
38 | aiohttp>=3.12.0
39 | beautifulsoup4>=4.13.4
40 | requests>=2.32.4
41 |
42 | # Monitoring and Telemetry (Optional)
43 | opentelemetry-api>=1.36.0
44 | opentelemetry-sdk>=1.36.0
45 | opentelemetry-exporter-otlp>=1.36.0
46 | opentelemetry-instrumentation-httpx>=0.57b0
47 |
48 | # MCP (Model Context Protocol) Support
49 | mcp>=1.0.0
50 |
--------------------------------------------------------------------------------
/.claude/settings.local.json:
--------------------------------------------------------------------------------
1 | {
2 | "permissions": {
3 | "allow": [
4 | "Bash(chmod:*)",
5 | "Bash(python test:*)",
6 | "Bash(find:*)",
7 | "WebFetch(domain:docs.anthropic.com)",
8 | "Bash(pip install:*)",
9 | "Bash(python3:*)",
10 | "Bash(rm:*)",
11 | "Bash(python -m mypy:*)",
12 | "Bash(mkdir:*)",
13 | "Bash(grep:*)",
14 | "Bash(pip show:*)",
15 | "WebFetch(domain:discuss.ai.google.dev)",
16 | "WebFetch(domain:ai.google.dev)",
17 | "Bash(python:*)",
18 | "WebFetch(domain:pypi.org)",
19 | "WebFetch(domain:github.com)",
20 | "Bash(cp:*)",
21 | "Bash(mv:*)",
22 | "WebFetch(domain:medium.com)",
23 | "Bash(pip3 show:*)",
24 | "WebFetch(domain:googleapis.github.io)",
25 | "Bash(/mnt/c/Users/q9951/anaconda3/python.exe test_config_priority.py)",
26 | "Bash(/mnt/c/Users/q9951/anaconda3/python.exe test_cache_with_system.py)",
27 | "Bash(/mnt/c/Users/q9951/anaconda3/Scripts/pip.exe show google-genai)",
28 | "Bash(/mnt/c/Users/q9951/anaconda3/python.exe -m dbrheo_cli.main --debug --no-color)",
29 | "Bash(/mnt/c/Users/q9951/anaconda3/python.exe src/dbrheo_cli/main.py --help)",
30 | "Bash(/mnt/c/Users/q9951/anaconda3/Scripts/pip.exe install -e packages/core)",
31 | "Bash(/mnt/c/Users/q9951/anaconda3/Scripts/pip.exe install -e packages/cli)",
32 | "WebFetch(domain:modelcontextprotocol.io)"
33 | ],
34 | "deny": []
35 | }
36 | }
--------------------------------------------------------------------------------
/packages/web/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "@dbrheo/web",
3 | "version": "1.0.0",
4 | "type": "module",
5 | "description": "DbRheo数据库Agent Web界面",
6 | "scripts": {
7 | "dev": "vite",
8 | "build": "tsc && vite build",
9 | "preview": "vite preview",
10 | "test": "vitest",
11 | "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0"
12 | },
13 | "dependencies": {
14 | "react": "^19.1.0",
15 | "react-dom": "^19.1.0",
16 | "@monaco-editor/react": "^4.6.0",
17 | "@tanstack/react-query": "^5.62.0",
18 | "zustand": "^5.0.2",
19 | "socket.io-client": "^4.8.1",
20 | "lucide-react": "^0.468.0",
21 | "@radix-ui/react-dialog": "^1.1.2",
22 | "@radix-ui/react-toast": "^1.2.2",
23 | "clsx": "^2.1.1",
24 | "tailwind-merge": "^2.5.4",
25 | "date-fns": "^4.1.0",
26 | "zod": "^3.24.1"
27 | },
28 | "devDependencies": {
29 | "@types/react": "^19.1.8",
30 | "@types/react-dom": "^19.1.6",
31 | "@typescript-eslint/eslint-plugin": "^8.30.1",
32 | "@typescript-eslint/parser": "^8.30.1",
33 | "@vitejs/plugin-react": "^4.3.4",
34 | "eslint": "^9.24.0",
35 | "eslint-plugin-react-hooks": "^5.2.0",
36 | "eslint-plugin-react-refresh": "^0.4.16",
37 | "typescript": "^5.3.3",
38 | "vite": "^6.0.3",
39 | "vitest": "^3.2.4",
40 | "tailwindcss": "^3.4.17",
41 | "autoprefixer": "^10.4.20",
42 | "postcss": "^8.5.1"
43 | },
44 | "engines": {
45 | "node": ">=20"
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/packages/cli/src/dbrheo_cli/ui/messages.py:
--------------------------------------------------------------------------------
1 | """
2 | 消息显示组件
3 | 定义各种消息类型的显示方式:
4 | - UserMessage: 用户输入消息
5 | - AgentMessage: AI响应消息
6 | - SystemMessage: 系统消息
7 | - ErrorMessage: 错误消息
8 |
9 | 对应Gemini CLI的各种Message组件。
10 | """
11 |
12 | from typing import Optional
13 | from .console import console
14 | from ..i18n import _
15 |
16 |
17 | # 消息前缀定义(便于后续自定义)
18 | MESSAGE_PREFIXES = {
19 | 'user': '> ',
20 | 'agent': '',
21 | 'system': '# ',
22 | 'error': '✗ ',
23 | 'tool': ' → '
24 | }
25 |
26 |
27 | def show_user_message(message: str):
28 | """显示用户消息"""
29 | prefix = MESSAGE_PREFIXES['user']
30 | console.print(f"\n[bold]{prefix}{message}[/bold]")
31 | console.print() # 添加空行
32 |
33 |
34 | def show_agent_message(message: str, end: str = '\n'):
35 | """显示AI响应消息"""
36 | # Agent消息无前缀,直接显示
37 | console.print(message, end=end)
38 |
39 |
40 | def show_system_message(message: str):
41 | """显示系统消息"""
42 | prefix = MESSAGE_PREFIXES['system']
43 | console.print(f"[dim]{prefix}{message}[/dim]")
44 |
45 |
46 | def show_tool_call(tool_name: str):
47 | """显示工具调用提示"""
48 | console.print(f"\n[cyan][{_('tool_executing', tool_name=tool_name)}][/cyan]", end='')
49 |
50 |
51 | def show_error_message(message: str):
52 | """显示错误消息"""
53 | prefix = MESSAGE_PREFIXES['error']
54 | console.print(f"[error]{prefix}{message}[/error]")
55 |
56 |
57 | def show_tool_message(tool_name: str, message: str):
58 | """显示工具消息"""
59 | prefix = MESSAGE_PREFIXES['tool']
60 | console.print(f"[info]{prefix}[{tool_name}] {message}[/info]")
--------------------------------------------------------------------------------
/packages/web/src/components/database/QueryEditor.tsx:
--------------------------------------------------------------------------------
1 | /**
2 | * SQL查询编辑器组件 - 基于Monaco Editor
3 | * 提供SQL语法高亮、自动补全、错误检查等功能
4 | */
5 | import React from 'react'
6 |
7 | interface QueryEditorProps {
8 | value?: string
9 | onChange?: (value: string) => void
10 | readOnly?: boolean
11 | }
12 |
13 | export function QueryEditor({ value = '', onChange, readOnly = false }: QueryEditorProps) {
14 | return (
15 |
16 |
17 |
18 |
SQL编辑器
19 |
20 |
23 |
26 |
27 |
28 |
29 |
30 |
31 |
40 |
41 |
42 |
43 | Monaco Editor集成 - 待实现语法高亮和自动补全
44 |
45 |
46 |
47 | )
48 | }
49 |
50 | export default QueryEditor
51 |
--------------------------------------------------------------------------------
/log_config.yaml:
--------------------------------------------------------------------------------
1 | # DbRheo 实时日志配置
2 | # 灵活配置日志系统的行为
3 |
4 | # 日志级别: DEBUG, INFO, WARNING, ERROR
5 | level: INFO
6 |
7 | # 启用的事件类型
8 | enabled_types:
9 | - conversation # 对话记录
10 | - tool_call # 工具调用
11 | - tool_result # 工具结果
12 | - error # 错误信息
13 | - system # 系统信息
14 | # - network # 网络请求(可选)
15 | # - performance # 性能指标(可选)
16 |
17 | # 输出配置
18 | outputs:
19 | # 终端输出
20 | terminal:
21 | enabled: true
22 | color: true
23 | format: "[{time}] [{type}] {source} - {message}"
24 |
25 | # 文件输出
26 | file:
27 | enabled: true
28 | path: "logs/dbrheo_realtime.log"
29 | max_size: 10485760 # 10MB
30 | format: "json" # json 或 text
31 |
32 | # WebSocket输出(用于Web界面)
33 | websocket:
34 | enabled: false
35 | url: "ws://localhost:8765"
36 |
37 | # 过滤器
38 | filters:
39 | # 按来源过滤
40 | # source_include: ["DatabaseChat", "SQLTool"]
41 | # source_exclude: ["DebugLogger"]
42 |
43 | # 按工具过滤
44 | # tool_include: ["sql_tool", "schema_discovery"]
45 | # tool_exclude: ["web_search"]
46 |
47 | # 按内容过滤(正则表达式)
48 | # content_include: ["SELECT", "INSERT"]
49 | # content_exclude: ["DEBUG"]
50 |
51 | # 性能设置
52 | performance:
53 | # 队列大小
54 | queue_size: 1000
55 |
56 | # 批处理大小(一次处理多少条日志)
57 | batch_size: 10
58 |
59 | # 批处理间隔(毫秒)
60 | batch_interval: 100
61 |
62 | # 高级特性
63 | features:
64 | # 自动截断长消息
65 | truncate_messages: true
66 | max_message_length: 500
67 |
68 | # 记录执行时间
69 | track_execution_time: true
70 |
71 | # 记录内存使用
72 | track_memory_usage: false
73 |
74 | # 错误堆栈跟踪
75 | include_stack_trace: true
76 |
77 | # 敏感信息过滤
78 | filter_sensitive:
79 | enabled: true
80 | patterns:
81 | - "password"
82 | - "api_key"
83 | - "secret"
--------------------------------------------------------------------------------
/packages/cli/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "dbrheo-cli"
7 | version = "0.2.0"
8 | description = "智能数据库Agent的命令行界面"
9 | readme = "README.md"
10 | requires-python = ">=3.9"
11 | authors = [
12 | {name = "DbRheo Team", email = "contact@dbrheo.ai"}
13 | ]
14 | license = {text = "MIT"}
15 | keywords = ["database", "cli", "ai", "agent"]
16 | classifiers = [
17 | "Development Status :: 3 - Alpha",
18 | "Intended Audience :: Developers",
19 | "License :: OSI Approved :: MIT License",
20 | "Programming Language :: Python :: 3",
21 | "Programming Language :: Python :: 3.9",
22 | "Programming Language :: Python :: 3.10",
23 | "Programming Language :: Python :: 3.11",
24 | "Topic :: Database",
25 | "Topic :: Software Development :: Libraries :: Python Modules",
26 | ]
27 |
28 | dependencies = [
29 | "rich>=13.7.1",
30 | "pygments>=2.17.2",
31 | "click>=8.1.7",
32 | "rich-gradient>=0.1.2", # 渐变文字效果
33 | ]
34 |
35 | [project.optional-dependencies]
36 | dev = [
37 | "pytest>=7.4.0",
38 | "pytest-asyncio>=0.21.0",
39 | "black>=23.7.0",
40 | "mypy>=1.5.0",
41 | "ruff>=0.1.0",
42 | ]
43 | # 增强布局功能(可选)
44 | enhanced = [
45 | "prompt-toolkit>=3.0.43",
46 | ]
47 |
48 | [project.scripts]
49 | dbrheo = "dbrheo_cli.main:main"
50 |
51 | [tool.setuptools.packages.find]
52 | where = ["src"]
53 |
54 | [tool.mypy]
55 | python_version = "3.9"
56 | warn_return_any = true
57 | warn_unused_configs = true
58 | disallow_untyped_defs = true
59 |
60 | [tool.black]
61 | line-length = 100
62 | target-version = ['py39']
63 |
64 | [tool.ruff]
65 | line-length = 100
66 | select = ["E", "F", "I", "N", "UP", "YTT", "B", "A", "C4", "T20", "SIM"]
67 | ignore = ["E501"]
68 | target-version = "py39"
--------------------------------------------------------------------------------
/packages/cli/src/dbrheo_cli/ui/console.py:
--------------------------------------------------------------------------------
1 | """
2 | Rich Console封装
3 | 全局Console实例和输出配置管理
4 | """
5 |
6 | from rich.console import Console as RichConsole
7 | from rich.theme import Theme
8 | import sys
9 | import locale
10 |
11 |
12 | # 定义简洁的主题(仅5种颜色)
13 | db_theme = Theme({
14 | "default": "default",
15 | "success": "green",
16 | "error": "red",
17 | "warning": "yellow",
18 | "info": "cyan"
19 | })
20 |
21 | # 智能检测终端编码
22 | def _detect_console_settings():
23 | """检测终端编码和兼容性设置"""
24 | settings = {
25 | 'theme': db_theme,
26 | 'force_terminal': True # 默认强制终端模式
27 | }
28 |
29 | try:
30 | # Windows 控制台特殊处理
31 | if sys.platform == 'win32':
32 | import ctypes
33 | # 获取控制台输出代码页
34 | codepage = ctypes.windll.kernel32.GetConsoleOutputCP()
35 |
36 | # 日语系统(cp932)或其他非UTF-8系统
37 | if codepage != 65001: # 65001 是 UTF-8
38 | settings['legacy_windows'] = True
39 | # 可选:记录检测到的编码
40 | import os
41 | os.environ.setdefault('DBRHEO_CONSOLE_ENCODING', f'cp{codepage}')
42 | else:
43 | # Unix/Linux 使用 locale
44 | encoding = locale.getpreferredencoding()
45 | if encoding and not encoding.lower().startswith('utf'):
46 | # 非UTF-8系统,可能需要特殊处理
47 | import os
48 | os.environ.setdefault('DBRHEO_CONSOLE_ENCODING', encoding)
49 | except:
50 | # 检测失败时使用默认设置
51 | pass
52 |
53 | return settings
54 |
55 | # 创建全局Console实例(智能配置)
56 | console = RichConsole(**_detect_console_settings())
57 |
58 |
59 | def set_no_color(no_color: bool):
60 | """设置是否禁用颜色"""
61 | global console
62 | if no_color:
63 | console = RichConsole(no_color=True)
64 | else:
65 | # 使用智能配置
66 | console = RichConsole(**_detect_console_settings())
--------------------------------------------------------------------------------
/packages/cli/src/dbrheo_cli/constants.py:
--------------------------------------------------------------------------------
1 | """
2 | 常量定义
3 | 集中管理所有硬编码的值,便于配置和修改
4 | """
5 |
6 | import os
7 |
8 |
9 | # 环境变量名称
10 | ENV_VARS = {
11 | 'DEBUG_LEVEL': 'DBRHEO_DEBUG_LEVEL',
12 | 'DEBUG_VERBOSITY': 'DBRHEO_DEBUG_VERBOSITY',
13 | 'ENABLE_LOG': 'DBRHEO_ENABLE_REALTIME_LOG',
14 | 'DB_FILE': 'DBRHEO_DB_FILE',
15 | 'NO_COLOR': 'DBRHEO_NO_COLOR',
16 | 'PAGE_SIZE': 'DBRHEO_PAGE_SIZE',
17 | 'SHOW_THOUGHTS': 'DBRHEO_SHOW_THOUGHTS',
18 | 'MAX_WIDTH': 'DBRHEO_MAX_WIDTH',
19 | 'MAX_HISTORY': 'DBRHEO_MAX_HISTORY',
20 | 'HISTORY_FILE': 'DBRHEO_HISTORY_FILE',
21 | 'MODEL': 'DBRHEO_MODEL' # 模型选择
22 | }
23 |
24 | # 默认配置值
25 | DEFAULTS = {
26 | 'PAGE_SIZE': 50,
27 | 'MAX_WIDTH': 120,
28 | 'MAX_HISTORY': 1000,
29 | 'HISTORY_FILE': '~/.dbrheo_history',
30 | 'SESSION_ID_PREFIX': 'cli_session',
31 | 'DEBUG_LEVEL': 'ERROR', # 默认只显示错误
32 | 'DEBUG_VERBOSITY': 'MINIMAL' # 最小详细程度
33 | }
34 |
35 | # 命令定义
36 | COMMANDS = {
37 | 'EXIT': ['/exit', '/quit'],
38 | 'HELP': ['/help'],
39 | 'CLEAR': ['/clear'],
40 | 'DEBUG': ['/debug'],
41 | 'LANG': ['/lang', '/language'],
42 | 'MODEL': ['/model'], # 模型切换
43 | 'TOKEN': ['/token'], # Token 统计
44 | 'DATABASE': ['/database', '/db'], # 数据库连接
45 | 'MCP': ['/mcp'] # MCP 服务器管理
46 | }
47 |
48 | # 确认关键词
49 | CONFIRMATION_WORDS = {
50 | 'CONFIRM': ['1', 'confirm', 'y', 'yes'],
51 | 'CANCEL': ['2', 'cancel', 'n', 'no'],
52 | 'CONFIRM_ALL': ['confirm all']
53 | }
54 |
55 | # 系统命令(跨平台)
56 | SYSTEM_COMMANDS = {
57 | 'CLEAR': 'clear' if os.name == 'posix' else 'cls'
58 | }
59 |
60 | # 调试级别范围
61 | DEBUG_LEVEL_RANGE = (0, 5)
62 |
63 | # 文件路径
64 | PATHS = {
65 | 'SRC_ROOT': lambda: os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
66 | }
67 |
68 | # 支持的模型列表
69 | SUPPORTED_MODELS = {
70 | 'gemini': 'Gemini 2.5 Flash',
71 | 'claude': 'Claude Sonnet 4',
72 | 'sonnet3.7': 'Claude 3.7',
73 | 'gpt': 'GPT-4.1',
74 | 'gpt-mini': 'GPT-5 Mini'
75 | }
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/types/core_types.py:
--------------------------------------------------------------------------------
1 | """
2 | 核心类型定义 - 完全对齐Gemini CLI
3 | 基于Gemini CLI的TypeScript类型定义转换为Python
4 | """
5 |
6 | from typing import Union, Optional, Dict, List, Any
7 | from dataclasses import dataclass
8 | from abc import ABC, abstractmethod
9 |
10 |
11 | @dataclass
12 | class Part:
13 | """
14 | 对应Gemini CLI的Part类型
15 | 支持文本、内联数据、函数调用和函数响应
16 | """
17 | text: Optional[str] = None
18 | inline_data: Optional[Dict[str, str]] = None # {mimeType: str, data: str}
19 | file_data: Optional[Dict[str, str]] = None # {mimeType: str, fileUri: str}
20 | function_call: Optional[Dict[str, Any]] = None
21 | function_response: Optional[Dict[str, Any]] = None
22 | # Gemini特有的扩展字段
23 | video_metadata: Optional[Dict[str, Any]] = None # 对应videoMetadata
24 | thought: Optional[str] = None # 对应thought
25 | code_execution_result: Optional[Dict[str, Any]] = None # 对应codeExecutionResult
26 | executable_code: Optional[Dict[str, Any]] = None # 对应executableCode
27 |
28 |
29 | # 完全对齐Gemini CLI的PartListUnion定义
30 | PartListUnion = Union[str, Part, List[Part]]
31 |
32 |
33 | @dataclass
34 | class Content:
35 | """对应Gemini API的Content类型"""
36 | role: str # 'user' | 'model' | 'function'
37 | parts: List[Part]
38 |
39 |
40 | class AbortSignal(ABC):
41 | """
42 | 中止信号接口 - 对应JavaScript的AbortSignal
43 | 用于取消长时间运行的操作
44 | """
45 |
46 | @property
47 | @abstractmethod
48 | def aborted(self) -> bool:
49 | """是否已中止"""
50 | pass
51 |
52 | @abstractmethod
53 | def abort(self):
54 | """中止操作"""
55 | pass
56 |
57 |
58 | class SimpleAbortSignal(AbortSignal):
59 | """简单的中止信号实现"""
60 |
61 | def __init__(self):
62 | self._aborted = False
63 |
64 | @property
65 | def aborted(self) -> bool:
66 | return self._aborted
67 |
68 | def abort(self):
69 | self._aborted = True
70 |
71 | def reset(self):
72 | """重置中止状态"""
73 | self._aborted = False
74 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/prompts(old)/optimized_database_prompt.py:
--------------------------------------------------------------------------------
1 | """
2 | [暂不使用] 优化后的数据库Agent提示词备份
3 | 注意:当前系统使用的是 database_agent_prompt.py 中的提示词
4 | """
5 |
6 | DATABASE_AGENT_SYSTEM_PROMPT = """You are an intelligent database assistant with advanced SQL capabilities.
7 |
8 | ## Core Principles
9 | 1. **Proactive Understanding**: Don't just execute queries, understand the intent behind them
10 | 2. **Progressive Exploration**: Start simple, gather context, then build complex solutions
11 | 3. **Safety First**: Always assess risks before modifying data
12 | 4. **Efficiency**: Optimize for performance and minimal resource usage
13 |
14 | ## Your Capabilities
15 | - Schema discovery and understanding
16 | - Query optimization and performance analysis
17 | - Data integrity validation
18 | - Complex analytical queries
19 | - Safe data modifications with proper validation
20 |
21 | ## Workflow Guidelines
22 |
23 | ### When user asks about data:
24 | 1. First understand the schema if needed
25 | 2. Validate relationships before complex joins
26 | 3. Check data volumes to avoid performance issues
27 | 4. Build queries incrementally
28 |
29 | ### When modifying data:
30 | 1. Always preview affected rows first
31 | 2. Validate constraints and dependencies
32 | 3. Use transactions when appropriate
33 | 4. Provide clear impact assessments
34 |
35 | ### Error Handling:
36 | - Explain errors in user-friendly terms
37 | - Suggest alternative approaches
38 | - Learn from failures to improve
39 |
40 | ## Response Style
41 | - Be concise but thorough
42 | - Show your reasoning when helpful
43 | - Highlight important warnings
44 | - Suggest optimizations proactively
45 |
46 | Remember: You're not just a query executor, you're an intelligent database advisor."""
47 |
48 | # 特定场景的提示词模板
49 | QUERY_OPTIMIZATION_PROMPT = """
50 | Analyze this query for performance:
51 | {query}
52 |
53 | Consider:
54 | 1. Index usage
55 | 2. Join efficiency
56 | 3. Data volume
57 | 4. Alternative approaches
58 | """
59 |
60 | DATA_EXPLORATION_PROMPT = """
61 | User wants to explore: {topic}
62 |
63 | Steps:
64 | 1. Identify relevant tables
65 | 2. Understand relationships
66 | 3. Check data quality
67 | 4. Build insights progressively
68 | """
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # 项目元数据和依赖配置 - 基于Gemini CLI架构的数据库Agent
2 |
3 | [build-system]
4 | requires = ["setuptools>=61.0", "wheel"]
5 | build-backend = "setuptools.build_meta"
6 |
7 | [project]
8 | name = "dbrheo"
9 | version = "1.0.0"
10 | description = "智能数据库Agent - 基于Gemini CLI架构"
11 | authors = [{name = "DbRheo Team", email = "team@dbrheo.com"}]
12 | license = {text = "MIT"}
13 | requires-python = ">=3.9"
14 | dependencies = [
15 | # Web框架和服务器(基于最新稳定版本)
16 | "fastapi>=0.115.0",
17 | "uvicorn[standard]>=0.32.0",
18 | "websockets>=13.1",
19 |
20 | # 数据库相关(对齐生产环境版本)
21 | "sqlalchemy[asyncio]>=2.0.36",
22 | "asyncpg>=0.30.0",
23 | "aiomysql>=0.2.0",
24 | "aiosqlite>=0.20.0",
25 |
26 | # AI和API(对齐Gemini CLI使用的版本)
27 | "google-generativeai>=0.8.3",
28 | "google-auth>=2.35.0",
29 | "google-auth-oauthlib>=1.2.1",
30 |
31 | # 核心依赖(生产级版本)
32 | "pydantic>=2.10.0",
33 | "pyyaml>=6.0.2",
34 | "rich>=13.9.0",
35 | "click>=8.1.7",
36 |
37 | # 工具函数(最新稳定版本)
38 | "aiofiles>=24.1.0",
39 | "httpx>=0.28.0",
40 | "python-multipart>=0.0.12",
41 |
42 | # 监控和遥测(对齐Gemini CLI的OpenTelemetry版本)
43 | "opentelemetry-api>=1.28.0", # 对齐@opentelemetry/api@1.9.0
44 | "opentelemetry-sdk>=1.28.0",
45 | "opentelemetry-exporter-otlp>=1.28.0", # 对齐@opentelemetry/exporter-*@0.52.0
46 | "opentelemetry-instrumentation-httpx>=0.49b0"
47 | ]
48 |
49 | [project.optional-dependencies]
50 | dev = [
51 | # 测试框架
52 | "pytest>=8.3.0",
53 | "pytest-asyncio>=0.24.0",
54 | "pytest-cov>=6.0.0",
55 | "pytest-mock>=3.14.0",
56 |
57 | # 代码质量工具
58 | "black>=24.10.0",
59 | "isort>=5.13.0",
60 | "mypy>=1.13.0",
61 | "ruff>=0.8.0",
62 |
63 | # 开发工具
64 | "coverage>=7.6.0",
65 | "bandit>=1.8.0",
66 | "safety>=3.2.0"
67 | ]
68 |
69 | [tool.setuptools.packages.find]
70 | where = ["packages/core/src"]
71 | include = ["dbrheo*"]
72 |
73 | [tool.setuptools.package-dir]
74 | "" = "packages/core/src"
75 |
76 | [tool.black]
77 | line-length = 88
78 | target-version = ['py39']
79 |
80 | [tool.isort]
81 | profile = "black"
82 | line_length = 88
83 |
84 | [tool.mypy]
85 | python_version = "3.9"
86 | strict = true
87 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/utils/type_converter.py:
--------------------------------------------------------------------------------
1 | """
2 | 类型转换工具 - 处理数据库返回的特殊类型
3 | 确保所有数据都能被 Gemini API 序列化
4 | """
5 |
6 | from decimal import Decimal
7 | from datetime import datetime, date, time
8 | from typing import Any, Dict, List, Union
9 | import json
10 |
11 |
12 | def convert_to_serializable(value: Any) -> Any:
13 | """
14 | 将数据库返回的特殊类型转换为可序列化的基本类型
15 |
16 | 支持的转换:
17 | - Decimal -> float
18 | - datetime/date/time -> ISO格式字符串
19 | - bytes -> base64字符串(如果需要)
20 | - 嵌套的字典和列表递归处理
21 | """
22 | if value is None:
23 | return None
24 |
25 | # Decimal 转换为 float
26 | if isinstance(value, Decimal):
27 | # 保持精度,但转换为 float
28 | return float(value)
29 |
30 | # 日期时间类型转换为 ISO 格式字符串
31 | elif isinstance(value, datetime):
32 | return value.isoformat()
33 | elif isinstance(value, date):
34 | return value.isoformat()
35 | elif isinstance(value, time):
36 | return value.isoformat()
37 |
38 | # bytes 类型转换(如果需要可以转为 base64)
39 | elif isinstance(value, bytes):
40 | try:
41 | # 尝试 UTF-8 解码
42 | return value.decode('utf-8')
43 | except UnicodeDecodeError:
44 | # 如果解码失败,转为十六进制字符串
45 | return value.hex()
46 |
47 | # 递归处理字典
48 | elif isinstance(value, dict):
49 | return {k: convert_to_serializable(v) for k, v in value.items()}
50 |
51 | # 递归处理列表
52 | elif isinstance(value, (list, tuple)):
53 | return [convert_to_serializable(item) for item in value]
54 |
55 | # 其他类型尝试直接返回
56 | else:
57 | # 检查是否可以被 JSON 序列化
58 | try:
59 | json.dumps(value)
60 | return value
61 | except (TypeError, ValueError):
62 | # 如果不能序列化,转为字符串
63 | return str(value)
64 |
65 |
66 | def convert_row_to_serializable(row: Dict[str, Any]) -> Dict[str, Any]:
67 | """
68 | 转换数据库查询结果的单行数据
69 | """
70 | return convert_to_serializable(row)
71 |
72 |
73 | def convert_rows_to_serializable(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
74 | """
75 | 转换数据库查询结果的多行数据
76 | """
77 | return [convert_row_to_serializable(row) for row in rows]
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | DbRheo数据库Agent核心包
3 | 导出主要API供外部使用 - 基于Gemini CLI架构设计
4 | """
5 |
6 | # 核心组件
7 | from .core.client import DatabaseClient
8 | from .core.chat import DatabaseChat
9 | from .core.turn import DatabaseTurn
10 | from .core.scheduler import DatabaseToolScheduler
11 | from .core.prompts import DatabasePromptManager
12 |
13 | # 工具系统
14 | from .tools.sql_tool import SQLTool
15 | from .tools.schema_discovery import SchemaDiscoveryTool
16 | from .tools.registry import DatabaseToolRegistry
17 | from .tools.base import DatabaseTool
18 | from .tools.risk_evaluator import DatabaseRiskEvaluator
19 |
20 | # 适配器
21 | from .adapters.base import DatabaseAdapter
22 | from .adapters.connection_manager import DatabaseConnectionManager
23 | from .adapters.sqlite_adapter import SQLiteAdapter
24 | from .adapters.transaction_manager import DatabaseTransactionManager
25 | from .adapters.dialect_parser import SQLDialectParser
26 |
27 | # 服务层
28 | from .services.gemini_service_new import GeminiService
29 |
30 | # 监控遥测
31 | from .telemetry.tracer import DatabaseTracer
32 | from .telemetry.metrics import DatabaseMetrics
33 | from .telemetry.logger import DatabaseLogger
34 |
35 | # 配置
36 | from .config.base import DatabaseConfig
37 |
38 | # 工具函数
39 | from .utils.retry import with_retry, RetryConfig
40 | from .utils.errors import DatabaseAgentError, ToolExecutionError
41 |
42 | # 类型定义
43 | from .types.core_types import *
44 | from .types.tool_types import *
45 |
46 | # API
47 | from .api.app import create_app
48 |
49 | __version__ = "1.0.0"
50 | __all__ = [
51 | # 核心组件
52 | "DatabaseClient",
53 | "DatabaseChat",
54 | "DatabaseTurn",
55 | "DatabaseToolScheduler",
56 | "DatabasePromptManager",
57 |
58 | # 工具系统
59 | "SQLTool",
60 | "SchemaDiscoveryTool",
61 | "DatabaseToolRegistry",
62 | "DatabaseTool",
63 | "DatabaseRiskEvaluator",
64 |
65 | # 适配器
66 | "DatabaseAdapter",
67 | "DatabaseConnectionManager",
68 | "SQLiteAdapter",
69 | "DatabaseTransactionManager",
70 | "SQLDialectParser",
71 |
72 | # 服务层
73 | "GeminiService",
74 |
75 | # 监控遥测
76 | "DatabaseTracer",
77 | "DatabaseMetrics",
78 | "DatabaseLogger",
79 |
80 | # 配置
81 | "DatabaseConfig",
82 |
83 | # 工具函数
84 | "with_retry",
85 | "RetryConfig",
86 | "DatabaseAgentError",
87 | "ToolExecutionError",
88 |
89 | # API
90 | "create_app"
91 | ]
92 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Git Ignore Rules - DbRheo Project
2 |
3 | # Python
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 | *.so
8 | .Python
9 | build/
10 | develop-eggs/
11 | dist/
12 | downloads/
13 | eggs/
14 | .eggs/
15 | lib/
16 | lib64/
17 | parts/
18 | sdist/
19 | var/
20 | wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
26 | # Virtual environments
27 | venv/
28 | env/
29 | ENV/
30 | .venv/
31 | .conda/
32 |
33 | # Environment variables (keep .env.example for reference)
34 | .env
35 | .env.local
36 | .env.production
37 | .env.development
38 | .env.test
39 |
40 | # IDE and Editors
41 | .vscode/
42 | .idea/
43 | *.swp
44 | *.swo
45 | *~
46 | .vim/
47 | .emacs.d/
48 | .sublime-*
49 |
50 | # AI Assistant configurations
51 | .claude/
52 | .cursor/
53 | .codeium/
54 |
55 | # Testing
56 | .coverage
57 | .pytest_cache/
58 | htmlcov/
59 | .tox/
60 | .nox/
61 | coverage.xml
62 | *.cover
63 | .hypothesis/
64 |
65 | # Logs (keep logs directory structure but ignore log files)
66 | *.log
67 | logs/*.log
68 | logs/**/*.log
69 | !logs/.gitkeep
70 |
71 | # Database files (runtime databases, not schema)
72 | *.db
73 | *.sqlite
74 | *.sqlite3
75 | mydatabase.db
76 | test.db
77 |
78 | # Node.js (for web package)
79 | node_modules/
80 | npm-debug.log*
81 | yarn-debug.log*
82 | yarn-error.log*
83 | .pnpm-debug.log*
84 | .yarn/
85 | .pnp.*
86 |
87 | # Build outputs
88 | dist/
89 | build/
90 | .next/
91 | .nuxt/
92 | out/
93 | .output/
94 |
95 | # OS specific files
96 | .DS_Store
97 | .DS_Store?
98 | ._*
99 | .Spotlight-V100
100 | .Trashes
101 | ehthumbs.db
102 | Thumbs.db
103 | Desktop.ini
104 |
105 | # Temporary files
106 | *.tmp
107 | *.temp
108 | .cache/
109 | .parcel-cache/
110 |
111 | # Development and testing data
112 | sample_data/
113 | *.xlsx
114 | *.csv
115 | !testdata/
116 | !**/sample_data.json
117 |
118 | # Personal/sensitive files
119 | session.md
120 | **/session.md
121 | dependencies_report.txt
122 |
123 | # User configuration
124 | config.yaml
125 |
126 | # DbRheo connection configurations
127 | .dbrheo/
128 | # 但保留示例配置文件
129 | !.dbrheo/connections.yaml.example
130 |
131 | # SSH keys and certificates (security)
132 | *.pem
133 | *.key
134 | *.crt
135 | *.cer
136 | id_rsa*
137 | id_dsa*
138 | id_ecdsa*
139 | id_ed25519*
140 |
141 | # Database credentials and sensitive configs
142 | **/credentials.yaml
143 | **/secrets.yaml
144 | **/passwords.txt
145 |
146 | # Promotion templates (keep private for customization)
147 | promotion_template_*.md
148 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/adapters/connection_manager.py:
--------------------------------------------------------------------------------
1 | """
2 | DatabaseConnectionManager - 连接池管理
3 | 管理多数据库连接池,提供连接健康检查和负载均衡
4 | """
5 |
6 | from typing import Dict, Optional, Any
7 | from .base import DatabaseAdapter
8 | from ..config.base import DatabaseConfig
9 |
10 |
11 | class DatabaseConnectionManager:
12 | """
13 | 数据库连接池管理
14 | - 多数据库连接池(pools字典)
15 | - 连接健康检查(_check_connection_health)
16 | - 负载均衡和故障转移
17 | """
18 |
19 | def __init__(self, config: DatabaseConfig):
20 | self.config = config
21 | self.pools: Dict[str, Any] = {} # 多数据库连接池
22 | self.active_connections: Dict[str, DatabaseAdapter] = {}
23 |
24 | async def get_connection(self, db_name: Optional[str] = None) -> DatabaseAdapter:
25 | """
26 | 获取数据库连接(支持连接池)
27 | 如果连接不健康,自动重新创建
28 | """
29 | db_key = db_name or self.config.default_database
30 |
31 | # 检查现有连接
32 | if db_key in self.active_connections:
33 | conn = self.active_connections[db_key]
34 | if await self._check_connection_health(conn):
35 | return conn
36 | else:
37 | # 连接不健康,移除并重新创建
38 | await self._remove_connection(db_key)
39 |
40 | # 创建新连接
41 | conn = await self._create_connection(db_key)
42 | self.active_connections[db_key] = conn
43 | return conn
44 |
45 | async def _create_connection(self, db_key: str) -> DatabaseAdapter:
46 | """创建新的数据库连接"""
47 | # TODO: 根据数据库类型创建相应的适配器
48 | # 这里需要实现具体的适配器工厂逻辑
49 | connection_string = self.config.get_connection_string(db_key)
50 |
51 | # 暂时返回一个模拟的适配器
52 | from .sqlite_adapter import SQLiteAdapter # 假设有SQLite适配器
53 | adapter = SQLiteAdapter(connection_string)
54 | await adapter.connect()
55 | return adapter
56 |
57 | async def _check_connection_health(self, conn: DatabaseAdapter) -> bool:
58 | """连接健康检查"""
59 | return await conn.health_check()
60 |
61 | async def _remove_connection(self, db_key: str):
62 | """移除连接"""
63 | if db_key in self.active_connections:
64 | conn = self.active_connections[db_key]
65 | try:
66 | await conn.disconnect()
67 | except:
68 | pass # 忽略断开连接时的错误
69 | del self.active_connections[db_key]
70 |
71 | async def close_all_connections(self):
72 | """关闭所有连接"""
73 | for db_key in list(self.active_connections.keys()):
74 | await self._remove_connection(db_key)
75 |
--------------------------------------------------------------------------------
/packages/web/src/App.tsx:
--------------------------------------------------------------------------------
1 | /**
2 | * Main Application Component - DbRheo Database Agent Web Interface
3 | * Provides chat interface, SQL editor, result display and other core features
4 | */
5 | import React from 'react'
6 |
7 | function App() {
8 | return (
9 |
10 |
11 |
12 |
13 |
14 | DbRheo - Intelligent Database Agent
15 |
16 |
17 | MVP Version - Based on Gemini CLI Architecture
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 | Welcome to DbRheo Database Agent
27 |
28 |
29 | This is the MVP version's basic interface. Core features include:
30 |
31 |
32 | - • SQLTool: Intelligent SQL execution tool
33 | - • SchemaDiscoveryTool: Database structure exploration tool
34 | - • Turn system and tool scheduling based on Gemini CLI
35 | - • Progressive database understanding and intelligent risk assessment
36 |
37 |
38 |
39 | Development Status: Currently in planning phase, basic architecture established, core features pending implementation.
40 |
41 |
42 | Recommendation: Please use the CLI interface for full functionality experience. Web interface will be enhanced in future versions.
43 |
44 |
45 |
46 |
47 |
Technology Stack
48 |
49 |
• React 19 + TypeScript
50 |
• Tailwind CSS 3.4
51 |
• Vite 6.0 + Monaco Editor
52 |
• Socket.IO + TanStack Query
53 |
54 |
55 |
56 |
57 |
58 | )
59 | }
60 |
61 | export default App
62 |
--------------------------------------------------------------------------------
/packages/cli/setup_enhanced_layout.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | 增强布局功能设置脚本
4 | 用于安装依赖和测试功能
5 | """
6 |
7 | import os
8 | import sys
9 | import subprocess
10 | import importlib.util
11 |
12 | def check_dependency(package_name):
13 | """检查依赖是否已安装"""
14 | spec = importlib.util.find_spec(package_name)
15 | return spec is not None
16 |
17 | def install_dependency(package_name):
18 | """安装依赖包"""
19 | try:
20 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', package_name])
21 | return True
22 | except subprocess.CalledProcessError:
23 | return False
24 |
25 | def main():
26 | print("DbRheo CLI 增强布局功能设置")
27 | print("=" * 40)
28 |
29 | # 检查prompt-toolkit
30 | if check_dependency('prompt_toolkit'):
31 | print("✓ prompt-toolkit 已安装")
32 | else:
33 | print("✗ prompt-toolkit 未安装,正在安装...")
34 | if install_dependency('prompt-toolkit>=3.0.43'):
35 | print("✓ prompt-toolkit 安装成功")
36 | else:
37 | print("✗ prompt-toolkit 安装失败")
38 | print("请手动运行: pip install prompt-toolkit>=3.0.43")
39 | sys.exit(1)
40 |
41 | # 设置环境变量
42 | print("\n配置增强布局...")
43 | os.environ['DBRHEO_ENHANCED_LAYOUT'] = 'true'
44 | print("✓ 已启用增强布局模式")
45 |
46 | print("\n环境变量配置:")
47 | print("DBRHEO_ENHANCED_LAYOUT=true # 启用增强布局")
48 | print("DBRHEO_INPUT_HEIGHT_MIN=3 # 输入框最小高度")
49 | print("DBRHEO_INPUT_HEIGHT_MAX=10 # 输入框最大高度")
50 | print("DBRHEO_AUTO_SCROLL=true # 自动滚动")
51 | print("DBRHEO_SHOW_SEPARATOR=true # 显示分隔线")
52 |
53 | print("\n要启用增强布局,请在运行CLI前设置环境变量:")
54 | print("Windows: set DBRHEO_ENHANCED_LAYOUT=true")
55 | print("Linux/Mac: export DBRHEO_ENHANCED_LAYOUT=true")
56 |
57 | print("\n测试增强布局...")
58 | try:
59 | # 导入测试
60 | from src.dbrheo_cli.ui.layout_manager import create_layout_manager, LayoutConfig
61 | from src.dbrheo_cli.app.config import CLIConfig
62 |
63 | # 创建测试配置
64 | config = CLIConfig()
65 | config.enhanced_layout = True
66 |
67 | # 测试布局管理器
68 | manager = create_layout_manager(config)
69 | if manager and manager.is_available():
70 | print("✓ 增强布局管理器可用")
71 | print("✓ prompt-toolkit 集成正常")
72 | print("\n🎉 增强布局功能设置完成!")
73 | print("\n现在可以运行 CLI 并体验底部固定输入框功能")
74 | else:
75 | print("✗ 增强布局管理器不可用")
76 |
77 | # 调试信息
78 | layout_config = LayoutConfig.from_env()
79 | print(f"调试: enabled={layout_config.enabled}")
80 | print(f"调试: prompt-toolkit可用={check_dependency('prompt_toolkit')}")
81 |
82 | except ImportError as e:
83 | print(f"✗ 导入错误: {e}")
84 | print("请确保在正确的目录运行此脚本")
85 |
86 | if __name__ == '__main__':
87 | main()
--------------------------------------------------------------------------------
/packages/web/src/components/database/ResultTable.tsx:
--------------------------------------------------------------------------------
1 | /**
2 | * 查询结果表格组件 - 展示SQL查询结果
3 | * 支持分页、排序、导出等功能
4 | */
5 | import React from 'react'
6 |
7 | interface ResultTableProps {
8 | data?: any[]
9 | columns?: string[]
10 | loading?: boolean
11 | }
12 |
13 | export function ResultTable({ data = [], columns = [], loading = false }: ResultTableProps) {
14 | if (loading) {
15 | return (
16 |
22 | )
23 | }
24 |
25 | if (data.length === 0) {
26 | return (
27 |
28 |
29 | 暂无查询结果
30 |
31 |
32 | )
33 | }
34 |
35 | return (
36 |
37 |
38 |
39 |
40 | 查询结果 ({data.length} 行)
41 |
42 |
43 |
46 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 | {columns.map((column, index) => (
58 | |
62 | {column}
63 | |
64 | ))}
65 |
66 |
67 |
68 | {data.map((row, rowIndex) => (
69 |
70 | {columns.map((column, colIndex) => (
71 | |
75 | {row[column]}
76 | |
77 | ))}
78 |
79 | ))}
80 |
81 |
82 |
83 |
84 | )
85 | }
86 |
87 | export default ResultTable
88 |
--------------------------------------------------------------------------------
/packages/cli/src/dbrheo_cli/utils/api_key_checker.py:
--------------------------------------------------------------------------------
1 | """
2 | API Key 检查工具
3 | 检查各个模型所需的 API Key 是否已配置
4 | """
5 |
6 | import os
7 | from typing import Tuple, Optional, List
8 | from ..ui.console import console
9 | from ..i18n import _
10 |
11 |
12 | def check_api_key_for_model(model: str) -> Tuple[bool, Optional[str]]:
13 | """
14 | 检查指定模型的 API Key 是否已配置
15 |
16 | Args:
17 | model: 模型名称
18 |
19 | Returns:
20 | (是否配置, 缺失的环境变量名)
21 | """
22 | model_lower = model.lower()
23 |
24 | # Gemini 系列
25 | if 'gemini' in model_lower:
26 | if os.environ.get('GOOGLE_API_KEY') or os.environ.get('GEMINI_API_KEY'):
27 | return True, None
28 | return False, 'api_key_gemini'
29 |
30 | # Claude 系列
31 | elif any(name in model_lower for name in ['claude', 'sonnet', 'opus']):
32 | if os.environ.get('ANTHROPIC_API_KEY') or os.environ.get('CLAUDE_API_KEY'):
33 | return True, None
34 | return False, 'api_key_claude'
35 |
36 | # OpenAI 系列
37 | elif any(name in model_lower for name in ['gpt', 'openai', 'o1', 'o3', 'o4']):
38 | if os.environ.get('OPENAI_API_KEY'):
39 | return True, None
40 | return False, 'api_key_openai'
41 |
42 | # 未知模型,假设不需要 API Key
43 | return True, None
44 |
45 |
46 | def show_api_key_setup_guide(model: str):
47 | """
48 | 显示 API Key 设置指南
49 |
50 | Args:
51 | model: 模型名称
52 | """
53 | has_key, key_type = check_api_key_for_model(model)
54 |
55 | if not has_key and key_type:
56 | console.print(f"\n[yellow]{_('api_key_missing', model=model)}[/yellow]")
57 | console.print(f"\n{_('api_key_setup')}")
58 | console.print(f" [cyan]{_(key_type)}[/cyan]")
59 |
60 | console.print(f"\n{_('api_key_instructions')}")
61 |
62 | # 根据模型类型显示对应的 URL
63 | model_lower = model.lower()
64 | if 'gemini' in model_lower:
65 | console.print(f" [blue]{_('api_key_gemini_url')}[/blue]")
66 | elif any(name in model_lower for name in ['claude', 'sonnet', 'opus']):
67 | console.print(f" [blue]{_('api_key_claude_url')}[/blue]")
68 | elif any(name in model_lower for name in ['gpt', 'openai', 'o1', 'o3']):
69 | console.print(f" [blue]{_('api_key_openai_url')}[/blue]")
70 |
71 | console.print(f"\n[dim]{_('api_key_reminder')}[/dim]\n")
72 | return True
73 |
74 | return False
75 |
76 |
77 | def check_all_api_keys() -> List[str]:
78 | """
79 | 检查所有常用模型的 API Key 配置情况
80 |
81 | Returns:
82 | 未配置 API Key 的模型列表
83 | """
84 | missing_models = []
85 |
86 | # 检查主要模型
87 | models_to_check = [
88 | ('gemini', 'Gemini'),
89 | ('claude', 'Claude'),
90 | ('gpt', 'OpenAI GPT')
91 | ]
92 |
93 | for model_key, model_name in models_to_check:
94 | has_key, _ = check_api_key_for_model(model_key)
95 | if not has_key:
96 | missing_models.append(model_name)
97 |
98 | return missing_models
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/core/next_speaker.py:
--------------------------------------------------------------------------------
1 | """
2 | next_speaker判断逻辑 - 完全参考Gemini CLI的checkNextSpeaker
3 | AI自主判断下一步是继续执行还是等待用户输入
4 | """
5 |
6 | from typing import Optional, Dict, Any
7 | from ..types.core_types import AbortSignal
8 | from .chat import DatabaseChat
9 | from .prompts import DatabasePromptManager
10 |
11 |
12 | # JSON Schema定义
13 | NEXT_SPEAKER_SCHEMA = {
14 | "type": "object",
15 | "properties": {
16 | "next_speaker": {
17 | "type": "string",
18 | "enum": ["user", "model"],
19 | "description": "Who should speak next"
20 | },
21 | "reasoning": {
22 | "type": "string",
23 | "description": "Explanation for the decision"
24 | }
25 | },
26 | "required": ["next_speaker", "reasoning"]
27 | }
28 |
29 |
30 | async def check_next_speaker(
31 | chat: DatabaseChat,
32 | client: 'DatabaseClient',
33 | signal: AbortSignal
34 | ) -> Optional[Dict[str, Any]]:
35 | """
36 | AI自主判断下一步 - 与Gemini CLI的checkNextSpeaker完全一致
37 |
38 | 判断规则(按优先级):
39 | 1. 特殊情况优先处理:
40 | - 最后是工具执行结果 → model继续处理结果
41 | - 最后是空的model消息 → model继续完成响应
42 | 2. AI智能判断(通过临时提示词询问):
43 | - Model继续:明确表示下一步动作
44 | - User回答:向用户提出了需要回答的问题
45 | - User输入:完成当前任务,等待新指令
46 | """
47 | # 调试
48 | from ..utils.debug_logger import log_info
49 | log_info("NextSpeaker", f"🤔 CHECK_NEXT_SPEAKER called")
50 |
51 | # 1. 特殊情况优先处理(与Gemini CLI逻辑一致)
52 | curated_history = chat.get_history(True)
53 | if not curated_history:
54 | return None
55 |
56 | last_message = curated_history[-1]
57 |
58 | # 工具刚执行完,AI应该继续处理结果
59 | if last_message.get('role') == 'function':
60 | return {
61 | 'next_speaker': 'model',
62 | 'reasoning': 'Function response received, model should process the result'
63 | }
64 |
65 | # 空的model消息,应该继续完成响应
66 | if (last_message.get('role') == 'model' and
67 | not any(part.get('text', '').strip() for part in last_message.get('parts', []))):
68 | return {
69 | 'next_speaker': 'model',
70 | 'reasoning': 'Empty model response, should continue'
71 | }
72 |
73 | # 2. AI智能判断(临时提示词,不保存到历史)
74 | prompt_manager = DatabasePromptManager()
75 | check_prompt = prompt_manager.get_next_speaker_prompt()
76 |
77 | # 构建临时内容(与Gemini CLI方式一致)
78 | contents = [
79 | *curated_history,
80 | {'role': 'user', 'parts': [{'text': check_prompt}]}
81 | ]
82 |
83 | # 3. 调用LLM判断(使用相同的模型和配置)
84 | try:
85 | response = await client.generate_json(
86 | contents,
87 | NEXT_SPEAKER_SCHEMA,
88 | signal,
89 | # 使用临时系统指令覆盖(不影响主对话)
90 | system_instruction="" # 清空系统指令,专注判断任务
91 | )
92 | return response
93 | except Exception as e:
94 | # 判断失败时的默认行为
95 | return {
96 | 'next_speaker': 'user',
97 | 'reasoning': f'Failed to determine next speaker: {str(e)}'
98 | }
99 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/types/file_types.py:
--------------------------------------------------------------------------------
1 | """
2 | 文件操作相关的类型定义
3 | """
4 |
5 | from dataclasses import dataclass
6 | from typing import Optional, List, Dict, Any
7 | from enum import Enum
8 | from .tool_types import DatabaseConfirmationDetails
9 |
10 |
11 | class FileFormat(Enum):
12 | """支持的文件格式"""
13 | CSV = "csv"
14 | JSON = "json"
15 | EXCEL = "excel"
16 | SQL = "sql"
17 | MARKDOWN = "markdown"
18 | TEXT = "text"
19 | PARQUET = "parquet"
20 | YAML = "yaml"
21 | XML = "xml"
22 |
23 |
24 | class ApprovalMode(Enum):
25 | """文件操作的审批模式"""
26 | MANUAL = "manual" # 每次都需要确认
27 | AUTO_READ = "auto_read" # 自动允许读取
28 | AUTO_WRITE = "auto_write" # 自动允许写入(危险)
29 | AUTO_ALL = "auto_all" # 自动允许所有操作
30 |
31 |
32 | @dataclass
33 | class FileWriteConfirmationDetails(DatabaseConfirmationDetails):
34 | """文件写入确认详情"""
35 | type: str = "file_write" # 确认类型
36 | title: str = ""
37 | file_path: str = ""
38 | file_diff: Optional[str] = None # diff格式的内容差异
39 | content_preview: Optional[str] = None # 内容预览
40 | estimated_size: Optional[str] = None # 预估文件大小
41 | format: Optional[FileFormat] = None # 文件格式
42 |
43 | # 数据库相关的元信息
44 | data_source_sql: Optional[str] = None # 数据来源SQL
45 | affected_tables: Optional[List[str]] = None # 相关数据表
46 | row_count: Optional[int] = None # 导出行数
47 |
48 | # 操作选项
49 | allow_overwrite: bool = True # 是否允许覆盖
50 | append_mode: bool = False # 是否为追加模式
51 |
52 |
53 | @dataclass
54 | class FileOperationResult:
55 | """文件操作结果的详细信息"""
56 | success: bool
57 | file_path: str
58 | operation: str # read, write, append, delete
59 |
60 | # 统计信息
61 | bytes_processed: Optional[int] = None
62 | lines_processed: Optional[int] = None
63 | duration_ms: Optional[float] = None
64 |
65 | # 数据相关
66 | format: Optional[FileFormat] = None
67 | encoding: Optional[str] = None
68 | compression: Optional[str] = None # gzip, bz2, etc
69 |
70 | # 错误信息
71 | error: Optional[str] = None
72 | error_details: Optional[Dict[str, Any]] = None
73 |
74 |
75 | @dataclass
76 | class StreamingConfig:
77 | """流式处理配置"""
78 | chunk_size: int = 10000 # 每批处理的行数
79 | memory_limit_mb: int = 100 # 内存限制(MB)
80 | progress_interval: int = 1000 # 进度更新间隔(行数)
81 | enable_compression: bool = True # 是否启用压缩
82 |
83 |
84 | @dataclass
85 | class FileAnalysisResult:
86 | """文件分析结果"""
87 | file_path: str
88 | file_size: int
89 | line_count: Optional[int] = None
90 |
91 | # 格式信息
92 | detected_format: Optional[FileFormat] = None
93 | detected_encoding: Optional[str] = None
94 | has_header: Optional[bool] = None
95 |
96 | # CSV/表格特定
97 | column_count: Optional[int] = None
98 | column_names: Optional[List[str]] = None
99 | data_types: Optional[Dict[str, str]] = None
100 | null_counts: Optional[Dict[str, int]] = None
101 |
102 | # 内容摘要
103 | preview_lines: Optional[List[str]] = None
104 | sample_rows: Optional[List[Dict[str, Any]]] = None
105 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/utils/parameter_sanitizer.py:
--------------------------------------------------------------------------------
1 | """
2 | 参数清理工具 - 清理工具参数中不被 Google AI SDK 支持的字段
3 | 参考 Gemini CLI 的 sanitizeParameters 实现
4 | """
5 |
6 | from typing import Dict, Any, Set
7 |
8 |
9 | def sanitize_parameters(schema: Dict[str, Any]) -> Dict[str, Any]:
10 | """
11 | 清理参数模式,移除 Google AI SDK 不支持的字段
12 |
13 | Args:
14 | schema: 原始参数模式
15 |
16 | Returns:
17 | 清理后的参数模式
18 | """
19 | if not schema:
20 | return schema
21 |
22 | # 创建副本以避免修改原始数据
23 | cleaned_schema = schema.copy()
24 |
25 | # 使用集合追踪已访问的对象,防止循环引用
26 | visited = set()
27 |
28 | _sanitize_parameters_recursive(cleaned_schema, visited)
29 |
30 | return cleaned_schema
31 |
32 |
33 | def _sanitize_parameters_recursive(schema: Dict[str, Any], visited: Set[int]):
34 | """
35 | 递归清理参数模式
36 |
37 | Args:
38 | schema: 当前处理的模式
39 | visited: 已访问对象的集合
40 | """
41 | # 防止循环引用
42 | schema_id = id(schema)
43 | if schema_id in visited:
44 | return
45 | visited.add(schema_id)
46 |
47 | # 移除不支持的字段
48 | unsupported_fields = [
49 | 'default', # Protocol message Schema has no "default" field
50 | 'minimum', # Protocol message Schema has no "minimum" field
51 | 'maximum', # Protocol message Schema has no "maximum" field
52 | 'minLength', # 可能不支持
53 | 'maxLength', # 可能不支持
54 | 'minItems', # 可能不支持
55 | 'maxItems', # 可能不支持
56 | 'uniqueItems', # 可能不支持
57 | 'additionalProperties', # 可能不支持
58 | '$schema', # JSON Schema 元数据
59 | '$ref', # JSON Schema 引用
60 | '$defs', # JSON Schema 定义
61 | ]
62 |
63 | for field in unsupported_fields:
64 | if field in schema:
65 | del schema[field]
66 |
67 | # 处理 format 字段 - 只保留 'enum' 和 'date-time'
68 | if schema.get('type') == 'string' and 'format' in schema:
69 | if schema['format'] not in ['enum', 'date-time']:
70 | del schema['format']
71 |
72 | # 递归处理 properties
73 | if 'properties' in schema and isinstance(schema['properties'], dict):
74 | for prop_name, prop_schema in schema['properties'].items():
75 | if isinstance(prop_schema, dict):
76 | _sanitize_parameters_recursive(prop_schema, visited)
77 |
78 | # 递归处理 items(数组类型)
79 | if 'items' in schema and isinstance(schema['items'], dict):
80 | _sanitize_parameters_recursive(schema['items'], visited)
81 |
82 | # 递归处理 anyOf
83 | if 'anyOf' in schema and isinstance(schema['anyOf'], list):
84 | for item in schema['anyOf']:
85 | if isinstance(item, dict):
86 | _sanitize_parameters_recursive(item, visited)
87 |
88 | # 递归处理 oneOf
89 | if 'oneOf' in schema and isinstance(schema['oneOf'], list):
90 | for item in schema['oneOf']:
91 | if isinstance(item, dict):
92 | _sanitize_parameters_recursive(item, visited)
93 |
94 | # 递归处理 allOf
95 | if 'allOf' in schema and isinstance(schema['allOf'], list):
96 | for item in schema['allOf']:
97 | if isinstance(item, dict):
98 | _sanitize_parameters_recursive(item, visited)
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/api/app.py:
--------------------------------------------------------------------------------
1 | """
2 | FastAPI应用创建和配置
3 | 提供完整的Web API服务
4 | """
5 |
6 | from fastapi import FastAPI, HTTPException
7 | from fastapi.middleware.cors import CORSMiddleware
8 | from fastapi.staticfiles import StaticFiles
9 | from contextlib import asynccontextmanager
10 | import logging
11 | import os
12 | from pathlib import Path
13 |
14 | from ..config.base import DatabaseConfig
15 | from ..core.client import DatabaseClient
16 | from .dependencies import set_app_state
17 |
18 |
19 | @asynccontextmanager
20 | async def lifespan(app: FastAPI):
21 | """应用生命周期管理"""
22 | # 确保环境变量已加载
23 | if not os.getenv("GOOGLE_API_KEY"):
24 | # 尝试加载.env文件
25 | env_paths = [
26 | Path.cwd() / '.env',
27 | Path(__file__).parent.parent.parent.parent.parent / '.env',
28 | ]
29 |
30 | for env_path in env_paths:
31 | if env_path.exists():
32 | logging.info(f"Loading environment from: {env_path}")
33 | with open(env_path, 'r', encoding='utf-8') as f:
34 | for line in f:
35 | line = line.strip()
36 | if line and not line.startswith('#') and '=' in line:
37 | key, value = line.split('=', 1)
38 | key = key.strip()
39 | value = value.strip()
40 | if key not in os.environ:
41 | os.environ[key] = value
42 | break
43 |
44 | # 启动时初始化
45 | config = DatabaseConfig()
46 | client = DatabaseClient(config)
47 |
48 | set_app_state("config", config)
49 | set_app_state("client", client)
50 |
51 | logging.info("DbRheo API server started")
52 | logging.info(f"GOOGLE_API_KEY configured: {'Yes' if os.getenv('GOOGLE_API_KEY') else 'No'}")
53 |
54 | yield
55 |
56 | # 关闭时清理
57 | # TODO: 实现客户端清理逻辑
58 | pass
59 |
60 | logging.info("DbRheo API server stopped")
61 |
62 |
63 | def create_app() -> FastAPI:
64 | """创建FastAPI应用"""
65 |
66 | app = FastAPI(
67 | title="DbRheo API",
68 | description="智能数据库Agent API - 基于Gemini CLI架构",
69 | version="1.0.0",
70 | lifespan=lifespan
71 | )
72 |
73 | # CORS中间件
74 | app.add_middleware(
75 | CORSMiddleware,
76 | allow_origins=["http://localhost:3000"], # Web界面地址
77 | allow_credentials=True,
78 | allow_methods=["*"],
79 | allow_headers=["*"],
80 | )
81 |
82 | # 延迟导入路由以避免循环导入
83 | from .routes.chat import chat_router
84 | from .routes.database import database_router
85 | from .routes.websocket import websocket_router
86 |
87 | # 注册路由
88 | app.include_router(chat_router, prefix="/api/chat", tags=["chat"])
89 | app.include_router(database_router, prefix="/api/database", tags=["database"])
90 | app.include_router(websocket_router, prefix="/ws", tags=["websocket"])
91 |
92 | # 健康检查
93 | @app.get("/health")
94 | async def health_check():
95 | return {"status": "healthy", "service": "DbRheo API"}
96 |
97 | # 根路径
98 | @app.get("/")
99 | async def root():
100 | return {
101 | "message": "DbRheo - 智能数据库Agent API",
102 | "version": "1.0.0",
103 | "docs": "/docs"
104 | }
105 |
106 | return app
107 |
108 |
109 | # 创建应用实例供导入使用
110 | app = create_app()
111 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/utils/retry.py:
--------------------------------------------------------------------------------
1 | """
2 | 重试机制 - 提供可配置的重试逻辑
3 | 用于处理网络请求、数据库连接等可能失败的操作
4 | """
5 |
6 | import asyncio
7 | import logging
8 | from typing import TypeVar, Callable, Any, Optional, List, Type
9 | from dataclasses import dataclass
10 | from functools import wraps
11 |
12 | T = TypeVar('T')
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | @dataclass
18 | class RetryConfig:
19 | """重试配置"""
20 | max_attempts: int = 3
21 | base_delay: float = 1.0
22 | max_delay: float = 60.0
23 | exponential_base: float = 2.0
24 | jitter: bool = True
25 | retryable_exceptions: Optional[List[Type[Exception]]] = None
26 |
27 |
28 | async def with_retry(
29 | func: Callable[..., T],
30 | config: RetryConfig,
31 | *args,
32 | **kwargs
33 | ) -> T:
34 | """
35 | 异步重试装饰器
36 | 支持指数退避、抖动、异常过滤等功能
37 | """
38 | last_exception = None
39 |
40 | for attempt in range(config.max_attempts):
41 | try:
42 | if asyncio.iscoroutinefunction(func):
43 | return await func(*args, **kwargs)
44 | else:
45 | return func(*args, **kwargs)
46 |
47 | except Exception as e:
48 | last_exception = e
49 |
50 | # 检查是否为可重试的异常
51 | if config.retryable_exceptions:
52 | if not any(isinstance(e, exc_type) for exc_type in config.retryable_exceptions):
53 | logger.warning(f"Non-retryable exception: {e}")
54 | raise e
55 |
56 | # 最后一次尝试,不再重试
57 | if attempt == config.max_attempts - 1:
58 | break
59 |
60 | # 计算延迟时间
61 | delay = _calculate_delay(attempt, config)
62 |
63 | logger.warning(
64 | f"Attempt {attempt + 1}/{config.max_attempts} failed: {e}. "
65 | f"Retrying in {delay:.2f} seconds..."
66 | )
67 |
68 | await asyncio.sleep(delay)
69 |
70 | # 所有重试都失败了
71 | logger.error(f"All {config.max_attempts} attempts failed. Last error: {last_exception}")
72 | raise last_exception
73 |
74 |
75 | def retry(config: RetryConfig):
76 | """重试装饰器"""
77 | def decorator(func: Callable[..., T]) -> Callable[..., T]:
78 | @wraps(func)
79 | async def wrapper(*args, **kwargs) -> T:
80 | return await with_retry(func, config, *args, **kwargs)
81 | return wrapper
82 | return decorator
83 |
84 |
85 | def _calculate_delay(attempt: int, config: RetryConfig) -> float:
86 | """计算延迟时间(指数退避 + 抖动)"""
87 | import random
88 |
89 | # 指数退避
90 | delay = config.base_delay * (config.exponential_base ** attempt)
91 |
92 | # 限制最大延迟
93 | delay = min(delay, config.max_delay)
94 |
95 | # 添加抖动
96 | if config.jitter:
97 | delay = delay * (0.5 + random.random() * 0.5)
98 |
99 | return delay
100 |
101 |
102 | # 预定义的重试配置
103 | DEFAULT_RETRY_CONFIG = RetryConfig(
104 | max_attempts=3,
105 | base_delay=1.0,
106 | max_delay=30.0
107 | )
108 |
109 | NETWORK_RETRY_CONFIG = RetryConfig(
110 | max_attempts=5,
111 | base_delay=0.5,
112 | max_delay=60.0,
113 | retryable_exceptions=[
114 | ConnectionError,
115 | TimeoutError,
116 | OSError
117 | ]
118 | )
119 |
120 | DATABASE_RETRY_CONFIG = RetryConfig(
121 | max_attempts=3,
122 | base_delay=2.0,
123 | max_delay=30.0,
124 | retryable_exceptions=[
125 | ConnectionError,
126 | TimeoutError
127 | ]
128 | )
129 |
--------------------------------------------------------------------------------
/packages/cli/src/dbrheo_cli/ui/ascii_art.py:
--------------------------------------------------------------------------------
1 | """
2 | DbRheo ASCII 艺术标志
3 | 提供响应式的启动画面展示
4 | """
5 |
6 | # DbRheo 短版本 (适合窄终端)
7 | SHORT_LOGO = r"""
8 | ____ _ ____ _
9 | | _ \| |__ | _ \| |__ ___ ___
10 | | | | | '_ \| |_) | '_ \ / _ \/ _ \
11 | | |_| | |_) | _ <| | | | __/ (_) |
12 | |____/|_.__/|_| \_\_| |_|\___|\___/
13 | """
14 |
15 | # DbRheo 长版本 (适合宽终端) - 超大尺寸
16 | LONG_LOGO = r"""
17 | ██████╗ ██████╗ ██████╗ ██╗ ██╗███████╗ ██████╗ ██████╗██╗ ██╗
18 | ██╔══██╗██╔══██╗██╔══██╗██║ ██║██╔════╝██╔═══██╗ ██╔════╝██║ ██║
19 | ██║ ██║██████╔╝██████╔╝███████║█████╗ ██║ ██║ ██║ ██║ ██║
20 | ██║ ██║██╔══██╗██╔══██╗██╔══██║██╔══╝ ██║ ██║ ██║ ██║ ██║
21 | ██████╔╝██████╔╝██║ ██║██║ ██║███████╗╚██████╔╝ ╚██████╗███████╗██║
22 | ╚═════╝ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝ ╚═════╝ ╚═════╝╚══════╝╚═╝
23 | """
24 |
25 | # DbRheo 倾斜版本 (斜体效果)
26 | ITALIC_LOGO = r"""
27 | ____ __ ____ __ _______ __ ____
28 | / __ \/ /_ / __ \/ /_ ___ ____ / ____/ | / / / _/
29 | / / / / __ \/ /_/ / __ \/ _ \/ __ \ / / / |/ / / /
30 | / /_/ / /_/ / _, _/ / / / __/ /_/ / / /___/ /| / _/ /
31 | /_____/_.___/_/ |_/_/ /_/\___/\____/ \____/_/ |_/ /___/
32 | """
33 |
34 | # DbRheo 超大版本 (非倾斜,更粗更大)
35 | EXTRA_LARGE_LOGO = r"""
36 | ███████╗ ██████╗ ██████╗ ██╗ ██╗ ███████╗ ██████╗ ██████╗ ██╗ ██╗
37 | ██╔═══██╗██╔══██╗ ██╔══██╗ ██║ ██║ ██╔════╝ ██╔═══██╗ ██╔════╝ ██║ ██║
38 | ██║ ██║██████╔╝ ██████╔╝ ███████║ █████╗ ██║ ██║ ██║ ██║ ██║
39 | ██║ ██║██╔══██╗ ██╔══██╗ ██╔══██║ ██╔══╝ ██║ ██║ ██║ ██║ ██║
40 | ███████╔╝██████╔╝ ██║ ██║ ██║ ██║ ███████╗ ╚██████╔╝ ╚██████╗ ███████╗ ██║
41 | ╚══════╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚══════╝ ╚═════╝ ╚═════╝ ╚══════╝ ╚═╝
42 | """
43 |
44 | # DbRheo 超大版本 (适合很宽的终端)
45 | EXTRA_LOGO = r"""
46 | ·▄▄▄▄ ▄▄▄▄· ▄▄▄ ▄ .▄▄▄▄ .
47 | ██▪ ██ ▐█ ▀█▪▀▄ █·██▪▐█▀▄.▀·▪
48 | ▐█· ▐█▌▐█▀▀█▄▐▀▀▄ ██▀▐█▐▀▀▪▄ ▄█▀▄
49 | ██. ██ ██▄▪▐█▐█•█▌██▌▐▀▐█▄▄▌▐█▌.▐▌
50 | ▀▀▀▀▀• ·▀▀▀▀ .▀ ▀▀▀▀ ▀ ▀▀▀ ▀█▄▀▪
51 | Database Intelligence Assistant
52 | """
53 |
54 | # 3D 风格版本 (特殊场合使用)
55 | LOGO_3D = r"""
56 | ___ __ ___ __
57 | / _ \/ / / _ \/ / ___ ___
58 | / // / _ \/ , _/ _ \/ -_) _ \
59 | /____/_.__/_/|_/_//_/\__/\___/
60 | D a t a b a s e A g e n t
61 | """
62 |
63 | def get_logo_width(logo: str) -> int:
64 | """计算 ASCII 艺术的宽度"""
65 | lines = logo.strip().split('\n')
66 | return max(len(line) for line in lines) if lines else 0
67 |
68 | def select_logo(terminal_width: int, style: str = "default") -> str:
69 | """
70 | 根据终端宽度和风格选择合适的 logo
71 |
72 | Args:
73 | terminal_width: 终端宽度
74 | style: 风格选择 - "default", "italic", "extra"
75 | """
76 | # 获取各版本宽度
77 | long_width = get_logo_width(LONG_LOGO)
78 | italic_width = get_logo_width(ITALIC_LOGO)
79 | extra_width = get_logo_width(EXTRA_LARGE_LOGO)
80 | short_width = get_logo_width(SHORT_LOGO)
81 |
82 | # 根据风格和宽度选择
83 | if style == "italic":
84 | if terminal_width >= italic_width + 10:
85 | return ITALIC_LOGO
86 | else:
87 | return SHORT_LOGO
88 | elif style == "extra":
89 | if terminal_width >= extra_width + 10:
90 | return EXTRA_LARGE_LOGO
91 | else:
92 | return ITALIC_LOGO
93 | else: # default
94 | # 优先使用长版本
95 | if terminal_width >= long_width + 10:
96 | return LONG_LOGO
97 | elif terminal_width >= short_width + 10:
98 | return SHORT_LOGO
99 | else:
100 | # 如果终端太窄,返回最简单的文字
101 | return "\n DbRheo CLI \n"
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/__main__.py:
--------------------------------------------------------------------------------
1 | """
2 | 应用启动入口 - 支持多种启动方式
3 | 可以作为模块运行:python -m dbrheo
4 | """
5 |
6 | import asyncio
7 | import sys
8 | import os
9 | import logging
10 | from pathlib import Path
11 |
12 | # 添加当前包到Python路径
13 | sys.path.insert(0, str(Path(__file__).parent.parent))
14 |
15 | # 加载环境变量文件
16 | def load_env_file():
17 | """加载.env文件中的环境变量"""
18 | env_paths = [
19 | Path.cwd() / '.env', # 当前工作目录
20 | Path(__file__).parent.parent.parent.parent.parent / '.env', # 项目根目录
21 | ]
22 |
23 | for env_path in env_paths:
24 | if env_path.exists():
25 | print(f"Loading environment from: {env_path}")
26 | with open(env_path, 'r', encoding='utf-8') as f:
27 | for line in f:
28 | line = line.strip()
29 | if line and not line.startswith('#') and '=' in line:
30 | key, value = line.split('=', 1)
31 | key = key.strip()
32 | value = value.strip()
33 | # 只设置未设置的环境变量
34 | if key not in os.environ:
35 | os.environ[key] = value
36 | break
37 | else:
38 | print("Warning: No .env file found")
39 |
40 | # 在导入其他模块前加载环境变量
41 | load_env_file()
42 |
43 | from dbrheo.config.base import DatabaseConfig
44 | from dbrheo.api.app import create_app
45 |
46 |
47 | def setup_logging(level: str = "INFO"):
48 | """设置日志配置"""
49 | logging.basicConfig(
50 | level=getattr(logging, level.upper()),
51 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
52 | handlers=[
53 | logging.StreamHandler(sys.stdout),
54 | logging.FileHandler("dbrheo.log")
55 | ]
56 | )
57 |
58 |
59 | def main():
60 | """主函数"""
61 | import argparse
62 |
63 | parser = argparse.ArgumentParser(description="DbRheo - 智能数据库Agent")
64 | parser.add_argument("--host", default="localhost", help="服务器主机地址")
65 | parser.add_argument("--port", type=int, default=8000, help="服务器端口")
66 | parser.add_argument("--reload", action="store_true", help="开发模式(自动重载)")
67 | parser.add_argument("--log-level", default="INFO", help="日志级别")
68 |
69 | args = parser.parse_args()
70 |
71 | # 如果环境变量中设置了 DEBUG=true,自动启用 reload
72 | if not args.reload and os.getenv('DBRHEO_DEBUG', '').lower() == 'true':
73 | args.reload = True
74 | print("Debug mode detected, enabling auto-reload")
75 |
76 | # 设置日志
77 | setup_logging(args.log_level)
78 |
79 | # 启动服务器
80 | try:
81 | import uvicorn
82 |
83 | # 如果启用了 reload,必须使用字符串格式的应用路径
84 | if args.reload:
85 | uvicorn.run(
86 | "dbrheo.api.app:app", # 模块路径字符串
87 | host=args.host,
88 | port=args.port,
89 | reload=True,
90 | reload_dirs=["packages/core/src"], # 监控的目录
91 | log_level=args.log_level.lower()
92 | )
93 | else:
94 | # 不使用 reload 时,可以直接传递应用对象
95 | app = create_app()
96 | uvicorn.run(
97 | app,
98 | host=args.host,
99 | port=args.port,
100 | reload=False,
101 | log_level=args.log_level.lower()
102 | )
103 | except ImportError:
104 | print("Error: uvicorn not installed. Please install with: pip install uvicorn")
105 | sys.exit(1)
106 | except KeyboardInterrupt:
107 | print("\nShutting down DbRheo server...")
108 | except Exception as e:
109 | print(f"Error starting server: {e}")
110 | sys.exit(1)
111 |
112 |
113 | if __name__ == "__main__":
114 | main()
115 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/core/turn.py:
--------------------------------------------------------------------------------
1 | """
2 | DatabaseTurn - Turn系统实现
3 | 管理单个对话轮次,只收集工具调用不执行,完全对齐Gemini CLI
4 | """
5 |
6 | from typing import List, AsyncIterator
7 | from ..types.core_types import PartListUnion, AbortSignal
8 | from ..types.tool_types import ToolCallRequestInfo
9 | from .chat import DatabaseChat
10 | from ..utils.debug_logger import DebugLogger
11 |
12 |
13 | class DatabaseTurn:
14 | """
15 | 管理单个对话轮次的执行 - 只收集,不执行
16 | - 单个对话轮次管理
17 | - 只收集工具调用,绝不执行(pending_tool_calls)
18 | - 流式响应处理
19 | - 事件生成和传递
20 | """
21 |
22 | def __init__(self, chat: DatabaseChat, prompt_id: str):
23 | self.chat = chat
24 | self.prompt_id = prompt_id
25 | self.pending_tool_calls: List[ToolCallRequestInfo] = [] # 只收集
26 |
27 | async def run(self, request: PartListUnion, signal: AbortSignal) -> AsyncIterator[dict]:
28 | """
29 | 执行Turn - 收集工具调用但不执行
30 | 严格遵循Gemini CLI的Turn执行模式
31 | """
32 | # 1. 发送请求到LLM(包含完整历史)
33 | response_stream = self.chat.send_message_stream(request, self.prompt_id)
34 |
35 | # 2. 流式处理响应,收集工具调用
36 | chunk_count = 0
37 | async for chunk in response_stream:
38 | chunk_count += 1
39 | # 使用优化的日志
40 | if DebugLogger.get_rules()["show_chunk_details"]:
41 | DebugLogger.log_turn_event("chunk_received", chunk)
42 | # 处理文本内容
43 | if chunk.get('text'):
44 | yield {'type': 'Content', 'value': chunk['text']}
45 |
46 | # 处理思维内容
47 | if chunk.get('thought'):
48 | yield {'type': 'Thought', 'value': chunk['thought']}
49 |
50 | # 处理工具调用
51 | if chunk.get('function_calls'):
52 | for call in chunk['function_calls']:
53 | # 生成调用ID(如果没有提供)- 参考 Gemini CLI
54 | import time
55 | import random
56 | call_id = call.get('id') or f"{call['name']}-{int(time.time() * 1000)}-{random.randint(1000, 9999)}"
57 |
58 | # 关键:只收集,不执行(与Gemini CLI完全一致)
59 | tool_request = ToolCallRequestInfo(
60 | call_id=call_id,
61 | name=call['name'],
62 | args=call['args'],
63 | is_client_initiated=False,
64 | prompt_id=self.prompt_id
65 | )
66 | self.pending_tool_calls.append(tool_request)
67 | yield {'type': 'ToolCallRequest', 'value': tool_request}
68 | DebugLogger.log_turn_event("tool_request", tool_request)
69 |
70 | # 处理错误
71 | if chunk.get('type') == 'error':
72 | yield {'type': 'Error', 'value': chunk.get('error', 'Unknown error')}
73 |
74 | # 处理 token 使用信息 - 新增事件类型
75 | if chunk.get('token_usage'):
76 | # 详细调试信息
77 | from ..utils.debug_logger import log_info
78 | log_info("Turn", f"🔴 TOKEN EVENT - Turn {self.prompt_id} emitting TokenUsage event:")
79 | log_info("Turn", f" - prompt_tokens: {chunk['token_usage'].get('prompt_tokens', 0)}")
80 | log_info("Turn", f" - completion_tokens: {chunk['token_usage'].get('completion_tokens', 0)}")
81 | log_info("Turn", f" - total_tokens: {chunk['token_usage'].get('total_tokens', 0)}")
82 | # 添加调试日志
83 | DebugLogger.log_turn_event("token_usage", chunk['token_usage'])
84 | yield {'type': 'TokenUsage', 'value': chunk['token_usage']}
85 |
86 | DebugLogger.log_turn_event("summary", chunk_count)
87 |
88 | # 3. Turn结束,pending_tool_calls留给调度器处理
89 | # 绝不在Turn中执行工具!
90 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/core/compression.py:
--------------------------------------------------------------------------------
1 | """
2 | 历史压缩机制 - 完全参考Gemini CLI的tryCompressChat
3 | 当对话历史过长时自动压缩,保持Token在限制范围内
4 | """
5 |
6 | from typing import List, Optional, Dict, Any
7 | from ..types.core_types import Content
8 | from .chat import DatabaseChat
9 |
10 |
11 | async def try_compress_chat(
12 | chat: DatabaseChat,
13 | prompt_id: str,
14 | force: bool = False
15 | ) -> Optional[Dict[str, Any]]:
16 | """
17 | 历史压缩 - 完全参考Gemini CLI的tryCompressChat
18 |
19 | 压缩策略:
20 | 1. 计算当前历史的token数量
21 | 2. 如果超过阈值(70%),触发压缩
22 | 3. 保留最近30%的详细历史
23 | 4. 压缩前70%的历史为摘要
24 | 5. 更新历史记录
25 | """
26 | curated_history = chat.get_history(True) # 获取清理后的历史
27 |
28 | if not curated_history:
29 | return None
30 |
31 | # 计算token数量
32 | token_count = await _count_tokens(curated_history)
33 | if token_count is None:
34 | return None
35 |
36 | # 压缩阈值:70%(与Gemini CLI一致)
37 | token_limit = _get_token_limit(chat.config.get_model())
38 | compression_threshold = 0.7 * token_limit
39 |
40 | if not force and token_count < compression_threshold:
41 | return None
42 |
43 | # 确定压缩边界:保留最近30%的历史
44 | preserve_threshold = 0.3
45 | compress_before_index = _find_index_after_fraction(
46 | curated_history, 1 - preserve_threshold
47 | )
48 |
49 | # 找到下一个用户消息作为Turn边界
50 | while (compress_before_index < len(curated_history) and
51 | curated_history[compress_before_index].get('role') != 'user'):
52 | compress_before_index += 1
53 |
54 | history_to_compress = curated_history[:compress_before_index]
55 | history_to_keep = curated_history[compress_before_index:]
56 |
57 | # 执行压缩(使用专门的压缩提示词)
58 | compressed_summary = await _compress_history_segment(
59 | history_to_compress, prompt_id
60 | )
61 |
62 | # 更新历史:压缩摘要 + 保留的详细历史
63 | chat.set_history([
64 | {'role': 'user', 'parts': [{'text': compressed_summary}]},
65 | *history_to_keep
66 | ])
67 |
68 | return {
69 | 'original_token_count': token_count,
70 | 'compressed_token_count': await _count_tokens(chat.get_history(True)),
71 | 'compression_ratio': len(history_to_compress) / len(curated_history)
72 | }
73 |
74 |
75 | async def _count_tokens(history: List[Content]) -> Optional[int]:
76 | """计算历史记录的token数量"""
77 | # TODO: 实现实际的token计数逻辑
78 | # 可以使用tiktoken或调用Gemini API的计数接口
79 | total_chars = 0
80 | for content in history:
81 | for part in content.get('parts', []):
82 | if part.get('text'):
83 | total_chars += len(part['text'])
84 |
85 | # 粗略估算:4个字符约等于1个token
86 | return total_chars // 4
87 |
88 |
89 | def _get_token_limit(model: str) -> int:
90 | """获取模型的token限制"""
91 | model_limits = {
92 | 'gemini-1.5-pro': 2000000,
93 | 'gemini-1.5-flash': 1000000,
94 | 'gemini-1.0-pro': 30720
95 | }
96 | return model_limits.get(model, 30720)
97 |
98 |
99 | def _find_index_after_fraction(history: List[Content], fraction: float) -> int:
100 | """找到指定比例后的索引位置"""
101 | target_index = int(len(history) * fraction)
102 | return min(target_index, len(history) - 1)
103 |
104 |
105 | async def _compress_history_segment(
106 | history_segment: List[Content],
107 | prompt_id: str
108 | ) -> str:
109 | """压缩历史片段为摘要"""
110 | # TODO: 实现实际的历史压缩逻辑
111 | # 使用专门的压缩提示词调用LLM生成摘要
112 |
113 | # 临时实现:简单的文本摘要
114 | summary_parts = []
115 | for content in history_segment:
116 | role = content.get('role', 'unknown')
117 | for part in content.get('parts', []):
118 | if part.get('text'):
119 | text = part['text'][:100] + "..." if len(part['text']) > 100 else part['text']
120 | summary_parts.append(f"{role}: {text}")
121 |
122 | return f"[压缩的对话历史摘要]\n" + "\n".join(summary_parts[-10:]) # 保留最后10条
123 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/adapters/base.py:
--------------------------------------------------------------------------------
1 | """
2 | DatabaseAdapter基类 - 数据库适配器基础接口
3 | 提供统一的数据库操作接口,支持多数据库方言
4 | """
5 |
6 | from abc import ABC, abstractmethod
7 | from typing import Any, Dict, List, Optional, AsyncIterator, TYPE_CHECKING
8 | from ..types.core_types import AbortSignal
9 | from .dialect_parser import SQLDialectParser, DatabaseDialect
10 |
11 | if TYPE_CHECKING:
12 | from .transaction_manager import DatabaseTransactionManager
13 |
14 |
15 | class DatabaseAdapter(ABC):
16 | """
17 | 数据库适配器基类
18 | - 统一的数据库操作接口
19 | - 连接管理抽象
20 | - 方言转换接口
21 | """
22 |
23 | def __init__(self, connection_string: str, **kwargs):
24 | self.connection_string = connection_string
25 | self.config = kwargs
26 | self.transaction_manager: Optional["DatabaseTransactionManager"] = None
27 | self.dialect_parser: Optional[SQLDialectParser] = None
28 |
29 | @abstractmethod
30 | async def connect(self) -> None:
31 | """建立数据库连接"""
32 | pass
33 |
34 | @abstractmethod
35 | async def disconnect(self) -> None:
36 | """关闭数据库连接"""
37 | pass
38 |
39 | @abstractmethod
40 | async def execute_query(
41 | self,
42 | sql: str,
43 | params: Optional[Dict[str, Any]] = None,
44 | signal: Optional[AbortSignal] = None
45 | ) -> Dict[str, Any]:
46 | """执行查询并返回结果"""
47 | pass
48 |
49 | @abstractmethod
50 | async def execute_command(
51 | self,
52 | sql: str,
53 | params: Optional[Dict[str, Any]] = None,
54 | signal: Optional[AbortSignal] = None
55 | ) -> Dict[str, Any]:
56 | """执行命令(INSERT、UPDATE、DELETE等)"""
57 | pass
58 |
59 | @abstractmethod
60 | async def get_schema_info(self, schema_name: Optional[str] = None) -> Dict[str, Any]:
61 | """获取数据库结构信息"""
62 | pass
63 |
64 | @abstractmethod
65 | async def get_table_info(self, table_name: str) -> Dict[str, Any]:
66 | """获取表结构信息"""
67 | pass
68 |
69 | @abstractmethod
70 | async def parse_sql(self, sql: str) -> Dict[str, Any]:
71 | """解析SQL语句"""
72 | pass
73 |
74 | @abstractmethod
75 | def get_dialect(self) -> str:
76 | """获取数据库方言"""
77 | pass
78 |
79 | async def apply_limit_if_needed(self, sql: str, limit: int) -> str:
80 | """
81 | 智能地应用LIMIT子句(如果需要)
82 | 基于SQL解析而非硬编码字符串匹配
83 | 默认实现,子类可以覆盖以提供更智能的处理
84 | """
85 | # 解析SQL以检查是否已经有LIMIT
86 | try:
87 | parsed_info = await self.parse_sql(sql)
88 | sql_type = parsed_info.get('sql_type', '').upper()
89 |
90 | # 只对SELECT查询应用LIMIT
91 | # 支持适配器返回的各种SELECT标识方式
92 | if sql_type not in ['SELECT', 'QUERY']:
93 | return sql
94 |
95 | # 检查是否已经包含LIMIT(通过SQL解析而非字符串匹配)
96 | if parsed_info.get('has_limit', False):
97 | # 已经有LIMIT,不重复添加
98 | return sql
99 | else:
100 | # 添加LIMIT子句,让子类决定具体的SQL语法
101 | return self._append_limit_clause(sql, limit)
102 |
103 | except Exception:
104 | # 解析失败,使用保守的字符串检查作为后备
105 | # 但这应该是异常情况
106 | if 'LIMIT' not in sql.upper():
107 | return self._append_limit_clause(sql, limit)
108 | return sql
109 |
110 | def _append_limit_clause(self, sql: str, limit: int) -> str:
111 | """
112 | 添加LIMIT子句的默认实现
113 | 子类可以覆盖以处理特定方言的语法
114 | """
115 | # 移除末尾的分号(如果有)
116 | sql = sql.rstrip().rstrip(';')
117 | return f"{sql} LIMIT {limit}"
118 |
119 | async def health_check(self) -> bool:
120 | """连接健康检查"""
121 | try:
122 | await self.execute_query("SELECT 1")
123 | return True
124 | except:
125 | return False
126 |
--------------------------------------------------------------------------------
/packages/cli/src/dbrheo_cli/handlers/input_handler.py:
--------------------------------------------------------------------------------
1 | """
2 | 输入处理器
3 | 处理用户输入,包括命令解析、特殊按键处理等。
4 | """
5 |
6 | import sys
7 | import asyncio
8 | from typing import Optional
9 |
10 | from ..ui.console import console
11 | from ..app.config import CLIConfig
12 | from ..i18n import _
13 |
14 | # 尝试导入增强输入组件(可选功能)
15 | try:
16 | from ..ui.simple_multiline_input import EnhancedInputHandler
17 | ENHANCED_INPUT_AVAILABLE = True
18 | except ImportError:
19 | ENHANCED_INPUT_AVAILABLE = False
20 |
21 |
22 | class InputHandler:
23 | """
24 | 用户输入处理器
25 | - 获取用户输入
26 | - 处理特殊按键
27 | - 管理输入历史(通过readline)
28 | """
29 |
30 | def __init__(self, config: CLIConfig):
31 | self.config = config
32 |
33 | # 初始化增强输入处理器(如果可用)
34 | self.enhanced_handler = None
35 | if ENHANCED_INPUT_AVAILABLE:
36 | try:
37 | self.enhanced_handler = EnhancedInputHandler(config, console)
38 | from dbrheo.utils.debug_logger import log_info
39 | log_info("InputHandler", "Enhanced input mode available")
40 | except Exception as e:
41 | # 初始化失败,使用传统模式
42 | self.enhanced_handler = None
43 | from dbrheo.utils.debug_logger import log_info
44 | log_info("InputHandler", f"Enhanced input initialization failed: {e}")
45 |
46 | async def get_input(self) -> str:
47 | """
48 | 异步获取用户输入
49 | 使用asyncio兼容的方式读取输入
50 | """
51 | # 优先使用增强输入(如果可用)
52 | if self.enhanced_handler:
53 | try:
54 | return await self.enhanced_handler.get_input()
55 | except Exception as e:
56 | # 增强输入失败,回退到传统模式
57 | from dbrheo.utils.debug_logger import log_info
58 | log_info("InputHandler", f"Enhanced input failed, falling back: {e}")
59 | self.enhanced_handler = None
60 |
61 | # 使用传统输入模式
62 | loop = asyncio.get_event_loop()
63 |
64 | try:
65 | # 在线程池中执行input
66 | user_input = await loop.run_in_executor(
67 | None,
68 | self._blocking_input
69 | )
70 | return user_input.strip()
71 | except EOFError:
72 | # Ctrl+D
73 | raise
74 | except KeyboardInterrupt:
75 | # Ctrl+C
76 | raise
77 |
78 | def _blocking_input(self) -> str:
79 | """阻塞式输入(在线程池中执行)"""
80 | try:
81 | # 添加分隔线,让输入区域更明显
82 | if not hasattr(self, '_first_input'):
83 | self._first_input = False
84 | else:
85 | console.print() # 简洁的空行分隔
86 |
87 | # 获取输入
88 | if self.config.no_color:
89 | first_line = input("> ")
90 | else:
91 | # 使用Rich的prompt功能
92 | first_line = console.input("[bold cyan]>[/bold cyan] ")
93 |
94 | # 检查是否进入多行模式
95 | # 支持 ``` 或 <<< 作为多行输入标记
96 | if first_line.strip() in ['```', '<<<']:
97 | console.print(f"[dim]{_('multiline_mode_hint')}[/dim]")
98 | lines = []
99 | while True:
100 | try:
101 | if self.config.no_color:
102 | line = input("... ")
103 | else:
104 | line = console.input("[dim]...[/dim] ")
105 |
106 | if line.strip() in ['```', '<<<']:
107 | break
108 | lines.append(line)
109 | except EOFError:
110 | break
111 | return "\n".join(lines)
112 |
113 | return first_line
114 | except KeyboardInterrupt:
115 | # 在input时按Ctrl+C
116 | raise
117 | except EOFError:
118 | # Ctrl+D
119 | raise
--------------------------------------------------------------------------------
/packages/cli/src/dbrheo_cli/app/config.py:
--------------------------------------------------------------------------------
1 | """
2 | CLI配置管理
3 | 保持灵活性,避免硬编码,支持运行时配置
4 | """
5 |
6 | import os
7 | from typing import Optional, Dict, Any
8 | from dataclasses import dataclass
9 |
10 | from ..constants import ENV_VARS, DEFAULTS
11 |
12 |
13 | @dataclass
14 | class CLIConfig:
15 | """
16 | CLI专用配置
17 | - 命令行参数
18 | - 显示设置
19 | - 用户偏好
20 | """
21 | # 数据库配置
22 | db_file: Optional[str] = None
23 |
24 | # 显示配置
25 | no_color: bool = False
26 | page_size: int = DEFAULTS['PAGE_SIZE']
27 | max_width: int = DEFAULTS['MAX_WIDTH']
28 |
29 | # 配置文件
30 | config_file: Optional[str] = None
31 |
32 | # 历史记录
33 | history_file: str = os.path.expanduser(DEFAULTS['HISTORY_FILE'])
34 | max_history: int = DEFAULTS['MAX_HISTORY']
35 |
36 | # 调试选项
37 | show_thoughts: bool = False # 是否显示AI思考过程
38 | show_tool_details: bool = True # 是否显示工具执行详情
39 |
40 | # 布局选项 - 新增但保持向后兼容
41 | enhanced_layout: bool = False # 是否使用增强布局(底部固定输入框)
42 |
43 | def __post_init__(self):
44 | """初始化后处理,确保配置的合理性"""
45 | # 确保历史文件目录存在
46 | history_dir = os.path.dirname(self.history_file)
47 | if history_dir and not os.path.exists(history_dir):
48 | os.makedirs(history_dir, exist_ok=True)
49 |
50 | # 从环境变量更新配置(环境变量优先级低于命令行参数)
51 | if not self.db_file and ENV_VARS['DB_FILE'] in os.environ:
52 | self.db_file = os.environ[ENV_VARS['DB_FILE']]
53 |
54 | if ENV_VARS['NO_COLOR'] in os.environ:
55 | self.no_color = os.environ[ENV_VARS['NO_COLOR']].lower() == 'true'
56 |
57 | if ENV_VARS['PAGE_SIZE'] in os.environ:
58 | try:
59 | self.page_size = int(os.environ[ENV_VARS['PAGE_SIZE']])
60 | except ValueError:
61 | pass
62 |
63 | if ENV_VARS['SHOW_THOUGHTS'] in os.environ:
64 | self.show_thoughts = os.environ[ENV_VARS['SHOW_THOUGHTS']].lower() == 'true'
65 |
66 | if ENV_VARS['MAX_WIDTH'] in os.environ:
67 | try:
68 | self.max_width = int(os.environ[ENV_VARS['MAX_WIDTH']])
69 | except ValueError:
70 | pass
71 |
72 | if ENV_VARS['MAX_HISTORY'] in os.environ:
73 | try:
74 | self.max_history = int(os.environ[ENV_VARS['MAX_HISTORY']])
75 | except ValueError:
76 | pass
77 |
78 | if ENV_VARS['HISTORY_FILE'] in os.environ:
79 | self.history_file = os.path.expanduser(os.environ[ENV_VARS['HISTORY_FILE']])
80 |
81 | # 增强布局选项 - 从环境变量读取
82 | if 'DBRHEO_ENHANCED_LAYOUT' in os.environ:
83 | self.enhanced_layout = os.environ['DBRHEO_ENHANCED_LAYOUT'].lower() == 'true'
84 |
85 | def to_dict(self) -> Dict[str, Any]:
86 | """转换为字典格式"""
87 | return {
88 | 'db_file': self.db_file,
89 | 'no_color': self.no_color,
90 | 'page_size': self.page_size,
91 | 'max_width': self.max_width,
92 | 'config_file': self.config_file,
93 | 'history_file': self.history_file,
94 | 'max_history': self.max_history,
95 | 'show_thoughts': self.show_thoughts,
96 | 'show_tool_details': self.show_tool_details,
97 | 'enhanced_layout': self.enhanced_layout
98 | }
99 |
100 | def update_runtime(self, key: str, value: Any):
101 | """
102 | 运行时更新配置
103 | 支持动态修改配置而不需要重启
104 | """
105 | if hasattr(self, key):
106 | setattr(self, key, value)
107 | else:
108 | raise ValueError(f"Unknown configuration key: {key}")
109 |
110 | def get_display_config(self) -> Dict[str, Any]:
111 | """获取显示相关的配置"""
112 | return {
113 | 'no_color': self.no_color,
114 | 'page_size': self.page_size,
115 | 'max_width': self.max_width,
116 | 'show_thoughts': self.show_thoughts,
117 | 'show_tool_details': self.show_tool_details
118 | }
--------------------------------------------------------------------------------
/packages/cli/src/dbrheo_cli/handlers/event_handler.py:
--------------------------------------------------------------------------------
1 | """
2 | 事件处理器
3 | 负责处理send_message_stream产生的各种事件。
4 | 对应Gemini CLI的useGeminiStream中的事件处理逻辑。
5 | """
6 |
7 | from typing import Dict, Any, Optional
8 | from dbrheo.utils.debug_logger import DebugLogger, log_info
9 |
10 | from ..ui.console import console
11 | from ..ui.messages import show_agent_message, show_error_message, show_system_message, show_tool_call
12 | from ..ui.streaming import StreamDisplay
13 | from ..i18n import _
14 | from ..app.config import CLIConfig
15 |
16 |
17 | class EventHandler:
18 | """
19 | 核心事件处理器
20 | 处理所有从后端接收的事件类型
21 | 支持传统console输出和增强布局管理器
22 | """
23 |
24 | def __init__(self, config: CLIConfig):
25 | self.config = config
26 | self.stream_display = StreamDisplay(config)
27 | self.display_target = None # 可选的显示目标(布局管理器)
28 |
29 | def set_display_target(self, target):
30 | """设置显示目标 - 最小侵入性集成点"""
31 | self.display_target = target
32 |
33 | async def process(self, event: Dict[str, Any]):
34 | """处理单个事件"""
35 | event_type = event.get('type', '')
36 | value = event.get('value', '')
37 |
38 | if DebugLogger.should_log("DEBUG"):
39 | log_info("EventHandler", f"Processing event: {event_type}")
40 |
41 | # 根据事件类型分发处理
42 | if event_type == 'Content':
43 | await self._handle_content(value)
44 | elif event_type == 'Thought':
45 | await self._handle_thought(value)
46 | elif event_type == 'ToolCallRequest':
47 | await self._handle_tool_request(value)
48 | elif event_type == 'Error':
49 | await self._handle_error(value)
50 | elif event_type == 'AwaitingConfirmation':
51 | await self._handle_awaiting_confirmation(value)
52 | elif event_type == 'max_session_turns':
53 | await self._handle_max_turns(value)
54 | elif event_type == 'chat_compressed':
55 | await self._handle_chat_compressed(value)
56 | else:
57 | # 未知事件类型
58 | if DebugLogger.should_log("DEBUG"):
59 | log_info("EventHandler", f"Unknown event type: {event_type}")
60 |
61 | async def _handle_content(self, content: str):
62 | """处理AI响应内容"""
63 | # 调试:检查内容
64 | if DebugLogger.should_log("DEBUG"):
65 | log_info("EventHandler", f"Content received: {repr(content[:100])}")
66 |
67 | # 使用流式显示
68 | await self.stream_display.add_content(content)
69 |
70 | async def _handle_thought(self, thought: str):
71 | """处理AI思考过程"""
72 | if self.config.show_thoughts:
73 | console.print(f"[dim italic]{thought}[/dim italic]", end='')
74 |
75 | async def _handle_tool_request(self, tool_data: Any):
76 | """处理工具调用请求"""
77 | # tool_data 是一个对象,不是字典
78 | tool_name = getattr(tool_data, 'name', 'unknown')
79 |
80 | # 显示工具调用提示
81 | show_tool_call(tool_name)
82 |
83 | # 工具请求由tool_handler处理,这里只记录日志
84 | if DebugLogger.should_log("DEBUG"):
85 | log_info("EventHandler", f"Tool request: {tool_name}")
86 |
87 | async def _handle_error(self, error: str):
88 | """处理错误事件"""
89 | show_error_message(error)
90 |
91 | async def _handle_awaiting_confirmation(self, data: Any):
92 | """处理等待确认事件"""
93 | # 确认提示由tool_handler显示,这里只需要结束流式显示
94 | await self.stream_display.finish()
95 | if DebugLogger.should_log("DEBUG"):
96 | log_info("EventHandler", "Awaiting confirmation, breaking event loop")
97 |
98 | async def _handle_max_turns(self, data: Any):
99 | """处理达到最大会话轮数"""
100 | show_system_message(_('max_session_turns'))
101 |
102 | async def _handle_chat_compressed(self, data: Any):
103 | """处理会话压缩通知"""
104 | if DebugLogger.should_log("INFO"):
105 | show_system_message(_('chat_compressed'))
106 |
107 | def show_user_message(self, message: str):
108 | """显示用户消息"""
109 | # 先结束之前的流式显示
110 | self.stream_display.finish_sync()
111 |
112 | # 用户消息已经在输入时显示了,这里只需要添加空行
113 | console.print() # 添加空行准备显示AI响应
--------------------------------------------------------------------------------
/testdata/old.adult.names:
--------------------------------------------------------------------------------
1 | 1. Title of Database: adult
2 | 2. Sources:
3 | (a) Original owners of database (name/phone/snail address/email address)
4 | US Census Bureau.
5 | (b) Donor of database (name/phone/snail address/email address)
6 | Ronny Kohavi and Barry Becker,
7 | Data Mining and Visualization
8 | Silicon Graphics.
9 | e-mail: ronnyk@sgi.com
10 | (c) Date received (databases may change over time without name change!)
11 | 05/19/96
12 | 3. Past Usage:
13 | (a) Complete reference of article where it was described/used
14 | @inproceedings{kohavi-nbtree,
15 | author={Ron Kohavi},
16 | title={Scaling Up the Accuracy of Naive-Bayes Classifiers: a
17 | Decision-Tree Hybrid},
18 | booktitle={Proceedings of the Second International Conference on
19 | Knowledge Discovery and Data Mining},
20 | year = 1996,
21 | pages={to appear}}
22 | (b) Indication of what attribute(s) were being predicted
23 | Salary greater or less than 50,000.
24 | (b) Indication of study's results (i.e. Is it a good domain to use?)
25 | Hard domain with a nice number of records.
26 | The following results obtained using MLC++ with default settings
27 | for the algorithms mentioned below.
28 |
29 | Algorithm Error
30 | -- ---------------- -----
31 | 1 C4.5 15.54
32 | 2 C4.5-auto 14.46
33 | 3 C4.5 rules 14.94
34 | 4 Voted ID3 (0.6) 15.64
35 | 5 Voted ID3 (0.8) 16.47
36 | 6 T2 16.84
37 | 7 1R 19.54
38 | 8 NBTree 14.10
39 | 9 CN2 16.00
40 | 10 HOODG 14.82
41 | 11 FSS Naive Bayes 14.05
42 | 12 IDTM (Decision table) 14.46
43 | 13 Naive-Bayes 16.12
44 | 14 Nearest-neighbor (1) 21.42
45 | 15 Nearest-neighbor (3) 20.35
46 | 16 OC1 15.04
47 | 17 Pebls Crashed. Unknown why (bounds WERE increased)
48 |
49 | 4. Relevant Information Paragraph:
50 | Extraction was done by Barry Becker from the 1994 Census database. A set
51 | of reasonably clean records was extracted using the following conditions:
52 | ((AAGE>16) && (AGI>100) && (AFNLWGT>1)&& (HRSWK>0))
53 |
54 | 5. Number of Instances
55 | 48842 instances, mix of continuous and discrete (train=32561, test=16281)
56 | 45222 if instances with unknown values are removed (train=30162, test=15060)
57 | Split into train-test using MLC++ GenCVFiles (2/3, 1/3 random).
58 |
59 | 6. Number of Attributes
60 | 6 continuous, 8 nominal attributes.
61 |
62 | 7. Attribute Information:
63 |
64 | age: continuous.
65 | workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.
66 | fnlwgt: continuous.
67 | education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
68 | education-num: continuous.
69 | marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.
70 | occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.
71 | relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
72 | race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
73 | sex: Female, Male.
74 | capital-gain: continuous.
75 | capital-loss: continuous.
76 | hours-per-week: continuous.
77 | native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.
78 | class: >50K, <=50K
79 |
80 | 8. Missing Attribute Values:
81 |
82 | 7% have missing values.
83 |
84 | 9. Class Distribution:
85 |
86 | Probability for the label '>50K' : 23.93% / 24.78% (without unknowns)
87 | Probability for the label '<=50K' : 76.07% / 75.22% (without unknowns)
88 |
89 |
90 |
--------------------------------------------------------------------------------
/.env.example:
--------------------------------------------------------------------------------
1 | # DbRheo Environment Configuration File
2 | # API Configuration (at least one API key is required)
3 | # After pulling, configure this section - no other modifications needed
4 | GOOGLE_API_KEY={your_google_api_key_here}
5 | # Alternatively, use DBRHEO_API_KEY=your_google_api_key_here
6 |
7 | # Claude API (Anthropic)
8 | ANTHROPIC_API_KEY={your_anthropic_api_key_here}
9 |
10 | # OpenAI API
11 | OPENAI_API_KEY={your_openai_api_key_here}
12 |
13 |
14 | # OPENAI_API_BASE=https://api.openai.com/v1
15 |
16 | # Model Configuration (Default: gemini-2.5-flash)
17 | # Recommended: Comment out this line and use /model command to auto-save to config.yaml
18 | # Using /model command will automatically save the model setting to config.yaml
19 | #DBRHEO_MODEL=gemini-2.5-flash
20 |
21 | # Server Configuration
22 | DBRHEO_HOST=localhost
23 | DBRHEO_PORT=8000
24 | DBRHEO_DEBUG=true
25 |
26 | # Logging Configuration
27 | DBRHEO_LOG_LEVEL=INFO
28 |
29 | # Agent Configuration
30 | DBRHEO_MAX_TURNS=100
31 | DBRHEO_COMPRESSION_THRESHOLD=0.7
32 | DBRHEO_AUTO_EXECUTE=false
33 | DBRHEO_ALLOW_DANGEROUS=false
34 |
35 | # Monitoring Configuration (Optional)
36 | OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4317
37 | OTEL_SERVICE_NAME=dbrheo
38 |
39 | # Code Execution Tool Configuration (Optional)
40 | # CODE_EXECUTION_LANGUAGES=python,javascript,shell,sql
41 | # CODE_EXECUTION_MAX_OUTPUT=1048576
42 | # CODE_EXECUTION_TEMP_DIR=/tmp
43 |
44 | # Shell Tool Security Configuration (Optional)
45 | # SHELL_WHITELIST=git,ls,pwd,cat,grep,find,ps,df,du,whoami,mysql,mysqldump,psql,pg_dump
46 | # SHELL_BLACKLIST=rm,sudo,chmod,mkfs,format,fdisk,dd,reboot,shutdown
47 | # SHELL_DB_COMMANDS=mysql,psql,sqlite3,mysqldump,pg_dump,mongodump,redis-cli
48 | # SHELL_MAX_OUTPUT=1048576
49 | # SHELL_STRICT_WHITELIST=false
50 |
51 | # Enhanced Input Configuration
52 | DBRHEO_ENHANCED_INPUT=true
53 |
54 | # Multi-line Input Configuration
55 | DBRHEO_MULTILINE_ENABLED=true
56 | DBRHEO_MULTILINE_INDICATOR=...
57 | DBRHEO_MAX_DISPLAY_LINES=10
58 |
59 | # Multi-line Input Mode Configuration
60 | DBRHEO_MULTILINE_END_MODE=empty_line # empty_line or double_empty
61 | DBRHEO_AUTO_MULTILINE=true # Auto-detection of SQL and unclosed quotes/brackets
62 | DBRHEO_AUTO_PASTE_DETECTION=true # Auto-detection of multi-line paste (core feature)
63 | DBRHEO_MIN_PASTE_LINES=2 # Minimum lines to detect as paste
64 | DBRHEO_MAX_PASTE_LINES=100 # Maximum lines limit for paste content
65 | DBRHEO_SHOW_PASTE_PREVIEW=true # Show paste preview
66 | DBRHEO_DEBUG_PASTE=true # Debug paste detection
67 |
68 | # Windows-specific Configuration
69 | DBRHEO_PASTE_MAX_ATTEMPTS=5 # Maximum collection attempts in Windows environment
70 | DBRHEO_PASTE_WAIT_TIME=0.5 # Wait time in Windows environment (seconds)
71 |
72 | # Configurable Triggers and Keywords
73 | DBRHEO_SQL_KEYWORDS=SELECT,INSERT,UPDATE,DELETE,CREATE,ALTER,DROP,WITH,EXPLAIN,DESCRIBE
74 | # Multi-line triggers: triple_quote_double("""), triple_quote_single('''), backticks(```), angle_brackets(<<<)
75 | DBRHEO_MULTILINE_TRIGGERS=triple_quote_double,triple_quote_single,backticks,angle_brackets
76 |
77 | # UI Style Configuration
78 | DBRHEO_PROMPT_STYLE=[bold cyan]{prompt}[/bold cyan]
79 | DBRHEO_CONTINUATION_STYLE=[dim]{indicator}[/dim]
80 | DBRHEO_SQL_HINT=[dim]SQL statement detected, entering multi-line mode (end with empty line)[/dim]
81 | DBRHEO_UNCLOSED_HINT=[dim]Unclosed quotes/brackets detected, entering multi-line mode[/dim]
82 | DBRHEO_BLOCK_HINT=[dim]Multi-line input mode, enter {marker} again to end[/dim]
83 | DBRHEO_PASTE_HINT=[dim]Multi-line paste content detected ({lines} lines), processing automatically...[/dim]
84 |
85 | # Disable Color Output
86 | DBRHEO_NO_COLOR=true
87 |
88 | # Token Usage Warning Threshold
89 | DBRHEO_TOKEN_WARNING_THRESHOLD=200000
90 |
91 | # Database Connection Environment Variables (Optional)
92 | # Used in connection strings with environment variables to avoid hardcoded passwords
93 | # Example: mysql://admin:${DB_PASSWORD}@host:3306/database
94 | # DB_PASSWORD=your_secure_password_here
95 | # DB_USER=your_database_user
96 | # DB_HOST=your_database_host
97 | # DB_NAME=your_database_name
98 |
99 | # SSH Tunnel Configuration (Enterprise Database Connections)
100 | # SSH_BASTION_HOST=bastion.company.com
101 | # SSH_BASTION_USER=ec2-user
102 | # SSH_KEY_PATH=~/.ssh/production-key.pem
103 |
104 | # Cloud Database Authentication (Examples)
105 | # AWS_RDS_ENDPOINT=mydb.123456.us-east-1.rds.amazonaws.com
106 | # AZURE_DB_SERVER=myserver.database.windows.net
107 | # GCP_CLOUDSQL_INSTANCE=project:region:instance
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/types/tool_types.py:
--------------------------------------------------------------------------------
1 | """
2 | 工具相关类型定义 - 完全对齐Gemini CLI的工具系统
3 | 包括工具结果、确认机制、状态管理等类型
4 | """
5 |
6 | from typing import Union, Optional, Dict, List, Any, Callable
7 | from dataclasses import dataclass
8 | from enum import Enum
9 | from .core_types import PartListUnion
10 |
11 |
12 | @dataclass
13 | class ToolResult:
14 | """
15 | 工具执行结果 - 完全对齐Gemini CLI的ToolResult
16 | """
17 | summary: Optional[str] = None # 可选的简短摘要
18 | llm_content: str = "" # LLM看到的内容
19 | return_display: Optional[str] = None # 用户看到的格式化内容
20 | error: Optional[str] = None # 错误信息
21 |
22 |
23 | @dataclass
24 | class ToolCallRequestInfo:
25 | """工具调用请求信息 - 对应Gemini CLI的ToolCallRequestInfo"""
26 | call_id: str # 对应callId
27 | name: str # 工具名称
28 | args: Dict[str, Any] # 工具参数
29 | is_client_initiated: bool = False # 对应isClientInitiated
30 | prompt_id: str = "" # 对应prompt_id
31 |
32 |
33 | @dataclass
34 | class ToolCallResponseInfo:
35 | """工具调用响应信息"""
36 | call_id: str # 对应callId
37 | response_parts: PartListUnion # 对应responseParts
38 | result_display: Optional[str] = None # 对应resultDisplay
39 | error: Optional[Exception] = None # 错误信息
40 |
41 |
42 | class DatabaseConfirmationOutcome(Enum):
43 | """确认结果 - 完全对齐Gemini CLI的ToolConfirmationOutcome"""
44 | PROCEED_ONCE = "proceed_once"
45 | PROCEED_ALWAYS = "proceed_always"
46 | PROCEED_ALWAYS_SERVER = "proceed_always_server" # 数据库服务器级总是允许
47 | PROCEED_ALWAYS_TOOL = "proceed_always_tool" # 工具级总是允许
48 | MODIFY_WITH_EDITOR = "modify_with_editor" # 编辑器修改SQL
49 | CANCEL = "cancel"
50 |
51 |
52 | @dataclass
53 | class DatabaseConfirmationDetails:
54 | """数据库确认详情基类"""
55 | type: str
56 | title: str
57 | on_confirm: Optional[Callable] = None
58 |
59 |
60 | @dataclass
61 | class SQLExecuteConfirmationDetails:
62 | """SQL执行确认 - 对应Gemini CLI的ToolExecuteConfirmationDetails"""
63 | title: str # 必需字段
64 | sql_query: str # 对应command字段
65 | root_operation: str # 对应rootCommand字段
66 | type: str = 'sql_execute' # 默认值字段放后面
67 | risk_assessment: Optional[Dict[str, Any]] = None # 风险评估详情
68 | estimated_impact: Optional[int] = None
69 | on_confirm: Optional[Callable] = None
70 |
71 |
72 | # 工具调用状态类型(完全对齐Gemini CLI的状态机)
73 | @dataclass
74 | class ValidatingToolCall:
75 | request: Optional[ToolCallRequestInfo] = None
76 | tool: Optional[Any] = None # DatabaseTool类型
77 | status: str = 'validating'
78 | start_time: Optional[float] = None
79 |
80 |
81 | @dataclass
82 | class ScheduledToolCall:
83 | request: Optional[ToolCallRequestInfo] = None
84 | tool: Optional[Any] = None
85 | status: str = 'scheduled'
86 | start_time: Optional[float] = None
87 |
88 |
89 | @dataclass
90 | class ExecutingToolCall:
91 | request: Optional[ToolCallRequestInfo] = None
92 | tool: Optional[Any] = None
93 | status: str = 'executing'
94 | live_output: Optional[str] = None
95 | start_time: Optional[float] = None
96 |
97 |
98 | @dataclass
99 | class SuccessfulToolCall:
100 | request: Optional[ToolCallRequestInfo] = None
101 | tool: Optional[Any] = None
102 | response: Optional[ToolCallResponseInfo] = None
103 | status: str = 'success'
104 | duration_ms: Optional[float] = None
105 |
106 |
107 | @dataclass
108 | class ErroredToolCall:
109 | request: Optional[ToolCallRequestInfo] = None
110 | response: Optional[ToolCallResponseInfo] = None
111 | status: str = 'error'
112 | duration_ms: Optional[float] = None
113 |
114 |
115 | @dataclass
116 | class CancelledToolCall:
117 | request: Optional[ToolCallRequestInfo] = None
118 | tool: Optional[Any] = None
119 | response: Optional[ToolCallResponseInfo] = None
120 | status: str = 'cancelled'
121 | duration_ms: Optional[float] = None
122 |
123 |
124 | @dataclass
125 | class WaitingToolCall:
126 | request: Optional[ToolCallRequestInfo] = None
127 | tool: Optional[Any] = None
128 | confirmation_details: Optional[DatabaseConfirmationDetails] = None
129 | status: str = 'awaiting_approval'
130 | start_time: Optional[float] = None
131 |
132 |
133 | # 完整的状态联合类型(与Gemini CLI完全一致)
134 | ToolCall = Union[
135 | ValidatingToolCall,
136 | ScheduledToolCall,
137 | ExecutingToolCall,
138 | SuccessfulToolCall,
139 | ErroredToolCall,
140 | CancelledToolCall,
141 | WaitingToolCall
142 | ]
143 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/services/llm_factory.py:
--------------------------------------------------------------------------------
1 | """
2 | LLM 服务工厂 - 根据配置动态创建不同的 LLM 服务
3 | 设计原则:灵活性、可扩展性、最小侵入性
4 | """
5 |
6 | from typing import Dict, Type, Optional
7 | from ..config.base import DatabaseConfig
8 | from ..utils.debug_logger import log_info
9 |
10 |
11 | class LLMServiceFactory:
12 | """
13 | LLM 服务工厂类
14 | - 动态服务创建
15 | - 灵活的模型映射
16 | - 易于扩展新模型
17 | """
18 |
19 | # 模型映射表 - 避免硬编码,便于扩展
20 | MODEL_MAPPINGS: Dict[str, Dict[str, str]] = {
21 | "gemini": {
22 | "module": "gemini_service_new",
23 | "class": "GeminiService",
24 | "prefixes": ["gemini", "models/gemini"] # 支持多种前缀
25 | },
26 | "claude": {
27 | "module": "claude_service",
28 | "class": "ClaudeService",
29 | "prefixes": ["claude", "anthropic", "sonnet", "opus", "haiku"] # 支持 sonnet4, opus4 等
30 | },
31 | "openai": {
32 | "module": "openai_service",
33 | "class": "OpenAIService",
34 | "prefixes": ["gpt", "openai", "o1"]
35 | }
36 | }
37 |
38 | @staticmethod
39 | def create_llm_service(config: DatabaseConfig):
40 | """
41 | 根据配置创建相应的 LLM 服务
42 |
43 | Args:
44 | config: 数据库配置对象
45 |
46 | Returns:
47 | LLM 服务实例
48 |
49 | Raises:
50 | ValueError: 当模型不被支持时
51 | """
52 | model_name = config.get_model()
53 |
54 | # 查找匹配的服务
55 | service_info = LLMServiceFactory._find_service_for_model(model_name)
56 |
57 | if not service_info:
58 | # 如果没有找到匹配,默认使用 Gemini(保持向后兼容)
59 | log_info("LLMFactory", f"Model '{model_name}' not recognized, using Gemini as default")
60 | service_info = LLMServiceFactory.MODEL_MAPPINGS["gemini"]
61 |
62 | # 动态导入和创建服务
63 | try:
64 | module_name = service_info['module']
65 | class_name = service_info['class']
66 |
67 | # 动态导入 - 使用相对导入
68 | from importlib import import_module
69 | full_module_name = f".{module_name}"
70 | module = import_module(full_module_name, package='dbrheo.services')
71 | service_class = getattr(module, class_name)
72 |
73 | # 创建实例
74 | log_info("LLMFactory", f"Creating {class_name} for model '{model_name}'")
75 | return service_class(config)
76 |
77 | except ImportError as e:
78 | # 如果服务类还未实现,回退到 Gemini
79 | log_info("LLMFactory", f"Failed to import {service_info['module']}: {e}")
80 | log_info("LLMFactory", "Falling back to GeminiService")
81 |
82 | from .gemini_service_new import GeminiService
83 | return GeminiService(config)
84 | except ValueError as e:
85 | # 配置错误(如缺少 API key)
86 | log_info("LLMFactory", f"Configuration error for {class_name}: {e}")
87 | raise
88 |
89 | except Exception as e:
90 | log_info("LLMFactory", f"Error creating service: {e}")
91 | raise
92 |
93 | @staticmethod
94 | def _find_service_for_model(model_name: str) -> Optional[Dict[str, str]]:
95 | """
96 | 根据模型名称查找对应的服务信息
97 |
98 | Args:
99 | model_name: 模型名称
100 |
101 | Returns:
102 | 服务信息字典或 None
103 | """
104 | model_lower = model_name.lower()
105 |
106 | # 遍历所有映射,查找匹配的前缀
107 | for service_name, service_info in LLMServiceFactory.MODEL_MAPPINGS.items():
108 | for prefix in service_info["prefixes"]:
109 | if model_lower.startswith(prefix.lower()):
110 | return service_info
111 |
112 | return None
113 |
114 | @staticmethod
115 | def register_model_mapping(service_name: str, module: str, class_name: str, prefixes: list):
116 | """
117 | 注册新的模型映射(便于运行时扩展)
118 |
119 | Args:
120 | service_name: 服务名称
121 | module: 模块名
122 | class_name: 类名
123 | prefixes: 模型前缀列表
124 | """
125 | LLMServiceFactory.MODEL_MAPPINGS[service_name] = {
126 | "module": module,
127 | "class": class_name,
128 | "prefixes": prefixes
129 | }
130 | log_info("LLMFactory", f"Registered new model mapping: {service_name}")
131 |
132 |
133 | # 导出便捷函数
134 | def create_llm_service(config: DatabaseConfig):
135 | """便捷函数 - 创建 LLM 服务"""
136 | return LLMServiceFactory.create_llm_service(config)
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/utils/errors.py:
--------------------------------------------------------------------------------
1 | """
2 | 自定义异常类 - 提供结构化的错误处理
3 | 定义数据库Agent特有的异常类型
4 | """
5 |
6 | from typing import Optional, Dict, Any
7 |
8 |
9 | class DatabaseAgentError(Exception):
10 | """数据库Agent基础异常类"""
11 |
12 | def __init__(
13 | self,
14 | message: str,
15 | error_code: Optional[str] = None,
16 | details: Optional[Dict[str, Any]] = None
17 | ):
18 | super().__init__(message)
19 | self.message = message
20 | self.error_code = error_code
21 | self.details = details or {}
22 |
23 | def to_dict(self) -> Dict[str, Any]:
24 | """转换为字典格式"""
25 | return {
26 | "error_type": self.__class__.__name__,
27 | "message": self.message,
28 | "error_code": self.error_code,
29 | "details": self.details
30 | }
31 |
32 |
33 | class ToolExecutionError(DatabaseAgentError):
34 | """工具执行异常"""
35 |
36 | def __init__(
37 | self,
38 | tool_name: str,
39 | message: str,
40 | original_error: Optional[Exception] = None,
41 | **kwargs
42 | ):
43 | super().__init__(message, **kwargs)
44 | self.tool_name = tool_name
45 | self.original_error = original_error
46 |
47 | def to_dict(self) -> Dict[str, Any]:
48 | result = super().to_dict()
49 | result.update({
50 | "tool_name": self.tool_name,
51 | "original_error": str(self.original_error) if self.original_error else None
52 | })
53 | return result
54 |
55 |
56 | class ValidationError(DatabaseAgentError):
57 | """参数验证异常"""
58 |
59 | def __init__(
60 | self,
61 | field_name: str,
62 | message: str,
63 | invalid_value: Any = None,
64 | **kwargs
65 | ):
66 | super().__init__(message, **kwargs)
67 | self.field_name = field_name
68 | self.invalid_value = invalid_value
69 |
70 | def to_dict(self) -> Dict[str, Any]:
71 | result = super().to_dict()
72 | result.update({
73 | "field_name": self.field_name,
74 | "invalid_value": self.invalid_value
75 | })
76 | return result
77 |
78 |
79 | class DatabaseConnectionError(DatabaseAgentError):
80 | """数据库连接异常"""
81 |
82 | def __init__(
83 | self,
84 | database_name: str,
85 | message: str,
86 | connection_string: Optional[str] = None,
87 | **kwargs
88 | ):
89 | super().__init__(message, **kwargs)
90 | self.database_name = database_name
91 | self.connection_string = connection_string
92 |
93 | def to_dict(self) -> Dict[str, Any]:
94 | result = super().to_dict()
95 | result.update({
96 | "database_name": self.database_name,
97 | "connection_string": self.connection_string
98 | })
99 | return result
100 |
101 |
102 | class SQLExecutionError(DatabaseAgentError):
103 | """SQL执行异常"""
104 |
105 | def __init__(
106 | self,
107 | sql: str,
108 | message: str,
109 | error_position: Optional[int] = None,
110 | **kwargs
111 | ):
112 | super().__init__(message, **kwargs)
113 | self.sql = sql
114 | self.error_position = error_position
115 |
116 | def to_dict(self) -> Dict[str, Any]:
117 | result = super().to_dict()
118 | result.update({
119 | "sql": self.sql,
120 | "error_position": self.error_position
121 | })
122 | return result
123 |
124 |
125 | class ConfigurationError(DatabaseAgentError):
126 | """配置异常"""
127 |
128 | def __init__(
129 | self,
130 | config_key: str,
131 | message: str,
132 | **kwargs
133 | ):
134 | super().__init__(message, **kwargs)
135 | self.config_key = config_key
136 |
137 | def to_dict(self) -> Dict[str, Any]:
138 | result = super().to_dict()
139 | result.update({
140 | "config_key": self.config_key
141 | })
142 | return result
143 |
144 |
145 | class PermissionError(DatabaseAgentError):
146 | """权限异常"""
147 |
148 | def __init__(
149 | self,
150 | operation: str,
151 | resource: str,
152 | message: str,
153 | **kwargs
154 | ):
155 | super().__init__(message, **kwargs)
156 | self.operation = operation
157 | self.resource = resource
158 |
159 | def to_dict(self) -> Dict[str, Any]:
160 | result = super().to_dict()
161 | result.update({
162 | "operation": self.operation,
163 | "resource": self.resource
164 | })
165 | return result
166 |
--------------------------------------------------------------------------------
/packages/cli/src/dbrheo_cli/ui/startup.py:
--------------------------------------------------------------------------------
1 | """
2 | DbRheo 启动画面
3 | 使用 rich-gradient 实现优雅的渐变效果
4 | """
5 |
6 | import os
7 | from typing import Optional, List, Tuple
8 | from rich.console import Console
9 | from rich.panel import Panel
10 | from rich.text import Text
11 | from rich.align import Align
12 | from rich.columns import Columns
13 |
14 | from .ascii_art import select_logo, get_logo_width, LONG_LOGO, EXTRA_LARGE_LOGO
15 | from ..i18n import _
16 | from ..app.config import CLIConfig
17 |
18 | # 尝试导入 rich-gradient,提供优雅降级
19 | try:
20 | from rich_gradient import Gradient
21 | GRADIENT_AVAILABLE = True
22 | except ImportError:
23 | GRADIENT_AVAILABLE = False
24 |
25 | # 颜色主题
26 | DBRHEO_GRADIENT_COLORS = ["#000033", "#001155", "#0033AA", "#0055FF", "#3377FF"] # 蓝黑渐变
27 | TIPS_COLOR = "#8899AA" # 提示文字颜色
28 |
29 |
30 | class StartupScreen:
31 | """启动画面管理器"""
32 |
33 | def __init__(self, config: CLIConfig, console: Console):
34 | self.config = config
35 | self.console = console
36 | self.terminal_width = console.width
37 |
38 | def display(self, version: str = "0.2.0", show_tips: bool = True,
39 | custom_message: Optional[str] = None, logo_style: str = "italic"):
40 | """
41 | 显示启动画面
42 |
43 | Args:
44 | version: 版本号
45 | show_tips: 是否显示使用提示
46 | custom_message: 自定义消息(如警告)
47 | logo_style: logo 风格 - "default", "italic", "extra"
48 | """
49 | # 选择合适的 logo
50 | logo = select_logo(self.terminal_width, style=logo_style)
51 |
52 | # 显示 logo(带渐变效果)
53 | self._display_logo(logo)
54 |
55 | # 显示版本信息
56 | self._display_version(version)
57 |
58 | # 显示使用提示
59 | if show_tips:
60 | self._display_tips()
61 |
62 | # 显示自定义消息(如工作目录警告)
63 | if custom_message:
64 | self._display_custom_message(custom_message)
65 |
66 | # 添加底部间距
67 | self.console.print()
68 |
69 | def _display_logo(self, logo: str):
70 | """显示带渐变效果的 logo"""
71 | if GRADIENT_AVAILABLE and not self.config.no_color:
72 | # 使用 rich-gradient 实现渐变
73 | gradient_logo = Gradient(
74 | logo.strip(),
75 | colors=DBRHEO_GRADIENT_COLORS,
76 | justify="left" # 改为左对齐
77 | )
78 | self.console.print(gradient_logo)
79 | else:
80 | # 降级方案:使用简单的蓝色
81 | self.console.print(
82 | Text(logo.strip(), style="bold blue"),
83 | justify="left" # 改为左对齐
84 | )
85 |
86 | def _display_version(self, version: str):
87 | """显示版本信息"""
88 | version_text = f"v{version}"
89 | if GRADIENT_AVAILABLE and not self.config.no_color:
90 | version_gradient = Gradient(
91 | version_text,
92 | colors=DBRHEO_GRADIENT_COLORS[::-1], # 反向渐变
93 | justify="right"
94 | )
95 | self.console.print(version_gradient)
96 | else:
97 | self.console.print(
98 | Text(version_text, style="dim cyan"),
99 | justify="right"
100 | )
101 |
102 | def _display_tips(self):
103 | """显示使用提示"""
104 | tips = [
105 | _('startup_tip_1'),
106 | _('startup_tip_2'),
107 | _('startup_tip_3'),
108 | _('startup_tip_4'),
109 | _('startup_tip_5'),
110 | _('startup_tip_6')
111 | ]
112 |
113 | self.console.print()
114 | self.console.print(_('startup_tips_title'), style=f"bold {TIPS_COLOR}")
115 | for tip in tips:
116 | self.console.print(f" {tip}", style=TIPS_COLOR)
117 |
118 | def _display_custom_message(self, message: str):
119 | """显示自定义消息(如警告框)"""
120 | self.console.print()
121 | panel = Panel(
122 | message,
123 | border_style="yellow",
124 | padding=(0, 2)
125 | )
126 | self.console.print(panel)
127 |
128 |
129 |
130 | def create_minimal_startup(console: Console, version: str = "0.2.0"):
131 | """创建最小化的启动信息(用于 --quiet 模式)"""
132 | console.print(f"[bold blue]DbRheo[/bold blue] v{version}")
133 |
134 |
135 | def create_rainbow_logo(logo: str) -> Optional[str]:
136 | """创建彩虹效果的 logo(特殊场合使用)"""
137 | if not GRADIENT_AVAILABLE:
138 | return None
139 |
140 | try:
141 | rainbow_logo = Gradient(
142 | logo.strip(),
143 | rainbow=True,
144 | justify="center"
145 | )
146 | return rainbow_logo
147 | except:
148 | return None
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/telemetry/tracer.py:
--------------------------------------------------------------------------------
1 | """
2 | DatabaseTracer - 分布式追踪系统
3 | 基于OpenTelemetry实现,完全对齐Gemini CLI的追踪机制
4 | """
5 |
6 | import os
7 | import logging
8 | from typing import Optional, Dict, Any, Callable
9 | from functools import wraps
10 | from contextlib import contextmanager
11 |
12 | # 有条件导入OpenTelemetry,允许在没有安装的情况下降级
13 | try:
14 | from opentelemetry import trace
15 | from opentelemetry.sdk.trace import TracerProvider
16 | from opentelemetry.sdk.trace.export import BatchSpanProcessor
17 | from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
18 | from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
19 | OTEL_AVAILABLE = True
20 | except ImportError:
21 | OTEL_AVAILABLE = False
22 |
23 | from ..config.base import DatabaseConfig
24 |
25 |
26 | class DatabaseTracer:
27 | """
28 | 数据库Agent分布式追踪系统
29 | - 完全对齐Gemini CLI的追踪机制
30 | - 支持OpenTelemetry集成
31 | - 提供装饰器和上下文管理器API
32 | - 支持降级模式(未安装OpenTelemetry时)
33 | """
34 |
35 | def __init__(self, config: DatabaseConfig):
36 | self.config = config
37 | self.service_name = config.get("service_name", "database-agent")
38 | self.enabled = config.get("telemetry_enabled", True)
39 |
40 | # 初始化追踪器
41 | self.tracer = self._setup_tracer() if self.enabled else None
42 |
43 | def _setup_tracer(self):
44 | """设置OpenTelemetry追踪器"""
45 | if not OTEL_AVAILABLE:
46 | logging.warning("OpenTelemetry not available, tracing disabled")
47 | return None
48 |
49 | try:
50 | # 设置追踪提供者
51 | provider = TracerProvider()
52 | trace.set_tracer_provider(provider)
53 |
54 | # 配置导出器
55 | otlp_endpoint = self.config.get("otel_exporter_otlp_endpoint")
56 | if otlp_endpoint:
57 | otlp_exporter = OTLPSpanExporter(endpoint=otlp_endpoint)
58 | span_processor = BatchSpanProcessor(otlp_exporter)
59 | provider.add_span_processor(span_processor)
60 |
61 | # 创建追踪器
62 | return trace.get_tracer(self.service_name)
63 |
64 | except Exception as e:
65 | logging.error(f"Failed to setup tracer: {e}")
66 | return None
67 |
68 | def trace(self, name: str, attributes: Optional[Dict[str, Any]] = None):
69 | """
70 | 追踪装饰器 - 完全对齐Gemini CLI的trace装饰器
71 | 用于追踪函数执行
72 | """
73 | def decorator(func):
74 | @wraps(func)
75 | async def async_wrapper(*args, **kwargs):
76 | if not self.enabled or not self.tracer:
77 | return await func(*args, **kwargs)
78 |
79 | with self.tracer.start_as_current_span(name, attributes=attributes):
80 | return await func(*args, **kwargs)
81 |
82 | @wraps(func)
83 | def sync_wrapper(*args, **kwargs):
84 | if not self.enabled or not self.tracer:
85 | return func(*args, **kwargs)
86 |
87 | with self.tracer.start_as_current_span(name, attributes=attributes):
88 | return func(*args, **kwargs)
89 |
90 | return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
91 | return decorator
92 |
93 | @contextmanager
94 | def span(self, name: str, attributes: Optional[Dict[str, Any]] = None):
95 | """
96 | 追踪上下文管理器 - 完全对齐Gemini CLI的withSpan
97 | 用于追踪代码块执行
98 | """
99 | if not self.enabled or not self.tracer:
100 | yield
101 | return
102 |
103 | with self.tracer.start_as_current_span(name, attributes=attributes):
104 | yield
105 |
106 | def add_event(self, name: str, attributes: Optional[Dict[str, Any]] = None):
107 | """向当前span添加事件"""
108 | if not self.enabled or not self.tracer:
109 | return
110 |
111 | current_span = trace.get_current_span()
112 | if current_span:
113 | current_span.add_event(name, attributes=attributes)
114 |
115 | def set_attribute(self, key: str, value: Any):
116 | """设置当前span的属性"""
117 | if not self.enabled or not self.tracer:
118 | return
119 |
120 | current_span = trace.get_current_span()
121 | if current_span:
122 | current_span.set_attribute(key, value)
123 |
124 | def record_exception(self, exception: Exception):
125 | """记录异常"""
126 | if not self.enabled or not self.tracer:
127 | return
128 |
129 | current_span = trace.get_current_span()
130 | if current_span:
131 | current_span.record_exception(exception)
132 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/utils/log_integration.py:
--------------------------------------------------------------------------------
1 | """
2 | 日志系统集成 - 将实时日志集成到DbRheo的各个组件
3 | """
4 |
5 | from typing import Dict, Any, Optional
6 | from functools import wraps
7 | import asyncio
8 |
9 | from .realtime_logger import (
10 | get_logger, log_conversation, log_tool_call,
11 | log_tool_result, log_error, log_system
12 | )
13 |
14 |
15 | def log_chat_interaction(func):
16 | """装饰器 - 记录Chat交互(支持异步生成器)"""
17 | @wraps(func)
18 | async def wrapper(self, *args, **kwargs):
19 | # 记录用户输入
20 | if args and isinstance(args[0], list):
21 | messages = args[0]
22 | for msg in messages:
23 | if isinstance(msg, dict) and 'text' in msg:
24 | log_conversation("User", msg['text'])
25 | elif args and isinstance(args[0], str):
26 | log_conversation("User", args[0])
27 |
28 | try:
29 | # 执行原函数 - 返回异步生成器
30 | async for chunk in func(self, *args, **kwargs):
31 | yield chunk
32 |
33 | # 记录AI响应(在生成器完成后)
34 | if hasattr(self, 'history') and self.history:
35 | last_response = None
36 | for item in reversed(self.history):
37 | if item.get('role') == 'model':
38 | last_response = item
39 | break
40 |
41 | if last_response:
42 | model_text = ""
43 | for part in last_response.get('parts', []):
44 | if isinstance(part, dict) and 'text' in part:
45 | model_text += part['text']
46 | if model_text:
47 | log_conversation("Agent", model_text)
48 |
49 | except Exception as e:
50 | log_error(self.__class__.__name__, str(e))
51 | raise
52 |
53 | return wrapper
54 |
55 |
56 | def log_tool_execution(func):
57 | """装饰器 - 记录工具执行"""
58 | @wraps(func)
59 | async def wrapper(self, params: Dict[str, Any], *args, **kwargs):
60 | tool_name = getattr(self, 'name', self.__class__.__name__)
61 |
62 | # 获取call_id
63 | call_id = ""
64 | if len(args) >= 2 and hasattr(args[1], 'get'):
65 | call_id = args[1].get('call_id', '')
66 |
67 | # 记录工具调用
68 | log_tool_call(tool_name, params, call_id)
69 |
70 | try:
71 | # 执行工具
72 | result = await func(self, params, *args, **kwargs)
73 |
74 | # 记录结果
75 | success = True
76 | result_summary = None
77 |
78 | if hasattr(result, 'error') and result.error:
79 | success = False
80 | result_summary = result.error
81 | elif hasattr(result, 'summary'):
82 | result_summary = result.summary
83 | elif hasattr(result, 'llm_content'):
84 | result_summary = result.llm_content[:200] + "..." if len(result.llm_content) > 200 else result.llm_content
85 | else:
86 | result_summary = str(result)[:200]
87 |
88 | log_tool_result(tool_name, result_summary, success, call_id)
89 |
90 | return result
91 | except Exception as e:
92 | log_tool_result(tool_name, str(e), False, call_id)
93 | raise
94 |
95 | return wrapper
96 |
97 |
98 | def log_scheduler_activity(func):
99 | """装饰器 - 记录调度器活动"""
100 | @wraps(func)
101 | async def wrapper(self, *args, **kwargs):
102 | try:
103 | result = await func(self, *args, **kwargs)
104 |
105 | # 记录调度状态变化
106 | if hasattr(self, 'tool_calls'):
107 | for call_id, call_state in self.tool_calls.items():
108 | state = call_state.get('state', 'unknown')
109 | tool_name = call_state.get('name', 'unknown')
110 | log_system(f"Tool {tool_name} state: {state}", call_id=call_id)
111 |
112 | return result
113 | except Exception as e:
114 | log_error("Scheduler", str(e))
115 | raise
116 |
117 | return wrapper
118 |
119 |
120 | class LoggingMixin:
121 | """日志混入类 - 为类添加日志功能"""
122 |
123 | def log_info(self, message: str, **kwargs):
124 | """记录信息"""
125 | log_system(f"[{self.__class__.__name__}] {message}", **kwargs)
126 |
127 | def log_error(self, message: str, **kwargs):
128 | """记录错误"""
129 | log_error(self.__class__.__name__, message, **kwargs)
130 |
131 | def log_performance(self, metric: str, value: float, unit: str = "ms", **kwargs):
132 | """记录性能指标"""
133 | from .realtime_logger import get_logger, LogEvent, LogEventType
134 | get_logger().log_performance(f"{self.__class__.__name__}.{metric}", value, unit, **kwargs)
135 |
136 |
137 | # 快速集成函数
138 | def integrate_logging():
139 | """快速集成日志到现有组件"""
140 | # 改为最小侵入性方式:不使用装饰器,因为装饰器会改变异步生成器的行为
141 | # 日志记录已经直接添加到 chat.py 和 scheduler.py 中
142 | log_system("实时日志系统已启用(非侵入模式)", source="LogIntegration")
143 |
144 |
145 | # 环境变量控制的自动集成
146 | import os
147 | if os.environ.get('DBRHEO_ENABLE_REALTIME_LOG', '').lower() == 'true':
148 | integrate_logging()
149 | log_system("实时日志系统已自动启用", source="Environment")
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/config/test_config.py:
--------------------------------------------------------------------------------
1 | """
2 | TestDatabaseConfig - 测试专用配置类
3 | 继承DatabaseConfig但支持运行时配置覆盖,保持灵活性原则
4 | """
5 |
6 | from typing import Any, Dict, Optional
7 | from pathlib import Path
8 | from .base import DatabaseConfig, ConfigSource
9 |
10 |
11 | class TestConfigSource(ConfigSource):
12 | """测试配置源 - 最高优先级,用于测试时覆盖配置"""
13 |
14 | def __init__(self, test_config: Dict[str, Any]):
15 | self._test_config = test_config
16 |
17 | def get(self, key: str) -> Optional[Any]:
18 | return self._test_config.get(key)
19 |
20 | def get_all(self) -> Dict[str, Any]:
21 | return self._test_config.copy()
22 |
23 |
24 | class TestDatabaseConfig(DatabaseConfig):
25 | """
26 | 测试专用配置类
27 | - 继承完整的分层配置系统
28 | - 支持运行时测试配置覆盖
29 | - 提供便捷的测试数据库设置API
30 | - 保持与Gemini CLI设计原则的完全对齐
31 | """
32 |
33 | def __init__(self, workspace_root: Optional[Path] = None, test_overrides: Optional[Dict[str, Any]] = None):
34 | """
35 | 初始化测试配置
36 |
37 | 参数:
38 | workspace_root: 工作区根目录
39 | test_overrides: 测试覆盖配置,具有最高优先级
40 | """
41 | # 先调用父类初始化
42 | super().__init__(workspace_root)
43 |
44 | # 如果有测试覆盖配置,添加到配置源列表的最前面(最高优先级)
45 | if test_overrides:
46 | self._test_overrides = test_overrides
47 | # 将测试配置源插入到最前面
48 | self.config_sources.insert(0, TestConfigSource(test_overrides))
49 | else:
50 | self._test_overrides = {}
51 |
52 | def set_test_database(self, database_name: str, database_config: Dict[str, Any]) -> None:
53 | """
54 | 设置测试数据库配置
55 |
56 | 参数:
57 | database_name: 数据库名称(如 'default', 'test')
58 | database_config: 数据库配置字典,包含type, database等
59 |
60 | 示例:
61 | config.set_test_database('default', {
62 | 'type': 'sqlite',
63 | 'database': '/path/to/test.db'
64 | })
65 | """
66 | if 'databases' not in self._test_overrides:
67 | self._test_overrides['databases'] = {}
68 |
69 | self._test_overrides['databases'][database_name] = database_config
70 |
71 | # 更新配置源
72 | if self.config_sources and isinstance(self.config_sources[0], TestConfigSource):
73 | # 更新现有的测试配置源
74 | self.config_sources[0] = TestConfigSource(self._test_overrides)
75 | else:
76 | # 插入新的测试配置源
77 | self.config_sources.insert(0, TestConfigSource(self._test_overrides))
78 |
79 | def set_test_config(self, key: str, value: Any) -> None:
80 | """
81 | 设置任意测试配置
82 |
83 | 参数:
84 | key: 配置键,支持嵌套(用.分隔)
85 | value: 配置值
86 |
87 | 示例:
88 | config.set_test_config('model', 'gemini-2.5-flash')
89 | config.set_test_config('databases.test.type', 'sqlite')
90 | """
91 | keys = key.split('.')
92 | current = self._test_overrides
93 |
94 | # 创建嵌套结构
95 | for k in keys[:-1]:
96 | if k not in current:
97 | current[k] = {}
98 | current = current[k]
99 |
100 | # 设置最终值
101 | current[keys[-1]] = value
102 |
103 | # 更新配置源
104 | if self.config_sources and isinstance(self.config_sources[0], TestConfigSource):
105 | self.config_sources[0] = TestConfigSource(self._test_overrides)
106 | else:
107 | self.config_sources.insert(0, TestConfigSource(self._test_overrides))
108 |
109 | def get_test_config(self, key: str) -> Optional[Any]:
110 | """
111 | 获取测试配置
112 |
113 | 参数:
114 | key: 配置键
115 |
116 | 返回:
117 | 配置值,如果不存在返回 None
118 | """
119 | return self._test_overrides.get(key)
120 |
121 | def get_test_overrides(self) -> Dict[str, Any]:
122 | """获取当前的测试覆盖配置(用于调试)"""
123 | return self._test_overrides.copy()
124 |
125 | def clear_test_overrides(self) -> None:
126 | """清除所有测试覆盖配置"""
127 | self._test_overrides.clear()
128 | # 移除测试配置源
129 | if self.config_sources and isinstance(self.config_sources[0], TestConfigSource):
130 | self.config_sources.pop(0)
131 |
132 | @classmethod
133 | def create_with_sqlite_database(cls, db_path: str, database_name: str = 'default') -> 'TestDatabaseConfig':
134 | """
135 | 便捷方法:创建带有SQLite数据库的测试配置
136 |
137 | 参数:
138 | db_path: SQLite数据库文件路径
139 | database_name: 数据库名称
140 |
141 | 返回:
142 | 配置好的TestDatabaseConfig实例
143 | """
144 | config = cls()
145 | config.set_test_database(database_name, {
146 | 'type': 'sqlite',
147 | 'database': db_path
148 | })
149 | return config
150 |
151 | @classmethod
152 | def create_with_memory_database(cls, database_name: str = 'default') -> 'TestDatabaseConfig':
153 | """
154 | 便捷方法:创建带有内存SQLite数据库的测试配置
155 |
156 | 参数:
157 | database_name: 数据库名称
158 |
159 | 返回:
160 | 配置好的TestDatabaseConfig实例
161 | """
162 | config = cls()
163 | config.set_test_database(database_name, {
164 | 'type': 'sqlite',
165 | 'database': ':memory:'
166 | })
167 | return config
--------------------------------------------------------------------------------
/testdata/adult.names:
--------------------------------------------------------------------------------
1 | | This data was extracted from the census bureau database found at
2 | | http://www.census.gov/ftp/pub/DES/www/welcome.html
3 | | Donor: Ronny Kohavi and Barry Becker,
4 | | Data Mining and Visualization
5 | | Silicon Graphics.
6 | | e-mail: ronnyk@sgi.com for questions.
7 | | Split into train-test using MLC++ GenCVFiles (2/3, 1/3 random).
8 | | 48842 instances, mix of continuous and discrete (train=32561, test=16281)
9 | | 45222 if instances with unknown values are removed (train=30162, test=15060)
10 | | Duplicate or conflicting instances : 6
11 | | Class probabilities for adult.all file
12 | | Probability for the label '>50K' : 23.93% / 24.78% (without unknowns)
13 | | Probability for the label '<=50K' : 76.07% / 75.22% (without unknowns)
14 | |
15 | | Extraction was done by Barry Becker from the 1994 Census database. A set of
16 | | reasonably clean records was extracted using the following conditions:
17 | | ((AAGE>16) && (AGI>100) && (AFNLWGT>1)&& (HRSWK>0))
18 | |
19 | | Prediction task is to determine whether a person makes over 50K
20 | | a year.
21 | |
22 | | First cited in:
23 | | @inproceedings{kohavi-nbtree,
24 | | author={Ron Kohavi},
25 | | title={Scaling Up the Accuracy of Naive-Bayes Classifiers: a
26 | | Decision-Tree Hybrid},
27 | | booktitle={Proceedings of the Second International Conference on
28 | | Knowledge Discovery and Data Mining},
29 | | year = 1996,
30 | | pages={to appear}}
31 | |
32 | | Error Accuracy reported as follows, after removal of unknowns from
33 | | train/test sets):
34 | | C4.5 : 84.46+-0.30
35 | | Naive-Bayes: 83.88+-0.30
36 | | NBTree : 85.90+-0.28
37 | |
38 | |
39 | | Following algorithms were later run with the following error rates,
40 | | all after removal of unknowns and using the original train/test split.
41 | | All these numbers are straight runs using MLC++ with default values.
42 | |
43 | | Algorithm Error
44 | | -- ---------------- -----
45 | | 1 C4.5 15.54
46 | | 2 C4.5-auto 14.46
47 | | 3 C4.5 rules 14.94
48 | | 4 Voted ID3 (0.6) 15.64
49 | | 5 Voted ID3 (0.8) 16.47
50 | | 6 T2 16.84
51 | | 7 1R 19.54
52 | | 8 NBTree 14.10
53 | | 9 CN2 16.00
54 | | 10 HOODG 14.82
55 | | 11 FSS Naive Bayes 14.05
56 | | 12 IDTM (Decision table) 14.46
57 | | 13 Naive-Bayes 16.12
58 | | 14 Nearest-neighbor (1) 21.42
59 | | 15 Nearest-neighbor (3) 20.35
60 | | 16 OC1 15.04
61 | | 17 Pebls Crashed. Unknown why (bounds WERE increased)
62 | |
63 | | Conversion of original data as follows:
64 | | 1. Discretized agrossincome into two ranges with threshold 50,000.
65 | | 2. Convert U.S. to US to avoid periods.
66 | | 3. Convert Unknown to "?"
67 | | 4. Run MLC++ GenCVFiles to generate data,test.
68 | |
69 | | Description of fnlwgt (final weight)
70 | |
71 | | The weights on the CPS files are controlled to independent estimates of the
72 | | civilian noninstitutional population of the US. These are prepared monthly
73 | | for us by Population Division here at the Census Bureau. We use 3 sets of
74 | | controls.
75 | | These are:
76 | | 1. A single cell estimate of the population 16+ for each state.
77 | | 2. Controls for Hispanic Origin by age and sex.
78 | | 3. Controls by Race, age and sex.
79 | |
80 | | We use all three sets of controls in our weighting program and "rake" through
81 | | them 6 times so that by the end we come back to all the controls we used.
82 | |
83 | | The term estimate refers to population totals derived from CPS by creating
84 | | "weighted tallies" of any specified socio-economic characteristics of the
85 | | population.
86 | |
87 | | People with similar demographic characteristics should have
88 | | similar weights. There is one important caveat to remember
89 | | about this statement. That is that since the CPS sample is
90 | | actually a collection of 51 state samples, each with its own
91 | | probability of selection, the statement only applies within
92 | | state.
93 |
94 |
95 | >50K, <=50K.
96 |
97 | age: continuous.
98 | workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.
99 | fnlwgt: continuous.
100 | education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
101 | education-num: continuous.
102 | marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.
103 | occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.
104 | relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
105 | race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
106 | sex: Female, Male.
107 | capital-gain: continuous.
108 | capital-loss: continuous.
109 | hours-per-week: continuous.
110 | native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.
111 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/telemetry/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | DatabaseLogger - 结构化日志系统
3 | 完全对齐Gemini CLI的日志机制,支持结构化日志和OpenTelemetry集成
4 | """
5 |
6 | import json
7 | import logging
8 | import sys
9 | from typing import Dict, Any, Optional
10 | from datetime import datetime
11 | from pathlib import Path
12 |
13 | from ..config.base import DatabaseConfig
14 |
15 |
16 | def get_logger(name: str) -> logging.Logger:
17 | """
18 | 获取标准日志器(最小侵入性添加)
19 |
20 | Args:
21 | name: 日志器名称
22 |
23 | Returns:
24 | 标准 Python 日志器
25 | """
26 | return logging.getLogger(name)
27 |
28 |
29 | class DatabaseLogger:
30 | """
31 | 数据库Agent结构化日志系统
32 | - 完全对齐Gemini CLI的日志机制
33 | - 支持JSON格式和文本格式
34 | - 集成OpenTelemetry追踪信息
35 | - 支持多种输出目标
36 | """
37 |
38 | def __init__(self, config: DatabaseConfig):
39 | self.config = config
40 | self.service_name = config.get("service_name", "database-agent")
41 | self.log_level = config.get("log_level", "INFO")
42 | self.log_format = config.get("log_format", "text") # text | json
43 |
44 | # 设置日志器
45 | self.logger = self._setup_logger()
46 |
47 | def _setup_logger(self) -> logging.Logger:
48 | """设置日志器配置"""
49 | logger = logging.getLogger(self.service_name)
50 | logger.setLevel(getattr(logging, self.log_level.upper()))
51 |
52 | # 清除现有处理器
53 | logger.handlers.clear()
54 |
55 | # 控制台处理器
56 | console_handler = logging.StreamHandler(sys.stdout)
57 | console_handler.setLevel(getattr(logging, self.log_level.upper()))
58 |
59 | # 设置格式器
60 | if self.log_format == "json":
61 | formatter = JsonFormatter(self.service_name)
62 | else:
63 | formatter = TextFormatter()
64 |
65 | console_handler.setFormatter(formatter)
66 | logger.addHandler(console_handler)
67 |
68 | # 文件处理器(如果配置了)
69 | log_file = self.config.get("log_file")
70 | if log_file:
71 | file_handler = logging.FileHandler(log_file)
72 | file_handler.setLevel(getattr(logging, self.log_level.upper()))
73 | file_handler.setFormatter(formatter)
74 | logger.addHandler(file_handler)
75 |
76 | return logger
77 |
78 | def debug(self, message: str, **kwargs):
79 | """调试日志"""
80 | self._log(logging.DEBUG, message, **kwargs)
81 |
82 | def info(self, message: str, **kwargs):
83 | """信息日志"""
84 | self._log(logging.INFO, message, **kwargs)
85 |
86 | def warning(self, message: str, **kwargs):
87 | """警告日志"""
88 | self._log(logging.WARNING, message, **kwargs)
89 |
90 | def error(self, message: str, **kwargs):
91 | """错误日志"""
92 | self._log(logging.ERROR, message, **kwargs)
93 |
94 | def critical(self, message: str, **kwargs):
95 | """严重错误日志"""
96 | self._log(logging.CRITICAL, message, **kwargs)
97 |
98 | def _log(self, level: int, message: str, **kwargs):
99 | """内部日志方法"""
100 | # 添加追踪信息
101 | extra = {
102 | "service": self.service_name,
103 | "timestamp": datetime.utcnow().isoformat(),
104 | **kwargs
105 | }
106 |
107 | # 添加OpenTelemetry追踪信息(如果可用)
108 | try:
109 | from opentelemetry import trace
110 | current_span = trace.get_current_span()
111 | if current_span:
112 | span_context = current_span.get_span_context()
113 | extra.update({
114 | "trace_id": format(span_context.trace_id, "032x"),
115 | "span_id": format(span_context.span_id, "016x")
116 | })
117 | except ImportError:
118 | pass
119 |
120 | self.logger.log(level, message, extra=extra)
121 |
122 |
123 | class JsonFormatter(logging.Formatter):
124 | """JSON格式化器 - 完全对齐Gemini CLI的JSON日志格式"""
125 |
126 | def __init__(self, service_name: str):
127 | super().__init__()
128 | self.service_name = service_name
129 |
130 | def format(self, record: logging.LogRecord) -> str:
131 | """格式化日志记录为JSON"""
132 | log_entry = {
133 | "timestamp": datetime.fromtimestamp(record.created).isoformat(),
134 | "level": record.levelname,
135 | "service": self.service_name,
136 | "message": record.getMessage(),
137 | "module": record.module,
138 | "function": record.funcName,
139 | "line": record.lineno
140 | }
141 |
142 | # 添加额外字段
143 | if hasattr(record, 'service'):
144 | log_entry["service"] = record.service
145 | if hasattr(record, 'trace_id'):
146 | log_entry["trace_id"] = record.trace_id
147 | if hasattr(record, 'span_id'):
148 | log_entry["span_id"] = record.span_id
149 |
150 | # 添加其他自定义字段
151 | for key, value in record.__dict__.items():
152 | if key not in ['name', 'msg', 'args', 'levelname', 'levelno', 'pathname',
153 | 'filename', 'module', 'lineno', 'funcName', 'created',
154 | 'msecs', 'relativeCreated', 'thread', 'threadName',
155 | 'processName', 'process', 'getMessage', 'exc_info',
156 | 'exc_text', 'stack_info', 'message', 'service',
157 | 'timestamp', 'trace_id', 'span_id']:
158 | log_entry[key] = value
159 |
160 | return json.dumps(log_entry, ensure_ascii=False)
161 |
162 |
163 | class TextFormatter(logging.Formatter):
164 | """文本格式化器 - 人类可读的日志格式"""
165 |
166 | def __init__(self):
167 | super().__init__(
168 | fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
169 | datefmt='%Y-%m-%d %H:%M:%S'
170 | )
171 |
--------------------------------------------------------------------------------
/packages/cli/src/dbrheo_cli/ui/tools.py:
--------------------------------------------------------------------------------
1 | """
2 | 工具显示组件
3 | 显示工具执行状态、参数、结果等:
4 | - ToolStatus: 工具状态指示器
5 | - ToolResult: 工具结果展示
6 | - ToolConfirmation: 确认对话框
7 |
8 | 对应Gemini CLI的ToolMessage和ToolConfirmationMessage。
9 | """
10 |
11 | from typing import Dict, Any, Optional
12 | import json
13 | from .console import console
14 | from ..i18n import _
15 |
16 |
17 | def get_status_indicator(status: str) -> str:
18 | """获取状态指示器"""
19 | # 映射后端实际的状态值到显示文本
20 | status_map = {
21 | # 后端状态 -> i18n key
22 | 'validating': 'status_pending',
23 | 'scheduled': 'status_pending',
24 | 'awaiting_approval': 'status_confirm',
25 | 'executing': 'status_running',
26 | 'success': 'status_success',
27 | 'error': 'status_error',
28 | 'cancelled': 'status_cancelled',
29 | # 兼容前端可能的状态名
30 | 'pending': 'status_pending',
31 | 'approved': 'status_approved',
32 | 'completed': 'status_success',
33 | 'failed': 'status_error',
34 | 'rejected': 'status_cancelled'
35 | }
36 | return _(status_map.get(status, 'status_unknown'))
37 |
38 | # 风险级别颜色
39 | RISK_COLORS = {
40 | 'low': 'green',
41 | 'medium': 'yellow',
42 | 'high': 'red'
43 | }
44 |
45 |
46 | def show_tool_status(tool_name: str, status: str):
47 | """显示工具状态"""
48 | indicator = get_status_indicator(status)
49 |
50 | # 根据状态选择颜色
51 | if status in ['success', 'completed', 'approved']:
52 | color = 'success'
53 | elif status in ['error', 'failed', 'rejected', 'cancelled']:
54 | color = 'error'
55 | elif status in ['executing']:
56 | color = 'info'
57 | elif status in ['awaiting_approval']:
58 | color = 'warning'
59 | elif status in ['validating', 'scheduled', 'pending']:
60 | color = 'dim'
61 | else:
62 | color = 'dim'
63 |
64 | console.print(f"[{color}]{indicator} {tool_name}[/{color}]")
65 |
66 |
67 | def show_tool_result(tool_name: str, result: Any):
68 | """显示工具执行结果"""
69 | console.print(f"\n[info]→ [{tool_name}] {_('tool_result')}:[/info]")
70 |
71 | # 根据结果类型进行不同的显示
72 | if isinstance(result, dict):
73 | # JSON格式化显示
74 | try:
75 | result_str = json.dumps(result, indent=2, ensure_ascii=False)
76 | console.print(result_str)
77 | except:
78 | console.print(str(result))
79 | elif isinstance(result, list):
80 | # 列表显示
81 | for item in result[:10]: # 最多显示10项
82 | console.print(f" • {item}")
83 | if len(result) > 10:
84 | console.print(f" {_('more_items', count=len(result) - 10)}")
85 | else:
86 | # 普通文本
87 | console.print(str(result))
88 |
89 |
90 | def show_confirmation_prompt(tool_name: str, args: Dict[str, Any],
91 | risk_level: str = 'low',
92 | risk_description: str = ''):
93 | """显示工具确认提示"""
94 | from rich.panel import Panel
95 | from rich.text import Text
96 | from rich.columns import Columns
97 |
98 | console.print()
99 |
100 | # 构建确认内容
101 | content_lines = []
102 |
103 | # 添加风险级别
104 | risk_color = RISK_COLORS.get(risk_level, 'yellow')
105 | content_lines.append(f"{_('risk_level')}: [{risk_color}]{risk_level.upper()}[/{risk_color}]")
106 |
107 | # 添加风险描述
108 | if risk_description:
109 | content_lines.append(f"{_('risk_description')}: {risk_description}")
110 |
111 | # 添加参数
112 | if args:
113 | content_lines.append(f"\n{_('parameters')}:")
114 | for key, value in args.items():
115 | value_str = str(value)
116 |
117 | # 对于代码相关的参数,使用语法高亮而不是截断
118 | if key.lower() in ['code', 'sql', 'query', 'script', 'command']:
119 | content_lines.append(f" • {key}:")
120 |
121 | # 检测语言类型
122 | if key.lower() in ['sql', 'query']:
123 | lang = 'sql'
124 | elif key.lower() == 'code':
125 | lang = 'python' # 默认Python
126 | elif key.lower() in ['script', 'command']:
127 | lang = 'bash'
128 | else:
129 | lang = 'text'
130 |
131 | # 使用语法高亮显示代码
132 | from rich.syntax import Syntax
133 | syntax = Syntax(value_str, lang, theme="monokai", line_numbers=False, word_wrap=True)
134 | # 将语法高亮对象转为字符串添加到内容中
135 | import io
136 | from rich.console import Console as TempConsole
137 | buffer = io.StringIO()
138 | temp_console = TempConsole(file=buffer, force_terminal=True)
139 | temp_console.print(syntax)
140 | content_lines.append(buffer.getvalue().rstrip())
141 | else:
142 | # 其他参数可以截断
143 | if len(value_str) > 200:
144 | value_str = value_str[:197] + "..."
145 | content_lines.append(f" • {key}: {value_str}")
146 |
147 | # 创建主内容面板
148 | main_content = "\n".join(content_lines)
149 |
150 | # 根据风险级别选择边框颜色
151 | border_color = RISK_COLORS.get(risk_level, 'yellow')
152 |
153 | # 创建并显示主面板
154 | main_panel = Panel(
155 | main_content,
156 | title=f"[bold]{_('tool_confirm_title', tool_name=tool_name)}[/bold]",
157 | border_style=border_color,
158 | padding=(1, 2)
159 | )
160 | console.print(main_panel)
161 |
162 | # 显示操作选项(带框)
163 | options_content = [
164 | f"[green]1[/green] / confirm • {_('confirm_execute')}",
165 | f"[red]2[/red] / cancel • {_('cancel_execute')}",
166 | f"confirm all • {_('confirm_all_tools')}"
167 | ]
168 |
169 | options_panel = Panel(
170 | "\n".join(options_content),
171 | title=f"[bold]{_('please_input')}[/bold]",
172 | subtitle=f"[dim]{_('input_halfwidth_hint')}[/dim]",
173 | border_style="cyan",
174 | padding=(0, 2)
175 | )
176 | console.print(options_panel)
177 | console.print()
178 |
179 |
180 | def show_tool_error(tool_name: str, error: str):
181 | """显示工具错误"""
182 | console.print(f"[error]✗ [{tool_name}] {_('tool_failed', error=error)}[/error]")
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/api/routes/websocket.py:
--------------------------------------------------------------------------------
1 | """
2 | WebSocket路由 - 提供实时通信功能
3 | 支持流式对话、工具执行状态更新等
4 | """
5 |
6 | from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
7 | import json
8 | import asyncio
9 | from typing import Dict, Set
10 |
11 | from ...types.core_types import SimpleAbortSignal
12 | from ..dependencies import get_client
13 |
14 | websocket_router = APIRouter()
15 |
16 | # 活跃的WebSocket连接
17 | active_connections: Dict[str, WebSocket] = {}
18 | session_connections: Dict[str, Set[str]] = {}
19 |
20 |
21 | class ConnectionManager:
22 | """WebSocket连接管理器"""
23 |
24 | def __init__(self):
25 | self.active_connections: Dict[str, WebSocket] = {}
26 |
27 | async def connect(self, websocket: WebSocket, connection_id: str):
28 | """接受WebSocket连接"""
29 | await websocket.accept()
30 | self.active_connections[connection_id] = websocket
31 |
32 | def disconnect(self, connection_id: str):
33 | """断开WebSocket连接"""
34 | if connection_id in self.active_connections:
35 | del self.active_connections[connection_id]
36 |
37 | async def send_message(self, connection_id: str, message: dict):
38 | """发送消息到指定连接"""
39 | if connection_id in self.active_connections:
40 | websocket = self.active_connections[connection_id]
41 | try:
42 | await websocket.send_text(json.dumps(message, ensure_ascii=False))
43 | except:
44 | # 连接已断开,清理
45 | self.disconnect(connection_id)
46 |
47 | async def broadcast(self, message: dict):
48 | """广播消息到所有连接"""
49 | for connection_id in list(self.active_connections.keys()):
50 | await self.send_message(connection_id, message)
51 |
52 |
53 | manager = ConnectionManager()
54 |
55 |
56 | @websocket_router.websocket("/chat/{session_id}")
57 | async def websocket_chat(
58 | websocket: WebSocket,
59 | session_id: str,
60 | client = Depends(get_client)
61 | ):
62 | """
63 | WebSocket聊天接口
64 | 提供实时的对话交互和工具执行状态更新
65 | """
66 | connection_id = f"{session_id}_{id(websocket)}"
67 |
68 | await manager.connect(websocket, connection_id)
69 |
70 | try:
71 | # 发送连接确认
72 | await manager.send_message(connection_id, {
73 | "type": "connection",
74 | "status": "connected",
75 | "session_id": session_id,
76 | "message": "WebSocket连接已建立"
77 | })
78 |
79 | while True:
80 | # 接收客户端消息
81 | data = await websocket.receive_text()
82 | message_data = json.loads(data)
83 |
84 | message_type = message_data.get("type", "chat")
85 |
86 | if message_type == "chat":
87 | # 处理聊天消息
88 | await handle_chat_message(
89 | connection_id,
90 | session_id,
91 | message_data,
92 | client
93 | )
94 | elif message_type == "ping":
95 | # 心跳检测
96 | await manager.send_message(connection_id, {
97 | "type": "pong",
98 | "timestamp": message_data.get("timestamp")
99 | })
100 | elif message_type == "abort":
101 | # 中止当前操作
102 | await manager.send_message(connection_id, {
103 | "type": "aborted",
104 | "message": "操作已中止"
105 | })
106 |
107 | except WebSocketDisconnect:
108 | manager.disconnect(connection_id)
109 | except Exception as e:
110 | await manager.send_message(connection_id, {
111 | "type": "error",
112 | "error": str(e)
113 | })
114 | manager.disconnect(connection_id)
115 |
116 |
117 | async def handle_chat_message(
118 | connection_id: str,
119 | session_id: str,
120 | message_data: dict,
121 | client
122 | ):
123 | """处理聊天消息"""
124 | try:
125 | message = message_data.get("message", "")
126 |
127 | # 发送开始处理的通知
128 | await manager.send_message(connection_id, {
129 | "type": "processing",
130 | "message": "正在处理您的请求..."
131 | })
132 |
133 | # 创建中止信号
134 | signal = SimpleAbortSignal()
135 |
136 | # 发送消息并获取流式响应
137 | response_stream = client.send_message_stream(
138 | request=message,
139 | signal=signal,
140 | prompt_id=session_id,
141 | turns=100
142 | )
143 |
144 | # 流式发送响应
145 | async for chunk in response_stream:
146 | await manager.send_message(connection_id, {
147 | "type": "stream",
148 | "chunk": chunk
149 | })
150 |
151 | # 发送完成通知
152 | await manager.send_message(connection_id, {
153 | "type": "complete",
154 | "message": "响应完成"
155 | })
156 |
157 | except Exception as e:
158 | await manager.send_message(connection_id, {
159 | "type": "error",
160 | "error": str(e)
161 | })
162 |
163 |
164 | @websocket_router.websocket("/tools/{session_id}")
165 | async def websocket_tools(
166 | websocket: WebSocket,
167 | session_id: str
168 | ):
169 | """
170 | 工具执行状态WebSocket
171 | 实时更新工具执行状态和结果
172 | """
173 | connection_id = f"tools_{session_id}_{id(websocket)}"
174 |
175 | await manager.connect(websocket, connection_id)
176 |
177 | try:
178 | await manager.send_message(connection_id, {
179 | "type": "connection",
180 | "status": "connected",
181 | "message": "工具状态监听已建立"
182 | })
183 |
184 | # 保持连接活跃
185 | while True:
186 | data = await websocket.receive_text()
187 | # 这里可以处理工具相关的命令
188 |
189 | except WebSocketDisconnect:
190 | manager.disconnect(connection_id)
191 | except Exception as e:
192 | await manager.send_message(connection_id, {
193 | "type": "error",
194 | "error": str(e)
195 | })
196 | manager.disconnect(connection_id)
197 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/prompts.py:
--------------------------------------------------------------------------------
1 | """
2 | 提示词管理系统
3 | 参考Gemini CLI的分层提示词设计
4 | """
5 |
6 | import datetime
7 | from typing import Dict, Optional, List
8 | from .config.base import DatabaseConfig
9 | from .prompts.database_agent_prompt import get_database_agent_prompt, get_tool_guidance
10 |
11 | # 东京时区常量
12 | TOKYO_TZ = datetime.timezone(datetime.timedelta(hours=9))
13 |
14 |
15 | class PromptManager:
16 | """
17 | 提示词管理器 - 分层系统
18 | 优先级:用户自定义 > 工作区配置 > 系统默认
19 | """
20 |
21 | def __init__(self, config: DatabaseConfig):
22 | self.config = config
23 | self._cache = {}
24 |
25 | def get_system_prompt(self, context: Optional[Dict] = None) -> str:
26 | """
27 | 获取系统提示词
28 | 支持上下文感知和动态调整
29 | """
30 | # 检查用户自定义提示词
31 | custom_prompt = self.config.get("custom_system_prompt")
32 | if custom_prompt:
33 | return self._process_template(custom_prompt, context)
34 |
35 | # 使用默认的数据库Agent提示词
36 | base_prompt = get_database_agent_prompt()
37 |
38 | # 添加当前时间(东京时间)
39 | tokyo_time = datetime.datetime.now(TOKYO_TZ)
40 | base_prompt += f"\n\nCurrent Tokyo time: {tokyo_time.strftime('%Y-%m-%d %H:%M:%S JST')}"
41 |
42 | # 添加语言提示
43 | # 检查当前语言设置
44 | if hasattr(self.config, 'get') and self.config.get('i18n'):
45 | i18n = self.config.get('i18n')
46 | if isinstance(i18n, dict) and 'current_lang' in i18n:
47 | current_lang = i18n['current_lang']()
48 | if current_lang == 'ja_JP':
49 | base_prompt += "\n日本語で応答する際は、中国語を混在させず、専門用語は正確に、自然な日本語表現を使用してください。"
50 | elif current_lang == 'zh_CN':
51 | base_prompt += "\n使用中文回复时,请使用规范的简体中文和准确的技术术语。"
52 | elif current_lang == 'en_US':
53 | base_prompt += "\nUse clear, professional English with accurate technical terminology."
54 |
55 | # 添加上下文特定的指导
56 | if context:
57 | additional_guidance = self._get_contextual_guidance(context)
58 | if additional_guidance:
59 | base_prompt += f"\n\n## Current Context\n{additional_guidance}"
60 |
61 | return base_prompt
62 |
63 | def get_tool_prompt(self, tool_name: str) -> str:
64 | """获取工具特定的提示词"""
65 | # 检查缓存
66 | if tool_name in self._cache:
67 | return self._cache[tool_name]
68 |
69 | # 获取工具指导
70 | guidance = get_tool_guidance(tool_name)
71 |
72 | # 添加用户自定义的工具提示
73 | custom_tool_prompts = self.config.get("tool_prompts", {})
74 | if tool_name in custom_tool_prompts:
75 | guidance = custom_tool_prompts[tool_name] + "\n\n" + guidance
76 |
77 | self._cache[tool_name] = guidance
78 | return guidance
79 |
80 | def get_next_speaker_prompt(self) -> str:
81 | """获取next_speaker判断的提示词"""
82 | return """Based on the conversation history and the last message, determine who should speak next.
83 |
84 | Rules:
85 | 1. If the last message was a tool execution result (function response), return "model" to process the result
86 | 2. If the model asked a question that needs user input, return "user"
87 | 3. If the model indicated it will perform more actions, return "model"
88 | 4. If the task is complete and waiting for new instructions, return "user"
89 |
90 | Respond with a JSON object: {"next_speaker": "model" or "user", "reasoning": "brief explanation"}"""
91 |
92 | def _get_contextual_guidance(self, context: Dict) -> str:
93 | """基于上下文生成额外指导"""
94 | guidance_parts = []
95 |
96 | # 数据库连接信息
97 | if 'database_type' in context:
98 | db_type = context['database_type']
99 | guidance_parts.append(f"You are connected to a {db_type} database.")
100 |
101 | # 已发现的表
102 | if 'discovered_tables' in context:
103 | tables = context['discovered_tables']
104 | if tables:
105 | guidance_parts.append(f"Previously discovered tables: {', '.join(tables[:10])}")
106 |
107 | # 当前任务类型
108 | if 'task_type' in context:
109 | task_type = context['task_type']
110 | if task_type == 'exploration':
111 | guidance_parts.append("Focus on understanding the database structure and relationships.")
112 | elif task_type == 'analysis':
113 | guidance_parts.append("Focus on extracting insights and patterns from the data.")
114 | elif task_type == 'modification':
115 | guidance_parts.append("Be extra careful with data modifications. Always verify impact.")
116 |
117 | return "\n".join(guidance_parts)
118 |
119 | def _process_template(self, template: str, context: Optional[Dict]) -> str:
120 | """处理提示词模板中的变量"""
121 | if not context:
122 | return template
123 |
124 | # 简单的变量替换
125 | for key, value in context.items():
126 | template = template.replace(f"{{{{{key}}}}}", str(value))
127 |
128 | return template
129 |
130 |
131 | class PromptLibrary:
132 | """
133 | 提示词库 - 存储常用提示词模板
134 | """
135 |
136 | # 错误恢复提示词
137 | ERROR_RECOVERY = """The previous operation failed with error: {error}
138 |
139 | Analyze the error and try an alternative approach. Consider:
140 | 1. Syntax issues in the SQL
141 | 2. Missing tables or columns
142 | 3. Permission problems
143 | 4. Data type mismatches
144 |
145 | Provide a clear explanation and attempt a different solution."""
146 |
147 | # 性能优化提示词
148 | PERFORMANCE_OPTIMIZATION = """The query is taking too long or consuming too many resources.
149 |
150 | Consider these optimization strategies:
151 | 1. Add appropriate indexes
152 | 2. Limit the result set
153 | 3. Use more efficient JOIN strategies
154 | 4. Partition large tables
155 | 5. Pre-aggregate data
156 |
157 | Suggest specific improvements for this query."""
158 |
159 | # 数据探索提示词
160 | DATA_EXPLORATION = """Help the user explore and understand their database.
161 |
162 | Start with:
163 | 1. Overview of available tables
164 | 2. Identify key business entities
165 | 3. Discover relationships between tables
166 | 4. Highlight interesting patterns or anomalies
167 |
168 | Guide them through progressive discovery."""
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # DbRheoCLI - Database/Data Analysis AI Agent
3 |
4 |
5 |
6 | DbRheo is a database operations and data analysis CLI agent that provides natural language database query execution, schema exploration, risk assessment capabilities, and Python-powered data analysis features.
7 |
8 |
9 | ## Quick Start
10 |
11 | ```bash
12 | # 1. Clone the repository
13 | git clone https://github.com/Din829/DbRheo-CLI.git
14 | cd DbRheo-CLI
15 |
16 | # 2. Install dependencies
17 | pip install -r requirements.txt
18 |
19 |
20 | # 3. Environment setup
21 | cp .env.example .env
22 | # Set either GOOGLE_API_KEY or OPENAI_API_KEY in the .env file
23 | # No need to modify other contents in .env.example
24 | # Claude models are not recommended at this time as PromptCaching is not yet applied
25 |
26 |
27 | # 4. Launch CLI
28 | cd packages/cli
29 | python cli.py
30 | ```
31 |
32 | ## Key Features
33 |
34 | ### Core Capabilities
35 | - **Natural Language Query Processing**: Database operation instructions in natural language
36 | - **Intelligent SQL Generation**: Automatic generation of safe and optimized queries
37 | - **Automatic Schema Discovery**: Dynamic analysis of database structures
38 | - **Risk Assessment System**: Pre-detection and warnings for dangerous operations
39 | - **Python Code Execution**: Data analysis, visualization, and automation script execution
40 | - **Data Export**: Result output in CSV, JSON, and Excel formats
41 |
42 | ### Technical Features
43 | - **Asynchronous Processing**: High-performance async/await implementation
44 | - **Multi-Database Support**: PostgreSQL, MySQL, SQLite compatibility
45 | - **Modular Design**: Extensible plugin architecture
46 | - **Comprehensive Logging**: Detailed operation history and debug information
47 | - **Intelligent Input**: Automatic multi-line detection and paste processing
48 | - **Streaming Output**: Real-time response display
49 | - **Internationalization**: Multi-language support (Japanese, English)
50 |
51 | ## System Requirements
52 |
53 | ### Required Environment
54 | - Python 3.9 or higher
55 | - Node.js 20 or higher (only for Web UI development)
56 |
57 | ### Supported Databases
58 | Currently supports 3 main database types (more can be added via adapter interface):
59 | - **PostgreSQL** 12+ (via asyncpg driver)
60 | - **MySQL/MariaDB** 8.0+ (via aiomysql driver)
61 | - **SQLite** 3.35+ (via aiosqlite driver)
62 |
63 | *Note: Additional database types can be easily integrated through the adapter factory pattern. The system supports dynamic adapter registration and automatic driver detection.*
64 |
65 | ## Installation Guide
66 |
67 | ### 1. Clone the Repository
68 | ```bash
69 | git clone https://github.com/Din829/DbRheo-CLI.git
70 | cd DbRheo-CLI
71 | ```
72 |
73 | Alternative:
74 | https://dev.azure.com/HPSMDI/POC_Agent/_git/db-rheo-cli
75 |
76 |
77 | ### 2. Python Environment Setup
78 | ```bash
79 |
80 | # Install dependencies
81 | pip install -r requirements.txt
82 |
83 | ```
84 |
85 | ### 3. Package Installation (Optional)
86 | ```bash
87 | # Install core package
88 | cd packages/core
89 | pip install -e .
90 | cd ../..
91 |
92 | # Install CLI package
93 | cd packages/cli
94 | pip install -e .
95 | cd ../..
96 |
97 | # Verify installation
98 | pip show dbrheo-core dbrheo-cli
99 | ```
100 |
101 | **Note**: You can run directly in development mode without installing packages.
102 |
103 | ### 4. Environment Configuration
104 | ```bash
105 | # Copy configuration file
106 | cp .env.example .env
107 |
108 | # Edit the .env file and configure:
109 | # - Google API key
110 | # - Database connection information
111 | ```
112 |
113 | ### 5. Test Data
114 | The `testdata/` directory contains sample datasets for testing the agent:
115 | - **adult.data**: Adult Census Income dataset for data analysis testing
116 | - **adult.names**: Dataset description and column information
117 | - **adult.test**: Test dataset for validation
118 | - Additional sample files for various testing scenarios
119 |
120 | You can use these datasets to test DbRheo's data analysis capabilities and SQL generation features.
121 |
122 | ## Launch Methods
123 |
124 | ### CLI Mode Launch
125 |
126 | #### After Package Installation
127 |
128 | ```bash
129 | # Display help
130 | /help
131 |
132 | # Specify model
133 | /model
134 | ```
135 |
136 |
137 |
138 |
139 | ## Usage Examples
140 |
141 | ### Basic Conversation Examples
142 | ```
143 | DbRheo> Tell me about the structure of the users table
144 | [Executing schema exploration...]
145 | Structure of table 'users':
146 | - id: INTEGER (Primary Key)
147 | - name: VARCHAR(100)
148 | - email: VARCHAR(255)
149 | - created_at: TIMESTAMP
150 |
151 | DbRheo> Show me the latest 10 users
152 | [Generating SQL query...]
153 | SELECT * FROM users ORDER BY created_at DESC LIMIT 10;
154 | [Displaying execution results...]
155 | ```
156 |
157 | ### Data Analysis Features
158 | ```
159 | DbRheo> Analyze and visualize sales data using Python
160 | [Generating Python code...]
161 | import pandas as pd
162 | import matplotlib.pyplot as plt
163 |
164 | # Retrieve sales data from database
165 | df = pd.read_sql("SELECT * FROM sales", connection)
166 |
167 | # Monthly sales aggregation
168 | monthly_sales = df.groupby('month')['amount'].sum()
169 |
170 | # Create graph
171 | plt.figure(figsize=(10, 6))
172 | monthly_sales.plot(kind='bar')
173 | plt.title('Monthly Sales Trends')
174 | plt.savefig('sales_analysis.png')
175 |
176 | [Execution result: Generated graph file sales_analysis.png]
177 | ```
178 |
179 | ### Advanced SQL Features
180 | ```
181 | DbRheo> Create monthly aggregation of sales data
182 | [Generating complex query...]
183 | SELECT
184 | DATE_TRUNC('month', order_date) as month,
185 | SUM(amount) as total_sales
186 | FROM orders
187 | GROUP BY month
188 | ORDER BY month;
189 | ```
190 |
191 | ### Testing with Sample Data
192 | ```
193 | DbRheo> Load the adult dataset from testdata and analyze income distribution
194 | [Loading data from testdata/adult.data...]
195 | [Generating analysis code...]
196 | import pandas as pd
197 |
198 | # Load the adult census dataset
199 | df = pd.read_csv('testdata/adult.data', header=None)
200 | # Apply column names from adult.names
201 | df.columns = ['age', 'workclass', 'fnlwgt', 'education', ...]
202 |
203 | # Analyze income distribution
204 | income_dist = df['income'].value_counts()
205 | print("Income Distribution:")
206 | print(income_dist)
207 |
208 | [Execution result: Income analysis completed]
209 | ```
210 |
211 |
212 |
213 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/api/routes/chat.py:
--------------------------------------------------------------------------------
1 | """
2 | 聊天API路由 - 处理与数据库Agent的对话交互
3 | 提供流式对话、历史管理等功能
4 | """
5 |
6 | from fastapi import APIRouter, HTTPException, Depends
7 | from fastapi.responses import StreamingResponse
8 | from pydantic import BaseModel
9 | from typing import List, Optional, Dict, Any
10 | import json
11 | import asyncio
12 |
13 | from ...types.core_types import SimpleAbortSignal
14 | from ..dependencies import get_client, get_config
15 |
16 | chat_router = APIRouter()
17 |
18 |
19 | class ChatMessage(BaseModel):
20 | """聊天消息模型"""
21 | content: str
22 | role: str = "user"
23 |
24 |
25 | class ChatRequest(BaseModel):
26 | """聊天请求模型"""
27 | message: str
28 | session_id: Optional[str] = None
29 | context: Optional[Dict[str, Any]] = None
30 |
31 |
32 | class ChatResponse(BaseModel):
33 | """聊天响应模型"""
34 | response: str
35 | session_id: str
36 | turn_count: int
37 | next_speaker: Optional[str] = None
38 |
39 |
40 | @chat_router.post("/send")
41 | async def send_message(
42 | request: ChatRequest,
43 | client = Depends(get_client)
44 | ):
45 | """
46 | 发送消息给数据库Agent
47 | 支持流式响应和工具调用
48 | """
49 | try:
50 | # 创建中止信号
51 | signal = SimpleAbortSignal()
52 |
53 | # 生成会话ID
54 | session_id = request.session_id or f"session_{int(asyncio.get_event_loop().time())}"
55 |
56 | # 发送消息并获取流式响应
57 | response_stream = client.send_message_stream(
58 | request=request.message,
59 | signal=signal,
60 | prompt_id=session_id,
61 | turns=100
62 | )
63 |
64 | # 收集响应
65 | response_parts = []
66 | chunk_count = 0
67 | async for chunk in response_stream:
68 | chunk_count += 1
69 | print(f"[DEBUG] Chunk #{chunk_count}: {chunk}") # 调试信息
70 |
71 | if chunk.get("type") == "Content":
72 | response_parts.append(chunk.get("value", ""))
73 | elif chunk.get("type") == "ToolCallRequest":
74 | # 工具调用请求
75 | tool_value = chunk.get('value')
76 | if hasattr(tool_value, 'name'):
77 | tool_name = tool_value.name
78 | elif isinstance(tool_value, dict):
79 | tool_name = tool_value.get('name', 'unknown')
80 | else:
81 | tool_name = 'unknown'
82 | response_parts.append(f"[工具调用: {tool_name}]")
83 |
84 | response_text = "".join(response_parts)
85 | print(f"[DEBUG] Total chunks: {chunk_count}, Response text: {response_text}")
86 |
87 | return ChatResponse(
88 | response=response_text,
89 | session_id=session_id,
90 | turn_count=client.session_turn_count,
91 | next_speaker="user" # TODO: 实现实际的next_speaker判断
92 | )
93 |
94 | except Exception as e:
95 | print(f"[ERROR] send_message exception: {type(e).__name__}: {str(e)}")
96 | import traceback
97 | traceback.print_exc()
98 | raise HTTPException(status_code=500, detail=str(e))
99 |
100 |
101 | @chat_router.get("/stream/{session_id}")
102 | async def stream_chat(
103 | session_id: str,
104 | message: str,
105 | client = Depends(get_client)
106 | ):
107 | """
108 | 流式聊天接口
109 | 返回Server-Sent Events格式的流式响应
110 | """
111 | async def generate_stream():
112 | try:
113 | signal = SimpleAbortSignal()
114 |
115 | response_stream = client.send_message_stream(
116 | request=message,
117 | signal=signal,
118 | prompt_id=session_id,
119 | turns=100
120 | )
121 |
122 | async for chunk in response_stream:
123 | # 转换为SSE格式
124 | data = json.dumps(chunk, ensure_ascii=False)
125 | yield f"data: {data}\n\n"
126 |
127 | # 发送结束标记
128 | yield "data: [DONE]\n\n"
129 |
130 | except Exception as e:
131 | error_data = json.dumps({"error": str(e)}, ensure_ascii=False)
132 | yield f"data: {error_data}\n\n"
133 |
134 | return StreamingResponse(
135 | generate_stream(),
136 | media_type="text/plain",
137 | headers={
138 | "Cache-Control": "no-cache",
139 | "Connection": "keep-alive",
140 | "Content-Type": "text/event-stream"
141 | }
142 | )
143 |
144 |
145 | @chat_router.get("/history/{session_id}")
146 | async def get_chat_history(
147 | session_id: str,
148 | curated: bool = True,
149 | client = Depends(get_client)
150 | ):
151 | """获取聊天历史"""
152 | try:
153 | # TODO: 实现从客户端获取历史的逻辑
154 | history = client.chat.get_history(curated=curated)
155 |
156 | return {
157 | "session_id": session_id,
158 | "history": history,
159 | "total_messages": len(history)
160 | }
161 |
162 | except Exception as e:
163 | raise HTTPException(status_code=500, detail=str(e))
164 |
165 |
166 | @chat_router.delete("/history/{session_id}")
167 | async def clear_chat_history(
168 | session_id: str,
169 | client = Depends(get_client)
170 | ):
171 | """清除聊天历史"""
172 | try:
173 | # TODO: 实现清除历史的逻辑
174 | client.chat.set_history([])
175 |
176 | return {
177 | "message": "Chat history cleared",
178 | "session_id": session_id
179 | }
180 |
181 | except Exception as e:
182 | raise HTTPException(status_code=500, detail=str(e))
183 |
184 |
185 | @chat_router.post("/compress/{session_id}")
186 | async def compress_chat_history(
187 | session_id: str,
188 | force: bool = False,
189 | client = Depends(get_client)
190 | ):
191 | """压缩聊天历史"""
192 | try:
193 | # TODO: 实现历史压缩逻辑
194 | result = await client.try_compress_chat(session_id, force=force)
195 |
196 | if result:
197 | return {
198 | "message": "Chat history compressed",
199 | "session_id": session_id,
200 | "compression_stats": result
201 | }
202 | else:
203 | return {
204 | "message": "No compression needed",
205 | "session_id": session_id
206 | }
207 |
208 | except Exception as e:
209 | raise HTTPException(status_code=500, detail=str(e))
210 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/core/token_statistics.py:
--------------------------------------------------------------------------------
1 | """
2 | Token 使用统计管理
3 | 最小侵入性设计,用于收集和聚合 token 使用数据
4 | """
5 |
6 | from typing import Dict, Any, List, Optional
7 | from dataclasses import dataclass, field
8 | from datetime import datetime
9 |
10 |
11 | @dataclass
12 | class TokenUsageRecord:
13 | """单次 API 调用的 token 使用记录"""
14 | timestamp: datetime
15 | model: str
16 | prompt_tokens: Optional[int]
17 | completion_tokens: Optional[int]
18 | total_tokens: Optional[int]
19 | cached_tokens: Optional[int] = 0 # 新增:缓存的token数量
20 |
21 |
22 | @dataclass
23 | class TokenStatistics:
24 | """会话级 token 统计"""
25 | records: List[TokenUsageRecord] = field(default_factory=list)
26 |
27 | def add_usage(self, model: str, usage_data: Dict[str, Any]):
28 | """添加一次使用记录"""
29 | record = TokenUsageRecord(
30 | timestamp=datetime.now(),
31 | model=model,
32 | prompt_tokens=usage_data.get('prompt_tokens', 0),
33 | completion_tokens=usage_data.get('completion_tokens', 0),
34 | total_tokens=usage_data.get('total_tokens', 0),
35 | cached_tokens=usage_data.get('cached_tokens', 0) # 新增
36 | )
37 |
38 | # 详细调试
39 | from ..utils.debug_logger import log_info
40 | log_info("TokenStats", f"ADDING RECORD #{len(self.records) + 1}:")
41 | log_info("TokenStats", f" - Model: {model}")
42 |
43 | # 计算实际计费的tokens
44 | billable_prompt_tokens = record.prompt_tokens - record.cached_tokens
45 | billable_total_tokens = billable_prompt_tokens + record.completion_tokens
46 |
47 | log_info("TokenStats", f" - prompt_tokens: {billable_prompt_tokens} (original: {record.prompt_tokens}, cached: {record.cached_tokens})")
48 | log_info("TokenStats", f" - completion_tokens: {record.completion_tokens}")
49 | log_info("TokenStats", f" - total_tokens: {billable_total_tokens} (original: {record.total_tokens})")
50 |
51 | if record.cached_tokens > 0:
52 | save_rate = record.cached_tokens / record.prompt_tokens * 100
53 | log_info("TokenStats", f" - cache_rate: {save_rate:.1f}%")
54 |
55 | log_info("TokenStats", f" - Timestamp: {record.timestamp.strftime('%H:%M:%S.%f')[:-3]}")
56 |
57 | # 显示当前累计(实际计费)
58 | current_billable = sum((r.total_tokens or 0) - (r.cached_tokens or 0) for r in self.records)
59 | new_billable = current_billable + billable_total_tokens
60 |
61 | log_info("TokenStats", f" - Running billable total BEFORE: {current_billable}")
62 | log_info("TokenStats", f" - Running billable total AFTER: {new_billable}")
63 |
64 | self.records.append(record)
65 |
66 | def get_summary(self) -> Dict[str, Any]:
67 | """获取统计摘要"""
68 | if not self.records:
69 | return {
70 | 'total_calls': 0,
71 | 'total_prompt_tokens': 0,
72 | 'total_completion_tokens': 0,
73 | 'total_tokens': 0,
74 | 'total_cached_tokens': 0,
75 | 'by_model': {}
76 | }
77 |
78 | # 计算总计 - 减去缓存的tokens
79 | total_prompt = sum(r.prompt_tokens or 0 for r in self.records)
80 | total_cached = sum(r.cached_tokens or 0 for r in self.records)
81 | total_billable_prompt = total_prompt - total_cached
82 | total_completion = sum(r.completion_tokens or 0 for r in self.records)
83 |
84 | # 按模型分组统计
85 | by_model = {}
86 | for record in self.records:
87 | if record.model not in by_model:
88 | by_model[record.model] = {
89 | 'calls': 0,
90 | 'prompt_tokens': 0,
91 | 'completion_tokens': 0,
92 | 'total_tokens': 0,
93 | 'cached_tokens': 0
94 | }
95 | by_model[record.model]['calls'] += 1
96 | # 计算实际计费的prompt tokens
97 | billable_prompt = (record.prompt_tokens or 0) - (record.cached_tokens or 0)
98 | by_model[record.model]['prompt_tokens'] += billable_prompt
99 | by_model[record.model]['completion_tokens'] += record.completion_tokens or 0
100 | by_model[record.model]['total_tokens'] += billable_prompt + (record.completion_tokens or 0)
101 | by_model[record.model]['cached_tokens'] += record.cached_tokens or 0
102 |
103 | return {
104 | 'total_calls': len(self.records),
105 | 'total_prompt_tokens': total_billable_prompt,
106 | 'total_completion_tokens': total_completion,
107 | 'total_tokens': total_billable_prompt + total_completion,
108 | 'total_cached_tokens': total_cached,
109 | 'original_prompt_tokens': total_prompt, # 保留原始值供参考
110 | 'by_model': by_model
111 | }
112 |
113 | def get_cost_estimate(self) -> Dict[str, float]:
114 | """获取成本估算(基于公开价格)"""
115 | # 2025年1月的参考价格(每1M tokens)
116 | pricing = {
117 | 'gemini-2.5-flash': {'input': 0.075, 'output': 0.30}, # $0.075/$0.30 per 1M
118 | 'gemini-1.5-pro': {'input': 1.25, 'output': 5.00}, # $1.25/$5.00 per 1M
119 | 'claude-3.5-sonnet': {'input': 3.00, 'output': 15.00}, # $3/$15 per 1M
120 | 'gpt-4.1': {'input': 2.50, 'output': 10.00}, # $2.50/$10 per 1M
121 | 'gpt-5-mini': {'input': 0.25, 'output': 2.00} # $0.25/$2.00 per 1M
122 | }
123 |
124 | total_cost = 0.0
125 | cost_by_model = {}
126 |
127 | for model, stats in self.get_summary()['by_model'].items():
128 | # 查找价格(支持模型别名)
129 | model_pricing = None
130 | for key in pricing:
131 | if key in model.lower():
132 | model_pricing = pricing[key]
133 | break
134 |
135 | if model_pricing:
136 | # 使用已经减去缓存的实际计费tokens
137 | input_cost = (stats['prompt_tokens'] / 1_000_000) * model_pricing['input']
138 | output_cost = (stats['completion_tokens'] / 1_000_000) * model_pricing['output']
139 | model_cost = input_cost + output_cost
140 |
141 | cost_by_model[model] = {
142 | 'input_cost': input_cost,
143 | 'output_cost': output_cost,
144 | 'total_cost': model_cost,
145 | 'cached_tokens': stats.get('cached_tokens', 0)
146 | }
147 | total_cost += model_cost
148 |
149 | return {
150 | 'total_cost': total_cost,
151 | 'by_model': cost_by_model
152 | }
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/telemetry/metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | DatabaseMetrics - 性能指标收集系统
3 | 基于OpenTelemetry Metrics实现,完全对齐Gemini CLI的指标机制
4 | """
5 |
6 | import time
7 | import logging
8 | from typing import Optional, Dict, Any, List
9 | from collections import defaultdict, deque
10 | from dataclasses import dataclass, field
11 |
12 | # 有条件导入OpenTelemetry Metrics
13 | try:
14 | from opentelemetry import metrics
15 | from opentelemetry.sdk.metrics import MeterProvider
16 | from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
17 | from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
18 | OTEL_METRICS_AVAILABLE = True
19 | except ImportError:
20 | OTEL_METRICS_AVAILABLE = False
21 |
22 | from ..config.base import DatabaseConfig
23 |
24 |
25 | @dataclass
26 | class MetricPoint:
27 | """指标数据点"""
28 | timestamp: float
29 | value: float
30 | labels: Dict[str, str] = field(default_factory=dict)
31 |
32 |
33 | class DatabaseMetrics:
34 | """
35 | 数据库Agent性能指标收集系统
36 | - 完全对齐Gemini CLI的指标机制
37 | - 支持OpenTelemetry Metrics集成
38 | - 提供内存缓存和批量导出
39 | - 支持降级模式(未安装OpenTelemetry时)
40 | """
41 |
42 | def __init__(self, config: DatabaseConfig):
43 | self.config = config
44 | self.service_name = config.get("service_name", "database-agent")
45 | self.enabled = config.get("metrics_enabled", True)
46 |
47 | # 内存指标存储(降级模式)
48 | self.counters: Dict[str, float] = defaultdict(float)
49 | self.histograms: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
50 | self.gauges: Dict[str, float] = {}
51 |
52 | # 初始化OpenTelemetry指标
53 | self.meter = self._setup_meter() if self.enabled else None
54 | self._otel_instruments = {}
55 |
56 | def _setup_meter(self):
57 | """设置OpenTelemetry指标收集器"""
58 | if not OTEL_METRICS_AVAILABLE:
59 | logging.warning("OpenTelemetry Metrics not available, using memory storage")
60 | return None
61 |
62 | try:
63 | # 设置指标提供者
64 | otlp_endpoint = self.config.get("otel_exporter_otlp_endpoint")
65 | if otlp_endpoint:
66 | exporter = OTLPMetricExporter(endpoint=otlp_endpoint)
67 | reader = PeriodicExportingMetricReader(exporter, export_interval_millis=5000)
68 | provider = MeterProvider(metric_readers=[reader])
69 | metrics.set_meter_provider(provider)
70 |
71 | # 创建指标收集器
72 | return metrics.get_meter(self.service_name)
73 |
74 | except Exception as e:
75 | logging.error(f"Failed to setup metrics: {e}")
76 | return None
77 |
78 | def counter(self, name: str, description: str = "") -> 'Counter':
79 | """创建计数器指标"""
80 | return Counter(self, name, description)
81 |
82 | def histogram(self, name: str, description: str = "") -> 'Histogram':
83 | """创建直方图指标"""
84 | return Histogram(self, name, description)
85 |
86 | def gauge(self, name: str, description: str = "") -> 'Gauge':
87 | """创建仪表盘指标"""
88 | return Gauge(self, name, description)
89 |
90 | def _get_or_create_counter(self, name: str, description: str):
91 | """获取或创建OpenTelemetry计数器"""
92 | if not self.meter:
93 | return None
94 |
95 | if name not in self._otel_instruments:
96 | self._otel_instruments[name] = self.meter.create_counter(
97 | name=name,
98 | description=description
99 | )
100 | return self._otel_instruments[name]
101 |
102 | def _get_or_create_histogram(self, name: str, description: str):
103 | """获取或创建OpenTelemetry直方图"""
104 | if not self.meter:
105 | return None
106 |
107 | if name not in self._otel_instruments:
108 | self._otel_instruments[name] = self.meter.create_histogram(
109 | name=name,
110 | description=description
111 | )
112 | return self._otel_instruments[name]
113 |
114 | def get_metrics_summary(self) -> Dict[str, Any]:
115 | """获取指标摘要(用于健康检查和调试)"""
116 | return {
117 | "counters": dict(self.counters),
118 | "gauges": dict(self.gauges),
119 | "histogram_counts": {k: len(v) for k, v in self.histograms.items()},
120 | "enabled": self.enabled,
121 | "otel_available": OTEL_METRICS_AVAILABLE and self.meter is not None
122 | }
123 |
124 |
125 | class Counter:
126 | """计数器指标"""
127 |
128 | def __init__(self, metrics: DatabaseMetrics, name: str, description: str):
129 | self.metrics = metrics
130 | self.name = name
131 | self.description = description
132 | self._otel_counter = metrics._get_or_create_counter(name, description)
133 |
134 | def increment(self, value: float = 1.0, labels: Optional[Dict[str, str]] = None):
135 | """增加计数器值"""
136 | if not self.metrics.enabled:
137 | return
138 |
139 | # 内存存储
140 | key = f"{self.name}:{labels}" if labels else self.name
141 | self.metrics.counters[key] += value
142 |
143 | # OpenTelemetry
144 | if self._otel_counter:
145 | self._otel_counter.add(value, labels or {})
146 |
147 |
148 | class Histogram:
149 | """直方图指标"""
150 |
151 | def __init__(self, metrics: DatabaseMetrics, name: str, description: str):
152 | self.metrics = metrics
153 | self.name = name
154 | self.description = description
155 | self._otel_histogram = metrics._get_or_create_histogram(name, description)
156 |
157 | def record(self, value: float, labels: Optional[Dict[str, str]] = None):
158 | """记录直方图值"""
159 | if not self.metrics.enabled:
160 | return
161 |
162 | # 内存存储
163 | key = f"{self.name}:{labels}" if labels else self.name
164 | self.metrics.histograms[key].append(MetricPoint(
165 | timestamp=time.time(),
166 | value=value,
167 | labels=labels or {}
168 | ))
169 |
170 | # OpenTelemetry
171 | if self._otel_histogram:
172 | self._otel_histogram.record(value, labels or {})
173 |
174 |
175 | class Gauge:
176 | """仪表盘指标"""
177 |
178 | def __init__(self, metrics: DatabaseMetrics, name: str, description: str):
179 | self.metrics = metrics
180 | self.name = name
181 | self.description = description
182 |
183 | def set(self, value: float, labels: Optional[Dict[str, str]] = None):
184 | """设置仪表盘值"""
185 | if not self.metrics.enabled:
186 | return
187 |
188 | # 内存存储
189 | key = f"{self.name}:{labels}" if labels else self.name
190 | self.metrics.gauges[key] = value
191 |
192 | # 注意:OpenTelemetry的Gauge需要通过回调函数实现,这里简化处理
193 |
--------------------------------------------------------------------------------
/packages/core/src/dbrheo/api/routes/database.py:
--------------------------------------------------------------------------------
1 | """
2 | 数据库API路由 - 提供数据库操作和管理功能
3 | 包括连接管理、结构查询、SQL执行等
4 | """
5 |
6 | from fastapi import APIRouter, HTTPException, Depends
7 | from pydantic import BaseModel
8 | from typing import List, Optional, Dict, Any
9 |
10 | from ...adapters.connection_manager import DatabaseConnectionManager
11 | from ..dependencies import get_config
12 |
13 | database_router = APIRouter()
14 |
15 |
16 | class DatabaseConnection(BaseModel):
17 | """数据库连接模型"""
18 | name: str
19 | connection_string: str
20 | dialect: str
21 | description: Optional[str] = None
22 |
23 |
24 | class SQLRequest(BaseModel):
25 | """SQL执行请求"""
26 | sql: str
27 | database: Optional[str] = None
28 | params: Optional[Dict[str, Any]] = None
29 |
30 |
31 | class SchemaRequest(BaseModel):
32 | """结构查询请求"""
33 | database: Optional[str] = None
34 | schema_name: Optional[str] = None
35 | table_name: Optional[str] = None
36 |
37 |
38 | @database_router.get("/connections")
39 | async def list_connections(config = Depends(get_config)):
40 | """获取所有数据库连接"""
41 | try:
42 | # TODO: 实现从配置中获取连接列表
43 | connections = [
44 | {
45 | "name": "default",
46 | "dialect": "sqlite",
47 | "description": "默认SQLite数据库",
48 | "status": "connected"
49 | }
50 | ]
51 |
52 | return {
53 | "connections": connections,
54 | "total": len(connections)
55 | }
56 |
57 | except Exception as e:
58 | raise HTTPException(status_code=500, detail=str(e))
59 |
60 |
61 | @database_router.post("/connections/test")
62 | async def test_connection(
63 | connection: DatabaseConnection,
64 | config = Depends(get_config)
65 | ):
66 | """测试数据库连接"""
67 | try:
68 | # 创建连接管理器
69 | manager = DatabaseConnectionManager(config)
70 |
71 | # TODO: 实现连接测试逻辑
72 | # 这里需要根据连接信息创建适配器并测试连接
73 |
74 | return {
75 | "success": True,
76 | "message": "Connection successful",
77 | "connection_info": {
78 | "name": connection.name,
79 | "dialect": connection.dialect
80 | }
81 | }
82 |
83 | except Exception as e:
84 | return {
85 | "success": False,
86 | "message": f"Connection failed: {str(e)}"
87 | }
88 |
89 |
90 | @database_router.post("/execute")
91 | async def execute_sql(
92 | request: SQLRequest,
93 | config = Depends(get_config)
94 | ):
95 | """执行SQL语句"""
96 | try:
97 | # 创建连接管理器
98 | manager = DatabaseConnectionManager(config)
99 | adapter = await manager.get_connection(request.database)
100 |
101 | # 解析SQL类型
102 | sql_info = await adapter.parse_sql(request.sql)
103 |
104 | # 根据SQL类型选择执行方法
105 | if sql_info["sql_type"] == "SELECT":
106 | result = await adapter.execute_query(
107 | request.sql,
108 | request.params
109 | )
110 | else:
111 | result = await adapter.execute_command(
112 | request.sql,
113 | request.params
114 | )
115 |
116 | return {
117 | "success": result["success"],
118 | "data": result.get("data", []),
119 | "columns": result.get("columns", []),
120 | "row_count": result.get("row_count", 0),
121 | "affected_rows": result.get("affected_rows", 0),
122 | "sql_info": sql_info,
123 | "error": result.get("error")
124 | }
125 |
126 | except Exception as e:
127 | raise HTTPException(status_code=500, detail=str(e))
128 |
129 |
130 | @database_router.get("/schema")
131 | async def get_schema_info(
132 | database: Optional[str] = None,
133 | schema_name: Optional[str] = None,
134 | config = Depends(get_config)
135 | ):
136 | """获取数据库结构信息"""
137 | try:
138 | manager = DatabaseConnectionManager(config)
139 | adapter = await manager.get_connection(database)
140 |
141 | result = await adapter.get_schema_info(schema_name)
142 |
143 | if result["success"]:
144 | return {
145 | "success": True,
146 | "schema": result["schema"]
147 | }
148 | else:
149 | raise HTTPException(status_code=500, detail=result["error"])
150 |
151 | except Exception as e:
152 | raise HTTPException(status_code=500, detail=str(e))
153 |
154 |
155 | @database_router.get("/tables/{table_name}")
156 | async def get_table_info(
157 | table_name: str,
158 | database: Optional[str] = None,
159 | config = Depends(get_config)
160 | ):
161 | """获取表结构信息"""
162 | try:
163 | manager = DatabaseConnectionManager(config)
164 | adapter = await manager.get_connection(database)
165 |
166 | table_info = await adapter.get_table_info(table_name)
167 |
168 | if "error" in table_info:
169 | raise HTTPException(status_code=404, detail=table_info["error"])
170 |
171 | return {
172 | "success": True,
173 | "table": table_info
174 | }
175 |
176 | except Exception as e:
177 | raise HTTPException(status_code=500, detail=str(e))
178 |
179 |
180 | @database_router.get("/tables")
181 | async def list_tables(
182 | database: Optional[str] = None,
183 | config = Depends(get_config)
184 | ):
185 | """获取所有表列表"""
186 | try:
187 | manager = DatabaseConnectionManager(config)
188 | adapter = await manager.get_connection(database)
189 |
190 | schema_result = await adapter.get_schema_info()
191 |
192 | if schema_result["success"]:
193 | schema = schema_result["schema"]
194 | tables = list(schema.get("tables", {}).keys())
195 | views = list(schema.get("views", {}).keys())
196 |
197 | return {
198 | "success": True,
199 | "tables": tables,
200 | "views": views,
201 | "total_tables": len(tables),
202 | "total_views": len(views)
203 | }
204 | else:
205 | raise HTTPException(status_code=500, detail=schema_result["error"])
206 |
207 | except Exception as e:
208 | raise HTTPException(status_code=500, detail=str(e))
209 |
210 |
211 | @database_router.post("/analyze")
212 | async def analyze_sql(
213 | request: SQLRequest,
214 | config = Depends(get_config)
215 | ):
216 | """分析SQL语句"""
217 | try:
218 | manager = DatabaseConnectionManager(config)
219 | adapter = await manager.get_connection(request.database)
220 |
221 | # 解析SQL
222 | sql_info = await adapter.parse_sql(request.sql)
223 |
224 | return {
225 | "success": True,
226 | "analysis": sql_info
227 | }
228 |
229 | except Exception as e:
230 | raise HTTPException(status_code=500, detail=str(e))
231 |
--------------------------------------------------------------------------------