├── .gitattributes ├── .gitignore ├── .mvn └── wrapper │ ├── maven-wrapper.jar │ └── maven-wrapper.properties ├── LICENSE ├── README.md ├── libs └── com │ └── github │ └── serpapi │ └── google-search-results-java │ └── 2.0.3 │ ├── google-search-results-java-2.0.3-sources.jar │ └── google-search-results-java-2.0.3.jar ├── mvnw ├── mvnw.cmd ├── pom.xml └── src ├── main └── java │ └── dev │ └── ai4j │ ├── agent │ └── tool │ │ └── webpage │ │ └── WebPageScrapperTool.java │ ├── document │ ├── loader │ │ ├── PdfFileLoader.java │ │ └── TextFileLoader.java │ └── splitter │ │ └── OverlappingDocumentSplitter.java │ ├── flows │ ├── ChatFlow.java │ └── DocumentQnAFlow.java │ ├── model │ ├── chat │ │ ├── OpenAiChatModel.java │ │ └── SimpleChatHistory.java │ ├── completion │ │ ├── OpenAiCompletionModel.java │ │ └── structured │ │ │ ├── Description.java │ │ │ └── Example.java │ ├── embedding │ │ └── OpenAiEmbeddingModel.java │ └── openai │ │ └── OpenAiModelName.java │ └── utils │ ├── Json.java │ ├── StopWatch.java │ └── Utils.java └── test └── java ├── TestIt.java ├── dev └── ai4j │ ├── document │ ├── loader │ │ ├── TextFileLoaderTest.java │ │ ├── test-file-iso-8859-1.txt │ │ └── test-file-utf8.txt │ └── splitter │ │ └── OverlappingDocumentSplitterTest.java │ ├── model │ └── chat │ │ └── SimpleChatHistoryTest.java │ └── utils │ └── TestUtils.java └── test-file.txt /.gitattributes: -------------------------------------------------------------------------------- 1 | # Handle line endings automatically for files detected as text 2 | # and leave all files detected as binary untouched. 3 | * text=auto 4 | 5 | # Force the following filetypes to have unix eols, so Windows does not break them 6 | *.* text eol=lf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | !.mvn/wrapper/maven-wrapper.jar 3 | !**/src/main/**/target/ 4 | !**/src/test/**/target/ 5 | 6 | ### IntelliJ IDEA ### 7 | .idea/* 8 | .idea/modules.xml 9 | .idea/jarRepositories.xml 10 | .idea/compiler.xml 11 | .idea/libraries/ 12 | *.iws 13 | *.iml 14 | *.ipr 15 | 16 | **/ApiKeys.java 17 | 18 | ### Eclipse ### 19 | .apt_generated 20 | .classpath 21 | .factorypath 22 | .project 23 | .settings 24 | .springBeans 25 | .sts4-cache 26 | 27 | ### NetBeans ### 28 | /nbproject/private/ 29 | /nbbuild/ 30 | /dist/ 31 | /nbdist/ 32 | /.nb-gradle/ 33 | build/ 34 | !**/src/main/**/build/ 35 | !**/src/test/**/build/ 36 | 37 | ### VS Code ### 38 | .vscode/ 39 | 40 | ### Mac OS ### 41 | .DS_Store -------------------------------------------------------------------------------- /.mvn/wrapper/maven-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-for-java/ai4j/3382afa163fea1cf6c642c1b471b9f56fdba98c9/.mvn/wrapper/maven-wrapper.jar -------------------------------------------------------------------------------- /.mvn/wrapper/maven-wrapper.properties: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.9.1/apache-maven-3.9.1-bin.zip 18 | wrapperUrl=https://repo.maven.apache.org/maven2/org/apache/maven/wrapper/maven-wrapper/3.2.0/maven-wrapper-3.2.0.jar 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Migrated to [LangChain4j](https://github.com/langchain4j/langchain4j) 2 | -------------------------------------------------------------------------------- /libs/com/github/serpapi/google-search-results-java/2.0.3/google-search-results-java-2.0.3-sources.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-for-java/ai4j/3382afa163fea1cf6c642c1b471b9f56fdba98c9/libs/com/github/serpapi/google-search-results-java/2.0.3/google-search-results-java-2.0.3-sources.jar -------------------------------------------------------------------------------- /libs/com/github/serpapi/google-search-results-java/2.0.3/google-search-results-java-2.0.3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-for-java/ai4j/3382afa163fea1cf6c642c1b471b9f56fdba98c9/libs/com/github/serpapi/google-search-results-java/2.0.3/google-search-results-java-2.0.3.jar -------------------------------------------------------------------------------- /mvnw: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # ---------------------------------------------------------------------------- 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # ---------------------------------------------------------------------------- 20 | 21 | # ---------------------------------------------------------------------------- 22 | # Apache Maven Wrapper startup batch script, version 3.2.0 23 | # 24 | # Required ENV vars: 25 | # ------------------ 26 | # JAVA_HOME - location of a JDK home dir 27 | # 28 | # Optional ENV vars 29 | # ----------------- 30 | # MAVEN_OPTS - parameters passed to the Java VM when running Maven 31 | # e.g. to debug Maven itself, use 32 | # set MAVEN_OPTS=-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=8000 33 | # MAVEN_SKIP_RC - flag to disable loading of mavenrc files 34 | # ---------------------------------------------------------------------------- 35 | 36 | if [ -z "$MAVEN_SKIP_RC" ] ; then 37 | 38 | if [ -f /usr/local/etc/mavenrc ] ; then 39 | . /usr/local/etc/mavenrc 40 | fi 41 | 42 | if [ -f /etc/mavenrc ] ; then 43 | . /etc/mavenrc 44 | fi 45 | 46 | if [ -f "$HOME/.mavenrc" ] ; then 47 | . "$HOME/.mavenrc" 48 | fi 49 | 50 | fi 51 | 52 | # OS specific support. $var _must_ be set to either true or false. 53 | cygwin=false; 54 | darwin=false; 55 | mingw=false 56 | case "$(uname)" in 57 | CYGWIN*) cygwin=true ;; 58 | MINGW*) mingw=true;; 59 | Darwin*) darwin=true 60 | # Use /usr/libexec/java_home if available, otherwise fall back to /Library/Java/Home 61 | # See https://developer.apple.com/library/mac/qa/qa1170/_index.html 62 | if [ -z "$JAVA_HOME" ]; then 63 | if [ -x "/usr/libexec/java_home" ]; then 64 | JAVA_HOME="$(/usr/libexec/java_home)"; export JAVA_HOME 65 | else 66 | JAVA_HOME="/Library/Java/Home"; export JAVA_HOME 67 | fi 68 | fi 69 | ;; 70 | esac 71 | 72 | if [ -z "$JAVA_HOME" ] ; then 73 | if [ -r /etc/gentoo-release ] ; then 74 | JAVA_HOME=$(java-config --jre-home) 75 | fi 76 | fi 77 | 78 | # For Cygwin, ensure paths are in UNIX format before anything is touched 79 | if $cygwin ; then 80 | [ -n "$JAVA_HOME" ] && 81 | JAVA_HOME=$(cygpath --unix "$JAVA_HOME") 82 | [ -n "$CLASSPATH" ] && 83 | CLASSPATH=$(cygpath --path --unix "$CLASSPATH") 84 | fi 85 | 86 | # For Mingw, ensure paths are in UNIX format before anything is touched 87 | if $mingw ; then 88 | [ -n "$JAVA_HOME" ] && [ -d "$JAVA_HOME" ] && 89 | JAVA_HOME="$(cd "$JAVA_HOME" || (echo "cannot cd into $JAVA_HOME."; exit 1); pwd)" 90 | fi 91 | 92 | if [ -z "$JAVA_HOME" ]; then 93 | javaExecutable="$(which javac)" 94 | if [ -n "$javaExecutable" ] && ! [ "$(expr "\"$javaExecutable\"" : '\([^ ]*\)')" = "no" ]; then 95 | # readlink(1) is not available as standard on Solaris 10. 96 | readLink=$(which readlink) 97 | if [ ! "$(expr "$readLink" : '\([^ ]*\)')" = "no" ]; then 98 | if $darwin ; then 99 | javaHome="$(dirname "\"$javaExecutable\"")" 100 | javaExecutable="$(cd "\"$javaHome\"" && pwd -P)/javac" 101 | else 102 | javaExecutable="$(readlink -f "\"$javaExecutable\"")" 103 | fi 104 | javaHome="$(dirname "\"$javaExecutable\"")" 105 | javaHome=$(expr "$javaHome" : '\(.*\)/bin') 106 | JAVA_HOME="$javaHome" 107 | export JAVA_HOME 108 | fi 109 | fi 110 | fi 111 | 112 | if [ -z "$JAVACMD" ] ; then 113 | if [ -n "$JAVA_HOME" ] ; then 114 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 115 | # IBM's JDK on AIX uses strange locations for the executables 116 | JAVACMD="$JAVA_HOME/jre/sh/java" 117 | else 118 | JAVACMD="$JAVA_HOME/bin/java" 119 | fi 120 | else 121 | JAVACMD="$(\unset -f command 2>/dev/null; \command -v java)" 122 | fi 123 | fi 124 | 125 | if [ ! -x "$JAVACMD" ] ; then 126 | echo "Error: JAVA_HOME is not defined correctly." >&2 127 | echo " We cannot execute $JAVACMD" >&2 128 | exit 1 129 | fi 130 | 131 | if [ -z "$JAVA_HOME" ] ; then 132 | echo "Warning: JAVA_HOME environment variable is not set." 133 | fi 134 | 135 | # traverses directory structure from process work directory to filesystem root 136 | # first directory with .mvn subdirectory is considered project base directory 137 | find_maven_basedir() { 138 | if [ -z "$1" ] 139 | then 140 | echo "Path not specified to find_maven_basedir" 141 | return 1 142 | fi 143 | 144 | basedir="$1" 145 | wdir="$1" 146 | while [ "$wdir" != '/' ] ; do 147 | if [ -d "$wdir"/.mvn ] ; then 148 | basedir=$wdir 149 | break 150 | fi 151 | # workaround for JBEAP-8937 (on Solaris 10/Sparc) 152 | if [ -d "${wdir}" ]; then 153 | wdir=$(cd "$wdir/.." || exit 1; pwd) 154 | fi 155 | # end of workaround 156 | done 157 | printf '%s' "$(cd "$basedir" || exit 1; pwd)" 158 | } 159 | 160 | # concatenates all lines of a file 161 | concat_lines() { 162 | if [ -f "$1" ]; then 163 | # Remove \r in case we run on Windows within Git Bash 164 | # and check out the repository with auto CRLF management 165 | # enabled. Otherwise, we may read lines that are delimited with 166 | # \r\n and produce $'-Xarg\r' rather than -Xarg due to word 167 | # splitting rules. 168 | tr -s '\r\n' ' ' < "$1" 169 | fi 170 | } 171 | 172 | log() { 173 | if [ "$MVNW_VERBOSE" = true ]; then 174 | printf '%s\n' "$1" 175 | fi 176 | } 177 | 178 | BASE_DIR=$(find_maven_basedir "$(dirname "$0")") 179 | if [ -z "$BASE_DIR" ]; then 180 | exit 1; 181 | fi 182 | 183 | MAVEN_PROJECTBASEDIR=${MAVEN_BASEDIR:-"$BASE_DIR"}; export MAVEN_PROJECTBASEDIR 184 | log "$MAVEN_PROJECTBASEDIR" 185 | 186 | ########################################################################################## 187 | # Extension to allow automatically downloading the maven-wrapper.jar from Maven-central 188 | # This allows using the maven wrapper in projects that prohibit checking in binary data. 189 | ########################################################################################## 190 | wrapperJarPath="$MAVEN_PROJECTBASEDIR/.mvn/wrapper/maven-wrapper.jar" 191 | if [ -r "$wrapperJarPath" ]; then 192 | log "Found $wrapperJarPath" 193 | else 194 | log "Couldn't find $wrapperJarPath, downloading it ..." 195 | 196 | if [ -n "$MVNW_REPOURL" ]; then 197 | wrapperUrl="$MVNW_REPOURL/org/apache/maven/wrapper/maven-wrapper/3.2.0/maven-wrapper-3.2.0.jar" 198 | else 199 | wrapperUrl="https://repo.maven.apache.org/maven2/org/apache/maven/wrapper/maven-wrapper/3.2.0/maven-wrapper-3.2.0.jar" 200 | fi 201 | while IFS="=" read -r key value; do 202 | # Remove '\r' from value to allow usage on windows as IFS does not consider '\r' as a separator ( considers space, tab, new line ('\n'), and custom '=' ) 203 | safeValue=$(echo "$value" | tr -d '\r') 204 | case "$key" in (wrapperUrl) wrapperUrl="$safeValue"; break ;; 205 | esac 206 | done < "$MAVEN_PROJECTBASEDIR/.mvn/wrapper/maven-wrapper.properties" 207 | log "Downloading from: $wrapperUrl" 208 | 209 | if $cygwin; then 210 | wrapperJarPath=$(cygpath --path --windows "$wrapperJarPath") 211 | fi 212 | 213 | if command -v wget > /dev/null; then 214 | log "Found wget ... using wget" 215 | [ "$MVNW_VERBOSE" = true ] && QUIET="" || QUIET="--quiet" 216 | if [ -z "$MVNW_USERNAME" ] || [ -z "$MVNW_PASSWORD" ]; then 217 | wget $QUIET "$wrapperUrl" -O "$wrapperJarPath" || rm -f "$wrapperJarPath" 218 | else 219 | wget $QUIET --http-user="$MVNW_USERNAME" --http-password="$MVNW_PASSWORD" "$wrapperUrl" -O "$wrapperJarPath" || rm -f "$wrapperJarPath" 220 | fi 221 | elif command -v curl > /dev/null; then 222 | log "Found curl ... using curl" 223 | [ "$MVNW_VERBOSE" = true ] && QUIET="" || QUIET="--silent" 224 | if [ -z "$MVNW_USERNAME" ] || [ -z "$MVNW_PASSWORD" ]; then 225 | curl $QUIET -o "$wrapperJarPath" "$wrapperUrl" -f -L || rm -f "$wrapperJarPath" 226 | else 227 | curl $QUIET --user "$MVNW_USERNAME:$MVNW_PASSWORD" -o "$wrapperJarPath" "$wrapperUrl" -f -L || rm -f "$wrapperJarPath" 228 | fi 229 | else 230 | log "Falling back to using Java to download" 231 | javaSource="$MAVEN_PROJECTBASEDIR/.mvn/wrapper/MavenWrapperDownloader.java" 232 | javaClass="$MAVEN_PROJECTBASEDIR/.mvn/wrapper/MavenWrapperDownloader.class" 233 | # For Cygwin, switch paths to Windows format before running javac 234 | if $cygwin; then 235 | javaSource=$(cygpath --path --windows "$javaSource") 236 | javaClass=$(cygpath --path --windows "$javaClass") 237 | fi 238 | if [ -e "$javaSource" ]; then 239 | if [ ! -e "$javaClass" ]; then 240 | log " - Compiling MavenWrapperDownloader.java ..." 241 | ("$JAVA_HOME/bin/javac" "$javaSource") 242 | fi 243 | if [ -e "$javaClass" ]; then 244 | log " - Running MavenWrapperDownloader.java ..." 245 | ("$JAVA_HOME/bin/java" -cp .mvn/wrapper MavenWrapperDownloader "$wrapperUrl" "$wrapperJarPath") || rm -f "$wrapperJarPath" 246 | fi 247 | fi 248 | fi 249 | fi 250 | ########################################################################################## 251 | # End of extension 252 | ########################################################################################## 253 | 254 | # If specified, validate the SHA-256 sum of the Maven wrapper jar file 255 | wrapperSha256Sum="" 256 | while IFS="=" read -r key value; do 257 | case "$key" in (wrapperSha256Sum) wrapperSha256Sum=$value; break ;; 258 | esac 259 | done < "$MAVEN_PROJECTBASEDIR/.mvn/wrapper/maven-wrapper.properties" 260 | if [ -n "$wrapperSha256Sum" ]; then 261 | wrapperSha256Result=false 262 | if command -v sha256sum > /dev/null; then 263 | if echo "$wrapperSha256Sum $wrapperJarPath" | sha256sum -c > /dev/null 2>&1; then 264 | wrapperSha256Result=true 265 | fi 266 | elif command -v shasum > /dev/null; then 267 | if echo "$wrapperSha256Sum $wrapperJarPath" | shasum -a 256 -c > /dev/null 2>&1; then 268 | wrapperSha256Result=true 269 | fi 270 | else 271 | echo "Checksum validation was requested but neither 'sha256sum' or 'shasum' are available." 272 | echo "Please install either command, or disable validation by removing 'wrapperSha256Sum' from your maven-wrapper.properties." 273 | exit 1 274 | fi 275 | if [ $wrapperSha256Result = false ]; then 276 | echo "Error: Failed to validate Maven wrapper SHA-256, your Maven wrapper might be compromised." >&2 277 | echo "Investigate or delete $wrapperJarPath to attempt a clean download." >&2 278 | echo "If you updated your Maven version, you need to update the specified wrapperSha256Sum property." >&2 279 | exit 1 280 | fi 281 | fi 282 | 283 | MAVEN_OPTS="$(concat_lines "$MAVEN_PROJECTBASEDIR/.mvn/jvm.config") $MAVEN_OPTS" 284 | 285 | # For Cygwin, switch paths to Windows format before running java 286 | if $cygwin; then 287 | [ -n "$JAVA_HOME" ] && 288 | JAVA_HOME=$(cygpath --path --windows "$JAVA_HOME") 289 | [ -n "$CLASSPATH" ] && 290 | CLASSPATH=$(cygpath --path --windows "$CLASSPATH") 291 | [ -n "$MAVEN_PROJECTBASEDIR" ] && 292 | MAVEN_PROJECTBASEDIR=$(cygpath --path --windows "$MAVEN_PROJECTBASEDIR") 293 | fi 294 | 295 | # Provide a "standardized" way to retrieve the CLI args that will 296 | # work with both Windows and non-Windows executions. 297 | MAVEN_CMD_LINE_ARGS="$MAVEN_CONFIG $*" 298 | export MAVEN_CMD_LINE_ARGS 299 | 300 | WRAPPER_LAUNCHER=org.apache.maven.wrapper.MavenWrapperMain 301 | 302 | # shellcheck disable=SC2086 # safe args 303 | exec "$JAVACMD" \ 304 | $MAVEN_OPTS \ 305 | $MAVEN_DEBUG_OPTS \ 306 | -classpath "$MAVEN_PROJECTBASEDIR/.mvn/wrapper/maven-wrapper.jar" \ 307 | "-Dmaven.multiModuleProjectDirectory=${MAVEN_PROJECTBASEDIR}" \ 308 | ${WRAPPER_LAUNCHER} $MAVEN_CONFIG "$@" 309 | -------------------------------------------------------------------------------- /mvnw.cmd: -------------------------------------------------------------------------------- 1 | @REM ---------------------------------------------------------------------------- 2 | @REM Licensed to the Apache Software Foundation (ASF) under one 3 | @REM or more contributor license agreements. See the NOTICE file 4 | @REM distributed with this work for additional information 5 | @REM regarding copyright ownership. The ASF licenses this file 6 | @REM to you under the Apache License, Version 2.0 (the 7 | @REM "License"); you may not use this file except in compliance 8 | @REM with the License. You may obtain a copy of the License at 9 | @REM 10 | @REM http://www.apache.org/licenses/LICENSE-2.0 11 | @REM 12 | @REM Unless required by applicable law or agreed to in writing, 13 | @REM software distributed under the License is distributed on an 14 | @REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | @REM KIND, either express or implied. See the License for the 16 | @REM specific language governing permissions and limitations 17 | @REM under the License. 18 | @REM ---------------------------------------------------------------------------- 19 | 20 | @REM ---------------------------------------------------------------------------- 21 | @REM Apache Maven Wrapper startup batch script, version 3.2.0 22 | @REM 23 | @REM Required ENV vars: 24 | @REM JAVA_HOME - location of a JDK home dir 25 | @REM 26 | @REM Optional ENV vars 27 | @REM MAVEN_BATCH_ECHO - set to 'on' to enable the echoing of the batch commands 28 | @REM MAVEN_BATCH_PAUSE - set to 'on' to wait for a keystroke before ending 29 | @REM MAVEN_OPTS - parameters passed to the Java VM when running Maven 30 | @REM e.g. to debug Maven itself, use 31 | @REM set MAVEN_OPTS=-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=8000 32 | @REM MAVEN_SKIP_RC - flag to disable loading of mavenrc files 33 | @REM ---------------------------------------------------------------------------- 34 | 35 | @REM Begin all REM lines with '@' in case MAVEN_BATCH_ECHO is 'on' 36 | @echo off 37 | @REM set title of command window 38 | title %0 39 | @REM enable echoing by setting MAVEN_BATCH_ECHO to 'on' 40 | @if "%MAVEN_BATCH_ECHO%" == "on" echo %MAVEN_BATCH_ECHO% 41 | 42 | @REM set %HOME% to equivalent of $HOME 43 | if "%HOME%" == "" (set "HOME=%HOMEDRIVE%%HOMEPATH%") 44 | 45 | @REM Execute a user defined script before this one 46 | if not "%MAVEN_SKIP_RC%" == "" goto skipRcPre 47 | @REM check for pre script, once with legacy .bat ending and once with .cmd ending 48 | if exist "%USERPROFILE%\mavenrc_pre.bat" call "%USERPROFILE%\mavenrc_pre.bat" %* 49 | if exist "%USERPROFILE%\mavenrc_pre.cmd" call "%USERPROFILE%\mavenrc_pre.cmd" %* 50 | :skipRcPre 51 | 52 | @setlocal 53 | 54 | set ERROR_CODE=0 55 | 56 | @REM To isolate internal variables from possible post scripts, we use another setlocal 57 | @setlocal 58 | 59 | @REM ==== START VALIDATION ==== 60 | if not "%JAVA_HOME%" == "" goto OkJHome 61 | 62 | echo. 63 | echo Error: JAVA_HOME not found in your environment. >&2 64 | echo Please set the JAVA_HOME variable in your environment to match the >&2 65 | echo location of your Java installation. >&2 66 | echo. 67 | goto error 68 | 69 | :OkJHome 70 | if exist "%JAVA_HOME%\bin\java.exe" goto init 71 | 72 | echo. 73 | echo Error: JAVA_HOME is set to an invalid directory. >&2 74 | echo JAVA_HOME = "%JAVA_HOME%" >&2 75 | echo Please set the JAVA_HOME variable in your environment to match the >&2 76 | echo location of your Java installation. >&2 77 | echo. 78 | goto error 79 | 80 | @REM ==== END VALIDATION ==== 81 | 82 | :init 83 | 84 | @REM Find the project base dir, i.e. the directory that contains the folder ".mvn". 85 | @REM Fallback to current working directory if not found. 86 | 87 | set MAVEN_PROJECTBASEDIR=%MAVEN_BASEDIR% 88 | IF NOT "%MAVEN_PROJECTBASEDIR%"=="" goto endDetectBaseDir 89 | 90 | set EXEC_DIR=%CD% 91 | set WDIR=%EXEC_DIR% 92 | :findBaseDir 93 | IF EXIST "%WDIR%"\.mvn goto baseDirFound 94 | cd .. 95 | IF "%WDIR%"=="%CD%" goto baseDirNotFound 96 | set WDIR=%CD% 97 | goto findBaseDir 98 | 99 | :baseDirFound 100 | set MAVEN_PROJECTBASEDIR=%WDIR% 101 | cd "%EXEC_DIR%" 102 | goto endDetectBaseDir 103 | 104 | :baseDirNotFound 105 | set MAVEN_PROJECTBASEDIR=%EXEC_DIR% 106 | cd "%EXEC_DIR%" 107 | 108 | :endDetectBaseDir 109 | 110 | IF NOT EXIST "%MAVEN_PROJECTBASEDIR%\.mvn\jvm.config" goto endReadAdditionalConfig 111 | 112 | @setlocal EnableExtensions EnableDelayedExpansion 113 | for /F "usebackq delims=" %%a in ("%MAVEN_PROJECTBASEDIR%\.mvn\jvm.config") do set JVM_CONFIG_MAVEN_PROPS=!JVM_CONFIG_MAVEN_PROPS! %%a 114 | @endlocal & set JVM_CONFIG_MAVEN_PROPS=%JVM_CONFIG_MAVEN_PROPS% 115 | 116 | :endReadAdditionalConfig 117 | 118 | SET MAVEN_JAVA_EXE="%JAVA_HOME%\bin\java.exe" 119 | set WRAPPER_JAR="%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.jar" 120 | set WRAPPER_LAUNCHER=org.apache.maven.wrapper.MavenWrapperMain 121 | 122 | set WRAPPER_URL="https://repo.maven.apache.org/maven2/org/apache/maven/wrapper/maven-wrapper/3.2.0/maven-wrapper-3.2.0.jar" 123 | 124 | FOR /F "usebackq tokens=1,2 delims==" %%A IN ("%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.properties") DO ( 125 | IF "%%A"=="wrapperUrl" SET WRAPPER_URL=%%B 126 | ) 127 | 128 | @REM Extension to allow automatically downloading the maven-wrapper.jar from Maven-central 129 | @REM This allows using the maven wrapper in projects that prohibit checking in binary data. 130 | if exist %WRAPPER_JAR% ( 131 | if "%MVNW_VERBOSE%" == "true" ( 132 | echo Found %WRAPPER_JAR% 133 | ) 134 | ) else ( 135 | if not "%MVNW_REPOURL%" == "" ( 136 | SET WRAPPER_URL="%MVNW_REPOURL%/org/apache/maven/wrapper/maven-wrapper/3.2.0/maven-wrapper-3.2.0.jar" 137 | ) 138 | if "%MVNW_VERBOSE%" == "true" ( 139 | echo Couldn't find %WRAPPER_JAR%, downloading it ... 140 | echo Downloading from: %WRAPPER_URL% 141 | ) 142 | 143 | powershell -Command "&{"^ 144 | "$webclient = new-object System.Net.WebClient;"^ 145 | "if (-not ([string]::IsNullOrEmpty('%MVNW_USERNAME%') -and [string]::IsNullOrEmpty('%MVNW_PASSWORD%'))) {"^ 146 | "$webclient.Credentials = new-object System.Net.NetworkCredential('%MVNW_USERNAME%', '%MVNW_PASSWORD%');"^ 147 | "}"^ 148 | "[Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; $webclient.DownloadFile('%WRAPPER_URL%', '%WRAPPER_JAR%')"^ 149 | "}" 150 | if "%MVNW_VERBOSE%" == "true" ( 151 | echo Finished downloading %WRAPPER_JAR% 152 | ) 153 | ) 154 | @REM End of extension 155 | 156 | @REM If specified, validate the SHA-256 sum of the Maven wrapper jar file 157 | SET WRAPPER_SHA_256_SUM="" 158 | FOR /F "usebackq tokens=1,2 delims==" %%A IN ("%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.properties") DO ( 159 | IF "%%A"=="wrapperSha256Sum" SET WRAPPER_SHA_256_SUM=%%B 160 | ) 161 | IF NOT %WRAPPER_SHA_256_SUM%=="" ( 162 | powershell -Command "&{"^ 163 | "$hash = (Get-FileHash \"%WRAPPER_JAR%\" -Algorithm SHA256).Hash.ToLower();"^ 164 | "If('%WRAPPER_SHA_256_SUM%' -ne $hash){"^ 165 | " Write-Output 'Error: Failed to validate Maven wrapper SHA-256, your Maven wrapper might be compromised.';"^ 166 | " Write-Output 'Investigate or delete %WRAPPER_JAR% to attempt a clean download.';"^ 167 | " Write-Output 'If you updated your Maven version, you need to update the specified wrapperSha256Sum property.';"^ 168 | " exit 1;"^ 169 | "}"^ 170 | "}" 171 | if ERRORLEVEL 1 goto error 172 | ) 173 | 174 | @REM Provide a "standardized" way to retrieve the CLI args that will 175 | @REM work with both Windows and non-Windows executions. 176 | set MAVEN_CMD_LINE_ARGS=%* 177 | 178 | %MAVEN_JAVA_EXE% ^ 179 | %JVM_CONFIG_MAVEN_PROPS% ^ 180 | %MAVEN_OPTS% ^ 181 | %MAVEN_DEBUG_OPTS% ^ 182 | -classpath %WRAPPER_JAR% ^ 183 | "-Dmaven.multiModuleProjectDirectory=%MAVEN_PROJECTBASEDIR%" ^ 184 | %WRAPPER_LAUNCHER% %MAVEN_CONFIG% %* 185 | if ERRORLEVEL 1 goto error 186 | goto end 187 | 188 | :error 189 | set ERROR_CODE=1 190 | 191 | :end 192 | @endlocal & set ERROR_CODE=%ERROR_CODE% 193 | 194 | if not "%MAVEN_SKIP_RC%"=="" goto skipRcPost 195 | @REM check for post script, once with legacy .bat ending and once with .cmd ending 196 | if exist "%USERPROFILE%\mavenrc_post.bat" call "%USERPROFILE%\mavenrc_post.bat" 197 | if exist "%USERPROFILE%\mavenrc_post.cmd" call "%USERPROFILE%\mavenrc_post.cmd" 198 | :skipRcPost 199 | 200 | @REM pause the script if MAVEN_BATCH_PAUSE is set to 'on' 201 | if "%MAVEN_BATCH_PAUSE%"=="on" pause 202 | 203 | if "%MAVEN_TERMINATE_CMD%"=="on" exit %ERROR_CODE% 204 | 205 | cmd /C exit /B %ERROR_CODE% 206 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | dev.ai4j 8 | ai4j 9 | 0.3.0 10 | jar 11 | 12 | ai4j 13 | Java library for smooth integration with popular AI tools and services 14 | https://github.com/ai-for-java/ai4j 15 | 16 | 17 | 18 | Apache-2.0 19 | https://www.apache.org/licenses/LICENSE-2.0.txt 20 | repo 21 | A business-friendly OSS license 22 | 23 | 24 | 25 | 26 | https://github.com/ai-for-java/ai4j 27 | scm:git:git://github.com/ai-for-java/ai4j.git 28 | scm:git:git@github.com:ai-for-java/ai4j.git 29 | 30 | 31 | 32 | 33 | deep-learning-dynamo 34 | deeplearningdynamo@gmail.com 35 | https://github.com/deep-learning-dynamo 36 | 37 | 38 | kuraleta 39 | digital.kuraleta@gmail.com 40 | https://github.com/kuraleta 41 | 42 | 43 | 44 | 45 | 1.8 46 | 1.8 47 | UTF-8 48 | 5.9.3 49 | 50 | 51 | 52 | 53 | 54 | dev.ai4j 55 | ai4j-core 56 | 0.3.0 57 | 58 | 59 | 60 | dev.ai4j 61 | openai4j 62 | 0.2.0 63 | 64 | 65 | 66 | 67 | org.projectlombok 68 | lombok 69 | 1.18.26 70 | provided 71 | 72 | 73 | 74 | org.apache.pdfbox 75 | pdfbox 76 | 2.0.28 77 | 78 | 79 | 80 | com.google.code.gson 81 | gson 82 | 2.10.1 83 | 84 | 85 | 86 | org.jsoup 87 | jsoup 88 | 1.16.1 89 | 90 | 91 | 92 | org.slf4j 93 | slf4j-api 94 | 2.0.7 95 | 96 | 97 | 98 | com.knuddels 99 | jtokkit 100 | 0.4.0 101 | 102 | 103 | 104 | org.junit.jupiter 105 | junit-jupiter-engine 106 | ${junit.version} 107 | test 108 | 109 | 110 | 111 | org.junit.jupiter 112 | junit-jupiter-params 113 | ${junit.version} 114 | test 115 | 116 | 117 | 118 | org.assertj 119 | assertj-core 120 | 3.24.2 121 | test 122 | 123 | 124 | 125 | 126 | 127 | 128 | ossrh 129 | https://s01.oss.sonatype.org/content/repositories/snapshots 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | org.sonatype.plugins 138 | nexus-staging-maven-plugin 139 | 1.6.13 140 | true 141 | 142 | ossrh 143 | https://s01.oss.sonatype.org/ 144 | false 145 | 146 | 147 | 148 | 149 | org.apache.maven.plugins 150 | maven-source-plugin 151 | 3.2.1 152 | 153 | 154 | attach-sources 155 | 156 | jar-no-fork 157 | 158 | 159 | 160 | 161 | 162 | 163 | org.apache.maven.plugins 164 | maven-javadoc-plugin 165 | 3.5.0 166 | 167 | 168 | attach-javadocs 169 | 170 | jar 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | sign 183 | 184 | 185 | sign 186 | 187 | 188 | 189 | 190 | 191 | org.apache.maven.plugins 192 | maven-gpg-plugin 193 | 3.0.1 194 | 195 | 196 | sign-artifacts 197 | verify 198 | 199 | sign 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/agent/tool/webpage/WebPageScrapperTool.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.agent.tool.webpage; 2 | 3 | import dev.ai4j.agent.Tool; 4 | import lombok.val; 5 | import org.jsoup.Jsoup; 6 | 7 | import java.io.IOException; 8 | import java.util.Optional; 9 | 10 | public class WebPageScrapperTool implements Tool { 11 | 12 | @Override 13 | public String id() { 14 | return "webpage-scrapper"; 15 | } 16 | 17 | @Override 18 | public String description() { 19 | return "A portal to the internet. Use this when you need to get content from a specific web page." 20 | + " You should provide a valid URL as an input and the tool with output all the text from that web page."; 21 | } 22 | 23 | @Override 24 | public Optional execute(String webPageUri) { 25 | try { 26 | val webPage = Jsoup.connect(webPageUri).get(); 27 | return Optional.of(webPage.text()); // TODO try html 28 | } catch (IOException e) { 29 | // TODO retry? 30 | } 31 | 32 | return Optional.empty(); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/document/loader/PdfFileLoader.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.document.loader; 2 | 3 | import dev.ai4j.document.Document; 4 | import dev.ai4j.document.DocumentLoader; 5 | import lombok.AllArgsConstructor; 6 | import lombok.SneakyThrows; 7 | import lombok.val; 8 | import org.apache.pdfbox.pdmodel.PDDocument; 9 | import org.apache.pdfbox.text.PDFTextStripper; 10 | 11 | import java.io.File; 12 | 13 | @AllArgsConstructor 14 | public class PdfFileLoader implements DocumentLoader { 15 | 16 | private final String absolutePathToPdfFile; 17 | 18 | @Override 19 | @SneakyThrows 20 | public Document load() { 21 | val pdfFile = new File(absolutePathToPdfFile); 22 | val pdfDocument = PDDocument.load(pdfFile); 23 | val stripper = new PDFTextStripper(); 24 | val text = stripper.getText(pdfDocument); 25 | pdfDocument.close(); 26 | return new Document(text); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/document/loader/TextFileLoader.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.document.loader; 2 | 3 | import dev.ai4j.document.Document; 4 | import dev.ai4j.document.DocumentLoader; 5 | import lombok.Builder; 6 | import lombok.SneakyThrows; 7 | import lombok.val; 8 | 9 | import java.nio.charset.Charset; 10 | import java.nio.file.Files; 11 | import java.nio.file.Paths; 12 | 13 | import static java.nio.charset.StandardCharsets.UTF_8; 14 | 15 | public class TextFileLoader implements DocumentLoader { 16 | 17 | private final String absolutePathToTextFile; 18 | private final Charset charset; 19 | 20 | @Builder 21 | public TextFileLoader(String absolutePathToTextFile, Charset charset) { 22 | this.absolutePathToTextFile = absolutePathToTextFile; 23 | this.charset = (charset == null) ? UTF_8 : charset; 24 | } 25 | 26 | @Override 27 | @SneakyThrows 28 | public Document load() { 29 | val fileContents = new String(Files.readAllBytes(Paths.get(absolutePathToTextFile)), charset); 30 | return Document.from(fileContents); 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/document/splitter/OverlappingDocumentSplitter.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.document.splitter; 2 | 3 | import dev.ai4j.document.Document; 4 | import dev.ai4j.document.DocumentSplitter; 5 | import lombok.Builder; 6 | import lombok.val; 7 | import lombok.var; 8 | 9 | import java.util.ArrayList; 10 | import java.util.List; 11 | 12 | public class OverlappingDocumentSplitter implements DocumentSplitter { 13 | 14 | private final int chunkSize; 15 | private final int chunkOverlap; 16 | 17 | @Builder 18 | public OverlappingDocumentSplitter(int chunkSize, int chunkOverlap) { 19 | this.chunkSize = chunkSize; 20 | this.chunkOverlap = chunkOverlap; 21 | } 22 | 23 | @Override 24 | public List split(Document document) { 25 | if (document.contents() == null || document.contents().isEmpty()) { 26 | throw new IllegalArgumentException("Document content should not be null or empty"); 27 | } 28 | 29 | val contents = document.contents(); 30 | val contentLength = contents.length(); 31 | 32 | if (chunkSize <= 0 || chunkOverlap < 0 || chunkSize <= chunkOverlap) { 33 | throw new IllegalArgumentException(String.format("Invalid chunkSize (%s) or chunkOverlap (%s)", chunkSize, chunkOverlap)); 34 | } 35 | 36 | val result = new ArrayList(); 37 | if (contentLength <= chunkSize) { 38 | result.add(document); 39 | } else { 40 | for (var i = 0; i < contentLength - chunkOverlap; i += chunkSize - chunkOverlap) { 41 | val endIndex = Math.min(i + chunkSize, contentLength); 42 | val chunk = contents.substring(i, endIndex); 43 | result.add(Document.from(chunk)); 44 | if (endIndex == contentLength) { 45 | break; 46 | } 47 | } 48 | } 49 | 50 | return result; 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/flows/ChatFlow.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.flows; 2 | 3 | import dev.ai4j.chat.AiMessage; 4 | import dev.ai4j.chat.ChatHistory; 5 | import dev.ai4j.chat.ChatMessage; 6 | import dev.ai4j.chat.ChatModel; 7 | import lombok.Builder; 8 | 9 | import java.util.List; 10 | 11 | import static dev.ai4j.chat.UserMessage.userMessage; 12 | 13 | public class ChatFlow { 14 | 15 | private final ChatModel chatModel; 16 | private final ChatHistory chatHistory; 17 | // TODO private final PromptTemplate promptTemplate; 18 | 19 | @Builder 20 | private ChatFlow(ChatModel chatModel, ChatHistory chatHistory) { 21 | this.chatModel = chatModel; 22 | this.chatHistory = chatHistory; 23 | } 24 | 25 | public String chat(String userMessage) { 26 | chatHistory.add(userMessage(userMessage)); 27 | List history = chatHistory.history(); 28 | 29 | AiMessage aiMessage = chatModel.chat(history); 30 | 31 | chatHistory.add(aiMessage); 32 | return aiMessage.contents(); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/flows/DocumentQnAFlow.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.flows; 2 | 3 | import dev.ai4j.PromptTemplate; 4 | import dev.ai4j.chat.ChatModel; 5 | import dev.ai4j.document.Document; 6 | import dev.ai4j.document.DocumentLoader; 7 | import dev.ai4j.document.DocumentSplitter; 8 | import dev.ai4j.document.splitter.OverlappingDocumentSplitter; 9 | import dev.ai4j.embedding.Embedding; 10 | import dev.ai4j.embedding.EmbeddingModel; 11 | import dev.ai4j.embedding.VectorDatabase; 12 | import lombok.Builder; 13 | 14 | import java.util.Collection; 15 | import java.util.HashMap; 16 | import java.util.List; 17 | import java.util.Map; 18 | 19 | import static java.util.stream.Collectors.joining; 20 | 21 | public class DocumentQnAFlow { 22 | 23 | private static final OverlappingDocumentSplitter DEFAULT_DOCUMENT_SPLITTER 24 | = new OverlappingDocumentSplitter(1000, 200); 25 | private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("Using the information delimited by triple angle brackets, answer the following question to the best of your ability: {{question}} <<<{{information}}>>>"); 26 | 27 | private final DocumentLoader documentLoader; 28 | private final DocumentSplitter documentSplitter; 29 | private final EmbeddingModel embeddingModel; 30 | private final VectorDatabase vectorDatabase; 31 | private final PromptTemplate promptTemplate; 32 | private final ChatModel chatModel; 33 | 34 | @Builder 35 | public DocumentQnAFlow(DocumentLoader documentLoader, 36 | DocumentSplitter documentSplitter, 37 | EmbeddingModel embeddingModel, // TODO provide possibility to use same openapi key 38 | VectorDatabase vectorDatabase, 39 | PromptTemplate promptTemplate, 40 | ChatModel chatModel) { 41 | this.documentLoader = documentLoader; 42 | this.documentSplitter = documentSplitter == null ? DEFAULT_DOCUMENT_SPLITTER : documentSplitter; 43 | this.embeddingModel = embeddingModel; 44 | this.vectorDatabase = vectorDatabase; 45 | this.promptTemplate = promptTemplate == null ? DEFAULT_PROMPT_TEMPLATE : promptTemplate; 46 | this.chatModel = chatModel; 47 | 48 | init(); 49 | } 50 | 51 | private void init() { 52 | Document document = documentLoader.load(); 53 | List chunks = documentSplitter.split(document); 54 | Collection embeddings = embeddingModel.embed(chunks); 55 | vectorDatabase.persist(embeddings); 56 | } 57 | 58 | public String ask(String question) { 59 | Embedding questionEmbedding = embeddingModel.embed(question); 60 | 61 | List relatedEmbeddings = vectorDatabase.findRelated(questionEmbedding, 5); // TODO defaults 62 | 63 | String concatenatedEmbeddings = relatedEmbeddings.stream() 64 | .map(Embedding::contents) 65 | .collect(joining(" ")); 66 | 67 | Map parameters = new HashMap<>(); 68 | parameters.put("question", question); 69 | parameters.put("information", concatenatedEmbeddings); 70 | 71 | String prompt = promptTemplate.format(parameters); 72 | 73 | return chatModel.chat(prompt); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/model/chat/OpenAiChatModel.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.model.chat; 2 | 3 | import com.google.gson.JsonElement; 4 | import com.google.gson.JsonParser; 5 | import com.google.gson.stream.JsonReader; 6 | import com.google.gson.stream.JsonToken; 7 | import dev.ai4j.PromptTemplate; 8 | import dev.ai4j.StreamingResponseHandler; 9 | import dev.ai4j.chat.AiMessage; 10 | import dev.ai4j.chat.ChatMessage; 11 | import dev.ai4j.chat.ChatModel; 12 | import dev.ai4j.chat.UserMessage; 13 | import dev.ai4j.model.completion.structured.Description; 14 | import dev.ai4j.openai4j.OpenAiClient; 15 | import dev.ai4j.openai4j.chat.ChatCompletionRequest; 16 | import dev.ai4j.openai4j.chat.ChatCompletionResponse; 17 | import dev.ai4j.openai4j.chat.Message; 18 | import dev.ai4j.openai4j.chat.Role; 19 | import dev.ai4j.utils.Json; 20 | import dev.ai4j.utils.StopWatch; 21 | import lombok.Builder; 22 | import lombok.extern.slf4j.Slf4j; 23 | import lombok.val; 24 | 25 | import java.io.StringReader; 26 | import java.time.Duration; 27 | import java.util.*; 28 | 29 | import static dev.ai4j.chat.AiMessage.aiMessage; 30 | import static dev.ai4j.chat.UserMessage.userMessage; 31 | import static dev.ai4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; 32 | import static dev.ai4j.openai4j.chat.Role.*; 33 | import static dev.ai4j.utils.Json.*; 34 | import static java.util.Arrays.asList; 35 | import static java.util.stream.Collectors.toList; 36 | 37 | @Slf4j 38 | public class OpenAiChatModel implements ChatModel { // TODO all models in one "service"? 39 | 40 | private static final double DEFAULT_TEMPERATURE = 0.7; 41 | private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(60); 42 | 43 | private final OpenAiClient client; 44 | private final String modelName; 45 | private final Double temperature; 46 | 47 | // TODO consider adding here an option for system prompt configuration 48 | 49 | @Builder 50 | public OpenAiChatModel(String apiKey, 51 | String modelName, 52 | Double temperature, 53 | Duration timeout) { 54 | this.client = OpenAiClient.builder() 55 | .apiKey(apiKey) 56 | .timeout(timeout == null ? DEFAULT_TIMEOUT : timeout) 57 | .build(); 58 | this.modelName = modelName == null ? GPT_3_5_TURBO : modelName; 59 | this.temperature = temperature == null ? DEFAULT_TEMPERATURE : temperature; 60 | } 61 | 62 | @Override 63 | public AiMessage chat(List messages) { 64 | 65 | ChatCompletionRequest request = ChatCompletionRequest.builder() 66 | .model(modelName) 67 | .messages(toOpenAiMessages(messages)) 68 | .temperature(temperature) 69 | .build(); 70 | 71 | if (log.isDebugEnabled()) { 72 | String json = toJson(request); 73 | log.debug("Sending to OpenAI:\n{}", json); 74 | } 75 | StopWatch sw = StopWatch.start(); 76 | 77 | ChatCompletionResponse response = client.chatCompletion(request).execute(); 78 | 79 | long secondsElapsed = sw.secondsElapsed(); 80 | if (log.isDebugEnabled()) { 81 | String json = toJson(response); 82 | log.debug("Received from OpenAI in {} seconds:\n{}", secondsElapsed, json); 83 | } 84 | 85 | return aiMessage(response.content()); 86 | } 87 | 88 | @Override 89 | public AiMessage chat(ChatMessage... messages) { 90 | return chat(asList(messages)); 91 | } 92 | 93 | @Override 94 | public String chat(String userMessage) { 95 | AiMessage aiMessage = chat(userMessage(userMessage)); 96 | return aiMessage.contents(); 97 | } 98 | 99 | @Override 100 | public void chat(List messages, StreamingResponseHandler handler) { 101 | ChatCompletionRequest request = ChatCompletionRequest.builder() 102 | .model(modelName) 103 | .messages(toOpenAiMessages(messages)) 104 | .temperature(temperature) 105 | .stream(true) 106 | .build(); 107 | 108 | client.chatCompletion(request) 109 | .onPartialResponse(partialResponse -> { 110 | String content = partialResponse.choices().get(0).delta().content(); 111 | if (content != null) { 112 | handler.onPartialResponse(content); 113 | } 114 | }) 115 | .onComplete(handler::onComplete) 116 | .onError(handler::onError) 117 | .execute(); 118 | } 119 | 120 | private static List toOpenAiMessages(List messages) { 121 | return messages.stream() 122 | .map(OpenAiChatModel::toOpenAiMessage) 123 | .collect(toList()); 124 | } 125 | 126 | private static Message toOpenAiMessage(ChatMessage message) { 127 | Role role; 128 | 129 | if (message instanceof UserMessage) { 130 | role = USER; 131 | } else if (message instanceof AiMessage) { 132 | role = ASSISTANT; 133 | } else { 134 | role = SYSTEM; 135 | } 136 | 137 | return Message.builder() 138 | .role(role) 139 | .content(message.contents()) 140 | .build(); 141 | } 142 | 143 | @Override 144 | public S getOne(Class structured) { 145 | return getMultiple(structured, 1).get(0); 146 | } 147 | 148 | @Override 149 | public List getMultiple(Class structured, int n) { 150 | String description = structured.getAnnotation(Description.class).value(); 151 | String jsonStructure = generateJsonStructure(structured); 152 | Optional maybeJsonExample = generateJsonExample(structured); 153 | 154 | PromptTemplate promptTemplate = PromptTemplate.from( 155 | "Provide exactly {{number_of_examples}} example(s) of {{description}} in exactly following JSON format:\n" + 156 | "{{json_structure}}\n" + 157 | "\n" + 158 | "{{maybe_example}}" + 159 | "Do not provide any other information, just valid JSON object(s)!" 160 | ); 161 | 162 | Map params = new HashMap<>(); 163 | params.put("number_of_examples", n); 164 | params.put("description", description); 165 | params.put("json_structure", jsonStructure); 166 | params.put("maybe_example", maybeJsonExample 167 | .map(example -> String.format("For example:\n%s\n\n", example)) 168 | .orElse("")); 169 | 170 | String prompt = promptTemplate.format(params); 171 | 172 | AiMessage aiMessage = chat(userMessage(prompt)); 173 | 174 | val jsonElements = parse(aiMessage.contents()); 175 | 176 | return jsonElements.stream() 177 | .map(jsonElement -> { 178 | try { 179 | return Json.fromJson(jsonElement.toString(), structured); 180 | } catch (Exception e) { 181 | // TODO 182 | return null; 183 | } 184 | }) 185 | .filter(Objects::nonNull) 186 | .limit(n) 187 | .collect(toList()); 188 | } 189 | 190 | private static ArrayList parse(String json) { 191 | val reader = new JsonReader(new StringReader(json)); 192 | reader.setLenient(true); 193 | 194 | val jsonElements = new ArrayList(); 195 | try { 196 | while (reader.peek() != JsonToken.END_DOCUMENT) { 197 | jsonElements.add(JsonParser.parseReader(reader)); 198 | } 199 | } catch (Exception e) { 200 | // TODO 201 | } 202 | 203 | return jsonElements; 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/model/chat/SimpleChatHistory.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.model.chat; 2 | 3 | import dev.ai4j.chat.ChatHistory; 4 | import dev.ai4j.chat.ChatMessage; 5 | import dev.ai4j.chat.SystemMessage; 6 | import dev.ai4j.chat.UserMessage; 7 | import lombok.extern.slf4j.Slf4j; 8 | import lombok.val; 9 | import lombok.var; 10 | 11 | import java.util.ArrayList; 12 | import java.util.LinkedList; 13 | import java.util.List; 14 | import java.util.Optional; 15 | 16 | @Slf4j 17 | public class SimpleChatHistory implements ChatHistory { 18 | 19 | // safety net to limit the cost in case user did not define it himself 20 | private static final int DEFAULT_CAPACITY_IN_TOKENS = 200; 21 | 22 | private final Optional maybeMessageFromSystem; 23 | private final LinkedList previousMessages; 24 | private final Integer capacityInTokens; 25 | private final Integer capacityInMessages; 26 | 27 | private SimpleChatHistory(Builder builder) { 28 | this.maybeMessageFromSystem = builder.maybeSystemMessage; 29 | this.previousMessages = builder.previousMessages; 30 | this.capacityInTokens = builder.capacityInTokens; 31 | this.capacityInMessages = builder.capacityInMessages; 32 | ensureCapacity(); 33 | } 34 | 35 | @Override 36 | public void add(ChatMessage chatMessage) { 37 | previousMessages.add(chatMessage); 38 | ensureCapacity(); 39 | } 40 | 41 | @Override 42 | public List history() { 43 | val messages = new ArrayList(); 44 | maybeMessageFromSystem.ifPresent(messages::add); 45 | messages.addAll(previousMessages); 46 | return messages; 47 | } 48 | 49 | private void ensureCapacity() { 50 | var currentNumberOfTokensInHistory = getCurrentNumberOfTokens(); 51 | var currentNumberOfMessagesInHistory = getCurrentNumberOfMessages(); 52 | 53 | while ((capacityInTokens != null && currentNumberOfTokensInHistory > capacityInTokens) 54 | || (capacityInMessages != null && currentNumberOfMessagesInHistory > capacityInMessages)) { 55 | 56 | val oldestMessage = previousMessages.removeFirst(); 57 | 58 | // remove all mentions of human, messageFrom 59 | 60 | log.debug("Removing the oldest message from {} '{}' ({} tokens) to comply with capacity requirements", 61 | oldestMessage instanceof UserMessage ? "user" : "AI", 62 | oldestMessage.contents(), 63 | oldestMessage.numberOfTokens()); 64 | 65 | currentNumberOfTokensInHistory -= oldestMessage.numberOfTokens(); 66 | currentNumberOfMessagesInHistory--; 67 | } 68 | 69 | log.debug("Current stats: { tokens: approximately {}, messages: {} }", getCurrentNumberOfTokens(), getCurrentNumberOfMessages()); 70 | } 71 | 72 | private int getCurrentNumberOfTokens() { 73 | val numberOfTokensInSystemMessage = maybeMessageFromSystem.map(ChatMessage::numberOfTokens).orElse(0); 74 | val numberOfTokensInPreviousMessages = previousMessages.stream() 75 | .map(ChatMessage::numberOfTokens) 76 | .reduce(0, Integer::sum); 77 | return numberOfTokensInSystemMessage + numberOfTokensInPreviousMessages; 78 | } 79 | 80 | private int getCurrentNumberOfMessages() { 81 | return maybeMessageFromSystem.map(m -> 1).orElse(0) + previousMessages.size(); 82 | } 83 | 84 | public static class Builder { 85 | 86 | private Optional maybeSystemMessage = Optional.empty(); 87 | private Integer capacityInTokens = DEFAULT_CAPACITY_IN_TOKENS; 88 | private Integer capacityInMessages; 89 | private LinkedList previousMessages = new LinkedList<>(); 90 | 91 | public Builder systemMessage(SystemMessage systemMessage) { 92 | this.maybeSystemMessage = Optional.ofNullable(systemMessage); 93 | return this; 94 | } 95 | 96 | public Builder systemMessage(String systemMessage) { 97 | if (systemMessage == null) { 98 | this.maybeSystemMessage = Optional.empty(); // TODO ? 99 | return this; 100 | } 101 | 102 | return systemMessage(SystemMessage.systemMessage(systemMessage)); 103 | } 104 | 105 | public Builder capacityInTokens(Integer capacityInTokens) { 106 | this.capacityInTokens = capacityInTokens; 107 | return this; 108 | } 109 | 110 | public Builder removeCapacityRestrictionInTokens() { 111 | return capacityInTokens(null); 112 | } 113 | 114 | public Builder capacityInMessages(Integer capacityInMessages) { 115 | this.capacityInMessages = capacityInMessages; 116 | return this; 117 | } 118 | 119 | public Builder previousMessages(List previousMessages) { 120 | if (previousMessages == null) { 121 | return this; 122 | } 123 | 124 | this.previousMessages = new LinkedList<>(previousMessages); 125 | return this; 126 | } 127 | 128 | public SimpleChatHistory build() { 129 | return new SimpleChatHistory(this); 130 | } 131 | } 132 | 133 | public static Builder builder() { 134 | return new Builder(); 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/model/completion/OpenAiCompletionModel.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.model.completion; 2 | 3 | import dev.ai4j.completion.CompletionModel; 4 | import dev.ai4j.openai4j.OpenAiClient; 5 | import dev.ai4j.openai4j.chat.ChatCompletionRequest; 6 | import dev.ai4j.openai4j.completion.CompletionRequest; 7 | import dev.ai4j.utils.StopWatch; 8 | import lombok.Builder; 9 | import lombok.extern.slf4j.Slf4j; 10 | import lombok.val; 11 | 12 | import java.time.Duration; 13 | 14 | import static dev.ai4j.model.openai.OpenAiModelName.GPT_3_5_TURBO; 15 | import static dev.ai4j.utils.Json.toJson; 16 | 17 | @Slf4j 18 | public class OpenAiCompletionModel implements CompletionModel { 19 | 20 | private static final double DEFAULT_TEMPERATURE = 0.7; 21 | private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(60); 22 | 23 | private final OpenAiClient client; 24 | private final String modelName; 25 | private final Double temperature; 26 | 27 | @Builder 28 | public OpenAiCompletionModel(String apiKey, String modelName, Double temperature, Duration timeout) { 29 | this.client = OpenAiClient.builder() 30 | .apiKey(apiKey) 31 | .timeout(timeout == null ? DEFAULT_TIMEOUT : timeout) 32 | .build(); 33 | this.modelName = modelName == null ? GPT_3_5_TURBO : modelName; 34 | this.temperature = temperature == null ? DEFAULT_TEMPERATURE : temperature; 35 | } 36 | 37 | @Override 38 | public String complete(String prompt) { 39 | if (GPT_3_5_TURBO.equals(modelName)) { // TODO remove this 40 | return chatCompletion(prompt); 41 | } else { 42 | return completion(prompt); 43 | } 44 | } 45 | 46 | private String chatCompletion(String input) { 47 | 48 | val request = ChatCompletionRequest.builder() 49 | .model(modelName) 50 | .addUserMessage(input) 51 | .temperature(temperature) 52 | .build(); 53 | 54 | if (log.isDebugEnabled()) { 55 | val json = toJson(request); 56 | log.debug("Sending to OpenAI:\n{}", json); 57 | } 58 | val sw = StopWatch.start(); 59 | 60 | val response = client.chatCompletion(request).execute(); 61 | 62 | val secondsElapsed = sw.secondsElapsed(); 63 | if (log.isDebugEnabled()) { 64 | val json = toJson(response); 65 | log.debug("Received from OpenAI in {} seconds:\n{}", secondsElapsed, json); 66 | } 67 | 68 | return response.content(); 69 | } 70 | 71 | private String completion(String input) { 72 | 73 | val request = CompletionRequest.builder() 74 | .model(modelName) 75 | .prompt(input) 76 | .temperature(temperature) 77 | .build(); 78 | 79 | if (log.isDebugEnabled()) { 80 | val json = toJson(request); 81 | log.debug("Sending to OpenAI:\n{}", json); 82 | } 83 | val sw = StopWatch.start(); 84 | 85 | val completionResult = client.completion(request).execute(); 86 | 87 | val secondsElapsed = sw.secondsElapsed(); 88 | if (log.isDebugEnabled()) { 89 | val json = toJson(completionResult); 90 | log.debug("Received from OpenAI in {} seconds:\n{}", secondsElapsed, json); 91 | } 92 | 93 | return completionResult.text(); 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/model/completion/structured/Description.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.model.completion.structured; 2 | 3 | import java.lang.annotation.Retention; 4 | 5 | import static java.lang.annotation.RetentionPolicy.RUNTIME; 6 | 7 | @Retention(RUNTIME) 8 | public @interface Description { 9 | 10 | String value(); 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/model/completion/structured/Example.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.model.completion.structured; 2 | 3 | import java.lang.annotation.Retention; 4 | 5 | import static java.lang.annotation.RetentionPolicy.RUNTIME; 6 | 7 | @Retention(RUNTIME) 8 | public @interface Example { 9 | 10 | String[] value(); 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/model/embedding/OpenAiEmbeddingModel.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.model.embedding; 2 | 3 | import dev.ai4j.document.Document; 4 | import dev.ai4j.embedding.Embedding; 5 | import dev.ai4j.embedding.EmbeddingModel; 6 | import dev.ai4j.openai4j.OpenAiClient; 7 | import dev.ai4j.openai4j.embedding.EmbeddingRequest; 8 | import lombok.Builder; 9 | import lombok.val; 10 | 11 | import java.time.Duration; 12 | import java.util.Collection; 13 | import java.util.List; 14 | import java.util.stream.IntStream; 15 | 16 | import static dev.ai4j.model.openai.OpenAiModelName.TEXT_EMBEDDING_ADA_002; 17 | import static dev.ai4j.utils.Utils.list; 18 | import static java.util.stream.Collectors.toList; 19 | 20 | public class OpenAiEmbeddingModel implements EmbeddingModel { 21 | 22 | private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(60); 23 | 24 | private final OpenAiClient client; 25 | private final String modelName; 26 | 27 | @Builder 28 | public OpenAiEmbeddingModel(String apiKey, String modelName, Duration timeout) { 29 | this.client = OpenAiClient.builder() 30 | .apiKey(apiKey) 31 | .timeout(timeout == null ? DEFAULT_TIMEOUT : timeout) 32 | .build(); 33 | this.modelName = modelName == null ? TEXT_EMBEDDING_ADA_002 : modelName; 34 | } 35 | 36 | @Override 37 | public Embedding embed(Document document) { 38 | return embed(list(document)).iterator().next(); 39 | } 40 | 41 | @Override 42 | public Embedding embed(String text) { 43 | return embed(list(Document.from(text))).iterator().next(); 44 | } 45 | 46 | @Override 47 | public Collection embed(Collection documents) { 48 | val documentContents = documents.stream() 49 | .map(Document::contents) 50 | .collect(toList()); 51 | 52 | val embeddingRequest = EmbeddingRequest.builder() 53 | .input(documentContents) // TODO handle newlines ? 54 | .model(modelName) 55 | .build(); 56 | 57 | val openAiEmbeddings = client.embedding(embeddingRequest).execute().data(); 58 | 59 | return zip(documentContents, openAiEmbeddings); 60 | } 61 | 62 | private static List zip(List documentTexts, List openAiEmbeddings) { 63 | return IntStream.range(0, documentTexts.size()) 64 | .mapToObj(i -> new Embedding(documentTexts.get(i), toDoubles(openAiEmbeddings.get(i).embedding()))) 65 | .collect(toList()); 66 | } 67 | 68 | private static List toDoubles(List floats) { 69 | return floats.stream() 70 | .map(Float::doubleValue) 71 | .collect(toList()); 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/model/openai/OpenAiModelName.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.model.openai; 2 | 3 | public class OpenAiModelName { 4 | 5 | public static final String GPT_3_5_TURBO = "gpt-3.5-turbo"; 6 | public static final String GPT_4 = "gpt-4"; 7 | public static final String GPT_4_32K = "gpt-4-32k"; 8 | public static final String CODE_DAVINCI_002 = "code-davinci-002"; 9 | public static final String TEXT_DAVINCI_002 = "text-davinci-002"; 10 | public static final String TEXT_DAVINCI_003 = "text-davinci-003"; 11 | public static final String TEXT_EMBEDDING_ADA_002 = "text-embedding-ada-002"; 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/utils/Json.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.utils; 2 | 3 | import com.google.gson.Gson; 4 | import com.google.gson.GsonBuilder; 5 | import dev.ai4j.model.completion.structured.Description; 6 | import dev.ai4j.model.completion.structured.Example; 7 | 8 | import java.lang.reflect.Field; 9 | import java.lang.reflect.ParameterizedType; 10 | import java.util.Collection; 11 | import java.util.Optional; 12 | 13 | import static java.util.Arrays.stream; 14 | import static java.util.stream.Collectors.joining; 15 | 16 | public class Json { 17 | 18 | private static final Gson GSON = new GsonBuilder().setPrettyPrinting().create(); 19 | 20 | public static String toJson(Object o) { 21 | return GSON.toJson(o); 22 | } 23 | 24 | public static T fromJson(String json, Class type) { 25 | return GSON.fromJson(json, type); 26 | } 27 | 28 | public static String generateJsonStructure(Class structured) { 29 | StringBuilder jsonStructure = new StringBuilder(); 30 | 31 | jsonStructure.append("{\n"); 32 | for (Field field : structured.getDeclaredFields()) { 33 | Description fieldDescription = field.getAnnotation(Description.class); 34 | if (fieldDescription == null) { 35 | throw new RuntimeException(String.format("Field %s is not annotated with @Description(\"...\")", field.getName())); 36 | } 37 | jsonStructure.append(String.format("\"%s\": // %s,\n", field.getName(), fieldDescription.value())); 38 | } 39 | jsonStructure.deleteCharAt(jsonStructure.length() - 2); 40 | jsonStructure.append("}"); 41 | 42 | return jsonStructure.toString(); 43 | } 44 | 45 | public static Optional generateJsonExample(Class structured) { 46 | if (!hasExamples(structured)) { 47 | return Optional.empty(); 48 | } 49 | 50 | StringBuilder jsonExample = new StringBuilder(); 51 | 52 | jsonExample.append("{\n"); 53 | for (Field field : structured.getDeclaredFields()) { 54 | Example fieldExample = field.getAnnotation(Example.class); 55 | if (fieldExample == null) { 56 | throw new RuntimeException(String.format("Field %s is not annotated with @Example(\"...\")", field.getName())); 57 | } 58 | jsonExample.append(String.format("\"%s\": %s,\n", field.getName(), toJsonExample(field))); 59 | } 60 | jsonExample.deleteCharAt(jsonExample.length() - 2); 61 | jsonExample.append("}"); 62 | 63 | return Optional.of(jsonExample.toString()); 64 | } 65 | 66 | private static boolean hasExamples(Class structured) { 67 | return stream(structured.getDeclaredFields()) 68 | .anyMatch(field -> field.isAnnotationPresent(Example.class)); 69 | } 70 | 71 | public static String toJsonExample(Field field) { 72 | Example fieldExample = field.getAnnotation(Example.class); 73 | String[] examples = fieldExample.value(); 74 | 75 | Class fieldType = field.getType(); 76 | boolean wrapInQuotes = fieldType == String.class 77 | || fieldType == String[].class 78 | || isCollectionOfStrings(field); 79 | 80 | if (examples.length == 1) { 81 | if (wrapInQuotes) { 82 | return "\"" + examples[0] + "\""; 83 | } 84 | return examples[0]; 85 | } 86 | 87 | return String.format("[%s]", stream(examples).map(example -> { 88 | if (wrapInQuotes) { 89 | return "\"" + example + "\""; 90 | } 91 | 92 | return example; 93 | }).collect(joining(", "))); 94 | } 95 | 96 | private static boolean isCollectionOfStrings(Field field) { 97 | if (!Collection.class.isAssignableFrom(field.getType())) { 98 | return false; 99 | } 100 | 101 | ParameterizedType genericType = (ParameterizedType) field.getGenericType(); 102 | Class actualTypeArgument = (Class) genericType.getActualTypeArguments()[0]; 103 | 104 | return actualTypeArgument.equals(String.class); 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/utils/StopWatch.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.utils; 2 | 3 | public class StopWatch { 4 | 5 | private final long startTime; 6 | 7 | public StopWatch(long currentTimeMillis) { 8 | this.startTime = currentTimeMillis; 9 | } 10 | 11 | public static StopWatch start() { 12 | return new StopWatch(System.currentTimeMillis()); 13 | } 14 | 15 | public int secondsElapsed() { 16 | return (int) (System.currentTimeMillis() - startTime) / 1000; 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/dev/ai4j/utils/Utils.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.utils; 2 | 3 | import java.util.Arrays; 4 | import java.util.List; 5 | 6 | public class Utils { 7 | 8 | public static List list(T... elements) { 9 | return Arrays.asList(elements); 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /src/test/java/TestIt.java: -------------------------------------------------------------------------------- 1 | import dev.ai4j.openai4j.OpenAiClient; 2 | import dev.ai4j.openai4j.chat.ChatCompletionRequest; 3 | import dev.ai4j.openai4j.chat.ChatCompletionResponse; 4 | import org.junit.jupiter.api.Test; 5 | 6 | import static dev.ai4j.openai4j.Model.GPT_4; 7 | 8 | public class TestIt { 9 | 10 | @Test 11 | void test() { 12 | 13 | OpenAiClient client = new OpenAiClient(System.getenv("OPENAI_API_KEY")); 14 | 15 | ChatCompletionRequest request = ChatCompletionRequest.builder() 16 | .model(GPT_4) 17 | .addSystemMessage("You are a helpful assistant") 18 | .addUserMessage("Tell me a joke") 19 | .temperature(0.7) 20 | .build(); 21 | 22 | ChatCompletionResponse response = client.chatCompletion(request).execute(); 23 | System.out.println(response.content()); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/test/java/dev/ai4j/document/loader/TextFileLoaderTest.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.document.loader; 2 | 3 | import lombok.val; 4 | import org.junit.jupiter.api.Test; 5 | 6 | import java.nio.file.NoSuchFileException; 7 | 8 | import static java.nio.charset.StandardCharsets.ISO_8859_1; 9 | import static org.assertj.core.api.Assertions.assertThat; 10 | import static org.assertj.core.api.Assertions.assertThatThrownBy; 11 | 12 | class TextFileLoaderTest { 13 | 14 | @Test 15 | void should_load_text_file_with_utf8_charset_by_default() { 16 | 17 | val loader = TextFileLoader.builder() 18 | .absolutePathToTextFile(System.getProperty("user.dir") + "/src/test/java/dev/ai4j/document/loader/test-file-utf8.txt") 19 | .build(); 20 | 21 | val document = loader.load(); 22 | 23 | assertThat(document.contents()).isEqualTo("test\ncontent"); 24 | } 25 | 26 | @Test 27 | void should_load_text_file_with_specified_charset() { 28 | 29 | val loader = TextFileLoader.builder() 30 | .absolutePathToTextFile(System.getProperty("user.dir") + "/src/test/java/dev/ai4j/document/loader/test-file-iso-8859-1.txt") 31 | .charset(ISO_8859_1) 32 | .build(); 33 | 34 | val document = loader.load(); 35 | 36 | assertThat(document.contents()).isEqualTo("test\ncontent"); 37 | } 38 | 39 | @Test 40 | void should_fail_to_load_not_existing_file() { 41 | 42 | val loader = TextFileLoader.builder() 43 | .absolutePathToTextFile(System.getProperty("user.dir") + "banana") 44 | .build(); 45 | 46 | assertThatThrownBy(loader::load).isInstanceOf(NoSuchFileException.class); 47 | } 48 | } -------------------------------------------------------------------------------- /src/test/java/dev/ai4j/document/loader/test-file-iso-8859-1.txt: -------------------------------------------------------------------------------- 1 | test 2 | content -------------------------------------------------------------------------------- /src/test/java/dev/ai4j/document/loader/test-file-utf8.txt: -------------------------------------------------------------------------------- 1 | test 2 | content -------------------------------------------------------------------------------- /src/test/java/dev/ai4j/document/splitter/OverlappingDocumentSplitterTest.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.document.splitter; 2 | 3 | import dev.ai4j.document.Document; 4 | import lombok.val; 5 | import org.junit.jupiter.api.Test; 6 | import org.junit.jupiter.params.ParameterizedTest; 7 | import org.junit.jupiter.params.provider.CsvSource; 8 | import org.junit.jupiter.params.provider.NullAndEmptySource; 9 | 10 | import java.util.Arrays; 11 | import java.util.List; 12 | 13 | import static dev.ai4j.document.Document.from; 14 | import static org.junit.jupiter.api.Assertions.assertEquals; 15 | import static org.junit.jupiter.api.Assertions.assertThrows; 16 | import static org.junit.jupiter.api.Assertions.assertTrue; 17 | 18 | class OverlappingDocumentSplitterTest { 19 | 20 | // TODO add more test variety 21 | @Test 22 | void testDocumentIsSplit() { 23 | val sut = new OverlappingDocumentSplitter(4, 2); 24 | List result = sut.split(new Document("1234567890")); 25 | 26 | List expected = Arrays.asList( 27 | from("1234"), 28 | from("3456"), 29 | from("5678"), 30 | from("7890") 31 | ); 32 | 33 | assertEquals(expected, result); 34 | } 35 | 36 | @ParameterizedTest 37 | @CsvSource({"0,-1", "-1,-1", "-1,0", "0,0", "0,1", "1,-1", "1,1", "1,2"}) 38 | void testIllegalArgumentExceptionWhenChunkSizeAndChunkOverlapMisconfigured(int chunkSize, int chunkOverlap) { 39 | val sut = new OverlappingDocumentSplitter(chunkSize, chunkOverlap); 40 | 41 | IllegalArgumentException thrown = assertThrows( 42 | IllegalArgumentException.class, 43 | () -> sut.split(new Document("any")) 44 | ); 45 | 46 | assertTrue(thrown.getMessage() 47 | .contentEquals("Invalid chunkSize (" + chunkSize + ") or chunkOverlap (" + chunkOverlap + ")")); 48 | } 49 | 50 | @ParameterizedTest 51 | @NullAndEmptySource 52 | void testNullCase(String documentContent) { 53 | val sut = new OverlappingDocumentSplitter(4, 2); 54 | 55 | IllegalArgumentException thrown = assertThrows( 56 | IllegalArgumentException.class, 57 | () -> sut.split(new Document(documentContent)) 58 | ); 59 | 60 | assertTrue(thrown.getMessage().contentEquals("Document content should not be null or empty")); 61 | } 62 | } -------------------------------------------------------------------------------- /src/test/java/dev/ai4j/model/chat/SimpleChatHistoryTest.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.model.chat; 2 | 3 | import lombok.val; 4 | import org.junit.jupiter.api.Test; 5 | 6 | import static dev.ai4j.chat.AiMessage.aiMessage; 7 | import static dev.ai4j.chat.SystemMessage.systemMessage; 8 | import static dev.ai4j.chat.UserMessage.userMessage; 9 | import static dev.ai4j.utils.TestUtils.*; 10 | import static java.util.Arrays.asList; 11 | import static org.assertj.core.api.Assertions.assertThat; 12 | import static org.assertj.core.api.Assertions.atIndex; 13 | 14 | class SimpleChatHistoryTest { 15 | 16 | @Test 17 | void should_keep_specified_number_of_tokens_in_chat_history_1() { 18 | 19 | val messageFromSystem = systemMessageWithTokens(10); 20 | val chatHistory = SimpleChatHistory.builder() 21 | .systemMessage(messageFromSystem) 22 | .capacityInTokens(30) 23 | .build(); 24 | assertThat(chatHistory.history()) 25 | .hasSize(1) 26 | .containsExactly(messageFromSystem); 27 | 28 | val firstMessageFromHuman = messageFromHumanWithTokens(10); 29 | chatHistory.add(firstMessageFromHuman); 30 | assertThat(chatHistory.history()) 31 | .hasSize(2) 32 | .containsExactly( 33 | messageFromSystem, 34 | firstMessageFromHuman 35 | ); 36 | 37 | val firstMessageFromAi = messageFromAiWithTokens(10); 38 | chatHistory.add(firstMessageFromAi); 39 | assertThat(chatHistory.history()) 40 | .hasSize(3) 41 | .containsExactly( 42 | messageFromSystem, 43 | firstMessageFromHuman, 44 | firstMessageFromAi 45 | ); 46 | 47 | val secondMessageFromHuman = messageFromHumanWithTokens(10); 48 | chatHistory.add(secondMessageFromHuman); 49 | assertThat(chatHistory.history()) 50 | .hasSize(3) 51 | .containsExactly( 52 | messageFromSystem, 53 | // firstMessageFromHuman was removed 54 | firstMessageFromAi, 55 | secondMessageFromHuman 56 | ); 57 | 58 | val secondMessageFromAi = messageFromAiWithTokens(10); 59 | chatHistory.add(secondMessageFromAi); 60 | assertThat(chatHistory.history()) 61 | .hasSize(3) 62 | .containsExactly( 63 | messageFromSystem, 64 | // firstMessageFromAi was removed 65 | secondMessageFromHuman, 66 | secondMessageFromAi 67 | ); 68 | } 69 | 70 | @Test 71 | void should_keep_specified_number_of_tokens_in_chat_history_2() { 72 | 73 | val messageFromSystem = systemMessageWithTokens(5); 74 | val chatHistory = SimpleChatHistory.builder() 75 | .systemMessage(messageFromSystem) 76 | .capacityInTokens(20) 77 | .build(); 78 | assertThat(chatHistory.history()) 79 | .hasSize(1) 80 | .containsExactly(messageFromSystem); 81 | 82 | val firstMessageFromHuman = messageFromHumanWithTokens(10); 83 | chatHistory.add(firstMessageFromHuman); 84 | assertThat(chatHistory.history()) 85 | .hasSize(2) 86 | .containsExactly( 87 | messageFromSystem, // 5 tokens 88 | firstMessageFromHuman // 10 tokens 89 | ); 90 | 91 | val firstMessageFromAi = messageFromAiWithTokens(10); 92 | chatHistory.add(firstMessageFromAi); 93 | assertThat(chatHistory.history()) 94 | .hasSize(2) 95 | .containsExactly( 96 | messageFromSystem, // 5 tokens 97 | // firstMessageFromHuman was removed 98 | firstMessageFromAi // 10 tokens 99 | ); 100 | 101 | val secondMessageFromHuman = messageFromAiWithTokens(5); 102 | chatHistory.add(secondMessageFromHuman); 103 | assertThat(chatHistory.history()) 104 | .hasSize(3) 105 | .containsExactly( 106 | messageFromSystem, // 5 tokens 107 | // firstMessageFromHuman was removed 108 | firstMessageFromAi, // 10 tokens 109 | secondMessageFromHuman // 5 tokens 110 | ); 111 | } 112 | 113 | @Test 114 | void should_keep_200_tokens_in_chat_history_by_default() { 115 | 116 | val messageFromSystem = systemMessageWithTokens(10); 117 | val chatHistory = SimpleChatHistory.builder() 118 | .systemMessage(messageFromSystem) 119 | // user did not configure maxTokensInHistory 120 | .build(); 121 | assertThat(chatHistory.history()) 122 | .hasSize(1) 123 | .containsExactly(messageFromSystem); 124 | 125 | for (int i = 0; i < 30; i++) { 126 | chatHistory.add(messageFromHumanWithTokens(10)); 127 | } 128 | 129 | assertThat(chatHistory.history()) 130 | .contains(messageFromSystem, atIndex(0)) 131 | .hasSize(20); // 20 messages 10 tokens each = 200 tokens 132 | } 133 | 134 | @Test 135 | void should_keep_specified_number_of_tokens_in_history_without_message_from_system() { 136 | 137 | val chatHistory = SimpleChatHistory.builder() 138 | // user did not configure messageFromSystem 139 | .capacityInTokens(20) 140 | .build(); 141 | assertThat(chatHistory.history()) 142 | .hasSize(0); 143 | 144 | val firstMessageFromHuman = messageFromHumanWithTokens(10); 145 | chatHistory.add(firstMessageFromHuman); 146 | assertThat(chatHistory.history()) 147 | .hasSize(1) 148 | .containsExactly(firstMessageFromHuman); 149 | 150 | val firstMessageFromAi = messageFromAiWithTokens(10); 151 | chatHistory.add(firstMessageFromAi); 152 | assertThat(chatHistory.history()) 153 | .hasSize(2) 154 | .containsExactly( 155 | firstMessageFromHuman, 156 | firstMessageFromAi 157 | ); 158 | 159 | val secondMessageFromHuman = messageFromHumanWithTokens(10); 160 | chatHistory.add(secondMessageFromHuman); 161 | assertThat(chatHistory.history()) 162 | .hasSize(2) 163 | .containsExactly( 164 | // firstMessageFromHuman was removed 165 | firstMessageFromAi, 166 | secondMessageFromHuman 167 | ); 168 | 169 | val secondMessageFromAi = messageFromAiWithTokens(10); 170 | chatHistory.add(secondMessageFromAi); 171 | assertThat(chatHistory.history()) 172 | .hasSize(2) 173 | .containsExactly( 174 | // firstMessageFromAi was removed 175 | secondMessageFromHuman, 176 | secondMessageFromAi 177 | ); 178 | } 179 | 180 | @Test 181 | void should_keep_specified_number_of_messages_in_chat_history() { 182 | 183 | val messageFromSystem = systemMessage("does not matter how many tokens"); 184 | val chatHistory = SimpleChatHistory.builder() 185 | .systemMessage(messageFromSystem) 186 | .capacityInMessages(3) 187 | .build(); 188 | assertThat(chatHistory.history()) 189 | .hasSize(1) 190 | .containsExactly(messageFromSystem); 191 | 192 | val firstMessageFromHuman = userMessage("does not matter how many tokens"); 193 | chatHistory.add(firstMessageFromHuman); 194 | assertThat(chatHistory.history()) 195 | .hasSize(2) 196 | .containsExactly( 197 | messageFromSystem, 198 | firstMessageFromHuman 199 | ); 200 | 201 | val firstMessageFromAi = aiMessage("does not matter how many tokens"); 202 | chatHistory.add(firstMessageFromAi); 203 | assertThat(chatHistory.history()) 204 | .hasSize(3) 205 | .containsExactly( 206 | messageFromSystem, 207 | firstMessageFromHuman, 208 | firstMessageFromAi 209 | ); 210 | 211 | val secondMessageFromHuman = userMessage("does not matter how many tokens"); 212 | chatHistory.add(secondMessageFromHuman); 213 | assertThat(chatHistory.history()) 214 | .hasSize(3) 215 | .containsExactly( 216 | messageFromSystem, 217 | // firstMessageFromHuman was removed 218 | firstMessageFromAi, 219 | secondMessageFromHuman 220 | ); 221 | 222 | val secondMessageFromAi = aiMessage("does not matter how many tokens"); 223 | chatHistory.add(secondMessageFromAi); 224 | assertThat(chatHistory.history()) 225 | .hasSize(3) 226 | .containsExactly( 227 | messageFromSystem, 228 | // firstMessageFromAi was removed 229 | secondMessageFromHuman, 230 | secondMessageFromAi 231 | ); 232 | } 233 | 234 | @Test 235 | void should_keep_specified_number_of_messages_in_chat_history_without_message_from_system() { 236 | 237 | val chatHistory = SimpleChatHistory.builder() 238 | .capacityInMessages(2) 239 | .build(); 240 | assertThat(chatHistory.history()) 241 | .hasSize(0); 242 | 243 | val firstMessageFromHuman = userMessage("does not matter how many tokens"); 244 | chatHistory.add(firstMessageFromHuman); 245 | assertThat(chatHistory.history()) 246 | .hasSize(1) 247 | .containsExactly(firstMessageFromHuman); 248 | 249 | val firstMessageFromAi = aiMessage("does not matter how many tokens"); 250 | chatHistory.add(firstMessageFromAi); 251 | assertThat(chatHistory.history()) 252 | .hasSize(2) 253 | .containsExactly( 254 | firstMessageFromHuman, 255 | firstMessageFromAi 256 | ); 257 | 258 | val secondMessageFromHuman = userMessage("does not matter how many tokens"); 259 | chatHistory.add(secondMessageFromHuman); 260 | assertThat(chatHistory.history()) 261 | .hasSize(2) 262 | .containsExactly( 263 | // firstMessageFromHuman was removed 264 | firstMessageFromAi, 265 | secondMessageFromHuman 266 | ); 267 | 268 | val secondMessageFromAi = aiMessage("does not matter how many tokens"); 269 | chatHistory.add(secondMessageFromAi); 270 | assertThat(chatHistory.history()) 271 | .hasSize(2) 272 | .containsExactly( 273 | // firstMessageFromAi was removed 274 | secondMessageFromHuman, 275 | secondMessageFromAi 276 | ); 277 | } 278 | 279 | @Test 280 | void should_keep_specified_number_of_tokens_and_number_of_messages_in_chat_history_1() { 281 | 282 | // In this test we will be using messages with 10 tokens each. 283 | // We will configure maxTokensInHistory(20) and maxMessagesInHistory(3): 284 | // With maxMessagesInHistory(3) we will be able to fit 3 messages into history. 285 | // But due to maxTokensInHistory(20) it will keep only 2. 286 | 287 | val messageFromSystem = systemMessageWithTokens(10); 288 | val chatHistory = SimpleChatHistory.builder() 289 | .systemMessage(messageFromSystem) 290 | .capacityInTokens(20) 291 | .capacityInMessages(3) 292 | .build(); 293 | assertThat(chatHistory.history()) 294 | .hasSize(1) 295 | .containsExactly(messageFromSystem); 296 | 297 | val firstMessageFromHuman = messageFromHumanWithTokens(10); 298 | chatHistory.add(firstMessageFromHuman); 299 | assertThat(chatHistory.history()) 300 | .hasSize(2) 301 | .containsExactly( 302 | messageFromSystem, 303 | firstMessageFromHuman 304 | ); 305 | 306 | val firstMessageFromAi = messageFromAiWithTokens(10); 307 | chatHistory.add(firstMessageFromAi); 308 | assertThat(chatHistory.history()) 309 | .hasSize(2) 310 | .containsExactly( 311 | messageFromSystem, 312 | // firstMessageFromHuman was removed 313 | firstMessageFromAi 314 | ); 315 | 316 | val secondMessageFromHuman = messageFromHumanWithTokens(10); 317 | chatHistory.add(secondMessageFromHuman); 318 | assertThat(chatHistory.history()) 319 | .hasSize(2) 320 | .containsExactly( 321 | messageFromSystem, 322 | // firstMessageFromAi was removed 323 | secondMessageFromHuman 324 | ); 325 | } 326 | 327 | @Test 328 | void should_keep_specified_number_of_tokens_and_number_of_messages_in_chat_history_2() { 329 | 330 | // In this test we will be using messages with 10 tokens each. 331 | // We will configure maxMessagesInHistory(2) and maxTokensInHistory(30): 332 | // With maxTokensInHistory(30) we will be able to fit 3 messages into history. 333 | // But due to maxMessagesInHistory(2) it will keep only 2. 334 | 335 | val messageFromSystem = systemMessageWithTokens(10); 336 | val chatHistory = SimpleChatHistory.builder() 337 | .systemMessage(messageFromSystem) 338 | .capacityInTokens(30) 339 | .capacityInMessages(2) 340 | .build(); 341 | assertThat(chatHistory.history()) 342 | .hasSize(1) 343 | .containsExactly(messageFromSystem); 344 | 345 | val firstMessageFromHuman = messageFromHumanWithTokens(10); 346 | chatHistory.add(firstMessageFromHuman); 347 | assertThat(chatHistory.history()) 348 | .hasSize(2) 349 | .containsExactly( 350 | messageFromSystem, 351 | firstMessageFromHuman 352 | ); 353 | 354 | val firstMessageFromAi = messageFromAiWithTokens(10); 355 | chatHistory.add(firstMessageFromAi); 356 | assertThat(chatHistory.history()) 357 | .hasSize(2) 358 | .containsExactly( 359 | messageFromSystem, 360 | // firstMessageFromHuman was removed 361 | firstMessageFromAi 362 | ); 363 | 364 | val secondMessageFromHuman = messageFromHumanWithTokens(10); 365 | chatHistory.add(secondMessageFromHuman); 366 | assertThat(chatHistory.history()) 367 | .hasSize(2) 368 | .containsExactly( 369 | messageFromSystem, 370 | // firstMessageFromAi was removed 371 | secondMessageFromHuman 372 | ); 373 | } 374 | 375 | @Test 376 | void should_load_previous_messages_with_token_restriction() { 377 | 378 | val previousMessages = asList( 379 | messageFromHumanWithTokens(10), 380 | messageFromAiWithTokens(10), 381 | messageFromHumanWithTokens(10), 382 | messageFromAiWithTokens(10) 383 | ); 384 | 385 | val chatHistory = SimpleChatHistory.builder() 386 | .previousMessages(previousMessages) 387 | .capacityInTokens(30) 388 | .build(); 389 | 390 | assertThat(chatHistory.history()) 391 | .hasSize(3); 392 | } 393 | 394 | @Test 395 | void should_load_previous_messages_with_message_restriction() { 396 | 397 | val previousMessages = asList( 398 | messageFromHumanWithTokens(10), 399 | messageFromAiWithTokens(10), 400 | messageFromHumanWithTokens(10), 401 | messageFromAiWithTokens(10) 402 | ); 403 | 404 | val chatHistory = SimpleChatHistory.builder() 405 | .previousMessages(previousMessages) 406 | .capacityInMessages(3) 407 | .build(); 408 | 409 | assertThat(chatHistory.history()) 410 | .hasSize(3); 411 | } 412 | 413 | @Test 414 | void should_keep_all_history_without_restrictions() { 415 | 416 | val chatHistory = SimpleChatHistory.builder() 417 | .removeCapacityRestrictionInTokens() 418 | .build(); 419 | 420 | for (int i = 0; i < 1000; i++) { 421 | chatHistory.add(messageFromHumanWithTokens(1000)); 422 | } 423 | 424 | assertThat(chatHistory.history()) 425 | .hasSize(1000); 426 | } 427 | } -------------------------------------------------------------------------------- /src/test/java/dev/ai4j/utils/TestUtils.java: -------------------------------------------------------------------------------- 1 | package dev.ai4j.utils; 2 | 3 | import dev.ai4j.Tokenizer; 4 | import dev.ai4j.chat.AiMessage; 5 | import dev.ai4j.chat.SystemMessage; 6 | import dev.ai4j.chat.UserMessage; 7 | import lombok.val; 8 | import org.junit.jupiter.api.Test; 9 | import org.junit.jupiter.params.ParameterizedTest; 10 | import org.junit.jupiter.params.provider.ValueSource; 11 | 12 | import java.util.ArrayList; 13 | import java.util.List; 14 | 15 | import static dev.ai4j.chat.AiMessage.aiMessage; 16 | import static dev.ai4j.chat.SystemMessage.systemMessage; 17 | import static dev.ai4j.chat.UserMessage.userMessage; 18 | import static org.assertj.core.api.Assertions.assertThat; 19 | 20 | public class TestUtils { 21 | 22 | private static final int EXTRA_TOKENS_PER_CHAT_MESSAGE = 5; 23 | 24 | @ParameterizedTest 25 | @ValueSource(ints = {5, 10, 25, 50, 100, 250, 500, 1000}) 26 | void should_create_message_from_system_with_tokens(int numberOfTokens) { 27 | val messageFromSystem = systemMessageWithTokens(numberOfTokens); 28 | assertThat(messageFromSystem.numberOfTokens()).isEqualTo(numberOfTokens); 29 | } 30 | 31 | public static SystemMessage systemMessageWithTokens(int numberOfTokens) { 32 | return systemMessage(generateTokens(numberOfTokens - EXTRA_TOKENS_PER_CHAT_MESSAGE)); 33 | } 34 | 35 | @ParameterizedTest 36 | @ValueSource(ints = {5, 10, 25, 50, 100, 250, 500, 1000}) 37 | void should_create_message_from_human_with_tokens(int numberOfTokens) { 38 | val messageFromHuman = messageFromHumanWithTokens(numberOfTokens); 39 | assertThat(messageFromHuman.numberOfTokens()).isEqualTo(numberOfTokens); 40 | } 41 | 42 | public static UserMessage messageFromHumanWithTokens(int numberOfTokens) { 43 | return userMessage(generateTokens(numberOfTokens - EXTRA_TOKENS_PER_CHAT_MESSAGE)); 44 | } 45 | 46 | @ParameterizedTest 47 | @ValueSource(ints = {5, 10, 25, 50, 100, 250, 500, 1000}) 48 | void should_create_message_from_ai_with_tokens(int numberOfTokens) { 49 | AiMessage messageFromAi = messageFromAiWithTokens(numberOfTokens); 50 | assertThat(messageFromAi.numberOfTokens()).isEqualTo(numberOfTokens); 51 | } 52 | 53 | public static AiMessage messageFromAiWithTokens(int numberOfTokens) { 54 | return aiMessage(generateTokens(numberOfTokens - EXTRA_TOKENS_PER_CHAT_MESSAGE)); 55 | } 56 | 57 | @ParameterizedTest 58 | @ValueSource(ints = {1, 2, 5, 10, 25, 50, 100, 250, 500, 1000}) 59 | void should_generate_tokens(int numberOfTokens) { 60 | val tokenizer = new Tokenizer(); 61 | 62 | val tokens = generateTokens(numberOfTokens); 63 | 64 | assertThat(tokenizer.countTokens(tokens)).isEqualTo(numberOfTokens); 65 | } 66 | 67 | public static String generateTokens(int n) { 68 | val tokenizer = new Tokenizer(); 69 | val text = String.join(" ", repeat("one two", n)); 70 | return tokenizer.decode(tokenizer.encode(text, n)); 71 | } 72 | 73 | @Test 74 | void should_repeat_n_times() { 75 | assertThat(repeat("word", 1)) 76 | .hasSize(1) 77 | .containsExactly("word"); 78 | 79 | assertThat(repeat("word", 2)) 80 | .hasSize(2) 81 | .containsExactly("word", "word"); 82 | 83 | assertThat(repeat("word", 3)) 84 | .hasSize(3) 85 | .containsExactly("word", "word", "word"); 86 | } 87 | 88 | public static List repeat(String s, int n) { 89 | val result = new ArrayList(); 90 | for (int i = 0; i < n; i++) { 91 | result.add(s); 92 | } 93 | return result; 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/test/java/test-file.txt: -------------------------------------------------------------------------------- 1 | test 2 | content 3 | --------------------------------------------------------------------------------