├── .gitignore ├── README.md ├── README_EN.md ├── cmd └── main.go ├── config ├── config.go ├── config.yaml.example └── types.go ├── docker-compose.yml ├── docs ├── AgentChat.png ├── AgentDebug.png ├── CONFIG_README.md └── QUICKSTART.md ├── go.mod ├── go.sum ├── init.sql ├── internal ├── component │ ├── embedding │ │ ├── embedding.go │ │ ├── ollama.go │ │ └── openai.go │ ├── indexer │ │ └── milvus │ │ │ ├── indexer.go │ │ │ └── types.go │ ├── llm │ │ ├── llm.go │ │ ├── ollama │ │ │ ├── call_option.go │ │ │ └── ollama.go │ │ └── openai │ │ │ └── openai.go │ ├── parser │ │ └── pdf │ │ │ └── docconv_parser.go │ └── retriever │ │ └── milvus │ │ ├── multi_retriever.go │ │ ├── retriever.go │ │ └── types.go ├── controller │ ├── agent_controller.go │ ├── conversation_controller.go │ ├── file_controller.go │ ├── kb_controller.go │ ├── model_controller.go │ └── user_controller.go ├── dao │ ├── agent.go │ ├── file_dao.go │ ├── history │ │ ├── conv_dao.go │ │ └── msg_dao.go │ ├── kb_dao.go │ ├── model_dao.go │ └── user_dao.go ├── database │ ├── milvus.go │ └── mysql.go ├── middleware │ ├── auth.go │ ├── cors.go │ └── jwt.go ├── model │ ├── agent.go │ ├── chat.go │ ├── conversation.go │ ├── file.go │ ├── knowledge.go │ ├── model.go │ └── user.go ├── router │ └── router.go ├── service │ ├── agent_service.go │ ├── conversation_service.go │ ├── file_service.go │ ├── history_service.go │ ├── kb_service.go │ ├── model_service.go │ └── user_service.go ├── storage │ ├── local.go │ ├── minio.go │ ├── oss.go │ └── storage.go └── utils │ ├── agent_utils.go │ ├── context.go │ ├── convert.go │ ├── convert_float.go │ ├── hitory_utils.go │ ├── pagination.go │ ├── uuid.go │ └── validate_sort.go └── pkgs ├── consts └── milvus_const.go ├── errcode └── errcode.go └── response └── response.go /.gitignore: -------------------------------------------------------------------------------- 1 | # 编译生成的二进制文件 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | *.out 8 | 9 | # 测试文件 10 | *.test 11 | *.out 12 | coverage.txt 13 | 14 | # Go 专用 15 | *.o 16 | *.a 17 | *.so 18 | _obj 19 | _test 20 | *.[568vq] 21 | [568vq].out 22 | *.cgo1.go 23 | *.cgo2.c 24 | _cgo_defun.c 25 | _cgo_gotypes.go 26 | _cgo_export.* 27 | _testmain.go 28 | go.work 29 | 30 | # 依赖管理 31 | vendor/ 32 | Godeps/ 33 | go-build 34 | 35 | # 日志和临时文件 36 | logs/ 37 | *.log 38 | tmp/ 39 | temp/ 40 | 41 | # 特定 IDE 和编辑器 42 | .idea/ 43 | .vscode/ 44 | *.swp 45 | *.swo 46 | *~ 47 | .DS_Store 48 | 49 | # 环境和配置文件(注意不要排除 config 目录) 50 | .env 51 | .envrc 52 | env.sh 53 | .env.local 54 | .env.development.local 55 | .env.test.local 56 | .env.production.local 57 | 58 | # 实际的配置文件(不排除示例文件) 59 | config.yaml 60 | config.yml 61 | config.json 62 | config.toml 63 | config.ini 64 | config.hcl 65 | config.props 66 | config.properties 67 | config/config.yaml 68 | config/config.yml 69 | config/config.json 70 | config/config.toml 71 | 72 | # 明确说明不排除 config 目录和示例文件 73 | !config/ 74 | !*.example 75 | !*.example.* 76 | !*.sample 77 | !*.sample.* 78 | !*.template 79 | !*.template.* 80 | 81 | # 敏感信息和凭证 82 | *.pem 83 | *.key 84 | *.crt 85 | *.p12 86 | *.pfx 87 | *.der 88 | *.csr 89 | *.cer 90 | *.keystore 91 | *.jks 92 | *.p7b 93 | *.p7c 94 | *.secret 95 | *.password 96 | *.credentials 97 | *secret* 98 | *password* 99 | *credential* 100 | *token* 101 | *apikey* 102 | *api_key* 103 | 104 | # 其他可能的敏感文件 105 | *.bak 106 | *.backup 107 | 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AI-Cloud-Go 基于Golang开发的LLM应用系统 2 | 3 | [English](README_EN) | [中文](README) 4 | 5 | ## 项目简介 6 | AI-Cloud-Go 是一个基于 Go 语言实现的LLM应用于开发系统,提供文件存储、用户管理、知识库管理、模型管理、Agent等功能,采用现代化的技术栈和架构设计。系统支持多种存储后端,并集成了向量数据库以支持智能检索功能。 7 | 8 | 前端界面展示: 9 | 10 | [AI-Cloud-Frontend](https://github.com/RaspberryCola/AI-Cloud-Frontend) 11 | 12 | 13 | ## 功能说明 14 | 15 | ### 部分功能展示 16 | 17 | **Agent配置** 18 | 19 | ![](./docs/AgentDebug.png) 20 | 21 | **Agent Chat** 22 | ![](./docs/AgentChat.png) 23 | 24 | **已经实现:** 25 | 26 | - [x] 用户模块:支持用户注册、登录、认证 27 | - [x] 多存储后端:支持本地存储、MinIO、阿里云OSS等多种存储方式 28 | - [x] 文件模块:支持文件上传、下载、管理 29 | - [x] 知识库模块:支持创建和管理知识库,支持导入云盘文件或上传新文件 30 | - [x] 模型模块:支持创建和管理自定义LLM模型和Embedding模型 31 | - [x] Agent模块:支持创建和管理Agent 32 | - [x] 支持自定义LLM、知识库、MCP 33 | - [x] 对话界面、历史对话 34 | 35 | **未来优化** 36 | 37 | - 知识库模块:多文件上传/优化解析状态处理/支持Rerank 38 | - Agent模块:自定义Tool(HTTP工具)/优化LLM的参数配置/跨知识库检索的Rerank实现 39 | - 模型管理:添加常用模型预设:OpenAI,DeepSeek,火山引擎等/支持Rerank模型 40 | 41 | ## 技术栈 42 | - 后端框架:Gin 43 | - 数据库:MySQL 44 | - 向量数据库:Milvus 45 | - 对象存储:MinIO/阿里云OSS 46 | - 认证:JWT 47 | - LLM框架:Eino 48 | - 其他:跨域中间件、自定义中间件等 49 | 50 | ## 目录结构 51 | ``` 52 | . 53 | ├── cmd/ # 主程序入口 54 | ├── config/ # 配置文件 55 | ├── docs/ # 文档目录 56 | │ ├── CONFIG_README.md # 配置指南 57 | │ └── QUICKSTART.md # 快速启动指南 58 | ├── internal/ # 内部包 59 | │ ├── component/ # 大模型相关服务 60 | │ ├── controller/ # 控制器层 61 | │ ├── service/ # 业务逻辑层 62 | │ ├── dao/ # 数据访问层 63 | │ ├── middleware/ # 中间件 64 | │ ├── router/ # 路由配置 65 | │ ├── database/ # 数据库(MySQL/Milvus...) 66 | │ ├── model/ # 数据模型 67 | │ ├── storage/ # 后端存储实现(Minio/OSS...) 68 | │ └── utils/ # 工具函数 69 | ├── pkgs/ # 公共包 70 | ├── docker-compose.yml # Docker配置文件 71 | ├── go.mod # Go 模块文件 72 | └── go.sum # 依赖版本锁定文件 73 | ``` 74 | 75 | ## 使用说明 76 | 1. 克隆项目 77 | ```bash 78 | git clone https://github.com/RaspberryCola/AI-Cloud-Go.git 79 | cd AI-Cloud-Go 80 | ``` 81 | 82 | 2. 安装依赖 83 | ```bash 84 | go mod download 85 | ``` 86 | 87 | 3. 配置环境 88 | - 确保已安装并启动 MySQL 89 | - 确保已安装并启动 Milvus(如需使用向量检索功能) 90 | - 配置存储后端(本地存储/MinIO/阿里云OSS) 91 | 92 | 可以通过Docker快速配置: 93 | ```bash 94 | docker-compose up -d 95 | ``` 96 | 4. 修改配置信息 97 | - 修改 `config/config.yaml` 中的相关配置 98 | 99 | 5. 运行项目 100 | ```bash 101 | go run cmd/main.go 102 | ``` 103 | 104 | # AI-Cloud-Go Docker 完整环境配置 105 | 106 | > 更多详情请看 [docs 文件夹](/docs/) 107 | 108 | ## 前置条件 109 | - 已安装 Docker 和 Docker Compose 110 | - 基本了解 Docker 的使用 111 | 112 | ## 服务组件 113 | 本配置包含以下服务: 114 | - **MySQL**: 数据库服务器,配置 `ai_cloud` 数据库 115 | - **MinIO**: 对象存储服务,兼容 S3 协议,配置 `ai-cloud` 存储桶 116 | - **Milvus**: 向量数据库,用于存储和检索文档向量 117 | 118 | ## 启动步骤 119 | 120 | 1. 环境配置 121 | - 修改 `config/config.yaml` 文件,配置各服务连接信息和LLM的API密钥信息 122 | ```yaml 123 | llm: 124 | api_key: "your-llm-api-key" 125 | model: "deepseek-chat" 126 | base_url: "https://api.deepseek.com/v1" 127 | ``` 128 | ⚠️语言模型配置后续将会移动到统一的模型服务管理中 129 | 2. 启动 Docker 容器 130 | ```bash 131 | docker-compose up -d 132 | ``` 133 | 134 | 3. 检查服务状态 135 | ```bash 136 | docker ps 137 | ``` 138 | 139 | 4. 运行项目 140 | ```bash 141 | go run cmd/main.go 142 | ``` 143 | 144 | 5. 停止容器 145 | ```bash 146 | docker-compose down 147 | ``` 148 | 149 | ## 配置详情 150 | 151 | ### MySQL 152 | - 主机: localhost 153 | - 端口: 3306 154 | - 用户名: root 155 | - 密码: 123456 156 | - 数据库: ai_cloud 157 | 158 | ### MinIO (对象存储) 159 | - 端点: localhost:9000 160 | - 管理控制台: http://localhost:9001 161 | - 访问密钥: minioadmin 162 | - 密钥: minioadmin 163 | - 存储桶: ai-cloud 164 | 165 | ### Milvus (向量数据库) 166 | - 配置路径: `milvus.address` 在 `config.yaml` 中 167 | - 默认地址: localhost:19530 168 | - 管理界面: 需要额外安装Attu (Milvus官方GUI工具) 169 | - 在端口9091只提供监控信息: http://localhost:9091/webui (Milvus 2.5.0+版本) 170 | - 完整管理界面需安装Attu: `docker run -p 8000:3000 -e MILVUS_URL=localhost:19530 zilliz/attu:latest` 171 | - 访问Attu: http://localhost:8000 172 | 173 | ### 环境配置 174 | 项目使用 `config/config.yaml` 文件配置服务连接和第三方 AI 模型的访问,主要包括: 175 | 176 | 1.**语言模型配置** 177 | - 用于知识库问答和智能处理 178 | - 默认使用 DeepSeek 的 deepseek-chat 模型 179 | 180 | 2.**Milvus配置** 181 | - 用于向量存储和检索 182 | - 配置项: `milvus.address` 183 | 184 | ## 故障排除 185 | 186 | ### 常见问题 187 | 188 | 1. **程序启动卡住** 189 | - 检查 Milvus 是否正常启动,查看日志 `docker logs milvus-standalone` 190 | - 检查 `config/config.yaml` 文件中的 Milvus 地址配置是否正确 191 | - 检查 `config/config.yaml` 文件是否配置了正确的 API 密钥 192 | - 确保 MySQL 中已创建 ai_cloud 数据库 193 | 194 | 2. **MinIO 连接问题** 195 | - 检查 MinIO 服务是否正常运行 196 | - 验证 config.yaml 中的 MinIO 配置是否与实际运行环境一致 197 | 198 | 3. **向量数据库操作失败** 199 | - 检查 Milvus 服务状态 200 | - 确认向量维度设置是否与模型输出一致 201 | - 确认 config.yaml 中的 milvus.address 配置是否正确 202 | 203 | 4. **Milvus 管理界面访问失败** 204 | - Milvus 2.5以上版本在端口9091提供简易WebUI: http://localhost:9091/webui 205 | - 完整的管理界面需要安装Attu工具: `docker run -p 8000:3000 -e MILVUS_URL=localhost:19530 zilliz/attu:latest` 206 | 207 | ## 开发调试 208 | 209 | ### API 测试 210 | 项目启动后,可通过以下端点进行测试: 211 | - 用户注册: POST http://localhost:8080/api/users/register 212 | - 用户登录: POST http://localhost:8080/api/users/login 213 | - 文件上传: POST http://localhost:8080/api/files/upload 214 | - 更多 API 详见代码中的路由配置 215 | 216 | ### 服务地址 217 | - 应用后端: http://localhost:8080 218 | - MinIO 控制台: http://localhost:9001 219 | - Milvus 管理界面: http://localhost:9091 220 | - Ollama 服务: http://localhost:11434 (如果启用) 221 | 222 | ## 注意事项 223 | - 如果您已经在本地运行了 MySQL,可以直接使用 `mysql -u root -p < init.sql` 创建 ai_cloud 数据库 224 | - 首次启动时,Milvus 会自动创建必要的集合和索引,可能需要一些时间 225 | - 使用生产环境时,请替换配置文件中的示例 API 密钥为您自己的有效密钥 226 | - 配置 config.yaml 时确保数据库和存储服务的连接信息与 Docker 环境一致 227 | - Go主程序使用端口8080,Ollama服务使用端口11434,避免端口冲突 -------------------------------------------------------------------------------- /README_EN.md: -------------------------------------------------------------------------------- 1 | # AI-Cloud-Go — A Golang-Based Cloud Drive & Knowledge Base System 2 | 3 | [English](README_EN) | [中文](README) 4 | 5 | ## Overview 6 | AI-Cloud-Go is a cloud drive and knowledge base system built using the Go programming language. It offers features such as file storage, user management, knowledge base management, and model management, leveraging modern technology stacks and architecture design. The system supports multiple storage backends and integrates with a vector database to enable intelligent search capabilities. 7 | 8 | Frontend repository: 9 | 10 | [AI-Cloud-Frontend](https://github.com/RaspberryCola/AI-Cloud-Frontend) 11 | 12 | ## Key Features 13 | 14 | ### Completed: 15 | - [x] User system: Supports registration, login, and authentication 16 | - [x] Cloud file system: Supports file upload, download, and management 17 | - [x] Knowledge base management: Create and manage knowledge bases, import files from cloud or upload new ones 18 | - [x] Model management: Manage custom LLM and Embedding models 19 | - [x] Multiple storage backends: Supports local storage, MinIO, Alibaba Cloud OSS, etc. 20 | - [x] Vector search: Integrated with Milvus vector database for smart document retrieval 21 | 22 | ### In Development: 23 | - [ ] Agent functionality 24 | 25 | ## Tech Stack 26 | - Backend Framework: Gin 27 | - Database: MySQL 28 | - Vector Database: Milvus 29 | - Object Storage: MinIO / Alibaba Cloud OSS 30 | - Authentication: JWT (JSON Web Token) 31 | - LLM Framework: Eino 32 | - Others: CORS middleware, custom middleware, etc. 33 | 34 | ## Directory Structure 35 | ``` 36 | . 37 | ├── cmd/ # Main application entry point 38 | ├── config/ # Configuration files 39 | ├── docs/ # Documentation 40 | │ ├── CONFIG_README.md # Configuration guide 41 | │ └── QUICKSTART.md # Quick start guide 42 | ├── internal/ # Internal packages 43 | │ ├── component/ # Large model related services 44 | │ ├── controller/ # Controller layer 45 | │ ├── service/ # Business logic layer 46 | │ ├── dao/ # Data access layer 47 | │ ├── middleware/ # Middleware components 48 | │ ├── router/ # Routing configuration 49 | │ ├── database/ # Database (MySQL/Milvus...) 50 | │ ├── model/ # Data models 51 | │ ├── storage/ # Storage backend implementations (MinIO/OSS...) 52 | │ └── utils/ # Utility functions 53 | ├── pkgs/ # Shared packages 54 | ├── docker-compose.yml # Docker configuration 55 | ├── go.mod # Go module file 56 | └── go.sum # Dependency version lock 57 | ``` 58 | 59 | ## Usage Instructions 60 | 61 | 1. Clone the project: 62 | ```bash 63 | git clone https://github.com/RaspberryCola/AI-Cloud-Go.git 64 | cd AI-Cloud-Go 65 | ``` 66 | 67 | 2. Install dependencies: 68 | ```bash 69 | go mod download 70 | ``` 71 | 72 | 3. Environment setup: 73 | - Ensure MySQL is installed and running 74 | - Ensure Milvus is installed if using vector search 75 | - Configure storage backend (local/MinIO/Alibaba Cloud OSS) 76 | 77 | You can use Docker Compose for quick setup: 78 | ```bash 79 | docker-compose up -d 80 | ``` 81 | 82 | 4. Modify configuration: 83 | - Update `config/config.yaml` accordingly 84 | 85 | 5. Run the project: 86 | ```bash 87 | go run cmd/main.go 88 | ``` 89 | 90 | # AI-Cloud-Go — Docker Full Environment Setup 91 | 92 | > For more details, see the [docs folder](/docs/) 93 | 94 | ## Prerequisites 95 | - Docker and Docker Compose installed 96 | - Basic understanding of Docker usage 97 | 98 | ## Services Included 99 | This setup includes the following services: 100 | - **MySQL**: Database server, configured with the `ai_cloud` database 101 | - **MinIO**: Object storage service compatible with S3 protocol, configured with bucket `ai-cloud` 102 | - **Milvus**: Vector database for storing and retrieving document vectors 103 | 104 | ## Setup Steps 105 | 106 | 1. **Environment Configuration** 107 | - Edit `config/config.yaml` to configure service connection info and LLM API keys 108 | ```yaml 109 | llm: 110 | api_key: "your-llm-api-key" 111 | model: "deepseek-chat" 112 | base_url: "https://api.deepseek.com/v1" 113 | ``` 114 | ⚠️ LLM configuration will be moved into unified model management in the future 115 | 116 | 2. **Start Docker Containers** 117 | ```bash 118 | docker-compose up -d 119 | ``` 120 | 121 | 3. **Check Service Status** 122 | ```bash 123 | docker ps 124 | ``` 125 | 126 | 4. **Run the Application** 127 | ```bash 128 | go run cmd/main.go 129 | ``` 130 | 131 | 5. **Stop Containers** 132 | ```bash 133 | docker-compose down 134 | ``` 135 | 136 | ## Configuration Details 137 | 138 | ### MySQL 139 | - Host: localhost 140 | - Port: 3306 141 | - Username: root 142 | - Password: 123456 143 | - Database: ai_cloud 144 | 145 | ### MinIO (Object Storage) 146 | - Endpoint: localhost:9000 147 | - Management Console: http://localhost:9001 148 | - Access Key: minioadmin 149 | - Secret Key: minioadmin 150 | - Bucket: ai-cloud 151 | 152 | ### Milvus (Vector Database) 153 | - Config path: `milvus.address` in `config.yaml` 154 | - Default address: localhost:19530 155 | - Admin UI: Requires installing Attu (official Milvus GUI tool) 156 | - Lightweight web UI on port 9091: http://localhost:9091/webui (for Milvus 2.5.0+) 157 | - Full admin interface via Attu: 158 | ```bash 159 | docker run -p 8000:3000 -e MILVUS_URL=localhost:19530 zilliz/attu:latest 160 | ``` 161 | - Visit Attu at: http://localhost:8000 162 | 163 | ### Environment Configuration 164 | The system uses `config/config.yaml` to configure service connections and AI model access, including: 165 | 166 | 1. **Language Model Configuration** 167 | - Used for Q&A and intelligent processing 168 | - Default: DeepSeek's `deepseek-chat` model 169 | 170 | 2. **Milvus Configuration** 171 | - For vector storage and retrieval 172 | - Config: `milvus.address` 173 | 174 | ## Troubleshooting 175 | 176 | ### Common Issues 177 | 178 | 1. **Application Hangs on Startup** 179 | - Check if Milvus started correctly: `docker logs milvus-standalone` 180 | - Confirm Milvus address in `config.yaml` is correct 181 | - Ensure valid API key is set in config.yaml 182 | - Make sure the `ai_cloud` database exists in MySQL 183 | 184 | 2. **MinIO Connection Problems** 185 | - Verify MinIO is running 186 | - Confirm config.yaml matches actual MinIO environment settings 187 | 188 | 3. **Vector DB Operations Fail** 189 | - Check Milvus service status 190 | - Ensure vector dimension matches model output 191 | - Confirm `milvus.address` is properly set 192 | 193 | 4. **Cannot Access Milvus Web UI** 194 | - Milvus 2.5+ provides basic UI at http://localhost:9091/webui 195 | - For full UI, install Attu with: 196 | ```bash 197 | docker run -p 8000:3000 -e MILVUS_URL=localhost:19530 zilliz/attu:latest 198 | ``` 199 | - Access at: http://localhost:8000 200 | 201 | ## Development and Debugging 202 | 203 | ### API Testing 204 | After starting the app, you can test APIs like: 205 | - User Registration: POST http://localhost:8080/api/users/register 206 | - User Login: POST http://localhost:8080/api/users/login 207 | - File Upload: POST http://localhost:8080/api/files/upload 208 | - See code for more endpoints 209 | 210 | ### Service URLs 211 | - App Backend: http://localhost:8080 212 | - MinIO Console: http://localhost:9001 213 | - Milvus Admin UI: http://localhost:9091 214 | - Ollama Service: http://localhost:11434 (if enabled) 215 | 216 | ## Notes 217 | - If you already have MySQL running locally, create the db via: 218 | ```bash 219 | mysql -u root -p < init.sql 220 | ``` 221 | - Milvus automatically creates collections/indexes on first launch (may take time) 222 | - Replace sample API keys in config.yaml with your own credentials before production 223 | - Ensure config.yaml reflects Docker service addresses 224 | - Go server uses port 8080; Ollama uses 11434 – avoid conflicts 225 | 226 | 227 | ## Third-Party Licenses 228 | 229 | This project uses the following third-party libraries: 230 | 231 | - **eino-history** 232 | - Source: [https://github.com/HildaM/eino-history ](https://github.com/HildaM/eino-history ) 233 | - License: [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0 ) 234 | - Description: A chat history management library for Eino large model framework. 235 | 236 | Modifications have been made to the original code to suit the needs of this project. -------------------------------------------------------------------------------- /cmd/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "ai-cloud/config" 5 | _ "ai-cloud/internal/component/embedding" 6 | "ai-cloud/internal/controller" 7 | "ai-cloud/internal/dao" 8 | "ai-cloud/internal/dao/history" 9 | "ai-cloud/internal/database" 10 | "ai-cloud/internal/middleware" 11 | "ai-cloud/internal/router" 12 | "ai-cloud/internal/service" 13 | "context" 14 | 15 | "github.com/gin-gonic/gin" 16 | ) 17 | 18 | func main() { 19 | config.InitConfig() 20 | ctx := context.Background() 21 | 22 | db, _ := database.GetDB() 23 | 24 | userDao := dao.NewUserDao(db) 25 | userService := service.NewUserService(userDao) 26 | userController := controller.NewUserController(userService) 27 | fileDao := dao.NewFileDao(db) 28 | fileService := service.NewFileService(fileDao) 29 | fileController := controller.NewFileController(fileService) 30 | 31 | milvusClient, _ := database.InitMilvus(ctx) 32 | defer milvusClient.Close() 33 | 34 | modelDao := dao.NewModelDao(db) 35 | modelService := service.NewModelService(modelDao) 36 | modelController := controller.NewModelController(modelService) 37 | 38 | kbDao := dao.NewKnowledgeBaseDao(db) 39 | kbService := service.NewKBService(kbDao, fileService, modelDao) 40 | kbController := controller.NewKBController(kbService, fileService) 41 | 42 | msgDao := history.NewMsgDao(db) 43 | convDao := history.NewConvDao(db) 44 | historyService := service.NewHistoryService(convDao, msgDao) 45 | 46 | agentDao := dao.NewAgentDao(db) 47 | agentService := service.NewAgentService(agentDao, modelService, kbService, kbDao, modelDao, historyService) 48 | agentController := controller.NewAgentController(agentService) 49 | 50 | // 创建ConversationService和ConversationController 51 | conversationService := service.NewConversationService(agentService, historyService) 52 | conversationController := controller.NewConversationController(conversationService) 53 | 54 | r := gin.Default() 55 | // 配置跨域 56 | r.Use(middleware.SetupCORS()) 57 | // 配置路由 58 | router.SetUpRouters(r, userController, fileController, kbController, modelController, agentController, conversationController) 59 | 60 | r.Run(":8080") 61 | } 62 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/fsnotify/fsnotify" 7 | "github.com/spf13/viper" 8 | ) 9 | 10 | var AppConfigInstance *AppConfig 11 | 12 | // InitConfig 初始化配置 13 | func InitConfig() { 14 | // 初始化 AppConfigInstance 15 | AppConfigInstance = &AppConfig{} 16 | 17 | // 加载配置 18 | viper.SetConfigName("config") 19 | viper.SetConfigType("yaml") 20 | viper.AddConfigPath("./config") 21 | 22 | if err := viper.ReadInConfig(); err != nil { 23 | log.Fatalf("Error reading config file: %v", err) 24 | } 25 | 26 | // 监听配置变化 27 | viper.WatchConfig() 28 | viper.OnConfigChange(func(e fsnotify.Event) { 29 | if err := viper.Unmarshal(AppConfigInstance); err != nil { 30 | log.Printf("loadConfig failed, unmarshal config err: %v", err) 31 | } 32 | }) 33 | 34 | // 解析配置 35 | if err := viper.Unmarshal(AppConfigInstance); err != nil { 36 | log.Fatalf("Unable to decode into struct: %v", err) 37 | } 38 | } 39 | 40 | // GetConfig 获取配置 41 | func GetConfig() *AppConfig { 42 | return AppConfigInstance 43 | } 44 | -------------------------------------------------------------------------------- /config/config.yaml.example: -------------------------------------------------------------------------------- 1 | server: 2 | port: "8080" 3 | 4 | # mysql配置 5 | database: 6 | host: "localhost" 7 | port: "3306" 8 | user: "root" 9 | password: "123456" 10 | name: "ai_cloud" 11 | 12 | jwt: 13 | secret: "your-jwt-secret" 14 | expiration_hours: 24 15 | 16 | # 后端文件存储服务 17 | storage: 18 | type: "minio" # local, oss, minio 19 | local: 20 | base_dir: "./storage_data" 21 | oss: 22 | endpoint: "your-oss-endpoint" 23 | bucket: "your-oss-bucket" 24 | access_key_id: "your-access-key-id" 25 | access_key_secret: "your-access-key-secret" 26 | minio: 27 | endpoint: "localhost:9000" 28 | bucket: "ai-cloud" 29 | access_key_id: "minioadmin" 30 | access_key_secret: "minioadmin" 31 | use_ssl: false 32 | region: "" 33 | 34 | # Milvus向量数据库配置 35 | milvus: 36 | address: "localhost:19530" 37 | # collection_name: "text_chunks" 38 | # vector_dimension: 1024 39 | index_type: "IVF_FLAT" 40 | metric_type: "COSINE" 41 | nlist: 128 42 | # 搜索参数 43 | nprobe: 16 44 | # 字段最大长度配置 45 | id_max_length: "64" 46 | content_max_length: "65535" 47 | doc_id_max_length: "64" 48 | doc_name_max_length: "256" 49 | kb_id_max_length: "64" 50 | 51 | rag: 52 | chunk_size: 1500 53 | overlap_size: 500 54 | 55 | cors: 56 | allow_origins: 57 | - "*" 58 | allow_methods: 59 | - "GET" 60 | - "POST" 61 | - "PUT" 62 | - "PATCH" 63 | - "DELETE" 64 | - "OPTIONS" 65 | allow_headers: 66 | - "Origin" 67 | - "Content-Type" 68 | - "Accept" 69 | - "Authorization" 70 | expose_headers: 71 | - "Content-Length" 72 | allow_credentials: true 73 | max_age: "12h" 74 | 75 | ## LLM配置 76 | llm: 77 | api_key: "your-llm-api-key" 78 | model: "deepseek-chat" 79 | base_url: "https://api.deepseek.com/v1" 80 | max_tokens: 10240 81 | temperature: 0.7 -------------------------------------------------------------------------------- /config/types.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/milvus-io/milvus-sdk-go/v2/entity" 5 | ) 6 | 7 | // ServerConfig 服务器配置 8 | type ServerConfig struct { 9 | Port string `mapstructure:"port"` 10 | } 11 | 12 | // DatabaseConfig 数据库配置 13 | type DatabaseConfig struct { 14 | Host string `mapstructure:"host"` 15 | Port string `mapstructure:"port"` 16 | User string `mapstructure:"user"` 17 | Password string `mapstructure:"password"` 18 | Name string `mapstructure:"name"` 19 | } 20 | 21 | // JWTConfig JWT配置 22 | type JWTConfig struct { 23 | Secret string `mapstructure:"secret"` 24 | ExpirationHours int `mapstructure:"expiration_hours"` 25 | } 26 | 27 | // MinioConfig Minio配置 28 | type MinioConfig struct { 29 | Endpoint string `mapstructure:"endpoint"` 30 | Bucket string `mapstructure:"bucket"` 31 | AccessKeyID string `mapstructure:"access_key_id"` 32 | AccessKeySecret string `mapstructure:"access_key_secret"` 33 | UseSSL bool `mapstructure:"use_ssl"` 34 | Region string `mapstructure:"region"` 35 | } 36 | 37 | // MilvusConfig Milvus向量数据库配置 38 | type MilvusConfig struct { 39 | Address string `mapstructure:"address"` 40 | CollectionName string `mapstructure:"collection_name"` 41 | VectorDimension int `mapstructure:"vector_dimension"` 42 | IndexType string `mapstructure:"index_type"` 43 | MetricType string `mapstructure:"metric_type"` 44 | Nlist int `mapstructure:"nlist"` 45 | // 搜索参数 46 | Nprobe int `mapstructure:"nprobe"` 47 | // 字段最大长度配置 48 | IDMaxLength string `mapstructure:"id_max_length"` 49 | ContentMaxLength string `mapstructure:"content_max_length"` 50 | DocIDMaxLength string `mapstructure:"doc_id_max_length"` 51 | DocNameMaxLength string `mapstructure:"doc_name_max_length"` 52 | KbIDMaxLength string `mapstructure:"kb_id_max_length"` 53 | } 54 | 55 | // GetMetricType 获取类型 56 | func (m *MilvusConfig) GetMetricType() entity.MetricType { 57 | // 获取配置的度量类型 58 | var metricType entity.MetricType 59 | switch m.MetricType { 60 | case "L2": 61 | metricType = entity.L2 // 欧几里得距离:测量向量间的直线距离,适合图像特征等数值型向量 62 | case "IP": 63 | metricType = entity.IP // 内积距离:适合已归一化的向量,计算效率高 64 | default: 65 | metricType = entity.COSINE // 余弦相似度:测量向量方向的相似性,适合文本语义搜索 66 | } 67 | return metricType 68 | } 69 | 70 | // GetMilvusIndex 根据配置构建索引 71 | func (m *MilvusConfig) GetMilvusIndex() (entity.Index, error) { 72 | // 选择索引类型的距离度量方式 73 | metricType := m.GetMetricType() 74 | 75 | // 创建索引 76 | var ( 77 | idx entity.Index 78 | err error 79 | ) 80 | if m.Nlist <= 0 { 81 | m.Nlist = 128 // 为空,取默认值 82 | } 83 | 84 | switch m.IndexType { 85 | case "IVF_FLAT": 86 | // IVF_FLAT: 倒排文件索引 + 原始向量存储 87 | // 优点:搜索精度高;缺点:内存占用较大 88 | // nlist: 聚类数量,值越大精度越高但速度越慢,通常设置为 sqrt(n) 到 4*sqrt(n),其中n为向量数量 89 | idx, err = entity.NewIndexIvfFlat(metricType, m.Nlist) 90 | case "IVF_SQ8": 91 | // IVF_SQ8: 倒排文件索引 + 标量量化压缩存储(8位) 92 | // 优点:比IVF_FLAT节省内存;缺点:轻微精度损失 93 | // nlist: 与IVF_FLAT相同,根据数据规模调整 94 | idx, err = entity.NewIndexIvfSQ8(metricType, m.Nlist) 95 | case "HNSW": 96 | // HNSW: 层次可导航小世界图索引,高效且精确但内存占用大 97 | // M: 每个节点的最大边数,影响图的连通性和构建/查询性能 98 | // - 值越大,构建越慢,内存占用越大,但查询越精确 99 | // - 通常取值范围为8-64之间,默认值8在大多数场景下平衡了性能和精度 100 | // efConstruction: 构建索引时每层搜索的候选邻居数量 101 | // - 值越大,构建越慢,索引质量越高 102 | // - 通常取值范围为40-800,默认值40在大多数场景下表现良好 103 | // 注:这两个参数需要根据数据特性和性能要求综合调优,目前使用经验值 104 | idx, err = entity.NewIndexHNSW(metricType, 8, 40) // M=8, efConstruction=40 105 | default: 106 | // 默认使用IVF_FLAT,兼顾搜索精度和性能 107 | idx, err = entity.NewIndexIvfFlat(metricType, m.Nlist) 108 | } 109 | return idx, err 110 | } 111 | 112 | // StorageConfig 存储配置 113 | type StorageConfig struct { 114 | Type string `mapstructure:"type"` // local/oss/minio 115 | Local LocalConfig `mapstructure:"local"` 116 | OSS OSSConfig `mapstructure:"oss"` 117 | Minio MinioConfig `mapstructure:"minio"` 118 | } 119 | 120 | // LocalConfig 本地存储配置 121 | type LocalConfig struct { 122 | BaseDir string `mapstructure:"base_dir"` // 本地存储根目录(如 /data/storage) 123 | } 124 | 125 | // OSSConfig OSS配置 126 | type OSSConfig struct { 127 | Endpoint string `mapstructure:"endpoint"` 128 | Bucket string `mapstructure:"bucket"` 129 | AccessKeyID string `mapstructure:"access_key_id"` 130 | AccessKeySecret string `mapstructure:"access_key_secret"` 131 | } 132 | 133 | // CORSConfig CORS配置 134 | type CORSConfig struct { 135 | AllowOrigins []string `mapstructure:"allow_origins"` 136 | AllowMethods []string `mapstructure:"allow_methods"` 137 | AllowHeaders []string `mapstructure:"allow_headers"` 138 | ExposeHeaders []string `mapstructure:"expose_headers"` 139 | AllowCredentials bool `mapstructure:"allow_credentials"` 140 | MaxAge string `mapstructure:"max_age"` // 使用字符串表示时间,便于配置 141 | } 142 | 143 | // RAGConfig RAG配置 144 | type RAGConfig struct { 145 | ChunkSize int `mapstructure:"chunk_size"` 146 | OverlapSize int `mapstructure:"overlap_size"` 147 | } 148 | 149 | // LLMConfig 语言模型配置 150 | type LLMConfig struct { 151 | APIKey string `mapstructure:"api_key"` 152 | Model string `mapstructure:"model"` 153 | BaseURL string `mapstructure:"base_url"` 154 | MaxTokens int `mapstructure:"max_tokens"` 155 | Temperature float32 `mapstructure:"temperature"` 156 | } 157 | 158 | // AppConfig 应用配置 159 | type AppConfig struct { 160 | Server ServerConfig `mapstructure:"server"` 161 | Database DatabaseConfig `mapstructure:"database"` 162 | JWT JWTConfig `mapstructure:"jwt"` 163 | Storage StorageConfig `mapstructure:"storage"` 164 | CORS CORSConfig `mapstructure:"cors"` 165 | RAG RAGConfig `mapstructure:"rag"` 166 | LLM LLMConfig `mapstructure:"llm"` 167 | Milvus MilvusConfig `mapstructure:"milvus"` 168 | } 169 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | # MySQL service - create ai_cloud database 5 | mysql: 6 | image: mysql:8.0 7 | ports: 8 | - "3306:3306" 9 | volumes: 10 | - mysql_data:/var/lib/mysql 11 | - ./init.sql:/docker-entrypoint-initdb.d/init.sql 12 | environment: 13 | MYSQL_ROOT_PASSWORD: 123456 14 | MYSQL_DATABASE: ai_cloud 15 | networks: 16 | - ai-cloud-network 17 | restart: always 18 | 19 | # MinIO service 20 | minio: 21 | image: minio/minio 22 | ports: 23 | - "9000:9000" 24 | - "9001:9001" 25 | volumes: 26 | - minio_data:/data 27 | environment: 28 | MINIO_ROOT_USER: minioadmin 29 | MINIO_ROOT_PASSWORD: minioadmin 30 | command: server /data --console-address ":9001" 31 | networks: 32 | - ai-cloud-network 33 | healthcheck: 34 | test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] 35 | interval: 30s 36 | timeout: 20s 37 | retries: 3 38 | 39 | # Create MinIO bucket 40 | minio-init: 41 | image: minio/mc 42 | depends_on: 43 | - minio 44 | entrypoint: > 45 | /bin/sh -c " 46 | /usr/bin/mc config host add myminio http://minio:9000 minioadmin minioadmin; 47 | /usr/bin/mc mb myminio/ai-cloud; 48 | exit 0; 49 | " 50 | networks: 51 | - ai-cloud-network 52 | 53 | # Milvus standalone service 54 | etcd: 55 | container_name: milvus-etcd 56 | image: quay.io/coreos/etcd:v3.5.5 57 | environment: 58 | - ETCD_AUTO_COMPACTION_MODE=revision 59 | - ETCD_AUTO_COMPACTION_RETENTION=1000 60 | - ETCD_QUOTA_BACKEND_BYTES=4294967296 61 | - ETCD_SNAPSHOT_COUNT=50000 62 | volumes: 63 | - etcd_data:/etcd 64 | command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd 65 | networks: 66 | - ai-cloud-network 67 | 68 | minio-for-milvus: 69 | container_name: milvus-minio 70 | image: minio/minio:RELEASE.2023-03-20T20-16-18Z 71 | environment: 72 | MINIO_ACCESS_KEY: minioadmin 73 | MINIO_SECRET_KEY: minioadmin 74 | volumes: 75 | - minio_data_for_milvus:/minio_data 76 | command: minio server /minio_data 77 | healthcheck: 78 | test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] 79 | interval: 30s 80 | timeout: 20s 81 | retries: 3 82 | networks: 83 | - ai-cloud-network 84 | 85 | standalone: 86 | container_name: milvus-standalone 87 | image: milvusdb/milvus:v2.4.1 88 | command: ["milvus", "run", "standalone"] 89 | environment: 90 | ETCD_ENDPOINTS: etcd:2379 91 | MINIO_ADDRESS: minio-for-milvus:9000 92 | volumes: 93 | - milvus_data:/var/lib/milvus 94 | ports: 95 | - "19530:19530" 96 | - "9091:9091" 97 | depends_on: 98 | - etcd 99 | - minio-for-milvus 100 | networks: 101 | - ai-cloud-network 102 | 103 | # Attu - Milvus管理界面 104 | attu: 105 | container_name: milvus-attu 106 | image: zilliz/attu:latest 107 | environment: 108 | MILVUS_URL: standalone:19530 109 | ports: 110 | - "8000:3000" 111 | depends_on: 112 | - standalone 113 | networks: 114 | - ai-cloud-network 115 | 116 | volumes: 117 | mysql_data: 118 | minio_data: 119 | etcd_data: 120 | minio_data_for_milvus: 121 | milvus_data: 122 | 123 | networks: 124 | ai-cloud-network: 125 | driver: bridge -------------------------------------------------------------------------------- /docs/AgentChat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RaspberryCola/AI-Cloud-Go/aa509c7429f0a408319274896c2b8af8ca01e5be/docs/AgentChat.png -------------------------------------------------------------------------------- /docs/AgentDebug.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RaspberryCola/AI-Cloud-Go/aa509c7429f0a408319274896c2b8af8ca01e5be/docs/AgentDebug.png -------------------------------------------------------------------------------- /docs/CONFIG_README.md: -------------------------------------------------------------------------------- 1 | # AI-Cloud-Go 配置指南 2 | 3 | ### 配置文件结构 4 | 5 | 配置文件(`config/config.yaml`)结构如下: 6 | 7 | ```yaml 8 | server: 9 | port: "8080" 10 | 11 | database: 12 | host: "localhost" 13 | port: "3306" 14 | user: "root" 15 | password: "123456" 16 | name: "ai_cloud" 17 | 18 | jwt: 19 | secret: "your-jwt-secret" 20 | expiration_hours: 24 21 | 22 | storage: 23 | type: "minio" # local, oss, minio 24 | local: 25 | base_dir: "./storage_data" 26 | oss: 27 | endpoint: "oss-endpoint" 28 | bucket: "bucket-name" 29 | access_key_id: "" 30 | access_key_secret: "" 31 | minio: 32 | endpoint: "localhost:9000" 33 | bucket: "ai-cloud" 34 | access_key_id: "minioadmin" 35 | access_key_secret: "minioadmin" 36 | use_ssl: false 37 | region: "" 38 | 39 | # Milvus向量数据库配置 40 | milvus: 41 | address: "localhost:19530" # Milvus服务地址 42 | index_type: "IVF_FLAT" # 索引类型 (IVF_FLAT, IVF_SQ8, HNSW) 43 | metric_type: "COSINE" # 距离计算方式 (COSINE, L2, IP) 44 | nlist: 128 # IVF索引聚类数量 45 | # 搜索参数 46 | nprobe: 16 # 搜索时检查的聚类数量,值越大结果越精确但越慢 47 | # 字段最大长度配置 48 | id_max_length: "64" # ID字段最大长度 49 | content_max_length: "65535" # 内容字段最大长度 50 | doc_id_max_length: "64" # 文档ID字段最大长度 51 | doc_name_max_length: "256" # 文档名称字段最大长度 52 | kb_id_max_length: "64" # 知识库ID字段最大长度 53 | 54 | rag: 55 | chunk_size: 1500 56 | overlap_size: 500 57 | 58 | cors: 59 | # CORS配置... 60 | 61 | # 语言模型配置(后续移除,通过统一的模块管理) 62 | llm: 63 | api_key: "your-llm-api-key" 64 | model: "deepseek-chat" 65 | base_url: "https://api.deepseek.com/v1" 66 | max_tokens: 4096 67 | temperature: 0.7 68 | ``` 69 | 70 | ## 配置使用 71 | 72 | 在Go代码中,可以通过以下方式访问配置: 73 | 74 | ```go 75 | import "ai-cloud/config" 76 | 77 | func main() { 78 | // 初始化配置(在应用启动时调用一次) 79 | config.InitConfig() 80 | 81 | // 获取配置实例 82 | cfg := config.GetConfig() 83 | 84 | // 访问配置项 85 | port := cfg.Server.Port 86 | embeddingService := cfg.Embedding.Service 87 | llmModel := cfg.LLM.Model 88 | milvusAddress := cfg.Milvus.Address 89 | 90 | // ... 91 | } 92 | ``` 93 | 94 | ## Milvus向量数据库配置 95 | 96 | 配置Milvus服务连接地址和向量集合参数: 97 | 98 | ```yaml 99 | # Milvus向量数据库配置 100 | milvus: 101 | address: "localhost:19530" # Milvus服务地址 102 | index_type: "IVF_FLAT" # 索引类型 (IVF_FLAT, IVF_SQ8, HNSW) 103 | metric_type: "COSINE" # 距离计算方式 (COSINE, L2, IP) 104 | nlist: 128 # IVF索引聚类数量 105 | # 搜索参数 106 | nprobe: 16 # 搜索时检查的聚类数量,值越大结果越精确但越慢 107 | # 字段最大长度配置 108 | id_max_length: "64" # ID字段最大长度 109 | content_max_length: "65535" # 内容字段最大长度 110 | doc_id_max_length: "64" # 文档ID字段最大长度 111 | doc_name_max_length: "256" # 文档名称字段最大长度 112 | kb_id_max_length: "64" # 知识库ID字段最大长度 113 | ``` 114 | 115 | 此配置在初始化Milvus客户端和创建集合时使用: 116 | 117 | ```go 118 | // 初始化Milvus客户端 119 | milvusClient, err := client.NewClient(ctx, client.Config{ 120 | Address: config.GetConfig().Milvus.Address, 121 | }) 122 | 123 | // 使用配置创建集合 124 | milvusConfig := config.GetConfig().Milvus 125 | address := milvusConfig.Address 126 | // ... 127 | ``` 128 | 129 | ## 语言模型配置 130 | 131 | 配置LLM服务: 132 | 133 | ```yaml 134 | llm: 135 | api_key: "your-api-key" 136 | model: "deepseek-chat" # 或其他支持的模型 137 | base_url: "https://api.deepseek.com/v1" 138 | max_tokens: 4096 139 | temperature: 0.7 140 | ``` 141 | 142 | ## 从环境变量迁移到配置文件 143 | 144 | 如果您之前使用`.env`文件配置项目,请按照以下对应关系迁移到`config.yaml`: 145 | 146 | | 环境变量 | 配置文件路径 | 147 | |---------|------------| 148 | | `LLM_API_KEY` | `llm.api_key` | 149 | | `LLM_MODEL` | `llm.model` | 150 | | `LLM_BASE_URL` | `llm.base_url` | 151 | | `MILVUS_ADDRESS` | `milvus.address` | 152 | 153 | ## 注意事项 154 | 155 | 1. 配置文件中的敏感信息(如API密钥)不应提交到版本控制系统 156 | 2. 可以考虑使用环境变量覆盖配置文件中的敏感信息 157 | 3. 为不同环境(开发、测试、生产)准备不同的配置文件 -------------------------------------------------------------------------------- /docs/QUICKSTART.md: -------------------------------------------------------------------------------- 1 | # AI-Cloud-Go 快速启动指南 2 | 3 | 这份指南将帮助您快速设置并运行AI-Cloud-Go系统,包括各种支持的嵌入服务配置。 4 | 5 | ## 前置条件 6 | 7 | - Go 1.16+ 8 | - Docker和Docker Compose 9 | - Git 10 | - 确保以下端口未被占用: 8080, 3306, 9000, 9001, 19530, 9091, 11434 11 | 12 | ## 步骤一:获取代码 13 | 14 | ```bash 15 | git clone https://github.com/RaspberryCola/AI-Cloud-Go.git 16 | cd AI-Cloud-Go 17 | ``` 18 | 19 | ## 步骤二:设置环境 20 | 21 | ### 配置系统参数 22 | 23 | 1. 确保项目根目录存在`config`文件夹,如果不存在请创建 24 | 2. 在`config`文件夹中创建或修改`config.yaml`文件: 25 | 26 | ```yaml 27 | server: 28 | port: "8080" 29 | 30 | database: 31 | host: "localhost" 32 | port: "3306" 33 | user: "root" 34 | password: "123456" 35 | name: "ai_cloud" 36 | 37 | storage: 38 | type: "minio" # local, oss, minio 39 | minio: 40 | endpoint: "localhost:9000" 41 | bucket: "ai-cloud" 42 | access_key_id: "minioadmin" 43 | access_key_secret: "minioadmin" 44 | use_ssl: false 45 | region: "" 46 | 47 | # Milvus向量数据库配置 48 | milvus: 49 | address: "localhost:19530" 50 | index_type: "IVF_FLAT" 51 | metric_type: "COSINE" 52 | nlist: 128 53 | # 搜索参数 54 | nprobe: 16 55 | # 字段最大长度配置 56 | id_max_length: "64" 57 | content_max_length: "65535" 58 | doc_id_max_length: "64" 59 | doc_name_max_length: "256" 60 | kb_id_max_length: "64" 61 | 62 | # 语言模型配置(后续会移除到统一的模型管理中) 63 | llm: 64 | api_key: "your-llm-api-key" # 替换为您的语言模型API密钥 65 | model: "deepseek-chat" 66 | base_url: "https://api.deepseek.com/v1" 67 | max_tokens: 4096 68 | temperature: 0.7 69 | 70 | rag: 71 | chunk_size: 1500 72 | overlap_size: 500 73 | ``` 74 | 75 | ## 步骤三:启动基础服务 76 | 77 | ### 启动所有Docker服务 78 | 79 | 确保项目根目录包含`docker-compose.yml`文件,然后运行: 80 | 81 | ```bash 82 | # 启动MySQL, MinIO, Milvus服务 83 | docker-compose up -d 84 | ``` 85 | 86 | 或者,如果您只想启动特定服务: 87 | 88 | ```bash 89 | # 仅启动MySQL和MinIO 90 | docker-compose up -d mysql-init minio minio-init 91 | 92 | # 仅启动Milvus相关服务 93 | docker-compose up -d etcd minio-for-milvus standalone 94 | ``` 95 | 96 | ### 检查服务状态 97 | 98 | ```bash 99 | # 列出所有容器 100 | docker ps 101 | 102 | # 检查MySQL是否正常初始化 103 | docker logs mysql-init 104 | 105 | # 检查MinIO是否正常启动 106 | curl http://localhost:9000 107 | 108 | # 检查Milvus是否正常启动 109 | docker logs milvus-standalone 110 | ``` 111 | 112 | 确保所有服务都正常运行,没有错误信息。 113 | 114 | ## 步骤四:设置Ollama (如果使用Ollama作为嵌入服务) 115 | 116 | 如果您选择使用Ollama作为嵌入服务: 117 | 118 | 1. 从[Ollama官网](https://ollama.com/download)下载并安装Ollama 119 | 120 | 2. 拉取嵌入模型: 121 | ```bash 122 | ollama pull mxbai-embed-large 123 | ``` 124 | 125 | 3. 启动Ollama服务: 126 | ```bash 127 | OLLAMA_HOST="0.0.0.0" OLLAMA_ORIGINS="*" ollama serve 128 | ``` 129 | 130 | 您也可以创建一个启动脚本`start-ollama.sh`: 131 | ```bash 132 | #!/bin/bash 133 | OLLAMA_HOST="0.0.0.0" OLLAMA_ORIGINS="*" OLLAMA_KEEP_ALIVE="24h" ollama serve 134 | ``` 135 | 136 | 然后运行`chmod +x start-ollama.sh`和`./start-ollama.sh` 137 | 138 | 4. 在前端模型服务中添加模型: 139 | 140 | 5. 验证Ollama服务是否正常运行: 141 | ```bash 142 | curl http://localhost:11434/api/tags 143 | ``` 144 | 应返回包含`mxbai-embed-large`的模型列表。 145 | 146 | ## 步骤五:确认Milvus配置 147 | 148 | Milvus的连接地址和向量集合参数在配置文件中指定: 149 | 150 | ```yaml 151 | # Milvus向量数据库配置 152 | milvus: 153 | address: "localhost:19530" 154 | index_type: "IVF_FLAT" 155 | metric_type: "COSINE" 156 | nlist: 128 157 | # 搜索参数 158 | nprobe: 16 159 | # 字段最大长度配置 160 | id_max_length: "64" 161 | content_max_length: "65535" 162 | doc_id_max_length: "64" 163 | doc_name_max_length: "256" 164 | kb_id_max_length: "64" 165 | ``` 166 | 167 | 如果您使用的是自定义的Milvus部署或远程Milvus服务,请相应地修改地址。您也可以根据需要调整以下参数: 168 | 169 | - `index_type`: 索引类型,支持IVF_FLAT、IVF_SQ8、HNSW等 170 | - `metric_type`: 距离计算方式,支持COSINE、L2、IP等 171 | - `nlist`: IVF索引的聚类数量 172 | - `nprobe`: 搜索时检查的聚类数量,值越大结果越精确但查询越慢 173 | - `*_max_length`: 各字段的最大长度设置,特别是处理大文档时可能需要调整content_max_length 174 | 175 | 验证Milvus是否正常运行: 176 | 177 | ```bash 178 | # 检查Milvus容器状态 179 | docker ps | grep milvus 180 | 181 | # 查看Milvus日志 182 | docker logs milvus-standalone 183 | ``` 184 | 185 | ## 步骤六:启动AI-Cloud-Go 186 | 187 | ### 下载依赖 188 | 189 | ```bash 190 | go mod download 191 | ``` 192 | 193 | ### 运行应用 194 | 195 | ```bash 196 | go run cmd/main.go 197 | ``` 198 | 199 | 应用将在http://localhost:8080运行。您应该看到类似以下的输出: 200 | 201 | ``` 202 | [GIN-debug] Listening and serving HTTP on :8080 203 | ``` 204 | 205 | ## 步骤七:验证安装 206 | 207 | ### 检查API服务 208 | 209 | ```bash 210 | # 检查健康状态 211 | curl http://localhost:8080/api/health 212 | 213 | # 使用Swagger查看API文档 214 | # 在浏览器中访问: http://localhost:8080/swagger/index.html 215 | ``` 216 | 217 | ### 测试嵌入服务 218 | 219 | 根据您选择的嵌入服务,验证其正常工作: 220 | 221 | ```bash 222 | # 如果使用OpenAI 223 | curl -H "Authorization: Bearer 您的API密钥" https://api.openai.com/v1/embeddings -d '{"model":"text-embedding-3-large", "input":"测试文本"}' 224 | 225 | # 如果使用Ollama 226 | curl -X POST http://localhost:11434/api/embed -d '{"model":"mxbai-embed-large", "input":"测试文本"}' 227 | ``` 228 | 229 | ### 注册和登录 230 | 231 | 要使用系统的大多数功能,您需要先注册并登录以获取JWT令牌: 232 | 233 | 1. 注册用户: 234 | ```bash 235 | curl -X POST http://localhost:8080/api/users/register -H "Content-Type: application/json" -d '{"username":"testuser", "password":"testpassword", "email":"test@example.com"}' 236 | ``` 237 | 238 | 2. 登录并获取Token: 239 | ```bash 240 | curl -X POST http://localhost:8080/api/users/login -H "Content-Type: application/json" -d '{"username":"testuser", "password":"testpassword"}' 241 | ``` 242 | 243 | 复制返回的`token`值,供后续API调用使用。 244 | 245 | ## 使用知识库功能 246 | 247 | 使用您在登录时获得的JWT令牌: 248 | 249 | 1. 创建知识库: 250 | ```bash 251 | curl -X POST http://localhost:8080/api/kb/create \ 252 | -H "Authorization: Bearer 您的JWT令牌" \ 253 | -H "Content-Type: application/json" \ 254 | -d '{"name":"测试知识库","description":"这是一个测试知识库"}' 255 | ``` 256 | 257 | 2. 上传文档到知识库(使用multipart/form-data): 258 | ```bash 259 | curl -X POST http://localhost:8080/api/kb/addNew \ 260 | -H "Authorization: Bearer 您的JWT令牌" \ 261 | -F "kb_id=知识库ID" \ 262 | -F "file=@/path/to/your/document.pdf" 263 | ``` 264 | 265 | 3. 查询知识库: 266 | ```bash 267 | curl -X POST http://localhost:8080/api/kb/retrieve \ 268 | -H "Authorization: Bearer 您的JWT令牌" \ 269 | -H "Content-Type: application/json" \ 270 | -d '{"kb_id":"知识库ID","query":"您的问题","top_k":3}' 271 | ``` 272 | 273 | ## 故障排除 274 | 275 | ### 初始化问题 276 | - 如果启动时显示找不到配置文件,请确保`config/config.yaml`文件存在并格式正确 277 | - 如果显示找不到`init.sql`,请确认项目根目录中有此文件 278 | 279 | ### MySQL连接问题 280 | - 错误:`dial tcp 127.0.0.1:3306: connect: connection refused` 281 | - 解决:确保MySQL容器已启动,运行`docker ps | grep mysql` 282 | 283 | ### Milvus连接问题 284 | - 错误:`无法连接到Milvus` 285 | - 解决: 286 | 1. 确保Milvus容器正在运行:`docker ps | grep milvus` 287 | 2. 检查config.yaml中的milvus.address配置是否正确 288 | 3. 如果使用自定义Milvus部署,确保端口映射正确 289 | 290 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module ai-cloud 2 | 3 | go 1.23.4 4 | 5 | require ( 6 | code.sajari.com/docconv/v2 v2.0.0-pre.4 7 | github.com/aliyun/aliyun-oss-go-sdk v3.0.2+incompatible 8 | github.com/bytedance/sonic v1.13.2 9 | github.com/cloudwego/eino v0.3.30 10 | github.com/cloudwego/eino-ext/components/document/loader/url v0.0.0-20250328102648-b47e7f1587fa 11 | github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20250328102648-b47e7f1587fa 12 | github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20250328102648-b47e7f1587fa 13 | github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250331101427-906b8d194a99 14 | github.com/cloudwego/eino-ext/components/tool/mcp v0.0.0-20250507115047-b20720df8528 15 | github.com/cloudwego/eino-ext/libs/acl/openai v0.0.0-20250422092704-54e372e1fa3d 16 | github.com/fsnotify/fsnotify v1.9.0 17 | github.com/gin-contrib/cors v1.7.3 18 | github.com/gin-contrib/sse v0.1.0 19 | github.com/gin-gonic/gin v1.10.0 20 | github.com/golang-jwt/jwt/v5 v5.2.1 21 | github.com/google/uuid v1.6.0 22 | github.com/mark3labs/mcp-go v0.26.0 23 | github.com/milvus-io/milvus-sdk-go/v2 v2.4.2 24 | github.com/minio/minio-go/v7 v7.0.84 25 | github.com/ollama/ollama v0.5.12 26 | github.com/spf13/viper v1.20.1 27 | golang.org/x/crypto v0.34.0 28 | gorm.io/driver/mysql v1.5.7 29 | gorm.io/gorm v1.25.12 30 | 31 | ) 32 | 33 | require ( 34 | github.com/JalfResi/justext v0.0.0-20170829062021-c0282dea7198 // indirect 35 | github.com/PuerkitoBio/goquery v1.8.1 // indirect 36 | github.com/advancedlogic/GoOse v0.0.0-20191112112754-e742535969c1 // indirect 37 | github.com/andybalholm/cascadia v1.3.2 // indirect 38 | github.com/araddon/dateparse v0.0.0-20200409225146-d820a6159ab1 // indirect 39 | github.com/aymerick/douceur v0.2.0 // indirect 40 | github.com/bytedance/sonic/loader v0.2.4 // indirect 41 | github.com/cloudwego/base64x v0.1.5 // indirect 42 | github.com/cloudwego/eino-ext/components/document/parser/html v0.0.0-20250117061805-cd80d1780d76 // indirect 43 | github.com/cockroachdb/errors v1.9.1 // indirect 44 | github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect 45 | github.com/cockroachdb/redact v1.1.3 // indirect 46 | github.com/dustin/go-humanize v1.0.1 // indirect 47 | github.com/fatih/set v0.2.1 // indirect 48 | github.com/gabriel-vasile/mimetype v1.4.8 // indirect 49 | github.com/getkin/kin-openapi v0.118.0 // indirect 50 | github.com/getsentry/sentry-go v0.12.0 // indirect 51 | github.com/gigawattio/window v0.0.0-20180317192513-0f5467e35573 // indirect 52 | github.com/go-ini/ini v1.67.0 // indirect 53 | github.com/go-openapi/jsonpointer v0.21.0 // indirect 54 | github.com/go-openapi/swag v0.23.0 // indirect 55 | github.com/go-playground/locales v0.14.1 // indirect 56 | github.com/go-playground/universal-translator v0.18.1 // indirect 57 | github.com/go-playground/validator/v10 v10.23.0 // indirect 58 | github.com/go-resty/resty/v2 v2.3.0 // indirect 59 | github.com/go-sql-driver/mysql v1.7.1 // indirect 60 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect 61 | github.com/goccy/go-json v0.10.4 // indirect 62 | github.com/gogo/protobuf v1.3.2 // indirect 63 | github.com/golang/protobuf v1.5.4 // indirect 64 | github.com/goph/emperror v0.17.2 // indirect 65 | github.com/gorilla/css v1.0.1 // indirect 66 | github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect 67 | github.com/invopop/yaml v0.3.1 // indirect 68 | github.com/jaytaylor/html2text v0.0.0-20200412013138-3577fbdbcff7 // indirect 69 | github.com/jinzhu/inflection v1.0.0 // indirect 70 | github.com/jinzhu/now v1.1.5 // indirect 71 | github.com/josharian/intern v1.0.0 // indirect 72 | github.com/json-iterator/go v1.1.12 // indirect 73 | github.com/klauspost/compress v1.17.11 // indirect 74 | github.com/klauspost/cpuid/v2 v2.2.9 // indirect 75 | github.com/kr/pretty v0.3.1 // indirect 76 | github.com/kr/text v0.2.0 // indirect 77 | github.com/leodido/go-urn v1.4.0 // indirect 78 | github.com/levigross/exp-html v0.0.0-20120902181939-8df60c69a8f5 // indirect 79 | github.com/mailru/easyjson v0.9.0 // indirect 80 | github.com/mattn/go-isatty v0.0.20 // indirect 81 | github.com/mattn/go-runewidth v0.0.14 // indirect 82 | github.com/meguminnnnnnnnn/go-openai v0.0.0-20250408071642-761325becfd6 // indirect 83 | github.com/microcosm-cc/bluemonday v1.0.27 // indirect 84 | github.com/milvus-io/milvus-proto/go-api/v2 v2.5.6 // indirect 85 | github.com/minio/md5-simd v1.1.2 // indirect 86 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 87 | github.com/modern-go/reflect2 v1.0.2 // indirect 88 | github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect 89 | github.com/nikolalohinski/gonja v1.5.3 // indirect 90 | github.com/olekukonko/tablewriter v0.0.5 // indirect 91 | github.com/otiai10/gosseract/v2 v2.2.4 // indirect 92 | github.com/pelletier/go-toml/v2 v2.2.3 // indirect 93 | github.com/perimeterx/marshmallow v1.1.5 // indirect 94 | github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect 95 | github.com/pkg/errors v0.9.1 // indirect 96 | github.com/richardlehane/mscfb v1.0.3 // indirect 97 | github.com/richardlehane/msoleps v1.0.3 // indirect 98 | github.com/rivo/uniseg v0.2.0 // indirect 99 | github.com/rogpeppe/go-internal v1.13.1 // indirect 100 | github.com/rs/xid v1.6.0 // indirect 101 | github.com/sagikazarmark/locafero v0.7.0 // indirect 102 | github.com/sashabaranov/go-openai v1.37.0 // indirect 103 | github.com/sirupsen/logrus v1.9.3 // indirect 104 | github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect 105 | github.com/sourcegraph/conc v0.3.0 // indirect 106 | github.com/spf13/afero v1.12.0 // indirect 107 | github.com/spf13/cast v1.7.1 // indirect 108 | github.com/spf13/pflag v1.0.6 // indirect 109 | github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect 110 | github.com/stretchr/objx v0.5.2 // indirect 111 | github.com/subosito/gotenv v1.6.0 // indirect 112 | github.com/tidwall/gjson v1.17.1 // indirect 113 | github.com/tidwall/match v1.1.1 // indirect 114 | github.com/tidwall/pretty v1.2.0 // indirect 115 | github.com/twitchyliquid64/golang-asm v0.15.1 // indirect 116 | github.com/ugorji/go/codec v1.2.12 // indirect 117 | github.com/yargevad/filepathx v1.0.0 // indirect 118 | github.com/yosida95/uritemplate/v3 v3.0.2 // indirect 119 | go.uber.org/atomic v1.11.0 // indirect 120 | go.uber.org/multierr v1.11.0 // indirect 121 | golang.org/x/arch v0.14.0 // indirect 122 | golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect 123 | golang.org/x/net v0.35.0 // indirect 124 | golang.org/x/sync v0.11.0 // indirect 125 | golang.org/x/sys v0.30.0 // indirect 126 | golang.org/x/text v0.22.0 // indirect 127 | golang.org/x/time v0.9.0 // indirect 128 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect 129 | google.golang.org/grpc v1.67.3 // indirect 130 | google.golang.org/protobuf v1.36.3 // indirect 131 | gopkg.in/yaml.v3 v3.0.1 // indirect 132 | ) 133 | 134 | replace nhooyr.io/websocket => github.com/coder/websocket v1.8.7 135 | -------------------------------------------------------------------------------- /init.sql: -------------------------------------------------------------------------------- 1 | CREATE DATABASE IF NOT EXISTS `ai_cloud`; 2 | USE ai_cloud; 3 | 4 | -- You can add more initialization steps here if needed, such as: 5 | -- CREATE TABLE statements 6 | -- INSERT initial data 7 | -- CREATE USER and GRANT privileges -------------------------------------------------------------------------------- /internal/component/embedding/embedding.go: -------------------------------------------------------------------------------- 1 | package embedding 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "context" 6 | "fmt" 7 | einoEmbedding "github.com/cloudwego/eino/components/embedding" 8 | "time" 9 | ) 10 | 11 | const ( 12 | ProviderOpenAI = "openai" 13 | ProviderOllama = "ollama" 14 | ) 15 | 16 | type EmbeddingOption func(*EmbeddingOptions) 17 | 18 | type EmbeddingOptions struct { 19 | Timeout *time.Duration 20 | } 21 | 22 | func WithTimeout(timeout time.Duration) EmbeddingOption { 23 | return func(o *EmbeddingOptions) { 24 | o.Timeout = &timeout 25 | } 26 | } 27 | 28 | // EmbeddingService 定义向量嵌入服务的通用接口 29 | type EmbeddingService interface { 30 | New(ctx context.Context, cfg *model.Model, opts ...EmbeddingOption) (EmbeddingService, error) 31 | // EmbedStrings 将文本转换为向量表示 32 | EmbedStrings(ctx context.Context, texts []string, opts ...einoEmbedding.Option) ([][]float64, error) 33 | // GetDimension 返回嵌入向量的维度 34 | GetDimension() int 35 | } 36 | 37 | var embeddingMap = make(map[string]EmbeddingService) 38 | 39 | func register(name string, embeddingService EmbeddingService) { 40 | embeddingMap[name] = embeddingService 41 | } 42 | 43 | func NewEmbeddingService(ctx context.Context, cfg *model.Model, opts ...EmbeddingOption) (EmbeddingService, error) { 44 | if cfg == nil { 45 | return nil, fmt.Errorf("embedding config is nil") 46 | } 47 | 48 | if cfg.Server == "" { 49 | return nil, fmt.Errorf("embedding config server is empty") 50 | } 51 | 52 | // 获取实例 53 | if embedding, ok := embeddingMap[cfg.Server]; ok { 54 | return embedding.New(ctx, cfg, opts...) 55 | } 56 | return nil, fmt.Errorf("不支持的嵌入服务提供者: %s", cfg.Type, opts) 57 | } 58 | -------------------------------------------------------------------------------- /internal/component/embedding/ollama.go: -------------------------------------------------------------------------------- 1 | package embedding 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "context" 6 | "fmt" 7 | einoEmbedding "github.com/cloudwego/eino/components/embedding" 8 | "time" 9 | 10 | "github.com/ollama/ollama/api" 11 | "net/http" 12 | "net/url" 13 | ) 14 | 15 | func init() { 16 | register(ProviderOllama, &ollamaEmbedder{}) 17 | } 18 | 19 | type OllamaEmbeddingConfig struct { 20 | BaseURL string 21 | Model string 22 | Dimension *int 23 | Timeout *time.Duration 24 | HTTPClient *http.Client 25 | } 26 | 27 | type ollamaEmbedder struct { 28 | cli *api.Client 29 | conf *OllamaEmbeddingConfig 30 | } 31 | 32 | // TODO:添加默认超时时间防止报错 33 | func (o *ollamaEmbedder) New(ctx context.Context, cfg *model.Model, opts ...EmbeddingOption) (EmbeddingService, error) { 34 | // 检查配置 35 | if err := checkCfg(cfg); err != nil { 36 | return nil, err 37 | } 38 | // 处理选项 39 | options := &EmbeddingOptions{} 40 | for _, opt := range opts { 41 | opt(options) 42 | } 43 | 44 | config := &OllamaEmbeddingConfig{ 45 | BaseURL: cfg.BaseURL, 46 | Model: cfg.ModelName, 47 | Dimension: &cfg.Dimension, 48 | Timeout: options.Timeout, 49 | } 50 | 51 | // 构造 client 52 | var httpClient *http.Client 53 | if config.HTTPClient != nil { 54 | httpClient = config.HTTPClient 55 | } else { 56 | httpClient = &http.Client{Timeout: *config.Timeout} 57 | } 58 | 59 | // 构造url 60 | baseURL, err := url.Parse(config.BaseURL) 61 | if err != nil { 62 | return nil, fmt.Errorf("invalid base URL: %w", err) 63 | } 64 | 65 | // 创建 client 66 | cli := api.NewClient(baseURL, httpClient) 67 | 68 | return &ollamaEmbedder{ 69 | cli: cli, 70 | conf: config, 71 | }, nil 72 | } 73 | 74 | func (o *ollamaEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...einoEmbedding.Option) ( 75 | embeddings [][]float64, err error) { 76 | req := &api.EmbedRequest{ 77 | Model: o.conf.Model, 78 | Input: texts, 79 | } 80 | resp, err := o.cli.Embed(ctx, req) 81 | if err != nil { 82 | return nil, err 83 | } 84 | 85 | embeddings = make([][]float64, len(resp.Embeddings)) 86 | for i, d := range resp.Embeddings { 87 | res := make([]float64, len(d)) 88 | for j, emb := range d { 89 | res[j] = float64(emb) 90 | } 91 | embeddings[i] = res 92 | } 93 | 94 | return embeddings, nil 95 | } 96 | 97 | func (o *ollamaEmbedder) GetType() string { 98 | return ProviderOllama 99 | } 100 | 101 | // TODO:处理Callback 102 | //func (e *ollamaEmbedder) IsCallbacksEnabled() bool { 103 | // return true 104 | //} 105 | 106 | func (o *ollamaEmbedder) GetDimension() int { 107 | return *o.conf.Dimension 108 | } 109 | 110 | func checkCfg(cfg *model.Model) error { 111 | if cfg.BaseURL == "" { 112 | return fmt.Errorf("ollama base URL cannot be empty") 113 | } 114 | 115 | if _, err := url.Parse(cfg.BaseURL); err != nil { 116 | return fmt.Errorf("invalid ollama base URL: %w", err) 117 | } 118 | 119 | if cfg.ModelName == "" { 120 | return fmt.Errorf("ollama model name cannot be empty") 121 | } 122 | 123 | if cfg.Dimension <= 0 { 124 | return fmt.Errorf("ollama embedding dimension must be positive") 125 | } 126 | 127 | return nil 128 | } 129 | -------------------------------------------------------------------------------- /internal/component/embedding/openai.go: -------------------------------------------------------------------------------- 1 | package embedding 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "context" 6 | "fmt" 7 | "github.com/cloudwego/eino-ext/components/embedding/openai" 8 | "github.com/cloudwego/eino/components/embedding" 9 | "net/url" 10 | "time" 11 | ) 12 | 13 | func init() { 14 | register(ProviderOpenAI, &openaiEmbedder{}) 15 | } 16 | 17 | type OpenAIEmbeddingConfig struct { 18 | BaseURL string 19 | APIKey string 20 | Model string 21 | Timeout *time.Duration 22 | Dimension *int 23 | } 24 | type openaiEmbedder struct { 25 | conf *OpenAIEmbeddingConfig 26 | embedder *openai.Embedder 27 | } 28 | 29 | func (o *openaiEmbedder) New(ctx context.Context, cfg *model.Model, opts ...EmbeddingOption) (EmbeddingService, error) { 30 | if err := checkOpenAICfg(cfg); err != nil { 31 | return nil, err 32 | } 33 | 34 | options := &EmbeddingOptions{} 35 | for _, opt := range opts { 36 | opt(options) 37 | } 38 | 39 | config := &OpenAIEmbeddingConfig{ 40 | BaseURL: cfg.BaseURL, 41 | APIKey: cfg.APIKey, 42 | Model: cfg.ModelName, 43 | Timeout: options.Timeout, 44 | Dimension: &cfg.Dimension, 45 | } 46 | 47 | embeder, err := openai.NewEmbedder(ctx, &openai.EmbeddingConfig{ 48 | APIKey: config.APIKey, 49 | BaseURL: config.BaseURL, 50 | Model: config.Model, 51 | Timeout: *options.Timeout, 52 | Dimensions: &cfg.Dimension, 53 | }) 54 | if err != nil { 55 | return nil, err 56 | } 57 | return &openaiEmbedder{ 58 | conf: config, 59 | embedder: embeder, 60 | }, nil 61 | } 62 | 63 | func (s *openaiEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { 64 | return s.embedder.EmbedStrings(ctx, texts, opts...) 65 | } 66 | 67 | func (o *openaiEmbedder) GetDimension() int { 68 | return *o.conf.Dimension 69 | } 70 | 71 | func (o *openaiEmbedder) GetType() string { 72 | return ProviderOpenAI 73 | } 74 | 75 | func checkOpenAICfg(cfg *model.Model) error { 76 | if cfg.BaseURL == "" { 77 | return fmt.Errorf("opnai base URL cannot be empty") 78 | } 79 | 80 | if _, err := url.Parse(cfg.BaseURL); err != nil { 81 | return fmt.Errorf("invalid openai base URL: %w", err) 82 | } 83 | 84 | if cfg.ModelName == "" { 85 | return fmt.Errorf("openai model name cannot be empty") 86 | } 87 | 88 | if cfg.Dimension <= 0 { 89 | return fmt.Errorf("openai embedding dimension must be positive") 90 | } 91 | 92 | return nil 93 | } 94 | -------------------------------------------------------------------------------- /internal/component/indexer/milvus/indexer.go: -------------------------------------------------------------------------------- 1 | package milvus 2 | 3 | import ( 4 | "ai-cloud/config" 5 | "ai-cloud/internal/utils" 6 | "ai-cloud/pkgs/consts" 7 | "context" 8 | "fmt" 9 | "github.com/bytedance/sonic" 10 | "github.com/cloudwego/eino/components/embedding" 11 | "github.com/cloudwego/eino/components/indexer" 12 | "github.com/cloudwego/eino/schema" 13 | "github.com/milvus-io/milvus-sdk-go/v2/client" 14 | "github.com/milvus-io/milvus-sdk-go/v2/entity" 15 | "strconv" 16 | "strings" 17 | ) 18 | 19 | type MilvusIndexerConfig struct { 20 | Collection string 21 | Dimension int 22 | Embedding embedding.Embedder 23 | Client client.Client 24 | } 25 | 26 | type MilvusIndexer struct { 27 | config MilvusIndexerConfig 28 | } 29 | 30 | func NewMilvusIndexer(ctx context.Context, conf *MilvusIndexerConfig) (*MilvusIndexer, error) { 31 | // 检查配置 32 | if err := conf.check(); err != nil { 33 | return nil, fmt.Errorf("[NewMilvusIndexer] invalid config: %w", err) 34 | } 35 | 36 | // 检查Collection是否存在 37 | exists, err := conf.Client.HasCollection(ctx, conf.Collection) 38 | if err != nil { 39 | return nil, fmt.Errorf("[NewMilvusIndexer] check milvus collection failed : %w", err) 40 | } 41 | if !exists { 42 | if err := conf.createCollection(ctx, conf.Collection, conf.Dimension); err != nil { 43 | return nil, fmt.Errorf("[NewMilvusIndexer] create collection failed: %w", err) 44 | } 45 | } 46 | 47 | // 加载Collection 48 | err = conf.Client.LoadCollection(ctx, conf.Collection, false) 49 | if err != nil { 50 | return nil, fmt.Errorf("[NewMilvusIndexer] failed to load collection: %w", err) 51 | } 52 | 53 | return &MilvusIndexer{ 54 | config: *conf, 55 | }, nil 56 | } 57 | 58 | func (m *MilvusIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) { 59 | 60 | // 如果有opts则用opts中的配置(允许在Store的时候更换Embedder配置) 61 | co := indexer.GetCommonOptions(&indexer.Options{ //提供默认值选项 62 | SubIndexes: nil, 63 | Embedding: m.config.Embedding, 64 | }, opts...) 65 | 66 | embedder := co.Embedding 67 | if embedder == nil { 68 | return nil, fmt.Errorf("[Indexer.Store] embedding not provided") 69 | } 70 | // 获取文档内容部分 71 | texts := make([]string, 0, len(docs)) 72 | for _, doc := range docs { 73 | texts = append(texts, doc.Content) 74 | } 75 | // 向量化 76 | vectors := make([][]float64, len(texts)) // 预分配结果切片 77 | 78 | for i, text := range texts { 79 | // 每次只embed一个文本 80 | vec, err := embedder.EmbedStrings(ctx, []string{text}) 81 | if err != nil { 82 | return nil, fmt.Errorf("[Indexer.Store] failed to embed text at index %d: %w", i, err) 83 | } 84 | 85 | // 确保返回的向量是我们期望的单个结果 86 | if len(vec) != 1 { 87 | return nil, fmt.Errorf("[Indexer.Store] unexpected number of vectors returned: %d", len(vec)) 88 | } 89 | 90 | vectors[i] = vec[0] 91 | } 92 | if len(vectors) != len(docs) { 93 | return nil, fmt.Errorf("[Indexer.Store] embedding vector length mismatch") 94 | } 95 | rows, err := DocumentConvert(ctx, docs, vectors) 96 | if err != nil { 97 | return nil, err 98 | } 99 | 100 | results, err := m.config.Client.InsertRows(ctx, m.config.Collection, "", rows) 101 | if err != nil { 102 | return nil, err 103 | } 104 | if err := m.config.Client.Flush(ctx, m.config.Collection, false); err != nil { 105 | return nil, err 106 | } 107 | ids = make([]string, results.Len()) 108 | for idx := 0; idx < results.Len(); idx++ { 109 | ids[idx], err = results.GetAsString(idx) 110 | if err != nil { 111 | return nil, fmt.Errorf("[Indexer.Store] failed to get id: %w", err) 112 | } 113 | } 114 | return ids, nil 115 | 116 | } 117 | func (m *MilvusIndexer) GetType() string { 118 | return "Milvus" 119 | } 120 | func DocumentConvert(ctx context.Context, docs []*schema.Document, vectors [][]float64) ([]interface{}, error) { 121 | 122 | em := make([]defaultSchema, 0, len(docs)) 123 | rows := make([]interface{}, 0, len(docs)) 124 | 125 | for _, doc := range docs { 126 | // 从原始 MetaData 中拿出结构化字段 127 | kbID, ok := doc.MetaData["kb_id"].(string) 128 | if !ok { 129 | return nil, fmt.Errorf("invalid type for kb_id") 130 | } 131 | docID, ok := doc.MetaData["document_id"].(string) 132 | if !ok { 133 | return nil, fmt.Errorf("invalid type for document_id") 134 | } 135 | 136 | // 构造要序列化到 Metadata 字段里的 map,排除 kb_id 和 document_id 137 | metaCopy := make(map[string]any, len(doc.MetaData)) 138 | for k, v := range doc.MetaData { 139 | if k == "kb_id" || k == "document_id" { 140 | continue 141 | } 142 | metaCopy[k] = v 143 | } 144 | metadataBytes, err := sonic.Marshal(metaCopy) 145 | if err != nil { 146 | return nil, fmt.Errorf("[DocumentConvert] failed to marshal metadata: %w", err) 147 | } 148 | 149 | em = append(em, defaultSchema{ 150 | ID: doc.ID, 151 | Content: doc.Content, 152 | KBID: kbID, 153 | DocumentID: docID, 154 | Vector: nil, // 后面统一填充 155 | Metadata: metadataBytes, // 只包含剩下的字段 156 | }) 157 | } 158 | 159 | // 填充向量并生成 rows 160 | for idx, vec := range vectors { 161 | em[idx].Vector = utils.ConvertFloat64ToFloat32Embedding(vec) 162 | rows = append(rows, &em[idx]) 163 | } 164 | return rows, nil 165 | } 166 | 167 | func (m *MilvusIndexerConfig) createCollection(ctx context.Context, collectionName string, dimension int) error { 168 | // 获取 Milvus 配置 169 | milvusConfig := config.GetConfig().Milvus 170 | // 创建集合Schema 171 | s := &entity.Schema{ 172 | CollectionName: collectionName, 173 | Description: "存储文档分块和向量", 174 | AutoID: false, 175 | Fields: []*entity.Field{ 176 | { 177 | Name: consts.FieldNameID, 178 | DataType: entity.FieldTypeVarChar, 179 | PrimaryKey: true, 180 | AutoID: false, 181 | TypeParams: map[string]string{ 182 | "max_length": milvusConfig.IDMaxLength, 183 | }, 184 | }, 185 | { 186 | Name: consts.FieldNameContent, 187 | DataType: entity.FieldTypeVarChar, 188 | TypeParams: map[string]string{ 189 | "max_length": milvusConfig.ContentMaxLength, 190 | }, 191 | }, 192 | { 193 | Name: consts.FieldNameDocumentID, 194 | DataType: entity.FieldTypeVarChar, 195 | TypeParams: map[string]string{ 196 | "max_length": milvusConfig.DocIDMaxLength, 197 | }, 198 | }, 199 | { 200 | Name: consts.FieldNameKBID, 201 | DataType: entity.FieldTypeVarChar, 202 | TypeParams: map[string]string{ 203 | "max_length": milvusConfig.KbIDMaxLength, 204 | }, 205 | }, 206 | { 207 | Name: consts.FieldNameVector, 208 | DataType: entity.FieldTypeFloatVector, 209 | TypeParams: map[string]string{ 210 | "dim": strconv.Itoa(dimension), 211 | }, 212 | }, 213 | { 214 | Name: consts.FieldNameMetadata, 215 | DataType: entity.FieldTypeJSON, 216 | }, 217 | }, 218 | } 219 | 220 | // 创建集合 221 | if err := m.Client.CreateCollection(ctx, s, 1); err != nil { 222 | return fmt.Errorf("[NewMilvusIndexer.createCollection] 创建集合失败: %w", err) 223 | } 224 | 225 | // 创建索引 226 | idx, err := milvusConfig.GetMilvusIndex() 227 | if err != nil { 228 | return fmt.Errorf("[NewMilvusIndexer.createCollection] 从配置中获取索引类型失败: %w", err) 229 | } 230 | 231 | if err := m.Client.CreateIndex(ctx, collectionName, consts.FieldNameVector, idx, false); err != nil { 232 | return fmt.Errorf("[NewMilvusIndexer.createCollection] 创建索引失败: %w", err) 233 | } 234 | return nil 235 | } 236 | 237 | func (m *MilvusIndexerConfig) check() error { 238 | if m.Client == nil { 239 | return fmt.Errorf("[NewMilvusIndexer] milvus client is nil") 240 | } 241 | if m.Embedding == nil { 242 | return fmt.Errorf("[NewMilvusIndexer] embedding is nil") 243 | } 244 | if m.Collection == "" { 245 | return fmt.Errorf("[NewMilvusIndexer] collection is empty") 246 | } 247 | if m.Dimension == 0 { 248 | return fmt.Errorf("[NewMilvusIndexer] embedding dimension is zero") 249 | } 250 | return nil 251 | } 252 | 253 | func (m *MilvusIndexerConfig) IsCallbacksEnabled() bool { 254 | return true 255 | } 256 | 257 | func DeleteDos(client client.Client, docIDs []string, collectionName string) error { 258 | expr := fmt.Sprintf("%s in [\"%s\"]", consts.FieldNameDocumentID, strings.Join(docIDs, "\",\"")) 259 | if err := client.Delete(context.Background(), collectionName, "", expr); err != nil { 260 | return fmt.Errorf("[MilvusIndexer.DeleteDos] failed to delete documents: %w", err) 261 | } 262 | return nil 263 | } 264 | -------------------------------------------------------------------------------- /internal/component/indexer/milvus/types.go: -------------------------------------------------------------------------------- 1 | package milvus 2 | 3 | // defaultSchema 4 | type defaultSchema struct { 5 | ID string `json:"id" milvus:"name:id"` 6 | Content string `json:"content" milvus:"name:content"` 7 | DocumentID string `json:"document_id" milvus:"name:document_id"` 8 | KBID string `json:"kb_id" milvus:"name:kb_id"` 9 | Vector []float32 `json:"vector" milvus:"name:vector"` 10 | Metadata []byte `json:"metadata" milvus:"name:metadata"` // 存放例如DocumentName,Index等信息 11 | } 12 | -------------------------------------------------------------------------------- /internal/component/llm/llm.go: -------------------------------------------------------------------------------- 1 | package llmfactory 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "strings" 9 | "time" 10 | 11 | eino_model "github.com/cloudwego/eino/components/model" 12 | ollama_api "github.com/ollama/ollama/api" 13 | 14 | // 假设这些是你的 Ollama 和 OpenAI 客户端包 15 | "ai-cloud/internal/component/llm/ollama" 16 | "ai-cloud/internal/component/llm/openai" 17 | ) 18 | 19 | const ( 20 | defaultLLMTimeout = 60 * time.Second 21 | serverOllama = "ollama" 22 | serverOpenAI = "openai" 23 | modelTypeLLM = "llm" 24 | ) 25 | 26 | // GetLLMClient 使用 传入的 配置基础,并允许通过 clientDefaultOpts 设置客户端级别的默认调用选项。 27 | func GetLLMClient(ctx context.Context, cfg *model.Model) (eino_model.ToolCallingChatModel, error) { 28 | // 检查Model配置 29 | // TODO: 考虑通过check函数实现 30 | if cfg == nil { 31 | return nil, errors.New("input model configuration is nil") 32 | } 33 | if cfg.Type != modelTypeLLM { 34 | return nil, fmt.Errorf("model type is '%s', but expected '%s'", cfg.Type, modelTypeLLM) 35 | } 36 | 37 | // 2. 返回对应的server 38 | switch strings.ToLower(cfg.Server) { 39 | case serverOllama: 40 | ollamaCfg := &ollama.ChatModelConfig{ 41 | BaseURL: cfg.BaseURL, 42 | Model: cfg.ModelName, // 使用最终确定的模型名称 43 | Timeout: defaultLLMTimeout, 44 | Options: &ollama_api.Options{}, // 设置包含默认调用参数的 Options 45 | } 46 | return ollama.NewChatModel(ctx, ollamaCfg) 47 | 48 | case serverOpenAI: 49 | openAICfg := &openai.ChatModelConfig{ 50 | APIKey: cfg.APIKey, 51 | Model: cfg.ModelName, 52 | Timeout: defaultLLMTimeout, 53 | BaseURL: cfg.BaseURL, 54 | } 55 | return openai.NewChatModel(ctx, openAICfg) 56 | 57 | default: 58 | return nil, fmt.Errorf("unsupported LLM server type: '%s'", cfg.Server) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /internal/component/llm/ollama/call_option.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2024 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package ollama 18 | 19 | import ( 20 | "github.com/cloudwego/eino/components/model" 21 | ) 22 | 23 | type options struct { 24 | Seed *int 25 | } 26 | 27 | func WithSeed(seed int) model.Option { 28 | return model.WrapImplSpecificOptFn(func(o *options) { 29 | o.Seed = &seed 30 | }) 31 | } 32 | -------------------------------------------------------------------------------- /internal/component/llm/openai/openai.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2024 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package openai 18 | 19 | import ( 20 | "context" 21 | "net/http" 22 | "time" 23 | 24 | "github.com/cloudwego/eino/callbacks" 25 | "github.com/cloudwego/eino/components" 26 | "github.com/cloudwego/eino/components/model" 27 | "github.com/cloudwego/eino/schema" 28 | 29 | "github.com/cloudwego/eino-ext/libs/acl/openai" 30 | ) 31 | 32 | var _ model.ToolCallingChatModel = (*ChatModel)(nil) 33 | 34 | type ChatModelConfig struct { 35 | // APIKey is your authentication key 36 | // Use OpenAI API key or Azure API key depending on the service 37 | // Required 38 | APIKey string `json:"api_key"` 39 | 40 | // Timeout specifies the maximum duration to wait for API responses 41 | // If HTTPClient is set, Timeout will not be used. 42 | // Optional. Default: no timeout 43 | Timeout time.Duration `json:"timeout"` 44 | 45 | // HTTPClient specifies the client to send HTTP requests. 46 | // If HTTPClient is set, Timeout will not be used. 47 | // Optional. Default &http.Client{Timeout: Timeout} 48 | HTTPClient *http.Client `json:"http_client"` 49 | 50 | // The following three fields are only required when using Azure OpenAI Service, otherwise they can be ignored. 51 | // For more details, see: https://learn.microsoft.com/en-us/azure/ai-services/openai/ 52 | 53 | // ByAzure indicates whether to use Azure OpenAI Service 54 | // Required for Azure 55 | ByAzure bool `json:"by_azure"` 56 | 57 | // BaseURL is the Azure OpenAI endpoint URL 58 | // Format: https://{YOUR_RESOURCE_NAME}.openai.azure.com. YOUR_RESOURCE_NAME is the name of your resource that you have created on Azure. 59 | // Required for Azure 60 | BaseURL string `json:"base_url"` 61 | 62 | // APIVersion specifies the Azure OpenAI API version 63 | // Required for Azure 64 | APIVersion string `json:"api_version"` 65 | 66 | // The following fields correspond to OpenAI's chat completion API parameters 67 | // Ref: https://platform.openai.com/docs/api-reference/chat/create 68 | 69 | // Model specifies the ID of the model to use 70 | // Required 71 | Model string `json:"model"` 72 | 73 | // MaxTokens limits the maximum number of tokens that can be generated in the chat completion 74 | // Optional. Default: model's maximum 75 | MaxTokens *int `json:"max_tokens,omitempty"` 76 | 77 | // Temperature specifies what sampling temperature to use 78 | // Generally recommend altering this or TopP but not both. 79 | // Range: 0.0 to 2.0. Higher values make output more random 80 | // Optional. Default: 1.0 81 | Temperature *float32 `json:"temperature,omitempty"` 82 | 83 | // TopP controls diversity via nucleus sampling 84 | // Generally recommend altering this or Temperature but not both. 85 | // Range: 0.0 to 1.0. Lower values make output more focused 86 | // Optional. Default: 1.0 87 | TopP *float32 `json:"top_p,omitempty"` 88 | 89 | // Stop sequences where the API will stop generating further tokens 90 | // Optional. Example: []string{"\n", "User:"} 91 | Stop []string `json:"stop,omitempty"` 92 | 93 | // PresencePenalty prevents repetition by penalizing tokens based on presence 94 | // Range: -2.0 to 2.0. Positive values increase likelihood of new topics 95 | // Optional. Default: 0 96 | PresencePenalty *float32 `json:"presence_penalty,omitempty"` 97 | 98 | // ResponseFormat specifies the format of the model's response 99 | // Optional. Use for structured outputs 100 | ResponseFormat *openai.ChatCompletionResponseFormat `json:"response_format,omitempty"` 101 | 102 | // Seed enables deterministic sampling for consistent outputs 103 | // Optional. Set for reproducible results 104 | Seed *int `json:"seed,omitempty"` 105 | 106 | // FrequencyPenalty prevents repetition by penalizing tokens based on frequency 107 | // Range: -2.0 to 2.0. Positive values decrease likelihood of repetition 108 | // Optional. Default: 0 109 | FrequencyPenalty *float32 `json:"frequency_penalty,omitempty"` 110 | 111 | // LogitBias modifies likelihood of specific tokens appearing in completion 112 | // Optional. Map token IDs to bias values from -100 to 100 113 | LogitBias map[string]int `json:"logit_bias,omitempty"` 114 | 115 | // User unique identifier representing end-user 116 | // Optional. Helps OpenAI monitor and detect abuse 117 | User *string `json:"user,omitempty"` 118 | } 119 | 120 | var _ model.ChatModel = (*ChatModel)(nil) 121 | 122 | type ChatModel struct { 123 | cli *openai.Client 124 | } 125 | 126 | func NewChatModel(ctx context.Context, config *ChatModelConfig) (*ChatModel, error) { 127 | var nConf *openai.Config 128 | if config != nil { 129 | var httpClient *http.Client 130 | 131 | if config.HTTPClient != nil { 132 | httpClient = config.HTTPClient 133 | } else { 134 | httpClient = &http.Client{Timeout: config.Timeout} 135 | } 136 | 137 | nConf = &openai.Config{ 138 | ByAzure: config.ByAzure, 139 | BaseURL: config.BaseURL, 140 | APIVersion: config.APIVersion, 141 | APIKey: config.APIKey, 142 | HTTPClient: httpClient, 143 | Model: config.Model, 144 | MaxTokens: config.MaxTokens, 145 | Temperature: config.Temperature, 146 | TopP: config.TopP, 147 | Stop: config.Stop, 148 | PresencePenalty: config.PresencePenalty, 149 | ResponseFormat: config.ResponseFormat, 150 | Seed: config.Seed, 151 | FrequencyPenalty: config.FrequencyPenalty, 152 | LogitBias: config.LogitBias, 153 | User: config.User, 154 | } 155 | } 156 | cli, err := openai.NewClient(ctx, nConf) 157 | if err != nil { 158 | return nil, err 159 | } 160 | 161 | return &ChatModel{ 162 | cli: cli, 163 | }, nil 164 | } 165 | 166 | func (cm *ChatModel) Generate(ctx context.Context, in []*schema.Message, opts ...model.Option) ( 167 | outMsg *schema.Message, err error) { 168 | ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel) 169 | return cm.cli.Generate(ctx, in, opts...) 170 | } 171 | 172 | func (cm *ChatModel) Stream(ctx context.Context, in []*schema.Message, opts ...model.Option) (outStream *schema.StreamReader[*schema.Message], err error) { 173 | ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel) 174 | return cm.cli.Stream(ctx, in, opts...) 175 | } 176 | 177 | func (cm *ChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { 178 | cli, err := cm.cli.WithToolsForClient(tools) 179 | if err != nil { 180 | return nil, err 181 | } 182 | return &ChatModel{cli: cli}, nil 183 | } 184 | 185 | func (cm *ChatModel) BindTools(tools []*schema.ToolInfo) error { 186 | return cm.cli.BindTools(tools) 187 | } 188 | 189 | func (cm *ChatModel) BindForcedTools(tools []*schema.ToolInfo) error { 190 | return cm.cli.BindForcedTools(tools) 191 | } 192 | 193 | const typ = "OpenAI" 194 | 195 | func (cm *ChatModel) GetType() string { 196 | return typ 197 | } 198 | 199 | func (cm *ChatModel) IsCallbacksEnabled() bool { 200 | return cm.cli.IsCallbacksEnabled() 201 | } 202 | -------------------------------------------------------------------------------- /internal/component/parser/pdf/docconv_parser.go: -------------------------------------------------------------------------------- 1 | /* 2 | 基于 docconv 库的 pdf 解析器; 3 | 实现了Eino 组件接口的 Parse 方法。 4 | */ 5 | 6 | package pdf 7 | 8 | import ( 9 | "context" 10 | "fmt" 11 | "io" 12 | 13 | "github.com/cloudwego/eino/components/document/parser" 14 | "github.com/cloudwego/eino/schema" 15 | 16 | "code.sajari.com/docconv/v2" 17 | ) 18 | 19 | // options 20 | // 定制实现自主定义的 option 结构体 21 | type options struct { 22 | toPages *bool 23 | } 24 | 25 | func WithToPages(toPages bool) parser.Option { 26 | return parser.WrapImplSpecificOptFn(func(opts *options) { 27 | opts.toPages = &toPages 28 | }) 29 | } 30 | 31 | type Config struct { 32 | ToPages bool 33 | } 34 | type DocconvPDFParser struct { 35 | ToPages bool 36 | } 37 | 38 | func NewDocconvPDFParser(ctx context.Context, config *Config) (*DocconvPDFParser, error) { 39 | if config == nil { 40 | config = &Config{} 41 | } 42 | return &DocconvPDFParser{ToPages: config.ToPages}, nil 43 | } 44 | 45 | func (pp *DocconvPDFParser) Parse(ctx context.Context, reader io.Reader, opts ...parser.Option) ([]*schema.Document, error) { 46 | // 1. 处理通用选项 47 | commonOpts := parser.GetCommonOptions(nil, opts...) 48 | 49 | specificOpts := parser.GetImplSpecificOptions(&options{ 50 | toPages: &pp.ToPages, 51 | }, opts...) 52 | 53 | // 3. 实现解析逻辑 54 | fmt.Println("开始解析PDF文档...") 55 | res, meta, err := docconv.ConvertPDF(reader) 56 | if err != nil { 57 | fmt.Printf("PDF解析错误: %v\n", err) 58 | return nil, fmt.Errorf("PDF解析失败: %w", err) 59 | } 60 | 61 | fmt.Printf("PDF解析完成,文本长度: %d字符\n", len(res)) 62 | fmt.Printf("PDF元数据: %+v\n", meta) 63 | 64 | // 检查解析结果是否为空 65 | if len(res) < 100 { // 至少需要100个字符才算有效 66 | fmt.Println("PDF解析结果太短或为空") 67 | if len(res) == 0 { 68 | return nil, fmt.Errorf("PDF解析结果为空,可能是扫描PDF或无文本内容") 69 | } 70 | } 71 | 72 | if *specificOpts.toPages { 73 | fmt.Println("待处理分页") 74 | } 75 | 76 | return []*schema.Document{{ 77 | Content: res, 78 | MetaData: commonOpts.ExtraMeta, 79 | }}, nil 80 | } 81 | -------------------------------------------------------------------------------- /internal/component/retriever/milvus/multi_retriever.go: -------------------------------------------------------------------------------- 1 | package milvus 2 | 3 | import ( 4 | "ai-cloud/internal/component/embedding" 5 | "ai-cloud/internal/dao" 6 | "ai-cloud/internal/database" 7 | "context" 8 | "fmt" 9 | eretriever "github.com/cloudwego/eino/components/retriever" 10 | "github.com/cloudwego/eino/schema" 11 | "log" 12 | "sort" 13 | "time" 14 | ) 15 | 16 | // 自定义的多知识库Retriever 17 | type MultiKBRetriever struct { 18 | KBIDs []string 19 | UserID uint 20 | KBDao dao.KnowledgeBaseDao 21 | ModelDao dao.ModelDao 22 | Ctx context.Context 23 | TopK int 24 | } 25 | 26 | func (m MultiKBRetriever) Retrieve(ctx context.Context, query string, opts ...eretriever.Option) ([]*schema.Document, error) { 27 | // 如果没有提供知识库ID,则返回空结果 28 | if len(m.KBIDs) == 0 { 29 | return []*schema.Document{}, nil 30 | } 31 | 32 | // 保存所有文档结果 33 | allDocuments := []*schema.Document{} 34 | 35 | // 对每个知识库进行检索 36 | for _, kbID := range m.KBIDs { 37 | // 获取知识库信息 38 | kb, err := m.KBDao.GetKBByID(kbID) 39 | if err != nil { 40 | return nil, fmt.Errorf("knowledge base not found: %w", err) 41 | } 42 | if kb.UserID != m.UserID { 43 | return nil, fmt.Errorf("userID mismatch: %w", err) 44 | } 45 | 46 | // 获取Embedding模型 47 | embedModel, err := m.ModelDao.GetByID(m.Ctx, m.UserID, kb.EmbedModelID) 48 | if err != nil { 49 | return nil, fmt.Errorf("failed to retrieve embedding model: %w", err) 50 | } 51 | 52 | // 创建Embedding服务 53 | embeddingService, err := embedding.NewEmbeddingService( 54 | m.Ctx, 55 | embedModel, 56 | embedding.WithTimeout(30*time.Second), 57 | ) 58 | if err != nil { 59 | return nil, fmt.Errorf("failed to initialize embedding service: %w", err) 60 | } 61 | 62 | // 创建当前知识库的Retriever 63 | retrieverConf := &MilvusRetrieverConfig{ 64 | Client: database.GetMilvusClient(), 65 | Embedding: embeddingService, 66 | Collection: kb.MilvusCollection, 67 | KBIDs: []string{kbID}, 68 | SearchFields: nil, 69 | TopK: 3, 70 | ScoreThreshold: 0, 71 | } 72 | 73 | retriever, err := NewMilvusRetriever(ctx, retrieverConf) 74 | if err != nil { 75 | return nil, fmt.Errorf("failed to create retriever for kb %s: %w", kbID, err) 76 | } 77 | 78 | // 执行检索 79 | docs, err := retriever.Retrieve(ctx, query) 80 | if err != nil { 81 | return nil, fmt.Errorf("failed to retrieve from kb %s: %w", kbID, err) 82 | } 83 | 84 | // 将结果添加到总结果中 85 | allDocuments = append(allDocuments, docs...) 86 | } 87 | 88 | // TODO:先简单按照分数返回。这是不合理的!需要用rerank! 89 | sort.Slice(allDocuments, func(i, j int) bool { 90 | scoreI, okI := allDocuments[i].MetaData["score"].(float64) 91 | scoreJ, okJ := allDocuments[j].MetaData["score"].(float64) 92 | 93 | // 如果score不存在或类型不正确,则将其视为最低优先级 94 | if !okI { 95 | return false 96 | } 97 | if !okJ { 98 | return true 99 | } 100 | 101 | return scoreI > scoreJ // 降序排序 102 | }) 103 | 104 | if len(allDocuments) > m.TopK { 105 | allDocuments = allDocuments[:m.TopK] 106 | } 107 | log.Printf("[Multi Retriever] Retrieved %d documents from %d knowledge bases", len(allDocuments), len(m.KBIDs)) 108 | return allDocuments, nil 109 | } 110 | 111 | func (m *MultiKBRetriever) GetType() string { 112 | return "MultiKnowledgeBaseRetriever" 113 | } 114 | -------------------------------------------------------------------------------- /internal/component/retriever/milvus/retriever.go: -------------------------------------------------------------------------------- 1 | package milvus 2 | 3 | import ( 4 | "ai-cloud/config" 5 | "ai-cloud/internal/utils" 6 | "ai-cloud/pkgs/consts" 7 | "context" 8 | "fmt" 9 | "github.com/bytedance/sonic" 10 | "github.com/cloudwego/eino/components/embedding" 11 | "github.com/cloudwego/eino/components/retriever" 12 | "github.com/cloudwego/eino/schema" 13 | "github.com/milvus-io/milvus-sdk-go/v2/client" 14 | "github.com/milvus-io/milvus-sdk-go/v2/entity" 15 | "strings" 16 | ) 17 | 18 | type MilvusRetrieverConfig struct { 19 | Client client.Client // Required 20 | Embedding embedding.Embedder // Required 21 | Collection string // Required 22 | KBIDs []string // Required 至少要查询一个知识库 23 | SearchFields []string // Optional defaultSearchFields 24 | TopK int // Optional default is 5 25 | ScoreThreshold float64 // Optional default is 0 26 | } 27 | 28 | type MilvusRetriever struct { 29 | config MilvusRetrieverConfig 30 | } 31 | 32 | func NewMilvusRetriever(ctx context.Context, conf *MilvusRetrieverConfig) (*MilvusRetriever, error) { 33 | // 检查必要配置,设置默认值 34 | if err := conf.check(); err != nil { 35 | return nil, fmt.Errorf("[NewMilvusRetriever] check config failed : %w", err) 36 | } 37 | // 检查Collection是否存在 38 | exists, err := conf.Client.HasCollection(ctx, conf.Collection) 39 | if err != nil { 40 | return nil, fmt.Errorf("[NewMilvusRetriever] check milvus collection failed : %w", err) 41 | } 42 | if !exists { 43 | return nil, fmt.Errorf("[NewMilvusRetirever] collection %s not exists", conf.Collection) 44 | } 45 | 46 | // 检查是否load,没load的话load 47 | collection, err := conf.Client.DescribeCollection(ctx, conf.Collection) 48 | if err != nil { 49 | return nil, fmt.Errorf("[NewRetriever] failed to describe collection: %w", err) 50 | } 51 | 52 | if !collection.Loaded { 53 | err = conf.Client.LoadCollection(ctx, conf.Collection, false) 54 | if err != nil { 55 | return nil, fmt.Errorf("[NewMilvusRetriever] failed to load collection: %w", err) 56 | } 57 | } 58 | 59 | return &MilvusRetriever{ 60 | config: *conf, 61 | }, nil 62 | } 63 | 64 | var FieldNameVector = "vector" 65 | 66 | func (m *MilvusRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { 67 | // retrieve的时候指定参数 68 | co := retriever.GetCommonOptions(&retriever.Options{ 69 | SubIndex: nil, 70 | TopK: &m.config.TopK, 71 | ScoreThreshold: &m.config.ScoreThreshold, 72 | Embedding: m.config.Embedding, 73 | }, opts...) 74 | 75 | emb := co.Embedding 76 | vectors, err := emb.EmbedStrings(ctx, []string{query}) 77 | if err != nil { 78 | return nil, fmt.Errorf("[MilvusRetriver.Retrieve] embedding has error: %w", err) 79 | } 80 | // 检查结果数量是否正确 81 | if len(vectors) != 1 { 82 | return nil, fmt.Errorf("[MilvusRetriver.Retrieve] invalid return length of vector, got=%d, expected=1", len(vectors)) 83 | } 84 | 85 | vector := utils.ConvertFloat64ToFloat32Embedding(vectors[0]) 86 | 87 | // 构造查询条件和参数 88 | kbIDs := m.config.KBIDs 89 | var expr string 90 | if len(kbIDs) > 0 { 91 | quotedIDs := make([]string, len(kbIDs)) 92 | for i, id := range kbIDs { 93 | quotedIDs[i] = fmt.Sprintf(`"%s"`, id) 94 | } 95 | expr = fmt.Sprintf("%s in [%s]", consts.FieldNameKBID, strings.Join(quotedIDs, ",")) 96 | } else { 97 | expr = "0 == 1" 98 | } 99 | 100 | var results []client.SearchResult 101 | sp, _ := entity.NewIndexIvfFlatSearchParam(config.GetConfig().Milvus.Nprobe) 102 | metricType := config.GetConfig().Milvus.GetMetricType() 103 | results, err = m.config.Client.Search( 104 | ctx, 105 | m.config.Collection, // 集合名称:指定要搜索的Milvus集合 106 | []string{}, // 分区名称:空表示搜索所有分区 107 | expr, // 过滤表达式:限制搜索范围,这里只搜索指定知识库ID的文档 108 | m.config.SearchFields, // 输出字段:指定返回结果中包含哪些字段 109 | []entity.Vector{entity.FloatVector(vector)}, // 查询向量:将输入向量转换为Milvus向量格式 110 | consts.FieldNameVector, // 向量字段名:指定在哪个字段上执行向量搜索(对应Index) 111 | metricType, // 度量类型:如何计算向量相似度(如余弦相似度、欧几里得距离等) 112 | m.config.TopK, // 返回数量:返回的最相似结果数量 113 | sp, // 搜索参数:索引特定的搜索参数,如nprobe(探测聚类数) 114 | ) 115 | 116 | documents := make([]*schema.Document, 0, len(results)) 117 | for _, result := range results { 118 | if result.Err != nil { 119 | return nil, fmt.Errorf("[MilvusRetriver.Retrieve] search result has error: %w", result.Err) 120 | } 121 | if result.IDs == nil || result.Fields == nil { 122 | return nil, fmt.Errorf("[MilvusRetriver.Retrieve] search result has no ids or fields") 123 | } 124 | document, err := DocumentConverter(ctx, result) 125 | if err != nil { 126 | return nil, fmt.Errorf("[MilvusRetriver.Retrieve] failed to convert search result to schema.Document: %w", err) 127 | } 128 | documents = append(documents, document...) 129 | } 130 | return documents, nil 131 | } 132 | 133 | // defaultDocumentConverter returns the default document converter 134 | func DocumentConverter(ctx context.Context, doc client.SearchResult) ([]*schema.Document, error) { 135 | var err error 136 | result := make([]*schema.Document, doc.IDs.Len(), doc.IDs.Len()) 137 | for i := range result { 138 | result[i] = &schema.Document{ 139 | MetaData: make(map[string]any), 140 | } 141 | } 142 | 143 | importantMetaFields := map[string]bool{ 144 | consts.FieldNameDocumentID: true, 145 | consts.FieldNameKBID: true, 146 | } 147 | 148 | for _, field := range doc.Fields { 149 | switch field.Name() { 150 | case consts.FieldNameID: 151 | for i, document := range result { 152 | document.ID, err = doc.IDs.GetAsString(i) 153 | if err != nil { 154 | return nil, fmt.Errorf("failed to get id: %w", err) 155 | } 156 | } 157 | case consts.FieldNameContent: 158 | for i, document := range result { 159 | document.Content, err = field.GetAsString(i) 160 | if err != nil { 161 | return nil, fmt.Errorf("failed to get content: %w", err) 162 | } 163 | } 164 | case consts.FieldNameMetadata: 165 | for i := range result { 166 | val, _ := field.Get(i) 167 | bytes, ok := val.([]byte) 168 | if !ok { 169 | return nil, fmt.Errorf("metadata field is not []byte") 170 | } 171 | var meta map[string]any 172 | if err := sonic.Unmarshal(bytes, &meta); err != nil { 173 | return nil, fmt.Errorf("unmarshal metadata failed: %w", err) 174 | } 175 | for k, v := range meta { 176 | result[i].MetaData[k] = v 177 | } 178 | } 179 | default: 180 | if importantMetaFields[field.Name()] { 181 | for i := range result { 182 | val, err := field.GetAsString(i) 183 | if err != nil { 184 | return nil, fmt.Errorf("get field %s failed: %w", field.Name(), err) 185 | } 186 | result[i].MetaData[field.Name()] = val 187 | } 188 | } 189 | } 190 | } 191 | 192 | for i := range result { 193 | if i >= len(doc.Scores) { 194 | continue 195 | } 196 | result[i].MetaData["score"] = doc.Scores[i] 197 | } 198 | 199 | return result, nil 200 | } 201 | 202 | func (m *MilvusRetriever) GetType() string { 203 | return "Milvus" 204 | } 205 | 206 | // 检查必要配置 207 | func (m *MilvusRetrieverConfig) check() error { 208 | if m.Client == nil { 209 | return fmt.Errorf("[NewMilvusRetriever] milvus client is nil") 210 | } 211 | if m.Embedding == nil { 212 | return fmt.Errorf("[NewMilvusRetriever] embedding is nil") 213 | } 214 | if m.Collection == "" { 215 | return fmt.Errorf("[NewMilvusRetriever] collection is empty") 216 | } 217 | if m.SearchFields == nil { 218 | m.SearchFields = defaultSearchFields // 默认搜索字段 219 | } 220 | if m.TopK == 0 { 221 | m.TopK = 5 // 默认返回结果数量 222 | } 223 | if m.ScoreThreshold == 0 { 224 | m.ScoreThreshold = 0 // 默认相似度阈值 225 | } 226 | 227 | return nil 228 | } 229 | -------------------------------------------------------------------------------- /internal/component/retriever/milvus/types.go: -------------------------------------------------------------------------------- 1 | package milvus 2 | 3 | import "ai-cloud/pkgs/consts" 4 | 5 | var ( 6 | defaultSearchFields = []string{ 7 | consts.FieldNameID, 8 | consts.FieldNameContent, 9 | consts.FieldNameKBID, 10 | consts.FieldNameDocumentID, 11 | consts.FieldNameMetadata, 12 | } 13 | ) 14 | -------------------------------------------------------------------------------- /internal/controller/conversation_controller.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "ai-cloud/internal/service" 6 | "ai-cloud/internal/utils" 7 | "ai-cloud/pkgs/errcode" 8 | "ai-cloud/pkgs/response" 9 | "errors" 10 | "io" 11 | "log" 12 | 13 | "github.com/gin-contrib/sse" 14 | "github.com/gin-gonic/gin" 15 | "github.com/google/uuid" 16 | ) 17 | 18 | type ConversationController struct { 19 | svc service.ConversationService 20 | } 21 | 22 | func NewConversationController(svc service.ConversationService) *ConversationController { 23 | return &ConversationController{svc: svc} 24 | } 25 | 26 | // DebugStreamAgent 调试模式,不保存历史 27 | func (c *ConversationController) DebugStreamAgent(ctx *gin.Context) { 28 | userID, err := utils.GetUserIDFromContext(ctx) 29 | if err != nil { 30 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "Failed to get user") 31 | return 32 | } 33 | 34 | var req model.DebugRequest 35 | if err := ctx.ShouldBindJSON(&req); err != nil { 36 | response.ParamError(ctx, errcode.ParamBindError, "Parameter error: "+err.Error()) 37 | return 38 | } 39 | 40 | // 调用debug模式流式处理 41 | sr, err := c.svc.DebugStreamAgent(ctx.Request.Context(), userID, req.AgentID, req.Message) 42 | if err != nil { 43 | response.InternalError(ctx, errcode.InternalServerError, "Agent execution failed: "+err.Error()) 44 | return 45 | } 46 | 47 | // 设置SSE响应头 48 | ctx.Writer.Header().Set("Content-Type", "text/event-stream") 49 | ctx.Writer.Header().Set("Cache-Control", "no-cache") 50 | ctx.Writer.Header().Set("Connection", "keep-alive") 51 | ctx.Writer.Header().Set("Transfer-Encoding", "chunked") 52 | 53 | // 传输流 54 | sessionID := uuid.NewString() 55 | done := make(chan struct{}) 56 | defer func() { 57 | sr.Close() 58 | close(done) 59 | log.Printf("[Debug Stream] Finish Stream with ID: %s\n", sessionID) 60 | }() 61 | 62 | // 流式响应 63 | ctx.Stream(func(w io.Writer) bool { 64 | select { 65 | case <-ctx.Request.Context().Done(): 66 | log.Printf("[Debug Stream] Context done for session ID: %s\n", sessionID) 67 | return false 68 | case <-done: 69 | return false 70 | default: 71 | msg, err := sr.Recv() 72 | if errors.Is(err, io.EOF) { 73 | log.Printf("[Debug Stream] EOF received for session ID: %s\n", sessionID) 74 | return false 75 | } 76 | if err != nil { 77 | log.Printf("[Debug Stream] Error receiving message: %v\n", err) 78 | return false 79 | } 80 | 81 | // 发送SSE事件 82 | sse.Encode(w, sse.Event{ 83 | Data: []byte(msg.Content), 84 | }) 85 | 86 | // 立即刷新响应 87 | ctx.Writer.Flush() 88 | return true 89 | } 90 | }) 91 | } 92 | 93 | // CreateConversation 创建新会话 94 | func (c *ConversationController) CreateConversation(ctx *gin.Context) { 95 | userID, err := utils.GetUserIDFromContext(ctx) 96 | if err != nil { 97 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "Failed to get user") 98 | return 99 | } 100 | 101 | var req model.CreateConvRequest 102 | if err := ctx.ShouldBindJSON(&req); err != nil { 103 | response.ParamError(ctx, errcode.ParamBindError, "Parameter error: "+err.Error()) 104 | return 105 | } 106 | 107 | // 创建会话 108 | convID, err := c.svc.CreateConversation(ctx.Request.Context(), userID, req.AgentID) 109 | if err != nil { 110 | log.Printf("[Conversation Create] Error creating conversation: %v\n", err) 111 | response.InternalError(ctx, errcode.InternalServerError, "Failed to create conversation") 112 | return 113 | } 114 | 115 | response.SuccessWithMessage(ctx, "Conversation created successfully", gin.H{"conv_id": convID}) 116 | } 117 | 118 | // StreamConversation 会话模式,保存历史 119 | func (c *ConversationController) StreamConversation(ctx *gin.Context) { 120 | userID, err := utils.GetUserIDFromContext(ctx) 121 | if err != nil { 122 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "Failed to get user") 123 | return 124 | } 125 | 126 | var req model.ConvRequest 127 | if err := ctx.ShouldBindJSON(&req); err != nil { 128 | response.ParamError(ctx, errcode.ParamBindError, "Parameter error: "+err.Error()) 129 | return 130 | } 131 | 132 | if req.ConvID == "" { 133 | convID, err := c.svc.CreateConversation(ctx.Request.Context(), userID, req.AgentID) 134 | if err != nil { 135 | log.Printf("[Conversation Create] Error creating conversation: %v\n", err) 136 | response.InternalError(ctx, errcode.InternalServerError, "Failed to create conversation") 137 | return 138 | } 139 | req.ConvID = convID 140 | } 141 | 142 | // 调用会话模式流式处理 143 | sr, err := c.svc.StreamAgentWithConversation(ctx.Request.Context(), userID, req.AgentID, req.ConvID, req.Message) 144 | if err != nil { 145 | log.Printf("[Conversation Stream] Error running agent: %v\n", err) 146 | response.InternalError(ctx, errcode.InternalServerError, "Agent execution failed") 147 | return 148 | } 149 | 150 | // 设置SSE响应头 151 | ctx.Writer.Header().Set("Content-Type", "text/event-stream") 152 | ctx.Writer.Header().Set("Cache-Control", "no-cache") 153 | ctx.Writer.Header().Set("Connection", "keep-alive") 154 | ctx.Writer.Header().Set("Transfer-Encoding", "chunked") 155 | 156 | // 传输流 157 | done := make(chan struct{}) 158 | defer func() { 159 | sr.Close() 160 | close(done) 161 | log.Printf("[Conversation Stream] Finish Stream with ConvID: %s\n", req.ConvID) 162 | }() 163 | 164 | // 流式响应 165 | ctx.Stream(func(w io.Writer) bool { 166 | select { 167 | case <-ctx.Request.Context().Done(): 168 | log.Printf("[Conversation Stream] Context done for ConvID: %s\n", req.ConvID) 169 | return false 170 | case <-done: 171 | return false 172 | default: 173 | msg, err := sr.Recv() 174 | if errors.Is(err, io.EOF) { 175 | log.Printf("[Conversation Stream] EOF received for ConvID: %s\n", req.ConvID) 176 | return false 177 | } 178 | if err != nil { 179 | log.Printf("[Conversation Stream] Error receiving message: %v\n", err) 180 | return false 181 | } 182 | 183 | // 发送SSE事件 184 | sse.Encode(w, sse.Event{ 185 | Data: []byte(msg.Content), 186 | }) 187 | 188 | // 立即刷新响应 189 | ctx.Writer.Flush() 190 | return true 191 | } 192 | }) 193 | } 194 | 195 | // ListConversations 获取用户所有会话 196 | func (c *ConversationController) ListConversations(ctx *gin.Context) { 197 | userID, err := utils.GetUserIDFromContext(ctx) 198 | if err != nil { 199 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "Failed to get user") 200 | return 201 | } 202 | 203 | // 分页参数 204 | page := utils.StringToInt(ctx.DefaultQuery("page", "1")) 205 | size := utils.StringToInt(ctx.DefaultQuery("size", "10")) 206 | 207 | // 获取会话列表 208 | convs, count, err := c.svc.ListConversations(ctx.Request.Context(), userID, page, size) 209 | if err != nil { 210 | response.InternalError(ctx, errcode.InternalServerError, "Failed to list conversations: "+err.Error()) 211 | return 212 | } 213 | 214 | // 返回分页数据 215 | response.PageSuccess(ctx, convs, count) 216 | } 217 | 218 | // ListAgentConversations 获取特定Agent的会话 219 | func (c *ConversationController) ListAgentConversations(ctx *gin.Context) { 220 | userID, err := utils.GetUserIDFromContext(ctx) 221 | if err != nil { 222 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "Failed to get user") 223 | return 224 | } 225 | 226 | // 获取AgentID 227 | agentID := ctx.Query("agent_id") 228 | if agentID == "" { 229 | response.ParamError(ctx, errcode.ParamBindError, "Agent ID is required") 230 | return 231 | } 232 | 233 | // 分页参数 234 | page := utils.StringToInt(ctx.DefaultQuery("page", "1")) 235 | size := utils.StringToInt(ctx.DefaultQuery("size", "10")) 236 | 237 | // 获取会话列表 238 | convs, count, err := c.svc.ListAgentConversations(ctx.Request.Context(), userID, agentID, page, size) 239 | if err != nil { 240 | response.InternalError(ctx, errcode.InternalServerError, "Failed to list agent conversations: "+err.Error()) 241 | return 242 | } 243 | 244 | // 返回分页数据 245 | response.PageSuccess(ctx, convs, count) 246 | } 247 | 248 | // GetConversationHistory 获取会话历史消息 249 | func (c *ConversationController) GetConversationHistory(ctx *gin.Context) { 250 | // 获取会话ID 251 | convID := ctx.Query("conv_id") 252 | if convID == "" { 253 | response.ParamError(ctx, errcode.ParamBindError, "Conversation ID is required") 254 | return 255 | } 256 | 257 | // 限制参数 258 | limit := utils.StringToInt(ctx.DefaultQuery("limit", "50")) 259 | 260 | // 获取历史消息 261 | msgs, err := c.svc.GetConversationHistory(ctx.Request.Context(), convID, limit) 262 | if err != nil { 263 | response.InternalError(ctx, errcode.InternalServerError, "Failed to get conversation history: "+err.Error()) 264 | return 265 | } 266 | 267 | // 返回历史消息 268 | response.SuccessWithMessage(ctx, "Conversation history retrieved successfully", gin.H{"messages": msgs}) 269 | } 270 | 271 | // DeleteConversation 删除会话 272 | func (c *ConversationController) DeleteConversation(ctx *gin.Context) { 273 | // 获取会话ID 274 | convID := ctx.Query("conv_id") 275 | if convID == "" { 276 | response.ParamError(ctx, errcode.ParamBindError, "Conversation ID is required") 277 | return 278 | } 279 | 280 | // 删除会话 281 | err := c.svc.DeleteConversation(ctx.Request.Context(), convID) 282 | if err != nil { 283 | log.Printf("[Conversation Delete] Error deleting conversation: %v\n", err) 284 | response.InternalError(ctx, errcode.InternalServerError, "Failed to delete conversation: "+err.Error()) 285 | return 286 | } 287 | 288 | // 返回成功消息 289 | response.SuccessWithMessage(ctx, "Conversation deleted successfully", nil) 290 | } 291 | -------------------------------------------------------------------------------- /internal/controller/file_controller.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "ai-cloud/internal/service" 6 | "ai-cloud/internal/utils" 7 | "ai-cloud/pkgs/errcode" 8 | "ai-cloud/pkgs/response" 9 | "fmt" 10 | "net/http" 11 | "strconv" 12 | 13 | "github.com/gin-gonic/gin" 14 | ) 15 | 16 | type FileController struct { 17 | fileService service.FileService 18 | } 19 | 20 | func NewFileController(fileService service.FileService) *FileController { 21 | return &FileController{fileService: fileService} 22 | } 23 | 24 | func (fc *FileController) Upload(ctx *gin.Context) { 25 | // 1. 获取用户ID 26 | userID, err := utils.GetUserIDFromContext(ctx) 27 | if err != nil { 28 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 29 | return 30 | } 31 | 32 | // 2. 解析表单文件 33 | fileHeader, err := ctx.FormFile("file") 34 | if err != nil { 35 | response.ParamError(ctx, errcode.ParamBindError, "上传失败") 36 | return 37 | } 38 | 39 | // 获取文件内容 40 | file, err := fileHeader.Open() 41 | if err != nil { 42 | 43 | response.ParamError(ctx, errcode.FileParseFailed, "上传失败") 44 | return 45 | } 46 | defer file.Close() 47 | 48 | // 4. 获取父目录ID(可选参数) 49 | parentID := ctx.PostForm("parent_id") // 空字符串表示根目录 50 | 51 | // 调用 Service 层处理文件上传 52 | _, err = fc.fileService.UploadFile(userID, fileHeader, file, parentID) 53 | if err != nil { 54 | response.InternalError(ctx, errcode.FileUploadFailed, "上传失败") 55 | return 56 | } 57 | response.SuccessWithMessage(ctx, "文件上传成功", nil) 58 | } 59 | 60 | func (fc *FileController) PageList(ctx *gin.Context) { 61 | // 获取用户ID并验证 62 | userID, err := utils.GetUserIDFromContext(ctx) 63 | if err != nil { 64 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 65 | return 66 | } 67 | // 获取父目录,处理根目录情况 68 | parentID := ctx.Query("parent_id") 69 | var parentIDPtr *string 70 | if parentID != "" { 71 | parentIDPtr = &parentID 72 | } 73 | 74 | page, pageSize, err := utils.ParsePaginationParams(ctx) 75 | if err != nil { 76 | response.ParamError(ctx, errcode.ParamBindError, "分页参数错误") 77 | return 78 | } 79 | 80 | sort := ctx.DefaultQuery("sort", "name:asc") 81 | if err := utils.ValidateSortParameter(sort, []string{"name", "updated_at"}); err != nil { 82 | response.ParamError(ctx, errcode.ParamValidateError, "排序参数错误") 83 | return 84 | } 85 | 86 | total, files, err := fc.fileService.PageList(userID, parentIDPtr, page, pageSize, sort) 87 | if err != nil { 88 | response.InternalError(ctx, errcode.FileListFailed, "获取文件列表失败") 89 | return 90 | } 91 | 92 | response.PageSuccess(ctx, files, total) 93 | } 94 | 95 | func (fc *FileController) Download(ctx *gin.Context) { 96 | fileID := ctx.Query("file_id") 97 | 98 | userID, err := utils.GetUserIDFromContext(ctx) 99 | if err != nil { 100 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 101 | return 102 | } 103 | 104 | fileMeta, fileData, err := fc.fileService.DownloadFile(fileID) 105 | 106 | if err != nil { 107 | response.InternalError(ctx, errcode.FileNotFound, "文件不存在") 108 | return 109 | } 110 | 111 | if userID != fileMeta.UserID { 112 | response.UnauthorizedError(ctx, errcode.ForbiddenError, "权限不足") 113 | return 114 | } 115 | ctx.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", fileMeta.Name)) 116 | ctx.Header("Content-Type", fileMeta.MIMEType) 117 | ctx.Header("Content-Length", strconv.FormatInt(fileMeta.Size, 10)) 118 | ctx.Data(http.StatusOK, fileMeta.MIMEType, fileData) 119 | } 120 | 121 | func (fc *FileController) Delete(ctx *gin.Context) { 122 | fileID := ctx.Query("file_id") 123 | if fileID == "" { 124 | response.ParamError(ctx, errcode.ParamValidateError, "参数错误") 125 | return 126 | } 127 | userID, err := utils.GetUserIDFromContext(ctx) 128 | if err != nil { 129 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户未认证") 130 | return 131 | } 132 | 133 | if err := fc.fileService.DeleteFileOrFolder(userID, fileID); err != nil { 134 | response.InternalError(ctx, errcode.FileDeleteFailed, "删除失败") 135 | return 136 | } 137 | 138 | response.SuccessWithMessage(ctx, "删除成功", nil) 139 | } 140 | 141 | func (fc *FileController) CreateFolder(ctx *gin.Context) { 142 | //var req models.CreateFolderRequest 143 | //if err := ctx.ShouldBind(&req); err != nil { 144 | // ctx.JSON(http.StatusBadRequest, common.Error(100, "参数错误")) 145 | // return 146 | //} 147 | req := model.CreateFolderReq{} 148 | err := ctx.ShouldBindJSON(&req) 149 | if err != nil { 150 | response.ParamError(ctx, errcode.ParamBindError, "参数错误") 151 | return 152 | } 153 | 154 | if req.ParentID != nil && *req.ParentID == "" { 155 | req.ParentID = nil 156 | } 157 | 158 | userID, err := utils.GetUserIDFromContext(ctx) 159 | if err != nil { 160 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 161 | return 162 | } 163 | 164 | err = fc.fileService.CreateFolder(userID, req.Name, req.ParentID) 165 | if err != nil { 166 | response.InternalError(ctx, errcode.InternalServerError, "文件夹创建失败") 167 | return 168 | } 169 | response.SuccessWithMessage(ctx, "创建成功", nil) 170 | } 171 | 172 | // BatchMove 批量移动文件/文件夹 173 | func (fc *FileController) BatchMove(ctx *gin.Context) { 174 | var req model.BatchMoveRequest 175 | if err := ctx.ShouldBindJSON(&req); err != nil { 176 | response.ParamError(ctx, errcode.ParamBindError, "参数错误") 177 | return 178 | } 179 | 180 | // 获取当前用户ID 181 | userID, err := utils.GetUserIDFromContext(ctx) 182 | if err != nil { 183 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 184 | return 185 | } 186 | 187 | // 执行批量移动 188 | if err := fc.fileService.BatchMoveFiles(userID, req.FileIDs, req.TargetParentID); err != nil { 189 | response.InternalError(ctx, errcode.InternalServerError, "移动失败") 190 | return 191 | } 192 | 193 | response.SuccessWithMessage(ctx, "移动成功", nil) 194 | } 195 | 196 | func (fc *FileController) Search(ctx *gin.Context) { 197 | userID, err := utils.GetUserIDFromContext(ctx) 198 | if err != nil { 199 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户校验失败") 200 | return 201 | } 202 | 203 | key := ctx.Query("key") 204 | 205 | page, pageSize, err := utils.ParsePaginationParams(ctx) 206 | if err != nil { 207 | response.ParamError(ctx, errcode.ParamValidateError, "分页参数错误") 208 | return 209 | } 210 | 211 | sort := ctx.DefaultQuery("sort", "name:asc") 212 | if err := utils.ValidateSortParameter(sort, []string{"name", "upadated_at"}); err != nil { 213 | response.ParamError(ctx, errcode.ParamValidateError, "排序参数错误") 214 | } 215 | 216 | total, files, err := fc.fileService.SearchList(userID, key, page, pageSize, sort) 217 | if err != nil { 218 | response.InternalError(ctx, errcode.FileSearchFailed, "搜索文件失败") 219 | } 220 | 221 | response.PageSuccess(ctx, files, total) 222 | 223 | } 224 | 225 | func (fc *FileController) Rename(ctx *gin.Context) { 226 | var req model.RenameRequest 227 | if err := ctx.ShouldBindJSON(&req); err != nil { 228 | response.ParamError(ctx, errcode.ParamBindError, "参数错误") 229 | return 230 | } 231 | 232 | userID, err := utils.GetUserIDFromContext(ctx) 233 | if err != nil { 234 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 235 | return 236 | } 237 | 238 | if err = fc.fileService.Rename(userID, req.FileID, req.NewName); err != nil { 239 | response.InternalError(ctx, errcode.InternalServerError, fmt.Sprintf("重命名失败 %s", err)) 240 | return 241 | } 242 | 243 | response.SuccessWithMessage(ctx, "重命名成功", nil) 244 | 245 | } 246 | 247 | // GetPath 获取文件的完整路径 248 | func (fc *FileController) GetPath(ctx *gin.Context) { 249 | fileID := ctx.Query("file_id") 250 | if fileID == "" { 251 | response.ParamError(ctx, errcode.ParamValidateError, "文件ID不能为空") 252 | return 253 | } 254 | 255 | // 获取文件路径 256 | path, err := fc.fileService.GetFilePath(fileID) 257 | if err != nil { 258 | response.InternalError(ctx, errcode.FileNotFound, "获取文件路径失败") 259 | return 260 | } 261 | 262 | response.SuccessWithMessage(ctx, "获取文件路径成功", gin.H{ 263 | "path": path, 264 | }) 265 | } 266 | 267 | // GetIDPath 获取文件的ID路径 268 | func (fc *FileController) GetIDPath(ctx *gin.Context) { 269 | fileID := ctx.Query("file_id") 270 | if fileID == "" { 271 | response.ParamError(ctx, errcode.ParamValidateError, "文件ID不能为空") 272 | return 273 | } 274 | 275 | // 获取文件ID路径 276 | path, err := fc.fileService.GetFileIDPath(fileID) 277 | if err != nil { 278 | response.InternalError(ctx, errcode.FileNotFound, "获取文件路径失败") 279 | return 280 | } 281 | 282 | response.SuccessWithMessage(ctx, "获取文件ID路径成功", gin.H{ 283 | "id_path": path, 284 | }) 285 | } 286 | -------------------------------------------------------------------------------- /internal/controller/kb_controller.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "ai-cloud/internal/service" 6 | "ai-cloud/internal/utils" 7 | "ai-cloud/pkgs/errcode" 8 | "ai-cloud/pkgs/response" 9 | "encoding/json" 10 | "fmt" 11 | "path/filepath" 12 | "strings" 13 | 14 | "github.com/gin-gonic/gin" 15 | ) 16 | 17 | type KBController struct { 18 | kbService service.KBService 19 | fileService service.FileService 20 | } 21 | 22 | func NewKBController(kbService service.KBService, fileService service.FileService) *KBController { 23 | return &KBController{kbService: kbService, fileService: fileService} 24 | } 25 | 26 | func (kc *KBController) Create(ctx *gin.Context) { 27 | // 1. 获取用户ID 28 | userID, err := utils.GetUserIDFromContext(ctx) 29 | if err != nil { 30 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 31 | return 32 | } 33 | 34 | var req model.CreateKBRequest 35 | if err := ctx.ShouldBindJSON(&req); err != nil { 36 | response.ParamError(ctx, errcode.ParamBindError, "参数错误") 37 | return 38 | } 39 | 40 | if err := kc.kbService.CreateKB(userID, req.Name, req.Description, req.EmbedModelID); err != nil { 41 | response.InternalError(ctx, errcode.InternalServerError, "创建失败: "+err.Error()) 42 | return 43 | } 44 | 45 | response.SuccessWithMessage(ctx, "创建知识库成功", nil) 46 | } 47 | 48 | // 删除知识库 49 | func (kc *KBController) Delete(ctx *gin.Context) { 50 | // 获取用户ID并验证 51 | userID, err := utils.GetUserIDFromContext(ctx) 52 | if err != nil { 53 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 54 | return 55 | } 56 | 57 | kbID := ctx.Query("kb_id") 58 | 59 | // 删除知识库 60 | if err := kc.kbService.DeleteKB(userID, kbID); err != nil { 61 | response.InternalError(ctx, errcode.InternalServerError, err.Error()) 62 | return 63 | } 64 | response.SuccessWithMessage(ctx, "删除知识库成功", nil) 65 | } 66 | 67 | func (kc *KBController) PageList(ctx *gin.Context) { 68 | // 获取用户ID并验证 69 | userID, err := utils.GetUserIDFromContext(ctx) 70 | if err != nil { 71 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 72 | return 73 | } 74 | 75 | page, pageSize, err := utils.ParsePaginationParams(ctx) 76 | if err != nil { 77 | response.ParamError(ctx, errcode.ParamBindError, "分页参数错误") 78 | return 79 | } 80 | 81 | total, kbs, err := kc.kbService.PageList(userID, page, pageSize) 82 | if err != nil { 83 | response.InternalError(ctx, errcode.InternalServerError, "获取知识库列表失败") 84 | return 85 | } 86 | 87 | response.PageSuccess(ctx, kbs, total) 88 | } 89 | 90 | func (kc *KBController) DocPage(ctx *gin.Context) { 91 | // 获取用户ID并验证 92 | userID, err := utils.GetUserIDFromContext(ctx) 93 | if err != nil { 94 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 95 | return 96 | } 97 | page, pageSize, err := utils.ParsePaginationParams(ctx) 98 | if err != nil { 99 | response.ParamError(ctx, errcode.ParamBindError, "分页参数错误") 100 | return 101 | } 102 | kbID := ctx.Query("kb_id") 103 | total, docs, err := kc.kbService.DocList(userID, kbID, page, pageSize) 104 | if err != nil { 105 | response.InternalError(ctx, errcode.InternalServerError, "获取列表失败") 106 | fmt.Printf(err.Error()) 107 | return 108 | } 109 | response.PageSuccess(ctx, docs, total) 110 | } 111 | func (kc *KBController) AddExistFile(ctx *gin.Context) { 112 | // 获取用户ID并验证 113 | userID, err := utils.GetUserIDFromContext(ctx) 114 | if err != nil { 115 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 116 | return 117 | } 118 | req := model.AddFileRequest{} 119 | 120 | if err := ctx.ShouldBindJSON(&req); err != nil { 121 | response.ParamError(ctx, errcode.ParamBindError, "参数错误") 122 | return 123 | } 124 | file, err := kc.fileService.GetFileByID(req.FileID) 125 | if err != nil { 126 | response.InternalError(ctx, errcode.InternalServerError, "获取文件信息失败") 127 | return 128 | } 129 | 130 | // 添加文件到知识库 131 | doc, err := kc.kbService.CreateDocument(userID, req.KBID, file) 132 | if err != nil { 133 | response.InternalError(ctx, errcode.InternalServerError, "添加文件到知识库失败") 134 | return 135 | } 136 | // 处理解析文档 137 | doc.Status = 1 //正在处理文档 138 | 139 | if err = kc.kbService.ProcessDocument(ctx, userID, req.KBID, doc); err != nil { 140 | response.InternalError(ctx, errcode.InternalServerError, err.Error()) 141 | return 142 | } 143 | response.SuccessWithMessage(ctx, "添加文件到知识库成功", nil) 144 | } 145 | 146 | // 上传新的文件到知识库 147 | func (kc *KBController) AddNewFile(ctx *gin.Context) { 148 | userID, err := utils.GetUserIDFromContext(ctx) 149 | if err != nil { 150 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 151 | return 152 | } 153 | 154 | // 获取知识库ID 155 | kbID := ctx.PostForm("kb_id") 156 | if kbID == "" { 157 | response.ParamError(ctx, errcode.ParamBindError, "知识库ID不能为空") 158 | return 159 | } 160 | 161 | fileHeader, err := ctx.FormFile("file") 162 | if err != nil { 163 | response.ParamError(ctx, errcode.ParamBindError, "文件上传失败") 164 | return 165 | } 166 | 167 | // 检查文件大小 168 | if fileHeader.Size > 20*1024*1024 { // 20MB限制 169 | response.ParamError(ctx, errcode.ParamBindError, "文件大小不能超过20MB") 170 | return 171 | } 172 | 173 | file, err := fileHeader.Open() 174 | if err != nil { 175 | response.ParamError(ctx, errcode.ParamBindError, "文件打开失败") 176 | return 177 | } 178 | defer file.Close() 179 | 180 | folderID, err := kc.fileService.InitKnowledgeDir(userID) 181 | if err != nil { 182 | response.InternalError(ctx, errcode.InternalServerError, "初始化知识库目录失败"+err.Error()) 183 | return 184 | } 185 | fileID, err := kc.fileService.UploadFile(userID, fileHeader, file, folderID) 186 | if err != nil { 187 | response.InternalError(ctx, errcode.InternalServerError, "文件上传失败") 188 | return 189 | } 190 | 191 | // 将文件添加到知识库中 192 | f, err := kc.fileService.GetFileByID(fileID) 193 | if err != nil || f == nil { // 添加对 nil 的检查 194 | response.InternalError(ctx, errcode.InternalServerError, "获取文件信息失败") 195 | return 196 | } 197 | 198 | // 文档名称长度检查 199 | if len(f.Name) > 200 { 200 | // 截断文件名 201 | nameBase := filepath.Base(f.Name) 202 | nameExt := filepath.Ext(nameBase) 203 | nameWithoutExt := strings.TrimSuffix(nameBase, nameExt) 204 | if len(nameWithoutExt) > 195 { 205 | nameWithoutExt = nameWithoutExt[:195] 206 | } 207 | f.Name = nameWithoutExt + nameExt 208 | } 209 | 210 | doc, err := kc.kbService.CreateDocument(userID, kbID, f) 211 | if err != nil { 212 | response.InternalError(ctx, errcode.InternalServerError, "添加文件到知识库失败") 213 | return 214 | } 215 | 216 | doc.Status = 1 // 正在处理文档 217 | if err = kc.kbService.ProcessDocument(ctx, userID, kbID, doc); err != nil { 218 | response.InternalError(ctx, errcode.InternalServerError, "处理文档失败: "+err.Error()) 219 | return 220 | } 221 | response.SuccessWithMessage(ctx, "添加文件到知识库成功", nil) 222 | } 223 | 224 | func (kc *KBController) Retrieve(ctx *gin.Context) { 225 | // 1. 获取用户ID 226 | userID, err := utils.GetUserIDFromContext(ctx) 227 | if err != nil { 228 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 229 | return 230 | } 231 | 232 | // 2. 解析请求参数 233 | var req model.RetrieveRequest 234 | 235 | if err := ctx.ShouldBindJSON(&req); err != nil { 236 | response.ParamError(ctx, errcode.ParamBindError, "参数错误") 237 | return 238 | } 239 | 240 | // 3. 调用服务层检索 241 | docs, err := kc.kbService.Retrieve(ctx, userID, req.KBID, req.Query, req.TopK) 242 | if err != nil { 243 | response.InternalError(ctx, errcode.InternalServerError, err.Error()) 244 | return 245 | } 246 | 247 | // 4. 返回结果 248 | response.Success(ctx, docs) 249 | } 250 | 251 | func (kc *KBController) Chat(ctx *gin.Context) { 252 | userID, err := utils.GetUserIDFromContext(ctx) 253 | if err != nil { 254 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 255 | return 256 | } 257 | 258 | // 2. 解析请求参数 259 | var req model.ChatRequest 260 | if err := ctx.ShouldBindJSON(&req); err != nil { 261 | response.ParamError(ctx, errcode.ParamBindError, "参数错误") 262 | return 263 | } 264 | 265 | // 3. 调用服务层处理 266 | resp, err := kc.kbService.RAGQuery(ctx, userID, req.Query, req.KBs) 267 | if err != nil { 268 | response.InternalError(ctx, errcode.InternalServerError, err.Error()) 269 | return 270 | } 271 | 272 | // 4. 返回结果 273 | response.Success(ctx, resp) 274 | 275 | } 276 | 277 | func (kc *KBController) ChatStream(ctx *gin.Context) { 278 | // 1. 获取用户ID 279 | userID, err := utils.GetUserIDFromContext(ctx) 280 | if err != nil { 281 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 282 | return 283 | } 284 | 285 | // 2. 解析请求参数 286 | var req model.ChatRequest 287 | if err := ctx.ShouldBindJSON(&req); err != nil { 288 | response.ParamError(ctx, errcode.ParamBindError, "参数错误") 289 | return 290 | } 291 | 292 | // 3. 设置响应头 293 | ctx.Writer.Header().Set("Content-Type", "text/event-stream") 294 | ctx.Writer.Header().Set("Cache-Control", "no-cache") 295 | ctx.Writer.Header().Set("Connection", "keep-alive") 296 | 297 | // 4. 调用服务层获取流式响应 298 | responseChan, err := kc.kbService.RAGQueryStream(ctx.Request.Context(), userID, req.Query, req.KBs) 299 | if err != nil { 300 | ctx.SSEvent("error", err.Error()) 301 | return 302 | } 303 | 304 | // 5. 发送流式响应 305 | for r := range responseChan { 306 | data, _ := json.Marshal(r) 307 | ctx.Writer.Write([]byte("data: " + string(data) + "\n\n")) 308 | ctx.Writer.Flush() 309 | } 310 | } 311 | 312 | func (kc *KBController) GetKBDetail(ctx *gin.Context) { 313 | // 获取用户ID并验证 314 | userID, err := utils.GetUserIDFromContext(ctx) 315 | if err != nil { 316 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 317 | return 318 | } 319 | 320 | // 获取知识库ID 321 | kbID := ctx.Query("kb_id") 322 | if kbID == "" { 323 | response.ParamError(ctx, errcode.ParamBindError, "知识库ID不能为空") 324 | return 325 | } 326 | 327 | // 获取知识库详情 328 | kb, err := kc.kbService.GetKBDetail(userID, kbID) 329 | if err != nil { 330 | response.InternalError(ctx, errcode.InternalServerError, "获取知识库详情失败") 331 | return 332 | } 333 | 334 | response.Success(ctx, kb) 335 | } 336 | 337 | func (kc *KBController) DeleteDocs(ctx *gin.Context) { 338 | 339 | // 获取用户ID并验证 340 | userID, err := utils.GetUserIDFromContext(ctx) 341 | if err != nil { 342 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "获取用户失败") 343 | return 344 | } 345 | 346 | req := model.BatchDeleteDocsReq{} 347 | 348 | if err = ctx.ShouldBindJSON(&req); err != nil { 349 | response.ParamError(ctx, errcode.ParamBindError, "参数错误") 350 | return 351 | } 352 | 353 | docIDs := req.DocIDs 354 | kbID := req.KBID 355 | 356 | if len(docIDs) == 0 { 357 | response.SuccessWithMessage(ctx, "删除知识库成功", nil) 358 | return 359 | } 360 | 361 | if err := kc.kbService.DeleteDocs(userID, kbID, docIDs); err != nil { 362 | response.InternalError(ctx, errcode.InternalServerError, "删除文档失败") 363 | return 364 | } 365 | response.SuccessWithMessage(ctx, "删除知识库成功", nil) 366 | } 367 | -------------------------------------------------------------------------------- /internal/controller/model_controller.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "ai-cloud/internal/service" 6 | "ai-cloud/internal/utils" 7 | "ai-cloud/pkgs/errcode" 8 | "ai-cloud/pkgs/response" 9 | "fmt" 10 | "github.com/gin-gonic/gin" 11 | ) 12 | 13 | type ModelController struct { 14 | svc service.ModelService 15 | } 16 | 17 | func NewModelController(svc service.ModelService) *ModelController { 18 | return &ModelController{svc: svc} 19 | } 20 | 21 | func (c *ModelController) CreateModel(ctx *gin.Context) { 22 | // 获取用户ID并验证 23 | userID, err := utils.GetUserIDFromContext(ctx) 24 | if err != nil { 25 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "获取用户失败") 26 | return 27 | } 28 | 29 | var req model.CreateModelRequest 30 | if err := ctx.ShouldBindJSON(&req); err != nil { 31 | response.ParamError(ctx, errcode.ParamBindError, "参数错误:"+err.Error()) 32 | return 33 | } 34 | 35 | m := &model.Model{ 36 | ID: utils.GenerateUUID(), 37 | UserID: userID, 38 | Type: req.Type, 39 | ShowName: req.ShowName, 40 | Server: req.Server, 41 | BaseURL: req.BaseURL, 42 | ModelName: req.ModelName, 43 | APIKey: req.APIKey, 44 | // embedding 45 | Dimension: req.Dimension, 46 | // llm 47 | MaxOutputLength: req.MaxOutputLength, 48 | Function: req.Function, 49 | // common 50 | MaxTokens: req.MaxTokens, 51 | } 52 | 53 | if err := c.svc.CreateModel(ctx.Request.Context(), m); err != nil { 54 | response.InternalError(ctx, errcode.InternalServerError, err.Error()) 55 | fmt.Println("创建模型失败:", err) 56 | return 57 | } 58 | 59 | response.SuccessWithMessage(ctx, "创建模型成功", nil) 60 | } 61 | 62 | // TODO:修改返回格式 63 | func (c *ModelController) UpdateModel(ctx *gin.Context) { 64 | // 获取用户ID并验证 65 | userID, err := utils.GetUserIDFromContext(ctx) 66 | if err != nil { 67 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "获取用户失败") 68 | return 69 | } 70 | 71 | var req model.UpdateModelRequest 72 | if err := ctx.ShouldBindJSON(&req); err != nil { 73 | response.ParamError(ctx, errcode.ParamBindError, "参数错误:"+err.Error()) 74 | return 75 | } 76 | 77 | m := &model.Model{ 78 | ID: req.ID, 79 | UserID: userID, 80 | ShowName: req.ShowName, 81 | Server: req.Server, 82 | BaseURL: req.BaseURL, 83 | ModelName: req.ModelName, 84 | APIKey: req.APIKey, 85 | // embedding 86 | Dimension: req.Dimension, 87 | // llm 88 | MaxOutputLength: req.MaxOutputLength, 89 | Function: req.Function, 90 | // common 91 | MaxTokens: req.MaxTokens, 92 | } 93 | 94 | if err := c.svc.UpdateModel(ctx.Request.Context(), m); err != nil { 95 | response.InternalError(ctx, errcode.InternalServerError, "更新模型失败:"+err.Error()) 96 | return 97 | } 98 | response.SuccessWithMessage(ctx, "更新模型成功", nil) 99 | } 100 | 101 | func (c *ModelController) DeleteModel(ctx *gin.Context) { 102 | // 获取用户ID并验证 103 | userID, err := utils.GetUserIDFromContext(ctx) 104 | if err != nil { 105 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败") 106 | return 107 | } 108 | kbID := ctx.Query("model_id") 109 | 110 | if err := c.svc.DeleteModel(ctx.Request.Context(), userID, kbID); err != nil { 111 | response.InternalError(ctx, errcode.InternalServerError, "删除模型失败:"+err.Error()) 112 | return 113 | } 114 | 115 | response.SuccessWithMessage(ctx, "删除模型成功", nil) 116 | } 117 | 118 | func (c *ModelController) GetModel(ctx *gin.Context) { 119 | // 获取用户ID并验证 120 | userID, err := utils.GetUserIDFromContext(ctx) 121 | if err != nil { 122 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "获取用户失败") 123 | return 124 | } 125 | 126 | modelID := ctx.Query("model_id") 127 | 128 | m, err := c.svc.GetModel(ctx.Request.Context(), userID, modelID) 129 | if err != nil { 130 | response.InternalError(ctx, errcode.InternalServerError, "获取模型失败:"+err.Error()) 131 | return 132 | } 133 | response.SuccessWithMessage(ctx, "获取模型成功", m) 134 | } 135 | 136 | func (c *ModelController) PageModels(ctx *gin.Context) { 137 | userID, err := utils.GetUserIDFromContext(ctx) 138 | if err != nil { 139 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "获取用户失败") 140 | return 141 | } 142 | 143 | var req model.PageModelRequest 144 | 145 | if err := ctx.ShouldBindQuery(&req); err != nil { 146 | response.ParamError(ctx, errcode.ParamBindError, "参数错误:"+err.Error()) 147 | return 148 | } 149 | 150 | models, count, err := c.svc.PageModels(ctx.Request.Context(), userID, req.Type, req.Page, req.Size) 151 | if err != nil { 152 | response.InternalError(ctx, errcode.InternalServerError, "获取模型列表失败:"+err.Error()) 153 | return 154 | } 155 | 156 | response.PageSuccess(ctx, models, count) 157 | } 158 | 159 | func (c *ModelController) ListModels(ctx *gin.Context) { 160 | userID, err := utils.GetUserIDFromContext(ctx) 161 | if err != nil { 162 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "获取用户失败") 163 | return 164 | } 165 | 166 | modelType := ctx.Query("type") 167 | models, err := c.svc.ListModels(ctx.Request.Context(), userID, modelType) 168 | if err != nil { 169 | response.InternalError(ctx, errcode.InternalServerError, "获取模型列表失败:"+err.Error()) 170 | return 171 | } 172 | 173 | response.SuccessWithMessage(ctx, "获取模型列表成功", models) 174 | } 175 | -------------------------------------------------------------------------------- /internal/controller/user_controller.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "ai-cloud/internal/service" 6 | "ai-cloud/pkgs/errcode" 7 | "ai-cloud/pkgs/response" 8 | "github.com/gin-gonic/gin" 9 | ) 10 | 11 | type UserController struct { 12 | userService service.UserService 13 | } 14 | 15 | func NewUserController(userService service.UserService) *UserController { 16 | return &UserController{userService: userService} 17 | } 18 | 19 | func (c *UserController) Register(ctx *gin.Context) { 20 | var req model.User 21 | 22 | if err := ctx.ShouldBindJSON(&req); err != nil { 23 | response.ParamError(ctx, errcode.ParamBindError, "用户注册参数错误: "+err.Error()) 24 | return 25 | } 26 | 27 | if err := c.userService.Register(&req); err != nil { 28 | response.InternalError(ctx, errcode.InternalServerError, "注册失败: "+err.Error()) 29 | return 30 | } 31 | 32 | response.SuccessWithMessage(ctx, req.Username+"注册成功", nil) 33 | } 34 | 35 | func (c *UserController) Login(ctx *gin.Context) { 36 | var req model.UserNameLoginReq 37 | if err := ctx.ShouldBindJSON(&req); err != nil { 38 | response.ParamError(ctx, errcode.ParamBindError, "用户名或密码错误") 39 | return 40 | } 41 | 42 | loginResponse, err := c.userService.Login(&req) 43 | if err != nil { 44 | response.InternalError(ctx, errcode.InternalServerError, "登录失败") 45 | return 46 | } 47 | 48 | response.SuccessWithMessage(ctx, "登录成功", gin.H{ 49 | "access_token": loginResponse.AccessToken, 50 | "expires_in": loginResponse.ExpiresIn, 51 | "token_type": loginResponse.TokenType, 52 | }) 53 | } 54 | -------------------------------------------------------------------------------- /internal/dao/agent.go: -------------------------------------------------------------------------------- 1 | package dao 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "context" 6 | "errors" 7 | "gorm.io/gorm" 8 | ) 9 | 10 | type AgentDao interface { 11 | Create(ctx context.Context, agent *model.Agent) error 12 | Update(ctx context.Context, agent *model.Agent) error 13 | Delete(ctx context.Context, userID uint, agentID string) error 14 | GetByID(ctx context.Context, userID uint, agentID string) (*model.Agent, error) 15 | List(ctx context.Context, userID uint) ([]*model.Agent, error) 16 | Page(ctx context.Context, userID uint, page, size int) ([]*model.Agent, int64, error) 17 | } 18 | 19 | type agentDao struct { 20 | db *gorm.DB 21 | } 22 | 23 | func NewAgentDao(db *gorm.DB) AgentDao { 24 | return &agentDao{db: db} 25 | } 26 | 27 | func (d *agentDao) Create(ctx context.Context, agent *model.Agent) error { 28 | return d.db.WithContext(ctx).Create(agent).Error 29 | } 30 | 31 | func (d *agentDao) Update(ctx context.Context, agent *model.Agent) error { 32 | // Check if the agent belongs to the user 33 | var count int64 34 | if err := d.db.WithContext(ctx).Model(&model.Agent{}).Where("id = ? AND user_id = ?", agent.ID, agent.UserID).Count(&count).Error; err != nil { 35 | return err 36 | } 37 | if count == 0 { 38 | return errors.New("agent not found or no permission") 39 | } 40 | 41 | // Only update specific fields 42 | return d.db.WithContext(ctx).Model(agent). 43 | Select("Name", "Description", "AgentSchema", "UpdatedAt"). 44 | Updates(agent).Error 45 | } 46 | 47 | func (d *agentDao) Delete(ctx context.Context, userID uint, agentID string) error { 48 | result := d.db.WithContext(ctx).Where("id = ? AND user_id = ?", agentID, userID).Delete(&model.Agent{}) 49 | if result.RowsAffected == 0 { 50 | return errors.New("agent not found or no permission") 51 | } 52 | return result.Error 53 | } 54 | 55 | func (d *agentDao) GetByID(ctx context.Context, userID uint, agentID string) (*model.Agent, error) { 56 | var agent model.Agent 57 | err := d.db.WithContext(ctx).Where("id = ? AND user_id = ?", agentID, userID).First(&agent).Error 58 | if err != nil { 59 | if errors.Is(err, gorm.ErrRecordNotFound) { 60 | return nil, errors.New("agent not found or no permission") 61 | } 62 | return nil, err 63 | } 64 | return &agent, nil 65 | } 66 | 67 | func (d *agentDao) List(ctx context.Context, userID uint) ([]*model.Agent, error) { 68 | var agents []*model.Agent 69 | if err := d.db.WithContext(ctx).Where("user_id = ?", userID).Find(&agents).Error; err != nil { 70 | return nil, err 71 | } 72 | return agents, nil 73 | } 74 | 75 | func (d *agentDao) Page(ctx context.Context, userID uint, page, size int) ([]*model.Agent, int64, error) { 76 | var agents []*model.Agent 77 | var count int64 78 | 79 | db := d.db.WithContext(ctx).Model(&model.Agent{}).Where("user_id = ?", userID).Order("updated_at desc") 80 | 81 | err := db.Count(&count).Error 82 | if err != nil { 83 | return nil, 0, err 84 | } 85 | 86 | err = db.Offset((page - 1) * size).Limit(size).Find(&agents).Error 87 | return agents, count, err 88 | } 89 | -------------------------------------------------------------------------------- /internal/dao/file_dao.go: -------------------------------------------------------------------------------- 1 | package dao 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "errors" 6 | "fmt" 7 | "gorm.io/gorm" 8 | "strings" 9 | ) 10 | 11 | // FileDao 定义了文件操作的接口 12 | type FileDao interface { 13 | CreateFile(file *model.File) error 14 | GetFilesByParentID(userID uint, parentID *string) ([]model.File, error) 15 | GetFileMetaByFileID(id string) (*model.File, error) 16 | DeleteFile(id string) error 17 | ListFiles(userID uint, parentID *string, page int, pageSize int, sort string) ([]model.File, error) 18 | CountFilesByParentID(parentID *string, userID uint) (int64, error) 19 | UpdateFile(file *model.File) error 20 | CountFilesByKeyword(key string, userID uint) (int64, error) 21 | GetFilesByKeyword(userID uint, key string, page int, pageSize int, sort string) ([]model.File, error) 22 | GetDocumentDir(userID uint) (*model.File, error) 23 | } 24 | 25 | // fileDao 实现了FileDao接口,提供文件相关操作 26 | type fileDao struct { 27 | db *gorm.DB 28 | } 29 | 30 | // NewFileDao 创建并返回一个新的FileDao实例 31 | func NewFileDao(db *gorm.DB) FileDao { 32 | return &fileDao{db: db} 33 | } 34 | 35 | // CreateFile 创建一个新的文件记录 36 | func (fd *fileDao) CreateFile(file *model.File) error { 37 | if fd.db == nil { 38 | return errors.New("数据库未初始化") 39 | } 40 | return fd.db.Create(file).Error 41 | } 42 | 43 | // GetFilesByParentID 根据父ID获取文件列表 44 | func (fd *fileDao) GetFilesByParentID(userID uint, parentID *string) ([]model.File, error) { 45 | var files []model.File 46 | query := fd.db.Where("user_id = ?", userID) 47 | 48 | if parentID == nil { 49 | query = query.Where("parent_id IS NULL") 50 | } else { 51 | query = query.Where("parent_id = ?", *parentID) 52 | } 53 | 54 | if err := query.Find(&files).Error; err != nil { 55 | return nil, err 56 | } 57 | return files, nil 58 | } 59 | 60 | // GetFileMetaByFileID 根据文件ID获取文件元信息 61 | func (fd *fileDao) GetFileMetaByFileID(id string) (*model.File, error) { 62 | var file model.File 63 | result := fd.db.Where("id = ?", id).First(&file) 64 | if result.Error != nil { 65 | if errors.Is(result.Error, gorm.ErrRecordNotFound) { 66 | return nil, nil 67 | } 68 | return nil, result.Error 69 | } 70 | return &file, nil 71 | } 72 | 73 | // DeleteFile 根据文件ID删除文件记录 74 | func (fd *fileDao) DeleteFile(id string) error { 75 | if err := fd.db.Where("id = ?", id).Delete(&model.File{}).Error; err != nil { 76 | return err 77 | } 78 | return nil 79 | } 80 | 81 | // ListFiles 列出文件列表,根据指定的排序方式和分页参数 82 | func (fd *fileDao) ListFiles(userID uint, parentID *string, page int, pageSize int, sort string) ([]model.File, error) { 83 | var files []model.File 84 | query := fd.db.Model(model.File{}).Where("user_id = ?", userID) 85 | 86 | if parentID == nil { 87 | query = query.Where("parent_id IS NULL") 88 | } else { 89 | query = query.Where("parent_id = ?", *parentID) 90 | } 91 | query = query.Order("is_dir desc") 92 | 93 | sortClauses := strings.Split(sort, ",") 94 | for _, clause := range sortClauses { 95 | parts := strings.Split(clause, ":") 96 | filed, order := parts[0], parts[1] 97 | query = query.Order(fmt.Sprintf("%s %s", filed, order)) 98 | } 99 | //处理分页 100 | offset := (page - 1) * pageSize 101 | query = query.Offset(offset).Limit(pageSize) 102 | 103 | if err := query.Find(&files).Error; err != nil { 104 | return nil, err 105 | } 106 | return files, nil 107 | } 108 | 109 | func (fd *fileDao) GetFilesByKeyword(userID uint, key string, page int, pageSize int, sort string) ([]model.File, error) { 110 | var files []model.File 111 | query := fd.db.Model(&model.File{}).Where("user_id=?", userID).Where("name LIKE ?", "%"+key+"%") 112 | 113 | query = query.Order("is_dir desc") 114 | sortClauses := strings.Split(sort, ",") 115 | for _, clause := range sortClauses { 116 | parts := strings.Split(clause, ":") 117 | filed, order := parts[0], parts[1] 118 | query = query.Order(fmt.Sprintf("%s %s", filed, order)) 119 | } 120 | //处理分页 121 | offset := (page - 1) * pageSize 122 | query = query.Offset(offset).Limit(pageSize) 123 | 124 | if err := query.Find(&files).Error; err != nil { 125 | return nil, err 126 | } 127 | return files, nil 128 | } 129 | 130 | // CountFilesByParentID 计算指定父ID下的文件数量 131 | func (fd *fileDao) CountFilesByParentID(parentID *string, userID uint) (int64, error) { 132 | var total int64 133 | query := fd.db.Model(&model.File{}).Where("user_id = ?", userID) 134 | 135 | if parentID == nil { 136 | query = query.Where("parent_id IS NULL") 137 | } else { 138 | query = query.Where("parent_id = ?", parentID) 139 | } 140 | if err := query.Count(&total).Error; err != nil { 141 | return 0, err 142 | } 143 | return total, nil 144 | } 145 | 146 | func (fd *fileDao) CountFilesByKeyword(key string, userID uint) (int64, error) { 147 | var total int64 148 | query := fd.db.Model(&model.File{}). 149 | Where("user_id = ?", userID). 150 | Where("name like ?", "%"+key+"%") 151 | if err := query.Count(&total).Error; err != nil { 152 | return 0, err 153 | } 154 | return total, nil 155 | } 156 | 157 | // UpdateFile 更新文件信息 158 | func (fd *fileDao) UpdateFile(file *model.File) error { 159 | if fd.db == nil { 160 | return errors.New("数据库未初始化") 161 | } 162 | return fd.db.Save(file).Error 163 | } 164 | 165 | func (fd *fileDao) GetDocumentDir(userID uint) (*model.File, error) { 166 | // 初始化结构体 167 | file := &model.File{} 168 | err := fd.db.Where("user_id = ? AND name = ? AND is_dir = ? AND parent_id IS NULL", 169 | userID, "知识库文件", true).First(file).Error 170 | if err != nil { 171 | return nil, err // 直接返回错误,包括 gorm.ErrRecordNotFound 172 | } 173 | return file, nil 174 | } 175 | -------------------------------------------------------------------------------- /internal/dao/history/conv_dao.go: -------------------------------------------------------------------------------- 1 | /* 2 | * This file contains modified code from the eino-history project. 3 | * Original source: https://github.com/HildaM/eino-history 4 | * Licensed under the Apache License, Version 2.0. 5 | * Modifications are made by RaspberryCola. 6 | */ 7 | 8 | package history 9 | 10 | import ( 11 | "ai-cloud/internal/model" 12 | "context" 13 | "fmt" 14 | 15 | "gorm.io/gorm" 16 | ) 17 | 18 | type ConvDao interface { 19 | Create(ctx context.Context, conv *model.Conversation) error 20 | Update(ctx context.Context, conv *model.Conversation) error 21 | Delete(ctx context.Context, convID string) error 22 | GetByID(ctx context.Context, convID string) (*model.Conversation, error) 23 | FirstOrCreate(ctx context.Context, conv *model.Conversation) error 24 | Page(ctx context.Context, userID uint, page, size int) ([]*model.Conversation, int64, error) 25 | PageByAgent(ctx context.Context, userID uint, agentID string, page, size int) ([]*model.Conversation, int64, error) 26 | Archive(ctx context.Context, convID string) error 27 | UnArchive(ctx context.Context, convID string) error 28 | Pin(ctx context.Context, convID string) error 29 | UnPin(ctx context.Context, convID string) error 30 | GetDB() *gorm.DB 31 | } 32 | 33 | type convDao struct { 34 | db *gorm.DB 35 | } 36 | 37 | // NewConvDao 创建一个ConvDao 38 | func NewConvDao(db *gorm.DB) ConvDao { 39 | return &convDao{db: db} 40 | } 41 | 42 | func (d *convDao) GetDB() *gorm.DB { 43 | return d.db 44 | } 45 | 46 | // Create 创建一个会话 47 | func (d *convDao) Create(ctx context.Context, conv *model.Conversation) error { 48 | err := d.db.WithContext(ctx).Create(conv).Error 49 | if err != nil { 50 | return fmt.Errorf("failed to create conversation: %w", err) 51 | } 52 | return nil 53 | } 54 | 55 | // Update 更新一个会话 56 | func (d *convDao) Update(ctx context.Context, conv *model.Conversation) error { 57 | err := d.db.WithContext(ctx).Save(conv).Error 58 | if err != nil { 59 | return fmt.Errorf("failed to update conversation: %w", err) 60 | } 61 | return nil 62 | } 63 | 64 | // Delete 删除一个会话 65 | func (d *convDao) Delete(ctx context.Context, convID string) error { 66 | err := d.db.WithContext(ctx).Delete(&model.Conversation{}, "conv_id = ?", convID).Error 67 | if err != nil { 68 | return fmt.Errorf("failed to delete conversation: %w", err) 69 | } 70 | return nil 71 | } 72 | 73 | // GetByID 根据ID获取一个会话 74 | func (d *convDao) GetByID(ctx context.Context, convID string) (*model.Conversation, error) { 75 | var conv model.Conversation 76 | err := d.db.WithContext(ctx).Where("conv_id = ?", convID).First(&conv).Error 77 | if err != nil { 78 | return nil, fmt.Errorf("failed to get conversation: %w", err) 79 | } 80 | return &conv, nil 81 | } 82 | 83 | // FirstOrCreate 根据ID获取一个会话,如果会话不存在则创建一个 84 | func (d *convDao) FirstOrCreate(ctx context.Context, conv *model.Conversation) error { 85 | err := d.db.WithContext(ctx).Where("conv_id = ?", conv.ConvID).FirstOrCreate(&conv).Error 86 | if err != nil { 87 | return fmt.Errorf("failed to get conversation: %w", err) 88 | } 89 | return nil 90 | } 91 | 92 | // Page 分页获取会话 93 | func (d *convDao) Page(ctx context.Context, userID uint, page, size int) ([]*model.Conversation, int64, error) { 94 | var convs []*model.Conversation 95 | var total int64 96 | 97 | db := d.db.WithContext(ctx).Model(&model.Conversation{}).Where("user_id = ?", userID).Order("updated_at DESC") // 按照更新时间降序排序 98 | 99 | err := db.Count(&total).Error 100 | if err != nil { 101 | return nil, 0, fmt.Errorf("failed to count conversations: %w", err) 102 | } 103 | err = db.Offset((page - 1) * size).Limit(size).Find(&convs).Error 104 | return convs, total, err 105 | } 106 | 107 | // PageByAgent 按Agent分页获取会话 108 | func (d *convDao) PageByAgent(ctx context.Context, userID uint, agentID string, page, size int) ([]*model.Conversation, int64, error) { 109 | var convs []*model.Conversation 110 | var total int64 111 | 112 | db := d.db.WithContext(ctx).Model(&model.Conversation{}).Where("user_id = ? AND agent_id = ?", userID, agentID).Order("updated_at DESC") // 按照更新时间降序排序 113 | 114 | err := db.Count(&total).Error 115 | if err != nil { 116 | return nil, 0, fmt.Errorf("failed to count conversations: %w", err) 117 | } 118 | err = db.Offset((page - 1) * size).Limit(size).Find(&convs).Error 119 | return convs, total, err 120 | } 121 | 122 | // Archive 归档一个会话 123 | func (d *convDao) Archive(ctx context.Context, convID string) error { 124 | err := d.db.WithContext(ctx).Model(&model.Conversation{}).Where("conv_id = ?", convID).Update("is_archived", true).Error 125 | if err != nil { 126 | return fmt.Errorf("failed to archive conversation: %w", err) 127 | } 128 | return nil 129 | } 130 | 131 | // UnArchive 取消归档一个会话 132 | func (d *convDao) UnArchive(ctx context.Context, convID string) error { 133 | err := d.db.WithContext(ctx).Model(&model.Conversation{}).Where("conv_id = ?", convID).Update("is_archived", false).Error 134 | if err != nil { 135 | return fmt.Errorf("failed to unarchive conversation: %w", err) 136 | } 137 | return nil 138 | } 139 | 140 | // Pin 置顶一个会话 141 | func (d *convDao) Pin(ctx context.Context, convID string) error { 142 | err := d.db.WithContext(ctx).Model(&model.Conversation{}).Where("conv_id = ?", convID).Update("is_pinned", true).Error 143 | if err != nil { 144 | return fmt.Errorf("failed to pin conversation: %w", err) 145 | } 146 | return nil 147 | } 148 | 149 | // UnPin 取消置顶一个会话 150 | func (d *convDao) UnPin(ctx context.Context, convID string) error { 151 | err := d.db.WithContext(ctx).Model(&model.Conversation{}).Where("conv_id = ?", convID).Update("is_pinned", false).Error 152 | if err != nil { 153 | return fmt.Errorf("failed to unpin conversation: %w", err) 154 | } 155 | return nil 156 | } 157 | -------------------------------------------------------------------------------- /internal/dao/history/msg_dao.go: -------------------------------------------------------------------------------- 1 | package history 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "context" 6 | "fmt" 7 | "github.com/google/uuid" 8 | "gorm.io/gorm" 9 | ) 10 | 11 | type MsgDao interface { 12 | GetDB() *gorm.DB 13 | Create(ctx context.Context, msg *model.Message) error 14 | Update(ctx context.Context, msg *model.Message) error 15 | Delete(ctx context.Context, msgID string) error 16 | GetByID(ctx context.Context, msgID string) (*model.Message, error) 17 | ListByConvID(ctx context.Context, convID string) ([]*model.Message, error) 18 | List(ctx context.Context, convID string, offset, limit int) ([]*model.Message, int64, error) 19 | UpdateStatus(ctx context.Context, msgID, status string) error 20 | UpdateTokenCount(ctx context.Context, msgID string, tokenCount int) 21 | SetContextEdge(ctx context.Context, msgID string, isContextEdge bool) error 22 | SetVariant(ctx context.Context, msgID string, isVariant bool) error 23 | } 24 | 25 | type msgDao struct { 26 | db *gorm.DB 27 | } 28 | 29 | func NewMsgDao(db *gorm.DB) MsgDao { 30 | return &msgDao{db: db} 31 | } 32 | func (d *msgDao) GetDB() *gorm.DB { 33 | return d.db 34 | } 35 | 36 | func (d *msgDao) Create(ctx context.Context, msg *model.Message) error { 37 | if len(msg.MsgID) == 0 { 38 | msg.MsgID = uuid.NewString() 39 | } 40 | err := d.db.WithContext(ctx).Create(msg).Error 41 | if err != nil { 42 | return fmt.Errorf("failed to create message: %w", err) 43 | } 44 | return nil 45 | } 46 | 47 | func (d *msgDao) Update(ctx context.Context, msg *model.Message) error { 48 | if err := d.db.WithContext(ctx).Save(msg).Error; err != nil { 49 | return fmt.Errorf("failed to update message: %w", err) 50 | } 51 | return nil 52 | } 53 | func (d *msgDao) Delete(ctx context.Context, msgID string) error { 54 | if err := d.db.WithContext(ctx).Delete(&model.Message{}, "msg_id = ?", msgID).Error; err != nil { 55 | return fmt.Errorf("failed to delete message: %w", err) 56 | } 57 | return nil 58 | } 59 | 60 | // GetByID 据ID获取消息 61 | func (d *msgDao) GetByID(ctx context.Context, msgID string) (*model.Message, error) { 62 | var msg model.Message 63 | if err := d.db.WithContext(ctx).Where("msg_id = ?", msgID).First(&msg).Error; err != nil { 64 | return nil, fmt.Errorf("failed to get message: %w", err) 65 | } 66 | return &msg, nil 67 | } 68 | 69 | // ListByConvID 根据会话ID获取消息 70 | func (d *msgDao) ListByConvID(ctx context.Context, convID string) ([]*model.Message, error) { 71 | var msgs []*model.Message 72 | if err := d.db.WithContext(ctx).Where("conv_id = ?", convID). 73 | Order("order_seq ASC"). 74 | Find(&msgs).Error; err != nil { 75 | return nil, fmt.Errorf("failed to list messages: %w", err) 76 | } 77 | return msgs, nil 78 | } 79 | 80 | // Page 获取历史消息 81 | func (d *msgDao) List(ctx context.Context, convID string, offset, limit int) ([]*model.Message, int64, error) { 82 | var msgs []*model.Message 83 | var total int64 84 | db := d.db.WithContext(ctx).Model(&model.Message{}).Where("conv_id = ?", convID). 85 | Order("order_seq ASC") 86 | err := db.Count(&total).Error 87 | if err != nil { 88 | return nil, 0, fmt.Errorf("failed to count messages: %w", err) 89 | } 90 | err = db.Offset(offset).Limit(limit).Find(&msgs).Error 91 | if err != nil { 92 | return nil, 0, fmt.Errorf("failed to list messages: %w", err) 93 | } 94 | return msgs, total, err 95 | } 96 | 97 | // UpdateStatus 更新消息状态 98 | func (d *msgDao) UpdateStatus(ctx context.Context, msgID, status string) error { 99 | err := d.db.WithContext(ctx).Model(&model.Message{}).Where("msg_id = ?", msgID).Update("status", status).Error 100 | if err != nil { 101 | return fmt.Errorf("failed to update message status: %w", err) 102 | } 103 | return nil 104 | } 105 | 106 | // UpdateTokenCount 更新消息的token数量 107 | func (d *msgDao) UpdateTokenCount(ctx context.Context, msgID string, tokenCount int) { 108 | err := d.db.WithContext(ctx).Model(&model.Message{}).Where("msg_id = ?", msgID).Update("token_count", tokenCount).Error 109 | if err != nil { 110 | fmt.Printf("failed to update message token count: %v", err) 111 | } 112 | return 113 | } 114 | 115 | // SetContextEdge 设置消息为上下文边界 116 | func (d *msgDao) SetContextEdge(ctx context.Context, msgID string, isContextEdge bool) error { 117 | err := d.db.WithContext(ctx).Model(&model.Message{}).Where("msg_id = ?", msgID).Update("is_context_edge", isContextEdge).Error 118 | if err != nil { 119 | return fmt.Errorf("failed to update message is_context_edge: %w", err) 120 | } 121 | return nil 122 | } 123 | 124 | // SetVariant 设置消息为变体消息 125 | func (d *msgDao) SetVariant(ctx context.Context, msgID string, isVariant bool) error { 126 | err := d.db.WithContext(ctx).Model(&model.Message{}).Where("msg_id = ?", msgID).Update("is_variant", isVariant).Error 127 | if err != nil { 128 | return fmt.Errorf("failed to update message is_variant: %w", err) 129 | } 130 | return nil 131 | } 132 | -------------------------------------------------------------------------------- /internal/dao/kb_dao.go: -------------------------------------------------------------------------------- 1 | package dao 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "fmt" 6 | "gorm.io/gorm" 7 | ) 8 | 9 | type KnowledgeBaseDao interface { 10 | GetDB() *gorm.DB 11 | // 知识库相关 12 | CreateKB(kb *model.KnowledgeBase) error // 创建知识库 13 | DeleteKB(id string) error // 删除知识库 14 | CountKBs(userID uint) (int64, error) // 统计知识库数量 15 | ListKBs(userID uint, page int, pageSize int) ([]model.KnowledgeBase, error) // 获取知识库列表 16 | GetKBByID(kb_id string) (*model.KnowledgeBase, error) // 获取知识库 17 | 18 | // 文档相关 19 | CreateDocument(doc *model.Document) error // 创建文档 20 | UpdateDocument(doc *model.Document) error // 更新文档 21 | CountDocs(id string) (int64, error) // 统计文档数量 22 | ListDocs(id string, page int, size int) ([]model.Document, error) // 获取文档列表 23 | GetAllDocsByKBID(kbID string) ([]model.Document, error) // 获取知识库下所有文档 24 | DeleteDocsByKBID(kbID string) error // 删除知识库下所有文档 25 | BatchDeleteDocs(userID uint, docIDs []string) error // 批量删除文档 26 | } 27 | 28 | type kbDao struct { 29 | db *gorm.DB 30 | } 31 | 32 | func NewKnowledgeBaseDao(db *gorm.DB) KnowledgeBaseDao { return &kbDao{db: db} } 33 | 34 | func (kd *kbDao) GetDB() *gorm.DB { 35 | return kd.db 36 | } 37 | 38 | func (kd *kbDao) CreateKB(kb *model.KnowledgeBase) error { 39 | result := kd.db.Create(kb) 40 | if result.Error != nil { 41 | return result.Error 42 | } 43 | return nil 44 | } 45 | 46 | func (kd *kbDao) GetKBByID(kb_id string) (*model.KnowledgeBase, error) { 47 | kb := &model.KnowledgeBase{} 48 | if err := kd.db.Where("id = ?", kb_id).First(kb).Error; err != nil { 49 | return nil, err 50 | } 51 | return kb, nil 52 | } 53 | 54 | func (kd *kbDao) CountKBs(userID uint) (int64, error) { 55 | var total int64 56 | query := kd.db.Model(&model.KnowledgeBase{}).Where("user_id = ?", userID) 57 | 58 | if err := query.Count(&total).Error; err != nil { 59 | return 0, err 60 | } 61 | return total, nil 62 | } 63 | func (kd *kbDao) ListKBs(userID uint, page int, pageSize int) ([]model.KnowledgeBase, error) { 64 | var kbs []model.KnowledgeBase 65 | query := kd.db.Where("user_id = ?", userID).Order("created_at desc") 66 | 67 | offset := (page - 1) * pageSize 68 | query = query.Offset(offset).Limit(pageSize) 69 | 70 | if err := query.Find(&kbs).Error; err != nil { 71 | return nil, err 72 | } 73 | return kbs, nil 74 | } 75 | func (kd *kbDao) CountDocs(kbID string) (int64, error) { 76 | var total int64 77 | query := kd.db.Model(&model.Document{}).Where("knowledge_base_id = ?", kbID) 78 | if err := query.Count(&total).Error; err != nil { 79 | return 0, err 80 | } 81 | return total, nil 82 | } 83 | 84 | func (kd *kbDao) ListDocs(kbID string, page int, size int) ([]model.Document, error) { 85 | var docs []model.Document 86 | query := kd.db.Where("knowledge_base_id = ?", kbID).Order("created_at asc") 87 | 88 | offset := (page - 1) * size 89 | query = query.Offset(offset).Limit(size) 90 | if err := query.Find(&docs).Error; err != nil { 91 | return nil, err 92 | } 93 | return docs, nil 94 | } 95 | func (kd *kbDao) DeleteKB(id string) error { 96 | return kd.db.Where("id = ?", id).Delete(&model.KnowledgeBase{}).Error 97 | } 98 | 99 | func (kd *kbDao) CreateDocument(doc *model.Document) error { 100 | return kd.db.Create(doc).Error 101 | } 102 | 103 | func (kd *kbDao) UpdateDocument(doc *model.Document) error { 104 | if err := kd.db.Save(doc).Error; err != nil { 105 | return fmt.Errorf("更新文档失败: %w", err) 106 | } 107 | return nil 108 | } 109 | 110 | func (kd *kbDao) GetAllDocsByKBID(kbID string) ([]model.Document, error) { 111 | var docs []model.Document 112 | if err := kd.db.Where("knowledge_base_id = ?", kbID).Find(&docs).Error; err != nil { 113 | return nil, fmt.Errorf("获取文档失败: %w", err) 114 | } 115 | return docs, nil 116 | } 117 | 118 | func (kd *kbDao) DeleteDocsByKBID(kbID string) error { 119 | if err := kd.db.Where("knowledge_base_id = ?", kbID).Delete(&model.Document{}).Error; err != nil { 120 | return fmt.Errorf("删除文档失败: %w", err) 121 | } 122 | return nil 123 | } 124 | 125 | func (kd *kbDao) BatchDeleteDocs(userID uint, docIDs []string) error { 126 | if len(docIDs) == 0 { 127 | return nil 128 | } 129 | res := kd.db.Where("id IN (?) AND user_id = ?", docIDs, userID).Delete(&model.Document{}) 130 | if res.Error != nil { 131 | return fmt.Errorf("db删除错误:%w", res.Error) 132 | } 133 | 134 | if res.RowsAffected != int64(len(docIDs)) { 135 | return fmt.Errorf("expected to delete %d records, but deleted %d", len(docIDs), res.RowsAffected) 136 | } 137 | return nil 138 | } 139 | -------------------------------------------------------------------------------- /internal/dao/model_dao.go: -------------------------------------------------------------------------------- 1 | package dao 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "context" 6 | "errors" 7 | "gorm.io/gorm" 8 | ) 9 | 10 | type ModelDao interface { 11 | Create(ctx context.Context, m *model.Model) error 12 | Update(ctx context.Context, m *model.Model) error 13 | Delete(ctx context.Context, userID uint, modelID string) error 14 | GetByID(ctx context.Context, userID uint, modelID string) (*model.Model, error) 15 | List(ctx context.Context, userID uint, modelType string) ([]*model.Model, error) 16 | Page(ctx context.Context, userID uint, modelType string, page, size int) ([]*model.Model, int64, error) 17 | } 18 | 19 | type modelDao struct { 20 | db *gorm.DB 21 | } 22 | 23 | func NewModelDao(db *gorm.DB) ModelDao { 24 | return &modelDao{db: db} 25 | } 26 | 27 | func (d *modelDao) Create(ctx context.Context, m *model.Model) error { 28 | return d.db.WithContext(ctx).Create(m).Error 29 | } 30 | 31 | func (d *modelDao) Update(ctx context.Context, m *model.Model) error { 32 | // 检查模型是否属于该用户 33 | var count int64 34 | if err := d.db.WithContext(ctx).Model(&model.Model{}).Where("id = ? AND user_id = ?", m.ID, m.UserID).Count(&count).Error; err != nil { 35 | return err 36 | } 37 | if count == 0 { 38 | return errors.New("模型不存在或无权限") 39 | } 40 | 41 | // 只更新允许修改的字段,排除CreatedAt 42 | return d.db.WithContext(ctx).Model(m). 43 | Select( 44 | "ShowName", "Server", "BaseURL", "ModelName", "APIKey", 45 | "Dimension", "MaxOutputLength", "Function", "MaxTokens", 46 | ). 47 | Updates(m).Error 48 | } 49 | 50 | func (d *modelDao) Delete(ctx context.Context, userID uint, id string) error { 51 | result := d.db.WithContext(ctx).Where("id = ? AND user_id = ?", id, userID).Delete(&model.Model{}) 52 | if result.RowsAffected == 0 { 53 | return errors.New("模型不存在或无权限") 54 | } 55 | return result.Error 56 | } 57 | 58 | func (d *modelDao) GetByID(ctx context.Context, userID uint, id string) (*model.Model, error) { 59 | var m model.Model 60 | err := d.db.WithContext(ctx).Where("id = ? AND user_id = ?", id, userID).First(&m).Error 61 | if err != nil { 62 | if errors.Is(err, gorm.ErrRecordNotFound) { 63 | return nil, errors.New("模型不存在或无权限") 64 | } 65 | return nil, err 66 | } 67 | return &m, nil 68 | } 69 | 70 | func (d *modelDao) Page(ctx context.Context, userID uint, modelType string, page, size int) ([]*model.Model, int64, error) { 71 | var models []*model.Model 72 | var count int64 73 | 74 | db := d.db.WithContext(ctx).Model(&model.Model{}).Where("user_id = ?", userID) 75 | if modelType != "" { 76 | db = db.Where("type = ?", modelType) 77 | } 78 | 79 | err := db.Count(&count).Offset((page - 1) * size).Limit(size).Find(&models).Error 80 | return models, count, err 81 | } 82 | 83 | func (d *modelDao) List(ctx context.Context, userID uint, modelType string) ([]*model.Model, error) { 84 | var models []*model.Model 85 | db := d.db.WithContext(ctx).Where("user_id = ?", userID) 86 | if modelType != "" { 87 | db = db.Where("type = ?", modelType) 88 | } 89 | if err := db.Find(&models).Error; err != nil { 90 | return nil, err 91 | } 92 | return models, nil 93 | } 94 | -------------------------------------------------------------------------------- /internal/dao/user_dao.go: -------------------------------------------------------------------------------- 1 | package dao 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | 6 | "gorm.io/gorm" 7 | ) 8 | 9 | type UserDao interface { 10 | CheckFieldExists(field string, value interface{}) (bool, error) 11 | CreateUser(user *model.User) error 12 | GetUserByName(name string) (user *model.User, err error) 13 | } 14 | 15 | type userDao struct { 16 | db *gorm.DB 17 | } 18 | 19 | func NewUserDao(db *gorm.DB) UserDao { 20 | return &userDao{db: db} 21 | } 22 | 23 | // CheckFieldExists 检查字段是否存在 24 | func (ud *userDao) CheckFieldExists(field string, value interface{}) (bool, error) { 25 | var count int64 26 | if err := ud.db.Model(&model.User{}).Where(field+" = ?", value).Count(&count).Error; err != nil { 27 | return false, err 28 | } 29 | return count > 0, nil 30 | } 31 | 32 | func (ud *userDao) CreateUser(user *model.User) error { 33 | result := ud.db.Create(user) 34 | 35 | if result.Error != nil { 36 | return result.Error 37 | } 38 | return nil 39 | } 40 | 41 | func (ud *userDao) GetUserByName(name string) (*model.User, error) { 42 | var user model.User 43 | result := ud.db.Model(&model.User{}).Where("username = ?", name).First(&user) 44 | if result.Error != nil { 45 | return nil, result.Error 46 | } 47 | return &user, nil 48 | } 49 | -------------------------------------------------------------------------------- /internal/database/milvus.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "ai-cloud/config" 5 | "context" 6 | "github.com/milvus-io/milvus-sdk-go/v2/client" 7 | "sync" 8 | ) 9 | 10 | var ( 11 | once sync.Once 12 | instance client.Client 13 | err error 14 | ) 15 | 16 | // InitMilvus 初始化 17 | func InitMilvus(ctx context.Context) (client.Client, error) { 18 | 19 | once.Do(func() { 20 | instance, err = client.NewClient(ctx, client.Config{ 21 | Address: config.GetConfig().Milvus.Address, 22 | }) 23 | 24 | }) 25 | return instance, err 26 | } 27 | 28 | func GetMilvusClient() client.Client { 29 | return instance 30 | } 31 | -------------------------------------------------------------------------------- /internal/database/mysql.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "ai-cloud/config" 5 | "ai-cloud/internal/model" 6 | "fmt" 7 | "gorm.io/driver/mysql" 8 | "gorm.io/gorm" 9 | "sync" 10 | ) 11 | 12 | var ( 13 | db *gorm.DB 14 | dbOnce sync.Once 15 | dbErr error 16 | ) 17 | 18 | // GetDB 获取数据库单例 19 | func GetDB() (*gorm.DB, error) { 20 | dbOnce.Do(func() { 21 | // 构造 DSN 22 | dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", 23 | config.AppConfigInstance.Database.User, 24 | config.AppConfigInstance.Database.Password, 25 | config.AppConfigInstance.Database.Host, 26 | config.AppConfigInstance.Database.Port, 27 | config.AppConfigInstance.Database.Name, 28 | ) 29 | 30 | // 连接数据库 31 | db, dbErr = gorm.Open(mysql.Open(dsn), &gorm.Config{}) 32 | if dbErr != nil { 33 | return 34 | } 35 | 36 | // 自动迁移, 创建表结构 37 | if err := db.AutoMigrate( 38 | &model.User{}, 39 | &model.File{}, 40 | &model.KnowledgeBase{}, 41 | &model.Document{}, 42 | &model.Model{}, 43 | &model.Agent{}, 44 | // 会话记录相关 45 | &model.Conversation{}, 46 | &model.Message{}, 47 | &model.Attachment{}, 48 | &model.MessageAttachment{}, 49 | ); err != nil { 50 | dbErr = err 51 | return 52 | } 53 | }) 54 | 55 | return db, dbErr 56 | } 57 | -------------------------------------------------------------------------------- /internal/middleware/auth.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "ai-cloud/pkgs/errcode" 5 | "errors" 6 | "net/http" 7 | "strings" 8 | 9 | "github.com/gin-gonic/gin" 10 | "github.com/golang-jwt/jwt/v5" 11 | ) 12 | 13 | const ( 14 | AuthHeaderKey = "Authorization" 15 | AuthBearerType = "Bearer" 16 | ) 17 | 18 | func JWTAuth() gin.HandlerFunc { 19 | return func(c *gin.Context) { 20 | // 获取 Authorization 头 21 | authHeader := c.GetHeader(AuthHeaderKey) 22 | if authHeader == "" { 23 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ 24 | "code": errcode.TokenMissing, 25 | "message": "需要认证令牌", 26 | }) 27 | return 28 | } 29 | 30 | // 分割 Bearer 和令牌 31 | parts := strings.SplitN(authHeader, " ", 2) 32 | if len(parts) != 2 || parts[0] != AuthBearerType { 33 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ 34 | "code": errcode.TokenInvalid, 35 | "message": "令牌格式错误", 36 | }) 37 | return 38 | } 39 | 40 | // 解析验证令牌 41 | tokenString := parts[1] 42 | claims, err := ParseToken(tokenString) 43 | if err != nil { 44 | status := http.StatusUnauthorized 45 | code := errcode.TokenInvalid 46 | message := "无效令牌" 47 | 48 | // 新版错误处理 49 | switch { 50 | case errors.Is(err, jwt.ErrTokenExpired): 51 | message = "令牌已过期" 52 | code = errcode.TokenExpired 53 | status = http.StatusForbidden 54 | case errors.Is(err, jwt.ErrTokenMalformed): 55 | message = "令牌格式错误" 56 | case errors.Is(err, jwt.ErrSignatureInvalid): 57 | message = "签名验证失败" 58 | } 59 | 60 | c.AbortWithStatusJSON(status, gin.H{ 61 | "code": code, 62 | "message": message, 63 | }) 64 | return 65 | } 66 | 67 | // 将用户ID存入上下文 68 | c.Set("user_id", claims.UserID) 69 | c.Next() 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /internal/middleware/cors.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "ai-cloud/config" 5 | "github.com/gin-contrib/cors" 6 | "github.com/gin-gonic/gin" 7 | "time" 8 | ) 9 | 10 | // SetupCORS 封装CORS配置 11 | func SetupCORS() gin.HandlerFunc { 12 | corsConfig := config.GetConfig().CORS 13 | 14 | maxAge, err := time.ParseDuration(corsConfig.MaxAge) 15 | if err != nil { 16 | maxAge = 12 * time.Hour 17 | } 18 | 19 | return cors.New(cors.Config{ 20 | AllowOrigins: corsConfig.AllowOrigins, // 允许所有域名 21 | AllowMethods: corsConfig.AllowMethods, // 允许的HTTP方法 22 | AllowHeaders: corsConfig.AllowHeaders, // 允许的请求头 23 | ExposeHeaders: corsConfig.ExposeHeaders, // 暴露的响应头 24 | AllowCredentials: corsConfig.AllowCredentials, // 允许携带凭证(如Cookie) 25 | MaxAge: maxAge, // 预检请求缓存时间 26 | }) 27 | } 28 | -------------------------------------------------------------------------------- /internal/middleware/jwt.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "ai-cloud/config" 5 | "time" 6 | 7 | "github.com/golang-jwt/jwt/v5" 8 | ) 9 | 10 | // Claims 结构体定义了JWT中包含的用户相关声明。 11 | // UserID 是用户在JWT中的唯一标识。 12 | // RegisteredClaims 包含了一些标准的JWT声明,如过期时间、发行时间等。 13 | type Claims struct { 14 | UserID uint `json:"user_id"` 15 | jwt.RegisteredClaims 16 | } 17 | 18 | // GenerateToken 为给定的用户ID生成JWT。 19 | // userID 是要编码到JWT中的用户ID。 20 | // 函数返回生成的JWT字符串和可能的错误。 21 | func GenerateToken(userID uint) (string, error) { 22 | cfg := config.AppConfigInstance.JWT 23 | expirationTime := time.Now().Add(time.Duration(cfg.ExpirationHours) * time.Hour) 24 | 25 | claims := &Claims{ 26 | UserID: userID, 27 | RegisteredClaims: jwt.RegisteredClaims{ 28 | ExpiresAt: jwt.NewNumericDate(expirationTime), 29 | IssuedAt: jwt.NewNumericDate(time.Now()), 30 | NotBefore: jwt.NewNumericDate(time.Now()), 31 | }, 32 | } 33 | 34 | token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 35 | return token.SignedString([]byte(cfg.Secret)) 36 | } 37 | 38 | // ParseToken 解析JWT并验证其有效性。 39 | // tokenString 是待解析的JWT字符串。 40 | // 函数返回解析出的Claims指针和可能的错误。 41 | func ParseToken(tokenString string) (*Claims, error) { 42 | cfg := config.AppConfigInstance.JWT 43 | 44 | token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { 45 | return []byte(cfg.Secret), nil 46 | }) 47 | 48 | if err != nil { 49 | return nil, err 50 | } 51 | 52 | if claims, ok := token.Claims.(*Claims); ok && token.Valid { 53 | return claims, nil 54 | } 55 | 56 | return nil, jwt.ErrInvalidKey 57 | } 58 | -------------------------------------------------------------------------------- /internal/model/agent.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "github.com/cloudwego/eino/schema" 5 | "time" 6 | ) 7 | 8 | // Agent 9 | type Agent struct { 10 | ID string `gorm:"primaryKey;type:char(36)"` 11 | UserID uint `gorm:"index"` 12 | Name string `gorm:"not null"` 13 | Description string `gorm:"type:text"` 14 | AgentSchema string `gorm:"type:json"` 15 | CreatedAt time.Time `gorm:"autoCreateTime"` 16 | UpdatedAt time.Time `gorm:"autoUpdateTime"` 17 | } 18 | 19 | // AgentSchema 配置Agent 20 | type AgentSchema struct { 21 | LLMConfig LLMConfig `json:"llm_config"` 22 | MCP MCPConfig `json:"mcp"` 23 | Tools ToolsConfig `json:"tools"` 24 | Prompt string `json:"prompt"` 25 | Knowledge KnowledgeConfig `json:"knowledge"` 26 | } 27 | 28 | // TODO:后续需要优化这一块,提高全面性 29 | // LLMConfig 配置Agent关联的LLM模型 30 | type LLMConfig struct { 31 | ModelID string `json:"model_id"` 32 | Temperature float64 `json:"temperature"` 33 | TopP float64 `json:"top_p"` 34 | MaxOutputLength int `json:"max_output_length"` 35 | Thinking bool `json:"thinking"` 36 | } 37 | 38 | // MCPConfig 配置MCP SSE服务器 39 | type MCPConfig struct { 40 | Servers []string `json:"servers"` 41 | } 42 | 43 | // ToolsConfig 配置Agent关联的工具IDs(考虑到MCP到存在,目前没有实现Tools模块) 44 | type ToolsConfig struct { 45 | ToolIDs []string `json:"tool_ids"` 46 | } 47 | 48 | // KnowledgeConfig Agent关联的知识库IDs 49 | type KnowledgeConfig struct { 50 | KnowledgeIDs []string `json:"knowledge_ids"` 51 | TopK int `json:"top_k"` 52 | } 53 | 54 | // CreateAgentRequest 创建Agent请求 55 | type CreateAgentRequest struct { 56 | Name string `json:"name" binding:"required"` 57 | Description string `json:"description"` 58 | } 59 | 60 | // UpdateAgentRequest 更新Agent请求 61 | type UpdateAgentRequest struct { 62 | ID string `json:"id" binding:"required"` 63 | Name string `json:"name"` 64 | Description string `json:"description"` 65 | LLMConfig LLMConfig `json:"llm_config"` 66 | MCP MCPConfig `json:"mcp"` 67 | Tools ToolsConfig `json:"tools"` 68 | Prompt string `json:"prompt"` 69 | Knowledge KnowledgeConfig `json:"knowledge"` 70 | } 71 | 72 | type PageAgentRequest struct { 73 | Page int `form:"page,default=1"` 74 | Size int `form:"size,default=10"` 75 | } 76 | 77 | type UserMessage struct { 78 | Query string `json:"query" binding:"required"` 79 | History []*schema.Message `json:"history"` 80 | } 81 | 82 | type ExecuteAgentRequest struct { 83 | ID string `json:"id" binding:"required"` 84 | AgentID string `json:"agent_id" binding:"required"` 85 | Message UserMessage `json:"message" binding:"required"` 86 | } 87 | -------------------------------------------------------------------------------- /internal/model/chat.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "github.com/cloudwego/eino/schema" 4 | 5 | type ChatResponse struct { 6 | Response string `json:"response"` 7 | References []*schema.Document `json:"references"` 8 | } 9 | 10 | type ChatRequest struct { 11 | Query string `json:"query"` 12 | KBs []string `json:"kbs"` 13 | } 14 | 15 | // ChatStreamResponse OpenAI 兼容的流式响应格式 16 | type ChatStreamResponse struct { 17 | ID string `json:"id"` 18 | Object string `json:"object"` 19 | Created int64 `json:"created"` 20 | Model string `json:"model"` 21 | Choices []ChatStreamChoice `json:"choices"` 22 | } 23 | 24 | type ChatStreamChoice struct { 25 | Delta ChatStreamDelta `json:"delta"` 26 | Index int `json:"index"` 27 | FinishReason *string `json:"finish_reason"` 28 | } 29 | 30 | type ChatStreamDelta struct { 31 | Role string `json:"role,omitempty"` 32 | Content string `json:"content,omitempty"` 33 | } 34 | -------------------------------------------------------------------------------- /internal/model/conversation.go: -------------------------------------------------------------------------------- 1 | /* 2 | * This file contains modified code from the eino-history project. 3 | * Original source: https://github.com/HildaM/eino-history 4 | * Licensed under the Apache License, Version 2.0. 5 | * Modifications are made by RaspberryCola. 6 | */ 7 | 8 | package model 9 | 10 | import "encoding/json" 11 | 12 | // Conversation 对话表 13 | type Conversation struct { 14 | ID uint64 `gorm:"primaryKey;column:id"` 15 | ConvID string `gorm:"uniqueIndex;column:conv_id;type:varchar(255)"` 16 | UserID uint `gorm:"index;column:user_id"` 17 | AgentID string `gorm:"index;column:agent_id;type:varchar(255)"` 18 | Title string `gorm:"column:title;type:varchar(255)"` 19 | CreatedAt int64 `gorm:"column:created_at"` 20 | UpdatedAt int64 `gorm:"column:updated_at"` 21 | Settings json.RawMessage `gorm:"column:settings;type:json"` 22 | IsArchived bool `gorm:"column:is_archived;default:0"` 23 | IsPinned bool `gorm:"column:is_pinned;default:0"` 24 | } 25 | 26 | // TableName 设置表名 27 | func (Conversation) TableName() string { 28 | return "conversations" 29 | } 30 | 31 | // Message 消息表 32 | type Message struct { 33 | ID uint64 `gorm:"primaryKey;column:id"` 34 | MsgID string `gorm:"uniqueIndex;column:msg_id;type:varchar(255)"` 35 | UserID uint `gorm:"index;column:user_id"` 36 | ConvID string `gorm:"column:conv_id;type:varchar(255)"` 37 | ParentID string `gorm:"column:parent_id;type:varchar(255);default:''"` 38 | Role string `gorm:"column:role;type:enum('user','assistant','system','function')"` 39 | Content string `gorm:"column:content;type:text"` 40 | CreatedAt int64 `gorm:"column:created_at"` 41 | OrderSeq int `gorm:"column:order_seq;default:0"` 42 | TokenCount int `gorm:"column:token_count;default:0"` 43 | Status string `gorm:"column:status;type:enum('sent','pending','error');default:'sent'"` 44 | Metadata json.RawMessage `gorm:"column:metadata;type:json"` 45 | IsContextEdge bool `gorm:"column:is_context_edge;default:0"` 46 | IsVariant bool `gorm:"column:is_variant;default:0"` 47 | } 48 | 49 | // TableName 设置表名 50 | func (Message) TableName() string { 51 | return "messages" 52 | } 53 | 54 | // Attachment 附件表 55 | type Attachment struct { 56 | ID uint64 `gorm:"primaryKey;column:id"` 57 | AttachID string `gorm:"uniqueIndex;column:attach_id;type:varchar(255)"` 58 | UserID uint `gorm:"index;column:user_id"` 59 | MessageID string `gorm:"column:message_id;type:varchar(255)"` 60 | AttachmentType string `gorm:"column:attachment_type;type:enum('file','image','code','audio','video')"` 61 | FileName string `gorm:"column:file_name;type:varchar(255)"` 62 | FileSize int64 `gorm:"column:file_size"` 63 | StorageType string `gorm:"column:storage_type;type:enum('path','blob','cloud')"` 64 | StoragePath string `gorm:"column:storage_path;type:varchar(1024)"` 65 | Thumbnail []byte `gorm:"column:thumbnail;type:mediumblob"` 66 | Vectorized bool `gorm:"column:vectorized;default:0"` 67 | DataSummary string `gorm:"column:data_summary;type:text"` 68 | MimeType string `gorm:"column:mime_type;type:varchar(255)"` 69 | CreatedAt int64 `gorm:"column:created_at"` 70 | } 71 | 72 | // TableName 设置表名 73 | func (Attachment) TableName() string { 74 | return "attachments" 75 | } 76 | 77 | // MessageAttachment 消息附件关联表 78 | type MessageAttachment struct { 79 | ID uint64 `gorm:"primaryKey;column:id"` 80 | MessageID string `gorm:"column:message_id;type:varchar(255)"` 81 | AttachmentID string `gorm:"column:attachment_id;type:varchar(255)"` 82 | } 83 | 84 | // TableName 设置表名 85 | func (MessageAttachment) TableName() string { 86 | return "message_attachments" 87 | } 88 | 89 | // DebugRequest 获取AgentID和消息 90 | type DebugRequest struct { 91 | AgentID string `json:"agent_id" binding:"required"` 92 | Message string `json:"message" binding:"required"` 93 | } 94 | 95 | // CreateConvRequest 创建会话请求 96 | type CreateConvRequest struct { 97 | AgentID string `json:"agent_id" binding:"required"` 98 | } 99 | 100 | // ConvRequest 对话请求 101 | type ConvRequest struct { 102 | AgentID string `json:"agent_id" binding:"required"` 103 | Message string `json:"message" binding:"required"` 104 | ConvID string `json:"conv_id"` 105 | } 106 | -------------------------------------------------------------------------------- /internal/model/file.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "time" 4 | 5 | type File struct { 6 | ID string `gorm:"primaryKey;type:char(36)"` // UUID 7 | UserID uint `gorm:"index"` // 用户ID 8 | Name string `gorm:"not null"` // 文件名 9 | Size int64 // 文件大小 10 | Hash string `gorm:"index;size:64"` // 文件哈希(SHA-256) 11 | MIMEType string // MIME类型 12 | IsDir bool `gorm:"default:false"` // 是否为目录 13 | ParentID *string `gorm:"type:char(36);index"` // 父目录ID 14 | StorageType string `gorm:"default:'local'"` // 存储类型:local/oss 15 | StorageKey string // 存储唯一标识(路径或OSS Key) 16 | CreatedAt time.Time `gorm:"autoCreateTime"` // 创建时间 17 | UpdatedAt time.Time `gorm:"autoUpdateTime"` // 更新时间 18 | } 19 | 20 | type CreateFolderReq struct { 21 | Name string `json:"name"` 22 | ParentID *string `json:"parent_id,omitempty"` 23 | } 24 | 25 | type BatchMoveRequest struct { 26 | FileIDs []string `json:"files_pid" binding:"required"` 27 | TargetParentID string `json:"target_pid"` 28 | } 29 | 30 | type RenameRequest struct { 31 | FileID string `json:"file_id" binding:"required"` 32 | NewName string `json:"new_name" binding:"required"` 33 | } 34 | -------------------------------------------------------------------------------- /internal/model/knowledge.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | // KnowledgeBase 知识库 8 | type KnowledgeBase struct { 9 | ID string `gorm:"primaryKey;type:char(36)"` // UUID 10 | Name string `gorm:"not null"` // 知识库名称 11 | Description string // 知识库描述 12 | UserID uint `gorm:"index"` // 创建者ID 13 | EmbedModelID string `gorm:"index"` // 关联的embedding模型id 14 | MilvusCollection string `gorm:"not null"` //对应的milvus collection名称 15 | CreatedAt time.Time `gorm:"autoCreateTime"` 16 | UpdatedAt time.Time `gorm:"autoUpdateTime"` 17 | } 18 | 19 | // Document 知识库文档 20 | type Document struct { 21 | ID string `gorm:"primaryKey;type:char(36)"` // UUID 22 | UserID uint `gorm:"index"` // 所属的用户 23 | KnowledgeBaseID string `gorm:"index"` // 所属知识库ID 24 | FileID string `gorm:"index"` // 关联的文件ID 25 | Title string // 文档标题 26 | DocType string // 文档类型(pdf/txt/md) 27 | Status int // 处理状态(0:待处理,1:处理中,2:已完成,3:失败) 28 | CreatedAt time.Time `gorm:"autoCreateTime"` 29 | UpdatedAt time.Time `gorm:"autoUpdateTime"` 30 | } 31 | 32 | // 存储到milvus中 33 | type Chunk struct { 34 | ID string `json:"id"` 35 | Content string `json:"content"` // chunk内容 36 | KBID string `json:"kb_id"` // 知识库ID(知识库级别的检索) 37 | DocumentID string `json:"document_id"` // 文档ID 38 | DocumentName string `json:"document_name"` // 文档名称 39 | Index int `json:"index"` // 第几个chunk 40 | Embeddings []float32 `json:"embeddings"` // chunk向量 41 | Score float32 `json:"score"` // 返回分数信息 42 | } 43 | 44 | type AddFileRequest struct { 45 | FileID string `json:"file_id"` 46 | KBID string `json:"kb_id"` 47 | } 48 | 49 | type BatchDeleteDocsReq struct { 50 | KBID string `json:"kb_id"` 51 | DocIDs []string `json:"doc_ids" binding:"required"` 52 | } 53 | 54 | type CreateKBRequest struct { 55 | Name string `json:"name" binding:"required"` 56 | Description string `json:"description"` 57 | EmbedModelID string `json:"embed_model_id" binding:"required"` 58 | } 59 | 60 | type RetrieveRequest struct { 61 | KBID string `json:"kb_id"` 62 | Query string `json:"query"` 63 | TopK int `json:"top_k"` 64 | } 65 | -------------------------------------------------------------------------------- /internal/model/model.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "time" 4 | 5 | type Model struct { 6 | // 基础信息 7 | ID string `gorm:"primaryKey;type:char(36)"` 8 | UserID uint `gorm:"index"` // 用户ID 9 | Type string `gorm:"not null"` // 模型的类型:embedding/llm 10 | ShowName string `gorm:"not null"` // 显示名称 11 | Server string `gorm:"not null"` // 模型的供应商:openai/ollama 12 | BaseURL string `gorm:"not null"` // API基础地址 13 | ModelName string `gorm:"not null"` // 模型标识符,例如 deepseek-chat,text-embedding-v3 14 | APIKey string // 访问密钥,ollama一般不需要 15 | 16 | // Embedding模型字段 17 | Dimension int // 向量维度(embedding必填) 18 | 19 | // LLM模型字段 20 | MaxOutputLength int `gorm:"default:4096"` 21 | Function bool `gorm:"default:false"` 22 | 23 | // 通用字段 24 | MaxTokens int `gorm:"default:1024"` // 限制最大的输入长度 25 | CreatedAt time.Time `gorm:"autoCreateTime"` 26 | UpdatedAt time.Time `gorm:"autoUpdateTime"` 27 | } 28 | 29 | type CreateModelRequest struct { 30 | // 基础信息 31 | Type string `json:"type" binding:"required,oneof=embedding llm"` 32 | ShowName string `json:"name" binding:"required"` 33 | Server string `json:"server" binding:"required"` 34 | BaseURL string `json:"base_url" binding:"required,url"` 35 | ModelName string `json:"model" binding:"required"` 36 | APIKey string `json:"api_key"` 37 | 38 | // Embedding 39 | Dimension int `json:"dimension"` 40 | 41 | // LLM 42 | MaxOutputLength int `json:"max_output_length"` 43 | Function bool `json:"function"` 44 | 45 | // 通用字段 46 | MaxTokens int `json:"max_tokens"` 47 | } 48 | 49 | type PageModelRequest struct { 50 | Type string `form:"type"` 51 | Page int `form:"page,default=1"` 52 | Size int `form:"size,default=10"` 53 | } 54 | 55 | type UpdateModelRequest struct { 56 | ID string `json:"id" binding:"required"` 57 | // 基础信息 58 | ShowName string `json:"name"` 59 | Server string `json:"server"` 60 | BaseURL string `json:"base_url"` 61 | ModelName string `json:"model"` 62 | APIKey string `json:"api_key"` 63 | 64 | // Embedding 65 | Dimension int `json:"dimension"` 66 | 67 | // LLM 68 | MaxOutputLength int `json:"max_output_length"` 69 | Function bool `json:"function"` 70 | 71 | // 通用字段 72 | MaxTokens int `json:"max_tokens"` 73 | } 74 | -------------------------------------------------------------------------------- /internal/model/user.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | type User struct { 8 | ID uint `gorm:"primaryKey"` 9 | Username string `gorm:"uniqueIndex;size:50;not null"` 10 | Phone string `gorm:"uniqueIndex;size:20;not null"` // 新增手机号字段 11 | Email string `gorm:"uniqueIndex;size:100"` 12 | Password string `gorm:"not null"` 13 | CreatedAt time.Time `gorm:"autoCreateTime"` 14 | UpdatedAt time.Time `gorm:"autoUpdateTime"` 15 | } 16 | 17 | type UserRegisterReq struct { 18 | Username string `json:"username" binding:"required,min=3,max=20"` 19 | Password string `json:"password" binding:"required,min=6,max=30"` 20 | Email string `json:"email" binding:"required,email"` 21 | Phone string `json:"phone" binding:"required,e164"` 22 | } 23 | 24 | type LoginResponse struct { 25 | AccessToken string `json:"access_token"` 26 | ExpiresIn int `json:"expires_in"` 27 | TokenType string `json:"token_type"` 28 | } 29 | type UserNameLoginReq struct { 30 | Username string `json:"username" binding:"required,min=3,max=20"` 31 | Password string `json:"password" binding:"required,min=8,max=30"` 32 | } 33 | 34 | type UserPhoneLogin struct { 35 | Phone string `json:"phone" binding:"required,e164"` 36 | Password string `json:"password" binding:"required,min=8,max=30"` 37 | } 38 | -------------------------------------------------------------------------------- /internal/router/router.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "ai-cloud/internal/controller" 5 | "ai-cloud/internal/middleware" 6 | 7 | "github.com/gin-gonic/gin" 8 | ) 9 | 10 | func SetUpRouters(r *gin.Engine, uc *controller.UserController, fc *controller.FileController, kc *controller.KBController, mc *controller.ModelController, ac *controller.AgentController, cc *controller.ConversationController) { 11 | api := r.Group("/api") 12 | { 13 | 14 | publicUser := api.Group("/users") 15 | { 16 | publicUser.POST("/register", uc.Register) 17 | publicUser.POST("/login", uc.Login) 18 | } 19 | 20 | auth := api.Group("files") 21 | auth.Use(middleware.JWTAuth()) 22 | { 23 | auth.POST("/upload", fc.Upload) 24 | auth.GET("/page", fc.PageList) 25 | auth.GET("/download", fc.Download) 26 | auth.DELETE("/delete", fc.Delete) 27 | auth.POST("/folder", fc.CreateFolder) 28 | auth.POST("/move", fc.BatchMove) 29 | auth.GET("/search", fc.Search) 30 | auth.PUT("/rename", fc.Rename) 31 | auth.GET("/path", fc.GetPath) 32 | auth.GET("/idPath", fc.GetIDPath) 33 | } 34 | kb := api.Group("knowledge") 35 | kb.Use(middleware.JWTAuth()) 36 | { 37 | // KB 38 | kb.POST("/create", kc.Create) 39 | kb.DELETE("/delete", kc.Delete) 40 | kb.POST("/add", kc.AddExistFile) 41 | kb.POST("/addNew", kc.AddNewFile) 42 | kb.GET("/page", kc.PageList) 43 | kb.GET("/detail", kc.GetKBDetail) 44 | // Doc 45 | kb.GET("/docPage", kc.DocPage) 46 | kb.POST("/docDelete", kc.DeleteDocs) 47 | // RAG 48 | kb.POST("/retrieve", kc.Retrieve) 49 | kb.POST("/chat", kc.Chat) 50 | kb.POST("/stream", kc.ChatStream) 51 | } 52 | model := api.Group("model") 53 | model.Use(middleware.JWTAuth()) 54 | { 55 | model.POST("/create", mc.CreateModel) 56 | model.PUT("/update", mc.UpdateModel) 57 | model.DELETE("/delete", mc.DeleteModel) 58 | model.GET("/get", mc.GetModel) 59 | model.GET("/page", mc.PageModels) 60 | model.GET("/list", mc.ListModels) 61 | } 62 | agent := api.Group("agent") 63 | agent.Use(middleware.JWTAuth()) 64 | { 65 | agent.POST("/create", ac.CreateAgent) 66 | agent.POST("/update", ac.UpdateAgent) 67 | agent.DELETE("/delete", ac.DeleteAgent) 68 | agent.GET("/get", ac.GetAgent) 69 | agent.GET("/page", ac.PageAgents) 70 | agent.POST("/execute/:id", ac.ExecuteAgent) 71 | agent.POST("/stream", ac.StreamExecuteAgent) 72 | } 73 | conv := api.Group("chat") 74 | conv.Use(middleware.JWTAuth()) 75 | { 76 | // 调试模式,不保存历史 77 | conv.POST("/debug", cc.DebugStreamAgent) 78 | // 会话相关功能 79 | conv.POST("/create", cc.CreateConversation) 80 | conv.POST("/stream", cc.StreamConversation) 81 | conv.GET("/list", cc.ListConversations) 82 | conv.GET("/list/agent", cc.ListAgentConversations) 83 | conv.GET("/history", cc.GetConversationHistory) 84 | conv.DELETE("/delete", cc.DeleteConversation) 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /internal/service/agent_service.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | llmfactory "ai-cloud/internal/component/llm" 5 | mretriever "ai-cloud/internal/component/retriever/milvus" 6 | "ai-cloud/internal/dao" 7 | "ai-cloud/internal/model" 8 | "context" 9 | "encoding/json" 10 | "errors" 11 | "fmt" 12 | "time" 13 | 14 | mcpp "github.com/cloudwego/eino-ext/components/tool/mcp" 15 | "github.com/cloudwego/eino/components/prompt" 16 | "github.com/cloudwego/eino/components/tool" 17 | "github.com/cloudwego/eino/compose" 18 | "github.com/cloudwego/eino/flow/agent/react" 19 | "github.com/cloudwego/eino/schema" 20 | "github.com/mark3labs/mcp-go/client" 21 | "github.com/mark3labs/mcp-go/mcp" 22 | ) 23 | 24 | const ( 25 | InputToQuery = "InputToQuery" 26 | InputToHistory = "InputToHistory" 27 | ChatTemplate = "ChatTemplate" 28 | ChatModel = "ChatModel" 29 | Retriever = "Retriever" 30 | Agent = "Agent" 31 | ) 32 | 33 | type AgentService interface { 34 | CreateAgent(ctx context.Context, agent *model.Agent) error 35 | UpdateAgent(ctx context.Context, agent *model.Agent) error 36 | DeleteAgent(ctx context.Context, userID uint, agentID string) error 37 | GetAgent(ctx context.Context, userID uint, agentID string) (*model.Agent, error) 38 | ListAgents(ctx context.Context, userID uint) ([]*model.Agent, error) 39 | PageAgents(ctx context.Context, userID uint, page, size int) ([]*model.Agent, int64, error) 40 | ExecuteAgent(ctx context.Context, userID uint, agentID string, msg model.UserMessage) (string, error) 41 | StreamExecuteAgent(ctx context.Context, userID uint, agentID string, msg model.UserMessage) (*schema.StreamReader[*schema.Message], error) 42 | } 43 | 44 | type agentService struct { 45 | dao dao.AgentDao 46 | modelSvc ModelService 47 | kbSvc KBService 48 | kbDao dao.KnowledgeBaseDao 49 | modelDao dao.ModelDao 50 | historySvc HistoryService 51 | } 52 | 53 | func NewAgentService(dao dao.AgentDao, modelSvc ModelService, kbSvc KBService, kbDao dao.KnowledgeBaseDao, modelDao dao.ModelDao, historySvc HistoryService) AgentService { 54 | return &agentService{ 55 | dao: dao, 56 | modelSvc: modelSvc, 57 | kbSvc: kbSvc, 58 | kbDao: kbDao, 59 | modelDao: modelDao, 60 | historySvc: historySvc, 61 | } 62 | } 63 | 64 | func (s *agentService) CreateAgent(ctx context.Context, agent *model.Agent) error { 65 | return s.dao.Create(ctx, agent) 66 | } 67 | 68 | func (s *agentService) UpdateAgent(ctx context.Context, agent *model.Agent) error { 69 | return s.dao.Update(ctx, agent) 70 | } 71 | 72 | func (s *agentService) DeleteAgent(ctx context.Context, userID uint, agentID string) error { 73 | return s.dao.Delete(ctx, userID, agentID) 74 | } 75 | 76 | func (s *agentService) GetAgent(ctx context.Context, userID uint, agentID string) (*model.Agent, error) { 77 | return s.dao.GetByID(ctx, userID, agentID) 78 | } 79 | 80 | func (s *agentService) ListAgents(ctx context.Context, userID uint) ([]*model.Agent, error) { 81 | return s.dao.List(ctx, userID) 82 | } 83 | 84 | func (s *agentService) PageAgents(ctx context.Context, userID uint, page, size int) ([]*model.Agent, int64, error) { 85 | return s.dao.Page(ctx, userID, page, size) 86 | } 87 | 88 | func (s *agentService) ExecuteAgent(ctx context.Context, userID uint, agentID string, msg model.UserMessage) (string, error) { 89 | // Retrieve the agent 90 | agent, err := s.dao.GetByID(ctx, userID, agentID) 91 | if err != nil { 92 | return "", err 93 | } 94 | 95 | // Parse the agent schema 96 | var agentSchema model.AgentSchema 97 | if err := json.Unmarshal([]byte(agent.AgentSchema), &agentSchema); err != nil { 98 | return "", err 99 | } 100 | 101 | graph, err := s.buildGraph(ctx, userID, agentSchema) 102 | if err != nil { 103 | return "", fmt.Errorf("buildGraph失败:%w", err) 104 | } 105 | 106 | runner, err := graph.Compile(ctx, compose.WithGraphName("EinoAgent"), compose.WithNodeTriggerMode(compose.AllPredecessor)) 107 | 108 | if err != nil { 109 | return "", err 110 | } 111 | 112 | res, err := runner.Invoke(ctx, &msg) 113 | if err != nil { 114 | return "", err 115 | } 116 | return res.String(), nil 117 | } 118 | 119 | func (s *agentService) StreamExecuteAgent(ctx context.Context, userID uint, agentID string, msg model.UserMessage) (*schema.StreamReader[*schema.Message], error) { 120 | // 1.获取Agent配置 121 | agent, err := s.dao.GetByID(ctx, userID, agentID) 122 | if err != nil { 123 | return nil, err 124 | } 125 | 126 | var agentSchema model.AgentSchema 127 | if err := json.Unmarshal([]byte(agent.AgentSchema), &agentSchema); err != nil { 128 | return nil, err 129 | } 130 | 131 | // 2.构建Graph 132 | graph, err := s.buildGraph(ctx, userID, agentSchema) 133 | if err != nil { 134 | return nil, fmt.Errorf("failed to build agent graph:%w", err) 135 | } 136 | 137 | // 3.构建runner 138 | runner, err := graph.Compile(ctx, compose.WithGraphName("EinoAgent"), compose.WithNodeTriggerMode(compose.AllPredecessor)) 139 | if err != nil { 140 | return nil, fmt.Errorf("failed to compile agent graph: %w", err) 141 | } 142 | 143 | // 执行stream 144 | sr, err := runner.Stream(ctx, &msg) 145 | if err != nil { 146 | return nil, fmt.Errorf("failed to stream: %w", err) 147 | } 148 | 149 | return sr, nil 150 | } 151 | 152 | func (s *agentService) buildGraph(ctx context.Context, userID uint, agentSchema model.AgentSchema) (*compose.Graph[*model.UserMessage, *schema.Message], error) { 153 | // 1. 创建LLM 154 | llmModelCfg, err := s.modelSvc.GetModel(ctx, userID, agentSchema.LLMConfig.ModelID) 155 | if err != nil { 156 | return nil, fmt.Errorf("failed to create get model:%w", err) 157 | } 158 | llm, err := llmfactory.GetLLMClient(ctx, llmModelCfg) 159 | if err != nil { 160 | return nil, fmt.Errorf("failed to create llm client:%w", err) 161 | } 162 | 163 | // 2. 构建 Retriever 164 | multiRetriever := mretriever.MultiKBRetriever{ 165 | KBIDs: agentSchema.Knowledge.KnowledgeIDs, 166 | UserID: userID, 167 | KBDao: s.kbDao, 168 | ModelDao: s.modelDao, 169 | Ctx: ctx, 170 | TopK: agentSchema.Knowledge.TopK, // 默认返回前5个最相关的文档 171 | } 172 | 173 | // 3. 构建Tools 174 | tools := []tool.BaseTool{} 175 | // 3.1 加载MCPTools 176 | for _, serverURL := range agentSchema.MCP.Servers { 177 | cli, err := client.NewSSEMCPClient(serverURL) 178 | err = cli.Start(ctx) 179 | if err != nil { 180 | return nil, fmt.Errorf("failed to create mcp client: %w", err) 181 | } 182 | initRequest := mcp.InitializeRequest{} 183 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 184 | initRequest.Params.ClientInfo = mcp.Implementation{ 185 | Name: "example-client", 186 | Version: "1.0.0", 187 | } 188 | 189 | _, err = cli.Initialize(ctx, initRequest) 190 | 191 | if err != nil { 192 | return nil, err 193 | } 194 | // 获取 mcpp 工具 195 | mcppTools, err := mcpp.GetTools(ctx, &mcpp.Config{Cli: cli}) 196 | if err != nil { 197 | return nil, fmt.Errorf("failed to get mcpp tools: %w", err) 198 | } 199 | tools = append(tools, mcppTools...) 200 | } 201 | // 3.2 加载系统和用户自定义Tools 202 | 203 | // 4. 构建提示词 204 | promptTemplate := prompt.FromMessages( 205 | schema.FString, 206 | schema.SystemMessage(agentSchema.Prompt), 207 | schema.MessagesPlaceholder("history", true), 208 | schema.UserMessage("用户消息:{query}\n 参考信息:{documents}"), 209 | ) 210 | 211 | // 5. 实现图编排 212 | graph := compose.NewGraph[*model.UserMessage, *schema.Message]() 213 | _ = graph.AddLambdaNode(InputToQuery, compose.InvokableLambdaWithOption(inputToQueryLambda), compose.WithNodeName("UserMessageToQuery")) 214 | _ = graph.AddChatTemplateNode(ChatTemplate, promptTemplate) 215 | _ = graph.AddRetrieverNode(Retriever, multiRetriever, compose.WithOutputKey("documents")) 216 | _ = graph.AddLambdaNode(InputToHistory, compose.InvokableLambdaWithOption(inputToHistoryLambda), compose.WithNodeName("UserMessageToHistory")) 217 | 218 | // 根据是否有工具决定使用Agent还是直接使用ChatModel 219 | if len(tools) > 0 { 220 | // 有工具时使用Agent 221 | agentConfig := &react.AgentConfig{ 222 | ToolCallingModel: llm, 223 | MaxStep: 10, 224 | ToolsConfig: compose.ToolsNodeConfig{ 225 | Tools: tools, 226 | }, 227 | } 228 | 229 | agt, err := react.NewAgent(ctx, agentConfig) 230 | if err != nil { 231 | return nil, fmt.Errorf("failed to create agent: %w", err) 232 | } 233 | if agt == nil { 234 | return nil, errors.New("react.NewAgent returned a nil agent instance") 235 | } 236 | 237 | agentLambda, _ := compose.AnyLambda(agt.Generate, agt.Stream, nil, nil) 238 | 239 | _ = graph.AddLambdaNode(Agent, agentLambda, compose.WithNodeName("Agent")) 240 | 241 | _ = graph.AddEdge(compose.START, InputToQuery) 242 | _ = graph.AddEdge(compose.START, InputToHistory) 243 | _ = graph.AddEdge(InputToQuery, Retriever) 244 | _ = graph.AddEdge(Retriever, ChatTemplate) 245 | _ = graph.AddEdge(InputToHistory, ChatTemplate) 246 | _ = graph.AddEdge(ChatTemplate, Agent) 247 | _ = graph.AddEdge(Agent, compose.END) 248 | } else { 249 | // 没有工具时直接使用ChatModel 250 | _ = graph.AddChatModelNode(ChatModel, llm) 251 | 252 | _ = graph.AddEdge(compose.START, InputToQuery) 253 | _ = graph.AddEdge(compose.START, InputToHistory) 254 | _ = graph.AddEdge(InputToQuery, Retriever) 255 | _ = graph.AddEdge(Retriever, ChatTemplate) 256 | _ = graph.AddEdge(InputToHistory, ChatTemplate) 257 | _ = graph.AddEdge(ChatTemplate, ChatModel) 258 | _ = graph.AddEdge(ChatModel, compose.END) 259 | } 260 | 261 | return graph, nil 262 | } 263 | 264 | // inputToQueryLambda component initialization function of node 'InputToQuery' in graph 'EinoAgent' 265 | func inputToQueryLambda(ctx context.Context, input *model.UserMessage, opts ...any) (output string, err error) { 266 | return input.Query, nil 267 | } 268 | 269 | // inputToHistoryLambda component initialization function of node 'InputToHistory' in graph 'EinoAgent' 270 | func inputToHistoryLambda(ctx context.Context, input *model.UserMessage, opts ...any) (output map[string]any, err error) { 271 | return map[string]any{ 272 | "query": input.Query, 273 | "history": input.History, 274 | "date": time.Now().Format(time.DateTime), 275 | }, nil 276 | } 277 | -------------------------------------------------------------------------------- /internal/service/conversation_service.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "log" 10 | "time" 11 | 12 | "github.com/cloudwego/eino/schema" 13 | "github.com/google/uuid" 14 | ) 15 | 16 | var defaultConvTitle = "新对话" 17 | 18 | type ConversationService interface { 19 | // Debug模式:临时会话,不保存历史 20 | DebugStreamAgent(ctx context.Context, userID uint, agentID string, message string) (*schema.StreamReader[*schema.Message], error) 21 | 22 | // 会话模式:创建/获取会话,记录历史 23 | StreamAgentWithConversation(ctx context.Context, userID uint, agentID string, convID string, message string) (*schema.StreamReader[*schema.Message], error) 24 | 25 | // 创建新会话 26 | CreateConversation(ctx context.Context, userID uint, agentID string) (string, error) 27 | 28 | // 删除会话 29 | DeleteConversation(ctx context.Context, convID string) error 30 | 31 | // 列出用户所有会话 32 | ListConversations(ctx context.Context, userID uint, page, size int) ([]*model.Conversation, int64, error) 33 | 34 | // 列出特定Agent的会话 35 | ListAgentConversations(ctx context.Context, userID uint, agentID string, page, size int) ([]*model.Conversation, int64, error) 36 | 37 | // 获取会话历史消息 38 | GetConversationHistory(ctx context.Context, convID string, limit int) ([]*schema.Message, error) 39 | } 40 | 41 | type conversationService struct { 42 | agentSvc AgentService 43 | historySvc HistoryService 44 | } 45 | 46 | func NewConversationService(agentSvc AgentService, historySvc HistoryService) ConversationService { 47 | return &conversationService{ 48 | agentSvc: agentSvc, 49 | historySvc: historySvc, 50 | } 51 | } 52 | 53 | // DebugStreamAgent 调试模式:临时会话,不保存历史 54 | func (s *conversationService) DebugStreamAgent(ctx context.Context, userID uint, agentID string, message string) (*schema.StreamReader[*schema.Message], error) { 55 | // 创建用户消息,不含历史 56 | userMsg := model.UserMessage{ 57 | Query: message, 58 | History: []*schema.Message{}, 59 | } 60 | 61 | // 调用无状态的StreamExecuteAgent 62 | return s.agentSvc.StreamExecuteAgent(ctx, userID, agentID, userMsg) 63 | } 64 | 65 | // StreamAgentWithConversation 会话模式:记录历史 66 | func (s *conversationService) StreamAgentWithConversation(ctx context.Context, userID uint, agentID string, convID string, message string) (*schema.StreamReader[*schema.Message], error) { 67 | // 确保会话存在 68 | conv := &model.Conversation{ 69 | ConvID: convID, 70 | UserID: userID, 71 | AgentID: agentID, 72 | CreatedAt: time.Now().Unix(), 73 | UpdatedAt: time.Now().Unix(), 74 | } 75 | err := s.historySvc.CreateConversation(ctx, conv) 76 | if err != nil { 77 | // 可能是会话已存在,忽略错误 78 | log.Printf("[StreamAgentWithConversation] 创建会话失败: %w", err) 79 | return nil, fmt.Errorf("获取会话失败: %w", err) 80 | } 81 | 82 | // 先获取历史消息 83 | historyMsgs, err := s.historySvc.GetHistory(ctx, convID, 50) 84 | if err != nil { 85 | log.Printf("[StreamAgentWithConversation] 获取历史消息失败: %w", err) 86 | return nil, fmt.Errorf("获取历史消息失败: %w", err) 87 | } 88 | 89 | // 保存用户消息 90 | userSchemaMsg := &schema.Message{ 91 | Role: schema.User, 92 | Content: message, 93 | } 94 | err = s.historySvc.SaveMessage(ctx, userSchemaMsg, convID) 95 | if err != nil { 96 | log.Printf("[StreamAgentWithConversation] 保存用户消息失败: %w", err) 97 | return nil, fmt.Errorf("保存用户消息失败: %w", err) 98 | } 99 | 100 | // 创建用户消息,包含历史 101 | userMsg := model.UserMessage{ 102 | Query: message, 103 | History: historyMsgs, 104 | } 105 | 106 | // 调用Agent处理 107 | sr, err := s.agentSvc.StreamExecuteAgent(ctx, userID, agentID, userMsg) 108 | if err != nil { 109 | log.Printf("[StreamAgentWithConversation] 运行Agent失败: %w", err) 110 | return nil, fmt.Errorf("运行Agent失败: %w", err) 111 | } 112 | 113 | // 复制流 114 | srs := sr.Copy(2) 115 | 116 | // 创建一个独立的上下文用于保存消息,不依赖于请求上下文 117 | saveCtx := context.Background() 118 | 119 | // 后台处理:记录完整回复 120 | go func() { 121 | fullMsgs := make([]*schema.Message, 0) 122 | 123 | defer func() { 124 | srs[1].Close() 125 | // 合并消息 126 | if len(fullMsgs) > 0 { 127 | fullMsg, err := schema.ConcatMessages(fullMsgs) 128 | if err != nil { 129 | fmt.Println("合并消息失败:", err.Error()) 130 | return 131 | } 132 | 133 | // 使用独立上下文保存消息 134 | err = s.historySvc.SaveMessage(saveCtx, fullMsg, convID) 135 | if err != nil { 136 | fmt.Println("保存消息失败:", err.Error()) 137 | } 138 | 139 | // 更新会话最后更新时间 140 | conv.UpdatedAt = time.Now().Unix() 141 | _ = s.historySvc.UpdateConversation(saveCtx, conv) 142 | } 143 | }() 144 | 145 | outer: 146 | for { 147 | select { 148 | case <-ctx.Done(): 149 | fmt.Println("上下文已关闭:", ctx.Err()) 150 | return 151 | default: 152 | chunk, err := srs[1].Recv() 153 | if err != nil { 154 | if errors.Is(err, io.EOF) { 155 | break outer 156 | } 157 | fmt.Println("接收消息块错误:", err.Error()) 158 | return 159 | } 160 | 161 | fullMsgs = append(fullMsgs, chunk) 162 | } 163 | } 164 | }() 165 | 166 | return srs[0], nil 167 | } 168 | 169 | // CreateConversation 创建新会话 170 | func (s *conversationService) CreateConversation(ctx context.Context, userID uint, agentID string) (string, error) { 171 | convID := uuid.NewString() 172 | 173 | conv := &model.Conversation{ 174 | ConvID: convID, 175 | UserID: userID, 176 | AgentID: agentID, 177 | Title: defaultConvTitle + time.Now().String(), 178 | CreatedAt: time.Now().Unix(), 179 | UpdatedAt: time.Now().Unix(), 180 | } 181 | 182 | err := s.historySvc.CreateConversation(ctx, conv) 183 | if err != nil { 184 | return "", fmt.Errorf("[CreateConversation] 创建会话失败: %w", err) 185 | } 186 | 187 | return convID, nil 188 | } 189 | 190 | // ListConversations 列出用户所有会话 191 | func (s *conversationService) ListConversations(ctx context.Context, userID uint, page, size int) ([]*model.Conversation, int64, error) { 192 | return s.historySvc.ListConversations(ctx, userID, page, size) 193 | } 194 | 195 | // ListAgentConversations 列出特定Agent的会话 196 | func (s *conversationService) ListAgentConversations(ctx context.Context, userID uint, agentID string, page, size int) ([]*model.Conversation, int64, error) { 197 | return s.historySvc.ListConversationsByAgent(ctx, userID, agentID, page, size) 198 | } 199 | 200 | // GetConversationHistory 获取会话历史消息 201 | func (s *conversationService) GetConversationHistory(ctx context.Context, convID string, limit int) ([]*schema.Message, error) { 202 | return s.historySvc.GetHistory(ctx, convID, limit) 203 | } 204 | 205 | // DeleteConversation 删除会话 206 | func (s *conversationService) DeleteConversation(ctx context.Context, convID string) error { 207 | return s.historySvc.DeleteConversation(ctx, convID) 208 | } 209 | -------------------------------------------------------------------------------- /internal/service/history_service.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | hisdao "ai-cloud/internal/dao/history" 5 | "ai-cloud/internal/model" 6 | "ai-cloud/internal/utils" 7 | "context" 8 | "fmt" 9 | 10 | "github.com/cloudwego/eino/schema" 11 | ) 12 | 13 | type HistoryService interface { 14 | SaveMessage(ctx context.Context, mess *schema.Message, convID string) error 15 | GetHistory(ctx context.Context, convID string, limit int) ([]*schema.Message, error) 16 | CreateConversation(ctx context.Context, conv *model.Conversation) error 17 | UpdateConversation(ctx context.Context, conv *model.Conversation) error 18 | DeleteConversation(ctx context.Context, convID string) error 19 | ArchiveConversation(ctx context.Context, convID string) error 20 | UnArchiveConversation(ctx context.Context, convID string) error 21 | PinConversation(ctx context.Context, convID string) error 22 | UnPinConversation(ctx context.Context, convID string) error 23 | ListConversations(ctx context.Context, userID uint, page, size int) ([]*model.Conversation, int64, error) 24 | ListConversationsByAgent(ctx context.Context, userID uint, agentID string, page, size int) ([]*model.Conversation, int64, error) 25 | } 26 | type history struct { 27 | convDao hisdao.ConvDao 28 | msgDao hisdao.MsgDao 29 | } 30 | 31 | // NewHistoryService 创建历史记录服务 32 | func NewHistoryService(convDao hisdao.ConvDao, msgDao hisdao.MsgDao) HistoryService { 33 | return &history{ 34 | convDao: convDao, 35 | msgDao: msgDao, 36 | } 37 | } 38 | 39 | // SaveMessage 保存消息 40 | func (s *history) SaveMessage(ctx context.Context, mess *schema.Message, convID string) error { 41 | err := s.msgDao.Create(ctx, &model.Message{ 42 | Role: string(mess.Role), 43 | Content: mess.Content, 44 | ConvID: convID, 45 | }) 46 | if err != nil { 47 | return fmt.Errorf("failed to save message: %w", err) 48 | } 49 | return nil 50 | } 51 | 52 | // GetHistory 获取对话的历史消息 53 | func (s *history) GetHistory(ctx context.Context, convID string, limit int) ([]*schema.Message, error) { 54 | if limit == 0 { 55 | limit = 50 56 | } 57 | _, err := s.convDao.GetByID(ctx, convID) 58 | if err != nil { 59 | return nil, fmt.Errorf("failed to get conversation: %w", err) 60 | } 61 | 62 | msgs, _, err := s.msgDao.List(ctx, convID, 0, limit) 63 | if err != nil { 64 | return nil, fmt.Errorf("failed to get messages: %w", err) 65 | } 66 | 67 | return utils.MessageList2ChatHistory(msgs), nil 68 | } 69 | 70 | // CreateConversation 创建会话 71 | func (s *history) CreateConversation(ctx context.Context, conv *model.Conversation) error { 72 | if err := s.convDao.FirstOrCreate(ctx, conv); err != nil { 73 | return fmt.Errorf("failed to create conversation: %w", err) 74 | } 75 | return nil 76 | } 77 | 78 | // UpdateConversation 更新会话 79 | func (s *history) UpdateConversation(ctx context.Context, conv *model.Conversation) error { 80 | if err := s.convDao.Update(ctx, conv); err != nil { 81 | return fmt.Errorf("failed to update conversation: %w", err) 82 | } 83 | return nil 84 | } 85 | 86 | // ArchiveConversation 归档会话 87 | func (s *history) ArchiveConversation(ctx context.Context, convID string) error { 88 | if err := s.convDao.Archive(ctx, convID); err != nil { 89 | return fmt.Errorf("failed to archive conversation: %w", err) 90 | } 91 | return nil 92 | } 93 | 94 | // UnArchiveConversation 取消归档会话 95 | func (s *history) UnArchiveConversation(ctx context.Context, convID string) error { 96 | if err := s.convDao.UnArchive(ctx, convID); err != nil { 97 | return fmt.Errorf("failed to unarchive conversation: %w", err) 98 | } 99 | return nil 100 | } 101 | 102 | // PinConversation 置顶会话 103 | func (s *history) PinConversation(ctx context.Context, convID string) error { 104 | if err := s.convDao.Pin(ctx, convID); err != nil { 105 | return fmt.Errorf("failed to pin conversation: %w", err) 106 | } 107 | return nil 108 | } 109 | 110 | // UnPinConversation 取消置顶会话 111 | func (s *history) UnPinConversation(ctx context.Context, convID string) error { 112 | if err := s.convDao.UnPin(ctx, convID); err != nil { 113 | return fmt.Errorf("failed to unpin conversation: %w", err) 114 | } 115 | return nil 116 | } 117 | 118 | // ListConversations 获取对话列表 119 | func (s *history) ListConversations(ctx context.Context, userID uint, page, size int) ([]*model.Conversation, int64, error) { 120 | convs, count, err := s.convDao.Page(ctx, userID, page, size) 121 | if err != nil { 122 | return nil, 0, fmt.Errorf("failed to list conversations: %w", err) 123 | } 124 | return convs, count, nil 125 | } 126 | 127 | // ListConversationsByAgent 按Agent获取对话列表 128 | func (s *history) ListConversationsByAgent(ctx context.Context, userID uint, agentID string, page, size int) ([]*model.Conversation, int64, error) { 129 | convs, count, err := s.convDao.PageByAgent(ctx, userID, agentID, page, size) 130 | if err != nil { 131 | return nil, 0, fmt.Errorf("failed to list conversations by agent: %w", err) 132 | } 133 | return convs, count, nil 134 | } 135 | 136 | // DeleteConversation 删除会话 137 | func (s *history) DeleteConversation(ctx context.Context, convID string) error { 138 | // 先获取会话的所有消息 139 | msgs, err := s.msgDao.ListByConvID(ctx, convID) 140 | if err != nil { 141 | return fmt.Errorf("failed to list messages for conversation: %w", err) 142 | } 143 | 144 | // 删除会话的所有消息 145 | for _, msg := range msgs { 146 | if err := s.msgDao.Delete(ctx, msg.MsgID); err != nil { 147 | return fmt.Errorf("failed to delete message: %w", err) 148 | } 149 | } 150 | 151 | // 删除会话 152 | if err := s.convDao.Delete(ctx, convID); err != nil { 153 | return fmt.Errorf("failed to delete conversation: %w", err) 154 | } 155 | return nil 156 | } 157 | -------------------------------------------------------------------------------- /internal/service/model_service.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "ai-cloud/internal/dao" 5 | "ai-cloud/internal/model" 6 | "context" 7 | ) 8 | 9 | type ModelService interface { 10 | CreateModel(ctx context.Context, m *model.Model) error 11 | UpdateModel(ctx context.Context, m *model.Model) error 12 | DeleteModel(ctx context.Context, userID uint, id string) error 13 | GetModel(ctx context.Context, userID uint, id string) (*model.Model, error) 14 | ListModels(ctx context.Context, userID uint, modelType string) ([]*model.Model, error) 15 | PageModels(ctx context.Context, userID uint, modelType string, page, size int) ([]*model.Model, int64, error) 16 | } 17 | 18 | type modelService struct { 19 | dao dao.ModelDao 20 | } 21 | 22 | func NewModelService(dao dao.ModelDao) ModelService { 23 | return &modelService{dao: dao} 24 | } 25 | 26 | func (s *modelService) CreateModel(ctx context.Context, m *model.Model) error { 27 | return s.dao.Create(ctx, m) 28 | } 29 | 30 | func (s *modelService) UpdateModel(ctx context.Context, m *model.Model) error { 31 | return s.dao.Update(ctx, m) 32 | } 33 | 34 | func (s *modelService) DeleteModel(ctx context.Context, userID uint, id string) error { 35 | return s.dao.Delete(ctx, userID, id) 36 | } 37 | 38 | func (s *modelService) GetModel(ctx context.Context, userID uint, id string) (*model.Model, error) { 39 | return s.dao.GetByID(ctx, userID, id) 40 | } 41 | 42 | func (s *modelService) ListModels(ctx context.Context, userID uint, modelType string) ([]*model.Model, error) { 43 | return s.dao.List(ctx, userID, modelType) 44 | } 45 | 46 | func (s *modelService) PageModels(ctx context.Context, userID uint, modelType string, page, size int) ([]*model.Model, int64, error) { 47 | return s.dao.Page(ctx, userID, modelType, page, size) 48 | } 49 | -------------------------------------------------------------------------------- /internal/service/user_service.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "ai-cloud/config" 5 | "ai-cloud/internal/dao" 6 | "ai-cloud/internal/middleware" 7 | "ai-cloud/internal/model" 8 | "errors" 9 | "golang.org/x/crypto/bcrypt" 10 | ) 11 | 12 | type UserService interface { 13 | Register(user *model.User) error 14 | Login(req *model.UserNameLoginReq) (*model.LoginResponse, error) 15 | } 16 | 17 | type userService struct { 18 | userDao dao.UserDao 19 | } 20 | 21 | func NewUserService(userDao dao.UserDao) UserService { 22 | return &userService{userDao: userDao} 23 | } 24 | 25 | func (s *userService) Register(user *model.User) error { 26 | // 检查 27 | usernameExists, err := s.userDao.CheckFieldExists("username", user.Username) 28 | if err != nil { 29 | return err 30 | } 31 | if usernameExists { 32 | return errors.New("用户名已注册") 33 | } 34 | 35 | phoneExists, err := s.userDao.CheckFieldExists("phone", user.Phone) 36 | if err != nil { 37 | return err 38 | } 39 | if phoneExists { 40 | return errors.New("手机号已注册") 41 | } 42 | 43 | hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost) 44 | if err != nil { 45 | return errors.New("密码加密失败") 46 | } 47 | 48 | newUser := &model.User{ 49 | Username: user.Username, 50 | Phone: user.Phone, 51 | Password: string(hashedPassword), 52 | Email: user.Email, 53 | } 54 | err = s.userDao.CreateUser(newUser) 55 | if err != nil { 56 | return err 57 | } 58 | return nil 59 | } 60 | 61 | func (s *userService) Login(req *model.UserNameLoginReq) (*model.LoginResponse, error) { 62 | user, err := s.userDao.GetUserByName(req.Username) 63 | if err != nil { 64 | return nil, err 65 | } 66 | 67 | if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil { 68 | return nil, errors.New("用户名或密码错误") 69 | } 70 | 71 | accessToken, err := middleware.GenerateToken(user.ID) 72 | if err != nil { 73 | return nil, errors.New("系统错误") 74 | } 75 | 76 | return &model.LoginResponse{ 77 | AccessToken: accessToken, 78 | ExpiresIn: config.AppConfigInstance.JWT.ExpirationHours * 3600, 79 | TokenType: "Bearer", 80 | }, nil 81 | } 82 | -------------------------------------------------------------------------------- /internal/storage/local.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | ) 8 | 9 | // LocalStorage 本地存储驱动结构体 10 | type LocalStorage struct { 11 | baseDir string // 本地存储根目录(如 ./storage_data) 12 | } 13 | 14 | // NewLocalStorage 初始化本地存储 15 | func NewLocalStorage(baseDir string) (*LocalStorage, error) { 16 | // 确保存储目录存在 17 | if err := os.MkdirAll(baseDir, 0755); err != nil { 18 | return nil, fmt.Errorf("failed to create local storage dir: %v", err) 19 | } 20 | return &LocalStorage{baseDir: baseDir}, nil 21 | } 22 | 23 | // Upload 上传文件到本地 24 | func (s *LocalStorage) Upload(data []byte, key string, contentType string) error { 25 | fullPath := filepath.Join(s.baseDir, key) 26 | // 确保父目录存在 27 | if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil { 28 | return fmt.Errorf("failed to create parent dir: %v", err) 29 | } 30 | return os.WriteFile(fullPath, data, 0644) 31 | } 32 | 33 | // Download 从本地下载文件 34 | func (s *LocalStorage) Download(key string) ([]byte, error) { 35 | fullPath := filepath.Join(s.baseDir, key) 36 | return os.ReadFile(fullPath) 37 | } 38 | 39 | // Delete 删除本地文件 40 | func (s *LocalStorage) Delete(key string) error { 41 | fullPath := filepath.Join(s.baseDir, key) 42 | return os.Remove(fullPath) 43 | } 44 | 45 | // GetURL 获取本地文件路径(仅返回相对路径) 46 | func (s *LocalStorage) GetURL(key string) (string, error) { 47 | return filepath.Join(s.baseDir, key), nil 48 | } 49 | -------------------------------------------------------------------------------- /internal/storage/minio.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import ( 4 | "ai-cloud/config" 5 | "bytes" 6 | "context" 7 | "fmt" 8 | "io" 9 | "net/url" 10 | "strings" 11 | "time" 12 | 13 | "github.com/minio/minio-go/v7" 14 | "github.com/minio/minio-go/v7/pkg/credentials" 15 | ) 16 | 17 | type MinioStorage struct { 18 | client *minio.Client 19 | bucket string 20 | } 21 | 22 | // NewMinioStorage 创建新的 Minio 存储实例 23 | func NewMinioStorage(cfg config.MinioConfig) (Driver, error) { 24 | // 设置中国时区 25 | loc, err := time.LoadLocation("Asia/Shanghai") 26 | if err != nil { 27 | return nil, fmt.Errorf("failed to load timezone: %v", err) 28 | } 29 | time.Local = loc 30 | 31 | // 初始化 Minio 客户端 32 | client, err := minio.New(cfg.Endpoint, &minio.Options{ 33 | Creds: credentials.NewStaticV4(cfg.AccessKeyID, cfg.AccessKeySecret, ""), 34 | Secure: cfg.UseSSL, 35 | Region: cfg.Region, 36 | }) 37 | if err != nil { 38 | return nil, fmt.Errorf("failed to create minio client: %v", err) 39 | } 40 | 41 | // 检查 bucket 是否存在 42 | exists, err := client.BucketExists(context.Background(), cfg.Bucket) 43 | if err != nil { 44 | return nil, fmt.Errorf("failed to check bucket existence: %v", err) 45 | } 46 | 47 | // 如果 bucket 不存在,创建它 48 | if !exists { 49 | err = client.MakeBucket(context.Background(), cfg.Bucket, minio.MakeBucketOptions{ 50 | Region: cfg.Region, 51 | }) 52 | if err != nil { 53 | return nil, fmt.Errorf("failed to create bucket: %v", err) 54 | } 55 | } 56 | 57 | return &MinioStorage{ 58 | client: client, 59 | bucket: cfg.Bucket, 60 | }, nil 61 | } 62 | 63 | // Upload 上传文件到 Minio 64 | func (m *MinioStorage) Upload(data []byte, key string, contentType string) error { 65 | reader := bytes.NewReader(data) 66 | _, err := m.client.PutObject( 67 | context.Background(), 68 | m.bucket, 69 | key, 70 | reader, 71 | int64(len(data)), 72 | minio.PutObjectOptions{ 73 | ContentType: contentType, // 例如 "application/pdf" 74 | }, 75 | ) 76 | if err != nil { 77 | return fmt.Errorf("failed to upload file: %v", err) 78 | } 79 | return nil 80 | } 81 | 82 | // Download 从 Minio 下载文件 83 | func (m *MinioStorage) Download(key string) ([]byte, error) { 84 | obj, err := m.client.GetObject( 85 | context.Background(), 86 | m.bucket, 87 | key, 88 | minio.GetObjectOptions{}, 89 | ) 90 | if err != nil { 91 | return nil, fmt.Errorf("failed to get object: %v", err) 92 | } 93 | defer obj.Close() 94 | 95 | data, err := io.ReadAll(obj) 96 | if err != nil { 97 | return nil, fmt.Errorf("failed to read object data: %v", err) 98 | } 99 | return data, nil 100 | } 101 | 102 | // Delete 从 Minio 删除文件 103 | func (m *MinioStorage) Delete(key string) error { 104 | err := m.client.RemoveObject(context.Background(), m.bucket, key, minio.RemoveObjectOptions{}) 105 | if err != nil { 106 | return fmt.Errorf("failed to delete object: %v", err) 107 | } 108 | return nil 109 | } 110 | 111 | func (m *MinioStorage) GetURL(key string) (string, error) { 112 | // 设置响应头,强制浏览器下载文件 113 | reqParams := make(url.Values) 114 | reqParams.Set("response-content-disposition", "attachment") 115 | 116 | // 生成预签名URL,有效期1小时 117 | expiry := time.Hour * 1 118 | presignedURL, err := m.client.PresignedGetObject( 119 | context.Background(), 120 | m.bucket, 121 | key, 122 | expiry, 123 | reqParams, // 关键:传递自定义参数 124 | ) 125 | if err != nil { 126 | return "", fmt.Errorf("failed to generate presigned URL: %v", err) 127 | } 128 | 129 | return presignedURL.String(), nil 130 | } 131 | 132 | // CreateDirectory 创建目录(通过上传空对象实现) 133 | func (m *MinioStorage) CreateDirectory(dirPath string) error { 134 | // 确保路径以 / 结尾 135 | if !strings.HasSuffix(dirPath, "/") { 136 | dirPath = dirPath + "/" 137 | } 138 | 139 | // 上传一个空对象来表示目录 140 | _, err := m.client.PutObject(context.Background(), m.bucket, dirPath, bytes.NewReader([]byte{}), 0, minio.PutObjectOptions{}) 141 | if err != nil { 142 | return fmt.Errorf("failed to create directory: %v", err) 143 | } 144 | return nil 145 | } 146 | -------------------------------------------------------------------------------- /internal/storage/oss.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import ( 4 | "ai-cloud/config" 5 | "bytes" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/aliyun/aliyun-oss-go-sdk/oss" 10 | ) 11 | 12 | // OSSStorage 阿里云OSS存储驱动结构体 13 | type OSSStorage struct { 14 | bucket *oss.Bucket // OSS Bucket实例 15 | } 16 | 17 | // NewOSSStorage 初始化OSS存储 18 | func NewOSSStorage(cfg config.OSSConfig) (*OSSStorage, error) { 19 | // 创建OSS客户端 20 | client, err := oss.New(cfg.Endpoint, cfg.AccessKeyID, cfg.AccessKeySecret) 21 | if err != nil { 22 | return nil, fmt.Errorf("failed to create OSS client: %v", err) 23 | } 24 | 25 | // 获取Bucket实例 26 | bucket, err := client.Bucket(cfg.Bucket) 27 | if err != nil { 28 | return nil, fmt.Errorf("failed to get OSS bucket: %v", err) 29 | } 30 | 31 | return &OSSStorage{bucket: bucket}, nil 32 | } 33 | 34 | // Upload 上传文件到OSS 35 | func (s *OSSStorage) Upload(data []byte, key string, contentType string) error { 36 | return s.bucket.PutObject(key, bytes.NewReader(data)) 37 | } 38 | 39 | // Download 从OSS下载文件 40 | func (s *OSSStorage) Download(key string) ([]byte, error) { 41 | reader, err := s.bucket.GetObject(key) 42 | if err != nil { 43 | return nil, fmt.Errorf("failed to download from OSS: %v", err) 44 | } 45 | defer reader.Close() 46 | 47 | buf := new(bytes.Buffer) 48 | if _, err := buf.ReadFrom(reader); err != nil { 49 | return nil, fmt.Errorf("failed to read OSS data: %v", err) 50 | } 51 | return buf.Bytes(), nil 52 | } 53 | 54 | // Delete 删除OSS文件 55 | func (s *OSSStorage) Delete(key string) error { 56 | return s.bucket.DeleteObject(key) 57 | } 58 | 59 | // GetURL 生成带签名的临时访问URL(有效期1小时) 60 | func (s *OSSStorage) GetURL(key string) (string, error) { 61 | expired := time.Now().Add(1 * time.Hour) 62 | return s.bucket.SignURL(key, oss.HTTPGet, int64(expired.Unix())) 63 | } 64 | -------------------------------------------------------------------------------- /internal/storage/storage.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import ( 4 | "ai-cloud/config" 5 | "fmt" 6 | ) 7 | 8 | // Driver 定义存储驱动接口 9 | type Driver interface { 10 | Upload(data []byte, key string, contentType string) error // 上传文件 11 | Download(key string) ([]byte, error) // 下载文件 12 | Delete(key string) error // 删除文件 13 | GetURL(key string) (string, error) // 获取访问URL 14 | } 15 | 16 | // NewDriver 根据配置初始化存储驱动 17 | func NewDriver(cfg config.StorageConfig) (Driver, error) { 18 | switch cfg.Type { 19 | case "local": 20 | return NewLocalStorage(cfg.Local.BaseDir) 21 | case "oss": 22 | return NewOSSStorage(cfg.OSS) 23 | case "minio": 24 | return NewMinioStorage(cfg.Minio) 25 | default: 26 | return nil, fmt.Errorf("unsupported storage type: %s", cfg.Type) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /internal/utils/agent_utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | // ReplaceKnowledgePlaceholder replaces the {{Knowledge}} placeholder in the prompt 8 | func ReplaceKnowledgePlaceholder(prompt string, knowledge string) string { 9 | return strings.Replace(prompt, "{{Knowledge}}", knowledge, -1) 10 | } 11 | 12 | // FormatRetrievalResults formats the retrieval results for insertion into the prompt 13 | func FormatRetrievalResults(results []map[string]interface{}) string { 14 | var builder strings.Builder 15 | 16 | for i, result := range results { 17 | if i > 0 { 18 | builder.WriteString("\n\n") 19 | } 20 | 21 | // Extract content and metadata 22 | if content, ok := result["content"].(string); ok { 23 | builder.WriteString(content) 24 | } 25 | 26 | // Add source if available 27 | if source, ok := result["source"].(string); ok { 28 | builder.WriteString("\nSource: " + source) 29 | } 30 | } 31 | 32 | return builder.String() 33 | } 34 | -------------------------------------------------------------------------------- /internal/utils/context.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "errors" 5 | "github.com/gin-gonic/gin" 6 | ) 7 | 8 | // 定义上下文键名(避免硬编码) 9 | const UserIDKey = "user_id" 10 | 11 | func GetUserIDFromContext(c *gin.Context) (uint, error) { 12 | // 从上下文中获取值 13 | userIDVal, exists := c.Get(UserIDKey) 14 | if !exists { 15 | return 0, errors.New("上下文中未找到用户ID") 16 | } 17 | 18 | // 类型断言 19 | userID, ok := userIDVal.(uint) 20 | if !ok { 21 | return 0, errors.New("用户ID类型错误") 22 | } 23 | 24 | return userID, nil 25 | } 26 | -------------------------------------------------------------------------------- /internal/utils/convert.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "strconv" 5 | ) 6 | 7 | // StringToInt 将字符串转换为整数,出错时返回默认值0 8 | func StringToInt(s string) int { 9 | i, err := strconv.Atoi(s) 10 | if err != nil { 11 | return 0 12 | } 13 | return i 14 | } 15 | 16 | // StringToInt64 将字符串转换为int64,出错时返回默认值0 17 | func StringToInt64(s string) int64 { 18 | i, err := strconv.ParseInt(s, 10, 64) 19 | if err != nil { 20 | return 0 21 | } 22 | return i 23 | } 24 | 25 | // IntToString 将整数转换为字符串 26 | func IntToString(i int) string { 27 | return strconv.Itoa(i) 28 | } 29 | 30 | // Int64ToString 将int64转换为字符串 31 | func Int64ToString(i int64) string { 32 | return strconv.FormatInt(i, 10) 33 | } 34 | -------------------------------------------------------------------------------- /internal/utils/convert_float.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | func ConvertFloat64ToFloat32Embeddings(embeddings [][]float64) [][]float32 { 4 | float32Embeddings := make([][]float32, len(embeddings)) 5 | for i, vec64 := range embeddings { 6 | vec32 := make([]float32, len(vec64)) 7 | for j, v := range vec64 { 8 | vec32[j] = float32(v) 9 | } 10 | float32Embeddings[i] = vec32 11 | } 12 | return float32Embeddings 13 | } 14 | 15 | func ConvertFloat64ToFloat32Embedding(embedding []float64) []float32 { 16 | float32Embedding := make([]float32, len(embedding)) 17 | for i, v := range embedding { 18 | float32Embedding[i] = float32(v) 19 | } 20 | return float32Embedding 21 | } 22 | -------------------------------------------------------------------------------- /internal/utils/hitory_utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "ai-cloud/internal/model" 5 | "github.com/cloudwego/eino/schema" 6 | ) 7 | 8 | func MessageList2ChatHistory(mess []*model.Message) (history []*schema.Message) { 9 | for _, m := range mess { 10 | history = append(history, message2MessagesTemplate(m)) 11 | } 12 | return 13 | } 14 | 15 | func message2MessagesTemplate(mess *model.Message) *schema.Message { 16 | return &schema.Message{ 17 | Role: schema.RoleType(mess.Role), 18 | Content: mess.Content, 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /internal/utils/pagination.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "github.com/gin-gonic/gin" 6 | "strconv" 7 | ) 8 | 9 | func ParsePaginationParams(ctx *gin.Context) (page int, pageSize int, err error) { 10 | pageStr := ctx.DefaultQuery("page", "1") 11 | page, err = strconv.Atoi(pageStr) 12 | if err != nil || page < 1 { 13 | return 0, 0, fmt.Errorf("无效的页码参数") 14 | 15 | } 16 | 17 | pageSizeStr := ctx.DefaultQuery("page_size", "10") 18 | pageSize, err = strconv.Atoi(pageSizeStr) 19 | if err != nil || pageSize < 1 { 20 | return 0, 0, fmt.Errorf("无效的页面大小参数") 21 | } 22 | 23 | return page, pageSize, nil 24 | } 25 | -------------------------------------------------------------------------------- /internal/utils/uuid.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "github.com/google/uuid" 4 | 5 | func GenerateUUID() string { 6 | return uuid.New().String() 7 | } 8 | -------------------------------------------------------------------------------- /internal/utils/validate_sort.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | func ValidateSortParameter(sort string, allowedFields []string) error { 9 | clauses := strings.Split(sort, ",") 10 | 11 | for _, clause := range clauses { 12 | parts := strings.Split(clause, ":") 13 | field, order := parts[0], parts[1] 14 | if !contains(allowedFields, field) { 15 | return fmt.Errorf("无效的排序字段:%s", field) 16 | } 17 | if order != "asc" && order != "desc" { 18 | return fmt.Errorf("无效的排序方向:%s", order) 19 | } 20 | } 21 | return nil 22 | } 23 | 24 | func contains(slice []string, item string) bool { 25 | for _, s := range slice { 26 | if s == item { 27 | return true 28 | } 29 | } 30 | return false 31 | } 32 | -------------------------------------------------------------------------------- /pkgs/consts/milvus_const.go: -------------------------------------------------------------------------------- 1 | package consts 2 | 3 | // 字段名称常量定义 4 | const ( 5 | // FieldNameID ID字段名 6 | FieldNameID = "id" 7 | // FieldNameContent 内容字段名 8 | FieldNameContent = "content" 9 | // FieldNameDocumentID 文档ID字段名 10 | FieldNameDocumentID = "document_id" 11 | // FieldNameDocumentName 文档名称字段名 12 | FieldNameDocumentName = "document_name" 13 | // FieldNameKBID 知识库ID字段名 14 | FieldNameKBID = "kb_id" 15 | // FieldNameChunkIndex 块索引字段名 16 | FieldNameChunkIndex = "chunk_index" 17 | // FieldNameVector 向量字段名 18 | FieldNameVector = "vector" 19 | // FiledNameMetadata meta信息 20 | FieldNameMetadata = "metadata" 21 | ) 22 | -------------------------------------------------------------------------------- /pkgs/errcode/errcode.go: -------------------------------------------------------------------------------- 1 | package errcode 2 | 3 | const ( 4 | /******************** 基础错误码 (10000-19999) ********************/ 5 | // 通用系统错误 6 | SuccessCode = 0 // 特殊保留成功码 7 | InternalServerError = 10001 // 服务器内部错误 8 | DatabaseError = 10002 // 数据库操作错误 9 | CacheError = 10003 // 缓存服务错误 10 | RateLimitExceeded = 10004 // 请求频率限制 11 | 12 | // 请求参数相关 13 | ParamBindError = 10101 // 参数绑定错误 14 | ParamValidateError = 10102 // 参数验证失败 15 | 16 | // 认证授权相关 17 | UnauthorizedError = 10201 // 身份未认证 18 | ForbiddenError = 10202 // 无访问权限 19 | TokenExpired = 10203 // Token过期 20 | TokenInvalid = 10204 // Token无效 21 | TokenMissing = 10205 // Token缺失 22 | 23 | /******************** 业务错误码 ********************/ 24 | // 用户模块 (20000-20999) 25 | UserNotFound = 20001 // 用户不存在 26 | UserAlreadyExists = 20002 // 用户已存在 27 | PasswordMismatch = 20003 // 密码错误 28 | UserDisabled = 20004 // 用户被禁用 29 | EmailInvalid = 20005 // 邮箱格式错误 30 | 31 | // 文件模块 (21000-21999) 32 | FileNotFound = 21001 // 文件不存在 33 | FileUploadFailed = 21002 // 文件上传失败 34 | FileDeleteFailed = 21003 // 文件删除失败 35 | FileSizeExceeded = 21004 // 文件大小超限 36 | FileTypeInvalid = 21005 // 文件类型无效 37 | FileParseFailed = 21006 // 文件解析失败 38 | FileListFailed = 21007 // 文杰列表获取失败 39 | FileSearchFailed = 21008 // 文件搜索失败 40 | // 订单模块 (22000-22999) 41 | // 可后续扩展... 42 | ) 43 | -------------------------------------------------------------------------------- /pkgs/response/response.go: -------------------------------------------------------------------------------- 1 | package response 2 | 3 | import ( 4 | "github.com/gin-gonic/gin" 5 | "net/http" 6 | ) 7 | 8 | // Response 统一响应结构 9 | type Response struct { 10 | Code int `json:"code"` // 业务码 11 | Message string `json:"message"` // 提示信息 12 | Data interface{} `json:"data"` // 数据 13 | } 14 | 15 | // 分页数据结构 16 | type PageData struct { 17 | Total int64 `json:"total"` // 总数 18 | List interface{} `json:"list"` // 数据列表 19 | } 20 | 21 | const ( 22 | SUCCESS = 0 23 | ERROR = 1 24 | ERROR_PARAM = 2 25 | ERROR_AUTH = 401 26 | ERROR_SERVER = 500 27 | ) 28 | 29 | // Success 成功响应 30 | func Success(c *gin.Context, data interface{}) { 31 | c.JSON(http.StatusOK, Response{ 32 | Code: SUCCESS, 33 | Message: "success", 34 | Data: data, 35 | }) 36 | } 37 | 38 | // SuccessWithMessage 成功响应带自定义消息 39 | func SuccessWithMessage(c *gin.Context, message string, data interface{}) { 40 | c.JSON(http.StatusOK, Response{ 41 | Code: SUCCESS, 42 | Message: message, 43 | Data: data, 44 | }) 45 | } 46 | 47 | // PageSuccess 分页数据响应 48 | func PageSuccess(c *gin.Context, list interface{}, total int64) { 49 | c.JSON(http.StatusOK, Response{ 50 | Code: SUCCESS, 51 | Message: "success", 52 | Data: PageData{ 53 | List: list, 54 | Total: total, 55 | }, 56 | }) 57 | } 58 | 59 | // Error 错误响应 60 | func Error(c *gin.Context, message string) { 61 | c.JSON(http.StatusOK, Response{ 62 | Code: ERROR, 63 | Message: message, 64 | Data: nil, 65 | }) 66 | } 67 | 68 | //// ErrorWithCode 错误响应带自定义错误码 69 | //func ErrorWithCode(c *gin.Context, code int, message string) { 70 | // c.JSON(http.StatusOK, Response{ 71 | // Code: code, 72 | // Message: message, 73 | // Data: nil, 74 | // }) 75 | //} 76 | 77 | // 错误响应 78 | func ErrorCustom(c *gin.Context, httpCode, code int, msg string, data interface{}) { 79 | c.JSON(httpCode, Response{ 80 | Code: code, 81 | Message: msg, 82 | Data: data, 83 | }) 84 | } 85 | func ParamError(c *gin.Context, code int, msg string) { 86 | c.JSON(http.StatusBadRequest, Response{ 87 | Code: code, 88 | Message: msg, 89 | }) 90 | } 91 | 92 | func UnauthorizedError(c *gin.Context, code int, msg string) { 93 | c.JSON(http.StatusUnauthorized, Response{ 94 | Code: code, 95 | Message: msg, 96 | }) 97 | } 98 | 99 | func InternalError(c *gin.Context, code int, msg string) { 100 | c.JSON(http.StatusInternalServerError, Response{ 101 | Code: code, 102 | Message: msg, 103 | }) 104 | } 105 | --------------------------------------------------------------------------------