20 |
21 | {{ msg }}
22 |
23 |
24 |
25 | See
26 | element-plus for more
27 | information.
28 |
29 |
30 |
31 |
32 |
33 | El Message
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | count is: {{ count }}
42 |
43 |
44 | count is: {{ count }}
45 |
46 |
47 | count is: {{ count }}
48 |
49 |
50 | count is: {{ count }}
51 |
52 |
53 | count is: {{ count }}
54 |
55 |
56 | count is: {{ count }}
57 |
58 |
59 |
60 |
61 |
62 | Tag 1
63 |
64 |
65 | Tag 1
66 |
67 |
68 | Tag 1
69 |
70 |
71 | Tag 1
72 |
73 |
74 |
75 |
76 |
77 |
82 |
83 |
84 |
85 |
86 |
92 |
93 |
94 | For example, we can custom primary color to 'green'.
95 |
96 |
97 | Edit
98 | components/HelloWorld.vue to test components.
99 |
100 |
101 | Edit
102 | styles/element/var.scss to test scss variables.
103 |
104 |
105 |
106 | Full Example:
107 | element-plus-vite-starter
111 | | On demand Example:
112 | unplugin-element-plus/examples/vite
116 |
117 |
118 |
119 |
128 |
--------------------------------------------------------------------------------
/server/core/retriever/retriever.go:
--------------------------------------------------------------------------------
1 | package retriever
2 |
3 | import (
4 | "context"
5 | "fmt"
6 |
7 | "github.com/bytedance/sonic"
8 | "github.com/cloudwego/eino-ext/components/retriever/es8"
9 | "github.com/cloudwego/eino-ext/components/retriever/es8/search_mode"
10 | "github.com/cloudwego/eino/components/retriever"
11 | "github.com/cloudwego/eino/schema"
12 | "github.com/elastic/go-elasticsearch/v8/typedapi/types"
13 | "github.com/wangle201210/go-rag/server/core/common"
14 | "github.com/wangle201210/go-rag/server/core/config"
15 | coretypes "github.com/wangle201210/go-rag/server/core/types"
16 | )
17 |
18 | // newRetriever component initialization function of node 'Retriever1' in graph 'retriever'
19 | func newRetriever(ctx context.Context, conf *config.Config) (rtr retriever.Retriever, err error) {
20 | vectorField := coretypes.FieldContentVector
21 | if value, ok := ctx.Value(coretypes.RetrieverFieldKey).(string); ok {
22 | vectorField = value
23 | }
24 |
25 | embeddingIns, err := common.NewEmbedding(ctx, conf)
26 | if err != nil {
27 | return nil, err
28 | }
29 |
30 | // 根据客户端类型创建不同的 retriever
31 | if conf.Client != nil {
32 | // ES retriever
33 | retrieverConfig := &es8.RetrieverConfig{
34 | Client: conf.Client,
35 | Index: conf.IndexName,
36 | SearchMode: search_mode.SearchModeDenseVectorSimilarity(
37 | search_mode.DenseVectorSimilarityTypeCosineSimilarity,
38 | vectorField,
39 | ),
40 | ResultParser: EsHit2Document,
41 | Embedding: embeddingIns,
42 | }
43 | rtr, err = es8.NewRetriever(ctx, retrieverConfig)
44 | if err != nil {
45 | return nil, err
46 | }
47 | return rtr, nil
48 | } else if conf.QdrantClient != nil {
49 | // Qdrant retriever
50 | rtr, err = NewQdrantRetriever(ctx, &QdrantRetrieverConfig{
51 | Client: conf.QdrantClient,
52 | Collection: conf.IndexName,
53 | Embedding: embeddingIns,
54 | VectorField: vectorField,
55 | })
56 | if err != nil {
57 | return nil, err
58 | }
59 | return rtr, nil
60 | }
61 |
62 | return nil, fmt.Errorf("no valid client configuration found")
63 | }
64 |
65 | func EsHit2Document(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) {
66 | doc = &schema.Document{
67 | ID: *hit.Id_,
68 | MetaData: map[string]any{},
69 | }
70 |
71 | var src map[string]any
72 | if err = sonic.Unmarshal(hit.Source_, &src); err != nil {
73 | return nil, err
74 | }
75 |
76 | for field, val := range src {
77 | switch field {
78 | case coretypes.FieldContent:
79 | doc.Content = val.(string)
80 | case coretypes.FieldContentVector:
81 | var v []float64
82 | for _, item := range val.([]interface{}) {
83 | v = append(v, item.(float64))
84 | }
85 | doc.WithDenseVector(v)
86 | case coretypes.FieldQAContentVector, coretypes.FieldQAContent:
87 | // 这两个字段都不返回
88 |
89 | case coretypes.FieldExtra:
90 | if val == nil {
91 | continue
92 | }
93 | doc.MetaData[coretypes.FieldExtra] = val.(string)
94 | case coretypes.KnowledgeName:
95 | doc.MetaData[coretypes.KnowledgeName] = val.(string)
96 | default:
97 | return nil, fmt.Errorf("unexpected field=%s, val=%v", field, val)
98 | }
99 | }
100 |
101 | if hit.Score_ != nil {
102 | doc.WithScore(float64(*hit.Score_))
103 | }
104 |
105 | return doc, nil
106 | }
107 |
--------------------------------------------------------------------------------
/server/internal/dao/dao.go:
--------------------------------------------------------------------------------
1 | package dao
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "os"
7 | "path/filepath"
8 |
9 | "github.com/gogf/gf/v2/frame/g"
10 | database "github.com/wangle201210/go-rag/server/internal/dao/db"
11 | )
12 |
13 | var db database.Database
14 |
15 | func init() {
16 | err := InitDB()
17 | if err != nil {
18 | g.Log().Fatal(context.Background(), "database connection not initialized, err %v", err)
19 | }
20 | }
21 |
22 | // InitDB 初始化数据库连接
23 | func InitDB() error {
24 | ctx := context.Background()
25 | dbType := g.Cfg().MustGet(ctx, "database.default.type", "mysql").String()
26 | var (
27 | cfg *database.Config
28 | err error
29 | )
30 | switch dbType {
31 | case "mysql", "":
32 | cfg, _ = GetMysqlConfig()
33 | db = database.NewMysql(cfg)
34 | case "sqlite":
35 | cfg, err = GetSqliteConfig()
36 | if err != nil {
37 | return err
38 | }
39 | db = database.NewSqlite(cfg)
40 | default:
41 | return fmt.Errorf("unsupported database type: %s", dbType)
42 | }
43 | if err = db.Connect(); err != nil {
44 | return err
45 | }
46 | if err = db.Ping(); err != nil {
47 | return err
48 | }
49 | // 迁移表结构
50 | return db.AutoMigrate()
51 | }
52 |
53 | // ensureSQLiteFileDir 确保 SQLite 文件的父目录存在
54 | func ensureSQLiteFileDir(filePath string) error {
55 | // 获取文件的目录路径
56 | dir := filepath.Dir(filePath)
57 |
58 | // 检查目录是否存在
59 | if _, err := os.Stat(dir); os.IsNotExist(err) {
60 | // 目录不存在,创建目录(包括所有必要的父目录)
61 | if err := os.MkdirAll(dir, 0755); err != nil {
62 | return fmt.Errorf("failed to create directory %s: %v", dir, err)
63 | }
64 | g.Log().Infof(context.Background(), "Created directory for SQLite database: %s", dir)
65 | }
66 |
67 | return nil
68 | }
69 |
70 | func GetDsn() string {
71 | return db.DSN()
72 | }
73 |
74 | func GetSqliteConfig() (*database.Config, error) {
75 | ctx := context.Background()
76 | cfg := &database.Config{}
77 | cfg.FilePath = g.Cfg().MustGet(ctx, "database.default.host").String()
78 | if cfg.FilePath == "" {
79 | return nil, fmt.Errorf("sqlite file path is required when using sqlite")
80 | }
81 | if err := ensureSQLiteFileDir(cfg.FilePath); err != nil {
82 | return nil, fmt.Errorf("failed to create sqlite file directory: %v", err)
83 | }
84 | cfg.BusyTimeout = g.Cfg().MustGet(ctx, "database.default.busy_timeout").Int()
85 | cfg.JournalMode = g.Cfg().MustGet(ctx, "database.default.journal_mode").String()
86 | cfg.Synchronous = g.Cfg().MustGet(ctx, "database.default.synchronous").String()
87 | cfg.CacheSize = g.Cfg().MustGet(ctx, "database.default.cache_size").Int()
88 | cfg.MaxOpenConn = g.Cfg().MustGet(ctx, "database.default.max_open_conns", 1).Int()
89 | cfg.MaxIdleConn = g.Cfg().MustGet(ctx, "database.default.max_idle_conns", 1).Int()
90 | cfg.LogLevel = 4
91 | return cfg, nil
92 | }
93 |
94 | func GetMysqlConfig() (*database.Config, error) {
95 | cfg := g.DB().GetConfig()
96 | c := &database.Config{
97 | Host: cfg.Host,
98 | Port: cfg.Port,
99 | User: cfg.User,
100 | Password: cfg.Pass,
101 | Database: cfg.Name,
102 | Charset: cfg.Charset,
103 | MaxOpenConn: cfg.MaxOpenConnCount,
104 | MaxIdleConn: cfg.MaxIdleConnCount,
105 | LogLevel: 4,
106 | }
107 | if c.MaxIdleConn == 0 {
108 | c.MaxIdleConn = 10
109 | }
110 | if c.MaxOpenConn == 0 {
111 | c.MaxOpenConn = 100
112 | }
113 | return c, nil
114 | }
115 |
--------------------------------------------------------------------------------
/fe/src/components.d.ts:
--------------------------------------------------------------------------------
1 | /* eslint-disable */
2 | // @ts-nocheck
3 | // Generated by unplugin-vue-components
4 | // Read more: https://github.com/vuejs/core/pull/3399
5 | export {}
6 |
7 | /* prettier-ignore */
8 | declare module 'vue' {
9 | export interface GlobalComponents {
10 | BaseHeader: typeof import('./components/layouts/BaseHeader.vue')['default']
11 | BaseSide: typeof import('./components/layouts/BaseSide.vue')['default']
12 | ElAlert: typeof import('element-plus/es')['ElAlert']
13 | ElAvatar: typeof import('element-plus/es')['ElAvatar']
14 | ElButton: typeof import('element-plus/es')['ElButton']
15 | ElCard: typeof import('element-plus/es')['ElCard']
16 | ElCol: typeof import('element-plus/es')['ElCol']
17 | ElCollapse: typeof import('element-plus/es')['ElCollapse']
18 | ElCollapseItem: typeof import('element-plus/es')['ElCollapseItem']
19 | ElCollapseTransition: typeof import('element-plus/es')['ElCollapseTransition']
20 | ElConfigProvider: typeof import('element-plus/es')['ElConfigProvider']
21 | ElDescriptions: typeof import('element-plus/es')['ElDescriptions']
22 | ElDescriptionsItem: typeof import('element-plus/es')['ElDescriptionsItem']
23 | ElDialog: typeof import('element-plus/es')['ElDialog']
24 | ElDivider: typeof import('element-plus/es')['ElDivider']
25 | ElEmpty: typeof import('element-plus/es')['ElEmpty']
26 | ElForm: typeof import('element-plus/es')['ElForm']
27 | ElFormItem: typeof import('element-plus/es')['ElFormItem']
28 | ElIcon: typeof import('element-plus/es')['ElIcon']
29 | ElInput: typeof import('element-plus/es')['ElInput']
30 | ElInputNumber: typeof import('element-plus/es')['ElInputNumber']
31 | ElMenu: typeof import('element-plus/es')['ElMenu']
32 | ElMenuItem: typeof import('element-plus/es')['ElMenuItem']
33 | ElOption: typeof import('element-plus/es')['ElOption']
34 | ElPageHeader: typeof import('element-plus/es')['ElPageHeader']
35 | ElPagination: typeof import('element-plus/es')['ElPagination']
36 | ElPopover: typeof import('element-plus/es')['ElPopover']
37 | ElRadio: typeof import('element-plus/es')['ElRadio']
38 | ElRadioGroup: typeof import('element-plus/es')['ElRadioGroup']
39 | ElRow: typeof import('element-plus/es')['ElRow']
40 | ElScrollbar: typeof import('element-plus/es')['ElScrollbar']
41 | ElSelect: typeof import('element-plus/es')['ElSelect']
42 | ElSkeleton: typeof import('element-plus/es')['ElSkeleton']
43 | ElSlider: typeof import('element-plus/es')['ElSlider']
44 | ElSpace: typeof import('element-plus/es')['ElSpace']
45 | ElTable: typeof import('element-plus/es')['ElTable']
46 | ElTableColumn: typeof import('element-plus/es')['ElTableColumn']
47 | ElTag: typeof import('element-plus/es')['ElTag']
48 | ElTooltip: typeof import('element-plus/es')['ElTooltip']
49 | ElUpload: typeof import('element-plus/es')['ElUpload']
50 | HelloWorld: typeof import('./components/HelloWorld.vue')['default']
51 | KnowledgeSelector: typeof import('./components/KnowledgeSelector.vue')['default']
52 | Logos: typeof import('./components/Logos.vue')['default']
53 | MessageBoxDemo: typeof import('./components/MessageBoxDemo.vue')['default']
54 | RouterLink: typeof import('vue-router')['RouterLink']
55 | RouterView: typeof import('vue-router')['RouterView']
56 | }
57 | export interface ComponentCustomProperties {
58 | vLoading: typeof import('element-plus/es')['ElLoadingDirective']
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/server/internal/mcp/indexer.go:
--------------------------------------------------------------------------------
1 | package mcp
2 |
3 | import (
4 | "context"
5 | "encoding/base64"
6 | "fmt"
7 |
8 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
9 | "github.com/gogf/gf/v2/frame/g"
10 | "github.com/gogf/gf/v2/os/gctx"
11 | gorag "github.com/wangle201210/go-rag/server/core"
12 | "github.com/wangle201210/go-rag/server/internal/logic/rag"
13 | )
14 |
15 | type IndexParam struct {
16 | URI string `json:"uri" description:"文件路径" required:"true"` // 可以是文件路径(pdf,html,md等),也可以是网址
17 | KnowledgeName string `json:"knowledge_name" description:"知识库名字,请先通过getKnowledgeBaseList获取列表后判断是否有符合的知识库,如果没有则根据用户提示词自己生成" required:"true"`
18 | }
19 |
20 | func GetIndexerByFilePathTool() *protocol.Tool {
21 | tool, err := protocol.NewTool("Indexer_by_filepath", "通过文件路径进行文本嵌入", IndexParam{})
22 | if err != nil {
23 | g.Log().Errorf(gctx.New(), "Failed to create tool: %v", err)
24 | return nil
25 | }
26 | return tool
27 | }
28 |
29 | func HandleIndexerByFilePath(ctx context.Context, req *protocol.CallToolRequest) (*protocol.CallToolResult, error) {
30 | var reqData IndexParam
31 | if err := protocol.VerifyAndUnmarshal(req.RawArguments, &reqData); err != nil {
32 | return nil, err
33 | }
34 | svr := rag.GetRagSvr()
35 | uri := reqData.URI
36 | indexReq := &gorag.IndexReq{
37 | URI: uri,
38 | KnowledgeName: reqData.KnowledgeName,
39 | }
40 | ids, err := svr.Index(ctx, indexReq)
41 | if err != nil {
42 | return nil, err
43 | }
44 | msg := fmt.Sprintf("index file %s successfully, knowledge_name: %s, doc_ids: %v", uri, reqData.KnowledgeName, ids)
45 | return &protocol.CallToolResult{
46 | Content: []protocol.Content{
47 | &protocol.TextContent{
48 | Type: "text",
49 | Text: msg,
50 | },
51 | },
52 | }, nil
53 | }
54 |
55 | type IndexFileParam struct {
56 | Filename string `json:"filename" description:"文件名字" required:"true"`
57 | Content string `json:"content" description:"被base64编码后的文件内容,先调用工具获取base64信息" required:"true"` // 可以是文件路径(pdf,html,md等),也可以是网址文件" required:"true"` // 可以是文件路径(pdf,html,md等),也可以是网址
58 | KnowledgeName string `json:"knowledge_name" description:"知识库名字,请先通过getKnowledgeBaseList获取列表后判断是否有符合的知识库,如果没有则根据用户提示词自己生成" required:"true"`
59 | }
60 |
61 | func GetIndexerByFileBase64ContentTool() *protocol.Tool {
62 | tool, err := protocol.NewTool("Indexer_by_base64_file_content", "获取文件base64信息后上传,然后对内容进行文本嵌入", IndexFileParam{})
63 | if err != nil {
64 | g.Log().Errorf(gctx.New(), "Failed to create tool: %v", err)
65 | return nil
66 | }
67 | return tool
68 | }
69 |
70 | func HandleIndexerByFileBase64Content(ctx context.Context, req *protocol.CallToolRequest) (*protocol.CallToolResult, error) {
71 | var reqData IndexFileParam
72 | if err := protocol.VerifyAndUnmarshal(req.RawArguments, &reqData); err != nil {
73 | return nil, err
74 | }
75 | // svr := rag.GetRagSvr()
76 | decoded, err := base64.StdEncoding.DecodeString(reqData.Content)
77 | if err != nil {
78 | return nil, err
79 | }
80 | fmt.Println(decoded)
81 | // indexReq := &gorag.IndexReq{
82 | // URI: uri,
83 | // KnowledgeName: reqData.KnowledgeName,
84 | // }
85 | // ids, err := svr.Index(ctx, indexReq)
86 | // if err != nil {
87 | // return nil, err
88 | // }
89 | // msg := fmt.Sprintf("index file %s successfully, knowledge_name: %s, doc_ids: %v", uri, reqData.KnowledgeName, ids)
90 | return &protocol.CallToolResult{
91 | Content: []protocol.Content{
92 | &protocol.TextContent{
93 | Type: "text",
94 | Text: string(decoded),
95 | },
96 | },
97 | }, nil
98 | }
99 |
--------------------------------------------------------------------------------
/server/core/indexer/merge.go:
--------------------------------------------------------------------------------
1 | package indexer
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "strings"
7 |
8 | "github.com/bytedance/sonic"
9 | "github.com/cloudwego/eino-ext/components/document/loader/file"
10 | "github.com/cloudwego/eino/schema"
11 | "github.com/google/uuid"
12 | coretypes "github.com/wangle201210/go-rag/server/core/types"
13 | )
14 |
15 | // docAddIDAndMerge component initialization function of node 'Lambda1' in graph 't'
16 | func docAddIDAndMerge(ctx context.Context, docs []*schema.Document) (output []*schema.Document, err error) {
17 | if len(docs) == 0 {
18 | return docs, nil
19 | }
20 | for _, doc := range docs {
21 | doc.ID = uuid.New().String() // 覆盖之前的id
22 | }
23 | switch docs[0].MetaData[file.MetaKeyExtension] {
24 | case ".md":
25 | return mergeMD(ctx, docs)
26 | case ".xlsx":
27 | return mergeXLSX(ctx, docs)
28 | default:
29 | return docs, nil
30 | }
31 | }
32 |
33 | func mergeMD(ctx context.Context, docs []*schema.Document) (output []*schema.Document, err error) {
34 | ndocs := make([]*schema.Document, 0, len(docs))
35 | var nd *schema.Document
36 | maxLen := 512
37 | for _, doc := range docs {
38 | // 不是同一个文件的就不要放一起了
39 | if nd != nil && doc.MetaData[file.MetaKeySource] != nd.MetaData[file.MetaKeySource] {
40 | ndocs = append(ndocs, nd)
41 | nd = nil
42 | }
43 | // 两个文档长度之和大于maxLen就不要放一起了
44 | if nd != nil && len(nd.Content)+len(doc.Content) > maxLen {
45 | ndocs = append(ndocs, nd)
46 | nd = nil
47 | }
48 | // 不是同一个一级标题的就不要放一起了
49 | if nd != nil && doc.MetaData[coretypes.Title1] != nd.MetaData[coretypes.Title1] {
50 | ndocs = append(ndocs, nd)
51 | nd = nil
52 | }
53 | // 不是同一个二级标题的就不要放一起了
54 | // 如果nd的h2是nil,证明之前只有h1,且两个的h1相等,则直接合并
55 | if nd != nil && nd.MetaData[coretypes.Title2] != nil && doc.MetaData[coretypes.Title2] != nd.MetaData[coretypes.Title2] {
56 | ndocs = append(ndocs, nd)
57 | nd = nil
58 | }
59 | if nd == nil {
60 | nd = doc
61 | } else {
62 | mergeTitle(nd, doc, coretypes.Title2)
63 | mergeTitle(nd, doc, coretypes.Title3)
64 | nd.Content += doc.Content
65 | }
66 | }
67 | if nd != nil {
68 | ndocs = append(ndocs, nd)
69 | }
70 | for _, ndoc := range ndocs {
71 | ndoc.Content = getMdContentWithTitle(ndoc)
72 | }
73 | return ndocs, nil
74 | }
75 |
76 | func mergeXLSX(ctx context.Context, docs []*schema.Document) (output []*schema.Document, err error) {
77 | for _, doc := range docs {
78 | marshal, _ := sonic.Marshal(doc.MetaData[coretypes.XlsxRow])
79 | doc.Content = string(marshal)
80 | }
81 | return docs, nil
82 | }
83 |
84 | func mergeTitle(orgDoc, addDoc *schema.Document, key string) {
85 | // 相等就不管了
86 | if orgDoc.MetaData[key] == addDoc.MetaData[key] {
87 | return
88 | }
89 | var title []string
90 | if orgDoc.MetaData[key] != nil {
91 | title = append(title, orgDoc.MetaData[key].(string))
92 | }
93 | if addDoc.MetaData[key] != nil {
94 | title = append(title, addDoc.MetaData[key].(string))
95 | }
96 | if len(title) > 0 {
97 | orgDoc.MetaData[key] = strings.Join(title, ",")
98 | }
99 | }
100 |
101 | func getMdContentWithTitle(doc *schema.Document) string {
102 | if doc.MetaData == nil {
103 | return doc.Content
104 | }
105 | title := ""
106 | list := []string{"h1", "h2", "h3", "h4", "h5", "h6"}
107 | for _, v := range list {
108 | if d, e := doc.MetaData[v].(string); e && len(d) > 0 {
109 | title += fmt.Sprintf("%s:%s ", v, d)
110 | }
111 | }
112 | if len(title) == 0 {
113 | return doc.Content
114 | }
115 | return title + "\n" + doc.Content
116 | }
117 |
--------------------------------------------------------------------------------
/server/internal/dao/internal/knowledge_chunks.go:
--------------------------------------------------------------------------------
1 | // ==========================================================================
2 | // Code generated and maintained by GoFrame CLI tool. DO NOT EDIT.
3 | // ==========================================================================
4 |
5 | package internal
6 |
7 | import (
8 | "context"
9 |
10 | "github.com/gogf/gf/v2/database/gdb"
11 | "github.com/gogf/gf/v2/frame/g"
12 | )
13 |
14 | // KnowledgeChunksDao is the data access object for the table knowledge_chunks.
15 | type KnowledgeChunksDao struct {
16 | table string // table is the underlying table name of the DAO.
17 | group string // group is the database configuration group name of the current DAO.
18 | columns KnowledgeChunksColumns // columns contains all the column names of Table for convenient usage.
19 | handlers []gdb.ModelHandler // handlers for customized model modification.
20 | }
21 |
22 | // KnowledgeChunksColumns defines and stores column names for the table knowledge_chunks.
23 | type KnowledgeChunksColumns struct {
24 | Id string //
25 | KnowledgeDocId string //
26 | ChunkId string //
27 | Content string //
28 | Ext string //
29 | Status string //
30 | CreatedAt string //
31 | UpdatedAt string //
32 | }
33 |
34 | // knowledgeChunksColumns holds the columns for the table knowledge_chunks.
35 | var knowledgeChunksColumns = KnowledgeChunksColumns{
36 | Id: "id",
37 | KnowledgeDocId: "knowledge_doc_id",
38 | ChunkId: "chunk_id",
39 | Content: "content",
40 | Ext: "ext",
41 | Status: "status",
42 | CreatedAt: "created_at",
43 | UpdatedAt: "updated_at",
44 | }
45 |
46 | // NewKnowledgeChunksDao creates and returns a new DAO object for table data access.
47 | func NewKnowledgeChunksDao(handlers ...gdb.ModelHandler) *KnowledgeChunksDao {
48 | return &KnowledgeChunksDao{
49 | group: "default",
50 | table: "knowledge_chunks",
51 | columns: knowledgeChunksColumns,
52 | handlers: handlers,
53 | }
54 | }
55 |
56 | // DB retrieves and returns the underlying raw database management object of the current DAO.
57 | func (dao *KnowledgeChunksDao) DB() gdb.DB {
58 | return g.DB(dao.group)
59 | }
60 |
61 | // Table returns the table name of the current DAO.
62 | func (dao *KnowledgeChunksDao) Table() string {
63 | return dao.table
64 | }
65 |
66 | // Columns returns all column names of the current DAO.
67 | func (dao *KnowledgeChunksDao) Columns() KnowledgeChunksColumns {
68 | return dao.columns
69 | }
70 |
71 | // Group returns the database configuration group name of the current DAO.
72 | func (dao *KnowledgeChunksDao) Group() string {
73 | return dao.group
74 | }
75 |
76 | // Ctx creates and returns a Model for the current DAO. It automatically sets the context for the current operation.
77 | func (dao *KnowledgeChunksDao) Ctx(ctx context.Context) *gdb.Model {
78 | model := dao.DB().Model(dao.table)
79 | for _, handler := range dao.handlers {
80 | model = handler(model)
81 | }
82 | return model.Safe().Ctx(ctx)
83 | }
84 |
85 | // Transaction wraps the transaction logic using function f.
86 | // It rolls back the transaction and returns the error if function f returns a non-nil error.
87 | // It commits the transaction and returns nil if function f returns nil.
88 | //
89 | // Note: Do not commit or roll back the transaction in function f,
90 | // as it is automatically handled by this function.
91 | func (dao *KnowledgeChunksDao) Transaction(ctx context.Context, f func(ctx context.Context, tx gdb.TX) error) (err error) {
92 | return dao.Ctx(ctx).Transaction(ctx, f)
93 | }
94 |
--------------------------------------------------------------------------------
/server/internal/dao/internal/knowledge_documents.go:
--------------------------------------------------------------------------------
1 | // ==========================================================================
2 | // Code generated and maintained by GoFrame CLI tool. DO NOT EDIT.
3 | // ==========================================================================
4 |
5 | package internal
6 |
7 | import (
8 | "context"
9 |
10 | "github.com/gogf/gf/v2/database/gdb"
11 | "github.com/gogf/gf/v2/frame/g"
12 | )
13 |
14 | // KnowledgeDocumentsDao is the data access object for the table knowledge_documents.
15 | type KnowledgeDocumentsDao struct {
16 | table string // table is the underlying table name of the DAO.
17 | group string // group is the database configuration group name of the current DAO.
18 | columns KnowledgeDocumentsColumns // columns contains all the column names of Table for convenient usage.
19 | handlers []gdb.ModelHandler // handlers for customized model modification.
20 | }
21 |
22 | // KnowledgeDocumentsColumns defines and stores column names for the table knowledge_documents.
23 | type KnowledgeDocumentsColumns struct {
24 | Id string //
25 | KnowledgeBaseName string //
26 | FileName string //
27 | Status string //
28 | CreatedAt string //
29 | UpdatedAt string //
30 | }
31 |
32 | // knowledgeDocumentsColumns holds the columns for the table knowledge_documents.
33 | var knowledgeDocumentsColumns = KnowledgeDocumentsColumns{
34 | Id: "id",
35 | KnowledgeBaseName: "knowledge_base_name",
36 | FileName: "file_name",
37 | Status: "status",
38 | CreatedAt: "created_at",
39 | UpdatedAt: "updated_at",
40 | }
41 |
42 | // NewKnowledgeDocumentsDao creates and returns a new DAO object for table data access.
43 | func NewKnowledgeDocumentsDao(handlers ...gdb.ModelHandler) *KnowledgeDocumentsDao {
44 | return &KnowledgeDocumentsDao{
45 | group: "default",
46 | table: "knowledge_documents",
47 | columns: knowledgeDocumentsColumns,
48 | handlers: handlers,
49 | }
50 | }
51 |
52 | // DB retrieves and returns the underlying raw database management object of the current DAO.
53 | func (dao *KnowledgeDocumentsDao) DB() gdb.DB {
54 | return g.DB(dao.group)
55 | }
56 |
57 | // Table returns the table name of the current DAO.
58 | func (dao *KnowledgeDocumentsDao) Table() string {
59 | return dao.table
60 | }
61 |
62 | // Columns returns all column names of the current DAO.
63 | func (dao *KnowledgeDocumentsDao) Columns() KnowledgeDocumentsColumns {
64 | return dao.columns
65 | }
66 |
67 | // Group returns the database configuration group name of the current DAO.
68 | func (dao *KnowledgeDocumentsDao) Group() string {
69 | return dao.group
70 | }
71 |
72 | // Ctx creates and returns a Model for the current DAO. It automatically sets the context for the current operation.
73 | func (dao *KnowledgeDocumentsDao) Ctx(ctx context.Context) *gdb.Model {
74 | model := dao.DB().Model(dao.table)
75 | for _, handler := range dao.handlers {
76 | model = handler(model)
77 | }
78 | return model.Safe().Ctx(ctx)
79 | }
80 |
81 | // Transaction wraps the transaction logic using function f.
82 | // It rolls back the transaction and returns the error if function f returns a non-nil error.
83 | // It commits the transaction and returns nil if function f returns nil.
84 | //
85 | // Note: Do not commit or roll back the transaction in function f,
86 | // as it is automatically handled by this function.
87 | func (dao *KnowledgeDocumentsDao) Transaction(ctx context.Context, f func(ctx context.Context, tx gdb.TX) error) (err error) {
88 | return dao.Ctx(ctx).Transaction(ctx, f)
89 | }
90 |
--------------------------------------------------------------------------------
/server/core/rerank/rerank.go:
--------------------------------------------------------------------------------
1 | package rerank
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "fmt"
7 | "io/ioutil"
8 | "net/http"
9 |
10 | "github.com/bytedance/sonic"
11 | "github.com/cloudwego/eino/schema"
12 | "github.com/gogf/gf/v2/frame/g"
13 | )
14 |
15 | type Conf struct {
16 | Model string `json:"model"`
17 | ReturnDocuments bool `json:"return_documents"`
18 | MaxChunksPerDoc int `json:"max_chunks_per_doc"`
19 | OverlapTokens int `json:"overlap_tokens"`
20 | url string
21 | apiKey string
22 | }
23 | type Data struct {
24 | Query string `json:"query"`
25 | Documents []string `json:"documents"`
26 | TopN int `json:"top_n"`
27 | }
28 |
29 | type Req struct {
30 | *Data
31 | *Conf
32 | }
33 |
34 | type Result struct {
35 | Index int `json:"index"`
36 | RelevanceScore float64 `json:"relevance_score"`
37 | }
38 |
39 | type Resp struct {
40 | ID string `json:"id"`
41 | Results []*Result `json:"results"`
42 | }
43 |
44 | var rerankCfg *Conf
45 |
46 | func NewRerank(ctx context.Context, query string, docs []*schema.Document, topK int) (output []*schema.Document, err error) {
47 | output, err = rerank(ctx, query, docs, topK)
48 | if err != nil {
49 | return
50 | }
51 | return
52 | }
53 |
54 | func GetConf(ctx context.Context) *Conf {
55 | if rerankCfg != nil {
56 | return rerankCfg
57 | }
58 | baseUrl := g.Cfg().MustGet(ctx, "rerank.baseURL").String()
59 | apiKey := g.Cfg().MustGet(ctx, "rerank.apiKey").String()
60 | model := g.Cfg().MustGet(ctx, "rerank.model").String()
61 | url := fmt.Sprintf("%s/rerank", baseUrl)
62 | rerankCfg = &Conf{
63 | apiKey: apiKey,
64 | Model: model,
65 | ReturnDocuments: false,
66 | MaxChunksPerDoc: 1024,
67 | OverlapTokens: 80,
68 | url: url,
69 | }
70 | return rerankCfg
71 | }
72 |
73 | func rerank(ctx context.Context, query string, docs []*schema.Document, topK int) (output []*schema.Document, err error) {
74 | data := &Data{
75 | Query: query,
76 | TopN: topK,
77 | }
78 | // g.Log().Infof(ctx, "docs num: %d", len(docs))
79 | for _, doc := range docs {
80 | data.Documents = append(data.Documents, doc.Content)
81 | }
82 | // 重排
83 | results, err := rerankDoHttp(ctx, data)
84 | if err != nil {
85 | return
86 | }
87 | // 重新组装数据
88 | for _, result := range results {
89 | doc := docs[result.Index]
90 | // g.Log().Infof(ctx, "content: %s, score_old: %f, score_new: %f", doc.Content, doc.Score(), result.RelevanceScore)
91 | doc.WithScore(result.RelevanceScore)
92 | output = append(output, docs[result.Index])
93 | }
94 | return
95 | }
96 |
97 | func rerankDoHttp(ctx context.Context, data *Data) ([]*Result, error) {
98 | cfg := GetConf(ctx)
99 | reqData := &Req{
100 | Data: data,
101 | Conf: cfg,
102 | }
103 |
104 | marshal, err := sonic.Marshal(reqData)
105 | if err != nil {
106 | return nil, err
107 | }
108 | payload := bytes.NewReader(marshal)
109 | request, err := http.NewRequest("POST", cfg.url, payload)
110 | if err != nil {
111 | return nil, err
112 | }
113 | request.Header.Add("Content-Type", "application/json")
114 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", cfg.apiKey))
115 | do, err := g.Client().Do(request)
116 | if err != nil {
117 | return nil, err
118 | }
119 | defer do.Body.Close()
120 | body, err := ioutil.ReadAll(do.Body)
121 | if err != nil {
122 | return nil, err
123 | }
124 | res := Resp{}
125 | err = sonic.Unmarshal(body, &res)
126 | if err != nil {
127 | return nil, err
128 | }
129 | return res.Results, nil
130 | }
131 |
--------------------------------------------------------------------------------
/server/core/indexer/qdrant_indexer_custom.go:
--------------------------------------------------------------------------------
1 | package indexer
2 |
3 | import (
4 | "context"
5 | "fmt"
6 |
7 | "github.com/bytedance/sonic"
8 | "github.com/cloudwego/eino/components/indexer"
9 | "github.com/cloudwego/eino/schema"
10 | "github.com/gogf/gf/v2/frame/g"
11 | "github.com/google/uuid"
12 | "github.com/qdrant/go-client/qdrant"
13 | coretypes "github.com/wangle201210/go-rag/server/core/types"
14 | )
15 |
16 | // StoreWithNamedVectors 使用命名向量存储文档到 Qdrant
17 | // 这是自定义实现,因为 eino-ext 的 indexer 不支持命名向量
18 | func (idx *QdrantIndexer) StoreWithNamedVectors(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) ([]string, error) {
19 | if len(docs) == 0 {
20 | return nil, nil
21 | }
22 |
23 | var knowledgeName string
24 | if value, ok := ctx.Value(coretypes.KnowledgeName).(string); ok {
25 | knowledgeName = value
26 | } else {
27 | return nil, fmt.Errorf("必须提供知识库名称")
28 | }
29 |
30 | g.Log().Infof(ctx, "QdrantIndexer.StoreWithNamedVectors: storing %d documents to collection %s, knowledge_name=%s", len(docs), idx.config.Collection, knowledgeName)
31 |
32 | // 准备 points
33 | points := make([]*qdrant.PointStruct, 0, len(docs))
34 | ids := make([]string, 0, len(docs))
35 |
36 | for _, doc := range docs {
37 | // 生成 ID
38 | if len(doc.ID) == 0 {
39 | doc.ID = uuid.New().String()
40 | }
41 | ids = append(ids, doc.ID)
42 |
43 | // 生成 embedding
44 | embeddings, err := idx.config.Embedding.EmbedStrings(ctx, []string{doc.Content})
45 | if err != nil {
46 | g.Log().Errorf(ctx, "Failed to embed document %s: %v", doc.ID, err)
47 | return nil, fmt.Errorf("failed to embed document: %w", err)
48 | }
49 | if len(embeddings) == 0 {
50 | return nil, fmt.Errorf("embedding returned empty result")
51 | }
52 |
53 | // 转换为 float32
54 | vec32 := make([]float32, len(embeddings[0]))
55 | for i, v := range embeddings[0] {
56 | vec32[i] = float32(v)
57 | }
58 |
59 | // 准备 payload
60 | payload := make(map[string]*qdrant.Value)
61 | payload[coretypes.FieldContent] = &qdrant.Value{
62 | Kind: &qdrant.Value_StringValue{StringValue: doc.Content},
63 | }
64 | payload[coretypes.KnowledgeName] = &qdrant.Value{
65 | Kind: &qdrant.Value_StringValue{StringValue: knowledgeName},
66 | }
67 |
68 | // 添加额外的 metadata
69 | if doc.MetaData != nil {
70 | extData := getExtData(doc)
71 | if len(extData) > 0 {
72 | marshal, _ := sonic.Marshal(extData)
73 | payload[coretypes.FieldExtra] = &qdrant.Value{
74 | Kind: &qdrant.Value_StringValue{StringValue: string(marshal)},
75 | }
76 | }
77 | }
78 |
79 | // 创建命名向量(只存储 content_vector,qa_content_vector 由异步任务处理)
80 | vectors := &qdrant.Vectors{
81 | VectorsOptions: &qdrant.Vectors_Vectors{
82 | Vectors: &qdrant.NamedVectors{
83 | Vectors: map[string]*qdrant.Vector{
84 | coretypes.FieldContentVector: {
85 | Data: vec32,
86 | },
87 | },
88 | },
89 | },
90 | }
91 |
92 | // 创建 point
93 | point := &qdrant.PointStruct{
94 | Id: &qdrant.PointId{
95 | PointIdOptions: &qdrant.PointId_Uuid{Uuid: doc.ID},
96 | },
97 | Vectors: vectors,
98 | Payload: payload,
99 | }
100 |
101 | points = append(points, point)
102 | }
103 |
104 | // 批量存储到 Qdrant
105 | _, err := idx.config.Client.Upsert(ctx, &qdrant.UpsertPoints{
106 | CollectionName: idx.config.Collection,
107 | Points: points,
108 | })
109 | if err != nil {
110 | g.Log().Errorf(ctx, "QdrantIndexer.StoreWithNamedVectors failed: %v", err)
111 | return nil, fmt.Errorf("failed to upsert points: %w", err)
112 | }
113 |
114 | g.Log().Infof(ctx, "QdrantIndexer.StoreWithNamedVectors success: stored %d documents, IDs: %v", len(ids), ids)
115 |
116 | return ids, nil
117 | }
118 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # Makefile for go-rag project
2 |
3 | .PHONY: build-fe build-server build docker-build clean build-all release clean-release
4 |
5 | # 默认目标
6 | all: build
7 |
8 | # 构建前端
9 | build-fe:
10 | cd fe && pnpm install && pnpm run build
11 | mkdir -p server/static/fe
12 | cp -r fe/dist/* server/static/fe/
13 |
14 | # 构建后端
15 | build-server:
16 | cd server && go mod tidy && go build -o go-rag-server main.go
17 |
18 | # 构建整个项目
19 | build: build-fe build-server
20 |
21 | # 运行
22 | run:
23 | cd server && ./go-rag-server
24 |
25 | # 清理构建产物
26 | clean:
27 | rm -rf fe/dist
28 | rm -rf server/static/fe
29 | rm -f server/go-rag-server
30 |
31 | # 构建Docker镜像
32 | docker-build: build
33 | docker build -t go-rag:latest -f Dockerfile .
34 |
35 | run-local:
36 | cd server && go mod tidy && go run .
37 |
38 | build-linux:
39 | cd server && go mod tidy && GOOS=linux GOARCH=amd64 go build -o go-rag-server
40 |
41 | run-by-docker:
42 | docker compose -f docker-compose.yml up -d
43 |
44 | v := v0.0.3
45 | buildx:
46 | docker buildx build \
47 | --platform linux/arm64,linux/amd64 \
48 | -t iwangle/go-rag:$(v) \
49 | --push \
50 | .
51 |
52 | # 项目名称和版本
53 | APP_NAME := go-rag
54 |
55 | # 支持的平台
56 | PLATFORMS := linux/amd64 linux/arm64 darwin/amd64 darwin/arm64 windows/amd64
57 |
58 | # 多平台构建
59 | build-all:
60 | @echo "Building go-rag for multiple platforms..."
61 | @mkdir -p releases
62 | @for platform in $(PLATFORMS); do \
63 | os=$$(echo $$platform | cut -d'/' -f1); \
64 | arch=$$(echo $$platform | cut -d'/' -f2); \
65 | output_name=$(APP_NAME)-$$os-$$arch; \
66 | if [ $$os = "windows" ]; then output_name=$$output_name.exe; fi; \
67 | echo "Building for $$os/$$arch..."; \
68 | (cd server && GOOS=$$os GOARCH=$$arch go build -o ../releases/$$output_name .); \
69 | done
70 | @echo "Build completed! Files are in releases/ directory"
71 |
72 | # 发布版本(构建 + 压缩)
73 | release: clean-release build-all
74 | @echo "Creating release archives..."
75 | @cd releases && \
76 | for platform in $(PLATFORMS); do \
77 | os=$$(echo $$platform | cut -d'/' -f1); \
78 | arch=$$(echo $$platform | cut -d'/' -f2); \
79 | release_dir=$(APP_NAME); \
80 | output_name=$(APP_NAME); \
81 | if [ $$os = "windows" ]; then output_name=$$output_name.exe; fi; \
82 | exe_file=$(APP_NAME)-$$os-$$arch; \
83 | if [ $$os = "windows" ]; then exe_file=$$exe_file.exe; fi; \
84 | archive_name=$(APP_NAME)-$$os-$$arch; \
85 | \
86 | echo "Preparing $$archive_name..."; \
87 | mkdir -p $$release_dir; \
88 | \
89 | if [ -f $$exe_file ]; then \
90 | cp $$exe_file $$release_dir/$$output_name; \
91 | echo "Copied executable to $$release_dir/$$output_name"; \
92 | else \
93 | echo "Warning: $$exe_file not found, skipping $$platform"; \
94 | rm -rf $$release_dir; \
95 | continue; \
96 | fi; \
97 | \
98 | if [ -d ../server/static ]; then \
99 | cp -r ../server/static $$release_dir/; \
100 | echo "Copied static files to $$release_dir/static/"; \
101 | fi; \
102 | \
103 | if [ -f ../server/manifest/config/config_qd_demo.yaml ]; then \
104 | cp ../server/manifest/config/config_qd_demo.yaml $$release_dir/config.yaml; \
105 | echo "Copied config file to $$release_dir/config.yaml"; \
106 | fi; \
107 | \
108 | if [ $$os = "windows" ]; then \
109 | zip -q $$archive_name.zip -r $$release_dir; \
110 | echo "Created $$archive_name.zip"; \
111 | else \
112 | tar -czf $$archive_name.tar.gz $$release_dir; \
113 | echo "Created $$archive_name.tar.gz"; \
114 | fi; \
115 | rm -rf $$release_dir; \
116 | done
117 | @echo "Release archives created!"
118 | @echo "Files in releases/:"
119 | @ls -la releases/
120 |
121 | # 清理发布文件
122 | clean-release:
123 | @echo "Cleaning release files..."
124 | @rm -rf releases/
125 | @echo "Release files cleaned!"
126 |
127 |
--------------------------------------------------------------------------------
/server/internal/logic/knowledge/chunks.go:
--------------------------------------------------------------------------------
1 | package knowledge
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/gogf/gf/v2/frame/g"
7 | v1 "github.com/wangle201210/go-rag/server/api/rag/v1"
8 | "github.com/wangle201210/go-rag/server/internal/dao"
9 | "github.com/wangle201210/go-rag/server/internal/model/entity"
10 | )
11 |
12 | // SaveChunksData 批量保存知识块数据
13 | func SaveChunksData(ctx context.Context, documentsId int64, chunks []entity.KnowledgeChunks) error {
14 | if len(chunks) == 0 {
15 | return nil
16 | }
17 | status := int(v1.StatusIndexing)
18 |
19 | // 逐个插入或更新,避免 SQLite 的 ON CONFLICT 语法问题
20 | for _, chunk := range chunks {
21 | // 先尝试查询是否存在
22 | var existing entity.KnowledgeChunks
23 | err := dao.KnowledgeChunks.Ctx(ctx).Where("chunk_id", chunk.ChunkId).Scan(&existing)
24 |
25 | if err == nil && existing.Id > 0 {
26 | // 已存在,更新(排除 id 和 created_at)
27 | _, err = dao.KnowledgeChunks.Ctx(ctx).
28 | Where("chunk_id", chunk.ChunkId).
29 | Data(g.Map{
30 | "knowledge_doc_id": chunk.KnowledgeDocId,
31 | "content": chunk.Content,
32 | "ext": chunk.Ext,
33 | "status": chunk.Status,
34 | }).
35 | Update()
36 | if err != nil {
37 | g.Log().Errorf(ctx, "SaveChunksData update failed for chunk_id=%s, err=%+v", chunk.ChunkId, err)
38 | status = int(v1.StatusFailed)
39 | }
40 | } else {
41 | // 不存在,插入(id 设为 0 让数据库自动分配)
42 | chunk.Id = 0
43 | _, err = dao.KnowledgeChunks.Ctx(ctx).Data(chunk).OmitEmpty().Insert()
44 | if err != nil {
45 | g.Log().Errorf(ctx, "SaveChunksData insert failed for chunk_id=%s, err=%+v", chunk.ChunkId, err)
46 | status = int(v1.StatusFailed)
47 | }
48 | }
49 | }
50 |
51 | UpdateDocumentsStatus(ctx, documentsId, status)
52 | return nil
53 | }
54 |
55 | // GetChunksList 查询知识块列表
56 | func GetChunksList(ctx context.Context, where entity.KnowledgeChunks, page, size int) (list []entity.KnowledgeChunks, total int, err error) {
57 | model := dao.KnowledgeChunks.Ctx(ctx)
58 |
59 | // 构建查询条件
60 | if where.KnowledgeDocId != 0 {
61 | model = model.Where("knowledge_doc_id", where.KnowledgeDocId)
62 | }
63 | if where.ChunkId != "" {
64 | model = model.Where("chunk_id", where.ChunkId)
65 | }
66 |
67 | // 获取总数
68 | total, err = model.Count()
69 | if err != nil {
70 | return
71 | }
72 |
73 | // 分页查询
74 | if page > 0 && size > 0 {
75 | model = model.Page(page, size)
76 | }
77 |
78 | // 按创建时间倒序
79 | model = model.OrderDesc("created_at")
80 |
81 | err = model.Scan(&list)
82 | return
83 | }
84 |
85 | // GetChunkById 根据ID查询单个知识块
86 | func GetChunkById(ctx context.Context, id int64) (chunk entity.KnowledgeChunks, err error) {
87 | err = dao.KnowledgeChunks.Ctx(ctx).Where("id", id).Scan(&chunk)
88 | return
89 | }
90 |
91 | // DeleteChunkByIds 根据ID软删除知识块
92 | func DeleteChunkById(ctx context.Context, id int64) error {
93 | _, err := dao.KnowledgeChunks.Ctx(ctx).Where("id", id).Delete()
94 | return err
95 | }
96 |
97 | // UpdateChunkById 根据ID更新知识块
98 | func UpdateChunkByIds(ctx context.Context, ids []int64, data entity.KnowledgeChunks) error {
99 | model := dao.KnowledgeChunks.Ctx(ctx).WhereIn("id", ids)
100 | if data.Content != "" {
101 | model = model.Data("content", data.Content)
102 | }
103 | if data.Status != 0 {
104 | model = model.Data("status", data.Status)
105 | }
106 | _, err := model.Update()
107 | return err
108 | }
109 |
110 | // GetAllChunksByDocId gets all chunks by document id
111 | func GetAllChunksByDocId(ctx context.Context, docId int64, fields ...string) (list []entity.KnowledgeChunks, err error) {
112 | model := dao.KnowledgeChunks.Ctx(ctx).Where("knowledge_doc_id", docId)
113 | if len(fields) > 0 {
114 | for _, field := range fields {
115 | model = model.Fields(field)
116 | }
117 | }
118 | err = model.Scan(&list)
119 | return
120 | }
121 |
--------------------------------------------------------------------------------
/server/internal/logic/chat/chat.go:
--------------------------------------------------------------------------------
1 | package chat
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "io"
8 |
9 | "github.com/cloudwego/eino-ext/components/model/openai"
10 | "github.com/cloudwego/eino/components/model"
11 | "github.com/cloudwego/eino/schema"
12 | "github.com/gogf/gf/v2/frame/g"
13 | "github.com/gogf/gf/v2/os/gctx"
14 | "github.com/wangle201210/chat-history/eino"
15 | "github.com/wangle201210/go-rag/server/internal/dao"
16 | )
17 |
18 | var chat *Chat
19 |
20 | type Chat struct {
21 | cm model.BaseChatModel
22 | eh *eino.History
23 | }
24 |
25 | func GetChat() *Chat {
26 | return chat
27 | }
28 |
29 | // 暂时用不上chat功能,先不init
30 | func init() {
31 | ctx := gctx.New()
32 | c, err := newChat(&openai.ChatModelConfig{
33 | APIKey: g.Cfg().MustGet(ctx, "chat.apiKey").String(),
34 | BaseURL: g.Cfg().MustGet(ctx, "chat.baseURL").String(),
35 | Model: g.Cfg().MustGet(ctx, "chat.model").String(),
36 | })
37 | if err != nil {
38 | g.Log().Fatalf(ctx, "newChat failed, err=%v", err)
39 | return
40 | }
41 | // 使用 DSN 初始化,chat-history 包会根据 DSN 判断数据库类型
42 | // 对于 SQLite: file.db?_journal_mode=WAL
43 | // 对于 MySQL: user:pass@tcp(host:port)/dbname?charset=utf8mb4
44 | c.eh = eino.NewEinoHistory(dao.GetDsn())
45 | chat = c
46 | }
47 |
48 | func newChat(cfg *openai.ChatModelConfig) (res *Chat, err error) {
49 | chatModel, err := openai.NewChatModel(context.Background(), cfg)
50 | if err != nil {
51 | return nil, err
52 | }
53 | return &Chat{cm: chatModel}, nil
54 | }
55 |
56 | func (x *Chat) GetAnswer(ctx context.Context, convID string, docs []*schema.Document, question string) (answer string, err error) {
57 | messages, err := x.docsMessages(ctx, convID, docs, question)
58 | if err != nil {
59 | return "", err
60 | }
61 | result, err := generate(ctx, x.cm, messages)
62 | if err != nil {
63 | return "", fmt.Errorf("生成答案失败: %w", err)
64 | }
65 | err = x.eh.SaveMessage(result, convID)
66 | if err != nil {
67 | g.Log().Error(ctx, "save assistant message err: %v", err)
68 | return
69 | }
70 | return result.Content, nil
71 | }
72 |
73 | // GetAnswerStream 流式生成答案
74 | func (x *Chat) GetAnswerStream(ctx context.Context, convID string, docs []*schema.Document, question string) (answer *schema.StreamReader[*schema.Message], err error) {
75 | messages, err := x.docsMessages(ctx, convID, docs, question)
76 | if err != nil {
77 | return
78 | }
79 | ctx = context.Background()
80 | streamData, err := stream(ctx, x.cm, messages)
81 | if err != nil {
82 | return nil, fmt.Errorf("生成答案失败: %w", err)
83 | }
84 | srs := streamData.Copy(2)
85 | go func() {
86 | // for save
87 | fullMsgs := make([]*schema.Message, 0)
88 | defer func() {
89 | srs[1].Close()
90 | fullMsg, err := schema.ConcatMessages(fullMsgs)
91 | if err != nil {
92 | g.Log().Error(ctx, "error concatenating messages: %v", err)
93 | return
94 | }
95 | err = x.eh.SaveMessage(fullMsg, convID)
96 | if err != nil {
97 | g.Log().Error(ctx, "save assistant message err: %v", err)
98 | return
99 | }
100 | }()
101 | outer:
102 | for {
103 | select {
104 | case <-ctx.Done():
105 | fmt.Println("context done", ctx.Err())
106 | return
107 | default:
108 | chunk, err := srs[1].Recv()
109 | if err != nil {
110 | if errors.Is(err, io.EOF) {
111 | break outer
112 | }
113 | }
114 | fullMsgs = append(fullMsgs, chunk)
115 | }
116 | }
117 | }()
118 |
119 | return srs[0], nil
120 |
121 | }
122 |
123 | func generate(ctx context.Context, llm model.BaseChatModel, in []*schema.Message) (message *schema.Message, err error) {
124 | message, err = llm.Generate(ctx, in)
125 | if err != nil {
126 | err = fmt.Errorf("llm generate failed: %v", err)
127 | return
128 | }
129 | return
130 | }
131 |
132 | func stream(ctx context.Context, llm model.BaseChatModel, in []*schema.Message) (res *schema.StreamReader[*schema.Message], err error) {
133 | res, err = llm.Stream(ctx, in)
134 | if err != nil {
135 | err = fmt.Errorf("llm generate failed: %v", err)
136 | return
137 | }
138 | return
139 | }
140 |
--------------------------------------------------------------------------------
/server/internal/logic/knowledge/documents.go:
--------------------------------------------------------------------------------
1 | package knowledge
2 |
3 | import (
4 | "context"
5 | "fmt"
6 |
7 | "github.com/gogf/gf/v2/database/gdb"
8 | "github.com/gogf/gf/v2/frame/g"
9 | "github.com/wangle201210/go-rag/server/internal/dao"
10 | "github.com/wangle201210/go-rag/server/internal/model/entity"
11 | )
12 |
13 | const (
14 | defaultPageSize = 10
15 | maxPageSize = 100
16 | )
17 |
18 | // SaveDocumentsInfo 保存文档信息
19 | func SaveDocumentsInfo(ctx context.Context, documents entity.KnowledgeDocuments) (id int64, err error) {
20 | // 确保 ID 为 0,让数据库自动分配
21 | documents.Id = 0
22 |
23 | // OmitEmpty 会忽略零值字段(包括 ID=0),让数据库自动分配 ID
24 | // 这样可以兼容 MySQL 和 SQLite 的自增主键
25 | result, err := dao.KnowledgeDocuments.Ctx(ctx).Data(documents).OmitEmpty().Insert()
26 | if err != nil {
27 | g.Log().Errorf(ctx, "保存文档信息失败: %+v, 错误: %v", documents, err)
28 | return 0, fmt.Errorf("保存文档信息失败: %w", err)
29 | }
30 |
31 | id, err = result.LastInsertId()
32 | if err != nil {
33 | return 0, fmt.Errorf("获取插入ID失败: %w", err)
34 | }
35 | g.Log().Infof(ctx, "文档保存成功, ID: %d", id)
36 | return id, nil
37 | }
38 |
39 | // UpdateDocumentsStatus 更新文档状态
40 | func UpdateDocumentsStatus(ctx context.Context, documentsId int64, status int) error {
41 | data := g.Map{
42 | "status": status,
43 | }
44 |
45 | _, err := dao.KnowledgeDocuments.Ctx(ctx).Where("id", documentsId).Data(data).Update()
46 | if err != nil {
47 | g.Log().Errorf(ctx, "更新文档状态失败: ID=%d, 错误: %v", documentsId, err)
48 | }
49 |
50 | return err
51 | }
52 |
53 | // GetDocumentById 根据ID获取文档信息
54 | func GetDocumentById(ctx context.Context, id int64) (document entity.KnowledgeDocuments, err error) {
55 | g.Log().Debugf(ctx, "获取文档信息: ID=%d", id)
56 |
57 | err = dao.KnowledgeDocuments.Ctx(ctx).Where("id", id).Scan(&document)
58 | if err != nil {
59 | g.Log().Errorf(ctx, "获取文档信息失败: ID=%d, 错误: %v", id, err)
60 | return document, fmt.Errorf("获取文档信息失败: %w", err)
61 | }
62 |
63 | return document, nil
64 | }
65 |
66 | // GetDocumentsList 获取文档列表
67 | func GetDocumentsList(ctx context.Context, where entity.KnowledgeDocuments, page int, pageSize int) (documents []entity.KnowledgeDocuments, total int, err error) {
68 | // 参数验证和默认值设置
69 | if page < 1 {
70 | page = 1
71 | }
72 | if pageSize < 1 {
73 | pageSize = defaultPageSize
74 | }
75 | if pageSize > maxPageSize {
76 | pageSize = maxPageSize
77 | }
78 |
79 | model := dao.KnowledgeDocuments.Ctx(ctx)
80 | if where.KnowledgeBaseName != "" {
81 | model = model.Where("knowledge_base_name", where.KnowledgeBaseName)
82 | }
83 |
84 | total, err = model.Count()
85 | if err != nil {
86 | g.Log().Errorf(ctx, "获取文档总数失败: %v", err)
87 | return nil, 0, fmt.Errorf("获取文档总数失败: %w", err)
88 | }
89 |
90 | if total == 0 {
91 | return nil, 0, nil
92 | }
93 |
94 | err = model.Page(page, pageSize).
95 | Order("created_at desc").
96 | Scan(&documents)
97 | if err != nil {
98 | g.Log().Errorf(ctx, "获取文档列表失败: %v", err)
99 | return nil, 0, fmt.Errorf("获取文档列表失败: %w", err)
100 | }
101 |
102 | return documents, total, nil
103 | }
104 |
105 | // DeleteDocument 删除文档及其相关数据
106 | func DeleteDocument(ctx context.Context, id int64) error {
107 | g.Log().Debugf(ctx, "删除文档: ID=%d", id)
108 |
109 | return dao.KnowledgeDocuments.Ctx(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
110 | // 先删除文档块
111 | _, err := dao.KnowledgeChunks.Ctx(ctx).TX(tx).Where("knowledge_doc_id", id).Delete()
112 | if err != nil {
113 | g.Log().Errorf(ctx, "删除文档块失败: ID=%d, 错误: %v", id, err)
114 | return fmt.Errorf("删除文档块失败: %w", err)
115 | }
116 |
117 | // 再删除文档
118 | result, err := dao.KnowledgeDocuments.Ctx(ctx).TX(tx).Where("id", id).Delete()
119 | if err != nil {
120 | g.Log().Errorf(ctx, "删除文档失败: ID=%d, 错误: %v", id, err)
121 | return fmt.Errorf("删除文档失败: %w", err)
122 | }
123 |
124 | affected, err := result.RowsAffected()
125 | if err != nil {
126 | return fmt.Errorf("获取影响行数失败: %w", err)
127 | }
128 | if affected == 0 {
129 | return fmt.Errorf("文档不存在")
130 | }
131 |
132 | g.Log().Infof(ctx, "文档删除成功: ID=%d", id)
133 | return nil
134 | })
135 | }
136 |
--------------------------------------------------------------------------------
/fe/public/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |