├── .gitignore ├── .mvn └── wrapper │ ├── maven-wrapper.jar │ └── maven-wrapper.properties ├── LICENSE ├── README.md ├── mvnw ├── mvnw.cmd ├── pom.xml └── src ├── main └── java │ └── io │ └── github │ └── rbajek │ └── rasa │ └── sdk │ ├── ActionExecutor.java │ ├── CollectingDispatcher.java │ ├── VersionChecker.java │ ├── action │ ├── Action.java │ └── form │ │ ├── AbstractFormAction.java │ │ └── slot │ │ └── mapper │ │ ├── AbstractSlotMapping.java │ │ ├── EntitySlotMapping.java │ │ ├── IntentSlotMapping.java │ │ ├── TextSlotMapping.java │ │ └── TriggerIntentSlotMapping.java │ ├── dto │ ├── ActionRequest.java │ ├── ActionResponse.java │ ├── Domain.java │ ├── Tracker.java │ └── event │ │ ├── AbstractEvent.java │ │ ├── ActionExecuted.java │ │ ├── ActionExecutionRejected.java │ │ ├── ActionReverted.java │ │ ├── AgentUttered.java │ │ ├── AllSlotsReset.java │ │ ├── BotUttered.java │ │ ├── ConversationPaused.java │ │ ├── ConversationResumed.java │ │ ├── FollowupAction.java │ │ ├── Form.java │ │ ├── FormValidation.java │ │ ├── ReminderCancelled.java │ │ ├── ReminderScheduled.java │ │ ├── Restarted.java │ │ ├── SlotSet.java │ │ ├── StoryExported.java │ │ ├── UserUtteranceReverted.java │ │ └── UserUttered.java │ ├── exception │ ├── ActionExecutionRejectionException.java │ └── RasaException.java │ └── util │ ├── CollectionsUtils.java │ ├── SerializationUtils.java │ └── StringUtils.java └── test ├── java └── io │ └── github │ └── rbajek │ └── rasa │ └── sdk │ ├── ActionExecutorTest.java │ ├── action │ └── form │ │ └── AbstractFormActionTest.java │ └── repository │ └── databuilder │ └── tracker │ ├── EntityBuilder.java │ ├── FormBuilder.java │ ├── IntentBuilder.java │ ├── MessageBuilder.java │ └── TrackerBuilder.java └── resources └── log4j2.xml /.gitignore: -------------------------------------------------------------------------------- 1 | .classpath 2 | .project 3 | *.iml 4 | .settings/ 5 | target/ -------------------------------------------------------------------------------- /.mvn/wrapper/maven-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rbajek/rasa-java-sdk/fb0eff74375ed44edc03547636d0da0f4e7cf6c4/.mvn/wrapper/maven-wrapper.jar -------------------------------------------------------------------------------- /.mvn/wrapper/maven-wrapper.properties: -------------------------------------------------------------------------------- 1 | distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.6.1/apache-maven-3.6.1-bin.zip 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rasa Java-SDK 2 | 3 | Java SDK for the development of custom actions for [Rasa](https://rasa.com/). 4 | 5 | When you want to do something advanced (e.g. execute external API, execute some business logic, etc.) by your Rasa's bot, you should use a [custom actions](https://rasa.com/docs/rasa/core/actions/#custom-actions). 6 | 7 | With this SDK you can easily create the custom actions, written in Java. You can focus only on bussiness logic, everythig else will be done by the SDK. 8 | 9 | In order to run Rasa's custom action, you need to have an action server. Make sure your action server is running and the URL for the action server is correct in the `endpoints.yml` file (see [custom action](https://rasa.com/docs/rasa/core/actions/#custom-actions)): 10 | 11 | For example: 12 | 13 | ```yml 14 | action_endpoint: 15 | url: "http://localhost:5055/webhook" 16 | ``` 17 | 18 | ### Java version 19 | 20 | SDK is compatibile with Java 1.8+ 21 | 22 | ### Maven Repository 23 | 24 | SDK is available in the Central Maven Reposity: 25 | 26 | ```xml 27 | 28 | io.github.rbajek 29 | rasa-java-sdk 30 | 1.0.1 31 | 32 | ``` 33 | 34 | ## Compatibility with Rasa 35 | 36 | | SDK version | compatible Rasa version | 37 | |----------------|-----------------------------------| 38 | | `1.0.x` | `>=1.4.x` | 39 | 40 | ## Usage 41 | 42 | Let's assume that we have a restaurant bot and when the user says "show me a Mexican restaurant", our bot should return the restaurant from the database. 43 | 44 | So, we can create a custom action (let's called "action_check_restaurants") which might look like this: 45 | 46 | ```java 47 | import io.github.rbajek.rasa.sdk.CollectingDispatcher; 48 | import io.github.rbajek.rasa.sdk.action.Action; 49 | import io.github.rbajek.rasa.sdk.dto.Domain; 50 | import io.github.rbajek.rasa.sdk.dto.Tracker; 51 | import io.github.rbajek.rasa.sdk.dto.event.AbstractEvent; 52 | import io.github.rbajek.rasa.sdk.dto.event.SlotSet; 53 | 54 | import java.util.Arrays; 55 | import java.util.List; 56 | 57 | public class ActionCheckRestaurants implements Action { 58 | 59 | @Override 60 | public String name() { 61 | return "action_check_restaurants"; 62 | } 63 | 64 | private String readFromRestaurantDatabase(String cuisine) { 65 | //TODO should be implemented 66 | return null; 67 | } 68 | 69 | @Override 70 | public List run(CollectingDispatcher collectingDispatcher, Tracker tracker, Domain domain) { 71 | String cuisine = tracker.getSlotValue("cuisine", String.class); 72 | 73 | // read from database 74 | String restaurant = readFromRestaurantDatabase(cuisine); 75 | 76 | // return result of the action 77 | return Arrays.asList(new SlotSet("matches", restaurant)); 78 | } 79 | } 80 | ``` 81 | Currently, SDK supports two types of custom actions: 82 | - **general** - which corresponds to [Rasa Custom Action](https://rasa.com/docs/rasa/core/actions/#custom-actions). To use this kind of actions, you can create a Java class which implement the ``io.github.rbajek.rasa.sdk.action.Action`` interface. [Here](https://github.com/rbajek/rasa-java-action-server/blob/master/src/main/java/io/github/rbajek/rasa/action/server/action/custom/joke/ActionJoke.java) you can find an example (based on [original example](https://rasa.com/docs/rasa/user-guide/running-rasa-with-docker/#creating-a-custom-action)) 83 | - **forms** - which corresponds to [Rasa Forms](https://rasa.com/docs/rasa/core/forms/). Ths kind of actions should extends the ``io.github.rbajek.rasa.sdk.action.form.AbstractFormAction``. [Here](https://github.com/rbajek/rasa-java-action-server/blob/master/src/main/java/io/github/rbajek/rasa/action/server/action/custom/form/restaurant/RestaurantFormAction.java) you can find an example (which implement functionality of [Restaurant Form](https://blog.rasa.com/building-contextual-assistants-with-rasa-formaction/)). 84 | 85 | Afterwards, we have to register our action within the `ActionExecutor` (which is part of the SDK) and run it. The response should be return back to Rasa as a JSON format. 86 | 87 | To run the custom action, Rasa needs the action server, which exposes a [REST API](https://rasa.com/docs/rasa/api/action-server/),which can be executed to run custom action. So, we need to have the REST endpoint in our system, which can consume the Rasa's JSON request, run the custom action, and return response in JSON format. 88 | 89 | The simple REST endpoint which can handle requests from Rasa is below:. 90 | 91 | ```java 92 | package io.example; 93 | 94 | import io.github.rbajek.rasa.sdk.ActionExecutor; 95 | import io.github.rbajek.rasa.sdk.dto.ActionRequest; 96 | import io.github.rbajek.rasa.sdk.dto.ActionResponse; 97 | 98 | import javax.ws.rs.GET; 99 | import javax.ws.rs.POST; 100 | import javax.ws.rs.Path; 101 | import javax.ws.rs.Produces; 102 | import javax.ws.rs.core.MediaType; 103 | 104 | @Path("/webhook") 105 | public class RasaWebhook { 106 | 107 | @POST 108 | @Produces(MediaType.APPLICATION_JSON) 109 | public ActionResponse handleAction(ActionRequest request) { 110 | // create instance of the action executor 111 | ActionExecutor actionExecutor = new ActionExecutor(); 112 | 113 | // register custom action 114 | actionExecutor.registerAction(new ActionCheckRestaurants()); 115 | 116 | // run custom action and return result 117 | return actionExecutor.run(actionRequest); 118 | } 119 | } 120 | ``` 121 | 122 | and the last but not least - we need to set the URL of our endpoint in the "endpoints.yml": 123 | 124 | ```yml 125 | action_endpoint: 126 | url: "http://:/webhook" 127 | ``` 128 | 129 | ### Action Server 130 | 131 | To run the custom actions, it's required to have REST endpoint which can handle requests from Rasa. A system which does this is called `Action Server`. You can use your existing system as action server or you can use the [Rasa Java Action Server](https://github.com/rbajek/rasa-java-action-server) as a starting point for your server. 132 | 133 | #### 1. Using your existing system 134 | 135 | If you have already your own system, which can expouse the REST API, you can simple use it. 136 | 137 | 1. Add the dependency to the Java SDK: 138 | 139 | SDK is available in the Central Maven Reposity: 140 | 141 | ```xml 142 | 143 | io.github.rbajek 144 | rasa-java-sdk 145 | 1.0.1 146 | 147 | ``` 148 | 2. Expouse the REST endpoint, which can received calls from Rasa. This enpoint should handle POST requests and map the input JSON request to 149 | 150 | ``` 151 | io.github.rbajek.rasa.sdk.dto.ActionRequest 152 | ``` 153 | 154 | register your custom action within the 155 | 156 | ``` 157 | io.github.rbajek.rasa.sdk.ActionExecutor 158 | ``` 159 | 160 | run it and return response as the JSON format 161 | 162 | #### 2. Rasa Java Action Server 163 | 164 | If you don't have your own system or if you would like to start from strach, then you can use the already created [Rasa Java Action Server](https://github.com/rbajek/rasa-java-action-server) as a starting point. 165 | -------------------------------------------------------------------------------- /mvnw: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # ---------------------------------------------------------------------------- 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # ---------------------------------------------------------------------------- 20 | 21 | # ---------------------------------------------------------------------------- 22 | # Maven2 Start Up Batch script 23 | # 24 | # Required ENV vars: 25 | # ------------------ 26 | # JAVA_HOME - location of a JDK home dir 27 | # 28 | # Optional ENV vars 29 | # ----------------- 30 | # M2_HOME - location of maven2's installed home dir 31 | # MAVEN_OPTS - parameters passed to the Java VM when running Maven 32 | # e.g. to debug Maven itself, use 33 | # set MAVEN_OPTS=-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=8000 34 | # MAVEN_SKIP_RC - flag to disable loading of mavenrc files 35 | # ---------------------------------------------------------------------------- 36 | 37 | if [ -z "$MAVEN_SKIP_RC" ] ; then 38 | 39 | if [ -f /etc/mavenrc ] ; then 40 | . /etc/mavenrc 41 | fi 42 | 43 | if [ -f "$HOME/.mavenrc" ] ; then 44 | . "$HOME/.mavenrc" 45 | fi 46 | 47 | fi 48 | 49 | # OS specific support. $var _must_ be set to either true or false. 50 | cygwin=false; 51 | darwin=false; 52 | mingw=false 53 | case "`uname`" in 54 | CYGWIN*) cygwin=true ;; 55 | MINGW*) mingw=true;; 56 | Darwin*) darwin=true 57 | # Use /usr/libexec/java_home if available, otherwise fall back to /Library/Java/Home 58 | # See https://developer.apple.com/library/mac/qa/qa1170/_index.html 59 | if [ -z "$JAVA_HOME" ]; then 60 | if [ -x "/usr/libexec/java_home" ]; then 61 | export JAVA_HOME="`/usr/libexec/java_home`" 62 | else 63 | export JAVA_HOME="/Library/Java/Home" 64 | fi 65 | fi 66 | ;; 67 | esac 68 | 69 | if [ -z "$JAVA_HOME" ] ; then 70 | if [ -r /etc/gentoo-release ] ; then 71 | JAVA_HOME=`java-config --jre-home` 72 | fi 73 | fi 74 | 75 | if [ -z "$M2_HOME" ] ; then 76 | ## resolve links - $0 may be a link to maven's home 77 | PRG="$0" 78 | 79 | # need this for relative symlinks 80 | while [ -h "$PRG" ] ; do 81 | ls=`ls -ld "$PRG"` 82 | link=`expr "$ls" : '.*-> \(.*\)$'` 83 | if expr "$link" : '/.*' > /dev/null; then 84 | PRG="$link" 85 | else 86 | PRG="`dirname "$PRG"`/$link" 87 | fi 88 | done 89 | 90 | saveddir=`pwd` 91 | 92 | M2_HOME=`dirname "$PRG"`/.. 93 | 94 | # make it fully qualified 95 | M2_HOME=`cd "$M2_HOME" && pwd` 96 | 97 | cd "$saveddir" 98 | # echo Using m2 at $M2_HOME 99 | fi 100 | 101 | # For Cygwin, ensure paths are in UNIX format before anything is touched 102 | if $cygwin ; then 103 | [ -n "$M2_HOME" ] && 104 | M2_HOME=`cygpath --unix "$M2_HOME"` 105 | [ -n "$JAVA_HOME" ] && 106 | JAVA_HOME=`cygpath --unix "$JAVA_HOME"` 107 | [ -n "$CLASSPATH" ] && 108 | CLASSPATH=`cygpath --path --unix "$CLASSPATH"` 109 | fi 110 | 111 | # For Mingw, ensure paths are in UNIX format before anything is touched 112 | if $mingw ; then 113 | [ -n "$M2_HOME" ] && 114 | M2_HOME="`(cd "$M2_HOME"; pwd)`" 115 | [ -n "$JAVA_HOME" ] && 116 | JAVA_HOME="`(cd "$JAVA_HOME"; pwd)`" 117 | # TODO classpath? 118 | fi 119 | 120 | if [ -z "$JAVA_HOME" ]; then 121 | javaExecutable="`which javac`" 122 | if [ -n "$javaExecutable" ] && ! [ "`expr \"$javaExecutable\" : '\([^ ]*\)'`" = "no" ]; then 123 | # readlink(1) is not available as standard on Solaris 10. 124 | readLink=`which readlink` 125 | if [ ! `expr "$readLink" : '\([^ ]*\)'` = "no" ]; then 126 | if $darwin ; then 127 | javaHome="`dirname \"$javaExecutable\"`" 128 | javaExecutable="`cd \"$javaHome\" && pwd -P`/javac" 129 | else 130 | javaExecutable="`readlink -f \"$javaExecutable\"`" 131 | fi 132 | javaHome="`dirname \"$javaExecutable\"`" 133 | javaHome=`expr "$javaHome" : '\(.*\)/bin'` 134 | JAVA_HOME="$javaHome" 135 | export JAVA_HOME 136 | fi 137 | fi 138 | fi 139 | 140 | if [ -z "$JAVACMD" ] ; then 141 | if [ -n "$JAVA_HOME" ] ; then 142 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 143 | # IBM's JDK on AIX uses strange locations for the executables 144 | JAVACMD="$JAVA_HOME/jre/sh/java" 145 | else 146 | JAVACMD="$JAVA_HOME/bin/java" 147 | fi 148 | else 149 | JAVACMD="`which java`" 150 | fi 151 | fi 152 | 153 | if [ ! -x "$JAVACMD" ] ; then 154 | echo "Error: JAVA_HOME is not defined correctly." >&2 155 | echo " We cannot execute $JAVACMD" >&2 156 | exit 1 157 | fi 158 | 159 | if [ -z "$JAVA_HOME" ] ; then 160 | echo "Warning: JAVA_HOME environment variable is not set." 161 | fi 162 | 163 | CLASSWORLDS_LAUNCHER=org.codehaus.plexus.classworlds.launcher.Launcher 164 | 165 | # traverses directory structure from process work directory to filesystem root 166 | # first directory with .mvn subdirectory is considered project base directory 167 | find_maven_basedir() { 168 | 169 | if [ -z "$1" ] 170 | then 171 | echo "Path not specified to find_maven_basedir" 172 | return 1 173 | fi 174 | 175 | basedir="$1" 176 | wdir="$1" 177 | while [ "$wdir" != '/' ] ; do 178 | if [ -d "$wdir"/.mvn ] ; then 179 | basedir=$wdir 180 | break 181 | fi 182 | # workaround for JBEAP-8937 (on Solaris 10/Sparc) 183 | if [ -d "${wdir}" ]; then 184 | wdir=`cd "$wdir/.."; pwd` 185 | fi 186 | # end of workaround 187 | done 188 | echo "${basedir}" 189 | } 190 | 191 | # concatenates all lines of a file 192 | concat_lines() { 193 | if [ -f "$1" ]; then 194 | echo "$(tr -s '\n' ' ' < "$1")" 195 | fi 196 | } 197 | 198 | BASE_DIR=`find_maven_basedir "$(pwd)"` 199 | if [ -z "$BASE_DIR" ]; then 200 | exit 1; 201 | fi 202 | 203 | ########################################################################################## 204 | # Extension to allow automatically downloading the maven-wrapper.jar from Maven-central 205 | # This allows using the maven wrapper in projects that prohibit checking in binary data. 206 | ########################################################################################## 207 | if [ -r "$BASE_DIR/.mvn/wrapper/maven-wrapper.jar" ]; then 208 | if [ "$MVNW_VERBOSE" = true ]; then 209 | echo "Found .mvn/wrapper/maven-wrapper.jar" 210 | fi 211 | else 212 | if [ "$MVNW_VERBOSE" = true ]; then 213 | echo "Couldn't find .mvn/wrapper/maven-wrapper.jar, downloading it ..." 214 | fi 215 | jarUrl="https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.4.2/maven-wrapper-0.4.2.jar" 216 | while IFS="=" read key value; do 217 | case "$key" in (wrapperUrl) jarUrl="$value"; break ;; 218 | esac 219 | done < "$BASE_DIR/.mvn/wrapper/maven-wrapper.properties" 220 | if [ "$MVNW_VERBOSE" = true ]; then 221 | echo "Downloading from: $jarUrl" 222 | fi 223 | wrapperJarPath="$BASE_DIR/.mvn/wrapper/maven-wrapper.jar" 224 | 225 | if command -v wget > /dev/null; then 226 | if [ "$MVNW_VERBOSE" = true ]; then 227 | echo "Found wget ... using wget" 228 | fi 229 | wget "$jarUrl" -O "$wrapperJarPath" 230 | elif command -v curl > /dev/null; then 231 | if [ "$MVNW_VERBOSE" = true ]; then 232 | echo "Found curl ... using curl" 233 | fi 234 | curl -o "$wrapperJarPath" "$jarUrl" 235 | else 236 | if [ "$MVNW_VERBOSE" = true ]; then 237 | echo "Falling back to using Java to download" 238 | fi 239 | javaClass="$BASE_DIR/.mvn/wrapper/MavenWrapperDownloader.java" 240 | if [ -e "$javaClass" ]; then 241 | if [ ! -e "$BASE_DIR/.mvn/wrapper/MavenWrapperDownloader.class" ]; then 242 | if [ "$MVNW_VERBOSE" = true ]; then 243 | echo " - Compiling MavenWrapperDownloader.java ..." 244 | fi 245 | # Compiling the Java class 246 | ("$JAVA_HOME/bin/javac" "$javaClass") 247 | fi 248 | if [ -e "$BASE_DIR/.mvn/wrapper/MavenWrapperDownloader.class" ]; then 249 | # Running the downloader 250 | if [ "$MVNW_VERBOSE" = true ]; then 251 | echo " - Running MavenWrapperDownloader.java ..." 252 | fi 253 | ("$JAVA_HOME/bin/java" -cp .mvn/wrapper MavenWrapperDownloader "$MAVEN_PROJECTBASEDIR") 254 | fi 255 | fi 256 | fi 257 | fi 258 | ########################################################################################## 259 | # End of extension 260 | ########################################################################################## 261 | 262 | export MAVEN_PROJECTBASEDIR=${MAVEN_BASEDIR:-"$BASE_DIR"} 263 | if [ "$MVNW_VERBOSE" = true ]; then 264 | echo $MAVEN_PROJECTBASEDIR 265 | fi 266 | MAVEN_OPTS="$(concat_lines "$MAVEN_PROJECTBASEDIR/.mvn/jvm.config") $MAVEN_OPTS" 267 | 268 | # For Cygwin, switch paths to Windows format before running java 269 | if $cygwin; then 270 | [ -n "$M2_HOME" ] && 271 | M2_HOME=`cygpath --path --windows "$M2_HOME"` 272 | [ -n "$JAVA_HOME" ] && 273 | JAVA_HOME=`cygpath --path --windows "$JAVA_HOME"` 274 | [ -n "$CLASSPATH" ] && 275 | CLASSPATH=`cygpath --path --windows "$CLASSPATH"` 276 | [ -n "$MAVEN_PROJECTBASEDIR" ] && 277 | MAVEN_PROJECTBASEDIR=`cygpath --path --windows "$MAVEN_PROJECTBASEDIR"` 278 | fi 279 | 280 | WRAPPER_LAUNCHER=org.apache.maven.wrapper.MavenWrapperMain 281 | 282 | exec "$JAVACMD" \ 283 | $MAVEN_OPTS \ 284 | -classpath "$MAVEN_PROJECTBASEDIR/.mvn/wrapper/maven-wrapper.jar" \ 285 | "-Dmaven.home=${M2_HOME}" "-Dmaven.multiModuleProjectDirectory=${MAVEN_PROJECTBASEDIR}" \ 286 | ${WRAPPER_LAUNCHER} $MAVEN_CONFIG "$@" 287 | -------------------------------------------------------------------------------- /mvnw.cmd: -------------------------------------------------------------------------------- 1 | @REM ---------------------------------------------------------------------------- 2 | @REM Licensed to the Apache Software Foundation (ASF) under one 3 | @REM or more contributor license agreements. See the NOTICE file 4 | @REM distributed with this work for additional information 5 | @REM regarding copyright ownership. The ASF licenses this file 6 | @REM to you under the Apache License, Version 2.0 (the 7 | @REM "License"); you may not use this file except in compliance 8 | @REM with the License. You may obtain a copy of the License at 9 | @REM 10 | @REM http://www.apache.org/licenses/LICENSE-2.0 11 | @REM 12 | @REM Unless required by applicable law or agreed to in writing, 13 | @REM software distributed under the License is distributed on an 14 | @REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | @REM KIND, either express or implied. See the License for the 16 | @REM specific language governing permissions and limitations 17 | @REM under the License. 18 | @REM ---------------------------------------------------------------------------- 19 | 20 | @REM ---------------------------------------------------------------------------- 21 | @REM Maven2 Start Up Batch script 22 | @REM 23 | @REM Required ENV vars: 24 | @REM JAVA_HOME - location of a JDK home dir 25 | @REM 26 | @REM Optional ENV vars 27 | @REM M2_HOME - location of maven2's installed home dir 28 | @REM MAVEN_BATCH_ECHO - set to 'on' to enable the echoing of the batch commands 29 | @REM MAVEN_BATCH_PAUSE - set to 'on' to wait for a key stroke before ending 30 | @REM MAVEN_OPTS - parameters passed to the Java VM when running Maven 31 | @REM e.g. to debug Maven itself, use 32 | @REM set MAVEN_OPTS=-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=8000 33 | @REM MAVEN_SKIP_RC - flag to disable loading of mavenrc files 34 | @REM ---------------------------------------------------------------------------- 35 | 36 | @REM Begin all REM lines with '@' in case MAVEN_BATCH_ECHO is 'on' 37 | @echo off 38 | @REM set title of command window 39 | title %0 40 | @REM enable echoing my setting MAVEN_BATCH_ECHO to 'on' 41 | @if "%MAVEN_BATCH_ECHO%" == "on" echo %MAVEN_BATCH_ECHO% 42 | 43 | @REM set %HOME% to equivalent of $HOME 44 | if "%HOME%" == "" (set "HOME=%HOMEDRIVE%%HOMEPATH%") 45 | 46 | @REM Execute a user defined script before this one 47 | if not "%MAVEN_SKIP_RC%" == "" goto skipRcPre 48 | @REM check for pre script, once with legacy .bat ending and once with .cmd ending 49 | if exist "%HOME%\mavenrc_pre.bat" call "%HOME%\mavenrc_pre.bat" 50 | if exist "%HOME%\mavenrc_pre.cmd" call "%HOME%\mavenrc_pre.cmd" 51 | :skipRcPre 52 | 53 | @setlocal 54 | 55 | set ERROR_CODE=0 56 | 57 | @REM To isolate internal variables from possible post scripts, we use another setlocal 58 | @setlocal 59 | 60 | @REM ==== START VALIDATION ==== 61 | if not "%JAVA_HOME%" == "" goto OkJHome 62 | 63 | echo. 64 | echo Error: JAVA_HOME not found in your environment. >&2 65 | echo Please set the JAVA_HOME variable in your environment to match the >&2 66 | echo location of your Java installation. >&2 67 | echo. 68 | goto error 69 | 70 | :OkJHome 71 | if exist "%JAVA_HOME%\bin\java.exe" goto init 72 | 73 | echo. 74 | echo Error: JAVA_HOME is set to an invalid directory. >&2 75 | echo JAVA_HOME = "%JAVA_HOME%" >&2 76 | echo Please set the JAVA_HOME variable in your environment to match the >&2 77 | echo location of your Java installation. >&2 78 | echo. 79 | goto error 80 | 81 | @REM ==== END VALIDATION ==== 82 | 83 | :init 84 | 85 | @REM Find the project base dir, i.e. the directory that contains the folder ".mvn". 86 | @REM Fallback to current working directory if not found. 87 | 88 | set MAVEN_PROJECTBASEDIR=%MAVEN_BASEDIR% 89 | IF NOT "%MAVEN_PROJECTBASEDIR%"=="" goto endDetectBaseDir 90 | 91 | set EXEC_DIR=%CD% 92 | set WDIR=%EXEC_DIR% 93 | :findBaseDir 94 | IF EXIST "%WDIR%"\.mvn goto baseDirFound 95 | cd .. 96 | IF "%WDIR%"=="%CD%" goto baseDirNotFound 97 | set WDIR=%CD% 98 | goto findBaseDir 99 | 100 | :baseDirFound 101 | set MAVEN_PROJECTBASEDIR=%WDIR% 102 | cd "%EXEC_DIR%" 103 | goto endDetectBaseDir 104 | 105 | :baseDirNotFound 106 | set MAVEN_PROJECTBASEDIR=%EXEC_DIR% 107 | cd "%EXEC_DIR%" 108 | 109 | :endDetectBaseDir 110 | 111 | IF NOT EXIST "%MAVEN_PROJECTBASEDIR%\.mvn\jvm.config" goto endReadAdditionalConfig 112 | 113 | @setlocal EnableExtensions EnableDelayedExpansion 114 | for /F "usebackq delims=" %%a in ("%MAVEN_PROJECTBASEDIR%\.mvn\jvm.config") do set JVM_CONFIG_MAVEN_PROPS=!JVM_CONFIG_MAVEN_PROPS! %%a 115 | @endlocal & set JVM_CONFIG_MAVEN_PROPS=%JVM_CONFIG_MAVEN_PROPS% 116 | 117 | :endReadAdditionalConfig 118 | 119 | SET MAVEN_JAVA_EXE="%JAVA_HOME%\bin\java.exe" 120 | set WRAPPER_JAR="%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.jar" 121 | set WRAPPER_LAUNCHER=org.apache.maven.wrapper.MavenWrapperMain 122 | 123 | set DOWNLOAD_URL="https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.4.2/maven-wrapper-0.4.2.jar" 124 | FOR /F "tokens=1,2 delims==" %%A IN (%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.properties) DO ( 125 | IF "%%A"=="wrapperUrl" SET DOWNLOAD_URL=%%B 126 | ) 127 | 128 | @REM Extension to allow automatically downloading the maven-wrapper.jar from Maven-central 129 | @REM This allows using the maven wrapper in projects that prohibit checking in binary data. 130 | if exist %WRAPPER_JAR% ( 131 | echo Found %WRAPPER_JAR% 132 | ) else ( 133 | echo Couldn't find %WRAPPER_JAR%, downloading it ... 134 | echo Downloading from: %DOWNLOAD_URL% 135 | powershell -Command "(New-Object Net.WebClient).DownloadFile('%DOWNLOAD_URL%', '%WRAPPER_JAR%')" 136 | echo Finished downloading %WRAPPER_JAR% 137 | ) 138 | @REM End of extension 139 | 140 | %MAVEN_JAVA_EXE% %JVM_CONFIG_MAVEN_PROPS% %MAVEN_OPTS% %MAVEN_DEBUG_OPTS% -classpath %WRAPPER_JAR% "-Dmaven.multiModuleProjectDirectory=%MAVEN_PROJECTBASEDIR%" %WRAPPER_LAUNCHER% %MAVEN_CONFIG% %* 141 | if ERRORLEVEL 1 goto error 142 | goto end 143 | 144 | :error 145 | set ERROR_CODE=1 146 | 147 | :end 148 | @endlocal & set ERROR_CODE=%ERROR_CODE% 149 | 150 | if not "%MAVEN_SKIP_RC%" == "" goto skipRcPost 151 | @REM check for post script, once with legacy .bat ending and once with .cmd ending 152 | if exist "%HOME%\mavenrc_post.bat" call "%HOME%\mavenrc_post.bat" 153 | if exist "%HOME%\mavenrc_post.cmd" call "%HOME%\mavenrc_post.cmd" 154 | :skipRcPost 155 | 156 | @REM pause the script if MAVEN_BATCH_PAUSE is set to 'on' 157 | if "%MAVEN_BATCH_PAUSE%" == "on" pause 158 | 159 | if "%MAVEN_TERMINATE_CMD%" == "on" exit %ERROR_CODE% 160 | 161 | exit /B %ERROR_CODE% 162 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | 5 | io.github.rbajek 6 | rasa-java-sdk 7 | 1.0.1 8 | 9 | rasa-java-sdk 10 | Java SDK for the development of custom actions for Rasa 11 | https://github.com/rbajek/rasa-java-sdk 12 | 13 | 14 | 15 | Apache License, Version 2.0 16 | http://www.apache.org/licenses/LICENSE-2.0.txt 17 | repo 18 | 19 | 20 | 21 | 22 | 23 | Rafał Bajek 24 | raf.bajek@gmail.com 25 | 26 | 27 | 28 | 29 | scm:git:https://github.com/rbajek/rasa-java-sdk.git 30 | scm:git:git@github.com:rbajek/rasa-java-sdk.git 31 | https://github.com/rbajek/rasa-java-sdk 32 | 33 | 34 | 35 | 36 | ossrh 37 | https://oss.sonatype.org/content/repositories/snapshots 38 | 39 | 40 | ossrh 41 | https://oss.sonatype.org/service/local/staging/deploy/maven2/ 42 | 43 | 44 | 45 | 46 | 47 | Github 48 | https://github.com/rbajek/rasa-java-sdk/issues 49 | 50 | 51 | 52 | 53 | UTF-8 54 | UTF-8 55 | 1.8 56 | 57 | 1.7.26 58 | 2.12.0 59 | 1.18.4 60 | 5.5.0 61 | 2.10.0 62 | 63 | 3.8.1 64 | 3.1.0 65 | 3.1.1 66 | 2.22.0 67 | 1.6.8 68 | 1.6 69 | 70 | false 71 | ${skipTests} 72 | ${skipTests} 73 | 74 | 75 | 76 | 77 | org.projectlombok 78 | lombok 79 | ${lombok.version} 80 | provided 81 | 82 | 83 | 84 | 85 | com.fasterxml.jackson.core 86 | jackson-databind 87 | ${fasterxml.jackson.version} 88 | 89 | 90 | 91 | 92 | 93 | org.slf4j 94 | slf4j-api 95 | ${slf4j.version} 96 | 97 | 98 | 99 | 100 | org.apache.logging.log4j 101 | log4j-api 102 | ${log4j.version} 103 | 104 | 105 | org.apache.logging.log4j 106 | log4j-core 107 | ${log4j.version} 108 | 109 | 110 | org.apache.logging.log4j 111 | log4j-slf4j-impl 112 | ${log4j.version} 113 | 114 | 115 | 116 | 117 | org.junit.jupiter 118 | junit-jupiter-api 119 | ${junit.version} 120 | test 121 | 122 | 123 | 124 | org.junit.jupiter 125 | junit-jupiter-engine 126 | ${junit.version} 127 | test 128 | 129 | 130 | 131 | 132 | ${project.name} 133 | 134 | 135 | 136 | maven-compiler-plugin 137 | ${maven-compiler-plugin.version} 138 | 139 | ${java.version} 140 | ${java.version} 141 | 142 | 143 | true 144 | 145 | 146 | 147 | 148 | org.apache.maven.plugins 149 | maven-surefire-plugin 150 | ${maven-surefire-plugin.version} 151 | 152 | ${skipUTs} 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | release 162 | 163 | 164 | 165 | org.apache.maven.plugins 166 | maven-source-plugin 167 | ${maven-source-plugin.version} 168 | 169 | 170 | attach-sources 171 | 172 | jar-no-fork 173 | 174 | 175 | 176 | 177 | 178 | 179 | org.apache.maven.plugins 180 | maven-javadoc-plugin 181 | ${maven-javadoc-plugin.version} 182 | 183 | 184 | attach-javadocs 185 | 186 | jar 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | org.apache.maven.plugins 206 | maven-gpg-plugin 207 | ${maven-gpg-plugin.version} 208 | 209 | 210 | sign-artifacts 211 | verify 212 | 213 | sign 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/ActionExecutor.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk; 2 | 3 | import io.github.rbajek.rasa.sdk.action.Action; 4 | import io.github.rbajek.rasa.sdk.dto.ActionRequest; 5 | import io.github.rbajek.rasa.sdk.dto.ActionResponse; 6 | import io.github.rbajek.rasa.sdk.dto.event.AbstractEvent; 7 | import io.github.rbajek.rasa.sdk.exception.RasaException; 8 | import io.github.rbajek.rasa.sdk.util.StringUtils; 9 | import org.slf4j.Logger; 10 | import org.slf4j.LoggerFactory; 11 | 12 | import java.util.*; 13 | 14 | /** 15 | * Action executor 16 | * 17 | * @author Rafał Bajek 18 | */ 19 | public class ActionExecutor { 20 | private static final Logger LOGGER = LoggerFactory.getLogger(ActionExecutor.class); 21 | 22 | private Map actions = new HashMap<>(); 23 | 24 | public void registerAction(Action action) { 25 | if(StringUtils.isNullOrEmpty(action.name())) { 26 | throw new RasaException("An action must implement a name"); 27 | } 28 | this.actions.put(action.name(), action); 29 | LOGGER.info("Registered action for '{}'.", action.name()); 30 | } 31 | 32 | private void validateEvents(List events, String actionName) { 33 | Iterator eventsIterator = events.iterator(); 34 | while (eventsIterator.hasNext()) { 35 | AbstractEvent event = eventsIterator.next(); 36 | if(StringUtils.isNullOrEmpty(event.getEvent())) { 37 | LOGGER.error("Your action '{}' returned an event without the 'event' property. Event will be ignored! Event: {}", actionName, event); 38 | eventsIterator.remove(); 39 | } 40 | } 41 | } 42 | 43 | public ActionResponse run(ActionRequest actionRequest) { 44 | // Check for version of Rasa. 45 | VersionChecker.checkVersionCompatibility(actionRequest.getVersion()); 46 | 47 | if(StringUtils.isNotNullOrEmpty(actionRequest.getNextAction())) { 48 | LOGGER.debug("Received request to run '{}'", actionRequest.getNextAction()); 49 | Action action = actions.get(actionRequest.getNextAction()); 50 | if(action == null) { 51 | throw new RasaException("No registered Action found for name '"+actionRequest.getNextAction()+"'."); 52 | } 53 | 54 | CollectingDispatcher dispatcher = new CollectingDispatcher(); 55 | List events = action.run(dispatcher, actionRequest.getTracker(), actionRequest.getDomain()); 56 | if(events == null) { 57 | // make sure the action did not just return "null"... 58 | events = Collections.emptyList(); 59 | } 60 | validateEvents(events, actionRequest.getNextAction()); 61 | LOGGER.debug("Finished running '{}'", actionRequest.getNextAction()); 62 | ActionResponse actionResponse = new ActionResponse(); 63 | actionResponse.setEvents(events); 64 | // Rasa API require list of key-value pair objects 65 | actionResponse.setResponses(Arrays.asList(dispatcher.getMessages())); 66 | 67 | return actionResponse; 68 | } 69 | LOGGER.warn("Received an action call without an action."); 70 | return null; 71 | } 72 | 73 | public List getRegisteredActionNames() { 74 | return new ArrayList<>(this.actions.keySet()); 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/CollectingDispatcher.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk; 2 | 3 | import lombok.Getter; 4 | 5 | import java.util.HashMap; 6 | import java.util.List; 7 | import java.util.Map; 8 | 9 | /** 10 | * Send messages back to user 11 | * 12 | * @author Rafał Bajek 13 | */ 14 | @Getter 15 | public class CollectingDispatcher { 16 | 17 | private static final String MSG_TEXT_KEY = "text"; 18 | private static final String MSG_ELEMENTS_KEY = "elements"; 19 | private static final String MSG_BUTTONS_KEY = "buttons"; 20 | private static final String MSG_ATTACHMENT_KEY = "attachment"; 21 | private static final String MSG_TEMPLATE_KEY = "template"; 22 | private static final String MSG_CUSTOM_KEY = "custom"; 23 | private static final String MSG_IMAGE_KEY = "image"; 24 | 25 | private Map messages = new HashMap<>(); 26 | 27 | /** 28 | * Sends a message with custom elements to the output channel. 29 | * 30 | * @param elements array of custom elements 31 | */ 32 | public void utterElements(Object[] elements) { 33 | utterElements(elements, null); 34 | } 35 | 36 | /** 37 | * Sends a message with custom elements to the output channel. 38 | * 39 | * @param elements array of custom elements 40 | * @param kwargs map of custom elements (optional) 41 | */ 42 | public void utterElements(Object[] elements, Map kwargs) { 43 | messages.put(MSG_TEXT_KEY, null); 44 | messages.put(MSG_ELEMENTS_KEY, elements); 45 | if(kwargs != null) { 46 | messages.putAll(kwargs); 47 | } 48 | } 49 | 50 | /** 51 | * Send a text to the output channel 52 | * 53 | * @param text a text message 54 | */ 55 | public void utterMessage(String text) { 56 | utterMessage(text, null); 57 | } 58 | 59 | /** 60 | * Send a text to the output channel 61 | * 62 | * @param text a text message 63 | * @param kwargs map with text messages (optional) 64 | */ 65 | public void utterMessage(String text, Map kwargs) { 66 | messages.put(MSG_TEXT_KEY, text); 67 | if(kwargs != null) { 68 | messages.putAll(kwargs); 69 | } 70 | } 71 | 72 | /** 73 | * Sends a message with buttons to the output channel. 74 | * 75 | * @param text a text message 76 | * @param buttons list of map of buttons 77 | */ 78 | public void utterButtonMessage(String text, List> buttons) { 79 | utterButtonMessage(text, buttons, null); 80 | } 81 | 82 | /** 83 | * Sends a message with buttons to the output channel. 84 | * 85 | * @param text a text message 86 | * @param buttons list of map of buttons 87 | * @param kwargs map of utter button messages (optional) 88 | */ 89 | public void utterButtonMessage(String text, List> buttons, Map kwargs) { 90 | messages.put(MSG_TEXT_KEY, text); 91 | messages.put(MSG_BUTTONS_KEY, buttons); 92 | if(kwargs != null) { 93 | messages.putAll(kwargs); 94 | } 95 | } 96 | 97 | /** 98 | * Send a message to the client with attachments. 99 | * 100 | * @param attachment a attachment 101 | */ 102 | public void utterAttachment(String attachment) { 103 | utterAttachment(attachment, null); 104 | } 105 | 106 | /** 107 | * Send a message to the client with attachments. 108 | * 109 | * @param attachment a attachment 110 | * @param kwargs map of attachments (optional) 111 | */ 112 | public void utterAttachment(String attachment, Map kwargs) { 113 | messages.put(MSG_TEXT_KEY, null); 114 | messages.put(MSG_ATTACHMENT_KEY, attachment); 115 | if(kwargs != null) { 116 | messages.putAll(kwargs); 117 | } 118 | } 119 | 120 | /** 121 | * Sends a message template with buttons to the output channel. 122 | * 123 | * @param template a template 124 | * @param buttons list of map of buttons 125 | */ 126 | public void utterButtonTemplate(String template, List> buttons) { 127 | utterButtonTemplate(template, buttons, null); 128 | } 129 | 130 | /** 131 | * Sends a message template with buttons to the output channel. 132 | * 133 | * @param template a template 134 | * @param buttons list of map of buttons 135 | * @param kwargs map of templates with buttons (optional) 136 | */ 137 | public void utterButtonTemplate(String template, List> buttons, Map kwargs) { 138 | messages.put(MSG_TEMPLATE_KEY, template); 139 | messages.put(MSG_BUTTONS_KEY, buttons); 140 | if(kwargs != null) { 141 | messages.putAll(kwargs); 142 | } 143 | } 144 | 145 | /** 146 | * Send a message to the client based on a template. 147 | * 148 | * @param template a template 149 | */ 150 | public void utterTemplate(String template) { 151 | utterTemplate(template, null); 152 | } 153 | 154 | /** 155 | * Send a message to the client based on a template. 156 | * 157 | * @param template a template 158 | * @param kwargs map of templates (optional) 159 | */ 160 | public void utterTemplate(String template, Map kwargs) { 161 | messages.put(MSG_TEMPLATE_KEY, template); 162 | if(kwargs != null) { 163 | messages.putAll(kwargs); 164 | } 165 | } 166 | 167 | /** 168 | * Sends custom json to the output channel. 169 | * 170 | * @param jsonMessage a JSON message 171 | */ 172 | public void utterCustomJson(String jsonMessage) { 173 | utterCustomJson(jsonMessage, null); 174 | } 175 | 176 | /** 177 | * Sends custom json to the output channel. 178 | * 179 | * @param jsonMessage a JSON message 180 | * @param kwargs map of JSON messages (optional) 181 | */ 182 | public void utterCustomJson(String jsonMessage, Map kwargs) { 183 | messages.put(MSG_CUSTOM_KEY, jsonMessage); 184 | if(kwargs != null) { 185 | messages.putAll(kwargs); 186 | } 187 | } 188 | 189 | /** 190 | * Sends url of image attachment to the output channel. 191 | * 192 | * @param image an URL of image 193 | */ 194 | public void utterImageUrl(String image) { 195 | utterImageUrl(image, null); 196 | } 197 | 198 | /** 199 | * Sends url of image attachment to the output channel. 200 | * 201 | * @param image an URL of image 202 | * @param kwargs map to an URLs of images (optional) 203 | */ 204 | public void utterImageUrl(String image, Map kwargs) { 205 | messages.put(MSG_IMAGE_KEY, image); 206 | if(kwargs != null) { 207 | messages.putAll(kwargs); 208 | } 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/VersionChecker.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk; 2 | 3 | import io.github.rbajek.rasa.sdk.util.StringUtils; 4 | import org.slf4j.Logger; 5 | import org.slf4j.LoggerFactory; 6 | 7 | /** 8 | * @author Rafał Bajek 9 | */ 10 | class VersionChecker { 11 | 12 | private static final Logger LOGGER = LoggerFactory.getLogger(VersionChecker.class); 13 | 14 | public static final String SUPPORTED_VERSION = "1.4.0"; 15 | 16 | /** 17 | *

Check if the version of rasa and rasa_sdk are compatible.

18 | * 19 | *

The version check relies on the version string being formatted as 20 | * 'x.y.z' and compares whether the numbers x and y are the same for both 21 | * rasa and rasa_sdk.

22 | * 23 | *

Currently, only warning is logging

24 | * 25 | * @param rasaVersion A string containing the version of rasa that is making the call to the action server. 26 | */ 27 | public static void checkVersionCompatibility(String rasaVersion) { 28 | 29 | // check for versions of Rasa that are too old to report their version number 30 | if(StringUtils.isNullOrEmpty(rasaVersion)) { 31 | LOGGER.warn("You are using an old version of rasa which might not be compatible with this version of rasa_sdk ({}).\n" + 32 | "To ensure compatibility use the same version for both, modulo the last number, i.e. using version A.B.x " + 33 | "the numbers A and B should be identical for both rasa and rasa_sdk.", SUPPORTED_VERSION); 34 | return; 35 | } 36 | 37 | String[] rasa = rasaVersion.split("\\."); 38 | String[] sdk = SUPPORTED_VERSION.split("\\."); 39 | 40 | if(rasa[0].equals(sdk[0]) == false || rasa[1].equals(sdk[1]) == false) { 41 | LOGGER.warn("Your versions of rasa and rasa_sdk might not be compatible. You are currently running rasa version {} " + 42 | "and rasa_sdk version {}.\nTo ensure compatibility use the same version for both, modulo the last number, " + 43 | "i.e. using version A.B.x the numbers A and B should be identical for both rasa and rasa_sdk.", rasaVersion, SUPPORTED_VERSION ); 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/action/Action.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.action; 2 | 3 | import io.github.rbajek.rasa.sdk.CollectingDispatcher; 4 | import io.github.rbajek.rasa.sdk.dto.Domain; 5 | import io.github.rbajek.rasa.sdk.dto.Tracker; 6 | import io.github.rbajek.rasa.sdk.dto.event.AbstractEvent; 7 | 8 | import java.util.List; 9 | import java.util.Map; 10 | 11 | /** 12 | * Next action to be taken in response to a dialogue state. 13 | * 14 | * @author Rafał Bajek 15 | */ 16 | public interface Action { 17 | 18 | /** 19 | * Unique identifier of this action. 20 | * 21 | * @return a name of this action 22 | */ 23 | String name(); 24 | 25 | /** 26 | * Execute the side effects of this action 27 | * 28 | * @param dispatcher the dispatcher which is used to send messages back to the user. 29 | * Use {@link CollectingDispatcher#utterMessage(String, Map)} or any other method. 30 | * @param tracker the state tracker for the current user. You can access slot values using tracker.getSlots().get(slotName) (see: {@link Tracker}), 31 | * the most recent user message is tracker.getLatestMessage().getText() (see {@link Tracker}) and any other property. 32 | * @param domain the bot's domain 33 | * @return A list of {@link AbstractEvent} instances that is returned through the endpoint 34 | */ 35 | List run(CollectingDispatcher dispatcher, Tracker tracker, Domain domain); 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/action/form/AbstractFormAction.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.action.form; 2 | 3 | import io.github.rbajek.rasa.sdk.CollectingDispatcher; 4 | import io.github.rbajek.rasa.sdk.action.Action; 5 | import io.github.rbajek.rasa.sdk.action.form.slot.mapper.AbstractSlotMapping; 6 | import io.github.rbajek.rasa.sdk.action.form.slot.mapper.EntitySlotMapping; 7 | import io.github.rbajek.rasa.sdk.action.form.slot.mapper.IntentSlotMapping; 8 | import io.github.rbajek.rasa.sdk.action.form.slot.mapper.TriggerIntentSlotMapping; 9 | import io.github.rbajek.rasa.sdk.dto.Domain; 10 | import io.github.rbajek.rasa.sdk.dto.Tracker; 11 | import io.github.rbajek.rasa.sdk.dto.event.AbstractEvent; 12 | import io.github.rbajek.rasa.sdk.dto.event.Form; 13 | import io.github.rbajek.rasa.sdk.dto.event.SlotSet; 14 | import io.github.rbajek.rasa.sdk.exception.ActionExecutionRejectionException; 15 | import io.github.rbajek.rasa.sdk.exception.RasaException; 16 | import io.github.rbajek.rasa.sdk.util.CollectionsUtils; 17 | import io.github.rbajek.rasa.sdk.util.SerializationUtils; 18 | import org.slf4j.Logger; 19 | import org.slf4j.LoggerFactory; 20 | 21 | import java.util.*; 22 | 23 | /** 24 | * An abstract form action class 25 | * 26 | * @author Rafał Bajek 27 | */ 28 | public abstract class AbstractFormAction implements Action { 29 | 30 | private static final Logger LOGGER = LoggerFactory.getLogger(AbstractFormAction.class); 31 | 32 | /** 33 | * This slot is used to store information needed to do the form handling 34 | */ 35 | protected static final String REQUESTED_SLOT = "requested_slot"; 36 | 37 | /** 38 | * Unique identifier of the form 39 | */ 40 | private final String formName; 41 | 42 | /** 43 | * Map of slot validators 44 | */ 45 | private final Map slotValidatorMap = new HashMap<>(); 46 | 47 | public AbstractFormAction(String formName) { 48 | this.formName = formName; 49 | registerSlotsValidators(slotValidatorMap); 50 | } 51 | 52 | /** 53 | * A list of required slots that the form has to fill. 54 | * 55 | * Use "tracker" to request different list of slots 56 | * depending on the state of the dialogue 57 | * 58 | * @param tracker a {@link Tracker} object 59 | * @return list of required slots 60 | */ 61 | protected abstract List requiredSlots(Tracker tracker); 62 | 63 | /** 64 | * Define what the form has to do 65 | * after all required slots are filled 66 | * 67 | * @param dispatcher a {@link CollectingDispatcher} object 68 | * @return list of events 69 | */ 70 | protected abstract List submit(CollectingDispatcher dispatcher); 71 | 72 | /** 73 | * Register a slot validator 74 | * 75 | * @param slotValidatorMap map of slot validators which should be filled out in a particular class 76 | */ 77 | protected abstract void registerSlotsValidators(Map slotValidatorMap); 78 | 79 | /** 80 | *

A Map to mapping required slots.

81 | * 82 | *

Options:

83 | *
    84 | *
  • an extracted entity
  • 85 | *
  • intent: value pairs
  • 86 | *
  • trigger_intent: value pairs
  • 87 | *
  • a whole message
  • 88 | *
89 | * 90 | *

or a list of them, where the first match will be picked

91 | * 92 | *

Empty map is converted to a mapping of 93 | * the slot to the extracted entity with the same name

94 | * 95 | * @return Map of slots with mappings 96 | */ 97 | protected Map> slotMappings() { 98 | return Collections.emptyMap(); 99 | } 100 | 101 | /** 102 | * Check whether user intent matches intent conditions 103 | * 104 | * @param requestedSlotMapping requested slot mapping 105 | * @param tracker a tracker object 106 | * @return true - if the user intent matches intent conditions. Otherwise - false 107 | */ 108 | private boolean intentIsDesired(AbstractSlotMapping requestedSlotMapping, Tracker tracker) { 109 | List mappingIntents = requestedSlotMapping.getIntent(); 110 | List mappingNotIntents = requestedSlotMapping.getNotIntent(); 111 | String intent = tracker.getLatestMessage().getIntent() != null ? tracker.getLatestMessage().getIntent().getName() : null; 112 | 113 | boolean intentNotBlacklisted = CollectionsUtils.isEmpty(mappingIntents) && mappingNotIntents.contains(intent) == false; 114 | 115 | return intentNotBlacklisted || mappingIntents.contains(intent); 116 | } 117 | 118 | /** 119 | * Logs the values of all required slots before submitting the form. 120 | * 121 | * @param tracker a {@link Tracker} object 122 | */ 123 | private void logFormSlots(Tracker tracker) { 124 | List requiredSlots = requiredSlots(tracker); 125 | if(CollectionsUtils.isNotEmpty(requiredSlots) && tracker.hasSlots()) { 126 | StringBuilder slotValues = new StringBuilder(); 127 | requiredSlots.forEach(slotName -> { 128 | slotValues.append("\n").append("\t").append(slotName).append(": ").append(tracker.getSlotValue(slotName)); 129 | }); 130 | LOGGER.debug("No slots left to request, all required slots are filled:{}", slotValues); 131 | } 132 | LOGGER.debug("There are no any slots to requests"); 133 | } 134 | 135 | /** 136 | * Activate form if the form is called for the first time 137 | * 138 | * If activating, validate any required slots that were filled before 139 | * form activation and return "Form" event with the name of the form, as well 140 | * 141 | * @param dispatcher a {@link CollectingDispatcher} object 142 | * @param tracker a {@link Tracker} object 143 | * @param domain a {@link Domain} object 144 | * @return list of events 145 | */ 146 | List activateFormIfRequired(CollectingDispatcher dispatcher, Tracker tracker, Domain domain) { 147 | List events = new ArrayList<>(); 148 | 149 | if(tracker.hasActiveForm()){ 150 | LOGGER.debug("The form '{}' is active", tracker.getActiveForm().getName()); 151 | } else { 152 | LOGGER.debug("There is no active form"); 153 | } 154 | 155 | if(tracker.hasActiveForm() && this.name().equals(tracker.getActiveForm().getName())) { 156 | return new ArrayList<>(); 157 | } else { 158 | LOGGER.debug("Activated the form '{}'", this.name()); 159 | events.add(new Form(this.name())); 160 | 161 | //collect values of required slots filled before activation 162 | Map preFilledSlots = new HashMap<>(); 163 | List requiredSlots = requiredSlots(tracker); 164 | if(CollectionsUtils.isNotEmpty(requiredSlots)) { 165 | requiredSlots.forEach(slotName -> { 166 | if (tracker.hasSlots() && !shouldRequestSlot(tracker, slotName)) { 167 | preFilledSlots.put(slotName, tracker.getSlotValue(slotName)); 168 | } 169 | }); 170 | } 171 | 172 | if(!preFilledSlots.isEmpty()) { 173 | LOGGER.debug("Validating pre-filled required slots: {}", preFilledSlots); 174 | events.addAll(validateSlots(preFilledSlots, dispatcher, tracker, domain)); 175 | } else { 176 | LOGGER.debug("No pre-filled required slots to validate."); 177 | } 178 | } 179 | return events; 180 | } 181 | 182 | /** 183 | * Return a list of events from "validate(...)" method if validation is required: 184 | * - the form is active 185 | * - the form is called after "action_listen" 186 | * - form validation was not cancelled 187 | * 188 | * @param dispatcher a {@link CollectingDispatcher} object 189 | * @param tracker a {@link Tracker} object 190 | * @param domain a {@link Domain} object 191 | * @return list of events 192 | */ 193 | List validateIfRequired(CollectingDispatcher dispatcher, Tracker tracker, Domain domain) { 194 | if("action_listen".equals(tracker.getLatestActionName()) && (tracker.getActiveForm() == null || tracker.getActiveForm().shouldValidate(true))) { 195 | LOGGER.debug("Validating user input '{}'", tracker.getLatestMessage()); 196 | return validate(dispatcher, tracker, domain); 197 | } 198 | LOGGER.debug("Skipping validation"); 199 | return Collections.emptyList(); 200 | } 201 | 202 | /** 203 | * Extract and validate value of requested slot. 204 | * 205 | * If nothing was extracted reject execution of the form action. 206 | * Subclass this method to add custom validation and rejection logic 207 | * 208 | * @param dispatcher a {@link CollectingDispatcher} object 209 | * @param tracker a {@link Tracker} object 210 | * @param domain a {@link Domain} object 211 | * @return list of events 212 | */ 213 | List validate(CollectingDispatcher dispatcher, Tracker tracker, Domain domain) { 214 | // extract other slots that were not requested 215 | // but set by corresponding entity or trigger intent mapping 216 | 217 | Map slotValues = extractOtherSlots(dispatcher, tracker, domain); 218 | 219 | // extract requested slot 220 | if(tracker.hasSlotValue(REQUESTED_SLOT)) { 221 | Object slotToFill = tracker.getSlotValue(REQUESTED_SLOT); 222 | slotValues.putAll(extractRequestedSlot(dispatcher, tracker, domain)); 223 | 224 | if(CollectionsUtils.isEmpty(slotValues)) { 225 | // reject to execute the form action 226 | // if some slot was requested but nothing was extracted 227 | // it will allow other policies to predict another action 228 | throw new ActionExecutionRejectionException("Failed to extract slot '" + slotToFill + "' with action '" + name() + "'"); 229 | } 230 | } 231 | 232 | LOGGER.debug("Validating extracted slots: {}", slotValues); 233 | return validateSlots(slotValues, dispatcher, tracker, domain); 234 | } 235 | 236 | /** 237 | * Request the next slot and utter template if needed 238 | * 239 | * @param dispatcher a {@link CollectingDispatcher} object 240 | * @param tracker a {@link Tracker} object 241 | * @param domain a {@link Domain} object 242 | * @return an event 243 | */ 244 | private AbstractEvent requestNextSlot(CollectingDispatcher dispatcher, Tracker tracker, Domain domain) { 245 | List requiredSlots = requiredSlots(tracker); 246 | if(requiredSlots != null) { 247 | for (String slotName : requiredSlots) { 248 | if(shouldRequestSlot(tracker, slotName)) { 249 | LOGGER.debug("Request next slot '{}'", slotName); 250 | dispatcher.utterTemplate("utter_ask_" + slotName, tracker.hasSlots() ? tracker.getSlots() : Collections.emptyMap()); 251 | return new SlotSet(REQUESTED_SLOT, slotName); 252 | } 253 | }; 254 | } 255 | // no more required slots to fill 256 | return null; 257 | } 258 | 259 | /** 260 | * Return "Form" event with null as name to deactivate the form and reset the requested slot 261 | * 262 | * @return list of events 263 | */ 264 | protected List deactivate() { 265 | LOGGER.debug("Deactivating the form '{}'", name()); 266 | return Arrays.asList(new Form(null), new SlotSet(REQUESTED_SLOT, null)); 267 | } 268 | 269 | /** 270 | * Extract entities for given name 271 | * 272 | * @param entityName an entity name 273 | * @return an extracted value of the given entity name 274 | */ 275 | private String getEntityValue(String entityName, Tracker tracker) { 276 | List entityValues = tracker.getLatestEntityValues(entityName); 277 | return CollectionsUtils.isNotEmpty(entityValues) ? entityValues.get(0) : null; 278 | } 279 | 280 | /** 281 | * Extract the values of the other slots if they are set by corresponding entities from the user input 282 | * else return None 283 | * 284 | * @param dispatcher a {@link CollectingDispatcher} object 285 | * @param tracker a {@link Tracker} object 286 | * @param domain a {@link Domain} object 287 | * @return map of the other slots with values 288 | */ 289 | Map extractOtherSlots(CollectingDispatcher dispatcher, Tracker tracker, Domain domain) { 290 | Map slotValues = new HashMap<>(); 291 | 292 | List requiredSlots = requiredSlots(tracker); 293 | if(requiredSlots != null && !requiredSlots.isEmpty()) { 294 | // look for other slots 295 | for (String slotName : requiredSlots) { 296 | if(!slotName.equals(tracker.getSlotValue(REQUESTED_SLOT))) { 297 | List otherSlotMappings = getMappingsForSlot(slotName); 298 | for (AbstractSlotMapping otherSlotMapping : otherSlotMappings) { 299 | // check whether the slot should be filled by entity with the same name 300 | 301 | boolean shouldFillEntitySlot = otherSlotMapping.isEntitySlotMappingType() && 302 | ((EntitySlotMapping) otherSlotMapping).getEntity().equals(slotName) && 303 | intentIsDesired(otherSlotMapping, tracker); 304 | 305 | // check whether the slot should be filled from trigger intent mapping 306 | boolean shouldFillTriggerSlot = (!tracker.hasActiveForm() || name().equals(tracker.getActiveForm().getName()) == false) && 307 | otherSlotMapping.isTriggerIntentSlotMappingType() && 308 | intentIsDesired(otherSlotMapping, tracker); 309 | 310 | Object value = null; 311 | if (shouldFillEntitySlot) { 312 | value = getEntityValue(slotName, tracker); 313 | } else if (shouldFillTriggerSlot) { 314 | value = ((TriggerIntentSlotMapping) otherSlotMapping).getValue(); 315 | } 316 | 317 | if(value != null) { 318 | LOGGER.debug("Extracted '{}' for extra slot '{}'", value, slotName); 319 | slotValues.put(slotName, value); 320 | return slotValues; 321 | } 322 | 323 | } 324 | } 325 | } 326 | } 327 | return slotValues; 328 | } 329 | 330 | /** 331 | * Extract the value of requested slot from a user input 332 | * 333 | * @param dispatcher a {@link CollectingDispatcher} object 334 | * @param tracker a {@link Tracker} object 335 | * @param domain a {@link Domain} object 336 | * @return map of the requested slot with extracted value 337 | */ 338 | protected Map extractRequestedSlot(CollectingDispatcher dispatcher, Tracker tracker, Domain domain) { 339 | if(tracker.hasSlotValue(REQUESTED_SLOT) == false) { 340 | return Collections.emptyMap(); 341 | } 342 | 343 | String slotToFill = tracker.getSlotValue(REQUESTED_SLOT, String.class); 344 | LOGGER.debug("Trying to extract requested slot '{}' ...", slotToFill); 345 | 346 | //get mapping for requested slot 347 | List requestedSlotMappings = getMappingsForSlot(slotToFill); 348 | for (AbstractSlotMapping requestedSlotMapping : requestedSlotMappings) { 349 | LOGGER.debug("Got mapping '{}'", requestedSlotMapping); 350 | if(intentIsDesired(requestedSlotMapping, tracker)) { 351 | Object value = null; 352 | switch(requestedSlotMapping.getType()) { 353 | case ENTITY: 354 | value = getEntityValue(EntitySlotMapping.class.cast(requestedSlotMapping).getEntity(), tracker); 355 | break; 356 | 357 | case INTENT: 358 | value = IntentSlotMapping.class.cast(requestedSlotMapping).getValue(); 359 | break; 360 | 361 | case TEXT: 362 | value = tracker.getLatestMessage().getText(); 363 | break; 364 | 365 | default: 366 | throw new RasaException("Provided slot mapping type ('" + requestedSlotMapping.getType() + "') is not supported"); 367 | } 368 | 369 | if(value != null) { 370 | LOGGER.debug("Successfully extracted '{}' for requested slot '{}'", value, slotToFill); 371 | Map resultMap = new HashMap<>(); 372 | resultMap.put(slotToFill, value); 373 | return resultMap; 374 | } 375 | } 376 | } 377 | LOGGER.debug("Failed to extract requested slot '{}'", slotToFill); 378 | return Collections.emptyMap(); 379 | } 380 | 381 | /** 382 | * Get mappings for requested slot. 383 | * 384 | * If None, map requested slot to an entity with the same name 385 | * 386 | * @param slotToFill a slot which should be filled 387 | * @return list of slot mappings 388 | */ 389 | private List getMappingsForSlot(String slotToFill) { 390 | List requestedSlotMappings = slotMappings().getOrDefault(slotToFill, Arrays.asList(EntitySlotMapping.builder(slotToFill).build())); 391 | 392 | // check provided slot mappings 393 | requestedSlotMappings.forEach(requestedSlotMapping -> { 394 | if(requestedSlotMapping.getType() == null) { 395 | throw new RasaException("Provided incompatible slot mapping"); 396 | } 397 | }); 398 | 399 | return requestedSlotMappings; 400 | } 401 | 402 | /** 403 | * Validate slots using helper validation functions. 404 | * 405 | * Call particular validator for each slot, value pair to be validated. 406 | * If the particular validator is not implemented, set the slot to the value. 407 | * 408 | * @param slotsMap map of slots with values 409 | * @param dispatcher a {@link CollectingDispatcher} object 410 | * @param tracker a {@link Tracker} object 411 | * @param domain a {@link Domain} object 412 | * @return list of event 413 | */ 414 | private List validateSlots(Map slotsMap, CollectingDispatcher dispatcher, Tracker tracker, Domain domain) { 415 | Map validationOutput = new HashMap<>(); 416 | slotsMap.forEach((slotName, slotValue) -> { 417 | if(this.slotValidatorMap.containsKey(slotName)) { 418 | Map result = this.slotValidatorMap.get(slotName).validateAndConvert(slotValue, dispatcher, tracker, domain); 419 | validationOutput.putAll(result); 420 | } 421 | }); 422 | 423 | slotsMap.putAll(validationOutput); 424 | 425 | // validation succeed, set slots to extracted values 426 | List events = new ArrayList<>(); 427 | slotsMap.forEach((slotName, slotValue) -> { 428 | events.add(new SlotSet(slotName, slotValue)); 429 | }); 430 | return events; 431 | } 432 | 433 | /** 434 | * Check whether form action should request given slot 435 | * 436 | * @param tracker a {@link Tracker} object 437 | * @param slotName a slot name 438 | * @return true - if the given slot should be requested. Otherwise - false 439 | */ 440 | private boolean shouldRequestSlot(Tracker tracker, String slotName) { 441 | return tracker.hasSlotValue(slotName) == false; 442 | } 443 | 444 | @Override 445 | public String name() { 446 | return this.formName; 447 | } 448 | 449 | /** 450 | * Execute the side effects of this form. 451 | * 452 | * Steps: 453 | * - activate if needed 454 | * - validate user input if needed 455 | * - set validated slots 456 | * - utter_ask_{slot} template with the next required slot 457 | * - submit the form if all required slots are set 458 | * - deactivate the form 459 | * @return list of events 460 | */ 461 | @Override 462 | public List run(CollectingDispatcher dispatcher, Tracker tracker, Domain domain) { 463 | // activate the form 464 | List events = activateFormIfRequired(dispatcher, tracker, domain); 465 | 466 | // validate user input 467 | events.addAll(validateIfRequired(dispatcher, tracker, domain)); 468 | 469 | // check that the form wasn't deactivated in validation 470 | if(events.stream().filter(event -> event instanceof Form).noneMatch(event -> Form.class.cast(event).isNotActive())) { 471 | // create temp tracker with populated slots from `validate` method 472 | 473 | // deep clone using JSON serialization 474 | Tracker tempTracker = SerializationUtils.deepClone(tracker); //JsonParser.parse(JsonParser.jsonAsString(tracker), Tracker.class); 475 | events.forEach(event -> { 476 | if(SlotSet.class.isInstance(event)) { 477 | SlotSet slotSetEvent = SlotSet.class.cast(event); 478 | tempTracker.addSlot(slotSetEvent.getName(), slotSetEvent.getValue()); 479 | } 480 | }); 481 | 482 | AbstractEvent nextSlotEvent = requestNextSlot(dispatcher, tempTracker, domain); 483 | if(nextSlotEvent != null) { 484 | // request next slot 485 | events.add(nextSlotEvent); 486 | } else { 487 | // there is nothing more to request, so we can submit 488 | logFormSlots(tempTracker); 489 | LOGGER.debug("Submitting the form '{}'", name()); 490 | List submitEvents = submit(dispatcher); 491 | if(CollectionsUtils.isNotEmpty(submitEvents)) { 492 | events.addAll(submitEvents); 493 | } 494 | // deactivate the form after submission 495 | events.addAll(deactivate()); 496 | } 497 | } 498 | return events; 499 | } 500 | 501 | /** 502 | * A validator slot interface 503 | * 504 | * @author Rafał Bajek 505 | */ 506 | public interface ValidateSlot { 507 | 508 | /** 509 | * Validate slot value and set a new value if required 510 | * 511 | * @param value current slot value 512 | * @param dispatcher dispatcher object 513 | * @param tracker tracker object 514 | * @param domain domain object 515 | * @return Map of slots value, where: key=slotName, value=slotValue 516 | */ 517 | Map validateAndConvert(Object value, CollectingDispatcher dispatcher, Tracker tracker, Domain domain); 518 | } 519 | } 520 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/action/form/slot/mapper/AbstractSlotMapping.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.action.form.slot.mapper; 2 | 3 | import io.github.rbajek.rasa.sdk.util.CollectionsUtils; 4 | import io.github.rbajek.rasa.sdk.exception.RasaException; 5 | import lombok.Getter; 6 | import lombok.ToString; 7 | 8 | import java.util.ArrayList; 9 | import java.util.Collections; 10 | import java.util.List; 11 | 12 | @Getter 13 | @ToString 14 | public abstract class AbstractSlotMapping { 15 | 16 | //================================================= 17 | // Class fields 18 | //================================================= 19 | 20 | protected final SlotMappingType type; 21 | protected List intent; 22 | protected List notIntent; 23 | 24 | //================================================= 25 | // Constructors 26 | //================================================= 27 | 28 | public AbstractSlotMapping(SlotMappingType type) { 29 | this.type = type; 30 | } 31 | 32 | //================================================= 33 | // Class methods 34 | //================================================= 35 | 36 | public boolean isEntitySlotMappingType() { 37 | return SlotMappingType.ENTITY == type; 38 | } 39 | 40 | public boolean isTriggerIntentSlotMappingType() { 41 | return SlotMappingType.TRIGGER_INTENT == type; 42 | } 43 | 44 | //================================================= 45 | // Builder 46 | //================================================= 47 | 48 | public abstract static class AbstractBuilder { 49 | protected final T instance; 50 | 51 | public AbstractBuilder(T instance) { 52 | this.instance = instance; 53 | } 54 | 55 | public B intent(String intent) { 56 | if(this.instance.intent == null) { 57 | this.instance.intent = new ArrayList<>(); 58 | } 59 | this.instance.intent.add(intent); 60 | return (B) this; 61 | } 62 | 63 | public B intent(List intents) { 64 | if(this.instance.intent == null) { 65 | this.instance.intent = new ArrayList<>(); 66 | } 67 | this.instance.intent.addAll(intents); 68 | return (B) this; 69 | } 70 | 71 | public B notIntent(String notIntent) { 72 | if(this.instance.notIntent == null) { 73 | this.instance.notIntent = new ArrayList<>(); 74 | } 75 | this.instance.notIntent.add(notIntent); 76 | return (B) this; 77 | } 78 | 79 | public B notIntent(List notIntents) { 80 | if(this.instance.notIntent == null) { 81 | this.instance.notIntent = new ArrayList<>(); 82 | } 83 | this.instance.notIntent.addAll(notIntents); 84 | return (B) this; 85 | } 86 | 87 | public T build() { 88 | if(CollectionsUtils.isNotEmpty(this.instance.intent) && CollectionsUtils.isNotEmpty(this.instance.notIntent)) { 89 | throw new RasaException("Providing both intent '" + this.instance.intent + "' and notIntent '" + this.instance.notIntent + "' is not supported"); 90 | } 91 | 92 | if(this.instance.intent == null) { 93 | this.instance.intent = Collections.emptyList(); 94 | } 95 | if(this.instance.notIntent == null) { 96 | this.instance.notIntent = Collections.emptyList(); 97 | } 98 | return this.instance; 99 | } 100 | } 101 | 102 | //================================================= 103 | // Inner Types 104 | //================================================= 105 | 106 | @Getter 107 | public enum SlotMappingType { 108 | ENTITY ("from_entity"), 109 | INTENT ("from_intent"), 110 | TEXT ("from_text"), 111 | TRIGGER_INTENT ("from_trigger_intent"); 112 | 113 | private final String value; 114 | 115 | SlotMappingType(String value) { 116 | this.value = value; 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/action/form/slot/mapper/EntitySlotMapping.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.action.form.slot.mapper; 2 | 3 | import lombok.Getter; 4 | import lombok.ToString; 5 | 6 | @Getter 7 | @ToString(callSuper = true) 8 | public class EntitySlotMapping extends AbstractSlotMapping { 9 | 10 | private final String entity; 11 | 12 | private EntitySlotMapping(String entity) { 13 | super(SlotMappingType.ENTITY); 14 | this.entity = entity; 15 | } 16 | 17 | public static Builder builder(String entity) { 18 | return new Builder(entity); 19 | } 20 | 21 | public static class Builder extends AbstractSlotMapping.AbstractBuilder { 22 | 23 | public Builder(String entity) { 24 | super(new EntitySlotMapping(entity)); 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/action/form/slot/mapper/IntentSlotMapping.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.action.form.slot.mapper; 2 | 3 | import io.github.rbajek.rasa.sdk.exception.RasaException; 4 | import lombok.Getter; 5 | import lombok.ToString; 6 | 7 | @Getter 8 | @ToString(callSuper = true) 9 | public class IntentSlotMapping extends AbstractSlotMapping { 10 | 11 | protected T value; 12 | 13 | public IntentSlotMapping() { 14 | this(SlotMappingType.INTENT); 15 | } 16 | 17 | public IntentSlotMapping(SlotMappingType type) { 18 | super(type); 19 | if(SlotMappingType.INTENT != type && SlotMappingType.TRIGGER_INTENT != type) { 20 | throw new RasaException("Slot mapping type should be one of the: " + SlotMappingType.INTENT.getValue() + " or " + SlotMappingType.TRIGGER_INTENT + " but is: " + type); 21 | } 22 | } 23 | 24 | public static Builder builder() { 25 | return new Builder(new IntentSlotMapping()); 26 | } 27 | 28 | public static class Builder extends AbstractSlotMapping.AbstractBuilder, Builder> { 29 | 30 | public Builder(IntentSlotMapping intentSlotMapping) { 31 | super(intentSlotMapping); 32 | } 33 | 34 | public Builder value(T value) { 35 | this.instance.value = value; 36 | return this; 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/action/form/slot/mapper/TextSlotMapping.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.action.form.slot.mapper; 2 | 3 | import lombok.ToString; 4 | 5 | @ToString(callSuper = true) 6 | public class TextSlotMapping extends AbstractSlotMapping { 7 | 8 | public TextSlotMapping() { 9 | super(SlotMappingType.TEXT); 10 | } 11 | 12 | public static Builder builder() { 13 | return new Builder(); 14 | } 15 | 16 | public static class Builder extends AbstractSlotMapping.AbstractBuilder { 17 | 18 | public Builder() { 19 | super(new TextSlotMapping()); 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/action/form/slot/mapper/TriggerIntentSlotMapping.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.action.form.slot.mapper; 2 | 3 | import lombok.ToString; 4 | 5 | @ToString(callSuper = true) 6 | public class TriggerIntentSlotMapping extends IntentSlotMapping { 7 | 8 | public TriggerIntentSlotMapping() { 9 | super(SlotMappingType.TRIGGER_INTENT); 10 | } 11 | 12 | public static Builder builder() { 13 | return new Builder(new TriggerIntentSlotMapping()); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/ActionRequest.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.Getter; 5 | import lombok.Setter; 6 | import lombok.ToString; 7 | 8 | @Getter @Setter @ToString 9 | public class ActionRequest { 10 | 11 | @JsonProperty("next_action") 12 | private String nextAction; 13 | 14 | @JsonProperty("sender_id") 15 | private String senderId; 16 | 17 | private Tracker tracker; 18 | 19 | private Domain domain; 20 | 21 | private String version; 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/ActionResponse.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto; 2 | 3 | import io.github.rbajek.rasa.sdk.dto.event.AbstractEvent; 4 | import lombok.Getter; 5 | import lombok.Setter; 6 | import lombok.ToString; 7 | 8 | import java.util.List; 9 | import java.util.Map; 10 | 11 | @Getter @Setter @ToString 12 | public class ActionResponse { 13 | 14 | private List events; 15 | private List> responses; 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/Domain.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.*; 5 | 6 | import java.util.List; 7 | import java.util.Map; 8 | 9 | @Getter @Setter @ToString 10 | public class Domain { 11 | 12 | private Config config; 13 | private List> intents; 14 | private List entities; 15 | private Map slots; 16 | private Map> templates; 17 | private List actions; 18 | 19 | @Getter @Setter @ToString 20 | public static class Intent { 21 | @JsonProperty("use_entities") 22 | private Object use_entities; 23 | } 24 | 25 | @Getter @Setter @ToString 26 | public static class Slot { 27 | @JsonProperty("auto_fill") 28 | private Boolean autoFill; 29 | 30 | @JsonProperty("initial_value") 31 | private String initialValue; 32 | 33 | private String type; 34 | 35 | private List values; 36 | } 37 | 38 | @Getter @Setter @ToString 39 | public static class Template { 40 | private String image; 41 | private String text; 42 | } 43 | 44 | @Getter @Setter @ToString 45 | public static class Config { 46 | @JsonProperty("store_entities_as_slots") 47 | private Boolean storeEntitiesAsSlots; 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/Tracker.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto; 2 | 3 | import io.github.rbajek.rasa.sdk.util.StringUtils; 4 | import com.fasterxml.jackson.annotation.JsonProperty; 5 | import lombok.Getter; 6 | import lombok.Setter; 7 | import lombok.ToString; 8 | 9 | import java.util.Collections; 10 | import java.util.HashMap; 11 | import java.util.List; 12 | import java.util.Map; 13 | import java.util.stream.Collectors; 14 | 15 | @Getter @Setter @ToString 16 | public class Tracker { 17 | 18 | //================================================= 19 | // Class fields 20 | //================================================= 21 | 22 | @JsonProperty("conversation_id") 23 | private String conversationId; 24 | 25 | private Map slots; 26 | 27 | @JsonProperty("sender_id") 28 | private String senderId; 29 | 30 | @JsonProperty("latest_message") 31 | private Message latestMessage; 32 | 33 | @JsonProperty("followup_action") 34 | private String followupAction; 35 | 36 | private Boolean paused; 37 | 38 | private List events; 39 | 40 | @JsonProperty("latest_input_channel") 41 | private String latestInputChannel; 42 | 43 | @JsonProperty("latest_action_name") 44 | private String latestActionName; 45 | 46 | @JsonProperty("active_form") 47 | private Form activeForm; 48 | 49 | //================================================= 50 | // Class methods 51 | //================================================= 52 | 53 | public boolean hasActiveForm() { 54 | return this.activeForm != null && StringUtils.isNotNullOrEmpty(this.activeForm.name); 55 | } 56 | 57 | /** 58 | * Get entity values found for the passed entity name in latest msg. 59 | * 60 | * @param entityName an entity name 61 | * @return value of given entity name 62 | */ 63 | public List getLatestEntityValues(String entityName) { 64 | if(latestMessage.entities == null) { 65 | return Collections.emptyList(); 66 | } 67 | 68 | return latestMessage.entities.stream() 69 | .filter(entity -> entityName.equals(entity.getEntity())) 70 | .map(Entity::getValue) 71 | .collect(Collectors.toList()); 72 | } 73 | 74 | public void addSlot(String slotName, Object value) { 75 | if(this.slots == null) { 76 | this.slots = new HashMap<>(); 77 | } 78 | this.slots.put(slotName, value); 79 | } 80 | 81 | public Map getSlots() { 82 | return this.slots != null ? this.slots : Collections.emptyMap(); 83 | } 84 | 85 | public boolean hasSlots() { 86 | return this.slots != null ? this.slots.isEmpty() == false : false; 87 | } 88 | 89 | public Object getSlotValue(String slotName) { 90 | return getSlotValue(slotName, Object.class); 91 | } 92 | 93 | public T getSlotValue(String slotName, Class type) { 94 | if(this.slots != null && this.slots.containsKey(slotName)) { 95 | return type.cast(this.slots.get(slotName)); 96 | } 97 | return null; 98 | } 99 | 100 | public boolean hasSlotValue(String slotName) { 101 | if(this.slots != null && this.slots.containsKey(slotName)) { 102 | return this.slots.get(slotName) != null; 103 | } 104 | return false; 105 | } 106 | 107 | //================================================= 108 | // Inner Types 109 | //================================================= 110 | 111 | @Getter @Setter @ToString 112 | public static class Message { 113 | 114 | private List entities; 115 | 116 | private Intent intent; 117 | 118 | @JsonProperty("intent_ranking") 119 | private List intentRanking; 120 | 121 | private String text; 122 | } 123 | 124 | @Getter @Setter @ToString 125 | public static class Event { 126 | private String event; 127 | private Long timestamp; 128 | } 129 | 130 | @Getter @Setter @ToString 131 | public static class Form { 132 | private String name; 133 | private Boolean validate; 134 | private Boolean rejected; 135 | @JsonProperty("trigger_message") 136 | private Message triggerMessage; 137 | 138 | public boolean shouldValidate(Boolean defaultValue) { 139 | return validate != null ? validate : defaultValue; 140 | } 141 | } 142 | 143 | @Getter @Setter @ToString 144 | public static class Intent { 145 | private Double confidence; 146 | private String name; 147 | } 148 | 149 | @Getter @Setter @ToString 150 | public static class Entity { 151 | private Integer start; 152 | private Integer end; 153 | private String value; 154 | private String entity; 155 | private Double confidence; 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/AbstractEvent.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import lombok.EqualsAndHashCode; 4 | import lombok.Getter; 5 | 6 | import java.sql.Timestamp; 7 | 8 | @Getter 9 | @EqualsAndHashCode 10 | public abstract class AbstractEvent { 11 | protected final String event; 12 | protected final Long timestamp; 13 | 14 | public AbstractEvent(String event, Timestamp timestamp) { 15 | this.event = event; 16 | this.timestamp = timestamp != null ? timestamp.getTime() : null; 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/ActionExecuted.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import java.sql.Timestamp; 4 | 5 | public class ActionExecuted extends AbstractEvent { 6 | 7 | //----------------------------------------------- 8 | // Fields 9 | //----------------------------------------------- 10 | 11 | private String name; 12 | private String policy; 13 | private Float confidence; 14 | 15 | //----------------------------------------------- 16 | // Constructors 17 | //----------------------------------------------- 18 | 19 | public ActionExecuted() { 20 | this(null); 21 | } 22 | 23 | public ActionExecuted(Timestamp timestamp) { 24 | super("action", timestamp); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/ActionExecutionRejected.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import lombok.Getter; 4 | 5 | import java.sql.Timestamp; 6 | 7 | @Getter 8 | public class ActionExecutionRejected extends AbstractEvent { 9 | 10 | //----------------------------------------------- 11 | // Fields 12 | //----------------------------------------------- 13 | 14 | private String name; 15 | private String policy; 16 | private Float confidence; 17 | 18 | //----------------------------------------------- 19 | // Constructors 20 | //----------------------------------------------- 21 | 22 | public ActionExecutionRejected() { 23 | this(null); 24 | } 25 | 26 | public ActionExecutionRejected(Timestamp timestamp) { 27 | super("form_validation", timestamp); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/ActionReverted.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import java.sql.Timestamp; 4 | 5 | public class ActionReverted extends AbstractEvent { 6 | 7 | //----------------------------------------------- 8 | // Constructors 9 | //----------------------------------------------- 10 | 11 | public ActionReverted() { 12 | this(null); 13 | } 14 | 15 | public ActionReverted(Timestamp timestamp) { 16 | super("undo", timestamp); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/AgentUttered.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import lombok.Getter; 4 | 5 | import java.sql.Timestamp; 6 | 7 | @Getter 8 | public class AgentUttered extends AbstractEvent { 9 | 10 | //----------------------------------------------- 11 | // Fields 12 | //----------------------------------------------- 13 | 14 | private String text; 15 | private Object data; 16 | 17 | //----------------------------------------------- 18 | // Constructors 19 | //----------------------------------------------- 20 | 21 | public AgentUttered() { 22 | this(null); 23 | } 24 | 25 | public AgentUttered(Timestamp timestamp) { 26 | super("agent", timestamp); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/AllSlotsReset.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import java.sql.Timestamp; 4 | 5 | public class AllSlotsReset extends AbstractEvent { 6 | 7 | //----------------------------------------------- 8 | // Constructors 9 | //----------------------------------------------- 10 | 11 | public AllSlotsReset() { 12 | this(null); 13 | } 14 | 15 | public AllSlotsReset(Timestamp timestamp) { 16 | super("reset_slots", timestamp); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/BotUttered.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import lombok.Getter; 4 | 5 | import java.sql.Timestamp; 6 | 7 | /** 8 | * Not used for now 9 | */ 10 | @Getter 11 | public class BotUttered extends AbstractEvent { 12 | 13 | //----------------------------------------------- 14 | // Fields 15 | //----------------------------------------------- 16 | 17 | private String text; 18 | private Object data; 19 | private Metadata metadata; 20 | 21 | //----------------------------------------------- 22 | // Constructors 23 | //----------------------------------------------- 24 | 25 | public BotUttered() { 26 | this(null); 27 | } 28 | 29 | public BotUttered(Timestamp timestamp) { 30 | super("bot", timestamp); 31 | } 32 | 33 | //----------------------------------------------- 34 | // Inner types 35 | //----------------------------------------------- 36 | 37 | public static class Metadata { 38 | 39 | } 40 | 41 | //----------------------------------------------- 42 | // Getters/Setters 43 | //----------------------------------------------- 44 | } 45 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/ConversationPaused.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import java.sql.Timestamp; 4 | 5 | public class ConversationPaused extends AbstractEvent { 6 | 7 | //----------------------------------------------- 8 | // Constructors 9 | //----------------------------------------------- 10 | 11 | public ConversationPaused() { 12 | this(null); 13 | } 14 | 15 | public ConversationPaused(Timestamp timestamp) { 16 | super("pause", timestamp); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/ConversationResumed.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import java.sql.Timestamp; 4 | 5 | public class ConversationResumed extends AbstractEvent { 6 | 7 | //----------------------------------------------- 8 | // Constructors 9 | //----------------------------------------------- 10 | 11 | public ConversationResumed() { 12 | this(null); 13 | } 14 | 15 | public ConversationResumed(Timestamp timestamp) { 16 | super("resume", timestamp); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/FollowupAction.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import lombok.Getter; 4 | 5 | import java.sql.Timestamp; 6 | 7 | @Getter 8 | public class FollowupAction extends AbstractEvent { 9 | 10 | //----------------------------------------------- 11 | // Fields 12 | //----------------------------------------------- 13 | 14 | private String name; 15 | 16 | //----------------------------------------------- 17 | // Constructors 18 | //----------------------------------------------- 19 | 20 | public FollowupAction() { 21 | this(null); 22 | } 23 | 24 | public FollowupAction(Timestamp timestamp) { 25 | super("followup", timestamp); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/Form.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import io.github.rbajek.rasa.sdk.util.StringUtils; 4 | import lombok.EqualsAndHashCode; 5 | import lombok.Getter; 6 | import lombok.ToString; 7 | 8 | import java.beans.Transient; 9 | import java.sql.Timestamp; 10 | 11 | @Getter 12 | @EqualsAndHashCode(callSuper = true) 13 | @ToString 14 | public class Form extends AbstractEvent { 15 | 16 | //----------------------------------------------- 17 | // Fields 18 | //----------------------------------------------- 19 | 20 | private final String name; 21 | 22 | //----------------------------------------------- 23 | // Constructors 24 | //----------------------------------------------- 25 | 26 | public Form(String name) { 27 | this(name, null); 28 | } 29 | 30 | public Form(String name, Timestamp timestamp) { 31 | super("form", timestamp); 32 | this.name = name; 33 | } 34 | 35 | //----------------------------------------------- 36 | // Class methods 37 | //----------------------------------------------- 38 | 39 | @Transient 40 | public boolean isActive() { 41 | return StringUtils.isNotNullOrEmpty(this.name); 42 | } 43 | 44 | @Transient 45 | public boolean isNotActive() { 46 | return !isActive(); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/FormValidation.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import lombok.Getter; 4 | 5 | import java.sql.Timestamp; 6 | 7 | @Getter 8 | public class FormValidation extends AbstractEvent { 9 | 10 | //----------------------------------------------- 11 | // Fields 12 | //----------------------------------------------- 13 | 14 | private Object validate; 15 | 16 | //----------------------------------------------- 17 | // Constructors 18 | //----------------------------------------------- 19 | 20 | public FormValidation() { 21 | this(null); 22 | } 23 | 24 | public FormValidation(Timestamp timestamp) { 25 | super("form_validation", timestamp); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/ReminderCancelled.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import lombok.Getter; 4 | 5 | import java.sql.Timestamp; 6 | 7 | @Getter 8 | public class ReminderCancelled extends AbstractEvent { 9 | 10 | //----------------------------------------------- 11 | // Fields 12 | //----------------------------------------------- 13 | 14 | private String action; 15 | private String name; 16 | 17 | //----------------------------------------------- 18 | // Constructors 19 | //----------------------------------------------- 20 | 21 | public ReminderCancelled() { 22 | this(null); 23 | } 24 | 25 | public ReminderCancelled(Timestamp timestamp) { 26 | super("cancel_reminder", timestamp); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/ReminderScheduled.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import com.fasterxml.jackson.annotation.JsonProperty; 4 | import lombok.Getter; 5 | 6 | import java.sql.Timestamp; 7 | import java.time.LocalDate; 8 | 9 | @Getter 10 | public class ReminderScheduled extends AbstractEvent { 11 | 12 | //----------------------------------------------- 13 | // Fields 14 | //----------------------------------------------- 15 | 16 | private String action; 17 | @JsonProperty("date_time") 18 | private LocalDate date; 19 | private String name; 20 | @JsonProperty("kill_on_user_msg") 21 | private String killOnUserMsg; 22 | 23 | //----------------------------------------------- 24 | // Constructors 25 | //----------------------------------------------- 26 | 27 | public ReminderScheduled() { 28 | this(null); 29 | } 30 | 31 | public ReminderScheduled(Timestamp timestamp) { 32 | super("reminder", timestamp); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/Restarted.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import java.sql.Timestamp; 4 | 5 | public class Restarted extends AbstractEvent { 6 | 7 | //----------------------------------------------- 8 | // Constructors 9 | //----------------------------------------------- 10 | 11 | public Restarted() { 12 | this(null); 13 | } 14 | 15 | public Restarted(Timestamp timestamp) { 16 | super("restart", timestamp); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/SlotSet.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import lombok.EqualsAndHashCode; 4 | import lombok.Getter; 5 | import lombok.ToString; 6 | 7 | import java.sql.Timestamp; 8 | 9 | @Getter 10 | @EqualsAndHashCode(callSuper = true) 11 | @ToString 12 | public class SlotSet extends AbstractEvent { 13 | 14 | //----------------------------------------------- 15 | // Fields 16 | //----------------------------------------------- 17 | 18 | private final String name; 19 | private final Object value; 20 | 21 | //----------------------------------------------- 22 | // Constructors 23 | //----------------------------------------------- 24 | 25 | public SlotSet(String name, Object value) { 26 | this(name, value, null); 27 | } 28 | 29 | public SlotSet(String name, Object value, Timestamp timestamp) { 30 | super("slot", timestamp); 31 | this.name = name; 32 | this.value = value; 33 | } 34 | 35 | //----------------------------------------------- 36 | // Getters/Setters 37 | //----------------------------------------------- 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/StoryExported.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import java.sql.Timestamp; 4 | 5 | public class StoryExported extends AbstractEvent { 6 | 7 | //----------------------------------------------- 8 | // Constructors 9 | //----------------------------------------------- 10 | 11 | public StoryExported() { 12 | this(null); 13 | } 14 | 15 | public StoryExported(Timestamp timestamp) { 16 | super("export", timestamp); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/UserUtteranceReverted.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import java.sql.Timestamp; 4 | 5 | public class UserUtteranceReverted extends AbstractEvent { 6 | 7 | //----------------------------------------------- 8 | // Constructors 9 | //----------------------------------------------- 10 | 11 | public UserUtteranceReverted() { 12 | this(null); 13 | } 14 | 15 | public UserUtteranceReverted(Timestamp timestamp) { 16 | super("rewind", timestamp); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/dto/event/UserUttered.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.dto.event; 2 | 3 | import io.github.rbajek.rasa.sdk.dto.Tracker; 4 | import com.fasterxml.jackson.annotation.JsonProperty; 5 | import lombok.Getter; 6 | 7 | import java.sql.Timestamp; 8 | 9 | @Getter 10 | public class UserUttered extends AbstractEvent { 11 | 12 | //----------------------------------------------- 13 | // Fields 14 | //----------------------------------------------- 15 | private String text; 16 | 17 | @JsonProperty("parse_data") 18 | private Tracker.Message parseData; 19 | 20 | @JsonProperty("input_channel") 21 | private String inputChannel; 22 | 23 | //----------------------------------------------- 24 | // Constructors 25 | //----------------------------------------------- 26 | 27 | public UserUttered() { 28 | this(null); 29 | } 30 | 31 | public UserUttered(Timestamp timestamp) { 32 | super("user", timestamp); 33 | } 34 | 35 | //----------------------------------------------- 36 | // Getters/Setters 37 | //----------------------------------------------- 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/exception/ActionExecutionRejectionException.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.exception; 2 | 3 | /** 4 | * @author Rafał Bajek 5 | */ 6 | public class ActionExecutionRejectionException extends RasaException { 7 | 8 | public ActionExecutionRejectionException(String message) { 9 | super(message); 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/exception/RasaException.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.exception; 2 | 3 | /** 4 | * Base Exception 5 | * 6 | * @author Rafał Bajek 7 | */ 8 | public class RasaException extends RuntimeException { 9 | 10 | public RasaException(String message) { 11 | super(message); 12 | } 13 | 14 | public RasaException(Throwable cause) { 15 | super(cause); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/util/CollectionsUtils.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.util; 2 | 3 | import java.util.Collection; 4 | import java.util.Map; 5 | 6 | public class CollectionsUtils { 7 | 8 | public static boolean isEmpty(Collection collection) { 9 | return collection == null || collection.isEmpty(); 10 | } 11 | 12 | public static boolean isNotEmpty(Collection collection) { 13 | return !isEmpty(collection); 14 | } 15 | 16 | public static boolean isEmpty(Map map) { 17 | return map == null || map.isEmpty(); 18 | } 19 | 20 | public static boolean isNotEmpty(Map map) { 21 | return !isEmpty(map); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/util/SerializationUtils.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.util; 2 | 3 | import io.github.rbajek.rasa.sdk.exception.RasaException; 4 | import com.fasterxml.jackson.databind.ObjectMapper; 5 | 6 | import java.io.*; 7 | 8 | public class SerializationUtils { 9 | 10 | private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); 11 | 12 | /** 13 | * Clone deep using JSON serialization 14 | * 15 | * @param object an object which should be cloned 16 | * @param type of the object which should be cloned 17 | * @return the cloned object 18 | */ 19 | public static T deepClone(final T object) { 20 | try { 21 | return (T) OBJECT_MAPPER.readValue(OBJECT_MAPPER.writeValueAsString(object), object.getClass()); 22 | } catch (IOException e) { 23 | throw new RasaException(e); 24 | } 25 | } 26 | 27 | /** 28 | * Clone deep using Java serialization 29 | * 30 | * @param object an object which should be cloned 31 | * @param type of the object which should be cloned 32 | * @return the cloned object 33 | */ 34 | public static T deepClone(final T object) { 35 | ByteArrayOutputStream baos = new ByteArrayOutputStream(); 36 | serialize(object, baos); 37 | return deserialize(new ByteArrayInputStream(baos.toByteArray())); 38 | } 39 | 40 | private static void serialize(final Serializable obj, final OutputStream outputStream) { 41 | try (ObjectOutputStream out = new ObjectOutputStream(outputStream)) { 42 | out.writeObject(obj); 43 | } catch (final IOException ex) { 44 | throw new RasaException(ex); 45 | } 46 | } 47 | 48 | private static T deserialize(final InputStream inputStream) { 49 | try (ObjectInputStream in = new ObjectInputStream(inputStream)) { 50 | return (T) in.readObject(); 51 | } catch (IOException | ClassNotFoundException e) { 52 | throw new RasaException(e); 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/io/github/rbajek/rasa/sdk/util/StringUtils.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.util; 2 | 3 | public class StringUtils { 4 | 5 | public static boolean isNullOrEmpty(String str) { 6 | return str == null || str.isEmpty(); 7 | } 8 | 9 | public static boolean isNotNullOrEmpty(String str) { 10 | return !isNullOrEmpty(str); 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /src/test/java/io/github/rbajek/rasa/sdk/ActionExecutorTest.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk; 2 | 3 | import io.github.rbajek.rasa.sdk.action.Action; 4 | import io.github.rbajek.rasa.sdk.dto.ActionRequest; 5 | import io.github.rbajek.rasa.sdk.dto.ActionResponse; 6 | import io.github.rbajek.rasa.sdk.dto.Domain; 7 | import io.github.rbajek.rasa.sdk.dto.Tracker; 8 | import io.github.rbajek.rasa.sdk.dto.event.AbstractEvent; 9 | import io.github.rbajek.rasa.sdk.dto.event.SlotSet; 10 | import org.junit.jupiter.api.Assertions; 11 | import org.junit.jupiter.api.Test; 12 | 13 | import java.util.Arrays; 14 | import java.util.List; 15 | 16 | import static org.junit.jupiter.api.Assertions.assertEquals; 17 | import static org.junit.jupiter.api.Assertions.assertTrue; 18 | 19 | class ActionExecutorTest { 20 | 21 | private static final String ACTION_NAME = "custom_action"; 22 | 23 | @Test 24 | void run() { 25 | ActionExecutor executor = new ActionExecutor(); 26 | executor.registerAction(new CustomAction()); 27 | 28 | ActionRequest actionRequest = new ActionRequest(); 29 | actionRequest.setNextAction(ACTION_NAME); 30 | actionRequest.setVersion(VersionChecker.SUPPORTED_VERSION); 31 | ActionResponse events = executor.run(actionRequest); 32 | assertTrue(events.getEvents().size() == 1); 33 | Assertions.assertEquals(new SlotSet("test", "test"), events.getEvents().get(0)); 34 | } 35 | 36 | private static class CustomAction implements Action { 37 | 38 | @Override 39 | public String name() { 40 | return ACTION_NAME; 41 | } 42 | 43 | public String someCommonFeature() { 44 | return "test"; 45 | } 46 | 47 | @Override 48 | public List run(CollectingDispatcher dispatcher, Tracker tracker, Domain domain) { 49 | return Arrays.asList(new SlotSet("test", someCommonFeature())); 50 | } 51 | } 52 | } -------------------------------------------------------------------------------- /src/test/java/io/github/rbajek/rasa/sdk/action/form/AbstractFormActionTest.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.action.form; 2 | 3 | import io.github.rbajek.rasa.sdk.CollectingDispatcher; 4 | import io.github.rbajek.rasa.sdk.action.form.slot.mapper.*; 5 | import io.github.rbajek.rasa.sdk.dto.Domain; 6 | import io.github.rbajek.rasa.sdk.dto.Tracker; 7 | import io.github.rbajek.rasa.sdk.dto.event.AbstractEvent; 8 | import io.github.rbajek.rasa.sdk.dto.event.Form; 9 | import io.github.rbajek.rasa.sdk.dto.event.SlotSet; 10 | import io.github.rbajek.rasa.sdk.exception.ActionExecutionRejectionException; 11 | import io.github.rbajek.rasa.sdk.repository.databuilder.tracker.EntityBuilder; 12 | import io.github.rbajek.rasa.sdk.repository.databuilder.tracker.FormBuilder; 13 | import io.github.rbajek.rasa.sdk.repository.databuilder.tracker.MessageBuilder; 14 | import io.github.rbajek.rasa.sdk.repository.databuilder.tracker.TrackerBuilder; 15 | import org.junit.jupiter.api.Assertions; 16 | import org.junit.jupiter.api.Test; 17 | 18 | import java.util.*; 19 | 20 | import static org.junit.jupiter.api.Assertions.assertEquals; 21 | import static org.junit.jupiter.api.Assertions.assertTrue; 22 | 23 | class AbstractFormActionTest { 24 | 25 | /** 26 | * Test default extraction of a slot value from entity with the same name 27 | */ 28 | @Test 29 | void extractRequestedSlotDefault() { 30 | AbstractFormAction form = new AbstractFormAction("defaultFormName") { 31 | @Override 32 | protected List requiredSlots(Tracker tracker) { 33 | return null; 34 | } 35 | 36 | @Override 37 | protected List submit(CollectingDispatcher dispatcher) { 38 | return null; 39 | } 40 | 41 | @Override 42 | protected void registerSlotsValidators(Map slotValidatorMap) { 43 | 44 | } 45 | }; 46 | 47 | Tracker tracker = TrackerBuilder.builder() 48 | .senderId("default") 49 | .addSlot("requested_slot", "some_slot") 50 | .latestMessage(MessageBuilder.builder() 51 | .addEntity(EntityBuilder.builder() 52 | .entity("some_slot") 53 | .value("some_value") 54 | .build()) 55 | .build()) 56 | .paused(false) 57 | .latestActionName("action_listen") 58 | .build(); 59 | 60 | Map slotValues = form.extractRequestedSlot(new CollectingDispatcher(), tracker, null); 61 | 62 | Map expected = new HashMap<>(); 63 | expected.put("some_slot", "some_value"); 64 | 65 | assertEquals(expected, slotValues); 66 | } 67 | 68 | /** 69 | * Test extraction of a slot value from entity with the different name and any intent 70 | */ 71 | @Test 72 | void extractRequestedSlotFromEntityNoIntent() { 73 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 74 | @Override 75 | protected List requiredSlots(Tracker tracker) { 76 | return null; 77 | } 78 | 79 | @Override 80 | protected List submit(CollectingDispatcher dispatcher) { 81 | return null; 82 | } 83 | 84 | @Override 85 | protected void registerSlotsValidators(Map slotValidatorMap) { 86 | 87 | } 88 | 89 | @Override 90 | protected Map> slotMappings() { 91 | Map> slotMappingMap = new HashMap<>(); 92 | 93 | slotMappingMap.put("some_slot", Arrays.asList(EntitySlotMapping.builder("some_entity").build())); 94 | 95 | return slotMappingMap; 96 | } 97 | }; 98 | 99 | Tracker tracker = TrackerBuilder.builder() 100 | .senderId("default") 101 | .addSlot("requested_slot", "some_slot") 102 | .latestMessage(MessageBuilder.builder() 103 | .addEntity(EntityBuilder.builder() 104 | .entity("some_entity") 105 | .value("some_value") 106 | .build()) 107 | .build()) 108 | .paused(false) 109 | .latestActionName("action_listen") 110 | .build(); 111 | 112 | Map slotValues = customFormAction.extractRequestedSlot(new CollectingDispatcher(), tracker, null); 113 | 114 | Map expected = new HashMap<>(); 115 | expected.put("some_slot", "some_value"); 116 | 117 | assertEquals(expected, slotValues); 118 | } 119 | 120 | /** 121 | * Test extraction of a slot value from entity with the different name and certain intent 122 | */ 123 | @Test 124 | void extractRequestedSlotFromEntityWithIntent() { 125 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 126 | @Override 127 | protected List requiredSlots(Tracker tracker) { 128 | return null; 129 | } 130 | 131 | @Override 132 | protected List submit(CollectingDispatcher dispatcher) { 133 | return null; 134 | } 135 | 136 | @Override 137 | protected void registerSlotsValidators(Map slotValidatorMap) { 138 | 139 | } 140 | 141 | @Override 142 | protected Map> slotMappings() { 143 | Map> slotMappingMap = new HashMap<>(); 144 | 145 | slotMappingMap.put("some_slot", Arrays.asList(EntitySlotMapping.builder("some_entity").intent("some_intent").build())); 146 | 147 | return slotMappingMap; 148 | } 149 | }; 150 | 151 | Tracker tracker = TrackerBuilder.builder() 152 | .senderId("default") 153 | .addSlot("requested_slot", "some_slot") 154 | .latestMessage(MessageBuilder.builder() 155 | .addEntity(EntityBuilder.builder() 156 | .entity("some_entity") 157 | .value("some_value") 158 | .build()) 159 | .intent("some_intent", 1.0) 160 | .build()) 161 | .paused(false) 162 | .latestActionName("action_listen") 163 | .build(); 164 | 165 | Map slotValues = customFormAction.extractRequestedSlot(new CollectingDispatcher(), tracker, null); 166 | 167 | Map expected = new HashMap<>(); 168 | expected.put("some_slot", "some_value"); 169 | 170 | assertEquals(expected, slotValues); 171 | } 172 | 173 | /** 174 | * Test extraction of a slot value from entity with the different name and certain intent 175 | */ 176 | @Test 177 | void extractRequestedSlotFromEntityWithNotIntent() { 178 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 179 | @Override 180 | protected List requiredSlots(Tracker tracker) { 181 | return null; 182 | } 183 | 184 | @Override 185 | protected List submit(CollectingDispatcher dispatcher) { 186 | return null; 187 | } 188 | 189 | @Override 190 | protected void registerSlotsValidators(Map slotValidatorMap) { 191 | 192 | } 193 | 194 | @Override 195 | protected Map> slotMappings() { 196 | Map> slotMappingMap = new HashMap<>(); 197 | 198 | slotMappingMap.put("some_slot", Arrays.asList(EntitySlotMapping.builder("some_entity").notIntent("some_intent").build())); 199 | 200 | return slotMappingMap; 201 | } 202 | }; 203 | 204 | Tracker tracker = TrackerBuilder.builder() 205 | .senderId("default") 206 | .addSlot("requested_slot", "some_slot") 207 | .latestMessage(MessageBuilder.builder() 208 | .addEntity(EntityBuilder.builder() 209 | .entity("some_entity") 210 | .value("some_value") 211 | .build()) 212 | .intent("some_intent", 1.0) 213 | .build()) 214 | .paused(false) 215 | .latestActionName("action_listen") 216 | .build(); 217 | 218 | Map slotValues = customFormAction.extractRequestedSlot(new CollectingDispatcher(), tracker, null); 219 | 220 | // check that the value was extracted for correct intent 221 | assertTrue(slotValues.isEmpty()); 222 | 223 | tracker = TrackerBuilder.builder() 224 | .senderId("default") 225 | .addSlot("requested_slot", "some_slot") 226 | .latestMessage(MessageBuilder.builder() 227 | .addEntity(EntityBuilder.builder() 228 | .entity("some_entity") 229 | .value("some_value") 230 | .build()) 231 | .intent("some_other_intent", 1.0) 232 | .build()) 233 | .paused(false) 234 | .latestActionName("action_listen") 235 | .build(); 236 | 237 | slotValues = customFormAction.extractRequestedSlot(new CollectingDispatcher(), tracker, null); 238 | 239 | Map expected = new HashMap<>(); 240 | expected.put("some_slot", "some_value"); 241 | 242 | // check that the value was not extracted for incorrect intent 243 | assertEquals(expected, slotValues); 244 | 245 | } 246 | 247 | /** 248 | * Test extraction of a slot value from certain intent 249 | */ 250 | @Test 251 | void extractRequestedSlotFromIntent() { 252 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 253 | @Override 254 | protected List requiredSlots(Tracker tracker) { 255 | return null; 256 | } 257 | 258 | @Override 259 | protected List submit(CollectingDispatcher dispatcher) { 260 | return null; 261 | } 262 | 263 | @Override 264 | protected void registerSlotsValidators(Map slotValidatorMap) { 265 | 266 | } 267 | 268 | @Override 269 | protected Map> slotMappings() { 270 | Map> slotMappingMap = new HashMap<>(); 271 | 272 | slotMappingMap.put("some_slot", Arrays.asList(IntentSlotMapping.builder().intent("some_intent").value("some_value").build())); 273 | 274 | return slotMappingMap; 275 | } 276 | }; 277 | 278 | Tracker tracker = TrackerBuilder.builder() 279 | .senderId("default") 280 | .addSlot("requested_slot", "some_slot") 281 | .latestMessage(MessageBuilder.builder() 282 | .intent("some_intent", 1.0) 283 | .build()) 284 | .paused(false) 285 | .latestActionName("action_listen") 286 | .build(); 287 | 288 | Map slotValues = customFormAction.extractRequestedSlot(new CollectingDispatcher(), tracker, null); 289 | 290 | Map expected = new HashMap<>(); 291 | expected.put("some_slot", "some_value"); 292 | 293 | // check that the value was not extracted for incorrect intent 294 | assertEquals(expected, slotValues); 295 | } 296 | 297 | /** 298 | * Test extraction of a slot value from certain intent 299 | */ 300 | @Test 301 | void extractRequestedSlotFromNotIntent() { 302 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 303 | @Override 304 | protected List requiredSlots(Tracker tracker) { 305 | return null; 306 | } 307 | 308 | @Override 309 | protected List submit(CollectingDispatcher dispatcher) { 310 | return null; 311 | } 312 | 313 | @Override 314 | protected void registerSlotsValidators(Map slotValidatorMap) { 315 | 316 | } 317 | 318 | @Override 319 | protected Map> slotMappings() { 320 | Map> slotMappingMap = new HashMap<>(); 321 | 322 | slotMappingMap.put("some_slot", Arrays.asList(IntentSlotMapping.builder().notIntent("some_intent").value("some_value").build())); 323 | 324 | return slotMappingMap; 325 | } 326 | }; 327 | 328 | Tracker tracker = TrackerBuilder.builder() 329 | .senderId("default") 330 | .addSlot("requested_slot", "some_slot") 331 | .latestMessage(MessageBuilder.builder() 332 | .intent("some_intent", 1.0) 333 | .build()) 334 | .paused(false) 335 | .latestActionName("action_listen") 336 | .build(); 337 | 338 | Map slotValues = customFormAction.extractRequestedSlot(new CollectingDispatcher(), tracker, null); 339 | 340 | // check that the value was extracted for correct intent 341 | assertTrue(slotValues.isEmpty()); 342 | 343 | tracker = TrackerBuilder.builder() 344 | .senderId("default") 345 | .addSlot("requested_slot", "some_slot") 346 | .latestMessage(MessageBuilder.builder() 347 | .intent("some_other_intent", 1.0) 348 | .build()) 349 | .paused(false) 350 | .latestActionName("action_listen") 351 | .build(); 352 | 353 | slotValues = customFormAction.extractRequestedSlot(new CollectingDispatcher(), tracker, null); 354 | 355 | Map expected = new HashMap<>(); 356 | expected.put("some_slot", "some_value"); 357 | 358 | // check that the value was not extracted for incorrect intent 359 | assertEquals(expected, slotValues); 360 | } 361 | 362 | /** 363 | * Test extraction of a slot value from text with any intent 364 | */ 365 | @Test 366 | void extractRequestedSlotFromTextNoIntent() { 367 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 368 | @Override 369 | protected List requiredSlots(Tracker tracker) { 370 | return null; 371 | } 372 | 373 | @Override 374 | protected List submit(CollectingDispatcher dispatcher) { 375 | return null; 376 | } 377 | 378 | @Override 379 | protected void registerSlotsValidators(Map slotValidatorMap) { 380 | 381 | } 382 | 383 | @Override 384 | protected Map> slotMappings() { 385 | Map> slotMappingMap = new HashMap<>(); 386 | 387 | slotMappingMap.put("some_slot", Arrays.asList(TextSlotMapping.builder().build())); 388 | 389 | return slotMappingMap; 390 | } 391 | }; 392 | 393 | Tracker tracker = TrackerBuilder.builder() 394 | .senderId("default") 395 | .addSlot("requested_slot", "some_slot") 396 | .latestMessage(MessageBuilder.builder() 397 | .text("some_text") 398 | .build()) 399 | .paused(false) 400 | .latestActionName("action_listen") 401 | .build(); 402 | 403 | Map slotValues = customFormAction.extractRequestedSlot(new CollectingDispatcher(), tracker, null); 404 | 405 | Map expected = new HashMap<>(); 406 | expected.put("some_slot", "some_text"); 407 | 408 | // check that the value was not extracted for incorrect intent 409 | assertEquals(expected, slotValues); 410 | } 411 | 412 | /** 413 | * Test extraction of a slot value from text with certain intent 414 | */ 415 | @Test 416 | void extractRequestedSlotFromTextWithIntent() { 417 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 418 | @Override 419 | protected List requiredSlots(Tracker tracker) { 420 | return null; 421 | } 422 | 423 | @Override 424 | protected List submit(CollectingDispatcher dispatcher) { 425 | return null; 426 | } 427 | 428 | @Override 429 | protected void registerSlotsValidators(Map slotValidatorMap) { 430 | 431 | } 432 | 433 | @Override 434 | protected Map> slotMappings() { 435 | Map> slotMappingMap = new HashMap<>(); 436 | 437 | slotMappingMap.put("some_slot", Arrays.asList(TextSlotMapping.builder().intent("some_intent").build())); 438 | 439 | return slotMappingMap; 440 | } 441 | }; 442 | 443 | Tracker tracker = TrackerBuilder.builder() 444 | .senderId("default") 445 | .addSlot("requested_slot", "some_slot") 446 | .latestMessage(MessageBuilder.builder() 447 | .text("some_text") 448 | .intent("some_intent", 1.0) 449 | .build()) 450 | .paused(false) 451 | .latestActionName("action_listen") 452 | .build(); 453 | 454 | Map slotValues = customFormAction.extractRequestedSlot(new CollectingDispatcher(), tracker, null); 455 | 456 | Map expected = new HashMap<>(); 457 | expected.put("some_slot", "some_text"); 458 | 459 | // check that the value was not extracted for incorrect intent 460 | assertEquals(expected, slotValues); 461 | 462 | tracker = TrackerBuilder.builder() 463 | .senderId("default") 464 | .addSlot("requested_slot", "some_slot") 465 | .latestMessage(MessageBuilder.builder() 466 | .text("some_text") 467 | .intent("some_other_intent", 1.0) 468 | .build()) 469 | .paused(false) 470 | .latestActionName("action_listen") 471 | .build(); 472 | 473 | slotValues = customFormAction.extractRequestedSlot(new CollectingDispatcher(), tracker, null); 474 | // check that the value was not extracted for incorrect intent 475 | assertTrue(slotValues.isEmpty()); 476 | } 477 | 478 | /** 479 | * Test extraction of a slot value from text with certain intent 480 | */ 481 | @Test 482 | void extractRequestedSlotFromTextWithNotIntent() { 483 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 484 | @Override 485 | protected List requiredSlots(Tracker tracker) { 486 | return null; 487 | } 488 | 489 | @Override 490 | protected List submit(CollectingDispatcher dispatcher) { 491 | return null; 492 | } 493 | 494 | @Override 495 | protected void registerSlotsValidators(Map slotValidatorMap) { 496 | 497 | } 498 | 499 | @Override 500 | protected Map> slotMappings() { 501 | Map> slotMappingMap = new HashMap<>(); 502 | 503 | slotMappingMap.put("some_slot", Arrays.asList(TextSlotMapping.builder().notIntent("some_intent").build())); 504 | 505 | return slotMappingMap; 506 | } 507 | }; 508 | 509 | Tracker tracker = TrackerBuilder.builder() 510 | .senderId("default") 511 | .addSlot("requested_slot", "some_slot") 512 | .latestMessage(MessageBuilder.builder() 513 | .text("some_text") 514 | .intent("some_intent", 1.0) 515 | .build()) 516 | .paused(false) 517 | .latestActionName("action_listen") 518 | .build(); 519 | 520 | Map slotValues = customFormAction.extractRequestedSlot(new CollectingDispatcher(), tracker, null); 521 | // check that the value was extracted for correct intent 522 | assertTrue(slotValues.isEmpty()); 523 | 524 | tracker = TrackerBuilder.builder() 525 | .senderId("default") 526 | .addSlot("requested_slot", "some_slot") 527 | .latestMessage(MessageBuilder.builder() 528 | .text("some_text") 529 | .intent("some_other_intent", 1.0) 530 | .build()) 531 | .paused(false) 532 | .latestActionName("action_listen") 533 | .build(); 534 | 535 | slotValues = customFormAction.extractRequestedSlot(new CollectingDispatcher(), tracker, null); 536 | 537 | Map expected = new HashMap<>(); 538 | expected.put("some_slot", "some_text"); 539 | 540 | // check that the value was not extracted for incorrect intent 541 | assertEquals(expected, slotValues); 542 | } 543 | 544 | /** 545 | * Test extraction of a slot value from trigger intent 546 | */ 547 | @Test 548 | void extractTriggerSlots() { 549 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 550 | @Override 551 | protected List requiredSlots(Tracker tracker) { 552 | return Arrays.asList("some_slot"); 553 | } 554 | 555 | @Override 556 | protected List submit(CollectingDispatcher dispatcher) { 557 | return null; 558 | } 559 | 560 | @Override 561 | protected void registerSlotsValidators(Map slotValidatorMap) { 562 | 563 | } 564 | 565 | @Override 566 | protected Map> slotMappings() { 567 | Map> slotMappingMap = new HashMap<>(); 568 | 569 | slotMappingMap.put("some_slot", Arrays.asList(TriggerIntentSlotMapping.builder().intent("trigger_intent").value("some_value").build())); 570 | 571 | return slotMappingMap; 572 | } 573 | }; 574 | 575 | Tracker tracker = TrackerBuilder.builder() 576 | .senderId("default") 577 | .latestMessage(MessageBuilder.builder() 578 | .intent("trigger_intent", 1.0) 579 | .build()) 580 | .paused(false) 581 | .latestActionName("action_listen") 582 | .build(); 583 | 584 | Map slotValues = customFormAction.extractOtherSlots(new CollectingDispatcher(), tracker, null); 585 | 586 | Map expected = new HashMap<>(); 587 | expected.put("some_slot", "some_value"); 588 | // check that the value was extracted for correct intent 589 | assertEquals(expected, slotValues); 590 | 591 | tracker = TrackerBuilder.builder() 592 | .senderId("default") 593 | .latestMessage(MessageBuilder.builder() 594 | .intent("other_intent", 1.0) 595 | .build()) 596 | .paused(false) 597 | .latestActionName("action_listen") 598 | .build(); 599 | 600 | slotValues = customFormAction.extractOtherSlots(new CollectingDispatcher(), tracker, null); 601 | // check that the value was not extracted for incorrect intent 602 | assertTrue(slotValues.isEmpty()); 603 | 604 | //============================== 605 | // tracker with active form 606 | //============================== 607 | tracker = TrackerBuilder.builder() 608 | .senderId("default") 609 | .latestMessage(MessageBuilder.builder() 610 | .intent("trigger_intent", 1.0) 611 | .build()) 612 | .paused(false) 613 | .activeForm(FormBuilder.builder().name("some_form").validate(true).rejected(false).build()) 614 | .latestActionName("action_listen") 615 | .build(); 616 | 617 | slotValues = customFormAction.extractOtherSlots(new CollectingDispatcher(), tracker, null); 618 | // check that the value was not extracted for correct intent 619 | assertTrue(slotValues.isEmpty()); 620 | } 621 | 622 | /** 623 | * Test extraction of other not requested slots values from entities with the same names 624 | */ 625 | @Test 626 | void extractOtherSlotsNoIntent() { 627 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 628 | @Override 629 | protected List requiredSlots(Tracker tracker) { 630 | return Arrays.asList("some_slot", "some_other_slot"); 631 | } 632 | 633 | @Override 634 | protected List submit(CollectingDispatcher dispatcher) { 635 | return null; 636 | } 637 | 638 | @Override 639 | protected void registerSlotsValidators(Map slotValidatorMap) { 640 | 641 | } 642 | }; 643 | 644 | Tracker tracker = TrackerBuilder.builder() 645 | .senderId("default") 646 | .addSlot("requested_slot", "some_slot") 647 | .latestMessage(MessageBuilder.builder() 648 | .addEntity(EntityBuilder.builder().entity("some_slot").value("some_value").build()) 649 | .build()) 650 | .paused(false) 651 | .latestActionName("action_listen") 652 | .build(); 653 | 654 | Map slotValues = customFormAction.extractOtherSlots(new CollectingDispatcher(), tracker, null); 655 | // check that the value was not extracted for correct intent 656 | assertTrue(slotValues.isEmpty()); 657 | 658 | //===================================== 659 | 660 | tracker = TrackerBuilder.builder() 661 | .senderId("default") 662 | .addSlot("requested_slot", "some_slot") 663 | .latestMessage(MessageBuilder.builder() 664 | .addEntity(EntityBuilder.builder().entity("some_other_slot").value("some_other_value").build()) 665 | .build()) 666 | .paused(false) 667 | .latestActionName("action_listen") 668 | .build(); 669 | 670 | slotValues = customFormAction.extractOtherSlots(new CollectingDispatcher(), tracker, null); 671 | Map expected = new HashMap<>(); 672 | expected.put("some_other_slot", "some_other_value"); 673 | // check that the value was extracted for non requested slot 674 | assertEquals(expected, slotValues); 675 | 676 | //===================================== 677 | tracker = TrackerBuilder.builder() 678 | .senderId("default") 679 | .addSlot("requested_slot", "some_slot") 680 | .latestMessage(MessageBuilder.builder() 681 | .addEntity(EntityBuilder.builder().entity("some_slot").value("some_value").build()) 682 | .addEntity(EntityBuilder.builder().entity("some_other_slot").value("some_other_value").build()) 683 | .build()) 684 | .paused(false) 685 | .latestActionName("action_listen") 686 | .build(); 687 | 688 | slotValues = customFormAction.extractOtherSlots(new CollectingDispatcher(), tracker, null); 689 | expected = new HashMap<>(); 690 | expected.put("some_other_slot", "some_other_value"); 691 | // check that the value was extracted only for non requested slot 692 | assertEquals(expected, slotValues); 693 | } 694 | 695 | /** 696 | * Test extraction of other not requested slots values from entities with the same names 697 | */ 698 | @Test 699 | void extractOtherSlotsWithIntent() { 700 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 701 | @Override 702 | protected List requiredSlots(Tracker tracker) { 703 | return Arrays.asList("some_slot", "some_other_slot"); 704 | } 705 | 706 | @Override 707 | protected List submit(CollectingDispatcher dispatcher) { 708 | return null; 709 | } 710 | 711 | @Override 712 | protected void registerSlotsValidators(Map slotValidatorMap) { 713 | 714 | } 715 | 716 | @Override 717 | protected Map> slotMappings() { 718 | Map> slotMappingMap = new HashMap<>(); 719 | 720 | slotMappingMap.put("some_other_slot", Arrays.asList(EntitySlotMapping.builder("some_other_slot").intent("some_intent").build())); 721 | 722 | return slotMappingMap; 723 | } 724 | }; 725 | 726 | Tracker tracker = TrackerBuilder.builder() 727 | .senderId("default") 728 | .addSlot("requested_slot", "some_slot") 729 | .latestMessage(MessageBuilder.builder() 730 | .intent("some_other_intent", 1.0) 731 | .addEntity(EntityBuilder.builder().entity("some_other_slot").value("some_other_value").build()) 732 | .build()) 733 | .paused(false) 734 | .latestActionName("action_listen") 735 | .build(); 736 | 737 | Map slotValues = customFormAction.extractOtherSlots(new CollectingDispatcher(), tracker, null); 738 | // check that the value was extracted for non requested slot 739 | assertTrue(slotValues.isEmpty()); 740 | 741 | //========================================= 742 | tracker = TrackerBuilder.builder() 743 | .senderId("default") 744 | .addSlot("requested_slot", "some_slot") 745 | .latestMessage(MessageBuilder.builder() 746 | .intent("some_intent", 1.0) 747 | .addEntity(EntityBuilder.builder().entity("some_other_slot").value("some_other_value").build()) 748 | .build()) 749 | .paused(false) 750 | .latestActionName("action_listen") 751 | .build(); 752 | 753 | slotValues = customFormAction.extractOtherSlots(new CollectingDispatcher(), tracker, null); 754 | 755 | Map expected = new HashMap<>(); 756 | expected.put("some_other_slot", "some_other_value"); 757 | // check that the value was extracted only for non requested slot 758 | assertEquals(expected, slotValues); 759 | } 760 | 761 | /** 762 | * Test form validation 763 | */ 764 | @Test 765 | void validate() { 766 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 767 | @Override 768 | protected List requiredSlots(Tracker tracker) { 769 | return Arrays.asList("some_slot", "some_other_slot"); 770 | } 771 | 772 | @Override 773 | protected List submit(CollectingDispatcher dispatcher) { 774 | return null; 775 | } 776 | 777 | @Override 778 | protected void registerSlotsValidators(Map slotValidatorMap) { 779 | 780 | } 781 | }; 782 | 783 | Tracker tracker = TrackerBuilder.builder() 784 | .senderId("default") 785 | .addSlot("requested_slot", "some_slot") 786 | .latestMessage(MessageBuilder.builder() 787 | .addEntity(EntityBuilder.builder().entity("some_slot").value("some_value").build()) 788 | .addEntity(EntityBuilder.builder().entity("some_other_slot").value("some_other_value").build()) 789 | .build()) 790 | .paused(false) 791 | .latestActionName("action_listen") 792 | .build(); 793 | 794 | List events = customFormAction.validate(new CollectingDispatcher(), tracker, null); 795 | List expectedEvents = new ArrayList<>(); 796 | expectedEvents.add(new SlotSet("some_other_slot", "some_other_value")); 797 | expectedEvents.add(new SlotSet("some_slot", "some_value")); 798 | 799 | // check that validation succeed 800 | assertTrue(events.size() == expectedEvents.size() && events.containsAll(expectedEvents)); 801 | 802 | //============================================== 803 | tracker = TrackerBuilder.builder() 804 | .senderId("default") 805 | .addSlot("requested_slot", "some_slot") 806 | .latestMessage(MessageBuilder.builder() 807 | .addEntity(EntityBuilder.builder().entity("some_other_slot").value("some_other_value").build()) 808 | .build()) 809 | .paused(false) 810 | .latestActionName("action_listen") 811 | .build(); 812 | 813 | events = customFormAction.validate(new CollectingDispatcher(), tracker, null); 814 | expectedEvents = new ArrayList<>(); 815 | expectedEvents.add(new SlotSet("some_other_slot", "some_other_value")); 816 | 817 | // check that validation succeed 818 | assertTrue(events.size() == expectedEvents.size() && events.containsAll(expectedEvents)); 819 | 820 | //============================================== 821 | Tracker tracker1 = TrackerBuilder.builder() 822 | .senderId("default") 823 | .addSlot("requested_slot", "some_slot") 824 | .latestMessage(MessageBuilder.builder().build()) 825 | .paused(false) 826 | .latestActionName("action_listen") 827 | .build(); 828 | 829 | Assertions.assertThrows(ActionExecutionRejectionException.class, () -> { 830 | customFormAction.validate(new CollectingDispatcher(), tracker1, null); 831 | },"Failed to extract slot some_slot with action some_form"); 832 | } 833 | 834 | /** 835 | * Test form validation with custom validator 836 | */ 837 | @Test 838 | void setSlotWithinHelper() { 839 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 840 | @Override 841 | protected List requiredSlots(Tracker tracker) { 842 | return Arrays.asList("some_slot", "some_other_slot"); 843 | } 844 | 845 | @Override 846 | protected List submit(CollectingDispatcher dispatcher) { 847 | return null; 848 | } 849 | 850 | @Override 851 | protected void registerSlotsValidators(Map slotValidatorMap) { 852 | //slotValidatorMap.put("some_slot", new SomeSlotValidator()); 853 | slotValidatorMap.put("some_slot", (value, dispatcher, tracker, domain) -> { 854 | if("some_value".equals(value)) { 855 | Map resultMap = new HashMap<>(); 856 | resultMap.put("some_slot", "validated_value"); 857 | resultMap.put("some_other_slot", "other_value"); 858 | return resultMap; 859 | } 860 | return Collections.emptyMap(); 861 | }); 862 | } 863 | }; 864 | 865 | Tracker tracker = TrackerBuilder.builder() 866 | .senderId("default") 867 | .addSlot("requested_slot", "some_slot") 868 | .latestMessage(MessageBuilder.builder() 869 | .addEntity(EntityBuilder.builder().entity("some_slot").value("some_value").build()) 870 | .build()) 871 | .paused(false) 872 | .latestActionName("action_listen") 873 | .build(); 874 | 875 | List events = customFormAction.validate(new CollectingDispatcher(), tracker, null); 876 | List expectedEvents = new ArrayList<>(); 877 | expectedEvents.add(new SlotSet("some_other_slot", "other_value")); 878 | expectedEvents.add(new SlotSet("some_slot", "validated_value")); 879 | 880 | // check that validation succeed 881 | assertTrue(events.size() == expectedEvents.size() && events.containsAll(expectedEvents)); 882 | } 883 | 884 | /** 885 | * Test form validation with custom validator 886 | */ 887 | @Test 888 | void validateExtractedNoRequested() { 889 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 890 | @Override 891 | protected List requiredSlots(Tracker tracker) { 892 | return Arrays.asList("some_slot", "some_other_slot"); 893 | } 894 | 895 | @Override 896 | protected List submit(CollectingDispatcher dispatcher) { 897 | return null; 898 | } 899 | 900 | @Override 901 | protected void registerSlotsValidators(Map slotValidatorMap) { 902 | slotValidatorMap.put("some_slot", (value, dispatcher, tracker, domain) -> { 903 | if("some_value".equals(value)) { 904 | Map resultMap = new HashMap<>(); 905 | resultMap.put("some_slot", "validated_value"); 906 | return resultMap; 907 | } 908 | return Collections.emptyMap(); 909 | }); 910 | } 911 | }; 912 | 913 | Tracker tracker = TrackerBuilder.builder() 914 | .senderId("default") 915 | .addSlot("requested_slot", null) 916 | .latestMessage(MessageBuilder.builder() 917 | .addEntity(EntityBuilder.builder().entity("some_slot").value("some_value").build()) 918 | .build()) 919 | .paused(false) 920 | .latestActionName("action_listen") 921 | .build(); 922 | 923 | List events = customFormAction.validate(new CollectingDispatcher(), tracker, null); 924 | List expectedEvents = new ArrayList<>(); 925 | expectedEvents.add(new SlotSet("some_slot", "validated_value")); 926 | 927 | // check that validation succeed 928 | assertTrue(events.size() == expectedEvents.size() && events.containsAll(expectedEvents)); 929 | } 930 | 931 | /** 932 | * Test form validation with custom validator 933 | */ 934 | @Test 935 | void validatePrefilledSlots() { 936 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 937 | @Override 938 | protected List requiredSlots(Tracker tracker) { 939 | return Arrays.asList("some_slot", "some_other_slot"); 940 | } 941 | 942 | @Override 943 | protected List submit(CollectingDispatcher dispatcher) { 944 | return null; 945 | } 946 | 947 | @Override 948 | protected void registerSlotsValidators(Map slotValidatorMap) { 949 | slotValidatorMap.put("some_slot", (value, dispatcher, tracker, domain) -> { 950 | Map resultMap = new HashMap<>(); 951 | if("some_value".equals(value)) { 952 | resultMap.put("some_slot", "validated_value"); 953 | } else { 954 | resultMap.put("some_slot", null); 955 | } 956 | return resultMap; 957 | }); 958 | } 959 | }; 960 | 961 | Tracker tracker = TrackerBuilder.builder() 962 | .senderId("default") 963 | .addSlot("some_slot", "some_value") 964 | .addSlot("some_other_slot", "some_other_value") 965 | .latestMessage(MessageBuilder.builder() 966 | .addEntity(EntityBuilder.builder().entity("some_slot").value("some_bad_value").build()) 967 | .text("some text") 968 | .build()) 969 | .paused(false) 970 | .latestActionName("action_listen") 971 | .build(); 972 | 973 | List events = customFormAction.activateFormIfRequired(null, tracker, null); 974 | 975 | // check that the form was activated and prefilled slots were validated 976 | List expectedEvents = new ArrayList<>(); 977 | expectedEvents.add(new Form("some_form")); 978 | expectedEvents.add(new SlotSet("some_slot", "validated_value")); 979 | expectedEvents.add(new SlotSet("some_other_slot", "some_other_value")); 980 | 981 | assertTrue(events.size() == expectedEvents.size() && events.containsAll(expectedEvents)); 982 | 983 | //================ 984 | events.addAll(customFormAction.validateIfRequired(null, tracker, null)); 985 | 986 | // check that entities picked up in input overwrite prefilled slots 987 | expectedEvents = new ArrayList<>(); 988 | expectedEvents.add(new Form("some_form")); 989 | expectedEvents.add(new SlotSet("some_slot", "validated_value")); 990 | expectedEvents.add(new SlotSet("some_other_slot", "some_other_value")); 991 | expectedEvents.add(new SlotSet("some_slot", null)); 992 | 993 | assertTrue(events.size() == expectedEvents.size() && events.containsAll(expectedEvents)); 994 | } 995 | 996 | /** 997 | * Test validation results of from_trigger_intent slot mappings 998 | */ 999 | @Test 1000 | void validateTriggerSlots() { 1001 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 1002 | @Override 1003 | protected List requiredSlots(Tracker tracker) { 1004 | return Arrays.asList("some_slot"); 1005 | } 1006 | 1007 | @Override 1008 | protected List submit(CollectingDispatcher dispatcher) { 1009 | return null; 1010 | } 1011 | 1012 | @Override 1013 | protected void registerSlotsValidators(Map slotValidatorMap) { 1014 | 1015 | } 1016 | 1017 | @Override 1018 | protected Map> slotMappings() { 1019 | Map> slotMappingMap = new HashMap<>(); 1020 | 1021 | slotMappingMap.put("some_slot", Arrays.asList(TriggerIntentSlotMapping.builder().intent("trigger_intent").value("some_value").build())); 1022 | 1023 | return slotMappingMap; 1024 | } 1025 | }; 1026 | 1027 | Tracker tracker = TrackerBuilder.builder() 1028 | .senderId("default") 1029 | .latestMessage(MessageBuilder.builder() 1030 | .intent("trigger_intent", 1.0) 1031 | .build()) 1032 | .paused(false) 1033 | .latestActionName("action_listen") 1034 | .build(); 1035 | 1036 | List slotValues = customFormAction.validate(new CollectingDispatcher(), tracker, null); 1037 | 1038 | //check that the value was extracted on form activation 1039 | List expectedEvents = new ArrayList<>(); 1040 | expectedEvents.add(new SlotSet("some_slot", "some_value")); 1041 | 1042 | assertTrue(slotValues.size() == expectedEvents.size() && slotValues.containsAll(expectedEvents)); 1043 | 1044 | //======================================= 1045 | tracker = TrackerBuilder.builder() 1046 | .senderId("default") 1047 | .latestMessage(MessageBuilder.builder() 1048 | .intent("trigger_intent", 1.0) 1049 | .build()) 1050 | .paused(false) 1051 | .activeForm(FormBuilder.builder() 1052 | .name("some_form") 1053 | .validate(true) 1054 | .rejected(false) 1055 | .triggerMessage(MessageBuilder.builder() 1056 | .intent("trigger_intent", 1.0) 1057 | .build()) 1058 | .build()) 1059 | .latestActionName("action_listen") 1060 | .build(); 1061 | 1062 | slotValues = customFormAction.validate(new CollectingDispatcher(), tracker, null); 1063 | //check that the value was not extracted after form activation 1064 | assertTrue(slotValues.isEmpty()); 1065 | 1066 | //======================================= 1067 | tracker = TrackerBuilder.builder() 1068 | .senderId("default") 1069 | .addSlot("requested_slot", "some_other_slot") 1070 | .latestMessage(MessageBuilder.builder() 1071 | .intent("some_other_intent", 1.0) 1072 | .addEntity(EntityBuilder.builder().entity("some_other_slot").value("some_other_value").build()) 1073 | .build()) 1074 | .paused(false) 1075 | .activeForm(FormBuilder.builder() 1076 | .name("some_form") 1077 | .validate(true) 1078 | .rejected(false) 1079 | .triggerMessage(MessageBuilder.builder() 1080 | .intent("trigger_intent", 1.0) 1081 | .build()) 1082 | .build()) 1083 | .latestActionName("action_listen") 1084 | .build(); 1085 | 1086 | slotValues = customFormAction.validate(new CollectingDispatcher(), tracker, null); 1087 | //check that validation failed gracefully 1088 | expectedEvents = new ArrayList<>(); 1089 | expectedEvents.add(new SlotSet("some_other_slot", "some_other_value")); 1090 | 1091 | assertTrue(slotValues.size() == expectedEvents.size() && slotValues.containsAll(expectedEvents)); 1092 | } 1093 | 1094 | /** 1095 | * Test activation form (if required) 1096 | */ 1097 | @Test 1098 | void activateIfRequired() { 1099 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 1100 | @Override 1101 | protected List requiredSlots(Tracker tracker) { 1102 | return Arrays.asList("some_slot", "some_other_slot"); 1103 | } 1104 | 1105 | @Override 1106 | protected List submit(CollectingDispatcher dispatcher) { 1107 | return null; 1108 | } 1109 | 1110 | @Override 1111 | protected void registerSlotsValidators(Map slotValidatorMap) { 1112 | 1113 | } 1114 | }; 1115 | 1116 | //================================================================================ 1117 | // Form should be activated 1118 | //================================================================================ 1119 | Tracker tracker = TrackerBuilder.builder() 1120 | .senderId("default") 1121 | .latestMessage(MessageBuilder.builder() 1122 | .intent("some_intent", 1.0) 1123 | .text("some text") 1124 | .build()) 1125 | .paused(false) 1126 | .latestActionName("action_listen") 1127 | .build(); 1128 | 1129 | List events = customFormAction.activateFormIfRequired(null, tracker, null); 1130 | // check that the form was activated 1131 | List expectedEvents = new ArrayList<>(); 1132 | expectedEvents.add(new Form("some_form")); 1133 | 1134 | assertTrue(events.size() == expectedEvents.size() && events.containsAll(expectedEvents)); 1135 | 1136 | //================================================================================ 1137 | // When a form is already active, it shouldn't be activated again 1138 | //================================================================================ 1139 | 1140 | tracker = TrackerBuilder.builder() 1141 | .senderId("default") 1142 | .paused(false) 1143 | .activeForm(FormBuilder.builder() 1144 | .name("some_form") 1145 | .validate(true) 1146 | .rejected(false) 1147 | .build()) 1148 | .latestActionName("action_listen") 1149 | .build(); 1150 | 1151 | events = customFormAction.activateFormIfRequired(null, tracker, null); 1152 | // check that the form was not activated again 1153 | assertTrue(events.isEmpty()); 1154 | } 1155 | 1156 | /** 1157 | * Test validate form (if required) 1158 | */ 1159 | @Test 1160 | void validateIfRequired() { 1161 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 1162 | @Override 1163 | protected List requiredSlots(Tracker tracker) { 1164 | return Arrays.asList("some_slot", "some_other_slot"); 1165 | } 1166 | 1167 | @Override 1168 | protected List submit(CollectingDispatcher dispatcher) { 1169 | return null; 1170 | } 1171 | 1172 | @Override 1173 | protected void registerSlotsValidators(Map slotValidatorMap) { 1174 | 1175 | } 1176 | }; 1177 | 1178 | //================================================================================ 1179 | // A form validation should be performed 1180 | //================================================================================ 1181 | Tracker tracker = TrackerBuilder.builder() 1182 | .senderId("default") 1183 | .addSlot("requested_slot", "some_slot") 1184 | .latestMessage(MessageBuilder.builder() 1185 | .addEntity(EntityBuilder.builder().entity("some_slot").value("some_value").build()) 1186 | .addEntity(EntityBuilder.builder().entity("some_other_slot").value("some_other_value").build()) 1187 | .build()) 1188 | .paused(false) 1189 | .activeForm(FormBuilder.builder() 1190 | .name("some_form") 1191 | .validate(true) 1192 | .rejected(false) 1193 | .build()) 1194 | .latestActionName("action_listen") 1195 | .build(); 1196 | 1197 | List events = customFormAction.validateIfRequired(new CollectingDispatcher(), tracker, null); 1198 | // check that validation was performed 1199 | List expectedEvents = new ArrayList<>(); 1200 | expectedEvents.add(new SlotSet("some_other_slot", "some_other_value")); 1201 | expectedEvents.add(new SlotSet("some_slot", "some_value")); 1202 | 1203 | assertTrue(events.size() == expectedEvents.size() && events.containsAll(expectedEvents)); 1204 | 1205 | //================================================================================ 1206 | // A form validation should be skipped, because "validate=false" 1207 | //================================================================================ 1208 | 1209 | tracker = TrackerBuilder.builder() 1210 | .senderId("default") 1211 | .addSlot("requested_slot", "some_slot") 1212 | .latestMessage(MessageBuilder.builder() 1213 | .addEntity(EntityBuilder.builder().entity("some_slot").value("some_value").build()) 1214 | .addEntity(EntityBuilder.builder().entity("some_other_slot").value("some_other_value").build()) 1215 | .build()) 1216 | .paused(false) 1217 | .activeForm(FormBuilder.builder() 1218 | .name("some_form") 1219 | .validate(false) 1220 | .rejected(false) 1221 | .build()) 1222 | .latestActionName("action_listen") 1223 | .build(); 1224 | 1225 | events = customFormAction.validateIfRequired(new CollectingDispatcher(), tracker, null); 1226 | // check that validation was skipped because "validate=false" 1227 | assertTrue(events.isEmpty()); 1228 | 1229 | //================================================================================ 1230 | // A form validation should be skipped, because previous action is not action_listen 1231 | //================================================================================ 1232 | 1233 | tracker = TrackerBuilder.builder() 1234 | .senderId("default") 1235 | .addSlot("requested_slot", "some_slot") 1236 | .latestMessage(MessageBuilder.builder() 1237 | .addEntity(EntityBuilder.builder().entity("some_slot").value("some_value").build()) 1238 | .addEntity(EntityBuilder.builder().entity("some_other_slot").value("some_other_value").build()) 1239 | .build()) 1240 | .paused(false) 1241 | .activeForm(FormBuilder.builder() 1242 | .name("some_form") 1243 | .validate(false) 1244 | .rejected(false) 1245 | .build()) 1246 | .latestActionName("some_form") 1247 | .build(); 1248 | 1249 | events = customFormAction.validateIfRequired(new CollectingDispatcher(), tracker, null); 1250 | // check that validation was skipped because previous action is not action_listen 1251 | assertTrue(events.isEmpty()); 1252 | } 1253 | 1254 | /** 1255 | * Test early deactivation 1256 | */ 1257 | @Test 1258 | void earlyDeactivation() { 1259 | AbstractFormAction customFormAction = new AbstractFormAction("some_form") { 1260 | @Override 1261 | protected List requiredSlots(Tracker tracker) { 1262 | return Arrays.asList("some_slot", "some_other_slot"); 1263 | } 1264 | 1265 | @Override 1266 | List validate(CollectingDispatcher dispatcher, Tracker tracker, Domain domain) { 1267 | return super.deactivate(); 1268 | } 1269 | 1270 | @Override 1271 | protected List submit(CollectingDispatcher dispatcher) { 1272 | return null; 1273 | } 1274 | 1275 | @Override 1276 | protected void registerSlotsValidators(Map slotValidatorMap) { 1277 | 1278 | } 1279 | }; 1280 | 1281 | Tracker tracker = TrackerBuilder.builder() 1282 | .senderId("default") 1283 | .addSlot("some_slot", "some_value") 1284 | .latestMessage(MessageBuilder.builder() 1285 | .intent("greet") 1286 | .build()) 1287 | .paused(false) 1288 | .activeForm(FormBuilder.builder() 1289 | .name("some_form") 1290 | .validate(true) 1291 | .rejected(false) 1292 | .build()) 1293 | .latestActionName("action_listen") 1294 | .build(); 1295 | 1296 | List events = customFormAction.run(null, tracker, null); 1297 | // check that form was deactivated before requesting next slot 1298 | List expectedEvents = new ArrayList<>(); 1299 | expectedEvents.add(new Form(null)); 1300 | expectedEvents.add(new SlotSet("requested_slot", null)); 1301 | 1302 | assertTrue(events.size() == expectedEvents.size() && events.containsAll(expectedEvents)); 1303 | } 1304 | } -------------------------------------------------------------------------------- /src/test/java/io/github/rbajek/rasa/sdk/repository/databuilder/tracker/EntityBuilder.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.repository.databuilder.tracker; 2 | 3 | 4 | import io.github.rbajek.rasa.sdk.dto.Tracker; 5 | 6 | public class EntityBuilder { 7 | 8 | public static Builder builder() { 9 | return new Builder(); 10 | } 11 | 12 | public static class Builder { 13 | private final Tracker.Entity entity; 14 | 15 | public Builder() { 16 | this.entity = new Tracker.Entity(); 17 | } 18 | 19 | public Builder start(int value) { 20 | this.entity.setStart(value); 21 | return this; 22 | } 23 | 24 | public Builder end(int value) { 25 | this.entity.setEnd(value); 26 | return this; 27 | } 28 | 29 | public Builder value(String value) { 30 | this.entity.setValue(value); 31 | return this; 32 | } 33 | 34 | public Builder entity(String value) { 35 | this.entity.setEntity(value); 36 | return this; 37 | } 38 | 39 | public Builder confidence(double value) { 40 | this.entity.setConfidence(value); 41 | return this; 42 | } 43 | 44 | public Tracker.Entity build() { 45 | return entity; 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/test/java/io/github/rbajek/rasa/sdk/repository/databuilder/tracker/FormBuilder.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.repository.databuilder.tracker; 2 | 3 | import io.github.rbajek.rasa.sdk.dto.Tracker; 4 | 5 | public class FormBuilder { 6 | 7 | public static Builder builder() { 8 | return new Builder(); 9 | } 10 | 11 | public static class Builder { 12 | private final Tracker.Form form; 13 | 14 | public Builder() { 15 | this.form = new Tracker.Form(); 16 | } 17 | 18 | 19 | public Builder name(String value) { 20 | this.form.setName(value); 21 | return this; 22 | } 23 | 24 | public Builder validate(boolean value) { 25 | this.form.setValidate(value); 26 | return this; 27 | } 28 | 29 | public Builder rejected(boolean value) { 30 | this.form.setRejected(value); 31 | return this; 32 | } 33 | 34 | public Builder triggerMessage(Tracker.Message value) { 35 | this.form.setTriggerMessage(value); 36 | return this; 37 | } 38 | 39 | public Tracker.Form build() { 40 | return this.form; 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/test/java/io/github/rbajek/rasa/sdk/repository/databuilder/tracker/IntentBuilder.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.repository.databuilder.tracker; 2 | 3 | import io.github.rbajek.rasa.sdk.dto.Tracker; 4 | 5 | public class IntentBuilder { 6 | 7 | public static class Builder { 8 | private final Tracker.Intent intent; 9 | 10 | public Builder(Tracker.Intent intent) { 11 | this.intent = intent; 12 | } 13 | 14 | public Builder confidence(double value) { 15 | this.intent.setConfidence(value); 16 | return this; 17 | } 18 | 19 | public Builder name(String value) { 20 | this.intent.setName(value); 21 | return this; 22 | } 23 | 24 | public Tracker.Intent build() { 25 | return this.intent; 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/test/java/io/github/rbajek/rasa/sdk/repository/databuilder/tracker/MessageBuilder.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.repository.databuilder.tracker; 2 | 3 | import io.github.rbajek.rasa.sdk.dto.Tracker; 4 | 5 | import java.util.ArrayList; 6 | 7 | public class MessageBuilder { 8 | 9 | public static Builder builder() { 10 | return new Builder(); 11 | } 12 | 13 | public static class Builder { 14 | private final Tracker.Message latestMessage; 15 | 16 | public Builder() { 17 | this.latestMessage = new Tracker.Message(); 18 | } 19 | 20 | public Builder addEntity(Tracker.Entity entity) { 21 | if(this.latestMessage.getEntities() == null) { 22 | this.latestMessage.setEntities(new ArrayList<>()); 23 | } 24 | this.latestMessage.getEntities().add(entity); 25 | return this; 26 | } 27 | 28 | public Builder intent(String name) { 29 | return intent(name, null); 30 | } 31 | 32 | public Builder intent(String name, Double confidence) { 33 | Tracker.Intent intent = new Tracker.Intent(); 34 | intent.setName(name); 35 | intent.setConfidence(confidence); 36 | this.latestMessage.setIntent(intent); 37 | return this; 38 | } 39 | 40 | public Builder text(String text) { 41 | this.latestMessage.setText(text); 42 | return this; 43 | } 44 | 45 | public Tracker.Message build() { 46 | return this.latestMessage; 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/test/java/io/github/rbajek/rasa/sdk/repository/databuilder/tracker/TrackerBuilder.java: -------------------------------------------------------------------------------- 1 | package io.github.rbajek.rasa.sdk.repository.databuilder.tracker; 2 | 3 | import io.github.rbajek.rasa.sdk.dto.Tracker; 4 | 5 | public class TrackerBuilder { 6 | 7 | public static Builder builder() { 8 | return new Builder(); 9 | } 10 | 11 | public static class Builder { 12 | private final Tracker tracker; 13 | 14 | public Builder() { 15 | this.tracker = new Tracker(); 16 | } 17 | 18 | public Builder senderId(String value) { 19 | this.tracker.setSenderId(value); 20 | return this; 21 | } 22 | 23 | public Builder addSlot(String slotName, Object value) { 24 | this.tracker.addSlot(slotName, value); 25 | return this; 26 | } 27 | 28 | public Builder latestMessage(Tracker.Message value) { 29 | this.tracker.setLatestMessage(value); 30 | return this; 31 | } 32 | 33 | public Builder followupAction(String value) { 34 | this.tracker.setFollowupAction(value); 35 | return this; 36 | } 37 | 38 | public Builder paused(boolean value) { 39 | this.tracker.setPaused(value); 40 | return this; 41 | } 42 | 43 | public Builder latestActionName(String value) { 44 | this.tracker.setLatestActionName(value); 45 | return this; 46 | } 47 | 48 | public Builder activeForm(Tracker.Form value) { 49 | this.tracker.setActiveForm(value); 50 | return this; 51 | } 52 | 53 | public Tracker build() { 54 | return this.tracker; 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/test/resources/log4j2.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | --------------------------------------------------------------------------------