├── README.md ├── backend ├── api │ ├── core │ │ ├── oonfig.py │ │ └── security.py │ ├── deps │ │ └── auth.py │ ├── main.py │ ├── models │ │ ├── analysis.py │ │ ├── chat.py │ │ ├── document.py │ │ ├── search.py │ │ └── user.py │ └── routers │ │ ├── analysis.py │ │ ├── chat.py │ │ ├── documents.py │ │ └── search.py ├── embeddings │ ├── batch_processor.py │ └── model.py ├── processors │ ├── base.py │ ├── docx_processor.py │ ├── excel_processor.py │ ├── html_processor.py │ └── pdf_processor.py ├── services │ ├── document_service.py │ ├── graphrag_service.py │ ├── llm_service.py │ ├── search_service.py │ └── vector_store.py └── worker │ ├── main.py │ └── tasks │ └── document_processor.py ├── configs ├── api.yaml ├── ollama.yaml ├── qdrant.yaml ├── redis.yaml └── worker.yaml ├── docker-compose.yml ├── docker ├── .env ├── Dockerfile.api ├── Dockerfile.frontend └── Dockerfile.worker ├── frontend ├── app.py ├── components │ ├── auth.py │ └── ui.py └── pages │ ├── analysis.py │ ├── chat.py │ ├── documents.py │ ├── home.py │ ├── search.py │ └── system_status.py ├── requirements.txt └── scripts ├── install.sh └── start_services.sh /README.md: -------------------------------------------------------------------------------- 1 | <<<<<<< HEAD 2 | # DeepSeek本地知识库系统 3 | 4 | 一个高性能的本地知识库系统,基于DeepSeek大模型,支持多种文档格式、本地检索、问答和分析功能。 5 | 6 | ## 系统特性 7 | 8 | - **基于DeepSeek模型的本地部署**:无需外部API,保护数据隐私 9 | - **多格式文档处理**:支持PDF、Word、Excel、TXT、HTML等格式 10 | - **高效检索**:支持5000+份文档的语义检索 11 | - **GraphRAG增强**:基于图结构的高效检索与推理 12 | - **多模态交互**:聊天、搜索、分析多种交互方式 13 | - **高性能架构**:基于Redis+Qdrant的高性能存储 14 | - **容器化部署**:基于Docker的简单部署 15 | 16 | ## 系统要求 17 | 18 | - **操作系统**:Ubuntu 20.04 LTS 或更高版本 19 | - **硬件要求**: 20 | - NVIDIA GPU (推荐RTX 3080+) 21 | - CUDA 11.7+ 22 | - 内存 16GB+ (推荐32GB+) 23 | - 存储空间 20GB+ 24 | - **软件要求**: 25 | - Docker 和 Docker Compose 26 | - NVIDIA Container Toolkit 27 | 28 | ## 快速开始 29 | 30 | ### 安装 31 | 32 | 1. 克隆代码库: 33 | ```bash 34 | git clone https://github.com/your-username/knowledge-base-system.git 35 | cd knowledge-base-system 36 | 运行安装脚本: 37 | 38 | bash scripts/install.sh 39 | 启动服务: 40 | 41 | bash scripts/start_services.sh 42 | 访问系统 43 | 前端界面:http://localhost:8501 44 | API文档:http://localhost:8000/docs 45 | 使用指南 46 | 上传文档 47 | 访问前端界面 48 | 导航至"文档管理"页面 49 | 点击"上传文档"并选择要上传的文件 50 | 添加可选的元数据 51 | 点击"上传并索引" 52 | 聊天问答 53 | 导航至"聊天问答"页面 54 | 在输入框中输入您的问题 55 | 系统将从您的知识库中检索相关信息并生成回答 56 | 搜索 57 | 导航至"搜索"页面 58 | 输入搜索查询 59 | 查看匹配的文档片段 60 | 数据分析 61 | 导航至"数据分析"页面 62 | 选择文档或输入文本进行分析 63 | 查看生成的分析报告和可视化 64 | 项目结构 65 | 66 | knowledge-base-system/ 67 | ├── docker/ - Docker配置文件 68 | ├── backend/ - 后端服务 69 | │ ├── api/ - FastAPI应用 70 | │ ├── services/ - 核心服务 71 | │ ├── processors/ - 文档处理器 72 | │ └── embeddings/ - 向量化模块 73 | ├── frontend/ - Streamlit前端 74 | ├── configs/ - 配置文件 75 | └── scripts/ - 安装和启动脚本 76 | 问题排查 77 | 常见问题 78 | 服务启动失败 79 | 检查Docker服务是否运行 80 | 检查GPU驱动和CUDA是否正确安装 81 | 查看日志:docker-compose logs -f 82 | 无法上传文档 83 | 检查文件格式是否受支持 84 | 检查文件大小是否超过限制 85 | 查看API服务日志 86 | 模型加载错误 87 | 确保DeepSeek模型已正确下载 88 | 检查GPU内存是否足够 89 | 查看Ollama服务日志 90 | 贡献与支持 91 | 欢迎提交问题和改进建议,请使用GitHub Issues。 92 | 93 | 许可证 94 | 本项目采用MIT许可证,详情请参见LICENSE文件。 95 | ``` 96 | ======= 97 | # knowledge-base-system 98 | 基于DeepSeek模型的本地部署方案 能够处理并索引多种格式的本地文档(PDF、Word、Excel、TXT、HTML等) 支持约5000+份私有文档的高效检索与分析 具备互联网搜索能力,实现本地数据与网络数据的融合分析 提供数据分析、预测功能 提供本地知识库问答功能 99 | >>>>>>> f0738bb0562f5aa7542bcad55b2e7c97b705793a 100 | -------------------------------------------------------------------------------- /backend/api/core/oonfig.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Union, Optional, Dict, Any 3 | from pydantic import BaseModel, validator 4 | 5 | class ServerSettings(BaseModel): 6 | """服务器配置模型""" 7 | host: str 8 | port: int 9 | workers: int 10 | log_level: str 11 | debug: bool 12 | reload: bool 13 | 14 | class CorsSettings(BaseModel): 15 | """CORS配置模型""" 16 | allowed_origins: List[str] 17 | allowed_methods: List[str] 18 | allowed_headers: List[str] 19 | 20 | class SecuritySettings(BaseModel): 21 | """安全配置模型""" 22 | api_key_header: str 23 | api_key: str 24 | jwt_secret: str 25 | token_expire_minutes: int 26 | 27 | class RateLimitSettings(BaseModel): 28 | """速率限制配置模型""" 29 | enabled: bool 30 | max_requests: int 31 | time_window_seconds: int 32 | 33 | class Settings(BaseModel): 34 | """全局配置设置模型""" 35 | server: ServerSettings 36 | cors: CorsSettings 37 | security: SecuritySettings 38 | rate_limiting: RateLimitSettings 39 | 40 | @classmethod 41 | def from_yaml(cls, config_dict: Dict[str, Any]) -> "Settings": 42 | """从YAML配置字典创建设置对象""" 43 | return cls( 44 | server=ServerSettings(**config_dict.get("server", {})), 45 | cors=CorsSettings(**config_dict.get("cors", {})), 46 | security=SecuritySettings(**config_dict.get("security", {})), 47 | rate_limiting=RateLimitSettings(**config_dict.get("rate_limiting", {})) 48 | ) 49 | 50 | class Config: 51 | env_file = ".env" -------------------------------------------------------------------------------- /backend/api/core/security.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from typing import Optional, Dict, Any 3 | import jwt 4 | from fastapi import Depends, HTTPException, status 5 | from fastapi.security import OAuth2PasswordBearer, APIKeyHeader 6 | from passlib.context import CryptContext 7 | from ..models.user import User, TokenData 8 | 9 | # 密码处理上下文 10 | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") 11 | 12 | # OAuth2配置 13 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") 14 | 15 | # API密钥配置 16 | api_key_scheme = APIKeyHeader(name="X-API-Key") 17 | 18 | # 从配置加载设置 19 | try: 20 | import yaml 21 | with open("configs/api.yaml", "r") as f: 22 | config = yaml.safe_load(f) 23 | SECRET_KEY = config["security"]["jwt_secret"] 24 | ALGORITHM = "HS256" 25 | ACCESS_TOKEN_EXPIRE_MINUTES = config["security"]["token_expire_minutes"] 26 | API_KEY = config["security"]["api_key"] 27 | except: 28 | # 默认值,实际生产中应该从环境变量或安全存储加载 29 | SECRET_KEY = "your-jwt-secret-here" 30 | ALGORITHM = "HS256" 31 | ACCESS_TOKEN_EXPIRE_MINUTES = 1440 # 24小时 32 | API_KEY = "your-api-key-here" 33 | 34 | def verify_password(plain_password: str, hashed_password: str) -> bool: 35 | """验证密码""" 36 | return pwd_context.verify(plain_password, hashed_password) 37 | 38 | def get_password_hash(password: str) -> str: 39 | """获取密码哈希""" 40 | return pwd_context.hash(password) 41 | 42 | def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: 43 | """创建访问令牌""" 44 | to_encode = data.copy() 45 | 46 | # 设置过期时间 47 | if expires_delta: 48 | expire = datetime.utcnow() + expires_delta 49 | else: 50 | expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) 51 | 52 | to_encode.update({"exp": expire}) 53 | 54 | # 编码JWT 55 | encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) 56 | return encoded_jwt 57 | 58 | def decode_token(token: str) -> TokenData: 59 | """解码令牌""" 60 | try: 61 | # 解码JWT 62 | payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) 63 | username = payload.get("sub") 64 | exp = payload.get("exp") 65 | 66 | if username is None: 67 | raise HTTPException( 68 | status_code=status.HTTP_401_UNAUTHORIZED, 69 | detail="无效的认证凭证", 70 | headers={"WWW-Authenticate": "Bearer"}, 71 | ) 72 | 73 | return TokenData(username=username, exp=exp) 74 | except jwt.PyJWTError: 75 | raise HTTPException( 76 | status_code=status.HTTP_401_UNAUTHORIZED, 77 | detail="无效的认证凭证", 78 | headers={"WWW-Authenticate": "Bearer"}, 79 | ) -------------------------------------------------------------------------------- /backend/api/deps/auth.py: -------------------------------------------------------------------------------- 1 | from fastapi import Depends, HTTPException, status, WebSocket 2 | from fastapi.security import OAuth2PasswordBearer, APIKeyHeader 3 | from typing import Optional 4 | from ..core.security import decode_token, API_KEY 5 | from ..models.user import User, TokenData 6 | 7 | # OAuth2配置 8 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") 9 | 10 | # API密钥配置 11 | api_key_scheme = APIKeyHeader(name="X-API-Key") 12 | 13 | # 模拟用户数据库(实际应用中应从数据库获取) 14 | fake_users_db = { 15 | "admin": { 16 | "id": "user_001", 17 | "username": "admin", 18 | "email": "admin@example.com", 19 | "full_name": "Admin User", 20 | "hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", # "password" 21 | "disabled": False, 22 | "role": "admin" 23 | }, 24 | "user": { 25 | "id": "user_002", 26 | "username": "user", 27 | "email": "user@example.com", 28 | "full_name": "Normal User", 29 | "hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", # "password" 30 | "disabled": False, 31 | "role": "user" 32 | } 33 | } 34 | 35 | async def get_user_by_token(token_data: TokenData) -> User: 36 | """通过令牌数据获取用户""" 37 | if token_data.username not in fake_users_db: 38 | raise HTTPException( 39 | status_code=status.HTTP_404_NOT_FOUND, 40 | detail="用户不存在" 41 | ) 42 | 43 | user_data = fake_users_db[token_data.username] 44 | return User(**user_data) 45 | 46 | async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: 47 | """获取当前用户""" 48 | # 尝试解码令牌 49 | token_data = decode_token(token) 50 | 51 | # 获取用户 52 | user = await get_user_by_token(token_data) 53 | 54 | if user.disabled: 55 | raise HTTPException( 56 | status_code=status.HTTP_403_FORBIDDEN, 57 | detail="用户已禁用" 58 | ) 59 | 60 | return user 61 | 62 | async def validate_api_key(api_key: str = Depends(api_key_scheme)) -> bool: 63 | """验证API密钥""" 64 | if api_key != API_KEY: 65 | raise HTTPException( 66 | status_code=status.HTTP_401_UNAUTHORIZED, 67 | detail="无效的API密钥" 68 | ) 69 | return True 70 | 71 | async def get_token_from_websocket(websocket: WebSocket) -> Optional[str]: 72 | """从WebSocket连接获取令牌""" 73 | # 尝试从查询参数获取令牌 74 | token = websocket.query_params.get("token") 75 | 76 | # 如果查询参数中没有令牌,尝试从头部获取 77 | if not token: 78 | # 从Authorization头部获取 79 | auth_header = websocket.headers.get("authorization") 80 | if auth_header and auth_header.startswith("Bearer "): 81 | token = auth_header.replace("Bearer ", "") 82 | 83 | return token -------------------------------------------------------------------------------- /backend/api/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from fastapi import FastAPI, Depends, HTTPException 4 | from fastapi.middleware.cors import CORSMiddleware 5 | import yaml 6 | from .routers import documents, search, chat, analysis 7 | from .core.config import Settings 8 | 9 | # 加载配置 10 | config_path = os.getenv("API_CONFIG_PATH", "configs/api.yaml") 11 | with open(config_path, "r") as f: 12 | config = yaml.safe_load(f) 13 | 14 | settings = Settings(**config) 15 | 16 | # 设置日志 17 | logging.basicConfig( 18 | level=getattr(logging, settings.server.log_level.upper()), 19 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 20 | ) 21 | logger = logging.getLogger(__name__) 22 | 23 | # 初始化FastAPI应用 24 | app = FastAPI( 25 | title="知识库系统API", 26 | description="高性能本地知识库系统的API", 27 | version="1.0.0", 28 | ) 29 | 30 | # 设置CORS 31 | app.add_middleware( 32 | CORSMiddleware, 33 | allow_origins=settings.cors.allowed_origins, 34 | allow_credentials=True, 35 | allow_methods=settings.cors.allowed_methods, 36 | allow_headers=settings.cors.allowed_headers, 37 | ) 38 | 39 | # 包含路由器 40 | app.include_router(documents.router, prefix="/api/documents", tags=["documents"]) 41 | app.include_router(search.router, prefix="/api/search", tags=["search"]) 42 | app.include_router(chat.router, prefix="/api/chat", tags=["chat"]) 43 | app.include_router(analysis.router, prefix="/api/analysis", tags=["analysis"]) 44 | 45 | @app.get("/api/health") 46 | async def health_check(): 47 | return {"status": "healthy"} 48 | 49 | @app.get("/") 50 | async def root(): 51 | return { 52 | "message": "知识库系统API正在运行", 53 | "docs_url": "/docs", 54 | "version": app.version 55 | } 56 | 57 | if __name__ == "__main__": 58 | import uvicorn 59 | uvicorn.run( 60 | "main:app", 61 | host=settings.server.host, 62 | port=settings.server.port, 63 | workers=settings.server.workers 64 | ) 65 | -------------------------------------------------------------------------------- /backend/api/models/analysis.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional, List, Dict, Any 3 | 4 | class DocumentAnalysisRequest(BaseModel): 5 | """文档分析请求模型""" 6 | document_id: str 7 | focus_areas: str = "关键数据点和趋势" 8 | instructions: Optional[str] = None 9 | 10 | class DocumentAnalysisResponse(BaseModel): 11 | """文档分析响应模型""" 12 | document_id: str 13 | analysis: str 14 | summary: str 15 | 16 | class TextAnalysisRequest(BaseModel): 17 | """文本分析请求模型""" 18 | text: str 19 | focus_areas: str = "关键点和主题" 20 | instructions: Optional[str] = None 21 | 22 | class TextAnalysisResponse(BaseModel): 23 | """文本分析响应模型""" 24 | analysis: str 25 | summary: str 26 | 27 | class PredictionRequest(BaseModel): 28 | """预测请求模型""" 29 | historical_data: str 30 | target: str 31 | context: Optional[str] = None 32 | instructions: Optional[str] = None 33 | 34 | class PredictionFactor(BaseModel): 35 | """预测影响因素模型""" 36 | name: str 37 | impact: str # 'positive', 'negative', 'neutral' 38 | weight: float # 0.0 to 1.0 39 | 40 | class PredictionResponse(BaseModel): 41 | """预测响应模型""" 42 | prediction: str 43 | confidence: str # 'high', 'medium', 'low' 44 | factors: List[PredictionFactor] = [] -------------------------------------------------------------------------------- /backend/api/models/chat.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional, List, Dict, Any 3 | 4 | class ChatRequest(BaseModel): 5 | """聊天请求模型""" 6 | message: str 7 | system_prompt: Optional[str] = None 8 | use_rag: Optional[bool] = True 9 | user_specific: Optional[bool] = True 10 | document_ids: Optional[List[str]] = None 11 | max_context_chunks: int = 5 12 | 13 | class ChatSource(BaseModel): 14 | """聊天上下文来源模型""" 15 | id: str 16 | text: str 17 | metadata: Dict[str, Any] 18 | 19 | class ChatResponse(BaseModel): 20 | """聊天响应模型""" 21 | message_id: str 22 | answer: str 23 | sources: Optional[List[ChatSource]] = [] 24 | 25 | class ChatHistoryItem(BaseModel): 26 | """聊天历史记录项模型""" 27 | role: str # 'user' 或 'assistant' 28 | content: str 29 | message_id: Optional[str] = None 30 | timestamp: Optional[float] = None 31 | 32 | class ChatSessionResponse(BaseModel): 33 | """聊天会话响应模型""" 34 | session_id: str 35 | history: List[ChatHistoryItem] 36 | created_at: float 37 | updated_at: float -------------------------------------------------------------------------------- /backend/api/models/document.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional, List, Dict, Any 3 | from datetime import datetime 4 | from enum import Enum 5 | 6 | class DocumentStatus(str, Enum): 7 | PROCESSING = "processing" 8 | CHUNKING = "chunking" 9 | EMBEDDING = "embedding" 10 | INDEXING = "indexing" 11 | INDEXED = "indexed" 12 | ERROR = "error" 13 | 14 | class DocumentMetadata(BaseModel): 15 | """文档元数据模型""" 16 | id: str 17 | filename: str 18 | file_extension: str 19 | mime_type: str 20 | file_size: int 21 | file_hash: str 22 | upload_date: datetime 23 | user_id: str 24 | status: DocumentStatus 25 | chunks_count: Optional[int] = 0 26 | text_length: Optional[int] = 0 27 | processing_error: Optional[str] = None 28 | extracted_metadata: Optional[Dict[str, Any]] = None 29 | custom_metadata: Optional[Dict[str, Any]] = None 30 | 31 | class DocumentResponse(BaseModel): 32 | """文档响应模型""" 33 | document: DocumentMetadata 34 | message: Optional[str] = None 35 | 36 | class DocumentListResponse(BaseModel): 37 | """文档列表响应模型""" 38 | documents: List[DocumentMetadata] 39 | total: int 40 | limit: int 41 | offset: int 42 | 43 | class DocumentStatusResponse(BaseModel): 44 | """任务状态响应模型""" 45 | task_id: str 46 | status: Dict[str, Any] -------------------------------------------------------------------------------- /backend/api/models/search.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional, List, Dict, Any 3 | 4 | class SearchRequest(BaseModel): 5 | """搜索请求模型""" 6 | query: str 7 | user_id: Optional[str] = None 8 | limit: int = 10 9 | offset: int = 0 10 | use_hybrid: bool = True 11 | filters: Optional[Dict[str, Any]] = None 12 | 13 | class SearchResult(BaseModel): 14 | """搜索结果项模型""" 15 | id: str 16 | score: float 17 | text: str 18 | metadata: Dict[str, Any] 19 | 20 | class SearchResponse(BaseModel): 21 | """搜索响应模型""" 22 | results: List[SearchResult] 23 | count: int 24 | 25 | class RelatedQueryResponse(BaseModel): 26 | """相关查询响应模型""" 27 | queries: List[str] 28 | count: int -------------------------------------------------------------------------------- /backend/api/models/user.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field, EmailStr 2 | from typing import Optional, List, Dict, Any 3 | from datetime import datetime 4 | 5 | class User(BaseModel): 6 | """用户数据模型""" 7 | id: str 8 | username: str 9 | email: EmailStr 10 | full_name: Optional[str] = None 11 | disabled: Optional[bool] = False 12 | role: str = "user" 13 | created_at: datetime = Field(default_factory=datetime.now) 14 | last_login: Optional[datetime] = None 15 | 16 | class UserCreate(BaseModel): 17 | """创建用户请求模型""" 18 | username: str 19 | email: EmailStr 20 | password: str 21 | full_name: Optional[str] = None 22 | 23 | class UserUpdate(BaseModel): 24 | """更新用户请求模型""" 25 | email: Optional[EmailStr] = None 26 | full_name: Optional[str] = None 27 | password: Optional[str] = None 28 | disabled: Optional[bool] = None 29 | role: Optional[str] = None 30 | 31 | class UserInDB(User): 32 | """数据库存储的用户模型,包含哈希密码""" 33 | hashed_password: str 34 | 35 | class Token(BaseModel): 36 | """认证令牌模型""" 37 | access_token: str 38 | token_type: str 39 | 40 | class TokenData(BaseModel): 41 | """认证令牌数据模型""" 42 | username: Optional[str] = None 43 | exp: Optional[int] = None -------------------------------------------------------------------------------- /backend/api/routers/analysis.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Depends, Query 2 | from typing import List, Optional, Dict, Any 3 | import logging 4 | from ..deps.auth import get_current_user 5 | from ..models.user import User 6 | from ..models.analysis import ( 7 | DocumentAnalysisRequest, 8 | DocumentAnalysisResponse, 9 | TextAnalysisRequest, 10 | TextAnalysisResponse, 11 | PredictionRequest, 12 | PredictionResponse 13 | ) 14 | from ...services.llm_service import LLMService 15 | from ...services.vector_store import VectorStore 16 | from ...services.search_service import SearchService 17 | 18 | router = APIRouter() 19 | logger = logging.getLogger(__name__) 20 | 21 | # 初始化服务 22 | vector_store = VectorStore() 23 | llm_service = LLMService() 24 | search_service = SearchService(vector_store, llm_service) 25 | 26 | @router.post("/document", response_model=DocumentAnalysisResponse) 27 | async def analyze_document( 28 | request: DocumentAnalysisRequest, 29 | current_user: User = Depends(get_current_user) 30 | ): 31 | """分析文档内容""" 32 | try: 33 | # 创建文档过滤器 34 | filters = {"document_id": request.document_id} 35 | 36 | # 获取文档内容 37 | doc_chunks = search_service.search( 38 | query="", # 空查询,获取所有内容 39 | user_id=current_user.id, 40 | limit=100, # 限制返回块数 41 | filters=filters 42 | ) 43 | 44 | # 提取文档文本 45 | doc_text = " ".join([chunk["text"] for chunk in doc_chunks]) 46 | 47 | # 生成分析提示 48 | analysis_prompt = f"""请分析以下文档内容,重点关注{request.focus_areas}。请给出关键点、主要观点和结论。 49 | 50 | 文档内容: 51 | {doc_text[:8000]} # 防止提示过长 52 | 53 | 分析要求: {request.instructions if request.instructions else "提供全面分析"}""" 54 | 55 | # 调用LLM生成分析 56 | response = llm_service.generate(analysis_prompt) 57 | analysis = response["response"] 58 | 59 | return { 60 | "document_id": request.document_id, 61 | "analysis": analysis, 62 | "summary": analysis[:200] + "..." if len(analysis) > 200 else analysis 63 | } 64 | 65 | except Exception as e: 66 | logger.error(f"文档分析错误: {str(e)}") 67 | raise HTTPException(status_code=500, detail="文档分析失败") 68 | 69 | @router.post("/text", response_model=TextAnalysisResponse) 70 | async def analyze_text( 71 | request: TextAnalysisRequest, 72 | current_user: User = Depends(get_current_user) 73 | ): 74 | """分析任意文本内容""" 75 | try: 76 | # 生成分析提示 77 | analysis_prompt = f"""请分析以下文本内容,重点关注{request.focus_areas}。请给出关键点、主要观点和结论。 78 | 79 | 文本内容: 80 | {request.text[:8000]} # 防止提示过长 81 | 82 | 分析要求: {request.instructions if request.instructions else "提供全面分析"}""" 83 | 84 | # 调用LLM生成分析 85 | response = llm_service.generate(analysis_prompt) 86 | analysis = response["response"] 87 | 88 | return { 89 | "analysis": analysis, 90 | "summary": analysis[:200] + "..." if len(analysis) > 200 else analysis 91 | } 92 | 93 | except Exception as e: 94 | logger.error(f"文本分析错误: {str(e)}") 95 | raise HTTPException(status_code=500, detail="文本分析失败") 96 | 97 | @router.post("/predict", response_model=PredictionResponse) 98 | async def predict( 99 | request: PredictionRequest, 100 | current_user: User = Depends(get_current_user) 101 | ): 102 | """基于历史数据进行预测""" 103 | try: 104 | # 准备预测提示 105 | prediction_prompt = f"""根据以下历史数据和上下文,请预测{request.target}。 106 | 107 | 历史数据: 108 | {request.historical_data[:4000]} # 防止提示过长 109 | 110 | 上下文信息: 111 | {request.context[:4000] if request.context else "无额外上下文信息"} 112 | 113 | 预测目标: {request.target} 114 | 预测说明: {request.instructions if request.instructions else "提供合理预测"}""" 115 | 116 | # 调用LLM生成预测 117 | response = llm_service.generate(prediction_prompt) 118 | prediction_result = response["response"] 119 | 120 | return { 121 | "prediction": prediction_result, 122 | "confidence": "中等", # 假设的置信度,实际上需要更复杂的计算 123 | "factors": [] # 影响因素,可以通过二次分析提取 124 | } 125 | 126 | except Exception as e: 127 | logger.error(f"预测错误: {str(e)}") 128 | raise HTTPException(status_code=500, detail="生成预测失败") 129 | -------------------------------------------------------------------------------- /backend/api/routers/chat.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Depends, WebSocket, WebSocketDisconnect 2 | from fastapi.responses import StreamingResponse 3 | from typing import List, Optional, Dict, Any 4 | import logging 5 | import json 6 | import uuid 7 | import asyncio 8 | from ..deps.auth import get_current_user, get_token_from_websocket 9 | from ..models.user import User 10 | from ..models.chat import ChatRequest, ChatResponse 11 | from ...services.search_service import SearchService 12 | from ...services.llm_service import LLMService 13 | from ...services.vector_store import VectorStore 14 | from ...services.graphrag_service import GraphRAGService 15 | 16 | router = APIRouter() 17 | logger = logging.getLogger(__name__) 18 | 19 | # 初始化服务 20 | vector_store = VectorStore() 21 | llm_service = LLMService() 22 | search_service = SearchService(vector_store, llm_service) 23 | graphrag_service = GraphRAGService(vector_store, llm_service) 24 | 25 | # 存储活跃的WebSocket连接 26 | active_connections = {} 27 | 28 | @router.post("/", response_model=ChatResponse) 29 | async def chat_completion( 30 | request: ChatRequest, 31 | current_user: User = Depends(get_current_user) 32 | ): 33 | """标准聊天请求""" 34 | try: 35 | # 确定是否需要RAG 36 | use_rag = request.use_rag if request.use_rag is not None else True 37 | 38 | if use_rag: 39 | # 使用GraphRAG提供上下文 40 | context = graphrag_service.get_context_for_query( 41 | query=request.message, 42 | user_id=current_user.id if request.user_specific else None, 43 | document_ids=request.document_ids, 44 | max_results=request.max_context_chunks 45 | ) 46 | 47 | # 生成带上下文的提示 48 | augmented_prompt = f"""请根据以下上下文和相关信息回答问题。如果上下文信息不足以回答问题,请明确指出。 49 | 50 | 上下文信息: 51 | {context} 52 | 53 | 用户问题: {request.message}""" 54 | 55 | # 调用LLM生成回复 56 | response = llm_service.generate(augmented_prompt, system_prompt=request.system_prompt) 57 | answer = response["response"] 58 | else: 59 | # 直接调用LLM 60 | response = llm_service.generate(request.message, system_prompt=request.system_prompt) 61 | answer = response["response"] 62 | 63 | return { 64 | "message_id": str(uuid.uuid4()), 65 | "answer": answer, 66 | "sources": [{"id": s["id"], "text": s["text"], "metadata": s["metadata"]} for s in context] if use_rag else [] 67 | } 68 | 69 | except Exception as e: 70 | logger.error(f"聊天处理错误: {str(e)}") 71 | raise HTTPException(status_code=500, detail="处理聊天请求失败") 72 | 73 | @router.post("/stream", response_class=StreamingResponse) 74 | async def stream_chat_completion( 75 | request: ChatRequest, 76 | current_user: User = Depends(get_current_user) 77 | ): 78 | """流式聊天响应""" 79 | try: 80 | # 确定是否需要RAG 81 | use_rag = request.use_rag if request.use_rag is not None else True 82 | context = [] 83 | 84 | if use_rag: 85 | # 使用GraphRAG提供上下文 86 | context = graphrag_service.get_context_for_query( 87 | query=request.message, 88 | user_id=current_user.id if request.user_specific else None, 89 | document_ids=request.document_ids, 90 | max_results=request.max_context_chunks 91 | ) 92 | 93 | # 生成带上下文的提示 94 | augmented_prompt = f"""请根据以下上下文和相关信息回答问题。如果上下文信息不足以回答问题,请明确指出。 95 | 96 | 上下文信息: 97 | {context} 98 | 99 | 用户问题: {request.message}""" 100 | prompt = augmented_prompt 101 | else: 102 | prompt = request.message 103 | 104 | # 创建一个流式响应生成器 105 | async def generate(): 106 | # 发送元数据和来源信息 107 | message_id = str(uuid.uuid4()) 108 | metadata = { 109 | "message_id": message_id, 110 | "sources": [{"id": s["id"], "text": s["text"], "metadata": s["metadata"]} for s in context] if use_rag else [] 111 | } 112 | yield f"data: {json.dumps({'type': 'metadata', 'content': metadata})}\n\n" 113 | 114 | # 流式生成内容 115 | for chunk in llm_service.generate_stream(prompt, system_prompt=request.system_prompt): 116 | yield f"data: {json.dumps({'type': 'content', 'content': chunk})}\n\n" 117 | 118 | # 发送完成信号 119 | yield f"data: {json.dumps({'type': 'done'})}\n\n" 120 | 121 | return StreamingResponse( 122 | generate(), 123 | media_type="text/event-stream" 124 | ) 125 | 126 | except Exception as e: 127 | logger.error(f"流式聊天处理错误: {str(e)}") 128 | raise HTTPException(status_code=500, detail="处理聊天请求失败") 129 | -------------------------------------------------------------------------------- /backend/api/routers/documents.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends, Query 2 | from fastapi.responses import JSONResponse 3 | from typing import List, Optional, Dict, Any 4 | import json 5 | import logging 6 | import tempfile 7 | import os 8 | import uuid 9 | from ..deps.auth import get_current_user 10 | from ..models.user import User 11 | from ..models.document import DocumentResponse, DocumentListResponse, DocumentStatusResponse 12 | from ...services.document_service import DocumentService 13 | from ...services.vector_store import VectorStore 14 | 15 | router = APIRouter() 16 | logger = logging.getLogger(__name__) 17 | 18 | # 初始化服务 19 | vector_store = VectorStore() 20 | document_service = DocumentService(vector_store) 21 | 22 | @router.post("/upload", response_model=DocumentResponse) 23 | async def upload_document( 24 | file: UploadFile = File(...), 25 | metadata: Optional[str] = Form(None), 26 | current_user: User = Depends(get_current_user) 27 | ): 28 | """上传文档进行处理和索引""" 29 | try: 30 | # 解析元数据 31 | metadata_dict = json.loads(metadata) if metadata else {} 32 | 33 | # 创建临时文件 34 | with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file: 35 | # 写入上传的文件内容 36 | content = await file.read() 37 | temp_file.write(content) 38 | temp_file_path = temp_file.name 39 | 40 | try: 41 | # 创建任务ID 42 | task_id = str(uuid.uuid4()) 43 | 44 | # 提交处理任务 45 | doc_metadata = document_service.process_document( 46 | file=open(temp_file_path, 'rb'), 47 | filename=file.filename, 48 | user_id=current_user.id, 49 | metadata=metadata_dict 50 | ) 51 | 52 | return JSONResponse( 53 | status_code=202, # Accepted 54 | content={"message": "文档上传已接受处理", "document": doc_metadata, "task_id": task_id} 55 | ) 56 | finally: 57 | # 删除临时文件 58 | os.unlink(temp_file_path) 59 | 60 | except ValueError as e: 61 | raise HTTPException(status_code=400, detail=str(e)) 62 | except Exception as e: 63 | logger.error(f"上传文档错误: {str(e)}") 64 | raise HTTPException(status_code=500, detail=f"文档上传失败: {str(e)}") 65 | 66 | @router.get("/{document_id}", response_model=DocumentResponse) 67 | async def get_document( 68 | document_id: str, 69 | current_user: User = Depends(get_current_user) 70 | ): 71 | """获取文档状态和元数据""" 72 | try: 73 | doc_metadata = document_service.get_document_metadata(document_id) 74 | 75 | # 检查授权 76 | if doc_metadata.get("user_id") != current_user.id: 77 | raise HTTPException(status_code=403, detail="无权访问此文档") 78 | 79 | return {"document": doc_metadata} 80 | 81 | except ValueError as e: 82 | raise HTTPException(status_code=404, detail=str(e)) 83 | except Exception as e: 84 | logger.error(f"获取文档 {document_id} 错误: {str(e)}") 85 | raise HTTPException(status_code=500, detail=f"获取文档失败: {str(e)}") 86 | 87 | @router.get("/", response_model=DocumentListResponse) 88 | async def list_documents( 89 | limit: int = Query(100, ge=1, le=1000), 90 | offset: int = Query(0, ge=0), 91 | category: Optional[str] = None, 92 | status: Optional[str] = None, 93 | current_user: User = Depends(get_current_user) 94 | ): 95 | """列出当前用户的所有文档""" 96 | try: 97 | # 准备过滤器 98 | filters = {"user_id": current_user.id} 99 | if category: 100 | filters["category"] = category 101 | if status: 102 | filters["status"] = status 103 | 104 | # 获取文档列表 105 | documents, total_count = document_service.get_all_documents( 106 | filters=filters, 107 | limit=limit, 108 | offset=offset 109 | ) 110 | 111 | return {"documents": documents, "total": total_count, "limit": limit, "offset": offset} 112 | 113 | except Exception as e: 114 | logger.error(f"列出文档错误: {str(e)}") 115 | raise HTTPException(status_code=500, detail=f"获取文档列表失败: {str(e)}") 116 | 117 | @router.delete("/{document_id}") 118 | async def delete_document( 119 | document_id: str, 120 | current_user: User = Depends(get_current_user) 121 | ): 122 | """删除文档及其索引""" 123 | try: 124 | # 检查文档所有权 125 | doc_metadata = document_service.get_document_metadata(document_id) 126 | if doc_metadata.get("user_id") != current_user.id: 127 | raise HTTPException(status_code=403, detail="无权删除此文档") 128 | 129 | # 删除文档 130 | document_service.delete_document(document_id) 131 | 132 | return {"message": f"文档 {document_id} 已成功删除"} 133 | 134 | except ValueError as e: 135 | raise HTTPException(status_code=404, detail=str(e)) 136 | except Exception as e: 137 | logger.error(f"删除文档 {document_id} 错误: {str(e)}") 138 | raise HTTPException(status_code=500, detail=f"删除文档失败: {str(e)}") 139 | 140 | @router.post("/{document_id}/reindex") 141 | async def reindex_document( 142 | document_id: str, 143 | current_user: User = Depends(get_current_user) 144 | ): 145 | """重新索引文档""" 146 | try: 147 | # 检查文档所有权 148 | doc_metadata = document_service.get_document_metadata(document_id) 149 | if doc_metadata.get("user_id") != current_user.id: 150 | raise HTTPException(status_code=403, detail="无权重新索引此文档") 151 | 152 | # 创建重新索引任务 153 | task_id = document_service.reindex_document(document_id) 154 | 155 | return {"message": f"文档 {document_id} 重新索引任务已创建", "task_id": task_id} 156 | 157 | except ValueError as e: 158 | raise HTTPException(status_code=404, detail=str(e)) 159 | except Exception as e: 160 | logger.error(f"重新索引文档 {document_id} 错误: {str(e)}") 161 | raise HTTPException(status_code=500, detail=f"重新索引文档失败: {str(e)}") 162 | 163 | @router.get("/task/{task_id}", response_model=DocumentStatusResponse) 164 | async def get_task_status( 165 | task_id: str, 166 | current_user: User = Depends(get_current_user) 167 | ): 168 | """获取文档处理任务状态""" 169 | try: 170 | task_status = document_service.get_task_status(task_id) 171 | 172 | # 检查任务所有权 173 | if task_status.get("user_id") and task_status.get("user_id") != current_user.id: 174 | raise HTTPException(status_code=403, detail="无权查看此任务状态") 175 | 176 | return {"task_id": task_id, "status": task_status} 177 | 178 | except ValueError as e: 179 | raise HTTPException(status_code=404, detail=str(e)) 180 | except Exception as e: 181 | logger.error(f"获取任务 {task_id} 状态错误: {str(e)}") 182 | raise HTTPException(status_code=500, detail=f"获取任务状态失败: {str(e)}") 183 | -------------------------------------------------------------------------------- /backend/api/routers/search.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Depends, Query 2 | from typing import List, Optional, Dict, Any 3 | import logging 4 | from ..deps.auth import get_current_user 5 | from ..models.user import User 6 | from ..models.search import SearchRequest, SearchResponse 7 | from ...services.search_service import SearchService 8 | from ...services.llm_service import LLMService 9 | from ...services.vector_store import VectorStore 10 | 11 | router = APIRouter() 12 | logger = logging.getLogger(__name__) 13 | 14 | # 初始化服务 15 | vector_store = VectorStore() 16 | llm_service = LLMService() 17 | search_service = SearchService(vector_store, llm_service) 18 | 19 | @router.post("/", response_model=SearchResponse) 20 | async def search_documents( 21 | request: SearchRequest, 22 | current_user: User = Depends(get_current_user) 23 | ): 24 | """搜索文档库""" 25 | try: 26 | results = search_service.search( 27 | query=request.query, 28 | user_id=current_user.id, 29 | limit=request.limit, 30 | use_hybrid=request.use_hybrid, 31 | filters=request.filters 32 | ) 33 | 34 | return {"results": results, "count": len(results)} 35 | 36 | except Exception as e: 37 | logger.error(f"搜索错误: {str(e)}") 38 | raise HTTPException(status_code=500, detail="搜索执行失败") 39 | 40 | @router.get("/documents/{document_id}", response_model=SearchResponse) 41 | async def search_within_document( 42 | document_id: str, 43 | query: str = Query(..., min_length=1), 44 | limit: int = Query(10, ge=1, le=100), 45 | current_user: User = Depends(get_current_user) 46 | ): 47 | """在特定文档内搜索""" 48 | try: 49 | # 创建文档过滤器 50 | filters = {"document_id": document_id} 51 | 52 | results = search_service.search( 53 | query=query, 54 | user_id=current_user.id, 55 | limit=limit, 56 | filters=filters 57 | ) 58 | 59 | return {"results": results, "count": len(results)} 60 | 61 | except Exception as e: 62 | logger.error(f"在文档内搜索错误: {str(e)}") 63 | raise HTTPException(status_code=500, detail="搜索执行失败") 64 | -------------------------------------------------------------------------------- /backend/embeddings/batch_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | from typing import List, Optional, Dict, Any 5 | import numpy as np 6 | from concurrent.futures import ThreadPoolExecutor 7 | from .model import get_embeddings 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | class BatchProcessor: 12 | """批量处理嵌入向量的工具类""" 13 | 14 | def __init__(self, config_path: str = "configs/worker.yaml"): 15 | # 加载配置 16 | with open(config_path, "r") as f: 17 | self.config = yaml.safe_load(f) 18 | 19 | # 获取批处理设置 20 | batch_settings = self.config["embedding"] 21 | self.batch_size = batch_settings["batch_size"] 22 | self.max_workers = batch_settings["max_workers"] 23 | 24 | logger.info(f"嵌入批处理器初始化完成,批大小: {self.batch_size}, 最大工作线程: {self.max_workers}") 25 | 26 | def process_in_batches(self, texts: List[str]) -> List[List[float]]: 27 | """ 28 | 批量处理文本嵌入 29 | 30 | Args: 31 | texts: 要处理的文本列表 32 | 33 | Returns: 34 | List[List[float]]: 嵌入向量列表 35 | """ 36 | all_embeddings = [] 37 | 38 | # 根据批大小拆分文本 39 | for i in range(0, len(texts), self.batch_size): 40 | batch = texts[i:i+self.batch_size] 41 | logger.info(f"处理批次 {i//self.batch_size + 1}/{(len(texts) + self.batch_size - 1)//self.batch_size}, 大小: {len(batch)}") 42 | 43 | # 获取当前批次的嵌入 44 | batch_embeddings = get_embeddings(batch) 45 | all_embeddings.extend(batch_embeddings) 46 | 47 | return all_embeddings 48 | 49 | def process_in_parallel(self, texts: List[str]) -> List[List[float]]: 50 | """ 51 | 并行批量处理文本嵌入 52 | 53 | Args: 54 | texts: 要处理的文本列表 55 | 56 | Returns: 57 | List[List[float]]: 嵌入向量列表 58 | """ 59 | # 根据批大小拆分文本 60 | batches = [texts[i:i+self.batch_size] for i in range(0, len(texts), self.batch_size)] 61 | logger.info(f"拆分为 {len(batches)} 个批次进行并行处理") 62 | 63 | # 并行处理 64 | all_embeddings = [] 65 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 66 | # 提交所有批次任务 67 | futures = [executor.submit(get_embeddings, batch) for batch in batches] 68 | 69 | # 收集结果 70 | for future in futures: 71 | result = future.result() 72 | all_embeddings.extend(result) 73 | 74 | # 确保结果顺序与输入一致 75 | return all_embeddings 76 | -------------------------------------------------------------------------------- /backend/embeddings/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | from typing import List, Optional, Dict, Any 5 | import numpy as np 6 | from ..services.llm_service import LLMService 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | # 全局LLM服务实例 11 | _llm_service = None 12 | 13 | def get_llm_service() -> LLMService: 14 | """获取LLM服务单例""" 15 | global _llm_service 16 | if _llm_service is None: 17 | config_path = os.getenv("OLLAMA_CONFIG_PATH", "configs/ollama.yaml") 18 | _llm_service = LLMService(config_path) 19 | return _llm_service 20 | 21 | def get_embeddings(texts: List[str]) -> List[List[float]]: 22 | """ 23 | 获取文本列表的向量嵌入 24 | 25 | Args: 26 | texts: 文本列表 27 | 28 | Returns: 29 | List[List[float]]: 嵌入向量列表 30 | """ 31 | try: 32 | llm_service = get_llm_service() 33 | embeddings = [] 34 | 35 | for text in texts: 36 | # 截断长文本 37 | max_length = 8192 # 大多数嵌入模型的最大输入长度 38 | if len(text) > max_length: 39 | text = text[:max_length] 40 | 41 | # 获取嵌入 42 | embedding = llm_service.get_embedding(text) 43 | embeddings.append(embedding) 44 | 45 | return embeddings 46 | 47 | except Exception as e: 48 | logger.error(f"获取嵌入向量失败: {str(e)}") 49 | # 返回零向量作为回退方案 50 | return [[0.0] * 768] * len(texts) # 假设维度为768 51 | -------------------------------------------------------------------------------- /backend/processors/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from abc import ABC, abstractmethod 4 | from typing import Dict, Any, Tuple, List, Optional, BinaryIO, Type 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | class DocumentProcessor(ABC): 9 | """文档处理器基类,定义处理接口""" 10 | 11 | @abstractmethod 12 | def process(self, file: BinaryIO) -> Tuple[str, Dict[str, Any]]: 13 | """ 14 | 处理文档文件 15 | 16 | Args: 17 | file: 文件对象 18 | 19 | Returns: 20 | Tuple[str, Dict[str, Any]]: 提取的文本和元数据 21 | """ 22 | pass 23 | 24 | # 存储注册的处理器 25 | _PROCESSORS: Dict[str, Type[DocumentProcessor]] = {} 26 | 27 | def register_processor(extensions: List[str]): 28 | """文档处理器注册装饰器""" 29 | def decorator(processor_class: Type[DocumentProcessor]): 30 | for ext in extensions: 31 | _PROCESSORS[ext.lower()] = processor_class 32 | return processor_class 33 | return decorator 34 | 35 | def get_document_processor(extension: str) -> DocumentProcessor: 36 | """ 37 | 根据文件扩展名获取适当的处理器 38 | 39 | Args: 40 | extension: 文件扩展名(包含点,如'.pdf') 41 | 42 | Returns: 43 | DocumentProcessor: 文档处理器实例 44 | 45 | Raises: 46 | ValueError: 如果找不到适用于该扩展名的处理器 47 | """ 48 | ext = extension.lower() 49 | if ext not in _PROCESSORS: 50 | raise ValueError(f"未找到支持的处理器: {ext}") 51 | 52 | # 实例化处理器 53 | return _PROCESSORS[ext]() 54 | -------------------------------------------------------------------------------- /backend/processors/docx_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Dict, Any, Tuple, List, Optional, BinaryIO 4 | import tempfile 5 | import docx 6 | import re 7 | from .base import DocumentProcessor, register_processor 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | @register_processor(extensions=['.docx', '.doc']) 12 | class DocxProcessor(DocumentProcessor): 13 | """Word文档处理器""" 14 | 15 | def process(self, file: BinaryIO) -> Tuple[str, Dict[str, Any]]: 16 | """ 17 | 处理Word文件,提取文本和元数据 18 | 19 | Args: 20 | file: Word文件对象 21 | 22 | Returns: 23 | Tuple[str, Dict[str, Any]]: 提取的文本和元数据 24 | """ 25 | # 创建临时文件 26 | with tempfile.NamedTemporaryFile(suffix='.docx', delete=False) as tmp: 27 | # 写入数据 28 | tmp.write(file.read()) 29 | file.seek(0) # 重置文件指针 30 | tmp_path = tmp.name 31 | 32 | try: 33 | # 打开Word文档 34 | doc = docx.Document(tmp_path) 35 | 36 | # 提取元数据 37 | core_properties = doc.core_properties 38 | metadata = { 39 | "title": core_properties.title if hasattr(core_properties, 'title') else "", 40 | "author": core_properties.author if hasattr(core_properties, 'author') else "", 41 | "comments": core_properties.comments if hasattr(core_properties, 'comments') else "", 42 | "keywords": core_properties.keywords if hasattr(core_properties, 'keywords') else "", 43 | "subject": core_properties.subject if hasattr(core_properties, 'subject') else "", 44 | "last_modified_by": core_properties.last_modified_by if hasattr(core_properties, 'last_modified_by') else "", 45 | "created": str(core_properties.created) if hasattr(core_properties, 'created') else "", 46 | "modified": str(core_properties.modified) if hasattr(core_properties, 'modified') else "", 47 | "paragraph_count": len(doc.paragraphs), 48 | "section_count": len(doc.sections), 49 | } 50 | 51 | # 提取文本 52 | extracted_text = "" 53 | 54 | # 提取标题 55 | if doc.paragraphs and doc.paragraphs[0].style.name.startswith('Heading'): 56 | metadata["document_title"] = doc.paragraphs[0].text 57 | 58 | # 提取段落 59 | for para in doc.paragraphs: 60 | # 添加段落文本 61 | if para.text.strip(): 62 | extracted_text += para.text + "\n\n" 63 | 64 | # 提取表格内容 65 | table_count = 0 66 | for table in doc.tables: 67 | table_count += 1 68 | extracted_text += f"\n--- 表格 {table_count} ---\n" 69 | 70 | for i, row in enumerate(table.rows): 71 | if i == 0: 72 | # 表头 73 | extracted_text += "| " 74 | for cell in row.cells: 75 | extracted_text += cell.text + " | " 76 | extracted_text += "\n" 77 | extracted_text += "|" + "---|" * len(row.cells) + "\n" 78 | else: 79 | # 表体 80 | extracted_text += "| " 81 | for cell in row.cells: 82 | extracted_text += cell.text + " | " 83 | extracted_text += "\n" 84 | 85 | extracted_text += "\n" 86 | 87 | metadata["table_count"] = table_count 88 | 89 | # 清理文本 90 | extracted_text = self._clean_text(extracted_text) 91 | 92 | return extracted_text, metadata 93 | 94 | except Exception as e: 95 | logger.error(f"处理Word文档时出错: {str(e)}") 96 | raise Exception(f"Word文档处理失败: {str(e)}") 97 | finally: 98 | # 清理临时文件 99 | if os.path.exists(tmp_path): 100 | os.unlink(tmp_path) 101 | 102 | def _clean_text(self, text: str) -> str: 103 | """清理提取的文本""" 104 | # 删除多余的空格 105 | text = re.sub(r'\s+', ' ', text) 106 | 107 | # 删除多余的换行符 108 | text = re.sub(r'\n\s*\n', '\n\n', text) 109 | 110 | # 保持段落结构 111 | text = re.sub(r'(\w) *\n *(\w)', r'\1 \2', text) 112 | 113 | return text 114 | -------------------------------------------------------------------------------- /backend/processors/excel_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Dict, Any, Tuple, List, Optional, BinaryIO 4 | import tempfile 5 | import pandas as pd 6 | from .base import DocumentProcessor, register_processor 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | @register_processor(extensions=['.xlsx', '.xls']) 11 | class ExcelProcessor(DocumentProcessor): 12 | """Excel文档处理器""" 13 | 14 | def process(self, file: BinaryIO) -> Tuple[str, Dict[str, Any]]: 15 | """ 16 | 处理Excel文件,提取文本和元数据 17 | 18 | Args: 19 | file: Excel文件对象 20 | 21 | Returns: 22 | Tuple[str, Dict[str, Any]]: 提取的文本和元数据 23 | """ 24 | # 创建临时文件 25 | with tempfile.NamedTemporaryFile(suffix='.xlsx', delete=False) as tmp: 26 | # 写入数据 27 | tmp.write(file.read()) 28 | file.seek(0) # 重置文件指针 29 | tmp_path = tmp.name 30 | 31 | try: 32 | # 使用pandas读取Excel文件 33 | excel_file = pd.ExcelFile(tmp_path) 34 | sheet_names = excel_file.sheet_names 35 | 36 | # 提取元数据 37 | metadata = { 38 | "sheet_count": len(sheet_names), 39 | "sheet_names": sheet_names, 40 | "file_path": tmp_path, 41 | } 42 | 43 | # 提取文本 44 | extracted_text = "" 45 | 46 | # 遍历所有工作表 47 | for sheet_index, sheet_name in enumerate(sheet_names): 48 | df = pd.read_excel(excel_file, sheet_name=sheet_name) 49 | 50 | # 添加工作表标题 51 | extracted_text += f"\n--- 工作表: {sheet_name} ---\n\n" 52 | 53 | # 处理列名(表头) 54 | header_row = "| " + " | ".join(str(col) for col in df.columns) + " |\n" 55 | separator = "|" + "---|" * len(df.columns) + "\n" 56 | extracted_text += header_row + separator 57 | 58 | # 处理数据行 59 | for _, row in df.iterrows(): 60 | row_text = "| " + " | ".join(str(cell) if str(cell) != "nan" else "" for cell in row) + " |\n" 61 | extracted_text += row_text 62 | 63 | extracted_text += "\n" 64 | 65 | # 添加工作表级元数据 66 | metadata[f"sheet_{sheet_index}_rows"] = len(df) 67 | metadata[f"sheet_{sheet_index}_columns"] = len(df.columns) 68 | 69 | # 统计总行数和列数 70 | total_cells = sum(metadata.get(f"sheet_{i}_rows", 0) * metadata.get(f"sheet_{i}_columns", 0) 71 | for i in range(len(sheet_names))) 72 | metadata["total_cells"] = total_cells 73 | 74 | return extracted_text, metadata 75 | 76 | except Exception as e: 77 | logger.error(f"处理Excel文件时出错: {str(e)}") 78 | raise Exception(f"Excel处理失败: {str(e)}") 79 | finally: 80 | # 清理临时文件 81 | if os.path.exists(tmp_path): 82 | os.unlink(tmp_path) -------------------------------------------------------------------------------- /backend/processors/html_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Dict, Any, Tuple, List, Optional, BinaryIO 4 | import tempfile 5 | from bs4 import BeautifulSoup 6 | import re 7 | from .base import DocumentProcessor, register_processor 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | @register_processor(extensions=['.html', '.htm']) 12 | class HtmlProcessor(DocumentProcessor): 13 | """HTML文档处理器""" 14 | 15 | def process(self, file: BinaryIO) -> Tuple[str, Dict[str, Any]]: 16 | """ 17 | 处理HTML文件,提取文本和元数据 18 | 19 | Args: 20 | file: HTML文件对象 21 | 22 | Returns: 23 | Tuple[str, Dict[str, Any]]: 提取的文本和元数据 24 | """ 25 | # 读取文件内容 26 | content = file.read().decode('utf-8', errors='replace') 27 | 28 | try: 29 | # 使用BeautifulSoup解析HTML 30 | soup = BeautifulSoup(content, 'html.parser') 31 | 32 | # 提取元数据 33 | metadata = { 34 | "title": self._get_title(soup), 35 | "description": self._get_meta_content(soup, "description"), 36 | "keywords": self._get_meta_content(soup, "keywords"), 37 | "author": self._get_meta_content(soup, "author"), 38 | "links_count": len(soup.find_all('a')), 39 | "images_count": len(soup.find_all('img')), 40 | "tables_count": len(soup.find_all('table')), 41 | "scripts_count": len(soup.find_all('script')), 42 | "styles_count": len(soup.find_all('style')), 43 | } 44 | 45 | # 提取并清理文本 46 | extracted_text = self._extract_text(soup) 47 | 48 | return extracted_text, metadata 49 | 50 | except Exception as e: 51 | logger.error(f"处理HTML文件时出错: {str(e)}") 52 | raise Exception(f"HTML处理失败: {str(e)}") 53 | 54 | def _get_title(self, soup: BeautifulSoup) -> str: 55 | """获取HTML标题""" 56 | title_tag = soup.find('title') 57 | return title_tag.get_text() if title_tag else "" 58 | 59 | def _get_meta_content(self, soup: BeautifulSoup, meta_name: str) -> str: 60 | """获取指定名称的meta标签内容""" 61 | meta_tag = soup.find('meta', attrs={'name': meta_name}) 62 | if meta_tag and meta_tag.get('content'): 63 | return meta_tag.get('content') 64 | return "" 65 | 66 | def _extract_text(self, soup: BeautifulSoup) -> str: 67 | """提取并清理HTML文本内容""" 68 | # 删除脚本和样式标签 69 | for script_or_style in soup(["script", "style"]): 70 | script_or_style.decompose() 71 | 72 | # 提取文本 73 | text = soup.get_text() 74 | 75 | # 处理标题 76 | for heading in soup.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6']): 77 | heading_level = int(heading.name[1]) 78 | heading_text = heading.get_text().strip() 79 | text = text.replace(heading_text, f"\n{'#' * heading_level} {heading_text}\n") 80 | 81 | # 处理列表 82 | for ul in soup.find_all('ul'): 83 | for li in ul.find_all('li'): 84 | li_text = li.get_text().strip() 85 | text = text.replace(li_text, f"- {li_text}") 86 | 87 | # 处理有序列表 88 | for ol in soup.find_all('ol'): 89 | for i, li in enumerate(ol.find_all('li')): 90 | li_text = li.get_text().strip() 91 | text = text.replace(li_text, f"{i+1}. {li_text}") 92 | 93 | # 清理空白 94 | lines = (line.strip() for line in text.splitlines()) 95 | chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) 96 | text = '\n'.join(chunk for chunk in chunks if chunk) 97 | 98 | return text -------------------------------------------------------------------------------- /backend/processors/pdf_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Dict, Any, Tuple, List, Optional, BinaryIO 4 | import tempfile 5 | import fitz # PyMuPDF 6 | import re 7 | from .base import DocumentProcessor, register_processor 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | @register_processor(extensions=['.pdf']) 12 | class PDFProcessor(DocumentProcessor): 13 | """PDF文档处理器""" 14 | 15 | def process(self, file: BinaryIO) -> Tuple[str, Dict[str, Any]]: 16 | """ 17 | 处理PDF文件,提取文本和元数据 18 | 19 | Args: 20 | file: PDF文件对象 21 | 22 | Returns: 23 | Tuple[str, Dict[str, Any]]: 提取的文本和元数据 24 | """ 25 | # 创建临时文件 26 | with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as tmp: 27 | # 写入数据 28 | tmp.write(file.read()) 29 | file.seek(0) # 重置文件指针 30 | tmp_path = tmp.name 31 | 32 | try: 33 | # 打开PDF文档 34 | doc = fitz.open(tmp_path) 35 | 36 | # 提取元数据 37 | metadata = { 38 | "title": doc.metadata.get("title", ""), 39 | "author": doc.metadata.get("author", ""), 40 | "subject": doc.metadata.get("subject", ""), 41 | "keywords": doc.metadata.get("keywords", ""), 42 | "creator": doc.metadata.get("creator", ""), 43 | "producer": doc.metadata.get("producer", ""), 44 | "creation_date": doc.metadata.get("creationDate", ""), 45 | "modification_date": doc.metadata.get("modDate", ""), 46 | "page_count": len(doc), 47 | } 48 | 49 | # 提取文本 50 | extracted_text = "" 51 | for page_num, page in enumerate(doc): 52 | # 获取页面文本 53 | page_text = page.get_text() 54 | 55 | # 清理文本 56 | page_text = self._clean_text(page_text) 57 | 58 | # 添加页码标记 59 | extracted_text += f"\n--- 页 {page_num + 1} ---\n{page_text}\n" 60 | 61 | # 提取目录(TOC) 62 | toc = doc.get_toc() 63 | if toc: 64 | toc_data = [] 65 | for level, title, page in toc: 66 | toc_data.append({ 67 | "level": level, 68 | "title": title, 69 | "page": page 70 | }) 71 | metadata["toc"] = toc_data 72 | 73 | # 处理图像 74 | image_count = 0 75 | for page_num, page in enumerate(doc): 76 | image_list = page.get_images(full=True) 77 | image_count += len(image_list) 78 | 79 | metadata["image_count"] = image_count 80 | 81 | return extracted_text, metadata 82 | 83 | except Exception as e: 84 | logger.error(f"处理PDF时出错: {str(e)}") 85 | raise Exception(f"PDF处理失败: {str(e)}") 86 | finally: 87 | # 清理临时文件 88 | if os.path.exists(tmp_path): 89 | os.unlink(tmp_path) 90 | 91 | def _clean_text(self, text: str) -> str: 92 | """清理提取的文本""" 93 | # 删除多余的空格 94 | text = re.sub(r'\s+', ' ', text) 95 | 96 | # 删除多余的换行符 97 | text = re.sub(r'\n\s*\n', '\n\n', text) 98 | 99 | # 删除页眉页脚(假设出现在每页的前几行和后几行) 100 | lines = text.split('\n') 101 | if len(lines) > 6: 102 | # 保留中间部分,去掉可能的页眉页脚 103 | text = '\n'.join(lines) 104 | 105 | return text 106 | -------------------------------------------------------------------------------- /backend/services/document_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | from typing import Dict, List, Optional, Any, BinaryIO, Tuple 5 | import uuid 6 | import hashlib 7 | import mimetypes 8 | from datetime import datetime 9 | import redis 10 | import json 11 | from .vector_store import VectorStore 12 | from ..processors.base import get_document_processor 13 | from ..embeddings.model import get_embeddings 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | class DocumentService: 18 | def __init__(self, vector_store: VectorStore, config_path: str = "configs/worker.yaml", 19 | redis_config_path: str = "configs/redis.yaml"): 20 | # 加载配置 21 | with open(config_path, "r") as f: 22 | self.config = yaml.safe_load(f) 23 | 24 | with open(redis_config_path, "r") as f: 25 | redis_config = yaml.safe_load(f) 26 | 27 | # 初始化Redis客户端 28 | self.redis = redis.Redis( 29 | host=redis_config["redis"]["host"], 30 | port=redis_config["redis"]["port"], 31 | db=redis_config["redis"]["db"], 32 | password=redis_config["redis"]["password"], 33 | decode_responses=True 34 | ) 35 | 36 | # 存储向量存储引用 37 | self.vector_store = vector_store 38 | 39 | # 加载文档处理设置 40 | self.doc_settings = self.config["document_processing"] 41 | self.supported_formats = self.doc_settings["supported_formats"] 42 | self.chunk_size = self.doc_settings["chunk_size"] 43 | self.chunk_overlap = self.doc_settings["chunk_overlap"] 44 | 45 | logger.info(f"文档服务初始化完成,支持格式: {self.supported_formats}") 46 | 47 | def process_document(self, file: BinaryIO, filename: str, user_id: str, 48 | metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 49 | """处理文档文件,提取文本并准备索引""" 50 | # 验证文件格式 51 | ext = os.path.splitext(filename)[1].lower() 52 | if ext not in self.supported_formats: 53 | raise ValueError(f"不支持的文件格式: {ext}. 支持的格式: {self.supported_formats}") 54 | 55 | # 生成文档ID和文件哈希 56 | doc_id = str(uuid.uuid4()) 57 | file_content = file.read() 58 | file_hash = hashlib.sha256(file_content).hexdigest() 59 | file.seek(0) 60 | 61 | # 确定MIME类型 62 | mime_type, _ = mimetypes.guess_type(filename) 63 | if not mime_type: 64 | mime_type = "application/octet-stream" 65 | 66 | # 创建文档元数据 67 | doc_metadata = { 68 | "id": doc_id, 69 | "filename": filename, 70 | "file_extension": ext, 71 | "mime_type": mime_type, 72 | "file_size": len(file_content), 73 | "file_hash": file_hash, 74 | "upload_date": datetime.now().isoformat(), 75 | "user_id": user_id, 76 | "status": "processing", 77 | "chunks_count": 0 78 | } 79 | 80 | # 添加自定义元数据 81 | if metadata: 82 | doc_metadata.update({"custom_metadata": metadata}) 83 | 84 | # 临时存储文档元数据 85 | self.redis.set( 86 | f"doc:{doc_id}:metadata", 87 | json.dumps(doc_metadata), 88 | ex=3600 # 1小时过期 89 | ) 90 | 91 | try: 92 | # 获取适当的处理器 93 | processor = get_document_processor(ext) 94 | 95 | # 提取文本和元数据 96 | extracted_text, extracted_metadata = processor.process(file) 97 | 98 | # 更新元数据 99 | doc_metadata.update({ 100 | "extracted_metadata": extracted_metadata, 101 | "text_length": len(extracted_text), 102 | "status": "chunking" 103 | }) 104 | 105 | # 更新Redis 106 | self.redis.set( 107 | f"doc:{doc_id}:metadata", 108 | json.dumps(doc_metadata), 109 | ex=3600 110 | ) 111 | 112 | # 分块文档 113 | chunks = self._chunk_document(extracted_text, doc_id) 114 | 115 | # 更新元数据 116 | doc_metadata.update({ 117 | "chunks_count": len(chunks), 118 | "status": "embedding" 119 | }) 120 | 121 | # 更新Redis 122 | self.redis.set( 123 | f"doc:{doc_id}:metadata", 124 | json.dumps(doc_metadata), 125 | ex=3600 126 | ) 127 | 128 | # 处理文档嵌入和索引 129 | self._process_embeddings_and_index(chunks, doc_id, doc_metadata) 130 | 131 | # 更新最终状态 132 | doc_metadata["status"] = "indexed" 133 | self.redis.set( 134 | f"doc:{doc_id}:metadata", 135 | json.dumps(doc_metadata), 136 | ex=3600 137 | ) 138 | 139 | # 返回元数据 140 | return doc_metadata 141 | 142 | except Exception as e: 143 | # 更新元数据错误 144 | doc_metadata.update({ 145 | "status": "error", 146 | "processing_error": str(e) 147 | }) 148 | 149 | # 更新Redis 150 | self.redis.set( 151 | f"doc:{doc_id}:metadata", 152 | json.dumps(doc_metadata), 153 | ex=3600 154 | ) 155 | 156 | logger.error(f"处理文档 {filename} (ID: {doc_id}) 错误: {str(e)}") 157 | raise Exception(f"文档处理失败: {str(e)}") 158 | 159 | def _chunk_document(self, text: str, doc_id: str) -> List[Dict[str, Any]]: 160 | """将文档文本分成块""" 161 | chunks = [] 162 | 163 | # 基于字符的分块,带重叠 164 | start = 0 165 | chunk_id = 0 166 | 167 | while start < len(text): 168 | # 计算结束位置 169 | end = min(start + self.chunk_size, len(text)) 170 | 171 | # 如果不是最后一块,尝试找到自然断点 172 | if end < len(text): 173 | # 尝试找到句子断点(句号、问号、感叹号) 174 | for i in range(min(100, end - start)): 175 | if text[end - i - 1] in ['.', '?', '!', '\n'] and text[end - i] in [' ', '\n']: 176 | end = end - i 177 | break 178 | 179 | # 提取块文本 180 | chunk_text = text[start:end] 181 | 182 | # 创建块元数据 183 | chunk = { 184 | "chunk_id": f"{doc_id}_{chunk_id}", 185 | "document_id": doc_id, 186 | "text": chunk_text, 187 | "start_char": start, 188 | "end_char": end, 189 | "length": len(chunk_text) 190 | } 191 | 192 | chunks.append(chunk) 193 | 194 | # 移动到下一块位置,考虑重叠 195 | start = end - self.chunk_overlap 196 | chunk_id += 1 197 | 198 | # 确保我们有进展 199 | if start >= end: 200 | start = end 201 | 202 | return chunks 203 | 204 | def _process_embeddings_and_index(self, chunks: List[Dict[str, Any]], 205 | doc_id: str, doc_metadata: Dict[str, Any]): 206 | """处理文档块嵌入和索引""" 207 | # 提取文本列表 208 | texts = [chunk["text"] for chunk in chunks] 209 | 210 | # 批量获取嵌入 211 | batch_size = self.config["embedding"]["batch_size"] 212 | all_embeddings = [] 213 | 214 | for i in range(0, len(texts), batch_size): 215 | batch_texts = texts[i:i+batch_size] 216 | batch_embeddings = get_embeddings(batch_texts) 217 | all_embeddings.extend(batch_embeddings) 218 | 219 | # 准备向量和有效载荷 220 | vectors = all_embeddings 221 | payloads = chunks 222 | chunk_ids = [chunk["chunk_id"] for chunk in chunks] 223 | 224 | # 存储在向量存储中 225 | self.vector_store.add_documents( 226 | vectors=vectors, 227 | payloads=payloads, 228 | ids=chunk_ids 229 | ) 230 | 231 | # 存储文档元数据 232 | doc_vector = all_embeddings[0] if all_embeddings else [0.0] * 768 # 默认向量大小 233 | self.vector_store.add_documents( 234 | vectors=[doc_vector], 235 | payloads=[doc_metadata], 236 | ids=[doc_id], 237 | collection_name=self.vector_store.metadata_collection 238 | ) 239 | 240 | def get_document_metadata(self, document_id: str) -> Dict[str, Any]: 241 | """ 242 | 获取文档元数据 243 | 244 | Args: 245 | document_id: 文档ID 246 | 247 | Returns: 248 | Dict[str, Any]: 文档元数据 249 | 250 | Raises: 251 | ValueError: 如果文档不存在 252 | """ 253 | try: 254 | # 从Redis检查缓存 255 | cached_metadata = self.redis.get(f"doc:{document_id}:metadata") 256 | if cached_metadata: 257 | return json.loads(cached_metadata) 258 | 259 | # 如果缓存中没有,从向量存储中获取 260 | results = self.vector_store.query( 261 | query_vector=[0.0] * 768, # 使用空向量 262 | limit=1, 263 | filter_={"id": document_id}, 264 | collection_name=self.vector_store.metadata_collection 265 | ) 266 | 267 | if not results: 268 | raise ValueError(f"文档 {document_id} 不存在") 269 | 270 | # 返回文档元数据 271 | return results[0]["payload"] 272 | 273 | except Exception as e: 274 | logger.error(f"获取文档元数据错误: {str(e)}") 275 | raise Exception(f"获取文档元数据失败: {str(e)}") 276 | 277 | def get_all_documents(self, filters: Dict[str, Any], limit: int = 100, offset: int = 0) -> Tuple[List[Dict[str, Any]], int]: 278 | """ 279 | 获取所有文档的元数据 280 | 281 | Args: 282 | filters: 过滤条件 283 | limit: 最大返回数量 284 | offset: 偏移量 285 | 286 | Returns: 287 | Tuple[List[Dict[str, Any]], int]: 文档元数据列表和总数 288 | """ 289 | try: 290 | # 从向量存储中获取所有文档 291 | results = self.vector_store.query( 292 | query_vector=[0.0] * 768, # 使用空向量 293 | limit=limit + offset, 294 | filter_=filters, 295 | collection_name=self.vector_store.metadata_collection 296 | ) 297 | 298 | # 获取总数 299 | total_count = len(results) 300 | 301 | # 应用偏移和限制 302 | results = results[offset:offset+limit] 303 | 304 | # 提取元数据 305 | documents = [result["payload"] for result in results] 306 | 307 | return documents, total_count 308 | 309 | except Exception as e: 310 | logger.error(f"获取所有文档错误: {str(e)}") 311 | raise Exception(f"获取文档列表失败: {str(e)}") 312 | 313 | def delete_document(self, document_id: str) -> bool: 314 | """ 315 | 删除文档及其索引 316 | 317 | Args: 318 | document_id: 文档ID 319 | 320 | Returns: 321 | bool: 操作是否成功 322 | 323 | Raises: 324 | ValueError: 如果文档不存在 325 | """ 326 | try: 327 | # 检查文档是否存在 328 | doc_metadata = self.get_document_metadata(document_id) 329 | 330 | # 删除文档块 331 | self.vector_store.client.delete( 332 | collection_name=self.vector_store.default_collection, 333 | points_selector=rest.Filter( 334 | must=[ 335 | rest.FieldCondition( 336 | key="document_id", 337 | match=rest.MatchValue(value=document_id) 338 | ) 339 | ] 340 | ) 341 | ) 342 | 343 | # 删除文档元数据 344 | self.vector_store.client.delete( 345 | collection_name=self.vector_store.metadata_collection, 346 | points_selector=rest.Filter( 347 | must=[ 348 | rest.FieldCondition( 349 | key="id", 350 | match=rest.MatchValue(value=document_id) 351 | ) 352 | ] 353 | ) 354 | ) 355 | 356 | # 删除Redis缓存 357 | self.redis.delete(f"doc:{document_id}:metadata") 358 | 359 | # 创建一个删除文档的后台任务来清理相关资源 360 | task_id = str(uuid.uuid4()) 361 | task_data = { 362 | "type": "document_cleanup", 363 | "task_id": task_id, 364 | "document_id": document_id, 365 | "created_at": time.time() 366 | } 367 | 368 | # 将任务添加到队列 369 | self.redis.rpush("task_queue", json.dumps(task_data)) 370 | 371 | return True 372 | 373 | except ValueError as e: 374 | raise e 375 | except Exception as e: 376 | logger.error(f"删除文档错误: {str(e)}") 377 | raise Exception(f"删除文档失败: {str(e)}") 378 | 379 | def reindex_document(self, document_id: str) -> str: 380 | """ 381 | 重新索引文档 382 | 383 | Args: 384 | document_id: 文档ID 385 | 386 | Returns: 387 | str: 任务ID 388 | 389 | Raises: 390 | ValueError: 如果文档不存在 391 | """ 392 | try: 393 | # 检查文档是否存在 394 | doc_metadata = self.get_document_metadata(document_id) 395 | 396 | # 创建一个重新索引任务 397 | task_id = str(uuid.uuid4()) 398 | task_data = { 399 | "type": "indexing", 400 | "task_id": task_id, 401 | "document_ids": [document_id], 402 | "user_id": doc_metadata.get("user_id", ""), 403 | "rebuild_all": False, 404 | "created_at": time.time() 405 | } 406 | 407 | # 将任务添加到队列 408 | self.redis.rpush("task_queue", json.dumps(task_data)) 409 | 410 | # 更新文档状态 411 | doc_metadata["status"] = "reindexing" 412 | self.redis.set( 413 | f"doc:{document_id}:metadata", 414 | json.dumps(doc_metadata), 415 | ex=3600 416 | ) 417 | 418 | # 存储任务状态 419 | self.redis.hset( 420 | f"task:{task_id}", 421 | mapping={ 422 | "status": "queued", 423 | "type": "indexing", 424 | "document_ids": json.dumps([document_id]), 425 | "created_at": time.time(), 426 | "user_id": doc_metadata.get("user_id", "") 427 | } 428 | ) 429 | 430 | return task_id 431 | 432 | except ValueError as e: 433 | raise e 434 | except Exception as e: 435 | logger.error(f"重新索引文档错误: {str(e)}") 436 | raise Exception(f"重新索引文档失败: {str(e)}") 437 | 438 | def get_task_status(self, task_id: str) -> Dict[str, Any]: 439 | """ 440 | 获取任务状态 441 | 442 | Args: 443 | task_id: 任务ID 444 | 445 | Returns: 446 | Dict[str, Any]: 任务状态信息 447 | 448 | Raises: 449 | ValueError: 如果任务不存在 450 | """ 451 | try: 452 | # 从Redis获取任务状态 453 | task_data = self.redis.hgetall(f"task:{task_id}") 454 | 455 | if not task_data: 456 | raise ValueError(f"任务 {task_id} 不存在") 457 | 458 | # 转换某些字段 459 | if "document_ids" in task_data and task_data["document_ids"]: 460 | task_data["document_ids"] = json.loads(task_data["document_ids"]) 461 | 462 | if "created_at" in task_data and task_data["created_at"]: 463 | task_data["created_at"] = float(task_data["created_at"]) 464 | 465 | if "completed_at" in task_data and task_data["completed_at"]: 466 | task_data["completed_at"] = float(task_data["completed_at"]) 467 | 468 | return task_data 469 | 470 | except ValueError as e: 471 | raise e 472 | except Exception as e: 473 | logger.error(f"获取任务状态错误: {str(e)}") 474 | raise Exception(f"获取任务状态失败: {str(e)}") -------------------------------------------------------------------------------- /backend/services/graphrag_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | from typing import Dict, List, Optional, Any 5 | import networkx as nx 6 | import numpy as np 7 | from .vector_store import VectorStore 8 | from .llm_service import LLMService 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | class GraphRAGService: 13 | def __init__(self, vector_store: VectorStore, llm_service: LLMService, 14 | config_path: str = "configs/worker.yaml"): 15 | # 加载配置 16 | with open(config_path, "r") as f: 17 | self.config = yaml.safe_load(f) 18 | 19 | # 存储服务引用 20 | self.vector_store = vector_store 21 | self.llm_service = llm_service 22 | 23 | # 加载GraphRAG设置 24 | self.rag_settings = self.config["graphrag"] 25 | self.similarity_threshold = self.rag_settings["similarity_threshold"] 26 | self.max_neighbors = self.rag_settings["max_neighbors"] 27 | self.max_hops = self.rag_settings["max_hops"] 28 | 29 | # 初始化图结构 30 | self.graph = nx.Graph() 31 | 32 | logger.info("GraphRAG服务初始化完成") 33 | 34 | def update_graph(self, document_id: str = None): 35 | """ 36 | 更新知识图谱,可以指定文档ID更新特定文档,或不指定更新全部 37 | 38 | Args: 39 | document_id (str, optional): 特定文档ID,为None时更新所有文档 40 | """ 41 | try: 42 | # 查询条件 43 | filter_dict = {} 44 | if document_id: 45 | filter_dict = {"document_id": document_id} 46 | 47 | # 获取所有块 48 | all_chunks = self.vector_store.query( 49 | query_vector=[0.0] * 768, # 使用空向量,将获取所有文档而不是相似匹配 50 | limit=10000, # 大量限制,实际上取决于数据库能返回的最大记录数 51 | filter_=filter_dict 52 | ) 53 | 54 | # 清理图形(如果更新特定文档) 55 | if document_id: 56 | # 删除与特定文档相关的节点 57 | nodes_to_remove = [node for node in self.graph.nodes if self.graph.nodes[node].get("document_id") == document_id] 58 | self.graph.remove_nodes_from(nodes_to_remove) 59 | else: 60 | # 重建整个图 61 | self.graph.clear() 62 | 63 | # 为每个块创建节点 64 | for chunk in all_chunks: 65 | chunk_id = chunk["id"] 66 | doc_id = chunk["payload"]["document_id"] 67 | 68 | # 添加节点 69 | self.graph.add_node( 70 | chunk_id, 71 | document_id=doc_id, 72 | text=chunk["payload"]["text"], 73 | embedding=chunk["payload"].get("embedding", []) 74 | ) 75 | 76 | # 计算相似度并创建边 77 | chunk_ids = list(self.graph.nodes) 78 | for i, chunk_id in enumerate(chunk_ids): 79 | # 获取当前块的向量嵌入 80 | if not self.graph.nodes[chunk_id].get("embedding"): 81 | # 如果没有嵌入,获取文本并生成嵌入 82 | text = self.graph.nodes[chunk_id]["text"] 83 | embedding = self.llm_service.get_embedding(text) 84 | self.graph.nodes[chunk_id]["embedding"] = embedding 85 | 86 | current_embedding = self.graph.nodes[chunk_id]["embedding"] 87 | current_doc_id = self.graph.nodes[chunk_id]["document_id"] 88 | 89 | # 为效率考虑,只与同一文档内的块或特殊关系块计算相似度 90 | for j, other_id in enumerate(chunk_ids): 91 | if i == j: 92 | continue 93 | 94 | other_doc_id = self.graph.nodes[other_id]["document_id"] 95 | 96 | # 同一文档内相邻块自动连接 97 | if current_doc_id == other_doc_id and abs(i - j) == 1: 98 | self.graph.add_edge(chunk_id, other_id, weight=0.9) 99 | continue 100 | 101 | # 计算其他相似块 102 | if not self.graph.nodes[other_id].get("embedding"): 103 | # 如果没有嵌入,获取文本并生成嵌入 104 | text = self.graph.nodes[other_id]["text"] 105 | embedding = self.llm_service.get_embedding(text) 106 | self.graph.nodes[other_id]["embedding"] = embedding 107 | 108 | other_embedding = self.graph.nodes[other_id]["embedding"] 109 | 110 | # 计算余弦相似度 111 | similarity = self._cosine_similarity(current_embedding, other_embedding) 112 | 113 | # 如果相似度高于阈值,添加边 114 | if similarity > self.similarity_threshold: 115 | self.graph.add_edge(chunk_id, other_id, weight=similarity) 116 | 117 | logger.info(f"知识图谱更新完成,节点数: {self.graph.number_of_nodes()}, 边数: {self.graph.number_of_edges()}") 118 | 119 | except Exception as e: 120 | logger.error(f"更新知识图谱时出错: {str(e)}") 121 | raise Exception(f"知识图谱更新失败: {str(e)}") 122 | 123 | def get_context_for_query(self, query: str, user_id: Optional[str] = None, 124 | document_ids: Optional[List[str]] = None, 125 | max_results: int = 5) -> List[Dict[str, Any]]: 126 | """ 127 | 为查询获取上下文信息 128 | 129 | Args: 130 | query: 用户查询 131 | user_id: 可选用户ID过滤 132 | document_ids: 可选文档ID列表过滤 133 | max_results: 最大返回结果数量 134 | 135 | Returns: 136 | List[Dict[str, Any]]: 相关上下文信息列表 137 | """ 138 | try: 139 | # 生成查询嵌入 140 | query_embedding = self.llm_service.get_embedding(query) 141 | 142 | # 准备过滤器 143 | filter_dict = {} 144 | if user_id: 145 | filter_dict["user_id"] = user_id 146 | if document_ids: 147 | filter_dict["document_id"] = document_ids 148 | 149 | # 首先进行向量搜索找到最相关的入口点 150 | initial_results = self.vector_store.query( 151 | query_vector=query_embedding, 152 | limit=3, # 获取少量高质量入口点 153 | filter_=filter_dict 154 | ) 155 | 156 | # 扩展结果集 157 | expanded_results = set() 158 | for result in initial_results: 159 | chunk_id = result["id"] 160 | expanded_results.add(chunk_id) 161 | 162 | # 图遍历扩展 163 | if chunk_id in self.graph: 164 | # 获取邻居节点 165 | neighbors = self._get_relevant_neighbors(chunk_id, query_embedding, hops=self.max_hops) 166 | expanded_results.update(neighbors) 167 | 168 | # 将扩展结果转换回详细信息 169 | detailed_results = [] 170 | for chunk_id in expanded_results: 171 | # 从图中获取节点信息 172 | if chunk_id in self.graph: 173 | node_data = self.graph.nodes[chunk_id] 174 | if "text" in node_data: 175 | # 计算与查询的相似度 176 | similarity = self._cosine_similarity( 177 | query_embedding, 178 | node_data.get("embedding", self.llm_service.get_embedding(node_data["text"])) 179 | ) 180 | 181 | detailed_results.append({ 182 | "id": chunk_id, 183 | "text": node_data["text"], 184 | "score": similarity, 185 | "metadata": { 186 | "document_id": node_data.get("document_id", ""), 187 | } 188 | }) 189 | 190 | # 排序并限制结果数量 191 | detailed_results.sort(key=lambda x: x["score"], reverse=True) 192 | return detailed_results[:max_results] 193 | 194 | except Exception as e: 195 | logger.error(f"获取查询上下文时出错: {str(e)}") 196 | # 回退到简单向量搜索 197 | try: 198 | fallback_results = self.vector_store.query( 199 | query_vector=query_embedding, 200 | limit=max_results, 201 | filter_=filter_dict 202 | ) 203 | 204 | # 格式化结果 205 | return [{ 206 | "id": result["id"], 207 | "text": result["payload"]["text"], 208 | "score": result["score"], 209 | "metadata": { 210 | "document_id": result["payload"].get("document_id", ""), 211 | } 212 | } for result in fallback_results] 213 | except: 214 | logger.error("回退搜索也失败") 215 | return [] 216 | 217 | def _get_relevant_neighbors(self, start_node: str, query_embedding: List[float], 218 | hops: int = 2) -> List[str]: 219 | """获取与查询相关的邻居节点""" 220 | # BFS搜索相关节点 221 | visited = set([start_node]) 222 | queue = [(start_node, 0)] # (node, hop) 223 | relevant_nodes = [] 224 | 225 | while queue: 226 | node, hop_count = queue.pop(0) 227 | 228 | # 如果超过最大跳数,停止 229 | if hop_count >= hops: 230 | continue 231 | 232 | # 获取邻居节点 233 | if node not in self.graph: 234 | continue 235 | 236 | neighbors = list(self.graph.neighbors(node)) 237 | 238 | # 计算邻居节点与查询的相似度 239 | neighbor_similarities = [] 240 | for neighbor in neighbors: 241 | if neighbor in visited: 242 | continue 243 | 244 | # 获取向量嵌入 245 | if "embedding" not in self.graph.nodes[neighbor]: 246 | text = self.graph.nodes[neighbor]["text"] 247 | self.graph.nodes[neighbor]["embedding"] = self.llm_service.get_embedding(text) 248 | 249 | embedding = self.graph.nodes[neighbor]["embedding"] 250 | 251 | # 计算相似度 252 | similarity = self._cosine_similarity(query_embedding, embedding) 253 | neighbor_similarities.append((neighbor, similarity)) 254 | 255 | # 按相似度排序 256 | neighbor_similarities.sort(key=lambda x: x[1], reverse=True) 257 | 258 | # 取前N个相关邻居 259 | for neighbor, similarity in neighbor_similarities[:self.max_neighbors]: 260 | if similarity > self.similarity_threshold and neighbor not in visited: 261 | visited.add(neighbor) 262 | queue.append((neighbor, hop_count + 1)) 263 | relevant_nodes.append(neighbor) 264 | 265 | return relevant_nodes 266 | 267 | def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float: 268 | """计算余弦相似度""" 269 | vec1 = np.array(vec1) 270 | vec2 = np.array(vec2) 271 | 272 | norm1 = np.linalg.norm(vec1) 273 | norm2 = np.linalg.norm(vec2) 274 | 275 | if norm1 == 0 or norm2 == 0: 276 | return 0.0 277 | 278 | return np.dot(vec1, vec2) / (norm1 * norm2) 279 | -------------------------------------------------------------------------------- /backend/services/llm_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | import httpx 5 | from typing import Dict, List, Optional, Any 6 | import json 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | class LLMService: 11 | def __init__(self, config_path: str = "configs/ollama.yaml"): 12 | # 加载配置 13 | with open(config_path, "r") as f: 14 | self.config = yaml.safe_load(f) 15 | 16 | self.host = self.config["ollama"]["host"] 17 | self.port = self.config["ollama"]["port"] 18 | self.timeout = self.config["ollama"]["timeout"] 19 | self.base_url = f"http://{self.host}:{self.port}" 20 | self.default_model = self.config["models"]["default"] 21 | self.embeddings_model = self.config["models"]["embeddings"] 22 | 23 | # 初始化客户端 24 | self.client = httpx.Client(timeout=self.timeout) 25 | 26 | # 加载推理参数 27 | self.inference_params = self.config["inference"] 28 | 29 | logger.info(f"LLM服务初始化完成,使用模型: {self.default_model}") 30 | 31 | def generate(self, prompt: str, system_prompt: Optional[str] = None, **kwargs) -> Dict[str, Any]: 32 | """生成LLM响应""" 33 | model = kwargs.pop("model", self.default_model) 34 | 35 | # 准备请求数据 36 | request_data = { 37 | "model": model, 38 | "prompt": prompt, 39 | **self.inference_params, 40 | **kwargs 41 | } 42 | 43 | if system_prompt: 44 | request_data["system"] = system_prompt 45 | 46 | try: 47 | # 请求Ollama API 48 | response = self.client.post( 49 | f"{self.base_url}/api/generate", 50 | json=request_data 51 | ) 52 | response.raise_for_status() 53 | 54 | return response.json() 55 | except httpx.HTTPError as e: 56 | logger.error(f"Ollama API调用失败: {str(e)}") 57 | raise Exception(f"生成响应失败: {str(e)}") 58 | 59 | def generate_stream(self, prompt: str, system_prompt: Optional[str] = None, **kwargs): 60 | """生成流式LLM响应""" 61 | model = kwargs.pop("model", self.default_model) 62 | 63 | # 准备请求数据 64 | request_data = { 65 | "model": model, 66 | "prompt": prompt, 67 | "stream": True, 68 | **self.inference_params, 69 | **kwargs 70 | } 71 | 72 | if system_prompt: 73 | request_data["system"] = system_prompt 74 | 75 | try: 76 | with httpx.stream("POST", f"{self.base_url}/api/generate", 77 | json=request_data, timeout=self.timeout) as response: 78 | response.raise_for_status() 79 | 80 | for line in response.iter_lines(): 81 | if line: 82 | try: 83 | chunk = json.loads(line) 84 | if "response" in chunk: 85 | yield chunk["response"] 86 | except json.JSONDecodeError: 87 | logger.warning(f"无法解析JSON响应: {line}") 88 | except httpx.HTTPError as e: 89 | logger.error(f"Ollama流式API调用失败: {str(e)}") 90 | raise Exception(f"生成流式响应失败: {str(e)}") 91 | 92 | def get_embedding(self, text: str, model: Optional[str] = None) -> List[float]: 93 | """获取文本的向量嵌入""" 94 | model = model or self.embeddings_model 95 | 96 | try: 97 | response = self.client.post( 98 | f"{self.base_url}/api/embeddings", 99 | json={"model": model, "prompt": text} 100 | ) 101 | response.raise_for_status() 102 | result = response.json() 103 | 104 | return result["embedding"] 105 | except httpx.HTTPError as e: 106 | logger.error(f"获取向量嵌入失败: {str(e)}") 107 | raise Exception(f"生成向量嵌入失败: {str(e)}") 108 | -------------------------------------------------------------------------------- /backend/services/search_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | from typing import Dict, List, Optional, Any 5 | import redis 6 | import json 7 | import time 8 | from .vector_store import VectorStore 9 | from .llm_service import LLMService 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | class SearchService: 14 | def __init__(self, vector_store: VectorStore, llm_service: LLMService, 15 | redis_config_path: str = "configs/redis.yaml"): 16 | # 加载Redis配置 17 | with open(redis_config_path, "r") as f: 18 | redis_config = yaml.safe_load(f) 19 | 20 | # 初始化Redis客户端 21 | self.redis = redis.Redis( 22 | host=redis_config["redis"]["host"], 23 | port=redis_config["redis"]["port"], 24 | db=redis_config["redis"]["db"], 25 | password=redis_config["redis"]["password"], 26 | decode_responses=True 27 | ) 28 | 29 | # 缓存TTL 30 | self.cache_ttl = redis_config["cache"]["ttl"] 31 | 32 | # 存储服务 33 | self.vector_store = vector_store 34 | self.llm_service = llm_service 35 | 36 | logger.info("搜索服务初始化完成") 37 | 38 | def search(self, query: str, user_id: Optional[str] = None, limit: int = 10, 39 | use_hybrid: bool = True, filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: 40 | """使用查询字符串搜索向量存储""" 41 | # 生成缓存键 42 | cache_key = f"search:{hash(query)}:{user_id or 'all'}:{limit}:{use_hybrid}:{hash(str(filters))}" 43 | 44 | # 检查缓存 45 | cached_results = self.redis.get(cache_key) 46 | if cached_results: 47 | logger.info(f"缓存命中: {query}") 48 | return json.loads(cached_results) 49 | 50 | try: 51 | # 生成查询嵌入 52 | start_time = time.time() 53 | query_embedding = self.llm_service.get_embedding(query) 54 | embedding_time = time.time() - start_time 55 | logger.debug(f"生成查询嵌入耗时: {embedding_time:.3f}秒") 56 | 57 | # 准备过滤器 58 | search_filter = self._prepare_filter(user_id, filters) 59 | 60 | # 执行向量搜索 61 | start_time = time.time() 62 | semantic_results = self.vector_store.query( 63 | query_vector=query_embedding, 64 | limit=limit, 65 | filter_=search_filter 66 | ) 67 | vector_time = time.time() - start_time 68 | logger.debug(f"向量搜索完成,耗时: {vector_time:.3f}秒") 69 | 70 | # 格式化结果 71 | results = self._format_search_results(semantic_results) 72 | 73 | # 缓存结果 74 | self.redis.set(cache_key, json.dumps(results), ex=self.cache_ttl) 75 | 76 | return results 77 | 78 | except Exception as e: 79 | logger.error(f"搜索查询'{query}'时出错: {str(e)}") 80 | raise Exception(f"搜索失败: {str(e)}") 81 | 82 | def _prepare_filter(self, user_id: Optional[str] = None, 83 | additional_filters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 84 | """准备向量搜索的过滤器""" 85 | filter_dict = {} 86 | filter_conditions = [] 87 | 88 | # 添加用户过滤器 89 | if user_id: 90 | filter_conditions.append({ 91 | "key": "user_id", 92 | "match": { 93 | "value": user_id 94 | } 95 | }) 96 | 97 | # 添加额外过滤器 98 | if additional_filters: 99 | for key, value in additional_filters.items(): 100 | if isinstance(value, list): 101 | # 处理列表值(OR条件) 102 | or_conditions = [] 103 | for val in value: 104 | or_conditions.append({ 105 | "key": key, 106 | "match": { 107 | "value": val 108 | } 109 | }) 110 | 111 | if or_conditions: 112 | filter_conditions.append({ 113 | "should": or_conditions 114 | }) 115 | else: 116 | # 处理单个值 117 | filter_conditions.append({ 118 | "key": key, 119 | "match": { 120 | "value": value 121 | } 122 | }) 123 | 124 | # 将所有条件与AND逻辑组合 125 | if filter_conditions: 126 | filter_dict = { 127 | "must": filter_conditions 128 | } 129 | 130 | return filter_dict 131 | 132 | def _format_search_results(self, vector_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 133 | """将向量搜索结果格式化为标准格式""" 134 | formatted_results = [] 135 | 136 | for result in vector_results: 137 | # 提取基本信息 138 | result_id = result["id"] 139 | score = result["score"] 140 | payload = result["payload"] 141 | 142 | # 创建格式化结果 143 | formatted_result = { 144 | "id": result_id, 145 | "score": score, 146 | "text": payload.get("text", ""), 147 | "metadata": { 148 | "document_id": payload.get("document_id", ""), 149 | "chunk_id": payload.get("chunk_id", ""), 150 | "filename": payload.get("filename", ""), 151 | "position": { 152 | "start": payload.get("start_char", 0), 153 | "end": payload.get("end_char", 0) 154 | } 155 | } 156 | } 157 | 158 | formatted_results.append(formatted_result) 159 | 160 | return formatted_results 161 | -------------------------------------------------------------------------------- /backend/services/vector_store.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | from typing import Dict, List, Optional, Any 5 | from qdrant_client import QdrantClient 6 | from qdrant_client.http import models as rest 7 | import numpy as np 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | class VectorStore: 12 | def __init__(self, config_path: str = "configs/qdrant.yaml"): 13 | # 加载配置 14 | with open(config_path, "r") as f: 15 | self.config = yaml.safe_load(f) 16 | 17 | # 初始化Qdrant客户端 18 | self.client = QdrantClient( 19 | host=self.config["qdrant"]["host"], 20 | port=self.config["qdrant"]["port"], 21 | prefer_grpc=self.config["qdrant"]["prefer_grpc"], 22 | timeout=self.config["qdrant"]["timeout"] 23 | ) 24 | 25 | # 获取集合配置 26 | self.default_collection = self.config["collections"]["default"]["name"] 27 | self.metadata_collection = self.config["collections"]["metadata"]["name"] 28 | 29 | # 初始化集合(如果不存在) 30 | self._initialize_collections() 31 | 32 | logger.info(f"向量存储初始化完成: {self.default_collection}, {self.metadata_collection}") 33 | 34 | def _initialize_collections(self): 35 | """初始化向量集合""" 36 | collections = [collection.name for collection in self.client.get_collections().collections] 37 | 38 | # 初始化文档集合 39 | if self.default_collection not in collections: 40 | doc_config = self.config["collections"]["default"] 41 | self.client.create_collection( 42 | collection_name=self.default_collection, 43 | vectors_config=rest.VectorsConfig( 44 | size=doc_config["vector_size"], 45 | distance=rest.Distance[doc_config["distance"]] 46 | ), 47 | optimizers_config=rest.OptimizersConfigDiff( 48 | deleted_threshold=doc_config["optimizers"]["deleted_threshold"], 49 | vacuum_min_vector_number=doc_config["optimizers"]["vacuum_min_vector_number"] 50 | ), 51 | hnsw_config=rest.HnswConfigDiff( 52 | m=doc_config["index"]["m"], 53 | ef_construct=doc_config["index"]["ef_construct"] 54 | ) 55 | ) 56 | logger.info(f"创建集合 {self.default_collection}") 57 | 58 | # 初始化元数据集合 59 | if self.metadata_collection not in collections: 60 | meta_config = self.config["collections"]["metadata"] 61 | self.client.create_collection( 62 | collection_name=self.metadata_collection, 63 | vectors_config=rest.VectorsConfig( 64 | size=meta_config["vector_size"], 65 | distance=rest.Distance[meta_config["distance"]] 66 | ) 67 | ) 68 | logger.info(f"创建集合 {self.metadata_collection}") 69 | 70 | def add_documents(self, vectors: List[List[float]], payloads: List[Dict[str, Any]], 71 | ids: Optional[List[str]] = None, collection_name: Optional[str] = None) -> List[str]: 72 | """添加文档向量和元数据""" 73 | collection_name = collection_name or self.default_collection 74 | 75 | # 如果没有提供ID,则生成 76 | if ids is None: 77 | ids = [str(i) for i in range(len(vectors))] 78 | 79 | # 转换为numpy数组进行验证 80 | vectors_np = np.array(vectors, dtype=np.float32) 81 | 82 | try: 83 | # 创建点批次 84 | points = [ 85 | rest.PointStruct( 86 | id=str(id_), 87 | vector=vector.tolist(), 88 | payload=payload 89 | ) 90 | for id_, vector, payload in zip(ids, vectors_np, payloads) 91 | ] 92 | 93 | # 插入批次 94 | self.client.upsert( 95 | collection_name=collection_name, 96 | points=points 97 | ) 98 | 99 | logger.info(f"向{collection_name}添加了{len(points)}个向量") 100 | return ids 101 | 102 | except Exception as e: 103 | logger.error(f"向{collection_name}添加向量失败: {str(e)}") 104 | raise Exception(f"添加向量失败: {str(e)}") 105 | 106 | def query(self, query_vector: List[float], limit: int = 5, 107 | filter_: Optional[Dict[str, Any]] = None, collection_name: Optional[str] = None) -> List[Dict[str, Any]]: 108 | """搜索相似向量""" 109 | collection_name = collection_name or self.default_collection 110 | 111 | try: 112 | # 转换查询向量为numpy数组 113 | query_vector_np = np.array(query_vector, dtype=np.float32) 114 | 115 | # 创建过滤器 116 | search_filter = None 117 | if filter_: 118 | search_filter = rest.Filter(**filter_) 119 | 120 | # 执行搜索 121 | results = self.client.search( 122 | collection_name=collection_name, 123 | query_vector=query_vector_np.tolist(), 124 | limit=limit, 125 | query_filter=search_filter, 126 | with_payload=True, 127 | with_vectors=False 128 | ) 129 | 130 | # 格式化结果 131 | formatted_results = [ 132 | { 133 | "id": str(result.id), 134 | "score": float(result.score), 135 | "payload": result.payload 136 | } 137 | for result in results 138 | ] 139 | 140 | return formatted_results 141 | 142 | except Exception as e: 143 | logger.error(f"查询{collection_name}失败: {str(e)}") 144 | raise Exception(f"查询向量失败: {str(e)}") 145 | -------------------------------------------------------------------------------- /backend/worker/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | import time 5 | import redis 6 | import json 7 | from typing import Dict, Any, Optional 8 | import signal 9 | import sys 10 | from concurrent.futures import ThreadPoolExecutor 11 | import threading 12 | from ..services.document_service import DocumentService 13 | from ..services.vector_store import VectorStore 14 | from ..services.llm_service import LLMService 15 | from ..services.graphrag_service import GraphRAGService 16 | 17 | # 加载配置 18 | config_path = os.getenv("WORKER_CONFIG_PATH", "configs/worker.yaml") 19 | with open(config_path, "r") as f: 20 | config = yaml.safe_load(f) 21 | 22 | # 加载Redis配置 23 | redis_config_path = os.getenv("REDIS_CONFIG_PATH", "configs/redis.yaml") 24 | with open(redis_config_path, "r") as f: 25 | redis_config = yaml.safe_load(f) 26 | 27 | # 设置日志 28 | logging.basicConfig( 29 | level=getattr(logging, config["worker"]["log_level"].upper()), 30 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 31 | ) 32 | logger = logging.getLogger(__name__) 33 | 34 | # 初始化服务 35 | vector_store = VectorStore() 36 | llm_service = LLMService() 37 | document_service = DocumentService(vector_store) 38 | graphrag_service = GraphRAGService(vector_store, llm_service) 39 | 40 | # 初始化Redis 41 | redis_client = redis.Redis( 42 | host=redis_config["redis"]["host"], 43 | port=redis_config["redis"]["port"], 44 | db=redis_config["redis"]["db"], 45 | password=redis_config["redis"]["password"], 46 | decode_responses=True 47 | ) 48 | 49 | # 创建线程池 50 | thread_pool = ThreadPoolExecutor(max_workers=config["worker"]["threads"]) 51 | 52 | # 运行状态标志 53 | running = True 54 | 55 | def process_document_task(task_data: Dict[str, Any]): 56 | """处理文档任务""" 57 | try: 58 | document_id = task_data.get("document_id") 59 | task_id = task_data.get("task_id") 60 | 61 | logger.info(f"处理文档任务: {task_id}, 文档ID: {document_id}") 62 | 63 | # 更新任务状态 64 | redis_client.hset(f"task:{task_id}", "status", "processing") 65 | 66 | # 实际处理文档的逻辑 67 | # 这里通常会从存储中加载临时文件,然后调用document_service的处理方法 68 | 69 | # 如果文档处理已在API层完成,这里可以进行额外的优化或后处理 70 | # 例如,更新知识图谱 71 | graphrag_service.update_graph(document_id) 72 | 73 | # 更新任务状态为完成 74 | redis_client.hset(f"task:{task_id}", "status", "completed") 75 | redis_client.hset(f"task:{task_id}", "completed_at", time.time()) 76 | 77 | logger.info(f"文档任务 {task_id} 处理完成") 78 | 79 | except Exception as e: 80 | logger.error(f"处理文档任务失败: {str(e)}") 81 | 82 | # 更新任务状态为失败 83 | if task_id: 84 | redis_client.hset(f"task:{task_id}", "status", "failed") 85 | redis_client.hset(f"task:{task_id}", "error", str(e)) 86 | 87 | def process_embedding_task(task_data: Dict[str, Any]): 88 | """处理嵌入任务""" 89 | try: 90 | document_id = task_data.get("document_id") 91 | chunk_ids = task_data.get("chunk_ids", []) 92 | task_id = task_data.get("task_id") 93 | 94 | logger.info(f"处理嵌入任务: {task_id}, 文档ID: {document_id}, 块数: {len(chunk_ids)}") 95 | 96 | # 更新任务状态 97 | redis_client.hset(f"task:{task_id}", "status", "processing") 98 | 99 | # 这里实现嵌入处理逻辑 100 | # 通常是从数据库加载块内容,然后生成嵌入 101 | 102 | # 更新任务状态为完成 103 | redis_client.hset(f"task:{task_id}", "status", "completed") 104 | redis_client.hset(f"task:{task_id}", "completed_at", time.time()) 105 | 106 | logger.info(f"嵌入任务 {task_id} 处理完成") 107 | 108 | except Exception as e: 109 | logger.error(f"处理嵌入任务失败: {str(e)}") 110 | 111 | # 更新任务状态为失败 112 | if task_id: 113 | redis_client.hset(f"task:{task_id}", "status", "failed") 114 | redis_client.hset(f"task:{task_id}", "error", str(e)) 115 | 116 | def process_indexing_task(task_data: Dict[str, Any]): 117 | """处理索引任务""" 118 | try: 119 | document_ids = task_data.get("document_ids", []) 120 | rebuild_all = task_data.get("rebuild_all", False) 121 | task_id = task_data.get("task_id") 122 | 123 | logger.info(f"处理索引任务: {task_id}, 全量重建: {rebuild_all}, 文档数: {len(document_ids)}") 124 | 125 | # 更新任务状态 126 | redis_client.hset(f"task:{task_id}", "status", "processing") 127 | 128 | # 如果是全量重建 129 | if rebuild_all: 130 | graphrag_service.update_graph() 131 | else: 132 | # 为指定文档更新图结构 133 | for doc_id in document_ids: 134 | graphrag_service.update_graph(doc_id) 135 | 136 | # 更新任务状态为完成 137 | redis_client.hset(f"task:{task_id}", "status", "completed") 138 | redis_client.hset(f"task:{task_id}", "completed_at", time.time()) 139 | 140 | logger.info(f"索引任务 {task_id} 处理完成") 141 | 142 | except Exception as e: 143 | logger.error(f"处理索引任务失败: {str(e)}") 144 | 145 | # 更新任务状态为失败 146 | if task_id: 147 | redis_client.hset(f"task:{task_id}", "status", "failed") 148 | redis_client.hset(f"task:{task_id}", "error", str(e)) 149 | 150 | def poll_tasks(): 151 | """轮询并处理任务队列中的任务""" 152 | task_types = { 153 | "document": process_document_task, 154 | "embedding": process_embedding_task, 155 | "indexing": process_indexing_task 156 | } 157 | 158 | while running: 159 | try: 160 | # 从任务队列中获取任务 161 | result = redis_client.blpop("task_queue", timeout=1) 162 | if result is None: 163 | continue 164 | 165 | # 解析任务数据 166 | _, task_json = result 167 | task_data = json.loads(task_json) 168 | 169 | # 获取任务类型 170 | task_type = task_data.get("type") 171 | task_id = task_data.get("task_id") 172 | 173 | if task_type in task_types: 174 | # 提交任务到线程池 175 | thread_pool.submit(task_types[task_type], task_data) 176 | else: 177 | logger.warning(f"未知任务类型: {task_type}, 任务ID: {task_id}") 178 | 179 | except Exception as e: 180 | logger.error(f"轮询任务时出错: {str(e)}") 181 | time.sleep(1) # 避免过度消耗CPU 182 | 183 | def handle_signal(sig, frame): 184 | """处理终止信号""" 185 | global running 186 | logger.info(f"收到信号 {sig},准备关闭...") 187 | running = False 188 | 189 | # 关闭线程池 190 | thread_pool.shutdown(wait=True) 191 | 192 | # 关闭连接 193 | redis_client.close() 194 | 195 | logger.info("工作进程已安全关闭") 196 | sys.exit(0) 197 | 198 | def main(): 199 | """主函数""" 200 | # 注册信号处理器 201 | signal.signal(signal.SIGINT, handle_signal) 202 | signal.signal(signal.SIGTERM, handle_signal) 203 | 204 | logger.info("知识库工作进程启动") 205 | 206 | # 启动任务轮询 207 | poll_thread = threading.Thread(target=poll_tasks) 208 | poll_thread.daemon = True 209 | poll_thread.start() 210 | 211 | # 主线程保持运行 212 | try: 213 | while running: 214 | time.sleep(1) 215 | except KeyboardInterrupt: 216 | handle_signal(signal.SIGINT, None) 217 | 218 | if __name__ == "__main__": 219 | main() 220 | -------------------------------------------------------------------------------- /backend/worker/tasks/document_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Dict, Any, List, BinaryIO 4 | import tempfile 5 | import time 6 | import uuid 7 | from ...services.document_service import DocumentService 8 | from ...services.vector_store import VectorStore 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | class DocumentProcessorTask: 13 | """文档处理任务类""" 14 | 15 | def __init__(self, document_service: DocumentService): 16 | self.document_service = document_service 17 | 18 | def process(self, file_path: str, filename: str, user_id: str, 19 | metadata: Dict[str, Any] = None) -> Dict[str, Any]: 20 | """ 21 | 处理文档文件 22 | 23 | Args: 24 | file_path: 临时文件路径 25 | filename: 原始文件名 26 | user_id: 用户ID 27 | metadata: 可选的元数据 28 | 29 | Returns: 30 | Dict[str, Any]: 处理结果 31 | """ 32 | start_time = time.time() 33 | logger.info(f"开始处理文档: {filename}") 34 | 35 | try: 36 | # 打开文件 37 | with open(file_path, 'rb') as file: 38 | # 处理文档 39 | result = self.document_service.process_document( 40 | file=file, 41 | filename=filename, 42 | user_id=user_id, 43 | metadata=metadata 44 | ) 45 | 46 | processing_time = time.time() - start_time 47 | logger.info(f"文档处理完成: {filename}, 耗时: {processing_time:.2f}秒") 48 | 49 | return { 50 | "success": True, 51 | "document_id": result["id"], 52 | "processing_time": processing_time, 53 | "chunks_count": result["chunks_count"] 54 | } 55 | 56 | except Exception as e: 57 | logger.error(f"处理文档{filename}时出错: {str(e)}") 58 | 59 | processing_time = time.time() - start_time 60 | return { 61 | "success": False, 62 | "error": str(e), 63 | "processing_time": processing_time 64 | } 65 | finally: 66 | # 删除临时文件 67 | if os.path.exists(file_path): 68 | try: 69 | os.unlink(file_path) 70 | except Exception as e: 71 | logger.warning(f"删除临时文件{file_path}失败: {str(e)}") -------------------------------------------------------------------------------- /configs/api.yaml: -------------------------------------------------------------------------------- 1 | server: 2 | host: '0.0.0.0' 3 | port: 8000 4 | workers: 4 5 | log_level: 'info' 6 | debug: false 7 | reload: false 8 | 9 | cors: 10 | allowed_origins: 11 | - '*' 12 | allowed_methods: 13 | - 'GET' 14 | - 'POST' 15 | - 'PUT' 16 | - 'DELETE' 17 | allowed_headers: 18 | - '*' 19 | 20 | security: 21 | api_key_header: 'X-API-Key' 22 | api_key: 'your-api-key-here' # 开发环境默认密钥,生产环境应更改 23 | jwt_secret: 'your-jwt-secret-here' # 开发环境默认密钥,生产环境应更改 24 | token_expire_minutes: 1440 # 24小时 25 | 26 | rate_limiting: 27 | enabled: true 28 | max_requests: 100 29 | time_window_seconds: 60 30 | -------------------------------------------------------------------------------- /configs/ollama.yaml: -------------------------------------------------------------------------------- 1 | ollama: 2 | host: 'ollama' 3 | port: 11434 4 | timeout: 120 5 | 6 | models: 7 | default: 'deepseek-v3' 8 | embeddings: 'deepseek-embeddings' 9 | available: 10 | - 'deepseek-v3' 11 | - 'deepseek-r1' 12 | - 'deepseek-embeddings' 13 | 14 | inference: 15 | temperature: 0.7 16 | top_p: 0.9 17 | top_k: 40 18 | max_tokens: 2048 19 | stop_sequences: [] 20 | -------------------------------------------------------------------------------- /configs/qdrant.yaml: -------------------------------------------------------------------------------- 1 | qdrant: 2 | host: 'qdrant' 3 | port: 6333 4 | grpc_port: 6334 5 | prefer_grpc: true 6 | timeout: 30 7 | 8 | collections: 9 | default: 10 | name: 'documents' 11 | vector_size: 768 12 | distance: 'Cosine' 13 | optimizers: 14 | deleted_threshold: 0.2 15 | vacuum_min_vector_number: 1000 16 | index: 17 | m: 16 18 | ef_construct: 100 19 | 20 | metadata: 21 | name: 'metadata' 22 | vector_size: 768 23 | distance: 'Cosine' 24 | -------------------------------------------------------------------------------- /configs/redis.yaml: -------------------------------------------------------------------------------- 1 | redis: 2 | host: 'redis' 3 | port: 6379 4 | db: 0 5 | password: 'redispassword' 6 | timeout: 5 7 | socket_timeout: 5 8 | socket_connect_timeout: 5 9 | 10 | cache: 11 | ttl: 3600 # 默认缓存1小时 12 | search_ttl: 1800 # 搜索结果缓存30分钟 13 | document_ttl: 86400 # 文档元数据缓存1天 14 | 15 | task_queue: 16 | name: 'task_queue' 17 | max_retries: 3 18 | retry_delay: 60 # 秒 19 | 20 | sessions: 21 | ttl: 86400 # 会话缓存1天 22 | -------------------------------------------------------------------------------- /configs/worker.yaml: -------------------------------------------------------------------------------- 1 | worker: 2 | threads: 4 3 | log_level: 'info' 4 | queue_check_interval: 1 # 秒 5 | 6 | document_processing: 7 | supported_formats: 8 | - '.pdf' 9 | - '.docx' 10 | - '.doc' 11 | - '.txt' 12 | - '.html' 13 | - '.xlsx' 14 | - '.xls' 15 | - '.pptx' 16 | - '.ppt' 17 | - '.md' 18 | chunk_size: 1000 19 | chunk_overlap: 200 20 | storage: 21 | path: '/app/data/uploads' 22 | max_size_mb: 100 23 | 24 | embedding: 25 | batch_size: 10 26 | max_workers: 4 27 | timeout: 60 28 | 29 | graphrag: 30 | enabled: true 31 | similarity_threshold: 0.75 32 | max_neighbors: 5 33 | max_hops: 2 34 | update_interval: 3600 # 自动更新图结构的间隔(秒) 35 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | api: 5 | build: 6 | context: . 7 | dockerfile: docker/Dockerfile.api 8 | ports: 9 | - "${API_PORT:-8000}:8000" 10 | volumes: 11 | - ./configs:/app/configs 12 | - ./data:/app/data 13 | depends_on: 14 | - redis 15 | - qdrant 16 | - ollama 17 | environment: 18 | - API_CONFIG_PATH=/app/configs/api.yaml 19 | - REDIS_CONFIG_PATH=/app/configs/redis.yaml 20 | - OLLAMA_CONFIG_PATH=/app/configs/ollama.yaml 21 | - QDRANT_CONFIG_PATH=/app/configs/qdrant.yaml 22 | restart: unless-stopped 23 | networks: 24 | - kb-network 25 | 26 | worker: 27 | build: 28 | context: . 29 | dockerfile: docker/Dockerfile.worker 30 | volumes: 31 | - ./configs:/app/configs 32 | - ./data:/app/data 33 | depends_on: 34 | - redis 35 | - qdrant 36 | - ollama 37 | environment: 38 | - WORKER_CONFIG_PATH=/app/configs/worker.yaml 39 | - REDIS_CONFIG_PATH=/app/configs/redis.yaml 40 | - OLLAMA_CONFIG_PATH=/app/configs/ollama.yaml 41 | - QDRANT_CONFIG_PATH=/app/configs/qdrant.yaml 42 | restart: unless-stopped 43 | networks: 44 | - kb-network 45 | 46 | frontend: 47 | build: 48 | context: . 49 | dockerfile: docker/Dockerfile.frontend 50 | ports: 51 | - "${FRONTEND_PORT:-8501}:8501" 52 | volumes: 53 | - ./configs:/app/configs 54 | - ./data:/app/data 55 | depends_on: 56 | - api 57 | environment: 58 | - STREAMLIT_SERVER_PORT=8501 59 | restart: unless-stopped 60 | networks: 61 | - kb-network 62 | 63 | redis: 64 | image: redis:6.2 65 | ports: 66 | - "6379:6379" 67 | volumes: 68 | - redis-data:/data 69 | command: redis-server --requirepass ${REDIS_PASSWORD:-redispassword} 70 | restart: unless-stopped 71 | networks: 72 | - kb-network 73 | 74 | qdrant: 75 | image: qdrant/qdrant:latest 76 | ports: 77 | - "6333:6333" 78 | - "6334:6334" 79 | volumes: 80 | - qdrant-data:/qdrant/storage 81 | restart: unless-stopped 82 | networks: 83 | - kb-network 84 | 85 | ollama: 86 | image: ollama/ollama:latest 87 | ports: 88 | - "11434:11434" 89 | volumes: 90 | - ollama-data:/root/.ollama 91 | environment: 92 | - OLLAMA_HOST=0.0.0.0 93 | - OLLAMA_MODELS=/root/.ollama/models 94 | restart: unless-stopped 95 | deploy: 96 | resources: 97 | reservations: 98 | devices: 99 | - driver: nvidia 100 | count: 1 101 | capabilities: [gpu] 102 | networks: 103 | - kb-network 104 | 105 | networks: 106 | kb-network: 107 | driver: bridge 108 | 109 | volumes: 110 | redis-data: 111 | qdrant-data: 112 | ollama-data: 113 | -------------------------------------------------------------------------------- /docker/.env: -------------------------------------------------------------------------------- 1 | # API服务配置 2 | API_HOST=0.0.0.0 3 | API_PORT=8000 4 | API_WORKERS=4 5 | API_LOG_LEVEL=info 6 | 7 | # 前端服务配置 8 | FRONTEND_HOST=0.0.0.0 9 | FRONTEND_PORT=8501 10 | 11 | # Worker配置 12 | WORKER_THREADS=4 13 | WORKER_LOG_LEVEL=info 14 | 15 | # Redis配置 16 | REDIS_HOST=redis 17 | REDIS_PORT=6379 18 | REDIS_PASSWORD=redispassword 19 | REDIS_DB=0 20 | 21 | # Qdrant配置 22 | QDRANT_HOST=qdrant 23 | QDRANT_PORT=6333 24 | QDRANT_GRPC_PORT=6334 25 | 26 | # Ollama配置 27 | OLLAMA_HOST=ollama 28 | OLLAMA_PORT=11434 29 | OLLAMA_TIMEOUT=120 30 | 31 | # DeepSeek模型配置 32 | DEEPSEEK_MODEL=deepseek-v3 33 | EMBEDDINGS_MODEL=deepseek-embeddings 34 | -------------------------------------------------------------------------------- /docker/Dockerfile.api: -------------------------------------------------------------------------------- 1 | FROM python:3.9-slim 2 | 3 | # 设置工作目录 4 | WORKDIR /app 5 | 6 | # 安装依赖 7 | COPY requirements.txt . 8 | RUN pip install --no-cache-dir -r requirements.txt 9 | 10 | # 复制应用代码 11 | COPY . . 12 | 13 | # 设置环境变量 14 | ENV PYTHONPATH=/app 15 | ENV API_CONFIG_PATH=/app/configs/api.yaml 16 | ENV REDIS_CONFIG_PATH=/app/configs/redis.yaml 17 | ENV OLLAMA_CONFIG_PATH=/app/configs/ollama.yaml 18 | 19 | # 设置端口 20 | EXPOSE 8000 21 | 22 | # 启动命令 23 | CMD ["uvicorn", "backend.api.main:app", "--host", "0.0.0.0", "--port", "8000"] 24 | docker/Dockerfile.worker 25 | 26 | FROM python:3.9-slim 27 | 28 | # 设置工作目录 29 | WORKDIR /app 30 | 31 | # 安装依赖 32 | COPY requirements.txt . 33 | RUN pip install --no-cache-dir -r requirements.txt 34 | 35 | # 复制应用代码 36 | COPY . . 37 | 38 | # 设置环境变量 39 | ENV PYTHONPATH=/app 40 | ENV WORKER_CONFIG_PATH=/app/configs/worker.yaml 41 | ENV REDIS_CONFIG_PATH=/app/configs/redis.yaml 42 | ENV OLLAMA_CONFIG_PATH=/app/configs/ollama.yaml 43 | 44 | # 启动命令 45 | CMD ["python", "-m", "backend.worker.main"] 46 | -------------------------------------------------------------------------------- /docker/Dockerfile.frontend: -------------------------------------------------------------------------------- 1 | FROM python:3.9-slim 2 | 3 | # 设置工作目录 4 | WORKDIR /app 5 | 6 | # 安装依赖 7 | COPY requirements.txt . 8 | RUN pip install --no-cache-dir -r requirements.txt 9 | 10 | # 复制应用代码 11 | COPY . . 12 | 13 | # 设置环境变量 14 | ENV PYTHONPATH=/app 15 | 16 | # 设置端口 17 | EXPOSE 8501 18 | 19 | # 启动命令 20 | CMD ["streamlit", "run", "frontend/app.py", "--server.port=8501", "--server.address=0.0.0.0"] 21 | -------------------------------------------------------------------------------- /docker/Dockerfile.worker: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /frontend/app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import requests 3 | import json 4 | import os 5 | import time 6 | import pandas as pd 7 | import plotly.express as px 8 | from pathlib import Path 9 | import yaml 10 | from io import BytesIO 11 | 12 | 13 | # 在app.py适当位置添加以下内容 14 | 15 | # 导入页面组件 16 | from pages import home, documents, chat, search, analysis, system_status 17 | 18 | 19 | 20 | # 加载配置 21 | config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../configs/api.yaml") 22 | with open(config_path, "r") as f: 23 | config = yaml.safe_load(f) 24 | 25 | # API基础URL 26 | API_URL = f"http://{config['server']['host']}:{config['server']['port']}" 27 | 28 | # 设置页面配置 29 | st.set_page_config( 30 | page_title="DeepSeek本地知识库系统", 31 | page_icon="📚", 32 | layout="wide", 33 | initial_sidebar_state="expanded" 34 | ) 35 | 36 | # 添加CSS样式 37 | st.markdown(""" 38 | 80 | """, unsafe_allow_html=True) 81 | 82 | # 导航设置 83 | def main(): 84 | # 侧边栏菜单 85 | st.sidebar.markdown("## 📚 DeepSeek知识库") 86 | page = st.sidebar.radio( 87 | "导航", 88 | ["首页", "文档管理", "聊天问答", "搜索", "数据分析", "系统状态"] 89 | ) 90 | 91 | # 根据页面选择渲染对应内容 92 | if page == "首页": 93 | home.render() 94 | elif page == "文档管理": 95 | documents.render() 96 | elif page == "聊天问答": 97 | chat.render() 98 | elif page == "搜索": 99 | search.render() 100 | elif page == "数据分析": 101 | analysis.render() 102 | elif page == "系统状态": 103 | system_status.render() 104 | # 侧边栏菜单 105 | # st.sidebar.markdown("## 📚 DeepSeek知识库") 106 | # page = st.sidebar.radio( 107 | # "导航", 108 | # ["首页", "文档管理", "聊天问答", "搜索", "数据分析", "系统状态"] 109 | # ) 110 | 111 | # # 根据页面选择渲染对应内容 112 | # if page == "首页": 113 | # render_home() 114 | # elif page == "文档管理": 115 | # render_document_manager() 116 | # elif page == "聊天问答": 117 | # render_chat() 118 | # elif page == "搜索": 119 | # render_search() 120 | # elif page == "数据分析": 121 | # render_analysis() 122 | # elif page == "系统状态": 123 | # render_system_status() 124 | 125 | # 首页 126 | def render_home(): 127 | st.markdown('
DeepSeek本地知识库系统
', unsafe_allow_html=True) 128 | 129 | # 系统概览 130 | st.markdown('
系统概览
', unsafe_allow_html=True) 131 | 132 | col1, col2, col3 = st.columns(3) 133 | 134 | with col1: 135 | st.markdown('
', unsafe_allow_html=True) 136 | st.metric(label="文档总数", value="123") 137 | st.markdown('
', unsafe_allow_html=True) 138 | 139 | with col2: 140 | st.markdown('
', unsafe_allow_html=True) 141 | st.metric(label="向量数据量", value="5.3 GB") 142 | st.markdown('
', unsafe_allow_html=True) 143 | 144 | with col3: 145 | st.markdown('
', unsafe_allow_html=True) 146 | st.metric(label="当日查询次数", value="457") 147 | st.markdown('
', unsafe_allow_html=True) 148 | 149 | # 快速导航 150 | st.markdown('
快速导航
', unsafe_allow_html=True) 151 | 152 | col1, col2 = st.columns(2) 153 | 154 | with col1: 155 | if st.button("📄 上传新文档", use_container_width=True): 156 | st.session_state.page = "文档管理" 157 | st.experimental_rerun() 158 | 159 | with col2: 160 | if st.button("💬 开始聊天", use_container_width=True): 161 | st.session_state.page = "聊天问答" 162 | st.experimental_rerun() 163 | 164 | # 最近活动 165 | st.markdown('
最近活动
', unsafe_allow_html=True) 166 | 167 | # 模拟活动数据 168 | activities = [ 169 | {"时间": "2025-02-28 14:32", "活动": "上传文档", "详情": "财务报表Q4.pdf"}, 170 | {"时间": "2025-02-28 13:45", "活动": "搜索查询", "详情": "2024年销售预测"}, 171 | {"时间": "2025-02-28 11:20", "活动": "数据分析", "详情": "市场趋势分析"}, 172 | {"时间": "2025-02-28 10:05", "活动": "聊天会话", "详情": "5条消息交互"}, 173 | {"时间": "2025-02-27 16:50", "活动": "上传文档", "详情": "战略规划2025.docx"} 174 | ] 175 | 176 | st.table(activities) 177 | 178 | # 文档管理页面 179 | def render_document_manager(): 180 | st.markdown('
文档管理
', unsafe_allow_html=True) 181 | 182 | tab1, tab2 = st.tabs(["上传文档", "文档列表"]) 183 | 184 | with tab1: 185 | st.markdown('
上传新文档
', unsafe_allow_html=True) 186 | 187 | # 文件上传界面 188 | uploaded_file = st.file_uploader("选择要上传的文件", 189 | type=["pdf", "docx", "txt", "xlsx", "html"]) 190 | 191 | # 元数据输入 192 | with st.expander("添加文档元数据"): 193 | col1, col2 = st.columns(2) 194 | with col1: 195 | title = st.text_input("标题") 196 | category = st.selectbox("分类", ["财务", "技术", "营销", "法律", "人力资源", "其他"]) 197 | with col2: 198 | tags = st.text_input("标签(用逗号分隔)") 199 | importance = st.slider("重要性", 1, 5, 3) 200 | 201 | # 上传按钮 202 | if st.button("上传并索引"): 203 | if uploaded_file is not None: 204 | with st.spinner("正在处理文档..."): 205 | # 模拟上传处理 206 | progress_bar = st.progress(0) 207 | for i in range(100): 208 | time.sleep(0.05) 209 | progress_bar.progress(i + 1) 210 | 211 | # 模拟成功消息 212 | st.success(f"文档 '{uploaded_file.name}' 上传成功并已添加到索引!") 213 | else: 214 | st.error("请先选择要上传的文件") 215 | 216 | with tab2: 217 | st.markdown('
文档列表
', unsafe_allow_html=True) 218 | 219 | # 搜索和过滤 220 | col1, col2, col3 = st.columns([2, 1, 1]) 221 | with col1: 222 | search_term = st.text_input("搜索文档", placeholder="输入关键词...") 223 | with col2: 224 | category_filter = st.selectbox("分类筛选", ["全部", "财务", "技术", "营销", "法律", "人力资源", "其他"]) 225 | with col3: 226 | sort_option = st.selectbox("排序方式", ["上传时间", "名称", "大小", "重要性"]) 227 | 228 | # 模拟文档列表 229 | documents = [ 230 | {"id": "doc1", "名称": "财务报表Q4.pdf", "大小": "2.3 MB", "上传时间": "2025-02-28", "分类": "财务", "状态": "已索引"}, 231 | {"id": "doc2", "名称": "产品规格说明.docx", "大小": "1.1 MB", "上传时间": "2025-02-27", "分类": "技术", "状态": "已索引"}, 232 | {"id": "doc3", "名称": "营销策略2025.pptx", "大小": "5.4 MB", "上传时间": "2025-02-26", "分类": "营销", "状态": "处理中"}, 233 | {"id": "doc4", "名称": "客户调研报告.xlsx", "大小": "3.7 MB", "上传时间": "2025-02-25", "分类": "营销", "状态": "已索引"}, 234 | {"id": "doc5", "名称": "法律合同模板.docx", "大小": "0.5 MB", "上传时间": "2025-02-24", "分类": "法律", "状态": "已索引"} 235 | ] 236 | 237 | # 显示文档列表 238 | st.dataframe(documents, use_container_width=True) 239 | 240 | # 批量操作 241 | col1, col2 = st.columns(2) 242 | with col1: 243 | if st.button("删除选中文档"): 244 | st.warning("请确认是否删除选中的文档?") 245 | with col2: 246 | if st.button("重新索引"): 247 | with st.spinner("正在重新索引..."): 248 | time.sleep(2) 249 | st.success("重新索引完成!") 250 | 251 | # 聊天问答页面 252 | def render_chat(): 253 | st.markdown('
知识库问答
', unsafe_allow_html=True) 254 | 255 | # 初始化会话历史 256 | if "messages" not in st.session_state: 257 | st.session_state.messages = [] 258 | 259 | # 显示聊天设置 260 | with st.sidebar.expander("聊天设置", expanded=False): 261 | use_rag = st.checkbox("使用知识库增强", value=True) 262 | if use_rag: 263 | search_scope = st.radio("搜索范围", ["全部文档", "选定文档"]) 264 | if search_scope == "选定文档": 265 | # 模拟文档选择 266 | selected_docs = st.multiselect( 267 | "选择文档", 268 | ["财务报表Q4.pdf", "产品规格说明.docx", "营销策略2025.pptx", "客户调研报告.xlsx", "法律合同模板.docx"] 269 | ) 270 | 271 | model = st.selectbox("模型", ["DeepSeek-V3", "DeepSeek-R1"]) 272 | temperature = st.slider("随机性", 0.0, 1.0, 0.7, 0.1) 273 | max_tokens = st.slider("最大生成长度", 256, 4096, 2048, 128) 274 | 275 | # 显示对话历史 276 | for message in st.session_state.messages: 277 | if message["role"] == "user": 278 | st.markdown(f'
{message["content"]}
', unsafe_allow_html=True) 279 | else: 280 | st.markdown(f'
{message["content"]}
', unsafe_allow_html=True) 281 | 282 | # 显示来源(如果有) 283 | if "sources" in message and message["sources"]: 284 | st.markdown('
', unsafe_allow_html=True) 285 | st.markdown("**参考来源:**") 286 | for source in message["sources"]: 287 | st.markdown(f"- {source['title']} (P{source['page']})") 288 | st.markdown('
', unsafe_allow_html=True) 289 | 290 | # 输入框 291 | user_input = st.text_area("输入您的问题", height=100) 292 | 293 | col1, col2 = st.columns([1, 5]) 294 | with col1: 295 | if st.button("发送", use_container_width=True): 296 | if user_input: 297 | # 添加用户消息 298 | st.session_state.messages.append({"role": "user", "content": user_input}) 299 | 300 | # 模拟API调用 301 | with st.spinner("思考中..."): 302 | time.sleep(1) # 模拟响应延迟 303 | 304 | # 模拟回复 305 | bot_reply = { 306 | "role": "assistant", 307 | "content": "根据我们的财务报表分析,2024年第四季度销售额比第三季度增长了15%,主要得益于新产品线的推出和假日促销活动的成功。总收入达到了1.2亿元,超过了预期目标8%。", 308 | "sources": [ 309 | {"title": "财务报表Q4.pdf", "page": 12}, 310 | {"title": "销售预测分析.xlsx", "page": 3} 311 | ] 312 | } 313 | 314 | # 添加回复 315 | st.session_state.messages.append(bot_reply) 316 | 317 | # 刷新界面显示新消息 318 | st.experimental_rerun() 319 | with col2: 320 | if st.button("清空对话", use_container_width=True): 321 | st.session_state.messages = [] 322 | st.experimental_rerun() 323 | 324 | # 搜索页面 325 | def render_search(): 326 | st.markdown('
知识库搜索
', unsafe_allow_html=True) 327 | 328 | # 搜索输入 329 | search_query = st.text_input("输入搜索查询", placeholder="例如:2024年销售预测...") 330 | 331 | col1, col2, col3 = st.columns([1, 1, 2]) 332 | with col1: 333 | search_mode = st.radio("搜索模式", ["语义搜索", "关键词搜索", "混合搜索"]) 334 | with col2: 335 | filter_category = st.multiselect("按分类筛选", ["财务", "技术", "营销", "法律", "人力资源"]) 336 | with col3: 337 | date_range = st.date_input("日期范围", []) 338 | 339 | search_button = st.button("搜索", use_container_width=True) 340 | 341 | # 如果搜索按钮被点击且有查询 342 | if search_button and search_query: 343 | with st.spinner("正在搜索..."): 344 | time.sleep(1) # 模拟搜索延迟 345 | 346 | # 模拟搜索结果 347 | search_results = [ 348 | { 349 | "title": "财务报表Q4.pdf", 350 | "page": 12, 351 | "text": "...2024年第四季度销售额比第三季度增长了15%,主要得益于新产品线的推出和假日促销活动的成功。总收入达到了1.2亿元,超过了预期目标8%...", 352 | "score": 0.92, 353 | "date": "2025-02-28" 354 | }, 355 | { 356 | "title": "销售预测分析.xlsx", 357 | "page": 3, 358 | "text": "...基于历史数据分析,我们预计2025年销售增长率将保持在12-15%之间,累计销售额预计达到4.8亿元...", 359 | "score": 0.87, 360 | "date": "2025-02-25" 361 | }, 362 | { 363 | "title": "营销策略2025.pptx", 364 | "page": 8, 365 | "text": "...针对2024年销售数据,新的数字营销策略将增加社交媒体投入30%,预计带来额外15%的销售增长...", 366 | "score": 0.81, 367 | "date": "2025-02-26" 368 | } 369 | ] 370 | 371 | # 显示搜索结果 372 | st.markdown(f"### 搜索结果: 找到 {len(search_results)} 条匹配项") 373 | 374 | for result in search_results: 375 | with st.expander(f"{result['title']} (相关度: {result['score']:.2f})"): 376 | st.markdown(f"**位置**: 第{result['page']}页") 377 | st.markdown(f"**更新日期**: {result['date']}") 378 | st.markdown("**内容片段**:") 379 | st.markdown(f"
{result['text']}
", unsafe_allow_html=True) 380 | 381 | col1, col2 = st.columns(2) 382 | with col1: 383 | st.button(f"打开文档 {result['title']}", key=f"open_{result['title']}") 384 | with col2: 385 | st.button(f"提问相关问题 {result['title']}", key=f"ask_{result['title']}") 386 | 387 | # 数据分析页面 388 | def render_analysis(): 389 | st.markdown('
数据分析
', unsafe_allow_html=True) 390 | 391 | tab1, tab2, tab3 = st.tabs(["文档分析", "文本分析", "预测分析"]) 392 | 393 | with tab1: 394 | st.markdown("### 文档分析") 395 | 396 | # 选择文档 397 | selected_doc = st.selectbox( 398 | "选择要分析的文档", 399 | ["财务报表Q4.pdf", "产品规格说明.docx", "营销策略2025.pptx", "客户调研报告.xlsx"] 400 | ) 401 | 402 | # 分析选项 403 | col1, col2 = st.columns(2) 404 | with col1: 405 | focus_areas = st.multiselect( 406 | "分析重点领域", 407 | ["关键数据指标", "趋势分析", "风险评估", "机会识别", "竞争分析"], 408 | ["关键数据指标", "趋势分析"] 409 | ) 410 | with col2: 411 | analysis_depth = st.select_slider( 412 | "分析深度", 413 | options=["基础概述", "标准分析", "深度分析"] 414 | ) 415 | 416 | # 分析按钮 417 | if st.button("开始文档分析", use_container_width=True): 418 | with st.spinner("正在分析文档..."): 419 | time.sleep(2) # 模拟分析延迟 420 | 421 | # 模拟分析结果 422 | st.success("分析完成!") 423 | 424 | st.markdown("#### 分析摘要") 425 | st.markdown(""" 426 | 该财务报表显示2024年第四季度业绩良好,销售额同比增长15%,总收入达1.2亿元,超预期8%。 427 | 主要增长来自新产品线(贡献35%)和假日促销活动(贡献25%)。 428 | 运营成本控制良好,同比下降3%,主要得益于数字化转型项目带来的效率提升。 429 | """) 430 | 431 | st.markdown("#### 关键发现") 432 | key_findings = [ 433 | "销售增长:Q4销售额同比增长15%,环比增长8%", 434 | "成本控制:运营成本同比下降3%", 435 | "利润率:毛利率提升2.5个百分点至38.5%", 436 | "区域表现:华东区表现最佳,增长22%", 437 | "挑战:供应链延迟导致某些产品线库存不足" 438 | ] 439 | for finding in key_findings: 440 | st.markdown(f"- {finding}") 441 | 442 | # 模拟可视化 443 | data = { 444 | "季度": ["Q1", "Q2", "Q3", "Q4"], 445 | "销售额": [78, 85, 102, 120], 446 | "成本": [52, 55, 65, 73], 447 | "利润": [26, 30, 37, 47] 448 | } 449 | df = pd.DataFrame(data) 450 | 451 | # 绘制图表 452 | fig = px.line(df, x="季度", y=["销售额", "成本", "利润"], 453 | title="2024年季度业绩趋势", 454 | labels={"value": "金额(百万元)", "variable": "指标"}) 455 | st.plotly_chart(fig, use_container_width=True) 456 | 457 | with tab2: 458 | st.markdown("### 文本分析") 459 | 460 | # 文本输入 461 | text_input = st.text_area( 462 | "输入要分析的文本", 463 | height=200, 464 | placeholder="粘贴需要分析的文本内容..." 465 | ) 466 | 467 | # 分析类型 468 | analysis_type = st.multiselect( 469 | "选择分析类型", 470 | ["情感分析", "关键词提取", "主题识别", "摘要生成", "实体识别"], 471 | ["情感分析", "关键词提取", "摘要生成"] 472 | ) 473 | 474 | # 分析按钮 475 | if st.button("分析文本", use_container_width=True): 476 | if text_input: 477 | with st.spinner("分析中..."): 478 | time.sleep(1.5) # 模拟分析延迟 479 | 480 | # 模拟分析结果 481 | st.success("文本分析完成!") 482 | 483 | col1, col2 = st.columns(2) 484 | 485 | with col1: 486 | if "情感分析" in analysis_type: 487 | st.markdown("#### 情感分析") 488 | st.progress(75) 489 | st.markdown("总体情感:**积极** (75%)") 490 | 491 | if "主题识别" in analysis_type: 492 | st.markdown("#### 主题识别") 493 | topics = ["业务增长", "市场扩张", "产品创新"] 494 | for topic in topics: 495 | st.markdown(f"- {topic}") 496 | 497 | with col2: 498 | if "关键词提取" in analysis_type: 499 | st.markdown("#### 关键词") 500 | keywords = ["销售增长", "市场份额", "产品线", "创新", "客户满意度"] 501 | for kw in keywords: 502 | st.markdown(f"- {kw}") 503 | 504 | if "实体识别" in analysis_type: 505 | st.markdown("#### 识别的实体") 506 | entities = [ 507 | {"text": "华东区", "type": "地理位置"}, 508 | {"text": "新产品A", "type": "产品"}, 509 | {"text": "2024年", "type": "时间"} 510 | ] 511 | for entity in entities: 512 | st.markdown(f"- {entity['text']} ({entity['type']})") 513 | 514 | if "摘要生成" in analysis_type: 515 | st.markdown("#### 自动摘要") 516 | st.markdown(""" 517 | 该文本主要讨论了公司2024年第四季度的业绩表现,重点关注销售增长和新产品线的成功。 518 | 文本表明公司在华东地区取得了显著增长,新产品线表现超出预期。 519 | 同时也提到了一些运营成本优化和未来增长策略。 520 | """) 521 | else: 522 | st.error("请输入要分析的文本") 523 | 524 | with tab3: 525 | st.markdown("### 预测分析") 526 | 527 | # 预测配置 528 | col1, col2 = st.columns(2) 529 | with col1: 530 | prediction_target = st.selectbox( 531 | "预测目标", 532 | ["销售额预测", "市场份额预测", "客户增长预测", "成本预测"] 533 | ) 534 | time_horizon = st.selectbox( 535 | "时间范围", 536 | ["下个季度", "未来6个月", "下一财年", "未来3年"] 537 | ) 538 | 539 | with col2: 540 | data_source = st.multiselect( 541 | "数据来源", 542 | ["历史销售数据", "市场研究报告", "竞争对手分析", "宏观经济指标", "客户反馈"], 543 | ["历史销售数据", "市场研究报告"] 544 | ) 545 | confidence_level = st.slider("置信度要求", 75, 99, 90) 546 | 547 | # 上传自定义数据(可选) 548 | upload_custom = st.checkbox("上传自定义数据") 549 | if upload_custom: 550 | custom_data = st.file_uploader("上传CSV或Excel文件", type=["csv", "xlsx"]) 551 | 552 | # 预测按钮 553 | if st.button("生成预测", use_container_width=True): 554 | with st.spinner("生成预测分析..."): 555 | time.sleep(2) # 模拟分析延迟 556 | 557 | # 模拟预测结果 558 | st.success("预测分析完成!") 559 | 560 | st.markdown("#### 预测摘要") 561 | st.markdown(""" 562 | 基于历史数据分析和当前市场趋势,预测2025年第一季度销售额将达到1.32亿元,同比增长17%,环比增长10%。 563 | 预测置信区间为1.26亿元至1.38亿元(90%置信度)。 564 | 增长主要来源预计是新产品线持续的市场渗透和线上渠道的扩展。 565 | """) 566 | 567 | # 预测图表 568 | forecast_data = { 569 | "时间": ["2024 Q1", "2024 Q2", "2024 Q3", "2024 Q4", "2025 Q1 (预测)"], 570 | "销售额": [78, 85, 102, 120, 132], 571 | "下限": [78, 85, 102, 120, 126], 572 | "上限": [78, 85, 102, 120, 138] 573 | } 574 | df = pd.DataFrame(forecast_data) 575 | 576 | fig = px.line(df, x="时间", y="销售额", 577 | title="销售额预测(含90%置信区间)", 578 | labels={"销售额": "金额(百万元)", "时间": "季度"}) 579 | 580 | # 添加置信区间 581 | fig.add_scatter(x=df["时间"], y=df["上限"], mode="lines", line=dict(width=0), showlegend=False) 582 | fig.add_scatter(x=df["时间"], y=df["下限"], mode="lines", line=dict(width=0), 583 | fill="tonexty", fillcolor="rgba(0,100,255,0.2)", name="90%置信区间") 584 | 585 | st.plotly_chart(fig, use_container_width=True) 586 | 587 | # 影响因素 588 | st.markdown("#### 关键影响因素") 589 | factors = [ 590 | {"因素": "新产品推出", "影响": "正面", "权重": 35}, 591 | {"因素": "市场竞争加剧", "影响": "负面", "权重": 20}, 592 | {"因素": "线上渠道扩展", "影响": "正面", "权重": 25}, 593 | {"因素": "季节性波动", "影响": "中性", "权重": 10}, 594 | {"因素": "宏观经济环境", "影响": "正面", "权重": 10} 595 | ] 596 | 597 | st.dataframe(factors, use_container_width=True) 598 | 599 | # 建议措施 600 | st.markdown("#### 建议措施") 601 | recommendations = [ 602 | "增加新产品线的营销投入,重点关注市场接受度高的产品", 603 | "加强线上销售渠道建设,优化用户体验", 604 | "密切监控竞争对手动态,调整差异化策略", 605 | "提前备货应对季节性需求高峰,优化供应链弹性", 606 | "建立动态预测模型,每月更新销售预测" 607 | ] 608 | 609 | for i, rec in enumerate(recommendations): 610 | st.markdown(f"{i+1}. {rec}") 611 | 612 | # 系统状态页面 613 | def render_system_status(): 614 | st.markdown('
系统状态
', unsafe_allow_html=True) 615 | 616 | col1, col2, col3 = st.columns(3) 617 | 618 | with col1: 619 | st.markdown('
', unsafe_allow_html=True) 620 | st.metric(label="CPU使用率", value="42%", delta="-5%") 621 | st.markdown('
', unsafe_allow_html=True) 622 | 623 | with col2: 624 | st.markdown('
', unsafe_allow_html=True) 625 | st.metric(label="内存使用率", value="68%", delta="3%") 626 | st.markdown('
', unsafe_allow_html=True) 627 | 628 | with col3: 629 | st.markdown('
', unsafe_allow_html=True) 630 | st.metric(label="GPU利用率", value="76%", delta="12%") 631 | st.markdown('
', unsafe_allow_html=True) 632 | 633 | # 系统组件状态 634 | st.markdown('
系统组件状态
', unsafe_allow_html=True) 635 | 636 | components = [ 637 | {"组件": "API服务", "状态": "运行中", "健康度": 100, "响应时间": "45ms"}, 638 | {"组件": "DeepSeek (Ollama)", "状态": "运行中", "健康度": 100, "响应时间": "350ms"}, 639 | {"组件": "Qdrant向量数据库", "状态": "运行中", "健康度": 98, "响应时间": "32ms"}, 640 | {"组件": "Redis缓存", "状态": "运行中", "健康度": 100, "响应时间": "5ms"}, 641 | {"组件": "工作进程", "状态": "运行中", "健康度": 100, "响应时间": "N/A"}, 642 | {"组件": "前端服务", "状态": "运行中", "健康度": 100, "响应时间": "N/A"} 643 | ] 644 | 645 | # 显示组件状态表格 646 | df = pd.DataFrame(components) 647 | st.dataframe(df, use_container_width=True) 648 | 649 | # 性能监控 650 | st.markdown('
性能监控
', unsafe_allow_html=True) 651 | 652 | tab1, tab2, tab3 = st.tabs(["资源使用", "请求统计", "索引状态"]) 653 | 654 | with tab1: 655 | # 模拟资源使用数据 656 | time_points = [f"{i}:00" for i in range(24)] 657 | cpu_usage = [25, 28, 27, 25, 23, 22, 25, 35, 45, 55, 65, 68, 70, 72, 68, 65, 60, 58, 55, 48, 42, 38, 32, 27] 658 | memory_usage = [45, 45, 45, 45, 45, 45, 45, 48, 55, 65, 68, 70, 72, 75, 73, 72, 70, 68, 65, 60, 55, 50, 48, 45] 659 | gpu_usage = [10, 10, 10, 10, 10, 10, 10, 35, 55, 65, 75, 82, 85, 88, 85, 80, 75, 70, 65, 50, 40, 30, 20, 15] 660 | 661 | usage_data = {"时间": time_points, "CPU使用率": cpu_usage, "内存使用率": memory_usage, "GPU使用率": gpu_usage} 662 | usage_df = pd.DataFrame(usage_data) 663 | 664 | fig = px.line(usage_df, x="时间", y=["CPU使用率", "内存使用率", "GPU使用率"], 665 | title="24小时资源使用趋势", 666 | labels={"value": "使用率 (%)", "variable": "资源类型"}) 667 | st.plotly_chart(fig, use_container_width=True) 668 | 669 | with tab2: 670 | # 模拟请求统计数据 671 | categories = ["文档上传", "文档检索", "聊天请求", "搜索查询", "分析任务"] 672 | today_counts = [24, 156, 432, 287, 63] 673 | yesterday_counts = [18, 142, 389, 253, 58] 674 | 675 | request_data = {"请求类型": categories, "今日请求数": today_counts, "昨日请求数": yesterday_counts} 676 | request_df = pd.DataFrame(request_data) 677 | 678 | fig = px.bar(request_df, x="请求类型", y=["今日请求数", "昨日请求数"], 679 | title="请求统计", 680 | barmode="group", 681 | labels={"value": "请求数", "variable": "时间"}) 682 | st.plotly_chart(fig, use_container_width=True) 683 | 684 | # 响应时间分布 685 | response_times = [ 686 | {"响应时间": "<100ms", "占比": 35}, 687 | {"响应时间": "100-300ms", "占比": 42}, 688 | {"响应时间": "300-500ms", "占比": 15}, 689 | {"响应时间": "500ms-1s", "占比": 6}, 690 | {"响应时间": ">1s", "占比": 2} 691 | ] 692 | 693 | response_df = pd.DataFrame(response_times) 694 | fig = px.pie(response_df, values="占比", names="响应时间", title="响应时间分布") 695 | st.plotly_chart(fig, use_container_width=True) 696 | 697 | with tab3: 698 | # 索引状态 699 | st.markdown("#### 向量索引状态") 700 | 701 | index_stats = { 702 | "总向量数": "1,523,482", 703 | "索引大小": "5.3 GB", 704 | "最后更新时间": "2025-02-28 14:32:45", 705 | "平均检索时间": "27ms", 706 | "索引分段数": "8", 707 | "删除向量占比": "3.2%" 708 | } 709 | 710 | for key, value in index_stats.items(): 711 | st.metric(label=key, value=value) 712 | 713 | if st.button("优化索引"): 714 | with st.spinner("正在优化索引..."): 715 | time.sleep(3) 716 | st.success("索引优化完成!检索性能提升约8%") 717 | 718 | if __name__ == "__main__": 719 | main() 720 | -------------------------------------------------------------------------------- /frontend/components/auth.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import requests 3 | import json 4 | from typing import Dict, Any, Optional 5 | 6 | class AuthManager: 7 | """认证管理器""" 8 | 9 | def __init__(self, api_url: str): 10 | """初始化认证管理器""" 11 | self.api_url = api_url 12 | 13 | # 初始化会话状态 14 | if "user" not in st.session_state: 15 | st.session_state.user = None 16 | if "token" not in st.session_state: 17 | st.session_state.token = None 18 | 19 | def login_form(self) -> bool: 20 | """显示登录表单并处理登录""" 21 | st.markdown('
用户登录
', unsafe_allow_html=True) 22 | 23 | # 登录表单 24 | with st.form("login_form"): 25 | username = st.text_input("用户名") 26 | password = st.text_input("密码", type="password") 27 | submit = st.form_submit_button("登录") 28 | 29 | if submit: 30 | if self._try_login(username, password): 31 | return True 32 | 33 | # 模拟登录 34 | st.markdown("### 快速测试登录") 35 | col1, col2 = st.columns(2) 36 | with col1: 37 | if st.button("管理员登录"): 38 | if self._try_login("admin", "password"): 39 | return True 40 | with col2: 41 | if st.button("普通用户登录"): 42 | if self._try_login("user", "password"): 43 | return True 44 | 45 | return False 46 | 47 | def _try_login(self, username: str, password: str) -> bool: 48 | """尝试登录""" 49 | try: 50 | # 模拟API调用 51 | # 实际应调用登录API 52 | # response = requests.post( 53 | # f"{self.api_url}/api/auth/token", 54 | # data={"username": username, "password": password} 55 | # ) 56 | # if response.status_code == 200: 57 | # data = response.json() 58 | # st.session_state.token = data["access_token"] 59 | # st.session_state.user = data["user"] 60 | # return True 61 | 62 | # 模拟成功登录 63 | if (username == "admin" and password == "password") or (username == "user" and password == "password"): 64 | # 模拟用户信息 65 | st.session_state.token = "fake_token_123456" 66 | st.session_state.user = { 67 | "id": "user_001" if username == "admin" else "user_002", 68 | "username": username, 69 | "role": "admin" if username == "admin" else "user", 70 | "full_name": "Admin User" if username == "admin" else "Normal User" 71 | } 72 | return True 73 | else: 74 | st.error("用户名或密码错误") 75 | return False 76 | 77 | except Exception as e: 78 | st.error(f"登录失败: {str(e)}") 79 | return False 80 | 81 | def logout(self): 82 | """退出登录""" 83 | st.session_state.token = None 84 | st.session_state.user = None 85 | 86 | def is_logged_in(self) -> bool: 87 | """检查是否已登录""" 88 | return st.session_state.token is not None 89 | 90 | def get_current_user(self) -> Optional[Dict[str, Any]]: 91 | """获取当前用户""" 92 | return st.session_state.user 93 | 94 | def show_user_info(self): 95 | """显示用户信息""" 96 | user = self.get_current_user() 97 | if user: 98 | st.sidebar.markdown(f"### 欢迎, {user.get('full_name', user.get('username'))}") 99 | st.sidebar.markdown(f"角色: {user.get('role', 'user')}") 100 | 101 | if st.sidebar.button("退出登录"): 102 | self.logout() 103 | st.experimental_rerun() -------------------------------------------------------------------------------- /frontend/components/ui.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from typing import Dict, List, Any, Optional 3 | 4 | def display_header(title: str, subtitle: Optional[str] = None): 5 | """显示页面标题和副标题""" 6 | st.markdown(f'
{title}
', unsafe_allow_html=True) 7 | if subtitle: 8 | st.markdown(f'

{subtitle}

', unsafe_allow_html=True) 9 | 10 | def display_info_card(label: str, value: str, delta: Optional[str] = None, color: str = "#EFF6FF"): 11 | """显示信息卡片""" 12 | st.markdown(f'
', unsafe_allow_html=True) 13 | st.metric(label=label, value=value, delta=delta) 14 | st.markdown('
', unsafe_allow_html=True) 15 | 16 | def display_document_card(doc: Dict[str, Any], on_click_callback=None): 17 | """显示文档卡片""" 18 | with st.container(): 19 | st.markdown(f""" 20 |
21 |

{doc.get('名称', 'Untitled')}

22 |

大小: {doc.get('大小', 'Unknown')}

23 |

上传时间: {doc.get('上传时间', 'Unknown')}

24 |

分类: {doc.get('分类', 'Uncategorized')}

25 |

状态: {doc.get('状态', 'Unknown')}

26 |
27 | """, unsafe_allow_html=True) 28 | 29 | if on_click_callback: 30 | if st.button("查看详情", key=f"view_{doc.get('id', '')}"): 31 | on_click_callback(doc) 32 | 33 | def display_chat_message(message: Dict[str, Any]): 34 | """显示聊天消息""" 35 | role = message.get("role", "unknown") 36 | content = message.get("content", "") 37 | 38 | if role == "user": 39 | st.markdown(f'
{content}
', unsafe_allow_html=True) 40 | else: 41 | st.markdown(f'
{content}
', unsafe_allow_html=True) 42 | 43 | # 显示来源(如果有) 44 | sources = message.get("sources", []) 45 | if sources: 46 | st.markdown('
', unsafe_allow_html=True) 47 | st.markdown("**参考来源:**") 48 | for source in sources: 49 | st.markdown(f"- {source.get('title', 'Unknown')} (P{source.get('page', '?')})") 50 | st.markdown('
', unsafe_allow_html=True) 51 | 52 | def display_search_result(result: Dict[str, Any]): 53 | """显示搜索结果""" 54 | with st.expander(f"{result.get('title', 'Untitled')} (相关度: {result.get('score', 0):.2f})"): 55 | st.markdown(f"**位置**: 第{result.get('page', '?')}页") 56 | st.markdown(f"**更新日期**: {result.get('date', 'Unknown')}") 57 | st.markdown("**内容片段**:") 58 | st.markdown(f"
{result.get('text', '')}
", unsafe_allow_html=True) 59 | 60 | col1, col2 = st.columns(2) 61 | with col1: 62 | st.button(f"打开文档", key=f"open_{result.get('title', '')}") 63 | with col2: 64 | st.button(f"提问相关问题", key=f"ask_{result.get('title', '')}") 65 | 66 | def get_status_color(status: str) -> str: 67 | """根据状态返回颜色""" 68 | status_colors = { 69 | "已索引": "green", 70 | "处理中": "orange", 71 | "错误": "red", 72 | "等待中": "blue" 73 | } 74 | return status_colors.get(status, "gray") -------------------------------------------------------------------------------- /frontend/pages/analysis.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | import time 4 | import plotly.express as px 5 | import requests 6 | import json 7 | 8 | def render(): 9 | st.markdown('
数据分析
', unsafe_allow_html=True) 10 | 11 | tab1, tab2, tab3 = st.tabs(["文档分析", "文本分析", "预测分析"]) 12 | 13 | with tab1: 14 | st.markdown("### 文档分析") 15 | 16 | # 选择文档 17 | selected_doc = st.selectbox( 18 | "选择要分析的文档", 19 | ["财务报表Q4.pdf", "产品规格说明.docx", "营销策略2025.pptx", "客户调研报告.xlsx"] 20 | ) 21 | 22 | # 分析选项 23 | col1, col2 = st.columns(2) 24 | with col1: 25 | focus_areas = st.multiselect( 26 | "分析重点领域", 27 | ["关键数据指标", "趋势分析", "风险评估", "机会识别", "竞争分析"], 28 | ["关键数据指标", "趋势分析"] 29 | ) 30 | with col2: 31 | analysis_depth = st.select_slider( 32 | "分析深度", 33 | options=["基础概述", "标准分析", "深度分析"] 34 | ) 35 | 36 | # 分析按钮 37 | if st.button("开始文档分析", use_container_width=True): 38 | with st.spinner("正在分析文档..."): 39 | time.sleep(2) # 模拟分析延迟 40 | 41 | # 模拟API请求 42 | # focus_areas_str = ", ".join(focus_areas) 43 | # response = requests.post( 44 | # f"{API_URL}/api/analysis/document", 45 | # json={"document_id": "doc1", "focus_areas": focus_areas_str, "instructions": analysis_depth} 46 | # ) 47 | # analysis_result = response.json() 48 | 49 | # 模拟分析结果 50 | st.success("分析完成!") 51 | 52 | st.markdown("#### 分析摘要") 53 | st.markdown(""" 54 | 该财务报表显示2024年第四季度业绩良好,销售额同比增长15%,总收入达1.2亿元,超预期8%。 55 | 主要增长来自新产品线(贡献35%)和假日促销活动(贡献25%)。 56 | 运营成本控制良好,同比下降3%,主要得益于数字化转型项目带来的效率提升。 57 | """) 58 | 59 | st.markdown("#### 关键发现") 60 | key_findings = [ 61 | "销售增长:Q4销售额同比增长15%,环比增长8%", 62 | "成本控制:运营成本同比下降3%", 63 | "利润率:毛利率提升2.5个百分点至38.5%", 64 | "区域表现:华东区表现最佳,增长22%", 65 | "挑战:供应链延迟导致某些产品线库存不足" 66 | ] 67 | for finding in key_findings: 68 | st.markdown(f"- {finding}") 69 | 70 | # 模拟可视化 71 | data = { 72 | "季度": ["Q1", "Q2", "Q3", "Q4"], 73 | "销售额": [78, 85, 102, 120], 74 | "成本": [52, 55, 65, 73], 75 | "利润": [26, 30, 37, 47] 76 | } 77 | df = pd.DataFrame(data) 78 | 79 | # 绘制图表 80 | fig = px.line(df, x="季度", y=["销售额", "成本", "利润"], 81 | title="2024年季度业绩趋势", 82 | labels={"value": "金额(百万元)", "variable": "指标"}) 83 | st.plotly_chart(fig, use_container_width=True) 84 | 85 | with tab2: 86 | st.markdown("### 文本分析") 87 | 88 | # 文本输入 89 | text_input = st.text_area( 90 | "输入要分析的文本", 91 | height=200, 92 | placeholder="粘贴需要分析的文本内容..." 93 | ) 94 | 95 | # 分析类型 96 | analysis_type = st.multiselect( 97 | "选择分析类型", 98 | ["情感分析", "关键词提取", "主题识别", "摘要生成", "实体识别"], 99 | ["情感分析", "关键词提取", "摘要生成"] 100 | ) 101 | 102 | # 分析按钮 103 | if st.button("分析文本", use_container_width=True): 104 | if text_input: 105 | with st.spinner("分析中..."): 106 | time.sleep(1.5) # 模拟分析延迟 107 | 108 | # 模拟API请求 109 | # focus_areas = ", ".join(analysis_type) 110 | # response = requests.post( 111 | # f"{API_URL}/api/analysis/text", 112 | # json={"text": text_input, "focus_areas": focus_areas} 113 | # ) 114 | # analysis_result = response.json() 115 | 116 | # 模拟分析结果 117 | st.success("文本分析完成!") 118 | 119 | col1, col2 = st.columns(2) 120 | 121 | with col1: 122 | if "情感分析" in analysis_type: 123 | st.markdown("#### 情感分析") 124 | st.progress(75) 125 | st.markdown("总体情感:**积极** (75%)") 126 | 127 | if "主题识别" in analysis_type: 128 | st.markdown("#### 主题识别") 129 | topics = ["业务增长", "市场扩张", "产品创新"] 130 | for topic in topics: 131 | st.markdown(f"- {topic}") 132 | 133 | with col2: 134 | if "关键词提取" in analysis_type: 135 | st.markdown("#### 关键词") 136 | keywords = ["销售增长", "市场份额", "产品线", "创新", "客户满意度"] 137 | for kw in keywords: 138 | st.markdown(f"- {kw}") 139 | 140 | if "实体识别" in analysis_type: 141 | st.markdown("#### 识别的实体") 142 | entities = [ 143 | {"text": "华东区", "type": "地理位置"}, 144 | {"text": "新产品A", "type": "产品"}, 145 | {"text": "2024年", "type": "时间"} 146 | ] 147 | for entity in entities: 148 | st.markdown(f"- {entity['text']} ({entity['type']})") 149 | 150 | if "摘要生成" in analysis_type: 151 | st.markdown("#### 自动摘要") 152 | st.markdown(""" 153 | 该文本主要讨论了公司2024年第四季度的业绩表现,重点关注销售增长和新产品线的成功。 154 | 文本表明公司在华东地区取得了显著增长,新产品线表现超出预期。 155 | 同时也提到了一些运营成本优化和未来增长策略。 156 | """) 157 | else: 158 | st.error("请输入要分析的文本") 159 | 160 | with tab3: 161 | st.markdown("### 预测分析") 162 | 163 | # 预测配置 164 | col1, col2 = st.columns(2) 165 | with col1: 166 | prediction_target = st.selectbox( 167 | "预测目标", 168 | ["销售额预测", "市场份额预测", "客户增长预测", "成本预测"] 169 | ) 170 | time_horizon = st.selectbox( 171 | "时间范围", 172 | ["下个季度", "未来6个月", "下一财年", "未来3年"] 173 | ) 174 | 175 | with col2: 176 | data_source = st.multiselect( 177 | "数据来源", 178 | ["历史销售数据", "市场研究报告", "竞争对手分析", "宏观经济指标", "客户反馈"], 179 | ["历史销售数据", "市场研究报告"] 180 | ) 181 | confidence_level = st.slider("置信度要求", 75, 99, 90) 182 | 183 | # 上传自定义数据(可选) 184 | upload_custom = st.checkbox("上传自定义数据") 185 | if upload_custom: 186 | custom_data = st.file_uploader("上传CSV或Excel文件", type=["csv", "xlsx"]) 187 | 188 | # 预测按钮 189 | if st.button("生成预测", use_container_width=True): 190 | with st.spinner("生成预测分析..."): 191 | time.sleep(2) # 模拟分析延迟 192 | 193 | # 模拟API请求 194 | # historical_data = "历史销售数据..." if not custom_data else "上传的数据..." 195 | # context = f"数据来源: {', '.join(data_source)}; 置信度: {confidence_level}%" 196 | # response = requests.post( 197 | # f"{API_URL}/api/analysis/predict", 198 | # json={ 199 | # "historical_data": historical_data, 200 | # "target": prediction_target, 201 | # "context": context 202 | # } 203 | # ) 204 | # prediction_result = response.json() 205 | 206 | # 模拟预测结果 207 | st.success("预测分析完成!") 208 | 209 | st.markdown("#### 预测摘要") 210 | st.markdown(""" 211 | 基于历史数据分析和当前市场趋势,预测2025年第一季度销售额将达到1.32亿元,同比增长17%,环比增长10%。 212 | 预测置信区间为1.26亿元至1.38亿元(90%置信度)。 213 | 增长主要来源预计是新产品线持续的市场渗透和线上渠道的扩展。 214 | """) 215 | 216 | # 预测图表 217 | forecast_data = { 218 | "时间": ["2024 Q1", "2024 Q2", "2024 Q3", "2024 Q4", "2025 Q1 (预测)"], 219 | "销售额": [78, 85, 102, 120, 132], 220 | "下限": [78, 85, 102, 120, 126], 221 | "上限": [78, 85, 102, 120, 138] 222 | } 223 | df = pd.DataFrame(forecast_data) 224 | 225 | fig = px.line(df, x="时间", y="销售额", 226 | title="销售额预测(含90%置信区间)", 227 | labels={"销售额": "金额(百万元)", "时间": "季度"}) 228 | 229 | # 添加置信区间 230 | fig.add_scatter(x=df["时间"], y=df["上限"], mode="lines", line=dict(width=0), showlegend=False) 231 | fig.add_scatter(x=df["时间"], y=df["下限"], mode="lines", line=dict(width=0), 232 | fill="tonexty", fillcolor="rgba(0,100,255,0.2)", name="90%置信区间") 233 | 234 | st.plotly_chart(fig, use_container_width=True) 235 | 236 | # 影响因素 237 | st.markdown("#### 关键影响因素") 238 | factors = [ 239 | {"因素": "新产品推出", "影响": "正面", "权重": 35}, 240 | {"因素": "市场竞争加剧", "影响": "负面", "权重": 20}, 241 | {"因素": "线上渠道扩展", "影响": "正面", "权重": 25}, 242 | {"因素": "季节性波动", "影响": "中性", "权重": 10}, 243 | {"因素": "宏观经济环境", "影响": "正面", "权重": 10} 244 | ] 245 | 246 | st.dataframe(factors, use_container_width=True) 247 | 248 | # 建议措施 249 | st.markdown("#### 建议措施") 250 | recommendations = [ 251 | "增加新产品线的营销投入,重点关注市场接受度高的产品", 252 | "加强线上销售渠道建设,优化用户体验", 253 | "密切监控竞争对手动态,调整差异化策略", 254 | "提前备货应对季节性需求高峰,优化供应链弹性", 255 | "建立动态预测模型,每月更新销售预测" 256 | ] 257 | 258 | for i, rec in enumerate(recommendations): 259 | st.markdown(f"{i+1}. {rec}") -------------------------------------------------------------------------------- /frontend/pages/chat.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | import time 4 | import json 5 | import uuid 6 | 7 | def render(): 8 | st.markdown('
知识库问答
', unsafe_allow_html=True) 9 | 10 | # 初始化会话历史 11 | if "messages" not in st.session_state: 12 | st.session_state.messages = [] 13 | 14 | # 显示聊天设置 15 | with st.sidebar.expander("聊天设置", expanded=False): 16 | use_rag = st.checkbox("使用知识库增强", value=True) 17 | if use_rag: 18 | search_scope = st.radio("搜索范围", ["全部文档", "选定文档"]) 19 | if search_scope == "选定文档": 20 | # 模拟文档选择 21 | selected_docs = st.multiselect( 22 | "选择文档", 23 | ["财务报表Q4.pdf", "产品规格说明.docx", "营销策略2025.pptx", "客户调研报告.xlsx", "法律合同模板.docx"] 24 | ) 25 | 26 | model = st.selectbox("模型", ["DeepSeek-V3", "DeepSeek-R1"]) 27 | temperature = st.slider("随机性", 0.0, 1.0, 0.7, 0.1) 28 | max_tokens = st.slider("最大生成长度", 256, 4096, 2048, 128) 29 | 30 | # 显示对话历史 31 | for message in st.session_state.messages: 32 | if message["role"] == "user": 33 | st.markdown(f'
{message["content"]}
', unsafe_allow_html=True) 34 | else: 35 | st.markdown(f'
{message["content"]}
', unsafe_allow_html=True) 36 | 37 | # 显示来源(如果有) 38 | if "sources" in message and message["sources"]: 39 | st.markdown('
', unsafe_allow_html=True) 40 | st.markdown("**参考来源:**") 41 | for source in message["sources"]: 42 | st.markdown(f"- {source['title']} (P{source['page']})") 43 | st.markdown('
', unsafe_allow_html=True) 44 | 45 | # 输入框 46 | user_input = st.text_area("输入您的问题", height=100) 47 | 48 | col1, col2 = st.columns([1, 5]) 49 | with col1: 50 | if st.button("发送", use_container_width=True): 51 | if user_input: 52 | # 添加用户消息 53 | st.session_state.messages.append({"role": "user", "content": user_input}) 54 | 55 | # 模拟API调用 56 | with st.spinner("思考中..."): 57 | time.sleep(1) # 模拟响应延迟 58 | 59 | # 模拟回复 60 | bot_reply = { 61 | "role": "assistant", 62 | "content": "根据我们的财务报表分析,2024年第四季度销售额比第三季度增长了15%,主要得益于新产品线的推出和假日促销活动的成功。总收入达到了1.2亿元,超过了预期目标8%。", 63 | "sources": [ 64 | {"title": "财务报表Q4.pdf", "page": 12}, 65 | {"title": "销售预测分析.xlsx", "page": 3} 66 | ] 67 | } 68 | 69 | # 添加回复 70 | st.session_state.messages.append(bot_reply) 71 | 72 | # 刷新界面显示新消息 73 | st.experimental_rerun() 74 | with col2: 75 | if st.button("清空对话", use_container_width=True): 76 | st.session_state.messages = [] 77 | st.experimental_rerun() -------------------------------------------------------------------------------- /frontend/pages/documents.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | import time 4 | import requests 5 | import json 6 | from datetime import datetime 7 | 8 | def render(): 9 | st.markdown('
文档管理
', unsafe_allow_html=True) 10 | 11 | tab1, tab2 = st.tabs(["上传文档", "文档列表"]) 12 | 13 | with tab1: 14 | st.markdown('
上传新文档
', unsafe_allow_html=True) 15 | 16 | # 文件上传界面 17 | uploaded_file = st.file_uploader("选择要上传的文件", 18 | type=["pdf", "docx", "txt", "xlsx", "html"]) 19 | 20 | # 元数据输入 21 | with st.expander("添加文档元数据"): 22 | col1, col2 = st.columns(2) 23 | with col1: 24 | title = st.text_input("标题") 25 | category = st.selectbox("分类", ["财务", "技术", "营销", "法律", "人力资源", "其他"]) 26 | with col2: 27 | tags = st.text_input("标签(用逗号分隔)") 28 | importance = st.slider("重要性", 1, 5, 3) 29 | 30 | # 上传按钮 31 | if st.button("上传并索引"): 32 | if uploaded_file is not None: 33 | with st.spinner("正在处理文档..."): 34 | # 准备元数据 35 | metadata = { 36 | "title": title if title else uploaded_file.name, 37 | "category": category, 38 | "tags": tags.split(",") if tags else [], 39 | "importance": importance, 40 | "upload_time": datetime.now().isoformat() 41 | } 42 | 43 | # 模拟上传请求 44 | # 注意:实际代码中应该调用API上传文档 45 | # files = {"file": uploaded_file.getbuffer()} 46 | # data = {"metadata": json.dumps(metadata)} 47 | # response = requests.post(f"{API_URL}/api/documents/upload", files=files, data=data) 48 | 49 | # 模拟上传进度 50 | progress_bar = st.progress(0) 51 | for i in range(100): 52 | time.sleep(0.05) 53 | progress_bar.progress(i + 1) 54 | 55 | # 模拟成功消息 56 | st.success(f"文档 '{uploaded_file.name}' 上传成功并已添加到索引!") 57 | else: 58 | st.error("请先选择要上传的文件") 59 | 60 | with tab2: 61 | st.markdown('
文档列表
', unsafe_allow_html=True) 62 | 63 | # 搜索和过滤 64 | col1, col2, col3 = st.columns([2, 1, 1]) 65 | with col1: 66 | search_term = st.text_input("搜索文档", placeholder="输入关键词...") 67 | with col2: 68 | category_filter = st.selectbox("分类筛选", ["全部", "财务", "技术", "营销", "法律", "人力资源", "其他"]) 69 | with col3: 70 | sort_option = st.selectbox("排序方式", ["上传时间", "名称", "大小", "重要性"]) 71 | 72 | # 模拟文档列表 73 | documents = [ 74 | {"id": "doc1", "名称": "财务报表Q4.pdf", "大小": "2.3 MB", "上传时间": "2025-02-28", "分类": "财务", "状态": "已索引"}, 75 | {"id": "doc2", "名称": "产品规格说明.docx", "大小": "1.1 MB", "上传时间": "2025-02-27", "分类": "技术", "状态": "已索引"}, 76 | {"id": "doc3", "名称": "营销策略2025.pptx", "大小": "5.4 MB", "上传时间": "2025-02-26", "分类": "营销", "状态": "处理中"}, 77 | {"id": "doc4", "名称": "客户调研报告.xlsx", "大小": "3.7 MB", "上传时间": "2025-02-25", "分类": "营销", "状态": "已索引"}, 78 | {"id": "doc5", "名称": "法律合同模板.docx", "大小": "0.5 MB", "上传时间": "2025-02-24", "分类": "法律", "状态": "已索引"} 79 | ] 80 | 81 | # 过滤文档 82 | if search_term: 83 | documents = [doc for doc in documents if search_term.lower() in doc["名称"].lower()] 84 | 85 | if category_filter != "全部": 86 | documents = [doc for doc in documents if doc["分类"] == category_filter] 87 | 88 | # 排序文档 89 | if sort_option == "上传时间": 90 | documents.sort(key=lambda x: x["上传时间"], reverse=True) 91 | elif sort_option == "名称": 92 | documents.sort(key=lambda x: x["名称"]) 93 | elif sort_option == "大小": 94 | # 转换大小字符串为数字进行排序 95 | def size_to_mb(size_str): 96 | value = float(size_str.split()[0]) 97 | return value 98 | documents.sort(key=lambda x: size_to_mb(x["大小"]), reverse=True) 99 | 100 | # 显示文档列表 101 | st.dataframe(documents, use_container_width=True) 102 | 103 | # 批量操作 104 | col1, col2 = st.columns(2) 105 | with col1: 106 | if st.button("删除选中文档"): 107 | st.warning("请确认是否删除选中的文档?") 108 | with col2: 109 | if st.button("重新索引"): 110 | with st.spinner("正在重新索引..."): 111 | time.sleep(2) 112 | st.success("重新索引完成!") -------------------------------------------------------------------------------- /frontend/pages/home.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | import plotly.express as px 4 | 5 | def render(): 6 | st.markdown('
DeepSeek本地知识库系统
', unsafe_allow_html=True) 7 | 8 | # 系统概览 9 | st.markdown('
系统概览
', unsafe_allow_html=True) 10 | 11 | col1, col2, col3 = st.columns(3) 12 | 13 | with col1: 14 | st.markdown('
', unsafe_allow_html=True) 15 | st.metric(label="文档总数", value="123") 16 | st.markdown('
', unsafe_allow_html=True) 17 | 18 | with col2: 19 | st.markdown('
', unsafe_allow_html=True) 20 | st.metric(label="向量数据量", value="5.3 GB") 21 | st.markdown('
', unsafe_allow_html=True) 22 | 23 | with col3: 24 | st.markdown('
', unsafe_allow_html=True) 25 | st.metric(label="当日查询次数", value="457") 26 | st.markdown('
', unsafe_allow_html=True) 27 | 28 | # 快速导航 29 | st.markdown('
快速导航
', unsafe_allow_html=True) 30 | 31 | col1, col2 = st.columns(2) 32 | 33 | with col1: 34 | if st.button("📄 上传新文档", use_container_width=True): 35 | st.session_state.page = "文档管理" 36 | st.experimental_rerun() 37 | 38 | with col2: 39 | if st.button("💬 开始聊天", use_container_width=True): 40 | st.session_state.page = "聊天问答" 41 | st.experimental_rerun() 42 | 43 | # 最近活动 44 | st.markdown('
最近活动
', unsafe_allow_html=True) 45 | 46 | # 模拟活动数据 47 | activities = [ 48 | {"时间": "2025-02-28 14:32", "活动": "上传文档", "详情": "财务报表Q4.pdf"}, 49 | {"时间": "2025-02-28 13:45", "活动": "搜索查询", "详情": "2024年销售预测"}, 50 | {"时间": "2025-02-28 11:20", "活动": "数据分析", "详情": "市场趋势分析"}, 51 | {"时间": "2025-02-28 10:05", "活动": "聊天会话", "详情": "5条消息交互"}, 52 | {"时间": "2025-02-27 16:50", "活动": "上传文档", "详情": "战略规划2025.docx"} 53 | ] 54 | 55 | st.table(activities) 56 | 57 | # 性能概览 58 | st.markdown('
系统性能
', unsafe_allow_html=True) 59 | 60 | # 模拟性能数据 61 | dates = pd.date_range(start='2025-02-20', end='2025-02-28') 62 | query_times = [120, 115, 118, 105, 98, 92, 85, 88, 90] 63 | 64 | performance_data = pd.DataFrame({ 65 | "日期": dates, 66 | "平均查询时间(ms)": query_times 67 | }) 68 | 69 | fig = px.line(performance_data, x='日期', y='平均查询时间(ms)', 70 | title='知识库查询性能趋势', 71 | labels={"平均查询时间(ms)": "毫秒"}) 72 | st.plotly_chart(fig, use_container_width=True) -------------------------------------------------------------------------------- /frontend/pages/search.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import time 3 | import pandas as pd 4 | import requests 5 | import json 6 | 7 | def render(): 8 | st.markdown('
知识库搜索
', unsafe_allow_html=True) 9 | 10 | # 搜索输入 11 | search_query = st.text_input("输入搜索查询", placeholder="例如:2024年销售预测...") 12 | 13 | col1, col2, col3 = st.columns([1, 1, 2]) 14 | with col1: 15 | search_mode = st.radio("搜索模式", ["语义搜索", "关键词搜索", "混合搜索"]) 16 | with col2: 17 | filter_category = st.multiselect("按分类筛选", ["财务", "技术", "营销", "法律", "人力资源"]) 18 | with col3: 19 | date_range = st.date_input("日期范围", []) 20 | 21 | search_button = st.button("搜索", use_container_width=True) 22 | 23 | # 如果搜索按钮被点击且有查询 24 | if search_button and search_query: 25 | with st.spinner("正在搜索..."): 26 | time.sleep(1) # 模拟搜索延迟 27 | 28 | # 模拟搜索请求 29 | # 注意:实际代码中应该调用API搜索 30 | # filters = {"category": filter_category} if filter_category else {} 31 | # if search_mode == "关键词搜索": 32 | # use_hybrid = True 33 | # else: 34 | # use_hybrid = False 35 | # response = requests.post( 36 | # f"{API_URL}/api/search", 37 | # json={"query": search_query, "use_hybrid": use_hybrid, "filters": filters, "limit": 10} 38 | # ) 39 | # search_results = response.json()["results"] 40 | 41 | # 模拟搜索结果 42 | search_results = [ 43 | { 44 | "title": "财务报表Q4.pdf", 45 | "page": 12, 46 | "text": "...2024年第四季度销售额比第三季度增长了15%,主要得益于新产品线的推出和假日促销活动的成功。总收入达到了1.2亿元,超过了预期目标8%...", 47 | "score": 0.92, 48 | "date": "2025-02-28" 49 | }, 50 | { 51 | "title": "销售预测分析.xlsx", 52 | "page": 3, 53 | "text": "...基于历史数据分析,我们预计2025年销售增长率将保持在12-15%之间,累计销售额预计达到4.8亿元...", 54 | "score": 0.87, 55 | "date": "2025-02-25" 56 | }, 57 | { 58 | "title": "营销策略2025.pptx", 59 | "page": 8, 60 | "text": "...针对2024年销售数据,新的数字营销策略将增加社交媒体投入30%,预计带来额外15%的销售增长...", 61 | "score": 0.81, 62 | "date": "2025-02-26" 63 | } 64 | ] 65 | 66 | # 过滤结果(如果应用了分类过滤器) 67 | if filter_category: 68 | # 模拟分类过滤 69 | if "财务" in filter_category: 70 | search_results = [r for r in search_results if "财务" in r["title"] or "销售" in r["title"]] 71 | if "营销" in filter_category and not any(r for r in search_results if "营销" in r["title"]): 72 | search_results.append({ 73 | "title": "营销策略报告.docx", 74 | "page": 5, 75 | "text": "...营销部门提出了创新的数字营销策略,瞄准电子商务平台和社交媒体...", 76 | "score": 0.76, 77 | "date": "2025-02-20" 78 | }) 79 | 80 | # 显示搜索结果 81 | st.markdown(f"### 搜索结果: 找到 {len(search_results)} 条匹配项") 82 | 83 | for result in search_results: 84 | with st.expander(f"{result['title']} (相关度: {result['score']:.2f})"): 85 | st.markdown(f"**位置**: 第{result['page']}页") 86 | st.markdown(f"**更新日期**: {result['date']}") 87 | st.markdown("**内容片段**:") 88 | st.markdown(f"
{result['text']}
", unsafe_allow_html=True) 89 | 90 | col1, col2 = st.columns(2) 91 | with col1: 92 | st.button(f"打开文档 {result['title']}", key=f"open_{result['title']}") 93 | with col2: 94 | st.button(f"提问相关问题 {result['title']}", key=f"ask_{result['title']}") -------------------------------------------------------------------------------- /frontend/pages/system_status.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | import plotly.express as px 4 | import time 5 | import requests 6 | import json 7 | 8 | def render(): 9 | st.markdown('
系统状态
', unsafe_allow_html=True) 10 | 11 | col1, col2, col3 = st.columns(3) 12 | 13 | with col1: 14 | st.markdown('
', unsafe_allow_html=True) 15 | st.metric(label="CPU使用率", value="42%", delta="-5%") 16 | st.markdown('
', unsafe_allow_html=True) 17 | 18 | with col2: 19 | st.markdown('
', unsafe_allow_html=True) 20 | st.metric(label="内存使用率", value="68%", delta="3%") 21 | st.markdown('
', unsafe_allow_html=True) 22 | 23 | with col3: 24 | st.markdown('
', unsafe_allow_html=True) 25 | st.metric(label="GPU利用率", value="76%", delta="12%") 26 | st.markdown('
', unsafe_allow_html=True) 27 | 28 | # 系统组件状态 29 | st.markdown('
系统组件状态
', unsafe_allow_html=True) 30 | 31 | # 模拟API请求获取系统状态 32 | # response = requests.get(f"{API_URL}/api/system/status") 33 | # components = response.json()["components"] 34 | 35 | # 模拟组件状态数据 36 | components = [ 37 | {"组件": "API服务", "状态": "运行中", "健康度": 100, "响应时间": "45ms"}, 38 | {"组件": "DeepSeek (Ollama)", "状态": "运行中", "健康度": 100, "响应时间": "350ms"}, 39 | {"组件": "Qdrant向量数据库", "状态": "运行中", "健康度": 98, "响应时间": "32ms"}, 40 | {"组件": "Redis缓存", "状态": "运行中", "健康度": 100, "响应时间": "5ms"}, 41 | {"组件": "工作进程", "状态": "运行中", "健康度": 100, "响应时间": "N/A"}, 42 | {"组件": "前端服务", "状态": "运行中", "健康度": 100, "响应时间": "N/A"} 43 | ] 44 | 45 | # 显示组件状态表格 46 | df = pd.DataFrame(components) 47 | st.dataframe(df, use_container_width=True) 48 | 49 | # 性能监控 50 | st.markdown('
性能监控
', unsafe_allow_html=True) 51 | 52 | tab1, tab2, tab3 = st.tabs(["资源使用", "请求统计", "索引状态"]) 53 | 54 | with tab1: 55 | # 模拟资源使用数据 56 | time_points = [f"{i}:00" for i in range(24)] 57 | cpu_usage = [25, 28, 27, 25, 23, 22, 25, 35, 45, 55, 65, 68, 70, 72, 68, 65, 60, 58, 55, 48, 42, 38, 32, 27] 58 | memory_usage = [45, 45, 45, 45, 45, 45, 45, 48, 55, 65, 68, 70, 72, 75, 73, 72, 70, 68, 65, 60, 55, 50, 48, 45] 59 | gpu_usage = [10, 10, 10, 10, 10, 10, 10, 35, 55, 65, 75, 82, 85, 88, 85, 80, 75, 70, 65, 50, 40, 30, 20, 15] 60 | 61 | usage_data = {"时间": time_points, "CPU使用率": cpu_usage, "内存使用率": memory_usage, "GPU使用率": gpu_usage} 62 | usage_df = pd.DataFrame(usage_data) 63 | 64 | fig = px.line(usage_df, x="时间", y=["CPU使用率", "内存使用率", "GPU使用率"], 65 | title="24小时资源使用趋势", 66 | labels={"value": "使用率 (%)", "variable": "资源类型"}) 67 | st.plotly_chart(fig, use_container_width=True) 68 | 69 | # 自动刷新选项 70 | auto_refresh = st.checkbox("自动刷新数据(每60秒)") 71 | if auto_refresh: 72 | st.info("已启用自动刷新,数据将每60秒更新一次") 73 | # 在实际应用中,可以使用st.experimental_rerun()来自动刷新 74 | 75 | with tab2: 76 | # 模拟请求统计数据 77 | categories = ["文档上传", "文档检索", "聊天请求", "搜索查询", "分析任务"] 78 | today_counts = [24, 156, 432, 287, 63] 79 | yesterday_counts = [18, 142, 389, 253, 58] 80 | 81 | request_data = {"请求类型": categories, "今日请求数": today_counts, "昨日请求数": yesterday_counts} 82 | request_df = pd.DataFrame(request_data) 83 | 84 | fig = px.bar(request_df, x="请求类型", y=["今日请求数", "昨日请求数"], 85 | title="请求统计", 86 | barmode="group", 87 | labels={"value": "请求数", "variable": "时间"}) 88 | st.plotly_chart(fig, use_container_width=True) 89 | 90 | # 响应时间分布 91 | response_times = [ 92 | {"响应时间": "<100ms", "占比": 35}, 93 | {"响应时间": "100-300ms", "占比": 42}, 94 | {"响应时间": "300-500ms", "占比": 15}, 95 | {"响应时间": "500ms-1s", "占比": 6}, 96 | {"响应时间": ">1s", "占比": 2} 97 | ] 98 | 99 | response_df = pd.DataFrame(response_times) 100 | fig = px.pie(response_df, values="占比", names="响应时间", title="响应时间分布") 101 | st.plotly_chart(fig, use_container_width=True) 102 | 103 | with tab3: 104 | # 索引状态 105 | st.markdown("#### 向量索引状态") 106 | 107 | index_stats = { 108 | "总向量数": "1,523,482", 109 | "索引大小": "5.3 GB", 110 | "最后更新时间": "2025-02-28 14:32:45", 111 | "平均检索时间": "27ms", 112 | "索引分段数": "8", 113 | "删除向量占比": "3.2%" 114 | } 115 | 116 | # 显示统计信息 117 | col1, col2, col3 = st.columns(3) 118 | 119 | for i, (key, value) in enumerate(index_stats.items()): 120 | with [col1, col2, col3][i % 3]: 121 | st.metric(label=key, value=value) 122 | 123 | # 增量索引统计 124 | st.markdown("#### 索引任务统计") 125 | 126 | # 模拟索引任务数据 127 | task_data = { 128 | "任务类型": ["全量重建", "增量索引", "文档添加", "文档删除", "文档更新"], 129 | "昨日次数": [1, 24, 18, 3, 5], 130 | "平均耗时(秒)": [320, 45, 12, 8, 25] 131 | } 132 | task_df = pd.DataFrame(task_data) 133 | st.dataframe(task_df, use_container_width=True) 134 | 135 | # 优化按钮 136 | if st.button("优化索引"): 137 | with st.spinner("正在优化索引..."): 138 | time.sleep(3) 139 | st.success("索引优化完成!检索性能提升约8%") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 后端依赖 2 | fastapi==0.103.2 3 | uvicorn==0.23.2 4 | pydantic==2.4.2 5 | python-multipart==0.0.6 6 | httpx==0.25.0 7 | redis==4.6.0 8 | qdrant-client==1.5.3 9 | PyJWT==2.8.0 10 | python-dotenv==1.0.0 11 | numpy==1.24.3 12 | networkx==3.1 13 | PyYAML==6.0.1 14 | 15 | # 文档处理 16 | PyMuPDF==1.22.5 17 | python-docx==0.8.11 18 | openpyxl==3.1.2 19 | markdown==3.4.4 20 | beautifulsoup4==4.12.2 21 | 22 | # 前端 23 | streamlit==1.27.2 24 | pandas==2.0.3 25 | plotly==5.17.0 26 | 27 | # 工具和辅助库 28 | tqdm==4.66.1 29 | tenacity==8.2.3 30 | ujson==5.8.0 -------------------------------------------------------------------------------- /scripts/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 高性能本地知识库系统安装脚本 4 | echo "DeepSeek本地知识库系统安装脚本" 5 | echo "================================" 6 | 7 | # 确认CUDA环境 8 | echo "正在检查CUDA环境..." 9 | if command -v nvidia-smi &> /dev/null; then 10 | echo "发现NVIDIA GPU:" 11 | nvidia-smi 12 | else 13 | echo "警告: 未检测到NVIDIA GPU。此系统需要NVIDIA GPU以获得最佳性能。" 14 | read -p "是否继续安装? (y/n): " continue_install 15 | if [[ "$continue_install" != "y" && "$continue_install" != "Y" ]]; then 16 | echo "安装已取消。" 17 | exit 1 18 | fi 19 | fi 20 | 21 | # 检查Docker和Docker Compose 22 | echo "检查Docker环境..." 23 | if command -v docker &> /dev/null; then 24 | docker_version=$(docker --version) 25 | echo "Docker已安装: $docker_version" 26 | else 27 | echo "错误: 未安装Docker。请先安装Docker。" 28 | echo "可参考: https://docs.docker.com/engine/install/ubuntu/" 29 | exit 1 30 | fi 31 | 32 | if command -v docker-compose &> /dev/null; then 33 | compose_version=$(docker-compose --version) 34 | echo "Docker Compose已安装: $compose_version" 35 | else 36 | echo "错误: 未安装Docker Compose。请先安装Docker Compose。" 37 | echo "可参考: https://docs.docker.com/compose/install/" 38 | exit 1 39 | fi 40 | 41 | # 创建必要的目录 42 | echo "创建项目目录..." 43 | mkdir -p data/uploads 44 | mkdir -p data/models 45 | mkdir -p logs 46 | 47 | # 设置权限 48 | echo "设置目录权限..." 49 | chmod -R 755 data 50 | chmod -R 755 logs 51 | 52 | # 确认Ollama GPU访问 53 | echo "配置Ollama GPU访问..." 54 | if command -v nvidia-smi &> /dev/null; then 55 | echo "确保nvidia-container-toolkit已安装" 56 | if ! command -v nvidia-container-toolkit &> /dev/null; then 57 | echo "安装nvidia-container-toolkit..." 58 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) 59 | curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - 60 | curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list 61 | sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit 62 | sudo systemctl restart docker 63 | fi 64 | fi 65 | 66 | # 拉取必要的Docker镜像 67 | echo "拉取必要的Docker镜像..." 68 | docker pull redis:6.2 69 | docker pull qdrant/qdrant:latest 70 | docker pull ollama/ollama:latest 71 | 72 | # 构建自定义镜像 73 | echo "构建项目Docker镜像..." 74 | docker-compose build 75 | 76 | # 设置DeepSeek模型 77 | echo "准备下载DeepSeek模型(这可能需要一些时间)..." 78 | read -p "是否现在下载DeepSeek模型? (y/n): " download_model 79 | if [[ "$download_model" == "y" || "$download_model" == "Y" ]]; then 80 | echo "启动Ollama服务..." 81 | docker-compose up -d ollama 82 | echo "等待Ollama服务启动..." 83 | sleep 10 84 | 85 | echo "下载DeepSeek-V3模型..." 86 | docker exec -it knowledge-base-system_ollama_1 ollama pull deepseek/deepseek-v3 87 | 88 | echo "下载DeepSeek嵌入模型..." 89 | docker exec -it knowledge-base-system_ollama_1 ollama pull deepseek/deepseek-embeddings 90 | 91 | echo "停止临时Ollama服务..." 92 | docker-compose stop ollama 93 | else 94 | echo "跳过模型下载。请在系统启动后手动下载模型。" 95 | fi 96 | 97 | # 完成安装 98 | echo "安装完成!" 99 | echo "使用以下命令启动系统:" 100 | echo "bash scripts/start_services.sh" 101 | -------------------------------------------------------------------------------- /scripts/start_services.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 高性能本地知识库系统启动脚本 4 | echo "DeepSeek本地知识库系统启动脚本" 5 | echo "================================" 6 | 7 | # 检查Docker状态 8 | echo "检查Docker服务..." 9 | if ! systemctl is-active --quiet docker; then 10 | echo "Docker服务未运行,正在启动..." 11 | sudo systemctl start docker 12 | fi 13 | 14 | # 检查NVIDIA GPU状态 15 | if command -v nvidia-smi &> /dev/null; then 16 | echo "检查GPU状态:" 17 | nvidia-smi | head -n 10 18 | else 19 | echo "警告: 未检测到NVIDIA GPU。系统性能可能受到影响。" 20 | fi 21 | 22 | # 启动服务 23 | echo "启动所有服务..." 24 | docker-compose up -d 25 | 26 | # 等待服务就绪 27 | echo "等待服务就绪..." 28 | sleep 10 29 | 30 | # 检查服务状态 31 | echo "检查服务状态:" 32 | docker-compose ps 33 | 34 | # 检查Ollama模型 35 | echo "检查DeepSeek模型状态..." 36 | docker exec -it knowledge-base-system_ollama_1 ollama list 37 | 38 | # 显示访问信息 39 | echo "================================" 40 | echo "系统已启动!" 41 | echo "- API服务: http://localhost:8000" 42 | echo "- 前端界面: http://localhost:8501" 43 | echo "- Qdrant管理界面: http://localhost:6333/dashboard" 44 | echo "================================" 45 | echo "使用以下命令查看日志:" 46 | echo "docker-compose logs -f" 47 | echo "使用以下命令停止服务:" 48 | echo "docker-compose down" 49 | echo "================================" 50 | 51 | --------------------------------------------------------------------------------