├── .github └── workflows │ └── maven-publish.yml ├── .gitignore ├── .mvn └── wrapper │ ├── maven-wrapper.jar │ └── maven-wrapper.properties ├── LICENSE ├── README-EN.md ├── README.md ├── doc ├── wenxin-doc-en.md └── wenxin-doc.md ├── mvnw ├── mvnw.cmd ├── pom.xml └── src ├── main ├── java │ └── com │ │ └── gearwenxin │ │ ├── client │ │ ├── ChatClient.java │ │ ├── ImageClient.java │ │ ├── PromptClient.java │ │ └── basic │ │ │ └── BasicChatClient.java │ │ ├── common │ │ ├── Constant.java │ │ ├── ConvertUtils.java │ │ ├── ErrorCode.java │ │ ├── FileUtils.java │ │ ├── RuntimeToolkit.java │ │ ├── StatusConst.java │ │ └── WenXinUtils.java │ │ ├── config │ │ ├── GearWenXinConfig.java │ │ ├── ModelConfig.java │ │ └── WenXinProperties.java │ │ ├── core │ │ ├── AuthEncryption.java │ │ ├── ConsumerService.java │ │ ├── ConsumerThreadMonitor.java │ │ ├── MessageHistoryManager.java │ │ └── RequestManager.java │ │ ├── entity │ │ ├── BaseProperty.java │ │ ├── BaseRequest.java │ │ ├── ClientParams.java │ │ ├── Example.java │ │ ├── FunctionCall.java │ │ ├── FunctionInfo.java │ │ ├── FunctionParameters.java │ │ ├── FunctionResponses.java │ │ ├── Message.java │ │ ├── PluginUsage.java │ │ ├── Usage.java │ │ ├── chatmodel │ │ │ ├── ChatBaseRequest.java │ │ │ ├── ChatErnieRequest.java │ │ │ └── ChatPromptRequest.java │ │ ├── enums │ │ │ ├── ModelType.java │ │ │ ├── ResponseFormatType.java │ │ │ ├── Role.java │ │ │ └── SamplerType.java │ │ ├── request │ │ │ ├── EmbeddingV1Request.java │ │ │ ├── ErnieRequest.java │ │ │ ├── ImageBaseRequest.java │ │ │ ├── PluginParams.java │ │ │ └── PromptRequest.java │ │ └── response │ │ │ ├── ChatResponse.java │ │ │ ├── ErrorResponse.java │ │ │ ├── ImageData.java │ │ │ ├── ImageResponse.java │ │ │ ├── PromptErrMessage.java │ │ │ ├── PromptResponse.java │ │ │ ├── PromptResult.java │ │ │ ├── SSEResponse.java │ │ │ ├── SearchInfo.java │ │ │ ├── SearchResult.java │ │ │ ├── TokenResponse.java │ │ │ └── plugin │ │ │ ├── PluginResponse.java │ │ │ └── knowledge │ │ │ ├── KnowledgeBaseMI.java │ │ │ ├── KnowledgeMIRequest.java │ │ │ ├── KnowledgeMIResponse.java │ │ │ ├── KnowledgeMIResponses.java │ │ │ └── KnowledgeMIResult.java │ │ ├── exception │ │ └── WenXinException.java │ │ ├── model │ │ ├── BasicChatModel.java │ │ ├── ChatModel.java │ │ ├── EmbeddingModel.java │ │ ├── ImageModel.java │ │ └── PromptModel.java │ │ ├── plugin │ │ └── Weather.java │ │ ├── schedule │ │ ├── BackgroundSaveManager.java │ │ ├── TaskConsumerLoop.java │ │ ├── TaskQueueManager.java │ │ ├── ThreadPoolManager.java │ │ └── entity │ │ │ ├── BlockingMap.java │ │ │ ├── ChatTask.java │ │ │ └── ModelHeader.java │ │ ├── service │ │ ├── ChatService.java │ │ ├── EmbeddingService.java │ │ ├── ImageService.java │ │ ├── MessageService.java │ │ ├── PromptService.java │ │ ├── WinXinActions.java │ │ └── impl │ │ │ └── WinXinActionsImpl.java │ │ ├── subscriber │ │ └── CommonSubscriber.java │ │ └── validator │ │ ├── ChatBaseRequestValidator.java │ │ ├── ChatErnieRequestValidator.java │ │ ├── RequestValidator.java │ │ └── RequestValidatorFactory.java └── resources │ ├── META-INF │ ├── spring.factories │ └── spring │ │ └── org.springframework.boot.autoconfigure.AutoConfiguration.imports │ └── application.yaml └── test └── java └── com └── gearwenxin └── client └── erniebot └── ErnieBotClientTest.java /.github/workflows/maven-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a package using Maven and then publish it to GitHub packages when a release is created 2 | # For more information see: https://github.com/actions/setup-java/blob/main/docs/advanced-usage.md#apache-maven-with-a-settings-path 3 | 4 | name: Maven Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | permissions: 15 | contents: read 16 | packages: write 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | - name: Set up JDK 17 21 | uses: actions/setup-java@v3 22 | with: 23 | java-version: '17' 24 | distribution: 'temurin' 25 | server-id: github # Value of the distributionManagement/repository/id field of the pom.xml 26 | settings-path: ${{ github.workspace }} # location for the settings.xml file 27 | 28 | - name: Build with Maven 29 | run: mvn -B package --file pom.xml 30 | 31 | - name: Publish to GitHub Packages Apache Maven 32 | run: mvn deploy -s $GITHUB_WORKSPACE/settings.xml 33 | env: 34 | GITHUB_TOKEN: ${{ github.token }} 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | HELP.md 2 | target/ 3 | !.mvn/wrapper/maven-wrapper.jar 4 | !**/src/main/**/target/ 5 | !**/src/test/**/target/ 6 | 7 | ### STS ### 8 | .apt_generated 9 | .classpath 10 | .factorypath 11 | .project 12 | .settings 13 | .springBeans 14 | .sts4-cache 15 | 16 | ### IntelliJ IDEA ### 17 | .idea 18 | *.iws 19 | *.iml 20 | *.ipr 21 | 22 | ### NetBeans ### 23 | /nbproject/private/ 24 | /nbbuild/ 25 | /dist/ 26 | /nbdist/ 27 | /.nb-gradle/ 28 | build/ 29 | !**/src/main/**/build/ 30 | !**/src/test/**/build/ 31 | 32 | ### VS Code ### 33 | .vscode/ -------------------------------------------------------------------------------- /.mvn/wrapper/maven-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciimina/qianfan-starter/53f7ab383c15e3b88e1bc9ce2c86b72af1bff39a/.mvn/wrapper/maven-wrapper.jar -------------------------------------------------------------------------------- /.mvn/wrapper/maven-wrapper.properties: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.8.7/apache-maven-3.8.7-bin.zip 18 | wrapperUrl=https://repo.maven.apache.org/maven2/org/apache/maven/wrapper/maven-wrapper/3.1.1/maven-wrapper-3.1.1.jar -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 GMerge 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README-EN.md: -------------------------------------------------------------------------------- 1 |
2 | 中文  |  3 | EN 4 |
5 | 6 |
7 | 8 | ![gear-wenxinworkshop-starter](https://socialify.git.ci/gemingjia/gear-wenxinworkshop-starter/image?font=Inter&forks=1&issues=1&language=1&name=1&owner=1&pattern=Floating%20Cogs&pulls=1&stargazers=1&theme=Light) 9 | 10 | ![LICENSE](https://img.shields.io/github/license/gemingjia/gear-wenxinworkshop-starter?style=flat-square) 11 | ![Spring Boot](https://img.shields.io/badge/Spring%20Boot-3.1.0-brightgreen.svg) 12 | ![JDK](https://img.shields.io/badge/JDK-17.0.5-orange.svg) 13 | ![Maven](https://img.shields.io/badge/Maven-3.9-blue.svg) 14 | 15 | ![COMMIT](https://img.shields.io/github/last-commit/gemingjia/gear-wenxinworkshop-starter?style=flat-square) 16 | ![LANG](https://img.shields.io/badge/language-Java-7F52FF?style=flat-square) 17 | 18 |
19 | # Gear-WenXinWorkShop-Starter 20 | 21 | ## How to get access-token? 22 | 23 | [Apply for WenxinYiyan & WenxinQianfan Big model API qualification, get access_token, and use SpringBoot to access WenxinYiyan API](https://juejin.cn/post/7260418945721991227) 24 | 25 | 26 | 1. Go to [WenXinYiYan qualification application](https://cloud.baidu.com/product/wenxinworkshop) 27 | 28 | 2. [Fill out the questionnaire](https://cloud.baidu.com/survey/qianfan.html),and wait for approval (it took me one and a half days) 29 | 30 | 3. After approval,enter the [console](https://console.bce.baidu.com/ai/?_=#/ai/wenxinworkshop/overview/index),click[Create Application](https://console.bce.baidu.com/ai/?_=#/ai/wenxinworkshop/app/create) 31 | 4. Enter the left side [Application List](https://console.bce.baidu.com/ai/?_=#/ai/wenxinworkshop/app/list),copy`API Key` and `Secret Key` 32 | 5. Replace your `API Key` and `Secret Key` with [Key] in the link and visit the following address 33 | > https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=[API-Key]&client_secret=[Secret-Key] 34 | 35 | ## 📖 Project Introduction 36 | - The spring-boot-starter of Baidu's **"Wenxin Qianfan WENXINWORKSHOP"** large model can help you quickly access Baidu's AI capabilities. You can call Baidu's Wenxin Qianfan large model with only one line of code. 37 | - Complete docking with the official API documentation of WenxinQianfan. 38 | - Support streaming back of conversations. 39 | - Full API support for `ErnieBot`、`ERNIE-Bot-turbo`、`BLOOMZ-7B`、`Ernie-Bot-VilG`、`VisualGLM-6B`、`Llama-2`、`Linly-Chinese-LLaMA-2-7B`、`Linly-Chinese-LLaMA-2-13B`、`ChatGLM2-6B`、`RWKV-4-World`、`OpenLLaMA-7B`、`Falcon-7B`、`Dolly-12B`、`MPT-7B-Instruct`、`Stable-Diffusion-v1.5`、`RWKV-4-pile-14B`、`RWKV-5-World`、`RWKV-Raven-14B`、`Falcon-40B`、`MPT-30B-instruct`、`Flan-UL2`、`Cerebras-GPT-13B`、`Cerebras-GPT-6.7B`、`Pythia-12B`、`Pythia-6.9B`、`GPT-J-6B`、`GPT-NeoX-20B`、`OA-Pythia-12B-SFT-4`、`GPT4All-J`、`StableLM-Alpha-7B` 、 `StarCoder`、`Prompt Template` models (single round conversation, continuous conversation, streaming return). 40 | - Support formore models will be added in behind version. 41 | 42 | ## 🚀 Quick Start 43 | 44 | [Project demo](https://github.com/gemingjia/springboot-wenxin-demo) 45 | 46 | ```text 47 | This version almost refactoring the entire project, the path between the client and the parameter class has changed, there is a certain incompatibility with the previous version, the method has not changed, just re-guide the package. 48 | 49 | "Bloomz7BClient" -> "BloomZ7BClient" 50 | 51 | Except "ErnieBot" and "Prompt", the receiving parameter class of the other conversational models is unified as ChatBaseRequest, and the response class is ChatResponse 52 | The receiving parameter class of the image generation model is unified as ChatImageRequest, the response class is ImageBaseRequest, and the content is base64 encoded image. 53 | ``` 54 | 55 | ### 1、Add Dependencies 56 | - Maven 57 | ```xml 58 | 59 | io.github.gemingjia 60 | gear-wenxinworkshop-starter 61 | 1.1.1 62 | 63 | ``` 64 | - Gradle 65 | ```gradle 66 | dependencies { 67 | implementation 'io.github.gemingjia:gear-wenxinworkshop-starter:1.1.1' 68 | } 69 | ``` 70 | 71 | ### 2、Add access-token 72 | - application.yml & application.yaml 73 | ```yaml 74 | gear: 75 | wenxin: 76 | access-token: xx.xxxxxxxxxx.xxxxxx.xxxxxxx.xxxxx-xxxx 77 | ``` 78 | - application.properties 79 | ```properties 80 | gear.wenxin.access-token=xx.xxxxxxxxxx.xxxxxx.xxxxxxx.xxxxx-xxxx 81 | ``` 82 | 83 | ### 3、Invoke Example 84 | ```java 85 | @RestController 86 | public class ChatController { 87 | 88 | // 要调用的模型的客户端 89 | @Resource 90 | private ErnieBotClient ernieBotClient; 91 | 92 | // 单次对话 93 | @PostMapping("/chat") 94 | public Mono chatSingle(String msg) { 95 | return ernieBotClient.chatSingle(msg); 96 | } 97 | 98 | // 连续对话 99 | @PostMapping("/chats") 100 | public Mono chatCont(String msg) { 101 | String chatUID = "test-user-1001"; 102 | return ernieBotClient.chatCont(msg, chatUID); 103 | } 104 | 105 | // 流式返回,单次对话 106 | @GetMapping(value = "/stream/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE) 107 | public Flux chatSingleStream(@RequestParam String msg) { 108 | Flux chatResponse = ernieBotClient.chatSingleOfStream(msg); 109 | 110 | return chatResponse.map(response -> "data: " + response.getResult() + "\n\n"); 111 | } 112 | 113 | // 流式返回,连续对话 114 | @GetMapping(value = "/stream/chats", produces = MediaType.TEXT_EVENT_STREAM_VALUE) 115 | public Flux chatContStream(@RequestParam String msg, @RequestParam String msgUid) { 116 | Flux chatResponse = ernieBotClient.chatContOfStream(msg, msgUid); 117 | 118 | return chatResponse.map(response -> "data: " + response.getResult() + "\n\n"); 119 | } 120 | 121 | // 模板对话 122 | @PostMapping("/prompt") 123 | public Mono chatSingle() { 124 | Map map = new HashMap<>(); 125 | map.put("article", "我看见过波澜壮阔的大海,玩赏过水平如镜的西湖,却从没看见过漓江这样的水。漓江的水真静啊,静得让你感觉不到它在流动。"); 126 | map.put("number", "20"); 127 | PromptRequest promptRequest = new PromptRequest(); 128 | promptRequest.setId(1234); 129 | promptRequest.setParamMap(map); 130 | 131 | return promptBotClient.chatPrompt(promptRequest); 132 | } 133 | 134 | } 135 | ``` 136 | 137 | ## 📑Documentation 138 | 139 |
140 | Click => 141 | Documents 142 |
143 | 144 | ## Open Source License 145 | [LICENSE](https://www.apache.org/licenses/LICENSE-2.0) 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 中文  |  3 | EN 4 |
5 | 6 |
7 | 8 | ![llms–nexus](https://socialify.git.ci/rainveil/wenxin-starter/image?font=Inter&forks=1&issues=1&language=1&name=1&owner=1&pattern=Floating%20Cogs&pulls=1&stargazers=1&theme=Light) 9 | 10 | ![Spring Boot](https://img.shields.io/badge/Spring%20Boot-3.1.5-brightgreen.svg) 11 | ![JDK](https://img.shields.io/badge/JDK-17.0.5-orange.svg) 12 | ![Maven](https://img.shields.io/badge/Maven-3.9-blue.svg) 13 | 14 | ![LICENSE](https://img.shields.io/github/license/rainveil/wenxin-starter?style=flat-square) 15 | ![COMMIT](https://img.shields.io/github/last-commit/rainveil/wenxin-starter?style=flat-square) 16 | ![LANG](https://img.shields.io/badge/language-Java-7F52FF?style=flat-square) 17 | 18 |
19 | # WenXin-Starter 20 | 21 | # 📢 此项目已停更,新项目请去 [AltEgo](https://github.com/altegox/AltEgo) 22 | 23 | 24 | # [ => 1.0版本链接](https://github.com/egmsia01/wenxin-starter/tree/master?tab=readme-ov-file) 25 | 26 | ## 项目简介 27 | - 百度 **“文心千帆 WENXINWORKSHOP”** 大模型的spring-boot-starter,可以帮助您快速接入百度的AI能力。 28 | - 完整对接文心千帆的官方API文档。 29 | - 支持文生图,内置对话记忆,支持对话的流式返回。 30 | - 支持单个模型的QPS控制,支持排队机制。 31 | - 即将增加插件支持。 32 | 33 | ## 快速开始 34 | 35 | [使用demo (1.x版,2.x请阅读文档) ](https://github.com/rainveil/springboot-wenxin-demo) 36 | 37 | *【基于Springboot 3.0开发,所以要求JDK版本为17及以上】* 38 | 39 | ### 1、添加依赖 40 | 41 | - Maven 42 | ```xml 43 | 44 | io.github.gemingjia 45 | wenxin-starter 46 | 2.0.0-beta4 47 | 48 | ``` 49 | - Gradle 50 | ```gradle 51 | dependencies { 52 | implementation 'io.github.gemingjia:wenxin-starter:2.0.0-beta4' 53 | } 54 | ``` 55 | 56 | ### 2、添加access-token 57 | - application.yml & application.yaml 58 | ```yaml 59 | gear: 60 | wenxin: 61 | access-token: xx.xxxxxxxxxx.xxxxxx.xxxxxxx.xxxxx-xxxx 62 | -------------或----------------- 63 | # 推荐 64 | gear: 65 | wenxin: 66 | api-key: xxxxxxxxxxxxxxxxxxx 67 | secret-key: xxxxxxxxxxxxxxxxxxxxxxxxxxxxx 68 | ``` 69 | - application.properties 70 | ```properties 71 | gear.wenxin.access-token=xx.xxxxxxxxxx.xxxxxx.xxxxxxx.xxxxx-xxxx 72 | ``` 73 | 74 | - 模型qps设置 75 | ```yaml 76 | gear: 77 | wenxin: 78 | model-qps: 79 | # 模型名 QPS数量 80 | - Ernie 10 81 | - Lamma 10 82 | - ChatGLM 10 83 | ``` 84 | 85 | ### 3、调用示例 86 | 87 | ```java 88 | 89 | @Configuration 90 | public class ClientConfig { 91 | 92 | @Bean 93 | @Qualifier("Ernie") 94 | public ChatModel ernieClient() { 95 | 96 | ModelConfig modelConfig = new ModelConfig(); 97 | // 模型名称,需跟设置的QPS数值的名称一致 (建议与官网名称一致) 98 | modelConfig.setModelName("Ernie"); 99 | // 模型url 100 | modelConfig.setModelUrl("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"); 101 | // 单独设置某个模型的access-token, 优先级高于全局access-token, 统一使用全局的话可以不设置 102 | modelConfig.setAccessToken("xx.xx.xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"); 103 | 104 | ModelHeader modelHeader = new ModelHeader(); 105 | // 一分钟内允许的最大请求次数 106 | modelHeader.set_X_Ratelimit_Limit_Requests(100); 107 | // 一分钟内允许的最大tokens消耗,包含输入tokens和输出tokens 108 | modelHeader.set_X_Ratelimit_Limit_Tokens(2000); 109 | // 达到RPM速率限制前,剩余可发送的请求数配额,如果配额用完,将会在0-60s后刷新 110 | modelHeader.set_X_Ratelimit_Remaining_Requests(1000); 111 | // 达到TPM速率限制前,剩余可消耗的tokens数配额,如果配额用完,将会在0-60s后刷新 112 | modelHeader.set_X_Ratelimit_Remaining_Tokens(5000); 113 | 114 | modelConfig.setModelHeader(modelHeader); 115 | 116 | return new ChatClient(modelConfig); 117 | } 118 | 119 | } 120 | 121 | @RestController 122 | public class ChatController { 123 | 124 | // 要调用的模型的客户端(示例为文心) 125 | @Resource 126 | @Qualifier("Ernie") 127 | private ChatModel chatClient; 128 | 129 | /** 130 | * chatClient.chatStream(msg) 单轮流式对话 131 | * chatClient.chatStream(new ChatErnieRequest()) 单轮流式对话, 参数可调 132 | * chatClient.chatsStream(msg, msgId) 连续对话 133 | * chatClient.chatsStream(new ChatErnieRequest(), msgId) 连续对话, 参数可调 134 | */ 135 | 136 | /** 137 | * 以下两种方式均可 138 | */ 139 | // 连续对话,流式 140 | @GetMapping(value = "/stream/chats", produces = MediaType.TEXT_EVENT_STREAM_VALUE) 141 | public Flux chatSingleStream(@RequestParam String msg, @RequestParam String uid) { 142 | // 单次对话 chatClient.chatStream(msg) 143 | Flux responseFlux = chatClient.chatsStream(msg, uid); 144 | return responseFlux.map(ChatResponse::getResult); 145 | } 146 | 147 | // 连续对话,流式 148 | @GetMapping(value = "/stream/chats1", produces = MediaType.TEXT_EVENT_STREAM_VALUE) 149 | public SseEmitter chats(@RequestParam String msg, @RequestParam String uid) { 150 | SseEmitter emitter = new SseEmitter(); 151 | // 支持参数设置 ChatErnieRequest(Ernie系列模型)、ChatBaseRequest(其他模型) 152 | // 单次对话 chatClient.chatsStream(msg) 153 | chatClient.chatsStream(msg, uid).subscribe(response -> { 154 | try { 155 | emitter.send(SseEmitter.event().data(response.getResult())); 156 | } catch (IOException e) { 157 | throw new RuntimeException(e); 158 | } 159 | }); 160 | return emitter; 161 | } 162 | 163 | } 164 | 165 | /** 166 | * Prompt模板被百度改的有点迷,等稳定一下再做适配... 167 | */ 168 | 169 | ``` 170 | 171 | ## Star History 172 | 173 | [![Star History Chart](https://api.star-history.com/svg?repos=rainveil/wenxin-starter&type=Date)](https://star-history.com/#rainveil/wenxin-starter) 174 | 175 | ## 更新日志 176 | 177 | v2.0.0-alpha1 // 始终上传失败...建议自己拉仓库install 178 | - JDK 8专版 179 | 180 | v2.0.0 - bata4 181 | 182 | - 修复 修复定时任务导致的序列化问题 183 | 184 | v2.0.0 - bata3 185 | 186 | - 修复 修复并发场景下导致的丢对话任务的问题 187 | - 修复 网络异常情况下导致的消息错乱问题 188 | - 新增 导入导出消息的api 189 | - 新增 消息存储与获取的api 190 | - 新增 Prompt与ImageClient 191 | - 优化 整体性能 192 | - 其余改动请查看commit. 193 | 194 | v2.0.0 - bata 195 | 196 | ! 2.x 版本与 1.x 版本不兼容 197 | - 重构 SDK架构,大幅提升性能 198 | - 重构 客户端生成方式,支持自定义多模型,不再需要适配 199 | - 完善 普通chat接口现已可用 200 | 201 | ## 使用文档 202 | 203 |
204 | 点击跳转 => 205 | 使用文档 206 |
207 | 208 | ## 开源协议 209 | ```text 210 | MIT License 211 | 212 | Copyright (c) 2023 Rainveil 213 | 214 | Permission is hereby granted, free of charge, to any person obtaining a copy 215 | of this software and associated documentation files (the "Software"), to deal 216 | in the Software without restriction, including without limitation the rights 217 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 218 | copies of the Software, and to permit persons to whom the Software is 219 | furnished to do so, subject to the following conditions: 220 | 221 | The above copyright notice and this permission notice shall be included in all 222 | copies or substantial portions of the Software. 223 | 224 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 225 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 226 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 227 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 228 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 229 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 230 | SOFTWARE. 231 | ``` 232 | -------------------------------------------------------------------------------- /doc/wenxin-doc.md: -------------------------------------------------------------------------------- 1 | # 一、前言 2 | 3 | ### 1、添加依赖 4 | 5 | - Maven 6 | ```xml 7 | 8 | io.github.gemingjia 9 | wenxin-starter 10 | 2.0.0-beta 11 | 12 | ``` 13 | - Gradle 14 | ```gradle 15 | dependencies { 16 | implementation 'io.github.gemingjia:wenxin-starter:2.0.0-beta' 17 | } 18 | ``` 19 | 20 | ### 2、添加access-token 21 | - application.yml & application.yaml 22 | ```yaml 23 | gear: 24 | wenxin: 25 | access-token: xx.xxxxxxxxxx.xxxxxx.xxxxxxx.xxxxx-xxxx 26 | -------------或----------------- 27 | # 推荐 28 | gear: 29 | wenxin: 30 | api-key: xxxxxxxxxxxxxxxxxxx 31 | secret-key: xxxxxxxxxxxxxxxxxxxxxxxxxxxxx 32 | ``` 33 | - application.properties 34 | ```properties 35 | gear.wenxin.access-token=xx.xxxxxxxxxx.xxxxxx.xxxxxxx.xxxxx-xxxx 36 | ``` 37 | 38 | - 模型qps设置 39 | ```yaml 40 | gear: 41 | wenxin: 42 | model-qps: 43 | # 模型名 QPS数量 44 | - Ernie 10 45 | - Lamma 10 46 | - ChatGLM 10 47 | ``` 48 | 49 | # 二、参数与返回值 50 | 51 | ## ErnieBot(文心一言) 52 | 53 | ErnieBot参数建议参考 [官方文档](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t) 54 | 55 | **ChatErnieRequest**:**ErnieBot、Ernie4Bot、ErnieBotTurbo** 参数配置类 56 | 57 | | 变量名 | 类型 | 说明 | 58 | | ------------ | ------ | ------------------------------------------------------------ | 59 | | userId | String | 表示最终用户的唯一标识符,可以监视和检测滥用行为,防止接口恶意调用 | 60 | | content | String | 聊天文本信息。单个`content` 长度不能超过2000个字符;连续对话中,若 `content` 总长度大于2000字符,系统会依次遗忘最早的历史会话,直到 `content` 的总长度不超过2000个字符。 | 61 | | temperature | Float | (1) 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
(2) 默认 `0.95`,范围 `(0,1.0]`,不能为0
(3) 建议该参数和 `top_p` 只设置1个
(4) 建议 `top_p` 和 `temperature` 不要同时更改 | 62 | | topP | Float | (1) 影响输出文本的多样性,取值越大,生成文本的多样性越强
(2) 默认`0.8`,取值范围 `[0,1.0]`
(3) 建议该参数和 `temperature` 只设置1个
(4) 建议 `top_p` 和 `temperature` 不要同时更改 | 63 | | penaltyScore | Float | 通过对已生成的 `token` 增加惩罚,减少重复生成的现象。说明:
(1) 值越大表示惩罚越大
(2) 默认 `1.0`,取值范围 `[1.0,2.0]` | 64 | 65 | **ChatResponse**:普通对话的响应类 66 | 67 | | 变量名 | 类型 | 说明 | 68 | | ---------------- | ------- | ------------------------------------------------------------ | 69 | | id | String | 本轮对话的id | 70 | | object | String | 回包类型,chat.completion为多轮对话返回 | 71 | | created | Integer | 时间戳 | 72 | | sentenceId | Integer | 表示当前子句的序号。只有在流式接口模式下会返回该字段 | 73 | | isEnd | Boolean | 表示当前子句是否是最后一句。只有在流式接口模式下会返回该字段 | 74 | | isTruncated | Boolean | 当前生成的结果是否被截断 | 75 | | result | String | 对话返回结果 | 76 | | needClearHistory | Boolean | 表示用户输入是否存在安全,是否关闭当前会话,清理历史回话信息
true:是,表示用户输入存在安全风险,建议关闭当前会话,清理历史会话信息
false:否,表示用户输入无安全风险 | 77 | | usage | Usage | token统计信息,token数 = 汉字数+单词数_1.3 (仅为估算逻辑) | 78 | | errorCode | Integer | 错误代码,正常为null | 79 | | errorMsg | String | 错误描述信息,帮助理解和解决发生的错误,正常为null | 80 | 81 | **Usage**:`tokens`使用情况 82 | 83 | | 变量名 | 类型 | 说明 | 84 | | ---------------- | ---- | ------------ | 85 | | promptTokens | int | 问题tokens数 | 86 | | completionTokens | int | 回答tokens数 | 87 | | totalTokens | int | tokens总数 | 88 | 89 | ## 文生图 90 | 91 | 详见下方使用示例。 92 | 93 | 94 | 95 | ## Prompt模板 96 | 97 | **ChatPromptRequest**:**Prompt** 模板参数配置类 98 | 99 | | 变量名 | 类型 | 说明 | 100 | | -------- | ------------------- | -------------------------- | 101 | | id | int | prompt工程里面对应的模板id | 102 | | paramMap | Map | Map<插值变量名1,插值变量> | 103 | 104 | **PromptResponse**:**Prompt** 模板响应类 105 | 106 | | **名称** | **类型** | **描述** | 107 | | ---------------- | ---------------- | ----------------------------- | 108 | | log_id | String | 唯一的 `log id`,用于问题定位 | 109 | | result | PromptResult | 模板内容详情 | 110 | | status | Integer | 状态码,正常200 | 111 | | success | Boolean | 调用成功与否,成功为true | 112 | | errorCode | Integer | 错误代码,正常为null | 113 | | errorMsg | String | 错误信息,正常为null | 114 | | promptErrCode | String | `Prompt`错误代码 | 115 | | promptErrMessage | PromptErrMessage | `Prompt`错误信息对象 | 116 | 117 | **PromptResult**: 118 | 119 | | **名称** | **类型** | **描述** | 120 | | ----------------- | -------- | -------------------------------------------- | 121 | | templateId | String | `prompt`工程里面对应的模板id | 122 | | templateName | String | 模板名称 | 123 | | templateContent | String | 模板原始内容 | 124 | | templateVariables | String | 模板变量插值 | 125 | | content | String | 将变量插值填充到模板原始内容后得到的模板内容 | 126 | 127 | **PromptErrMessage**: 128 | 129 | | **名称** | **类型** | **描述** | 130 | | -------- | -------- | ------------ | 131 | | global | String | 错误信息描述 | 132 | 133 | 134 | ## 配置类,创建对应的Bean 135 | ```java 136 | @Configuration 137 | public class ClientConfig { 138 | 139 | @Bean 140 | // 对应的模型名称,建议与modelConfig.setModelName("Ernie")保持一致 (允许不一致) 141 | @Qualifier("Ernie") 142 | public ChatClient ernieClient() { 143 | ModelConfig modelConfig = new ModelConfig(); 144 | // 模型名称,必须跟设置的QPS数值的名称一致 (强烈建议与官网名称一致) 145 | modelConfig.setModelName("Ernie"); 146 | // 模型url 147 | modelConfig.setModelUrl("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"); 148 | // 单独设置某个模型的access-token, 优先级高于全局access-token, 统一使用全局的话可以不设置 149 | modelConfig.setAccessToken("xx.xx.xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"); 150 | 151 | // 请求头,可选,不设置则使用默认值 152 | ModelHeader modelHeader = new ModelHeader(); 153 | // 一分钟内允许的最大请求次数 154 | modelHeader.set_X_Ratelimit_Limit_Requests(100); 155 | // 一分钟内允许的最大tokens消耗,包含输入tokens和输出tokens 156 | modelHeader.set_X_Ratelimit_Limit_Tokens(2000); 157 | // 达到RPM速率限制前,剩余可发送的请求数配额,如果配额用完,将会在0-60s后刷新 158 | modelHeader.set_X_Ratelimit_Remaining_Requests(1000); 159 | // 达到TPM速率限制前,剩余可消耗的tokens数配额,如果配额用完,将会在0-60s后刷新 160 | modelHeader.set_X_Ratelimit_Remaining_Tokens(5000); 161 | 162 | modelConfig.setModelHeader(modelHeader); 163 | 164 | return new ChatClient(modelConfig); 165 | } 166 | 167 | } 168 | ``` 169 | --- 170 | 以下均以 `Webflux` 为例,`Spring mvc` 请自行调整 171 | 172 | ## 单次对话 173 | 174 | ### 非流式返回 175 | 176 | ```java 177 | @RestController 178 | public class ChatController { 179 | 180 | // 要调用的模型的客户端(示例为文心) 181 | @Resource 182 | // 与上方配置类中的 @Qualifier("Ernie") 保持一致 183 | @Qualifier("Ernie") 184 | private ChatClient chatClient; 185 | 186 | @GetMapping(value = "/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE) 187 | public Mono chatSingle(@RequestParam String msg) { 188 | // 单次对话 chatClient.chat(msg) 189 | Mono response = chatClient.chats(msg); 190 | return response.map(ChatResponse::getResult); 191 | } 192 | 193 | } 194 | ``` 195 | 196 | ### 流式返回 197 | 198 | ```java 199 | @RestController 200 | public class ChatController { 201 | 202 | // 要调用的模型的客户端(示例为文心) 203 | @Resource 204 | @Qualifier("Ernie") 205 | private ChatClient chatClient; 206 | 207 | @GetMapping(value = "/stream/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE) 208 | public Flux chatSingleStream(@RequestParam String msg) { 209 | // 单次对话 chatClient.chatStream(msg) 210 | Flux responseFlux = chatClient.chatsStream(msg); 211 | return responseFlux.map(ChatResponse::getResult); 212 | } 213 | 214 | } 215 | ``` 216 | 217 | ## 连续对话 218 | 219 | 连续对话记录内部已内置 220 | 221 | ### 非流式返回 222 | 223 | ```java 224 | @RestController 225 | public class ChatController { 226 | 227 | // 要调用的模型的客户端(示例为文心) 228 | @Resource 229 | @Qualifier("Ernie") 230 | private ChatClient chatClient; 231 | 232 | @GetMapping(value = "/chats", produces = MediaType.TEXT_EVENT_STREAM_VALUE) 233 | public Mono chatSingle(@RequestParam String msg, @RequestParam String uid) { 234 | // 单次对话 chatClient.chat(msg) 235 | Mono response = chatClient.chats(msg, uid); 236 | return response.map(ChatResponse::getResult); 237 | } 238 | 239 | } 240 | ``` 241 | 242 | ### 流式返回 243 | 244 | ```java 245 | @RestController 246 | public class ChatController { 247 | 248 | // 要调用的模型的客户端(示例为文心) 249 | @Resource 250 | @Qualifier("Ernie") 251 | private ChatClient chatClient; 252 | 253 | @GetMapping(value = "/stream/chats", produces = MediaType.TEXT_EVENT_STREAM_VALUE) 254 | public Flux chatSingleStream(@RequestParam String msg, @RequestParam String uid) { 255 | // 单次对话 chatClient.chatStream(msg) 256 | Flux responseFlux = chatClient.chatsStream(msg, uid); 257 | return responseFlux.map(ChatResponse::getResult); 258 | } 259 | 260 | } 261 | ``` 262 | 263 | ## 文生图(Stable-Diffusion-XL) // 2.0.0-beta版本暂未支持 264 | 265 | ```java 266 | 267 | // 文生图响应类 268 | public class ImageResponse { 269 | /** 270 | * 请求的ID。 271 | */ 272 | private String id; 273 | 274 | /** 275 | * 回包类型。固定值为 "image",表示图像生成返回。 276 | */ 277 | private String object; 278 | 279 | /** 280 | * 时间戳,表示生成响应的时间。 281 | */ 282 | private int created; 283 | 284 | /** 285 | * 生成图片结果列表。 286 | */ 287 | private List data; 288 | 289 | /** 290 | * token统计信息,token数 = 汉字数 + 单词数 * 1.3 (仅为估算逻辑)。 291 | */ 292 | private Usage usage; 293 | 294 | /** 295 | * 错误代码,正常为 null 296 | */ 297 | private Integer errorCode; 298 | 299 | /** 300 | * 错误信息,正常为 null 301 | */ 302 | private String errorMsg; 303 | } 304 | 305 | public class ImageData { 306 | 307 | /** 308 | * 固定值 "image",表示图像。 309 | */ 310 | private String object; 311 | 312 | /** 313 | * 图片base64编码内容。 314 | */ 315 | private String b64Image; 316 | 317 | /** 318 | * 图片序号。 319 | */ 320 | private int index; 321 | } 322 | ``` 323 | 324 | 325 | ## Prompt模板 // 2.0.0-beta版本暂未支持 326 | 327 | ```java 328 | 329 | ``` 330 | 331 | ## 历史消息记录操作 332 | 333 | ### 导出历史消息记录 334 | 335 | 此功能为 **导出历史消息记录** ,供开发者自行保存历史消息记录。 336 | 337 | ```java 338 | @Service 339 | public class ChatService { 340 | 341 | @Resource 342 | private WinXinActions winXinActions; 343 | 344 | // 导出指定msgId的消息(json) 345 | public String exportMessages(String msgId) { 346 | return winXinActions.exportMessages(msgId); 347 | } 348 | 349 | // 导出所有消息(json) 350 | public String exportAllMessages() { 351 | return winXinActions.exportAllMessages(); 352 | } 353 | 354 | } 355 | ``` 356 | 357 | ### 导入历史消息记录 358 | 359 | ```java 360 | @Service 361 | public class ChatService { 362 | 363 | @Resource 364 | private WinXinActions winXinActions; 365 | 366 | // 初始化所有消息map 367 | public void initMessageMap(Map> map) { 368 | winXinActions.initMessageMap(map); 369 | } 370 | 371 | // 初始化指定msgId的消息 372 | public void initMessages(String msgId, Deque messageDeque) { 373 | winXinActions.initMessages(msgId, messageDeque); 374 | } 375 | 376 | } 377 | ``` -------------------------------------------------------------------------------- /mvnw: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # ---------------------------------------------------------------------------- 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # https://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # ---------------------------------------------------------------------------- 20 | 21 | # ---------------------------------------------------------------------------- 22 | # Maven Start Up Batch script 23 | # 24 | # Required ENV vars: 25 | # ------------------ 26 | # JAVA_HOME - location of a JDK home dir 27 | # 28 | # Optional ENV vars 29 | # ----------------- 30 | # M2_HOME - location of maven2's installed home dir 31 | # MAVEN_OPTS - parameters passed to the Java VM when running Maven 32 | # e.g. to debug Maven itself, use 33 | # set MAVEN_OPTS=-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=8000 34 | # MAVEN_SKIP_RC - flag to disable loading of mavenrc files 35 | # ---------------------------------------------------------------------------- 36 | 37 | if [ -z "$MAVEN_SKIP_RC" ] ; then 38 | 39 | if [ -f /usr/local/etc/mavenrc ] ; then 40 | . /usr/local/etc/mavenrc 41 | fi 42 | 43 | if [ -f /etc/mavenrc ] ; then 44 | . /etc/mavenrc 45 | fi 46 | 47 | if [ -f "$HOME/.mavenrc" ] ; then 48 | . "$HOME/.mavenrc" 49 | fi 50 | 51 | fi 52 | 53 | # OS specific support. $var _must_ be set to either true or false. 54 | cygwin=false; 55 | darwin=false; 56 | mingw=false 57 | case "`uname`" in 58 | CYGWIN*) cygwin=true ;; 59 | MINGW*) mingw=true;; 60 | Darwin*) darwin=true 61 | # Use /usr/libexec/java_home if available, otherwise fall back to /Library/Java/Home 62 | # See https://developer.apple.com/library/mac/qa/qa1170/_index.html 63 | if [ -z "$JAVA_HOME" ]; then 64 | if [ -x "/usr/libexec/java_home" ]; then 65 | export JAVA_HOME="`/usr/libexec/java_home`" 66 | else 67 | export JAVA_HOME="/Library/Java/Home" 68 | fi 69 | fi 70 | ;; 71 | esac 72 | 73 | if [ -z "$JAVA_HOME" ] ; then 74 | if [ -r /etc/gentoo-release ] ; then 75 | JAVA_HOME=`java-config --jre-home` 76 | fi 77 | fi 78 | 79 | if [ -z "$M2_HOME" ] ; then 80 | ## resolve links - $0 may be a link to maven's home 81 | PRG="$0" 82 | 83 | # need this for relative symlinks 84 | while [ -h "$PRG" ] ; do 85 | ls=`ls -ld "$PRG"` 86 | link=`expr "$ls" : '.*-> \(.*\)$'` 87 | if expr "$link" : '/.*' > /dev/null; then 88 | PRG="$link" 89 | else 90 | PRG="`dirname "$PRG"`/$link" 91 | fi 92 | done 93 | 94 | saveddir=`pwd` 95 | 96 | M2_HOME=`dirname "$PRG"`/.. 97 | 98 | # make it fully qualified 99 | M2_HOME=`cd "$M2_HOME" && pwd` 100 | 101 | cd "$saveddir" 102 | # echo Using m2 at $M2_HOME 103 | fi 104 | 105 | # For Cygwin, ensure paths are in UNIX format before anything is touched 106 | if $cygwin ; then 107 | [ -n "$M2_HOME" ] && 108 | M2_HOME=`cygpath --unix "$M2_HOME"` 109 | [ -n "$JAVA_HOME" ] && 110 | JAVA_HOME=`cygpath --unix "$JAVA_HOME"` 111 | [ -n "$CLASSPATH" ] && 112 | CLASSPATH=`cygpath --path --unix "$CLASSPATH"` 113 | fi 114 | 115 | # For Mingw, ensure paths are in UNIX format before anything is touched 116 | if $mingw ; then 117 | [ -n "$M2_HOME" ] && 118 | M2_HOME="`(cd "$M2_HOME"; pwd)`" 119 | [ -n "$JAVA_HOME" ] && 120 | JAVA_HOME="`(cd "$JAVA_HOME"; pwd)`" 121 | fi 122 | 123 | if [ -z "$JAVA_HOME" ]; then 124 | javaExecutable="`which javac`" 125 | if [ -n "$javaExecutable" ] && ! [ "`expr \"$javaExecutable\" : '\([^ ]*\)'`" = "no" ]; then 126 | # readlink(1) is not available as standard on Solaris 10. 127 | readLink=`which readlink` 128 | if [ ! `expr "$readLink" : '\([^ ]*\)'` = "no" ]; then 129 | if $darwin ; then 130 | javaHome="`dirname \"$javaExecutable\"`" 131 | javaExecutable="`cd \"$javaHome\" && pwd -P`/javac" 132 | else 133 | javaExecutable="`readlink -f \"$javaExecutable\"`" 134 | fi 135 | javaHome="`dirname \"$javaExecutable\"`" 136 | javaHome=`expr "$javaHome" : '\(.*\)/bin'` 137 | JAVA_HOME="$javaHome" 138 | export JAVA_HOME 139 | fi 140 | fi 141 | fi 142 | 143 | if [ -z "$JAVACMD" ] ; then 144 | if [ -n "$JAVA_HOME" ] ; then 145 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 146 | # IBM's JDK on AIX uses strange locations for the executables 147 | JAVACMD="$JAVA_HOME/jre/sh/java" 148 | else 149 | JAVACMD="$JAVA_HOME/bin/java" 150 | fi 151 | else 152 | JAVACMD="`\\unset -f command; \\command -v java`" 153 | fi 154 | fi 155 | 156 | if [ ! -x "$JAVACMD" ] ; then 157 | echo "Error: JAVA_HOME is not defined correctly." >&2 158 | echo " We cannot execute $JAVACMD" >&2 159 | exit 1 160 | fi 161 | 162 | if [ -z "$JAVA_HOME" ] ; then 163 | echo "Warning: JAVA_HOME environment variable is not set." 164 | fi 165 | 166 | CLASSWORLDS_LAUNCHER=org.codehaus.plexus.classworlds.launcher.Launcher 167 | 168 | # traverses directory structure from process work directory to filesystem root 169 | # first directory with .mvn subdirectory is considered project base directory 170 | find_maven_basedir() { 171 | 172 | if [ -z "$1" ] 173 | then 174 | echo "Path not specified to find_maven_basedir" 175 | return 1 176 | fi 177 | 178 | basedir="$1" 179 | wdir="$1" 180 | while [ "$wdir" != '/' ] ; do 181 | if [ -d "$wdir"/.mvn ] ; then 182 | basedir=$wdir 183 | break 184 | fi 185 | # workaround for JBEAP-8937 (on Solaris 10/Sparc) 186 | if [ -d "${wdir}" ]; then 187 | wdir=`cd "$wdir/.."; pwd` 188 | fi 189 | # end of workaround 190 | done 191 | echo "${basedir}" 192 | } 193 | 194 | # concatenates all lines of a file 195 | concat_lines() { 196 | if [ -f "$1" ]; then 197 | echo "$(tr -s '\n' ' ' < "$1")" 198 | fi 199 | } 200 | 201 | BASE_DIR=`find_maven_basedir "$(pwd)"` 202 | if [ -z "$BASE_DIR" ]; then 203 | exit 1; 204 | fi 205 | 206 | ########################################################################################## 207 | # Extension to allow automatically downloading the maven-wrapper.jar from Maven-central 208 | # This allows using the maven wrapper in projects that prohibit checking in binary data. 209 | ########################################################################################## 210 | if [ -r "$BASE_DIR/.mvn/wrapper/maven-wrapper.jar" ]; then 211 | if [ "$MVNW_VERBOSE" = true ]; then 212 | echo "Found .mvn/wrapper/maven-wrapper.jar" 213 | fi 214 | else 215 | if [ "$MVNW_VERBOSE" = true ]; then 216 | echo "Couldn't find .mvn/wrapper/maven-wrapper.jar, downloading it ..." 217 | fi 218 | if [ -n "$MVNW_REPOURL" ]; then 219 | jarUrl="$MVNW_REPOURL/org/apache/maven/wrapper/maven-wrapper/3.1.0/maven-wrapper-3.1.0.jar" 220 | else 221 | jarUrl="https://repo.maven.apache.org/maven2/org/apache/maven/wrapper/maven-wrapper/3.1.0/maven-wrapper-3.1.0.jar" 222 | fi 223 | while IFS="=" read key value; do 224 | case "$key" in (wrapperUrl) jarUrl="$value"; break ;; 225 | esac 226 | done < "$BASE_DIR/.mvn/wrapper/maven-wrapper.properties" 227 | if [ "$MVNW_VERBOSE" = true ]; then 228 | echo "Downloading from: $jarUrl" 229 | fi 230 | wrapperJarPath="$BASE_DIR/.mvn/wrapper/maven-wrapper.jar" 231 | if $cygwin; then 232 | wrapperJarPath=`cygpath --path --windows "$wrapperJarPath"` 233 | fi 234 | 235 | if command -v wget > /dev/null; then 236 | if [ "$MVNW_VERBOSE" = true ]; then 237 | echo "Found wget ... using wget" 238 | fi 239 | if [ -z "$MVNW_USERNAME" ] || [ -z "$MVNW_PASSWORD" ]; then 240 | wget "$jarUrl" -O "$wrapperJarPath" || rm -f "$wrapperJarPath" 241 | else 242 | wget --http-user=$MVNW_USERNAME --http-password=$MVNW_PASSWORD "$jarUrl" -O "$wrapperJarPath" || rm -f "$wrapperJarPath" 243 | fi 244 | elif command -v curl > /dev/null; then 245 | if [ "$MVNW_VERBOSE" = true ]; then 246 | echo "Found curl ... using curl" 247 | fi 248 | if [ -z "$MVNW_USERNAME" ] || [ -z "$MVNW_PASSWORD" ]; then 249 | curl -o "$wrapperJarPath" "$jarUrl" -f 250 | else 251 | curl --user $MVNW_USERNAME:$MVNW_PASSWORD -o "$wrapperJarPath" "$jarUrl" -f 252 | fi 253 | 254 | else 255 | if [ "$MVNW_VERBOSE" = true ]; then 256 | echo "Falling back to using Java to download" 257 | fi 258 | javaClass="$BASE_DIR/.mvn/wrapper/MavenWrapperDownloader.java" 259 | # For Cygwin, switch paths to Windows format before running javac 260 | if $cygwin; then 261 | javaClass=`cygpath --path --windows "$javaClass"` 262 | fi 263 | if [ -e "$javaClass" ]; then 264 | if [ ! -e "$BASE_DIR/.mvn/wrapper/MavenWrapperDownloader.class" ]; then 265 | if [ "$MVNW_VERBOSE" = true ]; then 266 | echo " - Compiling MavenWrapperDownloader.java ..." 267 | fi 268 | # Compiling the Java class 269 | ("$JAVA_HOME/bin/javac" "$javaClass") 270 | fi 271 | if [ -e "$BASE_DIR/.mvn/wrapper/MavenWrapperDownloader.class" ]; then 272 | # Running the downloader 273 | if [ "$MVNW_VERBOSE" = true ]; then 274 | echo " - Running MavenWrapperDownloader.java ..." 275 | fi 276 | ("$JAVA_HOME/bin/java" -cp .mvn/wrapper MavenWrapperDownloader "$MAVEN_PROJECTBASEDIR") 277 | fi 278 | fi 279 | fi 280 | fi 281 | ########################################################################################## 282 | # End of extension 283 | ########################################################################################## 284 | 285 | export MAVEN_PROJECTBASEDIR=${MAVEN_BASEDIR:-"$BASE_DIR"} 286 | if [ "$MVNW_VERBOSE" = true ]; then 287 | echo $MAVEN_PROJECTBASEDIR 288 | fi 289 | MAVEN_OPTS="$(concat_lines "$MAVEN_PROJECTBASEDIR/.mvn/jvm.config") $MAVEN_OPTS" 290 | 291 | # For Cygwin, switch paths to Windows format before running java 292 | if $cygwin; then 293 | [ -n "$M2_HOME" ] && 294 | M2_HOME=`cygpath --path --windows "$M2_HOME"` 295 | [ -n "$JAVA_HOME" ] && 296 | JAVA_HOME=`cygpath --path --windows "$JAVA_HOME"` 297 | [ -n "$CLASSPATH" ] && 298 | CLASSPATH=`cygpath --path --windows "$CLASSPATH"` 299 | [ -n "$MAVEN_PROJECTBASEDIR" ] && 300 | MAVEN_PROJECTBASEDIR=`cygpath --path --windows "$MAVEN_PROJECTBASEDIR"` 301 | fi 302 | 303 | # Provide a "standardized" way to retrieve the CLI args that will 304 | # work with both Windows and non-Windows executions. 305 | MAVEN_CMD_LINE_ARGS="$MAVEN_CONFIG $@" 306 | export MAVEN_CMD_LINE_ARGS 307 | 308 | WRAPPER_LAUNCHER=org.apache.maven.wrapper.MavenWrapperMain 309 | 310 | exec "$JAVACMD" \ 311 | $MAVEN_OPTS \ 312 | $MAVEN_DEBUG_OPTS \ 313 | -classpath "$MAVEN_PROJECTBASEDIR/.mvn/wrapper/maven-wrapper.jar" \ 314 | "-Dmaven.home=${M2_HOME}" \ 315 | "-Dmaven.multiModuleProjectDirectory=${MAVEN_PROJECTBASEDIR}" \ 316 | ${WRAPPER_LAUNCHER} $MAVEN_CONFIG "$@" -------------------------------------------------------------------------------- /mvnw.cmd: -------------------------------------------------------------------------------- 1 | @REM ---------------------------------------------------------------------------- 2 | @REM Licensed to the Apache Software Foundation (ASF) under one 3 | @REM or more contributor license agreements. See the NOTICE file 4 | @REM distributed with this work for additional information 5 | @REM regarding copyright ownership. The ASF licenses this file 6 | @REM to you under the Apache License, Version 2.0 (the 7 | @REM "License"); you may not use this file except in compliance 8 | @REM with the License. You may obtain a copy of the License at 9 | @REM 10 | @REM https://www.apache.org/licenses/LICENSE-2.0 11 | @REM 12 | @REM Unless required by applicable law or agreed to in writing, 13 | @REM software distributed under the License is distributed on an 14 | @REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | @REM KIND, either express or implied. See the License for the 16 | @REM specific language governing permissions and limitations 17 | @REM under the License. 18 | @REM ---------------------------------------------------------------------------- 19 | 20 | @REM ---------------------------------------------------------------------------- 21 | @REM Maven Start Up Batch script 22 | @REM 23 | @REM Required ENV vars: 24 | @REM JAVA_HOME - location of a JDK home dir 25 | @REM 26 | @REM Optional ENV vars 27 | @REM M2_HOME - location of maven2's installed home dir 28 | @REM MAVEN_BATCH_ECHO - set to 'on' to enable the echoing of the batch commands 29 | @REM MAVEN_BATCH_PAUSE - set to 'on' to wait for a keystroke before ending 30 | @REM MAVEN_OPTS - parameters passed to the Java VM when running Maven 31 | @REM e.g. to debug Maven itself, use 32 | @REM set MAVEN_OPTS=-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=8000 33 | @REM MAVEN_SKIP_RC - flag to disable loading of mavenrc files 34 | @REM ---------------------------------------------------------------------------- 35 | 36 | @REM Begin all REM lines with '@' in case MAVEN_BATCH_ECHO is 'on' 37 | @echo off 38 | @REM set title of command window 39 | title %0 40 | @REM enable echoing by setting MAVEN_BATCH_ECHO to 'on' 41 | @if "%MAVEN_BATCH_ECHO%" == "on" echo %MAVEN_BATCH_ECHO% 42 | 43 | @REM set %HOME% to equivalent of $HOME 44 | if "%HOME%" == "" (set "HOME=%HOMEDRIVE%%HOMEPATH%") 45 | 46 | @REM Execute a user defined script before this one 47 | if not "%MAVEN_SKIP_RC%" == "" goto skipRcPre 48 | @REM check for pre script, once with legacy .bat ending and once with .cmd ending 49 | if exist "%USERPROFILE%\mavenrc_pre.bat" call "%USERPROFILE%\mavenrc_pre.bat" %* 50 | if exist "%USERPROFILE%\mavenrc_pre.cmd" call "%USERPROFILE%\mavenrc_pre.cmd" %* 51 | :skipRcPre 52 | 53 | @setlocal 54 | 55 | set ERROR_CODE=0 56 | 57 | @REM To isolate internal variables from possible post scripts, we use another setlocal 58 | @setlocal 59 | 60 | @REM ==== START VALIDATION ==== 61 | if not "%JAVA_HOME%" == "" goto OkJHome 62 | 63 | echo. 64 | echo Error: JAVA_HOME not found in your environment. >&2 65 | echo Please set the JAVA_HOME variable in your environment to match the >&2 66 | echo location of your Java installation. >&2 67 | echo. 68 | goto error 69 | 70 | :OkJHome 71 | if exist "%JAVA_HOME%\bin\java.exe" goto init 72 | 73 | echo. 74 | echo Error: JAVA_HOME is set to an invalid directory. >&2 75 | echo JAVA_HOME = "%JAVA_HOME%" >&2 76 | echo Please set the JAVA_HOME variable in your environment to match the >&2 77 | echo location of your Java installation. >&2 78 | echo. 79 | goto error 80 | 81 | @REM ==== END VALIDATION ==== 82 | 83 | :init 84 | 85 | @REM Find the project base dir, i.e. the directory that contains the folder ".mvn". 86 | @REM Fallback to current working directory if not found. 87 | 88 | set MAVEN_PROJECTBASEDIR=%MAVEN_BASEDIR% 89 | IF NOT "%MAVEN_PROJECTBASEDIR%"=="" goto endDetectBaseDir 90 | 91 | set EXEC_DIR=%CD% 92 | set WDIR=%EXEC_DIR% 93 | :findBaseDir 94 | IF EXIST "%WDIR%"\.mvn goto baseDirFound 95 | cd .. 96 | IF "%WDIR%"=="%CD%" goto baseDirNotFound 97 | set WDIR=%CD% 98 | goto findBaseDir 99 | 100 | :baseDirFound 101 | set MAVEN_PROJECTBASEDIR=%WDIR% 102 | cd "%EXEC_DIR%" 103 | goto endDetectBaseDir 104 | 105 | :baseDirNotFound 106 | set MAVEN_PROJECTBASEDIR=%EXEC_DIR% 107 | cd "%EXEC_DIR%" 108 | 109 | :endDetectBaseDir 110 | 111 | IF NOT EXIST "%MAVEN_PROJECTBASEDIR%\.mvn\jvm.config" goto endReadAdditionalConfig 112 | 113 | @setlocal EnableExtensions EnableDelayedExpansion 114 | for /F "usebackq delims=" %%a in ("%MAVEN_PROJECTBASEDIR%\.mvn\jvm.config") do set JVM_CONFIG_MAVEN_PROPS=!JVM_CONFIG_MAVEN_PROPS! %%a 115 | @endlocal & set JVM_CONFIG_MAVEN_PROPS=%JVM_CONFIG_MAVEN_PROPS% 116 | 117 | :endReadAdditionalConfig 118 | 119 | SET MAVEN_JAVA_EXE="%JAVA_HOME%\bin\java.exe" 120 | set WRAPPER_JAR="%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.jar" 121 | set WRAPPER_LAUNCHER=org.apache.maven.wrapper.MavenWrapperMain 122 | 123 | set DOWNLOAD_URL="https://repo.maven.apache.org/maven2/org/apache/maven/wrapper/maven-wrapper/3.1.0/maven-wrapper-3.1.0.jar" 124 | 125 | FOR /F "usebackq tokens=1,2 delims==" %%A IN ("%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.properties") DO ( 126 | IF "%%A"=="wrapperUrl" SET DOWNLOAD_URL=%%B 127 | ) 128 | 129 | @REM Extension to allow automatically downloading the maven-wrapper.jar from Maven-central 130 | @REM This allows using the maven wrapper in projects that prohibit checking in binary data. 131 | if exist %WRAPPER_JAR% ( 132 | if "%MVNW_VERBOSE%" == "true" ( 133 | echo Found %WRAPPER_JAR% 134 | ) 135 | ) else ( 136 | if not "%MVNW_REPOURL%" == "" ( 137 | SET DOWNLOAD_URL="%MVNW_REPOURL%/org/apache/maven/wrapper/maven-wrapper/3.1.0/maven-wrapper-3.1.0.jar" 138 | ) 139 | if "%MVNW_VERBOSE%" == "true" ( 140 | echo Couldn't find %WRAPPER_JAR%, downloading it ... 141 | echo Downloading from: %DOWNLOAD_URL% 142 | ) 143 | 144 | powershell -Command "&{"^ 145 | "$webclient = new-object System.Net.WebClient;"^ 146 | "if (-not ([string]::IsNullOrEmpty('%MVNW_USERNAME%') -and [string]::IsNullOrEmpty('%MVNW_PASSWORD%'))) {"^ 147 | "$webclient.Credentials = new-object System.Net.NetworkCredential('%MVNW_USERNAME%', '%MVNW_PASSWORD%');"^ 148 | "}"^ 149 | "[Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; $webclient.DownloadFile('%DOWNLOAD_URL%', '%WRAPPER_JAR%')"^ 150 | "}" 151 | if "%MVNW_VERBOSE%" == "true" ( 152 | echo Finished downloading %WRAPPER_JAR% 153 | ) 154 | ) 155 | @REM End of extension 156 | 157 | @REM Provide a "standardized" way to retrieve the CLI args that will 158 | @REM work with both Windows and non-Windows executions. 159 | set MAVEN_CMD_LINE_ARGS=%* 160 | 161 | %MAVEN_JAVA_EXE% ^ 162 | %JVM_CONFIG_MAVEN_PROPS% ^ 163 | %MAVEN_OPTS% ^ 164 | %MAVEN_DEBUG_OPTS% ^ 165 | -classpath %WRAPPER_JAR% ^ 166 | "-Dmaven.multiModuleProjectDirectory=%MAVEN_PROJECTBASEDIR%" ^ 167 | %WRAPPER_LAUNCHER% %MAVEN_CONFIG% %* 168 | if ERRORLEVEL 1 goto error 169 | goto end 170 | 171 | :error 172 | set ERROR_CODE=1 173 | 174 | :end 175 | @endlocal & set ERROR_CODE=%ERROR_CODE% 176 | 177 | if not "%MAVEN_SKIP_RC%"=="" goto skipRcPost 178 | @REM check for post script, once with legacy .bat ending and once with .cmd ending 179 | if exist "%USERPROFILE%\mavenrc_post.bat" call "%USERPROFILE%\mavenrc_post.bat" 180 | if exist "%USERPROFILE%\mavenrc_post.cmd" call "%USERPROFILE%\mavenrc_post.cmd" 181 | :skipRcPost 182 | 183 | @REM pause the script if MAVEN_BATCH_PAUSE is set to 'on' 184 | if "%MAVEN_BATCH_PAUSE%"=="on" pause 185 | 186 | if "%MAVEN_TERMINATE_CMD%"=="on" exit %ERROR_CODE% 187 | 188 | cmd /C exit /B %ERROR_CODE% -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 4.0.0 5 | 6 | org.springframework.boot 7 | spring-boot-starter-parent 8 | 3.1.5 9 | 10 | 11 | io.github.gemingjia 12 | wenxin-starter 13 | 2.0.0-beta6 14 | wenxin-starter 15 | A springboot start of Baidu "WENXINWORKSHOP" 16 | https://github.com/egmsia01/wenxin-starter 17 | 18 | 19 | 17 20 | UTF-8 21 | UTF-8 22 | UTF-8 23 | 24 | 25 | https://github.com/egmsia01/wenxin-starter 26 | https://github.com/egmsia01/wenxin-starter 27 | https://github.com/egmsia01/wenxin-starter.git 28 | 29 | 30 | 31 | GMerge 32 | gemingjia 33 | gemingjia0201@163.com 34 | 35 | Developer 36 | 37 | https://github.com/egmsia01 38 | +8 39 | 40 | 41 | 42 | 43 | MIT License 44 | https://opensource.org/license/mit 45 | 46 | 47 | 48 | 49 | ossrh 50 | https://s01.oss.sonatype.org/content/repositories/snapshots 51 | 52 | 53 | ossrh 54 | https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ 55 | 56 | 57 | 58 | 59 | 60 | org.springframework.boot 61 | spring-boot-starter 62 | 63 | 64 | org.yaml 65 | snakeyaml 66 | 2.2 67 | 68 | 69 | org.springframework.boot 70 | spring-boot-configuration-processor 71 | true 72 | 73 | 74 | org.springframework.boot 75 | spring-boot-autoconfigure 76 | 3.1.5 77 | 78 | 79 | org.springframework.boot 80 | spring-boot-starter-webflux 81 | 82 | 83 | org.springframework.boot 84 | spring-boot-starter-websocket 85 | 86 | 87 | org.projectlombok 88 | lombok 89 | true 90 | provided 91 | 92 | 93 | org.apache.commons 94 | commons-lang3 95 | 3.14.0 96 | 97 | 98 | io.projectreactor 99 | reactor-core 100 | 3.6.3 101 | 102 | 103 | com.google.code.gson 104 | gson 105 | 2.10.1 106 | 107 | 108 | org.apache.httpcomponents 109 | httpclient 110 | 4.5.13 111 | 112 | 113 | org.apache.httpcomponents 114 | httpcore 115 | 4.4.14 116 | 117 | 118 | commons-codec 119 | commons-codec 120 | 1.15 121 | 122 | 123 | 124 | 125 | 126 | 127 | release 128 | 129 | 130 | 131 | org.apache.maven.plugins 132 | maven-source-plugin 133 | 3.3.0 134 | 135 | 136 | package 137 | 138 | jar-no-fork 139 | 140 | 141 | 142 | 143 | 144 | org.apache.maven.plugins 145 | maven-javadoc-plugin 146 | 3.5.0 147 | 148 | 149 | package 150 | 151 | jar 152 | 153 | 154 | zh_CN 155 | UTF-8 156 | UTF-8 157 | none 158 | 159 | 160 | 161 | 162 | 163 | org.apache.maven.plugins 164 | maven-gpg-plugin 165 | 1.6 166 | 167 | 168 | verify 169 | 170 | sign 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/client/ChatClient.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.client; 2 | 3 | import com.gearwenxin.entity.chatmodel.ChatErnieRequest; 4 | import com.gearwenxin.config.ModelConfig; 5 | import com.gearwenxin.entity.chatmodel.ChatBaseRequest; 6 | import com.gearwenxin.entity.enums.ModelType; 7 | import com.gearwenxin.entity.response.ChatResponse; 8 | import com.gearwenxin.model.ChatModel; 9 | import com.gearwenxin.schedule.entity.ChatTask; 10 | import com.gearwenxin.schedule.TaskQueueManager; 11 | import lombok.extern.slf4j.Slf4j; 12 | import reactor.core.publisher.Flux; 13 | import reactor.core.publisher.Mono; 14 | 15 | import java.util.Map; 16 | 17 | @Slf4j 18 | public class ChatClient implements ChatModel { 19 | 20 | private final ModelConfig modelConfig; 21 | 22 | private static final float defaultWeight = 0; 23 | 24 | public ChatClient(ModelConfig modelConfig) { 25 | this.modelConfig = modelConfig; 26 | } 27 | 28 | private static final TaskQueueManager taskQueueManager = TaskQueueManager.getInstance(); 29 | 30 | @Override 31 | public Mono chat(String content) { 32 | return chat(content, defaultWeight); 33 | } 34 | 35 | @Override 36 | public Mono chat(String content, float weight) { 37 | ChatErnieRequest request = new ChatErnieRequest(); 38 | request.setContent(content); 39 | return chat(request, weight); 40 | } 41 | 42 | @Override 43 | public Mono chat(T chatRequest) { 44 | return chat(chatRequest, defaultWeight); 45 | } 46 | 47 | @Override 48 | public Mono chat(T chatRequest, float weight) { 49 | ChatTask chatTask = ChatTask.builder() 50 | .modelConfig(modelConfig) 51 | .taskType(ModelType.chat) 52 | .taskRequest(chatRequest) 53 | .taskWeight(weight) 54 | .stream(false) 55 | .build(); 56 | String taskId = taskQueueManager.addTask(chatTask); 57 | return Mono.from(taskQueueManager.getChatFuture(taskId).join()); 58 | } 59 | 60 | @Override 61 | public Flux chatStream(String content) { 62 | return chatStream(content, defaultWeight); 63 | } 64 | 65 | @Override 66 | public Flux chatStream(String content, float weight) { 67 | ChatErnieRequest request = new ChatErnieRequest(); 68 | request.setContent(content); 69 | return chatStream(request, weight); 70 | } 71 | 72 | @Override 73 | public Flux chatStream(T chatRequest) { 74 | return chatStream(chatRequest, defaultWeight); 75 | } 76 | 77 | @Override 78 | public Flux chatStream(T request, float weight) { 79 | ChatTask chatTask = ChatTask.builder() 80 | .modelConfig(modelConfig) 81 | .taskType(ModelType.chat) 82 | .taskRequest(request) 83 | .taskWeight(weight) 84 | .stream(true) 85 | .build(); 86 | String taskId = taskQueueManager.addTask(chatTask); 87 | return Flux.from(taskQueueManager.getChatFuture(taskId).join()); 88 | } 89 | 90 | @Override 91 | public Flux chatStream(Map request) { 92 | ChatTask chatTask = ChatTask.builder() 93 | .modelConfig(modelConfig) 94 | .taskType(ModelType.chat) 95 | .taskRequest(request) 96 | .taskWeight(defaultWeight) 97 | .stream(true) 98 | .jsonMode(true) 99 | .build(); 100 | String taskId = taskQueueManager.addTask(chatTask); 101 | return Flux.from(taskQueueManager.getChatFuture(taskId).join()); 102 | } 103 | 104 | @Override 105 | public Mono chats(String content, String msgUid) { 106 | return chats(content, msgUid, defaultWeight); 107 | } 108 | 109 | @Override 110 | public Mono chats(String content, String msgUid, float weight) { 111 | return chatsStream(content, msgUid, weight).next(); 112 | } 113 | 114 | @Override 115 | public Mono chats(T chatRequest, String msgUid) { 116 | return chats(chatRequest, msgUid, defaultWeight); 117 | } 118 | 119 | @Override 120 | public Mono chats(T chatRequest, String msgUid, float weight) { 121 | return chatsStream(chatRequest, msgUid, weight).next(); 122 | } 123 | 124 | @Override 125 | public Flux chatsStream(String content, String msgUid) { 126 | return chatsStream(content, msgUid, defaultWeight); 127 | } 128 | 129 | @Override 130 | public Flux chatsStream(String content, String msgUid, float weight) { 131 | ChatErnieRequest request = new ChatErnieRequest(); 132 | request.setContent(content); 133 | return chatsStream(request, msgUid, weight); 134 | } 135 | 136 | @Override 137 | public Flux chatsStream(T chatRequest, String msgUid) { 138 | return chatsStream(chatRequest, msgUid, defaultWeight); 139 | } 140 | 141 | @Override 142 | public Flux chatsStream(T request, String msgUid, float weight) { 143 | ChatTask chatTask = ChatTask.builder() 144 | .modelConfig(modelConfig) 145 | .taskType(ModelType.chat) 146 | .taskRequest(request) 147 | .messageId(msgUid) 148 | .taskWeight(weight) 149 | .stream(true) 150 | .build(); 151 | String taskId = taskQueueManager.addTask(chatTask); 152 | return Flux.from(taskQueueManager.getChatFuture(taskId).join()); 153 | } 154 | 155 | } 156 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/client/ImageClient.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.client; 2 | 3 | import com.gearwenxin.config.ModelConfig; 4 | import com.gearwenxin.entity.enums.ModelType; 5 | import com.gearwenxin.entity.request.ImageBaseRequest; 6 | import com.gearwenxin.entity.response.ImageResponse; 7 | import com.gearwenxin.model.ImageModel; 8 | import com.gearwenxin.schedule.TaskQueueManager; 9 | import com.gearwenxin.schedule.entity.ChatTask; 10 | import reactor.core.publisher.Mono; 11 | 12 | public class ImageClient implements ImageModel { 13 | 14 | private final ModelConfig modelConfig; 15 | 16 | private static final float defaultWeight = 0; 17 | 18 | TaskQueueManager taskQueueManager = TaskQueueManager.getInstance(); 19 | 20 | public ImageClient(ModelConfig modelConfig) { 21 | this.modelConfig = modelConfig; 22 | } 23 | 24 | @Override 25 | public Mono chatImage(ImageBaseRequest imageBaseRequest) { 26 | return chatImage(imageBaseRequest, defaultWeight); 27 | } 28 | 29 | @Override 30 | public Mono chatImage(ImageBaseRequest imageBaseRequest, float weight) { 31 | ChatTask chatTask = ChatTask.builder() 32 | .modelConfig(modelConfig) 33 | .taskType(ModelType.image) 34 | .taskRequest(imageBaseRequest) 35 | .taskWeight(weight) 36 | .build(); 37 | String taskId = taskQueueManager.addTask(chatTask); 38 | return Mono.from(taskQueueManager.getImageFuture(taskId).join()); 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/client/PromptClient.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.client; 2 | 3 | import com.gearwenxin.config.ModelConfig; 4 | import com.gearwenxin.entity.chatmodel.ChatPromptRequest; 5 | import com.gearwenxin.entity.enums.ModelType; 6 | import com.gearwenxin.entity.response.PromptResponse; 7 | import com.gearwenxin.model.PromptModel; 8 | import com.gearwenxin.schedule.TaskQueueManager; 9 | import com.gearwenxin.schedule.entity.ChatTask; 10 | import lombok.extern.slf4j.Slf4j; 11 | import reactor.core.publisher.Mono; 12 | 13 | @Slf4j 14 | public class PromptClient implements PromptModel { 15 | 16 | private final ModelConfig modelConfig; 17 | 18 | private static final float defaultWeight = 0; 19 | 20 | TaskQueueManager taskQueueManager = TaskQueueManager.getInstance(); 21 | 22 | public PromptClient(ModelConfig modelConfig) { 23 | this.modelConfig = modelConfig; 24 | } 25 | 26 | @Override 27 | public Mono chat(ChatPromptRequest chatRequest) { 28 | return chat(chatRequest, defaultWeight); 29 | } 30 | 31 | @Override 32 | public Mono chat(ChatPromptRequest chatRequest, float weight) { 33 | ChatTask chatTask = ChatTask.builder() 34 | .modelConfig(modelConfig) 35 | .taskType(ModelType.prompt) 36 | .taskRequest(chatRequest) 37 | .taskWeight(weight) 38 | .build(); 39 | String taskId = taskQueueManager.addTask(chatTask); 40 | return Mono.from(taskQueueManager.getPromptFuture(taskId).join()); 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/client/basic/BasicChatClient.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.client.basic; 2 | 3 | import com.gearwenxin.config.ModelConfig; 4 | import com.gearwenxin.entity.chatmodel.ChatBaseRequest; 5 | import com.gearwenxin.entity.chatmodel.ChatErnieRequest; 6 | import com.gearwenxin.entity.enums.ModelType; 7 | import com.gearwenxin.entity.response.ChatResponse; 8 | import com.gearwenxin.model.BasicChatModel; 9 | import com.gearwenxin.schedule.TaskQueueManager; 10 | import com.gearwenxin.schedule.entity.ChatTask; 11 | import lombok.extern.slf4j.Slf4j; 12 | import reactor.core.publisher.Flux; 13 | import reactor.core.publisher.Mono; 14 | 15 | @Slf4j 16 | public class BasicChatClient { 17 | 18 | private final ModelConfig modelConfig; 19 | 20 | private static final float defaultWeight = 0; 21 | 22 | public BasicChatClient(ModelConfig modelConfig) { 23 | this.modelConfig = modelConfig; 24 | } 25 | 26 | private static final TaskQueueManager taskQueueManager = TaskQueueManager.getInstance(); 27 | 28 | public Mono chat(String content) { 29 | return chat(content, defaultWeight); 30 | } 31 | 32 | public Mono chat(String content, float weight) { 33 | ChatErnieRequest request = new ChatErnieRequest(); 34 | request.setContent(content); 35 | return chat(request, weight); 36 | } 37 | 38 | public Mono chat(T chatRequest) { 39 | return chat(chatRequest, defaultWeight); 40 | } 41 | 42 | public Mono chat(T chatRequest, float weight) { 43 | ChatTask chatTask = ChatTask.builder() 44 | .modelConfig(modelConfig) 45 | .taskType(ModelType.chat) 46 | .taskRequest(chatRequest) 47 | .taskWeight(weight) 48 | .stream(false) 49 | .build(); 50 | String taskId = taskQueueManager.addTask(chatTask); 51 | return Mono.from(taskQueueManager.getChatFuture(taskId).join()); 52 | } 53 | 54 | public Flux chatStream(String content) { 55 | return chatStream(content, defaultWeight); 56 | } 57 | 58 | public Flux chatStream(String content, float weight) { 59 | ChatErnieRequest request = new ChatErnieRequest(); 60 | request.setContent(content); 61 | return chatStream(request, weight); 62 | } 63 | 64 | public Flux chatStream(T chatRequest) { 65 | return chatStream(chatRequest, defaultWeight); 66 | } 67 | 68 | public Flux chatStream(T request, float weight) { 69 | ChatTask chatTask = ChatTask.builder() 70 | .modelConfig(modelConfig) 71 | .taskType(ModelType.chat) 72 | .taskRequest(request) 73 | .taskWeight(weight) 74 | .stream(true) 75 | .build(); 76 | String taskId = taskQueueManager.addTask(chatTask); 77 | return Flux.from(taskQueueManager.getChatFuture(taskId).join()); 78 | } 79 | 80 | } 81 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/common/Constant.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.common; 2 | 3 | import java.util.Map; 4 | import java.util.concurrent.ConcurrentHashMap; 5 | 6 | /** 7 | * 模型URL 8 | * 9 | * @author Ge Mingjia 10 | * {@code @date} 2023/7/20 11 | */ 12 | public interface Constant { 13 | 14 | /** 15 | * 最大单条内容长度 16 | */ 17 | int MAX_CONTENT_LENGTH = 2000; 18 | 19 | /** 20 | * 最大所有内容总长度 21 | */ 22 | int MAX_TOTAL_LENGTH = 2000; 23 | 24 | /** 25 | * 最大system长度 26 | */ 27 | int MAX_SYSTEM_LENGTH = 1024; 28 | 29 | // 中断标志 30 | Map INTERRUPT_MAP = new ConcurrentHashMap<>(); 31 | 32 | boolean BASIC_MODE = false; 33 | 34 | String CHECK = "check"; 35 | 36 | String GET_ACCESS_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s"; 37 | String PROMPT_URL = "https://aip.baidubce.com/rest/2.0/wenxinworkshop/api/v1/template/info"; 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/common/ConvertUtils.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.common; 2 | 3 | import com.gearwenxin.entity.BaseRequest; 4 | import com.gearwenxin.entity.chatmodel.*; 5 | import com.gearwenxin.entity.request.ErnieRequest; 6 | import com.gearwenxin.entity.request.PromptRequest; 7 | 8 | import static com.gearwenxin.common.WenXinUtils.assertNotBlank; 9 | import static com.gearwenxin.common.WenXinUtils.assertNotNull; 10 | 11 | /** 12 | * 类型转换工具类 13 | * 14 | * @author Ge Mingjia 15 | * {@code @date} 2023/5/27 16 | */ 17 | public class ConvertUtils { 18 | 19 | public static ErnieRequest.ErnieRequestBuilder toErnieRequest(ChatErnieRequest chatRequest) { 20 | 21 | assertNotNull(chatRequest, "ChatErnieRequest is null"); 22 | assertNotBlank(chatRequest.getContent(), "ChatErnieRequest.content is null"); 23 | 24 | return ErnieRequest.builder() 25 | .userId(chatRequest.getUserId()) 26 | .messages(WenXinUtils.buildUserMessageHistory(chatRequest.getContent(), chatRequest.getName(), chatRequest.getFunctionCall())) 27 | .temperature(chatRequest.getTemperature()) 28 | .topP(chatRequest.getTopP()) 29 | .penaltyScore(chatRequest.getPenaltyScore()) 30 | .functions(chatRequest.getFunctions()) 31 | .system(chatRequest.getSystem()) 32 | .stop(chatRequest.getStop()) 33 | .disableSearch(chatRequest.getDisableSearch()) 34 | .enableCitation(chatRequest.getEnableCitation()); 35 | } 36 | 37 | public static BaseRequest.BaseRequestBuilder toBaseRequest(ChatBaseRequest chatRequest) { 38 | 39 | assertNotNull(chatRequest, "ChatBaseRequest is null"); 40 | assertNotBlank(chatRequest.getContent(), "ChatBaseRequest.content is null"); 41 | 42 | return BaseRequest.builder() 43 | .userId(chatRequest.getUserId()) 44 | .messages(WenXinUtils.buildUserMessageHistory(chatRequest.getContent())); 45 | } 46 | 47 | public static PromptRequest toPromptRequest(ChatPromptRequest chatRequest) { 48 | 49 | assertNotNull(chatRequest, "ChatBaseRequest is null"); 50 | assertNotNull(chatRequest.getParamMap(), "ChatPromptRequest.ParamMap is null"); 51 | 52 | return PromptRequest.builder() 53 | .id(String.valueOf(chatRequest.getId())) 54 | .paramMap(chatRequest.getParamMap()) 55 | .build(); 56 | } 57 | 58 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/common/ErrorCode.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.common; 2 | 3 | import lombok.Getter; 4 | 5 | /** 6 | * 错误码 7 | * 8 | * @author Ge Mingjia 9 | * {@code @date} 2023/7/22 10 | */ 11 | @Getter 12 | public enum ErrorCode { 13 | 14 | PARAMS_ERROR(40000, "参数错误"), 15 | REQUEST_TYPE_ERROR(40001, "不受支持的请求类"), 16 | NO_AUTH_ERROR(40101, "无权限"), 17 | SYSTEM_ERROR(50000, "系统内部异常"), 18 | OPERATION_ERROR(50001, "操作失败"), 19 | SYSTEM_NET_ERROR(50002, "系统网络异常"), 20 | WENXIN_ERROR(1, "响应异常"), 21 | SYSTEM_INPUT_ERROR(336104, "'用户输入错误' system内容不合法"), 22 | EVENT_LOOP_ERROR(50003, "事件循环异常"), 23 | CONSUMER_THREAD_START_FAILED(50004, "消费者线程启动失败"), 24 | ; 25 | 26 | /** 27 | * 状态码 28 | */ 29 | private final int code; 30 | 31 | /** 32 | * 信息 33 | */ 34 | private final String message; 35 | 36 | ErrorCode(int code, String message) { 37 | this.code = code; 38 | this.message = message; 39 | } 40 | 41 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/common/FileUtils.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.common; 2 | 3 | import java.io.File; 4 | 5 | public class FileUtils { 6 | 7 | public static String getRootDir() { 8 | return System.getProperty("user.dir") + File.separator; 9 | } 10 | 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/common/RuntimeToolkit.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.common; 2 | 3 | import lombok.extern.slf4j.Slf4j; 4 | 5 | @Slf4j 6 | public class RuntimeToolkit { 7 | 8 | public static void ifOrElse(boolean condition, Runnable ifRunnable, Runnable elseRunnable) { 9 | if (condition) { 10 | ifRunnable.run(); 11 | } else { 12 | elseRunnable.run(); 13 | } 14 | } 15 | 16 | public static void threadWait(Thread thread) { 17 | try { 18 | thread.wait(); 19 | } catch (InterruptedException e) { 20 | log.error("[{}] wait error", thread.getName(), e); 21 | } 22 | } 23 | 24 | public static void threadNotify(Thread thread) { 25 | try { 26 | thread.notify(); 27 | } catch (Exception e) { 28 | log.error("[{}] notify error", thread.getName(), e); 29 | } 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/common/StatusConst.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.common; 2 | 3 | public class StatusConst { 4 | 5 | public static boolean SERVICE_STARTED = false; 6 | public static boolean JSON_MODE = false; 7 | 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/common/WenXinUtils.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.common; 2 | 3 | import com.gearwenxin.entity.FunctionCall; 4 | import com.gearwenxin.entity.enums.Role; 5 | import com.gearwenxin.exception.WenXinException; 6 | import com.gearwenxin.entity.Message; 7 | import org.apache.commons.lang3.StringUtils; 8 | import reactor.core.publisher.Mono; 9 | 10 | import java.util.Deque; 11 | import java.util.concurrent.ConcurrentLinkedDeque; 12 | 13 | /** 14 | * @author Ge Mingjia 15 | * {@code @date} 2023/7/23 16 | */ 17 | public class WenXinUtils { 18 | 19 | public static Deque buildUserMessageHistory(String content) { 20 | return buildUserMessageHistory(content, null, null); 21 | } 22 | 23 | // TODO:全局适配 24 | public static Deque buildUserMessageHistory(String content, String name, FunctionCall functionCall) { 25 | assertNotNull(content, "content is null"); 26 | 27 | Deque messageHistory = new ConcurrentLinkedDeque<>(); 28 | Message message = buildUserMessage(content, name, functionCall); 29 | messageHistory.offer(message); 30 | return messageHistory; 31 | } 32 | 33 | public static Message buildUserMessage(String content) { 34 | return buildUserMessage(content, null, null); 35 | } 36 | 37 | // TODO:全局适配 38 | public static Message buildUserMessage(String content, String name, FunctionCall functionCall) { 39 | assertNotNull(content, "content is null"); 40 | if (functionCall == null) { 41 | return new Message(Role.user, content, null, null); 42 | } 43 | assertNotNull(name, "content is null"); 44 | 45 | return new Message(Role.user, content, name, functionCall); 46 | } 47 | 48 | // TODO:全局适配 49 | public static Message buildAssistantMessage(String content, String name, FunctionCall functionCall) { 50 | assertNotNull(content, "content is null"); 51 | return new Message(Role.assistant, content, name, functionCall); 52 | } 53 | 54 | public static Message buildFunctionMessage(String name, String content) { 55 | assertNotNull(name, "name is null"); 56 | assertNotNull(content, "content is null"); 57 | 58 | return new Message(Role.function, content, name, null); 59 | } 60 | 61 | public static Message buildAssistantMessage(String content) { 62 | return buildAssistantMessage(content, null, null); 63 | } 64 | 65 | public static void assertNotBlank(String str, String message) { 66 | if (StringUtils.isBlank(str)) { 67 | throw new WenXinException(ErrorCode.PARAMS_ERROR, message); 68 | } 69 | } 70 | 71 | public static void assertNotBlankMono(String str, String message) { 72 | if (StringUtils.isBlank(str)) { 73 | throw new WenXinException(ErrorCode.PARAMS_ERROR, message); 74 | } 75 | } 76 | 77 | public static void assertNotBlank(String message, String... strings) { 78 | for (String str : strings) { 79 | if (StringUtils.isBlank(str)) { 80 | throw new WenXinException(ErrorCode.PARAMS_ERROR, message); 81 | } 82 | } 83 | } 84 | 85 | public static void assertNotNull(Object obj, String message) { 86 | if (obj == null) { 87 | throw new WenXinException(ErrorCode.PARAMS_ERROR, message); 88 | } 89 | } 90 | 91 | public static Mono assertNotNullMono(ErrorCode errorCode, String message, Object... obj) { 92 | for (Object o : obj) { 93 | if (o == null) { 94 | return Mono.error(() -> new WenXinException(errorCode, message)); 95 | } 96 | if (o instanceof String) { 97 | if (StringUtils.isBlank((String) o)) { 98 | return Mono.error(() -> new WenXinException(errorCode, message)); 99 | } 100 | } 101 | } 102 | return Mono.empty(); 103 | } 104 | 105 | public static Mono assertNotBlankMono(String message, String... strings) { 106 | for (String str : strings) { 107 | if (StringUtils.isBlank(str)) { 108 | return Mono.error(() -> new WenXinException(ErrorCode.PARAMS_ERROR, message)); 109 | } 110 | } 111 | return Mono.empty(); 112 | } 113 | 114 | } 115 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/config/GearWenXinConfig.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.config; 2 | 3 | import com.gearwenxin.core.RequestManager; 4 | import com.gearwenxin.entity.Message; 5 | import com.gearwenxin.entity.response.TokenResponse; 6 | import com.gearwenxin.service.*; 7 | import jakarta.annotation.Resource; 8 | import lombok.extern.slf4j.Slf4j; 9 | import org.springframework.boot.CommandLineRunner; 10 | import org.springframework.boot.autoconfigure.AutoConfiguration; 11 | import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; 12 | import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; 13 | import org.springframework.boot.context.properties.EnableConfigurationProperties; 14 | import org.springframework.context.annotation.Bean; 15 | import org.springframework.context.annotation.ComponentScan; 16 | import org.springframework.core.annotation.Order; 17 | 18 | import java.util.*; 19 | 20 | /** 21 | * @author Ge Mingjia 22 | */ 23 | @Slf4j 24 | @Order(1) 25 | @AutoConfiguration 26 | @ComponentScan(basePackages = {"com.gearwenxin"}) 27 | @ConditionalOnClass(MessageService.class) 28 | @EnableConfigurationProperties(value = {WenXinProperties.class}) 29 | public class GearWenXinConfig implements CommandLineRunner { 30 | 31 | @Resource 32 | private WenXinProperties wenXinProperties; 33 | 34 | @Override 35 | public void run(String... args) { 36 | String apiKey = wenXinProperties.getApiKey(); 37 | String secretKey = wenXinProperties.getSecretKey(); 38 | String accessToken = wenXinProperties.getAccessToken(); 39 | 40 | if (apiKey == null && secretKey == null) { 41 | return; 42 | } 43 | if (accessToken != null) { 44 | log.info("[global] access-token: {}", accessToken); 45 | return; 46 | } 47 | try { 48 | RequestManager.getAccessTokenByAKSK(apiKey, secretKey).doOnNext(tokenResponse -> { 49 | if (tokenResponse != null) { 50 | Optional.ofNullable(tokenResponse.getAccessToken()).ifPresentOrElse(token -> { 51 | wenXinProperties.setAccessToken(token); 52 | log.info("[global] access-token: {}", token); 53 | }, () -> log.error(""" 54 | api-key or secret-key error! 55 | error_description: {} 56 | error: {} 57 | """, tokenResponse.getErrorDescription(), tokenResponse.getError())); 58 | } 59 | }).map(TokenResponse::getAccessToken).block(); 60 | } catch (Exception e) { 61 | log.error("get access-token error, {}", e.getMessage()); 62 | } 63 | } 64 | 65 | @Bean 66 | @ConditionalOnMissingBean 67 | public MessageService defaultMessageService() { 68 | log.warn("[default] message service, It is recommended to provide developer autonomy"); 69 | return new MessageService() { 70 | @Override 71 | public Deque getHistoryMessages(String id) { 72 | // Map> messageMap = messageHistoryManager.getChatMessageHistoryMap(); 73 | // return messageMap.computeIfAbsent(id, k -> new ArrayDeque<>()); 74 | return null; 75 | } 76 | 77 | @Override 78 | public void addHistoryMessage(String id, Message message) { 79 | // Deque historyMessages = getHistoryMessages(id); 80 | // MessageHistoryManager.addMessage(historyMessages, message); 81 | } 82 | 83 | @Override 84 | public void addHistoryMessage(Deque historyMessages, Message message) { 85 | // MessageHistoryManager.addMessage(historyMessages, message); 86 | } 87 | }; 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/config/ModelConfig.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.config; 2 | 3 | import com.gearwenxin.schedule.entity.ModelHeader; 4 | import lombok.AllArgsConstructor; 5 | import lombok.Builder; 6 | import lombok.Data; 7 | import lombok.NoArgsConstructor; 8 | 9 | import java.io.Serializable; 10 | 11 | @Data 12 | @Builder 13 | @NoArgsConstructor 14 | @AllArgsConstructor 15 | public class ModelConfig implements Serializable { 16 | 17 | /** 18 | * 任务id, 无需传,SDK内部使用 19 | */ 20 | private String taskId; 21 | 22 | private String modelName; 23 | 24 | private String modelUrl; 25 | 26 | private String accessToken; 27 | 28 | private Integer contentMaxLength = 8000; 29 | 30 | private ModelHeader modelHeader; 31 | 32 | private boolean enableStringResponse; 33 | 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/config/WenXinProperties.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.config; 2 | 3 | import com.gearwenxin.common.StatusConst; 4 | import lombok.Getter; 5 | import lombok.Setter; 6 | import org.springframework.boot.context.properties.ConfigurationProperties; 7 | import org.springframework.stereotype.Component; 8 | 9 | import java.util.List; 10 | 11 | /** 12 | * @author Ge Mingjia 13 | * {@code @date} 2023/11/1 14 | */ 15 | @Component 16 | @ConfigurationProperties("gear.wenxin") 17 | public class WenXinProperties { 18 | 19 | @Getter 20 | @Setter 21 | private String accessToken; 22 | 23 | @Getter 24 | @Setter 25 | private String apiKey; 26 | 27 | @Getter 28 | @Setter 29 | private String secretKey; 30 | 31 | @Getter 32 | @Setter 33 | private List model_qps; 34 | 35 | @Getter 36 | @Setter 37 | private Integer saveScheduledTime; 38 | 39 | private boolean basicMode; 40 | private boolean jsonMode; 41 | 42 | public List getModelQPSList() { 43 | return model_qps; 44 | } 45 | 46 | public void setModelQPSList(List model_qps) { 47 | this.model_qps = model_qps; 48 | } 49 | 50 | private boolean getBasicMode() { 51 | return basicMode; 52 | } 53 | 54 | private void setBasicMode(boolean jsonMode) { 55 | this.basicMode = jsonMode; 56 | } 57 | 58 | private boolean getJsonMode() { 59 | return jsonMode; 60 | } 61 | 62 | private void setJsonMode(boolean jsonMode) { 63 | this.jsonMode = jsonMode; 64 | StatusConst.JSON_MODE = jsonMode; 65 | } 66 | 67 | } 68 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/core/AuthEncryption.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.core; 2 | 3 | import org.apache.commons.codec.EncoderException; 4 | import org.apache.http.client.utils.URIBuilder; 5 | import org.apache.commons.codec.digest.HmacUtils; 6 | import org.apache.commons.codec.net.URLCodec; 7 | import org.apache.http.NameValuePair; 8 | import org.apache.http.client.utils.URLEncodedUtils; 9 | 10 | import java.net.URI; 11 | import java.net.URISyntaxException; 12 | import java.nio.charset.Charset; 13 | import java.time.Instant; 14 | import java.util.*; 15 | 16 | public class AuthEncryption { 17 | 18 | private static final String[] DEFAULT_HEADERS = {"host", "content-length", "content-type", "content-md5"}; 19 | private static final String ACCESS_KEY = ""; 20 | private static final String SECRET_KEY = ""; 21 | private static final String AUTH_VERSION = "1"; 22 | private static final String EXPIRATION_IN_SECONDS = "1800"; 23 | private static String g_signed_headers = ""; 24 | 25 | public static void main(String[] args) throws URISyntaxException { 26 | URI uri = new URIBuilder() 27 | .setScheme("http") 28 | .setHost("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions") 29 | .setPath("/path") 30 | .setParameter("AccessKey", ACCESS_KEY) 31 | .setParameter("SecretKey", SECRET_KEY) 32 | .build(); 33 | 34 | System.out.println(generateAuthorization(uri)); 35 | } 36 | 37 | private static String getTimestamp() { 38 | return Instant.now().toString().replace("Z", "+0000"); 39 | } 40 | 41 | private static String normalize(String string, boolean encodingSlash) { 42 | if (string == null) { 43 | return ""; 44 | } 45 | String result = null; 46 | try { 47 | result = new URLCodec().encode(string); 48 | } catch (EncoderException e) { 49 | throw new RuntimeException(e); 50 | } 51 | result = result.replace("+", "%20"); 52 | result = result.replace("*", "%2A"); 53 | result = result.replace("%7E", "~"); 54 | if (!encodingSlash) { 55 | result = result.replace("%2F", "/"); 56 | } 57 | return result; 58 | } 59 | 60 | private static String generateCanonicalUri(URI uri) { 61 | if (uri.getPath() == null) { 62 | return ""; 63 | } 64 | return normalize(uri.getPath(), true); 65 | } 66 | 67 | private static String generateCanonicalQueryString(URI uri) { 68 | List params = URLEncodedUtils.parse(uri, Charset.defaultCharset()); 69 | params.removeIf(param -> param.getName().equalsIgnoreCase("authorization")); 70 | params.sort(Comparator.comparing(NameValuePair::getName)); 71 | StringBuilder sb = new StringBuilder(); 72 | for (NameValuePair param : params) { 73 | sb.append(normalize(param.getName(), false)).append("=").append(normalize(param.getValue(), false)).append("&"); 74 | } 75 | if (sb.length() > 0) { 76 | sb.deleteCharAt(sb.length() - 1); 77 | } 78 | return sb.toString(); 79 | } 80 | 81 | private static String generateCanonicalHeaders(URI uri) { 82 | Map headers = new HashMap<>(); 83 | headers.put("host", uri.getHost()); 84 | headers.put("content-length", "0"); 85 | headers.put("content-type", "application/json"); 86 | headers.put("content-md5", ""); 87 | 88 | List keyStrList = new ArrayList<>(Arrays.asList(DEFAULT_HEADERS)); 89 | List usedHeaderStrList = new ArrayList<>(); 90 | for (String key : keyStrList) { 91 | String value = headers.get(key); 92 | if (value == null || value.isEmpty()) { 93 | continue; 94 | } 95 | usedHeaderStrList.add(normalize(key, false) + ":" + normalize(value, false)); 96 | } 97 | Collections.sort(usedHeaderStrList); 98 | List usedHeaderKeys = new ArrayList<>(); 99 | for (String item : usedHeaderStrList) { 100 | usedHeaderKeys.add(item.split(":")[0]); 101 | } 102 | g_signed_headers = String.join(";", usedHeaderKeys); 103 | return String.join("\n", usedHeaderStrList); 104 | } 105 | 106 | private static String generateAuthorization(URI uri) { 107 | String timestamp = getTimestamp(); 108 | String signingKeyStr = "bce-auth-v" + AUTH_VERSION + "/" + ACCESS_KEY + "/" + timestamp + "/" + EXPIRATION_IN_SECONDS; 109 | String signingKey = HmacUtils.hmacSha256Hex(SECRET_KEY, signingKeyStr); 110 | 111 | String canonicalUri = generateCanonicalUri(uri); 112 | String canonicalQueryString = generateCanonicalQueryString(uri); 113 | String canonicalHeaders = generateCanonicalHeaders(uri); 114 | 115 | String canonicalRequest = "GET" + "\n" + canonicalUri + "\n" + canonicalQueryString + "\n" + canonicalHeaders; 116 | String signature = HmacUtils.hmacSha256Hex(signingKey, canonicalRequest); 117 | 118 | return signingKeyStr + "/" + g_signed_headers + "/" + signature; 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/core/ConsumerService.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.core; 2 | 3 | import com.gearwenxin.config.WenXinProperties; 4 | import com.gearwenxin.schedule.TaskConsumerLoop; 5 | import jakarta.annotation.Resource; 6 | import lombok.extern.slf4j.Slf4j; 7 | import org.springframework.boot.CommandLineRunner; 8 | import org.springframework.core.annotation.Order; 9 | import org.springframework.stereotype.Component; 10 | 11 | import java.util.List; 12 | 13 | @Slf4j 14 | @Order(2) 15 | @Component 16 | public class ConsumerService implements CommandLineRunner { 17 | 18 | @Resource 19 | private TaskConsumerLoop taskConsumerLoop; 20 | 21 | @Resource 22 | private WenXinProperties wenXinProperties; 23 | 24 | @Override 25 | public void run(String... args) { 26 | // TODO: 曲线救国,初始化modelQPSList 27 | List modelQPSList = wenXinProperties.getModelQPSList(); 28 | taskConsumerLoop.setQpsList(modelQPSList); 29 | 30 | log.info("EventLoop start"); 31 | taskConsumerLoop.start(); 32 | } 33 | 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/core/ConsumerThreadMonitor.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.core; 2 | 3 | import com.gearwenxin.common.Constant; 4 | import com.gearwenxin.common.ErrorCode; 5 | import com.gearwenxin.common.StatusConst; 6 | import com.gearwenxin.common.RuntimeToolkit; 7 | import com.gearwenxin.config.ModelConfig; 8 | import com.gearwenxin.entity.enums.ModelType; 9 | import com.gearwenxin.exception.WenXinException; 10 | import com.gearwenxin.schedule.TaskConsumerLoop; 11 | import com.gearwenxin.schedule.TaskQueueManager; 12 | import com.gearwenxin.schedule.entity.ChatTask; 13 | import jakarta.annotation.Resource; 14 | import lombok.extern.slf4j.Slf4j; 15 | import org.springframework.boot.CommandLineRunner; 16 | import org.springframework.core.annotation.Order; 17 | import org.springframework.stereotype.Component; 18 | 19 | import java.util.concurrent.CountDownLatch; 20 | 21 | @Slf4j 22 | @Order(3) 23 | @Component 24 | public class ConsumerThreadMonitor implements CommandLineRunner { 25 | 26 | private static final TaskQueueManager taskQueueManager = TaskQueueManager.getInstance(); 27 | 28 | @Resource 29 | private TaskConsumerLoop taskConsumerLoop; 30 | 31 | @Override 32 | public void run(String... args) { 33 | ChatTask checkTask = ChatTask.builder() 34 | .taskType(ModelType.check) 35 | .modelConfig(ModelConfig.builder().modelName(Constant.CHECK).build()) 36 | .build(); 37 | taskQueueManager.addTask(checkTask); 38 | CountDownLatch countDownLatch = new CountDownLatch(1); 39 | taskConsumerLoop.setTestCountDownLatch(countDownLatch); 40 | try { 41 | log.info("Waiting for consumer thread to start..."); 42 | countDownLatch.await(); 43 | } catch (InterruptedException ignored) { 44 | } 45 | RuntimeToolkit.ifOrElse(StatusConst.SERVICE_STARTED, 46 | () -> log.info("Consumer thread started."), 47 | () -> { 48 | throw new WenXinException(ErrorCode.CONSUMER_THREAD_START_FAILED); 49 | }); 50 | } 51 | 52 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/core/MessageHistoryManager.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.core; 2 | 3 | import com.gearwenxin.entity.Message; 4 | import com.gearwenxin.entity.enums.Role; 5 | import org.slf4j.Logger; 6 | import org.slf4j.LoggerFactory; 7 | 8 | import java.util.Deque; 9 | import java.util.LinkedList; 10 | import java.util.Map; 11 | import java.util.concurrent.ConcurrentHashMap; 12 | 13 | import static com.gearwenxin.common.Constant.MAX_TOTAL_LENGTH; 14 | import static com.gearwenxin.common.WenXinUtils.assertNotBlank; 15 | import static com.gearwenxin.common.WenXinUtils.assertNotNull; 16 | 17 | public class MessageHistoryManager { 18 | 19 | private static final Logger log = LoggerFactory.getLogger(MessageHistoryManager.class); 20 | 21 | private MessageHistoryManager() { 22 | } 23 | 24 | private static final Integer DEFAULT_MESSAGE_MAP_KRY_SIZE = 1024; 25 | 26 | private volatile int CURRENT_MAP = 0; 27 | 28 | /** 29 | * 历史消息记录, 扩容时会来回切换 30 | */ 31 | private static final Map> chatMessageHistoryMap0; 32 | 33 | 34 | private static volatile MessageHistoryManager messageHistoryManager; 35 | 36 | static { 37 | chatMessageHistoryMap0 = new ConcurrentHashMap<>(DEFAULT_MESSAGE_MAP_KRY_SIZE); 38 | } 39 | 40 | public static MessageHistoryManager getInstance() { 41 | if (messageHistoryManager == null) { 42 | synchronized (MessageHistoryManager.class) { 43 | if (messageHistoryManager == null) { 44 | messageHistoryManager = new MessageHistoryManager(); 45 | } 46 | } 47 | } 48 | return messageHistoryManager; 49 | } 50 | 51 | public Map> getChatMessageHistoryMap() { 52 | return chatMessageHistoryMap0; 53 | } 54 | 55 | public synchronized void setChatMessageHistoryMap(Map> map) { 56 | if (CURRENT_MAP == 0) { 57 | chatMessageHistoryMap0.clear(); 58 | chatMessageHistoryMap0.putAll(map); 59 | } 60 | } 61 | 62 | public Deque getMessageHistory(String msgUid) { 63 | return getChatMessageHistoryMap().get(msgUid); 64 | } 65 | 66 | /** 67 | * 向历史消息中添加消息 68 | * 69 | * @param originalHistory 历史消息队列 70 | * @param message 需添加的Message 71 | */ 72 | public static void addMessage(Deque originalHistory, Message message) { 73 | assertNotNull(originalHistory, "messagesHistory is null"); 74 | assertNotNull(message, "message is null"); 75 | assertNotBlank(message.getContent(), "message.content is null or blank"); 76 | 77 | // 复制原始历史记录,避免直接修改原始历史记录 78 | Deque updatedHistory = new LinkedList<>(originalHistory); 79 | 80 | // 验证消息规则 81 | validateMessageRule(updatedHistory, message); 82 | 83 | // 将新消息添加到历史记录中 84 | updatedHistory.offer(message); 85 | 86 | if (message.getRole() == Role.assistant) { 87 | syncHistories(originalHistory, updatedHistory); 88 | return; 89 | } 90 | 91 | // 处理超出长度的情况 92 | handleExceedingLength(updatedHistory); 93 | 94 | // 同步历史记录 95 | syncHistories(originalHistory, updatedHistory); 96 | } 97 | 98 | public static void validateMessageRule(Deque history, Message message) { 99 | if (!history.isEmpty()) { 100 | Message lastMessage = history.peekLast(); 101 | if (lastMessage != null) { 102 | // 如果当前是奇数位message,要求role值为user或function 103 | if (history.size() % 2 != 0) { 104 | if (message.getRole() != Role.user && message.getRole() != Role.function) { 105 | // 删除最后一条消息 106 | Message polledMessage = history.pollLast(); 107 | log.debug("remove message: {}. Odd Position role is not user or function", polledMessage); 108 | validateMessageRule(history, message); 109 | } 110 | } else { 111 | // 如果当前是偶数位message,要求role值为assistant 112 | if (message.getRole() != Role.assistant) { 113 | // 删除最后一条消息 114 | Message polledMessage = history.pollLast(); 115 | log.debug("remove message: {}. Even position role is not assistant", polledMessage); 116 | validateMessageRule(history, message); 117 | } 118 | } 119 | // 第一个message的role不能是function 120 | if (history.size() == 1 && message.getRole() == Role.function) { 121 | // 删除最后一条消息 122 | Message polledMessage = history.pollLast(); 123 | log.debug("remove message: {}. first role is function", polledMessage); 124 | validateMessageRule(history, message); 125 | } 126 | 127 | // 移除连续的相同role的user messages 128 | if (lastMessage.getRole() == Role.user && message.getRole() == Role.user) { 129 | Message polledMessage = history.pollLast(); 130 | log.debug("remove message: {}. Same role message", polledMessage); 131 | validateMessageRule(history, message); 132 | } 133 | } 134 | } 135 | } 136 | 137 | public static void validateMessageRule(Deque history) { 138 | if (history != null && !history.isEmpty()) { 139 | Message message = history.pollLast(); 140 | validateMessageRule(history, message); 141 | } 142 | } 143 | 144 | private static void syncHistories(Deque original, Deque updated) { 145 | // if (updated.size() <= original.size()) { 146 | original.clear(); 147 | original.addAll(updated); 148 | // } 149 | } 150 | 151 | private static void handleExceedingLength(Deque updatedHistory) { 152 | int totalLength = updatedHistory.stream() 153 | .filter(msg -> msg.getRole() == Role.user) 154 | .mapToInt(msg -> msg.getContent().length()) 155 | .sum(); 156 | 157 | while (totalLength > MAX_TOTAL_LENGTH && updatedHistory.size() > 2) { 158 | Message firstMessage = updatedHistory.poll(); 159 | Message secondMessage = updatedHistory.poll(); 160 | if (firstMessage != null && secondMessage != null) { 161 | totalLength -= (firstMessage.getContent().length() + secondMessage.getContent().length()); 162 | } else if (secondMessage != null) { 163 | updatedHistory.addFirst(secondMessage); 164 | totalLength -= secondMessage.getContent().length(); 165 | } 166 | } 167 | } 168 | 169 | } 170 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/core/RequestManager.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.core; 2 | 3 | import com.gearwenxin.common.ErrorCode; 4 | import com.gearwenxin.common.WenXinUtils; 5 | import com.gearwenxin.config.ModelConfig; 6 | import com.gearwenxin.entity.Message; 7 | import com.gearwenxin.entity.response.ChatResponse; 8 | import com.gearwenxin.entity.response.TokenResponse; 9 | import com.gearwenxin.exception.WenXinException; 10 | import com.gearwenxin.schedule.TaskQueueManager; 11 | import com.gearwenxin.schedule.entity.ModelHeader; 12 | import com.gearwenxin.subscriber.CommonSubscriber; 13 | import lombok.extern.slf4j.Slf4j; 14 | import org.springframework.http.HttpHeaders; 15 | import org.springframework.http.MediaType; 16 | import org.springframework.web.reactive.function.BodyInserters; 17 | import org.springframework.web.reactive.function.client.*; 18 | import reactor.core.publisher.Flux; 19 | import reactor.core.publisher.Mono; 20 | 21 | import java.net.URLEncoder; 22 | import java.nio.charset.StandardCharsets; 23 | import java.util.Map; 24 | import java.util.Deque; 25 | import java.util.Optional; 26 | import java.util.function.Consumer; 27 | import java.util.stream.Collectors; 28 | 29 | import static com.gearwenxin.common.Constant.GET_ACCESS_TOKEN_URL; 30 | import static com.gearwenxin.common.WenXinUtils.*; 31 | import static com.gearwenxin.core.MessageHistoryManager.validateMessageRule; 32 | 33 | /** 34 | * @author Ge Mingjia 35 | * {@code @date} 2023/7/21 36 | */ 37 | @Slf4j 38 | public class RequestManager { 39 | 40 | private final TaskQueueManager taskManager = TaskQueueManager.getInstance(); 41 | 42 | private static final MessageHistoryManager messageHistoryManager = MessageHistoryManager.getInstance(); 43 | private static final String ACCESS_TOKEN_PRE = "?access_token="; 44 | 45 | private static WebClient createWebClient(String baseUrl, ModelHeader header) { 46 | WebClient.Builder builder = WebClient.builder() 47 | .baseUrl(baseUrl) 48 | .defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE); 49 | 50 | if (header != null) { 51 | Optional.ofNullable(header.get_X_Ratelimit_Limit_Requests()) 52 | .ifPresent(value -> 53 | builder.defaultHeader("X-Ratelimit-Limit-Requests", String.valueOf(value))); 54 | Optional.ofNullable(header.get_X_Ratelimit_Limit_Tokens()) 55 | .ifPresent(value -> 56 | builder.defaultHeader("X-Ratelimit-Limit-Tokens", String.valueOf(value))); 57 | Optional.ofNullable(header.get_X_Ratelimit_Remaining_Requests()) 58 | .ifPresent(value -> 59 | builder.defaultHeader("X-Ratelimit-Remaining-Requests", String.valueOf(value))); 60 | Optional.ofNullable(header.get_X_Ratelimit_Remaining_Tokens()) 61 | .ifPresent(value -> 62 | builder.defaultHeader("X-Ratelimit-Remaining-Tokens", String.valueOf(value))); 63 | Optional.ofNullable(header.getAuthorization()) 64 | .ifPresent(value -> 65 | builder.defaultHeader(HttpHeaders.AUTHORIZATION, value)); 66 | } 67 | 68 | return builder.build(); 69 | } 70 | 71 | public Mono monoPost(ModelConfig config, String accessToken, Object request, Class type) { 72 | return monoPost(config, accessToken, request, type, null); 73 | } 74 | 75 | public Mono monoPost(ModelConfig config, String accessToken, Object request, Class type, 76 | String messageUid) { 77 | validateRequestParams(config.getModelUrl(), accessToken, request, type); 78 | 79 | String completeUrl = buildCompleteUrl(config, accessToken); 80 | return createWebClient(completeUrl, config.getModelHeader()) 81 | .post() 82 | .body(BodyInserters.fromValue(request)) 83 | .retrieve() 84 | .bodyToMono(type) 85 | .doOnSuccess(response -> handleSuccess(response, messageUid, config)) 86 | .doOnError(WebClientResponseException.class, handleWebClientError()); 87 | } 88 | 89 | public Flux fluxPost(ModelConfig config, String accessToken, Object request, Class type) { 90 | return fluxPost(config, accessToken, request, type, null); 91 | } 92 | 93 | public Flux fluxPost(ModelConfig config, String accessToken, Object request, Class type, 94 | String messageUid) { 95 | validateRequestParams(config.getModelUrl(), accessToken, request, type); 96 | 97 | String completeUrl = buildCompleteUrl(config, accessToken); 98 | return createWebClient(completeUrl, config.getModelHeader()) 99 | .post() 100 | .body(BodyInserters.fromValue(request)) 101 | .accept(MediaType.TEXT_EVENT_STREAM) 102 | .retrieve() 103 | .bodyToFlux(type) 104 | .doOnNext(response -> handleStreamingResponse(response, messageUid)) 105 | .doOnError(WebClientResponseException.class, handleWebClientError()) 106 | .doOnComplete(() -> taskManager.downModelCurrentQPS(config.getModelName())); 107 | } 108 | 109 | public Mono monoGet(ModelConfig config, String accessToken, Map paramsMap, 110 | Class type) { 111 | validateRequestParams(config.getModelUrl(), accessToken, paramsMap, type); 112 | 113 | if (!isAuthorization(config)) { 114 | paramsMap.put("access_token", accessToken); 115 | } 116 | 117 | String queryParams = buildQueryParams(paramsMap); 118 | 119 | return createWebClient(config.getModelUrl(), config.getModelHeader()) 120 | .get() 121 | .uri(uriBuilder -> uriBuilder.query(queryParams).build()) 122 | .retrieve() 123 | .bodyToMono(type) 124 | .doOnSuccess(RequestManager::handleErrResponse) 125 | .doOnError(WebClientResponseException.class, handleWebClientError()); 126 | } 127 | 128 | public Flux historyFluxPost(ModelConfig config, String token, T request, 129 | Deque messagesHistory, String msgUid) { 130 | return Flux.create(emitter -> { 131 | CommonSubscriber subscriber = new CommonSubscriber(emitter, messagesHistory, config, msgUid); 132 | fluxPost(config, token, request, ChatResponse.class).subscribe(subscriber); 133 | emitter.onDispose(subscriber); 134 | }); 135 | } 136 | 137 | public Mono historyMonoPost(ModelConfig config, String token, T request, 138 | Deque messagesHistory, String messageUid) { 139 | return monoPost(config, token, request, ChatResponse.class, messageUid) 140 | .flatMap(chatResponse -> { 141 | Message messageResult = WenXinUtils.buildAssistantMessage(chatResponse.getResult()); 142 | MessageHistoryManager.addMessage(messagesHistory, messageResult); 143 | taskManager.downModelCurrentQPS(config.getModelName()); 144 | return Mono.just(chatResponse); 145 | }); 146 | } 147 | 148 | public static Mono getAccessTokenByAKSK(String apiKey, String secretKey) { 149 | assertNotBlank("api-key或secret-key为空", apiKey, secretKey); 150 | 151 | final String url = String.format(GET_ACCESS_TOKEN_URL, apiKey, secretKey); 152 | return createWebClient(url, null) 153 | .get() 154 | .retrieve() 155 | .bodyToMono(TokenResponse.class); 156 | } 157 | 158 | private String buildCompleteUrl(ModelConfig config, String accessToken) { 159 | return String.format("%s%s%s", config.getModelUrl(), 160 | isAuthorization(config) ? "" : ACCESS_TOKEN_PRE, accessToken); 161 | } 162 | 163 | private static void validateRequestParams(String url, String accessToken, Object request, Class type) { 164 | assertNotBlank(url, "model url is null"); 165 | assertNotNull(request, "request is null"); 166 | assertNotNull(type, "response type is null"); 167 | } 168 | 169 | private static String buildQueryParams(Map paramsMap) { 170 | return paramsMap.entrySet().stream() 171 | .map(entry -> entry.getKey() + "=" + encodeURL(entry.getValue())) 172 | .collect(Collectors.joining("&")); 173 | } 174 | 175 | private void handleSuccess(Object response, String messageUid, ModelConfig config) { 176 | if (!handleErrResponse(response, messageUid)) { 177 | taskManager.downModelCurrentQPS(config.getModelName()); 178 | } 179 | } 180 | 181 | private static void handleErrResponse(T response) { 182 | handleErrResponse(response, null); 183 | } 184 | 185 | private static boolean handleErrResponse(T response, String messageUid) { 186 | assertNotNull(response, "响应异常"); 187 | if (response instanceof ChatResponse chatResponse && chatResponse.getErrorMsg() != null) { 188 | log.error("响应存在错误: {}", chatResponse.getErrorMsg()); 189 | if (messageUid != null) { 190 | Deque messageHistory = messageHistoryManager.getMessageHistory(messageUid); 191 | validateMessageRule(messageHistory); 192 | } 193 | return true; 194 | } 195 | return false; 196 | } 197 | 198 | private void handleStreamingResponse(T response, String messageUid) { 199 | if (handleErrResponse(response, messageUid)) { 200 | return; 201 | } 202 | 203 | if (response instanceof ChatResponse chatResponse) { 204 | String text = chatResponse.getResult(); 205 | if (text != null && text.startsWith("data:") && text.length() > 5) { 206 | text = text.substring(5); 207 | } 208 | if (text != null && text.endsWith("\n\n")) { 209 | text = text.trim(); 210 | } 211 | chatResponse.setResult(text); 212 | } 213 | } 214 | 215 | private boolean isAuthorization(ModelConfig config) { 216 | return config.getModelHeader().getAuthorization() != null; 217 | } 218 | 219 | private static Consumer handleWebClientError() { 220 | return err -> { 221 | log.error("请求错误: {}", err.getMessage()); 222 | throw new WenXinException(ErrorCode.SYSTEM_NET_ERROR); 223 | }; 224 | } 225 | 226 | private static String encodeURL(String component) { 227 | assertNotBlank(component, "EncodeURL error!"); 228 | return URLEncoder.encode(component, StandardCharsets.UTF_8); 229 | } 230 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/BaseProperty.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity; 2 | 3 | import lombok.Builder; 4 | import lombok.Data; 5 | 6 | /** 7 | * @author Ge Mingjia 8 | * {@code @date} 2023/10/8 9 | */ 10 | @Data 11 | @Builder 12 | public class BaseProperty { 13 | 14 | private String url; 15 | 16 | private String tag; 17 | 18 | private String accessToken; 19 | 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/BaseRequest.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | import java.util.Deque; 9 | 10 | /** 11 | * @author Ge Mingjia 12 | * {@code @date} 2023/7/21 13 | */ 14 | @Data 15 | @NoArgsConstructor 16 | @AllArgsConstructor 17 | public class BaseRequest { 18 | 19 | /** 20 | * 表示最终用户的唯一标识符,可以监视和检测滥用行为,防止接口恶意调用 21 | */ 22 | @JsonProperty("user_id") 23 | private String userId; 24 | 25 | /** 26 | * 聊天上下文信息. 27 | * (1)messages成员不能为空,1个成员表示单轮对话,多个成员表示多轮对话 28 | * (2)最后一个message为当前请求的信息,前面的message为历史对话信息 29 | * (3)必须为奇数个成员,成员中message的role必须依次为user、assistant 30 | * (4)最后一个message的content长度(即此轮对话的问题)不能超过2000个字符; 31 | * 如果messages中content总长度大于2000字符,系统会依次遗忘最早的历史会话,直到content的总长度不超过2000个字符 32 | */ 33 | @JsonProperty("messages") 34 | private Deque messages; 35 | 36 | /** 37 | * 是否以流式接口的形式返回数据,默认false 38 | */ 39 | @JsonProperty("stream") 40 | private Boolean stream; 41 | 42 | public static class BaseRequestBuilder { 43 | private String userId; 44 | private Deque messages; 45 | private Boolean stream; 46 | 47 | public BaseRequestBuilder userId(String userId) { 48 | this.userId = userId; 49 | return this; 50 | } 51 | 52 | public BaseRequestBuilder messages(Deque messages) { 53 | this.messages = messages; 54 | return this; 55 | } 56 | 57 | public BaseRequestBuilder stream(Boolean stream) { 58 | this.stream = stream; 59 | return this; 60 | } 61 | 62 | public BaseRequest build() { 63 | BaseRequest baseRequest = new BaseRequest(); 64 | baseRequest.setUserId(userId); 65 | baseRequest.setMessages(messages); 66 | baseRequest.setStream(stream); 67 | 68 | return baseRequest; 69 | } 70 | } 71 | 72 | public static BaseRequestBuilder builder() { 73 | return new BaseRequestBuilder(); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/ClientParams.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Data; 5 | import lombok.NoArgsConstructor; 6 | 7 | /** 8 | * @author Ge Mingjia 9 | * {@code @date} 2023/10/3 10 | */ 11 | @Data 12 | @NoArgsConstructor 13 | @AllArgsConstructor 14 | public class ClientParams { 15 | 16 | /** 17 | * 消息内容 18 | */ 19 | private String content; 20 | 21 | /** 22 | * 消息UID 23 | */ 24 | private String msgUid; 25 | 26 | /** 27 | * 参数类 28 | */ 29 | private T paramObj; 30 | 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/Example.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity; 2 | 3 | import com.gearwenxin.entity.enums.Role; 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | /** 9 | * @author Ge Mingjia 10 | * {@code @date} 2023/10/15 11 | */ 12 | @Data 13 | @NoArgsConstructor 14 | @AllArgsConstructor 15 | public class Example { 16 | 17 | /** 18 | * 当前支持以下: 19 | * user: 表示用户 20 | * assistant: 表示对话助手 21 | * function: 表示函数 22 | */ 23 | private Role role; 24 | 25 | /** 26 | * 对话内容,当前message存在function_call时可以为空,其他场景不能为空 27 | */ 28 | private String content; 29 | 30 | /** 31 | * message作者;当role=function时,必填,且是响应内容中function_call中的name 32 | */ 33 | private String name; 34 | 35 | /** 36 | * 函数调用,function call场景下第一轮对话的返回,第二轮对话作为历史信息在message中传入 37 | */ 38 | private FunctionCall functionCall; 39 | 40 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/FunctionCall.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Data; 5 | import lombok.NoArgsConstructor; 6 | 7 | /** 8 | * @author Ge Mingjia 9 | * {@code @date} 2023/10/15 10 | */ 11 | @Data 12 | @NoArgsConstructor 13 | @AllArgsConstructor 14 | public class FunctionCall { 15 | 16 | /** 17 | * 触发的function名 18 | */ 19 | private String name; 20 | 21 | /** 22 | * 请求参数 23 | */ 24 | private String arguments; 25 | 26 | /** 27 | * 模型思考过程 28 | */ 29 | private String thoughts; 30 | 31 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/FunctionInfo.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Data; 5 | import lombok.NoArgsConstructor; 6 | 7 | import java.util.List; 8 | 9 | /** 10 | * @author Ge Mingjia 11 | * {@code @date} 2023/10/15 12 | */ 13 | @Data 14 | @NoArgsConstructor 15 | @AllArgsConstructor 16 | public class FunctionInfo { 17 | 18 | /** 19 | * 函数名 20 | */ 21 | private String name; 22 | 23 | /** 24 | * 函数描述 25 | */ 26 | private String description; 27 | 28 | /** 29 | * 函数请求参数,说明: 30 | * (1)JSON Schema 格式,参考JSON Schema描述 31 | * (2)如果函数没有请求参数,parameters值格式如下: 32 | * {"type": "object","properties": {}} 33 | */ 34 | private FunctionParameters parameters; 35 | 36 | /** 37 | * 函数响应参数,JSON Schema 格式,参考JSON Schema描述 38 | */ 39 | private FunctionResponses responses; 40 | 41 | /** 42 | * function调用的一些历史示例 43 | */ 44 | private List examples; 45 | 46 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/FunctionParameters.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Data; 5 | import lombok.NoArgsConstructor; 6 | 7 | import java.util.Map; 8 | 9 | /** 10 | * @author Ge Mingjia 11 | * {@code @date} 2023/10/22 12 | */ 13 | @Data 14 | @AllArgsConstructor 15 | @NoArgsConstructor 16 | public class FunctionParameters { 17 | private String name; 18 | private String description; 19 | private Map> properties; 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/FunctionResponses.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Data; 5 | import lombok.NoArgsConstructor; 6 | 7 | import java.util.Map; 8 | 9 | /** 10 | * @author Ge Mingjia 11 | * {@code @date} 2023/10/22 12 | */ 13 | @Data 14 | @AllArgsConstructor 15 | @NoArgsConstructor 16 | public class FunctionResponses { 17 | private String type; 18 | private Map> properties; 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/Message.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import com.gearwenxin.entity.enums.Role; 5 | import lombok.AllArgsConstructor; 6 | import lombok.Data; 7 | import lombok.NoArgsConstructor; 8 | 9 | /** 10 | * @author Ge Mingjia 11 | * {@code @date} 2023/7/20 12 | */ 13 | @Data 14 | @AllArgsConstructor 15 | @NoArgsConstructor 16 | public class Message { 17 | 18 | /** 19 | * 当前支持以下: 20 | * user: 表示用户 21 | * assistant: 表示对话助手 22 | */ 23 | private Role role; 24 | 25 | /** 26 | * 对话内容,不能为空 27 | */ 28 | private String content; 29 | 30 | /** 31 | * message作者;当role=function时,必填,且是响应内容中function_call中的name 32 | */ 33 | private String name; 34 | 35 | /** 36 | * 函数调用,function call场景下第一轮对话的返回,第二轮对话作为历史信息在message中传入 37 | */ 38 | @JsonProperty("function_call") 39 | private FunctionCall functionCall; 40 | 41 | } 42 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/PluginUsage.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.Data; 5 | 6 | /** 7 | * plugin_usage说明 8 | */ 9 | @Data 10 | public class PluginUsage { 11 | 12 | /** 13 | * 插件名称,chatFile:chatfile插件消耗的tokens 14 | */ 15 | @JsonProperty("name") 16 | private String name; 17 | 18 | /** 19 | * 解析文档tokens 20 | */ 21 | @JsonProperty("parse_tokens") 22 | private int parseTokens; 23 | 24 | /** 25 | * 摘要文档tokens 26 | */ 27 | @JsonProperty("abstract_tokens") 28 | private int abstractTokens; 29 | 30 | /** 31 | * 检索文档tokens 32 | */ 33 | @JsonProperty("search_tokens") 34 | private int searchTokens; 35 | 36 | /** 37 | * 总tokens 38 | */ 39 | @JsonProperty("total_tokens") 40 | private int totalTokens; 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/Usage.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | import java.util.List; 9 | 10 | /** 11 | * @author Ge Mingjia 12 | * {@code @date} 2023/7/20 13 | */ 14 | @Data 15 | @NoArgsConstructor 16 | @AllArgsConstructor 17 | public class Usage { 18 | 19 | /** 20 | * 问题tokens数 21 | */ 22 | @JsonProperty("prompt_tokens") 23 | private int promptTokens; 24 | 25 | /** 26 | * 回答tokens数 27 | */ 28 | @JsonProperty("completion_tokens") 29 | private int completionTokens; 30 | 31 | /** 32 | * tokens总数 33 | */ 34 | @JsonProperty("total_tokens") 35 | private int totalTokens; 36 | 37 | /** 38 | * plugin消耗的tokens 39 | */ 40 | @JsonProperty("plugins") 41 | private List plugins; 42 | 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/chatmodel/ChatBaseRequest.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.chatmodel; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Builder; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | import java.io.Serializable; 9 | 10 | /** 11 | * @author Ge Mingjia 12 | * {@code @date} 2023/8/3 13 | */ 14 | @Data 15 | @Builder 16 | @AllArgsConstructor 17 | @NoArgsConstructor 18 | public class ChatBaseRequest implements Serializable { 19 | 20 | /** 21 | * 表示最终用户的唯一标识符,可以监视和检测滥用行为,防止接口恶意调用 22 | */ 23 | private String userId; 24 | 25 | /** 26 | * 聊天信息,不能为空 27 | */ 28 | private String content; 29 | 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/chatmodel/ChatErnieRequest.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.chatmodel; 2 | 3 | import com.gearwenxin.entity.FunctionCall; 4 | import com.gearwenxin.entity.FunctionInfo; 5 | import lombok.*; 6 | import lombok.extern.slf4j.Slf4j; 7 | 8 | import java.util.List; 9 | 10 | /** 11 | * @author Ge Mingjia 12 | * {@code @date} 2023/7/20 13 | *

14 | * ContBot 模型 15 | */ 16 | @Slf4j 17 | @Data 18 | @NoArgsConstructor 19 | @AllArgsConstructor 20 | @EqualsAndHashCode(callSuper = true) 21 | public class ChatErnieRequest extends ChatBaseRequest { 22 | 23 | /** 24 | * 输出更加随机,而较低的数值会使其更加集中和确定,默认0.95,范围 (0, 1.0] 25 | */ 26 | private Float temperature; 27 | 28 | /** 29 | * (影响输出文本的多样性,越大生成文本的多样性越强 30 | */ 31 | private Float topP; 32 | 33 | /** 34 | * 通过对已生成的token增加惩罚,减少重复生成的现象。 35 | */ 36 | private Float penaltyScore; 37 | 38 | /** 39 | * 一个可触发函数的描述列表 40 | */ 41 | private List functions; 42 | 43 | /** 44 | * 模型人设,主要用于人设设定,例如,你是xxx公司制作的AI助手,说明: 45 | * (1)长度限制1024个字符 46 | * (2)如果使用functions参数,不支持设定人设system 47 | */ 48 | private String system; 49 | 50 | /** 51 | * message作者;当role=function时,必填,且是响应内容中function_call中的name 52 | */ 53 | private String name; 54 | 55 | /** 56 | * 函数调用,function call场景下第一轮对话的返回,第二轮对话作为历史信息在message中传入 57 | */ 58 | private FunctionCall functionCall; 59 | 60 | /** 61 | * 生成停止标识,当模型生成结果以stop中某个元素结尾时,停止文本生成。说明: 62 | * (1)每个元素长度不超过20字符 63 | * (2)最多4个元素 64 | */ 65 | private List stop; 66 | 67 | /** 68 | * 是否强制关闭实时搜索功能,默认false,表示不关闭 69 | */ 70 | private Boolean disableSearch; 71 | 72 | /** 73 | * 是否开启上角标返回,说明: 74 | * (1)开启后,有概率触发搜索溯源信息search_info,search_info内容见响应参数介绍 75 | * (2)默认false,不开启 76 | */ 77 | private Boolean enableCitation; 78 | 79 | /** 80 | * 指定响应内容的格式,说明: 81 | * (1)可选值: 82 | * · json_object:以json格式返回,可能出现不满足效果情况 83 | * · text:以文本格式返回 84 | * (2)如果不填写参数response_format值,默认为text 85 | */ 86 | private String responseFormat; 87 | 88 | } 89 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/chatmodel/ChatPromptRequest.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.chatmodel; 2 | 3 | import lombok.*; 4 | 5 | import java.util.Map; 6 | 7 | /** 8 | * @author Ge Mingjia 9 | * {@code @date} 2023/7/26 10 | */ 11 | @Data 12 | @NoArgsConstructor 13 | @AllArgsConstructor 14 | public class ChatPromptRequest { 15 | 16 | /** 17 | * prompt工程里面对应的模板id 18 | */ 19 | private String id; 20 | 21 | /** 22 | * 参数map 23 | */ 24 | private Map paramMap; 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/enums/ModelType.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.enums; 2 | 3 | import lombok.Getter; 4 | 5 | /** 6 | * @author Ge Mingjia 7 | * {@code @date} 2023/10/15 8 | */ 9 | @Getter 10 | public enum ModelType { 11 | chat("chat"), 12 | prompt("prompt"), 13 | image("image"), 14 | embedding("embedding"), 15 | addTask("addTask"), 16 | check("check"),; 17 | 18 | private final String value; 19 | 20 | ModelType(String value) { 21 | this.value = value; 22 | } 23 | 24 | } 25 | 26 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/enums/ResponseFormatType.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.enums; 2 | 3 | import lombok.Getter; 4 | 5 | import java.util.Optional; 6 | 7 | @Getter 8 | public enum ResponseFormatType { 9 | 10 | /** 11 | * 以json格式返回 12 | */ 13 | json_object("json_object"), 14 | 15 | /** 16 | * 以文本格式返回 17 | */ 18 | text("text"); 19 | 20 | private final String value; 21 | 22 | ResponseFormatType(String value) { 23 | this.value = value; 24 | } 25 | 26 | public static ResponseFormatType TypeFromString(String text) { 27 | for (ResponseFormatType b : ResponseFormatType.values()) { 28 | if (b.value.equalsIgnoreCase(text)) { 29 | return b; 30 | } 31 | } 32 | return null; 33 | } 34 | 35 | public static Optional fromString(String text) { 36 | return Optional.ofNullable(TypeFromString(text)); 37 | } 38 | 39 | public static void ifPresent(String text, Runnable runnable) { 40 | fromString(text).ifPresent(result -> runnable.run()); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/enums/Role.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.enums; 2 | 3 | /** 4 | * @author Ge Mingjia 5 | * {@code @date} 2023/7/20 6 | */ 7 | public enum Role { 8 | 9 | /** 10 | * 用户 11 | */ 12 | user, 13 | 14 | /** 15 | * AI回复 16 | */ 17 | assistant, 18 | 19 | /** 20 | * 只存在于function call的examples中 21 | */ 22 | function 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/enums/SamplerType.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.enums; 2 | 3 | import lombok.Getter; 4 | 5 | /** 6 | * @author Ge Mingjia 7 | * {@code @date} 2023/11/4 8 | */ 9 | @Getter 10 | public enum SamplerType { 11 | Euler("Euler"), 12 | Euler_A("Euler a"), 13 | DPM_2M("DPM++ 2M"), 14 | DPM_2M_Karras("DPM++ 2M Karras"), 15 | LMS_Karras("LMS Karras"), 16 | DPM_SDE("DPM++ SDE"), 17 | DPM_SDE_Karras("DPM++ SDE Karras"), 18 | DPM2_a_Karras("DPM2 a Karras"), 19 | Heun("Heun"), 20 | DPM_2M_SDE("DPM++ 2M SDE"), 21 | DPM_2M_SDE_Karras("DPM++ 2M SDE Karras"), 22 | DPM2("DPM2"), 23 | DPM2_Karras("DPM2 Karras"), 24 | DPM2_a("DPM2 a"), 25 | LMS("LMS"); 26 | 27 | private final String value; 28 | 29 | SamplerType(String value) { 30 | this.value = value; 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/request/EmbeddingV1Request.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.request; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Data; 5 | import lombok.NoArgsConstructor; 6 | 7 | /** 8 | * @author Ge Mingjia 9 | * {@code @date} 2023/7/26 10 | */ 11 | @Data 12 | @AllArgsConstructor 13 | @NoArgsConstructor 14 | public class EmbeddingV1Request { 15 | 16 | /** 17 | * 内容 18 | */ 19 | private String content; 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/request/ErnieRequest.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.request; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import com.gearwenxin.entity.BaseRequest; 5 | 6 | import com.gearwenxin.entity.FunctionInfo; 7 | import com.gearwenxin.entity.Message; 8 | import lombok.*; 9 | 10 | import java.util.Deque; 11 | import java.util.List; 12 | 13 | /** 14 | * @author Ge Mingjia 15 | * {@code @date} 2023/7/20 16 | *

17 | * ContBot 模型 18 | */ 19 | @Data 20 | @NoArgsConstructor 21 | @AllArgsConstructor 22 | @ToString(callSuper = true) 23 | @EqualsAndHashCode(callSuper = true) 24 | public class ErnieRequest extends BaseRequest { 25 | 26 | /** 27 | * (1)较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 28 | * (2)默认0.95,范围 (0, 1.0],不能为0 29 | * (3)建议该参数和top_p只设置1个 30 | * (4)建议top_p和temperature不要同时更改 31 | */ 32 | @JsonProperty("temperature") 33 | private Float temperature; 34 | 35 | /** 36 | * (1)影响输出文本的多样性,取值越大,生成文本的多样性越强 37 | * (2)默认0.8,取值范围 [0, 1.0] 38 | * (3)建议该参数和temperature只设置1个 39 | * (4)建议top_p和temperature不要同时更改 40 | */ 41 | @JsonProperty("top_p") 42 | private Float topP; 43 | 44 | /** 45 | * 通过对已生成的token增加惩罚,减少重复生成的现象。说明: 46 | * (1)值越大表示惩罚越大 47 | * (2)默认1.0,取值范围:[1.0, 2.0] 48 | */ 49 | @JsonProperty("penalty_score") 50 | private Float penaltyScore; 51 | 52 | /** 53 | * 一个可触发函数的描述列表 54 | */ 55 | @JsonProperty("functions") 56 | private List functions; 57 | 58 | /** 59 | * 模型人设,主要用于人设设定,例如,你是xxx公司制作的AI助手,说明: 60 | * (1)长度限制1024个字符 61 | * (2)如果使用functions参数,不支持设定人设system 62 | */ 63 | @JsonProperty("system") 64 | private String system; 65 | 66 | /** 67 | * 生成停止标识,当模型生成结果以stop中某个元素结尾时,停止文本生成。说明: 68 | * (1)每个元素长度不超过20字符 69 | * (2)最多4个元素 70 | */ 71 | @JsonProperty("stop") 72 | private List stop; 73 | 74 | /** 75 | * 是否强制关闭实时搜索功能,默认false,表示不关闭 76 | */ 77 | @JsonProperty("disable_search") 78 | private Boolean disableSearch; 79 | 80 | /** 81 | * 是否开启上角标返回,说明: 82 | * (1)开启后,有概率触发搜索溯源信息search_info,search_info内容见响应参数介绍 83 | * (2)默认false,不开启 84 | */ 85 | @JsonProperty("enable_citation") 86 | private Boolean enableCitation; 87 | 88 | /** 89 | * 指定响应内容的格式,说明: 90 | * (1)可选值: 91 | * · json_object:以json格式返回,可能出现不满足效果情况 92 | * · text:以文本格式返回 93 | * (2)如果不填写参数response_format值,默认为text 94 | */ 95 | @JsonProperty("response_format") 96 | private String responseFormat; 97 | 98 | public static ErnieRequestBuilder builder() { 99 | return new ErnieRequestBuilder(); 100 | } 101 | 102 | public static class ErnieRequestBuilder extends BaseRequestBuilder { 103 | private Float temperature; 104 | private Float topP; 105 | private Float penaltyScore; 106 | private String userId; 107 | private Deque messages; 108 | private Boolean stream; 109 | private List functions; 110 | private String system; 111 | private List stop; 112 | private Boolean disableSearch; 113 | private Boolean enableCitation; 114 | private String responseFormat; 115 | 116 | public ErnieRequestBuilder temperature(Float temperature) { 117 | this.temperature = temperature; 118 | return this; 119 | } 120 | 121 | public ErnieRequestBuilder topP(Float topP) { 122 | this.topP = topP; 123 | return this; 124 | } 125 | 126 | public ErnieRequestBuilder penaltyScore(Float penaltyScore) { 127 | this.penaltyScore = penaltyScore; 128 | return this; 129 | } 130 | 131 | @Override 132 | public ErnieRequestBuilder userId(String userId) { 133 | this.userId = userId; 134 | return this; 135 | } 136 | 137 | @Override 138 | public ErnieRequestBuilder messages(Deque messages) { 139 | this.messages = messages; 140 | return this; 141 | } 142 | 143 | @Override 144 | public ErnieRequestBuilder stream(Boolean stream) { 145 | this.stream = stream; 146 | return this; 147 | } 148 | 149 | public ErnieRequestBuilder functions(List functions) { 150 | this.functions = functions; 151 | return this; 152 | } 153 | 154 | public ErnieRequestBuilder system(String system) { 155 | this.system = system; 156 | return this; 157 | } 158 | 159 | public ErnieRequestBuilder stop(List stop) { 160 | this.stop = stop; 161 | return this; 162 | } 163 | 164 | public ErnieRequestBuilder disableSearch(Boolean disableSearch) { 165 | this.disableSearch = disableSearch; 166 | return this; 167 | } 168 | 169 | public ErnieRequestBuilder enableCitation(Boolean enableCitation) { 170 | this.enableCitation = enableCitation; 171 | return this; 172 | } 173 | 174 | public ErnieRequestBuilder responseFormat(String responseFormat) { 175 | this.responseFormat = responseFormat; 176 | return this; 177 | } 178 | 179 | @Override 180 | public ErnieRequest build() { 181 | ErnieRequest ernieRequest = new ErnieRequest(); 182 | ernieRequest.setTemperature(temperature); 183 | ernieRequest.setTopP(topP); 184 | ernieRequest.setPenaltyScore(penaltyScore); 185 | ernieRequest.setUserId(userId); 186 | ernieRequest.setMessages(messages); 187 | ernieRequest.setStream(stream); 188 | ernieRequest.setFunctions(functions); 189 | ernieRequest.setSystem(system); 190 | ernieRequest.setStop(stop); 191 | ernieRequest.setDisableSearch(disableSearch); 192 | ernieRequest.setEnableCitation(enableCitation); 193 | ernieRequest.setResponseFormat(responseFormat); 194 | 195 | return ernieRequest; 196 | } 197 | } 198 | 199 | } 200 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/request/ImageBaseRequest.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.request; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import com.gearwenxin.common.ErrorCode; 5 | import com.gearwenxin.entity.enums.SamplerType; 6 | import com.gearwenxin.exception.WenXinException; 7 | import lombok.AllArgsConstructor; 8 | import lombok.Builder; 9 | import lombok.Data; 10 | import lombok.NoArgsConstructor; 11 | import org.apache.commons.lang3.StringUtils; 12 | 13 | /** 14 | * @author Ge Mingjia 15 | * {@code @date} 2023/8/3 16 | */ 17 | @Data 18 | @Builder 19 | @AllArgsConstructor 20 | @NoArgsConstructor 21 | public class ImageBaseRequest { 22 | 23 | /** 24 | * 提示词,即用户希望图片包含的元素。长度限制为1024字符,建议中文或者英文单词总数量不超过150个 25 | */ 26 | @JsonProperty("prompt") 27 | private String prompt; 28 | 29 | /** 30 | * 反向提示词,即用户希望图片不包含的元素。长度限制为1024字符,建议中文或者英文单词总数量不超过150个 31 | */ 32 | @JsonProperty("negative_prompt") 33 | private String negativePrompt; 34 | 35 | /** 36 | * 生成图片长宽,默认值 1024x1024,取值范围如下: 37 | * ["512x512", "768x768", "768x1024", "1024x768", "576x1024", "1024x576", "1024x1024"] 38 | * 注意:建议选择较大尺寸,结合完善的prompt,以保障图片质量。 39 | */ 40 | @JsonProperty("size") 41 | private String size = "1024x1024"; 42 | 43 | /** 44 | * 生成图片数量,说明: 45 | * · 默认值为1 46 | * · 取值范围为1-4 47 | * · 单次生成的图片较多及请求较频繁可能导致请求超时 48 | */ 49 | @JsonProperty("n") 50 | private Integer n = 1; 51 | 52 | /** 53 | * 迭代轮次,说明: 54 | * · 默认值为20 55 | * · 取值范围为10-50 56 | */ 57 | @JsonProperty("steps") 58 | private Integer steps = 20; 59 | 60 | @JsonProperty("sampler_index") 61 | private String samplerIndex = SamplerType.DPM2_a.getValue(); 62 | 63 | @JsonProperty("user_id") 64 | private String userId; 65 | 66 | public void validSelf() { 67 | 68 | // 检查content不为空 69 | if (StringUtils.isBlank(prompt)) { 70 | throw new WenXinException(ErrorCode.PARAMS_ERROR, "prompt cannot be empty"); 71 | } 72 | // 检查单个content长度 73 | if (prompt.length() > 1024) { 74 | throw new WenXinException(ErrorCode.PARAMS_ERROR, "prompt's length cannot be more than 1024"); 75 | } 76 | // 检查单个content长度 77 | if (negativePrompt.length() > 1024) { 78 | throw new WenXinException(ErrorCode.PARAMS_ERROR, "prompt's length cannot be more than 1024"); 79 | } 80 | 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/request/PluginParams.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.request; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | import java.util.List; 9 | 10 | /** 11 | * 请求参数类 12 | */ 13 | @Data 14 | @NoArgsConstructor 15 | @AllArgsConstructor 16 | public class PluginParams { 17 | 18 | /** 19 | * 查询信息 20 | * (1)成员不能为空 21 | * (2)长度不能超过1000个字符 22 | */ 23 | @JsonProperty("query") 24 | private String query; 25 | 26 | /** 27 | * 需要调用的插件,参数为插件ID,插件ID可在插件列表-插件详情中获取。 28 | * (1)最多3个插件,最少1个插件。 29 | * (2)当多个插件时,插件触发由大模型意图判断控制。 30 | * (3)当只有1个插件时,强制指定使用该插件工具。 31 | */ 32 | @JsonProperty("plugins") 33 | private List plugins; 34 | 35 | /** 36 | * 是否以流式接口的形式返回数据,默认false,可选值如下: 37 | * (1)true: 是,以流式接口的形式返回数据 38 | * (2)false:否,非流式接口形式返回数据 39 | */ 40 | @JsonProperty("stream") 41 | private boolean stream; 42 | 43 | /** 44 | * llm相关参数,不指定参数时,使用调试过程中的默认值。 45 | */ 46 | @JsonProperty("llm") 47 | private Object llm; 48 | 49 | /** 50 | * 如果prompt中使用了变量,推理时可以填写具体值; 51 | * 如果prompt中未使用变量,该字段不填。 52 | */ 53 | @JsonProperty("input_variables") 54 | private Object inputVariables; 55 | 56 | /** 57 | * 聊天上下文信息。 58 | */ 59 | @JsonProperty("history") 60 | private Object history; 61 | 62 | /** 63 | * 是否返回插件的原始请求信息,默认false,可选值如下: 64 | * true:是,返回插件的原始请求信息meta_info 65 | * false:否,不返回插件的原始请求信息meta_info 66 | */ 67 | @JsonProperty("verbose") 68 | private boolean verbose; 69 | 70 | /** 71 | * 文件的http地址 72 | */ 73 | @JsonProperty("fileurl") 74 | private String fileUrl; 75 | 76 | } 77 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/request/PromptRequest.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.request; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Builder; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | import java.util.Map; 9 | 10 | /** 11 | * @author Ge Mingjia 12 | * {@code @date} 2023/7/26 13 | */ 14 | @Data 15 | @Builder 16 | @AllArgsConstructor 17 | @NoArgsConstructor 18 | public class PromptRequest { 19 | 20 | /** 21 | * prompt工程里面对应的模板id 22 | */ 23 | private String id; 24 | 25 | /** 26 | * 参数map 27 | */ 28 | private Map paramMap; 29 | 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/ChatResponse.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import com.gearwenxin.entity.FunctionCall; 5 | import com.gearwenxin.entity.Usage; 6 | import lombok.AllArgsConstructor; 7 | import lombok.Data; 8 | import lombok.NoArgsConstructor; 9 | 10 | import java.io.Serializable; 11 | 12 | /** 13 | * @author Ge Mingjia 14 | * {@code @date} 2023/7/20 15 | *

16 | * ContBot 模型 17 | */ 18 | @Data 19 | @NoArgsConstructor 20 | @AllArgsConstructor 21 | public class ChatResponse implements Serializable { 22 | 23 | /** 24 | * 本轮对话的id 25 | */ 26 | @JsonProperty("id") 27 | private String id; 28 | 29 | /** 30 | * 用于定位的log_id 31 | */ 32 | @JsonProperty("log_id") 33 | private String logId; 34 | 35 | /** 36 | * 回包类型 37 | * chat.completion:多轮对话返回 38 | */ 39 | @JsonProperty("object") 40 | private String object; 41 | 42 | /** 43 | * 时间戳 44 | */ 45 | @JsonProperty("created") 46 | private Integer created; 47 | 48 | /** 49 | * 表示当前子句的序号。只有在流式接口模式下会返回该字段 50 | */ 51 | @JsonProperty("sentence_id") 52 | private Integer sentenceId; 53 | 54 | /** 55 | * 表示当前子句是否是最后一句。只有在流式接口模式下会返回该字段 56 | */ 57 | @JsonProperty("is_end") 58 | private Boolean isEnd; 59 | 60 | /** 61 | * 当前生成的结果是否被截断 62 | */ 63 | @JsonProperty("is_truncated") 64 | private Boolean isTruncated; 65 | 66 | /** 67 | * 输出内容标识,说明: 68 | * · normal:输出内容完全由大模型生成,未触发截断、替换 69 | * · stop:输出结果命中入参stop中指定的字段后被截断 70 | * · length:达到了最大的token数,根据EB返回结果is_truncated来截断 71 | * · content_filter:输出内容被截断、兜底、替换为**等 72 | * · function_call:调用了funtion call功能 73 | */ 74 | @JsonProperty("finish_reason") 75 | private String finishReason; 76 | 77 | /** 78 | * 搜索数据,当请求参数enable_citation为true并且触发搜索时,会返回该字段 79 | */ 80 | @JsonProperty("search_info") 81 | private SearchInfo searchInfo; 82 | 83 | /** 84 | * 对话返回结果 85 | */ 86 | @JsonProperty("result") 87 | private String result; 88 | 89 | /** 90 | * 表示用户输入是否存在安全,是否关闭当前会话,清理历史回话信息 91 | * true:是,表示用户输入存在安全风险,建议关闭当前会话,清理历史会话信息 92 | * false:否,表示用户输入无安全风险 93 | */ 94 | @JsonProperty("need_clear_history") 95 | private Boolean needClearHistory; 96 | 97 | /** 98 | * token统计信息,token数 = 汉字数+单词数*1.3 (仅为估算逻辑) 99 | */ 100 | @JsonProperty("usage") 101 | private Usage usage; 102 | 103 | /** 104 | * 当need_clear_history为true时,此字段会告知第几轮对话有敏感信息,如果是当前问题,ban_round=-1 105 | */ 106 | @JsonProperty("ban_round") 107 | private Integer banRound; 108 | 109 | /** 110 | * 说明: 111 | * · 0:正常返回 112 | * · 其他:非正常 113 | */ 114 | @JsonProperty("flag") 115 | private Integer flag; 116 | 117 | /** 118 | * 错误代码,正常为 null 119 | */ 120 | @JsonProperty("error_code") 121 | private Integer errorCode; 122 | 123 | /** 124 | * 错误代码,正常为 null 125 | */ 126 | @JsonProperty("eb_code") 127 | private Integer ebCode; 128 | 129 | /** 130 | * 错误信息,正常为 null 131 | */ 132 | @JsonProperty("error_msg") 133 | private String errorMsg; 134 | 135 | /** 136 | * 由模型生成的函数调用,包含函数名称,和调用参数 137 | */ 138 | @JsonProperty("function_call") 139 | private FunctionCall functionCall; 140 | 141 | } 142 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/ErrorResponse.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.AllArgsConstructor; 5 | import lombok.Builder; 6 | import lombok.Data; 7 | import lombok.NoArgsConstructor; 8 | 9 | /** 10 | * @author Ge Mingjia 11 | * {@code @date} 2023/10/26 12 | */ 13 | @Data 14 | @Builder 15 | @NoArgsConstructor 16 | @AllArgsConstructor 17 | public class ErrorResponse { 18 | 19 | /** 20 | * 本轮对话的id 21 | */ 22 | @JsonProperty("id") 23 | private String id; 24 | 25 | /** 26 | * 用于定位的log_id 27 | */ 28 | @JsonProperty("log_id") 29 | private String logId; 30 | 31 | /** 32 | * 错误代码,正常为 null 33 | */ 34 | @JsonProperty("error_code") 35 | private Integer errorCode; 36 | 37 | /** 38 | * 错误代码,正常为 null 39 | */ 40 | @JsonProperty("eb_code") 41 | private Integer ebCode; 42 | 43 | /** 44 | * 错误信息,正常为 null 45 | */ 46 | @JsonProperty("error_msg") 47 | private String errorMsg; 48 | 49 | @Override 50 | public String toString() { 51 | return "error_response { " + 52 | "id: '" + id + '\'' + 53 | ", logId: '" + logId + '\'' + 54 | ", errorCode: " + errorCode + 55 | ", ebCode: " + ebCode + 56 | ", errorMsg: '" + errorMsg + '\'' + 57 | '}' + " "; 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/ImageData.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.Data; 5 | 6 | /** 7 | * 表示生成图片的详细信息。 8 | */ 9 | @Data 10 | public class ImageData { 11 | 12 | /** 13 | * 固定值 "image",表示图像。 14 | */ 15 | private String object; 16 | 17 | /** 18 | * 图片base64编码内容。 19 | */ 20 | @JsonProperty("b64_image") 21 | private String b64Image; 22 | 23 | /** 24 | * 图片序号。 25 | */ 26 | private int index; 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/ImageResponse.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import com.gearwenxin.entity.Usage; 5 | import lombok.AllArgsConstructor; 6 | import lombok.Data; 7 | import lombok.NoArgsConstructor; 8 | 9 | import java.util.List; 10 | 11 | /** 12 | * @author Ge Mingjia 13 | * {@code @date} 2023/8/3 14 | */ 15 | @Data 16 | @AllArgsConstructor 17 | @NoArgsConstructor 18 | public class ImageResponse { 19 | 20 | /** 21 | * 请求的ID。 22 | */ 23 | private String id; 24 | 25 | /** 26 | * 回包类型。固定值为 "image",表示图像生成返回。 27 | */ 28 | private String object; 29 | 30 | /** 31 | * 时间戳,表示生成响应的时间。 32 | */ 33 | private int created; 34 | 35 | /** 36 | * 生成图片结果列表。 37 | */ 38 | private List data; 39 | 40 | /** 41 | * token统计信息,token数 = 汉字数 + 单词数 * 1.3 (仅为估算逻辑)。 42 | */ 43 | private Usage usage; 44 | 45 | /** 46 | * 错误代码,正常为 null 47 | */ 48 | @JsonProperty("error_code") 49 | private Integer errorCode; 50 | 51 | /** 52 | * 错误信息,正常为 null 53 | */ 54 | @JsonProperty("error_msg") 55 | private String errorMsg; 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/PromptErrMessage.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Data; 5 | import lombok.NoArgsConstructor; 6 | 7 | /** 8 | * @author Ge Mingjia 9 | * {@code @date} 2023/7/26 10 | */ 11 | @Data 12 | @AllArgsConstructor 13 | @NoArgsConstructor 14 | public class PromptErrMessage { 15 | 16 | private String global; 17 | 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/PromptResponse.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | /** 9 | * @author Ge Mingjia 10 | * {@code @date} 2023/7/23 11 | */ 12 | @Data 13 | @AllArgsConstructor 14 | @NoArgsConstructor 15 | public class PromptResponse { 16 | 17 | @JsonProperty("log_id") 18 | private String logId; 19 | 20 | private PromptResult result; 21 | 22 | private Integer status; 23 | 24 | private Boolean success; 25 | 26 | @JsonProperty("error_code") 27 | private Integer errorCode; 28 | 29 | @JsonProperty("error_msg") 30 | private String errorMsg; 31 | 32 | @JsonProperty("code") 33 | private String promptErrCode; 34 | 35 | @JsonProperty("message") 36 | private PromptErrMessage promptErrMessage; 37 | } 38 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/PromptResult.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Data; 5 | import lombok.NoArgsConstructor; 6 | 7 | /** 8 | * @author Ge Mingjia 9 | * {@code @date} 2023/7/26 10 | */ 11 | 12 | @Data 13 | @AllArgsConstructor 14 | @NoArgsConstructor 15 | public class PromptResult { 16 | 17 | /** 18 | * prompt工程里面对应的模板id 19 | */ 20 | private String templateId; 21 | 22 | /** 23 | * 模板名称 24 | */ 25 | private String templateName; 26 | 27 | /** 28 | * 模板原始内容 29 | */ 30 | private String templateContent; 31 | 32 | /** 33 | * 模板变量插值 34 | */ 35 | private String templateVariables; 36 | 37 | /** 38 | * 将变量插值填充到模板原始内容后得到的模板内容 39 | */ 40 | private String content; 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/SSEResponse.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response; 2 | 3 | import lombok.Data; 4 | 5 | @Data 6 | public class SSEResponse { 7 | 8 | private String content; 9 | 10 | @Override 11 | public String toString() { 12 | return "data: " + content + "\n\n"; 13 | } 14 | 15 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/SearchInfo.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | import java.util.List; 9 | 10 | @Data 11 | @AllArgsConstructor 12 | @NoArgsConstructor 13 | class SearchInfo { 14 | 15 | /** 16 | * 搜索结果的列表 17 | */ 18 | @JsonProperty("search_results") 19 | private List searchResults; 20 | 21 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/SearchResult.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Data; 5 | import lombok.NoArgsConstructor; 6 | 7 | @Data 8 | @AllArgsConstructor 9 | @NoArgsConstructor 10 | class SearchResult { 11 | 12 | /** 13 | * 搜索结果的序号 14 | */ 15 | private int index; 16 | 17 | /** 18 | * 搜索结果的url 19 | */ 20 | private String url; 21 | 22 | /** 23 | * 搜索结果的标题 24 | */ 25 | private String title; 26 | 27 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/TokenResponse.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | /** 9 | * @author Ge Mingjia 10 | */ 11 | @Data 12 | @AllArgsConstructor 13 | @NoArgsConstructor 14 | public class TokenResponse { 15 | 16 | @JsonProperty("refresh_token") 17 | private String refreshToken; 18 | 19 | @JsonProperty("expires_in") 20 | private int expiresIn; 21 | 22 | @JsonProperty("session_key") 23 | private String sessionKey; 24 | 25 | @JsonProperty("access_token") 26 | private String accessToken; 27 | 28 | @JsonProperty("scope") 29 | private String scope; 30 | 31 | @JsonProperty("session_secret") 32 | private String sessionSecret; 33 | 34 | @JsonProperty("error_description") 35 | private String errorDescription; 36 | 37 | @JsonProperty("error") 38 | private String error; 39 | 40 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/plugin/PluginResponse.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response.plugin; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import com.gearwenxin.entity.Usage; 5 | import lombok.AllArgsConstructor; 6 | import lombok.Data; 7 | import lombok.NoArgsConstructor; 8 | 9 | /** 10 | * 返回结果信息类 11 | */ 12 | @Data 13 | @NoArgsConstructor 14 | @AllArgsConstructor 15 | public class PluginResponse { 16 | 17 | /** 18 | * 唯一的log id,用于问题定位 19 | */ 20 | @JsonProperty("log_id") 21 | private String logId; 22 | 23 | /** 24 | * 本轮对话的id 25 | */ 26 | @JsonProperty("id") 27 | private String id; 28 | 29 | /** 30 | * 回包类型。 31 | * chat.completion:多轮对话返回 32 | */ 33 | @JsonProperty("object") 34 | private String object; 35 | 36 | /** 37 | * 时间戳 38 | */ 39 | @JsonProperty("created") 40 | private int created; 41 | 42 | /** 43 | * 表示当前子句的序号,只有在流式接口模式下会返回该字段 44 | */ 45 | @JsonProperty("sentence_id") 46 | private int sentenceId; 47 | 48 | /** 49 | * 表示当前子句是否是最后一句,只有在流式接口模式下会返回该字段 50 | */ 51 | @JsonProperty("is_end") 52 | private boolean isEnd; 53 | 54 | /** 55 | * 插件返回结果 56 | */ 57 | @JsonProperty("result") 58 | private String result; 59 | 60 | /** 61 | * 当前生成的结果是否被截断 62 | */ 63 | @JsonProperty("is_truncated") 64 | private boolean isTruncated; 65 | 66 | /** 67 | * 表示用户输入是否存在安全,是否关闭当前会话,清理历史会话信息 68 | * true:是,表示用户输入存在安全风险,建议关闭当前会话,清理历史会话信息 69 | * false:否,表示用户输入无安全风险 70 | */ 71 | @JsonProperty("need_clear_history") 72 | private boolean needClearHistory; 73 | 74 | /** 75 | * 当need_clear_history为true时,此字段会告知第几轮对话有敏感信息,如果是当前问题,ban_round = -1 76 | */ 77 | @JsonProperty("ban_round") 78 | private int banRound; 79 | 80 | /** 81 | * token统计信息,token数 = 汉字数+单词数*1.3 (仅为估算逻辑) 82 | */ 83 | @JsonProperty("usage") 84 | private Usage usage; 85 | 86 | /** 87 | * 插件的原始请求信息 88 | */ 89 | @JsonProperty("meta_info") 90 | private T metaInfo; 91 | 92 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/plugin/knowledge/KnowledgeBaseMI.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response.plugin.knowledge; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | /** 9 | * @author Ge Mingjia 10 | * {@code @date} 2023/10/22 11 | */ 12 | @Data 13 | @AllArgsConstructor 14 | @NoArgsConstructor 15 | public class KnowledgeBaseMI { 16 | 17 | /** 18 | * 插件 Id,为“uuid-zhishiku” 19 | */ 20 | @JsonProperty("plugin_id") 21 | private String pluginId; 22 | 23 | /** 24 | * 知识库原始请求参数 25 | */ 26 | @JsonProperty("request") 27 | private KnowledgeMIRequest request; 28 | 29 | /** 30 | * 知识库原始返回结果 31 | */ 32 | @JsonProperty("response") 33 | private KnowledgeMIResponse response; 34 | 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/plugin/knowledge/KnowledgeMIRequest.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response.plugin.knowledge; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | /** 9 | * @author Ge Mingjia 10 | * {@code @date} 2023/10/22 11 | */ 12 | @Data 13 | @NoArgsConstructor 14 | @AllArgsConstructor 15 | public class KnowledgeMIRequest { 16 | 17 | /** 18 | * 用于查询知识库的用户请求 19 | */ 20 | @JsonProperty("query") 21 | private String query; 22 | 23 | /** 24 | * 使用知识库的 Id 列表 25 | */ 26 | @JsonProperty("kbIds") 27 | private String[] kbIds; 28 | 29 | /** 30 | * 分片和query的相似度分数的下限,低于该下限的文档分片不会被返回 31 | */ 32 | @JsonProperty("score") 33 | private float score; 34 | 35 | /** 36 | * 返回的最相关的文档数 37 | */ 38 | @JsonProperty("topN") 39 | private Integer topN; 40 | 41 | } 42 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/plugin/knowledge/KnowledgeMIResponse.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response.plugin.knowledge; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | /** 9 | * @author Ge Mingjia 10 | * {@code @date} 2023/10/22 11 | */ 12 | @Data 13 | @NoArgsConstructor 14 | @AllArgsConstructor 15 | public class KnowledgeMIResponse { 16 | 17 | /** 18 | * 错误码 19 | */ 20 | @JsonProperty("retCode") 21 | private Integer retCode; 22 | 23 | /** 24 | * 错误信息 25 | */ 26 | @JsonProperty("message") 27 | private String message; 28 | 29 | /** 30 | * 返回结果 31 | */ 32 | @JsonProperty("result") 33 | private KnowledgeMIResult result; 34 | 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/plugin/knowledge/KnowledgeMIResponses.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response.plugin.knowledge; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | /** 9 | * @author Ge Mingjia 10 | * {@code @date} 2023/10/22 11 | */ 12 | @Data 13 | @NoArgsConstructor 14 | @AllArgsConstructor 15 | public class KnowledgeMIResponses { 16 | 17 | /** 18 | * 文档分片的下载地址 19 | */ 20 | @JsonProperty("contentUrl") 21 | private String contentUrl; 22 | 23 | /** 24 | * 文档 Id 25 | */ 26 | @JsonProperty("docId") 27 | private String docId; 28 | 29 | /** 30 | * 文档的名称 31 | */ 32 | @JsonProperty("docName") 33 | private String docName; 34 | 35 | /** 36 | * 文档上传的知识库 Id 37 | */ 38 | @JsonProperty("kbId") 39 | private String kbId; 40 | 41 | /** 42 | * 当前分片和用户请求的相关度,取值范围(0-1) 43 | */ 44 | @JsonProperty("score") 45 | private float score; 46 | 47 | /** 48 | * 分片 ID 49 | */ 50 | @JsonProperty("shardId") 51 | private String shardId; 52 | 53 | /** 54 | * 分片序号 55 | */ 56 | @JsonProperty("shardIndex") 57 | private Integer shardIndex; 58 | 59 | /** 60 | * 分片的实际内容 61 | */ 62 | @JsonProperty("content") 63 | private String content; 64 | 65 | } 66 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/entity/response/plugin/knowledge/KnowledgeMIResult.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.entity.response.plugin.knowledge; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | /** 9 | * @author Ge Mingjia 10 | * {@code @date} 2023/10/22 11 | */ 12 | @Data 13 | @NoArgsConstructor 14 | @AllArgsConstructor 15 | public class KnowledgeMIResult { 16 | 17 | /** 18 | * bes查询耗时 19 | */ 20 | @JsonProperty("besQueryCostMilsec3") 21 | private Integer besQueryCostMilsec3; 22 | 23 | /** 24 | * db查询耗时 25 | */ 26 | @JsonProperty("dbQueryCostMilsec1") 27 | private Integer dbQueryCostMilsec1; 28 | 29 | /** 30 | * embedding查询耗时 31 | */ 32 | @JsonProperty("embeddedCostMilsec2") 33 | private Integer embeddedCostMilsec2; 34 | 35 | /** 36 | * 知识库返回的最相关文档信息 37 | */ 38 | @JsonProperty("responses") 39 | private KnowledgeMIResponses responses; 40 | 41 | /** 42 | * bos url生成耗时 43 | */ 44 | @JsonProperty("urlSignedCostMilsec4") 45 | private Integer urlSignedCostMilsec4; 46 | 47 | } 48 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/exception/WenXinException.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.exception; 2 | 3 | import com.gearwenxin.common.ErrorCode; 4 | import lombok.Getter; 5 | 6 | /** 7 | * @author Ge Mingjia 8 | * {@code @date} 2023/7/22 9 | */ 10 | @Getter 11 | public class WenXinException extends RuntimeException { 12 | 13 | /** 14 | * 错误码 15 | */ 16 | private final int code; 17 | 18 | public WenXinException(int code, String message) { 19 | super(message); 20 | this.code = code; 21 | } 22 | 23 | public WenXinException(ErrorCode errorCode) { 24 | super(errorCode.getMessage()); 25 | this.code = errorCode.getCode(); 26 | } 27 | 28 | public WenXinException(ErrorCode errorCode, String message) { 29 | super(message); 30 | this.code = errorCode.getCode(); 31 | } 32 | 33 | } 34 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/model/BasicChatModel.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.model; 2 | 3 | import com.gearwenxin.entity.chatmodel.ChatBaseRequest; 4 | import com.gearwenxin.entity.response.ChatResponse; 5 | import reactor.core.publisher.Flux; 6 | import reactor.core.publisher.Mono; 7 | 8 | public interface BasicChatModel { 9 | 10 | /** 单次对话 **/ 11 | Mono chat(String content); 12 | 13 | Mono chat(String content, float weight); 14 | 15 | Mono chat(T chatRequest); 16 | 17 | Mono chat(T chatRequest, float weight); 18 | 19 | Flux chatStream(String content); 20 | 21 | Flux chatStream(String content, float weight); 22 | 23 | Flux chatStream(T chatRequest); 24 | 25 | Flux chatStream(T chatRequest, float weight); 26 | 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/model/ChatModel.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.model; 2 | 3 | import com.gearwenxin.entity.chatmodel.ChatBaseRequest; 4 | import com.gearwenxin.entity.response.ChatResponse; 5 | import reactor.core.publisher.Flux; 6 | import reactor.core.publisher.Mono; 7 | 8 | import java.util.Map; 9 | 10 | public interface ChatModel { 11 | 12 | /** 单次对话 **/ 13 | Mono chat(String content); 14 | 15 | Mono chat(String content, float weight); 16 | 17 | Mono chat(T chatRequest); 18 | 19 | Mono chat(T chatRequest, float weight); 20 | 21 | Flux chatStream(String content); 22 | 23 | Flux chatStream(String content, float weight); 24 | 25 | Flux chatStream(T chatRequest); 26 | 27 | Flux chatStream(T chatRequest, float weight); 28 | 29 | Flux chatStream(Map chatRequest); 30 | 31 | /** 连续对话 **/ 32 | Mono chats(String content, String msgUid); 33 | 34 | Mono chats(String content, String msgUid, float weight); 35 | 36 | Mono chats(T chatRequest, String msgUid); 37 | 38 | Mono chats(T chatRequest, String msgUid, float weight); 39 | 40 | Flux chatsStream(String content, String msgUid); 41 | 42 | Flux chatsStream(String content, String msgUid, float weight); 43 | 44 | Flux chatsStream(T chatRequest, String msgUid); 45 | 46 | Flux chatsStream(T chatRequest, String msgUid, float weight); 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/model/EmbeddingModel.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.model; 2 | 3 | import com.gearwenxin.entity.chatmodel.ChatPromptRequest; 4 | import com.gearwenxin.entity.response.PromptResponse; 5 | import reactor.core.publisher.Mono; 6 | 7 | /** 8 | * @author Ge Mingjia 9 | * {@code @date} 2023/7/20 10 | */ 11 | public interface EmbeddingModel { 12 | 13 | /** 14 | * Prompt模板对话 (Get请求 不支持流式返回) 15 | * (非流式) 16 | * 17 | * @param request 请求实体类 18 | * @return ChatResponse 响应实体类 19 | */ 20 | Mono chat(ChatPromptRequest request); 21 | Mono chat(ChatPromptRequest request, float weight); 22 | 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/model/ImageModel.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.model; 2 | 3 | import com.gearwenxin.entity.request.ImageBaseRequest; 4 | import com.gearwenxin.entity.response.ImageResponse; 5 | import reactor.core.publisher.Mono; 6 | 7 | /** 8 | * @author Ge Mingjia 9 | * {@code @date} 2023/7/20 10 | */ 11 | public interface ImageModel { 12 | 13 | /** 14 | * 绘图 15 | * 16 | * @param imageBaseRequest 作图参数 17 | * @return ImageResponse 图片响应 18 | */ 19 | Mono chatImage(T imageBaseRequest); 20 | 21 | Mono chatImage(T imageBaseRequest, float weight); 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/model/PromptModel.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.model; 2 | 3 | import com.gearwenxin.entity.chatmodel.ChatPromptRequest; 4 | import com.gearwenxin.entity.response.PromptResponse; 5 | import reactor.core.publisher.Mono; 6 | 7 | /** 8 | * @author Ge Mingjia 9 | * {@code @date} 2023/7/20 10 | */ 11 | public interface PromptModel { 12 | 13 | /** 14 | * Prompt模板对话 (Get请求 不支持流式返回) 15 | * (非流式) 16 | * 17 | * @param request 请求实体类 18 | * @return ChatResponse 响应实体类 19 | */ 20 | Mono chat(ChatPromptRequest request); 21 | Mono chat(ChatPromptRequest request, float weight); 22 | 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/plugin/Weather.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.plugin; 2 | 3 | /** 4 | * @author Ge Mingjia 5 | * {@code @date} 2023/10/22 6 | */ 7 | public class Weather { 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/schedule/BackgroundSaveManager.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.schedule; 2 | 3 | import com.gearwenxin.config.WenXinProperties; 4 | import com.gearwenxin.schedule.entity.BlockingMap; 5 | import com.gearwenxin.schedule.entity.ChatTask; 6 | import com.google.gson.Gson; 7 | import jakarta.annotation.Resource; 8 | import org.springframework.scheduling.annotation.EnableScheduling; 9 | import org.springframework.scheduling.annotation.Scheduled; 10 | import org.springframework.stereotype.Component; 11 | 12 | import java.util.List; 13 | 14 | @Component 15 | @EnableScheduling 16 | public class BackgroundSaveManager { 17 | 18 | @Resource 19 | private WenXinProperties wenXinProperties; 20 | 21 | private static final Gson gson = new Gson(); 22 | 23 | private static final TaskQueueManager taskManager = TaskQueueManager.getInstance(); 24 | 25 | /** 26 | * 定时保存任务队列 27 | */ 28 | // @Scheduled(fixedDelay = 2000) 29 | // public void saveTaskQueueThread() { 30 | // BlockingMap> taskMap = taskManager.getTaskMap(); 31 | // String taskMapJson = gson.toJson(taskMap); 32 | //// SaveService.saveTaskQueue(taskMapJson); 33 | // } 34 | 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/schedule/TaskConsumerLoop.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.schedule; 2 | 3 | import com.gearwenxin.common.Constant; 4 | import com.gearwenxin.common.StatusConst; 5 | import com.gearwenxin.config.ModelConfig; 6 | import com.gearwenxin.entity.chatmodel.ChatPromptRequest; 7 | import com.gearwenxin.entity.response.PromptResponse; 8 | import com.gearwenxin.schedule.entity.ChatTask; 9 | import com.gearwenxin.service.ChatService; 10 | import com.gearwenxin.service.ImageService; 11 | import com.gearwenxin.entity.chatmodel.ChatBaseRequest; 12 | import com.gearwenxin.entity.chatmodel.ChatErnieRequest; 13 | import com.gearwenxin.entity.request.ImageBaseRequest; 14 | import com.gearwenxin.entity.response.ChatResponse; 15 | import com.gearwenxin.entity.response.ImageResponse; 16 | import com.gearwenxin.service.PromptService; 17 | import jakarta.annotation.Resource; 18 | import lombok.Getter; 19 | import lombok.Setter; 20 | import lombok.extern.slf4j.Slf4j; 21 | import org.reactivestreams.Publisher; 22 | import org.springframework.stereotype.Component; 23 | import reactor.core.publisher.Mono; 24 | 25 | import java.util.*; 26 | import java.util.concurrent.CompletableFuture; 27 | import java.util.concurrent.CountDownLatch; 28 | import java.util.concurrent.ExecutorService; 29 | 30 | @Slf4j 31 | @Component 32 | public class TaskConsumerLoop { 33 | 34 | public static final String TAG = "TaskConsumerLoop"; 35 | public static final int DEFAULT_QPS = -1; 36 | 37 | @Getter 38 | @Setter 39 | public CountDownLatch testCountDownLatch; 40 | 41 | @Getter 42 | @Setter 43 | private List qpsList = null; 44 | 45 | @Resource 46 | private ChatService chatService; 47 | @Resource 48 | private PromptService promptService; 49 | @Resource 50 | private ImageService imageService; 51 | 52 | private static final Map MODEL_QPS_MAP = new HashMap<>(); 53 | 54 | private final TaskQueueManager taskManager = TaskQueueManager.getInstance(); 55 | 56 | public void start() { 57 | initModelQPSMap(); 58 | Set modelNames = MODEL_QPS_MAP.keySet(); 59 | modelNames.forEach(modelName -> new Thread(() -> { 60 | try { 61 | Thread.currentThread().setName(modelName + "-thread"); 62 | log.info("[{}] {}, model: {}, loop start", TAG, Thread.currentThread().getName(), modelName); 63 | // 消费事件循环处理 64 | while (true) { 65 | eventLoopProcess(modelName); 66 | } 67 | } catch (Exception e) { 68 | log.error("[{}] loop-process error, modelName: {}, thread-{}", TAG, modelName, Thread.currentThread().getName(), e); 69 | if (!Thread.currentThread().isAlive()) { 70 | log.error("[{}] {} is not alive", TAG, Thread.currentThread().getName()); 71 | } 72 | } 73 | }).start()); 74 | } 75 | 76 | public void initModelQPSMap() { 77 | if (qpsList == null || qpsList.isEmpty()) { 78 | return; 79 | } 80 | log.debug("[{}] model qps list: {}", TAG, qpsList); 81 | // 用于检测消费线程是否启动 82 | qpsList.add(Constant.CHECK + " " + DEFAULT_QPS); 83 | qpsList.forEach(s -> { 84 | String[] split = s.split(" "); 85 | MODEL_QPS_MAP.put(split[0], Integer.parseInt(split[1])); 86 | }); 87 | log.info("[{}] init model qps map complete", TAG); 88 | } 89 | 90 | private int getModelQPS(String modelName) { 91 | return MODEL_QPS_MAP.getOrDefault(modelName, DEFAULT_QPS); 92 | } 93 | 94 | /** 95 | * 消费事件循环处理 96 | */ 97 | public void eventLoopProcess(String modelName) { 98 | Map currentQPSMap = taskManager.getModelCurrentQPSMap(); 99 | int modelQPS = getModelQPS(modelName); 100 | // 获取到当前的QPS 101 | Integer currentQPS = currentQPSMap.get(modelName); 102 | if (currentQPS == null) { 103 | taskManager.initModelCurrentQPS(modelName); 104 | currentQPS = 0; 105 | } 106 | log.debug("[{}] [{}] current qps: {}", TAG, modelName, currentQPS); 107 | if (currentQPS < modelQPS || modelQPS == DEFAULT_QPS) { 108 | ChatTask task = taskManager.getTask(modelName); 109 | Optional.ofNullable(task).ifPresentOrElse(t -> { 110 | log.debug("[{}] [{}] task: {}", TAG, modelName, t); 111 | submitTask(t); 112 | taskManager.upModelCurrentQPS(modelName); 113 | }, () -> sleep(1500)); 114 | } else { 115 | // TODO: 待优化 116 | // RuntimeToolkit.threadWait(Thread.currentThread()); 117 | sleep(1000); 118 | } 119 | 120 | } 121 | 122 | /** 123 | * 提交任务到不同的线程池 124 | */ 125 | private void submitTask(ChatTask task) { 126 | String taskId = task.getTaskId(); 127 | ModelConfig modelConfig = task.getModelConfig(); 128 | // 根据不同的任务类型,获取不同的线程池实例 129 | ExecutorService executorService = ThreadPoolManager.getInstance(task.getTaskType()); 130 | switch (task.getTaskType()) { 131 | case chat -> { 132 | var future = CompletableFuture.supplyAsync(() -> processChatTask(task, modelConfig), executorService); 133 | taskManager.getChatFutureMap().putAndNotify(taskId, future); 134 | } 135 | case prompt -> { 136 | var future = CompletableFuture.supplyAsync(() -> processPromptTask(task, modelConfig), executorService); 137 | taskManager.getPromptFutureMap().putAndNotify(taskId, future); 138 | } 139 | case image -> { 140 | var future = CompletableFuture.supplyAsync(() -> processImageTask(task, modelConfig), executorService); 141 | taskManager.getImageFutureMap().putAndNotify(taskId, future); 142 | } 143 | case embedding -> { 144 | } 145 | case check -> { 146 | // 用于检查消费线程是否启动 147 | StatusConst.SERVICE_STARTED = true; 148 | getTestCountDownLatch().countDown(); 149 | // 销毁当前线程 150 | // Thread.currentThread().interrupt(); 151 | } 152 | default -> log.error("[{}] unknown task type: {}", TAG, task.getTaskType()); 153 | } 154 | } 155 | 156 | private Publisher processChatTask(ChatTask task, ModelConfig modelConfig) { 157 | if (task.isJsonMode()) { 158 | // TODO: 待实现 159 | log.warn("[{}] json mode is not implemented", TAG); 160 | return Mono.empty(); 161 | } else { 162 | // 如果包含"ernie",则使用erni的请求类 163 | ChatBaseRequest taskRequest = modelConfig.getModelName().toLowerCase().contains("ernie") ? 164 | (ChatErnieRequest) task.getTaskRequest() : (ChatBaseRequest) task.getTaskRequest(); 165 | log.debug("[{}] submit task {}, ernie: {}", TAG, task.getTaskId(), taskRequest.getClass() == ChatErnieRequest.class); 166 | return chatService.processChatRequest(taskRequest, task.getMessageId(), task.isStream(), modelConfig); 167 | } 168 | } 169 | 170 | private Mono processPromptTask(ChatTask task, ModelConfig modelConfig) { 171 | log.debug("[{}] submit task {}, type: prompt", TAG, task.getTaskId()); 172 | return promptService.promptProcess((ChatPromptRequest) task.getTaskRequest(), modelConfig); 173 | } 174 | 175 | private Mono processImageTask(ChatTask task, ModelConfig modelConfig) { 176 | log.debug("[{}] submit task {}, type: image", TAG, task.getTaskId()); 177 | return imageService.imageProcess((ImageBaseRequest) task.getTaskRequest(), modelConfig); 178 | } 179 | 180 | private void sleep(long millis) { 181 | try { 182 | Thread.sleep(millis); 183 | } catch (InterruptedException e) { 184 | log.error("[{}] thread sleep error", TAG); 185 | Thread.currentThread().interrupt(); 186 | } 187 | } 188 | 189 | } 190 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/schedule/TaskQueueManager.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.schedule; 2 | 3 | import com.gearwenxin.common.RuntimeToolkit; 4 | import com.gearwenxin.entity.enums.ModelType; 5 | import com.gearwenxin.entity.response.ChatResponse; 6 | import com.gearwenxin.entity.response.ImageResponse; 7 | import com.gearwenxin.entity.response.PromptResponse; 8 | import com.gearwenxin.schedule.entity.BlockingMap; 9 | import com.gearwenxin.schedule.entity.ChatTask; 10 | import lombok.Getter; 11 | import lombok.extern.slf4j.Slf4j; 12 | import org.reactivestreams.Publisher; 13 | import reactor.core.publisher.Mono; 14 | 15 | import java.util.*; 16 | import java.util.concurrent.*; 17 | import java.util.concurrent.locks.Lock; 18 | import java.util.concurrent.locks.ReentrantLock; 19 | 20 | /** 21 | * @author GMerge 22 | * {@code @date} 2024/2/28 23 | */ 24 | @Slf4j 25 | public class TaskQueueManager { 26 | 27 | public static final String TAG = "TaskQueueManager"; 28 | @Getter 29 | private final BlockingMap> taskMap = new BlockingMap<>(); 30 | 31 | // 任务数量Map 32 | @Getter 33 | private final Map taskCountMap = new ConcurrentHashMap<>(); 34 | @Getter 35 | private final Map modelCurrentQPSMap = new ConcurrentHashMap<>(); 36 | 37 | // 提交的任务Map 38 | @Getter 39 | private final BlockingMap>> chatFutureMap = new BlockingMap<>(); 40 | @Getter 41 | private final BlockingMap>> imageFutureMap = new BlockingMap<>(); 42 | @Getter 43 | private final BlockingMap>> promptFutureMap = new BlockingMap<>(); 44 | 45 | private final Lock lock = new ReentrantLock(); 46 | private final Map latchMap = new ConcurrentHashMap<>(); 47 | 48 | private volatile static TaskQueueManager instance = null; 49 | 50 | @Getter 51 | private final Map consumerCountDownLatchMap = new ConcurrentHashMap<>(); 52 | 53 | private TaskQueueManager() { 54 | } 55 | 56 | public static TaskQueueManager getInstance() { 57 | if (instance == null) { 58 | synchronized (TaskQueueManager.class) { 59 | if (instance == null) { 60 | instance = new TaskQueueManager(); 61 | } 62 | } 63 | } 64 | return instance; 65 | } 66 | 67 | public String addTask(ChatTask task) { 68 | String modelName = task.getModelConfig().getModelName(); 69 | String taskId = UUID.randomUUID().toString(); 70 | task.setTaskId(taskId); 71 | task.getModelConfig().setTaskId(taskId); 72 | List chatTaskList = taskMap.get(modelName); 73 | synchronized (this) { 74 | if (chatTaskList == null) { 75 | List list = new CopyOnWriteArrayList<>(); 76 | list.add(task); 77 | initTaskCount(modelName); 78 | taskMap.put(modelName, list); 79 | } else { 80 | chatTaskList.add(task); 81 | upTaskCount(modelName); 82 | taskMap.put(modelName, chatTaskList); 83 | } 84 | } 85 | // RuntimeToolkit.threadNotify(Thread.currentThread()); 86 | log.info("[{}] add task for [{}], count: {}", TAG, modelName, getTaskCount(modelName)); 87 | return taskId; 88 | } 89 | 90 | public synchronized ChatTask getTask(String modelName) { 91 | List list = taskMap.get(modelName); 92 | if (list == null || list.isEmpty()) { 93 | return null; 94 | } 95 | downTaskCount(modelName); 96 | return list.remove(0); 97 | } 98 | 99 | public CompletableFuture> getChatFuture(String taskId) { 100 | return chatFutureMap.getAndAwait(taskId); 101 | } 102 | 103 | public CompletableFuture> getImageFuture(String taskId) { 104 | return imageFutureMap.getAndAwait(taskId); 105 | } 106 | 107 | public CompletableFuture> getPromptFuture(String taskId) { 108 | return promptFutureMap.getAndAwait(taskId); 109 | } 110 | 111 | public Set getModelNames() { 112 | return taskMap.getMap().keySet(); 113 | } 114 | 115 | public int getTaskCount(String modelName) { 116 | return taskCountMap.get(modelName); 117 | } 118 | 119 | public synchronized void initTaskCount(String modelName) { 120 | taskCountMap.put(modelName, 1); 121 | log.debug("[{}] init task count for {}", TAG, modelName); 122 | } 123 | 124 | public synchronized void initModelCurrentQPS(String modelName) { 125 | modelCurrentQPSMap.put(modelName, 0); 126 | log.debug("[{}] init model current qps for {}", TAG, modelName); 127 | } 128 | 129 | public synchronized void upTaskCount(String modelName) { 130 | Integer taskCount = taskCountMap.get(modelName); 131 | if (taskCount == null) { 132 | log.error("[{}] task count map not has been init, {}", TAG, modelName); 133 | return; 134 | } 135 | taskCountMap.put(modelName, taskCount + 1); 136 | log.debug("[{}] up task count for {}, number {}", TAG, modelName, taskCount + 1); 137 | } 138 | 139 | public synchronized void upModelCurrentQPS(String modelName) { 140 | Integer currentQPS = modelCurrentQPSMap.get(modelName); 141 | modelCurrentQPSMap.put(modelName, currentQPS + 1); 142 | log.debug("[{}] up model current qps for {}, number {}", TAG, modelName, currentQPS + 1); 143 | } 144 | 145 | public synchronized void downTaskCount(String modelName) { 146 | Integer taskCount = taskCountMap.get(modelName); 147 | if (taskCount == null) { 148 | log.error("[{}] task count map not has been init, {}", TAG, modelName); 149 | return; 150 | } 151 | if (taskCount <= 0) { 152 | log.error("[{}] task count is less than 0, {}", TAG, modelName); 153 | return; 154 | } 155 | taskCountMap.put(modelName, taskCount - 1); 156 | log.debug("[{}] down task count for {}, number {}", TAG, modelName, taskCount - 1); 157 | } 158 | 159 | public synchronized void downModelCurrentQPS(String modelName) { 160 | Integer currentQPS = modelCurrentQPSMap.get(modelName); 161 | if (currentQPS == null || currentQPS <= 0) { 162 | return; 163 | } 164 | modelCurrentQPSMap.put(modelName, currentQPS - 1); 165 | log.debug("[{}] down model current qps for {}, number {}", TAG, modelName, currentQPS - 1); 166 | } 167 | 168 | } 169 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/schedule/ThreadPoolManager.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.schedule; 2 | 3 | import com.gearwenxin.entity.enums.ModelType; 4 | import lombok.extern.slf4j.Slf4j; 5 | 6 | import java.util.concurrent.ExecutorService; 7 | import java.util.concurrent.Executors; 8 | 9 | import static com.gearwenxin.entity.enums.ModelType.addTask; 10 | import static com.gearwenxin.entity.enums.ModelType.check; 11 | 12 | @Slf4j 13 | public class ThreadPoolManager { 14 | 15 | public static final String TAG = "ThreadPoolManager"; 16 | private static final int NUM_THREADS = 5; 17 | private static final int TASK_NUM_THREADS = 10; 18 | private static final ExecutorService[] executorServices = new ExecutorService[6]; 19 | 20 | public static ExecutorService getInstance(ModelType type) { 21 | int index = getIndex(type); 22 | if (executorServices[index] == null) { 23 | synchronized (ExecutorService.class) { 24 | if (executorServices[index] == null) { 25 | log.info("[{}] creat new thread pool for [{}]", TAG, type); 26 | if (type == check) { 27 | executorServices[index] = Executors.newFixedThreadPool(1); 28 | } else { 29 | executorServices[index] = Executors.newFixedThreadPool(NUM_THREADS); 30 | } 31 | } 32 | } 33 | } 34 | return executorServices[index]; 35 | } 36 | 37 | private static int getIndex(ModelType type) { 38 | return switch (type) { 39 | case chat -> 0; 40 | case image -> 1; 41 | case prompt -> 2; 42 | case embedding -> 3; 43 | case addTask -> 4; 44 | case check -> 5; 45 | }; 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/schedule/entity/BlockingMap.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.schedule.entity; 2 | 3 | import lombok.Getter; 4 | import lombok.extern.slf4j.Slf4j; 5 | 6 | import java.util.Map; 7 | import java.util.concurrent.ConcurrentHashMap; 8 | import java.util.concurrent.locks.Condition; 9 | import java.util.concurrent.locks.Lock; 10 | import java.util.concurrent.locks.ReentrantLock; 11 | 12 | @Getter 13 | @Slf4j 14 | public class BlockingMap { 15 | 16 | private final Map map = new ConcurrentHashMap<>(); 17 | private final Map lockMap = new ConcurrentHashMap<>(); 18 | private final Map conditionMap = new ConcurrentHashMap<>(); 19 | 20 | public V put(K key, V value) { 21 | return map.put(key, value); 22 | } 23 | 24 | public V putAndNotify(K key, V value) { 25 | Lock lock = getLockForKey(key); 26 | lock.lock(); 27 | try { 28 | V previous = map.put(key, value); 29 | Condition condition = conditionMap.remove(key); 30 | if (condition != null) { 31 | condition.signal(); 32 | } 33 | return previous; 34 | } finally { 35 | lock.unlock(); 36 | } 37 | } 38 | 39 | public V getAndAwait(K key) { 40 | Lock lock = getLockForKey(key); 41 | lock.lock(); 42 | try { 43 | while (!map.containsKey(key)) { 44 | Condition condition = getConditionForKey(key); 45 | condition.await(); 46 | } 47 | return map.get(key); 48 | } catch (InterruptedException e) { 49 | log.error("Interrupted while waiting for key: {}", key); 50 | Thread.currentThread().interrupt(); 51 | return null; 52 | } finally { 53 | lock.unlock(); 54 | } 55 | } 56 | 57 | public V get(K key) { 58 | return map.get(key); 59 | } 60 | 61 | public V get(K key, boolean delete) { 62 | if (delete) { 63 | return map.remove(key); 64 | } 65 | return map.get(key); 66 | } 67 | 68 | private Lock getLockForKey(K key) { 69 | lockMap.putIfAbsent(key, new ReentrantLock()); 70 | return lockMap.get(key); 71 | } 72 | 73 | private Condition getConditionForKey(K key) { 74 | conditionMap.putIfAbsent(key, lockMap.get(key).newCondition()); 75 | return conditionMap.get(key); 76 | } 77 | 78 | } 79 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/schedule/entity/ChatTask.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.schedule.entity; 2 | 3 | import com.gearwenxin.config.ModelConfig; 4 | import com.gearwenxin.entity.enums.ModelType; 5 | import lombok.AllArgsConstructor; 6 | import lombok.Builder; 7 | import lombok.Data; 8 | import lombok.NoArgsConstructor; 9 | 10 | /** 11 | * @author GMerge 12 | * {@code @date} 2024/2/28 13 | */ 14 | @Data 15 | @Builder 16 | @AllArgsConstructor 17 | @NoArgsConstructor 18 | public class ChatTask { 19 | 20 | private String taskId; 21 | 22 | private ModelConfig modelConfig; 23 | 24 | private ModelType taskType; 25 | 26 | private Object taskRequest; 27 | 28 | private Float taskWeight; 29 | 30 | private String messageId; 31 | 32 | private boolean stream; 33 | 34 | private boolean jsonMode = false; 35 | 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/schedule/entity/ModelHeader.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.schedule.entity; 2 | 3 | import lombok.Getter; 4 | import lombok.Setter; 5 | 6 | public class ModelHeader { 7 | 8 | private Integer X_Ratelimit_Limit_Requests; 9 | private Integer X_Ratelimit_Limit_Tokens; 10 | private Integer X_Ratelimit_Remaining_Requests; 11 | private Integer X_Ratelimit_Remaining_Tokens; 12 | 13 | @Getter 14 | @Setter 15 | private String authorization; 16 | 17 | public Integer get_X_Ratelimit_Limit_Requests() { 18 | return X_Ratelimit_Limit_Requests; 19 | } 20 | 21 | public void set_X_Ratelimit_Limit_Requests(Integer x_Ratelimit_Limit_Requests) { 22 | X_Ratelimit_Limit_Requests = x_Ratelimit_Limit_Requests; 23 | } 24 | 25 | public Integer get_X_Ratelimit_Limit_Tokens() { 26 | return X_Ratelimit_Limit_Tokens; 27 | } 28 | 29 | public void set_X_Ratelimit_Limit_Tokens(Integer x_Ratelimit_Limit_Tokens) { 30 | X_Ratelimit_Limit_Tokens = x_Ratelimit_Limit_Tokens; 31 | } 32 | 33 | public Integer get_X_Ratelimit_Remaining_Requests() { 34 | return X_Ratelimit_Remaining_Requests; 35 | } 36 | 37 | public void set_X_Ratelimit_Remaining_Requests(Integer x_Ratelimit_Remaining_Requests) { 38 | X_Ratelimit_Remaining_Requests = x_Ratelimit_Remaining_Requests; 39 | } 40 | 41 | public Integer get_X_Ratelimit_Remaining_Tokens() { 42 | return X_Ratelimit_Remaining_Tokens; 43 | } 44 | 45 | public void set_X_Ratelimit_Remaining_Tokens(Integer x_Ratelimit_Remaining_Tokens) { 46 | X_Ratelimit_Remaining_Tokens = x_Ratelimit_Remaining_Tokens; 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/service/ChatService.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.service; 2 | 3 | import com.gearwenxin.common.*; 4 | import com.gearwenxin.config.WenXinProperties; 5 | import com.gearwenxin.core.RequestManager; 6 | import com.gearwenxin.core.MessageHistoryManager; 7 | import com.gearwenxin.entity.BaseRequest; 8 | import com.gearwenxin.entity.chatmodel.ChatBaseRequest; 9 | import com.gearwenxin.entity.chatmodel.ChatErnieRequest; 10 | import com.gearwenxin.entity.request.ErnieRequest; 11 | import com.gearwenxin.entity.response.ChatResponse; 12 | import com.gearwenxin.entity.Message; 13 | import com.gearwenxin.config.ModelConfig; 14 | import com.gearwenxin.validator.RequestValidator; 15 | import com.gearwenxin.validator.RequestValidatorFactory; 16 | import jakarta.annotation.Resource; 17 | import lombok.extern.slf4j.Slf4j; 18 | import org.reactivestreams.Publisher; 19 | import org.springframework.stereotype.Service; 20 | 21 | import java.util.*; 22 | 23 | /** 24 | * @author Ge Mingjia 25 | * {@code @date} 2023/7/20 26 | */ 27 | @Slf4j 28 | @Service 29 | public class ChatService { 30 | 31 | public static final String SERVICE_TAG = "ChatService"; 32 | 33 | @Resource 34 | private WenXinProperties wenXinProperties; 35 | 36 | private final RequestManager requestManager = new RequestManager(); 37 | 38 | private static final MessageHistoryManager messageHistoryManager = MessageHistoryManager.getInstance(); 39 | 40 | private String retrieveAccessToken() { 41 | return wenXinProperties.getAccessToken(); 42 | } 43 | 44 | public Publisher processChatRequest(T request, String messageId, 45 | boolean useStreaming, 46 | ModelConfig modelConfig) { 47 | validateRequest(request, modelConfig); 48 | 49 | Map> chatHistoryMap = messageHistoryManager.getChatMessageHistoryMap(); 50 | boolean hasHistory = (messageId != null); 51 | String accessToken = modelConfig.getAccessToken() == null 52 | ? retrieveAccessToken() 53 | : modelConfig.getAccessToken(); 54 | 55 | Object targetRequest; 56 | 57 | if (hasHistory) { 58 | Deque messageHistory = chatHistoryMap.computeIfAbsent( 59 | messageId, key -> new ArrayDeque<>() 60 | ); 61 | targetRequest = prepareRequestWithHistory(messageHistory, useStreaming, request); 62 | Message userMessage = WenXinUtils.buildUserMessage(request.getContent()); 63 | MessageHistoryManager.addMessage(messageHistory, userMessage); 64 | 65 | log.debug("[{}] Streaming: {}, Has History: {}", SERVICE_TAG, useStreaming, true); 66 | 67 | return useStreaming ? 68 | requestManager.historyFluxPost(modelConfig, accessToken, targetRequest, messageHistory, messageId) : 69 | requestManager.historyMonoPost(modelConfig, accessToken, targetRequest, messageHistory, messageId); 70 | } else { 71 | targetRequest = prepareRequestWithoutHistory(useStreaming, request); 72 | } 73 | 74 | log.debug("[{}] Streaming: {}, Has History: {}", SERVICE_TAG, useStreaming, false); 75 | 76 | return useStreaming ? 77 | requestManager.fluxPost(modelConfig, accessToken, targetRequest, ChatResponse.class, messageId) : 78 | requestManager.monoPost(modelConfig, accessToken, targetRequest, ChatResponse.class, messageId); 79 | } 80 | 81 | public void validateRequest(T request, ModelConfig modelConfig) { 82 | RequestValidator validator = RequestValidatorFactory.getValidator(modelConfig); 83 | validator.validate(request, modelConfig); 84 | } 85 | 86 | public static Object prepareRequestWithHistory(Deque messageHistory, 87 | boolean useStreaming, T request) { 88 | Object targetRequest = null; 89 | 90 | if (request.getClass() == ChatBaseRequest.class) { 91 | BaseRequest.BaseRequestBuilder requestBuilder = ConvertUtils.toBaseRequest(request).stream(useStreaming); 92 | if (messageHistory != null) { 93 | requestBuilder.messages(messageHistory); 94 | } 95 | targetRequest = requestBuilder.build(); 96 | } else if (request.getClass() == ChatErnieRequest.class) { 97 | ErnieRequest.ErnieRequestBuilder requestBuilder = ConvertUtils.toErnieRequest( 98 | (ChatErnieRequest) request).stream(useStreaming); 99 | if (messageHistory != null) { 100 | requestBuilder.messages(messageHistory); 101 | } 102 | targetRequest = requestBuilder.build(); 103 | } 104 | 105 | return targetRequest; 106 | } 107 | 108 | public static Object prepareRequestWithoutHistory( 109 | boolean useStreaming, T request) { 110 | 111 | return prepareRequestWithHistory(null, useStreaming, request); 112 | } 113 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/service/EmbeddingService.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.service; 2 | 3 | import com.gearwenxin.config.WenXinProperties; 4 | import jakarta.annotation.Resource; 5 | import lombok.extern.slf4j.Slf4j; 6 | import org.springframework.stereotype.Service; 7 | 8 | /** 9 | * @author Ge Mingjia 10 | * {@code @date} 2023/7/20 11 | */ 12 | @Slf4j 13 | @Service 14 | public class EmbeddingService { 15 | 16 | @Resource 17 | private WenXinProperties wenXinProperties; 18 | 19 | private String getAccessToken() { 20 | return wenXinProperties.getAccessToken(); 21 | } 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/service/ImageService.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.service; 2 | 3 | import com.gearwenxin.config.ModelConfig; 4 | import com.gearwenxin.config.WenXinProperties; 5 | import com.gearwenxin.core.RequestManager; 6 | import com.gearwenxin.entity.request.ImageBaseRequest; 7 | import com.gearwenxin.entity.response.ImageResponse; 8 | import jakarta.annotation.Resource; 9 | import lombok.extern.slf4j.Slf4j; 10 | import org.springframework.stereotype.Service; 11 | import reactor.core.publisher.Mono; 12 | 13 | import static com.gearwenxin.common.WenXinUtils.assertNotNull; 14 | 15 | /** 16 | * @author Ge Mingjia 17 | * {@code @date} 2023/7/20 18 | */ 19 | @Slf4j 20 | @Service 21 | public class ImageService { 22 | 23 | private final RequestManager requestManager = new RequestManager(); 24 | 25 | @Resource 26 | private WenXinProperties wenXinProperties; 27 | 28 | private String getAccessToken() { 29 | return wenXinProperties.getAccessToken(); 30 | } 31 | 32 | public Mono imageProcess(ImageBaseRequest imageBaseRequest, ModelConfig config) { 33 | assertNotNull(imageBaseRequest, "imageBaseRequest is null"); 34 | imageBaseRequest.validSelf(); 35 | 36 | return requestManager.monoPost(config, getAccessToken(), imageBaseRequest, ImageResponse.class); 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/service/MessageService.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.service; 2 | 3 | import com.gearwenxin.entity.Message; 4 | 5 | import java.util.Deque; 6 | 7 | public interface MessageService { 8 | 9 | Deque getHistoryMessages(String id); 10 | 11 | void addHistoryMessage(String id, Message message); 12 | 13 | void addHistoryMessage(Deque messagesHistory, Message message); 14 | 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/service/PromptService.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.service; 2 | 3 | import com.gearwenxin.config.ModelConfig; 4 | import com.gearwenxin.core.RequestManager; 5 | import com.gearwenxin.common.ConvertUtils; 6 | import com.gearwenxin.common.ErrorCode; 7 | import com.gearwenxin.config.WenXinProperties; 8 | import com.gearwenxin.exception.WenXinException; 9 | 10 | import com.gearwenxin.entity.chatmodel.ChatPromptRequest; 11 | import com.gearwenxin.entity.request.PromptRequest; 12 | import com.gearwenxin.entity.response.PromptResponse; 13 | import jakarta.annotation.Resource; 14 | import lombok.extern.slf4j.Slf4j; 15 | 16 | import org.springframework.stereotype.Service; 17 | import org.springframework.util.CollectionUtils; 18 | import reactor.core.publisher.Mono; 19 | 20 | import java.util.Map; 21 | 22 | /** 23 | * @author Ge Mingjia 24 | * {@code @date} 2023/7/20 25 | */ 26 | @Slf4j 27 | @Service 28 | public class PromptService { 29 | 30 | private final RequestManager requestManager = new RequestManager(); 31 | 32 | @Resource 33 | private WenXinProperties wenXinProperties; 34 | 35 | private String getAccessToken() { 36 | return wenXinProperties.getAccessToken(); 37 | } 38 | 39 | public Mono promptProcess(ChatPromptRequest chatPromptRequest, ModelConfig config) { 40 | if (chatPromptRequest == null || chatPromptRequest.getId() == null || CollectionUtils.isEmpty(chatPromptRequest.getParamMap())) { 41 | throw new WenXinException(ErrorCode.PARAMS_ERROR, "chatPromptRequest is null or id is null or paramMap is null"); 42 | } 43 | PromptRequest promptRequest = ConvertUtils.toPromptRequest(chatPromptRequest); 44 | Map paramMap = promptRequest.getParamMap(); 45 | paramMap.put("id", promptRequest.getId()); 46 | 47 | return requestManager.monoGet(config, getAccessToken(), paramMap, PromptResponse.class); 48 | } 49 | 50 | } 51 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/service/WinXinActions.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.service; 2 | 3 | import com.gearwenxin.entity.Message; 4 | 5 | import java.util.Deque; 6 | import java.util.Map; 7 | 8 | public interface WinXinActions { 9 | 10 | void initMessageMap(Map> map); 11 | 12 | void initMessages(String msgUid, Deque messageDeque); 13 | 14 | String exportMessages(String msgUid); 15 | 16 | String exportAllMessages(); 17 | 18 | boolean interpretChat(String msgUid); 19 | 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/service/impl/WinXinActionsImpl.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.service.impl; 2 | 3 | import com.gearwenxin.common.Constant; 4 | import com.gearwenxin.core.MessageHistoryManager; 5 | import com.gearwenxin.entity.Message; 6 | import com.gearwenxin.service.WinXinActions; 7 | import com.google.gson.Gson; 8 | import lombok.extern.slf4j.Slf4j; 9 | import org.springframework.stereotype.Service; 10 | 11 | import java.util.Deque; 12 | import java.util.Map; 13 | 14 | @Slf4j 15 | @Service 16 | public class WinXinActionsImpl implements WinXinActions { 17 | 18 | public static final String TAG = "WinXinActions"; 19 | 20 | private static final MessageHistoryManager messageHistoryManager = MessageHistoryManager.getInstance(); 21 | 22 | public static final Gson gson = new Gson(); 23 | 24 | @Override 25 | public void initMessageMap(Map> map) { 26 | messageHistoryManager.setChatMessageHistoryMap(map); 27 | } 28 | 29 | @Override 30 | public void initMessages(String msgUid, Deque messageDeque) { 31 | messageHistoryManager.getChatMessageHistoryMap().put(msgUid, messageDeque); 32 | } 33 | 34 | @Override 35 | public String exportMessages(String msgUid) { 36 | Deque messages = messageHistoryManager.getChatMessageHistoryMap().get(msgUid); 37 | if (messages != null) { 38 | log.debug("[{}] export messages, magUid: {}", TAG, msgUid); 39 | return gson.toJson(messages); 40 | } 41 | return null; 42 | } 43 | 44 | @Override 45 | public String exportAllMessages() { 46 | Map> chatMessageHistoryMap = messageHistoryManager.getChatMessageHistoryMap(); 47 | if (chatMessageHistoryMap != null) { 48 | log.debug("[{}] export all messages", TAG); 49 | return gson.toJson(chatMessageHistoryMap); 50 | } 51 | return null; 52 | } 53 | 54 | @Override 55 | public boolean interpretChat(String msgUid) { 56 | return Boolean.TRUE.equals(Constant.INTERRUPT_MAP.put(msgUid, true)); 57 | } 58 | 59 | } 60 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/subscriber/CommonSubscriber.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.subscriber; 2 | 3 | import com.gearwenxin.common.Constant; 4 | import com.gearwenxin.config.ModelConfig; 5 | import com.gearwenxin.core.MessageHistoryManager; 6 | import com.gearwenxin.entity.Message; 7 | import com.gearwenxin.entity.response.ChatResponse; 8 | import com.gearwenxin.schedule.TaskQueueManager; 9 | import com.gearwenxin.service.MessageService; 10 | import jakarta.annotation.Resource; 11 | import lombok.extern.slf4j.Slf4j; 12 | import org.apache.commons.lang3.StringUtils; 13 | import org.reactivestreams.Subscriber; 14 | import org.reactivestreams.Subscription; 15 | import reactor.core.Disposable; 16 | import reactor.core.publisher.FluxSink; 17 | 18 | import java.util.Deque; 19 | import java.util.Optional; 20 | import java.util.StringJoiner; 21 | 22 | import static com.gearwenxin.common.WenXinUtils.assertNotNull; 23 | import static com.gearwenxin.common.WenXinUtils.buildAssistantMessage; 24 | import static com.gearwenxin.core.MessageHistoryManager.validateMessageRule; 25 | 26 | /** 27 | * @author Ge Mingjia 28 | * {@code @date} 2023/7/20 29 | */ 30 | @Slf4j 31 | public class CommonSubscriber implements Subscriber, Disposable { 32 | 33 | private final TaskQueueManager taskManager = TaskQueueManager.getInstance(); 34 | 35 | private final FluxSink emitter; 36 | private Subscription subscription; 37 | private final Deque messagesHistory; 38 | @Resource 39 | private MessageService messageService; 40 | private final ModelConfig modelConfig; 41 | private final String msgUid; 42 | 43 | private final StringJoiner joiner = new StringJoiner(""); 44 | 45 | public CommonSubscriber(FluxSink emitter, Deque messagesHistory, 46 | ModelConfig modelConfig, String msgUid) { 47 | this.emitter = emitter; 48 | this.messagesHistory = messagesHistory; 49 | this.modelConfig = modelConfig; 50 | this.msgUid = msgUid; 51 | } 52 | 53 | @Override 54 | public void onSubscribe(Subscription subscription) { 55 | this.subscription = subscription; 56 | subscription.request(1); 57 | log.debug("onSubscribe"); 58 | } 59 | 60 | @Override 61 | public void onNext(ChatResponse response) { 62 | if (isDisposed()) { 63 | return; 64 | } 65 | // 中断对话 66 | if (Constant.INTERRUPT_MAP.get(msgUid)) { 67 | log.debug("interrupted"); 68 | dispose(); 69 | return; 70 | } 71 | 72 | assertNotNull(response, "chat response is null"); 73 | 74 | log.debug("onNext..."); 75 | 76 | Optional.ofNullable(response.getResult()).ifPresent(joiner::add); 77 | subscription.request(1); 78 | emitter.next(response); 79 | } 80 | 81 | @Override 82 | public void onError(Throwable throwable) { 83 | taskManager.downModelCurrentQPS(modelConfig.getModelName()); 84 | validateMessageRule(messagesHistory); 85 | if (isDisposed()) { 86 | return; 87 | } 88 | log.debug("onError"); 89 | emitter.error(throwable); 90 | } 91 | 92 | @Override 93 | public void onComplete() { 94 | taskManager.downModelCurrentQPS(modelConfig.getModelName()); 95 | if (isDisposed()) { 96 | return; 97 | } 98 | log.debug("onComplete"); 99 | String result = joiner.toString(); 100 | Optional.ofNullable(result).filter(StringUtils::isNotBlank).ifPresent(r -> { 101 | Message message = buildAssistantMessage(r); 102 | MessageHistoryManager.addMessage(messagesHistory, message); 103 | log.debug("add message onComplete"); 104 | }); 105 | emitter.complete(); 106 | } 107 | 108 | @Override 109 | public void dispose() { 110 | taskManager.downModelCurrentQPS(modelConfig.getModelName()); 111 | log.debug("dispose"); 112 | subscription.cancel(); 113 | } 114 | 115 | @Override 116 | public boolean isDisposed() { 117 | return Disposable.super.isDisposed(); 118 | } 119 | 120 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/validator/ChatBaseRequestValidator.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.validator; 2 | 3 | import com.gearwenxin.common.ErrorCode; 4 | import com.gearwenxin.entity.chatmodel.ChatBaseRequest; 5 | import com.gearwenxin.exception.WenXinException; 6 | import com.gearwenxin.config.ModelConfig; 7 | import org.apache.commons.lang3.StringUtils; 8 | 9 | public class ChatBaseRequestValidator implements RequestValidator { 10 | 11 | @Override 12 | public void validate(T request, ModelConfig config) { 13 | ChatBaseRequest chatBaseRequest = (ChatBaseRequest) request; 14 | // 检查content不为空 15 | if (StringUtils.isBlank(chatBaseRequest.getContent())) { 16 | throw new WenXinException(ErrorCode.PARAMS_ERROR, "content cannot be empty"); 17 | } 18 | // 检查单个content长度 19 | if (chatBaseRequest.getContent().length() > config.getContentMaxLength()) { 20 | throw new WenXinException(ErrorCode.PARAMS_ERROR, "content's length cannot be more than " + config.getContentMaxLength()); 21 | } 22 | } 23 | 24 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/validator/ChatErnieRequestValidator.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.validator; 2 | 3 | import com.gearwenxin.common.ErrorCode; 4 | import com.gearwenxin.entity.chatmodel.ChatErnieRequest; 5 | import com.gearwenxin.exception.WenXinException; 6 | import com.gearwenxin.config.ModelConfig; 7 | import lombok.extern.slf4j.Slf4j; 8 | import org.apache.commons.lang3.StringUtils; 9 | 10 | import static com.gearwenxin.common.Constant.MAX_SYSTEM_LENGTH; 11 | 12 | @Slf4j 13 | public class ChatErnieRequestValidator implements RequestValidator { 14 | @Override 15 | public void validate(T request, ModelConfig config) { 16 | ChatErnieRequest chatErnieRequest = (ChatErnieRequest) request; 17 | // 检查content不为空 18 | if (StringUtils.isBlank(chatErnieRequest.getContent())) { 19 | throw new WenXinException(ErrorCode.PARAMS_ERROR, "content cannot be empty"); 20 | } 21 | // 检查单个content长度 22 | if (chatErnieRequest.getContent().length() > config.getContentMaxLength()) { 23 | throw new WenXinException(ErrorCode.PARAMS_ERROR, "content's length cannot be more than " + config.getContentMaxLength()); 24 | } 25 | // 检查temperature和topP不都有值 26 | if (chatErnieRequest.getTemperature() != null && chatErnieRequest.getTopP() != null) { 27 | log.warn("Temperature and topP cannot both have value"); 28 | } 29 | // 检查temperature范围 30 | if (chatErnieRequest.getTemperature() != null && (chatErnieRequest.getTemperature() <= 0 || chatErnieRequest.getTemperature() > 1.0)) { 31 | throw new WenXinException(ErrorCode.PARAMS_ERROR, "temperature should be in (0, 1]"); 32 | } 33 | // 检查topP范围 34 | if (chatErnieRequest.getTopP() != null && (chatErnieRequest.getTopP() < 0 || chatErnieRequest.getTopP() > 1.0)) { 35 | throw new WenXinException(ErrorCode.PARAMS_ERROR, "topP should be in [0, 1]"); 36 | } 37 | // 检查penaltyScore范围 38 | if (chatErnieRequest.getTemperature() != null && (chatErnieRequest.getPenaltyScore() < 1.0 || chatErnieRequest.getPenaltyScore() > 2.0)) { 39 | throw new WenXinException(ErrorCode.PARAMS_ERROR, "penaltyScore should be in [1, 2]"); 40 | } 41 | // 检查system与function call 42 | if (StringUtils.isNotBlank(chatErnieRequest.getSystem()) && chatErnieRequest.getFunctions() != null) { 43 | throw new WenXinException(ErrorCode.PARAMS_ERROR, "if 'function' not null, the 'system' must be null"); 44 | } 45 | // 检查system长度 46 | if (chatErnieRequest.getSystem() != null && chatErnieRequest.getSystem().length() > MAX_SYSTEM_LENGTH) { 47 | throw new WenXinException(ErrorCode.PARAMS_ERROR, "system's length cannot be more than 1024"); 48 | } 49 | } 50 | 51 | } -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/validator/RequestValidator.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.validator; 2 | 3 | import com.gearwenxin.config.ModelConfig;public interface RequestValidator { 4 | void validate(T request, ModelConfig config); 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/com/gearwenxin/validator/RequestValidatorFactory.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.validator; 2 | 3 | import com.gearwenxin.config.ModelConfig; 4 | 5 | public class RequestValidatorFactory { 6 | 7 | public static RequestValidator getValidator(ModelConfig config) { 8 | 9 | if (config.getModelName().toLowerCase().contains("ernie")) { 10 | return new ChatErnieRequestValidator(); 11 | } else { 12 | return new ChatBaseRequestValidator(); 13 | } 14 | } 15 | } -------------------------------------------------------------------------------- /src/main/resources/META-INF/spring.factories: -------------------------------------------------------------------------------- 1 | org.springframework.boot.autoconfigure.EnableAutoConfiguration=com.gearwenxin.config.GearWenXinConfig -------------------------------------------------------------------------------- /src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports: -------------------------------------------------------------------------------- 1 | com.gearwenxin.config.GearWenXinConfig -------------------------------------------------------------------------------- /src/main/resources/application.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciimina/qianfan-starter/53f7ab383c15e3b88e1bc9ce2c86b72af1bff39a/src/main/resources/application.yaml -------------------------------------------------------------------------------- /src/test/java/com/gearwenxin/client/erniebot/ErnieBotClientTest.java: -------------------------------------------------------------------------------- 1 | package com.gearwenxin.client.erniebot; 2 | 3 | /** 4 | * @author Ge Mingjia 5 | * @date 2023/7/21 6 | */ 7 | class ErnieBotClientTest { 8 | 9 | } --------------------------------------------------------------------------------