├── .gitignore
├── README.md
├── README_EN.md
├── cmd
└── main.go
├── config
├── config.go
├── config.yaml.example
└── types.go
├── docker-compose.yml
├── docs
├── AgentChat.png
├── AgentDebug.png
├── CONFIG_README.md
└── QUICKSTART.md
├── go.mod
├── go.sum
├── init.sql
├── internal
├── component
│ ├── embedding
│ │ ├── embedding.go
│ │ ├── ollama.go
│ │ └── openai.go
│ ├── indexer
│ │ └── milvus
│ │ │ ├── indexer.go
│ │ │ └── types.go
│ ├── llm
│ │ ├── llm.go
│ │ ├── ollama
│ │ │ ├── call_option.go
│ │ │ └── ollama.go
│ │ └── openai
│ │ │ └── openai.go
│ ├── parser
│ │ └── pdf
│ │ │ └── docconv_parser.go
│ └── retriever
│ │ └── milvus
│ │ ├── multi_retriever.go
│ │ ├── retriever.go
│ │ └── types.go
├── controller
│ ├── agent_controller.go
│ ├── conversation_controller.go
│ ├── file_controller.go
│ ├── kb_controller.go
│ ├── model_controller.go
│ └── user_controller.go
├── dao
│ ├── agent.go
│ ├── file_dao.go
│ ├── history
│ │ ├── conv_dao.go
│ │ └── msg_dao.go
│ ├── kb_dao.go
│ ├── model_dao.go
│ └── user_dao.go
├── database
│ ├── milvus.go
│ └── mysql.go
├── middleware
│ ├── auth.go
│ ├── cors.go
│ └── jwt.go
├── model
│ ├── agent.go
│ ├── chat.go
│ ├── conversation.go
│ ├── file.go
│ ├── knowledge.go
│ ├── model.go
│ └── user.go
├── router
│ └── router.go
├── service
│ ├── agent_service.go
│ ├── conversation_service.go
│ ├── file_service.go
│ ├── history_service.go
│ ├── kb_service.go
│ ├── model_service.go
│ └── user_service.go
├── storage
│ ├── local.go
│ ├── minio.go
│ ├── oss.go
│ └── storage.go
└── utils
│ ├── agent_utils.go
│ ├── context.go
│ ├── convert.go
│ ├── convert_float.go
│ ├── hitory_utils.go
│ ├── pagination.go
│ ├── uuid.go
│ └── validate_sort.go
└── pkgs
├── consts
└── milvus_const.go
├── errcode
└── errcode.go
└── response
└── response.go
/.gitignore:
--------------------------------------------------------------------------------
1 | # 编译生成的二进制文件
2 | *.exe
3 | *.exe~
4 | *.dll
5 | *.so
6 | *.dylib
7 | *.out
8 |
9 | # 测试文件
10 | *.test
11 | *.out
12 | coverage.txt
13 |
14 | # Go 专用
15 | *.o
16 | *.a
17 | *.so
18 | _obj
19 | _test
20 | *.[568vq]
21 | [568vq].out
22 | *.cgo1.go
23 | *.cgo2.c
24 | _cgo_defun.c
25 | _cgo_gotypes.go
26 | _cgo_export.*
27 | _testmain.go
28 | go.work
29 |
30 | # 依赖管理
31 | vendor/
32 | Godeps/
33 | go-build
34 |
35 | # 日志和临时文件
36 | logs/
37 | *.log
38 | tmp/
39 | temp/
40 |
41 | # 特定 IDE 和编辑器
42 | .idea/
43 | .vscode/
44 | *.swp
45 | *.swo
46 | *~
47 | .DS_Store
48 |
49 | # 环境和配置文件(注意不要排除 config 目录)
50 | .env
51 | .envrc
52 | env.sh
53 | .env.local
54 | .env.development.local
55 | .env.test.local
56 | .env.production.local
57 |
58 | # 实际的配置文件(不排除示例文件)
59 | config.yaml
60 | config.yml
61 | config.json
62 | config.toml
63 | config.ini
64 | config.hcl
65 | config.props
66 | config.properties
67 | config/config.yaml
68 | config/config.yml
69 | config/config.json
70 | config/config.toml
71 |
72 | # 明确说明不排除 config 目录和示例文件
73 | !config/
74 | !*.example
75 | !*.example.*
76 | !*.sample
77 | !*.sample.*
78 | !*.template
79 | !*.template.*
80 |
81 | # 敏感信息和凭证
82 | *.pem
83 | *.key
84 | *.crt
85 | *.p12
86 | *.pfx
87 | *.der
88 | *.csr
89 | *.cer
90 | *.keystore
91 | *.jks
92 | *.p7b
93 | *.p7c
94 | *.secret
95 | *.password
96 | *.credentials
97 | *secret*
98 | *password*
99 | *credential*
100 | *token*
101 | *apikey*
102 | *api_key*
103 |
104 | # 其他可能的敏感文件
105 | *.bak
106 | *.backup
107 |
108 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AI-Cloud-Go 基于Golang开发的LLM应用系统
2 |
3 | [English](README_EN) | [中文](README)
4 |
5 | ## 项目简介
6 | AI-Cloud-Go 是一个基于 Go 语言实现的LLM应用于开发系统,提供文件存储、用户管理、知识库管理、模型管理、Agent等功能,采用现代化的技术栈和架构设计。系统支持多种存储后端,并集成了向量数据库以支持智能检索功能。
7 |
8 | 前端界面展示: [AI-Cloud-Frontend](https://github.com/RaspberryCola/AI-Cloud-Frontend)
11 |
12 |
13 | ## 功能说明
14 |
15 | ### 部分功能展示
16 |
17 | **Agent配置**
18 |
19 | 
20 |
21 | **Agent Chat**
22 | 
23 |
24 | **已经实现:**
25 |
26 | - [x] 用户模块:支持用户注册、登录、认证
27 | - [x] 多存储后端:支持本地存储、MinIO、阿里云OSS等多种存储方式
28 | - [x] 文件模块:支持文件上传、下载、管理
29 | - [x] 知识库模块:支持创建和管理知识库,支持导入云盘文件或上传新文件
30 | - [x] 模型模块:支持创建和管理自定义LLM模型和Embedding模型
31 | - [x] Agent模块:支持创建和管理Agent
32 | - [x] 支持自定义LLM、知识库、MCP
33 | - [x] 对话界面、历史对话
34 |
35 | **未来优化**
36 |
37 | - 知识库模块:多文件上传/优化解析状态处理/支持Rerank
38 | - Agent模块:自定义Tool(HTTP工具)/优化LLM的参数配置/跨知识库检索的Rerank实现
39 | - 模型管理:添加常用模型预设:OpenAI,DeepSeek,火山引擎等/支持Rerank模型
40 |
41 | ## 技术栈
42 | - 后端框架:Gin
43 | - 数据库:MySQL
44 | - 向量数据库:Milvus
45 | - 对象存储:MinIO/阿里云OSS
46 | - 认证:JWT
47 | - LLM框架:Eino
48 | - 其他:跨域中间件、自定义中间件等
49 |
50 | ## 目录结构
51 | ```
52 | .
53 | ├── cmd/ # 主程序入口
54 | ├── config/ # 配置文件
55 | ├── docs/ # 文档目录
56 | │ ├── CONFIG_README.md # 配置指南
57 | │ └── QUICKSTART.md # 快速启动指南
58 | ├── internal/ # 内部包
59 | │ ├── component/ # 大模型相关服务
60 | │ ├── controller/ # 控制器层
61 | │ ├── service/ # 业务逻辑层
62 | │ ├── dao/ # 数据访问层
63 | │ ├── middleware/ # 中间件
64 | │ ├── router/ # 路由配置
65 | │ ├── database/ # 数据库(MySQL/Milvus...)
66 | │ ├── model/ # 数据模型
67 | │ ├── storage/ # 后端存储实现(Minio/OSS...)
68 | │ └── utils/ # 工具函数
69 | ├── pkgs/ # 公共包
70 | ├── docker-compose.yml # Docker配置文件
71 | ├── go.mod # Go 模块文件
72 | └── go.sum # 依赖版本锁定文件
73 | ```
74 |
75 | ## 使用说明
76 | 1. 克隆项目
77 | ```bash
78 | git clone https://github.com/RaspberryCola/AI-Cloud-Go.git
79 | cd AI-Cloud-Go
80 | ```
81 |
82 | 2. 安装依赖
83 | ```bash
84 | go mod download
85 | ```
86 |
87 | 3. 配置环境
88 | - 确保已安装并启动 MySQL
89 | - 确保已安装并启动 Milvus(如需使用向量检索功能)
90 | - 配置存储后端(本地存储/MinIO/阿里云OSS)
91 |
92 | 可以通过Docker快速配置:
93 | ```bash
94 | docker-compose up -d
95 | ```
96 | 4. 修改配置信息
97 | - 修改 `config/config.yaml` 中的相关配置
98 |
99 | 5. 运行项目
100 | ```bash
101 | go run cmd/main.go
102 | ```
103 |
104 | # AI-Cloud-Go Docker 完整环境配置
105 |
106 | > 更多详情请看 [docs 文件夹](/docs/)
107 |
108 | ## 前置条件
109 | - 已安装 Docker 和 Docker Compose
110 | - 基本了解 Docker 的使用
111 |
112 | ## 服务组件
113 | 本配置包含以下服务:
114 | - **MySQL**: 数据库服务器,配置 `ai_cloud` 数据库
115 | - **MinIO**: 对象存储服务,兼容 S3 协议,配置 `ai-cloud` 存储桶
116 | - **Milvus**: 向量数据库,用于存储和检索文档向量
117 |
118 | ## 启动步骤
119 |
120 | 1. 环境配置
121 | - 修改 `config/config.yaml` 文件,配置各服务连接信息和LLM的API密钥信息
122 | ```yaml
123 | llm:
124 | api_key: "your-llm-api-key"
125 | model: "deepseek-chat"
126 | base_url: "https://api.deepseek.com/v1"
127 | ```
128 | ⚠️语言模型配置后续将会移动到统一的模型服务管理中
129 | 2. 启动 Docker 容器
130 | ```bash
131 | docker-compose up -d
132 | ```
133 |
134 | 3. 检查服务状态
135 | ```bash
136 | docker ps
137 | ```
138 |
139 | 4. 运行项目
140 | ```bash
141 | go run cmd/main.go
142 | ```
143 |
144 | 5. 停止容器
145 | ```bash
146 | docker-compose down
147 | ```
148 |
149 | ## 配置详情
150 |
151 | ### MySQL
152 | - 主机: localhost
153 | - 端口: 3306
154 | - 用户名: root
155 | - 密码: 123456
156 | - 数据库: ai_cloud
157 |
158 | ### MinIO (对象存储)
159 | - 端点: localhost:9000
160 | - 管理控制台: http://localhost:9001
161 | - 访问密钥: minioadmin
162 | - 密钥: minioadmin
163 | - 存储桶: ai-cloud
164 |
165 | ### Milvus (向量数据库)
166 | - 配置路径: `milvus.address` 在 `config.yaml` 中
167 | - 默认地址: localhost:19530
168 | - 管理界面: 需要额外安装Attu (Milvus官方GUI工具)
169 | - 在端口9091只提供监控信息: http://localhost:9091/webui (Milvus 2.5.0+版本)
170 | - 完整管理界面需安装Attu: `docker run -p 8000:3000 -e MILVUS_URL=localhost:19530 zilliz/attu:latest`
171 | - 访问Attu: http://localhost:8000
172 |
173 | ### 环境配置
174 | 项目使用 `config/config.yaml` 文件配置服务连接和第三方 AI 模型的访问,主要包括:
175 |
176 | 1.**语言模型配置**
177 | - 用于知识库问答和智能处理
178 | - 默认使用 DeepSeek 的 deepseek-chat 模型
179 |
180 | 2.**Milvus配置**
181 | - 用于向量存储和检索
182 | - 配置项: `milvus.address`
183 |
184 | ## 故障排除
185 |
186 | ### 常见问题
187 |
188 | 1. **程序启动卡住**
189 | - 检查 Milvus 是否正常启动,查看日志 `docker logs milvus-standalone`
190 | - 检查 `config/config.yaml` 文件中的 Milvus 地址配置是否正确
191 | - 检查 `config/config.yaml` 文件是否配置了正确的 API 密钥
192 | - 确保 MySQL 中已创建 ai_cloud 数据库
193 |
194 | 2. **MinIO 连接问题**
195 | - 检查 MinIO 服务是否正常运行
196 | - 验证 config.yaml 中的 MinIO 配置是否与实际运行环境一致
197 |
198 | 3. **向量数据库操作失败**
199 | - 检查 Milvus 服务状态
200 | - 确认向量维度设置是否与模型输出一致
201 | - 确认 config.yaml 中的 milvus.address 配置是否正确
202 |
203 | 4. **Milvus 管理界面访问失败**
204 | - Milvus 2.5以上版本在端口9091提供简易WebUI: http://localhost:9091/webui
205 | - 完整的管理界面需要安装Attu工具: `docker run -p 8000:3000 -e MILVUS_URL=localhost:19530 zilliz/attu:latest`
206 |
207 | ## 开发调试
208 |
209 | ### API 测试
210 | 项目启动后,可通过以下端点进行测试:
211 | - 用户注册: POST http://localhost:8080/api/users/register
212 | - 用户登录: POST http://localhost:8080/api/users/login
213 | - 文件上传: POST http://localhost:8080/api/files/upload
214 | - 更多 API 详见代码中的路由配置
215 |
216 | ### 服务地址
217 | - 应用后端: http://localhost:8080
218 | - MinIO 控制台: http://localhost:9001
219 | - Milvus 管理界面: http://localhost:9091
220 | - Ollama 服务: http://localhost:11434 (如果启用)
221 |
222 | ## 注意事项
223 | - 如果您已经在本地运行了 MySQL,可以直接使用 `mysql -u root -p < init.sql` 创建 ai_cloud 数据库
224 | - 首次启动时,Milvus 会自动创建必要的集合和索引,可能需要一些时间
225 | - 使用生产环境时,请替换配置文件中的示例 API 密钥为您自己的有效密钥
226 | - 配置 config.yaml 时确保数据库和存储服务的连接信息与 Docker 环境一致
227 | - Go主程序使用端口8080,Ollama服务使用端口11434,避免端口冲突
--------------------------------------------------------------------------------
/README_EN.md:
--------------------------------------------------------------------------------
1 | # AI-Cloud-Go — A Golang-Based Cloud Drive & Knowledge Base System
2 |
3 | [English](README_EN) | [中文](README)
4 |
5 | ## Overview
6 | AI-Cloud-Go is a cloud drive and knowledge base system built using the Go programming language. It offers features such as file storage, user management, knowledge base management, and model management, leveraging modern technology stacks and architecture design. The system supports multiple storage backends and integrates with a vector database to enable intelligent search capabilities.
7 |
8 | Frontend repository: [AI-Cloud-Frontend](https://github.com/RaspberryCola/AI-Cloud-Frontend)
11 |
12 | ## Key Features
13 |
14 | ### Completed:
15 | - [x] User system: Supports registration, login, and authentication
16 | - [x] Cloud file system: Supports file upload, download, and management
17 | - [x] Knowledge base management: Create and manage knowledge bases, import files from cloud or upload new ones
18 | - [x] Model management: Manage custom LLM and Embedding models
19 | - [x] Multiple storage backends: Supports local storage, MinIO, Alibaba Cloud OSS, etc.
20 | - [x] Vector search: Integrated with Milvus vector database for smart document retrieval
21 |
22 | ### In Development:
23 | - [ ] Agent functionality
24 |
25 | ## Tech Stack
26 | - Backend Framework: Gin
27 | - Database: MySQL
28 | - Vector Database: Milvus
29 | - Object Storage: MinIO / Alibaba Cloud OSS
30 | - Authentication: JWT (JSON Web Token)
31 | - LLM Framework: Eino
32 | - Others: CORS middleware, custom middleware, etc.
33 |
34 | ## Directory Structure
35 | ```
36 | .
37 | ├── cmd/ # Main application entry point
38 | ├── config/ # Configuration files
39 | ├── docs/ # Documentation
40 | │ ├── CONFIG_README.md # Configuration guide
41 | │ └── QUICKSTART.md # Quick start guide
42 | ├── internal/ # Internal packages
43 | │ ├── component/ # Large model related services
44 | │ ├── controller/ # Controller layer
45 | │ ├── service/ # Business logic layer
46 | │ ├── dao/ # Data access layer
47 | │ ├── middleware/ # Middleware components
48 | │ ├── router/ # Routing configuration
49 | │ ├── database/ # Database (MySQL/Milvus...)
50 | │ ├── model/ # Data models
51 | │ ├── storage/ # Storage backend implementations (MinIO/OSS...)
52 | │ └── utils/ # Utility functions
53 | ├── pkgs/ # Shared packages
54 | ├── docker-compose.yml # Docker configuration
55 | ├── go.mod # Go module file
56 | └── go.sum # Dependency version lock
57 | ```
58 |
59 | ## Usage Instructions
60 |
61 | 1. Clone the project:
62 | ```bash
63 | git clone https://github.com/RaspberryCola/AI-Cloud-Go.git
64 | cd AI-Cloud-Go
65 | ```
66 |
67 | 2. Install dependencies:
68 | ```bash
69 | go mod download
70 | ```
71 |
72 | 3. Environment setup:
73 | - Ensure MySQL is installed and running
74 | - Ensure Milvus is installed if using vector search
75 | - Configure storage backend (local/MinIO/Alibaba Cloud OSS)
76 |
77 | You can use Docker Compose for quick setup:
78 | ```bash
79 | docker-compose up -d
80 | ```
81 |
82 | 4. Modify configuration:
83 | - Update `config/config.yaml` accordingly
84 |
85 | 5. Run the project:
86 | ```bash
87 | go run cmd/main.go
88 | ```
89 |
90 | # AI-Cloud-Go — Docker Full Environment Setup
91 |
92 | > For more details, see the [docs folder](/docs/)
93 |
94 | ## Prerequisites
95 | - Docker and Docker Compose installed
96 | - Basic understanding of Docker usage
97 |
98 | ## Services Included
99 | This setup includes the following services:
100 | - **MySQL**: Database server, configured with the `ai_cloud` database
101 | - **MinIO**: Object storage service compatible with S3 protocol, configured with bucket `ai-cloud`
102 | - **Milvus**: Vector database for storing and retrieving document vectors
103 |
104 | ## Setup Steps
105 |
106 | 1. **Environment Configuration**
107 | - Edit `config/config.yaml` to configure service connection info and LLM API keys
108 | ```yaml
109 | llm:
110 | api_key: "your-llm-api-key"
111 | model: "deepseek-chat"
112 | base_url: "https://api.deepseek.com/v1"
113 | ```
114 | ⚠️ LLM configuration will be moved into unified model management in the future
115 |
116 | 2. **Start Docker Containers**
117 | ```bash
118 | docker-compose up -d
119 | ```
120 |
121 | 3. **Check Service Status**
122 | ```bash
123 | docker ps
124 | ```
125 |
126 | 4. **Run the Application**
127 | ```bash
128 | go run cmd/main.go
129 | ```
130 |
131 | 5. **Stop Containers**
132 | ```bash
133 | docker-compose down
134 | ```
135 |
136 | ## Configuration Details
137 |
138 | ### MySQL
139 | - Host: localhost
140 | - Port: 3306
141 | - Username: root
142 | - Password: 123456
143 | - Database: ai_cloud
144 |
145 | ### MinIO (Object Storage)
146 | - Endpoint: localhost:9000
147 | - Management Console: http://localhost:9001
148 | - Access Key: minioadmin
149 | - Secret Key: minioadmin
150 | - Bucket: ai-cloud
151 |
152 | ### Milvus (Vector Database)
153 | - Config path: `milvus.address` in `config.yaml`
154 | - Default address: localhost:19530
155 | - Admin UI: Requires installing Attu (official Milvus GUI tool)
156 | - Lightweight web UI on port 9091: http://localhost:9091/webui (for Milvus 2.5.0+)
157 | - Full admin interface via Attu:
158 | ```bash
159 | docker run -p 8000:3000 -e MILVUS_URL=localhost:19530 zilliz/attu:latest
160 | ```
161 | - Visit Attu at: http://localhost:8000
162 |
163 | ### Environment Configuration
164 | The system uses `config/config.yaml` to configure service connections and AI model access, including:
165 |
166 | 1. **Language Model Configuration**
167 | - Used for Q&A and intelligent processing
168 | - Default: DeepSeek's `deepseek-chat` model
169 |
170 | 2. **Milvus Configuration**
171 | - For vector storage and retrieval
172 | - Config: `milvus.address`
173 |
174 | ## Troubleshooting
175 |
176 | ### Common Issues
177 |
178 | 1. **Application Hangs on Startup**
179 | - Check if Milvus started correctly: `docker logs milvus-standalone`
180 | - Confirm Milvus address in `config.yaml` is correct
181 | - Ensure valid API key is set in config.yaml
182 | - Make sure the `ai_cloud` database exists in MySQL
183 |
184 | 2. **MinIO Connection Problems**
185 | - Verify MinIO is running
186 | - Confirm config.yaml matches actual MinIO environment settings
187 |
188 | 3. **Vector DB Operations Fail**
189 | - Check Milvus service status
190 | - Ensure vector dimension matches model output
191 | - Confirm `milvus.address` is properly set
192 |
193 | 4. **Cannot Access Milvus Web UI**
194 | - Milvus 2.5+ provides basic UI at http://localhost:9091/webui
195 | - For full UI, install Attu with:
196 | ```bash
197 | docker run -p 8000:3000 -e MILVUS_URL=localhost:19530 zilliz/attu:latest
198 | ```
199 | - Access at: http://localhost:8000
200 |
201 | ## Development and Debugging
202 |
203 | ### API Testing
204 | After starting the app, you can test APIs like:
205 | - User Registration: POST http://localhost:8080/api/users/register
206 | - User Login: POST http://localhost:8080/api/users/login
207 | - File Upload: POST http://localhost:8080/api/files/upload
208 | - See code for more endpoints
209 |
210 | ### Service URLs
211 | - App Backend: http://localhost:8080
212 | - MinIO Console: http://localhost:9001
213 | - Milvus Admin UI: http://localhost:9091
214 | - Ollama Service: http://localhost:11434 (if enabled)
215 |
216 | ## Notes
217 | - If you already have MySQL running locally, create the db via:
218 | ```bash
219 | mysql -u root -p < init.sql
220 | ```
221 | - Milvus automatically creates collections/indexes on first launch (may take time)
222 | - Replace sample API keys in config.yaml with your own credentials before production
223 | - Ensure config.yaml reflects Docker service addresses
224 | - Go server uses port 8080; Ollama uses 11434 – avoid conflicts
225 |
226 |
227 | ## Third-Party Licenses
228 |
229 | This project uses the following third-party libraries:
230 |
231 | - **eino-history**
232 | - Source: [https://github.com/HildaM/eino-history ](https://github.com/HildaM/eino-history )
233 | - License: [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0 )
234 | - Description: A chat history management library for Eino large model framework.
235 |
236 | Modifications have been made to the original code to suit the needs of this project.
--------------------------------------------------------------------------------
/cmd/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "ai-cloud/config"
5 | _ "ai-cloud/internal/component/embedding"
6 | "ai-cloud/internal/controller"
7 | "ai-cloud/internal/dao"
8 | "ai-cloud/internal/dao/history"
9 | "ai-cloud/internal/database"
10 | "ai-cloud/internal/middleware"
11 | "ai-cloud/internal/router"
12 | "ai-cloud/internal/service"
13 | "context"
14 |
15 | "github.com/gin-gonic/gin"
16 | )
17 |
18 | func main() {
19 | config.InitConfig()
20 | ctx := context.Background()
21 |
22 | db, _ := database.GetDB()
23 |
24 | userDao := dao.NewUserDao(db)
25 | userService := service.NewUserService(userDao)
26 | userController := controller.NewUserController(userService)
27 | fileDao := dao.NewFileDao(db)
28 | fileService := service.NewFileService(fileDao)
29 | fileController := controller.NewFileController(fileService)
30 |
31 | milvusClient, _ := database.InitMilvus(ctx)
32 | defer milvusClient.Close()
33 |
34 | modelDao := dao.NewModelDao(db)
35 | modelService := service.NewModelService(modelDao)
36 | modelController := controller.NewModelController(modelService)
37 |
38 | kbDao := dao.NewKnowledgeBaseDao(db)
39 | kbService := service.NewKBService(kbDao, fileService, modelDao)
40 | kbController := controller.NewKBController(kbService, fileService)
41 |
42 | msgDao := history.NewMsgDao(db)
43 | convDao := history.NewConvDao(db)
44 | historyService := service.NewHistoryService(convDao, msgDao)
45 |
46 | agentDao := dao.NewAgentDao(db)
47 | agentService := service.NewAgentService(agentDao, modelService, kbService, kbDao, modelDao, historyService)
48 | agentController := controller.NewAgentController(agentService)
49 |
50 | // 创建ConversationService和ConversationController
51 | conversationService := service.NewConversationService(agentService, historyService)
52 | conversationController := controller.NewConversationController(conversationService)
53 |
54 | r := gin.Default()
55 | // 配置跨域
56 | r.Use(middleware.SetupCORS())
57 | // 配置路由
58 | router.SetUpRouters(r, userController, fileController, kbController, modelController, agentController, conversationController)
59 |
60 | r.Run(":8080")
61 | }
62 |
--------------------------------------------------------------------------------
/config/config.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "log"
5 |
6 | "github.com/fsnotify/fsnotify"
7 | "github.com/spf13/viper"
8 | )
9 |
10 | var AppConfigInstance *AppConfig
11 |
12 | // InitConfig 初始化配置
13 | func InitConfig() {
14 | // 初始化 AppConfigInstance
15 | AppConfigInstance = &AppConfig{}
16 |
17 | // 加载配置
18 | viper.SetConfigName("config")
19 | viper.SetConfigType("yaml")
20 | viper.AddConfigPath("./config")
21 |
22 | if err := viper.ReadInConfig(); err != nil {
23 | log.Fatalf("Error reading config file: %v", err)
24 | }
25 |
26 | // 监听配置变化
27 | viper.WatchConfig()
28 | viper.OnConfigChange(func(e fsnotify.Event) {
29 | if err := viper.Unmarshal(AppConfigInstance); err != nil {
30 | log.Printf("loadConfig failed, unmarshal config err: %v", err)
31 | }
32 | })
33 |
34 | // 解析配置
35 | if err := viper.Unmarshal(AppConfigInstance); err != nil {
36 | log.Fatalf("Unable to decode into struct: %v", err)
37 | }
38 | }
39 |
40 | // GetConfig 获取配置
41 | func GetConfig() *AppConfig {
42 | return AppConfigInstance
43 | }
44 |
--------------------------------------------------------------------------------
/config/config.yaml.example:
--------------------------------------------------------------------------------
1 | server:
2 | port: "8080"
3 |
4 | # mysql配置
5 | database:
6 | host: "localhost"
7 | port: "3306"
8 | user: "root"
9 | password: "123456"
10 | name: "ai_cloud"
11 |
12 | jwt:
13 | secret: "your-jwt-secret"
14 | expiration_hours: 24
15 |
16 | # 后端文件存储服务
17 | storage:
18 | type: "minio" # local, oss, minio
19 | local:
20 | base_dir: "./storage_data"
21 | oss:
22 | endpoint: "your-oss-endpoint"
23 | bucket: "your-oss-bucket"
24 | access_key_id: "your-access-key-id"
25 | access_key_secret: "your-access-key-secret"
26 | minio:
27 | endpoint: "localhost:9000"
28 | bucket: "ai-cloud"
29 | access_key_id: "minioadmin"
30 | access_key_secret: "minioadmin"
31 | use_ssl: false
32 | region: ""
33 |
34 | # Milvus向量数据库配置
35 | milvus:
36 | address: "localhost:19530"
37 | # collection_name: "text_chunks"
38 | # vector_dimension: 1024
39 | index_type: "IVF_FLAT"
40 | metric_type: "COSINE"
41 | nlist: 128
42 | # 搜索参数
43 | nprobe: 16
44 | # 字段最大长度配置
45 | id_max_length: "64"
46 | content_max_length: "65535"
47 | doc_id_max_length: "64"
48 | doc_name_max_length: "256"
49 | kb_id_max_length: "64"
50 |
51 | rag:
52 | chunk_size: 1500
53 | overlap_size: 500
54 |
55 | cors:
56 | allow_origins:
57 | - "*"
58 | allow_methods:
59 | - "GET"
60 | - "POST"
61 | - "PUT"
62 | - "PATCH"
63 | - "DELETE"
64 | - "OPTIONS"
65 | allow_headers:
66 | - "Origin"
67 | - "Content-Type"
68 | - "Accept"
69 | - "Authorization"
70 | expose_headers:
71 | - "Content-Length"
72 | allow_credentials: true
73 | max_age: "12h"
74 |
75 | ## LLM配置
76 | llm:
77 | api_key: "your-llm-api-key"
78 | model: "deepseek-chat"
79 | base_url: "https://api.deepseek.com/v1"
80 | max_tokens: 10240
81 | temperature: 0.7
--------------------------------------------------------------------------------
/config/types.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "github.com/milvus-io/milvus-sdk-go/v2/entity"
5 | )
6 |
7 | // ServerConfig 服务器配置
8 | type ServerConfig struct {
9 | Port string `mapstructure:"port"`
10 | }
11 |
12 | // DatabaseConfig 数据库配置
13 | type DatabaseConfig struct {
14 | Host string `mapstructure:"host"`
15 | Port string `mapstructure:"port"`
16 | User string `mapstructure:"user"`
17 | Password string `mapstructure:"password"`
18 | Name string `mapstructure:"name"`
19 | }
20 |
21 | // JWTConfig JWT配置
22 | type JWTConfig struct {
23 | Secret string `mapstructure:"secret"`
24 | ExpirationHours int `mapstructure:"expiration_hours"`
25 | }
26 |
27 | // MinioConfig Minio配置
28 | type MinioConfig struct {
29 | Endpoint string `mapstructure:"endpoint"`
30 | Bucket string `mapstructure:"bucket"`
31 | AccessKeyID string `mapstructure:"access_key_id"`
32 | AccessKeySecret string `mapstructure:"access_key_secret"`
33 | UseSSL bool `mapstructure:"use_ssl"`
34 | Region string `mapstructure:"region"`
35 | }
36 |
37 | // MilvusConfig Milvus向量数据库配置
38 | type MilvusConfig struct {
39 | Address string `mapstructure:"address"`
40 | CollectionName string `mapstructure:"collection_name"`
41 | VectorDimension int `mapstructure:"vector_dimension"`
42 | IndexType string `mapstructure:"index_type"`
43 | MetricType string `mapstructure:"metric_type"`
44 | Nlist int `mapstructure:"nlist"`
45 | // 搜索参数
46 | Nprobe int `mapstructure:"nprobe"`
47 | // 字段最大长度配置
48 | IDMaxLength string `mapstructure:"id_max_length"`
49 | ContentMaxLength string `mapstructure:"content_max_length"`
50 | DocIDMaxLength string `mapstructure:"doc_id_max_length"`
51 | DocNameMaxLength string `mapstructure:"doc_name_max_length"`
52 | KbIDMaxLength string `mapstructure:"kb_id_max_length"`
53 | }
54 |
55 | // GetMetricType 获取类型
56 | func (m *MilvusConfig) GetMetricType() entity.MetricType {
57 | // 获取配置的度量类型
58 | var metricType entity.MetricType
59 | switch m.MetricType {
60 | case "L2":
61 | metricType = entity.L2 // 欧几里得距离:测量向量间的直线距离,适合图像特征等数值型向量
62 | case "IP":
63 | metricType = entity.IP // 内积距离:适合已归一化的向量,计算效率高
64 | default:
65 | metricType = entity.COSINE // 余弦相似度:测量向量方向的相似性,适合文本语义搜索
66 | }
67 | return metricType
68 | }
69 |
70 | // GetMilvusIndex 根据配置构建索引
71 | func (m *MilvusConfig) GetMilvusIndex() (entity.Index, error) {
72 | // 选择索引类型的距离度量方式
73 | metricType := m.GetMetricType()
74 |
75 | // 创建索引
76 | var (
77 | idx entity.Index
78 | err error
79 | )
80 | if m.Nlist <= 0 {
81 | m.Nlist = 128 // 为空,取默认值
82 | }
83 |
84 | switch m.IndexType {
85 | case "IVF_FLAT":
86 | // IVF_FLAT: 倒排文件索引 + 原始向量存储
87 | // 优点:搜索精度高;缺点:内存占用较大
88 | // nlist: 聚类数量,值越大精度越高但速度越慢,通常设置为 sqrt(n) 到 4*sqrt(n),其中n为向量数量
89 | idx, err = entity.NewIndexIvfFlat(metricType, m.Nlist)
90 | case "IVF_SQ8":
91 | // IVF_SQ8: 倒排文件索引 + 标量量化压缩存储(8位)
92 | // 优点:比IVF_FLAT节省内存;缺点:轻微精度损失
93 | // nlist: 与IVF_FLAT相同,根据数据规模调整
94 | idx, err = entity.NewIndexIvfSQ8(metricType, m.Nlist)
95 | case "HNSW":
96 | // HNSW: 层次可导航小世界图索引,高效且精确但内存占用大
97 | // M: 每个节点的最大边数,影响图的连通性和构建/查询性能
98 | // - 值越大,构建越慢,内存占用越大,但查询越精确
99 | // - 通常取值范围为8-64之间,默认值8在大多数场景下平衡了性能和精度
100 | // efConstruction: 构建索引时每层搜索的候选邻居数量
101 | // - 值越大,构建越慢,索引质量越高
102 | // - 通常取值范围为40-800,默认值40在大多数场景下表现良好
103 | // 注:这两个参数需要根据数据特性和性能要求综合调优,目前使用经验值
104 | idx, err = entity.NewIndexHNSW(metricType, 8, 40) // M=8, efConstruction=40
105 | default:
106 | // 默认使用IVF_FLAT,兼顾搜索精度和性能
107 | idx, err = entity.NewIndexIvfFlat(metricType, m.Nlist)
108 | }
109 | return idx, err
110 | }
111 |
112 | // StorageConfig 存储配置
113 | type StorageConfig struct {
114 | Type string `mapstructure:"type"` // local/oss/minio
115 | Local LocalConfig `mapstructure:"local"`
116 | OSS OSSConfig `mapstructure:"oss"`
117 | Minio MinioConfig `mapstructure:"minio"`
118 | }
119 |
120 | // LocalConfig 本地存储配置
121 | type LocalConfig struct {
122 | BaseDir string `mapstructure:"base_dir"` // 本地存储根目录(如 /data/storage)
123 | }
124 |
125 | // OSSConfig OSS配置
126 | type OSSConfig struct {
127 | Endpoint string `mapstructure:"endpoint"`
128 | Bucket string `mapstructure:"bucket"`
129 | AccessKeyID string `mapstructure:"access_key_id"`
130 | AccessKeySecret string `mapstructure:"access_key_secret"`
131 | }
132 |
133 | // CORSConfig CORS配置
134 | type CORSConfig struct {
135 | AllowOrigins []string `mapstructure:"allow_origins"`
136 | AllowMethods []string `mapstructure:"allow_methods"`
137 | AllowHeaders []string `mapstructure:"allow_headers"`
138 | ExposeHeaders []string `mapstructure:"expose_headers"`
139 | AllowCredentials bool `mapstructure:"allow_credentials"`
140 | MaxAge string `mapstructure:"max_age"` // 使用字符串表示时间,便于配置
141 | }
142 |
143 | // RAGConfig RAG配置
144 | type RAGConfig struct {
145 | ChunkSize int `mapstructure:"chunk_size"`
146 | OverlapSize int `mapstructure:"overlap_size"`
147 | }
148 |
149 | // LLMConfig 语言模型配置
150 | type LLMConfig struct {
151 | APIKey string `mapstructure:"api_key"`
152 | Model string `mapstructure:"model"`
153 | BaseURL string `mapstructure:"base_url"`
154 | MaxTokens int `mapstructure:"max_tokens"`
155 | Temperature float32 `mapstructure:"temperature"`
156 | }
157 |
158 | // AppConfig 应用配置
159 | type AppConfig struct {
160 | Server ServerConfig `mapstructure:"server"`
161 | Database DatabaseConfig `mapstructure:"database"`
162 | JWT JWTConfig `mapstructure:"jwt"`
163 | Storage StorageConfig `mapstructure:"storage"`
164 | CORS CORSConfig `mapstructure:"cors"`
165 | RAG RAGConfig `mapstructure:"rag"`
166 | LLM LLMConfig `mapstructure:"llm"`
167 | Milvus MilvusConfig `mapstructure:"milvus"`
168 | }
169 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3'
2 |
3 | services:
4 | # MySQL service - create ai_cloud database
5 | mysql:
6 | image: mysql:8.0
7 | ports:
8 | - "3306:3306"
9 | volumes:
10 | - mysql_data:/var/lib/mysql
11 | - ./init.sql:/docker-entrypoint-initdb.d/init.sql
12 | environment:
13 | MYSQL_ROOT_PASSWORD: 123456
14 | MYSQL_DATABASE: ai_cloud
15 | networks:
16 | - ai-cloud-network
17 | restart: always
18 |
19 | # MinIO service
20 | minio:
21 | image: minio/minio
22 | ports:
23 | - "9000:9000"
24 | - "9001:9001"
25 | volumes:
26 | - minio_data:/data
27 | environment:
28 | MINIO_ROOT_USER: minioadmin
29 | MINIO_ROOT_PASSWORD: minioadmin
30 | command: server /data --console-address ":9001"
31 | networks:
32 | - ai-cloud-network
33 | healthcheck:
34 | test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
35 | interval: 30s
36 | timeout: 20s
37 | retries: 3
38 |
39 | # Create MinIO bucket
40 | minio-init:
41 | image: minio/mc
42 | depends_on:
43 | - minio
44 | entrypoint: >
45 | /bin/sh -c "
46 | /usr/bin/mc config host add myminio http://minio:9000 minioadmin minioadmin;
47 | /usr/bin/mc mb myminio/ai-cloud;
48 | exit 0;
49 | "
50 | networks:
51 | - ai-cloud-network
52 |
53 | # Milvus standalone service
54 | etcd:
55 | container_name: milvus-etcd
56 | image: quay.io/coreos/etcd:v3.5.5
57 | environment:
58 | - ETCD_AUTO_COMPACTION_MODE=revision
59 | - ETCD_AUTO_COMPACTION_RETENTION=1000
60 | - ETCD_QUOTA_BACKEND_BYTES=4294967296
61 | - ETCD_SNAPSHOT_COUNT=50000
62 | volumes:
63 | - etcd_data:/etcd
64 | command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
65 | networks:
66 | - ai-cloud-network
67 |
68 | minio-for-milvus:
69 | container_name: milvus-minio
70 | image: minio/minio:RELEASE.2023-03-20T20-16-18Z
71 | environment:
72 | MINIO_ACCESS_KEY: minioadmin
73 | MINIO_SECRET_KEY: minioadmin
74 | volumes:
75 | - minio_data_for_milvus:/minio_data
76 | command: minio server /minio_data
77 | healthcheck:
78 | test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
79 | interval: 30s
80 | timeout: 20s
81 | retries: 3
82 | networks:
83 | - ai-cloud-network
84 |
85 | standalone:
86 | container_name: milvus-standalone
87 | image: milvusdb/milvus:v2.4.1
88 | command: ["milvus", "run", "standalone"]
89 | environment:
90 | ETCD_ENDPOINTS: etcd:2379
91 | MINIO_ADDRESS: minio-for-milvus:9000
92 | volumes:
93 | - milvus_data:/var/lib/milvus
94 | ports:
95 | - "19530:19530"
96 | - "9091:9091"
97 | depends_on:
98 | - etcd
99 | - minio-for-milvus
100 | networks:
101 | - ai-cloud-network
102 |
103 | # Attu - Milvus管理界面
104 | attu:
105 | container_name: milvus-attu
106 | image: zilliz/attu:latest
107 | environment:
108 | MILVUS_URL: standalone:19530
109 | ports:
110 | - "8000:3000"
111 | depends_on:
112 | - standalone
113 | networks:
114 | - ai-cloud-network
115 |
116 | volumes:
117 | mysql_data:
118 | minio_data:
119 | etcd_data:
120 | minio_data_for_milvus:
121 | milvus_data:
122 |
123 | networks:
124 | ai-cloud-network:
125 | driver: bridge
--------------------------------------------------------------------------------
/docs/AgentChat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RaspberryCola/AI-Cloud-Go/aa509c7429f0a408319274896c2b8af8ca01e5be/docs/AgentChat.png
--------------------------------------------------------------------------------
/docs/AgentDebug.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RaspberryCola/AI-Cloud-Go/aa509c7429f0a408319274896c2b8af8ca01e5be/docs/AgentDebug.png
--------------------------------------------------------------------------------
/docs/CONFIG_README.md:
--------------------------------------------------------------------------------
1 | # AI-Cloud-Go 配置指南
2 |
3 | ### 配置文件结构
4 |
5 | 配置文件(`config/config.yaml`)结构如下:
6 |
7 | ```yaml
8 | server:
9 | port: "8080"
10 |
11 | database:
12 | host: "localhost"
13 | port: "3306"
14 | user: "root"
15 | password: "123456"
16 | name: "ai_cloud"
17 |
18 | jwt:
19 | secret: "your-jwt-secret"
20 | expiration_hours: 24
21 |
22 | storage:
23 | type: "minio" # local, oss, minio
24 | local:
25 | base_dir: "./storage_data"
26 | oss:
27 | endpoint: "oss-endpoint"
28 | bucket: "bucket-name"
29 | access_key_id: ""
30 | access_key_secret: ""
31 | minio:
32 | endpoint: "localhost:9000"
33 | bucket: "ai-cloud"
34 | access_key_id: "minioadmin"
35 | access_key_secret: "minioadmin"
36 | use_ssl: false
37 | region: ""
38 |
39 | # Milvus向量数据库配置
40 | milvus:
41 | address: "localhost:19530" # Milvus服务地址
42 | index_type: "IVF_FLAT" # 索引类型 (IVF_FLAT, IVF_SQ8, HNSW)
43 | metric_type: "COSINE" # 距离计算方式 (COSINE, L2, IP)
44 | nlist: 128 # IVF索引聚类数量
45 | # 搜索参数
46 | nprobe: 16 # 搜索时检查的聚类数量,值越大结果越精确但越慢
47 | # 字段最大长度配置
48 | id_max_length: "64" # ID字段最大长度
49 | content_max_length: "65535" # 内容字段最大长度
50 | doc_id_max_length: "64" # 文档ID字段最大长度
51 | doc_name_max_length: "256" # 文档名称字段最大长度
52 | kb_id_max_length: "64" # 知识库ID字段最大长度
53 |
54 | rag:
55 | chunk_size: 1500
56 | overlap_size: 500
57 |
58 | cors:
59 | # CORS配置...
60 |
61 | # 语言模型配置(后续移除,通过统一的模块管理)
62 | llm:
63 | api_key: "your-llm-api-key"
64 | model: "deepseek-chat"
65 | base_url: "https://api.deepseek.com/v1"
66 | max_tokens: 4096
67 | temperature: 0.7
68 | ```
69 |
70 | ## 配置使用
71 |
72 | 在Go代码中,可以通过以下方式访问配置:
73 |
74 | ```go
75 | import "ai-cloud/config"
76 |
77 | func main() {
78 | // 初始化配置(在应用启动时调用一次)
79 | config.InitConfig()
80 |
81 | // 获取配置实例
82 | cfg := config.GetConfig()
83 |
84 | // 访问配置项
85 | port := cfg.Server.Port
86 | embeddingService := cfg.Embedding.Service
87 | llmModel := cfg.LLM.Model
88 | milvusAddress := cfg.Milvus.Address
89 |
90 | // ...
91 | }
92 | ```
93 |
94 | ## Milvus向量数据库配置
95 |
96 | 配置Milvus服务连接地址和向量集合参数:
97 |
98 | ```yaml
99 | # Milvus向量数据库配置
100 | milvus:
101 | address: "localhost:19530" # Milvus服务地址
102 | index_type: "IVF_FLAT" # 索引类型 (IVF_FLAT, IVF_SQ8, HNSW)
103 | metric_type: "COSINE" # 距离计算方式 (COSINE, L2, IP)
104 | nlist: 128 # IVF索引聚类数量
105 | # 搜索参数
106 | nprobe: 16 # 搜索时检查的聚类数量,值越大结果越精确但越慢
107 | # 字段最大长度配置
108 | id_max_length: "64" # ID字段最大长度
109 | content_max_length: "65535" # 内容字段最大长度
110 | doc_id_max_length: "64" # 文档ID字段最大长度
111 | doc_name_max_length: "256" # 文档名称字段最大长度
112 | kb_id_max_length: "64" # 知识库ID字段最大长度
113 | ```
114 |
115 | 此配置在初始化Milvus客户端和创建集合时使用:
116 |
117 | ```go
118 | // 初始化Milvus客户端
119 | milvusClient, err := client.NewClient(ctx, client.Config{
120 | Address: config.GetConfig().Milvus.Address,
121 | })
122 |
123 | // 使用配置创建集合
124 | milvusConfig := config.GetConfig().Milvus
125 | address := milvusConfig.Address
126 | // ...
127 | ```
128 |
129 | ## 语言模型配置
130 |
131 | 配置LLM服务:
132 |
133 | ```yaml
134 | llm:
135 | api_key: "your-api-key"
136 | model: "deepseek-chat" # 或其他支持的模型
137 | base_url: "https://api.deepseek.com/v1"
138 | max_tokens: 4096
139 | temperature: 0.7
140 | ```
141 |
142 | ## 从环境变量迁移到配置文件
143 |
144 | 如果您之前使用`.env`文件配置项目,请按照以下对应关系迁移到`config.yaml`:
145 |
146 | | 环境变量 | 配置文件路径 |
147 | |---------|------------|
148 | | `LLM_API_KEY` | `llm.api_key` |
149 | | `LLM_MODEL` | `llm.model` |
150 | | `LLM_BASE_URL` | `llm.base_url` |
151 | | `MILVUS_ADDRESS` | `milvus.address` |
152 |
153 | ## 注意事项
154 |
155 | 1. 配置文件中的敏感信息(如API密钥)不应提交到版本控制系统
156 | 2. 可以考虑使用环境变量覆盖配置文件中的敏感信息
157 | 3. 为不同环境(开发、测试、生产)准备不同的配置文件
--------------------------------------------------------------------------------
/docs/QUICKSTART.md:
--------------------------------------------------------------------------------
1 | # AI-Cloud-Go 快速启动指南
2 |
3 | 这份指南将帮助您快速设置并运行AI-Cloud-Go系统,包括各种支持的嵌入服务配置。
4 |
5 | ## 前置条件
6 |
7 | - Go 1.16+
8 | - Docker和Docker Compose
9 | - Git
10 | - 确保以下端口未被占用: 8080, 3306, 9000, 9001, 19530, 9091, 11434
11 |
12 | ## 步骤一:获取代码
13 |
14 | ```bash
15 | git clone https://github.com/RaspberryCola/AI-Cloud-Go.git
16 | cd AI-Cloud-Go
17 | ```
18 |
19 | ## 步骤二:设置环境
20 |
21 | ### 配置系统参数
22 |
23 | 1. 确保项目根目录存在`config`文件夹,如果不存在请创建
24 | 2. 在`config`文件夹中创建或修改`config.yaml`文件:
25 |
26 | ```yaml
27 | server:
28 | port: "8080"
29 |
30 | database:
31 | host: "localhost"
32 | port: "3306"
33 | user: "root"
34 | password: "123456"
35 | name: "ai_cloud"
36 |
37 | storage:
38 | type: "minio" # local, oss, minio
39 | minio:
40 | endpoint: "localhost:9000"
41 | bucket: "ai-cloud"
42 | access_key_id: "minioadmin"
43 | access_key_secret: "minioadmin"
44 | use_ssl: false
45 | region: ""
46 |
47 | # Milvus向量数据库配置
48 | milvus:
49 | address: "localhost:19530"
50 | index_type: "IVF_FLAT"
51 | metric_type: "COSINE"
52 | nlist: 128
53 | # 搜索参数
54 | nprobe: 16
55 | # 字段最大长度配置
56 | id_max_length: "64"
57 | content_max_length: "65535"
58 | doc_id_max_length: "64"
59 | doc_name_max_length: "256"
60 | kb_id_max_length: "64"
61 |
62 | # 语言模型配置(后续会移除到统一的模型管理中)
63 | llm:
64 | api_key: "your-llm-api-key" # 替换为您的语言模型API密钥
65 | model: "deepseek-chat"
66 | base_url: "https://api.deepseek.com/v1"
67 | max_tokens: 4096
68 | temperature: 0.7
69 |
70 | rag:
71 | chunk_size: 1500
72 | overlap_size: 500
73 | ```
74 |
75 | ## 步骤三:启动基础服务
76 |
77 | ### 启动所有Docker服务
78 |
79 | 确保项目根目录包含`docker-compose.yml`文件,然后运行:
80 |
81 | ```bash
82 | # 启动MySQL, MinIO, Milvus服务
83 | docker-compose up -d
84 | ```
85 |
86 | 或者,如果您只想启动特定服务:
87 |
88 | ```bash
89 | # 仅启动MySQL和MinIO
90 | docker-compose up -d mysql-init minio minio-init
91 |
92 | # 仅启动Milvus相关服务
93 | docker-compose up -d etcd minio-for-milvus standalone
94 | ```
95 |
96 | ### 检查服务状态
97 |
98 | ```bash
99 | # 列出所有容器
100 | docker ps
101 |
102 | # 检查MySQL是否正常初始化
103 | docker logs mysql-init
104 |
105 | # 检查MinIO是否正常启动
106 | curl http://localhost:9000
107 |
108 | # 检查Milvus是否正常启动
109 | docker logs milvus-standalone
110 | ```
111 |
112 | 确保所有服务都正常运行,没有错误信息。
113 |
114 | ## 步骤四:设置Ollama (如果使用Ollama作为嵌入服务)
115 |
116 | 如果您选择使用Ollama作为嵌入服务:
117 |
118 | 1. 从[Ollama官网](https://ollama.com/download)下载并安装Ollama
119 |
120 | 2. 拉取嵌入模型:
121 | ```bash
122 | ollama pull mxbai-embed-large
123 | ```
124 |
125 | 3. 启动Ollama服务:
126 | ```bash
127 | OLLAMA_HOST="0.0.0.0" OLLAMA_ORIGINS="*" ollama serve
128 | ```
129 |
130 | 您也可以创建一个启动脚本`start-ollama.sh`:
131 | ```bash
132 | #!/bin/bash
133 | OLLAMA_HOST="0.0.0.0" OLLAMA_ORIGINS="*" OLLAMA_KEEP_ALIVE="24h" ollama serve
134 | ```
135 |
136 | 然后运行`chmod +x start-ollama.sh`和`./start-ollama.sh`
137 |
138 | 4. 在前端模型服务中添加模型:
139 |
140 | 5. 验证Ollama服务是否正常运行:
141 | ```bash
142 | curl http://localhost:11434/api/tags
143 | ```
144 | 应返回包含`mxbai-embed-large`的模型列表。
145 |
146 | ## 步骤五:确认Milvus配置
147 |
148 | Milvus的连接地址和向量集合参数在配置文件中指定:
149 |
150 | ```yaml
151 | # Milvus向量数据库配置
152 | milvus:
153 | address: "localhost:19530"
154 | index_type: "IVF_FLAT"
155 | metric_type: "COSINE"
156 | nlist: 128
157 | # 搜索参数
158 | nprobe: 16
159 | # 字段最大长度配置
160 | id_max_length: "64"
161 | content_max_length: "65535"
162 | doc_id_max_length: "64"
163 | doc_name_max_length: "256"
164 | kb_id_max_length: "64"
165 | ```
166 |
167 | 如果您使用的是自定义的Milvus部署或远程Milvus服务,请相应地修改地址。您也可以根据需要调整以下参数:
168 |
169 | - `index_type`: 索引类型,支持IVF_FLAT、IVF_SQ8、HNSW等
170 | - `metric_type`: 距离计算方式,支持COSINE、L2、IP等
171 | - `nlist`: IVF索引的聚类数量
172 | - `nprobe`: 搜索时检查的聚类数量,值越大结果越精确但查询越慢
173 | - `*_max_length`: 各字段的最大长度设置,特别是处理大文档时可能需要调整content_max_length
174 |
175 | 验证Milvus是否正常运行:
176 |
177 | ```bash
178 | # 检查Milvus容器状态
179 | docker ps | grep milvus
180 |
181 | # 查看Milvus日志
182 | docker logs milvus-standalone
183 | ```
184 |
185 | ## 步骤六:启动AI-Cloud-Go
186 |
187 | ### 下载依赖
188 |
189 | ```bash
190 | go mod download
191 | ```
192 |
193 | ### 运行应用
194 |
195 | ```bash
196 | go run cmd/main.go
197 | ```
198 |
199 | 应用将在http://localhost:8080运行。您应该看到类似以下的输出:
200 |
201 | ```
202 | [GIN-debug] Listening and serving HTTP on :8080
203 | ```
204 |
205 | ## 步骤七:验证安装
206 |
207 | ### 检查API服务
208 |
209 | ```bash
210 | # 检查健康状态
211 | curl http://localhost:8080/api/health
212 |
213 | # 使用Swagger查看API文档
214 | # 在浏览器中访问: http://localhost:8080/swagger/index.html
215 | ```
216 |
217 | ### 测试嵌入服务
218 |
219 | 根据您选择的嵌入服务,验证其正常工作:
220 |
221 | ```bash
222 | # 如果使用OpenAI
223 | curl -H "Authorization: Bearer 您的API密钥" https://api.openai.com/v1/embeddings -d '{"model":"text-embedding-3-large", "input":"测试文本"}'
224 |
225 | # 如果使用Ollama
226 | curl -X POST http://localhost:11434/api/embed -d '{"model":"mxbai-embed-large", "input":"测试文本"}'
227 | ```
228 |
229 | ### 注册和登录
230 |
231 | 要使用系统的大多数功能,您需要先注册并登录以获取JWT令牌:
232 |
233 | 1. 注册用户:
234 | ```bash
235 | curl -X POST http://localhost:8080/api/users/register -H "Content-Type: application/json" -d '{"username":"testuser", "password":"testpassword", "email":"test@example.com"}'
236 | ```
237 |
238 | 2. 登录并获取Token:
239 | ```bash
240 | curl -X POST http://localhost:8080/api/users/login -H "Content-Type: application/json" -d '{"username":"testuser", "password":"testpassword"}'
241 | ```
242 |
243 | 复制返回的`token`值,供后续API调用使用。
244 |
245 | ## 使用知识库功能
246 |
247 | 使用您在登录时获得的JWT令牌:
248 |
249 | 1. 创建知识库:
250 | ```bash
251 | curl -X POST http://localhost:8080/api/kb/create \
252 | -H "Authorization: Bearer 您的JWT令牌" \
253 | -H "Content-Type: application/json" \
254 | -d '{"name":"测试知识库","description":"这是一个测试知识库"}'
255 | ```
256 |
257 | 2. 上传文档到知识库(使用multipart/form-data):
258 | ```bash
259 | curl -X POST http://localhost:8080/api/kb/addNew \
260 | -H "Authorization: Bearer 您的JWT令牌" \
261 | -F "kb_id=知识库ID" \
262 | -F "file=@/path/to/your/document.pdf"
263 | ```
264 |
265 | 3. 查询知识库:
266 | ```bash
267 | curl -X POST http://localhost:8080/api/kb/retrieve \
268 | -H "Authorization: Bearer 您的JWT令牌" \
269 | -H "Content-Type: application/json" \
270 | -d '{"kb_id":"知识库ID","query":"您的问题","top_k":3}'
271 | ```
272 |
273 | ## 故障排除
274 |
275 | ### 初始化问题
276 | - 如果启动时显示找不到配置文件,请确保`config/config.yaml`文件存在并格式正确
277 | - 如果显示找不到`init.sql`,请确认项目根目录中有此文件
278 |
279 | ### MySQL连接问题
280 | - 错误:`dial tcp 127.0.0.1:3306: connect: connection refused`
281 | - 解决:确保MySQL容器已启动,运行`docker ps | grep mysql`
282 |
283 | ### Milvus连接问题
284 | - 错误:`无法连接到Milvus`
285 | - 解决:
286 | 1. 确保Milvus容器正在运行:`docker ps | grep milvus`
287 | 2. 检查config.yaml中的milvus.address配置是否正确
288 | 3. 如果使用自定义Milvus部署,确保端口映射正确
289 |
290 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module ai-cloud
2 |
3 | go 1.23.4
4 |
5 | require (
6 | code.sajari.com/docconv/v2 v2.0.0-pre.4
7 | github.com/aliyun/aliyun-oss-go-sdk v3.0.2+incompatible
8 | github.com/bytedance/sonic v1.13.2
9 | github.com/cloudwego/eino v0.3.30
10 | github.com/cloudwego/eino-ext/components/document/loader/url v0.0.0-20250328102648-b47e7f1587fa
11 | github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20250328102648-b47e7f1587fa
12 | github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20250328102648-b47e7f1587fa
13 | github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250331101427-906b8d194a99
14 | github.com/cloudwego/eino-ext/components/tool/mcp v0.0.0-20250507115047-b20720df8528
15 | github.com/cloudwego/eino-ext/libs/acl/openai v0.0.0-20250422092704-54e372e1fa3d
16 | github.com/fsnotify/fsnotify v1.9.0
17 | github.com/gin-contrib/cors v1.7.3
18 | github.com/gin-contrib/sse v0.1.0
19 | github.com/gin-gonic/gin v1.10.0
20 | github.com/golang-jwt/jwt/v5 v5.2.1
21 | github.com/google/uuid v1.6.0
22 | github.com/mark3labs/mcp-go v0.26.0
23 | github.com/milvus-io/milvus-sdk-go/v2 v2.4.2
24 | github.com/minio/minio-go/v7 v7.0.84
25 | github.com/ollama/ollama v0.5.12
26 | github.com/spf13/viper v1.20.1
27 | golang.org/x/crypto v0.34.0
28 | gorm.io/driver/mysql v1.5.7
29 | gorm.io/gorm v1.25.12
30 |
31 | )
32 |
33 | require (
34 | github.com/JalfResi/justext v0.0.0-20170829062021-c0282dea7198 // indirect
35 | github.com/PuerkitoBio/goquery v1.8.1 // indirect
36 | github.com/advancedlogic/GoOse v0.0.0-20191112112754-e742535969c1 // indirect
37 | github.com/andybalholm/cascadia v1.3.2 // indirect
38 | github.com/araddon/dateparse v0.0.0-20200409225146-d820a6159ab1 // indirect
39 | github.com/aymerick/douceur v0.2.0 // indirect
40 | github.com/bytedance/sonic/loader v0.2.4 // indirect
41 | github.com/cloudwego/base64x v0.1.5 // indirect
42 | github.com/cloudwego/eino-ext/components/document/parser/html v0.0.0-20250117061805-cd80d1780d76 // indirect
43 | github.com/cockroachdb/errors v1.9.1 // indirect
44 | github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect
45 | github.com/cockroachdb/redact v1.1.3 // indirect
46 | github.com/dustin/go-humanize v1.0.1 // indirect
47 | github.com/fatih/set v0.2.1 // indirect
48 | github.com/gabriel-vasile/mimetype v1.4.8 // indirect
49 | github.com/getkin/kin-openapi v0.118.0 // indirect
50 | github.com/getsentry/sentry-go v0.12.0 // indirect
51 | github.com/gigawattio/window v0.0.0-20180317192513-0f5467e35573 // indirect
52 | github.com/go-ini/ini v1.67.0 // indirect
53 | github.com/go-openapi/jsonpointer v0.21.0 // indirect
54 | github.com/go-openapi/swag v0.23.0 // indirect
55 | github.com/go-playground/locales v0.14.1 // indirect
56 | github.com/go-playground/universal-translator v0.18.1 // indirect
57 | github.com/go-playground/validator/v10 v10.23.0 // indirect
58 | github.com/go-resty/resty/v2 v2.3.0 // indirect
59 | github.com/go-sql-driver/mysql v1.7.1 // indirect
60 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
61 | github.com/goccy/go-json v0.10.4 // indirect
62 | github.com/gogo/protobuf v1.3.2 // indirect
63 | github.com/golang/protobuf v1.5.4 // indirect
64 | github.com/goph/emperror v0.17.2 // indirect
65 | github.com/gorilla/css v1.0.1 // indirect
66 | github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
67 | github.com/invopop/yaml v0.3.1 // indirect
68 | github.com/jaytaylor/html2text v0.0.0-20200412013138-3577fbdbcff7 // indirect
69 | github.com/jinzhu/inflection v1.0.0 // indirect
70 | github.com/jinzhu/now v1.1.5 // indirect
71 | github.com/josharian/intern v1.0.0 // indirect
72 | github.com/json-iterator/go v1.1.12 // indirect
73 | github.com/klauspost/compress v1.17.11 // indirect
74 | github.com/klauspost/cpuid/v2 v2.2.9 // indirect
75 | github.com/kr/pretty v0.3.1 // indirect
76 | github.com/kr/text v0.2.0 // indirect
77 | github.com/leodido/go-urn v1.4.0 // indirect
78 | github.com/levigross/exp-html v0.0.0-20120902181939-8df60c69a8f5 // indirect
79 | github.com/mailru/easyjson v0.9.0 // indirect
80 | github.com/mattn/go-isatty v0.0.20 // indirect
81 | github.com/mattn/go-runewidth v0.0.14 // indirect
82 | github.com/meguminnnnnnnnn/go-openai v0.0.0-20250408071642-761325becfd6 // indirect
83 | github.com/microcosm-cc/bluemonday v1.0.27 // indirect
84 | github.com/milvus-io/milvus-proto/go-api/v2 v2.5.6 // indirect
85 | github.com/minio/md5-simd v1.1.2 // indirect
86 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
87 | github.com/modern-go/reflect2 v1.0.2 // indirect
88 | github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
89 | github.com/nikolalohinski/gonja v1.5.3 // indirect
90 | github.com/olekukonko/tablewriter v0.0.5 // indirect
91 | github.com/otiai10/gosseract/v2 v2.2.4 // indirect
92 | github.com/pelletier/go-toml/v2 v2.2.3 // indirect
93 | github.com/perimeterx/marshmallow v1.1.5 // indirect
94 | github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect
95 | github.com/pkg/errors v0.9.1 // indirect
96 | github.com/richardlehane/mscfb v1.0.3 // indirect
97 | github.com/richardlehane/msoleps v1.0.3 // indirect
98 | github.com/rivo/uniseg v0.2.0 // indirect
99 | github.com/rogpeppe/go-internal v1.13.1 // indirect
100 | github.com/rs/xid v1.6.0 // indirect
101 | github.com/sagikazarmark/locafero v0.7.0 // indirect
102 | github.com/sashabaranov/go-openai v1.37.0 // indirect
103 | github.com/sirupsen/logrus v1.9.3 // indirect
104 | github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect
105 | github.com/sourcegraph/conc v0.3.0 // indirect
106 | github.com/spf13/afero v1.12.0 // indirect
107 | github.com/spf13/cast v1.7.1 // indirect
108 | github.com/spf13/pflag v1.0.6 // indirect
109 | github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
110 | github.com/stretchr/objx v0.5.2 // indirect
111 | github.com/subosito/gotenv v1.6.0 // indirect
112 | github.com/tidwall/gjson v1.17.1 // indirect
113 | github.com/tidwall/match v1.1.1 // indirect
114 | github.com/tidwall/pretty v1.2.0 // indirect
115 | github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
116 | github.com/ugorji/go/codec v1.2.12 // indirect
117 | github.com/yargevad/filepathx v1.0.0 // indirect
118 | github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
119 | go.uber.org/atomic v1.11.0 // indirect
120 | go.uber.org/multierr v1.11.0 // indirect
121 | golang.org/x/arch v0.14.0 // indirect
122 | golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
123 | golang.org/x/net v0.35.0 // indirect
124 | golang.org/x/sync v0.11.0 // indirect
125 | golang.org/x/sys v0.30.0 // indirect
126 | golang.org/x/text v0.22.0 // indirect
127 | golang.org/x/time v0.9.0 // indirect
128 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect
129 | google.golang.org/grpc v1.67.3 // indirect
130 | google.golang.org/protobuf v1.36.3 // indirect
131 | gopkg.in/yaml.v3 v3.0.1 // indirect
132 | )
133 |
134 | replace nhooyr.io/websocket => github.com/coder/websocket v1.8.7
135 |
--------------------------------------------------------------------------------
/init.sql:
--------------------------------------------------------------------------------
1 | CREATE DATABASE IF NOT EXISTS `ai_cloud`;
2 | USE ai_cloud;
3 |
4 | -- You can add more initialization steps here if needed, such as:
5 | -- CREATE TABLE statements
6 | -- INSERT initial data
7 | -- CREATE USER and GRANT privileges
--------------------------------------------------------------------------------
/internal/component/embedding/embedding.go:
--------------------------------------------------------------------------------
1 | package embedding
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "context"
6 | "fmt"
7 | einoEmbedding "github.com/cloudwego/eino/components/embedding"
8 | "time"
9 | )
10 |
11 | const (
12 | ProviderOpenAI = "openai"
13 | ProviderOllama = "ollama"
14 | )
15 |
16 | type EmbeddingOption func(*EmbeddingOptions)
17 |
18 | type EmbeddingOptions struct {
19 | Timeout *time.Duration
20 | }
21 |
22 | func WithTimeout(timeout time.Duration) EmbeddingOption {
23 | return func(o *EmbeddingOptions) {
24 | o.Timeout = &timeout
25 | }
26 | }
27 |
28 | // EmbeddingService 定义向量嵌入服务的通用接口
29 | type EmbeddingService interface {
30 | New(ctx context.Context, cfg *model.Model, opts ...EmbeddingOption) (EmbeddingService, error)
31 | // EmbedStrings 将文本转换为向量表示
32 | EmbedStrings(ctx context.Context, texts []string, opts ...einoEmbedding.Option) ([][]float64, error)
33 | // GetDimension 返回嵌入向量的维度
34 | GetDimension() int
35 | }
36 |
37 | var embeddingMap = make(map[string]EmbeddingService)
38 |
39 | func register(name string, embeddingService EmbeddingService) {
40 | embeddingMap[name] = embeddingService
41 | }
42 |
43 | func NewEmbeddingService(ctx context.Context, cfg *model.Model, opts ...EmbeddingOption) (EmbeddingService, error) {
44 | if cfg == nil {
45 | return nil, fmt.Errorf("embedding config is nil")
46 | }
47 |
48 | if cfg.Server == "" {
49 | return nil, fmt.Errorf("embedding config server is empty")
50 | }
51 |
52 | // 获取实例
53 | if embedding, ok := embeddingMap[cfg.Server]; ok {
54 | return embedding.New(ctx, cfg, opts...)
55 | }
56 | return nil, fmt.Errorf("不支持的嵌入服务提供者: %s", cfg.Type, opts)
57 | }
58 |
--------------------------------------------------------------------------------
/internal/component/embedding/ollama.go:
--------------------------------------------------------------------------------
1 | package embedding
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "context"
6 | "fmt"
7 | einoEmbedding "github.com/cloudwego/eino/components/embedding"
8 | "time"
9 |
10 | "github.com/ollama/ollama/api"
11 | "net/http"
12 | "net/url"
13 | )
14 |
15 | func init() {
16 | register(ProviderOllama, &ollamaEmbedder{})
17 | }
18 |
19 | type OllamaEmbeddingConfig struct {
20 | BaseURL string
21 | Model string
22 | Dimension *int
23 | Timeout *time.Duration
24 | HTTPClient *http.Client
25 | }
26 |
27 | type ollamaEmbedder struct {
28 | cli *api.Client
29 | conf *OllamaEmbeddingConfig
30 | }
31 |
32 | // TODO:添加默认超时时间防止报错
33 | func (o *ollamaEmbedder) New(ctx context.Context, cfg *model.Model, opts ...EmbeddingOption) (EmbeddingService, error) {
34 | // 检查配置
35 | if err := checkCfg(cfg); err != nil {
36 | return nil, err
37 | }
38 | // 处理选项
39 | options := &EmbeddingOptions{}
40 | for _, opt := range opts {
41 | opt(options)
42 | }
43 |
44 | config := &OllamaEmbeddingConfig{
45 | BaseURL: cfg.BaseURL,
46 | Model: cfg.ModelName,
47 | Dimension: &cfg.Dimension,
48 | Timeout: options.Timeout,
49 | }
50 |
51 | // 构造 client
52 | var httpClient *http.Client
53 | if config.HTTPClient != nil {
54 | httpClient = config.HTTPClient
55 | } else {
56 | httpClient = &http.Client{Timeout: *config.Timeout}
57 | }
58 |
59 | // 构造url
60 | baseURL, err := url.Parse(config.BaseURL)
61 | if err != nil {
62 | return nil, fmt.Errorf("invalid base URL: %w", err)
63 | }
64 |
65 | // 创建 client
66 | cli := api.NewClient(baseURL, httpClient)
67 |
68 | return &ollamaEmbedder{
69 | cli: cli,
70 | conf: config,
71 | }, nil
72 | }
73 |
74 | func (o *ollamaEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...einoEmbedding.Option) (
75 | embeddings [][]float64, err error) {
76 | req := &api.EmbedRequest{
77 | Model: o.conf.Model,
78 | Input: texts,
79 | }
80 | resp, err := o.cli.Embed(ctx, req)
81 | if err != nil {
82 | return nil, err
83 | }
84 |
85 | embeddings = make([][]float64, len(resp.Embeddings))
86 | for i, d := range resp.Embeddings {
87 | res := make([]float64, len(d))
88 | for j, emb := range d {
89 | res[j] = float64(emb)
90 | }
91 | embeddings[i] = res
92 | }
93 |
94 | return embeddings, nil
95 | }
96 |
97 | func (o *ollamaEmbedder) GetType() string {
98 | return ProviderOllama
99 | }
100 |
101 | // TODO:处理Callback
102 | //func (e *ollamaEmbedder) IsCallbacksEnabled() bool {
103 | // return true
104 | //}
105 |
106 | func (o *ollamaEmbedder) GetDimension() int {
107 | return *o.conf.Dimension
108 | }
109 |
110 | func checkCfg(cfg *model.Model) error {
111 | if cfg.BaseURL == "" {
112 | return fmt.Errorf("ollama base URL cannot be empty")
113 | }
114 |
115 | if _, err := url.Parse(cfg.BaseURL); err != nil {
116 | return fmt.Errorf("invalid ollama base URL: %w", err)
117 | }
118 |
119 | if cfg.ModelName == "" {
120 | return fmt.Errorf("ollama model name cannot be empty")
121 | }
122 |
123 | if cfg.Dimension <= 0 {
124 | return fmt.Errorf("ollama embedding dimension must be positive")
125 | }
126 |
127 | return nil
128 | }
129 |
--------------------------------------------------------------------------------
/internal/component/embedding/openai.go:
--------------------------------------------------------------------------------
1 | package embedding
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "context"
6 | "fmt"
7 | "github.com/cloudwego/eino-ext/components/embedding/openai"
8 | "github.com/cloudwego/eino/components/embedding"
9 | "net/url"
10 | "time"
11 | )
12 |
13 | func init() {
14 | register(ProviderOpenAI, &openaiEmbedder{})
15 | }
16 |
17 | type OpenAIEmbeddingConfig struct {
18 | BaseURL string
19 | APIKey string
20 | Model string
21 | Timeout *time.Duration
22 | Dimension *int
23 | }
24 | type openaiEmbedder struct {
25 | conf *OpenAIEmbeddingConfig
26 | embedder *openai.Embedder
27 | }
28 |
29 | func (o *openaiEmbedder) New(ctx context.Context, cfg *model.Model, opts ...EmbeddingOption) (EmbeddingService, error) {
30 | if err := checkOpenAICfg(cfg); err != nil {
31 | return nil, err
32 | }
33 |
34 | options := &EmbeddingOptions{}
35 | for _, opt := range opts {
36 | opt(options)
37 | }
38 |
39 | config := &OpenAIEmbeddingConfig{
40 | BaseURL: cfg.BaseURL,
41 | APIKey: cfg.APIKey,
42 | Model: cfg.ModelName,
43 | Timeout: options.Timeout,
44 | Dimension: &cfg.Dimension,
45 | }
46 |
47 | embeder, err := openai.NewEmbedder(ctx, &openai.EmbeddingConfig{
48 | APIKey: config.APIKey,
49 | BaseURL: config.BaseURL,
50 | Model: config.Model,
51 | Timeout: *options.Timeout,
52 | Dimensions: &cfg.Dimension,
53 | })
54 | if err != nil {
55 | return nil, err
56 | }
57 | return &openaiEmbedder{
58 | conf: config,
59 | embedder: embeder,
60 | }, nil
61 | }
62 |
63 | func (s *openaiEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
64 | return s.embedder.EmbedStrings(ctx, texts, opts...)
65 | }
66 |
67 | func (o *openaiEmbedder) GetDimension() int {
68 | return *o.conf.Dimension
69 | }
70 |
71 | func (o *openaiEmbedder) GetType() string {
72 | return ProviderOpenAI
73 | }
74 |
75 | func checkOpenAICfg(cfg *model.Model) error {
76 | if cfg.BaseURL == "" {
77 | return fmt.Errorf("opnai base URL cannot be empty")
78 | }
79 |
80 | if _, err := url.Parse(cfg.BaseURL); err != nil {
81 | return fmt.Errorf("invalid openai base URL: %w", err)
82 | }
83 |
84 | if cfg.ModelName == "" {
85 | return fmt.Errorf("openai model name cannot be empty")
86 | }
87 |
88 | if cfg.Dimension <= 0 {
89 | return fmt.Errorf("openai embedding dimension must be positive")
90 | }
91 |
92 | return nil
93 | }
94 |
--------------------------------------------------------------------------------
/internal/component/indexer/milvus/indexer.go:
--------------------------------------------------------------------------------
1 | package milvus
2 |
3 | import (
4 | "ai-cloud/config"
5 | "ai-cloud/internal/utils"
6 | "ai-cloud/pkgs/consts"
7 | "context"
8 | "fmt"
9 | "github.com/bytedance/sonic"
10 | "github.com/cloudwego/eino/components/embedding"
11 | "github.com/cloudwego/eino/components/indexer"
12 | "github.com/cloudwego/eino/schema"
13 | "github.com/milvus-io/milvus-sdk-go/v2/client"
14 | "github.com/milvus-io/milvus-sdk-go/v2/entity"
15 | "strconv"
16 | "strings"
17 | )
18 |
19 | type MilvusIndexerConfig struct {
20 | Collection string
21 | Dimension int
22 | Embedding embedding.Embedder
23 | Client client.Client
24 | }
25 |
26 | type MilvusIndexer struct {
27 | config MilvusIndexerConfig
28 | }
29 |
30 | func NewMilvusIndexer(ctx context.Context, conf *MilvusIndexerConfig) (*MilvusIndexer, error) {
31 | // 检查配置
32 | if err := conf.check(); err != nil {
33 | return nil, fmt.Errorf("[NewMilvusIndexer] invalid config: %w", err)
34 | }
35 |
36 | // 检查Collection是否存在
37 | exists, err := conf.Client.HasCollection(ctx, conf.Collection)
38 | if err != nil {
39 | return nil, fmt.Errorf("[NewMilvusIndexer] check milvus collection failed : %w", err)
40 | }
41 | if !exists {
42 | if err := conf.createCollection(ctx, conf.Collection, conf.Dimension); err != nil {
43 | return nil, fmt.Errorf("[NewMilvusIndexer] create collection failed: %w", err)
44 | }
45 | }
46 |
47 | // 加载Collection
48 | err = conf.Client.LoadCollection(ctx, conf.Collection, false)
49 | if err != nil {
50 | return nil, fmt.Errorf("[NewMilvusIndexer] failed to load collection: %w", err)
51 | }
52 |
53 | return &MilvusIndexer{
54 | config: *conf,
55 | }, nil
56 | }
57 |
58 | func (m *MilvusIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
59 |
60 | // 如果有opts则用opts中的配置(允许在Store的时候更换Embedder配置)
61 | co := indexer.GetCommonOptions(&indexer.Options{ //提供默认值选项
62 | SubIndexes: nil,
63 | Embedding: m.config.Embedding,
64 | }, opts...)
65 |
66 | embedder := co.Embedding
67 | if embedder == nil {
68 | return nil, fmt.Errorf("[Indexer.Store] embedding not provided")
69 | }
70 | // 获取文档内容部分
71 | texts := make([]string, 0, len(docs))
72 | for _, doc := range docs {
73 | texts = append(texts, doc.Content)
74 | }
75 | // 向量化
76 | vectors := make([][]float64, len(texts)) // 预分配结果切片
77 |
78 | for i, text := range texts {
79 | // 每次只embed一个文本
80 | vec, err := embedder.EmbedStrings(ctx, []string{text})
81 | if err != nil {
82 | return nil, fmt.Errorf("[Indexer.Store] failed to embed text at index %d: %w", i, err)
83 | }
84 |
85 | // 确保返回的向量是我们期望的单个结果
86 | if len(vec) != 1 {
87 | return nil, fmt.Errorf("[Indexer.Store] unexpected number of vectors returned: %d", len(vec))
88 | }
89 |
90 | vectors[i] = vec[0]
91 | }
92 | if len(vectors) != len(docs) {
93 | return nil, fmt.Errorf("[Indexer.Store] embedding vector length mismatch")
94 | }
95 | rows, err := DocumentConvert(ctx, docs, vectors)
96 | if err != nil {
97 | return nil, err
98 | }
99 |
100 | results, err := m.config.Client.InsertRows(ctx, m.config.Collection, "", rows)
101 | if err != nil {
102 | return nil, err
103 | }
104 | if err := m.config.Client.Flush(ctx, m.config.Collection, false); err != nil {
105 | return nil, err
106 | }
107 | ids = make([]string, results.Len())
108 | for idx := 0; idx < results.Len(); idx++ {
109 | ids[idx], err = results.GetAsString(idx)
110 | if err != nil {
111 | return nil, fmt.Errorf("[Indexer.Store] failed to get id: %w", err)
112 | }
113 | }
114 | return ids, nil
115 |
116 | }
117 | func (m *MilvusIndexer) GetType() string {
118 | return "Milvus"
119 | }
120 | func DocumentConvert(ctx context.Context, docs []*schema.Document, vectors [][]float64) ([]interface{}, error) {
121 |
122 | em := make([]defaultSchema, 0, len(docs))
123 | rows := make([]interface{}, 0, len(docs))
124 |
125 | for _, doc := range docs {
126 | // 从原始 MetaData 中拿出结构化字段
127 | kbID, ok := doc.MetaData["kb_id"].(string)
128 | if !ok {
129 | return nil, fmt.Errorf("invalid type for kb_id")
130 | }
131 | docID, ok := doc.MetaData["document_id"].(string)
132 | if !ok {
133 | return nil, fmt.Errorf("invalid type for document_id")
134 | }
135 |
136 | // 构造要序列化到 Metadata 字段里的 map,排除 kb_id 和 document_id
137 | metaCopy := make(map[string]any, len(doc.MetaData))
138 | for k, v := range doc.MetaData {
139 | if k == "kb_id" || k == "document_id" {
140 | continue
141 | }
142 | metaCopy[k] = v
143 | }
144 | metadataBytes, err := sonic.Marshal(metaCopy)
145 | if err != nil {
146 | return nil, fmt.Errorf("[DocumentConvert] failed to marshal metadata: %w", err)
147 | }
148 |
149 | em = append(em, defaultSchema{
150 | ID: doc.ID,
151 | Content: doc.Content,
152 | KBID: kbID,
153 | DocumentID: docID,
154 | Vector: nil, // 后面统一填充
155 | Metadata: metadataBytes, // 只包含剩下的字段
156 | })
157 | }
158 |
159 | // 填充向量并生成 rows
160 | for idx, vec := range vectors {
161 | em[idx].Vector = utils.ConvertFloat64ToFloat32Embedding(vec)
162 | rows = append(rows, &em[idx])
163 | }
164 | return rows, nil
165 | }
166 |
167 | func (m *MilvusIndexerConfig) createCollection(ctx context.Context, collectionName string, dimension int) error {
168 | // 获取 Milvus 配置
169 | milvusConfig := config.GetConfig().Milvus
170 | // 创建集合Schema
171 | s := &entity.Schema{
172 | CollectionName: collectionName,
173 | Description: "存储文档分块和向量",
174 | AutoID: false,
175 | Fields: []*entity.Field{
176 | {
177 | Name: consts.FieldNameID,
178 | DataType: entity.FieldTypeVarChar,
179 | PrimaryKey: true,
180 | AutoID: false,
181 | TypeParams: map[string]string{
182 | "max_length": milvusConfig.IDMaxLength,
183 | },
184 | },
185 | {
186 | Name: consts.FieldNameContent,
187 | DataType: entity.FieldTypeVarChar,
188 | TypeParams: map[string]string{
189 | "max_length": milvusConfig.ContentMaxLength,
190 | },
191 | },
192 | {
193 | Name: consts.FieldNameDocumentID,
194 | DataType: entity.FieldTypeVarChar,
195 | TypeParams: map[string]string{
196 | "max_length": milvusConfig.DocIDMaxLength,
197 | },
198 | },
199 | {
200 | Name: consts.FieldNameKBID,
201 | DataType: entity.FieldTypeVarChar,
202 | TypeParams: map[string]string{
203 | "max_length": milvusConfig.KbIDMaxLength,
204 | },
205 | },
206 | {
207 | Name: consts.FieldNameVector,
208 | DataType: entity.FieldTypeFloatVector,
209 | TypeParams: map[string]string{
210 | "dim": strconv.Itoa(dimension),
211 | },
212 | },
213 | {
214 | Name: consts.FieldNameMetadata,
215 | DataType: entity.FieldTypeJSON,
216 | },
217 | },
218 | }
219 |
220 | // 创建集合
221 | if err := m.Client.CreateCollection(ctx, s, 1); err != nil {
222 | return fmt.Errorf("[NewMilvusIndexer.createCollection] 创建集合失败: %w", err)
223 | }
224 |
225 | // 创建索引
226 | idx, err := milvusConfig.GetMilvusIndex()
227 | if err != nil {
228 | return fmt.Errorf("[NewMilvusIndexer.createCollection] 从配置中获取索引类型失败: %w", err)
229 | }
230 |
231 | if err := m.Client.CreateIndex(ctx, collectionName, consts.FieldNameVector, idx, false); err != nil {
232 | return fmt.Errorf("[NewMilvusIndexer.createCollection] 创建索引失败: %w", err)
233 | }
234 | return nil
235 | }
236 |
237 | func (m *MilvusIndexerConfig) check() error {
238 | if m.Client == nil {
239 | return fmt.Errorf("[NewMilvusIndexer] milvus client is nil")
240 | }
241 | if m.Embedding == nil {
242 | return fmt.Errorf("[NewMilvusIndexer] embedding is nil")
243 | }
244 | if m.Collection == "" {
245 | return fmt.Errorf("[NewMilvusIndexer] collection is empty")
246 | }
247 | if m.Dimension == 0 {
248 | return fmt.Errorf("[NewMilvusIndexer] embedding dimension is zero")
249 | }
250 | return nil
251 | }
252 |
253 | func (m *MilvusIndexerConfig) IsCallbacksEnabled() bool {
254 | return true
255 | }
256 |
257 | func DeleteDos(client client.Client, docIDs []string, collectionName string) error {
258 | expr := fmt.Sprintf("%s in [\"%s\"]", consts.FieldNameDocumentID, strings.Join(docIDs, "\",\""))
259 | if err := client.Delete(context.Background(), collectionName, "", expr); err != nil {
260 | return fmt.Errorf("[MilvusIndexer.DeleteDos] failed to delete documents: %w", err)
261 | }
262 | return nil
263 | }
264 |
--------------------------------------------------------------------------------
/internal/component/indexer/milvus/types.go:
--------------------------------------------------------------------------------
1 | package milvus
2 |
3 | // defaultSchema
4 | type defaultSchema struct {
5 | ID string `json:"id" milvus:"name:id"`
6 | Content string `json:"content" milvus:"name:content"`
7 | DocumentID string `json:"document_id" milvus:"name:document_id"`
8 | KBID string `json:"kb_id" milvus:"name:kb_id"`
9 | Vector []float32 `json:"vector" milvus:"name:vector"`
10 | Metadata []byte `json:"metadata" milvus:"name:metadata"` // 存放例如DocumentName,Index等信息
11 | }
12 |
--------------------------------------------------------------------------------
/internal/component/llm/llm.go:
--------------------------------------------------------------------------------
1 | package llmfactory
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "context"
6 | "errors"
7 | "fmt"
8 | "strings"
9 | "time"
10 |
11 | eino_model "github.com/cloudwego/eino/components/model"
12 | ollama_api "github.com/ollama/ollama/api"
13 |
14 | // 假设这些是你的 Ollama 和 OpenAI 客户端包
15 | "ai-cloud/internal/component/llm/ollama"
16 | "ai-cloud/internal/component/llm/openai"
17 | )
18 |
19 | const (
20 | defaultLLMTimeout = 60 * time.Second
21 | serverOllama = "ollama"
22 | serverOpenAI = "openai"
23 | modelTypeLLM = "llm"
24 | )
25 |
26 | // GetLLMClient 使用 传入的 配置基础,并允许通过 clientDefaultOpts 设置客户端级别的默认调用选项。
27 | func GetLLMClient(ctx context.Context, cfg *model.Model) (eino_model.ToolCallingChatModel, error) {
28 | // 检查Model配置
29 | // TODO: 考虑通过check函数实现
30 | if cfg == nil {
31 | return nil, errors.New("input model configuration is nil")
32 | }
33 | if cfg.Type != modelTypeLLM {
34 | return nil, fmt.Errorf("model type is '%s', but expected '%s'", cfg.Type, modelTypeLLM)
35 | }
36 |
37 | // 2. 返回对应的server
38 | switch strings.ToLower(cfg.Server) {
39 | case serverOllama:
40 | ollamaCfg := &ollama.ChatModelConfig{
41 | BaseURL: cfg.BaseURL,
42 | Model: cfg.ModelName, // 使用最终确定的模型名称
43 | Timeout: defaultLLMTimeout,
44 | Options: &ollama_api.Options{}, // 设置包含默认调用参数的 Options
45 | }
46 | return ollama.NewChatModel(ctx, ollamaCfg)
47 |
48 | case serverOpenAI:
49 | openAICfg := &openai.ChatModelConfig{
50 | APIKey: cfg.APIKey,
51 | Model: cfg.ModelName,
52 | Timeout: defaultLLMTimeout,
53 | BaseURL: cfg.BaseURL,
54 | }
55 | return openai.NewChatModel(ctx, openAICfg)
56 |
57 | default:
58 | return nil, fmt.Errorf("unsupported LLM server type: '%s'", cfg.Server)
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/internal/component/llm/ollama/call_option.go:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2024 CloudWeGo Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package ollama
18 |
19 | import (
20 | "github.com/cloudwego/eino/components/model"
21 | )
22 |
23 | type options struct {
24 | Seed *int
25 | }
26 |
27 | func WithSeed(seed int) model.Option {
28 | return model.WrapImplSpecificOptFn(func(o *options) {
29 | o.Seed = &seed
30 | })
31 | }
32 |
--------------------------------------------------------------------------------
/internal/component/llm/openai/openai.go:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2024 CloudWeGo Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package openai
18 |
19 | import (
20 | "context"
21 | "net/http"
22 | "time"
23 |
24 | "github.com/cloudwego/eino/callbacks"
25 | "github.com/cloudwego/eino/components"
26 | "github.com/cloudwego/eino/components/model"
27 | "github.com/cloudwego/eino/schema"
28 |
29 | "github.com/cloudwego/eino-ext/libs/acl/openai"
30 | )
31 |
32 | var _ model.ToolCallingChatModel = (*ChatModel)(nil)
33 |
34 | type ChatModelConfig struct {
35 | // APIKey is your authentication key
36 | // Use OpenAI API key or Azure API key depending on the service
37 | // Required
38 | APIKey string `json:"api_key"`
39 |
40 | // Timeout specifies the maximum duration to wait for API responses
41 | // If HTTPClient is set, Timeout will not be used.
42 | // Optional. Default: no timeout
43 | Timeout time.Duration `json:"timeout"`
44 |
45 | // HTTPClient specifies the client to send HTTP requests.
46 | // If HTTPClient is set, Timeout will not be used.
47 | // Optional. Default &http.Client{Timeout: Timeout}
48 | HTTPClient *http.Client `json:"http_client"`
49 |
50 | // The following three fields are only required when using Azure OpenAI Service, otherwise they can be ignored.
51 | // For more details, see: https://learn.microsoft.com/en-us/azure/ai-services/openai/
52 |
53 | // ByAzure indicates whether to use Azure OpenAI Service
54 | // Required for Azure
55 | ByAzure bool `json:"by_azure"`
56 |
57 | // BaseURL is the Azure OpenAI endpoint URL
58 | // Format: https://{YOUR_RESOURCE_NAME}.openai.azure.com. YOUR_RESOURCE_NAME is the name of your resource that you have created on Azure.
59 | // Required for Azure
60 | BaseURL string `json:"base_url"`
61 |
62 | // APIVersion specifies the Azure OpenAI API version
63 | // Required for Azure
64 | APIVersion string `json:"api_version"`
65 |
66 | // The following fields correspond to OpenAI's chat completion API parameters
67 | // Ref: https://platform.openai.com/docs/api-reference/chat/create
68 |
69 | // Model specifies the ID of the model to use
70 | // Required
71 | Model string `json:"model"`
72 |
73 | // MaxTokens limits the maximum number of tokens that can be generated in the chat completion
74 | // Optional. Default: model's maximum
75 | MaxTokens *int `json:"max_tokens,omitempty"`
76 |
77 | // Temperature specifies what sampling temperature to use
78 | // Generally recommend altering this or TopP but not both.
79 | // Range: 0.0 to 2.0. Higher values make output more random
80 | // Optional. Default: 1.0
81 | Temperature *float32 `json:"temperature,omitempty"`
82 |
83 | // TopP controls diversity via nucleus sampling
84 | // Generally recommend altering this or Temperature but not both.
85 | // Range: 0.0 to 1.0. Lower values make output more focused
86 | // Optional. Default: 1.0
87 | TopP *float32 `json:"top_p,omitempty"`
88 |
89 | // Stop sequences where the API will stop generating further tokens
90 | // Optional. Example: []string{"\n", "User:"}
91 | Stop []string `json:"stop,omitempty"`
92 |
93 | // PresencePenalty prevents repetition by penalizing tokens based on presence
94 | // Range: -2.0 to 2.0. Positive values increase likelihood of new topics
95 | // Optional. Default: 0
96 | PresencePenalty *float32 `json:"presence_penalty,omitempty"`
97 |
98 | // ResponseFormat specifies the format of the model's response
99 | // Optional. Use for structured outputs
100 | ResponseFormat *openai.ChatCompletionResponseFormat `json:"response_format,omitempty"`
101 |
102 | // Seed enables deterministic sampling for consistent outputs
103 | // Optional. Set for reproducible results
104 | Seed *int `json:"seed,omitempty"`
105 |
106 | // FrequencyPenalty prevents repetition by penalizing tokens based on frequency
107 | // Range: -2.0 to 2.0. Positive values decrease likelihood of repetition
108 | // Optional. Default: 0
109 | FrequencyPenalty *float32 `json:"frequency_penalty,omitempty"`
110 |
111 | // LogitBias modifies likelihood of specific tokens appearing in completion
112 | // Optional. Map token IDs to bias values from -100 to 100
113 | LogitBias map[string]int `json:"logit_bias,omitempty"`
114 |
115 | // User unique identifier representing end-user
116 | // Optional. Helps OpenAI monitor and detect abuse
117 | User *string `json:"user,omitempty"`
118 | }
119 |
120 | var _ model.ChatModel = (*ChatModel)(nil)
121 |
122 | type ChatModel struct {
123 | cli *openai.Client
124 | }
125 |
126 | func NewChatModel(ctx context.Context, config *ChatModelConfig) (*ChatModel, error) {
127 | var nConf *openai.Config
128 | if config != nil {
129 | var httpClient *http.Client
130 |
131 | if config.HTTPClient != nil {
132 | httpClient = config.HTTPClient
133 | } else {
134 | httpClient = &http.Client{Timeout: config.Timeout}
135 | }
136 |
137 | nConf = &openai.Config{
138 | ByAzure: config.ByAzure,
139 | BaseURL: config.BaseURL,
140 | APIVersion: config.APIVersion,
141 | APIKey: config.APIKey,
142 | HTTPClient: httpClient,
143 | Model: config.Model,
144 | MaxTokens: config.MaxTokens,
145 | Temperature: config.Temperature,
146 | TopP: config.TopP,
147 | Stop: config.Stop,
148 | PresencePenalty: config.PresencePenalty,
149 | ResponseFormat: config.ResponseFormat,
150 | Seed: config.Seed,
151 | FrequencyPenalty: config.FrequencyPenalty,
152 | LogitBias: config.LogitBias,
153 | User: config.User,
154 | }
155 | }
156 | cli, err := openai.NewClient(ctx, nConf)
157 | if err != nil {
158 | return nil, err
159 | }
160 |
161 | return &ChatModel{
162 | cli: cli,
163 | }, nil
164 | }
165 |
166 | func (cm *ChatModel) Generate(ctx context.Context, in []*schema.Message, opts ...model.Option) (
167 | outMsg *schema.Message, err error) {
168 | ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel)
169 | return cm.cli.Generate(ctx, in, opts...)
170 | }
171 |
172 | func (cm *ChatModel) Stream(ctx context.Context, in []*schema.Message, opts ...model.Option) (outStream *schema.StreamReader[*schema.Message], err error) {
173 | ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel)
174 | return cm.cli.Stream(ctx, in, opts...)
175 | }
176 |
177 | func (cm *ChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
178 | cli, err := cm.cli.WithToolsForClient(tools)
179 | if err != nil {
180 | return nil, err
181 | }
182 | return &ChatModel{cli: cli}, nil
183 | }
184 |
185 | func (cm *ChatModel) BindTools(tools []*schema.ToolInfo) error {
186 | return cm.cli.BindTools(tools)
187 | }
188 |
189 | func (cm *ChatModel) BindForcedTools(tools []*schema.ToolInfo) error {
190 | return cm.cli.BindForcedTools(tools)
191 | }
192 |
193 | const typ = "OpenAI"
194 |
195 | func (cm *ChatModel) GetType() string {
196 | return typ
197 | }
198 |
199 | func (cm *ChatModel) IsCallbacksEnabled() bool {
200 | return cm.cli.IsCallbacksEnabled()
201 | }
202 |
--------------------------------------------------------------------------------
/internal/component/parser/pdf/docconv_parser.go:
--------------------------------------------------------------------------------
1 | /*
2 | 基于 docconv 库的 pdf 解析器;
3 | 实现了Eino 组件接口的 Parse 方法。
4 | */
5 |
6 | package pdf
7 |
8 | import (
9 | "context"
10 | "fmt"
11 | "io"
12 |
13 | "github.com/cloudwego/eino/components/document/parser"
14 | "github.com/cloudwego/eino/schema"
15 |
16 | "code.sajari.com/docconv/v2"
17 | )
18 |
19 | // options
20 | // 定制实现自主定义的 option 结构体
21 | type options struct {
22 | toPages *bool
23 | }
24 |
25 | func WithToPages(toPages bool) parser.Option {
26 | return parser.WrapImplSpecificOptFn(func(opts *options) {
27 | opts.toPages = &toPages
28 | })
29 | }
30 |
31 | type Config struct {
32 | ToPages bool
33 | }
34 | type DocconvPDFParser struct {
35 | ToPages bool
36 | }
37 |
38 | func NewDocconvPDFParser(ctx context.Context, config *Config) (*DocconvPDFParser, error) {
39 | if config == nil {
40 | config = &Config{}
41 | }
42 | return &DocconvPDFParser{ToPages: config.ToPages}, nil
43 | }
44 |
45 | func (pp *DocconvPDFParser) Parse(ctx context.Context, reader io.Reader, opts ...parser.Option) ([]*schema.Document, error) {
46 | // 1. 处理通用选项
47 | commonOpts := parser.GetCommonOptions(nil, opts...)
48 |
49 | specificOpts := parser.GetImplSpecificOptions(&options{
50 | toPages: &pp.ToPages,
51 | }, opts...)
52 |
53 | // 3. 实现解析逻辑
54 | fmt.Println("开始解析PDF文档...")
55 | res, meta, err := docconv.ConvertPDF(reader)
56 | if err != nil {
57 | fmt.Printf("PDF解析错误: %v\n", err)
58 | return nil, fmt.Errorf("PDF解析失败: %w", err)
59 | }
60 |
61 | fmt.Printf("PDF解析完成,文本长度: %d字符\n", len(res))
62 | fmt.Printf("PDF元数据: %+v\n", meta)
63 |
64 | // 检查解析结果是否为空
65 | if len(res) < 100 { // 至少需要100个字符才算有效
66 | fmt.Println("PDF解析结果太短或为空")
67 | if len(res) == 0 {
68 | return nil, fmt.Errorf("PDF解析结果为空,可能是扫描PDF或无文本内容")
69 | }
70 | }
71 |
72 | if *specificOpts.toPages {
73 | fmt.Println("待处理分页")
74 | }
75 |
76 | return []*schema.Document{{
77 | Content: res,
78 | MetaData: commonOpts.ExtraMeta,
79 | }}, nil
80 | }
81 |
--------------------------------------------------------------------------------
/internal/component/retriever/milvus/multi_retriever.go:
--------------------------------------------------------------------------------
1 | package milvus
2 |
3 | import (
4 | "ai-cloud/internal/component/embedding"
5 | "ai-cloud/internal/dao"
6 | "ai-cloud/internal/database"
7 | "context"
8 | "fmt"
9 | eretriever "github.com/cloudwego/eino/components/retriever"
10 | "github.com/cloudwego/eino/schema"
11 | "log"
12 | "sort"
13 | "time"
14 | )
15 |
16 | // 自定义的多知识库Retriever
17 | type MultiKBRetriever struct {
18 | KBIDs []string
19 | UserID uint
20 | KBDao dao.KnowledgeBaseDao
21 | ModelDao dao.ModelDao
22 | Ctx context.Context
23 | TopK int
24 | }
25 |
26 | func (m MultiKBRetriever) Retrieve(ctx context.Context, query string, opts ...eretriever.Option) ([]*schema.Document, error) {
27 | // 如果没有提供知识库ID,则返回空结果
28 | if len(m.KBIDs) == 0 {
29 | return []*schema.Document{}, nil
30 | }
31 |
32 | // 保存所有文档结果
33 | allDocuments := []*schema.Document{}
34 |
35 | // 对每个知识库进行检索
36 | for _, kbID := range m.KBIDs {
37 | // 获取知识库信息
38 | kb, err := m.KBDao.GetKBByID(kbID)
39 | if err != nil {
40 | return nil, fmt.Errorf("knowledge base not found: %w", err)
41 | }
42 | if kb.UserID != m.UserID {
43 | return nil, fmt.Errorf("userID mismatch: %w", err)
44 | }
45 |
46 | // 获取Embedding模型
47 | embedModel, err := m.ModelDao.GetByID(m.Ctx, m.UserID, kb.EmbedModelID)
48 | if err != nil {
49 | return nil, fmt.Errorf("failed to retrieve embedding model: %w", err)
50 | }
51 |
52 | // 创建Embedding服务
53 | embeddingService, err := embedding.NewEmbeddingService(
54 | m.Ctx,
55 | embedModel,
56 | embedding.WithTimeout(30*time.Second),
57 | )
58 | if err != nil {
59 | return nil, fmt.Errorf("failed to initialize embedding service: %w", err)
60 | }
61 |
62 | // 创建当前知识库的Retriever
63 | retrieverConf := &MilvusRetrieverConfig{
64 | Client: database.GetMilvusClient(),
65 | Embedding: embeddingService,
66 | Collection: kb.MilvusCollection,
67 | KBIDs: []string{kbID},
68 | SearchFields: nil,
69 | TopK: 3,
70 | ScoreThreshold: 0,
71 | }
72 |
73 | retriever, err := NewMilvusRetriever(ctx, retrieverConf)
74 | if err != nil {
75 | return nil, fmt.Errorf("failed to create retriever for kb %s: %w", kbID, err)
76 | }
77 |
78 | // 执行检索
79 | docs, err := retriever.Retrieve(ctx, query)
80 | if err != nil {
81 | return nil, fmt.Errorf("failed to retrieve from kb %s: %w", kbID, err)
82 | }
83 |
84 | // 将结果添加到总结果中
85 | allDocuments = append(allDocuments, docs...)
86 | }
87 |
88 | // TODO:先简单按照分数返回。这是不合理的!需要用rerank!
89 | sort.Slice(allDocuments, func(i, j int) bool {
90 | scoreI, okI := allDocuments[i].MetaData["score"].(float64)
91 | scoreJ, okJ := allDocuments[j].MetaData["score"].(float64)
92 |
93 | // 如果score不存在或类型不正确,则将其视为最低优先级
94 | if !okI {
95 | return false
96 | }
97 | if !okJ {
98 | return true
99 | }
100 |
101 | return scoreI > scoreJ // 降序排序
102 | })
103 |
104 | if len(allDocuments) > m.TopK {
105 | allDocuments = allDocuments[:m.TopK]
106 | }
107 | log.Printf("[Multi Retriever] Retrieved %d documents from %d knowledge bases", len(allDocuments), len(m.KBIDs))
108 | return allDocuments, nil
109 | }
110 |
111 | func (m *MultiKBRetriever) GetType() string {
112 | return "MultiKnowledgeBaseRetriever"
113 | }
114 |
--------------------------------------------------------------------------------
/internal/component/retriever/milvus/retriever.go:
--------------------------------------------------------------------------------
1 | package milvus
2 |
3 | import (
4 | "ai-cloud/config"
5 | "ai-cloud/internal/utils"
6 | "ai-cloud/pkgs/consts"
7 | "context"
8 | "fmt"
9 | "github.com/bytedance/sonic"
10 | "github.com/cloudwego/eino/components/embedding"
11 | "github.com/cloudwego/eino/components/retriever"
12 | "github.com/cloudwego/eino/schema"
13 | "github.com/milvus-io/milvus-sdk-go/v2/client"
14 | "github.com/milvus-io/milvus-sdk-go/v2/entity"
15 | "strings"
16 | )
17 |
18 | type MilvusRetrieverConfig struct {
19 | Client client.Client // Required
20 | Embedding embedding.Embedder // Required
21 | Collection string // Required
22 | KBIDs []string // Required 至少要查询一个知识库
23 | SearchFields []string // Optional defaultSearchFields
24 | TopK int // Optional default is 5
25 | ScoreThreshold float64 // Optional default is 0
26 | }
27 |
28 | type MilvusRetriever struct {
29 | config MilvusRetrieverConfig
30 | }
31 |
32 | func NewMilvusRetriever(ctx context.Context, conf *MilvusRetrieverConfig) (*MilvusRetriever, error) {
33 | // 检查必要配置,设置默认值
34 | if err := conf.check(); err != nil {
35 | return nil, fmt.Errorf("[NewMilvusRetriever] check config failed : %w", err)
36 | }
37 | // 检查Collection是否存在
38 | exists, err := conf.Client.HasCollection(ctx, conf.Collection)
39 | if err != nil {
40 | return nil, fmt.Errorf("[NewMilvusRetriever] check milvus collection failed : %w", err)
41 | }
42 | if !exists {
43 | return nil, fmt.Errorf("[NewMilvusRetirever] collection %s not exists", conf.Collection)
44 | }
45 |
46 | // 检查是否load,没load的话load
47 | collection, err := conf.Client.DescribeCollection(ctx, conf.Collection)
48 | if err != nil {
49 | return nil, fmt.Errorf("[NewRetriever] failed to describe collection: %w", err)
50 | }
51 |
52 | if !collection.Loaded {
53 | err = conf.Client.LoadCollection(ctx, conf.Collection, false)
54 | if err != nil {
55 | return nil, fmt.Errorf("[NewMilvusRetriever] failed to load collection: %w", err)
56 | }
57 | }
58 |
59 | return &MilvusRetriever{
60 | config: *conf,
61 | }, nil
62 | }
63 |
64 | var FieldNameVector = "vector"
65 |
66 | func (m *MilvusRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
67 | // retrieve的时候指定参数
68 | co := retriever.GetCommonOptions(&retriever.Options{
69 | SubIndex: nil,
70 | TopK: &m.config.TopK,
71 | ScoreThreshold: &m.config.ScoreThreshold,
72 | Embedding: m.config.Embedding,
73 | }, opts...)
74 |
75 | emb := co.Embedding
76 | vectors, err := emb.EmbedStrings(ctx, []string{query})
77 | if err != nil {
78 | return nil, fmt.Errorf("[MilvusRetriver.Retrieve] embedding has error: %w", err)
79 | }
80 | // 检查结果数量是否正确
81 | if len(vectors) != 1 {
82 | return nil, fmt.Errorf("[MilvusRetriver.Retrieve] invalid return length of vector, got=%d, expected=1", len(vectors))
83 | }
84 |
85 | vector := utils.ConvertFloat64ToFloat32Embedding(vectors[0])
86 |
87 | // 构造查询条件和参数
88 | kbIDs := m.config.KBIDs
89 | var expr string
90 | if len(kbIDs) > 0 {
91 | quotedIDs := make([]string, len(kbIDs))
92 | for i, id := range kbIDs {
93 | quotedIDs[i] = fmt.Sprintf(`"%s"`, id)
94 | }
95 | expr = fmt.Sprintf("%s in [%s]", consts.FieldNameKBID, strings.Join(quotedIDs, ","))
96 | } else {
97 | expr = "0 == 1"
98 | }
99 |
100 | var results []client.SearchResult
101 | sp, _ := entity.NewIndexIvfFlatSearchParam(config.GetConfig().Milvus.Nprobe)
102 | metricType := config.GetConfig().Milvus.GetMetricType()
103 | results, err = m.config.Client.Search(
104 | ctx,
105 | m.config.Collection, // 集合名称:指定要搜索的Milvus集合
106 | []string{}, // 分区名称:空表示搜索所有分区
107 | expr, // 过滤表达式:限制搜索范围,这里只搜索指定知识库ID的文档
108 | m.config.SearchFields, // 输出字段:指定返回结果中包含哪些字段
109 | []entity.Vector{entity.FloatVector(vector)}, // 查询向量:将输入向量转换为Milvus向量格式
110 | consts.FieldNameVector, // 向量字段名:指定在哪个字段上执行向量搜索(对应Index)
111 | metricType, // 度量类型:如何计算向量相似度(如余弦相似度、欧几里得距离等)
112 | m.config.TopK, // 返回数量:返回的最相似结果数量
113 | sp, // 搜索参数:索引特定的搜索参数,如nprobe(探测聚类数)
114 | )
115 |
116 | documents := make([]*schema.Document, 0, len(results))
117 | for _, result := range results {
118 | if result.Err != nil {
119 | return nil, fmt.Errorf("[MilvusRetriver.Retrieve] search result has error: %w", result.Err)
120 | }
121 | if result.IDs == nil || result.Fields == nil {
122 | return nil, fmt.Errorf("[MilvusRetriver.Retrieve] search result has no ids or fields")
123 | }
124 | document, err := DocumentConverter(ctx, result)
125 | if err != nil {
126 | return nil, fmt.Errorf("[MilvusRetriver.Retrieve] failed to convert search result to schema.Document: %w", err)
127 | }
128 | documents = append(documents, document...)
129 | }
130 | return documents, nil
131 | }
132 |
133 | // defaultDocumentConverter returns the default document converter
134 | func DocumentConverter(ctx context.Context, doc client.SearchResult) ([]*schema.Document, error) {
135 | var err error
136 | result := make([]*schema.Document, doc.IDs.Len(), doc.IDs.Len())
137 | for i := range result {
138 | result[i] = &schema.Document{
139 | MetaData: make(map[string]any),
140 | }
141 | }
142 |
143 | importantMetaFields := map[string]bool{
144 | consts.FieldNameDocumentID: true,
145 | consts.FieldNameKBID: true,
146 | }
147 |
148 | for _, field := range doc.Fields {
149 | switch field.Name() {
150 | case consts.FieldNameID:
151 | for i, document := range result {
152 | document.ID, err = doc.IDs.GetAsString(i)
153 | if err != nil {
154 | return nil, fmt.Errorf("failed to get id: %w", err)
155 | }
156 | }
157 | case consts.FieldNameContent:
158 | for i, document := range result {
159 | document.Content, err = field.GetAsString(i)
160 | if err != nil {
161 | return nil, fmt.Errorf("failed to get content: %w", err)
162 | }
163 | }
164 | case consts.FieldNameMetadata:
165 | for i := range result {
166 | val, _ := field.Get(i)
167 | bytes, ok := val.([]byte)
168 | if !ok {
169 | return nil, fmt.Errorf("metadata field is not []byte")
170 | }
171 | var meta map[string]any
172 | if err := sonic.Unmarshal(bytes, &meta); err != nil {
173 | return nil, fmt.Errorf("unmarshal metadata failed: %w", err)
174 | }
175 | for k, v := range meta {
176 | result[i].MetaData[k] = v
177 | }
178 | }
179 | default:
180 | if importantMetaFields[field.Name()] {
181 | for i := range result {
182 | val, err := field.GetAsString(i)
183 | if err != nil {
184 | return nil, fmt.Errorf("get field %s failed: %w", field.Name(), err)
185 | }
186 | result[i].MetaData[field.Name()] = val
187 | }
188 | }
189 | }
190 | }
191 |
192 | for i := range result {
193 | if i >= len(doc.Scores) {
194 | continue
195 | }
196 | result[i].MetaData["score"] = doc.Scores[i]
197 | }
198 |
199 | return result, nil
200 | }
201 |
202 | func (m *MilvusRetriever) GetType() string {
203 | return "Milvus"
204 | }
205 |
206 | // 检查必要配置
207 | func (m *MilvusRetrieverConfig) check() error {
208 | if m.Client == nil {
209 | return fmt.Errorf("[NewMilvusRetriever] milvus client is nil")
210 | }
211 | if m.Embedding == nil {
212 | return fmt.Errorf("[NewMilvusRetriever] embedding is nil")
213 | }
214 | if m.Collection == "" {
215 | return fmt.Errorf("[NewMilvusRetriever] collection is empty")
216 | }
217 | if m.SearchFields == nil {
218 | m.SearchFields = defaultSearchFields // 默认搜索字段
219 | }
220 | if m.TopK == 0 {
221 | m.TopK = 5 // 默认返回结果数量
222 | }
223 | if m.ScoreThreshold == 0 {
224 | m.ScoreThreshold = 0 // 默认相似度阈值
225 | }
226 |
227 | return nil
228 | }
229 |
--------------------------------------------------------------------------------
/internal/component/retriever/milvus/types.go:
--------------------------------------------------------------------------------
1 | package milvus
2 |
3 | import "ai-cloud/pkgs/consts"
4 |
5 | var (
6 | defaultSearchFields = []string{
7 | consts.FieldNameID,
8 | consts.FieldNameContent,
9 | consts.FieldNameKBID,
10 | consts.FieldNameDocumentID,
11 | consts.FieldNameMetadata,
12 | }
13 | )
14 |
--------------------------------------------------------------------------------
/internal/controller/conversation_controller.go:
--------------------------------------------------------------------------------
1 | package controller
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "ai-cloud/internal/service"
6 | "ai-cloud/internal/utils"
7 | "ai-cloud/pkgs/errcode"
8 | "ai-cloud/pkgs/response"
9 | "errors"
10 | "io"
11 | "log"
12 |
13 | "github.com/gin-contrib/sse"
14 | "github.com/gin-gonic/gin"
15 | "github.com/google/uuid"
16 | )
17 |
18 | type ConversationController struct {
19 | svc service.ConversationService
20 | }
21 |
22 | func NewConversationController(svc service.ConversationService) *ConversationController {
23 | return &ConversationController{svc: svc}
24 | }
25 |
26 | // DebugStreamAgent 调试模式,不保存历史
27 | func (c *ConversationController) DebugStreamAgent(ctx *gin.Context) {
28 | userID, err := utils.GetUserIDFromContext(ctx)
29 | if err != nil {
30 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "Failed to get user")
31 | return
32 | }
33 |
34 | var req model.DebugRequest
35 | if err := ctx.ShouldBindJSON(&req); err != nil {
36 | response.ParamError(ctx, errcode.ParamBindError, "Parameter error: "+err.Error())
37 | return
38 | }
39 |
40 | // 调用debug模式流式处理
41 | sr, err := c.svc.DebugStreamAgent(ctx.Request.Context(), userID, req.AgentID, req.Message)
42 | if err != nil {
43 | response.InternalError(ctx, errcode.InternalServerError, "Agent execution failed: "+err.Error())
44 | return
45 | }
46 |
47 | // 设置SSE响应头
48 | ctx.Writer.Header().Set("Content-Type", "text/event-stream")
49 | ctx.Writer.Header().Set("Cache-Control", "no-cache")
50 | ctx.Writer.Header().Set("Connection", "keep-alive")
51 | ctx.Writer.Header().Set("Transfer-Encoding", "chunked")
52 |
53 | // 传输流
54 | sessionID := uuid.NewString()
55 | done := make(chan struct{})
56 | defer func() {
57 | sr.Close()
58 | close(done)
59 | log.Printf("[Debug Stream] Finish Stream with ID: %s\n", sessionID)
60 | }()
61 |
62 | // 流式响应
63 | ctx.Stream(func(w io.Writer) bool {
64 | select {
65 | case <-ctx.Request.Context().Done():
66 | log.Printf("[Debug Stream] Context done for session ID: %s\n", sessionID)
67 | return false
68 | case <-done:
69 | return false
70 | default:
71 | msg, err := sr.Recv()
72 | if errors.Is(err, io.EOF) {
73 | log.Printf("[Debug Stream] EOF received for session ID: %s\n", sessionID)
74 | return false
75 | }
76 | if err != nil {
77 | log.Printf("[Debug Stream] Error receiving message: %v\n", err)
78 | return false
79 | }
80 |
81 | // 发送SSE事件
82 | sse.Encode(w, sse.Event{
83 | Data: []byte(msg.Content),
84 | })
85 |
86 | // 立即刷新响应
87 | ctx.Writer.Flush()
88 | return true
89 | }
90 | })
91 | }
92 |
93 | // CreateConversation 创建新会话
94 | func (c *ConversationController) CreateConversation(ctx *gin.Context) {
95 | userID, err := utils.GetUserIDFromContext(ctx)
96 | if err != nil {
97 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "Failed to get user")
98 | return
99 | }
100 |
101 | var req model.CreateConvRequest
102 | if err := ctx.ShouldBindJSON(&req); err != nil {
103 | response.ParamError(ctx, errcode.ParamBindError, "Parameter error: "+err.Error())
104 | return
105 | }
106 |
107 | // 创建会话
108 | convID, err := c.svc.CreateConversation(ctx.Request.Context(), userID, req.AgentID)
109 | if err != nil {
110 | log.Printf("[Conversation Create] Error creating conversation: %v\n", err)
111 | response.InternalError(ctx, errcode.InternalServerError, "Failed to create conversation")
112 | return
113 | }
114 |
115 | response.SuccessWithMessage(ctx, "Conversation created successfully", gin.H{"conv_id": convID})
116 | }
117 |
118 | // StreamConversation 会话模式,保存历史
119 | func (c *ConversationController) StreamConversation(ctx *gin.Context) {
120 | userID, err := utils.GetUserIDFromContext(ctx)
121 | if err != nil {
122 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "Failed to get user")
123 | return
124 | }
125 |
126 | var req model.ConvRequest
127 | if err := ctx.ShouldBindJSON(&req); err != nil {
128 | response.ParamError(ctx, errcode.ParamBindError, "Parameter error: "+err.Error())
129 | return
130 | }
131 |
132 | if req.ConvID == "" {
133 | convID, err := c.svc.CreateConversation(ctx.Request.Context(), userID, req.AgentID)
134 | if err != nil {
135 | log.Printf("[Conversation Create] Error creating conversation: %v\n", err)
136 | response.InternalError(ctx, errcode.InternalServerError, "Failed to create conversation")
137 | return
138 | }
139 | req.ConvID = convID
140 | }
141 |
142 | // 调用会话模式流式处理
143 | sr, err := c.svc.StreamAgentWithConversation(ctx.Request.Context(), userID, req.AgentID, req.ConvID, req.Message)
144 | if err != nil {
145 | log.Printf("[Conversation Stream] Error running agent: %v\n", err)
146 | response.InternalError(ctx, errcode.InternalServerError, "Agent execution failed")
147 | return
148 | }
149 |
150 | // 设置SSE响应头
151 | ctx.Writer.Header().Set("Content-Type", "text/event-stream")
152 | ctx.Writer.Header().Set("Cache-Control", "no-cache")
153 | ctx.Writer.Header().Set("Connection", "keep-alive")
154 | ctx.Writer.Header().Set("Transfer-Encoding", "chunked")
155 |
156 | // 传输流
157 | done := make(chan struct{})
158 | defer func() {
159 | sr.Close()
160 | close(done)
161 | log.Printf("[Conversation Stream] Finish Stream with ConvID: %s\n", req.ConvID)
162 | }()
163 |
164 | // 流式响应
165 | ctx.Stream(func(w io.Writer) bool {
166 | select {
167 | case <-ctx.Request.Context().Done():
168 | log.Printf("[Conversation Stream] Context done for ConvID: %s\n", req.ConvID)
169 | return false
170 | case <-done:
171 | return false
172 | default:
173 | msg, err := sr.Recv()
174 | if errors.Is(err, io.EOF) {
175 | log.Printf("[Conversation Stream] EOF received for ConvID: %s\n", req.ConvID)
176 | return false
177 | }
178 | if err != nil {
179 | log.Printf("[Conversation Stream] Error receiving message: %v\n", err)
180 | return false
181 | }
182 |
183 | // 发送SSE事件
184 | sse.Encode(w, sse.Event{
185 | Data: []byte(msg.Content),
186 | })
187 |
188 | // 立即刷新响应
189 | ctx.Writer.Flush()
190 | return true
191 | }
192 | })
193 | }
194 |
195 | // ListConversations 获取用户所有会话
196 | func (c *ConversationController) ListConversations(ctx *gin.Context) {
197 | userID, err := utils.GetUserIDFromContext(ctx)
198 | if err != nil {
199 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "Failed to get user")
200 | return
201 | }
202 |
203 | // 分页参数
204 | page := utils.StringToInt(ctx.DefaultQuery("page", "1"))
205 | size := utils.StringToInt(ctx.DefaultQuery("size", "10"))
206 |
207 | // 获取会话列表
208 | convs, count, err := c.svc.ListConversations(ctx.Request.Context(), userID, page, size)
209 | if err != nil {
210 | response.InternalError(ctx, errcode.InternalServerError, "Failed to list conversations: "+err.Error())
211 | return
212 | }
213 |
214 | // 返回分页数据
215 | response.PageSuccess(ctx, convs, count)
216 | }
217 |
218 | // ListAgentConversations 获取特定Agent的会话
219 | func (c *ConversationController) ListAgentConversations(ctx *gin.Context) {
220 | userID, err := utils.GetUserIDFromContext(ctx)
221 | if err != nil {
222 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "Failed to get user")
223 | return
224 | }
225 |
226 | // 获取AgentID
227 | agentID := ctx.Query("agent_id")
228 | if agentID == "" {
229 | response.ParamError(ctx, errcode.ParamBindError, "Agent ID is required")
230 | return
231 | }
232 |
233 | // 分页参数
234 | page := utils.StringToInt(ctx.DefaultQuery("page", "1"))
235 | size := utils.StringToInt(ctx.DefaultQuery("size", "10"))
236 |
237 | // 获取会话列表
238 | convs, count, err := c.svc.ListAgentConversations(ctx.Request.Context(), userID, agentID, page, size)
239 | if err != nil {
240 | response.InternalError(ctx, errcode.InternalServerError, "Failed to list agent conversations: "+err.Error())
241 | return
242 | }
243 |
244 | // 返回分页数据
245 | response.PageSuccess(ctx, convs, count)
246 | }
247 |
248 | // GetConversationHistory 获取会话历史消息
249 | func (c *ConversationController) GetConversationHistory(ctx *gin.Context) {
250 | // 获取会话ID
251 | convID := ctx.Query("conv_id")
252 | if convID == "" {
253 | response.ParamError(ctx, errcode.ParamBindError, "Conversation ID is required")
254 | return
255 | }
256 |
257 | // 限制参数
258 | limit := utils.StringToInt(ctx.DefaultQuery("limit", "50"))
259 |
260 | // 获取历史消息
261 | msgs, err := c.svc.GetConversationHistory(ctx.Request.Context(), convID, limit)
262 | if err != nil {
263 | response.InternalError(ctx, errcode.InternalServerError, "Failed to get conversation history: "+err.Error())
264 | return
265 | }
266 |
267 | // 返回历史消息
268 | response.SuccessWithMessage(ctx, "Conversation history retrieved successfully", gin.H{"messages": msgs})
269 | }
270 |
271 | // DeleteConversation 删除会话
272 | func (c *ConversationController) DeleteConversation(ctx *gin.Context) {
273 | // 获取会话ID
274 | convID := ctx.Query("conv_id")
275 | if convID == "" {
276 | response.ParamError(ctx, errcode.ParamBindError, "Conversation ID is required")
277 | return
278 | }
279 |
280 | // 删除会话
281 | err := c.svc.DeleteConversation(ctx.Request.Context(), convID)
282 | if err != nil {
283 | log.Printf("[Conversation Delete] Error deleting conversation: %v\n", err)
284 | response.InternalError(ctx, errcode.InternalServerError, "Failed to delete conversation: "+err.Error())
285 | return
286 | }
287 |
288 | // 返回成功消息
289 | response.SuccessWithMessage(ctx, "Conversation deleted successfully", nil)
290 | }
291 |
--------------------------------------------------------------------------------
/internal/controller/file_controller.go:
--------------------------------------------------------------------------------
1 | package controller
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "ai-cloud/internal/service"
6 | "ai-cloud/internal/utils"
7 | "ai-cloud/pkgs/errcode"
8 | "ai-cloud/pkgs/response"
9 | "fmt"
10 | "net/http"
11 | "strconv"
12 |
13 | "github.com/gin-gonic/gin"
14 | )
15 |
16 | type FileController struct {
17 | fileService service.FileService
18 | }
19 |
20 | func NewFileController(fileService service.FileService) *FileController {
21 | return &FileController{fileService: fileService}
22 | }
23 |
24 | func (fc *FileController) Upload(ctx *gin.Context) {
25 | // 1. 获取用户ID
26 | userID, err := utils.GetUserIDFromContext(ctx)
27 | if err != nil {
28 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
29 | return
30 | }
31 |
32 | // 2. 解析表单文件
33 | fileHeader, err := ctx.FormFile("file")
34 | if err != nil {
35 | response.ParamError(ctx, errcode.ParamBindError, "上传失败")
36 | return
37 | }
38 |
39 | // 获取文件内容
40 | file, err := fileHeader.Open()
41 | if err != nil {
42 |
43 | response.ParamError(ctx, errcode.FileParseFailed, "上传失败")
44 | return
45 | }
46 | defer file.Close()
47 |
48 | // 4. 获取父目录ID(可选参数)
49 | parentID := ctx.PostForm("parent_id") // 空字符串表示根目录
50 |
51 | // 调用 Service 层处理文件上传
52 | _, err = fc.fileService.UploadFile(userID, fileHeader, file, parentID)
53 | if err != nil {
54 | response.InternalError(ctx, errcode.FileUploadFailed, "上传失败")
55 | return
56 | }
57 | response.SuccessWithMessage(ctx, "文件上传成功", nil)
58 | }
59 |
60 | func (fc *FileController) PageList(ctx *gin.Context) {
61 | // 获取用户ID并验证
62 | userID, err := utils.GetUserIDFromContext(ctx)
63 | if err != nil {
64 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
65 | return
66 | }
67 | // 获取父目录,处理根目录情况
68 | parentID := ctx.Query("parent_id")
69 | var parentIDPtr *string
70 | if parentID != "" {
71 | parentIDPtr = &parentID
72 | }
73 |
74 | page, pageSize, err := utils.ParsePaginationParams(ctx)
75 | if err != nil {
76 | response.ParamError(ctx, errcode.ParamBindError, "分页参数错误")
77 | return
78 | }
79 |
80 | sort := ctx.DefaultQuery("sort", "name:asc")
81 | if err := utils.ValidateSortParameter(sort, []string{"name", "updated_at"}); err != nil {
82 | response.ParamError(ctx, errcode.ParamValidateError, "排序参数错误")
83 | return
84 | }
85 |
86 | total, files, err := fc.fileService.PageList(userID, parentIDPtr, page, pageSize, sort)
87 | if err != nil {
88 | response.InternalError(ctx, errcode.FileListFailed, "获取文件列表失败")
89 | return
90 | }
91 |
92 | response.PageSuccess(ctx, files, total)
93 | }
94 |
95 | func (fc *FileController) Download(ctx *gin.Context) {
96 | fileID := ctx.Query("file_id")
97 |
98 | userID, err := utils.GetUserIDFromContext(ctx)
99 | if err != nil {
100 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
101 | return
102 | }
103 |
104 | fileMeta, fileData, err := fc.fileService.DownloadFile(fileID)
105 |
106 | if err != nil {
107 | response.InternalError(ctx, errcode.FileNotFound, "文件不存在")
108 | return
109 | }
110 |
111 | if userID != fileMeta.UserID {
112 | response.UnauthorizedError(ctx, errcode.ForbiddenError, "权限不足")
113 | return
114 | }
115 | ctx.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", fileMeta.Name))
116 | ctx.Header("Content-Type", fileMeta.MIMEType)
117 | ctx.Header("Content-Length", strconv.FormatInt(fileMeta.Size, 10))
118 | ctx.Data(http.StatusOK, fileMeta.MIMEType, fileData)
119 | }
120 |
121 | func (fc *FileController) Delete(ctx *gin.Context) {
122 | fileID := ctx.Query("file_id")
123 | if fileID == "" {
124 | response.ParamError(ctx, errcode.ParamValidateError, "参数错误")
125 | return
126 | }
127 | userID, err := utils.GetUserIDFromContext(ctx)
128 | if err != nil {
129 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户未认证")
130 | return
131 | }
132 |
133 | if err := fc.fileService.DeleteFileOrFolder(userID, fileID); err != nil {
134 | response.InternalError(ctx, errcode.FileDeleteFailed, "删除失败")
135 | return
136 | }
137 |
138 | response.SuccessWithMessage(ctx, "删除成功", nil)
139 | }
140 |
141 | func (fc *FileController) CreateFolder(ctx *gin.Context) {
142 | //var req models.CreateFolderRequest
143 | //if err := ctx.ShouldBind(&req); err != nil {
144 | // ctx.JSON(http.StatusBadRequest, common.Error(100, "参数错误"))
145 | // return
146 | //}
147 | req := model.CreateFolderReq{}
148 | err := ctx.ShouldBindJSON(&req)
149 | if err != nil {
150 | response.ParamError(ctx, errcode.ParamBindError, "参数错误")
151 | return
152 | }
153 |
154 | if req.ParentID != nil && *req.ParentID == "" {
155 | req.ParentID = nil
156 | }
157 |
158 | userID, err := utils.GetUserIDFromContext(ctx)
159 | if err != nil {
160 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
161 | return
162 | }
163 |
164 | err = fc.fileService.CreateFolder(userID, req.Name, req.ParentID)
165 | if err != nil {
166 | response.InternalError(ctx, errcode.InternalServerError, "文件夹创建失败")
167 | return
168 | }
169 | response.SuccessWithMessage(ctx, "创建成功", nil)
170 | }
171 |
172 | // BatchMove 批量移动文件/文件夹
173 | func (fc *FileController) BatchMove(ctx *gin.Context) {
174 | var req model.BatchMoveRequest
175 | if err := ctx.ShouldBindJSON(&req); err != nil {
176 | response.ParamError(ctx, errcode.ParamBindError, "参数错误")
177 | return
178 | }
179 |
180 | // 获取当前用户ID
181 | userID, err := utils.GetUserIDFromContext(ctx)
182 | if err != nil {
183 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
184 | return
185 | }
186 |
187 | // 执行批量移动
188 | if err := fc.fileService.BatchMoveFiles(userID, req.FileIDs, req.TargetParentID); err != nil {
189 | response.InternalError(ctx, errcode.InternalServerError, "移动失败")
190 | return
191 | }
192 |
193 | response.SuccessWithMessage(ctx, "移动成功", nil)
194 | }
195 |
196 | func (fc *FileController) Search(ctx *gin.Context) {
197 | userID, err := utils.GetUserIDFromContext(ctx)
198 | if err != nil {
199 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户校验失败")
200 | return
201 | }
202 |
203 | key := ctx.Query("key")
204 |
205 | page, pageSize, err := utils.ParsePaginationParams(ctx)
206 | if err != nil {
207 | response.ParamError(ctx, errcode.ParamValidateError, "分页参数错误")
208 | return
209 | }
210 |
211 | sort := ctx.DefaultQuery("sort", "name:asc")
212 | if err := utils.ValidateSortParameter(sort, []string{"name", "upadated_at"}); err != nil {
213 | response.ParamError(ctx, errcode.ParamValidateError, "排序参数错误")
214 | }
215 |
216 | total, files, err := fc.fileService.SearchList(userID, key, page, pageSize, sort)
217 | if err != nil {
218 | response.InternalError(ctx, errcode.FileSearchFailed, "搜索文件失败")
219 | }
220 |
221 | response.PageSuccess(ctx, files, total)
222 |
223 | }
224 |
225 | func (fc *FileController) Rename(ctx *gin.Context) {
226 | var req model.RenameRequest
227 | if err := ctx.ShouldBindJSON(&req); err != nil {
228 | response.ParamError(ctx, errcode.ParamBindError, "参数错误")
229 | return
230 | }
231 |
232 | userID, err := utils.GetUserIDFromContext(ctx)
233 | if err != nil {
234 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
235 | return
236 | }
237 |
238 | if err = fc.fileService.Rename(userID, req.FileID, req.NewName); err != nil {
239 | response.InternalError(ctx, errcode.InternalServerError, fmt.Sprintf("重命名失败 %s", err))
240 | return
241 | }
242 |
243 | response.SuccessWithMessage(ctx, "重命名成功", nil)
244 |
245 | }
246 |
247 | // GetPath 获取文件的完整路径
248 | func (fc *FileController) GetPath(ctx *gin.Context) {
249 | fileID := ctx.Query("file_id")
250 | if fileID == "" {
251 | response.ParamError(ctx, errcode.ParamValidateError, "文件ID不能为空")
252 | return
253 | }
254 |
255 | // 获取文件路径
256 | path, err := fc.fileService.GetFilePath(fileID)
257 | if err != nil {
258 | response.InternalError(ctx, errcode.FileNotFound, "获取文件路径失败")
259 | return
260 | }
261 |
262 | response.SuccessWithMessage(ctx, "获取文件路径成功", gin.H{
263 | "path": path,
264 | })
265 | }
266 |
267 | // GetIDPath 获取文件的ID路径
268 | func (fc *FileController) GetIDPath(ctx *gin.Context) {
269 | fileID := ctx.Query("file_id")
270 | if fileID == "" {
271 | response.ParamError(ctx, errcode.ParamValidateError, "文件ID不能为空")
272 | return
273 | }
274 |
275 | // 获取文件ID路径
276 | path, err := fc.fileService.GetFileIDPath(fileID)
277 | if err != nil {
278 | response.InternalError(ctx, errcode.FileNotFound, "获取文件路径失败")
279 | return
280 | }
281 |
282 | response.SuccessWithMessage(ctx, "获取文件ID路径成功", gin.H{
283 | "id_path": path,
284 | })
285 | }
286 |
--------------------------------------------------------------------------------
/internal/controller/kb_controller.go:
--------------------------------------------------------------------------------
1 | package controller
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "ai-cloud/internal/service"
6 | "ai-cloud/internal/utils"
7 | "ai-cloud/pkgs/errcode"
8 | "ai-cloud/pkgs/response"
9 | "encoding/json"
10 | "fmt"
11 | "path/filepath"
12 | "strings"
13 |
14 | "github.com/gin-gonic/gin"
15 | )
16 |
17 | type KBController struct {
18 | kbService service.KBService
19 | fileService service.FileService
20 | }
21 |
22 | func NewKBController(kbService service.KBService, fileService service.FileService) *KBController {
23 | return &KBController{kbService: kbService, fileService: fileService}
24 | }
25 |
26 | func (kc *KBController) Create(ctx *gin.Context) {
27 | // 1. 获取用户ID
28 | userID, err := utils.GetUserIDFromContext(ctx)
29 | if err != nil {
30 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
31 | return
32 | }
33 |
34 | var req model.CreateKBRequest
35 | if err := ctx.ShouldBindJSON(&req); err != nil {
36 | response.ParamError(ctx, errcode.ParamBindError, "参数错误")
37 | return
38 | }
39 |
40 | if err := kc.kbService.CreateKB(userID, req.Name, req.Description, req.EmbedModelID); err != nil {
41 | response.InternalError(ctx, errcode.InternalServerError, "创建失败: "+err.Error())
42 | return
43 | }
44 |
45 | response.SuccessWithMessage(ctx, "创建知识库成功", nil)
46 | }
47 |
48 | // 删除知识库
49 | func (kc *KBController) Delete(ctx *gin.Context) {
50 | // 获取用户ID并验证
51 | userID, err := utils.GetUserIDFromContext(ctx)
52 | if err != nil {
53 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
54 | return
55 | }
56 |
57 | kbID := ctx.Query("kb_id")
58 |
59 | // 删除知识库
60 | if err := kc.kbService.DeleteKB(userID, kbID); err != nil {
61 | response.InternalError(ctx, errcode.InternalServerError, err.Error())
62 | return
63 | }
64 | response.SuccessWithMessage(ctx, "删除知识库成功", nil)
65 | }
66 |
67 | func (kc *KBController) PageList(ctx *gin.Context) {
68 | // 获取用户ID并验证
69 | userID, err := utils.GetUserIDFromContext(ctx)
70 | if err != nil {
71 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
72 | return
73 | }
74 |
75 | page, pageSize, err := utils.ParsePaginationParams(ctx)
76 | if err != nil {
77 | response.ParamError(ctx, errcode.ParamBindError, "分页参数错误")
78 | return
79 | }
80 |
81 | total, kbs, err := kc.kbService.PageList(userID, page, pageSize)
82 | if err != nil {
83 | response.InternalError(ctx, errcode.InternalServerError, "获取知识库列表失败")
84 | return
85 | }
86 |
87 | response.PageSuccess(ctx, kbs, total)
88 | }
89 |
90 | func (kc *KBController) DocPage(ctx *gin.Context) {
91 | // 获取用户ID并验证
92 | userID, err := utils.GetUserIDFromContext(ctx)
93 | if err != nil {
94 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
95 | return
96 | }
97 | page, pageSize, err := utils.ParsePaginationParams(ctx)
98 | if err != nil {
99 | response.ParamError(ctx, errcode.ParamBindError, "分页参数错误")
100 | return
101 | }
102 | kbID := ctx.Query("kb_id")
103 | total, docs, err := kc.kbService.DocList(userID, kbID, page, pageSize)
104 | if err != nil {
105 | response.InternalError(ctx, errcode.InternalServerError, "获取列表失败")
106 | fmt.Printf(err.Error())
107 | return
108 | }
109 | response.PageSuccess(ctx, docs, total)
110 | }
111 | func (kc *KBController) AddExistFile(ctx *gin.Context) {
112 | // 获取用户ID并验证
113 | userID, err := utils.GetUserIDFromContext(ctx)
114 | if err != nil {
115 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
116 | return
117 | }
118 | req := model.AddFileRequest{}
119 |
120 | if err := ctx.ShouldBindJSON(&req); err != nil {
121 | response.ParamError(ctx, errcode.ParamBindError, "参数错误")
122 | return
123 | }
124 | file, err := kc.fileService.GetFileByID(req.FileID)
125 | if err != nil {
126 | response.InternalError(ctx, errcode.InternalServerError, "获取文件信息失败")
127 | return
128 | }
129 |
130 | // 添加文件到知识库
131 | doc, err := kc.kbService.CreateDocument(userID, req.KBID, file)
132 | if err != nil {
133 | response.InternalError(ctx, errcode.InternalServerError, "添加文件到知识库失败")
134 | return
135 | }
136 | // 处理解析文档
137 | doc.Status = 1 //正在处理文档
138 |
139 | if err = kc.kbService.ProcessDocument(ctx, userID, req.KBID, doc); err != nil {
140 | response.InternalError(ctx, errcode.InternalServerError, err.Error())
141 | return
142 | }
143 | response.SuccessWithMessage(ctx, "添加文件到知识库成功", nil)
144 | }
145 |
146 | // 上传新的文件到知识库
147 | func (kc *KBController) AddNewFile(ctx *gin.Context) {
148 | userID, err := utils.GetUserIDFromContext(ctx)
149 | if err != nil {
150 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
151 | return
152 | }
153 |
154 | // 获取知识库ID
155 | kbID := ctx.PostForm("kb_id")
156 | if kbID == "" {
157 | response.ParamError(ctx, errcode.ParamBindError, "知识库ID不能为空")
158 | return
159 | }
160 |
161 | fileHeader, err := ctx.FormFile("file")
162 | if err != nil {
163 | response.ParamError(ctx, errcode.ParamBindError, "文件上传失败")
164 | return
165 | }
166 |
167 | // 检查文件大小
168 | if fileHeader.Size > 20*1024*1024 { // 20MB限制
169 | response.ParamError(ctx, errcode.ParamBindError, "文件大小不能超过20MB")
170 | return
171 | }
172 |
173 | file, err := fileHeader.Open()
174 | if err != nil {
175 | response.ParamError(ctx, errcode.ParamBindError, "文件打开失败")
176 | return
177 | }
178 | defer file.Close()
179 |
180 | folderID, err := kc.fileService.InitKnowledgeDir(userID)
181 | if err != nil {
182 | response.InternalError(ctx, errcode.InternalServerError, "初始化知识库目录失败"+err.Error())
183 | return
184 | }
185 | fileID, err := kc.fileService.UploadFile(userID, fileHeader, file, folderID)
186 | if err != nil {
187 | response.InternalError(ctx, errcode.InternalServerError, "文件上传失败")
188 | return
189 | }
190 |
191 | // 将文件添加到知识库中
192 | f, err := kc.fileService.GetFileByID(fileID)
193 | if err != nil || f == nil { // 添加对 nil 的检查
194 | response.InternalError(ctx, errcode.InternalServerError, "获取文件信息失败")
195 | return
196 | }
197 |
198 | // 文档名称长度检查
199 | if len(f.Name) > 200 {
200 | // 截断文件名
201 | nameBase := filepath.Base(f.Name)
202 | nameExt := filepath.Ext(nameBase)
203 | nameWithoutExt := strings.TrimSuffix(nameBase, nameExt)
204 | if len(nameWithoutExt) > 195 {
205 | nameWithoutExt = nameWithoutExt[:195]
206 | }
207 | f.Name = nameWithoutExt + nameExt
208 | }
209 |
210 | doc, err := kc.kbService.CreateDocument(userID, kbID, f)
211 | if err != nil {
212 | response.InternalError(ctx, errcode.InternalServerError, "添加文件到知识库失败")
213 | return
214 | }
215 |
216 | doc.Status = 1 // 正在处理文档
217 | if err = kc.kbService.ProcessDocument(ctx, userID, kbID, doc); err != nil {
218 | response.InternalError(ctx, errcode.InternalServerError, "处理文档失败: "+err.Error())
219 | return
220 | }
221 | response.SuccessWithMessage(ctx, "添加文件到知识库成功", nil)
222 | }
223 |
224 | func (kc *KBController) Retrieve(ctx *gin.Context) {
225 | // 1. 获取用户ID
226 | userID, err := utils.GetUserIDFromContext(ctx)
227 | if err != nil {
228 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
229 | return
230 | }
231 |
232 | // 2. 解析请求参数
233 | var req model.RetrieveRequest
234 |
235 | if err := ctx.ShouldBindJSON(&req); err != nil {
236 | response.ParamError(ctx, errcode.ParamBindError, "参数错误")
237 | return
238 | }
239 |
240 | // 3. 调用服务层检索
241 | docs, err := kc.kbService.Retrieve(ctx, userID, req.KBID, req.Query, req.TopK)
242 | if err != nil {
243 | response.InternalError(ctx, errcode.InternalServerError, err.Error())
244 | return
245 | }
246 |
247 | // 4. 返回结果
248 | response.Success(ctx, docs)
249 | }
250 |
251 | func (kc *KBController) Chat(ctx *gin.Context) {
252 | userID, err := utils.GetUserIDFromContext(ctx)
253 | if err != nil {
254 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
255 | return
256 | }
257 |
258 | // 2. 解析请求参数
259 | var req model.ChatRequest
260 | if err := ctx.ShouldBindJSON(&req); err != nil {
261 | response.ParamError(ctx, errcode.ParamBindError, "参数错误")
262 | return
263 | }
264 |
265 | // 3. 调用服务层处理
266 | resp, err := kc.kbService.RAGQuery(ctx, userID, req.Query, req.KBs)
267 | if err != nil {
268 | response.InternalError(ctx, errcode.InternalServerError, err.Error())
269 | return
270 | }
271 |
272 | // 4. 返回结果
273 | response.Success(ctx, resp)
274 |
275 | }
276 |
277 | func (kc *KBController) ChatStream(ctx *gin.Context) {
278 | // 1. 获取用户ID
279 | userID, err := utils.GetUserIDFromContext(ctx)
280 | if err != nil {
281 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
282 | return
283 | }
284 |
285 | // 2. 解析请求参数
286 | var req model.ChatRequest
287 | if err := ctx.ShouldBindJSON(&req); err != nil {
288 | response.ParamError(ctx, errcode.ParamBindError, "参数错误")
289 | return
290 | }
291 |
292 | // 3. 设置响应头
293 | ctx.Writer.Header().Set("Content-Type", "text/event-stream")
294 | ctx.Writer.Header().Set("Cache-Control", "no-cache")
295 | ctx.Writer.Header().Set("Connection", "keep-alive")
296 |
297 | // 4. 调用服务层获取流式响应
298 | responseChan, err := kc.kbService.RAGQueryStream(ctx.Request.Context(), userID, req.Query, req.KBs)
299 | if err != nil {
300 | ctx.SSEvent("error", err.Error())
301 | return
302 | }
303 |
304 | // 5. 发送流式响应
305 | for r := range responseChan {
306 | data, _ := json.Marshal(r)
307 | ctx.Writer.Write([]byte("data: " + string(data) + "\n\n"))
308 | ctx.Writer.Flush()
309 | }
310 | }
311 |
312 | func (kc *KBController) GetKBDetail(ctx *gin.Context) {
313 | // 获取用户ID并验证
314 | userID, err := utils.GetUserIDFromContext(ctx)
315 | if err != nil {
316 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
317 | return
318 | }
319 |
320 | // 获取知识库ID
321 | kbID := ctx.Query("kb_id")
322 | if kbID == "" {
323 | response.ParamError(ctx, errcode.ParamBindError, "知识库ID不能为空")
324 | return
325 | }
326 |
327 | // 获取知识库详情
328 | kb, err := kc.kbService.GetKBDetail(userID, kbID)
329 | if err != nil {
330 | response.InternalError(ctx, errcode.InternalServerError, "获取知识库详情失败")
331 | return
332 | }
333 |
334 | response.Success(ctx, kb)
335 | }
336 |
337 | func (kc *KBController) DeleteDocs(ctx *gin.Context) {
338 |
339 | // 获取用户ID并验证
340 | userID, err := utils.GetUserIDFromContext(ctx)
341 | if err != nil {
342 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "获取用户失败")
343 | return
344 | }
345 |
346 | req := model.BatchDeleteDocsReq{}
347 |
348 | if err = ctx.ShouldBindJSON(&req); err != nil {
349 | response.ParamError(ctx, errcode.ParamBindError, "参数错误")
350 | return
351 | }
352 |
353 | docIDs := req.DocIDs
354 | kbID := req.KBID
355 |
356 | if len(docIDs) == 0 {
357 | response.SuccessWithMessage(ctx, "删除知识库成功", nil)
358 | return
359 | }
360 |
361 | if err := kc.kbService.DeleteDocs(userID, kbID, docIDs); err != nil {
362 | response.InternalError(ctx, errcode.InternalServerError, "删除文档失败")
363 | return
364 | }
365 | response.SuccessWithMessage(ctx, "删除知识库成功", nil)
366 | }
367 |
--------------------------------------------------------------------------------
/internal/controller/model_controller.go:
--------------------------------------------------------------------------------
1 | package controller
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "ai-cloud/internal/service"
6 | "ai-cloud/internal/utils"
7 | "ai-cloud/pkgs/errcode"
8 | "ai-cloud/pkgs/response"
9 | "fmt"
10 | "github.com/gin-gonic/gin"
11 | )
12 |
13 | type ModelController struct {
14 | svc service.ModelService
15 | }
16 |
17 | func NewModelController(svc service.ModelService) *ModelController {
18 | return &ModelController{svc: svc}
19 | }
20 |
21 | func (c *ModelController) CreateModel(ctx *gin.Context) {
22 | // 获取用户ID并验证
23 | userID, err := utils.GetUserIDFromContext(ctx)
24 | if err != nil {
25 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "获取用户失败")
26 | return
27 | }
28 |
29 | var req model.CreateModelRequest
30 | if err := ctx.ShouldBindJSON(&req); err != nil {
31 | response.ParamError(ctx, errcode.ParamBindError, "参数错误:"+err.Error())
32 | return
33 | }
34 |
35 | m := &model.Model{
36 | ID: utils.GenerateUUID(),
37 | UserID: userID,
38 | Type: req.Type,
39 | ShowName: req.ShowName,
40 | Server: req.Server,
41 | BaseURL: req.BaseURL,
42 | ModelName: req.ModelName,
43 | APIKey: req.APIKey,
44 | // embedding
45 | Dimension: req.Dimension,
46 | // llm
47 | MaxOutputLength: req.MaxOutputLength,
48 | Function: req.Function,
49 | // common
50 | MaxTokens: req.MaxTokens,
51 | }
52 |
53 | if err := c.svc.CreateModel(ctx.Request.Context(), m); err != nil {
54 | response.InternalError(ctx, errcode.InternalServerError, err.Error())
55 | fmt.Println("创建模型失败:", err)
56 | return
57 | }
58 |
59 | response.SuccessWithMessage(ctx, "创建模型成功", nil)
60 | }
61 |
62 | // TODO:修改返回格式
63 | func (c *ModelController) UpdateModel(ctx *gin.Context) {
64 | // 获取用户ID并验证
65 | userID, err := utils.GetUserIDFromContext(ctx)
66 | if err != nil {
67 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "获取用户失败")
68 | return
69 | }
70 |
71 | var req model.UpdateModelRequest
72 | if err := ctx.ShouldBindJSON(&req); err != nil {
73 | response.ParamError(ctx, errcode.ParamBindError, "参数错误:"+err.Error())
74 | return
75 | }
76 |
77 | m := &model.Model{
78 | ID: req.ID,
79 | UserID: userID,
80 | ShowName: req.ShowName,
81 | Server: req.Server,
82 | BaseURL: req.BaseURL,
83 | ModelName: req.ModelName,
84 | APIKey: req.APIKey,
85 | // embedding
86 | Dimension: req.Dimension,
87 | // llm
88 | MaxOutputLength: req.MaxOutputLength,
89 | Function: req.Function,
90 | // common
91 | MaxTokens: req.MaxTokens,
92 | }
93 |
94 | if err := c.svc.UpdateModel(ctx.Request.Context(), m); err != nil {
95 | response.InternalError(ctx, errcode.InternalServerError, "更新模型失败:"+err.Error())
96 | return
97 | }
98 | response.SuccessWithMessage(ctx, "更新模型成功", nil)
99 | }
100 |
101 | func (c *ModelController) DeleteModel(ctx *gin.Context) {
102 | // 获取用户ID并验证
103 | userID, err := utils.GetUserIDFromContext(ctx)
104 | if err != nil {
105 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "用户验证失败")
106 | return
107 | }
108 | kbID := ctx.Query("model_id")
109 |
110 | if err := c.svc.DeleteModel(ctx.Request.Context(), userID, kbID); err != nil {
111 | response.InternalError(ctx, errcode.InternalServerError, "删除模型失败:"+err.Error())
112 | return
113 | }
114 |
115 | response.SuccessWithMessage(ctx, "删除模型成功", nil)
116 | }
117 |
118 | func (c *ModelController) GetModel(ctx *gin.Context) {
119 | // 获取用户ID并验证
120 | userID, err := utils.GetUserIDFromContext(ctx)
121 | if err != nil {
122 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "获取用户失败")
123 | return
124 | }
125 |
126 | modelID := ctx.Query("model_id")
127 |
128 | m, err := c.svc.GetModel(ctx.Request.Context(), userID, modelID)
129 | if err != nil {
130 | response.InternalError(ctx, errcode.InternalServerError, "获取模型失败:"+err.Error())
131 | return
132 | }
133 | response.SuccessWithMessage(ctx, "获取模型成功", m)
134 | }
135 |
136 | func (c *ModelController) PageModels(ctx *gin.Context) {
137 | userID, err := utils.GetUserIDFromContext(ctx)
138 | if err != nil {
139 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "获取用户失败")
140 | return
141 | }
142 |
143 | var req model.PageModelRequest
144 |
145 | if err := ctx.ShouldBindQuery(&req); err != nil {
146 | response.ParamError(ctx, errcode.ParamBindError, "参数错误:"+err.Error())
147 | return
148 | }
149 |
150 | models, count, err := c.svc.PageModels(ctx.Request.Context(), userID, req.Type, req.Page, req.Size)
151 | if err != nil {
152 | response.InternalError(ctx, errcode.InternalServerError, "获取模型列表失败:"+err.Error())
153 | return
154 | }
155 |
156 | response.PageSuccess(ctx, models, count)
157 | }
158 |
159 | func (c *ModelController) ListModels(ctx *gin.Context) {
160 | userID, err := utils.GetUserIDFromContext(ctx)
161 | if err != nil {
162 | response.UnauthorizedError(ctx, errcode.UnauthorizedError, "获取用户失败")
163 | return
164 | }
165 |
166 | modelType := ctx.Query("type")
167 | models, err := c.svc.ListModels(ctx.Request.Context(), userID, modelType)
168 | if err != nil {
169 | response.InternalError(ctx, errcode.InternalServerError, "获取模型列表失败:"+err.Error())
170 | return
171 | }
172 |
173 | response.SuccessWithMessage(ctx, "获取模型列表成功", models)
174 | }
175 |
--------------------------------------------------------------------------------
/internal/controller/user_controller.go:
--------------------------------------------------------------------------------
1 | package controller
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "ai-cloud/internal/service"
6 | "ai-cloud/pkgs/errcode"
7 | "ai-cloud/pkgs/response"
8 | "github.com/gin-gonic/gin"
9 | )
10 |
11 | type UserController struct {
12 | userService service.UserService
13 | }
14 |
15 | func NewUserController(userService service.UserService) *UserController {
16 | return &UserController{userService: userService}
17 | }
18 |
19 | func (c *UserController) Register(ctx *gin.Context) {
20 | var req model.User
21 |
22 | if err := ctx.ShouldBindJSON(&req); err != nil {
23 | response.ParamError(ctx, errcode.ParamBindError, "用户注册参数错误: "+err.Error())
24 | return
25 | }
26 |
27 | if err := c.userService.Register(&req); err != nil {
28 | response.InternalError(ctx, errcode.InternalServerError, "注册失败: "+err.Error())
29 | return
30 | }
31 |
32 | response.SuccessWithMessage(ctx, req.Username+"注册成功", nil)
33 | }
34 |
35 | func (c *UserController) Login(ctx *gin.Context) {
36 | var req model.UserNameLoginReq
37 | if err := ctx.ShouldBindJSON(&req); err != nil {
38 | response.ParamError(ctx, errcode.ParamBindError, "用户名或密码错误")
39 | return
40 | }
41 |
42 | loginResponse, err := c.userService.Login(&req)
43 | if err != nil {
44 | response.InternalError(ctx, errcode.InternalServerError, "登录失败")
45 | return
46 | }
47 |
48 | response.SuccessWithMessage(ctx, "登录成功", gin.H{
49 | "access_token": loginResponse.AccessToken,
50 | "expires_in": loginResponse.ExpiresIn,
51 | "token_type": loginResponse.TokenType,
52 | })
53 | }
54 |
--------------------------------------------------------------------------------
/internal/dao/agent.go:
--------------------------------------------------------------------------------
1 | package dao
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "context"
6 | "errors"
7 | "gorm.io/gorm"
8 | )
9 |
10 | type AgentDao interface {
11 | Create(ctx context.Context, agent *model.Agent) error
12 | Update(ctx context.Context, agent *model.Agent) error
13 | Delete(ctx context.Context, userID uint, agentID string) error
14 | GetByID(ctx context.Context, userID uint, agentID string) (*model.Agent, error)
15 | List(ctx context.Context, userID uint) ([]*model.Agent, error)
16 | Page(ctx context.Context, userID uint, page, size int) ([]*model.Agent, int64, error)
17 | }
18 |
19 | type agentDao struct {
20 | db *gorm.DB
21 | }
22 |
23 | func NewAgentDao(db *gorm.DB) AgentDao {
24 | return &agentDao{db: db}
25 | }
26 |
27 | func (d *agentDao) Create(ctx context.Context, agent *model.Agent) error {
28 | return d.db.WithContext(ctx).Create(agent).Error
29 | }
30 |
31 | func (d *agentDao) Update(ctx context.Context, agent *model.Agent) error {
32 | // Check if the agent belongs to the user
33 | var count int64
34 | if err := d.db.WithContext(ctx).Model(&model.Agent{}).Where("id = ? AND user_id = ?", agent.ID, agent.UserID).Count(&count).Error; err != nil {
35 | return err
36 | }
37 | if count == 0 {
38 | return errors.New("agent not found or no permission")
39 | }
40 |
41 | // Only update specific fields
42 | return d.db.WithContext(ctx).Model(agent).
43 | Select("Name", "Description", "AgentSchema", "UpdatedAt").
44 | Updates(agent).Error
45 | }
46 |
47 | func (d *agentDao) Delete(ctx context.Context, userID uint, agentID string) error {
48 | result := d.db.WithContext(ctx).Where("id = ? AND user_id = ?", agentID, userID).Delete(&model.Agent{})
49 | if result.RowsAffected == 0 {
50 | return errors.New("agent not found or no permission")
51 | }
52 | return result.Error
53 | }
54 |
55 | func (d *agentDao) GetByID(ctx context.Context, userID uint, agentID string) (*model.Agent, error) {
56 | var agent model.Agent
57 | err := d.db.WithContext(ctx).Where("id = ? AND user_id = ?", agentID, userID).First(&agent).Error
58 | if err != nil {
59 | if errors.Is(err, gorm.ErrRecordNotFound) {
60 | return nil, errors.New("agent not found or no permission")
61 | }
62 | return nil, err
63 | }
64 | return &agent, nil
65 | }
66 |
67 | func (d *agentDao) List(ctx context.Context, userID uint) ([]*model.Agent, error) {
68 | var agents []*model.Agent
69 | if err := d.db.WithContext(ctx).Where("user_id = ?", userID).Find(&agents).Error; err != nil {
70 | return nil, err
71 | }
72 | return agents, nil
73 | }
74 |
75 | func (d *agentDao) Page(ctx context.Context, userID uint, page, size int) ([]*model.Agent, int64, error) {
76 | var agents []*model.Agent
77 | var count int64
78 |
79 | db := d.db.WithContext(ctx).Model(&model.Agent{}).Where("user_id = ?", userID).Order("updated_at desc")
80 |
81 | err := db.Count(&count).Error
82 | if err != nil {
83 | return nil, 0, err
84 | }
85 |
86 | err = db.Offset((page - 1) * size).Limit(size).Find(&agents).Error
87 | return agents, count, err
88 | }
89 |
--------------------------------------------------------------------------------
/internal/dao/file_dao.go:
--------------------------------------------------------------------------------
1 | package dao
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "errors"
6 | "fmt"
7 | "gorm.io/gorm"
8 | "strings"
9 | )
10 |
11 | // FileDao 定义了文件操作的接口
12 | type FileDao interface {
13 | CreateFile(file *model.File) error
14 | GetFilesByParentID(userID uint, parentID *string) ([]model.File, error)
15 | GetFileMetaByFileID(id string) (*model.File, error)
16 | DeleteFile(id string) error
17 | ListFiles(userID uint, parentID *string, page int, pageSize int, sort string) ([]model.File, error)
18 | CountFilesByParentID(parentID *string, userID uint) (int64, error)
19 | UpdateFile(file *model.File) error
20 | CountFilesByKeyword(key string, userID uint) (int64, error)
21 | GetFilesByKeyword(userID uint, key string, page int, pageSize int, sort string) ([]model.File, error)
22 | GetDocumentDir(userID uint) (*model.File, error)
23 | }
24 |
25 | // fileDao 实现了FileDao接口,提供文件相关操作
26 | type fileDao struct {
27 | db *gorm.DB
28 | }
29 |
30 | // NewFileDao 创建并返回一个新的FileDao实例
31 | func NewFileDao(db *gorm.DB) FileDao {
32 | return &fileDao{db: db}
33 | }
34 |
35 | // CreateFile 创建一个新的文件记录
36 | func (fd *fileDao) CreateFile(file *model.File) error {
37 | if fd.db == nil {
38 | return errors.New("数据库未初始化")
39 | }
40 | return fd.db.Create(file).Error
41 | }
42 |
43 | // GetFilesByParentID 根据父ID获取文件列表
44 | func (fd *fileDao) GetFilesByParentID(userID uint, parentID *string) ([]model.File, error) {
45 | var files []model.File
46 | query := fd.db.Where("user_id = ?", userID)
47 |
48 | if parentID == nil {
49 | query = query.Where("parent_id IS NULL")
50 | } else {
51 | query = query.Where("parent_id = ?", *parentID)
52 | }
53 |
54 | if err := query.Find(&files).Error; err != nil {
55 | return nil, err
56 | }
57 | return files, nil
58 | }
59 |
60 | // GetFileMetaByFileID 根据文件ID获取文件元信息
61 | func (fd *fileDao) GetFileMetaByFileID(id string) (*model.File, error) {
62 | var file model.File
63 | result := fd.db.Where("id = ?", id).First(&file)
64 | if result.Error != nil {
65 | if errors.Is(result.Error, gorm.ErrRecordNotFound) {
66 | return nil, nil
67 | }
68 | return nil, result.Error
69 | }
70 | return &file, nil
71 | }
72 |
73 | // DeleteFile 根据文件ID删除文件记录
74 | func (fd *fileDao) DeleteFile(id string) error {
75 | if err := fd.db.Where("id = ?", id).Delete(&model.File{}).Error; err != nil {
76 | return err
77 | }
78 | return nil
79 | }
80 |
81 | // ListFiles 列出文件列表,根据指定的排序方式和分页参数
82 | func (fd *fileDao) ListFiles(userID uint, parentID *string, page int, pageSize int, sort string) ([]model.File, error) {
83 | var files []model.File
84 | query := fd.db.Model(model.File{}).Where("user_id = ?", userID)
85 |
86 | if parentID == nil {
87 | query = query.Where("parent_id IS NULL")
88 | } else {
89 | query = query.Where("parent_id = ?", *parentID)
90 | }
91 | query = query.Order("is_dir desc")
92 |
93 | sortClauses := strings.Split(sort, ",")
94 | for _, clause := range sortClauses {
95 | parts := strings.Split(clause, ":")
96 | filed, order := parts[0], parts[1]
97 | query = query.Order(fmt.Sprintf("%s %s", filed, order))
98 | }
99 | //处理分页
100 | offset := (page - 1) * pageSize
101 | query = query.Offset(offset).Limit(pageSize)
102 |
103 | if err := query.Find(&files).Error; err != nil {
104 | return nil, err
105 | }
106 | return files, nil
107 | }
108 |
109 | func (fd *fileDao) GetFilesByKeyword(userID uint, key string, page int, pageSize int, sort string) ([]model.File, error) {
110 | var files []model.File
111 | query := fd.db.Model(&model.File{}).Where("user_id=?", userID).Where("name LIKE ?", "%"+key+"%")
112 |
113 | query = query.Order("is_dir desc")
114 | sortClauses := strings.Split(sort, ",")
115 | for _, clause := range sortClauses {
116 | parts := strings.Split(clause, ":")
117 | filed, order := parts[0], parts[1]
118 | query = query.Order(fmt.Sprintf("%s %s", filed, order))
119 | }
120 | //处理分页
121 | offset := (page - 1) * pageSize
122 | query = query.Offset(offset).Limit(pageSize)
123 |
124 | if err := query.Find(&files).Error; err != nil {
125 | return nil, err
126 | }
127 | return files, nil
128 | }
129 |
130 | // CountFilesByParentID 计算指定父ID下的文件数量
131 | func (fd *fileDao) CountFilesByParentID(parentID *string, userID uint) (int64, error) {
132 | var total int64
133 | query := fd.db.Model(&model.File{}).Where("user_id = ?", userID)
134 |
135 | if parentID == nil {
136 | query = query.Where("parent_id IS NULL")
137 | } else {
138 | query = query.Where("parent_id = ?", parentID)
139 | }
140 | if err := query.Count(&total).Error; err != nil {
141 | return 0, err
142 | }
143 | return total, nil
144 | }
145 |
146 | func (fd *fileDao) CountFilesByKeyword(key string, userID uint) (int64, error) {
147 | var total int64
148 | query := fd.db.Model(&model.File{}).
149 | Where("user_id = ?", userID).
150 | Where("name like ?", "%"+key+"%")
151 | if err := query.Count(&total).Error; err != nil {
152 | return 0, err
153 | }
154 | return total, nil
155 | }
156 |
157 | // UpdateFile 更新文件信息
158 | func (fd *fileDao) UpdateFile(file *model.File) error {
159 | if fd.db == nil {
160 | return errors.New("数据库未初始化")
161 | }
162 | return fd.db.Save(file).Error
163 | }
164 |
165 | func (fd *fileDao) GetDocumentDir(userID uint) (*model.File, error) {
166 | // 初始化结构体
167 | file := &model.File{}
168 | err := fd.db.Where("user_id = ? AND name = ? AND is_dir = ? AND parent_id IS NULL",
169 | userID, "知识库文件", true).First(file).Error
170 | if err != nil {
171 | return nil, err // 直接返回错误,包括 gorm.ErrRecordNotFound
172 | }
173 | return file, nil
174 | }
175 |
--------------------------------------------------------------------------------
/internal/dao/history/conv_dao.go:
--------------------------------------------------------------------------------
1 | /*
2 | * This file contains modified code from the eino-history project.
3 | * Original source: https://github.com/HildaM/eino-history
4 | * Licensed under the Apache License, Version 2.0.
5 | * Modifications are made by RaspberryCola.
6 | */
7 |
8 | package history
9 |
10 | import (
11 | "ai-cloud/internal/model"
12 | "context"
13 | "fmt"
14 |
15 | "gorm.io/gorm"
16 | )
17 |
18 | type ConvDao interface {
19 | Create(ctx context.Context, conv *model.Conversation) error
20 | Update(ctx context.Context, conv *model.Conversation) error
21 | Delete(ctx context.Context, convID string) error
22 | GetByID(ctx context.Context, convID string) (*model.Conversation, error)
23 | FirstOrCreate(ctx context.Context, conv *model.Conversation) error
24 | Page(ctx context.Context, userID uint, page, size int) ([]*model.Conversation, int64, error)
25 | PageByAgent(ctx context.Context, userID uint, agentID string, page, size int) ([]*model.Conversation, int64, error)
26 | Archive(ctx context.Context, convID string) error
27 | UnArchive(ctx context.Context, convID string) error
28 | Pin(ctx context.Context, convID string) error
29 | UnPin(ctx context.Context, convID string) error
30 | GetDB() *gorm.DB
31 | }
32 |
33 | type convDao struct {
34 | db *gorm.DB
35 | }
36 |
37 | // NewConvDao 创建一个ConvDao
38 | func NewConvDao(db *gorm.DB) ConvDao {
39 | return &convDao{db: db}
40 | }
41 |
42 | func (d *convDao) GetDB() *gorm.DB {
43 | return d.db
44 | }
45 |
46 | // Create 创建一个会话
47 | func (d *convDao) Create(ctx context.Context, conv *model.Conversation) error {
48 | err := d.db.WithContext(ctx).Create(conv).Error
49 | if err != nil {
50 | return fmt.Errorf("failed to create conversation: %w", err)
51 | }
52 | return nil
53 | }
54 |
55 | // Update 更新一个会话
56 | func (d *convDao) Update(ctx context.Context, conv *model.Conversation) error {
57 | err := d.db.WithContext(ctx).Save(conv).Error
58 | if err != nil {
59 | return fmt.Errorf("failed to update conversation: %w", err)
60 | }
61 | return nil
62 | }
63 |
64 | // Delete 删除一个会话
65 | func (d *convDao) Delete(ctx context.Context, convID string) error {
66 | err := d.db.WithContext(ctx).Delete(&model.Conversation{}, "conv_id = ?", convID).Error
67 | if err != nil {
68 | return fmt.Errorf("failed to delete conversation: %w", err)
69 | }
70 | return nil
71 | }
72 |
73 | // GetByID 根据ID获取一个会话
74 | func (d *convDao) GetByID(ctx context.Context, convID string) (*model.Conversation, error) {
75 | var conv model.Conversation
76 | err := d.db.WithContext(ctx).Where("conv_id = ?", convID).First(&conv).Error
77 | if err != nil {
78 | return nil, fmt.Errorf("failed to get conversation: %w", err)
79 | }
80 | return &conv, nil
81 | }
82 |
83 | // FirstOrCreate 根据ID获取一个会话,如果会话不存在则创建一个
84 | func (d *convDao) FirstOrCreate(ctx context.Context, conv *model.Conversation) error {
85 | err := d.db.WithContext(ctx).Where("conv_id = ?", conv.ConvID).FirstOrCreate(&conv).Error
86 | if err != nil {
87 | return fmt.Errorf("failed to get conversation: %w", err)
88 | }
89 | return nil
90 | }
91 |
92 | // Page 分页获取会话
93 | func (d *convDao) Page(ctx context.Context, userID uint, page, size int) ([]*model.Conversation, int64, error) {
94 | var convs []*model.Conversation
95 | var total int64
96 |
97 | db := d.db.WithContext(ctx).Model(&model.Conversation{}).Where("user_id = ?", userID).Order("updated_at DESC") // 按照更新时间降序排序
98 |
99 | err := db.Count(&total).Error
100 | if err != nil {
101 | return nil, 0, fmt.Errorf("failed to count conversations: %w", err)
102 | }
103 | err = db.Offset((page - 1) * size).Limit(size).Find(&convs).Error
104 | return convs, total, err
105 | }
106 |
107 | // PageByAgent 按Agent分页获取会话
108 | func (d *convDao) PageByAgent(ctx context.Context, userID uint, agentID string, page, size int) ([]*model.Conversation, int64, error) {
109 | var convs []*model.Conversation
110 | var total int64
111 |
112 | db := d.db.WithContext(ctx).Model(&model.Conversation{}).Where("user_id = ? AND agent_id = ?", userID, agentID).Order("updated_at DESC") // 按照更新时间降序排序
113 |
114 | err := db.Count(&total).Error
115 | if err != nil {
116 | return nil, 0, fmt.Errorf("failed to count conversations: %w", err)
117 | }
118 | err = db.Offset((page - 1) * size).Limit(size).Find(&convs).Error
119 | return convs, total, err
120 | }
121 |
122 | // Archive 归档一个会话
123 | func (d *convDao) Archive(ctx context.Context, convID string) error {
124 | err := d.db.WithContext(ctx).Model(&model.Conversation{}).Where("conv_id = ?", convID).Update("is_archived", true).Error
125 | if err != nil {
126 | return fmt.Errorf("failed to archive conversation: %w", err)
127 | }
128 | return nil
129 | }
130 |
131 | // UnArchive 取消归档一个会话
132 | func (d *convDao) UnArchive(ctx context.Context, convID string) error {
133 | err := d.db.WithContext(ctx).Model(&model.Conversation{}).Where("conv_id = ?", convID).Update("is_archived", false).Error
134 | if err != nil {
135 | return fmt.Errorf("failed to unarchive conversation: %w", err)
136 | }
137 | return nil
138 | }
139 |
140 | // Pin 置顶一个会话
141 | func (d *convDao) Pin(ctx context.Context, convID string) error {
142 | err := d.db.WithContext(ctx).Model(&model.Conversation{}).Where("conv_id = ?", convID).Update("is_pinned", true).Error
143 | if err != nil {
144 | return fmt.Errorf("failed to pin conversation: %w", err)
145 | }
146 | return nil
147 | }
148 |
149 | // UnPin 取消置顶一个会话
150 | func (d *convDao) UnPin(ctx context.Context, convID string) error {
151 | err := d.db.WithContext(ctx).Model(&model.Conversation{}).Where("conv_id = ?", convID).Update("is_pinned", false).Error
152 | if err != nil {
153 | return fmt.Errorf("failed to unpin conversation: %w", err)
154 | }
155 | return nil
156 | }
157 |
--------------------------------------------------------------------------------
/internal/dao/history/msg_dao.go:
--------------------------------------------------------------------------------
1 | package history
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "context"
6 | "fmt"
7 | "github.com/google/uuid"
8 | "gorm.io/gorm"
9 | )
10 |
11 | type MsgDao interface {
12 | GetDB() *gorm.DB
13 | Create(ctx context.Context, msg *model.Message) error
14 | Update(ctx context.Context, msg *model.Message) error
15 | Delete(ctx context.Context, msgID string) error
16 | GetByID(ctx context.Context, msgID string) (*model.Message, error)
17 | ListByConvID(ctx context.Context, convID string) ([]*model.Message, error)
18 | List(ctx context.Context, convID string, offset, limit int) ([]*model.Message, int64, error)
19 | UpdateStatus(ctx context.Context, msgID, status string) error
20 | UpdateTokenCount(ctx context.Context, msgID string, tokenCount int)
21 | SetContextEdge(ctx context.Context, msgID string, isContextEdge bool) error
22 | SetVariant(ctx context.Context, msgID string, isVariant bool) error
23 | }
24 |
25 | type msgDao struct {
26 | db *gorm.DB
27 | }
28 |
29 | func NewMsgDao(db *gorm.DB) MsgDao {
30 | return &msgDao{db: db}
31 | }
32 | func (d *msgDao) GetDB() *gorm.DB {
33 | return d.db
34 | }
35 |
36 | func (d *msgDao) Create(ctx context.Context, msg *model.Message) error {
37 | if len(msg.MsgID) == 0 {
38 | msg.MsgID = uuid.NewString()
39 | }
40 | err := d.db.WithContext(ctx).Create(msg).Error
41 | if err != nil {
42 | return fmt.Errorf("failed to create message: %w", err)
43 | }
44 | return nil
45 | }
46 |
47 | func (d *msgDao) Update(ctx context.Context, msg *model.Message) error {
48 | if err := d.db.WithContext(ctx).Save(msg).Error; err != nil {
49 | return fmt.Errorf("failed to update message: %w", err)
50 | }
51 | return nil
52 | }
53 | func (d *msgDao) Delete(ctx context.Context, msgID string) error {
54 | if err := d.db.WithContext(ctx).Delete(&model.Message{}, "msg_id = ?", msgID).Error; err != nil {
55 | return fmt.Errorf("failed to delete message: %w", err)
56 | }
57 | return nil
58 | }
59 |
60 | // GetByID 据ID获取消息
61 | func (d *msgDao) GetByID(ctx context.Context, msgID string) (*model.Message, error) {
62 | var msg model.Message
63 | if err := d.db.WithContext(ctx).Where("msg_id = ?", msgID).First(&msg).Error; err != nil {
64 | return nil, fmt.Errorf("failed to get message: %w", err)
65 | }
66 | return &msg, nil
67 | }
68 |
69 | // ListByConvID 根据会话ID获取消息
70 | func (d *msgDao) ListByConvID(ctx context.Context, convID string) ([]*model.Message, error) {
71 | var msgs []*model.Message
72 | if err := d.db.WithContext(ctx).Where("conv_id = ?", convID).
73 | Order("order_seq ASC").
74 | Find(&msgs).Error; err != nil {
75 | return nil, fmt.Errorf("failed to list messages: %w", err)
76 | }
77 | return msgs, nil
78 | }
79 |
80 | // Page 获取历史消息
81 | func (d *msgDao) List(ctx context.Context, convID string, offset, limit int) ([]*model.Message, int64, error) {
82 | var msgs []*model.Message
83 | var total int64
84 | db := d.db.WithContext(ctx).Model(&model.Message{}).Where("conv_id = ?", convID).
85 | Order("order_seq ASC")
86 | err := db.Count(&total).Error
87 | if err != nil {
88 | return nil, 0, fmt.Errorf("failed to count messages: %w", err)
89 | }
90 | err = db.Offset(offset).Limit(limit).Find(&msgs).Error
91 | if err != nil {
92 | return nil, 0, fmt.Errorf("failed to list messages: %w", err)
93 | }
94 | return msgs, total, err
95 | }
96 |
97 | // UpdateStatus 更新消息状态
98 | func (d *msgDao) UpdateStatus(ctx context.Context, msgID, status string) error {
99 | err := d.db.WithContext(ctx).Model(&model.Message{}).Where("msg_id = ?", msgID).Update("status", status).Error
100 | if err != nil {
101 | return fmt.Errorf("failed to update message status: %w", err)
102 | }
103 | return nil
104 | }
105 |
106 | // UpdateTokenCount 更新消息的token数量
107 | func (d *msgDao) UpdateTokenCount(ctx context.Context, msgID string, tokenCount int) {
108 | err := d.db.WithContext(ctx).Model(&model.Message{}).Where("msg_id = ?", msgID).Update("token_count", tokenCount).Error
109 | if err != nil {
110 | fmt.Printf("failed to update message token count: %v", err)
111 | }
112 | return
113 | }
114 |
115 | // SetContextEdge 设置消息为上下文边界
116 | func (d *msgDao) SetContextEdge(ctx context.Context, msgID string, isContextEdge bool) error {
117 | err := d.db.WithContext(ctx).Model(&model.Message{}).Where("msg_id = ?", msgID).Update("is_context_edge", isContextEdge).Error
118 | if err != nil {
119 | return fmt.Errorf("failed to update message is_context_edge: %w", err)
120 | }
121 | return nil
122 | }
123 |
124 | // SetVariant 设置消息为变体消息
125 | func (d *msgDao) SetVariant(ctx context.Context, msgID string, isVariant bool) error {
126 | err := d.db.WithContext(ctx).Model(&model.Message{}).Where("msg_id = ?", msgID).Update("is_variant", isVariant).Error
127 | if err != nil {
128 | return fmt.Errorf("failed to update message is_variant: %w", err)
129 | }
130 | return nil
131 | }
132 |
--------------------------------------------------------------------------------
/internal/dao/kb_dao.go:
--------------------------------------------------------------------------------
1 | package dao
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "fmt"
6 | "gorm.io/gorm"
7 | )
8 |
9 | type KnowledgeBaseDao interface {
10 | GetDB() *gorm.DB
11 | // 知识库相关
12 | CreateKB(kb *model.KnowledgeBase) error // 创建知识库
13 | DeleteKB(id string) error // 删除知识库
14 | CountKBs(userID uint) (int64, error) // 统计知识库数量
15 | ListKBs(userID uint, page int, pageSize int) ([]model.KnowledgeBase, error) // 获取知识库列表
16 | GetKBByID(kb_id string) (*model.KnowledgeBase, error) // 获取知识库
17 |
18 | // 文档相关
19 | CreateDocument(doc *model.Document) error // 创建文档
20 | UpdateDocument(doc *model.Document) error // 更新文档
21 | CountDocs(id string) (int64, error) // 统计文档数量
22 | ListDocs(id string, page int, size int) ([]model.Document, error) // 获取文档列表
23 | GetAllDocsByKBID(kbID string) ([]model.Document, error) // 获取知识库下所有文档
24 | DeleteDocsByKBID(kbID string) error // 删除知识库下所有文档
25 | BatchDeleteDocs(userID uint, docIDs []string) error // 批量删除文档
26 | }
27 |
28 | type kbDao struct {
29 | db *gorm.DB
30 | }
31 |
32 | func NewKnowledgeBaseDao(db *gorm.DB) KnowledgeBaseDao { return &kbDao{db: db} }
33 |
34 | func (kd *kbDao) GetDB() *gorm.DB {
35 | return kd.db
36 | }
37 |
38 | func (kd *kbDao) CreateKB(kb *model.KnowledgeBase) error {
39 | result := kd.db.Create(kb)
40 | if result.Error != nil {
41 | return result.Error
42 | }
43 | return nil
44 | }
45 |
46 | func (kd *kbDao) GetKBByID(kb_id string) (*model.KnowledgeBase, error) {
47 | kb := &model.KnowledgeBase{}
48 | if err := kd.db.Where("id = ?", kb_id).First(kb).Error; err != nil {
49 | return nil, err
50 | }
51 | return kb, nil
52 | }
53 |
54 | func (kd *kbDao) CountKBs(userID uint) (int64, error) {
55 | var total int64
56 | query := kd.db.Model(&model.KnowledgeBase{}).Where("user_id = ?", userID)
57 |
58 | if err := query.Count(&total).Error; err != nil {
59 | return 0, err
60 | }
61 | return total, nil
62 | }
63 | func (kd *kbDao) ListKBs(userID uint, page int, pageSize int) ([]model.KnowledgeBase, error) {
64 | var kbs []model.KnowledgeBase
65 | query := kd.db.Where("user_id = ?", userID).Order("created_at desc")
66 |
67 | offset := (page - 1) * pageSize
68 | query = query.Offset(offset).Limit(pageSize)
69 |
70 | if err := query.Find(&kbs).Error; err != nil {
71 | return nil, err
72 | }
73 | return kbs, nil
74 | }
75 | func (kd *kbDao) CountDocs(kbID string) (int64, error) {
76 | var total int64
77 | query := kd.db.Model(&model.Document{}).Where("knowledge_base_id = ?", kbID)
78 | if err := query.Count(&total).Error; err != nil {
79 | return 0, err
80 | }
81 | return total, nil
82 | }
83 |
84 | func (kd *kbDao) ListDocs(kbID string, page int, size int) ([]model.Document, error) {
85 | var docs []model.Document
86 | query := kd.db.Where("knowledge_base_id = ?", kbID).Order("created_at asc")
87 |
88 | offset := (page - 1) * size
89 | query = query.Offset(offset).Limit(size)
90 | if err := query.Find(&docs).Error; err != nil {
91 | return nil, err
92 | }
93 | return docs, nil
94 | }
95 | func (kd *kbDao) DeleteKB(id string) error {
96 | return kd.db.Where("id = ?", id).Delete(&model.KnowledgeBase{}).Error
97 | }
98 |
99 | func (kd *kbDao) CreateDocument(doc *model.Document) error {
100 | return kd.db.Create(doc).Error
101 | }
102 |
103 | func (kd *kbDao) UpdateDocument(doc *model.Document) error {
104 | if err := kd.db.Save(doc).Error; err != nil {
105 | return fmt.Errorf("更新文档失败: %w", err)
106 | }
107 | return nil
108 | }
109 |
110 | func (kd *kbDao) GetAllDocsByKBID(kbID string) ([]model.Document, error) {
111 | var docs []model.Document
112 | if err := kd.db.Where("knowledge_base_id = ?", kbID).Find(&docs).Error; err != nil {
113 | return nil, fmt.Errorf("获取文档失败: %w", err)
114 | }
115 | return docs, nil
116 | }
117 |
118 | func (kd *kbDao) DeleteDocsByKBID(kbID string) error {
119 | if err := kd.db.Where("knowledge_base_id = ?", kbID).Delete(&model.Document{}).Error; err != nil {
120 | return fmt.Errorf("删除文档失败: %w", err)
121 | }
122 | return nil
123 | }
124 |
125 | func (kd *kbDao) BatchDeleteDocs(userID uint, docIDs []string) error {
126 | if len(docIDs) == 0 {
127 | return nil
128 | }
129 | res := kd.db.Where("id IN (?) AND user_id = ?", docIDs, userID).Delete(&model.Document{})
130 | if res.Error != nil {
131 | return fmt.Errorf("db删除错误:%w", res.Error)
132 | }
133 |
134 | if res.RowsAffected != int64(len(docIDs)) {
135 | return fmt.Errorf("expected to delete %d records, but deleted %d", len(docIDs), res.RowsAffected)
136 | }
137 | return nil
138 | }
139 |
--------------------------------------------------------------------------------
/internal/dao/model_dao.go:
--------------------------------------------------------------------------------
1 | package dao
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "context"
6 | "errors"
7 | "gorm.io/gorm"
8 | )
9 |
10 | type ModelDao interface {
11 | Create(ctx context.Context, m *model.Model) error
12 | Update(ctx context.Context, m *model.Model) error
13 | Delete(ctx context.Context, userID uint, modelID string) error
14 | GetByID(ctx context.Context, userID uint, modelID string) (*model.Model, error)
15 | List(ctx context.Context, userID uint, modelType string) ([]*model.Model, error)
16 | Page(ctx context.Context, userID uint, modelType string, page, size int) ([]*model.Model, int64, error)
17 | }
18 |
19 | type modelDao struct {
20 | db *gorm.DB
21 | }
22 |
23 | func NewModelDao(db *gorm.DB) ModelDao {
24 | return &modelDao{db: db}
25 | }
26 |
27 | func (d *modelDao) Create(ctx context.Context, m *model.Model) error {
28 | return d.db.WithContext(ctx).Create(m).Error
29 | }
30 |
31 | func (d *modelDao) Update(ctx context.Context, m *model.Model) error {
32 | // 检查模型是否属于该用户
33 | var count int64
34 | if err := d.db.WithContext(ctx).Model(&model.Model{}).Where("id = ? AND user_id = ?", m.ID, m.UserID).Count(&count).Error; err != nil {
35 | return err
36 | }
37 | if count == 0 {
38 | return errors.New("模型不存在或无权限")
39 | }
40 |
41 | // 只更新允许修改的字段,排除CreatedAt
42 | return d.db.WithContext(ctx).Model(m).
43 | Select(
44 | "ShowName", "Server", "BaseURL", "ModelName", "APIKey",
45 | "Dimension", "MaxOutputLength", "Function", "MaxTokens",
46 | ).
47 | Updates(m).Error
48 | }
49 |
50 | func (d *modelDao) Delete(ctx context.Context, userID uint, id string) error {
51 | result := d.db.WithContext(ctx).Where("id = ? AND user_id = ?", id, userID).Delete(&model.Model{})
52 | if result.RowsAffected == 0 {
53 | return errors.New("模型不存在或无权限")
54 | }
55 | return result.Error
56 | }
57 |
58 | func (d *modelDao) GetByID(ctx context.Context, userID uint, id string) (*model.Model, error) {
59 | var m model.Model
60 | err := d.db.WithContext(ctx).Where("id = ? AND user_id = ?", id, userID).First(&m).Error
61 | if err != nil {
62 | if errors.Is(err, gorm.ErrRecordNotFound) {
63 | return nil, errors.New("模型不存在或无权限")
64 | }
65 | return nil, err
66 | }
67 | return &m, nil
68 | }
69 |
70 | func (d *modelDao) Page(ctx context.Context, userID uint, modelType string, page, size int) ([]*model.Model, int64, error) {
71 | var models []*model.Model
72 | var count int64
73 |
74 | db := d.db.WithContext(ctx).Model(&model.Model{}).Where("user_id = ?", userID)
75 | if modelType != "" {
76 | db = db.Where("type = ?", modelType)
77 | }
78 |
79 | err := db.Count(&count).Offset((page - 1) * size).Limit(size).Find(&models).Error
80 | return models, count, err
81 | }
82 |
83 | func (d *modelDao) List(ctx context.Context, userID uint, modelType string) ([]*model.Model, error) {
84 | var models []*model.Model
85 | db := d.db.WithContext(ctx).Where("user_id = ?", userID)
86 | if modelType != "" {
87 | db = db.Where("type = ?", modelType)
88 | }
89 | if err := db.Find(&models).Error; err != nil {
90 | return nil, err
91 | }
92 | return models, nil
93 | }
94 |
--------------------------------------------------------------------------------
/internal/dao/user_dao.go:
--------------------------------------------------------------------------------
1 | package dao
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 |
6 | "gorm.io/gorm"
7 | )
8 |
9 | type UserDao interface {
10 | CheckFieldExists(field string, value interface{}) (bool, error)
11 | CreateUser(user *model.User) error
12 | GetUserByName(name string) (user *model.User, err error)
13 | }
14 |
15 | type userDao struct {
16 | db *gorm.DB
17 | }
18 |
19 | func NewUserDao(db *gorm.DB) UserDao {
20 | return &userDao{db: db}
21 | }
22 |
23 | // CheckFieldExists 检查字段是否存在
24 | func (ud *userDao) CheckFieldExists(field string, value interface{}) (bool, error) {
25 | var count int64
26 | if err := ud.db.Model(&model.User{}).Where(field+" = ?", value).Count(&count).Error; err != nil {
27 | return false, err
28 | }
29 | return count > 0, nil
30 | }
31 |
32 | func (ud *userDao) CreateUser(user *model.User) error {
33 | result := ud.db.Create(user)
34 |
35 | if result.Error != nil {
36 | return result.Error
37 | }
38 | return nil
39 | }
40 |
41 | func (ud *userDao) GetUserByName(name string) (*model.User, error) {
42 | var user model.User
43 | result := ud.db.Model(&model.User{}).Where("username = ?", name).First(&user)
44 | if result.Error != nil {
45 | return nil, result.Error
46 | }
47 | return &user, nil
48 | }
49 |
--------------------------------------------------------------------------------
/internal/database/milvus.go:
--------------------------------------------------------------------------------
1 | package database
2 |
3 | import (
4 | "ai-cloud/config"
5 | "context"
6 | "github.com/milvus-io/milvus-sdk-go/v2/client"
7 | "sync"
8 | )
9 |
10 | var (
11 | once sync.Once
12 | instance client.Client
13 | err error
14 | )
15 |
16 | // InitMilvus 初始化
17 | func InitMilvus(ctx context.Context) (client.Client, error) {
18 |
19 | once.Do(func() {
20 | instance, err = client.NewClient(ctx, client.Config{
21 | Address: config.GetConfig().Milvus.Address,
22 | })
23 |
24 | })
25 | return instance, err
26 | }
27 |
28 | func GetMilvusClient() client.Client {
29 | return instance
30 | }
31 |
--------------------------------------------------------------------------------
/internal/database/mysql.go:
--------------------------------------------------------------------------------
1 | package database
2 |
3 | import (
4 | "ai-cloud/config"
5 | "ai-cloud/internal/model"
6 | "fmt"
7 | "gorm.io/driver/mysql"
8 | "gorm.io/gorm"
9 | "sync"
10 | )
11 |
12 | var (
13 | db *gorm.DB
14 | dbOnce sync.Once
15 | dbErr error
16 | )
17 |
18 | // GetDB 获取数据库单例
19 | func GetDB() (*gorm.DB, error) {
20 | dbOnce.Do(func() {
21 | // 构造 DSN
22 | dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
23 | config.AppConfigInstance.Database.User,
24 | config.AppConfigInstance.Database.Password,
25 | config.AppConfigInstance.Database.Host,
26 | config.AppConfigInstance.Database.Port,
27 | config.AppConfigInstance.Database.Name,
28 | )
29 |
30 | // 连接数据库
31 | db, dbErr = gorm.Open(mysql.Open(dsn), &gorm.Config{})
32 | if dbErr != nil {
33 | return
34 | }
35 |
36 | // 自动迁移, 创建表结构
37 | if err := db.AutoMigrate(
38 | &model.User{},
39 | &model.File{},
40 | &model.KnowledgeBase{},
41 | &model.Document{},
42 | &model.Model{},
43 | &model.Agent{},
44 | // 会话记录相关
45 | &model.Conversation{},
46 | &model.Message{},
47 | &model.Attachment{},
48 | &model.MessageAttachment{},
49 | ); err != nil {
50 | dbErr = err
51 | return
52 | }
53 | })
54 |
55 | return db, dbErr
56 | }
57 |
--------------------------------------------------------------------------------
/internal/middleware/auth.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "ai-cloud/pkgs/errcode"
5 | "errors"
6 | "net/http"
7 | "strings"
8 |
9 | "github.com/gin-gonic/gin"
10 | "github.com/golang-jwt/jwt/v5"
11 | )
12 |
13 | const (
14 | AuthHeaderKey = "Authorization"
15 | AuthBearerType = "Bearer"
16 | )
17 |
18 | func JWTAuth() gin.HandlerFunc {
19 | return func(c *gin.Context) {
20 | // 获取 Authorization 头
21 | authHeader := c.GetHeader(AuthHeaderKey)
22 | if authHeader == "" {
23 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
24 | "code": errcode.TokenMissing,
25 | "message": "需要认证令牌",
26 | })
27 | return
28 | }
29 |
30 | // 分割 Bearer 和令牌
31 | parts := strings.SplitN(authHeader, " ", 2)
32 | if len(parts) != 2 || parts[0] != AuthBearerType {
33 | c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
34 | "code": errcode.TokenInvalid,
35 | "message": "令牌格式错误",
36 | })
37 | return
38 | }
39 |
40 | // 解析验证令牌
41 | tokenString := parts[1]
42 | claims, err := ParseToken(tokenString)
43 | if err != nil {
44 | status := http.StatusUnauthorized
45 | code := errcode.TokenInvalid
46 | message := "无效令牌"
47 |
48 | // 新版错误处理
49 | switch {
50 | case errors.Is(err, jwt.ErrTokenExpired):
51 | message = "令牌已过期"
52 | code = errcode.TokenExpired
53 | status = http.StatusForbidden
54 | case errors.Is(err, jwt.ErrTokenMalformed):
55 | message = "令牌格式错误"
56 | case errors.Is(err, jwt.ErrSignatureInvalid):
57 | message = "签名验证失败"
58 | }
59 |
60 | c.AbortWithStatusJSON(status, gin.H{
61 | "code": code,
62 | "message": message,
63 | })
64 | return
65 | }
66 |
67 | // 将用户ID存入上下文
68 | c.Set("user_id", claims.UserID)
69 | c.Next()
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/internal/middleware/cors.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "ai-cloud/config"
5 | "github.com/gin-contrib/cors"
6 | "github.com/gin-gonic/gin"
7 | "time"
8 | )
9 |
10 | // SetupCORS 封装CORS配置
11 | func SetupCORS() gin.HandlerFunc {
12 | corsConfig := config.GetConfig().CORS
13 |
14 | maxAge, err := time.ParseDuration(corsConfig.MaxAge)
15 | if err != nil {
16 | maxAge = 12 * time.Hour
17 | }
18 |
19 | return cors.New(cors.Config{
20 | AllowOrigins: corsConfig.AllowOrigins, // 允许所有域名
21 | AllowMethods: corsConfig.AllowMethods, // 允许的HTTP方法
22 | AllowHeaders: corsConfig.AllowHeaders, // 允许的请求头
23 | ExposeHeaders: corsConfig.ExposeHeaders, // 暴露的响应头
24 | AllowCredentials: corsConfig.AllowCredentials, // 允许携带凭证(如Cookie)
25 | MaxAge: maxAge, // 预检请求缓存时间
26 | })
27 | }
28 |
--------------------------------------------------------------------------------
/internal/middleware/jwt.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "ai-cloud/config"
5 | "time"
6 |
7 | "github.com/golang-jwt/jwt/v5"
8 | )
9 |
10 | // Claims 结构体定义了JWT中包含的用户相关声明。
11 | // UserID 是用户在JWT中的唯一标识。
12 | // RegisteredClaims 包含了一些标准的JWT声明,如过期时间、发行时间等。
13 | type Claims struct {
14 | UserID uint `json:"user_id"`
15 | jwt.RegisteredClaims
16 | }
17 |
18 | // GenerateToken 为给定的用户ID生成JWT。
19 | // userID 是要编码到JWT中的用户ID。
20 | // 函数返回生成的JWT字符串和可能的错误。
21 | func GenerateToken(userID uint) (string, error) {
22 | cfg := config.AppConfigInstance.JWT
23 | expirationTime := time.Now().Add(time.Duration(cfg.ExpirationHours) * time.Hour)
24 |
25 | claims := &Claims{
26 | UserID: userID,
27 | RegisteredClaims: jwt.RegisteredClaims{
28 | ExpiresAt: jwt.NewNumericDate(expirationTime),
29 | IssuedAt: jwt.NewNumericDate(time.Now()),
30 | NotBefore: jwt.NewNumericDate(time.Now()),
31 | },
32 | }
33 |
34 | token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
35 | return token.SignedString([]byte(cfg.Secret))
36 | }
37 |
38 | // ParseToken 解析JWT并验证其有效性。
39 | // tokenString 是待解析的JWT字符串。
40 | // 函数返回解析出的Claims指针和可能的错误。
41 | func ParseToken(tokenString string) (*Claims, error) {
42 | cfg := config.AppConfigInstance.JWT
43 |
44 | token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
45 | return []byte(cfg.Secret), nil
46 | })
47 |
48 | if err != nil {
49 | return nil, err
50 | }
51 |
52 | if claims, ok := token.Claims.(*Claims); ok && token.Valid {
53 | return claims, nil
54 | }
55 |
56 | return nil, jwt.ErrInvalidKey
57 | }
58 |
--------------------------------------------------------------------------------
/internal/model/agent.go:
--------------------------------------------------------------------------------
1 | package model
2 |
3 | import (
4 | "github.com/cloudwego/eino/schema"
5 | "time"
6 | )
7 |
8 | // Agent
9 | type Agent struct {
10 | ID string `gorm:"primaryKey;type:char(36)"`
11 | UserID uint `gorm:"index"`
12 | Name string `gorm:"not null"`
13 | Description string `gorm:"type:text"`
14 | AgentSchema string `gorm:"type:json"`
15 | CreatedAt time.Time `gorm:"autoCreateTime"`
16 | UpdatedAt time.Time `gorm:"autoUpdateTime"`
17 | }
18 |
19 | // AgentSchema 配置Agent
20 | type AgentSchema struct {
21 | LLMConfig LLMConfig `json:"llm_config"`
22 | MCP MCPConfig `json:"mcp"`
23 | Tools ToolsConfig `json:"tools"`
24 | Prompt string `json:"prompt"`
25 | Knowledge KnowledgeConfig `json:"knowledge"`
26 | }
27 |
28 | // TODO:后续需要优化这一块,提高全面性
29 | // LLMConfig 配置Agent关联的LLM模型
30 | type LLMConfig struct {
31 | ModelID string `json:"model_id"`
32 | Temperature float64 `json:"temperature"`
33 | TopP float64 `json:"top_p"`
34 | MaxOutputLength int `json:"max_output_length"`
35 | Thinking bool `json:"thinking"`
36 | }
37 |
38 | // MCPConfig 配置MCP SSE服务器
39 | type MCPConfig struct {
40 | Servers []string `json:"servers"`
41 | }
42 |
43 | // ToolsConfig 配置Agent关联的工具IDs(考虑到MCP到存在,目前没有实现Tools模块)
44 | type ToolsConfig struct {
45 | ToolIDs []string `json:"tool_ids"`
46 | }
47 |
48 | // KnowledgeConfig Agent关联的知识库IDs
49 | type KnowledgeConfig struct {
50 | KnowledgeIDs []string `json:"knowledge_ids"`
51 | TopK int `json:"top_k"`
52 | }
53 |
54 | // CreateAgentRequest 创建Agent请求
55 | type CreateAgentRequest struct {
56 | Name string `json:"name" binding:"required"`
57 | Description string `json:"description"`
58 | }
59 |
60 | // UpdateAgentRequest 更新Agent请求
61 | type UpdateAgentRequest struct {
62 | ID string `json:"id" binding:"required"`
63 | Name string `json:"name"`
64 | Description string `json:"description"`
65 | LLMConfig LLMConfig `json:"llm_config"`
66 | MCP MCPConfig `json:"mcp"`
67 | Tools ToolsConfig `json:"tools"`
68 | Prompt string `json:"prompt"`
69 | Knowledge KnowledgeConfig `json:"knowledge"`
70 | }
71 |
72 | type PageAgentRequest struct {
73 | Page int `form:"page,default=1"`
74 | Size int `form:"size,default=10"`
75 | }
76 |
77 | type UserMessage struct {
78 | Query string `json:"query" binding:"required"`
79 | History []*schema.Message `json:"history"`
80 | }
81 |
82 | type ExecuteAgentRequest struct {
83 | ID string `json:"id" binding:"required"`
84 | AgentID string `json:"agent_id" binding:"required"`
85 | Message UserMessage `json:"message" binding:"required"`
86 | }
87 |
--------------------------------------------------------------------------------
/internal/model/chat.go:
--------------------------------------------------------------------------------
1 | package model
2 |
3 | import "github.com/cloudwego/eino/schema"
4 |
5 | type ChatResponse struct {
6 | Response string `json:"response"`
7 | References []*schema.Document `json:"references"`
8 | }
9 |
10 | type ChatRequest struct {
11 | Query string `json:"query"`
12 | KBs []string `json:"kbs"`
13 | }
14 |
15 | // ChatStreamResponse OpenAI 兼容的流式响应格式
16 | type ChatStreamResponse struct {
17 | ID string `json:"id"`
18 | Object string `json:"object"`
19 | Created int64 `json:"created"`
20 | Model string `json:"model"`
21 | Choices []ChatStreamChoice `json:"choices"`
22 | }
23 |
24 | type ChatStreamChoice struct {
25 | Delta ChatStreamDelta `json:"delta"`
26 | Index int `json:"index"`
27 | FinishReason *string `json:"finish_reason"`
28 | }
29 |
30 | type ChatStreamDelta struct {
31 | Role string `json:"role,omitempty"`
32 | Content string `json:"content,omitempty"`
33 | }
34 |
--------------------------------------------------------------------------------
/internal/model/conversation.go:
--------------------------------------------------------------------------------
1 | /*
2 | * This file contains modified code from the eino-history project.
3 | * Original source: https://github.com/HildaM/eino-history
4 | * Licensed under the Apache License, Version 2.0.
5 | * Modifications are made by RaspberryCola.
6 | */
7 |
8 | package model
9 |
10 | import "encoding/json"
11 |
12 | // Conversation 对话表
13 | type Conversation struct {
14 | ID uint64 `gorm:"primaryKey;column:id"`
15 | ConvID string `gorm:"uniqueIndex;column:conv_id;type:varchar(255)"`
16 | UserID uint `gorm:"index;column:user_id"`
17 | AgentID string `gorm:"index;column:agent_id;type:varchar(255)"`
18 | Title string `gorm:"column:title;type:varchar(255)"`
19 | CreatedAt int64 `gorm:"column:created_at"`
20 | UpdatedAt int64 `gorm:"column:updated_at"`
21 | Settings json.RawMessage `gorm:"column:settings;type:json"`
22 | IsArchived bool `gorm:"column:is_archived;default:0"`
23 | IsPinned bool `gorm:"column:is_pinned;default:0"`
24 | }
25 |
26 | // TableName 设置表名
27 | func (Conversation) TableName() string {
28 | return "conversations"
29 | }
30 |
31 | // Message 消息表
32 | type Message struct {
33 | ID uint64 `gorm:"primaryKey;column:id"`
34 | MsgID string `gorm:"uniqueIndex;column:msg_id;type:varchar(255)"`
35 | UserID uint `gorm:"index;column:user_id"`
36 | ConvID string `gorm:"column:conv_id;type:varchar(255)"`
37 | ParentID string `gorm:"column:parent_id;type:varchar(255);default:''"`
38 | Role string `gorm:"column:role;type:enum('user','assistant','system','function')"`
39 | Content string `gorm:"column:content;type:text"`
40 | CreatedAt int64 `gorm:"column:created_at"`
41 | OrderSeq int `gorm:"column:order_seq;default:0"`
42 | TokenCount int `gorm:"column:token_count;default:0"`
43 | Status string `gorm:"column:status;type:enum('sent','pending','error');default:'sent'"`
44 | Metadata json.RawMessage `gorm:"column:metadata;type:json"`
45 | IsContextEdge bool `gorm:"column:is_context_edge;default:0"`
46 | IsVariant bool `gorm:"column:is_variant;default:0"`
47 | }
48 |
49 | // TableName 设置表名
50 | func (Message) TableName() string {
51 | return "messages"
52 | }
53 |
54 | // Attachment 附件表
55 | type Attachment struct {
56 | ID uint64 `gorm:"primaryKey;column:id"`
57 | AttachID string `gorm:"uniqueIndex;column:attach_id;type:varchar(255)"`
58 | UserID uint `gorm:"index;column:user_id"`
59 | MessageID string `gorm:"column:message_id;type:varchar(255)"`
60 | AttachmentType string `gorm:"column:attachment_type;type:enum('file','image','code','audio','video')"`
61 | FileName string `gorm:"column:file_name;type:varchar(255)"`
62 | FileSize int64 `gorm:"column:file_size"`
63 | StorageType string `gorm:"column:storage_type;type:enum('path','blob','cloud')"`
64 | StoragePath string `gorm:"column:storage_path;type:varchar(1024)"`
65 | Thumbnail []byte `gorm:"column:thumbnail;type:mediumblob"`
66 | Vectorized bool `gorm:"column:vectorized;default:0"`
67 | DataSummary string `gorm:"column:data_summary;type:text"`
68 | MimeType string `gorm:"column:mime_type;type:varchar(255)"`
69 | CreatedAt int64 `gorm:"column:created_at"`
70 | }
71 |
72 | // TableName 设置表名
73 | func (Attachment) TableName() string {
74 | return "attachments"
75 | }
76 |
77 | // MessageAttachment 消息附件关联表
78 | type MessageAttachment struct {
79 | ID uint64 `gorm:"primaryKey;column:id"`
80 | MessageID string `gorm:"column:message_id;type:varchar(255)"`
81 | AttachmentID string `gorm:"column:attachment_id;type:varchar(255)"`
82 | }
83 |
84 | // TableName 设置表名
85 | func (MessageAttachment) TableName() string {
86 | return "message_attachments"
87 | }
88 |
89 | // DebugRequest 获取AgentID和消息
90 | type DebugRequest struct {
91 | AgentID string `json:"agent_id" binding:"required"`
92 | Message string `json:"message" binding:"required"`
93 | }
94 |
95 | // CreateConvRequest 创建会话请求
96 | type CreateConvRequest struct {
97 | AgentID string `json:"agent_id" binding:"required"`
98 | }
99 |
100 | // ConvRequest 对话请求
101 | type ConvRequest struct {
102 | AgentID string `json:"agent_id" binding:"required"`
103 | Message string `json:"message" binding:"required"`
104 | ConvID string `json:"conv_id"`
105 | }
106 |
--------------------------------------------------------------------------------
/internal/model/file.go:
--------------------------------------------------------------------------------
1 | package model
2 |
3 | import "time"
4 |
5 | type File struct {
6 | ID string `gorm:"primaryKey;type:char(36)"` // UUID
7 | UserID uint `gorm:"index"` // 用户ID
8 | Name string `gorm:"not null"` // 文件名
9 | Size int64 // 文件大小
10 | Hash string `gorm:"index;size:64"` // 文件哈希(SHA-256)
11 | MIMEType string // MIME类型
12 | IsDir bool `gorm:"default:false"` // 是否为目录
13 | ParentID *string `gorm:"type:char(36);index"` // 父目录ID
14 | StorageType string `gorm:"default:'local'"` // 存储类型:local/oss
15 | StorageKey string // 存储唯一标识(路径或OSS Key)
16 | CreatedAt time.Time `gorm:"autoCreateTime"` // 创建时间
17 | UpdatedAt time.Time `gorm:"autoUpdateTime"` // 更新时间
18 | }
19 |
20 | type CreateFolderReq struct {
21 | Name string `json:"name"`
22 | ParentID *string `json:"parent_id,omitempty"`
23 | }
24 |
25 | type BatchMoveRequest struct {
26 | FileIDs []string `json:"files_pid" binding:"required"`
27 | TargetParentID string `json:"target_pid"`
28 | }
29 |
30 | type RenameRequest struct {
31 | FileID string `json:"file_id" binding:"required"`
32 | NewName string `json:"new_name" binding:"required"`
33 | }
34 |
--------------------------------------------------------------------------------
/internal/model/knowledge.go:
--------------------------------------------------------------------------------
1 | package model
2 |
3 | import (
4 | "time"
5 | )
6 |
7 | // KnowledgeBase 知识库
8 | type KnowledgeBase struct {
9 | ID string `gorm:"primaryKey;type:char(36)"` // UUID
10 | Name string `gorm:"not null"` // 知识库名称
11 | Description string // 知识库描述
12 | UserID uint `gorm:"index"` // 创建者ID
13 | EmbedModelID string `gorm:"index"` // 关联的embedding模型id
14 | MilvusCollection string `gorm:"not null"` //对应的milvus collection名称
15 | CreatedAt time.Time `gorm:"autoCreateTime"`
16 | UpdatedAt time.Time `gorm:"autoUpdateTime"`
17 | }
18 |
19 | // Document 知识库文档
20 | type Document struct {
21 | ID string `gorm:"primaryKey;type:char(36)"` // UUID
22 | UserID uint `gorm:"index"` // 所属的用户
23 | KnowledgeBaseID string `gorm:"index"` // 所属知识库ID
24 | FileID string `gorm:"index"` // 关联的文件ID
25 | Title string // 文档标题
26 | DocType string // 文档类型(pdf/txt/md)
27 | Status int // 处理状态(0:待处理,1:处理中,2:已完成,3:失败)
28 | CreatedAt time.Time `gorm:"autoCreateTime"`
29 | UpdatedAt time.Time `gorm:"autoUpdateTime"`
30 | }
31 |
32 | // 存储到milvus中
33 | type Chunk struct {
34 | ID string `json:"id"`
35 | Content string `json:"content"` // chunk内容
36 | KBID string `json:"kb_id"` // 知识库ID(知识库级别的检索)
37 | DocumentID string `json:"document_id"` // 文档ID
38 | DocumentName string `json:"document_name"` // 文档名称
39 | Index int `json:"index"` // 第几个chunk
40 | Embeddings []float32 `json:"embeddings"` // chunk向量
41 | Score float32 `json:"score"` // 返回分数信息
42 | }
43 |
44 | type AddFileRequest struct {
45 | FileID string `json:"file_id"`
46 | KBID string `json:"kb_id"`
47 | }
48 |
49 | type BatchDeleteDocsReq struct {
50 | KBID string `json:"kb_id"`
51 | DocIDs []string `json:"doc_ids" binding:"required"`
52 | }
53 |
54 | type CreateKBRequest struct {
55 | Name string `json:"name" binding:"required"`
56 | Description string `json:"description"`
57 | EmbedModelID string `json:"embed_model_id" binding:"required"`
58 | }
59 |
60 | type RetrieveRequest struct {
61 | KBID string `json:"kb_id"`
62 | Query string `json:"query"`
63 | TopK int `json:"top_k"`
64 | }
65 |
--------------------------------------------------------------------------------
/internal/model/model.go:
--------------------------------------------------------------------------------
1 | package model
2 |
3 | import "time"
4 |
5 | type Model struct {
6 | // 基础信息
7 | ID string `gorm:"primaryKey;type:char(36)"`
8 | UserID uint `gorm:"index"` // 用户ID
9 | Type string `gorm:"not null"` // 模型的类型:embedding/llm
10 | ShowName string `gorm:"not null"` // 显示名称
11 | Server string `gorm:"not null"` // 模型的供应商:openai/ollama
12 | BaseURL string `gorm:"not null"` // API基础地址
13 | ModelName string `gorm:"not null"` // 模型标识符,例如 deepseek-chat,text-embedding-v3
14 | APIKey string // 访问密钥,ollama一般不需要
15 |
16 | // Embedding模型字段
17 | Dimension int // 向量维度(embedding必填)
18 |
19 | // LLM模型字段
20 | MaxOutputLength int `gorm:"default:4096"`
21 | Function bool `gorm:"default:false"`
22 |
23 | // 通用字段
24 | MaxTokens int `gorm:"default:1024"` // 限制最大的输入长度
25 | CreatedAt time.Time `gorm:"autoCreateTime"`
26 | UpdatedAt time.Time `gorm:"autoUpdateTime"`
27 | }
28 |
29 | type CreateModelRequest struct {
30 | // 基础信息
31 | Type string `json:"type" binding:"required,oneof=embedding llm"`
32 | ShowName string `json:"name" binding:"required"`
33 | Server string `json:"server" binding:"required"`
34 | BaseURL string `json:"base_url" binding:"required,url"`
35 | ModelName string `json:"model" binding:"required"`
36 | APIKey string `json:"api_key"`
37 |
38 | // Embedding
39 | Dimension int `json:"dimension"`
40 |
41 | // LLM
42 | MaxOutputLength int `json:"max_output_length"`
43 | Function bool `json:"function"`
44 |
45 | // 通用字段
46 | MaxTokens int `json:"max_tokens"`
47 | }
48 |
49 | type PageModelRequest struct {
50 | Type string `form:"type"`
51 | Page int `form:"page,default=1"`
52 | Size int `form:"size,default=10"`
53 | }
54 |
55 | type UpdateModelRequest struct {
56 | ID string `json:"id" binding:"required"`
57 | // 基础信息
58 | ShowName string `json:"name"`
59 | Server string `json:"server"`
60 | BaseURL string `json:"base_url"`
61 | ModelName string `json:"model"`
62 | APIKey string `json:"api_key"`
63 |
64 | // Embedding
65 | Dimension int `json:"dimension"`
66 |
67 | // LLM
68 | MaxOutputLength int `json:"max_output_length"`
69 | Function bool `json:"function"`
70 |
71 | // 通用字段
72 | MaxTokens int `json:"max_tokens"`
73 | }
74 |
--------------------------------------------------------------------------------
/internal/model/user.go:
--------------------------------------------------------------------------------
1 | package model
2 |
3 | import (
4 | "time"
5 | )
6 |
7 | type User struct {
8 | ID uint `gorm:"primaryKey"`
9 | Username string `gorm:"uniqueIndex;size:50;not null"`
10 | Phone string `gorm:"uniqueIndex;size:20;not null"` // 新增手机号字段
11 | Email string `gorm:"uniqueIndex;size:100"`
12 | Password string `gorm:"not null"`
13 | CreatedAt time.Time `gorm:"autoCreateTime"`
14 | UpdatedAt time.Time `gorm:"autoUpdateTime"`
15 | }
16 |
17 | type UserRegisterReq struct {
18 | Username string `json:"username" binding:"required,min=3,max=20"`
19 | Password string `json:"password" binding:"required,min=6,max=30"`
20 | Email string `json:"email" binding:"required,email"`
21 | Phone string `json:"phone" binding:"required,e164"`
22 | }
23 |
24 | type LoginResponse struct {
25 | AccessToken string `json:"access_token"`
26 | ExpiresIn int `json:"expires_in"`
27 | TokenType string `json:"token_type"`
28 | }
29 | type UserNameLoginReq struct {
30 | Username string `json:"username" binding:"required,min=3,max=20"`
31 | Password string `json:"password" binding:"required,min=8,max=30"`
32 | }
33 |
34 | type UserPhoneLogin struct {
35 | Phone string `json:"phone" binding:"required,e164"`
36 | Password string `json:"password" binding:"required,min=8,max=30"`
37 | }
38 |
--------------------------------------------------------------------------------
/internal/router/router.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "ai-cloud/internal/controller"
5 | "ai-cloud/internal/middleware"
6 |
7 | "github.com/gin-gonic/gin"
8 | )
9 |
10 | func SetUpRouters(r *gin.Engine, uc *controller.UserController, fc *controller.FileController, kc *controller.KBController, mc *controller.ModelController, ac *controller.AgentController, cc *controller.ConversationController) {
11 | api := r.Group("/api")
12 | {
13 |
14 | publicUser := api.Group("/users")
15 | {
16 | publicUser.POST("/register", uc.Register)
17 | publicUser.POST("/login", uc.Login)
18 | }
19 |
20 | auth := api.Group("files")
21 | auth.Use(middleware.JWTAuth())
22 | {
23 | auth.POST("/upload", fc.Upload)
24 | auth.GET("/page", fc.PageList)
25 | auth.GET("/download", fc.Download)
26 | auth.DELETE("/delete", fc.Delete)
27 | auth.POST("/folder", fc.CreateFolder)
28 | auth.POST("/move", fc.BatchMove)
29 | auth.GET("/search", fc.Search)
30 | auth.PUT("/rename", fc.Rename)
31 | auth.GET("/path", fc.GetPath)
32 | auth.GET("/idPath", fc.GetIDPath)
33 | }
34 | kb := api.Group("knowledge")
35 | kb.Use(middleware.JWTAuth())
36 | {
37 | // KB
38 | kb.POST("/create", kc.Create)
39 | kb.DELETE("/delete", kc.Delete)
40 | kb.POST("/add", kc.AddExistFile)
41 | kb.POST("/addNew", kc.AddNewFile)
42 | kb.GET("/page", kc.PageList)
43 | kb.GET("/detail", kc.GetKBDetail)
44 | // Doc
45 | kb.GET("/docPage", kc.DocPage)
46 | kb.POST("/docDelete", kc.DeleteDocs)
47 | // RAG
48 | kb.POST("/retrieve", kc.Retrieve)
49 | kb.POST("/chat", kc.Chat)
50 | kb.POST("/stream", kc.ChatStream)
51 | }
52 | model := api.Group("model")
53 | model.Use(middleware.JWTAuth())
54 | {
55 | model.POST("/create", mc.CreateModel)
56 | model.PUT("/update", mc.UpdateModel)
57 | model.DELETE("/delete", mc.DeleteModel)
58 | model.GET("/get", mc.GetModel)
59 | model.GET("/page", mc.PageModels)
60 | model.GET("/list", mc.ListModels)
61 | }
62 | agent := api.Group("agent")
63 | agent.Use(middleware.JWTAuth())
64 | {
65 | agent.POST("/create", ac.CreateAgent)
66 | agent.POST("/update", ac.UpdateAgent)
67 | agent.DELETE("/delete", ac.DeleteAgent)
68 | agent.GET("/get", ac.GetAgent)
69 | agent.GET("/page", ac.PageAgents)
70 | agent.POST("/execute/:id", ac.ExecuteAgent)
71 | agent.POST("/stream", ac.StreamExecuteAgent)
72 | }
73 | conv := api.Group("chat")
74 | conv.Use(middleware.JWTAuth())
75 | {
76 | // 调试模式,不保存历史
77 | conv.POST("/debug", cc.DebugStreamAgent)
78 | // 会话相关功能
79 | conv.POST("/create", cc.CreateConversation)
80 | conv.POST("/stream", cc.StreamConversation)
81 | conv.GET("/list", cc.ListConversations)
82 | conv.GET("/list/agent", cc.ListAgentConversations)
83 | conv.GET("/history", cc.GetConversationHistory)
84 | conv.DELETE("/delete", cc.DeleteConversation)
85 | }
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/internal/service/agent_service.go:
--------------------------------------------------------------------------------
1 | package service
2 |
3 | import (
4 | llmfactory "ai-cloud/internal/component/llm"
5 | mretriever "ai-cloud/internal/component/retriever/milvus"
6 | "ai-cloud/internal/dao"
7 | "ai-cloud/internal/model"
8 | "context"
9 | "encoding/json"
10 | "errors"
11 | "fmt"
12 | "time"
13 |
14 | mcpp "github.com/cloudwego/eino-ext/components/tool/mcp"
15 | "github.com/cloudwego/eino/components/prompt"
16 | "github.com/cloudwego/eino/components/tool"
17 | "github.com/cloudwego/eino/compose"
18 | "github.com/cloudwego/eino/flow/agent/react"
19 | "github.com/cloudwego/eino/schema"
20 | "github.com/mark3labs/mcp-go/client"
21 | "github.com/mark3labs/mcp-go/mcp"
22 | )
23 |
24 | const (
25 | InputToQuery = "InputToQuery"
26 | InputToHistory = "InputToHistory"
27 | ChatTemplate = "ChatTemplate"
28 | ChatModel = "ChatModel"
29 | Retriever = "Retriever"
30 | Agent = "Agent"
31 | )
32 |
33 | type AgentService interface {
34 | CreateAgent(ctx context.Context, agent *model.Agent) error
35 | UpdateAgent(ctx context.Context, agent *model.Agent) error
36 | DeleteAgent(ctx context.Context, userID uint, agentID string) error
37 | GetAgent(ctx context.Context, userID uint, agentID string) (*model.Agent, error)
38 | ListAgents(ctx context.Context, userID uint) ([]*model.Agent, error)
39 | PageAgents(ctx context.Context, userID uint, page, size int) ([]*model.Agent, int64, error)
40 | ExecuteAgent(ctx context.Context, userID uint, agentID string, msg model.UserMessage) (string, error)
41 | StreamExecuteAgent(ctx context.Context, userID uint, agentID string, msg model.UserMessage) (*schema.StreamReader[*schema.Message], error)
42 | }
43 |
44 | type agentService struct {
45 | dao dao.AgentDao
46 | modelSvc ModelService
47 | kbSvc KBService
48 | kbDao dao.KnowledgeBaseDao
49 | modelDao dao.ModelDao
50 | historySvc HistoryService
51 | }
52 |
53 | func NewAgentService(dao dao.AgentDao, modelSvc ModelService, kbSvc KBService, kbDao dao.KnowledgeBaseDao, modelDao dao.ModelDao, historySvc HistoryService) AgentService {
54 | return &agentService{
55 | dao: dao,
56 | modelSvc: modelSvc,
57 | kbSvc: kbSvc,
58 | kbDao: kbDao,
59 | modelDao: modelDao,
60 | historySvc: historySvc,
61 | }
62 | }
63 |
64 | func (s *agentService) CreateAgent(ctx context.Context, agent *model.Agent) error {
65 | return s.dao.Create(ctx, agent)
66 | }
67 |
68 | func (s *agentService) UpdateAgent(ctx context.Context, agent *model.Agent) error {
69 | return s.dao.Update(ctx, agent)
70 | }
71 |
72 | func (s *agentService) DeleteAgent(ctx context.Context, userID uint, agentID string) error {
73 | return s.dao.Delete(ctx, userID, agentID)
74 | }
75 |
76 | func (s *agentService) GetAgent(ctx context.Context, userID uint, agentID string) (*model.Agent, error) {
77 | return s.dao.GetByID(ctx, userID, agentID)
78 | }
79 |
80 | func (s *agentService) ListAgents(ctx context.Context, userID uint) ([]*model.Agent, error) {
81 | return s.dao.List(ctx, userID)
82 | }
83 |
84 | func (s *agentService) PageAgents(ctx context.Context, userID uint, page, size int) ([]*model.Agent, int64, error) {
85 | return s.dao.Page(ctx, userID, page, size)
86 | }
87 |
88 | func (s *agentService) ExecuteAgent(ctx context.Context, userID uint, agentID string, msg model.UserMessage) (string, error) {
89 | // Retrieve the agent
90 | agent, err := s.dao.GetByID(ctx, userID, agentID)
91 | if err != nil {
92 | return "", err
93 | }
94 |
95 | // Parse the agent schema
96 | var agentSchema model.AgentSchema
97 | if err := json.Unmarshal([]byte(agent.AgentSchema), &agentSchema); err != nil {
98 | return "", err
99 | }
100 |
101 | graph, err := s.buildGraph(ctx, userID, agentSchema)
102 | if err != nil {
103 | return "", fmt.Errorf("buildGraph失败:%w", err)
104 | }
105 |
106 | runner, err := graph.Compile(ctx, compose.WithGraphName("EinoAgent"), compose.WithNodeTriggerMode(compose.AllPredecessor))
107 |
108 | if err != nil {
109 | return "", err
110 | }
111 |
112 | res, err := runner.Invoke(ctx, &msg)
113 | if err != nil {
114 | return "", err
115 | }
116 | return res.String(), nil
117 | }
118 |
119 | func (s *agentService) StreamExecuteAgent(ctx context.Context, userID uint, agentID string, msg model.UserMessage) (*schema.StreamReader[*schema.Message], error) {
120 | // 1.获取Agent配置
121 | agent, err := s.dao.GetByID(ctx, userID, agentID)
122 | if err != nil {
123 | return nil, err
124 | }
125 |
126 | var agentSchema model.AgentSchema
127 | if err := json.Unmarshal([]byte(agent.AgentSchema), &agentSchema); err != nil {
128 | return nil, err
129 | }
130 |
131 | // 2.构建Graph
132 | graph, err := s.buildGraph(ctx, userID, agentSchema)
133 | if err != nil {
134 | return nil, fmt.Errorf("failed to build agent graph:%w", err)
135 | }
136 |
137 | // 3.构建runner
138 | runner, err := graph.Compile(ctx, compose.WithGraphName("EinoAgent"), compose.WithNodeTriggerMode(compose.AllPredecessor))
139 | if err != nil {
140 | return nil, fmt.Errorf("failed to compile agent graph: %w", err)
141 | }
142 |
143 | // 执行stream
144 | sr, err := runner.Stream(ctx, &msg)
145 | if err != nil {
146 | return nil, fmt.Errorf("failed to stream: %w", err)
147 | }
148 |
149 | return sr, nil
150 | }
151 |
152 | func (s *agentService) buildGraph(ctx context.Context, userID uint, agentSchema model.AgentSchema) (*compose.Graph[*model.UserMessage, *schema.Message], error) {
153 | // 1. 创建LLM
154 | llmModelCfg, err := s.modelSvc.GetModel(ctx, userID, agentSchema.LLMConfig.ModelID)
155 | if err != nil {
156 | return nil, fmt.Errorf("failed to create get model:%w", err)
157 | }
158 | llm, err := llmfactory.GetLLMClient(ctx, llmModelCfg)
159 | if err != nil {
160 | return nil, fmt.Errorf("failed to create llm client:%w", err)
161 | }
162 |
163 | // 2. 构建 Retriever
164 | multiRetriever := mretriever.MultiKBRetriever{
165 | KBIDs: agentSchema.Knowledge.KnowledgeIDs,
166 | UserID: userID,
167 | KBDao: s.kbDao,
168 | ModelDao: s.modelDao,
169 | Ctx: ctx,
170 | TopK: agentSchema.Knowledge.TopK, // 默认返回前5个最相关的文档
171 | }
172 |
173 | // 3. 构建Tools
174 | tools := []tool.BaseTool{}
175 | // 3.1 加载MCPTools
176 | for _, serverURL := range agentSchema.MCP.Servers {
177 | cli, err := client.NewSSEMCPClient(serverURL)
178 | err = cli.Start(ctx)
179 | if err != nil {
180 | return nil, fmt.Errorf("failed to create mcp client: %w", err)
181 | }
182 | initRequest := mcp.InitializeRequest{}
183 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
184 | initRequest.Params.ClientInfo = mcp.Implementation{
185 | Name: "example-client",
186 | Version: "1.0.0",
187 | }
188 |
189 | _, err = cli.Initialize(ctx, initRequest)
190 |
191 | if err != nil {
192 | return nil, err
193 | }
194 | // 获取 mcpp 工具
195 | mcppTools, err := mcpp.GetTools(ctx, &mcpp.Config{Cli: cli})
196 | if err != nil {
197 | return nil, fmt.Errorf("failed to get mcpp tools: %w", err)
198 | }
199 | tools = append(tools, mcppTools...)
200 | }
201 | // 3.2 加载系统和用户自定义Tools
202 |
203 | // 4. 构建提示词
204 | promptTemplate := prompt.FromMessages(
205 | schema.FString,
206 | schema.SystemMessage(agentSchema.Prompt),
207 | schema.MessagesPlaceholder("history", true),
208 | schema.UserMessage("用户消息:{query}\n 参考信息:{documents}"),
209 | )
210 |
211 | // 5. 实现图编排
212 | graph := compose.NewGraph[*model.UserMessage, *schema.Message]()
213 | _ = graph.AddLambdaNode(InputToQuery, compose.InvokableLambdaWithOption(inputToQueryLambda), compose.WithNodeName("UserMessageToQuery"))
214 | _ = graph.AddChatTemplateNode(ChatTemplate, promptTemplate)
215 | _ = graph.AddRetrieverNode(Retriever, multiRetriever, compose.WithOutputKey("documents"))
216 | _ = graph.AddLambdaNode(InputToHistory, compose.InvokableLambdaWithOption(inputToHistoryLambda), compose.WithNodeName("UserMessageToHistory"))
217 |
218 | // 根据是否有工具决定使用Agent还是直接使用ChatModel
219 | if len(tools) > 0 {
220 | // 有工具时使用Agent
221 | agentConfig := &react.AgentConfig{
222 | ToolCallingModel: llm,
223 | MaxStep: 10,
224 | ToolsConfig: compose.ToolsNodeConfig{
225 | Tools: tools,
226 | },
227 | }
228 |
229 | agt, err := react.NewAgent(ctx, agentConfig)
230 | if err != nil {
231 | return nil, fmt.Errorf("failed to create agent: %w", err)
232 | }
233 | if agt == nil {
234 | return nil, errors.New("react.NewAgent returned a nil agent instance")
235 | }
236 |
237 | agentLambda, _ := compose.AnyLambda(agt.Generate, agt.Stream, nil, nil)
238 |
239 | _ = graph.AddLambdaNode(Agent, agentLambda, compose.WithNodeName("Agent"))
240 |
241 | _ = graph.AddEdge(compose.START, InputToQuery)
242 | _ = graph.AddEdge(compose.START, InputToHistory)
243 | _ = graph.AddEdge(InputToQuery, Retriever)
244 | _ = graph.AddEdge(Retriever, ChatTemplate)
245 | _ = graph.AddEdge(InputToHistory, ChatTemplate)
246 | _ = graph.AddEdge(ChatTemplate, Agent)
247 | _ = graph.AddEdge(Agent, compose.END)
248 | } else {
249 | // 没有工具时直接使用ChatModel
250 | _ = graph.AddChatModelNode(ChatModel, llm)
251 |
252 | _ = graph.AddEdge(compose.START, InputToQuery)
253 | _ = graph.AddEdge(compose.START, InputToHistory)
254 | _ = graph.AddEdge(InputToQuery, Retriever)
255 | _ = graph.AddEdge(Retriever, ChatTemplate)
256 | _ = graph.AddEdge(InputToHistory, ChatTemplate)
257 | _ = graph.AddEdge(ChatTemplate, ChatModel)
258 | _ = graph.AddEdge(ChatModel, compose.END)
259 | }
260 |
261 | return graph, nil
262 | }
263 |
264 | // inputToQueryLambda component initialization function of node 'InputToQuery' in graph 'EinoAgent'
265 | func inputToQueryLambda(ctx context.Context, input *model.UserMessage, opts ...any) (output string, err error) {
266 | return input.Query, nil
267 | }
268 |
269 | // inputToHistoryLambda component initialization function of node 'InputToHistory' in graph 'EinoAgent'
270 | func inputToHistoryLambda(ctx context.Context, input *model.UserMessage, opts ...any) (output map[string]any, err error) {
271 | return map[string]any{
272 | "query": input.Query,
273 | "history": input.History,
274 | "date": time.Now().Format(time.DateTime),
275 | }, nil
276 | }
277 |
--------------------------------------------------------------------------------
/internal/service/conversation_service.go:
--------------------------------------------------------------------------------
1 | package service
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "context"
6 | "errors"
7 | "fmt"
8 | "io"
9 | "log"
10 | "time"
11 |
12 | "github.com/cloudwego/eino/schema"
13 | "github.com/google/uuid"
14 | )
15 |
16 | var defaultConvTitle = "新对话"
17 |
18 | type ConversationService interface {
19 | // Debug模式:临时会话,不保存历史
20 | DebugStreamAgent(ctx context.Context, userID uint, agentID string, message string) (*schema.StreamReader[*schema.Message], error)
21 |
22 | // 会话模式:创建/获取会话,记录历史
23 | StreamAgentWithConversation(ctx context.Context, userID uint, agentID string, convID string, message string) (*schema.StreamReader[*schema.Message], error)
24 |
25 | // 创建新会话
26 | CreateConversation(ctx context.Context, userID uint, agentID string) (string, error)
27 |
28 | // 删除会话
29 | DeleteConversation(ctx context.Context, convID string) error
30 |
31 | // 列出用户所有会话
32 | ListConversations(ctx context.Context, userID uint, page, size int) ([]*model.Conversation, int64, error)
33 |
34 | // 列出特定Agent的会话
35 | ListAgentConversations(ctx context.Context, userID uint, agentID string, page, size int) ([]*model.Conversation, int64, error)
36 |
37 | // 获取会话历史消息
38 | GetConversationHistory(ctx context.Context, convID string, limit int) ([]*schema.Message, error)
39 | }
40 |
41 | type conversationService struct {
42 | agentSvc AgentService
43 | historySvc HistoryService
44 | }
45 |
46 | func NewConversationService(agentSvc AgentService, historySvc HistoryService) ConversationService {
47 | return &conversationService{
48 | agentSvc: agentSvc,
49 | historySvc: historySvc,
50 | }
51 | }
52 |
53 | // DebugStreamAgent 调试模式:临时会话,不保存历史
54 | func (s *conversationService) DebugStreamAgent(ctx context.Context, userID uint, agentID string, message string) (*schema.StreamReader[*schema.Message], error) {
55 | // 创建用户消息,不含历史
56 | userMsg := model.UserMessage{
57 | Query: message,
58 | History: []*schema.Message{},
59 | }
60 |
61 | // 调用无状态的StreamExecuteAgent
62 | return s.agentSvc.StreamExecuteAgent(ctx, userID, agentID, userMsg)
63 | }
64 |
65 | // StreamAgentWithConversation 会话模式:记录历史
66 | func (s *conversationService) StreamAgentWithConversation(ctx context.Context, userID uint, agentID string, convID string, message string) (*schema.StreamReader[*schema.Message], error) {
67 | // 确保会话存在
68 | conv := &model.Conversation{
69 | ConvID: convID,
70 | UserID: userID,
71 | AgentID: agentID,
72 | CreatedAt: time.Now().Unix(),
73 | UpdatedAt: time.Now().Unix(),
74 | }
75 | err := s.historySvc.CreateConversation(ctx, conv)
76 | if err != nil {
77 | // 可能是会话已存在,忽略错误
78 | log.Printf("[StreamAgentWithConversation] 创建会话失败: %w", err)
79 | return nil, fmt.Errorf("获取会话失败: %w", err)
80 | }
81 |
82 | // 先获取历史消息
83 | historyMsgs, err := s.historySvc.GetHistory(ctx, convID, 50)
84 | if err != nil {
85 | log.Printf("[StreamAgentWithConversation] 获取历史消息失败: %w", err)
86 | return nil, fmt.Errorf("获取历史消息失败: %w", err)
87 | }
88 |
89 | // 保存用户消息
90 | userSchemaMsg := &schema.Message{
91 | Role: schema.User,
92 | Content: message,
93 | }
94 | err = s.historySvc.SaveMessage(ctx, userSchemaMsg, convID)
95 | if err != nil {
96 | log.Printf("[StreamAgentWithConversation] 保存用户消息失败: %w", err)
97 | return nil, fmt.Errorf("保存用户消息失败: %w", err)
98 | }
99 |
100 | // 创建用户消息,包含历史
101 | userMsg := model.UserMessage{
102 | Query: message,
103 | History: historyMsgs,
104 | }
105 |
106 | // 调用Agent处理
107 | sr, err := s.agentSvc.StreamExecuteAgent(ctx, userID, agentID, userMsg)
108 | if err != nil {
109 | log.Printf("[StreamAgentWithConversation] 运行Agent失败: %w", err)
110 | return nil, fmt.Errorf("运行Agent失败: %w", err)
111 | }
112 |
113 | // 复制流
114 | srs := sr.Copy(2)
115 |
116 | // 创建一个独立的上下文用于保存消息,不依赖于请求上下文
117 | saveCtx := context.Background()
118 |
119 | // 后台处理:记录完整回复
120 | go func() {
121 | fullMsgs := make([]*schema.Message, 0)
122 |
123 | defer func() {
124 | srs[1].Close()
125 | // 合并消息
126 | if len(fullMsgs) > 0 {
127 | fullMsg, err := schema.ConcatMessages(fullMsgs)
128 | if err != nil {
129 | fmt.Println("合并消息失败:", err.Error())
130 | return
131 | }
132 |
133 | // 使用独立上下文保存消息
134 | err = s.historySvc.SaveMessage(saveCtx, fullMsg, convID)
135 | if err != nil {
136 | fmt.Println("保存消息失败:", err.Error())
137 | }
138 |
139 | // 更新会话最后更新时间
140 | conv.UpdatedAt = time.Now().Unix()
141 | _ = s.historySvc.UpdateConversation(saveCtx, conv)
142 | }
143 | }()
144 |
145 | outer:
146 | for {
147 | select {
148 | case <-ctx.Done():
149 | fmt.Println("上下文已关闭:", ctx.Err())
150 | return
151 | default:
152 | chunk, err := srs[1].Recv()
153 | if err != nil {
154 | if errors.Is(err, io.EOF) {
155 | break outer
156 | }
157 | fmt.Println("接收消息块错误:", err.Error())
158 | return
159 | }
160 |
161 | fullMsgs = append(fullMsgs, chunk)
162 | }
163 | }
164 | }()
165 |
166 | return srs[0], nil
167 | }
168 |
169 | // CreateConversation 创建新会话
170 | func (s *conversationService) CreateConversation(ctx context.Context, userID uint, agentID string) (string, error) {
171 | convID := uuid.NewString()
172 |
173 | conv := &model.Conversation{
174 | ConvID: convID,
175 | UserID: userID,
176 | AgentID: agentID,
177 | Title: defaultConvTitle + time.Now().String(),
178 | CreatedAt: time.Now().Unix(),
179 | UpdatedAt: time.Now().Unix(),
180 | }
181 |
182 | err := s.historySvc.CreateConversation(ctx, conv)
183 | if err != nil {
184 | return "", fmt.Errorf("[CreateConversation] 创建会话失败: %w", err)
185 | }
186 |
187 | return convID, nil
188 | }
189 |
190 | // ListConversations 列出用户所有会话
191 | func (s *conversationService) ListConversations(ctx context.Context, userID uint, page, size int) ([]*model.Conversation, int64, error) {
192 | return s.historySvc.ListConversations(ctx, userID, page, size)
193 | }
194 |
195 | // ListAgentConversations 列出特定Agent的会话
196 | func (s *conversationService) ListAgentConversations(ctx context.Context, userID uint, agentID string, page, size int) ([]*model.Conversation, int64, error) {
197 | return s.historySvc.ListConversationsByAgent(ctx, userID, agentID, page, size)
198 | }
199 |
200 | // GetConversationHistory 获取会话历史消息
201 | func (s *conversationService) GetConversationHistory(ctx context.Context, convID string, limit int) ([]*schema.Message, error) {
202 | return s.historySvc.GetHistory(ctx, convID, limit)
203 | }
204 |
205 | // DeleteConversation 删除会话
206 | func (s *conversationService) DeleteConversation(ctx context.Context, convID string) error {
207 | return s.historySvc.DeleteConversation(ctx, convID)
208 | }
209 |
--------------------------------------------------------------------------------
/internal/service/history_service.go:
--------------------------------------------------------------------------------
1 | package service
2 |
3 | import (
4 | hisdao "ai-cloud/internal/dao/history"
5 | "ai-cloud/internal/model"
6 | "ai-cloud/internal/utils"
7 | "context"
8 | "fmt"
9 |
10 | "github.com/cloudwego/eino/schema"
11 | )
12 |
13 | type HistoryService interface {
14 | SaveMessage(ctx context.Context, mess *schema.Message, convID string) error
15 | GetHistory(ctx context.Context, convID string, limit int) ([]*schema.Message, error)
16 | CreateConversation(ctx context.Context, conv *model.Conversation) error
17 | UpdateConversation(ctx context.Context, conv *model.Conversation) error
18 | DeleteConversation(ctx context.Context, convID string) error
19 | ArchiveConversation(ctx context.Context, convID string) error
20 | UnArchiveConversation(ctx context.Context, convID string) error
21 | PinConversation(ctx context.Context, convID string) error
22 | UnPinConversation(ctx context.Context, convID string) error
23 | ListConversations(ctx context.Context, userID uint, page, size int) ([]*model.Conversation, int64, error)
24 | ListConversationsByAgent(ctx context.Context, userID uint, agentID string, page, size int) ([]*model.Conversation, int64, error)
25 | }
26 | type history struct {
27 | convDao hisdao.ConvDao
28 | msgDao hisdao.MsgDao
29 | }
30 |
31 | // NewHistoryService 创建历史记录服务
32 | func NewHistoryService(convDao hisdao.ConvDao, msgDao hisdao.MsgDao) HistoryService {
33 | return &history{
34 | convDao: convDao,
35 | msgDao: msgDao,
36 | }
37 | }
38 |
39 | // SaveMessage 保存消息
40 | func (s *history) SaveMessage(ctx context.Context, mess *schema.Message, convID string) error {
41 | err := s.msgDao.Create(ctx, &model.Message{
42 | Role: string(mess.Role),
43 | Content: mess.Content,
44 | ConvID: convID,
45 | })
46 | if err != nil {
47 | return fmt.Errorf("failed to save message: %w", err)
48 | }
49 | return nil
50 | }
51 |
52 | // GetHistory 获取对话的历史消息
53 | func (s *history) GetHistory(ctx context.Context, convID string, limit int) ([]*schema.Message, error) {
54 | if limit == 0 {
55 | limit = 50
56 | }
57 | _, err := s.convDao.GetByID(ctx, convID)
58 | if err != nil {
59 | return nil, fmt.Errorf("failed to get conversation: %w", err)
60 | }
61 |
62 | msgs, _, err := s.msgDao.List(ctx, convID, 0, limit)
63 | if err != nil {
64 | return nil, fmt.Errorf("failed to get messages: %w", err)
65 | }
66 |
67 | return utils.MessageList2ChatHistory(msgs), nil
68 | }
69 |
70 | // CreateConversation 创建会话
71 | func (s *history) CreateConversation(ctx context.Context, conv *model.Conversation) error {
72 | if err := s.convDao.FirstOrCreate(ctx, conv); err != nil {
73 | return fmt.Errorf("failed to create conversation: %w", err)
74 | }
75 | return nil
76 | }
77 |
78 | // UpdateConversation 更新会话
79 | func (s *history) UpdateConversation(ctx context.Context, conv *model.Conversation) error {
80 | if err := s.convDao.Update(ctx, conv); err != nil {
81 | return fmt.Errorf("failed to update conversation: %w", err)
82 | }
83 | return nil
84 | }
85 |
86 | // ArchiveConversation 归档会话
87 | func (s *history) ArchiveConversation(ctx context.Context, convID string) error {
88 | if err := s.convDao.Archive(ctx, convID); err != nil {
89 | return fmt.Errorf("failed to archive conversation: %w", err)
90 | }
91 | return nil
92 | }
93 |
94 | // UnArchiveConversation 取消归档会话
95 | func (s *history) UnArchiveConversation(ctx context.Context, convID string) error {
96 | if err := s.convDao.UnArchive(ctx, convID); err != nil {
97 | return fmt.Errorf("failed to unarchive conversation: %w", err)
98 | }
99 | return nil
100 | }
101 |
102 | // PinConversation 置顶会话
103 | func (s *history) PinConversation(ctx context.Context, convID string) error {
104 | if err := s.convDao.Pin(ctx, convID); err != nil {
105 | return fmt.Errorf("failed to pin conversation: %w", err)
106 | }
107 | return nil
108 | }
109 |
110 | // UnPinConversation 取消置顶会话
111 | func (s *history) UnPinConversation(ctx context.Context, convID string) error {
112 | if err := s.convDao.UnPin(ctx, convID); err != nil {
113 | return fmt.Errorf("failed to unpin conversation: %w", err)
114 | }
115 | return nil
116 | }
117 |
118 | // ListConversations 获取对话列表
119 | func (s *history) ListConversations(ctx context.Context, userID uint, page, size int) ([]*model.Conversation, int64, error) {
120 | convs, count, err := s.convDao.Page(ctx, userID, page, size)
121 | if err != nil {
122 | return nil, 0, fmt.Errorf("failed to list conversations: %w", err)
123 | }
124 | return convs, count, nil
125 | }
126 |
127 | // ListConversationsByAgent 按Agent获取对话列表
128 | func (s *history) ListConversationsByAgent(ctx context.Context, userID uint, agentID string, page, size int) ([]*model.Conversation, int64, error) {
129 | convs, count, err := s.convDao.PageByAgent(ctx, userID, agentID, page, size)
130 | if err != nil {
131 | return nil, 0, fmt.Errorf("failed to list conversations by agent: %w", err)
132 | }
133 | return convs, count, nil
134 | }
135 |
136 | // DeleteConversation 删除会话
137 | func (s *history) DeleteConversation(ctx context.Context, convID string) error {
138 | // 先获取会话的所有消息
139 | msgs, err := s.msgDao.ListByConvID(ctx, convID)
140 | if err != nil {
141 | return fmt.Errorf("failed to list messages for conversation: %w", err)
142 | }
143 |
144 | // 删除会话的所有消息
145 | for _, msg := range msgs {
146 | if err := s.msgDao.Delete(ctx, msg.MsgID); err != nil {
147 | return fmt.Errorf("failed to delete message: %w", err)
148 | }
149 | }
150 |
151 | // 删除会话
152 | if err := s.convDao.Delete(ctx, convID); err != nil {
153 | return fmt.Errorf("failed to delete conversation: %w", err)
154 | }
155 | return nil
156 | }
157 |
--------------------------------------------------------------------------------
/internal/service/model_service.go:
--------------------------------------------------------------------------------
1 | package service
2 |
3 | import (
4 | "ai-cloud/internal/dao"
5 | "ai-cloud/internal/model"
6 | "context"
7 | )
8 |
9 | type ModelService interface {
10 | CreateModel(ctx context.Context, m *model.Model) error
11 | UpdateModel(ctx context.Context, m *model.Model) error
12 | DeleteModel(ctx context.Context, userID uint, id string) error
13 | GetModel(ctx context.Context, userID uint, id string) (*model.Model, error)
14 | ListModels(ctx context.Context, userID uint, modelType string) ([]*model.Model, error)
15 | PageModels(ctx context.Context, userID uint, modelType string, page, size int) ([]*model.Model, int64, error)
16 | }
17 |
18 | type modelService struct {
19 | dao dao.ModelDao
20 | }
21 |
22 | func NewModelService(dao dao.ModelDao) ModelService {
23 | return &modelService{dao: dao}
24 | }
25 |
26 | func (s *modelService) CreateModel(ctx context.Context, m *model.Model) error {
27 | return s.dao.Create(ctx, m)
28 | }
29 |
30 | func (s *modelService) UpdateModel(ctx context.Context, m *model.Model) error {
31 | return s.dao.Update(ctx, m)
32 | }
33 |
34 | func (s *modelService) DeleteModel(ctx context.Context, userID uint, id string) error {
35 | return s.dao.Delete(ctx, userID, id)
36 | }
37 |
38 | func (s *modelService) GetModel(ctx context.Context, userID uint, id string) (*model.Model, error) {
39 | return s.dao.GetByID(ctx, userID, id)
40 | }
41 |
42 | func (s *modelService) ListModels(ctx context.Context, userID uint, modelType string) ([]*model.Model, error) {
43 | return s.dao.List(ctx, userID, modelType)
44 | }
45 |
46 | func (s *modelService) PageModels(ctx context.Context, userID uint, modelType string, page, size int) ([]*model.Model, int64, error) {
47 | return s.dao.Page(ctx, userID, modelType, page, size)
48 | }
49 |
--------------------------------------------------------------------------------
/internal/service/user_service.go:
--------------------------------------------------------------------------------
1 | package service
2 |
3 | import (
4 | "ai-cloud/config"
5 | "ai-cloud/internal/dao"
6 | "ai-cloud/internal/middleware"
7 | "ai-cloud/internal/model"
8 | "errors"
9 | "golang.org/x/crypto/bcrypt"
10 | )
11 |
12 | type UserService interface {
13 | Register(user *model.User) error
14 | Login(req *model.UserNameLoginReq) (*model.LoginResponse, error)
15 | }
16 |
17 | type userService struct {
18 | userDao dao.UserDao
19 | }
20 |
21 | func NewUserService(userDao dao.UserDao) UserService {
22 | return &userService{userDao: userDao}
23 | }
24 |
25 | func (s *userService) Register(user *model.User) error {
26 | // 检查
27 | usernameExists, err := s.userDao.CheckFieldExists("username", user.Username)
28 | if err != nil {
29 | return err
30 | }
31 | if usernameExists {
32 | return errors.New("用户名已注册")
33 | }
34 |
35 | phoneExists, err := s.userDao.CheckFieldExists("phone", user.Phone)
36 | if err != nil {
37 | return err
38 | }
39 | if phoneExists {
40 | return errors.New("手机号已注册")
41 | }
42 |
43 | hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
44 | if err != nil {
45 | return errors.New("密码加密失败")
46 | }
47 |
48 | newUser := &model.User{
49 | Username: user.Username,
50 | Phone: user.Phone,
51 | Password: string(hashedPassword),
52 | Email: user.Email,
53 | }
54 | err = s.userDao.CreateUser(newUser)
55 | if err != nil {
56 | return err
57 | }
58 | return nil
59 | }
60 |
61 | func (s *userService) Login(req *model.UserNameLoginReq) (*model.LoginResponse, error) {
62 | user, err := s.userDao.GetUserByName(req.Username)
63 | if err != nil {
64 | return nil, err
65 | }
66 |
67 | if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil {
68 | return nil, errors.New("用户名或密码错误")
69 | }
70 |
71 | accessToken, err := middleware.GenerateToken(user.ID)
72 | if err != nil {
73 | return nil, errors.New("系统错误")
74 | }
75 |
76 | return &model.LoginResponse{
77 | AccessToken: accessToken,
78 | ExpiresIn: config.AppConfigInstance.JWT.ExpirationHours * 3600,
79 | TokenType: "Bearer",
80 | }, nil
81 | }
82 |
--------------------------------------------------------------------------------
/internal/storage/local.go:
--------------------------------------------------------------------------------
1 | package storage
2 |
3 | import (
4 | "fmt"
5 | "os"
6 | "path/filepath"
7 | )
8 |
9 | // LocalStorage 本地存储驱动结构体
10 | type LocalStorage struct {
11 | baseDir string // 本地存储根目录(如 ./storage_data)
12 | }
13 |
14 | // NewLocalStorage 初始化本地存储
15 | func NewLocalStorage(baseDir string) (*LocalStorage, error) {
16 | // 确保存储目录存在
17 | if err := os.MkdirAll(baseDir, 0755); err != nil {
18 | return nil, fmt.Errorf("failed to create local storage dir: %v", err)
19 | }
20 | return &LocalStorage{baseDir: baseDir}, nil
21 | }
22 |
23 | // Upload 上传文件到本地
24 | func (s *LocalStorage) Upload(data []byte, key string, contentType string) error {
25 | fullPath := filepath.Join(s.baseDir, key)
26 | // 确保父目录存在
27 | if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil {
28 | return fmt.Errorf("failed to create parent dir: %v", err)
29 | }
30 | return os.WriteFile(fullPath, data, 0644)
31 | }
32 |
33 | // Download 从本地下载文件
34 | func (s *LocalStorage) Download(key string) ([]byte, error) {
35 | fullPath := filepath.Join(s.baseDir, key)
36 | return os.ReadFile(fullPath)
37 | }
38 |
39 | // Delete 删除本地文件
40 | func (s *LocalStorage) Delete(key string) error {
41 | fullPath := filepath.Join(s.baseDir, key)
42 | return os.Remove(fullPath)
43 | }
44 |
45 | // GetURL 获取本地文件路径(仅返回相对路径)
46 | func (s *LocalStorage) GetURL(key string) (string, error) {
47 | return filepath.Join(s.baseDir, key), nil
48 | }
49 |
--------------------------------------------------------------------------------
/internal/storage/minio.go:
--------------------------------------------------------------------------------
1 | package storage
2 |
3 | import (
4 | "ai-cloud/config"
5 | "bytes"
6 | "context"
7 | "fmt"
8 | "io"
9 | "net/url"
10 | "strings"
11 | "time"
12 |
13 | "github.com/minio/minio-go/v7"
14 | "github.com/minio/minio-go/v7/pkg/credentials"
15 | )
16 |
17 | type MinioStorage struct {
18 | client *minio.Client
19 | bucket string
20 | }
21 |
22 | // NewMinioStorage 创建新的 Minio 存储实例
23 | func NewMinioStorage(cfg config.MinioConfig) (Driver, error) {
24 | // 设置中国时区
25 | loc, err := time.LoadLocation("Asia/Shanghai")
26 | if err != nil {
27 | return nil, fmt.Errorf("failed to load timezone: %v", err)
28 | }
29 | time.Local = loc
30 |
31 | // 初始化 Minio 客户端
32 | client, err := minio.New(cfg.Endpoint, &minio.Options{
33 | Creds: credentials.NewStaticV4(cfg.AccessKeyID, cfg.AccessKeySecret, ""),
34 | Secure: cfg.UseSSL,
35 | Region: cfg.Region,
36 | })
37 | if err != nil {
38 | return nil, fmt.Errorf("failed to create minio client: %v", err)
39 | }
40 |
41 | // 检查 bucket 是否存在
42 | exists, err := client.BucketExists(context.Background(), cfg.Bucket)
43 | if err != nil {
44 | return nil, fmt.Errorf("failed to check bucket existence: %v", err)
45 | }
46 |
47 | // 如果 bucket 不存在,创建它
48 | if !exists {
49 | err = client.MakeBucket(context.Background(), cfg.Bucket, minio.MakeBucketOptions{
50 | Region: cfg.Region,
51 | })
52 | if err != nil {
53 | return nil, fmt.Errorf("failed to create bucket: %v", err)
54 | }
55 | }
56 |
57 | return &MinioStorage{
58 | client: client,
59 | bucket: cfg.Bucket,
60 | }, nil
61 | }
62 |
63 | // Upload 上传文件到 Minio
64 | func (m *MinioStorage) Upload(data []byte, key string, contentType string) error {
65 | reader := bytes.NewReader(data)
66 | _, err := m.client.PutObject(
67 | context.Background(),
68 | m.bucket,
69 | key,
70 | reader,
71 | int64(len(data)),
72 | minio.PutObjectOptions{
73 | ContentType: contentType, // 例如 "application/pdf"
74 | },
75 | )
76 | if err != nil {
77 | return fmt.Errorf("failed to upload file: %v", err)
78 | }
79 | return nil
80 | }
81 |
82 | // Download 从 Minio 下载文件
83 | func (m *MinioStorage) Download(key string) ([]byte, error) {
84 | obj, err := m.client.GetObject(
85 | context.Background(),
86 | m.bucket,
87 | key,
88 | minio.GetObjectOptions{},
89 | )
90 | if err != nil {
91 | return nil, fmt.Errorf("failed to get object: %v", err)
92 | }
93 | defer obj.Close()
94 |
95 | data, err := io.ReadAll(obj)
96 | if err != nil {
97 | return nil, fmt.Errorf("failed to read object data: %v", err)
98 | }
99 | return data, nil
100 | }
101 |
102 | // Delete 从 Minio 删除文件
103 | func (m *MinioStorage) Delete(key string) error {
104 | err := m.client.RemoveObject(context.Background(), m.bucket, key, minio.RemoveObjectOptions{})
105 | if err != nil {
106 | return fmt.Errorf("failed to delete object: %v", err)
107 | }
108 | return nil
109 | }
110 |
111 | func (m *MinioStorage) GetURL(key string) (string, error) {
112 | // 设置响应头,强制浏览器下载文件
113 | reqParams := make(url.Values)
114 | reqParams.Set("response-content-disposition", "attachment")
115 |
116 | // 生成预签名URL,有效期1小时
117 | expiry := time.Hour * 1
118 | presignedURL, err := m.client.PresignedGetObject(
119 | context.Background(),
120 | m.bucket,
121 | key,
122 | expiry,
123 | reqParams, // 关键:传递自定义参数
124 | )
125 | if err != nil {
126 | return "", fmt.Errorf("failed to generate presigned URL: %v", err)
127 | }
128 |
129 | return presignedURL.String(), nil
130 | }
131 |
132 | // CreateDirectory 创建目录(通过上传空对象实现)
133 | func (m *MinioStorage) CreateDirectory(dirPath string) error {
134 | // 确保路径以 / 结尾
135 | if !strings.HasSuffix(dirPath, "/") {
136 | dirPath = dirPath + "/"
137 | }
138 |
139 | // 上传一个空对象来表示目录
140 | _, err := m.client.PutObject(context.Background(), m.bucket, dirPath, bytes.NewReader([]byte{}), 0, minio.PutObjectOptions{})
141 | if err != nil {
142 | return fmt.Errorf("failed to create directory: %v", err)
143 | }
144 | return nil
145 | }
146 |
--------------------------------------------------------------------------------
/internal/storage/oss.go:
--------------------------------------------------------------------------------
1 | package storage
2 |
3 | import (
4 | "ai-cloud/config"
5 | "bytes"
6 | "fmt"
7 | "time"
8 |
9 | "github.com/aliyun/aliyun-oss-go-sdk/oss"
10 | )
11 |
12 | // OSSStorage 阿里云OSS存储驱动结构体
13 | type OSSStorage struct {
14 | bucket *oss.Bucket // OSS Bucket实例
15 | }
16 |
17 | // NewOSSStorage 初始化OSS存储
18 | func NewOSSStorage(cfg config.OSSConfig) (*OSSStorage, error) {
19 | // 创建OSS客户端
20 | client, err := oss.New(cfg.Endpoint, cfg.AccessKeyID, cfg.AccessKeySecret)
21 | if err != nil {
22 | return nil, fmt.Errorf("failed to create OSS client: %v", err)
23 | }
24 |
25 | // 获取Bucket实例
26 | bucket, err := client.Bucket(cfg.Bucket)
27 | if err != nil {
28 | return nil, fmt.Errorf("failed to get OSS bucket: %v", err)
29 | }
30 |
31 | return &OSSStorage{bucket: bucket}, nil
32 | }
33 |
34 | // Upload 上传文件到OSS
35 | func (s *OSSStorage) Upload(data []byte, key string, contentType string) error {
36 | return s.bucket.PutObject(key, bytes.NewReader(data))
37 | }
38 |
39 | // Download 从OSS下载文件
40 | func (s *OSSStorage) Download(key string) ([]byte, error) {
41 | reader, err := s.bucket.GetObject(key)
42 | if err != nil {
43 | return nil, fmt.Errorf("failed to download from OSS: %v", err)
44 | }
45 | defer reader.Close()
46 |
47 | buf := new(bytes.Buffer)
48 | if _, err := buf.ReadFrom(reader); err != nil {
49 | return nil, fmt.Errorf("failed to read OSS data: %v", err)
50 | }
51 | return buf.Bytes(), nil
52 | }
53 |
54 | // Delete 删除OSS文件
55 | func (s *OSSStorage) Delete(key string) error {
56 | return s.bucket.DeleteObject(key)
57 | }
58 |
59 | // GetURL 生成带签名的临时访问URL(有效期1小时)
60 | func (s *OSSStorage) GetURL(key string) (string, error) {
61 | expired := time.Now().Add(1 * time.Hour)
62 | return s.bucket.SignURL(key, oss.HTTPGet, int64(expired.Unix()))
63 | }
64 |
--------------------------------------------------------------------------------
/internal/storage/storage.go:
--------------------------------------------------------------------------------
1 | package storage
2 |
3 | import (
4 | "ai-cloud/config"
5 | "fmt"
6 | )
7 |
8 | // Driver 定义存储驱动接口
9 | type Driver interface {
10 | Upload(data []byte, key string, contentType string) error // 上传文件
11 | Download(key string) ([]byte, error) // 下载文件
12 | Delete(key string) error // 删除文件
13 | GetURL(key string) (string, error) // 获取访问URL
14 | }
15 |
16 | // NewDriver 根据配置初始化存储驱动
17 | func NewDriver(cfg config.StorageConfig) (Driver, error) {
18 | switch cfg.Type {
19 | case "local":
20 | return NewLocalStorage(cfg.Local.BaseDir)
21 | case "oss":
22 | return NewOSSStorage(cfg.OSS)
23 | case "minio":
24 | return NewMinioStorage(cfg.Minio)
25 | default:
26 | return nil, fmt.Errorf("unsupported storage type: %s", cfg.Type)
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/internal/utils/agent_utils.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "strings"
5 | )
6 |
7 | // ReplaceKnowledgePlaceholder replaces the {{Knowledge}} placeholder in the prompt
8 | func ReplaceKnowledgePlaceholder(prompt string, knowledge string) string {
9 | return strings.Replace(prompt, "{{Knowledge}}", knowledge, -1)
10 | }
11 |
12 | // FormatRetrievalResults formats the retrieval results for insertion into the prompt
13 | func FormatRetrievalResults(results []map[string]interface{}) string {
14 | var builder strings.Builder
15 |
16 | for i, result := range results {
17 | if i > 0 {
18 | builder.WriteString("\n\n")
19 | }
20 |
21 | // Extract content and metadata
22 | if content, ok := result["content"].(string); ok {
23 | builder.WriteString(content)
24 | }
25 |
26 | // Add source if available
27 | if source, ok := result["source"].(string); ok {
28 | builder.WriteString("\nSource: " + source)
29 | }
30 | }
31 |
32 | return builder.String()
33 | }
34 |
--------------------------------------------------------------------------------
/internal/utils/context.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "errors"
5 | "github.com/gin-gonic/gin"
6 | )
7 |
8 | // 定义上下文键名(避免硬编码)
9 | const UserIDKey = "user_id"
10 |
11 | func GetUserIDFromContext(c *gin.Context) (uint, error) {
12 | // 从上下文中获取值
13 | userIDVal, exists := c.Get(UserIDKey)
14 | if !exists {
15 | return 0, errors.New("上下文中未找到用户ID")
16 | }
17 |
18 | // 类型断言
19 | userID, ok := userIDVal.(uint)
20 | if !ok {
21 | return 0, errors.New("用户ID类型错误")
22 | }
23 |
24 | return userID, nil
25 | }
26 |
--------------------------------------------------------------------------------
/internal/utils/convert.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "strconv"
5 | )
6 |
7 | // StringToInt 将字符串转换为整数,出错时返回默认值0
8 | func StringToInt(s string) int {
9 | i, err := strconv.Atoi(s)
10 | if err != nil {
11 | return 0
12 | }
13 | return i
14 | }
15 |
16 | // StringToInt64 将字符串转换为int64,出错时返回默认值0
17 | func StringToInt64(s string) int64 {
18 | i, err := strconv.ParseInt(s, 10, 64)
19 | if err != nil {
20 | return 0
21 | }
22 | return i
23 | }
24 |
25 | // IntToString 将整数转换为字符串
26 | func IntToString(i int) string {
27 | return strconv.Itoa(i)
28 | }
29 |
30 | // Int64ToString 将int64转换为字符串
31 | func Int64ToString(i int64) string {
32 | return strconv.FormatInt(i, 10)
33 | }
34 |
--------------------------------------------------------------------------------
/internal/utils/convert_float.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | func ConvertFloat64ToFloat32Embeddings(embeddings [][]float64) [][]float32 {
4 | float32Embeddings := make([][]float32, len(embeddings))
5 | for i, vec64 := range embeddings {
6 | vec32 := make([]float32, len(vec64))
7 | for j, v := range vec64 {
8 | vec32[j] = float32(v)
9 | }
10 | float32Embeddings[i] = vec32
11 | }
12 | return float32Embeddings
13 | }
14 |
15 | func ConvertFloat64ToFloat32Embedding(embedding []float64) []float32 {
16 | float32Embedding := make([]float32, len(embedding))
17 | for i, v := range embedding {
18 | float32Embedding[i] = float32(v)
19 | }
20 | return float32Embedding
21 | }
22 |
--------------------------------------------------------------------------------
/internal/utils/hitory_utils.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "ai-cloud/internal/model"
5 | "github.com/cloudwego/eino/schema"
6 | )
7 |
8 | func MessageList2ChatHistory(mess []*model.Message) (history []*schema.Message) {
9 | for _, m := range mess {
10 | history = append(history, message2MessagesTemplate(m))
11 | }
12 | return
13 | }
14 |
15 | func message2MessagesTemplate(mess *model.Message) *schema.Message {
16 | return &schema.Message{
17 | Role: schema.RoleType(mess.Role),
18 | Content: mess.Content,
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/internal/utils/pagination.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "fmt"
5 | "github.com/gin-gonic/gin"
6 | "strconv"
7 | )
8 |
9 | func ParsePaginationParams(ctx *gin.Context) (page int, pageSize int, err error) {
10 | pageStr := ctx.DefaultQuery("page", "1")
11 | page, err = strconv.Atoi(pageStr)
12 | if err != nil || page < 1 {
13 | return 0, 0, fmt.Errorf("无效的页码参数")
14 |
15 | }
16 |
17 | pageSizeStr := ctx.DefaultQuery("page_size", "10")
18 | pageSize, err = strconv.Atoi(pageSizeStr)
19 | if err != nil || pageSize < 1 {
20 | return 0, 0, fmt.Errorf("无效的页面大小参数")
21 | }
22 |
23 | return page, pageSize, nil
24 | }
25 |
--------------------------------------------------------------------------------
/internal/utils/uuid.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import "github.com/google/uuid"
4 |
5 | func GenerateUUID() string {
6 | return uuid.New().String()
7 | }
8 |
--------------------------------------------------------------------------------
/internal/utils/validate_sort.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "fmt"
5 | "strings"
6 | )
7 |
8 | func ValidateSortParameter(sort string, allowedFields []string) error {
9 | clauses := strings.Split(sort, ",")
10 |
11 | for _, clause := range clauses {
12 | parts := strings.Split(clause, ":")
13 | field, order := parts[0], parts[1]
14 | if !contains(allowedFields, field) {
15 | return fmt.Errorf("无效的排序字段:%s", field)
16 | }
17 | if order != "asc" && order != "desc" {
18 | return fmt.Errorf("无效的排序方向:%s", order)
19 | }
20 | }
21 | return nil
22 | }
23 |
24 | func contains(slice []string, item string) bool {
25 | for _, s := range slice {
26 | if s == item {
27 | return true
28 | }
29 | }
30 | return false
31 | }
32 |
--------------------------------------------------------------------------------
/pkgs/consts/milvus_const.go:
--------------------------------------------------------------------------------
1 | package consts
2 |
3 | // 字段名称常量定义
4 | const (
5 | // FieldNameID ID字段名
6 | FieldNameID = "id"
7 | // FieldNameContent 内容字段名
8 | FieldNameContent = "content"
9 | // FieldNameDocumentID 文档ID字段名
10 | FieldNameDocumentID = "document_id"
11 | // FieldNameDocumentName 文档名称字段名
12 | FieldNameDocumentName = "document_name"
13 | // FieldNameKBID 知识库ID字段名
14 | FieldNameKBID = "kb_id"
15 | // FieldNameChunkIndex 块索引字段名
16 | FieldNameChunkIndex = "chunk_index"
17 | // FieldNameVector 向量字段名
18 | FieldNameVector = "vector"
19 | // FiledNameMetadata meta信息
20 | FieldNameMetadata = "metadata"
21 | )
22 |
--------------------------------------------------------------------------------
/pkgs/errcode/errcode.go:
--------------------------------------------------------------------------------
1 | package errcode
2 |
3 | const (
4 | /******************** 基础错误码 (10000-19999) ********************/
5 | // 通用系统错误
6 | SuccessCode = 0 // 特殊保留成功码
7 | InternalServerError = 10001 // 服务器内部错误
8 | DatabaseError = 10002 // 数据库操作错误
9 | CacheError = 10003 // 缓存服务错误
10 | RateLimitExceeded = 10004 // 请求频率限制
11 |
12 | // 请求参数相关
13 | ParamBindError = 10101 // 参数绑定错误
14 | ParamValidateError = 10102 // 参数验证失败
15 |
16 | // 认证授权相关
17 | UnauthorizedError = 10201 // 身份未认证
18 | ForbiddenError = 10202 // 无访问权限
19 | TokenExpired = 10203 // Token过期
20 | TokenInvalid = 10204 // Token无效
21 | TokenMissing = 10205 // Token缺失
22 |
23 | /******************** 业务错误码 ********************/
24 | // 用户模块 (20000-20999)
25 | UserNotFound = 20001 // 用户不存在
26 | UserAlreadyExists = 20002 // 用户已存在
27 | PasswordMismatch = 20003 // 密码错误
28 | UserDisabled = 20004 // 用户被禁用
29 | EmailInvalid = 20005 // 邮箱格式错误
30 |
31 | // 文件模块 (21000-21999)
32 | FileNotFound = 21001 // 文件不存在
33 | FileUploadFailed = 21002 // 文件上传失败
34 | FileDeleteFailed = 21003 // 文件删除失败
35 | FileSizeExceeded = 21004 // 文件大小超限
36 | FileTypeInvalid = 21005 // 文件类型无效
37 | FileParseFailed = 21006 // 文件解析失败
38 | FileListFailed = 21007 // 文杰列表获取失败
39 | FileSearchFailed = 21008 // 文件搜索失败
40 | // 订单模块 (22000-22999)
41 | // 可后续扩展...
42 | )
43 |
--------------------------------------------------------------------------------
/pkgs/response/response.go:
--------------------------------------------------------------------------------
1 | package response
2 |
3 | import (
4 | "github.com/gin-gonic/gin"
5 | "net/http"
6 | )
7 |
8 | // Response 统一响应结构
9 | type Response struct {
10 | Code int `json:"code"` // 业务码
11 | Message string `json:"message"` // 提示信息
12 | Data interface{} `json:"data"` // 数据
13 | }
14 |
15 | // 分页数据结构
16 | type PageData struct {
17 | Total int64 `json:"total"` // 总数
18 | List interface{} `json:"list"` // 数据列表
19 | }
20 |
21 | const (
22 | SUCCESS = 0
23 | ERROR = 1
24 | ERROR_PARAM = 2
25 | ERROR_AUTH = 401
26 | ERROR_SERVER = 500
27 | )
28 |
29 | // Success 成功响应
30 | func Success(c *gin.Context, data interface{}) {
31 | c.JSON(http.StatusOK, Response{
32 | Code: SUCCESS,
33 | Message: "success",
34 | Data: data,
35 | })
36 | }
37 |
38 | // SuccessWithMessage 成功响应带自定义消息
39 | func SuccessWithMessage(c *gin.Context, message string, data interface{}) {
40 | c.JSON(http.StatusOK, Response{
41 | Code: SUCCESS,
42 | Message: message,
43 | Data: data,
44 | })
45 | }
46 |
47 | // PageSuccess 分页数据响应
48 | func PageSuccess(c *gin.Context, list interface{}, total int64) {
49 | c.JSON(http.StatusOK, Response{
50 | Code: SUCCESS,
51 | Message: "success",
52 | Data: PageData{
53 | List: list,
54 | Total: total,
55 | },
56 | })
57 | }
58 |
59 | // Error 错误响应
60 | func Error(c *gin.Context, message string) {
61 | c.JSON(http.StatusOK, Response{
62 | Code: ERROR,
63 | Message: message,
64 | Data: nil,
65 | })
66 | }
67 |
68 | //// ErrorWithCode 错误响应带自定义错误码
69 | //func ErrorWithCode(c *gin.Context, code int, message string) {
70 | // c.JSON(http.StatusOK, Response{
71 | // Code: code,
72 | // Message: message,
73 | // Data: nil,
74 | // })
75 | //}
76 |
77 | // 错误响应
78 | func ErrorCustom(c *gin.Context, httpCode, code int, msg string, data interface{}) {
79 | c.JSON(httpCode, Response{
80 | Code: code,
81 | Message: msg,
82 | Data: data,
83 | })
84 | }
85 | func ParamError(c *gin.Context, code int, msg string) {
86 | c.JSON(http.StatusBadRequest, Response{
87 | Code: code,
88 | Message: msg,
89 | })
90 | }
91 |
92 | func UnauthorizedError(c *gin.Context, code int, msg string) {
93 | c.JSON(http.StatusUnauthorized, Response{
94 | Code: code,
95 | Message: msg,
96 | })
97 | }
98 |
99 | func InternalError(c *gin.Context, code int, msg string) {
100 | c.JSON(http.StatusInternalServerError, Response{
101 | Code: code,
102 | Message: msg,
103 | })
104 | }
105 |
--------------------------------------------------------------------------------