├── .gitignore ├── .sdkmanrc ├── LICENSE ├── README.md ├── build.gradle ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── settings.gradle └── src ├── main ├── java │ └── dev │ │ └── surly │ │ └── ai │ │ └── collab │ │ ├── SpringAICollabApplication.java │ │ ├── Team.java │ │ ├── agent │ │ ├── Agent.java │ │ ├── AgentMetadata.java │ │ ├── AgentRegistry.java │ │ ├── AgentService.java │ │ └── example │ │ │ ├── Artist.java │ │ │ ├── Biographer.java │ │ │ ├── BusinessAnalyst.java │ │ │ ├── CareerCoach.java │ │ │ ├── Chronologist.java │ │ │ ├── CodeLinguist.java │ │ │ ├── ComputerAssistant.java │ │ │ ├── DefaultLLMAgent.java │ │ │ ├── Greeter.java │ │ │ ├── HealthcarePatientAdvocate.java │ │ │ ├── JobRecruiter.java │ │ │ ├── Librarian.java │ │ │ ├── Mathematician.java │ │ │ ├── ProductManager.java │ │ │ ├── SoftwareEngineer.java │ │ │ ├── SoftwareTester.java │ │ │ └── model │ │ │ ├── CompanyDetail.java │ │ │ ├── JobRateRequest.java │ │ │ ├── JobRateResponse.java │ │ │ └── MathRequest.java │ │ ├── client │ │ └── RestClientCustomizations.java │ │ ├── controller │ │ ├── TeamController.java │ │ ├── TeamRestController.java │ │ └── model │ │ │ └── TeamForm.java │ │ ├── embedding │ │ └── HealthBenefitsDocumentEtlPipeline.java │ │ ├── exception │ │ ├── ToolInvocationException.java │ │ └── ToolNotFoundException.java │ │ ├── flow │ │ ├── Flow.java │ │ ├── FlowExecution.java │ │ ├── FlowExecutionResult.java │ │ ├── ParallelFlow.java │ │ └── SequentialFlow.java │ │ ├── log │ │ └── LoggingInterceptor.java │ │ ├── nlp │ │ └── NlpService.java │ │ ├── statemachine │ │ ├── EventPublisher.java │ │ ├── Events.java │ │ ├── SimpleStateMachine.java │ │ ├── SimpleStateMachineService.java │ │ ├── StateMachineListener.java │ │ ├── States.java │ │ └── TaskEvent.java │ │ ├── task │ │ ├── AgentTaskExecutor.java │ │ ├── Task.java │ │ ├── TaskAssignment.java │ │ ├── TaskDeconstructor.java │ │ ├── TaskError.java │ │ ├── TaskPlanner.java │ │ ├── TaskResult.java │ │ └── TaskTiming.java │ │ ├── tool │ │ ├── Tool.java │ │ └── ToolMetadata.java │ │ ├── util │ │ └── ConversionUtils.java │ │ ├── validation │ │ ├── CompositeTaskResultValidator.java │ │ ├── TaskResultValidator.java │ │ └── Validator.java │ │ ├── vectorstore │ │ └── VectorStoreConfig.java │ │ └── workflow │ │ ├── WorkflowCoordinator.java │ │ ├── WorkflowState.java │ │ └── WorkflowStateMachine.java └── resources │ ├── application.properties │ ├── documents │ └── health-benefits.pdf │ ├── prompts │ ├── agent-company-focus.st │ ├── agent-determine-programming-language.st │ ├── agent-job-rating.st │ ├── choose-tool-args-no-format.st │ ├── choose-tool-args.st │ ├── choose-tool.st │ ├── task-planner-choose-agent.st │ └── task-planner-choose-agents.st │ └── templates │ ├── 404.html │ ├── error.html │ ├── fragments │ ├── footer.html │ └── header.html │ ├── index.html │ └── layouts │ └── default.html └── test └── java └── dev └── surly └── ai └── collab ├── agent └── AgentRegistryTest.java ├── flow └── FlowExecutionTest.java └── task └── TaskDeconstructorTest.java /.gitignore: -------------------------------------------------------------------------------- 1 | HELP.md 2 | .gradle 3 | build/ 4 | !gradle/wrapper/gradle-wrapper.jar 5 | !**/src/main/**/build/ 6 | !**/src/test/**/build/ 7 | 8 | ### STS ### 9 | .apt_generated 10 | .classpath 11 | .factorypath 12 | .project 13 | .settings 14 | .springBeans 15 | .sts4-cache 16 | bin/ 17 | !**/src/main/**/bin/ 18 | !**/src/test/**/bin/ 19 | 20 | ### IntelliJ IDEA ### 21 | .idea 22 | *.iws 23 | *.iml 24 | *.ipr 25 | out/ 26 | !**/src/main/**/out/ 27 | !**/src/test/**/out/ 28 | 29 | .env 30 | -------------------------------------------------------------------------------- /.sdkmanrc: -------------------------------------------------------------------------------- 1 | # Enable auto-env through the sdkman_auto_env config 2 | # Add key=value pairs of SDKs to use below 3 | java=21.0.3-amzn 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spring AI Collab 2 | 3 | An agent framework using [Spring AI](https://spring.io/projects/spring-ai). 4 | 5 | **NOTE**: This is a proof of concept and work is underway to refactor the project into a library which can be used in other Spring projects. 6 | 7 | ## Features 8 | 9 | - Support for multiple agents and tools via simple annotations. 10 | - Leverages [Spring AI](https://spring.io/projects/spring-ai) for abstractions. 11 | - Automatically selects agent and tools based on the given task. 12 | - Web chat interface to perform tasks and optionally assign an agent. 13 | - If no agent is specified, the underlying LLM is used to choose an agent based on the task. 14 | 15 | ## Roadmap 16 | 17 | Note: Some of the roadmap features depend on Chat message history which is not available in Spring AI yet. 18 | 19 | - Process multiple tasks at once. 20 | - Compose "teams" of agents that collaboratively work together to accomplish tasks. 21 | - Add JVM code creation and execution. (Java, Kotlin) 22 | 23 | ## Requirements 24 | 25 | This project uses [OpenAI](https://openai.com/) as the default LLM. 26 | 27 | - Set `OPENAI_API_KEY` environment variable. 28 | 29 | ## Build and Test 30 | 31 | To build and run tests: 32 | ```shell 33 | ./gradlew clean build 34 | ``` 35 | 36 | ## Inspired by 37 | 38 | - [Microsoft's Autogen](https://www.microsoft.com/en-us/research/project/autogen/) 39 | - [Crew AI](https://www.crewai.com/) 40 | -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | import java.util.concurrent.TimeUnit 2 | 3 | plugins { 4 | id "java" 5 | id "org.springframework.boot" version "3.3.1" 6 | id "io.spring.dependency-management" version "1.1.5" 7 | } 8 | 9 | group = 'dev.surly.ai.collab' 10 | version = '0.0.1' 11 | 12 | java { 13 | toolchain { 14 | languageVersion = JavaLanguageVersion.of(21) 15 | } 16 | } 17 | 18 | allprojects { 19 | repositories { 20 | // mavenLocal() 21 | mavenCentral() 22 | maven { url 'https://jitpack.io' } 23 | maven { url 'https://repo.spring.io/milestone' } 24 | maven { url 'https://repo.spring.io/snapshot' } 25 | } 26 | } 27 | 28 | configurations.configureEach { 29 | resolutionStrategy.cacheChangingModulesFor(0, TimeUnit.SECONDS) 30 | } 31 | 32 | ext { 33 | set('springAiVersion', "1.0.0-M1") 34 | } 35 | 36 | dependencyManagement { 37 | imports { 38 | mavenBom "org.springframework.ai:spring-ai-bom:${springAiVersion}" 39 | } 40 | } 41 | 42 | dependencies { 43 | 44 | implementation 'org.springframework.ai:spring-ai-openai-spring-boot-starter' 45 | 46 | implementation 'org.springframework.ai:spring-ai-pdf-document-reader' 47 | // use the latest version of pdfbox for bug fixes 48 | implementation 'org.apache.pdfbox:pdfbox:3.0.2' 49 | 50 | developmentOnly 'org.springframework.boot:spring-boot-devtools' 51 | 52 | implementation 'org.springframework.boot:spring-boot-starter-aop' 53 | implementation 'org.springframework.boot:spring-boot-starter-actuator' 54 | implementation 'org.springframework.boot:spring-boot-starter-thymeleaf' 55 | implementation 'org.springframework.boot:spring-boot-starter-web' 56 | 57 | implementation 'org.springdoc:springdoc-openapi-starter-webmvc-ui:2.6.0' 58 | 59 | implementation 'org.springframework.statemachine:spring-statemachine-core:4.0.0' 60 | implementation 'io.cloudevents:cloudevents-core:4.0.1' 61 | 62 | implementation 'edu.stanford.nlp:stanford-corenlp:4.5.7' 63 | implementation 'edu.stanford.nlp:stanford-corenlp:4.5.7:models' 64 | 65 | implementation 'org.apache.commons:commons-text:1.12.0' 66 | implementation 'com.vladsch.flexmark:flexmark-html2md-converter:0.64.8' 67 | 68 | compileOnly 'org.projectlombok:lombok:1.18.34' 69 | annotationProcessor 'org.projectlombok:lombok:1.18.34' 70 | 71 | testImplementation 'org.springframework.boot:spring-boot-test' 72 | testRuntimeOnly 'org.junit.platform:junit-platform-launcher' 73 | testImplementation 'org.junit.jupiter:junit-jupiter-api:5.10.0' 74 | testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.10.0' 75 | } 76 | 77 | test { 78 | useJUnitPlatform() 79 | } 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thesurlydev/spring-ai-collab/689fa1650406d2cb9149715b2ac30de8f64c4108/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | distributionBase=GRADLE_USER_HOME 2 | distributionPath=wrapper/dists 3 | distributionUrl=https\://services.gradle.org/distributions/gradle-8.8-bin.zip 4 | networkTimeout=10000 5 | validateDistributionUrl=true 6 | zipStoreBase=GRADLE_USER_HOME 7 | zipStorePath=wrapper/dists 8 | -------------------------------------------------------------------------------- /gradlew: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # 4 | # Copyright © 2015-2021 the original authors. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | 19 | ############################################################################## 20 | # 21 | # Gradle start up script for POSIX generated by Gradle. 22 | # 23 | # Important for running: 24 | # 25 | # (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is 26 | # noncompliant, but you have some other compliant shell such as ksh or 27 | # bash, then to run this script, type that shell name before the whole 28 | # command line, like: 29 | # 30 | # ksh Gradle 31 | # 32 | # Busybox and similar reduced shells will NOT work, because this script 33 | # requires all of these POSIX shell features: 34 | # * functions; 35 | # * expansions «$var», «${var}», «${var:-default}», «${var+SET}», 36 | # «${var#prefix}», «${var%suffix}», and «$( cmd )»; 37 | # * compound commands having a testable exit status, especially «case»; 38 | # * various built-in commands including «command», «set», and «ulimit». 39 | # 40 | # Important for patching: 41 | # 42 | # (2) This script targets any POSIX shell, so it avoids extensions provided 43 | # by Bash, Ksh, etc; in particular arrays are avoided. 44 | # 45 | # The "traditional" practice of packing multiple parameters into a 46 | # space-separated string is a well documented source of bugs and security 47 | # problems, so this is (mostly) avoided, by progressively accumulating 48 | # options in "$@", and eventually passing that to Java. 49 | # 50 | # Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, 51 | # and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; 52 | # see the in-line comments for details. 53 | # 54 | # There are tweaks for specific operating systems such as AIX, CygWin, 55 | # Darwin, MinGW, and NonStop. 56 | # 57 | # (3) This script is generated from the Groovy template 58 | # https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt 59 | # within the Gradle project. 60 | # 61 | # You can find Gradle at https://github.com/gradle/gradle/. 62 | # 63 | ############################################################################## 64 | 65 | # Attempt to set APP_HOME 66 | 67 | # Resolve links: $0 may be a link 68 | app_path=$0 69 | 70 | # Need this for daisy-chained symlinks. 71 | while 72 | APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path 73 | [ -h "$app_path" ] 74 | do 75 | ls=$( ls -ld "$app_path" ) 76 | link=${ls#*' -> '} 77 | case $link in #( 78 | /*) app_path=$link ;; #( 79 | *) app_path=$APP_HOME$link ;; 80 | esac 81 | done 82 | 83 | # This is normally unused 84 | # shellcheck disable=SC2034 85 | APP_BASE_NAME=${0##*/} 86 | # Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) 87 | APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit 88 | 89 | # Use the maximum available, or set MAX_FD != -1 to use that value. 90 | MAX_FD=maximum 91 | 92 | warn () { 93 | echo "$*" 94 | } >&2 95 | 96 | die () { 97 | echo 98 | echo "$*" 99 | echo 100 | exit 1 101 | } >&2 102 | 103 | # OS specific support (must be 'true' or 'false'). 104 | cygwin=false 105 | msys=false 106 | darwin=false 107 | nonstop=false 108 | case "$( uname )" in #( 109 | CYGWIN* ) cygwin=true ;; #( 110 | Darwin* ) darwin=true ;; #( 111 | MSYS* | MINGW* ) msys=true ;; #( 112 | NONSTOP* ) nonstop=true ;; 113 | esac 114 | 115 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 116 | 117 | 118 | # Determine the Java command to use to start the JVM. 119 | if [ -n "$JAVA_HOME" ] ; then 120 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 121 | # IBM's JDK on AIX uses strange locations for the executables 122 | JAVACMD=$JAVA_HOME/jre/sh/java 123 | else 124 | JAVACMD=$JAVA_HOME/bin/java 125 | fi 126 | if [ ! -x "$JAVACMD" ] ; then 127 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 128 | 129 | Please set the JAVA_HOME variable in your environment to match the 130 | location of your Java installation." 131 | fi 132 | else 133 | JAVACMD=java 134 | if ! command -v java >/dev/null 2>&1 135 | then 136 | die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 137 | 138 | Please set the JAVA_HOME variable in your environment to match the 139 | location of your Java installation." 140 | fi 141 | fi 142 | 143 | # Increase the maximum file descriptors if we can. 144 | if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then 145 | case $MAX_FD in #( 146 | max*) 147 | # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. 148 | # shellcheck disable=SC2039,SC3045 149 | MAX_FD=$( ulimit -H -n ) || 150 | warn "Could not query maximum file descriptor limit" 151 | esac 152 | case $MAX_FD in #( 153 | '' | soft) :;; #( 154 | *) 155 | # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. 156 | # shellcheck disable=SC2039,SC3045 157 | ulimit -n "$MAX_FD" || 158 | warn "Could not set maximum file descriptor limit to $MAX_FD" 159 | esac 160 | fi 161 | 162 | # Collect all arguments for the java command, stacking in reverse order: 163 | # * args from the command line 164 | # * the main class name 165 | # * -classpath 166 | # * -D...appname settings 167 | # * --module-path (only if needed) 168 | # * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. 169 | 170 | # For Cygwin or MSYS, switch paths to Windows format before running java 171 | if "$cygwin" || "$msys" ; then 172 | APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) 173 | CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) 174 | 175 | JAVACMD=$( cygpath --unix "$JAVACMD" ) 176 | 177 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 178 | for arg do 179 | if 180 | case $arg in #( 181 | -*) false ;; # don't mess with options #( 182 | /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath 183 | [ -e "$t" ] ;; #( 184 | *) false ;; 185 | esac 186 | then 187 | arg=$( cygpath --path --ignore --mixed "$arg" ) 188 | fi 189 | # Roll the args list around exactly as many times as the number of 190 | # args, so each arg winds up back in the position where it started, but 191 | # possibly modified. 192 | # 193 | # NB: a `for` loop captures its iteration list before it begins, so 194 | # changing the positional parameters here affects neither the number of 195 | # iterations, nor the values presented in `arg`. 196 | shift # remove old arg 197 | set -- "$@" "$arg" # push replacement arg 198 | done 199 | fi 200 | 201 | 202 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 203 | DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' 204 | 205 | # Collect all arguments for the java command: 206 | # * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, 207 | # and any embedded shellness will be escaped. 208 | # * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be 209 | # treated as '${Hostname}' itself on the command line. 210 | 211 | set -- \ 212 | "-Dorg.gradle.appname=$APP_BASE_NAME" \ 213 | -classpath "$CLASSPATH" \ 214 | org.gradle.wrapper.GradleWrapperMain \ 215 | "$@" 216 | 217 | # Stop when "xargs" is not available. 218 | if ! command -v xargs >/dev/null 2>&1 219 | then 220 | die "xargs is not available" 221 | fi 222 | 223 | # Use "xargs" to parse quoted args. 224 | # 225 | # With -n1 it outputs one arg per line, with the quotes and backslashes removed. 226 | # 227 | # In Bash we could simply go: 228 | # 229 | # readarray ARGS < <( xargs -n1 <<<"$var" ) && 230 | # set -- "${ARGS[@]}" "$@" 231 | # 232 | # but POSIX shell has neither arrays nor command substitution, so instead we 233 | # post-process each arg (as a line of input to sed) to backslash-escape any 234 | # character that might be a shell metacharacter, then use eval to reverse 235 | # that process (while maintaining the separation between arguments), and wrap 236 | # the whole thing up as a single "set" statement. 237 | # 238 | # This will of course break if any of these variables contains a newline or 239 | # an unmatched quote. 240 | # 241 | 242 | eval "set -- $( 243 | printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | 244 | xargs -n1 | 245 | sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | 246 | tr '\n' ' ' 247 | )" '"$@"' 248 | 249 | exec "$JAVACMD" "$@" 250 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @rem 2 | @rem Copyright 2015 the original author or authors. 3 | @rem 4 | @rem Licensed under the Apache License, Version 2.0 (the "License"); 5 | @rem you may not use this file except in compliance with the License. 6 | @rem You may obtain a copy of the License at 7 | @rem 8 | @rem https://www.apache.org/licenses/LICENSE-2.0 9 | @rem 10 | @rem Unless required by applicable law or agreed to in writing, software 11 | @rem distributed under the License is distributed on an "AS IS" BASIS, 12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | @rem See the License for the specific language governing permissions and 14 | @rem limitations under the License. 15 | @rem 16 | 17 | @if "%DEBUG%"=="" @echo off 18 | @rem ########################################################################## 19 | @rem 20 | @rem Gradle startup script for Windows 21 | @rem 22 | @rem ########################################################################## 23 | 24 | @rem Set local scope for the variables with windows NT shell 25 | if "%OS%"=="Windows_NT" setlocal 26 | 27 | set DIRNAME=%~dp0 28 | if "%DIRNAME%"=="" set DIRNAME=. 29 | @rem This is normally unused 30 | set APP_BASE_NAME=%~n0 31 | set APP_HOME=%DIRNAME% 32 | 33 | @rem Resolve any "." and ".." in APP_HOME to make it shorter. 34 | for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi 35 | 36 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 37 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" 38 | 39 | @rem Find java.exe 40 | if defined JAVA_HOME goto findJavaFromJavaHome 41 | 42 | set JAVA_EXE=java.exe 43 | %JAVA_EXE% -version >NUL 2>&1 44 | if %ERRORLEVEL% equ 0 goto execute 45 | 46 | echo. 1>&2 47 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 48 | echo. 1>&2 49 | echo Please set the JAVA_HOME variable in your environment to match the 1>&2 50 | echo location of your Java installation. 1>&2 51 | 52 | goto fail 53 | 54 | :findJavaFromJavaHome 55 | set JAVA_HOME=%JAVA_HOME:"=% 56 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 57 | 58 | if exist "%JAVA_EXE%" goto execute 59 | 60 | echo. 1>&2 61 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 62 | echo. 1>&2 63 | echo Please set the JAVA_HOME variable in your environment to match the 1>&2 64 | echo location of your Java installation. 1>&2 65 | 66 | goto fail 67 | 68 | :execute 69 | @rem Setup the command line 70 | 71 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 72 | 73 | 74 | @rem Execute Gradle 75 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* 76 | 77 | :end 78 | @rem End local scope for the variables with windows NT shell 79 | if %ERRORLEVEL% equ 0 goto mainEnd 80 | 81 | :fail 82 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 83 | rem the _cmd.exe /c_ return code! 84 | set EXIT_CODE=%ERRORLEVEL% 85 | if %EXIT_CODE% equ 0 set EXIT_CODE=1 86 | if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% 87 | exit /b %EXIT_CODE% 88 | 89 | :mainEnd 90 | if "%OS%"=="Windows_NT" endlocal 91 | 92 | :omega 93 | -------------------------------------------------------------------------------- /settings.gradle: -------------------------------------------------------------------------------- 1 | pluginManagement { 2 | repositories { 3 | gradlePluginPortal() 4 | maven { url 'https://s01.oss.sonatype.org/content/repositories/snapshots' } 5 | } 6 | } 7 | 8 | rootProject.name = 'spring-ai-collab' 9 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/SpringAICollabApplication.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab; 2 | 3 | import org.springframework.boot.SpringApplication; 4 | import org.springframework.boot.autoconfigure.SpringBootApplication; 5 | 6 | @SpringBootApplication 7 | public class SpringAICollabApplication { 8 | 9 | public static void main(String[] args) { 10 | SpringApplication.run(SpringAICollabApplication.class, args); 11 | } 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/Team.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab; 2 | 3 | import dev.surly.ai.collab.task.AgentTaskExecutor; 4 | import dev.surly.ai.collab.task.Task; 5 | import dev.surly.ai.collab.task.TaskResult; 6 | import lombok.RequiredArgsConstructor; 7 | import lombok.extern.slf4j.Slf4j; 8 | import org.springframework.stereotype.Component; 9 | 10 | import java.util.ArrayList; 11 | import java.util.List; 12 | 13 | @Component 14 | @RequiredArgsConstructor 15 | @Slf4j 16 | public class Team { 17 | 18 | private final AgentTaskExecutor agentTaskExecutor; 19 | 20 | private final List tasks = new ArrayList<>(); 21 | 22 | public void clearTasks() { 23 | tasks.clear(); 24 | } 25 | 26 | public Team addTasks(List tasks) { 27 | clearTasks(); 28 | this.tasks.addAll(tasks); 29 | return this; 30 | } 31 | 32 | public List kickoff() { 33 | return tasks.stream() 34 | .map(agentTaskExecutor::executeTask) 35 | .toList(); 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/Agent.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent; 2 | 3 | import org.springframework.stereotype.Component; 4 | 5 | import java.lang.annotation.ElementType; 6 | import java.lang.annotation.Retention; 7 | import java.lang.annotation.RetentionPolicy; 8 | import java.lang.annotation.Target; 9 | 10 | @Component 11 | @Retention(RetentionPolicy.RUNTIME) 12 | @Target(ElementType.TYPE) // Apply to classes 13 | public @interface Agent { 14 | String goal(); 15 | String background() default ""; 16 | boolean disabled() default false; 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/AgentMetadata.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent; 2 | 3 | import dev.surly.ai.collab.tool.ToolMetadata; 4 | 5 | import java.util.Map; 6 | 7 | public record AgentMetadata(String name, String goal, String background, Map tools) { 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/AgentRegistry.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent; 2 | 3 | import dev.surly.ai.collab.tool.Tool; 4 | import dev.surly.ai.collab.tool.ToolMetadata; 5 | import jakarta.annotation.PostConstruct; 6 | import lombok.extern.slf4j.Slf4j; 7 | import org.springframework.context.ApplicationContext; 8 | import org.springframework.core.annotation.AnnotationUtils; 9 | import org.springframework.stereotype.Component; 10 | 11 | import java.lang.reflect.Method; 12 | import java.util.*; 13 | 14 | @Component 15 | @Slf4j 16 | public class AgentRegistry { 17 | 18 | private final Map allAgents = new HashMap<>(); 19 | private final ApplicationContext applicationContext; 20 | 21 | public AgentRegistry(ApplicationContext applicationContext) { 22 | this.applicationContext = applicationContext; 23 | } 24 | 25 | @PostConstruct 26 | public void initializeAgents() { 27 | Map agentBeans = applicationContext.getBeansOfType(AgentService.class); 28 | agentBeans.values().forEach(agent -> { 29 | addTools(agent); 30 | registerAgent(agent); 31 | }); 32 | } 33 | 34 | private void addTools(AgentService agent) { 35 | Arrays.stream(agent.getClass().getDeclaredMethods()) 36 | .filter(this::hasToolAnnotation) 37 | .map(this::createTool) 38 | .forEach(agent::addTool); 39 | } 40 | 41 | private boolean hasToolAnnotation(Method method) { 42 | return AnnotationUtils.findAnnotation(method, Tool.class) != null; 43 | } 44 | 45 | private ToolMetadata createTool(Method method) { 46 | Tool tool = AnnotationUtils.findAnnotation(method, Tool.class); 47 | String name = Objects.requireNonNull(tool).name() != null ? tool.name() : ""; 48 | String description = tool.description() != null ? tool.description() : ""; 49 | boolean disabled = tool.disabled(); 50 | return new ToolMetadata(name, description, method, disabled); 51 | } 52 | 53 | private void registerAgent(AgentService agent) { 54 | allAgents.put(agent.getName(), agent); 55 | } 56 | 57 | public AgentService getAgent(String agentName) { 58 | return allAgents.get(agentName); 59 | } 60 | 61 | public Map allAgents() { 62 | return allAgents; 63 | } 64 | 65 | public Map enabledAgents() { 66 | return allAgents.entrySet().stream() 67 | .filter(entry -> !entry.getValue().getDisabled()) 68 | .collect(HashMap::new, (m, v) -> m.put(v.getKey(), v.getValue()), Map::putAll); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/AgentService.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent; 2 | 3 | import dev.surly.ai.collab.exception.ToolInvocationException; 4 | import dev.surly.ai.collab.exception.ToolNotFoundException; 5 | import dev.surly.ai.collab.task.Task; 6 | import dev.surly.ai.collab.task.TaskResult; 7 | import dev.surly.ai.collab.task.TaskTiming; 8 | import dev.surly.ai.collab.tool.ToolMetadata; 9 | import dev.surly.ai.collab.workflow.WorkflowState; 10 | import lombok.Getter; 11 | import lombok.RequiredArgsConstructor; 12 | import lombok.ToString; 13 | import lombok.extern.slf4j.Slf4j; 14 | import org.springframework.ai.chat.messages.Message; 15 | import org.springframework.ai.chat.messages.UserMessage; 16 | import org.springframework.ai.chat.model.ChatModel; 17 | import org.springframework.ai.chat.model.Generation; 18 | import org.springframework.ai.chat.prompt.Prompt; 19 | import org.springframework.ai.chat.prompt.PromptTemplate; 20 | import org.springframework.ai.chat.prompt.SystemPromptTemplate; 21 | import org.springframework.ai.converter.BeanOutputConverter; 22 | import org.springframework.ai.openai.OpenAiImageModel; 23 | import org.springframework.beans.factory.annotation.Autowired; 24 | import org.springframework.beans.factory.annotation.Value; 25 | import org.springframework.core.io.Resource; 26 | 27 | import java.lang.reflect.Method; 28 | import java.time.ZonedDateTime; 29 | import java.time.format.DateTimeFormatter; 30 | import java.time.format.FormatStyle; 31 | import java.util.ArrayList; 32 | import java.util.HashMap; 33 | import java.util.List; 34 | import java.util.Map; 35 | 36 | @RequiredArgsConstructor 37 | @ToString 38 | @Slf4j 39 | public class AgentService { 40 | 41 | @Autowired 42 | protected OpenAiImageModel openAiImageModel; 43 | 44 | @Autowired 45 | protected ChatModel chatModel; 46 | 47 | @Value("classpath:/prompts/choose-tool.st") 48 | private Resource chooseToolUserPrompt; 49 | 50 | @Value("classpath:/prompts/choose-tool-args.st") 51 | private Resource chooseToolArgsUserPrompt; 52 | 53 | @Value("classpath:/prompts/choose-tool-args-no-format.st") 54 | private Resource chooseToolArgsUserPromptNoFormat; 55 | 56 | @Getter 57 | private final String name; 58 | @Getter 59 | private final String goal; 60 | @Getter 61 | private final String background; 62 | @Getter 63 | private final Boolean disabled; 64 | @Getter 65 | private final Map tools = new HashMap<>(); 66 | @Getter 67 | List messages = new ArrayList<>(); 68 | 69 | public boolean canPerform(WorkflowState workflowState) { 70 | // TODO implement this method 71 | return true; 72 | } 73 | 74 | public void performTask(String task, Map context) { 75 | // TODO implement this method 76 | } 77 | 78 | public Prompt createPrompt(Resource promptTemplateResource, 79 | Map promptModel) { 80 | PromptTemplate promptTemplate = new PromptTemplate(promptTemplateResource, promptModel); 81 | return promptTemplate.create(); 82 | } 83 | 84 | public String callPromptForString(Prompt prompt) { 85 | Generation generation = chatModel.call(prompt).getResult(); 86 | return generation.getOutput().getContent(); 87 | } 88 | 89 | public Object callPromptForBean(Prompt prompt, BeanOutputConverter beanOutputConverter) { 90 | Generation generation = chatModel.call(prompt).getResult(); 91 | String out = generation.getOutput().getContent(); 92 | return beanOutputConverter.convert(out); 93 | } 94 | 95 | public void addSystemMessage(String message) { 96 | SystemPromptTemplate systemTemplate = new SystemPromptTemplate(message); 97 | messages.add(systemTemplate.createMessage()); 98 | } 99 | 100 | public void addUserMessage(String message) { 101 | UserMessage userMessage = new UserMessage(message); 102 | messages.add(userMessage); 103 | } 104 | 105 | public AgentService() { 106 | this.name = this.getClass().getSimpleName(); 107 | if (this.getClass().isAnnotationPresent(Agent.class)) { 108 | Agent annotation = this.getClass().getAnnotation(Agent.class); 109 | this.goal = annotation.goal(); 110 | this.background = annotation.background(); 111 | this.disabled = annotation.disabled(); 112 | } else { 113 | throw new IllegalStateException("Agent annotation is required on Agent classes"); 114 | } 115 | } 116 | 117 | public void addTool(ToolMetadata toolMetadata) { 118 | tools.put(toolMetadata.name(), toolMetadata); 119 | } 120 | 121 | public TaskResult executeTask(Task task) throws ToolInvocationException { 122 | log.info("Executing task: {}", task); 123 | List timings = new ArrayList<>(); 124 | 125 | if (tools.isEmpty()) { 126 | log.info("{} agent has no tools configured, executing task via LLM", this.name); 127 | return executeTaskViaLLM(task, timings); 128 | } 129 | 130 | var startChooseTool = System.currentTimeMillis(); 131 | ToolMetadata toolMetadata = chooseTool(task); 132 | timings.add(new TaskTiming("chooseTool", System.currentTimeMillis() - startChooseTool)); 133 | if (toolMetadata.name().equals("__NO_TOOL__")) { 134 | return executeTaskViaLLM(task, timings); 135 | } 136 | 137 | var startGetArgs = System.currentTimeMillis(); 138 | Object args = null; 139 | Class returnType = toolMetadata.getReturnType(); 140 | if (returnType != null) { 141 | if (returnType.isPrimitive() || "java.lang.String".equals(returnType.getName())) { 142 | args = getArgsAsString(task, toolMetadata); 143 | } else { 144 | args = getArgsAsObject(task, returnType, toolMetadata); 145 | } 146 | } 147 | timings.add(new TaskTiming("getArgs", System.currentTimeMillis() - startGetArgs)); 148 | 149 | try { 150 | var startInvokeTool = System.currentTimeMillis(); 151 | Object toolResult = invokeTool(toolMetadata.method(), args); 152 | timings.add(new TaskTiming("invokeTool", System.currentTimeMillis() - startInvokeTool)); 153 | TaskResult tr = new TaskResult(task, this.name, toolMetadata.name(), toolResult, timings); 154 | log.info("TaskResult: {}", tr); 155 | return tr; 156 | } catch (Exception e) { 157 | throw new ToolInvocationException("Error invoking tool: " + toolMetadata + " for task: " + task, e); 158 | } 159 | } 160 | 161 | private Object getArgsAsString(Task task, ToolMetadata toolMetadata) { 162 | Prompt prompt = createPrompt(chooseToolArgsUserPromptNoFormat, Map.of( 163 | "task", task.getDescription(), 164 | "signature", toolMetadata.getMethodArgsAsString() 165 | )); 166 | return callPromptForString(prompt); 167 | } 168 | 169 | private Object getArgsAsObject(Task task, Class returnType, ToolMetadata toolMetadata) { 170 | BeanOutputConverter outputConverter = new BeanOutputConverter(returnType); 171 | Prompt prompt = createPrompt(chooseToolArgsUserPrompt, Map.of( 172 | "task", task.getDescription(), 173 | "signature", toolMetadata.getMethodArgsAsString(), 174 | "format", outputConverter.getFormat() 175 | )); 176 | return callPromptForBean(prompt, outputConverter); 177 | } 178 | 179 | 180 | private T invokeTool(Method method, Object args) throws Exception { 181 | if (args == null) { 182 | log.info("Invoking method: {}", method.toString()); 183 | T result = (T) method.invoke(this); 184 | return result; 185 | } else { 186 | log.info("Invoking method: {} with args: {}", method.toString(), args); 187 | log.info("args type: {}", args.getClass().getName()); 188 | T result = (T) method.invoke(this, args); 189 | return result; 190 | } 191 | } 192 | 193 | public void addDateContext() { 194 | StringBuilder dateContext = new StringBuilder(); 195 | dateContext 196 | .append("The date and time right now is: ") 197 | .append(ZonedDateTime.now().format(DateTimeFormatter.ofLocalizedDateTime(FormatStyle.FULL))) 198 | .append(". Use this date to answer any questions related to the current date and time."); 199 | addSystemMessage(dateContext.toString()); 200 | } 201 | 202 | public void addMathInstructions() { 203 | addSystemMessage("If you perform any math calculations, please return the resulting number and nothing else. Do not return a sentence or any other text."); 204 | } 205 | 206 | protected TaskResult executeTaskViaLLM(Task task, List timings) { 207 | 208 | var startLLM = System.currentTimeMillis(); 209 | if (this.background != null && !this.background.isEmpty()) { 210 | addSystemMessage(this.background); 211 | } 212 | 213 | addDateContext(); 214 | addMathInstructions(); 215 | 216 | var taskDescription = task.getDescription(); 217 | addUserMessage(taskDescription); 218 | 219 | Prompt prompt = new Prompt(messages); 220 | String data = callPromptForString(prompt); 221 | 222 | timings.add(new TaskTiming("executeViaLLM", System.currentTimeMillis() - startLLM)); 223 | 224 | return new TaskResult(task, this.name, null, data, timings); 225 | } 226 | 227 | private ToolMetadata chooseTool(Task task) { 228 | StringBuilder toolList = new StringBuilder(); 229 | tools.values().stream() 230 | .map(toolMetadata -> toolMetadata.name() + ": " + toolMetadata.description() + "\r\n") 231 | .forEach(toolList::append); 232 | Prompt prompt = createPrompt(chooseToolUserPrompt, Map.of( 233 | "task", task.getDescription(), 234 | "tools", toolList.toString() 235 | )); 236 | String toolName = callPromptForString(prompt); 237 | if (toolName.equals("__NO_TOOL__")) { 238 | log.warn("No suitable tool found for task: {}", task.getDescription()); 239 | return new ToolMetadata("__NO_TOOL__", null, null, false); 240 | } 241 | ToolMetadata tool = tools.get(toolName); 242 | log.info("Chosen tool: {}", tool); 243 | if (tool == null) { 244 | throw new ToolNotFoundException("No tool found with name: " + toolName); 245 | } 246 | return tool; 247 | } 248 | } 249 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/Artist.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import dev.surly.ai.collab.agent.Agent; 4 | import dev.surly.ai.collab.agent.AgentService; 5 | import dev.surly.ai.collab.tool.Tool; 6 | import org.springframework.ai.image.Image; 7 | import org.springframework.ai.image.ImagePrompt; 8 | import org.springframework.ai.openai.OpenAiImageOptions; 9 | 10 | @Agent(goal = "Generate images based on the users' input.", 11 | background = "You are an expert at interpreting and generating images based on the users' input") 12 | public class Artist extends AgentService { 13 | 14 | @Tool(name = "ImageGenerator", description = "Given a user's input, generate an image based on the input") 15 | public Image generateImage(String imageDescription) { 16 | var imageResponse = openAiImageModel.call( 17 | new ImagePrompt(imageDescription, 18 | OpenAiImageOptions.builder() 19 | .withQuality("hd") 20 | .withN(1) 21 | .withHeight(1024) 22 | .withWidth(1024) 23 | .withResponseFormat("url") 24 | .withModel("dall-e-3") 25 | .build())); 26 | return imageResponse.getResult().getOutput(); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/Biographer.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import dev.surly.ai.collab.agent.AgentService; 4 | import dev.surly.ai.collab.agent.Agent; 5 | 6 | @Agent( 7 | goal = "Answer questions about who a person is.", 8 | background = """ 9 | You are an expert at providing detailed information about a person. 10 | You can provide detailed information about a person's life, achievements, and other relevant information. 11 | Answer in a detailed manner, providing as much information as possible in two to three paragraphs using markdown format. 12 | If you don't know who the person is or don't have any information on the person, just reply with: 'I don't know' 13 | """) 14 | public class Biographer extends AgentService { 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/BusinessAnalyst.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import dev.surly.ai.collab.agent.Agent; 4 | import dev.surly.ai.collab.agent.AgentService; 5 | import dev.surly.ai.collab.agent.example.model.CompanyDetail; 6 | import dev.surly.ai.collab.tool.Tool; 7 | import lombok.RequiredArgsConstructor; 8 | import lombok.extern.slf4j.Slf4j; 9 | import org.springframework.ai.chat.model.Generation; 10 | import org.springframework.ai.chat.prompt.Prompt; 11 | import org.springframework.ai.chat.prompt.PromptTemplate; 12 | import org.springframework.ai.converter.BeanOutputConverter; 13 | import org.springframework.beans.factory.annotation.Value; 14 | import org.springframework.core.io.Resource; 15 | 16 | import java.util.Map; 17 | 18 | @Agent( 19 | goal = "Provide business analysis of companies" 20 | ) 21 | @RequiredArgsConstructor 22 | @Slf4j 23 | public class BusinessAnalyst extends AgentService { 24 | 25 | @Value("classpath:/prompts/agent-company-focus.st") 26 | private Resource companyFocusUserPrompt; 27 | 28 | @Tool(name = "GetCompanyFocus", description = "Given the name of a company, return the focus of the company") 29 | public String getCompanyFocus(String companyName) { 30 | Prompt prompt = createPrompt(companyFocusUserPrompt, Map.of( 31 | "companyName", companyName 32 | )); 33 | return callPromptForString(prompt); 34 | } 35 | 36 | @Tool(name = "GetCompanyDetail", description = "Given a company name, get the details about the company including website URL") 37 | public CompanyDetail getCompanyDetails(String name) { 38 | var outputConverter = new BeanOutputConverter<>(CompanyDetail.class); 39 | 40 | String userMessage = 41 | """ 42 | Get the details including website url and address for the company: {name}. 43 | Only provide the stock ticker if the company is public. 44 | {format} 45 | """; 46 | 47 | PromptTemplate promptTemplate = new PromptTemplate(userMessage, Map.of("name", name, "format", 48 | outputConverter.getFormat())); 49 | Prompt prompt = promptTemplate.create(); 50 | 51 | log.info("Prompt: {}", prompt.toString()); 52 | 53 | Generation generation = chatModel.call(prompt).getResult(); 54 | 55 | CompanyDetail detail = outputConverter.convert(generation.getOutput().getContent()); 56 | log.info("CompanyDetail: {}", detail); 57 | return detail; 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/CareerCoach.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import dev.surly.ai.collab.agent.AgentService; 4 | import dev.surly.ai.collab.agent.Agent; 5 | 6 | @Agent(goal = "Provide guidance and support to an individual in their career development") 7 | public class CareerCoach extends AgentService { 8 | 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/Chronologist.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import dev.surly.ai.collab.agent.Agent; 4 | import dev.surly.ai.collab.agent.AgentService; 5 | import dev.surly.ai.collab.tool.Tool; 6 | import org.springframework.ai.chat.prompt.Prompt; 7 | 8 | import java.time.LocalDate; 9 | import java.time.LocalDateTime; 10 | import java.time.format.TextStyle; 11 | import java.util.Locale; 12 | 13 | @Agent(goal = "Answer questions about time", 14 | background = """ 15 | You are a professional chronologist and can answer questions about time. 16 | You can also perform various time-related tasks such as conversions and formatting. 17 | """) 18 | public class Chronologist extends AgentService { 19 | 20 | @Tool(name = "CurrentTime", description = "Return the current time in the format HH:mm:ss") 21 | public String currentTime() { 22 | LocalDateTime currentDate = LocalDateTime.now(); 23 | return currentDate.toLocalTime().format(java.time.format.DateTimeFormatter.ofPattern("HH:mm:ss")); 24 | } 25 | 26 | @Tool(name = "CurrentTimeInLocation", description = "Return the current time for a location in the format HH:mm:ss") 27 | public String currentTime(String location) { 28 | addDateContext(); 29 | addSystemMessage("Return the current time for a location in the format HH:mm:ss"); 30 | addUserMessage("What is the current time in " + location + "?"); 31 | Prompt prompt = new Prompt(getMessages()); 32 | return callPromptForString(prompt); 33 | } 34 | 35 | @Tool(name = "CurrentDate", description = "Return the current date in the format yyyy-MM-dd") 36 | public String currentDate() { 37 | LocalDate currentDate = LocalDate.now(); 38 | return currentDate.format(java.time.format.DateTimeFormatter.ofPattern("yyyy-MM-dd")); 39 | } 40 | 41 | @Tool(name = "CurrentDay", description = "Return the current day of the week") 42 | public String currentDay() { 43 | LocalDate currentDate = LocalDate.now(); 44 | return currentDate.getDayOfWeek().getDisplayName(TextStyle.FULL, Locale.ENGLISH); 45 | } 46 | 47 | @Tool(name = "CurrentMonth", description = "Return the name of the month") 48 | public String currentMonth() { 49 | LocalDate currentDate = LocalDate.now(); 50 | return currentDate.getMonth().getDisplayName(TextStyle.FULL, Locale.ENGLISH); 51 | } 52 | 53 | @Tool(name = "CurrentYear", description = "Return the current year") 54 | public Integer currentYear() { 55 | LocalDate currentDate = LocalDate.now(); 56 | return currentDate.getYear(); 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/CodeLinguist.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import dev.surly.ai.collab.agent.Agent; 4 | import dev.surly.ai.collab.agent.AgentService; 5 | import dev.surly.ai.collab.tool.Tool; 6 | import org.springframework.ai.chat.prompt.Prompt; 7 | import org.springframework.beans.factory.annotation.Value; 8 | import org.springframework.core.io.Resource; 9 | 10 | import java.util.Map; 11 | 12 | @Agent(goal = "Determine programming language from code snippet", 13 | background = """ 14 | You are an expert software engineer in all major programming languages and are adept in determining the programming 15 | language from a given code snippet. 16 | """ 17 | ) 18 | public class CodeLinguist extends AgentService { 19 | 20 | @Value("classpath:/prompts/agent-determine-programming-language.st") 21 | private Resource determineProgrammingLanguagePrompt; 22 | 23 | @Tool(name = "DetermineLanguage", description = "Determine the programming language of a given code snippet") 24 | public String determineLanguage(String code) { 25 | Prompt prompt = createPrompt(determineProgrammingLanguagePrompt, Map.of( 26 | "code", code 27 | )); 28 | return callPromptForString(prompt); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/ComputerAssistant.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import dev.surly.ai.collab.agent.AgentService; 4 | import dev.surly.ai.collab.agent.Agent; 5 | import dev.surly.ai.collab.tool.Tool; 6 | import lombok.extern.slf4j.Slf4j; 7 | 8 | import java.io.BufferedReader; 9 | import java.io.File; 10 | import java.io.IOException; 11 | import java.io.InputStreamReader; 12 | import java.util.ArrayList; 13 | import java.util.Arrays; 14 | import java.util.List; 15 | import java.util.Map; 16 | 17 | @Agent(goal = "Provide useful answers about the host computer including cpus, files and directories") 18 | @Slf4j 19 | public class ComputerAssistant extends AgentService { 20 | 21 | @Tool(name = "DirectoryReader", description = "Given a directory, list all the files in the directory.") 22 | public List readDirectory(String path) { 23 | File f = new File(path); 24 | List out = new ArrayList<>(); 25 | if (f.exists()) { 26 | String[] list = f.list(); 27 | if (list != null) { 28 | out = Arrays.asList(list); 29 | } 30 | } 31 | return out; 32 | } 33 | 34 | @Tool(name = "CPU Analyzer", description = "Describe the number of cpus") 35 | public int cpuInfo() { 36 | return Runtime.getRuntime().availableProcessors(); 37 | } 38 | 39 | @Tool(name = "RAM memory analyzer", description = "Describe the total RAM") 40 | public String ramMemory() { 41 | String out = "Unable to get RAM info"; 42 | try { 43 | String command = "grep MemTotal /proc/meminfo"; 44 | Process process = Runtime.getRuntime().exec(new String[]{"bash", "-c", command}); 45 | BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream())); 46 | 47 | String line = reader.readLine(); // Read only the first line 48 | if (line != null) { 49 | String[] parts = line.split("\\s+"); 50 | long totalMemoryKb = Long.parseLong(parts[1]); 51 | out = String.format("Total RAM: %d GB", totalMemoryKb / 1024 / 1024); 52 | } 53 | reader.close(); 54 | 55 | } catch (IOException e) { 56 | log.error("Error getting RAM memory", e); 57 | } 58 | return out; 59 | } 60 | 61 | @Tool(name = "JVM Memory Analyzer", description = "Describe the memory available to the JVM") 62 | public Map memoryInfo() { 63 | Runtime runtime = Runtime.getRuntime(); 64 | return Map.of( 65 | "total", runtime.totalMemory(), 66 | "free", runtime.freeMemory(), 67 | "used", runtime.totalMemory() - runtime.freeMemory(), 68 | "max", runtime.maxMemory() 69 | ); 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/DefaultLLMAgent.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import dev.surly.ai.collab.agent.AgentService; 4 | import dev.surly.ai.collab.agent.Agent; 5 | 6 | @Agent(goal = "Be helpful by accomplishing a wide variety of tasks and answering questions") 7 | public class DefaultLLMAgent extends AgentService { 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/Greeter.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import dev.surly.ai.collab.agent.AgentService; 4 | import dev.surly.ai.collab.agent.Agent; 5 | import dev.surly.ai.collab.tool.Tool; 6 | 7 | @Agent(goal = "Say hello", background = "You are a friendly person and greet everyone you encounter") 8 | public class Greeter extends AgentService { 9 | @Tool(name = "SayHello", description = "Be friendly and say hello") 10 | public String sayHello() { 11 | return "Hello, World!"; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/HealthcarePatientAdvocate.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import dev.surly.ai.collab.agent.Agent; 4 | import dev.surly.ai.collab.agent.AgentService; 5 | import dev.surly.ai.collab.tool.Tool; 6 | import lombok.RequiredArgsConstructor; 7 | import lombok.extern.slf4j.Slf4j; 8 | import org.springframework.ai.chat.client.ChatClient; 9 | import org.springframework.ai.document.Document; 10 | import org.springframework.ai.openai.OpenAiChatModel; 11 | import org.springframework.ai.vectorstore.SearchRequest; 12 | import org.springframework.ai.vectorstore.VectorStore; 13 | import org.springframework.context.annotation.Profile; 14 | import org.springframework.stereotype.Component; 15 | import org.springframework.web.bind.annotation.RequestMapping; 16 | import org.springframework.web.bind.annotation.RestController; 17 | 18 | import java.util.List; 19 | import java.util.stream.Collectors; 20 | 21 | @Agent(goal = "Provide assistance to patients with healthcare needs") 22 | @RequiredArgsConstructor 23 | @RestController 24 | @RequestMapping("/api/agents/healthcare-advocate") 25 | @Slf4j 26 | @Component 27 | @Profile("healthcare") 28 | public class HealthcarePatientAdvocate extends AgentService { 29 | 30 | private final VectorStore vectorStore; 31 | private final OpenAiChatModel openAiChatModel; 32 | 33 | @Tool(name = "Healthcare benefits query interface", description = "Provide information about healthcare insurance benefits") 34 | public String getHealthcareBenefitsInfo(String healthcareBenefitsQuestion) { 35 | 36 | List similarDocuments = vectorStore.similaritySearch( 37 | SearchRequest.query(healthcareBenefitsQuestion).withTopK(1) 38 | ); 39 | String content = similarDocuments.stream() 40 | .map(Document::getContent) 41 | .collect(Collectors.joining(System.lineSeparator())); 42 | 43 | var systemPromptTemplate = """ 44 | You are a helpful assistant, conversing with a user about health benefits available to them through Providence HealthPlan insurance. 45 | Use the information from the DOCUMENTS section to provide accurate answers. If unsure or if the answer 46 | isn't found in the DOCUMENTS section, simply state that you don't know the answer and do not mention 47 | the DOCUMENTS section. 48 | 49 | ## DOCUMENTS: 50 | 51 | {documents} 52 | """; 53 | 54 | return ChatClient.create(openAiChatModel) 55 | .prompt() 56 | .system(sysSpec -> sysSpec.text(systemPromptTemplate).param("documents", content)) 57 | .user(healthcareBenefitsQuestion) 58 | .call() 59 | .content(); 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/JobRecruiter.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import dev.surly.ai.collab.agent.Agent; 4 | import dev.surly.ai.collab.agent.AgentService; 5 | import dev.surly.ai.collab.agent.example.model.JobRateRequest; 6 | import dev.surly.ai.collab.agent.example.model.JobRateResponse; 7 | import dev.surly.ai.collab.tool.Tool; 8 | import lombok.RequiredArgsConstructor; 9 | import lombok.extern.slf4j.Slf4j; 10 | import org.springframework.ai.chat.model.Generation; 11 | import org.springframework.ai.chat.prompt.Prompt; 12 | import org.springframework.ai.chat.prompt.PromptTemplate; 13 | import org.springframework.ai.converter.BeanOutputConverter; 14 | import org.springframework.beans.factory.annotation.Value; 15 | import org.springframework.core.io.Resource; 16 | import org.springframework.web.bind.annotation.PostMapping; 17 | import org.springframework.web.bind.annotation.RequestBody; 18 | import org.springframework.web.bind.annotation.RequestMapping; 19 | import org.springframework.web.bind.annotation.RestController; 20 | 21 | import java.util.Map; 22 | 23 | @Agent(goal = "Provide guidance and support to an individual in their search for a job") 24 | @RequiredArgsConstructor 25 | @RestController 26 | @RequestMapping("/api/agents/job-recruiter") 27 | @Slf4j 28 | public class JobRecruiter extends AgentService { 29 | 30 | @Value("classpath:/prompts/agent-job-rating.st") 31 | private Resource jobRaterPrompt; 32 | 33 | @PostMapping("/rate-job") 34 | @Tool(name = "JobRater", description = "Given a job description, rate the job based on how well it matches the user's skills and interests") 35 | public JobRateResponse rateJob(@RequestBody JobRateRequest jobRateRequest) { 36 | var outputConverter = new BeanOutputConverter<>(JobRateResponse.class); 37 | 38 | PromptTemplate promptTemplate = new PromptTemplate(jobRaterPrompt, 39 | Map.of( 40 | "jobDescription", jobRateRequest.jobDescription(), 41 | "qualifications", jobRateRequest.qualificationsForPrompt(), 42 | "interests", jobRateRequest.interestsForPrompt(), 43 | "format", outputConverter.getFormat() 44 | ) 45 | ); 46 | Prompt prompt = promptTemplate.create(); 47 | 48 | log.info("Prompt: {}", prompt.toString()); 49 | 50 | Generation generation = chatModel.call(prompt).getResult(); 51 | 52 | JobRateResponse response = outputConverter.convert(generation.getOutput().getContent()); 53 | log.info("JobRateResponse: {}", response); 54 | return response; 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/Librarian.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import dev.surly.ai.collab.agent.AgentService; 4 | import dev.surly.ai.collab.agent.Agent; 5 | 6 | @Agent(goal = "Answer questions about books and authors", 7 | background = "You are a helpful librarian and can answer a variety of questions about books and authors") 8 | public class Librarian extends AgentService { 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/Mathematician.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import dev.surly.ai.collab.agent.Agent; 4 | import dev.surly.ai.collab.agent.AgentService; 5 | import dev.surly.ai.collab.agent.example.model.MathRequest; 6 | import dev.surly.ai.collab.tool.Tool; 7 | import org.springframework.web.bind.annotation.PostMapping; 8 | import org.springframework.web.bind.annotation.RequestBody; 9 | import org.springframework.web.bind.annotation.RequestMapping; 10 | import org.springframework.web.bind.annotation.RestController; 11 | 12 | import java.util.List; 13 | import java.util.stream.Collectors; 14 | 15 | @Agent(goal = "Answer mathematical questions and solve problems") 16 | @RestController 17 | @RequestMapping("/api/agents/mathematician") 18 | public class Mathematician extends AgentService { 19 | 20 | @Tool(name = "Adder", description = "Add a list of numbers together") 21 | @PostMapping("/add") 22 | public double add(@RequestBody MathRequest request) { 23 | double sum = 0; 24 | for (double num : request.numbers()) { 25 | sum += num; 26 | } 27 | return sum; 28 | } 29 | 30 | @Tool(name = "Subtractor", description = "Subtract a list of numbers") 31 | public double subtract(MathRequest request) { 32 | if (request.numbers().isEmpty()) return 0; 33 | double result = request.numbers().getFirst(); 34 | for (int i = 1; i < request.numbers().size(); i++) { 35 | result -= request.numbers().get(i); 36 | } 37 | return result; 38 | } 39 | 40 | @Tool(name = "Multiplier", description = "Multiply a list of numbers together") 41 | public double multiply(MathRequest request) { 42 | if (request.numbers().isEmpty()) return 0; 43 | double result = 1; 44 | for (double num : request.numbers()) { 45 | result *= num; 46 | } 47 | return result; 48 | } 49 | 50 | @Tool(name = "Divider", description = "Divide a list of numbers") 51 | public double divide(MathRequest request) { 52 | if (request.numbers().isEmpty()) return Double.NaN; 53 | double result = request.numbers().getFirst(); 54 | for (int i = 1; i < request.numbers().size(); i++) { 55 | if (request.numbers().get(i) == 0) { 56 | return Double.NaN; // Return NaN if division by zero is attempted 57 | } 58 | result /= request.numbers().get(i); 59 | } 60 | return result; 61 | } 62 | 63 | @Tool(name = "Square", description = "Square each of the list of numbers") 64 | public List square(MathRequest request) { 65 | return request.numbers().stream() 66 | .map(num -> num * num) 67 | .collect(Collectors.toList()); 68 | } 69 | 70 | @Tool(name = "SquareRoot", description = "Calculate the square root of each number") 71 | public List squareRoot(MathRequest request) { 72 | return request.numbers().stream() 73 | .map(Math::sqrt) 74 | .collect(Collectors.toList()); 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/ProductManager.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import dev.surly.ai.collab.agent.AgentService; 4 | import dev.surly.ai.collab.agent.Agent; 5 | 6 | @Agent(goal= "Define the vision, strategy, and roadmap for a product, and orchestrating the cross-functional team " + 7 | "efforts to build and enhance the product to meet customer needs and business goals.") 8 | public class ProductManager extends AgentService { 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/SoftwareEngineer.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import com.fasterxml.jackson.core.JsonProcessingException; 4 | import com.fasterxml.jackson.core.util.DefaultIndenter; 5 | import com.fasterxml.jackson.core.util.DefaultPrettyPrinter; 6 | import com.fasterxml.jackson.databind.JsonNode; 7 | import com.fasterxml.jackson.databind.ObjectMapper; 8 | import com.fasterxml.jackson.databind.ObjectWriter; 9 | import com.github.victools.jsonschema.generator.*; 10 | import com.github.victools.jsonschema.module.jackson.JacksonModule; 11 | import dev.surly.ai.collab.agent.Agent; 12 | import dev.surly.ai.collab.agent.AgentService; 13 | import dev.surly.ai.collab.tool.Tool; 14 | import lombok.extern.slf4j.Slf4j; 15 | 16 | import java.lang.reflect.Type; 17 | 18 | @Agent(goal = "Interpret and answer questions about software code") 19 | @Slf4j 20 | public class SoftwareEngineer extends AgentService { 21 | @Tool(name = "SchemaGenerator", description = "Given the fully qualified name of a class, generate a JSON schema for it") 22 | public String generateSchema(String className) { 23 | log.info("Generating schema for: {}", className); 24 | Class aClass; 25 | try { 26 | aClass = Class.forName(className, false, this.getClass().getClassLoader()); 27 | } catch (ClassNotFoundException e) { 28 | log.error("Class not found", e); 29 | return "Class not found"; 30 | } 31 | return generateSchema(aClass); 32 | } 33 | 34 | /* 35 | Stolen from BeanOutputParser 36 | */ 37 | private String generateSchema(Class clazz) { 38 | 39 | JacksonModule jacksonModule = new JacksonModule(); 40 | SchemaGeneratorConfigBuilder configBuilder = 41 | (new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12, OptionPreset.PLAIN_JSON)) 42 | .with(jacksonModule); 43 | SchemaGeneratorConfig config = configBuilder.build(); 44 | SchemaGenerator generator = new SchemaGenerator(config); 45 | JsonNode jsonNode = generator.generateSchema(clazz, new Type[0]); 46 | ObjectWriter objectWriter = (new ObjectMapper()) 47 | .writer((new DefaultPrettyPrinter()) 48 | .withObjectIndenter((new DefaultIndenter()) 49 | .withLinefeed(System.lineSeparator()))); 50 | 51 | String jsonSchema; 52 | try { 53 | jsonSchema = objectWriter.writeValueAsString(jsonNode); 54 | } catch (JsonProcessingException var8) { 55 | throw new RuntimeException("Could not pretty print json schema for " + clazz, var8); 56 | } 57 | return String.format("\n```%s```\n", jsonSchema); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/SoftwareTester.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example; 2 | 3 | import dev.surly.ai.collab.agent.Agent; 4 | import dev.surly.ai.collab.agent.AgentService; 5 | import dev.surly.ai.collab.tool.Tool; 6 | import lombok.extern.slf4j.Slf4j; 7 | 8 | @Agent(goal = "Test software", 9 | background = "You are an expert software engineer in all major programming languages. Test software for bugs and issues.") 10 | @Slf4j 11 | public class SoftwareTester extends AgentService { 12 | @Tool(name ="TestWriter", description = "Write comprehensive test cases for a given software class, method, or function") 13 | public String writeTests(String code) { 14 | return "TODO"; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/model/CompanyDetail.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example.model; 2 | 3 | public record CompanyDetail(String name, 4 | String websiteLink, 5 | String stockTicker, 6 | String numberOfEmployees, 7 | String summary, 8 | String location 9 | ) {} 10 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/model/JobRateRequest.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example.model; 2 | 3 | import java.util.List; 4 | import java.util.StringJoiner; 5 | 6 | public record JobRateRequest(String jobDescription, List qualifications, List interests) { 7 | 8 | private static final StringJoiner JOINER = new StringJoiner("\n- ", "\n- ", "\n"); 9 | 10 | public String qualificationsForPrompt() { 11 | return join(qualifications); 12 | } 13 | 14 | public String interestsForPrompt() { 15 | return join(interests); 16 | } 17 | 18 | private String join(List list) { 19 | for (String item : list) { 20 | JOINER.add(item); 21 | } 22 | return JOINER.toString(); 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/model/JobRateResponse.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example.model; 2 | 3 | import java.util.List; 4 | 5 | public record JobRateResponse(Integer rating, 6 | List detractingFactors, 7 | List enhancingFactors, 8 | String roleDescription, 9 | String companyDescription, 10 | List responsibilities, 11 | List requirements) { 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/agent/example/model/MathRequest.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent.example.model; 2 | 3 | import java.util.List; 4 | 5 | 6 | public record MathRequest(List numbers) { 7 | } 8 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/client/RestClientCustomizations.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.client; 2 | 3 | import dev.surly.ai.collab.log.LoggingInterceptor; 4 | import org.springframework.boot.web.client.ClientHttpRequestFactories; 5 | import org.springframework.boot.web.client.ClientHttpRequestFactorySettings; 6 | import org.springframework.boot.web.client.RestClientCustomizer; 7 | import org.springframework.context.annotation.Bean; 8 | import org.springframework.stereotype.Component; 9 | 10 | import java.time.Duration; 11 | 12 | @Component 13 | public class RestClientCustomizations { 14 | 15 | /** 16 | * Customizes the timeouts for RestClient. 17 | * This is necessary so the Spring AI ChatClient doesn't timeout. 18 | * 19 | * @return 20 | */ 21 | @Bean 22 | public RestClientCustomizer restClientCustomizer() { 23 | return restClientBuilder -> restClientBuilder 24 | .requestFactory(ClientHttpRequestFactories.get( 25 | ClientHttpRequestFactorySettings.DEFAULTS 26 | .withConnectTimeout(Duration.ofSeconds(5)) 27 | .withReadTimeout(Duration.ofSeconds(60)) 28 | )) 29 | .requestInterceptor(new LoggingInterceptor()); 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/controller/TeamController.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.controller; 2 | 3 | import dev.surly.ai.collab.Team; 4 | import dev.surly.ai.collab.agent.AgentRegistry; 5 | import dev.surly.ai.collab.controller.model.TeamForm; 6 | import dev.surly.ai.collab.task.Task; 7 | import dev.surly.ai.collab.task.TaskResult; 8 | import lombok.RequiredArgsConstructor; 9 | import lombok.extern.slf4j.Slf4j; 10 | import org.springframework.stereotype.Controller; 11 | import org.springframework.ui.Model; 12 | import org.springframework.web.bind.annotation.GetMapping; 13 | import org.springframework.web.bind.annotation.ModelAttribute; 14 | import org.springframework.web.bind.annotation.PostMapping; 15 | 16 | import java.util.List; 17 | 18 | @Controller 19 | @RequiredArgsConstructor 20 | @Slf4j 21 | public class TeamController { 22 | private final Team team; 23 | private final AgentRegistry agentRegistry; 24 | 25 | @GetMapping("/") 26 | public String team(Model model) { 27 | model.addAttribute("teamForm", new TeamForm()); 28 | model.addAttribute("agents", agentRegistry.enabledAgents()); 29 | return "index"; 30 | } 31 | 32 | @PostMapping("/") 33 | public String executeTask(@ModelAttribute TeamForm teamForm, Model model) { 34 | log.info("Given task: {}", teamForm); 35 | Task task = teamForm.toTask(); 36 | List taskResults = team 37 | .addTasks(List.of(task)) 38 | .kickoff(); 39 | 40 | TaskResult taskResult = taskResults.getFirst(); 41 | model.addAttribute("taskResult", taskResult); 42 | model.addAttribute("agents", agentRegistry.enabledAgents()); 43 | return "index"; 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/controller/TeamRestController.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.controller; 2 | 3 | import dev.surly.ai.collab.Team; 4 | import dev.surly.ai.collab.agent.AgentMetadata; 5 | import dev.surly.ai.collab.agent.AgentRegistry; 6 | import dev.surly.ai.collab.task.Task; 7 | import dev.surly.ai.collab.task.TaskResult; 8 | import lombok.RequiredArgsConstructor; 9 | import lombok.extern.slf4j.Slf4j; 10 | import org.springframework.web.bind.annotation.GetMapping; 11 | import org.springframework.web.bind.annotation.RequestMapping; 12 | import org.springframework.web.bind.annotation.RequestParam; 13 | import org.springframework.web.bind.annotation.RestController; 14 | 15 | import java.util.List; 16 | 17 | @RequiredArgsConstructor 18 | @RestController 19 | @RequestMapping("/api") 20 | @Slf4j 21 | public class TeamRestController { 22 | 23 | private final Team team; 24 | private final AgentRegistry agentRegistry; 25 | 26 | @GetMapping() 27 | public TaskResult team(@RequestParam String task) { 28 | log.info("Given task: {}", task); 29 | List taskResults = team 30 | .addTasks(List.of(new Task(task))) 31 | .kickoff(); 32 | TaskResult taskResult = taskResults.getFirst(); 33 | log.info("Task result:\n\n {}\n\n", taskResult); 34 | return taskResult; 35 | } 36 | 37 | @GetMapping("/agents") 38 | public List agents() { 39 | return agentRegistry.allAgents().values().stream() 40 | .map(agent -> new AgentMetadata(agent.getName(), agent.getGoal(), agent.getBackground(), agent.getTools())) 41 | .toList(); 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/controller/model/TeamForm.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.controller.model; 2 | 3 | import dev.surly.ai.collab.task.Task; 4 | import lombok.Data; 5 | 6 | @Data 7 | public class TeamForm { 8 | private String task; 9 | private String agent; 10 | 11 | public Task toTask() { 12 | if (agent != null && !agent.isEmpty()) { 13 | return new Task(task, agent); 14 | } else { 15 | return new Task(task); 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/embedding/HealthBenefitsDocumentEtlPipeline.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.embedding; 2 | 3 | import jakarta.annotation.PostConstruct; 4 | import lombok.RequiredArgsConstructor; 5 | import lombok.extern.slf4j.Slf4j; 6 | import org.springframework.ai.reader.pdf.PagePdfDocumentReader; 7 | import org.springframework.ai.vectorstore.VectorStore; 8 | import org.springframework.beans.factory.annotation.Value; 9 | import org.springframework.context.annotation.Profile; 10 | import org.springframework.core.io.Resource; 11 | import org.springframework.stereotype.Component; 12 | 13 | @Component 14 | @Slf4j 15 | @RequiredArgsConstructor 16 | @Profile("healthcare") 17 | public class HealthBenefitsDocumentEtlPipeline { 18 | 19 | private final VectorStore vectorStore; 20 | 21 | @Value("classpath:documents/health-benefits.pdf") 22 | Resource healthBenefitsPdfFile; 23 | 24 | @PostConstruct 25 | public void run() { 26 | log.info("Running health benefits document ETL pipeline"); 27 | var pdfReader = new PagePdfDocumentReader(healthBenefitsPdfFile); 28 | vectorStore.add(pdfReader.get()); 29 | log.info("Health benefits document ETL pipeline complete"); 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/exception/ToolInvocationException.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.exception; 2 | 3 | public class ToolInvocationException extends RuntimeException { 4 | public ToolInvocationException(String msg) { 5 | super(msg) ; 6 | } 7 | 8 | public ToolInvocationException(String msg, Exception e) { 9 | super(msg, e); 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/exception/ToolNotFoundException.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.exception; 2 | 3 | public class ToolNotFoundException extends RuntimeException { 4 | public ToolNotFoundException(String msg) { 5 | super(msg); 6 | } 7 | 8 | public ToolNotFoundException(String msg, Exception e) { 9 | super(msg, e); 10 | } 11 | } -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/flow/Flow.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.flow; 2 | 3 | import dev.surly.ai.collab.task.Task; 4 | import dev.surly.ai.collab.task.TaskError; 5 | import dev.surly.ai.collab.task.TaskResult; 6 | import org.slf4j.Logger; 7 | 8 | import java.util.List; 9 | 10 | public interface Flow { 11 | void addTask(Task task); 12 | List execute(); 13 | 14 | default TaskResult logAndReturnTaskError(Logger log, Task task, Throwable t) { 15 | var taskError = new TaskError("Error executing task " + task, t); 16 | log.error(taskError.message(), taskError.throwable()); 17 | return new TaskResult(task, "unknown", "unknown", taskError); 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/flow/FlowExecution.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.flow; 2 | 3 | import dev.surly.ai.collab.task.TaskResult; 4 | import lombok.Getter; 5 | 6 | import java.util.List; 7 | import java.util.Objects; 8 | import java.util.UUID; 9 | 10 | @Getter 11 | public class FlowExecution { 12 | 13 | private final UUID id; 14 | private final Flow flow; 15 | 16 | public FlowExecution(Flow flow) { 17 | this.id = UUID.randomUUID(); 18 | this.flow = flow; 19 | } 20 | 21 | public FlowExecutionResult execute() { 22 | List results = flow.execute(); 23 | return new FlowExecutionResult(results); 24 | } 25 | 26 | @Override public String toString() { 27 | return "FlowExecution(id=" + this.getId() + ", flow=" + this.getFlow() + ")"; 28 | } 29 | 30 | @Override 31 | public boolean equals(Object o) { 32 | if (this == o) return true; 33 | if (o == null || getClass() != o.getClass()) return false; 34 | FlowExecution that = (FlowExecution) o; 35 | return Objects.equals(id, that.id); 36 | } 37 | 38 | @Override 39 | public int hashCode() { 40 | return Objects.hashCode(id); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/flow/FlowExecutionResult.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.flow; 2 | 3 | import dev.surly.ai.collab.task.TaskResult; 4 | 5 | import java.util.List; 6 | 7 | public record FlowExecutionResult(List taskResults) { 8 | public void printResults() { 9 | taskResults.forEach(System.out::println); 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/flow/ParallelFlow.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.flow; 2 | 3 | import dev.surly.ai.collab.task.AgentTaskExecutor; 4 | import dev.surly.ai.collab.task.Task; 5 | import dev.surly.ai.collab.task.TaskResult; 6 | import lombok.extern.slf4j.Slf4j; 7 | 8 | import java.util.ArrayList; 9 | import java.util.List; 10 | import java.util.concurrent.*; 11 | 12 | @Slf4j 13 | public class ParallelFlow implements Flow { 14 | private final List tasks = new ArrayList<>(); 15 | private final AgentTaskExecutor agentTaskExecutor; 16 | private final ExecutorService executorService = Executors.newCachedThreadPool(); 17 | 18 | public ParallelFlow(AgentTaskExecutor agentTaskExecutor) { 19 | this.agentTaskExecutor = agentTaskExecutor; 20 | } 21 | 22 | @Override 23 | public void addTask(Task task) { 24 | tasks.add(task); 25 | } 26 | 27 | @Override 28 | public List execute() { 29 | CompletionService completionService = new ExecutorCompletionService<>(executorService); 30 | for (Task task : tasks) { 31 | completionService.submit(() -> agentTaskExecutor.executeTask(task)); 32 | } 33 | List results = new ArrayList<>(); 34 | for (Task task : tasks) { 35 | TaskResult taskResult; 36 | try { 37 | taskResult = completionService.take().get(); 38 | } catch (InterruptedException | ExecutionException e) { 39 | taskResult = logAndReturnTaskError(log, task, e); 40 | } 41 | results.add(taskResult); 42 | } 43 | return results; 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/flow/SequentialFlow.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.flow; 2 | 3 | import dev.surly.ai.collab.task.AgentTaskExecutor; 4 | import dev.surly.ai.collab.task.Task; 5 | import dev.surly.ai.collab.task.TaskResult; 6 | import lombok.extern.slf4j.Slf4j; 7 | 8 | import java.util.ArrayList; 9 | import java.util.List; 10 | 11 | @Slf4j 12 | public class SequentialFlow implements Flow { 13 | 14 | private final List tasks = new ArrayList<>(); 15 | private final AgentTaskExecutor agentTaskExecutor; 16 | 17 | public SequentialFlow(AgentTaskExecutor agentTaskExecutor) { 18 | this.agentTaskExecutor = agentTaskExecutor; 19 | } 20 | 21 | @Override 22 | public void addTask(Task task) { 23 | tasks.add(task); 24 | } 25 | 26 | @Override 27 | public List execute() { 28 | return tasks.stream() 29 | .map(task -> { 30 | try { 31 | return agentTaskExecutor.executeTask(task); 32 | } catch (Exception e) { 33 | return logAndReturnTaskError(log, task, e); 34 | } 35 | }).toList(); 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/log/LoggingInterceptor.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.log; 2 | 3 | import com.fasterxml.jackson.databind.ObjectMapper; 4 | import com.fasterxml.jackson.databind.SerializationFeature; 5 | import lombok.extern.slf4j.Slf4j; 6 | import org.jetbrains.annotations.NotNull; 7 | import org.springframework.http.HttpHeaders; 8 | import org.springframework.http.HttpRequest; 9 | import org.springframework.http.HttpStatusCode; 10 | import org.springframework.http.client.ClientHttpRequestExecution; 11 | import org.springframework.http.client.ClientHttpRequestInterceptor; 12 | import org.springframework.http.client.ClientHttpResponse; 13 | import org.springframework.util.StreamUtils; 14 | 15 | import java.io.ByteArrayInputStream; 16 | import java.io.IOException; 17 | import java.io.InputStream; 18 | import java.nio.charset.StandardCharsets; 19 | 20 | @Slf4j 21 | public class LoggingInterceptor implements ClientHttpRequestInterceptor { 22 | @Override 23 | public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException { 24 | 25 | if (!log.isDebugEnabled()) { 26 | return execution.execute(request, body); 27 | } 28 | 29 | String rawRequestBody = new String(body, StandardCharsets.UTF_8); 30 | try { 31 | ObjectMapper mapper = new ObjectMapper(); 32 | Object json = mapper.readValue(rawRequestBody, Object.class); 33 | String prettyJson = mapper 34 | .enable(SerializationFeature.INDENT_OUTPUT) 35 | .writeValueAsString(json); 36 | logRequestBody(prettyJson); 37 | } catch (Exception e) { 38 | logRequestBody(rawRequestBody); 39 | } 40 | 41 | ClientHttpResponse response = execution.execute(request, body); 42 | if (response.getStatusCode() == HttpStatusCode.valueOf(400)) { 43 | log.error("400: {}", response.getStatusText()); 44 | } 45 | 46 | byte[] responseBody = logResponseBody(response); 47 | 48 | return new BufferedClientHttpResponse(response, responseBody); 49 | } 50 | 51 | private void logRequestBody(String prettyJson) { 52 | log.debug("REQUEST:{}{}", System.lineSeparator(), prettyJson); 53 | } 54 | 55 | private byte[] logResponseBody(ClientHttpResponse response) throws IOException { 56 | byte[] responseBody = StreamUtils.copyToByteArray(response.getBody()); 57 | String bodyAsString = new String(responseBody, StandardCharsets.UTF_8); 58 | log.debug("RESPONSE:{}{}", System.lineSeparator(), bodyAsString); 59 | return responseBody; 60 | } 61 | 62 | static class BufferedClientHttpResponse implements ClientHttpResponse { 63 | private final ClientHttpResponse originalResponse; 64 | private final byte[] body; 65 | 66 | public BufferedClientHttpResponse(ClientHttpResponse originalResponse, byte[] body) { 67 | this.originalResponse = originalResponse; 68 | this.body = body; 69 | } 70 | 71 | @Override 72 | public @NotNull HttpStatusCode getStatusCode() throws IOException { 73 | return originalResponse.getStatusCode(); 74 | } 75 | 76 | @Override 77 | public @NotNull String getStatusText() throws IOException { 78 | return originalResponse.getStatusText(); 79 | } 80 | 81 | @Override 82 | public void close() { 83 | originalResponse.close(); 84 | } 85 | 86 | @Override 87 | public @NotNull InputStream getBody() throws IOException { 88 | return new ByteArrayInputStream(body); 89 | } 90 | 91 | @Override 92 | public @NotNull HttpHeaders getHeaders() { 93 | return originalResponse.getHeaders(); 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/nlp/NlpService.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.nlp; 2 | 3 | import edu.stanford.nlp.ling.CoreAnnotations; 4 | import edu.stanford.nlp.pipeline.Annotation; 5 | import edu.stanford.nlp.pipeline.StanfordCoreNLP; 6 | import edu.stanford.nlp.trees.Tree; 7 | import edu.stanford.nlp.trees.TreeCoreAnnotations; 8 | import edu.stanford.nlp.util.CoreMap; 9 | import lombok.extern.slf4j.Slf4j; 10 | import org.springframework.stereotype.Service; 11 | 12 | import java.util.ArrayList; 13 | import java.util.List; 14 | import java.util.Properties; 15 | 16 | @Slf4j 17 | @Service 18 | public class NlpService { 19 | 20 | public List getSubtasks(String text) { 21 | 22 | log.info("Extracting sub-tasks from: {}", text); 23 | 24 | // Set up pipeline properties 25 | Properties props = new Properties(); 26 | props.setProperty("annotators", "tokenize,ssplit,pos,lemma,parse"); 27 | StanfordCoreNLP pipeline = new StanfordCoreNLP(props); 28 | 29 | // Create an empty Annotation just with the given text 30 | Annotation document = new Annotation(text); 31 | 32 | // Run all Annotators on this text 33 | pipeline.annotate(document); 34 | 35 | List subTasks = new ArrayList<>(); 36 | 37 | // Iterate over all of the sentences found 38 | List sentences = document.get(CoreAnnotations.SentencesAnnotation.class); 39 | for (CoreMap sentence : sentences) { 40 | // Parse the sentence 41 | Tree parseTree = sentence.get(TreeCoreAnnotations.TreeAnnotation.class); 42 | 43 | // Extract actionable phrases as sub-tasks 44 | extractActionPhrases(parseTree, subTasks); 45 | } 46 | 47 | return subTasks; 48 | } 49 | 50 | private void extractActionPhrases(Tree parseTree, List subTasks) { 51 | for (Tree subtree : parseTree) { 52 | if (subtree.label().value().equals("VP")) { // VP stands for Verb Phrase 53 | StringBuilder taskBuilder = new StringBuilder(); 54 | for (Tree leaf : subtree.getLeaves()) { 55 | if (!taskBuilder.isEmpty()) { 56 | taskBuilder.append(" "); 57 | } 58 | taskBuilder.append(leaf.toString()); 59 | } 60 | String potentialTask = taskBuilder.toString(); 61 | if (!potentialTask.isEmpty() && potentialTask.split(" ").length > 2) { // Filter very short phrases 62 | subTasks.add(potentialTask); 63 | } 64 | } 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/statemachine/EventPublisher.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.statemachine; 2 | 3 | import lombok.RequiredArgsConstructor; 4 | import org.springframework.context.ApplicationEventPublisher; 5 | import org.springframework.stereotype.Component; 6 | 7 | @Component 8 | @RequiredArgsConstructor 9 | public class EventPublisher { 10 | private final ApplicationEventPublisher applicationEventPublisher; 11 | 12 | public void publishEvent(TaskEvent taskEvent) { 13 | applicationEventPublisher.publishEvent(taskEvent); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/statemachine/Events.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.statemachine; 2 | 3 | public enum Events { 4 | START, 5 | FINISH 6 | } 7 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/statemachine/SimpleStateMachine.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.statemachine; 2 | 3 | import org.springframework.context.annotation.Bean; 4 | import org.springframework.context.annotation.Configuration; 5 | import org.springframework.statemachine.action.Action; 6 | import org.springframework.statemachine.config.EnableStateMachine; 7 | import org.springframework.statemachine.config.EnumStateMachineConfigurerAdapter; 8 | import org.springframework.statemachine.config.builders.StateMachineConfigurationConfigurer; 9 | import org.springframework.statemachine.config.builders.StateMachineStateConfigurer; 10 | import org.springframework.statemachine.config.builders.StateMachineTransitionConfigurer; 11 | 12 | @Configuration 13 | @EnableStateMachine(name = "simpleStateMachine1") 14 | public class SimpleStateMachine extends EnumStateMachineConfigurerAdapter { 15 | 16 | @Bean 17 | public Action initAction() { 18 | return ctx -> System.out.println(ctx.getTarget().getId()); 19 | } 20 | 21 | @Bean 22 | public Action executeAction() { 23 | return ctx -> System.out.println("Do " + ctx.getTarget().getId()); 24 | } 25 | 26 | @Bean 27 | public Action completedAction() { 28 | return ctx -> System.out.println("Completed " + ctx.getTarget().getId()); 29 | } 30 | 31 | @Override 32 | public void configure(StateMachineStateConfigurer states) throws Exception { 33 | states 34 | .withStates() 35 | .initial(States.INITIAL) 36 | .state(States.IN_PROGRESS, executeAction()) 37 | .end(States.COMPLETED) 38 | .state(States.COMPLETED, completedAction()); 39 | } 40 | 41 | @Override 42 | public void configure(StateMachineTransitionConfigurer transitions) throws Exception { 43 | transitions 44 | .withExternal().source(States.INITIAL).target(States.IN_PROGRESS).event(Events.START).action(initAction()) 45 | .and() 46 | .withExternal().source(States.IN_PROGRESS).target(States.COMPLETED).event(Events.FINISH); 47 | } 48 | 49 | @Override 50 | public void configure(StateMachineConfigurationConfigurer config) throws Exception { 51 | config 52 | .withConfiguration() 53 | .machineId("simpleStateMachine") 54 | .autoStartup(true); 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/statemachine/SimpleStateMachineService.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.statemachine; 2 | 3 | import io.cloudevents.CloudEvent; 4 | import lombok.RequiredArgsConstructor; 5 | import org.springframework.context.event.EventListener; 6 | import org.springframework.statemachine.StateMachine; 7 | import org.springframework.stereotype.Component; 8 | 9 | @RequiredArgsConstructor 10 | @Component 11 | public class SimpleStateMachineService { 12 | private final StateMachine simpleStateMachine; 13 | 14 | @EventListener 15 | public void onApplicationEvent(TaskEvent taskEvent) { 16 | CloudEvent ce = taskEvent.getCloudEvent(); 17 | switch(ce.getType()) { 18 | case "START": 19 | simpleStateMachine.sendEvent(Events.START); 20 | break; 21 | case "FINISH": 22 | simpleStateMachine.sendEvent(Events.FINISH); 23 | break; 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/statemachine/StateMachineListener.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.statemachine; 2 | 3 | import org.springframework.statemachine.listener.StateMachineListenerAdapter; 4 | import org.springframework.statemachine.state.State; 5 | import org.springframework.stereotype.Component; 6 | 7 | @Component 8 | public class StateMachineListener extends StateMachineListenerAdapter { 9 | @Override 10 | public void stateChanged(State from, State to) { 11 | System.out.printf("Transitioned from %s to %s%n", from == null ? 12 | "none" : from.getId(), to.getId()); 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/statemachine/States.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.statemachine; 2 | 3 | public enum States { 4 | INITIAL, 5 | IN_PROGRESS, 6 | COMPLETED 7 | } 8 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/statemachine/TaskEvent.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.statemachine; 2 | 3 | import io.cloudevents.CloudEvent; 4 | import io.cloudevents.core.v1.CloudEventBuilder; 5 | import org.springframework.context.ApplicationEvent; 6 | 7 | import java.net.URI; 8 | import java.time.OffsetDateTime; 9 | import java.util.UUID; 10 | 11 | public class TaskEvent extends ApplicationEvent { 12 | 13 | private final CloudEvent cloudEvent; 14 | 15 | public TaskEvent(Object source, Events eventType, byte[] data) { 16 | super(source); 17 | this.cloudEvent = new CloudEventBuilder() 18 | .withId(UUID.randomUUID().toString()) 19 | .withType(eventType.name()) 20 | .withSource(URI.create("https://surly.dev/collab/task")) 21 | .withTime(OffsetDateTime.now()) 22 | .withData(data) 23 | .build(); 24 | } 25 | 26 | public CloudEvent getCloudEvent() { 27 | return cloudEvent; 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/task/AgentTaskExecutor.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.task; 2 | 3 | import dev.surly.ai.collab.agent.AgentRegistry; 4 | import dev.surly.ai.collab.exception.ToolInvocationException; 5 | import lombok.RequiredArgsConstructor; 6 | import lombok.extern.slf4j.Slf4j; 7 | import org.jetbrains.annotations.NotNull; 8 | import org.springframework.ai.chat.model.ChatModel; 9 | import org.springframework.stereotype.Component; 10 | 11 | import java.util.List; 12 | 13 | @Component 14 | @Slf4j 15 | @RequiredArgsConstructor 16 | public class AgentTaskExecutor { 17 | 18 | private final ChatModel chatModel; 19 | private final TaskPlanner taskPlanner; 20 | private final AgentRegistry agentRegistry; 21 | 22 | public TaskResult executeTask(Task task) throws ToolInvocationException { 23 | log.info("Executing task: {}", task); 24 | return taskPlanner.chooseAgent(chatModel, task) 25 | .map(agentRegistry::getAgent) 26 | .map(agent -> agent.executeTask(task)) 27 | .orElseThrow(() -> handleNoToolAvailable(task)); 28 | } 29 | 30 | public List executeTasks(ChatModel chatModel, List tasks) throws ToolInvocationException { 31 | return tasks.stream() 32 | .map(task -> taskPlanner.chooseAgent(chatModel, task) 33 | .map(agentRegistry::getAgent) 34 | .map(agent -> agent.executeTask(task)) 35 | .orElseThrow(() -> handleNoToolAvailable(task)) 36 | ) 37 | .toList(); 38 | } 39 | 40 | private @NotNull ToolInvocationException handleNoToolAvailable(Task task) { 41 | log.error("No tool available to execute task: {}", task); 42 | return new ToolInvocationException("No tool available to execute task"); 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/task/Task.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.task; 2 | 3 | import lombok.Data; 4 | 5 | @Data 6 | public class Task { 7 | private final String description; 8 | private final String agent; 9 | 10 | public Task(String description, String agent) { 11 | this.description = description; 12 | this.agent = agent; 13 | } 14 | 15 | public Task(String description) { 16 | this(description, null); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/task/TaskAssignment.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.task; 2 | 3 | 4 | public record TaskAssignment(Task task, String agentName) { 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/task/TaskDeconstructor.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.task; 2 | 3 | import dev.surly.ai.collab.nlp.NlpService; 4 | import lombok.RequiredArgsConstructor; 5 | import lombok.extern.slf4j.Slf4j; 6 | import org.springframework.stereotype.Component; 7 | 8 | import java.util.List; 9 | 10 | @RequiredArgsConstructor 11 | @Component 12 | @Slf4j 13 | public class TaskDeconstructor { 14 | 15 | private final NlpService nlpService; 16 | 17 | public List deconstruct(List tasks) { 18 | List subtasks = tasks.stream() 19 | .map(Task::getDescription) 20 | .flatMap(s -> nlpService.getSubtasks(s).stream()) 21 | .map(Task::new) 22 | .toList(); 23 | 24 | List result; 25 | if (subtasks.isEmpty()) { 26 | result = tasks; 27 | } else if (subtasks.size() == 1) { 28 | result = subtasks; 29 | } else { 30 | result = subtasks.subList(1, subtasks.size()); 31 | } 32 | 33 | StringBuilder subtasksOutput = new StringBuilder(); 34 | for (int i = 0; i < result.size(); i++) { 35 | Task task = result.get(i); 36 | subtasksOutput.append(i+1).append(". ").append(task.getDescription()).append("\n"); 37 | } 38 | log.info("Subtasks:\n{}\n", subtasksOutput); 39 | 40 | return result; 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/task/TaskError.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.task; 2 | 3 | public record TaskError(String message, Throwable throwable) { 4 | } 5 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/task/TaskPlanner.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.task; 2 | 3 | import dev.surly.ai.collab.agent.AgentRegistry; 4 | import dev.surly.ai.collab.agent.AgentService; 5 | import dev.surly.ai.collab.tool.ToolMetadata; 6 | import lombok.RequiredArgsConstructor; 7 | import lombok.extern.slf4j.Slf4j; 8 | import org.jetbrains.annotations.NotNull; 9 | import org.springframework.ai.chat.model.ChatModel; 10 | import org.springframework.ai.chat.prompt.Prompt; 11 | import org.springframework.ai.chat.prompt.PromptTemplate; 12 | import org.springframework.ai.converter.MapOutputConverter; 13 | import org.springframework.beans.factory.annotation.Value; 14 | import org.springframework.core.io.Resource; 15 | import org.springframework.stereotype.Component; 16 | 17 | import java.util.List; 18 | import java.util.Map; 19 | import java.util.Optional; 20 | 21 | @Component 22 | @Slf4j 23 | @RequiredArgsConstructor 24 | public class TaskPlanner { 25 | 26 | private final ChatModel chatModel; 27 | private final AgentRegistry agentRegistry; 28 | 29 | @Value("classpath:/prompts/task-planner-choose-agents.st") 30 | private Resource chooseAgentsUserPrompt; 31 | 32 | @Value("classpath:/prompts/task-planner-choose-agent.st") 33 | private Resource chooseAgentUserPrompt; 34 | 35 | public List assign(List tasks) { 36 | return tasks.stream() 37 | .map(task -> new TaskAssignment(task, chooseAgent(chatModel, task).orElse(null))) 38 | .toList(); 39 | } 40 | 41 | /** 42 | * Given a task, determine which agent is most capable of accomplishing the task 43 | * TODO add tools to prompt to aid in decisioning 44 | * 45 | * @param task 46 | * @return 47 | */ 48 | public Optional chooseAgent(ChatModel chatModel, Task task) { 49 | long start = System.currentTimeMillis(); 50 | Map agents = agentRegistry.enabledAgents(); 51 | 52 | log.info("Found {} enabled agents", agents.size()); 53 | // agents.forEach((k, v) -> log.info("Agent: {}, Goal: {}", k, v.getGoal())); 54 | 55 | if (task.getAgent() != null) { 56 | log.info("Task specified agent: {}", task.getAgent()); 57 | return agents.keySet().stream().filter(s -> s.equals(task.getAgent())).findFirst(); 58 | } 59 | 60 | if (agents.isEmpty()) { 61 | log.warn("No agents available"); 62 | return Optional.empty(); 63 | } 64 | 65 | if (agents.size() == 1) { 66 | var agent = agents.entrySet().stream().findFirst().get(); 67 | log.warn("Only one agent available: {}", agent); 68 | return Optional.ofNullable(agent.getKey()); 69 | } 70 | 71 | StringBuilder agentList = new StringBuilder(); 72 | for (Map.Entry entry : agents.entrySet()) { 73 | agentList.append(entry.getKey()).append(": ").append(entry.getValue().getGoal()).append("\r\n"); 74 | } 75 | 76 | PromptTemplate promptTemplate = new PromptTemplate(chooseAgentUserPrompt, Map.of( 77 | "task", task.getDescription(), 78 | "agents", agentList.toString() 79 | )); 80 | Prompt prompt = promptTemplate.create(); 81 | 82 | var generation = chatModel.call(prompt).getResult(); 83 | String content = generation.getOutput().getContent(); 84 | Long elapsed = System.currentTimeMillis() - start; 85 | log.info("Selected Agent: {} in {} ms", content, elapsed); 86 | 87 | return Optional.ofNullable(content); 88 | } 89 | 90 | private Map chooseAgents(ChatModel chatModel, List tasks) { 91 | Map agents = agentRegistry.enabledAgents(); 92 | 93 | if (agents.isEmpty()) { 94 | log.warn("No agents available"); 95 | return Map.of(); 96 | } 97 | 98 | log.info("Found {} enabled agents", agents.size()); 99 | agents.forEach((k, v) -> log.info("Agent: {}, Goal: {}", k, v.getGoal())); 100 | 101 | var outputConverter = new MapOutputConverter(); 102 | 103 | StringBuilder agentList = new StringBuilder(); 104 | for (Map.Entry entry : agents.entrySet()) { 105 | var agentName = entry.getKey(); 106 | var agent = entry.getValue(); 107 | var agentGoal = agent.getGoal(); 108 | var tools = agent.getTools(); 109 | StringBuilder toolList = generateToolListForPrompt(tools); 110 | agentList.append(agentName).append(": ").append(agentGoal) 111 | .append("\r\n") 112 | .append(toolList); 113 | } 114 | 115 | StringBuilder taskList = new StringBuilder(); 116 | for (Task task : tasks) { 117 | taskList.append(task.getDescription()).append("\r\n"); 118 | } 119 | 120 | PromptTemplate promptTemplate = new PromptTemplate(chooseAgentsUserPrompt, Map.of( 121 | "tasks", taskList.toString(), 122 | "agents", agentList.toString(), 123 | "format", outputConverter.getFormat() 124 | )); 125 | Prompt prompt = promptTemplate.create(); 126 | 127 | var generation = chatModel.call(prompt).getResult(); 128 | String content = generation.getOutput().getContent(); 129 | 130 | return outputConverter.convert(content); 131 | } 132 | 133 | private @NotNull StringBuilder generateToolListForPrompt(Map tools) { 134 | StringBuilder toolList = new StringBuilder(); 135 | toolList.append("The tools available from this agent are: "); 136 | for (Map.Entry toolEntry : tools.entrySet()) { 137 | var toolName = toolEntry.getKey(); 138 | var toolDesc = toolEntry.getValue().description(); 139 | boolean toolDisabled = toolEntry.getValue().disabled(); 140 | if (!toolDisabled) { 141 | toolList.append("- ").append(toolName).append(": ").append(toolDesc).append("\r\n"); 142 | } 143 | } 144 | return toolList; 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/task/TaskResult.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.task; 2 | 3 | import com.fasterxml.jackson.core.JsonProcessingException; 4 | import com.fasterxml.jackson.databind.ObjectMapper; 5 | import com.fasterxml.jackson.databind.SerializationFeature; 6 | import dev.surly.ai.collab.util.ConversionUtils; 7 | import lombok.Data; 8 | import lombok.Getter; 9 | import org.springframework.ai.image.Image; 10 | 11 | import java.util.List; 12 | 13 | @Getter 14 | @Data 15 | public class TaskResult { 16 | private final Task task; 17 | private final String agentName; 18 | private final String toolName; 19 | private final Object data; 20 | private final String dataType; 21 | private TaskError taskError; 22 | private List timings; 23 | 24 | public TaskResult(Task task, String agentName, String toolName, Object data, List timings) { 25 | this.task = task; 26 | this.agentName = agentName; 27 | this.toolName = toolName; 28 | this.data = data; 29 | if (data != null) { 30 | this.dataType = data.getClass().getName(); 31 | } else { 32 | this.dataType = null; 33 | } 34 | this.timings = timings; 35 | } 36 | 37 | public TaskResult(Task task, String agentName, String toolName, TaskError taskError) { 38 | this.task = task; 39 | this.agentName = agentName; 40 | this.toolName = toolName; 41 | this.data = null; 42 | this.dataType = null; 43 | this.taskError = taskError; 44 | } 45 | 46 | public Long getDuration() { 47 | if (timings == null || timings.isEmpty()) { 48 | return 0L; 49 | } 50 | return timings.stream() 51 | .map(TaskTiming::timeMs) 52 | .reduce(0L, Long::sum); 53 | } 54 | 55 | public Object display() { 56 | if (data instanceof Image) { 57 | return data; 58 | } else if (data instanceof String raw) { 59 | var markdown = ConversionUtils.convertToMarkdown(raw); 60 | return ConversionUtils.convertToHtml(markdown); 61 | } else { 62 | String prettyJson; 63 | try { 64 | ObjectMapper mapper = new ObjectMapper() 65 | .enable(SerializationFeature.INDENT_OUTPUT); // Enable pretty print 66 | prettyJson = mapper.writeValueAsString(data); 67 | } catch (JsonProcessingException e) { 68 | throw new RuntimeException(e); 69 | } 70 | return prettyJson; 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/task/TaskTiming.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.task; 2 | 3 | public record TaskTiming(String label, Long timeMs) { 4 | } 5 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/tool/Tool.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.tool; 2 | 3 | import java.lang.annotation.ElementType; 4 | import java.lang.annotation.Retention; 5 | import java.lang.annotation.RetentionPolicy; 6 | import java.lang.annotation.Target; 7 | 8 | @Retention(RetentionPolicy.RUNTIME) 9 | @Target(ElementType.METHOD) 10 | public @interface Tool { 11 | String name(); 12 | String description() default ""; 13 | boolean disabled() default false; 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/tool/ToolMetadata.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.tool; 2 | 3 | import com.fasterxml.jackson.annotation.JsonIgnore; 4 | 5 | import java.lang.reflect.Method; 6 | import java.util.Arrays; 7 | 8 | public record ToolMetadata(String name, String description, @JsonIgnore Method method, boolean disabled) { 9 | @JsonIgnore 10 | public Class getReturnType() { 11 | if (method.getParameterCount() == 0) { 12 | return null; 13 | } 14 | return Arrays.stream(method.getParameterTypes()).findFirst().orElseThrow(); 15 | } 16 | 17 | @JsonIgnore 18 | public String getMethodArgsAsString() { 19 | // Get the parameter types 20 | Class[] parameterTypes = method.getParameterTypes(); 21 | 22 | // Build the string representation of the parameter types 23 | StringBuilder parametersString = new StringBuilder("("); 24 | for (int i = 0; i < parameterTypes.length; i++) { 25 | parametersString.append(parameterTypes[i].getTypeName()); 26 | if (i < parameterTypes.length - 1) { 27 | parametersString.append(", "); 28 | } 29 | } 30 | parametersString.append(")"); 31 | 32 | return parametersString.toString(); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/util/ConversionUtils.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.util; 2 | 3 | import com.vladsch.flexmark.html.HtmlRenderer; 4 | import com.vladsch.flexmark.html2md.converter.FlexmarkHtmlConverter; 5 | import com.vladsch.flexmark.parser.Parser; 6 | import org.apache.commons.text.StringEscapeUtils; 7 | 8 | import java.util.regex.Matcher; 9 | import java.util.regex.Pattern; 10 | 11 | public class ConversionUtils { 12 | 13 | private static final Pattern HTML_ENTITY_PATTERN = Pattern.compile("&[a-zA-Z0-9#]+;"); 14 | 15 | public static String convertToMarkdown(String html) { 16 | if (html == null) { 17 | return null; 18 | } 19 | 20 | // first, determine if the html is escaped. 21 | Matcher matcher = HTML_ENTITY_PATTERN.matcher(html); 22 | if (matcher.find()) { 23 | html = StringEscapeUtils.unescapeHtml4(html); 24 | } 25 | 26 | html = html.replaceAll("\\*", "\n"); 27 | 28 | return FlexmarkHtmlConverter.builder().build().convert(html); 29 | } 30 | 31 | public static String convertToHtml(String markdownContent) { 32 | if (markdownContent == null) { 33 | return null; 34 | } 35 | Parser parser = Parser.builder().build(); 36 | HtmlRenderer renderer = HtmlRenderer.builder().build(); 37 | String rawHtml = renderer.render(parser.parse(markdownContent)); 38 | return rawHtml.replaceFirst("```", "
").replace("```", "
"); 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/validation/CompositeTaskResultValidator.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.validation; 2 | 3 | import dev.surly.ai.collab.task.TaskResult; 4 | 5 | import java.util.ArrayList; 6 | import java.util.List; 7 | import java.util.function.Predicate; 8 | 9 | public class CompositeTaskResultValidator implements TaskResultValidator { 10 | 11 | private final List> predicates = new ArrayList<>(); 12 | 13 | public CompositeTaskResultValidator addPredicate(Predicate predicate) { 14 | predicates.add(predicate); 15 | return this; 16 | } 17 | 18 | @Override 19 | public boolean validate(TaskResult taskResult) { 20 | return predicates.stream().allMatch(predicate -> predicate.test(taskResult)); // All predicates returned true 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/validation/TaskResultValidator.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.validation; 2 | 3 | import dev.surly.ai.collab.task.TaskResult; 4 | 5 | public interface TaskResultValidator extends Validator { 6 | boolean validate(TaskResult taskResult); 7 | } 8 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/validation/Validator.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.validation; 2 | 3 | public interface Validator { 4 | boolean validate(T input); 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/vectorstore/VectorStoreConfig.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.vectorstore; 2 | 3 | import org.springframework.ai.embedding.EmbeddingModel; 4 | import org.springframework.ai.vectorstore.SimpleVectorStore; 5 | import org.springframework.ai.vectorstore.VectorStore; 6 | import org.springframework.context.annotation.Bean; 7 | import org.springframework.context.annotation.Configuration; 8 | 9 | @Configuration 10 | public class VectorStoreConfig { 11 | @Bean 12 | VectorStore vectorStore(EmbeddingModel openAiEmbeddingModel) { 13 | return new SimpleVectorStore(openAiEmbeddingModel); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/workflow/WorkflowCoordinator.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.workflow; 2 | 3 | import dev.surly.ai.collab.agent.AgentRegistry; 4 | import org.springframework.stereotype.Component; 5 | 6 | import java.util.HashMap; 7 | import java.util.Map; 8 | 9 | @Component 10 | public class WorkflowCoordinator { 11 | private final AgentRegistry agentRegistry; 12 | private final WorkflowStateMachine workflowStateMachine; 13 | private final Map context; 14 | 15 | public WorkflowCoordinator(AgentRegistry agentRegistry, WorkflowStateMachine workflowStateMachine) { 16 | this.agentRegistry = agentRegistry; 17 | this.workflowStateMachine = workflowStateMachine; 18 | this.context = new HashMap<>(); 19 | } 20 | 21 | public void executeWorkflow(String complexTask) { 22 | while (workflowStateMachine.getCurrentState() != WorkflowState.COMPLETED) { 23 | agentRegistry.enabledAgents().values().forEach(agent -> { 24 | if (agent.canPerform(workflowStateMachine.getCurrentState())) { 25 | agent.performTask(complexTask, context); 26 | } 27 | }); 28 | workflowStateMachine.transition(); 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/workflow/WorkflowState.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.workflow; 2 | 3 | public enum WorkflowState { 4 | DATA_COLLECTION, 5 | DATA_ANALYSIS, 6 | COMPLETED 7 | } 8 | -------------------------------------------------------------------------------- /src/main/java/dev/surly/ai/collab/workflow/WorkflowStateMachine.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.workflow; 2 | 3 | import org.springframework.stereotype.Component; 4 | 5 | @Component 6 | public class WorkflowStateMachine { 7 | private WorkflowState currentState; 8 | 9 | public WorkflowStateMachine() { 10 | this.currentState = WorkflowState.DATA_COLLECTION; 11 | } 12 | 13 | public void transition() { 14 | switch (currentState) { 15 | case DATA_COLLECTION: 16 | currentState = WorkflowState.DATA_ANALYSIS; 17 | break; 18 | case DATA_ANALYSIS: 19 | currentState = WorkflowState.COMPLETED; 20 | break; 21 | case COMPLETED: 22 | break; 23 | } 24 | } 25 | 26 | public WorkflowState getCurrentState() { 27 | return currentState; 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/resources/application.properties: -------------------------------------------------------------------------------- 1 | server.port=8081 2 | 3 | #management.endpoints.web.exposure.include=* 4 | #management.endpoint.health.show-details=always 5 | 6 | 7 | # set to false to disable devtools 8 | spring.devtools.add-properties=true 9 | 10 | server.error.whitelabel.enabled=false 11 | server.error.path=/error 12 | 13 | spring.mvc.hiddenmethod.filter.enabled=true 14 | 15 | spring.threads.virtual.enabled=true 16 | 17 | # Spring AI with OpenAI 18 | # OpenAI Chat properties: https://docs.spring.io/spring-ai/reference/api/clients/openai-chat.html#_chat_properties 19 | spring.ai.openai.api-key=${OPENAI_API_KEY} 20 | spring.ai.openai.chat.options.model=gpt-4o 21 | spring.ai.openai.chat.options.user=spring-ai-collab 22 | spring.ai.openai.chat.options.topP=0.1 23 | 24 | logging.level.dev.surly.ai.collab.log=debug -------------------------------------------------------------------------------- /src/main/resources/documents/health-benefits.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thesurlydev/spring-ai-collab/689fa1650406d2cb9149715b2ac30de8f64c4108/src/main/resources/documents/health-benefits.pdf -------------------------------------------------------------------------------- /src/main/resources/prompts/agent-company-focus.st: -------------------------------------------------------------------------------- 1 | Given the name of a company, return a summary of the company's focus in two to three paragraphs. 2 | 3 | The company name is: {companyName} 4 | -------------------------------------------------------------------------------- /src/main/resources/prompts/agent-determine-programming-language.st: -------------------------------------------------------------------------------- 1 | Given the following code snippet, return the name of the programming language the code is written in. 2 | Return a single word that is the name of the programming language. 3 | Look for any keywords or syntax that are unique to the programming language. 4 | Look for any words that are unique to libraries or frameworks of the programming language. 5 | Think carefully and make an educated guess if you're not sure. 6 | Return just the name of the programming language and nothing else. No extra text or punctuation. 7 | 8 | The code snippet is: 9 | {code} 10 | 11 | -------------------------------------------------------------------------------- /src/main/resources/prompts/agent-job-rating.st: -------------------------------------------------------------------------------- 1 | You are a helpful job recruiter tasked with finding job leads for a job seeker. 2 | 3 | Given a job description determine the rating from 1 to 10. 4 | 1 meaning the job description is not a good fit and 10 meaning the job description is a perfect fit. 5 | 6 | The rating is determined by: 7 | 1. Comparing the job description to a list of desired attributes for a job. 8 | 2. Comparing the job description to a list of experience the job seeker has. 9 | 3. Comparing the job description to a list of interests the job seeker has. 10 | 11 | The following is the job description: 12 | 13 | {jobDescription} 14 | 15 | If the job description contains a qualification the job seeker has, the rating should increase. 16 | If the job seeker does not have a requirement the job description has, the rating should decrease. 17 | If the job seeker has at least one but not all the skills or experience for a requirement, the rating should be unaffected. 18 | 19 | The following is a list of qualifications the job seeker has: 20 | {qualifications} 21 | 22 | The following is a list of interests the job seeker has: 23 | {interests} 24 | 25 | You should return the list of requirements taken directly from the job description. 26 | You should return a description of the job role taken directly from the job description. 27 | You should return a description of the company taken directly from the job description. 28 | You should return the rating along with detracting and enhancing factors in the following format: 29 | 30 | {format} 31 | -------------------------------------------------------------------------------- /src/main/resources/prompts/choose-tool-args-no-format.st: -------------------------------------------------------------------------------- 1 | Given a task and the signature of a method extract the values from the task necessary to populate the method arguments. 2 | 3 | If the method requires a URL, extract the URL from the task including query parameters. 4 | 5 | If the method requires a String, then: 6 | - extract the portion of task that makes sense for the tool description. 7 | - just return a String and nothing else in your response. 8 | - never include quotes or brackets in your response. 9 | - If the task description contains ticks or quotes, then prioritize returning just the characters inside the ticks or quotes. 10 | 11 | If the method has no arguments return an empty list. 12 | 13 | Only return the method arguments and nothing else. 14 | 15 | The task is: 16 | {task} 17 | 18 | The method signature is: 19 | {signature} 20 | -------------------------------------------------------------------------------- /src/main/resources/prompts/choose-tool-args.st: -------------------------------------------------------------------------------- 1 | Given a task and the signature of a method extract the values from the task necessary to populate the method arguments. 2 | If the method requires a URL, extract the URL from the task including query parameters. 3 | Return just the method arguments. 4 | If the method has no arguments return an empty list. 5 | 6 | The task is: 7 | {task} 8 | 9 | The method signature is: 10 | {signature} 11 | 12 | {format} -------------------------------------------------------------------------------- /src/main/resources/prompts/choose-tool.st: -------------------------------------------------------------------------------- 1 | Given a task and a list of tools, determine which of the tools is most capable of accomplishing the task. 2 | Return just the name of the tool and nothing else. 3 | If no tool is capable of accomplishing the task, return "__NO_TOOL__" and nothing else. 4 | 5 | The task is: 6 | {task} 7 | 8 | Each tool has a name and a description. Here are the list of tools: 9 | {tools} 10 | -------------------------------------------------------------------------------- /src/main/resources/prompts/task-planner-choose-agent.st: -------------------------------------------------------------------------------- 1 | Given a task and a list of agents, determine which of the agents is most capable of accomplishing the task. 2 | Return just the name of the agent with no extra commentary. 3 | If no agent is capable of accomplishing the task, just return "DefaultLLMAgent" and nothing else. 4 | 5 | The task is: 6 | {task} 7 | 8 | Each agent has a name and a goal. Here are the list of agents: 9 | {agents} 10 | 11 | -------------------------------------------------------------------------------- /src/main/resources/prompts/task-planner-choose-agents.st: -------------------------------------------------------------------------------- 1 | Given a list of tasks and available agents, determine which agent is most capable of accomplishing the task. 2 | If you are unable to determine which agent is most capable of accomplishing the task, return None. 3 | 4 | The tasks are: 5 | {tasks} 6 | 7 | Each agent has a name, a goal, and a list of tools available from the agent. Here are the list of agents: 8 | {agents} 9 | 10 | The map should be keyed by the task and the value should be the agent most capable of accomplishing the task. 11 | It's extremely important that the returned map size should equal the number of tasks. 12 | {format} 13 | -------------------------------------------------------------------------------- /src/main/resources/templates/404.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | not found 4 | 5 | 6 |
7 |

404 Not Found

8 |

The page you were looking for is not found.

9 |
10 | 11 | -------------------------------------------------------------------------------- /src/main/resources/templates/error.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | error 4 | 5 | 6 |
7 |

500 Internal Server Error

8 |
9 | 10 | -------------------------------------------------------------------------------- /src/main/resources/templates/fragments/footer.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |
5 |
6 | © 2024 Shane Witbeck. surly.dev 7 |
8 | 9 | -------------------------------------------------------------------------------- /src/main/resources/templates/fragments/header.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | -------------------------------------------------------------------------------- /src/main/resources/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | Spring AI Collab 4 | 5 | 6 |
7 |
8 | 9 |
10 | 11 |
12 | 13 |
14 |
15 | 16 |
17 | 18 |
19 | 23 |
24 |
25 | 26 |
27 |
28 |
29 | 30 |
31 |
32 |
33 | 34 |
35 | 36 |

Agent: agent

37 |

Tool: tool

38 |

Type: type

39 | 40 |
41 | 42 | display task result as pretty json 44 |
display task result as paragraph
46 | generate image 48 | 49 |
50 |

Timings (duration)

51 | 52 | 53 | 54 | 55 | 56 |
labeltime
57 |
58 |
59 | 60 | -------------------------------------------------------------------------------- /src/main/resources/templates/layouts/default.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Layout Title 5 | 6 | 7 | 8 | 9 | 10 |
11 |
12 | 13 |
14 |

Layout content

15 |
16 | 17 |
18 |
19 |
20 |
21 | 22 | 23 | -------------------------------------------------------------------------------- /src/test/java/dev/surly/ai/collab/agent/AgentRegistryTest.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.agent; 2 | 3 | import org.junit.jupiter.api.Test; 4 | import org.springframework.ai.chat.messages.Message; 5 | import org.springframework.ai.chat.prompt.PromptTemplate; 6 | import org.springframework.beans.factory.annotation.Autowired; 7 | import org.springframework.boot.test.context.SpringBootTest; 8 | import org.stringtemplate.v4.ST; 9 | 10 | import java.util.Arrays; 11 | import java.util.List; 12 | import java.util.Map; 13 | 14 | import static org.junit.jupiter.api.Assertions.*; 15 | 16 | @SpringBootTest 17 | public class AgentRegistryTest { 18 | 19 | @Autowired AgentRegistry agentRegistry; 20 | 21 | @Test 22 | public void registryContainsAgents() { 23 | Map allAgents = agentRegistry.allAgents(); 24 | assertFalse(allAgents.isEmpty()); 25 | } 26 | 27 | @Test 28 | public void testPromptRendering() { 29 | String templateString = "The items are:\n{items :{item | - {item}\n }}"; 30 | List itemList = Arrays.asList("apple", "banana", "cherry"); 31 | PromptTemplate promptTemplate = new PromptTemplate(templateString); 32 | Message message = promptTemplate.createMessage(Map.of("items", itemList)); 33 | 34 | String expected = "The items are:\n" + 35 | "- apple\n" + 36 | "- banana\n" + 37 | "- cherry\n"; 38 | 39 | assertEquals(expected, message.getContent()); 40 | } 41 | 42 | @Test 43 | public void testPromptRendering2() { 44 | String templateString = "The items are:\n{items:{item| - {item}\n}}"; 45 | List itemList = Arrays.asList("apple", "banana", "cherry"); 46 | PromptTemplate promptTemplate = new PromptTemplate(templateString); 47 | Message message = promptTemplate.createMessage(Map.of("items", itemList)); 48 | 49 | String expected = "The items are:\n" + 50 | "- apple\n" + 51 | "- banana\n" + 52 | "- cherry\n"; 53 | 54 | assertEquals(expected, message.getContent()); 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/test/java/dev/surly/ai/collab/flow/FlowExecutionTest.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.flow; 2 | 3 | import dev.surly.ai.collab.task.AgentTaskExecutor; 4 | import dev.surly.ai.collab.task.Task; 5 | import org.junit.jupiter.api.Disabled; 6 | import org.junit.jupiter.api.Test; 7 | import org.springframework.beans.factory.annotation.Autowired; 8 | import org.springframework.boot.test.context.SpringBootTest; 9 | 10 | @SpringBootTest 11 | public class FlowExecutionTest { 12 | @Autowired AgentTaskExecutor agentTaskExecutor; 13 | 14 | 15 | @Test 16 | @Disabled("Requires OpenAI API key") 17 | public void test() { 18 | 19 | Flow flow = new ParallelFlow(agentTaskExecutor); 20 | 21 | Task sayHiTask = new Task("say hi"); 22 | flow.addTask(sayHiTask); 23 | 24 | Task task = new Task(""" 25 | Identify programming language of the following code snippet: 26 | ``` 27 | @Tool(name ="TestWriter", description = "Write comprehensive test cases for a given software class, method, or function") 28 | public String writeTests(String code) { 29 | return "TODO"; 30 | } 31 | ``` 32 | """); 33 | flow.addTask(task); 34 | 35 | Task whoWroteTask = new Task("Who wrote the book Without Remorse?"); 36 | flow.addTask(whoWroteTask); 37 | 38 | Task whoIsTask = new Task("Who is Barack Obama?"); 39 | flow.addTask(whoIsTask); 40 | 41 | Task whoIsTask2 = new Task("What is 2 plus 2?"); 42 | flow.addTask(whoIsTask2); 43 | 44 | FlowExecution flowExecution = new FlowExecution(flow); 45 | 46 | FlowExecutionResult result = flowExecution.execute(); 47 | result.printResults(); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/test/java/dev/surly/ai/collab/task/TaskDeconstructorTest.java: -------------------------------------------------------------------------------- 1 | package dev.surly.ai.collab.task; 2 | 3 | import org.junit.jupiter.api.Test; 4 | import org.springframework.beans.factory.annotation.Autowired; 5 | import org.springframework.boot.test.context.SpringBootTest; 6 | 7 | import java.util.List; 8 | 9 | import static org.junit.jupiter.api.Assertions.*; 10 | 11 | @SpringBootTest 12 | public class TaskDeconstructorTest { 13 | 14 | @Autowired 15 | TaskDeconstructor taskDeconstructor; 16 | 17 | @Test 18 | public void testDeconstruct_No_Subtasks() { 19 | var task = new Task("Give me information about Alphabet"); 20 | var subtasks = taskDeconstructor.deconstruct(List.of(task)); 21 | assertFalse(subtasks.isEmpty()); 22 | assertEquals(1, subtasks.size()); 23 | } 24 | 25 | @Test 26 | public void testDeconstruct_Scrape_No_Subtasks() { 27 | var task = new Task("scrape yahoo.com"); 28 | var subtasks = taskDeconstructor.deconstruct(List.of(task)); 29 | assertFalse(subtasks.isEmpty()); 30 | assertEquals(1, subtasks.size()); 31 | } 32 | 33 | @Test 34 | public void testDeconstruct_Subtasks() { 35 | var task = new Task("Search for the top 5 results for Java, then scrape each page"); 36 | var subtasks = taskDeconstructor.deconstruct(List.of(task)); 37 | assertEquals(2, subtasks.size()); 38 | assertEquals("Search for the top 5 results for Java", subtasks.getFirst().getDescription()); 39 | assertEquals("scrape each page", subtasks.get(1).getDescription()); 40 | } 41 | } 42 | --------------------------------------------------------------------------------