├── .gitignore ├── LICENSE ├── README.md ├── img ├── 1.jpg ├── 1.png ├── 2.jpg ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── datacall.jpg └── zccbbg.jpg ├── pom.xml └── src ├── main ├── java │ └── com │ │ └── cyl │ │ └── ctrbt │ │ ├── ChatgptRobotBackApplication.java │ │ ├── controller │ │ └── DingTalkController.java │ │ ├── openai │ │ ├── ChatGPT.java │ │ ├── ChatGPTStream.java │ │ ├── ChatGPTStrreamUtil.java │ │ ├── ChatGPTUtil.java │ │ ├── api │ │ │ └── Api.java │ │ ├── entity │ │ │ ├── BaseResponse.java │ │ │ ├── billing │ │ │ │ ├── CreditGrantsResponse.java │ │ │ │ ├── Datum.java │ │ │ │ ├── Grants.java │ │ │ │ └── Usage.java │ │ │ └── chat │ │ │ │ ├── ChatChoice.java │ │ │ │ ├── ChatCompletion.java │ │ │ │ ├── ChatCompletionResponse.java │ │ │ │ └── Message.java │ │ ├── exception │ │ │ └── ChatException.java │ │ └── listener │ │ │ ├── AbstractStreamListener.java │ │ │ ├── ConsoleStreamListener.java │ │ │ └── SseStreamListener.java │ │ ├── util │ │ ├── ChatContextHolder.java │ │ ├── Proxys.java │ │ └── SseHelper.java │ │ └── websocket │ │ ├── MyWebSocketInterceptor.java │ │ ├── MyWebsocketHandler.java │ │ ├── WebSocketConfiguration.java │ │ ├── WebSocketServer.java │ │ └── bean │ │ └── WebSocketBean.java └── resources │ └── application.yml └── test ├── java └── com │ └── cyl │ └── ctrbt │ ├── ChatgptRobotBackApplicationTests.java │ ├── openai │ └── OpenAiTest.java │ └── websocket │ ├── MyWebSocketClient.java │ └── WebSocketTest.java └── resources └── application.yml /.gitignore: -------------------------------------------------------------------------------- 1 | HELP.md 2 | target/ 3 | !.mvn/wrapper/maven-wrapper.jar 4 | !**/src/main/**/target/ 5 | !**/src/test/**/target/ 6 | 7 | ### STS ### 8 | .apt_generated 9 | .classpath 10 | .factorypath 11 | .project 12 | .settings 13 | .springBeans 14 | .sts4-cache 15 | 16 | ### IntelliJ IDEA ### 17 | .idea 18 | *.iws 19 | *.iml 20 | *.ipr 21 | 22 | ### NetBeans ### 23 | /nbproject/private/ 24 | /nbbuild/ 25 | /dist/ 26 | /nbdist/ 27 | /.nb-gradle/ 28 | build/ 29 | !**/src/main/**/build/ 30 | !**/src/test/**/build/ 31 | 32 | ### VS Code ### 33 | .vscode/ 34 | application-dev.yml 35 | .mvn/ 36 | mvnw 37 | mvnw.cmd -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 字节叔叔 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 介绍 2 | - 本项目是一个基于Springboot的一个后端服务,用于实时接收chatGPT的消息,并通过websocket的方式实时反馈给前端。 3 | - 本项目还可以助你将GPT机器人集成到钉钉群聊中,通过@机器人进行聊天交互。 4 | ### 前端页面截图: 5 | - ![pc端](img/1.jpg) 6 | - ![手机端](img/2.jpg) 7 | ### 钉钉使用截图: 8 | - ![写代码](img/5.png) 9 | - ![入职介绍](img/6.png) 10 | - ![放臭屁](img/7.png) 11 | - ![迟到](img/8.png) 12 | 13 | ## 功能特性 14 | 15 | | 功能 | 特性 | 16 | | :---------: | :------: | 17 | | GPT 3.5 | 支持 | 18 | | GPT 4.0 | 支持 | 19 | | GPT 4.0-32k | 支持 | 20 | | 流式对话 | 支持 | 21 | | 阻塞式对话 | 支持 | 22 | | 上下文 | 支持 | 23 | | 计算Token | 即将支持 | 24 | | 多KEY轮询 | 支持 | 25 | | 代理 | 支持 | 26 | | 反向代理 | 支持 | 27 | 28 | 29 | ## 使用前提 30 | * 有Openai账号,并且创建好`api_key`,注册相关事项可以参考[此文章](https://juejin.cn/post/7173447848292253704) 。访问[这里](https://beta.openai.com/account/api-keys),申请个人秘钥。 31 | * 在钉钉开发者后台创建机器人,配置应用程序回调。 32 | 33 | 34 | ## 使用教程 35 | 36 | ### 钉钉创建机器人 37 | 38 | 创建步骤参考文档:[企业内部开发机器人](https://open.dingtalk.com/document/robots/enterprise-created-chatbot),或者根据如下步骤进行配置。 39 | 40 | 1. 创建机器人。 41 | ![image_20221209_163616](img/1.png) 42 | 43 | > `📢 注意:`可能现在创建机器人的时候名字为`chatgpt`会被钉钉限制,请用其他名字命名。 44 | 45 | 步骤比较简单,这里就不赘述了。 46 | 47 | 2. 配置机器人回调接口。 48 | ![image_20221209_163652](img/2.png) 49 | 50 | 创建完毕之后,点击机器人开发管理,然后配置将要部署的服务所在服务器的出口IP,以及将要给服务配置的域名。 51 | 52 | 3. 发布机器人。 53 | ![image_20221209_163709](img/3.png) 54 | 55 | 点击版本管理与发布,然后点击上线,这个时候就能在钉钉的群里中添加这个机器人了。 56 | 57 | 4. 群聊添加机器人。 58 | 59 | ![image_20221209_163724](img/4.png) 60 | 61 | ## 前端项目地址 62 | * github: https://github.com/zccbbg/chatgpt-vue 63 | * gitee: https://gitee.com/zccbbg/chatgpt-vue 64 | ## 关于我们 65 | * 开发团队成立5年,我们前端开发、后端架构,有一颗热爱开源的心,致力于打造企业级的通用产品设计UI体系让项目 或者更直观,更高效、更简单,未来将持续关注UI交互,持续推出高质量的交互产品。 66 | * 这五年我主要做isv对接淘宝、拼多多、抖音、美团等平台的订单处理应用,日处理订单300w条,因为要熟悉业务也开过淘宝和拼多多店铺运营了一个网易严选的品牌。我们的公众号会陆续更新一些我一边撸代码一边做客服的经历。也会更新一些我的读书笔记以及编程、创业、生活中踩坑的文章。另外还会开放一些米哈游、博世、企查查、同程、阿里、京东、拼多多等大厂、中厂或外企的内推岗位! 67 | ## 技术支持 68 | * 关注公众号“编写美好前程”回复:支持,即可加入群聊。
69 | * -------------------------------------------------------------------------------- /img/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zccbbg/chatgpt-springboot-service/80ceb9a5d7dcd0904cf8563390fe258761b93936/img/1.jpg -------------------------------------------------------------------------------- /img/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zccbbg/chatgpt-springboot-service/80ceb9a5d7dcd0904cf8563390fe258761b93936/img/1.png -------------------------------------------------------------------------------- /img/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zccbbg/chatgpt-springboot-service/80ceb9a5d7dcd0904cf8563390fe258761b93936/img/2.jpg -------------------------------------------------------------------------------- /img/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zccbbg/chatgpt-springboot-service/80ceb9a5d7dcd0904cf8563390fe258761b93936/img/2.png -------------------------------------------------------------------------------- /img/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zccbbg/chatgpt-springboot-service/80ceb9a5d7dcd0904cf8563390fe258761b93936/img/3.png -------------------------------------------------------------------------------- /img/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zccbbg/chatgpt-springboot-service/80ceb9a5d7dcd0904cf8563390fe258761b93936/img/4.png -------------------------------------------------------------------------------- /img/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zccbbg/chatgpt-springboot-service/80ceb9a5d7dcd0904cf8563390fe258761b93936/img/5.png -------------------------------------------------------------------------------- /img/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zccbbg/chatgpt-springboot-service/80ceb9a5d7dcd0904cf8563390fe258761b93936/img/6.png -------------------------------------------------------------------------------- /img/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zccbbg/chatgpt-springboot-service/80ceb9a5d7dcd0904cf8563390fe258761b93936/img/7.png -------------------------------------------------------------------------------- /img/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zccbbg/chatgpt-springboot-service/80ceb9a5d7dcd0904cf8563390fe258761b93936/img/8.png -------------------------------------------------------------------------------- /img/datacall.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zccbbg/chatgpt-springboot-service/80ceb9a5d7dcd0904cf8563390fe258761b93936/img/datacall.jpg -------------------------------------------------------------------------------- /img/zccbbg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zccbbg/chatgpt-springboot-service/80ceb9a5d7dcd0904cf8563390fe258761b93936/img/zccbbg.jpg -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 4.0.0 5 | 6 | org.springframework.boot 7 | spring-boot-starter-parent 8 | 2.2.13.RELEASE 9 | 10 | 11 | com.cyl 12 | chatgpt-springboot-service 13 | 0.0.1-SNAPSHOT 14 | chatgpt-springboot-service 15 | chatgpt-springboot-service 16 | 17 | 1.8 18 | 2.9.0 19 | ${java.version} 20 | ${java.version} 21 | 22 | 23 | 24 | org.projectlombok 25 | lombok 26 | 1.18.26 27 | compile 28 | true 29 | 30 | 31 | 32 | com.aliyun 33 | dingtalk 34 | 1.5.39 35 | 36 | 37 | com.squareup.retrofit2 38 | retrofit 39 | ${retrofit2.version} 40 | 41 | 42 | com.squareup.retrofit2 43 | converter-jackson 44 | ${retrofit2.version} 45 | 46 | 47 | com.squareup.retrofit2 48 | adapter-rxjava2 49 | ${retrofit2.version} 50 | 51 | 52 | com.squareup.okhttp3 53 | okhttp-sse 54 | 3.14.9 55 | 56 | 57 | com.squareup.okhttp3 58 | logging-interceptor 59 | 3.14.9 60 | 61 | 62 | com.alibaba 63 | fastjson 64 | 2.0.26 65 | 66 | 67 | com.aliyun 68 | alibaba-dingtalk-service-sdk 69 | 2.0.0 70 | 71 | 72 | 73 | org.springframework.boot 74 | spring-boot-starter-websocket 75 | 76 | 77 | org.springframework.boot 78 | spring-boot-starter-web 79 | 80 | 81 | 82 | org.projectlombok 83 | lombok 84 | true 85 | 86 | 87 | org.springframework.boot 88 | spring-boot-starter-test 89 | test 90 | 91 | 92 | cn.hutool 93 | hutool-all 94 | 5.8.12 95 | 96 | 97 | 98 | org.java-websocket 99 | Java-WebSocket 100 | 1.5.1 101 | test 102 | 103 | 104 | 105 | 106 | 107 | 108 | org.springframework.boot 109 | spring-boot-maven-plugin 110 | 111 | 112 | 113 | org.projectlombok 114 | lombok 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/ChatgptRobotBackApplication.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt; 2 | 3 | import org.springframework.boot.SpringApplication; 4 | import org.springframework.boot.autoconfigure.SpringBootApplication; 5 | 6 | @SpringBootApplication 7 | public class ChatgptRobotBackApplication { 8 | 9 | public static void main(String[] args) { 10 | SpringApplication.run(ChatgptRobotBackApplication.class, args); 11 | } 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/controller/DingTalkController.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.controller; 2 | 3 | import cn.hutool.json.JSONObject; 4 | import cn.hutool.json.JSONUtil; 5 | import com.cyl.ctrbt.openai.ChatGPTUtil; 6 | import com.cyl.ctrbt.openai.entity.chat.Message; 7 | import com.dingtalk.api.DefaultDingTalkClient; 8 | import com.dingtalk.api.DingTalkClient; 9 | import com.dingtalk.api.request.OapiRobotSendRequest; 10 | import com.dingtalk.api.response.OapiRobotSendResponse; 11 | import com.taobao.api.ApiException; 12 | import org.springframework.beans.factory.annotation.Autowired; 13 | import org.springframework.web.bind.annotation.*; 14 | 15 | import java.util.stream.Collectors; 16 | 17 | @RequestMapping("/ding-talk") 18 | @RestController 19 | public class DingTalkController { 20 | 21 | @Autowired 22 | private ChatGPTUtil chatGPTUtil; 23 | 24 | @RequestMapping("/receive") 25 | public String helloRobots(@RequestBody(required = false) JSONObject json) { 26 | System.out.println(JSONUtil.toJsonStr(json)); 27 | String content = json.getJSONObject("text").get("content").toString().replaceAll(" ", ""); 28 | String sessionWebhook = json.getStr("sessionWebhook"); 29 | DingTalkClient client = new DefaultDingTalkClient(sessionWebhook); 30 | if ("text".equals(json.getStr("msgtype"))) { 31 | text(client,content); 32 | } 33 | return null; 34 | } 35 | 36 | private void text(DingTalkClient client,String content) { 37 | try { 38 | OapiRobotSendRequest request = new OapiRobotSendRequest(); 39 | request.setMsgtype("text"); 40 | OapiRobotSendRequest.Text text = new OapiRobotSendRequest.Text(); 41 | Message message = chatGPTUtil.chat(content, "dingtalk"); 42 | text.setContent(message.getContent()); 43 | request.setText(text); 44 | OapiRobotSendResponse response = client.execute(request); 45 | System.out.println(response.getBody()); 46 | } catch (ApiException e) { 47 | e.printStackTrace(); 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/ChatGPT.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai; 2 | 3 | import cn.hutool.core.util.RandomUtil; 4 | import cn.hutool.http.ContentType; 5 | import cn.hutool.http.Header; 6 | import com.alibaba.fastjson.JSON; 7 | import com.cyl.ctrbt.openai.api.Api; 8 | import com.cyl.ctrbt.openai.entity.BaseResponse; 9 | import com.cyl.ctrbt.openai.entity.billing.CreditGrantsResponse; 10 | import com.cyl.ctrbt.openai.entity.chat.ChatCompletion; 11 | import com.cyl.ctrbt.openai.entity.chat.ChatCompletionResponse; 12 | import com.cyl.ctrbt.openai.entity.chat.Message; 13 | import com.cyl.ctrbt.openai.exception.ChatException; 14 | import io.reactivex.Single; 15 | import lombok.*; 16 | import lombok.extern.slf4j.Slf4j; 17 | import okhttp3.OkHttpClient; 18 | import okhttp3.Request; 19 | import okhttp3.Response; 20 | import retrofit2.Retrofit; 21 | import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory; 22 | import retrofit2.converter.jackson.JacksonConverterFactory; 23 | 24 | import java.math.BigDecimal; 25 | import java.net.Proxy; 26 | import java.util.Arrays; 27 | import java.util.List; 28 | import java.util.Objects; 29 | import java.util.concurrent.TimeUnit; 30 | 31 | 32 | /** 33 | * open ai 客户端 34 | * 35 | * @author plexpt 36 | */ 37 | 38 | @Slf4j 39 | @Getter 40 | @Setter 41 | @Builder 42 | @AllArgsConstructor 43 | @NoArgsConstructor 44 | public class ChatGPT { 45 | /** 46 | * keys 47 | */ 48 | private String apiKey; 49 | 50 | private List apiKeyList; 51 | /** 52 | * 自定义api host使用builder的方式构造client 53 | */ 54 | @Builder.Default 55 | private String apiHost = Api.DEFAULT_API_HOST; 56 | private Api apiClient; 57 | private OkHttpClient okHttpClient; 58 | /** 59 | * 超时 默认300 60 | */ 61 | @Builder.Default 62 | private long timeout = 300; 63 | /** 64 | * okhttp 代理 65 | */ 66 | @Builder.Default 67 | private Proxy proxy = Proxy.NO_PROXY; 68 | 69 | 70 | /** 71 | * 初始化 72 | */ 73 | public ChatGPT init() { 74 | OkHttpClient.Builder client = new OkHttpClient.Builder(); 75 | client.addInterceptor(chain -> { 76 | Request original = chain.request(); 77 | String key = apiKey; 78 | if (apiKeyList != null && !apiKeyList.isEmpty()) { 79 | key = RandomUtil.randomEle(apiKeyList); 80 | } 81 | 82 | Request request = original.newBuilder() 83 | .header(Header.AUTHORIZATION.getValue(), "Bearer " + key) 84 | .header(Header.CONTENT_TYPE.getValue(), ContentType.JSON.getValue()) 85 | .method(original.method(), original.body()) 86 | .build(); 87 | return chain.proceed(request); 88 | }).addInterceptor(chain -> { 89 | Request original = chain.request(); 90 | Response response = chain.proceed(original); 91 | if (!response.isSuccessful()) { 92 | String errorMsg = response.body().string(); 93 | 94 | log.error("请求异常:{}", errorMsg); 95 | BaseResponse baseResponse = JSON.parseObject(errorMsg, BaseResponse.class); 96 | if (Objects.nonNull(baseResponse.getError())) { 97 | log.error(baseResponse.getError().getMessage()); 98 | throw new ChatException(baseResponse.getError().getMessage()); 99 | } 100 | throw new ChatException("error"); 101 | } 102 | return response; 103 | }); 104 | 105 | client.connectTimeout(timeout, TimeUnit.SECONDS); 106 | client.writeTimeout(timeout, TimeUnit.SECONDS); 107 | client.readTimeout(timeout, TimeUnit.SECONDS); 108 | if (Objects.nonNull(proxy)) { 109 | client.proxy(proxy); 110 | } 111 | OkHttpClient httpClient = client.build(); 112 | this.okHttpClient = httpClient; 113 | 114 | 115 | this.apiClient = new Retrofit.Builder() 116 | .baseUrl(this.apiHost) 117 | .client(okHttpClient) 118 | .addCallAdapterFactory(RxJava2CallAdapterFactory.create()) 119 | .addConverterFactory(JacksonConverterFactory.create()) 120 | .build() 121 | .create(Api.class); 122 | 123 | return this; 124 | } 125 | 126 | 127 | /** 128 | * 最新版的GPT-3.5 chat completion 更加贴近官方网站的问答模型 129 | * 130 | * @param chatCompletion 问答参数 131 | * @return 答案 132 | */ 133 | public ChatCompletionResponse chatCompletion(ChatCompletion chatCompletion) { 134 | Single chatCompletionResponse = 135 | this.apiClient.chatCompletion(chatCompletion); 136 | return chatCompletionResponse.blockingGet(); 137 | } 138 | 139 | /** 140 | * 简易版 141 | * 142 | * @param messages 问答参数 143 | */ 144 | public ChatCompletionResponse chatCompletion(List messages) { 145 | ChatCompletion chatCompletion = ChatCompletion.builder().messages(messages).build(); 146 | return this.chatCompletion(chatCompletion); 147 | } 148 | 149 | /** 150 | * 直接问 151 | */ 152 | public String chat(String message) { 153 | ChatCompletion chatCompletion = ChatCompletion.builder() 154 | .messages(Arrays.asList(Message.of(message))) 155 | .build(); 156 | ChatCompletionResponse response = this.chatCompletion(chatCompletion); 157 | return response.getChoices().get(0).getMessage().getContent(); 158 | } 159 | 160 | /** 161 | * 余额查询 162 | * 163 | * @return 164 | */ 165 | public CreditGrantsResponse creditGrants() { 166 | Single creditGrants = this.apiClient.creditGrants(); 167 | return creditGrants.blockingGet(); 168 | } 169 | 170 | 171 | /** 172 | * 余额查询 173 | * 174 | * @return 175 | */ 176 | public BigDecimal balance() { 177 | Single creditGrants = apiClient.creditGrants(); 178 | CreditGrantsResponse response = creditGrants.blockingGet(); 179 | 180 | return response.getTotalAvailable(); 181 | } 182 | 183 | /** 184 | * 余额查询 185 | * 186 | * @return 187 | */ 188 | public static BigDecimal balance(String key) { 189 | ChatGPT chatGPT = ChatGPT.builder() 190 | .apiKey(key) 191 | .build() 192 | .init(); 193 | 194 | return chatGPT.balance(); 195 | } 196 | 197 | } 198 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/ChatGPTStream.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai; 2 | 3 | import cn.hutool.core.util.RandomUtil; 4 | import cn.hutool.http.ContentType; 5 | import com.cyl.ctrbt.openai.api.Api; 6 | import com.cyl.ctrbt.openai.entity.chat.ChatCompletion; 7 | import com.cyl.ctrbt.openai.entity.chat.Message; 8 | import com.fasterxml.jackson.databind.ObjectMapper; 9 | import lombok.AllArgsConstructor; 10 | import lombok.Builder; 11 | import lombok.Data; 12 | import lombok.NoArgsConstructor; 13 | import lombok.extern.slf4j.Slf4j; 14 | import okhttp3.MediaType; 15 | import okhttp3.OkHttpClient; 16 | import okhttp3.Request; 17 | import okhttp3.RequestBody; 18 | import okhttp3.sse.EventSource; 19 | import okhttp3.sse.EventSourceListener; 20 | import okhttp3.sse.EventSources; 21 | 22 | import java.net.Proxy; 23 | import java.util.List; 24 | import java.util.Objects; 25 | import java.util.concurrent.TimeUnit; 26 | 27 | 28 | /** 29 | * open ai 客户端 30 | * 31 | * @author plexpt 32 | */ 33 | 34 | @Slf4j 35 | @Data 36 | @Builder 37 | @NoArgsConstructor 38 | @AllArgsConstructor 39 | public class ChatGPTStream { 40 | 41 | private String apiKey; 42 | private List apiKeyList; 43 | 44 | private OkHttpClient okHttpClient; 45 | /** 46 | * 连接超时 47 | */ 48 | @Builder.Default 49 | private long timeout = 90; 50 | 51 | /** 52 | * 网络代理 53 | */ 54 | @Builder.Default 55 | private Proxy proxy = Proxy.NO_PROXY; 56 | /** 57 | * 反向代理 58 | */ 59 | @Builder.Default 60 | private String apiHost = Api.DEFAULT_API_HOST; 61 | 62 | /** 63 | * 初始化 64 | */ 65 | public ChatGPTStream init() { 66 | OkHttpClient.Builder client = new OkHttpClient.Builder(); 67 | client.connectTimeout(timeout, TimeUnit.SECONDS); 68 | client.writeTimeout(timeout, TimeUnit.SECONDS); 69 | client.readTimeout(timeout, TimeUnit.SECONDS); 70 | if (Objects.nonNull(proxy)) { 71 | client.proxy(proxy); 72 | } 73 | 74 | okHttpClient = client.build(); 75 | 76 | return this; 77 | } 78 | 79 | 80 | /** 81 | * 流式输出 82 | */ 83 | public void streamChatCompletion(ChatCompletion chatCompletion, 84 | EventSourceListener eventSourceListener) { 85 | 86 | chatCompletion.setStream(true); 87 | 88 | try { 89 | EventSource.Factory factory = EventSources.createFactory(okHttpClient); 90 | ObjectMapper mapper = new ObjectMapper(); 91 | String requestBody = mapper.writeValueAsString(chatCompletion); 92 | String key = apiKey; 93 | if (apiKeyList != null && !apiKeyList.isEmpty()) { 94 | key = RandomUtil.randomEle(apiKeyList); 95 | } 96 | 97 | 98 | Request request = new Request.Builder() 99 | .url(apiHost + "v1/chat/completions") 100 | .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), 101 | requestBody)) 102 | .header("Authorization", "Bearer " + key) 103 | .build(); 104 | factory.newEventSource(request, eventSourceListener); 105 | 106 | } catch (Exception e) { 107 | log.error("请求出错:{}", e); 108 | } 109 | } 110 | 111 | /** 112 | * 流式输出 113 | */ 114 | public void streamChatCompletion(List messages, 115 | EventSourceListener eventSourceListener) { 116 | ChatCompletion chatCompletion = ChatCompletion.builder() 117 | .messages(messages) 118 | .stream(true) 119 | .build(); 120 | streamChatCompletion(chatCompletion, eventSourceListener); 121 | } 122 | 123 | 124 | } 125 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/ChatGPTStrreamUtil.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai; 2 | 3 | import com.cyl.ctrbt.openai.entity.chat.ChatCompletion; 4 | import com.cyl.ctrbt.openai.entity.chat.Message; 5 | import com.cyl.ctrbt.openai.listener.ConsoleStreamListener; 6 | import com.cyl.ctrbt.util.Proxys; 7 | import lombok.extern.slf4j.Slf4j; 8 | import org.springframework.beans.factory.annotation.Value; 9 | import org.springframework.stereotype.Component; 10 | import org.springframework.util.StringUtils; 11 | 12 | import javax.annotation.PostConstruct; 13 | import java.net.Proxy; 14 | import java.util.Arrays; 15 | 16 | @Slf4j 17 | @Component 18 | public class ChatGPTStrreamUtil { 19 | @Value("${openai.secret_key}") 20 | private String token; 21 | 22 | private ChatGPTStream chatGPTStream; 23 | 24 | @Value("${proxy.ip}") 25 | private String proxyIp; 26 | @Value("${proxy.port}") 27 | private Integer proxyPort; 28 | 29 | @PostConstruct 30 | public void init(){ 31 | //如果在国内访问,使用这个 32 | if(!StringUtils.isEmpty(proxyIp)){ 33 | Proxy proxy = Proxys.http(proxyIp, proxyPort); 34 | chatGPTStream = ChatGPTStream.builder() 35 | .apiKey(token) 36 | .timeout(900) 37 | .proxy(proxy) 38 | .apiHost("https://api.openai.com/") //代理地址 39 | .build() 40 | .init(); 41 | }else{ 42 | chatGPTStream = ChatGPTStream.builder() 43 | .apiKey(token) 44 | .timeout(900) 45 | .apiHost("https://api.openai.com/") //代理地址 46 | .build() 47 | .init(); 48 | } 49 | } 50 | public void chat(String userMessage,String user) { 51 | ConsoleStreamListener listener = new ConsoleStreamListener(); 52 | Message message = Message.of(userMessage); 53 | ChatCompletion chatCompletion = ChatCompletion.builder() 54 | .user(user) 55 | .messages(Arrays.asList(message)) 56 | .build(); 57 | chatGPTStream.streamChatCompletion(chatCompletion, listener); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/ChatGPTUtil.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai; 2 | 3 | import com.cyl.ctrbt.openai.entity.billing.CreditGrantsResponse; 4 | import com.cyl.ctrbt.openai.entity.chat.ChatCompletion; 5 | import com.cyl.ctrbt.openai.entity.chat.ChatCompletionResponse; 6 | import com.cyl.ctrbt.openai.entity.chat.Message; 7 | import com.cyl.ctrbt.util.Proxys; 8 | import lombok.extern.slf4j.Slf4j; 9 | import org.springframework.beans.factory.annotation.Value; 10 | import org.springframework.stereotype.Component; 11 | import org.springframework.util.StringUtils; 12 | 13 | import javax.annotation.PostConstruct; 14 | import java.net.Proxy; 15 | import java.util.Arrays; 16 | 17 | @Slf4j 18 | @Component 19 | public class ChatGPTUtil { 20 | @Value("${openai.secret_key}") 21 | private String token; 22 | 23 | private ChatGPT chatGPT; 24 | 25 | @Value("${proxy.ip}") 26 | private String proxyIp; 27 | @Value("${proxy.port}") 28 | private Integer proxyPort; 29 | 30 | @PostConstruct 31 | public void init(){ 32 | if(!StringUtils.isEmpty(proxyIp)){ 33 | //如果在国内访问,使用这个,在application.yml里面配置 34 | Proxy proxy = Proxys.http(proxyIp, proxyPort); 35 | chatGPT = ChatGPT.builder() 36 | .apiKey(token) 37 | .timeout(600) 38 | .proxy(proxy) 39 | .apiHost("https://api.openai.com/") //代理地址 40 | .build() 41 | .init(); 42 | }else{ 43 | chatGPT = ChatGPT.builder() 44 | .apiKey(token) 45 | .timeout(600) 46 | .apiHost("https://api.openai.com/") //代理地址 47 | .build() 48 | .init(); 49 | } 50 | 51 | } 52 | public Message chat(String userMessage,String user) { 53 | Message message = Message.of(userMessage); 54 | 55 | ChatCompletion chatCompletion = ChatCompletion.builder() 56 | .model(ChatCompletion.Model.GPT_3_5_TURBO.getName()) 57 | .user(user) 58 | .messages(Arrays.asList(message)) 59 | .maxTokens(3000) 60 | .temperature(0.9) 61 | .build(); 62 | ChatCompletionResponse response = chatGPT.chatCompletion(chatCompletion); 63 | return response.getChoices().get(0).getMessage(); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/api/Api.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai.api; 2 | 3 | import com.cyl.ctrbt.openai.entity.billing.CreditGrantsResponse; 4 | import com.cyl.ctrbt.openai.entity.chat.ChatCompletion; 5 | import com.cyl.ctrbt.openai.entity.chat.ChatCompletionResponse; 6 | import io.reactivex.Single; 7 | import retrofit2.http.Body; 8 | import retrofit2.http.GET; 9 | import retrofit2.http.POST; 10 | 11 | 12 | /** 13 | * 14 | */ 15 | public interface Api { 16 | 17 | String DEFAULT_API_HOST = "https://api.openai.com/"; 18 | 19 | 20 | /** 21 | * chat 22 | */ 23 | @POST("v1/chat/completions") 24 | Single chatCompletion(@Body ChatCompletion chatCompletion); 25 | 26 | 27 | /** 28 | * 余额查询 29 | */ 30 | @GET("dashboard/billing/credit_grants") 31 | Single creditGrants(); 32 | 33 | 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/entity/BaseResponse.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai.entity; 2 | 3 | import lombok.Data; 4 | 5 | import java.util.List; 6 | 7 | /** 8 | * @author plexpt 9 | */ 10 | @Data 11 | public class BaseResponse { 12 | private String object; 13 | private List data; 14 | private Error error; 15 | 16 | 17 | @Data 18 | public class Error { 19 | private String message; 20 | private String type; 21 | private String param; 22 | private String code; 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/entity/billing/CreditGrantsResponse.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai.entity.billing; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.Data; 5 | 6 | import java.math.BigDecimal; 7 | 8 | /** 9 | * 余额查询接口返回值 10 | * 11 | * @author plexpt 12 | */ 13 | @Data 14 | public class CreditGrantsResponse { 15 | private String object; 16 | /** 17 | * 总金额:美元 18 | */ 19 | @JsonProperty("total_granted") 20 | private BigDecimal totalGranted; 21 | /** 22 | * 总使用金额:美元 23 | */ 24 | @JsonProperty("total_used") 25 | private BigDecimal totalUsed; 26 | /** 27 | * 总剩余金额:美元 28 | */ 29 | @JsonProperty("total_available") 30 | private BigDecimal totalAvailable; 31 | /** 32 | * 余额明细 33 | */ 34 | private Grants grants; 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/entity/billing/Datum.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai.entity.billing; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.Data; 5 | 6 | import java.math.BigDecimal; 7 | 8 | /** 9 | * @author plexpt 10 | */ 11 | @Data 12 | public class Datum { 13 | private String object; 14 | private String id; 15 | /** 16 | * 赠送金额:美元 17 | */ 18 | @JsonProperty("grant_amount") 19 | private BigDecimal grantAmount; 20 | /** 21 | * 使用金额:美元 22 | */ 23 | @JsonProperty("used_amount") 24 | private BigDecimal usedAmount; 25 | /** 26 | * 生效时间戳 27 | */ 28 | @JsonProperty("effective_at") 29 | private Long effectiveAt; 30 | /** 31 | * 过期时间戳 32 | */ 33 | @JsonProperty("expires_at") 34 | private Long expiresAt; 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/entity/billing/Grants.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai.entity.billing; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.Data; 5 | 6 | import java.util.List; 7 | 8 | /** 9 | * @author plexpt 10 | */ 11 | @Data 12 | public class Grants { 13 | private String object; 14 | @JsonProperty("data") 15 | private List data; 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/entity/billing/Usage.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai.entity.billing; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.Data; 5 | 6 | /** 7 | * @author plexpt 8 | */ 9 | @Data 10 | public class Usage { 11 | @JsonProperty("prompt_tokens") 12 | private long promptTokens; 13 | @JsonProperty("completion_tokens") 14 | private long completionTokens; 15 | @JsonProperty("total_tokens") 16 | private long totalTokens; 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/entity/chat/ChatChoice.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai.entity.chat; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.Data; 5 | 6 | /** 7 | * @author plexpt 8 | */ 9 | @Data 10 | public class ChatChoice { 11 | private long index; 12 | /** 13 | * 请求参数stream为true返回是delta 14 | */ 15 | @JsonProperty("delta") 16 | private Message delta; 17 | /** 18 | * 请求参数stream为false返回是message 19 | */ 20 | @JsonProperty("message") 21 | private Message message; 22 | @JsonProperty("finish_reason") 23 | private String finishReason; 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/entity/chat/ChatCompletion.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai.entity.chat; 2 | 3 | import com.fasterxml.jackson.annotation.JsonInclude; 4 | import com.fasterxml.jackson.annotation.JsonProperty; 5 | import lombok.*; 6 | import lombok.extern.slf4j.Slf4j; 7 | 8 | import java.io.Serializable; 9 | import java.util.List; 10 | import java.util.Map; 11 | 12 | /** 13 | * chat 14 | * 15 | * @author plexpt 16 | */ 17 | @Data 18 | @Builder 19 | @Slf4j 20 | @AllArgsConstructor 21 | @NoArgsConstructor 22 | @JsonInclude(JsonInclude.Include.NON_NULL) 23 | public class ChatCompletion implements Serializable { 24 | 25 | @NonNull 26 | @Builder.Default 27 | private String model = Model.GPT_3_5_TURBO.getName(); 28 | 29 | @NonNull 30 | private List messages; 31 | /** 32 | * 使用什么取样温度,0到2之间。越高越奔放。越低越保守。 33 | *

34 | * 不要同时改这个和topP 35 | */ 36 | @Builder.Default 37 | private double temperature = 0.9; 38 | 39 | /** 40 | * 0-1 41 | * 建议0.9 42 | * 不要同时改这个和temperature 43 | */ 44 | @JsonProperty("top_p") 45 | @Builder.Default 46 | private double topP = 0.9; 47 | 48 | 49 | /** 50 | * 结果数。 51 | */ 52 | @Builder.Default 53 | private Integer n = 1; 54 | 55 | 56 | /** 57 | * 是否流式输出. 58 | * default:false 59 | */ 60 | @Builder.Default 61 | private boolean stream = false; 62 | /** 63 | * 停用词 64 | */ 65 | private List stop; 66 | /** 67 | * 3.5 最大支持4096 68 | * 4.0 最大32k 69 | */ 70 | @JsonProperty("max_tokens") 71 | private Integer maxTokens; 72 | 73 | 74 | @JsonProperty("presence_penalty") 75 | private double presencePenalty; 76 | 77 | /** 78 | * -2.0 ~~ 2.0 79 | */ 80 | @JsonProperty("frequency_penalty") 81 | private double frequencyPenalty; 82 | 83 | @JsonProperty("logit_bias") 84 | private Map logitBias; 85 | /** 86 | * 用户唯一值,确保接口不被重复调用 87 | */ 88 | private String user; 89 | 90 | 91 | @Getter 92 | @AllArgsConstructor 93 | public enum Model { 94 | /** 95 | * gpt-3.5-turbo 96 | */ 97 | GPT_3_5_TURBO("gpt-3.5-turbo"), 98 | /** 99 | * 临时模型,不建议使用 100 | */ 101 | GPT_3_5_TURBO_0301("gpt-3.5-turbo-0301"), 102 | /** 103 | * GPT4.0 104 | */ 105 | GPT_4("gpt-4"), 106 | /** 107 | * 临时模型,不建议使用 108 | */ 109 | GPT_4_0314("gpt-4-0314"), 110 | /** 111 | * GPT4.0 超长上下文 112 | */ 113 | GPT_4_32K("gpt-4-32k"), 114 | /** 115 | * 临时模型,不建议使用 116 | */ 117 | GPT_4_32K_0314("gpt-4-32k-0314"), 118 | ; 119 | private String name; 120 | } 121 | 122 | } 123 | 124 | 125 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/entity/chat/ChatCompletionResponse.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai.entity.chat; 2 | 3 | import com.cyl.ctrbt.openai.entity.billing.Usage; 4 | import lombok.Data; 5 | 6 | import java.util.List; 7 | 8 | /** 9 | * chat答案类 10 | * 11 | * @author plexpt 12 | */ 13 | @Data 14 | public class ChatCompletionResponse { 15 | private String id; 16 | private String object; 17 | private long created; 18 | private String model; 19 | private List choices; 20 | private Usage usage; 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/entity/chat/Message.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai.entity.chat; 2 | 3 | import lombok.*; 4 | 5 | /** 6 | * @author plexpt 7 | */ 8 | @Data 9 | @AllArgsConstructor 10 | @NoArgsConstructor 11 | @Builder 12 | public class Message { 13 | /** 14 | * 目前支持三中角色参考官网,进行情景输入:https://platform.openai.com/docs/guides/chat/introduction 15 | */ 16 | private String role; 17 | private String content; 18 | 19 | public static Message of(String content) { 20 | 21 | return new Message(Role.USER.getValue(), content); 22 | } 23 | 24 | public static Message ofSystem(String content) { 25 | 26 | return new Message(Role.SYSTEM.getValue(), content); 27 | } 28 | 29 | public static Message ofAssistant(String content) { 30 | 31 | return new Message(Role.ASSISTANT.getValue(), content); 32 | } 33 | 34 | @Getter 35 | @AllArgsConstructor 36 | public enum Role { 37 | 38 | SYSTEM("system"), 39 | USER("user"), 40 | ASSISTANT("assistant"), 41 | ; 42 | private String value; 43 | } 44 | 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/exception/ChatException.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai.exception; 2 | 3 | /** 4 | * Custom exception class for chat-related errors 5 | * 6 | * @author plexpt 7 | */ 8 | public class ChatException extends RuntimeException { 9 | 10 | 11 | /** 12 | * Constructs a new ChatException with the specified detail message. 13 | * 14 | * @param message the detail message (which is saved for later retrieval by the getMessage() method) 15 | */ 16 | public ChatException(String msg) { 17 | super(msg); 18 | } 19 | 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/listener/AbstractStreamListener.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai.listener; 2 | 3 | import cn.hutool.core.util.StrUtil; 4 | import com.alibaba.fastjson.JSON; 5 | import com.cyl.ctrbt.openai.entity.chat.ChatChoice; 6 | import com.cyl.ctrbt.openai.entity.chat.ChatCompletionResponse; 7 | import com.cyl.ctrbt.openai.entity.chat.Message; 8 | import lombok.Getter; 9 | import lombok.Setter; 10 | import lombok.SneakyThrows; 11 | import lombok.extern.slf4j.Slf4j; 12 | import okhttp3.Response; 13 | import okhttp3.sse.EventSource; 14 | import okhttp3.sse.EventSourceListener; 15 | 16 | import java.util.List; 17 | import java.util.Objects; 18 | import java.util.function.Consumer; 19 | 20 | /** 21 | * EventSource listener for chat-related events. 22 | * 23 | * @author plexpt 24 | */ 25 | @Slf4j 26 | public abstract class AbstractStreamListener extends EventSourceListener { 27 | 28 | protected String lastMessage = ""; 29 | 30 | 31 | /** 32 | * Called when all new message are received. 33 | * 34 | * @param message the new message 35 | */ 36 | @Setter 37 | @Getter 38 | protected Consumer onComplate = s -> { 39 | 40 | }; 41 | 42 | /** 43 | * Called when a new message is received. 44 | * 收到消息 单个字 45 | * 46 | * @param message the new message 47 | */ 48 | public abstract void onMsg(String message); 49 | 50 | /** 51 | * Called when an error occurs. 52 | * 出错时调用 53 | * 54 | * @param throwable the throwable that caused the error 55 | * @param response the response associated with the error, if any 56 | */ 57 | public abstract void onError(Throwable throwable, String response); 58 | 59 | @Override 60 | public void onOpen(EventSource eventSource, Response response) { 61 | // do nothing 62 | } 63 | 64 | @Override 65 | public void onClosed(EventSource eventSource) { 66 | // do nothing 67 | } 68 | 69 | @Override 70 | public void onEvent(EventSource eventSource, String id, String type, String data) { 71 | if (data.equals("[DONE]")) { 72 | onComplate.accept(lastMessage); 73 | return; 74 | } 75 | 76 | ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class); 77 | // 读取Json 78 | List choices = response.getChoices(); 79 | if (choices == null || choices.isEmpty()) { 80 | return; 81 | } 82 | Message delta = choices.get(0).getDelta(); 83 | String text = delta.getContent(); 84 | 85 | if (text != null) { 86 | lastMessage += text; 87 | 88 | onMsg(text); 89 | 90 | } 91 | 92 | } 93 | 94 | 95 | @SneakyThrows 96 | @Override 97 | public void onFailure(EventSource eventSource, Throwable throwable, Response response) { 98 | 99 | try { 100 | log.error("Stream connection error: {}", throwable); 101 | 102 | String responseText = ""; 103 | 104 | if (Objects.nonNull(response)) { 105 | responseText = response.body().string(); 106 | } 107 | 108 | log.error("response:{}", responseText); 109 | 110 | String forbiddenText = "Your access was terminated due to violation of our policies"; 111 | 112 | if (StrUtil.contains(responseText, forbiddenText)) { 113 | log.error("Chat session has been terminated due to policy violation"); 114 | log.error("检测到号被封了"); 115 | } 116 | 117 | String overloadedText = "That model is currently overloaded with other requests."; 118 | 119 | if (StrUtil.contains(responseText, overloadedText)) { 120 | log.error("检测到官方超载了,赶紧优化你的代码,做重试吧"); 121 | } 122 | 123 | this.onError(throwable, responseText); 124 | 125 | } catch (Exception e) { 126 | log.warn("onFailure error:{}", e); 127 | // do nothing 128 | 129 | } finally { 130 | eventSource.cancel(); 131 | } 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/listener/ConsoleStreamListener.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai.listener; 2 | 3 | import lombok.extern.slf4j.Slf4j; 4 | 5 | /** 6 | * 控制台测试 7 | * Console Stream Test Listener 8 | * 9 | * @author plexpt 10 | */ 11 | @Slf4j 12 | public class ConsoleStreamListener extends AbstractStreamListener { 13 | 14 | 15 | @Override 16 | public void onMsg(String message) { 17 | System.out.print(message); 18 | } 19 | 20 | @Override 21 | public void onError(Throwable throwable, String response) { 22 | 23 | } 24 | 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/openai/listener/SseStreamListener.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai.listener; 2 | 3 | 4 | import com.cyl.ctrbt.util.SseHelper; 5 | import lombok.RequiredArgsConstructor; 6 | import lombok.extern.slf4j.Slf4j; 7 | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; 8 | 9 | /** 10 | * sse 11 | * 12 | * @author plexpt 13 | */ 14 | @Slf4j 15 | @RequiredArgsConstructor 16 | public class SseStreamListener extends AbstractStreamListener { 17 | 18 | final SseEmitter sseEmitter; 19 | 20 | 21 | @Override 22 | public void onMsg(String message) { 23 | SseHelper.send(sseEmitter, message); 24 | } 25 | 26 | @Override 27 | public void onError(Throwable throwable, String response) { 28 | SseHelper.complete(sseEmitter); 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/util/ChatContextHolder.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.util; 2 | 3 | 4 | import com.cyl.ctrbt.openai.entity.chat.Message; 5 | 6 | import java.util.ArrayList; 7 | import java.util.HashMap; 8 | import java.util.List; 9 | import java.util.Map; 10 | 11 | public class ChatContextHolder { 12 | 13 | private static Map> context = new HashMap<>(); 14 | 15 | 16 | /** 17 | * 获取对话历史 18 | * 19 | * @param id 20 | * @return 21 | */ 22 | public static List get(String id) { 23 | List messages = context.get(id); 24 | if (messages == null) { 25 | messages = new ArrayList<>(); 26 | context.put(id, messages); 27 | } 28 | 29 | return messages; 30 | } 31 | 32 | 33 | /** 34 | * 添加对话 35 | * 36 | * @param id 37 | * @return 38 | */ 39 | public static void add(String id, String msg) { 40 | 41 | Message message = Message.builder().content(msg).build(); 42 | add(id, message); 43 | } 44 | 45 | 46 | /** 47 | * 添加对话 48 | * 49 | * @param id 50 | * @return 51 | */ 52 | public static void add(String id, Message message) { 53 | List messages = context.get(id); 54 | if (messages == null) { 55 | messages = new ArrayList<>(); 56 | context.put(id, messages); 57 | } 58 | messages.add(message); 59 | } 60 | 61 | /** 62 | * 清除对话 63 | * @param id 64 | */ 65 | public static void remove(String id) { 66 | context.remove(id); 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/util/Proxys.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.util; 2 | 3 | import lombok.experimental.UtilityClass; 4 | 5 | import java.net.InetSocketAddress; 6 | import java.net.Proxy; 7 | 8 | 9 | @UtilityClass 10 | public class Proxys { 11 | 12 | 13 | /** 14 | * http 代理 15 | * 16 | * @param ip 17 | * @param port 18 | * @return 19 | */ 20 | public static Proxy http(String ip, int port) { 21 | return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(ip, port)); 22 | } 23 | 24 | /** 25 | * socks5 代理 26 | * 27 | * @param ip 28 | * @param port 29 | * @return 30 | */ 31 | public static Proxy socks5(String ip, int port) { 32 | return new Proxy(Proxy.Type.SOCKS, new InetSocketAddress(ip, port)); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/util/SseHelper.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.util; 2 | 3 | import lombok.experimental.UtilityClass; 4 | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; 5 | 6 | @UtilityClass 7 | public class SseHelper { 8 | 9 | 10 | public void complete(SseEmitter sseEmitter) { 11 | 12 | try { 13 | sseEmitter.complete(); 14 | } catch (Exception e) { 15 | 16 | } 17 | } 18 | 19 | public void send(SseEmitter sseEmitter, Object data) { 20 | 21 | try { 22 | sseEmitter.send(data); 23 | } catch (Exception e) { 24 | 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/websocket/MyWebSocketInterceptor.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.websocket; 2 | 3 | import cn.hutool.core.util.StrUtil; 4 | import org.slf4j.Logger; 5 | import org.slf4j.LoggerFactory; 6 | import org.springframework.http.server.ServerHttpRequest; 7 | import org.springframework.http.server.ServerHttpResponse; 8 | import org.springframework.http.server.ServletServerHttpRequest; 9 | import org.springframework.web.socket.WebSocketHandler; 10 | import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor; 11 | 12 | import java.util.Map; 13 | 14 | public class MyWebSocketInterceptor extends HttpSessionHandshakeInterceptor { 15 | 16 | private final Logger logger = LoggerFactory.getLogger(getClass()); 17 | 18 | @Override 19 | public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws Exception { 20 | logger.info("[MyWebSocketInterceptor#BeforeHandshake] Request from " + request.getRemoteAddress().getHostString()); 21 | if (request instanceof ServletServerHttpRequest) { 22 | ServletServerHttpRequest serverHttpRequest = (ServletServerHttpRequest) request; 23 | String token = serverHttpRequest.getServletRequest().getHeader("token"); 24 | if (StrUtil.isEmpty(token)) { 25 | token = serverHttpRequest.getServletRequest().getParameter("token"); 26 | } 27 | //这里做一个简单的鉴权,只有符合条件的鉴权才能握手成功 28 | if ("token-123456".equals(token)) { 29 | return super.beforeHandshake(request, response, wsHandler, attributes); 30 | } else { 31 | return false; 32 | } 33 | } 34 | return super.beforeHandshake(request, response, wsHandler, attributes); 35 | } 36 | 37 | @Override 38 | public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception ex) { 39 | logger.info("[MyWebSocketInterceptor#afterHandshake] Request from " + request.getRemoteAddress().getHostString()); 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/websocket/MyWebsocketHandler.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.websocket; 2 | 3 | import cn.hutool.core.lang.UUID; 4 | import cn.hutool.json.JSONUtil; 5 | import com.cyl.ctrbt.openai.ChatGPTStrreamUtil; 6 | import com.cyl.ctrbt.openai.ChatGPTUtil; 7 | import com.cyl.ctrbt.openai.entity.chat.Message; 8 | import com.cyl.ctrbt.websocket.bean.WebSocketBean; 9 | import org.slf4j.Logger; 10 | import org.slf4j.LoggerFactory; 11 | import org.springframework.beans.factory.annotation.Autowired; 12 | import org.springframework.stereotype.Component; 13 | import org.springframework.web.socket.CloseStatus; 14 | import org.springframework.web.socket.PongMessage; 15 | import org.springframework.web.socket.TextMessage; 16 | import org.springframework.web.socket.WebSocketSession; 17 | import org.springframework.web.socket.handler.AbstractWebSocketHandler; 18 | 19 | import java.util.Map; 20 | import java.util.concurrent.ConcurrentHashMap; 21 | import java.util.concurrent.atomic.AtomicInteger; 22 | @Component 23 | public class MyWebsocketHandler extends AbstractWebSocketHandler { 24 | 25 | @Autowired 26 | private ChatGPTUtil chatGPTUtil; 27 | 28 | private final Logger logger = LoggerFactory.getLogger(getClass()); 29 | 30 | private static final Map webSocketBeanMap; 31 | private static final AtomicInteger clientIdMaker; //仅用用于标识客户端编号 32 | 33 | static { 34 | webSocketBeanMap = new ConcurrentHashMap<>(); 35 | clientIdMaker = new AtomicInteger(0); 36 | } 37 | 38 | @Override 39 | public void afterConnectionEstablished(WebSocketSession session) throws Exception { 40 | //当WebSocket连接正式建立后,将该Session加入到Map中进行管理 41 | WebSocketBean webSocketBean = new WebSocketBean(); 42 | webSocketBean.setWebSocketSession(session); 43 | webSocketBean.setClientId(UUID.fastUUID().toString()); 44 | webSocketBeanMap.put(session.getId(), webSocketBean); 45 | } 46 | 47 | @Override 48 | public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { 49 | //当连接关闭后,从Map中移除session实例 50 | webSocketBeanMap.remove(session.getId()); 51 | } 52 | 53 | @Override 54 | public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { 55 | logger.error("session {}", session.getId(), exception); 56 | //传输过程中出现了错误 57 | if (session.isOpen()) { 58 | session.close(); 59 | } 60 | webSocketBeanMap.remove(session.getId()); 61 | } 62 | 63 | @Override 64 | protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { 65 | String user = webSocketBeanMap.get(session.getId()).getClientId(); 66 | //处理接收到的消息 67 | logger.info("Received message from client[ID:" + user + 68 | "]; Content is [" + message.getPayload() + "]."); 69 | TextMessage textMessage; 70 | Message returnMessage = chatGPTUtil.chat(message.getPayload(), user); 71 | textMessage= new TextMessage(returnMessage.getContent()); 72 | session.sendMessage(textMessage); 73 | } 74 | 75 | @Override 76 | protected void handlePongMessage(WebSocketSession session, PongMessage message) throws Exception { 77 | super.handlePongMessage(session, message); 78 | } 79 | } -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/websocket/WebSocketConfiguration.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.websocket; 2 | 3 | import org.springframework.beans.factory.annotation.Autowired; 4 | import org.springframework.context.annotation.Bean; 5 | import org.springframework.context.annotation.Configuration; 6 | import org.springframework.web.socket.config.annotation.EnableWebSocket; 7 | import org.springframework.web.socket.config.annotation.WebSocketConfigurer; 8 | import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; 9 | import org.springframework.web.socket.server.standard.ServerEndpointExporter; 10 | 11 | @Configuration 12 | @EnableWebSocket 13 | public class WebSocketConfiguration implements WebSocketConfigurer { 14 | 15 | @Bean 16 | public MyWebSocketInterceptor webSocketInterceptor() { 17 | return new MyWebSocketInterceptor(); 18 | } 19 | 20 | @Bean 21 | public ServerEndpointExporter serverEndpointExporter() { 22 | return new ServerEndpointExporter(); 23 | } 24 | @Autowired 25 | private MyWebsocketHandler myWebsocketHandler; 26 | 27 | @Override 28 | public void registerWebSocketHandlers(WebSocketHandlerRegistry webSocketHandlerRegistry) { 29 | webSocketHandlerRegistry 30 | .addHandler(myWebsocketHandler, "/websocket") 31 | .setAllowedOrigins("*") 32 | .addInterceptors(webSocketInterceptor()); 33 | } 34 | } -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/websocket/WebSocketServer.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.websocket; 2 | 3 | import lombok.extern.slf4j.Slf4j; 4 | import org.springframework.stereotype.Component; 5 | import org.springframework.stereotype.Service; 6 | 7 | import javax.websocket.*; 8 | import javax.websocket.server.PathParam; 9 | import javax.websocket.server.ServerEndpoint; 10 | import java.io.IOException; 11 | import java.util.concurrent.CopyOnWriteArraySet; 12 | 13 | @Component 14 | @Slf4j 15 | @Service 16 | @ServerEndpoint("/api/websocket/{sid}") 17 | public class WebSocketServer { 18 | //静态变量,用来记录当前在线连接数。应该把它设计成线程安全的。 19 | private static int onlineCount = 0; 20 | //concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象。 21 | private static CopyOnWriteArraySet webSocketSet = new CopyOnWriteArraySet<>(); 22 | 23 | //与某个客户端的连接会话,需要通过它来给客户端发送数据 24 | private Session session; 25 | 26 | //接收sid 27 | private String sid = ""; 28 | 29 | /** 30 | * 连接建立成功调用的方法 31 | */ 32 | @OnOpen 33 | public void onOpen(Session session, @PathParam("sid") String sid) { 34 | this.session = session; 35 | webSocketSet.add(this); //加入set中 36 | this.sid = sid; 37 | addOnlineCount(); //在线数加1 38 | try { 39 | sendMessage("conn_success"); 40 | log.info("有新窗口开始监听:" + sid + ",当前在线人数为:" + getOnlineCount()); 41 | } catch (IOException e) { 42 | log.error("websocket IO Exception"); 43 | } 44 | } 45 | 46 | /** 47 | * 连接关闭调用的方法 48 | */ 49 | @OnClose 50 | public void onClose() { 51 | webSocketSet.remove(this); //从set中删除 52 | subOnlineCount(); //在线数减1 53 | //断开连接情况下,更新主板占用情况为释放 54 | log.info("释放的sid为:"+sid); 55 | //这里写你 释放的时候,要处理的业务 56 | log.info("有一连接关闭!当前在线人数为" + getOnlineCount()); 57 | 58 | } 59 | 60 | /** 61 | * 收到客户端消息后调用的方法 62 | * @ Param message 客户端发送过来的消息 63 | */ 64 | @OnMessage 65 | public void onMessage(String message, Session session) { 66 | log.info("收到来自窗口" + sid + "的信息:" + message); 67 | //群发消息 68 | for (WebSocketServer item : webSocketSet) { 69 | try { 70 | item.sendMessage(message); 71 | } catch (IOException e) { 72 | e.printStackTrace(); 73 | } 74 | } 75 | } 76 | 77 | /** 78 | * @ Param session 79 | * @ Param error 80 | */ 81 | @OnError 82 | public void onError(Session session, Throwable error) { 83 | log.error("发生错误"); 84 | error.printStackTrace(); 85 | } 86 | 87 | /** 88 | * 实现服务器主动推送 89 | */ 90 | public void sendMessage(String message) throws IOException { 91 | this.session.getBasicRemote().sendText(message); 92 | } 93 | 94 | /** 95 | * 群发自定义消息 96 | */ 97 | public static void sendInfo(String message, @PathParam("sid") String sid) throws IOException { 98 | log.info("推送消息到窗口" + sid + ",推送内容:" + message); 99 | 100 | for (WebSocketServer item : webSocketSet) { 101 | try { 102 | //这里可以设定只推送给这个sid的,为null则全部推送 103 | if (sid == null) { 104 | // item.sendMessage(message); 105 | } else if (item.sid.equals(sid)) { 106 | item.sendMessage(message); 107 | } 108 | } catch (IOException e) { 109 | continue; 110 | } 111 | } 112 | } 113 | 114 | public static synchronized int getOnlineCount() { 115 | return onlineCount; 116 | } 117 | 118 | public static synchronized void addOnlineCount() { 119 | WebSocketServer.onlineCount++; 120 | } 121 | 122 | public static synchronized void subOnlineCount() { 123 | WebSocketServer.onlineCount--; 124 | } 125 | 126 | public static CopyOnWriteArraySet getWebSocketSet() { 127 | return webSocketSet; 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /src/main/java/com/cyl/ctrbt/websocket/bean/WebSocketBean.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.websocket.bean; 2 | 3 | import lombok.Data; 4 | import org.springframework.web.socket.WebSocketSession; 5 | 6 | import java.time.LocalDateTime; 7 | 8 | @Data 9 | public class WebSocketBean { 10 | // websocket 11 | private WebSocketSession webSocketSession; 12 | // 客户端id 13 | private String clientId; 14 | // 最后更新时间 15 | private LocalDateTime lastMessageTime; 16 | } 17 | -------------------------------------------------------------------------------- /src/main/resources/application.yml: -------------------------------------------------------------------------------- 1 | server: 2 | port: 8081 3 | logging: 4 | level: 5 | root: info 6 | openai: 7 | secret_key: 测试前请先设置您的SECRET KEY,查看地址:https://platform.openai.com/account/api-keys 8 | proxy: #如果在国内访问,使用这个 9 | ip: 127.0.0.1 10 | port: 33210 -------------------------------------------------------------------------------- /src/test/java/com/cyl/ctrbt/ChatgptRobotBackApplicationTests.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt; 2 | 3 | import org.junit.jupiter.api.Test; 4 | import org.springframework.boot.test.context.SpringBootTest; 5 | 6 | @SpringBootTest 7 | class ChatgptRobotBackApplicationTests { 8 | 9 | @Test 10 | void contextLoads() { 11 | } 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/test/java/com/cyl/ctrbt/openai/OpenAiTest.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.openai; 2 | 3 | import com.cyl.ctrbt.ChatgptRobotBackApplication; 4 | import com.cyl.ctrbt.openai.entity.chat.Message; 5 | import org.junit.Test; 6 | import org.junit.runner.RunWith; 7 | import org.springframework.beans.factory.annotation.Autowired; 8 | import org.springframework.boot.test.context.SpringBootTest; 9 | import org.springframework.test.context.ActiveProfiles; 10 | import org.springframework.test.context.junit4.SpringRunner; 11 | 12 | @RunWith(SpringRunner.class) 13 | @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT, classes = ChatgptRobotBackApplication.class) 14 | @ActiveProfiles("dev") 15 | public class OpenAiTest { 16 | @Autowired 17 | private ChatGPTUtil chatGPTUtil; 18 | 19 | @Test 20 | public void testChatGPT(){ 21 | Message message = chatGPTUtil.chat("单身狗如何过情人节?", "单身狗"); 22 | System.out.println(message.getContent()); 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /src/test/java/com/cyl/ctrbt/websocket/MyWebSocketClient.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.websocket; 2 | 3 | import org.java_websocket.client.WebSocketClient; 4 | import org.java_websocket.handshake.ServerHandshake; 5 | import org.slf4j.Logger; 6 | import org.slf4j.LoggerFactory; 7 | 8 | import java.net.URI; 9 | 10 | public class MyWebSocketClient extends WebSocketClient { 11 | 12 | private Logger logger = LoggerFactory.getLogger(getClass()); 13 | 14 | public MyWebSocketClient(URI serverUri) { 15 | super(serverUri); 16 | } 17 | 18 | @Override 19 | public void onOpen(ServerHandshake serverHandshake) { 20 | logger.info("[MyWebSocketClient#onOpen]The WebSocket connection is open."); 21 | } 22 | 23 | @Override 24 | public void onMessage(String s) { 25 | logger.info("[MyWebSocketClient#onMessage]The client has received the message from server." + 26 | "The Content is [" + s + "]"); 27 | } 28 | 29 | @Override 30 | public void onClose(int i, String s, boolean b) { 31 | logger.info("[MyWebSocketClient#onClose]The WebSocket connection is close."); 32 | } 33 | 34 | @Override 35 | public void onError(Exception e) { 36 | logger.info("[MyWebSocketClient#onError]The WebSocket connection is error."); 37 | } 38 | } -------------------------------------------------------------------------------- /src/test/java/com/cyl/ctrbt/websocket/WebSocketTest.java: -------------------------------------------------------------------------------- 1 | package com.cyl.ctrbt.websocket; 2 | 3 | import org.junit.Test; 4 | 5 | import java.net.URI; 6 | import java.util.Timer; 7 | import java.util.TimerTask; 8 | import java.util.concurrent.atomic.AtomicInteger; 9 | 10 | public class WebSocketTest { 11 | private static final AtomicInteger count = new AtomicInteger(0); 12 | 13 | @Test 14 | public void test() { 15 | URI uri = URI.create("ws://127.0.0.1:8081/websocket"); //注意协议号为ws 16 | MyWebSocketClient client = new MyWebSocketClient(uri); 17 | client.addHeader("token", "token-123456"); //这里为header添加了token,实现简单的校验 18 | try { 19 | client.connectBlocking(); //在连接成功之前会一直阻塞 20 | 21 | while (true) { 22 | if (client.isOpen()) { 23 | client.send("先有鸡还是先有蛋?"); 24 | } 25 | Thread.sleep(1000000); 26 | } 27 | } catch (InterruptedException e) { 28 | e.printStackTrace(); 29 | } 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /src/test/resources/application.yml: -------------------------------------------------------------------------------- 1 | logging: 2 | level: 3 | root: info 4 | openai: 5 | secret_key: 测试前请先设置您的SECRET KEY,查看地址:https://platform.openai.com/account/api-keys 6 | --------------------------------------------------------------------------------