├── .gitignore ├── README.md ├── database.sql ├── front-end ├── .env.development ├── .env.production ├── .eslintrc.cjs ├── .gitignore ├── .prettierrc.json ├── README.md ├── env.d.ts ├── index.html ├── package-lock.json ├── package.json ├── public │ └── favicon.ico ├── scripts │ └── generate-api.js ├── src │ ├── App.vue │ ├── assets │ │ ├── SrpingAI知识点.png │ │ ├── background.jpg │ │ ├── cover.png │ │ └── logo.jpg │ ├── components │ │ ├── image │ │ │ └── image-upload.vue │ │ └── key-value │ │ │ ├── key-value-input.vue │ │ │ └── value-input.vue │ ├── main.ts │ ├── router │ │ └── index.ts │ ├── stores │ │ └── home-store.ts │ ├── typings │ │ └── index.d.ts │ ├── utils │ │ ├── api-instance.ts │ │ ├── common.ts │ │ └── request.ts │ └── views │ │ ├── chat │ │ ├── chat-view.vue │ │ ├── components │ │ │ ├── markdown-message.vue │ │ │ ├── message-input.vue │ │ │ ├── message-row.vue │ │ │ ├── session-item.vue │ │ │ └── text-loading.vue │ │ └── store │ │ │ └── chat-store.ts │ │ ├── code │ │ └── analyze │ │ │ ├── analyze-result-view.vue │ │ │ └── collapse-title.vue │ │ └── login │ │ ├── login-view.vue │ │ └── register-view.vue ├── tsconfig.app.json ├── tsconfig.json ├── tsconfig.node.json └── vite.config.ts ├── pom.xml └── src └── main ├── dto ├── AiMessage.dto ├── AiSession.dto └── User.dto ├── java └── io │ └── github │ └── qifan777 │ └── knowledge │ ├── ServerApplication.java │ ├── ai │ ├── agent │ │ ├── AbstractAgent.java │ │ ├── Agent.java │ │ ├── chronologist │ │ │ └── Chronologist.java │ │ └── computer │ │ │ ├── ComputerAssistant.java │ │ │ ├── CpuAnalyzer.java │ │ │ └── DirectoryReader.java │ ├── document │ │ └── DocumentController.java │ ├── message │ │ ├── AiMessage.java │ │ ├── AiMessageChatMemory.java │ │ ├── AiMessageController.java │ │ ├── AiMessageRepository.java │ │ ├── MessageInputWrapper.java │ │ └── dto │ │ │ ├── AiMessageParams.java │ │ │ └── AiMessageWrapper.java │ └── session │ │ ├── AiSession.java │ │ ├── AiSessionController.java │ │ └── AiSessionRepository.java │ ├── code │ ├── CodeAssistantAgent.java │ ├── analyze │ │ ├── AnalyzeController.java │ │ └── AnalyzeFunction.java │ ├── arthas │ │ └── ArthasFunction.java │ └── graph │ │ ├── controller │ │ └── CodeGraphController.java │ │ ├── entity │ │ ├── ClassNode.java │ │ └── MethodNode.java │ │ ├── repository │ │ ├── ClassNodeRepository.java │ │ └── MethodNodeRepository.java │ │ └── service │ │ ├── CodeGraphBuilder.java │ │ └── CodeGraphService.java │ ├── demo │ ├── DocumentAnalyzerFunction.java │ ├── DocumentDemoController.java │ └── MessageDemoController.java │ ├── graph │ ├── GraphController.java │ ├── chunk │ │ ├── Chunk.java │ │ ├── ChunkController.java │ │ └── ChunkRepository.java │ ├── company │ │ ├── Company.java │ │ ├── CompanyController.java │ │ └── CompanyRepository.java │ ├── form │ │ ├── Form.java │ │ ├── FormController.java │ │ └── FormRepository.java │ ├── manager │ │ ├── Manager.java │ │ ├── ManagerController.java │ │ └── ManagerRepository.java │ └── model │ │ ├── Form10K.java │ │ └── Form13.java │ ├── infrastructure │ ├── code │ │ ├── CodeAssistantProperties.java │ │ └── JavaParserUtils.java │ ├── config │ │ ├── GlobalExceptionAdvice.java │ │ ├── LocalDateTimeConvert.java │ │ ├── PageableConvert.java │ │ ├── ResponseInterceptor.java │ │ └── WebConfig.java │ ├── jimmer │ │ ├── BaseDateTime.java │ │ ├── BaseDateTimeDraftInterceptor.java │ │ ├── BaseEntity.java │ │ ├── BaseEntityDraftInterceptor.java │ │ ├── JimmerConfig.java │ │ └── UUIDIdGenerator.java │ └── model │ │ └── QueryRequest.java │ ├── oss │ └── OSSController.java │ └── user │ ├── User.java │ ├── UserController.java │ └── UserRepository.java └── resources ├── application.yml └── prompts └── react ├── react-tool.st └── react.st /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | !.mvn/wrapper/maven-wrapper.jar 3 | !**/src/main/**/target/ 4 | !**/src/test/**/target/ 5 | 6 | ### IntelliJ IDEA ### 7 | .idea/modules.xml 8 | .idea/jarRepositories.xml 9 | .idea/compiler.xml 10 | .idea/libraries/ 11 | *.iws 12 | *.iml 13 | *.ipr 14 | 15 | ### Eclipse ### 16 | .apt_generated 17 | .classpath 18 | .factorypath 19 | .project 20 | .settings 21 | .springBeans 22 | .sts4-cache 23 | 24 | ### NetBeans ### 25 | /nbproject/private/ 26 | /nbbuild/ 27 | /dist/ 28 | /nbdist/ 29 | /.nb-gradle/ 30 | build/ 31 | !**/src/main/**/build/ 32 | !**/src/test/**/build/ 33 | 34 | ### VS Code ### 35 | .vscode/ 36 | 37 | ### Mac OS ### 38 | .DS_Store 39 | /src/main/resources/application-private.yml 40 | /.idea/ 41 | /template/ 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 项目介绍 2 | 3 | 本项目使用SpringAI教学,包含了SSE流/Agent智能体/FunctionCall/Embedding/VectorDatabase/RAG/Graph RAG/历史消息/图片生成/图片理解 4 | 5 | ![项目内容](./front-end/src/assets/cover.png) 6 | 7 | ![知识点](./front-end/src/assets/SrpingAI知识点.png) 8 | 9 | 10 | 11 | [文档地址](https://www.jarcheng.top/blog/project/spring-ai/intro.html) 12 | [视频地址](https://www.bilibili.com/video/BV14y411q7RN/) 13 | 14 | ## 运行环境 15 | 16 | - Java 17 17 | - Node.js 18+ 18 | - MySQL 8 19 | - DashScope API KEY(或者其他) 20 | - Redis-Stack 21 | 22 | redis基础上拓展向量查询功能 23 | 24 | ```shell 25 | docker run -d \ 26 | --name redis-stack \ 27 | --restart=always \ 28 | -v redis-data:/data \ 29 | -p 6379:6379 \ 30 | -p 8001:8001 \ 31 | -e REDIS_ARGS="--requirepass 123456" redis/redis-stack:latest 32 | ``` 33 | 34 | - neo4j 5+ 35 | 36 | 安装完neo4j访问`localhost:7474`, 默认的账号密码都是`neo4j`和`neo4j`。 37 | 38 | ```shell 39 | docker run \ 40 | -d \ 41 | -p 7474:7474 -p 7687:7687 \ 42 | -v neo4j-data:/data -v neo4j-data:/plugins \ 43 | --name neo4j \ 44 | -e NEO4J_apoc_export_file_enabled=true \ 45 | -e NEO4J_apoc_import_file_enabled=true \ 46 | -e NEO4J_apoc_import_file_use__neo4j__config=true \ 47 | -e NEO4JLABS_PLUGINS=\[\"apoc\"\] \ 48 | -e NEO4J_dbms_security_procedures_unrestricted=apoc.\\\* \ 49 | neo4j 50 | ``` 51 | 52 | ## 运行步骤 53 | 54 | ### 1.clone代码 55 | 56 | ```shell 57 | git clone https://github.com/qifan777/KnowledgeBaseChatSpringAI 58 | ``` 59 | 60 | ### 2. idea打开项目 61 | 62 | ### 3. 修改配置文件 63 | 64 | 修改application.yml中的API-KEY, MySQL, Redis-Stack, Neo4j配置 65 | ### 4. 运行项目 66 | 67 | 后端运行 68 | 69 | 1. 运行ServerApplication.java 70 | 2. target/generated-sources/annotations右键mark directory as/generated source root 71 | 72 | 前端运行,在front-end目录下 73 | 74 | - npm run install 75 | - npm run api (先运行后端) 76 | - npm run dev 77 | 78 | 79 | ## 联系方式 80 | 81 | 付费远程运行/安装/定制开发联系微信:ljc666max 82 | 83 | 其他关于程序运行安装报错请加QQ群: 84 | 85 | - 416765656(满) 86 | - 632067985 87 | -------------------------------------------------------------------------------- /database.sql: -------------------------------------------------------------------------------- 1 | -- MySQL dump 10.13 Distrib 8.0.36, for Win64 (x86_64) 2 | -- 3 | -- Host: 127.0.0.1 Database: knowledge_base 4 | -- ------------------------------------------------------ 5 | -- Server version 8.0.36 6 | 7 | /*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */; 8 | /*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */; 9 | /*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */; 10 | /*!50503 SET NAMES utf8mb4 */; 11 | /*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */; 12 | /*!40103 SET TIME_ZONE='+00:00' */; 13 | /*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */; 14 | /*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */; 15 | /*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */; 16 | /*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */; 17 | 18 | -- 19 | -- Table structure for table `ai_message` 20 | -- 21 | 22 | DROP TABLE IF EXISTS `ai_message`; 23 | /*!40101 SET @saved_cs_client = @@character_set_client */; 24 | /*!50503 SET character_set_client = utf8mb4 */; 25 | CREATE TABLE `ai_message` ( 26 | `id` varchar(36) NOT NULL, 27 | `created_time` datetime(6) NOT NULL, 28 | `edited_time` datetime(6) NOT NULL, 29 | `creator_id` varchar(32) NOT NULL, 30 | `editor_id` varchar(32) NOT NULL, 31 | `type` varchar(32) NOT NULL COMMENT '消息类型(用户/助手/系统)', 32 | `text_content` text NOT NULL COMMENT '消息内容', 33 | `medias` json DEFAULT NULL COMMENT '媒体内容如图片链接、语音链接', 34 | `ai_session_id` varchar(32) NOT NULL COMMENT '会话id', 35 | PRIMARY KEY (`id`), 36 | KEY `ai_message_ai_session_id_fk` (`ai_session_id`) 37 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; 38 | /*!40101 SET character_set_client = @saved_cs_client */; 39 | 40 | -- 41 | -- Dumping data for table `ai_message` 42 | -- 43 | 44 | LOCK TABLES `ai_message` WRITE; 45 | /*!40000 ALTER TABLE `ai_message` DISABLE KEYS */; 46 | INSERT INTO `ai_message` VALUES ('05c523bf586b422e8f19d32a6bab17b6','2024-06-23 22:08:33.529844','2024-06-23 22:08:33.529844','fe3d0d7d6eb34eb7a6b7985426cf8af7','fe3d0d7d6eb34eb7a6b7985426cf8af7','ASSISTANT','今天的日期是2024年06月23日。',NULL,'b495186b9259494792a9fe4ce8807f25'); 47 | INSERT INTO `ai_message` VALUES ('1d14d513ae8c4ffda0fb50df6e24ba77','2024-06-20 22:08:28.389711','2024-06-20 22:08:28.389711','fe3d0d7d6eb34eb7a6b7985426cf8af7','fe3d0d7d6eb34eb7a6b7985426cf8af7','USER','C:\\Users\\Administrator\\Desktop\\2023年工作总结及2024年工作展望-林家成.docx,这份文档的内容是什么?','[]','c4d192c4b2de48bcae585f5b9672cd1d'); 48 | INSERT INTO `ai_message` VALUES ('34d65ad69e4b470794da11732fe44a7d','2024-06-23 22:08:33.337776','2024-06-23 22:08:33.337776','fe3d0d7d6eb34eb7a6b7985426cf8af7','fe3d0d7d6eb34eb7a6b7985426cf8af7','USER','今天的日期是多少?','[]','b495186b9259494792a9fe4ce8807f25'); 49 | INSERT INTO `ai_message` VALUES ('ba161500bec84b388a6501b3be333b45','2024-06-20 22:08:28.649439','2024-06-20 22:08:28.649439','fe3d0d7d6eb34eb7a6b7985426cf8af7','fe3d0d7d6eb34eb7a6b7985426cf8af7','ASSISTANT','2023年的工作总结及2024年工作展望文档内容概要如下:\n\n1、**工作回顾重点:**\n - 成功运用Python脚本在一经上云项目中实现Oracle脚本批量迁移、实体批量登记等,极大提升了工作效率(速度提升70倍以上),并提前完成年度工作任务。\n - 领导新智慧工厂项目的前端与后端框架搭建,采用更高效、安全的技术栈,促进团队快速投入业务开发。重构并优化了基础功能模块,如菜单框架、租户管理等,封装为易用组件。\n - 开发了脚本依赖分析工具,通过SQL语法树解析等技术,有效辅助故障检测、代码规范检查及开发效率,与部门其他系统(如指标库、报表系统)集成,增强功能实用性。\n\n2、**个人成长与进步:**\n - 今年最大的成就是深入学习了数据开发知识,综合软件开发能力使自己更能满足用户需求,相比同行具有更全面的技能组合。\n\n3、**应对挑战的方法:**\n - 面对SQL语法解析等高难度任务,通过自学源码、复习理论基础,成功开发出工具,体现了自我驱动力和解决问题的能力。\n\n4、**反思与改进空间:**\n - 希望能在项目中实践深度学习、机器学习技术,这是未来努力的一个方向。\n\n5、**提质增效的建议:**\n - 强调代码审查和个人技能提升,鼓励使用先进技术和工具,以及创建持续学习的环境,以保持竞争力。\n\n6、**对当前开发工作的建议:**\n - 提倡技术分享会议,打破知识孤岛,促进新老员工间的技术交流和学习,平衡稳定与创新的需求。\n\n7、**其他建议:**\n - 加强跨部门合作,以新智慧工厂项目为范例,推动更多跨领域协作成果。\n\n对于2024年的展望虽未详细展开,但基于上述总结,可能包括深化技术应用、促进团队协作、继续个人技能升级等方面。',NULL,'c4d192c4b2de48bcae585f5b9672cd1d'); 50 | /*!40000 ALTER TABLE `ai_message` ENABLE KEYS */; 51 | UNLOCK TABLES; 52 | 53 | -- 54 | -- Table structure for table `ai_session` 55 | -- 56 | 57 | DROP TABLE IF EXISTS `ai_session`; 58 | /*!40101 SET @saved_cs_client = @@character_set_client */; 59 | /*!50503 SET character_set_client = utf8mb4 */; 60 | CREATE TABLE `ai_session` ( 61 | `id` varchar(36) NOT NULL, 62 | `created_time` datetime(6) NOT NULL, 63 | `edited_time` datetime(6) NOT NULL, 64 | `creator_id` varchar(32) NOT NULL, 65 | `editor_id` varchar(32) NOT NULL, 66 | `name` varchar(32) NOT NULL COMMENT '会话名称', 67 | PRIMARY KEY (`id`) 68 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; 69 | /*!40101 SET character_set_client = @saved_cs_client */; 70 | 71 | -- 72 | -- Dumping data for table `ai_session` 73 | -- 74 | 75 | LOCK TABLES `ai_session` WRITE; 76 | /*!40000 ALTER TABLE `ai_session` DISABLE KEYS */; 77 | INSERT INTO `ai_session` VALUES ('b495186b9259494792a9fe4ce8807f25','2024-06-18 22:01:51.459293','2024-06-18 22:01:51.459293','fe3d0d7d6eb34eb7a6b7985426cf8af7','fe3d0d7d6eb34eb7a6b7985426cf8af7','新的聊天'); 78 | INSERT INTO `ai_session` VALUES ('c4d192c4b2de48bcae585f5b9672cd1d','2024-06-18 22:01:45.819825','2024-06-18 22:01:45.819317','fe3d0d7d6eb34eb7a6b7985426cf8af7','fe3d0d7d6eb34eb7a6b7985426cf8af7','新的聊天'); 79 | INSERT INTO `ai_session` VALUES ('dcd32fd9f7fc4d0fb9a5b7dc9ff0dbc8','2024-06-18 21:57:09.390651','2024-06-18 21:57:09.390651','fe3d0d7d6eb34eb7a6b7985426cf8af7','fe3d0d7d6eb34eb7a6b7985426cf8af7','新的聊天'); 80 | /*!40000 ALTER TABLE `ai_session` ENABLE KEYS */; 81 | UNLOCK TABLES; 82 | 83 | -- 84 | -- Table structure for table `user` 85 | -- 86 | 87 | DROP TABLE IF EXISTS `user`; 88 | /*!40101 SET @saved_cs_client = @@character_set_client */; 89 | /*!50503 SET character_set_client = utf8mb4 */; 90 | CREATE TABLE `user` ( 91 | `id` varchar(36) NOT NULL, 92 | `created_time` datetime(6) NOT NULL, 93 | `edited_time` datetime(6) NOT NULL, 94 | `nickname` varchar(20) DEFAULT NULL, 95 | `avatar` varchar(255) DEFAULT NULL, 96 | `gender` varchar(36) DEFAULT NULL, 97 | `phone` varchar(20) NOT NULL, 98 | `password` varchar(100) NOT NULL, 99 | PRIMARY KEY (`id`), 100 | UNIQUE KEY `phone` (`phone`) 101 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; 102 | /*!40101 SET character_set_client = @saved_cs_client */; 103 | 104 | -- 105 | -- Dumping data for table `user` 106 | -- 107 | 108 | LOCK TABLES `user` WRITE; 109 | /*!40000 ALTER TABLE `user` DISABLE KEYS */; 110 | INSERT INTO `user` VALUES ('dcd256e2412f4162a6a5fcbd5cfedc84','2024-05-01 16:52:43.364225','2024-05-19 21:30:34.686818','起凡','https://my-community.oss-cn-qingdao.aliyuncs.com/20240501203628ptwondCGhItP67eb5ac72554b07800b22c542245e457.jpeg','MALE','11111111111','$2a$10$o/DHIt/eAMR175TgDV/PeeuEOpqW1N4Klft6obvs2zqBuiwMgLWOW'); 111 | INSERT INTO `user` VALUES ('fe3d0d7d6eb34eb7a6b7985426cf8af7','2024-06-06 13:23:46.130879','2024-06-06 13:23:46.130817','默认用户',NULL,NULL,'13656987994','$2a$10$q7pey1P1/b3lO9nzFLKOb.ISrX7.lkktMjghwhgvNqvA.EjZZ2mg2'); 112 | /*!40000 ALTER TABLE `user` ENABLE KEYS */; 113 | UNLOCK TABLES; 114 | /*!40103 SET TIME_ZONE=@OLD_TIME_ZONE */; 115 | 116 | /*!40101 SET SQL_MODE=@OLD_SQL_MODE */; 117 | /*!40014 SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS */; 118 | /*!40014 SET UNIQUE_CHECKS=@OLD_UNIQUE_CHECKS */; 119 | /*!40101 SET CHARACTER_SET_CLIENT=@OLD_CHARACTER_SET_CLIENT */; 120 | /*!40101 SET CHARACTER_SET_RESULTS=@OLD_CHARACTER_SET_RESULTS */; 121 | /*!40101 SET COLLATION_CONNECTION=@OLD_COLLATION_CONNECTION */; 122 | /*!40111 SET SQL_NOTES=@OLD_SQL_NOTES */; 123 | 124 | -- Dump completed on 2024-06-24 21:11:15 125 | -------------------------------------------------------------------------------- /front-end/.env.development: -------------------------------------------------------------------------------- 1 | VITE_WEBSOCKET_URL=ws://localhost:8877/handshake 2 | VITE_API_PREFIX=/api -------------------------------------------------------------------------------- /front-end/.env.production: -------------------------------------------------------------------------------- 1 | VITE_API_PREFIX= -------------------------------------------------------------------------------- /front-end/.eslintrc.cjs: -------------------------------------------------------------------------------- 1 | /* eslint-env node */ 2 | require('@rushstack/eslint-patch/modern-module-resolution') 3 | 4 | module.exports = { 5 | root: true, 6 | extends: [ 7 | 'plugin:vue/vue3-essential', 8 | 'eslint:recommended', 9 | '@vue/eslint-config-typescript', 10 | '@vue/eslint-config-prettier/skip-formatting' 11 | ], 12 | parserOptions: { 13 | ecmaVersion: 'latest' 14 | }, 15 | rules: { 16 | 'prettier/prettier': [ 17 | 'warn', 18 | { 19 | endOfLine: 'auto' 20 | } 21 | ] 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /front-end/.gitignore: -------------------------------------------------------------------------------- 1 | ### IntelliJ IDEA ### 2 | /.idea/ 3 | *.iws 4 | *.iml 5 | *.ipr 6 | # Logs 7 | logs 8 | *.log 9 | npm-debug.log* 10 | yarn-debug.log* 11 | yarn-error.log* 12 | pnpm-debug.log* 13 | lerna-debug.log* 14 | 15 | node_modules 16 | .DS_Store 17 | dist 18 | dist-ssr 19 | coverage 20 | *.local 21 | 22 | /cypress/videos/ 23 | /cypress/screenshots/ 24 | 25 | # Editor directories and files 26 | .vscode/* 27 | !.vscode/extensions.json 28 | .idea 29 | *.suo 30 | *.ntvs* 31 | *.njsproj 32 | *.sln 33 | *.sw? 34 | 35 | /src/apis/__generated/ 36 | /scripts/ingredient-admin.tar 37 | -------------------------------------------------------------------------------- /front-end/.prettierrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://json.schemastore.org/prettierrc", 3 | "semi": false, 4 | "tabWidth": 2, 5 | "singleQuote": true, 6 | "printWidth": 100, 7 | "trailingComma": "none" 8 | } -------------------------------------------------------------------------------- /front-end/README.md: -------------------------------------------------------------------------------- 1 | # mall-admin 2 | 3 | This template should help get you started developing with Vue 3 in Vite. 4 | 5 | ## Recommended IDE Setup 6 | 7 | [VSCode](https://code.visualstudio.com/) + [Volar](https://marketplace.visualstudio.com/items?itemName=Vue.volar) (and disable Vetur) + [TypeScript Vue Plugin (Volar)](https://marketplace.visualstudio.com/items?itemName=Vue.vscode-typescript-vue-plugin). 8 | 9 | ## Type Support for `.vue` Imports in TS 10 | 11 | TypeScript cannot handle type information for `.vue` imports by default, so we replace the `tsc` CLI with `vue-tsc` for type checking. In editors, we need [TypeScript Vue Plugin (Volar)](https://marketplace.visualstudio.com/items?itemName=Vue.vscode-typescript-vue-plugin) to make the TypeScript language service aware of `.vue` types. 12 | 13 | If the standalone TypeScript plugin doesn't feel fast enough to you, Volar has also implemented a [Take Over Mode](https://github.com/johnsoncodehk/volar/discussions/471#discussioncomment-1361669) that is more performant. You can enable it by the following steps: 14 | 15 | 1. Disable the built-in TypeScript Extension 16 | 1) Run `Extensions: Show Built-in Extensions` from VSCode's command palette 17 | 2) Find `TypeScript and JavaScript Language Features`, right click and select `Disable (Workspace)` 18 | 2. Reload the VSCode window by running `Developer: Reload Window` from the command palette. 19 | 20 | ## Customize configuration 21 | 22 | See [Vite Configuration Reference](https://vitejs.dev/config/). 23 | 24 | ## Project Setup 25 | 26 | ```sh 27 | npm install 28 | ``` 29 | 30 | ### Compile and Hot-Reload for Development 31 | 32 | ```sh 33 | npm run dev 34 | ``` 35 | 36 | ### Type-Check, Compile and Minify for Production 37 | 38 | ```sh 39 | npm run build 40 | ``` 41 | 42 | ### Lint with [ESLint](https://eslint.org/) 43 | 44 | ```sh 45 | npm run lint 46 | ``` 47 | -------------------------------------------------------------------------------- /front-end/env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | declare module 'element-plus/dist/locale/zh-cn.mjs' 3 | -------------------------------------------------------------------------------- /front-end/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Vite App 8 | 9 | 10 |
11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /front-end/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "uni-ai-admin", 3 | "version": "0.0.0", 4 | "private": true, 5 | "type": "module", 6 | "scripts": { 7 | "dev": "vite", 8 | "build": "run-p type-check \"build-only {@}\" --", 9 | "preview": "vite preview", 10 | "build-only": "vite build", 11 | "type-check": "vue-tsc --build --force", 12 | "lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore", 13 | "format": "prettier --write src/", 14 | "api": "node scripts/generate-api.js" 15 | }, 16 | "dependencies": { 17 | "@element-plus/icons-vue": "^2.3.1", 18 | "axios": "^1.6.4", 19 | "dayjs": "^1.11.10", 20 | "element-plus": "^2.4.2", 21 | "lodash": "^4.17.21", 22 | "md-editor-v3": "^4.13.2", 23 | "pinia": "^2.1.7", 24 | "sse.js": "^2.4.1", 25 | "vue": "^3.4.21", 26 | "vue-router": "^4.2.5" 27 | }, 28 | "devDependencies": { 29 | "@rushstack/eslint-patch": "^1.3.3", 30 | "@tsconfig/node18": "^18.2.2", 31 | "@types/node": "^18.19.3", 32 | "@vitejs/plugin-vue": "^4.5.2", 33 | "@vitejs/plugin-vue-jsx": "^3.1.0", 34 | "@vue/eslint-config-prettier": "^8.0.0", 35 | "@vue/eslint-config-typescript": "^12.0.0", 36 | "@vue/tsconfig": "^0.5.0", 37 | "adm-zip": "^0.5.10", 38 | "eslint": "^8.49.0", 39 | "eslint-plugin-vue": "^9.17.0", 40 | "fs-extra": "^11.1.1", 41 | "npm-run-all2": "^6.1.1", 42 | "prettier": "^3.0.3", 43 | "sass": "^1.69.7", 44 | "typescript": "~5.3.0", 45 | "uuid": "^9.0.1", 46 | "vite": "^5.0.10", 47 | "vue-tsc": "^1.8.25" 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /front-end/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qifan777/dive-into-spring-ai/406b00b513ddac955e79e76b998487a2cc1fa20f/front-end/public/favicon.ico -------------------------------------------------------------------------------- /front-end/scripts/generate-api.js: -------------------------------------------------------------------------------- 1 | /* eslint-env node */ 2 | import http from 'http' 3 | import fs from 'fs' 4 | import fse from 'fs-extra' 5 | import { v4 as uuidv4 } from 'uuid' 6 | import os from 'os' 7 | import path from 'path' 8 | import AdmZip from 'adm-zip' 9 | const sourceUrl = 'http://localhost:9902/ts.zip' 10 | const tmpFilePath = os.tmpdir() + '/' + uuidv4() + '.zip' 11 | const generatePath = 'src/apis/__generated' 12 | 13 | console.log('Downloading ' + sourceUrl + '...') 14 | 15 | const tmpFile = fs.createWriteStream(tmpFilePath) 16 | http.get(sourceUrl, (response) => { 17 | response.pipe(tmpFile) 18 | tmpFile.on('finish', () => { 19 | tmpFile.close() 20 | console.log('File save success: ', tmpFilePath) 21 | 22 | // Remove generatePath if it exists 23 | if (fs.existsSync(generatePath)) { 24 | console.log('Removing existing generatePath...') 25 | fse.removeSync(generatePath) 26 | console.log('Existing generatePath removed.') 27 | } 28 | 29 | // Unzip the file using adm-zip 30 | console.log('Unzipping the file...') 31 | const zip = new AdmZip(tmpFilePath) 32 | zip.extractAllTo(generatePath, true) 33 | console.log('File unzipped successfully.') 34 | // Remove the temporary file 35 | console.log('Removing temporary file...') 36 | fs.unlink(tmpFilePath, (err) => { 37 | if (err) { 38 | console.error('Error while removing temporary file:', err) 39 | } else { 40 | console.log('Temporary file removed.') 41 | } 42 | }) 43 | traverseDirectory(modelPath) 44 | traverseDirectory(servicePath) 45 | }) 46 | }) 47 | 48 | // 替换目录路径 49 | const modelPath = 'src/apis/__generated/model' 50 | const servicePath = 'src/apis/__generated/services' 51 | 52 | // 递归遍历目录中的所有文件 53 | function traverseDirectory(directoryPath) { 54 | const files = fs.readdirSync(directoryPath) 55 | 56 | files.forEach((file) => { 57 | const filePath = path.join(directoryPath, file) 58 | const stats = fs.statSync(filePath) 59 | 60 | if (stats.isDirectory()) { 61 | traverseDirectory(filePath) 62 | } else if (stats.isFile() && path.extname(filePath) === '.ts') { 63 | replaceInFile(filePath) 64 | } 65 | }) 66 | } 67 | 68 | // 替换文件中的文本 69 | function replaceInFile(filePath) { 70 | const fileContent = fs.readFileSync(filePath, 'utf8') 71 | const updatedContent = fileContent 72 | .replaceAll('readonly ', '') 73 | .replace(/ReadonlyArray/g, 'Array') 74 | .replaceAll('ReadonlyMap', 'Map') 75 | .replace(/Map<(\S+), (\S+)>/g, '{ [key: $1]: $2 }') 76 | // .replace(/query: (\S+)/g, 'query: T') 77 | fs.writeFileSync(filePath, updatedContent, 'utf8') 78 | } 79 | -------------------------------------------------------------------------------- /front-end/src/App.vue: -------------------------------------------------------------------------------- 1 | 5 | 6 | 14 | 15 | 79 | -------------------------------------------------------------------------------- /front-end/src/assets/SrpingAI知识点.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qifan777/dive-into-spring-ai/406b00b513ddac955e79e76b998487a2cc1fa20f/front-end/src/assets/SrpingAI知识点.png -------------------------------------------------------------------------------- /front-end/src/assets/background.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qifan777/dive-into-spring-ai/406b00b513ddac955e79e76b998487a2cc1fa20f/front-end/src/assets/background.jpg -------------------------------------------------------------------------------- /front-end/src/assets/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qifan777/dive-into-spring-ai/406b00b513ddac955e79e76b998487a2cc1fa20f/front-end/src/assets/cover.png -------------------------------------------------------------------------------- /front-end/src/assets/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qifan777/dive-into-spring-ai/406b00b513ddac955e79e76b998487a2cc1fa20f/front-end/src/assets/logo.jpg -------------------------------------------------------------------------------- /front-end/src/components/image/image-upload.vue: -------------------------------------------------------------------------------- 1 | 63 | 64 | 84 | -------------------------------------------------------------------------------- /front-end/src/components/key-value/key-value-input.vue: -------------------------------------------------------------------------------- 1 | 32 | 33 | 70 | 71 | 93 | -------------------------------------------------------------------------------- /front-end/src/components/key-value/value-input.vue: -------------------------------------------------------------------------------- 1 | 48 | 106 | 107 | 136 | -------------------------------------------------------------------------------- /front-end/src/main.ts: -------------------------------------------------------------------------------- 1 | import { createApp } from 'vue' 2 | import { createPinia } from 'pinia' 3 | import ElementPlus from 'element-plus' 4 | import 'element-plus/dist/index.css' 5 | import * as ElementPlusIconsVue from '@element-plus/icons-vue' 6 | import App from './App.vue' 7 | import router from './router' 8 | import { MdPreview } from 'md-editor-v3' 9 | import 'md-editor-v3/lib/preview.css' 10 | const app = createApp(App) 11 | 12 | app.use(createPinia()) 13 | app.use(ElementPlus) 14 | app.use(router) 15 | app.component('MdPreview', MdPreview) 16 | for (const [key, component] of Object.entries(ElementPlusIconsVue)) { 17 | app.component(key, component) 18 | } 19 | app.mount('#app') 20 | -------------------------------------------------------------------------------- /front-end/src/router/index.ts: -------------------------------------------------------------------------------- 1 | import { createRouter, createWebHashHistory } from 'vue-router' 2 | import RegisterView from '@/views/login/register-view.vue' 3 | import LoginView from '@/views/login/login-view.vue' 4 | import ChatView from '@/views/chat/chat-view.vue' 5 | import AnalyzeResultView from '@/views/code/analyze/analyze-result-view.vue' 6 | 7 | const router = createRouter({ 8 | history: createWebHashHistory(import.meta.env.BASE_URL), 9 | routes: [ 10 | { 11 | path: '/', 12 | component: ChatView 13 | }, 14 | { 15 | path: '/login', 16 | name: 'login', 17 | component: LoginView 18 | }, 19 | { 20 | path: '/register', 21 | name: 'register', 22 | component: RegisterView 23 | }, 24 | { 25 | path: '/analyze', 26 | component: AnalyzeResultView, 27 | props(to) { 28 | return { path: to.query.path } 29 | } 30 | } 31 | ] 32 | }) 33 | 34 | export default router 35 | -------------------------------------------------------------------------------- /front-end/src/stores/home-store.ts: -------------------------------------------------------------------------------- 1 | import { defineStore } from 'pinia' 2 | import { ref } from 'vue' 3 | import type { UserDto } from '@/apis/__generated/model/dto' 4 | import { api } from '@/utils/api-instance' 5 | 6 | export const useHomeStore = defineStore('home', () => { 7 | const userInfo = ref() 8 | const getUserInfo = async () => { 9 | userInfo.value = await api.userController.userInfo() 10 | return userInfo.value 11 | } 12 | const init = async () => { 13 | await getUserInfo() 14 | } 15 | const logout = () => { 16 | userInfo.value = undefined 17 | } 18 | return { userInfo, getUserInfo, init, logout } 19 | }) 20 | -------------------------------------------------------------------------------- /front-end/src/typings/index.d.ts: -------------------------------------------------------------------------------- 1 | import type { MenuDto } from '@/apis/__generated/model/dto' 2 | 3 | export type EditMode = 'CREATE' | 'UPDATE' 4 | export interface Scope { 5 | row: T 6 | $index: number 7 | } 8 | 9 | export interface Result { 10 | code: number 11 | success: boolean 12 | msg: string 13 | result: T 14 | } 15 | export type MenuTreeDto = { 16 | children?: MenuTreeDto[] 17 | } & MenuDto['MenuRepository/SIMPLE_FETCHER'] 18 | -------------------------------------------------------------------------------- /front-end/src/utils/api-instance.ts: -------------------------------------------------------------------------------- 1 | import { Api } from '@/apis/__generated' 2 | import { request } from '@/utils/request' 3 | // 导出全局变量`apis` 4 | export const api = new Api(async ({ uri, method, body }) => { 5 | return await request({ url: uri, method, data: body }) 6 | }) 7 | -------------------------------------------------------------------------------- /front-end/src/utils/common.ts: -------------------------------------------------------------------------------- 1 | import { ElMessage } from 'element-plus' 2 | import type { ValidateFieldsError } from 'async-validator/dist-types' 3 | 4 | export const assertFormValidate = (callback: () => void) => { 5 | return (valid: boolean, fields: ValidateFieldsError | undefined) => { 6 | if (valid) { 7 | callback() 8 | } else { 9 | if (fields) { 10 | const messages: string[] = [] 11 | for (const field in fields) { 12 | fields[field].forEach((error) => { 13 | if (error.message) { 14 | messages.push(error.message) 15 | } 16 | }) 17 | } 18 | ElMessage.error({ message: messages.join('\n') }) 19 | } 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /front-end/src/utils/request.ts: -------------------------------------------------------------------------------- 1 | import axios from 'axios' 2 | import { ElMessage } from 'element-plus' 3 | import router from '@/router' 4 | 5 | const BASE_URL = import.meta.env.VITE_API_PREFIX 6 | export const request = axios.create({ 7 | baseURL: BASE_URL, 8 | timeout: 600000 9 | }) 10 | request.interceptors.response.use( 11 | (res) => { 12 | return res.data.result 13 | }, 14 | ({ response }) => { 15 | if (response.data.code !== 1) { 16 | ElMessage.warning({ message: response.data.msg }) 17 | } 18 | if (response.data.code === 10012) { 19 | router.push('/login') 20 | } 21 | return Promise.reject(response.data.result) 22 | } 23 | ) 24 | -------------------------------------------------------------------------------- /front-end/src/views/chat/chat-view.vue: -------------------------------------------------------------------------------- 1 | 151 | 256 | 386 | -------------------------------------------------------------------------------- /front-end/src/views/chat/components/markdown-message.vue: -------------------------------------------------------------------------------- 1 | 6 | 7 | 11 | 12 | 22 | -------------------------------------------------------------------------------- /front-end/src/views/chat/components/message-input.vue: -------------------------------------------------------------------------------- 1 | 26 | 27 | 53 | 54 | 77 | -------------------------------------------------------------------------------- /front-end/src/views/chat/components/message-row.vue: -------------------------------------------------------------------------------- 1 | 19 | 20 | 21 | 58 | 59 | 113 | -------------------------------------------------------------------------------- /front-end/src/views/chat/components/session-item.vue: -------------------------------------------------------------------------------- 1 | 16 | 17 | 41 | 42 | 141 | -------------------------------------------------------------------------------- /front-end/src/views/chat/components/text-loading.vue: -------------------------------------------------------------------------------- 1 | 9 | 10 | 51 | -------------------------------------------------------------------------------- /front-end/src/views/chat/store/chat-store.ts: -------------------------------------------------------------------------------- 1 | import { defineStore } from 'pinia' 2 | import { ref } from 'vue' 3 | import type { AiSessionDto } from '@/apis/__generated/model/dto' 4 | import { api } from '@/utils/api-instance' 5 | import type { AiMessageInput, AiSessionInput } from '@/apis/__generated/model/static' 6 | import { ElMessageBox } from 'element-plus' 7 | 8 | export type AiSession = Pick< 9 | AiSessionDto['AiSessionRepository/FETCHER'], 10 | 'id' | 'name' | 'editedTime' 11 | > & { 12 | messages: AiMessage[] 13 | } 14 | 15 | export type AiMessage = Pick & { 16 | id: string 17 | } 18 | export const useChatStore = defineStore('ai-chat', () => { 19 | const isEdit = ref(false) 20 | const activeSession = ref() 21 | const sessionList = ref([]) 22 | const handleCreateSession = async (session: AiSessionInput) => { 23 | const res = await api.aiSessionController.save({ body: session }) 24 | const sessionRes = await api.aiSessionController.findById({ id: res }) 25 | sessionList.value.unshift(sessionRes) 26 | activeSession.value = sessionList.value[0] 27 | } 28 | // 从会话列表中删除会话 29 | const handleDeleteSession = async (session: AiSession) => { 30 | await api.aiSessionController.delete({ body: [session.id] }) 31 | const index = sessionList.value.findIndex((value) => { 32 | return value.id === session.id 33 | }) 34 | sessionList.value.splice(index, 1) 35 | if (index == sessionList.value.length) { 36 | activeSession.value = sessionList.value[index - 1] 37 | } else { 38 | activeSession.value = sessionList.value[index] 39 | } 40 | } 41 | // 修改会话 42 | const handleUpdateSession = async () => { 43 | if (!activeSession.value) { 44 | return 45 | } 46 | await api.aiSessionController.save({ 47 | body: { ...activeSession.value } 48 | }) 49 | isEdit.value = false 50 | } 51 | const handleClearMessage = async (sessionId: string) => { 52 | await ElMessageBox.confirm('是否清空会话记录?', '提示') 53 | await api.aiMessageController.deleteHistory({ sessionId }) 54 | const index = sessionList.value.findIndex((value) => { 55 | return value.id === sessionId 56 | }) 57 | activeSession.value = await api.aiSessionController.findById({ id: sessionId }) 58 | sessionList.value[index] = activeSession.value 59 | } 60 | return { 61 | isEdit, 62 | activeSession, 63 | sessionList, 64 | handleUpdateSession, 65 | handleCreateSession, 66 | handleDeleteSession, 67 | handleClearMessage 68 | } 69 | }) 70 | -------------------------------------------------------------------------------- /front-end/src/views/code/analyze/analyze-result-view.vue: -------------------------------------------------------------------------------- 1 | 38 | 39 | 52 | 53 | 58 | -------------------------------------------------------------------------------- /front-end/src/views/code/analyze/collapse-title.vue: -------------------------------------------------------------------------------- 1 | 10 | 11 | 18 | 19 | 40 | -------------------------------------------------------------------------------- /front-end/src/views/login/login-view.vue: -------------------------------------------------------------------------------- 1 | 42 | 93 | 160 | -------------------------------------------------------------------------------- /front-end/src/views/login/register-view.vue: -------------------------------------------------------------------------------- 1 | 44 | 91 | 92 | 169 | -------------------------------------------------------------------------------- /front-end/tsconfig.app.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "@vue/tsconfig/tsconfig.dom.json", 3 | "include": ["env.d.ts", "src/**/*", "src/**/*.vue"], 4 | "exclude": ["src/**/__tests__/*"], 5 | "compilerOptions": { 6 | "composite": true, 7 | "noEmit": true, 8 | "baseUrl": ".", 9 | "paths": { 10 | "@/*": ["./src/*"] 11 | }, 12 | "types": ["element-plus/global"] 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /front-end/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "files": [], 3 | "references": [ 4 | { 5 | "path": "./tsconfig.node.json" 6 | }, 7 | { 8 | "path": "./tsconfig.app.json" 9 | } 10 | ] 11 | } 12 | -------------------------------------------------------------------------------- /front-end/tsconfig.node.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "@tsconfig/node18/tsconfig.json", 3 | "include": [ 4 | "vite.config.*", 5 | "vitest.config.*", 6 | "cypress.config.*", 7 | "nightwatch.conf.*", 8 | "playwright.config.*" 9 | ], 10 | "compilerOptions": { 11 | "composite": true, 12 | "noEmit": true, 13 | "module": "ESNext", 14 | "moduleResolution": "Bundler", 15 | "types": ["node"] 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /front-end/vite.config.ts: -------------------------------------------------------------------------------- 1 | import { fileURLToPath, URL } from 'node:url' 2 | 3 | import { defineConfig } from 'vite' 4 | import vue from '@vitejs/plugin-vue' 5 | import vueJsx from '@vitejs/plugin-vue-jsx' 6 | 7 | // https://vitejs.dev/config/ 8 | export default defineConfig({ 9 | plugins: [vue(), vueJsx()], 10 | resolve: { 11 | alias: { 12 | '@': fileURLToPath(new URL('./src', import.meta.url)) 13 | } 14 | }, 15 | server: { 16 | host: '0.0.0.0', 17 | port: 5177, 18 | proxy: { 19 | '/api': { 20 | target: 'http://localhost:9902', 21 | changeOrigin: true, 22 | rewrite: (path) => path.replace(/^\/api/, '') 23 | } 24 | } 25 | } 26 | }) 27 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | io.github.qifan777 8 | dive-into-spring-ai 9 | 1.0-SNAPSHOT 10 | 11 | 12 | spring-boot-starter-parent 13 | org.springframework.boot 14 | 15 | 3.2.1 16 | 17 | 18 | 19 | 17 20 | 17 21 | UTF-8 22 | 1.0.0 23 | 1.0.0.1 24 | 0.1.9 25 | 0.8.134 26 | 5.8.25 27 | 1.37.0 28 | UTF-8 29 | UTF-8 30 | 31 | 32 | 33 | com.squareup.okhttp3 34 | okhttp 35 | 36 | 37 | com.github.javaparser 38 | javaparser-symbol-solver-core 39 | 3.26.2 40 | 41 | 42 | com.alibaba 43 | easyexcel 44 | 3.3.4 45 | 46 | 47 | commons-io 48 | commons-io 49 | 50 | 51 | 52 | 53 | org.springframework.boot 54 | spring-boot-starter-data-neo4j 55 | 56 | 57 | org.springframework.ai 58 | spring-ai-starter-vector-store-redis 59 | 60 | 61 | redis.clients 62 | jedis 63 | 64 | 65 | org.springframework.ai 66 | spring-ai-advisors-vector-store 67 | 68 | 69 | org.springframework.ai 70 | spring-ai-tika-document-reader 71 | 72 | 73 | jimmer-spring-boot-starter 74 | org.babyfish.jimmer 75 | ${jimmer.version} 76 | 77 | 78 | io.github.qifan777 79 | spring-boot-starter-oss 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | com.alibaba.cloud.ai 104 | spring-ai-alibaba-starter-dashscope 105 | ${spring-ai-alibaba.version} 106 | 107 | 108 | org.springframework.boot 109 | spring-boot-starter-actuator 110 | 111 | 112 | org.springframework.boot 113 | spring-boot-starter-test 114 | test 115 | 116 | 117 | org.springframework 118 | spring-test 119 | 120 | 121 | mysql-connector-j 122 | com.mysql 123 | 124 | 125 | hutool-core 126 | cn.hutool 127 | ${hutool.version} 128 | 129 | 130 | cn.dev33 131 | sa-token-redis-jackson 132 | ${sa-token.version} 133 | 134 | 135 | cn.dev33 136 | sa-token-spring-boot3-starter 137 | ${sa-token.version} 138 | 139 | 140 | 141 | 142 | 143 | org.springframework.ai 144 | spring-ai-bom 145 | ${spring-ai.version} 146 | pom 147 | import 148 | 149 | 150 | io.github.qifan777 151 | uni-ai-bom 152 | ${uni-ai.version} 153 | pom 154 | import 155 | 156 | 157 | 158 | 159 | 160 | spring-milestones 161 | Spring Milestones 162 | https://repo.spring.io/milestone 163 | 164 | false 165 | 166 | 167 | 168 | 169 | 170 | 171 | org.springframework.boot 172 | spring-boot-maven-plugin 173 | 174 | 175 | org.apache.maven.plugins 176 | maven-compiler-plugin 177 | 178 | 179 | 180 | jimmer-apt 181 | org.babyfish.jimmer 182 | ${jimmer.version} 183 | 184 | 185 | lombok 186 | org.projectlombok 187 | ${lombok.version} 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | -------------------------------------------------------------------------------- /src/main/dto/AiMessage.dto: -------------------------------------------------------------------------------- 1 | export io.github.qifan777.knowledge.ai.message.AiMessage 2 | 3 | input AiMessageInput{ 4 | #allScalars(AiMessage) 5 | id(session) 6 | } -------------------------------------------------------------------------------- /src/main/dto/AiSession.dto: -------------------------------------------------------------------------------- 1 | export io.github.qifan777.knowledge.ai.session.AiSession 2 | 3 | input AiSessionInput{ 4 | id? 5 | name 6 | } -------------------------------------------------------------------------------- /src/main/dto/User.dto: -------------------------------------------------------------------------------- 1 | export io.github.qifan777.knowledge.user.User 2 | 3 | input UserLoginInput{ 4 | phone 5 | password 6 | } 7 | input UserRegisterInput{ 8 | phone 9 | password 10 | } -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ServerApplication.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge; 2 | 3 | import io.github.qifan777.knowledge.infrastructure.code.CodeAssistantProperties; 4 | import org.babyfish.jimmer.client.EnableImplicitApi; 5 | import org.springframework.boot.SpringApplication; 6 | import org.springframework.boot.autoconfigure.SpringBootApplication; 7 | import org.springframework.boot.context.properties.EnableConfigurationProperties; 8 | 9 | @SpringBootApplication 10 | @EnableImplicitApi 11 | @EnableConfigurationProperties(CodeAssistantProperties.class) 12 | public class ServerApplication { 13 | public static void main(String[] args) { 14 | SpringApplication.run(ServerApplication.class, args); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/agent/AbstractAgent.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.agent; 2 | 3 | import lombok.extern.slf4j.Slf4j; 4 | import org.springframework.context.annotation.Description; 5 | import org.springframework.util.StringUtils; 6 | 7 | import java.util.Arrays; 8 | import java.util.List; 9 | 10 | @Slf4j 11 | public abstract class AbstractAgent { 12 | 13 | /** 14 | * 获取指定的function bean名称 15 | * 16 | * @return Function Call名称列表 17 | */ 18 | public String[] getFunctions(Class... clazz) { 19 | List> classList = Arrays.stream(clazz).filter(aClass -> aClass.isAnnotationPresent(Description.class)).toList(); 20 | String[] names = new String[classList.size()]; 21 | classList.stream().map(aClass -> StringUtils.uncapitalize(aClass.getSimpleName())).toList().toArray(names); 22 | return names; 23 | } 24 | /** 25 | * 获取内嵌的Function Call也就是Agent的Tools 26 | * 27 | * @return Function Call名称列表 28 | */ 29 | public String[] getAgentFunctions(Class clazz) { 30 | List> classList = Arrays.stream(clazz.getClasses()).filter(aClass -> aClass.isAnnotationPresent(Description.class)).toList(); 31 | String[] names = new String[classList.size()]; 32 | classList.stream().map(aClass -> StringUtils.uncapitalize(this.getClass().getSimpleName()) + "." + aClass.getSimpleName()).toList().toArray(names); 33 | return names; 34 | } 35 | 36 | } -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/agent/Agent.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.agent; 2 | 3 | import org.springframework.stereotype.Component; 4 | 5 | import java.lang.annotation.ElementType; 6 | import java.lang.annotation.Retention; 7 | import java.lang.annotation.RetentionPolicy; 8 | import java.lang.annotation.Target; 9 | 10 | @Component 11 | @Retention(RetentionPolicy.RUNTIME) 12 | @Target(ElementType.TYPE) 13 | public @interface Agent { 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/agent/chronologist/Chronologist.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.agent.chronologist; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import com.fasterxml.jackson.annotation.JsonPropertyDescription; 5 | import io.github.qifan777.knowledge.ai.agent.AbstractAgent; 6 | import io.github.qifan777.knowledge.ai.agent.Agent; 7 | import lombok.AllArgsConstructor; 8 | import org.springframework.ai.chat.client.ChatClient; 9 | import org.springframework.ai.chat.model.ChatModel; 10 | import org.springframework.context.annotation.Description; 11 | import org.springframework.stereotype.Component; 12 | 13 | import java.time.LocalDate; 14 | import java.time.LocalDateTime; 15 | import java.time.format.DateTimeFormatter; 16 | import java.util.function.Function; 17 | 18 | @Agent 19 | @Description(value = "回答用户有关于日期、时间的提问") 20 | @AllArgsConstructor 21 | public class Chronologist extends AbstractAgent implements Function { 22 | private final String SYSTEM = """ 23 | 你是一个专业的编年史学家,可以回答有关时间的问题。 24 | 您还可以执行各种与时间相关的任务,如转换和格式化。 25 | """; 26 | private final ChatModel chatModel; 27 | 28 | 29 | @Override 30 | public String apply(Request request) { 31 | return ChatClient.create(chatModel) 32 | .prompt() 33 | .system(SYSTEM) 34 | .user(request.query) 35 | .toolNames(getAgentFunctions(this.getClass())) 36 | .call() 37 | .content(); 38 | } 39 | 40 | public record Request( 41 | @JsonProperty(required = true) @JsonPropertyDescription(value = "用户原始的提问") String query) { 42 | } 43 | 44 | @Component 45 | @Description("获取当前的时间,格式是 HH:mm:ss") 46 | public static class CurrentTime implements Function { 47 | @Override 48 | public String apply(Request request) { 49 | LocalDateTime currentDate = LocalDateTime.now(); 50 | return currentDate.toLocalTime().format(DateTimeFormatter.ofPattern("HH:mm:ss")); 51 | 52 | } 53 | 54 | public record Request() { 55 | } 56 | } 57 | 58 | @Component 59 | @Description("获取当前的日期,格式是 yyyy-MM-dd") 60 | public static class CurrentDate implements Function { 61 | @Override 62 | public String apply(Request request) { 63 | LocalDate currentDate = LocalDate.now(); 64 | return currentDate.format(DateTimeFormatter.ofPattern("yyyy-MM-dd")); 65 | } 66 | 67 | public record Request() { 68 | } 69 | } 70 | } -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/agent/computer/ComputerAssistant.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.agent.computer; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import com.fasterxml.jackson.annotation.JsonPropertyDescription; 5 | import io.github.qifan777.knowledge.ai.agent.AbstractAgent; 6 | import io.github.qifan777.knowledge.ai.agent.Agent; 7 | import lombok.AllArgsConstructor; 8 | import org.springframework.ai.chat.client.ChatClient; 9 | import org.springframework.ai.chat.model.ChatModel; 10 | import org.springframework.context.annotation.Description; 11 | 12 | import java.util.function.Function; 13 | 14 | @Agent 15 | @Description("提供关于当前主机的cpu,文件,文件夹相关问题的有用回答") 16 | @AllArgsConstructor 17 | public class ComputerAssistant extends AbstractAgent implements Function { 18 | private final ChatModel chatModel; 19 | 20 | @Override 21 | public String apply(Request request) { 22 | return ChatClient.create(chatModel) 23 | .prompt() 24 | .toolNames(getFunctions(CpuAnalyzer.class, DirectoryReader.class)) 25 | .user(request.query()) 26 | .call() 27 | .content(); 28 | } 29 | 30 | public record Request( 31 | @JsonProperty(required = true) @JsonPropertyDescription(value = "用户原始的提问") String query) { 32 | } 33 | 34 | 35 | } -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/agent/computer/CpuAnalyzer.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.agent.computer; 2 | 3 | import org.springframework.context.annotation.Description; 4 | import org.springframework.stereotype.Component; 5 | 6 | import java.util.function.Function; 7 | 8 | @Component 9 | @Description("读取CPU的数量") 10 | public class CpuAnalyzer implements Function { 11 | @Override 12 | public Integer apply(Request request) { 13 | return Runtime.getRuntime().availableProcessors(); 14 | } 15 | 16 | public record Request() { 17 | } 18 | } -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/agent/computer/DirectoryReader.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.agent.computer; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import com.fasterxml.jackson.annotation.JsonPropertyDescription; 5 | import org.springframework.context.annotation.Description; 6 | import org.springframework.stereotype.Component; 7 | 8 | import java.io.File; 9 | import java.util.ArrayList; 10 | import java.util.Arrays; 11 | import java.util.List; 12 | import java.util.function.Function; 13 | 14 | @Component 15 | @Description("读取用户给定的文件夹,列出文件夹下的所有文件") 16 | public class DirectoryReader implements Function { 17 | @Override 18 | public String apply(Request request) { 19 | File f = new File(request.path); 20 | List out = new ArrayList<>(); 21 | if (f.exists()) { 22 | String[] list = f.list(); 23 | if (list != null) { 24 | out = Arrays.asList(list); 25 | } 26 | } 27 | return String.join(",", out); 28 | } 29 | 30 | public record Request( 31 | @JsonProperty(required = true) @JsonPropertyDescription("本机文件夹的绝对路径") String path 32 | ) { 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/document/DocumentController.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.document; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.SneakyThrows; 5 | import org.springframework.ai.document.Document; 6 | import org.springframework.ai.reader.tika.TikaDocumentReader; 7 | import org.springframework.ai.transformer.splitter.TokenTextSplitter; 8 | import org.springframework.ai.vectorstore.VectorStore; 9 | import org.springframework.core.io.InputStreamResource; 10 | import org.springframework.web.bind.annotation.PostMapping; 11 | import org.springframework.web.bind.annotation.RequestMapping; 12 | import org.springframework.web.bind.annotation.RequestParam; 13 | import org.springframework.web.bind.annotation.RestController; 14 | import org.springframework.web.multipart.MultipartFile; 15 | 16 | import java.util.List; 17 | 18 | @RequestMapping("document") 19 | @RestController 20 | @AllArgsConstructor 21 | public class DocumentController { 22 | private final VectorStore vectorStore; 23 | 24 | /** 25 | * 嵌入文件 26 | * 27 | * @param file 待嵌入的文件 28 | * @return 是否成功 29 | */ 30 | @SneakyThrows 31 | @PostMapping("embedding") 32 | public Boolean embedding(@RequestParam MultipartFile file) { 33 | // 从IO流中读取文件 34 | TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new InputStreamResource(file.getInputStream())); 35 | // 将文本内容划分成更小的块 36 | List splitDocuments = new TokenTextSplitter() 37 | .apply(tikaDocumentReader.read()); 38 | // 存入向量数据库,这个过程会自动调用embeddingModel,将文本变成向量再存入。 39 | vectorStore.add(splitDocuments); 40 | return true; 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/message/AiMessage.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.message; 2 | 3 | import io.github.qifan777.knowledge.ai.session.AiSession; 4 | import io.github.qifan777.knowledge.infrastructure.jimmer.BaseEntity; 5 | import jakarta.validation.constraints.Null; 6 | import lombok.AllArgsConstructor; 7 | import lombok.Data; 8 | import lombok.experimental.Accessors; 9 | import org.babyfish.jimmer.sql.*; 10 | import org.springframework.ai.chat.messages.MessageType; 11 | 12 | import java.util.List; 13 | 14 | /** 15 | * 历史消息 16 | */ 17 | 18 | @Entity 19 | public interface AiMessage extends BaseEntity { 20 | 21 | /** 22 | * 消息类型(用户/助手/系统) 23 | */ 24 | MessageType type(); 25 | 26 | /** 27 | * 消息内容 28 | */ 29 | String textContent(); 30 | 31 | @Serialized 32 | @Null 33 | List medias(); 34 | 35 | @IdView 36 | String sessionId(); 37 | 38 | /** 39 | * 会话 40 | */ 41 | @ManyToOne 42 | @JoinColumn(name = "ai_session_id") 43 | @OnDissociate(DissociateAction.DELETE) 44 | AiSession session(); 45 | 46 | @Data 47 | @Accessors(chain = true) 48 | class Media { 49 | public String type; 50 | public String data; 51 | } 52 | } 53 | 54 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/message/AiMessageChatMemory.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.message; 2 | 3 | import cn.hutool.core.collection.CollectionUtil; 4 | import io.qifan.infrastructure.common.exception.BusinessException; 5 | import lombok.AllArgsConstructor; 6 | import lombok.SneakyThrows; 7 | import org.jetbrains.annotations.NotNull; 8 | import org.springframework.ai.chat.memory.ChatMemory; 9 | import org.springframework.ai.chat.messages.*; 10 | import org.springframework.ai.content.Media; 11 | import org.springframework.http.MediaType; 12 | import org.springframework.stereotype.Service; 13 | 14 | import java.net.URL; 15 | import java.util.ArrayList; 16 | import java.util.List; 17 | 18 | @Service 19 | @AllArgsConstructor 20 | public class AiMessageChatMemory implements ChatMemory { 21 | private final AiMessageRepository messageRepository; 22 | 23 | public static AiMessage toAiMessage(Message message, String sessionId) { 24 | return AiMessageDraft.$.produce(draft -> { 25 | draft.setSessionId(sessionId) 26 | .setTextContent(message.getText()) 27 | .setType(message.getMessageType()) 28 | .setMedias(new ArrayList<>()); 29 | if (message instanceof UserMessage userMessage && 30 | !CollectionUtil.isEmpty(userMessage.getMedia())) { 31 | List mediaList = userMessage 32 | .getMedia() 33 | .stream() 34 | .map(media -> new AiMessage.Media() 35 | .setType(media.getMimeType().getType()) 36 | .setData(media.getData().toString())) 37 | .toList(); 38 | draft.setMedias(mediaList); 39 | } 40 | }); 41 | } 42 | 43 | public static Message toSpringAiMessage(AiMessage aiMessage) { 44 | List mediaList = new ArrayList<>(); 45 | if (!CollectionUtil.isEmpty(aiMessage.medias())) { 46 | mediaList = aiMessage.medias().stream().map(AiMessageChatMemory::toSpringAiMedia).toList(); 47 | } 48 | if (aiMessage.type().equals(MessageType.ASSISTANT)) { 49 | return new AssistantMessage(aiMessage.textContent()); 50 | } 51 | if (aiMessage.type().equals(MessageType.USER)) { 52 | return UserMessage.builder().text(aiMessage.textContent()) 53 | .media(mediaList) 54 | .build(); 55 | } 56 | if (aiMessage.type().equals(MessageType.SYSTEM)) { 57 | return new SystemMessage(aiMessage.textContent()); 58 | } 59 | throw new BusinessException("不支持的消息类型"); 60 | } 61 | 62 | @SneakyThrows 63 | public static Media toSpringAiMedia(AiMessage.Media media) { 64 | return Media.builder().data(new URL(media.getData())) 65 | .mimeType(new MediaType(media.getType())) 66 | .build(); 67 | } 68 | 69 | /** 70 | * 不实现,手动前端发起请求保存用户的消息和大模型回复的消息 71 | */ 72 | @Override 73 | public void add(@NotNull String conversationId, @NotNull List messages) { 74 | } 75 | 76 | 77 | /** 78 | * 查询会话内的消息最新n条历史记录 79 | * 80 | * @param conversationId 会话id 81 | * @return org.springframework.ai.chat.messages.Message格式的消息 82 | */ 83 | @Override 84 | public @NotNull List get(@NotNull String conversationId) { 85 | return messageRepository 86 | // 查询会话内的最新10条消息 87 | .findBySessionId(conversationId, 10) 88 | .stream() 89 | // 转成Message对象 90 | .map(AiMessageChatMemory::toSpringAiMessage) 91 | .toList(); 92 | } 93 | 94 | /** 95 | * 清除会话内的消息 96 | * 97 | * @param conversationId 会话id 98 | */ 99 | @Override 100 | public void clear(@NotNull String conversationId) { 101 | messageRepository.deleteBySessionId(conversationId); 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/message/AiMessageController.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.message; 2 | 3 | import com.fasterxml.jackson.databind.ObjectMapper; 4 | import io.github.qifan777.knowledge.ai.agent.Agent; 5 | import io.github.qifan777.knowledge.ai.message.dto.AiMessageInput; 6 | import io.github.qifan777.knowledge.ai.message.dto.AiMessageWrapper; 7 | import lombok.AllArgsConstructor; 8 | import lombok.SneakyThrows; 9 | import lombok.extern.slf4j.Slf4j; 10 | import org.springframework.ai.chat.client.ChatClient; 11 | import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; 12 | import org.springframework.ai.chat.client.advisor.vectorstore.QuestionAnswerAdvisor; 13 | import org.springframework.ai.chat.messages.Message; 14 | import org.springframework.ai.chat.messages.UserMessage; 15 | import org.springframework.ai.chat.model.ChatModel; 16 | import org.springframework.ai.chat.model.ChatResponse; 17 | import org.springframework.ai.chat.prompt.PromptTemplate; 18 | import org.springframework.ai.content.Media; 19 | import org.springframework.ai.reader.tika.TikaDocumentReader; 20 | import org.springframework.ai.vectorstore.VectorStore; 21 | import org.springframework.context.ApplicationContext; 22 | import org.springframework.core.io.InputStreamResource; 23 | import org.springframework.http.MediaType; 24 | import org.springframework.http.codec.ServerSentEvent; 25 | import org.springframework.util.CollectionUtils; 26 | import org.springframework.web.bind.annotation.*; 27 | import org.springframework.web.multipart.MultipartFile; 28 | import reactor.core.publisher.Flux; 29 | 30 | import java.util.Map; 31 | 32 | @RequestMapping("message") 33 | @RestController 34 | @AllArgsConstructor 35 | @Slf4j 36 | public class AiMessageController { 37 | private final AiMessageChatMemory chatMemory; 38 | private final ChatModel chatModel; 39 | // private final ImageModel imageModel; 40 | private final VectorStore vectorStore; 41 | private final ObjectMapper objectMapper; 42 | private final AiMessageRepository messageRepository; 43 | private final ApplicationContext applicationContext; 44 | 45 | @DeleteMapping("history/{sessionId}") 46 | public void deleteHistory(@PathVariable String sessionId) { 47 | chatMemory.clear(sessionId); 48 | } 49 | 50 | /** 51 | * 消息保存 52 | * 53 | * @param input 用户发送的消息/AI回复的消息 54 | */ 55 | @PostMapping 56 | public void save(@RequestBody AiMessageInput input) { 57 | messageRepository.save(input.toEntity()); 58 | } 59 | 60 | // @PostMapping("chat/image") 61 | // public String textToImageChat(@RequestBody AiMessageInput input) { 62 | // return imageModel.call(new ImagePrompt(input.getTextContent())).getResult().getOutput().getUrl(); 63 | // } 64 | 65 | /** 66 | * 为了支持文件问答,需要同时接收json(AiMessageWrapper json体)和 MultipartFile(文件) 67 | * Content-Type 从 application/json 修改为 multipart/form-data 68 | * 之前接收请求参数是用@RequestBody, 现在使用@RequestPart 接收json字符串再手动转成AiMessageWrapper. 69 | * SpringMVC的@RequestPart是支持自动将Json字符串转换为Java对象,也就是说可以等效`@RequestBody`, 70 | * 但是由于前端FormData无法设置Part的Content-Type,所以只能手动转json字符串再转成Java对象。 71 | * 72 | * @param input 消息包含文本信息,会话id,多媒体信息(图片语言)。参考src/main/dto/AiMessage.dto 73 | * @param file 文件问答 74 | * @return SSE流 75 | */ 76 | @SneakyThrows 77 | @PostMapping(value = "chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE) 78 | public Flux> chat(@RequestPart String input, @RequestPart(required = false) MultipartFile file) { 79 | AiMessageWrapper aiMessageWrapper = objectMapper.readValue(input, AiMessageWrapper.class); 80 | String[] functionBeanNames = new String[0]; 81 | // 如果启用Agent则获取Agent的bean 82 | if (aiMessageWrapper.getParams().getEnableAgent()) { 83 | // 获取带有Agent注解的bean 84 | Map beansWithAnnotation = applicationContext.getBeansWithAnnotation(Agent.class); 85 | functionBeanNames = new String[beansWithAnnotation.size()]; 86 | functionBeanNames = beansWithAnnotation.keySet().toArray(functionBeanNames); 87 | } 88 | return ChatClient.create(chatModel).prompt() 89 | // 启用文件问答 90 | .system(promptSystemSpec -> useFile(promptSystemSpec, file)) 91 | .user(promptUserSpec -> toPrompt(promptUserSpec, aiMessageWrapper.getMessage())) 92 | // agent列表 93 | .toolNames(functionBeanNames) 94 | .advisors(advisorSpec -> { 95 | // 使用历史消息 96 | useChatHistory(advisorSpec, aiMessageWrapper.getMessage().getSessionId()); 97 | // 使用向量数据库 98 | useVectorStore(advisorSpec, aiMessageWrapper.getParams().getEnableVectorStore()); 99 | }) 100 | .stream() 101 | .chatResponse() 102 | .map(chatResponse -> ServerSentEvent.builder(toJson(chatResponse)) 103 | // 和前端监听的事件相对应 104 | .event("message") 105 | .build()); 106 | } 107 | 108 | @SneakyThrows 109 | public String toJson(ChatResponse response) { 110 | return objectMapper.writeValueAsString(response); 111 | } 112 | 113 | public void toPrompt(ChatClient.PromptUserSpec promptUserSpec, AiMessageInput input) { 114 | // AiMessageInput转成Message 115 | Message message = AiMessageChatMemory.toSpringAiMessage(input.toEntity()); 116 | if (message instanceof UserMessage userMessage && 117 | !CollectionUtils.isEmpty(userMessage.getMedia())) { 118 | // 用户发送的图片/语言 119 | Media[] medias = new Media[userMessage.getMedia().size()]; 120 | promptUserSpec.media(userMessage.getMedia().toArray(medias)); 121 | } 122 | // 用户发送的文本 123 | promptUserSpec.text(message.getText()); 124 | } 125 | 126 | public void useChatHistory(ChatClient.AdvisorSpec advisorSpec, String sessionId) { 127 | // 1. 如果需要存储会话和消息到数据库,自己可以实现ChatMemory接口,这里使用自己实现的AiMessageChatMemory,数据库存储。 128 | // 2. 传入会话id,MessageChatMemoryAdvisor会根据会话id去查找消息。 129 | // 3. 只需要携带最近10条消息 130 | // MessageChatMemoryAdvisor会在消息发送给大模型之前,从ChatMemory中获取会话的历史消息,然后一起发送给大模型。 131 | advisorSpec.advisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId(sessionId).build()); 132 | } 133 | 134 | public void useVectorStore(ChatClient.AdvisorSpec advisorSpec, Boolean enableVectorStore) { 135 | if (!enableVectorStore) return; 136 | // question_answer_context是一个占位符,会替换成向量数据库中查询到的文档。QuestionAnswerAdvisor会替换。 137 | String promptWithContext = """ 138 | {query} 139 | 下面是上下文信息 140 | --------------------- 141 | {question_answer_context} 142 | --------------------- 143 | 给定的上下文和提供的历史信息,而不是事先的知识,回复用户的意见。如果答案不在上下文中,告诉用户你不能回答这个问题。 144 | """; 145 | advisorSpec.advisors(QuestionAnswerAdvisor.builder(vectorStore) 146 | .promptTemplate(new PromptTemplate(promptWithContext)) 147 | .build()); 148 | } 149 | 150 | @SneakyThrows 151 | public void useFile(ChatClient.PromptSystemSpec spec, MultipartFile file) { 152 | if (file == null) return; 153 | String content = new TikaDocumentReader(new InputStreamResource(file.getInputStream())).get().get(0).getText(); 154 | Message message = new PromptTemplate(""" 155 | 已下内容是额外的知识,在你回答问题时可以参考下面的内容 156 | --------------------- 157 | {context} 158 | --------------------- 159 | """) 160 | .createMessage(Map.of("context", content)); 161 | spec.text(message.getText()); 162 | } 163 | 164 | } 165 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/message/AiMessageRepository.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.message; 2 | 3 | import org.babyfish.jimmer.spring.repository.JRepository; 4 | 5 | import java.util.List; 6 | 7 | public interface AiMessageRepository extends JRepository { 8 | AiMessageTable t = AiMessageTable.$; 9 | 10 | default List findBySessionId(String sessionId, int lastN) { 11 | return sql() 12 | .createQuery(t) 13 | .where(t.sessionId().eq(sessionId)) 14 | .orderBy(t.createdTime().asc()) 15 | .select(t) 16 | .limit(lastN) 17 | .execute(); 18 | } 19 | 20 | default Integer deleteBySessionId(String sessionId) { 21 | return sql().createDelete(t) 22 | .where(t.sessionId().eq(sessionId)) 23 | .execute(); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/message/MessageInputWrapper.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.message; 2 | 3 | import lombok.Data; 4 | 5 | @Data 6 | public class MessageInputWrapper { 7 | AiMessage message; 8 | 9 | 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/message/dto/AiMessageParams.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.message.dto; 2 | 3 | import lombok.Data; 4 | 5 | @Data 6 | public class AiMessageParams { 7 | Boolean enableVectorStore; 8 | Boolean enableAgent; 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/message/dto/AiMessageWrapper.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.message.dto; 2 | 3 | import lombok.Data; 4 | 5 | @Data 6 | public class AiMessageWrapper { 7 | AiMessageInput message; 8 | AiMessageParams params; 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/session/AiSession.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.session; 2 | 3 | import io.github.qifan777.knowledge.ai.message.AiMessage; 4 | import io.github.qifan777.knowledge.infrastructure.jimmer.BaseEntity; 5 | import org.babyfish.jimmer.sql.Entity; 6 | import org.babyfish.jimmer.sql.OneToMany; 7 | import org.babyfish.jimmer.sql.OrderedProp; 8 | 9 | import java.util.List; 10 | 11 | /** 12 | * 会话 13 | */ 14 | @Entity 15 | public interface AiSession extends BaseEntity { 16 | 17 | /** 18 | * 会话名称 19 | */ 20 | String name(); 21 | 22 | /** 23 | * 一对多关联消息,按创建时间升序 24 | */ 25 | 26 | @OneToMany(mappedBy = "session", orderedProps = @OrderedProp(value = "createdTime")) 27 | List messages(); 28 | } 29 | 30 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/session/AiSessionController.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.session; 2 | 3 | import io.github.qifan777.knowledge.ai.session.dto.AiSessionInput; 4 | import io.qifan.infrastructure.common.exception.BusinessException; 5 | import lombok.AllArgsConstructor; 6 | import org.babyfish.jimmer.client.FetchBy; 7 | import org.springframework.web.bind.annotation.*; 8 | 9 | import java.util.List; 10 | 11 | @RequestMapping("session") 12 | @RestController 13 | @AllArgsConstructor 14 | public class AiSessionController { 15 | private final AiSessionRepository sessionRepository; 16 | 17 | /** 18 | * 根据id查询会话 19 | * @param id 会话id 20 | * @return 会话信息 21 | */ 22 | @GetMapping("{id}") 23 | public @FetchBy(value = "FETCHER", ownerType = AiSessionRepository.class) AiSession findById(@PathVariable String id) { 24 | return sessionRepository.findById(id, AiSessionRepository.FETCHER).orElseThrow(() -> new BusinessException("会话不存在")); 25 | } 26 | 27 | /** 28 | * 保存会话 29 | * @param input 会话dto参考src/main/dto/AiSession.dto 30 | * @return 创建后的id 31 | */ 32 | @PostMapping("save") 33 | public String save(@RequestBody AiSessionInput input) { 34 | return sessionRepository.save(input.toEntity()).id(); 35 | } 36 | 37 | /** 38 | * 查询当前登录用户的会话 39 | * 40 | * @return 会话列表 41 | */ 42 | @GetMapping("user") 43 | public List<@FetchBy(value = "FETCHER", ownerType = AiSessionRepository.class) AiSession> findByUser() { 44 | return sessionRepository.findByUser(); 45 | } 46 | 47 | /** 48 | * 批量删除会话 49 | * @param ids 会话id列表 50 | */ 51 | @DeleteMapping 52 | public void delete(@RequestBody List ids) { 53 | sessionRepository.deleteByIds(ids); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/ai/session/AiSessionRepository.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.ai.session; 2 | 3 | import cn.dev33.satoken.stp.StpUtil; 4 | import io.github.qifan777.knowledge.ai.message.AiMessageFetcher; 5 | import org.babyfish.jimmer.spring.repository.JRepository; 6 | 7 | import java.util.List; 8 | 9 | public interface AiSessionRepository extends JRepository { 10 | AiSessionTable t = AiSessionTable.$; 11 | AiSessionFetcher FETCHER = AiSessionFetcher.$.allScalarFields() 12 | .messages(AiMessageFetcher.$.allScalarFields().sessionId()); 13 | 14 | default List findByUser() { 15 | return sql().createQuery(t) 16 | .where(t.creatorId().eq(StpUtil.getLoginIdAsString())) 17 | .select(t.fetch(FETCHER)) 18 | .execute(); 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/code/CodeAssistantAgent.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.code; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import com.fasterxml.jackson.annotation.JsonPropertyDescription; 5 | import io.github.qifan777.knowledge.ai.agent.AbstractAgent; 6 | import io.github.qifan777.knowledge.ai.agent.Agent; 7 | import io.github.qifan777.knowledge.code.analyze.AnalyzeFunction; 8 | import io.github.qifan777.knowledge.code.arthas.ArthasFunction; 9 | import lombok.AllArgsConstructor; 10 | import org.springframework.ai.chat.client.ChatClient; 11 | import org.springframework.ai.chat.model.ChatModel; 12 | import org.springframework.context.annotation.Description; 13 | 14 | import java.util.function.Function; 15 | 16 | @Description("提供有关于Java代码的评审分析,在线诊断异常相关的回答") 17 | @Agent 18 | @AllArgsConstructor 19 | public class CodeAssistantAgent extends AbstractAgent implements Function { 20 | private final ChatModel chatModel; 21 | 22 | @Override 23 | public String apply(Request request) { 24 | return ChatClient.create(chatModel) 25 | .prompt() 26 | .user(request.query()) 27 | .toolNames(getFunctions(AnalyzeFunction.class, ArthasFunction.class)) 28 | .call() 29 | .content(); 30 | } 31 | 32 | public record Request( 33 | @JsonProperty(required = true) @JsonPropertyDescription(value = "用户原始的提问") String query) { 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/code/analyze/AnalyzeController.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.code.analyze; 2 | 3 | import com.fasterxml.jackson.databind.ObjectMapper; 4 | import lombok.AllArgsConstructor; 5 | import lombok.SneakyThrows; 6 | import org.springframework.http.MediaType; 7 | import org.springframework.http.codec.ServerSentEvent; 8 | import org.springframework.web.bind.annotation.GetMapping; 9 | import org.springframework.web.bind.annotation.RequestMapping; 10 | import org.springframework.web.bind.annotation.RequestParam; 11 | import org.springframework.web.bind.annotation.RestController; 12 | import reactor.core.publisher.Flux; 13 | 14 | @RestController 15 | @RequestMapping("analyze") 16 | @AllArgsConstructor 17 | public class AnalyzeController { 18 | private final AnalyzeFunction analyzeFunction; 19 | private final ObjectMapper objectMapper; 20 | 21 | @GetMapping(produces = MediaType.TEXT_EVENT_STREAM_VALUE) 22 | public Flux> analyzeTask(@RequestParam String path) { 23 | return analyzeFunction.analyze(path) 24 | .map(content -> ServerSentEvent.builder(toJson(content)) 25 | .event("message") 26 | .build()); 27 | } 28 | 29 | @SneakyThrows 30 | public String toJson(AnalyzeFunction.AnalyzeResult result) { 31 | return objectMapper.writeValueAsString(result); 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/code/analyze/AnalyzeFunction.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.code.analyze; 2 | 3 | import cn.hutool.core.io.FileUtil; 4 | import com.fasterxml.jackson.annotation.JsonProperty; 5 | import com.fasterxml.jackson.annotation.JsonPropertyDescription; 6 | import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration; 7 | import io.github.qifan777.knowledge.code.graph.entity.MethodNode; 8 | import io.github.qifan777.knowledge.code.graph.service.CodeGraphService; 9 | import io.github.qifan777.knowledge.infrastructure.code.CodeAssistantProperties; 10 | import io.github.qifan777.knowledge.infrastructure.code.JavaParserUtils; 11 | import lombok.AllArgsConstructor; 12 | import lombok.Data; 13 | import lombok.SneakyThrows; 14 | import lombok.extern.slf4j.Slf4j; 15 | import org.springframework.ai.chat.model.ChatModel; 16 | import org.springframework.ai.chat.prompt.PromptTemplate; 17 | import org.springframework.context.annotation.Description; 18 | import org.springframework.stereotype.Service; 19 | import reactor.core.publisher.Flux; 20 | 21 | import java.nio.charset.StandardCharsets; 22 | import java.nio.file.Path; 23 | import java.util.List; 24 | import java.util.Map; 25 | import java.util.concurrent.Executor; 26 | import java.util.concurrent.Executors; 27 | import java.util.function.Function; 28 | import java.util.stream.Collectors; 29 | 30 | @Description("评审分析给定的java文件") 31 | @Service 32 | @AllArgsConstructor 33 | @Slf4j 34 | public class AnalyzeFunction implements Function { 35 | private final CodeAssistantProperties properties; 36 | private final CodeGraphService codeGraphService; 37 | // 防止页面多次调用开启多个线程token消耗过多,demo使用单线程 38 | private final Executor executor = Executors.newSingleThreadExecutor(); 39 | private final ChatModel chatModel; 40 | 41 | public record Request(@JsonProperty(required = true) 42 | @JsonPropertyDescription("java文件路径如:src/main/java/xxx.java") String path) { 43 | } 44 | 45 | /** 46 | * 重定向到下面的网页进行展示分析结果,由于一个类里面可能有很多个方法,分析的单位是按照方法来分析,因此单独做一个页面来展示分析结果。 47 | * 48 | * @param request the function argument 49 | * @return 重定向链接 50 | */ 51 | @Override 52 | public String apply(Request request) { 53 | return "请在下面的网页链接查看评审结果:http://localhost:5177/#/analyze?path=" + request.path; 54 | } 55 | 56 | /** 57 | * 创建一个结果分析流,另起一个线程开启解析java文件,获取主类中的所有方法,然后分析方法调用得到分析结果,推流到前端。 58 | * 59 | * @param filePath java文件路径 如:io.qifan.github777.UserController.java 60 | * @return Flux 结果分析流 61 | */ 62 | @SneakyThrows 63 | public Flux analyze(String filePath) { 64 | log.info("正在评审文件:{}", filePath); 65 | return Flux.create(fluxSink -> executor.execute(() -> { 66 | JavaParserUtils.parse(Path.of(properties.getProject().getProjectPath(), "src", "main", "java", filePath)) 67 | .getResult() 68 | .map(compilationUnit -> compilationUnit.findAll(ClassOrInterfaceDeclaration.class)) 69 | // 只分析主类, 可能java文件中一个类都没有,因此返回的是一个Optional。如果不使用FlatMap,会直接返回一个Optional>, 70 | // 因此需要使用flatMap,是的返回结果变成Optional 71 | .flatMap(classOrInterfaceDeclarations -> classOrInterfaceDeclarations.stream().findFirst()) 72 | // 由于类可能是匿名内部类,这边过滤一下。下面可以直接使用get()。当然实际上这个地方的主类类名肯定是存在的,但是为了写法严谨我还是判断了一下 73 | .filter(classOrInterfaceDeclaration -> classOrInterfaceDeclaration.getFullyQualifiedName().isPresent()) 74 | // 获取到类名,并遍历所有方法 75 | .map(classOrInterfaceDeclaration -> { 76 | String fullyQualifiedName = classOrInterfaceDeclaration 77 | .getFullyQualifiedName().get(); 78 | return classOrInterfaceDeclaration 79 | .getMethods() 80 | .stream() 81 | .map(methodDeclaration -> fullyQualifiedName + "#" + methodDeclaration.getNameAsString()) 82 | .toList(); 83 | }) 84 | .ifPresentOrElse(methodIds -> { 85 | methodIds.forEach(methodId -> { 86 | analyzeMethod(methodId).doOnNext(fluxSink::next).blockLast(); 87 | }); 88 | }, () -> { 89 | // 如果没有主类(Class Or Interface),则直接分析整个文件,或者不是Java文件(Mapper.xml) 90 | fluxSink.next(analyzeFile(filePath)); 91 | }); 92 | // 完成 93 | fluxSink.complete(); 94 | })); 95 | } 96 | 97 | /** 98 | * 分析单个方法调用 99 | * 100 | * @param methodId 方法ID 101 | * @return 分析结果流 102 | */ 103 | public Flux analyzeMethod(String methodId) { 104 | List childMethods = codeGraphService.findChildMethods(methodId); 105 | String prompt = new PromptTemplate(""" 106 | 请你根据{methodId}的调用链评审代码,并给出你的改进建议,并且附带修改后的代码片段,用中文回答。 107 | {methodChains} 108 | """) 109 | .createMessage(Map.of("methodChains", childMethods 110 | .stream() 111 | .map(MethodNode::getContent) 112 | .distinct() 113 | .collect(Collectors.joining("\n")), 114 | "methodId", methodId)) 115 | .getText(); 116 | log.info("评审方法: {}", prompt); 117 | String content = childMethods.stream().filter(m -> m.getId().equals(methodId)).findFirst().orElseThrow().getContent(); 118 | return chatModel.stream(prompt).map(response -> new AnalyzeResult(methodId, response, methodId.split("#")[0], content)); 119 | } 120 | 121 | public AnalyzeResult analyzeFile(String filePath) { 122 | String fileContent = FileUtil.readString(Path.of(properties.getProject().getProjectPath(), filePath).toFile(), StandardCharsets.UTF_8); 123 | String prompt = new PromptTemplate(""" 124 | 请你评审一下该提交文件是否有可以改进的地方,并且附带修改后的代码片段,没有请回答无,用中文回答。 125 | {content} 126 | """) 127 | .createMessage(Map.of("content", fileContent)).getText(); 128 | log.info("评审文件提示词: {}", prompt); 129 | String result = chatModel.call(prompt); 130 | log.info("文件分析结果: {}", result); 131 | return new AnalyzeResult(filePath, result, filePath, fileContent); 132 | } 133 | 134 | @Data 135 | @AllArgsConstructor 136 | public static class AnalyzeResult { 137 | private String id; 138 | private String content; 139 | private String fileName; 140 | private String fileContent; 141 | } 142 | 143 | 144 | } -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/code/arthas/ArthasFunction.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.code.arthas; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import com.fasterxml.jackson.annotation.JsonPropertyDescription; 5 | import com.fasterxml.jackson.databind.JsonNode; 6 | import io.github.qifan777.knowledge.code.graph.entity.MethodNode; 7 | import io.github.qifan777.knowledge.code.graph.service.CodeGraphService; 8 | import io.github.qifan777.knowledge.infrastructure.code.CodeAssistantProperties; 9 | import lombok.AllArgsConstructor; 10 | import lombok.Data; 11 | import lombok.experimental.Accessors; 12 | import lombok.extern.slf4j.Slf4j; 13 | import org.springframework.ai.chat.model.ChatModel; 14 | import org.springframework.ai.chat.prompt.PromptTemplate; 15 | import org.springframework.context.annotation.Description; 16 | import org.springframework.http.HttpEntity; 17 | import org.springframework.http.HttpHeaders; 18 | import org.springframework.http.HttpMethod; 19 | import org.springframework.stereotype.Service; 20 | import org.springframework.web.client.RestTemplate; 21 | 22 | import java.nio.charset.StandardCharsets; 23 | import java.util.Base64; 24 | import java.util.List; 25 | import java.util.Map; 26 | import java.util.function.Function; 27 | import java.util.stream.Collectors; 28 | 29 | @Slf4j 30 | @Service 31 | @AllArgsConstructor 32 | @Description("诊断方法出现异常的原因") 33 | public class ArthasFunction implements Function { 34 | private final CodeGraphService codeGraphService; 35 | private final CodeAssistantProperties properties; 36 | private final ChatModel chatModel; 37 | 38 | public record Request(@JsonProperty(required = true) 39 | @JsonPropertyDescription("类名") String className, 40 | @JsonProperty(required = true) 41 | @JsonPropertyDescription("类名") String methodName) { 42 | } 43 | 44 | @Override 45 | public String apply(Request request) { 46 | String methodId = request.className + "#" + request.methodName; 47 | log.info("监听目标:{}", methodId); 48 | JobResult jobResult = startWatch(request.className, request.methodName); 49 | if (jobResult == null) { 50 | return "无异常信息"; 51 | } 52 | String analyzeResult = jobResult 53 | .getBody() 54 | .getResults() 55 | .stream() 56 | .filter(r -> r.getType().equals("tt")) 57 | .findFirst() 58 | .map(result -> { 59 | String methods = codeGraphService.findChildMethods(methodId).stream().map(MethodNode::getContent) 60 | .distinct() 61 | .collect(Collectors.joining("\n")); 62 | TimeFragment timeFragment = result.getTimeFragmentList().get(0); 63 | String content = new PromptTemplate(""" 64 | 根据下面提供的内容分析异常原因,回答结果用中文 65 | 方法名称:{methodName} 66 | 方法调用链: {methods} 67 | 方法参数:{params} 68 | 异常信息:{exp} 69 | """) 70 | .createMessage(Map.of("methodName", timeFragment.getMethodName(), 71 | "methods", methods, 72 | "params", timeFragment.getParams().stream().map(param -> param.getObject() == null ? "" : param.getObject().toString()).collect(Collectors.joining("\n")), 73 | "exp", timeFragment.getThrowExp())) 74 | .getText(); 75 | log.info("代码诊断prompt: {}", content); 76 | return chatModel.call(content); 77 | }) 78 | .orElse("无异常信息"); 79 | log.info("诊断结果: {}", analyzeResult); 80 | return analyzeResult; 81 | } 82 | 83 | 84 | public JobResult startWatch(String className, String method) { 85 | CodeAssistantProperties.ArthasProperties arthasProperties = properties.getArthas(); 86 | RestTemplate restTemplate = new RestTemplate(); 87 | String encode = Base64.getEncoder().encodeToString((arthasProperties.getUsername() + ":" + arthasProperties.getPassword()).getBytes(StandardCharsets.UTF_8)); 88 | ArthasRequest arthasRequest = new ArthasRequest().setAction("exec").setCommand("tt -t " + className + " " + method + " -n 1"); 89 | HttpHeaders httpHeaders = new HttpHeaders(); 90 | httpHeaders.add("Authorization", "Basic " + encode); 91 | HttpEntity requestHttpEntity = new HttpEntity<>(arthasRequest, httpHeaders); 92 | return restTemplate.exchange(arthasProperties.getUrl(), HttpMethod.POST, requestHttpEntity, JobResult.class).getBody(); 93 | 94 | } 95 | 96 | @Accessors(chain = true) 97 | @Data 98 | public static class ArthasRequest { 99 | private String action; 100 | private String command; 101 | } 102 | 103 | @Data 104 | public static class JobResult { 105 | @JsonProperty("body") 106 | private Body body; 107 | private String message; 108 | private String sessionId; 109 | private String state; 110 | } 111 | 112 | @Data 113 | public static class Body { 114 | @JsonProperty("command") 115 | private String command; 116 | private int jobId; 117 | private String jobStatus; 118 | @JsonProperty("results") 119 | private List results; 120 | @JsonProperty("timeExpired") 121 | private boolean timeExpired; 122 | private int timeout; 123 | } 124 | 125 | @Data 126 | public static class Result { 127 | private Effect effect; 128 | private int jobId; 129 | private boolean success; 130 | private String type; 131 | @JsonProperty("timeFragmentList") 132 | private List timeFragmentList; 133 | private int statusCode; 134 | private boolean first; 135 | } 136 | 137 | @Data 138 | public static class Effect { 139 | private int classCount; 140 | private int cost; 141 | private int listenerId; 142 | private int methodCount; 143 | } 144 | 145 | @Data 146 | public static class TimeFragment { 147 | private String className; 148 | private double cost; 149 | private int index; 150 | private String methodName; 151 | private String object; 152 | private List params; 153 | @JsonProperty("return") 154 | private boolean isReturn; 155 | private String returnObj; 156 | @JsonProperty("throw") 157 | private boolean isThrow; 158 | @JsonProperty("throwExp") 159 | private String throwExp; 160 | @JsonProperty("timestamp") 161 | private String timestamp; 162 | } 163 | 164 | @Data 165 | public static class Param { 166 | private int expand; 167 | private JsonNode object; 168 | } 169 | 170 | } 171 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/code/graph/controller/CodeGraphController.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.code.graph.controller; 2 | 3 | import io.github.qifan777.knowledge.code.graph.service.CodeGraphService; 4 | import lombok.AllArgsConstructor; 5 | import lombok.extern.slf4j.Slf4j; 6 | import org.springframework.web.bind.annotation.PostMapping; 7 | import org.springframework.web.bind.annotation.RequestMapping; 8 | import org.springframework.web.bind.annotation.RestController; 9 | 10 | @RestController 11 | @RequestMapping("code/graph") 12 | @AllArgsConstructor 13 | @Slf4j 14 | public class CodeGraphController { 15 | private final CodeGraphService codeGraphService; 16 | 17 | @PostMapping("build") 18 | public String buildGraph() { 19 | codeGraphService.buildGraph(); 20 | return "SUCCESS"; 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/code/graph/entity/ClassNode.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.code.graph.entity; 2 | 3 | import lombok.Data; 4 | import lombok.experimental.Accessors; 5 | import org.springframework.data.neo4j.core.schema.Id; 6 | import org.springframework.data.neo4j.core.schema.Node; 7 | import org.springframework.data.neo4j.core.schema.Relationship; 8 | 9 | import java.util.List; 10 | 11 | @Node 12 | @Data 13 | @Accessors(chain = true) 14 | public class ClassNode { 15 | @Id 16 | private String id; 17 | private String name; 18 | private String content; 19 | @Relationship(direction = Relationship.Direction.OUTGOING, type = "OWNS") 20 | private List ownsMethodNodes; 21 | @Relationship(direction = Relationship.Direction.OUTGOING, type = "IMPORTS") 22 | private List importNodes; 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/code/graph/entity/MethodNode.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.code.graph.entity; 2 | 3 | import lombok.Data; 4 | import lombok.experimental.Accessors; 5 | import org.springframework.data.neo4j.core.schema.Id; 6 | import org.springframework.data.neo4j.core.schema.Node; 7 | import org.springframework.data.neo4j.core.schema.Relationship; 8 | 9 | import java.util.List; 10 | 11 | @Node 12 | @Data 13 | @Accessors(chain = true) 14 | public class MethodNode { 15 | @Id 16 | private String id; 17 | private String name; 18 | private String comment; 19 | private String content; 20 | @Relationship(direction = Relationship.Direction.OUTGOING, type = "USES") 21 | private List usesMethodNodes; 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/code/graph/repository/ClassNodeRepository.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.code.graph.repository; 2 | 3 | import io.github.qifan777.knowledge.code.graph.entity.ClassNode; 4 | import org.springframework.data.neo4j.repository.Neo4jRepository; 5 | 6 | public interface ClassNodeRepository extends Neo4jRepository { 7 | } 8 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/code/graph/repository/MethodNodeRepository.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.code.graph.repository; 2 | 3 | import io.github.qifan777.knowledge.code.graph.entity.MethodNode; 4 | import org.springframework.data.neo4j.repository.Neo4jRepository; 5 | 6 | public interface MethodNodeRepository extends Neo4jRepository { 7 | 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/code/graph/service/CodeGraphBuilder.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.code.graph.service; 2 | 3 | import com.github.javaparser.JavaParser; 4 | import com.github.javaparser.ast.ImportDeclaration; 5 | import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration; 6 | import com.github.javaparser.ast.body.MethodDeclaration; 7 | import com.github.javaparser.ast.comments.Comment; 8 | import com.github.javaparser.ast.expr.AnnotationExpr; 9 | import com.github.javaparser.ast.expr.Expression; 10 | import com.github.javaparser.ast.expr.MethodCallExpr; 11 | import io.github.qifan777.knowledge.code.graph.entity.ClassNode; 12 | import io.github.qifan777.knowledge.code.graph.entity.MethodNode; 13 | import lombok.SneakyThrows; 14 | import lombok.extern.slf4j.Slf4j; 15 | import org.w3c.dom.Document; 16 | import org.w3c.dom.NodeList; 17 | 18 | import javax.xml.parsers.DocumentBuilder; 19 | import javax.xml.parsers.DocumentBuilderFactory; 20 | import javax.xml.transform.OutputKeys; 21 | import javax.xml.transform.Transformer; 22 | import javax.xml.transform.TransformerFactory; 23 | import javax.xml.transform.dom.DOMSource; 24 | import javax.xml.transform.stream.StreamResult; 25 | import java.io.File; 26 | import java.io.StringWriter; 27 | import java.nio.file.Files; 28 | import java.nio.file.Path; 29 | import java.util.*; 30 | import java.util.stream.Stream; 31 | 32 | @Slf4j 33 | public class CodeGraphBuilder { 34 | private final Map mapperSqlMap = new HashMap<>(); 35 | private final Map classDeclarationMap = new HashMap<>(); 36 | private final Map classNodeMap = new HashMap<>(); 37 | private final Map methodNodeMap = new HashMap<>(); 38 | private final Path projectPath; 39 | private final JavaParser javaParser; 40 | 41 | public CodeGraphBuilder(String projectPath, JavaParser javaParser) { 42 | this.projectPath = Path.of(projectPath); 43 | this.javaParser = javaParser; 44 | } 45 | 46 | public record BuildContext(Collection classNodes, Collection methodNodes) { 47 | } 48 | 49 | /** 50 | * 构建java方法关系图谱,包含类节点和类方法节点,类和类之间的引用关系,类和方法之间的归属关系,方法和方法之间的调用关系 51 | * 52 | * @return 构建好的类节点和类方法节点 53 | */ 54 | @SneakyThrows 55 | public BuildContext buildGraph() { 56 | // 构建mybatis sql 57 | buildMapperSqlMap(); 58 | try (Stream pathStream = Files.walk(getJavaSourcePath())) { 59 | pathStream.filter(path -> path.toFile().isFile()) 60 | .flatMap(path -> getClassDeclarations(path).stream()) 61 | .filter(declaration -> declaration.getFullyQualifiedName().isPresent()) 62 | .forEach(declaration -> classDeclarationMap.put(declaration.getFullyQualifiedName().get(), declaration)); 63 | 64 | classDeclarationMap.values().forEach(this::buildClassNode); 65 | classNodeMap 66 | .values() 67 | .forEach(classNode -> { 68 | ClassOrInterfaceDeclaration classOrInterfaceDeclaration = classDeclarationMap.get(classNode.getId()); 69 | List methodDeclarations = classOrInterfaceDeclaration.findAll(MethodDeclaration.class); 70 | List ownsMethodNodes = methodDeclarations 71 | .stream() 72 | .map(methodDeclaration -> buildMethodNode(methodDeclaration.getNameAsString(), classNode.getId(), methodDeclarations)) 73 | .filter(Optional::isPresent) 74 | .map(Optional::get) 75 | .toList(); 76 | classNode.setOwnsMethodNodes(ownsMethodNodes); 77 | }); 78 | 79 | } 80 | return new BuildContext(classNodeMap.values(), methodNodeMap.values()); 81 | } 82 | 83 | /** 84 | * 获取指定路径下的java文件中的所有类声明,一个java文件中可能包含多个类 85 | * 86 | * @param path java文件路径 87 | * @return 类声明列表 88 | */ 89 | @SneakyThrows 90 | private List getClassDeclarations(Path path) { 91 | return javaParser.parse(path.toFile()) 92 | .getResult() 93 | .map(compilationUnit -> compilationUnit.findAll(ClassOrInterfaceDeclaration.class)) 94 | .filter(list -> !list.isEmpty()) 95 | .orElse(List.of()); 96 | } 97 | 98 | /** 99 | * 递归构建类节点 100 | * 101 | * @param declaration 类或者接口声明 102 | * @return 类节点 103 | */ 104 | @SneakyThrows 105 | private Optional buildClassNode(ClassOrInterfaceDeclaration declaration) { 106 | // 用classNodeMap缓存,避免重复构建 107 | if (classNodeMap.containsKey(declaration.getNameAsString())) { 108 | return Optional.of(classNodeMap.get(declaration.getNameAsString())); 109 | } 110 | return declaration 111 | .getFullyQualifiedName() 112 | .map(qualifiedClasName -> { 113 | ClassNode classNode = new ClassNode().setId(qualifiedClasName) 114 | .setName(declaration.getNameAsString()) 115 | .setContent(declaration.toString()); 116 | // 缓存类节点 117 | classNodeMap.put(qualifiedClasName, classNode); 118 | List importClassNodes = declaration 119 | .findAll(ImportDeclaration.class) 120 | .stream() 121 | // 递归构建类节点 122 | .map(importDeclaration -> Optional.ofNullable(classDeclarationMap.get(importDeclaration.getNameAsString())) 123 | .flatMap(this::buildClassNode)) 124 | .filter(Optional::isPresent) 125 | .map(Optional::get) 126 | .toList(); 127 | classNode.setImportNodes(importClassNodes); 128 | return classNode; 129 | }); 130 | } 131 | 132 | /** 133 | * 递归构建方法节点 134 | * 135 | * @param methodName 方法名称,如login 136 | * @param className 方法所在的全限定类名,如io.qifan.xxx.UserService 137 | * @param declarations 类中的所有方法,如login、logout等 138 | * @return 方法节点 139 | */ 140 | @SneakyThrows 141 | private Optional buildMethodNode(String methodName, String className, List declarations) { 142 | String methodId = className + "#" + methodName; 143 | // 用methodNodeMap缓存,避免重复构建 144 | if (methodNodeMap.containsKey(methodId)) { 145 | return Optional.of(methodNodeMap.get(methodId)); 146 | } 147 | return declarations.stream() 148 | .filter(methodDeclaration -> methodDeclaration.getNameAsString().equals(methodName)) 149 | .findFirst() 150 | .map(methodDeclaration -> { 151 | // 获取方法内容,如果是mapper接口则获取方法对应的sql 152 | String content = methodDeclaration.findAll(AnnotationExpr.class) 153 | .stream() 154 | .filter(a -> a.getMetaModel().getQualifiedClassName().equals("org.apache.ibatis.annotations.Mapper")) 155 | .findAny() 156 | .map(annotationExpr -> mapperSqlMap.get(methodId)) 157 | .orElse(methodDeclaration.toString()); 158 | MethodNode methodNode = new MethodNode() 159 | .setId(methodId) 160 | .setName(methodDeclaration.getNameAsString()) 161 | .setContent(content) 162 | .setComment(methodDeclaration.getComment().map(Comment::getContent).orElse("")); 163 | // 缓存方法节点 164 | methodNodeMap.put(methodNode.getId(), methodNode); 165 | // 递归构建方法调用关系 166 | List usesMethodNodes = methodDeclaration 167 | .findAll(MethodCallExpr.class) 168 | .stream() 169 | .map(this::buildMethodNodeFromMethodCall) 170 | .filter(Optional::isPresent) 171 | .map(Optional::get) 172 | .toList(); 173 | methodNode.setUsesMethodNodes(usesMethodNodes); 174 | return methodNode; 175 | }); 176 | } 177 | 178 | /** 179 | * userService.login()这段代码指的是methodCall,要获取login()的方法体内容,需要先解析userService中的所有方法,然后取出login方法 180 | * 181 | * @param methodCallExpr 方法调用表达式 182 | * @return 方法节点 183 | */ 184 | public Optional buildMethodNodeFromMethodCall(MethodCallExpr methodCallExpr) { 185 | return methodCallExpr 186 | .getScope() 187 | .filter(this::checkScopeExist) 188 | .flatMap(scope -> buildMethodNode(methodCallExpr.getNameAsString(), scope.calculateResolvedType().asReferenceType().getQualifiedName(), getMethodDeclarationsFromScope(scope))); 189 | } 190 | 191 | /** 192 | * userService.login(), scope指的是userService. 获取userService中的所有方法 193 | * 194 | * @param scope 方法所在的对象 195 | * @return 方法列表 196 | */ 197 | public List getMethodDeclarationsFromScope(Expression scope) { 198 | return Optional.ofNullable(classDeclarationMap.get(scope.calculateResolvedType() 199 | .asReferenceType().getQualifiedName())) 200 | .map(declaration -> declaration.findAll(MethodDeclaration.class)) 201 | .orElse(List.of()); 202 | } 203 | 204 | /** 205 | * 初始化mybatis mapper xml,提取其中的sql将方法名称和sql语句对应起来 206 | */ 207 | @SneakyThrows 208 | private void buildMapperSqlMap() { 209 | try (Stream mapper = Files.walk(getFileInResource("mapper"))) { 210 | mapper.filter(path -> path.toString().endsWith(".xml")) 211 | .forEach(file -> { 212 | Document document = parseXMLFileAsDocument(file.toFile()); 213 | NodeList selectNodes = document.getDocumentElement().getElementsByTagName("select"); 214 | String namespace = document.getDocumentElement().getAttribute("namespace"); 215 | extractSqlFromStatement(selectNodes, namespace); 216 | NodeList deleteNodes = document.getDocumentElement().getElementsByTagName("delete"); 217 | extractSqlFromStatement(deleteNodes, namespace); 218 | NodeList updateNodes = document.getDocumentElement().getElementsByTagName("update"); 219 | extractSqlFromStatement(updateNodes, namespace); 220 | }); 221 | } catch (Exception ignored) { 222 | log.warn("不存在mapper"); 223 | } 224 | } 225 | 226 | /** 227 | * 解析mapper xml文件 228 | * 229 | * @param file mapper文件路径 230 | * @return document 231 | */ 232 | @SneakyThrows 233 | private Document parseXMLFileAsDocument(File file) { 234 | DocumentBuilder builder = DocumentBuilderFactory.newInstance().newDocumentBuilder(); 235 | return builder.parse(file); 236 | } 237 | 238 | /** 239 | * 将mapper中的sql语句提取出来,并和sql语句对应的方法名对应起来 240 | * 241 | * @param nodeList select/delete/update标签 242 | * @param namespace mapper的namespace 243 | */ 244 | @SneakyThrows 245 | private void extractSqlFromStatement(NodeList nodeList, String namespace) { 246 | TransformerFactory tf = TransformerFactory.newInstance(); 247 | Transformer transformer = tf.newTransformer(); 248 | transformer.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "yes"); 249 | for (int i = 0; i < nodeList.getLength(); i++) { 250 | StringWriter writer = new StringWriter(); 251 | transformer.transform(new DOMSource(nodeList.item(i)), new StreamResult(writer)); 252 | String output = writer.getBuffer().toString().replaceAll("\n|\r", ""); 253 | mapperSqlMap.put(namespace + "#" + nodeList.item(i).getAttributes().getNamedItem("id").getNodeValue(), output); 254 | } 255 | } 256 | 257 | 258 | private Path getJavaSourcePath() { 259 | return projectPath.resolve(Path.of("src", "main", "java")); 260 | } 261 | 262 | private Path getFileInResource(String fileName) { 263 | return projectPath.resolve(Path.of("src", "main", "resources", fileName)); 264 | } 265 | 266 | 267 | private boolean checkScopeExist(Expression expression) { 268 | try { 269 | expression.calculateResolvedType().asReferenceType(); 270 | } catch (Exception e) { 271 | return false; 272 | } 273 | return true; 274 | } 275 | 276 | } 277 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/code/graph/service/CodeGraphService.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.code.graph.service; 2 | 3 | import io.github.qifan777.knowledge.code.graph.entity.MethodNode; 4 | import io.github.qifan777.knowledge.code.graph.repository.ClassNodeRepository; 5 | import io.github.qifan777.knowledge.code.graph.repository.MethodNodeRepository; 6 | import io.github.qifan777.knowledge.infrastructure.code.CodeAssistantProperties; 7 | import io.github.qifan777.knowledge.infrastructure.code.JavaParserUtils; 8 | import lombok.AllArgsConstructor; 9 | import lombok.SneakyThrows; 10 | import lombok.extern.slf4j.Slf4j; 11 | import org.springframework.data.neo4j.core.Neo4jClient; 12 | import org.springframework.stereotype.Service; 13 | 14 | import java.util.ArrayList; 15 | import java.util.List; 16 | 17 | @Slf4j 18 | @Service 19 | @AllArgsConstructor 20 | public class CodeGraphService { 21 | private final ClassNodeRepository classNodeRepository; 22 | private final MethodNodeRepository methodNodeRepository; 23 | private final Neo4jClient neo4jClient; 24 | private final CodeAssistantProperties properties; 25 | 26 | @SneakyThrows 27 | public void buildGraph() { 28 | methodNodeRepository.deleteAll(); 29 | classNodeRepository.deleteAll(); 30 | CodeGraphBuilder.BuildContext buildContext = new CodeGraphBuilder(properties.getProject().getProjectPath(), JavaParserUtils.getJavaParser()).buildGraph(); 31 | classNodeRepository.saveAll(buildContext.classNodes()); 32 | log.info("类节点保存完毕: {}", classNodeRepository.count()); 33 | methodNodeRepository.saveAll(buildContext.methodNodes()); 34 | log.info("方法节点保存完毕: {}", methodNodeRepository.count()); 35 | } 36 | 37 | public List findChildMethods(String methodId) { 38 | String cypher = """ 39 | match window=(m)-[:USES*0..3]->(:MethodNode) 40 | where m.id = $methodId 41 | with nodes(window) as nodeList 42 | unwind nodeList as nodeRows 43 | return nodeRows; 44 | """; 45 | return findMethods(cypher, methodId); 46 | } 47 | 48 | public ArrayList findMethods(String cypher, String methodId) { 49 | return new ArrayList<>(neo4jClient.query(cypher) 50 | .bind(methodId).to("methodId") 51 | .fetchAs(MethodNode.class) 52 | .mappedBy((typeSystem, record) -> { 53 | MethodNode methodNode = new MethodNode(); 54 | methodNode.setContent(String.valueOf(record.get(0).get("content"))); 55 | methodNode.setId(String.valueOf(record.get(0).get("id")).replaceAll("\"", "")); 56 | return methodNode; 57 | }) 58 | .all()); 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/demo/DocumentAnalyzerFunction.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.demo; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import com.fasterxml.jackson.annotation.JsonPropertyDescription; 5 | import lombok.Data; 6 | import lombok.SneakyThrows; 7 | import lombok.extern.slf4j.Slf4j; 8 | import org.springframework.ai.reader.tika.TikaDocumentReader; 9 | import org.springframework.context.annotation.Description; 10 | import org.springframework.core.io.FileSystemResource; 11 | import org.springframework.stereotype.Service; 12 | 13 | import java.util.function.Function; 14 | 15 | /** 16 | * 通过@Description描述函数的用途,这样ai在多个函数中可以根据描述进行选择。 17 | */ 18 | @Description("文档解析函数") 19 | @Service 20 | @Slf4j 21 | public class DocumentAnalyzerFunction implements Function { 22 | /** 23 | * 通过@JsonProperty声明属性名称和是否必填 24 | * 通过@JsonPropertyDescription描述属性的用途,这样ai可以提取出符合参数描述的内容。 25 | */ 26 | @Data 27 | public static class Request { 28 | @JsonProperty(required = true, value = "path") 29 | @JsonPropertyDescription(value = "需要解析的本地文件路径") 30 | String path; 31 | } 32 | 33 | public record Response(String result) { 34 | } 35 | 36 | @SneakyThrows 37 | @Override 38 | public Response apply(Request request) { 39 | // ai解析用户的提问得到path参数,使用tika读取本地文件获取内容。把读取到的内容再返回给ai作为上下文去回答用户的问题。 40 | TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new FileSystemResource(request.path)); 41 | return new Response(tikaDocumentReader.read().get(0).getText()); 42 | } 43 | 44 | 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/demo/DocumentDemoController.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.demo; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.SneakyThrows; 5 | import lombok.extern.slf4j.Slf4j; 6 | import org.springframework.ai.chat.client.ChatClient; 7 | import org.springframework.ai.chat.model.ChatModel; 8 | import org.springframework.ai.document.Document; 9 | import org.springframework.ai.reader.tika.TikaDocumentReader; 10 | import org.springframework.ai.transformer.splitter.TokenTextSplitter; 11 | import org.springframework.ai.vectorstore.VectorStore; 12 | import org.springframework.core.io.InputStreamResource; 13 | import org.springframework.web.bind.annotation.*; 14 | import org.springframework.web.multipart.MultipartFile; 15 | 16 | import java.util.List; 17 | 18 | @RequestMapping("demo/document") 19 | @RestController 20 | @AllArgsConstructor 21 | @Slf4j 22 | public class DocumentDemoController { 23 | private final VectorStore vectorStore; 24 | private final ChatModel chatModel; 25 | 26 | /** 27 | * 嵌入文件 28 | * 29 | * @param file 待嵌入的文件 30 | * @return 是否成功 31 | */ 32 | @SneakyThrows 33 | @PostMapping("embedding") 34 | public Boolean embedding(@RequestParam MultipartFile file) { 35 | // 从IO流中读取文件 36 | TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new InputStreamResource(file.getInputStream())); 37 | // 将文本内容划分成更小的块 38 | List splitDocuments = new TokenTextSplitter() 39 | .apply(tikaDocumentReader.read()); 40 | // 存入向量数据库,这个过程会自动调用embeddingModel,将文本变成向量再存入。 41 | vectorStore.add(splitDocuments); 42 | return true; 43 | } 44 | 45 | /** 46 | * 查询向量数据库 47 | * 48 | * @param query 用户的提问 49 | * @return 匹配到的文档 50 | */ 51 | 52 | @GetMapping("query") 53 | public List query(@RequestParam String query) { 54 | return vectorStore.similaritySearch(query); 55 | } 56 | 57 | @GetMapping("chat") 58 | public String chat(@RequestParam String query) { 59 | return ChatClient.create(chatModel) 60 | .prompt(query) 61 | .toolNames("documentAnalyzerFunction") 62 | .call() 63 | .content(); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/demo/MessageDemoController.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.demo; 2 | 3 | import com.fasterxml.jackson.databind.ObjectMapper; 4 | import io.github.qifan777.knowledge.ai.message.dto.AiMessageParams; 5 | import io.github.qifan777.knowledge.ai.message.dto.AiMessageWrapper; 6 | import lombok.AllArgsConstructor; 7 | import lombok.SneakyThrows; 8 | import org.springframework.ai.chat.client.ChatClient; 9 | import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; 10 | import org.springframework.ai.chat.client.advisor.vectorstore.QuestionAnswerAdvisor; 11 | import org.springframework.ai.chat.memory.MessageWindowChatMemory; 12 | import org.springframework.ai.chat.messages.SystemMessage; 13 | import org.springframework.ai.chat.messages.UserMessage; 14 | import org.springframework.ai.chat.model.ChatModel; 15 | import org.springframework.ai.chat.model.ChatResponse; 16 | import org.springframework.ai.chat.prompt.PromptTemplate; 17 | import org.springframework.ai.vectorstore.VectorStore; 18 | import org.springframework.http.MediaType; 19 | import org.springframework.http.codec.ServerSentEvent; 20 | import org.springframework.web.bind.annotation.*; 21 | import reactor.core.publisher.Flux; 22 | 23 | @RequestMapping("demo/message") 24 | @RestController 25 | @AllArgsConstructor 26 | public class MessageDemoController { 27 | 28 | private final ChatModel chatModel; 29 | 30 | 31 | private final ObjectMapper objectMapper; 32 | private final VectorStore vectorStore; 33 | // 模拟数据库存储会话和消息 34 | private final MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder() 35 | .maxMessages(10) 36 | .build(); 37 | 38 | /** 39 | * 非流式问答 40 | * 41 | * @param prompt 用户提问 42 | * @return org.springframework.ai.chat.model.ChatResponse 43 | */ 44 | @GetMapping("chat") 45 | public String chat(@RequestParam String prompt) { 46 | ChatClient chatClient = ChatClient.create(chatModel); 47 | return chatClient.prompt() 48 | // 输入单条提示词 49 | .user(prompt) 50 | // call代表非流式问答,返回的结果可以是ChatResponse,也可以是Entity(转成java类型),也可以是字符串直接提取回答结果。 51 | .call() 52 | .content(); 53 | } 54 | 55 | /** 56 | * 流式问答 57 | * 58 | * @param prompt 用户提问 59 | * @return SSE流式响应 60 | */ 61 | @GetMapping(value = "chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) 62 | public Flux> chatStream(@RequestParam String prompt) { 63 | return ChatClient.create(chatModel).prompt() 64 | // 输入多条消息,可以将历史消息记录传入 65 | .messages(new SystemMessage("你是一个Java智能助手,应用你的Java知识帮助用户解决问题或者编写程序"), 66 | new UserMessage(prompt)) 67 | // 流式返回 68 | .stream() 69 | // 构造SSE(ServerSendEvent)格式返回结果 70 | .chatResponse().map(chatResponse -> ServerSentEvent.builder(toJson(chatResponse)) 71 | .event("message") 72 | .build()); 73 | } 74 | 75 | /** 76 | * 将流式回答结果转json字符串 77 | * 78 | * @param chatResponse 流式回答结果 79 | * @return String json字符串 80 | */ 81 | @SneakyThrows 82 | public String toJson(ChatResponse chatResponse) { 83 | return objectMapper.writeValueAsString(chatResponse); 84 | } 85 | 86 | /** 87 | * 调用自定义函数回答用户的提问 88 | * 89 | * @param prompt 用户的提问 90 | * @param functionName 函数名称(bean的名称,类名小写) 91 | * @return SSE流式响应 92 | */ 93 | @GetMapping(value = "chat/stream/function", produces = MediaType.TEXT_EVENT_STREAM_VALUE) 94 | public Flux> chatStreamWithFunction(@RequestParam String prompt, @RequestParam String functionName) { 95 | return ChatClient.create(chatModel).prompt() 96 | .messages(new UserMessage(prompt)) 97 | // spring ai会从已注册为bean的function中查找函数,将它添加到请求中。如果成功触发就会调用函数 98 | .toolNames(functionName) 99 | .stream() 100 | .chatResponse() 101 | .map(chatResponse -> ServerSentEvent.builder(toJson(chatResponse)) 102 | .event("message") 103 | .build()); 104 | } 105 | 106 | /** 107 | * 从向量数据库中查找文档,并将查询的文档作为上下文回答。 108 | * 109 | * @param prompt 用户的提问 110 | * @return SSE流响应 111 | */ 112 | @GetMapping(value = "chat/stream/database", produces = MediaType.TEXT_EVENT_STREAM_VALUE) 113 | public Flux> chatStreamWithDatabase(@RequestParam String prompt) { 114 | // question_answer_context是一个占位符,会替换成向量数据库中查询到的文档。QuestionAnswerAdvisor会替换。 115 | String promptWithContext = """ 116 | {query} 117 | 下面是上下文信息 118 | --------------------- 119 | {question_answer_context} 120 | --------------------- 121 | 给定的上下文和提供的历史信息,而不是事先的知识,回复用户的意见。如果答案不在上下文中,告诉用户你不能回答这个问题。 122 | """; 123 | return ChatClient.create(chatModel).prompt() 124 | .user(prompt) 125 | .advisors(QuestionAnswerAdvisor.builder(vectorStore) 126 | .promptTemplate(new PromptTemplate(promptWithContext)).build()) 127 | .stream() 128 | .content() 129 | .map(chatResponse -> ServerSentEvent.builder(chatResponse) 130 | .event("message") 131 | .build()); 132 | } 133 | 134 | /** 135 | * 根据会话id,从数据库中查找历史消息,并将消息作为上下文回答。 136 | * 137 | * @param prompt 用户的提问 138 | * @param sessionId 会话id 139 | * @return SSE流响应 140 | */ 141 | @GetMapping(value = "chat/stream/history", produces = MediaType.TEXT_EVENT_STREAM_VALUE) 142 | public Flux> chatStreamWithHistory(@RequestParam String prompt, 143 | @RequestParam String sessionId) { 144 | // 1. 如果需要存储会话和消息到数据库,自己可以实现ChatMemory接口,这里使用InMemoryChatMemory,内存存储。 145 | // 2. 传入会话id,MessageChatMemoryAdvisor会根据会话id去查找消息。 146 | // 3. 只需要携带最近10条消息 147 | MessageChatMemoryAdvisor messageChatMemoryAdvisor= MessageChatMemoryAdvisor 148 | .builder(chatMemory) 149 | .conversationId(sessionId) 150 | .build(); 151 | return ChatClient.create(chatModel).prompt() 152 | .user(prompt) 153 | // MessageChatMemoryAdvisor会在消息发送给大模型之前,从ChatMemory中获取会话的历史消息,然后一起发送给大模型。 154 | .advisors(messageChatMemoryAdvisor) 155 | .stream() 156 | .content() 157 | .map(chatResponse -> ServerSentEvent.builder(chatResponse) 158 | .event("message") 159 | .build()); 160 | } 161 | 162 | @PostMapping("ignore") 163 | public void ignore(@RequestParam AiMessageWrapper wrapper, @RequestParam AiMessageParams params) { 164 | } 165 | 166 | } 167 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/graph/GraphController.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.graph; 2 | 3 | import io.github.qifan777.knowledge.graph.chunk.ChunkController; 4 | import io.qifan.infrastructure.common.exception.BusinessException; 5 | import lombok.AllArgsConstructor; 6 | import lombok.extern.slf4j.Slf4j; 7 | import org.springframework.ai.chat.messages.UserMessage; 8 | import org.springframework.ai.chat.model.ChatModel; 9 | import org.springframework.ai.chat.prompt.PromptTemplate; 10 | import org.springframework.ai.embedding.EmbeddingModel; 11 | import org.springframework.data.neo4j.core.Neo4jClient; 12 | import org.springframework.web.bind.annotation.GetMapping; 13 | import org.springframework.web.bind.annotation.RequestParam; 14 | import org.springframework.web.bind.annotation.RestController; 15 | 16 | import java.util.List; 17 | import java.util.Map; 18 | 19 | @RestController("graph") 20 | @AllArgsConstructor 21 | @Slf4j 22 | public class GraphController { 23 | private final EmbeddingModel embeddingModel; 24 | private final ChatModel chatModel; 25 | private final Neo4jClient neo4jClient; 26 | private final PromptTemplate promptTemplate = new PromptTemplate(""" 27 | Context information is below. 28 | --------------------- 29 | {question_answer_context} 30 | --------------------- 31 | Given the context and provided history information and not prior knowledge, 32 | reply to the user comment. If the answer is not in the context, inform 33 | the user that you can't answer the question. 34 | """); 35 | 36 | @GetMapping(value = "chunk/rag") 37 | public String chunkRag(@RequestParam String query) { 38 | List embed = ChunkController.floatsToDoubles(embeddingModel.embed(query)); 39 | String result = neo4jClient.query(""" 40 | CALL db.index.vector.queryNodes('form_10k_chunks', 1, $embedding) 41 | yield node, score 42 | match window=(:Chunk)-[:NEXT*0..1]->(node)-[:NEXT*0..1]->(:Chunk) 43 | with nodes(window) as chunkList, node, score 44 | unwind chunkList as chunkRows 45 | with collect(chunkRows.text) as textList, node, score 46 | return apoc.text.join(textList, " \\n ") 47 | """) 48 | .bind(embed).to("embedding") 49 | .fetchAs(String.class).first() 50 | .orElseThrow(() -> new BusinessException("未找到相似文档")); 51 | String content = promptTemplate.createMessage(Map.of("question_answer_context", result)).getText(); 52 | return chatModel.call(new UserMessage(content + "\n" + query)); 53 | } 54 | 55 | @GetMapping(value = "manager/rag") 56 | public String managerRag(@RequestParam String query) { 57 | List embed = ChunkController.floatsToDoubles(embeddingModel.embed(query)); 58 | var result = neo4jClient.query(""" 59 | CALL db.index.vector.queryNodes('form_10k_chunks', 1, $embedding) 60 | YIELD node, score 61 | MATCH (node)-[:PART_OF]->(f:Form), 62 | (f)<-[:FILED]-(com:Company), 63 | (com)<-[owns:OWNS_STOCK_IN]-(mgr:Manager) 64 | WITH node, score, mgr, owns, com 65 | ORDER BY owns.shares DESC LIMIT 5 66 | WITH collect ( 67 | mgr.name + 68 | " owns " + owns.shares + " of " + com.name + 69 | " at a value of $" + apoc.number.format(owns.value) + "." 70 | ) AS investment_statements, node, score 71 | RETURN "investors: \\n" + apoc.text.join(investment_statements, "\\n") +\s 72 | "\\n" + node.text AS text 73 | """) 74 | .bind(embed).to("embedding") 75 | .fetchAs(String.class) 76 | .first() 77 | .orElseThrow(() -> new BusinessException("未找到相似文档")); 78 | String content = promptTemplate.createMessage(Map.of("question_answer_context", result)).getText(); 79 | log.info("context result: {}", content); 80 | return chatModel.call(new UserMessage(content + "\n" + query)); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/graph/chunk/Chunk.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.graph.chunk; 2 | 3 | import lombok.Builder; 4 | import lombok.Data; 5 | import org.springframework.data.neo4j.core.schema.Id; 6 | import org.springframework.data.neo4j.core.schema.Node; 7 | 8 | import java.util.List; 9 | 10 | @Builder 11 | @Data 12 | @Node 13 | public class Chunk { 14 | @Id 15 | private String id; 16 | // 切割后的文本 17 | private String text; 18 | // item1, item1a, item7, item7a 19 | private String item; 20 | // Chunk序列号 21 | private Integer chunkSeqId; 22 | // 属于的Form 23 | private String formId; 24 | // text的embedding 25 | private List textEmbedding; 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/graph/chunk/ChunkController.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.graph.chunk; 2 | 3 | import com.fasterxml.jackson.databind.ObjectMapper; 4 | import com.fasterxml.jackson.databind.node.ObjectNode; 5 | import lombok.AllArgsConstructor; 6 | import lombok.SneakyThrows; 7 | import lombok.extern.slf4j.Slf4j; 8 | import org.springframework.ai.document.Document; 9 | import org.springframework.ai.embedding.EmbeddingModel; 10 | import org.springframework.ai.transformer.splitter.TokenTextSplitter; 11 | import org.springframework.data.neo4j.core.Neo4jClient; 12 | import org.springframework.web.bind.annotation.PostMapping; 13 | import org.springframework.web.bind.annotation.RestController; 14 | 15 | import java.io.File; 16 | import java.util.ArrayList; 17 | import java.util.List; 18 | 19 | @RestController("chunk") 20 | @AllArgsConstructor 21 | @Slf4j 22 | public class ChunkController { 23 | private final ChunkRepository chunkRepository; 24 | private final Neo4jClient neo4jClient; 25 | private final EmbeddingModel embeddingModel; 26 | 27 | /** 28 | * 创建Chunk节点, 29 | */ 30 | 31 | @PostMapping("node") 32 | public void createNodes() { 33 | var fileDir = new File("F:\\workspace\\code\\learn\\sec-edgar-notebooks\\data\\sample\\form10k"); 34 | File[] files = fileDir.listFiles(); 35 | for (File file : files) { 36 | if (!file.getName().contains(".json")) continue; 37 | chunkRepository.saveAll(fileToChunkList(file)); 38 | } 39 | 40 | } 41 | 42 | /** 43 | * 解析form10k的中的item属性切割成Chunk 44 | * 45 | * @param file form10k的json文件 46 | * @return Chunk节点 47 | */ 48 | @SneakyThrows 49 | public List fileToChunkList(File file) { 50 | ObjectNode node = new ObjectMapper().readValue(file, ObjectNode.class); 51 | // 每个form10k有item1,item1a,item7,item7a四种文本信息,都需要将切割 52 | String[] items = {"item1", "item1a", "item7", "item7a"}; 53 | List chunks = new ArrayList<>(); 54 | for (String item : items) { 55 | String text = node.get(item).asText(); 56 | // 切割文本成 57 | List documents = new TokenTextSplitter().split(new Document(text)); 58 | // 最多不超过20 Chunk 59 | for (int chunkSeqId = 0; chunkSeqId < Integer.min(documents.size(), 20); chunkSeqId++) { 60 | String formId = file.getName().replace(".json", ""); 61 | chunks.add(Chunk.builder() 62 | .id("%s-%s-chunk%04d".formatted(formId, item, chunkSeqId)) 63 | .chunkSeqId(chunkSeqId) 64 | .formId(formId) 65 | .text(documents.get(chunkSeqId).getText()) 66 | .item(item) 67 | .build()); 68 | } 69 | } 70 | return chunks; 71 | } 72 | 73 | @PostMapping("link") 74 | public void createLink() { 75 | var formIds = neo4jClient 76 | .query("match (c:Chunk) return distinct c.formId as formId") 77 | .fetchAs(String.class) 78 | .all(); 79 | // 每个form10k有item1,item1a,item7,item7a四种文本信息,都需要将切割后的Chunk通过NEXT关联起来 80 | formIds.forEach(formId -> { 81 | List.of("item1", "item1a", "item7", "item7a") 82 | .forEach(item -> { 83 | neo4jClient.query(""" 84 | MATCH (c:Chunk) // 匹配所有的节点 85 | WHERE c.formId = $formId // 属于同一个form和同一个item的节点 86 | AND c.item = $item 87 | WITH c 88 | ORDER BY c.chunkSeqId ASC // 根据seqId排序一下节点 89 | WITH collect(c) as section_chunk_list // 转成list 90 | CALL apoc.nodes.link(section_chunk_list, "NEXT", {avoidDuplicates: true}) // 节点之间依按顺序创建连接 91 | RETURN size(section_chunk_list) 92 | """) 93 | .bind(formId).to("formId") 94 | .bind(item).to("item") 95 | .run(); 96 | }); 97 | }); 98 | } 99 | 100 | /** 101 | * 对所有Chunk进行embedding,neo4j中支持向量索引,只有创建索引之后才可以查询相似的向量 102 | */ 103 | @PostMapping("embedding") 104 | public void createEmbedding() { 105 | // 随便将一段文本转成向量,看看这个嵌入模型的向量维度是多少 106 | int dimension = embeddingModel.embed("你好").length; 107 | 108 | // 在Chunk节点创建索引,使用cosine求向量之间的相似度 109 | neo4jClient.query(""" 110 | CREATE VECTOR INDEX `form_10k_chunks` IF NOT EXISTS 111 | FOR (c:Chunk) ON (c.textEmbedding) 112 | OPTIONS { indexConfig: { 113 | `vector.dimensions`: $dimensions, 114 | `vector.similarity_function`: 'cosine' 115 | }} 116 | """) 117 | .bind(dimension).to("dimensions") 118 | .run(); 119 | // 对那些没有嵌入的Chunk进行embedding 120 | List waitToEmbedList = chunkRepository.findByTextEmbeddingIsNull(); 121 | waitToEmbedList.forEach(chunk -> chunk.setTextEmbedding(floatsToDoubles(embeddingModel.embed(chunk.getText())))); 122 | chunkRepository.saveAll(waitToEmbedList); 123 | } 124 | 125 | public static List floatsToDoubles(float[] floats) { 126 | List result = new ArrayList<>(floats.length); 127 | for (float f : floats) { 128 | result.add((double) f); 129 | } 130 | return result; 131 | } 132 | 133 | } 134 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/graph/chunk/ChunkRepository.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.graph.chunk; 2 | 3 | import org.springframework.data.neo4j.repository.Neo4jRepository; 4 | 5 | import java.util.List; 6 | 7 | public interface ChunkRepository extends Neo4jRepository { 8 | List findByTextEmbeddingIsNull(); 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/graph/company/Company.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.graph.company; 2 | 3 | import lombok.Builder; 4 | import lombok.Data; 5 | import org.springframework.data.neo4j.core.schema.Id; 6 | import org.springframework.data.neo4j.core.schema.Node; 7 | 8 | import java.util.List; 9 | 10 | @Builder 11 | @Data 12 | @Node 13 | public class Company { 14 | @Id 15 | private String cusip6; 16 | private List cusips; 17 | private List names; 18 | private String name; 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/graph/company/CompanyController.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.graph.company; 2 | 3 | import com.fasterxml.jackson.databind.ObjectMapper; 4 | import io.github.qifan777.knowledge.graph.model.Form10K; 5 | import lombok.AllArgsConstructor; 6 | import lombok.SneakyThrows; 7 | import lombok.extern.slf4j.Slf4j; 8 | import org.springframework.data.neo4j.core.Neo4jClient; 9 | import org.springframework.web.bind.annotation.PostMapping; 10 | import org.springframework.web.bind.annotation.RequestMapping; 11 | import org.springframework.web.bind.annotation.RestController; 12 | 13 | import java.io.File; 14 | 15 | @RestController 16 | @RequestMapping("company") 17 | @Slf4j 18 | @AllArgsConstructor 19 | public class CompanyController { 20 | private final CompanyRepository companyRepository; 21 | private final Neo4jClient neo4jClient; 22 | 23 | @SneakyThrows 24 | @PostMapping("nodes") 25 | public void createNodes() { 26 | var fileDir = new File("F:\\workspace\\code\\learn\\sec-edgar-notebooks\\data\\sample\\form10k"); 27 | File[] files = fileDir.listFiles(); 28 | for (File file : files) { 29 | if (!file.getName().contains(".json")) continue; 30 | var form10K = new ObjectMapper().readValue(file, Form10K.class); 31 | var company = Company.builder().cusip6(form10K.getCusip6()) 32 | .cusips(form10K.getCusip()) 33 | .names(form10K.getNames()) 34 | .name(form10K.getNames().get(0)) 35 | .build(); 36 | companyRepository.save(company); 37 | } 38 | 39 | } 40 | 41 | @PostMapping("relationship/filed") 42 | public void createFiledRelationship() { 43 | // 创建公司和form关系 44 | neo4jClient.query(""" 45 | MATCH (com:Company), (form:Form) 46 | WHERE com.cusip6 = form.cusip6 47 | MERGE (com)-[:FILED]->(form) 48 | """) 49 | .run(); 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/graph/company/CompanyRepository.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.graph.company; 2 | 3 | import org.springframework.data.neo4j.repository.Neo4jRepository; 4 | 5 | public interface CompanyRepository extends Neo4jRepository { 6 | } 7 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/graph/form/Form.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.graph.form; 2 | 3 | import lombok.Builder; 4 | import lombok.Data; 5 | import org.springframework.data.neo4j.core.schema.Id; 6 | import org.springframework.data.neo4j.core.schema.Node; 7 | 8 | @Builder 9 | @Data 10 | @Node 11 | public class Form { 12 | @Id 13 | private String id; 14 | private String cusip6; 15 | private String source; 16 | private String fullText; 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/graph/form/FormController.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.graph.form; 2 | 3 | import com.fasterxml.jackson.databind.ObjectMapper; 4 | import io.github.qifan777.knowledge.graph.model.Form10K; 5 | import lombok.AllArgsConstructor; 6 | import lombok.SneakyThrows; 7 | import lombok.extern.slf4j.Slf4j; 8 | import org.springframework.data.neo4j.core.Neo4jClient; 9 | import org.springframework.web.bind.annotation.PostMapping; 10 | import org.springframework.web.bind.annotation.RequestMapping; 11 | import org.springframework.web.bind.annotation.RestController; 12 | 13 | import java.io.File; 14 | 15 | @RestController 16 | @RequestMapping("form") 17 | @AllArgsConstructor 18 | @Slf4j 19 | public class FormController { 20 | private final FormRepository formRepository; 21 | private final Neo4jClient neo4jClient; 22 | 23 | @SneakyThrows 24 | @PostMapping("node") 25 | public void createNodes() { 26 | var fileDir = new File("F:\\workspace\\code\\learn\\sec-edgar-notebooks\\data\\sample\\form10k"); 27 | File[] files = fileDir.listFiles(); 28 | for (File file : files) { 29 | if (!file.getName().contains(".json")) continue; 30 | var form10K = new ObjectMapper().readValue(file, Form10K.class); 31 | var fullText = "About " + 32 | String.join(",", form10K.getNames()) + 33 | "..." + 34 | form10K.getItem1() + 35 | "\n" + 36 | form10K.getItem1a() + 37 | "\n" + 38 | form10K.getItem7() + 39 | "\n" + 40 | form10K.getItem7a(); 41 | var formId = file.getName().replace(".json", ""); 42 | var form = Form.builder().id(formId) 43 | .fullText(fullText) 44 | .cusip6(form10K.getCusip6()) 45 | .source(form10K.getSource()) 46 | .build(); 47 | formRepository.save(form); 48 | } 49 | } 50 | 51 | @PostMapping("relationship/section") 52 | public void createSectionRelationship() { 53 | neo4jClient.query(""" 54 | match (c:Chunk),(f:Form) where c.chunkSeqId=0 and f.id = c.formId 55 | merge (f)-[r:SECTION {item:c.item}] -> (c) 56 | return count(r) 57 | """) 58 | .run(); 59 | } 60 | 61 | @PostMapping("relationship/part-of") 62 | public void createPartOfRelationship() { 63 | neo4jClient.query(""" 64 | match (c:Chunk), (f:Form) where c.formId=f.id 65 | merge (c)-[r:PART_OF]->(f) 66 | return count(r); 67 | """) 68 | .run(); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/graph/form/FormRepository.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.graph.form; 2 | 3 | import org.springframework.data.neo4j.repository.Neo4jRepository; 4 | 5 | public interface FormRepository extends Neo4jRepository { 6 | } 7 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/graph/manager/Manager.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.graph.manager; 2 | 3 | import lombok.Data; 4 | import lombok.experimental.Accessors; 5 | import org.springframework.data.neo4j.core.schema.Id; 6 | import org.springframework.data.neo4j.core.schema.Node; 7 | 8 | @Accessors(chain = true) 9 | @Data 10 | @Node 11 | public class Manager { 12 | @Id 13 | private String cik; 14 | private String name; 15 | private String address; 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/graph/manager/ManagerController.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.graph.manager; 2 | 3 | import com.alibaba.excel.EasyExcel; 4 | import com.alibaba.excel.event.SyncReadListener; 5 | import com.fasterxml.jackson.core.type.TypeReference; 6 | import com.fasterxml.jackson.databind.ObjectMapper; 7 | import io.github.qifan777.knowledge.graph.model.Form13; 8 | import io.qifan.infrastructure.common.exception.BusinessException; 9 | import lombok.AllArgsConstructor; 10 | import lombok.SneakyThrows; 11 | import lombok.extern.slf4j.Slf4j; 12 | import org.springframework.data.neo4j.core.Neo4jClient; 13 | import org.springframework.web.bind.annotation.PostMapping; 14 | import org.springframework.web.bind.annotation.RequestMapping; 15 | import org.springframework.web.bind.annotation.RestController; 16 | 17 | import java.io.File; 18 | import java.util.List; 19 | import java.util.Map; 20 | 21 | @RestController 22 | @RequestMapping("manager") 23 | @AllArgsConstructor 24 | @Slf4j 25 | public class ManagerController { 26 | 27 | 28 | private final Neo4jClient neo4jClient; 29 | private final ObjectMapper jacksonObjectMapper; 30 | private final ManagerRepository managerRepository; 31 | 32 | public List readForm13List() { 33 | SyncReadListener syncReadListener = new SyncReadListener(); 34 | EasyExcel.read(new File("F:\\workspace\\code\\learn\\sec-edgar-notebooks\\data\\sample\\form13.csv"), Form13.class, syncReadListener) 35 | .sheet() 36 | .doRead(); 37 | return syncReadListener.getList() 38 | .stream() 39 | .map(o -> (Form13) o) 40 | .toList(); 41 | } 42 | 43 | @PostMapping("nodes") 44 | public void createNodes() { 45 | List form13List = readForm13List(); 46 | // 投资方可能投资了多个证券,所以会有重复的投资方记录,去重一下 47 | List cikList = form13List.stream().map(Form13::getManagerCik).distinct().toList(); 48 | // 每个cik代表一个投资方,映射成Manager对象 49 | List managerList = cikList.stream().map(cik -> { 50 | Form13 manager = form13List.stream() 51 | .filter(form13 -> form13.getManagerCik().equals(cik)) 52 | .findFirst() 53 | .orElseThrow(() -> new BusinessException("投资公司不存在")); 54 | return new Manager() 55 | .setCik(manager.getManagerCik()) 56 | .setName(manager.getManagerName()) 57 | .setAddress(manager.getManagerAddress()); 58 | }).toList(); 59 | managerRepository.saveAll(managerList); 60 | } 61 | 62 | @PostMapping("relationship/stock-in") 63 | public void createStockInRelationship() { 64 | List form13List = readForm13List(); 65 | form13List.forEach(form13 -> { 66 | neo4jClient.query(""" 67 | match (m:Manager {cik: $managerCik}), (com:Company {cusip6: $cusip6}) 68 | merge (m)-[owns:OWNS_STOCK_IN {reportCalendarOrQuarter: $reportCalendarOrQuarter}]->(com) 69 | on create set 70 | owns.value = $value, 71 | owns.shares = $shares 72 | """) 73 | .bindAll(toMap(form13)) 74 | .run(); 75 | }); 76 | 77 | } 78 | 79 | @SneakyThrows 80 | public Map toMap(Form13 form13) { 81 | return jacksonObjectMapper.readValue(jacksonObjectMapper.writeValueAsString(form13), new TypeReference<>() { 82 | }); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/graph/manager/ManagerRepository.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.graph.manager; 2 | 3 | import org.springframework.data.neo4j.repository.Neo4jRepository; 4 | 5 | public interface ManagerRepository extends Neo4jRepository { 6 | } 7 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/graph/model/Form10K.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.graph.model; 2 | 3 | import lombok.Data; 4 | 5 | import java.util.List; 6 | 7 | /** 8 | * Form 10-K是美国证券交易委员会(SEC)要求上市公司必须每年提交的有关其财务表现与公司运营的综合性报告, 9 | * 具体来说包括公司历史,组织架构,财务状况,每股收益,分支机构,高管薪酬等信息。 10 | */ 11 | @Data 12 | public class Form10K { 13 | /** 14 | * 业务 15 | */ 16 | private String item1; 17 | /** 18 | * 危险因素 19 | */ 20 | private String item1a; 21 | /** 22 | * 管理层对财务状况及经营成果的探讨与分析。 23 | */ 24 | private String item7; 25 | /** 26 | * 市场风险的定量和定性披露 27 | */ 28 | private String item7a; 29 | /** 30 | * 中央索引键(CIK)用于证券交易委员会的计算机系统,用于识别已向证券交易委员会提交披露文件的公司和个人。 31 | */ 32 | private String cik; 33 | 34 | /** 35 | * 。CUSIP的创立是为了给北美的每一个证券一个唯一的代码,这样在清算的时候就不会因为名字相似而出错。 36 | * 注意它是为了给每一个证券一个唯一的代码,这个证券包括股票,期权,期货,政府债券,企业债券等所有的证券 37 | */ 38 | private List cusip; 39 | /** 40 | * CUSIP的前六位是企业的代码 41 | */ 42 | private String cusip6; 43 | /** 44 | * 公司的名称(包含别名,所以有多个)列表 45 | */ 46 | private List names; 47 | /** 48 | * 该Form 10-K报告的原文 49 | */ 50 | private String source; 51 | } 52 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/graph/model/Form13.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.graph.model; 2 | 3 | import lombok.Data; 4 | 5 | /** 6 | * 表单包含投资方公司投资的其他公司、所持股份数量和投资价值的信息。 7 | */ 8 | @Data 9 | public class Form13 { 10 | /** 11 | * From13表格的原文链接 12 | */ 13 | private String source; 14 | /** 15 | * 投资方公司的CIK,参考Form10K中的CIK解释 16 | */ 17 | private String managerCik; 18 | /** 19 | * 投资方公司的名称 20 | */ 21 | private String managerName; 22 | /** 23 | * 投资方公司的地址 24 | */ 25 | private String managerAddress; 26 | /** 27 | * Form13报告发布的日期 28 | */ 29 | private String reportCalendarOrQuarter; 30 | /** 31 | * 参考Form10K中的CUSIP6解释 32 | */ 33 | private String cusip6; 34 | /** 35 | * 参考Form10K中的CUSIP解释 36 | */ 37 | private String cusip; 38 | /** 39 | * 被投资公司的名称 40 | */ 41 | private String companyName; 42 | /** 43 | * 投资的金额 44 | */ 45 | private Double value; 46 | /** 47 | * 投资份额 48 | */ 49 | private Double shares; 50 | } 51 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/infrastructure/code/CodeAssistantProperties.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.infrastructure.code; 2 | 3 | import lombok.Data; 4 | import org.springframework.boot.context.properties.ConfigurationProperties; 5 | import org.springframework.boot.context.properties.NestedConfigurationProperty; 6 | 7 | @ConfigurationProperties(prefix = "code-assistant") 8 | @Data 9 | public class CodeAssistantProperties { 10 | @NestedConfigurationProperty 11 | private ProjectProperties project; 12 | @NestedConfigurationProperty 13 | private ArthasProperties arthas; 14 | 15 | @Data 16 | public static class ProjectProperties { 17 | private String projectPath; 18 | } 19 | 20 | @Data 21 | public static class ArthasProperties { 22 | private String url; 23 | private String password; 24 | private String username; 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/infrastructure/code/JavaParserUtils.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.infrastructure.code; 2 | 3 | import com.github.javaparser.JavaParser; 4 | import com.github.javaparser.ParseResult; 5 | import com.github.javaparser.ParserConfiguration; 6 | import com.github.javaparser.ast.CompilationUnit; 7 | import com.github.javaparser.symbolsolver.JavaSymbolSolver; 8 | import com.github.javaparser.symbolsolver.resolution.typesolvers.CombinedTypeSolver; 9 | import com.github.javaparser.symbolsolver.resolution.typesolvers.JavaParserTypeSolver; 10 | import lombok.SneakyThrows; 11 | import lombok.extern.slf4j.Slf4j; 12 | import org.springframework.beans.BeansException; 13 | import org.springframework.context.ApplicationContext; 14 | import org.springframework.context.ApplicationContextAware; 15 | import org.springframework.stereotype.Component; 16 | 17 | import java.nio.file.Path; 18 | 19 | @Component 20 | @Slf4j 21 | public class JavaParserUtils implements ApplicationContextAware { 22 | private static CodeAssistantProperties properties; 23 | 24 | public static JavaParser getJavaParser() { 25 | CombinedTypeSolver combinedTypeSolver = new CombinedTypeSolver(); 26 | combinedTypeSolver.add(new JavaParserTypeSolver(Path.of(properties.getProject().getProjectPath()).resolve(Path.of("src", "main", "java")))); 27 | ParserConfiguration parserConfiguration = new ParserConfiguration().setSymbolResolver(new JavaSymbolSolver(combinedTypeSolver)); 28 | return new JavaParser(parserConfiguration); 29 | } 30 | 31 | @SneakyThrows 32 | public static ParseResult parse(Path path) { 33 | return getJavaParser().parse(path); 34 | } 35 | 36 | @Override 37 | public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { 38 | properties = applicationContext.getBean(CodeAssistantProperties.class); 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/infrastructure/config/GlobalExceptionAdvice.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.infrastructure.config; 2 | 3 | import cn.dev33.satoken.exception.DisableServiceException; 4 | import cn.dev33.satoken.exception.NotLoginException; 5 | import cn.dev33.satoken.exception.NotRoleException; 6 | import io.qifan.infrastructure.common.constants.ResultCode; 7 | import io.qifan.infrastructure.common.exception.BusinessException; 8 | import io.qifan.infrastructure.common.exception.SystemException; 9 | import io.qifan.infrastructure.common.model.R; 10 | import jakarta.validation.ConstraintViolation; 11 | import jakarta.validation.ConstraintViolationException; 12 | import lombok.extern.slf4j.Slf4j; 13 | import org.springframework.http.HttpStatus; 14 | import org.springframework.http.ResponseEntity; 15 | import org.springframework.web.bind.MethodArgumentNotValidException; 16 | import org.springframework.web.bind.annotation.ExceptionHandler; 17 | import org.springframework.web.bind.annotation.RestControllerAdvice; 18 | 19 | import java.util.ArrayList; 20 | 21 | @RestControllerAdvice 22 | @Slf4j 23 | public class GlobalExceptionAdvice { 24 | 25 | @ExceptionHandler(BusinessException.class) 26 | public ResponseEntity> handleBusinessException(BusinessException e) { 27 | log.error("业务异常", e); 28 | return ResponseEntity.status(HttpStatus.BAD_REQUEST) 29 | .body(R.fail(e.getResultCode(), e.getMessage())); 30 | } 31 | 32 | @ExceptionHandler(SystemException.class) 33 | public ResponseEntity> handleSystemException(SystemException e) { 34 | log.error("系统异常", e); 35 | return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(R.fail(ResultCode.SystemError)); 36 | } 37 | 38 | @ExceptionHandler(Exception.class) 39 | public ResponseEntity> handleException(Exception e) { 40 | log.error("系统异常", e); 41 | return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(R.fail(ResultCode.SystemError)); 42 | } 43 | 44 | @ExceptionHandler(ConstraintViolationException.class) 45 | public ResponseEntity> handleValidateException(ConstraintViolationException e) { 46 | log.warn("校验异常", e); 47 | // 不合格的字段,可能有多个,只需要返回其中一个提示用户就行 48 | // 比如密码为空 49 | ArrayList> constraintViolations = new ArrayList<>( 50 | e.getConstraintViolations()); 51 | return ResponseEntity.status(HttpStatus.BAD_REQUEST) 52 | .body(R.fail(ResultCode.ValidateError, 53 | constraintViolations.get(0).getMessage())); 54 | 55 | } 56 | 57 | @ExceptionHandler(MethodArgumentNotValidException.class) 58 | public ResponseEntity> handleValidateExceptionForSpring( 59 | MethodArgumentNotValidException e) { 60 | log.warn("校验异常", e); 61 | return ResponseEntity.status(HttpStatus.BAD_REQUEST) 62 | .body(R.fail(ResultCode.ValidateError, 63 | e.getBindingResult().getAllErrors().get(0) 64 | .getDefaultMessage())); 65 | } 66 | 67 | @ExceptionHandler(NotLoginException.class) 68 | public ResponseEntity> handleNotLogin(NotLoginException e) { 69 | log.error("未登录", e); 70 | return ResponseEntity.status(HttpStatus.BAD_REQUEST) 71 | .body(R.fail(ResultCode.Unauthorized)); 72 | } 73 | 74 | @ExceptionHandler(NotRoleException.class) 75 | public ResponseEntity> handleNotRole(NotRoleException e) { 76 | log.error("角色校验异常", e); 77 | return ResponseEntity.status(HttpStatus.BAD_REQUEST) 78 | .body(R.fail(ResultCode.NotGrant, e.getMessage())); 79 | } 80 | 81 | @ExceptionHandler(DisableServiceException.class) 82 | public ResponseEntity> handleDisabledException(DisableServiceException e) { 83 | log.error("账号封禁", e); 84 | return ResponseEntity.status(HttpStatus.BAD_REQUEST) 85 | .body(R.fail(ResultCode.StatusHasInvalid, "账号已被封禁")); 86 | } 87 | 88 | 89 | } 90 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/infrastructure/config/LocalDateTimeConvert.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.infrastructure.config; 2 | 3 | import com.fasterxml.jackson.core.JsonGenerator; 4 | import com.fasterxml.jackson.core.JsonParser; 5 | import com.fasterxml.jackson.databind.DeserializationContext; 6 | import com.fasterxml.jackson.databind.JsonDeserializer; 7 | import com.fasterxml.jackson.databind.JsonSerializer; 8 | import com.fasterxml.jackson.databind.SerializerProvider; 9 | import lombok.extern.slf4j.Slf4j; 10 | import org.springframework.boot.jackson.JsonComponent; 11 | 12 | import java.io.IOException; 13 | import java.time.LocalDateTime; 14 | import java.time.format.DateTimeFormatter; 15 | 16 | @JsonComponent 17 | @Slf4j 18 | public class LocalDateTimeConvert { 19 | 20 | public static class Serializer extends JsonSerializer { 21 | 22 | 23 | @Override 24 | public void serialize(LocalDateTime localDateTime, JsonGenerator jsonGenerator, 25 | SerializerProvider serializerProvider) throws IOException { 26 | DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern( 27 | "yyyy-MM-dd HH:mm:ss"); 28 | String format = dateTimeFormatter.format(localDateTime); 29 | jsonGenerator.writeString(format); 30 | } 31 | } 32 | 33 | public static class Deserializer extends JsonDeserializer { 34 | 35 | 36 | @Override 37 | public LocalDateTime deserialize(JsonParser jsonParser, 38 | DeserializationContext deserializationContext) throws IOException { 39 | String text = jsonParser.getText(); 40 | DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern( 41 | "yyyy-MM-dd HH:mm:ss"); 42 | return LocalDateTime.parse(text, dateTimeFormatter); 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/infrastructure/config/PageableConvert.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.infrastructure.config; 2 | 3 | import com.fasterxml.jackson.core.JsonGenerator; 4 | import com.fasterxml.jackson.core.JsonParser; 5 | import com.fasterxml.jackson.databind.DeserializationContext; 6 | import com.fasterxml.jackson.databind.JsonDeserializer; 7 | import com.fasterxml.jackson.databind.JsonSerializer; 8 | import com.fasterxml.jackson.databind.SerializerProvider; 9 | import io.qifan.infrastructure.common.model.PageResult; 10 | import org.springframework.boot.jackson.JsonComponent; 11 | import org.springframework.data.domain.Page; 12 | 13 | import java.io.IOException; 14 | import java.util.List; 15 | 16 | @JsonComponent 17 | public class PageableConvert { 18 | 19 | public static class Serializer extends JsonSerializer> { 20 | 21 | @Override 22 | public void serialize(Page page, JsonGenerator jsonGenerator, 23 | SerializerProvider serializerProvider) throws IOException { 24 | PageResult pageResult = new PageResult<>() 25 | .setNumber(page.getNumber()) 26 | .setSize(page.getSize()) 27 | .setTotalElements(page.getTotalElements()) 28 | .setTotalPages(page.getTotalPages()) 29 | .setContent((List) page.getContent()); 30 | jsonGenerator.writeObject(pageResult); 31 | } 32 | } 33 | 34 | public static class Deserializer extends JsonDeserializer> { 35 | 36 | 37 | @Override 38 | public Page deserialize(JsonParser jsonParser, 39 | DeserializationContext deserializationContext) throws IOException { 40 | return jsonParser.readValueAs(Page.class); 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/infrastructure/config/ResponseInterceptor.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.infrastructure.config; 2 | 3 | import com.fasterxml.jackson.databind.ObjectMapper; 4 | import io.qifan.infrastructure.common.model.R; 5 | import lombok.AllArgsConstructor; 6 | import lombok.SneakyThrows; 7 | import lombok.extern.slf4j.Slf4j; 8 | import org.springframework.core.MethodParameter; 9 | import org.springframework.http.MediaType; 10 | import org.springframework.http.server.ServerHttpRequest; 11 | import org.springframework.http.server.ServerHttpResponse; 12 | import org.springframework.web.bind.annotation.RestControllerAdvice; 13 | import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyAdvice; 14 | 15 | @RestControllerAdvice(basePackages = "io.github.qifan777.knowledge") 16 | @Slf4j 17 | @AllArgsConstructor 18 | public class ResponseInterceptor implements ResponseBodyAdvice { 19 | 20 | private final ObjectMapper objectMapper; 21 | 22 | @Override 23 | public boolean supports(MethodParameter returnType, Class converterType) { 24 | return true; 25 | } 26 | 27 | @SneakyThrows 28 | @Override 29 | public Object beforeBodyWrite(Object body, MethodParameter returnType, 30 | MediaType selectedContentType, Class selectedConverterType, 31 | ServerHttpRequest request, ServerHttpResponse response) { 32 | if (body instanceof byte[]) { 33 | return body; 34 | } 35 | if (body instanceof R) { 36 | return body; 37 | } 38 | if (body instanceof String) { 39 | response.getHeaders().setContentType(MediaType.APPLICATION_JSON); 40 | return objectMapper.writeValueAsString(R.ok(body)); 41 | } 42 | if (body == null) { 43 | return null; 44 | } 45 | log.debug("响应结果:{}", body); 46 | return R.ok(body); 47 | 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/infrastructure/config/WebConfig.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.infrastructure.config; 2 | 3 | import cn.dev33.satoken.interceptor.SaInterceptor; 4 | import org.springframework.boot.autoconfigure.security.SecurityProperties; 5 | import org.springframework.boot.context.properties.EnableConfigurationProperties; 6 | import org.springframework.context.annotation.Bean; 7 | import org.springframework.context.annotation.Configuration; 8 | import org.springframework.context.annotation.Primary; 9 | import org.springframework.core.task.AsyncTaskExecutor; 10 | import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; 11 | import org.springframework.web.servlet.config.annotation.AsyncSupportConfigurer; 12 | import org.springframework.web.servlet.config.annotation.InterceptorRegistry; 13 | import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; 14 | 15 | @Configuration 16 | @EnableConfigurationProperties(SecurityProperties.class) 17 | public class WebConfig implements WebMvcConfigurer { 18 | @Primary 19 | @Bean(name = "asyncTaskExecutor") 20 | public AsyncTaskExecutor asyncTaskExecutor() { 21 | ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); 22 | executor.setCorePoolSize(5); // 根据需求调整 23 | executor.setMaxPoolSize(10); 24 | executor.setQueueCapacity(25); 25 | executor.setThreadNamePrefix("MyApp-Async-"); 26 | executor.initialize(); 27 | return executor; 28 | } 29 | 30 | @Override 31 | public void configureAsyncSupport(AsyncSupportConfigurer configurer) { 32 | configurer.setTaskExecutor(asyncTaskExecutor()); 33 | } 34 | 35 | @Override 36 | public void addInterceptors(InterceptorRegistry registry) { 37 | registry.addInterceptor(new SaInterceptor()).addPathPatterns("/**"); 38 | } 39 | } -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/infrastructure/jimmer/BaseDateTime.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.infrastructure.jimmer; 2 | 3 | import org.babyfish.jimmer.sql.MappedSuperclass; 4 | 5 | import java.time.LocalDateTime; 6 | 7 | @MappedSuperclass 8 | public interface BaseDateTime { 9 | 10 | LocalDateTime createdTime(); 11 | 12 | LocalDateTime editedTime(); 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/infrastructure/jimmer/BaseDateTimeDraftInterceptor.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.infrastructure.jimmer; 2 | 3 | import lombok.AllArgsConstructor; 4 | import org.babyfish.jimmer.sql.DraftInterceptor; 5 | import org.jetbrains.annotations.NotNull; 6 | import org.springframework.stereotype.Component; 7 | 8 | import java.time.LocalDateTime; 9 | 10 | @Component 11 | @AllArgsConstructor 12 | public class BaseDateTimeDraftInterceptor implements 13 | DraftInterceptor { 14 | 15 | 16 | @Override 17 | public void beforeSave(@NotNull BaseDateTimeDraft draft, BaseDateTime baseDateTime) { 18 | draft.setEditedTime(LocalDateTime.now()); 19 | if (baseDateTime == null) { 20 | draft.setCreatedTime(LocalDateTime.now()); 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/infrastructure/jimmer/BaseEntity.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.infrastructure.jimmer; 2 | 3 | 4 | import io.github.qifan777.knowledge.user.User; 5 | import org.babyfish.jimmer.sql.*; 6 | 7 | @MappedSuperclass 8 | public interface BaseEntity extends BaseDateTime { 9 | 10 | @Id 11 | @GeneratedValue(generatorType = UUIDIdGenerator.class) 12 | String id(); 13 | 14 | @ManyToOne 15 | @OnDissociate(DissociateAction.DELETE) 16 | User editor(); 17 | 18 | @ManyToOne 19 | @OnDissociate(DissociateAction.DELETE) 20 | User creator(); 21 | } -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/infrastructure/jimmer/BaseEntityDraftInterceptor.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.infrastructure.jimmer; 2 | 3 | import cn.dev33.satoken.stp.StpUtil; 4 | import lombok.AllArgsConstructor; 5 | import org.babyfish.jimmer.sql.DraftInterceptor; 6 | import org.jetbrains.annotations.NotNull; 7 | import org.springframework.stereotype.Component; 8 | 9 | @Component 10 | @AllArgsConstructor 11 | public class BaseEntityDraftInterceptor implements DraftInterceptor { 12 | 13 | 14 | @Override 15 | public void beforeSave(@NotNull BaseEntityDraft draft, BaseEntity baseEntity) { 16 | draft.applyEditor(user -> { 17 | user.setId(StpUtil.getLoginIdAsString()); 18 | }); 19 | if (baseEntity == null) { 20 | draft.applyCreator(user -> { 21 | user.setId(StpUtil.getLoginIdAsString()); 22 | }); 23 | } 24 | } 25 | } -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/infrastructure/jimmer/JimmerConfig.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.infrastructure.jimmer; 2 | 3 | import org.babyfish.jimmer.sql.meta.DatabaseNamingStrategy; 4 | import org.babyfish.jimmer.sql.runtime.DefaultDatabaseNamingStrategy; 5 | import org.springframework.context.annotation.Bean; 6 | import org.springframework.context.annotation.Configuration; 7 | 8 | @Configuration 9 | public class JimmerConfig { 10 | 11 | @Bean 12 | public DatabaseNamingStrategy databaseNamingStrategy() { 13 | return DefaultDatabaseNamingStrategy.LOWER_CASE; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/infrastructure/jimmer/UUIDIdGenerator.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.infrastructure.jimmer; 2 | 3 | import cn.hutool.core.util.IdUtil; 4 | import org.babyfish.jimmer.sql.meta.UserIdGenerator; 5 | 6 | public class UUIDIdGenerator implements UserIdGenerator { 7 | 8 | public UUIDIdGenerator() { 9 | } 10 | 11 | public String generate(Class entityType) { 12 | return IdUtil.fastSimpleUUID(); 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/infrastructure/model/QueryRequest.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.infrastructure.model; 2 | 3 | 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | import lombok.extern.slf4j.Slf4j; 8 | import org.babyfish.jimmer.sql.ast.LikeMode; 9 | import org.jetbrains.annotations.Nullable; 10 | import org.springframework.data.domain.PageRequest; 11 | import org.springframework.data.domain.Pageable; 12 | import org.springframework.data.domain.Sort; 13 | import org.springframework.util.CollectionUtils; 14 | 15 | import java.util.List; 16 | import java.util.stream.Collectors; 17 | 18 | @Data 19 | @Slf4j 20 | public class QueryRequest { 21 | 22 | private T query; 23 | 24 | private Integer pageSize = 10; 25 | 26 | private Integer pageNum = 1; 27 | @Nullable 28 | private LikeMode likeMode; 29 | @Nullable 30 | private List sorts; 31 | 32 | public Pageable toPageable() { 33 | if (!CollectionUtils.isEmpty(sorts)) { 34 | return PageRequest.of(getPageNum() - 1, getPageSize(), 35 | Sort.by(sorts.stream() 36 | .map(QuerySort::toOrder) 37 | .collect(Collectors.toList()))); 38 | } 39 | return PageRequest.of(getPageNum() - 1, getPageSize(), 40 | Sort.by(Sort.Order.desc("createdTime"))); 41 | } 42 | 43 | public Pageable toPageable(Sort.Order... orders) { 44 | return PageRequest.of(this.getPageNum() - 1, 45 | this.getPageSize(), 46 | Sort.by(orders)); 47 | } 48 | 49 | @Data 50 | @AllArgsConstructor 51 | @NoArgsConstructor 52 | public static class QuerySort { 53 | 54 | String property; 55 | Sort.Direction direction; 56 | 57 | public Sort.Order toOrder() { 58 | return new Sort.Order(direction, property); 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/oss/OSSController.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.oss; 2 | 3 | import io.qifan.infrastructure.oss.service.OSSService; 4 | import lombok.AllArgsConstructor; 5 | import org.springframework.web.bind.annotation.PostMapping; 6 | import org.springframework.web.bind.annotation.RequestMapping; 7 | import org.springframework.web.bind.annotation.RequestParam; 8 | import org.springframework.web.bind.annotation.RestController; 9 | import org.springframework.web.multipart.MultipartFile; 10 | 11 | @RestController 12 | @RequestMapping("oss") 13 | @AllArgsConstructor 14 | public class OSSController { 15 | private final OSSService ossService; 16 | 17 | @PostMapping("upload") 18 | public String upload(@RequestParam MultipartFile file) { 19 | return ossService.upload(file); 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/user/User.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.user; 2 | 3 | import io.github.qifan777.knowledge.infrastructure.jimmer.BaseDateTime; 4 | import io.github.qifan777.knowledge.infrastructure.jimmer.UUIDIdGenerator; 5 | import jakarta.validation.constraints.Null; 6 | import org.babyfish.jimmer.sql.Entity; 7 | import org.babyfish.jimmer.sql.GeneratedValue; 8 | import org.babyfish.jimmer.sql.Id; 9 | import org.babyfish.jimmer.sql.Key; 10 | 11 | 12 | @Entity 13 | public interface User extends BaseDateTime { 14 | @Id 15 | @GeneratedValue(generatorType = UUIDIdGenerator.class) 16 | String id(); 17 | 18 | /** 19 | * 手机号 20 | */ 21 | @Key 22 | String phone(); 23 | 24 | /** 25 | * 密码 26 | */ 27 | String password(); 28 | 29 | /** 30 | * 头像 31 | */ 32 | @Null 33 | String avatar(); 34 | 35 | /** 36 | * 昵称 37 | */ 38 | @Null 39 | String nickname(); 40 | 41 | /** 42 | * 性别 43 | */ 44 | @Null 45 | String gender(); 46 | } 47 | 48 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/user/UserController.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.user; 2 | 3 | import cn.dev33.satoken.secure.BCrypt; 4 | import cn.dev33.satoken.stp.SaTokenInfo; 5 | import cn.dev33.satoken.stp.StpUtil; 6 | import io.github.qifan777.knowledge.user.dto.UserLoginInput; 7 | import io.github.qifan777.knowledge.user.dto.UserRegisterInput; 8 | import io.qifan.infrastructure.common.exception.BusinessException; 9 | import lombok.AllArgsConstructor; 10 | import org.babyfish.jimmer.client.FetchBy; 11 | import org.babyfish.jimmer.sql.EnableDtoGeneration; 12 | import org.springframework.web.bind.annotation.*; 13 | 14 | import java.util.Optional; 15 | 16 | 17 | @RequestMapping("user") 18 | @RestController 19 | @AllArgsConstructor 20 | public class UserController { 21 | private final UserRepository userRepository; 22 | 23 | @GetMapping 24 | public @FetchBy(value = "FETCHER", ownerType = UserRepository.class) User userInfo() { 25 | return userRepository.findById(StpUtil.getLoginIdAsString(), UserRepository.FETCHER) 26 | .orElseThrow(() -> new BusinessException("用户信息不存在")); 27 | } 28 | 29 | @PostMapping("login") 30 | public SaTokenInfo login(@RequestBody UserLoginInput input) { 31 | User databaseUser = userRepository.findByPhone(input.getPhone()) 32 | .orElseThrow(() -> new BusinessException("用户名/密码错误")); 33 | if (!BCrypt.checkpw(input.getPassword(), databaseUser.password())) { 34 | throw new BusinessException("用户名/密码错误"); 35 | } 36 | StpUtil.login(databaseUser.id()); 37 | return StpUtil.getTokenInfo(); 38 | } 39 | 40 | @PostMapping("register") 41 | public SaTokenInfo register(@RequestBody UserRegisterInput input) { 42 | Optional byPhone = userRepository.findByPhone(input.getPhone()); 43 | if (byPhone.isPresent()) { 44 | throw new BusinessException("手机号已存在, 请登录"); 45 | } 46 | User save = userRepository.save(UserDraft.$.produce(draft -> { 47 | draft.setPhone(input.getPhone()) 48 | .setPassword(BCrypt.hashpw(input.getPassword())); 49 | })); 50 | StpUtil.login(save.id()); 51 | return StpUtil.getTokenInfo(); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/io/github/qifan777/knowledge/user/UserRepository.java: -------------------------------------------------------------------------------- 1 | package io.github.qifan777.knowledge.user; 2 | 3 | import org.babyfish.jimmer.spring.repository.JRepository; 4 | 5 | import java.util.Optional; 6 | 7 | public interface UserRepository extends JRepository { 8 | UserTable t = UserTable.$; 9 | UserFetcher FETCHER = UserFetcher.$.allScalarFields(); 10 | 11 | default Optional findByPhone(String phone) { 12 | return sql().createQuery(t) 13 | .where(t.phone().eq(phone)) 14 | .select(t) 15 | .fetchOptional(); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/main/resources/application.yml: -------------------------------------------------------------------------------- 1 | spring: 2 | application: 3 | name: knowledge-base 4 | profiles: 5 | active: dev,private 6 | data: 7 | redis: 8 | database: 0 9 | timeout: 10s 10 | password: 123456 11 | repositories: 12 | enabled: false 13 | client-type: jedis 14 | datasource: 15 | url: jdbc:mysql://localhost:3306/knowledge_base 16 | username: root 17 | password: 123456 18 | driver-class-name: com.mysql.cj.jdbc.Driver 19 | neo4j: 20 | authentication: 21 | username: neo4j 22 | password: 12345678 23 | uri: bolt://localhost:7687 24 | ai: 25 | # 阿里灵积 26 | dash-scope: 27 | api-key: xxx 28 | chat: 29 | options: 30 | incremental-output: false 31 | model: qwen-max 32 | # 多模态图片理解需要开启 33 | # multi-model: true 34 | embedding: 35 | options: 36 | model: text-embedding-v2 37 | image: 38 | options: 39 | model: wanx-v1 40 | # 百度千帆 需要在pom.xml中放开qianfan starter的注释 41 | qian-fan: 42 | api-key: xxx 43 | secret-key: xxx 44 | # openai 需要在pom.xml中放开openai starter的注释 45 | openai: 46 | api-key: xxx 47 | # 智谱清言 同样 48 | zhipuai: 49 | api-key: xxx 50 | # kimi 同样 51 | moonshot: 52 | api-key: xxx 53 | ollama: 54 | # 填写模型所在的地址 55 | base-url: http://localhost:11434 56 | chat: 57 | model: xxx 58 | vectorstore: 59 | redis: 60 | initialize-schema: true 61 | index-name: custom-index 62 | prefix: custom-prefix 63 | code-assistant: 64 | project: 65 | project-path: F:\workspace\code\opensource\dive-into-spring-ai 66 | arthas: 67 | url: http://localhost:8563/api 68 | password: arthas 69 | username: arthas 70 | server: 71 | port: 9902 72 | sa-token: 73 | # token名称 (同时也是cookie名称) 74 | token-name: token 75 | # token有效期,单位s 默认30天, -1代表永不过期 76 | timeout: 2592000 77 | # token临时有效期 (指定时间内无操作就视为token过期) 单位: 秒 78 | active-timeout: -1 79 | # 是否允许同一账号并发登录 (为true时允许一起登录, 为false时新登录挤掉旧登录) 80 | is-concurrent: true 81 | # 在多人登录同一账号时,是否共用一个token (为true时所有登录共用一个token, 为false时每次登录新建一个token) 82 | is-share: true 83 | # token风格 84 | token-style: uuid 85 | # 是否输出操作日志 86 | is-log: true 87 | # 自动续签 88 | auto-renew: true 89 | oss: 90 | provider: ali_yun 91 | # 阿里云oss配置 92 | ali-yun: 93 | access-key-id: xxx 94 | access-key-secret: xxx 95 | endpoint: xxx 96 | bucket-name: xxx 97 | # 腾讯云oss配置 98 | tencent: 99 | bucket: xx 100 | region: xx 101 | secret-id: xx 102 | secret-key: xx 103 | jimmer: 104 | dialect: org.babyfish.jimmer.sql.dialect.MySqlDialect 105 | show-sql: true 106 | pretty-sql: true 107 | client: 108 | ts: 109 | path: /ts.zip 110 | openapi: 111 | path: /openapi 112 | ui-path: /openapi-ui -------------------------------------------------------------------------------- /src/main/resources/prompts/react/react-tool.st: -------------------------------------------------------------------------------- 1 | {name}: Call this tool to interact with the {name} API. What is the {name} API useful for? {description} Parameters: {parameters} 2 | -------------------------------------------------------------------------------- /src/main/resources/prompts/react/react.st: -------------------------------------------------------------------------------- 1 | Answer the following questions as best you can. You have access to the following APIs: 2 | 3 | {tools_text} 4 | 5 | Use the following format: 6 | 7 | Question: the input question you must answer 8 | Thought: you should always think about what to do 9 | Action: the action to take, should be one of [{tools_name_text}] 10 | Action Input: the input to the action 11 | Observation: the result of the action 12 | ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) 13 | Thought: I now know the final answer 14 | Final Answer: the final answer to the original input question 15 | 16 | Begin! 17 | 18 | Question: {query} --------------------------------------------------------------------------------