messages;
31 | /**
32 | * 使用什么取样温度,0到2之间。越高越奔放。越低越保守。
33 | *
34 | * 不要同时改这个和topP
35 | */
36 | @Builder.Default
37 | private double temperature = 0.9;
38 |
39 | /**
40 | * 0-1
41 | * 建议0.9
42 | * 不要同时改这个和temperature
43 | */
44 | @JsonProperty("top_p")
45 | @Builder.Default
46 | private double topP = 0.9;
47 |
48 |
49 | /**
50 | * 结果数。
51 | */
52 | @Builder.Default
53 | private Integer n = 1;
54 |
55 |
56 | /**
57 | * 是否流式输出.
58 | * default:false
59 | */
60 | @Builder.Default
61 | private boolean stream = false;
62 | /**
63 | * 停用词
64 | */
65 | private List stop;
66 | /**
67 | * 3.5 最大支持4096
68 | * 4.0 最大32k
69 | */
70 | @JsonProperty("max_tokens")
71 | private Integer maxTokens;
72 |
73 |
74 | @JsonProperty("presence_penalty")
75 | private double presencePenalty;
76 |
77 | /**
78 | * -2.0 ~~ 2.0
79 | */
80 | @JsonProperty("frequency_penalty")
81 | private double frequencyPenalty;
82 |
83 | @JsonProperty("logit_bias")
84 | private Map logitBias;
85 | /**
86 | * 用户唯一值,确保接口不被重复调用
87 | */
88 | private String user;
89 |
90 |
91 | @Getter
92 | @AllArgsConstructor
93 | public enum Model {
94 | /**
95 | * gpt-3.5-turbo
96 | */
97 | GPT_3_5_TURBO("gpt-3.5-turbo"),
98 | /**
99 | * 临时模型,不建议使用
100 | */
101 | GPT_3_5_TURBO_0301("gpt-3.5-turbo-0301"),
102 | /**
103 | * GPT4.0
104 | */
105 | GPT_4("gpt-4"),
106 | /**
107 | * 临时模型,不建议使用
108 | */
109 | GPT_4_0314("gpt-4-0314"),
110 | /**
111 | * GPT4.0 超长上下文
112 | */
113 | GPT_4_32K("gpt-4-32k"),
114 | /**
115 | * 临时模型,不建议使用
116 | */
117 | GPT_4_32K_0314("gpt-4-32k-0314"),
118 | ;
119 | private String name;
120 | }
121 |
122 | }
123 |
124 |
125 |
--------------------------------------------------------------------------------
/src/main/java/com/cyl/ctrbt/openai/entity/chat/ChatCompletionResponse.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.openai.entity.chat;
2 |
3 | import com.cyl.ctrbt.openai.entity.billing.Usage;
4 | import lombok.Data;
5 |
6 | import java.util.List;
7 |
8 | /**
9 | * chat答案类
10 | *
11 | * @author plexpt
12 | */
13 | @Data
14 | public class ChatCompletionResponse {
15 | private String id;
16 | private String object;
17 | private long created;
18 | private String model;
19 | private List choices;
20 | private Usage usage;
21 | }
22 |
--------------------------------------------------------------------------------
/src/main/java/com/cyl/ctrbt/openai/entity/chat/Message.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.openai.entity.chat;
2 |
3 | import lombok.*;
4 |
5 | /**
6 | * @author plexpt
7 | */
8 | @Data
9 | @AllArgsConstructor
10 | @NoArgsConstructor
11 | @Builder
12 | public class Message {
13 | /**
14 | * 目前支持三中角色参考官网,进行情景输入:https://platform.openai.com/docs/guides/chat/introduction
15 | */
16 | private String role;
17 | private String content;
18 |
19 | public static Message of(String content) {
20 |
21 | return new Message(Role.USER.getValue(), content);
22 | }
23 |
24 | public static Message ofSystem(String content) {
25 |
26 | return new Message(Role.SYSTEM.getValue(), content);
27 | }
28 |
29 | public static Message ofAssistant(String content) {
30 |
31 | return new Message(Role.ASSISTANT.getValue(), content);
32 | }
33 |
34 | @Getter
35 | @AllArgsConstructor
36 | public enum Role {
37 |
38 | SYSTEM("system"),
39 | USER("user"),
40 | ASSISTANT("assistant"),
41 | ;
42 | private String value;
43 | }
44 |
45 | }
46 |
--------------------------------------------------------------------------------
/src/main/java/com/cyl/ctrbt/openai/exception/ChatException.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.openai.exception;
2 |
3 | /**
4 | * Custom exception class for chat-related errors
5 | *
6 | * @author plexpt
7 | */
8 | public class ChatException extends RuntimeException {
9 |
10 |
11 | /**
12 | * Constructs a new ChatException with the specified detail message.
13 | *
14 | * @param message the detail message (which is saved for later retrieval by the getMessage() method)
15 | */
16 | public ChatException(String msg) {
17 | super(msg);
18 | }
19 |
20 |
21 | }
22 |
--------------------------------------------------------------------------------
/src/main/java/com/cyl/ctrbt/openai/listener/AbstractStreamListener.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.openai.listener;
2 |
3 | import cn.hutool.core.util.StrUtil;
4 | import com.alibaba.fastjson.JSON;
5 | import com.cyl.ctrbt.openai.entity.chat.ChatChoice;
6 | import com.cyl.ctrbt.openai.entity.chat.ChatCompletionResponse;
7 | import com.cyl.ctrbt.openai.entity.chat.Message;
8 | import lombok.Getter;
9 | import lombok.Setter;
10 | import lombok.SneakyThrows;
11 | import lombok.extern.slf4j.Slf4j;
12 | import okhttp3.Response;
13 | import okhttp3.sse.EventSource;
14 | import okhttp3.sse.EventSourceListener;
15 |
16 | import java.util.List;
17 | import java.util.Objects;
18 | import java.util.function.Consumer;
19 |
20 | /**
21 | * EventSource listener for chat-related events.
22 | *
23 | * @author plexpt
24 | */
25 | @Slf4j
26 | public abstract class AbstractStreamListener extends EventSourceListener {
27 |
28 | protected String lastMessage = "";
29 |
30 |
31 | /**
32 | * Called when all new message are received.
33 | *
34 | * @param message the new message
35 | */
36 | @Setter
37 | @Getter
38 | protected Consumer onComplate = s -> {
39 |
40 | };
41 |
42 | /**
43 | * Called when a new message is received.
44 | * 收到消息 单个字
45 | *
46 | * @param message the new message
47 | */
48 | public abstract void onMsg(String message);
49 |
50 | /**
51 | * Called when an error occurs.
52 | * 出错时调用
53 | *
54 | * @param throwable the throwable that caused the error
55 | * @param response the response associated with the error, if any
56 | */
57 | public abstract void onError(Throwable throwable, String response);
58 |
59 | @Override
60 | public void onOpen(EventSource eventSource, Response response) {
61 | // do nothing
62 | }
63 |
64 | @Override
65 | public void onClosed(EventSource eventSource) {
66 | // do nothing
67 | }
68 |
69 | @Override
70 | public void onEvent(EventSource eventSource, String id, String type, String data) {
71 | if (data.equals("[DONE]")) {
72 | onComplate.accept(lastMessage);
73 | return;
74 | }
75 |
76 | ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
77 | // 读取Json
78 | List choices = response.getChoices();
79 | if (choices == null || choices.isEmpty()) {
80 | return;
81 | }
82 | Message delta = choices.get(0).getDelta();
83 | String text = delta.getContent();
84 |
85 | if (text != null) {
86 | lastMessage += text;
87 |
88 | onMsg(text);
89 |
90 | }
91 |
92 | }
93 |
94 |
95 | @SneakyThrows
96 | @Override
97 | public void onFailure(EventSource eventSource, Throwable throwable, Response response) {
98 |
99 | try {
100 | log.error("Stream connection error: {}", throwable);
101 |
102 | String responseText = "";
103 |
104 | if (Objects.nonNull(response)) {
105 | responseText = response.body().string();
106 | }
107 |
108 | log.error("response:{}", responseText);
109 |
110 | String forbiddenText = "Your access was terminated due to violation of our policies";
111 |
112 | if (StrUtil.contains(responseText, forbiddenText)) {
113 | log.error("Chat session has been terminated due to policy violation");
114 | log.error("检测到号被封了");
115 | }
116 |
117 | String overloadedText = "That model is currently overloaded with other requests.";
118 |
119 | if (StrUtil.contains(responseText, overloadedText)) {
120 | log.error("检测到官方超载了,赶紧优化你的代码,做重试吧");
121 | }
122 |
123 | this.onError(throwable, responseText);
124 |
125 | } catch (Exception e) {
126 | log.warn("onFailure error:{}", e);
127 | // do nothing
128 |
129 | } finally {
130 | eventSource.cancel();
131 | }
132 | }
133 | }
134 |
--------------------------------------------------------------------------------
/src/main/java/com/cyl/ctrbt/openai/listener/ConsoleStreamListener.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.openai.listener;
2 |
3 | import lombok.extern.slf4j.Slf4j;
4 |
5 | /**
6 | * 控制台测试
7 | * Console Stream Test Listener
8 | *
9 | * @author plexpt
10 | */
11 | @Slf4j
12 | public class ConsoleStreamListener extends AbstractStreamListener {
13 |
14 |
15 | @Override
16 | public void onMsg(String message) {
17 | System.out.print(message);
18 | }
19 |
20 | @Override
21 | public void onError(Throwable throwable, String response) {
22 |
23 | }
24 |
25 | }
26 |
--------------------------------------------------------------------------------
/src/main/java/com/cyl/ctrbt/openai/listener/SseStreamListener.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.openai.listener;
2 |
3 |
4 | import com.cyl.ctrbt.util.SseHelper;
5 | import lombok.RequiredArgsConstructor;
6 | import lombok.extern.slf4j.Slf4j;
7 | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
8 |
9 | /**
10 | * sse
11 | *
12 | * @author plexpt
13 | */
14 | @Slf4j
15 | @RequiredArgsConstructor
16 | public class SseStreamListener extends AbstractStreamListener {
17 |
18 | final SseEmitter sseEmitter;
19 |
20 |
21 | @Override
22 | public void onMsg(String message) {
23 | SseHelper.send(sseEmitter, message);
24 | }
25 |
26 | @Override
27 | public void onError(Throwable throwable, String response) {
28 | SseHelper.complete(sseEmitter);
29 | }
30 |
31 | }
32 |
--------------------------------------------------------------------------------
/src/main/java/com/cyl/ctrbt/util/ChatContextHolder.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.util;
2 |
3 |
4 | import com.cyl.ctrbt.openai.entity.chat.Message;
5 |
6 | import java.util.ArrayList;
7 | import java.util.HashMap;
8 | import java.util.List;
9 | import java.util.Map;
10 |
11 | public class ChatContextHolder {
12 |
13 | private static Map> context = new HashMap<>();
14 |
15 |
16 | /**
17 | * 获取对话历史
18 | *
19 | * @param id
20 | * @return
21 | */
22 | public static List get(String id) {
23 | List messages = context.get(id);
24 | if (messages == null) {
25 | messages = new ArrayList<>();
26 | context.put(id, messages);
27 | }
28 |
29 | return messages;
30 | }
31 |
32 |
33 | /**
34 | * 添加对话
35 | *
36 | * @param id
37 | * @return
38 | */
39 | public static void add(String id, String msg) {
40 |
41 | Message message = Message.builder().content(msg).build();
42 | add(id, message);
43 | }
44 |
45 |
46 | /**
47 | * 添加对话
48 | *
49 | * @param id
50 | * @return
51 | */
52 | public static void add(String id, Message message) {
53 | List messages = context.get(id);
54 | if (messages == null) {
55 | messages = new ArrayList<>();
56 | context.put(id, messages);
57 | }
58 | messages.add(message);
59 | }
60 |
61 | /**
62 | * 清除对话
63 | * @param id
64 | */
65 | public static void remove(String id) {
66 | context.remove(id);
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/src/main/java/com/cyl/ctrbt/util/Proxys.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.util;
2 |
3 | import lombok.experimental.UtilityClass;
4 |
5 | import java.net.InetSocketAddress;
6 | import java.net.Proxy;
7 |
8 |
9 | @UtilityClass
10 | public class Proxys {
11 |
12 |
13 | /**
14 | * http 代理
15 | *
16 | * @param ip
17 | * @param port
18 | * @return
19 | */
20 | public static Proxy http(String ip, int port) {
21 | return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(ip, port));
22 | }
23 |
24 | /**
25 | * socks5 代理
26 | *
27 | * @param ip
28 | * @param port
29 | * @return
30 | */
31 | public static Proxy socks5(String ip, int port) {
32 | return new Proxy(Proxy.Type.SOCKS, new InetSocketAddress(ip, port));
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/src/main/java/com/cyl/ctrbt/util/SseHelper.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.util;
2 |
3 | import lombok.experimental.UtilityClass;
4 | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
5 |
6 | @UtilityClass
7 | public class SseHelper {
8 |
9 |
10 | public void complete(SseEmitter sseEmitter) {
11 |
12 | try {
13 | sseEmitter.complete();
14 | } catch (Exception e) {
15 |
16 | }
17 | }
18 |
19 | public void send(SseEmitter sseEmitter, Object data) {
20 |
21 | try {
22 | sseEmitter.send(data);
23 | } catch (Exception e) {
24 |
25 | }
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/src/main/java/com/cyl/ctrbt/websocket/MyWebSocketInterceptor.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.websocket;
2 |
3 | import cn.hutool.core.util.StrUtil;
4 | import org.slf4j.Logger;
5 | import org.slf4j.LoggerFactory;
6 | import org.springframework.http.server.ServerHttpRequest;
7 | import org.springframework.http.server.ServerHttpResponse;
8 | import org.springframework.http.server.ServletServerHttpRequest;
9 | import org.springframework.web.socket.WebSocketHandler;
10 | import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
11 |
12 | import java.util.Map;
13 |
14 | public class MyWebSocketInterceptor extends HttpSessionHandshakeInterceptor {
15 |
16 | private final Logger logger = LoggerFactory.getLogger(getClass());
17 |
18 | @Override
19 | public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws Exception {
20 | logger.info("[MyWebSocketInterceptor#BeforeHandshake] Request from " + request.getRemoteAddress().getHostString());
21 | if (request instanceof ServletServerHttpRequest) {
22 | ServletServerHttpRequest serverHttpRequest = (ServletServerHttpRequest) request;
23 | String token = serverHttpRequest.getServletRequest().getHeader("token");
24 | if (StrUtil.isEmpty(token)) {
25 | token = serverHttpRequest.getServletRequest().getParameter("token");
26 | }
27 | //这里做一个简单的鉴权,只有符合条件的鉴权才能握手成功
28 | if ("token-123456".equals(token)) {
29 | return super.beforeHandshake(request, response, wsHandler, attributes);
30 | } else {
31 | return false;
32 | }
33 | }
34 | return super.beforeHandshake(request, response, wsHandler, attributes);
35 | }
36 |
37 | @Override
38 | public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception ex) {
39 | logger.info("[MyWebSocketInterceptor#afterHandshake] Request from " + request.getRemoteAddress().getHostString());
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/src/main/java/com/cyl/ctrbt/websocket/MyWebsocketHandler.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.websocket;
2 |
3 | import cn.hutool.core.lang.UUID;
4 | import cn.hutool.json.JSONUtil;
5 | import com.cyl.ctrbt.openai.ChatGPTStrreamUtil;
6 | import com.cyl.ctrbt.openai.ChatGPTUtil;
7 | import com.cyl.ctrbt.openai.entity.chat.Message;
8 | import com.cyl.ctrbt.websocket.bean.WebSocketBean;
9 | import org.slf4j.Logger;
10 | import org.slf4j.LoggerFactory;
11 | import org.springframework.beans.factory.annotation.Autowired;
12 | import org.springframework.stereotype.Component;
13 | import org.springframework.web.socket.CloseStatus;
14 | import org.springframework.web.socket.PongMessage;
15 | import org.springframework.web.socket.TextMessage;
16 | import org.springframework.web.socket.WebSocketSession;
17 | import org.springframework.web.socket.handler.AbstractWebSocketHandler;
18 |
19 | import java.util.Map;
20 | import java.util.concurrent.ConcurrentHashMap;
21 | import java.util.concurrent.atomic.AtomicInteger;
22 | @Component
23 | public class MyWebsocketHandler extends AbstractWebSocketHandler {
24 |
25 | @Autowired
26 | private ChatGPTUtil chatGPTUtil;
27 |
28 | private final Logger logger = LoggerFactory.getLogger(getClass());
29 |
30 | private static final Map webSocketBeanMap;
31 | private static final AtomicInteger clientIdMaker; //仅用用于标识客户端编号
32 |
33 | static {
34 | webSocketBeanMap = new ConcurrentHashMap<>();
35 | clientIdMaker = new AtomicInteger(0);
36 | }
37 |
38 | @Override
39 | public void afterConnectionEstablished(WebSocketSession session) throws Exception {
40 | //当WebSocket连接正式建立后,将该Session加入到Map中进行管理
41 | WebSocketBean webSocketBean = new WebSocketBean();
42 | webSocketBean.setWebSocketSession(session);
43 | webSocketBean.setClientId(UUID.fastUUID().toString());
44 | webSocketBeanMap.put(session.getId(), webSocketBean);
45 | }
46 |
47 | @Override
48 | public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
49 | //当连接关闭后,从Map中移除session实例
50 | webSocketBeanMap.remove(session.getId());
51 | }
52 |
53 | @Override
54 | public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
55 | logger.error("session {}", session.getId(), exception);
56 | //传输过程中出现了错误
57 | if (session.isOpen()) {
58 | session.close();
59 | }
60 | webSocketBeanMap.remove(session.getId());
61 | }
62 |
63 | @Override
64 | protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
65 | String user = webSocketBeanMap.get(session.getId()).getClientId();
66 | //处理接收到的消息
67 | logger.info("Received message from client[ID:" + user +
68 | "]; Content is [" + message.getPayload() + "].");
69 | TextMessage textMessage;
70 | Message returnMessage = chatGPTUtil.chat(message.getPayload(), user);
71 | textMessage= new TextMessage(returnMessage.getContent());
72 | session.sendMessage(textMessage);
73 | }
74 |
75 | @Override
76 | protected void handlePongMessage(WebSocketSession session, PongMessage message) throws Exception {
77 | super.handlePongMessage(session, message);
78 | }
79 | }
--------------------------------------------------------------------------------
/src/main/java/com/cyl/ctrbt/websocket/WebSocketConfiguration.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.websocket;
2 |
3 | import org.springframework.beans.factory.annotation.Autowired;
4 | import org.springframework.context.annotation.Bean;
5 | import org.springframework.context.annotation.Configuration;
6 | import org.springframework.web.socket.config.annotation.EnableWebSocket;
7 | import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
8 | import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
9 | import org.springframework.web.socket.server.standard.ServerEndpointExporter;
10 |
11 | @Configuration
12 | @EnableWebSocket
13 | public class WebSocketConfiguration implements WebSocketConfigurer {
14 |
15 | @Bean
16 | public MyWebSocketInterceptor webSocketInterceptor() {
17 | return new MyWebSocketInterceptor();
18 | }
19 |
20 | @Bean
21 | public ServerEndpointExporter serverEndpointExporter() {
22 | return new ServerEndpointExporter();
23 | }
24 | @Autowired
25 | private MyWebsocketHandler myWebsocketHandler;
26 |
27 | @Override
28 | public void registerWebSocketHandlers(WebSocketHandlerRegistry webSocketHandlerRegistry) {
29 | webSocketHandlerRegistry
30 | .addHandler(myWebsocketHandler, "/websocket")
31 | .setAllowedOrigins("*")
32 | .addInterceptors(webSocketInterceptor());
33 | }
34 | }
--------------------------------------------------------------------------------
/src/main/java/com/cyl/ctrbt/websocket/WebSocketServer.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.websocket;
2 |
3 | import lombok.extern.slf4j.Slf4j;
4 | import org.springframework.stereotype.Component;
5 | import org.springframework.stereotype.Service;
6 |
7 | import javax.websocket.*;
8 | import javax.websocket.server.PathParam;
9 | import javax.websocket.server.ServerEndpoint;
10 | import java.io.IOException;
11 | import java.util.concurrent.CopyOnWriteArraySet;
12 |
13 | @Component
14 | @Slf4j
15 | @Service
16 | @ServerEndpoint("/api/websocket/{sid}")
17 | public class WebSocketServer {
18 | //静态变量,用来记录当前在线连接数。应该把它设计成线程安全的。
19 | private static int onlineCount = 0;
20 | //concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象。
21 | private static CopyOnWriteArraySet webSocketSet = new CopyOnWriteArraySet<>();
22 |
23 | //与某个客户端的连接会话,需要通过它来给客户端发送数据
24 | private Session session;
25 |
26 | //接收sid
27 | private String sid = "";
28 |
29 | /**
30 | * 连接建立成功调用的方法
31 | */
32 | @OnOpen
33 | public void onOpen(Session session, @PathParam("sid") String sid) {
34 | this.session = session;
35 | webSocketSet.add(this); //加入set中
36 | this.sid = sid;
37 | addOnlineCount(); //在线数加1
38 | try {
39 | sendMessage("conn_success");
40 | log.info("有新窗口开始监听:" + sid + ",当前在线人数为:" + getOnlineCount());
41 | } catch (IOException e) {
42 | log.error("websocket IO Exception");
43 | }
44 | }
45 |
46 | /**
47 | * 连接关闭调用的方法
48 | */
49 | @OnClose
50 | public void onClose() {
51 | webSocketSet.remove(this); //从set中删除
52 | subOnlineCount(); //在线数减1
53 | //断开连接情况下,更新主板占用情况为释放
54 | log.info("释放的sid为:"+sid);
55 | //这里写你 释放的时候,要处理的业务
56 | log.info("有一连接关闭!当前在线人数为" + getOnlineCount());
57 |
58 | }
59 |
60 | /**
61 | * 收到客户端消息后调用的方法
62 | * @ Param message 客户端发送过来的消息
63 | */
64 | @OnMessage
65 | public void onMessage(String message, Session session) {
66 | log.info("收到来自窗口" + sid + "的信息:" + message);
67 | //群发消息
68 | for (WebSocketServer item : webSocketSet) {
69 | try {
70 | item.sendMessage(message);
71 | } catch (IOException e) {
72 | e.printStackTrace();
73 | }
74 | }
75 | }
76 |
77 | /**
78 | * @ Param session
79 | * @ Param error
80 | */
81 | @OnError
82 | public void onError(Session session, Throwable error) {
83 | log.error("发生错误");
84 | error.printStackTrace();
85 | }
86 |
87 | /**
88 | * 实现服务器主动推送
89 | */
90 | public void sendMessage(String message) throws IOException {
91 | this.session.getBasicRemote().sendText(message);
92 | }
93 |
94 | /**
95 | * 群发自定义消息
96 | */
97 | public static void sendInfo(String message, @PathParam("sid") String sid) throws IOException {
98 | log.info("推送消息到窗口" + sid + ",推送内容:" + message);
99 |
100 | for (WebSocketServer item : webSocketSet) {
101 | try {
102 | //这里可以设定只推送给这个sid的,为null则全部推送
103 | if (sid == null) {
104 | // item.sendMessage(message);
105 | } else if (item.sid.equals(sid)) {
106 | item.sendMessage(message);
107 | }
108 | } catch (IOException e) {
109 | continue;
110 | }
111 | }
112 | }
113 |
114 | public static synchronized int getOnlineCount() {
115 | return onlineCount;
116 | }
117 |
118 | public static synchronized void addOnlineCount() {
119 | WebSocketServer.onlineCount++;
120 | }
121 |
122 | public static synchronized void subOnlineCount() {
123 | WebSocketServer.onlineCount--;
124 | }
125 |
126 | public static CopyOnWriteArraySet getWebSocketSet() {
127 | return webSocketSet;
128 | }
129 | }
130 |
--------------------------------------------------------------------------------
/src/main/java/com/cyl/ctrbt/websocket/bean/WebSocketBean.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.websocket.bean;
2 |
3 | import lombok.Data;
4 | import org.springframework.web.socket.WebSocketSession;
5 |
6 | import java.time.LocalDateTime;
7 |
8 | @Data
9 | public class WebSocketBean {
10 | // websocket
11 | private WebSocketSession webSocketSession;
12 | // 客户端id
13 | private String clientId;
14 | // 最后更新时间
15 | private LocalDateTime lastMessageTime;
16 | }
17 |
--------------------------------------------------------------------------------
/src/main/resources/application.yml:
--------------------------------------------------------------------------------
1 | server:
2 | port: 8081
3 | logging:
4 | level:
5 | root: info
6 | openai:
7 | secret_key: 测试前请先设置您的SECRET KEY,查看地址:https://platform.openai.com/account/api-keys
8 | proxy: #如果在国内访问,使用这个
9 | ip: 127.0.0.1
10 | port: 33210
--------------------------------------------------------------------------------
/src/test/java/com/cyl/ctrbt/ChatgptRobotBackApplicationTests.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt;
2 |
3 | import org.junit.jupiter.api.Test;
4 | import org.springframework.boot.test.context.SpringBootTest;
5 |
6 | @SpringBootTest
7 | class ChatgptRobotBackApplicationTests {
8 |
9 | @Test
10 | void contextLoads() {
11 | }
12 |
13 | }
14 |
--------------------------------------------------------------------------------
/src/test/java/com/cyl/ctrbt/openai/OpenAiTest.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.openai;
2 |
3 | import com.cyl.ctrbt.ChatgptRobotBackApplication;
4 | import com.cyl.ctrbt.openai.entity.chat.Message;
5 | import org.junit.Test;
6 | import org.junit.runner.RunWith;
7 | import org.springframework.beans.factory.annotation.Autowired;
8 | import org.springframework.boot.test.context.SpringBootTest;
9 | import org.springframework.test.context.ActiveProfiles;
10 | import org.springframework.test.context.junit4.SpringRunner;
11 |
12 | @RunWith(SpringRunner.class)
13 | @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT, classes = ChatgptRobotBackApplication.class)
14 | @ActiveProfiles("dev")
15 | public class OpenAiTest {
16 | @Autowired
17 | private ChatGPTUtil chatGPTUtil;
18 |
19 | @Test
20 | public void testChatGPT(){
21 | Message message = chatGPTUtil.chat("单身狗如何过情人节?", "单身狗");
22 | System.out.println(message.getContent());
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/src/test/java/com/cyl/ctrbt/websocket/MyWebSocketClient.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.websocket;
2 |
3 | import org.java_websocket.client.WebSocketClient;
4 | import org.java_websocket.handshake.ServerHandshake;
5 | import org.slf4j.Logger;
6 | import org.slf4j.LoggerFactory;
7 |
8 | import java.net.URI;
9 |
10 | public class MyWebSocketClient extends WebSocketClient {
11 |
12 | private Logger logger = LoggerFactory.getLogger(getClass());
13 |
14 | public MyWebSocketClient(URI serverUri) {
15 | super(serverUri);
16 | }
17 |
18 | @Override
19 | public void onOpen(ServerHandshake serverHandshake) {
20 | logger.info("[MyWebSocketClient#onOpen]The WebSocket connection is open.");
21 | }
22 |
23 | @Override
24 | public void onMessage(String s) {
25 | logger.info("[MyWebSocketClient#onMessage]The client has received the message from server." +
26 | "The Content is [" + s + "]");
27 | }
28 |
29 | @Override
30 | public void onClose(int i, String s, boolean b) {
31 | logger.info("[MyWebSocketClient#onClose]The WebSocket connection is close.");
32 | }
33 |
34 | @Override
35 | public void onError(Exception e) {
36 | logger.info("[MyWebSocketClient#onError]The WebSocket connection is error.");
37 | }
38 | }
--------------------------------------------------------------------------------
/src/test/java/com/cyl/ctrbt/websocket/WebSocketTest.java:
--------------------------------------------------------------------------------
1 | package com.cyl.ctrbt.websocket;
2 |
3 | import org.junit.Test;
4 |
5 | import java.net.URI;
6 | import java.util.Timer;
7 | import java.util.TimerTask;
8 | import java.util.concurrent.atomic.AtomicInteger;
9 |
10 | public class WebSocketTest {
11 | private static final AtomicInteger count = new AtomicInteger(0);
12 |
13 | @Test
14 | public void test() {
15 | URI uri = URI.create("ws://127.0.0.1:8081/websocket"); //注意协议号为ws
16 | MyWebSocketClient client = new MyWebSocketClient(uri);
17 | client.addHeader("token", "token-123456"); //这里为header添加了token,实现简单的校验
18 | try {
19 | client.connectBlocking(); //在连接成功之前会一直阻塞
20 |
21 | while (true) {
22 | if (client.isOpen()) {
23 | client.send("先有鸡还是先有蛋?");
24 | }
25 | Thread.sleep(1000000);
26 | }
27 | } catch (InterruptedException e) {
28 | e.printStackTrace();
29 | }
30 | }
31 |
32 | }
33 |
--------------------------------------------------------------------------------
/src/test/resources/application.yml:
--------------------------------------------------------------------------------
1 | logging:
2 | level:
3 | root: info
4 | openai:
5 | secret_key: 测试前请先设置您的SECRET KEY,查看地址:https://platform.openai.com/account/api-keys
6 |
--------------------------------------------------------------------------------