├── README.md ├── doc ├── img │ └── star-history-2024124.png └── test │ └── document │ ├── ConversationalRetrievalChainTest-中文.txt │ ├── ConversationalRetrievalChainTest-英文.txt │ ├── EmbeddingMemoryStoreTest.txt │ ├── 中文测试.docx │ ├── 中文测试.pdf │ ├── 中文测试.pptx │ ├── 中文测试.txt │ └── 中文测试.xlsx ├── pom.xml ├── smartFuse-baidu ├── pom.xml └── src │ ├── main │ └── java │ │ └── com │ │ └── ai │ │ └── baidu │ │ ├── client │ │ └── BaiduClient.java │ │ ├── converter │ │ └── BeanConverter.java │ │ ├── model │ │ ├── BaiduChatModel.java │ │ ├── BaiduEmbeddingModel.java │ │ └── BaiduImageModel.java │ │ └── parameter │ │ ├── BaiduChatModelParameter.java │ │ ├── BaiduImageModelParameter.java │ │ └── input │ │ ├── BaiduChatParameter.java │ │ └── BaiduImageParameter.java │ └── test │ └── java │ └── com │ └── ai │ └── baidu │ ├── chain │ ├── ConversationalChainTest.java │ └── ConversationalRetrievalChainTest.java │ ├── model │ └── ModelTest.java │ └── service │ └── ServiceTest.java ├── smartFuse-common ├── pom.xml └── src │ └── main │ └── java │ └── com │ └── ai │ └── common │ ├── resp │ ├── AiResponse.java │ ├── finish │ │ └── FinishReason.java │ └── usage │ │ └── TokenUsage.java │ └── util │ ├── Exceptions.java │ ├── PlaceHolderReplaceUtils.java │ ├── Utils.java │ └── ValidationUtils.java ├── smartFuse-domain ├── pom.xml └── src │ ├── main │ └── java │ │ └── com │ │ └── ai │ │ └── domain │ │ ├── chain │ │ ├── Chain.java │ │ └── impl │ │ │ ├── ConversationalChain.java │ │ │ └── ConversationalRetrievalChain.java │ │ ├── data │ │ ├── embedding │ │ │ ├── CosineSimilarity.java │ │ │ ├── Embedding.java │ │ │ └── EmbeddingMatch.java │ │ ├── images │ │ │ └── Image.java │ │ ├── message │ │ │ ├── AssistantMessage.java │ │ │ ├── ChatMessage.java │ │ │ ├── MessageType.java │ │ │ ├── SystemMessage.java │ │ │ └── UserMessage.java │ │ ├── moderation │ │ │ └── Moderation.java │ │ └── parameter │ │ │ └── Parameter.java │ │ ├── document │ │ ├── AbstractS3Loader.java │ │ ├── AwsCredentials.java │ │ ├── Document.java │ │ ├── DocumentLoaderUtils.java │ │ ├── DocumentType.java │ │ ├── FileSystemDocumentLoader.java │ │ ├── Metadata.java │ │ ├── S3DirectoryLoader.java │ │ ├── S3FileLoader.java │ │ ├── TextSegment.java │ │ ├── UrlDocumentLoader.java │ │ ├── parser │ │ │ ├── DocumentParser.java │ │ │ └── impl │ │ │ │ ├── MsOfficeDocumentParser.java │ │ │ │ ├── PdfDocumentParser.java │ │ │ │ └── TextDocumentParser.java │ │ ├── source │ │ │ ├── DocumentSource.java │ │ │ └── impl │ │ │ │ ├── FileSystemSource.java │ │ │ │ ├── S3Source.java │ │ │ │ └── UrlSource.java │ │ ├── splitter │ │ │ ├── DocumentSplitter.java │ │ │ └── impl │ │ │ │ ├── DocumentByCharacterSplitter.java │ │ │ │ ├── DocumentByLineSplitter.java │ │ │ │ ├── DocumentByParagraphSplitter.java │ │ │ │ ├── DocumentByRegexSplitter.java │ │ │ │ ├── DocumentBySentenceSplitter.java │ │ │ │ ├── DocumentByWordSplitter.java │ │ │ │ ├── DocumentSplitters.java │ │ │ │ ├── HierarchicalDocumentSplitter.java │ │ │ │ └── SegmentBuilder.java │ │ └── tokenizer │ │ │ ├── Tokenizer.java │ │ │ └── impl │ │ │ └── OpenAiTokenizer.java │ │ ├── memory │ │ ├── chat │ │ │ ├── ChatHistoryRecorder.java │ │ │ ├── ChatMemoryStore.java │ │ │ └── impl │ │ │ │ ├── SimpleChatHistoryRecorder.java │ │ │ │ └── SimpleChatMemoryStore.java │ │ └── embedding │ │ │ ├── EmbeddingMemoryStore.java │ │ │ ├── EmbeddingStoreIngestor.java │ │ │ ├── EmbeddingStoreJsonCodec.java │ │ │ ├── EmbeddingStoreJsonCodecFactory.java │ │ │ ├── EmbeddingStoreRetriever.java │ │ │ └── impl │ │ │ ├── GsonInMemoryEmbeddingStoreJsonCodec.java │ │ │ ├── SimpleEmbeddingMemoryStore.java │ │ │ ├── SimpleEmbeddingStoreIngestor.java │ │ │ └── SimpleEmbeddingStoreRetriever.java │ │ ├── model │ │ ├── AudioModel.java │ │ ├── ChatModel.java │ │ ├── EmbeddingModel.java │ │ ├── ImageModel.java │ │ ├── ModelTemplate.java │ │ ├── ModerationModel.java │ │ └── output │ │ │ ├── BigDecimalOutputParser.java │ │ │ ├── BigIntegerOutputParser.java │ │ │ ├── BooleanOutputParser.java │ │ │ ├── ByteOutputParser.java │ │ │ ├── DateOutputParser.java │ │ │ ├── DoubleOutputParser.java │ │ │ ├── EnumOutputParser.java │ │ │ ├── FloatOutputParser.java │ │ │ ├── IntOutputParser.java │ │ │ ├── LocalDateOutputParser.java │ │ │ ├── LocalDateTimeOutputParser.java │ │ │ ├── LocalTimeOutputParser.java │ │ │ ├── LongOutputParser.java │ │ │ └── ShortOutputParser.java │ │ ├── prompt │ │ ├── Prompt.java │ │ ├── PromptTemplate.java │ │ └── impl │ │ │ ├── SimplePrompt.java │ │ │ └── SimplePromptTemplate.java │ │ ├── service │ │ ├── AiServiceContext.java │ │ ├── AiServices.java │ │ ├── DefaultAiServices.java │ │ ├── OutputParser.java │ │ ├── ServiceOutputParser.java │ │ └── annotation │ │ │ ├── ChatConfig.java │ │ │ ├── Memory.java │ │ │ ├── MemoryId.java │ │ │ ├── Moderate.java │ │ │ ├── Prompt.java │ │ │ ├── SystemMessage.java │ │ │ ├── UserMessage.java │ │ │ └── V.java │ │ ├── spi │ │ ├── AiServicesFactory.java │ │ └── ServiceHelper.java │ │ └── tools │ │ ├── JsonSchemaProperty.java │ │ ├── ToolExecutionRequest.java │ │ ├── ToolParameters.java │ │ ├── ToolSpecification.java │ │ ├── ToolSpecifications.java │ │ └── annotation │ │ ├── P.java │ │ ├── Tool.java │ │ └── ToolMemoryId.java │ └── test │ └── java │ └── com │ └── ai │ └── domain │ ├── document │ ├── DocumentTest.java │ └── SplitterTest.java │ └── prompt │ ├── PromptAnnotationTest.java │ └── PromptTest.java ├── smartFuse-openai ├── pom.xml └── src │ ├── main │ └── java │ │ └── com │ │ └── ai │ │ └── openai │ │ ├── client │ │ └── OpenAiClient.java │ │ ├── converter │ │ └── BeanConverter.java │ │ ├── model │ │ ├── ModelConversionTemplate.java │ │ ├── OpenaiAudioModel.java │ │ ├── OpenaiChatModel.java │ │ ├── OpenaiEmbeddingModel.java │ │ ├── OpenaiImageModel.java │ │ └── OpenaiModerationModel.java │ │ └── parameter │ │ ├── OpenaiAudioModelSttParameter.java │ │ ├── OpenaiAudioModelTtsParameter.java │ │ ├── OpenaiChatModelParameter.java │ │ ├── OpenaiEmbeddingModelParameter.java │ │ ├── OpenaiImageModelParameter.java │ │ ├── OpenaiModerationModelParameter.java │ │ └── input │ │ ├── OpenaiAudioSttParameter.java │ │ ├── OpenaiAudioTtsParameter.java │ │ ├── OpenaiChatParameter.java │ │ ├── OpenaiEmbeddingParameter.java │ │ ├── OpenaiImageParameter.java │ │ └── OpenaiModerationParameter.java │ └── test │ └── java │ └── com │ └── ai │ └── openai │ ├── chain │ ├── ConversationalChainTest.java │ └── ConversationalRetrievalChainTest.java │ ├── model │ └── ModelTest.java │ └── service │ └── ServiceTest.java └── smartFuse-spark └── pom.xml /doc/img/star-history-2024124.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mainpropath/AI-SmartFuse-Framework/36842665c2767cab64bb9fa0429510d4feaffa2c/doc/img/star-history-2024124.png -------------------------------------------------------------------------------- /doc/test/document/ConversationalRetrievalChainTest-中文.txt: -------------------------------------------------------------------------------- 1 | 从前有个可爱的小姑娘,她总是带着奶奶送的红帽子,所以大家都叫她小红帽。 2 | 一天,妈妈对小红帽说: 来,小红帽,这里有一块蛋糕和一瓶酒,奶奶生病了,快给奶奶送去。路上要小心哟! 3 | 小红帽对妈妈说: 好,我会小心的。我去看奶奶啦! 说完小红帽就高兴地走了。她刚走进森林就碰到了一条狼。 4 | 小红帽说: 我要到奶奶家去。奶奶病了,我给她带了好吃的蛋糕和酒。 5 | 于是它对小红帽说: 小红帽,你看周围这些花多么美丽啊!采点给你奶奶吧,她一定会很开心的。 6 | 小红帽想: 是啊,这些花这么漂亮,奶奶一定会很高兴地。 于是她开始采花了。这时候,狼乘机跑到奶奶家,将奶奶一口就吞进了肚子,然后穿上奶奶的衣服,戴着奶奶的帽子躺在床上。 7 | 过了一会儿,小红帽来到了奶奶家,走到奶奶的床边,狼一下子扑起来,一口就把小红帽吞进了肚子,狼吃饱了,就躺到床上睡着了。 8 | 一位猎人走过,看奶奶的门是打开的,于是,他走进去看看,看见狼躺在床上,肚子还在动,于是,猎人拿起一把剪刀,把狼的肚子剪开了。 9 | 小红帽和奶奶都被救了出来,这时候猎人又搬来几块大石头,塞进狼的肚子。狼醒来之后想逃走,可是那些石头太重了,它刚站起来就跌到在地,摔死了。 10 | 奶奶吃了小红帽带来的蛋糕和酒,感觉好多了,他们高兴极了! 11 | -------------------------------------------------------------------------------- /doc/test/document/ConversationalRetrievalChainTest-英文.txt: -------------------------------------------------------------------------------- 1 | Once upon a time, there was a lovely little girl who always wore a red hat given by her grandmother, so everyone called her Little Red Riding Hood. 2 | One day, my mother said to Little Red Riding Hood, "Come, Little Red Riding Hood, here is a cake and a bottle of wine. Grandma is sick, hurry up and bring it to her.". Be careful on the road! 3 | Little Red Riding Hood said to her mother, "Okay, I will be careful.". I'm going to see my grandmother! After finishing speaking, Little Red Riding Hood happily left. She ran into a wolf as soon as she entered the forest. 4 | Little Red Riding Hood said, "I'm going to my grandmother's house.". My grandmother is sick, and I brought her delicious cake and wine. 5 | So it said to Little Red Riding Hood: Little Red Riding Hood, look at how beautiful these flowers are around you! Pick some for your grandmother, she will definitely be very happy. 6 | Little Red Riding Hood thought to herself, "Yes, these flowers are so beautiful. Grandma will definitely be very happy.". So she started picking flowers. At this moment, the wolf took the opportunity to run to his grandmother's house and swallowed her in one gulp. He then put on his grandmother's clothes and lay in bed wearing her hat. 7 | After a while, Little Red Riding Hood arrived at her grandmother's house and walked to her bed. The wolf pounced and swallowed Little Red Riding Hood in one gulp. When the wolf was full, it lay down on the bed and fell asleep. 8 | A hunter walked by and saw that his grandmother's door was open. So, he walked in and saw the wolf lying in bed with its stomach still moving. So, the hunter picked up a pair of scissors and cut open the wolf's stomach. 9 | Little Red Riding Hood and Grandma were both rescued when the hunter brought a few large stones and stuffed them into the wolf's belly. The wolf woke up and wanted to escape, but the stones were too heavy. As soon as it stood up, it fell to the ground and died. 10 | Grandma ate the cake and wine brought by Little Red Riding Hood and felt much better. They were extremely happy! -------------------------------------------------------------------------------- /doc/test/document/中文测试.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mainpropath/AI-SmartFuse-Framework/36842665c2767cab64bb9fa0429510d4feaffa2c/doc/test/document/中文测试.docx -------------------------------------------------------------------------------- /doc/test/document/中文测试.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mainpropath/AI-SmartFuse-Framework/36842665c2767cab64bb9fa0429510d4feaffa2c/doc/test/document/中文测试.pdf -------------------------------------------------------------------------------- /doc/test/document/中文测试.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mainpropath/AI-SmartFuse-Framework/36842665c2767cab64bb9fa0429510d4feaffa2c/doc/test/document/中文测试.pptx -------------------------------------------------------------------------------- /doc/test/document/中文测试.txt: -------------------------------------------------------------------------------- 1 | This is an example example example sentence.This is an example example example sentence. 2 | 你好,你好,你好。你好,你好,你好 3 | -------------------------------------------------------------------------------- /doc/test/document/中文测试.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mainpropath/AI-SmartFuse-Framework/36842665c2767cab64bb9fa0429510d4feaffa2c/doc/test/document/中文测试.xlsx -------------------------------------------------------------------------------- /smartFuse-baidu/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | AI-SmartFuse-Framework 7 | com.ai 8 | 1.0-SNAPSHOT 9 | 10 | 4.0.0 11 | 12 | smartFuse-baidu 13 | 14 | 15 | UTF-8 16 | UTF-8 17 | 1.8 18 | 8 19 | 8 20 | 21 | 22 | 23 | 24 | com.ai 25 | smartFuse-common 26 | 1.0-SNAPSHOT 27 | 28 | 29 | com.ai 30 | smartFuse-domain 31 | 1.0-SNAPSHOT 32 | 33 | 34 | junit 35 | junit 36 | 4.13.2 37 | test 38 | 39 | 40 | com.ai 41 | ai-baidu 42 | 1.0 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /smartFuse-baidu/src/main/java/com/ai/baidu/client/BaiduClient.java: -------------------------------------------------------------------------------- 1 | package com.ai.baidu.client; 2 | 3 | import com.ai.baidu.achieve.Configuration; 4 | import com.ai.baidu.achieve.defaults.DefaultBaiduSessionFactory; 5 | import com.ai.baidu.achieve.standard.session.AggregationSession; 6 | 7 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 8 | 9 | 10 | public class BaiduClient { 11 | 12 | private static AggregationSession aggregationSession; 13 | 14 | private static Configuration configuration; 15 | 16 | public static void SetConfiguration(Configuration configuration) { 17 | ensureNotNull(configuration, "configuration"); 18 | aggregationSession = new DefaultBaiduSessionFactory(configuration).openAggregationSession(); 19 | } 20 | 21 | public static Configuration GetConfiguration() { 22 | return configuration; 23 | } 24 | 25 | public static AggregationSession getAggregationSession() { 26 | return aggregationSession; 27 | } 28 | 29 | } 30 | -------------------------------------------------------------------------------- /smartFuse-baidu/src/main/java/com/ai/baidu/converter/BeanConverter.java: -------------------------------------------------------------------------------- 1 | package com.ai.baidu.converter; 2 | 3 | import com.ai.baidu.common.Usage; 4 | import com.ai.common.resp.usage.TokenUsage; 5 | 6 | 7 | public class BeanConverter { 8 | 9 | public static TokenUsage usage2tokenUsage(Usage usage) { 10 | return TokenUsage.usage(usage.getPromptTokens(), usage.getCompletionTokens(), usage.getTotalTokens()); 11 | } 12 | 13 | } 14 | -------------------------------------------------------------------------------- /smartFuse-baidu/src/main/java/com/ai/baidu/model/BaiduChatModel.java: -------------------------------------------------------------------------------- 1 | package com.ai.baidu.model; 2 | 3 | import cn.hutool.core.bean.BeanUtil; 4 | import com.ai.baidu.achieve.standard.session.ChatSession; 5 | import com.ai.baidu.client.BaiduClient; 6 | import com.ai.baidu.converter.BeanConverter; 7 | import com.ai.baidu.endPoint.chat.Message; 8 | import com.ai.baidu.endPoint.chat.req.ChatRequest; 9 | import com.ai.baidu.endPoint.chat.resp.ChatResponse; 10 | import com.ai.baidu.parameter.BaiduChatModelParameter; 11 | import com.ai.baidu.parameter.input.BaiduChatParameter; 12 | import com.ai.common.resp.AiResponse; 13 | import com.ai.common.resp.finish.FinishReason; 14 | import com.ai.common.resp.usage.TokenUsage; 15 | import com.ai.common.util.Exceptions; 16 | import com.ai.domain.data.message.AssistantMessage; 17 | import com.ai.domain.data.message.ChatMessage; 18 | import com.ai.domain.data.message.MessageType; 19 | import com.ai.domain.data.parameter.Parameter; 20 | import com.ai.domain.model.ChatModel; 21 | 22 | import java.util.List; 23 | import java.util.stream.Collectors; 24 | 25 | import static com.ai.common.util.ValidationUtils.ensureNotEmpty; 26 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 27 | 28 | 29 | public class BaiduChatModel implements ChatModel { 30 | 31 | private final ChatSession chatSession = BaiduClient.getAggregationSession().getChatSession(); 32 | private Parameter parameter; 33 | 34 | public BaiduChatModel() { 35 | this(new BaiduChatModelParameter()); 36 | } 37 | 38 | public BaiduChatModel(Parameter parameter) { 39 | this.parameter = ensureNotNull(parameter, "parameter"); 40 | } 41 | 42 | public static List chatMessageList2BaiduMessageList(List chatMessages) { 43 | return chatMessages.stream().map(chatMessage -> { 44 | // 百度不支持System类型的消息 45 | if (chatMessage.type().getMessageType().equals(MessageType.SYSTEM)) 46 | throw Exceptions.runtime(String.format("Baidu API does not support [%s] type messages", MessageType.SYSTEM.getMessageType())); 47 | return Message.builder().role(chatMessage.type().getMessageType()).content(chatMessage.text()).build(); 48 | }).collect(Collectors.toList()); 49 | } 50 | 51 | public Parameter getParameter() { 52 | return parameter; 53 | } 54 | 55 | public void setParameter(Parameter parameter) { 56 | this.parameter = ensureNotNull(parameter, "parameter"); 57 | } 58 | 59 | @Override 60 | public AiResponse generate(List messages) { 61 | ensureNotEmpty(messages, "messages"); 62 | // 将消息转换为百度格式的消息 63 | List messageList = chatMessageList2BaiduMessageList(messages); 64 | // 构造请求主要参数 65 | ChatRequest chatRequest = ChatRequest.builder().messages(messageList).build(); 66 | // 填充请求配置属性 67 | BeanUtil.copyProperties(parameter.getParameter(), chatRequest); 68 | // 发起请求获取结果 69 | ChatResponse chatResponse = this.chatSession.chat(chatRequest); 70 | // 转换结果为统一返回值 71 | AssistantMessage message = AssistantMessage.message(chatResponse.getResult()); 72 | TokenUsage tokenUsage = BeanConverter.usage2tokenUsage(chatResponse.getUsage()); 73 | return AiResponse.R(message, tokenUsage, FinishReason.success()); 74 | } 75 | 76 | } 77 | -------------------------------------------------------------------------------- /smartFuse-baidu/src/main/java/com/ai/baidu/model/BaiduEmbeddingModel.java: -------------------------------------------------------------------------------- 1 | package com.ai.baidu.model; 2 | 3 | import com.ai.baidu.achieve.standard.session.EmbeddingSession; 4 | import com.ai.baidu.client.BaiduClient; 5 | import com.ai.baidu.endPoint.embedding.EmbeddingData; 6 | import com.ai.baidu.endPoint.embedding.req.EmbeddingRequest; 7 | import com.ai.baidu.endPoint.embedding.resp.EmbeddingResponse; 8 | import com.ai.common.resp.AiResponse; 9 | import com.ai.common.resp.finish.FinishReason; 10 | import com.ai.domain.data.embedding.Embedding; 11 | import com.ai.domain.model.EmbeddingModel; 12 | 13 | import java.util.List; 14 | import java.util.stream.Collectors; 15 | 16 | import static com.ai.common.util.ValidationUtils.ensureNotEmpty; 17 | 18 | 19 | public class BaiduEmbeddingModel implements EmbeddingModel { 20 | 21 | private final EmbeddingSession embeddingSession = BaiduClient.getAggregationSession().getEmbeddingSession(); 22 | 23 | public static Embedding embeddingData2Embedding(EmbeddingData embeddingData) { 24 | return new Embedding(embeddingData.getEmbedding(), embeddingData.getContent()); 25 | } 26 | 27 | public static List embeddingDataList2EmbeddingList(List embeddingDataList) { 28 | return embeddingDataList.stream() 29 | .map(EmbeddingData -> embeddingData2Embedding(EmbeddingData)) 30 | .collect(Collectors.toList()); 31 | } 32 | 33 | @Override 34 | public AiResponse> embed(List text) { 35 | ensureNotEmpty(text, "text"); 36 | EmbeddingResponse embedding = embeddingSession.embedding(EmbeddingRequest.baseBuild(text)); 37 | return AiResponse.R(embeddingDataList2EmbeddingList(embedding.getData()), FinishReason.success()); 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /smartFuse-baidu/src/main/java/com/ai/baidu/model/BaiduImageModel.java: -------------------------------------------------------------------------------- 1 | package com.ai.baidu.model; 2 | 3 | 4 | import cn.hutool.core.bean.BeanUtil; 5 | import com.ai.baidu.achieve.standard.session.ImageSession; 6 | import com.ai.baidu.client.BaiduClient; 7 | import com.ai.baidu.endPoint.images.ImageData; 8 | import com.ai.baidu.endPoint.images.req.ImageRequest; 9 | import com.ai.baidu.parameter.BaiduImageModelParameter; 10 | import com.ai.baidu.parameter.input.BaiduImageParameter; 11 | import com.ai.common.resp.AiResponse; 12 | import com.ai.common.resp.finish.FinishReason; 13 | import com.ai.domain.data.images.Image; 14 | import com.ai.domain.data.parameter.Parameter; 15 | import com.ai.domain.model.ImageModel; 16 | 17 | import java.util.List; 18 | import java.util.stream.Collectors; 19 | 20 | import static com.ai.common.util.ValidationUtils.ensureNotBlank; 21 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 22 | 23 | public class BaiduImageModel implements ImageModel { 24 | 25 | private final ImageSession imageSession = BaiduClient.getAggregationSession().getImageSession(); 26 | private Parameter parameter; 27 | 28 | public BaiduImageModel() { 29 | this.parameter = new BaiduImageModelParameter(); 30 | } 31 | 32 | public BaiduImageModel(Parameter parameter) { 33 | this.parameter = ensureNotNull(parameter, "parameter"); 34 | } 35 | 36 | public static List imageDataList2ImageList(List imageDataList) { 37 | return imageDataList.stream() 38 | .map(imageData -> imageData2Image(imageData)) 39 | .collect(Collectors.toList()); 40 | } 41 | 42 | public static Image imageData2Image(ImageData imageData) { 43 | return Image.b64Json(imageData.getB64Image()); 44 | } 45 | 46 | public Parameter getParameter() { 47 | return parameter; 48 | } 49 | 50 | public void setParameter(Parameter parameter) { 51 | this.parameter = ensureNotNull(parameter, "parameter"); 52 | } 53 | 54 | @Override 55 | public AiResponse> create(String prompt, String size, String style, int n) { 56 | ensureNotBlank(prompt, "prompt"); 57 | // 构造请求主要参数 58 | ImageRequest request = ImageRequest.builder() 59 | .prompt(prompt).size(size).style(style).n(n).build(); 60 | // 填充请求配置属性 61 | BeanUtil.copyProperties(parameter.getParameter(), request); 62 | // 发起请求获取结果 63 | List imageDataList = imageSession.text2image(request).getData(); 64 | // 转换结果为统一返回值 65 | return AiResponse.R(imageDataList2ImageList(imageDataList), FinishReason.success()); 66 | } 67 | 68 | } 69 | -------------------------------------------------------------------------------- /smartFuse-baidu/src/main/java/com/ai/baidu/parameter/BaiduChatModelParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.baidu.parameter; 2 | 3 | import com.ai.baidu.parameter.input.BaiduChatParameter; 4 | import com.ai.domain.data.parameter.Parameter; 5 | 6 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 7 | 8 | 9 | public class BaiduChatModelParameter implements Parameter { 10 | 11 | private BaiduChatParameter parameter; 12 | 13 | public BaiduChatModelParameter() { 14 | this(BaiduChatParameter.builder().build()); 15 | } 16 | 17 | public BaiduChatModelParameter(BaiduChatParameter parameter) { 18 | this.parameter = ensureNotNull(parameter, "BaiduChatParameter"); 19 | } 20 | 21 | @Override 22 | public BaiduChatParameter getParameter() { 23 | return this.parameter; 24 | } 25 | 26 | @Override 27 | public void SetParameter(BaiduChatParameter parameter) { 28 | this.parameter = parameter; 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /smartFuse-baidu/src/main/java/com/ai/baidu/parameter/BaiduImageModelParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.baidu.parameter; 2 | 3 | import com.ai.baidu.parameter.input.BaiduImageParameter; 4 | import com.ai.domain.data.parameter.Parameter; 5 | 6 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 7 | 8 | public class BaiduImageModelParameter implements Parameter { 9 | 10 | private BaiduImageParameter parameter; 11 | 12 | public BaiduImageModelParameter() { 13 | this(BaiduImageParameter.builder().build()); 14 | } 15 | 16 | public BaiduImageModelParameter(BaiduImageParameter parameter) { 17 | this.parameter = ensureNotNull(parameter, "BaiduImageParameter"); 18 | } 19 | 20 | @Override 21 | public BaiduImageParameter getParameter() { 22 | return this.parameter; 23 | } 24 | 25 | @Override 26 | public void SetParameter(BaiduImageParameter parameter) { 27 | this.parameter = parameter; 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /smartFuse-baidu/src/main/java/com/ai/baidu/parameter/input/BaiduChatParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.baidu.parameter.input; 2 | 3 | import com.fasterxml.jackson.annotation.JsonIgnoreProperties; 4 | import com.fasterxml.jackson.annotation.JsonInclude; 5 | import com.fasterxml.jackson.annotation.JsonProperty; 6 | import lombok.AllArgsConstructor; 7 | import lombok.Builder; 8 | import lombok.Data; 9 | import lombok.NoArgsConstructor; 10 | 11 | import java.io.Serializable; 12 | import java.util.List; 13 | 14 | @Data 15 | @Builder 16 | @NoArgsConstructor 17 | @AllArgsConstructor 18 | @JsonIgnoreProperties(ignoreUnknown = true) 19 | @JsonInclude(JsonInclude.Include.NON_NULL) 20 | public class BaiduChatParameter implements Serializable { 21 | 22 | /** 23 | * (1)较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 24 | * (2)默认0.8,范围 (0, 1.0],不能为0 25 | */ 26 | private Double temperature; 27 | 28 | /** 29 | * (1)影响输出文本的多样性,取值越大,生成文本的多样性越强 30 | * (2)默认0.8,取值范围 [0, 1.0] 31 | */ 32 | @JsonProperty("top_p") 33 | private Double topP; 34 | 35 | /** 36 | * 通过对已生成的token增加惩罚,减少重复生成的现象。说明: 37 | * (1)值越大表示惩罚越大 38 | * (2)默认1.0,取值范围:[1.0, 2.0] 39 | */ 40 | @JsonProperty("penalty_score") 41 | private Double penaltyScore; 42 | 43 | /** 44 | * 是否以流式接口的形式返回数据,默认false 45 | */ 46 | private Boolean stream; 47 | 48 | /** 49 | * 模型人设,主要用于人设设定,例如,你是xxx公司制作的AI助手,说明: 50 | * (1)长度限制,最后一个message的content长度(即此轮对话的问题)和system字段总内容不能超过20000个字符,且不能超过5120 tokens 51 | */ 52 | private String system; 53 | 54 | /** 55 | * 生成停止标识,当模型生成结果以stop中某个元素结尾时,停止文本生成。说明: 56 | * (1)每个元素长度不超过20字符 57 | * (2)最多4个元素 58 | */ 59 | private List stop; 60 | 61 | /** 62 | * 否强制关闭实时搜索功能,默认false,表示不关闭 63 | */ 64 | @JsonProperty("disable_search") 65 | private Boolean disableSearch; 66 | 67 | /** 68 | * 是否开启上角标返回,说明: 69 | * (1)开启后,有概率触发搜索溯源信息search_info,search_info内容见响应参数介绍 70 | * (2)默认false,不开启 71 | */ 72 | @JsonProperty("enable_citation") 73 | private Boolean enableCitation; 74 | 75 | /** 76 | * 指定模型最大输出token数,范围[2, 2048] 77 | */ 78 | @JsonProperty("max_output_tokens") 79 | private Integer maxOutputTokens; 80 | 81 | /** 82 | * 指定响应内容的格式,说明: 83 | * (1)可选值: 84 | * json_object:以json格式返回,可能出现不满足效果情况 85 | * text:以文本格式返回 86 | * (2)如果不填写参数response_format值,默认为text 87 | */ 88 | @JsonProperty("response_format") 89 | private String responseFormat; 90 | 91 | /** 92 | * 表示最终用户的唯一标识符 93 | */ 94 | @JsonProperty("user_id") 95 | private String userId; 96 | 97 | } 98 | -------------------------------------------------------------------------------- /smartFuse-baidu/src/main/java/com/ai/baidu/parameter/input/BaiduImageParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.baidu.parameter.input; 2 | 3 | import com.fasterxml.jackson.annotation.JsonIgnoreProperties; 4 | import com.fasterxml.jackson.annotation.JsonInclude; 5 | import com.fasterxml.jackson.annotation.JsonProperty; 6 | import lombok.AllArgsConstructor; 7 | import lombok.Builder; 8 | import lombok.Data; 9 | import lombok.NoArgsConstructor; 10 | 11 | import java.io.Serializable; 12 | 13 | @Data 14 | @Builder 15 | @NoArgsConstructor 16 | @AllArgsConstructor 17 | @JsonIgnoreProperties(ignoreUnknown = true) 18 | @JsonInclude(JsonInclude.Include.NON_NULL) 19 | public class BaiduImageParameter implements Serializable { 20 | 21 | 22 | /** 23 | * 生成图片数量,说明: 24 | * · 默认值为1 25 | * · 取值范围为1-4 26 | * · 单次生成的图片较多及请求较频繁可能导致请求超时 27 | */ 28 | private Integer n; 29 | 30 | /** 31 | * 迭代轮次,说明: 32 | * · 默认值为20 33 | * · 取值范围为10-50 34 | */ 35 | private Integer steps; 36 | 37 | /** 38 | * 采样方式,默认值:Euler a,可选值如下(释义参考): 39 | * · Euler 40 | * · Euler a 41 | * · DPM++ 2M 42 | * · DPM++ 2M Karras 43 | * · LMS Karras 44 | * · DPM++ SDE 45 | * · DPM++ SDE Karras 46 | * · DPM2 a Karras 47 | * · Heun 48 | * · DPM++ 2M SDE 49 | * · DPM++ 2M SDE Karras 50 | * · DPM2 51 | * · DPM2 Karras 52 | * · DPM2 a 53 | * · LMS 54 | */ 55 | @JsonProperty("sampler_index") 56 | private String samplerIndex; 57 | 58 | /** 59 | * 随机种子,说明: 60 | * · 不设置时,自动生成随机数 61 | * · 取值范围 [0, 4294967295] 62 | */ 63 | private Integer seed; 64 | 65 | /** 66 | * 提示词相关性,说明:默认值为5,取值范围0-30 67 | */ 68 | @JsonProperty("cfg_scale") 69 | private Double cfgScale; 70 | 71 | /** 72 | * 表示最终用户的唯一标识符 73 | */ 74 | @JsonProperty("user_id") 75 | private String userId; 76 | 77 | } 78 | -------------------------------------------------------------------------------- /smartFuse-baidu/src/test/java/com/ai/baidu/chain/ConversationalChainTest.java: -------------------------------------------------------------------------------- 1 | package com.ai.baidu.chain; 2 | 3 | import com.ai.baidu.achieve.ApiData; 4 | import com.ai.baidu.achieve.Configuration; 5 | import com.ai.baidu.client.BaiduClient; 6 | import com.ai.baidu.model.BaiduChatModel; 7 | import com.ai.core.strategy.impl.FirstKeyStrategy; 8 | import com.ai.domain.chain.impl.ConversationalChain; 9 | import com.ai.domain.memory.chat.impl.SimpleChatHistoryRecorder; 10 | import org.junit.Before; 11 | import org.junit.Test; 12 | 13 | import java.util.Arrays; 14 | 15 | /** 16 | * 测试链路功能 17 | **/ 18 | public class ConversationalChainTest { 19 | 20 | private ConversationalChain conversationalChain; 21 | 22 | @Before 23 | public void test_create_conversational_chain() { 24 | // 设置配置信息 25 | Configuration configuration = new Configuration(); 26 | configuration.setApiHost("https://aip.baidubce.com"); 27 | ApiData apiData = ApiData.builder() 28 | .apiKey("**************************") 29 | .secretKey("**************************") 30 | .appId("**************************") 31 | .build(); 32 | configuration.setKeyList(Arrays.asList(apiData)); 33 | configuration.setKeyStrategy(new FirstKeyStrategy()); 34 | BaiduClient.SetConfiguration(configuration); 35 | this.conversationalChain = ConversationalChain.builder() 36 | .chatModel(new BaiduChatModel()) 37 | .historyRecorder(SimpleChatHistoryRecorder.builder().build()) 38 | .build(); 39 | } 40 | 41 | @Test 42 | public void test_conversational_chain_run() { 43 | String res1 = conversationalChain.run("你好,请记住我的名字叫做小明"); 44 | System.out.println(res1);// 你好,小明!很高兴认识你。 45 | String res2 = conversationalChain.run("我的名字是什么?"); 46 | System.out.println(res2);// 你的名字是小明。 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /smartFuse-baidu/src/test/java/com/ai/baidu/chain/ConversationalRetrievalChainTest.java: -------------------------------------------------------------------------------- 1 | package com.ai.baidu.chain; 2 | 3 | 4 | import com.ai.baidu.achieve.ApiData; 5 | import com.ai.baidu.achieve.Configuration; 6 | import com.ai.baidu.client.BaiduClient; 7 | import com.ai.baidu.model.BaiduChatModel; 8 | import com.ai.baidu.model.BaiduEmbeddingModel; 9 | import com.ai.core.strategy.impl.FirstKeyStrategy; 10 | import com.ai.domain.chain.impl.ConversationalRetrievalChain; 11 | import com.ai.domain.document.Document; 12 | import com.ai.domain.document.FileSystemDocumentLoader; 13 | import com.ai.domain.memory.chat.impl.SimpleChatHistoryRecorder; 14 | import com.ai.domain.memory.embedding.impl.SimpleEmbeddingStoreIngestor; 15 | import com.ai.domain.memory.embedding.impl.SimpleEmbeddingStoreRetriever; 16 | import org.junit.Before; 17 | import org.junit.Test; 18 | 19 | import java.io.File; 20 | import java.nio.file.Path; 21 | import java.nio.file.Paths; 22 | import java.util.Arrays; 23 | 24 | 25 | public class ConversationalRetrievalChainTest { 26 | 27 | private ConversationalRetrievalChain conversationalRetrievalChain; 28 | 29 | public static Path toPath(String fileName) { 30 | File file = new File(fileName); 31 | if (file.exists()) { 32 | try { 33 | return Paths.get(file.toURI()); 34 | } catch (Exception e) { 35 | e.printStackTrace(); 36 | } 37 | } 38 | return null; 39 | } 40 | 41 | @Before 42 | public void before() { 43 | // 设置配置信息 44 | Configuration configuration = new Configuration(); 45 | configuration.setApiHost("https://aip.baidubce.com"); 46 | ApiData apiData = ApiData.builder() 47 | .apiKey("**************************") 48 | .secretKey("**************************") 49 | .appId("**************************") 50 | .build(); 51 | configuration.setKeyList(Arrays.asList(apiData)); 52 | configuration.setKeyStrategy(new FirstKeyStrategy()); 53 | BaiduClient.SetConfiguration(configuration); 54 | // 测试文件路径 55 | String filePath = "D:\\chatGPT-api\\AI-SmartFuse-Framework\\doc\\test\\document\\ConversationalRetrievalChainTest-中文.txt"; 56 | // "D:\\chatGPT-api\\AI-SmartFuse-Framework\\doc\\test\\document\\ConversationalRetrievalChainTest-英文.txt"; 57 | // 创建嵌入数据导入器,这里可以设置你指定的存储器,也可以直接使用其中默认的存储器。 58 | SimpleEmbeddingStoreIngestor ingestor = SimpleEmbeddingStoreIngestor.builder().embeddingModel(new BaiduEmbeddingModel()).build(); 59 | Document document = FileSystemDocumentLoader.loadDocument(toPath(filePath)); 60 | // 将数据导入到存储器当中 61 | ingestor.ingest(document); 62 | // 获取存储器,并设置其对应的检索器,向检索器当中设置检索器检索的嵌入存储器。 63 | this.conversationalRetrievalChain = ConversationalRetrievalChain.builder() 64 | .chatModel(new BaiduChatModel()) 65 | .embeddingModel(new BaiduEmbeddingModel()) 66 | .historyRecorder(SimpleChatHistoryRecorder.builder().build()) 67 | .retriever(SimpleEmbeddingStoreRetriever.builder().embeddingMemoryStore(ingestor.getStore()).build()) 68 | .build(); 69 | } 70 | 71 | @Test 72 | public void test_embedding_data_retriever_with_en() { 73 | String question = "What kind of person is Little Red Riding Hood?"; 74 | String res = conversationalRetrievalChain.run(question); 75 | System.out.println(res); 76 | } 77 | 78 | @Test 79 | public void test_embedding_data_retriever_with_ch() { 80 | String question = "小红帽要去干什么?"; 81 | String res = conversationalRetrievalChain.run(question); 82 | System.out.println(res); 83 | } 84 | 85 | 86 | } 87 | -------------------------------------------------------------------------------- /smartFuse-baidu/src/test/java/com/ai/baidu/model/ModelTest.java: -------------------------------------------------------------------------------- 1 | package com.ai.baidu.model; 2 | 3 | 4 | import com.ai.baidu.achieve.ApiData; 5 | import com.ai.baidu.achieve.Configuration; 6 | import com.ai.baidu.achieve.standard.session.AggregationSession; 7 | import com.ai.baidu.client.BaiduClient; 8 | import com.ai.common.resp.AiResponse; 9 | import com.ai.core.strategy.impl.FirstKeyStrategy; 10 | import com.ai.domain.data.embedding.Embedding; 11 | import com.ai.domain.data.images.Image; 12 | import com.ai.domain.model.ChatModel; 13 | import com.ai.domain.model.EmbeddingModel; 14 | import com.ai.domain.model.ImageModel; 15 | import org.junit.Before; 16 | import org.junit.Test; 17 | 18 | import java.util.Arrays; 19 | 20 | public class ModelTest { 21 | 22 | private AggregationSession aggregationSession; 23 | 24 | @Before 25 | public void test_model_before() { 26 | // 设置配置信息 27 | Configuration configuration = new Configuration(); 28 | configuration.setApiHost("https://aip.baidubce.com"); 29 | ApiData apiData = ApiData.builder() 30 | .apiKey("**************************") 31 | .secretKey("**************************") 32 | .appId("**************************") 33 | .build(); 34 | configuration.setKeyList(Arrays.asList(apiData)); 35 | configuration.setKeyStrategy(new FirstKeyStrategy()); 36 | BaiduClient.SetConfiguration(configuration); 37 | } 38 | 39 | @Test 40 | public void test_chat() { 41 | ChatModel baiduChatModel = new BaiduChatModel(); 42 | String res = baiduChatModel.generate("你好"); 43 | System.out.println(res); 44 | } 45 | 46 | @Test 47 | public void test_embedding() { 48 | EmbeddingModel embeddingModel = new BaiduEmbeddingModel(); 49 | AiResponse response = embeddingModel.embed("你好"); 50 | System.out.println(response); 51 | } 52 | 53 | @Test 54 | public void test_image() { 55 | ImageModel imageModel = new BaiduImageModel(); 56 | AiResponse imageAiResponse = imageModel.create("画一幅山水画"); 57 | System.out.println(imageAiResponse.getData()); 58 | } 59 | 60 | 61 | } 62 | -------------------------------------------------------------------------------- /smartFuse-common/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | AI-SmartFuse-Framework 7 | com.ai 8 | 1.0-SNAPSHOT 9 | 10 | 4.0.0 11 | 12 | smartFuse-common 13 | 14 | 15 | UTF-8 16 | UTF-8 17 | 1.8 18 | 8 19 | 8 20 | 21 | 22 | 23 | 24 | cn.hutool 25 | hutool-all 26 | 5.8.18 27 | 28 | 29 | org.projectlombok 30 | lombok 31 | 1.18.26 32 | compile 33 | 34 | 35 | com.fasterxml.jackson.core 36 | jackson-databind 37 | 2.13.3 38 | 39 | 40 | junit 41 | junit 42 | 4.13.2 43 | test 44 | 45 | 46 | com.squareup.retrofit2 47 | retrofit 48 | 2.9.0 49 | 50 | 51 | com.squareup.retrofit2 52 | converter-jackson 53 | 2.9.0 54 | 55 | 56 | com.squareup.retrofit2 57 | adapter-rxjava2 58 | 2.9.0 59 | 60 | 61 | com.squareup.okhttp3 62 | okhttp-sse 63 | 4.9.3 64 | 65 | 66 | com.squareup.okhttp3 67 | logging-interceptor 68 | 4.9.3 69 | 70 | 71 | org.slf4j 72 | slf4j-api 73 | 2.0.6 74 | 75 | 76 | org.slf4j 77 | slf4j-simple 78 | 2.0.6 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /smartFuse-common/src/main/java/com/ai/common/resp/AiResponse.java: -------------------------------------------------------------------------------- 1 | package com.ai.common.resp; 2 | 3 | 4 | import com.ai.common.resp.finish.FinishReason; 5 | import com.ai.common.resp.usage.TokenUsage; 6 | import lombok.AllArgsConstructor; 7 | import lombok.Data; 8 | import lombok.NoArgsConstructor; 9 | 10 | @Data 11 | @NoArgsConstructor 12 | @AllArgsConstructor 13 | public class AiResponse { 14 | 15 | private T data; 16 | private TokenUsage tokenUsage; 17 | private FinishReason finishReason; 18 | 19 | public static AiResponse R(T data, TokenUsage tokenUsage, FinishReason finishReason) { 20 | return new AiResponse<>(data, tokenUsage, finishReason); 21 | } 22 | 23 | public static AiResponse R(T data) { 24 | return R(data, null, null); 25 | } 26 | 27 | public static AiResponse R(T data, FinishReason finishReason) { 28 | return R(data, null, finishReason); 29 | } 30 | 31 | public static AiResponse R(T data, TokenUsage tokenUsage) { 32 | return R(data, tokenUsage, null); 33 | } 34 | 35 | } 36 | -------------------------------------------------------------------------------- /smartFuse-common/src/main/java/com/ai/common/resp/finish/FinishReason.java: -------------------------------------------------------------------------------- 1 | package com.ai.common.resp.finish; 2 | 3 | 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | @Data 9 | @NoArgsConstructor 10 | @AllArgsConstructor 11 | public class FinishReason { 12 | 13 | private String description; 14 | 15 | public static FinishReason Finish(String description) { 16 | return new FinishReason(description); 17 | } 18 | 19 | public static FinishReason success() { 20 | return new FinishReason("success"); 21 | } 22 | 23 | public static FinishReason error() { 24 | return new FinishReason("error"); 25 | } 26 | 27 | public static FinishReason timeout() { 28 | return new FinishReason("timeout"); 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /smartFuse-common/src/main/java/com/ai/common/resp/usage/TokenUsage.java: -------------------------------------------------------------------------------- 1 | package com.ai.common.resp.usage; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Data; 5 | import lombok.NoArgsConstructor; 6 | 7 | @Data 8 | @NoArgsConstructor 9 | @AllArgsConstructor 10 | public class TokenUsage { 11 | 12 | private Integer inputTokenCount; 13 | private Integer outputTokenCount; 14 | private Integer totalTokenCount; 15 | 16 | public static TokenUsage usage(Integer inputTokenCount, Integer outputTokenCount, Integer totalTokenCount) { 17 | return new TokenUsage(inputTokenCount, outputTokenCount, totalTokenCount); 18 | } 19 | 20 | } 21 | -------------------------------------------------------------------------------- /smartFuse-common/src/main/java/com/ai/common/util/Exceptions.java: -------------------------------------------------------------------------------- 1 | package com.ai.common.util; 2 | 3 | /** 4 | * 异常工具类 5 | */ 6 | public class Exceptions { 7 | 8 | public static IllegalArgumentException illegalArgument(String format, Object... args) { 9 | return new IllegalArgumentException(String.format(format, args)); 10 | } 11 | 12 | public static RuntimeException runtime(String format, Object... args) { 13 | return new RuntimeException(String.format(format, args)); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /smartFuse-common/src/main/java/com/ai/common/util/PlaceHolderReplaceUtils.java: -------------------------------------------------------------------------------- 1 | package com.ai.common.util; 2 | 3 | import cn.hutool.core.util.ObjectUtil; 4 | import cn.hutool.core.util.StrUtil; 5 | 6 | import java.util.HashSet; 7 | import java.util.Map; 8 | import java.util.Set; 9 | import java.util.regex.Matcher; 10 | import java.util.regex.Pattern; 11 | 12 | 13 | /** 14 | * @Description: 占位符替换工具类 15 | */ 16 | public class PlaceHolderReplaceUtils { 17 | private static final Pattern pattern = Pattern.compile("\\{\\{(.*?)\\}\\}"); 18 | private static Matcher matcher; 19 | 20 | /** 21 | * 替换字符串占位符{{key}} 22 | * 23 | * @param sourceString 需要匹配的字符串 24 | * @param param 参数 25 | * @return 替换后的字符串 26 | */ 27 | public static String replaceWithMap(String sourceString, Map param) { 28 | if (StrUtil.isEmpty(sourceString) || ObjectUtil.isEmpty(pattern)) { 29 | return sourceString; 30 | } 31 | 32 | String targetString = sourceString; 33 | matcher = pattern.matcher(sourceString); 34 | while (matcher.find()) { 35 | try { 36 | String key = matcher.group(); 37 | String keyclone = key.substring(2, key.length() - 2).trim(); 38 | String value = param.get(keyclone); 39 | if (value != null) { 40 | targetString = targetString.replace(key, value); 41 | } 42 | } catch (Exception e) { 43 | throw new RuntimeException("String formatter failed", e); 44 | } 45 | } 46 | return targetString; 47 | } 48 | 49 | public static Set findPlaceHolderKeys(String sourceString) { 50 | return findPlaceHolderKeys(sourceString, PlaceHolderReplaceUtils.pattern); 51 | } 52 | 53 | /** 54 | * 查找String中的占位符keys 55 | * 56 | * @param sourceString 需要匹配的字符串 57 | * @param pattern 表达式 58 | * @return 占位符集合 59 | */ 60 | public static Set findPlaceHolderKeys(String sourceString, Pattern pattern) { 61 | Set placeHolderSet = new HashSet<>(); 62 | if (StrUtil.isEmpty(sourceString) || ObjectUtil.isEmpty(pattern)) { 63 | return placeHolderSet; 64 | } 65 | 66 | matcher = pattern.matcher(sourceString); 67 | while (matcher.find()) { 68 | String key = matcher.group(); 69 | String placeHolder = key.substring(1, key.length() - 1).trim(); 70 | placeHolderSet.add(placeHolder); 71 | } 72 | 73 | return placeHolderSet; 74 | } 75 | 76 | } 77 | 78 | -------------------------------------------------------------------------------- /smartFuse-common/src/main/java/com/ai/common/util/Utils.java: -------------------------------------------------------------------------------- 1 | package com.ai.common.util; 2 | 3 | import java.security.MessageDigest; 4 | import java.security.NoSuchAlgorithmException; 5 | import java.util.Collection; 6 | import java.util.UUID; 7 | import java.util.function.Supplier; 8 | 9 | import static java.nio.charset.StandardCharsets.UTF_8; 10 | 11 | 12 | public class Utils { 13 | 14 | public static T getOrDefault(T value, T defaultValue) { 15 | return value != null ? value : defaultValue; 16 | } 17 | 18 | public static T getOrDefault(T value, Supplier defaultValueSupplier) { 19 | return value != null ? value : defaultValueSupplier.get(); 20 | } 21 | 22 | public static boolean isNullOrBlank(String string) { 23 | return string == null || string.trim().isEmpty(); 24 | } 25 | 26 | public static boolean isNotNullOrBlank(String string) { 27 | return !isNullOrBlank(string); 28 | } 29 | 30 | public static boolean areNotNullOrBlank(String... strings) { 31 | if (strings == null || strings.length == 0) { 32 | return false; 33 | } 34 | for (String string : strings) { 35 | if (isNullOrBlank(string)) { 36 | return false; 37 | } 38 | } 39 | return true; 40 | } 41 | 42 | public static boolean isNullOrEmpty(Collection collection) { 43 | return collection == null || collection.isEmpty(); 44 | } 45 | 46 | public static String repeat(String string, int times) { 47 | StringBuilder sb = new StringBuilder(); 48 | for (int i = 0; i < times; i++) { 49 | sb.append(string); 50 | } 51 | return sb.toString(); 52 | } 53 | 54 | public static String randomUUID() { 55 | return UUID.randomUUID().toString(); 56 | } 57 | 58 | public static String generateUUIDFrom(String input) { 59 | try { 60 | byte[] hashBytes = MessageDigest.getInstance("SHA-256").digest(input.getBytes(UTF_8)); 61 | StringBuilder sb = new StringBuilder(); 62 | for (byte b : hashBytes) sb.append(String.format("%02x", b)); 63 | return UUID.nameUUIDFromBytes(sb.toString().getBytes(UTF_8)).toString(); 64 | } catch (NoSuchAlgorithmException e) { 65 | throw new IllegalArgumentException(e); 66 | } 67 | } 68 | 69 | public static String quoted(String string) { 70 | if (string == null) { 71 | return "null"; 72 | } 73 | return "\"" + string + "\""; 74 | } 75 | 76 | public static String firstChars(String string, int numberOfChars) { 77 | if (string == null) { 78 | return null; 79 | } 80 | return string.length() > numberOfChars ? string.substring(0, numberOfChars) : string; 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /smartFuse-common/src/main/java/com/ai/common/util/ValidationUtils.java: -------------------------------------------------------------------------------- 1 | package com.ai.common.util; 2 | 3 | import java.util.Collection; 4 | 5 | import static com.ai.common.util.Exceptions.illegalArgument; 6 | 7 | 8 | public class ValidationUtils { 9 | 10 | public static T ensureNotNull(T object, String name) { 11 | if (object == null) { 12 | throw illegalArgument("%s cannot be null", name); 13 | } 14 | 15 | return object; 16 | } 17 | 18 | public static > T ensureNotEmpty(T collection, String name) { 19 | if (collection == null || collection.isEmpty()) { 20 | throw illegalArgument("%s cannot be null or empty", name); 21 | } 22 | 23 | return collection; 24 | } 25 | 26 | public static String ensureNotBlank(String string, String name) { 27 | if (string == null || string.trim().isEmpty()) { 28 | throw illegalArgument("%s cannot be null or blank", name); 29 | } 30 | 31 | return string; 32 | } 33 | 34 | public static void ensureTrue(boolean expression, String msg) { 35 | if (!expression) { 36 | throw illegalArgument(msg); 37 | } 38 | } 39 | 40 | public static int ensureGreaterThanZero(Integer i, String name) { 41 | if (i == null || i <= 0) { 42 | throw illegalArgument("%s must be greater than zero, but is: %s", name, i); 43 | } 44 | 45 | return i; 46 | } 47 | 48 | public static double ensureBetween(Double d, double min, double max, String name) { 49 | if (d == null || d < min || d > max) { 50 | throw illegalArgument("%s must be between %s and %s, but is: %s", name, min, max, d); 51 | } 52 | 53 | return d; 54 | } 55 | 56 | public static int ensureBetween(Integer i, int min, int max, String name) { 57 | if (i == null || i < min || i > max) { 58 | throw illegalArgument("%s must be between %s and %s, but is: %s", name, min, max, i); 59 | } 60 | 61 | return i; 62 | } 63 | } 64 | 65 | -------------------------------------------------------------------------------- /smartFuse-domain/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | AI-SmartFuse-Framework 7 | com.ai 8 | 1.0-SNAPSHOT 9 | 10 | 4.0.0 11 | 12 | smartFuse-domain 13 | 14 | 15 | UTF-8 16 | UTF-8 17 | 1.8 18 | 8 19 | 8 20 | 21 | 22 | 23 | 24 | junit 25 | junit 26 | 4.13.2 27 | test 28 | 29 | 30 | com.ai 31 | smartFuse-common 32 | 1.0-SNAPSHOT 33 | 34 | 35 | org.apache.pdfbox 36 | pdfbox 37 | 2.0.29 38 | 39 | 40 | org.apache.poi 41 | poi 42 | 5.2.3 43 | 44 | 45 | org.apache.poi 46 | poi-ooxml 47 | 5.2.3 48 | 49 | 50 | com.knuddels 51 | jtokkit 52 | 0.6.1 53 | 54 | 55 | software.amazon.awssdk 56 | auth 57 | 2.20.149 58 | 59 | 60 | software.amazon.awssdk 61 | s3 62 | 2.20.149 63 | 64 | 65 | com.google.code.gson 66 | gson 67 | 2.10.1 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/chain/Chain.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.chain; 2 | 3 | /** 4 | * 链路接口,一个链条当中可以有多个链路节点。 5 | **/ 6 | public interface Chain { 7 | 8 | /** 9 | * 运行这条链路 10 | */ 11 | Output run(Input input); 12 | 13 | } 14 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/chain/impl/ConversationalChain.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.chain.impl; 2 | 3 | import com.ai.domain.chain.Chain; 4 | import com.ai.domain.data.message.AssistantMessage; 5 | import com.ai.domain.data.message.UserMessage; 6 | import com.ai.domain.memory.chat.ChatHistoryRecorder; 7 | import com.ai.domain.model.ChatModel; 8 | import lombok.Builder; 9 | import lombok.Data; 10 | 11 | /** 12 | * 纯文本聊天链 13 | **/ 14 | @Data 15 | @Builder 16 | public class ConversationalChain implements Chain { 17 | 18 | private ChatModel chatModel; 19 | private ChatHistoryRecorder historyRecorder; 20 | 21 | @Override 22 | public String run(String s) { 23 | historyRecorder.add(UserMessage.message(s)); 24 | AssistantMessage data = chatModel.generate(historyRecorder.getCurrentMessages()).getData(); 25 | historyRecorder.add(data); 26 | return data.text(); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/chain/impl/ConversationalRetrievalChain.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.chain.impl; 2 | 3 | import com.ai.domain.chain.Chain; 4 | import com.ai.domain.data.embedding.EmbeddingMatch; 5 | import com.ai.domain.data.message.AssistantMessage; 6 | import com.ai.domain.data.message.UserMessage; 7 | import com.ai.domain.memory.chat.ChatHistoryRecorder; 8 | import com.ai.domain.memory.embedding.EmbeddingStoreRetriever; 9 | import com.ai.domain.model.ChatModel; 10 | import com.ai.domain.model.EmbeddingModel; 11 | import com.ai.domain.prompt.impl.SimplePromptTemplate; 12 | import lombok.Builder; 13 | import lombok.Data; 14 | 15 | import java.util.List; 16 | 17 | 18 | /** 19 | * 文档检索聊天链 20 | */ 21 | @Data 22 | @Builder 23 | public class ConversationalRetrievalChain implements Chain { 24 | 25 | private EmbeddingModel embeddingModel; 26 | private ChatModel chatModel; 27 | private ChatHistoryRecorder historyRecorder; 28 | @Builder.Default 29 | private SimplePromptTemplate promptTemplate = new SimplePromptTemplate("Answer the following question to the best of your ability: {{question}}\\n\\nBase your answer on the following information:\\n{{information}}", "default template"); 30 | private EmbeddingStoreRetriever retriever; 31 | 32 | @Override 33 | public String run(String s) { 34 | List relevant = retriever.findRelevant(embeddingModel.embed(s).getData()); 35 | StringBuilder stringBuilder = new StringBuilder(); 36 | for (EmbeddingMatch embeddingMatch : relevant) { 37 | stringBuilder.append(embeddingMatch.getEmbedding().getContent()); 38 | } 39 | promptTemplate.add("question", s); 40 | promptTemplate.add("information", stringBuilder.toString()); 41 | historyRecorder.add(UserMessage.message(promptTemplate.render())); 42 | AssistantMessage data = chatModel.generate(historyRecorder.getCurrentMessages()).getData(); 43 | historyRecorder.add(data); 44 | return data.text(); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/data/embedding/CosineSimilarity.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.data.embedding; 2 | 3 | 4 | import com.ai.common.util.Exceptions; 5 | import com.ai.common.util.ValidationUtils; 6 | 7 | /** 8 | * 相关性分析 9 | */ 10 | public class CosineSimilarity { 11 | 12 | public static double between(Embedding embeddingA, Embedding embeddingB) { 13 | ValidationUtils.ensureNotNull(embeddingA, "embeddingA"); 14 | ValidationUtils.ensureNotNull(embeddingB, "embeddingB"); 15 | double[] vectorA = embeddingA.getEmbedding(); 16 | double[] vectorB = embeddingB.getEmbedding(); 17 | if (vectorA.length != vectorB.length) { 18 | throw Exceptions.illegalArgument("Length of vector a (%s) must be equal to the length of vector b (%s)", new Object[]{vectorA.length, vectorB.length}); 19 | } else { 20 | double dotProduct = 0.0D; 21 | double normA = 0.0D; 22 | double normB = 0.0D; 23 | 24 | for (int i = 0; i < vectorA.length; ++i) { 25 | dotProduct += (vectorA[i] * vectorB[i]); 26 | normA += (vectorA[i] * vectorA[i]); 27 | normB += (vectorB[i] * vectorB[i]); 28 | } 29 | 30 | return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); 31 | } 32 | } 33 | 34 | public static double fromCosineSimilarity(double cosineSimilarity) { 35 | return (cosineSimilarity + 1.0D) / 2.0D; 36 | } 37 | 38 | public static double fromRelevanceScore(double relevanceScore) { 39 | return relevanceScore * 2.0D - 1.0D; 40 | } 41 | 42 | } 43 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/data/embedding/Embedding.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.data.embedding; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Data; 5 | import lombok.NoArgsConstructor; 6 | 7 | 8 | @Data 9 | @NoArgsConstructor 10 | @AllArgsConstructor 11 | public class Embedding { 12 | 13 | private double[] embedding; 14 | 15 | private String content; 16 | 17 | } 18 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/data/embedding/EmbeddingMatch.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.data.embedding; 2 | 3 | 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | 7 | /** 8 | * 匹配完成存放信息类 9 | */ 10 | @Data 11 | @AllArgsConstructor 12 | public class EmbeddingMatch { 13 | 14 | private final Double score; 15 | 16 | private final Embedding embedding; 17 | 18 | } 19 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/data/images/Image.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.data.images; 2 | 3 | 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | @Data 9 | @NoArgsConstructor 10 | @AllArgsConstructor 11 | public class Image { 12 | 13 | private String url; 14 | 15 | private String b64Json; 16 | 17 | public static Image b64Json(String b64Json) { 18 | return new Image(null, b64Json); 19 | } 20 | 21 | public static Image url(String url) { 22 | return new Image(url, null); 23 | } 24 | 25 | public static Image from(String url, String b64Json) { 26 | return new Image(url, b64Json); 27 | } 28 | 29 | } 30 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/data/message/AssistantMessage.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.data.message; 2 | 3 | /** 4 | * @Description: AI消息 5 | **/ 6 | public class AssistantMessage extends ChatMessage { 7 | 8 | public AssistantMessage(String text, Integer order) { 9 | super(text, order); 10 | } 11 | 12 | public static AssistantMessage message(String message) { 13 | return new AssistantMessage(message, -1); 14 | } 15 | 16 | public static AssistantMessage message(String message, Integer order) { 17 | return new AssistantMessage(message, order); 18 | } 19 | 20 | @Override 21 | public MessageType type() { 22 | return MessageType.ASSISTANT; 23 | } 24 | 25 | } 26 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/data/message/ChatMessage.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.data.message; 2 | 3 | 4 | import lombok.Data; 5 | 6 | /** 7 | * 聊天消息 8 | **/ 9 | @Data 10 | public abstract class ChatMessage { 11 | 12 | protected final String text; 13 | 14 | protected final Integer order; 15 | 16 | ChatMessage(String text, Integer order) { 17 | this.text = text; 18 | this.order = order; 19 | } 20 | 21 | public abstract MessageType type(); 22 | 23 | public Integer order() { 24 | return order; 25 | } 26 | 27 | public String text() { 28 | return text; 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/data/message/MessageType.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.data.message; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Getter; 5 | 6 | @Getter 7 | @AllArgsConstructor 8 | public enum MessageType { 9 | SYSTEM("system"), 10 | USER("user"), 11 | ASSISTANT("assistant"); 12 | 13 | private String messageType; 14 | } 15 | 16 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/data/message/SystemMessage.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.data.message; 2 | 3 | /** 4 | * @Description: 系统消息 5 | **/ 6 | public class SystemMessage extends ChatMessage { 7 | 8 | public SystemMessage(String text, Integer order) { 9 | super(text, order); 10 | } 11 | 12 | public static SystemMessage message(String message) { 13 | return new SystemMessage(message, -1); 14 | } 15 | 16 | public static SystemMessage message(String message, Integer order) { 17 | return new SystemMessage(message, order); 18 | } 19 | 20 | @Override 21 | public MessageType type() { 22 | return MessageType.SYSTEM; 23 | } 24 | 25 | } 26 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/data/message/UserMessage.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.data.message; 2 | 3 | /** 4 | * @Description: 用户消息 5 | **/ 6 | public class UserMessage extends ChatMessage { 7 | 8 | public UserMessage(String text, Integer order) { 9 | super(text, order); 10 | } 11 | 12 | public static UserMessage message(String message) { 13 | return new UserMessage(message, -1); 14 | } 15 | 16 | public static UserMessage message(String message, Integer order) { 17 | return new UserMessage(message, order); 18 | } 19 | 20 | @Override 21 | public MessageType type() { 22 | return MessageType.USER; 23 | } 24 | 25 | } 26 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/data/moderation/Moderation.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.data.moderation; 2 | 3 | import lombok.Data; 4 | 5 | @Data 6 | public class Moderation { 7 | 8 | private Boolean flagged; 9 | private String flaggedText; 10 | private String type; 11 | 12 | public Moderation() { 13 | this.flagged = false; 14 | this.flaggedText = null; 15 | this.type = null; 16 | } 17 | 18 | public Moderation(String flaggedText, String type) { 19 | this.flagged = true; 20 | this.flaggedText = flaggedText; 21 | this.type = type; 22 | } 23 | 24 | public static Moderation flagged(String flaggedText, String type) { 25 | return new Moderation(flaggedText, type); 26 | } 27 | 28 | public static Moderation notFlagged() { 29 | return new Moderation(); 30 | } 31 | } 32 | 33 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/data/parameter/Parameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.data.parameter; 2 | 3 | /** 4 | * 参数标志,调用参数的设置 5 | **/ 6 | public interface Parameter { 7 | 8 | T getParameter(); 9 | 10 | void SetParameter(T parameter); 11 | 12 | } 13 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/AwsCredentials.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document; 2 | 3 | import static com.ai.common.util.Utils.areNotNullOrBlank; 4 | import static com.ai.common.util.ValidationUtils.ensureNotBlank; 5 | 6 | 7 | public class AwsCredentials { 8 | 9 | private final String accessKeyId; 10 | private final String secretAccessKey; 11 | private String sessionToken; 12 | 13 | public AwsCredentials(String accessKeyId, String secretAccessKey, String sessionToken) { 14 | this.accessKeyId = ensureNotBlank(accessKeyId, "accessKeyId"); 15 | this.secretAccessKey = ensureNotBlank(secretAccessKey, "secretAccessKey"); 16 | this.sessionToken = sessionToken; 17 | } 18 | 19 | public AwsCredentials(String accessKeyId, String secretAccessKey) { 20 | this(accessKeyId, secretAccessKey, null); 21 | } 22 | 23 | 24 | public String accessKeyId() { 25 | return accessKeyId; 26 | } 27 | 28 | public String secretAccessKey() { 29 | return secretAccessKey; 30 | } 31 | 32 | public String sessionToken() { 33 | return sessionToken; 34 | } 35 | 36 | public boolean hasAccessKeyIdAndSecretKey() { 37 | return areNotNullOrBlank(accessKeyId, secretAccessKey); 38 | } 39 | 40 | public boolean hasAllCredentials() { 41 | return areNotNullOrBlank(accessKeyId, secretAccessKey, sessionToken); 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/Document.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document; 2 | 3 | import lombok.Data; 4 | 5 | import java.util.Objects; 6 | 7 | 8 | @Data 9 | public class Document { 10 | 11 | public static final String DOCUMENT_TYPE = "document_type"; 12 | public static final String FILE_NAME = "file_name"; 13 | public static final String ABSOLUTE_DIRECTORY_PATH = "absolute_directory_path"; 14 | public static final String URL = "url"; 15 | 16 | private final String text; 17 | private final Metadata metadata; 18 | 19 | public Document(String text, Metadata metadata) { 20 | this.text = text; 21 | this.metadata = metadata; 22 | } 23 | 24 | public static Document from(String text) { 25 | return new Document(text, new Metadata()); 26 | } 27 | 28 | public static Document from(String text, Metadata metadata) { 29 | return new Document(text, metadata); 30 | } 31 | 32 | public static Document document(String text) { 33 | return from(text); 34 | } 35 | 36 | public static Document document(String text, Metadata metadata) { 37 | return from(text, metadata); 38 | } 39 | 40 | public String text() { 41 | return text; 42 | } 43 | 44 | public Metadata metadata() { 45 | return metadata; 46 | } 47 | 48 | public String metadata(String key) { 49 | return metadata.get(key); 50 | } 51 | 52 | public TextSegment toTextSegment() { 53 | return TextSegment.from(text, metadata.copy().add("index", 0)); 54 | } 55 | 56 | @Override 57 | public boolean equals(Object o) { 58 | if (this == o) return true; 59 | if (o == null || getClass() != o.getClass()) return false; 60 | Document that = (Document) o; 61 | return Objects.equals(this.text, that.text) 62 | && Objects.equals(this.metadata, that.metadata); 63 | } 64 | 65 | @Override 66 | public int hashCode() { 67 | return Objects.hash(text, metadata); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/DocumentLoaderUtils.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document; 2 | 3 | import com.ai.domain.document.parser.DocumentParser; 4 | import com.ai.domain.document.parser.impl.MsOfficeDocumentParser; 5 | import com.ai.domain.document.parser.impl.PdfDocumentParser; 6 | import com.ai.domain.document.parser.impl.TextDocumentParser; 7 | import com.ai.domain.document.source.DocumentSource; 8 | 9 | import java.io.InputStream; 10 | 11 | class DocumentLoaderUtils { 12 | 13 | static Document load(DocumentSource source, DocumentParser parser) { 14 | try (InputStream inputStream = source.inputStream()) { 15 | Document document = parser.parse(inputStream); 16 | source.metadata().asMap().forEach((key, value) -> document.metadata().add(key, value)); 17 | return document; 18 | } catch (Exception e) { 19 | throw new RuntimeException("Failed to load document", e); 20 | } 21 | } 22 | 23 | static DocumentParser parserFor(DocumentType type) { 24 | switch (type) { 25 | case TXT: 26 | case HTML: 27 | case UNKNOWN: 28 | return new TextDocumentParser(type); 29 | case PDF: 30 | return new PdfDocumentParser(); 31 | case DOC: 32 | case XLS: 33 | case PPT: 34 | return new MsOfficeDocumentParser(type); 35 | default: 36 | throw new RuntimeException(String.format("Cannot find parser for document type '%s'", type)); 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/DocumentType.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document; 2 | 3 | import static java.util.Arrays.asList; 4 | 5 | public enum DocumentType { 6 | 7 | TXT(".txt"), 8 | PDF(".pdf"), 9 | HTML(".html", ".htm", ".xhtml"), 10 | DOC(".doc", ".docx"), 11 | XLS(".xls", ".xlsx"), 12 | PPT(".ppt", ".pptx"), 13 | UNKNOWN; 14 | 15 | private final Iterable supportedExtensions; 16 | 17 | DocumentType(String... supportedExtensions) { 18 | this.supportedExtensions = asList(supportedExtensions); 19 | } 20 | 21 | public static DocumentType of(String fileName) { 22 | 23 | for (DocumentType documentType : values()) { 24 | for (String supportedExtension : documentType.supportedExtensions) { 25 | if (fileName.toLowerCase().endsWith(supportedExtension)) { 26 | return documentType; 27 | } 28 | } 29 | } 30 | 31 | return UNKNOWN; 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/FileSystemDocumentLoader.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document; 2 | 3 | import lombok.extern.slf4j.Slf4j; 4 | 5 | import java.io.IOException; 6 | import java.nio.file.Files; 7 | import java.nio.file.Path; 8 | import java.nio.file.Paths; 9 | import java.util.ArrayList; 10 | import java.util.List; 11 | import java.util.stream.Stream; 12 | 13 | import static com.ai.common.util.Exceptions.illegalArgument; 14 | import static com.ai.domain.document.DocumentLoaderUtils.parserFor; 15 | import static com.ai.domain.document.source.impl.FileSystemSource.from; 16 | import static java.nio.file.Files.isDirectory; 17 | import static java.nio.file.Files.isRegularFile; 18 | 19 | @Slf4j 20 | public class FileSystemDocumentLoader { 21 | 22 | public static Document loadDocument(Path filePath) { 23 | return loadDocument(filePath, DocumentType.of(filePath.toString())); 24 | } 25 | 26 | public static Document loadDocument(String filePath) { 27 | return loadDocument(Paths.get(filePath)); 28 | } 29 | 30 | public static Document loadDocument(Path filePath, DocumentType documentType) { 31 | if (!isRegularFile(filePath)) { 32 | throw illegalArgument("%s is not a file", filePath); 33 | } 34 | 35 | return DocumentLoaderUtils.load(from(filePath), parserFor(documentType)); 36 | } 37 | 38 | public static Document loadDocument(String filePath, DocumentType documentType) { 39 | return loadDocument(Paths.get(filePath), documentType); 40 | } 41 | 42 | public static List loadDocuments(Path directoryPath) { 43 | if (!isDirectory(directoryPath)) { 44 | throw illegalArgument("%s is not a directory", directoryPath); 45 | } 46 | 47 | List documents = new ArrayList<>(); 48 | 49 | try (Stream paths = Files.list(directoryPath)) { 50 | paths.filter(Files::isRegularFile) 51 | .forEach(filePath -> { 52 | try { 53 | Document document = loadDocument(filePath); 54 | documents.add(document); 55 | } catch (Exception e) { 56 | log.warn("Failed to load document from " + filePath, e); 57 | } 58 | }); 59 | } catch (IOException e) { 60 | throw new RuntimeException(e); 61 | } 62 | 63 | return documents; 64 | } 65 | 66 | public static List loadDocuments(String directoryPath) { 67 | return loadDocuments(Paths.get(directoryPath)); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/Metadata.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document; 2 | 3 | import java.util.HashMap; 4 | import java.util.Map; 5 | import java.util.Objects; 6 | 7 | public class Metadata { 8 | 9 | private final Map metadata; 10 | 11 | public Metadata() { 12 | this(new HashMap<>()); 13 | } 14 | 15 | public Metadata(Map metadata) { 16 | this.metadata = metadata; 17 | } 18 | 19 | public static Metadata from(String key, Object value) { 20 | return new Metadata().add(key, value); 21 | } 22 | 23 | public static Metadata from(Map metadata) { 24 | return new Metadata(metadata); 25 | } 26 | 27 | public static Metadata metadata(String key, Object value) { 28 | return from(key, value); 29 | } 30 | 31 | public String get(String key) { 32 | return metadata.get(key); 33 | } 34 | 35 | public Metadata add(String key, Object value) { 36 | this.metadata.put(key, value.toString()); 37 | return this; 38 | } 39 | 40 | public Metadata remove(String key) { 41 | this.metadata.remove(key); 42 | return this; 43 | } 44 | 45 | public Metadata copy() { 46 | return new Metadata(new HashMap<>(metadata)); 47 | } 48 | 49 | public Map asMap() { 50 | return new HashMap<>(metadata); 51 | } 52 | 53 | @Override 54 | public boolean equals(Object o) { 55 | if (this == o) return true; 56 | if (o == null || getClass() != o.getClass()) return false; 57 | Metadata that = (Metadata) o; 58 | return Objects.equals(this.metadata, that.metadata); 59 | } 60 | 61 | @Override 62 | public int hashCode() { 63 | return Objects.hash(metadata); 64 | } 65 | 66 | @Override 67 | public String toString() { 68 | return "Metadata {" + 69 | " metadata = " + metadata + 70 | " }"; 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/S3DirectoryLoader.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document; 2 | 3 | import com.ai.domain.document.source.impl.S3Source; 4 | import org.slf4j.Logger; 5 | import org.slf4j.LoggerFactory; 6 | import software.amazon.awssdk.core.ResponseInputStream; 7 | import software.amazon.awssdk.services.s3.S3Client; 8 | import software.amazon.awssdk.services.s3.model.*; 9 | 10 | import java.util.ArrayList; 11 | import java.util.List; 12 | import java.util.stream.Collectors; 13 | 14 | import static com.ai.domain.document.DocumentLoaderUtils.parserFor; 15 | 16 | public class S3DirectoryLoader extends AbstractS3Loader> { 17 | 18 | private static final Logger log = LoggerFactory.getLogger(S3DirectoryLoader.class); 19 | 20 | private final String prefix; 21 | 22 | private S3DirectoryLoader(Builder builder) { 23 | super(builder); 24 | this.prefix = builder.prefix; 25 | } 26 | 27 | public static Builder builder() { 28 | return new Builder(); 29 | } 30 | 31 | @Override 32 | protected List load(S3Client s3Client) { 33 | List documents = new ArrayList<>(); 34 | 35 | ListObjectsV2Request listObjectsV2Request = ListObjectsV2Request.builder() 36 | .bucket(bucket) 37 | .prefix(prefix) 38 | .build(); 39 | 40 | ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request); 41 | List filteredS3Objects = listObjectsV2Response.contents().stream() 42 | .filter(s3Object -> !s3Object.key().endsWith("/") && s3Object.size() > 0) 43 | .collect(Collectors.toList()); 44 | 45 | for (S3Object s3Object : filteredS3Objects) { 46 | String key = s3Object.key(); 47 | 48 | GetObjectRequest getObjectRequest = GetObjectRequest.builder() 49 | .bucket(bucket) 50 | .key(key) 51 | .build(); 52 | 53 | ResponseInputStream inputStream = s3Client.getObject(getObjectRequest); 54 | 55 | try { 56 | documents.add(DocumentLoaderUtils.load(new S3Source(bucket, key, inputStream), parserFor(DocumentType.of(key)))); 57 | } catch (Exception e) { 58 | log.warn("Failed to load document from S3", e); 59 | } 60 | } 61 | 62 | return documents; 63 | } 64 | 65 | public static final class Builder extends AbstractS3Loader.Builder { 66 | private String prefix = ""; 67 | 68 | /** 69 | * Set the prefix. 70 | * 71 | * @param prefix Prefix. 72 | */ 73 | public Builder prefix(String prefix) { 74 | this.prefix = prefix; 75 | return this; 76 | } 77 | 78 | @Override 79 | public S3DirectoryLoader build() { 80 | return new S3DirectoryLoader(this); 81 | } 82 | 83 | @Override 84 | protected Builder self() { 85 | return this; 86 | } 87 | } 88 | } 89 | 90 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/S3FileLoader.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document; 2 | 3 | import com.ai.domain.document.source.impl.S3Source; 4 | import software.amazon.awssdk.core.ResponseInputStream; 5 | import software.amazon.awssdk.services.s3.S3Client; 6 | import software.amazon.awssdk.services.s3.model.GetObjectRequest; 7 | import software.amazon.awssdk.services.s3.model.GetObjectResponse; 8 | import software.amazon.awssdk.services.s3.model.S3Exception; 9 | 10 | import static com.ai.common.util.ValidationUtils.ensureNotBlank; 11 | import static com.ai.domain.document.DocumentLoaderUtils.parserFor; 12 | 13 | 14 | public class S3FileLoader extends AbstractS3Loader { 15 | 16 | private final String key; 17 | 18 | private S3FileLoader(Builder builder) { 19 | super(builder); 20 | this.key = ensureNotBlank(builder.key, "key"); 21 | } 22 | 23 | public static Builder builder() { 24 | return new Builder(); 25 | } 26 | 27 | @Override 28 | protected Document load(S3Client s3Client) { 29 | try { 30 | GetObjectRequest objectRequest = GetObjectRequest.builder().bucket(bucket).key(key).build(); 31 | ResponseInputStream inputStream = s3Client.getObject(objectRequest); 32 | return DocumentLoaderUtils.load(new S3Source(bucket, key, inputStream), parserFor(DocumentType.of(key))); 33 | } catch (S3Exception e) { 34 | throw new RuntimeException("Failed to load document from s3", e); 35 | } 36 | } 37 | 38 | public static final class Builder extends AbstractS3Loader.Builder { 39 | 40 | private String key; 41 | 42 | public Builder key(String key) { 43 | this.key = key; 44 | return this; 45 | } 46 | 47 | @Override 48 | public S3FileLoader build() { 49 | return new S3FileLoader(this); 50 | } 51 | 52 | @Override 53 | protected Builder self() { 54 | return this; 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/TextSegment.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document; 2 | 3 | import java.util.Objects; 4 | 5 | public class TextSegment { 6 | 7 | private final String text; 8 | private final Metadata metadata; 9 | 10 | public TextSegment(String text, Metadata metadata) { 11 | this.text = text; 12 | this.metadata = metadata; 13 | } 14 | 15 | public static TextSegment from(String text) { 16 | return new TextSegment(text, new Metadata()); 17 | } 18 | 19 | public static TextSegment from(String text, Metadata metadata) { 20 | return new TextSegment(text, metadata); 21 | } 22 | 23 | public static TextSegment textSegment(String text) { 24 | return from(text); 25 | } 26 | 27 | public static TextSegment textSegment(String text, Metadata metadata) { 28 | return from(text, metadata); 29 | } 30 | 31 | public String text() { 32 | return text; 33 | } 34 | 35 | public Metadata metadata() { 36 | return metadata; 37 | } 38 | 39 | public String metadata(String key) { 40 | return metadata.get(key); 41 | } 42 | 43 | @Override 44 | public boolean equals(Object o) { 45 | if (this == o) return true; 46 | if (o == null || getClass() != o.getClass()) return false; 47 | TextSegment that = (TextSegment) o; 48 | return Objects.equals(this.text, that.text) 49 | && Objects.equals(this.metadata, that.metadata); 50 | } 51 | 52 | @Override 53 | public int hashCode() { 54 | return Objects.hash(text, metadata); 55 | } 56 | 57 | @Override 58 | public String toString() { 59 | return "TextSegment{" + 60 | "text='" + text + '\'' + 61 | ", metadata=" + metadata + 62 | '}'; 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/UrlDocumentLoader.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document; 2 | 3 | 4 | import com.ai.domain.document.source.impl.UrlSource; 5 | 6 | import java.net.MalformedURLException; 7 | import java.net.URL; 8 | 9 | import static com.ai.domain.document.DocumentLoaderUtils.parserFor; 10 | 11 | public class UrlDocumentLoader { 12 | 13 | public static Document load(URL url) { 14 | return load(url, DocumentType.of(url.toString())); 15 | } 16 | 17 | public static Document load(String url) { 18 | try { 19 | return load(new URL(url)); 20 | } catch (MalformedURLException e) { 21 | throw new RuntimeException(e); 22 | } 23 | } 24 | 25 | public static Document load(URL url, DocumentType documentType) { 26 | return DocumentLoaderUtils.load(UrlSource.from(url), parserFor(documentType)); 27 | } 28 | 29 | public static Document load(String url, DocumentType documentType) { 30 | try { 31 | return load(new URL(url), documentType); 32 | } catch (MalformedURLException e) { 33 | throw new RuntimeException(e); 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/parser/DocumentParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.parser; 2 | 3 | import com.ai.domain.document.Document; 4 | 5 | import java.io.InputStream; 6 | 7 | /** 8 | * 解析器接口 9 | */ 10 | public interface DocumentParser { 11 | 12 | Document parse(InputStream inputStream); 13 | 14 | } 15 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/parser/impl/MsOfficeDocumentParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.parser.impl; 2 | 3 | import com.ai.domain.document.Document; 4 | import com.ai.domain.document.DocumentType; 5 | import com.ai.domain.document.Metadata; 6 | import com.ai.domain.document.parser.DocumentParser; 7 | import org.apache.poi.extractor.ExtractorFactory; 8 | import org.apache.poi.extractor.POITextExtractor; 9 | 10 | import java.io.IOException; 11 | import java.io.InputStream; 12 | 13 | import static com.ai.domain.document.Document.DOCUMENT_TYPE; 14 | 15 | /** 16 | * Office 套件解析器,word\ppt\excel 17 | */ 18 | public class MsOfficeDocumentParser implements DocumentParser { 19 | 20 | private final DocumentType documentType; 21 | 22 | public MsOfficeDocumentParser(DocumentType documentType) { 23 | this.documentType = documentType; 24 | } 25 | 26 | @Override 27 | public Document parse(InputStream inputStream) { 28 | try (POITextExtractor extractor = ExtractorFactory.createExtractor(inputStream)) { 29 | return new Document(extractor.getText(), Metadata.from(DOCUMENT_TYPE, documentType)); 30 | } catch (IOException e) { 31 | throw new RuntimeException(e); 32 | } 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/parser/impl/PdfDocumentParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.parser.impl; 2 | 3 | import com.ai.domain.document.Document; 4 | import com.ai.domain.document.Metadata; 5 | import com.ai.domain.document.parser.DocumentParser; 6 | import org.apache.pdfbox.pdmodel.PDDocument; 7 | import org.apache.pdfbox.text.PDFTextStripper; 8 | 9 | import java.io.IOException; 10 | import java.io.InputStream; 11 | 12 | import static com.ai.domain.document.Document.DOCUMENT_TYPE; 13 | import static com.ai.domain.document.DocumentType.PDF; 14 | 15 | /** 16 | * pdf 内容解析器 17 | */ 18 | public class PdfDocumentParser implements DocumentParser { 19 | 20 | @Override 21 | public Document parse(InputStream inputStream) { 22 | try { 23 | PDDocument pdfDocument = PDDocument.load(inputStream); 24 | PDFTextStripper stripper = new PDFTextStripper(); 25 | String content = stripper.getText(pdfDocument); 26 | pdfDocument.close(); 27 | return Document.from(content, Metadata.from(DOCUMENT_TYPE, PDF)); 28 | } catch (IOException e) { 29 | throw new RuntimeException(e); 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/parser/impl/TextDocumentParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.parser.impl; 2 | 3 | import com.ai.domain.document.Document; 4 | import com.ai.domain.document.DocumentType; 5 | import com.ai.domain.document.Metadata; 6 | import com.ai.domain.document.parser.DocumentParser; 7 | 8 | import java.io.ByteArrayOutputStream; 9 | import java.io.InputStream; 10 | import java.nio.charset.Charset; 11 | 12 | import static com.ai.domain.document.Document.DOCUMENT_TYPE; 13 | import static java.nio.charset.StandardCharsets.UTF_8; 14 | 15 | /** 16 | * 文本文件解析器 17 | */ 18 | public class TextDocumentParser implements DocumentParser { 19 | 20 | private final DocumentType documentType; 21 | private final Charset charset; 22 | 23 | public TextDocumentParser(DocumentType documentType) { 24 | this(documentType, UTF_8); 25 | } 26 | 27 | public TextDocumentParser(DocumentType documentType, Charset charset) { 28 | this.documentType = documentType; 29 | this.charset = charset; 30 | } 31 | 32 | @Override 33 | public Document parse(InputStream inputStream) { 34 | try { 35 | ByteArrayOutputStream buffer = new ByteArrayOutputStream(); 36 | int nRead; 37 | byte[] data = new byte[1024]; 38 | while ((nRead = inputStream.read(data, 0, data.length)) != -1) { 39 | buffer.write(data, 0, nRead); 40 | } 41 | buffer.flush(); 42 | 43 | String text = new String(buffer.toByteArray(), charset); 44 | 45 | return Document.from(text, Metadata.from(DOCUMENT_TYPE, documentType.toString())); 46 | } catch (Exception e) { 47 | throw new RuntimeException(e); 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/source/DocumentSource.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.source; 2 | 3 | import com.ai.domain.document.Metadata; 4 | 5 | import java.io.IOException; 6 | import java.io.InputStream; 7 | 8 | 9 | public interface DocumentSource { 10 | 11 | InputStream inputStream() throws IOException; 12 | 13 | Metadata metadata(); 14 | 15 | } 16 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/source/impl/FileSystemSource.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.source.impl; 2 | 3 | import com.ai.domain.document.Metadata; 4 | import com.ai.domain.document.source.DocumentSource; 5 | 6 | import java.io.File; 7 | import java.io.IOException; 8 | import java.io.InputStream; 9 | import java.net.URI; 10 | import java.nio.file.Files; 11 | import java.nio.file.Path; 12 | import java.nio.file.Paths; 13 | 14 | import static com.ai.domain.document.Document.ABSOLUTE_DIRECTORY_PATH; 15 | import static com.ai.domain.document.Document.FILE_NAME; 16 | 17 | 18 | public class FileSystemSource implements DocumentSource { 19 | 20 | public final Path path; 21 | 22 | public FileSystemSource(Path path) { 23 | this.path = path; 24 | } 25 | 26 | public static FileSystemSource from(Path filePath) { 27 | return new FileSystemSource(filePath); 28 | } 29 | 30 | public static FileSystemSource from(String filePath) { 31 | return new FileSystemSource(Paths.get(filePath)); 32 | } 33 | 34 | public static FileSystemSource from(URI fileUri) { 35 | return new FileSystemSource(Paths.get(fileUri)); 36 | } 37 | 38 | public static FileSystemSource from(File file) { 39 | return new FileSystemSource(file.toPath()); 40 | } 41 | 42 | @Override 43 | public InputStream inputStream() throws IOException { 44 | return Files.newInputStream(path); 45 | } 46 | 47 | @Override 48 | public Metadata metadata() { 49 | return new Metadata() 50 | .add(FILE_NAME, path.getFileName()) 51 | .add(ABSOLUTE_DIRECTORY_PATH, path.getParent().toAbsolutePath()); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/source/impl/S3Source.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.source.impl; 2 | 3 | import com.ai.domain.document.Metadata; 4 | import com.ai.domain.document.source.DocumentSource; 5 | 6 | import java.io.IOException; 7 | import java.io.InputStream; 8 | 9 | public class S3Source implements DocumentSource { 10 | 11 | private static final String SOURCE = "source"; 12 | private final String bucket; 13 | private final String key; 14 | private InputStream inputStream; 15 | 16 | public S3Source(String bucket, String key, InputStream inputStream) { 17 | this.inputStream = inputStream; 18 | this.bucket = bucket; 19 | this.key = key; 20 | } 21 | 22 | @Override 23 | public InputStream inputStream() throws IOException { 24 | return inputStream; 25 | } 26 | 27 | @Override 28 | public Metadata metadata() { 29 | return new Metadata() 30 | .add(SOURCE, String.format("s3://%s/%s", bucket, key)); 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/source/impl/UrlSource.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.source.impl; 2 | 3 | import com.ai.domain.document.Document; 4 | import com.ai.domain.document.Metadata; 5 | import com.ai.domain.document.source.DocumentSource; 6 | 7 | import java.io.IOException; 8 | import java.io.InputStream; 9 | import java.net.MalformedURLException; 10 | import java.net.URI; 11 | import java.net.URL; 12 | import java.net.URLConnection; 13 | 14 | public class UrlSource implements DocumentSource { 15 | 16 | private final URL url; 17 | 18 | public UrlSource(URL url) { 19 | this.url = url; 20 | } 21 | 22 | public static UrlSource from(String url) { 23 | try { 24 | return new UrlSource(new URL(url)); 25 | } catch (MalformedURLException e) { 26 | throw new RuntimeException(e); 27 | } 28 | } 29 | 30 | public static UrlSource from(URL url) { 31 | return new UrlSource(url); 32 | } 33 | 34 | public static UrlSource from(URI uri) { 35 | try { 36 | return new UrlSource(uri.toURL()); 37 | } catch (MalformedURLException e) { 38 | throw new RuntimeException(e); 39 | } 40 | } 41 | 42 | @Override 43 | public InputStream inputStream() throws IOException { 44 | URLConnection connection = url.openConnection(); 45 | return connection.getInputStream(); 46 | } 47 | 48 | @Override 49 | public Metadata metadata() { 50 | return Metadata.from(Document.URL, url); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/splitter/DocumentSplitter.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.splitter; 2 | 3 | import com.ai.domain.document.Document; 4 | import com.ai.domain.document.TextSegment; 5 | 6 | import java.util.List; 7 | 8 | import static java.util.stream.Collectors.toList; 9 | 10 | public interface DocumentSplitter { 11 | 12 | 13 | List split(Document document); 14 | 15 | default List splitAll(List documents) { 16 | return documents.stream() 17 | .flatMap(document -> split(document).stream()) 18 | .collect(toList()); 19 | } 20 | } 21 | 22 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/splitter/impl/DocumentByCharacterSplitter.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.splitter.impl; 2 | 3 | 4 | import com.ai.domain.document.splitter.DocumentSplitter; 5 | import com.ai.domain.document.tokenizer.Tokenizer; 6 | 7 | 8 | public class DocumentByCharacterSplitter extends HierarchicalDocumentSplitter { 9 | 10 | public DocumentByCharacterSplitter(int maxSegmentSizeInChars, 11 | int maxOverlapSizeInChars) { 12 | super(maxSegmentSizeInChars, maxOverlapSizeInChars, null, null); 13 | } 14 | 15 | public DocumentByCharacterSplitter(int maxSegmentSizeInChars, 16 | int maxOverlapSizeInChars, 17 | DocumentSplitter subSplitter) { 18 | super(maxSegmentSizeInChars, maxOverlapSizeInChars, null, subSplitter); 19 | } 20 | 21 | public DocumentByCharacterSplitter(int maxSegmentSizeInTokens, 22 | int maxOverlapSizeInTokens, 23 | Tokenizer tokenizer) { 24 | super(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer, null); 25 | } 26 | 27 | public DocumentByCharacterSplitter(int maxSegmentSizeInTokens, 28 | int maxOverlapSizeInTokens, 29 | Tokenizer tokenizer, 30 | DocumentSplitter subSplitter) { 31 | super(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer, subSplitter); 32 | } 33 | 34 | @Override 35 | public String[] split(String text) { 36 | return text.split(""); 37 | } 38 | 39 | @Override 40 | public String joinDelimiter() { 41 | return ""; 42 | } 43 | 44 | @Override 45 | protected DocumentSplitter defaultSubSplitter() { 46 | return null; 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/splitter/impl/DocumentByLineSplitter.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.splitter.impl; 2 | 3 | import com.ai.domain.document.splitter.DocumentSplitter; 4 | import com.ai.domain.document.tokenizer.Tokenizer; 5 | 6 | 7 | public class DocumentByLineSplitter extends HierarchicalDocumentSplitter { 8 | 9 | public DocumentByLineSplitter(int maxSegmentSizeInChars, 10 | int maxOverlapSizeInChars) { 11 | super(maxSegmentSizeInChars, maxOverlapSizeInChars, null, null); 12 | } 13 | 14 | public DocumentByLineSplitter(int maxSegmentSizeInChars, 15 | int maxOverlapSizeInChars, 16 | DocumentSplitter subSplitter) { 17 | super(maxSegmentSizeInChars, maxOverlapSizeInChars, null, subSplitter); 18 | } 19 | 20 | public DocumentByLineSplitter(int maxSegmentSizeInTokens, 21 | int maxOverlapSizeInTokens, 22 | Tokenizer tokenizer) { 23 | super(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer, null); 24 | } 25 | 26 | public DocumentByLineSplitter(int maxSegmentSizeInTokens, 27 | int maxOverlapSizeInTokens, 28 | Tokenizer tokenizer, 29 | DocumentSplitter subSplitter) { 30 | super(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer, subSplitter); 31 | } 32 | 33 | @Override 34 | public String[] split(String text) { 35 | return text.split("\\s*\\R\\s*"); // additional whitespaces are ignored 36 | } 37 | 38 | @Override 39 | public String joinDelimiter() { 40 | return "\n"; 41 | } 42 | 43 | @Override 44 | protected DocumentSplitter defaultSubSplitter() { 45 | return new DocumentBySentenceSplitter(maxSegmentSize, maxOverlapSize, tokenizer); 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/splitter/impl/DocumentByParagraphSplitter.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.splitter.impl; 2 | 3 | import com.ai.domain.document.splitter.DocumentSplitter; 4 | import com.ai.domain.document.tokenizer.Tokenizer; 5 | 6 | public class DocumentByParagraphSplitter extends HierarchicalDocumentSplitter { 7 | 8 | public DocumentByParagraphSplitter(int maxSegmentSizeInChars, 9 | int maxOverlapSizeInChars) { 10 | super(maxSegmentSizeInChars, maxOverlapSizeInChars, null, null); 11 | } 12 | 13 | public DocumentByParagraphSplitter(int maxSegmentSizeInChars, 14 | int maxOverlapSizeInChars, 15 | DocumentSplitter subSplitter) { 16 | super(maxSegmentSizeInChars, maxOverlapSizeInChars, null, subSplitter); 17 | } 18 | 19 | public DocumentByParagraphSplitter(int maxSegmentSizeInTokens, 20 | int maxOverlapSizeInTokens, 21 | Tokenizer tokenizer) { 22 | super(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer, null); 23 | } 24 | 25 | public DocumentByParagraphSplitter(int maxSegmentSizeInTokens, 26 | int maxOverlapSizeInTokens, 27 | Tokenizer tokenizer, 28 | DocumentSplitter subSplitter) { 29 | super(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer, subSplitter); 30 | } 31 | 32 | @Override 33 | public String[] split(String text) { 34 | return text.split("\\s*\\R\\s*\\R\\s*"); 35 | } 36 | 37 | @Override 38 | public String joinDelimiter() { 39 | return "\n\n"; 40 | } 41 | 42 | @Override 43 | protected DocumentSplitter defaultSubSplitter() { 44 | return new DocumentBySentenceSplitter(maxSegmentSize, maxOverlapSize, tokenizer); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/splitter/impl/DocumentByRegexSplitter.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.splitter.impl; 2 | 3 | import com.ai.domain.document.splitter.DocumentSplitter; 4 | import com.ai.domain.document.tokenizer.Tokenizer; 5 | 6 | 7 | public class DocumentByRegexSplitter extends HierarchicalDocumentSplitter { 8 | 9 | private final String regex; 10 | private final String joinDelimiter; 11 | 12 | public DocumentByRegexSplitter(String regex, 13 | String joinDelimiter, 14 | int maxSegmentSizeInChars, 15 | int maxOverlapSizeInChars) { 16 | super(maxSegmentSizeInChars, maxOverlapSizeInChars, null, null); 17 | this.regex = regex; 18 | this.joinDelimiter = joinDelimiter; 19 | } 20 | 21 | public DocumentByRegexSplitter(String regex, 22 | String joinDelimiter, 23 | int maxSegmentSizeInChars, 24 | int maxOverlapSizeInChars, 25 | DocumentSplitter subSplitter) { 26 | super(maxSegmentSizeInChars, maxOverlapSizeInChars, null, subSplitter); 27 | this.regex = regex; 28 | this.joinDelimiter = joinDelimiter; 29 | } 30 | 31 | public DocumentByRegexSplitter(String regex, 32 | String joinDelimiter, 33 | int maxSegmentSizeInTokens, 34 | int maxOverlapSizeInTokens, 35 | Tokenizer tokenizer) { 36 | super(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer, null); 37 | this.regex = regex; 38 | this.joinDelimiter = joinDelimiter; 39 | } 40 | 41 | public DocumentByRegexSplitter(String regex, 42 | String joinDelimiter, 43 | int maxSegmentSizeInTokens, 44 | int maxOverlapSizeInTokens, 45 | Tokenizer tokenizer, 46 | DocumentSplitter subSplitter) { 47 | super(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer, subSplitter); 48 | this.regex = regex; 49 | this.joinDelimiter = joinDelimiter; 50 | } 51 | 52 | @Override 53 | public String[] split(String text) { 54 | return text.split(regex); 55 | } 56 | 57 | @Override 58 | public String joinDelimiter() { 59 | return joinDelimiter; 60 | } 61 | 62 | @Override 63 | protected DocumentSplitter defaultSubSplitter() { 64 | return null; 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/splitter/impl/DocumentBySentenceSplitter.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.splitter.impl; 2 | 3 | import com.ai.domain.document.splitter.DocumentSplitter; 4 | import com.ai.domain.document.tokenizer.Tokenizer; 5 | 6 | 7 | public class DocumentBySentenceSplitter extends HierarchicalDocumentSplitter { 8 | 9 | public DocumentBySentenceSplitter(int maxSegmentSizeInChars, 10 | int maxOverlapSizeInChars) { 11 | super(maxSegmentSizeInChars, maxOverlapSizeInChars, null, null); 12 | } 13 | 14 | public DocumentBySentenceSplitter(int maxSegmentSizeInChars, 15 | int maxOverlapSizeInChars, 16 | DocumentSplitter subSplitter) { 17 | super(maxSegmentSizeInChars, maxOverlapSizeInChars, null, subSplitter); 18 | } 19 | 20 | public DocumentBySentenceSplitter(int maxSegmentSizeInTokens, 21 | int maxOverlapSizeInTokens, 22 | Tokenizer tokenizer) { 23 | super(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer, null); 24 | } 25 | 26 | public DocumentBySentenceSplitter(int maxSegmentSizeInTokens, 27 | int maxOverlapSizeInTokens, 28 | Tokenizer tokenizer, 29 | DocumentSplitter subSplitter) { 30 | super(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer, subSplitter); 31 | } 32 | 33 | @Override 34 | public String[] split(String text) { 35 | return text.split("\\s*[.。!?!?]\\s*"); 36 | } 37 | 38 | @Override 39 | public String joinDelimiter() { 40 | return " "; 41 | } 42 | 43 | @Override 44 | protected DocumentSplitter defaultSubSplitter() { 45 | return new DocumentByWordSplitter(maxSegmentSize, maxOverlapSize, tokenizer); 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/splitter/impl/DocumentByWordSplitter.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.splitter.impl; 2 | 3 | import com.ai.domain.document.splitter.DocumentSplitter; 4 | import com.ai.domain.document.tokenizer.Tokenizer; 5 | 6 | public class DocumentByWordSplitter extends HierarchicalDocumentSplitter { 7 | 8 | public DocumentByWordSplitter(int maxSegmentSizeInChars, 9 | int maxOverlapSizeInChars) { 10 | super(maxSegmentSizeInChars, maxOverlapSizeInChars, null, null); 11 | } 12 | 13 | public DocumentByWordSplitter(int maxSegmentSizeInChars, 14 | int maxOverlapSizeInChars, 15 | DocumentSplitter subSplitter) { 16 | super(maxSegmentSizeInChars, maxOverlapSizeInChars, null, subSplitter); 17 | } 18 | 19 | public DocumentByWordSplitter(int maxSegmentSizeInTokens, 20 | int maxOverlapSizeInTokens, 21 | Tokenizer tokenizer) { 22 | super(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer, null); 23 | } 24 | 25 | public DocumentByWordSplitter(int maxSegmentSizeInTokens, 26 | int maxOverlapSizeInTokens, 27 | Tokenizer tokenizer, 28 | DocumentSplitter subSplitter) { 29 | super(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer, subSplitter); 30 | } 31 | 32 | @Override 33 | public String[] split(String text) { 34 | return text.split("\\s+"); // additional whitespaces are ignored 35 | } 36 | 37 | @Override 38 | public String joinDelimiter() { 39 | return " "; 40 | } 41 | 42 | @Override 43 | protected DocumentSplitter defaultSubSplitter() { 44 | return new DocumentByCharacterSplitter(maxSegmentSize, maxOverlapSize, tokenizer); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/splitter/impl/DocumentSplitters.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.splitter.impl; 2 | 3 | import com.ai.domain.document.splitter.DocumentSplitter; 4 | import com.ai.domain.document.tokenizer.Tokenizer; 5 | 6 | public class DocumentSplitters { 7 | 8 | public static DocumentSplitter recursive(int maxSegmentSizeInTokens, 9 | int maxOverlapSizeInTokens, 10 | Tokenizer tokenizer) { 11 | return new DocumentByParagraphSplitter(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer, 12 | new DocumentByLineSplitter(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer, 13 | new DocumentBySentenceSplitter(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer, 14 | new DocumentByWordSplitter(maxSegmentSizeInTokens, maxOverlapSizeInTokens, tokenizer) 15 | ) 16 | ) 17 | ); 18 | } 19 | 20 | public static DocumentSplitter recursive(int maxSegmentSizeInChars, int maxOverlapSizeInTokens) { 21 | return recursive(maxSegmentSizeInChars, maxOverlapSizeInTokens, null); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/splitter/impl/SegmentBuilder.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.splitter.impl; 2 | 3 | import java.util.function.Function; 4 | 5 | class SegmentBuilder { 6 | 7 | private final int maxSegmentSize; 8 | private final Function sizeFunction; 9 | private final String joinSeparator; 10 | private StringBuilder segmentBuilder; 11 | 12 | SegmentBuilder(int maxSegmentSize, Function sizeFunction, String joinSeparator) { 13 | this.segmentBuilder = new StringBuilder(); 14 | this.maxSegmentSize = maxSegmentSize; 15 | this.sizeFunction = sizeFunction; 16 | this.joinSeparator = joinSeparator; 17 | } 18 | 19 | boolean hasSpaceFor(String text) { 20 | if (isNotEmpty()) { 21 | return sizeOf(segmentBuilder.toString()) + sizeOf(joinSeparator) + sizeOf(text) <= maxSegmentSize; 22 | } else { 23 | return sizeOf(text) <= maxSegmentSize; 24 | } 25 | } 26 | 27 | private int sizeOf(String text) { 28 | return sizeFunction.apply(text); 29 | } 30 | 31 | void append(String text) { 32 | if (isNotEmpty()) { 33 | segmentBuilder.append(joinSeparator); 34 | } 35 | segmentBuilder.append(text); 36 | } 37 | 38 | void prepend(String text) { 39 | if (isNotEmpty()) { 40 | segmentBuilder.insert(0, text + joinSeparator); 41 | } else { 42 | segmentBuilder.insert(0, text); 43 | } 44 | } 45 | 46 | boolean isNotEmpty() { 47 | return segmentBuilder.length() > 0; 48 | } 49 | 50 | String build() { 51 | return segmentBuilder.toString().trim(); 52 | } 53 | 54 | void reset() { 55 | segmentBuilder = new StringBuilder(); 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/document/tokenizer/Tokenizer.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document.tokenizer; 2 | 3 | import com.ai.domain.data.message.ChatMessage; 4 | import com.ai.domain.tools.ToolExecutionRequest; 5 | import com.ai.domain.tools.ToolSpecification; 6 | 7 | import java.util.ArrayList; 8 | import java.util.List; 9 | 10 | import static com.ai.domain.tools.ToolSpecifications.toolSpecificationsFrom; 11 | import static java.util.Collections.singletonList; 12 | 13 | 14 | public interface Tokenizer { 15 | 16 | int estimateTokenCountInText(String text); 17 | 18 | int estimateTokenCountInMessage(ChatMessage message); 19 | 20 | int estimateTokenCountInMessages(Iterable messages); 21 | 22 | default int estimateTokenCountInTools(Object objectWithTools) { 23 | return estimateTokenCountInTools(singletonList(objectWithTools)); 24 | } 25 | 26 | default int estimateTokenCountInTools(Iterable objectsWithTools) { 27 | List toolSpecifications = new ArrayList<>(); 28 | objectsWithTools.forEach(objectWithTools -> 29 | toolSpecifications.addAll(toolSpecificationsFrom(objectWithTools))); 30 | return estimateTokenCountInToolSpecifications(toolSpecifications); 31 | } 32 | 33 | default int estimateTokenCountInToolSpecification(ToolSpecification toolSpecification) { 34 | return estimateTokenCountInToolSpecifications(singletonList(toolSpecification)); 35 | } 36 | 37 | int estimateTokenCountInToolSpecifications(Iterable toolSpecifications); 38 | 39 | int estimateTokenCountInToolExecutionRequests(Iterable toolExecutionRequests); 40 | 41 | default int estimateTokenCountInForcefulToolExecutionRequest(ToolExecutionRequest toolExecutionRequest) { 42 | return estimateTokenCountInToolExecutionRequests(singletonList(toolExecutionRequest)); 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/memory/chat/ChatHistoryRecorder.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.memory.chat; 2 | 3 | import com.ai.domain.data.message.ChatMessage; 4 | 5 | import java.util.List; 6 | 7 | /** 8 | * 历史聊天信息记录器 9 | **/ 10 | public interface ChatHistoryRecorder { 11 | 12 | /** 13 | * 获取当前对话信息的存储ID 14 | */ 15 | String getId(); 16 | 17 | /** 18 | * 设置当前对话信息的存储ID 19 | * 20 | * @param id 21 | */ 22 | void setId(String id); 23 | 24 | /** 25 | * 添加信息到当前对话信息列表中 26 | */ 27 | void add(ChatMessage message); 28 | 29 | /** 30 | * 得到当前对话信息列表 31 | */ 32 | List getCurrentMessages(); 33 | 34 | /** 35 | * 根据ID在存储器当中获取对话信息列表 36 | */ 37 | List getMessagesById(String id); 38 | 39 | /** 40 | * 清除当前对话信息列表 41 | */ 42 | void clear(); 43 | 44 | /** 45 | * 根据ID清除对话信息列表 46 | */ 47 | void clearById(String id); 48 | 49 | } 50 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/memory/chat/ChatMemoryStore.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.memory.chat; 2 | 3 | import com.ai.domain.data.message.ChatMessage; 4 | 5 | import java.util.List; 6 | 7 | /** 8 | * 历史聊天信息存储器 9 | */ 10 | public interface ChatMemoryStore { 11 | 12 | /** 13 | * 根据ID获取对话信息列表 14 | */ 15 | List getMessages(String memoryId); 16 | 17 | /** 18 | * 根据ID修改对话信息列表 19 | */ 20 | void updateMessages(String memoryId, List msgList); 21 | 22 | /** 23 | * 根据ID添加消息信息 24 | */ 25 | void addMessages(String memoryId, ChatMessage message); 26 | 27 | /** 28 | * 根据ID删除对话信息列表 29 | */ 30 | void deleteMessages(String memoryId); 31 | 32 | } 33 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/memory/chat/impl/SimpleChatHistoryRecorder.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.memory.chat.impl; 2 | 3 | import com.ai.domain.data.message.ChatMessage; 4 | import com.ai.domain.data.message.SystemMessage; 5 | import com.ai.domain.memory.chat.ChatHistoryRecorder; 6 | import com.ai.domain.memory.chat.ChatMemoryStore; 7 | import lombok.Builder; 8 | import lombok.Data; 9 | 10 | import java.util.ArrayList; 11 | import java.util.List; 12 | import java.util.Optional; 13 | 14 | import static com.ai.common.util.Utils.randomUUID; 15 | 16 | 17 | @Data 18 | @Builder 19 | public class SimpleChatHistoryRecorder implements ChatHistoryRecorder { 20 | 21 | @Builder.Default 22 | private String id = randomUUID(); 23 | @Builder.Default 24 | private Integer maxMessageNumber = 100; 25 | @Builder.Default 26 | private ChatMemoryStore memoryStore = new SimpleChatMemoryStore(); 27 | 28 | public SimpleChatHistoryRecorder() { 29 | this(randomUUID(), 30, new SimpleChatMemoryStore()); 30 | } 31 | 32 | public SimpleChatHistoryRecorder(String id, Integer maxMessageNumber, ChatMemoryStore memoryStore) { 33 | this.id = id; 34 | this.maxMessageNumber = maxMessageNumber; 35 | this.memoryStore = memoryStore; 36 | } 37 | 38 | @Override 39 | public String getId() { 40 | return this.id; 41 | } 42 | 43 | @Override 44 | public void setId(String id) { 45 | this.id = id; 46 | } 47 | 48 | @Override 49 | public void add(ChatMessage message) { 50 | List messages = this.getCurrentMessages(); 51 | // 如果添加的是系统消息 52 | if (message instanceof SystemMessage) { 53 | // 先判断是否有相同的系统消息 54 | Optional systemMessage = findSystemMessage(messages); 55 | if (systemMessage.isPresent()) { 56 | if (systemMessage.get().equals(message)) { 57 | return; 58 | } 59 | messages.remove(systemMessage.get()); 60 | } 61 | } 62 | updatePolicy(messages); 63 | memoryStore.addMessages(this.id, message); 64 | } 65 | 66 | /** 67 | * 查找到对应的系统消息 68 | */ 69 | private Optional findSystemMessage(List messages) { 70 | return messages.stream() 71 | .filter((message) -> message instanceof SystemMessage) 72 | .map((message) -> (SystemMessage) message) 73 | .findAny(); 74 | } 75 | 76 | private void updatePolicy(List messages) { 77 | while (messages.size() > maxMessageNumber) { 78 | int messageToRemove = 0; 79 | if (messages.get(0) instanceof SystemMessage) { 80 | messageToRemove = 1; 81 | } 82 | messages.remove(messageToRemove); 83 | } 84 | } 85 | 86 | @Override 87 | public List getCurrentMessages() { 88 | return getMessagesById(this.id); 89 | } 90 | 91 | @Override 92 | public List getMessagesById(String id) { 93 | List messages = new ArrayList(this.memoryStore.getMessages(id)); 94 | updatePolicy(messages); 95 | return messages; 96 | } 97 | 98 | @Override 99 | public void clearById(String id) { 100 | this.memoryStore.deleteMessages(id); 101 | } 102 | 103 | @Override 104 | public void clear() { 105 | this.clearById(this.id); 106 | } 107 | 108 | } 109 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/memory/chat/impl/SimpleChatMemoryStore.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.memory.chat.impl; 2 | 3 | import com.ai.domain.data.message.ChatMessage; 4 | import com.ai.domain.memory.chat.ChatMemoryStore; 5 | import lombok.NoArgsConstructor; 6 | 7 | import java.util.ArrayList; 8 | import java.util.List; 9 | import java.util.Map; 10 | import java.util.concurrent.ConcurrentHashMap; 11 | 12 | /** 13 | * Openai历史信息存储器 14 | **/ 15 | @NoArgsConstructor 16 | public class SimpleChatMemoryStore implements ChatMemoryStore { 17 | 18 | private final Map> messagesByMemoryId = new ConcurrentHashMap<>(); 19 | 20 | public List getMessages(String memoryId) { 21 | return this.messagesByMemoryId.computeIfAbsent(memoryId, (ignored) -> new ArrayList()); 22 | } 23 | 24 | public void updateMessages(String memoryId, List messages) { 25 | this.messagesByMemoryId.put(memoryId, messages); 26 | } 27 | 28 | @Override 29 | public void addMessages(String memoryId, ChatMessage message) { 30 | this.messagesByMemoryId.get(memoryId).add(message); 31 | } 32 | 33 | public void deleteMessages(String memoryId) { 34 | this.messagesByMemoryId.remove(memoryId); 35 | } 36 | 37 | 38 | } 39 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/memory/embedding/EmbeddingMemoryStore.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.memory.embedding; 2 | 3 | import java.util.List; 4 | 5 | /** 6 | * 嵌入数据存储器,存放在内存当中 7 | */ 8 | public interface EmbeddingMemoryStore { 9 | 10 | String add(Data embedding); 11 | 12 | void add(String id, Data embedding); 13 | 14 | List addAll(List embedding); 15 | 16 | List getAllData(); 17 | 18 | } 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/memory/embedding/EmbeddingStoreIngestor.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.memory.embedding; 2 | 3 | 4 | import com.ai.domain.document.Document; 5 | 6 | import java.util.List; 7 | 8 | /** 9 | * 嵌入数据导入器 10 | */ 11 | public interface EmbeddingStoreIngestor { 12 | 13 | /** 14 | * 解析文档信息 15 | */ 16 | void ingest(Document document); 17 | 18 | void ingest(Document... documents); 19 | 20 | void ingest(List documents); 21 | 22 | } 23 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/memory/embedding/EmbeddingStoreJsonCodec.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.memory.embedding; 2 | 3 | 4 | import com.ai.domain.memory.embedding.impl.SimpleEmbeddingMemoryStore; 5 | 6 | public interface EmbeddingStoreJsonCodec { 7 | 8 | SimpleEmbeddingMemoryStore fromJson(String json); 9 | 10 | String toJson(EmbeddingMemoryStore store); 11 | 12 | } 13 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/memory/embedding/EmbeddingStoreJsonCodecFactory.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.memory.embedding; 2 | 3 | 4 | public interface EmbeddingStoreJsonCodecFactory { 5 | 6 | EmbeddingStoreJsonCodec create(); 7 | } 8 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/memory/embedding/EmbeddingStoreRetriever.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.memory.embedding; 2 | 3 | 4 | import com.ai.domain.data.embedding.Embedding; 5 | 6 | import java.util.List; 7 | 8 | /** 9 | * 嵌入数据检索器 10 | */ 11 | public interface EmbeddingStoreRetriever { 12 | 13 | List findRelevant(Embedding embedding, int maxResults, double minScore); 14 | 15 | List findRelevant(Embedding embedding); 16 | 17 | } 18 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/memory/embedding/impl/GsonInMemoryEmbeddingStoreJsonCodec.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.memory.embedding.impl; 2 | 3 | import com.ai.domain.memory.embedding.EmbeddingMemoryStore; 4 | import com.ai.domain.memory.embedding.EmbeddingStoreJsonCodec; 5 | import com.google.gson.Gson; 6 | import com.google.gson.reflect.TypeToken; 7 | 8 | import java.lang.reflect.Type; 9 | 10 | 11 | public class GsonInMemoryEmbeddingStoreJsonCodec implements EmbeddingStoreJsonCodec { 12 | 13 | @Override 14 | public SimpleEmbeddingMemoryStore fromJson(String json) { 15 | Type type = new TypeToken() { 16 | }.getType(); 17 | return new Gson().fromJson(json, type); 18 | } 19 | 20 | @Override 21 | public String toJson(EmbeddingMemoryStore store) { 22 | return new Gson().toJson(store); 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/memory/embedding/impl/SimpleEmbeddingMemoryStore.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.memory.embedding.impl; 2 | 3 | import com.ai.domain.data.embedding.Embedding; 4 | import com.ai.domain.memory.embedding.EmbeddingMemoryStore; 5 | import com.ai.domain.memory.embedding.EmbeddingStoreJsonCodec; 6 | import com.ai.domain.memory.embedding.EmbeddingStoreJsonCodecFactory; 7 | import com.ai.domain.spi.ServiceHelper; 8 | 9 | import java.io.IOException; 10 | import java.nio.file.Files; 11 | import java.nio.file.Path; 12 | import java.util.ArrayList; 13 | import java.util.Collection; 14 | import java.util.List; 15 | import java.util.Map; 16 | import java.util.concurrent.ConcurrentHashMap; 17 | import java.util.stream.Collectors; 18 | 19 | import static com.ai.common.util.Utils.randomUUID; 20 | import static java.nio.file.StandardOpenOption.CREATE; 21 | import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING; 22 | 23 | /** 24 | * 嵌入数据存储器,存放在内存当中 25 | */ 26 | @lombok.Data 27 | public class SimpleEmbeddingMemoryStore implements EmbeddingMemoryStore { 28 | 29 | private static final EmbeddingStoreJsonCodec CODEC = loadCodec(); 30 | private final Map idToEmbeddingData = new ConcurrentHashMap<>(); 31 | 32 | private static EmbeddingStoreJsonCodec loadCodec() { 33 | Collection factories = ServiceHelper.loadFactories( 34 | EmbeddingStoreJsonCodecFactory.class); 35 | for (EmbeddingStoreJsonCodecFactory factory : factories) { 36 | return factory.create(); 37 | } 38 | return new GsonInMemoryEmbeddingStoreJsonCodec(); 39 | } 40 | 41 | public static EmbeddingMemoryStore fromJson(String json) { 42 | return CODEC.fromJson(json); 43 | } 44 | 45 | public static EmbeddingMemoryStore fromFile(Path filePath) { 46 | try { 47 | String json = new String(Files.readAllBytes(filePath)); 48 | return fromJson(json); 49 | } catch (IOException e) { 50 | throw new RuntimeException(e); 51 | } 52 | } 53 | 54 | @Override 55 | public String add(Embedding embedding) { 56 | String id = randomUUID(); 57 | add(id, embedding); 58 | return id; 59 | } 60 | 61 | @Override 62 | public void add(String id, Embedding embedding) { 63 | idToEmbeddingData.put(id, embedding); 64 | } 65 | 66 | @Override 67 | public List addAll(List embeddings) { 68 | List ids = new ArrayList<>(); 69 | for (Embedding embedding : embeddings) { 70 | ids.add(add(embedding)); 71 | } 72 | return ids; 73 | } 74 | 75 | @Override 76 | public List getAllData() { 77 | return idToEmbeddingData.values().stream().collect(Collectors.toList()); 78 | } 79 | 80 | public void serializeToFile(Path filePath) { 81 | try { 82 | String json = serializeToJson(); 83 | Files.write(filePath, json.getBytes(), CREATE, TRUNCATE_EXISTING); 84 | } catch (IOException e) { 85 | throw new RuntimeException(e); 86 | } 87 | } 88 | 89 | public String serializeToJson() { 90 | return CODEC.toJson(this); 91 | } 92 | 93 | } 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/memory/embedding/impl/SimpleEmbeddingStoreIngestor.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.memory.embedding.impl; 2 | 3 | 4 | import com.ai.domain.data.embedding.Embedding; 5 | import com.ai.domain.document.Document; 6 | import com.ai.domain.document.splitter.DocumentSplitter; 7 | import com.ai.domain.document.splitter.impl.DocumentSplitters; 8 | import com.ai.domain.memory.embedding.EmbeddingMemoryStore; 9 | import com.ai.domain.memory.embedding.EmbeddingStoreIngestor; 10 | import com.ai.domain.model.EmbeddingModel; 11 | import lombok.Builder; 12 | import lombok.Data; 13 | import software.amazon.awssdk.annotations.NotNull; 14 | 15 | import java.util.Arrays; 16 | import java.util.Collections; 17 | import java.util.List; 18 | 19 | /** 20 | * 嵌入数据导入器 21 | */ 22 | @Data 23 | @Builder 24 | public class SimpleEmbeddingStoreIngestor implements EmbeddingStoreIngestor { 25 | 26 | @Builder.Default 27 | private DocumentSplitter splitter = DocumentSplitters.recursive(100, 0); 28 | @NotNull 29 | private EmbeddingModel embeddingModel; 30 | @Builder.Default 31 | private EmbeddingMemoryStore store = new SimpleEmbeddingMemoryStore(); 32 | 33 | public void ingest(Document document) { 34 | this.ingest(Collections.singletonList(document)); 35 | } 36 | 37 | public void ingest(Document... documents) { 38 | this.ingest(Arrays.asList(documents)); 39 | } 40 | 41 | public void ingest(List documents) { 42 | for (Document document : documents) { 43 | List data = embeddingModel.embedAll(splitter.split(document)).getData(); 44 | store.addAll(data); 45 | } 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/memory/embedding/impl/SimpleEmbeddingStoreRetriever.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.memory.embedding.impl; 2 | 3 | 4 | import com.ai.domain.data.embedding.CosineSimilarity; 5 | import com.ai.domain.data.embedding.Embedding; 6 | import com.ai.domain.data.embedding.EmbeddingMatch; 7 | import com.ai.domain.memory.embedding.EmbeddingMemoryStore; 8 | import com.ai.domain.memory.embedding.EmbeddingStoreRetriever; 9 | import lombok.Builder; 10 | 11 | import java.util.*; 12 | 13 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 14 | import static java.util.Comparator.comparingDouble; 15 | 16 | /** 17 | * 嵌入数据检索器 18 | */ 19 | @Builder 20 | public class SimpleEmbeddingStoreRetriever implements EmbeddingStoreRetriever { 21 | 22 | @Builder.Default 23 | private final EmbeddingMemoryStore embeddingMemoryStore = new SimpleEmbeddingMemoryStore(); 24 | @Builder.Default 25 | private final int maxResults = 2; 26 | @Builder.Default 27 | private final Double minScore = 0.7; 28 | 29 | public List findRelevant(Embedding embedding, int maxResults, double minScore) { 30 | ensureNotNull(embedding, "embedding"); 31 | Comparator comparator = comparingDouble(EmbeddingMatch::getScore); 32 | PriorityQueue matches = new PriorityQueue<>(comparator); 33 | List allData = embeddingMemoryStore.getAllData(); 34 | for (Embedding data : allData) { 35 | double cosineSimilarity = CosineSimilarity.between(data, embedding); 36 | double score = CosineSimilarity.fromCosineSimilarity(cosineSimilarity); 37 | if (score >= minScore) { 38 | matches.add(new EmbeddingMatch(score, data)); 39 | if (matches.size() > maxResults) { 40 | matches.poll(); 41 | } 42 | } 43 | } 44 | ArrayList result = new ArrayList<>(matches); 45 | result.sort(comparator); 46 | Collections.reverse(result); 47 | return result; 48 | } 49 | 50 | public List findRelevant(Embedding embedding) { 51 | ensureNotNull(embedding, "embedding"); 52 | return findRelevant(embedding, this.maxResults, this.minScore); 53 | } 54 | 55 | 56 | } 57 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/AudioModel.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model; 2 | 3 | 4 | import com.ai.common.resp.AiResponse; 5 | import okhttp3.ResponseBody; 6 | import retrofit2.Callback; 7 | 8 | import java.io.File; 9 | 10 | /** 11 | * 文字和语音处理 12 | */ 13 | public interface AudioModel { 14 | 15 | /** 16 | * 文字转语音 17 | */ 18 | void textToSpeech(String text, Callback callback); 19 | 20 | /** 21 | * 语音转文字 22 | */ 23 | AiResponse speechToText(File speech); 24 | 25 | } 26 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/ChatModel.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model; 2 | 3 | 4 | import com.ai.common.resp.AiResponse; 5 | import com.ai.domain.data.message.AssistantMessage; 6 | import com.ai.domain.data.message.ChatMessage; 7 | import com.ai.domain.data.message.UserMessage; 8 | 9 | import java.util.List; 10 | 11 | import static com.ai.common.util.ValidationUtils.ensureNotBlank; 12 | import static java.util.Arrays.asList; 13 | 14 | public interface ChatModel { 15 | 16 | /** 17 | * 单次问答 18 | */ 19 | default String generate(String userMessage) { 20 | ensureNotBlank(userMessage, "userMessage"); 21 | return generate(UserMessage.message(userMessage)).getData().text(); 22 | } 23 | 24 | /** 25 | * 多轮问答 26 | */ 27 | default AiResponse generate(ChatMessage... messages) { 28 | return generate(asList(messages)); 29 | } 30 | 31 | /** 32 | * 多轮问答 33 | */ 34 | AiResponse generate(List messages); 35 | 36 | } 37 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/EmbeddingModel.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model; 2 | 3 | 4 | import com.ai.common.resp.AiResponse; 5 | import com.ai.domain.data.embedding.Embedding; 6 | import com.ai.domain.document.TextSegment; 7 | 8 | import java.util.Arrays; 9 | import java.util.List; 10 | import java.util.stream.Collectors; 11 | 12 | import static com.ai.common.util.ValidationUtils.*; 13 | 14 | public interface EmbeddingModel { 15 | 16 | /** 17 | * 对单个文本进行嵌入 18 | */ 19 | default AiResponse embed(String text) { 20 | ensureNotBlank(text, "text"); 21 | AiResponse> res = embed(Arrays.asList(text)); 22 | return AiResponse.R(res.getData().get(0), res.getTokenUsage(), res.getFinishReason()); 23 | } 24 | 25 | /** 26 | * 对多个文本进行嵌入 27 | */ 28 | AiResponse> embed(List text); 29 | 30 | /** 31 | * 对切分的文本段进行嵌入 32 | */ 33 | default AiResponse embed(TextSegment textSegment) { 34 | ensureNotNull(textSegment, "textSegment"); 35 | return embed(textSegment.text()); 36 | } 37 | 38 | /** 39 | * 对多个切分的文本段进行嵌入 40 | */ 41 | default AiResponse> embedAll(List textSegmentList) { 42 | ensureNotEmpty(textSegmentList, "textSegments"); 43 | List stringList = textSegmentList.stream() 44 | .map(TextSegment::text).collect(Collectors.toList()); 45 | return embed(stringList); 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/ImageModel.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model; 2 | 3 | 4 | import com.ai.common.resp.AiResponse; 5 | import com.ai.common.resp.finish.FinishReason; 6 | import com.ai.domain.data.images.Image; 7 | 8 | import java.util.List; 9 | 10 | public interface ImageModel { 11 | 12 | /** 13 | * 根据 prompt 生成一张图片 14 | */ 15 | default AiResponse create(String prompt) { 16 | AiResponse> listAiResponse = create(prompt, 1); 17 | return AiResponse.R(listAiResponse.getData().get(0), FinishReason.success()); 18 | } 19 | 20 | /** 21 | * 根据 prompt 生成 n 张图片,n 的大小根据不同的模型而定 22 | * n 的大小根据不同的模型而定 23 | */ 24 | default AiResponse> create(String prompt, int n) { 25 | return create(prompt, null, null, n); 26 | } 27 | 28 | /** 29 | * 根据 prompt 生成 n 张图片,大小为 size,风格为 style 的图片 30 | * n/size/style 的值根据不同模型而定 31 | */ 32 | AiResponse> create(String prompt, String size, String style, int n); 33 | 34 | } 35 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/ModelTemplate.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model; 2 | 3 | public interface ModelTemplate { 4 | 5 | Resp createAiResponse(Req request); 6 | } 7 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/ModerationModel.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model; 2 | 3 | 4 | import com.ai.common.resp.AiResponse; 5 | import com.ai.common.resp.finish.FinishReason; 6 | import com.ai.domain.data.message.ChatMessage; 7 | import com.ai.domain.data.moderation.Moderation; 8 | 9 | import java.util.Arrays; 10 | import java.util.List; 11 | 12 | import static com.ai.common.util.ValidationUtils.ensureNotBlank; 13 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 14 | 15 | public interface ModerationModel { 16 | 17 | /** 18 | * 对单个文本进行审核 19 | */ 20 | default AiResponse moderate(String message) { 21 | ensureNotBlank(message, "message"); 22 | AiResponse> moderate = moderate(Arrays.asList(message)); 23 | return AiResponse.R(moderate.getData().get(0), FinishReason.success()); 24 | } 25 | 26 | /** 27 | * 对多个文本进行审核 28 | */ 29 | AiResponse> moderate(List messages); 30 | 31 | /** 32 | * 对消息文本进行审核 33 | */ 34 | default AiResponse moderate(ChatMessage messages) { 35 | ensureNotNull(messages, "messages"); 36 | return moderate(messages.getText()); 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/output/BigDecimalOutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model.output; 2 | 3 | import com.ai.domain.service.OutputParser; 4 | 5 | import java.math.BigDecimal; 6 | 7 | public class BigDecimalOutputParser implements OutputParser { 8 | 9 | @Override 10 | public BigDecimal parse(String string) { 11 | return new BigDecimal(string); 12 | } 13 | 14 | @Override 15 | public String formatInstructions() { 16 | return "floating point number"; 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/output/BigIntegerOutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model.output; 2 | 3 | import com.ai.domain.service.OutputParser; 4 | 5 | import java.math.BigInteger; 6 | 7 | public class BigIntegerOutputParser implements OutputParser { 8 | 9 | @Override 10 | public BigInteger parse(String string) { 11 | return new BigInteger(string); 12 | } 13 | 14 | @Override 15 | public String formatInstructions() { 16 | return "integer number"; 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/output/BooleanOutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model.output; 2 | 3 | import com.ai.domain.service.OutputParser; 4 | 5 | public class BooleanOutputParser implements OutputParser { 6 | 7 | @Override 8 | public Boolean parse(String string) { 9 | return Boolean.parseBoolean(string); 10 | } 11 | 12 | @Override 13 | public String formatInstructions() { 14 | return "one of [true, false]"; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/output/ByteOutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model.output; 2 | 3 | import com.ai.domain.service.OutputParser; 4 | 5 | public class ByteOutputParser implements OutputParser { 6 | 7 | @Override 8 | public Byte parse(String string) { 9 | return Byte.parseByte(string); 10 | } 11 | 12 | @Override 13 | public String formatInstructions() { 14 | return "integer number in range [-128, 127]"; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/output/DateOutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model.output; 2 | 3 | import com.ai.domain.service.OutputParser; 4 | 5 | import java.text.ParseException; 6 | import java.text.SimpleDateFormat; 7 | import java.util.Date; 8 | 9 | public class DateOutputParser implements OutputParser { 10 | 11 | private static final String DATE_PATTERN = "yyyy-MM-dd"; 12 | private static final SimpleDateFormat SIMPLE_DATE_FORMAT = new SimpleDateFormat(DATE_PATTERN); 13 | 14 | @Override 15 | public Date parse(String string) { 16 | try { 17 | return SIMPLE_DATE_FORMAT.parse(string); 18 | } catch (ParseException e) { 19 | throw new RuntimeException(e); 20 | } 21 | } 22 | 23 | @Override 24 | public String formatInstructions() { 25 | return DATE_PATTERN; 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/output/DoubleOutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model.output; 2 | 3 | import com.ai.domain.service.OutputParser; 4 | 5 | public class DoubleOutputParser implements OutputParser { 6 | 7 | @Override 8 | public Double parse(String string) { 9 | return Double.parseDouble(string); 10 | } 11 | 12 | @Override 13 | public String formatInstructions() { 14 | return "floating point number"; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/output/EnumOutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model.output; 2 | 3 | import com.ai.domain.service.OutputParser; 4 | import com.google.gson.Gson; 5 | 6 | import java.util.Arrays; 7 | 8 | public class EnumOutputParser implements OutputParser { 9 | 10 | private final Class enumClass; 11 | 12 | public EnumOutputParser(Class enumClass) { 13 | this.enumClass = enumClass; 14 | } 15 | 16 | @Override 17 | public Enum parse(String string) { 18 | return new Gson().fromJson(string, enumClass); 19 | } 20 | 21 | @Override 22 | public String formatInstructions() { 23 | return "one of " + Arrays.toString(enumClass.getEnumConstants()); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/output/FloatOutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model.output; 2 | 3 | import com.ai.domain.service.OutputParser; 4 | 5 | public class FloatOutputParser implements OutputParser { 6 | 7 | @Override 8 | public Float parse(String string) { 9 | return Float.parseFloat(string); 10 | } 11 | 12 | @Override 13 | public String formatInstructions() { 14 | return "floating point number"; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/output/IntOutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model.output; 2 | 3 | import com.ai.domain.service.OutputParser; 4 | 5 | public class IntOutputParser implements OutputParser { 6 | 7 | @Override 8 | public Integer parse(String string) { 9 | return Integer.parseInt(string); 10 | } 11 | 12 | @Override 13 | public String formatInstructions() { 14 | return "integer number"; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/output/LocalDateOutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model.output; 2 | 3 | import com.ai.domain.service.OutputParser; 4 | 5 | import java.time.LocalDate; 6 | 7 | import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE; 8 | 9 | public class LocalDateOutputParser implements OutputParser { 10 | 11 | @Override 12 | public LocalDate parse(String string) { 13 | return LocalDate.parse(string, ISO_LOCAL_DATE); 14 | } 15 | 16 | @Override 17 | public String formatInstructions() { 18 | return "2023-12-31"; 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/output/LocalDateTimeOutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model.output; 2 | 3 | import com.ai.domain.service.OutputParser; 4 | 5 | import java.time.LocalDateTime; 6 | 7 | import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE_TIME; 8 | 9 | public class LocalDateTimeOutputParser implements OutputParser { 10 | 11 | @Override 12 | public LocalDateTime parse(String string) { 13 | return LocalDateTime.parse(string, ISO_LOCAL_DATE_TIME); 14 | } 15 | 16 | @Override 17 | public String formatInstructions() { 18 | return "2023-12-31T23:59:59"; 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/output/LocalTimeOutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model.output; 2 | 3 | import com.ai.domain.service.OutputParser; 4 | 5 | import java.time.LocalTime; 6 | 7 | import static java.time.format.DateTimeFormatter.ISO_LOCAL_TIME; 8 | 9 | public class LocalTimeOutputParser implements OutputParser { 10 | 11 | @Override 12 | public LocalTime parse(String string) { 13 | return LocalTime.parse(string, ISO_LOCAL_TIME); 14 | } 15 | 16 | @Override 17 | public String formatInstructions() { 18 | return "23:59:59"; 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/output/LongOutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model.output; 2 | 3 | import com.ai.domain.service.OutputParser; 4 | 5 | public class LongOutputParser implements OutputParser { 6 | 7 | @Override 8 | public Long parse(String string) { 9 | return Long.parseLong(string); 10 | } 11 | 12 | @Override 13 | public String formatInstructions() { 14 | return "integer number"; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/model/output/ShortOutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.model.output; 2 | 3 | import com.ai.domain.service.OutputParser; 4 | 5 | public class ShortOutputParser implements OutputParser { 6 | 7 | @Override 8 | public Short parse(String string) { 9 | return Short.parseShort(string); 10 | } 11 | 12 | @Override 13 | public String formatInstructions() { 14 | return "integer number in range [-32768, 32767]"; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/prompt/Prompt.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.prompt; 2 | 3 | /** 4 | * @Description: 提示词 5 | **/ 6 | public interface Prompt { 7 | 8 | String text(); 9 | 10 | } 11 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/prompt/PromptTemplate.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.prompt; 2 | 3 | import java.util.Map; 4 | 5 | /** 6 | * @Description: 提示词模板 7 | **/ 8 | public interface PromptTemplate { 9 | 10 | String render(Map keys); 11 | 12 | Prompt apply(Map keys); 13 | 14 | } 15 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/prompt/impl/SimplePrompt.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.prompt.impl; 2 | 3 | import com.ai.domain.prompt.Prompt; 4 | import lombok.AllArgsConstructor; 5 | import lombok.Data; 6 | import lombok.NoArgsConstructor; 7 | 8 | @Data 9 | @AllArgsConstructor 10 | @NoArgsConstructor 11 | public class SimplePrompt implements Prompt { 12 | 13 | private String text; 14 | private SimplePromptTemplate promptTemplate; 15 | 16 | @Override 17 | public String text() { 18 | return text; 19 | } 20 | 21 | } 22 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/prompt/impl/SimplePromptTemplate.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.prompt.impl; 2 | 3 | import com.ai.common.util.PlaceHolderReplaceUtils; 4 | import com.ai.domain.prompt.PromptTemplate; 5 | 6 | import java.util.HashMap; 7 | import java.util.Map; 8 | 9 | import static com.ai.common.util.ValidationUtils.ensureNotBlank; 10 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 11 | 12 | /** 13 | * @Description: 普通提示词模板 14 | **/ 15 | public class SimplePromptTemplate implements PromptTemplate { 16 | 17 | private final Map renderMap = new HashMap<>(); 18 | private String template; 19 | private String promptName; 20 | 21 | public SimplePromptTemplate(String template, String promptName) { 22 | this.template = ensureNotBlank(template, "template"); 23 | this.promptName = ensureNotBlank(promptName, "promptName"); 24 | } 25 | 26 | public static String render(String prompt, Map keys) { 27 | return PlaceHolderReplaceUtils.replaceWithMap(prompt, keys); 28 | } 29 | 30 | @Override 31 | public SimplePrompt apply(Map keys) { 32 | return new SimplePrompt(PlaceHolderReplaceUtils.replaceWithMap(template, keys), new SimplePromptTemplate(this.template, this.promptName)); 33 | } 34 | 35 | public void add(String key, String value) { 36 | this.renderMap.put(key, value); 37 | } 38 | 39 | public void addAll(Map m) { 40 | this.renderMap.putAll(m); 41 | } 42 | 43 | @Override 44 | public String render(Map keys) { 45 | return PlaceHolderReplaceUtils.replaceWithMap(template, keys); 46 | } 47 | 48 | public String render() { 49 | ensureNotNull(this.renderMap, "renderMap"); 50 | return render(this.renderMap); 51 | } 52 | 53 | public String getTemplate() { 54 | return template; 55 | } 56 | 57 | public void setTemplate(String template) { 58 | this.template = template; 59 | } 60 | 61 | public String getPromptName() { 62 | return promptName; 63 | } 64 | 65 | public void setPromptName(String promptName) { 66 | this.promptName = promptName; 67 | } 68 | 69 | @Override 70 | public String toString() { 71 | return "SimplePromptTemplate{" + 72 | "template='" + template + '\'' + 73 | ", promptName='" + promptName + '\'' + 74 | ", renderMap=" + renderMap + 75 | '}'; 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/service/AiServiceContext.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.service; 2 | 3 | import com.ai.domain.memory.chat.ChatHistoryRecorder; 4 | import com.ai.domain.model.ChatModel; 5 | import com.ai.domain.model.ModerationModel; 6 | import com.ai.domain.tools.ToolSpecification; 7 | import lombok.Builder; 8 | import lombok.Data; 9 | 10 | import java.util.List; 11 | 12 | @Data 13 | @Builder 14 | public class AiServiceContext { 15 | private Class aiServiceClass; 16 | private ChatModel chatModel; 17 | private ModerationModel moderationModel; 18 | private ChatHistoryRecorder chatHistoryRecorder; 19 | private List toolSpecifications; 20 | } 21 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/service/AiServices.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.service; 2 | 3 | import com.ai.domain.memory.chat.ChatHistoryRecorder; 4 | import com.ai.domain.model.ChatModel; 5 | import com.ai.domain.model.ModerationModel; 6 | import com.ai.domain.spi.AiServicesFactory; 7 | import com.ai.domain.spi.ServiceHelper; 8 | 9 | import java.util.Collection; 10 | 11 | public abstract class AiServices { 12 | 13 | protected AiServiceContext context; 14 | 15 | protected AiServices(AiServiceContext context) { 16 | this.context = context; 17 | } 18 | 19 | public static AiServices builder(Class aiService) { 20 | AiServiceContext context = AiServiceContext.builder().aiServiceClass(aiService).build(); 21 | Collection aiServicesFactories = ServiceHelper.loadFactories(AiServicesFactory.class); 22 | for (AiServicesFactory factory : aiServicesFactories) { 23 | return factory.create(context); 24 | } 25 | return new DefaultAiServices<>(context); 26 | } 27 | 28 | public AiServices chat(ChatModel chatModel) { 29 | context.setChatModel(chatModel); 30 | return this; 31 | } 32 | 33 | public AiServices memory(ChatHistoryRecorder chatHistoryRecorder) { 34 | context.setChatHistoryRecorder(chatHistoryRecorder); 35 | return this; 36 | } 37 | 38 | public AiServices moderate(ModerationModel moderationModel) { 39 | context.setModerationModel(moderationModel); 40 | return this; 41 | } 42 | 43 | public abstract T build(); 44 | 45 | public void performValidation() { 46 | if (this.context.getChatModel() == null) { 47 | throw new IllegalArgumentException("No chat model set"); 48 | } 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/service/OutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.service; 2 | 3 | public interface OutputParser { 4 | 5 | T parse(String text); 6 | 7 | String formatInstructions(); 8 | } 9 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/service/ServiceOutputParser.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.service; 2 | 3 | import com.ai.common.resp.AiResponse; 4 | import com.ai.domain.data.message.AssistantMessage; 5 | import com.ai.domain.model.output.*; 6 | import com.google.gson.Gson; 7 | 8 | import java.math.BigDecimal; 9 | import java.math.BigInteger; 10 | import java.time.LocalDate; 11 | import java.time.LocalDateTime; 12 | import java.time.LocalTime; 13 | import java.util.*; 14 | 15 | import static java.util.Arrays.asList; 16 | 17 | public class ServiceOutputParser { 18 | 19 | private static final Map, OutputParser> OUTPUT_PARSERS = new HashMap<>(); 20 | 21 | static { 22 | OUTPUT_PARSERS.put(boolean.class, new BooleanOutputParser()); 23 | OUTPUT_PARSERS.put(Boolean.class, new BooleanOutputParser()); 24 | 25 | OUTPUT_PARSERS.put(byte.class, new ByteOutputParser()); 26 | OUTPUT_PARSERS.put(Byte.class, new ByteOutputParser()); 27 | 28 | OUTPUT_PARSERS.put(short.class, new ShortOutputParser()); 29 | OUTPUT_PARSERS.put(Short.class, new ShortOutputParser()); 30 | 31 | OUTPUT_PARSERS.put(int.class, new IntOutputParser()); 32 | OUTPUT_PARSERS.put(Integer.class, new IntOutputParser()); 33 | 34 | OUTPUT_PARSERS.put(long.class, new LongOutputParser()); 35 | OUTPUT_PARSERS.put(Long.class, new LongOutputParser()); 36 | 37 | OUTPUT_PARSERS.put(BigInteger.class, new BigIntegerOutputParser()); 38 | 39 | OUTPUT_PARSERS.put(float.class, new FloatOutputParser()); 40 | OUTPUT_PARSERS.put(Float.class, new FloatOutputParser()); 41 | 42 | OUTPUT_PARSERS.put(double.class, new DoubleOutputParser()); 43 | OUTPUT_PARSERS.put(Double.class, new DoubleOutputParser()); 44 | 45 | OUTPUT_PARSERS.put(BigDecimal.class, new BigDecimalOutputParser()); 46 | 47 | OUTPUT_PARSERS.put(Date.class, new DateOutputParser()); 48 | OUTPUT_PARSERS.put(LocalDate.class, new LocalDateOutputParser()); 49 | OUTPUT_PARSERS.put(LocalTime.class, new LocalTimeOutputParser()); 50 | OUTPUT_PARSERS.put(LocalDateTime.class, new LocalDateTimeOutputParser()); 51 | } 52 | 53 | public static Object parse(AiResponse response, Class returnType) { 54 | 55 | if (returnType == AiResponse.class) { 56 | return response; 57 | } 58 | 59 | AssistantMessage aiMessage = response.getData(); 60 | if (returnType == AssistantMessage.class) { 61 | return aiMessage; 62 | } 63 | 64 | String text = aiMessage.text(); 65 | if (returnType == String.class) { 66 | return text; 67 | } 68 | 69 | OutputParser outputParser = OUTPUT_PARSERS.get(returnType); 70 | if (outputParser != null) { 71 | return outputParser.parse(text); 72 | } 73 | 74 | if (returnType == List.class) { 75 | return asList(text.split("\n")); 76 | } 77 | 78 | if (returnType == Set.class) { 79 | return new HashSet<>(asList(text.split("\n"))); 80 | } 81 | 82 | return new Gson().fromJson(text, returnType); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/service/annotation/ChatConfig.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.service.annotation; 2 | 3 | import java.lang.annotation.Retention; 4 | import java.lang.annotation.Target; 5 | 6 | import static java.lang.annotation.ElementType.TYPE; 7 | import static java.lang.annotation.RetentionPolicy.RUNTIME; 8 | 9 | /** 10 | * 用于设置对话的模型,审核模型,消息存储器。 11 | */ 12 | @Target(TYPE) 13 | @Retention(RUNTIME) 14 | public @interface ChatConfig { 15 | 16 | /** 17 | * 设置要使用的存储器 18 | */ 19 | Class memory() default void.class; 20 | 21 | /** 22 | * 设置要使用的审核模型 23 | * 24 | * @return 25 | */ 26 | Class moderate() default void.class; 27 | 28 | /** 29 | * 设置要使用的对话模型 30 | */ 31 | Class chat() default void.class; 32 | 33 | } 34 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/service/annotation/Memory.java: -------------------------------------------------------------------------------- 1 | //package com.ai.domain.service.annotation; 2 | // 3 | //import java.lang.annotation.Retention; 4 | //import java.lang.annotation.Target; 5 | // 6 | //import static java.lang.annotation.ElementType.METHOD; 7 | //import static java.lang.annotation.RetentionPolicy.RUNTIME; 8 | // 9 | ///** 10 | // * 用此注解标注在方法上,指定聊天所用的存储器。 11 | // */ 12 | //@Target(METHOD) 13 | //@Retention(RUNTIME) 14 | //public @interface Memory { 15 | // 16 | //} 17 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/service/annotation/MemoryId.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.service.annotation; 2 | 3 | import java.lang.annotation.ElementType; 4 | import java.lang.annotation.Retention; 5 | import java.lang.annotation.RetentionPolicy; 6 | import java.lang.annotation.Target; 7 | 8 | /** 9 | * 此注解用来标注出用户存储器当中的存储ID,支持 String\int\long\Integer\Long作为MemoryId 10 | * 此注解生效的前提是:设置了消息存储器 11 | */ 12 | @Retention(RetentionPolicy.RUNTIME) 13 | @Target({ElementType.PARAMETER}) 14 | public @interface MemoryId { 15 | } 16 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/service/annotation/Moderate.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.service.annotation; 2 | 3 | import java.lang.annotation.Retention; 4 | import java.lang.annotation.Target; 5 | 6 | import static java.lang.annotation.ElementType.METHOD; 7 | import static java.lang.annotation.RetentionPolicy.RUNTIME; 8 | 9 | /** 10 | * 用此注解标注在方法上,指定聊天所用的审核模型 11 | */ 12 | @Target(METHOD) 13 | @Retention(RUNTIME) 14 | public @interface Moderate { 15 | 16 | Class value() default Void.class; 17 | 18 | } 19 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/service/annotation/Prompt.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.service.annotation; 2 | 3 | 4 | import java.lang.annotation.Retention; 5 | import java.lang.annotation.Target; 6 | 7 | import static java.lang.annotation.ElementType.TYPE; 8 | import static java.lang.annotation.RetentionPolicy.RUNTIME; 9 | 10 | /** 11 | * 此注解标注在类上,表示这是一个模板类。需要跟@UserMessage注解配合使用。 12 | */ 13 | @Target(TYPE) 14 | @Retention(RUNTIME) 15 | public @interface Prompt { 16 | 17 | String[] value(); 18 | 19 | String delimiter() default "\n"; 20 | 21 | } 22 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/service/annotation/SystemMessage.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.service.annotation; 2 | 3 | import java.lang.annotation.Retention; 4 | import java.lang.annotation.Target; 5 | 6 | import static java.lang.annotation.ElementType.METHOD; 7 | import static java.lang.annotation.RetentionPolicy.RUNTIME; 8 | 9 | /** 10 | * 通过此注解来标识系统信息 11 | */ 12 | @Target(METHOD) 13 | @Retention(RUNTIME) 14 | public @interface SystemMessage { 15 | 16 | String[] value() default ""; 17 | 18 | String delimiter() default "\n"; 19 | 20 | } 21 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/service/annotation/UserMessage.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.service.annotation; 2 | 3 | import java.lang.annotation.Retention; 4 | import java.lang.annotation.Target; 5 | 6 | import static java.lang.annotation.ElementType.METHOD; 7 | import static java.lang.annotation.ElementType.PARAMETER; 8 | import static java.lang.annotation.RetentionPolicy.RUNTIME; 9 | 10 | /** 11 | * 通过此注解来标识用户信息 12 | */ 13 | @Retention(RUNTIME) 14 | @Target({METHOD, PARAMETER}) 15 | public @interface UserMessage { 16 | 17 | String[] value() default ""; 18 | 19 | String delimiter() default "\n"; 20 | 21 | } 22 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/service/annotation/V.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.service.annotation; 2 | 3 | import java.lang.annotation.Retention; 4 | import java.lang.annotation.Target; 5 | 6 | import static java.lang.annotation.ElementType.PARAMETER; 7 | import static java.lang.annotation.RetentionPolicy.RUNTIME; 8 | 9 | /** 10 | * 此注解用来表示系统消息或用户消息当中的占位符内容,如果占位符在系统消息和用户消息当中都存在,那么都会替换。 11 | */ 12 | @Target(PARAMETER) 13 | @Retention(RUNTIME) 14 | public @interface V { 15 | String value(); 16 | } 17 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/spi/AiServicesFactory.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.spi; 2 | 3 | import com.ai.domain.service.AiServiceContext; 4 | import com.ai.domain.service.AiServices; 5 | 6 | public interface AiServicesFactory { 7 | 8 | AiServices create(AiServiceContext context); 9 | } 10 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/spi/ServiceHelper.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.spi; 2 | 3 | import java.util.*; 4 | 5 | public class ServiceHelper { 6 | 7 | public static Collection loadFactories(Class clazz) { 8 | return loadFactories(clazz, null); 9 | } 10 | 11 | public static Collection loadFactories(Class clazz, ClassLoader classLoader) { 12 | List list = new ArrayList<>(); 13 | ServiceLoader factories; 14 | if (classLoader != null) { 15 | factories = ServiceLoader.load(clazz, classLoader); 16 | } else { 17 | factories = ServiceLoader.load(clazz); 18 | } 19 | if (factories.iterator().hasNext()) { 20 | factories.iterator().forEachRemaining(list::add); 21 | return list; 22 | } else { 23 | factories = ServiceLoader.load(clazz, ServiceHelper.class.getClassLoader()); 24 | if (factories.iterator().hasNext()) { 25 | factories.iterator().forEachRemaining(list::add); 26 | return list; 27 | } else { 28 | return Collections.emptyList(); 29 | } 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/tools/JsonSchemaProperty.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.tools; 2 | 3 | import java.util.Objects; 4 | 5 | import static java.util.Collections.singletonMap; 6 | 7 | 8 | public class JsonSchemaProperty { 9 | 10 | public static final JsonSchemaProperty STRING = type("string"); 11 | public static final JsonSchemaProperty INTEGER = type("integer"); 12 | public static final JsonSchemaProperty NUMBER = type("number"); 13 | public static final JsonSchemaProperty OBJECT = type("object"); 14 | public static final JsonSchemaProperty ARRAY = type("array"); 15 | public static final JsonSchemaProperty BOOLEAN = type("boolean"); 16 | public static final JsonSchemaProperty NULL = type("null"); 17 | 18 | private final String key; 19 | private final Object value; 20 | 21 | public JsonSchemaProperty(String key, Object value) { 22 | this.key = key; 23 | this.value = value; 24 | } 25 | 26 | public static JsonSchemaProperty from(String key, Object value) { 27 | return new JsonSchemaProperty(key, value); 28 | } 29 | 30 | public static JsonSchemaProperty property(String key, Object value) { 31 | return from(key, value); 32 | } 33 | 34 | public static JsonSchemaProperty type(String value) { 35 | return from("type", value); 36 | } 37 | 38 | public static JsonSchemaProperty description(String value) { 39 | return from("description", value); 40 | } 41 | 42 | public static JsonSchemaProperty enums(String... enumValues) { 43 | return from("enum", enumValues); 44 | } 45 | 46 | public static JsonSchemaProperty enums(Object... enumValues) { 47 | for (Object enumValue : enumValues) { 48 | if (!enumValue.getClass().isEnum()) { 49 | throw new RuntimeException("Value " + enumValue.getClass().getName() + " should be enum"); 50 | } 51 | } 52 | 53 | return from("enum", enumValues); 54 | } 55 | 56 | public static JsonSchemaProperty enums(Class enumClass) { 57 | if (!enumClass.isEnum()) { 58 | throw new RuntimeException("Class " + enumClass.getName() + " should be enum"); 59 | } 60 | 61 | return from("enum", enumClass.getEnumConstants()); 62 | } 63 | 64 | public static JsonSchemaProperty items(JsonSchemaProperty type) { 65 | return from("items", singletonMap(type.key, type.value)); 66 | } 67 | 68 | public String key() { 69 | return key; 70 | } 71 | 72 | public Object value() { 73 | return value; 74 | } 75 | 76 | @Override 77 | public boolean equals(Object another) { 78 | if (this == another) return true; 79 | return another instanceof JsonSchemaProperty 80 | && equalTo((JsonSchemaProperty) another); 81 | } 82 | 83 | private boolean equalTo(JsonSchemaProperty another) { 84 | return Objects.equals(key, another.key) 85 | && Objects.equals(value, another.value); 86 | } 87 | 88 | @Override 89 | public int hashCode() { 90 | int h = 5381; 91 | h += (h << 5) + Objects.hashCode(key); 92 | h += (h << 5) + Objects.hashCode(value); 93 | return h; 94 | } 95 | } 96 | 97 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/tools/ToolExecutionRequest.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.tools; 2 | 3 | import java.util.Objects; 4 | 5 | public class ToolExecutionRequest { 6 | 7 | private final String id; 8 | private final String name; 9 | private final String arguments; 10 | 11 | private ToolExecutionRequest(Builder builder) { 12 | this.id = builder.id; 13 | this.name = builder.name; 14 | this.arguments = builder.arguments; 15 | } 16 | 17 | public static Builder builder() { 18 | return new Builder(); 19 | } 20 | 21 | public String id() { 22 | return id; 23 | } 24 | 25 | public String name() { 26 | return name; 27 | } 28 | 29 | public String arguments() { 30 | return arguments; 31 | } 32 | 33 | @Override 34 | public boolean equals(Object another) { 35 | if (this == another) return true; 36 | return another instanceof ToolExecutionRequest 37 | && equalTo((ToolExecutionRequest) another); 38 | } 39 | 40 | private boolean equalTo(ToolExecutionRequest another) { 41 | return Objects.equals(id, another.id) 42 | && Objects.equals(name, another.name) 43 | && Objects.equals(arguments, another.arguments); 44 | } 45 | 46 | @Override 47 | public int hashCode() { 48 | int h = 5381; 49 | h += (h << 5) + Objects.hashCode(id); 50 | h += (h << 5) + Objects.hashCode(name); 51 | h += (h << 5) + Objects.hashCode(arguments); 52 | return h; 53 | } 54 | 55 | public static final class Builder { 56 | 57 | private String id; 58 | private String name; 59 | private String arguments; 60 | 61 | private Builder() { 62 | } 63 | 64 | public Builder id(String id) { 65 | this.id = id; 66 | return this; 67 | } 68 | 69 | public Builder name(String name) { 70 | this.name = name; 71 | return this; 72 | } 73 | 74 | public Builder arguments(String arguments) { 75 | this.arguments = arguments; 76 | return this; 77 | } 78 | 79 | public ToolExecutionRequest build() { 80 | return new ToolExecutionRequest(this); 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/tools/ToolParameters.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.tools; 2 | 3 | import java.util.*; 4 | 5 | public class ToolParameters { 6 | 7 | private final String type; 8 | private final Map> properties; 9 | private final List required; 10 | 11 | private ToolParameters(Builder builder) { 12 | this.type = builder.type; 13 | this.properties = builder.properties; 14 | this.required = builder.required; 15 | } 16 | 17 | public static Builder builder() { 18 | return new Builder(); 19 | } 20 | 21 | public String type() { 22 | return type; 23 | } 24 | 25 | public Map> properties() { 26 | return properties; 27 | } 28 | 29 | public List required() { 30 | return required; 31 | } 32 | 33 | @Override 34 | public boolean equals(Object another) { 35 | if (this == another) return true; 36 | return another instanceof ToolParameters 37 | && equalTo((ToolParameters) another); 38 | } 39 | 40 | private boolean equalTo(ToolParameters another) { 41 | return Objects.equals(type, another.type) 42 | && Objects.equals(properties, another.properties) 43 | && Objects.equals(required, another.required); 44 | } 45 | 46 | @Override 47 | public int hashCode() { 48 | int h = 5381; 49 | h += (h << 5) + Objects.hashCode(type); 50 | h += (h << 5) + Objects.hashCode(properties); 51 | h += (h << 5) + Objects.hashCode(required); 52 | return h; 53 | } 54 | 55 | public static final class Builder { 56 | 57 | private String type = "object"; 58 | private Map> properties = new HashMap<>(); 59 | private List required = new ArrayList<>(); 60 | 61 | private Builder() { 62 | } 63 | 64 | public Builder type(String type) { 65 | this.type = type; 66 | return this; 67 | } 68 | 69 | public Builder properties(Map> properties) { 70 | this.properties = properties; 71 | return this; 72 | } 73 | 74 | public Builder required(List required) { 75 | this.required = required; 76 | return this; 77 | } 78 | 79 | public ToolParameters build() { 80 | return new ToolParameters(this); 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/tools/ToolSpecifications.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.tools; 2 | 3 | import com.ai.domain.tools.annotation.P; 4 | import com.ai.domain.tools.annotation.Tool; 5 | import com.ai.domain.tools.annotation.ToolMemoryId; 6 | 7 | import java.lang.reflect.Method; 8 | import java.lang.reflect.Parameter; 9 | import java.math.BigDecimal; 10 | import java.math.BigInteger; 11 | import java.util.List; 12 | import java.util.Objects; 13 | import java.util.Set; 14 | 15 | import static com.ai.common.util.Utils.isNullOrBlank; 16 | import static com.ai.domain.tools.JsonSchemaProperty.*; 17 | import static java.util.Arrays.stream; 18 | import static java.util.stream.Collectors.toList; 19 | 20 | public class ToolSpecifications { 21 | 22 | public static List toolSpecificationsFrom(Object objectWithTools) { 23 | return stream(objectWithTools.getClass().getDeclaredMethods()) 24 | .filter(method -> method.isAnnotationPresent(Tool.class)) 25 | .map(ToolSpecifications::toolSpecificationFrom) 26 | .collect(toList()); 27 | } 28 | 29 | public static ToolSpecification toolSpecificationFrom(Method method) { 30 | Tool annotation = method.getAnnotation(Tool.class); 31 | 32 | String name = isNullOrBlank(annotation.name()) ? method.getName() : annotation.name(); 33 | String description = String.join("\n", annotation.value()); 34 | 35 | ToolSpecification.Builder builder = ToolSpecification.builder() 36 | .name(name) 37 | .description(description); 38 | 39 | for (Parameter parameter : method.getParameters()) { 40 | if (parameter.isAnnotationPresent(ToolMemoryId.class)) { 41 | continue; 42 | } 43 | builder.addParameter(parameter.getName(), toJsonSchemaProperties(parameter)); 44 | } 45 | 46 | return builder.build(); 47 | } 48 | 49 | private static Iterable toJsonSchemaProperties(Parameter parameter) { 50 | 51 | Class type = parameter.getType(); 52 | 53 | P annotation = parameter.getAnnotation(P.class); 54 | JsonSchemaProperty description = annotation == null ? null : description(annotation.value()); 55 | 56 | if (type == String.class) { 57 | return removeNulls(STRING, description); 58 | } 59 | 60 | if (type == boolean.class || type == Boolean.class) { 61 | return removeNulls(BOOLEAN, description); 62 | } 63 | 64 | if (type == byte.class || type == Byte.class 65 | || type == short.class || type == Short.class 66 | || type == int.class || type == Integer.class 67 | || type == long.class || type == Long.class 68 | || type == BigInteger.class) { 69 | return removeNulls(INTEGER, description); 70 | } 71 | 72 | if (type == float.class || type == Float.class 73 | || type == double.class || type == Double.class 74 | || type == BigDecimal.class) { 75 | return removeNulls(NUMBER, description); 76 | } 77 | 78 | if (type.isArray() 79 | || type == List.class 80 | || type == Set.class) { 81 | return removeNulls(ARRAY, description); 82 | } 83 | 84 | if (type.isEnum()) { 85 | return removeNulls(STRING, enums((Object[]) type.getEnumConstants()), description); 86 | } 87 | 88 | return removeNulls(OBJECT, description); 89 | } 90 | 91 | private static Iterable removeNulls(JsonSchemaProperty... properties) { 92 | return stream(properties) 93 | .filter(Objects::nonNull) 94 | .collect(toList()); 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/tools/annotation/P.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.tools.annotation; 2 | 3 | import java.lang.annotation.Retention; 4 | import java.lang.annotation.Target; 5 | 6 | import static java.lang.annotation.ElementType.PARAMETER; 7 | import static java.lang.annotation.RetentionPolicy.RUNTIME; 8 | 9 | 10 | @Retention(RUNTIME) 11 | @Target({PARAMETER}) 12 | public @interface P { 13 | 14 | String value(); 15 | } 16 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/tools/annotation/Tool.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.tools.annotation; 2 | 3 | import java.lang.annotation.Retention; 4 | import java.lang.annotation.Target; 5 | 6 | import static java.lang.annotation.ElementType.METHOD; 7 | import static java.lang.annotation.RetentionPolicy.RUNTIME; 8 | 9 | 10 | @Retention(RUNTIME) 11 | @Target({METHOD}) 12 | public @interface Tool { 13 | 14 | String name() default ""; 15 | 16 | String[] value() default ""; 17 | } 18 | -------------------------------------------------------------------------------- /smartFuse-domain/src/main/java/com/ai/domain/tools/annotation/ToolMemoryId.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.tools.annotation; 2 | 3 | import java.lang.annotation.Retention; 4 | import java.lang.annotation.Target; 5 | 6 | import static java.lang.annotation.ElementType.PARAMETER; 7 | import static java.lang.annotation.RetentionPolicy.RUNTIME; 8 | 9 | @Retention(RUNTIME) 10 | @Target(PARAMETER) 11 | public @interface ToolMemoryId { 12 | } 13 | -------------------------------------------------------------------------------- /smartFuse-domain/src/test/java/com/ai/domain/document/DocumentTest.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document; 2 | 3 | import org.junit.Test; 4 | 5 | import java.io.File; 6 | import java.nio.file.Path; 7 | import java.nio.file.Paths; 8 | 9 | /** 10 | * 文档加载测试 11 | **/ 12 | public class DocumentTest { 13 | 14 | public static Path toPath(String fileName) { 15 | File file = new File(fileName); 16 | if (file.exists()) { 17 | try { 18 | return Paths.get(file.toURI()); 19 | } catch (Exception e) { 20 | e.printStackTrace(); 21 | } 22 | } 23 | return null; 24 | } 25 | 26 | @Test 27 | public void test_load() { 28 | String[] filePaths = { 29 | "文件路径\\中文测试.txt", 30 | "文件路径\\中文测试.docx", 31 | "文件路径\\中文测试.pdf", 32 | "文件路径\\中文测试.xlsx", 33 | "文件路径\\中文测试.pptx" 34 | }; 35 | for (String filePath : filePaths) { 36 | Document document = FileSystemDocumentLoader.loadDocument(filePath); 37 | System.out.println(document.text()); 38 | System.out.println(document.metadata()); 39 | } 40 | } 41 | 42 | @Test 43 | public void test_load_txt() { 44 | Path filePath = toPath("D:\\chatGPT-api\\AI-SmartFuse-Framework\\doc\\test\\document\\中文测试.txt"); 45 | Document document = FileSystemDocumentLoader.loadDocument(filePath); 46 | System.out.println(document); 47 | } 48 | 49 | @Test 50 | public void test_load_word() { 51 | Path filePath = toPath("D:\\chatGPT-api\\AI-SmartFuse-Framework\\doc\\test\\document\\中文测试.docx"); 52 | Document document = FileSystemDocumentLoader.loadDocument(filePath); 53 | System.out.println(document); 54 | } 55 | 56 | @Test 57 | public void test_load_pdf() { 58 | Path filePath = toPath("D:\\chatGPT-api\\AI-SmartFuse-Framework\\doc\\test\\document\\中文测试.pdf"); 59 | Document document = FileSystemDocumentLoader.loadDocument(filePath); 60 | System.out.println(document); 61 | } 62 | 63 | @Test 64 | public void test_load_excel() { 65 | Path filePath = toPath("D:\\chatGPT-api\\AI-SmartFuse-Framework\\doc\\test\\document\\中文测试.xlsx"); 66 | Document document = FileSystemDocumentLoader.loadDocument(filePath); 67 | System.out.println(document); 68 | } 69 | 70 | @Test 71 | public void test_load_ppt() { 72 | Path filePath = toPath("D:\\chatGPT-api\\AI-SmartFuse-Framework\\doc\\test\\document\\中文测试.pptx"); 73 | Document document = FileSystemDocumentLoader.loadDocument(filePath); 74 | System.out.println(document); 75 | } 76 | 77 | } 78 | -------------------------------------------------------------------------------- /smartFuse-domain/src/test/java/com/ai/domain/document/SplitterTest.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.document; 2 | 3 | import com.ai.domain.document.splitter.impl.*; 4 | import org.junit.Before; 5 | import org.junit.Test; 6 | 7 | import java.nio.file.Path; 8 | 9 | import static com.ai.domain.document.DocumentTest.toPath; 10 | 11 | /** 12 | * 分束器测试 13 | */ 14 | public class SplitterTest { 15 | 16 | private Document document; 17 | 18 | public static void showRes(String[] strings) { 19 | for (String res : strings) { 20 | System.out.println(res); 21 | } 22 | } 23 | 24 | @Before 25 | public void test_load() { 26 | Path filePath = toPath("D:\\chatGPT-api\\AI-SmartFuse-Framework\\doc\\test\\document\\中文测试.txt"); 27 | this.document = FileSystemDocumentLoader.loadDocument(filePath); 28 | } 29 | 30 | @Test 31 | public void test_character_splitter() { 32 | DocumentByCharacterSplitter documentByCharacterSplitter = new DocumentByCharacterSplitter(10, 1); 33 | String[] split = documentByCharacterSplitter.split(document.text()); 34 | showRes(split); 35 | } 36 | 37 | @Test 38 | public void test_line_splitter() { 39 | DocumentByLineSplitter documentByLineSplitter = new DocumentByLineSplitter(10, 1); 40 | String[] split = documentByLineSplitter.split(document.text()); 41 | showRes(split); 42 | } 43 | 44 | @Test 45 | public void test_paragraph_splitter() { 46 | DocumentByParagraphSplitter documentByParagraphSplitter = new DocumentByParagraphSplitter(10, 0); 47 | String[] split = documentByParagraphSplitter.split(document.text()); 48 | showRes(split); 49 | } 50 | 51 | @Test 52 | public void test_sentence_splitter() { 53 | DocumentBySentenceSplitter documentBySentenceSplitter = new DocumentBySentenceSplitter(10, 0); 54 | String[] split = documentBySentenceSplitter.split(document.text()); 55 | showRes(split); 56 | } 57 | 58 | @Test 59 | public void test_word_splitter() { 60 | DocumentByWordSplitter documentByWordSplitter = new DocumentByWordSplitter(10, 0); 61 | String[] split = documentByWordSplitter.split(document.text()); 62 | showRes(split); 63 | } 64 | 65 | } 66 | -------------------------------------------------------------------------------- /smartFuse-domain/src/test/java/com/ai/domain/prompt/PromptAnnotationTest.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.prompt; 2 | 3 | 4 | public class PromptAnnotationTest { 5 | 6 | 7 | } 8 | -------------------------------------------------------------------------------- /smartFuse-domain/src/test/java/com/ai/domain/prompt/PromptTest.java: -------------------------------------------------------------------------------- 1 | package com.ai.domain.prompt; 2 | 3 | import com.ai.domain.prompt.impl.SimplePrompt; 4 | import com.ai.domain.prompt.impl.SimplePromptTemplate; 5 | import org.junit.Before; 6 | import org.junit.Test; 7 | 8 | import java.util.HashMap; 9 | import java.util.Map; 10 | 11 | /** 12 | * 提示词模板测试类 13 | **/ 14 | public class PromptTest { 15 | 16 | private SimplePromptTemplate simplePromptTemplate; 17 | 18 | @Before 19 | public void test_create_prompt_template() { 20 | // 提示词模板,需要替换的地方用{{}}包括,其中 key 为 money 21 | String promptTemplateString = "我有一辆价值{{money}}的车,它的品牌是:{{brand}}。"; 22 | // 提示词的名称,标识这个提示词是干什么的 23 | String templateName = "汽车提示词"; 24 | this.simplePromptTemplate = new SimplePromptTemplate(promptTemplateString, templateName); 25 | } 26 | 27 | @Test 28 | public void test_use_prompt_apply() { 29 | Map map = new HashMap<>(); 30 | map.put("money", "100万"); 31 | map.put("brand", "宝马"); 32 | // 传入一个包含关键字的Map,返回一个应用示例。其中会记录其对应的模板信息 33 | SimplePrompt apply = simplePromptTemplate.apply(map); 34 | System.out.println(apply);// SimplePrompt(text=我有一辆价值100万的车,它的品牌是:宝马。) 35 | System.out.println(apply.text());// 我有一辆价值100万的车,它的品牌是:宝马。 36 | System.out.println(apply.getPromptTemplate());// SimplePromptTemplate{template='我有一辆价值{{money}}的车,它的品牌是:{{brand}}。', promptName='汽车提示词', renderMap={}} 37 | } 38 | 39 | @Test 40 | public void test_use_prompt_render() { 41 | Map map = new HashMap<>(); 42 | map.put("money", "50万"); 43 | map.put("brand", "奔驰"); 44 | // 直接渲染,返回字符串 45 | String render = simplePromptTemplate.render(map); 46 | System.out.println(render);// 我有一辆价值50万的车,它的品牌是:奔驰。 47 | } 48 | 49 | @Test 50 | public void test_use_prompt_add_render() { 51 | // 将需要渲染的数据跟模板进行绑定 52 | simplePromptTemplate.add("money", "50万"); 53 | simplePromptTemplate.add("brand", "奔驰"); 54 | String render = simplePromptTemplate.render(); 55 | System.out.println(render); 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /smartFuse-openai/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | AI-SmartFuse-Framework 7 | com.ai 8 | 1.0-SNAPSHOT 9 | 10 | 4.0.0 11 | 12 | smartFuse-openai 13 | 14 | 15 | UTF-8 16 | UTF-8 17 | 1.8 18 | 8 19 | 8 20 | 21 | 22 | 23 | 24 | com.ai 25 | smartFuse-common 26 | 1.0-SNAPSHOT 27 | 28 | 29 | com.ai 30 | smartFuse-domain 31 | 1.0-SNAPSHOT 32 | 33 | 34 | junit 35 | junit 36 | 4.13.2 37 | test 38 | 39 | 40 | com.ai 41 | ai-openai 42 | 1.0 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/client/OpenAiClient.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.client; 2 | 3 | 4 | import com.ai.openai.achieve.Configuration; 5 | import com.ai.openai.achieve.defaults.DefaultOpenAiSessionFactory; 6 | import com.ai.openai.achieve.standard.session.AggregationSession; 7 | 8 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 9 | 10 | /** 11 | * @Description: openAi请求客户端 12 | **/ 13 | public class OpenAiClient { 14 | 15 | private static AggregationSession aggregationSession; 16 | 17 | private static Configuration configuration; 18 | 19 | public static void SetConfiguration(Configuration configuration) { 20 | ensureNotNull(configuration, "configuration"); 21 | aggregationSession = new DefaultOpenAiSessionFactory(configuration).openAggregationSession(); 22 | } 23 | 24 | public static Configuration GetConfiguration() { 25 | return configuration; 26 | } 27 | 28 | public static AggregationSession getAggregationSession() { 29 | return aggregationSession; 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/converter/BeanConverter.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.converter; 2 | 3 | import com.ai.common.resp.usage.TokenUsage; 4 | import com.ai.openai.common.Usage; 5 | 6 | public class BeanConverter { 7 | 8 | public static TokenUsage usage2tokenUsage(Usage usage) { 9 | return TokenUsage.usage(usage.getPromptTokens(), usage.getCompletionTokens(), usage.getTotalTokens()); 10 | } 11 | 12 | } 13 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/model/ModelConversionTemplate.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.model; 2 | 3 | 4 | /** 5 | * TODO 考虑用模板模式优化代码结构 6 | */ 7 | public abstract class ModelConversionTemplate { 8 | 9 | public final void run() { 10 | constructRequestParameters(); 11 | initiateRequest(); 12 | ParseReturnParameters(); 13 | } 14 | 15 | abstract void constructRequestParameters(); 16 | 17 | abstract void initiateRequest(); 18 | 19 | abstract void ParseReturnParameters(); 20 | 21 | } 22 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/model/OpenaiAudioModel.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.model; 2 | 3 | 4 | import cn.hutool.core.bean.BeanUtil; 5 | import com.ai.common.resp.AiResponse; 6 | import com.ai.common.resp.finish.FinishReason; 7 | import com.ai.domain.data.parameter.Parameter; 8 | import com.ai.domain.model.AudioModel; 9 | import com.ai.openai.achieve.standard.session.AudioSession; 10 | import com.ai.openai.client.OpenAiClient; 11 | import com.ai.openai.endPoint.audio.req.SttCompletionRequest; 12 | import com.ai.openai.endPoint.audio.req.TtsCompletionRequest; 13 | import com.ai.openai.endPoint.audio.resp.SttCompletionResponse; 14 | import com.ai.openai.parameter.OpenaiAudioModelSttParameter; 15 | import com.ai.openai.parameter.OpenaiAudioModelTtsParameter; 16 | import com.ai.openai.parameter.input.OpenaiAudioSttParameter; 17 | import com.ai.openai.parameter.input.OpenaiAudioTtsParameter; 18 | import okhttp3.ResponseBody; 19 | import retrofit2.Callback; 20 | 21 | import java.io.File; 22 | 23 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 24 | import static com.ai.core.exception.Constants.NULL; 25 | 26 | /** 27 | * openai语音模型 28 | */ 29 | public class OpenaiAudioModel implements AudioModel { 30 | 31 | private final AudioSession audioSession = OpenAiClient.getAggregationSession().getAudioSession(); 32 | private Parameter sttParameter; 33 | private Parameter ttsParameter; 34 | 35 | public OpenaiAudioModel() { 36 | this(new OpenaiAudioModelSttParameter(), new OpenaiAudioModelTtsParameter()); 37 | } 38 | 39 | public OpenaiAudioModel(Parameter sttParameter, Parameter ttsParameter) { 40 | this.sttParameter = ensureNotNull(sttParameter, "sttParameter"); 41 | this.ttsParameter = ensureNotNull(ttsParameter, "ttsParameter"); 42 | } 43 | 44 | public Parameter getSttParameter() { 45 | return sttParameter; 46 | } 47 | 48 | public void setSttParameter(Parameter sttParameter) { 49 | this.sttParameter = ensureNotNull(sttParameter, "sttParameter"); 50 | } 51 | 52 | public Parameter getTtsParameter() { 53 | return ttsParameter; 54 | } 55 | 56 | public void setTtsParameter(Parameter ttsParameter) { 57 | this.ttsParameter = ensureNotNull(ttsParameter, "ttsParameter"); 58 | } 59 | 60 | @Override 61 | public void textToSpeech(String text, Callback callback) { 62 | // 构造请求主要参数 63 | TtsCompletionRequest request = TtsCompletionRequest.builder().input(text).build(); 64 | // 填充请求配置属性 65 | BeanUtil.copyProperties(ttsParameter.getParameter(), request); 66 | // 通过回调函数处理结果 67 | audioSession.ttsCompletions(NULL, NULL, NULL, request, callback); 68 | } 69 | 70 | @Override 71 | public AiResponse speechToText(File speech) { 72 | // 构造请求主要参数 73 | SttCompletionRequest request = SttCompletionRequest.builder().file(speech).build(); 74 | // 填充请求配置属性 75 | BeanUtil.copyProperties(sttParameter.getParameter(), request); 76 | // 发起请求获取结果 77 | SttCompletionResponse response = audioSession.sttCompletions(NULL, NULL, NULL, request); 78 | // 转换结果为统一返回值 79 | return AiResponse.R(response.getText(), FinishReason.success()); 80 | } 81 | 82 | } 83 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/model/OpenaiChatModel.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.model; 2 | 3 | import cn.hutool.core.bean.BeanUtil; 4 | import com.ai.common.resp.AiResponse; 5 | import com.ai.common.resp.finish.FinishReason; 6 | import com.ai.domain.data.message.AssistantMessage; 7 | import com.ai.domain.data.message.ChatMessage; 8 | import com.ai.domain.data.parameter.Parameter; 9 | import com.ai.domain.model.ChatModel; 10 | import com.ai.openai.achieve.standard.session.ChatSession; 11 | import com.ai.openai.client.OpenAiClient; 12 | import com.ai.openai.endPoint.chat.ChatChoice; 13 | import com.ai.openai.endPoint.chat.msg.DefaultMessage; 14 | import com.ai.openai.endPoint.chat.req.DefaultChatCompletionRequest; 15 | import com.ai.openai.endPoint.chat.resp.ChatCompletionResponse; 16 | import com.ai.openai.parameter.OpenaiChatModelParameter; 17 | import com.ai.openai.parameter.input.OpenaiChatParameter; 18 | 19 | import java.util.List; 20 | import java.util.stream.Collectors; 21 | 22 | import static com.ai.common.util.ValidationUtils.ensureNotEmpty; 23 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 24 | import static com.ai.core.exception.Constants.NULL; 25 | import static com.ai.openai.converter.BeanConverter.usage2tokenUsage; 26 | 27 | /** 28 | * openai对话聊天模型 29 | **/ 30 | public class OpenaiChatModel implements ChatModel { 31 | 32 | private final ChatSession chatSession = OpenAiClient.getAggregationSession().getChatSession(); 33 | private Parameter parameter; 34 | 35 | public OpenaiChatModel() { 36 | this(new OpenaiChatModelParameter()); 37 | } 38 | 39 | public OpenaiChatModel(Parameter parameter) { 40 | this.parameter = ensureNotNull(parameter, "parameter"); 41 | } 42 | 43 | private static List chatMessageList2DefaultMessageList(List chatMessages) { 44 | return chatMessages.stream() 45 | .map(chatMessage -> DefaultMessage.builder() 46 | .role(chatMessage.type().getMessageType()) 47 | .content(chatMessage.text()) 48 | .build()) 49 | .collect(Collectors.toList()); 50 | } 51 | 52 | public Parameter getParameter() { 53 | return parameter; 54 | } 55 | 56 | public void setParameter(Parameter parameter) { 57 | this.parameter = ensureNotNull(parameter, "parameter"); 58 | } 59 | 60 | @Override 61 | public AiResponse generate(List messages) { 62 | ensureNotEmpty(messages, "messages"); 63 | // 将message转换为openai模型所需的格式 64 | List defaultMessages = chatMessageList2DefaultMessageList(messages); 65 | // 构造请求主要参数 66 | DefaultChatCompletionRequest request = DefaultChatCompletionRequest.builder() 67 | .messages(defaultMessages).build(); 68 | // 填充请求配置属性 69 | BeanUtil.copyProperties(parameter.getParameter(), request); 70 | // 发送请求获取结果 71 | ChatCompletionResponse response = chatSession.chatCompletions(NULL, NULL, NULL, request); 72 | return createAiResponse(response); 73 | } 74 | 75 | private AiResponse createAiResponse(ChatCompletionResponse response) { 76 | // 获取对话内容 77 | List choices = response.getChoices(); 78 | // 得到模型的回复 79 | AssistantMessage assistantMessage = AssistantMessage.message(choices.get(choices.size() - 1).getMessage().getContent()); 80 | // 转换结果为统一返回值 81 | return AiResponse.R(assistantMessage, usage2tokenUsage(response.getUsage()), FinishReason.success()); 82 | } 83 | 84 | } 85 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/model/OpenaiEmbeddingModel.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.model; 2 | 3 | import cn.hutool.core.bean.BeanUtil; 4 | import com.ai.common.resp.AiResponse; 5 | import com.ai.common.resp.finish.FinishReason; 6 | import com.ai.domain.data.embedding.Embedding; 7 | import com.ai.domain.data.parameter.Parameter; 8 | import com.ai.domain.model.EmbeddingModel; 9 | import com.ai.openai.achieve.standard.session.EmbeddingSession; 10 | import com.ai.openai.client.OpenAiClient; 11 | import com.ai.openai.endPoint.embeddings.EmbeddingObject; 12 | import com.ai.openai.endPoint.embeddings.req.EmbeddingCompletionRequest; 13 | import com.ai.openai.endPoint.embeddings.resp.EmbeddingCompletionResponse; 14 | import com.ai.openai.parameter.OpenaiEmbeddingModelParameter; 15 | import com.ai.openai.parameter.input.OpenaiEmbeddingParameter; 16 | 17 | import java.util.List; 18 | import java.util.stream.Collectors; 19 | 20 | import static com.ai.common.util.ValidationUtils.ensureNotEmpty; 21 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 22 | import static com.ai.core.exception.Constants.NULL; 23 | import static com.ai.openai.converter.BeanConverter.usage2tokenUsage; 24 | 25 | /** 26 | * 文本嵌入 27 | **/ 28 | public class OpenaiEmbeddingModel implements EmbeddingModel { 29 | 30 | private final EmbeddingSession embeddingSession = OpenAiClient.getAggregationSession().getEmbeddingSession(); 31 | private Parameter parameter; 32 | 33 | public OpenaiEmbeddingModel() { 34 | this.parameter = new OpenaiEmbeddingModelParameter(); 35 | } 36 | 37 | public OpenaiEmbeddingModel(Parameter parameter) { 38 | this.parameter = ensureNotNull(parameter, "parameter"); 39 | } 40 | 41 | public static Embedding embeddingObj2Embedding(EmbeddingObject embeddingObject) { 42 | return new Embedding(embeddingObject.getEmbedding(), embeddingObject.getContent()); 43 | } 44 | 45 | public static List embeddingObjList2embeddingList(List embeddingObjectList) { 46 | return embeddingObjectList.stream() 47 | .map(embeddingObject -> embeddingObj2Embedding(embeddingObject)) 48 | .collect(Collectors.toList()); 49 | } 50 | 51 | public Parameter getParameter() { 52 | return parameter; 53 | } 54 | 55 | public void setParameter(Parameter parameter) { 56 | this.parameter = ensureNotNull(parameter, "parameter"); 57 | } 58 | 59 | public AiResponse> embed(List stringList) { 60 | ensureNotEmpty(stringList, "stringList"); 61 | // 构造请求主要参数 62 | EmbeddingCompletionRequest request = EmbeddingCompletionRequest.baseBuild(stringList); 63 | // 填充请求配置属性 64 | BeanUtil.copyProperties(parameter.getParameter(), request); 65 | // 发起请求获取结果 66 | EmbeddingCompletionResponse response = embeddingSession.embeddingCompletions(NULL, NULL, NULL, request); 67 | // 转换结果为统一返回值 68 | List embeddings = embeddingObjList2embeddingList(response.getData()); 69 | return AiResponse.R(embeddings, usage2tokenUsage(response.getUsage()), FinishReason.success()); 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/model/OpenaiImageModel.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.model; 2 | 3 | import cn.hutool.core.bean.BeanUtil; 4 | import com.ai.common.resp.AiResponse; 5 | import com.ai.common.resp.finish.FinishReason; 6 | import com.ai.domain.data.images.Image; 7 | import com.ai.domain.data.parameter.Parameter; 8 | import com.ai.domain.model.ImageModel; 9 | import com.ai.openai.achieve.standard.session.ImageSession; 10 | import com.ai.openai.client.OpenAiClient; 11 | import com.ai.openai.endPoint.images.ImageObject; 12 | import com.ai.openai.endPoint.images.req.CreateImageRequest; 13 | import com.ai.openai.parameter.OpenaiImageModelParameter; 14 | import com.ai.openai.parameter.input.OpenaiImageParameter; 15 | 16 | import java.util.List; 17 | import java.util.stream.Collectors; 18 | 19 | import static com.ai.common.util.ValidationUtils.ensureNotBlank; 20 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 21 | import static com.ai.core.exception.Constants.NULL; 22 | 23 | /** 24 | * 图片生成模型 25 | **/ 26 | public class OpenaiImageModel implements ImageModel { 27 | 28 | private final ImageSession imageSession = OpenAiClient.getAggregationSession().getImageSession(); 29 | private Parameter parameter; 30 | 31 | public OpenaiImageModel() { 32 | this.parameter = new OpenaiImageModelParameter(); 33 | } 34 | 35 | public OpenaiImageModel(Parameter parameter) { 36 | this.parameter = ensureNotNull(parameter, "parameter"); 37 | } 38 | 39 | public static List imageObjList2ImageList(List imageObjectList) { 40 | return imageObjectList.stream() 41 | .map(imageObject -> imageObj2Image(imageObject)) 42 | .collect(Collectors.toList()); 43 | } 44 | 45 | public static Image imageObj2Image(ImageObject imageObject) { 46 | return Image.from(imageObject.getUrl(), imageObject.getB64Json()); 47 | } 48 | 49 | public Parameter getParameter() { 50 | return parameter; 51 | } 52 | 53 | public void setParameter(Parameter parameter) { 54 | this.parameter = ensureNotNull(parameter, "parameter"); 55 | } 56 | 57 | @Override 58 | public AiResponse> create(String prompt, String size, String style, int n) { 59 | ensureNotBlank(prompt, "prompt"); 60 | // 构造请求主要参数 61 | CreateImageRequest request = CreateImageRequest.builder() 62 | .prompt(prompt).size(size).style(style).n(n).build(); 63 | // 填充请求配置属性 64 | BeanUtil.copyProperties(parameter.getParameter(), request); 65 | // 发起请求获取结果 66 | List imageObjectList = imageSession.createImageCompletions(NULL, NULL, NULL, request); 67 | // 转换结果为统一返回值 68 | return AiResponse.R(imageObjList2ImageList(imageObjectList), FinishReason.success()); 69 | } 70 | 71 | } 72 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/parameter/OpenaiAudioModelSttParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.parameter; 2 | 3 | import com.ai.domain.data.parameter.Parameter; 4 | import com.ai.openai.parameter.input.OpenaiAudioSttParameter; 5 | 6 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 7 | 8 | 9 | public class OpenaiAudioModelSttParameter implements Parameter { 10 | 11 | private OpenaiAudioSttParameter parameter; 12 | 13 | public OpenaiAudioModelSttParameter() { 14 | this(OpenaiAudioSttParameter.builder().build()); 15 | } 16 | 17 | public OpenaiAudioModelSttParameter(OpenaiAudioSttParameter parameter) { 18 | this.parameter = ensureNotNull(parameter, "OpenaiAudioSttParameter"); 19 | } 20 | 21 | @Override 22 | public OpenaiAudioSttParameter getParameter() { 23 | return this.parameter; 24 | } 25 | 26 | @Override 27 | public void SetParameter(OpenaiAudioSttParameter parameter) { 28 | this.parameter = parameter; 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/parameter/OpenaiAudioModelTtsParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.parameter; 2 | 3 | import com.ai.domain.data.parameter.Parameter; 4 | import com.ai.openai.parameter.input.OpenaiAudioTtsParameter; 5 | 6 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 7 | 8 | 9 | public class OpenaiAudioModelTtsParameter implements Parameter { 10 | 11 | private OpenaiAudioTtsParameter parameter; 12 | 13 | public OpenaiAudioModelTtsParameter() { 14 | this(OpenaiAudioTtsParameter.builder().build()); 15 | } 16 | 17 | public OpenaiAudioModelTtsParameter(OpenaiAudioTtsParameter parameter) { 18 | this.parameter = ensureNotNull(parameter, "OpenaiAudioTtsParameter"); 19 | } 20 | 21 | @Override 22 | public OpenaiAudioTtsParameter getParameter() { 23 | return this.parameter; 24 | } 25 | 26 | 27 | @Override 28 | public void SetParameter(OpenaiAudioTtsParameter parameter) { 29 | this.parameter = parameter; 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/parameter/OpenaiChatModelParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.parameter; 2 | 3 | import com.ai.domain.data.parameter.Parameter; 4 | import com.ai.openai.parameter.input.OpenaiChatParameter; 5 | 6 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 7 | 8 | public class OpenaiChatModelParameter implements Parameter { 9 | 10 | private OpenaiChatParameter parameter; 11 | 12 | public OpenaiChatModelParameter() { 13 | this(OpenaiChatParameter.builder().build()); 14 | } 15 | 16 | public OpenaiChatModelParameter(OpenaiChatParameter parameter) { 17 | this.parameter = ensureNotNull(parameter, "OpenaiChatParameter"); 18 | } 19 | 20 | @Override 21 | public OpenaiChatParameter getParameter() { 22 | return this.parameter; 23 | } 24 | 25 | @Override 26 | public void SetParameter(OpenaiChatParameter parameter) { 27 | this.parameter = parameter; 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/parameter/OpenaiEmbeddingModelParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.parameter; 2 | 3 | import com.ai.domain.data.parameter.Parameter; 4 | import com.ai.openai.parameter.input.OpenaiEmbeddingParameter; 5 | 6 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 7 | 8 | 9 | public class OpenaiEmbeddingModelParameter implements Parameter { 10 | 11 | private OpenaiEmbeddingParameter parameter; 12 | 13 | public OpenaiEmbeddingModelParameter() { 14 | this(OpenaiEmbeddingParameter.builder().build()); 15 | } 16 | 17 | public OpenaiEmbeddingModelParameter(OpenaiEmbeddingParameter parameter) { 18 | this.parameter = ensureNotNull(parameter, "OpenaiEmbeddingParameter"); 19 | } 20 | 21 | @Override 22 | public OpenaiEmbeddingParameter getParameter() { 23 | return this.parameter; 24 | } 25 | 26 | @Override 27 | public void SetParameter(OpenaiEmbeddingParameter parameter) { 28 | this.parameter = parameter; 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/parameter/OpenaiImageModelParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.parameter; 2 | 3 | import com.ai.domain.data.parameter.Parameter; 4 | import com.ai.openai.parameter.input.OpenaiImageParameter; 5 | 6 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 7 | 8 | public class OpenaiImageModelParameter implements Parameter { 9 | 10 | private OpenaiImageParameter parameter; 11 | 12 | public OpenaiImageModelParameter() { 13 | this(OpenaiImageParameter.builder().build()); 14 | } 15 | 16 | public OpenaiImageModelParameter(OpenaiImageParameter parameter) { 17 | this.parameter = ensureNotNull(parameter, "OpenaiImageParameter"); 18 | } 19 | 20 | @Override 21 | public OpenaiImageParameter getParameter() { 22 | return this.parameter; 23 | } 24 | 25 | @Override 26 | public void SetParameter(OpenaiImageParameter parameter) { 27 | this.parameter = parameter; 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/parameter/OpenaiModerationModelParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.parameter; 2 | 3 | 4 | import com.ai.domain.data.parameter.Parameter; 5 | import com.ai.openai.parameter.input.OpenaiModerationParameter; 6 | 7 | import static com.ai.common.util.ValidationUtils.ensureNotNull; 8 | 9 | public class OpenaiModerationModelParameter implements Parameter { 10 | 11 | private OpenaiModerationParameter parameter; 12 | 13 | public OpenaiModerationModelParameter() { 14 | this(OpenaiModerationParameter.builder().build()); 15 | } 16 | 17 | public OpenaiModerationModelParameter(OpenaiModerationParameter parameter) { 18 | this.parameter = ensureNotNull(parameter, "OpenaiModerationParameter"); 19 | } 20 | 21 | @Override 22 | public OpenaiModerationParameter getParameter() { 23 | return this.parameter; 24 | } 25 | 26 | @Override 27 | public void SetParameter(OpenaiModerationParameter parameter) { 28 | this.parameter = parameter; 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/parameter/input/OpenaiAudioSttParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.parameter.input; 2 | 3 | import com.fasterxml.jackson.annotation.JsonIgnoreProperties; 4 | import com.fasterxml.jackson.annotation.JsonInclude; 5 | import com.fasterxml.jackson.annotation.JsonProperty; 6 | import lombok.*; 7 | 8 | import java.io.Serializable; 9 | 10 | @Data 11 | @Builder 12 | @NoArgsConstructor 13 | @AllArgsConstructor 14 | @JsonIgnoreProperties(ignoreUnknown = true) 15 | @JsonInclude(JsonInclude.Include.NON_NULL) 16 | public class OpenaiAudioSttParameter implements Serializable { 17 | 18 | /** 19 | * 使用的模型名 20 | */ 21 | @NonNull 22 | @Builder.Default 23 | private String model = OpenaiAudioSttParameter.Model.whisper_1.getModuleName(); 24 | 25 | /** 26 | * 音频的语言,以 ISO-639-1 格式提供输入语言将提高准确性和延迟。 27 | */ 28 | private String language; 29 | 30 | /** 31 | * 一个可选文本,用于指导模型的样式或继续上一个音频片段。提示应与音频语言匹配。 32 | */ 33 | private String prompt; 34 | 35 | /** 36 | * 脚本输出的格式 37 | */ 38 | @JsonProperty("response_format") 39 | private String responseFormat; 40 | 41 | /** 42 | * 采样温度,介于 0 和 1 之间,默认值为 0 。 43 | * 较高的值(如 0.8)将使输出更加随机,而较低的值(如 0.2)将使其更具针对性和确定性。 44 | * 如果设置为 0,模型将使用对数概率自动提高温度,直到达到某些阈值。 45 | */ 46 | private Double temperature; 47 | 48 | @Getter 49 | @AllArgsConstructor 50 | public enum Model { 51 | whisper_1("whisper-1"); 52 | private String moduleName; 53 | } 54 | 55 | 56 | } 57 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/parameter/input/OpenaiAudioTtsParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.parameter.input; 2 | 3 | import com.fasterxml.jackson.annotation.JsonIgnoreProperties; 4 | import com.fasterxml.jackson.annotation.JsonInclude; 5 | import com.fasterxml.jackson.annotation.JsonProperty; 6 | import lombok.*; 7 | 8 | import java.io.Serializable; 9 | 10 | @Data 11 | @Builder 12 | @NoArgsConstructor 13 | @AllArgsConstructor 14 | @JsonIgnoreProperties(ignoreUnknown = true) 15 | @JsonInclude(JsonInclude.Include.NON_NULL) 16 | public class OpenaiAudioTtsParameter implements Serializable { 17 | /** 18 | * 要使用的模型的 ID 19 | */ 20 | @Builder.Default 21 | private String model = Model.tts_1.getModuleName(); 22 | 23 | /** 24 | * 声音样式 25 | */ 26 | @NonNull 27 | private String voice; 28 | 29 | /** 30 | * 音频输入的格式,默认为mp3 31 | */ 32 | @JsonProperty("response_format") 33 | private String responseFormat; 34 | 35 | /** 36 | * 音频的速度,0.25 到 4.0 之中选取一个数,数字越大速度越快。默认为1。 37 | */ 38 | private String speed; 39 | 40 | @Getter 41 | @AllArgsConstructor 42 | public enum Model { 43 | tts_1("tts-1"), tts_1_hd("tts-1-hd"); 44 | private String moduleName; 45 | } 46 | 47 | /** 48 | * 声音样式 49 | */ 50 | @Getter 51 | @AllArgsConstructor 52 | public enum Voice { 53 | alloy("alloy"), 54 | echo("echo"), 55 | fable("fable"), 56 | onyx("onyx"), 57 | nova("nova"), 58 | shimmer("shimmer"); 59 | private String voiceName; 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/parameter/input/OpenaiChatParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.parameter.input; 2 | 3 | import com.ai.openai.endPoint.chat.ResponseFormat; 4 | import com.fasterxml.jackson.annotation.JsonIgnoreProperties; 5 | import com.fasterxml.jackson.annotation.JsonInclude; 6 | import com.fasterxml.jackson.annotation.JsonProperty; 7 | import lombok.*; 8 | 9 | import java.io.Serializable; 10 | import java.util.List; 11 | import java.util.Map; 12 | 13 | @Data 14 | @Builder 15 | @NoArgsConstructor 16 | @AllArgsConstructor 17 | @JsonIgnoreProperties(ignoreUnknown = true) 18 | @JsonInclude(JsonInclude.Include.NON_NULL) 19 | public class OpenaiChatParameter implements Serializable { 20 | 21 | /** 22 | * 要使用的模型的 ID 23 | */ 24 | @Builder.Default 25 | private String model = OpenaiChatParameter.Model.GPT_3_5_TURBO.getModuleName(); 26 | 27 | /** 28 | * 介于 -2.0 和 2.0 之间的数字,默认值为 0 29 | * 正值会根据新标记在文本中的现有频率来惩罚新标记从而降低模型逐字重复同一行的可能性 30 | */ 31 | @JsonProperty("frequency_penalty") 32 | private double frequencyPenalty; 33 | 34 | /** 35 | * 修改指定标记出现的可能性,默认值为 null 36 | */ 37 | @JsonProperty("logit_bias") 38 | private Map logitBias; 39 | 40 | /** 41 | * 输出字符串限制;0 ~ 4096 42 | */ 43 | @JsonProperty("max_tokens") 44 | private Integer maxTokens; 45 | 46 | /** 47 | * 为每个提示生成的完成次数,默认值为 1 48 | */ 49 | private Integer n; 50 | 51 | /** 52 | * 介于 -2.0 和 2.0 之间的数字,默认值为 0 53 | * 正值会根据新标记到目前为止是否出现在文本中来惩罚它们从而增加模型谈论新主题的可能性 54 | */ 55 | @JsonProperty("presence_penalty") 56 | private double presencePenalty; 57 | 58 | /** 59 | * 指定模型必须输出的格式的对象。 60 | */ 61 | @JsonProperty("response_format") 62 | private ResponseFormat responseFormat; 63 | 64 | private Integer seed; 65 | 66 | /** 67 | * 停止输出标识,默认值为 null 68 | * 最多 4 个序列,API 将停止生成更多令牌 69 | */ 70 | private List stop; 71 | 72 | /** 73 | * 使用什么采样温度,介于 0 和 2 之间,默认值为 1 74 | * 较高的值(如 0.8)将使输出更加随机,而较低的值(如 0.2)将使其更具集中性和确定性 75 | */ 76 | private double temperature; 77 | 78 | /** 79 | * 默认值为 1 80 | * 温度采样的替代方法,称为核采样,其中模型考虑具有top_p概率质量的标记的结果。因此,0.1 表示仅考虑包含前 10% 概率质量的代币。 81 | */ 82 | @JsonProperty("top_p") 83 | private Double topP; 84 | 85 | /** 86 | * 调用标识,避免重复调用 87 | */ 88 | private String user; 89 | 90 | @Getter 91 | @AllArgsConstructor 92 | public enum Model { 93 | GPT_3_5_TURBO("gpt-3.5-turbo"), GPT_4("gpt-4"), GPT_4_32K("gpt-4-32k"), 94 | GPT_4_VISION_PREVIEW("gpt-4-vision-preview"), 95 | ; 96 | private String moduleName; 97 | } 98 | 99 | } 100 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/parameter/input/OpenaiEmbeddingParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.parameter.input; 2 | 3 | import com.fasterxml.jackson.annotation.JsonIgnoreProperties; 4 | import com.fasterxml.jackson.annotation.JsonInclude; 5 | import com.fasterxml.jackson.annotation.JsonProperty; 6 | import lombok.*; 7 | 8 | import java.io.Serializable; 9 | 10 | @Data 11 | @Builder 12 | @NoArgsConstructor 13 | @AllArgsConstructor 14 | @JsonIgnoreProperties(ignoreUnknown = true) 15 | @JsonInclude(JsonInclude.Include.NON_NULL) 16 | public class OpenaiEmbeddingParameter implements Serializable { 17 | 18 | /** 19 | * 要使用的模型的 ID。 20 | */ 21 | @Builder.Default 22 | private String model = OpenaiEmbeddingParameter.Model.TEXT_EMBEDDING_ADA_002.getModelName(); 23 | 24 | /** 25 | * 要返回嵌入的格式。 26 | */ 27 | @JsonProperty("encoding_format") 28 | private String encodingFormat; 29 | 30 | @Getter 31 | @AllArgsConstructor 32 | public enum Model { 33 | TEXT_EMBEDDING_ADA_002("text-embedding-ada-002"), 34 | ; 35 | private final String modelName; 36 | } 37 | 38 | } 39 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/parameter/input/OpenaiImageParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.parameter.input; 2 | 3 | import com.fasterxml.jackson.annotation.JsonIgnoreProperties; 4 | import com.fasterxml.jackson.annotation.JsonInclude; 5 | import com.fasterxml.jackson.annotation.JsonProperty; 6 | import lombok.*; 7 | 8 | import java.io.Serializable; 9 | 10 | @Data 11 | @Builder 12 | @NoArgsConstructor 13 | @AllArgsConstructor 14 | @JsonIgnoreProperties(ignoreUnknown = true) 15 | @JsonInclude(JsonInclude.Include.NON_NULL) 16 | public class OpenaiImageParameter implements Serializable { 17 | 18 | /** 19 | * 用于图像生成的模型 20 | *

21 | * 默认为 dall-e-2 22 | */ 23 | @Builder.Default 24 | private String model = OpenaiImageParameter.Model.DALL_E_3.getName(); 25 | 26 | /** 27 | * 将生成的图像的质量。 创建具有更精细细节和更高一致性的图像。 28 | */ 29 | private String quality; 30 | 31 | /** 32 | * 返回生成的图像的格式:url、b64_json 33 | */ 34 | @JsonProperty("response_format") 35 | private String responseFormat; 36 | 37 | /** 38 | * 代表最终用户的唯一标识符 39 | */ 40 | private String user; 41 | 42 | /** 43 | * 图片生成模型 44 | */ 45 | @Getter 46 | @AllArgsConstructor 47 | public enum Model { 48 | DALL_E_2("dall-e-2"), 49 | DALL_E_3("dall-e-3"), 50 | ; 51 | private final String name; 52 | } 53 | 54 | /** 55 | * 生成图片质量 56 | */ 57 | @Getter 58 | @AllArgsConstructor 59 | public enum Quality { 60 | STANDARD("standard"), 61 | HD("hd"), 62 | ; 63 | private final String quality; 64 | } 65 | 66 | /** 67 | * 生成图片风格 68 | */ 69 | @Getter 70 | @AllArgsConstructor 71 | public enum Style { 72 | VIVID("vivid"), 73 | NATURAL("natural"), 74 | ; 75 | private final String style; 76 | } 77 | 78 | @Getter 79 | @AllArgsConstructor 80 | public enum Format { 81 | URL("url"), 82 | B64JSON("b64_json"), 83 | ; 84 | private final String format; 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /smartFuse-openai/src/main/java/com/ai/openai/parameter/input/OpenaiModerationParameter.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.parameter.input; 2 | 3 | import com.ai.openai.endPoint.moderations.req.ModerationRequest; 4 | import com.fasterxml.jackson.annotation.JsonIgnoreProperties; 5 | import com.fasterxml.jackson.annotation.JsonInclude; 6 | import lombok.*; 7 | 8 | import java.io.Serializable; 9 | 10 | @Data 11 | @Builder 12 | @NoArgsConstructor 13 | @AllArgsConstructor 14 | @JsonIgnoreProperties(ignoreUnknown = true) 15 | @JsonInclude(JsonInclude.Include.NON_NULL) 16 | public class OpenaiModerationParameter implements Serializable { 17 | 18 | @Builder.Default 19 | private String model = ModerationRequest.Model.TEXT_MODERATION_LATEST.getName(); 20 | 21 | @Getter 22 | @AllArgsConstructor 23 | public enum Model { 24 | TEXT_MODERATION_STABLE("text-moderation-stable"), 25 | TEXT_MODERATION_LATEST("text-moderation-latest"), 26 | ; 27 | 28 | private final String name; 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /smartFuse-openai/src/test/java/com/ai/openai/chain/ConversationalChainTest.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.chain; 2 | 3 | import com.ai.core.strategy.impl.FirstKeyStrategy; 4 | import com.ai.domain.chain.impl.ConversationalChain; 5 | import com.ai.domain.memory.chat.impl.SimpleChatHistoryRecorder; 6 | import com.ai.openai.achieve.Configuration; 7 | import com.ai.openai.client.OpenAiClient; 8 | import com.ai.openai.model.OpenaiChatModel; 9 | import org.junit.Before; 10 | import org.junit.Test; 11 | 12 | import java.net.InetSocketAddress; 13 | import java.net.Proxy; 14 | import java.util.Arrays; 15 | 16 | /** 17 | * 测试链路功能 18 | **/ 19 | public class ConversationalChainTest { 20 | 21 | private ConversationalChain conversationalChain; 22 | 23 | @Before 24 | public void test_create_conversational_chain() { 25 | // 设置配置信息 26 | Configuration configuration = new Configuration(); 27 | configuration.setApiHost("https://api.openai.com"); 28 | configuration.setKeyList(Arrays.asList("************************")); 29 | configuration.setKeyStrategy(new FirstKeyStrategy()); 30 | configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890))); 31 | OpenAiClient.SetConfiguration(configuration); 32 | // 创建一个记录器,记录器是不能重用的,即不能多个chain使用同一个记录器,否则就相当于多个会话公用同一个历史聊天记录。 33 | // 但是记录器对应的存储器,及记录器当中的ChatMemoryStore是可以重用的,及存在多个记录器使用同一个存储器。 34 | // 可以在创建时指定记录器,也可以直接创建使用默认的记录器,默认存储30条消息。 35 | this.conversationalChain = ConversationalChain.builder() 36 | .chatModel(new OpenaiChatModel()) 37 | .historyRecorder(SimpleChatHistoryRecorder.builder().build()) 38 | .build(); 39 | } 40 | 41 | @Test 42 | public void test_conversational_chain_run() { 43 | String res1 = conversationalChain.run("你好,请记住我的名字叫做小明"); 44 | System.out.println(res1);// 你好,小明!很高兴认识你。 45 | String res2 = conversationalChain.run("我的名字是什么?"); 46 | System.out.println(res2);// 你的名字是小明。 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /smartFuse-openai/src/test/java/com/ai/openai/chain/ConversationalRetrievalChainTest.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.chain; 2 | 3 | 4 | import com.ai.core.strategy.impl.FirstKeyStrategy; 5 | import com.ai.domain.chain.impl.ConversationalRetrievalChain; 6 | import com.ai.domain.document.Document; 7 | import com.ai.domain.document.FileSystemDocumentLoader; 8 | import com.ai.domain.memory.chat.impl.SimpleChatHistoryRecorder; 9 | import com.ai.domain.memory.embedding.impl.SimpleEmbeddingStoreIngestor; 10 | import com.ai.domain.memory.embedding.impl.SimpleEmbeddingStoreRetriever; 11 | import com.ai.openai.achieve.Configuration; 12 | import com.ai.openai.client.OpenAiClient; 13 | import com.ai.openai.model.OpenaiChatModel; 14 | import com.ai.openai.model.OpenaiEmbeddingModel; 15 | import org.junit.Before; 16 | import org.junit.Test; 17 | 18 | import java.io.File; 19 | import java.net.InetSocketAddress; 20 | import java.net.Proxy; 21 | import java.nio.file.Path; 22 | import java.nio.file.Paths; 23 | import java.util.ArrayList; 24 | import java.util.Arrays; 25 | import java.util.List; 26 | 27 | public class ConversationalRetrievalChainTest { 28 | 29 | private ConversationalRetrievalChain conversationalRetrievalChain; 30 | 31 | public static Path toPath(String fileName) { 32 | File file = new File(fileName); 33 | if (file.exists()) { 34 | try { 35 | return Paths.get(file.toURI()); 36 | } catch (Exception e) { 37 | e.printStackTrace(); 38 | } 39 | } 40 | return null; 41 | } 42 | 43 | @Before 44 | public void before() { 45 | // 设置配置信息 46 | Configuration configuration = new Configuration(); 47 | configuration.setApiHost("https://api.openai.com"); 48 | configuration.setKeyList(Arrays.asList("************************")); 49 | configuration.setKeyStrategy(new FirstKeyStrategy()); 50 | configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890))); 51 | OpenAiClient.SetConfiguration(configuration); 52 | // 测试文件路径 53 | String[] filePath = {"D:\\chatGPT-api\\AI-SmartFuse-Framework\\doc\\test\\document\\ConversationalRetrievalChainTest-中文.txt"}; 54 | // "D:\\chatGPT-api\\AI-SmartFuse-Framework\\doc\\test\\document\\ConversationalRetrievalChainTest-英文.txt"}; 55 | // 创建嵌入数据导入器,这里可以设置你指定的存储器,也可以直接使用其中默认的存储器。 56 | SimpleEmbeddingStoreIngestor ingestor = SimpleEmbeddingStoreIngestor.builder().embeddingModel(new OpenaiEmbeddingModel()).build(); 57 | List documents = new ArrayList<>(); 58 | // 导入数据并放入List当中 59 | for (String file : filePath) { 60 | documents.add(FileSystemDocumentLoader.loadDocument(toPath(file))); 61 | } 62 | // 将数据导入到存储器当中 63 | ingestor.ingest(documents); 64 | // 获取存储器,并设置其对应的检索器,向检索器当中设置检索器检索的嵌入存储器。 65 | this.conversationalRetrievalChain = ConversationalRetrievalChain.builder() 66 | .chatModel(new OpenaiChatModel()) 67 | .embeddingModel(new OpenaiEmbeddingModel()) 68 | .historyRecorder(SimpleChatHistoryRecorder.builder().build()) 69 | .retriever(SimpleEmbeddingStoreRetriever.builder().embeddingMemoryStore(ingestor.getStore()).build()) 70 | .build(); 71 | } 72 | 73 | @Test 74 | public void test_embedding_data_retriever_with_en() { 75 | String question = "What kind of person is Little Red Riding Hood?"; 76 | String res = conversationalRetrievalChain.run(question); 77 | System.out.println(res); 78 | } 79 | 80 | @Test 81 | public void test_embedding_data_retriever_with_ch() { 82 | String question = "小红帽要去干什么?"; 83 | String res = conversationalRetrievalChain.run(question); 84 | System.out.println(res); 85 | } 86 | 87 | 88 | } 89 | -------------------------------------------------------------------------------- /smartFuse-openai/src/test/java/com/ai/openai/model/ModelTest.java: -------------------------------------------------------------------------------- 1 | package com.ai.openai.model; 2 | 3 | 4 | import com.ai.common.resp.AiResponse; 5 | import com.ai.core.strategy.impl.FirstKeyStrategy; 6 | import com.ai.domain.data.moderation.Moderation; 7 | import com.ai.openai.achieve.Configuration; 8 | import com.ai.openai.client.OpenAiClient; 9 | import org.junit.Before; 10 | import org.junit.Test; 11 | 12 | import java.net.InetSocketAddress; 13 | import java.net.Proxy; 14 | import java.util.ArrayList; 15 | import java.util.Arrays; 16 | import java.util.List; 17 | 18 | public class ModelTest { 19 | 20 | @Before 21 | public void test_model_before() { 22 | // 设置配置信息 23 | Configuration configuration = new Configuration(); 24 | configuration.setApiHost("https://api.openai.com"); 25 | configuration.setKeyList(Arrays.asList("************************")); 26 | configuration.setKeyStrategy(new FirstKeyStrategy()); 27 | configuration.setProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890))); 28 | OpenAiClient.SetConfiguration(configuration); 29 | } 30 | 31 | @Test 32 | public void test_moderation_model() { 33 | OpenaiModerationModel openaiModerationModel = new OpenaiModerationModel(); 34 | ArrayList strings = new ArrayList<>(); 35 | strings.add("你好"); 36 | strings.add("我要杀了你"); 37 | AiResponse> moderate = openaiModerationModel.moderate(strings); 38 | System.out.println(moderate.getData()); 39 | } 40 | 41 | } 42 | -------------------------------------------------------------------------------- /smartFuse-spark/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | AI-SmartFuse-Framework 7 | com.ai 8 | 1.0-SNAPSHOT 9 | 10 | 4.0.0 11 | 12 | smartFuse-spark 13 | 14 | 15 | UTF-8 16 | UTF-8 17 | 1.8 18 | 8 19 | 8 20 | 21 | 22 | 23 | 24 | com.ai 25 | smartFuse-common 26 | 1.0-SNAPSHOT 27 | 28 | 29 | com.ai 30 | smartFuse-domain 31 | 1.0-SNAPSHOT 32 | 33 | 34 | junit 35 | junit 36 | 4.13.2 37 | test 38 | 39 | 40 | com.ai 41 | ai-spark 42 | 1.0 43 | 44 | 45 | 46 | 47 | --------------------------------------------------------------------------------