├── .gitignore ├── LICENSE ├── README.md ├── misc ├── docker │ └── elasticsearch │ │ └── docker-compose.yml ├── drawio │ └── docment-comparison.svg └── logo.svg ├── pom.xml └── src ├── main ├── java │ └── io │ │ └── github │ │ └── cupybara │ │ └── javalangchains │ │ ├── chains │ │ ├── Chain.java │ │ ├── ChainLink.java │ │ ├── base │ │ │ ├── ApplyToStreamInputChain.java │ │ │ ├── JoinChain.java │ │ │ ├── StreamUnwrappingChain.java │ │ │ ├── StreamWrappingChain.java │ │ │ └── logging │ │ │ │ └── LoggingChain.java │ │ ├── data │ │ │ ├── reader │ │ │ │ ├── ReadDocumentsFromInMemoryPdfChain.java │ │ │ │ ├── ReadDocumentsFromPdfChain.java │ │ │ │ └── ReadDocumentsFromPdfChainBase.java │ │ │ ├── retrieval │ │ │ │ ├── ElasticsearchRetrievalChain.java │ │ │ │ ├── JdbcRetrievalChain.java │ │ │ │ ├── LuceneRetrievalChain.java │ │ │ │ └── RetrievalChain.java │ │ │ └── writer │ │ │ │ ├── WriteDocumentsToElasticsearchIndexChain.java │ │ │ │ └── WriteDocumentsToLuceneDirectoryChain.java │ │ ├── llm │ │ │ ├── LargeLanguageModelChain.java │ │ │ ├── azure │ │ │ │ ├── chat │ │ │ │ │ └── AzureOpenAiChatCompletionsChain.java │ │ │ │ └── completions │ │ │ │ │ └── AzureOpenAiCompletionsChain.java │ │ │ └── openai │ │ │ │ ├── OpenAiChain.java │ │ │ │ ├── OpenAiParameters.java │ │ │ │ ├── OpenAiResponse.java │ │ │ │ ├── chat │ │ │ │ ├── OpenAiChatCompletionsChain.java │ │ │ │ ├── OpenAiChatCompletionsChoice.java │ │ │ │ ├── OpenAiChatCompletionsParameters.java │ │ │ │ ├── OpenAiChatCompletionsRequest.java │ │ │ │ ├── OpenAiChatCompletionsResponse.java │ │ │ │ └── OpenAiChatMessage.java │ │ │ │ └── completions │ │ │ │ ├── OpenAiCompletionsChain.java │ │ │ │ ├── OpenAiCompletionsChoice.java │ │ │ │ ├── OpenAiCompletionsParameters.java │ │ │ │ ├── OpenAiCompletionsRequest.java │ │ │ │ └── OpenAiCompletionsResponse.java │ │ └── qa │ │ │ ├── AnswerWithSources.java │ │ │ ├── CombineDocumentsChain.java │ │ │ ├── MapAnswerWithSourcesChain.java │ │ │ ├── ModifyDocumentsContentChain.java │ │ │ └── split │ │ │ ├── JtokkitTextSplitter.java │ │ │ ├── MaxLengthBasedTextSplitter.java │ │ │ ├── SplitDocumentsChain.java │ │ │ ├── TextSplitter.java │ │ │ └── TextStreamer.java │ │ └── util │ │ ├── PromptConstants.java │ │ └── PromptTemplates.java └── resources │ └── log4j2.xml └── test ├── java └── io │ └── github │ └── cupybara │ └── javalangchains │ ├── chains │ ├── data │ │ ├── read │ │ │ ├── ReadDocumentsFromInMemoryPdfChainTest.java │ │ │ └── ReadDocumentsFromPdfChainTest.java │ │ ├── retrieval │ │ │ ├── DocumentTestUtil.java │ │ │ ├── ElasticsearchRetrievalChainIT.java │ │ │ ├── JdbcRetrievalChainIT.java │ │ │ └── LuceneRetrievalChainTest.java │ │ └── writer │ │ │ └── WriteDocumentsToElasticsearchIndexChainIT.java │ ├── llm │ │ ├── azure │ │ │ └── chat │ │ │ │ └── AzureOpenAiChatCompletionsChainIT.java │ │ └── openai │ │ │ ├── chat │ │ │ └── OpenAiChatCompletionsChainIT.java │ │ │ └── completions │ │ │ └── OpenAiCompletionsChainIT.java │ └── qa │ │ ├── MapAnswerWithSourcesChainTest.java │ │ └── split │ │ ├── SplitDocumentsChainTest.java │ │ └── TextStreamerTest.java │ └── usecases │ ├── DocumentComparisonTest.java │ └── RetrievalQaTest.java └── resources └── pdf ├── comparison ├── galactic-journey-insurance.pdf └── interplanetary-travel-insurance.pdf └── qa ├── book-of-john-1.pdf ├── book-of-john-2.pdf └── book-of-john-3.pdf /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .settings 3 | .project 4 | .classpath 5 | target 6 | /bin/ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) Harrison Chase 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ![](misc/logo.svg) 2 | 3 | This repository aims to provide a java alternative to [hwchase17/langchain](https://github.com/hwchase17/langchain). 4 | It was born from the need to create an enterprise QA application. 5 | 6 | - [Dependency](#dependency) 7 | - [Chains](#chains) 8 | - [Base](#base) 9 | - [Logging](#logging) 10 | - [Data](#data) 11 | - [Reader](#reader) 12 | - [Read Documents from In Memory PDF](#read-documents-from-in-memory-pdf) 13 | - [Read Documents from PDF](#read-documents-from-pdf) 14 | - [Retrieval](#retrieval) 15 | - [Retrieve Documents from Elasticsearch Index](#retrieve-documents-from-elasticsearch-index) 16 | - [Retrieve Documents from Lucene Directory](#retrieve-documents-from-lucene-directory) 17 | - [Retrieve Documents from a relational database](#retrieve-documents-from-rdbms) 18 | - [Writer](#writer) 19 | - [Write Documents to Elasticsearch Index](#write-documents-to-elasticsearch-index) 20 | - [Write Documents to Lucene Directory](#write-documents-to-lucene-directory) 21 | - [LLM](#llm) 22 | - [Azure](#azure) 23 | - [Azure Chat](#azure-chat) 24 | - [Azure Completions](#azure-completions) 25 | - [OpenAI](#openai) 26 | - [OpenAI Chat](#openai-chat) 27 | - [OpenAI Completions](#openai-completions) 28 | - [QA](#qa) 29 | - [Modify Documents](#modify-documents) 30 | - [Combine Documents](#combine-documents) 31 | - [Map LLM results to answers with sources](#map-llm-results-to-answers-with-sources) 32 | - [Split Documents](#split-documents) 33 | - [Usage behind a corporate proxy](#usage-behind-a-corporate-proxy) 34 | - [Use Cases](#use-cases) 35 | - [Document Comparison](#document-comparison) 36 | - [Retrieval Question-Answering Chain](#retrieval-question-answering-chain) 37 | 38 | ## Dependency 39 | java-langchains requires Java 8 or higher 40 | 41 | 42 | To group this repository with other related repositories in the future we lately transferred this repository to the freshly created organization [cupybara](https://github.com/cupybara). 43 | Therefore we changed the package names from *com.github.hakenadu* to *io.github.cupybara* and also changed the groupId. 44 | The latest artifact is therefore available by using the following dependency: 45 | 46 | ```xml 47 | 48 | io.github.cupybara 49 | java-langchains 50 | 0.6.3 51 | 52 | ``` 53 | 54 | ### deprecated older dependency 55 | 56 | Packages up to version 0.5.0 are available using the groupId com.github.hakenadu. 57 | These artifacts are not updates anymore so we don't suggest to use them. 58 | Please switch to **io.github.cupybara**. 59 | 60 |
61 | old dependency 62 | 63 | ```xml 64 | 65 | com.github.hakenadu 66 | java-langchains 67 | 0.5.0 68 | 69 | ``` 70 |
71 | 72 | 73 | ## Chains 74 | Modular components implement the [Chain](src/main/java/io/github/cupybara/javalangchains/chains/Chain.java) interface. 75 | This provides an easy way to modularize the application and enables us to reuse them for various use cases. 76 | 77 | This section describes the usage of all chains that are currently available. 78 | 79 | ### Base 80 | 81 | #### Logging 82 | The [LoggingChain](src/main/java/io/github/cupybara/javalangchains/chains/base/logging/LoggingChain.java) can be used to log the previous chain's output. 83 | Take a look at the [RetrievalQaTest](src/test/java/io/github/cupybara/javalangchains/usecases/RetrievalQaTest.java) to see some example usages (logging chains indented for improved readability): 84 | 85 | ```java 86 | final Chain qaChain = retrievalChain 87 | .chain(summarizeDocumentsChain) 88 | .chain(new ApplyToStreamInputChain<>(new LoggingChain<>(LoggingChain.defaultLogPrefix("SUMMARIZED_DOCUMENT")))) 89 | .chain(combineDocumentsChain) 90 | .chain(new LoggingChain<>(LoggingChain.defaultLogPrefix("COMBINED_DOCUMENT"))) 91 | .chain(openAiChatChain) 92 | .chain(new LoggingChain<>(LoggingChain.defaultLogPrefix("LLM_RESULT"))) 93 | .chain(mapAnswerWithSourcesChain); 94 | ``` 95 | 96 | The summarizeDocumentsChain in this example provides a Stream as an output. To log each item of the Stream the LoggingChain can be wrapped in an 97 | [ApplyToStreamInputChain](src/main/java/io/github/cupybara/javalangchains/chains/base/ApplyToStreamInputChain.java). 98 | 99 | This example provides the following log output running the RetrievalQaTest: 100 | 101 | ``` 102 | ======================================================================================================================================================== 103 | SUMMARIZED_DOCUMENT 104 | ======================================================================================================================================================== 105 | {source=book-of-john-1.pdf, question=who is john doe?, content=John Doe is a highly skilled and experienced software engineer with a passion for problem-solving and creating innovative solutions. He has been working in the technology industry for over 15 years and has gained a reputation for his exceptional programming abilities and attention to detail.} 106 | 107 | ======================================================================================================================================================== 108 | SUMMARIZED_DOCUMENT 109 | ======================================================================================================================================================== 110 | {source=book-of-john-3.pdf, question=who is john doe?, content=John Doe is described as someone with a diverse range of hobbies and interests. Some of his notable hobbies include music production, culinary adventures, photography and travel, fitness and outdoor activities, and being a book club enthusiast. He is also involved in volunteering and community service, language learning, gardening, DIY projects, and astronomy.} 111 | 112 | ======================================================================================================================================================== 113 | COMBINED_DOCUMENT 114 | ======================================================================================================================================================== 115 | {question=who is john doe?, content=Content: John Doe is described as someone with a diverse range of hobbies and interests. Some of his notable hobbies include music production, culinary adventures, photography and travel, fitness and outdoor activities, and being a book club enthusiast. He is also involved in volunteering and community service, language learning, gardening, DIY projects, and astronomy. 116 | Source: book-of-john-3.pdf 117 | 118 | Content: John Doe is a highly skilled and experienced software engineer with a passion for problem-solving and creating innovative solutions. He has been working in the technology industry for over 15 years and has gained a reputation for his exceptional programming abilities and attention to detail. 119 | Source: book-of-john-1.pdf} 120 | 121 | ======================================================================================================================================================== 122 | LLM_RESULT 123 | ======================================================================================================================================================== 124 | John Doe is described as someone with a diverse range of hobbies and interests, including music production, culinary adventures, photography, travel, fitness, outdoor activities, being a book club enthusiast, volunteering, community service, language learning, gardening, DIY projects, and astronomy. Additionally, John Doe is a highly skilled and experienced software engineer with a passion for problem-solving and creating innovative solutions. He has been working in the technology industry for over 15 years and is known for his exceptional programming abilities and attention to detail. 125 | SOURCES: book-of-john-3.pdf, book-of-john-1.pdf 126 | ``` 127 | 128 | ### Data 129 | 130 | #### Reader 131 | 132 | ##### Read Documents from In Memory PDF 133 | See [ReadDocumentsFromInMemoryPdfChainTest](src/test/java/io/github/cupybara/javalangchains/chains/data/read/ReadDocumentsFromInMemoryPdfChainTest.java) 134 | 135 | Read the in memory pdf into a single document 136 | 137 | ```java 138 | InMemoryPdf inMemoryPdf = new InMemoryPdf( 139 | IOUtils.toByteArray(ReadDocumentsFromInMemoryPdfChainTest.class.getResourceAsStream("/pdf/qa/book-of-john-3.pdf")), 140 | "my-in-memory.pdf"); 141 | 142 | Stream> readDocuments = new ReadDocumentsFromInMemoryPdfChain().run(inMemoryPdf) 143 | 144 | // the readDocuments contains a (pdfContent, "my-in-memory.pdf") pair 145 | ``` 146 | 147 | Read documents for each page of the in memory pdf 148 | 149 | ```java 150 | InMemoryPdf inMemoryPdf = new InMemoryPdf( 151 | IOUtils.toByteArray(ReadDocumentsFromInMemoryPdfChainTest.class.getResourceAsStream("/pdf/qa/book-of-john-3.pdf")), 152 | "my-in-memory.pdf"); 153 | 154 | Stream> readDocuments = new ReadDocumentsFromInMemoryPdfChain(PdfReadMode.PAGES).run(inMemoryPdf) 155 | 156 | // the readDocuments contains (content, source) pairs for all read pdf pages (source is "my-in-memory.pdf" + the pdf page number) 157 | ``` 158 | 159 | ##### Read Documents from PDF 160 | See [ReadDocumentsFromPdfChainTest](src/test/java/io/github/cupybara/javalangchains/chains/data/read/ReadDocumentsFromPdfChainTest.java) 161 | 162 | Read each pdf in the given directory into a single document each 163 | 164 | ```java 165 | Stream> readDocuments = new ReadDocumentsFromPdfChain() 166 | .run(Paths.get("path/to/my/pdf/folder")) 167 | 168 | // the readDocuments contains (content, source) pairs for all read pdfs (source is the pdf filename) 169 | ``` 170 | 171 | Read each page of each pdf in the given directory into a single document each 172 | 173 | ```java 174 | Stream> readDocuments = new ReadDocumentsFromPdfChain(PdfReadMode.PAGES) 175 | .run(Paths.get("path/to/my/pdf/folder")) 176 | 177 | // the readDocuments contains (content, source) pairs for all read pdf pages (source is the pdf filename + the pdf page number) 178 | ``` 179 | 180 | #### Retrieval 181 | 182 | ##### Retrieve Documents from Elasticsearch Index 183 | See [ElasticsearchRetrievalChainIT](src/test/java/io/github/cupybara/javalangchains/chains/data/retrieval/ElasticsearchRetrievalChainIT.java) 184 | 185 | ```java 186 | RestClientBuilder restClientBuilder = RestClient.builder(new HttpHost("localhost", 9200)); 187 | 188 | Chain createElasticsearchIndexChain = new ReadDocumentsFromPdfChain() 189 | .chain(new WriteDocumentsToElasticsearchIndexChain("my-index", restClientBuilder)); 190 | 191 | Path pdfDirectoryPath = Paths.get(ElasticsearchRetrievalChainTest.class.getResource("/pdf/qa").toURI()); 192 | 193 | // create and fill elasticsearch index with read pdfs (source, content)-pairs 194 | createElasticsearchIndexChain.run(pdfDirectoryPath); 195 | 196 | // retrieve documents relevant to a specific question 197 | try (RestClient restClient = restClientBuilder.build(); 198 | ElasticsearchRetrievalChain retrievalChain = new ElasticsearchRetrievalChain("my-index", restClient, 1)) { 199 | 200 | // retrieve the most relevant documents for the passed question 201 | Stream> retrievedDocuments = retrievalChain.run("who is john doe?").collect(Collectors.toList()); 202 | 203 | // ... 204 | } 205 | ``` 206 | 207 | ##### Retrieve Documents from Lucene Directory 208 | See [LuceneRetrievalChainTest](src/test/java/io/github/cupybara/javalangchains/chains/data/retrieval/LuceneRetrievalChainTest.java) 209 | 210 | ```java 211 | // create lucene index 212 | Directory directory = new MMapDirectory(Files.createTempDirectory("myTempDir")); 213 | 214 | // fill lucene index 215 | try (IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig(new StandardAnalyzer()))) { 216 | List documents = Arrays.asList("My first document", "My second document", "My third document"); 217 | 218 | for (String content : documents) { 219 | Document doc = new Document(); 220 | doc.add(new TextField(PromptConstants.CONTENT, content, Field.Store.YES)); 221 | doc.add(new StringField(PromptConstants.SOURCE, String.valueOf(documents.indexOf(content) + 1), Field.Store.YES)); 222 | indexWriter.addDocument(doc); 223 | } 224 | 225 | indexWriter.commit(); 226 | } 227 | 228 | // create retrieval chain 229 | RetrievalChain retrievalChain = new LuceneRetrievalChain(directory, 2 /* max count of retrieved documents */); 230 | 231 | // retrieve the most relevant documents for the passed question 232 | Stream> retrievedDocuments = retrievalChain.run("my question?"); 233 | ``` 234 | 235 | 236 | ##### Retrieve Documents from RDBMS 237 | See [JdbcRetrievalChainIT](src/test/java/io/github/cupybara/javalangchains/chains/data/retrieval/JdbcRetrievalChainIT.java) 238 | 239 | ```java 240 | Supplier connectionSupplier = () -> { 241 | try { 242 | return DriverManager.getConnection(connectionString, username, password); 243 | } catch (SQLException e) { 244 | throw new IllegalStateException("error creating database connection", e); 245 | } 246 | }; 247 | 248 | RetrievalChain retrievalChain = new JdbcRetrievalChain(connectionSupplier, 2 /* max count of retrieved documents */); 249 | 250 | Stream> retrievedDocuments = retrievalChain.run("my question?"); 251 | ``` 252 | 253 | #### Writer 254 | 255 | ##### Write Documents to Elasticsearch Index 256 | ```java 257 | RestClientBuilder restClientBuilder = RestClient.builder(new HttpHost("localhost", 9200)); 258 | 259 | // this chain reads documents from a folder of pdfs and writes them to an elasticsearch index 260 | Chain fillElasticsearchIndexChain = new ReadDocumentsFromPdfChain() 261 | .chain(new WriteDocumentsToElasticsearchIndexChain("my-index", restClientBuilder)); 262 | 263 | Path pdfDirectoryPath = Paths.get(getClass().getResource("/pdf/qa").toURI()); 264 | 265 | fillElasticsearchIndexChain.run(pdfDirectoryPath); 266 | ``` 267 | 268 | ##### Write Documents to Lucene Directory 269 | ```java 270 | Path tempIndexPath = Files.createTempDirectory("lucene") 271 | 272 | // this chain reads documents from a folder of pdfs and writes them to an index directory 273 | Chain createLuceneIndexChain = new ReadDocumentsFromPdfChain() 274 | .chain(new WriteDocumentsToLuceneDirectoryChain(tempIndexPath)); 275 | 276 | Path pdfDirectoryPath = Paths.get(getClass().getResource("/pdf/qa").toURI()); 277 | 278 | Directory directory = createLuceneIndexChain.run(pdfDirectoryPath); 279 | ``` 280 | 281 | ### LLM 282 | 283 | #### Azure 284 | 285 | ##### Azure Chat 286 | See [AzureOpenAiChatCompletionsChainIT](src/test/java/io/github/cupybara/javalangchains/chains/llm/azure/chat/AzureOpenAiChatCompletionsChainIT.java) 287 | 288 | ```java 289 | AzureOpenAiChatCompletionsChain chain = new AzureOpenAiChatCompletionsChain( 290 | "my-azure-resource-name", 291 | "gpt-35-turbo", // deployment name 292 | "2023-05-15", // api version 293 | "Hello, this is ${name}", 294 | new OpenAiChatCompletionsParameters().temperature(0D), // also allows to set more parameters 295 | System.getenv("OPENAI_API_KEY"), 296 | "You are a helpful assistant who answers questions to ${name}" // optional systemTemplate 297 | ); 298 | 299 | String result = chain.run(Collections.singletonMap("name", "Manuel")); 300 | // the above outputs something like: "Hello Manuel, how are you" 301 | ``` 302 | 303 | ##### Azure Completions 304 | ```java 305 | AzureOpenAiCompletionsChain chain = new AzureOpenAiCompletionsChain( 306 | "my-azure-resource-name", 307 | "text-davinci-003", // deployment name 308 | "2023-05-15", // api version 309 | "Hello, this is ${name}", 310 | new OpenAiCompletionsParameters().temperature(0D), // also allows to set more parameters 311 | System.getenv("OPENAI_API_KEY"), 312 | "You are a helpful assistant who answers questions to ${name}" // optional systemTemplate 313 | ); 314 | 315 | String result = chain.run(Collections.singletonMap("name", "Manuel")); 316 | // the above outputs something like: "Hello Manuel, how are you" 317 | ``` 318 | 319 | #### OpenAI 320 | 321 | ##### OpenAI Chat 322 | See [OpenAiChatCompletionsChainIT](src/test/java/io/github/cupybara/javalangchains/chains/llm/openai/chat/OpenAiChatCompletionsChainIT.java) 323 | 324 | ```java 325 | OpenAiChatCompletionsChain chain = new OpenAiChatCompletionsChain( 326 | "Hello, this is ${name}", 327 | new OpenAiChatCompletionsParameters().model("gpt-3.5-turbo").temperature(0D), // also allows to set more parameters 328 | System.getenv("OPENAI_API_KEY"), 329 | "You are a helpful assistant who answers questions to ${name}" // optional systemTemplate 330 | ); 331 | 332 | String result = chain.run(Collections.singletonMap("name", "Manuel")); 333 | // the above outputs something like: "Hello Manuel, how are you" 334 | ``` 335 | 336 | #### OpenAI Completions 337 | ```java 338 | OpenAiCompletionsChain chain = new OpenAiCompletionsChain( 339 | "Hello, this is ${name}", 340 | new OpenAiCompletionsParameters().model("text-davinci-003").temperature(0D), // also allows to set more parameters 341 | System.getenv("OPENAI_API_KEY"), 342 | "You are a helpful assistant who answers questions to ${name}" // optional systemTemplate 343 | ); 344 | 345 | String result = chain.run(Collections.singletonMap("name", "Manuel")); 346 | // the above outputs something like: "Hello Manuel, how are you" 347 | ``` 348 | 349 | ### QA 350 | 351 | #### Modify Documents 352 | The ModifyDocumentsContentChain can be used for document summarization (for example). 353 | 354 | ```java 355 | // create the llm chain which is used for summarization 356 | LargeLanguageModelChain llmChain = new OpenAiChatCompletionsChain( 357 | PromptTemplates.QA_SUMMARIZE, 358 | new OpenAiChatCompletionsParameters().temperature(0D).model("gpt-3.5-turbo"), 359 | System.getenv("OPENAI_API_KEY")); 360 | 361 | // create the ModifyDocumentsContentChain which is used to apply the llm chain to each passed document 362 | ModifyDocumentsContentChain summarizeDocumentsChain = new ModifyDocumentsContentChain(llmChain); 363 | 364 | // create some example documents 365 | Map myFirstDocument = new HashMap(); 366 | myFirstDocument.put(PromptConstants.CONTENT, "this is my first document content"); 367 | myFirstDocument.put(PromptConstants.SOURCE, "this is my first document source"); 368 | // the default summarize prompt PromptTemplates.QA_SUMMARIZE also expects the question used for retrieval in the document 369 | myFirstDocument.put(PromptConstants.QUESTION, "who is John Doe?"); 370 | 371 | Map mySecondDocument = new HashMap(); 372 | mySecondDocument.put(PromptConstants.CONTENT, "this is my second document content"); 373 | mySecondDocument.put(PromptConstants.SOURCE, "this is my second document source"); 374 | mySecondDocument.put(PromptConstants.QUESTION, "how old is John Doe?"); // see comment above 375 | 376 | // input for the summarize chain is a stream of documents 377 | Stream> documents = Stream.of(myFirstDocument, mySecondDocument); 378 | 379 | // output contains the passed documents with summarized content-Value 380 | Stream> summarizedDocuments = summarizeDocumentsChain.run(documents); 381 | ``` 382 | 383 | #### Combine Documents 384 | ```java 385 | CombineDocumentsChain combineDocumentsChain = new CombineDocumentsChain(); 386 | 387 | Map myFirstDocument = new HashMap(); 388 | myFirstDocument.put(PromptConstants.CONTENT, "this is my first document content"); 389 | myFirstDocument.put(PromptConstants.SOURCE, "this is my first document source"); 390 | 391 | Map mySecondDocument = new HashMap(); 392 | mySecondDocument.put(PromptConstants.CONTENT, "this is my second document content"); 393 | mySecondDocument.put(PromptConstants.SOURCE, "this is my second document source"); 394 | 395 | Stream> documents = Stream.of(myFirstDocument, mySecondDocument); 396 | 397 | Map combinedDocument = combineDocumentsChain.run(documents); 398 | /* 399 | * Content: this is my first document content 400 | * Source: this is my first document source 401 | * 402 | * Content: this is my second document content 403 | * Source: this is my second document source 404 | * 405 | * (stored with key "content" inside the map) 406 | */ 407 | ``` 408 | 409 | #### Map LLM results to answers with sources 410 | ```java 411 | MapAnswerWithSourcesChain mapAnswerWithSourcesChain = new MapAnswerWithSourcesChain(); 412 | 413 | AnswerWithSources answerWithSources = mapAnswerWithSourcesChain.run("The answer is bla bla bla.\nSOURCES: page 1 book xy, page 2 book ab"); 414 | 415 | System.out.println(answerWithSources.getAnswer()); // The answer is bla bla bla. 416 | System.out.println(answerWithSources.getSources()); // [page 1 book xy, page 2 book ab] 417 | 418 | ``` 419 | 420 | #### Split Documents 421 | See [SplitDocumentsChainTest](src/test/java/io/github/cupybara/javalangchains/chains/qa/split/SplitDocumentsChainTest.java) 422 | 423 | ```java 424 | 425 | // 1. Create Documents 426 | 427 | List> documents = new LinkedList<>(); 428 | 429 | Map firstDocument = new LinkedHashMap<>(); 430 | firstDocument.put(PromptConstants.SOURCE, "book of john"); 431 | firstDocument.put(PromptConstants.CONTENT, "This is a short text. This is another short text."); 432 | documents.add(firstDocument); 433 | 434 | Map secondDocument = new LinkedHashMap<>(); 435 | secondDocument.put(PromptConstants.SOURCE, "book of jane"); 436 | secondDocument.put(PromptConstants.CONTENT, "This is a short text."); 437 | documents.add(secondDocument); 438 | 439 | // 2. Split Documents 440 | 441 | /* 442 | * We create a TextSplitter that splits a text into partitions using a JTokkit 443 | * Encoding. We use the cl100k_base encoding (which btw is the default for 444 | * gpt-3.5-turbo) 445 | */ 446 | TextSplitter textSplitter = new JtokkitTextSplitter( 447 | Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE), 10); 448 | 449 | /* 450 | * we now instantiate the SplitDocumentsChain which will split our documents 451 | * using the above created TextSplitter on the "content" field. 452 | */ 453 | SplitDocumentsChain splitDocumentsChain = new SplitDocumentsChain(textSplitter); 454 | 455 | List> splitDocuments = splitDocumentsChain.run(documents.stream()) 456 | .collect(Collectors.toList()); 457 | 458 | // splitDocuments: [ 459 | // {content=This is a short text. , source=book of john}, 460 | // {content=This is another short text., source=book of john}, 461 | // {content=This is a short text., source=book of jane} 462 | // ] 463 | ``` 464 | 465 | ## Usage behind a corporate proxy 466 | If a chain needs to access to an external service, there will be a constructor parameter for passing the http client. 467 | The [WebClient](https://docs.spring.io/spring-framework/reference/web/webflux-webclient.html) is used for the following chains: 468 | * [AzureOpenAiChatCompletionsChain](src/main/java/io/github/cupybara/javalangchains/chains/llm/azure/chat/AzureOpenAiChatCompletionsChain.java) 469 | * [AzureOpenAiCompletionsChain](src/main/java/io/github/cupybara/javalangchains/chains/llm/azure/completions/AzureOpenAiCompletionsChain.java) 470 | * [OpenAiChatCompletionsChain](src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/chat/OpenAiChatCompletionsChain.java) 471 | * [OpenAiCompletionsChain](src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/completions/OpenAiCompletionsChain.java) 472 | 473 | There exists plenty of public documentation on how to configure a http proxy for those cases. 474 | One example is [this one from Baeldung](https://www.baeldung.com/spring-webflux-timeout). 475 | 476 | For accessing an Elasticsearch cluster the [Elasticsearch Low Level Client](https://www.elastic.co/guide/en/elasticsearch/client/java-api-client/current/java-rest-low.html) is used. 477 | The [official documentation](https://www.elastic.co/guide/en/elasticsearch/client/java-api-client/8.8/java-rest-low-usage-initialization.html) shows how to use a proxy in this case. 478 | 479 | ## Use Cases 480 | Multiple chains can be chained together to create more powerful chains for complex use cases. 481 | 482 | ### Document Comparison 483 | The [following unit test](src/test/java/io/github/cupybara/javalangchains/usecases/DocumentComparisonTest.java) shows how the existing chains may be used to compare 2 or more documents. 484 | I guess more abstraction would be useful. I will target that in one of the next releases and then also include example code in this README. 485 | 486 | The following diagram shows how the implementation for this usecase works: 487 | 488 | ![](misc/drawio/docment-comparison.svg) 489 | 490 | 491 | ### Retrieval Question-Answering Chain 492 | The [following unit test](src/test/java/io/github/cupybara/javalangchains/usecases/RetrievalQaTest.java) provides a comprehensive solution for an information retrieval and summarization task, with the aim to provide concise, informative and relevant answers from a large set of documents. It combines multiple processes into a Question-Answering (QA) chain, each responsible for a specific task. 493 | 494 | ```java 495 | /* 496 | * take a look at src/test/resources/pdf of this repository 497 | * the pdf directory contains three documents about a fictional person named john doe 498 | * which we want to query using our retrieval based qa with sources chain 499 | */ 500 | Path pdfDirectoryPath = Paths.get(RetrievalQaTest.class.getResource("/pdf/qa").toURI()); 501 | 502 | /* 503 | * We are creating and running an initializing chain which reads document from our pdf folder 504 | * and writes them to a lucene index directory 505 | */ 506 | Directory directory = new ReadDocumentsFromPdfChain().chain(new WriteDocumentsToLuceneDirectoryChain()).run(pdfDirectoryPath); 507 | 508 | // we got multiple OpenAI LLM Chains and define our parameters at first 509 | OpenAiChatCompletionsParameters openAiChatParameters = new OpenAiChatCompletionsParameters() 510 | .temperature(0D) 511 | .model("gpt-3.5-turbo"); 512 | 513 | /* 514 | * Chain 1: The retrievalChain is used to retrieve relevant documents from an 515 | * index by using bm25 similarity 516 | */ 517 | try (LuceneRetrievalChain retrievalChain = new LuceneRetrievalChain(directory /* implies a filled lucene directory */, 2)) { 518 | 519 | /* 520 | * Chain 2: The summarizeDocumentsChain is used to summarize documents to only 521 | * contain the most relevant information. This is achieved using an OpenAI LLM 522 | * (gpt-3.5-turbo in this case) 523 | */ 524 | ModifyDocumentsContentChain summarizeDocumentsChain = new ModifyDocumentsContentChain(new OpenAiChatCompletionsChain( 525 | PromptTemplates.QA_SUMMARIZE, openAiChatParameters, System.getenv("OPENAI_API_KEY"))); 526 | 527 | /* 528 | * Chain 3: The combineDocumentsChain is used to combine the retrieved documents 529 | * in a single prompt 530 | */ 531 | CombineDocumentsChain combineDocumentsChain = new CombineDocumentsChain(); 532 | 533 | /* 534 | * Chain 4: The openAiChatChain is used to process the combined prompt using an 535 | * OpenAI LLM (gpt-3.5-turbo in this case) 536 | */ 537 | OpenAiChatCompletionsChain openAiChatChain = new OpenAiChatCompletionsChain(PromptTemplates.QA_COMBINE, 538 | openAiChatParameters, System.getenv("OPENAI_API_KEY")); 539 | 540 | /* 541 | * Chain 5: The mapAnswerWithSourcesChain is used to map the llm string output 542 | * to a complex object using a regular expression which splits the sources and 543 | * the answer. 544 | */ 545 | MapAnswerWithSourcesChain mapAnswerWithSourcesChain = new MapAnswerWithSourcesChain(); 546 | 547 | // we combine all chain links into a self contained QA chain 548 | Chain qaChain = retrievalChain 549 | .chain(summarizeDocumentsChain) 550 | .chain(combineDocumentsChain) 551 | .chain(openAiChatChain) 552 | .chain(mapAnswerWithSourcesChain); 553 | 554 | // the QA chain can now be called with a question and delivers an answer 555 | AnswerWithSources answerWithSources = qaChain.run("who is john doe?"); 556 | 557 | /* 558 | * answerWithSources.getAnwswer() provides the answer to the question based on the retrieved documents 559 | * answerWithSources.getSources() provides a list of source strings for the retrieved documents 560 | */ 561 | } 562 | ``` 563 | 564 | The QA chain performs the following tasks: 565 | 566 | 1. **Document Retrieval**: This step is responsible for retrieving the most relevant documents related to a given query from a large collection. It uses an index-based search algorithm to find documents containing information related to the input query. This functionality can be facilitated by any `RetrievalChain` implementation. `LuceneRetrievalChain`, which utilizes the BM25 similarity metric, is just an example used in the test case. 567 | 568 | 2. **Document Summarization**: Once relevant documents are retrieved, they need to be summarized to extract the most essential information. The `SummarizeDocumentsChain` uses an instance of `LargeLanguageModelChain` for this task. In the provided example, OpenAI's GPT-3.5-turbo model via `OpenAiChatCompletionsChain` is used to reduce the information to its most relevant content. 569 | 570 | 3. **Document Combination**: The `CombineDocumentsChain` combines the summarized documents into a single prompt. This forms the input to the next stage of the process. 571 | 572 | 4. **Answer Generation**: The `OpenAiChatCompletionsChain` uses the combined prompt to generate a response. Any instance of `LargeLanguageModelChain` can be used for this step. In the given example, OpenAI's GPT-3.5-turbo model is utilized. 573 | 574 | 5. **Mapping and Answer Extraction**: Finally, the `MapAnswerWithSourcesChain` maps the string output to a complex object using a regular expression, which splits the answer from the sources of information. This provides a structured output that includes both the answer to the query and the sources from which the answer was derived. 575 | 576 | In conclusion, the QA chain represents a comprehensive solution for document-based question-answering tasks, providing not only the most relevant answer but also citing the sources from which the information was retrieved. This chain is particularly useful in contexts where understanding the origin of information is as crucial as the answer itself. -------------------------------------------------------------------------------- /misc/docker/elasticsearch/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.2" 2 | services: 3 | elasticsearch: 4 | image: docker.elastic.co/elasticsearch/elasticsearch:8.8.1 5 | container_name: elasticsearch 6 | environment: 7 | - discovery.type=single-node 8 | - ES_JAVA_OPTS=-Xms750m -Xmx750m 9 | - xpack.security.enabled=false 10 | ports: 11 | - 9200:9200 12 | -------------------------------------------------------------------------------- /misc/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | Java.langchains 4 | 5 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | 5 | io.github.cupybara 6 | java-langchains 7 | 0.6.4-SNAPSHOT 8 | 9 | Java.langchains 10 | 11 | A library which provides components to build LLM based applications in 12 | Java 13 | 14 | https://github.com/cupybara/java-langchains 15 | 16 | 17 | 18 | MIT License 19 | http://www.opensource.org/licenses/mit-license.php 20 | 21 | 22 | 23 | 24 | 25 | mseiche 26 | Manuel Seiche 27 | hakenadu91@gmail.com 28 | https://mseiche.de 29 | 30 | architect 31 | developer 32 | 33 | Europe/Berlin 34 | 35 | 36 | 37 | 38 | scm:git:git://github.com/cupybara/java-langchains.git 39 | 40 | scm:git:ssh://github.com:cupybara/java-langchains.git 41 | 42 | https://github.com/cupybara/java-langchains/tree/master 43 | v0.6.2 44 | 45 | 46 | 47 | 48 | ossrh 49 | https://s01.oss.sonatype.org/content/repositories/snapshots 50 | 51 | 52 | ossrh 53 | https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ 54 | 55 | 56 | 57 | 58 | 1.8 59 | ${java.version} 60 | ${java.version} 61 | 62 | UTF-8 63 | ${encoding} 64 | ${encoding} 65 | 66 | 1.10.0 67 | 2.15.1 68 | 0.5.0 69 | 5.9.3 70 | 2.20.0 71 | 8.10.1 72 | 8.8.1 73 | 3.1.0 74 | 3.1.0 75 | 3.5.0 76 | 3.0.0 77 | 2.0.1 78 | 3.3.0 79 | 3.1.0 80 | 1.6.13 81 | 3.0.0-RC1 82 | 42.6.0 83 | 1.1.7 84 | 5.3.27 85 | 86 | 87 | 88 | 89 | release-sign-artifacts 90 | 91 | 92 | performRelease 93 | true 94 | 95 | 96 | 97 | 98 | 99 | org.apache.maven.plugins 100 | maven-gpg-plugin 101 | ${maven.gpg.plugin.version} 102 | 103 | 104 | sign-artifacts 105 | verify 106 | 107 | sign 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | maven-deploy-plugin 121 | ${maven.deploy.plugin.version} 122 | 123 | 124 | default-deploy 125 | deploy 126 | 127 | deploy 128 | 129 | 130 | 131 | 132 | 133 | org.apache.maven.plugins 134 | maven-release-plugin 135 | ${maven.release.plugin.version} 136 | 137 | true 138 | 139 | pom.xml 140 | 141 | false 142 | forked-path 143 | -Dgpg.passphrase=${gpg.passphrase} 144 | -Dmaven.test.skipTests=true 145 | 146 | 147 | 148 | org.apache.maven.scm 149 | maven-scm-provider-gitexe 150 | ${maven.scm.provider.gitexe.version} 151 | 152 | 153 | 154 | 155 | org.sonatype.plugins 156 | nexus-staging-maven-plugin 157 | ${nexus.staging.maven.plugin.version} 158 | true 159 | 160 | ossrh 161 | https://s01.oss.sonatype.org/ 162 | true 163 | 164 | 165 | 166 | org.apache.maven.plugins 167 | maven-source-plugin 168 | ${maven.source.plugin.version} 169 | 170 | 171 | attach-sources 172 | 173 | jar 174 | 175 | 176 | 177 | 178 | 179 | org.apache.maven.plugins 180 | maven-javadoc-plugin 181 | ${maven.javadoc.plugin.version} 182 | 183 | 184 | attach-javadocs 185 | 186 | jar 187 | 188 | 189 | 190 | 191 | 192 | org.apache.maven.plugins 193 | maven-surefire-plugin 194 | ${maven.surefire.plugin.version} 195 | 196 | 197 | 198 | 199 | 200 | 201 | org.apache.logging.log4j 202 | log4j-api 203 | ${log4j.version} 204 | 205 | 206 | org.apache.logging.log4j 207 | log4j-core 208 | ${log4j.version} 209 | 210 | 211 | org.apache.commons 212 | commons-text 213 | ${commons.text.version} 214 | 215 | 216 | com.fasterxml.jackson.core 217 | jackson-databind 218 | ${jackson.version} 219 | 220 | 221 | com.fasterxml.jackson.core 222 | jackson-annotations 223 | ${jackson.version} 224 | 225 | 226 | org.springframework 227 | spring-webflux 228 | ${webflux.version} 229 | 230 | 231 | io.projectreactor.netty 232 | reactor-netty 233 | ${reactor.version} 234 | 235 | 236 | 237 | 238 | org.apache.lucene 239 | lucene-core 240 | ${lucene.version} 241 | 242 | 243 | org.apache.lucene 244 | lucene-queryparser 245 | ${lucene.version} 246 | 247 | 248 | org.elasticsearch.client 249 | elasticsearch-rest-client 250 | ${elasticsearch.version} 251 | 252 | 253 | org.apache.pdfbox 254 | pdfbox 255 | ${pdfbox.version} 256 | 257 | 258 | com.knuddels 259 | jtokkit 260 | ${jtokkit.version} 261 | 262 | 263 | 264 | 265 | org.junit.jupiter 266 | junit-jupiter-api 267 | ${junit.jupiter.version} 268 | test 269 | 270 | 271 | org.postgresql 272 | postgresql 273 | ${postgresql.version} 274 | test 275 | 276 | 277 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/Chain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains; 2 | 3 | /** 4 | * Basic interface for all modular components in this repository. A 5 | * {@link Chain} accepts an input of type *I* and provides an output of type 6 | * *O*. Using the method {@link #chain(Chain)} passing another {@link Chain}, a 7 | * new {@link Chain} can be created which accepts the original chain's input and 8 | * provided the new chain's output. 9 | * 10 | * @param the chain input type 11 | * @param the chain output type 12 | */ 13 | @FunctionalInterface 14 | public interface Chain { 15 | 16 | /** 17 | * Execute this {@link Chain} 18 | * 19 | * @param input this chain's input 20 | * @return this chain's output 21 | */ 22 | O run(I input); 23 | 24 | /** 25 | * create a new {@link Chain} connecting this instance with another passed one. 26 | * 27 | * @param type of the next chain's output 28 | * @param next the next chain which is attached to this instance 29 | * @return a new {@link Chain} consisting of the original {@link Chain} and the 30 | * passed one 31 | */ 32 | default ChainLink chain(final Chain next) { 33 | return new ChainLink<>(this, next); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/ChainLink.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains; 2 | 3 | /** 4 | * A Link Between Worlds ;-) 5 | * 6 | * @param type of the input chain's input 7 | * @param type of the input chain's output and the output chain's input 8 | * @param type of the output chain's output 9 | */ 10 | public final class ChainLink implements Chain { 11 | 12 | private final Chain inputChain; 13 | private final Chain outputChain; 14 | 15 | /** 16 | * @param inputChain {@link #inputChain} 17 | * @param outputChain {@link #outputChain} 18 | */ 19 | ChainLink(final Chain inputChain, final Chain outputChain) { 20 | this.inputChain = inputChain; 21 | this.outputChain = outputChain; 22 | } 23 | 24 | @Override 25 | public O run(final I input) { 26 | final M intermediateOutput = inputChain.run(input); 27 | 28 | final O output = outputChain.run(intermediateOutput); 29 | 30 | return output; 31 | } 32 | 33 | /** 34 | * @return true if this {@link ChainLink} is the first one of the whole chain 35 | */ 36 | public boolean isHead() { 37 | return !(inputChain instanceof ChainLink); 38 | } 39 | 40 | /** 41 | * @return true if this {@link ChainLink} is the final one of the whole chain 42 | */ 43 | public boolean isTail() { 44 | return !(outputChain instanceof ChainLink); 45 | } 46 | 47 | /** 48 | * @return {@link #inputChain} 49 | */ 50 | public Chain getInputChain() { 51 | return inputChain; 52 | } 53 | 54 | /** 55 | * @return {@link #outputChain} 56 | */ 57 | public Chain getOutputChain() { 58 | return outputChain; 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/base/ApplyToStreamInputChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.base; 2 | 3 | import java.util.stream.Stream; 4 | 5 | import io.github.cupybara.javalangchains.chains.Chain; 6 | 7 | /** 8 | * this chain applies another chain (which is passed as a constructor parameter) 9 | * to each item of the input stream. 10 | * 11 | * @param the type of each item in the input stream 12 | * @param the type of each item in the output stream 13 | */ 14 | public final class ApplyToStreamInputChain implements Chain, Stream> { 15 | 16 | /** 17 | * this chain is applied to each item of the input stream 18 | */ 19 | private final Chain applyToStreamItemChain; 20 | 21 | /** 22 | * @param applyToStreamItemChain {@link #applyToStreamItemChain} 23 | */ 24 | public ApplyToStreamInputChain(final Chain applyToStreamItemChain) { 25 | this.applyToStreamItemChain = applyToStreamItemChain; 26 | } 27 | 28 | @Override 29 | public Stream run(final Stream input) { 30 | return input.map(this.applyToStreamItemChain::run); 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/base/JoinChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.base; 2 | 3 | import java.util.Arrays; 4 | import java.util.List; 5 | import java.util.stream.Stream; 6 | 7 | import io.github.cupybara.javalangchains.chains.Chain; 8 | 9 | /** 10 | * This {@link Chain} is used to join multiple other chains. Their output is 11 | * provided as a {@link Stream} which will be passed as an input to subsequent 12 | * chains. 13 | * 14 | * @param Input type of joined chains 15 | * @param Output type of joined chains 16 | */ 17 | public final class JoinChain implements Chain> { 18 | 19 | /** 20 | * the list of joined {@link Chain Chains} 21 | */ 22 | private final List> chains; 23 | 24 | /** 25 | * if true the result stream will be a parallel one 26 | */ 27 | private final boolean parallel; 28 | 29 | /** 30 | * @param parallel {@link #parallel} 31 | * @param chains {@link #chains} 32 | */ 33 | public JoinChain(final boolean parallel, final List> chains) { 34 | this.parallel = parallel; 35 | this.chains = chains; 36 | } 37 | 38 | /** 39 | * @param chains {@link #chains} 40 | */ 41 | public JoinChain(final List> chains) { 42 | this(false, chains); 43 | } 44 | 45 | /** 46 | * @param parallel {@link #parallel} 47 | * @param chains {@link #chains} 48 | */ 49 | @SafeVarargs 50 | public JoinChain(final boolean parallel, final Chain... chains) { 51 | this(parallel, Arrays.asList(chains)); 52 | } 53 | 54 | /** 55 | * @param chains {@link #chains} 56 | */ 57 | @SafeVarargs 58 | public JoinChain(final Chain... chains) { 59 | this(false, chains); 60 | } 61 | 62 | @Override 63 | public Stream run(final I input) { 64 | final Stream result = chains.stream().map(chain -> chain.run(input)); 65 | if (parallel) { 66 | return result.parallel(); 67 | } 68 | return result; 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/base/StreamUnwrappingChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.base; 2 | 3 | import java.util.NoSuchElementException; 4 | import java.util.stream.Stream; 5 | 6 | import io.github.cupybara.javalangchains.chains.Chain; 7 | 8 | /** 9 | * a utility chain which is used to retrieve the element from a singleton stream 10 | * 11 | * @param Type of the element in the {@link Stream} 12 | */ 13 | public final class StreamUnwrappingChain implements Chain, T> { 14 | 15 | @Override 16 | public T run(final Stream input) { 17 | return input.findAny().orElseThrow(NoSuchElementException::new); 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/base/StreamWrappingChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.base; 2 | 3 | import java.util.stream.Stream; 4 | 5 | import io.github.cupybara.javalangchains.chains.Chain; 6 | 7 | /** 8 | * utility {@link Chain} which wraps the output of the previous {@link Chain} in 9 | * a {@link Stream} for processing using {@link Stream} consuming {@link Chain 10 | * Chains}. 11 | */ 12 | public final class StreamWrappingChain implements Chain> { 13 | 14 | @Override 15 | public Stream run(final T input) { 16 | return Stream.of(input); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/base/logging/LoggingChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.base.logging; 2 | 3 | import java.util.function.Consumer; 4 | import java.util.function.Function; 5 | 6 | import io.github.cupybara.javalangchains.chains.Chain; 7 | 8 | /** 9 | * this chain is used to log an input instance 10 | * 11 | * @param the static input type 12 | */ 13 | public class LoggingChain implements Chain { 14 | 15 | private final String logPrefix; 16 | private final Consumer logConsumer; 17 | private final Function inputSerializer; 18 | 19 | /** 20 | * @param logPrefix {@link #logPrefix} 21 | * @param logConsumer {@link #logConsumer} 22 | * @param inputSerializer {@link #inputSerializer} 23 | */ 24 | public LoggingChain(final String logPrefix, final Consumer logConsumer, 25 | final Function inputSerializer) { 26 | this.logPrefix = logPrefix; 27 | this.inputSerializer = inputSerializer; 28 | this.logConsumer = logConsumer; 29 | } 30 | 31 | /** 32 | * @param logPrefix {@link #logPrefix} 33 | * @param logConsumer {@link #logConsumer} 34 | */ 35 | public LoggingChain(final String logPrefix, final Consumer logConsumer) { 36 | this(logPrefix, logConsumer, String::valueOf); 37 | } 38 | 39 | /** 40 | * @param logPrefix {@link #logPrefix} 41 | * @param inputSerializer {@link #inputSerializer} 42 | */ 43 | public LoggingChain(final String logPrefix, final Function inputSerializer) { 44 | this(logPrefix, System.out::println, inputSerializer); 45 | } 46 | 47 | /** 48 | * @param logPrefix {@link #logPrefix} 49 | */ 50 | public LoggingChain(final String logPrefix) { 51 | this(logPrefix, System.out::println, String::valueOf); 52 | } 53 | 54 | /** 55 | * creates an instance of the {@link LoggingChain} 56 | */ 57 | public LoggingChain() { 58 | this(""); 59 | } 60 | 61 | @Override 62 | public I run(final I input) { 63 | final String serializedInput = inputSerializer.apply(input); 64 | logConsumer.accept(String.format("%s%s", this.logPrefix, serializedInput)); 65 | return input; 66 | } 67 | 68 | /** 69 | * @param title the title for the logPrefix 70 | * @return title with delimiting lines 71 | */ 72 | public static String defaultLogPrefix(final String title) { 73 | return "\n========================================================================================================================================================\n" 74 | + title 75 | + "\n========================================================================================================================================================\n"; 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/data/reader/ReadDocumentsFromInMemoryPdfChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.reader; 2 | 3 | import java.io.IOException; 4 | import java.util.stream.Stream; 5 | 6 | import org.apache.pdfbox.Loader; 7 | 8 | import io.github.cupybara.javalangchains.chains.data.reader.ReadDocumentsFromInMemoryPdfChain.InMemoryPdf; 9 | 10 | /** 11 | * Utilizes Apache PDFBox to read documents from a byte array 12 | */ 13 | public class ReadDocumentsFromInMemoryPdfChain extends ReadDocumentsFromPdfChainBase { 14 | 15 | /** 16 | * wrapper for an in memory pdf (byte array + title) 17 | */ 18 | public static class InMemoryPdf { 19 | 20 | /** 21 | * pdf data as byte array 22 | */ 23 | private final byte[] data; 24 | 25 | /** 26 | * pdf document name 27 | */ 28 | private final String name; 29 | 30 | /** 31 | * @param data {@link #data} 32 | * @param name {@link #name} 33 | */ 34 | public InMemoryPdf(final byte[] data, final String name) { 35 | this.data = data; 36 | this.name = name; 37 | } 38 | } 39 | 40 | /** 41 | * creates a {@link ReadDocumentsFromInMemoryPdfChain} 42 | * 43 | * @param readMode {@link #readMode} 44 | * @param parallel {@link #parallel} 45 | */ 46 | public ReadDocumentsFromInMemoryPdfChain(final PdfReadMode readMode, final boolean parallel) { 47 | super(readMode, parallel); 48 | } 49 | 50 | /** 51 | * creates a {@link ReadDocumentsFromInMemoryPdfChain} 52 | * 53 | * @param readMode {@link #readMode} 54 | */ 55 | public ReadDocumentsFromInMemoryPdfChain(final PdfReadMode readMode) { 56 | this(readMode, false); 57 | } 58 | 59 | /** 60 | * creates a {@link ReadDocumentsFromInMemoryPdfChain} which reads the whole pdf 61 | * as a document 62 | */ 63 | public ReadDocumentsFromInMemoryPdfChain() { 64 | this(PdfReadMode.WHOLE); 65 | } 66 | 67 | @Override 68 | protected Stream loadPdDocuments(final InMemoryPdf input) throws IOException { 69 | return Stream.of(new PdDocumentWrapper(Loader.loadPDF(input.data), input.name)); 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/data/reader/ReadDocumentsFromPdfChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.reader; 2 | 3 | import java.io.IOException; 4 | import java.nio.file.Files; 5 | import java.nio.file.Path; 6 | import java.util.stream.Stream; 7 | 8 | import org.apache.pdfbox.Loader; 9 | 10 | /** 11 | * Utilizes Apache PDFBox to read documents from a PDF or a folder of PDFs 12 | */ 13 | public class ReadDocumentsFromPdfChain extends ReadDocumentsFromPdfChainBase { 14 | 15 | /** 16 | * creates a {@link ReadDocumentsFromPdfChain} 17 | * 18 | * @param readMode {@link #readMode} 19 | * @param parallel {@link #parallel} 20 | */ 21 | public ReadDocumentsFromPdfChain(final PdfReadMode readMode, final boolean parallel) { 22 | super(readMode, parallel); 23 | } 24 | 25 | /** 26 | * creates a {@link ReadDocumentsFromPdfChain} 27 | * 28 | * @param readMode {@link #readMode} 29 | */ 30 | public ReadDocumentsFromPdfChain(final PdfReadMode readMode) { 31 | this(readMode, false); 32 | } 33 | 34 | /** 35 | * creates a {@link ReadDocumentsFromPdfChain} which reads the whole pdf as a 36 | * document 37 | */ 38 | public ReadDocumentsFromPdfChain() { 39 | this(PdfReadMode.WHOLE); 40 | } 41 | 42 | @Override 43 | protected Stream loadPdDocuments(final Path input) throws IOException { 44 | return Files.walk(input).filter(Files::isRegularFile) 45 | .filter(path -> path.toString().toLowerCase().endsWith(".pdf")).map(path -> { 46 | try { 47 | return new PdDocumentWrapper(Loader.loadPDF(path.toFile()), path.getFileName().toString()); 48 | } catch (final IOException ioException) { 49 | throw new IllegalStateException("could not read document from " + path); 50 | } 51 | }); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/data/reader/ReadDocumentsFromPdfChainBase.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.reader; 2 | 3 | import java.io.IOException; 4 | import java.util.LinkedHashMap; 5 | import java.util.LinkedList; 6 | import java.util.List; 7 | import java.util.Map; 8 | import java.util.stream.Stream; 9 | 10 | import org.apache.logging.log4j.LogManager; 11 | import org.apache.pdfbox.pdmodel.PDDocument; 12 | import org.apache.pdfbox.text.PDFTextStripper; 13 | 14 | import io.github.cupybara.javalangchains.chains.Chain; 15 | import io.github.cupybara.javalangchains.util.PromptConstants; 16 | 17 | /** 18 | * provides base functionality for all pdf reading chains 19 | * 20 | * @param input type to read pdfs from 21 | */ 22 | public abstract class ReadDocumentsFromPdfChainBase implements Chain>> { 23 | 24 | /** 25 | * this enum is used to configure how each pdf content is read into a string 26 | */ 27 | public enum PdfReadMode { 28 | /** 29 | * Reads the whole document into a string 30 | */ 31 | WHOLE, 32 | 33 | /** 34 | * Reads each document page by page: provides a list of documents for each 35 | * document and adds "p. ${pageIndex}" to each "source" field 36 | */ 37 | PAGES; 38 | } 39 | 40 | /** 41 | * (PDDocument, PDF-Name) pair 42 | */ 43 | protected class PdDocumentWrapper { 44 | private final PDDocument pdDocument; 45 | private final String pdDocumentName; 46 | 47 | /** 48 | * creates an instance of PdDocumentWrapper 49 | * 50 | * @param pdDocument {@link #pdDocument} 51 | * @param pdDocumentName {@link #pdDocumentName} 52 | */ 53 | protected PdDocumentWrapper(final PDDocument pdDocument, final String pdDocumentName) { 54 | this.pdDocument = pdDocument; 55 | this.pdDocumentName = pdDocumentName; 56 | } 57 | } 58 | 59 | /** 60 | * @see PdfReadMode 61 | */ 62 | private final PdfReadMode readMode; 63 | 64 | /** 65 | * if true the reading is done in parallel 66 | */ 67 | private final boolean parallel; 68 | 69 | /** 70 | * creates a {@link ReadDocumentsFromPdfChainBase} 71 | * 72 | * @param readMode {@link #readMode} 73 | * @param parallel {@link #parallel} 74 | */ 75 | protected ReadDocumentsFromPdfChainBase(final PdfReadMode readMode, final boolean parallel) { 76 | this.readMode = readMode; 77 | this.parallel = parallel; 78 | } 79 | 80 | /** 81 | * load a pdf from an input instance 82 | * 83 | * @param input input instance 84 | * @return {@link PDDocument} 85 | * 86 | * @throws IOException on error loading pdf 87 | */ 88 | protected abstract Stream loadPdDocuments(I input) throws IOException; 89 | 90 | @Override 91 | public Stream> run(final I input) { 92 | final Stream> documents; 93 | 94 | try { 95 | documents = loadPdDocuments(input).flatMap(this::createDocumentFromPdDocumentWrapper); 96 | } catch (final IOException ioException) { 97 | throw new IllegalStateException("error loading pdf for input " + input, ioException); 98 | } 99 | 100 | if (parallel) { 101 | return documents.parallel(); 102 | } 103 | 104 | return documents; 105 | } 106 | 107 | private Stream> createDocumentFromPdDocumentWrapper(final PdDocumentWrapper pdDocumentWrapper) { 108 | try { 109 | switch (readMode) { 110 | case WHOLE: 111 | return Stream.of(createDocumentFromWholePdf(pdDocumentWrapper)); 112 | case PAGES: 113 | return createDocumentsFromPdfPages(pdDocumentWrapper).stream(); 114 | default: 115 | throw new IllegalStateException("unsupported readMode " + readMode); 116 | } 117 | } catch (final IOException innerIoException) { 118 | throw new IllegalStateException("could not create documents", innerIoException); 119 | } finally { 120 | try { 121 | pdDocumentWrapper.pdDocument.close(); 122 | } catch (final IOException ioException) { 123 | throw new IllegalStateException("could not close PDDocument", ioException); 124 | } 125 | } 126 | } 127 | 128 | private Map createDocumentFromWholePdf(final PdDocumentWrapper pdDocumentWrapper) 129 | throws IOException { 130 | 131 | final PDFTextStripper textStripper = new PDFTextStripper(); 132 | 133 | final String content = textStripper.getText(pdDocumentWrapper.pdDocument); 134 | 135 | final Map document = new LinkedHashMap<>(); 136 | document.put(PromptConstants.CONTENT, content); 137 | document.put(PromptConstants.SOURCE, pdDocumentWrapper.pdDocumentName); 138 | 139 | LogManager.getLogger().info("successfully read document {}", pdDocumentWrapper.pdDocumentName); 140 | 141 | return document; 142 | } 143 | 144 | private List> createDocumentsFromPdfPages(final PdDocumentWrapper pdDocumentWrapper) 145 | throws IOException { 146 | final PDFTextStripper textStripper = new PDFTextStripper(); 147 | 148 | final List> documents = new LinkedList<>(); 149 | for (int pageIndex = 0; pageIndex < pdDocumentWrapper.pdDocument.getNumberOfPages(); pageIndex++) { 150 | textStripper.setStartPage(pageIndex + 1); 151 | textStripper.setEndPage(pageIndex + 1); 152 | 153 | final String pageContent; 154 | try { 155 | pageContent = textStripper.getText(pdDocumentWrapper.pdDocument); 156 | } catch (final IOException innerIoException) { 157 | throw new IllegalStateException("error reading page with index " + pageIndex, innerIoException); 158 | } 159 | 160 | final int pageNumber = pageIndex + 1; 161 | 162 | final Map pageDocument = new LinkedHashMap<>(); 163 | pageDocument.put(PromptConstants.CONTENT, pageContent); 164 | pageDocument.put(PromptConstants.SOURCE, 165 | String.format("%s p.%d", pdDocumentWrapper.pdDocumentName, pageNumber)); 166 | 167 | LogManager.getLogger().info("successfully read page {} of document {}", pageNumber, 168 | pdDocumentWrapper.pdDocumentName); 169 | 170 | documents.add(pageDocument); 171 | } 172 | 173 | return documents; 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/data/retrieval/ElasticsearchRetrievalChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.retrieval; 2 | 3 | import java.io.Closeable; 4 | import java.io.IOException; 5 | import java.io.InputStream; 6 | import java.util.HashMap; 7 | import java.util.Map; 8 | import java.util.Map.Entry; 9 | import java.util.Optional; 10 | import java.util.Spliterator; 11 | import java.util.Spliterators; 12 | import java.util.function.BiFunction; 13 | import java.util.function.Function; 14 | import java.util.stream.Stream; 15 | import java.util.stream.StreamSupport; 16 | 17 | import org.apache.http.HttpHost; 18 | import org.apache.lucene.search.Query; 19 | import org.elasticsearch.client.Request; 20 | import org.elasticsearch.client.Response; 21 | import org.elasticsearch.client.RestClient; 22 | 23 | import com.fasterxml.jackson.core.type.TypeReference; 24 | import com.fasterxml.jackson.databind.ObjectMapper; 25 | import com.fasterxml.jackson.databind.node.ArrayNode; 26 | import com.fasterxml.jackson.databind.node.ObjectNode; 27 | 28 | import io.github.cupybara.javalangchains.util.PromptConstants; 29 | 30 | /** 31 | * This {@link RetrievalChain} retrieves documents from an elasticsearch index 32 | */ 33 | public class ElasticsearchRetrievalChain extends RetrievalChain implements Closeable { 34 | 35 | /** 36 | * elasticsearch index name 37 | */ 38 | private final String index; 39 | 40 | /** 41 | * elasticsearch low level {@link RestClient} 42 | */ 43 | private final RestClient restClient; 44 | 45 | /** 46 | * this {@link Function} accepts the user's question and provides the 47 | * {@link Query} which is executed against the Elasticsearch _search API 48 | */ 49 | private final Function queryCreator; 50 | 51 | /** 52 | * Consumes an elasticsearch hit and the question and creates a document as a 53 | * result 54 | */ 55 | private final BiFunction> documentCreator; 56 | 57 | /** 58 | * {@link ObjectMapper} used for query creation and document deserialization 59 | */ 60 | private final ObjectMapper objectMapper; 61 | 62 | /** 63 | * Creates an instance of {@link ElasticsearchRetrievalChain} 64 | * 65 | * @param index {@link #index} 66 | * @param restClient {@link #restClient} 67 | * @param maxDocumentCount {@link #getMaxDocumentCount()} 68 | * @param objectMapper {@link #objectMapper} 69 | * @param queryCreator {@link #queryCreator} 70 | * @param documentCreator {@link #documentCreator} 71 | */ 72 | public ElasticsearchRetrievalChain(final String index, final RestClient restClient, final int maxDocumentCount, 73 | final ObjectMapper objectMapper, final Function queryCreator, 74 | final BiFunction> documentCreator) { 75 | super(maxDocumentCount); 76 | this.index = index; 77 | this.restClient = restClient; 78 | this.objectMapper = objectMapper; 79 | this.queryCreator = queryCreator; 80 | this.documentCreator = documentCreator; 81 | } 82 | 83 | /** 84 | * Creates an instance of {@link ElasticsearchRetrievalChain} 85 | * 86 | * @param index {@link #index} 87 | * @param restClient {@link #restClient} 88 | * @param maxDocumentCount {@link #getMaxDocumentCount()} 89 | * @param objectMapper {@link #objectMapper} 90 | * @param queryCreator {@link #queryCreator} 91 | */ 92 | public ElasticsearchRetrievalChain(final String index, final RestClient restClient, final int maxDocumentCount, 93 | final ObjectMapper objectMapper, final Function queryCreator) { 94 | this(index, restClient, maxDocumentCount, objectMapper, queryCreator, defaultDocumentCreator(objectMapper)); 95 | } 96 | 97 | /** 98 | * Creates an instance of {@link ElasticsearchRetrievalChain} 99 | * 100 | * @param index {@link #index} 101 | * @param restClient {@link #restClient} 102 | * @param maxDocumentCount {@link #getMaxDocumentCount} 103 | * @param objectMapper {@link #objectMapper} 104 | */ 105 | public ElasticsearchRetrievalChain(final String index, final RestClient restClient, final int maxDocumentCount, 106 | final ObjectMapper objectMapper) { 107 | this(index, restClient, maxDocumentCount, objectMapper, question -> createQuery(objectMapper, question)); 108 | } 109 | 110 | /** 111 | * Creates an instance of {@link ElasticsearchRetrievalChain} 112 | * 113 | * @param index {@link #index} 114 | * @param restClient {@link #restClient} 115 | * @param maxDocumentCount {@link #getMaxDocumentCount} 116 | */ 117 | public ElasticsearchRetrievalChain(final String index, final RestClient restClient, final int maxDocumentCount) { 118 | this(index, restClient, maxDocumentCount, new ObjectMapper()); 119 | } 120 | 121 | /** 122 | * Creates an instance of {@link ElasticsearchRetrievalChain} 123 | * 124 | * @param index {@link #index} 125 | * @param restClient {@link #restClient} 126 | */ 127 | public ElasticsearchRetrievalChain(final String index, final RestClient restClient) { 128 | this(index, restClient, 4); 129 | } 130 | 131 | /** 132 | * Creates an instance of {@link ElasticsearchRetrievalChain} 133 | * 134 | * @param index {@link #index} 135 | */ 136 | public ElasticsearchRetrievalChain(final String index) { 137 | this(index, RestClient.builder(new HttpHost("localhost", 9200)).build()); 138 | } 139 | 140 | @Override 141 | public Stream> run(final String input) { 142 | final ObjectNode query = queryCreator.apply(input); 143 | 144 | final String requestJson = objectMapper.createObjectNode().put("size", getMaxDocumentCount()) 145 | .set("query", query).toString(); 146 | 147 | final Request searchRequest = new Request("GET", String.format("/%s/_search", index)); 148 | searchRequest.setJsonEntity(requestJson); 149 | 150 | final Response searchResponse; 151 | try { 152 | searchResponse = restClient.performRequest(searchRequest); 153 | } catch (final IOException ioException) { 154 | throw new IllegalStateException("error executing search with request " + requestJson, ioException); 155 | } 156 | 157 | final ObjectNode response; 158 | try (final InputStream responseInputStream = searchResponse.getEntity().getContent()) { 159 | response = (ObjectNode) objectMapper.readTree(responseInputStream); 160 | } catch (final IOException ioException) { 161 | throw new IllegalStateException("error parsing search response", ioException); 162 | } 163 | 164 | final ArrayNode hits = Optional.of(response).map(o -> o.get("hits")).map(ObjectNode.class::cast) 165 | .map(o -> o.get("hits")).map(ArrayNode.class::cast).orElse(null); 166 | 167 | if (hits == null) { 168 | return Stream.empty(); 169 | } 170 | 171 | return StreamSupport.stream(Spliterators.spliteratorUnknownSize(hits.iterator(), Spliterator.ORDERED), false) 172 | .map(ObjectNode.class::cast).map(hitNode -> documentCreator.apply(hitNode, input)); 173 | } 174 | 175 | @Override 176 | public void close() throws IOException { 177 | this.restClient.close(); 178 | } 179 | 180 | /** 181 | * @param objectMapper {@link ObjectMapper} used for {@link ObjectNode} creation 182 | * @param question the question used for retrieval 183 | * @return {"match": {"content": question}} 184 | */ 185 | private static ObjectNode createQuery(final ObjectMapper objectMapper, final String question) { 186 | final ObjectNode query = objectMapper.createObjectNode(); 187 | query.putObject("match").put(PromptConstants.CONTENT, question); 188 | return query; 189 | } 190 | 191 | /** 192 | * creates the default {@link #queryCreator} 193 | * 194 | * @param objectMapper the {@link ObjectMapper} used for json operations 195 | * @return {@link BiFunction} which consumes a hit node and the question and 196 | * produces a document consisting of all (key, value)-pairs of the hit's 197 | * _source object 198 | */ 199 | public static BiFunction> defaultDocumentCreator( 200 | final ObjectMapper objectMapper) { 201 | return (hitObjectNode, question) -> { 202 | final ObjectNode source = (ObjectNode) hitObjectNode.get("_source"); 203 | 204 | final Map sourceMap = objectMapper.convertValue(source, 205 | new TypeReference>() { 206 | // noop 207 | }); 208 | 209 | final Map document = new HashMap<>(); 210 | document.put(PromptConstants.QUESTION, question); 211 | 212 | for (final Entry sourceEntry : sourceMap.entrySet()) { 213 | document.put(sourceEntry.getKey(), sourceEntry.getValue().toString()); 214 | } 215 | 216 | return document; 217 | }; 218 | } 219 | } 220 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/data/retrieval/JdbcRetrievalChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.retrieval; 2 | 3 | import java.sql.Connection; 4 | import java.sql.PreparedStatement; 5 | import java.sql.ResultSet; 6 | import java.sql.ResultSetMetaData; 7 | import java.sql.SQLException; 8 | import java.util.ArrayList; 9 | import java.util.Arrays; 10 | import java.util.Collections; 11 | import java.util.HashMap; 12 | import java.util.List; 13 | import java.util.Map; 14 | import java.util.function.Function; 15 | import java.util.function.Supplier; 16 | import java.util.stream.Stream; 17 | 18 | import org.apache.commons.lang3.tuple.Pair; 19 | 20 | import io.github.cupybara.javalangchains.util.PromptConstants; 21 | 22 | public class JdbcRetrievalChain extends RetrievalChain { 23 | /** 24 | * supplier for lazy connection creation on chain invocation 25 | */ 26 | private final Supplier connectionSupplier; 27 | /** 28 | * this {@link Function} accepts the user's question and provides the 29 | * corresponding SQL statement to execute 30 | */ 31 | private final Function>> queryBuilder; 32 | /** 33 | * transforms a {@link ResultSet} to a document. default implementation in {@link #documentFromResultSet(ResultSet)} 34 | */ 35 | private final DocumentCreator documentCreator; 36 | 37 | /** 38 | * Creates an instance of {@link JdbcRetrievalChain} 39 | * 40 | * @param connectionSupplier {@link #connectionSupplier} 41 | * @param documentCreator {@link #documentCreator} 42 | * @param queryBuilder {@link #queryBuilder} 43 | * @param maxDocumentCount {@link RetrievalChain#getMaxDocumentCount()} 44 | */ 45 | public JdbcRetrievalChain(Supplier connectionSupplier, Function>> queryBuilder, DocumentCreator documentCreator, int maxDocumentCount) { 46 | super(maxDocumentCount); 47 | this.connectionSupplier = connectionSupplier; 48 | this.documentCreator = documentCreator; 49 | this.queryBuilder = queryBuilder; 50 | } 51 | 52 | /** 53 | * Creates an instance of {@link JdbcRetrievalChain} using {@link #createQuery(String, String, String)} 54 | * for SQL statement creation. 55 | * 56 | * @param connectionSupplier {@link #connectionSupplier} 57 | * @param table Name of the document table used for query creation 58 | * @param maxDocumentCount {@link RetrievalChain#getMaxDocumentCount()} 59 | */ 60 | public JdbcRetrievalChain(Supplier connectionSupplier, String table, String contentColumn, int maxDocumentCount) { 61 | this(connectionSupplier, (question) -> createQuery(question, table, contentColumn), JdbcRetrievalChain::documentFromResultSet, maxDocumentCount); 62 | } 63 | 64 | /** 65 | * Creates an instance of {@link JdbcRetrievalChain} using {@link #createQuery(String, String, String)} 66 | * for SQL statement creation and `content`, `source` as the result columns and `Documents` as the table. 67 | * 68 | * @param connectionSupplier {@link #connectionSupplier} 69 | * @param maxDocumentCount {@link RetrievalChain#getMaxDocumentCount()} 70 | */ 71 | public JdbcRetrievalChain(Supplier connectionSupplier, int maxDocumentCount) { 72 | this(connectionSupplier, "Documents", "content", maxDocumentCount); 73 | } 74 | 75 | @Override 76 | public Stream> run(String input) { 77 | Connection connection = connectionSupplier.get(); 78 | 79 | Pair> query = queryBuilder.apply(input); 80 | final String sql = query.getLeft(); 81 | final List params = query.getRight(); 82 | 83 | try (PreparedStatement statement = connection.prepareStatement(sql)) { 84 | statement.setMaxRows(getMaxDocumentCount()); 85 | for (int i = 0; i < params.size(); i++) { 86 | statement.setObject(i + 1, params.get(i)); 87 | } 88 | ResultSet resultSet = statement.executeQuery(); 89 | List> queryResult = new ArrayList<>(); 90 | while (resultSet.next()) { 91 | Map documentMap = documentCreator.create(resultSet); 92 | documentMap.put(PromptConstants.QUESTION, input); 93 | queryResult.add(documentMap); 94 | } 95 | return queryResult.stream(); 96 | } catch (SQLException e) { 97 | throw new IllegalStateException("error creating / executing database statement", e); 98 | } 99 | } 100 | 101 | /** 102 | * Transforms a {@link ResultSet} entry to a document containing the corresponding prompt info. 103 | * 104 | * @param resultSet JDBC {@link ResultSet} 105 | * @return transformed document map 106 | * @throws SQLException if a column cannot be retrieved from the result set 107 | */ 108 | private static Map documentFromResultSet(ResultSet resultSet) throws SQLException { 109 | ResultSetMetaData metaData = resultSet.getMetaData(); 110 | 111 | Map documentMap = new HashMap<>(); 112 | 113 | for(int i = 1; i <= metaData.getColumnCount(); i++) { 114 | String columnName = metaData.getColumnName(i); 115 | Object value = resultSet.getObject(i); 116 | documentMap.put(columnName, value.toString()); 117 | } 118 | 119 | return documentMap; 120 | } 121 | 122 | /** 123 | * Internal query creator that acts as a default when the user doesn't supply a customized function. 124 | * Creates a SQL statement using a content likeness query. 125 | * 126 | * @param question Input / question of the user 127 | * @param contentColumn Name of the column containing the document content 128 | * @return a {@link Pair} of the SQL and parameters to bind 129 | */ 130 | private static Pair> createQuery(final String question, final String table, final String contentColumn) { 131 | final String query = String.format("SELECT * FROM %s WHERE %s LIKE ANY (?)", table, contentColumn); 132 | final String[] splitQuestion = Arrays.stream(question.split(question)).map(t -> String.format("%%%s%%", t)).toArray(String[]::new); 133 | final List params = Collections.singletonList(splitQuestion); 134 | return Pair.of(query, params); 135 | } 136 | 137 | /** 138 | * Wrapper interface for Lambdas that act as document creators for a JDBC {@link ResultSet}. 139 | * Advancing the {@link ResultSet} is not necessary as it is done by the {@link JdbcRetrievalChain}. 140 | */ 141 | @FunctionalInterface 142 | public interface DocumentCreator { 143 | Map create(final ResultSet resultSet) throws SQLException; 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/data/retrieval/LuceneRetrievalChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.retrieval; 2 | 3 | import java.io.Closeable; 4 | import java.io.IOException; 5 | import java.util.Arrays; 6 | import java.util.LinkedHashMap; 7 | import java.util.Map; 8 | import java.util.function.Function; 9 | import java.util.stream.Collectors; 10 | import java.util.stream.Stream; 11 | 12 | import org.apache.lucene.analysis.standard.StandardAnalyzer; 13 | import org.apache.lucene.document.Document; 14 | import org.apache.lucene.index.DirectoryReader; 15 | import org.apache.lucene.index.IndexReader; 16 | import org.apache.lucene.index.IndexableField; 17 | import org.apache.lucene.queryparser.classic.ParseException; 18 | import org.apache.lucene.queryparser.classic.QueryParser; 19 | import org.apache.lucene.search.IndexSearcher; 20 | import org.apache.lucene.search.Query; 21 | import org.apache.lucene.search.TopDocs; 22 | import org.apache.lucene.search.similarities.BM25Similarity; 23 | import org.apache.lucene.store.Directory; 24 | 25 | import io.github.cupybara.javalangchains.util.PromptConstants; 26 | 27 | /** 28 | * This {@link RetrievalChain} retrieves documents from a lucene index 29 | */ 30 | public class LuceneRetrievalChain extends RetrievalChain implements Closeable { 31 | 32 | private final Function queryCreator; 33 | private final Function> documentCreator; 34 | 35 | private final IndexReader indexReader; 36 | private final IndexSearcher indexSearcher; 37 | 38 | /** 39 | * Creates an instance of {@link LuceneRetrievalChain} 40 | * 41 | * @param indexDirectory Lucene Index {@link Directory} 42 | * @param maxDocumentCount maximal count of retrieved documents 43 | * @param queryCreator this {@link Function} accepts the user's question and 44 | * provides the {@link Query} which is executed against 45 | * the Lucene {@link Directory} 46 | * @param documentCreator this {@link Function} accepts a lucene 47 | * {@link Document} and provides a {@link Map} of key 48 | * value pairs for subsequent chains 49 | */ 50 | public LuceneRetrievalChain(final Directory indexDirectory, final int maxDocumentCount, 51 | final Function queryCreator, final Function> documentCreator) { 52 | super(maxDocumentCount); 53 | this.queryCreator = queryCreator; 54 | this.documentCreator = documentCreator; 55 | 56 | try { 57 | this.indexReader = DirectoryReader.open(indexDirectory); 58 | } catch (final IOException ioException) { 59 | throw new IllegalStateException("could not open indexReader", ioException); 60 | } 61 | 62 | this.indexSearcher = new IndexSearcher(indexReader); 63 | this.indexSearcher.setSimilarity(new BM25Similarity()); // TODO: Parameterize 64 | } 65 | 66 | /** 67 | * Creates an instance of {@link LuceneRetrievalChain}. Uses 68 | * {@link #createDocument(Document)} to map all lucene document fields into the 69 | * output {@link Map}. 70 | * 71 | * @param indexDirectory Lucene Index {@link Directory} 72 | * @param maxDocumentCount maximal count of retrieved documents 73 | * @param queryCreator this {@link Function} accepts the user's question and 74 | * provides the {@link Query} which is executed against 75 | * the Lucene {@link Directory} 76 | */ 77 | public LuceneRetrievalChain(final Directory indexDirectory, final int maxDocumentCount, 78 | final Function queryCreator) { 79 | this(indexDirectory, maxDocumentCount, queryCreator, LuceneRetrievalChain::createDocument); 80 | } 81 | 82 | /** 83 | * Creates an instance of {@link LuceneRetrievalChain}. Uses 84 | * {@link #createQuery(String)} to provide a default {@link Query} using a 85 | * {@link StandardAnalyzer} targeting the field 86 | * {@link PromptConstants#CONTENT}.. Uses {@link #createDocument(Document)} to 87 | * map all lucene document fields into the output {@link Map}. 88 | * 89 | * @param indexDirectory Lucene Index {@link Directory} 90 | * @param maxDocumentCount maximal count of retrieved documents 91 | */ 92 | public LuceneRetrievalChain(final Directory indexDirectory, final int maxDocumentCount) { 93 | this(indexDirectory, maxDocumentCount, LuceneRetrievalChain::createQuery, LuceneRetrievalChain::createDocument); 94 | } 95 | 96 | /** 97 | * Creates an instance of {@link LuceneRetrievalChain} with a maximum of 4 98 | * retrieved documents. Uses {@link #createQuery(String)} to provide a default 99 | * {@link Query} using a {@link StandardAnalyzer} targeting the field 100 | * {@link PromptConstants#CONTENT}.. Uses {@link #createDocument(Document)} to 101 | * map all lucene document fields into the output {@link Map}. 102 | * 103 | * @param indexDirectory Lucene Index {@link Directory} 104 | */ 105 | public LuceneRetrievalChain(final Directory indexDirectory) { 106 | this(indexDirectory, 4); 107 | } 108 | 109 | @Override 110 | public Stream> run(final String input) { 111 | final Query query = queryCreator.apply(input); 112 | 113 | final TopDocs topDocs; 114 | try { 115 | topDocs = indexSearcher.search(query, this.getMaxDocumentCount()); 116 | } catch (final IOException ioException) { 117 | throw new IllegalStateException("error processing search for query " + query, ioException); 118 | } 119 | 120 | return Arrays.stream(topDocs.scoreDocs).map(hit -> { 121 | try { 122 | return indexSearcher.doc(hit.doc); 123 | } catch (final IOException ioException) { 124 | throw new IllegalStateException("could not process document " + hit.doc, ioException); 125 | } 126 | }).map(this.documentCreator).map(document -> { 127 | final Map mappedDocument = new LinkedHashMap<>(document); 128 | mappedDocument.put(PromptConstants.QUESTION, input); 129 | return mappedDocument; 130 | }); 131 | } 132 | 133 | @Override 134 | public void close() throws IOException { 135 | this.indexReader.close(); 136 | } 137 | 138 | private static Map createDocument(final Document document) { 139 | return document.getFields().stream() 140 | .collect(Collectors.toMap(IndexableField::name, IndexableField::stringValue)); 141 | } 142 | 143 | private static Query createQuery(final String searchTerm) { 144 | final StandardAnalyzer analyzer = new StandardAnalyzer(); 145 | final QueryParser queryParser = new QueryParser(PromptConstants.CONTENT, analyzer); 146 | try { 147 | return queryParser.parse(searchTerm); 148 | } catch (final ParseException parseException) { 149 | throw new IllegalStateException("could not create query for searchTerm " + searchTerm, parseException); 150 | } 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/data/retrieval/RetrievalChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.retrieval; 2 | 3 | import java.util.Map; 4 | import java.util.stream.Stream; 5 | 6 | import io.github.cupybara.javalangchains.chains.Chain; 7 | 8 | /** 9 | * {@link Chain} which is utilized for retrieving documents in a QA context 10 | */ 11 | public abstract class RetrievalChain implements Chain>> { 12 | 13 | /** 14 | * maximum count of retrieved documents 15 | */ 16 | private final int maxDocumentCount; 17 | 18 | /** 19 | * @param maxDocumentCount {@link #maxDocumentCount} 20 | */ 21 | protected RetrievalChain(final int maxDocumentCount) { 22 | this.maxDocumentCount = maxDocumentCount; 23 | } 24 | 25 | /** 26 | * @return {@link #maxDocumentCount} 27 | */ 28 | protected final int getMaxDocumentCount() { 29 | return this.maxDocumentCount; 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/data/writer/WriteDocumentsToElasticsearchIndexChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.writer; 2 | 3 | import java.io.IOException; 4 | import java.util.Map; 5 | import java.util.function.BiConsumer; 6 | import java.util.function.Function; 7 | import java.util.stream.Stream; 8 | 9 | import org.apache.http.HttpHost; 10 | import org.apache.http.HttpStatus; 11 | import org.apache.logging.log4j.LogManager; 12 | import org.elasticsearch.client.Request; 13 | import org.elasticsearch.client.Response; 14 | import org.elasticsearch.client.RestClient; 15 | import org.elasticsearch.client.RestClientBuilder; 16 | 17 | import com.fasterxml.jackson.core.JsonProcessingException; 18 | import com.fasterxml.jackson.databind.ObjectMapper; 19 | import com.fasterxml.jackson.databind.node.ObjectNode; 20 | 21 | import io.github.cupybara.javalangchains.chains.Chain; 22 | import io.github.cupybara.javalangchains.chains.data.retrieval.ElasticsearchRetrievalChain; 23 | import io.github.cupybara.javalangchains.util.PromptConstants; 24 | 25 | /** 26 | * Inserts documents into an elasticsearch index 27 | */ 28 | public class WriteDocumentsToElasticsearchIndexChain implements Chain>, Void> { 29 | 30 | /** 31 | * Elasticsearch index name 32 | */ 33 | private final String index; 34 | 35 | /** 36 | * Elasticsearch Low Level {@link RestClientBuilder} 37 | */ 38 | private final RestClientBuilder restClientBuilder; 39 | 40 | /** 41 | * Optional {@link Function} which provides an ID value for a document. If set, 42 | * documents are indexed using PUT /_doc/${id} instead of POST /_doc 43 | */ 44 | private final Function, String> idProvider; 45 | 46 | /** 47 | * creates an elasticsearch index by consuming its name and an already 48 | * instantiated {@link RestClient} 49 | */ 50 | private final BiConsumer indexCreator; 51 | 52 | /** 53 | * creates the effective document json from an input document 54 | */ 55 | private final Function, String> documentJsonCreator; 56 | 57 | /** 58 | * @param index {@link #index} 59 | * @param restClientBuilder {@link #restClientBuilder} 60 | * @param idProvider {@link #idProvider} 61 | * @param indexCreator {@link #indexCreator} 62 | * @param documentJsonCreator {@link #documentJsonCreator} 63 | */ 64 | public WriteDocumentsToElasticsearchIndexChain(final String index, final RestClientBuilder restClientBuilder, 65 | final Function, String> idProvider, final BiConsumer indexCreator, 66 | final Function, String> documentJsonCreator) { 67 | this.index = index; 68 | this.restClientBuilder = restClientBuilder; 69 | this.idProvider = idProvider; 70 | this.indexCreator = indexCreator; 71 | this.documentJsonCreator = documentJsonCreator; 72 | } 73 | 74 | /** 75 | * @param index {@link #index} 76 | * @param restClientBuilder {@link #restClientBuilder} 77 | * @param objectMapper {@link ObjectMapper} used to create default json 78 | * operations 79 | * @param idProvider {@link #idProvider} 80 | * @param indexCreator {@link #indexCreator} 81 | */ 82 | public WriteDocumentsToElasticsearchIndexChain(final String index, final RestClientBuilder restClientBuilder, 83 | final ObjectMapper objectMapper, final Function, String> idProvider, 84 | final BiConsumer indexCreator) { 85 | this(index, restClientBuilder, idProvider, defaultIndexCreator(objectMapper), 86 | defaultDocumentJsonCreator(objectMapper)); 87 | } 88 | 89 | /** 90 | * @param index {@link #index} 91 | * @param restClientBuilder {@link #restClientBuilder} 92 | * @param objectMapper {@link ObjectMapper} used to create default json 93 | * operations 94 | * @param idProvider {@link #idProvider} 95 | */ 96 | public WriteDocumentsToElasticsearchIndexChain(final String index, final RestClientBuilder restClientBuilder, 97 | final ObjectMapper objectMapper, final Function, String> idProvider) { 98 | this(index, restClientBuilder, objectMapper, idProvider, defaultIndexCreator(objectMapper)); 99 | } 100 | 101 | /** 102 | * @param index {@link #index} 103 | * @param restClientBuilder {@link #restClientBuilder} 104 | * @param objectMapper {@link ObjectMapper} used to create default json 105 | * operations 106 | */ 107 | public WriteDocumentsToElasticsearchIndexChain(final String index, final RestClientBuilder restClientBuilder, 108 | final ObjectMapper objectMapper) { 109 | this(index, restClientBuilder, objectMapper, null); 110 | } 111 | 112 | /** 113 | * creates a {@link WriteDocumentsToElasticsearchIndexChain} with the default 114 | * {@link ObjectMapper} 115 | * 116 | * @param index {@link #index} 117 | * @param restClientBuilder {@link #restClientBuilder} 118 | */ 119 | public WriteDocumentsToElasticsearchIndexChain(final String index, final RestClientBuilder restClientBuilder) { 120 | this(index, restClientBuilder, new ObjectMapper()); 121 | } 122 | 123 | /** 124 | * creates a {@link WriteDocumentsToElasticsearchIndexChain} with the default 125 | * {@link HttpHost} (http://localhost:9200) and a default {@link ObjectMapper} 126 | * 127 | * @param index {@link #index} 128 | */ 129 | public WriteDocumentsToElasticsearchIndexChain(final String index) { 130 | this(index, RestClient.builder(new HttpHost("localhost", 9200))); 131 | } 132 | 133 | @Override 134 | public Void run(final Stream> input) { 135 | try (final RestClient restClient = restClientBuilder.build()) { 136 | 137 | if (this.indexCreator != null) { 138 | createIndexIfNotExists(restClient); 139 | } 140 | 141 | input.forEach(document -> { 142 | final String documentJson = documentJsonCreator.apply(document); 143 | 144 | final Request indexRequest = createIndexRequest(document); 145 | indexRequest.setJsonEntity(documentJson); 146 | try { 147 | restClient.performRequest(indexRequest); 148 | } catch (final IOException innerIoException) { 149 | throw new IllegalStateException("error writing document " + documentJson, innerIoException); 150 | } 151 | }); 152 | 153 | } catch (final IOException ioException) { 154 | throw new IllegalStateException("error writing documents to elasticsearch index", ioException); 155 | } 156 | return null; 157 | } 158 | 159 | private Request createIndexRequest(final Map document) { 160 | if (idProvider == null) { 161 | return new Request("POST", String.format("/%s/_doc", index)); 162 | } else { 163 | final String id = idProvider.apply(document); 164 | LogManager.getLogger(getClass()).debug("creating document with id {}", id); 165 | return new Request("PUT", String.format("/%s/_doc/%s", index, id)); 166 | } 167 | } 168 | 169 | /** 170 | * Checks whether an index with the name {@link #index} exists. If none exists, 171 | * it is created with default settings used for 172 | * {@link ElasticsearchRetrievalChain}. 173 | * 174 | * @param restClient the {@link RestClient} used to perform elasticsearch 175 | * requests 176 | * @throws IOException on error 177 | */ 178 | private synchronized void createIndexIfNotExists(final RestClient restClient) throws IOException { 179 | final Request indexExistsRequest = new Request("HEAD", '/' + index); 180 | final Response indexExistsResponse = restClient.performRequest(indexExistsRequest); 181 | if (indexExistsResponse.getStatusLine().getStatusCode() == HttpStatus.SC_OK) { 182 | LogManager.getLogger(getClass()).info("index {} exists", index); 183 | return; 184 | } 185 | 186 | LogManager.getLogger(getClass()).info("creating index {} with default settings", index); 187 | 188 | this.indexCreator.accept(index, restClient); 189 | } 190 | 191 | /** 192 | * creates the default {@link Function} for creating json documents from input 193 | * documents for this chain 194 | * 195 | * @param objectMapper {@link ObjectMapper} for json operations 196 | * @return default {@link #documentJsonCreator} 197 | */ 198 | public static Function, String> defaultDocumentJsonCreator(final ObjectMapper objectMapper) { 199 | return document -> { 200 | try { 201 | return objectMapper.writeValueAsString(document); 202 | } catch (final JsonProcessingException jsonProcessingException) { 203 | throw new IllegalStateException("error creating json for document " + document, 204 | jsonProcessingException); 205 | } 206 | }; 207 | } 208 | 209 | /** 210 | * Realizes the default way of creating an elasticsearch index using the method 211 | * from 212 | * https://github.com/hwchase17/langchain/blob/master/langchain/retrievers/elastic_search_bm25.py 213 | * 214 | * @param objectMapper {@link ObjectMapper} for json operations 215 | * @return default {@link #indexCreator} 216 | */ 217 | public static BiConsumer defaultIndexCreator(final ObjectMapper objectMapper) { 218 | return (indexName, restClient) -> { 219 | final ObjectNode indexRequestBody = objectMapper.createObjectNode(); 220 | 221 | // "settings": {...} 222 | final ObjectNode settings = indexRequestBody.putObject("settings"); 223 | 224 | // "analysis": {"analyzer": {"default": {"type": "standard"}}} 225 | settings.putObject("analysis").putObject("analyzer").putObject("default").put("type", "standard"); 226 | 227 | // "similarity": {"custom_bm25": {"type": "BM25", "k1": 2.0, "b": 0.75} 228 | settings.putObject("similarity").putObject("custom_bm25").put("type", "BM25").put("k1", 2.0).put("b", 0.75); 229 | 230 | // "mappings": {"properties": {"content": {"type": "text", "similarity": 231 | // "custom_bm25"}}} 232 | indexRequestBody.putObject("mappings").putObject("properties").putObject(PromptConstants.CONTENT) 233 | .put("type", "text").put("similarity", "custom_bm25"); 234 | 235 | final String indexRequestBodyJson = indexRequestBody.toString(); 236 | final Request indexRequest = new Request("PUT", '/' + indexName); 237 | indexRequest.setJsonEntity(indexRequestBodyJson); 238 | try { 239 | restClient.performRequest(indexRequest); 240 | } catch (final IOException ioException) { 241 | throw new IllegalStateException("error creating index " + indexRequestBodyJson, ioException); 242 | } 243 | }; 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/data/writer/WriteDocumentsToLuceneDirectoryChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.writer; 2 | 3 | import java.io.IOException; 4 | import java.nio.file.Files; 5 | import java.nio.file.Path; 6 | import java.util.Map; 7 | import java.util.stream.Stream; 8 | 9 | import org.apache.lucene.analysis.standard.StandardAnalyzer; 10 | import org.apache.lucene.document.Document; 11 | import org.apache.lucene.document.Field; 12 | import org.apache.lucene.document.TextField; 13 | import org.apache.lucene.index.IndexWriter; 14 | import org.apache.lucene.index.IndexWriterConfig; 15 | import org.apache.lucene.store.Directory; 16 | import org.apache.lucene.store.MMapDirectory; 17 | 18 | import io.github.cupybara.javalangchains.chains.Chain; 19 | import io.github.cupybara.javalangchains.util.PromptConstants; 20 | 21 | /** 22 | * Stores documents in a lucene {@link Directory} 23 | */ 24 | public class WriteDocumentsToLuceneDirectoryChain implements Chain>, Directory> { 25 | 26 | /** 27 | * The directory {@link Path} used to store the created index data 28 | */ 29 | private final Path directoryOutputPath; 30 | 31 | /** 32 | * @param directoryOutputPath {@link #directoryOutputPath} 33 | */ 34 | public WriteDocumentsToLuceneDirectoryChain(final Path directoryOutputPath) { 35 | this.directoryOutputPath = directoryOutputPath; 36 | } 37 | 38 | /** 39 | * creates a {@link WriteDocumentsToLuceneDirectoryChain} with a default temp 40 | * directory path 41 | * 42 | * @throws IOException on error creating the temp dir 43 | */ 44 | public WriteDocumentsToLuceneDirectoryChain() throws IOException { 45 | this(Files.createTempDirectory("lucene")); 46 | } 47 | 48 | @Override 49 | public Directory run(final Stream> input) { 50 | final Directory indexDirectory; 51 | try { 52 | indexDirectory = new MMapDirectory(this.directoryOutputPath); 53 | } catch (final IOException ioException) { 54 | throw new IllegalStateException("error creating index", ioException); 55 | } 56 | 57 | final StandardAnalyzer analyzer = new StandardAnalyzer(); 58 | final IndexWriterConfig config = new IndexWriterConfig(analyzer); 59 | 60 | try (final IndexWriter indexWriter = new IndexWriter(indexDirectory, config)) { 61 | 62 | input.forEach(document -> { 63 | final Document doc = new Document(); 64 | doc.add(new TextField(PromptConstants.CONTENT, document.get(PromptConstants.CONTENT), Field.Store.YES)); 65 | doc.add(new TextField(PromptConstants.SOURCE, document.get(PromptConstants.SOURCE), Field.Store.YES)); 66 | try { 67 | indexWriter.addDocument(doc); 68 | } catch (final IOException innerIoException) { 69 | throw new IllegalStateException("error writing document: " + document, innerIoException); 70 | } 71 | }); 72 | 73 | indexWriter.commit(); 74 | } catch (final IOException ioException) { 75 | throw new IllegalStateException("error creating writer", ioException); 76 | } 77 | 78 | return indexDirectory; 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/LargeLanguageModelChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm; 2 | 3 | import java.util.Map; 4 | 5 | import io.github.cupybara.javalangchains.chains.Chain; 6 | 7 | /** 8 | * Parent of all {@link Chain Chains} which allow passing input to a large 9 | * language model. Accepts a document of key value pairs and provides the LLM 10 | * output. 11 | */ 12 | public abstract class LargeLanguageModelChain implements Chain, String> { 13 | 14 | /** 15 | * The template which contains placeholders in the form ${myPlaceholder} that 16 | * are replaced for input documents before creating a request to a LLM. 17 | */ 18 | private final String promptTemplate; 19 | 20 | /** 21 | * creates an instance of the {@link LargeLanguageModelChain} 22 | * 23 | * @param promptTemplate {@link #promptTemplate} 24 | */ 25 | protected LargeLanguageModelChain(final String promptTemplate) { 26 | this.promptTemplate = promptTemplate; 27 | } 28 | 29 | /** 30 | * @return {@link #promptTemplate} 31 | */ 32 | protected final String getPromptTemplate() { 33 | return promptTemplate; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/azure/chat/AzureOpenAiChatCompletionsChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.azure.chat; 2 | 3 | import java.net.URI; 4 | 5 | import org.springframework.http.MediaType; 6 | import org.springframework.web.reactive.function.BodyInserters; 7 | import org.springframework.web.reactive.function.client.WebClient; 8 | import org.springframework.web.reactive.function.client.WebClient.ResponseSpec; 9 | import org.springframework.web.util.UriComponentsBuilder; 10 | 11 | import com.fasterxml.jackson.databind.ObjectMapper; 12 | 13 | import io.github.cupybara.javalangchains.chains.llm.openai.chat.OpenAiChatCompletionsChain; 14 | import io.github.cupybara.javalangchains.chains.llm.openai.chat.OpenAiChatCompletionsParameters; 15 | import io.github.cupybara.javalangchains.chains.llm.openai.chat.OpenAiChatCompletionsRequest; 16 | 17 | /** 18 | * {@link OpenAiChatCompletionsChain} adopted for usage of Azure OpenAI Services 19 | */ 20 | public final class AzureOpenAiChatCompletionsChain extends OpenAiChatCompletionsChain { 21 | 22 | private final URI requestUri; 23 | 24 | /** 25 | * @param resourceName Name of the azure resource 26 | * @param deploymentName Name of the azure openai service deployment 27 | * @param apiVersion The target API Version 28 | * @param promptTemplate The template which contains placeholders in the form 29 | * ${myPlaceholder} that are replaced for input documents 30 | * before creating a request to a LLM. 31 | * @param parameters The {@link OpenAiChatCompletionsParameters} allows to 32 | * finetune requests to the OpenAI API 33 | * @param apiKey The API-Key used for Authentication (passed using the 34 | * "api-key" Header) 35 | * @param systemTemplate The template for the system role which contains 36 | * placeholders in the form ${myPlaceholder} that are 37 | * replaced for input documents before creating a request 38 | * to a LLM. 39 | * @param objectMapper The {@link ObjectMapper} used for body serialization 40 | * and deserialization 41 | * @param webClient The {@link WebClient} used for executing requests to 42 | * the OpenAI API 43 | */ 44 | public AzureOpenAiChatCompletionsChain(final String resourceName, final String deploymentName, 45 | final String apiVersion, final String promptTemplate, final OpenAiChatCompletionsParameters parameters, 46 | final String apiKey, final String systemTemplate, final ObjectMapper objectMapper, 47 | final WebClient webClient) { 48 | super(promptTemplate, parameters, apiKey, systemTemplate, objectMapper, webClient); 49 | 50 | if (parameters.getModel() != null) { 51 | throw new IllegalArgumentException( 52 | "the model parameter cannot be used for the Azure OpenAI Services, it is passed via deploymentName instead"); 53 | } 54 | 55 | this.requestUri = UriComponentsBuilder.newInstance().scheme("https") 56 | .host(String.format("%s.openai.azure.com", resourceName)).queryParam("api-version", apiVersion) 57 | .path(String.format("/openai/deployments/%s/chat/completions", deploymentName)).build().toUri(); 58 | } 59 | 60 | /** 61 | * @param resourceName Name of the azure resource 62 | * @param deploymentName Name of the azure openai service deployment 63 | * @param apiVersion The target API Version 64 | * @param promptTemplate The template which contains placeholders in the form 65 | * ${myPlaceholder} that are replaced for input documents 66 | * before creating a request to a LLM. 67 | * @param parameters The {@link OpenAiChatCompletionsParameters} allows to 68 | * finetune requests to the OpenAI API 69 | * @param apiKey The API-Key used for Authentication (passed using the 70 | * "api-key" Header) 71 | * @param systemTemplate The template for the system role which contains 72 | * placeholders in the form ${myPlaceholder} that are 73 | * replaced for input documents before creating a request 74 | */ 75 | public AzureOpenAiChatCompletionsChain(final String resourceName, final String deploymentName, 76 | final String apiVersion, final String promptTemplate, final OpenAiChatCompletionsParameters parameters, 77 | final String apiKey, final String systemTemplate) { 78 | this(resourceName, deploymentName, apiVersion, promptTemplate, parameters, apiKey, systemTemplate, 79 | createDefaultObjectMapper(), createDefaultWebClient()); 80 | } 81 | 82 | /** 83 | * @param resourceName Name of the azure resource 84 | * @param deploymentName Name of the azure openai service deployment 85 | * @param apiVersion The target API Version 86 | * @param promptTemplate The template which contains placeholders in the form 87 | * ${myPlaceholder} that are replaced for input documents 88 | * before creating a request to a LLM. 89 | * @param parameters The {@link OpenAiChatCompletionsParameters} allows to 90 | * finetune requests to the OpenAI API 91 | * @param apiKey The API-Key used for Authentication (passed using the 92 | * "api-key" Header) 93 | */ 94 | public AzureOpenAiChatCompletionsChain(final String resourceName, final String deploymentName, 95 | final String apiVersion, final String promptTemplate, final OpenAiChatCompletionsParameters parameters, 96 | final String apiKey) { 97 | this(resourceName, deploymentName, apiVersion, promptTemplate, parameters, apiKey, null); 98 | } 99 | 100 | @Override 101 | protected ResponseSpec createResponseSpec(final OpenAiChatCompletionsRequest request, final WebClient webClient, 102 | final ObjectMapper objectMapper) { 103 | return webClient.post().uri(requestUri).contentType(MediaType.APPLICATION_JSON).header("api-key", getApiKey()) 104 | .body(BodyInserters.fromValue(requestToBody(request, objectMapper))).retrieve(); 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/azure/completions/AzureOpenAiCompletionsChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.azure.completions; 2 | 3 | import java.net.URI; 4 | 5 | import org.springframework.http.MediaType; 6 | import org.springframework.web.reactive.function.BodyInserters; 7 | import org.springframework.web.reactive.function.client.WebClient; 8 | import org.springframework.web.reactive.function.client.WebClient.ResponseSpec; 9 | import org.springframework.web.util.UriComponentsBuilder; 10 | 11 | import com.fasterxml.jackson.databind.ObjectMapper; 12 | 13 | import io.github.cupybara.javalangchains.chains.llm.openai.chat.OpenAiChatCompletionsParameters; 14 | import io.github.cupybara.javalangchains.chains.llm.openai.completions.OpenAiCompletionsChain; 15 | import io.github.cupybara.javalangchains.chains.llm.openai.completions.OpenAiCompletionsParameters; 16 | import io.github.cupybara.javalangchains.chains.llm.openai.completions.OpenAiCompletionsRequest; 17 | 18 | /** 19 | * {@link OpenAiCompletionsChain} adopted for usage of Azure OpenAI Services 20 | */ 21 | public final class AzureOpenAiCompletionsChain extends OpenAiCompletionsChain { 22 | 23 | private final URI requestUri; 24 | 25 | /** 26 | * @param resourceName Name of the azure resource 27 | * @param deploymentName Name of the azure openai service deployment 28 | * @param apiVersion The target API Version 29 | * @param promptTemplate The template which contains placeholders in the form 30 | * ${myPlaceholder} that are replaced for input documents 31 | * before creating a request to a LLM. 32 | * @param parameters The {@link OpenAiChatCompletionsParameters} allows to 33 | * finetune requests to the OpenAI API 34 | * @param apiKey The API-Key used for Authentication (passed using the 35 | * "api-key" Header) 36 | * @param objectMapper The {@link ObjectMapper} used for body serialization 37 | * and deserialization 38 | * @param webClient The {@link WebClient} used for executing requests to 39 | * the OpenAI API 40 | */ 41 | public AzureOpenAiCompletionsChain(final String resourceName, final String deploymentName, final String apiVersion, 42 | final String promptTemplate, final OpenAiCompletionsParameters parameters, final String apiKey, 43 | final ObjectMapper objectMapper, final WebClient webClient) { 44 | super(promptTemplate, parameters, apiKey, objectMapper, webClient); 45 | 46 | if (parameters.getModel() != null) { 47 | throw new IllegalArgumentException( 48 | "the model parameter cannot be used for the Azure OpenAI Services, it is passed via deploymentName instead"); 49 | } 50 | 51 | this.requestUri = UriComponentsBuilder.newInstance().scheme("https") 52 | .host(String.format("%s.openai.azure.com", resourceName)).queryParam("api-version", apiVersion) 53 | .path(String.format("/openai/deployments/%s/completions", deploymentName)).build().toUri(); 54 | } 55 | 56 | /** 57 | * @param resourceName Name of the azure resource 58 | * @param deploymentName Name of the azure openai service deployment 59 | * @param apiVersion The target API Version 60 | * @param promptTemplate The template which contains placeholders in the form 61 | * ${myPlaceholder} that are replaced for input documents 62 | * before creating a request to a LLM. 63 | * @param parameters The {@link OpenAiChatCompletionsParameters} allows to 64 | * finetune requests to the OpenAI API 65 | * @param apiKey The API-Key used for Authentication (passed using the 66 | * "api-key" Header) 67 | */ 68 | public AzureOpenAiCompletionsChain(final String resourceName, final String deploymentName, final String apiVersion, 69 | final String promptTemplate, final OpenAiCompletionsParameters parameters, final String apiKey) { 70 | this(resourceName, deploymentName, apiVersion, promptTemplate, parameters, apiKey, createDefaultObjectMapper(), 71 | createDefaultWebClient()); 72 | } 73 | 74 | @Override 75 | protected ResponseSpec createResponseSpec(final OpenAiCompletionsRequest request, final WebClient webClient, 76 | final ObjectMapper objectMapper) { 77 | return webClient.post().uri(requestUri).contentType(MediaType.APPLICATION_JSON).header("api-key", getApiKey()) 78 | .body(BodyInserters.fromValue(requestToBody(request, objectMapper))).retrieve(); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/OpenAiChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai; 2 | 3 | import java.util.Map; 4 | 5 | import org.springframework.http.HttpHeaders; 6 | import org.springframework.http.MediaType; 7 | import org.springframework.web.reactive.function.BodyInserters; 8 | import org.springframework.web.reactive.function.client.WebClient; 9 | import org.springframework.web.reactive.function.client.WebClient.ResponseSpec; 10 | import org.springframework.web.util.UriComponentsBuilder; 11 | 12 | import com.fasterxml.jackson.annotation.JsonInclude; 13 | import com.fasterxml.jackson.core.JsonProcessingException; 14 | import com.fasterxml.jackson.databind.ObjectMapper; 15 | 16 | import io.github.cupybara.javalangchains.chains.llm.LargeLanguageModelChain; 17 | 18 | /** 19 | * {@link LargeLanguageModelChain} for usage with the OpenAI /completions API 20 | * 21 | * @param

the static parameter type 22 | * @param the static request type 23 | * @param the static response type 24 | */ 25 | public abstract class OpenAiChain

, I extends P, O extends OpenAiResponse> 26 | extends LargeLanguageModelChain { 27 | 28 | /** 29 | * The request path (/v1/chat/completions or /v1/completions) 30 | */ 31 | private final String requestPath; 32 | 33 | /** 34 | * The response type class 35 | */ 36 | private final Class responseClass; 37 | 38 | /** 39 | * The {@link OpenAiParameters} allows to tune requests to the OpenAI API 40 | */ 41 | private final P parameters; 42 | 43 | /** 44 | * The API-Key used for Authentication 45 | */ 46 | private final String apiKey; 47 | 48 | /** 49 | * The {@link ObjectMapper} used for body serialization and deserialization 50 | */ 51 | private final ObjectMapper objectMapper; 52 | 53 | /** 54 | * The {@link WebClient} used for executing requests to the OpenAI API 55 | */ 56 | private final WebClient webClient; 57 | 58 | /** 59 | * @param promptTemplate {@link #getPromptTemplate()} 60 | * @param requestPath {@link #requestPath} 61 | * @param responseClass {@link #responseClass} 62 | * @param parameters {@link #parameters}r 63 | * @param apiKey {@link #apiKey} 64 | * @param objectMapper {@link #objectMapper} 65 | * @param webClient {@link #webClient} 66 | */ 67 | protected OpenAiChain(final String promptTemplate, final String requestPath, final Class responseClass, 68 | final P parameters, final String apiKey, final ObjectMapper objectMapper, final WebClient webClient) { 69 | super(promptTemplate); 70 | this.requestPath = requestPath; 71 | this.responseClass = responseClass; 72 | this.parameters = parameters; 73 | this.apiKey = apiKey; 74 | this.objectMapper = objectMapper; 75 | this.webClient = webClient; 76 | } 77 | 78 | /** 79 | * @param promptTemplate {@link #getPromptTemplate()} 80 | * @param requestPath {@link #requestPath} 81 | * @param responseClass {@link #responseClass} 82 | * @param parameters {@link #parameters} 83 | * @param apiKey {@link #apiKey} 84 | */ 85 | protected OpenAiChain(final String promptTemplate, final String requestPath, final Class responseClass, 86 | final P parameters, final String apiKey) { 87 | this(promptTemplate, requestPath, responseClass, parameters, apiKey, createDefaultObjectMapper(), 88 | createDefaultWebClient()); 89 | } 90 | 91 | /** 92 | * creates the request entity from the current document 93 | * 94 | * @param input the current document 95 | * @return the request entity 96 | */ 97 | protected abstract I createRequest(final Map input); 98 | 99 | /** 100 | * creates the chain output from the response entity 101 | * 102 | * @param response the response entity 103 | * @return this chain's output 104 | */ 105 | protected abstract String createOutput(O response); 106 | 107 | @Override 108 | public String run(final Map input) { 109 | final I request = createRequest(input); 110 | if (parameters != null) { 111 | request.copyFrom(parameters); 112 | } 113 | 114 | return createResponseSpec(request, webClient, objectMapper).bodyToMono(String.class) 115 | .map(responseBody -> bodyToResponse(responseBody, objectMapper)).map(this::createOutput).block(); 116 | } 117 | 118 | /** 119 | * executes the request to the OpenAI API. Protected so that it may be 120 | * overridden for other OpenAI API Providers. 121 | * 122 | * @param request the request entity 123 | * @param webClient the {@link WebClient} to use for the request 124 | * @param objectMapper the {@link ObjectMapper} used for body serialization 125 | * @return the {@link ResponseSpec} 126 | */ 127 | protected ResponseSpec createResponseSpec(final I request, final WebClient webClient, 128 | final ObjectMapper objectMapper) { 129 | return this.webClient.post() 130 | .uri(UriComponentsBuilder.newInstance().scheme("https").host("api.openai.com").path(requestPath).build() 131 | .toUri()) 132 | .contentType(MediaType.APPLICATION_JSON).header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey) 133 | .body(BodyInserters.fromValue(requestToBody(request, objectMapper))).retrieve(); 134 | } 135 | 136 | /** 137 | * Serializes the request entity 138 | * 139 | * @param request the request entity to serialize 140 | * @param objectMapper {@link ObjectMapper} used for serialization 141 | * @return serialized the serialized request body 142 | */ 143 | protected String requestToBody(final I request, final ObjectMapper objectMapper) { 144 | try { 145 | return objectMapper.writeValueAsString(request); 146 | } catch (final JsonProcessingException jsonProcessingException) { 147 | throw new IllegalStateException("error creating request body", jsonProcessingException); 148 | } 149 | } 150 | 151 | private O bodyToResponse(final String responseBody, final ObjectMapper objectMapper) { 152 | try { 153 | return objectMapper.readValue(responseBody, this.responseClass); 154 | } catch (final JsonProcessingException jsonProcessingException) { 155 | throw new IllegalStateException("error deserializing responseBody " + responseBody, 156 | jsonProcessingException); 157 | } 158 | } 159 | 160 | /** 161 | * @return {@link #apiKey}p 162 | */ 163 | protected final String getApiKey() { 164 | return apiKey; 165 | } 166 | 167 | /** 168 | * @return a default configured {@link ObjectMapper} 169 | */ 170 | public static ObjectMapper createDefaultObjectMapper() { 171 | return new ObjectMapper().setSerializationInclusion(JsonInclude.Include.NON_NULL); 172 | } 173 | 174 | /** 175 | * @return a default configured {@link WebClient} 176 | */ 177 | public static WebClient createDefaultWebClient() { 178 | return WebClient.create(); 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/OpenAiParameters.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai; 2 | 3 | import com.fasterxml.jackson.annotation.JsonIgnore; 4 | import com.fasterxml.jackson.annotation.JsonProperty; 5 | 6 | /** 7 | * Contains the intersection of parameters for the /chat/completions and 8 | * /completions api 9 | * 10 | * @param The type of this Parameter Class for typed fluent api return 11 | * values 12 | */ 13 | public abstract class OpenAiParameters> { 14 | 15 | /** 16 | * The base type for correctly typed fluent api return values 17 | */ 18 | @JsonIgnore 19 | private final Class typeClass; 20 | 21 | /** 22 | *

From 23 | * https://github.com/openai/openai-openapi/blob/master/openapi.yaml

24 | * 25 | * The maximum number of tokens allowed for the generated answer. By default, 26 | * the number of tokens the model can return will be (4096 - prompt tokens). 27 | */ 28 | @JsonProperty("max_tokens") 29 | private Integer maxTokens; 30 | 31 | /** 32 | *

From 33 | * https://github.com/openai/openai-openapi/blob/master/openapi.yaml

34 | * 35 | * ID of the model to use. Currently, only `gpt-3.5-turbo` and 36 | * `gpt-3.5-turbo-0301` are supported. 37 | */ 38 | private String model; 39 | 40 | /** 41 | *

From 42 | * https://github.com/openai/openai-openapi/blob/master/openapi.yaml

43 | * 44 | * How many completions to generate for each prompt. 45 | ** 46 | * Note:** Because this parameter generates many completions, it can quickly 47 | * consume your token quota. Use carefully and ensure that you have reasonable 48 | * settings for `max_tokens` and `stop`. 49 | */ 50 | @JsonProperty 51 | private Integer n; 52 | 53 | /** 54 | *

From 55 | * https://github.com/openai/openai-openapi/blob/master/openapi.yaml

56 | * 57 | * What sampling temperature to use, between 0 and 2. Higher values like 0.8 58 | * will make the output more random, while lower values like 0.2 will make it 59 | * more focused and deterministic. 60 | * 61 | * We generally recommend altering this or `top_p` but not both. 62 | */ 63 | @JsonProperty 64 | private Double temperature; 65 | 66 | /** 67 | * @param typeClass {@link #typeClass} 68 | */ 69 | protected OpenAiParameters(final Class typeClass) { 70 | this.typeClass = typeClass; 71 | } 72 | 73 | /** 74 | * @return {@link #maxTokens} 75 | */ 76 | public Integer getMaxTokens() { 77 | return maxTokens; 78 | } 79 | 80 | /** 81 | * @param maxTokens {@link #maxTokens} 82 | */ 83 | public void setMaxTokens(final Integer maxTokens) { 84 | this.maxTokens = maxTokens; 85 | } 86 | 87 | /** 88 | * @param maxTokens {@link #maxTokens} 89 | * @return this 90 | */ 91 | public T maxTokens(final Integer maxTokens) { 92 | setMaxTokens(maxTokens); 93 | return this.typeClass.cast(this); 94 | } 95 | 96 | /** 97 | * @return {@link #model} 98 | */ 99 | public String getModel() { 100 | return model; 101 | } 102 | 103 | /** 104 | * @param model {@link #model} 105 | */ 106 | public void setModel(final String model) { 107 | this.model = model; 108 | } 109 | 110 | /** 111 | * @param model {@link #model} 112 | * @return this 113 | */ 114 | public T model(final String model) { 115 | setModel(model); 116 | return this.typeClass.cast(this); 117 | } 118 | 119 | /** 120 | * @return {@link #n} 121 | */ 122 | public Integer getN() { 123 | return n; 124 | } 125 | 126 | /** 127 | * @param n {@link #n} 128 | */ 129 | public void setN(final Integer n) { 130 | this.n = n; 131 | } 132 | 133 | /** 134 | * @param n {@link #n} 135 | * @return this 136 | */ 137 | public T n(final Integer n) { 138 | setN(n); 139 | return this.typeClass.cast(this); 140 | } 141 | 142 | /** 143 | * @return {@link #temperature} 144 | */ 145 | public Double getTemperature() { 146 | return temperature; 147 | } 148 | 149 | /** 150 | * @param temperature {@link #temperature} 151 | */ 152 | public void setTemperature(final Double temperature) { 153 | this.temperature = temperature; 154 | } 155 | 156 | /** 157 | * @param temperature {@link #temperature} 158 | * @return this 159 | */ 160 | public T temperature(final Double temperature) { 161 | this.setTemperature(temperature); 162 | return this.typeClass.cast(this); 163 | } 164 | 165 | /** 166 | * copies parameter values from another instance of {@link OpenAiParameters} 167 | * 168 | * @param parameters the source {@link OpenAiParameters} 169 | */ 170 | public void copyFrom(final T parameters) { 171 | this.setMaxTokens(parameters.getMaxTokens()); 172 | this.setModel(parameters.getModel()); 173 | this.setN(parameters.getN()); 174 | this.setTemperature(parameters.getTemperature()); 175 | } 176 | } 177 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/OpenAiResponse.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai; 2 | 3 | import java.util.List; 4 | 5 | import com.fasterxml.jackson.annotation.JsonIgnoreProperties; 6 | 7 | /** 8 | * Model class for the response body of an OpenAI request 9 | */ 10 | @JsonIgnoreProperties(ignoreUnknown = true) 11 | public abstract class OpenAiResponse { 12 | 13 | /** 14 | * All contained choice instances of the response. 15 | */ 16 | private final List choices; 17 | 18 | /** 19 | * @param choices {@link #choices} 20 | */ 21 | protected OpenAiResponse(final List choices) { 22 | this.choices = choices; 23 | } 24 | 25 | /** 26 | * @return {@link #choices} 27 | */ 28 | public List getChoices() { 29 | return choices; 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/chat/OpenAiChatCompletionsChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai.chat; 2 | 3 | import java.util.LinkedList; 4 | import java.util.List; 5 | import java.util.Map; 6 | 7 | import org.apache.commons.text.StringSubstitutor; 8 | import org.springframework.web.reactive.function.client.WebClient; 9 | 10 | import com.fasterxml.jackson.databind.ObjectMapper; 11 | 12 | import io.github.cupybara.javalangchains.chains.llm.openai.OpenAiChain; 13 | 14 | /** 15 | * {@link OpenAiChain} for usage with the OpenAI /chat/completions API 16 | */ 17 | public class OpenAiChatCompletionsChain extends 18 | OpenAiChain { 19 | 20 | /** 21 | * The template for the system role which contains placeholders in the form 22 | * ${myPlaceholder} that are replaced for input documents before creating a 23 | * request to a LLM. 24 | */ 25 | private final String systemTemplate; 26 | 27 | /** 28 | * @param promptTemplate {@link #getPromptTemplate()} 29 | * @param parameters {@link #parameters}r 30 | * @param apiKey {@link #apiKey} 31 | * @param systemTemplate {@link #systemTemplate} 32 | * @param objectMapper {@link #objectMapper} 33 | * @param webClient {@link #webClient} 34 | */ 35 | public OpenAiChatCompletionsChain(final String promptTemplate, final OpenAiChatCompletionsParameters parameters, 36 | final String apiKey, final String systemTemplate, final ObjectMapper objectMapper, 37 | final WebClient webClient) { 38 | super(promptTemplate, "/v1/chat/completions", OpenAiChatCompletionsResponse.class, parameters, apiKey, 39 | objectMapper, webClient); 40 | this.systemTemplate = systemTemplate; 41 | } 42 | 43 | /** 44 | * @param promptTemplate {@link #getPromptTemplate()} 45 | * @param parameters {@link #parameters} 46 | * @param apiKey {@link #apiKey} 47 | * @param systemTemplate {@link #systemTemplate}s 48 | */ 49 | public OpenAiChatCompletionsChain(final String promptTemplate, final OpenAiChatCompletionsParameters parameters, 50 | final String apiKey, final String systemTemplate) { 51 | this(promptTemplate, parameters, apiKey, systemTemplate, createDefaultObjectMapper(), createDefaultWebClient()); 52 | } 53 | 54 | /** 55 | * @param promptTemplate {@link #getPromptTemplate()} 56 | * @param parameters {@link #parameters} 57 | * @param apiKey {@link #apiKey} 58 | */ 59 | public OpenAiChatCompletionsChain(final String promptTemplate, final OpenAiChatCompletionsParameters parameters, 60 | final String apiKey) { 61 | this(promptTemplate, parameters, apiKey, null); 62 | } 63 | 64 | @Override 65 | protected OpenAiChatCompletionsRequest createRequest(Map input) { 66 | final List messages = new LinkedList<>(); 67 | if (systemTemplate != null) { 68 | messages.add(new OpenAiChatMessage("system", new StringSubstitutor(input).replace(systemTemplate))); 69 | } 70 | messages.add(new OpenAiChatMessage("user", new StringSubstitutor(input).replace(getPromptTemplate()))); 71 | 72 | return new OpenAiChatCompletionsRequest(messages); 73 | } 74 | 75 | @Override 76 | protected String createOutput(final OpenAiChatCompletionsResponse response) { 77 | return response.getChoices().get(0).getMessage().getContent(); 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/chat/OpenAiChatCompletionsChoice.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai.chat; 2 | 3 | import com.fasterxml.jackson.annotation.JsonCreator; 4 | import com.fasterxml.jackson.annotation.JsonCreator.Mode; 5 | import com.fasterxml.jackson.annotation.JsonIgnoreProperties; 6 | import com.fasterxml.jackson.annotation.JsonProperty; 7 | 8 | /** 9 | * Model class for choices in an OpenAI /chat/completions response 10 | */ 11 | @JsonIgnoreProperties(ignoreUnknown = true) 12 | public final class OpenAiChatCompletionsChoice { 13 | 14 | /** 15 | * the {@link OpenAiChatMessage} for this response choice 16 | */ 17 | private final OpenAiChatMessage message; 18 | 19 | /** 20 | * @param message {@link #message} 21 | */ 22 | @JsonCreator(mode = Mode.PROPERTIES) 23 | public OpenAiChatCompletionsChoice(final @JsonProperty("message") OpenAiChatMessage message) { 24 | this.message = message; 25 | } 26 | 27 | /** 28 | * @return {@link #message} 29 | */ 30 | public OpenAiChatMessage getMessage() { 31 | return message; 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/chat/OpenAiChatCompletionsParameters.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai.chat; 2 | 3 | import java.util.Map; 4 | 5 | import com.fasterxml.jackson.annotation.JsonProperty; 6 | 7 | import io.github.cupybara.javalangchains.chains.llm.openai.OpenAiParameters; 8 | 9 | /** 10 | * Parameters for calling an OpenAI Chat Model 11 | * 12 | * https://platform.openai.com/docs/api-reference/chat/create 13 | */ 14 | public class OpenAiChatCompletionsParameters extends OpenAiParameters { 15 | 16 | /** 17 | *

From 18 | * https://github.com/openai/openai-openapi/blob/master/openapi.yaml

19 | * 20 | * Number between -2.0 and 2.0. Positive values penalize new tokens based on 21 | * their existing frequency in the text so far, decreasing the model's 22 | * likelihood to repeat the same line verbatim. 23 | */ 24 | @JsonProperty("frequence_penalty") 25 | private Double frequencePenalty; 26 | 27 | /** 28 | *

From 29 | * https://github.com/openai/openai-openapi/blob/master/openapi.yaml

30 | * 31 | * Modify the likelihood of specified tokens appearing in the completion. 32 | * Accepts a json object that maps tokens (specified by their token ID in the 33 | * tokenizer) to an associated bias value from -100 to 100. Mathematically, the 34 | * bias is added to the logits generated by the model prior to sampling. The 35 | * exact effect will vary per model, but values between -1 and 1 should decrease 36 | * or increase likelihood of selection; values like -100 or 100 should result in 37 | * a ban or exclusive selection of the relevant token. 38 | */ 39 | private Map logitBias; 40 | 41 | /** 42 | *

From 43 | * https://github.com/openai/openai-openapi/blob/master/openapi.yaml

44 | * 45 | * Number between -2.0 and 2.0. Positive values penalize new tokens based on 46 | * whether they appear in the text so far, increasing the model's likelihood to 47 | * talk about new topics. 48 | */ 49 | @JsonProperty("precence_penalty") 50 | private Double presencePenalty; 51 | 52 | /** 53 | * creates an instance of {@link OpenAiChatCompletionsParameters} 54 | */ 55 | public OpenAiChatCompletionsParameters() { 56 | super(OpenAiChatCompletionsParameters.class); 57 | } 58 | 59 | /** 60 | * @return {@link #frequencePenalty} 61 | */ 62 | public Double getFrequencePenalty() { 63 | return frequencePenalty; 64 | } 65 | 66 | /** 67 | * @param frequencePenalty {@link #frequencePenalty} 68 | */ 69 | public void setFrequencePenalty(final Double frequencePenalty) { 70 | this.frequencePenalty = frequencePenalty; 71 | } 72 | 73 | /** 74 | * @param frequencePenalty {@link #frequencePenalty} 75 | * @return this 76 | */ 77 | public OpenAiChatCompletionsParameters frequencePenalty(final Double frequencePenalty) { 78 | this.setFrequencePenalty(frequencePenalty); 79 | return this; 80 | } 81 | 82 | /** 83 | * @return {@link #logitBias} 84 | */ 85 | public Map getLogitBias() { 86 | return logitBias; 87 | } 88 | 89 | /** 90 | * @param logitBias {@link #logitBias} 91 | */ 92 | public void setLogitBias(final Map logitBias) { 93 | this.logitBias = logitBias; 94 | } 95 | 96 | /** 97 | * @param logitBias {@link #logitBias} 98 | * @return this 99 | */ 100 | public OpenAiChatCompletionsParameters logitBias(final Map logitBias) { 101 | setLogitBias(logitBias); 102 | return this; 103 | } 104 | 105 | /** 106 | * @return {@link #presencePenalty} 107 | */ 108 | public Double getPresencePenalty() { 109 | return presencePenalty; 110 | } 111 | 112 | /** 113 | * @param presencePenalty {@link #presencePenalty} 114 | */ 115 | public void setPresencePenalty(final Double presencePenalty) { 116 | this.presencePenalty = presencePenalty; 117 | } 118 | 119 | /** 120 | * @param presencePenalty {@link #presencePenalty} 121 | * @return this 122 | */ 123 | public OpenAiChatCompletionsParameters presencePenalty(final Double presencePenalty) { 124 | this.setPresencePenalty(presencePenalty); 125 | return this; 126 | } 127 | 128 | /** 129 | * copies parameter values from another instance of {@link OpenAiChatCompletionsParameters} 130 | * 131 | * @param parameters the source {@link OpenAiChatCompletionsParameters} 132 | */ 133 | @Override 134 | public void copyFrom(final OpenAiChatCompletionsParameters parameters) { 135 | super.copyFrom(parameters); 136 | 137 | this.setFrequencePenalty(parameters.getFrequencePenalty()); 138 | this.setLogitBias(parameters.getLogitBias()); 139 | this.setPresencePenalty(parameters.getPresencePenalty()); 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/chat/OpenAiChatCompletionsRequest.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai.chat; 2 | 3 | import java.util.List; 4 | 5 | /** 6 | * Model class for the OpenAI /chat/completions request body 7 | */ 8 | public final class OpenAiChatCompletionsRequest extends OpenAiChatCompletionsParameters { 9 | 10 | /** 11 | * The {@link OpenAiChatMessage} instances of the conversation 12 | */ 13 | private final List messages; 14 | 15 | /** 16 | * @param messages {@link #messages} 17 | */ 18 | public OpenAiChatCompletionsRequest(final List messages) { 19 | this.messages = messages; 20 | } 21 | 22 | /** 23 | * @return {@link #messages} 24 | */ 25 | public List getMessages() { 26 | return messages; 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/chat/OpenAiChatCompletionsResponse.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai.chat; 2 | 3 | import java.util.List; 4 | 5 | import com.fasterxml.jackson.annotation.JsonCreator; 6 | import com.fasterxml.jackson.annotation.JsonCreator.Mode; 7 | 8 | import io.github.cupybara.javalangchains.chains.llm.openai.OpenAiResponse; 9 | 10 | import com.fasterxml.jackson.annotation.JsonIgnoreProperties; 11 | import com.fasterxml.jackson.annotation.JsonProperty; 12 | 13 | /** 14 | * Model class for the response body of an OpenAI /chat/completions request 15 | */ 16 | @JsonIgnoreProperties(ignoreUnknown = true) 17 | public final class OpenAiChatCompletionsResponse extends OpenAiResponse { 18 | 19 | /** 20 | * @param choices {@link #getChoices()} 21 | */ 22 | @JsonCreator(mode = Mode.PROPERTIES) 23 | public OpenAiChatCompletionsResponse(final @JsonProperty("choices") List choices) { 24 | super(choices); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/chat/OpenAiChatMessage.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai.chat; 2 | 3 | import com.fasterxml.jackson.annotation.JsonCreator; 4 | import com.fasterxml.jackson.annotation.JsonProperty; 5 | import com.fasterxml.jackson.annotation.JsonCreator.Mode; 6 | 7 | /** 8 | * Model class for request and response messages of an OpenAI /chat/completions 9 | * request 10 | */ 11 | public final class OpenAiChatMessage { 12 | 13 | /** 14 | * system|user|assistant 15 | */ 16 | private final String role; 17 | 18 | /** 19 | * Message text 20 | */ 21 | private final String content; 22 | 23 | /** 24 | * @param role {@link #role} 25 | * @param content {@link #content} 26 | */ 27 | @JsonCreator(mode = Mode.PROPERTIES) 28 | public OpenAiChatMessage(final @JsonProperty("role") String role, final @JsonProperty("content") String content) { 29 | this.role = role; 30 | this.content = content; 31 | } 32 | 33 | /** 34 | * @return {@link #role} 35 | */ 36 | public String getRole() { 37 | return role; 38 | } 39 | 40 | /** 41 | * @return {@link #content} 42 | */ 43 | public String getContent() { 44 | return content; 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/completions/OpenAiCompletionsChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai.completions; 2 | 3 | import java.util.Map; 4 | 5 | import org.apache.commons.text.StringSubstitutor; 6 | import org.springframework.web.reactive.function.client.WebClient; 7 | 8 | import com.fasterxml.jackson.databind.ObjectMapper; 9 | 10 | import io.github.cupybara.javalangchains.chains.llm.openai.OpenAiChain; 11 | 12 | /** 13 | * {@link OpenAiChain} for usage with the OpenAI /completions API 14 | */ 15 | public class OpenAiCompletionsChain 16 | extends OpenAiChain { 17 | 18 | /** 19 | * @param promptTemplate {@link #getPromptTemplate()} 20 | * @param parameters {@link #parameters}r 21 | * @param apiKey {@link #apiKey} 22 | * @param objectMapper {@link #objectMapper} 23 | * @param webClient {@link #webClient} 24 | */ 25 | public OpenAiCompletionsChain(final String promptTemplate, final OpenAiCompletionsParameters parameters, 26 | final String apiKey, final ObjectMapper objectMapper, final WebClient webClient) { 27 | super(promptTemplate, "/v1/completions", OpenAiCompletionsResponse.class, parameters, apiKey, objectMapper, 28 | webClient); 29 | } 30 | 31 | /** 32 | * @param promptTemplate {@link #getPromptTemplate()} 33 | * @param parameters {@link #parameters} 34 | * @param apiKey {@link #apiKey} 35 | */ 36 | public OpenAiCompletionsChain(final String promptTemplate, final OpenAiCompletionsParameters parameters, 37 | final String apiKey) { 38 | this(promptTemplate, parameters, apiKey, createDefaultObjectMapper(), createDefaultWebClient()); 39 | } 40 | 41 | @Override 42 | protected OpenAiCompletionsRequest createRequest(final Map input) { 43 | return new OpenAiCompletionsRequest(new StringSubstitutor(input).replace(getPromptTemplate())); 44 | } 45 | 46 | @Override 47 | protected String createOutput(final OpenAiCompletionsResponse response) { 48 | return response.getChoices().get(0).getText(); 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/completions/OpenAiCompletionsChoice.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai.completions; 2 | 3 | import com.fasterxml.jackson.annotation.JsonCreator; 4 | import com.fasterxml.jackson.annotation.JsonCreator.Mode; 5 | import com.fasterxml.jackson.annotation.JsonIgnoreProperties; 6 | import com.fasterxml.jackson.annotation.JsonProperty; 7 | 8 | /** 9 | * Model class for choices in an OpenAI /completions response 10 | */ 11 | @JsonIgnoreProperties(ignoreUnknown = true) 12 | public final class OpenAiCompletionsChoice { 13 | 14 | /** 15 | * the completion result 16 | */ 17 | private final String text; 18 | 19 | /** 20 | * @param text {@link #text} 21 | */ 22 | @JsonCreator(mode = Mode.PROPERTIES) 23 | public OpenAiCompletionsChoice(final @JsonProperty("text") String text) { 24 | this.text = text; 25 | } 26 | 27 | /** 28 | * @return {@link #text} 29 | */ 30 | public String getText() { 31 | return text; 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/completions/OpenAiCompletionsParameters.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai.completions; 2 | 3 | import java.util.Set; 4 | 5 | import io.github.cupybara.javalangchains.chains.llm.openai.OpenAiParameters; 6 | 7 | /** 8 | * Parameters for calling an OpenAI Completions Model 9 | * 10 | * https://platform.openai.com/docs/api-reference/completions/create 11 | */ 12 | public class OpenAiCompletionsParameters extends OpenAiParameters { 13 | 14 | /** 15 | *

From 16 | * https://github.com/openai/openai-openapi/blob/master/openapi.yaml

17 | * 18 | * Up to 4 sequences where the API will stop generating further tokens. The 19 | * returned text will not contain the stop sequence. 20 | */ 21 | private Set stop; 22 | 23 | /** 24 | * Creates an instance of {@link OpenAiCompletionsParameters} 25 | */ 26 | public OpenAiCompletionsParameters() { 27 | super(OpenAiCompletionsParameters.class); 28 | } 29 | 30 | /** 31 | * @return {@link #stop} 32 | */ 33 | public Set getStop() { 34 | return stop; 35 | } 36 | 37 | /** 38 | * @param stop {@link #stop} 39 | */ 40 | public void setStop(final Set stop) { 41 | this.stop = stop; 42 | } 43 | 44 | /** 45 | * @param stop {@link #stop} 46 | * @return this 47 | */ 48 | public OpenAiCompletionsParameters stop(final Set stop) { 49 | setStop(stop); 50 | return this; 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/completions/OpenAiCompletionsRequest.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai.completions; 2 | 3 | /** 4 | * Model class for the OpenAI /completions request body 5 | */ 6 | public final class OpenAiCompletionsRequest extends OpenAiCompletionsParameters { 7 | 8 | /** 9 | * The prompt for the model 10 | */ 11 | private final String prompt; 12 | 13 | /** 14 | * @param prompt {@link #prompt} 15 | */ 16 | public OpenAiCompletionsRequest(final String prompt) { 17 | this.prompt = prompt; 18 | } 19 | 20 | /** 21 | * @return {@link #prompt} 22 | */ 23 | public String getPrompt() { 24 | return prompt; 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/llm/openai/completions/OpenAiCompletionsResponse.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai.completions; 2 | 3 | import java.util.List; 4 | 5 | import com.fasterxml.jackson.annotation.JsonCreator; 6 | import com.fasterxml.jackson.annotation.JsonCreator.Mode; 7 | 8 | import io.github.cupybara.javalangchains.chains.llm.openai.OpenAiResponse; 9 | 10 | import com.fasterxml.jackson.annotation.JsonIgnoreProperties; 11 | import com.fasterxml.jackson.annotation.JsonProperty; 12 | 13 | /** 14 | * Model class for the response body of an OpenAI /completions request 15 | */ 16 | @JsonIgnoreProperties(ignoreUnknown = true) 17 | public final class OpenAiCompletionsResponse extends OpenAiResponse { 18 | 19 | /** 20 | * @param choices {@link #getChoices()} 21 | */ 22 | @JsonCreator(mode = Mode.PROPERTIES) 23 | public OpenAiCompletionsResponse(final @JsonProperty("choices") List choices) { 24 | super(choices); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/qa/AnswerWithSources.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.qa; 2 | 3 | import java.util.Collections; 4 | import java.util.List; 5 | import java.util.stream.Collectors; 6 | 7 | /** 8 | * Model class for QA output with sources 9 | */ 10 | public final class AnswerWithSources { 11 | 12 | private final String answer; 13 | private final List sources; 14 | 15 | /** 16 | * Creates an instance of {@link AnswerWithSources} 17 | * 18 | * @param answer {@link #answer} 19 | * @param sources {@link #sources} 20 | */ 21 | public AnswerWithSources(final String answer, final List sources) { 22 | this.answer = answer; 23 | this.sources = sources; 24 | } 25 | 26 | /** 27 | * Creates an instance of {@link AnswerWithSources} 28 | * 29 | * @param answer {@link #answer} 30 | */ 31 | public AnswerWithSources(final String answer) { 32 | this(answer, Collections.emptyList()); 33 | } 34 | 35 | /** 36 | * @return {@link #answer} 37 | */ 38 | public String getAnswer() { 39 | return answer; 40 | } 41 | 42 | /** 43 | * @return {@link #sources} 44 | */ 45 | public List getSources() { 46 | return sources; 47 | } 48 | 49 | @Override 50 | public String toString() { 51 | return String.format("%s (%s)", getAnswer(), getSources().stream().collect(Collectors.joining(", "))); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/qa/CombineDocumentsChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.qa; 2 | 3 | import java.util.HashMap; 4 | import java.util.Map; 5 | import java.util.concurrent.atomic.AtomicReference; 6 | import java.util.stream.Collectors; 7 | import java.util.stream.Stream; 8 | 9 | import org.apache.commons.text.StringSubstitutor; 10 | 11 | import io.github.cupybara.javalangchains.chains.Chain; 12 | import io.github.cupybara.javalangchains.util.PromptConstants; 13 | import io.github.cupybara.javalangchains.util.PromptTemplates; 14 | 15 | /** 16 | * This {@link Chain} is used to combine multiple retrieved documents into one 17 | * prompt which can then be used to target a LLM in subsequent steps. 18 | */ 19 | public class CombineDocumentsChain implements Chain>, Map> { 20 | 21 | /** 22 | * The template for each single document which contains placeholders in the form 23 | * ${myPlaceholder} that are replaced for each the keys of each input document. 24 | */ 25 | private final String documentPromptTemplate; 26 | 27 | /** 28 | * creates an instance of the {@link CombineDocumentsChain} 29 | * 30 | * @param documentPromptTemplate {@link #documentPromptTemplate} 31 | */ 32 | public CombineDocumentsChain(final String documentPromptTemplate) { 33 | this.documentPromptTemplate = documentPromptTemplate; 34 | } 35 | 36 | /** 37 | * creates an instance of the {@link CombineDocumentsChain} 38 | */ 39 | public CombineDocumentsChain() { 40 | this(PromptTemplates.QA_DOCUMENT); 41 | } 42 | 43 | @Override 44 | public Map run(final Stream> input) { 45 | final AtomicReference questionRef = new AtomicReference<>(); 46 | 47 | final String combinedContent = input.map(document -> { 48 | if (questionRef.get() == null) { 49 | questionRef.set(document.get(PromptConstants.QUESTION)); 50 | } 51 | return this.createDocumentPrompt(document); 52 | }).collect(Collectors.joining("\n\n")); 53 | 54 | final Map result = new HashMap<>(); 55 | result.put(PromptConstants.QUESTION, questionRef.get()); 56 | result.put(PromptConstants.CONTENT, combinedContent); 57 | return result; 58 | } 59 | 60 | private String createDocumentPrompt(final Map document) { 61 | return new StringSubstitutor(document).replace(documentPromptTemplate); 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/qa/MapAnswerWithSourcesChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.qa; 2 | 3 | import java.util.Arrays; 4 | import java.util.regex.Matcher; 5 | import java.util.regex.Pattern; 6 | import java.util.stream.Collectors; 7 | 8 | import io.github.cupybara.javalangchains.chains.Chain; 9 | 10 | /** 11 | * Splits answers with sources from a QA chain. 12 | * 13 | *

Examples

14 | * 15 | *
16 |  *  I don't know who James Anderson is.
17 |  *  SOURCES:
18 |  * 
19 | * 20 | *
21 |  *  There are two different John Does mentioned, one is a scientist and humanitarian, and the other is a traveler and author of a travel memoir. It is not clear if they are the same person or two different people with the same name.
22 |  *  SOURCES: 1, 3
23 |  * 
24 | */ 25 | public class MapAnswerWithSourcesChain implements Chain { 26 | 27 | /** 28 | * this {@link Pattern} is used to retrieve sources from a qa result string 29 | */ 30 | private final Pattern retrieveSourcesPattern; 31 | 32 | /** 33 | * @param retrieveSourcesPattern {@link #retrieveSourcesPattern} 34 | */ 35 | public MapAnswerWithSourcesChain(final Pattern retrieveSourcesPattern) { 36 | this.retrieveSourcesPattern = retrieveSourcesPattern; 37 | } 38 | 39 | /** 40 | * @param retrieveSourcesRegex used to create the 41 | * {@link #retrieveSourcesPattern} 42 | */ 43 | public MapAnswerWithSourcesChain(final String retrieveSourcesRegex) { 44 | this.retrieveSourcesPattern = Pattern.compile(retrieveSourcesRegex, Pattern.CASE_INSENSITIVE | Pattern.DOTALL); 45 | } 46 | 47 | /** 48 | * creates an instance of {@link MapAnswerWithSourcesChain} with a default regex 49 | * to retrieve sources 50 | */ 51 | public MapAnswerWithSourcesChain() { 52 | this("(.*?)(?:Source(?:s)?:\\s*)(.*)"); 53 | } 54 | 55 | @Override 56 | public AnswerWithSources run(final String input) { 57 | final Matcher matcher = retrieveSourcesPattern.matcher(input); 58 | 59 | if (matcher.find()) { 60 | final String content = matcher.group(1).trim(); 61 | final String sources = matcher.group(2).trim(); 62 | return new AnswerWithSources(content, 63 | Arrays.stream(sources.split(",")).map(String::trim).distinct().collect(Collectors.toList())); 64 | } else { 65 | return new AnswerWithSources(input); 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/qa/ModifyDocumentsContentChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.qa; 2 | 3 | import java.util.HashMap; 4 | import java.util.Map; 5 | import java.util.stream.Stream; 6 | 7 | import org.apache.logging.log4j.LogManager; 8 | 9 | import io.github.cupybara.javalangchains.chains.Chain; 10 | import io.github.cupybara.javalangchains.util.PromptConstants; 11 | 12 | /** 13 | * {@link Chain} that takes documents as input and modifies their 14 | * {@link PromptConstants#CONTENT} entry using a llm chain that is passed as a 15 | * constructor param. 16 | */ 17 | public class ModifyDocumentsContentChain implements Chain>, Stream>> { 18 | 19 | /** 20 | * this {@link Chain} is applied each document (a LLM Chain for example) 21 | */ 22 | private final Chain, String> documentChain; 23 | 24 | /** 25 | * if true the {@link #documentChain} is called for each document on another 26 | * thread to increase overall performance 27 | */ 28 | private final boolean parallel; 29 | 30 | /** 31 | * @param documentChain {@link #documentChain} 32 | * @param parallel {@link #parallel} 33 | */ 34 | public ModifyDocumentsContentChain(final Chain, String> documentChain, final boolean parallel) { 35 | this.documentChain = documentChain; 36 | this.parallel = parallel; 37 | } 38 | 39 | /** 40 | * @param documentChain {@link #documentChain} 41 | */ 42 | public ModifyDocumentsContentChain(final Chain, String> documentChain) { 43 | this(documentChain, true); 44 | } 45 | 46 | @Override 47 | public Stream> run(final Stream> input) { 48 | final Stream> stream = input.map(document -> { 49 | LogManager.getLogger(getClass()).trace("pre modification: {}", document); 50 | final String mappedContent = documentChain.run(document); 51 | LogManager.getLogger(getClass()).trace("post modification: {}", mappedContent); 52 | 53 | final Map mappedDocument = new HashMap<>(document); 54 | mappedDocument.put(PromptConstants.CONTENT, mappedContent); 55 | return mappedDocument; 56 | }); 57 | 58 | if (parallel) { 59 | return stream.parallel(); 60 | } 61 | 62 | return stream; 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/qa/split/JtokkitTextSplitter.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.qa.split; 2 | 3 | import com.knuddels.jtokkit.api.Encoding; 4 | 5 | /** 6 | * This {@link TextSplitter} splits documents based on their token count. For 7 | * that purpose jtokkit is 8 | * utilized. 9 | */ 10 | public final class JtokkitTextSplitter extends MaxLengthBasedTextSplitter { 11 | 12 | /** 13 | * the {@link Encoding} used for token counting 14 | */ 15 | private final Encoding encoding; 16 | 17 | /** 18 | * creates an instance of {@link JtokkitTextSplitter} 19 | * 20 | * @param encoding {@link #encoding} 21 | * @param maxTokens max amount of tokens for each chunk 22 | * @param textStreamer the {@link TextStreamer} used for streaming the base text 23 | */ 24 | public JtokkitTextSplitter(final Encoding encoding, final int maxTokens, final TextStreamer textStreamer) { 25 | super(maxTokens, textStreamer); 26 | this.encoding = encoding; 27 | } 28 | 29 | /** 30 | * creates an instance of {@link JtokkitTextSplitter} with sentence based text 31 | * streaming 32 | * 33 | * @param encoding {@link #encoding} 34 | * @param maxTokens max amount of tokens for each chunk 35 | */ 36 | public JtokkitTextSplitter(final Encoding encoding, final int maxTokens) { 37 | this(encoding, maxTokens, new TextStreamer()); 38 | } 39 | 40 | @Override 41 | protected int getLength(final String textPart) { 42 | return encoding.countTokens(textPart); 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/qa/split/MaxLengthBasedTextSplitter.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.qa.split; 2 | 3 | import java.util.LinkedList; 4 | import java.util.List; 5 | import java.util.concurrent.atomic.AtomicReference; 6 | 7 | /** 8 | * abstract base class for {@link TextSplitter} implementations that use a max 9 | * length for deciding when to split 10 | */ 11 | public abstract class MaxLengthBasedTextSplitter implements TextSplitter { 12 | 13 | /** 14 | * max length for each chunk in the split 15 | */ 16 | private final int maxLength; 17 | 18 | /** 19 | * the {@link TextStreamer} used for streaming the base text (sentence by 20 | * sentence is the default) 21 | */ 22 | private final TextStreamer textStreamer; 23 | 24 | /** 25 | * @param maxLength {@link #maxLength} 26 | * @param textStreamer {@link #textStreamer} 27 | */ 28 | protected MaxLengthBasedTextSplitter(final int maxLength, final TextStreamer textStreamer) { 29 | this.maxLength = maxLength; 30 | this.textStreamer = textStreamer; 31 | } 32 | 33 | /** 34 | * creates a {@link MaxLengthBasedTextSplitter} using sentence wise text 35 | * streaming 36 | * 37 | * @param maxLength {@link #maxLength} 38 | */ 39 | protected MaxLengthBasedTextSplitter(final int maxLength) { 40 | this(maxLength, new TextStreamer()); 41 | } 42 | 43 | /** 44 | * provide the length value for a text part 45 | * 46 | * @param textPart the text part which needs to be measured 47 | * @return the length for the passed textPart 48 | */ 49 | protected abstract int getLength(String textPart); 50 | 51 | @Override 52 | public final List split(final String text) { 53 | final List split = new LinkedList<>(); 54 | 55 | final AtomicReference partition = new AtomicReference<>(""); 56 | 57 | this.textStreamer.stream(text).forEach(textPart -> { 58 | final String newPartition = partition.get() + textPart; 59 | 60 | final int newPartitionLength = getLength(newPartition); 61 | if (newPartitionLength > maxLength) { 62 | 63 | // the current textPart must be part of the next chunk 64 | if (textPart.length() == newPartition.length()) { 65 | throw new IllegalStateException( 66 | "Text partition " + textPart + " is too long. Try to use another TextStreamer."); 67 | } 68 | split.add(partition.get()); 69 | partition.set(textPart); 70 | 71 | } else if (newPartitionLength == maxLength) { 72 | 73 | // the current textPart is part of the current chunk but max length is reached 74 | split.add(newPartition); 75 | partition.set(""); 76 | 77 | } else { 78 | // the current textPart is part of the current chunk 79 | partition.set(newPartition); 80 | } 81 | }); 82 | 83 | final String trailingText = partition.get(); 84 | if (trailingText.length() > 0) { 85 | split.add(trailingText); 86 | } 87 | 88 | return split; 89 | } 90 | 91 | } 92 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/qa/split/SplitDocumentsChain.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.qa.split; 2 | 3 | import java.util.HashMap; 4 | import java.util.Map; 5 | import java.util.stream.Stream; 6 | 7 | import io.github.cupybara.javalangchains.chains.Chain; 8 | import io.github.cupybara.javalangchains.util.PromptConstants; 9 | 10 | /** 11 | * This {@link Chain} is used to split long documents into chunks. All document 12 | * keys are copied except for the {@link PromptConstants#CONTENT} which is 13 | * split. 14 | */ 15 | public class SplitDocumentsChain implements Chain>, Stream>> { 16 | 17 | /** 18 | * This {@link TextSplitter} is used to create one or more documents from an 19 | * input document based on the {@link PromptConstants#CONTENT} key. 20 | */ 21 | private final TextSplitter textSplitter; 22 | 23 | /** 24 | * creates an instance of the {@link SplitDocumentsChain} 25 | * 26 | * @param textSplitter {@link #textSplitter} 27 | */ 28 | public SplitDocumentsChain(final TextSplitter textSplitter) { 29 | this.textSplitter = textSplitter; 30 | } 31 | 32 | @Override 33 | public Stream> run(final Stream> input) { 34 | return input.flatMap(this::splitDocument); 35 | } 36 | 37 | private Stream> splitDocument(final Map document) { 38 | final String content = document.get(PromptConstants.CONTENT); 39 | 40 | return this.textSplitter.split(content).stream().map(contentPart -> { 41 | final Map documentPart = new HashMap<>(document); 42 | documentPart.put(PromptConstants.CONTENT, contentPart); 43 | return documentPart; 44 | }); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/qa/split/TextSplitter.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.qa.split; 2 | 3 | import java.util.List; 4 | 5 | /** 6 | * Implementations are used by the {@link SplitDocumentsChain}. A 7 | * {@link TextSplitter} takes an input text and creates a {@link List} with one 8 | * or more result strings based on the original text. 9 | */ 10 | @FunctionalInterface 11 | public interface TextSplitter { 12 | 13 | /** 14 | * Splits a text into one or more subtexts 15 | * 16 | * @param text text to split 17 | * @return {@link List} with text partitions 18 | */ 19 | List split(String text); 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/chains/qa/split/TextStreamer.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.qa.split; 2 | 3 | import java.text.BreakIterator; 4 | import java.util.Iterator; 5 | import java.util.NoSuchElementException; 6 | import java.util.Spliterator; 7 | import java.util.Spliterators; 8 | import java.util.function.Supplier; 9 | import java.util.stream.Stream; 10 | import java.util.stream.StreamSupport; 11 | 12 | /** 13 | * streams a text using a {@link BreakIterator} 14 | */ 15 | public class TextStreamer { 16 | 17 | /** 18 | * creates the {@link BreakIterator} used for streaming 19 | */ 20 | private final Supplier breakIteratorSupplier; 21 | 22 | /** 23 | * creates a {@link TextStreamer} using a custom {@link BreakIterator} 24 | * 25 | * @param breakIteratorSupplier {@link #breakIteratorSupplier} 26 | */ 27 | public TextStreamer(final Supplier breakIteratorSupplier) { 28 | this.breakIteratorSupplier = breakIteratorSupplier; 29 | } 30 | 31 | /** 32 | * creates a {@link TextStreamer} which streams sentences 33 | */ 34 | public TextStreamer() { 35 | this(BreakIterator::getSentenceInstance); 36 | } 37 | 38 | /** 39 | * creates a stream of text partitions 40 | * 41 | * @param text partitionized text 42 | * @return {@link Stream} of text partitions 43 | */ 44 | public Stream stream(final String text) { 45 | final BreakIterator breakIterator = breakIteratorSupplier.get(); 46 | breakIterator.setText(text); 47 | 48 | final Iterator breakIteratorAdapter = new Iterator() { 49 | int start = breakIterator.first(); 50 | int end = breakIterator.next(); 51 | 52 | @Override 53 | public boolean hasNext() { 54 | return end != BreakIterator.DONE; 55 | } 56 | 57 | @Override 58 | public String next() { 59 | if (end == BreakIterator.DONE) { 60 | throw new NoSuchElementException("No more words"); 61 | } 62 | 63 | final String textPartition = text.substring(start, end); 64 | start = end; 65 | end = breakIterator.next(); 66 | return textPartition; 67 | } 68 | }; 69 | 70 | return StreamSupport.stream(Spliterators.spliteratorUnknownSize(breakIteratorAdapter, Spliterator.ORDERED), 71 | false); 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/util/PromptConstants.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.util; 2 | 3 | /** 4 | * Utility Class which holds constants for prompt placeholders 5 | */ 6 | public final class PromptConstants { 7 | 8 | /** 9 | * placeholder for the question in qa context 10 | */ 11 | public static final String QUESTION = "question"; 12 | 13 | /** 14 | * placeholder for text content in qa context 15 | */ 16 | public static final String CONTENT = "content"; 17 | 18 | /** 19 | * placeholder for sources in qa context 20 | */ 21 | public static final String SOURCE = "source"; 22 | 23 | private PromptConstants() { 24 | // not instantiated 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/io/github/cupybara/javalangchains/util/PromptTemplates.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.util; 2 | 3 | // @formatter:off 4 | 5 | /** 6 | * this utility class holds templates for various prompts 7 | */ 8 | public final class PromptTemplates { 9 | 10 | /** 11 | * https://github.com/hwchase17/langchain/blob/master/langchain/chains/qa_with_sources/map_reduce_prompt.py#L4 12 | */ 13 | public static final String QA_SUMMARIZE = String.format( 14 | "Use the following portion of a long document to see if any of the text is relevant to answer the question.\n" 15 | + "Return any relevant text verbatim.\n" 16 | + "${content}\n" 17 | + "Question: ${%s}\n" 18 | + "Relevant text, if any:", 19 | PromptConstants.QUESTION); 20 | 21 | /** 22 | * https://github.com/hwchase17/langchain/blob/master/langchain/chains/qa_with_sources/stuff_prompt.py#LL4C15-L38C13 23 | */ 24 | public static final String QA_COMBINE = String.format( 25 | "Given the following extracted parts of a long document and a question, create a final answer with references (\"SOURCES\").\n" 26 | + "If you don't know the answer, just say that you don't know. Don't try to make up an answer.\n" 27 | + "ALWAYS return a \"SOURCES\" part in your answer.\n\n" 28 | + "QUESTION: Which state/country's law governs the interpretation of the contract?\n" 29 | + "=========\n" 30 | + "Content: This Agreement is governed by English law and the parties submit to the exclusive jurisdiction of the English courts in relation to any dispute (contractual or non-contractual) concerning this Agreement save that either party may apply to any court for an injunction or other relief to protect its Intellectual Property Rights.\n" 31 | + "Source: 28-pl\n" 32 | + "Content: No Waiver. Failure or delay in exercising any right or remedy under this Agreement shall not constitute a waiver of such (or any other) right or remedy.\n\n11.7 Severability. The invalidity, illegality or unenforceability of any term (or part of a term) of this Agreement shall not affect the continuation in force of the remainder of the term (if any) and this Agreement.\n\n11.8 No Agency. Except as expressly stated otherwise, nothing in this Agreement shall create an agency, partnership or joint venture of any kind between the parties.\n\n11.9 No Third-Party Beneficiaries.\n" 33 | + "Source: 30-pl\n" 34 | + "Content: (b) if Google believes, in good faith, that the Distributor has violated or caused Google to violate any Anti-Bribery Laws (as defined in Clause 8.5) or that such a violation is reasonably likely to occur,\n" 35 | + "Source: 4-pl\n" 36 | + "=========\n" 37 | + "FINAL ANSWER: This Agreement is governed by English law.\n" 38 | + "SOURCES: 28-pl\n\n" 39 | + "QUESTION: What did the president say about Michael Jackson?\n" 40 | + "=========\n" 41 | + "Content: Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n\nLast year COVID-19 kept us apart. This year we are finally together again. \n\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \n\nWith a duty to one another to the American people to the Constitution. \n\nAnd with an unwavering resolve that freedom will always triumph over tyranny. \n\nSix days ago, Russia’s Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \n\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \n\nHe met the Ukrainian people. \n\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \n\nGroups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland.\n" 42 | + "Source: 0-pl\n" 43 | + "Content: And we won’t stop. \n\nWe have lost so much to COVID-19. Time with one another. And worst of all, so much loss of life. \n\nLet’s use this moment to reset. Let’s stop looking at COVID-19 as a partisan dividing line and see it for what it is: A God-awful disease. \n\nLet’s stop seeing each other as enemies, and start seeing each other for who we really are: Fellow Americans. \n\nWe can’t change how divided we’ve been. But we can change how we move forward—on COVID-19 and other issues we must face together. \n\nI recently visited the New York City Police Department days after the funerals of Officer Wilbert Mora and his partner, Officer Jason Rivera. \n\nThey were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n\nOfficer Mora was 27 years old. \n\nOfficer Rivera was 22. \n\nBoth Dominican Americans who’d grown up on the same streets they later chose to patrol as police officers. \n\nI spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves.\n" 44 | + "Source: 24-pl\n" 45 | + "Content: And a proud Ukrainian people, who have known 30 years of independence, have repeatedly shown that they will not tolerate anyone who tries to take their country backwards. \n\nTo all Americans, I will be honest with you, as I’ve always promised. A Russian dictator, invading a foreign country, has costs around the world. \n\nAnd I’m taking robust action to make sure the pain of our sanctions is targeted at Russia’s economy. And I will use every tool at our disposal to protect American businesses and consumers. \n\nTonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world. \n\nAmerica will lead that effort, releasing 30 Million barrels from our own Strategic Petroleum Reserve. And we stand ready to do more if necessary, unified with our allies. \n\nThese steps will help blunt gas prices here at home. And I know the news about what’s happening can seem alarming. \n\nBut I want you to know that we are going to be okay.\n" 46 | + "Source: 5-pl\n" 47 | + "Content: More support for patients and families. \n\nTo get there, I call on Congress to fund ARPA-H, the Advanced Research Projects Agency for Health. \n\nIt’s based on DARPA—the Defense Department project that led to the Internet, GPS, and so much more. \n\nARPA-H will have a singular purpose—to drive breakthroughs in cancer, Alzheimer’s, diabetes, and more. \n\nA unity agenda for the nation. \n\nWe can do this. \n\nMy fellow Americans—tonight , we have gathered in a sacred space—the citadel of our democracy. \n\nIn this Capitol, generation after generation, Americans have debated great questions amid great strife, and have done great things. \n\nWe have fought for freedom, expanded liberty, defeated totalitarianism and terror. \n\nAnd built the strongest, freest, and most prosperous nation the world has ever known. \n\nNow is the hour. \n\nOur moment of responsibility. \n\nOur test of resolve and conscience, of history itself. \n\nIt is in this moment that our character is formed. Our purpose is found. Our future is forged. \n\nWell I know this nation.\n" 48 | + "Source: 34-pl\n" 49 | + "=========\n" 50 | + "FINAL ANSWER: The president did not mention Michael Jackson.\n" 51 | + "SOURCES:\n\n" 52 | + "QUESTION: ${%s}\n" 53 | + "=========\n" 54 | + "${%s}\n" 55 | + "=========\n" 56 | + "FINAL ANSWER:", 57 | PromptConstants.QUESTION, PromptConstants.CONTENT); 58 | 59 | /** 60 | * https://github.com/hwchase17/langchain/blob/master/langchain/chains/qa_with_sources/stuff_prompt.py#L41 61 | */ 62 | public static final String QA_DOCUMENT = String.format("Content: ${%s}\nSource: ${%s}", PromptConstants.CONTENT, PromptConstants.SOURCE); 63 | 64 | 65 | /** 66 | * Based on {@link #QA_COMBINE} instructs an LLM to create information snippets used for document comparison 67 | */ 68 | public static final String QA_COMPARE = String.format( 69 | "Given the following extracted parts of a long document and a question, create a final answer which provides information about the question or criterion in the extracted document parts.\n" 70 | + "Your answer will be used for document comparison, so do not leave out potentially interesting key points.\n" 71 | + "If you don't know the answer, just say that you don't know. Don't try to make up an answer.\n" 72 | + "=========\n" 73 | + "QUESTION: ${%s}\n" 74 | + "=========\n" 75 | + "${%s}", 76 | PromptConstants.QUESTION, PromptConstants.CONTENT); 77 | 78 | /** 79 | * Instructs an LLM to compare a set of documents 80 | */ 81 | public static final String COMPARE_MULTIPLE_DOCUMENTS = String.format( 82 | "${%s}\n" 83 | + "\n" 84 | + "Compare the provided content and sources based on the given question: ${%s}." 85 | + "\n" 86 | + "Line out differences and similarities between the different sources.\n" 87 | + "Name the source names in your comparison.", 88 | PromptConstants.CONTENT, PromptConstants.QUESTION); 89 | 90 | private PromptTemplates() { 91 | // not instantiated 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /src/main/resources/log4j2.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /src/test/java/io/github/cupybara/javalangchains/chains/data/read/ReadDocumentsFromInMemoryPdfChainTest.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.read; 2 | 3 | import static org.junit.jupiter.api.Assertions.assertEquals; 4 | import static org.junit.jupiter.api.Assertions.assertNotNull; 5 | 6 | import java.io.IOException; 7 | import java.net.URISyntaxException; 8 | import java.util.List; 9 | import java.util.Map; 10 | import java.util.stream.Collectors; 11 | 12 | import org.apache.pdfbox.io.IOUtils; 13 | import org.junit.jupiter.api.BeforeAll; 14 | import org.junit.jupiter.api.Test; 15 | 16 | import io.github.cupybara.javalangchains.chains.data.reader.ReadDocumentsFromInMemoryPdfChain; 17 | import io.github.cupybara.javalangchains.chains.data.reader.ReadDocumentsFromInMemoryPdfChain.InMemoryPdf; 18 | import io.github.cupybara.javalangchains.chains.data.reader.ReadDocumentsFromPdfChainBase.PdfReadMode; 19 | import io.github.cupybara.javalangchains.util.PromptConstants; 20 | 21 | /** 22 | * Unit Tests for the {@link ReadDocumentsFromInMemoryPdfChainTest} 23 | */ 24 | class ReadDocumentsFromInMemoryPdfChainTest { 25 | 26 | private static InMemoryPdf inMemoryPdf; 27 | 28 | @BeforeAll 29 | static void setupBeforeAll() throws URISyntaxException, IOException { 30 | inMemoryPdf = new InMemoryPdf( 31 | IOUtils.toByteArray( 32 | ReadDocumentsFromInMemoryPdfChainTest.class.getResourceAsStream("/pdf/qa/book-of-john-3.pdf")), 33 | "my-in-memory.pdf"); 34 | } 35 | 36 | @Test 37 | void testReadWhole() { 38 | final List> documents = new ReadDocumentsFromInMemoryPdfChain(PdfReadMode.WHOLE) 39 | .run(inMemoryPdf).collect(Collectors.toList()); 40 | assertEquals(1, documents.size(), "incorrect number of read documents"); 41 | 42 | final Map doc = documents.get(0); 43 | assertNotNull(doc.get(PromptConstants.CONTENT), "got no content for doc"); 44 | assertEquals("my-in-memory.pdf", doc.get(PromptConstants.SOURCE), "got wrong source for doc"); 45 | } 46 | 47 | @Test 48 | void testReadPages() { 49 | final List> documents = new ReadDocumentsFromInMemoryPdfChain(PdfReadMode.PAGES) 50 | .run(inMemoryPdf).collect(Collectors.toList()); 51 | assertEquals(2, documents.size(), "incorrect number of read document pages"); 52 | 53 | final Map doc1 = documents.get(0); 54 | assertNotNull(doc1.get(PromptConstants.CONTENT), "got no content for doc1"); 55 | assertEquals("my-in-memory.pdf p.1", doc1.get(PromptConstants.SOURCE), "got wrong source for doc1"); 56 | 57 | final Map doc2 = documents.get(1); 58 | assertNotNull(doc2.get(PromptConstants.CONTENT), "got no content for doc2"); 59 | assertEquals("my-in-memory.pdf p.2", doc2.get(PromptConstants.SOURCE), "got wrong source for doc2"); 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/test/java/io/github/cupybara/javalangchains/chains/data/read/ReadDocumentsFromPdfChainTest.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.read; 2 | 3 | import static org.junit.jupiter.api.Assertions.assertEquals; 4 | import static org.junit.jupiter.api.Assertions.assertFalse; 5 | import static org.junit.jupiter.api.Assertions.assertNotNull; 6 | 7 | import java.net.URISyntaxException; 8 | import java.nio.file.Path; 9 | import java.nio.file.Paths; 10 | import java.util.List; 11 | import java.util.Map; 12 | import java.util.stream.Collectors; 13 | 14 | import org.junit.jupiter.api.BeforeAll; 15 | import org.junit.jupiter.api.Test; 16 | 17 | import io.github.cupybara.javalangchains.chains.data.reader.ReadDocumentsFromPdfChain; 18 | import io.github.cupybara.javalangchains.chains.data.reader.ReadDocumentsFromPdfChainBase.PdfReadMode; 19 | import io.github.cupybara.javalangchains.util.PromptConstants; 20 | 21 | /** 22 | * Unit Tests for the {@link ReadDocumentsFromPdfChainTest} 23 | */ 24 | class ReadDocumentsFromPdfChainTest { 25 | 26 | private static Path pdfDirectory; 27 | 28 | @BeforeAll 29 | static void setupBeforeAll() throws URISyntaxException { 30 | pdfDirectory = Paths.get(ReadDocumentsFromPdfChainTest.class.getResource("/pdf/qa").toURI()); 31 | } 32 | 33 | @Test 34 | void testReadWhole() { 35 | final List> documents = new ReadDocumentsFromPdfChain(PdfReadMode.WHOLE).run(pdfDirectory) 36 | .collect(Collectors.toList()); 37 | assertEquals(3, documents.size(), "incorrect number of read documents"); 38 | 39 | final Map doc1 = documents.get(0); 40 | assertNotNull(doc1.get(PromptConstants.CONTENT), "got no content for doc1"); 41 | assertEquals("book-of-john-1.pdf", doc1.get(PromptConstants.SOURCE), "got wrong source for doc1"); 42 | 43 | final Map doc2 = documents.get(1); 44 | assertNotNull(doc2.get(PromptConstants.CONTENT), "got no content for doc2"); 45 | assertEquals("book-of-john-2.pdf", doc2.get(PromptConstants.SOURCE), "got wrong source for doc2"); 46 | 47 | final Map doc3 = documents.get(2); 48 | assertNotNull(doc3.get(PromptConstants.CONTENT), "got no content for doc3"); 49 | assertEquals("book-of-john-3.pdf", doc3.get(PromptConstants.SOURCE), "got wrong source for doc3"); 50 | } 51 | 52 | @Test 53 | void testReadPages() { 54 | final List> documents = new ReadDocumentsFromPdfChain(PdfReadMode.PAGES).run(pdfDirectory) 55 | .collect(Collectors.toList()); 56 | assertEquals(4, documents.size(), "incorrect number of read document pages"); 57 | 58 | final Map doc1 = documents.get(0); 59 | assertNotNull(doc1.get(PromptConstants.CONTENT), "got no content for doc1"); 60 | assertFalse(doc1.get(PromptConstants.CONTENT).trim().isEmpty(), "got empty content for doc1"); 61 | assertEquals("book-of-john-1.pdf p.1", doc1.get(PromptConstants.SOURCE), "got wrong source for doc1"); 62 | 63 | final Map doc2 = documents.get(1); 64 | assertNotNull(doc2.get(PromptConstants.CONTENT), "got no content for doc2"); 65 | assertFalse(doc2.get(PromptConstants.CONTENT).trim().isEmpty(), "got empty content for doc2"); 66 | assertEquals("book-of-john-2.pdf p.1", doc2.get(PromptConstants.SOURCE), "got wrong source for doc2"); 67 | 68 | final Map doc3 = documents.get(2); 69 | assertNotNull(doc3.get(PromptConstants.CONTENT), "got no content for doc3"); 70 | assertFalse(doc3.get(PromptConstants.CONTENT).trim().isEmpty(), "got empty content for doc3"); 71 | assertEquals("book-of-john-3.pdf p.1", doc3.get(PromptConstants.SOURCE), "got wrong source for doc3"); 72 | 73 | final Map doc4 = documents.get(3); 74 | assertNotNull(doc4.get(PromptConstants.CONTENT), "got no content for doc4"); 75 | assertFalse(doc4.get(PromptConstants.CONTENT).trim().isEmpty(), "got empty content for doc4"); 76 | assertEquals("book-of-john-3.pdf p.2", doc4.get(PromptConstants.SOURCE), "got wrong source for doc4"); 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/test/java/io/github/cupybara/javalangchains/chains/data/retrieval/DocumentTestUtil.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.retrieval; 2 | 3 | import java.util.Arrays; 4 | import java.util.List; 5 | 6 | public class DocumentTestUtil { 7 | // @formatter:off 8 | public static final String DOCUMENT_1 = 9 | "Subject: John Doe's Biography\n" 10 | + "Dear Reader,\n" 11 | + "I am delighted to present to you the biography of John Doe, a remarkable individual who has left an indelible mark on society. Born and raised in a small town, John displayed an insatiable curiosity and a thirst for knowledge from a young age. He excelled academically, earning scholarships that allowed him to attend prestigious universities and pursue his passion for scientific research.\n" 12 | + "Throughout his career, John made groundbreaking discoveries in the field of medicine, revolutionizing treatment options for previously incurable diseases. His relentless dedication and tireless efforts have saved countless lives and earned him numerous accolades, including the Nobel Prize in Medicine.\n" 13 | + "However, John's impact extends far beyond his professional accomplishments. He is known for his philanthropic endeavors, establishing charitable foundations that provide support and opportunities to underprivileged communities. John's compassion and commitment to social justice have inspired many to follow in his footsteps.\n" 14 | + "In his personal life, John is a devoted family man. He cherishes the time spent with his loving wife and children, always prioritizing their well-being amidst his demanding schedule. Despite his remarkable success, John remains humble and grounded, never forgetting his roots and always seeking ways to uplift those around him.\n" 15 | + "In conclusion, John Doe is not only a brilliant scientist and humanitarian but also a role model for future generations. His unwavering determination, kindness, and pursuit of excellence make him a true legend.\n" 16 | + "Sincerely,\n" 17 | + "Jane Doe"; 18 | 19 | public static final String DOCUMENT_2 = 20 | "Subject: Invitation to John Doe's Art Exhibition\n" 21 | + "Dear Art Enthusiast,\n" 22 | + "We are pleased to invite you to a remarkable art exhibition featuring the mesmerizing works of John Doe. Renowned for his unique style and ability to capture the essence of emotions on canvas, John has curated a collection that will leave you awe-struck.\n" 23 | + "Drawing inspiration from his diverse life experiences, John's art tells compelling stories and invites viewers to delve into the depths of their imagination. Each stroke of the brush reveals a glimpse into his creative mind, conveying a range of emotions that resonate with the observer.\n" 24 | + "The exhibition will be held at the prestigious XYZ Art Gallery on [date] at [time]. It promises to be an evening filled with artistic brilliance, where you will have the opportunity to meet John Doe in person and gain insights into his creative process. Light refreshments will be served, providing a delightful ambiance for engaging discussions with fellow art enthusiasts.\n" 25 | + "Kindly RSVP by [RSVP date] to ensure your attendance at this exclusive event. We look forward to your presence and sharing this unforgettable artistic journey with you.\n" 26 | + "Yours sincerely,\n" 27 | + "Jane Doe"; 28 | 29 | public static final String DOCUMENT_3 = 30 | "Subject: John Doe's Travel Memoir - Exploring the Unknown\n" 31 | + "Dear Adventurers,\n" 32 | + "Prepare to embark on an extraordinary journey as we delve into the captivating travel memoir of John Doe. Throughout his life, John has traversed the globe, seeking out the hidden gems and immersing himself in diverse cultures. His memoir is a testament to the transformative power of travel and the profound impact it can have on one's perspective.\n" 33 | + "From the bustling streets of Tokyo to the serene beaches of Bali, John's vivid descriptions transport readers to each destination, allowing them to experience the sights, sounds, and flavors firsthand. With a keen eye for detail and a genuine curiosity for the world, he uncovers the untold stories that lie beneath the surface, providing a fresh and unique perspective.\n" 34 | + "Through his encounters with locals, John unearths the beauty of human connection and the universal language of kindness. He shares anecdotes that will make you laugh, moments that will leave you in awe, and reflections that will inspire you to embark on your own adventures.\n" 35 | + "This travel memoir not only serves as a guide to off-the-beaten-path destinations but also as a reminder of the inherent beauty and diversity of our planet. It encourages readers to step out of their comfort zones, embrace new experiences, and celebrate the richness of different cultures.\n" 36 | + "Whether you are an avid traveler or an armchair explorer, John Doe's memoir is a captivating read that will ignite your wanderlust and leave you yearning for new horizons. Join him on this literary expedition and discover the world through his eyes.\n" 37 | + "Happy reading,\n" 38 | + "Jane Doe"; 39 | 40 | public static final List DOCUMENTS = Arrays.asList(DocumentTestUtil.DOCUMENT_1, DocumentTestUtil.DOCUMENT_2, DocumentTestUtil.DOCUMENT_3); 41 | // @formatter:on 42 | 43 | private DocumentTestUtil() { 44 | 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/test/java/io/github/cupybara/javalangchains/chains/data/retrieval/ElasticsearchRetrievalChainIT.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.retrieval; 2 | 3 | import static org.junit.jupiter.api.Assertions.assertEquals; 4 | import static org.junit.jupiter.api.Assertions.assertTrue; 5 | 6 | import java.io.IOException; 7 | import java.net.URISyntaxException; 8 | import java.nio.file.Path; 9 | import java.nio.file.Paths; 10 | import java.util.List; 11 | import java.util.Map; 12 | import java.util.stream.Collectors; 13 | 14 | import org.apache.http.HttpHost; 15 | import org.elasticsearch.client.RestClient; 16 | import org.elasticsearch.client.RestClientBuilder; 17 | import org.junit.jupiter.api.BeforeAll; 18 | import org.junit.jupiter.api.Disabled; 19 | import org.junit.jupiter.api.Test; 20 | 21 | import io.github.cupybara.javalangchains.chains.Chain; 22 | import io.github.cupybara.javalangchains.chains.data.reader.ReadDocumentsFromPdfChain; 23 | import io.github.cupybara.javalangchains.chains.data.writer.WriteDocumentsToElasticsearchIndexChain; 24 | import io.github.cupybara.javalangchains.util.PromptConstants; 25 | 26 | /** 27 | * Tests for the {@link ElasticsearchRetrievalChain} 28 | */ 29 | @Disabled 30 | class ElasticsearchRetrievalChainIT { 31 | 32 | private static RestClientBuilder restClientBuilder; 33 | 34 | @BeforeAll 35 | static void beforeAll() throws URISyntaxException { 36 | restClientBuilder = RestClient.builder(new HttpHost("localhost", 9200)); 37 | 38 | final Chain createElasticsearchIndexChain = new ReadDocumentsFromPdfChain() 39 | .chain(new WriteDocumentsToElasticsearchIndexChain("my-index", restClientBuilder)); 40 | 41 | final Path pdfDirectoryPath = Paths.get(ElasticsearchRetrievalChainIT.class.getResource("/pdf/qa").toURI()); 42 | 43 | // creates and fills the elasticsearch index "my-index" 44 | createElasticsearchIndexChain.run(pdfDirectoryPath); 45 | } 46 | 47 | @Test 48 | void testRun() throws IOException { 49 | try (final RestClient restClient = restClientBuilder.build(); 50 | final ElasticsearchRetrievalChain retrievalChain = new ElasticsearchRetrievalChain("my-index", 51 | restClient, 1)) { 52 | 53 | final List> retrievedDocuments = retrievalChain.run("who is john doe?") 54 | .collect(Collectors.toList()); 55 | assertEquals(1, retrievedDocuments.size(), "incorrect number of retrieved documents"); 56 | 57 | final Map document = retrievedDocuments.get(0); 58 | 59 | assertTrue(document.containsKey(PromptConstants.QUESTION), "document does not contain question key"); 60 | assertTrue(document.containsKey(PromptConstants.SOURCE), "document does not contain source key"); 61 | assertTrue(document.containsKey(PromptConstants.CONTENT), "document does not contain content key"); 62 | 63 | assertEquals("who is john doe?", document.get(PromptConstants.QUESTION), "wrong question in document"); 64 | } 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/test/java/io/github/cupybara/javalangchains/chains/data/retrieval/JdbcRetrievalChainIT.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.retrieval; 2 | 3 | import static org.junit.jupiter.api.Assertions.assertEquals; 4 | 5 | import java.sql.Connection; 6 | import java.sql.DriverManager; 7 | import java.sql.PreparedStatement; 8 | import java.sql.ResultSet; 9 | import java.sql.SQLException; 10 | import java.sql.Statement; 11 | import java.util.List; 12 | import java.util.Map; 13 | import java.util.function.Supplier; 14 | import java.util.stream.Collectors; 15 | 16 | import org.junit.jupiter.api.BeforeAll; 17 | import org.junit.jupiter.api.Disabled; 18 | import org.junit.jupiter.api.Test; 19 | 20 | import io.github.cupybara.javalangchains.util.PromptConstants; 21 | 22 | @Disabled 23 | public class JdbcRetrievalChainIT { 24 | private static Supplier connectionSupplier; 25 | 26 | @BeforeAll 27 | static void setup() throws SQLException { 28 | final String connectionString = "jdbc:postgresql://localhost:5432/"; 29 | final String username = "postgres"; 30 | final String password = "admin"; 31 | 32 | Connection connection = DriverManager.getConnection(connectionString, username, password); 33 | 34 | Statement setupStatement = connection.createStatement(); 35 | 36 | ResultSet dbResult = setupStatement 37 | .executeQuery("SELECT datname FROM pg_catalog.pg_database WHERE datname='langchain_test'"); 38 | if (dbResult.next()) { 39 | setupStatement.execute("DROP DATABASE langchain_test"); 40 | } 41 | 42 | setupStatement.execute("CREATE DATABASE langchain_test"); 43 | 44 | setupStatement.close(); 45 | 46 | connection.setCatalog("langchain_test"); 47 | 48 | Statement createTableStatement = connection.createStatement(); 49 | if (connection.getMetaData().getTables("langchain_test", null, null, new String[] { "TABLE" }).next()) { 50 | createTableStatement.execute("DROP TABLE Documents"); 51 | } 52 | createTableStatement.execute( 53 | "CREATE TABLE Documents (source VARCHAR PRIMARY KEY, content VARCHAR, additional_attribute INTEGER)"); 54 | createTableStatement.close(); 55 | for (int i = 0; i < DocumentTestUtil.DOCUMENTS.size(); i++) { 56 | String content = DocumentTestUtil.DOCUMENTS.get(i); 57 | PreparedStatement seedStatement = connection 58 | .prepareStatement("INSERT INTO Documents(source, content, additional_attribute) VALUES (?, ?, 1)"); 59 | seedStatement.setString(1, Integer.toString(i)); 60 | seedStatement.setString(2, content); 61 | seedStatement.execute(); 62 | seedStatement.close(); 63 | } 64 | 65 | connectionSupplier = () -> connection; 66 | } 67 | 68 | @Test 69 | void testRun() throws SQLException { 70 | JdbcRetrievalChain jdbcRetrievalChain = new JdbcRetrievalChain(connectionSupplier, 1); 71 | 72 | final List> retrievedDocuments = jdbcRetrievalChain.run("who is john doe?") 73 | .collect(Collectors.toList()); 74 | assertEquals(1, retrievedDocuments.size(), "incorrect number of retrieved documents"); 75 | 76 | Map document = retrievedDocuments.get(0); 77 | assertEquals("0", document.get("source")); 78 | assertEquals("1", document.get("additional_attribute")); 79 | assertEquals(DocumentTestUtil.DOCUMENT_1, document.get(PromptConstants.CONTENT)); 80 | assertEquals("who is john doe?", document.get(PromptConstants.QUESTION)); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/test/java/io/github/cupybara/javalangchains/chains/data/retrieval/LuceneRetrievalChainTest.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.retrieval; 2 | 3 | import static org.junit.jupiter.api.Assertions.assertEquals; 4 | import static org.junit.jupiter.api.Assertions.assertFalse; 5 | import static org.junit.jupiter.api.Assertions.assertTrue; 6 | 7 | import java.io.File; 8 | import java.io.IOException; 9 | import java.nio.file.Files; 10 | import java.nio.file.Path; 11 | import java.util.Comparator; 12 | import java.util.List; 13 | import java.util.Map; 14 | import java.util.stream.Collectors; 15 | 16 | import org.apache.lucene.analysis.standard.StandardAnalyzer; 17 | import org.apache.lucene.document.Document; 18 | import org.apache.lucene.document.Field; 19 | import org.apache.lucene.document.StringField; 20 | import org.apache.lucene.document.TextField; 21 | import org.apache.lucene.index.IndexWriter; 22 | import org.apache.lucene.index.IndexWriterConfig; 23 | import org.apache.lucene.store.Directory; 24 | import org.apache.lucene.store.MMapDirectory; 25 | import org.junit.jupiter.api.AfterAll; 26 | import org.junit.jupiter.api.BeforeAll; 27 | import org.junit.jupiter.api.Test; 28 | 29 | import io.github.cupybara.javalangchains.util.PromptConstants; 30 | 31 | /** 32 | * Tests for the {@link LuceneRetrievalChain} 33 | */ 34 | class LuceneRetrievalChainTest { 35 | 36 | private static Path tempDirPath; 37 | private static Directory directory; 38 | 39 | @BeforeAll 40 | static void beforeAll() throws IOException { 41 | tempDirPath = Files.createTempDirectory("lucene"); 42 | directory = new MMapDirectory(tempDirPath); 43 | fillDirectory(directory); 44 | } 45 | 46 | static void fillDirectory(final Directory indexDirectory) throws IOException { 47 | final StandardAnalyzer analyzer = new StandardAnalyzer(); 48 | final IndexWriterConfig config = new IndexWriterConfig(analyzer); 49 | try (final IndexWriter indexWriter = new IndexWriter(indexDirectory, config)) { 50 | for (final String content : DocumentTestUtil.DOCUMENTS) { 51 | final Document doc = new Document(); 52 | doc.add(new TextField(PromptConstants.CONTENT, content, Field.Store.YES)); 53 | doc.add(new StringField(PromptConstants.SOURCE, 54 | String.valueOf(DocumentTestUtil.DOCUMENTS.indexOf(content) + 1), Field.Store.YES)); 55 | indexWriter.addDocument(doc); 56 | } 57 | 58 | indexWriter.commit(); 59 | } 60 | } 61 | 62 | @AfterAll 63 | static void afterAll() throws IOException { 64 | directory.close(); 65 | Files.walk(tempDirPath).sorted(Comparator.reverseOrder()).map(Path::toFile).forEach(File::delete); 66 | } 67 | 68 | @Test 69 | void testRun() throws IOException { 70 | try (final LuceneRetrievalChain retrievalChain = new LuceneRetrievalChain(directory, 2)) { 71 | final String question = "what kind of art does john make?"; 72 | 73 | final List> documents = retrievalChain.run(question).collect(Collectors.toList()); 74 | assertFalse(documents.isEmpty(), "no documents retrieved"); 75 | 76 | final Map mostRelevantDocument = documents.get(0); 77 | assertTrue(mostRelevantDocument.containsKey(PromptConstants.SOURCE), "source key is missing"); 78 | assertEquals("2", mostRelevantDocument.get(PromptConstants.SOURCE), "invalid source"); 79 | 80 | assertTrue(mostRelevantDocument.containsKey(PromptConstants.CONTENT), "content key is missing"); 81 | assertEquals(DocumentTestUtil.DOCUMENT_2, mostRelevantDocument.get(PromptConstants.CONTENT), 82 | "invalid content"); 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/test/java/io/github/cupybara/javalangchains/chains/data/writer/WriteDocumentsToElasticsearchIndexChainIT.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.data.writer; 2 | 3 | import java.net.URISyntaxException; 4 | import java.nio.file.Path; 5 | import java.nio.file.Paths; 6 | import java.util.Base64; 7 | 8 | import org.apache.http.HttpHost; 9 | import org.elasticsearch.client.RestClient; 10 | import org.junit.jupiter.api.Test; 11 | 12 | import com.fasterxml.jackson.databind.ObjectMapper; 13 | 14 | import io.github.cupybara.javalangchains.chains.Chain; 15 | import io.github.cupybara.javalangchains.chains.data.reader.ReadDocumentsFromPdfChain; 16 | import io.github.cupybara.javalangchains.util.PromptConstants; 17 | 18 | /** 19 | * tests for the {@link WriteDocumentsToElasticsearchIndexChain} 20 | */ 21 | class WriteDocumentsToElasticsearchIndexChainIT { 22 | 23 | @Test 24 | void testRun() throws URISyntaxException { 25 | Chain fillElasticsearchIndexChain = new ReadDocumentsFromPdfChain() 26 | .chain(new WriteDocumentsToElasticsearchIndexChain("my-index")); 27 | 28 | Path pdfDirectoryPath = Paths 29 | .get(WriteDocumentsToElasticsearchIndexChainIT.class.getResource("/pdf/qa").toURI()); 30 | 31 | fillElasticsearchIndexChain.run(pdfDirectoryPath); 32 | } 33 | 34 | @Test 35 | void testRunSpecificId() throws URISyntaxException { 36 | WriteDocumentsToElasticsearchIndexChain writeChain = new WriteDocumentsToElasticsearchIndexChain("my-index", 37 | RestClient.builder(new HttpHost("localhost", 9200)), new ObjectMapper(), 38 | doc -> Base64.getEncoder().encodeToString(doc.get(PromptConstants.SOURCE).getBytes())); 39 | 40 | Chain fillElasticsearchIndexChain = new ReadDocumentsFromPdfChain().chain(writeChain); 41 | 42 | Path pdfDirectoryPath = Paths 43 | .get(WriteDocumentsToElasticsearchIndexChainIT.class.getResource("/pdf/qa").toURI()); 44 | 45 | fillElasticsearchIndexChain.run(pdfDirectoryPath); 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/test/java/io/github/cupybara/javalangchains/chains/llm/azure/chat/AzureOpenAiChatCompletionsChainIT.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.azure.chat; 2 | 3 | import static org.junit.jupiter.api.Assertions.assertNotNull; 4 | import static org.junit.jupiter.api.Assertions.assertTrue; 5 | 6 | import java.util.Collections; 7 | import java.util.Map; 8 | 9 | import org.apache.logging.log4j.LogManager; 10 | import org.apache.logging.log4j.Logger; 11 | import org.junit.jupiter.api.Test; 12 | 13 | import io.github.cupybara.javalangchains.chains.Chain; 14 | import io.github.cupybara.javalangchains.chains.llm.openai.chat.OpenAiChatCompletionsParameters; 15 | 16 | /** 17 | * Integration Tests for the {@link AzureOpenAiChatCompletionsChain} 18 | */ 19 | class AzureOpenAiChatCompletionsChainIT { 20 | 21 | private static final Logger LOGGER = LogManager.getLogger(); 22 | 23 | @Test 24 | void testRun() { 25 | final AzureOpenAiChatCompletionsChain chain = new AzureOpenAiChatCompletionsChain( 26 | "my-azure-resource-name", "gpt-35-turbo", "2023-05-15", 27 | "Hello, this is ${name}. What was my name again?", 28 | new OpenAiChatCompletionsParameters(), System.getenv("AZURE_OPENAI_API_KEY")); 29 | 30 | final String name = "Manuel"; 31 | final String result = chain.run(Collections.singletonMap("name", name)); 32 | LOGGER.info(result); 33 | 34 | assertNotNull(result, "got no result from OpenAiChatCompletionsChain"); 35 | assertTrue(result.contains(name), "The answer did not contain the name"); 36 | } 37 | 38 | @Test 39 | void testChainedRun() { 40 | final OpenAiChatCompletionsParameters parameters = new OpenAiChatCompletionsParameters(); 41 | 42 | final Chain, String> chain = new AzureOpenAiChatCompletionsChain( 43 | "my-azure-resource-name", "gpt-35-turbo", "2023-05-15", 44 | "Hello, this is ${name}. What was my name again?", 45 | parameters, System.getenv("AZURE_OPENAI_API_KEY")) 46 | .chain(prev -> Collections.singletonMap("result", prev)) 47 | .chain(new AzureOpenAiChatCompletionsChain( 48 | "my-azure-resource-name", "gpt-35-turbo", "2023-05-15", 49 | "What was the question for the following answer: ${result}", 50 | parameters, System.getenv("AZURE_OPENAI_API_KEY"))); 51 | 52 | final String result = chain.run(Collections.singletonMap("name", "Manuel")); 53 | LOGGER.info(result); 54 | 55 | assertNotNull(result, "got no result from chain"); 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/test/java/io/github/cupybara/javalangchains/chains/llm/openai/chat/OpenAiChatCompletionsChainIT.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai.chat; 2 | 3 | import static org.junit.jupiter.api.Assertions.assertNotNull; 4 | import static org.junit.jupiter.api.Assertions.assertTrue; 5 | 6 | import java.util.Collections; 7 | import java.util.Map; 8 | 9 | import org.apache.logging.log4j.LogManager; 10 | import org.apache.logging.log4j.Logger; 11 | import org.junit.jupiter.api.Test; 12 | 13 | import io.github.cupybara.javalangchains.chains.Chain; 14 | 15 | /** 16 | * Integration Tests for the {@link OpenAiChatCompletionsChain} 17 | */ 18 | class OpenAiChatCompletionsChainIT { 19 | 20 | private static final Logger LOGGER = LogManager.getLogger(); 21 | 22 | @Test 23 | void testRun() { 24 | final OpenAiChatCompletionsParameters parameters = new OpenAiChatCompletionsParameters(); 25 | parameters.setModel("gpt-3.5-turbo"); 26 | parameters.setTemperature(0D); 27 | 28 | final OpenAiChatCompletionsChain chain = new OpenAiChatCompletionsChain( 29 | "Hello, this is ${name}. What was my name again?", parameters, System.getenv("OPENAI_API_KEY"), 30 | "You are a helpful assistant who answers questions to ${name}"); 31 | 32 | final String name = "Manuel"; 33 | final String result = chain.run(Collections.singletonMap("name", name)); 34 | LOGGER.info(result); 35 | 36 | assertNotNull(result, "got no result from OpenAiChatCompletionsChain"); 37 | assertTrue(result.contains(name), "The answer did not contain the name"); 38 | } 39 | 40 | @Test 41 | void testChainedRun() { 42 | final OpenAiChatCompletionsParameters parameters = new OpenAiChatCompletionsParameters(); 43 | parameters.setModel("gpt-3.5-turbo"); 44 | 45 | final Chain, String> chain = new OpenAiChatCompletionsChain( 46 | "Hello, this is ${name}. What is your name?", parameters, System.getenv("OPENAI_API_KEY")) 47 | .chain(prev -> Collections.singletonMap("result", prev)) 48 | .chain(new OpenAiChatCompletionsChain("What was the question for the following answer: ${result}", 49 | parameters, System.getenv("OPENAI_API_KEY"))); 50 | 51 | final String result = chain.run(Collections.singletonMap("name", "Manuel")); 52 | LOGGER.info(result); 53 | 54 | assertNotNull(result, "got no result from chain"); 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/test/java/io/github/cupybara/javalangchains/chains/llm/openai/completions/OpenAiCompletionsChainIT.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.llm.openai.completions; 2 | 3 | import static org.junit.jupiter.api.Assertions.assertNotNull; 4 | import static org.junit.jupiter.api.Assertions.assertTrue; 5 | 6 | import java.util.Collections; 7 | import java.util.Map; 8 | 9 | import org.apache.logging.log4j.LogManager; 10 | import org.apache.logging.log4j.Logger; 11 | import org.junit.jupiter.api.Test; 12 | 13 | import io.github.cupybara.javalangchains.chains.Chain; 14 | 15 | /** 16 | * Integration Tests for the {@link OpenAiCompletionsChain} 17 | */ 18 | class OpenAiCompletionsChainIT { 19 | 20 | private static final Logger LOGGER = LogManager.getLogger(); 21 | 22 | @Test 23 | void testRun() { 24 | final OpenAiCompletionsParameters parameters = new OpenAiCompletionsParameters(); 25 | parameters.setModel("text-davinci-003"); 26 | 27 | final OpenAiCompletionsChain chain = new OpenAiCompletionsChain( 28 | "Hello, this is ${name}. What was my name again?", parameters, System.getenv("OPENAI_API_KEY")); 29 | 30 | final String name = "Manuel"; 31 | final String result = chain.run(Collections.singletonMap("name", name)); 32 | LOGGER.info(result); 33 | 34 | assertNotNull(result, "got no result from OpenAiCompletionsChain"); 35 | assertTrue(result.contains(name), "The answer did not contain the name"); 36 | } 37 | 38 | @Test 39 | void testChainedRun() { 40 | final OpenAiCompletionsParameters parameters = new OpenAiCompletionsParameters(); 41 | parameters.setModel("text-davinci-003"); 42 | 43 | final Chain, String> chain = new OpenAiCompletionsChain( 44 | "Hello, this is ${name}. What is your name?", parameters, System.getenv("OPENAI_API_KEY")) 45 | .chain(prev -> Collections.singletonMap("result", prev)) 46 | .chain(new OpenAiCompletionsChain("What was the question for the following answer: ${result}", 47 | parameters, System.getenv("OPENAI_API_KEY"))); 48 | 49 | final String result = chain.run(Collections.singletonMap("name", "Manuel")); 50 | LOGGER.info(result); 51 | 52 | assertNotNull(result, "got no result from chain"); 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/test/java/io/github/cupybara/javalangchains/chains/qa/MapAnswerWithSourcesChainTest.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.qa; 2 | 3 | import static org.junit.jupiter.api.Assertions.assertEquals; 4 | import static org.junit.jupiter.api.Assertions.assertNotNull; 5 | import static org.junit.jupiter.api.Assertions.assertTrue; 6 | 7 | import org.junit.jupiter.api.Test; 8 | 9 | /** 10 | * unit tests for the class {@link MapAnswerWithSourcesChain} 11 | */ 12 | class MapAnswerWithSourcesChainTest { 13 | 14 | @Test 15 | void testRunSources() { 16 | final AnswerWithSources answerWithSources = new MapAnswerWithSourcesChain() 17 | .run("This is my test content. Sources: source-1, source-2"); 18 | assertNotNull(answerWithSources, "got no answer"); 19 | assertEquals("This is my test content.", answerWithSources.getAnswer(), "wrong answer"); 20 | assertEquals(2, answerWithSources.getSources().size(), "wrong sources count"); 21 | } 22 | 23 | @Test 24 | void testRunSource() { 25 | final AnswerWithSources answerWithSources = new MapAnswerWithSourcesChain() 26 | .run("This is my test content. Source: source-1"); 27 | assertNotNull(answerWithSources, "got no answer"); 28 | assertEquals("This is my test content.", answerWithSources.getAnswer(), "wrong answer"); 29 | assertEquals(1, answerWithSources.getSources().size(), "wrong sources count"); 30 | } 31 | 32 | @Test 33 | void testRunNoSource() { 34 | final AnswerWithSources answerWithSources = new MapAnswerWithSourcesChain().run("This is my test content."); 35 | assertNotNull(answerWithSources, "got no answer"); 36 | assertEquals("This is my test content.", answerWithSources.getAnswer(), "wrong answer"); 37 | assertTrue(answerWithSources.getSources().isEmpty(), "got sources bot there are none"); 38 | } 39 | 40 | @Test 41 | void testRunSourcesMultiline() { 42 | final AnswerWithSources answerWithSources = new MapAnswerWithSourcesChain() 43 | .run("This is my test content.\nSOURCES: source-1, source-2"); 44 | assertNotNull(answerWithSources, "got no answer"); 45 | assertEquals("This is my test content.", answerWithSources.getAnswer(), "wrong answer"); 46 | assertEquals(2, answerWithSources.getSources().size(), "wrong sources count"); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/test/java/io/github/cupybara/javalangchains/chains/qa/split/SplitDocumentsChainTest.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.qa.split; 2 | 3 | import static org.junit.jupiter.api.Assertions.assertEquals; 4 | import static org.junit.jupiter.api.Assertions.assertNotNull; 5 | 6 | import java.util.LinkedHashMap; 7 | import java.util.LinkedList; 8 | import java.util.List; 9 | import java.util.Map; 10 | import java.util.stream.Collectors; 11 | 12 | import org.junit.jupiter.api.BeforeAll; 13 | import org.junit.jupiter.api.Test; 14 | 15 | import com.knuddels.jtokkit.Encodings; 16 | import com.knuddels.jtokkit.api.EncodingType; 17 | 18 | import io.github.cupybara.javalangchains.util.PromptConstants; 19 | 20 | /** 21 | * Unit tests for the {@link SplitDocumentsChain} 22 | */ 23 | class SplitDocumentsChainTest { 24 | 25 | private static List> documents; 26 | 27 | @BeforeAll 28 | static void beforeAll() { 29 | documents = new LinkedList<>(); 30 | 31 | final Map firstDocument = new LinkedHashMap<>(); 32 | firstDocument.put(PromptConstants.SOURCE, "book of john"); 33 | firstDocument.put(PromptConstants.CONTENT, "This is a short text. This is another short text."); 34 | documents.add(firstDocument); 35 | 36 | final Map secondDocument = new LinkedHashMap<>(); 37 | secondDocument.put(PromptConstants.SOURCE, "book of jane"); 38 | secondDocument.put(PromptConstants.CONTENT, "This is a short text."); 39 | documents.add(secondDocument); 40 | } 41 | 42 | /** 43 | * tests the {@link SplitDocumentsChain} using a {@link JtokkitTextSplitter} 44 | */ 45 | @Test 46 | void testSplitDocumentsByTokenCount() { 47 | 48 | /* 49 | * We create a TextSplitter that splits a text into partitions using a JTokkit 50 | * Encoding. We use the cl100k_base encoding which is the default for 51 | * gpt-3.5-turbo. 52 | */ 53 | final TextSplitter tiktokenTextSplitter = new JtokkitTextSplitter( 54 | Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE), 10); 55 | 56 | /* 57 | * we now instantiate the SplitDocumentsChain which will split our documents 58 | * using the above created TextSplitter on the "content" field. 59 | */ 60 | final SplitDocumentsChain splitDocumentsChain = new SplitDocumentsChain(tiktokenTextSplitter); 61 | 62 | // perform the split and collect the items 63 | final List> splitDocuments = splitDocumentsChain.run(documents.stream()) 64 | .collect(Collectors.toList()); 65 | 66 | System.out.println(splitDocuments); 67 | assertNotNull(splitDocuments, "null result"); 68 | assertEquals(3, splitDocuments.size(), "wrong result size"); 69 | 70 | final Map firstDocument = splitDocuments.get(0); 71 | assertEquals("This is a short text. ", firstDocument.get(PromptConstants.CONTENT), 72 | "wrong first chunk of split document"); 73 | assertEquals("book of john", firstDocument.get(PromptConstants.SOURCE), "wrong source for firstDocument"); 74 | 75 | final Map secondDocument = splitDocuments.get(1); 76 | assertEquals("This is another short text.", secondDocument.get(PromptConstants.CONTENT), 77 | "wrong second chunk of split document"); 78 | assertEquals("book of john", secondDocument.get(PromptConstants.SOURCE), "wrong source for secondDocument"); 79 | 80 | final Map thirdDocument = splitDocuments.get(2); 81 | assertEquals("This is a short text.", thirdDocument.get(PromptConstants.CONTENT), 82 | "wrong content for thirdDocument"); 83 | assertEquals("book of jane", thirdDocument.get(PromptConstants.SOURCE), "wrong source for thirdDocument"); 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/test/java/io/github/cupybara/javalangchains/chains/qa/split/TextStreamerTest.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.chains.qa.split; 2 | 3 | import static org.junit.jupiter.api.Assertions.assertEquals; 4 | import static org.junit.jupiter.api.Assertions.assertNotNull; 5 | 6 | import java.util.List; 7 | import java.util.stream.Collectors; 8 | 9 | import org.junit.jupiter.api.Test; 10 | 11 | /** 12 | * Unit Tests for the {@link TextStreamer} 13 | */ 14 | class TextStreamerTest { 15 | 16 | private static final String TEXT_TO_SPLIT = "Hi there. This is an example text\nused for unit testing."; 17 | 18 | /** 19 | * Tests the no args constructor version of the {@link TextStreamer} 20 | */ 21 | @Test 22 | void testStreamSentences() { 23 | final List split = new TextStreamer() // no args constructor => sentences 24 | .stream(TEXT_TO_SPLIT).collect(Collectors.toList()); 25 | assertNotNull(split, "got null result"); 26 | assertEquals(2, split.size(), "wrong result count (2 sentences)"); 27 | assertEquals("Hi there. ", split.get(0), "first sentence is wrong"); 28 | assertEquals("This is an example text\nused for unit testing.", split.get(1), "second sentence is wrong"); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/test/java/io/github/cupybara/javalangchains/usecases/DocumentComparisonTest.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.usecases; 2 | 3 | import static org.junit.jupiter.api.Assertions.assertNotNull; 4 | 5 | import java.io.File; 6 | import java.io.IOException; 7 | import java.net.URISyntaxException; 8 | import java.nio.file.Files; 9 | import java.nio.file.Path; 10 | import java.nio.file.Paths; 11 | import java.util.Comparator; 12 | import java.util.LinkedHashMap; 13 | import java.util.LinkedHashSet; 14 | import java.util.LinkedList; 15 | import java.util.List; 16 | import java.util.Map; 17 | import java.util.Set; 18 | 19 | import org.apache.lucene.analysis.standard.StandardAnalyzer; 20 | import org.apache.lucene.queryparser.classic.ParseException; 21 | import org.apache.lucene.queryparser.classic.QueryParser; 22 | import org.apache.lucene.search.BooleanClause; 23 | import org.apache.lucene.search.BooleanQuery; 24 | import org.apache.lucene.search.Query; 25 | import org.apache.lucene.store.Directory; 26 | import org.junit.jupiter.api.AfterAll; 27 | import org.junit.jupiter.api.BeforeAll; 28 | import org.junit.jupiter.api.Test; 29 | 30 | import com.knuddels.jtokkit.Encodings; 31 | import com.knuddels.jtokkit.api.EncodingType; 32 | 33 | import io.github.cupybara.javalangchains.chains.Chain; 34 | import io.github.cupybara.javalangchains.chains.base.JoinChain; 35 | import io.github.cupybara.javalangchains.chains.base.StreamUnwrappingChain; 36 | import io.github.cupybara.javalangchains.chains.base.StreamWrappingChain; 37 | import io.github.cupybara.javalangchains.chains.data.reader.ReadDocumentsFromPdfChain; 38 | import io.github.cupybara.javalangchains.chains.data.retrieval.LuceneRetrievalChain; 39 | import io.github.cupybara.javalangchains.chains.data.writer.WriteDocumentsToLuceneDirectoryChain; 40 | import io.github.cupybara.javalangchains.chains.llm.openai.chat.OpenAiChatCompletionsChain; 41 | import io.github.cupybara.javalangchains.chains.llm.openai.chat.OpenAiChatCompletionsParameters; 42 | import io.github.cupybara.javalangchains.chains.qa.CombineDocumentsChain; 43 | import io.github.cupybara.javalangchains.chains.qa.ModifyDocumentsContentChain; 44 | import io.github.cupybara.javalangchains.chains.qa.split.JtokkitTextSplitter; 45 | import io.github.cupybara.javalangchains.chains.qa.split.SplitDocumentsChain; 46 | import io.github.cupybara.javalangchains.util.PromptConstants; 47 | import io.github.cupybara.javalangchains.util.PromptTemplates; 48 | 49 | /** 50 | * tests for a complete document comparing {@link Chain} 51 | * 52 | * we'll read two insurance policies and compare them by querying 53 | */ 54 | class DocumentComparisonTest { 55 | 56 | private static Path tempIndexPath; 57 | private static Directory directory; 58 | private static Set pdfSources; 59 | 60 | @BeforeAll 61 | static void beforeAll() throws IOException, URISyntaxException { 62 | tempIndexPath = Files.createTempDirectory("lucene"); 63 | pdfSources = new LinkedHashSet<>(); 64 | 65 | /* 66 | * We are also using a chain to create the lucene index directory 67 | */ 68 | final Chain createLuceneIndexChain = new ReadDocumentsFromPdfChain() 69 | // utility chain for storing all different pdf sources (multi compare) 70 | .chain(readDocuments -> readDocuments.map(doc -> { 71 | pdfSources.add(doc.get(PromptConstants.SOURCE)); 72 | return doc; 73 | })) 74 | // Optional Chain: split pdfs based on a max token count of 1000 75 | .chain(new SplitDocumentsChain(new JtokkitTextSplitter( 76 | Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE), 500))) 77 | // Mandatory Chain: write split pdf documents to a lucene directory 78 | .chain(new WriteDocumentsToLuceneDirectoryChain(tempIndexPath)); 79 | 80 | final Path pdfDirectoryPath = Paths.get(RetrievalQaTest.class.getResource("/pdf/comparison").toURI()); 81 | 82 | directory = createLuceneIndexChain.run(pdfDirectoryPath); 83 | } 84 | 85 | @AfterAll 86 | static void afterAll() throws IOException { 87 | directory.close(); 88 | Files.walk(tempIndexPath).sorted(Comparator.reverseOrder()).map(Path::toFile).forEach(File::delete); 89 | } 90 | 91 | @Test 92 | void testDocumentComparison() throws IOException { 93 | final List>> retrievalChains = new LinkedList<>(); 94 | for (final String pdfSource : pdfSources) { 95 | retrievalChains.add(createRetrievalChain(pdfSource)); 96 | } 97 | 98 | final Chain documentComparisonChain = new JoinChain<>(true /* parallel */, retrievalChains) 99 | .chain(new CombineDocumentsChain(PromptTemplates.QA_DOCUMENT)) 100 | .chain(new OpenAiChatCompletionsChain(PromptTemplates.COMPARE_MULTIPLE_DOCUMENTS, 101 | new OpenAiChatCompletionsParameters().temperature(0.8).model("gpt-3.5-turbo"), 102 | System.getenv("OPENAI_API_KEY"))); 103 | 104 | final String result = documentComparisonChain.run("how are claims processed?"); 105 | assertNotNull(result, "got null result"); 106 | 107 | System.out.println(result); 108 | } 109 | 110 | private Chain> createRetrievalChain(final String pdfSource) { 111 | final LuceneRetrievalChain retrievalChain = new LuceneRetrievalChain(directory, 4, 112 | content -> createQuery(pdfSource, content)); 113 | 114 | final ModifyDocumentsContentChain summarizeDocumentsChain = new ModifyDocumentsContentChain( 115 | new OpenAiChatCompletionsChain(PromptTemplates.QA_SUMMARIZE, 116 | new OpenAiChatCompletionsParameters().temperature(0.8).model("gpt-3.5-turbo"), 117 | System.getenv("OPENAI_API_KEY"))); 118 | 119 | final CombineDocumentsChain combineDocumentsChain = new CombineDocumentsChain(); 120 | 121 | final ModifyDocumentsContentChain compareChain = new ModifyDocumentsContentChain( 122 | new OpenAiChatCompletionsChain(PromptTemplates.QA_COMPARE, 123 | new OpenAiChatCompletionsParameters().temperature(0.8).model("gpt-3.5-turbo"), 124 | System.getenv("OPENAI_API_KEY"))); 125 | 126 | // @formatter:off 127 | return retrievalChain 128 | .chain(summarizeDocumentsChain) 129 | .chain(combineDocumentsChain) 130 | .chain(new StreamWrappingChain<>()) 131 | .chain(compareChain) 132 | .chain(new StreamUnwrappingChain<>()) 133 | .chain(llmOutput -> { 134 | final Map document = new LinkedHashMap<>(llmOutput); 135 | document.put(PromptConstants.SOURCE, pdfSource); 136 | return document; 137 | }); 138 | // @formatter:on 139 | } 140 | 141 | private Query createQuery(final String source, final String searchTerm) { 142 | final StandardAnalyzer analyzer = new StandardAnalyzer(); 143 | 144 | final QueryParser contentQueryParser = new QueryParser(PromptConstants.CONTENT, analyzer); 145 | final QueryParser sourceQueryParser = new QueryParser(PromptConstants.SOURCE, analyzer); 146 | 147 | try { 148 | final Query contentQuery = contentQueryParser.parse(searchTerm); 149 | final Query sourceQuery = sourceQueryParser.parse(source); 150 | 151 | // @formatter:off 152 | return new BooleanQuery.Builder() 153 | .add(sourceQuery, BooleanClause.Occur.FILTER) 154 | .add(contentQuery, BooleanClause.Occur.MUST) 155 | .build(); 156 | // @formatter:on 157 | } catch (final ParseException parseException) { 158 | throw new IllegalStateException("error creating query", parseException); 159 | } 160 | } 161 | } 162 | -------------------------------------------------------------------------------- /src/test/java/io/github/cupybara/javalangchains/usecases/RetrievalQaTest.java: -------------------------------------------------------------------------------- 1 | package io.github.cupybara.javalangchains.usecases; 2 | 3 | import static org.junit.jupiter.api.Assertions.assertFalse; 4 | import static org.junit.jupiter.api.Assertions.assertNotNull; 5 | 6 | import java.io.File; 7 | import java.io.IOException; 8 | import java.net.URISyntaxException; 9 | import java.nio.file.Files; 10 | import java.nio.file.Path; 11 | import java.nio.file.Paths; 12 | import java.util.Comparator; 13 | 14 | import org.apache.logging.log4j.LogManager; 15 | import org.apache.lucene.store.Directory; 16 | import org.junit.jupiter.api.AfterAll; 17 | import org.junit.jupiter.api.BeforeAll; 18 | import org.junit.jupiter.api.Test; 19 | 20 | import com.knuddels.jtokkit.Encodings; 21 | import com.knuddels.jtokkit.api.EncodingType; 22 | 23 | import io.github.cupybara.javalangchains.chains.Chain; 24 | import io.github.cupybara.javalangchains.chains.base.ApplyToStreamInputChain; 25 | import io.github.cupybara.javalangchains.chains.base.logging.LoggingChain; 26 | import io.github.cupybara.javalangchains.chains.data.reader.ReadDocumentsFromPdfChain; 27 | import io.github.cupybara.javalangchains.chains.data.retrieval.LuceneRetrievalChain; 28 | import io.github.cupybara.javalangchains.chains.data.writer.WriteDocumentsToLuceneDirectoryChain; 29 | import io.github.cupybara.javalangchains.chains.llm.openai.chat.OpenAiChatCompletionsChain; 30 | import io.github.cupybara.javalangchains.chains.llm.openai.chat.OpenAiChatCompletionsParameters; 31 | import io.github.cupybara.javalangchains.chains.qa.AnswerWithSources; 32 | import io.github.cupybara.javalangchains.chains.qa.CombineDocumentsChain; 33 | import io.github.cupybara.javalangchains.chains.qa.MapAnswerWithSourcesChain; 34 | import io.github.cupybara.javalangchains.chains.qa.ModifyDocumentsContentChain; 35 | import io.github.cupybara.javalangchains.chains.qa.split.JtokkitTextSplitter; 36 | import io.github.cupybara.javalangchains.chains.qa.split.SplitDocumentsChain; 37 | import io.github.cupybara.javalangchains.util.PromptTemplates; 38 | 39 | /** 40 | * tests for a complete qa {@link Chain} 41 | * 42 | * we'll read documents from our demo john doe pdfs at src/test/resources/pdf 43 | * and then ask questions about the protagonist. 44 | */ 45 | class RetrievalQaTest { 46 | 47 | private static Path tempIndexPath; 48 | private static Directory directory; 49 | 50 | @BeforeAll 51 | static void beforeAll() throws IOException, URISyntaxException { 52 | tempIndexPath = Files.createTempDirectory("lucene"); 53 | 54 | /* 55 | * We are also using a chain to create the lucene index directory 56 | */ 57 | final Chain createLuceneIndexChain = new ReadDocumentsFromPdfChain() 58 | // Optional Chain: split pdfs based on a max token count of 1000 59 | .chain(new SplitDocumentsChain(new JtokkitTextSplitter( 60 | Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE), 1000))) 61 | // Mandatory Chain: write split pdf documents to a lucene directory 62 | .chain(new WriteDocumentsToLuceneDirectoryChain(tempIndexPath)); 63 | 64 | final Path pdfDirectoryPath = Paths.get(RetrievalQaTest.class.getResource("/pdf/qa").toURI()); 65 | 66 | directory = createLuceneIndexChain.run(pdfDirectoryPath); 67 | } 68 | 69 | @AfterAll 70 | static void afterAll() throws IOException { 71 | directory.close(); 72 | Files.walk(tempIndexPath).sorted(Comparator.reverseOrder()).map(Path::toFile).forEach(File::delete); 73 | } 74 | 75 | @Test 76 | void testQa() throws IOException { 77 | final OpenAiChatCompletionsParameters openAiChatParameters = new OpenAiChatCompletionsParameters() 78 | .temperature(0D).model("gpt-3.5-turbo"); 79 | 80 | /* 81 | * Chain 1: The retrievalChain is used to retrieve relevant documents from an 82 | * index by using bm25 similarity 83 | */ 84 | try (final LuceneRetrievalChain retrievalChain = new LuceneRetrievalChain(directory, 2)) { 85 | 86 | /* 87 | * Chain 2: The summarizeDocumentsChain is used to summarize documents to only 88 | * contain the most relevant information. This is achieved using an OpenAI LLM 89 | * (gpt-3.5-turbo in this case) 90 | */ 91 | final ModifyDocumentsContentChain summarizeDocumentsChain = new ModifyDocumentsContentChain( 92 | new OpenAiChatCompletionsChain(PromptTemplates.QA_SUMMARIZE, openAiChatParameters, 93 | System.getenv("OPENAI_API_KEY"))); 94 | 95 | /* 96 | * Chain 3: The combineDocumentsChain is used to combine the retrieved documents 97 | * in a single prompt 98 | */ 99 | final CombineDocumentsChain combineDocumentsChain = new CombineDocumentsChain(); 100 | 101 | /* 102 | * Chain 4: The openAiChatChain is used to process the combined prompt using an 103 | * OpenAI LLM (gpt-3.5-turbo in this case) 104 | */ 105 | final OpenAiChatCompletionsChain openAiChatChain = new OpenAiChatCompletionsChain( 106 | PromptTemplates.QA_COMBINE, openAiChatParameters, System.getenv("OPENAI_API_KEY")); 107 | 108 | /* 109 | * Chain 5: The mapAnswerWithSourcesChain is used to map the llm string output 110 | * to a complex object using a regular expression which splits the sources and 111 | * the answer. 112 | */ 113 | final MapAnswerWithSourcesChain mapAnswerWithSourcesChain = new MapAnswerWithSourcesChain(); 114 | 115 | // @formatter:off 116 | // we combine all chain links into a self contained QA chain 117 | final Chain qaChain = retrievalChain 118 | .chain(summarizeDocumentsChain) 119 | .chain(new ApplyToStreamInputChain<>(new LoggingChain<>(LoggingChain.defaultLogPrefix("SUMMARIZED_DOCUMENT")))) 120 | .chain(combineDocumentsChain) 121 | .chain(new LoggingChain<>(LoggingChain.defaultLogPrefix("COMBINED_DOCUMENT"))) 122 | .chain(openAiChatChain) 123 | .chain(new LoggingChain<>(LoggingChain.defaultLogPrefix("LLM_RESULT"))) 124 | .chain(mapAnswerWithSourcesChain); 125 | // @formatter:on 126 | 127 | // the QA chain can now be called with a question and delivers an answer 128 | final AnswerWithSources answerToValidQuestion = qaChain.run("who is john doe?"); 129 | assertNotNull(answerToValidQuestion, "no answer provided"); 130 | assertFalse(answerToValidQuestion.getSources().isEmpty(), "no sources"); 131 | LogManager.getLogger().info("answer to valid question: {}", answerToValidQuestion); 132 | } 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /src/test/resources/pdf/comparison/galactic-journey-insurance.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cupybara/java-langchains/28a4ad79c72801e4df5dbe6cc5456b31686c487f/src/test/resources/pdf/comparison/galactic-journey-insurance.pdf -------------------------------------------------------------------------------- /src/test/resources/pdf/comparison/interplanetary-travel-insurance.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cupybara/java-langchains/28a4ad79c72801e4df5dbe6cc5456b31686c487f/src/test/resources/pdf/comparison/interplanetary-travel-insurance.pdf -------------------------------------------------------------------------------- /src/test/resources/pdf/qa/book-of-john-1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cupybara/java-langchains/28a4ad79c72801e4df5dbe6cc5456b31686c487f/src/test/resources/pdf/qa/book-of-john-1.pdf -------------------------------------------------------------------------------- /src/test/resources/pdf/qa/book-of-john-2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cupybara/java-langchains/28a4ad79c72801e4df5dbe6cc5456b31686c487f/src/test/resources/pdf/qa/book-of-john-2.pdf -------------------------------------------------------------------------------- /src/test/resources/pdf/qa/book-of-john-3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cupybara/java-langchains/28a4ad79c72801e4df5dbe6cc5456b31686c487f/src/test/resources/pdf/qa/book-of-john-3.pdf --------------------------------------------------------------------------------