├── .gitignore ├── .gitmodules ├── README.md ├── build.gradle ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── settings.gradle └── src ├── main └── java │ └── ai │ └── test │ └── classifier_client │ ├── Classification.java │ └── ClassifierClient.java └── test ├── java ├── ClientSeleniumTest.java └── ClientTest.java └── resources ├── cart.png └── menu.png /.gitignore: -------------------------------------------------------------------------------- 1 | .classpath 2 | .project 3 | .settings 4 | bin 5 | .idea 6 | .gradle 7 | */build 8 | build 9 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "src/main/proto"] 2 | path = src/main/proto 3 | url = git@github.com:testdotai/classifier-proto.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Test.ai Classifier - Java Client 2 | 3 | The code in this directory defines a client library for use with the [gRPC-based Test.ai classifier server](https://github.com/testdotai/appium-classifier-plugin). 4 | 5 | ## Installation & Setup 6 | 7 | At this point, the library is not available on Maven Central for easy download. Instead, check the [Releases](https://github.com/testdotai/classifier-client-java/releases) page to download pre-built Jarfiles you can import into your projects, or use Jitpack. 8 | 9 | ## Usage 10 | 11 | This client exposes two classes: 12 | 13 | ``` 14 | ai.test.classifier_client.ClassifierClient; 15 | ai.test.classifier_client.Classification; 16 | ``` 17 | 18 | The important class is `ClassifierClient`, whose constructor takes host and port parameters, so that the client can speak to the correct classifier server. The class has two important methods: 19 | 20 | ```java 21 | Map classifyElements(String label, Map elementImages, 22 | double confidenceThreshold, boolean allowWeakerMatches) 23 | 24 | List findElementsMatchingLabel (RemoteWebDriver driver, String label, 25 | double confidenceThreshold, boolean allowWeakerMatches) 26 | ``` 27 | 28 | 1. `classifyElements` takes a label (see `lib/labels.js` in this repo), a map of Strings (ids) to byte arrays (representing PNG image data), a confidence threshold (1.0 = perfect confidence required for a match, 0.0 = no confidence required), and a boolean flag which tells the server whether or not to return potential matches even if the potential match had a *different* label as its highest-confidence classification. The return value is a map of Strings (the same ids you passed in) to `Classification` objects (described below). 29 | 2. `findElementsMatchingLabel` is a helper function for use with Selenium tests (for Appium use the Appium plugin as described in the main README for this repo). It takes a driver object and the same final parameters as `classifyElements`. The plugin will use the driver object to take screenshots of relevant images and pass them to the classifier. The return value is a list of any `WebElement`s which match the label provided. 30 | 31 | The `Classification` object is simply a container for the label, confidence, and confidence for the label which was originally provided. 32 | 33 | For a concrete example, check out the `ClientSeleniumTest.java` file in this repo. 34 | -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | plugins { 2 | id 'java' 3 | id 'com.google.protobuf' version '0.8.8' 4 | id 'application' 5 | id 'idea' 6 | id 'maven' 7 | } 8 | 9 | group 'ai.test' 10 | version '1.0-SNAPSHOT' 11 | 12 | sourceCompatibility = 1.8 13 | 14 | repositories { 15 | maven { // The google mirror is less flaky than mavenCentral() 16 | url "https://maven-central.storage-download.googleapis.com/repos/central/data/" } 17 | mavenCentral() 18 | } 19 | 20 | def grpcVersion = '1.25.0' // CURRENT_GRPC_VERSION 21 | def protobufVersion = '3.10.1' 22 | def protocVersion = protobufVersion 23 | 24 | dependencies { 25 | implementation "io.grpc:grpc-protobuf:${grpcVersion}" 26 | implementation "io.grpc:grpc-stub:${grpcVersion}" 27 | implementation "io.grpc:grpc-okhttp:${grpcVersion}" 28 | implementation group: 'org.seleniumhq.selenium', name: 'selenium-java', version: '3.141.59' 29 | 30 | compileOnly "javax.annotation:javax.annotation-api:1.2" 31 | 32 | testCompile group: 'com.google.guava', name: 'guava', version: '28.1-jre' 33 | testCompile group: 'junit', name: 'junit', version: '4.12' 34 | testCompile group: 'org.hamcrest', name: 'hamcrest-junit', version: '2.0.0.0' 35 | } 36 | 37 | protobuf { 38 | protoc { 39 | artifact = "com.google.protobuf:protoc:${protocVersion}" 40 | } 41 | plugins { 42 | grpc { 43 | artifact = "io.grpc:protoc-gen-grpc-java:${grpcVersion}" 44 | } 45 | } 46 | generateProtoTasks { 47 | all()*.plugins { 48 | grpc {} 49 | } 50 | } 51 | } 52 | 53 | // Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. 54 | sourceSets { 55 | main { 56 | java { 57 | srcDirs 'build/generated/source/proto/main/grpc' 58 | srcDirs 'build/generated/source/proto/main/java' 59 | } 60 | } 61 | } 62 | 63 | startScripts.enabled = false 64 | 65 | task classifierClient(type: CreateStartScripts) { 66 | mainClassName = 'ai.test.classifier_client' 67 | applicationName = 'test-ai-classifier-client' 68 | outputDir = new File(project.buildDir, 'tmp') 69 | classpath = startScripts.classpath 70 | } 71 | 72 | applicationDistribution.into('bin') { 73 | from(classifierClient) 74 | fileMode = 0755 75 | } 76 | 77 | task fatJar(type: Jar) { 78 | manifest { 79 | attributes 'Main-Class': 'ai.test.classifier_client.ClassifierClient' 80 | } 81 | baseName = project.name + '-all' 82 | from { 83 | configurations.runtimeClasspath.collect { it.isDirectory() ? it : zipTree(it) } 84 | } 85 | with jar 86 | } 87 | 88 | test { 89 | outputs.upToDateWhen {false} 90 | useJUnit() 91 | testLogging { 92 | exceptionFormat = 'full' 93 | showStandardStreams = true 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/testdotai/classifier-client-java/14aae5a661843be08a1ec33cbea2f33d4fd4b999/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | #Thu Nov 21 11:37:52 PST 2019 2 | distributionBase=GRADLE_USER_HOME 3 | distributionPath=wrapper/dists 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | distributionUrl=https\://services.gradle.org/distributions/gradle-4.10.3-all.zip 7 | -------------------------------------------------------------------------------- /gradlew: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | ############################################################################## 4 | ## 5 | ## Gradle start up script for UN*X 6 | ## 7 | ############################################################################## 8 | 9 | # Attempt to set APP_HOME 10 | # Resolve links: $0 may be a link 11 | PRG="$0" 12 | # Need this for relative symlinks. 13 | while [ -h "$PRG" ] ; do 14 | ls=`ls -ld "$PRG"` 15 | link=`expr "$ls" : '.*-> \(.*\)$'` 16 | if expr "$link" : '/.*' > /dev/null; then 17 | PRG="$link" 18 | else 19 | PRG=`dirname "$PRG"`"/$link" 20 | fi 21 | done 22 | SAVED="`pwd`" 23 | cd "`dirname \"$PRG\"`/" >/dev/null 24 | APP_HOME="`pwd -P`" 25 | cd "$SAVED" >/dev/null 26 | 27 | APP_NAME="Gradle" 28 | APP_BASE_NAME=`basename "$0"` 29 | 30 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 31 | DEFAULT_JVM_OPTS="" 32 | 33 | # Use the maximum available, or set MAX_FD != -1 to use that value. 34 | MAX_FD="maximum" 35 | 36 | warn () { 37 | echo "$*" 38 | } 39 | 40 | die () { 41 | echo 42 | echo "$*" 43 | echo 44 | exit 1 45 | } 46 | 47 | # OS specific support (must be 'true' or 'false'). 48 | cygwin=false 49 | msys=false 50 | darwin=false 51 | nonstop=false 52 | case "`uname`" in 53 | CYGWIN* ) 54 | cygwin=true 55 | ;; 56 | Darwin* ) 57 | darwin=true 58 | ;; 59 | MINGW* ) 60 | msys=true 61 | ;; 62 | NONSTOP* ) 63 | nonstop=true 64 | ;; 65 | esac 66 | 67 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 68 | 69 | # Determine the Java command to use to start the JVM. 70 | if [ -n "$JAVA_HOME" ] ; then 71 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 72 | # IBM's JDK on AIX uses strange locations for the executables 73 | JAVACMD="$JAVA_HOME/jre/sh/java" 74 | else 75 | JAVACMD="$JAVA_HOME/bin/java" 76 | fi 77 | if [ ! -x "$JAVACMD" ] ; then 78 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 79 | 80 | Please set the JAVA_HOME variable in your environment to match the 81 | location of your Java installation." 82 | fi 83 | else 84 | JAVACMD="java" 85 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 86 | 87 | Please set the JAVA_HOME variable in your environment to match the 88 | location of your Java installation." 89 | fi 90 | 91 | # Increase the maximum file descriptors if we can. 92 | if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then 93 | MAX_FD_LIMIT=`ulimit -H -n` 94 | if [ $? -eq 0 ] ; then 95 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then 96 | MAX_FD="$MAX_FD_LIMIT" 97 | fi 98 | ulimit -n $MAX_FD 99 | if [ $? -ne 0 ] ; then 100 | warn "Could not set maximum file descriptor limit: $MAX_FD" 101 | fi 102 | else 103 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" 104 | fi 105 | fi 106 | 107 | # For Darwin, add options to specify how the application appears in the dock 108 | if $darwin; then 109 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" 110 | fi 111 | 112 | # For Cygwin, switch paths to Windows format before running java 113 | if $cygwin ; then 114 | APP_HOME=`cygpath --path --mixed "$APP_HOME"` 115 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` 116 | JAVACMD=`cygpath --unix "$JAVACMD"` 117 | 118 | # We build the pattern for arguments to be converted via cygpath 119 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` 120 | SEP="" 121 | for dir in $ROOTDIRSRAW ; do 122 | ROOTDIRS="$ROOTDIRS$SEP$dir" 123 | SEP="|" 124 | done 125 | OURCYGPATTERN="(^($ROOTDIRS))" 126 | # Add a user-defined pattern to the cygpath arguments 127 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then 128 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" 129 | fi 130 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 131 | i=0 132 | for arg in "$@" ; do 133 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` 134 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option 135 | 136 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition 137 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` 138 | else 139 | eval `echo args$i`="\"$arg\"" 140 | fi 141 | i=$((i+1)) 142 | done 143 | case $i in 144 | (0) set -- ;; 145 | (1) set -- "$args0" ;; 146 | (2) set -- "$args0" "$args1" ;; 147 | (3) set -- "$args0" "$args1" "$args2" ;; 148 | (4) set -- "$args0" "$args1" "$args2" "$args3" ;; 149 | (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; 150 | (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; 151 | (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; 152 | (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; 153 | (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; 154 | esac 155 | fi 156 | 157 | # Escape application args 158 | save () { 159 | for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done 160 | echo " " 161 | } 162 | APP_ARGS=$(save "$@") 163 | 164 | # Collect all arguments for the java command, following the shell quoting and substitution rules 165 | eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" 166 | 167 | # by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong 168 | if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then 169 | cd "$(dirname "$0")" 170 | fi 171 | 172 | exec "$JAVACMD" "$@" 173 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @if "%DEBUG%" == "" @echo off 2 | @rem ########################################################################## 3 | @rem 4 | @rem Gradle startup script for Windows 5 | @rem 6 | @rem ########################################################################## 7 | 8 | @rem Set local scope for the variables with windows NT shell 9 | if "%OS%"=="Windows_NT" setlocal 10 | 11 | set DIRNAME=%~dp0 12 | if "%DIRNAME%" == "" set DIRNAME=. 13 | set APP_BASE_NAME=%~n0 14 | set APP_HOME=%DIRNAME% 15 | 16 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 17 | set DEFAULT_JVM_OPTS= 18 | 19 | @rem Find java.exe 20 | if defined JAVA_HOME goto findJavaFromJavaHome 21 | 22 | set JAVA_EXE=java.exe 23 | %JAVA_EXE% -version >NUL 2>&1 24 | if "%ERRORLEVEL%" == "0" goto init 25 | 26 | echo. 27 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 28 | echo. 29 | echo Please set the JAVA_HOME variable in your environment to match the 30 | echo location of your Java installation. 31 | 32 | goto fail 33 | 34 | :findJavaFromJavaHome 35 | set JAVA_HOME=%JAVA_HOME:"=% 36 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 37 | 38 | if exist "%JAVA_EXE%" goto init 39 | 40 | echo. 41 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 42 | echo. 43 | echo Please set the JAVA_HOME variable in your environment to match the 44 | echo location of your Java installation. 45 | 46 | goto fail 47 | 48 | :init 49 | @rem Get command-line arguments, handling Windows variants 50 | 51 | if not "%OS%" == "Windows_NT" goto win9xME_args 52 | 53 | :win9xME_args 54 | @rem Slurp the command line arguments. 55 | set CMD_LINE_ARGS= 56 | set _SKIP=2 57 | 58 | :win9xME_args_slurp 59 | if "x%~1" == "x" goto execute 60 | 61 | set CMD_LINE_ARGS=%* 62 | 63 | :execute 64 | @rem Setup the command line 65 | 66 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 67 | 68 | @rem Execute Gradle 69 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% 70 | 71 | :end 72 | @rem End local scope for the variables with windows NT shell 73 | if "%ERRORLEVEL%"=="0" goto mainEnd 74 | 75 | :fail 76 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 77 | rem the _cmd.exe /c_ return code! 78 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 79 | exit /b 1 80 | 81 | :mainEnd 82 | if "%OS%"=="Windows_NT" endlocal 83 | 84 | :omega 85 | -------------------------------------------------------------------------------- /settings.gradle: -------------------------------------------------------------------------------- 1 | rootProject.name = 'classifier_client' 2 | 3 | -------------------------------------------------------------------------------- /src/main/java/ai/test/classifier_client/Classification.java: -------------------------------------------------------------------------------- 1 | package ai.test.classifier_client; 2 | 3 | public class Classification { 4 | 5 | private String label; 6 | private double confidence; 7 | private double confidenceForLabel; 8 | 9 | public Classification(String label, double confidence, double confidenceForLabel) { 10 | this.label = label; 11 | this.confidence = confidence; 12 | this.confidenceForLabel = confidenceForLabel; 13 | } 14 | 15 | public String getLabel() { 16 | return label; 17 | } 18 | 19 | public double getConfidenceForLabel() { 20 | return confidenceForLabel; 21 | } 22 | 23 | public double getConfidence() { 24 | return confidence; 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/ai/test/classifier_client/ClassifierClient.java: -------------------------------------------------------------------------------- 1 | package ai.test.classifier_client; 2 | 3 | import ai.test.classifier_client.ClassifierGrpc.ClassifierBlockingStub; 4 | import ai.test.classifier_client.ClassifierOuterClass.ElementClassificationRequest; 5 | import ai.test.classifier_client.ClassifierOuterClass.ElementClassificationResult; 6 | import com.google.protobuf.ByteString; 7 | import io.grpc.ManagedChannel; 8 | import io.grpc.ManagedChannelBuilder; 9 | import java.util.ArrayList; 10 | import java.util.HashMap; 11 | import java.util.List; 12 | import java.util.Map; 13 | import java.util.concurrent.TimeUnit; 14 | import org.openqa.selenium.By; 15 | import org.openqa.selenium.OutputType; 16 | import org.openqa.selenium.WebElement; 17 | import org.openqa.selenium.remote.RemoteWebDriver; 18 | import org.openqa.selenium.remote.RemoteWebElement; 19 | 20 | public class ClassifierClient { 21 | 22 | public static final double DEFAULT_THRESHOLD = 0.2; 23 | 24 | private final ManagedChannel channel; 25 | private final ClassifierBlockingStub blockingStub; 26 | 27 | public ClassifierClient(String host, int port) { 28 | this(ManagedChannelBuilder.forAddress(host, port).usePlaintext()); 29 | } 30 | 31 | public ClassifierClient(ManagedChannelBuilder channelBuilder) { 32 | channel = channelBuilder.build(); 33 | blockingStub = ClassifierGrpc.newBlockingStub(channel); 34 | } 35 | 36 | public void shutdown() throws InterruptedException { 37 | channel.shutdown().awaitTermination(5, TimeUnit.SECONDS); 38 | } 39 | 40 | public Map classifyElements(String label, Map elementImages, 41 | double confidenceThreshold, boolean allowWeakerMatches) { 42 | 43 | Map _elementImages = new HashMap<>(); 44 | elementImages.forEach((id, image) -> { 45 | _elementImages.put(id, ByteString.copyFrom(image)); 46 | }); 47 | ElementClassificationRequest req = ElementClassificationRequest.newBuilder() 48 | .setLabelHint(label) 49 | .setAllowWeakerMatches(allowWeakerMatches) 50 | .setConfidenceThreshold(confidenceThreshold) 51 | .putAllElementImages(_elementImages) 52 | .build(); 53 | 54 | Map res = blockingStub.classifyElements(req).getClassificationsMap(); 55 | Map classifications = new HashMap<>(); 56 | res.forEach((id, elClassRes) -> { 57 | Classification c = new Classification(elClassRes.getLabel(), elClassRes.getConfidence(), 58 | elClassRes.getConfidenceForHint()); 59 | classifications.put(id, c); 60 | }); 61 | return classifications; 62 | } 63 | 64 | public Map classifyElements(String label, Map elementImages) { 65 | return classifyElements(label, elementImages, DEFAULT_THRESHOLD, false); 66 | } 67 | 68 | 69 | public List findElementsMatchingLabel (RemoteWebDriver driver, String label, 70 | double confidenceThreshold, boolean allowWeakerMatches) throws Exception { 71 | 72 | List els = driver.findElements(By.xpath("//body//*[not(self::script) and not(self::style) and not(child::*)]")); 73 | Map elementImages = new HashMap<>(); 74 | Map elements = new HashMap<>(); 75 | for (WebElement el : els) { 76 | String elId = ((RemoteWebElement)el).getId(); 77 | elements.put(elId, el); 78 | try { 79 | elementImages.put(elId, el.getScreenshotAs(OutputType.BYTES)); 80 | } catch (Exception ign) {} 81 | } 82 | if (elementImages.size() < 1) { 83 | throw new Exception("Didn't find any leaf node elements with valid screenshots"); 84 | } 85 | Map classifications = classifyElements(label, elementImages, 86 | confidenceThreshold, allowWeakerMatches); 87 | List matchedEls = new ArrayList<>(); 88 | classifications.forEach((id, clsf) -> { 89 | matchedEls.add(elements.get(id)); 90 | }); 91 | return matchedEls; 92 | } 93 | 94 | public List findElementsMatchingLabel (RemoteWebDriver driver, String label) throws Exception { 95 | return findElementsMatchingLabel(driver, label, DEFAULT_THRESHOLD, false); 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /src/test/java/ClientSeleniumTest.java: -------------------------------------------------------------------------------- 1 | import ai.test.classifier_client.ClassifierClient; 2 | import java.net.MalformedURLException; 3 | import java.net.URL; 4 | import java.util.List; 5 | import org.hamcrest.collection.IsCollectionWithSize; 6 | import org.junit.After; 7 | import org.junit.Assert; 8 | import org.junit.Before; 9 | import org.junit.Test; 10 | import org.openqa.selenium.WebElement; 11 | import org.openqa.selenium.remote.DesiredCapabilities; 12 | import org.openqa.selenium.remote.RemoteWebDriver; 13 | 14 | public class ClientSeleniumTest { 15 | 16 | private RemoteWebDriver driver; 17 | private ClassifierClient classifier; 18 | 19 | @Before 20 | public void setUp() throws MalformedURLException { 21 | DesiredCapabilities caps = DesiredCapabilities.chrome(); 22 | driver = new RemoteWebDriver(new URL("http://localhost:4444/wd/hub"), caps); 23 | classifier = new ClassifierClient("127.0.0.1", 50051); 24 | } 25 | 26 | @After 27 | public void tearDown() throws InterruptedException { 28 | if (driver != null) { 29 | driver.quit(); 30 | } 31 | classifier.shutdown(); 32 | } 33 | 34 | 35 | @Test 36 | public void testClassifierClient() throws Exception { 37 | driver.get("https://test.ai"); 38 | List els = classifier.findElementsMatchingLabel(driver, "twitter"); 39 | Assert.assertThat(els, IsCollectionWithSize.hasSize(1)); 40 | els.get(0).click(); 41 | Assert.assertEquals(driver.getCurrentUrl(), "https://twitter.com/testdotai"); 42 | } 43 | 44 | 45 | } -------------------------------------------------------------------------------- /src/test/java/ClientTest.java: -------------------------------------------------------------------------------- 1 | import ai.test.classifier_client.Classification; 2 | import ai.test.classifier_client.ClassifierClient; 3 | import com.google.common.collect.ImmutableMap; 4 | import java.io.IOException; 5 | import java.nio.file.Files; 6 | import java.nio.file.Path; 7 | import java.nio.file.Paths; 8 | import java.util.Map; 9 | import org.hamcrest.Matchers; 10 | import org.junit.AfterClass; 11 | import org.junit.Assert; 12 | import org.junit.BeforeClass; 13 | import org.junit.Test; 14 | 15 | public class ClientTest { 16 | 17 | private static ClassifierClient client; 18 | 19 | private static final Path CART = Paths.get("src/test/resources/cart.png"); 20 | private static final Path MENU = Paths.get("src/test/resources/menu.png"); 21 | 22 | @BeforeClass 23 | public static void setUp() { 24 | client = new ClassifierClient("127.0.0.1", 50051); 25 | } 26 | 27 | @AfterClass 28 | public static void tearDown() throws InterruptedException { 29 | client.shutdown(); 30 | } 31 | 32 | @Test 33 | public void testClassification() throws IOException { 34 | byte[] cartBytes = Files.readAllBytes(CART); 35 | byte[] menuBytes = Files.readAllBytes(MENU); 36 | Map elementImages = ImmutableMap.of("cart", cartBytes, "menu", menuBytes); 37 | 38 | Map classifications = client.classifyElements("cart", elementImages); 39 | Classification cartCls = classifications.get("cart"); 40 | Assert.assertEquals(cartCls.getLabel(), "cart"); 41 | Assert.assertThat(classifications, Matchers.not(Matchers.hasKey("menu"))); 42 | 43 | classifications = client.classifyElements("cart", elementImages, 0.0, true); 44 | cartCls = classifications.get("cart"); 45 | Classification menuCls = classifications.get("menu"); 46 | Assert.assertEquals(cartCls.getLabel(), "cart"); 47 | Assert.assertEquals(menuCls.getLabel(), "menu"); 48 | Assert.assertThat(cartCls.getConfidenceForLabel(), Matchers.greaterThan(menuCls.getConfidenceForLabel())); 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/test/resources/cart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/testdotai/classifier-client-java/14aae5a661843be08a1ec33cbea2f33d4fd4b999/src/test/resources/cart.png -------------------------------------------------------------------------------- /src/test/resources/menu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/testdotai/classifier-client-java/14aae5a661843be08a1ec33cbea2f33d4fd4b999/src/test/resources/menu.png --------------------------------------------------------------------------------