├── README.md ├── backend ├── api │ ├── core │ │ ├── oonfig.py │ │ └── security.py │ ├── deps │ │ └── auth.py │ ├── main.py │ ├── models │ │ ├── analysis.py │ │ ├── chat.py │ │ ├── document.py │ │ ├── search.py │ │ └── user.py │ └── routers │ │ ├── analysis.py │ │ ├── chat.py │ │ ├── documents.py │ │ └── search.py ├── embeddings │ ├── batch_processor.py │ └── model.py ├── processors │ ├── base.py │ ├── docx_processor.py │ ├── excel_processor.py │ ├── html_processor.py │ └── pdf_processor.py ├── services │ ├── document_service.py │ ├── graphrag_service.py │ ├── llm_service.py │ ├── search_service.py │ └── vector_store.py └── worker │ ├── main.py │ └── tasks │ └── document_processor.py ├── configs ├── api.yaml ├── ollama.yaml ├── qdrant.yaml ├── redis.yaml └── worker.yaml ├── docker-compose.yml ├── docker ├── .env ├── Dockerfile.api ├── Dockerfile.frontend └── Dockerfile.worker ├── frontend ├── app.py ├── components │ ├── auth.py │ └── ui.py └── pages │ ├── analysis.py │ ├── chat.py │ ├── documents.py │ ├── home.py │ ├── search.py │ └── system_status.py ├── requirements.txt └── scripts ├── install.sh └── start_services.sh /README.md: -------------------------------------------------------------------------------- 1 | <<<<<<< HEAD 2 | # DeepSeek本地知识库系统 3 | 4 | 一个高性能的本地知识库系统,基于DeepSeek大模型,支持多种文档格式、本地检索、问答和分析功能。 5 | 6 | ## 系统特性 7 | 8 | - **基于DeepSeek模型的本地部署**:无需外部API,保护数据隐私 9 | - **多格式文档处理**:支持PDF、Word、Excel、TXT、HTML等格式 10 | - **高效检索**:支持5000+份文档的语义检索 11 | - **GraphRAG增强**:基于图结构的高效检索与推理 12 | - **多模态交互**:聊天、搜索、分析多种交互方式 13 | - **高性能架构**:基于Redis+Qdrant的高性能存储 14 | - **容器化部署**:基于Docker的简单部署 15 | 16 | ## 系统要求 17 | 18 | - **操作系统**:Ubuntu 20.04 LTS 或更高版本 19 | - **硬件要求**: 20 | - NVIDIA GPU (推荐RTX 3080+) 21 | - CUDA 11.7+ 22 | - 内存 16GB+ (推荐32GB+) 23 | - 存储空间 20GB+ 24 | - **软件要求**: 25 | - Docker 和 Docker Compose 26 | - NVIDIA Container Toolkit 27 | 28 | ## 快速开始 29 | 30 | ### 安装 31 | 32 | 1. 克隆代码库: 33 | ```bash 34 | git clone https://github.com/your-username/knowledge-base-system.git 35 | cd knowledge-base-system 36 | 运行安装脚本: 37 | 38 | bash scripts/install.sh 39 | 启动服务: 40 | 41 | bash scripts/start_services.sh 42 | 访问系统 43 | 前端界面:http://localhost:8501 44 | API文档:http://localhost:8000/docs 45 | 使用指南 46 | 上传文档 47 | 访问前端界面 48 | 导航至"文档管理"页面 49 | 点击"上传文档"并选择要上传的文件 50 | 添加可选的元数据 51 | 点击"上传并索引" 52 | 聊天问答 53 | 导航至"聊天问答"页面 54 | 在输入框中输入您的问题 55 | 系统将从您的知识库中检索相关信息并生成回答 56 | 搜索 57 | 导航至"搜索"页面 58 | 输入搜索查询 59 | 查看匹配的文档片段 60 | 数据分析 61 | 导航至"数据分析"页面 62 | 选择文档或输入文本进行分析 63 | 查看生成的分析报告和可视化 64 | 项目结构 65 | 66 | knowledge-base-system/ 67 | ├── docker/ - Docker配置文件 68 | ├── backend/ - 后端服务 69 | │ ├── api/ - FastAPI应用 70 | │ ├── services/ - 核心服务 71 | │ ├── processors/ - 文档处理器 72 | │ └── embeddings/ - 向量化模块 73 | ├── frontend/ - Streamlit前端 74 | ├── configs/ - 配置文件 75 | └── scripts/ - 安装和启动脚本 76 | 问题排查 77 | 常见问题 78 | 服务启动失败 79 | 检查Docker服务是否运行 80 | 检查GPU驱动和CUDA是否正确安装 81 | 查看日志:docker-compose logs -f 82 | 无法上传文档 83 | 检查文件格式是否受支持 84 | 检查文件大小是否超过限制 85 | 查看API服务日志 86 | 模型加载错误 87 | 确保DeepSeek模型已正确下载 88 | 检查GPU内存是否足够 89 | 查看Ollama服务日志 90 | 贡献与支持 91 | 欢迎提交问题和改进建议,请使用GitHub Issues。 92 | 93 | 许可证 94 | 本项目采用MIT许可证,详情请参见LICENSE文件。 95 | ``` 96 | ======= 97 | # knowledge-base-system 98 | 基于DeepSeek模型的本地部署方案 能够处理并索引多种格式的本地文档(PDF、Word、Excel、TXT、HTML等) 支持约5000+份私有文档的高效检索与分析 具备互联网搜索能力,实现本地数据与网络数据的融合分析 提供数据分析、预测功能 提供本地知识库问答功能 99 | >>>>>>> f0738bb0562f5aa7542bcad55b2e7c97b705793a 100 | -------------------------------------------------------------------------------- /backend/api/core/oonfig.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Union, Optional, Dict, Any 3 | from pydantic import BaseModel, validator 4 | 5 | class ServerSettings(BaseModel): 6 | """服务器配置模型""" 7 | host: str 8 | port: int 9 | workers: int 10 | log_level: str 11 | debug: bool 12 | reload: bool 13 | 14 | class CorsSettings(BaseModel): 15 | """CORS配置模型""" 16 | allowed_origins: List[str] 17 | allowed_methods: List[str] 18 | allowed_headers: List[str] 19 | 20 | class SecuritySettings(BaseModel): 21 | """安全配置模型""" 22 | api_key_header: str 23 | api_key: str 24 | jwt_secret: str 25 | token_expire_minutes: int 26 | 27 | class RateLimitSettings(BaseModel): 28 | """速率限制配置模型""" 29 | enabled: bool 30 | max_requests: int 31 | time_window_seconds: int 32 | 33 | class Settings(BaseModel): 34 | """全局配置设置模型""" 35 | server: ServerSettings 36 | cors: CorsSettings 37 | security: SecuritySettings 38 | rate_limiting: RateLimitSettings 39 | 40 | @classmethod 41 | def from_yaml(cls, config_dict: Dict[str, Any]) -> "Settings": 42 | """从YAML配置字典创建设置对象""" 43 | return cls( 44 | server=ServerSettings(**config_dict.get("server", {})), 45 | cors=CorsSettings(**config_dict.get("cors", {})), 46 | security=SecuritySettings(**config_dict.get("security", {})), 47 | rate_limiting=RateLimitSettings(**config_dict.get("rate_limiting", {})) 48 | ) 49 | 50 | class Config: 51 | env_file = ".env" -------------------------------------------------------------------------------- /backend/api/core/security.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from typing import Optional, Dict, Any 3 | import jwt 4 | from fastapi import Depends, HTTPException, status 5 | from fastapi.security import OAuth2PasswordBearer, APIKeyHeader 6 | from passlib.context import CryptContext 7 | from ..models.user import User, TokenData 8 | 9 | # 密码处理上下文 10 | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") 11 | 12 | # OAuth2配置 13 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") 14 | 15 | # API密钥配置 16 | api_key_scheme = APIKeyHeader(name="X-API-Key") 17 | 18 | # 从配置加载设置 19 | try: 20 | import yaml 21 | with open("configs/api.yaml", "r") as f: 22 | config = yaml.safe_load(f) 23 | SECRET_KEY = config["security"]["jwt_secret"] 24 | ALGORITHM = "HS256" 25 | ACCESS_TOKEN_EXPIRE_MINUTES = config["security"]["token_expire_minutes"] 26 | API_KEY = config["security"]["api_key"] 27 | except: 28 | # 默认值,实际生产中应该从环境变量或安全存储加载 29 | SECRET_KEY = "your-jwt-secret-here" 30 | ALGORITHM = "HS256" 31 | ACCESS_TOKEN_EXPIRE_MINUTES = 1440 # 24小时 32 | API_KEY = "your-api-key-here" 33 | 34 | def verify_password(plain_password: str, hashed_password: str) -> bool: 35 | """验证密码""" 36 | return pwd_context.verify(plain_password, hashed_password) 37 | 38 | def get_password_hash(password: str) -> str: 39 | """获取密码哈希""" 40 | return pwd_context.hash(password) 41 | 42 | def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: 43 | """创建访问令牌""" 44 | to_encode = data.copy() 45 | 46 | # 设置过期时间 47 | if expires_delta: 48 | expire = datetime.utcnow() + expires_delta 49 | else: 50 | expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) 51 | 52 | to_encode.update({"exp": expire}) 53 | 54 | # 编码JWT 55 | encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) 56 | return encoded_jwt 57 | 58 | def decode_token(token: str) -> TokenData: 59 | """解码令牌""" 60 | try: 61 | # 解码JWT 62 | payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) 63 | username = payload.get("sub") 64 | exp = payload.get("exp") 65 | 66 | if username is None: 67 | raise HTTPException( 68 | status_code=status.HTTP_401_UNAUTHORIZED, 69 | detail="无效的认证凭证", 70 | headers={"WWW-Authenticate": "Bearer"}, 71 | ) 72 | 73 | return TokenData(username=username, exp=exp) 74 | except jwt.PyJWTError: 75 | raise HTTPException( 76 | status_code=status.HTTP_401_UNAUTHORIZED, 77 | detail="无效的认证凭证", 78 | headers={"WWW-Authenticate": "Bearer"}, 79 | ) -------------------------------------------------------------------------------- /backend/api/deps/auth.py: -------------------------------------------------------------------------------- 1 | from fastapi import Depends, HTTPException, status, WebSocket 2 | from fastapi.security import OAuth2PasswordBearer, APIKeyHeader 3 | from typing import Optional 4 | from ..core.security import decode_token, API_KEY 5 | from ..models.user import User, TokenData 6 | 7 | # OAuth2配置 8 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") 9 | 10 | # API密钥配置 11 | api_key_scheme = APIKeyHeader(name="X-API-Key") 12 | 13 | # 模拟用户数据库(实际应用中应从数据库获取) 14 | fake_users_db = { 15 | "admin": { 16 | "id": "user_001", 17 | "username": "admin", 18 | "email": "admin@example.com", 19 | "full_name": "Admin User", 20 | "hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", # "password" 21 | "disabled": False, 22 | "role": "admin" 23 | }, 24 | "user": { 25 | "id": "user_002", 26 | "username": "user", 27 | "email": "user@example.com", 28 | "full_name": "Normal User", 29 | "hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", # "password" 30 | "disabled": False, 31 | "role": "user" 32 | } 33 | } 34 | 35 | async def get_user_by_token(token_data: TokenData) -> User: 36 | """通过令牌数据获取用户""" 37 | if token_data.username not in fake_users_db: 38 | raise HTTPException( 39 | status_code=status.HTTP_404_NOT_FOUND, 40 | detail="用户不存在" 41 | ) 42 | 43 | user_data = fake_users_db[token_data.username] 44 | return User(**user_data) 45 | 46 | async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: 47 | """获取当前用户""" 48 | # 尝试解码令牌 49 | token_data = decode_token(token) 50 | 51 | # 获取用户 52 | user = await get_user_by_token(token_data) 53 | 54 | if user.disabled: 55 | raise HTTPException( 56 | status_code=status.HTTP_403_FORBIDDEN, 57 | detail="用户已禁用" 58 | ) 59 | 60 | return user 61 | 62 | async def validate_api_key(api_key: str = Depends(api_key_scheme)) -> bool: 63 | """验证API密钥""" 64 | if api_key != API_KEY: 65 | raise HTTPException( 66 | status_code=status.HTTP_401_UNAUTHORIZED, 67 | detail="无效的API密钥" 68 | ) 69 | return True 70 | 71 | async def get_token_from_websocket(websocket: WebSocket) -> Optional[str]: 72 | """从WebSocket连接获取令牌""" 73 | # 尝试从查询参数获取令牌 74 | token = websocket.query_params.get("token") 75 | 76 | # 如果查询参数中没有令牌,尝试从头部获取 77 | if not token: 78 | # 从Authorization头部获取 79 | auth_header = websocket.headers.get("authorization") 80 | if auth_header and auth_header.startswith("Bearer "): 81 | token = auth_header.replace("Bearer ", "") 82 | 83 | return token -------------------------------------------------------------------------------- /backend/api/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from fastapi import FastAPI, Depends, HTTPException 4 | from fastapi.middleware.cors import CORSMiddleware 5 | import yaml 6 | from .routers import documents, search, chat, analysis 7 | from .core.config import Settings 8 | 9 | # 加载配置 10 | config_path = os.getenv("API_CONFIG_PATH", "configs/api.yaml") 11 | with open(config_path, "r") as f: 12 | config = yaml.safe_load(f) 13 | 14 | settings = Settings(**config) 15 | 16 | # 设置日志 17 | logging.basicConfig( 18 | level=getattr(logging, settings.server.log_level.upper()), 19 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 20 | ) 21 | logger = logging.getLogger(__name__) 22 | 23 | # 初始化FastAPI应用 24 | app = FastAPI( 25 | title="知识库系统API", 26 | description="高性能本地知识库系统的API", 27 | version="1.0.0", 28 | ) 29 | 30 | # 设置CORS 31 | app.add_middleware( 32 | CORSMiddleware, 33 | allow_origins=settings.cors.allowed_origins, 34 | allow_credentials=True, 35 | allow_methods=settings.cors.allowed_methods, 36 | allow_headers=settings.cors.allowed_headers, 37 | ) 38 | 39 | # 包含路由器 40 | app.include_router(documents.router, prefix="/api/documents", tags=["documents"]) 41 | app.include_router(search.router, prefix="/api/search", tags=["search"]) 42 | app.include_router(chat.router, prefix="/api/chat", tags=["chat"]) 43 | app.include_router(analysis.router, prefix="/api/analysis", tags=["analysis"]) 44 | 45 | @app.get("/api/health") 46 | async def health_check(): 47 | return {"status": "healthy"} 48 | 49 | @app.get("/") 50 | async def root(): 51 | return { 52 | "message": "知识库系统API正在运行", 53 | "docs_url": "/docs", 54 | "version": app.version 55 | } 56 | 57 | if __name__ == "__main__": 58 | import uvicorn 59 | uvicorn.run( 60 | "main:app", 61 | host=settings.server.host, 62 | port=settings.server.port, 63 | workers=settings.server.workers 64 | ) 65 | -------------------------------------------------------------------------------- /backend/api/models/analysis.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional, List, Dict, Any 3 | 4 | class DocumentAnalysisRequest(BaseModel): 5 | """文档分析请求模型""" 6 | document_id: str 7 | focus_areas: str = "关键数据点和趋势" 8 | instructions: Optional[str] = None 9 | 10 | class DocumentAnalysisResponse(BaseModel): 11 | """文档分析响应模型""" 12 | document_id: str 13 | analysis: str 14 | summary: str 15 | 16 | class TextAnalysisRequest(BaseModel): 17 | """文本分析请求模型""" 18 | text: str 19 | focus_areas: str = "关键点和主题" 20 | instructions: Optional[str] = None 21 | 22 | class TextAnalysisResponse(BaseModel): 23 | """文本分析响应模型""" 24 | analysis: str 25 | summary: str 26 | 27 | class PredictionRequest(BaseModel): 28 | """预测请求模型""" 29 | historical_data: str 30 | target: str 31 | context: Optional[str] = None 32 | instructions: Optional[str] = None 33 | 34 | class PredictionFactor(BaseModel): 35 | """预测影响因素模型""" 36 | name: str 37 | impact: str # 'positive', 'negative', 'neutral' 38 | weight: float # 0.0 to 1.0 39 | 40 | class PredictionResponse(BaseModel): 41 | """预测响应模型""" 42 | prediction: str 43 | confidence: str # 'high', 'medium', 'low' 44 | factors: List[PredictionFactor] = [] -------------------------------------------------------------------------------- /backend/api/models/chat.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional, List, Dict, Any 3 | 4 | class ChatRequest(BaseModel): 5 | """聊天请求模型""" 6 | message: str 7 | system_prompt: Optional[str] = None 8 | use_rag: Optional[bool] = True 9 | user_specific: Optional[bool] = True 10 | document_ids: Optional[List[str]] = None 11 | max_context_chunks: int = 5 12 | 13 | class ChatSource(BaseModel): 14 | """聊天上下文来源模型""" 15 | id: str 16 | text: str 17 | metadata: Dict[str, Any] 18 | 19 | class ChatResponse(BaseModel): 20 | """聊天响应模型""" 21 | message_id: str 22 | answer: str 23 | sources: Optional[List[ChatSource]] = [] 24 | 25 | class ChatHistoryItem(BaseModel): 26 | """聊天历史记录项模型""" 27 | role: str # 'user' 或 'assistant' 28 | content: str 29 | message_id: Optional[str] = None 30 | timestamp: Optional[float] = None 31 | 32 | class ChatSessionResponse(BaseModel): 33 | """聊天会话响应模型""" 34 | session_id: str 35 | history: List[ChatHistoryItem] 36 | created_at: float 37 | updated_at: float -------------------------------------------------------------------------------- /backend/api/models/document.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional, List, Dict, Any 3 | from datetime import datetime 4 | from enum import Enum 5 | 6 | class DocumentStatus(str, Enum): 7 | PROCESSING = "processing" 8 | CHUNKING = "chunking" 9 | EMBEDDING = "embedding" 10 | INDEXING = "indexing" 11 | INDEXED = "indexed" 12 | ERROR = "error" 13 | 14 | class DocumentMetadata(BaseModel): 15 | """文档元数据模型""" 16 | id: str 17 | filename: str 18 | file_extension: str 19 | mime_type: str 20 | file_size: int 21 | file_hash: str 22 | upload_date: datetime 23 | user_id: str 24 | status: DocumentStatus 25 | chunks_count: Optional[int] = 0 26 | text_length: Optional[int] = 0 27 | processing_error: Optional[str] = None 28 | extracted_metadata: Optional[Dict[str, Any]] = None 29 | custom_metadata: Optional[Dict[str, Any]] = None 30 | 31 | class DocumentResponse(BaseModel): 32 | """文档响应模型""" 33 | document: DocumentMetadata 34 | message: Optional[str] = None 35 | 36 | class DocumentListResponse(BaseModel): 37 | """文档列表响应模型""" 38 | documents: List[DocumentMetadata] 39 | total: int 40 | limit: int 41 | offset: int 42 | 43 | class DocumentStatusResponse(BaseModel): 44 | """任务状态响应模型""" 45 | task_id: str 46 | status: Dict[str, Any] -------------------------------------------------------------------------------- /backend/api/models/search.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional, List, Dict, Any 3 | 4 | class SearchRequest(BaseModel): 5 | """搜索请求模型""" 6 | query: str 7 | user_id: Optional[str] = None 8 | limit: int = 10 9 | offset: int = 0 10 | use_hybrid: bool = True 11 | filters: Optional[Dict[str, Any]] = None 12 | 13 | class SearchResult(BaseModel): 14 | """搜索结果项模型""" 15 | id: str 16 | score: float 17 | text: str 18 | metadata: Dict[str, Any] 19 | 20 | class SearchResponse(BaseModel): 21 | """搜索响应模型""" 22 | results: List[SearchResult] 23 | count: int 24 | 25 | class RelatedQueryResponse(BaseModel): 26 | """相关查询响应模型""" 27 | queries: List[str] 28 | count: int -------------------------------------------------------------------------------- /backend/api/models/user.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field, EmailStr 2 | from typing import Optional, List, Dict, Any 3 | from datetime import datetime 4 | 5 | class User(BaseModel): 6 | """用户数据模型""" 7 | id: str 8 | username: str 9 | email: EmailStr 10 | full_name: Optional[str] = None 11 | disabled: Optional[bool] = False 12 | role: str = "user" 13 | created_at: datetime = Field(default_factory=datetime.now) 14 | last_login: Optional[datetime] = None 15 | 16 | class UserCreate(BaseModel): 17 | """创建用户请求模型""" 18 | username: str 19 | email: EmailStr 20 | password: str 21 | full_name: Optional[str] = None 22 | 23 | class UserUpdate(BaseModel): 24 | """更新用户请求模型""" 25 | email: Optional[EmailStr] = None 26 | full_name: Optional[str] = None 27 | password: Optional[str] = None 28 | disabled: Optional[bool] = None 29 | role: Optional[str] = None 30 | 31 | class UserInDB(User): 32 | """数据库存储的用户模型,包含哈希密码""" 33 | hashed_password: str 34 | 35 | class Token(BaseModel): 36 | """认证令牌模型""" 37 | access_token: str 38 | token_type: str 39 | 40 | class TokenData(BaseModel): 41 | """认证令牌数据模型""" 42 | username: Optional[str] = None 43 | exp: Optional[int] = None -------------------------------------------------------------------------------- /backend/api/routers/analysis.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Depends, Query 2 | from typing import List, Optional, Dict, Any 3 | import logging 4 | from ..deps.auth import get_current_user 5 | from ..models.user import User 6 | from ..models.analysis import ( 7 | DocumentAnalysisRequest, 8 | DocumentAnalysisResponse, 9 | TextAnalysisRequest, 10 | TextAnalysisResponse, 11 | PredictionRequest, 12 | PredictionResponse 13 | ) 14 | from ...services.llm_service import LLMService 15 | from ...services.vector_store import VectorStore 16 | from ...services.search_service import SearchService 17 | 18 | router = APIRouter() 19 | logger = logging.getLogger(__name__) 20 | 21 | # 初始化服务 22 | vector_store = VectorStore() 23 | llm_service = LLMService() 24 | search_service = SearchService(vector_store, llm_service) 25 | 26 | @router.post("/document", response_model=DocumentAnalysisResponse) 27 | async def analyze_document( 28 | request: DocumentAnalysisRequest, 29 | current_user: User = Depends(get_current_user) 30 | ): 31 | """分析文档内容""" 32 | try: 33 | # 创建文档过滤器 34 | filters = {"document_id": request.document_id} 35 | 36 | # 获取文档内容 37 | doc_chunks = search_service.search( 38 | query="", # 空查询,获取所有内容 39 | user_id=current_user.id, 40 | limit=100, # 限制返回块数 41 | filters=filters 42 | ) 43 | 44 | # 提取文档文本 45 | doc_text = " ".join([chunk["text"] for chunk in doc_chunks]) 46 | 47 | # 生成分析提示 48 | analysis_prompt = f"""请分析以下文档内容,重点关注{request.focus_areas}。请给出关键点、主要观点和结论。 49 | 50 | 文档内容: 51 | {doc_text[:8000]} # 防止提示过长 52 | 53 | 分析要求: {request.instructions if request.instructions else "提供全面分析"}""" 54 | 55 | # 调用LLM生成分析 56 | response = llm_service.generate(analysis_prompt) 57 | analysis = response["response"] 58 | 59 | return { 60 | "document_id": request.document_id, 61 | "analysis": analysis, 62 | "summary": analysis[:200] + "..." if len(analysis) > 200 else analysis 63 | } 64 | 65 | except Exception as e: 66 | logger.error(f"文档分析错误: {str(e)}") 67 | raise HTTPException(status_code=500, detail="文档分析失败") 68 | 69 | @router.post("/text", response_model=TextAnalysisResponse) 70 | async def analyze_text( 71 | request: TextAnalysisRequest, 72 | current_user: User = Depends(get_current_user) 73 | ): 74 | """分析任意文本内容""" 75 | try: 76 | # 生成分析提示 77 | analysis_prompt = f"""请分析以下文本内容,重点关注{request.focus_areas}。请给出关键点、主要观点和结论。 78 | 79 | 文本内容: 80 | {request.text[:8000]} # 防止提示过长 81 | 82 | 分析要求: {request.instructions if request.instructions else "提供全面分析"}""" 83 | 84 | # 调用LLM生成分析 85 | response = llm_service.generate(analysis_prompt) 86 | analysis = response["response"] 87 | 88 | return { 89 | "analysis": analysis, 90 | "summary": analysis[:200] + "..." if len(analysis) > 200 else analysis 91 | } 92 | 93 | except Exception as e: 94 | logger.error(f"文本分析错误: {str(e)}") 95 | raise HTTPException(status_code=500, detail="文本分析失败") 96 | 97 | @router.post("/predict", response_model=PredictionResponse) 98 | async def predict( 99 | request: PredictionRequest, 100 | current_user: User = Depends(get_current_user) 101 | ): 102 | """基于历史数据进行预测""" 103 | try: 104 | # 准备预测提示 105 | prediction_prompt = f"""根据以下历史数据和上下文,请预测{request.target}。 106 | 107 | 历史数据: 108 | {request.historical_data[:4000]} # 防止提示过长 109 | 110 | 上下文信息: 111 | {request.context[:4000] if request.context else "无额外上下文信息"} 112 | 113 | 预测目标: {request.target} 114 | 预测说明: {request.instructions if request.instructions else "提供合理预测"}""" 115 | 116 | # 调用LLM生成预测 117 | response = llm_service.generate(prediction_prompt) 118 | prediction_result = response["response"] 119 | 120 | return { 121 | "prediction": prediction_result, 122 | "confidence": "中等", # 假设的置信度,实际上需要更复杂的计算 123 | "factors": [] # 影响因素,可以通过二次分析提取 124 | } 125 | 126 | except Exception as e: 127 | logger.error(f"预测错误: {str(e)}") 128 | raise HTTPException(status_code=500, detail="生成预测失败") 129 | -------------------------------------------------------------------------------- /backend/api/routers/chat.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Depends, WebSocket, WebSocketDisconnect 2 | from fastapi.responses import StreamingResponse 3 | from typing import List, Optional, Dict, Any 4 | import logging 5 | import json 6 | import uuid 7 | import asyncio 8 | from ..deps.auth import get_current_user, get_token_from_websocket 9 | from ..models.user import User 10 | from ..models.chat import ChatRequest, ChatResponse 11 | from ...services.search_service import SearchService 12 | from ...services.llm_service import LLMService 13 | from ...services.vector_store import VectorStore 14 | from ...services.graphrag_service import GraphRAGService 15 | 16 | router = APIRouter() 17 | logger = logging.getLogger(__name__) 18 | 19 | # 初始化服务 20 | vector_store = VectorStore() 21 | llm_service = LLMService() 22 | search_service = SearchService(vector_store, llm_service) 23 | graphrag_service = GraphRAGService(vector_store, llm_service) 24 | 25 | # 存储活跃的WebSocket连接 26 | active_connections = {} 27 | 28 | @router.post("/", response_model=ChatResponse) 29 | async def chat_completion( 30 | request: ChatRequest, 31 | current_user: User = Depends(get_current_user) 32 | ): 33 | """标准聊天请求""" 34 | try: 35 | # 确定是否需要RAG 36 | use_rag = request.use_rag if request.use_rag is not None else True 37 | 38 | if use_rag: 39 | # 使用GraphRAG提供上下文 40 | context = graphrag_service.get_context_for_query( 41 | query=request.message, 42 | user_id=current_user.id if request.user_specific else None, 43 | document_ids=request.document_ids, 44 | max_results=request.max_context_chunks 45 | ) 46 | 47 | # 生成带上下文的提示 48 | augmented_prompt = f"""请根据以下上下文和相关信息回答问题。如果上下文信息不足以回答问题,请明确指出。 49 | 50 | 上下文信息: 51 | {context} 52 | 53 | 用户问题: {request.message}""" 54 | 55 | # 调用LLM生成回复 56 | response = llm_service.generate(augmented_prompt, system_prompt=request.system_prompt) 57 | answer = response["response"] 58 | else: 59 | # 直接调用LLM 60 | response = llm_service.generate(request.message, system_prompt=request.system_prompt) 61 | answer = response["response"] 62 | 63 | return { 64 | "message_id": str(uuid.uuid4()), 65 | "answer": answer, 66 | "sources": [{"id": s["id"], "text": s["text"], "metadata": s["metadata"]} for s in context] if use_rag else [] 67 | } 68 | 69 | except Exception as e: 70 | logger.error(f"聊天处理错误: {str(e)}") 71 | raise HTTPException(status_code=500, detail="处理聊天请求失败") 72 | 73 | @router.post("/stream", response_class=StreamingResponse) 74 | async def stream_chat_completion( 75 | request: ChatRequest, 76 | current_user: User = Depends(get_current_user) 77 | ): 78 | """流式聊天响应""" 79 | try: 80 | # 确定是否需要RAG 81 | use_rag = request.use_rag if request.use_rag is not None else True 82 | context = [] 83 | 84 | if use_rag: 85 | # 使用GraphRAG提供上下文 86 | context = graphrag_service.get_context_for_query( 87 | query=request.message, 88 | user_id=current_user.id if request.user_specific else None, 89 | document_ids=request.document_ids, 90 | max_results=request.max_context_chunks 91 | ) 92 | 93 | # 生成带上下文的提示 94 | augmented_prompt = f"""请根据以下上下文和相关信息回答问题。如果上下文信息不足以回答问题,请明确指出。 95 | 96 | 上下文信息: 97 | {context} 98 | 99 | 用户问题: {request.message}""" 100 | prompt = augmented_prompt 101 | else: 102 | prompt = request.message 103 | 104 | # 创建一个流式响应生成器 105 | async def generate(): 106 | # 发送元数据和来源信息 107 | message_id = str(uuid.uuid4()) 108 | metadata = { 109 | "message_id": message_id, 110 | "sources": [{"id": s["id"], "text": s["text"], "metadata": s["metadata"]} for s in context] if use_rag else [] 111 | } 112 | yield f"data: {json.dumps({'type': 'metadata', 'content': metadata})}\n\n" 113 | 114 | # 流式生成内容 115 | for chunk in llm_service.generate_stream(prompt, system_prompt=request.system_prompt): 116 | yield f"data: {json.dumps({'type': 'content', 'content': chunk})}\n\n" 117 | 118 | # 发送完成信号 119 | yield f"data: {json.dumps({'type': 'done'})}\n\n" 120 | 121 | return StreamingResponse( 122 | generate(), 123 | media_type="text/event-stream" 124 | ) 125 | 126 | except Exception as e: 127 | logger.error(f"流式聊天处理错误: {str(e)}") 128 | raise HTTPException(status_code=500, detail="处理聊天请求失败") 129 | -------------------------------------------------------------------------------- /backend/api/routers/documents.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends, Query 2 | from fastapi.responses import JSONResponse 3 | from typing import List, Optional, Dict, Any 4 | import json 5 | import logging 6 | import tempfile 7 | import os 8 | import uuid 9 | from ..deps.auth import get_current_user 10 | from ..models.user import User 11 | from ..models.document import DocumentResponse, DocumentListResponse, DocumentStatusResponse 12 | from ...services.document_service import DocumentService 13 | from ...services.vector_store import VectorStore 14 | 15 | router = APIRouter() 16 | logger = logging.getLogger(__name__) 17 | 18 | # 初始化服务 19 | vector_store = VectorStore() 20 | document_service = DocumentService(vector_store) 21 | 22 | @router.post("/upload", response_model=DocumentResponse) 23 | async def upload_document( 24 | file: UploadFile = File(...), 25 | metadata: Optional[str] = Form(None), 26 | current_user: User = Depends(get_current_user) 27 | ): 28 | """上传文档进行处理和索引""" 29 | try: 30 | # 解析元数据 31 | metadata_dict = json.loads(metadata) if metadata else {} 32 | 33 | # 创建临时文件 34 | with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file: 35 | # 写入上传的文件内容 36 | content = await file.read() 37 | temp_file.write(content) 38 | temp_file_path = temp_file.name 39 | 40 | try: 41 | # 创建任务ID 42 | task_id = str(uuid.uuid4()) 43 | 44 | # 提交处理任务 45 | doc_metadata = document_service.process_document( 46 | file=open(temp_file_path, 'rb'), 47 | filename=file.filename, 48 | user_id=current_user.id, 49 | metadata=metadata_dict 50 | ) 51 | 52 | return JSONResponse( 53 | status_code=202, # Accepted 54 | content={"message": "文档上传已接受处理", "document": doc_metadata, "task_id": task_id} 55 | ) 56 | finally: 57 | # 删除临时文件 58 | os.unlink(temp_file_path) 59 | 60 | except ValueError as e: 61 | raise HTTPException(status_code=400, detail=str(e)) 62 | except Exception as e: 63 | logger.error(f"上传文档错误: {str(e)}") 64 | raise HTTPException(status_code=500, detail=f"文档上传失败: {str(e)}") 65 | 66 | @router.get("/{document_id}", response_model=DocumentResponse) 67 | async def get_document( 68 | document_id: str, 69 | current_user: User = Depends(get_current_user) 70 | ): 71 | """获取文档状态和元数据""" 72 | try: 73 | doc_metadata = document_service.get_document_metadata(document_id) 74 | 75 | # 检查授权 76 | if doc_metadata.get("user_id") != current_user.id: 77 | raise HTTPException(status_code=403, detail="无权访问此文档") 78 | 79 | return {"document": doc_metadata} 80 | 81 | except ValueError as e: 82 | raise HTTPException(status_code=404, detail=str(e)) 83 | except Exception as e: 84 | logger.error(f"获取文档 {document_id} 错误: {str(e)}") 85 | raise HTTPException(status_code=500, detail=f"获取文档失败: {str(e)}") 86 | 87 | @router.get("/", response_model=DocumentListResponse) 88 | async def list_documents( 89 | limit: int = Query(100, ge=1, le=1000), 90 | offset: int = Query(0, ge=0), 91 | category: Optional[str] = None, 92 | status: Optional[str] = None, 93 | current_user: User = Depends(get_current_user) 94 | ): 95 | """列出当前用户的所有文档""" 96 | try: 97 | # 准备过滤器 98 | filters = {"user_id": current_user.id} 99 | if category: 100 | filters["category"] = category 101 | if status: 102 | filters["status"] = status 103 | 104 | # 获取文档列表 105 | documents, total_count = document_service.get_all_documents( 106 | filters=filters, 107 | limit=limit, 108 | offset=offset 109 | ) 110 | 111 | return {"documents": documents, "total": total_count, "limit": limit, "offset": offset} 112 | 113 | except Exception as e: 114 | logger.error(f"列出文档错误: {str(e)}") 115 | raise HTTPException(status_code=500, detail=f"获取文档列表失败: {str(e)}") 116 | 117 | @router.delete("/{document_id}") 118 | async def delete_document( 119 | document_id: str, 120 | current_user: User = Depends(get_current_user) 121 | ): 122 | """删除文档及其索引""" 123 | try: 124 | # 检查文档所有权 125 | doc_metadata = document_service.get_document_metadata(document_id) 126 | if doc_metadata.get("user_id") != current_user.id: 127 | raise HTTPException(status_code=403, detail="无权删除此文档") 128 | 129 | # 删除文档 130 | document_service.delete_document(document_id) 131 | 132 | return {"message": f"文档 {document_id} 已成功删除"} 133 | 134 | except ValueError as e: 135 | raise HTTPException(status_code=404, detail=str(e)) 136 | except Exception as e: 137 | logger.error(f"删除文档 {document_id} 错误: {str(e)}") 138 | raise HTTPException(status_code=500, detail=f"删除文档失败: {str(e)}") 139 | 140 | @router.post("/{document_id}/reindex") 141 | async def reindex_document( 142 | document_id: str, 143 | current_user: User = Depends(get_current_user) 144 | ): 145 | """重新索引文档""" 146 | try: 147 | # 检查文档所有权 148 | doc_metadata = document_service.get_document_metadata(document_id) 149 | if doc_metadata.get("user_id") != current_user.id: 150 | raise HTTPException(status_code=403, detail="无权重新索引此文档") 151 | 152 | # 创建重新索引任务 153 | task_id = document_service.reindex_document(document_id) 154 | 155 | return {"message": f"文档 {document_id} 重新索引任务已创建", "task_id": task_id} 156 | 157 | except ValueError as e: 158 | raise HTTPException(status_code=404, detail=str(e)) 159 | except Exception as e: 160 | logger.error(f"重新索引文档 {document_id} 错误: {str(e)}") 161 | raise HTTPException(status_code=500, detail=f"重新索引文档失败: {str(e)}") 162 | 163 | @router.get("/task/{task_id}", response_model=DocumentStatusResponse) 164 | async def get_task_status( 165 | task_id: str, 166 | current_user: User = Depends(get_current_user) 167 | ): 168 | """获取文档处理任务状态""" 169 | try: 170 | task_status = document_service.get_task_status(task_id) 171 | 172 | # 检查任务所有权 173 | if task_status.get("user_id") and task_status.get("user_id") != current_user.id: 174 | raise HTTPException(status_code=403, detail="无权查看此任务状态") 175 | 176 | return {"task_id": task_id, "status": task_status} 177 | 178 | except ValueError as e: 179 | raise HTTPException(status_code=404, detail=str(e)) 180 | except Exception as e: 181 | logger.error(f"获取任务 {task_id} 状态错误: {str(e)}") 182 | raise HTTPException(status_code=500, detail=f"获取任务状态失败: {str(e)}") 183 | -------------------------------------------------------------------------------- /backend/api/routers/search.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Depends, Query 2 | from typing import List, Optional, Dict, Any 3 | import logging 4 | from ..deps.auth import get_current_user 5 | from ..models.user import User 6 | from ..models.search import SearchRequest, SearchResponse 7 | from ...services.search_service import SearchService 8 | from ...services.llm_service import LLMService 9 | from ...services.vector_store import VectorStore 10 | 11 | router = APIRouter() 12 | logger = logging.getLogger(__name__) 13 | 14 | # 初始化服务 15 | vector_store = VectorStore() 16 | llm_service = LLMService() 17 | search_service = SearchService(vector_store, llm_service) 18 | 19 | @router.post("/", response_model=SearchResponse) 20 | async def search_documents( 21 | request: SearchRequest, 22 | current_user: User = Depends(get_current_user) 23 | ): 24 | """搜索文档库""" 25 | try: 26 | results = search_service.search( 27 | query=request.query, 28 | user_id=current_user.id, 29 | limit=request.limit, 30 | use_hybrid=request.use_hybrid, 31 | filters=request.filters 32 | ) 33 | 34 | return {"results": results, "count": len(results)} 35 | 36 | except Exception as e: 37 | logger.error(f"搜索错误: {str(e)}") 38 | raise HTTPException(status_code=500, detail="搜索执行失败") 39 | 40 | @router.get("/documents/{document_id}", response_model=SearchResponse) 41 | async def search_within_document( 42 | document_id: str, 43 | query: str = Query(..., min_length=1), 44 | limit: int = Query(10, ge=1, le=100), 45 | current_user: User = Depends(get_current_user) 46 | ): 47 | """在特定文档内搜索""" 48 | try: 49 | # 创建文档过滤器 50 | filters = {"document_id": document_id} 51 | 52 | results = search_service.search( 53 | query=query, 54 | user_id=current_user.id, 55 | limit=limit, 56 | filters=filters 57 | ) 58 | 59 | return {"results": results, "count": len(results)} 60 | 61 | except Exception as e: 62 | logger.error(f"在文档内搜索错误: {str(e)}") 63 | raise HTTPException(status_code=500, detail="搜索执行失败") 64 | -------------------------------------------------------------------------------- /backend/embeddings/batch_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | from typing import List, Optional, Dict, Any 5 | import numpy as np 6 | from concurrent.futures import ThreadPoolExecutor 7 | from .model import get_embeddings 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | class BatchProcessor: 12 | """批量处理嵌入向量的工具类""" 13 | 14 | def __init__(self, config_path: str = "configs/worker.yaml"): 15 | # 加载配置 16 | with open(config_path, "r") as f: 17 | self.config = yaml.safe_load(f) 18 | 19 | # 获取批处理设置 20 | batch_settings = self.config["embedding"] 21 | self.batch_size = batch_settings["batch_size"] 22 | self.max_workers = batch_settings["max_workers"] 23 | 24 | logger.info(f"嵌入批处理器初始化完成,批大小: {self.batch_size}, 最大工作线程: {self.max_workers}") 25 | 26 | def process_in_batches(self, texts: List[str]) -> List[List[float]]: 27 | """ 28 | 批量处理文本嵌入 29 | 30 | Args: 31 | texts: 要处理的文本列表 32 | 33 | Returns: 34 | List[List[float]]: 嵌入向量列表 35 | """ 36 | all_embeddings = [] 37 | 38 | # 根据批大小拆分文本 39 | for i in range(0, len(texts), self.batch_size): 40 | batch = texts[i:i+self.batch_size] 41 | logger.info(f"处理批次 {i//self.batch_size + 1}/{(len(texts) + self.batch_size - 1)//self.batch_size}, 大小: {len(batch)}") 42 | 43 | # 获取当前批次的嵌入 44 | batch_embeddings = get_embeddings(batch) 45 | all_embeddings.extend(batch_embeddings) 46 | 47 | return all_embeddings 48 | 49 | def process_in_parallel(self, texts: List[str]) -> List[List[float]]: 50 | """ 51 | 并行批量处理文本嵌入 52 | 53 | Args: 54 | texts: 要处理的文本列表 55 | 56 | Returns: 57 | List[List[float]]: 嵌入向量列表 58 | """ 59 | # 根据批大小拆分文本 60 | batches = [texts[i:i+self.batch_size] for i in range(0, len(texts), self.batch_size)] 61 | logger.info(f"拆分为 {len(batches)} 个批次进行并行处理") 62 | 63 | # 并行处理 64 | all_embeddings = [] 65 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 66 | # 提交所有批次任务 67 | futures = [executor.submit(get_embeddings, batch) for batch in batches] 68 | 69 | # 收集结果 70 | for future in futures: 71 | result = future.result() 72 | all_embeddings.extend(result) 73 | 74 | # 确保结果顺序与输入一致 75 | return all_embeddings 76 | -------------------------------------------------------------------------------- /backend/embeddings/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | from typing import List, Optional, Dict, Any 5 | import numpy as np 6 | from ..services.llm_service import LLMService 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | # 全局LLM服务实例 11 | _llm_service = None 12 | 13 | def get_llm_service() -> LLMService: 14 | """获取LLM服务单例""" 15 | global _llm_service 16 | if _llm_service is None: 17 | config_path = os.getenv("OLLAMA_CONFIG_PATH", "configs/ollama.yaml") 18 | _llm_service = LLMService(config_path) 19 | return _llm_service 20 | 21 | def get_embeddings(texts: List[str]) -> List[List[float]]: 22 | """ 23 | 获取文本列表的向量嵌入 24 | 25 | Args: 26 | texts: 文本列表 27 | 28 | Returns: 29 | List[List[float]]: 嵌入向量列表 30 | """ 31 | try: 32 | llm_service = get_llm_service() 33 | embeddings = [] 34 | 35 | for text in texts: 36 | # 截断长文本 37 | max_length = 8192 # 大多数嵌入模型的最大输入长度 38 | if len(text) > max_length: 39 | text = text[:max_length] 40 | 41 | # 获取嵌入 42 | embedding = llm_service.get_embedding(text) 43 | embeddings.append(embedding) 44 | 45 | return embeddings 46 | 47 | except Exception as e: 48 | logger.error(f"获取嵌入向量失败: {str(e)}") 49 | # 返回零向量作为回退方案 50 | return [[0.0] * 768] * len(texts) # 假设维度为768 51 | -------------------------------------------------------------------------------- /backend/processors/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from abc import ABC, abstractmethod 4 | from typing import Dict, Any, Tuple, List, Optional, BinaryIO, Type 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | class DocumentProcessor(ABC): 9 | """文档处理器基类,定义处理接口""" 10 | 11 | @abstractmethod 12 | def process(self, file: BinaryIO) -> Tuple[str, Dict[str, Any]]: 13 | """ 14 | 处理文档文件 15 | 16 | Args: 17 | file: 文件对象 18 | 19 | Returns: 20 | Tuple[str, Dict[str, Any]]: 提取的文本和元数据 21 | """ 22 | pass 23 | 24 | # 存储注册的处理器 25 | _PROCESSORS: Dict[str, Type[DocumentProcessor]] = {} 26 | 27 | def register_processor(extensions: List[str]): 28 | """文档处理器注册装饰器""" 29 | def decorator(processor_class: Type[DocumentProcessor]): 30 | for ext in extensions: 31 | _PROCESSORS[ext.lower()] = processor_class 32 | return processor_class 33 | return decorator 34 | 35 | def get_document_processor(extension: str) -> DocumentProcessor: 36 | """ 37 | 根据文件扩展名获取适当的处理器 38 | 39 | Args: 40 | extension: 文件扩展名(包含点,如'.pdf') 41 | 42 | Returns: 43 | DocumentProcessor: 文档处理器实例 44 | 45 | Raises: 46 | ValueError: 如果找不到适用于该扩展名的处理器 47 | """ 48 | ext = extension.lower() 49 | if ext not in _PROCESSORS: 50 | raise ValueError(f"未找到支持的处理器: {ext}") 51 | 52 | # 实例化处理器 53 | return _PROCESSORS[ext]() 54 | -------------------------------------------------------------------------------- /backend/processors/docx_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Dict, Any, Tuple, List, Optional, BinaryIO 4 | import tempfile 5 | import docx 6 | import re 7 | from .base import DocumentProcessor, register_processor 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | @register_processor(extensions=['.docx', '.doc']) 12 | class DocxProcessor(DocumentProcessor): 13 | """Word文档处理器""" 14 | 15 | def process(self, file: BinaryIO) -> Tuple[str, Dict[str, Any]]: 16 | """ 17 | 处理Word文件,提取文本和元数据 18 | 19 | Args: 20 | file: Word文件对象 21 | 22 | Returns: 23 | Tuple[str, Dict[str, Any]]: 提取的文本和元数据 24 | """ 25 | # 创建临时文件 26 | with tempfile.NamedTemporaryFile(suffix='.docx', delete=False) as tmp: 27 | # 写入数据 28 | tmp.write(file.read()) 29 | file.seek(0) # 重置文件指针 30 | tmp_path = tmp.name 31 | 32 | try: 33 | # 打开Word文档 34 | doc = docx.Document(tmp_path) 35 | 36 | # 提取元数据 37 | core_properties = doc.core_properties 38 | metadata = { 39 | "title": core_properties.title if hasattr(core_properties, 'title') else "", 40 | "author": core_properties.author if hasattr(core_properties, 'author') else "", 41 | "comments": core_properties.comments if hasattr(core_properties, 'comments') else "", 42 | "keywords": core_properties.keywords if hasattr(core_properties, 'keywords') else "", 43 | "subject": core_properties.subject if hasattr(core_properties, 'subject') else "", 44 | "last_modified_by": core_properties.last_modified_by if hasattr(core_properties, 'last_modified_by') else "", 45 | "created": str(core_properties.created) if hasattr(core_properties, 'created') else "", 46 | "modified": str(core_properties.modified) if hasattr(core_properties, 'modified') else "", 47 | "paragraph_count": len(doc.paragraphs), 48 | "section_count": len(doc.sections), 49 | } 50 | 51 | # 提取文本 52 | extracted_text = "" 53 | 54 | # 提取标题 55 | if doc.paragraphs and doc.paragraphs[0].style.name.startswith('Heading'): 56 | metadata["document_title"] = doc.paragraphs[0].text 57 | 58 | # 提取段落 59 | for para in doc.paragraphs: 60 | # 添加段落文本 61 | if para.text.strip(): 62 | extracted_text += para.text + "\n\n" 63 | 64 | # 提取表格内容 65 | table_count = 0 66 | for table in doc.tables: 67 | table_count += 1 68 | extracted_text += f"\n--- 表格 {table_count} ---\n" 69 | 70 | for i, row in enumerate(table.rows): 71 | if i == 0: 72 | # 表头 73 | extracted_text += "| " 74 | for cell in row.cells: 75 | extracted_text += cell.text + " | " 76 | extracted_text += "\n" 77 | extracted_text += "|" + "---|" * len(row.cells) + "\n" 78 | else: 79 | # 表体 80 | extracted_text += "| " 81 | for cell in row.cells: 82 | extracted_text += cell.text + " | " 83 | extracted_text += "\n" 84 | 85 | extracted_text += "\n" 86 | 87 | metadata["table_count"] = table_count 88 | 89 | # 清理文本 90 | extracted_text = self._clean_text(extracted_text) 91 | 92 | return extracted_text, metadata 93 | 94 | except Exception as e: 95 | logger.error(f"处理Word文档时出错: {str(e)}") 96 | raise Exception(f"Word文档处理失败: {str(e)}") 97 | finally: 98 | # 清理临时文件 99 | if os.path.exists(tmp_path): 100 | os.unlink(tmp_path) 101 | 102 | def _clean_text(self, text: str) -> str: 103 | """清理提取的文本""" 104 | # 删除多余的空格 105 | text = re.sub(r'\s+', ' ', text) 106 | 107 | # 删除多余的换行符 108 | text = re.sub(r'\n\s*\n', '\n\n', text) 109 | 110 | # 保持段落结构 111 | text = re.sub(r'(\w) *\n *(\w)', r'\1 \2', text) 112 | 113 | return text 114 | -------------------------------------------------------------------------------- /backend/processors/excel_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Dict, Any, Tuple, List, Optional, BinaryIO 4 | import tempfile 5 | import pandas as pd 6 | from .base import DocumentProcessor, register_processor 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | @register_processor(extensions=['.xlsx', '.xls']) 11 | class ExcelProcessor(DocumentProcessor): 12 | """Excel文档处理器""" 13 | 14 | def process(self, file: BinaryIO) -> Tuple[str, Dict[str, Any]]: 15 | """ 16 | 处理Excel文件,提取文本和元数据 17 | 18 | Args: 19 | file: Excel文件对象 20 | 21 | Returns: 22 | Tuple[str, Dict[str, Any]]: 提取的文本和元数据 23 | """ 24 | # 创建临时文件 25 | with tempfile.NamedTemporaryFile(suffix='.xlsx', delete=False) as tmp: 26 | # 写入数据 27 | tmp.write(file.read()) 28 | file.seek(0) # 重置文件指针 29 | tmp_path = tmp.name 30 | 31 | try: 32 | # 使用pandas读取Excel文件 33 | excel_file = pd.ExcelFile(tmp_path) 34 | sheet_names = excel_file.sheet_names 35 | 36 | # 提取元数据 37 | metadata = { 38 | "sheet_count": len(sheet_names), 39 | "sheet_names": sheet_names, 40 | "file_path": tmp_path, 41 | } 42 | 43 | # 提取文本 44 | extracted_text = "" 45 | 46 | # 遍历所有工作表 47 | for sheet_index, sheet_name in enumerate(sheet_names): 48 | df = pd.read_excel(excel_file, sheet_name=sheet_name) 49 | 50 | # 添加工作表标题 51 | extracted_text += f"\n--- 工作表: {sheet_name} ---\n\n" 52 | 53 | # 处理列名(表头) 54 | header_row = "| " + " | ".join(str(col) for col in df.columns) + " |\n" 55 | separator = "|" + "---|" * len(df.columns) + "\n" 56 | extracted_text += header_row + separator 57 | 58 | # 处理数据行 59 | for _, row in df.iterrows(): 60 | row_text = "| " + " | ".join(str(cell) if str(cell) != "nan" else "" for cell in row) + " |\n" 61 | extracted_text += row_text 62 | 63 | extracted_text += "\n" 64 | 65 | # 添加工作表级元数据 66 | metadata[f"sheet_{sheet_index}_rows"] = len(df) 67 | metadata[f"sheet_{sheet_index}_columns"] = len(df.columns) 68 | 69 | # 统计总行数和列数 70 | total_cells = sum(metadata.get(f"sheet_{i}_rows", 0) * metadata.get(f"sheet_{i}_columns", 0) 71 | for i in range(len(sheet_names))) 72 | metadata["total_cells"] = total_cells 73 | 74 | return extracted_text, metadata 75 | 76 | except Exception as e: 77 | logger.error(f"处理Excel文件时出错: {str(e)}") 78 | raise Exception(f"Excel处理失败: {str(e)}") 79 | finally: 80 | # 清理临时文件 81 | if os.path.exists(tmp_path): 82 | os.unlink(tmp_path) -------------------------------------------------------------------------------- /backend/processors/html_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Dict, Any, Tuple, List, Optional, BinaryIO 4 | import tempfile 5 | from bs4 import BeautifulSoup 6 | import re 7 | from .base import DocumentProcessor, register_processor 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | @register_processor(extensions=['.html', '.htm']) 12 | class HtmlProcessor(DocumentProcessor): 13 | """HTML文档处理器""" 14 | 15 | def process(self, file: BinaryIO) -> Tuple[str, Dict[str, Any]]: 16 | """ 17 | 处理HTML文件,提取文本和元数据 18 | 19 | Args: 20 | file: HTML文件对象 21 | 22 | Returns: 23 | Tuple[str, Dict[str, Any]]: 提取的文本和元数据 24 | """ 25 | # 读取文件内容 26 | content = file.read().decode('utf-8', errors='replace') 27 | 28 | try: 29 | # 使用BeautifulSoup解析HTML 30 | soup = BeautifulSoup(content, 'html.parser') 31 | 32 | # 提取元数据 33 | metadata = { 34 | "title": self._get_title(soup), 35 | "description": self._get_meta_content(soup, "description"), 36 | "keywords": self._get_meta_content(soup, "keywords"), 37 | "author": self._get_meta_content(soup, "author"), 38 | "links_count": len(soup.find_all('a')), 39 | "images_count": len(soup.find_all('img')), 40 | "tables_count": len(soup.find_all('table')), 41 | "scripts_count": len(soup.find_all('script')), 42 | "styles_count": len(soup.find_all('style')), 43 | } 44 | 45 | # 提取并清理文本 46 | extracted_text = self._extract_text(soup) 47 | 48 | return extracted_text, metadata 49 | 50 | except Exception as e: 51 | logger.error(f"处理HTML文件时出错: {str(e)}") 52 | raise Exception(f"HTML处理失败: {str(e)}") 53 | 54 | def _get_title(self, soup: BeautifulSoup) -> str: 55 | """获取HTML标题""" 56 | title_tag = soup.find('title') 57 | return title_tag.get_text() if title_tag else "" 58 | 59 | def _get_meta_content(self, soup: BeautifulSoup, meta_name: str) -> str: 60 | """获取指定名称的meta标签内容""" 61 | meta_tag = soup.find('meta', attrs={'name': meta_name}) 62 | if meta_tag and meta_tag.get('content'): 63 | return meta_tag.get('content') 64 | return "" 65 | 66 | def _extract_text(self, soup: BeautifulSoup) -> str: 67 | """提取并清理HTML文本内容""" 68 | # 删除脚本和样式标签 69 | for script_or_style in soup(["script", "style"]): 70 | script_or_style.decompose() 71 | 72 | # 提取文本 73 | text = soup.get_text() 74 | 75 | # 处理标题 76 | for heading in soup.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6']): 77 | heading_level = int(heading.name[1]) 78 | heading_text = heading.get_text().strip() 79 | text = text.replace(heading_text, f"\n{'#' * heading_level} {heading_text}\n") 80 | 81 | # 处理列表 82 | for ul in soup.find_all('ul'): 83 | for li in ul.find_all('li'): 84 | li_text = li.get_text().strip() 85 | text = text.replace(li_text, f"- {li_text}") 86 | 87 | # 处理有序列表 88 | for ol in soup.find_all('ol'): 89 | for i, li in enumerate(ol.find_all('li')): 90 | li_text = li.get_text().strip() 91 | text = text.replace(li_text, f"{i+1}. {li_text}") 92 | 93 | # 清理空白 94 | lines = (line.strip() for line in text.splitlines()) 95 | chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) 96 | text = '\n'.join(chunk for chunk in chunks if chunk) 97 | 98 | return text -------------------------------------------------------------------------------- /backend/processors/pdf_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Dict, Any, Tuple, List, Optional, BinaryIO 4 | import tempfile 5 | import fitz # PyMuPDF 6 | import re 7 | from .base import DocumentProcessor, register_processor 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | @register_processor(extensions=['.pdf']) 12 | class PDFProcessor(DocumentProcessor): 13 | """PDF文档处理器""" 14 | 15 | def process(self, file: BinaryIO) -> Tuple[str, Dict[str, Any]]: 16 | """ 17 | 处理PDF文件,提取文本和元数据 18 | 19 | Args: 20 | file: PDF文件对象 21 | 22 | Returns: 23 | Tuple[str, Dict[str, Any]]: 提取的文本和元数据 24 | """ 25 | # 创建临时文件 26 | with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as tmp: 27 | # 写入数据 28 | tmp.write(file.read()) 29 | file.seek(0) # 重置文件指针 30 | tmp_path = tmp.name 31 | 32 | try: 33 | # 打开PDF文档 34 | doc = fitz.open(tmp_path) 35 | 36 | # 提取元数据 37 | metadata = { 38 | "title": doc.metadata.get("title", ""), 39 | "author": doc.metadata.get("author", ""), 40 | "subject": doc.metadata.get("subject", ""), 41 | "keywords": doc.metadata.get("keywords", ""), 42 | "creator": doc.metadata.get("creator", ""), 43 | "producer": doc.metadata.get("producer", ""), 44 | "creation_date": doc.metadata.get("creationDate", ""), 45 | "modification_date": doc.metadata.get("modDate", ""), 46 | "page_count": len(doc), 47 | } 48 | 49 | # 提取文本 50 | extracted_text = "" 51 | for page_num, page in enumerate(doc): 52 | # 获取页面文本 53 | page_text = page.get_text() 54 | 55 | # 清理文本 56 | page_text = self._clean_text(page_text) 57 | 58 | # 添加页码标记 59 | extracted_text += f"\n--- 页 {page_num + 1} ---\n{page_text}\n" 60 | 61 | # 提取目录(TOC) 62 | toc = doc.get_toc() 63 | if toc: 64 | toc_data = [] 65 | for level, title, page in toc: 66 | toc_data.append({ 67 | "level": level, 68 | "title": title, 69 | "page": page 70 | }) 71 | metadata["toc"] = toc_data 72 | 73 | # 处理图像 74 | image_count = 0 75 | for page_num, page in enumerate(doc): 76 | image_list = page.get_images(full=True) 77 | image_count += len(image_list) 78 | 79 | metadata["image_count"] = image_count 80 | 81 | return extracted_text, metadata 82 | 83 | except Exception as e: 84 | logger.error(f"处理PDF时出错: {str(e)}") 85 | raise Exception(f"PDF处理失败: {str(e)}") 86 | finally: 87 | # 清理临时文件 88 | if os.path.exists(tmp_path): 89 | os.unlink(tmp_path) 90 | 91 | def _clean_text(self, text: str) -> str: 92 | """清理提取的文本""" 93 | # 删除多余的空格 94 | text = re.sub(r'\s+', ' ', text) 95 | 96 | # 删除多余的换行符 97 | text = re.sub(r'\n\s*\n', '\n\n', text) 98 | 99 | # 删除页眉页脚(假设出现在每页的前几行和后几行) 100 | lines = text.split('\n') 101 | if len(lines) > 6: 102 | # 保留中间部分,去掉可能的页眉页脚 103 | text = '\n'.join(lines) 104 | 105 | return text 106 | -------------------------------------------------------------------------------- /backend/services/document_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | from typing import Dict, List, Optional, Any, BinaryIO, Tuple 5 | import uuid 6 | import hashlib 7 | import mimetypes 8 | from datetime import datetime 9 | import redis 10 | import json 11 | from .vector_store import VectorStore 12 | from ..processors.base import get_document_processor 13 | from ..embeddings.model import get_embeddings 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | class DocumentService: 18 | def __init__(self, vector_store: VectorStore, config_path: str = "configs/worker.yaml", 19 | redis_config_path: str = "configs/redis.yaml"): 20 | # 加载配置 21 | with open(config_path, "r") as f: 22 | self.config = yaml.safe_load(f) 23 | 24 | with open(redis_config_path, "r") as f: 25 | redis_config = yaml.safe_load(f) 26 | 27 | # 初始化Redis客户端 28 | self.redis = redis.Redis( 29 | host=redis_config["redis"]["host"], 30 | port=redis_config["redis"]["port"], 31 | db=redis_config["redis"]["db"], 32 | password=redis_config["redis"]["password"], 33 | decode_responses=True 34 | ) 35 | 36 | # 存储向量存储引用 37 | self.vector_store = vector_store 38 | 39 | # 加载文档处理设置 40 | self.doc_settings = self.config["document_processing"] 41 | self.supported_formats = self.doc_settings["supported_formats"] 42 | self.chunk_size = self.doc_settings["chunk_size"] 43 | self.chunk_overlap = self.doc_settings["chunk_overlap"] 44 | 45 | logger.info(f"文档服务初始化完成,支持格式: {self.supported_formats}") 46 | 47 | def process_document(self, file: BinaryIO, filename: str, user_id: str, 48 | metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 49 | """处理文档文件,提取文本并准备索引""" 50 | # 验证文件格式 51 | ext = os.path.splitext(filename)[1].lower() 52 | if ext not in self.supported_formats: 53 | raise ValueError(f"不支持的文件格式: {ext}. 支持的格式: {self.supported_formats}") 54 | 55 | # 生成文档ID和文件哈希 56 | doc_id = str(uuid.uuid4()) 57 | file_content = file.read() 58 | file_hash = hashlib.sha256(file_content).hexdigest() 59 | file.seek(0) 60 | 61 | # 确定MIME类型 62 | mime_type, _ = mimetypes.guess_type(filename) 63 | if not mime_type: 64 | mime_type = "application/octet-stream" 65 | 66 | # 创建文档元数据 67 | doc_metadata = { 68 | "id": doc_id, 69 | "filename": filename, 70 | "file_extension": ext, 71 | "mime_type": mime_type, 72 | "file_size": len(file_content), 73 | "file_hash": file_hash, 74 | "upload_date": datetime.now().isoformat(), 75 | "user_id": user_id, 76 | "status": "processing", 77 | "chunks_count": 0 78 | } 79 | 80 | # 添加自定义元数据 81 | if metadata: 82 | doc_metadata.update({"custom_metadata": metadata}) 83 | 84 | # 临时存储文档元数据 85 | self.redis.set( 86 | f"doc:{doc_id}:metadata", 87 | json.dumps(doc_metadata), 88 | ex=3600 # 1小时过期 89 | ) 90 | 91 | try: 92 | # 获取适当的处理器 93 | processor = get_document_processor(ext) 94 | 95 | # 提取文本和元数据 96 | extracted_text, extracted_metadata = processor.process(file) 97 | 98 | # 更新元数据 99 | doc_metadata.update({ 100 | "extracted_metadata": extracted_metadata, 101 | "text_length": len(extracted_text), 102 | "status": "chunking" 103 | }) 104 | 105 | # 更新Redis 106 | self.redis.set( 107 | f"doc:{doc_id}:metadata", 108 | json.dumps(doc_metadata), 109 | ex=3600 110 | ) 111 | 112 | # 分块文档 113 | chunks = self._chunk_document(extracted_text, doc_id) 114 | 115 | # 更新元数据 116 | doc_metadata.update({ 117 | "chunks_count": len(chunks), 118 | "status": "embedding" 119 | }) 120 | 121 | # 更新Redis 122 | self.redis.set( 123 | f"doc:{doc_id}:metadata", 124 | json.dumps(doc_metadata), 125 | ex=3600 126 | ) 127 | 128 | # 处理文档嵌入和索引 129 | self._process_embeddings_and_index(chunks, doc_id, doc_metadata) 130 | 131 | # 更新最终状态 132 | doc_metadata["status"] = "indexed" 133 | self.redis.set( 134 | f"doc:{doc_id}:metadata", 135 | json.dumps(doc_metadata), 136 | ex=3600 137 | ) 138 | 139 | # 返回元数据 140 | return doc_metadata 141 | 142 | except Exception as e: 143 | # 更新元数据错误 144 | doc_metadata.update({ 145 | "status": "error", 146 | "processing_error": str(e) 147 | }) 148 | 149 | # 更新Redis 150 | self.redis.set( 151 | f"doc:{doc_id}:metadata", 152 | json.dumps(doc_metadata), 153 | ex=3600 154 | ) 155 | 156 | logger.error(f"处理文档 {filename} (ID: {doc_id}) 错误: {str(e)}") 157 | raise Exception(f"文档处理失败: {str(e)}") 158 | 159 | def _chunk_document(self, text: str, doc_id: str) -> List[Dict[str, Any]]: 160 | """将文档文本分成块""" 161 | chunks = [] 162 | 163 | # 基于字符的分块,带重叠 164 | start = 0 165 | chunk_id = 0 166 | 167 | while start < len(text): 168 | # 计算结束位置 169 | end = min(start + self.chunk_size, len(text)) 170 | 171 | # 如果不是最后一块,尝试找到自然断点 172 | if end < len(text): 173 | # 尝试找到句子断点(句号、问号、感叹号) 174 | for i in range(min(100, end - start)): 175 | if text[end - i - 1] in ['.', '?', '!', '\n'] and text[end - i] in [' ', '\n']: 176 | end = end - i 177 | break 178 | 179 | # 提取块文本 180 | chunk_text = text[start:end] 181 | 182 | # 创建块元数据 183 | chunk = { 184 | "chunk_id": f"{doc_id}_{chunk_id}", 185 | "document_id": doc_id, 186 | "text": chunk_text, 187 | "start_char": start, 188 | "end_char": end, 189 | "length": len(chunk_text) 190 | } 191 | 192 | chunks.append(chunk) 193 | 194 | # 移动到下一块位置,考虑重叠 195 | start = end - self.chunk_overlap 196 | chunk_id += 1 197 | 198 | # 确保我们有进展 199 | if start >= end: 200 | start = end 201 | 202 | return chunks 203 | 204 | def _process_embeddings_and_index(self, chunks: List[Dict[str, Any]], 205 | doc_id: str, doc_metadata: Dict[str, Any]): 206 | """处理文档块嵌入和索引""" 207 | # 提取文本列表 208 | texts = [chunk["text"] for chunk in chunks] 209 | 210 | # 批量获取嵌入 211 | batch_size = self.config["embedding"]["batch_size"] 212 | all_embeddings = [] 213 | 214 | for i in range(0, len(texts), batch_size): 215 | batch_texts = texts[i:i+batch_size] 216 | batch_embeddings = get_embeddings(batch_texts) 217 | all_embeddings.extend(batch_embeddings) 218 | 219 | # 准备向量和有效载荷 220 | vectors = all_embeddings 221 | payloads = chunks 222 | chunk_ids = [chunk["chunk_id"] for chunk in chunks] 223 | 224 | # 存储在向量存储中 225 | self.vector_store.add_documents( 226 | vectors=vectors, 227 | payloads=payloads, 228 | ids=chunk_ids 229 | ) 230 | 231 | # 存储文档元数据 232 | doc_vector = all_embeddings[0] if all_embeddings else [0.0] * 768 # 默认向量大小 233 | self.vector_store.add_documents( 234 | vectors=[doc_vector], 235 | payloads=[doc_metadata], 236 | ids=[doc_id], 237 | collection_name=self.vector_store.metadata_collection 238 | ) 239 | 240 | def get_document_metadata(self, document_id: str) -> Dict[str, Any]: 241 | """ 242 | 获取文档元数据 243 | 244 | Args: 245 | document_id: 文档ID 246 | 247 | Returns: 248 | Dict[str, Any]: 文档元数据 249 | 250 | Raises: 251 | ValueError: 如果文档不存在 252 | """ 253 | try: 254 | # 从Redis检查缓存 255 | cached_metadata = self.redis.get(f"doc:{document_id}:metadata") 256 | if cached_metadata: 257 | return json.loads(cached_metadata) 258 | 259 | # 如果缓存中没有,从向量存储中获取 260 | results = self.vector_store.query( 261 | query_vector=[0.0] * 768, # 使用空向量 262 | limit=1, 263 | filter_={"id": document_id}, 264 | collection_name=self.vector_store.metadata_collection 265 | ) 266 | 267 | if not results: 268 | raise ValueError(f"文档 {document_id} 不存在") 269 | 270 | # 返回文档元数据 271 | return results[0]["payload"] 272 | 273 | except Exception as e: 274 | logger.error(f"获取文档元数据错误: {str(e)}") 275 | raise Exception(f"获取文档元数据失败: {str(e)}") 276 | 277 | def get_all_documents(self, filters: Dict[str, Any], limit: int = 100, offset: int = 0) -> Tuple[List[Dict[str, Any]], int]: 278 | """ 279 | 获取所有文档的元数据 280 | 281 | Args: 282 | filters: 过滤条件 283 | limit: 最大返回数量 284 | offset: 偏移量 285 | 286 | Returns: 287 | Tuple[List[Dict[str, Any]], int]: 文档元数据列表和总数 288 | """ 289 | try: 290 | # 从向量存储中获取所有文档 291 | results = self.vector_store.query( 292 | query_vector=[0.0] * 768, # 使用空向量 293 | limit=limit + offset, 294 | filter_=filters, 295 | collection_name=self.vector_store.metadata_collection 296 | ) 297 | 298 | # 获取总数 299 | total_count = len(results) 300 | 301 | # 应用偏移和限制 302 | results = results[offset:offset+limit] 303 | 304 | # 提取元数据 305 | documents = [result["payload"] for result in results] 306 | 307 | return documents, total_count 308 | 309 | except Exception as e: 310 | logger.error(f"获取所有文档错误: {str(e)}") 311 | raise Exception(f"获取文档列表失败: {str(e)}") 312 | 313 | def delete_document(self, document_id: str) -> bool: 314 | """ 315 | 删除文档及其索引 316 | 317 | Args: 318 | document_id: 文档ID 319 | 320 | Returns: 321 | bool: 操作是否成功 322 | 323 | Raises: 324 | ValueError: 如果文档不存在 325 | """ 326 | try: 327 | # 检查文档是否存在 328 | doc_metadata = self.get_document_metadata(document_id) 329 | 330 | # 删除文档块 331 | self.vector_store.client.delete( 332 | collection_name=self.vector_store.default_collection, 333 | points_selector=rest.Filter( 334 | must=[ 335 | rest.FieldCondition( 336 | key="document_id", 337 | match=rest.MatchValue(value=document_id) 338 | ) 339 | ] 340 | ) 341 | ) 342 | 343 | # 删除文档元数据 344 | self.vector_store.client.delete( 345 | collection_name=self.vector_store.metadata_collection, 346 | points_selector=rest.Filter( 347 | must=[ 348 | rest.FieldCondition( 349 | key="id", 350 | match=rest.MatchValue(value=document_id) 351 | ) 352 | ] 353 | ) 354 | ) 355 | 356 | # 删除Redis缓存 357 | self.redis.delete(f"doc:{document_id}:metadata") 358 | 359 | # 创建一个删除文档的后台任务来清理相关资源 360 | task_id = str(uuid.uuid4()) 361 | task_data = { 362 | "type": "document_cleanup", 363 | "task_id": task_id, 364 | "document_id": document_id, 365 | "created_at": time.time() 366 | } 367 | 368 | # 将任务添加到队列 369 | self.redis.rpush("task_queue", json.dumps(task_data)) 370 | 371 | return True 372 | 373 | except ValueError as e: 374 | raise e 375 | except Exception as e: 376 | logger.error(f"删除文档错误: {str(e)}") 377 | raise Exception(f"删除文档失败: {str(e)}") 378 | 379 | def reindex_document(self, document_id: str) -> str: 380 | """ 381 | 重新索引文档 382 | 383 | Args: 384 | document_id: 文档ID 385 | 386 | Returns: 387 | str: 任务ID 388 | 389 | Raises: 390 | ValueError: 如果文档不存在 391 | """ 392 | try: 393 | # 检查文档是否存在 394 | doc_metadata = self.get_document_metadata(document_id) 395 | 396 | # 创建一个重新索引任务 397 | task_id = str(uuid.uuid4()) 398 | task_data = { 399 | "type": "indexing", 400 | "task_id": task_id, 401 | "document_ids": [document_id], 402 | "user_id": doc_metadata.get("user_id", ""), 403 | "rebuild_all": False, 404 | "created_at": time.time() 405 | } 406 | 407 | # 将任务添加到队列 408 | self.redis.rpush("task_queue", json.dumps(task_data)) 409 | 410 | # 更新文档状态 411 | doc_metadata["status"] = "reindexing" 412 | self.redis.set( 413 | f"doc:{document_id}:metadata", 414 | json.dumps(doc_metadata), 415 | ex=3600 416 | ) 417 | 418 | # 存储任务状态 419 | self.redis.hset( 420 | f"task:{task_id}", 421 | mapping={ 422 | "status": "queued", 423 | "type": "indexing", 424 | "document_ids": json.dumps([document_id]), 425 | "created_at": time.time(), 426 | "user_id": doc_metadata.get("user_id", "") 427 | } 428 | ) 429 | 430 | return task_id 431 | 432 | except ValueError as e: 433 | raise e 434 | except Exception as e: 435 | logger.error(f"重新索引文档错误: {str(e)}") 436 | raise Exception(f"重新索引文档失败: {str(e)}") 437 | 438 | def get_task_status(self, task_id: str) -> Dict[str, Any]: 439 | """ 440 | 获取任务状态 441 | 442 | Args: 443 | task_id: 任务ID 444 | 445 | Returns: 446 | Dict[str, Any]: 任务状态信息 447 | 448 | Raises: 449 | ValueError: 如果任务不存在 450 | """ 451 | try: 452 | # 从Redis获取任务状态 453 | task_data = self.redis.hgetall(f"task:{task_id}") 454 | 455 | if not task_data: 456 | raise ValueError(f"任务 {task_id} 不存在") 457 | 458 | # 转换某些字段 459 | if "document_ids" in task_data and task_data["document_ids"]: 460 | task_data["document_ids"] = json.loads(task_data["document_ids"]) 461 | 462 | if "created_at" in task_data and task_data["created_at"]: 463 | task_data["created_at"] = float(task_data["created_at"]) 464 | 465 | if "completed_at" in task_data and task_data["completed_at"]: 466 | task_data["completed_at"] = float(task_data["completed_at"]) 467 | 468 | return task_data 469 | 470 | except ValueError as e: 471 | raise e 472 | except Exception as e: 473 | logger.error(f"获取任务状态错误: {str(e)}") 474 | raise Exception(f"获取任务状态失败: {str(e)}") -------------------------------------------------------------------------------- /backend/services/graphrag_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | from typing import Dict, List, Optional, Any 5 | import networkx as nx 6 | import numpy as np 7 | from .vector_store import VectorStore 8 | from .llm_service import LLMService 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | class GraphRAGService: 13 | def __init__(self, vector_store: VectorStore, llm_service: LLMService, 14 | config_path: str = "configs/worker.yaml"): 15 | # 加载配置 16 | with open(config_path, "r") as f: 17 | self.config = yaml.safe_load(f) 18 | 19 | # 存储服务引用 20 | self.vector_store = vector_store 21 | self.llm_service = llm_service 22 | 23 | # 加载GraphRAG设置 24 | self.rag_settings = self.config["graphrag"] 25 | self.similarity_threshold = self.rag_settings["similarity_threshold"] 26 | self.max_neighbors = self.rag_settings["max_neighbors"] 27 | self.max_hops = self.rag_settings["max_hops"] 28 | 29 | # 初始化图结构 30 | self.graph = nx.Graph() 31 | 32 | logger.info("GraphRAG服务初始化完成") 33 | 34 | def update_graph(self, document_id: str = None): 35 | """ 36 | 更新知识图谱,可以指定文档ID更新特定文档,或不指定更新全部 37 | 38 | Args: 39 | document_id (str, optional): 特定文档ID,为None时更新所有文档 40 | """ 41 | try: 42 | # 查询条件 43 | filter_dict = {} 44 | if document_id: 45 | filter_dict = {"document_id": document_id} 46 | 47 | # 获取所有块 48 | all_chunks = self.vector_store.query( 49 | query_vector=[0.0] * 768, # 使用空向量,将获取所有文档而不是相似匹配 50 | limit=10000, # 大量限制,实际上取决于数据库能返回的最大记录数 51 | filter_=filter_dict 52 | ) 53 | 54 | # 清理图形(如果更新特定文档) 55 | if document_id: 56 | # 删除与特定文档相关的节点 57 | nodes_to_remove = [node for node in self.graph.nodes if self.graph.nodes[node].get("document_id") == document_id] 58 | self.graph.remove_nodes_from(nodes_to_remove) 59 | else: 60 | # 重建整个图 61 | self.graph.clear() 62 | 63 | # 为每个块创建节点 64 | for chunk in all_chunks: 65 | chunk_id = chunk["id"] 66 | doc_id = chunk["payload"]["document_id"] 67 | 68 | # 添加节点 69 | self.graph.add_node( 70 | chunk_id, 71 | document_id=doc_id, 72 | text=chunk["payload"]["text"], 73 | embedding=chunk["payload"].get("embedding", []) 74 | ) 75 | 76 | # 计算相似度并创建边 77 | chunk_ids = list(self.graph.nodes) 78 | for i, chunk_id in enumerate(chunk_ids): 79 | # 获取当前块的向量嵌入 80 | if not self.graph.nodes[chunk_id].get("embedding"): 81 | # 如果没有嵌入,获取文本并生成嵌入 82 | text = self.graph.nodes[chunk_id]["text"] 83 | embedding = self.llm_service.get_embedding(text) 84 | self.graph.nodes[chunk_id]["embedding"] = embedding 85 | 86 | current_embedding = self.graph.nodes[chunk_id]["embedding"] 87 | current_doc_id = self.graph.nodes[chunk_id]["document_id"] 88 | 89 | # 为效率考虑,只与同一文档内的块或特殊关系块计算相似度 90 | for j, other_id in enumerate(chunk_ids): 91 | if i == j: 92 | continue 93 | 94 | other_doc_id = self.graph.nodes[other_id]["document_id"] 95 | 96 | # 同一文档内相邻块自动连接 97 | if current_doc_id == other_doc_id and abs(i - j) == 1: 98 | self.graph.add_edge(chunk_id, other_id, weight=0.9) 99 | continue 100 | 101 | # 计算其他相似块 102 | if not self.graph.nodes[other_id].get("embedding"): 103 | # 如果没有嵌入,获取文本并生成嵌入 104 | text = self.graph.nodes[other_id]["text"] 105 | embedding = self.llm_service.get_embedding(text) 106 | self.graph.nodes[other_id]["embedding"] = embedding 107 | 108 | other_embedding = self.graph.nodes[other_id]["embedding"] 109 | 110 | # 计算余弦相似度 111 | similarity = self._cosine_similarity(current_embedding, other_embedding) 112 | 113 | # 如果相似度高于阈值,添加边 114 | if similarity > self.similarity_threshold: 115 | self.graph.add_edge(chunk_id, other_id, weight=similarity) 116 | 117 | logger.info(f"知识图谱更新完成,节点数: {self.graph.number_of_nodes()}, 边数: {self.graph.number_of_edges()}") 118 | 119 | except Exception as e: 120 | logger.error(f"更新知识图谱时出错: {str(e)}") 121 | raise Exception(f"知识图谱更新失败: {str(e)}") 122 | 123 | def get_context_for_query(self, query: str, user_id: Optional[str] = None, 124 | document_ids: Optional[List[str]] = None, 125 | max_results: int = 5) -> List[Dict[str, Any]]: 126 | """ 127 | 为查询获取上下文信息 128 | 129 | Args: 130 | query: 用户查询 131 | user_id: 可选用户ID过滤 132 | document_ids: 可选文档ID列表过滤 133 | max_results: 最大返回结果数量 134 | 135 | Returns: 136 | List[Dict[str, Any]]: 相关上下文信息列表 137 | """ 138 | try: 139 | # 生成查询嵌入 140 | query_embedding = self.llm_service.get_embedding(query) 141 | 142 | # 准备过滤器 143 | filter_dict = {} 144 | if user_id: 145 | filter_dict["user_id"] = user_id 146 | if document_ids: 147 | filter_dict["document_id"] = document_ids 148 | 149 | # 首先进行向量搜索找到最相关的入口点 150 | initial_results = self.vector_store.query( 151 | query_vector=query_embedding, 152 | limit=3, # 获取少量高质量入口点 153 | filter_=filter_dict 154 | ) 155 | 156 | # 扩展结果集 157 | expanded_results = set() 158 | for result in initial_results: 159 | chunk_id = result["id"] 160 | expanded_results.add(chunk_id) 161 | 162 | # 图遍历扩展 163 | if chunk_id in self.graph: 164 | # 获取邻居节点 165 | neighbors = self._get_relevant_neighbors(chunk_id, query_embedding, hops=self.max_hops) 166 | expanded_results.update(neighbors) 167 | 168 | # 将扩展结果转换回详细信息 169 | detailed_results = [] 170 | for chunk_id in expanded_results: 171 | # 从图中获取节点信息 172 | if chunk_id in self.graph: 173 | node_data = self.graph.nodes[chunk_id] 174 | if "text" in node_data: 175 | # 计算与查询的相似度 176 | similarity = self._cosine_similarity( 177 | query_embedding, 178 | node_data.get("embedding", self.llm_service.get_embedding(node_data["text"])) 179 | ) 180 | 181 | detailed_results.append({ 182 | "id": chunk_id, 183 | "text": node_data["text"], 184 | "score": similarity, 185 | "metadata": { 186 | "document_id": node_data.get("document_id", ""), 187 | } 188 | }) 189 | 190 | # 排序并限制结果数量 191 | detailed_results.sort(key=lambda x: x["score"], reverse=True) 192 | return detailed_results[:max_results] 193 | 194 | except Exception as e: 195 | logger.error(f"获取查询上下文时出错: {str(e)}") 196 | # 回退到简单向量搜索 197 | try: 198 | fallback_results = self.vector_store.query( 199 | query_vector=query_embedding, 200 | limit=max_results, 201 | filter_=filter_dict 202 | ) 203 | 204 | # 格式化结果 205 | return [{ 206 | "id": result["id"], 207 | "text": result["payload"]["text"], 208 | "score": result["score"], 209 | "metadata": { 210 | "document_id": result["payload"].get("document_id", ""), 211 | } 212 | } for result in fallback_results] 213 | except: 214 | logger.error("回退搜索也失败") 215 | return [] 216 | 217 | def _get_relevant_neighbors(self, start_node: str, query_embedding: List[float], 218 | hops: int = 2) -> List[str]: 219 | """获取与查询相关的邻居节点""" 220 | # BFS搜索相关节点 221 | visited = set([start_node]) 222 | queue = [(start_node, 0)] # (node, hop) 223 | relevant_nodes = [] 224 | 225 | while queue: 226 | node, hop_count = queue.pop(0) 227 | 228 | # 如果超过最大跳数,停止 229 | if hop_count >= hops: 230 | continue 231 | 232 | # 获取邻居节点 233 | if node not in self.graph: 234 | continue 235 | 236 | neighbors = list(self.graph.neighbors(node)) 237 | 238 | # 计算邻居节点与查询的相似度 239 | neighbor_similarities = [] 240 | for neighbor in neighbors: 241 | if neighbor in visited: 242 | continue 243 | 244 | # 获取向量嵌入 245 | if "embedding" not in self.graph.nodes[neighbor]: 246 | text = self.graph.nodes[neighbor]["text"] 247 | self.graph.nodes[neighbor]["embedding"] = self.llm_service.get_embedding(text) 248 | 249 | embedding = self.graph.nodes[neighbor]["embedding"] 250 | 251 | # 计算相似度 252 | similarity = self._cosine_similarity(query_embedding, embedding) 253 | neighbor_similarities.append((neighbor, similarity)) 254 | 255 | # 按相似度排序 256 | neighbor_similarities.sort(key=lambda x: x[1], reverse=True) 257 | 258 | # 取前N个相关邻居 259 | for neighbor, similarity in neighbor_similarities[:self.max_neighbors]: 260 | if similarity > self.similarity_threshold and neighbor not in visited: 261 | visited.add(neighbor) 262 | queue.append((neighbor, hop_count + 1)) 263 | relevant_nodes.append(neighbor) 264 | 265 | return relevant_nodes 266 | 267 | def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float: 268 | """计算余弦相似度""" 269 | vec1 = np.array(vec1) 270 | vec2 = np.array(vec2) 271 | 272 | norm1 = np.linalg.norm(vec1) 273 | norm2 = np.linalg.norm(vec2) 274 | 275 | if norm1 == 0 or norm2 == 0: 276 | return 0.0 277 | 278 | return np.dot(vec1, vec2) / (norm1 * norm2) 279 | -------------------------------------------------------------------------------- /backend/services/llm_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | import httpx 5 | from typing import Dict, List, Optional, Any 6 | import json 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | class LLMService: 11 | def __init__(self, config_path: str = "configs/ollama.yaml"): 12 | # 加载配置 13 | with open(config_path, "r") as f: 14 | self.config = yaml.safe_load(f) 15 | 16 | self.host = self.config["ollama"]["host"] 17 | self.port = self.config["ollama"]["port"] 18 | self.timeout = self.config["ollama"]["timeout"] 19 | self.base_url = f"http://{self.host}:{self.port}" 20 | self.default_model = self.config["models"]["default"] 21 | self.embeddings_model = self.config["models"]["embeddings"] 22 | 23 | # 初始化客户端 24 | self.client = httpx.Client(timeout=self.timeout) 25 | 26 | # 加载推理参数 27 | self.inference_params = self.config["inference"] 28 | 29 | logger.info(f"LLM服务初始化完成,使用模型: {self.default_model}") 30 | 31 | def generate(self, prompt: str, system_prompt: Optional[str] = None, **kwargs) -> Dict[str, Any]: 32 | """生成LLM响应""" 33 | model = kwargs.pop("model", self.default_model) 34 | 35 | # 准备请求数据 36 | request_data = { 37 | "model": model, 38 | "prompt": prompt, 39 | **self.inference_params, 40 | **kwargs 41 | } 42 | 43 | if system_prompt: 44 | request_data["system"] = system_prompt 45 | 46 | try: 47 | # 请求Ollama API 48 | response = self.client.post( 49 | f"{self.base_url}/api/generate", 50 | json=request_data 51 | ) 52 | response.raise_for_status() 53 | 54 | return response.json() 55 | except httpx.HTTPError as e: 56 | logger.error(f"Ollama API调用失败: {str(e)}") 57 | raise Exception(f"生成响应失败: {str(e)}") 58 | 59 | def generate_stream(self, prompt: str, system_prompt: Optional[str] = None, **kwargs): 60 | """生成流式LLM响应""" 61 | model = kwargs.pop("model", self.default_model) 62 | 63 | # 准备请求数据 64 | request_data = { 65 | "model": model, 66 | "prompt": prompt, 67 | "stream": True, 68 | **self.inference_params, 69 | **kwargs 70 | } 71 | 72 | if system_prompt: 73 | request_data["system"] = system_prompt 74 | 75 | try: 76 | with httpx.stream("POST", f"{self.base_url}/api/generate", 77 | json=request_data, timeout=self.timeout) as response: 78 | response.raise_for_status() 79 | 80 | for line in response.iter_lines(): 81 | if line: 82 | try: 83 | chunk = json.loads(line) 84 | if "response" in chunk: 85 | yield chunk["response"] 86 | except json.JSONDecodeError: 87 | logger.warning(f"无法解析JSON响应: {line}") 88 | except httpx.HTTPError as e: 89 | logger.error(f"Ollama流式API调用失败: {str(e)}") 90 | raise Exception(f"生成流式响应失败: {str(e)}") 91 | 92 | def get_embedding(self, text: str, model: Optional[str] = None) -> List[float]: 93 | """获取文本的向量嵌入""" 94 | model = model or self.embeddings_model 95 | 96 | try: 97 | response = self.client.post( 98 | f"{self.base_url}/api/embeddings", 99 | json={"model": model, "prompt": text} 100 | ) 101 | response.raise_for_status() 102 | result = response.json() 103 | 104 | return result["embedding"] 105 | except httpx.HTTPError as e: 106 | logger.error(f"获取向量嵌入失败: {str(e)}") 107 | raise Exception(f"生成向量嵌入失败: {str(e)}") 108 | -------------------------------------------------------------------------------- /backend/services/search_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | from typing import Dict, List, Optional, Any 5 | import redis 6 | import json 7 | import time 8 | from .vector_store import VectorStore 9 | from .llm_service import LLMService 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | class SearchService: 14 | def __init__(self, vector_store: VectorStore, llm_service: LLMService, 15 | redis_config_path: str = "configs/redis.yaml"): 16 | # 加载Redis配置 17 | with open(redis_config_path, "r") as f: 18 | redis_config = yaml.safe_load(f) 19 | 20 | # 初始化Redis客户端 21 | self.redis = redis.Redis( 22 | host=redis_config["redis"]["host"], 23 | port=redis_config["redis"]["port"], 24 | db=redis_config["redis"]["db"], 25 | password=redis_config["redis"]["password"], 26 | decode_responses=True 27 | ) 28 | 29 | # 缓存TTL 30 | self.cache_ttl = redis_config["cache"]["ttl"] 31 | 32 | # 存储服务 33 | self.vector_store = vector_store 34 | self.llm_service = llm_service 35 | 36 | logger.info("搜索服务初始化完成") 37 | 38 | def search(self, query: str, user_id: Optional[str] = None, limit: int = 10, 39 | use_hybrid: bool = True, filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: 40 | """使用查询字符串搜索向量存储""" 41 | # 生成缓存键 42 | cache_key = f"search:{hash(query)}:{user_id or 'all'}:{limit}:{use_hybrid}:{hash(str(filters))}" 43 | 44 | # 检查缓存 45 | cached_results = self.redis.get(cache_key) 46 | if cached_results: 47 | logger.info(f"缓存命中: {query}") 48 | return json.loads(cached_results) 49 | 50 | try: 51 | # 生成查询嵌入 52 | start_time = time.time() 53 | query_embedding = self.llm_service.get_embedding(query) 54 | embedding_time = time.time() - start_time 55 | logger.debug(f"生成查询嵌入耗时: {embedding_time:.3f}秒") 56 | 57 | # 准备过滤器 58 | search_filter = self._prepare_filter(user_id, filters) 59 | 60 | # 执行向量搜索 61 | start_time = time.time() 62 | semantic_results = self.vector_store.query( 63 | query_vector=query_embedding, 64 | limit=limit, 65 | filter_=search_filter 66 | ) 67 | vector_time = time.time() - start_time 68 | logger.debug(f"向量搜索完成,耗时: {vector_time:.3f}秒") 69 | 70 | # 格式化结果 71 | results = self._format_search_results(semantic_results) 72 | 73 | # 缓存结果 74 | self.redis.set(cache_key, json.dumps(results), ex=self.cache_ttl) 75 | 76 | return results 77 | 78 | except Exception as e: 79 | logger.error(f"搜索查询'{query}'时出错: {str(e)}") 80 | raise Exception(f"搜索失败: {str(e)}") 81 | 82 | def _prepare_filter(self, user_id: Optional[str] = None, 83 | additional_filters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 84 | """准备向量搜索的过滤器""" 85 | filter_dict = {} 86 | filter_conditions = [] 87 | 88 | # 添加用户过滤器 89 | if user_id: 90 | filter_conditions.append({ 91 | "key": "user_id", 92 | "match": { 93 | "value": user_id 94 | } 95 | }) 96 | 97 | # 添加额外过滤器 98 | if additional_filters: 99 | for key, value in additional_filters.items(): 100 | if isinstance(value, list): 101 | # 处理列表值(OR条件) 102 | or_conditions = [] 103 | for val in value: 104 | or_conditions.append({ 105 | "key": key, 106 | "match": { 107 | "value": val 108 | } 109 | }) 110 | 111 | if or_conditions: 112 | filter_conditions.append({ 113 | "should": or_conditions 114 | }) 115 | else: 116 | # 处理单个值 117 | filter_conditions.append({ 118 | "key": key, 119 | "match": { 120 | "value": value 121 | } 122 | }) 123 | 124 | # 将所有条件与AND逻辑组合 125 | if filter_conditions: 126 | filter_dict = { 127 | "must": filter_conditions 128 | } 129 | 130 | return filter_dict 131 | 132 | def _format_search_results(self, vector_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 133 | """将向量搜索结果格式化为标准格式""" 134 | formatted_results = [] 135 | 136 | for result in vector_results: 137 | # 提取基本信息 138 | result_id = result["id"] 139 | score = result["score"] 140 | payload = result["payload"] 141 | 142 | # 创建格式化结果 143 | formatted_result = { 144 | "id": result_id, 145 | "score": score, 146 | "text": payload.get("text", ""), 147 | "metadata": { 148 | "document_id": payload.get("document_id", ""), 149 | "chunk_id": payload.get("chunk_id", ""), 150 | "filename": payload.get("filename", ""), 151 | "position": { 152 | "start": payload.get("start_char", 0), 153 | "end": payload.get("end_char", 0) 154 | } 155 | } 156 | } 157 | 158 | formatted_results.append(formatted_result) 159 | 160 | return formatted_results 161 | -------------------------------------------------------------------------------- /backend/services/vector_store.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | from typing import Dict, List, Optional, Any 5 | from qdrant_client import QdrantClient 6 | from qdrant_client.http import models as rest 7 | import numpy as np 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | class VectorStore: 12 | def __init__(self, config_path: str = "configs/qdrant.yaml"): 13 | # 加载配置 14 | with open(config_path, "r") as f: 15 | self.config = yaml.safe_load(f) 16 | 17 | # 初始化Qdrant客户端 18 | self.client = QdrantClient( 19 | host=self.config["qdrant"]["host"], 20 | port=self.config["qdrant"]["port"], 21 | prefer_grpc=self.config["qdrant"]["prefer_grpc"], 22 | timeout=self.config["qdrant"]["timeout"] 23 | ) 24 | 25 | # 获取集合配置 26 | self.default_collection = self.config["collections"]["default"]["name"] 27 | self.metadata_collection = self.config["collections"]["metadata"]["name"] 28 | 29 | # 初始化集合(如果不存在) 30 | self._initialize_collections() 31 | 32 | logger.info(f"向量存储初始化完成: {self.default_collection}, {self.metadata_collection}") 33 | 34 | def _initialize_collections(self): 35 | """初始化向量集合""" 36 | collections = [collection.name for collection in self.client.get_collections().collections] 37 | 38 | # 初始化文档集合 39 | if self.default_collection not in collections: 40 | doc_config = self.config["collections"]["default"] 41 | self.client.create_collection( 42 | collection_name=self.default_collection, 43 | vectors_config=rest.VectorsConfig( 44 | size=doc_config["vector_size"], 45 | distance=rest.Distance[doc_config["distance"]] 46 | ), 47 | optimizers_config=rest.OptimizersConfigDiff( 48 | deleted_threshold=doc_config["optimizers"]["deleted_threshold"], 49 | vacuum_min_vector_number=doc_config["optimizers"]["vacuum_min_vector_number"] 50 | ), 51 | hnsw_config=rest.HnswConfigDiff( 52 | m=doc_config["index"]["m"], 53 | ef_construct=doc_config["index"]["ef_construct"] 54 | ) 55 | ) 56 | logger.info(f"创建集合 {self.default_collection}") 57 | 58 | # 初始化元数据集合 59 | if self.metadata_collection not in collections: 60 | meta_config = self.config["collections"]["metadata"] 61 | self.client.create_collection( 62 | collection_name=self.metadata_collection, 63 | vectors_config=rest.VectorsConfig( 64 | size=meta_config["vector_size"], 65 | distance=rest.Distance[meta_config["distance"]] 66 | ) 67 | ) 68 | logger.info(f"创建集合 {self.metadata_collection}") 69 | 70 | def add_documents(self, vectors: List[List[float]], payloads: List[Dict[str, Any]], 71 | ids: Optional[List[str]] = None, collection_name: Optional[str] = None) -> List[str]: 72 | """添加文档向量和元数据""" 73 | collection_name = collection_name or self.default_collection 74 | 75 | # 如果没有提供ID,则生成 76 | if ids is None: 77 | ids = [str(i) for i in range(len(vectors))] 78 | 79 | # 转换为numpy数组进行验证 80 | vectors_np = np.array(vectors, dtype=np.float32) 81 | 82 | try: 83 | # 创建点批次 84 | points = [ 85 | rest.PointStruct( 86 | id=str(id_), 87 | vector=vector.tolist(), 88 | payload=payload 89 | ) 90 | for id_, vector, payload in zip(ids, vectors_np, payloads) 91 | ] 92 | 93 | # 插入批次 94 | self.client.upsert( 95 | collection_name=collection_name, 96 | points=points 97 | ) 98 | 99 | logger.info(f"向{collection_name}添加了{len(points)}个向量") 100 | return ids 101 | 102 | except Exception as e: 103 | logger.error(f"向{collection_name}添加向量失败: {str(e)}") 104 | raise Exception(f"添加向量失败: {str(e)}") 105 | 106 | def query(self, query_vector: List[float], limit: int = 5, 107 | filter_: Optional[Dict[str, Any]] = None, collection_name: Optional[str] = None) -> List[Dict[str, Any]]: 108 | """搜索相似向量""" 109 | collection_name = collection_name or self.default_collection 110 | 111 | try: 112 | # 转换查询向量为numpy数组 113 | query_vector_np = np.array(query_vector, dtype=np.float32) 114 | 115 | # 创建过滤器 116 | search_filter = None 117 | if filter_: 118 | search_filter = rest.Filter(**filter_) 119 | 120 | # 执行搜索 121 | results = self.client.search( 122 | collection_name=collection_name, 123 | query_vector=query_vector_np.tolist(), 124 | limit=limit, 125 | query_filter=search_filter, 126 | with_payload=True, 127 | with_vectors=False 128 | ) 129 | 130 | # 格式化结果 131 | formatted_results = [ 132 | { 133 | "id": str(result.id), 134 | "score": float(result.score), 135 | "payload": result.payload 136 | } 137 | for result in results 138 | ] 139 | 140 | return formatted_results 141 | 142 | except Exception as e: 143 | logger.error(f"查询{collection_name}失败: {str(e)}") 144 | raise Exception(f"查询向量失败: {str(e)}") 145 | -------------------------------------------------------------------------------- /backend/worker/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | import time 5 | import redis 6 | import json 7 | from typing import Dict, Any, Optional 8 | import signal 9 | import sys 10 | from concurrent.futures import ThreadPoolExecutor 11 | import threading 12 | from ..services.document_service import DocumentService 13 | from ..services.vector_store import VectorStore 14 | from ..services.llm_service import LLMService 15 | from ..services.graphrag_service import GraphRAGService 16 | 17 | # 加载配置 18 | config_path = os.getenv("WORKER_CONFIG_PATH", "configs/worker.yaml") 19 | with open(config_path, "r") as f: 20 | config = yaml.safe_load(f) 21 | 22 | # 加载Redis配置 23 | redis_config_path = os.getenv("REDIS_CONFIG_PATH", "configs/redis.yaml") 24 | with open(redis_config_path, "r") as f: 25 | redis_config = yaml.safe_load(f) 26 | 27 | # 设置日志 28 | logging.basicConfig( 29 | level=getattr(logging, config["worker"]["log_level"].upper()), 30 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 31 | ) 32 | logger = logging.getLogger(__name__) 33 | 34 | # 初始化服务 35 | vector_store = VectorStore() 36 | llm_service = LLMService() 37 | document_service = DocumentService(vector_store) 38 | graphrag_service = GraphRAGService(vector_store, llm_service) 39 | 40 | # 初始化Redis 41 | redis_client = redis.Redis( 42 | host=redis_config["redis"]["host"], 43 | port=redis_config["redis"]["port"], 44 | db=redis_config["redis"]["db"], 45 | password=redis_config["redis"]["password"], 46 | decode_responses=True 47 | ) 48 | 49 | # 创建线程池 50 | thread_pool = ThreadPoolExecutor(max_workers=config["worker"]["threads"]) 51 | 52 | # 运行状态标志 53 | running = True 54 | 55 | def process_document_task(task_data: Dict[str, Any]): 56 | """处理文档任务""" 57 | try: 58 | document_id = task_data.get("document_id") 59 | task_id = task_data.get("task_id") 60 | 61 | logger.info(f"处理文档任务: {task_id}, 文档ID: {document_id}") 62 | 63 | # 更新任务状态 64 | redis_client.hset(f"task:{task_id}", "status", "processing") 65 | 66 | # 实际处理文档的逻辑 67 | # 这里通常会从存储中加载临时文件,然后调用document_service的处理方法 68 | 69 | # 如果文档处理已在API层完成,这里可以进行额外的优化或后处理 70 | # 例如,更新知识图谱 71 | graphrag_service.update_graph(document_id) 72 | 73 | # 更新任务状态为完成 74 | redis_client.hset(f"task:{task_id}", "status", "completed") 75 | redis_client.hset(f"task:{task_id}", "completed_at", time.time()) 76 | 77 | logger.info(f"文档任务 {task_id} 处理完成") 78 | 79 | except Exception as e: 80 | logger.error(f"处理文档任务失败: {str(e)}") 81 | 82 | # 更新任务状态为失败 83 | if task_id: 84 | redis_client.hset(f"task:{task_id}", "status", "failed") 85 | redis_client.hset(f"task:{task_id}", "error", str(e)) 86 | 87 | def process_embedding_task(task_data: Dict[str, Any]): 88 | """处理嵌入任务""" 89 | try: 90 | document_id = task_data.get("document_id") 91 | chunk_ids = task_data.get("chunk_ids", []) 92 | task_id = task_data.get("task_id") 93 | 94 | logger.info(f"处理嵌入任务: {task_id}, 文档ID: {document_id}, 块数: {len(chunk_ids)}") 95 | 96 | # 更新任务状态 97 | redis_client.hset(f"task:{task_id}", "status", "processing") 98 | 99 | # 这里实现嵌入处理逻辑 100 | # 通常是从数据库加载块内容,然后生成嵌入 101 | 102 | # 更新任务状态为完成 103 | redis_client.hset(f"task:{task_id}", "status", "completed") 104 | redis_client.hset(f"task:{task_id}", "completed_at", time.time()) 105 | 106 | logger.info(f"嵌入任务 {task_id} 处理完成") 107 | 108 | except Exception as e: 109 | logger.error(f"处理嵌入任务失败: {str(e)}") 110 | 111 | # 更新任务状态为失败 112 | if task_id: 113 | redis_client.hset(f"task:{task_id}", "status", "failed") 114 | redis_client.hset(f"task:{task_id}", "error", str(e)) 115 | 116 | def process_indexing_task(task_data: Dict[str, Any]): 117 | """处理索引任务""" 118 | try: 119 | document_ids = task_data.get("document_ids", []) 120 | rebuild_all = task_data.get("rebuild_all", False) 121 | task_id = task_data.get("task_id") 122 | 123 | logger.info(f"处理索引任务: {task_id}, 全量重建: {rebuild_all}, 文档数: {len(document_ids)}") 124 | 125 | # 更新任务状态 126 | redis_client.hset(f"task:{task_id}", "status", "processing") 127 | 128 | # 如果是全量重建 129 | if rebuild_all: 130 | graphrag_service.update_graph() 131 | else: 132 | # 为指定文档更新图结构 133 | for doc_id in document_ids: 134 | graphrag_service.update_graph(doc_id) 135 | 136 | # 更新任务状态为完成 137 | redis_client.hset(f"task:{task_id}", "status", "completed") 138 | redis_client.hset(f"task:{task_id}", "completed_at", time.time()) 139 | 140 | logger.info(f"索引任务 {task_id} 处理完成") 141 | 142 | except Exception as e: 143 | logger.error(f"处理索引任务失败: {str(e)}") 144 | 145 | # 更新任务状态为失败 146 | if task_id: 147 | redis_client.hset(f"task:{task_id}", "status", "failed") 148 | redis_client.hset(f"task:{task_id}", "error", str(e)) 149 | 150 | def poll_tasks(): 151 | """轮询并处理任务队列中的任务""" 152 | task_types = { 153 | "document": process_document_task, 154 | "embedding": process_embedding_task, 155 | "indexing": process_indexing_task 156 | } 157 | 158 | while running: 159 | try: 160 | # 从任务队列中获取任务 161 | result = redis_client.blpop("task_queue", timeout=1) 162 | if result is None: 163 | continue 164 | 165 | # 解析任务数据 166 | _, task_json = result 167 | task_data = json.loads(task_json) 168 | 169 | # 获取任务类型 170 | task_type = task_data.get("type") 171 | task_id = task_data.get("task_id") 172 | 173 | if task_type in task_types: 174 | # 提交任务到线程池 175 | thread_pool.submit(task_types[task_type], task_data) 176 | else: 177 | logger.warning(f"未知任务类型: {task_type}, 任务ID: {task_id}") 178 | 179 | except Exception as e: 180 | logger.error(f"轮询任务时出错: {str(e)}") 181 | time.sleep(1) # 避免过度消耗CPU 182 | 183 | def handle_signal(sig, frame): 184 | """处理终止信号""" 185 | global running 186 | logger.info(f"收到信号 {sig},准备关闭...") 187 | running = False 188 | 189 | # 关闭线程池 190 | thread_pool.shutdown(wait=True) 191 | 192 | # 关闭连接 193 | redis_client.close() 194 | 195 | logger.info("工作进程已安全关闭") 196 | sys.exit(0) 197 | 198 | def main(): 199 | """主函数""" 200 | # 注册信号处理器 201 | signal.signal(signal.SIGINT, handle_signal) 202 | signal.signal(signal.SIGTERM, handle_signal) 203 | 204 | logger.info("知识库工作进程启动") 205 | 206 | # 启动任务轮询 207 | poll_thread = threading.Thread(target=poll_tasks) 208 | poll_thread.daemon = True 209 | poll_thread.start() 210 | 211 | # 主线程保持运行 212 | try: 213 | while running: 214 | time.sleep(1) 215 | except KeyboardInterrupt: 216 | handle_signal(signal.SIGINT, None) 217 | 218 | if __name__ == "__main__": 219 | main() 220 | -------------------------------------------------------------------------------- /backend/worker/tasks/document_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Dict, Any, List, BinaryIO 4 | import tempfile 5 | import time 6 | import uuid 7 | from ...services.document_service import DocumentService 8 | from ...services.vector_store import VectorStore 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | class DocumentProcessorTask: 13 | """文档处理任务类""" 14 | 15 | def __init__(self, document_service: DocumentService): 16 | self.document_service = document_service 17 | 18 | def process(self, file_path: str, filename: str, user_id: str, 19 | metadata: Dict[str, Any] = None) -> Dict[str, Any]: 20 | """ 21 | 处理文档文件 22 | 23 | Args: 24 | file_path: 临时文件路径 25 | filename: 原始文件名 26 | user_id: 用户ID 27 | metadata: 可选的元数据 28 | 29 | Returns: 30 | Dict[str, Any]: 处理结果 31 | """ 32 | start_time = time.time() 33 | logger.info(f"开始处理文档: {filename}") 34 | 35 | try: 36 | # 打开文件 37 | with open(file_path, 'rb') as file: 38 | # 处理文档 39 | result = self.document_service.process_document( 40 | file=file, 41 | filename=filename, 42 | user_id=user_id, 43 | metadata=metadata 44 | ) 45 | 46 | processing_time = time.time() - start_time 47 | logger.info(f"文档处理完成: {filename}, 耗时: {processing_time:.2f}秒") 48 | 49 | return { 50 | "success": True, 51 | "document_id": result["id"], 52 | "processing_time": processing_time, 53 | "chunks_count": result["chunks_count"] 54 | } 55 | 56 | except Exception as e: 57 | logger.error(f"处理文档{filename}时出错: {str(e)}") 58 | 59 | processing_time = time.time() - start_time 60 | return { 61 | "success": False, 62 | "error": str(e), 63 | "processing_time": processing_time 64 | } 65 | finally: 66 | # 删除临时文件 67 | if os.path.exists(file_path): 68 | try: 69 | os.unlink(file_path) 70 | except Exception as e: 71 | logger.warning(f"删除临时文件{file_path}失败: {str(e)}") -------------------------------------------------------------------------------- /configs/api.yaml: -------------------------------------------------------------------------------- 1 | server: 2 | host: '0.0.0.0' 3 | port: 8000 4 | workers: 4 5 | log_level: 'info' 6 | debug: false 7 | reload: false 8 | 9 | cors: 10 | allowed_origins: 11 | - '*' 12 | allowed_methods: 13 | - 'GET' 14 | - 'POST' 15 | - 'PUT' 16 | - 'DELETE' 17 | allowed_headers: 18 | - '*' 19 | 20 | security: 21 | api_key_header: 'X-API-Key' 22 | api_key: 'your-api-key-here' # 开发环境默认密钥,生产环境应更改 23 | jwt_secret: 'your-jwt-secret-here' # 开发环境默认密钥,生产环境应更改 24 | token_expire_minutes: 1440 # 24小时 25 | 26 | rate_limiting: 27 | enabled: true 28 | max_requests: 100 29 | time_window_seconds: 60 30 | -------------------------------------------------------------------------------- /configs/ollama.yaml: -------------------------------------------------------------------------------- 1 | ollama: 2 | host: 'ollama' 3 | port: 11434 4 | timeout: 120 5 | 6 | models: 7 | default: 'deepseek-v3' 8 | embeddings: 'deepseek-embeddings' 9 | available: 10 | - 'deepseek-v3' 11 | - 'deepseek-r1' 12 | - 'deepseek-embeddings' 13 | 14 | inference: 15 | temperature: 0.7 16 | top_p: 0.9 17 | top_k: 40 18 | max_tokens: 2048 19 | stop_sequences: [] 20 | -------------------------------------------------------------------------------- /configs/qdrant.yaml: -------------------------------------------------------------------------------- 1 | qdrant: 2 | host: 'qdrant' 3 | port: 6333 4 | grpc_port: 6334 5 | prefer_grpc: true 6 | timeout: 30 7 | 8 | collections: 9 | default: 10 | name: 'documents' 11 | vector_size: 768 12 | distance: 'Cosine' 13 | optimizers: 14 | deleted_threshold: 0.2 15 | vacuum_min_vector_number: 1000 16 | index: 17 | m: 16 18 | ef_construct: 100 19 | 20 | metadata: 21 | name: 'metadata' 22 | vector_size: 768 23 | distance: 'Cosine' 24 | -------------------------------------------------------------------------------- /configs/redis.yaml: -------------------------------------------------------------------------------- 1 | redis: 2 | host: 'redis' 3 | port: 6379 4 | db: 0 5 | password: 'redispassword' 6 | timeout: 5 7 | socket_timeout: 5 8 | socket_connect_timeout: 5 9 | 10 | cache: 11 | ttl: 3600 # 默认缓存1小时 12 | search_ttl: 1800 # 搜索结果缓存30分钟 13 | document_ttl: 86400 # 文档元数据缓存1天 14 | 15 | task_queue: 16 | name: 'task_queue' 17 | max_retries: 3 18 | retry_delay: 60 # 秒 19 | 20 | sessions: 21 | ttl: 86400 # 会话缓存1天 22 | -------------------------------------------------------------------------------- /configs/worker.yaml: -------------------------------------------------------------------------------- 1 | worker: 2 | threads: 4 3 | log_level: 'info' 4 | queue_check_interval: 1 # 秒 5 | 6 | document_processing: 7 | supported_formats: 8 | - '.pdf' 9 | - '.docx' 10 | - '.doc' 11 | - '.txt' 12 | - '.html' 13 | - '.xlsx' 14 | - '.xls' 15 | - '.pptx' 16 | - '.ppt' 17 | - '.md' 18 | chunk_size: 1000 19 | chunk_overlap: 200 20 | storage: 21 | path: '/app/data/uploads' 22 | max_size_mb: 100 23 | 24 | embedding: 25 | batch_size: 10 26 | max_workers: 4 27 | timeout: 60 28 | 29 | graphrag: 30 | enabled: true 31 | similarity_threshold: 0.75 32 | max_neighbors: 5 33 | max_hops: 2 34 | update_interval: 3600 # 自动更新图结构的间隔(秒) 35 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | api: 5 | build: 6 | context: . 7 | dockerfile: docker/Dockerfile.api 8 | ports: 9 | - "${API_PORT:-8000}:8000" 10 | volumes: 11 | - ./configs:/app/configs 12 | - ./data:/app/data 13 | depends_on: 14 | - redis 15 | - qdrant 16 | - ollama 17 | environment: 18 | - API_CONFIG_PATH=/app/configs/api.yaml 19 | - REDIS_CONFIG_PATH=/app/configs/redis.yaml 20 | - OLLAMA_CONFIG_PATH=/app/configs/ollama.yaml 21 | - QDRANT_CONFIG_PATH=/app/configs/qdrant.yaml 22 | restart: unless-stopped 23 | networks: 24 | - kb-network 25 | 26 | worker: 27 | build: 28 | context: . 29 | dockerfile: docker/Dockerfile.worker 30 | volumes: 31 | - ./configs:/app/configs 32 | - ./data:/app/data 33 | depends_on: 34 | - redis 35 | - qdrant 36 | - ollama 37 | environment: 38 | - WORKER_CONFIG_PATH=/app/configs/worker.yaml 39 | - REDIS_CONFIG_PATH=/app/configs/redis.yaml 40 | - OLLAMA_CONFIG_PATH=/app/configs/ollama.yaml 41 | - QDRANT_CONFIG_PATH=/app/configs/qdrant.yaml 42 | restart: unless-stopped 43 | networks: 44 | - kb-network 45 | 46 | frontend: 47 | build: 48 | context: . 49 | dockerfile: docker/Dockerfile.frontend 50 | ports: 51 | - "${FRONTEND_PORT:-8501}:8501" 52 | volumes: 53 | - ./configs:/app/configs 54 | - ./data:/app/data 55 | depends_on: 56 | - api 57 | environment: 58 | - STREAMLIT_SERVER_PORT=8501 59 | restart: unless-stopped 60 | networks: 61 | - kb-network 62 | 63 | redis: 64 | image: redis:6.2 65 | ports: 66 | - "6379:6379" 67 | volumes: 68 | - redis-data:/data 69 | command: redis-server --requirepass ${REDIS_PASSWORD:-redispassword} 70 | restart: unless-stopped 71 | networks: 72 | - kb-network 73 | 74 | qdrant: 75 | image: qdrant/qdrant:latest 76 | ports: 77 | - "6333:6333" 78 | - "6334:6334" 79 | volumes: 80 | - qdrant-data:/qdrant/storage 81 | restart: unless-stopped 82 | networks: 83 | - kb-network 84 | 85 | ollama: 86 | image: ollama/ollama:latest 87 | ports: 88 | - "11434:11434" 89 | volumes: 90 | - ollama-data:/root/.ollama 91 | environment: 92 | - OLLAMA_HOST=0.0.0.0 93 | - OLLAMA_MODELS=/root/.ollama/models 94 | restart: unless-stopped 95 | deploy: 96 | resources: 97 | reservations: 98 | devices: 99 | - driver: nvidia 100 | count: 1 101 | capabilities: [gpu] 102 | networks: 103 | - kb-network 104 | 105 | networks: 106 | kb-network: 107 | driver: bridge 108 | 109 | volumes: 110 | redis-data: 111 | qdrant-data: 112 | ollama-data: 113 | -------------------------------------------------------------------------------- /docker/.env: -------------------------------------------------------------------------------- 1 | # API服务配置 2 | API_HOST=0.0.0.0 3 | API_PORT=8000 4 | API_WORKERS=4 5 | API_LOG_LEVEL=info 6 | 7 | # 前端服务配置 8 | FRONTEND_HOST=0.0.0.0 9 | FRONTEND_PORT=8501 10 | 11 | # Worker配置 12 | WORKER_THREADS=4 13 | WORKER_LOG_LEVEL=info 14 | 15 | # Redis配置 16 | REDIS_HOST=redis 17 | REDIS_PORT=6379 18 | REDIS_PASSWORD=redispassword 19 | REDIS_DB=0 20 | 21 | # Qdrant配置 22 | QDRANT_HOST=qdrant 23 | QDRANT_PORT=6333 24 | QDRANT_GRPC_PORT=6334 25 | 26 | # Ollama配置 27 | OLLAMA_HOST=ollama 28 | OLLAMA_PORT=11434 29 | OLLAMA_TIMEOUT=120 30 | 31 | # DeepSeek模型配置 32 | DEEPSEEK_MODEL=deepseek-v3 33 | EMBEDDINGS_MODEL=deepseek-embeddings 34 | -------------------------------------------------------------------------------- /docker/Dockerfile.api: -------------------------------------------------------------------------------- 1 | FROM python:3.9-slim 2 | 3 | # 设置工作目录 4 | WORKDIR /app 5 | 6 | # 安装依赖 7 | COPY requirements.txt . 8 | RUN pip install --no-cache-dir -r requirements.txt 9 | 10 | # 复制应用代码 11 | COPY . . 12 | 13 | # 设置环境变量 14 | ENV PYTHONPATH=/app 15 | ENV API_CONFIG_PATH=/app/configs/api.yaml 16 | ENV REDIS_CONFIG_PATH=/app/configs/redis.yaml 17 | ENV OLLAMA_CONFIG_PATH=/app/configs/ollama.yaml 18 | 19 | # 设置端口 20 | EXPOSE 8000 21 | 22 | # 启动命令 23 | CMD ["uvicorn", "backend.api.main:app", "--host", "0.0.0.0", "--port", "8000"] 24 | docker/Dockerfile.worker 25 | 26 | FROM python:3.9-slim 27 | 28 | # 设置工作目录 29 | WORKDIR /app 30 | 31 | # 安装依赖 32 | COPY requirements.txt . 33 | RUN pip install --no-cache-dir -r requirements.txt 34 | 35 | # 复制应用代码 36 | COPY . . 37 | 38 | # 设置环境变量 39 | ENV PYTHONPATH=/app 40 | ENV WORKER_CONFIG_PATH=/app/configs/worker.yaml 41 | ENV REDIS_CONFIG_PATH=/app/configs/redis.yaml 42 | ENV OLLAMA_CONFIG_PATH=/app/configs/ollama.yaml 43 | 44 | # 启动命令 45 | CMD ["python", "-m", "backend.worker.main"] 46 | -------------------------------------------------------------------------------- /docker/Dockerfile.frontend: -------------------------------------------------------------------------------- 1 | FROM python:3.9-slim 2 | 3 | # 设置工作目录 4 | WORKDIR /app 5 | 6 | # 安装依赖 7 | COPY requirements.txt . 8 | RUN pip install --no-cache-dir -r requirements.txt 9 | 10 | # 复制应用代码 11 | COPY . . 12 | 13 | # 设置环境变量 14 | ENV PYTHONPATH=/app 15 | 16 | # 设置端口 17 | EXPOSE 8501 18 | 19 | # 启动命令 20 | CMD ["streamlit", "run", "frontend/app.py", "--server.port=8501", "--server.address=0.0.0.0"] 21 | -------------------------------------------------------------------------------- /docker/Dockerfile.worker: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /frontend/app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import requests 3 | import json 4 | import os 5 | import time 6 | import pandas as pd 7 | import plotly.express as px 8 | from pathlib import Path 9 | import yaml 10 | from io import BytesIO 11 | 12 | 13 | # 在app.py适当位置添加以下内容 14 | 15 | # 导入页面组件 16 | from pages import home, documents, chat, search, analysis, system_status 17 | 18 | 19 | 20 | # 加载配置 21 | config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../configs/api.yaml") 22 | with open(config_path, "r") as f: 23 | config = yaml.safe_load(f) 24 | 25 | # API基础URL 26 | API_URL = f"http://{config['server']['host']}:{config['server']['port']}" 27 | 28 | # 设置页面配置 29 | st.set_page_config( 30 | page_title="DeepSeek本地知识库系统", 31 | page_icon="📚", 32 | layout="wide", 33 | initial_sidebar_state="expanded" 34 | ) 35 | 36 | # 添加CSS样式 37 | st.markdown(""" 38 | 80 | """, unsafe_allow_html=True) 81 | 82 | # 导航设置 83 | def main(): 84 | # 侧边栏菜单 85 | st.sidebar.markdown("## 📚 DeepSeek知识库") 86 | page = st.sidebar.radio( 87 | "导航", 88 | ["首页", "文档管理", "聊天问答", "搜索", "数据分析", "系统状态"] 89 | ) 90 | 91 | # 根据页面选择渲染对应内容 92 | if page == "首页": 93 | home.render() 94 | elif page == "文档管理": 95 | documents.render() 96 | elif page == "聊天问答": 97 | chat.render() 98 | elif page == "搜索": 99 | search.render() 100 | elif page == "数据分析": 101 | analysis.render() 102 | elif page == "系统状态": 103 | system_status.render() 104 | # 侧边栏菜单 105 | # st.sidebar.markdown("## 📚 DeepSeek知识库") 106 | # page = st.sidebar.radio( 107 | # "导航", 108 | # ["首页", "文档管理", "聊天问答", "搜索", "数据分析", "系统状态"] 109 | # ) 110 | 111 | # # 根据页面选择渲染对应内容 112 | # if page == "首页": 113 | # render_home() 114 | # elif page == "文档管理": 115 | # render_document_manager() 116 | # elif page == "聊天问答": 117 | # render_chat() 118 | # elif page == "搜索": 119 | # render_search() 120 | # elif page == "数据分析": 121 | # render_analysis() 122 | # elif page == "系统状态": 123 | # render_system_status() 124 | 125 | # 首页 126 | def render_home(): 127 | st.markdown('
{subtitle}
', unsafe_allow_html=True) 9 | 10 | def display_info_card(label: str, value: str, delta: Optional[str] = None, color: str = "#EFF6FF"): 11 | """显示信息卡片""" 12 | st.markdown(f'大小: {doc.get('大小', 'Unknown')}
23 |上传时间: {doc.get('上传时间', 'Unknown')}
24 |分类: {doc.get('分类', 'Uncategorized')}
25 |状态: {doc.get('状态', 'Unknown')}
26 |