├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── build.gradle ├── build.gradle.release ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── pom.xml ├── settings.gradle └── src ├── main └── java │ └── com │ └── medallia │ └── word2vec │ ├── NormalizedWord2VecModel.java │ ├── Searcher.java │ ├── SearcherImpl.java │ ├── Word2VecExamples.java │ ├── Word2VecModel.java │ ├── Word2VecTrainer.java │ ├── Word2VecTrainerBuilder.java │ ├── huffman │ └── HuffmanCoding.java │ ├── neuralnetwork │ ├── CBOWModelTrainer.java │ ├── NeuralNetworkConfig.java │ ├── NeuralNetworkTrainer.java │ ├── NeuralNetworkType.java │ └── SkipGramModelTrainer.java │ ├── thrift │ └── Word2VecModelThrift.java │ └── util │ ├── AC.java │ ├── AutoLog.java │ ├── CallableVoid.java │ ├── Common.java │ ├── Compare.java │ ├── FileUtils.java │ ├── Format.java │ ├── IO.java │ ├── NDC.java │ ├── Pair.java │ ├── ProfilingTimer.java │ ├── Strings.java │ ├── ThriftUtils.java │ └── UnicodeReader.java └── test ├── java └── com │ └── medallia │ └── word2vec │ ├── Word2VecBinTest.java │ └── Word2VecTest.java └── resources └── com └── medallia └── word2vec ├── cbowBasic.model ├── cbowIterations.model ├── skipGramBasic.model ├── skipGramIterations.model ├── tokensModel.bin ├── tokensModel.txt ├── word2vec.c.output.model.txt └── word2vec.short.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | .gradle 3 | /build 4 | /out 5 | /intTestHomeDir 6 | /subprojects/*/out 7 | /intellij 8 | /buildSrc/lib 9 | /buildSrc/build 10 | /subprojects/*/build 11 | /subprojects/docs/src/samples/*/*/build 12 | /website/build 13 | /website/website.iml 14 | /website/website.ipr 15 | /website/website.iws 16 | /performanceTest/build 17 | /subprojects/*/ide 18 | /*.iml 19 | /*.ipr 20 | /*.iws 21 | /subprojects/*/*.iml 22 | /buildSrc/*.ipr 23 | /buildSrc/*.iws 24 | /buildSrc/*.iml 25 | /buildSrc/out 26 | *.classpath 27 | *.project 28 | *.settings 29 | /bin 30 | /subprojects/*/bin 31 | .DS_Store 32 | /performanceTest/lib 33 | .textmate 34 | /incoming-distributions 35 | .idea 36 | *.sublime-* 37 | .nb-gradle 38 | /target/ 39 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: java 2 | 3 | jdk: 4 | - oraclejdk8 5 | - oraclejdk7 6 | - openjdk7 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Medallia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Word2vecJava 2 | 3 | [![Build Status](https://travis-ci.org/medallia/Word2VecJava.svg?branch=master)](https://travis-ci.org/medallia/Word2VecJava) 4 | 5 | This is a port of the open source C implementation of word2vec (https://code.google.com/p/word2vec/). You can browse/contribute the repository via [Github](https://github.com/medallia/Word2VecJava). Alternatively you can pull it from the central Maven repositories: 6 | ```XML 7 | 8 | com.medallia.word2vec 9 | Word2VecJava 10 | 0.10.3 11 | 12 | ``` 13 | 14 | For more background information about word2vec and neural network training for the vector representation of words, please see the following papers. 15 | * http://ttic.uchicago.edu/~haotang/speech/1301.3781.pdf 16 | * http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf 17 | 18 | For comprehensive explanation of the training process (the gradiant descent formula calculation in the back propagation training), please see: 19 | * http://www-personal.umich.edu/~ronxin/pdf/w2vexp.pdf 20 | 21 | Note that this isn't a completely faithful rewrite, specifically: 22 | 23 | ### When building the vocabulary from the training file: 24 | 1. The original version does a reduction step when learning the vocabulary from the file when the vocab size hits 21 million words, removing any words that do not meet the minimum frequency threshold. This Java port has no such reduction step. 25 | 2. The original version injects a token into the vocabulary (with a word count of 0) as a substitute for newlines in the input file. This Java port's vocabulary excludes the token. 26 | 3. The original version does a quicksort which is not stable, so vocabulary terms with the same frequency may be ordered non-deterministically. The Java port does an explicit sort first by frequency, then by the token's lexicographical ordering. 27 | 28 | ### In partitioning the file for processing 29 | 1. The original version assumes that sentences are delimited by newline characters and injects a sentence boundary per 1000 non-filtered tokens, i.e. valid token by the vocabulary and not removed by the randomized sampling process. Java port mimics this behavior for now ... 30 | 2. When the original version encounters an empty line in the input file, it re-processes the first word of the last non-empty line with a sentence length of 0 and updates the random value. Java port omits this behavior. 31 | 32 | ### In the sampling function 33 | 1. The original C documentation indicates that the range should be between 0 and 1e-5, but the default value is 1e-3. This Java port retains that confusing information. 34 | 2. The random value generated for comparison to determine if a token should be filtered uses a float. This Java port uses double precision for twice the fun. 35 | 36 | ### In the distance function to find the nearest matches to a target query 37 | 1. The original version includes an unnecessary normalization of the vector for the input query which may lead to tiny inaccuracies. This Java port foregoes this superfluous operation. 38 | 2. The original version has an O(n * k) algorithm for finding top matches and is hardcoded to 40 matches. This Java port uses Google's lovely com.google.common.collect.Ordering.greatestOf(java.util.Iterator, int) which is O(n + k log k) and takes in arbitrary k. 39 | 40 | Note: The k-means clustering option is excluded in the Java port 41 | 42 | Please do not hesitate to peek at the source code. It should be readable, concise, and correct. Please feel free to reach out if it is not. 43 | 44 | ## Building the Project 45 | To verify that the project is building correctly, run 46 | ```bash 47 | ./gradlew build && ./gradlew test 48 | ``` 49 | 50 | It should run 7 tests without any error. 51 | 52 | Note: this project requires gradle 2.2+, if you are using older version of gradle, please upgrade it and run: 53 | ```bash 54 | ./gradlew clean test 55 | ``` 56 | 57 | to have a clean build and re-run the tests. 58 | 59 | 60 | ## Contact 61 | Andrew Ko (wko27code@gmail.com) 62 | -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | apply plugin: 'java' 2 | apply plugin: 'maven' 3 | 4 | sourceCompatibility = 1.7 5 | targetCompatibility = 1.7 6 | group = 'com.medallia.word2vec' 7 | version = '0.10.3' 8 | 9 | repositories { 10 | mavenCentral() 11 | } 12 | 13 | dependencies { 14 | compile 'org.apache.thrift:libthrift:0.9.1' 15 | compile group: 'org.apache.commons', name: 'commons-lang3', version:'3.1' 16 | compile group: 'com.google.guava', name: 'guava', version: '18.0' 17 | compile group: 'joda-time', name: 'joda-time', version: '2.3' 18 | compile group: 'log4j', name: 'log4j', version: '1.2.17' 19 | compile group: 'commons-io', name: 'commons-io', version: '2.4' 20 | testCompile group: 'junit', name: 'junit', version: '4.11' 21 | } 22 | 23 | if (JavaVersion.current().isJava8Compatible()) { 24 | tasks.withType(Javadoc) { 25 | // disable the crazy super-strict doclint tool in Java 8 26 | //noinspection SpellCheckingInspection 27 | options.addStringOption('Xdoclint:none', '-quiet') 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /build.gradle.release: -------------------------------------------------------------------------------- 1 | apply plugin: 'signing' 2 | 3 | task javadocJar(type: Jar) { 4 | classifier = 'javadoc' 5 | from javadoc 6 | } 7 | 8 | task sourcesJar(type: Jar) { 9 | classifier = 'sources' 10 | from sourceSets.main.allSource 11 | } 12 | 13 | artifacts { 14 | archives javadocJar, sourcesJar 15 | } 16 | 17 | signing { 18 | sign configurations.archives 19 | } 20 | 21 | uploadArchives { 22 | repositories { 23 | mavenDeployer { 24 | beforeDeployment { MavenDeployment deployment -> signing.signPom(deployment) } 25 | 26 | repository(url: "https://oss.sonatype.org/service/local/staging/deploy/maven2/") { 27 | authentication(userName: "*****", password: "*****") 28 | } 29 | 30 | snapshotRepository(url: "https://oss.sonatype.org/content/repositories/snapshots/") { 31 | authentication(userName: "*****", password: "*****") 32 | } 33 | 34 | pom.project { 35 | name 'Word2VecJava' 36 | packaging 'jar' 37 | description 'Word2Vec Java Port' 38 | url 'https://github.com/medallia/Word2VecJava' 39 | 40 | scm { 41 | connection 'scm:git:git@github.com:medallia/Word2VecJava.git' 42 | developerConnection 'scm:git:git@github.com:medallia/Word2VecJava.git' 43 | url 'https://github.com/medallia/Word2VecJava.git' 44 | } 45 | 46 | licenses { 47 | license { 48 | name 'The MIT License' 49 | url 'http://opensource.org/licenses/MIT' 50 | } 51 | } 52 | 53 | developers { 54 | developer { 55 | id 'wko' 56 | name 'Andrew Ko' 57 | email 'wko@medallia.com' 58 | } 59 | developer { 60 | id 'yibinlin' 61 | name 'Yibin Lin' 62 | email 'yibin@medallia.com' 63 | } 64 | developer { 65 | id 'yfpeng' 66 | name 'Yifan Peng' 67 | email 'yfpeng@udel.edu' 68 | } 69 | developer { 70 | id 'guerda' 71 | name 'Philip Gillißen' 72 | } 73 | } 74 | } 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medallia/Word2VecJava/eb31fbb99ac6bbab82d7f807b3e2240edca50eb7/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | #Mon Feb 02 17:12:48 PST 2015 2 | distributionBase=GRADLE_USER_HOME 3 | distributionPath=wrapper/dists 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | distributionUrl=https\://services.gradle.org/distributions/gradle-2.1-bin.zip 7 | -------------------------------------------------------------------------------- /gradlew: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ############################################################################## 4 | ## 5 | ## Gradle start up script for UN*X 6 | ## 7 | ############################################################################## 8 | 9 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 10 | DEFAULT_JVM_OPTS="" 11 | 12 | APP_NAME="Gradle" 13 | APP_BASE_NAME=`basename "$0"` 14 | 15 | # Use the maximum available, or set MAX_FD != -1 to use that value. 16 | MAX_FD="maximum" 17 | 18 | warn ( ) { 19 | echo "$*" 20 | } 21 | 22 | die ( ) { 23 | echo 24 | echo "$*" 25 | echo 26 | exit 1 27 | } 28 | 29 | # OS specific support (must be 'true' or 'false'). 30 | cygwin=false 31 | msys=false 32 | darwin=false 33 | case "`uname`" in 34 | CYGWIN* ) 35 | cygwin=true 36 | ;; 37 | Darwin* ) 38 | darwin=true 39 | ;; 40 | MINGW* ) 41 | msys=true 42 | ;; 43 | esac 44 | 45 | # For Cygwin, ensure paths are in UNIX format before anything is touched. 46 | if $cygwin ; then 47 | [ -n "$JAVA_HOME" ] && JAVA_HOME=`cygpath --unix "$JAVA_HOME"` 48 | fi 49 | 50 | # Attempt to set APP_HOME 51 | # Resolve links: $0 may be a link 52 | PRG="$0" 53 | # Need this for relative symlinks. 54 | while [ -h "$PRG" ] ; do 55 | ls=`ls -ld "$PRG"` 56 | link=`expr "$ls" : '.*-> \(.*\)$'` 57 | if expr "$link" : '/.*' > /dev/null; then 58 | PRG="$link" 59 | else 60 | PRG=`dirname "$PRG"`"/$link" 61 | fi 62 | done 63 | SAVED="`pwd`" 64 | cd "`dirname \"$PRG\"`/" >&- 65 | APP_HOME="`pwd -P`" 66 | cd "$SAVED" >&- 67 | 68 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 69 | 70 | # Determine the Java command to use to start the JVM. 71 | if [ -n "$JAVA_HOME" ] ; then 72 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 73 | # IBM's JDK on AIX uses strange locations for the executables 74 | JAVACMD="$JAVA_HOME/jre/sh/java" 75 | else 76 | JAVACMD="$JAVA_HOME/bin/java" 77 | fi 78 | if [ ! -x "$JAVACMD" ] ; then 79 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 80 | 81 | Please set the JAVA_HOME variable in your environment to match the 82 | location of your Java installation." 83 | fi 84 | else 85 | JAVACMD="java" 86 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 87 | 88 | Please set the JAVA_HOME variable in your environment to match the 89 | location of your Java installation." 90 | fi 91 | 92 | # Increase the maximum file descriptors if we can. 93 | if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then 94 | MAX_FD_LIMIT=`ulimit -H -n` 95 | if [ $? -eq 0 ] ; then 96 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then 97 | MAX_FD="$MAX_FD_LIMIT" 98 | fi 99 | ulimit -n $MAX_FD 100 | if [ $? -ne 0 ] ; then 101 | warn "Could not set maximum file descriptor limit: $MAX_FD" 102 | fi 103 | else 104 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" 105 | fi 106 | fi 107 | 108 | # For Darwin, add options to specify how the application appears in the dock 109 | if $darwin; then 110 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" 111 | fi 112 | 113 | # For Cygwin, switch paths to Windows format before running java 114 | if $cygwin ; then 115 | APP_HOME=`cygpath --path --mixed "$APP_HOME"` 116 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` 117 | 118 | # We build the pattern for arguments to be converted via cygpath 119 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` 120 | SEP="" 121 | for dir in $ROOTDIRSRAW ; do 122 | ROOTDIRS="$ROOTDIRS$SEP$dir" 123 | SEP="|" 124 | done 125 | OURCYGPATTERN="(^($ROOTDIRS))" 126 | # Add a user-defined pattern to the cygpath arguments 127 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then 128 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" 129 | fi 130 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 131 | i=0 132 | for arg in "$@" ; do 133 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` 134 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option 135 | 136 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition 137 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` 138 | else 139 | eval `echo args$i`="\"$arg\"" 140 | fi 141 | i=$((i+1)) 142 | done 143 | case $i in 144 | (0) set -- ;; 145 | (1) set -- "$args0" ;; 146 | (2) set -- "$args0" "$args1" ;; 147 | (3) set -- "$args0" "$args1" "$args2" ;; 148 | (4) set -- "$args0" "$args1" "$args2" "$args3" ;; 149 | (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; 150 | (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; 151 | (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; 152 | (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; 153 | (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; 154 | esac 155 | fi 156 | 157 | # Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules 158 | function splitJvmOpts() { 159 | JVM_OPTS=("$@") 160 | } 161 | eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS 162 | JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" 163 | 164 | exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" 165 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @if "%DEBUG%" == "" @echo off 2 | @rem ########################################################################## 3 | @rem 4 | @rem Gradle startup script for Windows 5 | @rem 6 | @rem ########################################################################## 7 | 8 | @rem Set local scope for the variables with windows NT shell 9 | if "%OS%"=="Windows_NT" setlocal 10 | 11 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 12 | set DEFAULT_JVM_OPTS= 13 | 14 | set DIRNAME=%~dp0 15 | if "%DIRNAME%" == "" set DIRNAME=. 16 | set APP_BASE_NAME=%~n0 17 | set APP_HOME=%DIRNAME% 18 | 19 | @rem Find java.exe 20 | if defined JAVA_HOME goto findJavaFromJavaHome 21 | 22 | set JAVA_EXE=java.exe 23 | %JAVA_EXE% -version >NUL 2>&1 24 | if "%ERRORLEVEL%" == "0" goto init 25 | 26 | echo. 27 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 28 | echo. 29 | echo Please set the JAVA_HOME variable in your environment to match the 30 | echo location of your Java installation. 31 | 32 | goto fail 33 | 34 | :findJavaFromJavaHome 35 | set JAVA_HOME=%JAVA_HOME:"=% 36 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 37 | 38 | if exist "%JAVA_EXE%" goto init 39 | 40 | echo. 41 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 42 | echo. 43 | echo Please set the JAVA_HOME variable in your environment to match the 44 | echo location of your Java installation. 45 | 46 | goto fail 47 | 48 | :init 49 | @rem Get command-line arguments, handling Windowz variants 50 | 51 | if not "%OS%" == "Windows_NT" goto win9xME_args 52 | if "%@eval[2+2]" == "4" goto 4NT_args 53 | 54 | :win9xME_args 55 | @rem Slurp the command line arguments. 56 | set CMD_LINE_ARGS= 57 | set _SKIP=2 58 | 59 | :win9xME_args_slurp 60 | if "x%~1" == "x" goto execute 61 | 62 | set CMD_LINE_ARGS=%* 63 | goto execute 64 | 65 | :4NT_args 66 | @rem Get arguments from the 4NT Shell from JP Software 67 | set CMD_LINE_ARGS=%$ 68 | 69 | :execute 70 | @rem Setup the command line 71 | 72 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 73 | 74 | @rem Execute Gradle 75 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% 76 | 77 | :end 78 | @rem End local scope for the variables with windows NT shell 79 | if "%ERRORLEVEL%"=="0" goto mainEnd 80 | 81 | :fail 82 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 83 | rem the _cmd.exe /c_ return code! 84 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 85 | exit /b 1 86 | 87 | :mainEnd 88 | if "%OS%"=="Windows_NT" endlocal 89 | 90 | :omega 91 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | com.medallia.word2vec 5 | medallia-word2vec 6 | 0.10.2 7 | 8 | 9 | MIT License 10 | http://opensource.org/licenses/MIT 11 | repo 12 | 13 | 14 | 15 | UTF-8 16 | UTF-8 17 | 18 | 19 | ${project.artifactId}-${project.version} 20 | 21 | 22 | maven-compiler-plugin 23 | 3.1 24 | 25 | 1.8 26 | 1.8 27 | 28 | 29 | 30 | org.apache.maven.plugins 31 | maven-eclipse-plugin 32 | 2.9 33 | 34 | true 35 | true 36 | 37 | 38 | 39 | 40 | 41 | 42 | oss-sonatype 43 | oss-sonatype 44 | https://oss.sonatype.org/content/repositories/snapshots/ 45 | 46 | true 47 | 48 | 49 | 50 | 51 | 52 | org.apache.commons 53 | commons-lang3 54 | 3.1 55 | 56 | 57 | com.google.guava 58 | guava 59 | 18.0 60 | 61 | 62 | commons-io 63 | commons-io 64 | 2.4 65 | 66 | 67 | junit 68 | junit 69 | 4.11 70 | 71 | 72 | log4j 73 | log4j 74 | 1.2.17 75 | 76 | 77 | joda-time 78 | joda-time 79 | 2.3 80 | 81 | 82 | org.apache.thrift 83 | libfb303 84 | 0.9.1 85 | 86 | 87 | org.apache.commons 88 | commons-math3 89 | 3.4.1 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /settings.gradle: -------------------------------------------------------------------------------- 1 | rootProject.name = 'Word2VecJava' 2 | 3 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/NormalizedWord2VecModel.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec; 2 | 3 | import com.medallia.word2vec.thrift.Word2VecModelThrift; 4 | 5 | import java.io.File; 6 | import java.io.IOException; 7 | import java.nio.ByteBuffer; 8 | import java.nio.DoubleBuffer; 9 | 10 | /** 11 | * Represents a word2vec model where all the vectors are normalized to unit length. 12 | */ 13 | public class NormalizedWord2VecModel extends Word2VecModel { 14 | private NormalizedWord2VecModel(Iterable vocab, int layerSize, final DoubleBuffer vectors) { 15 | super(vocab, layerSize, vectors); 16 | normalize(); 17 | } 18 | 19 | private NormalizedWord2VecModel(Iterable vocab, int layerSize, double[] vectors) { 20 | super(vocab, layerSize, vectors); 21 | normalize(); 22 | } 23 | 24 | public static NormalizedWord2VecModel fromWord2VecModel(Word2VecModel model) { 25 | return new NormalizedWord2VecModel(model.vocab, model.layerSize, model.vectors.duplicate()); 26 | } 27 | 28 | /** @return {@link NormalizedWord2VecModel} created from a thrift representation */ 29 | public static NormalizedWord2VecModel fromThrift(final Word2VecModelThrift thrift) { 30 | return fromWord2VecModel(Word2VecModel.fromThrift(thrift)); 31 | } 32 | 33 | public static NormalizedWord2VecModel fromBinFile(final File file) throws IOException { 34 | return fromWord2VecModel(Word2VecModel.fromBinFile(file)); 35 | } 36 | 37 | /** Normalizes the vectors in this model */ 38 | private void normalize() { 39 | for(int i = 0; i < vocab.size(); ++i) { 40 | double len = 0; 41 | for(int j = i * layerSize; j < (i + 1) * layerSize; ++j) 42 | len += vectors.get(j) * vectors.get(j); 43 | len = Math.sqrt(len); 44 | 45 | for(int j = i * layerSize; j < (i + 1) * layerSize; ++j) 46 | vectors.put(j, vectors.get(j) / len); 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/Searcher.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec; 2 | 3 | import com.google.common.base.Function; 4 | import com.google.common.collect.ImmutableList; 5 | import com.google.common.collect.Ordering; 6 | 7 | import java.util.List; 8 | 9 | /** Provides search functionality */ 10 | public interface Searcher { 11 | /** @return true if a word is inside the model's vocabulary. */ 12 | boolean contains(String word); 13 | 14 | /** @return Raw word vector */ 15 | ImmutableList getRawVector(String word) throws UnknownWordException; 16 | 17 | /** @return Top matches to the given word */ 18 | List getMatches(String word, int maxMatches) throws UnknownWordException; 19 | 20 | /** @return Top matches to the given vector */ 21 | List getMatches(final double[] vec, int maxNumMatches); 22 | 23 | /** Represents the similarity between two words */ 24 | public interface SemanticDifference { 25 | /** @return Top matches to the given word which share this semantic relationship */ 26 | List getMatches(String word, int maxMatches) throws UnknownWordException; 27 | } 28 | 29 | /** @return {@link SemanticDifference} between the word vectors for the given */ 30 | SemanticDifference similarity(String s1, String s2) throws UnknownWordException; 31 | 32 | /** @return cosine similarity between two words. */ 33 | double cosineDistance(String s1, String s2) throws UnknownWordException; 34 | 35 | /** Represents a match to a search word */ 36 | public interface Match { 37 | /** @return Matching word */ 38 | String match(); 39 | /** @return Cosine distance of the match */ 40 | double distance(); 41 | /** {@link Ordering} which compares {@link Match#distance()} */ 42 | Ordering ORDERING = Ordering.natural().onResultOf(new Function() { 43 | @Override public Double apply(Match match) { 44 | return match.distance(); 45 | } 46 | }); 47 | /** {@link Function} which forwards to {@link #match()} */ 48 | Function TO_WORD = new Function() { 49 | @Override public String apply(Match result) { 50 | return result.match(); 51 | } 52 | }; 53 | } 54 | 55 | /** Exception when a word is unknown to the {@link Word2VecModel}'s vocabulary */ 56 | public static class UnknownWordException extends Exception { 57 | UnknownWordException(String word) { 58 | super(String.format("Unknown search word '%s'", word)); 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/SearcherImpl.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec; 2 | 3 | import com.google.common.base.Function; 4 | import com.google.common.collect.ImmutableList; 5 | import com.google.common.collect.ImmutableMap; 6 | import com.google.common.collect.Iterables; 7 | import com.google.common.primitives.Doubles; 8 | import com.medallia.word2vec.util.Pair; 9 | 10 | import java.nio.DoubleBuffer; 11 | import java.util.Arrays; 12 | import java.util.List; 13 | 14 | /** Implementation of {@link Searcher} */ 15 | class SearcherImpl implements Searcher { 16 | private final NormalizedWord2VecModel model; 17 | private final ImmutableMap word2vectorOffset; 18 | 19 | SearcherImpl(final NormalizedWord2VecModel model) { 20 | this.model = model; 21 | 22 | final ImmutableMap.Builder result = ImmutableMap.builder(); 23 | for (int i = 0; i < model.vocab.size(); i++) { 24 | result.put(model.vocab.get(i), i * model.layerSize); 25 | } 26 | 27 | word2vectorOffset = result.build(); 28 | } 29 | 30 | SearcherImpl(final Word2VecModel model) { 31 | this(NormalizedWord2VecModel.fromWord2VecModel(model)); 32 | } 33 | 34 | @Override public List getMatches(String s, int maxNumMatches) throws UnknownWordException { 35 | return getMatches(getVector(s), maxNumMatches); 36 | } 37 | 38 | @Override public double cosineDistance(String s1, String s2) throws UnknownWordException { 39 | return calculateDistance(getVector(s1), getVector(s2)); 40 | } 41 | 42 | @Override public boolean contains(String word) { 43 | return word2vectorOffset.containsKey(word); 44 | } 45 | 46 | @Override public List getMatches(final double[] vec, int maxNumMatches) { 47 | return Match.ORDERING.greatestOf( 48 | Iterables.transform(model.vocab, new Function() { 49 | @Override 50 | public Match apply(String other) { 51 | double[] otherVec = getVectorOrNull(other); 52 | double d = calculateDistance(otherVec, vec); 53 | return new MatchImpl(other, d); 54 | } 55 | }), 56 | maxNumMatches 57 | ); 58 | } 59 | 60 | private double calculateDistance(double[] otherVec, double[] vec) { 61 | double d = 0; 62 | for (int a = 0; a < model.layerSize; a++) 63 | d += vec[a] * otherVec[a]; 64 | return d; 65 | } 66 | 67 | @Override public ImmutableList getRawVector(String word) throws UnknownWordException { 68 | return ImmutableList.copyOf(Doubles.asList(getVector(word))); 69 | } 70 | 71 | /** 72 | * @return Vector for the given word 73 | * @throws UnknownWordException If word is not in the model's vocabulary 74 | */ 75 | private double[] getVector(String word) throws UnknownWordException { 76 | final double[] result = getVectorOrNull(word); 77 | if(result == null) 78 | throw new UnknownWordException(word); 79 | 80 | return result; 81 | } 82 | 83 | private double[] getVectorOrNull(final String word) { 84 | final Integer index = word2vectorOffset.get(word); 85 | if(index == null) 86 | return null; 87 | 88 | final DoubleBuffer vectors = model.vectors.duplicate(); 89 | double[] result = new double[model.layerSize]; 90 | vectors.position(index); 91 | vectors.get(result); 92 | return result; 93 | } 94 | 95 | /** @return Vector difference from v1 to v2 */ 96 | private double[] getDifference(double[] v1, double[] v2) { 97 | double[] diff = new double[model.layerSize]; 98 | for (int i = 0; i < model.layerSize; i++) 99 | diff[i] = v1[i] - v2[i]; 100 | return diff; 101 | } 102 | 103 | @Override public SemanticDifference similarity(String s1, String s2) throws UnknownWordException { 104 | double[] v1 = getVector(s1); 105 | double[] v2 = getVector(s2); 106 | final double[] diff = getDifference(v1, v2); 107 | 108 | return new SemanticDifference() { 109 | @Override public List getMatches(String word, int maxMatches) throws UnknownWordException { 110 | double[] target = getDifference(getVector(word), diff); 111 | return SearcherImpl.this.getMatches(target, maxMatches); 112 | } 113 | }; 114 | } 115 | 116 | /** Implementation of {@link Match} */ 117 | private static class MatchImpl extends Pair implements Match { 118 | private MatchImpl(String first, Double second) { 119 | super(first, second); 120 | } 121 | 122 | @Override public String match() { 123 | return first; 124 | } 125 | 126 | @Override public double distance() { 127 | return second; 128 | } 129 | 130 | @Override public String toString() { 131 | return String.format("%s [%s]", first, second); 132 | } 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/Word2VecExamples.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec; 2 | 3 | import com.google.common.base.Function; 4 | import com.google.common.collect.Lists; 5 | import com.medallia.word2vec.Searcher.Match; 6 | import com.medallia.word2vec.Searcher.UnknownWordException; 7 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener; 8 | import com.medallia.word2vec.neuralnetwork.NeuralNetworkType; 9 | import com.medallia.word2vec.thrift.Word2VecModelThrift; 10 | import com.medallia.word2vec.util.AutoLog; 11 | import com.medallia.word2vec.util.Common; 12 | import com.medallia.word2vec.util.Format; 13 | import com.medallia.word2vec.util.ProfilingTimer; 14 | import com.medallia.word2vec.util.Strings; 15 | import com.medallia.word2vec.util.ThriftUtils; 16 | import org.apache.commons.logging.Log; 17 | import org.apache.thrift.TException; 18 | import org.apache.commons.io.FileUtils; 19 | 20 | import java.io.BufferedReader; 21 | import java.io.File; 22 | import java.io.IOException; 23 | import java.io.InputStreamReader; 24 | import java.io.OutputStream; 25 | import java.nio.file.Files; 26 | import java.nio.file.Paths; 27 | import java.util.Arrays; 28 | import java.util.List; 29 | 30 | /** Example usages of {@link Word2VecModel} */ 31 | public class Word2VecExamples { 32 | private static final Log LOG = AutoLog.getLog(); 33 | 34 | /** Runs the example */ 35 | public static void main(String[] args) throws IOException, TException, UnknownWordException, InterruptedException { 36 | demoWord(); 37 | } 38 | 39 | /** 40 | * Trains a model and allows user to find similar words 41 | * demo-word.sh example from the open source C implementation 42 | */ 43 | public static void demoWord() throws IOException, TException, InterruptedException, UnknownWordException { 44 | File f = new File("text8"); 45 | if (!f.exists()) 46 | throw new IllegalStateException("Please download and unzip the text8 example from http://mattmahoney.net/dc/text8.zip"); 47 | List read = Common.readToList(f); 48 | List> partitioned = Lists.transform(read, new Function>() { 49 | @Override 50 | public List apply(String input) { 51 | return Arrays.asList(input.split(" ")); 52 | } 53 | }); 54 | 55 | Word2VecModel model = Word2VecModel.trainer() 56 | .setMinVocabFrequency(5) 57 | .useNumThreads(20) 58 | .setWindowSize(8) 59 | .type(NeuralNetworkType.CBOW) 60 | .setLayerSize(200) 61 | .useNegativeSamples(25) 62 | .setDownSamplingRate(1e-4) 63 | .setNumIterations(5) 64 | .setListener(new TrainingProgressListener() { 65 | @Override public void update(Stage stage, double progress) { 66 | System.out.println(String.format("%s is %.2f%% complete", Format.formatEnum(stage), progress * 100)); 67 | } 68 | }) 69 | .train(partitioned); 70 | 71 | // Writes model to a thrift file 72 | try (ProfilingTimer timer = ProfilingTimer.create(LOG, "Writing output to file")) { 73 | FileUtils.writeStringToFile(new File("text8.model"), ThriftUtils.serializeJson(model.toThrift())); 74 | } 75 | 76 | // Alternatively, you can write the model to a bin file that's compatible with the C 77 | // implementation. 78 | try(final OutputStream os = Files.newOutputStream(Paths.get("text8.bin"))) { 79 | model.toBinFile(os); 80 | } 81 | 82 | interact(model.forSearch()); 83 | } 84 | 85 | /** Loads a model and allows user to find similar words */ 86 | public static void loadModel() throws IOException, TException, UnknownWordException { 87 | final Word2VecModel model; 88 | try (ProfilingTimer timer = ProfilingTimer.create(LOG, "Loading model")) { 89 | String json = Common.readFileToString(new File("text8.model")); 90 | model = Word2VecModel.fromThrift(ThriftUtils.deserializeJson(new Word2VecModelThrift(), json)); 91 | } 92 | interact(model.forSearch()); 93 | } 94 | 95 | /** Example using Skip-Gram model */ 96 | public static void skipGram() throws IOException, TException, InterruptedException, UnknownWordException { 97 | List read = Common.readToList(new File("sents.cleaned.word2vec.txt")); 98 | List> partitioned = Lists.transform(read, new Function>() { 99 | @Override 100 | public List apply(String input) { 101 | return Arrays.asList(input.split(" ")); 102 | } 103 | }); 104 | 105 | Word2VecModel model = Word2VecModel.trainer() 106 | .setMinVocabFrequency(100) 107 | .useNumThreads(20) 108 | .setWindowSize(7) 109 | .type(NeuralNetworkType.SKIP_GRAM) 110 | .useHierarchicalSoftmax() 111 | .setLayerSize(300) 112 | .useNegativeSamples(0) 113 | .setDownSamplingRate(1e-3) 114 | .setNumIterations(5) 115 | .setListener(new TrainingProgressListener() { 116 | @Override public void update(Stage stage, double progress) { 117 | System.out.println(String.format("%s is %.2f%% complete", Format.formatEnum(stage), progress * 100)); 118 | } 119 | }) 120 | .train(partitioned); 121 | 122 | try (ProfilingTimer timer = ProfilingTimer.create(LOG, "Writing output to file")) { 123 | FileUtils.writeStringToFile(new File("300layer.20threads.5iter.model"), ThriftUtils.serializeJson(model.toThrift())); 124 | } 125 | 126 | interact(model.forSearch()); 127 | } 128 | 129 | private static void interact(Searcher searcher) throws IOException, UnknownWordException { 130 | try (BufferedReader br = new BufferedReader(new InputStreamReader(System.in))) { 131 | while (true) { 132 | System.out.print("Enter word or sentence (EXIT to break): "); 133 | String word = br.readLine(); 134 | if (word.equals("EXIT")) { 135 | break; 136 | } 137 | List matches = searcher.getMatches(word, 20); 138 | System.out.println(Strings.joinObjects("\n", matches)); 139 | } 140 | } 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/Word2VecModel.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec; 2 | 3 | import java.io.File; 4 | import java.io.FileInputStream; 5 | import java.io.IOException; 6 | import java.io.OutputStream; 7 | import java.nio.ByteBuffer; 8 | import java.nio.ByteOrder; 9 | import java.nio.DoubleBuffer; 10 | import java.nio.FloatBuffer; 11 | import java.nio.MappedByteBuffer; 12 | import java.nio.channels.FileChannel; 13 | import java.nio.charset.Charset; 14 | import java.util.ArrayList; 15 | import java.util.List; 16 | 17 | import com.google.common.annotations.VisibleForTesting; 18 | import com.google.common.base.Preconditions; 19 | import com.google.common.collect.ImmutableList; 20 | import com.google.common.collect.Lists; 21 | import com.google.common.primitives.Doubles; 22 | import com.medallia.word2vec.thrift.Word2VecModelThrift; 23 | import com.medallia.word2vec.util.Common; 24 | import com.medallia.word2vec.util.ProfilingTimer; 25 | import com.medallia.word2vec.util.AC; 26 | 27 | 28 | /** 29 | * Represents the Word2Vec model, containing vectors for each word 30 | *

31 | * Instances of this class are obtained via: 32 | *

    33 | *
  • {@link #trainer()} 34 | *
  • {@link #fromThrift(Word2VecModelThrift)} 35 | *
36 | * 37 | * @see {@link #forSearch()} 38 | */ 39 | public class Word2VecModel { 40 | final List vocab; 41 | final int layerSize; 42 | final DoubleBuffer vectors; 43 | private final static long ONE_GB = 1024 * 1024 * 1024; 44 | 45 | Word2VecModel(Iterable vocab, int layerSize, DoubleBuffer vectors) { 46 | this.vocab = ImmutableList.copyOf(vocab); 47 | this.layerSize = layerSize; 48 | this.vectors = vectors; 49 | } 50 | 51 | Word2VecModel(Iterable vocab, int layerSize, double[] vectors) { 52 | this(vocab, layerSize, DoubleBuffer.wrap(vectors)); 53 | } 54 | 55 | /** @return Vocabulary */ 56 | public Iterable getVocab() { 57 | return vocab; 58 | } 59 | 60 | /** @return Layer size */ 61 | public int getLayerSize() { 62 | return layerSize; 63 | } 64 | 65 | /** @return {@link Searcher} for searching */ 66 | public Searcher forSearch() { 67 | return new SearcherImpl(this); 68 | } 69 | 70 | /** @return Serializable thrift representation */ 71 | public Word2VecModelThrift toThrift() { 72 | double[] vectorsArray; 73 | if(vectors.hasArray()) { 74 | vectorsArray = vectors.array(); 75 | } else { 76 | vectorsArray = new double[vectors.limit()]; 77 | vectors.position(0); 78 | vectors.get(vectorsArray); 79 | } 80 | 81 | return new Word2VecModelThrift() 82 | .setVocab(vocab) 83 | .setLayerSize(layerSize) 84 | .setVectors(Doubles.asList(vectorsArray)); 85 | } 86 | 87 | /** @return {@link Word2VecModel} created from a thrift representation */ 88 | public static Word2VecModel fromThrift(Word2VecModelThrift thrift) { 89 | return new Word2VecModel( 90 | thrift.getVocab(), 91 | thrift.getLayerSize(), 92 | Doubles.toArray(thrift.getVectors())); 93 | } 94 | 95 | /** 96 | * @return {@link Word2VecModel} read from a file in the text output format of the Word2Vec C 97 | * open source project. 98 | */ 99 | public static Word2VecModel fromTextFile(File file) throws IOException { 100 | List lines = Common.readToList(file); 101 | return fromTextFile(file.getAbsolutePath(), lines); 102 | } 103 | 104 | /** 105 | * Forwards to {@link #fromBinFile(File, ByteOrder, ProfilingTimer)} with the default 106 | * ByteOrder.LITTLE_ENDIAN and no ProfilingTimer 107 | */ 108 | public static Word2VecModel fromBinFile(File file) 109 | throws IOException { 110 | return fromBinFile(file, ByteOrder.LITTLE_ENDIAN, ProfilingTimer.NONE); 111 | } 112 | 113 | /** 114 | * Forwards to {@link #fromBinFile(File, ByteOrder, ProfilingTimer)} with no ProfilingTimer 115 | */ 116 | public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder) 117 | throws IOException { 118 | return fromBinFile(file, byteOrder, ProfilingTimer.NONE); 119 | } 120 | 121 | /** 122 | * @return {@link Word2VecModel} created from the binary representation output 123 | * by the open source C version of word2vec using the given byte order. 124 | */ 125 | public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder, ProfilingTimer timer) 126 | throws IOException { 127 | 128 | try ( 129 | final FileInputStream fis = new FileInputStream(file); 130 | final AC ac = timer.start("Loading vectors from bin file") 131 | ) { 132 | final FileChannel channel = fis.getChannel(); 133 | timer.start("Reading gigabyte #1"); 134 | MappedByteBuffer buffer = 135 | channel.map( 136 | FileChannel.MapMode.READ_ONLY, 137 | 0, 138 | Math.min(channel.size(), Integer.MAX_VALUE)); 139 | buffer.order(byteOrder); 140 | int bufferCount = 1; 141 | // Java's NIO only allows memory-mapping up to 2GB. To work around this problem, we re-map 142 | // every gigabyte. To calculate offsets correctly, we have to keep track how many gigabytes 143 | // we've already skipped. That's what this is for. 144 | 145 | StringBuilder sb = new StringBuilder(); 146 | char c = (char) buffer.get(); 147 | while (c != '\n') { 148 | sb.append(c); 149 | c = (char) buffer.get(); 150 | } 151 | String firstLine = sb.toString(); 152 | int index = firstLine.indexOf(' '); 153 | Preconditions.checkState(index != -1, 154 | "Expected a space in the first line of file '%s': '%s'", 155 | file.getAbsolutePath(), firstLine); 156 | 157 | final int vocabSize = Integer.parseInt(firstLine.substring(0, index)); 158 | final int layerSize = Integer.parseInt(firstLine.substring(index + 1)); 159 | timer.appendToLog(String.format( 160 | "Loading %d vectors with dimensionality %d", 161 | vocabSize, 162 | layerSize)); 163 | 164 | List vocabs = new ArrayList(vocabSize); 165 | DoubleBuffer vectors = ByteBuffer.allocateDirect(vocabSize * layerSize * 8).asDoubleBuffer(); 166 | 167 | long lastLogMessage = System.currentTimeMillis(); 168 | final float[] floats = new float[layerSize]; 169 | for (int lineno = 0; lineno < vocabSize; lineno++) { 170 | // read vocab 171 | sb.setLength(0); 172 | c = (char) buffer.get(); 173 | while (c != ' ') { 174 | // ignore newlines in front of words (some binary files have newline, 175 | // some don't) 176 | if (c != '\n') { 177 | sb.append(c); 178 | } 179 | c = (char) buffer.get(); 180 | } 181 | vocabs.add(sb.toString()); 182 | 183 | // read vector 184 | final FloatBuffer floatBuffer = buffer.asFloatBuffer(); 185 | floatBuffer.get(floats); 186 | for (int i = 0; i < floats.length; ++i) { 187 | vectors.put(lineno * layerSize + i, floats[i]); 188 | } 189 | buffer.position(buffer.position() + 4 * layerSize); 190 | 191 | // print log 192 | final long now = System.currentTimeMillis(); 193 | if (now - lastLogMessage > 1000) { 194 | final double percentage = ((double) (lineno + 1) / (double) vocabSize) * 100.0; 195 | timer.appendToLog( 196 | String.format("Loaded %d/%d vectors (%f%%)", lineno + 1, vocabSize, percentage)); 197 | lastLogMessage = now; 198 | } 199 | 200 | // remap file 201 | if (buffer.position() > ONE_GB) { 202 | final int newPosition = (int) (buffer.position() - ONE_GB); 203 | final long size = Math.min(channel.size() - ONE_GB * bufferCount, Integer.MAX_VALUE); 204 | timer.endAndStart( 205 | "Reading gigabyte #%d. Start: %d, size: %d", 206 | bufferCount, 207 | ONE_GB * bufferCount, 208 | size); 209 | buffer = channel.map( 210 | FileChannel.MapMode.READ_ONLY, 211 | ONE_GB * bufferCount, 212 | size); 213 | buffer.order(byteOrder); 214 | buffer.position(newPosition); 215 | bufferCount += 1; 216 | } 217 | } 218 | timer.end(); 219 | 220 | return new Word2VecModel(vocabs, layerSize, vectors); 221 | } 222 | } 223 | 224 | /** 225 | * Saves the model as a bin file that's compatible with the C version of Word2Vec 226 | */ 227 | public void toBinFile(final OutputStream out) throws IOException { 228 | final Charset cs = Charset.forName("UTF-8"); 229 | final String header = String.format("%d %d\n", vocab.size(), layerSize); 230 | out.write(header.getBytes(cs)); 231 | 232 | final double[] vector = new double[layerSize]; 233 | final ByteBuffer buffer = ByteBuffer.allocate(4 * layerSize); 234 | buffer.order(ByteOrder.LITTLE_ENDIAN); // The C version uses this byte order. 235 | for(int i = 0; i < vocab.size(); ++i) { 236 | out.write(String.format("%s ", vocab.get(i)).getBytes(cs)); 237 | 238 | vectors.position(i * layerSize); 239 | vectors.get(vector); 240 | buffer.clear(); 241 | for(int j = 0; j < layerSize; ++j) 242 | buffer.putFloat((float)vector[j]); 243 | out.write(buffer.array()); 244 | 245 | out.write('\n'); 246 | } 247 | 248 | out.flush(); 249 | } 250 | 251 | /** 252 | * @return {@link Word2VecModel} from the lines of the file in the text output format of the 253 | * Word2Vec C open source project. 254 | */ 255 | @VisibleForTesting 256 | static Word2VecModel fromTextFile(String filename, List lines) throws IOException { 257 | List vocab = Lists.newArrayList(); 258 | List vectors = Lists.newArrayList(); 259 | int vocabSize = Integer.parseInt(lines.get(0).split(" ")[0]); 260 | int layerSize = Integer.parseInt(lines.get(0).split(" ")[1]); 261 | 262 | Preconditions.checkArgument( 263 | vocabSize == lines.size() - 1, 264 | "For file '%s', vocab size is %s, but there are %s word vectors in the file", 265 | filename, 266 | vocabSize, 267 | lines.size() - 1 268 | ); 269 | 270 | for (int n = 1; n < lines.size(); n++) { 271 | String[] values = lines.get(n).split(" "); 272 | vocab.add(values[0]); 273 | 274 | // Sanity check 275 | Preconditions.checkArgument( 276 | layerSize == values.length - 1, 277 | "For file '%s', on line %s, layer size is %s, but found %s values in the word vector", 278 | filename, 279 | n, 280 | layerSize, 281 | values.length - 1 282 | ); 283 | 284 | for (int d = 1; d < values.length; d++) { 285 | vectors.add(Double.parseDouble(values[d])); 286 | } 287 | } 288 | 289 | Word2VecModelThrift thrift = new Word2VecModelThrift() 290 | .setLayerSize(layerSize) 291 | .setVocab(vocab) 292 | .setVectors(vectors); 293 | return fromThrift(thrift); 294 | } 295 | 296 | /** @return {@link Word2VecTrainerBuilder} for training a model */ 297 | public static Word2VecTrainerBuilder trainer() { 298 | return new Word2VecTrainerBuilder(); 299 | } 300 | } 301 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/Word2VecTrainer.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec; 2 | 3 | import com.google.common.base.Optional; 4 | import com.google.common.base.Predicate; 5 | import com.google.common.collect.HashMultiset; 6 | import com.google.common.collect.ImmutableMultiset; 7 | import com.google.common.collect.ImmutableSortedMultiset; 8 | import com.google.common.collect.Iterables; 9 | import com.google.common.collect.Multiset; 10 | import com.google.common.collect.Multisets; 11 | import com.google.common.primitives.Doubles; 12 | import com.medallia.word2vec.util.AC; 13 | import com.medallia.word2vec.util.ProfilingTimer; 14 | import org.apache.commons.logging.Log; 15 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener; 16 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener.Stage; 17 | import com.medallia.word2vec.huffman.HuffmanCoding; 18 | import com.medallia.word2vec.huffman.HuffmanCoding.HuffmanNode; 19 | import com.medallia.word2vec.neuralnetwork.NeuralNetworkConfig; 20 | import com.medallia.word2vec.neuralnetwork.NeuralNetworkTrainer.NeuralNetworkModel; 21 | 22 | import java.util.List; 23 | import java.util.Map; 24 | 25 | /** Responsible for training a word2vec model */ 26 | class Word2VecTrainer { 27 | private final int minFrequency; 28 | private final Optional> vocab; 29 | private final NeuralNetworkConfig neuralNetworkConfig; 30 | 31 | Word2VecTrainer( 32 | Integer minFrequency, 33 | Optional> vocab, 34 | NeuralNetworkConfig neuralNetworkConfig) { 35 | this.vocab = vocab; 36 | this.minFrequency = minFrequency; 37 | this.neuralNetworkConfig = neuralNetworkConfig; 38 | } 39 | 40 | /** @return {@link Multiset} containing unique tokens and their counts */ 41 | private static Multiset count(Iterable tokens) { 42 | Multiset counts = HashMultiset.create(); 43 | for (String token : tokens) 44 | counts.add(token); 45 | return counts; 46 | } 47 | 48 | /** @return Tokens with their count, sorted by frequency decreasing, then lexicographically ascending */ 49 | private ImmutableMultiset filterAndSort(final Multiset counts) { 50 | // This isn't terribly efficient, but it is deterministic 51 | // Unfortunately, Guava's multiset doesn't give us a clean way to order both by count and element 52 | return Multisets.copyHighestCountFirst( 53 | ImmutableSortedMultiset.copyOf( 54 | Multisets.filter( 55 | counts, 56 | new Predicate() { 57 | @Override 58 | public boolean apply(String s) { 59 | return counts.count(s) >= minFrequency; 60 | } 61 | } 62 | ) 63 | ) 64 | ); 65 | 66 | } 67 | 68 | /** Train a model using the given data */ 69 | Word2VecModel train(Log log, TrainingProgressListener listener, Iterable> sentences) throws InterruptedException { 70 | try (ProfilingTimer timer = ProfilingTimer.createLoggingSubtasks(log, "Training word2vec")) { 71 | final Multiset counts; 72 | 73 | try (AC ac = timer.start("Acquiring word frequencies")) { 74 | listener.update(Stage.ACQUIRE_VOCAB, 0.0); 75 | counts = (vocab.isPresent()) 76 | ? vocab.get() 77 | : count(Iterables.concat(sentences)); 78 | } 79 | 80 | final ImmutableMultiset vocab; 81 | try (AC ac = timer.start("Filtering and sorting vocabulary")) { 82 | listener.update(Stage.FILTER_SORT_VOCAB, 0.0); 83 | vocab = filterAndSort(counts); 84 | } 85 | 86 | final Map huffmanNodes; 87 | try (AC task = timer.start("Create Huffman encoding")) { 88 | huffmanNodes = new HuffmanCoding(vocab, listener).encode(); 89 | } 90 | 91 | final NeuralNetworkModel model; 92 | try (AC task = timer.start("Training model %s", neuralNetworkConfig)) { 93 | model = neuralNetworkConfig.createTrainer(vocab, huffmanNodes, listener).train(sentences); 94 | } 95 | 96 | return new Word2VecModel(vocab.elementSet(), model.layerSize(), Doubles.concat(model.vectors())); 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/Word2VecTrainerBuilder.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec; 2 | 3 | import com.google.common.base.MoreObjects; 4 | import com.google.common.base.Optional; 5 | import com.google.common.base.Preconditions; 6 | import com.google.common.collect.Multiset; 7 | import com.medallia.word2vec.util.AutoLog; 8 | import org.apache.commons.logging.Log; 9 | import com.medallia.word2vec.neuralnetwork.NeuralNetworkConfig; 10 | import com.medallia.word2vec.neuralnetwork.NeuralNetworkType; 11 | 12 | import java.util.List; 13 | import java.util.Map; 14 | 15 | /** 16 | * Builder pattern for training a new {@link Word2VecModel} 17 | *

18 | * This is a port of the open source C implementation of word2vec 19 | *

20 | * Note that this isn't a completely faithful rewrite, specifically: 21 | *

    22 | *
  • When building the vocabulary from the training file: 23 | *
      24 | *
    • The original version does a reduction step when learning the vocabulary from the file 25 | * when the vocab size hits 21 million words, removing any words that do not meet the 26 | * minimum frequency threshold. This Java port has no such reduction step. 27 | *
    • The original version injects a </s> token into the vocabulary (with a word count of 0) 28 | * as a substitute for newlines in the input file. This Java port's vocabulary excludes the token. 29 | *
    30 | *
  • In partitioning the file for processing 31 | *
      32 | *
    • The original version assumes that sentences are delimited by newline characters and injects a sentence 33 | * boundary per 1000 non-filtered tokens, i.e. valid token by the vocabulary and not removed by the randomized 34 | * sampling process. Java port mimics this behavior for now ... 35 | *
    • When the original version encounters an empty line in the input file, it re-processes the first word of the 36 | * last non-empty line with a sentence length of 0 and updates the random value. Java port omits this behavior. 37 | *
    38 | *
  • In the sampling function 39 | *
      40 | *
    • The original C documentation indicates that the range should be between 0 and 1e-5, but the default value is 1e-3. 41 | * This Java port retains that confusing information. 42 | *
    • The random value generated for comparison to determine if a token should be filtered uses a float. 43 | * This Java port uses double precision for twice the fun 44 | *
    45 | *
  • In the distance function to find the nearest matches to a target query 46 | *
      47 | *
    • The original version includes an unnecessary normalization of the vector for the input query which 48 | * may lead to tiny inaccuracies. This Java port foregoes this superfluous operation. 49 | *
    • The original version has an O(n * k) algorithm for finding top matches and is hardcoded to 40 matches. 50 | * This Java port uses Google's lovely {@link com.google.common.collect.Ordering#greatestOf(java.util.Iterator, int)} 51 | * which is O(n + k log k) and takes in arbitrary k. 52 | *
    53 | *
  • The k-means clustering option is excluded in the Java port 54 | *
55 | * 56 | *

57 | *

58 | * Please do not hesitate to peek at the source code. 59 | *
60 | * It should be readable, concise, and correct. 61 | *
62 | * I ask you to reach out if it is not. 63 | */ 64 | public class Word2VecTrainerBuilder { 65 | private static final Log LOG = AutoLog.getLog(); 66 | 67 | private Integer layerSize; 68 | private Integer windowSize; 69 | private Integer numThreads; 70 | private NeuralNetworkType type; 71 | private int negativeSamples; 72 | private boolean useHierarchicalSoftmax; 73 | private Multiset vocab; 74 | private Integer minFrequency; 75 | private Double initialLearningRate; 76 | private Double downSampleRate; 77 | private Integer iterations; 78 | private TrainingProgressListener listener; 79 | 80 | Word2VecTrainerBuilder() { 81 | } 82 | 83 | /** 84 | * Size of the layers in the neural network model 85 | *

86 | * Defaults to 100 87 | */ 88 | public Word2VecTrainerBuilder setLayerSize(int layerSize) { 89 | Preconditions.checkArgument(layerSize > 0, "Value must be positive"); 90 | this.layerSize = layerSize; 91 | return this; 92 | } 93 | 94 | /** 95 | * Size of the window to consider 96 | *

97 | * Default window size is 5 tokens 98 | */ 99 | public Word2VecTrainerBuilder setWindowSize(int windowSize) { 100 | Preconditions.checkArgument(windowSize > 0, "Value must be positive"); 101 | this.windowSize = windowSize; 102 | return this; 103 | } 104 | 105 | /** 106 | * Specify number of threads to use for parallelization 107 | *

108 | * Defaults to {@link Runtime#availableProcessors()} 109 | */ 110 | public Word2VecTrainerBuilder useNumThreads(int numThreads) { 111 | Preconditions.checkArgument(numThreads > 0, "Value must be positive"); 112 | this.numThreads = numThreads; 113 | return this; 114 | } 115 | 116 | /** 117 | * @see {@link NeuralNetworkType} 118 | *

119 | * By default, word2vec uses the {@link NeuralNetworkType#SKIP_GRAM} 120 | */ 121 | public Word2VecTrainerBuilder type(NeuralNetworkType type) { 122 | this.type = Preconditions.checkNotNull(type); 123 | return this; 124 | } 125 | 126 | /** 127 | * Specify to use hierarchical softmax 128 | *

129 | * By default, word2vec does not use hierarchical softmax 130 | */ 131 | public Word2VecTrainerBuilder useHierarchicalSoftmax() { 132 | this.useHierarchicalSoftmax = true; 133 | return this; 134 | } 135 | 136 | /** 137 | * Number of negative samples to use 138 | * Common values are between 5 and 10 139 | *

140 | * Defaults to 0 141 | */ 142 | public Word2VecTrainerBuilder useNegativeSamples(int negativeSamples) { 143 | Preconditions.checkArgument(negativeSamples >= 0, "Value must be non-negative"); 144 | this.negativeSamples = negativeSamples; 145 | return this; 146 | } 147 | 148 | /** 149 | * Use a pre-built vocabulary 150 | *

151 | * If this is not specified, word2vec will attempt to learn a vocabulary from the training data 152 | * @param vocab {@link Map} from token to frequency 153 | */ 154 | public Word2VecTrainerBuilder useVocab(Multiset vocab) { 155 | this.vocab = Preconditions.checkNotNull(vocab); 156 | return this; 157 | } 158 | 159 | /** 160 | * Specify the minimum frequency for a valid token to be considered 161 | * part of the vocabulary 162 | *

163 | * Defaults to 5 164 | */ 165 | public Word2VecTrainerBuilder setMinVocabFrequency(int minFrequency) { 166 | Preconditions.checkArgument(minFrequency >= 0, "Value must be non-negative"); 167 | this.minFrequency = minFrequency; 168 | return this; 169 | } 170 | 171 | /** 172 | * Set the starting learning rate 173 | *

174 | * Default is 0.025 for skip-gram and 0.05 for CBOW 175 | */ 176 | public Word2VecTrainerBuilder setInitialLearningRate(double initialLearningRate) { 177 | Preconditions.checkArgument(initialLearningRate >= 0, "Value must be non-negative"); 178 | this.initialLearningRate = initialLearningRate; 179 | return this; 180 | } 181 | 182 | /** 183 | * Set threshold for occurrence of words. Those that appear with higher frequency in the training data, 184 | * e.g. stopwords, will be randomly removed 185 | *

186 | * Default is 1e-3, useful range is (0, 1e-5) 187 | */ 188 | public Word2VecTrainerBuilder setDownSamplingRate(double downSampleRate) { 189 | Preconditions.checkArgument(downSampleRate >= 0, "Value must be non-negative"); 190 | this.downSampleRate = downSampleRate; 191 | return this; 192 | } 193 | 194 | /** Set the number of iterations */ 195 | public Word2VecTrainerBuilder setNumIterations(int iterations) { 196 | Preconditions.checkArgument(iterations > 0, "Value must be positive"); 197 | this.iterations = iterations; 198 | return this; 199 | } 200 | 201 | /** Set a progress listener */ 202 | public Word2VecTrainerBuilder setListener(TrainingProgressListener listener) { 203 | this.listener = listener; 204 | return this; 205 | } 206 | 207 | /** Train the model */ 208 | public Word2VecModel train(Iterable> sentences) throws InterruptedException { 209 | this.type = MoreObjects.firstNonNull(type, NeuralNetworkType.CBOW); 210 | this.initialLearningRate = MoreObjects.firstNonNull(initialLearningRate, type.getDefaultInitialLearningRate()); 211 | if (this.numThreads == null) 212 | this.numThreads = Runtime.getRuntime().availableProcessors(); 213 | this.iterations = MoreObjects.firstNonNull(iterations, 5); 214 | this.layerSize = MoreObjects.firstNonNull(layerSize, 100); 215 | this.windowSize = MoreObjects.firstNonNull(windowSize, 5); 216 | this.downSampleRate = MoreObjects.firstNonNull(downSampleRate, 0.001); 217 | this.minFrequency = MoreObjects.firstNonNull(minFrequency, 5); 218 | this.listener = MoreObjects.firstNonNull(listener, new TrainingProgressListener() { 219 | @Override 220 | public void update(Stage stage, double progress) { 221 | System.out.println(String.format("Stage %s, progress %s%%", stage, progress)); 222 | } 223 | }); 224 | 225 | Optional> vocab = this.vocab == null 226 | ? Optional.>absent() 227 | : Optional.of(this.vocab); 228 | 229 | return new Word2VecTrainer( 230 | minFrequency, 231 | vocab, 232 | new NeuralNetworkConfig( 233 | type, 234 | numThreads, 235 | iterations, 236 | layerSize, 237 | windowSize, 238 | negativeSamples, 239 | downSampleRate, 240 | initialLearningRate, 241 | useHierarchicalSoftmax 242 | ) 243 | ).train(LOG, listener, sentences); 244 | } 245 | 246 | /** Listener for model training progress */ 247 | public interface TrainingProgressListener { 248 | /** Sequential stages of processing */ 249 | enum Stage { 250 | ACQUIRE_VOCAB, 251 | FILTER_SORT_VOCAB, 252 | CREATE_HUFFMAN_ENCODING, 253 | TRAIN_NEURAL_NETWORK, 254 | } 255 | 256 | /** 257 | * Called during word2vec training 258 | *

259 | * Note that this is called in a separate thread from the processing thread 260 | * @param stage Current {@link Stage} of processing 261 | * @param progress Progress of the current stage as a double value between 0 and 1 262 | */ 263 | void update(Stage stage, double progress); 264 | } 265 | } -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/huffman/HuffmanCoding.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.huffman; 2 | 3 | import com.google.common.base.Preconditions; 4 | import com.google.common.collect.ImmutableMap; 5 | import com.google.common.collect.ImmutableMultiset; 6 | import com.google.common.collect.Multiset.Entry; 7 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener; 8 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener.Stage; 9 | 10 | import java.util.ArrayList; 11 | import java.util.Map; 12 | 13 | /** 14 | * Word2Vec library relies on a Huffman encoding scheme 15 | *

16 | * Note that the generated codes and the index of the parents are both used in the 17 | * hierarchical softmax portions of the neural network training phase 18 | *

19 | */ 20 | public class HuffmanCoding { 21 | /** Node */ 22 | public static class HuffmanNode { 23 | /** Array of 0's and 1's */ 24 | public final byte[] code; 25 | /** Array of parent node index offsets */ 26 | public final int[] point; 27 | /** Index of the Huffman node */ 28 | public final int idx; 29 | /** Frequency of the token */ 30 | public final int count; 31 | 32 | private HuffmanNode(byte[] code, int[] point, int idx, int count) { 33 | this.code = code; 34 | this.point = point; 35 | this.idx = idx; 36 | this.count = count; 37 | } 38 | } 39 | 40 | private final ImmutableMultiset vocab; 41 | private final TrainingProgressListener listener; 42 | 43 | /** 44 | * @param vocab {@link Multiset} of tokens, sorted by frequency descending 45 | * @param listener Progress listener 46 | */ 47 | public HuffmanCoding(ImmutableMultiset vocab, TrainingProgressListener listener) { 48 | this.vocab = vocab; 49 | this.listener = listener; 50 | } 51 | 52 | /** 53 | * @return {@link Map} from each given token to a {@link HuffmanNode} 54 | */ 55 | public Map encode() throws InterruptedException { 56 | final int numTokens = vocab.elementSet().size(); 57 | 58 | int[] parentNode = new int[numTokens * 2 + 1]; 59 | byte[] binary = new byte[numTokens * 2 + 1]; 60 | long[] count = new long[numTokens * 2 + 1]; 61 | int i = 0; 62 | for (Entry e : vocab.entrySet()) { 63 | count[i] = e.getCount(); 64 | i++; 65 | } 66 | Preconditions.checkState(i == numTokens, "Expected %s to match %s", i, numTokens); 67 | for (i = numTokens; i < count.length; i++) 68 | count[i] = (long)1e15; 69 | 70 | createTree(numTokens, count, binary, parentNode); 71 | 72 | return encode(binary, parentNode); 73 | } 74 | 75 | /** 76 | * Populate the count, binary, and parentNode arrays with the Huffman tree 77 | * This uses the linear time method assuming that the count array is sorted 78 | */ 79 | private void createTree(int numTokens, long[] count, byte[] binary, int[] parentNode) throws InterruptedException { 80 | int min1i; 81 | int min2i; 82 | int pos1 = numTokens - 1; 83 | int pos2 = numTokens; 84 | 85 | // Construct the Huffman tree by adding one node at a time 86 | for (int a = 0; a < numTokens - 1; a++) { 87 | // First, find two smallest nodes 'min1, min2' 88 | if (pos1 >= 0) { 89 | if (count[pos1] < count[pos2]) { 90 | min1i = pos1; 91 | pos1--; 92 | } else { 93 | min1i = pos2; 94 | pos2++; 95 | } 96 | } else { 97 | min1i = pos2; 98 | pos2++; 99 | } 100 | 101 | if (pos1 >= 0) { 102 | if (count[pos1] < count[pos2]) { 103 | min2i = pos1; 104 | pos1--; 105 | } else { 106 | min2i = pos2; 107 | pos2++; 108 | } 109 | } else { 110 | min2i = pos2; 111 | pos2++; 112 | } 113 | 114 | int newNodeIdx = numTokens + a; 115 | count[newNodeIdx] = count[min1i] + count[min2i]; 116 | parentNode[min1i] = newNodeIdx; 117 | parentNode[min2i] = newNodeIdx; 118 | binary[min2i] = 1; 119 | 120 | if (a % 1_000 == 0) { 121 | if (Thread.currentThread().isInterrupted()) 122 | throw new InterruptedException("Interrupted while encoding huffman tree"); 123 | listener.update(Stage.CREATE_HUFFMAN_ENCODING, (0.5 * a) / numTokens); 124 | } 125 | } 126 | } 127 | 128 | /** @return Ordered map from each token to its {@link HuffmanNode}, ordered by frequency descending */ 129 | private Map encode(byte[] binary, int[] parentNode) throws InterruptedException { 130 | int numTokens = vocab.elementSet().size(); 131 | 132 | // Now assign binary code to each unique token 133 | ImmutableMap.Builder result = ImmutableMap.builder(); 134 | int nodeIdx = 0; 135 | for (Entry e : vocab.entrySet()) { 136 | int curNodeIdx = nodeIdx; 137 | ArrayList code = new ArrayList<>(); 138 | ArrayList points = new ArrayList<>(); 139 | while (true) { 140 | code.add(binary[curNodeIdx]); 141 | points.add(curNodeIdx); 142 | curNodeIdx = parentNode[curNodeIdx]; 143 | if (curNodeIdx == numTokens * 2 - 2) 144 | break; 145 | } 146 | int codeLen = code.size(); 147 | final int count = e.getCount(); 148 | final byte[] rawCode = new byte[codeLen]; 149 | final int[] rawPoints = new int[codeLen + 1]; 150 | 151 | rawPoints[0] = numTokens - 2; 152 | for (int i = 0; i < codeLen; i++) { 153 | rawCode[codeLen - i - 1] = code.get(i); 154 | rawPoints[codeLen - i] = points.get(i) - numTokens; 155 | } 156 | 157 | String token = e.getElement(); 158 | result.put(token, new HuffmanNode(rawCode, rawPoints, nodeIdx, count)); 159 | 160 | if (nodeIdx % 1_000 == 0) { 161 | if (Thread.currentThread().isInterrupted()) 162 | throw new InterruptedException("Interrupted while encoding huffman tree"); 163 | listener.update(Stage.CREATE_HUFFMAN_ENCODING, 0.5 + (0.5 * nodeIdx) / numTokens); 164 | } 165 | 166 | nodeIdx++; 167 | } 168 | 169 | return result.build(); 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/neuralnetwork/CBOWModelTrainer.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.neuralnetwork; 2 | 3 | import com.google.common.collect.Multiset; 4 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener; 5 | import com.medallia.word2vec.huffman.HuffmanCoding.HuffmanNode; 6 | 7 | import java.util.List; 8 | import java.util.Map; 9 | 10 | /** 11 | * Trainer for neural network using continuous bag of words 12 | */ 13 | class CBOWModelTrainer extends NeuralNetworkTrainer { 14 | 15 | CBOWModelTrainer(NeuralNetworkConfig config, Multiset counts, Map huffmanNodes, TrainingProgressListener listener) { 16 | super(config, counts, huffmanNodes, listener); 17 | } 18 | 19 | /** {@link Worker} for {@link CBOWModelTrainer} */ 20 | private class CBOWWorker extends Worker { 21 | private CBOWWorker(int randomSeed, int iter, Iterable> batch) { 22 | super(randomSeed, iter, batch); 23 | } 24 | 25 | @Override void trainSentence(List sentence) { 26 | int sentenceLength = sentence.size(); 27 | 28 | for (int sentencePosition = 0; sentencePosition < sentenceLength; sentencePosition++) { 29 | String word = sentence.get(sentencePosition); 30 | HuffmanNode huffmanNode = huffmanNodes.get(word); 31 | 32 | for (int c = 0; c < layer1_size; c++) 33 | neu1[c] = 0; 34 | for (int c = 0; c < layer1_size; c++) 35 | neu1e[c] = 0; 36 | 37 | nextRandom = incrementRandom(nextRandom); 38 | int b = (int)((nextRandom % window) + window) % window; 39 | 40 | // in -> hidden 41 | int cw = 0; 42 | for (int a = b; a < window * 2 + 1 - b; a++) { 43 | if (a == window) 44 | continue; 45 | int c = sentencePosition - window + a; 46 | if (c < 0 || c >= sentenceLength) 47 | continue; 48 | int idx = huffmanNodes.get(sentence.get(c)).idx; 49 | for (int d = 0; d < layer1_size; d++) { 50 | neu1[d] += syn0[idx][d]; 51 | } 52 | 53 | cw++; 54 | } 55 | 56 | if (cw == 0) 57 | continue; 58 | 59 | for (int c = 0; c < layer1_size; c++) 60 | neu1[c] /= cw; 61 | 62 | if (config.useHierarchicalSoftmax) { 63 | for (int d = 0; d < huffmanNode.code.length; d++) { 64 | double f = 0; 65 | int l2 = huffmanNode.point[d]; 66 | // Propagate hidden -> output 67 | for (int c = 0; c < layer1_size; c++) 68 | f += neu1[c] * syn1[l2][c]; 69 | if (f <= -MAX_EXP || f >= MAX_EXP) 70 | continue; 71 | else 72 | f = EXP_TABLE[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]; 73 | // 'g' is the gradient multiplied by the learning rate 74 | double g = (1 - huffmanNode.code[d] - f) * alpha; 75 | // Propagate errors output -> hidden 76 | for (int c = 0; c < layer1_size; c++) 77 | neu1e[c] += g * syn1[l2][c]; 78 | // Learn weights hidden -> output 79 | for (int c = 0; c < layer1_size; c++) 80 | syn1[l2][c] += g * neu1[c]; 81 | } 82 | } 83 | 84 | handleNegativeSampling(huffmanNode); 85 | 86 | // hidden -> in 87 | for (int a = b; a < window * 2 + 1 - b; a++) { 88 | if (a == window) 89 | continue; 90 | int c = sentencePosition - window + a; 91 | if (c < 0 || c >= sentenceLength) 92 | continue; 93 | int idx = huffmanNodes.get(sentence.get(c)).idx; 94 | for (int d = 0; d < layer1_size; d++) 95 | syn0[idx][d] += neu1e[d]; 96 | } 97 | } 98 | } 99 | } 100 | 101 | @Override Worker createWorker(int randomSeed, int iter, Iterable> batch) { 102 | return new CBOWWorker(randomSeed, iter, batch); 103 | } 104 | } -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkConfig.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.neuralnetwork; 2 | 3 | import com.google.common.collect.ImmutableMultiset; 4 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener; 5 | import com.medallia.word2vec.huffman.HuffmanCoding.HuffmanNode; 6 | 7 | import java.util.Map; 8 | 9 | /** Fixed configuration for training the neural network */ 10 | public class NeuralNetworkConfig { 11 | final int numThreads; 12 | final int iterations; 13 | final NeuralNetworkType type; 14 | final int layerSize; 15 | final int windowSize; 16 | final int negativeSamples; 17 | final boolean useHierarchicalSoftmax; 18 | 19 | final double initialLearningRate; 20 | final double downSampleRate; 21 | 22 | /** Constructor */ 23 | public NeuralNetworkConfig( 24 | NeuralNetworkType type, 25 | int numThreads, 26 | int iterations, 27 | int layerSize, 28 | int windowSize, 29 | int negativeSamples, 30 | double downSampleRate, 31 | double initialLearningRate, 32 | boolean useHierarchicalSoftmax) { 33 | this.type = type; 34 | this.iterations = iterations; 35 | this.numThreads = numThreads; 36 | this.layerSize = layerSize; 37 | this.windowSize = windowSize; 38 | this.negativeSamples = negativeSamples; 39 | this.useHierarchicalSoftmax = useHierarchicalSoftmax; 40 | this.initialLearningRate = initialLearningRate; 41 | this.downSampleRate = downSampleRate; 42 | } 43 | 44 | /** @return {@link NeuralNetworkTrainer} */ 45 | public NeuralNetworkTrainer createTrainer(ImmutableMultiset vocab, Map huffmanNodes, TrainingProgressListener listener) { 46 | return type.createTrainer(this, vocab, huffmanNodes, listener); 47 | } 48 | 49 | @Override public String toString() { 50 | return String.format("%s with %s threads, %s iterations[%s layer size, %s window, %s hierarchical softmax, %s negative samples, %s initial learning rate, %s down sample rate]", 51 | type.name(), 52 | numThreads, 53 | iterations, 54 | layerSize, 55 | windowSize, 56 | useHierarchicalSoftmax ? "using" : "not using", 57 | negativeSamples, 58 | initialLearningRate, 59 | downSampleRate 60 | ); 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.neuralnetwork; 2 | 3 | import com.google.common.collect.Iterables; 4 | import com.google.common.collect.Multiset; 5 | import com.google.common.util.concurrent.Futures; 6 | import com.google.common.util.concurrent.ListenableFuture; 7 | import com.google.common.util.concurrent.ListeningExecutorService; 8 | import com.google.common.util.concurrent.MoreExecutors; 9 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener; 10 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener.Stage; 11 | import com.medallia.word2vec.huffman.HuffmanCoding.HuffmanNode; 12 | import com.medallia.word2vec.util.CallableVoid; 13 | 14 | import java.util.ArrayList; 15 | import java.util.Iterator; 16 | import java.util.List; 17 | import java.util.Map; 18 | import java.util.concurrent.ExecutionException; 19 | import java.util.concurrent.Executors; 20 | import java.util.concurrent.atomic.AtomicInteger; 21 | 22 | /** Parent class for training word2vec's neural network */ 23 | public abstract class NeuralNetworkTrainer { 24 | /** Sentences longer than this are broken into multiple chunks */ 25 | private static final int MAX_SENTENCE_LENGTH = 1_000; 26 | 27 | /** Boundary for maximum exponent allowed */ 28 | static final int MAX_EXP = 6; 29 | 30 | /** Size of the pre-cached exponent table */ 31 | static final int EXP_TABLE_SIZE = 1_000; 32 | static final double[] EXP_TABLE = new double[EXP_TABLE_SIZE]; 33 | static { 34 | for (int i = 0; i < EXP_TABLE_SIZE; i++) { 35 | // Precompute the exp() table 36 | EXP_TABLE[i] = Math.exp((i / (double)EXP_TABLE_SIZE * 2 - 1) * MAX_EXP); 37 | // Precompute f(x) = x / (x + 1) 38 | EXP_TABLE[i] /= EXP_TABLE[i] + 1; 39 | } 40 | } 41 | 42 | private static final int TABLE_SIZE = (int)1e8; 43 | 44 | private final TrainingProgressListener listener; 45 | 46 | final NeuralNetworkConfig config; 47 | final Map huffmanNodes; 48 | private final int vocabSize; 49 | final int layer1_size; 50 | final int window; 51 | /** 52 | * In the C version, this includes the token that replaces a newline character 53 | */ 54 | int numTrainedTokens; 55 | 56 | /* The following includes shared state that is updated per worker thread */ 57 | 58 | /** 59 | * To be precise, this is the number of words in the training data that exist in the vocabulary 60 | * which have been processed so far. It includes words that are discarded from sampling. 61 | * Note that each word is processed once per iteration. 62 | */ 63 | protected final AtomicInteger actualWordCount; 64 | /** Learning rate, affects how fast values in the layers get updated */ 65 | volatile double alpha; 66 | /** 67 | * This contains the outer layers of the neural network 68 | * First dimension is the vocab, second is the layer 69 | */ 70 | final double[][] syn0; 71 | /** This contains hidden layers of the neural network */ 72 | final double[][] syn1; 73 | /** This is used for negative sampling */ 74 | private final double[][] syn1neg; 75 | /** Used for negative sampling */ 76 | private final int[] table; 77 | long startNano; 78 | 79 | NeuralNetworkTrainer(NeuralNetworkConfig config, Multiset vocab, Map huffmanNodes, TrainingProgressListener listener) { 80 | this.config = config; 81 | this.huffmanNodes = huffmanNodes; 82 | this.listener = listener; 83 | this.vocabSize = huffmanNodes.size(); 84 | this.numTrainedTokens = vocab.size(); 85 | this.layer1_size = config.layerSize; 86 | this.window = config.windowSize; 87 | 88 | this.actualWordCount = new AtomicInteger(); 89 | this.alpha = config.initialLearningRate; 90 | 91 | this.syn0 = new double[vocabSize][layer1_size]; 92 | this.syn1 = new double[vocabSize][layer1_size]; 93 | this.syn1neg = new double[vocabSize][layer1_size]; 94 | this.table = new int[TABLE_SIZE]; 95 | 96 | initializeSyn0(); 97 | initializeUnigramTable(); 98 | } 99 | 100 | private void initializeUnigramTable() { 101 | long trainWordsPow = 0; 102 | double power = 0.75; 103 | 104 | for (HuffmanNode node : huffmanNodes.values()) { 105 | trainWordsPow += Math.pow(node.count, power); 106 | } 107 | 108 | Iterator nodeIter = huffmanNodes.values().iterator(); 109 | HuffmanNode last = nodeIter.next(); 110 | double d1 = Math.pow(last.count, power) / trainWordsPow; 111 | int i = 0; 112 | for (int a = 0; a < TABLE_SIZE; a++) { 113 | table[a] = i; 114 | if (a / (double)TABLE_SIZE > d1) { 115 | i++; 116 | HuffmanNode next = nodeIter.hasNext() 117 | ? nodeIter.next() 118 | : last; 119 | 120 | d1 += Math.pow(next.count, power) / trainWordsPow; 121 | 122 | last = next; 123 | } 124 | } 125 | } 126 | 127 | private void initializeSyn0() { 128 | long nextRandom = 1; 129 | for (int a = 0; a < huffmanNodes.size(); a++) { 130 | // Consume a random for fun 131 | // Actually we do this to use up the injected token 132 | nextRandom = incrementRandom(nextRandom); 133 | for (int b = 0; b < layer1_size; b++) { 134 | nextRandom = incrementRandom(nextRandom); 135 | syn0[a][b] = (((nextRandom & 0xFFFF) / (double)65_536) - 0.5) / layer1_size; 136 | } 137 | } 138 | } 139 | 140 | /** @return Next random value to use */ 141 | static long incrementRandom(long r) { 142 | return r * 25_214_903_917L + 11; 143 | } 144 | 145 | /** Represents a neural network model */ 146 | public interface NeuralNetworkModel { 147 | /** Size of the layers */ 148 | int layerSize(); 149 | /** Resulting vectors */ 150 | double[][] vectors(); 151 | } 152 | 153 | /** @return Trained NN model */ 154 | public NeuralNetworkModel train(Iterable> sentences) throws InterruptedException { 155 | ListeningExecutorService ex = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(config.numThreads)); 156 | 157 | int numSentences = Iterables.size(sentences); 158 | numTrainedTokens += numSentences; 159 | 160 | // Partition the sentences evenly amongst the threads 161 | Iterable>> partitioned = Iterables.partition(sentences, numSentences / config.numThreads + 1); 162 | 163 | try { 164 | listener.update(Stage.TRAIN_NEURAL_NETWORK, 0.0); 165 | for (int iter = config.iterations; iter > 0; iter--) { 166 | List tasks = new ArrayList<>(); 167 | int i = 0; 168 | for (final List> batch : partitioned) { 169 | tasks.add(createWorker(i, iter, batch)); 170 | i++; 171 | } 172 | 173 | List> futures = new ArrayList<>(tasks.size()); 174 | for (CallableVoid task : tasks) 175 | futures.add(ex.submit(task)); 176 | try { 177 | Futures.allAsList(futures).get(); 178 | } catch (ExecutionException e) { 179 | throw new IllegalStateException("Error training neural network", e.getCause()); 180 | } 181 | } 182 | ex.shutdown(); 183 | } finally { 184 | ex.shutdownNow(); 185 | } 186 | 187 | return new NeuralNetworkModel() { 188 | @Override public int layerSize() { 189 | return config.layerSize; 190 | } 191 | 192 | @Override public double[][] vectors() { 193 | return syn0; 194 | } 195 | }; 196 | } 197 | 198 | /** @return {@link Worker} to process the given sentences */ 199 | abstract Worker createWorker(int randomSeed, int iter, Iterable> batch); 200 | 201 | /** Worker thread that updates the neural network model */ 202 | abstract class Worker extends CallableVoid { 203 | private static final int LEARNING_RATE_UPDATE_FREQUENCY = 10_000; 204 | 205 | long nextRandom; 206 | final int iter; 207 | final Iterable> batch; 208 | 209 | /** 210 | * The number of words observed in the training data for this worker that exist 211 | * in the vocabulary. It includes words that are discarded from sampling. 212 | */ 213 | int wordCount; 214 | /** Value of wordCount the last time alpha was updated */ 215 | int lastWordCount; 216 | 217 | final double[] neu1 = new double[layer1_size]; 218 | final double[] neu1e = new double[layer1_size]; 219 | 220 | Worker(int randomSeed, int iter, Iterable> batch) { 221 | this.nextRandom = randomSeed; 222 | this.iter = iter; 223 | this.batch = batch; 224 | } 225 | 226 | @Override public void run() throws InterruptedException { 227 | for (List sentence : batch) { 228 | List filteredSentence = new ArrayList<>(sentence.size()); 229 | for (String s : sentence) { 230 | if (!huffmanNodes.containsKey(s)) 231 | continue; 232 | 233 | wordCount++; 234 | if (config.downSampleRate > 0) { 235 | HuffmanNode huffmanNode = huffmanNodes.get(s); 236 | double random = (Math.sqrt(huffmanNode.count / (config.downSampleRate * numTrainedTokens)) + 1) 237 | * (config.downSampleRate * numTrainedTokens) / huffmanNode.count; 238 | nextRandom = incrementRandom(nextRandom); 239 | if (random < (nextRandom & 0xFFFF) / (double)65_536) { 240 | continue; 241 | } 242 | } 243 | 244 | filteredSentence.add(s); 245 | } 246 | 247 | // Increment word count one extra for the injected token 248 | // Turns out if you don't do this, the produced word vectors aren't as tasty 249 | wordCount++; 250 | 251 | Iterable> partitioned = Iterables.partition(filteredSentence, MAX_SENTENCE_LENGTH); 252 | for (List chunked : partitioned) { 253 | if (Thread.currentThread().isInterrupted()) 254 | throw new InterruptedException("Interrupted while training word2vec model"); 255 | 256 | if (wordCount - lastWordCount > LEARNING_RATE_UPDATE_FREQUENCY) { 257 | updateAlpha(iter); 258 | } 259 | trainSentence(chunked); 260 | } 261 | } 262 | 263 | actualWordCount.addAndGet(wordCount - lastWordCount); 264 | } 265 | 266 | /** 267 | * Degrades the learning rate (alpha) steadily towards 0 268 | * @param iter Only used for debugging 269 | */ 270 | private void updateAlpha(int iter) { 271 | int currentActual = actualWordCount.addAndGet(wordCount - lastWordCount); 272 | lastWordCount = wordCount; 273 | 274 | // Degrade the learning rate linearly towards 0 but keep a minimum 275 | alpha = config.initialLearningRate * Math.max( 276 | 1 - currentActual / (double)(config.iterations * numTrainedTokens), 277 | 0.0001 278 | ); 279 | 280 | listener.update( 281 | Stage.TRAIN_NEURAL_NETWORK, 282 | currentActual / (double) (config.iterations * numTrainedTokens + 1) 283 | ); 284 | } 285 | 286 | void handleNegativeSampling(HuffmanNode huffmanNode) { 287 | for (int d = 0; d <= config.negativeSamples; d++) { 288 | int target; 289 | final int label; 290 | if (d == 0) { 291 | target = huffmanNode.idx; 292 | label = 1; 293 | } else { 294 | nextRandom = incrementRandom(nextRandom); 295 | target = table[(int) (((nextRandom >> 16) % TABLE_SIZE) + TABLE_SIZE) % TABLE_SIZE]; 296 | if (target == 0) 297 | target = (int)(((nextRandom % (vocabSize - 1)) + vocabSize - 1) % (vocabSize - 1)) + 1; 298 | if (target == huffmanNode.idx) 299 | continue; 300 | label = 0; 301 | } 302 | int l2 = target; 303 | double f = 0; 304 | for (int c = 0; c < layer1_size; c++) 305 | f += neu1[c] * syn1neg[l2][c]; 306 | final double g; 307 | if (f > MAX_EXP) 308 | g = (label - 1) * alpha; 309 | else if (f < -MAX_EXP) 310 | g = (label - 0) * alpha; 311 | else 312 | g = (label - EXP_TABLE[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]) * alpha; 313 | for (int c = 0; c < layer1_size; c++) 314 | neu1e[c] += g * syn1neg[l2][c]; 315 | for (int c = 0; c < layer1_size; c++) 316 | syn1neg[l2][c] += g * neu1[c]; 317 | } 318 | } 319 | 320 | /** Update the model with the given raw sentence */ 321 | abstract void trainSentence(List unfiltered); 322 | } 323 | } 324 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkType.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.neuralnetwork; 2 | 3 | import com.google.common.collect.Multiset; 4 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener; 5 | import com.medallia.word2vec.huffman.HuffmanCoding.HuffmanNode; 6 | 7 | import java.util.Map; 8 | 9 | /** 10 | * Supported types for the neural network 11 | */ 12 | public enum NeuralNetworkType { 13 | /** Faster, slightly better accuracy for frequent words */ 14 | CBOW { 15 | @Override NeuralNetworkTrainer createTrainer(NeuralNetworkConfig config, Multiset counts, Map huffmanNodes, TrainingProgressListener listener) { 16 | return new CBOWModelTrainer(config, counts, huffmanNodes, listener); 17 | } 18 | 19 | @Override public double getDefaultInitialLearningRate() { 20 | return 0.05; 21 | } 22 | }, 23 | /** Slower, better for infrequent words */ 24 | SKIP_GRAM { 25 | @Override NeuralNetworkTrainer createTrainer(NeuralNetworkConfig config, Multiset counts, Map huffmanNodes, TrainingProgressListener listener) { 26 | return new SkipGramModelTrainer(config, counts, huffmanNodes, listener); 27 | } 28 | 29 | @Override public double getDefaultInitialLearningRate() { 30 | return 0.025; 31 | } 32 | }, 33 | ; 34 | 35 | /** @return Default initial learning rate */ 36 | public abstract double getDefaultInitialLearningRate(); 37 | 38 | /** @return New {@link NeuralNetworkTrainer} */ 39 | abstract NeuralNetworkTrainer createTrainer(NeuralNetworkConfig config, Multiset counts, Map huffmanNodes, TrainingProgressListener listener); 40 | } -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/neuralnetwork/SkipGramModelTrainer.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.neuralnetwork; 2 | 3 | import com.google.common.collect.Multiset; 4 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener; 5 | import com.medallia.word2vec.huffman.HuffmanCoding.HuffmanNode; 6 | 7 | import java.util.List; 8 | import java.util.Map; 9 | 10 | /** 11 | * Trainer for neural network using skip gram 12 | */ 13 | class SkipGramModelTrainer extends NeuralNetworkTrainer { 14 | 15 | SkipGramModelTrainer(NeuralNetworkConfig config, Multiset counts, Map huffmanNodes, TrainingProgressListener listener) { 16 | super(config, counts, huffmanNodes, listener); 17 | } 18 | 19 | /** {@link Worker} for {@link SkipGramModelTrainer} */ 20 | private class SkipGramWorker extends Worker { 21 | private SkipGramWorker(int randomSeed, int iter, Iterable> batch) { 22 | super(randomSeed, iter, batch); 23 | } 24 | 25 | @Override void trainSentence(List sentence) { 26 | int sentenceLength = sentence.size(); 27 | 28 | for (int sentencePosition = 0; sentencePosition < sentenceLength; sentencePosition++) { 29 | String word = sentence.get(sentencePosition); 30 | HuffmanNode huffmanNode = huffmanNodes.get(word); 31 | 32 | for (int c = 0; c < layer1_size; c++) 33 | neu1[c] = 0; 34 | for (int c = 0; c < layer1_size; c++) 35 | neu1e[c] = 0; 36 | nextRandom = incrementRandom(nextRandom); 37 | 38 | int b = (int)(((nextRandom % window) + nextRandom) % window); 39 | 40 | for (int a = b; a < window * 2 + 1 - b; a++) { 41 | if (a == window) 42 | continue; 43 | int c = sentencePosition - window + a; 44 | 45 | if (c < 0 || c >= sentenceLength) 46 | continue; 47 | for (int d = 0; d < layer1_size; d++) 48 | neu1e[d] = 0; 49 | 50 | int l1 = huffmanNodes.get(sentence.get(c)).idx; 51 | 52 | if (config.useHierarchicalSoftmax) { 53 | for (int d = 0; d < huffmanNode.code.length; d++) { 54 | double f = 0; 55 | int l2 = huffmanNode.point[d]; 56 | // Propagate hidden -> output 57 | for (int e = 0; e < layer1_size; e++) 58 | f += syn0[l1][e] * syn1[l2][e]; 59 | 60 | if (f <= -MAX_EXP || f >= MAX_EXP) 61 | continue; 62 | else 63 | f = EXP_TABLE[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]; 64 | // 'g' is the gradient multiplied by the learning rate 65 | double g = (1 - huffmanNode.code[d] - f) * alpha; 66 | 67 | // Propagate errors output -> hidden 68 | for (int e = 0; e < layer1_size; e++) 69 | neu1e[e] += g * syn1[l2][e]; 70 | // Learn weights hidden -> output 71 | for (int e = 0; e < layer1_size; e++) 72 | syn1[l2][e] += g * syn0[l1][e]; 73 | } 74 | } 75 | 76 | handleNegativeSampling(huffmanNode); 77 | 78 | // Learn weights input -> hidden 79 | for (int d = 0; d < layer1_size; d++) { 80 | syn0[l1][d] += neu1e[d]; 81 | } 82 | } 83 | } 84 | } 85 | } 86 | 87 | @Override Worker createWorker(int randomSeed, int iter, Iterable> batch) { 88 | return new SkipGramWorker(randomSeed, iter, batch); 89 | } 90 | } -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/thrift/Word2VecModelThrift.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Autogenerated by Thrift Compiler (0.9.1) 3 | * 4 | * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING 5 | * @generated 6 | */ 7 | package com.medallia.word2vec.thrift; 8 | 9 | import org.apache.commons.lang3.builder.HashCodeBuilder; 10 | import org.apache.thrift.EncodingUtils; 11 | import org.apache.thrift.protocol.TTupleProtocol; 12 | import org.apache.thrift.scheme.IScheme; 13 | import org.apache.thrift.scheme.SchemeFactory; 14 | import org.apache.thrift.scheme.StandardScheme; 15 | import org.apache.thrift.scheme.TupleScheme; 16 | 17 | import java.util.ArrayList; 18 | import java.util.BitSet; 19 | import java.util.Collections; 20 | import java.util.EnumMap; 21 | import java.util.EnumSet; 22 | import java.util.HashMap; 23 | import java.util.List; 24 | import java.util.Map; 25 | 26 | public class Word2VecModelThrift implements org.apache.thrift.TBase, java.io.Serializable, Cloneable, Comparable { 27 | private static final org.apache.thrift.protocol.TStruct STRUCT_DESC = new org.apache.thrift.protocol.TStruct("Word2VecModelThrift"); 28 | 29 | private static final org.apache.thrift.protocol.TField VOCAB_FIELD_DESC = new org.apache.thrift.protocol.TField("vocab", org.apache.thrift.protocol.TType.LIST, (short)1); 30 | private static final org.apache.thrift.protocol.TField LAYER_SIZE_FIELD_DESC = new org.apache.thrift.protocol.TField("layerSize", org.apache.thrift.protocol.TType.I32, (short)2); 31 | private static final org.apache.thrift.protocol.TField VECTORS_FIELD_DESC = new org.apache.thrift.protocol.TField("vectors", org.apache.thrift.protocol.TType.LIST, (short)3); 32 | 33 | private static final Map, SchemeFactory> schemes = new HashMap, SchemeFactory>(); 34 | static { 35 | schemes.put(StandardScheme.class, new Word2VecModelThriftStandardSchemeFactory()); 36 | schemes.put(TupleScheme.class, new Word2VecModelThriftTupleSchemeFactory()); 37 | } 38 | 39 | private List vocab; // optional 40 | private int layerSize; // optional 41 | private List vectors; // optional 42 | 43 | /** The set of fields this struct contains, along with convenience methods for finding and manipulating them. */ 44 | public enum _Fields implements org.apache.thrift.TFieldIdEnum { 45 | VOCAB((short)1, "vocab"), 46 | LAYER_SIZE((short)2, "layerSize"), 47 | VECTORS((short)3, "vectors"); 48 | 49 | private static final Map byName = new HashMap(); 50 | 51 | static { 52 | for (_Fields field : EnumSet.allOf(_Fields.class)) { 53 | byName.put(field.getFieldName(), field); 54 | } 55 | } 56 | 57 | /** 58 | * Find the _Fields constant that matches fieldId, or null if its not found. 59 | */ 60 | public static _Fields findByThriftId(int fieldId) { 61 | switch(fieldId) { 62 | case 1: // VOCAB 63 | return VOCAB; 64 | case 2: // LAYER_SIZE 65 | return LAYER_SIZE; 66 | case 3: // VECTORS 67 | return VECTORS; 68 | default: 69 | return null; 70 | } 71 | } 72 | 73 | /** 74 | * Find the _Fields constant that matches fieldId, throwing an exception 75 | * if it is not found. 76 | */ 77 | public static _Fields findByThriftIdOrThrow(int fieldId) { 78 | _Fields fields = findByThriftId(fieldId); 79 | if (fields == null) throw new IllegalArgumentException("Field " + fieldId + " doesn't exist!"); 80 | return fields; 81 | } 82 | 83 | /** 84 | * Find the _Fields constant that matches name, or null if its not found. 85 | */ 86 | public static _Fields findByName(String name) { 87 | return byName.get(name); 88 | } 89 | 90 | private final short _thriftId; 91 | private final String _fieldName; 92 | 93 | _Fields(short thriftId, String fieldName) { 94 | _thriftId = thriftId; 95 | _fieldName = fieldName; 96 | } 97 | 98 | public short getThriftFieldId() { 99 | return _thriftId; 100 | } 101 | 102 | public String getFieldName() { 103 | return _fieldName; 104 | } 105 | } 106 | 107 | // isset id assignments 108 | private static final int __LAYERSIZE_ISSET_ID = 0; 109 | private byte __isset_bitfield = 0; 110 | private static _Fields optionals[] = {_Fields.VOCAB, _Fields.LAYER_SIZE, _Fields.VECTORS}; 111 | public static final Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> metaDataMap; 112 | static { 113 | Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> tmpMap = new EnumMap<_Fields, org.apache.thrift.meta_data.FieldMetaData>(_Fields.class); 114 | tmpMap.put(_Fields.VOCAB, new org.apache.thrift.meta_data.FieldMetaData("vocab", org.apache.thrift.TFieldRequirementType.OPTIONAL, 115 | new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, 116 | new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING)))); 117 | tmpMap.put(_Fields.LAYER_SIZE, new org.apache.thrift.meta_data.FieldMetaData("layerSize", org.apache.thrift.TFieldRequirementType.OPTIONAL, 118 | new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32))); 119 | tmpMap.put(_Fields.VECTORS, new org.apache.thrift.meta_data.FieldMetaData("vectors", org.apache.thrift.TFieldRequirementType.OPTIONAL, 120 | new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, 121 | new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.DOUBLE)))); 122 | metaDataMap = Collections.unmodifiableMap(tmpMap); 123 | org.apache.thrift.meta_data.FieldMetaData.addStructMetaDataMap(Word2VecModelThrift.class, metaDataMap); 124 | } 125 | 126 | public Word2VecModelThrift() { 127 | } 128 | 129 | /** 130 | * Performs a deep copy on other. 131 | */ 132 | public Word2VecModelThrift(Word2VecModelThrift other) { 133 | __isset_bitfield = other.__isset_bitfield; 134 | if (other.isSetVocab()) { 135 | List __this__vocab = new ArrayList(other.vocab); 136 | this.vocab = __this__vocab; 137 | } 138 | this.layerSize = other.layerSize; 139 | if (other.isSetVectors()) { 140 | List __this__vectors = new ArrayList(other.vectors); 141 | this.vectors = __this__vectors; 142 | } 143 | } 144 | 145 | public Word2VecModelThrift deepCopy() { 146 | return new Word2VecModelThrift(this); 147 | } 148 | 149 | @Override 150 | public void clear() { 151 | this.vocab = null; 152 | setLayerSizeIsSet(false); 153 | this.layerSize = 0; 154 | this.vectors = null; 155 | } 156 | 157 | public int getVocabSize() { 158 | return (this.vocab == null) ? 0 : this.vocab.size(); 159 | } 160 | 161 | public java.util.Iterator getVocabIterator() { 162 | return (this.vocab == null) ? null : this.vocab.iterator(); 163 | } 164 | 165 | public void addToVocab(String elem) { 166 | if (this.vocab == null) { 167 | this.vocab = new ArrayList(); 168 | } 169 | this.vocab.add(elem); 170 | } 171 | 172 | public List getVocab() { 173 | return this.vocab; 174 | } 175 | 176 | public Word2VecModelThrift setVocab(List vocab) { 177 | this.vocab = vocab; 178 | return this; 179 | } 180 | 181 | public void unsetVocab() { 182 | this.vocab = null; 183 | } 184 | 185 | /** Returns true if field vocab is set (has been assigned a value) and false otherwise */ 186 | public boolean isSetVocab() { 187 | return this.vocab != null; 188 | } 189 | 190 | public void setVocabIsSet(boolean value) { 191 | if (!value) { 192 | this.vocab = null; 193 | } 194 | } 195 | 196 | public int getLayerSize() { 197 | return this.layerSize; 198 | } 199 | 200 | public Word2VecModelThrift setLayerSize(int layerSize) { 201 | this.layerSize = layerSize; 202 | setLayerSizeIsSet(true); 203 | return this; 204 | } 205 | 206 | public void unsetLayerSize() { 207 | __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __LAYERSIZE_ISSET_ID); 208 | } 209 | 210 | /** Returns true if field layerSize is set (has been assigned a value) and false otherwise */ 211 | public boolean isSetLayerSize() { 212 | return EncodingUtils.testBit(__isset_bitfield, __LAYERSIZE_ISSET_ID); 213 | } 214 | 215 | public void setLayerSizeIsSet(boolean value) { 216 | __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __LAYERSIZE_ISSET_ID, value); 217 | } 218 | 219 | public int getVectorsSize() { 220 | return (this.vectors == null) ? 0 : this.vectors.size(); 221 | } 222 | 223 | public java.util.Iterator getVectorsIterator() { 224 | return (this.vectors == null) ? null : this.vectors.iterator(); 225 | } 226 | 227 | public void addToVectors(double elem) { 228 | if (this.vectors == null) { 229 | this.vectors = new ArrayList(); 230 | } 231 | this.vectors.add(elem); 232 | } 233 | 234 | public List getVectors() { 235 | return this.vectors; 236 | } 237 | 238 | public Word2VecModelThrift setVectors(List vectors) { 239 | this.vectors = vectors; 240 | return this; 241 | } 242 | 243 | public void unsetVectors() { 244 | this.vectors = null; 245 | } 246 | 247 | /** Returns true if field vectors is set (has been assigned a value) and false otherwise */ 248 | public boolean isSetVectors() { 249 | return this.vectors != null; 250 | } 251 | 252 | public void setVectorsIsSet(boolean value) { 253 | if (!value) { 254 | this.vectors = null; 255 | } 256 | } 257 | 258 | public void setFieldValue(_Fields field, Object value) { 259 | switch (field) { 260 | case VOCAB: 261 | if (value == null) { 262 | unsetVocab(); 263 | } else { 264 | setVocab((List)value); 265 | } 266 | break; 267 | 268 | case LAYER_SIZE: 269 | if (value == null) { 270 | unsetLayerSize(); 271 | } else { 272 | setLayerSize((Integer)value); 273 | } 274 | break; 275 | 276 | case VECTORS: 277 | if (value == null) { 278 | unsetVectors(); 279 | } else { 280 | setVectors((List)value); 281 | } 282 | break; 283 | 284 | } 285 | } 286 | 287 | public Object getFieldValue(_Fields field) { 288 | switch (field) { 289 | case VOCAB: 290 | return getVocab(); 291 | 292 | case LAYER_SIZE: 293 | return Integer.valueOf(getLayerSize()); 294 | 295 | case VECTORS: 296 | return getVectors(); 297 | 298 | } 299 | throw new IllegalStateException(); 300 | } 301 | 302 | /** Returns true if field corresponding to fieldID is set (has been assigned a value) and false otherwise */ 303 | public boolean isSet(_Fields field) { 304 | if (field == null) { 305 | throw new IllegalArgumentException(); 306 | } 307 | 308 | switch (field) { 309 | case VOCAB: 310 | return isSetVocab(); 311 | case LAYER_SIZE: 312 | return isSetLayerSize(); 313 | case VECTORS: 314 | return isSetVectors(); 315 | } 316 | throw new IllegalStateException(); 317 | } 318 | 319 | @Override 320 | public boolean equals(Object that) { 321 | if (that == null) 322 | return false; 323 | if (that instanceof Word2VecModelThrift) 324 | return this.equals((Word2VecModelThrift)that); 325 | return false; 326 | } 327 | 328 | public boolean equals(Word2VecModelThrift that) { 329 | if (that == null) 330 | return false; 331 | 332 | boolean this_present_vocab = true && this.isSetVocab(); 333 | boolean that_present_vocab = true && that.isSetVocab(); 334 | if (this_present_vocab || that_present_vocab) { 335 | if (!(this_present_vocab && that_present_vocab)) 336 | return false; 337 | if (!this.vocab.equals(that.vocab)) 338 | return false; 339 | } 340 | 341 | boolean this_present_layerSize = true && this.isSetLayerSize(); 342 | boolean that_present_layerSize = true && that.isSetLayerSize(); 343 | if (this_present_layerSize || that_present_layerSize) { 344 | if (!(this_present_layerSize && that_present_layerSize)) 345 | return false; 346 | if (this.layerSize != that.layerSize) 347 | return false; 348 | } 349 | 350 | boolean this_present_vectors = true && this.isSetVectors(); 351 | boolean that_present_vectors = true && that.isSetVectors(); 352 | if (this_present_vectors || that_present_vectors) { 353 | if (!(this_present_vectors && that_present_vectors)) 354 | return false; 355 | if (!this.vectors.equals(that.vectors)) 356 | return false; 357 | } 358 | 359 | return true; 360 | } 361 | 362 | @Override 363 | public int hashCode() { 364 | HashCodeBuilder builder = new HashCodeBuilder(); 365 | 366 | boolean present_vocab = true && (isSetVocab()); 367 | builder.append(present_vocab); 368 | if (present_vocab) 369 | builder.append(vocab); 370 | 371 | boolean present_layerSize = true && (isSetLayerSize()); 372 | builder.append(present_layerSize); 373 | if (present_layerSize) 374 | builder.append(layerSize); 375 | 376 | boolean present_vectors = true && (isSetVectors()); 377 | builder.append(present_vectors); 378 | if (present_vectors) 379 | builder.append(vectors); 380 | 381 | return builder.toHashCode(); 382 | } 383 | 384 | @Override 385 | public int compareTo(Word2VecModelThrift other) { 386 | if (!getClass().equals(other.getClass())) { 387 | return getClass().getName().compareTo(other.getClass().getName()); 388 | } 389 | 390 | int lastComparison = 0; 391 | 392 | lastComparison = Boolean.valueOf(isSetVocab()).compareTo(other.isSetVocab()); 393 | if (lastComparison != 0) { 394 | return lastComparison; 395 | } 396 | if (isSetVocab()) { 397 | lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.vocab, other.vocab); 398 | if (lastComparison != 0) { 399 | return lastComparison; 400 | } 401 | } 402 | lastComparison = Boolean.valueOf(isSetLayerSize()).compareTo(other.isSetLayerSize()); 403 | if (lastComparison != 0) { 404 | return lastComparison; 405 | } 406 | if (isSetLayerSize()) { 407 | lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.layerSize, other.layerSize); 408 | if (lastComparison != 0) { 409 | return lastComparison; 410 | } 411 | } 412 | lastComparison = Boolean.valueOf(isSetVectors()).compareTo(other.isSetVectors()); 413 | if (lastComparison != 0) { 414 | return lastComparison; 415 | } 416 | if (isSetVectors()) { 417 | lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.vectors, other.vectors); 418 | if (lastComparison != 0) { 419 | return lastComparison; 420 | } 421 | } 422 | return 0; 423 | } 424 | 425 | public _Fields fieldForId(int fieldId) { 426 | return _Fields.findByThriftId(fieldId); 427 | } 428 | 429 | public void read(org.apache.thrift.protocol.TProtocol iprot) throws org.apache.thrift.TException { 430 | schemes.get(iprot.getScheme()).getScheme().read(iprot, this); 431 | } 432 | 433 | public void write(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException { 434 | schemes.get(oprot.getScheme()).getScheme().write(oprot, this); 435 | } 436 | 437 | @Override 438 | public String toString() { 439 | StringBuilder sb = new StringBuilder("Word2VecModelThrift("); 440 | boolean first = true; 441 | 442 | if (isSetVocab()) { 443 | sb.append("vocab:"); 444 | if (this.vocab == null) { 445 | sb.append("null"); 446 | } else { 447 | sb.append(this.vocab); 448 | } 449 | first = false; 450 | } 451 | if (isSetLayerSize()) { 452 | if (!first) sb.append(", "); 453 | sb.append("layerSize:"); 454 | sb.append(this.layerSize); 455 | first = false; 456 | } 457 | if (isSetVectors()) { 458 | if (!first) sb.append(", "); 459 | sb.append("vectors:"); 460 | if (this.vectors == null) { 461 | sb.append("null"); 462 | } else { 463 | sb.append(this.vectors); 464 | } 465 | first = false; 466 | } 467 | sb.append(")"); 468 | return sb.toString(); 469 | } 470 | 471 | public void validate() throws org.apache.thrift.TException { 472 | // check for required fields 473 | // check for sub-struct validity 474 | } 475 | 476 | private void writeObject(java.io.ObjectOutputStream out) throws java.io.IOException { 477 | try { 478 | write(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(out))); 479 | } catch (org.apache.thrift.TException te) { 480 | throw new java.io.IOException(te); 481 | } 482 | } 483 | 484 | private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException { 485 | try { 486 | // it doesn't seem like you should have to do this, but java serialization is wacky, and doesn't call the default constructor. 487 | __isset_bitfield = 0; 488 | read(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(in))); 489 | } catch (org.apache.thrift.TException te) { 490 | throw new java.io.IOException(te); 491 | } 492 | } 493 | 494 | private static class Word2VecModelThriftStandardSchemeFactory implements SchemeFactory { 495 | public Word2VecModelThriftStandardScheme getScheme() { 496 | return new Word2VecModelThriftStandardScheme(); 497 | } 498 | } 499 | 500 | private static class Word2VecModelThriftStandardScheme extends StandardScheme { 501 | 502 | public void read(org.apache.thrift.protocol.TProtocol iprot, Word2VecModelThrift struct) throws org.apache.thrift.TException { 503 | org.apache.thrift.protocol.TField schemeField; 504 | iprot.readStructBegin(); 505 | while (true) 506 | { 507 | schemeField = iprot.readFieldBegin(); 508 | if (schemeField.type == org.apache.thrift.protocol.TType.STOP) { 509 | break; 510 | } 511 | switch (schemeField.id) { 512 | case 1: // VOCAB 513 | if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { 514 | { 515 | org.apache.thrift.protocol.TList _list0 = iprot.readListBegin(); 516 | struct.vocab = new ArrayList(_list0.size); 517 | for (int _i1 = 0; _i1 < _list0.size; ++_i1) 518 | { 519 | String _elem2; 520 | _elem2 = iprot.readString(); 521 | struct.vocab.add(_elem2); 522 | } 523 | iprot.readListEnd(); 524 | } 525 | struct.setVocabIsSet(true); 526 | } else { 527 | org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); 528 | } 529 | break; 530 | case 2: // LAYER_SIZE 531 | if (schemeField.type == org.apache.thrift.protocol.TType.I32) { 532 | struct.layerSize = iprot.readI32(); 533 | struct.setLayerSizeIsSet(true); 534 | } else { 535 | org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); 536 | } 537 | break; 538 | case 3: // VECTORS 539 | if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { 540 | { 541 | org.apache.thrift.protocol.TList _list3 = iprot.readListBegin(); 542 | struct.vectors = new ArrayList(_list3.size); 543 | for (int _i4 = 0; _i4 < _list3.size; ++_i4) 544 | { 545 | double _elem5; 546 | _elem5 = iprot.readDouble(); 547 | struct.vectors.add(_elem5); 548 | } 549 | iprot.readListEnd(); 550 | } 551 | struct.setVectorsIsSet(true); 552 | } else { 553 | org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); 554 | } 555 | break; 556 | default: 557 | org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); 558 | } 559 | iprot.readFieldEnd(); 560 | } 561 | iprot.readStructEnd(); 562 | 563 | // check for required fields of primitive type, which can't be checked in the validate method 564 | struct.validate(); 565 | } 566 | 567 | public void write(org.apache.thrift.protocol.TProtocol oprot, Word2VecModelThrift struct) throws org.apache.thrift.TException { 568 | struct.validate(); 569 | 570 | oprot.writeStructBegin(STRUCT_DESC); 571 | if (struct.vocab != null) { 572 | if (struct.isSetVocab()) { 573 | oprot.writeFieldBegin(VOCAB_FIELD_DESC); 574 | { 575 | oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, struct.vocab.size())); 576 | for (String _iter6 : struct.vocab) 577 | { 578 | oprot.writeString(_iter6); 579 | } 580 | oprot.writeListEnd(); 581 | } 582 | oprot.writeFieldEnd(); 583 | } 584 | } 585 | if (struct.isSetLayerSize()) { 586 | oprot.writeFieldBegin(LAYER_SIZE_FIELD_DESC); 587 | oprot.writeI32(struct.layerSize); 588 | oprot.writeFieldEnd(); 589 | } 590 | if (struct.vectors != null) { 591 | if (struct.isSetVectors()) { 592 | oprot.writeFieldBegin(VECTORS_FIELD_DESC); 593 | { 594 | oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.DOUBLE, struct.vectors.size())); 595 | for (double _iter7 : struct.vectors) 596 | { 597 | oprot.writeDouble(_iter7); 598 | } 599 | oprot.writeListEnd(); 600 | } 601 | oprot.writeFieldEnd(); 602 | } 603 | } 604 | oprot.writeFieldStop(); 605 | oprot.writeStructEnd(); 606 | } 607 | 608 | } 609 | 610 | private static class Word2VecModelThriftTupleSchemeFactory implements SchemeFactory { 611 | public Word2VecModelThriftTupleScheme getScheme() { 612 | return new Word2VecModelThriftTupleScheme(); 613 | } 614 | } 615 | 616 | private static class Word2VecModelThriftTupleScheme extends TupleScheme { 617 | 618 | @Override 619 | public void write(org.apache.thrift.protocol.TProtocol prot, Word2VecModelThrift struct) throws org.apache.thrift.TException { 620 | TTupleProtocol oprot = (TTupleProtocol) prot; 621 | BitSet optionals = new BitSet(); 622 | if (struct.isSetVocab()) { 623 | optionals.set(0); 624 | } 625 | if (struct.isSetLayerSize()) { 626 | optionals.set(1); 627 | } 628 | if (struct.isSetVectors()) { 629 | optionals.set(2); 630 | } 631 | oprot.writeBitSet(optionals, 3); 632 | if (struct.isSetVocab()) { 633 | { 634 | oprot.writeI32(struct.vocab.size()); 635 | for (String _iter8 : struct.vocab) 636 | { 637 | oprot.writeString(_iter8); 638 | } 639 | } 640 | } 641 | if (struct.isSetLayerSize()) { 642 | oprot.writeI32(struct.layerSize); 643 | } 644 | if (struct.isSetVectors()) { 645 | { 646 | oprot.writeI32(struct.vectors.size()); 647 | for (double _iter9 : struct.vectors) 648 | { 649 | oprot.writeDouble(_iter9); 650 | } 651 | } 652 | } 653 | } 654 | 655 | @Override 656 | public void read(org.apache.thrift.protocol.TProtocol prot, Word2VecModelThrift struct) throws org.apache.thrift.TException { 657 | TTupleProtocol iprot = (TTupleProtocol) prot; 658 | BitSet incoming = iprot.readBitSet(3); 659 | if (incoming.get(0)) { 660 | { 661 | org.apache.thrift.protocol.TList _list10 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, iprot.readI32()); 662 | struct.vocab = new ArrayList(_list10.size); 663 | for (int _i11 = 0; _i11 < _list10.size; ++_i11) 664 | { 665 | String _elem12; 666 | _elem12 = iprot.readString(); 667 | struct.vocab.add(_elem12); 668 | } 669 | } 670 | struct.setVocabIsSet(true); 671 | } 672 | if (incoming.get(1)) { 673 | struct.layerSize = iprot.readI32(); 674 | struct.setLayerSizeIsSet(true); 675 | } 676 | if (incoming.get(2)) { 677 | { 678 | org.apache.thrift.protocol.TList _list13 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.DOUBLE, iprot.readI32()); 679 | struct.vectors = new ArrayList(_list13.size); 680 | for (int _i14 = 0; _i14 < _list13.size; ++_i14) 681 | { 682 | double _elem15; 683 | _elem15 = iprot.readDouble(); 684 | struct.vectors.add(_elem15); 685 | } 686 | } 687 | struct.setVectorsIsSet(true); 688 | } 689 | } 690 | } 691 | 692 | } 693 | 694 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/util/AC.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.util; 2 | 3 | /** Extension of {@link AutoCloseable} where {@link #close()} does not throw any exception */ 4 | public interface AC extends AutoCloseable { 5 | @Override void close(); 6 | 7 | /** {@link AC} that does nothing */ 8 | AC NOTHING = new AC() { 9 | @Override public void close() { } 10 | }; 11 | } -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/util/AutoLog.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.util; 2 | 3 | import com.google.common.base.Preconditions; 4 | import org.apache.commons.logging.Log; 5 | import org.apache.commons.logging.LogFactory; 6 | import org.apache.log4j.Logger; 7 | import org.apache.log4j.varia.NullAppender; 8 | 9 | /** 10 | * Creates loggers based on the caller's class. 11 | */ 12 | public final class AutoLog { 13 | /** Prevents initialization. */ 14 | private AutoLog() { 15 | } 16 | 17 | /** @return {@link org.apache.commons.logging.Log} based on the caller's class */ 18 | public static Log getLog() { 19 | return getLog(2); 20 | } 21 | 22 | /** Make sure there is at least one appender to avoid a warning printed on stderr */ 23 | private static class InitializeOnDemand { 24 | private static final boolean INIT = init(); 25 | private static boolean init() { 26 | if (!Logger.getRootLogger().getAllAppenders().hasMoreElements()) 27 | Logger.getRootLogger().addAppender(new NullAppender()); 28 | return true; 29 | } 30 | } 31 | 32 | /** @return {@link org.apache.commons.logging.Log} based on the stacktrace distance to 33 | * the original caller. 1= the caller to this method. 2 = the caller to the caller... etc*/ 34 | public static Log getLog(int distance) { 35 | Preconditions.checkState(InitializeOnDemand.INIT); 36 | String callerClassName = Common.myCaller(distance).getClassName(); 37 | try { 38 | return LogFactory.getLog(Class.forName(callerClassName)); 39 | } catch (ClassNotFoundException t) { 40 | String err = "Class.forName on " + callerClassName + " failed"; 41 | System.err.println(err); 42 | throw new IllegalStateException(err, t); 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/util/CallableVoid.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.util; 2 | 3 | import java.util.concurrent.Callable; 4 | 5 | /** Utility base implementation of Callable with a Void return type. */ 6 | public abstract class CallableVoid implements Callable { 7 | 8 | @Override public final Void call() throws Exception { 9 | run(); 10 | return null; 11 | } 12 | 13 | /** Do the actual work here instead of using {@link #call()} */ 14 | protected abstract void run() throws Exception; 15 | 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/util/Common.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.util; 2 | 3 | import org.apache.commons.io.FilenameUtils; 4 | import org.apache.commons.io.IOUtils; 5 | 6 | import java.io.BufferedReader; 7 | import java.io.ByteArrayOutputStream; 8 | import java.io.File; 9 | import java.io.FileInputStream; 10 | import java.io.FileNotFoundException; 11 | import java.io.IOException; 12 | import java.io.InputStream; 13 | import java.io.ObjectOutputStream; 14 | import java.io.Reader; 15 | import java.io.Serializable; 16 | import java.io.StringWriter; 17 | import java.net.URL; 18 | import java.util.ArrayList; 19 | import java.util.Collections; 20 | import java.util.List; 21 | import java.util.zip.GZIPInputStream; 22 | 23 | /** 24 | * Simple utilities that in no way deserve their own class. 25 | */ 26 | public class Common { 27 | /** 28 | * @param distance use 1 for our caller, 2 for their caller, etc... 29 | * @return the stack trace element from where the calling method was invoked 30 | */ 31 | public static StackTraceElement myCaller(int distance) { 32 | // 0 here, 1 our caller, 2 their caller 33 | int index = distance + 1; 34 | try { 35 | StackTraceElement st[] = new Throwable().getStackTrace(); 36 | // hack: skip synthetic caster methods 37 | if (st[index].getLineNumber() == 1) return st[index + 1]; 38 | return st[index]; 39 | } catch (Throwable t) { 40 | return new StackTraceElement("[unknown]","-","-",0); 41 | } 42 | } 43 | 44 | /** Serialize the given object into the given stream */ 45 | public static void serialize(Serializable obj, ByteArrayOutputStream bout) { 46 | try { 47 | ObjectOutputStream out = new ObjectOutputStream(bout); 48 | out.writeObject(obj); 49 | out.close(); 50 | } catch (IOException e) { 51 | throw new IllegalStateException("Could not serialize " + obj, e); 52 | } 53 | } 54 | 55 | /** 56 | * Read the file line for line and return the result in a list 57 | * @throws IOException upon failure in reading, note that we wrap the underlying IOException with the file name 58 | */ 59 | public static List readToList(File f) throws IOException { 60 | try (final Reader reader = asReaderUTF8Lenient(new FileInputStream(f))) { 61 | return readToList(reader); 62 | } catch (IOException ioe) { 63 | throw new IllegalStateException(String.format("Failed to read %s: %s", f.getAbsolutePath(), ioe), ioe); 64 | } 65 | } 66 | /** Read the Reader line for line and return the result in a list */ 67 | public static List readToList(Reader r) throws IOException { 68 | try ( BufferedReader in = new BufferedReader(r) ) { 69 | List l = new ArrayList<>(); 70 | String line = null; 71 | while ((line = in.readLine()) != null) 72 | l.add(line); 73 | return Collections.unmodifiableList(l); 74 | } 75 | } 76 | 77 | /** Wrap the InputStream in a Reader that reads UTF-8. Invalid content will be replaced by unicode replacement glyph. */ 78 | public static Reader asReaderUTF8Lenient(InputStream in) { 79 | return new UnicodeReader(in, "utf-8"); 80 | } 81 | 82 | /** Read the contents of the given file into a string */ 83 | public static String readFileToString(File f) throws IOException { 84 | StringWriter sw = new StringWriter(); 85 | IO.copyAndCloseBoth(Common.asReaderUTF8Lenient(new FileInputStream(f)), sw); 86 | return sw.toString(); 87 | } 88 | 89 | /** @return true if i is an even number */ 90 | public static boolean isEven(int i) { return (i&1)==0; } 91 | /** @return true if i is an odd number */ 92 | public static boolean isOdd(int i) { return !isEven(i); } 93 | 94 | /** Read the lines (as UTF8) of the resource file fn from the package of the given class into a (unmodifiable) list of strings 95 | * @throws IOException */ 96 | public static List readResource(Class clazz, String fn) throws IOException { 97 | try (final Reader reader = asReaderUTF8Lenient(getResourceAsStream(clazz, fn))) { 98 | return readToList(reader); 99 | } 100 | } 101 | 102 | /** Get an input stream to read the raw contents of the given resource, remember to close it :) */ 103 | public static InputStream getResourceAsStream(Class clazz, String fn) throws IOException { 104 | InputStream stream = clazz.getResourceAsStream(fn); 105 | if (stream == null) { 106 | throw new IOException("resource \"" + fn + "\" relative to " + clazz + " not found."); 107 | } 108 | return unpackStream(stream, fn); 109 | } 110 | 111 | /** Get a file to read the raw contents of the given resource :) */ 112 | public static File getResourceAsFile(Class clazz, String fn) throws IOException { 113 | URL url = clazz.getResource(fn); 114 | if (url == null || url.getFile() == null) { 115 | throw new IOException("resource \"" + fn + "\" relative to " + clazz + " not found."); 116 | } 117 | return new File(url.getFile()); 118 | } 119 | 120 | /** 121 | * @throws IOException if {@code is} is null or if an {@link IOException} is thrown when reading from {@code is} 122 | */ 123 | public static InputStream unpackStream(InputStream is, String fn) throws IOException { 124 | if (is == null) 125 | throw new FileNotFoundException("InputStream is null for " + fn); 126 | 127 | switch (FilenameUtils.getExtension(fn).toLowerCase()) { 128 | case "gz": 129 | return new GZIPInputStream(is); 130 | default: 131 | return is; 132 | } 133 | } 134 | 135 | /** Read the lines (as UTF8) of the resource file fn from the package of the given class into a string */ 136 | public static String readResourceToStringChecked(Class clazz, String fn) throws IOException { 137 | try (InputStream stream = getResourceAsStream(clazz, fn)) { 138 | return IOUtils.toString(asReaderUTF8Lenient(stream)); 139 | } 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/util/Compare.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.util; 2 | 3 | /** Utility class for general comparison and equality operations. */ 4 | public class Compare { 5 | /** 6 | * {@link NullPointerException} safe compare method; nulls are less than non-nulls. 7 | */ 8 | public static > int compare(X x1, X x2) { 9 | if (x1 == null) return x2 == null ? 0 : -1; 10 | return x2 == null ? 1 : x1.compareTo(x2); 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/util/FileUtils.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.util; 2 | 3 | import java.io.File; 4 | import java.io.FileFilter; 5 | import java.io.IOException; 6 | import java.nio.file.Paths; 7 | import java.util.Arrays; 8 | import java.util.List; 9 | import java.util.UUID; 10 | 11 | import com.google.common.base.Function; 12 | 13 | import com.google.common.base.Strings; 14 | 15 | /** 16 | * Collection of file-related utilities. 17 | */ 18 | public final class FileUtils { 19 | 20 | public static final int ONE_KB = 1<<10; 21 | public static final int ONE_MB = 1<<20; 22 | 23 | public static final Function FILE_TO_NAME = new Function() { 24 | @Override public String apply(File file) { 25 | return file.getName(); 26 | } 27 | }; 28 | 29 | /** 30 | * Returns a subdirectory of a given directory; the subdirectory is expected to already exist. 31 | * @param parent the directory in which to find the specified subdirectory 32 | * @param item the name of the subdirectory 33 | * @return the subdirectory having the specified name; null if no such directory exists or 34 | * exists but is a regular file. 35 | */ 36 | public static File getDir(File parent, String item) { 37 | File dir = new File(parent, item); 38 | return (dir.exists() && dir.isDirectory()) ? dir : null; 39 | } 40 | 41 | /** @return File for the specified directory; creates the directory if necessary. */ 42 | private static File getOrCreateDir(File f) throws IOException { 43 | if (!f.exists()) { 44 | if (!f.mkdirs()) { 45 | throw new IOException(f.getName() + ": Unable to create directory: " + f.getAbsolutePath()); 46 | } 47 | } else if (!f.isDirectory()) { 48 | throw new IOException(f.getName() + ": Exists and is not a directory: " + f.getAbsolutePath()); 49 | } 50 | return f; 51 | } 52 | 53 | /** 54 | * @return File for the specified directory, creates dirName directory if necessary. 55 | * @throws IOException if there is an error during directory creation or if a non-directory file with the desired name already 56 | * exists. 57 | */ 58 | public static File getOrCreateDir(String dirName) throws IOException { 59 | return getOrCreateDir(new File(dirName)); 60 | } 61 | 62 | /** 63 | * @return File for the specified directory; parent must already exist, creates dirName subdirectory if necessary. 64 | * @throws IOException if there is an error during directory creation or if a non-directory file with the desired name already 65 | * exists. 66 | */ 67 | public static File getOrCreateDir(File parent, String dirName) throws IOException { 68 | return getOrCreateDir(new File(parent, dirName)); 69 | } 70 | 71 | /** 72 | * @return File for the specified directory; parent must already exist, creates all intermediate dirNames subdirectories if necessary. 73 | * @throws IOException if there is an error during directory creation or if a non-directory file with the desired name already 74 | * exists. 75 | */ 76 | public static File getOrCreateDir(File parent, String ... dirNames) throws IOException { 77 | return getOrCreateDir(Paths.get(parent.getPath(), dirNames).toFile()); 78 | } 79 | 80 | /** 81 | * Deletes a file or directory. 82 | * If the file is a directory it recursively deletes it. 83 | * @param file file to be deleted 84 | * @return true if all the files where deleted successfully. 85 | */ 86 | public static boolean deleteRecursive(final File file) { 87 | boolean result = true; 88 | if (file.isDirectory()) { 89 | for (final File inner : file.listFiles()) { 90 | result &= deleteRecursive(inner); 91 | } 92 | } 93 | return result & file.delete(); 94 | } 95 | 96 | /** Utility class; don't instantiate. */ 97 | private FileUtils() { 98 | throw new AssertionError("Do not instantiate."); 99 | } 100 | 101 | /** @return A random temporary folder that can be used for file-system operations testing */ 102 | public static File getRandomTemporaryFolder(String prefix, String suffix) { 103 | return new File(System.getProperty("java.io.tmpdir"), Strings.nullToEmpty(prefix) + UUID.randomUUID().toString() + Strings.nullToEmpty(suffix)); 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/util/Format.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.util; 2 | 3 | /** 4 | * Created by yibin on 2/2/15. 5 | */ 6 | public class Format { 7 | /** @see {@link Strings#formatEnum(Enum)} */ 8 | public static String formatEnum(Enum enumValue) { 9 | return Strings.formatEnum(enumValue); 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/util/IO.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.util; 2 | 3 | import com.google.common.base.Function; 4 | import com.google.common.base.Preconditions; 5 | import org.apache.commons.io.IOUtils; 6 | 7 | import java.io.ByteArrayInputStream; 8 | import java.io.ByteArrayOutputStream; 9 | import java.io.Closeable; 10 | import java.io.File; 11 | import java.io.FileOutputStream; 12 | import java.io.IOException; 13 | import java.io.InputStream; 14 | import java.io.ObjectInputStream; 15 | import java.io.ObjectOutputStream; 16 | import java.io.ObjectStreamClass; 17 | import java.io.OutputStream; 18 | import java.io.Reader; 19 | import java.io.Writer; 20 | import java.lang.reflect.Field; 21 | import java.util.Collection; 22 | import java.util.Collections; 23 | import java.util.Comparator; 24 | import java.util.zip.GZIPInputStream; 25 | import java.util.zip.GZIPOutputStream; 26 | 27 | /** 28 | * Static utility functions related to IO. 29 | */ 30 | public final class IO { 31 | 32 | /** a Comparator that orders {@link File} objects as by the last modified date */ 33 | public static final Comparator FILE_LAST_MODIFIED_COMPARATOR = new Comparator() { 34 | @Override public int compare(File o1, File o2) { 35 | return Long.compare(o1.lastModified(), o2.lastModified()); 36 | }}; 37 | 38 | private static final int DEFAULT_BUFFER_SIZE = 1024 * 8; // 8K. Same as BufferedOutputStream, so writes are written-through with no extra copying 39 | 40 | private IO() { } 41 | 42 | /** @return the oldest file in the given collection using the last modified date or return null if the collection is empty */ 43 | public static File getOldestFileOrNull(Collection files) { 44 | if (files.isEmpty()) 45 | return null; 46 | 47 | return Collections.min(files, FILE_LAST_MODIFIED_COMPARATOR); 48 | } 49 | 50 | /** Copy input to output and close the output stream before returning */ 51 | public static long copyAndCloseOutput(InputStream input, OutputStream output) throws IOException { 52 | try (OutputStream outputStream = output) { 53 | return copy(input, outputStream); 54 | } 55 | } 56 | 57 | /** Copy input to output and close both the input and output streams before returning */ 58 | public static long copyAndCloseBoth(InputStream input, OutputStream output) throws IOException { 59 | try (InputStream inputStream = input) { 60 | return copyAndCloseOutput(inputStream, output); 61 | } 62 | } 63 | 64 | /** Similar to {@link IOUtils#toByteArray(InputStream)} but closes the stream. */ 65 | public static byte[] toByteArray(InputStream is) throws IOException { 66 | ByteArrayOutputStream bao = new ByteArrayOutputStream(); 67 | IO.copyAndCloseBoth(is, bao); 68 | return bao.toByteArray(); 69 | } 70 | 71 | /** Copy input to output; neither stream is closed */ 72 | public static long copy(InputStream input, OutputStream output) throws IOException { 73 | byte[] buffer = new byte[DEFAULT_BUFFER_SIZE]; 74 | long count = 0; 75 | int n = 0; 76 | while (-1 != (n = input.read(buffer))) { 77 | output.write(buffer, 0, n); 78 | count += n; 79 | } 80 | return count; 81 | } 82 | 83 | /** Copy input to output; neither stream is closed */ 84 | public static int copy(Reader input, Writer output) throws IOException { 85 | char[] buffer = new char[DEFAULT_BUFFER_SIZE]; 86 | int count = 0; 87 | int n = 0; 88 | while (-1 != (n = input.read(buffer))) { 89 | output.write(buffer, 0, n); 90 | count += n; 91 | } 92 | return count; 93 | } 94 | 95 | /** Copy input to output and close the output stream before returning */ 96 | public static int copyAndCloseOutput(Reader input, Writer output) throws IOException { 97 | try { 98 | return copy(input, output); 99 | } finally { 100 | output.close(); 101 | } 102 | 103 | } 104 | /** Copy input to output and close both the input and output streams before returning */ 105 | public static int copyAndCloseBoth(Reader input, Writer output) throws IOException { 106 | try { 107 | return copyAndCloseOutput(input, output); 108 | } finally { 109 | input.close(); 110 | } 111 | } 112 | 113 | /** 114 | * Copy the data from the given {@link InputStream} to a temporary file and call the given 115 | * {@link Function} with it; after the function returns the file is deleted. 116 | */ 117 | public static X runWithFile(InputStream stream, Function function) throws IOException { 118 | File f = File.createTempFile("run-with-file", null); 119 | try { 120 | try (FileOutputStream out = new FileOutputStream(f)) { 121 | IOUtils.copy(stream, out); 122 | } 123 | return function.apply(f); 124 | } finally { 125 | f.delete(); 126 | } 127 | } 128 | 129 | /** @return the compressed (gzip) version of the given bytes */ 130 | public static byte[] gzip(byte[] in) { 131 | try { 132 | ByteArrayOutputStream bos = new ByteArrayOutputStream(); 133 | GZIPOutputStream gz = new GZIPOutputStream(bos); 134 | gz.write(in); 135 | gz.close(); 136 | return bos.toByteArray(); 137 | } catch (IOException e) { 138 | throw new RuntimeException("Failed to compress bytes", e); 139 | } 140 | } 141 | 142 | /** @return the decompressed version of the given (compressed with gzip) bytes */ 143 | public static byte[] gunzip(byte[] in) { 144 | try { 145 | GZIPInputStream gis = new GZIPInputStream(new ByteArrayInputStream(in)); 146 | ByteArrayOutputStream bos = new ByteArrayOutputStream(); 147 | IO.copyAndCloseOutput(gis, bos); 148 | return bos.toByteArray(); 149 | } catch (IOException e) { 150 | throw new RuntimeException("Failed to decompress data", e); 151 | } 152 | } 153 | 154 | /** @return the compressed (gzip) version of the given object */ 155 | public static byte[] gzipObject(Object object) { 156 | try ( 157 | ByteArrayOutputStream bos = new ByteArrayOutputStream(); 158 | GZIPOutputStream zos = new GZIPOutputStream(bos); 159 | ObjectOutputStream oos = new ObjectOutputStream(zos)) 160 | { 161 | oos.writeObject(object); 162 | zos.close(); // Terminate gzip 163 | return bos.toByteArray(); 164 | } catch (IOException e) { 165 | throw new RuntimeException("Failed to compress bytes", e); 166 | } 167 | } 168 | 169 | /** @return the decompressed version of the given (compressed with gzip) bytes */ 170 | public static Object gunzipObject(byte[] in) { 171 | try { 172 | return new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(in))).readObject(); 173 | } catch (IOException | ClassNotFoundException e) { 174 | throw new RuntimeException("Failed to decompress data", e); 175 | } 176 | } 177 | 178 | /** ObjectInputStream which doesn't care too much about serialVersionUIDs. Horrible :) 179 | */ 180 | public static ObjectInputStream gullibleObjectInputStream(InputStream is) throws IOException { 181 | return new ObjectInputStream(is) { 182 | @Override 183 | protected ObjectStreamClass readClassDescriptor() throws IOException, ClassNotFoundException { 184 | ObjectStreamClass oc = super.readClassDescriptor(); 185 | try { 186 | Class c = Class.forName(oc.getName()); 187 | // interfaces do not have fields 188 | if (!c.isInterface()) { 189 | Field f = oc.getClass().getDeclaredField("suid"); 190 | f.setAccessible(true); 191 | f.set(oc, ObjectStreamClass.lookup(c).getSerialVersionUID()); 192 | } 193 | } catch (Exception e) { 194 | System.err.println("Couldn't fake class descriptor for "+oc+": "+ e); 195 | } 196 | return oc; 197 | } 198 | }; 199 | } 200 | 201 | /** Close, ignoring exceptions */ 202 | public static void close(Closeable stream) { 203 | try { 204 | stream.close(); 205 | } catch (IOException e) { 206 | } 207 | } 208 | 209 | /** 210 | * Creates a directory if it does not exist. 211 | * 212 | *

213 | * It will recursively create all the absolute path specified in the input 214 | * parameter. 215 | * 216 | * @param directory 217 | * the {@link File} containing the directory structure that is 218 | * going to be created. 219 | * @return a {@link File} pointing to the created directory 220 | * @throws IOException if there was en error while creating the directory. 221 | */ 222 | public static File createDirIfNotExists(File directory) throws IOException { 223 | if (!directory.isDirectory()) { 224 | if (!directory.mkdirs()) { 225 | throw new IOException("Failed to create directory: " + directory.getAbsolutePath()); 226 | } 227 | } 228 | return directory; 229 | } 230 | 231 | /** 232 | * @return true if the {@link File} is null, does not exist, or we were able to delete it; false 233 | * if the file exists and could not be deleted. 234 | */ 235 | public static boolean deleteIfPresent(File f) { 236 | return f == null || !f.exists() || f.delete(); 237 | } 238 | 239 | 240 | /** 241 | * Stores the given contents into a temporary file 242 | * @param fileContents the raw contents to store in the temporary file 243 | * @param namePrefix the desired file name prefix (must be at least 3 characters long) 244 | * @param extension the desired extension including the '.' character (use null for '.tmp') 245 | * @return a {@link File} reference to the newly created temporary file 246 | * @throws IOException if the temporary file creation fails 247 | */ 248 | public static File createTempFile(byte[] fileContents, String namePrefix, String extension) throws IOException { 249 | Preconditions.checkNotNull(fileContents, "file contents missing"); 250 | File tempFile = File.createTempFile(namePrefix, extension); 251 | try (FileOutputStream fos = new FileOutputStream(tempFile)) { 252 | fos.write(fileContents); 253 | } 254 | return tempFile; 255 | } 256 | } 257 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/util/NDC.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.util; 2 | 3 | 4 | /** Helper to create {@link org.apache.log4j.NDC} for nested diagnostic contexts */ 5 | public class NDC implements AC { 6 | private final int size; 7 | 8 | /** Push all the contexts given and pop them when auto-closed */ 9 | public static NDC push(String... context) { 10 | return new NDC(context); 11 | } 12 | 13 | /** Construct an {@link AutoCloseable} {@link NDC} with the given contexts */ 14 | private NDC(String... context) { 15 | for (String c : context) { 16 | org.apache.log4j.NDC.push("[" + c + "]"); 17 | } 18 | this.size = context.length; 19 | } 20 | 21 | @Override 22 | public void close() { 23 | for (int i = 0; i < size; i++) { 24 | org.apache.log4j.NDC.pop(); 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/util/Pair.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.util; 2 | 3 | import com.google.common.base.Function; 4 | import com.google.common.base.Predicate; 5 | import com.google.common.base.Predicates; 6 | import com.google.common.collect.FluentIterable; 7 | import com.google.common.collect.ImmutableList; 8 | import com.google.common.collect.ImmutableSet; 9 | import com.google.common.collect.Iterables; 10 | import com.google.common.collect.Lists; 11 | import com.google.common.collect.Ordering; 12 | import com.google.common.collect.Sets; 13 | 14 | import java.io.Serializable; 15 | import java.util.ArrayList; 16 | import java.util.Collection; 17 | import java.util.Comparator; 18 | import java.util.HashSet; 19 | import java.util.Iterator; 20 | import java.util.List; 21 | import java.util.Map; 22 | import java.util.Objects; 23 | import java.util.Set; 24 | 25 | /** 26 | * Simple class for storing two arbitrary objects in one. 27 | * 28 | * @param the type of the first value 29 | * @param the type of the second value 30 | */ 31 | public class Pair implements Map.Entry, Serializable { 32 | 33 | /** @see Serializable */ 34 | private static final long serialVersionUID = 1L; 35 | 36 | /** The first item in the pair. */ 37 | public final K first; 38 | /** The second item in the pair. */ 39 | public final V second; 40 | 41 | /** Creates a new instance of Pair */ 42 | protected Pair(K first, V second) { 43 | this.first = first; 44 | this.second = second; 45 | } 46 | 47 | /** Type-inferring constructor */ 48 | public static Pair cons(X x, Y y) { return new Pair(x,y); } 49 | 50 | /** Type-inferring constructor for pairs of the same type, which can optionally be swapped */ 51 | public static Pair cons(X x, X y, boolean swapped) { return swapped ? new Pair(y, x) : new Pair(x, y); } 52 | 53 | @Override 54 | public int hashCode() { 55 | // Compute by hand instead of using Encoding.combineHashes for improved performance 56 | return (first == null ? 0 : first.hashCode() * 13) + (second == null ? 0 : second.hashCode() * 17); 57 | } 58 | 59 | @Override 60 | public boolean equals(Object o) { 61 | if (o == this) 62 | return true; 63 | if (o == null || !getClass().equals(o.getClass())) 64 | return false; 65 | 66 | Pair op = (Pair) o; 67 | return Objects.equals(op.first, first) && Objects.equals(op.second, second); 68 | } 69 | 70 | /** @return {@link #first}; needed because String Templates have 'first' as a reserved word. */ 71 | public K getOne() { return first; } 72 | /** @return {@link #first} */ 73 | public K getFirst() { return first; } 74 | /** @return {@link #second} */ 75 | public V getSecond() { return second; } 76 | 77 | @Override public String toString() { 78 | return "Pair<"+first+","+second+">"; 79 | } 80 | 81 | /** @return a list with the two elements from this pair, regardless of whether they are null or not */ 82 | public List asList() { 83 | return Lists.newArrayList(first, second); 84 | } 85 | 86 | /** 87 | * @return a list of key/value pairs (keys are at even indices, values at odd) taken from the 88 | * given array, whose length must be even. 89 | */ 90 | @SafeVarargs 91 | public static List> fromPairs(X... args) { 92 | if (Common.isOdd(args.length)) 93 | throw new IllegalArgumentException("Array length must be even: " + args.length); 94 | 95 | List> l = new ArrayList<>(args.length / 2); 96 | for (int i = 0; i < args.length; i += 2) 97 | l.add(Pair.cons(args[i], args[i + 1])); 98 | return l; 99 | } 100 | 101 | /** 102 | * Converts a Map to a List of pairs. 103 | * Each entry in the map results in a Pair in the returned list. 104 | */ 105 | public static List> fromMap(Map m) { 106 | List> l = new ArrayList<>(); 107 | for (Map.Entry me : m.entrySet()) { 108 | l.add(Pair.cons(me.getKey(), me.getValue())); 109 | } 110 | return l; 111 | } 112 | 113 | private static >> C fromMapFlatten(C c, Map> m) { 114 | for (Map.Entry> me : m.entrySet()) { 115 | for (Y y : me.getValue()) 116 | c.add(Pair.cons(me.getKey(), y)); 117 | } 118 | return c; 119 | } 120 | 121 | @Override public K getKey() { return first; } 122 | @Override public V getValue() { return second; } 123 | @Override public V setValue(V value) { throw new UnsupportedOperationException(); } 124 | 125 | /** Method that allows Pair to be used directly by the Setup system (wtf) */ 126 | public String getName() { return String.valueOf(second); } 127 | 128 | /** @return a reversed version of this pair */ 129 | public Pair swapped() { return Pair.cons(second, first); } 130 | 131 | /** @return {@link Function} which performs a {@link #swapped()} */ 132 | public static Function, Pair> swappedFunction() { 133 | return new Function, Pair>() { 134 | @Override public Pair apply(Pair p) { 135 | return p.swapped(); 136 | } 137 | }; 138 | } 139 | 140 | /** 141 | * @return {@link Ordering} which compares the first value of the pairs. 142 | * Pairs with equal first value will be considered equivalent independent of the second value 143 | */ 144 | public static > Ordering> firstComparator() { 145 | return new Ordering>() { 146 | @Override public int compare(Pair o1, Pair o2) { 147 | return Compare.compare(o1.first, o2.first); 148 | } 149 | }; 150 | } 151 | 152 | /** 153 | * @return {@link Ordering} which compares the second value of the pairs. 154 | * Pairs with equal second value will be considered equivalent independent of the first value 155 | */ 156 | public static > Ordering> secondComparator() { 157 | return new Ordering>() { 158 | @Override public int compare(Pair o1, Pair o2) { 159 | return Compare.compare(o1.second, o2.second); 160 | } 161 | }; 162 | } 163 | 164 | /** @return {@link Ordering} which compares both values of the {@link Pair}s, with the first taking precedence. */ 165 | public static , Y extends Comparable> Ordering> firstThenSecondComparator() { 166 | return new Ordering>() { 167 | @Override public int compare(Pair o1, Pair o2) { 168 | int k = Compare.compare(o1.first, o2.first); 169 | if (k == 0) k = Compare.compare(o1.second, o2.second); 170 | return k; 171 | } 172 | }; 173 | } 174 | 175 | /** @return {@link Ordering} which compares both values of the {@link Pair}s, with the second taking precedence. */ 176 | public static , Y extends Comparable> Ordering> secondThenFirstComparator() { 177 | return new Ordering>() { 178 | @Override public int compare(Pair o1, Pair o2) { 179 | int k = Compare.compare(o1.second, o2.second); 180 | if (k == 0) 181 | k = Compare.compare(o1.first, o2.first); 182 | return k; 183 | } 184 | }; 185 | } 186 | 187 | /** 188 | * Pair comparator that applies the given {@link Comparator} to the first value of the pairs 189 | */ 190 | public static Comparator> firstComparator(final Comparator comp) { 191 | return new Comparator>() { 192 | @Override public int compare(Pair o1, Pair o2) { 193 | return comp.compare(o1.first, o2.first); 194 | } 195 | }; 196 | } 197 | 198 | /** 199 | * Pair comparator that applies the given {@link Comparator} to the second value of the pairs 200 | */ 201 | public static Comparator> secondComparator(final Comparator comp) { 202 | return new Comparator>() { 203 | @Override public int compare(Pair o1, Pair o2) { 204 | return comp.compare(o1.second, o2.second); 205 | } 206 | }; 207 | } 208 | 209 | /** 210 | * Pair comparator that compares both values of the pairs, with the first taking 211 | * precedence; the order is reversed for the first value only. 212 | */ 213 | public static , Y extends Comparable> Comparator> bothFirstReversedComparator() { 214 | return new Comparator>() { 215 | @Override public int compare(Pair o1, Pair o2) { 216 | int k = Compare.compare(o2.first, o1.first); 217 | if (k == 0) k = Compare.compare(o1.second, o2.second); 218 | return k; 219 | } 220 | }; 221 | } 222 | 223 | private static Map fillMap(Map m, Iterable> pairs) { 224 | for (Pair p : pairs) { 225 | m.put(p.first, p.second); 226 | } 227 | return m; 228 | } 229 | 230 | /** @return the combination of all the elements in each collection. For instance if the first collection is 231 | * {@code [1, 2, 3]}, and the second one is {@code [a, b]}, then the result is {@code [(1, a), (1, b), (2, a), ...]} 232 | */ 233 | @SuppressWarnings("unchecked") 234 | public static List> cartesianProduct(Collection c1, Collection c2) { 235 | return FluentIterable.from(Sets.cartesianProduct(ImmutableSet.copyOf(c1), ImmutableSet.copyOf(c2))) 236 | .transform(new Function, Pair>() { 237 | @Override public Pair apply(List objs) { 238 | X x = (X) objs.get(0); 239 | Y y = (Y) objs.get(1); 240 | return Pair.cons(x, y); 241 | } 242 | }) 243 | .toList(); 244 | } 245 | 246 | /** @return the elements at equal indices in the two lists, which must be of the same 247 | * length, as pairs. 248 | */ 249 | public static List> zip(Collection c1, Collection c2) { 250 | return zip(c1, c2, new ArrayList>(c1.size()), false); 251 | } 252 | 253 | /** @return the elements at equal indices in the two lists, which must be of the same 254 | * length, as pairs. 255 | */ 256 | public static List> zip(X[] a1, Y[] a2) { 257 | return zip(ImmutableList.copyOf(a1), ImmutableList.copyOf(a2)); 258 | } 259 | 260 | /** 261 | * @return the elements at equal indices in the two lists, which must be of 262 | * the same length, as pairs, without duplicates removed from the 263 | * first list 264 | */ 265 | public static List> zipUnique(Collection c1, Collection c2) { 266 | return zip(c1, c2, new ArrayList>(), true); 267 | } 268 | 269 | private static List> zip(Collection c1, Collection c2, List> output, boolean uniqueKeys) { 270 | int size = c1.size(); 271 | if (size != c2.size()) 272 | throw new IllegalArgumentException("Collections must be of same size: " + size + ", " + c2.size()); 273 | 274 | Set set = uniqueKeys ? new HashSet() : null; 275 | Iterator it1 = c1.iterator(); 276 | Iterator it2 = c2.iterator(); 277 | 278 | while (it1.hasNext() && it2.hasNext()) { 279 | X x = it1.next(); 280 | Y y = it2.next(); 281 | if (set == null || set.add(x)) 282 | output.add(Pair.cons(x, y)); 283 | } 284 | return output; 285 | } 286 | 287 | /** 288 | * @return the elements at equal indices of the two list as pairs. The number of elements in the result list 289 | * is the minimum of the given iterable 290 | * of different size only elements at indices 291 | * present on the first {@link Iterable} are used. 292 | */ 293 | public static Iterable> zipInner(final Iterable first, final Iterable second) { 294 | return new Iterable>() { 295 | @Override public Iterator> iterator() { 296 | final Iterator x = first.iterator(); 297 | final Iterator y = second.iterator(); 298 | return new Iterator>() { 299 | @Override public boolean hasNext() { 300 | return x.hasNext() && y.hasNext(); 301 | } 302 | 303 | @Override 304 | public Pair next() { 305 | return Pair.cons(x.next(), y.next()); 306 | } 307 | 308 | @Override 309 | public void remove() { 310 | x.remove(); 311 | y.remove(); 312 | } 313 | 314 | }; 315 | } 316 | 317 | }; 318 | } 319 | 320 | /** @return {@link Function} which retrieves the second of the pair */ 321 | public static Function, V> retrieveSecondFunction() { 322 | return new Function, V>() { 323 | @Override 324 | public V apply(Pair p) { 325 | return p.second; 326 | } 327 | }; 328 | } 329 | 330 | /** @return {@link Iterable} of second element in pair */ 331 | public static Iterable unzipSecond(Iterable> pairs) { 332 | return Iterables.transform(pairs, Pair.retrieveSecondFunction()); 333 | } 334 | 335 | /** @return {@link Function} which maps the value of each pair through the given {@link Function} */ 336 | public static Function, Pair> mapValues(final Function func) { 337 | return new Function, Pair>() { 338 | @Override public Pair apply(Pair p) { 339 | return Pair.cons(p.first, func.apply(p.second)); 340 | } 341 | }; 342 | } 343 | 344 | /** @return the first value, or null if the pair is null */ 345 | public static K firstOrNull(Pair pair) { 346 | return pair != null ? pair.first : null; 347 | } 348 | 349 | /** @return the second value, or null if the pair is null */ 350 | public static V secondOrNull(Pair pair) { 351 | return pair != null ? pair.second : null; 352 | } 353 | 354 | /** @return {@link Predicates} which filters only on the first value */ 355 | public static Predicate> getFirstPredicate(final Predicate pred) { 356 | return new Predicate>() { 357 | @Override public boolean apply(Pair pair) { 358 | return pred.apply(pair.first); 359 | } 360 | }; 361 | } 362 | 363 | /** @return {@link Predicates} which filters only on the second value */ 364 | public static Predicate> getSecondPredicate(final Predicate pred) { 365 | return new Predicate>() { 366 | @Override public boolean apply(Pair pair) { 367 | return pred.apply(pair.second); 368 | } 369 | }; 370 | } 371 | 372 | /** @return {@link Predicates} which accepts a pair only if both values are accepted */ 373 | public static Predicate> getAndPredicate(final Predicate firstPred, final Predicate secondPred) { 374 | return new Predicate>() { 375 | @Override public boolean apply(Pair pair) { 376 | return firstPred.apply(pair.first) && secondPred.apply(pair.second); 377 | } 378 | }; 379 | } 380 | 381 | /** @return {@link Predicates} which accepts a pair if either values is accepted */ 382 | public static Predicate> getOrPredicate(final Predicate firstPred, final Predicate secondPred) { 383 | return new Predicate>() { 384 | @Override public boolean apply(Pair pair) { 385 | return firstPred.apply(pair.first) || secondPred.apply(pair.second); 386 | } 387 | }; 388 | } 389 | 390 | /** 391 | * @return {@link ImmutableList} containing all values paired with their applied value 392 | * through the function 393 | */ 394 | public static ImmutableList> toPairList(Iterable values, Function func) { 395 | ImmutableList.Builder> result = ImmutableList.builder(); 396 | for (X x : values) 397 | result.add(Pair.cons(x, func.apply(x))); 398 | return result.build(); 399 | } 400 | 401 | } 402 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/util/ProfilingTimer.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.util; 2 | 3 | import com.google.common.base.Preconditions; 4 | import com.google.common.collect.Maps; 5 | import org.apache.commons.lang3.mutable.MutableInt; 6 | import org.apache.commons.logging.Log; 7 | import org.joda.time.Duration; 8 | import org.joda.time.Period; 9 | import org.joda.time.format.PeriodFormat; 10 | 11 | import java.io.ByteArrayOutputStream; 12 | import java.io.Serializable; 13 | import java.util.Map; 14 | 15 | /** 16 | * A timer utility that can be used to keep track of the execution time of a single-threaded task 17 | * composed of several subtasks. Subtasks may in turn be composed of other subtasks in a recursive 18 | * (i.e., a tree) fashion. 19 | * 20 | *

21 | * If a (sub) task has the same name and all of its parents up to the root share the same name, then 22 | * aggregate information is going to be shown for that task: total time it took, number of times it 23 | * was executed and average time. 24 | * 25 | *

26 | * It is strongly recommended that this class instances be used in a try-with-resources as the actual 27 | * writing of the information to the logs happens on {@link #close()} and abnormal termination may 28 | * otherwise prevent this method from being called. 29 | * 30 | *

31 | * This class is thread safe. However, activity tracked in threads other than the one that created 32 | * the {@link ProfilingTimer} will be ignored and won't be rolled up as a tree in combination with 33 | * the activity from other threads. 34 | * 35 | *

 36 |  *  try (ProfilingTimer timer = ProfilingTimer.start(LOG, "processing file %s", file.getName())) {
 37 |  *  	// manually starting and finishing of a task
 38 |  *  	timer.start("uncompressing file");
 39 |  *  	// ... unzipping code here ...
 40 |  *  	timer.end();
 41 |  *
 42 |  *  	// alternatively, use a try-with-resources
 43 |  *  	try (AC ac = timer.start("decrypting file")) {
 44 |  *  		// subtasks are allowed
 45 |  *  		timer.start("analyzing public/private keys");
 46 |  *  		// ... GPG stuff here ...
 47 |  *
 48 |  *  		// convenience method for sibling tasks
 49 |  *  		timer.endAndStart("actual decryption");
 50 |  *  		// ... more GPG stuff here ...
 51 |  *  		timer.end();
 52 |  *  	}
 53 |  *  }
 54 |  *
 55 |  *  // at this point all the information is written to the log, e.g., as follows
 56 |  *  // [processing file example.txt]	total time 10s
 57 |  *  // [processing file example.txt]		[uncompressing file] took 300ms
 58 |  *  // [processing file example.txt]		[decrypting file] took 9s
 59 |  *  // [processing file example.txt]			[analyzing public/private keys] took 1s
 60 |  *  // [processing file example.txt]			[actual decryption] took 8s
 61 |  * 
62 | */ 63 | public class ProfilingTimer implements AC { 64 | 65 | /** 66 | * Just in case we need to disable this feature due to excessive logging 67 | */ 68 | public static volatile boolean enabled = true; 69 | 70 | /** 71 | * When this flag is enabled we only report data about the top-level activity 72 | */ 73 | public static volatile boolean topLevelInfoOnly = true; 74 | 75 | /** 76 | * Keeps information about a task within a {@link ProfilingTimer}. Since tasks can have multiple 77 | * subtasks, this represents a tree. 78 | */ 79 | public static class ProfilingTimerNode implements Serializable { 80 | private static final long serialVersionUID = 7464244055073290781L; 81 | 82 | private static final long CLOSED = -1; 83 | 84 | private final String taskName; 85 | private String logAppendMessage = ""; 86 | private ProfilingTimerNode parent; 87 | private final Map children = Maps.newLinkedHashMap(); 88 | private final Log log; 89 | 90 | private long start = System.nanoTime(); 91 | private long totalNanos; 92 | private long count; 93 | 94 | private ProfilingTimerNode(String taskName, ProfilingTimerNode parent, Log log) { 95 | this.taskName = taskName; 96 | if (parent != null) { 97 | parent.addChild(this); 98 | } 99 | this.log = log; 100 | } 101 | 102 | private void addChild(ProfilingTimerNode child) { 103 | if (child.parent != null) { 104 | throw new IllegalStateException(String.format("Child [%s] already belongs to parent [%s], can't be added to new parent [%s]", 105 | child.taskName, child.parent.taskName, taskName)); 106 | } 107 | 108 | child.parent = this; 109 | children.put(child.taskName, child); 110 | } 111 | 112 | private void stop() { 113 | if (start != CLOSED) { 114 | totalNanos += System.nanoTime() - start; 115 | count++; 116 | start = CLOSED; 117 | if (parent == null) { 118 | try (AC ac = NDC.push(taskName)) { 119 | log(0, log); 120 | } 121 | } 122 | } 123 | } 124 | 125 | private void appendToLog(String logAppendMessage) { 126 | this.logAppendMessage += logAppendMessage; 127 | } 128 | 129 | private void log(int level, Log log) { 130 | writeToLog(level, totalNanos, count, parent, taskName, log, logAppendMessage); 131 | 132 | for (ProfilingTimerNode child : children.values()) { 133 | child.log(level + 1, log); 134 | } 135 | } 136 | 137 | private void merge(ProfilingTimerNode other) { 138 | Preconditions.checkState(other.start == ProfilingTimerNode.CLOSED, "Can't merge non-closed node: %s", other.taskName); 139 | Preconditions.checkState(start == ProfilingTimerNode.CLOSED, "Can't merge into non-closed nodes: %s", taskName); 140 | 141 | totalNanos += other.totalNanos; 142 | count += other.count; 143 | } 144 | } 145 | 146 | /** 147 | * Null object pattern {@link ProfilingTimer} instance that does nothing at all 148 | */ 149 | public static final ProfilingTimer NONE = new ProfilingTimer(null, null, null) { 150 | @Override public AC start(String taskName, Object... args) { return AC.NOTHING; } 151 | @Override public void end() { } 152 | @Override public void close() { } 153 | }; 154 | 155 | private final Log log; 156 | private final ThreadLocal current = new ThreadLocal<>(); 157 | private final ByteArrayOutputStream serializationOutput; 158 | 159 | /** 160 | * Starts a new profiling timer with the given process name (optional arguments can be used as in {@link String#format(String, Object...)}). 161 | * When this {@link AC} is closed the profiling information will be dumped on the given log. 162 | * 163 | * Note that this method obeys the static {@link #topLevelInfoOnly} 164 | * 165 | *

166 | * It is highly recommended to use this in a try-with-resources block so that even if there's an abrupt termination of one of the tasks, 167 | * the {@link #close()} method will always be called. Otherwise the profiling information may not make it to the log. 168 | * 169 | *

170 | * Notice that this method may return {@link #NONE} if {@link #enabled} is false. 171 | */ 172 | public static ProfilingTimer create(final Log log, final String processName, final Object... args) { 173 | return create(log, topLevelInfoOnly, null, processName, args); 174 | } 175 | 176 | /** Same as {@link #create(Log, String, Object...)} but logs subtasks as well */ 177 | public static ProfilingTimer createLoggingSubtasks(final Log log, final String processName, final Object... args) { 178 | return create(log, false, null, processName, args); 179 | } 180 | 181 | /** Same as {@link #create(Log, String, Object...)} but includes subtasks, and instead of writing to a log, it outputs the tree in serialized form */ 182 | public static ProfilingTimer createSubtasksAndSerialization(ByteArrayOutputStream serializationOutput, final String processName, final Object... args) { 183 | return create(null, false, serializationOutput, processName, args); 184 | } 185 | 186 | private static ProfilingTimer create(final Log log, boolean topLevelInfoOnly, ByteArrayOutputStream serializationOutput, final String processName, final Object... args) { 187 | // do not use ternary as it creates an annoying resource leak warning 188 | if (enabled) 189 | if (topLevelInfoOnly) 190 | return new ProfilingTimer(null, null, null) { 191 | MutableInt level = new MutableInt(0); 192 | String logAppendMessage = ""; 193 | long startNanos = System.nanoTime(); 194 | @Override public AC start(String taskName, Object... args) { 195 | if (level != null) 196 | level.increment(); 197 | return new AC() { 198 | @Override public void close() { 199 | level.decrement(); 200 | } 201 | }; 202 | } 203 | @Override public void end() { 204 | level.decrement(); 205 | } 206 | @Override public void close() { 207 | if (startNanos != ProfilingTimerNode.CLOSED) { 208 | String taskName = String.format(processName, args); 209 | try (AC ac = NDC.push(taskName)) { 210 | writeToLog(0, System.nanoTime() - startNanos, 1, null, taskName, log, logAppendMessage); 211 | } 212 | startNanos = ProfilingTimerNode.CLOSED; 213 | } 214 | } 215 | @Override public void appendToLog(String logAppendMessage) { 216 | if (level.intValue() == 0) 217 | this.logAppendMessage += logAppendMessage; 218 | } 219 | }; 220 | else 221 | return new ProfilingTimer(log, serializationOutput, processName, args); 222 | else 223 | return NONE; 224 | } 225 | 226 | private ProfilingTimer(Log log, ByteArrayOutputStream serializationOutput, String processName, Object... args) { 227 | this.log = log; 228 | this.serializationOutput = serializationOutput; 229 | start(processName, args); 230 | } 231 | 232 | /** 233 | * Append the given string to the log message of the current subtask 234 | */ 235 | public void appendToLog(String logAppendMessage) { 236 | ProfilingTimerNode currentNode = current.get(); 237 | if (currentNode != null) { 238 | currentNode.appendToLog(logAppendMessage); 239 | } 240 | } 241 | 242 | /** 243 | * Indicates that a new task has started. Nested tasks are supported, so this method 244 | * can potentially be called various times in a row without invoking {@link #end()}. 245 | * 246 | *

247 | * Optionally, this method can be used in a try-with-resources block so that there is no 248 | * need to manually invoking {@link #end()} when the task at hand finishes. 249 | */ 250 | public AC start(String taskName, Object... args) { 251 | final ProfilingTimerNode parent = current.get(); 252 | current.set(findOrCreateNode(String.format(taskName, args), parent)); 253 | return new AC() { 254 | @Override public void close() { 255 | // return to the parent that we had when this AC was created 256 | current.set(parent); 257 | // close all the elements in the subtree under current 258 | if (parent != null) 259 | stopAll(parent); 260 | } 261 | private void stopAll(ProfilingTimerNode current) { 262 | for (ProfilingTimerNode child : current.children.values()) { 263 | stopAll(child); 264 | child.stop(); 265 | } 266 | } 267 | }; 268 | } 269 | 270 | /** 271 | * Indicates that the most recently initiated task (via {@link #start(String, Object...)}) is now finished 272 | */ 273 | public void end() { 274 | ProfilingTimerNode currentNode = current.get(); 275 | if (currentNode != null) { 276 | currentNode.stop(); 277 | current.set(currentNode.parent); 278 | } 279 | } 280 | 281 | /** 282 | * Convenience method for when a task starts right after the previous one finished. 283 | */ 284 | public void endAndStart(String taskName, Object... args) { 285 | end(); 286 | start(taskName, args); 287 | } 288 | 289 | @Override 290 | public void close() { 291 | ProfilingTimerNode root = current.get(); 292 | while (current.get() != null) { 293 | end(); 294 | } 295 | 296 | if (root != null && serializationOutput != null) { 297 | Common.serialize(root, serializationOutput); 298 | } 299 | } 300 | 301 | /** Merges the specified tree as a child under the current node. */ 302 | public void mergeTree(ProfilingTimerNode otherRoot) { 303 | ProfilingTimerNode currentNode = current.get(); 304 | Preconditions.checkNotNull(currentNode); 305 | mergeOrAddNode(currentNode, otherRoot); 306 | } 307 | 308 | private void mergeOrAddNode(ProfilingTimerNode parent, ProfilingTimerNode child) { 309 | ProfilingTimerNode nodeToBeMerged = parent.children.get(child.taskName); 310 | if (nodeToBeMerged == null) { 311 | parent.addChild(child); 312 | return; 313 | } 314 | 315 | nodeToBeMerged.merge(child); 316 | for (ProfilingTimerNode grandchild : child.children.values()) { 317 | mergeOrAddNode(nodeToBeMerged, grandchild); 318 | } 319 | } 320 | 321 | private ProfilingTimerNode findOrCreateNode(String taskName, ProfilingTimerNode parent) { 322 | ProfilingTimerNode node = null; 323 | if (parent != null) { 324 | node = parent.children.get(taskName); 325 | if (node != null) { 326 | node.start = System.nanoTime(); 327 | } 328 | } 329 | if (node == null) { 330 | node = new ProfilingTimerNode(taskName, parent, log); 331 | } 332 | return node; 333 | } 334 | 335 | /** Writes one profiling line of information to the log */ 336 | private static void writeToLog(int level, long totalNanos, long count, ProfilingTimerNode parent, String taskName, Log log, String logAppendMessage) { 337 | if (log == null) { 338 | return; 339 | } 340 | 341 | StringBuilder sb = new StringBuilder(); 342 | for (int i = 0; i < level; i++) { 343 | sb.append('\t'); 344 | } 345 | String durationText = String.format("%s%s", 346 | formatElapsed(totalNanos), 347 | count == 1 ? 348 | "" : 349 | String.format(" across %d invocations, average: %s", count, formatElapsed(totalNanos / count))); 350 | String text = parent == null ? 351 | String.format("total time %s", durationText) : 352 | String.format("[%s] took %s", taskName, durationText); 353 | sb.append(text); 354 | sb.append(logAppendMessage); 355 | log.info(sb.toString()); 356 | } 357 | 358 | /** @return a human-readable formatted string for the given amount of nanos */ 359 | private static String formatElapsed(long nanos) { 360 | return String.format("%s (%6.3g nanoseconds)", 361 | PeriodFormat.getDefault().print(Period.millis((int)(nanos / 1000))), 362 | (double) nanos); 363 | } 364 | 365 | } 366 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/util/Strings.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.util; 2 | 3 | /** 4 | * Various utility functions for working with String objects 5 | */ 6 | public class Strings { 7 | /** @see {@link Strings#formatEnum(Enum)} */ 8 | public static String formatEnum(Enum enumValue) { 9 | return capitalizeFirstCharacterLowercaseRest(enumValue.name().replace('_', ' ')); 10 | } 11 | 12 | private static String capitalizeFirstCharacterLowercaseRest(String s) { 13 | if (!hasContent(s)) return s; 14 | return s.substring(0, 1).toUpperCase() + s.substring(1).toLowerCase(); 15 | } 16 | 17 | /** @return true if the string is not null and has non-zero trimmed length; false otherwise */ 18 | public static boolean hasContent(String s) { 19 | return hasContent(s, true); 20 | } 21 | 22 | /** 23 | * @param trim true if the string should be trimmed 24 | * @return true if the string is not null and has non-zero trimmed length; false otherwise */ 25 | public static boolean hasContent(String s, boolean trim) { 26 | return s != null && !(trim ? s.trim() : s).isEmpty(); 27 | } 28 | 29 | /** 30 | * Join the toString of each object element into a single string, with 31 | * each element separated by the given sep (which can be empty). 32 | */ 33 | public static String joinObjects(String sep, Iterable l) { 34 | return sepList(sep, l, -1); 35 | } 36 | 37 | /** Same as sepList with no wrapping */ 38 | public static String sepList(String sep, Iterable os, int max) { 39 | return sepList(sep, null, os, max); 40 | } 41 | 42 | /** @return The concatenation of toString of the objects obtained from the iterable, separated by sep, and if max 43 | * is > 0 include no more than that number of objects. If wrap is non-null, prepend and append each object with it 44 | */ 45 | public static String sepList(String sep, String wrap, Iterable os, int max) { 46 | StringBuilder sb = new StringBuilder(); 47 | String s = ""; 48 | if (max == 0) max = -1; 49 | for (Object o : os) { 50 | sb.append(s); s = sep; 51 | if (max-- == 0) { sb.append("..."); break; } 52 | if (wrap != null) sb.append(wrap); 53 | sb.append(o); 54 | if (wrap != null) sb.append(wrap); 55 | } 56 | return sb.toString(); 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/util/ThriftUtils.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.util; 2 | 3 | import org.apache.thrift.TBase; 4 | import org.apache.thrift.TDeserializer; 5 | import org.apache.thrift.TException; 6 | import org.apache.thrift.TSerializer; 7 | import org.apache.thrift.protocol.TJSONProtocol; 8 | 9 | /** Contains useful methods for using Thrift */ 10 | public final class ThriftUtils { 11 | private static final String THRIFT_CHARSET = "utf-8"; 12 | 13 | /** Serialize a JSON-encoded thrift object */ 14 | public static String serializeJson(T obj) throws TException { 15 | // Tried having a static final serializer, but it doesn't seem to be thread safe 16 | return new TSerializer(new TJSONProtocol.Factory()).toString(obj, THRIFT_CHARSET); 17 | } 18 | 19 | /** Deserialize a JSON-encoded thrift object */ 20 | public static T deserializeJson(T dest, String thriftJson) throws TException { 21 | // Tried having a static final deserializer, but it doesn't seem to be thread safe 22 | new TDeserializer(new TJSONProtocol.Factory()).deserialize(dest, thriftJson, THRIFT_CHARSET); 23 | return dest; 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/com/medallia/word2vec/util/UnicodeReader.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec.util; 2 | 3 | import java.io.IOException; 4 | import java.io.InputStream; 5 | import java.io.InputStreamReader; 6 | import java.io.PushbackInputStream; 7 | import java.io.Reader; 8 | import java.nio.charset.Charset; 9 | import java.nio.charset.CharsetDecoder; 10 | 11 | /** 12 | * Generic unicode textreader, which will use BOM 13 | * to identify the encoding to be used. If BOM is not found 14 | * then use a given default or system encoding. 15 | *

16 | * BOMs for different unicodes use this standard: 17 | * 00 00 FE FF = UTF-32, big-endian 18 | * FF FE 00 00 = UTF-32, little-endian 19 | * EF BB BF = UTF-8, 20 | * FE FF = UTF-16, big-endian 21 | * FF FE = UTF-16, little-endian 22 | *

23 | * This piece of code is found in:

24 | * http://koti.mbnet.fi/akini/java/unicodereader/ 25 | *

26 | * The decoding will be handled by the decoder returned by 27 | * {@link Charset#newDecoder()} if strict is set to true, 28 | * and exceptions will be propagated from the returned 29 | * {@link CharsetDecoder}. 30 | */ 31 | public class UnicodeReader extends Reader { 32 | PushbackInputStream internalIn; 33 | InputStreamReader internalIn2 = null; 34 | String defaultEnc; 35 | boolean strict; 36 | 37 | private static final int BOM_SIZE = 4; 38 | 39 | /** 40 | * Constructor 41 | * @param in inputstream to be read 42 | * @param defaultEnc default encoding if stream does not have 43 | * BOM. Give NULL to use system-level default. 44 | * @param strict invalid content will give exceptions (See {@link UnicodeReader}) 45 | */ 46 | public UnicodeReader(InputStream in, String defaultEnc, boolean strict) { 47 | internalIn = new PushbackInputStream(in, BOM_SIZE); 48 | this.defaultEnc = defaultEnc; 49 | this.strict = strict; 50 | } 51 | 52 | /** 53 | * Same as {@link #UnicodeReader(InputStream, String, boolean)}, with strict = false. 54 | */ 55 | public UnicodeReader(InputStream in, String defaultEnc) { 56 | this(in, defaultEnc, false); 57 | } 58 | 59 | /** 60 | * @return Default encoding during constructor 61 | */ 62 | public String getDefaultEncoding() { 63 | return defaultEnc; 64 | } 65 | 66 | /** 67 | * Get stream encoding or NULL if stream is uninitialized. 68 | * Call init() or read() method to initialize it. 69 | */ 70 | public String getEncoding() { 71 | if (internalIn2 == null) return null; 72 | return internalIn2.getEncoding(); 73 | } 74 | 75 | /** 76 | * Read-ahead four bytes and check for BOM. Extra bytes are 77 | * unread back to the stream, only BOM bytes are skipped. 78 | */ 79 | protected void init() throws IOException { 80 | if (internalIn2 != null) return; 81 | 82 | String encoding; 83 | byte bom[] = new byte[BOM_SIZE]; 84 | int n, unread; 85 | n = internalIn.read(bom, 0, bom.length); 86 | 87 | if ( (bom[0] == (byte)0x00) && (bom[1] == (byte)0x00) && 88 | (bom[2] == (byte)0xFE) && (bom[3] == (byte)0xFF) ) { 89 | encoding = "UTF-32BE"; 90 | unread = n - 4; 91 | } else if ( (bom[0] == (byte)0xFF) && (bom[1] == (byte)0xFE) && 92 | (bom[2] == (byte)0x00) && (bom[3] == (byte)0x00) ) { 93 | encoding = "UTF-32LE"; 94 | unread = n - 4; 95 | } else if ( (bom[0] == (byte)0xEF) && (bom[1] == (byte)0xBB) && 96 | (bom[2] == (byte)0xBF) ) { 97 | encoding = "UTF-8"; 98 | unread = n - 3; 99 | } else if ( (bom[0] == (byte)0xFE) && (bom[1] == (byte)0xFF) ) { 100 | encoding = "UTF-16BE"; 101 | unread = n - 2; 102 | } else if ( (bom[0] == (byte)0xFF) && (bom[1] == (byte)0xFE) ) { 103 | encoding = "UTF-16LE"; 104 | unread = n - 2; 105 | } else { 106 | // Unicode BOM not found, unread all bytes 107 | encoding = defaultEnc; 108 | unread = n; 109 | } 110 | 111 | if (unread > 0) internalIn.unread(bom, (n - unread), unread); 112 | 113 | // Use given encoding 114 | if (encoding == null) { 115 | internalIn2 = new InputStreamReader(internalIn); 116 | } else if (strict) { 117 | internalIn2 = new InputStreamReader(internalIn, Charset.forName(encoding).newDecoder()); 118 | } else { 119 | internalIn2 = new InputStreamReader(internalIn, encoding); 120 | } 121 | } 122 | 123 | @Override public void close() throws IOException { 124 | init(); 125 | internalIn2.close(); 126 | } 127 | 128 | @Override public int read(char[] cbuf, int off, int len) throws IOException { 129 | init(); 130 | return internalIn2.read(cbuf, off, len); 131 | } 132 | 133 | @Override public boolean ready() throws IOException { 134 | init(); 135 | return internalIn2.ready(); 136 | } 137 | } -------------------------------------------------------------------------------- /src/test/java/com/medallia/word2vec/Word2VecBinTest.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec; 2 | 3 | import static org.junit.Assert.assertEquals; 4 | import static org.junit.Assert.assertTrue; 5 | 6 | import java.io.File; 7 | import java.io.IOException; 8 | import java.io.OutputStream; 9 | import java.nio.file.Files; 10 | import java.nio.file.Path; 11 | import java.util.List; 12 | 13 | import org.junit.After; 14 | import org.junit.Assert; 15 | import org.junit.Test; 16 | 17 | import com.medallia.word2vec.Searcher.UnknownWordException; 18 | import com.medallia.word2vec.util.Common; 19 | 20 | /** 21 | * Tests converting the binary models into 22 | * {@link com.medallia.word2vec.Word2VecModel}s. 23 | * 24 | * @see com.medallia.word2vec.Word2VecModel#fromBinFile(File) 25 | * @see com.medallia.word2vec.Word2VecModel#fromBinFile(File, 26 | * java.nio.ByteOrder) 27 | */ 28 | public class Word2VecBinTest { 29 | 30 | /** 31 | * Tests that the Word2VecModels created from a binary and text 32 | * representations are equivalent 33 | */ 34 | @Test 35 | public void testRead() 36 | throws IOException, UnknownWordException { 37 | File binFile = Common.getResourceAsFile( 38 | this.getClass(), 39 | "/com/medallia/word2vec/tokensModel.bin"); 40 | Word2VecModel binModel = Word2VecModel.fromBinFile(binFile); 41 | 42 | File txtFile = Common.getResourceAsFile( 43 | this.getClass(), 44 | "/com/medallia/word2vec/tokensModel.txt"); 45 | Word2VecModel txtModel = Word2VecModel.fromTextFile(txtFile); 46 | 47 | assertEquals(binModel, txtModel); 48 | } 49 | 50 | private Path tempFile = null; 51 | 52 | /** 53 | * Tests that a Word2VecModel round-trips through the bin format without changes 54 | */ 55 | @Test 56 | public void testRoundTrip() throws IOException, UnknownWordException { 57 | final String filename = "word2vec.c.output.model.txt"; 58 | final Word2VecModel model = 59 | Word2VecModel.fromTextFile(filename, Common.readResource(Word2VecTest.class, filename)); 60 | 61 | tempFile = Files.createTempFile( 62 | String.format("%s-", Word2VecBinTest.class.getSimpleName()), ".bin"); 63 | try (final OutputStream os = Files.newOutputStream(tempFile)) { 64 | model.toBinFile(os); 65 | } 66 | 67 | final Word2VecModel modelCopy = Word2VecModel.fromBinFile(tempFile.toFile()); 68 | assertEquals(model, modelCopy); 69 | } 70 | 71 | @After 72 | public void cleanupTempFile() throws IOException { 73 | if(tempFile != null) 74 | Files.delete(tempFile); 75 | } 76 | 77 | private void assertEquals( 78 | final Word2VecModel leftModel, 79 | final Word2VecModel rightModel) throws UnknownWordException { 80 | final Searcher leftSearcher = leftModel.forSearch(); 81 | final Searcher rightSearcher = rightModel.forSearch(); 82 | 83 | // test vocab 84 | for (String vocab : leftModel.getVocab()) { 85 | assertTrue(rightSearcher.contains(vocab)); 86 | } 87 | for (String vocab : rightModel.getVocab()) { 88 | assertTrue(leftSearcher.contains(vocab)); 89 | } 90 | // test vector 91 | for (String vocab : leftModel.getVocab()) { 92 | final List leftVector = leftSearcher.getRawVector(vocab); 93 | final List rightVector = rightSearcher.getRawVector(vocab); 94 | assertEquals(leftVector, rightVector); 95 | } 96 | } 97 | 98 | private void assertEquals( 99 | final List leftVector, 100 | final List rightVector) { 101 | Assert.assertEquals(leftVector.size(), rightVector.size()); 102 | for (int i = 0; i < leftVector.size(); i++) { 103 | double txtD = leftVector.get(i); 104 | double binD = rightVector.get(i); 105 | Assert.assertEquals(txtD, binD, 0.0001); 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /src/test/java/com/medallia/word2vec/Word2VecTest.java: -------------------------------------------------------------------------------- 1 | package com.medallia.word2vec; 2 | 3 | import static org.junit.Assert.assertEquals; 4 | import static org.junit.Assert.fail; 5 | 6 | import java.io.File; 7 | import java.io.IOException; 8 | import java.util.List; 9 | 10 | import org.apache.commons.io.FileUtils; 11 | import org.apache.thrift.TException; 12 | import org.junit.After; 13 | import org.junit.Rule; 14 | import org.junit.Test; 15 | import org.junit.rules.ExpectedException; 16 | 17 | import com.google.common.annotations.VisibleForTesting; 18 | import com.google.common.collect.ImmutableList; 19 | import com.google.common.collect.Iterables; 20 | import com.google.common.collect.Lists; 21 | import com.medallia.word2vec.Searcher.Match; 22 | import com.medallia.word2vec.Searcher.UnknownWordException; 23 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener; 24 | import com.medallia.word2vec.neuralnetwork.NeuralNetworkType; 25 | import com.medallia.word2vec.thrift.Word2VecModelThrift; 26 | import com.medallia.word2vec.util.Common; 27 | import com.medallia.word2vec.util.ThriftUtils; 28 | 29 | /** 30 | * Tests for {@link Word2VecModel} and related classes. 31 | *

32 | * Note that the implementation is expected to be deterministic if numThreads is 33 | * set to 1 34 | */ 35 | public class Word2VecTest { 36 | @Rule 37 | public ExpectedException expected = ExpectedException.none(); 38 | 39 | /** Clean up after a test run */ 40 | @After 41 | public void after() { 42 | // Unset the interrupted flag to avoid polluting other tests 43 | Thread.interrupted(); 44 | } 45 | 46 | /** Test {@link NeuralNetworkType#CBOW} */ 47 | @Test 48 | public void testCBOW() throws IOException, TException, InterruptedException { 49 | assertModelMatches("cbowBasic.model", 50 | Word2VecModel.trainer() 51 | .setMinVocabFrequency(6) 52 | .useNumThreads(1) 53 | .setWindowSize(8) 54 | .type(NeuralNetworkType.CBOW) 55 | .useHierarchicalSoftmax() 56 | .setLayerSize(25) 57 | .setDownSamplingRate(1e-3) 58 | .setNumIterations(1) 59 | .train(testData()) 60 | ); 61 | } 62 | 63 | /** Test {@link NeuralNetworkType#CBOW} with 15 iterations */ 64 | @Test 65 | public void testCBOWwith15Iterations() throws IOException, TException, InterruptedException { 66 | assertModelMatches("cbowIterations.model", 67 | Word2VecModel.trainer() 68 | .setMinVocabFrequency(5) 69 | .useNumThreads(1) 70 | .setWindowSize(8) 71 | .type(NeuralNetworkType.CBOW) 72 | .useHierarchicalSoftmax() 73 | .setLayerSize(25) 74 | .useNegativeSamples(5) 75 | .setDownSamplingRate(1e-3) 76 | .setNumIterations(15) 77 | .train(testData()) 78 | ); 79 | } 80 | 81 | /** Test {@link NeuralNetworkType#SKIP_GRAM} */ 82 | @Test 83 | public void testSkipGram() throws IOException, TException, InterruptedException { 84 | assertModelMatches("skipGramBasic.model", 85 | Word2VecModel.trainer() 86 | .setMinVocabFrequency(6) 87 | .useNumThreads(1) 88 | .setWindowSize(8) 89 | .type(NeuralNetworkType.SKIP_GRAM) 90 | .useHierarchicalSoftmax() 91 | .setLayerSize(25) 92 | .setDownSamplingRate(1e-3) 93 | .setNumIterations(1) 94 | .train(testData()) 95 | ); 96 | } 97 | 98 | /** Test {@link NeuralNetworkType#SKIP_GRAM} with 15 iterations */ 99 | @Test 100 | public void testSkipGramWith15Iterations() throws IOException, TException, InterruptedException { 101 | assertModelMatches("skipGramIterations.model", 102 | Word2VecModel.trainer() 103 | .setMinVocabFrequency(6) 104 | .useNumThreads(1) 105 | .setWindowSize(8) 106 | .type(NeuralNetworkType.SKIP_GRAM) 107 | .useHierarchicalSoftmax() 108 | .setLayerSize(25) 109 | .setDownSamplingRate(1e-3) 110 | .setNumIterations(15) 111 | .train(testData()) 112 | ); 113 | } 114 | 115 | /** Test that we can interrupt the huffman encoding process */ 116 | @Test 117 | public void testInterruptHuffman() throws IOException, InterruptedException { 118 | expected.expect(InterruptedException.class); 119 | trainer() 120 | .type(NeuralNetworkType.SKIP_GRAM) 121 | .setNumIterations(15) 122 | .setListener(new TrainingProgressListener() { 123 | @Override public void update(Stage stage, double progress) { 124 | if (stage == Stage.CREATE_HUFFMAN_ENCODING) 125 | Thread.currentThread().interrupt(); 126 | else if (stage == Stage.TRAIN_NEURAL_NETWORK) 127 | fail("Should not have reached this stage"); 128 | } 129 | }) 130 | .train(testData()); 131 | } 132 | 133 | /** Test that we can interrupt the neural network training process */ 134 | @Test 135 | public void testInterruptNeuralNetworkTraining() throws InterruptedException, IOException { 136 | expected.expect(InterruptedException.class); 137 | trainer() 138 | .type(NeuralNetworkType.SKIP_GRAM) 139 | .setNumIterations(15) 140 | .setListener(new TrainingProgressListener() { 141 | @Override public void update(Stage stage, double progress) { 142 | if (stage == Stage.TRAIN_NEURAL_NETWORK) 143 | Thread.currentThread().interrupt(); 144 | } 145 | }) 146 | .train(testData()); 147 | } 148 | 149 | /** 150 | * Test the search results are deterministic Note the actual values may not 151 | * make sense since the model we train isn't tuned 152 | */ 153 | @Test 154 | public void testSearch() throws InterruptedException, IOException, UnknownWordException { 155 | Word2VecModel model = trainer() 156 | .type(NeuralNetworkType.SKIP_GRAM) 157 | .train(testData()); 158 | 159 | List matches = model.forSearch().getMatches("anarchism", 5); 160 | 161 | assertEquals( 162 | ImmutableList.of("anarchism", "feminism", "trouble", "left", "capitalism"), 163 | Lists.transform(matches, Match.TO_WORD) 164 | ); 165 | } 166 | 167 | /** 168 | * Test that the model can retrieve words by a vector. 169 | */ 170 | @Test 171 | public void testGetWordByVector() throws InterruptedException, IOException, UnknownWordException { 172 | Word2VecModel model = trainer() 173 | .type(NeuralNetworkType.SKIP_GRAM) 174 | .train(testData()); 175 | 176 | // This vector defines the word "anarchism" in the given model. 177 | double[] vectors = new double[] { 0.11410251703652753, 0.271180824514185, 0.03748515103121994, 0.20888126888511183, 0.009713531343874777, 0.4769425625416319, 0.1431890482445165, -0.1917578875330224, -0.33532561802423366, 178 | -0.08794543238607992, 0.20404593606213406, 0.26170074241479385, 0.10020961212561065, 0.11400571893146201, -0.07846426915175395, -0.19404092647187385, 0.13381991303455204, -4.6749635342694615E-4, -0.0820905789076496, 179 | -0.30157145455251866, 0.3652037905836543, -0.16466827556950117, -0.012965932276668056, 0.09896568721267748, -0.01925755122093615 }; 180 | 181 | List matches = model.forSearch().getMatches(vectors, 5); 182 | 183 | assertEquals( 184 | ImmutableList.of("anarchism", "feminism", "trouble", "left", "capitalism"), 185 | Lists.transform(matches, Match.TO_WORD) 186 | ); 187 | } 188 | 189 | /** 190 | * Test that the model can retrieve words by a vector. 191 | */ 192 | @Test 193 | public void testGetWordByNotExistantVector() throws InterruptedException, IOException, UnknownWordException { 194 | Word2VecModel model = trainer() 195 | .type(NeuralNetworkType.SKIP_GRAM) 196 | .train(testData()); 197 | 198 | double[] vectors = new double[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 199 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 200 | 0, 0, 0, 0, 0, 0 }; 201 | 202 | List matches = model.forSearch().getMatches(vectors, 5); 203 | 204 | assertEquals( 205 | ImmutableList.of("the", "of", "and", "in", "a"), 206 | Lists.transform(matches, Match.TO_WORD) 207 | ); 208 | } 209 | 210 | /** Test reading Word2Vec C version txt output format into this library */ 211 | @Test 212 | public void testTxtModelRead() throws IOException, UnknownWordException { 213 | String filename = "word2vec.c.output.model.txt"; 214 | Word2VecModel word2VecModel = Word2VecModel.fromTextFile(filename, Common.readResource(Word2VecTest.class, filename)); 215 | assertEquals(0.9927725293757652, word2VecModel.forSearch().cosineDistance("three", "five"), 1e-5); 216 | } 217 | 218 | /** @return {@link Word2VecTrainer} which by default uses all of the supported features */ 219 | @VisibleForTesting 220 | public static Word2VecTrainerBuilder trainer() { 221 | return Word2VecModel.trainer() 222 | .setMinVocabFrequency(6) 223 | .useNumThreads(1) 224 | .setWindowSize(8) 225 | .type(NeuralNetworkType.CBOW) 226 | .useHierarchicalSoftmax() 227 | .setLayerSize(25) 228 | .setDownSamplingRate(1e-3) 229 | .setNumIterations(1); 230 | } 231 | 232 | /** @return raw test dataset. The tokens are separated by newlines. */ 233 | @VisibleForTesting 234 | public static Iterable> testData() throws IOException { 235 | List lines = Common.readResource(Word2VecTest.class, "word2vec.short.txt"); 236 | Iterable> partitioned = Iterables.partition(lines, 1000); 237 | return partitioned; 238 | } 239 | 240 | private void assertModelMatches(String expectedResource, Word2VecModel model) throws TException { 241 | final String thrift; 242 | try { 243 | thrift = Common.readResourceToStringChecked(getClass(), expectedResource); 244 | } catch (IOException ioe) { 245 | String filename = "/tmp/" + expectedResource; 246 | try { 247 | FileUtils.writeStringToFile( 248 | new File(filename), 249 | ThriftUtils.serializeJson(model.toThrift()) 250 | ); 251 | } catch (IOException e) { 252 | throw new AssertionError("Could not read resource " + expectedResource + " and could not write expected output to /tmp"); 253 | } 254 | throw new AssertionError("Could not read resource " + expectedResource + " wrote to " + filename); 255 | } 256 | 257 | Word2VecModelThrift expected = ThriftUtils.deserializeJson( 258 | new Word2VecModelThrift(), 259 | thrift 260 | ); 261 | 262 | assertEquals("Mismatched vocab", expected.getVocab().size(), Iterables.size(model.getVocab())); 263 | 264 | assertEquals(expected, model.toThrift()); 265 | } 266 | } 267 | -------------------------------------------------------------------------------- /src/test/resources/com/medallia/word2vec/tokensModel.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medallia/Word2VecJava/eb31fbb99ac6bbab82d7f807b3e2240edca50eb7/src/test/resources/com/medallia/word2vec/tokensModel.bin --------------------------------------------------------------------------------