├── .gitignore ├── Readme.md ├── pom.xml └── src ├── main ├── java │ └── me │ │ └── zhangjh │ │ └── chatgpt │ │ ├── client │ │ └── ChatGptService.java │ │ ├── config │ │ ├── ChatGptConfig.java │ │ └── HttpSessionWSHelper.java │ │ ├── constant │ │ ├── ModelEnum.java │ │ └── RoleEnum.java │ │ ├── dto │ │ ├── Message.java │ │ ├── request │ │ │ ├── ChatBaseRequest.java │ │ │ ├── ChatRequest.java │ │ │ ├── ImageRequest.java │ │ │ ├── TextRequest.java │ │ │ └── TranscriptionRequest.java │ │ └── response │ │ │ ├── BizException.java │ │ │ ├── ChatResponse.java │ │ │ ├── ChatRet.java │ │ │ ├── ChatStreamRet.java │ │ │ ├── CompletionUsage.java │ │ │ ├── ImageResponse.java │ │ │ ├── ImageRet.java │ │ │ ├── TextResponse.java │ │ │ ├── TextRet.java │ │ │ └── TranscriptionResponse.java │ │ ├── service │ │ └── ChatGptServiceImpl.java │ │ ├── socket │ │ └── SocketServer.java │ │ └── util │ │ └── BizHttpClientUtil.java └── resources │ ├── META-INF │ └── spring.factories │ └── application.properties └── test └── java └── me └── zhangjh └── chatgpt ├── Application.java └── ChatGptTest.java /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | /target/ 3 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | #### ChatGpt JAVA API Starter 2 | ##### 简介 3 | 这是一个基于Java开发的ChatGpt API库,非常易于接入使用。 4 | 你只需生成一个自己的openAI apiKey,依赖本三方库,即可便利地使用ChatGpt。 5 | 6 | 当前接口功能主要有文本补全、图片生成、Chat三种,Chat支持到gpt-3.5-turbol模型,也会跟随官方更新进行升级。 7 | 8 | #### 2.x版本已重构过,建议使用2.x版本,但是注意2.x版本和1.x版本不兼容,从1.x版本升级需要注意修改已接入代码 9 | 10 | ##### 如何使用? 11 | 0. 到[这里](https://beta.openai.com/docs/quickstart/build-your-application)生成一个自己的API KEY 12 | 1. 工程中加入依赖: 13 | ```xml 14 | 15 | me.zhangjh 16 | chatgpt-starter 17 | ${最新版本} 18 | 19 | ``` 20 | 最新版本查询:https://mvnrepository.com/artifact/me.zhangjh/chatgpt-starter 21 | 22 | 2. 将生成的apiKey加入配置文件application.properties 23 | ```properties 24 | openai.apikey=xxxxxxxxxxxxxxxxxxx 25 | ``` 26 | 或者将上述配置添加进环境变量 27 | 3. 代码中注入service 28 | 29 | ```java 30 | import org.springframework.beans.factory.annotation.Autowired; 31 | 32 | @Autowired 33 | private ChatGptService chatGptService; 34 | // 调用方法即可,其他方法不赘述 35 | TextResponse createTextCompletion(TextRequest data); 36 | ImageResponse createImageGeneration(ImageRequest imageRequest); 37 | 38 | ``` 39 | 40 | 41 | ## 我使用这个starter制作了一个微信小程序:AI文图,欢迎交流~ 42 | ![little-program](https://user-images.githubusercontent.com/3371714/219958080-f537f271-3d1b-41e1-86cf-1036d04ab6ba.jpeg) 43 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | me.zhangjh 8 | chatgpt-starter 9 | jar 10 | 2.0.6 11 | 12 | chatgpt-starter 13 | A springboot starter for chatgpt api. 14 | git@github.com:zhangjh/chatgpt-starter.git 15 | 16 | 17 | scm:git:git@github.com:zhangjh/chatgpt-starter.git 18 | scm:git:ssh://git@github.com:zhangjh/chatgpt-starter.git 19 | git@github.com:zhangjh/chatgpt-starter.git 20 | 21 | 22 | 23 | 24 | zhangjh 25 | njhxzhangjihong@126.com 26 | zhangjh.me 27 | https://zhangjh.me 28 | +8 29 | 30 | 31 | 32 | 33 | GPL 34 | https://www.gnu.org/licenses/gpl.html 35 | 36 | 37 | 38 | 39 | ossrh 40 | https://s01.oss.sonatype.org/content/repositories/snapshots 41 | 42 | 43 | ossrh 44 | https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ 45 | 46 | 47 | 48 | 49 | 11 50 | 11 51 | 52 | 53 | 54 | 55 | org.projectlombok 56 | lombok 57 | 1.18.26 58 | 59 | 60 | com.alibaba 61 | fastjson 62 | 2.0.25 63 | 64 | 65 | org.springframework 66 | spring-context 67 | 5.3.25 68 | 69 | 70 | javax.annotation 71 | javax.annotation-api 72 | 1.3.2 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | junit 81 | junit 82 | 4.13.2 83 | test 84 | 85 | 86 | org.springframework 87 | spring-test 88 | 5.3.25 89 | test 90 | 91 | 92 | org.springframework.boot 93 | spring-boot-test 94 | 2.7.9 95 | test 96 | 97 | 98 | 99 | me.zhangjh 100 | share 101 | 2.0.5-1 102 | 103 | 104 | 105 | org.apache.httpcomponents 106 | httpclient 107 | 4.5.14 108 | 109 | 110 | org.apache.commons 111 | commons-lang3 112 | 3.12.0 113 | 114 | 115 | ch.qos.logback 116 | logback-core 117 | 1.4.5 118 | 119 | 120 | org.slf4j 121 | log4j-over-slf4j 122 | 2.0.6 123 | 124 | 125 | org.springframework 126 | spring-webmvc 127 | 5.3.25 128 | 129 | 130 | org.apache.tomcat.embed 131 | tomcat-embed-websocket 132 | 9.0.71 133 | 134 | 135 | 136 | com.squareup.okhttp3 137 | okhttp 138 | 4.10.0 139 | 140 | 141 | kotlin-stdlib 142 | org.jetbrains.kotlin 143 | 144 | 145 | 146 | 147 | jakarta.validation 148 | jakarta.validation-api 149 | 3.0.2 150 | 151 | 152 | org.apache.commons 153 | commons-collections4 154 | 4.4 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | org.apache.maven.plugins 163 | maven-resources-plugin 164 | 3.2.0 165 | 166 | 167 | org.springframework.boot 168 | spring-boot-maven-plugin 169 | 170 | 171 | org.apache.maven.plugins 172 | maven-compiler-plugin 173 | 174 | utf-8 175 | 11 176 | 11 177 | 178 | 179 | 180 | org.sonatype.plugins 181 | nexus-staging-maven-plugin 182 | 1.6.7 183 | true 184 | 185 | ossrh 186 | https://s01.oss.sonatype.org/ 187 | true 188 | 189 | 190 | 191 | org.apache.maven.plugins 192 | maven-source-plugin 193 | 2.2.1 194 | 195 | 196 | attach-sources 197 | 198 | jar-no-fork 199 | 200 | 201 | 202 | 203 | 204 | org.apache.maven.plugins 205 | maven-gpg-plugin 206 | 1.5 207 | 208 | 209 | sign-artifacts 210 | verify 211 | 212 | sign 213 | 214 | 215 | 0x7F777CAF 216 | 0x7F777CAF 217 | 218 | 219 | 220 | 221 | 222 | org.apache.maven.plugins 223 | maven-javadoc-plugin 224 | 2.9.1 225 | 226 | -Xdoclint:none 227 | 228 | 229 | 230 | attach-javadocs 231 | 232 | jar 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/client/ChatGptService.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.client; 2 | 3 | import me.zhangjh.chatgpt.dto.request.ChatRequest; 4 | import me.zhangjh.chatgpt.dto.request.ImageRequest; 5 | import me.zhangjh.chatgpt.dto.request.TextRequest; 6 | import me.zhangjh.chatgpt.dto.request.TranscriptionRequest; 7 | import me.zhangjh.chatgpt.dto.response.ChatResponse; 8 | import me.zhangjh.chatgpt.dto.response.ImageResponse; 9 | import me.zhangjh.chatgpt.dto.response.TextResponse; 10 | import me.zhangjh.chatgpt.dto.response.TranscriptionResponse; 11 | import me.zhangjh.chatgpt.socket.SocketServer; 12 | 13 | import java.util.Map; 14 | import java.util.function.Function; 15 | 16 | /** 17 | * @author zhangjh 18 | * @date 1:41 PM 2022/12/15 19 | * @Description 20 | */ 21 | public interface ChatGptService { 22 | 23 | /** 24 | * text completion 25 | * @param data 26 | * @return chatResponse 27 | */ 28 | TextResponse createTextCompletion(TextRequest textRequest, Map bizParams); 29 | 30 | /** 31 | * image generation 32 | * @param imageRequest 33 | * @return imageResponse 34 | */ 35 | ImageResponse createImageGeneration(ImageRequest imageRequest, Map bizParams); 36 | 37 | 38 | ChatResponse createChatCompletion(ChatRequest request, Map bizParams); 39 | 40 | /** 41 | * for weixin only, transfer SseEmitter to WebSocket 42 | * */ 43 | void createChatCompletionStream(ChatRequest request, Map bizParams, SocketServer socketServer, 44 | Function bizCb); 45 | 46 | TranscriptionResponse createTranscription(TranscriptionRequest request, Map bizParams); 47 | } 48 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/config/ChatGptConfig.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.config; 2 | 3 | import me.zhangjh.chatgpt.client.ChatGptService; 4 | import me.zhangjh.chatgpt.service.ChatGptServiceImpl; 5 | import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; 6 | import org.springframework.context.annotation.Bean; 7 | import org.springframework.context.annotation.Configuration; 8 | 9 | /** 10 | * @author zhangjh 11 | * @date 4:09 PM 2022/12/15 12 | * @Description 13 | */ 14 | @Configuration 15 | public class ChatGptConfig { 16 | 17 | @Bean 18 | @ConditionalOnMissingBean 19 | public ChatGptService chatGptService() { 20 | return new ChatGptServiceImpl(); 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/config/HttpSessionWSHelper.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.config; 2 | 3 | import javax.servlet.http.HttpSession; 4 | import javax.websocket.HandshakeResponse; 5 | import javax.websocket.server.HandshakeRequest; 6 | import javax.websocket.server.ServerEndpointConfig; 7 | 8 | /** 9 | * @author njhxzhangjihong@126.com 10 | * @date 3:33 PM 2023/3/13 11 | * @Description 12 | */ 13 | public class HttpSessionWSHelper extends ServerEndpointConfig.Configurator { 14 | 15 | @Override 16 | public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) { 17 | HttpSession session = (HttpSession) request.getHttpSession(); 18 | if (session != null){ 19 | sec.getUserProperties().put(HttpSession.class.getName(), session); 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/constant/ModelEnum.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.constant; 2 | 3 | import lombok.Getter; 4 | 5 | /** 6 | * @author njhxzhangjihong@126.com 7 | * @date 9:18 AM 2023/3/2 8 | * @Description 9 | */ 10 | @Getter 11 | public enum ModelEnum { 12 | 13 | DAVINCI("text-davinci-003"), 14 | CHATGPT_TURBO("gpt-3.5-turbo"), 15 | CHATGPT_0301("gpt-3.5-turbo-0301"), 16 | ; 17 | 18 | private String code; 19 | 20 | ModelEnum(String code) { 21 | this.code = code; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/constant/RoleEnum.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.constant; 2 | 3 | /** 4 | * @author njhxzhangjihong@126.com 5 | * @date 9:31 AM 2023/3/2 6 | * @Description 7 | */ 8 | public enum RoleEnum { 9 | system, user, assistant 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/Message.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto; 2 | 3 | import lombok.Data; 4 | 5 | /** 6 | * @author njhxzhangjihong@126.com 7 | * @date 9:53 AM 2023/3/2 8 | * @Description 9 | */ 10 | @Data 11 | public class Message { 12 | private String role; 13 | private String content; 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/request/ChatBaseRequest.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto.request; 2 | 3 | import com.alibaba.fastjson.annotation.JSONField; 4 | import jakarta.validation.constraints.NotNull; 5 | import lombok.Data; 6 | 7 | /** 8 | * @author njhxzhangjihong@126.com 9 | * @date 10:07 AM 2023/3/2 10 | * @Description 11 | */ 12 | @Data 13 | public class ChatBaseRequest { 14 | 15 | /** ID of the model to use */ 16 | @NotNull 17 | private String model = "text-davinci-003"; 18 | 19 | /** Control the Creativity, it can be 0~1. 20 | * It’s not recommended to use the temperature with the Top_p parameter 21 | * */ 22 | private Double temperature; 23 | 24 | /** the maximum num of tokens to generate in the completion */ 25 | @JSONField(name = "max_tokens") 26 | private Integer maxTokens; 27 | 28 | @JSONField(name = "top_p") 29 | private Integer topP; 30 | 31 | /** 32 | * how many completions to generate for each prompt 33 | * default to 1 34 | * */ 35 | private Integer n = 1; 36 | 37 | @JSONField(name = "frequency_penalty") 38 | private Double frequencyPenalty; 39 | 40 | @JSONField(name = "presence_penalty") 41 | 42 | private Double presencePenalty; 43 | 44 | /** 45 | * up to 4 sequences where the API stop generating more text. 46 | * */ 47 | private String stop; 48 | 49 | /** 50 | * Whether to stream back partial progress. 51 | * */ 52 | private Boolean stream = false; 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/request/ChatRequest.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto.request; 2 | 3 | import jakarta.validation.constraints.NotEmpty; 4 | import jakarta.validation.constraints.NotNull; 5 | import lombok.Data; 6 | import me.zhangjh.chatgpt.constant.ModelEnum; 7 | import me.zhangjh.chatgpt.dto.Message; 8 | 9 | import java.util.List; 10 | 11 | /** 12 | * @author njhxzhangjihong@126.com 13 | * @date 9:29 AM 2023/3/2 14 | * @Description 15 | */ 16 | @Data 17 | public class ChatRequest extends ChatBaseRequest { 18 | 19 | @NotNull 20 | private String model = ModelEnum.CHATGPT_TURBO.getCode(); 21 | 22 | @NotEmpty 23 | private List messages; 24 | } 25 | 26 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/request/ImageRequest.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto.request; 2 | 3 | import com.alibaba.fastjson.annotation.JSONField; 4 | import jakarta.validation.constraints.NotNull; 5 | import lombok.Data; 6 | 7 | import java.util.Arrays; 8 | import java.util.List; 9 | 10 | /** 11 | * @author zhangjh 12 | * @date 2022/12/15 13 | * @Description 14 | */ 15 | @Data 16 | public class ImageRequest { 17 | 18 | /** a text description of the desired images */ 19 | @NotNull 20 | private String prompt; 21 | 22 | /** a num >=1 && <= 10, images to generate*/ 23 | private int n = 1; 24 | 25 | /** the size of images */ 26 | private String size = "1024x1024"; 27 | 28 | /** must be url or b64_json, default url */ 29 | @JSONField(name = "response_format") 30 | private String responseFormat = "url"; 31 | 32 | public void check() { 33 | List validSizeList = Arrays.asList("256x256", "512x512", "1024x1024"); 34 | if(validSizeList.stream().noneMatch(item -> this.size.equals(item))) { 35 | throw new RuntimeException("invalid image size:" + this.size); 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/request/TextRequest.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto.request; 2 | 3 | import com.alibaba.fastjson.annotation.JSONField; 4 | import jakarta.validation.constraints.NotNull; 5 | import lombok.Data; 6 | import me.zhangjh.chatgpt.constant.ModelEnum; 7 | 8 | /** 9 | * @author zhangjh 10 | * @date 3:05 PM 2022/12/15 11 | * @Description 12 | */ 13 | @Data 14 | public class TextRequest extends ChatBaseRequest { 15 | 16 | /** ID of the model to use */ 17 | @NotNull 18 | private String model = ModelEnum.CHATGPT_TURBO.getCode(); 19 | 20 | /** the propmts to generate completions for */ 21 | @NotNull 22 | private String prompt; 23 | 24 | /** the suffix that comes after a completion of inserted text */ 25 | private String suffix; 26 | 27 | /** the maximum num of tokens to generate in the completion */ 28 | @JSONField(name = "max_tokens") 29 | private Integer maxTokens = 2048; 30 | 31 | @JSONField(name = "best_of") 32 | private Integer bestOf = 1; 33 | 34 | public void check() { 35 | if(this.getStream()) { 36 | throw new RuntimeException("do not support stream"); 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/request/TranscriptionRequest.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto.request; 2 | 3 | import jakarta.validation.constraints.NotNull; 4 | import lombok.Data; 5 | 6 | /** 7 | * @author njhxzhangjihong@126.com 8 | * @date 6:01 PM 2023/3/14 9 | * @Description 10 | */ 11 | @Data 12 | public class TranscriptionRequest { 13 | 14 | private String file; 15 | 16 | /** 17 | * An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language. 18 | * */ 19 | @NotNull 20 | private String model = "whisper-1"; 21 | 22 | @NotNull 23 | private String prompt; 24 | 25 | private Integer temperature = 0; 26 | 27 | /** 28 | * The language of the input audio. Supplying the input language in ISO-639-1 format will improve accuracy and latency. 29 | * */ 30 | private String language = "ISO-639-1"; 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/response/BizException.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto.response; 2 | 3 | /** 4 | * @author njhxzhangjihong@126.com 5 | * @date 5:58 PM 2023/2/15 6 | * @Description 7 | */ 8 | public class BizException extends Exception { 9 | public BizException(String error) { 10 | super(error); 11 | } 12 | 13 | public BizException(Throwable cause) { 14 | super(cause); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/response/ChatResponse.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto.response; 2 | 3 | import lombok.Data; 4 | 5 | import java.util.Date; 6 | import java.util.List; 7 | 8 | /** 9 | * @author njhxzhangjihong@126.com 10 | * @date 9:50 AM 2023/3/2 11 | * @Description 12 | */ 13 | @Data 14 | public class ChatResponse { 15 | 16 | private String id; 17 | 18 | private String object; 19 | 20 | private Date created; 21 | 22 | private String model; 23 | 24 | private String errorMsg; 25 | 26 | private List choices; 27 | 28 | private CompletionUsage usage; 29 | 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/response/ChatRet.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto.response; 2 | 3 | import lombok.Data; 4 | import me.zhangjh.chatgpt.dto.Message; 5 | 6 | import java.util.List; 7 | 8 | /** 9 | * @author njhxzhangjihong@126.com 10 | * @date 9:51 AM 2023/3/2 11 | * @Description 12 | */ 13 | @Data 14 | public class ChatRet extends TextRet { 15 | 16 | private Message message; 17 | 18 | private List delta; 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/response/ChatStreamRet.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto.response; 2 | 3 | import lombok.Data; 4 | 5 | /** 6 | * @author njhxzhangjihong@126.com 7 | * @date 5:34 PM 2023/3/13 8 | * @Description 9 | */ 10 | @Data 11 | public class ChatStreamRet extends TextRet { 12 | 13 | private String content; 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/response/CompletionUsage.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto.response; 2 | 3 | import com.alibaba.fastjson.annotation.JSONField; 4 | import lombok.Data; 5 | 6 | /** 7 | * @author zhangjh 8 | * @date 3:03 PM 2022/12/15 9 | * @Description 10 | */ 11 | @Data 12 | public class CompletionUsage { 13 | @JSONField(name = "prompt_tokens") 14 | private int promptTokens; 15 | 16 | @JSONField(name = "completion_tokens") 17 | private int completionTokens; 18 | 19 | @JSONField(name = "total_tokens") 20 | private int totalTokens; 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/response/ImageResponse.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto.response; 2 | 3 | import lombok.Data; 4 | 5 | import java.util.Date; 6 | import java.util.List; 7 | 8 | /** 9 | * @author zhangjh 10 | * @date 3:13 PM 2022/12/15 11 | * @Description 12 | */ 13 | @Data 14 | public class ImageResponse { 15 | 16 | private Date create; 17 | 18 | private List data; 19 | 20 | private String errorMsg; 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/response/ImageRet.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto.response; 2 | 3 | import lombok.Data; 4 | 5 | /** 6 | * @author zhangjh 7 | * @date 3:13 PM 2022/12/15 8 | * @Description 9 | */ 10 | @Data 11 | public class ImageRet { 12 | private String url; 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/response/TextResponse.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto.response; 2 | 3 | import lombok.Data; 4 | 5 | import java.util.Date; 6 | import java.util.List; 7 | 8 | /** 9 | * @author zhangjh 10 | * @date 3:02 PM 2022/12/15 11 | * @Description 12 | */ 13 | @Data 14 | public class TextResponse { 15 | private String id; 16 | 17 | private String object; 18 | 19 | private Date created; 20 | 21 | private String model; 22 | 23 | private String errorMsg; 24 | 25 | private List choices; 26 | 27 | private CompletionUsage usage; 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/response/TextRet.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto.response; 2 | 3 | import lombok.Data; 4 | 5 | /** 6 | * @author zhangjh 7 | * @date 3:02 PM 2022/12/15 8 | * @Description 9 | */ 10 | @Data 11 | public class TextRet { 12 | private String text; 13 | 14 | private int index; 15 | 16 | private String finishReason; 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/dto/response/TranscriptionResponse.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.dto.response; 2 | 3 | import lombok.Data; 4 | 5 | /** 6 | * @author njhxzhangjihong@126.com 7 | * @date 6:09 PM 2023/3/14 8 | * @Description 9 | */ 10 | @Data 11 | public class TranscriptionResponse { 12 | 13 | private String text; 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/service/ChatGptServiceImpl.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.service; 2 | 3 | import com.alibaba.fastjson.JSONObject; 4 | import lombok.extern.slf4j.Slf4j; 5 | import me.zhangjh.chatgpt.client.ChatGptService; 6 | import me.zhangjh.chatgpt.dto.request.ChatRequest; 7 | import me.zhangjh.chatgpt.dto.request.ImageRequest; 8 | import me.zhangjh.chatgpt.dto.request.TextRequest; 9 | import me.zhangjh.chatgpt.dto.request.TranscriptionRequest; 10 | import me.zhangjh.chatgpt.dto.response.ChatResponse; 11 | import me.zhangjh.chatgpt.dto.response.ImageResponse; 12 | import me.zhangjh.chatgpt.dto.response.TextResponse; 13 | import me.zhangjh.chatgpt.dto.response.TranscriptionResponse; 14 | import me.zhangjh.chatgpt.socket.SocketServer; 15 | import me.zhangjh.chatgpt.util.BizHttpClientUtil; 16 | import me.zhangjh.share.util.HttpClientUtil; 17 | import me.zhangjh.share.util.HttpRequest; 18 | import org.apache.commons.collections4.MapUtils; 19 | import org.apache.commons.lang3.StringUtils; 20 | import org.springframework.beans.factory.annotation.Value; 21 | import org.springframework.util.Assert; 22 | 23 | import javax.annotation.PostConstruct; 24 | import java.util.HashMap; 25 | import java.util.Map; 26 | import java.util.function.Function; 27 | 28 | /** 29 | * @author zhangjh 30 | * @date 3:21 PM 2022/12/15 31 | * @Description 32 | */ 33 | @Slf4j 34 | public class ChatGptServiceImpl implements ChatGptService { 35 | 36 | @Value("${openai.apikey}") 37 | private String configApiKey; 38 | 39 | private String apiKey; 40 | 41 | private final Map header = new HashMap<>(); 42 | 43 | private static final String TEXT_COMPLETION_URL = "https://api.openai.com/v1/completions"; 44 | 45 | private static final String IMAGE_GENERATE_URL = "https://api.openai.com/v1/images/generations"; 46 | 47 | @Value("${openai.chat.url:https://api.openai.com/v1/chat/completions}") 48 | private String chatUrl; 49 | 50 | private static final String TRANSCRIPTION_URL = "https://api.openai.com/v1/audio/transcriptions"; 51 | 52 | @PostConstruct 53 | public void init() { 54 | if(StringUtils.isEmpty(configApiKey)) { 55 | configApiKey = System.getenv("openai.apikey"); 56 | } 57 | if(StringUtils.isEmpty(apiKey)) { 58 | apiKey = configApiKey; 59 | } 60 | Assert.isTrue(StringUtils.isNotEmpty(apiKey), "openai apiKey not exist"); 61 | // openAi 62 | if(apiKey.startsWith("sk-")) { 63 | header.put("Authorization", "Bearer " + apiKey); 64 | } else { 65 | // azure 66 | header.put("api-key", apiKey); 67 | } 68 | } 69 | 70 | /** 71 | * allow set apiKey from outside, but you must define this bean yourself 72 | */ 73 | public void setApiKey(String apiKey) { 74 | this.apiKey = apiKey; 75 | } 76 | 77 | @Override 78 | public TextResponse createTextCompletion(TextRequest textRequest, Map bizParams) { 79 | // this interface must set request.stream to false 80 | textRequest.check(); 81 | 82 | HttpRequest httpRequest = new HttpRequest(TEXT_COMPLETION_URL); 83 | httpRequest.setReqData(JSONObject.toJSONString(textRequest)); 84 | httpRequest.setBizHeaderMap(this.header); 85 | String response = HttpClientUtil.sendNormally(httpRequest).toString(); 86 | return JSONObject.parseObject(response, TextResponse.class); 87 | } 88 | 89 | @Override 90 | public ImageResponse createImageGeneration(ImageRequest imageRequest, Map bizParams) { 91 | imageRequest.check(); 92 | HttpRequest httpRequest = new HttpRequest(IMAGE_GENERATE_URL); 93 | httpRequest.setReqData(JSONObject.toJSONString(imageRequest)); 94 | httpRequest.setBizHeaderMap(this.header); 95 | String response = HttpClientUtil.sendNormally(httpRequest).toString(); 96 | return JSONObject.parseObject(response, ImageResponse.class); 97 | } 98 | 99 | @Override 100 | public ChatResponse createChatCompletion(ChatRequest request, Map bizParams) { 101 | HttpRequest httpRequest = new HttpRequest(chatUrl); 102 | httpRequest.setReqData(JSONObject.toJSONString(request)); 103 | httpRequest.setBizHeaderMap(this.header); 104 | String response = HttpClientUtil.sendNormally(httpRequest).toString(); 105 | return JSONObject.parseObject(response, ChatResponse.class); 106 | } 107 | 108 | @Override 109 | public void createChatCompletionStream(ChatRequest request, Map bizParams, 110 | SocketServer socketServer, Function bizCb) { 112 | request.setStream(true); 113 | Assert.isTrue(MapUtils.isNotEmpty(bizParams) && StringUtils.isNotEmpty(bizParams.get("userId")), 114 | "userId为空"); 115 | this.header.putAll(bizParams); 116 | HttpRequest httpRequest = new HttpRequest(chatUrl); 117 | httpRequest.setReqData(JSONObject.toJSONString(request)); 118 | httpRequest.setBizHeaderMap(this.header); 119 | 120 | BizHttpClientUtil.sendStream(httpRequest, socketServer, bizCb); 121 | } 122 | 123 | @Override 124 | public TranscriptionResponse createTranscription(TranscriptionRequest request, Map bizParams) { 125 | HttpRequest httpRequest = new HttpRequest(TRANSCRIPTION_URL); 126 | httpRequest.setReqData(JSONObject.toJSONString(request)); 127 | httpRequest.setBizHeaderMap(this.header); 128 | String response = BizHttpClientUtil.sendFileMultiPart(httpRequest).toString(); 129 | return JSONObject.parseObject(response, TranscriptionResponse.class); 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/socket/SocketServer.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.socket; 2 | 3 | import com.alibaba.fastjson.JSONObject; 4 | import lombok.Data; 5 | import lombok.extern.slf4j.Slf4j; 6 | import org.apache.commons.lang3.StringUtils; 7 | import org.springframework.util.Assert; 8 | 9 | import javax.annotation.PostConstruct; 10 | import javax.websocket.*; 11 | import javax.websocket.server.PathParam; 12 | import javax.websocket.server.ServerEndpoint; 13 | import java.io.EOFException; 14 | import java.io.IOException; 15 | import java.util.concurrent.ConcurrentHashMap; 16 | import java.util.concurrent.ConcurrentMap; 17 | 18 | /** 19 | * @author njhxzhangjihong@126.com 20 | * @date 3:03 PM 2023/3/13 21 | * @Description configure me as spring bean,you can override me to custom your sendMessage method 22 | */ 23 | @Data 24 | @Slf4j 25 | @ServerEndpoint("/socket/chatStream/{userId}") 26 | public class SocketServer { 27 | 28 | private String userId; 29 | private Session session; 30 | 31 | protected static ConcurrentMap userMap = new ConcurrentHashMap<>(); 32 | 33 | private static final String FINISH_FLAG = "[DONE]"; 34 | 35 | @PostConstruct 36 | public void init() { 37 | log.info("chatSocketServer inited"); 38 | } 39 | 40 | @OnOpen 41 | public void onOpen(Session session, @PathParam("userId") String userId) { 42 | this.session = session; 43 | this.userId = userId; 44 | log.info("onOpen...userId: {}", userId); 45 | userMap.putIfAbsent(userId, this); 46 | } 47 | 48 | @OnClose 49 | public void onClose() { 50 | log.info("onClose.."); 51 | // should notice clients to update connect status 52 | // when multiple clients connected 53 | SocketServer server = userMap.get(this.userId); 54 | if(server != null) { 55 | try { 56 | server.getSession().close(); 57 | } catch (IOException e) { 58 | log.error("socketServer close exception:", e); 59 | } 60 | } 61 | userMap.remove(this.userId); 62 | } 63 | 64 | @OnMessage 65 | public void onMessage(String msg, Session session) { 66 | // in this project, client will not send message to server 67 | } 68 | 69 | /** you must override this method to implements you biz log 70 | * message removed the beginning 'data:' 71 | * */ 72 | public void sendMessage(String userId, String message) { 73 | log.info("userId: {}, message: {}", userId, message); 74 | if(StringUtils.isNotEmpty(message)) { 75 | Assert.isTrue(userMap.size() != 0, "socket not connected"); 76 | Assert.isTrue(StringUtils.isNotEmpty(userId), "userId empty"); 77 | 78 | log.info("message: {}", message); 79 | sendMsgInternal(userId, message); 80 | } 81 | } 82 | 83 | private void sendMsgInternal(String userId, String message) { 84 | try { 85 | SocketServer server = userMap.get(userId); 86 | Assert.isTrue(server != null, "socketServer empty"); 87 | Session session = server.getSession(); 88 | Assert.isTrue(session.isOpen(), "socket not connected"); 89 | TextMessage textMessage = new TextMessage(); 90 | textMessage.setUserId(userId); 91 | textMessage.setMessage(message); 92 | log.info("textMessage: {}", textMessage); 93 | synchronized (session.getId()) { 94 | session.getBasicRemote().sendText(JSONObject.toJSONString(textMessage)); 95 | } 96 | } catch (IOException e) { 97 | log.error("socket sendText exception:", e); 98 | } 99 | } 100 | 101 | @OnError 102 | public void onError(Session session, Throwable t) { 103 | if(t instanceof EOFException) { 104 | log.info("socket timeout"); 105 | this.onClose(); 106 | } 107 | } 108 | 109 | @Data 110 | static class TextMessage { 111 | private String userId; 112 | private String message; 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /src/main/java/me/zhangjh/chatgpt/util/BizHttpClientUtil.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt.util; 2 | 3 | import com.alibaba.fastjson.JSONObject; 4 | import lombok.extern.slf4j.Slf4j; 5 | import me.zhangjh.chatgpt.dto.request.TranscriptionRequest; 6 | import me.zhangjh.chatgpt.dto.response.ChatResponse; 7 | import me.zhangjh.chatgpt.dto.response.ChatRet; 8 | import me.zhangjh.chatgpt.dto.response.ChatStreamRet; 9 | import me.zhangjh.chatgpt.socket.SocketServer; 10 | import me.zhangjh.share.util.HttpClientUtil; 11 | import me.zhangjh.share.util.HttpRequest; 12 | import okhttp3.*; 13 | import okhttp3.sse.EventSource; 14 | import okhttp3.sse.EventSourceListener; 15 | import okhttp3.sse.EventSources; 16 | import org.apache.commons.lang3.StringUtils; 17 | import org.jetbrains.annotations.NotNull; 18 | import org.jetbrains.annotations.Nullable; 19 | import org.springframework.util.Assert; 20 | 21 | import java.io.File; 22 | import java.io.IOException; 23 | import java.util.HashMap; 24 | import java.util.List; 25 | import java.util.Objects; 26 | import java.util.function.Function; 27 | 28 | /** 29 | * @author njhxzhangjihong@126.com 30 | * @date 1:05 AM 2023/3/17 31 | * @Description 32 | */ 33 | @Slf4j 34 | public class BizHttpClientUtil extends HttpClientUtil { 35 | 36 | private static HashMap msgMap = new HashMap<>(); 37 | 38 | private static final String FINISH_FLAG = "[DONE]"; 39 | 40 | public static void sendStream(HttpRequest httpRequest, SocketServer socketServer, Function bizCb) { 41 | EventSource.Factory factory = EventSources.createFactory(OK_HTTP_CLIENT); 42 | EventSourceListener eventSourceListener = new EventSourceListener() { 43 | @Override 44 | public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) { 45 | log.info("id: {}, type: {}, data: {}", id, type, data); 46 | // 异步结果回来后要从eventSource里取userId,因为不确定返回事件里的结果是哪次请求的 47 | String userId = eventSource.request().header("userId"); 48 | handleResponse(userId, data, socketServer, bizCb); 49 | } 50 | 51 | @Override 52 | public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) { 53 | handleException(t); 54 | } 55 | }; 56 | factory.newEventSource(buildRequest(httpRequest), eventSourceListener); 57 | } 58 | 59 | public static Object sendFileMultiPart(HttpRequest request) { 60 | TranscriptionRequest transcriptionRequest = JSONObject.parseObject(request.getReqData(), TranscriptionRequest.class); 61 | 62 | MediaType mediaType = MediaType.parse("multipart/form-data"); 63 | RequestBody body = new MultipartBody.Builder().setType(MultipartBody.FORM) 64 | .addFormDataPart("file", transcriptionRequest.getFile(), 65 | RequestBody.create(new File(transcriptionRequest.getFile()), mediaType)) 66 | .addFormDataPart("model", transcriptionRequest.getModel()) 67 | .build(); 68 | Request httpRequest = new Request.Builder() 69 | .url(request.getUrl()) 70 | .method("POST", body) 71 | .addHeader("Authorization", request.getBizHeaderMap().get("Authorization")) 72 | .addHeader("Content-Type", "multipart/form-data") 73 | .build(); 74 | try (Response response = OK_HTTP_CLIENT.newCall(httpRequest).execute()){ 75 | return handleResponse(Objects.requireNonNull(response.body())); 76 | } catch (IOException e) { 77 | log.error("sendFileMultiPart exception, ", e); 78 | throw new RuntimeException(e); 79 | } 80 | } 81 | 82 | private static void handleResponse(String userId, String data, 83 | SocketServer socketServer, Function bizCb) { 84 | Assert.isTrue(StringUtils.isNotEmpty(userId), "userId为空"); 85 | // if stream is true, send socket msg here 86 | // res format: data: {xxx} 87 | if(StringUtils.isNotEmpty(data)) { 88 | if(data.equals(FINISH_FLAG)) { 89 | socketServer.sendMessage(userId, data); 90 | String fullMsg = msgMap.get(userId).toString(); 91 | bizCb.apply(fullMsg); 92 | // 清空缓存 93 | msgMap.remove(userId); 94 | return; 95 | } 96 | ChatResponse chatResponse = JSONObject.parseObject(data, ChatResponse.class); 97 | List choices = chatResponse.getChoices(); 98 | StringBuilder partialMsg = msgMap.getOrDefault(userId, new StringBuilder()); 99 | 100 | for (ChatRet choice : choices) { 101 | List delta = choice.getDelta(); 102 | for (ChatStreamRet ret : delta) { 103 | if(ret == null) { 104 | continue; 105 | } 106 | String content = ret.getContent(); 107 | if(StringUtils.isNotEmpty(content)) { 108 | socketServer.sendMessage(userId, content); 109 | partialMsg.append(ret.getContent()); 110 | } 111 | } 112 | } 113 | msgMap.put(userId, partialMsg); 114 | } 115 | } 116 | 117 | private static String handleException(Throwable t) { 118 | log.error("sendStream exception:", t); 119 | return null; 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /src/main/resources/META-INF/spring.factories: -------------------------------------------------------------------------------- 1 | org.springframework.boot.autoconfigure.EnableAutoConfiguration=\ 2 | me.zhangjh.chatgpt.config.ChatGptConfig -------------------------------------------------------------------------------- /src/main/resources/application.properties: -------------------------------------------------------------------------------- 1 | openai.apikey=xxxxxxxxxxx -------------------------------------------------------------------------------- /src/test/java/me/zhangjh/chatgpt/Application.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt; 2 | 3 | import org.springframework.boot.SpringApplication; 4 | import org.springframework.boot.autoconfigure.SpringBootApplication; 5 | 6 | /** 7 | * @author zhangjh 8 | * @date 4:17 PM 2022/12/15 9 | * @Description 10 | */ 11 | @SpringBootApplication 12 | public class Application { 13 | public static void main(String[] args) { 14 | SpringApplication.run(Application.class, args); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/test/java/me/zhangjh/chatgpt/ChatGptTest.java: -------------------------------------------------------------------------------- 1 | package me.zhangjh.chatgpt; 2 | 3 | import me.zhangjh.chatgpt.client.ChatGptService; 4 | import me.zhangjh.chatgpt.constant.RoleEnum; 5 | import me.zhangjh.chatgpt.dto.Message; 6 | import me.zhangjh.chatgpt.dto.request.ChatRequest; 7 | import me.zhangjh.chatgpt.dto.request.ImageRequest; 8 | import me.zhangjh.chatgpt.dto.request.TextRequest; 9 | import me.zhangjh.chatgpt.dto.request.TranscriptionRequest; 10 | import me.zhangjh.chatgpt.dto.response.ChatResponse; 11 | import me.zhangjh.chatgpt.dto.response.ImageResponse; 12 | import me.zhangjh.chatgpt.dto.response.TextResponse; 13 | import me.zhangjh.chatgpt.dto.response.TranscriptionResponse; 14 | import org.junit.Test; 15 | import org.junit.runner.RunWith; 16 | import org.springframework.beans.factory.annotation.Autowired; 17 | import org.springframework.boot.test.context.SpringBootTest; 18 | import org.springframework.test.context.junit4.SpringRunner; 19 | 20 | import java.util.ArrayList; 21 | import java.util.HashMap; 22 | import java.util.List; 23 | import java.util.Map; 24 | 25 | /** 26 | * @author zhangjh 27 | * @date 4:14 PM 2022/12/15 28 | * @Description 29 | */ 30 | @RunWith(SpringRunner.class) 31 | @SpringBootTest 32 | public class ChatGptTest { 33 | 34 | @Autowired 35 | private ChatGptService chatGptService; 36 | 37 | @Test 38 | public void textCompletionTest() { 39 | TextRequest textRequest = new TextRequest(); 40 | textRequest.setPrompt("Q:写出java hello world?"); 41 | textRequest.setTemperature(0.5); 42 | textRequest.setMaxTokens(2048); 43 | textRequest.setBestOf(1); 44 | textRequest.setTopP(1); 45 | // textRequest.setPrompt("Q:将括号里的词汇翻译一下,如果是中文翻译成英文,如果是英文翻译成中文.(一只小狐狸正在吃葡萄)A:"); 46 | // TextResponse textCompletion = chatGptService.createTextCompletion(textRequest); 47 | TextResponse textCompletion = chatGptService.createTextCompletion(textRequest, null); 48 | 49 | System.out.println(textCompletion); 50 | 51 | } 52 | 53 | @Test 54 | public void imageGenerateTest() { 55 | ImageRequest imageRequest = new ImageRequest(); 56 | imageRequest.setPrompt("一只小狐狸正在吃葡萄"); 57 | ImageResponse imageGeneration = chatGptService.createImageGeneration(imageRequest, null); 58 | System.out.println(imageGeneration); 59 | } 60 | 61 | @Test 62 | public void chatTest() { 63 | ChatRequest chatRequest = new ChatRequest(); 64 | List messages = new ArrayList<>(); 65 | Message message = new Message(); 66 | message.setRole(RoleEnum.user.name()); 67 | message.setContent("什么是斐波那契数列?"); 68 | messages.add(message); 69 | Message answer = new Message(); 70 | answer.setRole(RoleEnum.assistant.name()); 71 | answer.setContent("斐波那契数列是指:1,1,2,3,5,8,13,21,34……这样一个由数列中前两项相加得出第三项的数列。这个数列与自然界中很多事物的增长规律有关,比如植物的叶片数量、兔子的繁殖规律等等。斐波那契数列的本质是递归定义,是算法和数学中的重要概念之一,也是计算机程序设计中常用的算法之一。"); 72 | messages.add(answer); 73 | Message curMessage = new Message(); 74 | curMessage.setRole(RoleEnum.user.name()); 75 | curMessage.setContent("可以给出代码示例吗,java的?"); 76 | messages.add(curMessage); 77 | chatRequest.setMessages(messages); 78 | Map bizParams = new HashMap<>(1); 79 | bizParams.put("userId", "1"); 80 | ChatResponse chatCompletion = chatGptService.createChatCompletion(chatRequest, bizParams); 81 | System.out.println(chatCompletion); 82 | } 83 | 84 | @Test 85 | public void audioTest() { 86 | TranscriptionRequest request = new TranscriptionRequest(); 87 | request.setFile("/root/test.m4a"); 88 | TranscriptionResponse transcription = chatGptService.createTranscription(request, null); 89 | System.out.println(transcription); 90 | } 91 | } 92 | --------------------------------------------------------------------------------