├── .gitignore
├── README.md
├── database.sql
├── front-end
├── .env.development
├── .env.production
├── .eslintrc.cjs
├── .gitignore
├── .prettierrc.json
├── README.md
├── env.d.ts
├── index.html
├── package-lock.json
├── package.json
├── public
│ └── favicon.ico
├── scripts
│ └── generate-api.js
├── src
│ ├── App.vue
│ ├── assets
│ │ ├── SrpingAI知识点.png
│ │ ├── background.jpg
│ │ ├── cover.png
│ │ └── logo.jpg
│ ├── components
│ │ ├── image
│ │ │ └── image-upload.vue
│ │ └── key-value
│ │ │ ├── key-value-input.vue
│ │ │ └── value-input.vue
│ ├── main.ts
│ ├── router
│ │ └── index.ts
│ ├── stores
│ │ └── home-store.ts
│ ├── typings
│ │ └── index.d.ts
│ ├── utils
│ │ ├── api-instance.ts
│ │ ├── common.ts
│ │ └── request.ts
│ └── views
│ │ ├── chat
│ │ ├── chat-view.vue
│ │ ├── components
│ │ │ ├── markdown-message.vue
│ │ │ ├── message-input.vue
│ │ │ ├── message-row.vue
│ │ │ ├── session-item.vue
│ │ │ └── text-loading.vue
│ │ └── store
│ │ │ └── chat-store.ts
│ │ ├── code
│ │ └── analyze
│ │ │ ├── analyze-result-view.vue
│ │ │ └── collapse-title.vue
│ │ └── login
│ │ ├── login-view.vue
│ │ └── register-view.vue
├── tsconfig.app.json
├── tsconfig.json
├── tsconfig.node.json
└── vite.config.ts
├── pom.xml
└── src
└── main
├── dto
├── AiMessage.dto
├── AiSession.dto
└── User.dto
├── java
└── io
│ └── github
│ └── qifan777
│ └── knowledge
│ ├── ServerApplication.java
│ ├── ai
│ ├── agent
│ │ ├── AbstractAgent.java
│ │ ├── Agent.java
│ │ ├── chronologist
│ │ │ └── Chronologist.java
│ │ └── computer
│ │ │ ├── ComputerAssistant.java
│ │ │ ├── CpuAnalyzer.java
│ │ │ └── DirectoryReader.java
│ ├── document
│ │ └── DocumentController.java
│ ├── message
│ │ ├── AiMessage.java
│ │ ├── AiMessageChatMemory.java
│ │ ├── AiMessageController.java
│ │ ├── AiMessageRepository.java
│ │ ├── MessageInputWrapper.java
│ │ └── dto
│ │ │ ├── AiMessageParams.java
│ │ │ └── AiMessageWrapper.java
│ └── session
│ │ ├── AiSession.java
│ │ ├── AiSessionController.java
│ │ └── AiSessionRepository.java
│ ├── code
│ ├── CodeAssistantAgent.java
│ ├── analyze
│ │ ├── AnalyzeController.java
│ │ └── AnalyzeFunction.java
│ ├── arthas
│ │ └── ArthasFunction.java
│ └── graph
│ │ ├── controller
│ │ └── CodeGraphController.java
│ │ ├── entity
│ │ ├── ClassNode.java
│ │ └── MethodNode.java
│ │ ├── repository
│ │ ├── ClassNodeRepository.java
│ │ └── MethodNodeRepository.java
│ │ └── service
│ │ ├── CodeGraphBuilder.java
│ │ └── CodeGraphService.java
│ ├── demo
│ ├── DocumentAnalyzerFunction.java
│ ├── DocumentDemoController.java
│ └── MessageDemoController.java
│ ├── graph
│ ├── GraphController.java
│ ├── chunk
│ │ ├── Chunk.java
│ │ ├── ChunkController.java
│ │ └── ChunkRepository.java
│ ├── company
│ │ ├── Company.java
│ │ ├── CompanyController.java
│ │ └── CompanyRepository.java
│ ├── form
│ │ ├── Form.java
│ │ ├── FormController.java
│ │ └── FormRepository.java
│ ├── manager
│ │ ├── Manager.java
│ │ ├── ManagerController.java
│ │ └── ManagerRepository.java
│ └── model
│ │ ├── Form10K.java
│ │ └── Form13.java
│ ├── infrastructure
│ ├── code
│ │ ├── CodeAssistantProperties.java
│ │ └── JavaParserUtils.java
│ ├── config
│ │ ├── GlobalExceptionAdvice.java
│ │ ├── LocalDateTimeConvert.java
│ │ ├── PageableConvert.java
│ │ ├── ResponseInterceptor.java
│ │ └── WebConfig.java
│ ├── jimmer
│ │ ├── BaseDateTime.java
│ │ ├── BaseDateTimeDraftInterceptor.java
│ │ ├── BaseEntity.java
│ │ ├── BaseEntityDraftInterceptor.java
│ │ ├── JimmerConfig.java
│ │ └── UUIDIdGenerator.java
│ └── model
│ │ └── QueryRequest.java
│ ├── oss
│ └── OSSController.java
│ └── user
│ ├── User.java
│ ├── UserController.java
│ └── UserRepository.java
└── resources
├── application.yml
└── prompts
└── react
├── react-tool.st
└── react.st
/.gitignore:
--------------------------------------------------------------------------------
1 | target/
2 | !.mvn/wrapper/maven-wrapper.jar
3 | !**/src/main/**/target/
4 | !**/src/test/**/target/
5 |
6 | ### IntelliJ IDEA ###
7 | .idea/modules.xml
8 | .idea/jarRepositories.xml
9 | .idea/compiler.xml
10 | .idea/libraries/
11 | *.iws
12 | *.iml
13 | *.ipr
14 |
15 | ### Eclipse ###
16 | .apt_generated
17 | .classpath
18 | .factorypath
19 | .project
20 | .settings
21 | .springBeans
22 | .sts4-cache
23 |
24 | ### NetBeans ###
25 | /nbproject/private/
26 | /nbbuild/
27 | /dist/
28 | /nbdist/
29 | /.nb-gradle/
30 | build/
31 | !**/src/main/**/build/
32 | !**/src/test/**/build/
33 |
34 | ### VS Code ###
35 | .vscode/
36 |
37 | ### Mac OS ###
38 | .DS_Store
39 | /src/main/resources/application-private.yml
40 | /.idea/
41 | /template/
42 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 项目介绍
2 |
3 | 本项目使用SpringAI教学,包含了SSE流/Agent智能体/FunctionCall/Embedding/VectorDatabase/RAG/Graph RAG/历史消息/图片生成/图片理解
4 |
5 | 
6 |
7 | 
8 |
9 |
10 |
11 | [文档地址](https://www.jarcheng.top/blog/project/spring-ai/intro.html)
12 | [视频地址](https://www.bilibili.com/video/BV14y411q7RN/)
13 |
14 | ## 运行环境
15 |
16 | - Java 17
17 | - Node.js 18+
18 | - MySQL 8
19 | - DashScope API KEY(或者其他)
20 | - Redis-Stack
21 |
22 | redis基础上拓展向量查询功能
23 |
24 | ```shell
25 | docker run -d \
26 | --name redis-stack \
27 | --restart=always \
28 | -v redis-data:/data \
29 | -p 6379:6379 \
30 | -p 8001:8001 \
31 | -e REDIS_ARGS="--requirepass 123456" redis/redis-stack:latest
32 | ```
33 |
34 | - neo4j 5+
35 |
36 | 安装完neo4j访问`localhost:7474`, 默认的账号密码都是`neo4j`和`neo4j`。
37 |
38 | ```shell
39 | docker run \
40 | -d \
41 | -p 7474:7474 -p 7687:7687 \
42 | -v neo4j-data:/data -v neo4j-data:/plugins \
43 | --name neo4j \
44 | -e NEO4J_apoc_export_file_enabled=true \
45 | -e NEO4J_apoc_import_file_enabled=true \
46 | -e NEO4J_apoc_import_file_use__neo4j__config=true \
47 | -e NEO4JLABS_PLUGINS=\[\"apoc\"\] \
48 | -e NEO4J_dbms_security_procedures_unrestricted=apoc.\\\* \
49 | neo4j
50 | ```
51 |
52 | ## 运行步骤
53 |
54 | ### 1.clone代码
55 |
56 | ```shell
57 | git clone https://github.com/qifan777/KnowledgeBaseChatSpringAI
58 | ```
59 |
60 | ### 2. idea打开项目
61 |
62 | ### 3. 修改配置文件
63 |
64 | 修改application.yml中的API-KEY, MySQL, Redis-Stack, Neo4j配置
65 | ### 4. 运行项目
66 |
67 | 后端运行
68 |
69 | 1. 运行ServerApplication.java
70 | 2. target/generated-sources/annotations右键mark directory as/generated source root
71 |
72 | 前端运行,在front-end目录下
73 |
74 | - npm run install
75 | - npm run api (先运行后端)
76 | - npm run dev
77 |
78 |
79 | ## 联系方式
80 |
81 | 付费远程运行/安装/定制开发联系微信:ljc666max
82 |
83 | 其他关于程序运行安装报错请加QQ群:
84 |
85 | - 416765656(满)
86 | - 632067985
87 |
--------------------------------------------------------------------------------
/database.sql:
--------------------------------------------------------------------------------
1 | -- MySQL dump 10.13 Distrib 8.0.36, for Win64 (x86_64)
2 | --
3 | -- Host: 127.0.0.1 Database: knowledge_base
4 | -- ------------------------------------------------------
5 | -- Server version 8.0.36
6 |
7 | /*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;
8 | /*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */;
9 | /*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */;
10 | /*!50503 SET NAMES utf8mb4 */;
11 | /*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */;
12 | /*!40103 SET TIME_ZONE='+00:00' */;
13 | /*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */;
14 | /*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */;
15 | /*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */;
16 | /*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */;
17 |
18 | --
19 | -- Table structure for table `ai_message`
20 | --
21 |
22 | DROP TABLE IF EXISTS `ai_message`;
23 | /*!40101 SET @saved_cs_client = @@character_set_client */;
24 | /*!50503 SET character_set_client = utf8mb4 */;
25 | CREATE TABLE `ai_message` (
26 | `id` varchar(36) NOT NULL,
27 | `created_time` datetime(6) NOT NULL,
28 | `edited_time` datetime(6) NOT NULL,
29 | `creator_id` varchar(32) NOT NULL,
30 | `editor_id` varchar(32) NOT NULL,
31 | `type` varchar(32) NOT NULL COMMENT '消息类型(用户/助手/系统)',
32 | `text_content` text NOT NULL COMMENT '消息内容',
33 | `medias` json DEFAULT NULL COMMENT '媒体内容如图片链接、语音链接',
34 | `ai_session_id` varchar(32) NOT NULL COMMENT '会话id',
35 | PRIMARY KEY (`id`),
36 | KEY `ai_message_ai_session_id_fk` (`ai_session_id`)
37 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;
38 | /*!40101 SET character_set_client = @saved_cs_client */;
39 |
40 | --
41 | -- Dumping data for table `ai_message`
42 | --
43 |
44 | LOCK TABLES `ai_message` WRITE;
45 | /*!40000 ALTER TABLE `ai_message` DISABLE KEYS */;
46 | INSERT INTO `ai_message` VALUES ('05c523bf586b422e8f19d32a6bab17b6','2024-06-23 22:08:33.529844','2024-06-23 22:08:33.529844','fe3d0d7d6eb34eb7a6b7985426cf8af7','fe3d0d7d6eb34eb7a6b7985426cf8af7','ASSISTANT','今天的日期是2024年06月23日。',NULL,'b495186b9259494792a9fe4ce8807f25');
47 | INSERT INTO `ai_message` VALUES ('1d14d513ae8c4ffda0fb50df6e24ba77','2024-06-20 22:08:28.389711','2024-06-20 22:08:28.389711','fe3d0d7d6eb34eb7a6b7985426cf8af7','fe3d0d7d6eb34eb7a6b7985426cf8af7','USER','C:\\Users\\Administrator\\Desktop\\2023年工作总结及2024年工作展望-林家成.docx,这份文档的内容是什么?','[]','c4d192c4b2de48bcae585f5b9672cd1d');
48 | INSERT INTO `ai_message` VALUES ('34d65ad69e4b470794da11732fe44a7d','2024-06-23 22:08:33.337776','2024-06-23 22:08:33.337776','fe3d0d7d6eb34eb7a6b7985426cf8af7','fe3d0d7d6eb34eb7a6b7985426cf8af7','USER','今天的日期是多少?','[]','b495186b9259494792a9fe4ce8807f25');
49 | INSERT INTO `ai_message` VALUES ('ba161500bec84b388a6501b3be333b45','2024-06-20 22:08:28.649439','2024-06-20 22:08:28.649439','fe3d0d7d6eb34eb7a6b7985426cf8af7','fe3d0d7d6eb34eb7a6b7985426cf8af7','ASSISTANT','2023年的工作总结及2024年工作展望文档内容概要如下:\n\n1、**工作回顾重点:**\n - 成功运用Python脚本在一经上云项目中实现Oracle脚本批量迁移、实体批量登记等,极大提升了工作效率(速度提升70倍以上),并提前完成年度工作任务。\n - 领导新智慧工厂项目的前端与后端框架搭建,采用更高效、安全的技术栈,促进团队快速投入业务开发。重构并优化了基础功能模块,如菜单框架、租户管理等,封装为易用组件。\n - 开发了脚本依赖分析工具,通过SQL语法树解析等技术,有效辅助故障检测、代码规范检查及开发效率,与部门其他系统(如指标库、报表系统)集成,增强功能实用性。\n\n2、**个人成长与进步:**\n - 今年最大的成就是深入学习了数据开发知识,综合软件开发能力使自己更能满足用户需求,相比同行具有更全面的技能组合。\n\n3、**应对挑战的方法:**\n - 面对SQL语法解析等高难度任务,通过自学源码、复习理论基础,成功开发出工具,体现了自我驱动力和解决问题的能力。\n\n4、**反思与改进空间:**\n - 希望能在项目中实践深度学习、机器学习技术,这是未来努力的一个方向。\n\n5、**提质增效的建议:**\n - 强调代码审查和个人技能提升,鼓励使用先进技术和工具,以及创建持续学习的环境,以保持竞争力。\n\n6、**对当前开发工作的建议:**\n - 提倡技术分享会议,打破知识孤岛,促进新老员工间的技术交流和学习,平衡稳定与创新的需求。\n\n7、**其他建议:**\n - 加强跨部门合作,以新智慧工厂项目为范例,推动更多跨领域协作成果。\n\n对于2024年的展望虽未详细展开,但基于上述总结,可能包括深化技术应用、促进团队协作、继续个人技能升级等方面。',NULL,'c4d192c4b2de48bcae585f5b9672cd1d');
50 | /*!40000 ALTER TABLE `ai_message` ENABLE KEYS */;
51 | UNLOCK TABLES;
52 |
53 | --
54 | -- Table structure for table `ai_session`
55 | --
56 |
57 | DROP TABLE IF EXISTS `ai_session`;
58 | /*!40101 SET @saved_cs_client = @@character_set_client */;
59 | /*!50503 SET character_set_client = utf8mb4 */;
60 | CREATE TABLE `ai_session` (
61 | `id` varchar(36) NOT NULL,
62 | `created_time` datetime(6) NOT NULL,
63 | `edited_time` datetime(6) NOT NULL,
64 | `creator_id` varchar(32) NOT NULL,
65 | `editor_id` varchar(32) NOT NULL,
66 | `name` varchar(32) NOT NULL COMMENT '会话名称',
67 | PRIMARY KEY (`id`)
68 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;
69 | /*!40101 SET character_set_client = @saved_cs_client */;
70 |
71 | --
72 | -- Dumping data for table `ai_session`
73 | --
74 |
75 | LOCK TABLES `ai_session` WRITE;
76 | /*!40000 ALTER TABLE `ai_session` DISABLE KEYS */;
77 | INSERT INTO `ai_session` VALUES ('b495186b9259494792a9fe4ce8807f25','2024-06-18 22:01:51.459293','2024-06-18 22:01:51.459293','fe3d0d7d6eb34eb7a6b7985426cf8af7','fe3d0d7d6eb34eb7a6b7985426cf8af7','新的聊天');
78 | INSERT INTO `ai_session` VALUES ('c4d192c4b2de48bcae585f5b9672cd1d','2024-06-18 22:01:45.819825','2024-06-18 22:01:45.819317','fe3d0d7d6eb34eb7a6b7985426cf8af7','fe3d0d7d6eb34eb7a6b7985426cf8af7','新的聊天');
79 | INSERT INTO `ai_session` VALUES ('dcd32fd9f7fc4d0fb9a5b7dc9ff0dbc8','2024-06-18 21:57:09.390651','2024-06-18 21:57:09.390651','fe3d0d7d6eb34eb7a6b7985426cf8af7','fe3d0d7d6eb34eb7a6b7985426cf8af7','新的聊天');
80 | /*!40000 ALTER TABLE `ai_session` ENABLE KEYS */;
81 | UNLOCK TABLES;
82 |
83 | --
84 | -- Table structure for table `user`
85 | --
86 |
87 | DROP TABLE IF EXISTS `user`;
88 | /*!40101 SET @saved_cs_client = @@character_set_client */;
89 | /*!50503 SET character_set_client = utf8mb4 */;
90 | CREATE TABLE `user` (
91 | `id` varchar(36) NOT NULL,
92 | `created_time` datetime(6) NOT NULL,
93 | `edited_time` datetime(6) NOT NULL,
94 | `nickname` varchar(20) DEFAULT NULL,
95 | `avatar` varchar(255) DEFAULT NULL,
96 | `gender` varchar(36) DEFAULT NULL,
97 | `phone` varchar(20) NOT NULL,
98 | `password` varchar(100) NOT NULL,
99 | PRIMARY KEY (`id`),
100 | UNIQUE KEY `phone` (`phone`)
101 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;
102 | /*!40101 SET character_set_client = @saved_cs_client */;
103 |
104 | --
105 | -- Dumping data for table `user`
106 | --
107 |
108 | LOCK TABLES `user` WRITE;
109 | /*!40000 ALTER TABLE `user` DISABLE KEYS */;
110 | INSERT INTO `user` VALUES ('dcd256e2412f4162a6a5fcbd5cfedc84','2024-05-01 16:52:43.364225','2024-05-19 21:30:34.686818','起凡','https://my-community.oss-cn-qingdao.aliyuncs.com/20240501203628ptwondCGhItP67eb5ac72554b07800b22c542245e457.jpeg','MALE','11111111111','$2a$10$o/DHIt/eAMR175TgDV/PeeuEOpqW1N4Klft6obvs2zqBuiwMgLWOW');
111 | INSERT INTO `user` VALUES ('fe3d0d7d6eb34eb7a6b7985426cf8af7','2024-06-06 13:23:46.130879','2024-06-06 13:23:46.130817','默认用户',NULL,NULL,'13656987994','$2a$10$q7pey1P1/b3lO9nzFLKOb.ISrX7.lkktMjghwhgvNqvA.EjZZ2mg2');
112 | /*!40000 ALTER TABLE `user` ENABLE KEYS */;
113 | UNLOCK TABLES;
114 | /*!40103 SET TIME_ZONE=@OLD_TIME_ZONE */;
115 |
116 | /*!40101 SET SQL_MODE=@OLD_SQL_MODE */;
117 | /*!40014 SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS */;
118 | /*!40014 SET UNIQUE_CHECKS=@OLD_UNIQUE_CHECKS */;
119 | /*!40101 SET CHARACTER_SET_CLIENT=@OLD_CHARACTER_SET_CLIENT */;
120 | /*!40101 SET CHARACTER_SET_RESULTS=@OLD_CHARACTER_SET_RESULTS */;
121 | /*!40101 SET COLLATION_CONNECTION=@OLD_COLLATION_CONNECTION */;
122 | /*!40111 SET SQL_NOTES=@OLD_SQL_NOTES */;
123 |
124 | -- Dump completed on 2024-06-24 21:11:15
125 |
--------------------------------------------------------------------------------
/front-end/.env.development:
--------------------------------------------------------------------------------
1 | VITE_WEBSOCKET_URL=ws://localhost:8877/handshake
2 | VITE_API_PREFIX=/api
--------------------------------------------------------------------------------
/front-end/.env.production:
--------------------------------------------------------------------------------
1 | VITE_API_PREFIX=
--------------------------------------------------------------------------------
/front-end/.eslintrc.cjs:
--------------------------------------------------------------------------------
1 | /* eslint-env node */
2 | require('@rushstack/eslint-patch/modern-module-resolution')
3 |
4 | module.exports = {
5 | root: true,
6 | extends: [
7 | 'plugin:vue/vue3-essential',
8 | 'eslint:recommended',
9 | '@vue/eslint-config-typescript',
10 | '@vue/eslint-config-prettier/skip-formatting'
11 | ],
12 | parserOptions: {
13 | ecmaVersion: 'latest'
14 | },
15 | rules: {
16 | 'prettier/prettier': [
17 | 'warn',
18 | {
19 | endOfLine: 'auto'
20 | }
21 | ]
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/front-end/.gitignore:
--------------------------------------------------------------------------------
1 | ### IntelliJ IDEA ###
2 | /.idea/
3 | *.iws
4 | *.iml
5 | *.ipr
6 | # Logs
7 | logs
8 | *.log
9 | npm-debug.log*
10 | yarn-debug.log*
11 | yarn-error.log*
12 | pnpm-debug.log*
13 | lerna-debug.log*
14 |
15 | node_modules
16 | .DS_Store
17 | dist
18 | dist-ssr
19 | coverage
20 | *.local
21 |
22 | /cypress/videos/
23 | /cypress/screenshots/
24 |
25 | # Editor directories and files
26 | .vscode/*
27 | !.vscode/extensions.json
28 | .idea
29 | *.suo
30 | *.ntvs*
31 | *.njsproj
32 | *.sln
33 | *.sw?
34 |
35 | /src/apis/__generated/
36 | /scripts/ingredient-admin.tar
37 |
--------------------------------------------------------------------------------
/front-end/.prettierrc.json:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "https://json.schemastore.org/prettierrc",
3 | "semi": false,
4 | "tabWidth": 2,
5 | "singleQuote": true,
6 | "printWidth": 100,
7 | "trailingComma": "none"
8 | }
--------------------------------------------------------------------------------
/front-end/README.md:
--------------------------------------------------------------------------------
1 | # mall-admin
2 |
3 | This template should help get you started developing with Vue 3 in Vite.
4 |
5 | ## Recommended IDE Setup
6 |
7 | [VSCode](https://code.visualstudio.com/) + [Volar](https://marketplace.visualstudio.com/items?itemName=Vue.volar) (and disable Vetur) + [TypeScript Vue Plugin (Volar)](https://marketplace.visualstudio.com/items?itemName=Vue.vscode-typescript-vue-plugin).
8 |
9 | ## Type Support for `.vue` Imports in TS
10 |
11 | TypeScript cannot handle type information for `.vue` imports by default, so we replace the `tsc` CLI with `vue-tsc` for type checking. In editors, we need [TypeScript Vue Plugin (Volar)](https://marketplace.visualstudio.com/items?itemName=Vue.vscode-typescript-vue-plugin) to make the TypeScript language service aware of `.vue` types.
12 |
13 | If the standalone TypeScript plugin doesn't feel fast enough to you, Volar has also implemented a [Take Over Mode](https://github.com/johnsoncodehk/volar/discussions/471#discussioncomment-1361669) that is more performant. You can enable it by the following steps:
14 |
15 | 1. Disable the built-in TypeScript Extension
16 | 1) Run `Extensions: Show Built-in Extensions` from VSCode's command palette
17 | 2) Find `TypeScript and JavaScript Language Features`, right click and select `Disable (Workspace)`
18 | 2. Reload the VSCode window by running `Developer: Reload Window` from the command palette.
19 |
20 | ## Customize configuration
21 |
22 | See [Vite Configuration Reference](https://vitejs.dev/config/).
23 |
24 | ## Project Setup
25 |
26 | ```sh
27 | npm install
28 | ```
29 |
30 | ### Compile and Hot-Reload for Development
31 |
32 | ```sh
33 | npm run dev
34 | ```
35 |
36 | ### Type-Check, Compile and Minify for Production
37 |
38 | ```sh
39 | npm run build
40 | ```
41 |
42 | ### Lint with [ESLint](https://eslint.org/)
43 |
44 | ```sh
45 | npm run lint
46 | ```
47 |
--------------------------------------------------------------------------------
/front-end/env.d.ts:
--------------------------------------------------------------------------------
1 | ///
2 | declare module 'element-plus/dist/locale/zh-cn.mjs'
3 |
--------------------------------------------------------------------------------
/front-end/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | Vite App
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/front-end/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "uni-ai-admin",
3 | "version": "0.0.0",
4 | "private": true,
5 | "type": "module",
6 | "scripts": {
7 | "dev": "vite",
8 | "build": "run-p type-check \"build-only {@}\" --",
9 | "preview": "vite preview",
10 | "build-only": "vite build",
11 | "type-check": "vue-tsc --build --force",
12 | "lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore",
13 | "format": "prettier --write src/",
14 | "api": "node scripts/generate-api.js"
15 | },
16 | "dependencies": {
17 | "@element-plus/icons-vue": "^2.3.1",
18 | "axios": "^1.6.4",
19 | "dayjs": "^1.11.10",
20 | "element-plus": "^2.4.2",
21 | "lodash": "^4.17.21",
22 | "md-editor-v3": "^4.13.2",
23 | "pinia": "^2.1.7",
24 | "sse.js": "^2.4.1",
25 | "vue": "^3.4.21",
26 | "vue-router": "^4.2.5"
27 | },
28 | "devDependencies": {
29 | "@rushstack/eslint-patch": "^1.3.3",
30 | "@tsconfig/node18": "^18.2.2",
31 | "@types/node": "^18.19.3",
32 | "@vitejs/plugin-vue": "^4.5.2",
33 | "@vitejs/plugin-vue-jsx": "^3.1.0",
34 | "@vue/eslint-config-prettier": "^8.0.0",
35 | "@vue/eslint-config-typescript": "^12.0.0",
36 | "@vue/tsconfig": "^0.5.0",
37 | "adm-zip": "^0.5.10",
38 | "eslint": "^8.49.0",
39 | "eslint-plugin-vue": "^9.17.0",
40 | "fs-extra": "^11.1.1",
41 | "npm-run-all2": "^6.1.1",
42 | "prettier": "^3.0.3",
43 | "sass": "^1.69.7",
44 | "typescript": "~5.3.0",
45 | "uuid": "^9.0.1",
46 | "vite": "^5.0.10",
47 | "vue-tsc": "^1.8.25"
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/front-end/public/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qifan777/dive-into-spring-ai/406b00b513ddac955e79e76b998487a2cc1fa20f/front-end/public/favicon.ico
--------------------------------------------------------------------------------
/front-end/scripts/generate-api.js:
--------------------------------------------------------------------------------
1 | /* eslint-env node */
2 | import http from 'http'
3 | import fs from 'fs'
4 | import fse from 'fs-extra'
5 | import { v4 as uuidv4 } from 'uuid'
6 | import os from 'os'
7 | import path from 'path'
8 | import AdmZip from 'adm-zip'
9 | const sourceUrl = 'http://localhost:9902/ts.zip'
10 | const tmpFilePath = os.tmpdir() + '/' + uuidv4() + '.zip'
11 | const generatePath = 'src/apis/__generated'
12 |
13 | console.log('Downloading ' + sourceUrl + '...')
14 |
15 | const tmpFile = fs.createWriteStream(tmpFilePath)
16 | http.get(sourceUrl, (response) => {
17 | response.pipe(tmpFile)
18 | tmpFile.on('finish', () => {
19 | tmpFile.close()
20 | console.log('File save success: ', tmpFilePath)
21 |
22 | // Remove generatePath if it exists
23 | if (fs.existsSync(generatePath)) {
24 | console.log('Removing existing generatePath...')
25 | fse.removeSync(generatePath)
26 | console.log('Existing generatePath removed.')
27 | }
28 |
29 | // Unzip the file using adm-zip
30 | console.log('Unzipping the file...')
31 | const zip = new AdmZip(tmpFilePath)
32 | zip.extractAllTo(generatePath, true)
33 | console.log('File unzipped successfully.')
34 | // Remove the temporary file
35 | console.log('Removing temporary file...')
36 | fs.unlink(tmpFilePath, (err) => {
37 | if (err) {
38 | console.error('Error while removing temporary file:', err)
39 | } else {
40 | console.log('Temporary file removed.')
41 | }
42 | })
43 | traverseDirectory(modelPath)
44 | traverseDirectory(servicePath)
45 | })
46 | })
47 |
48 | // 替换目录路径
49 | const modelPath = 'src/apis/__generated/model'
50 | const servicePath = 'src/apis/__generated/services'
51 |
52 | // 递归遍历目录中的所有文件
53 | function traverseDirectory(directoryPath) {
54 | const files = fs.readdirSync(directoryPath)
55 |
56 | files.forEach((file) => {
57 | const filePath = path.join(directoryPath, file)
58 | const stats = fs.statSync(filePath)
59 |
60 | if (stats.isDirectory()) {
61 | traverseDirectory(filePath)
62 | } else if (stats.isFile() && path.extname(filePath) === '.ts') {
63 | replaceInFile(filePath)
64 | }
65 | })
66 | }
67 |
68 | // 替换文件中的文本
69 | function replaceInFile(filePath) {
70 | const fileContent = fs.readFileSync(filePath, 'utf8')
71 | const updatedContent = fileContent
72 | .replaceAll('readonly ', '')
73 | .replace(/ReadonlyArray/g, 'Array')
74 | .replaceAll('ReadonlyMap', 'Map')
75 | .replace(/Map<(\S+), (\S+)>/g, '{ [key: $1]: $2 }')
76 | // .replace(/query: (\S+)/g, 'query: T')
77 | fs.writeFileSync(filePath, updatedContent, 'utf8')
78 | }
79 |
--------------------------------------------------------------------------------
/front-end/src/App.vue:
--------------------------------------------------------------------------------
1 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
13 |
14 |
15 |
79 |
--------------------------------------------------------------------------------
/front-end/src/assets/SrpingAI知识点.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qifan777/dive-into-spring-ai/406b00b513ddac955e79e76b998487a2cc1fa20f/front-end/src/assets/SrpingAI知识点.png
--------------------------------------------------------------------------------
/front-end/src/assets/background.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qifan777/dive-into-spring-ai/406b00b513ddac955e79e76b998487a2cc1fa20f/front-end/src/assets/background.jpg
--------------------------------------------------------------------------------
/front-end/src/assets/cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qifan777/dive-into-spring-ai/406b00b513ddac955e79e76b998487a2cc1fa20f/front-end/src/assets/cover.png
--------------------------------------------------------------------------------
/front-end/src/assets/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qifan777/dive-into-spring-ai/406b00b513ddac955e79e76b998487a2cc1fa20f/front-end/src/assets/logo.jpg
--------------------------------------------------------------------------------
/front-end/src/components/image/image-upload.vue:
--------------------------------------------------------------------------------
1 |
63 |
64 |
84 |
--------------------------------------------------------------------------------
/front-end/src/components/key-value/key-value-input.vue:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
70 |
71 |
93 |
--------------------------------------------------------------------------------
/front-end/src/components/key-value/value-input.vue:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
11 |
12 |
13 |
14 |
15 |
23 |
24 | {{ tag }}
33 |
34 |
35 |
44 | + 新增值
45 |
46 |
47 |
48 |
106 |
107 |
136 |
--------------------------------------------------------------------------------
/front-end/src/main.ts:
--------------------------------------------------------------------------------
1 | import { createApp } from 'vue'
2 | import { createPinia } from 'pinia'
3 | import ElementPlus from 'element-plus'
4 | import 'element-plus/dist/index.css'
5 | import * as ElementPlusIconsVue from '@element-plus/icons-vue'
6 | import App from './App.vue'
7 | import router from './router'
8 | import { MdPreview } from 'md-editor-v3'
9 | import 'md-editor-v3/lib/preview.css'
10 | const app = createApp(App)
11 |
12 | app.use(createPinia())
13 | app.use(ElementPlus)
14 | app.use(router)
15 | app.component('MdPreview', MdPreview)
16 | for (const [key, component] of Object.entries(ElementPlusIconsVue)) {
17 | app.component(key, component)
18 | }
19 | app.mount('#app')
20 |
--------------------------------------------------------------------------------
/front-end/src/router/index.ts:
--------------------------------------------------------------------------------
1 | import { createRouter, createWebHashHistory } from 'vue-router'
2 | import RegisterView from '@/views/login/register-view.vue'
3 | import LoginView from '@/views/login/login-view.vue'
4 | import ChatView from '@/views/chat/chat-view.vue'
5 | import AnalyzeResultView from '@/views/code/analyze/analyze-result-view.vue'
6 |
7 | const router = createRouter({
8 | history: createWebHashHistory(import.meta.env.BASE_URL),
9 | routes: [
10 | {
11 | path: '/',
12 | component: ChatView
13 | },
14 | {
15 | path: '/login',
16 | name: 'login',
17 | component: LoginView
18 | },
19 | {
20 | path: '/register',
21 | name: 'register',
22 | component: RegisterView
23 | },
24 | {
25 | path: '/analyze',
26 | component: AnalyzeResultView,
27 | props(to) {
28 | return { path: to.query.path }
29 | }
30 | }
31 | ]
32 | })
33 |
34 | export default router
35 |
--------------------------------------------------------------------------------
/front-end/src/stores/home-store.ts:
--------------------------------------------------------------------------------
1 | import { defineStore } from 'pinia'
2 | import { ref } from 'vue'
3 | import type { UserDto } from '@/apis/__generated/model/dto'
4 | import { api } from '@/utils/api-instance'
5 |
6 | export const useHomeStore = defineStore('home', () => {
7 | const userInfo = ref()
8 | const getUserInfo = async () => {
9 | userInfo.value = await api.userController.userInfo()
10 | return userInfo.value
11 | }
12 | const init = async () => {
13 | await getUserInfo()
14 | }
15 | const logout = () => {
16 | userInfo.value = undefined
17 | }
18 | return { userInfo, getUserInfo, init, logout }
19 | })
20 |
--------------------------------------------------------------------------------
/front-end/src/typings/index.d.ts:
--------------------------------------------------------------------------------
1 | import type { MenuDto } from '@/apis/__generated/model/dto'
2 |
3 | export type EditMode = 'CREATE' | 'UPDATE'
4 | export interface Scope {
5 | row: T
6 | $index: number
7 | }
8 |
9 | export interface Result {
10 | code: number
11 | success: boolean
12 | msg: string
13 | result: T
14 | }
15 | export type MenuTreeDto = {
16 | children?: MenuTreeDto[]
17 | } & MenuDto['MenuRepository/SIMPLE_FETCHER']
18 |
--------------------------------------------------------------------------------
/front-end/src/utils/api-instance.ts:
--------------------------------------------------------------------------------
1 | import { Api } from '@/apis/__generated'
2 | import { request } from '@/utils/request'
3 | // 导出全局变量`apis`
4 | export const api = new Api(async ({ uri, method, body }) => {
5 | return await request({ url: uri, method, data: body })
6 | })
7 |
--------------------------------------------------------------------------------
/front-end/src/utils/common.ts:
--------------------------------------------------------------------------------
1 | import { ElMessage } from 'element-plus'
2 | import type { ValidateFieldsError } from 'async-validator/dist-types'
3 |
4 | export const assertFormValidate = (callback: () => void) => {
5 | return (valid: boolean, fields: ValidateFieldsError | undefined) => {
6 | if (valid) {
7 | callback()
8 | } else {
9 | if (fields) {
10 | const messages: string[] = []
11 | for (const field in fields) {
12 | fields[field].forEach((error) => {
13 | if (error.message) {
14 | messages.push(error.message)
15 | }
16 | })
17 | }
18 | ElMessage.error({ message: messages.join('\n') })
19 | }
20 | }
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/front-end/src/utils/request.ts:
--------------------------------------------------------------------------------
1 | import axios from 'axios'
2 | import { ElMessage } from 'element-plus'
3 | import router from '@/router'
4 |
5 | const BASE_URL = import.meta.env.VITE_API_PREFIX
6 | export const request = axios.create({
7 | baseURL: BASE_URL,
8 | timeout: 600000
9 | })
10 | request.interceptors.response.use(
11 | (res) => {
12 | return res.data.result
13 | },
14 | ({ response }) => {
15 | if (response.data.code !== 1) {
16 | ElMessage.warning({ message: response.data.msg })
17 | }
18 | if (response.data.code === 10012) {
19 | router.push('/login')
20 | }
21 | return Promise.reject(response.data.result)
22 | }
23 | )
24 |
--------------------------------------------------------------------------------
/front-end/src/views/chat/chat-view.vue:
--------------------------------------------------------------------------------
1 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
AI助手
159 |
160 |
161 |
170 |
171 |
172 | 创建会话
178 |
179 |
180 |
181 |
182 |
183 |
184 |
211 |
212 |
213 |
214 |
215 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
235 | 上传文档
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 | 上传文本文件
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
386 |
--------------------------------------------------------------------------------
/front-end/src/views/chat/components/markdown-message.vue:
--------------------------------------------------------------------------------
1 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
22 |
--------------------------------------------------------------------------------
/front-end/src/views/chat/components/message-input.vue:
--------------------------------------------------------------------------------
1 |
26 |
27 |
28 |
52 |
53 |
54 |
77 |
--------------------------------------------------------------------------------
/front-end/src/views/chat/components/message-row.vue:
--------------------------------------------------------------------------------
1 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
33 |
34 |
35 |
36 |
37 |
38 |
43 |
44 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
113 |
--------------------------------------------------------------------------------
/front-end/src/views/chat/components/session-item.vue:
--------------------------------------------------------------------------------
1 |
16 |
17 |
18 |
19 |
20 |
21 |
{{ session.name }}
22 |
23 |
24 |
{{ session.messages ? session.messages.length : 0 }}条对话
25 |
{{ session.editedTime }}
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
141 |
--------------------------------------------------------------------------------
/front-end/src/views/chat/components/text-loading.vue:
--------------------------------------------------------------------------------
1 |
2 |
8 |
9 |
10 |
51 |
--------------------------------------------------------------------------------
/front-end/src/views/chat/store/chat-store.ts:
--------------------------------------------------------------------------------
1 | import { defineStore } from 'pinia'
2 | import { ref } from 'vue'
3 | import type { AiSessionDto } from '@/apis/__generated/model/dto'
4 | import { api } from '@/utils/api-instance'
5 | import type { AiMessageInput, AiSessionInput } from '@/apis/__generated/model/static'
6 | import { ElMessageBox } from 'element-plus'
7 |
8 | export type AiSession = Pick<
9 | AiSessionDto['AiSessionRepository/FETCHER'],
10 | 'id' | 'name' | 'editedTime'
11 | > & {
12 | messages: AiMessage[]
13 | }
14 |
15 | export type AiMessage = Pick & {
16 | id: string
17 | }
18 | export const useChatStore = defineStore('ai-chat', () => {
19 | const isEdit = ref(false)
20 | const activeSession = ref()
21 | const sessionList = ref([])
22 | const handleCreateSession = async (session: AiSessionInput) => {
23 | const res = await api.aiSessionController.save({ body: session })
24 | const sessionRes = await api.aiSessionController.findById({ id: res })
25 | sessionList.value.unshift(sessionRes)
26 | activeSession.value = sessionList.value[0]
27 | }
28 | // 从会话列表中删除会话
29 | const handleDeleteSession = async (session: AiSession) => {
30 | await api.aiSessionController.delete({ body: [session.id] })
31 | const index = sessionList.value.findIndex((value) => {
32 | return value.id === session.id
33 | })
34 | sessionList.value.splice(index, 1)
35 | if (index == sessionList.value.length) {
36 | activeSession.value = sessionList.value[index - 1]
37 | } else {
38 | activeSession.value = sessionList.value[index]
39 | }
40 | }
41 | // 修改会话
42 | const handleUpdateSession = async () => {
43 | if (!activeSession.value) {
44 | return
45 | }
46 | await api.aiSessionController.save({
47 | body: { ...activeSession.value }
48 | })
49 | isEdit.value = false
50 | }
51 | const handleClearMessage = async (sessionId: string) => {
52 | await ElMessageBox.confirm('是否清空会话记录?', '提示')
53 | await api.aiMessageController.deleteHistory({ sessionId })
54 | const index = sessionList.value.findIndex((value) => {
55 | return value.id === sessionId
56 | })
57 | activeSession.value = await api.aiSessionController.findById({ id: sessionId })
58 | sessionList.value[index] = activeSession.value
59 | }
60 | return {
61 | isEdit,
62 | activeSession,
63 | sessionList,
64 | handleUpdateSession,
65 | handleCreateSession,
66 | handleDeleteSession,
67 | handleClearMessage
68 | }
69 | })
70 |
--------------------------------------------------------------------------------
/front-end/src/views/code/analyze/analyze-result-view.vue:
--------------------------------------------------------------------------------
1 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
58 |
--------------------------------------------------------------------------------
/front-end/src/views/code/analyze/collapse-title.vue:
--------------------------------------------------------------------------------
1 |
10 |
11 |
12 |
13 |
14 |
{{ title }}
15 |
16 |
17 |
18 |
19 |
40 |
--------------------------------------------------------------------------------
/front-end/src/views/login/login-view.vue:
--------------------------------------------------------------------------------
1 |
42 |
43 |
44 |
![背景图片]()
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
AI助手
53 |
构建你的AI助手
54 |
55 |
56 |
快速开始
57 |
登录你的账号
58 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 | 登录
75 | router.push('/register')"
81 | >
82 | 注册
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
160 |
--------------------------------------------------------------------------------
/front-end/src/views/login/register-view.vue:
--------------------------------------------------------------------------------
1 |
44 |
45 |
46 |
![背景图片]()
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
AI助手
55 |
构建你的AI助手
56 |
57 |
58 |
快速开始
59 |
创建你的账号
60 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 | 注册
78 |
79 |
80 | 登录
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
169 |
--------------------------------------------------------------------------------
/front-end/tsconfig.app.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "@vue/tsconfig/tsconfig.dom.json",
3 | "include": ["env.d.ts", "src/**/*", "src/**/*.vue"],
4 | "exclude": ["src/**/__tests__/*"],
5 | "compilerOptions": {
6 | "composite": true,
7 | "noEmit": true,
8 | "baseUrl": ".",
9 | "paths": {
10 | "@/*": ["./src/*"]
11 | },
12 | "types": ["element-plus/global"]
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/front-end/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "files": [],
3 | "references": [
4 | {
5 | "path": "./tsconfig.node.json"
6 | },
7 | {
8 | "path": "./tsconfig.app.json"
9 | }
10 | ]
11 | }
12 |
--------------------------------------------------------------------------------
/front-end/tsconfig.node.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "@tsconfig/node18/tsconfig.json",
3 | "include": [
4 | "vite.config.*",
5 | "vitest.config.*",
6 | "cypress.config.*",
7 | "nightwatch.conf.*",
8 | "playwright.config.*"
9 | ],
10 | "compilerOptions": {
11 | "composite": true,
12 | "noEmit": true,
13 | "module": "ESNext",
14 | "moduleResolution": "Bundler",
15 | "types": ["node"]
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/front-end/vite.config.ts:
--------------------------------------------------------------------------------
1 | import { fileURLToPath, URL } from 'node:url'
2 |
3 | import { defineConfig } from 'vite'
4 | import vue from '@vitejs/plugin-vue'
5 | import vueJsx from '@vitejs/plugin-vue-jsx'
6 |
7 | // https://vitejs.dev/config/
8 | export default defineConfig({
9 | plugins: [vue(), vueJsx()],
10 | resolve: {
11 | alias: {
12 | '@': fileURLToPath(new URL('./src', import.meta.url))
13 | }
14 | },
15 | server: {
16 | host: '0.0.0.0',
17 | port: 5177,
18 | proxy: {
19 | '/api': {
20 | target: 'http://localhost:9902',
21 | changeOrigin: true,
22 | rewrite: (path) => path.replace(/^\/api/, '')
23 | }
24 | }
25 | }
26 | })
27 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | io.github.qifan777
8 | dive-into-spring-ai
9 | 1.0-SNAPSHOT
10 |
11 |
12 | spring-boot-starter-parent
13 | org.springframework.boot
14 |
15 | 3.2.1
16 |
17 |
18 |
19 | 17
20 | 17
21 | UTF-8
22 | 1.0.0
23 | 1.0.0.1
24 | 0.1.9
25 | 0.8.134
26 | 5.8.25
27 | 1.37.0
28 | UTF-8
29 | UTF-8
30 |
31 |
32 |
33 | com.squareup.okhttp3
34 | okhttp
35 |
36 |
37 | com.github.javaparser
38 | javaparser-symbol-solver-core
39 | 3.26.2
40 |
41 |
42 | com.alibaba
43 | easyexcel
44 | 3.3.4
45 |
46 |
47 | commons-io
48 | commons-io
49 |
50 |
51 |
52 |
53 | org.springframework.boot
54 | spring-boot-starter-data-neo4j
55 |
56 |
57 | org.springframework.ai
58 | spring-ai-starter-vector-store-redis
59 |
60 |
61 | redis.clients
62 | jedis
63 |
64 |
65 | org.springframework.ai
66 | spring-ai-advisors-vector-store
67 |
68 |
69 | org.springframework.ai
70 | spring-ai-tika-document-reader
71 |
72 |
73 | jimmer-spring-boot-starter
74 | org.babyfish.jimmer
75 | ${jimmer.version}
76 |
77 |
78 | io.github.qifan777
79 | spring-boot-starter-oss
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 | com.alibaba.cloud.ai
104 | spring-ai-alibaba-starter-dashscope
105 | ${spring-ai-alibaba.version}
106 |
107 |
108 | org.springframework.boot
109 | spring-boot-starter-actuator
110 |
111 |
112 | org.springframework.boot
113 | spring-boot-starter-test
114 | test
115 |
116 |
117 | org.springframework
118 | spring-test
119 |
120 |
121 | mysql-connector-j
122 | com.mysql
123 |
124 |
125 | hutool-core
126 | cn.hutool
127 | ${hutool.version}
128 |
129 |
130 | cn.dev33
131 | sa-token-redis-jackson
132 | ${sa-token.version}
133 |
134 |
135 | cn.dev33
136 | sa-token-spring-boot3-starter
137 | ${sa-token.version}
138 |
139 |
140 |
141 |
142 |
143 | org.springframework.ai
144 | spring-ai-bom
145 | ${spring-ai.version}
146 | pom
147 | import
148 |
149 |
150 | io.github.qifan777
151 | uni-ai-bom
152 | ${uni-ai.version}
153 | pom
154 | import
155 |
156 |
157 |
158 |
159 |
160 | spring-milestones
161 | Spring Milestones
162 | https://repo.spring.io/milestone
163 |
164 | false
165 |
166 |
167 |
168 |
169 |
170 |
171 | org.springframework.boot
172 | spring-boot-maven-plugin
173 |
174 |
175 | org.apache.maven.plugins
176 | maven-compiler-plugin
177 |
178 |
179 |
180 | jimmer-apt
181 | org.babyfish.jimmer
182 | ${jimmer.version}
183 |
184 |
185 | lombok
186 | org.projectlombok
187 | ${lombok.version}
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
--------------------------------------------------------------------------------
/src/main/dto/AiMessage.dto:
--------------------------------------------------------------------------------
1 | export io.github.qifan777.knowledge.ai.message.AiMessage
2 |
3 | input AiMessageInput{
4 | #allScalars(AiMessage)
5 | id(session)
6 | }
--------------------------------------------------------------------------------
/src/main/dto/AiSession.dto:
--------------------------------------------------------------------------------
1 | export io.github.qifan777.knowledge.ai.session.AiSession
2 |
3 | input AiSessionInput{
4 | id?
5 | name
6 | }
--------------------------------------------------------------------------------
/src/main/dto/User.dto:
--------------------------------------------------------------------------------
1 | export io.github.qifan777.knowledge.user.User
2 |
3 | input UserLoginInput{
4 | phone
5 | password
6 | }
7 | input UserRegisterInput{
8 | phone
9 | password
10 | }
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ServerApplication.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge;
2 |
3 | import io.github.qifan777.knowledge.infrastructure.code.CodeAssistantProperties;
4 | import org.babyfish.jimmer.client.EnableImplicitApi;
5 | import org.springframework.boot.SpringApplication;
6 | import org.springframework.boot.autoconfigure.SpringBootApplication;
7 | import org.springframework.boot.context.properties.EnableConfigurationProperties;
8 |
9 | @SpringBootApplication
10 | @EnableImplicitApi
11 | @EnableConfigurationProperties(CodeAssistantProperties.class)
12 | public class ServerApplication {
13 | public static void main(String[] args) {
14 | SpringApplication.run(ServerApplication.class, args);
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/agent/AbstractAgent.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.agent;
2 |
3 | import lombok.extern.slf4j.Slf4j;
4 | import org.springframework.context.annotation.Description;
5 | import org.springframework.util.StringUtils;
6 |
7 | import java.util.Arrays;
8 | import java.util.List;
9 |
10 | @Slf4j
11 | public abstract class AbstractAgent {
12 |
13 | /**
14 | * 获取指定的function bean名称
15 | *
16 | * @return Function Call名称列表
17 | */
18 | public String[] getFunctions(Class>... clazz) {
19 | List> classList = Arrays.stream(clazz).filter(aClass -> aClass.isAnnotationPresent(Description.class)).toList();
20 | String[] names = new String[classList.size()];
21 | classList.stream().map(aClass -> StringUtils.uncapitalize(aClass.getSimpleName())).toList().toArray(names);
22 | return names;
23 | }
24 | /**
25 | * 获取内嵌的Function Call也就是Agent的Tools
26 | *
27 | * @return Function Call名称列表
28 | */
29 | public String[] getAgentFunctions(Class> clazz) {
30 | List> classList = Arrays.stream(clazz.getClasses()).filter(aClass -> aClass.isAnnotationPresent(Description.class)).toList();
31 | String[] names = new String[classList.size()];
32 | classList.stream().map(aClass -> StringUtils.uncapitalize(this.getClass().getSimpleName()) + "." + aClass.getSimpleName()).toList().toArray(names);
33 | return names;
34 | }
35 |
36 | }
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/agent/Agent.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.agent;
2 |
3 | import org.springframework.stereotype.Component;
4 |
5 | import java.lang.annotation.ElementType;
6 | import java.lang.annotation.Retention;
7 | import java.lang.annotation.RetentionPolicy;
8 | import java.lang.annotation.Target;
9 |
10 | @Component
11 | @Retention(RetentionPolicy.RUNTIME)
12 | @Target(ElementType.TYPE)
13 | public @interface Agent {
14 | }
15 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/agent/chronologist/Chronologist.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.agent.chronologist;
2 |
3 | import com.fasterxml.jackson.annotation.JsonProperty;
4 | import com.fasterxml.jackson.annotation.JsonPropertyDescription;
5 | import io.github.qifan777.knowledge.ai.agent.AbstractAgent;
6 | import io.github.qifan777.knowledge.ai.agent.Agent;
7 | import lombok.AllArgsConstructor;
8 | import org.springframework.ai.chat.client.ChatClient;
9 | import org.springframework.ai.chat.model.ChatModel;
10 | import org.springframework.context.annotation.Description;
11 | import org.springframework.stereotype.Component;
12 |
13 | import java.time.LocalDate;
14 | import java.time.LocalDateTime;
15 | import java.time.format.DateTimeFormatter;
16 | import java.util.function.Function;
17 |
18 | @Agent
19 | @Description(value = "回答用户有关于日期、时间的提问")
20 | @AllArgsConstructor
21 | public class Chronologist extends AbstractAgent implements Function {
22 | private final String SYSTEM = """
23 | 你是一个专业的编年史学家,可以回答有关时间的问题。
24 | 您还可以执行各种与时间相关的任务,如转换和格式化。
25 | """;
26 | private final ChatModel chatModel;
27 |
28 |
29 | @Override
30 | public String apply(Request request) {
31 | return ChatClient.create(chatModel)
32 | .prompt()
33 | .system(SYSTEM)
34 | .user(request.query)
35 | .toolNames(getAgentFunctions(this.getClass()))
36 | .call()
37 | .content();
38 | }
39 |
40 | public record Request(
41 | @JsonProperty(required = true) @JsonPropertyDescription(value = "用户原始的提问") String query) {
42 | }
43 |
44 | @Component
45 | @Description("获取当前的时间,格式是 HH:mm:ss")
46 | public static class CurrentTime implements Function {
47 | @Override
48 | public String apply(Request request) {
49 | LocalDateTime currentDate = LocalDateTime.now();
50 | return currentDate.toLocalTime().format(DateTimeFormatter.ofPattern("HH:mm:ss"));
51 |
52 | }
53 |
54 | public record Request() {
55 | }
56 | }
57 |
58 | @Component
59 | @Description("获取当前的日期,格式是 yyyy-MM-dd")
60 | public static class CurrentDate implements Function {
61 | @Override
62 | public String apply(Request request) {
63 | LocalDate currentDate = LocalDate.now();
64 | return currentDate.format(DateTimeFormatter.ofPattern("yyyy-MM-dd"));
65 | }
66 |
67 | public record Request() {
68 | }
69 | }
70 | }
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/agent/computer/ComputerAssistant.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.agent.computer;
2 |
3 | import com.fasterxml.jackson.annotation.JsonProperty;
4 | import com.fasterxml.jackson.annotation.JsonPropertyDescription;
5 | import io.github.qifan777.knowledge.ai.agent.AbstractAgent;
6 | import io.github.qifan777.knowledge.ai.agent.Agent;
7 | import lombok.AllArgsConstructor;
8 | import org.springframework.ai.chat.client.ChatClient;
9 | import org.springframework.ai.chat.model.ChatModel;
10 | import org.springframework.context.annotation.Description;
11 |
12 | import java.util.function.Function;
13 |
14 | @Agent
15 | @Description("提供关于当前主机的cpu,文件,文件夹相关问题的有用回答")
16 | @AllArgsConstructor
17 | public class ComputerAssistant extends AbstractAgent implements Function {
18 | private final ChatModel chatModel;
19 |
20 | @Override
21 | public String apply(Request request) {
22 | return ChatClient.create(chatModel)
23 | .prompt()
24 | .toolNames(getFunctions(CpuAnalyzer.class, DirectoryReader.class))
25 | .user(request.query())
26 | .call()
27 | .content();
28 | }
29 |
30 | public record Request(
31 | @JsonProperty(required = true) @JsonPropertyDescription(value = "用户原始的提问") String query) {
32 | }
33 |
34 |
35 | }
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/agent/computer/CpuAnalyzer.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.agent.computer;
2 |
3 | import org.springframework.context.annotation.Description;
4 | import org.springframework.stereotype.Component;
5 |
6 | import java.util.function.Function;
7 |
8 | @Component
9 | @Description("读取CPU的数量")
10 | public class CpuAnalyzer implements Function {
11 | @Override
12 | public Integer apply(Request request) {
13 | return Runtime.getRuntime().availableProcessors();
14 | }
15 |
16 | public record Request() {
17 | }
18 | }
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/agent/computer/DirectoryReader.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.agent.computer;
2 |
3 | import com.fasterxml.jackson.annotation.JsonProperty;
4 | import com.fasterxml.jackson.annotation.JsonPropertyDescription;
5 | import org.springframework.context.annotation.Description;
6 | import org.springframework.stereotype.Component;
7 |
8 | import java.io.File;
9 | import java.util.ArrayList;
10 | import java.util.Arrays;
11 | import java.util.List;
12 | import java.util.function.Function;
13 |
14 | @Component
15 | @Description("读取用户给定的文件夹,列出文件夹下的所有文件")
16 | public class DirectoryReader implements Function {
17 | @Override
18 | public String apply(Request request) {
19 | File f = new File(request.path);
20 | List out = new ArrayList<>();
21 | if (f.exists()) {
22 | String[] list = f.list();
23 | if (list != null) {
24 | out = Arrays.asList(list);
25 | }
26 | }
27 | return String.join(",", out);
28 | }
29 |
30 | public record Request(
31 | @JsonProperty(required = true) @JsonPropertyDescription("本机文件夹的绝对路径") String path
32 | ) {
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/document/DocumentController.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.document;
2 |
3 | import lombok.AllArgsConstructor;
4 | import lombok.SneakyThrows;
5 | import org.springframework.ai.document.Document;
6 | import org.springframework.ai.reader.tika.TikaDocumentReader;
7 | import org.springframework.ai.transformer.splitter.TokenTextSplitter;
8 | import org.springframework.ai.vectorstore.VectorStore;
9 | import org.springframework.core.io.InputStreamResource;
10 | import org.springframework.web.bind.annotation.PostMapping;
11 | import org.springframework.web.bind.annotation.RequestMapping;
12 | import org.springframework.web.bind.annotation.RequestParam;
13 | import org.springframework.web.bind.annotation.RestController;
14 | import org.springframework.web.multipart.MultipartFile;
15 |
16 | import java.util.List;
17 |
18 | @RequestMapping("document")
19 | @RestController
20 | @AllArgsConstructor
21 | public class DocumentController {
22 | private final VectorStore vectorStore;
23 |
24 | /**
25 | * 嵌入文件
26 | *
27 | * @param file 待嵌入的文件
28 | * @return 是否成功
29 | */
30 | @SneakyThrows
31 | @PostMapping("embedding")
32 | public Boolean embedding(@RequestParam MultipartFile file) {
33 | // 从IO流中读取文件
34 | TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new InputStreamResource(file.getInputStream()));
35 | // 将文本内容划分成更小的块
36 | List splitDocuments = new TokenTextSplitter()
37 | .apply(tikaDocumentReader.read());
38 | // 存入向量数据库,这个过程会自动调用embeddingModel,将文本变成向量再存入。
39 | vectorStore.add(splitDocuments);
40 | return true;
41 | }
42 |
43 | }
44 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/message/AiMessage.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.message;
2 |
3 | import io.github.qifan777.knowledge.ai.session.AiSession;
4 | import io.github.qifan777.knowledge.infrastructure.jimmer.BaseEntity;
5 | import jakarta.validation.constraints.Null;
6 | import lombok.AllArgsConstructor;
7 | import lombok.Data;
8 | import lombok.experimental.Accessors;
9 | import org.babyfish.jimmer.sql.*;
10 | import org.springframework.ai.chat.messages.MessageType;
11 |
12 | import java.util.List;
13 |
14 | /**
15 | * 历史消息
16 | */
17 |
18 | @Entity
19 | public interface AiMessage extends BaseEntity {
20 |
21 | /**
22 | * 消息类型(用户/助手/系统)
23 | */
24 | MessageType type();
25 |
26 | /**
27 | * 消息内容
28 | */
29 | String textContent();
30 |
31 | @Serialized
32 | @Null
33 | List medias();
34 |
35 | @IdView
36 | String sessionId();
37 |
38 | /**
39 | * 会话
40 | */
41 | @ManyToOne
42 | @JoinColumn(name = "ai_session_id")
43 | @OnDissociate(DissociateAction.DELETE)
44 | AiSession session();
45 |
46 | @Data
47 | @Accessors(chain = true)
48 | class Media {
49 | public String type;
50 | public String data;
51 | }
52 | }
53 |
54 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/message/AiMessageChatMemory.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.message;
2 |
3 | import cn.hutool.core.collection.CollectionUtil;
4 | import io.qifan.infrastructure.common.exception.BusinessException;
5 | import lombok.AllArgsConstructor;
6 | import lombok.SneakyThrows;
7 | import org.jetbrains.annotations.NotNull;
8 | import org.springframework.ai.chat.memory.ChatMemory;
9 | import org.springframework.ai.chat.messages.*;
10 | import org.springframework.ai.content.Media;
11 | import org.springframework.http.MediaType;
12 | import org.springframework.stereotype.Service;
13 |
14 | import java.net.URL;
15 | import java.util.ArrayList;
16 | import java.util.List;
17 |
18 | @Service
19 | @AllArgsConstructor
20 | public class AiMessageChatMemory implements ChatMemory {
21 | private final AiMessageRepository messageRepository;
22 |
23 | public static AiMessage toAiMessage(Message message, String sessionId) {
24 | return AiMessageDraft.$.produce(draft -> {
25 | draft.setSessionId(sessionId)
26 | .setTextContent(message.getText())
27 | .setType(message.getMessageType())
28 | .setMedias(new ArrayList<>());
29 | if (message instanceof UserMessage userMessage &&
30 | !CollectionUtil.isEmpty(userMessage.getMedia())) {
31 | List mediaList = userMessage
32 | .getMedia()
33 | .stream()
34 | .map(media -> new AiMessage.Media()
35 | .setType(media.getMimeType().getType())
36 | .setData(media.getData().toString()))
37 | .toList();
38 | draft.setMedias(mediaList);
39 | }
40 | });
41 | }
42 |
43 | public static Message toSpringAiMessage(AiMessage aiMessage) {
44 | List mediaList = new ArrayList<>();
45 | if (!CollectionUtil.isEmpty(aiMessage.medias())) {
46 | mediaList = aiMessage.medias().stream().map(AiMessageChatMemory::toSpringAiMedia).toList();
47 | }
48 | if (aiMessage.type().equals(MessageType.ASSISTANT)) {
49 | return new AssistantMessage(aiMessage.textContent());
50 | }
51 | if (aiMessage.type().equals(MessageType.USER)) {
52 | return UserMessage.builder().text(aiMessage.textContent())
53 | .media(mediaList)
54 | .build();
55 | }
56 | if (aiMessage.type().equals(MessageType.SYSTEM)) {
57 | return new SystemMessage(aiMessage.textContent());
58 | }
59 | throw new BusinessException("不支持的消息类型");
60 | }
61 |
62 | @SneakyThrows
63 | public static Media toSpringAiMedia(AiMessage.Media media) {
64 | return Media.builder().data(new URL(media.getData()))
65 | .mimeType(new MediaType(media.getType()))
66 | .build();
67 | }
68 |
69 | /**
70 | * 不实现,手动前端发起请求保存用户的消息和大模型回复的消息
71 | */
72 | @Override
73 | public void add(@NotNull String conversationId, @NotNull List messages) {
74 | }
75 |
76 |
77 | /**
78 | * 查询会话内的消息最新n条历史记录
79 | *
80 | * @param conversationId 会话id
81 | * @return org.springframework.ai.chat.messages.Message格式的消息
82 | */
83 | @Override
84 | public @NotNull List get(@NotNull String conversationId) {
85 | return messageRepository
86 | // 查询会话内的最新10条消息
87 | .findBySessionId(conversationId, 10)
88 | .stream()
89 | // 转成Message对象
90 | .map(AiMessageChatMemory::toSpringAiMessage)
91 | .toList();
92 | }
93 |
94 | /**
95 | * 清除会话内的消息
96 | *
97 | * @param conversationId 会话id
98 | */
99 | @Override
100 | public void clear(@NotNull String conversationId) {
101 | messageRepository.deleteBySessionId(conversationId);
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/message/AiMessageController.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.message;
2 |
3 | import com.fasterxml.jackson.databind.ObjectMapper;
4 | import io.github.qifan777.knowledge.ai.agent.Agent;
5 | import io.github.qifan777.knowledge.ai.message.dto.AiMessageInput;
6 | import io.github.qifan777.knowledge.ai.message.dto.AiMessageWrapper;
7 | import lombok.AllArgsConstructor;
8 | import lombok.SneakyThrows;
9 | import lombok.extern.slf4j.Slf4j;
10 | import org.springframework.ai.chat.client.ChatClient;
11 | import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
12 | import org.springframework.ai.chat.client.advisor.vectorstore.QuestionAnswerAdvisor;
13 | import org.springframework.ai.chat.messages.Message;
14 | import org.springframework.ai.chat.messages.UserMessage;
15 | import org.springframework.ai.chat.model.ChatModel;
16 | import org.springframework.ai.chat.model.ChatResponse;
17 | import org.springframework.ai.chat.prompt.PromptTemplate;
18 | import org.springframework.ai.content.Media;
19 | import org.springframework.ai.reader.tika.TikaDocumentReader;
20 | import org.springframework.ai.vectorstore.VectorStore;
21 | import org.springframework.context.ApplicationContext;
22 | import org.springframework.core.io.InputStreamResource;
23 | import org.springframework.http.MediaType;
24 | import org.springframework.http.codec.ServerSentEvent;
25 | import org.springframework.util.CollectionUtils;
26 | import org.springframework.web.bind.annotation.*;
27 | import org.springframework.web.multipart.MultipartFile;
28 | import reactor.core.publisher.Flux;
29 |
30 | import java.util.Map;
31 |
32 | @RequestMapping("message")
33 | @RestController
34 | @AllArgsConstructor
35 | @Slf4j
36 | public class AiMessageController {
37 | private final AiMessageChatMemory chatMemory;
38 | private final ChatModel chatModel;
39 | // private final ImageModel imageModel;
40 | private final VectorStore vectorStore;
41 | private final ObjectMapper objectMapper;
42 | private final AiMessageRepository messageRepository;
43 | private final ApplicationContext applicationContext;
44 |
45 | @DeleteMapping("history/{sessionId}")
46 | public void deleteHistory(@PathVariable String sessionId) {
47 | chatMemory.clear(sessionId);
48 | }
49 |
50 | /**
51 | * 消息保存
52 | *
53 | * @param input 用户发送的消息/AI回复的消息
54 | */
55 | @PostMapping
56 | public void save(@RequestBody AiMessageInput input) {
57 | messageRepository.save(input.toEntity());
58 | }
59 |
60 | // @PostMapping("chat/image")
61 | // public String textToImageChat(@RequestBody AiMessageInput input) {
62 | // return imageModel.call(new ImagePrompt(input.getTextContent())).getResult().getOutput().getUrl();
63 | // }
64 |
65 | /**
66 | * 为了支持文件问答,需要同时接收json(AiMessageWrapper json体)和 MultipartFile(文件)
67 | * Content-Type 从 application/json 修改为 multipart/form-data
68 | * 之前接收请求参数是用@RequestBody, 现在使用@RequestPart 接收json字符串再手动转成AiMessageWrapper.
69 | * SpringMVC的@RequestPart是支持自动将Json字符串转换为Java对象,也就是说可以等效`@RequestBody`,
70 | * 但是由于前端FormData无法设置Part的Content-Type,所以只能手动转json字符串再转成Java对象。
71 | *
72 | * @param input 消息包含文本信息,会话id,多媒体信息(图片语言)。参考src/main/dto/AiMessage.dto
73 | * @param file 文件问答
74 | * @return SSE流
75 | */
76 | @SneakyThrows
77 | @PostMapping(value = "chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
78 | public Flux> chat(@RequestPart String input, @RequestPart(required = false) MultipartFile file) {
79 | AiMessageWrapper aiMessageWrapper = objectMapper.readValue(input, AiMessageWrapper.class);
80 | String[] functionBeanNames = new String[0];
81 | // 如果启用Agent则获取Agent的bean
82 | if (aiMessageWrapper.getParams().getEnableAgent()) {
83 | // 获取带有Agent注解的bean
84 | Map beansWithAnnotation = applicationContext.getBeansWithAnnotation(Agent.class);
85 | functionBeanNames = new String[beansWithAnnotation.size()];
86 | functionBeanNames = beansWithAnnotation.keySet().toArray(functionBeanNames);
87 | }
88 | return ChatClient.create(chatModel).prompt()
89 | // 启用文件问答
90 | .system(promptSystemSpec -> useFile(promptSystemSpec, file))
91 | .user(promptUserSpec -> toPrompt(promptUserSpec, aiMessageWrapper.getMessage()))
92 | // agent列表
93 | .toolNames(functionBeanNames)
94 | .advisors(advisorSpec -> {
95 | // 使用历史消息
96 | useChatHistory(advisorSpec, aiMessageWrapper.getMessage().getSessionId());
97 | // 使用向量数据库
98 | useVectorStore(advisorSpec, aiMessageWrapper.getParams().getEnableVectorStore());
99 | })
100 | .stream()
101 | .chatResponse()
102 | .map(chatResponse -> ServerSentEvent.builder(toJson(chatResponse))
103 | // 和前端监听的事件相对应
104 | .event("message")
105 | .build());
106 | }
107 |
108 | @SneakyThrows
109 | public String toJson(ChatResponse response) {
110 | return objectMapper.writeValueAsString(response);
111 | }
112 |
113 | public void toPrompt(ChatClient.PromptUserSpec promptUserSpec, AiMessageInput input) {
114 | // AiMessageInput转成Message
115 | Message message = AiMessageChatMemory.toSpringAiMessage(input.toEntity());
116 | if (message instanceof UserMessage userMessage &&
117 | !CollectionUtils.isEmpty(userMessage.getMedia())) {
118 | // 用户发送的图片/语言
119 | Media[] medias = new Media[userMessage.getMedia().size()];
120 | promptUserSpec.media(userMessage.getMedia().toArray(medias));
121 | }
122 | // 用户发送的文本
123 | promptUserSpec.text(message.getText());
124 | }
125 |
126 | public void useChatHistory(ChatClient.AdvisorSpec advisorSpec, String sessionId) {
127 | // 1. 如果需要存储会话和消息到数据库,自己可以实现ChatMemory接口,这里使用自己实现的AiMessageChatMemory,数据库存储。
128 | // 2. 传入会话id,MessageChatMemoryAdvisor会根据会话id去查找消息。
129 | // 3. 只需要携带最近10条消息
130 | // MessageChatMemoryAdvisor会在消息发送给大模型之前,从ChatMemory中获取会话的历史消息,然后一起发送给大模型。
131 | advisorSpec.advisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId(sessionId).build());
132 | }
133 |
134 | public void useVectorStore(ChatClient.AdvisorSpec advisorSpec, Boolean enableVectorStore) {
135 | if (!enableVectorStore) return;
136 | // question_answer_context是一个占位符,会替换成向量数据库中查询到的文档。QuestionAnswerAdvisor会替换。
137 | String promptWithContext = """
138 | {query}
139 | 下面是上下文信息
140 | ---------------------
141 | {question_answer_context}
142 | ---------------------
143 | 给定的上下文和提供的历史信息,而不是事先的知识,回复用户的意见。如果答案不在上下文中,告诉用户你不能回答这个问题。
144 | """;
145 | advisorSpec.advisors(QuestionAnswerAdvisor.builder(vectorStore)
146 | .promptTemplate(new PromptTemplate(promptWithContext))
147 | .build());
148 | }
149 |
150 | @SneakyThrows
151 | public void useFile(ChatClient.PromptSystemSpec spec, MultipartFile file) {
152 | if (file == null) return;
153 | String content = new TikaDocumentReader(new InputStreamResource(file.getInputStream())).get().get(0).getText();
154 | Message message = new PromptTemplate("""
155 | 已下内容是额外的知识,在你回答问题时可以参考下面的内容
156 | ---------------------
157 | {context}
158 | ---------------------
159 | """)
160 | .createMessage(Map.of("context", content));
161 | spec.text(message.getText());
162 | }
163 |
164 | }
165 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/message/AiMessageRepository.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.message;
2 |
3 | import org.babyfish.jimmer.spring.repository.JRepository;
4 |
5 | import java.util.List;
6 |
7 | public interface AiMessageRepository extends JRepository {
8 | AiMessageTable t = AiMessageTable.$;
9 |
10 | default List findBySessionId(String sessionId, int lastN) {
11 | return sql()
12 | .createQuery(t)
13 | .where(t.sessionId().eq(sessionId))
14 | .orderBy(t.createdTime().asc())
15 | .select(t)
16 | .limit(lastN)
17 | .execute();
18 | }
19 |
20 | default Integer deleteBySessionId(String sessionId) {
21 | return sql().createDelete(t)
22 | .where(t.sessionId().eq(sessionId))
23 | .execute();
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/message/MessageInputWrapper.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.message;
2 |
3 | import lombok.Data;
4 |
5 | @Data
6 | public class MessageInputWrapper {
7 | AiMessage message;
8 |
9 |
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/message/dto/AiMessageParams.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.message.dto;
2 |
3 | import lombok.Data;
4 |
5 | @Data
6 | public class AiMessageParams {
7 | Boolean enableVectorStore;
8 | Boolean enableAgent;
9 | }
10 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/message/dto/AiMessageWrapper.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.message.dto;
2 |
3 | import lombok.Data;
4 |
5 | @Data
6 | public class AiMessageWrapper {
7 | AiMessageInput message;
8 | AiMessageParams params;
9 | }
10 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/session/AiSession.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.session;
2 |
3 | import io.github.qifan777.knowledge.ai.message.AiMessage;
4 | import io.github.qifan777.knowledge.infrastructure.jimmer.BaseEntity;
5 | import org.babyfish.jimmer.sql.Entity;
6 | import org.babyfish.jimmer.sql.OneToMany;
7 | import org.babyfish.jimmer.sql.OrderedProp;
8 |
9 | import java.util.List;
10 |
11 | /**
12 | * 会话
13 | */
14 | @Entity
15 | public interface AiSession extends BaseEntity {
16 |
17 | /**
18 | * 会话名称
19 | */
20 | String name();
21 |
22 | /**
23 | * 一对多关联消息,按创建时间升序
24 | */
25 |
26 | @OneToMany(mappedBy = "session", orderedProps = @OrderedProp(value = "createdTime"))
27 | List messages();
28 | }
29 |
30 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/session/AiSessionController.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.session;
2 |
3 | import io.github.qifan777.knowledge.ai.session.dto.AiSessionInput;
4 | import io.qifan.infrastructure.common.exception.BusinessException;
5 | import lombok.AllArgsConstructor;
6 | import org.babyfish.jimmer.client.FetchBy;
7 | import org.springframework.web.bind.annotation.*;
8 |
9 | import java.util.List;
10 |
11 | @RequestMapping("session")
12 | @RestController
13 | @AllArgsConstructor
14 | public class AiSessionController {
15 | private final AiSessionRepository sessionRepository;
16 |
17 | /**
18 | * 根据id查询会话
19 | * @param id 会话id
20 | * @return 会话信息
21 | */
22 | @GetMapping("{id}")
23 | public @FetchBy(value = "FETCHER", ownerType = AiSessionRepository.class) AiSession findById(@PathVariable String id) {
24 | return sessionRepository.findById(id, AiSessionRepository.FETCHER).orElseThrow(() -> new BusinessException("会话不存在"));
25 | }
26 |
27 | /**
28 | * 保存会话
29 | * @param input 会话dto参考src/main/dto/AiSession.dto
30 | * @return 创建后的id
31 | */
32 | @PostMapping("save")
33 | public String save(@RequestBody AiSessionInput input) {
34 | return sessionRepository.save(input.toEntity()).id();
35 | }
36 |
37 | /**
38 | * 查询当前登录用户的会话
39 | *
40 | * @return 会话列表
41 | */
42 | @GetMapping("user")
43 | public List<@FetchBy(value = "FETCHER", ownerType = AiSessionRepository.class) AiSession> findByUser() {
44 | return sessionRepository.findByUser();
45 | }
46 |
47 | /**
48 | * 批量删除会话
49 | * @param ids 会话id列表
50 | */
51 | @DeleteMapping
52 | public void delete(@RequestBody List ids) {
53 | sessionRepository.deleteByIds(ids);
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/ai/session/AiSessionRepository.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.ai.session;
2 |
3 | import cn.dev33.satoken.stp.StpUtil;
4 | import io.github.qifan777.knowledge.ai.message.AiMessageFetcher;
5 | import org.babyfish.jimmer.spring.repository.JRepository;
6 |
7 | import java.util.List;
8 |
9 | public interface AiSessionRepository extends JRepository {
10 | AiSessionTable t = AiSessionTable.$;
11 | AiSessionFetcher FETCHER = AiSessionFetcher.$.allScalarFields()
12 | .messages(AiMessageFetcher.$.allScalarFields().sessionId());
13 |
14 | default List findByUser() {
15 | return sql().createQuery(t)
16 | .where(t.creatorId().eq(StpUtil.getLoginIdAsString()))
17 | .select(t.fetch(FETCHER))
18 | .execute();
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/code/CodeAssistantAgent.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.code;
2 |
3 | import com.fasterxml.jackson.annotation.JsonProperty;
4 | import com.fasterxml.jackson.annotation.JsonPropertyDescription;
5 | import io.github.qifan777.knowledge.ai.agent.AbstractAgent;
6 | import io.github.qifan777.knowledge.ai.agent.Agent;
7 | import io.github.qifan777.knowledge.code.analyze.AnalyzeFunction;
8 | import io.github.qifan777.knowledge.code.arthas.ArthasFunction;
9 | import lombok.AllArgsConstructor;
10 | import org.springframework.ai.chat.client.ChatClient;
11 | import org.springframework.ai.chat.model.ChatModel;
12 | import org.springframework.context.annotation.Description;
13 |
14 | import java.util.function.Function;
15 |
16 | @Description("提供有关于Java代码的评审分析,在线诊断异常相关的回答")
17 | @Agent
18 | @AllArgsConstructor
19 | public class CodeAssistantAgent extends AbstractAgent implements Function {
20 | private final ChatModel chatModel;
21 |
22 | @Override
23 | public String apply(Request request) {
24 | return ChatClient.create(chatModel)
25 | .prompt()
26 | .user(request.query())
27 | .toolNames(getFunctions(AnalyzeFunction.class, ArthasFunction.class))
28 | .call()
29 | .content();
30 | }
31 |
32 | public record Request(
33 | @JsonProperty(required = true) @JsonPropertyDescription(value = "用户原始的提问") String query) {
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/code/analyze/AnalyzeController.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.code.analyze;
2 |
3 | import com.fasterxml.jackson.databind.ObjectMapper;
4 | import lombok.AllArgsConstructor;
5 | import lombok.SneakyThrows;
6 | import org.springframework.http.MediaType;
7 | import org.springframework.http.codec.ServerSentEvent;
8 | import org.springframework.web.bind.annotation.GetMapping;
9 | import org.springframework.web.bind.annotation.RequestMapping;
10 | import org.springframework.web.bind.annotation.RequestParam;
11 | import org.springframework.web.bind.annotation.RestController;
12 | import reactor.core.publisher.Flux;
13 |
14 | @RestController
15 | @RequestMapping("analyze")
16 | @AllArgsConstructor
17 | public class AnalyzeController {
18 | private final AnalyzeFunction analyzeFunction;
19 | private final ObjectMapper objectMapper;
20 |
21 | @GetMapping(produces = MediaType.TEXT_EVENT_STREAM_VALUE)
22 | public Flux> analyzeTask(@RequestParam String path) {
23 | return analyzeFunction.analyze(path)
24 | .map(content -> ServerSentEvent.builder(toJson(content))
25 | .event("message")
26 | .build());
27 | }
28 |
29 | @SneakyThrows
30 | public String toJson(AnalyzeFunction.AnalyzeResult result) {
31 | return objectMapper.writeValueAsString(result);
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/code/analyze/AnalyzeFunction.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.code.analyze;
2 |
3 | import cn.hutool.core.io.FileUtil;
4 | import com.fasterxml.jackson.annotation.JsonProperty;
5 | import com.fasterxml.jackson.annotation.JsonPropertyDescription;
6 | import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
7 | import io.github.qifan777.knowledge.code.graph.entity.MethodNode;
8 | import io.github.qifan777.knowledge.code.graph.service.CodeGraphService;
9 | import io.github.qifan777.knowledge.infrastructure.code.CodeAssistantProperties;
10 | import io.github.qifan777.knowledge.infrastructure.code.JavaParserUtils;
11 | import lombok.AllArgsConstructor;
12 | import lombok.Data;
13 | import lombok.SneakyThrows;
14 | import lombok.extern.slf4j.Slf4j;
15 | import org.springframework.ai.chat.model.ChatModel;
16 | import org.springframework.ai.chat.prompt.PromptTemplate;
17 | import org.springframework.context.annotation.Description;
18 | import org.springframework.stereotype.Service;
19 | import reactor.core.publisher.Flux;
20 |
21 | import java.nio.charset.StandardCharsets;
22 | import java.nio.file.Path;
23 | import java.util.List;
24 | import java.util.Map;
25 | import java.util.concurrent.Executor;
26 | import java.util.concurrent.Executors;
27 | import java.util.function.Function;
28 | import java.util.stream.Collectors;
29 |
30 | @Description("评审分析给定的java文件")
31 | @Service
32 | @AllArgsConstructor
33 | @Slf4j
34 | public class AnalyzeFunction implements Function {
35 | private final CodeAssistantProperties properties;
36 | private final CodeGraphService codeGraphService;
37 | // 防止页面多次调用开启多个线程token消耗过多,demo使用单线程
38 | private final Executor executor = Executors.newSingleThreadExecutor();
39 | private final ChatModel chatModel;
40 |
41 | public record Request(@JsonProperty(required = true)
42 | @JsonPropertyDescription("java文件路径如:src/main/java/xxx.java") String path) {
43 | }
44 |
45 | /**
46 | * 重定向到下面的网页进行展示分析结果,由于一个类里面可能有很多个方法,分析的单位是按照方法来分析,因此单独做一个页面来展示分析结果。
47 | *
48 | * @param request the function argument
49 | * @return 重定向链接
50 | */
51 | @Override
52 | public String apply(Request request) {
53 | return "请在下面的网页链接查看评审结果:http://localhost:5177/#/analyze?path=" + request.path;
54 | }
55 |
56 | /**
57 | * 创建一个结果分析流,另起一个线程开启解析java文件,获取主类中的所有方法,然后分析方法调用得到分析结果,推流到前端。
58 | *
59 | * @param filePath java文件路径 如:io.qifan.github777.UserController.java
60 | * @return Flux 结果分析流
61 | */
62 | @SneakyThrows
63 | public Flux analyze(String filePath) {
64 | log.info("正在评审文件:{}", filePath);
65 | return Flux.create(fluxSink -> executor.execute(() -> {
66 | JavaParserUtils.parse(Path.of(properties.getProject().getProjectPath(), "src", "main", "java", filePath))
67 | .getResult()
68 | .map(compilationUnit -> compilationUnit.findAll(ClassOrInterfaceDeclaration.class))
69 | // 只分析主类, 可能java文件中一个类都没有,因此返回的是一个Optional。如果不使用FlatMap,会直接返回一个Optional>,
70 | // 因此需要使用flatMap,是的返回结果变成Optional
71 | .flatMap(classOrInterfaceDeclarations -> classOrInterfaceDeclarations.stream().findFirst())
72 | // 由于类可能是匿名内部类,这边过滤一下。下面可以直接使用get()。当然实际上这个地方的主类类名肯定是存在的,但是为了写法严谨我还是判断了一下
73 | .filter(classOrInterfaceDeclaration -> classOrInterfaceDeclaration.getFullyQualifiedName().isPresent())
74 | // 获取到类名,并遍历所有方法
75 | .map(classOrInterfaceDeclaration -> {
76 | String fullyQualifiedName = classOrInterfaceDeclaration
77 | .getFullyQualifiedName().get();
78 | return classOrInterfaceDeclaration
79 | .getMethods()
80 | .stream()
81 | .map(methodDeclaration -> fullyQualifiedName + "#" + methodDeclaration.getNameAsString())
82 | .toList();
83 | })
84 | .ifPresentOrElse(methodIds -> {
85 | methodIds.forEach(methodId -> {
86 | analyzeMethod(methodId).doOnNext(fluxSink::next).blockLast();
87 | });
88 | }, () -> {
89 | // 如果没有主类(Class Or Interface),则直接分析整个文件,或者不是Java文件(Mapper.xml)
90 | fluxSink.next(analyzeFile(filePath));
91 | });
92 | // 完成
93 | fluxSink.complete();
94 | }));
95 | }
96 |
97 | /**
98 | * 分析单个方法调用
99 | *
100 | * @param methodId 方法ID
101 | * @return 分析结果流
102 | */
103 | public Flux analyzeMethod(String methodId) {
104 | List childMethods = codeGraphService.findChildMethods(methodId);
105 | String prompt = new PromptTemplate("""
106 | 请你根据{methodId}的调用链评审代码,并给出你的改进建议,并且附带修改后的代码片段,用中文回答。
107 | {methodChains}
108 | """)
109 | .createMessage(Map.of("methodChains", childMethods
110 | .stream()
111 | .map(MethodNode::getContent)
112 | .distinct()
113 | .collect(Collectors.joining("\n")),
114 | "methodId", methodId))
115 | .getText();
116 | log.info("评审方法: {}", prompt);
117 | String content = childMethods.stream().filter(m -> m.getId().equals(methodId)).findFirst().orElseThrow().getContent();
118 | return chatModel.stream(prompt).map(response -> new AnalyzeResult(methodId, response, methodId.split("#")[0], content));
119 | }
120 |
121 | public AnalyzeResult analyzeFile(String filePath) {
122 | String fileContent = FileUtil.readString(Path.of(properties.getProject().getProjectPath(), filePath).toFile(), StandardCharsets.UTF_8);
123 | String prompt = new PromptTemplate("""
124 | 请你评审一下该提交文件是否有可以改进的地方,并且附带修改后的代码片段,没有请回答无,用中文回答。
125 | {content}
126 | """)
127 | .createMessage(Map.of("content", fileContent)).getText();
128 | log.info("评审文件提示词: {}", prompt);
129 | String result = chatModel.call(prompt);
130 | log.info("文件分析结果: {}", result);
131 | return new AnalyzeResult(filePath, result, filePath, fileContent);
132 | }
133 |
134 | @Data
135 | @AllArgsConstructor
136 | public static class AnalyzeResult {
137 | private String id;
138 | private String content;
139 | private String fileName;
140 | private String fileContent;
141 | }
142 |
143 |
144 | }
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/code/arthas/ArthasFunction.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.code.arthas;
2 |
3 | import com.fasterxml.jackson.annotation.JsonProperty;
4 | import com.fasterxml.jackson.annotation.JsonPropertyDescription;
5 | import com.fasterxml.jackson.databind.JsonNode;
6 | import io.github.qifan777.knowledge.code.graph.entity.MethodNode;
7 | import io.github.qifan777.knowledge.code.graph.service.CodeGraphService;
8 | import io.github.qifan777.knowledge.infrastructure.code.CodeAssistantProperties;
9 | import lombok.AllArgsConstructor;
10 | import lombok.Data;
11 | import lombok.experimental.Accessors;
12 | import lombok.extern.slf4j.Slf4j;
13 | import org.springframework.ai.chat.model.ChatModel;
14 | import org.springframework.ai.chat.prompt.PromptTemplate;
15 | import org.springframework.context.annotation.Description;
16 | import org.springframework.http.HttpEntity;
17 | import org.springframework.http.HttpHeaders;
18 | import org.springframework.http.HttpMethod;
19 | import org.springframework.stereotype.Service;
20 | import org.springframework.web.client.RestTemplate;
21 |
22 | import java.nio.charset.StandardCharsets;
23 | import java.util.Base64;
24 | import java.util.List;
25 | import java.util.Map;
26 | import java.util.function.Function;
27 | import java.util.stream.Collectors;
28 |
29 | @Slf4j
30 | @Service
31 | @AllArgsConstructor
32 | @Description("诊断方法出现异常的原因")
33 | public class ArthasFunction implements Function {
34 | private final CodeGraphService codeGraphService;
35 | private final CodeAssistantProperties properties;
36 | private final ChatModel chatModel;
37 |
38 | public record Request(@JsonProperty(required = true)
39 | @JsonPropertyDescription("类名") String className,
40 | @JsonProperty(required = true)
41 | @JsonPropertyDescription("类名") String methodName) {
42 | }
43 |
44 | @Override
45 | public String apply(Request request) {
46 | String methodId = request.className + "#" + request.methodName;
47 | log.info("监听目标:{}", methodId);
48 | JobResult jobResult = startWatch(request.className, request.methodName);
49 | if (jobResult == null) {
50 | return "无异常信息";
51 | }
52 | String analyzeResult = jobResult
53 | .getBody()
54 | .getResults()
55 | .stream()
56 | .filter(r -> r.getType().equals("tt"))
57 | .findFirst()
58 | .map(result -> {
59 | String methods = codeGraphService.findChildMethods(methodId).stream().map(MethodNode::getContent)
60 | .distinct()
61 | .collect(Collectors.joining("\n"));
62 | TimeFragment timeFragment = result.getTimeFragmentList().get(0);
63 | String content = new PromptTemplate("""
64 | 根据下面提供的内容分析异常原因,回答结果用中文
65 | 方法名称:{methodName}
66 | 方法调用链: {methods}
67 | 方法参数:{params}
68 | 异常信息:{exp}
69 | """)
70 | .createMessage(Map.of("methodName", timeFragment.getMethodName(),
71 | "methods", methods,
72 | "params", timeFragment.getParams().stream().map(param -> param.getObject() == null ? "" : param.getObject().toString()).collect(Collectors.joining("\n")),
73 | "exp", timeFragment.getThrowExp()))
74 | .getText();
75 | log.info("代码诊断prompt: {}", content);
76 | return chatModel.call(content);
77 | })
78 | .orElse("无异常信息");
79 | log.info("诊断结果: {}", analyzeResult);
80 | return analyzeResult;
81 | }
82 |
83 |
84 | public JobResult startWatch(String className, String method) {
85 | CodeAssistantProperties.ArthasProperties arthasProperties = properties.getArthas();
86 | RestTemplate restTemplate = new RestTemplate();
87 | String encode = Base64.getEncoder().encodeToString((arthasProperties.getUsername() + ":" + arthasProperties.getPassword()).getBytes(StandardCharsets.UTF_8));
88 | ArthasRequest arthasRequest = new ArthasRequest().setAction("exec").setCommand("tt -t " + className + " " + method + " -n 1");
89 | HttpHeaders httpHeaders = new HttpHeaders();
90 | httpHeaders.add("Authorization", "Basic " + encode);
91 | HttpEntity requestHttpEntity = new HttpEntity<>(arthasRequest, httpHeaders);
92 | return restTemplate.exchange(arthasProperties.getUrl(), HttpMethod.POST, requestHttpEntity, JobResult.class).getBody();
93 |
94 | }
95 |
96 | @Accessors(chain = true)
97 | @Data
98 | public static class ArthasRequest {
99 | private String action;
100 | private String command;
101 | }
102 |
103 | @Data
104 | public static class JobResult {
105 | @JsonProperty("body")
106 | private Body body;
107 | private String message;
108 | private String sessionId;
109 | private String state;
110 | }
111 |
112 | @Data
113 | public static class Body {
114 | @JsonProperty("command")
115 | private String command;
116 | private int jobId;
117 | private String jobStatus;
118 | @JsonProperty("results")
119 | private List results;
120 | @JsonProperty("timeExpired")
121 | private boolean timeExpired;
122 | private int timeout;
123 | }
124 |
125 | @Data
126 | public static class Result {
127 | private Effect effect;
128 | private int jobId;
129 | private boolean success;
130 | private String type;
131 | @JsonProperty("timeFragmentList")
132 | private List timeFragmentList;
133 | private int statusCode;
134 | private boolean first;
135 | }
136 |
137 | @Data
138 | public static class Effect {
139 | private int classCount;
140 | private int cost;
141 | private int listenerId;
142 | private int methodCount;
143 | }
144 |
145 | @Data
146 | public static class TimeFragment {
147 | private String className;
148 | private double cost;
149 | private int index;
150 | private String methodName;
151 | private String object;
152 | private List params;
153 | @JsonProperty("return")
154 | private boolean isReturn;
155 | private String returnObj;
156 | @JsonProperty("throw")
157 | private boolean isThrow;
158 | @JsonProperty("throwExp")
159 | private String throwExp;
160 | @JsonProperty("timestamp")
161 | private String timestamp;
162 | }
163 |
164 | @Data
165 | public static class Param {
166 | private int expand;
167 | private JsonNode object;
168 | }
169 |
170 | }
171 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/code/graph/controller/CodeGraphController.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.code.graph.controller;
2 |
3 | import io.github.qifan777.knowledge.code.graph.service.CodeGraphService;
4 | import lombok.AllArgsConstructor;
5 | import lombok.extern.slf4j.Slf4j;
6 | import org.springframework.web.bind.annotation.PostMapping;
7 | import org.springframework.web.bind.annotation.RequestMapping;
8 | import org.springframework.web.bind.annotation.RestController;
9 |
10 | @RestController
11 | @RequestMapping("code/graph")
12 | @AllArgsConstructor
13 | @Slf4j
14 | public class CodeGraphController {
15 | private final CodeGraphService codeGraphService;
16 |
17 | @PostMapping("build")
18 | public String buildGraph() {
19 | codeGraphService.buildGraph();
20 | return "SUCCESS";
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/code/graph/entity/ClassNode.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.code.graph.entity;
2 |
3 | import lombok.Data;
4 | import lombok.experimental.Accessors;
5 | import org.springframework.data.neo4j.core.schema.Id;
6 | import org.springframework.data.neo4j.core.schema.Node;
7 | import org.springframework.data.neo4j.core.schema.Relationship;
8 |
9 | import java.util.List;
10 |
11 | @Node
12 | @Data
13 | @Accessors(chain = true)
14 | public class ClassNode {
15 | @Id
16 | private String id;
17 | private String name;
18 | private String content;
19 | @Relationship(direction = Relationship.Direction.OUTGOING, type = "OWNS")
20 | private List ownsMethodNodes;
21 | @Relationship(direction = Relationship.Direction.OUTGOING, type = "IMPORTS")
22 | private List importNodes;
23 | }
24 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/code/graph/entity/MethodNode.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.code.graph.entity;
2 |
3 | import lombok.Data;
4 | import lombok.experimental.Accessors;
5 | import org.springframework.data.neo4j.core.schema.Id;
6 | import org.springframework.data.neo4j.core.schema.Node;
7 | import org.springframework.data.neo4j.core.schema.Relationship;
8 |
9 | import java.util.List;
10 |
11 | @Node
12 | @Data
13 | @Accessors(chain = true)
14 | public class MethodNode {
15 | @Id
16 | private String id;
17 | private String name;
18 | private String comment;
19 | private String content;
20 | @Relationship(direction = Relationship.Direction.OUTGOING, type = "USES")
21 | private List usesMethodNodes;
22 | }
23 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/code/graph/repository/ClassNodeRepository.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.code.graph.repository;
2 |
3 | import io.github.qifan777.knowledge.code.graph.entity.ClassNode;
4 | import org.springframework.data.neo4j.repository.Neo4jRepository;
5 |
6 | public interface ClassNodeRepository extends Neo4jRepository {
7 | }
8 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/code/graph/repository/MethodNodeRepository.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.code.graph.repository;
2 |
3 | import io.github.qifan777.knowledge.code.graph.entity.MethodNode;
4 | import org.springframework.data.neo4j.repository.Neo4jRepository;
5 |
6 | public interface MethodNodeRepository extends Neo4jRepository {
7 |
8 | }
9 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/code/graph/service/CodeGraphBuilder.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.code.graph.service;
2 |
3 | import com.github.javaparser.JavaParser;
4 | import com.github.javaparser.ast.ImportDeclaration;
5 | import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
6 | import com.github.javaparser.ast.body.MethodDeclaration;
7 | import com.github.javaparser.ast.comments.Comment;
8 | import com.github.javaparser.ast.expr.AnnotationExpr;
9 | import com.github.javaparser.ast.expr.Expression;
10 | import com.github.javaparser.ast.expr.MethodCallExpr;
11 | import io.github.qifan777.knowledge.code.graph.entity.ClassNode;
12 | import io.github.qifan777.knowledge.code.graph.entity.MethodNode;
13 | import lombok.SneakyThrows;
14 | import lombok.extern.slf4j.Slf4j;
15 | import org.w3c.dom.Document;
16 | import org.w3c.dom.NodeList;
17 |
18 | import javax.xml.parsers.DocumentBuilder;
19 | import javax.xml.parsers.DocumentBuilderFactory;
20 | import javax.xml.transform.OutputKeys;
21 | import javax.xml.transform.Transformer;
22 | import javax.xml.transform.TransformerFactory;
23 | import javax.xml.transform.dom.DOMSource;
24 | import javax.xml.transform.stream.StreamResult;
25 | import java.io.File;
26 | import java.io.StringWriter;
27 | import java.nio.file.Files;
28 | import java.nio.file.Path;
29 | import java.util.*;
30 | import java.util.stream.Stream;
31 |
32 | @Slf4j
33 | public class CodeGraphBuilder {
34 | private final Map mapperSqlMap = new HashMap<>();
35 | private final Map classDeclarationMap = new HashMap<>();
36 | private final Map classNodeMap = new HashMap<>();
37 | private final Map methodNodeMap = new HashMap<>();
38 | private final Path projectPath;
39 | private final JavaParser javaParser;
40 |
41 | public CodeGraphBuilder(String projectPath, JavaParser javaParser) {
42 | this.projectPath = Path.of(projectPath);
43 | this.javaParser = javaParser;
44 | }
45 |
46 | public record BuildContext(Collection classNodes, Collection methodNodes) {
47 | }
48 |
49 | /**
50 | * 构建java方法关系图谱,包含类节点和类方法节点,类和类之间的引用关系,类和方法之间的归属关系,方法和方法之间的调用关系
51 | *
52 | * @return 构建好的类节点和类方法节点
53 | */
54 | @SneakyThrows
55 | public BuildContext buildGraph() {
56 | // 构建mybatis sql
57 | buildMapperSqlMap();
58 | try (Stream pathStream = Files.walk(getJavaSourcePath())) {
59 | pathStream.filter(path -> path.toFile().isFile())
60 | .flatMap(path -> getClassDeclarations(path).stream())
61 | .filter(declaration -> declaration.getFullyQualifiedName().isPresent())
62 | .forEach(declaration -> classDeclarationMap.put(declaration.getFullyQualifiedName().get(), declaration));
63 |
64 | classDeclarationMap.values().forEach(this::buildClassNode);
65 | classNodeMap
66 | .values()
67 | .forEach(classNode -> {
68 | ClassOrInterfaceDeclaration classOrInterfaceDeclaration = classDeclarationMap.get(classNode.getId());
69 | List methodDeclarations = classOrInterfaceDeclaration.findAll(MethodDeclaration.class);
70 | List ownsMethodNodes = methodDeclarations
71 | .stream()
72 | .map(methodDeclaration -> buildMethodNode(methodDeclaration.getNameAsString(), classNode.getId(), methodDeclarations))
73 | .filter(Optional::isPresent)
74 | .map(Optional::get)
75 | .toList();
76 | classNode.setOwnsMethodNodes(ownsMethodNodes);
77 | });
78 |
79 | }
80 | return new BuildContext(classNodeMap.values(), methodNodeMap.values());
81 | }
82 |
83 | /**
84 | * 获取指定路径下的java文件中的所有类声明,一个java文件中可能包含多个类
85 | *
86 | * @param path java文件路径
87 | * @return 类声明列表
88 | */
89 | @SneakyThrows
90 | private List getClassDeclarations(Path path) {
91 | return javaParser.parse(path.toFile())
92 | .getResult()
93 | .map(compilationUnit -> compilationUnit.findAll(ClassOrInterfaceDeclaration.class))
94 | .filter(list -> !list.isEmpty())
95 | .orElse(List.of());
96 | }
97 |
98 | /**
99 | * 递归构建类节点
100 | *
101 | * @param declaration 类或者接口声明
102 | * @return 类节点
103 | */
104 | @SneakyThrows
105 | private Optional buildClassNode(ClassOrInterfaceDeclaration declaration) {
106 | // 用classNodeMap缓存,避免重复构建
107 | if (classNodeMap.containsKey(declaration.getNameAsString())) {
108 | return Optional.of(classNodeMap.get(declaration.getNameAsString()));
109 | }
110 | return declaration
111 | .getFullyQualifiedName()
112 | .map(qualifiedClasName -> {
113 | ClassNode classNode = new ClassNode().setId(qualifiedClasName)
114 | .setName(declaration.getNameAsString())
115 | .setContent(declaration.toString());
116 | // 缓存类节点
117 | classNodeMap.put(qualifiedClasName, classNode);
118 | List importClassNodes = declaration
119 | .findAll(ImportDeclaration.class)
120 | .stream()
121 | // 递归构建类节点
122 | .map(importDeclaration -> Optional.ofNullable(classDeclarationMap.get(importDeclaration.getNameAsString()))
123 | .flatMap(this::buildClassNode))
124 | .filter(Optional::isPresent)
125 | .map(Optional::get)
126 | .toList();
127 | classNode.setImportNodes(importClassNodes);
128 | return classNode;
129 | });
130 | }
131 |
132 | /**
133 | * 递归构建方法节点
134 | *
135 | * @param methodName 方法名称,如login
136 | * @param className 方法所在的全限定类名,如io.qifan.xxx.UserService
137 | * @param declarations 类中的所有方法,如login、logout等
138 | * @return 方法节点
139 | */
140 | @SneakyThrows
141 | private Optional buildMethodNode(String methodName, String className, List declarations) {
142 | String methodId = className + "#" + methodName;
143 | // 用methodNodeMap缓存,避免重复构建
144 | if (methodNodeMap.containsKey(methodId)) {
145 | return Optional.of(methodNodeMap.get(methodId));
146 | }
147 | return declarations.stream()
148 | .filter(methodDeclaration -> methodDeclaration.getNameAsString().equals(methodName))
149 | .findFirst()
150 | .map(methodDeclaration -> {
151 | // 获取方法内容,如果是mapper接口则获取方法对应的sql
152 | String content = methodDeclaration.findAll(AnnotationExpr.class)
153 | .stream()
154 | .filter(a -> a.getMetaModel().getQualifiedClassName().equals("org.apache.ibatis.annotations.Mapper"))
155 | .findAny()
156 | .map(annotationExpr -> mapperSqlMap.get(methodId))
157 | .orElse(methodDeclaration.toString());
158 | MethodNode methodNode = new MethodNode()
159 | .setId(methodId)
160 | .setName(methodDeclaration.getNameAsString())
161 | .setContent(content)
162 | .setComment(methodDeclaration.getComment().map(Comment::getContent).orElse(""));
163 | // 缓存方法节点
164 | methodNodeMap.put(methodNode.getId(), methodNode);
165 | // 递归构建方法调用关系
166 | List usesMethodNodes = methodDeclaration
167 | .findAll(MethodCallExpr.class)
168 | .stream()
169 | .map(this::buildMethodNodeFromMethodCall)
170 | .filter(Optional::isPresent)
171 | .map(Optional::get)
172 | .toList();
173 | methodNode.setUsesMethodNodes(usesMethodNodes);
174 | return methodNode;
175 | });
176 | }
177 |
178 | /**
179 | * userService.login()这段代码指的是methodCall,要获取login()的方法体内容,需要先解析userService中的所有方法,然后取出login方法
180 | *
181 | * @param methodCallExpr 方法调用表达式
182 | * @return 方法节点
183 | */
184 | public Optional buildMethodNodeFromMethodCall(MethodCallExpr methodCallExpr) {
185 | return methodCallExpr
186 | .getScope()
187 | .filter(this::checkScopeExist)
188 | .flatMap(scope -> buildMethodNode(methodCallExpr.getNameAsString(), scope.calculateResolvedType().asReferenceType().getQualifiedName(), getMethodDeclarationsFromScope(scope)));
189 | }
190 |
191 | /**
192 | * userService.login(), scope指的是userService. 获取userService中的所有方法
193 | *
194 | * @param scope 方法所在的对象
195 | * @return 方法列表
196 | */
197 | public List getMethodDeclarationsFromScope(Expression scope) {
198 | return Optional.ofNullable(classDeclarationMap.get(scope.calculateResolvedType()
199 | .asReferenceType().getQualifiedName()))
200 | .map(declaration -> declaration.findAll(MethodDeclaration.class))
201 | .orElse(List.of());
202 | }
203 |
204 | /**
205 | * 初始化mybatis mapper xml,提取其中的sql将方法名称和sql语句对应起来
206 | */
207 | @SneakyThrows
208 | private void buildMapperSqlMap() {
209 | try (Stream mapper = Files.walk(getFileInResource("mapper"))) {
210 | mapper.filter(path -> path.toString().endsWith(".xml"))
211 | .forEach(file -> {
212 | Document document = parseXMLFileAsDocument(file.toFile());
213 | NodeList selectNodes = document.getDocumentElement().getElementsByTagName("select");
214 | String namespace = document.getDocumentElement().getAttribute("namespace");
215 | extractSqlFromStatement(selectNodes, namespace);
216 | NodeList deleteNodes = document.getDocumentElement().getElementsByTagName("delete");
217 | extractSqlFromStatement(deleteNodes, namespace);
218 | NodeList updateNodes = document.getDocumentElement().getElementsByTagName("update");
219 | extractSqlFromStatement(updateNodes, namespace);
220 | });
221 | } catch (Exception ignored) {
222 | log.warn("不存在mapper");
223 | }
224 | }
225 |
226 | /**
227 | * 解析mapper xml文件
228 | *
229 | * @param file mapper文件路径
230 | * @return document
231 | */
232 | @SneakyThrows
233 | private Document parseXMLFileAsDocument(File file) {
234 | DocumentBuilder builder = DocumentBuilderFactory.newInstance().newDocumentBuilder();
235 | return builder.parse(file);
236 | }
237 |
238 | /**
239 | * 将mapper中的sql语句提取出来,并和sql语句对应的方法名对应起来
240 | *
241 | * @param nodeList select/delete/update标签
242 | * @param namespace mapper的namespace
243 | */
244 | @SneakyThrows
245 | private void extractSqlFromStatement(NodeList nodeList, String namespace) {
246 | TransformerFactory tf = TransformerFactory.newInstance();
247 | Transformer transformer = tf.newTransformer();
248 | transformer.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "yes");
249 | for (int i = 0; i < nodeList.getLength(); i++) {
250 | StringWriter writer = new StringWriter();
251 | transformer.transform(new DOMSource(nodeList.item(i)), new StreamResult(writer));
252 | String output = writer.getBuffer().toString().replaceAll("\n|\r", "");
253 | mapperSqlMap.put(namespace + "#" + nodeList.item(i).getAttributes().getNamedItem("id").getNodeValue(), output);
254 | }
255 | }
256 |
257 |
258 | private Path getJavaSourcePath() {
259 | return projectPath.resolve(Path.of("src", "main", "java"));
260 | }
261 |
262 | private Path getFileInResource(String fileName) {
263 | return projectPath.resolve(Path.of("src", "main", "resources", fileName));
264 | }
265 |
266 |
267 | private boolean checkScopeExist(Expression expression) {
268 | try {
269 | expression.calculateResolvedType().asReferenceType();
270 | } catch (Exception e) {
271 | return false;
272 | }
273 | return true;
274 | }
275 |
276 | }
277 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/code/graph/service/CodeGraphService.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.code.graph.service;
2 |
3 | import io.github.qifan777.knowledge.code.graph.entity.MethodNode;
4 | import io.github.qifan777.knowledge.code.graph.repository.ClassNodeRepository;
5 | import io.github.qifan777.knowledge.code.graph.repository.MethodNodeRepository;
6 | import io.github.qifan777.knowledge.infrastructure.code.CodeAssistantProperties;
7 | import io.github.qifan777.knowledge.infrastructure.code.JavaParserUtils;
8 | import lombok.AllArgsConstructor;
9 | import lombok.SneakyThrows;
10 | import lombok.extern.slf4j.Slf4j;
11 | import org.springframework.data.neo4j.core.Neo4jClient;
12 | import org.springframework.stereotype.Service;
13 |
14 | import java.util.ArrayList;
15 | import java.util.List;
16 |
17 | @Slf4j
18 | @Service
19 | @AllArgsConstructor
20 | public class CodeGraphService {
21 | private final ClassNodeRepository classNodeRepository;
22 | private final MethodNodeRepository methodNodeRepository;
23 | private final Neo4jClient neo4jClient;
24 | private final CodeAssistantProperties properties;
25 |
26 | @SneakyThrows
27 | public void buildGraph() {
28 | methodNodeRepository.deleteAll();
29 | classNodeRepository.deleteAll();
30 | CodeGraphBuilder.BuildContext buildContext = new CodeGraphBuilder(properties.getProject().getProjectPath(), JavaParserUtils.getJavaParser()).buildGraph();
31 | classNodeRepository.saveAll(buildContext.classNodes());
32 | log.info("类节点保存完毕: {}", classNodeRepository.count());
33 | methodNodeRepository.saveAll(buildContext.methodNodes());
34 | log.info("方法节点保存完毕: {}", methodNodeRepository.count());
35 | }
36 |
37 | public List findChildMethods(String methodId) {
38 | String cypher = """
39 | match window=(m)-[:USES*0..3]->(:MethodNode)
40 | where m.id = $methodId
41 | with nodes(window) as nodeList
42 | unwind nodeList as nodeRows
43 | return nodeRows;
44 | """;
45 | return findMethods(cypher, methodId);
46 | }
47 |
48 | public ArrayList findMethods(String cypher, String methodId) {
49 | return new ArrayList<>(neo4jClient.query(cypher)
50 | .bind(methodId).to("methodId")
51 | .fetchAs(MethodNode.class)
52 | .mappedBy((typeSystem, record) -> {
53 | MethodNode methodNode = new MethodNode();
54 | methodNode.setContent(String.valueOf(record.get(0).get("content")));
55 | methodNode.setId(String.valueOf(record.get(0).get("id")).replaceAll("\"", ""));
56 | return methodNode;
57 | })
58 | .all());
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/demo/DocumentAnalyzerFunction.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.demo;
2 |
3 | import com.fasterxml.jackson.annotation.JsonProperty;
4 | import com.fasterxml.jackson.annotation.JsonPropertyDescription;
5 | import lombok.Data;
6 | import lombok.SneakyThrows;
7 | import lombok.extern.slf4j.Slf4j;
8 | import org.springframework.ai.reader.tika.TikaDocumentReader;
9 | import org.springframework.context.annotation.Description;
10 | import org.springframework.core.io.FileSystemResource;
11 | import org.springframework.stereotype.Service;
12 |
13 | import java.util.function.Function;
14 |
15 | /**
16 | * 通过@Description描述函数的用途,这样ai在多个函数中可以根据描述进行选择。
17 | */
18 | @Description("文档解析函数")
19 | @Service
20 | @Slf4j
21 | public class DocumentAnalyzerFunction implements Function {
22 | /**
23 | * 通过@JsonProperty声明属性名称和是否必填
24 | * 通过@JsonPropertyDescription描述属性的用途,这样ai可以提取出符合参数描述的内容。
25 | */
26 | @Data
27 | public static class Request {
28 | @JsonProperty(required = true, value = "path")
29 | @JsonPropertyDescription(value = "需要解析的本地文件路径")
30 | String path;
31 | }
32 |
33 | public record Response(String result) {
34 | }
35 |
36 | @SneakyThrows
37 | @Override
38 | public Response apply(Request request) {
39 | // ai解析用户的提问得到path参数,使用tika读取本地文件获取内容。把读取到的内容再返回给ai作为上下文去回答用户的问题。
40 | TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new FileSystemResource(request.path));
41 | return new Response(tikaDocumentReader.read().get(0).getText());
42 | }
43 |
44 |
45 | }
46 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/demo/DocumentDemoController.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.demo;
2 |
3 | import lombok.AllArgsConstructor;
4 | import lombok.SneakyThrows;
5 | import lombok.extern.slf4j.Slf4j;
6 | import org.springframework.ai.chat.client.ChatClient;
7 | import org.springframework.ai.chat.model.ChatModel;
8 | import org.springframework.ai.document.Document;
9 | import org.springframework.ai.reader.tika.TikaDocumentReader;
10 | import org.springframework.ai.transformer.splitter.TokenTextSplitter;
11 | import org.springframework.ai.vectorstore.VectorStore;
12 | import org.springframework.core.io.InputStreamResource;
13 | import org.springframework.web.bind.annotation.*;
14 | import org.springframework.web.multipart.MultipartFile;
15 |
16 | import java.util.List;
17 |
18 | @RequestMapping("demo/document")
19 | @RestController
20 | @AllArgsConstructor
21 | @Slf4j
22 | public class DocumentDemoController {
23 | private final VectorStore vectorStore;
24 | private final ChatModel chatModel;
25 |
26 | /**
27 | * 嵌入文件
28 | *
29 | * @param file 待嵌入的文件
30 | * @return 是否成功
31 | */
32 | @SneakyThrows
33 | @PostMapping("embedding")
34 | public Boolean embedding(@RequestParam MultipartFile file) {
35 | // 从IO流中读取文件
36 | TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new InputStreamResource(file.getInputStream()));
37 | // 将文本内容划分成更小的块
38 | List splitDocuments = new TokenTextSplitter()
39 | .apply(tikaDocumentReader.read());
40 | // 存入向量数据库,这个过程会自动调用embeddingModel,将文本变成向量再存入。
41 | vectorStore.add(splitDocuments);
42 | return true;
43 | }
44 |
45 | /**
46 | * 查询向量数据库
47 | *
48 | * @param query 用户的提问
49 | * @return 匹配到的文档
50 | */
51 |
52 | @GetMapping("query")
53 | public List query(@RequestParam String query) {
54 | return vectorStore.similaritySearch(query);
55 | }
56 |
57 | @GetMapping("chat")
58 | public String chat(@RequestParam String query) {
59 | return ChatClient.create(chatModel)
60 | .prompt(query)
61 | .toolNames("documentAnalyzerFunction")
62 | .call()
63 | .content();
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/demo/MessageDemoController.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.demo;
2 |
3 | import com.fasterxml.jackson.databind.ObjectMapper;
4 | import io.github.qifan777.knowledge.ai.message.dto.AiMessageParams;
5 | import io.github.qifan777.knowledge.ai.message.dto.AiMessageWrapper;
6 | import lombok.AllArgsConstructor;
7 | import lombok.SneakyThrows;
8 | import org.springframework.ai.chat.client.ChatClient;
9 | import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
10 | import org.springframework.ai.chat.client.advisor.vectorstore.QuestionAnswerAdvisor;
11 | import org.springframework.ai.chat.memory.MessageWindowChatMemory;
12 | import org.springframework.ai.chat.messages.SystemMessage;
13 | import org.springframework.ai.chat.messages.UserMessage;
14 | import org.springframework.ai.chat.model.ChatModel;
15 | import org.springframework.ai.chat.model.ChatResponse;
16 | import org.springframework.ai.chat.prompt.PromptTemplate;
17 | import org.springframework.ai.vectorstore.VectorStore;
18 | import org.springframework.http.MediaType;
19 | import org.springframework.http.codec.ServerSentEvent;
20 | import org.springframework.web.bind.annotation.*;
21 | import reactor.core.publisher.Flux;
22 |
23 | @RequestMapping("demo/message")
24 | @RestController
25 | @AllArgsConstructor
26 | public class MessageDemoController {
27 |
28 | private final ChatModel chatModel;
29 |
30 |
31 | private final ObjectMapper objectMapper;
32 | private final VectorStore vectorStore;
33 | // 模拟数据库存储会话和消息
34 | private final MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder()
35 | .maxMessages(10)
36 | .build();
37 |
38 | /**
39 | * 非流式问答
40 | *
41 | * @param prompt 用户提问
42 | * @return org.springframework.ai.chat.model.ChatResponse
43 | */
44 | @GetMapping("chat")
45 | public String chat(@RequestParam String prompt) {
46 | ChatClient chatClient = ChatClient.create(chatModel);
47 | return chatClient.prompt()
48 | // 输入单条提示词
49 | .user(prompt)
50 | // call代表非流式问答,返回的结果可以是ChatResponse,也可以是Entity(转成java类型),也可以是字符串直接提取回答结果。
51 | .call()
52 | .content();
53 | }
54 |
55 | /**
56 | * 流式问答
57 | *
58 | * @param prompt 用户提问
59 | * @return SSE流式响应
60 | */
61 | @GetMapping(value = "chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
62 | public Flux> chatStream(@RequestParam String prompt) {
63 | return ChatClient.create(chatModel).prompt()
64 | // 输入多条消息,可以将历史消息记录传入
65 | .messages(new SystemMessage("你是一个Java智能助手,应用你的Java知识帮助用户解决问题或者编写程序"),
66 | new UserMessage(prompt))
67 | // 流式返回
68 | .stream()
69 | // 构造SSE(ServerSendEvent)格式返回结果
70 | .chatResponse().map(chatResponse -> ServerSentEvent.builder(toJson(chatResponse))
71 | .event("message")
72 | .build());
73 | }
74 |
75 | /**
76 | * 将流式回答结果转json字符串
77 | *
78 | * @param chatResponse 流式回答结果
79 | * @return String json字符串
80 | */
81 | @SneakyThrows
82 | public String toJson(ChatResponse chatResponse) {
83 | return objectMapper.writeValueAsString(chatResponse);
84 | }
85 |
86 | /**
87 | * 调用自定义函数回答用户的提问
88 | *
89 | * @param prompt 用户的提问
90 | * @param functionName 函数名称(bean的名称,类名小写)
91 | * @return SSE流式响应
92 | */
93 | @GetMapping(value = "chat/stream/function", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
94 | public Flux> chatStreamWithFunction(@RequestParam String prompt, @RequestParam String functionName) {
95 | return ChatClient.create(chatModel).prompt()
96 | .messages(new UserMessage(prompt))
97 | // spring ai会从已注册为bean的function中查找函数,将它添加到请求中。如果成功触发就会调用函数
98 | .toolNames(functionName)
99 | .stream()
100 | .chatResponse()
101 | .map(chatResponse -> ServerSentEvent.builder(toJson(chatResponse))
102 | .event("message")
103 | .build());
104 | }
105 |
106 | /**
107 | * 从向量数据库中查找文档,并将查询的文档作为上下文回答。
108 | *
109 | * @param prompt 用户的提问
110 | * @return SSE流响应
111 | */
112 | @GetMapping(value = "chat/stream/database", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
113 | public Flux> chatStreamWithDatabase(@RequestParam String prompt) {
114 | // question_answer_context是一个占位符,会替换成向量数据库中查询到的文档。QuestionAnswerAdvisor会替换。
115 | String promptWithContext = """
116 | {query}
117 | 下面是上下文信息
118 | ---------------------
119 | {question_answer_context}
120 | ---------------------
121 | 给定的上下文和提供的历史信息,而不是事先的知识,回复用户的意见。如果答案不在上下文中,告诉用户你不能回答这个问题。
122 | """;
123 | return ChatClient.create(chatModel).prompt()
124 | .user(prompt)
125 | .advisors(QuestionAnswerAdvisor.builder(vectorStore)
126 | .promptTemplate(new PromptTemplate(promptWithContext)).build())
127 | .stream()
128 | .content()
129 | .map(chatResponse -> ServerSentEvent.builder(chatResponse)
130 | .event("message")
131 | .build());
132 | }
133 |
134 | /**
135 | * 根据会话id,从数据库中查找历史消息,并将消息作为上下文回答。
136 | *
137 | * @param prompt 用户的提问
138 | * @param sessionId 会话id
139 | * @return SSE流响应
140 | */
141 | @GetMapping(value = "chat/stream/history", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
142 | public Flux> chatStreamWithHistory(@RequestParam String prompt,
143 | @RequestParam String sessionId) {
144 | // 1. 如果需要存储会话和消息到数据库,自己可以实现ChatMemory接口,这里使用InMemoryChatMemory,内存存储。
145 | // 2. 传入会话id,MessageChatMemoryAdvisor会根据会话id去查找消息。
146 | // 3. 只需要携带最近10条消息
147 | MessageChatMemoryAdvisor messageChatMemoryAdvisor= MessageChatMemoryAdvisor
148 | .builder(chatMemory)
149 | .conversationId(sessionId)
150 | .build();
151 | return ChatClient.create(chatModel).prompt()
152 | .user(prompt)
153 | // MessageChatMemoryAdvisor会在消息发送给大模型之前,从ChatMemory中获取会话的历史消息,然后一起发送给大模型。
154 | .advisors(messageChatMemoryAdvisor)
155 | .stream()
156 | .content()
157 | .map(chatResponse -> ServerSentEvent.builder(chatResponse)
158 | .event("message")
159 | .build());
160 | }
161 |
162 | @PostMapping("ignore")
163 | public void ignore(@RequestParam AiMessageWrapper wrapper, @RequestParam AiMessageParams params) {
164 | }
165 |
166 | }
167 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/graph/GraphController.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.graph;
2 |
3 | import io.github.qifan777.knowledge.graph.chunk.ChunkController;
4 | import io.qifan.infrastructure.common.exception.BusinessException;
5 | import lombok.AllArgsConstructor;
6 | import lombok.extern.slf4j.Slf4j;
7 | import org.springframework.ai.chat.messages.UserMessage;
8 | import org.springframework.ai.chat.model.ChatModel;
9 | import org.springframework.ai.chat.prompt.PromptTemplate;
10 | import org.springframework.ai.embedding.EmbeddingModel;
11 | import org.springframework.data.neo4j.core.Neo4jClient;
12 | import org.springframework.web.bind.annotation.GetMapping;
13 | import org.springframework.web.bind.annotation.RequestParam;
14 | import org.springframework.web.bind.annotation.RestController;
15 |
16 | import java.util.List;
17 | import java.util.Map;
18 |
19 | @RestController("graph")
20 | @AllArgsConstructor
21 | @Slf4j
22 | public class GraphController {
23 | private final EmbeddingModel embeddingModel;
24 | private final ChatModel chatModel;
25 | private final Neo4jClient neo4jClient;
26 | private final PromptTemplate promptTemplate = new PromptTemplate("""
27 | Context information is below.
28 | ---------------------
29 | {question_answer_context}
30 | ---------------------
31 | Given the context and provided history information and not prior knowledge,
32 | reply to the user comment. If the answer is not in the context, inform
33 | the user that you can't answer the question.
34 | """);
35 |
36 | @GetMapping(value = "chunk/rag")
37 | public String chunkRag(@RequestParam String query) {
38 | List embed = ChunkController.floatsToDoubles(embeddingModel.embed(query));
39 | String result = neo4jClient.query("""
40 | CALL db.index.vector.queryNodes('form_10k_chunks', 1, $embedding)
41 | yield node, score
42 | match window=(:Chunk)-[:NEXT*0..1]->(node)-[:NEXT*0..1]->(:Chunk)
43 | with nodes(window) as chunkList, node, score
44 | unwind chunkList as chunkRows
45 | with collect(chunkRows.text) as textList, node, score
46 | return apoc.text.join(textList, " \\n ")
47 | """)
48 | .bind(embed).to("embedding")
49 | .fetchAs(String.class).first()
50 | .orElseThrow(() -> new BusinessException("未找到相似文档"));
51 | String content = promptTemplate.createMessage(Map.of("question_answer_context", result)).getText();
52 | return chatModel.call(new UserMessage(content + "\n" + query));
53 | }
54 |
55 | @GetMapping(value = "manager/rag")
56 | public String managerRag(@RequestParam String query) {
57 | List embed = ChunkController.floatsToDoubles(embeddingModel.embed(query));
58 | var result = neo4jClient.query("""
59 | CALL db.index.vector.queryNodes('form_10k_chunks', 1, $embedding)
60 | YIELD node, score
61 | MATCH (node)-[:PART_OF]->(f:Form),
62 | (f)<-[:FILED]-(com:Company),
63 | (com)<-[owns:OWNS_STOCK_IN]-(mgr:Manager)
64 | WITH node, score, mgr, owns, com
65 | ORDER BY owns.shares DESC LIMIT 5
66 | WITH collect (
67 | mgr.name +
68 | " owns " + owns.shares + " of " + com.name +
69 | " at a value of $" + apoc.number.format(owns.value) + "."
70 | ) AS investment_statements, node, score
71 | RETURN "investors: \\n" + apoc.text.join(investment_statements, "\\n") +\s
72 | "\\n" + node.text AS text
73 | """)
74 | .bind(embed).to("embedding")
75 | .fetchAs(String.class)
76 | .first()
77 | .orElseThrow(() -> new BusinessException("未找到相似文档"));
78 | String content = promptTemplate.createMessage(Map.of("question_answer_context", result)).getText();
79 | log.info("context result: {}", content);
80 | return chatModel.call(new UserMessage(content + "\n" + query));
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/graph/chunk/Chunk.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.graph.chunk;
2 |
3 | import lombok.Builder;
4 | import lombok.Data;
5 | import org.springframework.data.neo4j.core.schema.Id;
6 | import org.springframework.data.neo4j.core.schema.Node;
7 |
8 | import java.util.List;
9 |
10 | @Builder
11 | @Data
12 | @Node
13 | public class Chunk {
14 | @Id
15 | private String id;
16 | // 切割后的文本
17 | private String text;
18 | // item1, item1a, item7, item7a
19 | private String item;
20 | // Chunk序列号
21 | private Integer chunkSeqId;
22 | // 属于的Form
23 | private String formId;
24 | // text的embedding
25 | private List textEmbedding;
26 | }
27 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/graph/chunk/ChunkController.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.graph.chunk;
2 |
3 | import com.fasterxml.jackson.databind.ObjectMapper;
4 | import com.fasterxml.jackson.databind.node.ObjectNode;
5 | import lombok.AllArgsConstructor;
6 | import lombok.SneakyThrows;
7 | import lombok.extern.slf4j.Slf4j;
8 | import org.springframework.ai.document.Document;
9 | import org.springframework.ai.embedding.EmbeddingModel;
10 | import org.springframework.ai.transformer.splitter.TokenTextSplitter;
11 | import org.springframework.data.neo4j.core.Neo4jClient;
12 | import org.springframework.web.bind.annotation.PostMapping;
13 | import org.springframework.web.bind.annotation.RestController;
14 |
15 | import java.io.File;
16 | import java.util.ArrayList;
17 | import java.util.List;
18 |
19 | @RestController("chunk")
20 | @AllArgsConstructor
21 | @Slf4j
22 | public class ChunkController {
23 | private final ChunkRepository chunkRepository;
24 | private final Neo4jClient neo4jClient;
25 | private final EmbeddingModel embeddingModel;
26 |
27 | /**
28 | * 创建Chunk节点,
29 | */
30 |
31 | @PostMapping("node")
32 | public void createNodes() {
33 | var fileDir = new File("F:\\workspace\\code\\learn\\sec-edgar-notebooks\\data\\sample\\form10k");
34 | File[] files = fileDir.listFiles();
35 | for (File file : files) {
36 | if (!file.getName().contains(".json")) continue;
37 | chunkRepository.saveAll(fileToChunkList(file));
38 | }
39 |
40 | }
41 |
42 | /**
43 | * 解析form10k的中的item属性切割成Chunk
44 | *
45 | * @param file form10k的json文件
46 | * @return Chunk节点
47 | */
48 | @SneakyThrows
49 | public List fileToChunkList(File file) {
50 | ObjectNode node = new ObjectMapper().readValue(file, ObjectNode.class);
51 | // 每个form10k有item1,item1a,item7,item7a四种文本信息,都需要将切割
52 | String[] items = {"item1", "item1a", "item7", "item7a"};
53 | List chunks = new ArrayList<>();
54 | for (String item : items) {
55 | String text = node.get(item).asText();
56 | // 切割文本成
57 | List documents = new TokenTextSplitter().split(new Document(text));
58 | // 最多不超过20 Chunk
59 | for (int chunkSeqId = 0; chunkSeqId < Integer.min(documents.size(), 20); chunkSeqId++) {
60 | String formId = file.getName().replace(".json", "");
61 | chunks.add(Chunk.builder()
62 | .id("%s-%s-chunk%04d".formatted(formId, item, chunkSeqId))
63 | .chunkSeqId(chunkSeqId)
64 | .formId(formId)
65 | .text(documents.get(chunkSeqId).getText())
66 | .item(item)
67 | .build());
68 | }
69 | }
70 | return chunks;
71 | }
72 |
73 | @PostMapping("link")
74 | public void createLink() {
75 | var formIds = neo4jClient
76 | .query("match (c:Chunk) return distinct c.formId as formId")
77 | .fetchAs(String.class)
78 | .all();
79 | // 每个form10k有item1,item1a,item7,item7a四种文本信息,都需要将切割后的Chunk通过NEXT关联起来
80 | formIds.forEach(formId -> {
81 | List.of("item1", "item1a", "item7", "item7a")
82 | .forEach(item -> {
83 | neo4jClient.query("""
84 | MATCH (c:Chunk) // 匹配所有的节点
85 | WHERE c.formId = $formId // 属于同一个form和同一个item的节点
86 | AND c.item = $item
87 | WITH c
88 | ORDER BY c.chunkSeqId ASC // 根据seqId排序一下节点
89 | WITH collect(c) as section_chunk_list // 转成list
90 | CALL apoc.nodes.link(section_chunk_list, "NEXT", {avoidDuplicates: true}) // 节点之间依按顺序创建连接
91 | RETURN size(section_chunk_list)
92 | """)
93 | .bind(formId).to("formId")
94 | .bind(item).to("item")
95 | .run();
96 | });
97 | });
98 | }
99 |
100 | /**
101 | * 对所有Chunk进行embedding,neo4j中支持向量索引,只有创建索引之后才可以查询相似的向量
102 | */
103 | @PostMapping("embedding")
104 | public void createEmbedding() {
105 | // 随便将一段文本转成向量,看看这个嵌入模型的向量维度是多少
106 | int dimension = embeddingModel.embed("你好").length;
107 |
108 | // 在Chunk节点创建索引,使用cosine求向量之间的相似度
109 | neo4jClient.query("""
110 | CREATE VECTOR INDEX `form_10k_chunks` IF NOT EXISTS
111 | FOR (c:Chunk) ON (c.textEmbedding)
112 | OPTIONS { indexConfig: {
113 | `vector.dimensions`: $dimensions,
114 | `vector.similarity_function`: 'cosine'
115 | }}
116 | """)
117 | .bind(dimension).to("dimensions")
118 | .run();
119 | // 对那些没有嵌入的Chunk进行embedding
120 | List waitToEmbedList = chunkRepository.findByTextEmbeddingIsNull();
121 | waitToEmbedList.forEach(chunk -> chunk.setTextEmbedding(floatsToDoubles(embeddingModel.embed(chunk.getText()))));
122 | chunkRepository.saveAll(waitToEmbedList);
123 | }
124 |
125 | public static List floatsToDoubles(float[] floats) {
126 | List result = new ArrayList<>(floats.length);
127 | for (float f : floats) {
128 | result.add((double) f);
129 | }
130 | return result;
131 | }
132 |
133 | }
134 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/graph/chunk/ChunkRepository.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.graph.chunk;
2 |
3 | import org.springframework.data.neo4j.repository.Neo4jRepository;
4 |
5 | import java.util.List;
6 |
7 | public interface ChunkRepository extends Neo4jRepository {
8 | List findByTextEmbeddingIsNull();
9 | }
10 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/graph/company/Company.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.graph.company;
2 |
3 | import lombok.Builder;
4 | import lombok.Data;
5 | import org.springframework.data.neo4j.core.schema.Id;
6 | import org.springframework.data.neo4j.core.schema.Node;
7 |
8 | import java.util.List;
9 |
10 | @Builder
11 | @Data
12 | @Node
13 | public class Company {
14 | @Id
15 | private String cusip6;
16 | private List cusips;
17 | private List names;
18 | private String name;
19 | }
20 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/graph/company/CompanyController.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.graph.company;
2 |
3 | import com.fasterxml.jackson.databind.ObjectMapper;
4 | import io.github.qifan777.knowledge.graph.model.Form10K;
5 | import lombok.AllArgsConstructor;
6 | import lombok.SneakyThrows;
7 | import lombok.extern.slf4j.Slf4j;
8 | import org.springframework.data.neo4j.core.Neo4jClient;
9 | import org.springframework.web.bind.annotation.PostMapping;
10 | import org.springframework.web.bind.annotation.RequestMapping;
11 | import org.springframework.web.bind.annotation.RestController;
12 |
13 | import java.io.File;
14 |
15 | @RestController
16 | @RequestMapping("company")
17 | @Slf4j
18 | @AllArgsConstructor
19 | public class CompanyController {
20 | private final CompanyRepository companyRepository;
21 | private final Neo4jClient neo4jClient;
22 |
23 | @SneakyThrows
24 | @PostMapping("nodes")
25 | public void createNodes() {
26 | var fileDir = new File("F:\\workspace\\code\\learn\\sec-edgar-notebooks\\data\\sample\\form10k");
27 | File[] files = fileDir.listFiles();
28 | for (File file : files) {
29 | if (!file.getName().contains(".json")) continue;
30 | var form10K = new ObjectMapper().readValue(file, Form10K.class);
31 | var company = Company.builder().cusip6(form10K.getCusip6())
32 | .cusips(form10K.getCusip())
33 | .names(form10K.getNames())
34 | .name(form10K.getNames().get(0))
35 | .build();
36 | companyRepository.save(company);
37 | }
38 |
39 | }
40 |
41 | @PostMapping("relationship/filed")
42 | public void createFiledRelationship() {
43 | // 创建公司和form关系
44 | neo4jClient.query("""
45 | MATCH (com:Company), (form:Form)
46 | WHERE com.cusip6 = form.cusip6
47 | MERGE (com)-[:FILED]->(form)
48 | """)
49 | .run();
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/graph/company/CompanyRepository.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.graph.company;
2 |
3 | import org.springframework.data.neo4j.repository.Neo4jRepository;
4 |
5 | public interface CompanyRepository extends Neo4jRepository {
6 | }
7 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/graph/form/Form.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.graph.form;
2 |
3 | import lombok.Builder;
4 | import lombok.Data;
5 | import org.springframework.data.neo4j.core.schema.Id;
6 | import org.springframework.data.neo4j.core.schema.Node;
7 |
8 | @Builder
9 | @Data
10 | @Node
11 | public class Form {
12 | @Id
13 | private String id;
14 | private String cusip6;
15 | private String source;
16 | private String fullText;
17 | }
18 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/graph/form/FormController.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.graph.form;
2 |
3 | import com.fasterxml.jackson.databind.ObjectMapper;
4 | import io.github.qifan777.knowledge.graph.model.Form10K;
5 | import lombok.AllArgsConstructor;
6 | import lombok.SneakyThrows;
7 | import lombok.extern.slf4j.Slf4j;
8 | import org.springframework.data.neo4j.core.Neo4jClient;
9 | import org.springframework.web.bind.annotation.PostMapping;
10 | import org.springframework.web.bind.annotation.RequestMapping;
11 | import org.springframework.web.bind.annotation.RestController;
12 |
13 | import java.io.File;
14 |
15 | @RestController
16 | @RequestMapping("form")
17 | @AllArgsConstructor
18 | @Slf4j
19 | public class FormController {
20 | private final FormRepository formRepository;
21 | private final Neo4jClient neo4jClient;
22 |
23 | @SneakyThrows
24 | @PostMapping("node")
25 | public void createNodes() {
26 | var fileDir = new File("F:\\workspace\\code\\learn\\sec-edgar-notebooks\\data\\sample\\form10k");
27 | File[] files = fileDir.listFiles();
28 | for (File file : files) {
29 | if (!file.getName().contains(".json")) continue;
30 | var form10K = new ObjectMapper().readValue(file, Form10K.class);
31 | var fullText = "About " +
32 | String.join(",", form10K.getNames()) +
33 | "..." +
34 | form10K.getItem1() +
35 | "\n" +
36 | form10K.getItem1a() +
37 | "\n" +
38 | form10K.getItem7() +
39 | "\n" +
40 | form10K.getItem7a();
41 | var formId = file.getName().replace(".json", "");
42 | var form = Form.builder().id(formId)
43 | .fullText(fullText)
44 | .cusip6(form10K.getCusip6())
45 | .source(form10K.getSource())
46 | .build();
47 | formRepository.save(form);
48 | }
49 | }
50 |
51 | @PostMapping("relationship/section")
52 | public void createSectionRelationship() {
53 | neo4jClient.query("""
54 | match (c:Chunk),(f:Form) where c.chunkSeqId=0 and f.id = c.formId
55 | merge (f)-[r:SECTION {item:c.item}] -> (c)
56 | return count(r)
57 | """)
58 | .run();
59 | }
60 |
61 | @PostMapping("relationship/part-of")
62 | public void createPartOfRelationship() {
63 | neo4jClient.query("""
64 | match (c:Chunk), (f:Form) where c.formId=f.id
65 | merge (c)-[r:PART_OF]->(f)
66 | return count(r);
67 | """)
68 | .run();
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/src/main/java/io/github/qifan777/knowledge/graph/form/FormRepository.java:
--------------------------------------------------------------------------------
1 | package io.github.qifan777.knowledge.graph.form;
2 |
3 | import org.springframework.data.neo4j.repository.Neo4jRepository;
4 |
5 | public interface FormRepository extends Neo4jRepository