├── .gitattributes
├── .gitignore
├── .mvn
└── wrapper
│ ├── maven-wrapper.jar
│ └── maven-wrapper.properties
├── LICENSE
├── README.md
├── libs
└── com
│ └── github
│ └── serpapi
│ └── google-search-results-java
│ └── 2.0.3
│ ├── google-search-results-java-2.0.3-sources.jar
│ └── google-search-results-java-2.0.3.jar
├── mvnw
├── mvnw.cmd
├── pom.xml
└── src
├── main
└── java
│ └── dev
│ └── ai4j
│ ├── agent
│ └── tool
│ │ └── webpage
│ │ └── WebPageScrapperTool.java
│ ├── document
│ ├── loader
│ │ ├── PdfFileLoader.java
│ │ └── TextFileLoader.java
│ └── splitter
│ │ └── OverlappingDocumentSplitter.java
│ ├── flows
│ ├── ChatFlow.java
│ └── DocumentQnAFlow.java
│ ├── model
│ ├── chat
│ │ ├── OpenAiChatModel.java
│ │ └── SimpleChatHistory.java
│ ├── completion
│ │ ├── OpenAiCompletionModel.java
│ │ └── structured
│ │ │ ├── Description.java
│ │ │ └── Example.java
│ ├── embedding
│ │ └── OpenAiEmbeddingModel.java
│ └── openai
│ │ └── OpenAiModelName.java
│ └── utils
│ ├── Json.java
│ ├── StopWatch.java
│ └── Utils.java
└── test
└── java
├── TestIt.java
├── dev
└── ai4j
│ ├── document
│ ├── loader
│ │ ├── TextFileLoaderTest.java
│ │ ├── test-file-iso-8859-1.txt
│ │ └── test-file-utf8.txt
│ └── splitter
│ │ └── OverlappingDocumentSplitterTest.java
│ ├── model
│ └── chat
│ │ └── SimpleChatHistoryTest.java
│ └── utils
│ └── TestUtils.java
└── test-file.txt
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Handle line endings automatically for files detected as text
2 | # and leave all files detected as binary untouched.
3 | * text=auto
4 |
5 | # Force the following filetypes to have unix eols, so Windows does not break them
6 | *.* text eol=lf
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | target/
2 | !.mvn/wrapper/maven-wrapper.jar
3 | !**/src/main/**/target/
4 | !**/src/test/**/target/
5 |
6 | ### IntelliJ IDEA ###
7 | .idea/*
8 | .idea/modules.xml
9 | .idea/jarRepositories.xml
10 | .idea/compiler.xml
11 | .idea/libraries/
12 | *.iws
13 | *.iml
14 | *.ipr
15 |
16 | **/ApiKeys.java
17 |
18 | ### Eclipse ###
19 | .apt_generated
20 | .classpath
21 | .factorypath
22 | .project
23 | .settings
24 | .springBeans
25 | .sts4-cache
26 |
27 | ### NetBeans ###
28 | /nbproject/private/
29 | /nbbuild/
30 | /dist/
31 | /nbdist/
32 | /.nb-gradle/
33 | build/
34 | !**/src/main/**/build/
35 | !**/src/test/**/build/
36 |
37 | ### VS Code ###
38 | .vscode/
39 |
40 | ### Mac OS ###
41 | .DS_Store
--------------------------------------------------------------------------------
/.mvn/wrapper/maven-wrapper.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ai-for-java/ai4j/3382afa163fea1cf6c642c1b471b9f56fdba98c9/.mvn/wrapper/maven-wrapper.jar
--------------------------------------------------------------------------------
/.mvn/wrapper/maven-wrapper.properties:
--------------------------------------------------------------------------------
1 | # Licensed to the Apache Software Foundation (ASF) under one
2 | # or more contributor license agreements. See the NOTICE file
3 | # distributed with this work for additional information
4 | # regarding copyright ownership. The ASF licenses this file
5 | # to you under the Apache License, Version 2.0 (the
6 | # "License"); you may not use this file except in compliance
7 | # with the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing,
12 | # software distributed under the License is distributed on an
13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 | # KIND, either express or implied. See the License for the
15 | # specific language governing permissions and limitations
16 | # under the License.
17 | distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.9.1/apache-maven-3.9.1-bin.zip
18 | wrapperUrl=https://repo.maven.apache.org/maven2/org/apache/maven/wrapper/maven-wrapper/3.2.0/maven-wrapper-3.2.0.jar
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Migrated to [LangChain4j](https://github.com/langchain4j/langchain4j)
2 |
--------------------------------------------------------------------------------
/libs/com/github/serpapi/google-search-results-java/2.0.3/google-search-results-java-2.0.3-sources.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ai-for-java/ai4j/3382afa163fea1cf6c642c1b471b9f56fdba98c9/libs/com/github/serpapi/google-search-results-java/2.0.3/google-search-results-java-2.0.3-sources.jar
--------------------------------------------------------------------------------
/libs/com/github/serpapi/google-search-results-java/2.0.3/google-search-results-java-2.0.3.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ai-for-java/ai4j/3382afa163fea1cf6c642c1b471b9f56fdba98c9/libs/com/github/serpapi/google-search-results-java/2.0.3/google-search-results-java-2.0.3.jar
--------------------------------------------------------------------------------
/mvnw:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | # ----------------------------------------------------------------------------
3 | # Licensed to the Apache Software Foundation (ASF) under one
4 | # or more contributor license agreements. See the NOTICE file
5 | # distributed with this work for additional information
6 | # regarding copyright ownership. The ASF licenses this file
7 | # to you under the Apache License, Version 2.0 (the
8 | # "License"); you may not use this file except in compliance
9 | # with the License. You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing,
14 | # software distributed under the License is distributed on an
15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 | # KIND, either express or implied. See the License for the
17 | # specific language governing permissions and limitations
18 | # under the License.
19 | # ----------------------------------------------------------------------------
20 |
21 | # ----------------------------------------------------------------------------
22 | # Apache Maven Wrapper startup batch script, version 3.2.0
23 | #
24 | # Required ENV vars:
25 | # ------------------
26 | # JAVA_HOME - location of a JDK home dir
27 | #
28 | # Optional ENV vars
29 | # -----------------
30 | # MAVEN_OPTS - parameters passed to the Java VM when running Maven
31 | # e.g. to debug Maven itself, use
32 | # set MAVEN_OPTS=-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=8000
33 | # MAVEN_SKIP_RC - flag to disable loading of mavenrc files
34 | # ----------------------------------------------------------------------------
35 |
36 | if [ -z "$MAVEN_SKIP_RC" ] ; then
37 |
38 | if [ -f /usr/local/etc/mavenrc ] ; then
39 | . /usr/local/etc/mavenrc
40 | fi
41 |
42 | if [ -f /etc/mavenrc ] ; then
43 | . /etc/mavenrc
44 | fi
45 |
46 | if [ -f "$HOME/.mavenrc" ] ; then
47 | . "$HOME/.mavenrc"
48 | fi
49 |
50 | fi
51 |
52 | # OS specific support. $var _must_ be set to either true or false.
53 | cygwin=false;
54 | darwin=false;
55 | mingw=false
56 | case "$(uname)" in
57 | CYGWIN*) cygwin=true ;;
58 | MINGW*) mingw=true;;
59 | Darwin*) darwin=true
60 | # Use /usr/libexec/java_home if available, otherwise fall back to /Library/Java/Home
61 | # See https://developer.apple.com/library/mac/qa/qa1170/_index.html
62 | if [ -z "$JAVA_HOME" ]; then
63 | if [ -x "/usr/libexec/java_home" ]; then
64 | JAVA_HOME="$(/usr/libexec/java_home)"; export JAVA_HOME
65 | else
66 | JAVA_HOME="/Library/Java/Home"; export JAVA_HOME
67 | fi
68 | fi
69 | ;;
70 | esac
71 |
72 | if [ -z "$JAVA_HOME" ] ; then
73 | if [ -r /etc/gentoo-release ] ; then
74 | JAVA_HOME=$(java-config --jre-home)
75 | fi
76 | fi
77 |
78 | # For Cygwin, ensure paths are in UNIX format before anything is touched
79 | if $cygwin ; then
80 | [ -n "$JAVA_HOME" ] &&
81 | JAVA_HOME=$(cygpath --unix "$JAVA_HOME")
82 | [ -n "$CLASSPATH" ] &&
83 | CLASSPATH=$(cygpath --path --unix "$CLASSPATH")
84 | fi
85 |
86 | # For Mingw, ensure paths are in UNIX format before anything is touched
87 | if $mingw ; then
88 | [ -n "$JAVA_HOME" ] && [ -d "$JAVA_HOME" ] &&
89 | JAVA_HOME="$(cd "$JAVA_HOME" || (echo "cannot cd into $JAVA_HOME."; exit 1); pwd)"
90 | fi
91 |
92 | if [ -z "$JAVA_HOME" ]; then
93 | javaExecutable="$(which javac)"
94 | if [ -n "$javaExecutable" ] && ! [ "$(expr "\"$javaExecutable\"" : '\([^ ]*\)')" = "no" ]; then
95 | # readlink(1) is not available as standard on Solaris 10.
96 | readLink=$(which readlink)
97 | if [ ! "$(expr "$readLink" : '\([^ ]*\)')" = "no" ]; then
98 | if $darwin ; then
99 | javaHome="$(dirname "\"$javaExecutable\"")"
100 | javaExecutable="$(cd "\"$javaHome\"" && pwd -P)/javac"
101 | else
102 | javaExecutable="$(readlink -f "\"$javaExecutable\"")"
103 | fi
104 | javaHome="$(dirname "\"$javaExecutable\"")"
105 | javaHome=$(expr "$javaHome" : '\(.*\)/bin')
106 | JAVA_HOME="$javaHome"
107 | export JAVA_HOME
108 | fi
109 | fi
110 | fi
111 |
112 | if [ -z "$JAVACMD" ] ; then
113 | if [ -n "$JAVA_HOME" ] ; then
114 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
115 | # IBM's JDK on AIX uses strange locations for the executables
116 | JAVACMD="$JAVA_HOME/jre/sh/java"
117 | else
118 | JAVACMD="$JAVA_HOME/bin/java"
119 | fi
120 | else
121 | JAVACMD="$(\unset -f command 2>/dev/null; \command -v java)"
122 | fi
123 | fi
124 |
125 | if [ ! -x "$JAVACMD" ] ; then
126 | echo "Error: JAVA_HOME is not defined correctly." >&2
127 | echo " We cannot execute $JAVACMD" >&2
128 | exit 1
129 | fi
130 |
131 | if [ -z "$JAVA_HOME" ] ; then
132 | echo "Warning: JAVA_HOME environment variable is not set."
133 | fi
134 |
135 | # traverses directory structure from process work directory to filesystem root
136 | # first directory with .mvn subdirectory is considered project base directory
137 | find_maven_basedir() {
138 | if [ -z "$1" ]
139 | then
140 | echo "Path not specified to find_maven_basedir"
141 | return 1
142 | fi
143 |
144 | basedir="$1"
145 | wdir="$1"
146 | while [ "$wdir" != '/' ] ; do
147 | if [ -d "$wdir"/.mvn ] ; then
148 | basedir=$wdir
149 | break
150 | fi
151 | # workaround for JBEAP-8937 (on Solaris 10/Sparc)
152 | if [ -d "${wdir}" ]; then
153 | wdir=$(cd "$wdir/.." || exit 1; pwd)
154 | fi
155 | # end of workaround
156 | done
157 | printf '%s' "$(cd "$basedir" || exit 1; pwd)"
158 | }
159 |
160 | # concatenates all lines of a file
161 | concat_lines() {
162 | if [ -f "$1" ]; then
163 | # Remove \r in case we run on Windows within Git Bash
164 | # and check out the repository with auto CRLF management
165 | # enabled. Otherwise, we may read lines that are delimited with
166 | # \r\n and produce $'-Xarg\r' rather than -Xarg due to word
167 | # splitting rules.
168 | tr -s '\r\n' ' ' < "$1"
169 | fi
170 | }
171 |
172 | log() {
173 | if [ "$MVNW_VERBOSE" = true ]; then
174 | printf '%s\n' "$1"
175 | fi
176 | }
177 |
178 | BASE_DIR=$(find_maven_basedir "$(dirname "$0")")
179 | if [ -z "$BASE_DIR" ]; then
180 | exit 1;
181 | fi
182 |
183 | MAVEN_PROJECTBASEDIR=${MAVEN_BASEDIR:-"$BASE_DIR"}; export MAVEN_PROJECTBASEDIR
184 | log "$MAVEN_PROJECTBASEDIR"
185 |
186 | ##########################################################################################
187 | # Extension to allow automatically downloading the maven-wrapper.jar from Maven-central
188 | # This allows using the maven wrapper in projects that prohibit checking in binary data.
189 | ##########################################################################################
190 | wrapperJarPath="$MAVEN_PROJECTBASEDIR/.mvn/wrapper/maven-wrapper.jar"
191 | if [ -r "$wrapperJarPath" ]; then
192 | log "Found $wrapperJarPath"
193 | else
194 | log "Couldn't find $wrapperJarPath, downloading it ..."
195 |
196 | if [ -n "$MVNW_REPOURL" ]; then
197 | wrapperUrl="$MVNW_REPOURL/org/apache/maven/wrapper/maven-wrapper/3.2.0/maven-wrapper-3.2.0.jar"
198 | else
199 | wrapperUrl="https://repo.maven.apache.org/maven2/org/apache/maven/wrapper/maven-wrapper/3.2.0/maven-wrapper-3.2.0.jar"
200 | fi
201 | while IFS="=" read -r key value; do
202 | # Remove '\r' from value to allow usage on windows as IFS does not consider '\r' as a separator ( considers space, tab, new line ('\n'), and custom '=' )
203 | safeValue=$(echo "$value" | tr -d '\r')
204 | case "$key" in (wrapperUrl) wrapperUrl="$safeValue"; break ;;
205 | esac
206 | done < "$MAVEN_PROJECTBASEDIR/.mvn/wrapper/maven-wrapper.properties"
207 | log "Downloading from: $wrapperUrl"
208 |
209 | if $cygwin; then
210 | wrapperJarPath=$(cygpath --path --windows "$wrapperJarPath")
211 | fi
212 |
213 | if command -v wget > /dev/null; then
214 | log "Found wget ... using wget"
215 | [ "$MVNW_VERBOSE" = true ] && QUIET="" || QUIET="--quiet"
216 | if [ -z "$MVNW_USERNAME" ] || [ -z "$MVNW_PASSWORD" ]; then
217 | wget $QUIET "$wrapperUrl" -O "$wrapperJarPath" || rm -f "$wrapperJarPath"
218 | else
219 | wget $QUIET --http-user="$MVNW_USERNAME" --http-password="$MVNW_PASSWORD" "$wrapperUrl" -O "$wrapperJarPath" || rm -f "$wrapperJarPath"
220 | fi
221 | elif command -v curl > /dev/null; then
222 | log "Found curl ... using curl"
223 | [ "$MVNW_VERBOSE" = true ] && QUIET="" || QUIET="--silent"
224 | if [ -z "$MVNW_USERNAME" ] || [ -z "$MVNW_PASSWORD" ]; then
225 | curl $QUIET -o "$wrapperJarPath" "$wrapperUrl" -f -L || rm -f "$wrapperJarPath"
226 | else
227 | curl $QUIET --user "$MVNW_USERNAME:$MVNW_PASSWORD" -o "$wrapperJarPath" "$wrapperUrl" -f -L || rm -f "$wrapperJarPath"
228 | fi
229 | else
230 | log "Falling back to using Java to download"
231 | javaSource="$MAVEN_PROJECTBASEDIR/.mvn/wrapper/MavenWrapperDownloader.java"
232 | javaClass="$MAVEN_PROJECTBASEDIR/.mvn/wrapper/MavenWrapperDownloader.class"
233 | # For Cygwin, switch paths to Windows format before running javac
234 | if $cygwin; then
235 | javaSource=$(cygpath --path --windows "$javaSource")
236 | javaClass=$(cygpath --path --windows "$javaClass")
237 | fi
238 | if [ -e "$javaSource" ]; then
239 | if [ ! -e "$javaClass" ]; then
240 | log " - Compiling MavenWrapperDownloader.java ..."
241 | ("$JAVA_HOME/bin/javac" "$javaSource")
242 | fi
243 | if [ -e "$javaClass" ]; then
244 | log " - Running MavenWrapperDownloader.java ..."
245 | ("$JAVA_HOME/bin/java" -cp .mvn/wrapper MavenWrapperDownloader "$wrapperUrl" "$wrapperJarPath") || rm -f "$wrapperJarPath"
246 | fi
247 | fi
248 | fi
249 | fi
250 | ##########################################################################################
251 | # End of extension
252 | ##########################################################################################
253 |
254 | # If specified, validate the SHA-256 sum of the Maven wrapper jar file
255 | wrapperSha256Sum=""
256 | while IFS="=" read -r key value; do
257 | case "$key" in (wrapperSha256Sum) wrapperSha256Sum=$value; break ;;
258 | esac
259 | done < "$MAVEN_PROJECTBASEDIR/.mvn/wrapper/maven-wrapper.properties"
260 | if [ -n "$wrapperSha256Sum" ]; then
261 | wrapperSha256Result=false
262 | if command -v sha256sum > /dev/null; then
263 | if echo "$wrapperSha256Sum $wrapperJarPath" | sha256sum -c > /dev/null 2>&1; then
264 | wrapperSha256Result=true
265 | fi
266 | elif command -v shasum > /dev/null; then
267 | if echo "$wrapperSha256Sum $wrapperJarPath" | shasum -a 256 -c > /dev/null 2>&1; then
268 | wrapperSha256Result=true
269 | fi
270 | else
271 | echo "Checksum validation was requested but neither 'sha256sum' or 'shasum' are available."
272 | echo "Please install either command, or disable validation by removing 'wrapperSha256Sum' from your maven-wrapper.properties."
273 | exit 1
274 | fi
275 | if [ $wrapperSha256Result = false ]; then
276 | echo "Error: Failed to validate Maven wrapper SHA-256, your Maven wrapper might be compromised." >&2
277 | echo "Investigate or delete $wrapperJarPath to attempt a clean download." >&2
278 | echo "If you updated your Maven version, you need to update the specified wrapperSha256Sum property." >&2
279 | exit 1
280 | fi
281 | fi
282 |
283 | MAVEN_OPTS="$(concat_lines "$MAVEN_PROJECTBASEDIR/.mvn/jvm.config") $MAVEN_OPTS"
284 |
285 | # For Cygwin, switch paths to Windows format before running java
286 | if $cygwin; then
287 | [ -n "$JAVA_HOME" ] &&
288 | JAVA_HOME=$(cygpath --path --windows "$JAVA_HOME")
289 | [ -n "$CLASSPATH" ] &&
290 | CLASSPATH=$(cygpath --path --windows "$CLASSPATH")
291 | [ -n "$MAVEN_PROJECTBASEDIR" ] &&
292 | MAVEN_PROJECTBASEDIR=$(cygpath --path --windows "$MAVEN_PROJECTBASEDIR")
293 | fi
294 |
295 | # Provide a "standardized" way to retrieve the CLI args that will
296 | # work with both Windows and non-Windows executions.
297 | MAVEN_CMD_LINE_ARGS="$MAVEN_CONFIG $*"
298 | export MAVEN_CMD_LINE_ARGS
299 |
300 | WRAPPER_LAUNCHER=org.apache.maven.wrapper.MavenWrapperMain
301 |
302 | # shellcheck disable=SC2086 # safe args
303 | exec "$JAVACMD" \
304 | $MAVEN_OPTS \
305 | $MAVEN_DEBUG_OPTS \
306 | -classpath "$MAVEN_PROJECTBASEDIR/.mvn/wrapper/maven-wrapper.jar" \
307 | "-Dmaven.multiModuleProjectDirectory=${MAVEN_PROJECTBASEDIR}" \
308 | ${WRAPPER_LAUNCHER} $MAVEN_CONFIG "$@"
309 |
--------------------------------------------------------------------------------
/mvnw.cmd:
--------------------------------------------------------------------------------
1 | @REM ----------------------------------------------------------------------------
2 | @REM Licensed to the Apache Software Foundation (ASF) under one
3 | @REM or more contributor license agreements. See the NOTICE file
4 | @REM distributed with this work for additional information
5 | @REM regarding copyright ownership. The ASF licenses this file
6 | @REM to you under the Apache License, Version 2.0 (the
7 | @REM "License"); you may not use this file except in compliance
8 | @REM with the License. You may obtain a copy of the License at
9 | @REM
10 | @REM http://www.apache.org/licenses/LICENSE-2.0
11 | @REM
12 | @REM Unless required by applicable law or agreed to in writing,
13 | @REM software distributed under the License is distributed on an
14 | @REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | @REM KIND, either express or implied. See the License for the
16 | @REM specific language governing permissions and limitations
17 | @REM under the License.
18 | @REM ----------------------------------------------------------------------------
19 |
20 | @REM ----------------------------------------------------------------------------
21 | @REM Apache Maven Wrapper startup batch script, version 3.2.0
22 | @REM
23 | @REM Required ENV vars:
24 | @REM JAVA_HOME - location of a JDK home dir
25 | @REM
26 | @REM Optional ENV vars
27 | @REM MAVEN_BATCH_ECHO - set to 'on' to enable the echoing of the batch commands
28 | @REM MAVEN_BATCH_PAUSE - set to 'on' to wait for a keystroke before ending
29 | @REM MAVEN_OPTS - parameters passed to the Java VM when running Maven
30 | @REM e.g. to debug Maven itself, use
31 | @REM set MAVEN_OPTS=-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=8000
32 | @REM MAVEN_SKIP_RC - flag to disable loading of mavenrc files
33 | @REM ----------------------------------------------------------------------------
34 |
35 | @REM Begin all REM lines with '@' in case MAVEN_BATCH_ECHO is 'on'
36 | @echo off
37 | @REM set title of command window
38 | title %0
39 | @REM enable echoing by setting MAVEN_BATCH_ECHO to 'on'
40 | @if "%MAVEN_BATCH_ECHO%" == "on" echo %MAVEN_BATCH_ECHO%
41 |
42 | @REM set %HOME% to equivalent of $HOME
43 | if "%HOME%" == "" (set "HOME=%HOMEDRIVE%%HOMEPATH%")
44 |
45 | @REM Execute a user defined script before this one
46 | if not "%MAVEN_SKIP_RC%" == "" goto skipRcPre
47 | @REM check for pre script, once with legacy .bat ending and once with .cmd ending
48 | if exist "%USERPROFILE%\mavenrc_pre.bat" call "%USERPROFILE%\mavenrc_pre.bat" %*
49 | if exist "%USERPROFILE%\mavenrc_pre.cmd" call "%USERPROFILE%\mavenrc_pre.cmd" %*
50 | :skipRcPre
51 |
52 | @setlocal
53 |
54 | set ERROR_CODE=0
55 |
56 | @REM To isolate internal variables from possible post scripts, we use another setlocal
57 | @setlocal
58 |
59 | @REM ==== START VALIDATION ====
60 | if not "%JAVA_HOME%" == "" goto OkJHome
61 |
62 | echo.
63 | echo Error: JAVA_HOME not found in your environment. >&2
64 | echo Please set the JAVA_HOME variable in your environment to match the >&2
65 | echo location of your Java installation. >&2
66 | echo.
67 | goto error
68 |
69 | :OkJHome
70 | if exist "%JAVA_HOME%\bin\java.exe" goto init
71 |
72 | echo.
73 | echo Error: JAVA_HOME is set to an invalid directory. >&2
74 | echo JAVA_HOME = "%JAVA_HOME%" >&2
75 | echo Please set the JAVA_HOME variable in your environment to match the >&2
76 | echo location of your Java installation. >&2
77 | echo.
78 | goto error
79 |
80 | @REM ==== END VALIDATION ====
81 |
82 | :init
83 |
84 | @REM Find the project base dir, i.e. the directory that contains the folder ".mvn".
85 | @REM Fallback to current working directory if not found.
86 |
87 | set MAVEN_PROJECTBASEDIR=%MAVEN_BASEDIR%
88 | IF NOT "%MAVEN_PROJECTBASEDIR%"=="" goto endDetectBaseDir
89 |
90 | set EXEC_DIR=%CD%
91 | set WDIR=%EXEC_DIR%
92 | :findBaseDir
93 | IF EXIST "%WDIR%"\.mvn goto baseDirFound
94 | cd ..
95 | IF "%WDIR%"=="%CD%" goto baseDirNotFound
96 | set WDIR=%CD%
97 | goto findBaseDir
98 |
99 | :baseDirFound
100 | set MAVEN_PROJECTBASEDIR=%WDIR%
101 | cd "%EXEC_DIR%"
102 | goto endDetectBaseDir
103 |
104 | :baseDirNotFound
105 | set MAVEN_PROJECTBASEDIR=%EXEC_DIR%
106 | cd "%EXEC_DIR%"
107 |
108 | :endDetectBaseDir
109 |
110 | IF NOT EXIST "%MAVEN_PROJECTBASEDIR%\.mvn\jvm.config" goto endReadAdditionalConfig
111 |
112 | @setlocal EnableExtensions EnableDelayedExpansion
113 | for /F "usebackq delims=" %%a in ("%MAVEN_PROJECTBASEDIR%\.mvn\jvm.config") do set JVM_CONFIG_MAVEN_PROPS=!JVM_CONFIG_MAVEN_PROPS! %%a
114 | @endlocal & set JVM_CONFIG_MAVEN_PROPS=%JVM_CONFIG_MAVEN_PROPS%
115 |
116 | :endReadAdditionalConfig
117 |
118 | SET MAVEN_JAVA_EXE="%JAVA_HOME%\bin\java.exe"
119 | set WRAPPER_JAR="%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.jar"
120 | set WRAPPER_LAUNCHER=org.apache.maven.wrapper.MavenWrapperMain
121 |
122 | set WRAPPER_URL="https://repo.maven.apache.org/maven2/org/apache/maven/wrapper/maven-wrapper/3.2.0/maven-wrapper-3.2.0.jar"
123 |
124 | FOR /F "usebackq tokens=1,2 delims==" %%A IN ("%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.properties") DO (
125 | IF "%%A"=="wrapperUrl" SET WRAPPER_URL=%%B
126 | )
127 |
128 | @REM Extension to allow automatically downloading the maven-wrapper.jar from Maven-central
129 | @REM This allows using the maven wrapper in projects that prohibit checking in binary data.
130 | if exist %WRAPPER_JAR% (
131 | if "%MVNW_VERBOSE%" == "true" (
132 | echo Found %WRAPPER_JAR%
133 | )
134 | ) else (
135 | if not "%MVNW_REPOURL%" == "" (
136 | SET WRAPPER_URL="%MVNW_REPOURL%/org/apache/maven/wrapper/maven-wrapper/3.2.0/maven-wrapper-3.2.0.jar"
137 | )
138 | if "%MVNW_VERBOSE%" == "true" (
139 | echo Couldn't find %WRAPPER_JAR%, downloading it ...
140 | echo Downloading from: %WRAPPER_URL%
141 | )
142 |
143 | powershell -Command "&{"^
144 | "$webclient = new-object System.Net.WebClient;"^
145 | "if (-not ([string]::IsNullOrEmpty('%MVNW_USERNAME%') -and [string]::IsNullOrEmpty('%MVNW_PASSWORD%'))) {"^
146 | "$webclient.Credentials = new-object System.Net.NetworkCredential('%MVNW_USERNAME%', '%MVNW_PASSWORD%');"^
147 | "}"^
148 | "[Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; $webclient.DownloadFile('%WRAPPER_URL%', '%WRAPPER_JAR%')"^
149 | "}"
150 | if "%MVNW_VERBOSE%" == "true" (
151 | echo Finished downloading %WRAPPER_JAR%
152 | )
153 | )
154 | @REM End of extension
155 |
156 | @REM If specified, validate the SHA-256 sum of the Maven wrapper jar file
157 | SET WRAPPER_SHA_256_SUM=""
158 | FOR /F "usebackq tokens=1,2 delims==" %%A IN ("%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.properties") DO (
159 | IF "%%A"=="wrapperSha256Sum" SET WRAPPER_SHA_256_SUM=%%B
160 | )
161 | IF NOT %WRAPPER_SHA_256_SUM%=="" (
162 | powershell -Command "&{"^
163 | "$hash = (Get-FileHash \"%WRAPPER_JAR%\" -Algorithm SHA256).Hash.ToLower();"^
164 | "If('%WRAPPER_SHA_256_SUM%' -ne $hash){"^
165 | " Write-Output 'Error: Failed to validate Maven wrapper SHA-256, your Maven wrapper might be compromised.';"^
166 | " Write-Output 'Investigate or delete %WRAPPER_JAR% to attempt a clean download.';"^
167 | " Write-Output 'If you updated your Maven version, you need to update the specified wrapperSha256Sum property.';"^
168 | " exit 1;"^
169 | "}"^
170 | "}"
171 | if ERRORLEVEL 1 goto error
172 | )
173 |
174 | @REM Provide a "standardized" way to retrieve the CLI args that will
175 | @REM work with both Windows and non-Windows executions.
176 | set MAVEN_CMD_LINE_ARGS=%*
177 |
178 | %MAVEN_JAVA_EXE% ^
179 | %JVM_CONFIG_MAVEN_PROPS% ^
180 | %MAVEN_OPTS% ^
181 | %MAVEN_DEBUG_OPTS% ^
182 | -classpath %WRAPPER_JAR% ^
183 | "-Dmaven.multiModuleProjectDirectory=%MAVEN_PROJECTBASEDIR%" ^
184 | %WRAPPER_LAUNCHER% %MAVEN_CONFIG% %*
185 | if ERRORLEVEL 1 goto error
186 | goto end
187 |
188 | :error
189 | set ERROR_CODE=1
190 |
191 | :end
192 | @endlocal & set ERROR_CODE=%ERROR_CODE%
193 |
194 | if not "%MAVEN_SKIP_RC%"=="" goto skipRcPost
195 | @REM check for post script, once with legacy .bat ending and once with .cmd ending
196 | if exist "%USERPROFILE%\mavenrc_post.bat" call "%USERPROFILE%\mavenrc_post.bat"
197 | if exist "%USERPROFILE%\mavenrc_post.cmd" call "%USERPROFILE%\mavenrc_post.cmd"
198 | :skipRcPost
199 |
200 | @REM pause the script if MAVEN_BATCH_PAUSE is set to 'on'
201 | if "%MAVEN_BATCH_PAUSE%"=="on" pause
202 |
203 | if "%MAVEN_TERMINATE_CMD%"=="on" exit %ERROR_CODE%
204 |
205 | cmd /C exit /B %ERROR_CODE%
206 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | dev.ai4j
8 | ai4j
9 | 0.3.0
10 | jar
11 |
12 | ai4j
13 | Java library for smooth integration with popular AI tools and services
14 | https://github.com/ai-for-java/ai4j
15 |
16 |
17 |
18 | Apache-2.0
19 | https://www.apache.org/licenses/LICENSE-2.0.txt
20 | repo
21 | A business-friendly OSS license
22 |
23 |
24 |
25 |
26 | https://github.com/ai-for-java/ai4j
27 | scm:git:git://github.com/ai-for-java/ai4j.git
28 | scm:git:git@github.com:ai-for-java/ai4j.git
29 |
30 |
31 |
32 |
33 | deep-learning-dynamo
34 | deeplearningdynamo@gmail.com
35 | https://github.com/deep-learning-dynamo
36 |
37 |
38 | kuraleta
39 | digital.kuraleta@gmail.com
40 | https://github.com/kuraleta
41 |
42 |
43 |
44 |
45 | 1.8
46 | 1.8
47 | UTF-8
48 | 5.9.3
49 |
50 |
51 |
52 |
53 |
54 | dev.ai4j
55 | ai4j-core
56 | 0.3.0
57 |
58 |
59 |
60 | dev.ai4j
61 | openai4j
62 | 0.2.0
63 |
64 |
65 |
66 |
67 | org.projectlombok
68 | lombok
69 | 1.18.26
70 | provided
71 |
72 |
73 |
74 | org.apache.pdfbox
75 | pdfbox
76 | 2.0.28
77 |
78 |
79 |
80 | com.google.code.gson
81 | gson
82 | 2.10.1
83 |
84 |
85 |
86 | org.jsoup
87 | jsoup
88 | 1.16.1
89 |
90 |
91 |
92 | org.slf4j
93 | slf4j-api
94 | 2.0.7
95 |
96 |
97 |
98 | com.knuddels
99 | jtokkit
100 | 0.4.0
101 |
102 |
103 |
104 | org.junit.jupiter
105 | junit-jupiter-engine
106 | ${junit.version}
107 | test
108 |
109 |
110 |
111 | org.junit.jupiter
112 | junit-jupiter-params
113 | ${junit.version}
114 | test
115 |
116 |
117 |
118 | org.assertj
119 | assertj-core
120 | 3.24.2
121 | test
122 |
123 |
124 |
125 |
126 |
127 |
128 | ossrh
129 | https://s01.oss.sonatype.org/content/repositories/snapshots
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 | org.sonatype.plugins
138 | nexus-staging-maven-plugin
139 | 1.6.13
140 | true
141 |
142 | ossrh
143 | https://s01.oss.sonatype.org/
144 | false
145 |
146 |
147 |
148 |
149 | org.apache.maven.plugins
150 | maven-source-plugin
151 | 3.2.1
152 |
153 |
154 | attach-sources
155 |
156 | jar-no-fork
157 |
158 |
159 |
160 |
161 |
162 |
163 | org.apache.maven.plugins
164 | maven-javadoc-plugin
165 | 3.5.0
166 |
167 |
168 | attach-javadocs
169 |
170 | jar
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 | sign
183 |
184 |
185 | sign
186 |
187 |
188 |
189 |
190 |
191 | org.apache.maven.plugins
192 | maven-gpg-plugin
193 | 3.0.1
194 |
195 |
196 | sign-artifacts
197 | verify
198 |
199 | sign
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/agent/tool/webpage/WebPageScrapperTool.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.agent.tool.webpage;
2 |
3 | import dev.ai4j.agent.Tool;
4 | import lombok.val;
5 | import org.jsoup.Jsoup;
6 |
7 | import java.io.IOException;
8 | import java.util.Optional;
9 |
10 | public class WebPageScrapperTool implements Tool {
11 |
12 | @Override
13 | public String id() {
14 | return "webpage-scrapper";
15 | }
16 |
17 | @Override
18 | public String description() {
19 | return "A portal to the internet. Use this when you need to get content from a specific web page."
20 | + " You should provide a valid URL as an input and the tool with output all the text from that web page.";
21 | }
22 |
23 | @Override
24 | public Optional execute(String webPageUri) {
25 | try {
26 | val webPage = Jsoup.connect(webPageUri).get();
27 | return Optional.of(webPage.text()); // TODO try html
28 | } catch (IOException e) {
29 | // TODO retry?
30 | }
31 |
32 | return Optional.empty();
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/document/loader/PdfFileLoader.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.document.loader;
2 |
3 | import dev.ai4j.document.Document;
4 | import dev.ai4j.document.DocumentLoader;
5 | import lombok.AllArgsConstructor;
6 | import lombok.SneakyThrows;
7 | import lombok.val;
8 | import org.apache.pdfbox.pdmodel.PDDocument;
9 | import org.apache.pdfbox.text.PDFTextStripper;
10 |
11 | import java.io.File;
12 |
13 | @AllArgsConstructor
14 | public class PdfFileLoader implements DocumentLoader {
15 |
16 | private final String absolutePathToPdfFile;
17 |
18 | @Override
19 | @SneakyThrows
20 | public Document load() {
21 | val pdfFile = new File(absolutePathToPdfFile);
22 | val pdfDocument = PDDocument.load(pdfFile);
23 | val stripper = new PDFTextStripper();
24 | val text = stripper.getText(pdfDocument);
25 | pdfDocument.close();
26 | return new Document(text);
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/document/loader/TextFileLoader.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.document.loader;
2 |
3 | import dev.ai4j.document.Document;
4 | import dev.ai4j.document.DocumentLoader;
5 | import lombok.Builder;
6 | import lombok.SneakyThrows;
7 | import lombok.val;
8 |
9 | import java.nio.charset.Charset;
10 | import java.nio.file.Files;
11 | import java.nio.file.Paths;
12 |
13 | import static java.nio.charset.StandardCharsets.UTF_8;
14 |
15 | public class TextFileLoader implements DocumentLoader {
16 |
17 | private final String absolutePathToTextFile;
18 | private final Charset charset;
19 |
20 | @Builder
21 | public TextFileLoader(String absolutePathToTextFile, Charset charset) {
22 | this.absolutePathToTextFile = absolutePathToTextFile;
23 | this.charset = (charset == null) ? UTF_8 : charset;
24 | }
25 |
26 | @Override
27 | @SneakyThrows
28 | public Document load() {
29 | val fileContents = new String(Files.readAllBytes(Paths.get(absolutePathToTextFile)), charset);
30 | return Document.from(fileContents);
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/document/splitter/OverlappingDocumentSplitter.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.document.splitter;
2 |
3 | import dev.ai4j.document.Document;
4 | import dev.ai4j.document.DocumentSplitter;
5 | import lombok.Builder;
6 | import lombok.val;
7 | import lombok.var;
8 |
9 | import java.util.ArrayList;
10 | import java.util.List;
11 |
12 | public class OverlappingDocumentSplitter implements DocumentSplitter {
13 |
14 | private final int chunkSize;
15 | private final int chunkOverlap;
16 |
17 | @Builder
18 | public OverlappingDocumentSplitter(int chunkSize, int chunkOverlap) {
19 | this.chunkSize = chunkSize;
20 | this.chunkOverlap = chunkOverlap;
21 | }
22 |
23 | @Override
24 | public List split(Document document) {
25 | if (document.contents() == null || document.contents().isEmpty()) {
26 | throw new IllegalArgumentException("Document content should not be null or empty");
27 | }
28 |
29 | val contents = document.contents();
30 | val contentLength = contents.length();
31 |
32 | if (chunkSize <= 0 || chunkOverlap < 0 || chunkSize <= chunkOverlap) {
33 | throw new IllegalArgumentException(String.format("Invalid chunkSize (%s) or chunkOverlap (%s)", chunkSize, chunkOverlap));
34 | }
35 |
36 | val result = new ArrayList();
37 | if (contentLength <= chunkSize) {
38 | result.add(document);
39 | } else {
40 | for (var i = 0; i < contentLength - chunkOverlap; i += chunkSize - chunkOverlap) {
41 | val endIndex = Math.min(i + chunkSize, contentLength);
42 | val chunk = contents.substring(i, endIndex);
43 | result.add(Document.from(chunk));
44 | if (endIndex == contentLength) {
45 | break;
46 | }
47 | }
48 | }
49 |
50 | return result;
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/flows/ChatFlow.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.flows;
2 |
3 | import dev.ai4j.chat.AiMessage;
4 | import dev.ai4j.chat.ChatHistory;
5 | import dev.ai4j.chat.ChatMessage;
6 | import dev.ai4j.chat.ChatModel;
7 | import lombok.Builder;
8 |
9 | import java.util.List;
10 |
11 | import static dev.ai4j.chat.UserMessage.userMessage;
12 |
13 | public class ChatFlow {
14 |
15 | private final ChatModel chatModel;
16 | private final ChatHistory chatHistory;
17 | // TODO private final PromptTemplate promptTemplate;
18 |
19 | @Builder
20 | private ChatFlow(ChatModel chatModel, ChatHistory chatHistory) {
21 | this.chatModel = chatModel;
22 | this.chatHistory = chatHistory;
23 | }
24 |
25 | public String chat(String userMessage) {
26 | chatHistory.add(userMessage(userMessage));
27 | List history = chatHistory.history();
28 |
29 | AiMessage aiMessage = chatModel.chat(history);
30 |
31 | chatHistory.add(aiMessage);
32 | return aiMessage.contents();
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/flows/DocumentQnAFlow.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.flows;
2 |
3 | import dev.ai4j.PromptTemplate;
4 | import dev.ai4j.chat.ChatModel;
5 | import dev.ai4j.document.Document;
6 | import dev.ai4j.document.DocumentLoader;
7 | import dev.ai4j.document.DocumentSplitter;
8 | import dev.ai4j.document.splitter.OverlappingDocumentSplitter;
9 | import dev.ai4j.embedding.Embedding;
10 | import dev.ai4j.embedding.EmbeddingModel;
11 | import dev.ai4j.embedding.VectorDatabase;
12 | import lombok.Builder;
13 |
14 | import java.util.Collection;
15 | import java.util.HashMap;
16 | import java.util.List;
17 | import java.util.Map;
18 |
19 | import static java.util.stream.Collectors.joining;
20 |
21 | public class DocumentQnAFlow {
22 |
23 | private static final OverlappingDocumentSplitter DEFAULT_DOCUMENT_SPLITTER
24 | = new OverlappingDocumentSplitter(1000, 200);
25 | private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("Using the information delimited by triple angle brackets, answer the following question to the best of your ability: {{question}} <<<{{information}}>>>");
26 |
27 | private final DocumentLoader documentLoader;
28 | private final DocumentSplitter documentSplitter;
29 | private final EmbeddingModel embeddingModel;
30 | private final VectorDatabase vectorDatabase;
31 | private final PromptTemplate promptTemplate;
32 | private final ChatModel chatModel;
33 |
34 | @Builder
35 | public DocumentQnAFlow(DocumentLoader documentLoader,
36 | DocumentSplitter documentSplitter,
37 | EmbeddingModel embeddingModel, // TODO provide possibility to use same openapi key
38 | VectorDatabase vectorDatabase,
39 | PromptTemplate promptTemplate,
40 | ChatModel chatModel) {
41 | this.documentLoader = documentLoader;
42 | this.documentSplitter = documentSplitter == null ? DEFAULT_DOCUMENT_SPLITTER : documentSplitter;
43 | this.embeddingModel = embeddingModel;
44 | this.vectorDatabase = vectorDatabase;
45 | this.promptTemplate = promptTemplate == null ? DEFAULT_PROMPT_TEMPLATE : promptTemplate;
46 | this.chatModel = chatModel;
47 |
48 | init();
49 | }
50 |
51 | private void init() {
52 | Document document = documentLoader.load();
53 | List chunks = documentSplitter.split(document);
54 | Collection embeddings = embeddingModel.embed(chunks);
55 | vectorDatabase.persist(embeddings);
56 | }
57 |
58 | public String ask(String question) {
59 | Embedding questionEmbedding = embeddingModel.embed(question);
60 |
61 | List relatedEmbeddings = vectorDatabase.findRelated(questionEmbedding, 5); // TODO defaults
62 |
63 | String concatenatedEmbeddings = relatedEmbeddings.stream()
64 | .map(Embedding::contents)
65 | .collect(joining(" "));
66 |
67 | Map parameters = new HashMap<>();
68 | parameters.put("question", question);
69 | parameters.put("information", concatenatedEmbeddings);
70 |
71 | String prompt = promptTemplate.format(parameters);
72 |
73 | return chatModel.chat(prompt);
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/model/chat/OpenAiChatModel.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.model.chat;
2 |
3 | import com.google.gson.JsonElement;
4 | import com.google.gson.JsonParser;
5 | import com.google.gson.stream.JsonReader;
6 | import com.google.gson.stream.JsonToken;
7 | import dev.ai4j.PromptTemplate;
8 | import dev.ai4j.StreamingResponseHandler;
9 | import dev.ai4j.chat.AiMessage;
10 | import dev.ai4j.chat.ChatMessage;
11 | import dev.ai4j.chat.ChatModel;
12 | import dev.ai4j.chat.UserMessage;
13 | import dev.ai4j.model.completion.structured.Description;
14 | import dev.ai4j.openai4j.OpenAiClient;
15 | import dev.ai4j.openai4j.chat.ChatCompletionRequest;
16 | import dev.ai4j.openai4j.chat.ChatCompletionResponse;
17 | import dev.ai4j.openai4j.chat.Message;
18 | import dev.ai4j.openai4j.chat.Role;
19 | import dev.ai4j.utils.Json;
20 | import dev.ai4j.utils.StopWatch;
21 | import lombok.Builder;
22 | import lombok.extern.slf4j.Slf4j;
23 | import lombok.val;
24 |
25 | import java.io.StringReader;
26 | import java.time.Duration;
27 | import java.util.*;
28 |
29 | import static dev.ai4j.chat.AiMessage.aiMessage;
30 | import static dev.ai4j.chat.UserMessage.userMessage;
31 | import static dev.ai4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
32 | import static dev.ai4j.openai4j.chat.Role.*;
33 | import static dev.ai4j.utils.Json.*;
34 | import static java.util.Arrays.asList;
35 | import static java.util.stream.Collectors.toList;
36 |
37 | @Slf4j
38 | public class OpenAiChatModel implements ChatModel { // TODO all models in one "service"?
39 |
40 | private static final double DEFAULT_TEMPERATURE = 0.7;
41 | private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(60);
42 |
43 | private final OpenAiClient client;
44 | private final String modelName;
45 | private final Double temperature;
46 |
47 | // TODO consider adding here an option for system prompt configuration
48 |
49 | @Builder
50 | public OpenAiChatModel(String apiKey,
51 | String modelName,
52 | Double temperature,
53 | Duration timeout) {
54 | this.client = OpenAiClient.builder()
55 | .apiKey(apiKey)
56 | .timeout(timeout == null ? DEFAULT_TIMEOUT : timeout)
57 | .build();
58 | this.modelName = modelName == null ? GPT_3_5_TURBO : modelName;
59 | this.temperature = temperature == null ? DEFAULT_TEMPERATURE : temperature;
60 | }
61 |
62 | @Override
63 | public AiMessage chat(List messages) {
64 |
65 | ChatCompletionRequest request = ChatCompletionRequest.builder()
66 | .model(modelName)
67 | .messages(toOpenAiMessages(messages))
68 | .temperature(temperature)
69 | .build();
70 |
71 | if (log.isDebugEnabled()) {
72 | String json = toJson(request);
73 | log.debug("Sending to OpenAI:\n{}", json);
74 | }
75 | StopWatch sw = StopWatch.start();
76 |
77 | ChatCompletionResponse response = client.chatCompletion(request).execute();
78 |
79 | long secondsElapsed = sw.secondsElapsed();
80 | if (log.isDebugEnabled()) {
81 | String json = toJson(response);
82 | log.debug("Received from OpenAI in {} seconds:\n{}", secondsElapsed, json);
83 | }
84 |
85 | return aiMessage(response.content());
86 | }
87 |
88 | @Override
89 | public AiMessage chat(ChatMessage... messages) {
90 | return chat(asList(messages));
91 | }
92 |
93 | @Override
94 | public String chat(String userMessage) {
95 | AiMessage aiMessage = chat(userMessage(userMessage));
96 | return aiMessage.contents();
97 | }
98 |
99 | @Override
100 | public void chat(List messages, StreamingResponseHandler handler) {
101 | ChatCompletionRequest request = ChatCompletionRequest.builder()
102 | .model(modelName)
103 | .messages(toOpenAiMessages(messages))
104 | .temperature(temperature)
105 | .stream(true)
106 | .build();
107 |
108 | client.chatCompletion(request)
109 | .onPartialResponse(partialResponse -> {
110 | String content = partialResponse.choices().get(0).delta().content();
111 | if (content != null) {
112 | handler.onPartialResponse(content);
113 | }
114 | })
115 | .onComplete(handler::onComplete)
116 | .onError(handler::onError)
117 | .execute();
118 | }
119 |
120 | private static List toOpenAiMessages(List messages) {
121 | return messages.stream()
122 | .map(OpenAiChatModel::toOpenAiMessage)
123 | .collect(toList());
124 | }
125 |
126 | private static Message toOpenAiMessage(ChatMessage message) {
127 | Role role;
128 |
129 | if (message instanceof UserMessage) {
130 | role = USER;
131 | } else if (message instanceof AiMessage) {
132 | role = ASSISTANT;
133 | } else {
134 | role = SYSTEM;
135 | }
136 |
137 | return Message.builder()
138 | .role(role)
139 | .content(message.contents())
140 | .build();
141 | }
142 |
143 | @Override
144 | public S getOne(Class structured) {
145 | return getMultiple(structured, 1).get(0);
146 | }
147 |
148 | @Override
149 | public List getMultiple(Class structured, int n) {
150 | String description = structured.getAnnotation(Description.class).value();
151 | String jsonStructure = generateJsonStructure(structured);
152 | Optional maybeJsonExample = generateJsonExample(structured);
153 |
154 | PromptTemplate promptTemplate = PromptTemplate.from(
155 | "Provide exactly {{number_of_examples}} example(s) of {{description}} in exactly following JSON format:\n" +
156 | "{{json_structure}}\n" +
157 | "\n" +
158 | "{{maybe_example}}" +
159 | "Do not provide any other information, just valid JSON object(s)!"
160 | );
161 |
162 | Map params = new HashMap<>();
163 | params.put("number_of_examples", n);
164 | params.put("description", description);
165 | params.put("json_structure", jsonStructure);
166 | params.put("maybe_example", maybeJsonExample
167 | .map(example -> String.format("For example:\n%s\n\n", example))
168 | .orElse(""));
169 |
170 | String prompt = promptTemplate.format(params);
171 |
172 | AiMessage aiMessage = chat(userMessage(prompt));
173 |
174 | val jsonElements = parse(aiMessage.contents());
175 |
176 | return jsonElements.stream()
177 | .map(jsonElement -> {
178 | try {
179 | return Json.fromJson(jsonElement.toString(), structured);
180 | } catch (Exception e) {
181 | // TODO
182 | return null;
183 | }
184 | })
185 | .filter(Objects::nonNull)
186 | .limit(n)
187 | .collect(toList());
188 | }
189 |
190 | private static ArrayList parse(String json) {
191 | val reader = new JsonReader(new StringReader(json));
192 | reader.setLenient(true);
193 |
194 | val jsonElements = new ArrayList();
195 | try {
196 | while (reader.peek() != JsonToken.END_DOCUMENT) {
197 | jsonElements.add(JsonParser.parseReader(reader));
198 | }
199 | } catch (Exception e) {
200 | // TODO
201 | }
202 |
203 | return jsonElements;
204 | }
205 | }
206 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/model/chat/SimpleChatHistory.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.model.chat;
2 |
3 | import dev.ai4j.chat.ChatHistory;
4 | import dev.ai4j.chat.ChatMessage;
5 | import dev.ai4j.chat.SystemMessage;
6 | import dev.ai4j.chat.UserMessage;
7 | import lombok.extern.slf4j.Slf4j;
8 | import lombok.val;
9 | import lombok.var;
10 |
11 | import java.util.ArrayList;
12 | import java.util.LinkedList;
13 | import java.util.List;
14 | import java.util.Optional;
15 |
16 | @Slf4j
17 | public class SimpleChatHistory implements ChatHistory {
18 |
19 | // safety net to limit the cost in case user did not define it himself
20 | private static final int DEFAULT_CAPACITY_IN_TOKENS = 200;
21 |
22 | private final Optional maybeMessageFromSystem;
23 | private final LinkedList previousMessages;
24 | private final Integer capacityInTokens;
25 | private final Integer capacityInMessages;
26 |
27 | private SimpleChatHistory(Builder builder) {
28 | this.maybeMessageFromSystem = builder.maybeSystemMessage;
29 | this.previousMessages = builder.previousMessages;
30 | this.capacityInTokens = builder.capacityInTokens;
31 | this.capacityInMessages = builder.capacityInMessages;
32 | ensureCapacity();
33 | }
34 |
35 | @Override
36 | public void add(ChatMessage chatMessage) {
37 | previousMessages.add(chatMessage);
38 | ensureCapacity();
39 | }
40 |
41 | @Override
42 | public List history() {
43 | val messages = new ArrayList();
44 | maybeMessageFromSystem.ifPresent(messages::add);
45 | messages.addAll(previousMessages);
46 | return messages;
47 | }
48 |
49 | private void ensureCapacity() {
50 | var currentNumberOfTokensInHistory = getCurrentNumberOfTokens();
51 | var currentNumberOfMessagesInHistory = getCurrentNumberOfMessages();
52 |
53 | while ((capacityInTokens != null && currentNumberOfTokensInHistory > capacityInTokens)
54 | || (capacityInMessages != null && currentNumberOfMessagesInHistory > capacityInMessages)) {
55 |
56 | val oldestMessage = previousMessages.removeFirst();
57 |
58 | // remove all mentions of human, messageFrom
59 |
60 | log.debug("Removing the oldest message from {} '{}' ({} tokens) to comply with capacity requirements",
61 | oldestMessage instanceof UserMessage ? "user" : "AI",
62 | oldestMessage.contents(),
63 | oldestMessage.numberOfTokens());
64 |
65 | currentNumberOfTokensInHistory -= oldestMessage.numberOfTokens();
66 | currentNumberOfMessagesInHistory--;
67 | }
68 |
69 | log.debug("Current stats: { tokens: approximately {}, messages: {} }", getCurrentNumberOfTokens(), getCurrentNumberOfMessages());
70 | }
71 |
72 | private int getCurrentNumberOfTokens() {
73 | val numberOfTokensInSystemMessage = maybeMessageFromSystem.map(ChatMessage::numberOfTokens).orElse(0);
74 | val numberOfTokensInPreviousMessages = previousMessages.stream()
75 | .map(ChatMessage::numberOfTokens)
76 | .reduce(0, Integer::sum);
77 | return numberOfTokensInSystemMessage + numberOfTokensInPreviousMessages;
78 | }
79 |
80 | private int getCurrentNumberOfMessages() {
81 | return maybeMessageFromSystem.map(m -> 1).orElse(0) + previousMessages.size();
82 | }
83 |
84 | public static class Builder {
85 |
86 | private Optional maybeSystemMessage = Optional.empty();
87 | private Integer capacityInTokens = DEFAULT_CAPACITY_IN_TOKENS;
88 | private Integer capacityInMessages;
89 | private LinkedList previousMessages = new LinkedList<>();
90 |
91 | public Builder systemMessage(SystemMessage systemMessage) {
92 | this.maybeSystemMessage = Optional.ofNullable(systemMessage);
93 | return this;
94 | }
95 |
96 | public Builder systemMessage(String systemMessage) {
97 | if (systemMessage == null) {
98 | this.maybeSystemMessage = Optional.empty(); // TODO ?
99 | return this;
100 | }
101 |
102 | return systemMessage(SystemMessage.systemMessage(systemMessage));
103 | }
104 |
105 | public Builder capacityInTokens(Integer capacityInTokens) {
106 | this.capacityInTokens = capacityInTokens;
107 | return this;
108 | }
109 |
110 | public Builder removeCapacityRestrictionInTokens() {
111 | return capacityInTokens(null);
112 | }
113 |
114 | public Builder capacityInMessages(Integer capacityInMessages) {
115 | this.capacityInMessages = capacityInMessages;
116 | return this;
117 | }
118 |
119 | public Builder previousMessages(List previousMessages) {
120 | if (previousMessages == null) {
121 | return this;
122 | }
123 |
124 | this.previousMessages = new LinkedList<>(previousMessages);
125 | return this;
126 | }
127 |
128 | public SimpleChatHistory build() {
129 | return new SimpleChatHistory(this);
130 | }
131 | }
132 |
133 | public static Builder builder() {
134 | return new Builder();
135 | }
136 | }
137 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/model/completion/OpenAiCompletionModel.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.model.completion;
2 |
3 | import dev.ai4j.completion.CompletionModel;
4 | import dev.ai4j.openai4j.OpenAiClient;
5 | import dev.ai4j.openai4j.chat.ChatCompletionRequest;
6 | import dev.ai4j.openai4j.completion.CompletionRequest;
7 | import dev.ai4j.utils.StopWatch;
8 | import lombok.Builder;
9 | import lombok.extern.slf4j.Slf4j;
10 | import lombok.val;
11 |
12 | import java.time.Duration;
13 |
14 | import static dev.ai4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
15 | import static dev.ai4j.utils.Json.toJson;
16 |
17 | @Slf4j
18 | public class OpenAiCompletionModel implements CompletionModel {
19 |
20 | private static final double DEFAULT_TEMPERATURE = 0.7;
21 | private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(60);
22 |
23 | private final OpenAiClient client;
24 | private final String modelName;
25 | private final Double temperature;
26 |
27 | @Builder
28 | public OpenAiCompletionModel(String apiKey, String modelName, Double temperature, Duration timeout) {
29 | this.client = OpenAiClient.builder()
30 | .apiKey(apiKey)
31 | .timeout(timeout == null ? DEFAULT_TIMEOUT : timeout)
32 | .build();
33 | this.modelName = modelName == null ? GPT_3_5_TURBO : modelName;
34 | this.temperature = temperature == null ? DEFAULT_TEMPERATURE : temperature;
35 | }
36 |
37 | @Override
38 | public String complete(String prompt) {
39 | if (GPT_3_5_TURBO.equals(modelName)) { // TODO remove this
40 | return chatCompletion(prompt);
41 | } else {
42 | return completion(prompt);
43 | }
44 | }
45 |
46 | private String chatCompletion(String input) {
47 |
48 | val request = ChatCompletionRequest.builder()
49 | .model(modelName)
50 | .addUserMessage(input)
51 | .temperature(temperature)
52 | .build();
53 |
54 | if (log.isDebugEnabled()) {
55 | val json = toJson(request);
56 | log.debug("Sending to OpenAI:\n{}", json);
57 | }
58 | val sw = StopWatch.start();
59 |
60 | val response = client.chatCompletion(request).execute();
61 |
62 | val secondsElapsed = sw.secondsElapsed();
63 | if (log.isDebugEnabled()) {
64 | val json = toJson(response);
65 | log.debug("Received from OpenAI in {} seconds:\n{}", secondsElapsed, json);
66 | }
67 |
68 | return response.content();
69 | }
70 |
71 | private String completion(String input) {
72 |
73 | val request = CompletionRequest.builder()
74 | .model(modelName)
75 | .prompt(input)
76 | .temperature(temperature)
77 | .build();
78 |
79 | if (log.isDebugEnabled()) {
80 | val json = toJson(request);
81 | log.debug("Sending to OpenAI:\n{}", json);
82 | }
83 | val sw = StopWatch.start();
84 |
85 | val completionResult = client.completion(request).execute();
86 |
87 | val secondsElapsed = sw.secondsElapsed();
88 | if (log.isDebugEnabled()) {
89 | val json = toJson(completionResult);
90 | log.debug("Received from OpenAI in {} seconds:\n{}", secondsElapsed, json);
91 | }
92 |
93 | return completionResult.text();
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/model/completion/structured/Description.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.model.completion.structured;
2 |
3 | import java.lang.annotation.Retention;
4 |
5 | import static java.lang.annotation.RetentionPolicy.RUNTIME;
6 |
7 | @Retention(RUNTIME)
8 | public @interface Description {
9 |
10 | String value();
11 | }
12 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/model/completion/structured/Example.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.model.completion.structured;
2 |
3 | import java.lang.annotation.Retention;
4 |
5 | import static java.lang.annotation.RetentionPolicy.RUNTIME;
6 |
7 | @Retention(RUNTIME)
8 | public @interface Example {
9 |
10 | String[] value();
11 | }
12 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/model/embedding/OpenAiEmbeddingModel.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.model.embedding;
2 |
3 | import dev.ai4j.document.Document;
4 | import dev.ai4j.embedding.Embedding;
5 | import dev.ai4j.embedding.EmbeddingModel;
6 | import dev.ai4j.openai4j.OpenAiClient;
7 | import dev.ai4j.openai4j.embedding.EmbeddingRequest;
8 | import lombok.Builder;
9 | import lombok.val;
10 |
11 | import java.time.Duration;
12 | import java.util.Collection;
13 | import java.util.List;
14 | import java.util.stream.IntStream;
15 |
16 | import static dev.ai4j.model.openai.OpenAiModelName.TEXT_EMBEDDING_ADA_002;
17 | import static dev.ai4j.utils.Utils.list;
18 | import static java.util.stream.Collectors.toList;
19 |
20 | public class OpenAiEmbeddingModel implements EmbeddingModel {
21 |
22 | private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(60);
23 |
24 | private final OpenAiClient client;
25 | private final String modelName;
26 |
27 | @Builder
28 | public OpenAiEmbeddingModel(String apiKey, String modelName, Duration timeout) {
29 | this.client = OpenAiClient.builder()
30 | .apiKey(apiKey)
31 | .timeout(timeout == null ? DEFAULT_TIMEOUT : timeout)
32 | .build();
33 | this.modelName = modelName == null ? TEXT_EMBEDDING_ADA_002 : modelName;
34 | }
35 |
36 | @Override
37 | public Embedding embed(Document document) {
38 | return embed(list(document)).iterator().next();
39 | }
40 |
41 | @Override
42 | public Embedding embed(String text) {
43 | return embed(list(Document.from(text))).iterator().next();
44 | }
45 |
46 | @Override
47 | public Collection embed(Collection documents) {
48 | val documentContents = documents.stream()
49 | .map(Document::contents)
50 | .collect(toList());
51 |
52 | val embeddingRequest = EmbeddingRequest.builder()
53 | .input(documentContents) // TODO handle newlines ?
54 | .model(modelName)
55 | .build();
56 |
57 | val openAiEmbeddings = client.embedding(embeddingRequest).execute().data();
58 |
59 | return zip(documentContents, openAiEmbeddings);
60 | }
61 |
62 | private static List zip(List documentTexts, List openAiEmbeddings) {
63 | return IntStream.range(0, documentTexts.size())
64 | .mapToObj(i -> new Embedding(documentTexts.get(i), toDoubles(openAiEmbeddings.get(i).embedding())))
65 | .collect(toList());
66 | }
67 |
68 | private static List toDoubles(List floats) {
69 | return floats.stream()
70 | .map(Float::doubleValue)
71 | .collect(toList());
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/model/openai/OpenAiModelName.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.model.openai;
2 |
3 | public class OpenAiModelName {
4 |
5 | public static final String GPT_3_5_TURBO = "gpt-3.5-turbo";
6 | public static final String GPT_4 = "gpt-4";
7 | public static final String GPT_4_32K = "gpt-4-32k";
8 | public static final String CODE_DAVINCI_002 = "code-davinci-002";
9 | public static final String TEXT_DAVINCI_002 = "text-davinci-002";
10 | public static final String TEXT_DAVINCI_003 = "text-davinci-003";
11 | public static final String TEXT_EMBEDDING_ADA_002 = "text-embedding-ada-002";
12 | }
13 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/utils/Json.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.utils;
2 |
3 | import com.google.gson.Gson;
4 | import com.google.gson.GsonBuilder;
5 | import dev.ai4j.model.completion.structured.Description;
6 | import dev.ai4j.model.completion.structured.Example;
7 |
8 | import java.lang.reflect.Field;
9 | import java.lang.reflect.ParameterizedType;
10 | import java.util.Collection;
11 | import java.util.Optional;
12 |
13 | import static java.util.Arrays.stream;
14 | import static java.util.stream.Collectors.joining;
15 |
16 | public class Json {
17 |
18 | private static final Gson GSON = new GsonBuilder().setPrettyPrinting().create();
19 |
20 | public static String toJson(Object o) {
21 | return GSON.toJson(o);
22 | }
23 |
24 | public static T fromJson(String json, Class type) {
25 | return GSON.fromJson(json, type);
26 | }
27 |
28 | public static String generateJsonStructure(Class structured) {
29 | StringBuilder jsonStructure = new StringBuilder();
30 |
31 | jsonStructure.append("{\n");
32 | for (Field field : structured.getDeclaredFields()) {
33 | Description fieldDescription = field.getAnnotation(Description.class);
34 | if (fieldDescription == null) {
35 | throw new RuntimeException(String.format("Field %s is not annotated with @Description(\"...\")", field.getName()));
36 | }
37 | jsonStructure.append(String.format("\"%s\": // %s,\n", field.getName(), fieldDescription.value()));
38 | }
39 | jsonStructure.deleteCharAt(jsonStructure.length() - 2);
40 | jsonStructure.append("}");
41 |
42 | return jsonStructure.toString();
43 | }
44 |
45 | public static Optional generateJsonExample(Class structured) {
46 | if (!hasExamples(structured)) {
47 | return Optional.empty();
48 | }
49 |
50 | StringBuilder jsonExample = new StringBuilder();
51 |
52 | jsonExample.append("{\n");
53 | for (Field field : structured.getDeclaredFields()) {
54 | Example fieldExample = field.getAnnotation(Example.class);
55 | if (fieldExample == null) {
56 | throw new RuntimeException(String.format("Field %s is not annotated with @Example(\"...\")", field.getName()));
57 | }
58 | jsonExample.append(String.format("\"%s\": %s,\n", field.getName(), toJsonExample(field)));
59 | }
60 | jsonExample.deleteCharAt(jsonExample.length() - 2);
61 | jsonExample.append("}");
62 |
63 | return Optional.of(jsonExample.toString());
64 | }
65 |
66 | private static boolean hasExamples(Class structured) {
67 | return stream(structured.getDeclaredFields())
68 | .anyMatch(field -> field.isAnnotationPresent(Example.class));
69 | }
70 |
71 | public static String toJsonExample(Field field) {
72 | Example fieldExample = field.getAnnotation(Example.class);
73 | String[] examples = fieldExample.value();
74 |
75 | Class> fieldType = field.getType();
76 | boolean wrapInQuotes = fieldType == String.class
77 | || fieldType == String[].class
78 | || isCollectionOfStrings(field);
79 |
80 | if (examples.length == 1) {
81 | if (wrapInQuotes) {
82 | return "\"" + examples[0] + "\"";
83 | }
84 | return examples[0];
85 | }
86 |
87 | return String.format("[%s]", stream(examples).map(example -> {
88 | if (wrapInQuotes) {
89 | return "\"" + example + "\"";
90 | }
91 |
92 | return example;
93 | }).collect(joining(", ")));
94 | }
95 |
96 | private static boolean isCollectionOfStrings(Field field) {
97 | if (!Collection.class.isAssignableFrom(field.getType())) {
98 | return false;
99 | }
100 |
101 | ParameterizedType genericType = (ParameterizedType) field.getGenericType();
102 | Class> actualTypeArgument = (Class>) genericType.getActualTypeArguments()[0];
103 |
104 | return actualTypeArgument.equals(String.class);
105 | }
106 | }
107 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/utils/StopWatch.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.utils;
2 |
3 | public class StopWatch {
4 |
5 | private final long startTime;
6 |
7 | public StopWatch(long currentTimeMillis) {
8 | this.startTime = currentTimeMillis;
9 | }
10 |
11 | public static StopWatch start() {
12 | return new StopWatch(System.currentTimeMillis());
13 | }
14 |
15 | public int secondsElapsed() {
16 | return (int) (System.currentTimeMillis() - startTime) / 1000;
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/src/main/java/dev/ai4j/utils/Utils.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.utils;
2 |
3 | import java.util.Arrays;
4 | import java.util.List;
5 |
6 | public class Utils {
7 |
8 | public static List list(T... elements) {
9 | return Arrays.asList(elements);
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/src/test/java/TestIt.java:
--------------------------------------------------------------------------------
1 | import dev.ai4j.openai4j.OpenAiClient;
2 | import dev.ai4j.openai4j.chat.ChatCompletionRequest;
3 | import dev.ai4j.openai4j.chat.ChatCompletionResponse;
4 | import org.junit.jupiter.api.Test;
5 |
6 | import static dev.ai4j.openai4j.Model.GPT_4;
7 |
8 | public class TestIt {
9 |
10 | @Test
11 | void test() {
12 |
13 | OpenAiClient client = new OpenAiClient(System.getenv("OPENAI_API_KEY"));
14 |
15 | ChatCompletionRequest request = ChatCompletionRequest.builder()
16 | .model(GPT_4)
17 | .addSystemMessage("You are a helpful assistant")
18 | .addUserMessage("Tell me a joke")
19 | .temperature(0.7)
20 | .build();
21 |
22 | ChatCompletionResponse response = client.chatCompletion(request).execute();
23 | System.out.println(response.content());
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/src/test/java/dev/ai4j/document/loader/TextFileLoaderTest.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.document.loader;
2 |
3 | import lombok.val;
4 | import org.junit.jupiter.api.Test;
5 |
6 | import java.nio.file.NoSuchFileException;
7 |
8 | import static java.nio.charset.StandardCharsets.ISO_8859_1;
9 | import static org.assertj.core.api.Assertions.assertThat;
10 | import static org.assertj.core.api.Assertions.assertThatThrownBy;
11 |
12 | class TextFileLoaderTest {
13 |
14 | @Test
15 | void should_load_text_file_with_utf8_charset_by_default() {
16 |
17 | val loader = TextFileLoader.builder()
18 | .absolutePathToTextFile(System.getProperty("user.dir") + "/src/test/java/dev/ai4j/document/loader/test-file-utf8.txt")
19 | .build();
20 |
21 | val document = loader.load();
22 |
23 | assertThat(document.contents()).isEqualTo("test\ncontent");
24 | }
25 |
26 | @Test
27 | void should_load_text_file_with_specified_charset() {
28 |
29 | val loader = TextFileLoader.builder()
30 | .absolutePathToTextFile(System.getProperty("user.dir") + "/src/test/java/dev/ai4j/document/loader/test-file-iso-8859-1.txt")
31 | .charset(ISO_8859_1)
32 | .build();
33 |
34 | val document = loader.load();
35 |
36 | assertThat(document.contents()).isEqualTo("test\ncontent");
37 | }
38 |
39 | @Test
40 | void should_fail_to_load_not_existing_file() {
41 |
42 | val loader = TextFileLoader.builder()
43 | .absolutePathToTextFile(System.getProperty("user.dir") + "banana")
44 | .build();
45 |
46 | assertThatThrownBy(loader::load).isInstanceOf(NoSuchFileException.class);
47 | }
48 | }
--------------------------------------------------------------------------------
/src/test/java/dev/ai4j/document/loader/test-file-iso-8859-1.txt:
--------------------------------------------------------------------------------
1 | test
2 | content
--------------------------------------------------------------------------------
/src/test/java/dev/ai4j/document/loader/test-file-utf8.txt:
--------------------------------------------------------------------------------
1 | test
2 | content
--------------------------------------------------------------------------------
/src/test/java/dev/ai4j/document/splitter/OverlappingDocumentSplitterTest.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.document.splitter;
2 |
3 | import dev.ai4j.document.Document;
4 | import lombok.val;
5 | import org.junit.jupiter.api.Test;
6 | import org.junit.jupiter.params.ParameterizedTest;
7 | import org.junit.jupiter.params.provider.CsvSource;
8 | import org.junit.jupiter.params.provider.NullAndEmptySource;
9 |
10 | import java.util.Arrays;
11 | import java.util.List;
12 |
13 | import static dev.ai4j.document.Document.from;
14 | import static org.junit.jupiter.api.Assertions.assertEquals;
15 | import static org.junit.jupiter.api.Assertions.assertThrows;
16 | import static org.junit.jupiter.api.Assertions.assertTrue;
17 |
18 | class OverlappingDocumentSplitterTest {
19 |
20 | // TODO add more test variety
21 | @Test
22 | void testDocumentIsSplit() {
23 | val sut = new OverlappingDocumentSplitter(4, 2);
24 | List result = sut.split(new Document("1234567890"));
25 |
26 | List expected = Arrays.asList(
27 | from("1234"),
28 | from("3456"),
29 | from("5678"),
30 | from("7890")
31 | );
32 |
33 | assertEquals(expected, result);
34 | }
35 |
36 | @ParameterizedTest
37 | @CsvSource({"0,-1", "-1,-1", "-1,0", "0,0", "0,1", "1,-1", "1,1", "1,2"})
38 | void testIllegalArgumentExceptionWhenChunkSizeAndChunkOverlapMisconfigured(int chunkSize, int chunkOverlap) {
39 | val sut = new OverlappingDocumentSplitter(chunkSize, chunkOverlap);
40 |
41 | IllegalArgumentException thrown = assertThrows(
42 | IllegalArgumentException.class,
43 | () -> sut.split(new Document("any"))
44 | );
45 |
46 | assertTrue(thrown.getMessage()
47 | .contentEquals("Invalid chunkSize (" + chunkSize + ") or chunkOverlap (" + chunkOverlap + ")"));
48 | }
49 |
50 | @ParameterizedTest
51 | @NullAndEmptySource
52 | void testNullCase(String documentContent) {
53 | val sut = new OverlappingDocumentSplitter(4, 2);
54 |
55 | IllegalArgumentException thrown = assertThrows(
56 | IllegalArgumentException.class,
57 | () -> sut.split(new Document(documentContent))
58 | );
59 |
60 | assertTrue(thrown.getMessage().contentEquals("Document content should not be null or empty"));
61 | }
62 | }
--------------------------------------------------------------------------------
/src/test/java/dev/ai4j/model/chat/SimpleChatHistoryTest.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.model.chat;
2 |
3 | import lombok.val;
4 | import org.junit.jupiter.api.Test;
5 |
6 | import static dev.ai4j.chat.AiMessage.aiMessage;
7 | import static dev.ai4j.chat.SystemMessage.systemMessage;
8 | import static dev.ai4j.chat.UserMessage.userMessage;
9 | import static dev.ai4j.utils.TestUtils.*;
10 | import static java.util.Arrays.asList;
11 | import static org.assertj.core.api.Assertions.assertThat;
12 | import static org.assertj.core.api.Assertions.atIndex;
13 |
14 | class SimpleChatHistoryTest {
15 |
16 | @Test
17 | void should_keep_specified_number_of_tokens_in_chat_history_1() {
18 |
19 | val messageFromSystem = systemMessageWithTokens(10);
20 | val chatHistory = SimpleChatHistory.builder()
21 | .systemMessage(messageFromSystem)
22 | .capacityInTokens(30)
23 | .build();
24 | assertThat(chatHistory.history())
25 | .hasSize(1)
26 | .containsExactly(messageFromSystem);
27 |
28 | val firstMessageFromHuman = messageFromHumanWithTokens(10);
29 | chatHistory.add(firstMessageFromHuman);
30 | assertThat(chatHistory.history())
31 | .hasSize(2)
32 | .containsExactly(
33 | messageFromSystem,
34 | firstMessageFromHuman
35 | );
36 |
37 | val firstMessageFromAi = messageFromAiWithTokens(10);
38 | chatHistory.add(firstMessageFromAi);
39 | assertThat(chatHistory.history())
40 | .hasSize(3)
41 | .containsExactly(
42 | messageFromSystem,
43 | firstMessageFromHuman,
44 | firstMessageFromAi
45 | );
46 |
47 | val secondMessageFromHuman = messageFromHumanWithTokens(10);
48 | chatHistory.add(secondMessageFromHuman);
49 | assertThat(chatHistory.history())
50 | .hasSize(3)
51 | .containsExactly(
52 | messageFromSystem,
53 | // firstMessageFromHuman was removed
54 | firstMessageFromAi,
55 | secondMessageFromHuman
56 | );
57 |
58 | val secondMessageFromAi = messageFromAiWithTokens(10);
59 | chatHistory.add(secondMessageFromAi);
60 | assertThat(chatHistory.history())
61 | .hasSize(3)
62 | .containsExactly(
63 | messageFromSystem,
64 | // firstMessageFromAi was removed
65 | secondMessageFromHuman,
66 | secondMessageFromAi
67 | );
68 | }
69 |
70 | @Test
71 | void should_keep_specified_number_of_tokens_in_chat_history_2() {
72 |
73 | val messageFromSystem = systemMessageWithTokens(5);
74 | val chatHistory = SimpleChatHistory.builder()
75 | .systemMessage(messageFromSystem)
76 | .capacityInTokens(20)
77 | .build();
78 | assertThat(chatHistory.history())
79 | .hasSize(1)
80 | .containsExactly(messageFromSystem);
81 |
82 | val firstMessageFromHuman = messageFromHumanWithTokens(10);
83 | chatHistory.add(firstMessageFromHuman);
84 | assertThat(chatHistory.history())
85 | .hasSize(2)
86 | .containsExactly(
87 | messageFromSystem, // 5 tokens
88 | firstMessageFromHuman // 10 tokens
89 | );
90 |
91 | val firstMessageFromAi = messageFromAiWithTokens(10);
92 | chatHistory.add(firstMessageFromAi);
93 | assertThat(chatHistory.history())
94 | .hasSize(2)
95 | .containsExactly(
96 | messageFromSystem, // 5 tokens
97 | // firstMessageFromHuman was removed
98 | firstMessageFromAi // 10 tokens
99 | );
100 |
101 | val secondMessageFromHuman = messageFromAiWithTokens(5);
102 | chatHistory.add(secondMessageFromHuman);
103 | assertThat(chatHistory.history())
104 | .hasSize(3)
105 | .containsExactly(
106 | messageFromSystem, // 5 tokens
107 | // firstMessageFromHuman was removed
108 | firstMessageFromAi, // 10 tokens
109 | secondMessageFromHuman // 5 tokens
110 | );
111 | }
112 |
113 | @Test
114 | void should_keep_200_tokens_in_chat_history_by_default() {
115 |
116 | val messageFromSystem = systemMessageWithTokens(10);
117 | val chatHistory = SimpleChatHistory.builder()
118 | .systemMessage(messageFromSystem)
119 | // user did not configure maxTokensInHistory
120 | .build();
121 | assertThat(chatHistory.history())
122 | .hasSize(1)
123 | .containsExactly(messageFromSystem);
124 |
125 | for (int i = 0; i < 30; i++) {
126 | chatHistory.add(messageFromHumanWithTokens(10));
127 | }
128 |
129 | assertThat(chatHistory.history())
130 | .contains(messageFromSystem, atIndex(0))
131 | .hasSize(20); // 20 messages 10 tokens each = 200 tokens
132 | }
133 |
134 | @Test
135 | void should_keep_specified_number_of_tokens_in_history_without_message_from_system() {
136 |
137 | val chatHistory = SimpleChatHistory.builder()
138 | // user did not configure messageFromSystem
139 | .capacityInTokens(20)
140 | .build();
141 | assertThat(chatHistory.history())
142 | .hasSize(0);
143 |
144 | val firstMessageFromHuman = messageFromHumanWithTokens(10);
145 | chatHistory.add(firstMessageFromHuman);
146 | assertThat(chatHistory.history())
147 | .hasSize(1)
148 | .containsExactly(firstMessageFromHuman);
149 |
150 | val firstMessageFromAi = messageFromAiWithTokens(10);
151 | chatHistory.add(firstMessageFromAi);
152 | assertThat(chatHistory.history())
153 | .hasSize(2)
154 | .containsExactly(
155 | firstMessageFromHuman,
156 | firstMessageFromAi
157 | );
158 |
159 | val secondMessageFromHuman = messageFromHumanWithTokens(10);
160 | chatHistory.add(secondMessageFromHuman);
161 | assertThat(chatHistory.history())
162 | .hasSize(2)
163 | .containsExactly(
164 | // firstMessageFromHuman was removed
165 | firstMessageFromAi,
166 | secondMessageFromHuman
167 | );
168 |
169 | val secondMessageFromAi = messageFromAiWithTokens(10);
170 | chatHistory.add(secondMessageFromAi);
171 | assertThat(chatHistory.history())
172 | .hasSize(2)
173 | .containsExactly(
174 | // firstMessageFromAi was removed
175 | secondMessageFromHuman,
176 | secondMessageFromAi
177 | );
178 | }
179 |
180 | @Test
181 | void should_keep_specified_number_of_messages_in_chat_history() {
182 |
183 | val messageFromSystem = systemMessage("does not matter how many tokens");
184 | val chatHistory = SimpleChatHistory.builder()
185 | .systemMessage(messageFromSystem)
186 | .capacityInMessages(3)
187 | .build();
188 | assertThat(chatHistory.history())
189 | .hasSize(1)
190 | .containsExactly(messageFromSystem);
191 |
192 | val firstMessageFromHuman = userMessage("does not matter how many tokens");
193 | chatHistory.add(firstMessageFromHuman);
194 | assertThat(chatHistory.history())
195 | .hasSize(2)
196 | .containsExactly(
197 | messageFromSystem,
198 | firstMessageFromHuman
199 | );
200 |
201 | val firstMessageFromAi = aiMessage("does not matter how many tokens");
202 | chatHistory.add(firstMessageFromAi);
203 | assertThat(chatHistory.history())
204 | .hasSize(3)
205 | .containsExactly(
206 | messageFromSystem,
207 | firstMessageFromHuman,
208 | firstMessageFromAi
209 | );
210 |
211 | val secondMessageFromHuman = userMessage("does not matter how many tokens");
212 | chatHistory.add(secondMessageFromHuman);
213 | assertThat(chatHistory.history())
214 | .hasSize(3)
215 | .containsExactly(
216 | messageFromSystem,
217 | // firstMessageFromHuman was removed
218 | firstMessageFromAi,
219 | secondMessageFromHuman
220 | );
221 |
222 | val secondMessageFromAi = aiMessage("does not matter how many tokens");
223 | chatHistory.add(secondMessageFromAi);
224 | assertThat(chatHistory.history())
225 | .hasSize(3)
226 | .containsExactly(
227 | messageFromSystem,
228 | // firstMessageFromAi was removed
229 | secondMessageFromHuman,
230 | secondMessageFromAi
231 | );
232 | }
233 |
234 | @Test
235 | void should_keep_specified_number_of_messages_in_chat_history_without_message_from_system() {
236 |
237 | val chatHistory = SimpleChatHistory.builder()
238 | .capacityInMessages(2)
239 | .build();
240 | assertThat(chatHistory.history())
241 | .hasSize(0);
242 |
243 | val firstMessageFromHuman = userMessage("does not matter how many tokens");
244 | chatHistory.add(firstMessageFromHuman);
245 | assertThat(chatHistory.history())
246 | .hasSize(1)
247 | .containsExactly(firstMessageFromHuman);
248 |
249 | val firstMessageFromAi = aiMessage("does not matter how many tokens");
250 | chatHistory.add(firstMessageFromAi);
251 | assertThat(chatHistory.history())
252 | .hasSize(2)
253 | .containsExactly(
254 | firstMessageFromHuman,
255 | firstMessageFromAi
256 | );
257 |
258 | val secondMessageFromHuman = userMessage("does not matter how many tokens");
259 | chatHistory.add(secondMessageFromHuman);
260 | assertThat(chatHistory.history())
261 | .hasSize(2)
262 | .containsExactly(
263 | // firstMessageFromHuman was removed
264 | firstMessageFromAi,
265 | secondMessageFromHuman
266 | );
267 |
268 | val secondMessageFromAi = aiMessage("does not matter how many tokens");
269 | chatHistory.add(secondMessageFromAi);
270 | assertThat(chatHistory.history())
271 | .hasSize(2)
272 | .containsExactly(
273 | // firstMessageFromAi was removed
274 | secondMessageFromHuman,
275 | secondMessageFromAi
276 | );
277 | }
278 |
279 | @Test
280 | void should_keep_specified_number_of_tokens_and_number_of_messages_in_chat_history_1() {
281 |
282 | // In this test we will be using messages with 10 tokens each.
283 | // We will configure maxTokensInHistory(20) and maxMessagesInHistory(3):
284 | // With maxMessagesInHistory(3) we will be able to fit 3 messages into history.
285 | // But due to maxTokensInHistory(20) it will keep only 2.
286 |
287 | val messageFromSystem = systemMessageWithTokens(10);
288 | val chatHistory = SimpleChatHistory.builder()
289 | .systemMessage(messageFromSystem)
290 | .capacityInTokens(20)
291 | .capacityInMessages(3)
292 | .build();
293 | assertThat(chatHistory.history())
294 | .hasSize(1)
295 | .containsExactly(messageFromSystem);
296 |
297 | val firstMessageFromHuman = messageFromHumanWithTokens(10);
298 | chatHistory.add(firstMessageFromHuman);
299 | assertThat(chatHistory.history())
300 | .hasSize(2)
301 | .containsExactly(
302 | messageFromSystem,
303 | firstMessageFromHuman
304 | );
305 |
306 | val firstMessageFromAi = messageFromAiWithTokens(10);
307 | chatHistory.add(firstMessageFromAi);
308 | assertThat(chatHistory.history())
309 | .hasSize(2)
310 | .containsExactly(
311 | messageFromSystem,
312 | // firstMessageFromHuman was removed
313 | firstMessageFromAi
314 | );
315 |
316 | val secondMessageFromHuman = messageFromHumanWithTokens(10);
317 | chatHistory.add(secondMessageFromHuman);
318 | assertThat(chatHistory.history())
319 | .hasSize(2)
320 | .containsExactly(
321 | messageFromSystem,
322 | // firstMessageFromAi was removed
323 | secondMessageFromHuman
324 | );
325 | }
326 |
327 | @Test
328 | void should_keep_specified_number_of_tokens_and_number_of_messages_in_chat_history_2() {
329 |
330 | // In this test we will be using messages with 10 tokens each.
331 | // We will configure maxMessagesInHistory(2) and maxTokensInHistory(30):
332 | // With maxTokensInHistory(30) we will be able to fit 3 messages into history.
333 | // But due to maxMessagesInHistory(2) it will keep only 2.
334 |
335 | val messageFromSystem = systemMessageWithTokens(10);
336 | val chatHistory = SimpleChatHistory.builder()
337 | .systemMessage(messageFromSystem)
338 | .capacityInTokens(30)
339 | .capacityInMessages(2)
340 | .build();
341 | assertThat(chatHistory.history())
342 | .hasSize(1)
343 | .containsExactly(messageFromSystem);
344 |
345 | val firstMessageFromHuman = messageFromHumanWithTokens(10);
346 | chatHistory.add(firstMessageFromHuman);
347 | assertThat(chatHistory.history())
348 | .hasSize(2)
349 | .containsExactly(
350 | messageFromSystem,
351 | firstMessageFromHuman
352 | );
353 |
354 | val firstMessageFromAi = messageFromAiWithTokens(10);
355 | chatHistory.add(firstMessageFromAi);
356 | assertThat(chatHistory.history())
357 | .hasSize(2)
358 | .containsExactly(
359 | messageFromSystem,
360 | // firstMessageFromHuman was removed
361 | firstMessageFromAi
362 | );
363 |
364 | val secondMessageFromHuman = messageFromHumanWithTokens(10);
365 | chatHistory.add(secondMessageFromHuman);
366 | assertThat(chatHistory.history())
367 | .hasSize(2)
368 | .containsExactly(
369 | messageFromSystem,
370 | // firstMessageFromAi was removed
371 | secondMessageFromHuman
372 | );
373 | }
374 |
375 | @Test
376 | void should_load_previous_messages_with_token_restriction() {
377 |
378 | val previousMessages = asList(
379 | messageFromHumanWithTokens(10),
380 | messageFromAiWithTokens(10),
381 | messageFromHumanWithTokens(10),
382 | messageFromAiWithTokens(10)
383 | );
384 |
385 | val chatHistory = SimpleChatHistory.builder()
386 | .previousMessages(previousMessages)
387 | .capacityInTokens(30)
388 | .build();
389 |
390 | assertThat(chatHistory.history())
391 | .hasSize(3);
392 | }
393 |
394 | @Test
395 | void should_load_previous_messages_with_message_restriction() {
396 |
397 | val previousMessages = asList(
398 | messageFromHumanWithTokens(10),
399 | messageFromAiWithTokens(10),
400 | messageFromHumanWithTokens(10),
401 | messageFromAiWithTokens(10)
402 | );
403 |
404 | val chatHistory = SimpleChatHistory.builder()
405 | .previousMessages(previousMessages)
406 | .capacityInMessages(3)
407 | .build();
408 |
409 | assertThat(chatHistory.history())
410 | .hasSize(3);
411 | }
412 |
413 | @Test
414 | void should_keep_all_history_without_restrictions() {
415 |
416 | val chatHistory = SimpleChatHistory.builder()
417 | .removeCapacityRestrictionInTokens()
418 | .build();
419 |
420 | for (int i = 0; i < 1000; i++) {
421 | chatHistory.add(messageFromHumanWithTokens(1000));
422 | }
423 |
424 | assertThat(chatHistory.history())
425 | .hasSize(1000);
426 | }
427 | }
--------------------------------------------------------------------------------
/src/test/java/dev/ai4j/utils/TestUtils.java:
--------------------------------------------------------------------------------
1 | package dev.ai4j.utils;
2 |
3 | import dev.ai4j.Tokenizer;
4 | import dev.ai4j.chat.AiMessage;
5 | import dev.ai4j.chat.SystemMessage;
6 | import dev.ai4j.chat.UserMessage;
7 | import lombok.val;
8 | import org.junit.jupiter.api.Test;
9 | import org.junit.jupiter.params.ParameterizedTest;
10 | import org.junit.jupiter.params.provider.ValueSource;
11 |
12 | import java.util.ArrayList;
13 | import java.util.List;
14 |
15 | import static dev.ai4j.chat.AiMessage.aiMessage;
16 | import static dev.ai4j.chat.SystemMessage.systemMessage;
17 | import static dev.ai4j.chat.UserMessage.userMessage;
18 | import static org.assertj.core.api.Assertions.assertThat;
19 |
20 | public class TestUtils {
21 |
22 | private static final int EXTRA_TOKENS_PER_CHAT_MESSAGE = 5;
23 |
24 | @ParameterizedTest
25 | @ValueSource(ints = {5, 10, 25, 50, 100, 250, 500, 1000})
26 | void should_create_message_from_system_with_tokens(int numberOfTokens) {
27 | val messageFromSystem = systemMessageWithTokens(numberOfTokens);
28 | assertThat(messageFromSystem.numberOfTokens()).isEqualTo(numberOfTokens);
29 | }
30 |
31 | public static SystemMessage systemMessageWithTokens(int numberOfTokens) {
32 | return systemMessage(generateTokens(numberOfTokens - EXTRA_TOKENS_PER_CHAT_MESSAGE));
33 | }
34 |
35 | @ParameterizedTest
36 | @ValueSource(ints = {5, 10, 25, 50, 100, 250, 500, 1000})
37 | void should_create_message_from_human_with_tokens(int numberOfTokens) {
38 | val messageFromHuman = messageFromHumanWithTokens(numberOfTokens);
39 | assertThat(messageFromHuman.numberOfTokens()).isEqualTo(numberOfTokens);
40 | }
41 |
42 | public static UserMessage messageFromHumanWithTokens(int numberOfTokens) {
43 | return userMessage(generateTokens(numberOfTokens - EXTRA_TOKENS_PER_CHAT_MESSAGE));
44 | }
45 |
46 | @ParameterizedTest
47 | @ValueSource(ints = {5, 10, 25, 50, 100, 250, 500, 1000})
48 | void should_create_message_from_ai_with_tokens(int numberOfTokens) {
49 | AiMessage messageFromAi = messageFromAiWithTokens(numberOfTokens);
50 | assertThat(messageFromAi.numberOfTokens()).isEqualTo(numberOfTokens);
51 | }
52 |
53 | public static AiMessage messageFromAiWithTokens(int numberOfTokens) {
54 | return aiMessage(generateTokens(numberOfTokens - EXTRA_TOKENS_PER_CHAT_MESSAGE));
55 | }
56 |
57 | @ParameterizedTest
58 | @ValueSource(ints = {1, 2, 5, 10, 25, 50, 100, 250, 500, 1000})
59 | void should_generate_tokens(int numberOfTokens) {
60 | val tokenizer = new Tokenizer();
61 |
62 | val tokens = generateTokens(numberOfTokens);
63 |
64 | assertThat(tokenizer.countTokens(tokens)).isEqualTo(numberOfTokens);
65 | }
66 |
67 | public static String generateTokens(int n) {
68 | val tokenizer = new Tokenizer();
69 | val text = String.join(" ", repeat("one two", n));
70 | return tokenizer.decode(tokenizer.encode(text, n));
71 | }
72 |
73 | @Test
74 | void should_repeat_n_times() {
75 | assertThat(repeat("word", 1))
76 | .hasSize(1)
77 | .containsExactly("word");
78 |
79 | assertThat(repeat("word", 2))
80 | .hasSize(2)
81 | .containsExactly("word", "word");
82 |
83 | assertThat(repeat("word", 3))
84 | .hasSize(3)
85 | .containsExactly("word", "word", "word");
86 | }
87 |
88 | public static List repeat(String s, int n) {
89 | val result = new ArrayList();
90 | for (int i = 0; i < n; i++) {
91 | result.add(s);
92 | }
93 | return result;
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/src/test/java/test-file.txt:
--------------------------------------------------------------------------------
1 | test
2 | content
3 |
--------------------------------------------------------------------------------