├── .gitignore
├── Readme.md
├── pom.xml
└── src
├── main
├── java
│ └── me
│ │ └── zhangjh
│ │ └── chatgpt
│ │ ├── client
│ │ └── ChatGptService.java
│ │ ├── config
│ │ ├── ChatGptConfig.java
│ │ └── HttpSessionWSHelper.java
│ │ ├── constant
│ │ ├── ModelEnum.java
│ │ └── RoleEnum.java
│ │ ├── dto
│ │ ├── Message.java
│ │ ├── request
│ │ │ ├── ChatBaseRequest.java
│ │ │ ├── ChatRequest.java
│ │ │ ├── ImageRequest.java
│ │ │ ├── TextRequest.java
│ │ │ └── TranscriptionRequest.java
│ │ └── response
│ │ │ ├── BizException.java
│ │ │ ├── ChatResponse.java
│ │ │ ├── ChatRet.java
│ │ │ ├── ChatStreamRet.java
│ │ │ ├── CompletionUsage.java
│ │ │ ├── ImageResponse.java
│ │ │ ├── ImageRet.java
│ │ │ ├── TextResponse.java
│ │ │ ├── TextRet.java
│ │ │ └── TranscriptionResponse.java
│ │ ├── service
│ │ └── ChatGptServiceImpl.java
│ │ ├── socket
│ │ └── SocketServer.java
│ │ └── util
│ │ └── BizHttpClientUtil.java
└── resources
│ ├── META-INF
│ └── spring.factories
│ └── application.properties
└── test
└── java
└── me
└── zhangjh
└── chatgpt
├── Application.java
└── ChatGptTest.java
/.gitignore:
--------------------------------------------------------------------------------
1 | /.idea/
2 | /target/
3 |
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | #### ChatGpt JAVA API Starter
2 | ##### 简介
3 | 这是一个基于Java开发的ChatGpt API库,非常易于接入使用。
4 | 你只需生成一个自己的openAI apiKey,依赖本三方库,即可便利地使用ChatGpt。
5 |
6 | 当前接口功能主要有文本补全、图片生成、Chat三种,Chat支持到gpt-3.5-turbol模型,也会跟随官方更新进行升级。
7 |
8 | #### 2.x版本已重构过,建议使用2.x版本,但是注意2.x版本和1.x版本不兼容,从1.x版本升级需要注意修改已接入代码
9 |
10 | ##### 如何使用?
11 | 0. 到[这里](https://beta.openai.com/docs/quickstart/build-your-application)生成一个自己的API KEY
12 | 1. 工程中加入依赖:
13 | ```xml
14 |
15 | me.zhangjh
16 | chatgpt-starter
17 | ${最新版本}
18 |
19 | ```
20 | 最新版本查询:https://mvnrepository.com/artifact/me.zhangjh/chatgpt-starter
21 |
22 | 2. 将生成的apiKey加入配置文件application.properties
23 | ```properties
24 | openai.apikey=xxxxxxxxxxxxxxxxxxx
25 | ```
26 | 或者将上述配置添加进环境变量
27 | 3. 代码中注入service
28 |
29 | ```java
30 | import org.springframework.beans.factory.annotation.Autowired;
31 |
32 | @Autowired
33 | private ChatGptService chatGptService;
34 | // 调用方法即可,其他方法不赘述
35 | TextResponse createTextCompletion(TextRequest data);
36 | ImageResponse createImageGeneration(ImageRequest imageRequest);
37 |
38 | ```
39 |
40 |
41 | ## 我使用这个starter制作了一个微信小程序:AI文图,欢迎交流~
42 | 
43 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | me.zhangjh
8 | chatgpt-starter
9 | jar
10 | 2.0.6
11 |
12 | chatgpt-starter
13 | A springboot starter for chatgpt api.
14 | git@github.com:zhangjh/chatgpt-starter.git
15 |
16 |
17 | scm:git:git@github.com:zhangjh/chatgpt-starter.git
18 | scm:git:ssh://git@github.com:zhangjh/chatgpt-starter.git
19 | git@github.com:zhangjh/chatgpt-starter.git
20 |
21 |
22 |
23 |
24 | zhangjh
25 | njhxzhangjihong@126.com
26 | zhangjh.me
27 | https://zhangjh.me
28 | +8
29 |
30 |
31 |
32 |
33 | GPL
34 | https://www.gnu.org/licenses/gpl.html
35 |
36 |
37 |
38 |
39 | ossrh
40 | https://s01.oss.sonatype.org/content/repositories/snapshots
41 |
42 |
43 | ossrh
44 | https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/
45 |
46 |
47 |
48 |
49 | 11
50 | 11
51 |
52 |
53 |
54 |
55 | org.projectlombok
56 | lombok
57 | 1.18.26
58 |
59 |
60 | com.alibaba
61 | fastjson
62 | 2.0.25
63 |
64 |
65 | org.springframework
66 | spring-context
67 | 5.3.25
68 |
69 |
70 | javax.annotation
71 | javax.annotation-api
72 | 1.3.2
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 | junit
81 | junit
82 | 4.13.2
83 | test
84 |
85 |
86 | org.springframework
87 | spring-test
88 | 5.3.25
89 | test
90 |
91 |
92 | org.springframework.boot
93 | spring-boot-test
94 | 2.7.9
95 | test
96 |
97 |
98 |
99 | me.zhangjh
100 | share
101 | 2.0.5-1
102 |
103 |
104 |
105 | org.apache.httpcomponents
106 | httpclient
107 | 4.5.14
108 |
109 |
110 | org.apache.commons
111 | commons-lang3
112 | 3.12.0
113 |
114 |
115 | ch.qos.logback
116 | logback-core
117 | 1.4.5
118 |
119 |
120 | org.slf4j
121 | log4j-over-slf4j
122 | 2.0.6
123 |
124 |
125 | org.springframework
126 | spring-webmvc
127 | 5.3.25
128 |
129 |
130 | org.apache.tomcat.embed
131 | tomcat-embed-websocket
132 | 9.0.71
133 |
134 |
135 |
136 | com.squareup.okhttp3
137 | okhttp
138 | 4.10.0
139 |
140 |
141 | kotlin-stdlib
142 | org.jetbrains.kotlin
143 |
144 |
145 |
146 |
147 | jakarta.validation
148 | jakarta.validation-api
149 | 3.0.2
150 |
151 |
152 | org.apache.commons
153 | commons-collections4
154 | 4.4
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 | org.apache.maven.plugins
163 | maven-resources-plugin
164 | 3.2.0
165 |
166 |
167 | org.springframework.boot
168 | spring-boot-maven-plugin
169 |
170 |
171 | org.apache.maven.plugins
172 | maven-compiler-plugin
173 |
174 | utf-8
175 | 11
176 | 11
177 |
178 |
179 |
180 | org.sonatype.plugins
181 | nexus-staging-maven-plugin
182 | 1.6.7
183 | true
184 |
185 | ossrh
186 | https://s01.oss.sonatype.org/
187 | true
188 |
189 |
190 |
191 | org.apache.maven.plugins
192 | maven-source-plugin
193 | 2.2.1
194 |
195 |
196 | attach-sources
197 |
198 | jar-no-fork
199 |
200 |
201 |
202 |
203 |
204 | org.apache.maven.plugins
205 | maven-gpg-plugin
206 | 1.5
207 |
208 |
209 | sign-artifacts
210 | verify
211 |
212 | sign
213 |
214 |
215 | 0x7F777CAF
216 | 0x7F777CAF
217 |
218 |
219 |
220 |
221 |
222 | org.apache.maven.plugins
223 | maven-javadoc-plugin
224 | 2.9.1
225 |
226 | -Xdoclint:none
227 |
228 |
229 |
230 | attach-javadocs
231 |
232 | jar
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/client/ChatGptService.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.client;
2 |
3 | import me.zhangjh.chatgpt.dto.request.ChatRequest;
4 | import me.zhangjh.chatgpt.dto.request.ImageRequest;
5 | import me.zhangjh.chatgpt.dto.request.TextRequest;
6 | import me.zhangjh.chatgpt.dto.request.TranscriptionRequest;
7 | import me.zhangjh.chatgpt.dto.response.ChatResponse;
8 | import me.zhangjh.chatgpt.dto.response.ImageResponse;
9 | import me.zhangjh.chatgpt.dto.response.TextResponse;
10 | import me.zhangjh.chatgpt.dto.response.TranscriptionResponse;
11 | import me.zhangjh.chatgpt.socket.SocketServer;
12 |
13 | import java.util.Map;
14 | import java.util.function.Function;
15 |
16 | /**
17 | * @author zhangjh
18 | * @date 1:41 PM 2022/12/15
19 | * @Description
20 | */
21 | public interface ChatGptService {
22 |
23 | /**
24 | * text completion
25 | * @param data
26 | * @return chatResponse
27 | */
28 | TextResponse createTextCompletion(TextRequest textRequest, Map bizParams);
29 |
30 | /**
31 | * image generation
32 | * @param imageRequest
33 | * @return imageResponse
34 | */
35 | ImageResponse createImageGeneration(ImageRequest imageRequest, Map bizParams);
36 |
37 |
38 | ChatResponse createChatCompletion(ChatRequest request, Map bizParams);
39 |
40 | /**
41 | * for weixin only, transfer SseEmitter to WebSocket
42 | * */
43 | void createChatCompletionStream(ChatRequest request, Map bizParams, SocketServer socketServer,
44 | Function bizCb);
45 |
46 | TranscriptionResponse createTranscription(TranscriptionRequest request, Map bizParams);
47 | }
48 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/config/ChatGptConfig.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.config;
2 |
3 | import me.zhangjh.chatgpt.client.ChatGptService;
4 | import me.zhangjh.chatgpt.service.ChatGptServiceImpl;
5 | import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
6 | import org.springframework.context.annotation.Bean;
7 | import org.springframework.context.annotation.Configuration;
8 |
9 | /**
10 | * @author zhangjh
11 | * @date 4:09 PM 2022/12/15
12 | * @Description
13 | */
14 | @Configuration
15 | public class ChatGptConfig {
16 |
17 | @Bean
18 | @ConditionalOnMissingBean
19 | public ChatGptService chatGptService() {
20 | return new ChatGptServiceImpl();
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/config/HttpSessionWSHelper.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.config;
2 |
3 | import javax.servlet.http.HttpSession;
4 | import javax.websocket.HandshakeResponse;
5 | import javax.websocket.server.HandshakeRequest;
6 | import javax.websocket.server.ServerEndpointConfig;
7 |
8 | /**
9 | * @author njhxzhangjihong@126.com
10 | * @date 3:33 PM 2023/3/13
11 | * @Description
12 | */
13 | public class HttpSessionWSHelper extends ServerEndpointConfig.Configurator {
14 |
15 | @Override
16 | public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
17 | HttpSession session = (HttpSession) request.getHttpSession();
18 | if (session != null){
19 | sec.getUserProperties().put(HttpSession.class.getName(), session);
20 | }
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/constant/ModelEnum.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.constant;
2 |
3 | import lombok.Getter;
4 |
5 | /**
6 | * @author njhxzhangjihong@126.com
7 | * @date 9:18 AM 2023/3/2
8 | * @Description
9 | */
10 | @Getter
11 | public enum ModelEnum {
12 |
13 | DAVINCI("text-davinci-003"),
14 | CHATGPT_TURBO("gpt-3.5-turbo"),
15 | CHATGPT_0301("gpt-3.5-turbo-0301"),
16 | ;
17 |
18 | private String code;
19 |
20 | ModelEnum(String code) {
21 | this.code = code;
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/constant/RoleEnum.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.constant;
2 |
3 | /**
4 | * @author njhxzhangjihong@126.com
5 | * @date 9:31 AM 2023/3/2
6 | * @Description
7 | */
8 | public enum RoleEnum {
9 | system, user, assistant
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/Message.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto;
2 |
3 | import lombok.Data;
4 |
5 | /**
6 | * @author njhxzhangjihong@126.com
7 | * @date 9:53 AM 2023/3/2
8 | * @Description
9 | */
10 | @Data
11 | public class Message {
12 | private String role;
13 | private String content;
14 | }
15 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/request/ChatBaseRequest.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto.request;
2 |
3 | import com.alibaba.fastjson.annotation.JSONField;
4 | import jakarta.validation.constraints.NotNull;
5 | import lombok.Data;
6 |
7 | /**
8 | * @author njhxzhangjihong@126.com
9 | * @date 10:07 AM 2023/3/2
10 | * @Description
11 | */
12 | @Data
13 | public class ChatBaseRequest {
14 |
15 | /** ID of the model to use */
16 | @NotNull
17 | private String model = "text-davinci-003";
18 |
19 | /** Control the Creativity, it can be 0~1.
20 | * It’s not recommended to use the temperature with the Top_p parameter
21 | * */
22 | private Double temperature;
23 |
24 | /** the maximum num of tokens to generate in the completion */
25 | @JSONField(name = "max_tokens")
26 | private Integer maxTokens;
27 |
28 | @JSONField(name = "top_p")
29 | private Integer topP;
30 |
31 | /**
32 | * how many completions to generate for each prompt
33 | * default to 1
34 | * */
35 | private Integer n = 1;
36 |
37 | @JSONField(name = "frequency_penalty")
38 | private Double frequencyPenalty;
39 |
40 | @JSONField(name = "presence_penalty")
41 |
42 | private Double presencePenalty;
43 |
44 | /**
45 | * up to 4 sequences where the API stop generating more text.
46 | * */
47 | private String stop;
48 |
49 | /**
50 | * Whether to stream back partial progress.
51 | * */
52 | private Boolean stream = false;
53 | }
54 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/request/ChatRequest.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto.request;
2 |
3 | import jakarta.validation.constraints.NotEmpty;
4 | import jakarta.validation.constraints.NotNull;
5 | import lombok.Data;
6 | import me.zhangjh.chatgpt.constant.ModelEnum;
7 | import me.zhangjh.chatgpt.dto.Message;
8 |
9 | import java.util.List;
10 |
11 | /**
12 | * @author njhxzhangjihong@126.com
13 | * @date 9:29 AM 2023/3/2
14 | * @Description
15 | */
16 | @Data
17 | public class ChatRequest extends ChatBaseRequest {
18 |
19 | @NotNull
20 | private String model = ModelEnum.CHATGPT_TURBO.getCode();
21 |
22 | @NotEmpty
23 | private List messages;
24 | }
25 |
26 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/request/ImageRequest.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto.request;
2 |
3 | import com.alibaba.fastjson.annotation.JSONField;
4 | import jakarta.validation.constraints.NotNull;
5 | import lombok.Data;
6 |
7 | import java.util.Arrays;
8 | import java.util.List;
9 |
10 | /**
11 | * @author zhangjh
12 | * @date 2022/12/15
13 | * @Description
14 | */
15 | @Data
16 | public class ImageRequest {
17 |
18 | /** a text description of the desired images */
19 | @NotNull
20 | private String prompt;
21 |
22 | /** a num >=1 && <= 10, images to generate*/
23 | private int n = 1;
24 |
25 | /** the size of images */
26 | private String size = "1024x1024";
27 |
28 | /** must be url or b64_json, default url */
29 | @JSONField(name = "response_format")
30 | private String responseFormat = "url";
31 |
32 | public void check() {
33 | List validSizeList = Arrays.asList("256x256", "512x512", "1024x1024");
34 | if(validSizeList.stream().noneMatch(item -> this.size.equals(item))) {
35 | throw new RuntimeException("invalid image size:" + this.size);
36 | }
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/request/TextRequest.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto.request;
2 |
3 | import com.alibaba.fastjson.annotation.JSONField;
4 | import jakarta.validation.constraints.NotNull;
5 | import lombok.Data;
6 | import me.zhangjh.chatgpt.constant.ModelEnum;
7 |
8 | /**
9 | * @author zhangjh
10 | * @date 3:05 PM 2022/12/15
11 | * @Description
12 | */
13 | @Data
14 | public class TextRequest extends ChatBaseRequest {
15 |
16 | /** ID of the model to use */
17 | @NotNull
18 | private String model = ModelEnum.CHATGPT_TURBO.getCode();
19 |
20 | /** the propmts to generate completions for */
21 | @NotNull
22 | private String prompt;
23 |
24 | /** the suffix that comes after a completion of inserted text */
25 | private String suffix;
26 |
27 | /** the maximum num of tokens to generate in the completion */
28 | @JSONField(name = "max_tokens")
29 | private Integer maxTokens = 2048;
30 |
31 | @JSONField(name = "best_of")
32 | private Integer bestOf = 1;
33 |
34 | public void check() {
35 | if(this.getStream()) {
36 | throw new RuntimeException("do not support stream");
37 | }
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/request/TranscriptionRequest.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto.request;
2 |
3 | import jakarta.validation.constraints.NotNull;
4 | import lombok.Data;
5 |
6 | /**
7 | * @author njhxzhangjihong@126.com
8 | * @date 6:01 PM 2023/3/14
9 | * @Description
10 | */
11 | @Data
12 | public class TranscriptionRequest {
13 |
14 | private String file;
15 |
16 | /**
17 | * An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language.
18 | * */
19 | @NotNull
20 | private String model = "whisper-1";
21 |
22 | @NotNull
23 | private String prompt;
24 |
25 | private Integer temperature = 0;
26 |
27 | /**
28 | * The language of the input audio. Supplying the input language in ISO-639-1 format will improve accuracy and latency.
29 | * */
30 | private String language = "ISO-639-1";
31 | }
32 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/response/BizException.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto.response;
2 |
3 | /**
4 | * @author njhxzhangjihong@126.com
5 | * @date 5:58 PM 2023/2/15
6 | * @Description
7 | */
8 | public class BizException extends Exception {
9 | public BizException(String error) {
10 | super(error);
11 | }
12 |
13 | public BizException(Throwable cause) {
14 | super(cause);
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/response/ChatResponse.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto.response;
2 |
3 | import lombok.Data;
4 |
5 | import java.util.Date;
6 | import java.util.List;
7 |
8 | /**
9 | * @author njhxzhangjihong@126.com
10 | * @date 9:50 AM 2023/3/2
11 | * @Description
12 | */
13 | @Data
14 | public class ChatResponse {
15 |
16 | private String id;
17 |
18 | private String object;
19 |
20 | private Date created;
21 |
22 | private String model;
23 |
24 | private String errorMsg;
25 |
26 | private List choices;
27 |
28 | private CompletionUsage usage;
29 |
30 | }
31 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/response/ChatRet.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto.response;
2 |
3 | import lombok.Data;
4 | import me.zhangjh.chatgpt.dto.Message;
5 |
6 | import java.util.List;
7 |
8 | /**
9 | * @author njhxzhangjihong@126.com
10 | * @date 9:51 AM 2023/3/2
11 | * @Description
12 | */
13 | @Data
14 | public class ChatRet extends TextRet {
15 |
16 | private Message message;
17 |
18 | private List delta;
19 | }
20 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/response/ChatStreamRet.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto.response;
2 |
3 | import lombok.Data;
4 |
5 | /**
6 | * @author njhxzhangjihong@126.com
7 | * @date 5:34 PM 2023/3/13
8 | * @Description
9 | */
10 | @Data
11 | public class ChatStreamRet extends TextRet {
12 |
13 | private String content;
14 | }
15 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/response/CompletionUsage.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto.response;
2 |
3 | import com.alibaba.fastjson.annotation.JSONField;
4 | import lombok.Data;
5 |
6 | /**
7 | * @author zhangjh
8 | * @date 3:03 PM 2022/12/15
9 | * @Description
10 | */
11 | @Data
12 | public class CompletionUsage {
13 | @JSONField(name = "prompt_tokens")
14 | private int promptTokens;
15 |
16 | @JSONField(name = "completion_tokens")
17 | private int completionTokens;
18 |
19 | @JSONField(name = "total_tokens")
20 | private int totalTokens;
21 | }
22 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/response/ImageResponse.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto.response;
2 |
3 | import lombok.Data;
4 |
5 | import java.util.Date;
6 | import java.util.List;
7 |
8 | /**
9 | * @author zhangjh
10 | * @date 3:13 PM 2022/12/15
11 | * @Description
12 | */
13 | @Data
14 | public class ImageResponse {
15 |
16 | private Date create;
17 |
18 | private List data;
19 |
20 | private String errorMsg;
21 | }
22 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/response/ImageRet.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto.response;
2 |
3 | import lombok.Data;
4 |
5 | /**
6 | * @author zhangjh
7 | * @date 3:13 PM 2022/12/15
8 | * @Description
9 | */
10 | @Data
11 | public class ImageRet {
12 | private String url;
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/response/TextResponse.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto.response;
2 |
3 | import lombok.Data;
4 |
5 | import java.util.Date;
6 | import java.util.List;
7 |
8 | /**
9 | * @author zhangjh
10 | * @date 3:02 PM 2022/12/15
11 | * @Description
12 | */
13 | @Data
14 | public class TextResponse {
15 | private String id;
16 |
17 | private String object;
18 |
19 | private Date created;
20 |
21 | private String model;
22 |
23 | private String errorMsg;
24 |
25 | private List choices;
26 |
27 | private CompletionUsage usage;
28 | }
29 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/response/TextRet.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto.response;
2 |
3 | import lombok.Data;
4 |
5 | /**
6 | * @author zhangjh
7 | * @date 3:02 PM 2022/12/15
8 | * @Description
9 | */
10 | @Data
11 | public class TextRet {
12 | private String text;
13 |
14 | private int index;
15 |
16 | private String finishReason;
17 | }
18 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/dto/response/TranscriptionResponse.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.dto.response;
2 |
3 | import lombok.Data;
4 |
5 | /**
6 | * @author njhxzhangjihong@126.com
7 | * @date 6:09 PM 2023/3/14
8 | * @Description
9 | */
10 | @Data
11 | public class TranscriptionResponse {
12 |
13 | private String text;
14 | }
15 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/service/ChatGptServiceImpl.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.service;
2 |
3 | import com.alibaba.fastjson.JSONObject;
4 | import lombok.extern.slf4j.Slf4j;
5 | import me.zhangjh.chatgpt.client.ChatGptService;
6 | import me.zhangjh.chatgpt.dto.request.ChatRequest;
7 | import me.zhangjh.chatgpt.dto.request.ImageRequest;
8 | import me.zhangjh.chatgpt.dto.request.TextRequest;
9 | import me.zhangjh.chatgpt.dto.request.TranscriptionRequest;
10 | import me.zhangjh.chatgpt.dto.response.ChatResponse;
11 | import me.zhangjh.chatgpt.dto.response.ImageResponse;
12 | import me.zhangjh.chatgpt.dto.response.TextResponse;
13 | import me.zhangjh.chatgpt.dto.response.TranscriptionResponse;
14 | import me.zhangjh.chatgpt.socket.SocketServer;
15 | import me.zhangjh.chatgpt.util.BizHttpClientUtil;
16 | import me.zhangjh.share.util.HttpClientUtil;
17 | import me.zhangjh.share.util.HttpRequest;
18 | import org.apache.commons.collections4.MapUtils;
19 | import org.apache.commons.lang3.StringUtils;
20 | import org.springframework.beans.factory.annotation.Value;
21 | import org.springframework.util.Assert;
22 |
23 | import javax.annotation.PostConstruct;
24 | import java.util.HashMap;
25 | import java.util.Map;
26 | import java.util.function.Function;
27 |
28 | /**
29 | * @author zhangjh
30 | * @date 3:21 PM 2022/12/15
31 | * @Description
32 | */
33 | @Slf4j
34 | public class ChatGptServiceImpl implements ChatGptService {
35 |
36 | @Value("${openai.apikey}")
37 | private String configApiKey;
38 |
39 | private String apiKey;
40 |
41 | private final Map header = new HashMap<>();
42 |
43 | private static final String TEXT_COMPLETION_URL = "https://api.openai.com/v1/completions";
44 |
45 | private static final String IMAGE_GENERATE_URL = "https://api.openai.com/v1/images/generations";
46 |
47 | @Value("${openai.chat.url:https://api.openai.com/v1/chat/completions}")
48 | private String chatUrl;
49 |
50 | private static final String TRANSCRIPTION_URL = "https://api.openai.com/v1/audio/transcriptions";
51 |
52 | @PostConstruct
53 | public void init() {
54 | if(StringUtils.isEmpty(configApiKey)) {
55 | configApiKey = System.getenv("openai.apikey");
56 | }
57 | if(StringUtils.isEmpty(apiKey)) {
58 | apiKey = configApiKey;
59 | }
60 | Assert.isTrue(StringUtils.isNotEmpty(apiKey), "openai apiKey not exist");
61 | // openAi
62 | if(apiKey.startsWith("sk-")) {
63 | header.put("Authorization", "Bearer " + apiKey);
64 | } else {
65 | // azure
66 | header.put("api-key", apiKey);
67 | }
68 | }
69 |
70 | /**
71 | * allow set apiKey from outside, but you must define this bean yourself
72 | */
73 | public void setApiKey(String apiKey) {
74 | this.apiKey = apiKey;
75 | }
76 |
77 | @Override
78 | public TextResponse createTextCompletion(TextRequest textRequest, Map bizParams) {
79 | // this interface must set request.stream to false
80 | textRequest.check();
81 |
82 | HttpRequest httpRequest = new HttpRequest(TEXT_COMPLETION_URL);
83 | httpRequest.setReqData(JSONObject.toJSONString(textRequest));
84 | httpRequest.setBizHeaderMap(this.header);
85 | String response = HttpClientUtil.sendNormally(httpRequest).toString();
86 | return JSONObject.parseObject(response, TextResponse.class);
87 | }
88 |
89 | @Override
90 | public ImageResponse createImageGeneration(ImageRequest imageRequest, Map bizParams) {
91 | imageRequest.check();
92 | HttpRequest httpRequest = new HttpRequest(IMAGE_GENERATE_URL);
93 | httpRequest.setReqData(JSONObject.toJSONString(imageRequest));
94 | httpRequest.setBizHeaderMap(this.header);
95 | String response = HttpClientUtil.sendNormally(httpRequest).toString();
96 | return JSONObject.parseObject(response, ImageResponse.class);
97 | }
98 |
99 | @Override
100 | public ChatResponse createChatCompletion(ChatRequest request, Map bizParams) {
101 | HttpRequest httpRequest = new HttpRequest(chatUrl);
102 | httpRequest.setReqData(JSONObject.toJSONString(request));
103 | httpRequest.setBizHeaderMap(this.header);
104 | String response = HttpClientUtil.sendNormally(httpRequest).toString();
105 | return JSONObject.parseObject(response, ChatResponse.class);
106 | }
107 |
108 | @Override
109 | public void createChatCompletionStream(ChatRequest request, Map bizParams,
110 | SocketServer socketServer, Function bizCb) {
112 | request.setStream(true);
113 | Assert.isTrue(MapUtils.isNotEmpty(bizParams) && StringUtils.isNotEmpty(bizParams.get("userId")),
114 | "userId为空");
115 | this.header.putAll(bizParams);
116 | HttpRequest httpRequest = new HttpRequest(chatUrl);
117 | httpRequest.setReqData(JSONObject.toJSONString(request));
118 | httpRequest.setBizHeaderMap(this.header);
119 |
120 | BizHttpClientUtil.sendStream(httpRequest, socketServer, bizCb);
121 | }
122 |
123 | @Override
124 | public TranscriptionResponse createTranscription(TranscriptionRequest request, Map bizParams) {
125 | HttpRequest httpRequest = new HttpRequest(TRANSCRIPTION_URL);
126 | httpRequest.setReqData(JSONObject.toJSONString(request));
127 | httpRequest.setBizHeaderMap(this.header);
128 | String response = BizHttpClientUtil.sendFileMultiPart(httpRequest).toString();
129 | return JSONObject.parseObject(response, TranscriptionResponse.class);
130 | }
131 | }
132 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/socket/SocketServer.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.socket;
2 |
3 | import com.alibaba.fastjson.JSONObject;
4 | import lombok.Data;
5 | import lombok.extern.slf4j.Slf4j;
6 | import org.apache.commons.lang3.StringUtils;
7 | import org.springframework.util.Assert;
8 |
9 | import javax.annotation.PostConstruct;
10 | import javax.websocket.*;
11 | import javax.websocket.server.PathParam;
12 | import javax.websocket.server.ServerEndpoint;
13 | import java.io.EOFException;
14 | import java.io.IOException;
15 | import java.util.concurrent.ConcurrentHashMap;
16 | import java.util.concurrent.ConcurrentMap;
17 |
18 | /**
19 | * @author njhxzhangjihong@126.com
20 | * @date 3:03 PM 2023/3/13
21 | * @Description configure me as spring bean,you can override me to custom your sendMessage method
22 | */
23 | @Data
24 | @Slf4j
25 | @ServerEndpoint("/socket/chatStream/{userId}")
26 | public class SocketServer {
27 |
28 | private String userId;
29 | private Session session;
30 |
31 | protected static ConcurrentMap userMap = new ConcurrentHashMap<>();
32 |
33 | private static final String FINISH_FLAG = "[DONE]";
34 |
35 | @PostConstruct
36 | public void init() {
37 | log.info("chatSocketServer inited");
38 | }
39 |
40 | @OnOpen
41 | public void onOpen(Session session, @PathParam("userId") String userId) {
42 | this.session = session;
43 | this.userId = userId;
44 | log.info("onOpen...userId: {}", userId);
45 | userMap.putIfAbsent(userId, this);
46 | }
47 |
48 | @OnClose
49 | public void onClose() {
50 | log.info("onClose..");
51 | // should notice clients to update connect status
52 | // when multiple clients connected
53 | SocketServer server = userMap.get(this.userId);
54 | if(server != null) {
55 | try {
56 | server.getSession().close();
57 | } catch (IOException e) {
58 | log.error("socketServer close exception:", e);
59 | }
60 | }
61 | userMap.remove(this.userId);
62 | }
63 |
64 | @OnMessage
65 | public void onMessage(String msg, Session session) {
66 | // in this project, client will not send message to server
67 | }
68 |
69 | /** you must override this method to implements you biz log
70 | * message removed the beginning 'data:'
71 | * */
72 | public void sendMessage(String userId, String message) {
73 | log.info("userId: {}, message: {}", userId, message);
74 | if(StringUtils.isNotEmpty(message)) {
75 | Assert.isTrue(userMap.size() != 0, "socket not connected");
76 | Assert.isTrue(StringUtils.isNotEmpty(userId), "userId empty");
77 |
78 | log.info("message: {}", message);
79 | sendMsgInternal(userId, message);
80 | }
81 | }
82 |
83 | private void sendMsgInternal(String userId, String message) {
84 | try {
85 | SocketServer server = userMap.get(userId);
86 | Assert.isTrue(server != null, "socketServer empty");
87 | Session session = server.getSession();
88 | Assert.isTrue(session.isOpen(), "socket not connected");
89 | TextMessage textMessage = new TextMessage();
90 | textMessage.setUserId(userId);
91 | textMessage.setMessage(message);
92 | log.info("textMessage: {}", textMessage);
93 | synchronized (session.getId()) {
94 | session.getBasicRemote().sendText(JSONObject.toJSONString(textMessage));
95 | }
96 | } catch (IOException e) {
97 | log.error("socket sendText exception:", e);
98 | }
99 | }
100 |
101 | @OnError
102 | public void onError(Session session, Throwable t) {
103 | if(t instanceof EOFException) {
104 | log.info("socket timeout");
105 | this.onClose();
106 | }
107 | }
108 |
109 | @Data
110 | static class TextMessage {
111 | private String userId;
112 | private String message;
113 | }
114 | }
115 |
--------------------------------------------------------------------------------
/src/main/java/me/zhangjh/chatgpt/util/BizHttpClientUtil.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt.util;
2 |
3 | import com.alibaba.fastjson.JSONObject;
4 | import lombok.extern.slf4j.Slf4j;
5 | import me.zhangjh.chatgpt.dto.request.TranscriptionRequest;
6 | import me.zhangjh.chatgpt.dto.response.ChatResponse;
7 | import me.zhangjh.chatgpt.dto.response.ChatRet;
8 | import me.zhangjh.chatgpt.dto.response.ChatStreamRet;
9 | import me.zhangjh.chatgpt.socket.SocketServer;
10 | import me.zhangjh.share.util.HttpClientUtil;
11 | import me.zhangjh.share.util.HttpRequest;
12 | import okhttp3.*;
13 | import okhttp3.sse.EventSource;
14 | import okhttp3.sse.EventSourceListener;
15 | import okhttp3.sse.EventSources;
16 | import org.apache.commons.lang3.StringUtils;
17 | import org.jetbrains.annotations.NotNull;
18 | import org.jetbrains.annotations.Nullable;
19 | import org.springframework.util.Assert;
20 |
21 | import java.io.File;
22 | import java.io.IOException;
23 | import java.util.HashMap;
24 | import java.util.List;
25 | import java.util.Objects;
26 | import java.util.function.Function;
27 |
28 | /**
29 | * @author njhxzhangjihong@126.com
30 | * @date 1:05 AM 2023/3/17
31 | * @Description
32 | */
33 | @Slf4j
34 | public class BizHttpClientUtil extends HttpClientUtil {
35 |
36 | private static HashMap msgMap = new HashMap<>();
37 |
38 | private static final String FINISH_FLAG = "[DONE]";
39 |
40 | public static void sendStream(HttpRequest httpRequest, SocketServer socketServer, Function bizCb) {
41 | EventSource.Factory factory = EventSources.createFactory(OK_HTTP_CLIENT);
42 | EventSourceListener eventSourceListener = new EventSourceListener() {
43 | @Override
44 | public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
45 | log.info("id: {}, type: {}, data: {}", id, type, data);
46 | // 异步结果回来后要从eventSource里取userId,因为不确定返回事件里的结果是哪次请求的
47 | String userId = eventSource.request().header("userId");
48 | handleResponse(userId, data, socketServer, bizCb);
49 | }
50 |
51 | @Override
52 | public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
53 | handleException(t);
54 | }
55 | };
56 | factory.newEventSource(buildRequest(httpRequest), eventSourceListener);
57 | }
58 |
59 | public static Object sendFileMultiPart(HttpRequest request) {
60 | TranscriptionRequest transcriptionRequest = JSONObject.parseObject(request.getReqData(), TranscriptionRequest.class);
61 |
62 | MediaType mediaType = MediaType.parse("multipart/form-data");
63 | RequestBody body = new MultipartBody.Builder().setType(MultipartBody.FORM)
64 | .addFormDataPart("file", transcriptionRequest.getFile(),
65 | RequestBody.create(new File(transcriptionRequest.getFile()), mediaType))
66 | .addFormDataPart("model", transcriptionRequest.getModel())
67 | .build();
68 | Request httpRequest = new Request.Builder()
69 | .url(request.getUrl())
70 | .method("POST", body)
71 | .addHeader("Authorization", request.getBizHeaderMap().get("Authorization"))
72 | .addHeader("Content-Type", "multipart/form-data")
73 | .build();
74 | try (Response response = OK_HTTP_CLIENT.newCall(httpRequest).execute()){
75 | return handleResponse(Objects.requireNonNull(response.body()));
76 | } catch (IOException e) {
77 | log.error("sendFileMultiPart exception, ", e);
78 | throw new RuntimeException(e);
79 | }
80 | }
81 |
82 | private static void handleResponse(String userId, String data,
83 | SocketServer socketServer, Function bizCb) {
84 | Assert.isTrue(StringUtils.isNotEmpty(userId), "userId为空");
85 | // if stream is true, send socket msg here
86 | // res format: data: {xxx}
87 | if(StringUtils.isNotEmpty(data)) {
88 | if(data.equals(FINISH_FLAG)) {
89 | socketServer.sendMessage(userId, data);
90 | String fullMsg = msgMap.get(userId).toString();
91 | bizCb.apply(fullMsg);
92 | // 清空缓存
93 | msgMap.remove(userId);
94 | return;
95 | }
96 | ChatResponse chatResponse = JSONObject.parseObject(data, ChatResponse.class);
97 | List choices = chatResponse.getChoices();
98 | StringBuilder partialMsg = msgMap.getOrDefault(userId, new StringBuilder());
99 |
100 | for (ChatRet choice : choices) {
101 | List delta = choice.getDelta();
102 | for (ChatStreamRet ret : delta) {
103 | if(ret == null) {
104 | continue;
105 | }
106 | String content = ret.getContent();
107 | if(StringUtils.isNotEmpty(content)) {
108 | socketServer.sendMessage(userId, content);
109 | partialMsg.append(ret.getContent());
110 | }
111 | }
112 | }
113 | msgMap.put(userId, partialMsg);
114 | }
115 | }
116 |
117 | private static String handleException(Throwable t) {
118 | log.error("sendStream exception:", t);
119 | return null;
120 | }
121 | }
122 |
--------------------------------------------------------------------------------
/src/main/resources/META-INF/spring.factories:
--------------------------------------------------------------------------------
1 | org.springframework.boot.autoconfigure.EnableAutoConfiguration=\
2 | me.zhangjh.chatgpt.config.ChatGptConfig
--------------------------------------------------------------------------------
/src/main/resources/application.properties:
--------------------------------------------------------------------------------
1 | openai.apikey=xxxxxxxxxxx
--------------------------------------------------------------------------------
/src/test/java/me/zhangjh/chatgpt/Application.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt;
2 |
3 | import org.springframework.boot.SpringApplication;
4 | import org.springframework.boot.autoconfigure.SpringBootApplication;
5 |
6 | /**
7 | * @author zhangjh
8 | * @date 4:17 PM 2022/12/15
9 | * @Description
10 | */
11 | @SpringBootApplication
12 | public class Application {
13 | public static void main(String[] args) {
14 | SpringApplication.run(Application.class, args);
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/src/test/java/me/zhangjh/chatgpt/ChatGptTest.java:
--------------------------------------------------------------------------------
1 | package me.zhangjh.chatgpt;
2 |
3 | import me.zhangjh.chatgpt.client.ChatGptService;
4 | import me.zhangjh.chatgpt.constant.RoleEnum;
5 | import me.zhangjh.chatgpt.dto.Message;
6 | import me.zhangjh.chatgpt.dto.request.ChatRequest;
7 | import me.zhangjh.chatgpt.dto.request.ImageRequest;
8 | import me.zhangjh.chatgpt.dto.request.TextRequest;
9 | import me.zhangjh.chatgpt.dto.request.TranscriptionRequest;
10 | import me.zhangjh.chatgpt.dto.response.ChatResponse;
11 | import me.zhangjh.chatgpt.dto.response.ImageResponse;
12 | import me.zhangjh.chatgpt.dto.response.TextResponse;
13 | import me.zhangjh.chatgpt.dto.response.TranscriptionResponse;
14 | import org.junit.Test;
15 | import org.junit.runner.RunWith;
16 | import org.springframework.beans.factory.annotation.Autowired;
17 | import org.springframework.boot.test.context.SpringBootTest;
18 | import org.springframework.test.context.junit4.SpringRunner;
19 |
20 | import java.util.ArrayList;
21 | import java.util.HashMap;
22 | import java.util.List;
23 | import java.util.Map;
24 |
25 | /**
26 | * @author zhangjh
27 | * @date 4:14 PM 2022/12/15
28 | * @Description
29 | */
30 | @RunWith(SpringRunner.class)
31 | @SpringBootTest
32 | public class ChatGptTest {
33 |
34 | @Autowired
35 | private ChatGptService chatGptService;
36 |
37 | @Test
38 | public void textCompletionTest() {
39 | TextRequest textRequest = new TextRequest();
40 | textRequest.setPrompt("Q:写出java hello world?");
41 | textRequest.setTemperature(0.5);
42 | textRequest.setMaxTokens(2048);
43 | textRequest.setBestOf(1);
44 | textRequest.setTopP(1);
45 | // textRequest.setPrompt("Q:将括号里的词汇翻译一下,如果是中文翻译成英文,如果是英文翻译成中文.(一只小狐狸正在吃葡萄)A:");
46 | // TextResponse textCompletion = chatGptService.createTextCompletion(textRequest);
47 | TextResponse textCompletion = chatGptService.createTextCompletion(textRequest, null);
48 |
49 | System.out.println(textCompletion);
50 |
51 | }
52 |
53 | @Test
54 | public void imageGenerateTest() {
55 | ImageRequest imageRequest = new ImageRequest();
56 | imageRequest.setPrompt("一只小狐狸正在吃葡萄");
57 | ImageResponse imageGeneration = chatGptService.createImageGeneration(imageRequest, null);
58 | System.out.println(imageGeneration);
59 | }
60 |
61 | @Test
62 | public void chatTest() {
63 | ChatRequest chatRequest = new ChatRequest();
64 | List messages = new ArrayList<>();
65 | Message message = new Message();
66 | message.setRole(RoleEnum.user.name());
67 | message.setContent("什么是斐波那契数列?");
68 | messages.add(message);
69 | Message answer = new Message();
70 | answer.setRole(RoleEnum.assistant.name());
71 | answer.setContent("斐波那契数列是指:1,1,2,3,5,8,13,21,34……这样一个由数列中前两项相加得出第三项的数列。这个数列与自然界中很多事物的增长规律有关,比如植物的叶片数量、兔子的繁殖规律等等。斐波那契数列的本质是递归定义,是算法和数学中的重要概念之一,也是计算机程序设计中常用的算法之一。");
72 | messages.add(answer);
73 | Message curMessage = new Message();
74 | curMessage.setRole(RoleEnum.user.name());
75 | curMessage.setContent("可以给出代码示例吗,java的?");
76 | messages.add(curMessage);
77 | chatRequest.setMessages(messages);
78 | Map bizParams = new HashMap<>(1);
79 | bizParams.put("userId", "1");
80 | ChatResponse chatCompletion = chatGptService.createChatCompletion(chatRequest, bizParams);
81 | System.out.println(chatCompletion);
82 | }
83 |
84 | @Test
85 | public void audioTest() {
86 | TranscriptionRequest request = new TranscriptionRequest();
87 | request.setFile("/root/test.m4a");
88 | TranscriptionResponse transcription = chatGptService.createTranscription(request, null);
89 | System.out.println(transcription);
90 | }
91 | }
92 |
--------------------------------------------------------------------------------