├── server ├── core │ ├── test_file │ │ ├── readme2.md │ │ ├── readme.md │ │ ├── test.pdf │ │ ├── test.xlsx │ │ └── readme.html │ ├── readme.md │ ├── common │ │ ├── chat_model_test.go │ │ ├── slices_test.go │ │ ├── helper.go │ │ ├── embedding.go │ │ ├── stream.go │ │ ├── es.go │ │ └── chat_model.go │ ├── types │ │ └── consts.go │ ├── vector │ │ ├── factory.go │ │ ├── qdrant_test.go │ │ ├── interface.go │ │ ├── factory_test.go │ │ └── README.md │ ├── rerank │ │ ├── rerank_test.go │ │ └── rerank.go │ ├── retriever │ │ ├── orchestration.go │ │ └── retriever.go │ ├── indexer │ │ ├── async.go │ │ ├── parser.go │ │ ├── loader.go │ │ ├── orchestration.go │ │ ├── transformer.go │ │ ├── qa.go │ │ ├── indexer_async.go │ │ ├── qdrant_indexer.go │ │ ├── indexer.go │ │ ├── merge.go │ │ └── qdrant_indexer_custom.go │ ├── config │ │ └── config.go │ ├── grader │ │ ├── grader.go │ │ └── message.go │ ├── message.go │ └── rag.go ├── README.md ├── static │ ├── kb.png │ ├── rag.png │ ├── wx.jpg │ ├── chat.png │ ├── wachat.png │ ├── doc-list.png │ ├── indexer.png │ ├── mcp-cfg.png │ ├── mcp-use.png │ ├── chunk-edit.png │ └── retriever.png ├── internal │ ├── mcp │ │ ├── mcp.go │ │ ├── knowledgebase.go │ │ ├── retriever.go │ │ └── indexer.go │ ├── controller │ │ └── rag │ │ │ ├── rag.go │ │ │ ├── rag_new.go │ │ │ ├── rag_v1_update_chunk.go │ │ │ ├── rag_v1_chunks_list.go │ │ │ ├── rag_v1_documents_list.go │ │ │ ├── rag_v1_chat.go │ │ │ ├── rag_v1_retriever_dify.go │ │ │ ├── rag_v1_chunk_delete.go │ │ │ ├── rag_v1_documents_delete.go │ │ │ ├── rag_v1_chat_stream.go │ │ │ ├── rag_v1_indexer.go │ │ │ ├── rag_v1_retriever.go │ │ │ ├── rag_v1_kb.go │ │ │ └── rag_v1_update_chunk_content.go │ ├── model │ │ ├── gorm │ │ │ ├── migrate.go │ │ │ ├── knowledge_base.go │ │ │ ├── knowledge_documents.go │ │ │ └── knowledge_chunks.go │ │ ├── do │ │ │ ├── knowledge_documents.go │ │ │ ├── knowledge_base.go │ │ │ └── knowledge_chunks.go │ │ └── entity │ │ │ ├── knowledge_documents.go │ │ │ ├── knowledge_base.go │ │ │ └── knowledge_chunks.go │ ├── dao │ │ ├── db │ │ │ ├── base.go │ │ │ ├── interface.go │ │ │ ├── mysql.go │ │ │ └── sqlite.go │ │ ├── knowledge_chunks.go │ │ ├── knowledge_documents.go │ │ ├── knowledge_base.go │ │ ├── internal │ │ │ ├── knowledge_base.go │ │ │ ├── knowledge_chunks.go │ │ │ └── knowledge_documents.go │ │ └── dao.go │ ├── cmd │ │ ├── cmd.go │ │ └── middleware.go │ └── logic │ │ ├── chat │ │ ├── message.go │ │ └── chat.go │ │ ├── rag │ │ └── retriever.go │ │ └── knowledge │ │ ├── chunks.go │ │ └── documents.go ├── hack │ └── config.yaml ├── main.go ├── server │ └── server.go ├── api │ └── rag │ │ ├── v1 │ │ ├── indexer.go │ │ ├── chat.go │ │ ├── documents.go │ │ ├── retriever.go │ │ ├── chunks.go │ │ └── knowledge_base.go │ │ └── rag.go └── manifest │ └── config │ ├── config_demo.yaml │ └── config_qd_demo.yaml ├── fe ├── public │ ├── CNAME │ ├── favicon.svg │ ├── element-plus-logo-small.svg │ ├── vite.svg │ └── logo.svg ├── src │ ├── composables │ │ ├── index.ts │ │ └── dark.ts │ ├── types.ts │ ├── pages │ │ └── index.vue │ ├── modules │ │ └── router.ts │ ├── env.d.ts │ ├── styles │ │ ├── element │ │ │ ├── dark.scss │ │ │ └── index.scss │ │ ├── index.scss │ │ └── markdown.css │ ├── assets │ │ └── vue.svg │ ├── utils │ │ ├── markdown.js │ │ ├── knowledgeIdStore.js │ │ ├── format.js │ │ ├── knowledgeStore.js │ │ └── request.js │ ├── App.vue │ ├── components │ │ ├── MessageBoxDemo.vue │ │ ├── Logos.vue │ │ ├── layouts │ │ │ ├── BaseHeader.vue │ │ │ └── BaseSide.vue │ │ ├── KnowledgeNameSetting.vue_back │ │ └── HelloWorld.vue │ ├── main.ts │ ├── router │ │ └── index.ts │ ├── typed-router.d.ts │ └── components.d.ts ├── .env.development ├── .npmrc ├── eslint.config.js ├── index.html ├── tsconfig.json ├── uno.config.ts ├── README.md ├── package.json └── vite.config.ts ├── Dockerfile ├── .gitignore ├── roadmap.md ├── docker-compose.yml ├── README.md └── Makefile /server/core/test_file/readme2.md: -------------------------------------------------------------------------------- 1 | # 开饭了 -------------------------------------------------------------------------------- /server/README.md: -------------------------------------------------------------------------------- 1 | # go-rag api 2 | rag api 项目 3 | -------------------------------------------------------------------------------- /fe/public/CNAME: -------------------------------------------------------------------------------- 1 | vite-starter.element-plus.org 2 | -------------------------------------------------------------------------------- /server/core/test_file/readme.md: -------------------------------------------------------------------------------- 1 | # 这是一个readme文件,这里有很多内容 -------------------------------------------------------------------------------- /fe/src/composables/index.ts: -------------------------------------------------------------------------------- 1 | export * from './dark' 2 | -------------------------------------------------------------------------------- /fe/.env.development: -------------------------------------------------------------------------------- 1 | VITE_API_BASE_URL=http://localhost:8000/api -------------------------------------------------------------------------------- /fe/.npmrc: -------------------------------------------------------------------------------- 1 | shamefully-hoist=true 2 | strict-peer-dependencies=false 3 | -------------------------------------------------------------------------------- /server/static/kb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangle201210/go-rag/HEAD/server/static/kb.png -------------------------------------------------------------------------------- /server/static/rag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangle201210/go-rag/HEAD/server/static/rag.png -------------------------------------------------------------------------------- /server/static/wx.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangle201210/go-rag/HEAD/server/static/wx.jpg -------------------------------------------------------------------------------- /server/static/chat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangle201210/go-rag/HEAD/server/static/chat.png -------------------------------------------------------------------------------- /server/static/wachat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangle201210/go-rag/HEAD/server/static/wachat.png -------------------------------------------------------------------------------- /server/static/doc-list.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangle201210/go-rag/HEAD/server/static/doc-list.png -------------------------------------------------------------------------------- /server/static/indexer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangle201210/go-rag/HEAD/server/static/indexer.png -------------------------------------------------------------------------------- /server/static/mcp-cfg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangle201210/go-rag/HEAD/server/static/mcp-cfg.png -------------------------------------------------------------------------------- /server/static/mcp-use.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangle201210/go-rag/HEAD/server/static/mcp-use.png -------------------------------------------------------------------------------- /server/static/chunk-edit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangle201210/go-rag/HEAD/server/static/chunk-edit.png -------------------------------------------------------------------------------- /server/static/retriever.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangle201210/go-rag/HEAD/server/static/retriever.png -------------------------------------------------------------------------------- /server/core/test_file/test.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangle201210/go-rag/HEAD/server/core/test_file/test.pdf -------------------------------------------------------------------------------- /server/core/test_file/test.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangle201210/go-rag/HEAD/server/core/test_file/test.xlsx -------------------------------------------------------------------------------- /fe/src/types.ts: -------------------------------------------------------------------------------- 1 | import type { ViteSSGContext } from 'vite-ssg' 2 | 3 | export type UserModule = (ctx: ViteSSGContext) => void 4 | -------------------------------------------------------------------------------- /server/core/readme.md: -------------------------------------------------------------------------------- 1 | # go-rag 核心逻辑 2 | 此目录本应放在internal目录下,但为了便于被其他项目直接引用,将其放在core目录下 3 | 4 | 详情可以参照 [test文件](../core/rag_test.go) 5 | -------------------------------------------------------------------------------- /fe/src/pages/index.vue: -------------------------------------------------------------------------------- 1 | 5 | -------------------------------------------------------------------------------- /server/internal/mcp/mcp.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import "github.com/wangle201210/go-rag/server/internal/controller/rag" 4 | 5 | var c = rag.NewV1() 6 | -------------------------------------------------------------------------------- /fe/src/modules/router.ts: -------------------------------------------------------------------------------- 1 | import type { UserModule } from '~/types' 2 | 3 | export const install: UserModule = ({ router: _router }) => { 4 | // 路由守卫可以在这里添加 5 | } -------------------------------------------------------------------------------- /fe/eslint.config.js: -------------------------------------------------------------------------------- 1 | import antfu from '@antfu/eslint-config' 2 | 3 | export default antfu({ 4 | formatters: true, 5 | unocss: true, 6 | vue: true, 7 | }) 8 | -------------------------------------------------------------------------------- /fe/src/composables/dark.ts: -------------------------------------------------------------------------------- 1 | import { useDark, useToggle } from '@vueuse/core' 2 | 3 | export const isDark = useDark() 4 | export const toggleDark = useToggle(isDark) 5 | -------------------------------------------------------------------------------- /fe/src/env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | 3 | declare module '*.vue' { 4 | import type { DefineComponent } from 'vue' 5 | 6 | const component: DefineComponent 7 | export default component 8 | } 9 | -------------------------------------------------------------------------------- /fe/src/styles/element/dark.scss: -------------------------------------------------------------------------------- 1 | // only scss variables 2 | 3 | $--colors: ( 4 | 'primary': ( 5 | 'base': #589ef8, 6 | ), 7 | ); 8 | 9 | @forward 'element-plus/theme-chalk/src/dark/var.scss' with ( 10 | $colors: $--colors 11 | ); 12 | -------------------------------------------------------------------------------- /server/core/test_file/readme.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | golang项目搭建 4 | 5 | 6 |

1. 安装环境

7 |

2. 写代码

8 |

3. 执行

9 | 10 | -------------------------------------------------------------------------------- /server/hack/config.yaml: -------------------------------------------------------------------------------- 1 | gfcli: 2 | gen: 3 | dao: 4 | - link: "mysql:root:123456@tcp(127.0.0.1:3306)/go-rag?charset=utf8mb4&parseTime=True&loc=Local" 5 | descriptionTag: true 6 | tables: "knowledge_base,knowledge_chunks,knowledge_documents" 7 | -------------------------------------------------------------------------------- /server/internal/controller/rag/rag.go: -------------------------------------------------------------------------------- 1 | // ================================================================================= 2 | // This is auto-generated by GoFrame CLI tool only once. Fill this file as you wish. 3 | // ================================================================================= 4 | 5 | package rag 6 | -------------------------------------------------------------------------------- /server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/gogf/gf/v2/os/gctx" 5 | 6 | _ "github.com/gogf/gf/contrib/drivers/mysql/v2" 7 | _ "github.com/gogf/gf/contrib/drivers/sqlite/v2" 8 | "github.com/wangle201210/go-rag/server/internal/cmd" 9 | ) 10 | 11 | func main() { 12 | cmd.Main.Run(gctx.GetInitCtx()) 13 | } 14 | -------------------------------------------------------------------------------- /server/server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/gogf/gf/v2/os/gctx" 7 | "github.com/wangle201210/go-rag/server/internal/cmd" 8 | ) 9 | 10 | // Start 启动 go-rag 服务器(公开函数) 11 | func Start(ctx context.Context) { 12 | if ctx == nil { 13 | ctx = gctx.GetInitCtx() 14 | } 15 | cmd.Main.Run(ctx) 16 | } 17 | -------------------------------------------------------------------------------- /server/internal/model/gorm/migrate.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | ) 6 | 7 | var AllTables = []any{ 8 | &KnowledgeBase{}, 9 | &KnowledgeDocuments{}, 10 | &KnowledgeChunks{}, 11 | } 12 | 13 | // AutoMigrate 自动迁移所有GORM模型 14 | func AutoMigrate(db *gorm.DB) error { 15 | return db.AutoMigrate( 16 | AllTables..., 17 | ) 18 | } 19 | -------------------------------------------------------------------------------- /server/core/common/chat_model_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/cloudwego/eino-ext/components/model/openai" 7 | "github.com/gogf/gf/v2/frame/g" 8 | "github.com/gogf/gf/v2/os/gctx" 9 | ) 10 | 11 | func TestCfg(t *testing.T) { 12 | ctx := gctx.New() 13 | cfg := &openai.ChatModelConfig{} 14 | err := g.Cfg().MustGet(ctx, "qa").Scan(cfg) 15 | if err != nil { 16 | t.Fatal(err) 17 | return 18 | } 19 | t.Log(cfg) 20 | } 21 | -------------------------------------------------------------------------------- /fe/src/assets/vue.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /server/internal/controller/rag/rag_new.go: -------------------------------------------------------------------------------- 1 | // ================================================================================= 2 | // This is auto-generated by GoFrame CLI tool only once. Fill this file as you wish. 3 | // ================================================================================= 4 | 5 | package rag 6 | 7 | import ( 8 | "github.com/wangle201210/go-rag/server/api/rag" 9 | ) 10 | 11 | type ControllerV1 struct{} 12 | 13 | func NewV1() rag.IRagV1 { 14 | return &ControllerV1{} 15 | } 16 | -------------------------------------------------------------------------------- /fe/src/utils/markdown.js: -------------------------------------------------------------------------------- 1 | import hljs from 'highlight.js'; 2 | import { marked } from 'marked'; 3 | import 'highlight.js/styles/github.css'; 4 | 5 | // Initialize marked configuration 6 | marked.setOptions({ 7 | highlight(code, lang) { 8 | const language = hljs.getLanguage(lang) ? lang : 'plaintext'; 9 | return hljs.highlight(code, { language }).value; 10 | }, 11 | langPrefix: 'hljs language-', 12 | gfm: true, 13 | breaks: true, 14 | }); 15 | 16 | export function renderMarkdown(text) { 17 | if (!text) { 18 | return ''; 19 | } 20 | return marked.parse(text); 21 | } -------------------------------------------------------------------------------- /server/internal/dao/db/base.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "fmt" 5 | 6 | gormModel "github.com/wangle201210/go-rag/server/internal/model/gorm" 7 | "gorm.io/gorm" 8 | ) 9 | 10 | type Base struct { 11 | *gorm.DB 12 | } 13 | 14 | func (b *Base) AutoMigrate() error { 15 | return b.DB.AutoMigrate(gormModel.AllTables...) 16 | } 17 | 18 | // Ping 健康检查 19 | func (b *Base) Ping() error { 20 | if b.DB == nil { 21 | return fmt.Errorf("database not connected") 22 | } 23 | 24 | sqlDB, err := b.DB.DB() 25 | if err != nil { 26 | return err 27 | } 28 | 29 | return sqlDB.Ping() 30 | } 31 | -------------------------------------------------------------------------------- /fe/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | GO-RAG 8 | 9 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /server/core/types/consts.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | const ( 4 | FieldContent = "content" 5 | FieldContentVector = "content_vector" 6 | FieldQAContent = "qa_content" 7 | FieldQAContentVector = "qa_content_vector" 8 | FieldExtra = "ext" 9 | KnowledgeName = "_knowledge_name" 10 | 11 | RetrieverFieldKey = "_retriever_field" 12 | 13 | Title1 = "h1" 14 | Title2 = "h2" 15 | Title3 = "h3" 16 | 17 | XlsxRow = "_row" 18 | ) 19 | 20 | var ( 21 | // ExtKeys ext 里面需要存储的数据 22 | ExtKeys = []string{"_extension", "_file_name", "_source", Title1, Title2, Title3} 23 | ) 24 | -------------------------------------------------------------------------------- /server/internal/controller/rag/rag_v1_update_chunk.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "context" 5 | 6 | v1 "github.com/wangle201210/go-rag/server/api/rag/v1" 7 | "github.com/wangle201210/go-rag/server/internal/logic/knowledge" 8 | "github.com/wangle201210/go-rag/server/internal/model/entity" 9 | ) 10 | 11 | func (c *ControllerV1) UpdateChunk(ctx context.Context, req *v1.UpdateChunkReq) (res *v1.UpdateChunkRes, err error) { 12 | err = knowledge.UpdateChunkByIds(ctx, req.Ids, entity.KnowledgeChunks{ 13 | Status: req.Status, 14 | }) 15 | if err != nil { 16 | return 17 | } 18 | 19 | return 20 | } 21 | -------------------------------------------------------------------------------- /server/api/rag/v1/indexer.go: -------------------------------------------------------------------------------- 1 | package v1 2 | 3 | import ( 4 | "github.com/gogf/gf/v2/frame/g" 5 | "github.com/gogf/gf/v2/net/ghttp" 6 | ) 7 | 8 | type IndexerReq struct { 9 | g.Meta `path:"/v1/indexer" method:"post" mime:"multipart/form-data" tags:"rag"` 10 | File *ghttp.UploadFile `p:"file" type:"file" dc:"如果是本地文件,则直接上传文件"` 11 | URL string `p:"url" dc:"如果是网络文件则直接输入url即可"` 12 | KnowledgeName string `p:"knowledge_name" dc:"知识库名称" v:"required"` 13 | } 14 | 15 | type IndexerRes struct { 16 | g.Meta `mime:"application/json"` 17 | DocIDs []string `json:"doc_ids"` 18 | } 19 | -------------------------------------------------------------------------------- /server/core/common/slices_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/cloudwego/eino/schema" 7 | ) 8 | 9 | func TestRemoveDuplicates(t *testing.T) { 10 | docs := []*schema.Document{ 11 | { 12 | ID: "1", 13 | MetaData: map[string]any{ 14 | "foo": "bar", 15 | }, 16 | }, 17 | { 18 | ID: "2", 19 | }, 20 | { 21 | ID: "3", 22 | }, 23 | { 24 | ID: "1", 25 | }, 26 | } 27 | 28 | docs = RemoveDuplicates(docs, func(t *schema.Document) string { 29 | return t.ID 30 | }) 31 | for i, doc := range docs { 32 | t.Logf("i: %d, doc_id: %v", i, doc.ID) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /server/core/common/helper.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import "net/url" 4 | 5 | func Of[T any](v T) *T { 6 | return &v 7 | } 8 | 9 | func IsURL(str string) bool { 10 | u, err := url.Parse(str) 11 | if err != nil { 12 | return false 13 | } 14 | return u.Scheme != "" && u.Host != "" 15 | } 16 | 17 | func RemoveDuplicates[T any, K comparable](slice []T, keyFunc func(T) K) []T { 18 | encountered := make(map[K]bool) 19 | var result []T 20 | 21 | for _, v := range slice { 22 | key := keyFunc(v) 23 | if !encountered[key] { 24 | encountered[key] = true 25 | result = append(result, v) 26 | } 27 | } 28 | 29 | return result 30 | } 31 | -------------------------------------------------------------------------------- /fe/src/utils/knowledgeIdStore.js: -------------------------------------------------------------------------------- 1 | /** 2 | * 知识库ID管理服务 3 | * 用于在不同页面之间共享选中的知识库ID 4 | * 使用localStorage存储,确保页面刷新后仍然可用 5 | */ 6 | 7 | const STORAGE_KEY = 'go_rag_selected_knowledge_id' 8 | 9 | /** 10 | * 获取当前选中的知识库ID 11 | * 如果不存在,则返回空字符串 12 | */ 13 | export function getSelectedKnowledgeId() { 14 | return localStorage.getItem(STORAGE_KEY) || '' 15 | } 16 | 17 | /** 18 | * 设置选中的知识库ID 19 | */ 20 | export function setSelectedKnowledgeId(id) { 21 | localStorage.setItem(STORAGE_KEY, id) 22 | return id 23 | } 24 | 25 | /** 26 | * 清除选中的知识库ID 27 | */ 28 | export function clearSelectedKnowledgeId() { 29 | localStorage.removeItem(STORAGE_KEY) 30 | } -------------------------------------------------------------------------------- /fe/src/App.vue: -------------------------------------------------------------------------------- 1 | 12 | 13 | 30 | -------------------------------------------------------------------------------- /fe/src/utils/format.js: -------------------------------------------------------------------------------- 1 | export const getStatusType = (status) => { 2 | switch (status) { 3 | case 0: return 'info' // 待处理 4 | case 1: return 'warning' // 处理中 5 | case 2: return 'success' // 已完成 6 | case 3: return 'danger' // 失败 7 | default: return 'info' 8 | } 9 | } 10 | 11 | export const getStatusText = (status) => { 12 | switch (status) { 13 | case 0: return '待处理' 14 | case 1: return '处理中' 15 | case 2: return '已完成' 16 | case 3: return '失败' 17 | default: return '未知' 18 | } 19 | } 20 | 21 | export const formatDate = (date) => { 22 | if (!date) return '-' 23 | return new Date(date).toLocaleString('zh-CN') 24 | } -------------------------------------------------------------------------------- /server/core/vector/factory.go: -------------------------------------------------------------------------------- 1 | package vector 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // NewVectorStore 创建向量存储实例 8 | func NewVectorStore(cfg *Config) (VectorStore, error) { 9 | switch cfg.Type { 10 | case "es", "elasticsearch": 11 | if cfg.ES == nil { 12 | return nil, fmt.Errorf("es config is required when type is es") 13 | } 14 | return NewESVectorStore(cfg.ES) 15 | case "qdrant": 16 | if cfg.Qdrant == nil { 17 | return nil, fmt.Errorf("qdrant config is required when type is qdrant") 18 | } 19 | return NewQdrantVectorStore(cfg.Qdrant) 20 | default: 21 | return nil, fmt.Errorf("unsupported vector store type: %s", cfg.Type) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /fe/src/components/MessageBoxDemo.vue: -------------------------------------------------------------------------------- 1 | 19 | 20 | 25 | -------------------------------------------------------------------------------- /server/internal/dao/db/interface.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | // Database 数据库接口 4 | type Database interface { 5 | // Connect 连接管理 6 | Connect() error 7 | // DSN 构建连接DSN 8 | DSN() string 9 | // AutoMigrate 迁移表结构 10 | AutoMigrate() error 11 | // Ping 健康检查 12 | Ping() error 13 | } 14 | 15 | type Config struct { 16 | Host string 17 | Port string 18 | User string 19 | Password string 20 | Database string 21 | Charset string 22 | MaxOpenConn int 23 | MaxIdleConn int 24 | LogLevel int // 1: silent 2: Error 3:Warn 4: Info 25 | 26 | // sqlite相关配置 27 | FilePath string 28 | BusyTimeout int 29 | JournalMode string 30 | Synchronous string 31 | CacheSize int 32 | } 33 | -------------------------------------------------------------------------------- /server/internal/model/gorm/knowledge_base.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | // KnowledgeBase GORM模型定义 8 | type KnowledgeBase struct { 9 | ID int64 `gorm:"primaryKey;column:id"` 10 | Name string `gorm:"column:name;type:varchar(255)"` 11 | Description string `gorm:"column:description;type:varchar(255)"` 12 | Category string `gorm:"column:category;type:varchar(255)"` 13 | Status int `gorm:"column:status;default:1"` 14 | CreateTime time.Time `gorm:"column:created_at"` 15 | UpdateTime time.Time `gorm:"column:updated_at"` 16 | } 17 | 18 | // TableName 设置表名 19 | func (KnowledgeBase) TableName() string { 20 | return "knowledge_base" 21 | } 22 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # 多阶段构建 2 | # 阶段2: 构建后端 3 | FROM golang:1.23.9-alpine AS server-builder 4 | WORKDIR /app 5 | # 复制整个项目代码 6 | COPY . . 7 | # 复制前端构建产物到server目录 8 | # 构建后端 9 | RUN cd server && go mod tidy && go build -o go-rag-server main.go 10 | 11 | # 阶段3: 最终镜像 12 | FROM alpine:latest 13 | WORKDIR /app 14 | # 安装运行时依赖 15 | RUN apk --no-cache add ca-certificates tzdata 16 | # 设置时区 17 | ENV TZ=Asia/Shanghai 18 | # 复制后端构建产物 19 | COPY --from=server-builder /app/server/go-rag-server /app/ 20 | COPY --from=server-builder /app/server/static/ /app/static/ 21 | COPY --from=server-builder /app/server/manifest/config/config_demo.yaml /app/manifest/config/config.yaml 22 | 23 | # 暴露端口 24 | EXPOSE 8000 25 | 26 | # 启动命令 27 | CMD ["/app/go-rag-server"] -------------------------------------------------------------------------------- /fe/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "esnext", 4 | "jsx": "preserve", 5 | "lib": ["esnext", "dom"], 6 | "useDefineForClassFields": true, 7 | "baseUrl": ".", 8 | "module": "esnext", 9 | "moduleResolution": "bundler", 10 | "paths": { 11 | "~/*": ["src/*"] 12 | }, 13 | "resolveJsonModule": true, 14 | "types": [ 15 | "vite/client", 16 | "unplugin-vue-router/client" 17 | ], 18 | "strict": true, 19 | "sourceMap": true, 20 | "esModuleInterop": true, 21 | "skipLibCheck": true 22 | }, 23 | "vueCompilerOptions": { 24 | "target": 3 25 | }, 26 | "include": ["src/**/*.ts", "src/**/*.d.ts", "src/**/*.tsx", "src/**/*.vue"] 27 | } 28 | -------------------------------------------------------------------------------- /server/internal/controller/rag/rag_v1_chunks_list.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "context" 5 | 6 | v1 "github.com/wangle201210/go-rag/server/api/rag/v1" 7 | "github.com/wangle201210/go-rag/server/internal/logic/knowledge" 8 | "github.com/wangle201210/go-rag/server/internal/model/entity" 9 | ) 10 | 11 | func (c *ControllerV1) ChunksList(ctx context.Context, req *v1.ChunksListReq) (res *v1.ChunksListRes, err error) { 12 | chunks, total, err := knowledge.GetChunksList(ctx, entity.KnowledgeChunks{ 13 | KnowledgeDocId: req.KnowledgeDocId, 14 | }, req.Page, req.Size) 15 | if err != nil { 16 | return 17 | } 18 | return &v1.ChunksListRes{ 19 | Data: chunks, 20 | Total: total, 21 | Page: req.Page, 22 | Size: req.Size, 23 | }, nil 24 | } 25 | -------------------------------------------------------------------------------- /server/internal/controller/rag/rag_v1_documents_list.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "context" 5 | 6 | v1 "github.com/wangle201210/go-rag/server/api/rag/v1" 7 | "github.com/wangle201210/go-rag/server/internal/logic/knowledge" 8 | "github.com/wangle201210/go-rag/server/internal/model/entity" 9 | ) 10 | 11 | func (c *ControllerV1) DocumentsList(ctx context.Context, req *v1.DocumentsListReq) (res *v1.DocumentsListRes, err error) { 12 | documents, total, err := knowledge.GetDocumentsList(ctx, entity.KnowledgeDocuments{ 13 | KnowledgeBaseName: req.KnowledgeName, 14 | }, req.Page, req.Size) 15 | if err != nil { 16 | return 17 | } 18 | 19 | res = &v1.DocumentsListRes{ 20 | Data: documents, 21 | Total: total, 22 | Page: req.Page, 23 | Size: req.Size, 24 | } 25 | 26 | return 27 | } 28 | -------------------------------------------------------------------------------- /server/internal/model/gorm/knowledge_documents.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | // KnowledgeDocuments GORM模型定义 8 | type KnowledgeDocuments struct { 9 | ID int64 `gorm:"primaryKey;column:id;autoIncrement"` 10 | KnowledgeBaseName string `gorm:"column:knowledge_base_name;type:varchar(255);not null"` 11 | FileName string `gorm:"column:file_name;type:varchar(255)"` 12 | Status int8 `gorm:"column:status;type:tinyint;not null;default:0"` 13 | CreateTime time.Time `gorm:"column:created_at;type:timestamp;autoCreateTime"` 14 | UpdateTime time.Time `gorm:"column:updated_at;type:timestamp;autoUpdateTime"` 15 | } 16 | 17 | // TableName 设置表名 18 | func (KnowledgeDocuments) TableName() string { 19 | return "knowledge_documents" 20 | } 21 | -------------------------------------------------------------------------------- /server/internal/controller/rag/rag_v1_chat.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/wangle201210/go-rag/server/api/rag/v1" 7 | "github.com/wangle201210/go-rag/server/internal/logic/chat" 8 | ) 9 | 10 | func (c *ControllerV1) Chat(ctx context.Context, req *v1.ChatReq) (res *v1.ChatRes, err error) { 11 | retriever, err := c.Retriever(ctx, &v1.RetrieverReq{ 12 | Question: req.Question, 13 | TopK: req.TopK, 14 | Score: req.Score, 15 | KnowledgeName: req.KnowledgeName, 16 | }) 17 | if err != nil { 18 | return 19 | } 20 | chatI := chat.GetChat() 21 | answer, err := chatI.GetAnswer(ctx, req.ConvID, retriever.Document, req.Question) 22 | if err != nil { 23 | return 24 | } 25 | res = &v1.ChatRes{ 26 | Answer: answer, 27 | References: retriever.Document, 28 | } 29 | return 30 | } 31 | -------------------------------------------------------------------------------- /server/core/rerank/rerank_test.go: -------------------------------------------------------------------------------- 1 | package rerank 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/cloudwego/eino/schema" 7 | "github.com/gogf/gf/v2/os/gctx" 8 | ) 9 | 10 | func TestRerank(t *testing.T) { 11 | rerankCfg = &Conf{ 12 | apiKey: "sk-***", 13 | Model: "BAAI/bge-reranker-v2-m3", 14 | ReturnDocuments: false, 15 | MaxChunksPerDoc: 1024, 16 | OverlapTokens: 80, 17 | url: "https://api.siliconflow.cn/v1/rerank", 18 | } 19 | ctx := gctx.New() 20 | docs := []*schema.Document{ 21 | {Content: "banana"}, 22 | {Content: "fruit"}, 23 | {Content: "apple"}, 24 | {Content: "vegetable"}, 25 | } 26 | output, err := Rerank(ctx, "水果", docs, 2) 27 | if err != nil { 28 | t.Fatal(err) 29 | } 30 | for _, doc := range output { 31 | t.Logf("content: %v, score: %v", doc.Content, doc.Score()) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /server/internal/model/do/knowledge_documents.go: -------------------------------------------------------------------------------- 1 | // ================================================================================= 2 | // Code generated and maintained by GoFrame CLI tool. DO NOT EDIT. 3 | // ================================================================================= 4 | 5 | package do 6 | 7 | import ( 8 | "github.com/gogf/gf/v2/frame/g" 9 | "github.com/gogf/gf/v2/os/gtime" 10 | ) 11 | 12 | // KnowledgeDocuments is the golang structure of table knowledge_documents for DAO operations like Where/Data. 13 | type KnowledgeDocuments struct { 14 | g.Meta `orm:"table:knowledge_documents, do:true"` 15 | Id interface{} // 16 | KnowledgeBaseName interface{} // 17 | FileName interface{} // 18 | Status interface{} // 19 | CreatedAt *gtime.Time // 20 | UpdatedAt *gtime.Time // 21 | } 22 | -------------------------------------------------------------------------------- /server/internal/model/do/knowledge_base.go: -------------------------------------------------------------------------------- 1 | // ================================================================================= 2 | // Code generated and maintained by GoFrame CLI tool. DO NOT EDIT. 3 | // ================================================================================= 4 | 5 | package do 6 | 7 | import ( 8 | "github.com/gogf/gf/v2/frame/g" 9 | "github.com/gogf/gf/v2/os/gtime" 10 | ) 11 | 12 | // KnowledgeBase is the golang structure of table knowledge_base for DAO operations like Where/Data. 13 | type KnowledgeBase struct { 14 | g.Meta `orm:"table:knowledge_base, do:true"` 15 | Id interface{} // 主键ID 16 | Name interface{} // 知识库名称 17 | Description interface{} // 知识库描述 18 | Category interface{} // 知识库分类 19 | Status interface{} // 状态:0-禁用,1-启用 20 | CreateTime *gtime.Time // 创建时间 21 | UpdateTime *gtime.Time // 更新时间 22 | } 23 | -------------------------------------------------------------------------------- /server/core/vector/qdrant_test.go: -------------------------------------------------------------------------------- 1 | package vector 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | ) 7 | 8 | func TestQdrantVectorStore_CreateIndex(t *testing.T) { 9 | // 这是一个示例测试,需要实际的 Qdrant 服务器才能运行 10 | t.Skip("需要实际的 Qdrant 服务器") 11 | 12 | cfg := &QdrantConfig{ 13 | Address: "localhost:6334", 14 | APIKey: "", 15 | } 16 | 17 | store, err := NewQdrantVectorStore(cfg) 18 | if err != nil { 19 | t.Fatalf("创建 Qdrant 存储失败: %v", err) 20 | } 21 | defer store.Close() 22 | 23 | ctx := context.Background() 24 | indexName := "test_index" 25 | 26 | // 测试创建索引 27 | err = store.CreateIndex(ctx, indexName) 28 | if err != nil { 29 | t.Fatalf("创建索引失败: %v", err) 30 | } 31 | 32 | // 测试索引是否存在 33 | exists, err := store.IndexExists(ctx, indexName) 34 | if err != nil { 35 | t.Fatalf("检查索引失败: %v", err) 36 | } 37 | 38 | if !exists { 39 | t.Errorf("索引应该存在") 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /server/core/retriever/orchestration.go: -------------------------------------------------------------------------------- 1 | package retriever 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/cloudwego/eino/compose" 7 | "github.com/cloudwego/eino/schema" 8 | "github.com/wangle201210/go-rag/server/core/config" 9 | ) 10 | 11 | func BuildRetriever(ctx context.Context, conf *config.Config) (r compose.Runnable[string, []*schema.Document], err error) { 12 | const ( 13 | Retriever1 = "Retriever" 14 | ) 15 | g := compose.NewGraph[string, []*schema.Document]() 16 | retriever1KeyOfRetriever, err := newRetriever(ctx, conf) 17 | if err != nil { 18 | return nil, err 19 | } 20 | _ = g.AddRetrieverNode(Retriever1, retriever1KeyOfRetriever) 21 | _ = g.AddEdge(compose.START, Retriever1) 22 | _ = g.AddEdge(Retriever1, compose.END) 23 | r, err = g.Compile(ctx, compose.WithGraphName("retriever")) 24 | if err != nil { 25 | return nil, err 26 | } 27 | return r, err 28 | } 29 | -------------------------------------------------------------------------------- /server/internal/model/do/knowledge_chunks.go: -------------------------------------------------------------------------------- 1 | // ================================================================================= 2 | // Code generated and maintained by GoFrame CLI tool. DO NOT EDIT. 3 | // ================================================================================= 4 | 5 | package do 6 | 7 | import ( 8 | "github.com/gogf/gf/v2/frame/g" 9 | "github.com/gogf/gf/v2/os/gtime" 10 | ) 11 | 12 | // KnowledgeChunks is the golang structure of table knowledge_chunks for DAO operations like Where/Data. 13 | type KnowledgeChunks struct { 14 | g.Meta `orm:"table:knowledge_chunks, do:true"` 15 | Id interface{} // 16 | KnowledgeDocId interface{} // 17 | ChunkId interface{} // 18 | Content interface{} // 19 | Ext interface{} // 20 | Status interface{} // 21 | CreatedAt *gtime.Time // 22 | UpdatedAt *gtime.Time // 23 | } 24 | -------------------------------------------------------------------------------- /fe/src/components/Logos.vue: -------------------------------------------------------------------------------- 1 | 14 | 15 | 32 | -------------------------------------------------------------------------------- /server/internal/dao/knowledge_chunks.go: -------------------------------------------------------------------------------- 1 | // ================================================================================= 2 | // This file is auto-generated by the GoFrame CLI tool. You may modify it as needed. 3 | // ================================================================================= 4 | 5 | package dao 6 | 7 | import ( 8 | "github.com/wangle201210/go-rag/server/internal/dao/internal" 9 | ) 10 | 11 | // knowledgeChunksDao is the data access object for the table knowledge_chunks. 12 | // You can define custom methods on it to extend its functionality as needed. 13 | type knowledgeChunksDao struct { 14 | *internal.KnowledgeChunksDao 15 | } 16 | 17 | var ( 18 | // KnowledgeChunks is a globally accessible object for table knowledge_chunks operations. 19 | KnowledgeChunks = knowledgeChunksDao{internal.NewKnowledgeChunksDao()} 20 | ) 21 | 22 | // Add your custom methods and functionality below. 23 | -------------------------------------------------------------------------------- /fe/src/utils/knowledgeStore.js: -------------------------------------------------------------------------------- 1 | /** 2 | * 知识库名称管理服务 3 | * 用于在不同页面之间共享knowledge_name参数 4 | * 使用localStorage存储,确保页面刷新后仍然可用 5 | */ 6 | 7 | import { v4 as uuidv4 } from 'uuid' 8 | 9 | const STORAGE_KEY = 'go_rag_knowledge_name' 10 | 11 | /** 12 | * 获取当前知识库名称 13 | * 如果不存在,则自动生成一个新的 14 | */ 15 | export function getKnowledgeName() { 16 | let knowledgeName = localStorage.getItem(STORAGE_KEY) 17 | 18 | // 如果不存在,则生成一个新的知识库名称 19 | if (!knowledgeName) { 20 | knowledgeName = generateKnowledgeName() 21 | localStorage.setItem(STORAGE_KEY, knowledgeName) 22 | } 23 | 24 | return knowledgeName 25 | } 26 | 27 | /** 28 | * 设置知识库名称 29 | */ 30 | export function setKnowledgeName(name) { 31 | localStorage.setItem(STORAGE_KEY, name) 32 | return name 33 | } 34 | 35 | /** 36 | * 生成一个新的知识库名称 37 | * 使用UUID确保唯一性 38 | */ 39 | export function generateKnowledgeName() { 40 | return `knowledge_${uuidv4().substring(0, 8)}` 41 | } -------------------------------------------------------------------------------- /server/internal/dao/knowledge_documents.go: -------------------------------------------------------------------------------- 1 | // ================================================================================= 2 | // This file is auto-generated by the GoFrame CLI tool. You may modify it as needed. 3 | // ================================================================================= 4 | 5 | package dao 6 | 7 | import ( 8 | "github.com/wangle201210/go-rag/server/internal/dao/internal" 9 | ) 10 | 11 | // knowledgeDocumentsDao is the data access object for the table knowledge_documents. 12 | // You can define custom methods on it to extend its functionality as needed. 13 | type knowledgeDocumentsDao struct { 14 | *internal.KnowledgeDocumentsDao 15 | } 16 | 17 | var ( 18 | // KnowledgeDocuments is a globally accessible object for table knowledge_documents operations. 19 | KnowledgeDocuments = knowledgeDocumentsDao{internal.NewKnowledgeDocumentsDao()} 20 | ) 21 | 22 | // Add your custom methods and functionality below. 23 | -------------------------------------------------------------------------------- /server/core/common/embedding.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "context" 5 | "os" 6 | 7 | "github.com/cloudwego/eino-ext/components/embedding/openai" 8 | "github.com/cloudwego/eino/components/embedding" 9 | "github.com/wangle201210/go-rag/server/core/config" 10 | ) 11 | 12 | func NewEmbedding(ctx context.Context, conf *config.Config) (eb embedding.Embedder, err error) { 13 | econf := &openai.EmbeddingConfig{ 14 | APIKey: conf.APIKey, 15 | Model: conf.EmbeddingModel, 16 | Dimensions: Of(1024), 17 | Timeout: 0, 18 | BaseURL: conf.BaseURL, 19 | } 20 | if econf.APIKey == "" { 21 | econf.APIKey = os.Getenv("OPENAI_API_KEY") 22 | } 23 | if econf.BaseURL == "" { 24 | econf.BaseURL = os.Getenv("OPENAI_BASE_URL") 25 | } 26 | if econf.Model == "" { 27 | econf.Model = "text-embedding-3-large" 28 | } 29 | eb, err = openai.NewEmbedder(ctx, econf) 30 | if err != nil { 31 | return nil, err 32 | } 33 | return eb, nil 34 | } 35 | -------------------------------------------------------------------------------- /fe/public/favicon.svg: -------------------------------------------------------------------------------- 1 | element plus-logo-small 副本 -------------------------------------------------------------------------------- /server/internal/controller/rag/rag_v1_retriever_dify.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/gogf/gf/v2/frame/g" 7 | "github.com/wangle201210/go-rag/server/api/rag/v1" 8 | ) 9 | 10 | func (c *ControllerV1) RetrieverDify(ctx context.Context, req *v1.RetrieverDifyReq) (res *v1.RetrieverDifyRes, err error) { 11 | retriever, err := c.Retriever(ctx, &v1.RetrieverReq{ 12 | Question: req.Query, 13 | TopK: req.RetrievalSetting.TopK, 14 | Score: req.RetrievalSetting.ScoreThreshold, 15 | KnowledgeName: req.KnowledgeID, 16 | }) 17 | if err != nil { 18 | return 19 | } 20 | res = &v1.RetrieverDifyRes{} 21 | for _, document := range retriever.Document { 22 | g.Log().Infof(ctx, "content: %s, score: %f", document.Content, document.Score()) 23 | record := &v1.Record{ 24 | Content: document.Content, 25 | Score: document.Score(), 26 | } 27 | res.Records = append(res.Records, record) 28 | } 29 | return 30 | } 31 | -------------------------------------------------------------------------------- /server/core/indexer/async.go: -------------------------------------------------------------------------------- 1 | package indexer 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/cloudwego/eino/compose" 7 | "github.com/cloudwego/eino/schema" 8 | "github.com/wangle201210/go-rag/server/core/config" 9 | ) 10 | 11 | func BuildIndexerAsync(ctx context.Context, conf *config.Config) (r compose.Runnable[[]*schema.Document, []string], err error) { 12 | const ( 13 | Indexer = "Indexer" 14 | QA = "QA" 15 | ) 16 | 17 | g := compose.NewGraph[[]*schema.Document, []string]() 18 | indexer2KeyOfIndexer, err := newAsyncIndexer(ctx, conf) 19 | if err != nil { 20 | return nil, err 21 | } 22 | _ = g.AddIndexerNode(Indexer, indexer2KeyOfIndexer) 23 | _ = g.AddLambdaNode(QA, compose.InvokableLambda(qa)) 24 | _ = g.AddEdge(compose.START, QA) 25 | _ = g.AddEdge(QA, Indexer) 26 | _ = g.AddEdge(Indexer, compose.END) 27 | r, err = g.Compile(ctx, compose.WithGraphName("indexer_async")) 28 | if err != nil { 29 | return nil, err 30 | } 31 | return r, err 32 | } 33 | -------------------------------------------------------------------------------- /fe/public/element-plus-logo-small.svg: -------------------------------------------------------------------------------- 1 | element plus-logo-small 副本 -------------------------------------------------------------------------------- /server/internal/controller/rag/rag_v1_chunk_delete.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/gogf/gf/v2/frame/g" 7 | v1 "github.com/wangle201210/go-rag/server/api/rag/v1" 8 | "github.com/wangle201210/go-rag/server/internal/logic/knowledge" 9 | "github.com/wangle201210/go-rag/server/internal/logic/rag" 10 | ) 11 | 12 | func (c *ControllerV1) ChunkDelete(ctx context.Context, req *v1.ChunkDeleteReq) (res *v1.ChunkDeleteRes, err error) { 13 | svr := rag.GetRagSvr() 14 | 15 | chunk, err := knowledge.GetChunkById(ctx, req.Id) 16 | if err != nil { 17 | g.Log().Errorf(ctx, "DeleteDocumentAndChunks: GetChunkById failed for id %v, err: %v", req.Id, err) 18 | return 19 | } 20 | 21 | err = svr.DeleteDocument(ctx, chunk.ChunkId) 22 | if err != nil { 23 | g.Log().Errorf(ctx, "DeleteDocumentAndChunks: ES DeleteByQuery failed for docId %v, err: %v", chunk.ChunkId, err) 24 | return 25 | } 26 | 27 | err = knowledge.DeleteChunkById(ctx, req.Id) 28 | return 29 | } 30 | -------------------------------------------------------------------------------- /fe/src/styles/index.scss: -------------------------------------------------------------------------------- 1 | // import dark theme 2 | @use 'element-plus/theme-chalk/src/dark/css-vars.scss' as *; 3 | 4 | // import common styles 5 | @import './common.css'; 6 | 7 | // :root { 8 | // --ep-color-primary: red; 9 | // } 10 | 11 | body { 12 | font-family: Inter, system-ui, Avenir, 'Helvetica Neue', Helvetica, 'PingFang SC', 'Hiragino Sans GB', 13 | 'Microsoft YaHei', '微软雅黑', Arial, sans-serif; 14 | -webkit-font-smoothing: antialiased; 15 | -moz-osx-font-smoothing: grayscale; 16 | margin: 0; 17 | } 18 | 19 | a { 20 | color: var(--ep-color-primary); 21 | } 22 | 23 | code { 24 | border-radius: 2px; 25 | padding: 2px 4px; 26 | background-color: var(--ep-color-primary-light-9); 27 | color: var(--ep-color-primary); 28 | } 29 | 30 | #nprogress { 31 | pointer-events: none; 32 | } 33 | 34 | #nprogress .bar { 35 | background: rgb(13, 148, 136); 36 | opacity: 0.75; 37 | position: fixed; 38 | z-index: 1031; 39 | top: 0; 40 | left: 0; 41 | width: 100%; 42 | height: 2px; 43 | } 44 | -------------------------------------------------------------------------------- /server/internal/dao/knowledge_base.go: -------------------------------------------------------------------------------- 1 | // ================================================================================= 2 | // This file is auto-generated by the GoFrame CLI tool. You may modify it as needed. 3 | // ================================================================================= 4 | 5 | package dao 6 | 7 | import ( 8 | "github.com/wangle201210/go-rag/server/internal/dao/internal" 9 | ) 10 | 11 | // internalKnowledgeBaseDao is an internal type for wrapping the internal DAO implementation. 12 | type internalKnowledgeBaseDao = *internal.KnowledgeBaseDao 13 | 14 | // knowledgeBaseDao is the data access object for the table knowledge_base. 15 | // You can define custom methods on it to extend its functionality as needed. 16 | type knowledgeBaseDao struct { 17 | internalKnowledgeBaseDao 18 | } 19 | 20 | var ( 21 | // KnowledgeBase is a globally accessible object for table knowledge_base operations. 22 | KnowledgeBase = knowledgeBaseDao{ 23 | internal.NewKnowledgeBaseDao(), 24 | } 25 | ) 26 | 27 | // Add your custom methods and functionality below. 28 | -------------------------------------------------------------------------------- /server/internal/model/gorm/knowledge_chunks.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | // KnowledgeChunks GORM模型定义 8 | type KnowledgeChunks struct { 9 | ID int64 `gorm:"primaryKey;column:id;autoIncrement:true"` 10 | KnowledgeDocID int64 `gorm:"column:knowledge_doc_id;not null;index"` 11 | ChunkID string `gorm:"column:chunk_id;type:varchar(36);not null;uniqueIndex:uk_chunk_id"` 12 | Content string `gorm:"column:content;type:text"` 13 | Ext string `gorm:"column:ext;type:varchar(1024)"` 14 | Status int8 `gorm:"column:status;type:tinyint(1);not null;default:1"` 15 | CreateTime time.Time `gorm:"column:created_at;type:timestamp;autoCreateTime"` 16 | UpdateTime time.Time `gorm:"column:updated_at;type:timestamp;autoUpdateTime"` 17 | 18 | KnowledgeDocument KnowledgeDocuments `gorm:"foreignKey:KnowledgeDocID;references:ID;constraint:OnDelete:CASCADE,OnUpdate:RESTRICT"` 19 | } 20 | 21 | // TableName 设置表名 22 | func (KnowledgeChunks) TableName() string { 23 | return "knowledge_chunks" 24 | } 25 | -------------------------------------------------------------------------------- /fe/src/styles/element/index.scss: -------------------------------------------------------------------------------- 1 | $--colors: ( 2 | 'primary': ( 3 | 'base': green, 4 | ), 5 | 'success': ( 6 | 'base': #21ba45, 7 | ), 8 | 'warning': ( 9 | 'base': #f2711c, 10 | ), 11 | 'danger': ( 12 | 'base': #db2828, 13 | ), 14 | 'error': ( 15 | 'base': #db2828, 16 | ), 17 | 'info': ( 18 | 'base': #42b8dd, 19 | ), 20 | ); 21 | 22 | // we can add this to custom namespace, default is 'el' 23 | @forward 'element-plus/theme-chalk/src/mixins/config.scss' with ( 24 | $namespace: 'ep' 25 | ); 26 | 27 | // You should use them in scss, because we calculate it by sass. 28 | // comment next lines to use default color 29 | @forward 'element-plus/theme-chalk/src/common/var.scss' with ( 30 | // do not use same name, it will override. 31 | $colors: $--colors, 32 | $button-padding-horizontal: ('default': 50px) 33 | ); 34 | 35 | // if you want to import all 36 | // @use "element-plus/theme-chalk/src/index.scss" as *; 37 | 38 | // You can comment it to hide debug info. 39 | // @debug $--colors; 40 | 41 | // custom dark variables 42 | @use './dark.scss'; 43 | -------------------------------------------------------------------------------- /fe/uno.config.ts: -------------------------------------------------------------------------------- 1 | import { 2 | defineConfig, 3 | presetAttributify, 4 | presetIcons, 5 | presetTypography, 6 | presetUno, 7 | presetWebFonts, 8 | transformerDirectives, 9 | transformerVariantGroup, 10 | } from 'unocss' 11 | 12 | export default defineConfig({ 13 | shortcuts: [ 14 | ['btn', 'px-4 py-1 rounded inline-block bg-teal-700 text-white cursor-pointer !outline-none hover:bg-teal-800 disabled:cursor-default disabled:bg-gray-600 disabled:opacity-50'], 15 | ['icon-btn', 'inline-block cursor-pointer select-none opacity-75 transition duration-200 ease-in-out hover:opacity-100 hover:text-teal-600'], 16 | ], 17 | presets: [ 18 | presetUno(), 19 | presetAttributify(), 20 | presetIcons({ 21 | scale: 1.2, 22 | }), 23 | presetTypography(), 24 | presetWebFonts({ 25 | fonts: { 26 | sans: 'DM Sans', 27 | serif: 'DM Serif Display', 28 | mono: 'DM Mono', 29 | }, 30 | }), 31 | ], 32 | transformers: [ 33 | transformerDirectives(), 34 | transformerVariantGroup(), 35 | ], 36 | safelist: 'prose prose-sm m-auto text-left'.split(' '), 37 | }) 38 | -------------------------------------------------------------------------------- /server/internal/model/entity/knowledge_documents.go: -------------------------------------------------------------------------------- 1 | // ================================================================================= 2 | // Code generated and maintained by GoFrame CLI tool. DO NOT EDIT. 3 | // ================================================================================= 4 | 5 | package entity 6 | 7 | import ( 8 | "github.com/gogf/gf/v2/os/gtime" 9 | ) 10 | 11 | // KnowledgeDocuments is the golang structure for table knowledge_documents. 12 | type KnowledgeDocuments struct { 13 | Id int64 `json:"id" orm:"id" description:""` // 14 | KnowledgeBaseName string `json:"knowledgeBaseName" orm:"knowledge_base_name" description:""` // 15 | FileName string `json:"fileName" orm:"file_name" description:""` // 16 | Status int `json:"status" orm:"status" description:""` // 17 | CreatedAt *gtime.Time `json:"createdAt" orm:"created_at" description:""` // 18 | UpdatedAt *gtime.Time `json:"updatedAt" orm:"updated_at" description:""` // 19 | } 20 | -------------------------------------------------------------------------------- /server/core/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/cloudwego/eino-ext/components/model/openai" 5 | "github.com/elastic/go-elasticsearch/v8" 6 | "github.com/qdrant/go-client/qdrant" 7 | ) 8 | 9 | type Config struct { 10 | Client *elasticsearch.Client // ES 客户端 11 | QdrantClient *qdrant.Client // Qdrant 客户端 12 | IndexName string // index name / collection name 13 | // embedding 时使用 14 | APIKey string 15 | BaseURL string 16 | EmbeddingModel string 17 | ChatModel string 18 | } 19 | 20 | func (x *Config) GetChatModelConfig() *openai.ChatModelConfig { 21 | if x == nil { 22 | return nil 23 | } 24 | return &openai.ChatModelConfig{ 25 | APIKey: x.APIKey, 26 | BaseURL: x.BaseURL, 27 | Model: x.ChatModel, 28 | } 29 | } 30 | 31 | func (x *Config) Copy() *Config { 32 | return &Config{ 33 | Client: x.Client, 34 | QdrantClient: x.QdrantClient, 35 | IndexName: x.IndexName, 36 | // embedding 时使用 37 | APIKey: x.APIKey, 38 | BaseURL: x.BaseURL, 39 | EmbeddingModel: x.EmbeddingModel, 40 | ChatModel: x.ChatModel, 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | go.work.sum 23 | 24 | # env file 25 | .env 26 | .idea 27 | .vscode 28 | .cursor 29 | 30 | server/manifest/config/config.yaml 31 | /server/uploads/ 32 | **/package-lock.json 33 | /server/static/fe/ 34 | **/.vite 35 | /server/go-rag-server 36 | /server/logs/ 37 | /data/ 38 | /server/manifest/config/config-docker.yaml 39 | /server/manifest/config/frpc.toml 40 | /fe/src/components.d.ts 41 | **/node_modules/ 42 | **/dist/ 43 | **/build/ 44 | .DS_Store 45 | *.md 46 | *.pdf 47 | *.yaml 48 | *.yml 49 | *.db 50 | *.sqlite3 51 | *.lock 52 | *.json 53 | /storage/ 54 | /snapshots/ 55 | /server/?_journal_mode=WAL 56 | /static/ 57 | /releases/ 58 | -------------------------------------------------------------------------------- /server/internal/controller/rag/rag_v1_documents_delete.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/gogf/gf/v2/frame/g" 7 | v1 "github.com/wangle201210/go-rag/server/api/rag/v1" 8 | "github.com/wangle201210/go-rag/server/internal/logic/knowledge" 9 | "github.com/wangle201210/go-rag/server/internal/logic/rag" 10 | ) 11 | 12 | func (c *ControllerV1) DocumentsDelete(ctx context.Context, req *v1.DocumentsDeleteReq) (res *v1.DocumentsDeleteRes, err error) { 13 | svr := rag.GetRagSvr() 14 | 15 | ChunksList, err := knowledge.GetAllChunksByDocId(ctx, req.DocumentId, "id", "chunk_id") 16 | if err != nil { 17 | g.Log().Errorf(ctx, "DeleteDocumentAndChunks: GetAllChunksByDocId failed for id %d, err: %v", req.DocumentId, err) 18 | return 19 | } 20 | 21 | if len(ChunksList) > 0 { 22 | for _, chunk := range ChunksList { 23 | if chunk.ChunkId != "" { 24 | err = svr.DeleteDocument(ctx, chunk.ChunkId) 25 | if err != nil { 26 | g.Log().Errorf(ctx, "DeleteDocumentAndChunks: ES DeleteByQuery failed for docId %v, err: %v", chunk.ChunkId, err) 27 | return 28 | } 29 | } 30 | } 31 | } 32 | 33 | err = knowledge.DeleteDocument(ctx, req.DocumentId) 34 | return 35 | } 36 | -------------------------------------------------------------------------------- /server/internal/model/entity/knowledge_base.go: -------------------------------------------------------------------------------- 1 | // ================================================================================= 2 | // Code generated and maintained by GoFrame CLI tool. DO NOT EDIT. 3 | // ================================================================================= 4 | 5 | package entity 6 | 7 | import ( 8 | "github.com/gogf/gf/v2/os/gtime" 9 | ) 10 | 11 | // KnowledgeBase is the golang structure for table knowledge_base. 12 | type KnowledgeBase struct { 13 | Id int64 `json:"id" orm:"id" description:"主键ID"` // 主键ID 14 | Name string `json:"name" orm:"name" description:"知识库名称"` // 知识库名称 15 | Description string `json:"description" orm:"description" description:"知识库描述"` // 知识库描述 16 | Category string `json:"category" orm:"category" description:"知识库分类"` // 知识库分类 17 | Status int `json:"status" orm:"status" description:"状态:0-禁用,1-启用"` // 状态:0-禁用,1-启用 18 | CreateTime *gtime.Time `json:"createTime" orm:"create_time" description:"创建时间"` // 创建时间 19 | UpdateTime *gtime.Time `json:"updateTime" orm:"update_time" description:"更新时间"` // 更新时间 20 | } 21 | -------------------------------------------------------------------------------- /server/core/indexer/parser.go: -------------------------------------------------------------------------------- 1 | package indexer 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/cloudwego/eino-ext/components/document/parser/html" 7 | "github.com/cloudwego/eino-ext/components/document/parser/pdf" 8 | "github.com/cloudwego/eino-ext/components/document/parser/xlsx" 9 | 10 | "github.com/cloudwego/eino/components/document/parser" 11 | "github.com/wangle201210/go-rag/server/core/common" 12 | ) 13 | 14 | func newParser(ctx context.Context) (p parser.Parser, err error) { 15 | textParser := parser.TextParser{} 16 | 17 | htmlParser, err := html.NewParser(ctx, &html.Config{ 18 | Selector: common.Of("body"), 19 | }) 20 | if err != nil { 21 | return nil, err 22 | } 23 | xlsxParser, err := xlsx.NewXlsxParser(ctx, nil) 24 | 25 | pdfParser, err := pdf.NewPDFParser(ctx, &pdf.Config{}) 26 | if err != nil { 27 | return 28 | } 29 | 30 | // 创建扩展解析器 31 | p, err = parser.NewExtParser(ctx, &parser.ExtParserConfig{ 32 | // 注册特定扩展名的解析器 33 | Parsers: map[string]parser.Parser{ 34 | ".html": htmlParser, 35 | ".pdf": pdfParser, 36 | ".xlsx": xlsxParser, 37 | }, 38 | // 设置默认解析器,用于处理未知格式 39 | FallbackParser: textParser, 40 | }) 41 | if err != nil { 42 | return nil, err 43 | } 44 | return 45 | } 46 | -------------------------------------------------------------------------------- /roadmap.md: -------------------------------------------------------------------------------- 1 | # 项目路线图 2 | 3 | ## 当前状态 4 | - 版本:0.0.* (开发中) 5 | - 急需核心功能:`QA & embedding 异步执行`、`chunk管理` 6 | 7 | ## 版本规划 8 | 9 | ### 0.1.0 - 核心功能完善 10 | #### 异步处理 11 | - [x] QA & embedding 异步执行 12 | - 现状:同步处理,上传文档时立即执行 split→QA生成→embedding 13 | - 优化:将split后的处理流程改为异步 14 | 15 | #### 数据管理 16 | - [x] chunk管理 17 | - 问题:缺乏文档-chunk映射关系,无法编辑单个chunk 18 | - 方案: 19 | 1. ES存储chunk时同步记录映射关系到MySQL 20 | 2. 实现类似ragflow的数据集管理功能 21 | 22 | #### 检索增强 23 | - [ ] 稀疏向量(关键词)检索 24 | - 现状:仅支持稠密向量 25 | - 目标:增加稀疏向量检索路径,提升召回率 26 | 27 | #### 数据导入 28 | - [ ] excel(csv) QA对导入 29 | - 基础:eino已支持excel解析 30 | - 实现:快速接入excel/csv格式QA对导入 31 | 32 | #### 国际化 33 | - [ ] 多语言支持(i18n) 34 | - 目标:支持界面和内容的国际化 35 | 36 | ### 0.2.0 - 进阶功能 37 | #### 智能增强 38 | - [ ] Agentic RAG 39 | - 功能:通过智能体实现动态规划和自主决策 40 | 41 | #### 文档解析 42 | - [ ] 解析优化 43 | - 现状:基础pdf/txt/html解析 44 | - 优化: 45 | 1. 引入第三方API提升解析质量(如mineru) 46 | 2. 新增ppt/docx等格式支持 47 | 3. 图片解析 48 | 4. 用户自定义文档解析逻辑 49 | 50 | #### 用户系统 51 | - [ ] 添加用户体系 52 | - 问题:知识库全局可见 53 | - 方案: 54 | 1. 用户登录/鉴权 55 | 2. 知识库用户隔离 56 | 57 | - [ ] 用户配置分离 58 | - 功能: 59 | 1. 自定义模型提供商 60 | 2. 个人API_KEY管理 61 | 62 | #### 多向量库支持 63 | - [ ] 多数据库支持 64 | - 现状:只支持 es 65 | - 优化:支持postgre、milvus等向量数据库 66 | -------------------------------------------------------------------------------- /server/internal/model/entity/knowledge_chunks.go: -------------------------------------------------------------------------------- 1 | // ================================================================================= 2 | // Code generated and maintained by GoFrame CLI tool. DO NOT EDIT. 3 | // ================================================================================= 4 | 5 | package entity 6 | 7 | import ( 8 | "github.com/gogf/gf/v2/os/gtime" 9 | ) 10 | 11 | // KnowledgeChunks is the golang structure for table knowledge_chunks. 12 | type KnowledgeChunks struct { 13 | Id int64 `json:"id" orm:"id" description:""` // 14 | KnowledgeDocId int64 `json:"knowledgeDocId" orm:"knowledge_doc_id" description:""` // 15 | ChunkId string `json:"chunkId" orm:"chunk_id" description:""` // 16 | Content string `json:"content" orm:"content" description:""` // 17 | Ext string `json:"ext" orm:"ext" description:""` // 18 | Status int `json:"status" orm:"status" description:""` // 19 | CreatedAt *gtime.Time `json:"createdAt" orm:"created_at" description:""` // 20 | UpdatedAt *gtime.Time `json:"updatedAt" orm:"updated_at" description:""` // 21 | } 22 | -------------------------------------------------------------------------------- /server/api/rag/v1/chat.go: -------------------------------------------------------------------------------- 1 | package v1 2 | 3 | import ( 4 | "github.com/cloudwego/eino/schema" 5 | "github.com/gogf/gf/v2/frame/g" 6 | ) 7 | 8 | type ChatReq struct { 9 | g.Meta `path:"/v1/chat" method:"post" tags:"rag"` 10 | ConvID string `json:"conv_id" v:"required"` // 会话id 11 | Question string `json:"question" v:"required"` 12 | KnowledgeName string `json:"knowledge_name" v:"required"` 13 | TopK int `json:"top_k"` // 默认为5 14 | Score float64 `json:"score"` // 默认为0.2 15 | } 16 | 17 | type ChatRes struct { 18 | g.Meta `mime:"application/json"` 19 | Answer string `json:"answer"` 20 | References []*schema.Document `json:"references"` 21 | } 22 | 23 | // ChatStreamReq 流式输出请求 24 | type ChatStreamReq struct { 25 | g.Meta `path:"/v1/chat/stream" method:"post" tags:"rag"` 26 | ConvID string `json:"conv_id" v:"required"` // 会话id 27 | Question string `json:"question" v:"required"` 28 | KnowledgeName string `json:"knowledge_name" v:"required"` 29 | TopK int `json:"top_k"` // 默认为5 30 | Score float64 `json:"score"` // 默认为0.2 31 | } 32 | 33 | // ChatStreamRes 流式输出响应 34 | type ChatStreamRes struct { 35 | g.Meta `mime:"text/event-stream"` 36 | // 流式输出不需要返回具体内容,内容通过HTTP响应流返回 37 | } 38 | -------------------------------------------------------------------------------- /server/api/rag/v1/documents.go: -------------------------------------------------------------------------------- 1 | package v1 2 | 3 | import ( 4 | "github.com/gogf/gf/v2/frame/g" 5 | "github.com/wangle201210/go-rag/server/internal/model/entity" 6 | ) 7 | 8 | const ( 9 | StatusPending Status = 0 10 | StatusIndexing Status = 1 11 | StatusActive Status = 2 12 | StatusFailed Status = 3 13 | ) 14 | 15 | type DocumentsListReq struct { 16 | g.Meta `path:"/v1/documents" method:"get" tags:"rag"` 17 | KnowledgeName string `p:"knowledge_name" dc:"knowledge_name" v:"required|length:3,50"` 18 | Page int `p:"page" dc:"page" v:"required|min:1" d:"1"` 19 | Size int `p:"size" dc:"size" v:"required|min:1|max:100" d:"10"` 20 | } 21 | 22 | type DocumentsListRes struct { 23 | g.Meta `mime:"application/json"` 24 | Data []entity.KnowledgeDocuments `json:"data"` 25 | Total int `json:"total"` 26 | Page int `json:"page"` 27 | Size int `json:"size"` 28 | } 29 | 30 | type DocumentsDeleteReq struct { 31 | g.Meta `path:"/v1/documents" method:"delete" tags:"rag" summary:"Delete a document and its chunks"` 32 | DocumentId int64 `p:"document_id" dc:"document_id" v:"required"` 33 | } 34 | 35 | type DocumentsDeleteRes struct { 36 | g.Meta `mime:"application/json"` 37 | } 38 | -------------------------------------------------------------------------------- /server/internal/dao/db/mysql.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "gorm.io/driver/mysql" 8 | "gorm.io/gorm" 9 | "gorm.io/gorm/logger" 10 | ) 11 | 12 | type Mysql struct { 13 | Base 14 | cfg *Config 15 | } 16 | 17 | func NewMysql(cfg *Config) *Mysql { 18 | if cfg.LogLevel == 0 { 19 | cfg.LogLevel = 4 20 | } 21 | return &Mysql{cfg: cfg} 22 | } 23 | 24 | func (m *Mysql) Connect() error { 25 | dsn := m.DSN() 26 | dialect := mysql.Open(dsn) 27 | config := &gorm.Config{ 28 | Logger: logger.Default.LogMode(logger.LogLevel(m.cfg.LogLevel)), 29 | } 30 | db, err := gorm.Open(dialect, config) 31 | if err != nil { 32 | return fmt.Errorf("failed to connect database: %v", err) 33 | } 34 | 35 | sqlDB, err := db.DB() 36 | if err != nil { 37 | return fmt.Errorf("failed to get database instance: %v", err) 38 | } 39 | 40 | sqlDB.SetMaxIdleConns(m.cfg.MaxIdleConn) 41 | sqlDB.SetMaxOpenConns(m.cfg.MaxOpenConn) 42 | 43 | sqlDB.SetConnMaxLifetime(time.Hour) 44 | 45 | m.DB = db 46 | return nil 47 | } 48 | 49 | func (m *Mysql) DSN() string { 50 | var dsn string 51 | dsn = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s&parseTime=True&loc=Local", 52 | m.cfg.User, m.cfg.Password, m.cfg.Host, m.cfg.Port, m.cfg.Database, m.cfg.Charset) 53 | return dsn 54 | } 55 | -------------------------------------------------------------------------------- /server/internal/mcp/knowledgebase.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 8 | "github.com/gogf/gf/v2/frame/g" 9 | "github.com/gogf/gf/v2/os/gctx" 10 | v1 "github.com/wangle201210/go-rag/server/api/rag/v1" 11 | ) 12 | 13 | type KnowledgeBaseParam struct { 14 | } 15 | 16 | func GetKnowledgeBaseTool() *protocol.Tool { 17 | tool, err := protocol.NewTool("getKnowledgeBaseList", "获取知识库列表", KnowledgeBaseParam{}) 18 | if err != nil { 19 | g.Log().Errorf(gctx.New(), "Failed to create tool: %v", err) 20 | return nil 21 | } 22 | return tool 23 | } 24 | 25 | func HandleKnowledgeBase(ctx context.Context, toolReq *protocol.CallToolRequest) (res *protocol.CallToolResult, err error) { 26 | statusOK := v1.StatusOK 27 | getList, err := c.KBGetList(ctx, &v1.KBGetListReq{ 28 | Status: &statusOK, 29 | }) 30 | if err != nil { 31 | return nil, err 32 | } 33 | list := getList.List 34 | msg := fmt.Sprintf("get %d knowledgeBase", len(list)) 35 | for _, l := range list { 36 | msg += fmt.Sprintf("\n - name: %s, description: %s", l.Name, l.Description) 37 | } 38 | return &protocol.CallToolResult{ 39 | Content: []protocol.Content{ 40 | &protocol.TextContent{ 41 | Type: "text", 42 | Text: msg, 43 | }, 44 | }, 45 | }, nil 46 | } 47 | -------------------------------------------------------------------------------- /server/internal/controller/rag/rag_v1_chat_stream.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/cloudwego/eino/schema" 7 | "github.com/gogf/gf/v2/frame/g" 8 | v1 "github.com/wangle201210/go-rag/server/api/rag/v1" 9 | "github.com/wangle201210/go-rag/server/core/common" 10 | "github.com/wangle201210/go-rag/server/internal/logic/chat" 11 | ) 12 | 13 | // ChatStream 流式输出接口 14 | func (c *ControllerV1) ChatStream(ctx context.Context, req *v1.ChatStreamReq) (res *v1.ChatStreamRes, err error) { 15 | var streamReader *schema.StreamReader[*schema.Message] 16 | // 获取检索结果 17 | retriever, err := c.Retriever(ctx, &v1.RetrieverReq{ 18 | Question: req.Question, 19 | TopK: req.TopK, 20 | Score: req.Score, 21 | KnowledgeName: req.KnowledgeName, 22 | }) 23 | if err != nil { 24 | g.Log().Error(ctx, err) 25 | return 26 | } 27 | // 获取Chat实例 28 | chatI := chat.GetChat() 29 | // 获取流式响应 30 | streamReader, err = chatI.GetAnswerStream(ctx, req.ConvID, retriever.Document, req.Question) 31 | if err != nil { 32 | g.Log().Error(ctx, err) 33 | return &v1.ChatStreamRes{}, nil 34 | } 35 | defer streamReader.Close() 36 | err = common.SteamResponse(ctx, streamReader, retriever.Document) 37 | if err != nil { 38 | g.Log().Error(ctx, err) 39 | return 40 | } 41 | return &v1.ChatStreamRes{}, nil 42 | } 43 | -------------------------------------------------------------------------------- /fe/README.md: -------------------------------------------------------------------------------- 1 | # element-plus-vite-starter 2 | 3 | > A starter kit for Element Plus with Vite 4 | 5 | - Preview: 6 | 7 | This is an example of on-demand element-plus with [unplugin-vue-components](https://github.com/antfu/unplugin-vue-components). 8 | 9 | > If you want to import all, it may be so simple that no examples are needed. Just follow [quickstart | Docs](https://element-plus.org/zh-CN/guide/quickstart.html) and import them. 10 | 11 | If you just want an on-demand import example `manually`, you can check [unplugin-element-plus/examples/vite](https://github.com/element-plus/unplugin-element-plus/tree/main/examples/vite). 12 | 13 | If you want to a nuxt starter, see [element-plus-nuxt-starter](https://github.com/element-plus/element-plus-nuxt-starter/). 14 | 15 | ## Project setup 16 | 17 | ```bash 18 | pnpm install 19 | 20 | # npm install 21 | # yarn install 22 | ``` 23 | 24 | ### Compiles and hot-reloads for development 25 | 26 | ```bash 27 | npm run dev 28 | ``` 29 | 30 | ### Compiles and minifies for production 31 | 32 | ```bash 33 | npm run build 34 | ``` 35 | 36 | ## Usage 37 | 38 | ```bash 39 | git clone https://github.com/element-plus/element-plus-vite-starter 40 | cd element-plus-vite-starter 41 | npm i 42 | npm run dev 43 | ``` 44 | 45 | ### Custom theme 46 | 47 | See `src/styles/element/index.scss`. 48 | -------------------------------------------------------------------------------- /fe/src/components/layouts/BaseHeader.vue: -------------------------------------------------------------------------------- 1 | 6 | 7 | 31 | 32 | 48 | -------------------------------------------------------------------------------- /fe/src/main.ts: -------------------------------------------------------------------------------- 1 | import type { UserModule } from './types' 2 | import { ViteSSG } from 'vite-ssg' 3 | 4 | // import "~/styles/element/index.scss"; 5 | 6 | // import ElementPlus from "element-plus"; 7 | // import all element css, uncommented next line 8 | // import "element-plus/dist/index.css"; 9 | 10 | // or use cdn, uncomment cdn link in `index.html` 11 | 12 | import App from './App.vue' 13 | import router from './router' 14 | 15 | import '~/styles/index.scss' 16 | 17 | import 'uno.css' 18 | // If you want to use ElMessage, import it. 19 | import 'element-plus/theme-chalk/src/message.scss' 20 | import 'element-plus/theme-chalk/src/message-box.scss' 21 | 22 | // if you do not need ssg: 23 | // import { createApp } from "vue"; 24 | 25 | // const app = createApp(App); 26 | // app.use(createRouter({ 27 | // history: createWebHistory(), 28 | // routes, 29 | // })) 30 | // // app.use(ElementPlus); 31 | // app.mount("#app"); 32 | 33 | // https://github.com/antfu/vite-ssg 34 | export const createApp = ViteSSG( 35 | App, 36 | { 37 | routes: router.options.routes, 38 | base: import.meta.env.BASE_URL, 39 | history: router.options.history, 40 | }, 41 | (ctx) => { 42 | // install all modules under `modules/` 43 | Object.values(import.meta.glob<{ install: UserModule }>('./modules/*.ts', { eager: true })) 44 | .forEach(i => i.install?.(ctx)) 45 | // ctx.app.use(Previewer) 46 | }, 47 | ) 48 | -------------------------------------------------------------------------------- /server/core/indexer/loader.go: -------------------------------------------------------------------------------- 1 | package indexer 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/cloudwego/eino-ext/components/document/loader/file" 7 | "github.com/cloudwego/eino-ext/components/document/loader/url" 8 | "github.com/cloudwego/eino/components/document" 9 | "github.com/cloudwego/eino/schema" 10 | "github.com/wangle201210/go-rag/server/core/common" 11 | ) 12 | 13 | // newLoader component initialization function of node 'Loader1' in graph 'rag' 14 | func newLoader(ctx context.Context) (ldr document.Loader, err error) { 15 | mldr := &multiLoader{} 16 | parser, err := newParser(ctx) 17 | if err != nil { 18 | return nil, err 19 | } 20 | fldr, err := file.NewFileLoader(ctx, &file.FileLoaderConfig{ 21 | UseNameAsID: false, 22 | Parser: parser, 23 | }) 24 | if err != nil { 25 | return nil, err 26 | } 27 | mldr.fileLoader = fldr 28 | uldr, err := url.NewLoader(ctx, &url.LoaderConfig{}) 29 | if err != nil { 30 | return nil, err 31 | } 32 | mldr.urlLoader = uldr 33 | return mldr, nil 34 | } 35 | 36 | type multiLoader struct { 37 | fileLoader document.Loader 38 | urlLoader document.Loader 39 | } 40 | 41 | func (x *multiLoader) Load(ctx context.Context, src document.Source, opts ...document.LoaderOption) ([]*schema.Document, error) { 42 | if common.IsURL(src.URI) { 43 | return x.urlLoader.Load(ctx, src, opts...) 44 | } 45 | return x.fileLoader.Load(ctx, src, opts...) 46 | } 47 | -------------------------------------------------------------------------------- /server/core/grader/grader.go: -------------------------------------------------------------------------------- 1 | package grader 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/cloudwego/eino/components/model" 9 | "github.com/cloudwego/eino/schema" 10 | "github.com/gogf/gf/v2/frame/g" 11 | ) 12 | 13 | type Grader struct { 14 | cm model.BaseChatModel 15 | } 16 | 17 | func NewGrader(cm model.BaseChatModel) *Grader { 18 | return &Grader{ 19 | cm: cm, 20 | } 21 | } 22 | 23 | // Retriever 检查下检索到的结果是否能够回答当前问题 24 | func (x *Grader) Retriever(ctx context.Context, docs []*schema.Document, question string) (pass bool, err error) { 25 | messages, err := retrieverMessages(docs, question) 26 | if err != nil { 27 | return 28 | } 29 | result, err := x.cm.Generate(ctx, messages) 30 | if err != nil { 31 | return false, fmt.Errorf("检查下检索到的结果是否能够回答当前问题失败: %v", err) 32 | } 33 | pass = isPass(result.Content) 34 | return 35 | } 36 | 37 | func (x *Grader) Related(ctx context.Context, doc *schema.Document, question string) (pass bool, err error) { 38 | messages, err := docRelatedMessages(doc, question) 39 | if err != nil { 40 | return 41 | } 42 | result, err := x.cm.Generate(ctx, messages) 43 | if err != nil { 44 | return false, fmt.Errorf("检查下检索到的结果是否和用户问题相关失败: %v", err) 45 | } 46 | pass = isPass(result.Content) 47 | return 48 | } 49 | 50 | func isPass(msg string) bool { 51 | g.Log().Infof(context.Background(), "isPass: %s", msg) 52 | msg = strings.ToLower(msg) 53 | return strings.Contains(msg, "yes") 54 | } 55 | -------------------------------------------------------------------------------- /fe/public/vite.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /server/internal/controller/rag/rag_v1_indexer.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/gogf/gf/v2/frame/g" 7 | gorag "github.com/wangle201210/go-rag/server/core" 8 | "github.com/wangle201210/go-rag/server/internal/logic/knowledge" 9 | "github.com/wangle201210/go-rag/server/internal/logic/rag" 10 | "github.com/wangle201210/go-rag/server/internal/model/entity" 11 | 12 | v1 "github.com/wangle201210/go-rag/server/api/rag/v1" 13 | ) 14 | 15 | func (c *ControllerV1) Indexer(ctx context.Context, req *v1.IndexerReq) (res *v1.IndexerRes, err error) { 16 | svr := rag.GetRagSvr() 17 | uri := req.URL 18 | if req.File != nil { 19 | filename, e := req.File.Save("./uploads/") 20 | if e != nil { 21 | err = e 22 | return 23 | } 24 | uri = "./uploads/" + filename 25 | } 26 | 27 | documents := entity.KnowledgeDocuments{ 28 | KnowledgeBaseName: req.KnowledgeName, 29 | FileName: req.File.Filename, 30 | Status: int(v1.StatusPending), 31 | } 32 | documentsId, err := knowledge.SaveDocumentsInfo(ctx, documents) 33 | if err != nil { 34 | g.Log().Errorf(ctx, "SaveDocumentsInfo failed, err=%v", err) 35 | return 36 | } 37 | 38 | indexReq := &gorag.IndexReq{ 39 | URI: uri, 40 | KnowledgeName: req.KnowledgeName, 41 | DocumentsId: documentsId, 42 | } 43 | ids, err := svr.Index(ctx, indexReq) 44 | if err != nil { 45 | return 46 | } 47 | res = &v1.IndexerRes{ 48 | DocIDs: ids, 49 | } 50 | return 51 | } 52 | -------------------------------------------------------------------------------- /fe/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "go-rag", 3 | "type": "module", 4 | "version": "0.1.0", 5 | "private": true, 6 | "packageManager": "pnpm@9.15.0", 7 | "license": "MIT", 8 | "homepage": "https://github.com/wangle201210/go-rag", 9 | "repository": { 10 | "url": "https://github.com/wangle201210/go-rag" 11 | }, 12 | "scripts": { 13 | "dev": "vite", 14 | "build": "vite build", 15 | "generate": "vite-ssg build", 16 | "lint": "eslint .", 17 | "preview": "vite preview", 18 | "typecheck": "vue-tsc --noEmit" 19 | }, 20 | "dependencies": { 21 | "@element-plus/icons-vue": "^2.3.1", 22 | "@vueuse/core": "^12.0.0", 23 | "axios": "^1.4.0", 24 | "dompurify": "^3.0.5", 25 | "element-plus": "^2.9.0", 26 | "highlight.js": "^11.8.0", 27 | "marked": "^4.3.0", 28 | "nprogress": "^0.2.0", 29 | "pinia": "^3.0.3", 30 | "uuid": "^9.0.0", 31 | "vue": "^3.5.13", 32 | "vue-router": "^4.5.0" 33 | }, 34 | "devDependencies": { 35 | "@antfu/eslint-config": "^3.11.2", 36 | "@iconify-json/ep": "^1.2.1", 37 | "@iconify-json/ri": "^1.2.3", 38 | "@types/node": "^20.17.10", 39 | "@unocss/eslint-plugin": "^0.65.1", 40 | "@vitejs/plugin-vue": "^5.2.1", 41 | "eslint": "^9.16.0", 42 | "eslint-plugin-format": "^0.1.3", 43 | "sass": "^1.82.0", 44 | "typescript": "^5.6.3", 45 | "unocss": "^0.65.1", 46 | "unplugin-vue-components": "^0.27.5", 47 | "unplugin-vue-router": "^0.10.9", 48 | "vite": "^6.0.3", 49 | "vite-ssg": "^0.24.2", 50 | "vue-tsc": "^2.1.10" 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /fe/src/components/layouts/BaseSide.vue: -------------------------------------------------------------------------------- 1 | 32 | 33 | 56 | -------------------------------------------------------------------------------- /server/api/rag/v1/retriever.go: -------------------------------------------------------------------------------- 1 | package v1 2 | 3 | import ( 4 | "github.com/cloudwego/eino/schema" 5 | "github.com/gogf/gf/v2/frame/g" 6 | ) 7 | 8 | type RetrieverReq struct { 9 | g.Meta `path:"/v1/retriever" method:"post" tags:"rag"` 10 | Question string `json:"question" v:"required"` 11 | TopK int `json:"top_k"` // 默认为5 12 | Score float64 `json:"score"` // 默认为0.2 13 | KnowledgeName string `json:"knowledge_name" v:"required"` 14 | } 15 | 16 | type RetrieverRes struct { 17 | g.Meta `mime:"application/json"` 18 | Document []*schema.Document `json:"document"` 19 | } 20 | 21 | type RetrieverDifyReq struct { 22 | g.Meta `path:"/v1/dify/retrieval" method:"post" tags:"rag" no_wrap_resp:"true"` 23 | KnowledgeID string `json:"knowledge_id" v:"required"` 24 | Query string `json:"query" v:"required"` 25 | RetrievalSetting *RetrievalSetting `json:"retrieval_setting" v:"required"` 26 | // MetadataCondition map[string]interface{} `json:"metadata_condition"` 27 | } 28 | 29 | type RetrievalSetting struct { 30 | TopK int `json:"top_k"` 31 | ScoreThreshold float64 `json:"score_threshold"` 32 | } 33 | type RetrieverDifyRes struct { 34 | g.Meta `mime:"application/json"` 35 | Records []*Record `json:"records"` 36 | } 37 | 38 | type Record struct { 39 | Metadata *Metadata `json:"metadata"` 40 | Score float64 `json:"score"` 41 | Title string `json:"title"` 42 | Content string `json:"content"` 43 | } 44 | 45 | type Metadata struct { 46 | Path string `json:"path"` 47 | Description string `json:"description"` 48 | } 49 | -------------------------------------------------------------------------------- /server/core/vector/interface.go: -------------------------------------------------------------------------------- 1 | package vector 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/cloudwego/eino/schema" 7 | ) 8 | 9 | // VectorStore 向量存储接口 10 | type VectorStore interface { 11 | // CreateIndex 创建索引 12 | CreateIndex(ctx context.Context, indexName string) error 13 | 14 | // IndexExists 检查索引是否存在 15 | IndexExists(ctx context.Context, indexName string) (bool, error) 16 | 17 | // DeleteDocument 删除文档 18 | DeleteDocument(ctx context.Context, indexName, documentID string) error 19 | 20 | // GetKnowledgeBaseList 获取知识库列表 21 | GetKnowledgeBaseList(ctx context.Context, indexName string) ([]string, error) 22 | 23 | // SearchDocuments 搜索文档 24 | SearchDocuments(ctx context.Context, req *SearchRequest) (*SearchResponse, error) 25 | 26 | // Close 关闭连接 27 | Close() error 28 | } 29 | 30 | // SearchRequest 搜索请求 31 | type SearchRequest struct { 32 | IndexName string 33 | Query interface{} // 查询条件,不同实现可能不同 34 | Size int 35 | KnowledgeName string 36 | DocIDs []string 37 | } 38 | 39 | // SearchResponse 搜索响应 40 | type SearchResponse struct { 41 | Documents []*schema.Document 42 | Total int64 43 | } 44 | 45 | // Config 向量存储配置 46 | type Config struct { 47 | Type string // 类型:es 或 qdrant 48 | IndexName string // 索引名称 49 | ES *ESConfig // ES 配置 50 | Qdrant *QdrantConfig // Qdrant 配置 51 | } 52 | 53 | // ESConfig Elasticsearch 配置 54 | type ESConfig struct { 55 | Address string 56 | Username string 57 | Password string 58 | } 59 | 60 | // QdrantConfig Qdrant 配置 61 | type QdrantConfig struct { 62 | Address string 63 | Port int 64 | APIKey string 65 | } 66 | -------------------------------------------------------------------------------- /server/api/rag/v1/chunks.go: -------------------------------------------------------------------------------- 1 | package v1 2 | 3 | import ( 4 | "github.com/gogf/gf/v2/frame/g" 5 | "github.com/wangle201210/go-rag/server/internal/model/entity" 6 | ) 7 | 8 | type ChunksListReq struct { 9 | g.Meta `path:"/v1/chunks" method:"get" tags:"rag"` 10 | KnowledgeDocId int64 `p:"knowledge_doc_id" dc:"knowledge_doc_id" v:"required"` 11 | Page int `p:"page" dc:"page" v:"required|min:1" d:"1"` 12 | Size int `p:"size" dc:"size" v:"required|min:1|max:100" d:"10"` 13 | } 14 | 15 | type ChunksListRes struct { 16 | g.Meta `mime:"application/json"` 17 | Data []entity.KnowledgeChunks `json:"data"` 18 | Total int `json:"total"` 19 | Page int `json:"page"` 20 | Size int `json:"size"` 21 | } 22 | 23 | type ChunkDeleteReq struct { 24 | g.Meta `path:"/v1/chunks" method:"delete" tags:"rag"` 25 | Id int64 `p:"id" dc:"id" v:"required"` 26 | } 27 | 28 | type ChunkDeleteRes struct { 29 | g.Meta `mime:"application/json"` 30 | } 31 | 32 | type UpdateChunkReq struct { 33 | g.Meta `path:"/v1/chunks" method:"put" tags:"rag"` 34 | Ids []int64 `p:"ids" dc:"ids" v:"required"` 35 | Status int `p:"status" dc:"status" v:"required|in:0,1"` 36 | } 37 | 38 | type UpdateChunkRes struct { 39 | g.Meta `mime:"application/json"` 40 | } 41 | 42 | type UpdateChunkContentReq struct { 43 | g.Meta `path:"/v1/chunks_content" method:"put" tags:"rag"` 44 | Id int64 `p:"id" dc:"id" v:"required"` 45 | Content string `p:"content" dc:"content" v:"required"` 46 | } 47 | 48 | type UpdateChunkContentRes struct { 49 | g.Meta `mime:"application/json"` 50 | } 51 | -------------------------------------------------------------------------------- /fe/vite.config.ts: -------------------------------------------------------------------------------- 1 | import path from 'node:path' 2 | import Vue from '@vitejs/plugin-vue' 3 | 4 | import Unocss from 'unocss/vite' 5 | import { ElementPlusResolver } from 'unplugin-vue-components/resolvers' 6 | import Components from 'unplugin-vue-components/vite' 7 | 8 | import { defineConfig } from 'vite' 9 | 10 | // https://vitejs.dev/config/ 11 | export default defineConfig({ 12 | resolve: { 13 | alias: { 14 | '~/': `${path.resolve(__dirname, 'src')}/`, 15 | }, 16 | }, 17 | 18 | server: { 19 | proxy: { 20 | '/api': { 21 | target: 'http://localhost:8000', 22 | changeOrigin: true, 23 | rewrite: path => path.replace(/^\/api/, '/api'), 24 | }, 25 | }, 26 | }, 27 | 28 | css: { 29 | preprocessorOptions: { 30 | scss: { 31 | additionalData: `@use "~/styles/element/index.scss" as *;`, 32 | api: 'modern-compiler', 33 | }, 34 | }, 35 | }, 36 | 37 | plugins: [ 38 | Vue(), 39 | 40 | Components({ 41 | // allow auto load markdown components under `./src/components/` 42 | extensions: ['vue', 'md'], 43 | // allow auto import and register components used in markdown 44 | include: [/\.vue$/, /\.vue\?vue/, /\.md$/], 45 | resolvers: [ 46 | ElementPlusResolver({ 47 | importStyle: 'sass', 48 | }), 49 | ], 50 | dts: 'src/components.d.ts', 51 | }), 52 | 53 | // https://github.com/antfu/unocss 54 | // see uno.config.ts for config 55 | Unocss(), 56 | ], 57 | 58 | ssr: { 59 | // TODO: workaround until they support native ESM 60 | noExternal: ['element-plus'], 61 | }, 62 | }) 63 | -------------------------------------------------------------------------------- /server/internal/controller/rag/rag_v1_retriever.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "sort" 7 | 8 | "github.com/gogf/gf/v2/frame/g" 9 | gorag "github.com/wangle201210/go-rag/server/core" 10 | "github.com/wangle201210/go-rag/server/internal/logic/rag" 11 | 12 | v1 "github.com/wangle201210/go-rag/server/api/rag/v1" 13 | ) 14 | 15 | func (c *ControllerV1) Retriever(ctx context.Context, req *v1.RetrieverReq) (res *v1.RetrieverRes, err error) { 16 | ragSvr := rag.GetRagSvr() 17 | if req.TopK == 0 { 18 | req.TopK = 5 19 | } 20 | if req.Score == 0 { 21 | req.Score = 0.2 22 | } 23 | if req.Score < 1.0 { 24 | req.Score += 1 25 | } 26 | ragReq := &gorag.RetrieveReq{ 27 | Query: req.Question, 28 | TopK: req.TopK, 29 | Score: req.Score, 30 | KnowledgeName: req.KnowledgeName, 31 | } 32 | g.Log().Infof(ctx, "ragReq: %v", ragReq) 33 | msg, err := ragSvr.Retrieve(ctx, ragReq) 34 | if err != nil { 35 | return 36 | } 37 | for _, document := range msg { 38 | if document.MetaData != nil { 39 | delete(document.MetaData, "_dense_vector") 40 | if extValue, ok := document.MetaData["ext"]; ok && extValue != nil { 41 | if extStr, ok := extValue.(string); ok && extStr != "" { 42 | m := make(map[string]interface{}) 43 | if err = json.Unmarshal([]byte(extStr), &m); err != nil { 44 | return 45 | } 46 | document.MetaData["ext"] = m 47 | } 48 | } 49 | } 50 | } 51 | // eino 默认是把分高的排在两边,这里我xiu gai 52 | sort.Slice(msg, func(i, j int) bool { 53 | return msg[i].Score() > msg[j].Score() 54 | }) 55 | res = &v1.RetrieverRes{ 56 | Document: msg, 57 | } 58 | return 59 | } 60 | -------------------------------------------------------------------------------- /server/internal/dao/db/sqlite.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "time" 7 | 8 | "gorm.io/driver/sqlite" 9 | "gorm.io/gorm" 10 | "gorm.io/gorm/logger" 11 | ) 12 | 13 | type SQLite struct { 14 | cfg *Config 15 | Base 16 | } 17 | 18 | func NewSqlite(cfg *Config) *SQLite { 19 | s := &SQLite{cfg: cfg} 20 | return s 21 | } 22 | 23 | func (s *SQLite) Connect() error { 24 | dsn := s.DSN() 25 | dialect := sqlite.Open(dsn) 26 | config := &gorm.Config{ 27 | Logger: logger.Default.LogMode(logger.LogLevel(s.cfg.LogLevel)), 28 | } 29 | db, err := gorm.Open(dialect, config) 30 | if err != nil { 31 | return fmt.Errorf("failed to connect database: %v", err) 32 | } 33 | 34 | sqlDB, err := db.DB() 35 | if err != nil { 36 | return fmt.Errorf("failed to get database instance: %v", err) 37 | } 38 | 39 | sqlDB.SetMaxIdleConns(s.cfg.MaxIdleConn) 40 | sqlDB.SetMaxOpenConns(s.cfg.MaxOpenConn) 41 | 42 | sqlDB.SetConnMaxLifetime(time.Hour) 43 | 44 | s.DB = db 45 | return nil 46 | } 47 | 48 | func (s *SQLite) DSN() string { 49 | var dsn string 50 | var params []string 51 | if s.cfg.JournalMode != "" { 52 | params = append(params, fmt.Sprintf("_journal_mode=%s", s.cfg.JournalMode)) 53 | } 54 | if s.cfg.Synchronous != "" { 55 | params = append(params, fmt.Sprintf("_synchronous=%s", s.cfg.Synchronous)) 56 | } 57 | if s.cfg.CacheSize > 0 { 58 | params = append(params, fmt.Sprintf("_cache_size=%d", s.cfg.CacheSize)) 59 | } 60 | if s.cfg.BusyTimeout > 0 { 61 | params = append(params, fmt.Sprintf("_busy_timeout=%d", s.cfg.BusyTimeout)) 62 | } 63 | dsn = s.cfg.FilePath 64 | if len(params) > 0 { 65 | dsn += "?" + strings.Join(params, "&") 66 | } 67 | return dsn 68 | } 69 | -------------------------------------------------------------------------------- /server/core/indexer/orchestration.go: -------------------------------------------------------------------------------- 1 | package indexer 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/cloudwego/eino/compose" 7 | "github.com/wangle201210/go-rag/server/core/config" 8 | ) 9 | 10 | func BuildIndexer(ctx context.Context, conf *config.Config) (r compose.Runnable[any, []string], err error) { 11 | const ( 12 | Loader1 = "Loader" 13 | Indexer2 = "Indexer" 14 | DocumentTransformer3 = "DocumentTransformer" 15 | DocAddIDAndMerge = "DocAddIDAndMerge" 16 | // QA = "QA" 17 | ) 18 | 19 | g := compose.NewGraph[any, []string]() 20 | loader1KeyOfLoader, err := newLoader(ctx) 21 | if err != nil { 22 | return nil, err 23 | } 24 | _ = g.AddLoaderNode(Loader1, loader1KeyOfLoader) 25 | indexer2KeyOfIndexer, err := newIndexer(ctx, conf) 26 | if err != nil { 27 | return nil, err 28 | } 29 | _ = g.AddIndexerNode(Indexer2, indexer2KeyOfIndexer) 30 | documentTransformer2KeyOfDocumentTransformer, err := newDocumentTransformer(ctx) 31 | if err != nil { 32 | return nil, err 33 | } 34 | _ = g.AddLambdaNode(DocAddIDAndMerge, compose.InvokableLambda(docAddIDAndMerge)) 35 | // _ = g.AddLambdaNode(QA, compose.InvokableLambda(qa)) // qa 异步 执行 36 | 37 | _ = g.AddDocumentTransformerNode(DocumentTransformer3, documentTransformer2KeyOfDocumentTransformer) 38 | _ = g.AddEdge(compose.START, Loader1) 39 | _ = g.AddEdge(Loader1, DocumentTransformer3) 40 | _ = g.AddEdge(DocumentTransformer3, DocAddIDAndMerge) 41 | _ = g.AddEdge(DocAddIDAndMerge, Indexer2) 42 | // _ = g.AddEdge(DocAddIDAndMerge, QA) 43 | // _ = g.AddEdge(QA, Indexer2) 44 | _ = g.AddEdge(Indexer2, compose.END) 45 | r, err = g.Compile(ctx, compose.WithGraphName("indexer")) 46 | if err != nil { 47 | return nil, err 48 | } 49 | return r, err 50 | } 51 | -------------------------------------------------------------------------------- /fe/src/router/index.ts: -------------------------------------------------------------------------------- 1 | import type { RouteRecordRaw } from 'vue-router' 2 | import { createRouter, createWebHashHistory } from 'vue-router' 3 | 4 | const routes: RouteRecordRaw[] = [ 5 | { 6 | path: '/', 7 | redirect: '/knowledge-base', 8 | }, 9 | { 10 | path: '/knowledge-base', 11 | name: 'KnowledgeBase', 12 | component: () => import('~/pages/rag/knowledge-base.vue'), 13 | meta: { 14 | title: '知识库管理', 15 | icon: 'FolderOpened', 16 | showInMenu: true, 17 | }, 18 | }, 19 | { 20 | path: '/indexer', 21 | name: 'Indexer', 22 | component: () => import('~/pages/rag/indexer.vue'), 23 | meta: { 24 | title: '文档索引', 25 | icon: 'Upload', 26 | showInMenu: true, 27 | }, 28 | }, 29 | { 30 | path: '/knowledge-documents', 31 | name: 'KnowledgeDocuments', 32 | component: () => import('~/pages/rag/knowledge-documents.vue'), 33 | meta: { 34 | title: '文档管理', 35 | icon: 'Files', 36 | showInMenu: true, 37 | }, 38 | }, 39 | { 40 | path: '/retriever', 41 | name: 'Retriever', 42 | component: () => import('~/pages/rag/retriever.vue'), 43 | meta: { 44 | title: '文档检索', 45 | icon: 'Search', 46 | showInMenu: true, 47 | }, 48 | }, 49 | { 50 | path: '/chat', 51 | name: 'Chat', 52 | component: () => import('~/pages/rag/chat.vue'), 53 | meta: { 54 | title: '智能问答', 55 | icon: 'ChatDotRound', 56 | showInMenu: true, 57 | }, 58 | }, 59 | { 60 | path: '/chunk-details/:documentId', 61 | name: 'ChunkDetails', 62 | component: () => import('~/pages/rag/chunk-details/[documentId].vue'), 63 | }, 64 | ] 65 | 66 | const router = createRouter({ 67 | history: createWebHashHistory(), 68 | routes, 69 | }) 70 | 71 | export default router -------------------------------------------------------------------------------- /fe/src/typed-router.d.ts: -------------------------------------------------------------------------------- 1 | /* eslint-disable */ 2 | /* prettier-ignore */ 3 | // @ts-nocheck 4 | // Generated by unplugin-vue-router. ‼️ DO NOT MODIFY THIS FILE ‼️ 5 | // It's recommended to commit this file. 6 | // Make sure to add this file to your tsconfig.json file as an "includes" or "files" entry. 7 | 8 | declare module 'vue-router/auto-routes' { 9 | import type { 10 | RouteRecordInfo, 11 | ParamValue, 12 | ParamValueOneOrMore, 13 | ParamValueZeroOrMore, 14 | ParamValueZeroOrOne, 15 | } from 'vue-router' 16 | 17 | /** 18 | * Route name map generated by unplugin-vue-router 19 | */ 20 | export interface RouteNamedMap { 21 | '/': RouteRecordInfo<'/', '/', Record, Record>, 22 | '/chat': RouteRecordInfo<'/chat', '/chat', Record, Record>, 23 | '/chunk-details/[documentId]': RouteRecordInfo<'/chunk-details/[documentId]', '/chunk-details/:documentId', { documentId: ParamValue }, { documentId: ParamValue }>, 24 | '/indexer': RouteRecordInfo<'/indexer', '/indexer', Record, Record>, 25 | '/knowledge-base': RouteRecordInfo<'/knowledge-base', '/knowledge-base', Record, Record>, 26 | '/knowledge-documents': RouteRecordInfo<'/knowledge-documents', '/knowledge-documents', Record, Record>, 27 | '/nav/1/item-1': RouteRecordInfo<'/nav/1/item-1', '/nav/1/item-1', Record, Record>, 28 | '/nav/2': RouteRecordInfo<'/nav/2', '/nav/2', Record, Record>, 29 | '/nav/4': RouteRecordInfo<'/nav/4', '/nav/4', Record, Record>, 30 | '/retriever': RouteRecordInfo<'/retriever', '/retriever', Record, Record>, 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /server/internal/cmd/cmd.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/ThinkInAIXYZ/go-mcp/server" 7 | "github.com/ThinkInAIXYZ/go-mcp/transport" 8 | "github.com/gogf/gf/v2/frame/g" 9 | "github.com/gogf/gf/v2/net/ghttp" 10 | "github.com/gogf/gf/v2/os/gcmd" 11 | "github.com/wangle201210/go-rag/server/internal/controller/rag" 12 | "github.com/wangle201210/go-rag/server/internal/mcp" 13 | ) 14 | 15 | var ( 16 | Main = gcmd.Command{ 17 | Name: "main", 18 | Usage: "main", 19 | Brief: "start http server", 20 | Func: func(ctx context.Context, parser *gcmd.Parser) (err error) { 21 | s := g.Server() 22 | Mcp(ctx, s) 23 | s.Group("/", func(group *ghttp.RouterGroup) { 24 | s.AddStaticPath("", "./static/fe/") 25 | s.SetIndexFiles([]string{"index.html"}) 26 | }) 27 | s.Group("/api", func(group *ghttp.RouterGroup) { 28 | group.Middleware(MiddlewareHandlerResponse, ghttp.MiddlewareCORS) 29 | group.Bind( 30 | rag.NewV1(), 31 | ) 32 | }) 33 | s.Run() 34 | return nil 35 | }, 36 | } 37 | ) 38 | 39 | func Mcp(ctx context.Context, s *ghttp.Server) { 40 | trans, handler, err := transport.NewStreamableHTTPServerTransportAndHandler() 41 | if err != nil { 42 | g.Log().Panicf(ctx, "new sse transport and hander with error: %v", err) 43 | } 44 | // new mcp server 45 | mcpServer, _ := server.NewServer(trans) 46 | mcpServer.RegisterTool(mcp.GetRetrieverTool(), mcp.HandleRetriever) 47 | mcpServer.RegisterTool(mcp.GetKnowledgeBaseTool(), mcp.HandleKnowledgeBase) 48 | // start mcp Server 49 | go func() { 50 | mcpServer.Run() 51 | }() 52 | // mcpServer.Shutdown(context.Background()) 53 | s.Group("/", func(r *ghttp.RouterGroup) { 54 | r.ALL("/mcp", func(r *ghttp.Request) { 55 | handler.HandleMCP().ServeHTTP(r.Response.Writer, r.Request) 56 | }) 57 | }) 58 | } 59 | -------------------------------------------------------------------------------- /server/core/message.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/cloudwego/eino/components/prompt" 9 | "github.com/cloudwego/eino/schema" 10 | ) 11 | 12 | var system = "你非常擅长于使用rag进行数据检索," + 13 | "你的目标是在充分理解用户的问题后进行向量化检索\n" + 14 | "现在时间{time_now}\n" + 15 | "你要优化并提取搜索的查询内容。" + 16 | "请遵循以下规则重写查询内容:\n" + 17 | "- 根据用户的问题和上下文,重写应该进行搜索的关键词\n" + 18 | "- 如果需要使用时间,则根据当前时间给出需要查询的具体时间日期信息\n" + 19 | // "- 生成的查询关键词要选择合适的语言,考虑用户的问题类型使用最适合的语言进行搜索,例如某些问题应该保持用户的问题语言,而有一些则更适合翻译成英语或其他语言\n" + 20 | "- 保持查询简洁,查询内容通常不超过3个关键词, 最多不要超过5个关键词\n" + 21 | "- 参考Elasticsearch搜索查询习惯重写关键字。" + 22 | "- 直接返回优化后的搜索词,不要有任何额外说明。\n" + 23 | "- 尽量不要使用下面这些已使用过的关键词,因为之前使用这些关键词搜索到的结果不符合预期,已使用过的关键词:{used}\n" + 24 | "- 尽量不使用知识库名字《{knowledgeBase}》中包含的关键词\n" 25 | 26 | // createTemplate 创建并返回一个配置好的聊天模板 27 | func createTemplate() prompt.ChatTemplate { 28 | return prompt.FromMessages(schema.FString, 29 | // 系统消息模板 30 | schema.SystemMessage(system), 31 | // 用户消息模板 32 | schema.UserMessage( 33 | "如下是用户的问题: {question}"), 34 | ) 35 | } 36 | 37 | // formatMessages 格式化消息并处理错误 38 | func formatMessages(template prompt.ChatTemplate, data map[string]any) ([]*schema.Message, error) { 39 | messages, err := template.Format(context.Background(), data) 40 | if err != nil { 41 | return nil, fmt.Errorf("格式化模板失败: %w", err) 42 | } 43 | return messages, nil 44 | } 45 | 46 | func getOptimizedQueryMessages(used, question, knowledgeBase string) ([]*schema.Message, error) { 47 | template := createTemplate() 48 | data := map[string]any{ 49 | "time_now": time.Now().Format(time.RFC3339), 50 | "question": question, 51 | "used": used, 52 | "knowledgeBase": knowledgeBase, 53 | } 54 | messages, err := formatMessages(template, data) 55 | if err != nil { 56 | return nil, err 57 | } 58 | return messages, nil 59 | } 60 | -------------------------------------------------------------------------------- /server/core/vector/factory_test.go: -------------------------------------------------------------------------------- 1 | package vector 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestNewVectorStore_ES(t *testing.T) { 8 | cfg := &Config{ 9 | Type: "es", 10 | IndexName: "test-index", 11 | ES: &ESConfig{ 12 | Address: "http://localhost:9200", 13 | Username: "", 14 | Password: "", 15 | }, 16 | } 17 | 18 | store, err := NewVectorStore(cfg) 19 | if err != nil { 20 | t.Fatalf("Failed to create ES vector store: %v", err) 21 | } 22 | 23 | if store == nil { 24 | t.Fatal("Expected non-nil vector store") 25 | } 26 | 27 | if _, ok := store.(*ESVectorStore); !ok { 28 | t.Fatal("Expected ESVectorStore type") 29 | } 30 | } 31 | 32 | func TestNewVectorStore_Qdrant(t *testing.T) { 33 | cfg := &Config{ 34 | Type: "qdrant", 35 | IndexName: "test-index", 36 | Qdrant: &QdrantConfig{ 37 | Address: "http://localhost:6333", 38 | APIKey: "", 39 | }, 40 | } 41 | 42 | store, err := NewVectorStore(cfg) 43 | if err != nil { 44 | t.Fatalf("Failed to create Qdrant vector store: %v", err) 45 | } 46 | 47 | if store == nil { 48 | t.Fatal("Expected non-nil vector store") 49 | } 50 | 51 | if _, ok := store.(*QdrantVectorStore); !ok { 52 | t.Fatal("Expected QdrantVectorStore type") 53 | } 54 | } 55 | 56 | func TestNewVectorStore_InvalidType(t *testing.T) { 57 | cfg := &Config{ 58 | Type: "invalid", 59 | IndexName: "test-index", 60 | } 61 | 62 | _, err := NewVectorStore(cfg) 63 | if err == nil { 64 | t.Fatal("Expected error for invalid type") 65 | } 66 | } 67 | 68 | func TestNewVectorStore_MissingConfig(t *testing.T) { 69 | cfg := &Config{ 70 | Type: "es", 71 | IndexName: "test-index", 72 | // ES config is nil 73 | } 74 | 75 | _, err := NewVectorStore(cfg) 76 | if err == nil { 77 | t.Fatal("Expected error for missing ES config") 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /fe/src/utils/request.js: -------------------------------------------------------------------------------- 1 | import axios from 'axios' 2 | import { ElMessage } from 'element-plus' 3 | 4 | // 创建axios实例 5 | const request = axios.create({ 6 | baseURL: import.meta.env.VITE_API_BASE_URL || '/api', // 使用环境变量或默认值 7 | timeout: 30000, // 请求超时时间 8 | headers: { 9 | 'Content-Type': 'application/json' 10 | } 11 | }) 12 | 13 | // 请求拦截器 14 | request.interceptors.request.use( 15 | config => { 16 | // 在发送请求之前做些什么 17 | const token = localStorage.getItem('token') 18 | if (token) { 19 | config.headers['Authorization'] = `Bearer ${token}` 20 | } 21 | return config 22 | }, 23 | error => { 24 | // 对请求错误做些什么 25 | console.error('请求错误:', error) 26 | return Promise.reject(error) 27 | } 28 | ) 29 | 30 | // 响应拦截器 31 | request.interceptors.response.use( 32 | response => { 33 | // 对响应数据做点什么 34 | const res = response.data 35 | if (res.code && res.code !== 0) { 36 | ElMessage.error(res.message || '请求失败') 37 | return Promise.reject(new Error(res.message || '请求失败')) 38 | } 39 | return res 40 | }, 41 | error => { 42 | // 对响应错误做点什么 43 | console.error('响应错误:', error) 44 | 45 | // 处理特定的错误状态码 46 | const status = error.response?.status 47 | switch (status) { 48 | case 401: 49 | // 未授权,清除token并跳转到登录页 50 | localStorage.removeItem('token') 51 | ElMessage.error('登录已过期,请重新登录') 52 | // 这里可以添加跳转到登录页的逻辑 53 | break 54 | case 403: 55 | ElMessage.error('没有权限访问') 56 | break 57 | case 404: 58 | ElMessage.error('请求的资源不存在') 59 | break 60 | case 500: 61 | ElMessage.error('服务器内部错误') 62 | break 63 | default: 64 | ElMessage.error(error.message || '网络错误') 65 | } 66 | 67 | return Promise.reject(error) 68 | } 69 | ) 70 | 71 | export default request -------------------------------------------------------------------------------- /server/manifest/config/config_demo.yaml: -------------------------------------------------------------------------------- 1 | server: 2 | address: ":8000" 3 | openapiPath: "/api.json" 4 | swaggerPath: "/swagger" 5 | 6 | logger: 7 | level : "all" 8 | stdout: true 9 | 10 | database: 11 | default: 12 | type: "mysql" # 数据库类型(如:mariadb/tidb/mysql/sqlite) 13 | link: "" # goframe link, mysql:root:12345678@tcp(127.0.0.1:3306)/test?loc=Local&parseTime=true / sqlite::@file(/var/data/db.sqlite3) 14 | host: "mysql" # 地址, sqlite 文件路径,若启用 sqlite 必填 15 | port: "3306" # 端口 16 | user: "root" # 账号 17 | pass: "123456" # 密码 18 | name: "go-rag" # 数据库名称 19 | charset: "utf8mb4" # 数据库编码,一定要加上,因为文档里面经常出现特殊字符 20 | timezone: "Local" 21 | # SQLite 可选参数(推荐配置以提高性能) 22 | # busy_timeout: 5000 # 忙等待超时时间(毫秒) 23 | # journal_mode: WAL # 日志模式(推荐 WAL) 24 | # synchronous: NORMAL # 同步模式(NORMAL/FULL) 25 | # cache_size: -2000 # 缓存大小(负数表示 KB) 26 | # max_open_conns: 1 # 最大连接数(SQLite 建议 1) 27 | # max_idle_conns: 1 # 最大空闲连接数(SQLite 建议 1) 28 | 29 | vector: 30 | type: "es" # 向量存储类型:es 或 qdrant 31 | indexName: "rag-test" 32 | es: 33 | address: "http://elasticsearch:9200" 34 | # username: "elastic" 35 | # password: "123456" 36 | qdrant: 37 | address: "http://qdrant:6333" 38 | # apiKey: "" 39 | 40 | embedding: 41 | apiKey: "sk-****" 42 | baseURL: "https://api.siliconflow.cn/v1" 43 | model: "BAAI/bge-m3" 44 | 45 | rerank: 46 | apiKey: "sk-****" 47 | baseURL: "https://api.siliconflow.cn/v1" 48 | model: "BAAI/bge-reranker-v2-m3" 49 | 50 | rewrite: 51 | apiKey: "sk-****" 52 | baseURL: "https://api.siliconflow.cn/v1" 53 | model: "Qwen/Qwen3-14B" # 测试下来14b速度最快 54 | 55 | qa: 56 | apiKey: "sk-****" 57 | baseURL: "https://api.siliconflow.cn/v1" 58 | model: "Qwen/Qwen3-14B" # 测试下来14b速度最快 59 | 60 | chat: 61 | apiKey: "sk-****" 62 | baseURL: "https://api.siliconflow.cn/v1" 63 | model: "deepseek-ai/DeepSeek-V3" 64 | -------------------------------------------------------------------------------- /server/internal/controller/rag/rag_v1_kb.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/wangle201210/go-rag/server/internal/dao" 7 | "github.com/wangle201210/go-rag/server/internal/model/do" 8 | 9 | "github.com/wangle201210/go-rag/server/api/rag/v1" 10 | ) 11 | 12 | func (c *ControllerV1) KBCreate(ctx context.Context, req *v1.KBCreateReq) (res *v1.KBCreateRes, err error) { 13 | insertId, err := dao.KnowledgeBase.Ctx(ctx).Data(do.KnowledgeBase{ 14 | Name: req.Name, 15 | Status: v1.StatusOK, 16 | Description: req.Description, 17 | Category: req.Category, 18 | }).InsertAndGetId() 19 | if err != nil { 20 | return nil, err 21 | } 22 | res = &v1.KBCreateRes{ 23 | Id: insertId, 24 | } 25 | return 26 | } 27 | 28 | func (c *ControllerV1) KBDelete(ctx context.Context, req *v1.KBDeleteReq) (res *v1.KBDeleteRes, err error) { 29 | _, err = dao.KnowledgeBase.Ctx(ctx).WherePri(req.Id).Delete() 30 | return 31 | } 32 | 33 | func (c *ControllerV1) KBGetList(ctx context.Context, req *v1.KBGetListReq) (res *v1.KBGetListRes, err error) { 34 | res = &v1.KBGetListRes{} 35 | err = dao.KnowledgeBase.Ctx(ctx).Where(do.KnowledgeBase{ 36 | Status: req.Status, 37 | Name: req.Name, 38 | Category: req.Category, 39 | }).Scan(&res.List) 40 | return 41 | } 42 | 43 | func (c *ControllerV1) KBGetOne(ctx context.Context, req *v1.KBGetOneReq) (res *v1.KBGetOneRes, err error) { 44 | res = &v1.KBGetOneRes{} 45 | err = dao.KnowledgeBase.Ctx(ctx).WherePri(req.Id).Scan(&res.KnowledgeBase) 46 | return 47 | } 48 | 49 | func (c *ControllerV1) KBUpdate(ctx context.Context, req *v1.KBUpdateReq) (res *v1.KBUpdateRes, err error) { 50 | _, err = dao.KnowledgeBase.Ctx(ctx).Data(do.KnowledgeBase{ 51 | Name: req.Name, 52 | Status: req.Status, 53 | Description: req.Description, 54 | Category: req.Category, 55 | }).WherePri(req.Id).Update() 56 | return 57 | } 58 | -------------------------------------------------------------------------------- /server/core/indexer/transformer.go: -------------------------------------------------------------------------------- 1 | package indexer 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown" 7 | "github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive" 8 | "github.com/cloudwego/eino/components/document" 9 | "github.com/cloudwego/eino/schema" 10 | coretypes "github.com/wangle201210/go-rag/server/core/types" 11 | ) 12 | 13 | // newDocumentTransformer component initialization function of node 'DocumentTransformer3' in graph 'rag' 14 | func newDocumentTransformer(ctx context.Context) (tfr document.Transformer, err error) { 15 | trans := &transformer{} 16 | // 递归分割 17 | config := &recursive.Config{ 18 | ChunkSize: 1000, // 每段内容1000字 19 | OverlapSize: 100, // 有10%的重叠 20 | Separators: []string{"\n", "。", "?", "?", "!", "!"}, 21 | } 22 | recTrans, err := recursive.NewSplitter(ctx, config) 23 | if err != nil { 24 | return nil, err 25 | } 26 | // md 文档特殊处理 27 | mdTrans, err := markdown.NewHeaderSplitter(ctx, &markdown.HeaderConfig{ 28 | Headers: map[string]string{"#": coretypes.Title1, "##": coretypes.Title2, "###": coretypes.Title3}, 29 | TrimHeaders: false, 30 | }) 31 | if err != nil { 32 | return nil, err 33 | } 34 | trans.recursive = recTrans 35 | trans.markdown = mdTrans 36 | return trans, nil 37 | } 38 | 39 | type transformer struct { 40 | markdown document.Transformer 41 | recursive document.Transformer 42 | } 43 | 44 | func (x *transformer) Transform(ctx context.Context, docs []*schema.Document, opts ...document.TransformerOption) ([]*schema.Document, error) { 45 | isMd := false 46 | for _, doc := range docs { 47 | // 只需要判断第一个是不是.md 48 | if doc.MetaData["_extension"] == ".md" { 49 | isMd = true 50 | break 51 | } 52 | } 53 | if isMd { 54 | return x.markdown.Transform(ctx, docs, opts...) 55 | } 56 | return x.recursive.Transform(ctx, docs, opts...) 57 | } 58 | -------------------------------------------------------------------------------- /server/api/rag/rag.go: -------------------------------------------------------------------------------- 1 | // ================================================================================= 2 | // Code generated and maintained by GoFrame CLI tool. DO NOT EDIT. 3 | // ================================================================================= 4 | 5 | package rag 6 | 7 | import ( 8 | "context" 9 | 10 | "github.com/wangle201210/go-rag/server/api/rag/v1" 11 | ) 12 | 13 | type IRagV1 interface { 14 | Chat(ctx context.Context, req *v1.ChatReq) (res *v1.ChatRes, err error) 15 | ChatStream(ctx context.Context, req *v1.ChatStreamReq) (res *v1.ChatStreamRes, err error) 16 | ChunksList(ctx context.Context, req *v1.ChunksListReq) (res *v1.ChunksListRes, err error) 17 | ChunkDelete(ctx context.Context, req *v1.ChunkDeleteReq) (res *v1.ChunkDeleteRes, err error) 18 | UpdateChunk(ctx context.Context, req *v1.UpdateChunkReq) (res *v1.UpdateChunkRes, err error) 19 | UpdateChunkContent(ctx context.Context, req *v1.UpdateChunkContentReq) (res *v1.UpdateChunkContentRes, err error) 20 | DocumentsList(ctx context.Context, req *v1.DocumentsListReq) (res *v1.DocumentsListRes, err error) 21 | DocumentsDelete(ctx context.Context, req *v1.DocumentsDeleteReq) (res *v1.DocumentsDeleteRes, err error) 22 | Indexer(ctx context.Context, req *v1.IndexerReq) (res *v1.IndexerRes, err error) 23 | KBCreate(ctx context.Context, req *v1.KBCreateReq) (res *v1.KBCreateRes, err error) 24 | KBUpdate(ctx context.Context, req *v1.KBUpdateReq) (res *v1.KBUpdateRes, err error) 25 | KBDelete(ctx context.Context, req *v1.KBDeleteReq) (res *v1.KBDeleteRes, err error) 26 | KBGetOne(ctx context.Context, req *v1.KBGetOneReq) (res *v1.KBGetOneRes, err error) 27 | KBGetList(ctx context.Context, req *v1.KBGetListReq) (res *v1.KBGetListRes, err error) 28 | Retriever(ctx context.Context, req *v1.RetrieverReq) (res *v1.RetrieverRes, err error) 29 | RetrieverDify(ctx context.Context, req *v1.RetrieverDifyReq) (res *v1.RetrieverDifyRes, err error) 30 | } 31 | -------------------------------------------------------------------------------- /server/internal/mcp/retriever.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 8 | "github.com/gogf/gf/v2/frame/g" 9 | "github.com/gogf/gf/v2/os/gctx" 10 | v1 "github.com/wangle201210/go-rag/server/api/rag/v1" 11 | ) 12 | 13 | type RetrieverParam struct { 14 | Question string `json:"question" description:"用户提问的问题" required:"true"` 15 | KnowledgeName string `json:"knowledge_name" description:"知识库名称,请先通过getKnowledgeBaseList获取列表后判断是否有符合用户提示词的知识库" required:"true"` 16 | TopK int `json:"top_k" description:"检索结果的数量,默认为5" required:"false"` // 默认为5 17 | Score float64 `json:"score" description:"检索结果的分数阀值,默认为0.2" required:"false"` // 默认为0.2 18 | } 19 | 20 | func GetRetrieverTool() *protocol.Tool { 21 | tool, err := protocol.NewTool("retriever", "检索知识库文档", RetrieverParam{}) 22 | if err != nil { 23 | g.Log().Errorf(gctx.New(), "Failed to create tool: %v", err) 24 | return nil 25 | } 26 | return tool 27 | } 28 | 29 | func HandleRetriever(ctx context.Context, toolReq *protocol.CallToolRequest) (res *protocol.CallToolResult, err error) { 30 | var req RetrieverParam 31 | if err := protocol.VerifyAndUnmarshal(toolReq.RawArguments, &req); err != nil { 32 | return nil, err 33 | } 34 | retriever, err := c.Retriever(ctx, &v1.RetrieverReq{ 35 | Question: req.Question, 36 | TopK: req.TopK, 37 | Score: req.Score, 38 | KnowledgeName: req.KnowledgeName, 39 | }) 40 | if err != nil { 41 | return nil, err 42 | } 43 | docs := retriever.Document 44 | msg := fmt.Sprintf("retrieve %d documents", len(docs)) 45 | for i, doc := range docs { 46 | msg += fmt.Sprintf("\n%d. score: %.2f, content: %s", i+1, doc.Score(), doc.Content) 47 | } 48 | return &protocol.CallToolResult{ 49 | Content: []protocol.Content{ 50 | &protocol.TextContent{ 51 | Type: "text", 52 | Text: msg, 53 | }, 54 | }, 55 | }, nil 56 | } 57 | -------------------------------------------------------------------------------- /server/api/rag/v1/knowledge_base.go: -------------------------------------------------------------------------------- 1 | package v1 2 | 3 | import ( 4 | "github.com/gogf/gf/v2/frame/g" 5 | "github.com/wangle201210/go-rag/server/internal/model/entity" 6 | ) 7 | 8 | // Status marks kb status. 9 | type Status int 10 | 11 | const ( 12 | StatusOK Status = 1 13 | StatusDisabled Status = 2 14 | ) 15 | 16 | type KBCreateReq struct { 17 | g.Meta `path:"/v1/kb" method:"post" tags:"kb" summary:"Create kb"` 18 | Name string `v:"required|length:3,50" dc:"kb name"` 19 | Description string `v:"required|length:3,200" dc:"kb description"` 20 | Category string `v:"length:3,50" dc:"kb category"` 21 | } 22 | 23 | type KBCreateRes struct { 24 | Id int64 `json:"id" dc:"kb id"` 25 | } 26 | 27 | type KBUpdateReq struct { 28 | g.Meta `path:"/v1/kb/{id}" method:"put" tags:"kb" summary:"Update kb"` 29 | Id int64 `v:"required" dc:"kb id"` 30 | Name *string `v:"length:3,50" dc:"kb name"` 31 | Description *string `v:"length:3,200" dc:"kb description"` 32 | Category *string `v:"length:3,50" dc:"kb category"` 33 | Status *Status `v:"in:1,2" dc:"kb status"` 34 | } 35 | type KBUpdateRes struct{} 36 | 37 | type KBDeleteReq struct { 38 | g.Meta `path:"/v1/kb/{id}" method:"delete" tags:"kb" summary:"Delete kb"` 39 | Id int64 `v:"required" dc:"kb id"` 40 | } 41 | type KBDeleteRes struct{} 42 | 43 | type KBGetOneReq struct { 44 | g.Meta `path:"/v1/kb/{id}" method:"get" tags:"kb" summary:"Get one kb"` 45 | Id int64 `v:"required" dc:"kb id"` 46 | } 47 | type KBGetOneRes struct { 48 | *entity.KnowledgeBase `dc:"kb"` 49 | } 50 | 51 | type KBGetListReq struct { 52 | g.Meta `path:"/v1/kb" method:"get" tags:"kb" summary:"Get kbs"` 53 | Name *string `v:"length:3,50" dc:"kb name"` 54 | Status *Status `v:"in:1,2" dc:"kb age"` 55 | Category *string `v:"length:3,50" dc:"kb category"` 56 | } 57 | 58 | type KBGetListRes struct { 59 | List []*entity.KnowledgeBase `json:"list" dc:"kb list"` 60 | } 61 | -------------------------------------------------------------------------------- /server/manifest/config/config_qd_demo.yaml: -------------------------------------------------------------------------------- 1 | # https://goframe.org/docs/web/server-config-file-template 2 | server: 3 | address: ":8000" 4 | openapiPath: "/api.json" 5 | swaggerPath: "/swagger" 6 | 7 | # https://goframe.org/docs/core/glog-config 8 | logger: 9 | level : "all" 10 | stdout: true 11 | path: "./logs/" # 日志文件路径。默认为空,表示关闭,仅输出到终端 12 | file: "{Y-m-d}.log" # 日志文件格式。默认为"{Y-m-d}.log" 13 | 14 | database: 15 | default: 16 | type: "sqlite" # 数据库类型(如:mariadb/tidb/mysql/sqlite) 17 | link: "sqlite::@file(~/.go-rag/chat.db)" # goframe link, mysql:root:12345678@tcp(127.0.0.1:3306)/test?loc=Local&parseTime=true / sqlite::@file(/var/data/db.sqlite3) 18 | host: "~/.go-rag/chat.db" # 地址, sqlite 文件路径,若启用 sqlite 必填 19 | busy_timeout: 5000 # 忙等待超时时间(毫秒) 20 | journal_mode: WAL # 日志模式(推荐 WAL) 21 | synchronous: NORMAL # 同步模式(NORMAL/FULL) 22 | cache_size: -2000 # 缓存大小(负数表示 KB) 23 | max_open_conns: 1 # 最大连接数(SQLite 建议 1) 24 | max_idle_conns: 1 # 最大空闲连接数(SQLite 建议 1) 25 | 26 | vector: 27 | type: "qdrant" # 向量存储类型:es 或 qdrant 28 | indexName: "rag-test" 29 | # es: 30 | # address: "http://elasticsearch:9200" 31 | # username: "elastic" 32 | # password: "123456" 33 | qdrant: 34 | address: "localhost" 35 | # port: "6334" 36 | # apiKey: "" 37 | 38 | embedding: 39 | apiKey: "sk-****" 40 | baseURL: "https://api.siliconflow.cn/v1" 41 | model: "BAAI/bge-m3" 42 | 43 | rerank: 44 | apiKey: "sk-****" 45 | baseURL: "https://api.siliconflow.cn/v1" 46 | model: "BAAI/bge-reranker-v2-m3" 47 | 48 | rewrite: 49 | apiKey: "sk-****" 50 | baseURL: "https://api.siliconflow.cn/v1" 51 | model: "Qwen/Qwen3-14B" # 测试下来14b速度最快 52 | 53 | qa: 54 | apiKey: "sk-****" 55 | baseURL: "https://api.siliconflow.cn/v1" 56 | model: "Qwen/Qwen3-14B" # 测试下来14b速度最快 57 | 58 | chat: 59 | apiKey: "sk-****" 60 | baseURL: "https://api.siliconflow.cn/v1" 61 | model: "deepseek-ai/DeepSeek-V3" 62 | 63 | -------------------------------------------------------------------------------- /fe/src/styles/markdown.css: -------------------------------------------------------------------------------- 1 | .markdown-content { 2 | /* General styles */ 3 | line-height: 1.6; 4 | color: #333; 5 | } 6 | 7 | .markdown-content h1, 8 | .markdown-content h2, 9 | .markdown-content h3, 10 | .markdown-content h4, 11 | .markdown-content h5, 12 | .markdown-content h6 { 13 | margin-top: 24px; 14 | margin-bottom: 16px; 15 | font-weight: 600; 16 | line-height: 1.25; 17 | } 18 | 19 | .markdown-content h1 { 20 | font-size: 2em; 21 | border-bottom: 1px solid #eaecef; 22 | padding-bottom: 0.3em; 23 | } 24 | 25 | .markdown-content h2 { 26 | font-size: 1.5em; 27 | border-bottom: 1px solid #eaecef; 28 | padding-bottom: 0.3em; 29 | } 30 | 31 | .markdown-content h3 { 32 | font-size: 1.25em; 33 | } 34 | 35 | .markdown-content p { 36 | margin-bottom: 16px; 37 | } 38 | 39 | .markdown-content ul, 40 | .markdown-content ol { 41 | padding-left: 2em; 42 | margin-bottom: 16px; 43 | } 44 | 45 | .markdown-content li > p { 46 | margin-bottom: 0; 47 | } 48 | 49 | .markdown-content blockquote { 50 | margin: 0 0 16px; 51 | padding: 0 1em; 52 | color: #6a737d; 53 | border-left: 0.25em solid #dfe2e5; 54 | } 55 | 56 | .markdown-content code { 57 | padding: 0.2em 0.4em; 58 | margin: 0; 59 | font-size: 85%; 60 | background-color: rgba(27, 31, 35, 0.05); 61 | border-radius: 3px; 62 | font-family: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, Courier, monospace; 63 | } 64 | 65 | .markdown-content pre { 66 | margin-bottom: 16px; 67 | padding: 16px; 68 | overflow: auto; 69 | font-size: 85%; 70 | line-height: 1.45; 71 | background-color: #f6f8fa; 72 | border-radius: 3px; 73 | } 74 | 75 | .markdown-content pre code { 76 | padding: 0; 77 | margin: 0; 78 | font-size: 100%; 79 | background-color: transparent; 80 | border: 0; 81 | } 82 | 83 | .markdown-content table { 84 | width: 100%; 85 | border-collapse: collapse; 86 | margin-bottom: 16px; 87 | } 88 | 89 | .markdown-content th, 90 | .markdown-content td { 91 | border: 1px solid #dfe2e5; 92 | padding: 6px 13px; 93 | } 94 | 95 | .markdown-content th { 96 | font-weight: 600; 97 | } -------------------------------------------------------------------------------- /server/core/indexer/qa.go: -------------------------------------------------------------------------------- 1 | package indexer 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | 8 | "github.com/cloudwego/eino/schema" 9 | "github.com/gogf/gf/v2/frame/g" 10 | "github.com/wangle201210/go-rag/server/core/common" 11 | coretypes "github.com/wangle201210/go-rag/server/core/types" 12 | ) 13 | 14 | func qa(ctx context.Context, docs []*schema.Document) (output []*schema.Document, err error) { 15 | var knowledgeName string 16 | if value, ok := ctx.Value(coretypes.KnowledgeName).(string); ok { 17 | knowledgeName = value 18 | } else { 19 | err = fmt.Errorf("必须提供知识库名称") 20 | return 21 | } 22 | wg := &sync.WaitGroup{} 23 | for _, doc := range docs { 24 | wg.Add(1) 25 | go func(doc *schema.Document) { 26 | defer wg.Done() 27 | qaContent, e := getQAContent(ctx, doc, knowledgeName) 28 | if e != nil { 29 | g.Log().Errorf(ctx, "getQAContent failed, err=%v", e) 30 | return 31 | } 32 | // 生成QA和内容放在一个chunk的不同字段 33 | doc.MetaData[coretypes.FieldQAContent] = qaContent 34 | }(doc) 35 | } 36 | wg.Wait() 37 | return docs, nil 38 | } 39 | 40 | func getQAContent(ctx context.Context, doc *schema.Document, knowledgeName string) (qaContent string, err error) { 41 | // 已经有数据了就不要再生成了 42 | if s, ok := doc.MetaData[coretypes.FieldQAContent].(string); ok && len(s) > 0 { 43 | return s, nil 44 | } 45 | cm, err := common.GetQAModel(ctx, nil) 46 | if err != nil { 47 | return 48 | } 49 | generate, err := cm.Generate(ctx, []*schema.Message{ 50 | { 51 | Role: schema.System, 52 | Content: fmt.Sprintf("你是一个专业的问题生成助手,任务是从给定的文本中提取或生成可能的问题。你不需要回答这些问题,只需生成问题本身。\n"+ 53 | "知识库名字是:《%s》\n\n"+ 54 | "输出格式:\n"+ 55 | "- 每个问题占一行\n"+ 56 | "- 问题必须以问号结尾\n"+ 57 | "- 避免重复或语义相似的问题\n\n"+ 58 | "生成规则:\n"+ 59 | "- 生成的问题必须严格基于文本内容,不能脱离文本虚构。\n"+ 60 | "- 优先生成事实性问题(如谁、何时、何地、如何)。\n"+ 61 | "- 对于复杂文本,可生成多层次问题(基础事实 + 推理问题)。\n"+ 62 | "- 禁止生成主观或开放式问题(如“你认为...?”)。"+ 63 | "- 数量控制在3-5个", knowledgeName), 64 | }, 65 | { 66 | Role: schema.User, 67 | Content: doc.Content, 68 | }, 69 | }) 70 | if err != nil { 71 | return 72 | } 73 | qaContent = generate.Content 74 | return 75 | } 76 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | go-rag: 5 | image: iwangle/go-rag:v0.0.3 6 | # build: 7 | # context: . 8 | # dockerfile: Dockerfile 9 | ports: 10 | - "8000:8000" 11 | restart: unless-stopped 12 | environment: 13 | - TZ=Asia/Shanghai 14 | - ES_HOST=elasticsearch 15 | - ES_PORT=9200 16 | - MYSQL_HOST=mysql 17 | - MYSQL_PORT=3306 18 | - MYSQL_USER=root 19 | - MYSQL_PASSWORD=123456 20 | - MYSQL_DATABASE=go-rag 21 | volumes: 22 | # 如果需要持久化配置或数据,可以添加相应的卷挂载,目前是在构建镜像时就copy过去的 23 | - ./server/manifest/config/config.yaml:/app/manifest/config/config.yaml 24 | depends_on: 25 | mysql: 26 | condition: service_healthy 27 | elasticsearch: 28 | condition: service_healthy 29 | 30 | elasticsearch: 31 | image: elasticsearch:8.11.3 32 | environment: 33 | - discovery.type=single-node 34 | - xpack.security.enabled=false 35 | - "ES_JAVA_OPTS=-Xms512m -Xmx512m" 36 | - "cluster.routing.allocation.disk.watermark.low=1gb" # 低于 1GB 停止分配 37 | - "cluster.routing.allocation.disk.watermark.high=1gb" # 低于 1GB 迁移分片 38 | - "cluster.routing.allocation.disk.watermark.flood_stage=1gb" # 低于 1GB 设为只读 39 | ports: 40 | - "9200:9200" 41 | volumes: 42 | - ./data/es_data:/usr/share/elasticsearch/data 43 | restart: unless-stopped 44 | healthcheck: 45 | test: [ "CMD-SHELL", "curl -f http://localhost:9200/_cluster/health || exit 1" ] 46 | interval: 10s 47 | timeout: 5s 48 | retries: 10 49 | 50 | mysql: 51 | image: mysql:8.0 52 | environment: 53 | - MYSQL_ROOT_PASSWORD=123456 54 | - MYSQL_DATABASE=go-rag 55 | - MYSQL_ROOT_HOST=% # 允许root从任意主机连接 56 | - MYSQL_CHARSET=utf8mb4 # 设置数据库字符集为utf8mb4 57 | ports: 58 | - "3306:3306" 59 | volumes: 60 | - ./data/mysql_data:/var/lib/mysql 61 | restart: unless-stopped 62 | healthcheck: 63 | test: [ "CMD", "mysqladmin", "ping", "-h", "localhost", "-u", "root", "-p123456" ] 64 | interval: 10s 65 | timeout: 5s 66 | retries: 10 67 | -------------------------------------------------------------------------------- /server/internal/logic/chat/message.go: -------------------------------------------------------------------------------- 1 | package chat 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/cloudwego/eino/components/prompt" 8 | "github.com/cloudwego/eino/schema" 9 | "github.com/gogf/gf/v2/frame/g" 10 | ) 11 | 12 | const ( 13 | role = "你是一个专业的AI助手,能够根据提供的参考信息准确回答用户问题。" 14 | ) 15 | 16 | // createTemplate 创建并返回一个配置好的聊天模板 17 | func createTemplate() prompt.ChatTemplate { 18 | // 创建模板,使用 FString 格式 19 | return prompt.FromMessages(schema.FString, 20 | // 系统消息模板 21 | schema.SystemMessage("{role}"+ 22 | "请严格遵守以下规则:\n"+ 23 | "1. 回答必须基于提供的参考内容,不要依赖外部知识\n"+ 24 | "2. 如果参考内容中有明确答案,直接使用参考内容回答\n"+ 25 | "3. 如果参考内容不完整或模糊,可以合理推断但需说明\n"+ 26 | "4. 如果参考内容完全不相关或不存在,如实告知用户'根据现有资料无法回答'\n"+ 27 | "5. 保持回答专业、简洁、准确\n"+ 28 | "6. 必要时可引用参考内容中的具体数据或原文\n\n"+ 29 | "当前提供的参考内容:\n"+ 30 | "{docs}\n\n"+ 31 | ""), 32 | schema.MessagesPlaceholder("chat_history", true), 33 | // 用户消息模板 34 | schema.UserMessage("Question: {question}"), 35 | ) 36 | } 37 | 38 | // formatMessages 格式化消息并处理错误 39 | func formatMessages(template prompt.ChatTemplate, data map[string]any) ([]*schema.Message, error) { 40 | messages, err := template.Format(context.Background(), data) 41 | if err != nil { 42 | return nil, fmt.Errorf("格式化模板失败: %w", err) 43 | } 44 | return messages, nil 45 | } 46 | 47 | // docsMessages 将检索到的上下文和问题转换为消息列表 48 | func (x *Chat) docsMessages(ctx context.Context, convID string, docs []*schema.Document, question string) (messages []*schema.Message, err error) { 49 | chatHistory, err := x.eh.GetHistory(convID, 100) 50 | if err != nil { 51 | return 52 | } 53 | // 插入一条用户数据 54 | err = x.eh.SaveMessage(&schema.Message{ 55 | Role: schema.User, 56 | Content: question, 57 | }, convID) 58 | if err != nil { 59 | return 60 | } 61 | template := createTemplate() 62 | for i, doc := range docs { 63 | g.Log().Debugf(context.Background(), "docs[%d]: %s", i, doc.Content) 64 | } 65 | data := map[string]any{ 66 | "role": role, 67 | "question": question, 68 | "docs": docs, 69 | "chat_history": chatHistory, 70 | } 71 | messages, err = formatMessages(template, data) 72 | if err != nil { 73 | return 74 | } 75 | return 76 | } 77 | -------------------------------------------------------------------------------- /server/internal/controller/rag/rag_v1_update_chunk_content.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/bytedance/sonic" 8 | "github.com/cloudwego/eino/schema" 9 | "github.com/gogf/gf/v2/frame/g" 10 | "github.com/gogf/gf/v2/os/gctx" 11 | v1 "github.com/wangle201210/go-rag/server/api/rag/v1" 12 | gorag "github.com/wangle201210/go-rag/server/core" 13 | "github.com/wangle201210/go-rag/server/internal/logic/knowledge" 14 | "github.com/wangle201210/go-rag/server/internal/logic/rag" 15 | "github.com/wangle201210/go-rag/server/internal/model/entity" 16 | ) 17 | 18 | func (c *ControllerV1) UpdateChunkContent(ctx context.Context, req *v1.UpdateChunkContentReq) (res *v1.UpdateChunkContentRes, err error) { 19 | chunk, err := knowledge.GetChunkById(ctx, req.Id) 20 | if err != nil { 21 | g.Log().Errorf(ctx, "GetChunkById failed, err=%v", err) 22 | return 23 | } 24 | 25 | document, err := knowledge.GetDocumentById(ctx, chunk.KnowledgeDocId) 26 | if err != nil { 27 | g.Log().Errorf(ctx, "GetDocumentById failed, err=%v", err) 28 | return 29 | } 30 | 31 | knowledgeName := document.KnowledgeBaseName 32 | 33 | err = knowledge.UpdateChunkByIds(ctx, []int64{req.Id}, entity.KnowledgeChunks{ 34 | Content: req.Content, 35 | }) 36 | if err != nil { 37 | g.Log().Errorf(ctx, "UpdateChunkByIds failed, err=%v", err) 38 | return 39 | } 40 | 41 | go func() { 42 | // 等待一段时间确保数据库更新完成 43 | time.Sleep(time.Millisecond * 500) 44 | 45 | ctxN := gctx.New() 46 | defer func() { 47 | if e := recover(); e != nil { 48 | g.Log().Errorf(ctxN, "recover updateChunkContent failed, err=%v", e) 49 | } 50 | }() 51 | 52 | doc := &schema.Document{ 53 | ID: chunk.ChunkId, 54 | Content: req.Content, 55 | } 56 | 57 | if chunk.Ext != "" { 58 | extData := map[string]any{} 59 | if err := sonic.Unmarshal([]byte(chunk.Ext), &extData); err == nil { 60 | doc.MetaData = extData 61 | } 62 | } 63 | 64 | // 调用异步索引更新 65 | ragSvr := rag.GetRagSvr() 66 | asyncReq := &gorag.IndexAsyncReq{ 67 | Docs: []*schema.Document{doc}, 68 | KnowledgeName: knowledgeName, 69 | DocumentsId: chunk.KnowledgeDocId, 70 | } 71 | 72 | _, err = ragSvr.IndexAsync(ctxN, asyncReq) 73 | if err != nil { 74 | g.Log().Errorf(ctxN, "IndexAsync failed, err=%v", err) 75 | } else { 76 | g.Log().Infof(ctxN, "Chunk content updated and reindexed successfully, chunk_id=%d", req.Id) 77 | } 78 | }() 79 | 80 | return 81 | } 82 | -------------------------------------------------------------------------------- /server/core/vector/README.md: -------------------------------------------------------------------------------- 1 | # 向量存储抽象层 2 | 3 | 本模块提供了向量存储的抽象接口,支持多种向量数据库实现。 4 | 5 | ## 支持的向量存储 6 | 7 | - **Elasticsearch (ES)** 8 | - **Qdrant** 9 | 10 | ## 配置说明 11 | 12 | ### 使用 Elasticsearch 13 | 14 | ```yaml 15 | vector: 16 | type: "es" # 或 "elasticsearch" 17 | indexName: "rag-test" 18 | es: 19 | address: "http://elasticsearch:9200" 20 | username: "elastic" # 可选 21 | password: "123456" # 可选 22 | ``` 23 | 24 | ### 使用 Qdrant 25 | 26 | ```yaml 27 | vector: 28 | type: "qdrant" 29 | indexName: "rag-test" 30 | qdrant: 31 | address: "http://qdrant:6333" 32 | apiKey: "" # 可选,如果需要认证 33 | ``` 34 | 35 | ## 接口说明 36 | 37 | ### VectorStore 接口 38 | 39 | ```go 40 | type VectorStore interface { 41 | // 创建索引/集合 42 | CreateIndex(ctx context.Context, indexName string) error 43 | 44 | // 检查索引/集合是否存在 45 | IndexExists(ctx context.Context, indexName string) (bool, error) 46 | 47 | // 删除文档 48 | DeleteDocument(ctx context.Context, indexName, documentID string) error 49 | 50 | // 获取知识库列表 51 | GetKnowledgeBaseList(ctx context.Context, indexName string) ([]string, error) 52 | 53 | // 搜索文档 54 | SearchDocuments(ctx context.Context, req *SearchRequest) (*SearchResponse, error) 55 | 56 | // 关闭连接 57 | Close() error 58 | } 59 | ``` 60 | 61 | ## 实现新的向量存储 62 | 63 | 要添加新的向量存储实现: 64 | 65 | 1. 在 `vector` 包中创建新文件(如 `milvus.go`) 66 | 2. 实现 `VectorStore` 接口 67 | 3. 在 `factory.go` 中添加新类型的支持 68 | 4. 更新配置结构 69 | 70 | ## 迁移说明 71 | 72 | ### 从旧配置迁移 73 | 74 | 旧配置格式: 75 | ```yaml 76 | es: 77 | address: "http://elasticsearch:9200" 78 | indexName: "rag-test" 79 | username: "elastic" 80 | password: "123456" 81 | ``` 82 | 83 | 新配置格式: 84 | ```yaml 85 | vector: 86 | type: "es" 87 | indexName: "rag-test" 88 | es: 89 | address: "http://elasticsearch:9200" 90 | username: "elastic" 91 | password: "123456" 92 | ``` 93 | 94 | ### 代码迁移 95 | 96 | 旧代码: 97 | ```go 98 | err := common.CreateIndexIfNotExists(ctx, client, indexName) 99 | err := common.DeleteDocument(ctx, client, documentID) 100 | ``` 101 | 102 | 新代码: 103 | ```go 104 | exists, err := vectorStore.IndexExists(ctx, indexName) 105 | if !exists { 106 | err = vectorStore.CreateIndex(ctx, indexName) 107 | } 108 | err := vectorStore.DeleteDocument(ctx, indexName, documentID) 109 | ``` 110 | 111 | ## TODO 112 | 113 | - [ ] 添加单元测试 114 | - [ ] 添加 Milvus 支持 115 | - [ ] 添加 Pinecone 支持 116 | -------------------------------------------------------------------------------- /server/core/grader/message.go: -------------------------------------------------------------------------------- 1 | package grader 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/cloudwego/eino/components/prompt" 8 | "github.com/cloudwego/eino/schema" 9 | ) 10 | 11 | // createRetrieverTemplate 判断检索到的文档是否足够回答用户问题 12 | func createRetrieverTemplate() prompt.ChatTemplate { 13 | // 创建模板,使用 FString 格式 14 | return prompt.FromMessages(schema.FString, 15 | // 系统消息模板 16 | schema.SystemMessage( 17 | "您是一名评估检索到的文档是否足够回答用户问题的专家。"+ 18 | "请先仔细理解用户问题"+ 19 | "如果检索到的文档足够回答用户问题,请给出 'yes',"+ 20 | "如果检索到的文档不足以回答用户问题,请给出 'no'。"+ 21 | "不要给出任何其他解释。", 22 | ), 23 | // 用户消息模板 24 | schema.UserMessage( 25 | "这是检索到的文档: \n"+ 26 | "{document} \n\n"+ 27 | "这是用户的问题: {question}"), 28 | ) 29 | } 30 | 31 | // createDocRelatedTemplate 判断检索到的文档是否和用户问题相关 32 | func createDocRelatedTemplate() prompt.ChatTemplate { 33 | // 创建模板,使用 FString 格式 34 | return prompt.FromMessages(schema.FString, 35 | // 系统消息模板 36 | schema.SystemMessage( 37 | "您是一名评估检索到的文档是否和用户问题相关的专家。"+ 38 | "这里不需要是一个严格的测试,目标是过滤掉错误的检索。"+ 39 | "如果检索到的文档和用户问题相关,请给出 'yes',"+ 40 | "如果检索到的文档和用户问题不相关,请给出 'no'。"+ 41 | "不要给出任何其他解释。", 42 | ), 43 | // 用户消息模板 44 | schema.UserMessage( 45 | "<|start_documents|> \n"+ 46 | "{document} <|end_documents|>\n"+ 47 | "<|start_query|>{question}<|end_query|>"), 48 | ) 49 | } 50 | 51 | // formatMessages 格式化消息并处理错误 52 | func formatMessages(template prompt.ChatTemplate, data map[string]any) ([]*schema.Message, error) { 53 | messages, err := template.Format(context.Background(), data) 54 | if err != nil { 55 | return nil, fmt.Errorf("格式化模板失败: %w", err) 56 | } 57 | return messages, nil 58 | } 59 | 60 | func retrieverMessages(docs []*schema.Document, question string) ([]*schema.Message, error) { 61 | document := "" 62 | for i, doc := range docs { 63 | document += fmt.Sprintf("docs[%d]: %s", i, doc.Content) 64 | } 65 | template := createRetrieverTemplate() 66 | data := map[string]any{ 67 | "question": question, 68 | "document": document, 69 | } 70 | messages, err := formatMessages(template, data) 71 | if err != nil { 72 | return nil, err 73 | } 74 | return messages, nil 75 | } 76 | 77 | func docRelatedMessages(doc *schema.Document, question string) ([]*schema.Message, error) { 78 | template := createDocRelatedTemplate() 79 | data := map[string]any{ 80 | "question": question, 81 | "document": doc, 82 | } 83 | messages, err := formatMessages(template, data) 84 | if err != nil { 85 | return nil, err 86 | } 87 | return messages, nil 88 | } 89 | -------------------------------------------------------------------------------- /fe/src/components/KnowledgeNameSetting.vue_back: -------------------------------------------------------------------------------- 1 | 41 | 42 | 74 | 75 | -------------------------------------------------------------------------------- /server/internal/logic/rag/retriever.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "github.com/elastic/go-elasticsearch/v8" 5 | "github.com/gogf/gf/v2/frame/g" 6 | "github.com/gogf/gf/v2/os/gctx" 7 | "github.com/qdrant/go-client/qdrant" 8 | "github.com/wangle201210/go-rag/server/core" 9 | "github.com/wangle201210/go-rag/server/core/config" 10 | "github.com/wangle201210/go-rag/server/core/vector" 11 | ) 12 | 13 | var ragSvr = &core.Rag{} 14 | 15 | func init() { 16 | ctx := gctx.New() 17 | 18 | // 读取向量存储配置 19 | vectorType := g.Cfg().MustGet(ctx, "vector.type").String() 20 | indexName := g.Cfg().MustGet(ctx, "vector.indexName").String() 21 | 22 | // 创建向量存储配置 23 | vectorCfg := &vector.Config{ 24 | Type: vectorType, 25 | IndexName: indexName, 26 | } 27 | 28 | // 根据类型配置 29 | if vectorType == "es" || vectorType == "elasticsearch" { 30 | vectorCfg.ES = &vector.ESConfig{ 31 | Address: g.Cfg().MustGet(ctx, "vector.es.address").String(), 32 | Username: g.Cfg().MustGet(ctx, "vector.es.username").String(), 33 | Password: g.Cfg().MustGet(ctx, "vector.es.password").String(), 34 | } 35 | } else if vectorType == "qdrant" { 36 | vectorCfg.Qdrant = &vector.QdrantConfig{ 37 | Address: g.Cfg().MustGet(ctx, "vector.qdrant.address").String(), 38 | Port: g.Cfg().MustGet(ctx, "vector.qdrant.port").Int(), 39 | APIKey: g.Cfg().MustGet(ctx, "vector.qdrant.apiKey").String(), 40 | } 41 | } 42 | 43 | // 创建向量存储实例 44 | vectorStore, err := vector.NewVectorStore(vectorCfg) 45 | if err != nil { 46 | g.Log().Fatalf(ctx, "NewVectorStore failed, err=%v", err) 47 | return 48 | } 49 | 50 | // 根据类型获取对应的客户端 51 | var client *elasticsearch.Client 52 | var qdrantClient *qdrant.Client 53 | 54 | if esStore, ok := vectorStore.(*vector.ESVectorStore); ok { 55 | client = esStore.GetClient() 56 | } else if qdrantStore, ok := vectorStore.(*vector.QdrantVectorStore); ok { 57 | qdrantClient = qdrantStore.GetClient() 58 | } 59 | 60 | ragSvr, err = core.New(ctx, &config.Config{ 61 | Client: client, 62 | QdrantClient: qdrantClient, 63 | IndexName: indexName, 64 | APIKey: g.Cfg().MustGet(ctx, "embedding.apiKey").String(), 65 | BaseURL: g.Cfg().MustGet(ctx, "embedding.baseURL").String(), 66 | EmbeddingModel: g.Cfg().MustGet(ctx, "embedding.model").String(), 67 | ChatModel: g.Cfg().MustGet(ctx, "chat.model").String(), 68 | }) 69 | if err != nil { 70 | g.Log().Fatalf(ctx, "New of rag failed, err=%v", err) 71 | return 72 | } 73 | } 74 | 75 | func GetRagSvr() *core.Rag { 76 | return ragSvr 77 | } 78 | -------------------------------------------------------------------------------- /server/core/common/stream.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "time" 8 | 9 | "github.com/bytedance/sonic" 10 | "github.com/cloudwego/eino/schema" 11 | "github.com/gogf/gf/v2/frame/g" 12 | "github.com/gogf/gf/v2/net/ghttp" 13 | "github.com/google/uuid" 14 | ) 15 | 16 | type StreamData struct { 17 | Id string `json:"id"` // 同一个消息里面的id是相同的 18 | Created int64 `json:"created"` // 消息初始生成时间 19 | Content string `json:"content"` // 消息具体内容 20 | Document []*schema.Document `json:"document"` 21 | } 22 | 23 | func SteamResponse(ctx context.Context, streamReader *schema.StreamReader[*schema.Message], docs []*schema.Document) (err error) { 24 | // 获取HTTP响应对象 25 | httpReq := ghttp.RequestFromCtx(ctx) 26 | httpResp := httpReq.Response 27 | // 设置响应头 28 | httpResp.Header().Set("Content-Type", "text/event-stream") 29 | httpResp.Header().Set("Cache-Control", "no-cache") 30 | httpResp.Header().Set("Connection", "keep-alive") 31 | httpResp.Header().Set("X-Accel-Buffering", "no") // 禁用Nginx缓冲 32 | httpResp.Header().Set("Access-Control-Allow-Origin", "*") 33 | sd := &StreamData{ 34 | Id: uuid.NewString(), 35 | Created: time.Now().Unix(), 36 | } 37 | if len(docs) > 0 { 38 | sd.Document = docs 39 | marshal, _ := sonic.Marshal(sd) 40 | writeSSEDocuments(httpResp, string(marshal)) 41 | } 42 | sd.Document = nil // 置空,发一次就够了 43 | // 处理流式响应 44 | for { 45 | chunk, err := streamReader.Recv() 46 | if err == io.EOF { 47 | break 48 | } 49 | if err != nil { 50 | writeSSEError(httpResp, err) 51 | break 52 | } 53 | if len(chunk.Content) == 0 { 54 | continue 55 | } 56 | 57 | sd.Content = chunk.Content 58 | marshal, _ := sonic.Marshal(sd) 59 | // 发送数据事件 60 | writeSSEData(httpResp, string(marshal)) 61 | } 62 | // 发送结束事件 63 | writeSSEDone(httpResp) 64 | return nil 65 | } 66 | 67 | // writeSSEData 写入SSE事件 68 | func writeSSEData(resp *ghttp.Response, data string) { 69 | if len(data) == 0 { 70 | return 71 | } 72 | // g.Log().Infof(context.Background(), "data: %s", data) 73 | resp.Writeln(fmt.Sprintf("data:%s\n", data)) 74 | resp.Flush() 75 | } 76 | 77 | func writeSSEDone(resp *ghttp.Response) { 78 | resp.Writeln(fmt.Sprintf("data:%s\n", "[DONE]")) 79 | resp.Flush() 80 | } 81 | 82 | func writeSSEDocuments(resp *ghttp.Response, data string) { 83 | resp.Writeln(fmt.Sprintf("documents:%s\n", data)) 84 | resp.Flush() 85 | } 86 | 87 | // writeSSEError 写入SSE错误 88 | func writeSSEError(resp *ghttp.Response, err error) { 89 | g.Log().Error(context.Background(), err) 90 | resp.Writeln(fmt.Sprintf("event: error\ndata: %s\n\n", err.Error())) 91 | resp.Flush() 92 | } 93 | -------------------------------------------------------------------------------- /server/internal/cmd/middleware.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "mime" 5 | "net/http" 6 | "reflect" 7 | 8 | "github.com/gogf/gf/v2/errors/gcode" 9 | "github.com/gogf/gf/v2/errors/gerror" 10 | "github.com/gogf/gf/v2/net/ghttp" 11 | "github.com/gogf/gf/v2/util/gmeta" 12 | ) 13 | 14 | const ( 15 | contentTypeEventStream = "text/event-stream" 16 | contentTypeOctetStream = "application/octet-stream" 17 | contentTypeMixedReplace = "multipart/x-mixed-replace" 18 | ) 19 | 20 | var ( 21 | // streamContentType is the content types for stream response. 22 | streamContentType = []string{contentTypeEventStream, contentTypeOctetStream, contentTypeMixedReplace} 23 | ) 24 | 25 | // MiddlewareHandlerResponse is the default middleware handling handler response object and its error. 26 | func MiddlewareHandlerResponse(r *ghttp.Request) { 27 | r.Middleware.Next() 28 | 29 | // There's custom buffer content, it then exits current handler. 30 | if r.Response.BufferLength() > 0 || r.Response.Writer.BytesWritten() > 0 { 31 | return 32 | } 33 | 34 | // It does not output common response content if it is stream response. 35 | mediaType, _, _ := mime.ParseMediaType(r.Response.Header().Get("Content-Type")) 36 | for _, ct := range streamContentType { 37 | if mediaType == ct { 38 | return 39 | } 40 | } 41 | 42 | var ( 43 | msg string 44 | err = r.GetError() 45 | res = r.GetHandlerResponse() 46 | code = gerror.Code(err) 47 | ) 48 | if err != nil { 49 | if code == gcode.CodeNil { 50 | code = gcode.CodeInternalError 51 | } 52 | msg = err.Error() 53 | } else { 54 | if r.Response.Status > 0 && r.Response.Status != http.StatusOK { 55 | switch r.Response.Status { 56 | case http.StatusNotFound: 57 | code = gcode.CodeNotFound 58 | case http.StatusForbidden: 59 | code = gcode.CodeNotAuthorized 60 | default: 61 | code = gcode.CodeUnknown 62 | } 63 | // It creates an error as it can be retrieved by other middlewares. 64 | err = gerror.NewCode(code, msg) 65 | r.SetError(err) 66 | } else { 67 | code = gcode.CodeOK 68 | } 69 | msg = code.Message() 70 | } 71 | if noWrapResp(r) { 72 | r.Response.WriteJson(res) 73 | return 74 | } 75 | r.Response.WriteJson(ghttp.DefaultHandlerResponse{ 76 | Code: code.Code(), 77 | Message: msg, 78 | Data: res, 79 | }) 80 | } 81 | 82 | // 中间件中判断 83 | func noWrapResp(r *ghttp.Request) bool { 84 | handler := r.GetServeHandler().Handler 85 | if handler.Info.Type != nil && handler.Info.Type.NumIn() == 2 { 86 | var objectReq = reflect.New(handler.Info.Type.In(1)) 87 | if v := gmeta.Get(objectReq, "no_wrap_resp"); !v.IsEmpty() { 88 | return v.Bool() 89 | } 90 | } 91 | return false 92 | } 93 | -------------------------------------------------------------------------------- /server/core/rag.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/cloudwego/eino/components/model" 8 | "github.com/cloudwego/eino/compose" 9 | "github.com/cloudwego/eino/schema" 10 | "github.com/elastic/go-elasticsearch/v8" 11 | "github.com/gogf/gf/v2/frame/g" 12 | "github.com/wangle201210/go-rag/server/core/common" 13 | "github.com/wangle201210/go-rag/server/core/config" 14 | "github.com/wangle201210/go-rag/server/core/grader" 15 | "github.com/wangle201210/go-rag/server/core/indexer" 16 | "github.com/wangle201210/go-rag/server/core/retriever" 17 | coretypes "github.com/wangle201210/go-rag/server/core/types" 18 | ) 19 | 20 | const ( 21 | scoreThreshold = 1.05 // 设置一个很小的阈值 22 | esTopK = 50 23 | esTryFindDoc = 10 24 | ) 25 | 26 | type Rag struct { 27 | idxer compose.Runnable[any, []string] 28 | idxerAsync compose.Runnable[[]*schema.Document, []string] 29 | rtrvr compose.Runnable[string, []*schema.Document] 30 | qaRtrvr compose.Runnable[string, []*schema.Document] 31 | client *elasticsearch.Client // 保留用于兼容 32 | cm model.BaseChatModel 33 | 34 | grader *grader.Grader // 暂时先弃用,使用 grader 会严重影响rag的速度 35 | conf *config.Config 36 | } 37 | 38 | func New(ctx context.Context, conf *config.Config) (*Rag, error) { 39 | if len(conf.IndexName) == 0 { 40 | return nil, fmt.Errorf("indexName is empty") 41 | } 42 | // 确保 index 存在 43 | exists, err := conf.IndexExists(ctx) 44 | if err != nil { 45 | return nil, err 46 | } 47 | if !exists { 48 | err = conf.CreateIndex(ctx) 49 | if err != nil { 50 | return nil, err 51 | } 52 | } 53 | buildIndex, err := indexer.BuildIndexer(ctx, conf) 54 | if err != nil { 55 | return nil, err 56 | } 57 | buildIndexAsync, err := indexer.BuildIndexerAsync(ctx, conf) 58 | if err != nil { 59 | return nil, err 60 | } 61 | buildRetriever, err := retriever.BuildRetriever(ctx, conf) 62 | if err != nil { 63 | return nil, err 64 | } 65 | qaCtx := context.WithValue(ctx, coretypes.RetrieverFieldKey, coretypes.FieldQAContentVector) 66 | qaRetriever, err := retriever.BuildRetriever(qaCtx, conf) 67 | if err != nil { 68 | return nil, err 69 | } 70 | cm, err := common.GetChatModel(ctx, nil) 71 | if err != nil { 72 | g.Log().Error(ctx, "GetChatModel failed, err=%v", err) 73 | return nil, err 74 | } 75 | return &Rag{ 76 | idxer: buildIndex, 77 | idxerAsync: buildIndexAsync, 78 | rtrvr: buildRetriever, 79 | qaRtrvr: qaRetriever, 80 | client: conf.Client, 81 | cm: cm, 82 | conf: conf, 83 | // grader: grader.NewGrader(cm), 84 | }, nil 85 | } 86 | 87 | // GetKnowledgeBaseList 获取知识库列表 88 | func (x *Rag) GetKnowledgeBaseList(ctx context.Context) (list []string, err error) { 89 | return x.conf.GetKnowledgeBaseList(ctx) 90 | } 91 | -------------------------------------------------------------------------------- /server/core/indexer/indexer_async.go: -------------------------------------------------------------------------------- 1 | package indexer 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/bytedance/sonic" 8 | "github.com/cloudwego/eino-ext/components/indexer/es8" 9 | "github.com/cloudwego/eino/components/indexer" 10 | "github.com/cloudwego/eino/schema" 11 | "github.com/wangle201210/go-rag/server/core/common" 12 | "github.com/wangle201210/go-rag/server/core/config" 13 | coretypes "github.com/wangle201210/go-rag/server/core/types" 14 | ) 15 | 16 | // newAsyncIndexer component initialization function of node 'Indexer2' in graph 'rag' 17 | func newAsyncIndexer(ctx context.Context, conf *config.Config) (idr indexer.Indexer, err error) { 18 | embeddingIns11, err := common.NewEmbedding(ctx, conf) 19 | if err != nil { 20 | return nil, err 21 | } 22 | 23 | // 根据向量存储类型创建不同的 indexer 24 | if conf.Client != nil { 25 | // ES indexer 26 | indexerConfig := &es8.IndexerConfig{ 27 | Client: conf.Client, 28 | Index: conf.IndexName, 29 | BatchSize: 10, 30 | DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]es8.FieldValue, err error) { 31 | var knowledgeName string 32 | if value, ok := ctx.Value(coretypes.KnowledgeName).(string); ok { 33 | knowledgeName = value 34 | } else { 35 | err = fmt.Errorf("必须提供知识库名称") 36 | return 37 | } 38 | if doc.MetaData != nil { 39 | // 存储ext数据 40 | marshal, _ := sonic.Marshal(getExtData(doc)) 41 | doc.MetaData[coretypes.FieldExtra] = string(marshal) 42 | } 43 | return map[string]es8.FieldValue{ 44 | coretypes.FieldContent: { 45 | Value: doc.Content, 46 | EmbedKey: coretypes.FieldContentVector, 47 | }, 48 | coretypes.FieldExtra: { 49 | Value: doc.MetaData[coretypes.FieldExtra], 50 | }, 51 | coretypes.KnowledgeName: { 52 | Value: knowledgeName, 53 | }, 54 | coretypes.FieldQAContent: { 55 | Value: doc.MetaData[coretypes.FieldQAContent], 56 | EmbedKey: coretypes.FieldQAContentVector, 57 | }, 58 | }, nil 59 | }, 60 | } 61 | indexerConfig.Embedding = embeddingIns11 62 | idr, err = es8.NewIndexer(ctx, indexerConfig) 63 | if err != nil { 64 | return nil, err 65 | } 66 | return idr, nil 67 | } else if conf.QdrantClient != nil { 68 | // Qdrant indexer 69 | idr, err = NewQdrantIndexer(ctx, &QdrantIndexerConfig{ 70 | Client: conf.QdrantClient, 71 | Collection: conf.IndexName, 72 | VectorDim: 1024, // 根据你的 embedding 模型调整 73 | Distance: 0, // 使用默认 Cosine 74 | Embedding: embeddingIns11, 75 | BatchSize: 10, 76 | IsAsync: true, 77 | }) 78 | if err != nil { 79 | return nil, err 80 | } 81 | return idr, nil 82 | } else { 83 | return nil, fmt.Errorf("no valid client configuration found") 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /server/core/indexer/qdrant_indexer.go: -------------------------------------------------------------------------------- 1 | package indexer 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | einoqdrant "github.com/cloudwego/eino-ext/components/indexer/qdrant" 8 | "github.com/cloudwego/eino/components/embedding" 9 | "github.com/cloudwego/eino/components/indexer" 10 | "github.com/cloudwego/eino/schema" 11 | "github.com/qdrant/go-client/qdrant" 12 | ) 13 | 14 | // QdrantIndexerConfig Qdrant indexer 配置 15 | type QdrantIndexerConfig struct { 16 | Client *qdrant.Client // Required: Qdrant client 17 | Collection string // Required: Collection name 18 | VectorDim int // Required: Vector dimension 19 | Distance qdrant.Distance // Required: Distance metric 20 | BatchSize int // Optional: Batch size (default: 10) 21 | Embedding embedding.Embedder // Required: Embedding component 22 | IsAsync bool // Optional: 是否异步模式(包含 QA 向量) 23 | } 24 | 25 | // QdrantIndexer Qdrant indexer 实现(包装 eino-ext 的实现) 26 | type QdrantIndexer struct { 27 | config *QdrantIndexerConfig 28 | einoIndexer indexer.Indexer // eino-ext 的 indexer 29 | asyncIndexer indexer.Indexer // 异步模式的 indexer(用于 QA 向量) 30 | } 31 | 32 | // NewQdrantIndexer 创建 Qdrant indexer,使用 eino-ext 库 33 | func NewQdrantIndexer(ctx context.Context, config *QdrantIndexerConfig) (indexer.Indexer, error) { 34 | if config.Client == nil { 35 | return nil, fmt.Errorf("qdrant client is required") 36 | } 37 | if config.Collection == "" { 38 | return nil, fmt.Errorf("collection name is required") 39 | } 40 | if config.Embedding == nil { 41 | return nil, fmt.Errorf("embedding component is required") 42 | } 43 | if config.BatchSize == 0 { 44 | config.BatchSize = 10 45 | } 46 | if config.Distance == 0 { 47 | config.Distance = qdrant.Distance_Cosine 48 | } 49 | 50 | // 使用 eino-ext 的 Qdrant indexer 51 | einoConfig := &einoqdrant.Config{ 52 | Client: config.Client, 53 | Collection: config.Collection, 54 | VectorDim: config.VectorDim, 55 | Distance: config.Distance, 56 | Embedding: config.Embedding, 57 | BatchSize: config.BatchSize, 58 | } 59 | 60 | einoIndexer, err := einoqdrant.NewIndexer(ctx, einoConfig) 61 | if err != nil { 62 | return nil, fmt.Errorf("failed to create eino indexer: %w", err) 63 | } 64 | 65 | idx := &QdrantIndexer{ 66 | config: config, 67 | einoIndexer: einoIndexer, 68 | } 69 | 70 | // 注意:eino-ext 的 indexer 不支持命名向量 71 | // 如果需要异步模式(QA 向量),需要使用自定义实现或创建单独的 collection 72 | // 当前实现:使用自定义逻辑处理命名向量 73 | 74 | return idx, nil 75 | } 76 | 77 | // Store 存储文档(实现 Indexer 接口) 78 | // 注意:由于 eino-ext 的 indexer 不支持命名向量,这里使用自定义实现 79 | func (idx *QdrantIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) ([]string, error) { 80 | // 使用自定义实现支持命名向量 81 | return idx.StoreWithNamedVectors(ctx, docs, opts...) 82 | } 83 | 84 | // GetType 返回 indexer 类型 85 | func (idx *QdrantIndexer) GetType() string { 86 | return "qdrant_indexer" 87 | } 88 | -------------------------------------------------------------------------------- /server/core/common/es.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/cenkalti/backoff/v4" 9 | "github.com/elastic/go-elasticsearch/v8" 10 | "github.com/elastic/go-elasticsearch/v8/typedapi/indices/create" 11 | "github.com/elastic/go-elasticsearch/v8/typedapi/indices/exists" 12 | "github.com/elastic/go-elasticsearch/v8/typedapi/types" 13 | "github.com/gogf/gf/v2/frame/g" 14 | coretypes "github.com/wangle201210/go-rag/server/core/types" 15 | ) 16 | 17 | // createIndex create index for example in add_documents.go. 18 | // 已废弃:使用 vector.VectorStore 接口替代 19 | // Deprecated: Use vector.VectorStore.CreateIndex instead 20 | func createIndex(ctx context.Context, client *elasticsearch.Client, indexName string) error { 21 | _, err := create.NewCreateFunc(client)(indexName).Request(&create.Request{ 22 | Mappings: &types.TypeMapping{ 23 | Properties: map[string]types.Property{ 24 | coretypes.FieldContent: types.NewTextProperty(), 25 | coretypes.FieldExtra: types.NewTextProperty(), 26 | coretypes.KnowledgeName: types.NewKeywordProperty(), 27 | coretypes.FieldContentVector: &types.DenseVectorProperty{ 28 | Dims: Of(1024), // same as embedding dimensions 29 | Index: Of(true), 30 | Similarity: Of("cosine"), 31 | }, 32 | coretypes.FieldQAContentVector: &types.DenseVectorProperty{ 33 | Dims: Of(1024), // same as embedding dimensions 34 | Index: Of(true), 35 | Similarity: Of("cosine"), 36 | }, 37 | }, 38 | }, 39 | }).Do(ctx) 40 | 41 | return err 42 | } 43 | 44 | // CreateIndexIfNotExists 已废弃:使用 vector.VectorStore 接口替代 45 | // Deprecated: Use vector.VectorStore.IndexExists and CreateIndex instead 46 | func CreateIndexIfNotExists(ctx context.Context, client *elasticsearch.Client, indexName string) error { 47 | indexExists, err := exists.NewExistsFunc(client)(indexName).Do(ctx) 48 | if err != nil { 49 | return err 50 | } 51 | if indexExists { 52 | return nil 53 | } 54 | err = createIndex(ctx, client, indexName) 55 | return err 56 | } 57 | 58 | // DeleteDocument 删除索引中的单个文档 59 | // 已废弃:使用 vector.VectorStore 接口替代 60 | // Deprecated: Use vector.VectorStore.DeleteDocument instead 61 | func DeleteDocument(ctx context.Context, client *elasticsearch.Client, documentID string) error { 62 | return withRetry(func() error { 63 | indexName := g.Cfg().MustGet(ctx, "vector.indexName").String() 64 | res, err := client.Delete(indexName, documentID) 65 | if err != nil { 66 | return fmt.Errorf("delete document failed: %w", err) 67 | } 68 | defer res.Body.Close() 69 | 70 | if res.IsError() { 71 | return fmt.Errorf("delete document failed: %s", res.String()) 72 | } 73 | 74 | return nil 75 | }) 76 | } 77 | 78 | // withRetry 包装函数,添加重试机制 79 | func withRetry(operation func() error) error { 80 | b := backoff.NewExponentialBackOff() 81 | b.MaxElapsedTime = 30 * time.Second 82 | 83 | return backoff.Retry(operation, b) 84 | } 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-rag 2 | 基于eino+gf+vue实现知识库的rag 3 | 1. 创建知识库 4 | ![](./server/static/kb.png) 5 | 2. 选择需要使用的知识库,上传文档 6 | ![](./server/static/indexer.png) 7 | 3. 文档列表 & chunk 编辑 8 | ![](./server/static/doc-list.png) 9 | ![](./server/static/chunk-edit.png) 10 | 4. 文档检索 11 | ![](./server/static/retriever.png) 12 | 5. 对话 13 | ![](./server/static/chat.png) 14 | 6. mcp (以集成到deepchat为例) 15 | ![](./server/static/mcp-cfg.png) 16 | ![](./server/static/mcp-use.png) 17 | 7. 基于go-rag 实现了一个兼容mac和windows的gui,直接双击即可完成所有安装,不再依赖docker安装mysql,es等依赖 18 | [wachat](https://github.com/wangle201210/wachat) 19 | ![](./server/static/rag.png) 20 | ![](./server/static/wachat.png) 21 | 22 | 23 | ## roadmap 24 | [roadmap](./roadmap.md) 25 | 26 | ## 存储层 27 | - [x] es8存储向量相关数据 28 | - [x] qdrant存储向量相关数据 29 | 30 | ## 功能列表 31 | - [x] md、pdf、html 文档解析 32 | - [x] 网页解析 33 | - [x] 文档检索 34 | - [x] 长文档自动切割(chunk) 35 | - [x] 多知识库支持 36 | - [x] chunk 编辑 37 | - [x] 自动生成 QA 对 38 | - [x] 多路召回 39 | 40 | ## 使用 41 | ### clone项目 42 | ```bash 43 | git clone https://github.com/wangle201210/go-rag.git 44 | ``` 45 | 46 | ### 使用 Docker Compose 快速启动(推荐) 47 | ```bash 48 | cd go-rag 49 | cp server/manifest/config/config_demo.yaml server/manifest/config/config.yaml 50 | # 修改配置文件中的embedding、chat、mysql、es等配置 51 | docker compose up -d 52 | # 浏览器打开 http://localhost:8000 53 | ``` 54 | 55 | ### 使用源码启动 56 | 如果有可用的es8和mysql,可以直接快速启动项目,否则需要先安装es8和mysql 57 | 需要修改`config.yaml`文件的相关配置 58 | ```bash 59 | cp server/manifest/config/config_demo.yaml server/manifest/config/config.yaml 60 | # 修改配置文件中的embedding、chat、mysql、es等配置 61 | make build # 这里会构建前后端项目 62 | make run 63 | # 浏览器打开 http://localhost:8000 64 | ```` 65 | 66 | ### 安装依赖 67 | *如果有可用的es8和mysql,可以不用安装* 68 | 安装es8 69 | ```bash 70 | docker run -d --name elasticsearch \ 71 | -e "discovery.type=single-node" \ 72 | -e "ES_JAVA_OPTS=-Xms512m -Xmx512m" \ 73 | -e "cluster.routing.allocation.disk.watermark.low=1gb" \ 74 | -e "cluster.routing.allocation.disk.watermark.high=1gb" \ 75 | -e "cluster.routing.allocation.disk.watermark.flood_stage=1gb" \ 76 | -e "xpack.security.enabled=false" \ 77 | -p 9200:9200 \ 78 | -p 9300:9300 \ 79 | elasticsearch:8.18.0 80 | ``` 81 | 安装mysql 82 | ```bash 83 | docker run -p 3306:3306 --name mysql \ 84 | -v /Users/wanna/docker/mysql/log:/var/log/mysql \ 85 | -v /Users/wanna/docker/mysql/data:/var/lib/mysql \ 86 | --restart=always \ 87 | -e MYSQL_ROOT_PASSWORD=123456 \ 88 | -e MYSQL_DATABASE=go-rag \ 89 | -d mysql:8.0 90 | ``` 91 | 92 | ### 运行 api 项目 93 | 94 | ```bash 95 | cd server 96 | go mod tidy 97 | go run main.go 98 | ``` 99 | 100 | ### 运行前端项目 101 | 102 | ```bash 103 | cd fe 104 | npm install 105 | npm run dev 106 | ``` 107 | 108 | ## 使用Makefile构建 109 | 110 | - 构建前端并将产物复制到server/static/fe目录 `make build-fe` 111 | 112 | - 构建后端 `make build-server` 113 | 114 | - 构建整个项目(前端+后端)`make build` 115 | 116 | - 清理构建产物 `make clean` 117 | 118 | ## 联系方式 119 | 120 | > 如果使用上遇到什么问题,欢迎加微信交流(尽量使用github issue 交流) 121 | - 参与项目开发备注:go-rag 开发 122 | - 问题咨询备注:go-rag 咨询 123 | ![微信](./server/static/wx.jpg) -------------------------------------------------------------------------------- /server/core/indexer/indexer.go: -------------------------------------------------------------------------------- 1 | package indexer 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/bytedance/sonic" 8 | "github.com/cloudwego/eino-ext/components/indexer/es8" 9 | "github.com/cloudwego/eino/components/indexer" 10 | "github.com/cloudwego/eino/schema" 11 | "github.com/google/uuid" 12 | "github.com/wangle201210/go-rag/server/core/common" 13 | "github.com/wangle201210/go-rag/server/core/config" 14 | coretypes "github.com/wangle201210/go-rag/server/core/types" 15 | ) 16 | 17 | // newIndexer component initialization function of node 'Indexer2' in graph 'rag' 18 | func newIndexer(ctx context.Context, conf *config.Config) (idr indexer.Indexer, err error) { 19 | embeddingIns11, err := common.NewEmbedding(ctx, conf) 20 | if err != nil { 21 | return nil, err 22 | } 23 | 24 | // 根据向量存储类型创建不同的 indexer 25 | if conf.Client != nil { 26 | // ES indexer 27 | indexerConfig := &es8.IndexerConfig{ 28 | Client: conf.Client, 29 | Index: conf.IndexName, 30 | BatchSize: 10, 31 | DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]es8.FieldValue, err error) { 32 | var knowledgeName string 33 | if value, ok := ctx.Value(coretypes.KnowledgeName).(string); ok { 34 | knowledgeName = value 35 | } else { 36 | err = fmt.Errorf("必须提供知识库名称") 37 | return 38 | } 39 | // 没传入才需要生成 40 | if len(doc.ID) == 0 { 41 | doc.ID = uuid.New().String() 42 | } 43 | if doc.MetaData != nil { 44 | // 存储ext数据 45 | marshal, _ := sonic.Marshal(getExtData(doc)) 46 | doc.MetaData[coretypes.FieldExtra] = string(marshal) 47 | } 48 | return map[string]es8.FieldValue{ 49 | coretypes.FieldContent: { 50 | Value: doc.Content, 51 | EmbedKey: coretypes.FieldContentVector, // 这里也可以考虑注释掉,等后续异步执行 52 | }, 53 | coretypes.FieldExtra: { 54 | Value: doc.MetaData[coretypes.FieldExtra], 55 | }, 56 | coretypes.KnowledgeName: { 57 | Value: knowledgeName, 58 | }, 59 | }, nil 60 | }, 61 | } 62 | indexerConfig.Embedding = embeddingIns11 63 | idr, err = es8.NewIndexer(ctx, indexerConfig) 64 | if err != nil { 65 | return nil, err 66 | } 67 | return idr, nil 68 | } else if conf.QdrantClient != nil { 69 | // Qdrant indexer 70 | idr, err = NewQdrantIndexer(ctx, &QdrantIndexerConfig{ 71 | Client: conf.QdrantClient, 72 | Collection: conf.IndexName, 73 | VectorDim: 1024, // 根据你的 embedding 模型调整 74 | Distance: 0, // 使用默认 Cosine 75 | Embedding: embeddingIns11, 76 | BatchSize: 10, 77 | IsAsync: false, 78 | }) 79 | if err != nil { 80 | return nil, err 81 | } 82 | return idr, nil 83 | } else { 84 | return nil, fmt.Errorf("no valid client configuration found") 85 | } 86 | } 87 | 88 | func getExtData(doc *schema.Document) map[string]any { 89 | if doc.MetaData == nil { 90 | return nil 91 | } 92 | res := make(map[string]any) 93 | for _, key := range coretypes.ExtKeys { 94 | if v, e := doc.MetaData[key]; e { 95 | res[key] = v 96 | } 97 | } 98 | return res 99 | } 100 | -------------------------------------------------------------------------------- /server/core/common/chat_model.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/cloudwego/eino-ext/components/model/openai" 7 | "github.com/cloudwego/eino-ext/components/model/qwen" 8 | "github.com/cloudwego/eino/components/model" 9 | "github.com/gogf/gf/v2/frame/g" 10 | ) 11 | 12 | var ( 13 | embeddingModel model.BaseChatModel 14 | rerankModel model.BaseChatModel 15 | rewriteModel model.BaseChatModel 16 | qaModel model.BaseChatModel 17 | chatModel model.BaseChatModel 18 | ) 19 | 20 | func GetChatModel(ctx context.Context, cfg *openai.ChatModelConfig) (model.BaseChatModel, error) { 21 | if chatModel != nil { 22 | return chatModel, nil 23 | } 24 | if cfg == nil { 25 | cfg = &openai.ChatModelConfig{} 26 | err := g.Cfg().MustGet(ctx, "chat").Scan(cfg) 27 | if err != nil { 28 | return nil, err 29 | } 30 | } 31 | cm, err := openai.NewChatModel(ctx, cfg) 32 | if err != nil { 33 | return nil, err 34 | } 35 | chatModel = cm 36 | return cm, nil 37 | } 38 | 39 | func GetEmbeddingModel(ctx context.Context, cfg *openai.ChatModelConfig) (model.BaseChatModel, error) { 40 | if embeddingModel != nil { 41 | return embeddingModel, nil 42 | } 43 | if cfg == nil { 44 | cfg = &openai.ChatModelConfig{} 45 | err := g.Cfg().MustGet(ctx, "embedding").Scan(cfg) 46 | if err != nil { 47 | return nil, err 48 | } 49 | } 50 | cm, err := openai.NewChatModel(ctx, cfg) 51 | if err != nil { 52 | return nil, err 53 | } 54 | embeddingModel = cm 55 | return cm, nil 56 | } 57 | 58 | func GetRewriteModel(ctx context.Context, cfg *qwen.ChatModelConfig) (model.BaseChatModel, error) { 59 | if rewriteModel != nil { 60 | return rewriteModel, nil 61 | } 62 | if cfg == nil { 63 | cfg = &qwen.ChatModelConfig{} 64 | err := g.Cfg().MustGet(ctx, "rewrite").Scan(cfg) 65 | cfg.EnableThinking = Of(false) 66 | if err != nil { 67 | return nil, err 68 | } 69 | } 70 | cm, err := qwen.NewChatModel(ctx, cfg) 71 | if err != nil { 72 | return nil, err 73 | } 74 | rewriteModel = cm 75 | return cm, nil 76 | } 77 | 78 | func GetRerankModel(ctx context.Context, cfg *openai.ChatModelConfig) (model.BaseChatModel, error) { 79 | if rerankModel != nil { 80 | return rerankModel, nil 81 | } 82 | if cfg == nil { 83 | cfg = &openai.ChatModelConfig{} 84 | err := g.Cfg().MustGet(ctx, "rerank").Scan(cfg) 85 | if err != nil { 86 | return nil, err 87 | } 88 | } 89 | cm, err := openai.NewChatModel(ctx, cfg) 90 | if err != nil { 91 | return nil, err 92 | } 93 | rerankModel = cm 94 | return cm, nil 95 | } 96 | 97 | func GetQAModel(ctx context.Context, cfg *qwen.ChatModelConfig) (model.BaseChatModel, error) { 98 | if qaModel != nil { 99 | return qaModel, nil 100 | } 101 | if cfg == nil { 102 | cfg = &qwen.ChatModelConfig{} 103 | err := g.Cfg().MustGet(ctx, "qa").Scan(cfg) 104 | cfg.EnableThinking = Of(false) 105 | if err != nil { 106 | return nil, err 107 | } 108 | } 109 | cm, err := qwen.NewChatModel(ctx, cfg) 110 | if err != nil { 111 | return nil, err 112 | } 113 | qaModel = cm 114 | return cm, nil 115 | } 116 | -------------------------------------------------------------------------------- /server/internal/dao/internal/knowledge_base.go: -------------------------------------------------------------------------------- 1 | // ========================================================================== 2 | // Code generated and maintained by GoFrame CLI tool. DO NOT EDIT. 3 | // ========================================================================== 4 | 5 | package internal 6 | 7 | import ( 8 | "context" 9 | 10 | "github.com/gogf/gf/v2/database/gdb" 11 | "github.com/gogf/gf/v2/frame/g" 12 | ) 13 | 14 | // KnowledgeBaseDao is the data access object for the table knowledge_base. 15 | type KnowledgeBaseDao struct { 16 | table string // table is the underlying table name of the DAO. 17 | group string // group is the database configuration group name of the current DAO. 18 | columns KnowledgeBaseColumns // columns contains all the column names of Table for convenient usage. 19 | } 20 | 21 | // KnowledgeBaseColumns defines and stores column names for the table knowledge_base. 22 | type KnowledgeBaseColumns struct { 23 | Id string // 主键ID 24 | Name string // 知识库名称 25 | Description string // 知识库描述 26 | Category string // 知识库分类 27 | Status string // 状态:1-启用,2-禁用 28 | CreateTime string // 创建时间 29 | UpdateTime string // 更新时间 30 | } 31 | 32 | // knowledgeBaseColumns holds the columns for the table knowledge_base. 33 | var knowledgeBaseColumns = KnowledgeBaseColumns{ 34 | Id: "id", 35 | Name: "name", 36 | Description: "description", 37 | Category: "category", 38 | Status: "status", 39 | CreateTime: "create_time", 40 | UpdateTime: "update_time", 41 | } 42 | 43 | // NewKnowledgeBaseDao creates and returns a new DAO object for table data access. 44 | func NewKnowledgeBaseDao() *KnowledgeBaseDao { 45 | return &KnowledgeBaseDao{ 46 | group: "default", 47 | table: "knowledge_base", 48 | columns: knowledgeBaseColumns, 49 | } 50 | } 51 | 52 | // DB retrieves and returns the underlying raw database management object of the current DAO. 53 | func (dao *KnowledgeBaseDao) DB() gdb.DB { 54 | return g.DB(dao.group) 55 | } 56 | 57 | // Table returns the table name of the current DAO. 58 | func (dao *KnowledgeBaseDao) Table() string { 59 | return dao.table 60 | } 61 | 62 | // Columns returns all column names of the current DAO. 63 | func (dao *KnowledgeBaseDao) Columns() KnowledgeBaseColumns { 64 | return dao.columns 65 | } 66 | 67 | // Group returns the database configuration group name of the current DAO. 68 | func (dao *KnowledgeBaseDao) Group() string { 69 | return dao.group 70 | } 71 | 72 | // Ctx creates and returns a Model for the current DAO. It automatically sets the context for the current operation. 73 | func (dao *KnowledgeBaseDao) Ctx(ctx context.Context) *gdb.Model { 74 | return dao.DB().Model(dao.table).Safe().Ctx(ctx) 75 | } 76 | 77 | // Transaction wraps the transaction logic using function f. 78 | // It rolls back the transaction and returns the error if function f returns a non-nil error. 79 | // It commits the transaction and returns nil if function f returns nil. 80 | // 81 | // Note: Do not commit or roll back the transaction in function f, 82 | // as it is automatically handled by this function. 83 | func (dao *KnowledgeBaseDao) Transaction(ctx context.Context, f func(ctx context.Context, tx gdb.TX) error) (err error) { 84 | return dao.Ctx(ctx).Transaction(ctx, f) 85 | } 86 | -------------------------------------------------------------------------------- /fe/src/components/HelloWorld.vue: -------------------------------------------------------------------------------- 1 | 18 | 19 | 118 | 119 | 128 | -------------------------------------------------------------------------------- /server/core/retriever/retriever.go: -------------------------------------------------------------------------------- 1 | package retriever 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/bytedance/sonic" 8 | "github.com/cloudwego/eino-ext/components/retriever/es8" 9 | "github.com/cloudwego/eino-ext/components/retriever/es8/search_mode" 10 | "github.com/cloudwego/eino/components/retriever" 11 | "github.com/cloudwego/eino/schema" 12 | "github.com/elastic/go-elasticsearch/v8/typedapi/types" 13 | "github.com/wangle201210/go-rag/server/core/common" 14 | "github.com/wangle201210/go-rag/server/core/config" 15 | coretypes "github.com/wangle201210/go-rag/server/core/types" 16 | ) 17 | 18 | // newRetriever component initialization function of node 'Retriever1' in graph 'retriever' 19 | func newRetriever(ctx context.Context, conf *config.Config) (rtr retriever.Retriever, err error) { 20 | vectorField := coretypes.FieldContentVector 21 | if value, ok := ctx.Value(coretypes.RetrieverFieldKey).(string); ok { 22 | vectorField = value 23 | } 24 | 25 | embeddingIns, err := common.NewEmbedding(ctx, conf) 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | // 根据客户端类型创建不同的 retriever 31 | if conf.Client != nil { 32 | // ES retriever 33 | retrieverConfig := &es8.RetrieverConfig{ 34 | Client: conf.Client, 35 | Index: conf.IndexName, 36 | SearchMode: search_mode.SearchModeDenseVectorSimilarity( 37 | search_mode.DenseVectorSimilarityTypeCosineSimilarity, 38 | vectorField, 39 | ), 40 | ResultParser: EsHit2Document, 41 | Embedding: embeddingIns, 42 | } 43 | rtr, err = es8.NewRetriever(ctx, retrieverConfig) 44 | if err != nil { 45 | return nil, err 46 | } 47 | return rtr, nil 48 | } else if conf.QdrantClient != nil { 49 | // Qdrant retriever 50 | rtr, err = NewQdrantRetriever(ctx, &QdrantRetrieverConfig{ 51 | Client: conf.QdrantClient, 52 | Collection: conf.IndexName, 53 | Embedding: embeddingIns, 54 | VectorField: vectorField, 55 | }) 56 | if err != nil { 57 | return nil, err 58 | } 59 | return rtr, nil 60 | } 61 | 62 | return nil, fmt.Errorf("no valid client configuration found") 63 | } 64 | 65 | func EsHit2Document(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) { 66 | doc = &schema.Document{ 67 | ID: *hit.Id_, 68 | MetaData: map[string]any{}, 69 | } 70 | 71 | var src map[string]any 72 | if err = sonic.Unmarshal(hit.Source_, &src); err != nil { 73 | return nil, err 74 | } 75 | 76 | for field, val := range src { 77 | switch field { 78 | case coretypes.FieldContent: 79 | doc.Content = val.(string) 80 | case coretypes.FieldContentVector: 81 | var v []float64 82 | for _, item := range val.([]interface{}) { 83 | v = append(v, item.(float64)) 84 | } 85 | doc.WithDenseVector(v) 86 | case coretypes.FieldQAContentVector, coretypes.FieldQAContent: 87 | // 这两个字段都不返回 88 | 89 | case coretypes.FieldExtra: 90 | if val == nil { 91 | continue 92 | } 93 | doc.MetaData[coretypes.FieldExtra] = val.(string) 94 | case coretypes.KnowledgeName: 95 | doc.MetaData[coretypes.KnowledgeName] = val.(string) 96 | default: 97 | return nil, fmt.Errorf("unexpected field=%s, val=%v", field, val) 98 | } 99 | } 100 | 101 | if hit.Score_ != nil { 102 | doc.WithScore(float64(*hit.Score_)) 103 | } 104 | 105 | return doc, nil 106 | } 107 | -------------------------------------------------------------------------------- /server/internal/dao/dao.go: -------------------------------------------------------------------------------- 1 | package dao 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | 9 | "github.com/gogf/gf/v2/frame/g" 10 | database "github.com/wangle201210/go-rag/server/internal/dao/db" 11 | ) 12 | 13 | var db database.Database 14 | 15 | func init() { 16 | err := InitDB() 17 | if err != nil { 18 | g.Log().Fatal(context.Background(), "database connection not initialized, err %v", err) 19 | } 20 | } 21 | 22 | // InitDB 初始化数据库连接 23 | func InitDB() error { 24 | ctx := context.Background() 25 | dbType := g.Cfg().MustGet(ctx, "database.default.type", "mysql").String() 26 | var ( 27 | cfg *database.Config 28 | err error 29 | ) 30 | switch dbType { 31 | case "mysql", "": 32 | cfg, _ = GetMysqlConfig() 33 | db = database.NewMysql(cfg) 34 | case "sqlite": 35 | cfg, err = GetSqliteConfig() 36 | if err != nil { 37 | return err 38 | } 39 | db = database.NewSqlite(cfg) 40 | default: 41 | return fmt.Errorf("unsupported database type: %s", dbType) 42 | } 43 | if err = db.Connect(); err != nil { 44 | return err 45 | } 46 | if err = db.Ping(); err != nil { 47 | return err 48 | } 49 | // 迁移表结构 50 | return db.AutoMigrate() 51 | } 52 | 53 | // ensureSQLiteFileDir 确保 SQLite 文件的父目录存在 54 | func ensureSQLiteFileDir(filePath string) error { 55 | // 获取文件的目录路径 56 | dir := filepath.Dir(filePath) 57 | 58 | // 检查目录是否存在 59 | if _, err := os.Stat(dir); os.IsNotExist(err) { 60 | // 目录不存在,创建目录(包括所有必要的父目录) 61 | if err := os.MkdirAll(dir, 0755); err != nil { 62 | return fmt.Errorf("failed to create directory %s: %v", dir, err) 63 | } 64 | g.Log().Infof(context.Background(), "Created directory for SQLite database: %s", dir) 65 | } 66 | 67 | return nil 68 | } 69 | 70 | func GetDsn() string { 71 | return db.DSN() 72 | } 73 | 74 | func GetSqliteConfig() (*database.Config, error) { 75 | ctx := context.Background() 76 | cfg := &database.Config{} 77 | cfg.FilePath = g.Cfg().MustGet(ctx, "database.default.host").String() 78 | if cfg.FilePath == "" { 79 | return nil, fmt.Errorf("sqlite file path is required when using sqlite") 80 | } 81 | if err := ensureSQLiteFileDir(cfg.FilePath); err != nil { 82 | return nil, fmt.Errorf("failed to create sqlite file directory: %v", err) 83 | } 84 | cfg.BusyTimeout = g.Cfg().MustGet(ctx, "database.default.busy_timeout").Int() 85 | cfg.JournalMode = g.Cfg().MustGet(ctx, "database.default.journal_mode").String() 86 | cfg.Synchronous = g.Cfg().MustGet(ctx, "database.default.synchronous").String() 87 | cfg.CacheSize = g.Cfg().MustGet(ctx, "database.default.cache_size").Int() 88 | cfg.MaxOpenConn = g.Cfg().MustGet(ctx, "database.default.max_open_conns", 1).Int() 89 | cfg.MaxIdleConn = g.Cfg().MustGet(ctx, "database.default.max_idle_conns", 1).Int() 90 | cfg.LogLevel = 4 91 | return cfg, nil 92 | } 93 | 94 | func GetMysqlConfig() (*database.Config, error) { 95 | cfg := g.DB().GetConfig() 96 | c := &database.Config{ 97 | Host: cfg.Host, 98 | Port: cfg.Port, 99 | User: cfg.User, 100 | Password: cfg.Pass, 101 | Database: cfg.Name, 102 | Charset: cfg.Charset, 103 | MaxOpenConn: cfg.MaxOpenConnCount, 104 | MaxIdleConn: cfg.MaxIdleConnCount, 105 | LogLevel: 4, 106 | } 107 | if c.MaxIdleConn == 0 { 108 | c.MaxIdleConn = 10 109 | } 110 | if c.MaxOpenConn == 0 { 111 | c.MaxOpenConn = 100 112 | } 113 | return c, nil 114 | } 115 | -------------------------------------------------------------------------------- /fe/src/components.d.ts: -------------------------------------------------------------------------------- 1 | /* eslint-disable */ 2 | // @ts-nocheck 3 | // Generated by unplugin-vue-components 4 | // Read more: https://github.com/vuejs/core/pull/3399 5 | export {} 6 | 7 | /* prettier-ignore */ 8 | declare module 'vue' { 9 | export interface GlobalComponents { 10 | BaseHeader: typeof import('./components/layouts/BaseHeader.vue')['default'] 11 | BaseSide: typeof import('./components/layouts/BaseSide.vue')['default'] 12 | ElAlert: typeof import('element-plus/es')['ElAlert'] 13 | ElAvatar: typeof import('element-plus/es')['ElAvatar'] 14 | ElButton: typeof import('element-plus/es')['ElButton'] 15 | ElCard: typeof import('element-plus/es')['ElCard'] 16 | ElCol: typeof import('element-plus/es')['ElCol'] 17 | ElCollapse: typeof import('element-plus/es')['ElCollapse'] 18 | ElCollapseItem: typeof import('element-plus/es')['ElCollapseItem'] 19 | ElCollapseTransition: typeof import('element-plus/es')['ElCollapseTransition'] 20 | ElConfigProvider: typeof import('element-plus/es')['ElConfigProvider'] 21 | ElDescriptions: typeof import('element-plus/es')['ElDescriptions'] 22 | ElDescriptionsItem: typeof import('element-plus/es')['ElDescriptionsItem'] 23 | ElDialog: typeof import('element-plus/es')['ElDialog'] 24 | ElDivider: typeof import('element-plus/es')['ElDivider'] 25 | ElEmpty: typeof import('element-plus/es')['ElEmpty'] 26 | ElForm: typeof import('element-plus/es')['ElForm'] 27 | ElFormItem: typeof import('element-plus/es')['ElFormItem'] 28 | ElIcon: typeof import('element-plus/es')['ElIcon'] 29 | ElInput: typeof import('element-plus/es')['ElInput'] 30 | ElInputNumber: typeof import('element-plus/es')['ElInputNumber'] 31 | ElMenu: typeof import('element-plus/es')['ElMenu'] 32 | ElMenuItem: typeof import('element-plus/es')['ElMenuItem'] 33 | ElOption: typeof import('element-plus/es')['ElOption'] 34 | ElPageHeader: typeof import('element-plus/es')['ElPageHeader'] 35 | ElPagination: typeof import('element-plus/es')['ElPagination'] 36 | ElPopover: typeof import('element-plus/es')['ElPopover'] 37 | ElRadio: typeof import('element-plus/es')['ElRadio'] 38 | ElRadioGroup: typeof import('element-plus/es')['ElRadioGroup'] 39 | ElRow: typeof import('element-plus/es')['ElRow'] 40 | ElScrollbar: typeof import('element-plus/es')['ElScrollbar'] 41 | ElSelect: typeof import('element-plus/es')['ElSelect'] 42 | ElSkeleton: typeof import('element-plus/es')['ElSkeleton'] 43 | ElSlider: typeof import('element-plus/es')['ElSlider'] 44 | ElSpace: typeof import('element-plus/es')['ElSpace'] 45 | ElTable: typeof import('element-plus/es')['ElTable'] 46 | ElTableColumn: typeof import('element-plus/es')['ElTableColumn'] 47 | ElTag: typeof import('element-plus/es')['ElTag'] 48 | ElTooltip: typeof import('element-plus/es')['ElTooltip'] 49 | ElUpload: typeof import('element-plus/es')['ElUpload'] 50 | HelloWorld: typeof import('./components/HelloWorld.vue')['default'] 51 | KnowledgeSelector: typeof import('./components/KnowledgeSelector.vue')['default'] 52 | Logos: typeof import('./components/Logos.vue')['default'] 53 | MessageBoxDemo: typeof import('./components/MessageBoxDemo.vue')['default'] 54 | RouterLink: typeof import('vue-router')['RouterLink'] 55 | RouterView: typeof import('vue-router')['RouterView'] 56 | } 57 | export interface ComponentCustomProperties { 58 | vLoading: typeof import('element-plus/es')['ElLoadingDirective'] 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /server/internal/mcp/indexer.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import ( 4 | "context" 5 | "encoding/base64" 6 | "fmt" 7 | 8 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 9 | "github.com/gogf/gf/v2/frame/g" 10 | "github.com/gogf/gf/v2/os/gctx" 11 | gorag "github.com/wangle201210/go-rag/server/core" 12 | "github.com/wangle201210/go-rag/server/internal/logic/rag" 13 | ) 14 | 15 | type IndexParam struct { 16 | URI string `json:"uri" description:"文件路径" required:"true"` // 可以是文件路径(pdf,html,md等),也可以是网址 17 | KnowledgeName string `json:"knowledge_name" description:"知识库名字,请先通过getKnowledgeBaseList获取列表后判断是否有符合的知识库,如果没有则根据用户提示词自己生成" required:"true"` 18 | } 19 | 20 | func GetIndexerByFilePathTool() *protocol.Tool { 21 | tool, err := protocol.NewTool("Indexer_by_filepath", "通过文件路径进行文本嵌入", IndexParam{}) 22 | if err != nil { 23 | g.Log().Errorf(gctx.New(), "Failed to create tool: %v", err) 24 | return nil 25 | } 26 | return tool 27 | } 28 | 29 | func HandleIndexerByFilePath(ctx context.Context, req *protocol.CallToolRequest) (*protocol.CallToolResult, error) { 30 | var reqData IndexParam 31 | if err := protocol.VerifyAndUnmarshal(req.RawArguments, &reqData); err != nil { 32 | return nil, err 33 | } 34 | svr := rag.GetRagSvr() 35 | uri := reqData.URI 36 | indexReq := &gorag.IndexReq{ 37 | URI: uri, 38 | KnowledgeName: reqData.KnowledgeName, 39 | } 40 | ids, err := svr.Index(ctx, indexReq) 41 | if err != nil { 42 | return nil, err 43 | } 44 | msg := fmt.Sprintf("index file %s successfully, knowledge_name: %s, doc_ids: %v", uri, reqData.KnowledgeName, ids) 45 | return &protocol.CallToolResult{ 46 | Content: []protocol.Content{ 47 | &protocol.TextContent{ 48 | Type: "text", 49 | Text: msg, 50 | }, 51 | }, 52 | }, nil 53 | } 54 | 55 | type IndexFileParam struct { 56 | Filename string `json:"filename" description:"文件名字" required:"true"` 57 | Content string `json:"content" description:"被base64编码后的文件内容,先调用工具获取base64信息" required:"true"` // 可以是文件路径(pdf,html,md等),也可以是网址文件" required:"true"` // 可以是文件路径(pdf,html,md等),也可以是网址 58 | KnowledgeName string `json:"knowledge_name" description:"知识库名字,请先通过getKnowledgeBaseList获取列表后判断是否有符合的知识库,如果没有则根据用户提示词自己生成" required:"true"` 59 | } 60 | 61 | func GetIndexerByFileBase64ContentTool() *protocol.Tool { 62 | tool, err := protocol.NewTool("Indexer_by_base64_file_content", "获取文件base64信息后上传,然后对内容进行文本嵌入", IndexFileParam{}) 63 | if err != nil { 64 | g.Log().Errorf(gctx.New(), "Failed to create tool: %v", err) 65 | return nil 66 | } 67 | return tool 68 | } 69 | 70 | func HandleIndexerByFileBase64Content(ctx context.Context, req *protocol.CallToolRequest) (*protocol.CallToolResult, error) { 71 | var reqData IndexFileParam 72 | if err := protocol.VerifyAndUnmarshal(req.RawArguments, &reqData); err != nil { 73 | return nil, err 74 | } 75 | // svr := rag.GetRagSvr() 76 | decoded, err := base64.StdEncoding.DecodeString(reqData.Content) 77 | if err != nil { 78 | return nil, err 79 | } 80 | fmt.Println(decoded) 81 | // indexReq := &gorag.IndexReq{ 82 | // URI: uri, 83 | // KnowledgeName: reqData.KnowledgeName, 84 | // } 85 | // ids, err := svr.Index(ctx, indexReq) 86 | // if err != nil { 87 | // return nil, err 88 | // } 89 | // msg := fmt.Sprintf("index file %s successfully, knowledge_name: %s, doc_ids: %v", uri, reqData.KnowledgeName, ids) 90 | return &protocol.CallToolResult{ 91 | Content: []protocol.Content{ 92 | &protocol.TextContent{ 93 | Type: "text", 94 | Text: string(decoded), 95 | }, 96 | }, 97 | }, nil 98 | } 99 | -------------------------------------------------------------------------------- /server/core/indexer/merge.go: -------------------------------------------------------------------------------- 1 | package indexer 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/bytedance/sonic" 9 | "github.com/cloudwego/eino-ext/components/document/loader/file" 10 | "github.com/cloudwego/eino/schema" 11 | "github.com/google/uuid" 12 | coretypes "github.com/wangle201210/go-rag/server/core/types" 13 | ) 14 | 15 | // docAddIDAndMerge component initialization function of node 'Lambda1' in graph 't' 16 | func docAddIDAndMerge(ctx context.Context, docs []*schema.Document) (output []*schema.Document, err error) { 17 | if len(docs) == 0 { 18 | return docs, nil 19 | } 20 | for _, doc := range docs { 21 | doc.ID = uuid.New().String() // 覆盖之前的id 22 | } 23 | switch docs[0].MetaData[file.MetaKeyExtension] { 24 | case ".md": 25 | return mergeMD(ctx, docs) 26 | case ".xlsx": 27 | return mergeXLSX(ctx, docs) 28 | default: 29 | return docs, nil 30 | } 31 | } 32 | 33 | func mergeMD(ctx context.Context, docs []*schema.Document) (output []*schema.Document, err error) { 34 | ndocs := make([]*schema.Document, 0, len(docs)) 35 | var nd *schema.Document 36 | maxLen := 512 37 | for _, doc := range docs { 38 | // 不是同一个文件的就不要放一起了 39 | if nd != nil && doc.MetaData[file.MetaKeySource] != nd.MetaData[file.MetaKeySource] { 40 | ndocs = append(ndocs, nd) 41 | nd = nil 42 | } 43 | // 两个文档长度之和大于maxLen就不要放一起了 44 | if nd != nil && len(nd.Content)+len(doc.Content) > maxLen { 45 | ndocs = append(ndocs, nd) 46 | nd = nil 47 | } 48 | // 不是同一个一级标题的就不要放一起了 49 | if nd != nil && doc.MetaData[coretypes.Title1] != nd.MetaData[coretypes.Title1] { 50 | ndocs = append(ndocs, nd) 51 | nd = nil 52 | } 53 | // 不是同一个二级标题的就不要放一起了 54 | // 如果nd的h2是nil,证明之前只有h1,且两个的h1相等,则直接合并 55 | if nd != nil && nd.MetaData[coretypes.Title2] != nil && doc.MetaData[coretypes.Title2] != nd.MetaData[coretypes.Title2] { 56 | ndocs = append(ndocs, nd) 57 | nd = nil 58 | } 59 | if nd == nil { 60 | nd = doc 61 | } else { 62 | mergeTitle(nd, doc, coretypes.Title2) 63 | mergeTitle(nd, doc, coretypes.Title3) 64 | nd.Content += doc.Content 65 | } 66 | } 67 | if nd != nil { 68 | ndocs = append(ndocs, nd) 69 | } 70 | for _, ndoc := range ndocs { 71 | ndoc.Content = getMdContentWithTitle(ndoc) 72 | } 73 | return ndocs, nil 74 | } 75 | 76 | func mergeXLSX(ctx context.Context, docs []*schema.Document) (output []*schema.Document, err error) { 77 | for _, doc := range docs { 78 | marshal, _ := sonic.Marshal(doc.MetaData[coretypes.XlsxRow]) 79 | doc.Content = string(marshal) 80 | } 81 | return docs, nil 82 | } 83 | 84 | func mergeTitle(orgDoc, addDoc *schema.Document, key string) { 85 | // 相等就不管了 86 | if orgDoc.MetaData[key] == addDoc.MetaData[key] { 87 | return 88 | } 89 | var title []string 90 | if orgDoc.MetaData[key] != nil { 91 | title = append(title, orgDoc.MetaData[key].(string)) 92 | } 93 | if addDoc.MetaData[key] != nil { 94 | title = append(title, addDoc.MetaData[key].(string)) 95 | } 96 | if len(title) > 0 { 97 | orgDoc.MetaData[key] = strings.Join(title, ",") 98 | } 99 | } 100 | 101 | func getMdContentWithTitle(doc *schema.Document) string { 102 | if doc.MetaData == nil { 103 | return doc.Content 104 | } 105 | title := "" 106 | list := []string{"h1", "h2", "h3", "h4", "h5", "h6"} 107 | for _, v := range list { 108 | if d, e := doc.MetaData[v].(string); e && len(d) > 0 { 109 | title += fmt.Sprintf("%s:%s ", v, d) 110 | } 111 | } 112 | if len(title) == 0 { 113 | return doc.Content 114 | } 115 | return title + "\n" + doc.Content 116 | } 117 | -------------------------------------------------------------------------------- /server/internal/dao/internal/knowledge_chunks.go: -------------------------------------------------------------------------------- 1 | // ========================================================================== 2 | // Code generated and maintained by GoFrame CLI tool. DO NOT EDIT. 3 | // ========================================================================== 4 | 5 | package internal 6 | 7 | import ( 8 | "context" 9 | 10 | "github.com/gogf/gf/v2/database/gdb" 11 | "github.com/gogf/gf/v2/frame/g" 12 | ) 13 | 14 | // KnowledgeChunksDao is the data access object for the table knowledge_chunks. 15 | type KnowledgeChunksDao struct { 16 | table string // table is the underlying table name of the DAO. 17 | group string // group is the database configuration group name of the current DAO. 18 | columns KnowledgeChunksColumns // columns contains all the column names of Table for convenient usage. 19 | handlers []gdb.ModelHandler // handlers for customized model modification. 20 | } 21 | 22 | // KnowledgeChunksColumns defines and stores column names for the table knowledge_chunks. 23 | type KnowledgeChunksColumns struct { 24 | Id string // 25 | KnowledgeDocId string // 26 | ChunkId string // 27 | Content string // 28 | Ext string // 29 | Status string // 30 | CreatedAt string // 31 | UpdatedAt string // 32 | } 33 | 34 | // knowledgeChunksColumns holds the columns for the table knowledge_chunks. 35 | var knowledgeChunksColumns = KnowledgeChunksColumns{ 36 | Id: "id", 37 | KnowledgeDocId: "knowledge_doc_id", 38 | ChunkId: "chunk_id", 39 | Content: "content", 40 | Ext: "ext", 41 | Status: "status", 42 | CreatedAt: "created_at", 43 | UpdatedAt: "updated_at", 44 | } 45 | 46 | // NewKnowledgeChunksDao creates and returns a new DAO object for table data access. 47 | func NewKnowledgeChunksDao(handlers ...gdb.ModelHandler) *KnowledgeChunksDao { 48 | return &KnowledgeChunksDao{ 49 | group: "default", 50 | table: "knowledge_chunks", 51 | columns: knowledgeChunksColumns, 52 | handlers: handlers, 53 | } 54 | } 55 | 56 | // DB retrieves and returns the underlying raw database management object of the current DAO. 57 | func (dao *KnowledgeChunksDao) DB() gdb.DB { 58 | return g.DB(dao.group) 59 | } 60 | 61 | // Table returns the table name of the current DAO. 62 | func (dao *KnowledgeChunksDao) Table() string { 63 | return dao.table 64 | } 65 | 66 | // Columns returns all column names of the current DAO. 67 | func (dao *KnowledgeChunksDao) Columns() KnowledgeChunksColumns { 68 | return dao.columns 69 | } 70 | 71 | // Group returns the database configuration group name of the current DAO. 72 | func (dao *KnowledgeChunksDao) Group() string { 73 | return dao.group 74 | } 75 | 76 | // Ctx creates and returns a Model for the current DAO. It automatically sets the context for the current operation. 77 | func (dao *KnowledgeChunksDao) Ctx(ctx context.Context) *gdb.Model { 78 | model := dao.DB().Model(dao.table) 79 | for _, handler := range dao.handlers { 80 | model = handler(model) 81 | } 82 | return model.Safe().Ctx(ctx) 83 | } 84 | 85 | // Transaction wraps the transaction logic using function f. 86 | // It rolls back the transaction and returns the error if function f returns a non-nil error. 87 | // It commits the transaction and returns nil if function f returns nil. 88 | // 89 | // Note: Do not commit or roll back the transaction in function f, 90 | // as it is automatically handled by this function. 91 | func (dao *KnowledgeChunksDao) Transaction(ctx context.Context, f func(ctx context.Context, tx gdb.TX) error) (err error) { 92 | return dao.Ctx(ctx).Transaction(ctx, f) 93 | } 94 | -------------------------------------------------------------------------------- /server/internal/dao/internal/knowledge_documents.go: -------------------------------------------------------------------------------- 1 | // ========================================================================== 2 | // Code generated and maintained by GoFrame CLI tool. DO NOT EDIT. 3 | // ========================================================================== 4 | 5 | package internal 6 | 7 | import ( 8 | "context" 9 | 10 | "github.com/gogf/gf/v2/database/gdb" 11 | "github.com/gogf/gf/v2/frame/g" 12 | ) 13 | 14 | // KnowledgeDocumentsDao is the data access object for the table knowledge_documents. 15 | type KnowledgeDocumentsDao struct { 16 | table string // table is the underlying table name of the DAO. 17 | group string // group is the database configuration group name of the current DAO. 18 | columns KnowledgeDocumentsColumns // columns contains all the column names of Table for convenient usage. 19 | handlers []gdb.ModelHandler // handlers for customized model modification. 20 | } 21 | 22 | // KnowledgeDocumentsColumns defines and stores column names for the table knowledge_documents. 23 | type KnowledgeDocumentsColumns struct { 24 | Id string // 25 | KnowledgeBaseName string // 26 | FileName string // 27 | Status string // 28 | CreatedAt string // 29 | UpdatedAt string // 30 | } 31 | 32 | // knowledgeDocumentsColumns holds the columns for the table knowledge_documents. 33 | var knowledgeDocumentsColumns = KnowledgeDocumentsColumns{ 34 | Id: "id", 35 | KnowledgeBaseName: "knowledge_base_name", 36 | FileName: "file_name", 37 | Status: "status", 38 | CreatedAt: "created_at", 39 | UpdatedAt: "updated_at", 40 | } 41 | 42 | // NewKnowledgeDocumentsDao creates and returns a new DAO object for table data access. 43 | func NewKnowledgeDocumentsDao(handlers ...gdb.ModelHandler) *KnowledgeDocumentsDao { 44 | return &KnowledgeDocumentsDao{ 45 | group: "default", 46 | table: "knowledge_documents", 47 | columns: knowledgeDocumentsColumns, 48 | handlers: handlers, 49 | } 50 | } 51 | 52 | // DB retrieves and returns the underlying raw database management object of the current DAO. 53 | func (dao *KnowledgeDocumentsDao) DB() gdb.DB { 54 | return g.DB(dao.group) 55 | } 56 | 57 | // Table returns the table name of the current DAO. 58 | func (dao *KnowledgeDocumentsDao) Table() string { 59 | return dao.table 60 | } 61 | 62 | // Columns returns all column names of the current DAO. 63 | func (dao *KnowledgeDocumentsDao) Columns() KnowledgeDocumentsColumns { 64 | return dao.columns 65 | } 66 | 67 | // Group returns the database configuration group name of the current DAO. 68 | func (dao *KnowledgeDocumentsDao) Group() string { 69 | return dao.group 70 | } 71 | 72 | // Ctx creates and returns a Model for the current DAO. It automatically sets the context for the current operation. 73 | func (dao *KnowledgeDocumentsDao) Ctx(ctx context.Context) *gdb.Model { 74 | model := dao.DB().Model(dao.table) 75 | for _, handler := range dao.handlers { 76 | model = handler(model) 77 | } 78 | return model.Safe().Ctx(ctx) 79 | } 80 | 81 | // Transaction wraps the transaction logic using function f. 82 | // It rolls back the transaction and returns the error if function f returns a non-nil error. 83 | // It commits the transaction and returns nil if function f returns nil. 84 | // 85 | // Note: Do not commit or roll back the transaction in function f, 86 | // as it is automatically handled by this function. 87 | func (dao *KnowledgeDocumentsDao) Transaction(ctx context.Context, f func(ctx context.Context, tx gdb.TX) error) (err error) { 88 | return dao.Ctx(ctx).Transaction(ctx, f) 89 | } 90 | -------------------------------------------------------------------------------- /server/core/rerank/rerank.go: -------------------------------------------------------------------------------- 1 | package rerank 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "io/ioutil" 8 | "net/http" 9 | 10 | "github.com/bytedance/sonic" 11 | "github.com/cloudwego/eino/schema" 12 | "github.com/gogf/gf/v2/frame/g" 13 | ) 14 | 15 | type Conf struct { 16 | Model string `json:"model"` 17 | ReturnDocuments bool `json:"return_documents"` 18 | MaxChunksPerDoc int `json:"max_chunks_per_doc"` 19 | OverlapTokens int `json:"overlap_tokens"` 20 | url string 21 | apiKey string 22 | } 23 | type Data struct { 24 | Query string `json:"query"` 25 | Documents []string `json:"documents"` 26 | TopN int `json:"top_n"` 27 | } 28 | 29 | type Req struct { 30 | *Data 31 | *Conf 32 | } 33 | 34 | type Result struct { 35 | Index int `json:"index"` 36 | RelevanceScore float64 `json:"relevance_score"` 37 | } 38 | 39 | type Resp struct { 40 | ID string `json:"id"` 41 | Results []*Result `json:"results"` 42 | } 43 | 44 | var rerankCfg *Conf 45 | 46 | func NewRerank(ctx context.Context, query string, docs []*schema.Document, topK int) (output []*schema.Document, err error) { 47 | output, err = rerank(ctx, query, docs, topK) 48 | if err != nil { 49 | return 50 | } 51 | return 52 | } 53 | 54 | func GetConf(ctx context.Context) *Conf { 55 | if rerankCfg != nil { 56 | return rerankCfg 57 | } 58 | baseUrl := g.Cfg().MustGet(ctx, "rerank.baseURL").String() 59 | apiKey := g.Cfg().MustGet(ctx, "rerank.apiKey").String() 60 | model := g.Cfg().MustGet(ctx, "rerank.model").String() 61 | url := fmt.Sprintf("%s/rerank", baseUrl) 62 | rerankCfg = &Conf{ 63 | apiKey: apiKey, 64 | Model: model, 65 | ReturnDocuments: false, 66 | MaxChunksPerDoc: 1024, 67 | OverlapTokens: 80, 68 | url: url, 69 | } 70 | return rerankCfg 71 | } 72 | 73 | func rerank(ctx context.Context, query string, docs []*schema.Document, topK int) (output []*schema.Document, err error) { 74 | data := &Data{ 75 | Query: query, 76 | TopN: topK, 77 | } 78 | // g.Log().Infof(ctx, "docs num: %d", len(docs)) 79 | for _, doc := range docs { 80 | data.Documents = append(data.Documents, doc.Content) 81 | } 82 | // 重排 83 | results, err := rerankDoHttp(ctx, data) 84 | if err != nil { 85 | return 86 | } 87 | // 重新组装数据 88 | for _, result := range results { 89 | doc := docs[result.Index] 90 | // g.Log().Infof(ctx, "content: %s, score_old: %f, score_new: %f", doc.Content, doc.Score(), result.RelevanceScore) 91 | doc.WithScore(result.RelevanceScore) 92 | output = append(output, docs[result.Index]) 93 | } 94 | return 95 | } 96 | 97 | func rerankDoHttp(ctx context.Context, data *Data) ([]*Result, error) { 98 | cfg := GetConf(ctx) 99 | reqData := &Req{ 100 | Data: data, 101 | Conf: cfg, 102 | } 103 | 104 | marshal, err := sonic.Marshal(reqData) 105 | if err != nil { 106 | return nil, err 107 | } 108 | payload := bytes.NewReader(marshal) 109 | request, err := http.NewRequest("POST", cfg.url, payload) 110 | if err != nil { 111 | return nil, err 112 | } 113 | request.Header.Add("Content-Type", "application/json") 114 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", cfg.apiKey)) 115 | do, err := g.Client().Do(request) 116 | if err != nil { 117 | return nil, err 118 | } 119 | defer do.Body.Close() 120 | body, err := ioutil.ReadAll(do.Body) 121 | if err != nil { 122 | return nil, err 123 | } 124 | res := Resp{} 125 | err = sonic.Unmarshal(body, &res) 126 | if err != nil { 127 | return nil, err 128 | } 129 | return res.Results, nil 130 | } 131 | -------------------------------------------------------------------------------- /server/core/indexer/qdrant_indexer_custom.go: -------------------------------------------------------------------------------- 1 | package indexer 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/bytedance/sonic" 8 | "github.com/cloudwego/eino/components/indexer" 9 | "github.com/cloudwego/eino/schema" 10 | "github.com/gogf/gf/v2/frame/g" 11 | "github.com/google/uuid" 12 | "github.com/qdrant/go-client/qdrant" 13 | coretypes "github.com/wangle201210/go-rag/server/core/types" 14 | ) 15 | 16 | // StoreWithNamedVectors 使用命名向量存储文档到 Qdrant 17 | // 这是自定义实现,因为 eino-ext 的 indexer 不支持命名向量 18 | func (idx *QdrantIndexer) StoreWithNamedVectors(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) ([]string, error) { 19 | if len(docs) == 0 { 20 | return nil, nil 21 | } 22 | 23 | var knowledgeName string 24 | if value, ok := ctx.Value(coretypes.KnowledgeName).(string); ok { 25 | knowledgeName = value 26 | } else { 27 | return nil, fmt.Errorf("必须提供知识库名称") 28 | } 29 | 30 | g.Log().Infof(ctx, "QdrantIndexer.StoreWithNamedVectors: storing %d documents to collection %s, knowledge_name=%s", len(docs), idx.config.Collection, knowledgeName) 31 | 32 | // 准备 points 33 | points := make([]*qdrant.PointStruct, 0, len(docs)) 34 | ids := make([]string, 0, len(docs)) 35 | 36 | for _, doc := range docs { 37 | // 生成 ID 38 | if len(doc.ID) == 0 { 39 | doc.ID = uuid.New().String() 40 | } 41 | ids = append(ids, doc.ID) 42 | 43 | // 生成 embedding 44 | embeddings, err := idx.config.Embedding.EmbedStrings(ctx, []string{doc.Content}) 45 | if err != nil { 46 | g.Log().Errorf(ctx, "Failed to embed document %s: %v", doc.ID, err) 47 | return nil, fmt.Errorf("failed to embed document: %w", err) 48 | } 49 | if len(embeddings) == 0 { 50 | return nil, fmt.Errorf("embedding returned empty result") 51 | } 52 | 53 | // 转换为 float32 54 | vec32 := make([]float32, len(embeddings[0])) 55 | for i, v := range embeddings[0] { 56 | vec32[i] = float32(v) 57 | } 58 | 59 | // 准备 payload 60 | payload := make(map[string]*qdrant.Value) 61 | payload[coretypes.FieldContent] = &qdrant.Value{ 62 | Kind: &qdrant.Value_StringValue{StringValue: doc.Content}, 63 | } 64 | payload[coretypes.KnowledgeName] = &qdrant.Value{ 65 | Kind: &qdrant.Value_StringValue{StringValue: knowledgeName}, 66 | } 67 | 68 | // 添加额外的 metadata 69 | if doc.MetaData != nil { 70 | extData := getExtData(doc) 71 | if len(extData) > 0 { 72 | marshal, _ := sonic.Marshal(extData) 73 | payload[coretypes.FieldExtra] = &qdrant.Value{ 74 | Kind: &qdrant.Value_StringValue{StringValue: string(marshal)}, 75 | } 76 | } 77 | } 78 | 79 | // 创建命名向量(只存储 content_vector,qa_content_vector 由异步任务处理) 80 | vectors := &qdrant.Vectors{ 81 | VectorsOptions: &qdrant.Vectors_Vectors{ 82 | Vectors: &qdrant.NamedVectors{ 83 | Vectors: map[string]*qdrant.Vector{ 84 | coretypes.FieldContentVector: { 85 | Data: vec32, 86 | }, 87 | }, 88 | }, 89 | }, 90 | } 91 | 92 | // 创建 point 93 | point := &qdrant.PointStruct{ 94 | Id: &qdrant.PointId{ 95 | PointIdOptions: &qdrant.PointId_Uuid{Uuid: doc.ID}, 96 | }, 97 | Vectors: vectors, 98 | Payload: payload, 99 | } 100 | 101 | points = append(points, point) 102 | } 103 | 104 | // 批量存储到 Qdrant 105 | _, err := idx.config.Client.Upsert(ctx, &qdrant.UpsertPoints{ 106 | CollectionName: idx.config.Collection, 107 | Points: points, 108 | }) 109 | if err != nil { 110 | g.Log().Errorf(ctx, "QdrantIndexer.StoreWithNamedVectors failed: %v", err) 111 | return nil, fmt.Errorf("failed to upsert points: %w", err) 112 | } 113 | 114 | g.Log().Infof(ctx, "QdrantIndexer.StoreWithNamedVectors success: stored %d documents, IDs: %v", len(ids), ids) 115 | 116 | return ids, nil 117 | } 118 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for go-rag project 2 | 3 | .PHONY: build-fe build-server build docker-build clean build-all release clean-release 4 | 5 | # 默认目标 6 | all: build 7 | 8 | # 构建前端 9 | build-fe: 10 | cd fe && pnpm install && pnpm run build 11 | mkdir -p server/static/fe 12 | cp -r fe/dist/* server/static/fe/ 13 | 14 | # 构建后端 15 | build-server: 16 | cd server && go mod tidy && go build -o go-rag-server main.go 17 | 18 | # 构建整个项目 19 | build: build-fe build-server 20 | 21 | # 运行 22 | run: 23 | cd server && ./go-rag-server 24 | 25 | # 清理构建产物 26 | clean: 27 | rm -rf fe/dist 28 | rm -rf server/static/fe 29 | rm -f server/go-rag-server 30 | 31 | # 构建Docker镜像 32 | docker-build: build 33 | docker build -t go-rag:latest -f Dockerfile . 34 | 35 | run-local: 36 | cd server && go mod tidy && go run . 37 | 38 | build-linux: 39 | cd server && go mod tidy && GOOS=linux GOARCH=amd64 go build -o go-rag-server 40 | 41 | run-by-docker: 42 | docker compose -f docker-compose.yml up -d 43 | 44 | v := v0.0.3 45 | buildx: 46 | docker buildx build \ 47 | --platform linux/arm64,linux/amd64 \ 48 | -t iwangle/go-rag:$(v) \ 49 | --push \ 50 | . 51 | 52 | # 项目名称和版本 53 | APP_NAME := go-rag 54 | 55 | # 支持的平台 56 | PLATFORMS := linux/amd64 linux/arm64 darwin/amd64 darwin/arm64 windows/amd64 57 | 58 | # 多平台构建 59 | build-all: 60 | @echo "Building go-rag for multiple platforms..." 61 | @mkdir -p releases 62 | @for platform in $(PLATFORMS); do \ 63 | os=$$(echo $$platform | cut -d'/' -f1); \ 64 | arch=$$(echo $$platform | cut -d'/' -f2); \ 65 | output_name=$(APP_NAME)-$$os-$$arch; \ 66 | if [ $$os = "windows" ]; then output_name=$$output_name.exe; fi; \ 67 | echo "Building for $$os/$$arch..."; \ 68 | (cd server && GOOS=$$os GOARCH=$$arch go build -o ../releases/$$output_name .); \ 69 | done 70 | @echo "Build completed! Files are in releases/ directory" 71 | 72 | # 发布版本(构建 + 压缩) 73 | release: clean-release build-all 74 | @echo "Creating release archives..." 75 | @cd releases && \ 76 | for platform in $(PLATFORMS); do \ 77 | os=$$(echo $$platform | cut -d'/' -f1); \ 78 | arch=$$(echo $$platform | cut -d'/' -f2); \ 79 | release_dir=$(APP_NAME); \ 80 | output_name=$(APP_NAME); \ 81 | if [ $$os = "windows" ]; then output_name=$$output_name.exe; fi; \ 82 | exe_file=$(APP_NAME)-$$os-$$arch; \ 83 | if [ $$os = "windows" ]; then exe_file=$$exe_file.exe; fi; \ 84 | archive_name=$(APP_NAME)-$$os-$$arch; \ 85 | \ 86 | echo "Preparing $$archive_name..."; \ 87 | mkdir -p $$release_dir; \ 88 | \ 89 | if [ -f $$exe_file ]; then \ 90 | cp $$exe_file $$release_dir/$$output_name; \ 91 | echo "Copied executable to $$release_dir/$$output_name"; \ 92 | else \ 93 | echo "Warning: $$exe_file not found, skipping $$platform"; \ 94 | rm -rf $$release_dir; \ 95 | continue; \ 96 | fi; \ 97 | \ 98 | if [ -d ../server/static ]; then \ 99 | cp -r ../server/static $$release_dir/; \ 100 | echo "Copied static files to $$release_dir/static/"; \ 101 | fi; \ 102 | \ 103 | if [ -f ../server/manifest/config/config_qd_demo.yaml ]; then \ 104 | cp ../server/manifest/config/config_qd_demo.yaml $$release_dir/config.yaml; \ 105 | echo "Copied config file to $$release_dir/config.yaml"; \ 106 | fi; \ 107 | \ 108 | if [ $$os = "windows" ]; then \ 109 | zip -q $$archive_name.zip -r $$release_dir; \ 110 | echo "Created $$archive_name.zip"; \ 111 | else \ 112 | tar -czf $$archive_name.tar.gz $$release_dir; \ 113 | echo "Created $$archive_name.tar.gz"; \ 114 | fi; \ 115 | rm -rf $$release_dir; \ 116 | done 117 | @echo "Release archives created!" 118 | @echo "Files in releases/:" 119 | @ls -la releases/ 120 | 121 | # 清理发布文件 122 | clean-release: 123 | @echo "Cleaning release files..." 124 | @rm -rf releases/ 125 | @echo "Release files cleaned!" 126 | 127 | -------------------------------------------------------------------------------- /server/internal/logic/knowledge/chunks.go: -------------------------------------------------------------------------------- 1 | package knowledge 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/gogf/gf/v2/frame/g" 7 | v1 "github.com/wangle201210/go-rag/server/api/rag/v1" 8 | "github.com/wangle201210/go-rag/server/internal/dao" 9 | "github.com/wangle201210/go-rag/server/internal/model/entity" 10 | ) 11 | 12 | // SaveChunksData 批量保存知识块数据 13 | func SaveChunksData(ctx context.Context, documentsId int64, chunks []entity.KnowledgeChunks) error { 14 | if len(chunks) == 0 { 15 | return nil 16 | } 17 | status := int(v1.StatusIndexing) 18 | 19 | // 逐个插入或更新,避免 SQLite 的 ON CONFLICT 语法问题 20 | for _, chunk := range chunks { 21 | // 先尝试查询是否存在 22 | var existing entity.KnowledgeChunks 23 | err := dao.KnowledgeChunks.Ctx(ctx).Where("chunk_id", chunk.ChunkId).Scan(&existing) 24 | 25 | if err == nil && existing.Id > 0 { 26 | // 已存在,更新(排除 id 和 created_at) 27 | _, err = dao.KnowledgeChunks.Ctx(ctx). 28 | Where("chunk_id", chunk.ChunkId). 29 | Data(g.Map{ 30 | "knowledge_doc_id": chunk.KnowledgeDocId, 31 | "content": chunk.Content, 32 | "ext": chunk.Ext, 33 | "status": chunk.Status, 34 | }). 35 | Update() 36 | if err != nil { 37 | g.Log().Errorf(ctx, "SaveChunksData update failed for chunk_id=%s, err=%+v", chunk.ChunkId, err) 38 | status = int(v1.StatusFailed) 39 | } 40 | } else { 41 | // 不存在,插入(id 设为 0 让数据库自动分配) 42 | chunk.Id = 0 43 | _, err = dao.KnowledgeChunks.Ctx(ctx).Data(chunk).OmitEmpty().Insert() 44 | if err != nil { 45 | g.Log().Errorf(ctx, "SaveChunksData insert failed for chunk_id=%s, err=%+v", chunk.ChunkId, err) 46 | status = int(v1.StatusFailed) 47 | } 48 | } 49 | } 50 | 51 | UpdateDocumentsStatus(ctx, documentsId, status) 52 | return nil 53 | } 54 | 55 | // GetChunksList 查询知识块列表 56 | func GetChunksList(ctx context.Context, where entity.KnowledgeChunks, page, size int) (list []entity.KnowledgeChunks, total int, err error) { 57 | model := dao.KnowledgeChunks.Ctx(ctx) 58 | 59 | // 构建查询条件 60 | if where.KnowledgeDocId != 0 { 61 | model = model.Where("knowledge_doc_id", where.KnowledgeDocId) 62 | } 63 | if where.ChunkId != "" { 64 | model = model.Where("chunk_id", where.ChunkId) 65 | } 66 | 67 | // 获取总数 68 | total, err = model.Count() 69 | if err != nil { 70 | return 71 | } 72 | 73 | // 分页查询 74 | if page > 0 && size > 0 { 75 | model = model.Page(page, size) 76 | } 77 | 78 | // 按创建时间倒序 79 | model = model.OrderDesc("created_at") 80 | 81 | err = model.Scan(&list) 82 | return 83 | } 84 | 85 | // GetChunkById 根据ID查询单个知识块 86 | func GetChunkById(ctx context.Context, id int64) (chunk entity.KnowledgeChunks, err error) { 87 | err = dao.KnowledgeChunks.Ctx(ctx).Where("id", id).Scan(&chunk) 88 | return 89 | } 90 | 91 | // DeleteChunkByIds 根据ID软删除知识块 92 | func DeleteChunkById(ctx context.Context, id int64) error { 93 | _, err := dao.KnowledgeChunks.Ctx(ctx).Where("id", id).Delete() 94 | return err 95 | } 96 | 97 | // UpdateChunkById 根据ID更新知识块 98 | func UpdateChunkByIds(ctx context.Context, ids []int64, data entity.KnowledgeChunks) error { 99 | model := dao.KnowledgeChunks.Ctx(ctx).WhereIn("id", ids) 100 | if data.Content != "" { 101 | model = model.Data("content", data.Content) 102 | } 103 | if data.Status != 0 { 104 | model = model.Data("status", data.Status) 105 | } 106 | _, err := model.Update() 107 | return err 108 | } 109 | 110 | // GetAllChunksByDocId gets all chunks by document id 111 | func GetAllChunksByDocId(ctx context.Context, docId int64, fields ...string) (list []entity.KnowledgeChunks, err error) { 112 | model := dao.KnowledgeChunks.Ctx(ctx).Where("knowledge_doc_id", docId) 113 | if len(fields) > 0 { 114 | for _, field := range fields { 115 | model = model.Fields(field) 116 | } 117 | } 118 | err = model.Scan(&list) 119 | return 120 | } 121 | -------------------------------------------------------------------------------- /server/internal/logic/chat/chat.go: -------------------------------------------------------------------------------- 1 | package chat 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | 9 | "github.com/cloudwego/eino-ext/components/model/openai" 10 | "github.com/cloudwego/eino/components/model" 11 | "github.com/cloudwego/eino/schema" 12 | "github.com/gogf/gf/v2/frame/g" 13 | "github.com/gogf/gf/v2/os/gctx" 14 | "github.com/wangle201210/chat-history/eino" 15 | "github.com/wangle201210/go-rag/server/internal/dao" 16 | ) 17 | 18 | var chat *Chat 19 | 20 | type Chat struct { 21 | cm model.BaseChatModel 22 | eh *eino.History 23 | } 24 | 25 | func GetChat() *Chat { 26 | return chat 27 | } 28 | 29 | // 暂时用不上chat功能,先不init 30 | func init() { 31 | ctx := gctx.New() 32 | c, err := newChat(&openai.ChatModelConfig{ 33 | APIKey: g.Cfg().MustGet(ctx, "chat.apiKey").String(), 34 | BaseURL: g.Cfg().MustGet(ctx, "chat.baseURL").String(), 35 | Model: g.Cfg().MustGet(ctx, "chat.model").String(), 36 | }) 37 | if err != nil { 38 | g.Log().Fatalf(ctx, "newChat failed, err=%v", err) 39 | return 40 | } 41 | // 使用 DSN 初始化,chat-history 包会根据 DSN 判断数据库类型 42 | // 对于 SQLite: file.db?_journal_mode=WAL 43 | // 对于 MySQL: user:pass@tcp(host:port)/dbname?charset=utf8mb4 44 | c.eh = eino.NewEinoHistory(dao.GetDsn()) 45 | chat = c 46 | } 47 | 48 | func newChat(cfg *openai.ChatModelConfig) (res *Chat, err error) { 49 | chatModel, err := openai.NewChatModel(context.Background(), cfg) 50 | if err != nil { 51 | return nil, err 52 | } 53 | return &Chat{cm: chatModel}, nil 54 | } 55 | 56 | func (x *Chat) GetAnswer(ctx context.Context, convID string, docs []*schema.Document, question string) (answer string, err error) { 57 | messages, err := x.docsMessages(ctx, convID, docs, question) 58 | if err != nil { 59 | return "", err 60 | } 61 | result, err := generate(ctx, x.cm, messages) 62 | if err != nil { 63 | return "", fmt.Errorf("生成答案失败: %w", err) 64 | } 65 | err = x.eh.SaveMessage(result, convID) 66 | if err != nil { 67 | g.Log().Error(ctx, "save assistant message err: %v", err) 68 | return 69 | } 70 | return result.Content, nil 71 | } 72 | 73 | // GetAnswerStream 流式生成答案 74 | func (x *Chat) GetAnswerStream(ctx context.Context, convID string, docs []*schema.Document, question string) (answer *schema.StreamReader[*schema.Message], err error) { 75 | messages, err := x.docsMessages(ctx, convID, docs, question) 76 | if err != nil { 77 | return 78 | } 79 | ctx = context.Background() 80 | streamData, err := stream(ctx, x.cm, messages) 81 | if err != nil { 82 | return nil, fmt.Errorf("生成答案失败: %w", err) 83 | } 84 | srs := streamData.Copy(2) 85 | go func() { 86 | // for save 87 | fullMsgs := make([]*schema.Message, 0) 88 | defer func() { 89 | srs[1].Close() 90 | fullMsg, err := schema.ConcatMessages(fullMsgs) 91 | if err != nil { 92 | g.Log().Error(ctx, "error concatenating messages: %v", err) 93 | return 94 | } 95 | err = x.eh.SaveMessage(fullMsg, convID) 96 | if err != nil { 97 | g.Log().Error(ctx, "save assistant message err: %v", err) 98 | return 99 | } 100 | }() 101 | outer: 102 | for { 103 | select { 104 | case <-ctx.Done(): 105 | fmt.Println("context done", ctx.Err()) 106 | return 107 | default: 108 | chunk, err := srs[1].Recv() 109 | if err != nil { 110 | if errors.Is(err, io.EOF) { 111 | break outer 112 | } 113 | } 114 | fullMsgs = append(fullMsgs, chunk) 115 | } 116 | } 117 | }() 118 | 119 | return srs[0], nil 120 | 121 | } 122 | 123 | func generate(ctx context.Context, llm model.BaseChatModel, in []*schema.Message) (message *schema.Message, err error) { 124 | message, err = llm.Generate(ctx, in) 125 | if err != nil { 126 | err = fmt.Errorf("llm generate failed: %v", err) 127 | return 128 | } 129 | return 130 | } 131 | 132 | func stream(ctx context.Context, llm model.BaseChatModel, in []*schema.Message) (res *schema.StreamReader[*schema.Message], err error) { 133 | res, err = llm.Stream(ctx, in) 134 | if err != nil { 135 | err = fmt.Errorf("llm generate failed: %v", err) 136 | return 137 | } 138 | return 139 | } 140 | -------------------------------------------------------------------------------- /server/internal/logic/knowledge/documents.go: -------------------------------------------------------------------------------- 1 | package knowledge 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/gogf/gf/v2/database/gdb" 8 | "github.com/gogf/gf/v2/frame/g" 9 | "github.com/wangle201210/go-rag/server/internal/dao" 10 | "github.com/wangle201210/go-rag/server/internal/model/entity" 11 | ) 12 | 13 | const ( 14 | defaultPageSize = 10 15 | maxPageSize = 100 16 | ) 17 | 18 | // SaveDocumentsInfo 保存文档信息 19 | func SaveDocumentsInfo(ctx context.Context, documents entity.KnowledgeDocuments) (id int64, err error) { 20 | // 确保 ID 为 0,让数据库自动分配 21 | documents.Id = 0 22 | 23 | // OmitEmpty 会忽略零值字段(包括 ID=0),让数据库自动分配 ID 24 | // 这样可以兼容 MySQL 和 SQLite 的自增主键 25 | result, err := dao.KnowledgeDocuments.Ctx(ctx).Data(documents).OmitEmpty().Insert() 26 | if err != nil { 27 | g.Log().Errorf(ctx, "保存文档信息失败: %+v, 错误: %v", documents, err) 28 | return 0, fmt.Errorf("保存文档信息失败: %w", err) 29 | } 30 | 31 | id, err = result.LastInsertId() 32 | if err != nil { 33 | return 0, fmt.Errorf("获取插入ID失败: %w", err) 34 | } 35 | g.Log().Infof(ctx, "文档保存成功, ID: %d", id) 36 | return id, nil 37 | } 38 | 39 | // UpdateDocumentsStatus 更新文档状态 40 | func UpdateDocumentsStatus(ctx context.Context, documentsId int64, status int) error { 41 | data := g.Map{ 42 | "status": status, 43 | } 44 | 45 | _, err := dao.KnowledgeDocuments.Ctx(ctx).Where("id", documentsId).Data(data).Update() 46 | if err != nil { 47 | g.Log().Errorf(ctx, "更新文档状态失败: ID=%d, 错误: %v", documentsId, err) 48 | } 49 | 50 | return err 51 | } 52 | 53 | // GetDocumentById 根据ID获取文档信息 54 | func GetDocumentById(ctx context.Context, id int64) (document entity.KnowledgeDocuments, err error) { 55 | g.Log().Debugf(ctx, "获取文档信息: ID=%d", id) 56 | 57 | err = dao.KnowledgeDocuments.Ctx(ctx).Where("id", id).Scan(&document) 58 | if err != nil { 59 | g.Log().Errorf(ctx, "获取文档信息失败: ID=%d, 错误: %v", id, err) 60 | return document, fmt.Errorf("获取文档信息失败: %w", err) 61 | } 62 | 63 | return document, nil 64 | } 65 | 66 | // GetDocumentsList 获取文档列表 67 | func GetDocumentsList(ctx context.Context, where entity.KnowledgeDocuments, page int, pageSize int) (documents []entity.KnowledgeDocuments, total int, err error) { 68 | // 参数验证和默认值设置 69 | if page < 1 { 70 | page = 1 71 | } 72 | if pageSize < 1 { 73 | pageSize = defaultPageSize 74 | } 75 | if pageSize > maxPageSize { 76 | pageSize = maxPageSize 77 | } 78 | 79 | model := dao.KnowledgeDocuments.Ctx(ctx) 80 | if where.KnowledgeBaseName != "" { 81 | model = model.Where("knowledge_base_name", where.KnowledgeBaseName) 82 | } 83 | 84 | total, err = model.Count() 85 | if err != nil { 86 | g.Log().Errorf(ctx, "获取文档总数失败: %v", err) 87 | return nil, 0, fmt.Errorf("获取文档总数失败: %w", err) 88 | } 89 | 90 | if total == 0 { 91 | return nil, 0, nil 92 | } 93 | 94 | err = model.Page(page, pageSize). 95 | Order("created_at desc"). 96 | Scan(&documents) 97 | if err != nil { 98 | g.Log().Errorf(ctx, "获取文档列表失败: %v", err) 99 | return nil, 0, fmt.Errorf("获取文档列表失败: %w", err) 100 | } 101 | 102 | return documents, total, nil 103 | } 104 | 105 | // DeleteDocument 删除文档及其相关数据 106 | func DeleteDocument(ctx context.Context, id int64) error { 107 | g.Log().Debugf(ctx, "删除文档: ID=%d", id) 108 | 109 | return dao.KnowledgeDocuments.Ctx(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { 110 | // 先删除文档块 111 | _, err := dao.KnowledgeChunks.Ctx(ctx).TX(tx).Where("knowledge_doc_id", id).Delete() 112 | if err != nil { 113 | g.Log().Errorf(ctx, "删除文档块失败: ID=%d, 错误: %v", id, err) 114 | return fmt.Errorf("删除文档块失败: %w", err) 115 | } 116 | 117 | // 再删除文档 118 | result, err := dao.KnowledgeDocuments.Ctx(ctx).TX(tx).Where("id", id).Delete() 119 | if err != nil { 120 | g.Log().Errorf(ctx, "删除文档失败: ID=%d, 错误: %v", id, err) 121 | return fmt.Errorf("删除文档失败: %w", err) 122 | } 123 | 124 | affected, err := result.RowsAffected() 125 | if err != nil { 126 | return fmt.Errorf("获取影响行数失败: %w", err) 127 | } 128 | if affected == 0 { 129 | return fmt.Errorf("文档不存在") 130 | } 131 | 132 | g.Log().Infof(ctx, "文档删除成功: ID=%d", id) 133 | return nil 134 | }) 135 | } 136 | -------------------------------------------------------------------------------- /fe/public/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 7 | 8 | 10 | 35 | 40 | 42 | 47 | 49 | 60 | 61 | 62 | --------------------------------------------------------------------------------