├── .env.sample ├── .gitignore ├── .python-version ├── README.md ├── README_CN.md ├── alembic.ini ├── alembic ├── README ├── env.py ├── script.py.mako └── versions │ └── .gitkeep ├── config.ini ├── controller ├── __init__.py ├── chat_controller.py ├── chat_window_controller.py ├── login_controller.py ├── mcp_controller.py └── user_controller.py ├── core ├── __init__.py ├── common │ ├── __init__.py │ ├── container.py │ └── logger.py ├── llm │ ├── __init__.py │ ├── azure_open_ai.py │ ├── doubao_open_ai.py │ ├── llm_chat.py │ ├── llm_manager.py │ ├── llm_message.py │ ├── moonshot_open_ai.py │ └── qwen_open_ai.py └── mcp │ ├── __init__.py │ ├── convert_mcp_tools.py │ └── server │ ├── __init__.py │ └── server_loader.py ├── dao ├── __init__.py ├── chat_window_dao.py ├── mcp_server_dao.py └── user_dao.py ├── dto ├── __init__.py ├── chat_dto.py ├── chat_window_dto.py ├── global_response.py ├── mcp_server_dto.py └── user_dto.py ├── exception ├── exception.py └── exception_dict.py ├── extensions ├── __init__.py └── ext_database.py ├── main.py ├── model ├── __init__.py ├── chat_window.py ├── mcp_server.py └── user.py ├── pyproject.toml ├── service ├── __init__.py ├── chat_service.py ├── chat_window_service.py ├── mcp_config_service.py └── user_service.py ├── utils ├── db_utils.py └── result_utils.py └── uv.lock /.env.sample: -------------------------------------------------------------------------------- 1 | # 业务数据库配置 2 | DATABASE_HOST=localhost 3 | DATABASE_PORT=5432 4 | DATABASE_NAME=your_database_name 5 | DATABASE_USERNAME=your_username 6 | DATABASE_PASSWORD=your_password 7 | 8 | # AZURE OPEN AI 配置 9 | AZURE_API_KEY=your_azure_api_key 10 | AZURE_ENDPOINT=your_azure_endpoint 11 | AZURE_API_VERSION=your_api_version 12 | DEPLOYMENT_NAME=your_deployment_name 13 | 14 | # 通义千问配置 15 | DASHSCOPE_API_KEY=your_dashscope_api_key 16 | QWEN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 17 | QWEN_MODEL=your_qwen_model 18 | 19 | # 豆包配置 20 | DOUBAO_API_KEY=your_doubao_api_key 21 | DOUBAO_BASE_URL=https://ark.cn-beijing.volces.com/api/v3 22 | DOUBAO_MODEL=your_doubao_model 23 | 24 | # 月之暗面配置 25 | MOONSHOT_API_KEY=your_moonshot_api_key 26 | MOONSHOT_BASE_URL=https://ark.cn-beijing.volces.com/api/v3 27 | MOONSHOT_MODEL=your_moonshot_model 28 | 29 | # deepseek配置 30 | 31 | # claude配置 32 | 33 | # chatgpt配置 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #macos 2 | .DS_Store 3 | # Pycharm 4 | .idea 5 | # Python-generated files 6 | __pycache__/ 7 | *.py[oc] 8 | build/ 9 | dist/ 10 | wheels/ 11 | *.egg-info 12 | 13 | # Virtual environments 14 | .venv 15 | # App 16 | app.log 17 | 18 | # .env 19 | .env 20 | 21 | # Alembic 22 | alembic/versions/* 23 | !alembic/versions/.gitkeep -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
23 | 24 | Efflux is an LLM-based Agent chat client featuring streaming responses and chat history management. As an MCP Host, it leverages the Model Context Protocol to connect with various MCP Servers, enabling standardized tool invocation and data access for large language models. 25 | 26 | ### Key Features 27 | - Rapid Agent construction 28 | - Dynamic MCP tool loading and invocation 29 | - Support for multiple large language models 30 | - Real-time streaming chat responses 31 | - Chat history management 32 | 33 | ### Online Demo 34 | - 🏠 [Efflux Homepage](https://www.efflux.ai) 35 | - 🚀 [Interactive Demo](https://www.efflux.ai/demo) 36 | 37 | 38 | ### Requirements 39 | - Python 3.12+ 40 | - PostgreSQL 41 | - uv (Python package & environment manager), installable via `pip install uv` 42 | 43 | ### Quick Start 44 | 45 | 1. Clone the project 46 | ```bash 47 | git clone git@github.com:isoftstone-data-intelligence-ai/efflux-backend.git 48 | cd efflux-backend 49 | ``` 50 | 51 | 2. Install uv 52 | ```bash 53 | pip install uv 54 | ``` 55 | 56 | 3. Reload dependencies 57 | ```bash 58 | uv sync --reinstall 59 | ``` 60 | 61 | 4. Activate virtual environment 62 | ```bash 63 | # Activate virtual environment 64 | source .venv/bin/activate # MacOS/Linux 65 | 66 | # Deactivate when needed 67 | deactivate 68 | ``` 69 | 70 | 5. Configure environment variables 71 | ```bash 72 | # Copy environment variable template 73 | cp .env.sample .env 74 | 75 | # Edit .env file, configure: 76 | # 1. Database connection info (DATABASE_NAME, DATABASE_USERNAME, DATABASE_PASSWORD) 77 | # 2. At least one LLM configuration (e.g., Azure OpenAI, Qwen, Doubao, or Moonshot) 78 | ``` 79 | 80 | 6. Select the LLM 81 | ```bash 82 | # Edit core/common/container.py file 83 | # Find the llm registration section, replace with any of the following models (Qwen by default): 84 | # - QwenLlm: Qwen 85 | # - AzureLlm: Azure OpenAI 86 | # - DoubaoLlm: Doubao 87 | # - MoonshotLlm: Moonshot 88 | 89 | # Example: Using Azure OpenAI 90 | from core.llm.azure_open_ai import AzureLlm 91 | # ... 92 | llm = providers.Singleton(AzureLlm) 93 | ``` 94 | 95 | 7. Start PostgreSQL database 96 | ```bash 97 | # Method 1: If PostgreSQL is installed locally 98 | # Simply start your local PostgreSQL service 99 | 100 | # Method 2: Using Docker (example) 101 | docker run -d --name local-postgres \ 102 | -e POSTGRES_DB=your_database_name \ 103 | -e POSTGRES_USER=your_username \ 104 | -e POSTGRES_PASSWORD=your_password \ 105 | -p 5432:5432 \ 106 | postgres 107 | 108 | # Note: Ensure database connection info matches the configuration in your .env file 109 | ``` 110 | 111 | 8. Initialize database 112 | ```bash 113 | # Create a new version and generate a migration file in alembic/versions 114 | alembic revision --autogenerate -m "initial migration" 115 | 116 | # Preview SQL to be executed: 117 | alembic upgrade head --sql 118 | 119 | # If preview looks good, execute migration 120 | alembic upgrade head 121 | ``` 122 | 123 | 9. Initialize LLM template data 124 | ```bash 125 | # Run initialization script 126 | python scripts/init_llm_templates.py 127 | ``` 128 | 129 | 10. Start the service 130 | ```bash 131 | python -m uvicorn main:app --host 0.0.0.0 --port 8000 132 | ``` 133 | 134 | ### Acknowledgments 135 | 136 | This project utilizes the following excellent open-source projects and technologies: 137 | 138 | - [@modelcontextprotocol/mcp](https://modelcontextprotocol.io) - Standardized open protocol for LLM data interaction 139 | - [@langchain-ai/langchain](https://github.com/langchain-ai/langchain) - LLM application development framework 140 | - [@sqlalchemy/sqlalchemy](https://github.com/sqlalchemy/sqlalchemy) - Python SQL toolkit and ORM framework 141 | - [@pydantic/pydantic](https://github.com/pydantic/pydantic) - Data validation and settings management 142 | - [@tiangolo/fastapi](https://github.com/tiangolo/fastapi) - Modern, fast web framework 143 | - [@aio-libs/aiohttp](https://github.com/aio-libs/aiohttp) - Async HTTP client/server framework 144 | - [@sqlalchemy/alembic](https://github.com/sqlalchemy/alembic) - Database migration tool for SQLAlchemy 145 | - [@astral-sh/uv](https://github.com/astral-sh/uv) - Ultra-fast Python package manager 146 | - [@python-colorlog/colorlog](https://github.com/python-colorlog/colorlog) - Colored log output tool 147 | - [@jlowin/fastmcp](https://github.com/jlowin/fastmcp) - Python framework for building MCP servers 148 | - [@langchain-ai/langgraph](https://github.com/langchain-ai/langgraph) - Framework for building stateful multi-agent LLM applications 149 | 150 | Thanks to the developers and maintainers of these projects for their contributions to the open-source community. 151 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 25 | 26 | 27 | Efflux 是一个基于大语言模型的 Agent 智能体对话客户端,提供流式会话响应和完整的对话历史管理。通过集成 MCP 协议,系统可作为 MCP Host 连接不同的 MCP Servers,为模型提供标准化的工具调用和数据访问能力。 28 | 29 | 30 | 31 | 32 | 33 | ### 主要特性 34 | - 快速构建 Agent 智能体 35 | - MCP 工具动态加载与调用 36 | - 支持多种大语言模型接入 37 | - 实时流式对话响应 38 | - 会话历史记录管理 39 | 40 | 41 | 42 | 43 | 44 | ### 在线体验 45 | - 🏠 [Efflux 官网](https://www.efflux.ai) 46 | - 🚀 [立即体验](https://www.efflux.ai/demo) 47 | 48 | 49 | 50 | ### 环境要求 51 | - Python 3.12+ 52 | - PostgreSQL 53 | - uv 包和项目管理工具,可通过`pip install uv`安装 54 | 55 | ### 快速开始 56 | 57 | 1. 克隆项目 58 | ```bash 59 | git clone git@github.com:isoftstone-data-intelligence-ai/efflux-backend.git 60 | cd efflux-backend 61 | ``` 62 | 63 | 2. 安装 uv 64 | ```bash 65 | pip install uv 66 | ``` 67 | 68 | 3. 重载依赖项 69 | ```bash 70 | uv sync --reinstall 71 | ``` 72 | 73 | 4. 激活虚拟环境 74 | ```bash 75 | # 激活虚拟环境 76 | source .venv/bin/activate # MacOS/Linux 77 | 78 | # 退出虚拟环境(当需要时) 79 | deactivate 80 | ``` 81 | 82 | 5. 配置环境变量 83 | ```bash 84 | # 复制环境变量模板 85 | cp .env.sample .env 86 | 87 | # 编辑 .env 文件,需要配置: 88 | # 1. 数据库连接信息(DATABASE_NAME、DATABASE_USERNAME、DATABASE_PASSWORD) 89 | # 2. 至少一个大语言模型的配置(如 Azure OpenAI、通义千问、豆包或月之暗面) 90 | ``` 91 | 92 | 6. 选择使用的大模型 93 | ```bash 94 | # 编辑 core/common/container.py 文件 95 | # 找到 llm 的注册部分,根据需要替换为以下任一模型(默认使用通义千问): 96 | # - QwenLlm:通义千问 97 | # - AzureLlm:Azure OpenAI 98 | # - DoubaoLlm:豆包 99 | # - MoonshotLlm:月之暗面 100 | 101 | # 示例:使用 Azure OpenAI 102 | from core.llm.azure_open_ai import AzureLlm 103 | # ... 104 | llm = providers.Singleton(AzureLlm) 105 | ``` 106 | 107 | 7. 启动 postgres 数据库 108 | ```bash 109 | # 方法一:如果你本地已安装 PostgreSQL 110 | # 直接启动本地的 PostgreSQL 服务即可 111 | 112 | # 方法二:使用 Docker 启动(示例) 113 | docker run -d --name local-postgres \ 114 | -e POSTGRES_DB=your_database_name \ 115 | -e POSTGRES_USER=your_username \ 116 | -e POSTGRES_PASSWORD=your_password \ 117 | -p 5432:5432 \ 118 | postgres 119 | 120 | # 注意:无论使用哪种方式,请确保数据库的连接信息与 .env 文件中的配置保持一致 121 | ``` 122 | 123 | 8. 初始化数据库 124 | ```bash 125 | # 创建一个新的版本,并在alembic/versions下创建一个修改数据结构版本的py文件 126 | alembic revision --autogenerate -m "initial migration" 127 | 128 | # 预览将要执行的 SQL: 129 | alembic upgrade head --sql 130 | 131 | # 如果预览没有问题,执行迁移 132 | alembic upgrade head 133 | ``` 134 | 135 | 9. 初始化 LLM 模板数据 136 | ```bash 137 | # 运行初始化脚本 138 | python scripts/init_llm_templates.py 139 | ``` 140 | 141 | 10. 启动服务 142 | ```bash 143 | python -m uvicorn main:app --host 0.0.0.0 --port 8000 144 | ``` 145 | 146 | 147 | ### 致谢 148 | 149 | 本项目使用了以下优秀的开源项目和技术: 150 | 151 | - [@modelcontextprotocol/mcp](https://modelcontextprotocol.io) - 标准化的 LLM 数据交互开放协议 152 | - [@langchain-ai/langchain](https://github.com/langchain-ai/langchain) - LLM 应用开发框架 153 | - [@sqlalchemy/sqlalchemy](https://github.com/sqlalchemy/sqlalchemy) - Python SQL 工具包和 ORM 框架 154 | - [@pydantic/pydantic](https://github.com/pydantic/pydantic) - 数据验证和设置管理 155 | - [@tiangolo/fastapi](https://github.com/tiangolo/fastapi) - 现代、快速的 Web 框架 156 | - [@aio-libs/aiohttp](https://github.com/aio-libs/aiohttp) - 异步 HTTP 客户端/服务器框架 157 | - [@sqlalchemy/alembic](https://github.com/sqlalchemy/alembic) - SQLAlchemy 的数据库迁移工具 158 | - [@astral-sh/uv](https://github.com/astral-sh/uv) - 极速 Python 包管理器 159 | - [@python-colorlog/colorlog](https://github.com/python-colorlog/colorlog) - 彩色日志输出工具 160 | - [@jlowin/fastmcp](https://github.com/jlowin/fastmcp) - 快速构建 MCP 服务器的 Python 框架 161 | - [@langchain-ai/langgraph](https://github.com/langchain-ai/langgraph) - 构建状态化多智能体 LLM 应用的框架 162 | 163 | 感谢这些项目的开发者和维护者为开源社区做出的贡献。 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /alembic.ini: -------------------------------------------------------------------------------- 1 | # A generic, single database configuration. 2 | 3 | [alembic] 4 | # path to migration scripts 5 | # Use forward slashes (/) also on windows to provide an os agnostic path 6 | script_location = alembic 7 | 8 | # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s 9 | # Uncomment the line below if you want the files to be prepended with date and time 10 | # see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file 11 | # for all available tokens 12 | # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s 13 | 14 | # sys.path path, will be prepended to sys.path if present. 15 | # defaults to the current working directory. 16 | prepend_sys_path = . 17 | 18 | # timezone to use when rendering the date within the migration file 19 | # as well as the filename. 20 | # If specified, requires the python>=3.9 or backports.zoneinfo library. 21 | # Any required deps can installed by adding `alembic[tz]` to the pip requirements 22 | # string value is passed to ZoneInfo() 23 | # leave blank for localtime 24 | # timezone = 25 | 26 | # max length of characters to apply to the "slug" field 27 | # truncate_slug_length = 40 28 | 29 | # set to 'true' to run the environment during 30 | # the 'revision' command, regardless of autogenerate 31 | # revision_environment = false 32 | 33 | # set to 'true' to allow .pyc and .pyo files without 34 | # a source .py file to be detected as revisions in the 35 | # versions/ directory 36 | # sourceless = false 37 | 38 | # version location specification; This defaults 39 | # to alembic/versions. When using multiple version 40 | # directories, initial revisions must be specified with --version-path. 41 | # The path separator used here should be the separator specified by "version_path_separator" below. 42 | # version_locations = %(here)s/bar:%(here)s/bat:alembic/versions 43 | 44 | # version path separator; As mentioned above, this is the character used to split 45 | # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. 46 | # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. 47 | # Valid values for version_path_separator are: 48 | # 49 | # version_path_separator = : 50 | # version_path_separator = ; 51 | # version_path_separator = space 52 | # version_path_separator = newline 53 | version_path_separator = os # Use os.pathsep. Default configuration used for new projects. 54 | 55 | # set to 'true' to search source files recursively 56 | # in each "version_locations" directory 57 | # new in Alembic version 1.10 58 | # recursive_version_locations = false 59 | 60 | # the output encoding used when revision files 61 | # are written from script.py.mako 62 | # output_encoding = utf-8 63 | 64 | sqlalchemy.url = 65 | 66 | 67 | [post_write_hooks] 68 | # post_write_hooks defines scripts or Python functions that are run 69 | # on newly generated revision scripts. See the documentation for further 70 | # detail and examples 71 | 72 | # format using "black" - use the console_scripts runner, against the "black" entrypoint 73 | # hooks = black 74 | # black.type = console_scripts 75 | # black.entrypoint = black 76 | # black.options = -l 79 REVISION_SCRIPT_FILENAME 77 | 78 | # lint with attempts to fix using "ruff" - use the exec runner, execute a binary 79 | # hooks = ruff 80 | # ruff.type = exec 81 | # ruff.executable = %(here)s/.venv/bin/ruff 82 | # ruff.options = --fix REVISION_SCRIPT_FILENAME 83 | 84 | # Logging configuration 85 | [loggers] 86 | keys = root,sqlalchemy,alembic 87 | 88 | [handlers] 89 | keys = console 90 | 91 | [formatters] 92 | keys = generic 93 | 94 | [logger_root] 95 | level = WARNING 96 | handlers = console 97 | qualname = 98 | 99 | [logger_sqlalchemy] 100 | level = WARNING 101 | handlers = 102 | qualname = sqlalchemy.engine 103 | 104 | [logger_alembic] 105 | level = DEBUG 106 | handlers = 107 | qualname = alembic 108 | 109 | [handler_console] 110 | class = StreamHandler 111 | args = (sys.stderr,) 112 | level = NOTSET 113 | formatter = generic 114 | 115 | [formatter_generic] 116 | format = %(levelname)-5.5s [%(name)s] %(message)s 117 | datefmt = %H:%M:%S 118 | -------------------------------------------------------------------------------- /alembic/README: -------------------------------------------------------------------------------- 1 | Generic single-database configuration. -------------------------------------------------------------------------------- /alembic/env.py: -------------------------------------------------------------------------------- 1 | from logging.config import fileConfig 2 | 3 | import asyncio 4 | from sqlalchemy.ext.asyncio import create_async_engine 5 | from sqlalchemy import engine_from_config 6 | from sqlalchemy import pool 7 | 8 | from alembic import context 9 | import os 10 | from dotenv import load_dotenv 11 | 12 | config = context.config 13 | 14 | if config.config_file_name is not None: 15 | fileConfig(config.config_file_name) 16 | 17 | # 加载环境变量 18 | load_dotenv() 19 | 20 | # 从环境变量构建连接字符串 21 | DATABASE_URL = ( 22 | f"postgresql+asyncpg://{os.environ['DATABASE_USERNAME']}:{os.environ['DATABASE_PASSWORD']}" 23 | f"@{os.environ['DATABASE_HOST']}:{os.environ['DATABASE_PORT']}/{os.environ['DATABASE_NAME']}" 24 | ) 25 | 26 | print(DATABASE_URL) 27 | 28 | # Configure the URL of the database 29 | config.set_main_option('sqlalchemy.url', DATABASE_URL) 30 | 31 | from extensions.ext_database import Base, engine 32 | from model.user import User 33 | from model.mcp_server import McpServer 34 | from model.chat_window import ChatWindow 35 | 36 | 37 | target_metadata = Base.metadata 38 | 39 | def run_migrations_offline() -> None: 40 | 41 | url = config.get_main_option("sqlalchemy.url") 42 | context.configure( 43 | url=url, 44 | target_metadata=target_metadata, 45 | literal_binds=True, 46 | dialect_opts={"paramstyle": "named"}, 47 | ) 48 | 49 | with context.begin_transaction(): 50 | context.run_migrations() 51 | 52 | async def debug_transaction(sync_connection): 53 | context.configure( 54 | connection=sync_connection, 55 | target_metadata=target_metadata, 56 | transactional_ddl=True 57 | ) 58 | print("Configured context") 59 | context.run_migrations() 60 | print("Migrations executed") 61 | 62 | async def run_migrations_online() -> None: 63 | """Run migrations in 'online' mode.""" 64 | connectable = create_async_engine( 65 | DATABASE_URL, 66 | poolclass=pool.NullPool, 67 | ) 68 | 69 | async with connectable.connect() as connection: 70 | await connection.run_sync( 71 | lambda sync_connection: context.configure( 72 | connection=sync_connection, 73 | target_metadata=target_metadata, 74 | compare_type=True, # 检查字段类型的变化 75 | transactional_ddl=True # 确保事务性 DDL 76 | ) 77 | ) 78 | 79 | async with connection.begin() as trans: 80 | try: 81 | await connection.run_sync(lambda sync_connection: context.run_migrations()) 82 | await trans.commit() # 明确提交事务 83 | except Exception: 84 | await trans.rollback() # 出现问题回滚事务 85 | raise 86 | 87 | def run_migrations() -> None: 88 | """Run the migrations in an async way.""" 89 | asyncio.run(run_migrations_online()) 90 | 91 | 92 | if context.is_offline_mode(): 93 | run_migrations_offline() 94 | else: 95 | run_migrations() 96 | -------------------------------------------------------------------------------- /alembic/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from typing import Sequence, Union 9 | 10 | from alembic import op 11 | import sqlalchemy as sa 12 | ${imports if imports else ""} 13 | 14 | # revision identifiers, used by Alembic. 15 | revision: str = ${repr(up_revision)} 16 | down_revision: Union[str, None] = ${repr(down_revision)} 17 | branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} 18 | depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} 19 | 20 | 21 | def upgrade() -> None: 22 | ${upgrades if upgrades else "pass"} 23 | 24 | 25 | def downgrade() -> None: 26 | ${downgrades if downgrades else "pass"} 27 | -------------------------------------------------------------------------------- /alembic/versions/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isoftstone-data-intelligence-ai/efflux-backend/c25742b0c76663265e42fdc67385a4bf78ee15c3/alembic/versions/.gitkeep -------------------------------------------------------------------------------- /config.ini: -------------------------------------------------------------------------------- 1 | [SECURITY] 2 | SECRET_KEY = "1101faa7d00266ed0edd81a44ca33661b79a5c891bfae134f77a2320a5b1c1ea" 3 | ALGORITHM = "HS256" 4 | ACCESS_TOKEN_EXPIRE_MINUTES = 30 -------------------------------------------------------------------------------- /controller/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isoftstone-data-intelligence-ai/efflux-backend/c25742b0c76663265e42fdc67385a4bf78ee15c3/controller/__init__.py -------------------------------------------------------------------------------- /controller/chat_controller.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends 2 | from fastapi.responses import StreamingResponse 3 | from core.common.container import Container 4 | from dto.chat_dto import ChatDTO 5 | from service.chat_service import ChatService 6 | 7 | router = APIRouter(prefix="/chat", tags=["Chat"]) 8 | 9 | # 从容器中获取注册在容器中的 ChatService 实例 10 | def get_chat_service() -> ChatService: 11 | """ 12 | 获取 ChatService 实例 13 | 14 | 通过依赖注入,从容器中获取注册的 ChatService 实例。 15 | 16 | Returns: 17 | ChatService: 用于处理会话的服务实例 18 | """ 19 | return Container.chat_service() 20 | 21 | 22 | @router.post("/stream_agent", summary="流式返回会话") 23 | async def stream_response(chat_dto: ChatDTO, chat_service: ChatService = Depends(get_chat_service)): 24 | """ 25 | 模型会话接口 - 返回流式响应 26 | 27 | 接收用户请求数据并通过 ChatService 处理会话逻辑,返回流式响应。 28 | 29 | Args: 30 | chat_dto (ChatDTO): 包含会话请求数据的对象 31 | chat_service (ChatService): 会话服务实例(通过依赖注入获取) 32 | 33 | Returns: 34 | StreamingResponse: 流式响应对象,媒体类型为 text/event-stream 35 | """ 36 | return StreamingResponse( 37 | chat_service.agent_stream(chat_dto), 38 | media_type="text/event-stream" 39 | ) 40 | 41 | @router.post("/normal_chat", summary="普通会话") 42 | async def normal_chat(chat_dto: ChatDTO, chat_service: ChatService = Depends(get_chat_service)): 43 | return await chat_service.normal_chat(chat_dto) -------------------------------------------------------------------------------- /controller/chat_window_controller.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends 2 | from core.common.container import Container 3 | from dto.global_response import GlobalResponse 4 | from service.chat_window_service import ChatWindowService 5 | from utils import result_utils 6 | 7 | router = APIRouter(prefix="/chat_window", tags=["ChatWindow"]) 8 | 9 | 10 | def get_chat_window_service(): 11 | return Container.chat_window_service() 12 | 13 | 14 | @router.get("/chat_window_list/{user_id}", summary="用户会话列表") 15 | async def get_chat_window_list(user_id: int, chat_window_service: ChatWindowService = Depends(get_chat_window_service)) \ 16 | -> GlobalResponse: 17 | result = await chat_window_service.get_user_chat_windows(user_id) 18 | return result_utils.build_response(result) 19 | -------------------------------------------------------------------------------- /controller/login_controller.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone, timedelta 2 | from asyncpg.pgproto.pgproto import timedelta 3 | from fastapi import APIRouter, HTTPException, Depends, status 4 | from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm 5 | from core.common.logger import get_logger 6 | from core.common.container import Container 7 | from dto.user_dto import Token 8 | from service.user_service import UserService 9 | from passlib.context import CryptContext 10 | import configparser 11 | import jwt 12 | 13 | router = APIRouter(prefix="/auth", tags=["Authentication"]) 14 | # 显示创建日志 15 | logger = get_logger(__name__) 16 | 17 | # config 18 | config_file = "config.ini" 19 | config = configparser.ConfigParser() 20 | config.read(config_file) 21 | # security 22 | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") 23 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") 24 | 25 | SECRET_KEY = config['SECURITY']['SECRET_KEY'] 26 | ALGORITHM = config['SECURITY']['ALGORITHM'] 27 | ACCESS_TOKEN_EXPIRE_MINUTES = config['SECURITY']['ACCESS_TOKEN_EXPIRE_MINUTES'] 28 | 29 | 30 | # 从容器中获取注册在容器中的user Service 31 | def get_user_service() -> UserService: 32 | us = Container.user_service() # 33 | return us 34 | 35 | 36 | # 密码校验 37 | def verify_password(plain_password, hashed_password): 38 | return pwd_context.verify(plain_password, hashed_password) 39 | 40 | # 密码hash 41 | def get_password_hash(password): 42 | return pwd_context.hash(password) 43 | 44 | 45 | # 验证用户 46 | async def authenticate_user(username: str, 47 | password: str, 48 | user_service: UserService = Depends(get_user_service)): 49 | user = await user_service.get_user_by_name(username) 50 | if not user: 51 | return False 52 | if not verify_password(password, user.password): 53 | return False 54 | return user 55 | 56 | 57 | # 创建令牌 58 | async def create_access_token(data: dict, expires_delta: timedelta | None = None): 59 | to_encode = data.copy() 60 | if expires_delta: 61 | expire = datetime.now(timezone.utc) + expires_delta 62 | else: 63 | expire = datetime.now(timezone.utc) + timedelta(minutes=15) 64 | to_encode.update({"exp": expire}) 65 | return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) 66 | 67 | 68 | # 登录成功,颁发令牌 69 | async def login_for_access_token(username: str, password: str) -> Token: 70 | user = await authenticate_user(username, password) 71 | if not user: 72 | raise HTTPException( 73 | status_code=status.HTTP_401_UNAUTHORIZED, 74 | detail="Could not validate credentials", 75 | headers={"WWW-Authenticate": "Bearer"}, 76 | ) 77 | access_token_expires = timedelta(minutes=int(ACCESS_TOKEN_EXPIRE_MINUTES)) 78 | access_token = await create_access_token( 79 | data={"sub": user.name}, expires_delta=access_token_expires 80 | ) 81 | return Token(access_token=access_token, token_type="bearer") 82 | 83 | 84 | # 登录接口 85 | @router.post("/login", summary="登录") 86 | async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> Token: 87 | return await login_for_access_token(form_data.username, form_data.password) 88 | 89 | 90 | # 登出接口 需要token 91 | @router.post("/logout", summary="登出") 92 | async def get_current_user(token: str = Depends(oauth2_scheme), user_service: UserService = Depends(get_user_service)): 93 | credentials_exception = HTTPException( 94 | status_code=status.HTTP_401_UNAUTHORIZED, 95 | detail="Could not validate credentials", 96 | headers={"WWW-Authenticate": "Bearer"}, 97 | ) 98 | try: 99 | payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) 100 | username: str = payload.get("sub") 101 | if username is None: 102 | raise credentials_exception 103 | except jwt.JWTError: 104 | raise credentials_exception 105 | user = await user_service.get_user_by_name(username) 106 | if user is None: 107 | raise credentials_exception 108 | return user 109 | 110 | @router.post("/logout", summary="登出") 111 | async def logout(current_user = Depends(get_current_user)): 112 | # 由于使用了JWT,服务器端不需要存储token状态 113 | # 客户端只需要删除本地存储的token即可 114 | # 这里可以添加一些额外的清理工作,比如记录日志等 115 | logger.info(f"User {current_user.name} logged out") 116 | return {"message": "Logged out successfully"} 117 | -------------------------------------------------------------------------------- /controller/mcp_controller.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends 2 | from core.common.container import Container 3 | from dto.global_response import GlobalResponse 4 | from dto.mcp_server_dto import MCPServerDTO, CreateMCPServerDTO 5 | from service.mcp_config_service import MCPConfigService 6 | from utils import result_utils 7 | 8 | router = APIRouter(prefix="/mcp", tags=["MCPServer"]) 9 | 10 | 11 | def get_mcp_config_service() -> MCPConfigService: 12 | mcp_config_service = Container.mcp_config_service() # 直接获取已注册的服务 13 | return mcp_config_service 14 | 15 | 16 | @router.get("/mcp_server_list/{user_id}") 17 | async def mcp_server_list(user_id: int, mcp_config_service: MCPConfigService = Depends(get_mcp_config_service)) \ 18 | -> GlobalResponse: 19 | mcp_servers = await mcp_config_service.get_user_servers(user_id) 20 | return result_utils.build_response(mcp_servers) 21 | 22 | 23 | @router.get("/mcp_server/{_id}") 24 | async def get_server( 25 | _id: int, 26 | mcp_config_service: MCPConfigService = Depends(get_mcp_config_service) 27 | ) -> GlobalResponse: 28 | server = await mcp_config_service.get_server(_id) 29 | return result_utils.build_response(server) 30 | 31 | 32 | @router.post("/mcp_server") 33 | async def add_server( 34 | server: CreateMCPServerDTO, 35 | mcp_config_service: MCPConfigService = Depends(get_mcp_config_service) 36 | ) -> GlobalResponse: 37 | new_server = await mcp_config_service.add_server(server) 38 | return result_utils.build_response(new_server) 39 | 40 | 41 | @router.put("/mcp_server") 42 | async def update_server( 43 | server: MCPServerDTO, 44 | mcp_config_service: MCPConfigService = Depends(get_mcp_config_service) 45 | ) -> GlobalResponse: 46 | updated_server = await mcp_config_service.update_server(server) 47 | return result_utils.build_response(updated_server) 48 | 49 | 50 | @router.delete("/mcp_server/{_id}") 51 | async def delete_server( 52 | _id: int, 53 | mcp_config_service: MCPConfigService = Depends(get_mcp_config_service) 54 | ) -> GlobalResponse: 55 | await mcp_config_service.delete_server(_id) 56 | return result_utils.build_response(None) 57 | -------------------------------------------------------------------------------- /controller/user_controller.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends 2 | from dto.user_dto import UserResult, UserInit 3 | from core.common.container import Container 4 | from service.user_service import UserService 5 | from typing import List 6 | from core.common.logger import get_logger 7 | from model.user import User 8 | 9 | router = APIRouter(prefix="/user", tags=["Users"]) 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | def get_user_service() -> UserService: 15 | return Container.user_service() # 直接获取已注册的测试服务 16 | 17 | 18 | def convert_to_user_result(users: List[User]) -> List[UserResult]: 19 | return [UserResult(id=user.id, name=user.name, email=user.email) for user in users] 20 | 21 | 22 | @router.get("/users", summary="用户列表", response_model=List[UserResult]) 23 | async def list_users(user_service: UserService = Depends(get_user_service)): 24 | """ 25 | List all users 用户列表 26 | """ 27 | users: List[User] = await user_service.get_users() 28 | user_result = convert_to_user_result(users) 29 | return user_result 30 | 31 | 32 | @router.post("/user", summary="用户注册", response_model=UserResult) 33 | async def create_user(user: UserInit, user_service: UserService = Depends(get_user_service)): 34 | """ 35 | 创建用户 36 | :user: 用户初始信息 37 | """ 38 | rs = await user_service.create_user(user.name, user.email, user.password) 39 | logger.error(rs) 40 | return rs 41 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isoftstone-data-intelligence-ai/efflux-backend/c25742b0c76663265e42fdc67385a4bf78ee15c3/core/__init__.py -------------------------------------------------------------------------------- /core/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isoftstone-data-intelligence-ai/efflux-backend/c25742b0c76663265e42fdc67385a4bf78ee15c3/core/common/__init__.py -------------------------------------------------------------------------------- /core/common/container.py: -------------------------------------------------------------------------------- 1 | from dependency_injector import containers, providers 2 | 3 | from extensions.ext_database import DatabaseProvider 4 | from dao.user_dao import UserDAO 5 | from dao.chat_window_dao import ChatWindowDAO 6 | from dao.mcp_server_dao import MCPServerDAO 7 | from service.chat_service import ChatService 8 | from service.chat_window_service import ChatWindowService 9 | from service.mcp_config_service import MCPConfigService 10 | from service.user_service import UserService 11 | from core.llm.qwen_open_ai import QwenLlm 12 | 13 | 14 | class Container(containers.DeclarativeContainer): 15 | # 注册数据库会话提供器 16 | database_provider = providers.Singleton(DatabaseProvider) 17 | 18 | # 注册 user DAO 19 | user_dao = providers.Singleton(UserDAO, session_factory=database_provider.provided.session_factory) 20 | # 注册 user Service 21 | user_service = providers.Singleton(UserService, user_dao=user_dao) 22 | 23 | # 注册 MCPServer DAO 24 | mcp_server_dao = providers.Singleton(MCPServerDAO, session_factory=database_provider.provided.session_factory) 25 | # 注册 MCP Config Service 26 | mcp_config_service = providers.Singleton(MCPConfigService, mcp_server_dao=mcp_server_dao) 27 | 28 | # 注册chat_window DAO 29 | chat_window_dao = providers.Singleton(ChatWindowDAO, session_factory=database_provider.provided.session_factory) 30 | # 注册chat_window Service 31 | chat_window_service = providers.Singleton(ChatWindowService, chat_window_dao=chat_window_dao) 32 | 33 | # 注册模型 34 | llm = providers.Singleton(QwenLlm) 35 | 36 | # 注册 chat Service 37 | chat_service = providers.Singleton(ChatService, llm=llm, mcp_config_service=mcp_config_service, chat_window_dao=chat_window_dao) 38 | 39 | 40 | -------------------------------------------------------------------------------- /core/common/logger.py: -------------------------------------------------------------------------------- 1 | import logging.config 2 | 3 | # 高级日志配置 4 | LOGGING_CONFIG = { 5 | "version": 1, 6 | "disable_existing_loggers": False, # 必须为 False 7 | "formatters": { 8 | "default": { 9 | "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", 10 | }, 11 | "detailed": { 12 | "format": "%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", 13 | }, 14 | "color": { # 带颜色的格式器 15 | "()": "colorlog.ColoredFormatter", 16 | "format": "%(log_color)s%(asctime)s - %(name)s - %(levelname)s - %(message)s", 17 | "datefmt": "%Y-%m-%d %H:%M:%S", 18 | "log_colors": { 19 | "DEBUG": "white", 20 | "INFO": "cyan", 21 | "WARNING": "yellow", 22 | "ERROR": "red", 23 | "CRITICAL": "bold_red", 24 | }, 25 | }, 26 | }, 27 | "handlers": { 28 | "console": { 29 | "level": "DEBUG", 30 | "class": "logging.StreamHandler", 31 | "formatter": "color", 32 | }, 33 | "file": { 34 | "level": "INFO", 35 | "class": "logging.FileHandler", 36 | "formatter": "detailed", 37 | "filename": "app.log", 38 | }, 39 | }, 40 | "loggers": { 41 | "": { # 根日志器 42 | "handlers": ["console", "file"], 43 | "level": "INFO", 44 | "propagate": False, 45 | }, 46 | "sqlalchemy": { # 配置 SQLAlchemy 日志 47 | "handlers": ["console", "file"], # 使用与根日志器相同的处理器 48 | "level": "WARNING", # SQLAlchemy 日志级别 49 | "propagate": False, # 禁止传播到根日志器 50 | }, 51 | "sqlalchemy.engine.Engine": { # 配置 sqlalchemy.engine.Engine 日志 52 | "handlers": ["console", "file"], # 使用与根日志器相同的处理器 53 | "level": "WARNING", # sqlalchemy.engine.Engine 日志级别 54 | "propagate": False, # 禁止传播到根日志器 55 | }, 56 | "uvicorn.access": { # FastAPI 的访问日志 57 | "handlers": ["console", "file"], 58 | "level": "INFO", 59 | "propagate": False, 60 | }, 61 | }, 62 | } 63 | 64 | # 应用日志配置 65 | logging.config.dictConfig(LOGGING_CONFIG) 66 | 67 | # 日志获取函数 68 | def get_logger(name): 69 | return logging.getLogger(name) 70 | 71 | def logger(cls): 72 | cls.log = get_logger(cls.__name__) 73 | return cls -------------------------------------------------------------------------------- /core/llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isoftstone-data-intelligence-ai/efflux-backend/c25742b0c76663265e42fdc67385a4bf78ee15c3/core/llm/__init__.py -------------------------------------------------------------------------------- /core/llm/azure_open_ai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from langchain_core.language_models import LanguageModelLike 3 | from core.llm.llm_chat import LLMChat 4 | from langchain_openai import AzureChatOpenAI 5 | from typing import Union 6 | from dotenv import load_dotenv 7 | 8 | # 加载环境变量 9 | load_dotenv() 10 | 11 | 12 | class AzureLlm(LLMChat): 13 | """ 14 | AzureLlm 类 - 封装 Azure OpenAI 服务作为语言模型 15 | 16 | 该类继承自 LLMChat,提供了与 Azure OpenAI 集成的具体实现。 17 | 它通过读取环境变量获取 Azure OpenAI 服务的配置信息,并创建语言模型实例。 18 | """ 19 | ENABLED = True # 指定该模型是否启用 20 | 21 | def is_enable(self) -> bool: 22 | """ 23 | 检查当前模型是否启用 24 | 25 | Returns: 26 | bool: 返回 True 表示模型启用 27 | """ 28 | return True 29 | 30 | def name(self) -> Union[str, None]: 31 | """ 32 | 获取模型的名称 33 | 34 | Returns: 35 | Union[str, None]: 模型名称,返回 "AzureLlm" 表示当前模型。 36 | """ 37 | return "AzureLlm" 38 | 39 | def get_llm_model(self) -> Union[LanguageModelLike, None]: 40 | """ 41 | 获取 Azure OpenAI 服务的语言模型实例 42 | 43 | 通过环境变量加载 Azure OpenAI 服务的配置,创建一个 AzureChatOpenAI 实例。 44 | 45 | Returns: 46 | Union[LanguageModelLike, None]: 返回一个 AzureChatOpenAI 实例,用于流式生成响应。 47 | """ 48 | return AzureChatOpenAI( 49 | deployment_name=os.environ['DEPLOYMENT_NAME'], # Azure 部署名称 50 | openai_api_key=os.environ['AZURE_API_KEY'], # Azure OpenAI API 密钥 51 | azure_endpoint=os.environ['AZURE_ENDPOINT'], # Azure 服务端点 52 | openai_api_version=os.environ['AZURE_API_VERSION'], # Azure OpenAI API 版本 53 | temperature=0.7, # 模型生成文本的随机性参数 54 | streaming=True, # 启用流式响应 55 | ) 56 | -------------------------------------------------------------------------------- /core/llm/doubao_open_ai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from langchain_core.language_models import LanguageModelLike 3 | from core.llm.llm_chat import LLMChat 4 | from langchain_openai import ChatOpenAI 5 | from typing import Union 6 | from dotenv import load_dotenv 7 | 8 | # 加载环境变量 9 | load_dotenv() 10 | 11 | 12 | class DoubaoLlm(LLMChat): 13 | """ 14 | DoubaoLlm 类 - 封装豆包自研语言模型 15 | 16 | 该类继承自 LLMChat,提供了与豆包自研语言模型(Doubao LLM)的集成实现。 17 | 它通过环境变量加载配置信息,并创建对应的语言模型实例。 18 | """ 19 | ENABLED = True # 指定该模型是否启用 20 | 21 | def is_enable(self) -> bool: 22 | """ 23 | 检查当前模型是否启用 24 | 25 | Returns: 26 | bool: 返回 True 表示模型启用 27 | """ 28 | return True 29 | 30 | def name(self) -> Union[str, None]: 31 | """ 32 | 获取模型的名称 33 | 34 | Returns: 35 | Union[str, None]: 模型名称,返回 "DoubaoLlm" 表示当前模型。 36 | """ 37 | return "DoubaoLlm" 38 | 39 | def get_llm_model(self) -> Union[LanguageModelLike, None]: 40 | """ 41 | 获取豆包语言模型实例 42 | 43 | 通过环境变量加载豆包语言模型的配置,创建一个 ChatOpenAI 实例。 44 | 45 | Returns: 46 | Union[LanguageModelLike, None]: 返回一个 ChatOpenAI 实例,用于流式生成响应。 47 | """ 48 | return ChatOpenAI( 49 | api_key=os.environ['DOUBAO_API_KEY'], # 豆包 API 密钥 50 | base_url=os.environ['DOUBAO_BASE_URL'], # 豆包服务的基础 URL 51 | model=os.environ['DOUBAO_MODEL'], # 使用的模型名称 52 | streaming=True, # 启用流式响应 53 | ) 54 | -------------------------------------------------------------------------------- /core/llm/llm_chat.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Union, AsyncGenerator, Any, List, Optional 3 | from langchain_core.language_models import LanguageModelLike 4 | from core.llm.llm_message import LLMMessage 5 | from langchain_core.tools import BaseTool 6 | from langgraph.prebuilt import create_react_agent 7 | from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage 8 | from core.common.logger import get_logger 9 | 10 | logger = get_logger(__name__) 11 | 12 | 13 | class LLMChat(ABC): 14 | """ 15 | 抽象类用于定义与语言模型进行对话的基础接口。 16 | 17 | Attributes: 18 | llm_model (LanguageModelLike): 语言模型实例。 19 | """ 20 | 21 | def __init__(self): 22 | """ 23 | 初始化LLMChat类并设置语言模型。 24 | 25 | 创建一个新实例时会调用此构造函数来初始化语言模型。 26 | """ 27 | self.llm_model = self.get_llm_model() 28 | 29 | @abstractmethod 30 | def is_enable(self) -> bool: 31 | """ 32 | 检查是否启用了LLMChat服务。 33 | 34 | Returns: 35 | bool: 如果服务可用则返回True,否则返回False。 36 | """ 37 | 38 | @abstractmethod 39 | def name(self) -> Union[str, None]: 40 | """ 41 | 获取聊天模块的名称。 42 | 43 | Returns: 44 | Union[str, None]: 聊天模块的名称或None如果未指定。 45 | """ 46 | 47 | @abstractmethod 48 | def get_llm_model(self) -> Union[LanguageModelLike, None]: 49 | """ 50 | 获取当前使用的语言模型。 51 | 52 | Returns: 53 | Union[LanguageModelLike, None]: 语言模型实例或None如果未指定。 54 | """ 55 | 56 | async def stream_chat(self, inputs: Union[dict[str, Any], Any], tools: List[BaseTool], callback=None) -> \ 57 | AsyncGenerator[LLMMessage, None]: 58 | """ 59 | 异步流式处理用户输入并与语言模型交互。 60 | 61 | Args: 62 | inputs (Union[dict[str, Any], Any]): 用户输入的数据。 63 | tools (List[BaseTool]): 工具列表,可以被语言模型使用。 64 | callback (callable, optional): 完成后的回调函数。 65 | 66 | Yields: 67 | LLMMessage: 包含从语言模型获得的消息或工具调用信息。 68 | """ 69 | # 收集流数据的容器 70 | collected_data = { 71 | "messages": [], # 存储接收到的消息内容 72 | "tool_calls": [], # 存储工具调用记录 73 | "tool_errors": [] # 存储工具调用错误 74 | } 75 | 76 | graph = create_react_agent(model=self.llm_model, tools=tools) 77 | 78 | async for chunk in graph.astream(inputs, stream_mode=["messages", "values"]): 79 | # chat消息获取 80 | if isinstance(chunk, tuple) and chunk[0] == "messages": 81 | message_chunk = chunk[1][0] # Get the message content 82 | if isinstance(message_chunk, AIMessageChunk): 83 | if message_chunk.content != '': 84 | # chat结果收集 85 | collected_data["messages"].append(message_chunk.content) 86 | yield LLMMessage(content=message_chunk.content, type="message") 87 | elif isinstance(chunk, dict) and "messages" in chunk: 88 | # Print a newline after the complete message 89 | print("newline\n", flush=True) 90 | # tools调用消息获取 91 | elif isinstance(chunk, tuple) and chunk[0] == "values": 92 | message = chunk[1]['messages'][-1] 93 | if isinstance(message, AIMessage) and message.tool_calls: 94 | # 工具调用 标题话术 95 | yield LLMMessage(content="Tool Calls: ", type="tool_call") 96 | for tc in message.tool_calls: 97 | # 工具调用收集 98 | collected_data["tool_calls"].append(tc) 99 | # 工具调用 名称 100 | yield LLMMessage(content=tc.get('name', 'Tool'), type="tool_call") 101 | if tc.get("error"): 102 | # 工具调用错误收集 103 | collected_data["tool_errors"].append(tc.get("error")) 104 | yield LLMMessage(content=tc.get('error'), type="tool_call") 105 | # 工具调用 参数 106 | yield LLMMessage(content=' Args:', type="tool_call") 107 | args = tc.get("args") 108 | if isinstance(args, str): 109 | yield LLMMessage(content=f' {args}', type="tool_call") 110 | elif isinstance(args, dict): 111 | for arg, value in args.items(): 112 | yield LLMMessage(content=f' {arg}: {value}', type="tool_call") 113 | 114 | # 返回最终 collected_data 115 | if callback: 116 | await callback(collected_data) 117 | 118 | async def normal_chat(self, model_id: int, inputs: str) -> str: 119 | result = self.llm_model.invoke([HumanMessage(content=inputs)]) 120 | return result.content 121 | -------------------------------------------------------------------------------- /core/llm/llm_manager.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Dict 2 | from core.llm.llm_chat import LLMChat 3 | 4 | 5 | class LLMManager: 6 | """ 7 | 管理多个语言模型(LLMChat)实例的管理器。 8 | 9 | 该管理器允许通过名称获取特定的语言模型实现,并可以方便地在不同模型之间切换或访问。 10 | 11 | Attributes: 12 | llm_map (Dict[str, Type[LLMChat]]): 映射了语言模型名称到其实现类的字典。 13 | """ 14 | 15 | def __init__(self, llm_map: Dict[str, Type[LLMChat]]): 16 | """ 17 | 初始化LLMManager类并设置语言模型映射。 18 | 19 | Args: 20 | llm_map (Dict[str, Type[LLMChat]]): 包含语言模型名称和对应实现类的映射。 21 | """ 22 | self.llm_map = llm_map 23 | 24 | def get_llm(self, name: str) -> Type[LLMChat]: 25 | """ 26 | 根据给定的名称获取相应的语言模型实现类。 27 | 28 | 如果提供的名称不在 `llm_map` 中,则会引发 KeyError 异常。 29 | 30 | Args: 31 | name (str): 要获取的语言模型的名称。 32 | 33 | Returns: 34 | Type[LLMChat]: 对应名称的语言模型实现类。 35 | 36 | Raises: 37 | KeyError: 如果提供的名称不在 `llm_map` 中。 38 | """ 39 | return self.llm_map[name] 40 | 41 | -------------------------------------------------------------------------------- /core/llm/llm_message.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | class LLMMessage(BaseModel): 4 | type: str 5 | content: str -------------------------------------------------------------------------------- /core/llm/moonshot_open_ai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from langchain_core.language_models import LanguageModelLike 3 | from core.llm.llm_chat import LLMChat 4 | from langchain_openai import ChatOpenAI 5 | from typing import Union 6 | from dotenv import load_dotenv 7 | 8 | # 加载环境变量 9 | load_dotenv() 10 | 11 | 12 | class MoonshotLlm(LLMChat): 13 | """ 14 | 实现了LLMChat接口的具体语言模型类,用于与名为'Moonshot'的语言模型交互。 15 | 16 | Attributes: 17 | ENABLED (bool): 表示该模型是否启用的常量。 18 | """ 19 | 20 | ENABLED = True 21 | 22 | def is_enable(self) -> bool: 23 | """ 24 | 检查MoonshotLlm是否已启用。 25 | 26 | Returns: 27 | bool: 总是返回True,表示此实现总是启用的。 28 | """ 29 | return True 30 | 31 | def name(self) -> Union[str, None]: 32 | """ 33 | 获取此语言模型的名称。 34 | 35 | Returns: 36 | str: 返回代表语言模型名称的字符串 "MoonshotLlm"。 37 | """ 38 | return "MoonshotLlm" 39 | 40 | def get_llm_model(self) -> Union[LanguageModelLike, None]: 41 | """ 42 | 获取配置好的Moonshot语言模型实例。 43 | 44 | 该方法从环境变量中读取必要的配置信息(如API密钥、基础URL和模型名称), 45 | 并创建一个启用了流式传输的ChatOpenAI实例来作为语言模型。 46 | 47 | Returns: 48 | LanguageModelLike: 配置好的语言模型实例,或None如果创建失败。 49 | 50 | Raises: 51 | KeyError: 如果环境变量中缺少必需的配置项。 52 | """ 53 | return ChatOpenAI( 54 | api_key=os.environ['MOONSHOT_API_KEY'], 55 | base_url=os.environ['MOONSHOT_BASE_URL'], 56 | model=os.environ['MOONSHOT_MODEL'], 57 | streaming=True, 58 | ) 59 | 60 | -------------------------------------------------------------------------------- /core/llm/qwen_open_ai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from langchain_core.language_models import LanguageModelLike 3 | from core.llm.llm_chat import LLMChat 4 | from langchain_openai import ChatOpenAI 5 | from typing import Union 6 | from dotenv import load_dotenv 7 | 8 | # 加载环境变量 9 | load_dotenv() 10 | 11 | 12 | class QwenLlm(LLMChat): 13 | """ 14 | 实现了LLMChat接口的具体语言模型类,用于与名为'Qwen'的语言模型交互。 15 | 16 | Attributes: 17 | ENABLED (bool): 表示该模型是否启用的常量。 18 | """ 19 | 20 | ENABLED = True 21 | 22 | def is_enable(self) -> bool: 23 | """ 24 | 检查QwenLlm是否已启用。 25 | 26 | Returns: 27 | bool: 总是返回True,表示此实现总是启用的。 28 | """ 29 | return True 30 | 31 | def name(self) -> Union[str, None]: 32 | """ 33 | 获取此语言模型的名称。 34 | 35 | Returns: 36 | str: 返回代表语言模型名称的字符串 "QwenLlm"。 37 | """ 38 | return "QwenLlm" 39 | 40 | def get_llm_model(self) -> Union[LanguageModelLike, None]: 41 | """ 42 | 获取配置好的Qwen语言模型实例。 43 | 44 | 该方法从环境变量中读取必要的配置信息(如API密钥、基础URL和模型名称), 45 | 并创建一个启用了流式传输的ChatOpenAI实例来作为语言模型。 46 | 47 | Returns: 48 | LanguageModelLike: 配置好的语言模型实例,或None如果创建失败。 49 | """ 50 | return ChatOpenAI( 51 | api_key=os.environ['DASHSCOPE_API_KEY'], 52 | base_url=os.environ['QWEN_BASE_URL'], 53 | model=os.environ['QWEN_MODEL'], 54 | streaming=True, 55 | ) 56 | 57 | -------------------------------------------------------------------------------- /core/mcp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isoftstone-data-intelligence-ai/efflux-backend/c25742b0c76663265e42fdc67385a4bf78ee15c3/core/mcp/__init__.py -------------------------------------------------------------------------------- /core/mcp/convert_mcp_tools.py: -------------------------------------------------------------------------------- 1 | from langchain_core.tools import BaseTool, ToolException 2 | from typing import Type, List 3 | from jsonschema_pydantic import jsonschema_to_pydantic 4 | from pydantic import BaseModel 5 | from mcp import ClientSession, StdioServerParameters, types 6 | from mcp.client.stdio import stdio_client 7 | 8 | 9 | def create_langchain_tool( 10 | tool_schema: types.Tool, 11 | server_params: StdioServerParameters 12 | ) -> BaseTool: 13 | """Create a LangChain tool from MCP tool schema.""" 14 | input_model = jsonschema_to_pydantic(tool_schema.inputSchema) 15 | 16 | class McpTool(BaseTool): 17 | name: str = tool_schema.name 18 | description: str = tool_schema.description 19 | args_schema: Type[BaseModel] = input_model 20 | mcp_server_params: StdioServerParameters = server_params 21 | 22 | def _run(self, **kwargs): 23 | raise NotImplementedError("Only async operations are supported") 24 | 25 | async def _arun(self, **kwargs): 26 | async with stdio_client(self.mcp_server_params) as (read, write): 27 | async with ClientSession(read, write) as session: 28 | await session.initialize() 29 | result = await session.call_tool(self.name, arguments=kwargs) 30 | if result.isError: 31 | raise ToolException(result.content) 32 | return result.content 33 | 34 | return McpTool() 35 | 36 | 37 | async def convert_mcp_to_langchain_tools(server_params: List[StdioServerParameters]) -> List[BaseTool]: 38 | """Convert MCP tools to LangChain tools.""" 39 | langchain_tools = [] 40 | 41 | for server_param in server_params: 42 | # cached_tools = get_cached_tools(server_param) 43 | # 44 | # if cached_tools: 45 | # for tool in cached_tools: 46 | # langchain_tools.append(create_langchain_tool(tool, server_param)) 47 | # continue 48 | 49 | async with stdio_client(server_param) as (read, write): 50 | async with ClientSession(read, write) as session: 51 | print(f"Gathering capability of {server_param.command} {' '.join(server_param.args)}") 52 | await session.initialize() 53 | tools: types.ListToolsResult = await session.list_tools() 54 | #save_tools_cache(server_param, tools.tools) 55 | 56 | for tool in tools.tools: 57 | langchain_tools.append(create_langchain_tool(tool, server_param)) 58 | 59 | return langchain_tools -------------------------------------------------------------------------------- /core/mcp/server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isoftstone-data-intelligence-ai/efflux-backend/c25742b0c76663265e42fdc67385a4bf78ee15c3/core/mcp/server/__init__.py -------------------------------------------------------------------------------- /core/mcp/server/server_loader.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Dict, Optional, List 3 | import json 4 | from core.common.logger import get_logger 5 | 6 | 7 | logger = get_logger(__name__) 8 | config_path = "./server_config.json" 9 | 10 | 11 | class StdioServerParameters(BaseModel): 12 | name: Optional[str] = None 13 | command: str 14 | args: list[str] = Field(default_factory=list) 15 | env: Optional[Dict[str, str]] = None 16 | 17 | 18 | async def load_config_from_file(server_name: str) -> StdioServerParameters: 19 | """ Load the server configuration from a JSON file. """ 20 | try: 21 | # debug 22 | logger.debug(f"Loading config from {config_path}") 23 | 24 | # Read the configuration file 25 | with open(config_path, "r") as config_file: 26 | config = json.load(config_file) 27 | 28 | # Retrieve the server configuration 29 | server_config = config.get("mcpServers", {}).get(server_name) 30 | if not server_config: 31 | error_msg = f"Server '{server_name}' not found in configuration file." 32 | logger.error(error_msg) 33 | raise ValueError(error_msg) 34 | 35 | # Construct the server parameters 36 | result = StdioServerParameters( 37 | command=server_config["command"], 38 | args=server_config.get("args", []), 39 | env=server_config.get("env"), 40 | ) 41 | 42 | # debug 43 | logger.debug(f"Loaded config from file: command='{result.command}', args={result.args}, env={result.env}") 44 | 45 | # return result 46 | return result 47 | 48 | except FileNotFoundError: 49 | # error 50 | error_msg = f"Configuration file not found: {config_path}" 51 | logger.error(error_msg) 52 | raise FileNotFoundError(error_msg) 53 | except json.JSONDecodeError as e: 54 | # json error 55 | error_msg = f"Invalid JSON in configuration file: {e.msg}" 56 | logger.error(error_msg) 57 | raise json.JSONDecodeError(error_msg, e.doc, e.pos) 58 | except ValueError as e: 59 | # error 60 | logger.error(str(e)) 61 | raise 62 | 63 | 64 | async def load_all_config() -> List[StdioServerParameters]: 65 | """ Load all the server configuration from a JSON file. """ 66 | try: 67 | # debug 68 | logger.debug(f"Loading config from {config_path}") 69 | 70 | # Read the configuration file 71 | with open(config_path, "r") as config_file: 72 | config = json.load(config_file) 73 | 74 | # Retrieve the server configuration 75 | server_configs = config.get("mcpServers", {}) 76 | result = [] 77 | for server_name, server_config in server_configs.items(): 78 | result.append( 79 | StdioServerParameters( 80 | name=server_name, 81 | command=server_config["command"], 82 | args=server_config.get("args", []), 83 | env=server_config.get("env"), 84 | ) 85 | ) 86 | return result 87 | except FileNotFoundError: 88 | # error 89 | error_msg = f"Configuration file not found: {config_path}" 90 | logger.error(error_msg) 91 | raise FileNotFoundError(error_msg) 92 | except json.JSONDecodeError as e: 93 | # json error 94 | error_msg = f"Invalid JSON in configuration file: {e.msg}" 95 | logger.error(error_msg) 96 | raise json.JSONDecodeError(error_msg, e.doc, e.pos) 97 | except ValueError as e: 98 | # error 99 | logger.error(str(e)) 100 | raise 101 | 102 | 103 | async def add_config(server: StdioServerParameters) -> StdioServerParameters: 104 | """ add a server configuration to a JSON file. """ 105 | try: 106 | # debug 107 | logger.debug(f"Adding config to {config_path}") 108 | 109 | with open(config_path, "r") as config_file: 110 | config = json.load(config_file) 111 | 112 | server_param = {"command":server.command, "args":server.args} 113 | config["mcpServers"][server.server_name] = server_param 114 | config = json.dumps(config, indent=2, ensure_ascii=False) 115 | 116 | with open(config_path, "w") as config_file: 117 | config_file.write(config) 118 | return server 119 | 120 | except FileNotFoundError: 121 | # error 122 | error_msg = f"Configuration file not found: {config_path}" 123 | logger.error(error_msg) 124 | raise FileNotFoundError(error_msg) 125 | except json.JSONDecodeError as e: 126 | # json error 127 | error_msg = f"Invalid JSON in configuration file: {e.msg}" 128 | logger.error(error_msg) 129 | raise json.JSONDecodeError(error_msg, e.doc, e.pos) 130 | except ValueError as e: 131 | # error 132 | logger.error(str(e)) 133 | raise 134 | 135 | 136 | async def edit_config(server: StdioServerParameters) -> Optional[StdioServerParameters]: 137 | """ edit a server configuration to a JSON file. """ 138 | try: 139 | # debug 140 | logger.debug(f"Updating config to {config_path}") 141 | 142 | with open(config_path, "r") as config_file: 143 | config = json.load(config_file) 144 | 145 | if server.server_name not in config["mcpServers"].keys(): 146 | return None 147 | 148 | if "command" in config["mcpServers"][server.server_name].keys(): 149 | config["mcpServers"][server.server_name]["command"] = server.command 150 | if "args" in config["mcpServers"][server.server_name].keys(): 151 | config["mcpServers"][server.server_name]["args"] = server.args 152 | if "env" in config["mcpServers"][server.server_name].keys(): 153 | config["mcpServers"][server.server_name]["env"] = server.env 154 | 155 | config = json.dumps(config, indent=2, ensure_ascii=False) 156 | with open(config_path, "w") as config_file: 157 | config_file.write(config) 158 | 159 | return server 160 | 161 | except FileNotFoundError: 162 | # error 163 | error_msg = f"Configuration file not found: {config_path}" 164 | logger.error(error_msg) 165 | raise FileNotFoundError(error_msg) 166 | except json.JSONDecodeError as e: 167 | # json error 168 | error_msg = f"Invalid JSON in configuration file: {e.msg}" 169 | logger.error(error_msg) 170 | raise json.JSONDecodeError(error_msg, e.doc, e.pos) 171 | except ValueError as e: 172 | # error 173 | logger.error(str(e)) 174 | raise 175 | 176 | 177 | async def delete_config(server_name: str) -> Optional[str]: 178 | """ delete a server configuration in a JSON file. """ 179 | try: 180 | # debug 181 | logger.debug(f"Deleting a config in {config_path}") 182 | 183 | with open(config_path, "r") as config_file: 184 | config = json.load(config_file) 185 | 186 | removed_config = config["mcpServers"].pop(server_name, None) 187 | if not removed_config: 188 | return None 189 | 190 | config = json.dumps(config, indent=2, ensure_ascii=False) 191 | with open(config_path, "w") as config_file: 192 | config_file.write(config) 193 | 194 | return server_name 195 | 196 | except FileNotFoundError: 197 | # error 198 | error_msg = f"Configuration file not found: {config_path}" 199 | logger.error(error_msg) 200 | raise FileNotFoundError(error_msg) 201 | except json.JSONDecodeError as e: 202 | # json error 203 | error_msg = f"Invalid JSON in configuration file: {e.msg}" 204 | logger.error(error_msg) 205 | raise json.JSONDecodeError(error_msg, e.doc, e.pos) 206 | except ValueError as e: 207 | # error 208 | logger.error(str(e)) 209 | raise 210 | 211 | -------------------------------------------------------------------------------- /dao/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isoftstone-data-intelligence-ai/efflux-backend/c25742b0c76663265e42fdc67385a4bf78ee15c3/dao/__init__.py -------------------------------------------------------------------------------- /dao/chat_window_dao.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from sqlalchemy.future import select 3 | from datetime import datetime 4 | from model.chat_window import ChatWindow 5 | 6 | 7 | # 会话记录DAO 8 | class ChatWindowDAO: 9 | def __init__(self, session_factory): 10 | self._session_factory = session_factory 11 | 12 | async def create_chat_window(self, user_id, summary, content: Optional[str] = None): 13 | async with self._session_factory() as session: 14 | new_chat_window = ChatWindow( 15 | user_id=user_id, 16 | summary=summary, 17 | content=content, 18 | created_at=datetime.now(), 19 | updated_at=datetime.now() 20 | ) 21 | session.add(new_chat_window) 22 | await session.commit() 23 | return new_chat_window 24 | 25 | async def delete_chat_window(self, chat_window_id): 26 | async with self._session_factory() as session: 27 | chat_window = session.query(ChatWindow).get(chat_window_id) 28 | if chat_window: 29 | session.delete(chat_window) 30 | await session.commit() 31 | return chat_window_id 32 | 33 | async def update_chat_window(self, chat_window_id, summary, content): 34 | async with self._session_factory() as session: 35 | result = await session.execute( 36 | select(ChatWindow).where(ChatWindow.id == chat_window_id) 37 | ) 38 | chat_window = result.scalar_one_or_none() 39 | 40 | if chat_window: 41 | if summary: 42 | chat_window.summary = summary 43 | chat_window.content = content 44 | chat_window.updated_at = datetime.now() 45 | await session.commit() 46 | return chat_window_id 47 | 48 | async def get_user_chat_windows(self, user_id) -> List[ChatWindow]: 49 | async with self._session_factory() as session: 50 | result = await session.execute( 51 | select(ChatWindow).where(ChatWindow.user_id == user_id).order_by(ChatWindow.id.desc())) 52 | return result.scalars().all() 53 | 54 | async def get_chat_window_by_id(self, chat_window_id: int) -> ChatWindow: 55 | async with self._session_factory() as session: 56 | result = await session.execute( 57 | select(ChatWindow).where(ChatWindow.id == chat_window_id) 58 | ) 59 | return result.scalar_one_or_none() 60 | -------------------------------------------------------------------------------- /dao/mcp_server_dao.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.future import select 2 | from sqlalchemy import update, delete 3 | from model.mcp_server import McpServer 4 | from typing import Optional, List, Dict, Any 5 | from datetime import datetime 6 | 7 | class MCPServerDAO: 8 | def __init__(self, session_factory): 9 | self._session_factory = session_factory 10 | print('mcp server dao init') 11 | 12 | async def get_all_servers(self) -> List[McpServer]: 13 | """获取所有 MCP 服务器""" 14 | async with self._session_factory() as session: 15 | result = await session.execute(select(McpServer)) 16 | return result.scalars().all() 17 | 18 | async def get_server_by_id(self, id: int) -> Optional[McpServer]: 19 | """根据服务器ID获取 MCP 服务器""" 20 | async with self._session_factory() as session: 21 | result = await session.execute( 22 | select(McpServer).where(McpServer.id == id) 23 | ) 24 | return result.scalar_one_or_none() 25 | 26 | async def get_servers_by_user_id(self, user_id: int) -> List[McpServer]: 27 | """根据用户ID获取该用户的所有 MCP 服务器""" 28 | async with self._session_factory() as session: 29 | result = await session.execute( 30 | select(McpServer).where(McpServer.user_id == user_id) 31 | ) 32 | return result.scalars().all() 33 | 34 | async def create_server(self, 35 | user_id: int, 36 | server_name: str, 37 | command: str, 38 | args: List[str], 39 | env: Optional[Dict[str, Any]] = None) -> McpServer: 40 | """创建新的 MCP 服务器""" 41 | async with self._session_factory() as session: 42 | new_server = McpServer( 43 | user_id=user_id, 44 | server_name=server_name, 45 | command=command, 46 | args=args, 47 | env=env, 48 | created_at=datetime.now(), 49 | updated_at=datetime.now() 50 | ) 51 | session.add(new_server) 52 | await session.commit() 53 | return new_server 54 | 55 | async def update_server(self, 56 | id: int, 57 | server_name: Optional[str] = None, 58 | command: Optional[str] = None, 59 | args: Optional[List[str]] = None, 60 | env: Optional[Dict[str, Any]] = None) -> Optional[McpServer]: 61 | """更新 MCP 服务器信息""" 62 | async with self._session_factory() as session: 63 | update_data = {} 64 | if server_name is not None: 65 | update_data['server_name'] = server_name 66 | if command is not None: 67 | update_data['command'] = command 68 | if args is not None: 69 | update_data['args'] = args 70 | if env is not None: 71 | update_data['env'] = env 72 | 73 | if update_data: 74 | update_data['updated_at'] = datetime.now() 75 | result = await session.execute( 76 | update(McpServer) 77 | .where(McpServer.id == id) 78 | .values(**update_data) 79 | .returning(McpServer) 80 | ) 81 | await session.commit() 82 | return result.scalar_one_or_none() 83 | return None 84 | 85 | async def delete_server(self, id: int) -> bool: 86 | """删除 MCP 服务器""" 87 | async with self._session_factory() as session: 88 | result = await session.execute( 89 | delete(McpServer).where(McpServer.id == id) 90 | ) 91 | await session.commit() 92 | return result.rowcount > 0 93 | -------------------------------------------------------------------------------- /dao/user_dao.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.future import select 2 | from model.user import User 3 | from typing import List 4 | 5 | class UserDAO: 6 | def __init__(self, session_factory): 7 | self._session_factory = session_factory 8 | print('user dao init') 9 | 10 | async def get_all_users(self) -> List[User]: 11 | async with self._session_factory() as session: # 获取会话 12 | result = await session.execute(select(User)) # 得到Result对象 13 | return result.scalars().all() # 使用Result.scalars将元组列表转为List[User] 14 | 15 | 16 | async def create_user(self, name: str, email: str, password: str): 17 | async with self._session_factory() as session: 18 | new_user = User(name=name, email=email, password=password) 19 | session.add(new_user) 20 | await session.commit() 21 | return new_user 22 | 23 | async def get_user_by_name(self, user_name: str) -> User: 24 | async with self._session_factory() as session: 25 | result = await session.execute(select(User).where(User.name == user_name)) 26 | return result.scalars().first() -------------------------------------------------------------------------------- /dto/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isoftstone-data-intelligence-ai/efflux-backend/c25742b0c76663265e42fdc67385a4bf78ee15c3/dto/__init__.py -------------------------------------------------------------------------------- /dto/chat_dto.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import Optional 3 | 4 | # 会话用DTO 5 | class ChatDTO(BaseModel): 6 | # 会话id 7 | chat_id: Optional[int] = None 8 | # system_message 系统提示词 9 | prompt: Optional[str] = None 10 | # mcp server name,选择已经加载的mcp sever执行任务 11 | server_id: Optional[int] = None 12 | # 用户输入的内容 13 | query: str 14 | # 对话历史记录 15 | history: Optional[dict] = None 16 | api: Optional[str] = None 17 | card: Optional[str] = None 18 | user_id: int = 1 -------------------------------------------------------------------------------- /dto/chat_window_dto.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional, List 3 | from pydantic import BaseModel 4 | 5 | # 对话信息内容 6 | class ContentDTO(BaseModel): 7 | type: str # text | image 8 | text: Optional[str] = None 9 | image: Optional[str] = None 10 | 11 | def model_dump(self, **kwargs): 12 | return { 13 | "type": self.type, 14 | "text": self.text, 15 | "image": self.image 16 | } 17 | 18 | # 对话信息记录 19 | class ChatMessageDTO(BaseModel): 20 | role: str # user | assistant 21 | content: List[ContentDTO] 22 | 23 | def model_dump(self, **kwargs): 24 | return { 25 | "role": self.role, 26 | "content": [c.model_dump() for c in self.content] 27 | } 28 | 29 | # 会话DTO 30 | class ChatWindowDTO(BaseModel): 31 | # 主键 32 | id: int 33 | # 用户id 34 | user_id: int 35 | # 概要 36 | summary: Optional[str] = None 37 | # 对话信息记录 38 | chat_messages: List[ChatMessageDTO] = [] 39 | created_at: datetime 40 | updated_at: datetime 41 | 42 | def model_dump(self, **kwargs): 43 | return { 44 | "id": self.id, 45 | "user_id": self.user_id, 46 | "summary": self.summary, 47 | "chat_messages": [m.model_dump() for m in self.chat_messages], 48 | "created_at": self.created_at.isoformat() if self.created_at else None, 49 | "updated_at": self.updated_at.isoformat() if self.updated_at else None 50 | } 51 | 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /dto/global_response.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any, Type 2 | from fastapi.responses import JSONResponse 3 | import json 4 | 5 | 6 | # 接口统一返回体 7 | class GlobalResponse(JSONResponse): 8 | # 接口状态码 9 | code: int 10 | # 接口讯息,通常为针对状态码code的描述 11 | message: str = "success" 12 | # 子状态码(业务异常码) 13 | sub_code: int 14 | # 子讯息(业务异常信息) 15 | sub_message: str 16 | # 接口返回数据 17 | data: Optional[Any] 18 | 19 | # 时间戳 20 | # timestamp: Optional[int] 21 | 22 | def __init__( 23 | self, 24 | code: int, 25 | sub_code: int, 26 | sub_message: str, 27 | data: Optional[Any] = None, 28 | json_encoder: Optional[Type[json.JSONEncoder]] = None, 29 | **kwargs 30 | ): 31 | content = { 32 | "code": code, 33 | "message": self.message, 34 | "sub_code": sub_code, 35 | "sub_message": sub_message, 36 | "data": data, 37 | # "timestamp": timestamp, 38 | } 39 | 40 | # 如果提供了自定义的 JSON 编码器,使用它来序列化内容 41 | if json_encoder: 42 | content = json.loads(json_encoder().encode(content)) 43 | 44 | super().__init__(content=content, **kwargs) 45 | -------------------------------------------------------------------------------- /dto/mcp_server_dto.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import List, Dict, Optional 3 | 4 | class MCPServerDTO(BaseModel): 5 | """用于返回数据和更新操作的 DTO""" 6 | id: int 7 | user_id: int 8 | server_name: str 9 | command: str 10 | args: List[str] = Field(default_factory=list) # Pydantic会自动处理可变默认值的问题,为每个实例创建一个深拷贝 11 | env: Optional[Dict[str, str]] = None # Pylint不了解Pydantic的这个特性,所以有红线,可以忽略 12 | 13 | 14 | class CreateMCPServerDTO(BaseModel): 15 | """仅用于创建操作的 DTO""" 16 | user_id: int 17 | server_name: str 18 | command: str 19 | args: List[str] = Field(default_factory=list) 20 | env: Optional[Dict[str, str]] = None -------------------------------------------------------------------------------- /dto/user_dto.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | class UserInit(BaseModel): 4 | name: str 5 | email: str 6 | password: str 7 | 8 | class UserResult(BaseModel): 9 | id: int 10 | name: str 11 | email: str 12 | 13 | class Token(BaseModel): 14 | access_token: str 15 | token_type: str 16 | 17 | class TokenData(BaseModel): 18 | username: str | None = None 19 | -------------------------------------------------------------------------------- /exception/exception.py: -------------------------------------------------------------------------------- 1 | from starlette.exceptions import HTTPException 2 | from fastapi.exceptions import RequestValidationError 3 | from exception.exception_dict import ExceptionType 4 | from utils import result_utils 5 | 6 | 7 | # 基础API异常,特定异常须继承此类进行拓展 8 | class BaseAPIException(HTTPException): 9 | """基础 API 异常类""" 10 | 11 | def __init__( 12 | self, 13 | status_code: int = ExceptionType.UNKNOWN_ERROR.code, 14 | detail: str = ExceptionType.UNKNOWN_ERROR.message 15 | ): 16 | super().__init__(status_code=status_code, detail=detail) 17 | 18 | 19 | # 全局异常处理 20 | async def global_exception_handler(request, exception): 21 | """处理通用 HTTP 异常""" 22 | return result_utils.build_error_response( 23 | sub_code=exception.status_code, 24 | sub_message=exception.detail 25 | ) 26 | 27 | 28 | # 请求数据无效异常处理 29 | async def validate_exception_handler(request, exception): 30 | """处理请求参数验证异常""" 31 | err = exception.errors()[0] 32 | sub_code = 400 33 | return result_utils.build_validation_error_response( 34 | sub_code=sub_code, 35 | errors=exception.errors() 36 | ) 37 | 38 | 39 | # 指定异常对应的处理函数 40 | global_exception_handlers = { 41 | HTTPException: global_exception_handler, 42 | RequestValidationError: validate_exception_handler 43 | } 44 | -------------------------------------------------------------------------------- /exception/exception_dict.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class ExceptionType(Enum): 4 | UNKNOWN_ERROR = (1, "未知的异常") 5 | INVALID_PARAM = (2, "参数校验异常") 6 | RESOURCE_NOT_FOUND = (3, "资源未找到") 7 | DUPLICATE_ENTRY = (4, "检测到重复条目") 8 | AUTHENTICATION_FAILED = (5, "身份验证失败") 9 | AUTHORIZATION_FAILED = (6, "授权失败") 10 | SERVER_UPDATE_FAILED = (7, "服务器更新失败") 11 | SERVER_DELETE_FAILED = (8, "服务器删除失败") 12 | DUPLICATE_SERVER_NAME = (9, "服务器名称已存在") 13 | # 追加自定义异常 14 | 15 | 16 | 17 | 18 | def __new__(cls, code, message): 19 | value = len(cls.__members__) + 1 20 | obj = object.__new__(cls) 21 | obj._value_ = value 22 | obj.code = code 23 | obj.message = message 24 | return obj 25 | 26 | """ 27 | 在业务判断中,声明BaseAPIException,传入ExceptionType的code和message 28 | 代码示例: 29 | 30 | async def get_id_by_user(user_id: int): 31 | if user_id != 1: 32 | raise BaseAPIException( 33 | status_code=ExceptionDict.INVALID_PARAM.code, 34 | detail=ExceptionDict.INVALID_PARAM.message 35 | ) 36 | return {"user_id": user_id} 37 | """ -------------------------------------------------------------------------------- /extensions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isoftstone-data-intelligence-ai/efflux-backend/c25742b0c76663265e42fdc67385a4bf78ee15c3/extensions/__init__.py -------------------------------------------------------------------------------- /extensions/ext_database.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine 2 | from sqlalchemy.orm import sessionmaker, declarative_base 3 | from dependency_injector.providers import Singleton 4 | import os 5 | from dotenv import load_dotenv 6 | # 加载环境变量 7 | load_dotenv() 8 | 9 | # 从环境变量构建连接字符串 10 | DATABASE_URL = ( 11 | f"postgresql+asyncpg://{os.environ['DATABASE_USERNAME']}:{os.environ['DATABASE_PASSWORD']}" 12 | f"@{os.environ['DATABASE_HOST']}:{os.environ['DATABASE_PORT']}/{os.environ['DATABASE_NAME']}" 13 | ) 14 | 15 | # # 数据库连接字符串(根据实际情况替换 user, password, host, database) 16 | # DATABASE_URL = "postgresql+asyncpg://admin:123456@localhost:5432/efflux" 17 | 18 | 19 | # 创建 SQLAlchemy Async Engine 20 | engine = create_async_engine(DATABASE_URL, future=True, echo=False) 21 | 22 | # 创建 Async Session 工厂 23 | async_session_factory = sessionmaker( 24 | bind=engine, 25 | class_=AsyncSession, 26 | expire_on_commit=False 27 | ) 28 | 29 | Base = declarative_base() 30 | 31 | # 用作依赖注入的单例 Session 32 | class DatabaseProvider(Singleton): 33 | session_factory = async_session_factory -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Request 2 | from fastapi.middleware.cors import CORSMiddleware 3 | from exception.exception import global_exception_handlers 4 | from controller import chat_controller, user_controller, login_controller, mcp_controller, chat_window_controller 5 | 6 | app = FastAPI(exception_handlers=global_exception_handlers) 7 | 8 | # @app.middleware("http") 9 | # async def middleware(request: Request, call_next): 10 | # response = await call_next(request) 11 | # response.headers["Access-Control-Allow-Origin"] = "*" 12 | # response.headers["Access-Control-Allow-Headers"] = "*" 13 | # response.headers["Access-Control-Allow-Methods"] = "*" 14 | # return response 15 | 16 | # 添加CORS中间件 17 | app.add_middleware( 18 | CORSMiddleware, 19 | allow_origins=["*"], # 允许所有源,或者指定特定源 20 | allow_credentials=True, 21 | allow_methods=["*"], # 允许所有方法 22 | allow_headers=["*"], # 允许所有头 23 | ) 24 | 25 | app.include_router(chat_controller.router) 26 | app.include_router(user_controller.router) 27 | app.include_router(login_controller.router) 28 | app.include_router(mcp_controller.router) 29 | app.include_router(chat_window_controller.router) 30 | 31 | # 在应用程序启动时初始化资源 32 | # container = Container() 33 | 34 | # 初始化数据库表 35 | @app.on_event("startup") 36 | async def init(): 37 | print("init app") 38 | # await container.init_resources() 39 | 40 | # @app.on_event("shutdown") 41 | # async def shutdown_event(): 42 | # await container.shutdown_resources() 43 | 44 | 45 | @app.get("/") 46 | async def root(): 47 | return {"message": "Hello World"} -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isoftstone-data-intelligence-ai/efflux-backend/c25742b0c76663265e42fdc67385a4bf78ee15c3/model/__init__.py -------------------------------------------------------------------------------- /model/chat_window.py: -------------------------------------------------------------------------------- 1 | from extensions.ext_database import Base 2 | from sqlalchemy import Column, JSON, TIMESTAMP, BigInteger, String 3 | from datetime import datetime 4 | 5 | class ChatWindow(Base): 6 | __tablename__ = 'chat_window' 7 | 8 | # 主键 9 | id = Column(BigInteger, primary_key=True, autoincrement=True, index=True) 10 | # 用户id 11 | user_id = Column(BigInteger, nullable=False, index=True) 12 | # 会话概要 13 | summary = Column(String(100), nullable=False) 14 | # 会话内容 15 | content = Column(JSON, nullable=True) 16 | # 创建时间 17 | created_at = Column(TIMESTAMP(timezone=True), nullable=True, default=datetime.now) 18 | # 更新时间 19 | updated_at = Column(TIMESTAMP(timezone=True), nullable=True, default=datetime.now, onupdate=datetime.now) 20 | -------------------------------------------------------------------------------- /model/mcp_server.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, String, JSON, TIMESTAMP, TEXT, ARRAY, BigInteger 2 | from extensions.ext_database import Base 3 | from datetime import datetime 4 | 5 | class McpServer(Base): 6 | __tablename__ = 'mcp_servers' 7 | 8 | id = Column(BigInteger, primary_key=True, autoincrement=True, index=True) 9 | user_id = Column(BigInteger, nullable=False, index=True) 10 | server_name = Column(String(100), nullable=False) 11 | command = Column(String(100), nullable=False) 12 | args = Column(ARRAY(TEXT), nullable=False) 13 | env = Column(JSON, nullable=True) 14 | created_at = Column(TIMESTAMP(timezone=True), nullable=True, default=datetime.now) 15 | updated_at = Column(TIMESTAMP(timezone=True), nullable=True, default=datetime.now, onupdate=datetime.now) 16 | 17 | def model_dump(self) -> dict: 18 | """将模型序列化为字典""" 19 | return { 20 | "id": self.id, 21 | "user_id": self.user_id, 22 | "server_name": self.server_name, 23 | "command": self.command, 24 | "args": self.args, 25 | "env": self.env 26 | } 27 | 28 | def to_mcp_config(self) -> dict: 29 | """将服务器配置转换为标准MCP服务器配置格式 30 | 31 | Returns: 32 | dict: MCP服务器配置字典,格式如: 33 | { 34 | "server_name": { 35 | "command": "command", 36 | "args": ["arg1", "arg2"], 37 | "env": {"key": "value"} # env字段在没有环境变量时会被省略 38 | } 39 | } 40 | """ 41 | return { 42 | self.server_name: { 43 | "command": self.command, 44 | "args": self.args, 45 | **({"env": self.env} if self.env else {}) # 如果self.env为空,则解包空字典{},相当于不添加任何字段或键值对,env字段不会出现 46 | } 47 | } 48 | 49 | 50 | -------------------------------------------------------------------------------- /model/user.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, Time 2 | from extensions.ext_database import Base 3 | 4 | class User(Base): 5 | __tablename__ = 'account' 6 | 7 | id = Column(Integer, primary_key=True, index=True) 8 | name = Column(String, unique=True, index=True) 9 | password = Column(String) 10 | email = Column(String, unique=True, index=True) 11 | create_time = Column(Time) 12 | 13 | def __str__(self): 14 | return f"User(id={self.id}, name={self.name}, email={self.email}, create_time={self.create_time})" 15 | 16 | def __repr__(self): 17 | return str(self) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "efflux" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "alembic>=1.14.0", 9 | "asyncpg>=0.30.0", 10 | "colorlog>=6.9.0", 11 | "dependency-injector>=4.44.0", 12 | "fastapi>=0.115.6", 13 | "greenlet>=3.1.1", 14 | "langchain-openai>=0.2.12", 15 | "langchain>=0.3.11", 16 | "mcp>=1.1.1", 17 | "openai>=1.57.2", 18 | "python-dotenv>=1.0.1", 19 | "sqlalchemy>=2.0.36", 20 | "bcrypt>=4.2.1", 21 | "passlib>=1.7.4", 22 | "python-multipart>=0.0.19", 23 | "PyJWT>=2.10.1", 24 | "langgraph>=0.2.59", 25 | "jsonschema-pydantic>=0.6", 26 | ] 27 | -------------------------------------------------------------------------------- /service/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isoftstone-data-intelligence-ai/efflux-backend/c25742b0c76663265e42fdc67385a4bf78ee15c3/service/__init__.py -------------------------------------------------------------------------------- /service/chat_service.py: -------------------------------------------------------------------------------- 1 | from core.llm.llm_manager import LLMChat 2 | from typing import AsyncGenerator, List, Optional 3 | from core.mcp.convert_mcp_tools import convert_mcp_to_langchain_tools 4 | from dao.chat_window_dao import ChatWindowDAO 5 | from dto.chat_dto import ChatDTO 6 | from dto.chat_window_dto import ContentDTO, ChatMessageDTO 7 | from model.chat_window import ChatWindow 8 | from service.mcp_config_service import MCPConfigService 9 | import json 10 | 11 | 12 | class ChatService: 13 | """ 14 | ChatService 类 - 提供与模型会话相关的服务 15 | 16 | 该类负责与语言模型 (LLM) 进行交互,动态加载工具并支持流式会话。 17 | 同时维护用户的历史会话记录,以实现更自然的上下文会话效果。 18 | """ 19 | 20 | def __init__(self, llm: LLMChat, mcp_config_service: MCPConfigService, chat_window_dao: ChatWindowDAO): 21 | """ 22 | 初始化 ChatService 23 | 24 | Args: 25 | llm (LLMChat): 语言模型管理器,用于处理会话逻辑。 26 | mcp_config_service (MCPConfigService): MCP 配置服务,用于获取 MCP-Server 的相关配置。 27 | chat_window_dao: 会话记录DAO,用于持久化会话及对话历史 28 | """ 29 | self.llm = llm 30 | self.chat_window_dao = chat_window_dao 31 | self.mcp_config_service = mcp_config_service 32 | self.user_history_dict = {} # 用于存储每个用户的历史会话记录 33 | 34 | async def agent_stream(self, chat_dto: ChatDTO) -> AsyncGenerator[str, None]: 35 | """ 36 | langchain 流式代理方法 37 | 38 | 动态加载 MCP-Server 的标准化工具,通过语言模型提供流式会话功能。 39 | 40 | Args: 41 | chat_dto (ChatDTO): 会话请求数据传输对象,包含用户 ID、问题、提示词等。 42 | 43 | Yields: 44 | str: 模型返回的流式会话响应,每次生成一个 JSON 格式的消息块。 45 | """ 46 | # 判断是否需要创建会话 47 | if not chat_dto.chat_id: 48 | new_chat_window_id = await self.create_chat(chat_dto.user_id, chat_dto.query) 49 | # 初始化 langchain agent 工具列表 50 | tools = [] 51 | if chat_dto.server_id: 52 | # 通过json文件获取mcp-server配置,这个需要保留 53 | # server_params = await load_config_from_file(chat_dto.server_name) 54 | 55 | # 通过数据库获取mcp-server配置 56 | server_params = await self.mcp_config_service.load_mcp_server_config(chat_dto.server_id) 57 | 58 | # 获取mcp-server的所有tools并转换为langchain agent tools 59 | tools = await convert_mcp_to_langchain_tools([server_params]) 60 | 61 | # 定义回调方法,用于收集模型返回的数据 62 | async def data_callback(collected_data): 63 | user_id = chat_dto.user_id 64 | user_query = chat_dto.query 65 | print("--->messages:", ''.join(collected_data["messages"])) 66 | if "messages" in collected_data and collected_data["messages"]: 67 | assistant_reply = ''.join(collected_data["messages"]) 68 | 69 | # 初始化或更新用户的历史记录 70 | if user_id not in self.user_history_dict: 71 | self.user_history_dict[user_id] = [] 72 | self.user_history_dict[user_id].append({ 73 | "user": chat_dto.query, 74 | "assistant": assistant_reply 75 | }) 76 | # 保留最近 3 条历史记录 77 | self.user_history_dict[user_id] = self.user_history_dict[user_id][-3:] 78 | # 更新会话历史记录 79 | if chat_dto.chat_id is None: 80 | chat_window_id = new_chat_window_id 81 | else: 82 | chat_window_id = chat_dto.chat_id 83 | await self.update_chat_window(chat_window_id, user_query, assistant_reply) 84 | for tool_call in collected_data["tool_calls"]: 85 | print("--->tool_call:", tool_call) 86 | 87 | # 构造模型的输入内容 88 | inputs = await self.load_inputs(chat_dto) 89 | 90 | # 调用语言模型的流式接口,生成响应 91 | async for chunk in self.llm.stream_chat(inputs=inputs, tools=tools, callback=data_callback): 92 | yield json.dumps(chunk.model_dump()) + "\n" 93 | 94 | async def load_inputs(self, chat_dto: ChatDTO) -> dict: 95 | """ 96 | 加载模型输入内容 97 | 98 | 根据用户请求,拼接提示词、历史记录和当前问题,构造模型的输入。 99 | 100 | Args: 101 | chat_dto (ChatDTO): 会话请求数据传输对象。 102 | 103 | Returns: 104 | dict: 拼装好的模型输入,包含消息列表。 105 | """ 106 | # 初始化消息列表 107 | messages = [] 108 | # 如果有提示词,加入系统提示 109 | if chat_dto.prompt: 110 | messages.append(("system", chat_dto.prompt)) 111 | 112 | # 添加最近 3 条用户历史记录 113 | if chat_dto.user_id in self.user_history_dict: 114 | for record in self.user_history_dict[chat_dto.user_id][-3:]: 115 | messages.append(("user", record["user"])) 116 | messages.append(("assistant", record["assistant"])) 117 | 118 | # 添加当前用户的问题 119 | messages.append(("user", chat_dto.query)) 120 | 121 | # 打印即将发送给模型的消息 122 | print(">>> 模型输入 messages:") 123 | for role, content in messages: 124 | print(f"{role}: {content}") 125 | 126 | return {"messages": messages} 127 | 128 | async def create_chat(self, user_id: int, query: str, chat_messages: Optional[List[str]] = None) -> int: 129 | # 会话概要 130 | # summary_prompt = "根据会话记录总结出本次会话的概要" 131 | # summary_query = query + reply 132 | # summary = self.normal_chat(ChatDTO(prompt=summary_prompt, query=summary_query)) 133 | summary = query 134 | # 创建新的会话 135 | new_chat_window = await self.chat_window_dao.create_chat_window(user_id=user_id, summary=summary) 136 | return new_chat_window.id 137 | 138 | async def get_chat_by_id(self, chat_window_id: int) -> ChatWindow: 139 | return await self.chat_window_dao.get_chat_window_by_id(chat_window_id) 140 | 141 | async def update_chat_window(self, chat_window_id: int, query: str, reply: Optional[str] = None): 142 | # 旧会话 143 | old_chat_window: ChatWindow = await self.chat_window_dao.get_chat_window_by_id(chat_window_id) 144 | 145 | # 构建新会话 content 146 | content_user = ContentDTO(type="text", text=query) 147 | content_assistant = ContentDTO(type="text", text=reply) 148 | 149 | # 创建消息,注意 content 需要是列表 150 | chat_message_user = ChatMessageDTO( 151 | role="user", 152 | content=[content_user] 153 | ) 154 | chat_message_assistant = ChatMessageDTO( 155 | role="assistant", 156 | content=[content_assistant] 157 | ) 158 | 159 | # 构建新的内容列表并转换为可序列化的字典 160 | new_content = [ 161 | chat_message_user.model_dump(), 162 | chat_message_assistant.model_dump() 163 | ] 164 | 165 | # 最终更新内容 166 | if old_chat_window.content is None: 167 | update_content = new_content 168 | else: 169 | old_content = old_chat_window.content 170 | old_content.extend(new_content) 171 | update_content = old_content 172 | 173 | # 更新会话窗口 174 | await self.chat_window_dao.update_chat_window( 175 | chat_window_id=chat_window_id, 176 | summary=None, 177 | content=update_content 178 | ) 179 | 180 | async def normal_chat(self, chat_dto: ChatDTO) -> str: 181 | return await self.llm.normal_chat(1, chat_dto.query) 182 | -------------------------------------------------------------------------------- /service/chat_window_service.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from dao.chat_window_dao import ChatWindowDAO 3 | from dto.chat_window_dto import ChatWindowDTO 4 | from model.chat_window import ChatWindow 5 | 6 | 7 | class ChatWindowService: 8 | def __init__(self, chat_window_dao: ChatWindowDAO): 9 | self.chat_window_dao = chat_window_dao 10 | 11 | 12 | async def get_user_chat_windows(self, user_id: int) -> List[ChatWindowDTO]: 13 | result = await self.chat_window_dao.get_user_chat_windows(user_id) 14 | return await self.convert_models_to_chat_windows(result) 15 | 16 | async def convert_models_to_chat_windows(self, chat_windows: List[ChatWindow]) -> List[ChatWindowDTO]: 17 | return [self.convert_model_to_chat_window_dto(chat_window) for chat_window in chat_windows] 18 | 19 | @staticmethod 20 | def convert_model_to_chat_window_dto(chat_window: ChatWindow) -> ChatWindowDTO: 21 | # 确保 content 字段有值,如果为 None 则使用空列表 22 | chat_messages = chat_window.content if chat_window.content else [] 23 | 24 | return ChatWindowDTO( 25 | id=chat_window.id, 26 | user_id=chat_window.user_id, 27 | summary=chat_window.summary, 28 | chat_messages=chat_messages, 29 | created_at=chat_window.created_at, 30 | updated_at=chat_window.updated_at 31 | ) -------------------------------------------------------------------------------- /service/mcp_config_service.py: -------------------------------------------------------------------------------- 1 | import json 2 | from core.common.logger import get_logger 3 | from core.mcp.server.server_loader import StdioServerParameters 4 | from dto.mcp_server_dto import MCPServerDTO, CreateMCPServerDTO 5 | from typing import List, Optional 6 | from dao.mcp_server_dao import MCPServerDAO 7 | from exception.exception import BaseAPIException 8 | from model.mcp_server import McpServer 9 | from exception.exception_dict import ExceptionType 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | class MCPConfigService: 15 | 16 | def __init__(self, mcp_server_dao: MCPServerDAO): 17 | self.mcp_server_dao = mcp_server_dao 18 | 19 | # 不设置为静态方法,因为config service已经在container注册过了,一定有实例 20 | def to_dto(self, server: Optional[McpServer]) -> Optional[MCPServerDTO]: 21 | """将 McpServer 模型转换为 DTO 22 | 23 | Args: 24 | server: McpServer 实例或 None 25 | 26 | Returns: 27 | McPServerDTO 实例或 None 28 | """ 29 | if not server: 30 | return None 31 | return MCPServerDTO(**server.model_dump()) # ** 解包操作:将 model_dump() 返回的字典进行解包,展开为关键字参数(键值对)传递给 DTO 32 | 33 | # 获取用户的所有servers 34 | async def get_user_servers(self, user_id: int) -> List[MCPServerDTO]: 35 | servers = await self.mcp_server_dao.get_servers_by_user_id(user_id) 36 | return [self.to_dto(server) for server in servers] 37 | 38 | # 获取单个server 39 | async def get_server(self, server_id: int) -> Optional[MCPServerDTO]: 40 | server = await self.mcp_server_dao.get_server_by_id(server_id) 41 | if server is None: 42 | raise BaseAPIException( 43 | status_code=ExceptionType.RESOURCE_NOT_FOUND.code, 44 | detail=ExceptionType.RESOURCE_NOT_FOUND.message 45 | ) 46 | return self.to_dto(server) 47 | 48 | # 新增一个server 49 | async def add_server(self, mcp_server: CreateMCPServerDTO) -> MCPServerDTO: 50 | # 验证服务器名称和命令不能为空字符串 51 | if not mcp_server.server_name.strip(): # 除字符串两端的空白字符(包括空格、制表符、换行符等) 52 | raise BaseAPIException( 53 | status_code=ExceptionType.INVALID_PARAM.code, 54 | detail="MCP server name 不能为空" 55 | ) 56 | if not mcp_server.command.strip(): 57 | raise BaseAPIException( 58 | status_code=ExceptionType.INVALID_PARAM.code, 59 | detail="MCP server command 不能为空" 60 | ) 61 | 62 | # 检查同一用户下是否存在同名服务器 63 | existing_servers = await self.get_user_servers(mcp_server.user_id) 64 | if any(server.server_name == mcp_server.server_name for server in existing_servers): 65 | raise BaseAPIException( 66 | status_code=ExceptionType.DUPLICATE_SERVER_NAME.code, 67 | detail=ExceptionType.DUPLICATE_SERVER_NAME.message 68 | ) 69 | 70 | new_server = await self.mcp_server_dao.create_server( 71 | user_id=mcp_server.user_id, 72 | server_name=mcp_server.server_name, 73 | command=mcp_server.command, 74 | args=mcp_server.args, 75 | env=mcp_server.env 76 | ) 77 | return self.to_dto(new_server) 78 | 79 | # 编辑server 80 | async def update_server(self, mcp_server: MCPServerDTO) -> Optional[MCPServerDTO]: 81 | # 验证服务器名称和命令不能为空字符串 82 | if not mcp_server.server_name.strip(): 83 | raise BaseAPIException( 84 | status_code=ExceptionType.INVALID_PARAM.code, 85 | detail="MCP server name 不能为空" 86 | ) 87 | if not mcp_server.command.strip(): 88 | raise BaseAPIException( 89 | status_code=ExceptionType.INVALID_PARAM.code, 90 | detail="MCP server command 不能为空" 91 | ) 92 | 93 | # 获取当前服务器信息 94 | current_server = await self.get_server(mcp_server.id) 95 | if not current_server: 96 | raise BaseAPIException( 97 | status_code=ExceptionType.RESOURCE_NOT_FOUND.code, 98 | detail=ExceptionType.RESOURCE_NOT_FOUND.message 99 | ) 100 | 101 | # 如果服务器名称发生变化,检查新名称是否与其他服务器冲突 102 | if current_server.server_name != mcp_server.server_name: 103 | existing_servers = await self.get_user_servers(mcp_server.user_id) 104 | if any(server.server_name == mcp_server.server_name for server in existing_servers): 105 | raise BaseAPIException( 106 | status_code=ExceptionType.DUPLICATE_SERVER_NAME.code, 107 | detail=ExceptionType.DUPLICATE_SERVER_NAME.message 108 | ) 109 | 110 | updated_server = await self.mcp_server_dao.update_server( 111 | id=mcp_server.id, 112 | server_name=mcp_server.server_name, 113 | command=mcp_server.command, 114 | args=mcp_server.args, 115 | env=mcp_server.env 116 | ) 117 | if not updated_server: 118 | raise BaseAPIException( 119 | status_code=ExceptionType.SERVER_UPDATE_FAILED.code, 120 | detail=ExceptionType.SERVER_UPDATE_FAILED.message 121 | ) 122 | return self.to_dto(updated_server) 123 | 124 | # 删除server 125 | async def delete_server(self, server_id: int) -> bool: 126 | # 确保服务器存在 127 | server = await self.get_server(server_id) 128 | if not server: 129 | raise BaseAPIException( 130 | status_code=ExceptionType.RESOURCE_NOT_FOUND.code, 131 | detail=ExceptionType.RESOURCE_NOT_FOUND.message 132 | ) 133 | 134 | success = await self.mcp_server_dao.delete_server(server_id) 135 | if not success: 136 | raise BaseAPIException( 137 | status_code=ExceptionType.SERVER_DELETE_FAILED.code, 138 | detail=ExceptionType.SERVER_DELETE_FAILED.message 139 | ) 140 | return success 141 | 142 | # 加载mcp_server_config 143 | async def load_mcp_server_config(self, server_id: int) -> StdioServerParameters: 144 | """ Load the server configuration from DB """ 145 | try: 146 | mcp_sever = await self.get_server(server_id) 147 | 148 | # Construct the server parameters 149 | result = StdioServerParameters( 150 | command=mcp_sever.command, 151 | args=mcp_sever.args, 152 | env=mcp_sever.env if mcp_sever.env else None, # 这是额外的一层保护,以免数据库有env={}的情况 153 | ) 154 | 155 | # debug 156 | logger.debug(f"Loaded config from DB: command='{result.command}', args={result.args}, env={result.env}") 157 | 158 | # return result 159 | return result 160 | except Exception as e: 161 | # error 162 | logger.error(str(e)) 163 | raise 164 | -------------------------------------------------------------------------------- /service/user_service.py: -------------------------------------------------------------------------------- 1 | from dao.user_dao import UserDAO 2 | from core.common.logger import logger 3 | from passlib.hash import bcrypt 4 | from model.user import User 5 | from typing import List 6 | 7 | 8 | @logger 9 | class UserService: 10 | def __init__(self, user_dao: UserDAO): 11 | self.user_dao = user_dao 12 | 13 | async def get_users(self) -> List[User]: 14 | print("enter user service") 15 | users: List[User] = await self.user_dao.get_all_users() 16 | print("enter user service done") 17 | return users 18 | 19 | async def create_user(self, name: str, email: str, password: str): 20 | hashed_password = bcrypt.hash(password) 21 | return await self.user_dao.create_user(name, email, hashed_password) 22 | 23 | async def get_user_by_name(self, user_name) -> User: 24 | return await self.user_dao.get_user_by_name(user_name) 25 | -------------------------------------------------------------------------------- /utils/db_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from extensions.ext_database import engine, Base 3 | 4 | class DatabaseProvider: 5 | @classmethod 6 | async def init_db(cls): 7 | # 延迟导入,避免循环依赖 8 | from model.user import User 9 | async with engine.begin() as conn: 10 | # 打印 metadata 的所有表 11 | print(Base.metadata.tables.keys()) 12 | # 创建所有未存在的表 13 | await conn.run_sync(Base.metadata.create_all) 14 | 15 | if __name__ == "__main__": 16 | asyncio.run(DatabaseProvider.init_db()) 17 | -------------------------------------------------------------------------------- /utils/result_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any 2 | from dto.global_response import GlobalResponse 3 | from pydantic import BaseModel 4 | from datetime import datetime 5 | import json 6 | 7 | class DateTimeEncoder(json.JSONEncoder): 8 | """自定义JSON编码器,用于处理datetime对象""" 9 | def default(self, obj): 10 | if isinstance(obj, datetime): 11 | return obj.isoformat() 12 | return super().default(obj) 13 | 14 | def build_response(data: Optional[Any] = None) -> GlobalResponse: 15 | """构建成功响应 16 | 17 | Args: 18 | data: 响应数据,可以是 BaseModel、dict、list 或 None 19 | 20 | Returns: 21 | GlobalResponse: 统一的成功响应格式 22 | """ 23 | # 如果是 Pydantic 模型或包含 Pydantic 模型的列表,转换为字典 24 | if isinstance(data, BaseModel): 25 | data = data.model_dump() 26 | elif isinstance(data, list): 27 | data = [item.model_dump() if isinstance(item, BaseModel) else item for item in data] 28 | 29 | return GlobalResponse( 30 | code=200, 31 | sub_code=200, 32 | sub_message="success", 33 | data=data, 34 | json_encoder=DateTimeEncoder 35 | ) 36 | 37 | 38 | def build_error_response(sub_code: int, sub_message: str, data: Optional[Any] = None) -> GlobalResponse: 39 | """构建错误响应 40 | 41 | Args: 42 | sub_code: 错误码 43 | sub_message: 错误信息 44 | data: 可选的额外错误数据 45 | 46 | Returns: 47 | GlobalResponse: 统一的错误响应格式 48 | """ 49 | return GlobalResponse( 50 | code=500, 51 | sub_code=sub_code, 52 | sub_message=sub_message, 53 | data=data, 54 | json_encoder=DateTimeEncoder 55 | ) 56 | 57 | 58 | def build_validation_error_response(sub_code: int, errors: list) -> GlobalResponse: 59 | """构建参数验证错误响应 60 | 61 | Args: 62 | sub_code: 子状态码 63 | errors: 验证错误列表 64 | 65 | Returns: 66 | GlobalResponse: 统一的验证错误响应格式 67 | """ 68 | return GlobalResponse( 69 | code=400, 70 | sub_code=sub_code, 71 | sub_message="参数验证错误", 72 | data=errors, 73 | json_encoder=DateTimeEncoder 74 | ) 75 | --------------------------------------------------------------------------------