├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── build.gradle
├── build.gradle.release
├── gradle
└── wrapper
│ ├── gradle-wrapper.jar
│ └── gradle-wrapper.properties
├── gradlew
├── gradlew.bat
├── pom.xml
├── settings.gradle
└── src
├── main
└── java
│ └── com
│ └── medallia
│ └── word2vec
│ ├── NormalizedWord2VecModel.java
│ ├── Searcher.java
│ ├── SearcherImpl.java
│ ├── Word2VecExamples.java
│ ├── Word2VecModel.java
│ ├── Word2VecTrainer.java
│ ├── Word2VecTrainerBuilder.java
│ ├── huffman
│ └── HuffmanCoding.java
│ ├── neuralnetwork
│ ├── CBOWModelTrainer.java
│ ├── NeuralNetworkConfig.java
│ ├── NeuralNetworkTrainer.java
│ ├── NeuralNetworkType.java
│ └── SkipGramModelTrainer.java
│ ├── thrift
│ └── Word2VecModelThrift.java
│ └── util
│ ├── AC.java
│ ├── AutoLog.java
│ ├── CallableVoid.java
│ ├── Common.java
│ ├── Compare.java
│ ├── FileUtils.java
│ ├── Format.java
│ ├── IO.java
│ ├── NDC.java
│ ├── Pair.java
│ ├── ProfilingTimer.java
│ ├── Strings.java
│ ├── ThriftUtils.java
│ └── UnicodeReader.java
└── test
├── java
└── com
│ └── medallia
│ └── word2vec
│ ├── Word2VecBinTest.java
│ └── Word2VecTest.java
└── resources
└── com
└── medallia
└── word2vec
├── cbowBasic.model
├── cbowIterations.model
├── skipGramBasic.model
├── skipGramIterations.model
├── tokensModel.bin
├── tokensModel.txt
├── word2vec.c.output.model.txt
└── word2vec.short.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | .gradle
3 | /build
4 | /out
5 | /intTestHomeDir
6 | /subprojects/*/out
7 | /intellij
8 | /buildSrc/lib
9 | /buildSrc/build
10 | /subprojects/*/build
11 | /subprojects/docs/src/samples/*/*/build
12 | /website/build
13 | /website/website.iml
14 | /website/website.ipr
15 | /website/website.iws
16 | /performanceTest/build
17 | /subprojects/*/ide
18 | /*.iml
19 | /*.ipr
20 | /*.iws
21 | /subprojects/*/*.iml
22 | /buildSrc/*.ipr
23 | /buildSrc/*.iws
24 | /buildSrc/*.iml
25 | /buildSrc/out
26 | *.classpath
27 | *.project
28 | *.settings
29 | /bin
30 | /subprojects/*/bin
31 | .DS_Store
32 | /performanceTest/lib
33 | .textmate
34 | /incoming-distributions
35 | .idea
36 | *.sublime-*
37 | .nb-gradle
38 | /target/
39 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: java
2 |
3 | jdk:
4 | - oraclejdk8
5 | - oraclejdk7
6 | - openjdk7
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2015 Medallia
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Word2vecJava
2 |
3 | [](https://travis-ci.org/medallia/Word2VecJava)
4 |
5 | This is a port of the open source C implementation of word2vec (https://code.google.com/p/word2vec/). You can browse/contribute the repository via [Github](https://github.com/medallia/Word2VecJava). Alternatively you can pull it from the central Maven repositories:
6 | ```XML
7 |
8 | com.medallia.word2vec
9 | Word2VecJava
10 | 0.10.3
11 |
12 | ```
13 |
14 | For more background information about word2vec and neural network training for the vector representation of words, please see the following papers.
15 | * http://ttic.uchicago.edu/~haotang/speech/1301.3781.pdf
16 | * http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf
17 |
18 | For comprehensive explanation of the training process (the gradiant descent formula calculation in the back propagation training), please see:
19 | * http://www-personal.umich.edu/~ronxin/pdf/w2vexp.pdf
20 |
21 | Note that this isn't a completely faithful rewrite, specifically:
22 |
23 | ### When building the vocabulary from the training file:
24 | 1. The original version does a reduction step when learning the vocabulary from the file when the vocab size hits 21 million words, removing any words that do not meet the minimum frequency threshold. This Java port has no such reduction step.
25 | 2. The original version injects a token into the vocabulary (with a word count of 0) as a substitute for newlines in the input file. This Java port's vocabulary excludes the token.
26 | 3. The original version does a quicksort which is not stable, so vocabulary terms with the same frequency may be ordered non-deterministically. The Java port does an explicit sort first by frequency, then by the token's lexicographical ordering.
27 |
28 | ### In partitioning the file for processing
29 | 1. The original version assumes that sentences are delimited by newline characters and injects a sentence boundary per 1000 non-filtered tokens, i.e. valid token by the vocabulary and not removed by the randomized sampling process. Java port mimics this behavior for now ...
30 | 2. When the original version encounters an empty line in the input file, it re-processes the first word of the last non-empty line with a sentence length of 0 and updates the random value. Java port omits this behavior.
31 |
32 | ### In the sampling function
33 | 1. The original C documentation indicates that the range should be between 0 and 1e-5, but the default value is 1e-3. This Java port retains that confusing information.
34 | 2. The random value generated for comparison to determine if a token should be filtered uses a float. This Java port uses double precision for twice the fun.
35 |
36 | ### In the distance function to find the nearest matches to a target query
37 | 1. The original version includes an unnecessary normalization of the vector for the input query which may lead to tiny inaccuracies. This Java port foregoes this superfluous operation.
38 | 2. The original version has an O(n * k) algorithm for finding top matches and is hardcoded to 40 matches. This Java port uses Google's lovely com.google.common.collect.Ordering.greatestOf(java.util.Iterator, int) which is O(n + k log k) and takes in arbitrary k.
39 |
40 | Note: The k-means clustering option is excluded in the Java port
41 |
42 | Please do not hesitate to peek at the source code. It should be readable, concise, and correct. Please feel free to reach out if it is not.
43 |
44 | ## Building the Project
45 | To verify that the project is building correctly, run
46 | ```bash
47 | ./gradlew build && ./gradlew test
48 | ```
49 |
50 | It should run 7 tests without any error.
51 |
52 | Note: this project requires gradle 2.2+, if you are using older version of gradle, please upgrade it and run:
53 | ```bash
54 | ./gradlew clean test
55 | ```
56 |
57 | to have a clean build and re-run the tests.
58 |
59 |
60 | ## Contact
61 | Andrew Ko (wko27code@gmail.com)
62 |
--------------------------------------------------------------------------------
/build.gradle:
--------------------------------------------------------------------------------
1 | apply plugin: 'java'
2 | apply plugin: 'maven'
3 |
4 | sourceCompatibility = 1.7
5 | targetCompatibility = 1.7
6 | group = 'com.medallia.word2vec'
7 | version = '0.10.3'
8 |
9 | repositories {
10 | mavenCentral()
11 | }
12 |
13 | dependencies {
14 | compile 'org.apache.thrift:libthrift:0.9.1'
15 | compile group: 'org.apache.commons', name: 'commons-lang3', version:'3.1'
16 | compile group: 'com.google.guava', name: 'guava', version: '18.0'
17 | compile group: 'joda-time', name: 'joda-time', version: '2.3'
18 | compile group: 'log4j', name: 'log4j', version: '1.2.17'
19 | compile group: 'commons-io', name: 'commons-io', version: '2.4'
20 | testCompile group: 'junit', name: 'junit', version: '4.11'
21 | }
22 |
23 | if (JavaVersion.current().isJava8Compatible()) {
24 | tasks.withType(Javadoc) {
25 | // disable the crazy super-strict doclint tool in Java 8
26 | //noinspection SpellCheckingInspection
27 | options.addStringOption('Xdoclint:none', '-quiet')
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/build.gradle.release:
--------------------------------------------------------------------------------
1 | apply plugin: 'signing'
2 |
3 | task javadocJar(type: Jar) {
4 | classifier = 'javadoc'
5 | from javadoc
6 | }
7 |
8 | task sourcesJar(type: Jar) {
9 | classifier = 'sources'
10 | from sourceSets.main.allSource
11 | }
12 |
13 | artifacts {
14 | archives javadocJar, sourcesJar
15 | }
16 |
17 | signing {
18 | sign configurations.archives
19 | }
20 |
21 | uploadArchives {
22 | repositories {
23 | mavenDeployer {
24 | beforeDeployment { MavenDeployment deployment -> signing.signPom(deployment) }
25 |
26 | repository(url: "https://oss.sonatype.org/service/local/staging/deploy/maven2/") {
27 | authentication(userName: "*****", password: "*****")
28 | }
29 |
30 | snapshotRepository(url: "https://oss.sonatype.org/content/repositories/snapshots/") {
31 | authentication(userName: "*****", password: "*****")
32 | }
33 |
34 | pom.project {
35 | name 'Word2VecJava'
36 | packaging 'jar'
37 | description 'Word2Vec Java Port'
38 | url 'https://github.com/medallia/Word2VecJava'
39 |
40 | scm {
41 | connection 'scm:git:git@github.com:medallia/Word2VecJava.git'
42 | developerConnection 'scm:git:git@github.com:medallia/Word2VecJava.git'
43 | url 'https://github.com/medallia/Word2VecJava.git'
44 | }
45 |
46 | licenses {
47 | license {
48 | name 'The MIT License'
49 | url 'http://opensource.org/licenses/MIT'
50 | }
51 | }
52 |
53 | developers {
54 | developer {
55 | id 'wko'
56 | name 'Andrew Ko'
57 | email 'wko@medallia.com'
58 | }
59 | developer {
60 | id 'yibinlin'
61 | name 'Yibin Lin'
62 | email 'yibin@medallia.com'
63 | }
64 | developer {
65 | id 'yfpeng'
66 | name 'Yifan Peng'
67 | email 'yfpeng@udel.edu'
68 | }
69 | developer {
70 | id 'guerda'
71 | name 'Philip Gillißen'
72 | }
73 | }
74 | }
75 | }
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/gradle/wrapper/gradle-wrapper.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/medallia/Word2VecJava/eb31fbb99ac6bbab82d7f807b3e2240edca50eb7/gradle/wrapper/gradle-wrapper.jar
--------------------------------------------------------------------------------
/gradle/wrapper/gradle-wrapper.properties:
--------------------------------------------------------------------------------
1 | #Mon Feb 02 17:12:48 PST 2015
2 | distributionBase=GRADLE_USER_HOME
3 | distributionPath=wrapper/dists
4 | zipStoreBase=GRADLE_USER_HOME
5 | zipStorePath=wrapper/dists
6 | distributionUrl=https\://services.gradle.org/distributions/gradle-2.1-bin.zip
7 |
--------------------------------------------------------------------------------
/gradlew:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ##############################################################################
4 | ##
5 | ## Gradle start up script for UN*X
6 | ##
7 | ##############################################################################
8 |
9 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
10 | DEFAULT_JVM_OPTS=""
11 |
12 | APP_NAME="Gradle"
13 | APP_BASE_NAME=`basename "$0"`
14 |
15 | # Use the maximum available, or set MAX_FD != -1 to use that value.
16 | MAX_FD="maximum"
17 |
18 | warn ( ) {
19 | echo "$*"
20 | }
21 |
22 | die ( ) {
23 | echo
24 | echo "$*"
25 | echo
26 | exit 1
27 | }
28 |
29 | # OS specific support (must be 'true' or 'false').
30 | cygwin=false
31 | msys=false
32 | darwin=false
33 | case "`uname`" in
34 | CYGWIN* )
35 | cygwin=true
36 | ;;
37 | Darwin* )
38 | darwin=true
39 | ;;
40 | MINGW* )
41 | msys=true
42 | ;;
43 | esac
44 |
45 | # For Cygwin, ensure paths are in UNIX format before anything is touched.
46 | if $cygwin ; then
47 | [ -n "$JAVA_HOME" ] && JAVA_HOME=`cygpath --unix "$JAVA_HOME"`
48 | fi
49 |
50 | # Attempt to set APP_HOME
51 | # Resolve links: $0 may be a link
52 | PRG="$0"
53 | # Need this for relative symlinks.
54 | while [ -h "$PRG" ] ; do
55 | ls=`ls -ld "$PRG"`
56 | link=`expr "$ls" : '.*-> \(.*\)$'`
57 | if expr "$link" : '/.*' > /dev/null; then
58 | PRG="$link"
59 | else
60 | PRG=`dirname "$PRG"`"/$link"
61 | fi
62 | done
63 | SAVED="`pwd`"
64 | cd "`dirname \"$PRG\"`/" >&-
65 | APP_HOME="`pwd -P`"
66 | cd "$SAVED" >&-
67 |
68 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
69 |
70 | # Determine the Java command to use to start the JVM.
71 | if [ -n "$JAVA_HOME" ] ; then
72 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
73 | # IBM's JDK on AIX uses strange locations for the executables
74 | JAVACMD="$JAVA_HOME/jre/sh/java"
75 | else
76 | JAVACMD="$JAVA_HOME/bin/java"
77 | fi
78 | if [ ! -x "$JAVACMD" ] ; then
79 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
80 |
81 | Please set the JAVA_HOME variable in your environment to match the
82 | location of your Java installation."
83 | fi
84 | else
85 | JAVACMD="java"
86 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
87 |
88 | Please set the JAVA_HOME variable in your environment to match the
89 | location of your Java installation."
90 | fi
91 |
92 | # Increase the maximum file descriptors if we can.
93 | if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then
94 | MAX_FD_LIMIT=`ulimit -H -n`
95 | if [ $? -eq 0 ] ; then
96 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
97 | MAX_FD="$MAX_FD_LIMIT"
98 | fi
99 | ulimit -n $MAX_FD
100 | if [ $? -ne 0 ] ; then
101 | warn "Could not set maximum file descriptor limit: $MAX_FD"
102 | fi
103 | else
104 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
105 | fi
106 | fi
107 |
108 | # For Darwin, add options to specify how the application appears in the dock
109 | if $darwin; then
110 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
111 | fi
112 |
113 | # For Cygwin, switch paths to Windows format before running java
114 | if $cygwin ; then
115 | APP_HOME=`cygpath --path --mixed "$APP_HOME"`
116 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
117 |
118 | # We build the pattern for arguments to be converted via cygpath
119 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
120 | SEP=""
121 | for dir in $ROOTDIRSRAW ; do
122 | ROOTDIRS="$ROOTDIRS$SEP$dir"
123 | SEP="|"
124 | done
125 | OURCYGPATTERN="(^($ROOTDIRS))"
126 | # Add a user-defined pattern to the cygpath arguments
127 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then
128 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
129 | fi
130 | # Now convert the arguments - kludge to limit ourselves to /bin/sh
131 | i=0
132 | for arg in "$@" ; do
133 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
134 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
135 |
136 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
137 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
138 | else
139 | eval `echo args$i`="\"$arg\""
140 | fi
141 | i=$((i+1))
142 | done
143 | case $i in
144 | (0) set -- ;;
145 | (1) set -- "$args0" ;;
146 | (2) set -- "$args0" "$args1" ;;
147 | (3) set -- "$args0" "$args1" "$args2" ;;
148 | (4) set -- "$args0" "$args1" "$args2" "$args3" ;;
149 | (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
150 | (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
151 | (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
152 | (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
153 | (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
154 | esac
155 | fi
156 |
157 | # Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules
158 | function splitJvmOpts() {
159 | JVM_OPTS=("$@")
160 | }
161 | eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS
162 | JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME"
163 |
164 | exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@"
165 |
--------------------------------------------------------------------------------
/gradlew.bat:
--------------------------------------------------------------------------------
1 | @if "%DEBUG%" == "" @echo off
2 | @rem ##########################################################################
3 | @rem
4 | @rem Gradle startup script for Windows
5 | @rem
6 | @rem ##########################################################################
7 |
8 | @rem Set local scope for the variables with windows NT shell
9 | if "%OS%"=="Windows_NT" setlocal
10 |
11 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
12 | set DEFAULT_JVM_OPTS=
13 |
14 | set DIRNAME=%~dp0
15 | if "%DIRNAME%" == "" set DIRNAME=.
16 | set APP_BASE_NAME=%~n0
17 | set APP_HOME=%DIRNAME%
18 |
19 | @rem Find java.exe
20 | if defined JAVA_HOME goto findJavaFromJavaHome
21 |
22 | set JAVA_EXE=java.exe
23 | %JAVA_EXE% -version >NUL 2>&1
24 | if "%ERRORLEVEL%" == "0" goto init
25 |
26 | echo.
27 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
28 | echo.
29 | echo Please set the JAVA_HOME variable in your environment to match the
30 | echo location of your Java installation.
31 |
32 | goto fail
33 |
34 | :findJavaFromJavaHome
35 | set JAVA_HOME=%JAVA_HOME:"=%
36 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe
37 |
38 | if exist "%JAVA_EXE%" goto init
39 |
40 | echo.
41 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
42 | echo.
43 | echo Please set the JAVA_HOME variable in your environment to match the
44 | echo location of your Java installation.
45 |
46 | goto fail
47 |
48 | :init
49 | @rem Get command-line arguments, handling Windowz variants
50 |
51 | if not "%OS%" == "Windows_NT" goto win9xME_args
52 | if "%@eval[2+2]" == "4" goto 4NT_args
53 |
54 | :win9xME_args
55 | @rem Slurp the command line arguments.
56 | set CMD_LINE_ARGS=
57 | set _SKIP=2
58 |
59 | :win9xME_args_slurp
60 | if "x%~1" == "x" goto execute
61 |
62 | set CMD_LINE_ARGS=%*
63 | goto execute
64 |
65 | :4NT_args
66 | @rem Get arguments from the 4NT Shell from JP Software
67 | set CMD_LINE_ARGS=%$
68 |
69 | :execute
70 | @rem Setup the command line
71 |
72 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
73 |
74 | @rem Execute Gradle
75 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
76 |
77 | :end
78 | @rem End local scope for the variables with windows NT shell
79 | if "%ERRORLEVEL%"=="0" goto mainEnd
80 |
81 | :fail
82 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
83 | rem the _cmd.exe /c_ return code!
84 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
85 | exit /b 1
86 |
87 | :mainEnd
88 | if "%OS%"=="Windows_NT" endlocal
89 |
90 | :omega
91 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | 4.0.0
4 | com.medallia.word2vec
5 | medallia-word2vec
6 | 0.10.2
7 |
8 |
9 | MIT License
10 | http://opensource.org/licenses/MIT
11 | repo
12 |
13 |
14 |
15 | UTF-8
16 | UTF-8
17 |
18 |
19 | ${project.artifactId}-${project.version}
20 |
21 |
22 | maven-compiler-plugin
23 | 3.1
24 |
25 | 1.8
26 | 1.8
27 |
28 |
29 |
30 | org.apache.maven.plugins
31 | maven-eclipse-plugin
32 | 2.9
33 |
34 | true
35 | true
36 |
37 |
38 |
39 |
40 |
41 |
42 | oss-sonatype
43 | oss-sonatype
44 | https://oss.sonatype.org/content/repositories/snapshots/
45 |
46 | true
47 |
48 |
49 |
50 |
51 |
52 | org.apache.commons
53 | commons-lang3
54 | 3.1
55 |
56 |
57 | com.google.guava
58 | guava
59 | 18.0
60 |
61 |
62 | commons-io
63 | commons-io
64 | 2.4
65 |
66 |
67 | junit
68 | junit
69 | 4.11
70 |
71 |
72 | log4j
73 | log4j
74 | 1.2.17
75 |
76 |
77 | joda-time
78 | joda-time
79 | 2.3
80 |
81 |
82 | org.apache.thrift
83 | libfb303
84 | 0.9.1
85 |
86 |
87 | org.apache.commons
88 | commons-math3
89 | 3.4.1
90 |
91 |
92 |
93 |
--------------------------------------------------------------------------------
/settings.gradle:
--------------------------------------------------------------------------------
1 | rootProject.name = 'Word2VecJava'
2 |
3 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/NormalizedWord2VecModel.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec;
2 |
3 | import com.medallia.word2vec.thrift.Word2VecModelThrift;
4 |
5 | import java.io.File;
6 | import java.io.IOException;
7 | import java.nio.ByteBuffer;
8 | import java.nio.DoubleBuffer;
9 |
10 | /**
11 | * Represents a word2vec model where all the vectors are normalized to unit length.
12 | */
13 | public class NormalizedWord2VecModel extends Word2VecModel {
14 | private NormalizedWord2VecModel(Iterable vocab, int layerSize, final DoubleBuffer vectors) {
15 | super(vocab, layerSize, vectors);
16 | normalize();
17 | }
18 |
19 | private NormalizedWord2VecModel(Iterable vocab, int layerSize, double[] vectors) {
20 | super(vocab, layerSize, vectors);
21 | normalize();
22 | }
23 |
24 | public static NormalizedWord2VecModel fromWord2VecModel(Word2VecModel model) {
25 | return new NormalizedWord2VecModel(model.vocab, model.layerSize, model.vectors.duplicate());
26 | }
27 |
28 | /** @return {@link NormalizedWord2VecModel} created from a thrift representation */
29 | public static NormalizedWord2VecModel fromThrift(final Word2VecModelThrift thrift) {
30 | return fromWord2VecModel(Word2VecModel.fromThrift(thrift));
31 | }
32 |
33 | public static NormalizedWord2VecModel fromBinFile(final File file) throws IOException {
34 | return fromWord2VecModel(Word2VecModel.fromBinFile(file));
35 | }
36 |
37 | /** Normalizes the vectors in this model */
38 | private void normalize() {
39 | for(int i = 0; i < vocab.size(); ++i) {
40 | double len = 0;
41 | for(int j = i * layerSize; j < (i + 1) * layerSize; ++j)
42 | len += vectors.get(j) * vectors.get(j);
43 | len = Math.sqrt(len);
44 |
45 | for(int j = i * layerSize; j < (i + 1) * layerSize; ++j)
46 | vectors.put(j, vectors.get(j) / len);
47 | }
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/Searcher.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec;
2 |
3 | import com.google.common.base.Function;
4 | import com.google.common.collect.ImmutableList;
5 | import com.google.common.collect.Ordering;
6 |
7 | import java.util.List;
8 |
9 | /** Provides search functionality */
10 | public interface Searcher {
11 | /** @return true if a word is inside the model's vocabulary. */
12 | boolean contains(String word);
13 |
14 | /** @return Raw word vector */
15 | ImmutableList getRawVector(String word) throws UnknownWordException;
16 |
17 | /** @return Top matches to the given word */
18 | List getMatches(String word, int maxMatches) throws UnknownWordException;
19 |
20 | /** @return Top matches to the given vector */
21 | List getMatches(final double[] vec, int maxNumMatches);
22 |
23 | /** Represents the similarity between two words */
24 | public interface SemanticDifference {
25 | /** @return Top matches to the given word which share this semantic relationship */
26 | List getMatches(String word, int maxMatches) throws UnknownWordException;
27 | }
28 |
29 | /** @return {@link SemanticDifference} between the word vectors for the given */
30 | SemanticDifference similarity(String s1, String s2) throws UnknownWordException;
31 |
32 | /** @return cosine similarity between two words. */
33 | double cosineDistance(String s1, String s2) throws UnknownWordException;
34 |
35 | /** Represents a match to a search word */
36 | public interface Match {
37 | /** @return Matching word */
38 | String match();
39 | /** @return Cosine distance of the match */
40 | double distance();
41 | /** {@link Ordering} which compares {@link Match#distance()} */
42 | Ordering ORDERING = Ordering.natural().onResultOf(new Function() {
43 | @Override public Double apply(Match match) {
44 | return match.distance();
45 | }
46 | });
47 | /** {@link Function} which forwards to {@link #match()} */
48 | Function TO_WORD = new Function() {
49 | @Override public String apply(Match result) {
50 | return result.match();
51 | }
52 | };
53 | }
54 |
55 | /** Exception when a word is unknown to the {@link Word2VecModel}'s vocabulary */
56 | public static class UnknownWordException extends Exception {
57 | UnknownWordException(String word) {
58 | super(String.format("Unknown search word '%s'", word));
59 | }
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/SearcherImpl.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec;
2 |
3 | import com.google.common.base.Function;
4 | import com.google.common.collect.ImmutableList;
5 | import com.google.common.collect.ImmutableMap;
6 | import com.google.common.collect.Iterables;
7 | import com.google.common.primitives.Doubles;
8 | import com.medallia.word2vec.util.Pair;
9 |
10 | import java.nio.DoubleBuffer;
11 | import java.util.Arrays;
12 | import java.util.List;
13 |
14 | /** Implementation of {@link Searcher} */
15 | class SearcherImpl implements Searcher {
16 | private final NormalizedWord2VecModel model;
17 | private final ImmutableMap word2vectorOffset;
18 |
19 | SearcherImpl(final NormalizedWord2VecModel model) {
20 | this.model = model;
21 |
22 | final ImmutableMap.Builder result = ImmutableMap.builder();
23 | for (int i = 0; i < model.vocab.size(); i++) {
24 | result.put(model.vocab.get(i), i * model.layerSize);
25 | }
26 |
27 | word2vectorOffset = result.build();
28 | }
29 |
30 | SearcherImpl(final Word2VecModel model) {
31 | this(NormalizedWord2VecModel.fromWord2VecModel(model));
32 | }
33 |
34 | @Override public List getMatches(String s, int maxNumMatches) throws UnknownWordException {
35 | return getMatches(getVector(s), maxNumMatches);
36 | }
37 |
38 | @Override public double cosineDistance(String s1, String s2) throws UnknownWordException {
39 | return calculateDistance(getVector(s1), getVector(s2));
40 | }
41 |
42 | @Override public boolean contains(String word) {
43 | return word2vectorOffset.containsKey(word);
44 | }
45 |
46 | @Override public List getMatches(final double[] vec, int maxNumMatches) {
47 | return Match.ORDERING.greatestOf(
48 | Iterables.transform(model.vocab, new Function() {
49 | @Override
50 | public Match apply(String other) {
51 | double[] otherVec = getVectorOrNull(other);
52 | double d = calculateDistance(otherVec, vec);
53 | return new MatchImpl(other, d);
54 | }
55 | }),
56 | maxNumMatches
57 | );
58 | }
59 |
60 | private double calculateDistance(double[] otherVec, double[] vec) {
61 | double d = 0;
62 | for (int a = 0; a < model.layerSize; a++)
63 | d += vec[a] * otherVec[a];
64 | return d;
65 | }
66 |
67 | @Override public ImmutableList getRawVector(String word) throws UnknownWordException {
68 | return ImmutableList.copyOf(Doubles.asList(getVector(word)));
69 | }
70 |
71 | /**
72 | * @return Vector for the given word
73 | * @throws UnknownWordException If word is not in the model's vocabulary
74 | */
75 | private double[] getVector(String word) throws UnknownWordException {
76 | final double[] result = getVectorOrNull(word);
77 | if(result == null)
78 | throw new UnknownWordException(word);
79 |
80 | return result;
81 | }
82 |
83 | private double[] getVectorOrNull(final String word) {
84 | final Integer index = word2vectorOffset.get(word);
85 | if(index == null)
86 | return null;
87 |
88 | final DoubleBuffer vectors = model.vectors.duplicate();
89 | double[] result = new double[model.layerSize];
90 | vectors.position(index);
91 | vectors.get(result);
92 | return result;
93 | }
94 |
95 | /** @return Vector difference from v1 to v2 */
96 | private double[] getDifference(double[] v1, double[] v2) {
97 | double[] diff = new double[model.layerSize];
98 | for (int i = 0; i < model.layerSize; i++)
99 | diff[i] = v1[i] - v2[i];
100 | return diff;
101 | }
102 |
103 | @Override public SemanticDifference similarity(String s1, String s2) throws UnknownWordException {
104 | double[] v1 = getVector(s1);
105 | double[] v2 = getVector(s2);
106 | final double[] diff = getDifference(v1, v2);
107 |
108 | return new SemanticDifference() {
109 | @Override public List getMatches(String word, int maxMatches) throws UnknownWordException {
110 | double[] target = getDifference(getVector(word), diff);
111 | return SearcherImpl.this.getMatches(target, maxMatches);
112 | }
113 | };
114 | }
115 |
116 | /** Implementation of {@link Match} */
117 | private static class MatchImpl extends Pair implements Match {
118 | private MatchImpl(String first, Double second) {
119 | super(first, second);
120 | }
121 |
122 | @Override public String match() {
123 | return first;
124 | }
125 |
126 | @Override public double distance() {
127 | return second;
128 | }
129 |
130 | @Override public String toString() {
131 | return String.format("%s [%s]", first, second);
132 | }
133 | }
134 | }
135 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/Word2VecExamples.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec;
2 |
3 | import com.google.common.base.Function;
4 | import com.google.common.collect.Lists;
5 | import com.medallia.word2vec.Searcher.Match;
6 | import com.medallia.word2vec.Searcher.UnknownWordException;
7 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener;
8 | import com.medallia.word2vec.neuralnetwork.NeuralNetworkType;
9 | import com.medallia.word2vec.thrift.Word2VecModelThrift;
10 | import com.medallia.word2vec.util.AutoLog;
11 | import com.medallia.word2vec.util.Common;
12 | import com.medallia.word2vec.util.Format;
13 | import com.medallia.word2vec.util.ProfilingTimer;
14 | import com.medallia.word2vec.util.Strings;
15 | import com.medallia.word2vec.util.ThriftUtils;
16 | import org.apache.commons.logging.Log;
17 | import org.apache.thrift.TException;
18 | import org.apache.commons.io.FileUtils;
19 |
20 | import java.io.BufferedReader;
21 | import java.io.File;
22 | import java.io.IOException;
23 | import java.io.InputStreamReader;
24 | import java.io.OutputStream;
25 | import java.nio.file.Files;
26 | import java.nio.file.Paths;
27 | import java.util.Arrays;
28 | import java.util.List;
29 |
30 | /** Example usages of {@link Word2VecModel} */
31 | public class Word2VecExamples {
32 | private static final Log LOG = AutoLog.getLog();
33 |
34 | /** Runs the example */
35 | public static void main(String[] args) throws IOException, TException, UnknownWordException, InterruptedException {
36 | demoWord();
37 | }
38 |
39 | /**
40 | * Trains a model and allows user to find similar words
41 | * demo-word.sh example from the open source C implementation
42 | */
43 | public static void demoWord() throws IOException, TException, InterruptedException, UnknownWordException {
44 | File f = new File("text8");
45 | if (!f.exists())
46 | throw new IllegalStateException("Please download and unzip the text8 example from http://mattmahoney.net/dc/text8.zip");
47 | List read = Common.readToList(f);
48 | List> partitioned = Lists.transform(read, new Function>() {
49 | @Override
50 | public List apply(String input) {
51 | return Arrays.asList(input.split(" "));
52 | }
53 | });
54 |
55 | Word2VecModel model = Word2VecModel.trainer()
56 | .setMinVocabFrequency(5)
57 | .useNumThreads(20)
58 | .setWindowSize(8)
59 | .type(NeuralNetworkType.CBOW)
60 | .setLayerSize(200)
61 | .useNegativeSamples(25)
62 | .setDownSamplingRate(1e-4)
63 | .setNumIterations(5)
64 | .setListener(new TrainingProgressListener() {
65 | @Override public void update(Stage stage, double progress) {
66 | System.out.println(String.format("%s is %.2f%% complete", Format.formatEnum(stage), progress * 100));
67 | }
68 | })
69 | .train(partitioned);
70 |
71 | // Writes model to a thrift file
72 | try (ProfilingTimer timer = ProfilingTimer.create(LOG, "Writing output to file")) {
73 | FileUtils.writeStringToFile(new File("text8.model"), ThriftUtils.serializeJson(model.toThrift()));
74 | }
75 |
76 | // Alternatively, you can write the model to a bin file that's compatible with the C
77 | // implementation.
78 | try(final OutputStream os = Files.newOutputStream(Paths.get("text8.bin"))) {
79 | model.toBinFile(os);
80 | }
81 |
82 | interact(model.forSearch());
83 | }
84 |
85 | /** Loads a model and allows user to find similar words */
86 | public static void loadModel() throws IOException, TException, UnknownWordException {
87 | final Word2VecModel model;
88 | try (ProfilingTimer timer = ProfilingTimer.create(LOG, "Loading model")) {
89 | String json = Common.readFileToString(new File("text8.model"));
90 | model = Word2VecModel.fromThrift(ThriftUtils.deserializeJson(new Word2VecModelThrift(), json));
91 | }
92 | interact(model.forSearch());
93 | }
94 |
95 | /** Example using Skip-Gram model */
96 | public static void skipGram() throws IOException, TException, InterruptedException, UnknownWordException {
97 | List read = Common.readToList(new File("sents.cleaned.word2vec.txt"));
98 | List> partitioned = Lists.transform(read, new Function>() {
99 | @Override
100 | public List apply(String input) {
101 | return Arrays.asList(input.split(" "));
102 | }
103 | });
104 |
105 | Word2VecModel model = Word2VecModel.trainer()
106 | .setMinVocabFrequency(100)
107 | .useNumThreads(20)
108 | .setWindowSize(7)
109 | .type(NeuralNetworkType.SKIP_GRAM)
110 | .useHierarchicalSoftmax()
111 | .setLayerSize(300)
112 | .useNegativeSamples(0)
113 | .setDownSamplingRate(1e-3)
114 | .setNumIterations(5)
115 | .setListener(new TrainingProgressListener() {
116 | @Override public void update(Stage stage, double progress) {
117 | System.out.println(String.format("%s is %.2f%% complete", Format.formatEnum(stage), progress * 100));
118 | }
119 | })
120 | .train(partitioned);
121 |
122 | try (ProfilingTimer timer = ProfilingTimer.create(LOG, "Writing output to file")) {
123 | FileUtils.writeStringToFile(new File("300layer.20threads.5iter.model"), ThriftUtils.serializeJson(model.toThrift()));
124 | }
125 |
126 | interact(model.forSearch());
127 | }
128 |
129 | private static void interact(Searcher searcher) throws IOException, UnknownWordException {
130 | try (BufferedReader br = new BufferedReader(new InputStreamReader(System.in))) {
131 | while (true) {
132 | System.out.print("Enter word or sentence (EXIT to break): ");
133 | String word = br.readLine();
134 | if (word.equals("EXIT")) {
135 | break;
136 | }
137 | List matches = searcher.getMatches(word, 20);
138 | System.out.println(Strings.joinObjects("\n", matches));
139 | }
140 | }
141 | }
142 | }
143 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/Word2VecModel.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec;
2 |
3 | import java.io.File;
4 | import java.io.FileInputStream;
5 | import java.io.IOException;
6 | import java.io.OutputStream;
7 | import java.nio.ByteBuffer;
8 | import java.nio.ByteOrder;
9 | import java.nio.DoubleBuffer;
10 | import java.nio.FloatBuffer;
11 | import java.nio.MappedByteBuffer;
12 | import java.nio.channels.FileChannel;
13 | import java.nio.charset.Charset;
14 | import java.util.ArrayList;
15 | import java.util.List;
16 |
17 | import com.google.common.annotations.VisibleForTesting;
18 | import com.google.common.base.Preconditions;
19 | import com.google.common.collect.ImmutableList;
20 | import com.google.common.collect.Lists;
21 | import com.google.common.primitives.Doubles;
22 | import com.medallia.word2vec.thrift.Word2VecModelThrift;
23 | import com.medallia.word2vec.util.Common;
24 | import com.medallia.word2vec.util.ProfilingTimer;
25 | import com.medallia.word2vec.util.AC;
26 |
27 |
28 | /**
29 | * Represents the Word2Vec model, containing vectors for each word
30 | *
31 | * Instances of this class are obtained via:
32 | *
33 | *
{@link #trainer()}
34 | *
{@link #fromThrift(Word2VecModelThrift)}
35 | *
36 | *
37 | * @see {@link #forSearch()}
38 | */
39 | public class Word2VecModel {
40 | final List vocab;
41 | final int layerSize;
42 | final DoubleBuffer vectors;
43 | private final static long ONE_GB = 1024 * 1024 * 1024;
44 |
45 | Word2VecModel(Iterable vocab, int layerSize, DoubleBuffer vectors) {
46 | this.vocab = ImmutableList.copyOf(vocab);
47 | this.layerSize = layerSize;
48 | this.vectors = vectors;
49 | }
50 |
51 | Word2VecModel(Iterable vocab, int layerSize, double[] vectors) {
52 | this(vocab, layerSize, DoubleBuffer.wrap(vectors));
53 | }
54 |
55 | /** @return Vocabulary */
56 | public Iterable getVocab() {
57 | return vocab;
58 | }
59 |
60 | /** @return Layer size */
61 | public int getLayerSize() {
62 | return layerSize;
63 | }
64 |
65 | /** @return {@link Searcher} for searching */
66 | public Searcher forSearch() {
67 | return new SearcherImpl(this);
68 | }
69 |
70 | /** @return Serializable thrift representation */
71 | public Word2VecModelThrift toThrift() {
72 | double[] vectorsArray;
73 | if(vectors.hasArray()) {
74 | vectorsArray = vectors.array();
75 | } else {
76 | vectorsArray = new double[vectors.limit()];
77 | vectors.position(0);
78 | vectors.get(vectorsArray);
79 | }
80 |
81 | return new Word2VecModelThrift()
82 | .setVocab(vocab)
83 | .setLayerSize(layerSize)
84 | .setVectors(Doubles.asList(vectorsArray));
85 | }
86 |
87 | /** @return {@link Word2VecModel} created from a thrift representation */
88 | public static Word2VecModel fromThrift(Word2VecModelThrift thrift) {
89 | return new Word2VecModel(
90 | thrift.getVocab(),
91 | thrift.getLayerSize(),
92 | Doubles.toArray(thrift.getVectors()));
93 | }
94 |
95 | /**
96 | * @return {@link Word2VecModel} read from a file in the text output format of the Word2Vec C
97 | * open source project.
98 | */
99 | public static Word2VecModel fromTextFile(File file) throws IOException {
100 | List lines = Common.readToList(file);
101 | return fromTextFile(file.getAbsolutePath(), lines);
102 | }
103 |
104 | /**
105 | * Forwards to {@link #fromBinFile(File, ByteOrder, ProfilingTimer)} with the default
106 | * ByteOrder.LITTLE_ENDIAN and no ProfilingTimer
107 | */
108 | public static Word2VecModel fromBinFile(File file)
109 | throws IOException {
110 | return fromBinFile(file, ByteOrder.LITTLE_ENDIAN, ProfilingTimer.NONE);
111 | }
112 |
113 | /**
114 | * Forwards to {@link #fromBinFile(File, ByteOrder, ProfilingTimer)} with no ProfilingTimer
115 | */
116 | public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder)
117 | throws IOException {
118 | return fromBinFile(file, byteOrder, ProfilingTimer.NONE);
119 | }
120 |
121 | /**
122 | * @return {@link Word2VecModel} created from the binary representation output
123 | * by the open source C version of word2vec using the given byte order.
124 | */
125 | public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder, ProfilingTimer timer)
126 | throws IOException {
127 |
128 | try (
129 | final FileInputStream fis = new FileInputStream(file);
130 | final AC ac = timer.start("Loading vectors from bin file")
131 | ) {
132 | final FileChannel channel = fis.getChannel();
133 | timer.start("Reading gigabyte #1");
134 | MappedByteBuffer buffer =
135 | channel.map(
136 | FileChannel.MapMode.READ_ONLY,
137 | 0,
138 | Math.min(channel.size(), Integer.MAX_VALUE));
139 | buffer.order(byteOrder);
140 | int bufferCount = 1;
141 | // Java's NIO only allows memory-mapping up to 2GB. To work around this problem, we re-map
142 | // every gigabyte. To calculate offsets correctly, we have to keep track how many gigabytes
143 | // we've already skipped. That's what this is for.
144 |
145 | StringBuilder sb = new StringBuilder();
146 | char c = (char) buffer.get();
147 | while (c != '\n') {
148 | sb.append(c);
149 | c = (char) buffer.get();
150 | }
151 | String firstLine = sb.toString();
152 | int index = firstLine.indexOf(' ');
153 | Preconditions.checkState(index != -1,
154 | "Expected a space in the first line of file '%s': '%s'",
155 | file.getAbsolutePath(), firstLine);
156 |
157 | final int vocabSize = Integer.parseInt(firstLine.substring(0, index));
158 | final int layerSize = Integer.parseInt(firstLine.substring(index + 1));
159 | timer.appendToLog(String.format(
160 | "Loading %d vectors with dimensionality %d",
161 | vocabSize,
162 | layerSize));
163 |
164 | List vocabs = new ArrayList(vocabSize);
165 | DoubleBuffer vectors = ByteBuffer.allocateDirect(vocabSize * layerSize * 8).asDoubleBuffer();
166 |
167 | long lastLogMessage = System.currentTimeMillis();
168 | final float[] floats = new float[layerSize];
169 | for (int lineno = 0; lineno < vocabSize; lineno++) {
170 | // read vocab
171 | sb.setLength(0);
172 | c = (char) buffer.get();
173 | while (c != ' ') {
174 | // ignore newlines in front of words (some binary files have newline,
175 | // some don't)
176 | if (c != '\n') {
177 | sb.append(c);
178 | }
179 | c = (char) buffer.get();
180 | }
181 | vocabs.add(sb.toString());
182 |
183 | // read vector
184 | final FloatBuffer floatBuffer = buffer.asFloatBuffer();
185 | floatBuffer.get(floats);
186 | for (int i = 0; i < floats.length; ++i) {
187 | vectors.put(lineno * layerSize + i, floats[i]);
188 | }
189 | buffer.position(buffer.position() + 4 * layerSize);
190 |
191 | // print log
192 | final long now = System.currentTimeMillis();
193 | if (now - lastLogMessage > 1000) {
194 | final double percentage = ((double) (lineno + 1) / (double) vocabSize) * 100.0;
195 | timer.appendToLog(
196 | String.format("Loaded %d/%d vectors (%f%%)", lineno + 1, vocabSize, percentage));
197 | lastLogMessage = now;
198 | }
199 |
200 | // remap file
201 | if (buffer.position() > ONE_GB) {
202 | final int newPosition = (int) (buffer.position() - ONE_GB);
203 | final long size = Math.min(channel.size() - ONE_GB * bufferCount, Integer.MAX_VALUE);
204 | timer.endAndStart(
205 | "Reading gigabyte #%d. Start: %d, size: %d",
206 | bufferCount,
207 | ONE_GB * bufferCount,
208 | size);
209 | buffer = channel.map(
210 | FileChannel.MapMode.READ_ONLY,
211 | ONE_GB * bufferCount,
212 | size);
213 | buffer.order(byteOrder);
214 | buffer.position(newPosition);
215 | bufferCount += 1;
216 | }
217 | }
218 | timer.end();
219 |
220 | return new Word2VecModel(vocabs, layerSize, vectors);
221 | }
222 | }
223 |
224 | /**
225 | * Saves the model as a bin file that's compatible with the C version of Word2Vec
226 | */
227 | public void toBinFile(final OutputStream out) throws IOException {
228 | final Charset cs = Charset.forName("UTF-8");
229 | final String header = String.format("%d %d\n", vocab.size(), layerSize);
230 | out.write(header.getBytes(cs));
231 |
232 | final double[] vector = new double[layerSize];
233 | final ByteBuffer buffer = ByteBuffer.allocate(4 * layerSize);
234 | buffer.order(ByteOrder.LITTLE_ENDIAN); // The C version uses this byte order.
235 | for(int i = 0; i < vocab.size(); ++i) {
236 | out.write(String.format("%s ", vocab.get(i)).getBytes(cs));
237 |
238 | vectors.position(i * layerSize);
239 | vectors.get(vector);
240 | buffer.clear();
241 | for(int j = 0; j < layerSize; ++j)
242 | buffer.putFloat((float)vector[j]);
243 | out.write(buffer.array());
244 |
245 | out.write('\n');
246 | }
247 |
248 | out.flush();
249 | }
250 |
251 | /**
252 | * @return {@link Word2VecModel} from the lines of the file in the text output format of the
253 | * Word2Vec C open source project.
254 | */
255 | @VisibleForTesting
256 | static Word2VecModel fromTextFile(String filename, List lines) throws IOException {
257 | List vocab = Lists.newArrayList();
258 | List vectors = Lists.newArrayList();
259 | int vocabSize = Integer.parseInt(lines.get(0).split(" ")[0]);
260 | int layerSize = Integer.parseInt(lines.get(0).split(" ")[1]);
261 |
262 | Preconditions.checkArgument(
263 | vocabSize == lines.size() - 1,
264 | "For file '%s', vocab size is %s, but there are %s word vectors in the file",
265 | filename,
266 | vocabSize,
267 | lines.size() - 1
268 | );
269 |
270 | for (int n = 1; n < lines.size(); n++) {
271 | String[] values = lines.get(n).split(" ");
272 | vocab.add(values[0]);
273 |
274 | // Sanity check
275 | Preconditions.checkArgument(
276 | layerSize == values.length - 1,
277 | "For file '%s', on line %s, layer size is %s, but found %s values in the word vector",
278 | filename,
279 | n,
280 | layerSize,
281 | values.length - 1
282 | );
283 |
284 | for (int d = 1; d < values.length; d++) {
285 | vectors.add(Double.parseDouble(values[d]));
286 | }
287 | }
288 |
289 | Word2VecModelThrift thrift = new Word2VecModelThrift()
290 | .setLayerSize(layerSize)
291 | .setVocab(vocab)
292 | .setVectors(vectors);
293 | return fromThrift(thrift);
294 | }
295 |
296 | /** @return {@link Word2VecTrainerBuilder} for training a model */
297 | public static Word2VecTrainerBuilder trainer() {
298 | return new Word2VecTrainerBuilder();
299 | }
300 | }
301 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/Word2VecTrainer.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec;
2 |
3 | import com.google.common.base.Optional;
4 | import com.google.common.base.Predicate;
5 | import com.google.common.collect.HashMultiset;
6 | import com.google.common.collect.ImmutableMultiset;
7 | import com.google.common.collect.ImmutableSortedMultiset;
8 | import com.google.common.collect.Iterables;
9 | import com.google.common.collect.Multiset;
10 | import com.google.common.collect.Multisets;
11 | import com.google.common.primitives.Doubles;
12 | import com.medallia.word2vec.util.AC;
13 | import com.medallia.word2vec.util.ProfilingTimer;
14 | import org.apache.commons.logging.Log;
15 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener;
16 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener.Stage;
17 | import com.medallia.word2vec.huffman.HuffmanCoding;
18 | import com.medallia.word2vec.huffman.HuffmanCoding.HuffmanNode;
19 | import com.medallia.word2vec.neuralnetwork.NeuralNetworkConfig;
20 | import com.medallia.word2vec.neuralnetwork.NeuralNetworkTrainer.NeuralNetworkModel;
21 |
22 | import java.util.List;
23 | import java.util.Map;
24 |
25 | /** Responsible for training a word2vec model */
26 | class Word2VecTrainer {
27 | private final int minFrequency;
28 | private final Optional> vocab;
29 | private final NeuralNetworkConfig neuralNetworkConfig;
30 |
31 | Word2VecTrainer(
32 | Integer minFrequency,
33 | Optional> vocab,
34 | NeuralNetworkConfig neuralNetworkConfig) {
35 | this.vocab = vocab;
36 | this.minFrequency = minFrequency;
37 | this.neuralNetworkConfig = neuralNetworkConfig;
38 | }
39 |
40 | /** @return {@link Multiset} containing unique tokens and their counts */
41 | private static Multiset count(Iterable tokens) {
42 | Multiset counts = HashMultiset.create();
43 | for (String token : tokens)
44 | counts.add(token);
45 | return counts;
46 | }
47 |
48 | /** @return Tokens with their count, sorted by frequency decreasing, then lexicographically ascending */
49 | private ImmutableMultiset filterAndSort(final Multiset counts) {
50 | // This isn't terribly efficient, but it is deterministic
51 | // Unfortunately, Guava's multiset doesn't give us a clean way to order both by count and element
52 | return Multisets.copyHighestCountFirst(
53 | ImmutableSortedMultiset.copyOf(
54 | Multisets.filter(
55 | counts,
56 | new Predicate() {
57 | @Override
58 | public boolean apply(String s) {
59 | return counts.count(s) >= minFrequency;
60 | }
61 | }
62 | )
63 | )
64 | );
65 |
66 | }
67 |
68 | /** Train a model using the given data */
69 | Word2VecModel train(Log log, TrainingProgressListener listener, Iterable> sentences) throws InterruptedException {
70 | try (ProfilingTimer timer = ProfilingTimer.createLoggingSubtasks(log, "Training word2vec")) {
71 | final Multiset counts;
72 |
73 | try (AC ac = timer.start("Acquiring word frequencies")) {
74 | listener.update(Stage.ACQUIRE_VOCAB, 0.0);
75 | counts = (vocab.isPresent())
76 | ? vocab.get()
77 | : count(Iterables.concat(sentences));
78 | }
79 |
80 | final ImmutableMultiset vocab;
81 | try (AC ac = timer.start("Filtering and sorting vocabulary")) {
82 | listener.update(Stage.FILTER_SORT_VOCAB, 0.0);
83 | vocab = filterAndSort(counts);
84 | }
85 |
86 | final Map huffmanNodes;
87 | try (AC task = timer.start("Create Huffman encoding")) {
88 | huffmanNodes = new HuffmanCoding(vocab, listener).encode();
89 | }
90 |
91 | final NeuralNetworkModel model;
92 | try (AC task = timer.start("Training model %s", neuralNetworkConfig)) {
93 | model = neuralNetworkConfig.createTrainer(vocab, huffmanNodes, listener).train(sentences);
94 | }
95 |
96 | return new Word2VecModel(vocab.elementSet(), model.layerSize(), Doubles.concat(model.vectors()));
97 | }
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/Word2VecTrainerBuilder.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec;
2 |
3 | import com.google.common.base.MoreObjects;
4 | import com.google.common.base.Optional;
5 | import com.google.common.base.Preconditions;
6 | import com.google.common.collect.Multiset;
7 | import com.medallia.word2vec.util.AutoLog;
8 | import org.apache.commons.logging.Log;
9 | import com.medallia.word2vec.neuralnetwork.NeuralNetworkConfig;
10 | import com.medallia.word2vec.neuralnetwork.NeuralNetworkType;
11 |
12 | import java.util.List;
13 | import java.util.Map;
14 |
15 | /**
16 | * Builder pattern for training a new {@link Word2VecModel}
17 | *
18 | * This is a port of the open source C implementation of word2vec
19 | *
20 | * Note that this isn't a completely faithful rewrite, specifically:
21 | *
22 | *
When building the vocabulary from the training file:
23 | *
24 | *
The original version does a reduction step when learning the vocabulary from the file
25 | * when the vocab size hits 21 million words, removing any words that do not meet the
26 | * minimum frequency threshold. This Java port has no such reduction step.
27 | *
The original version injects a </s> token into the vocabulary (with a word count of 0)
28 | * as a substitute for newlines in the input file. This Java port's vocabulary excludes the token.
29 | *
30 | *
In partitioning the file for processing
31 | *
32 | *
The original version assumes that sentences are delimited by newline characters and injects a sentence
33 | * boundary per 1000 non-filtered tokens, i.e. valid token by the vocabulary and not removed by the randomized
34 | * sampling process. Java port mimics this behavior for now ...
35 | *
When the original version encounters an empty line in the input file, it re-processes the first word of the
36 | * last non-empty line with a sentence length of 0 and updates the random value. Java port omits this behavior.
37 | *
38 | *
In the sampling function
39 | *
40 | *
The original C documentation indicates that the range should be between 0 and 1e-5, but the default value is 1e-3.
41 | * This Java port retains that confusing information.
42 | *
The random value generated for comparison to determine if a token should be filtered uses a float.
43 | * This Java port uses double precision for twice the fun
44 | *
45 | *
In the distance function to find the nearest matches to a target query
46 | *
47 | *
The original version includes an unnecessary normalization of the vector for the input query which
48 | * may lead to tiny inaccuracies. This Java port foregoes this superfluous operation.
49 | *
The original version has an O(n * k) algorithm for finding top matches and is hardcoded to 40 matches.
50 | * This Java port uses Google's lovely {@link com.google.common.collect.Ordering#greatestOf(java.util.Iterator, int)}
51 | * which is O(n + k log k) and takes in arbitrary k.
52 | *
53 | *
The k-means clustering option is excluded in the Java port
54 | *
55 | *
56 | *
57 | *
58 | * Please do not hesitate to peek at the source code.
59 | *
60 | * It should be readable, concise, and correct.
61 | *
62 | * I ask you to reach out if it is not.
63 | */
64 | public class Word2VecTrainerBuilder {
65 | private static final Log LOG = AutoLog.getLog();
66 |
67 | private Integer layerSize;
68 | private Integer windowSize;
69 | private Integer numThreads;
70 | private NeuralNetworkType type;
71 | private int negativeSamples;
72 | private boolean useHierarchicalSoftmax;
73 | private Multiset vocab;
74 | private Integer minFrequency;
75 | private Double initialLearningRate;
76 | private Double downSampleRate;
77 | private Integer iterations;
78 | private TrainingProgressListener listener;
79 |
80 | Word2VecTrainerBuilder() {
81 | }
82 |
83 | /**
84 | * Size of the layers in the neural network model
85 | *
86 | * Defaults to 100
87 | */
88 | public Word2VecTrainerBuilder setLayerSize(int layerSize) {
89 | Preconditions.checkArgument(layerSize > 0, "Value must be positive");
90 | this.layerSize = layerSize;
91 | return this;
92 | }
93 |
94 | /**
95 | * Size of the window to consider
96 | *
97 | * Default window size is 5 tokens
98 | */
99 | public Word2VecTrainerBuilder setWindowSize(int windowSize) {
100 | Preconditions.checkArgument(windowSize > 0, "Value must be positive");
101 | this.windowSize = windowSize;
102 | return this;
103 | }
104 |
105 | /**
106 | * Specify number of threads to use for parallelization
107 | *
119 | * By default, word2vec uses the {@link NeuralNetworkType#SKIP_GRAM}
120 | */
121 | public Word2VecTrainerBuilder type(NeuralNetworkType type) {
122 | this.type = Preconditions.checkNotNull(type);
123 | return this;
124 | }
125 |
126 | /**
127 | * Specify to use hierarchical softmax
128 | *
129 | * By default, word2vec does not use hierarchical softmax
130 | */
131 | public Word2VecTrainerBuilder useHierarchicalSoftmax() {
132 | this.useHierarchicalSoftmax = true;
133 | return this;
134 | }
135 |
136 | /**
137 | * Number of negative samples to use
138 | * Common values are between 5 and 10
139 | *
140 | * Defaults to 0
141 | */
142 | public Word2VecTrainerBuilder useNegativeSamples(int negativeSamples) {
143 | Preconditions.checkArgument(negativeSamples >= 0, "Value must be non-negative");
144 | this.negativeSamples = negativeSamples;
145 | return this;
146 | }
147 |
148 | /**
149 | * Use a pre-built vocabulary
150 | *
151 | * If this is not specified, word2vec will attempt to learn a vocabulary from the training data
152 | * @param vocab {@link Map} from token to frequency
153 | */
154 | public Word2VecTrainerBuilder useVocab(Multiset vocab) {
155 | this.vocab = Preconditions.checkNotNull(vocab);
156 | return this;
157 | }
158 |
159 | /**
160 | * Specify the minimum frequency for a valid token to be considered
161 | * part of the vocabulary
162 | *
163 | * Defaults to 5
164 | */
165 | public Word2VecTrainerBuilder setMinVocabFrequency(int minFrequency) {
166 | Preconditions.checkArgument(minFrequency >= 0, "Value must be non-negative");
167 | this.minFrequency = minFrequency;
168 | return this;
169 | }
170 |
171 | /**
172 | * Set the starting learning rate
173 | *
174 | * Default is 0.025 for skip-gram and 0.05 for CBOW
175 | */
176 | public Word2VecTrainerBuilder setInitialLearningRate(double initialLearningRate) {
177 | Preconditions.checkArgument(initialLearningRate >= 0, "Value must be non-negative");
178 | this.initialLearningRate = initialLearningRate;
179 | return this;
180 | }
181 |
182 | /**
183 | * Set threshold for occurrence of words. Those that appear with higher frequency in the training data,
184 | * e.g. stopwords, will be randomly removed
185 | *
259 | * Note that this is called in a separate thread from the processing thread
260 | * @param stage Current {@link Stage} of processing
261 | * @param progress Progress of the current stage as a double value between 0 and 1
262 | */
263 | void update(Stage stage, double progress);
264 | }
265 | }
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/huffman/HuffmanCoding.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.huffman;
2 |
3 | import com.google.common.base.Preconditions;
4 | import com.google.common.collect.ImmutableMap;
5 | import com.google.common.collect.ImmutableMultiset;
6 | import com.google.common.collect.Multiset.Entry;
7 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener;
8 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener.Stage;
9 |
10 | import java.util.ArrayList;
11 | import java.util.Map;
12 |
13 | /**
14 | * Word2Vec library relies on a Huffman encoding scheme
15 | *
16 | * Note that the generated codes and the index of the parents are both used in the
17 | * hierarchical softmax portions of the neural network training phase
18 | *
19 | */
20 | public class HuffmanCoding {
21 | /** Node */
22 | public static class HuffmanNode {
23 | /** Array of 0's and 1's */
24 | public final byte[] code;
25 | /** Array of parent node index offsets */
26 | public final int[] point;
27 | /** Index of the Huffman node */
28 | public final int idx;
29 | /** Frequency of the token */
30 | public final int count;
31 |
32 | private HuffmanNode(byte[] code, int[] point, int idx, int count) {
33 | this.code = code;
34 | this.point = point;
35 | this.idx = idx;
36 | this.count = count;
37 | }
38 | }
39 |
40 | private final ImmutableMultiset vocab;
41 | private final TrainingProgressListener listener;
42 |
43 | /**
44 | * @param vocab {@link Multiset} of tokens, sorted by frequency descending
45 | * @param listener Progress listener
46 | */
47 | public HuffmanCoding(ImmutableMultiset vocab, TrainingProgressListener listener) {
48 | this.vocab = vocab;
49 | this.listener = listener;
50 | }
51 |
52 | /**
53 | * @return {@link Map} from each given token to a {@link HuffmanNode}
54 | */
55 | public Map encode() throws InterruptedException {
56 | final int numTokens = vocab.elementSet().size();
57 |
58 | int[] parentNode = new int[numTokens * 2 + 1];
59 | byte[] binary = new byte[numTokens * 2 + 1];
60 | long[] count = new long[numTokens * 2 + 1];
61 | int i = 0;
62 | for (Entry e : vocab.entrySet()) {
63 | count[i] = e.getCount();
64 | i++;
65 | }
66 | Preconditions.checkState(i == numTokens, "Expected %s to match %s", i, numTokens);
67 | for (i = numTokens; i < count.length; i++)
68 | count[i] = (long)1e15;
69 |
70 | createTree(numTokens, count, binary, parentNode);
71 |
72 | return encode(binary, parentNode);
73 | }
74 |
75 | /**
76 | * Populate the count, binary, and parentNode arrays with the Huffman tree
77 | * This uses the linear time method assuming that the count array is sorted
78 | */
79 | private void createTree(int numTokens, long[] count, byte[] binary, int[] parentNode) throws InterruptedException {
80 | int min1i;
81 | int min2i;
82 | int pos1 = numTokens - 1;
83 | int pos2 = numTokens;
84 |
85 | // Construct the Huffman tree by adding one node at a time
86 | for (int a = 0; a < numTokens - 1; a++) {
87 | // First, find two smallest nodes 'min1, min2'
88 | if (pos1 >= 0) {
89 | if (count[pos1] < count[pos2]) {
90 | min1i = pos1;
91 | pos1--;
92 | } else {
93 | min1i = pos2;
94 | pos2++;
95 | }
96 | } else {
97 | min1i = pos2;
98 | pos2++;
99 | }
100 |
101 | if (pos1 >= 0) {
102 | if (count[pos1] < count[pos2]) {
103 | min2i = pos1;
104 | pos1--;
105 | } else {
106 | min2i = pos2;
107 | pos2++;
108 | }
109 | } else {
110 | min2i = pos2;
111 | pos2++;
112 | }
113 |
114 | int newNodeIdx = numTokens + a;
115 | count[newNodeIdx] = count[min1i] + count[min2i];
116 | parentNode[min1i] = newNodeIdx;
117 | parentNode[min2i] = newNodeIdx;
118 | binary[min2i] = 1;
119 |
120 | if (a % 1_000 == 0) {
121 | if (Thread.currentThread().isInterrupted())
122 | throw new InterruptedException("Interrupted while encoding huffman tree");
123 | listener.update(Stage.CREATE_HUFFMAN_ENCODING, (0.5 * a) / numTokens);
124 | }
125 | }
126 | }
127 |
128 | /** @return Ordered map from each token to its {@link HuffmanNode}, ordered by frequency descending */
129 | private Map encode(byte[] binary, int[] parentNode) throws InterruptedException {
130 | int numTokens = vocab.elementSet().size();
131 |
132 | // Now assign binary code to each unique token
133 | ImmutableMap.Builder result = ImmutableMap.builder();
134 | int nodeIdx = 0;
135 | for (Entry e : vocab.entrySet()) {
136 | int curNodeIdx = nodeIdx;
137 | ArrayList code = new ArrayList<>();
138 | ArrayList points = new ArrayList<>();
139 | while (true) {
140 | code.add(binary[curNodeIdx]);
141 | points.add(curNodeIdx);
142 | curNodeIdx = parentNode[curNodeIdx];
143 | if (curNodeIdx == numTokens * 2 - 2)
144 | break;
145 | }
146 | int codeLen = code.size();
147 | final int count = e.getCount();
148 | final byte[] rawCode = new byte[codeLen];
149 | final int[] rawPoints = new int[codeLen + 1];
150 |
151 | rawPoints[0] = numTokens - 2;
152 | for (int i = 0; i < codeLen; i++) {
153 | rawCode[codeLen - i - 1] = code.get(i);
154 | rawPoints[codeLen - i] = points.get(i) - numTokens;
155 | }
156 |
157 | String token = e.getElement();
158 | result.put(token, new HuffmanNode(rawCode, rawPoints, nodeIdx, count));
159 |
160 | if (nodeIdx % 1_000 == 0) {
161 | if (Thread.currentThread().isInterrupted())
162 | throw new InterruptedException("Interrupted while encoding huffman tree");
163 | listener.update(Stage.CREATE_HUFFMAN_ENCODING, 0.5 + (0.5 * nodeIdx) / numTokens);
164 | }
165 |
166 | nodeIdx++;
167 | }
168 |
169 | return result.build();
170 | }
171 | }
172 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/neuralnetwork/CBOWModelTrainer.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.neuralnetwork;
2 |
3 | import com.google.common.collect.Multiset;
4 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener;
5 | import com.medallia.word2vec.huffman.HuffmanCoding.HuffmanNode;
6 |
7 | import java.util.List;
8 | import java.util.Map;
9 |
10 | /**
11 | * Trainer for neural network using continuous bag of words
12 | */
13 | class CBOWModelTrainer extends NeuralNetworkTrainer {
14 |
15 | CBOWModelTrainer(NeuralNetworkConfig config, Multiset counts, Map huffmanNodes, TrainingProgressListener listener) {
16 | super(config, counts, huffmanNodes, listener);
17 | }
18 |
19 | /** {@link Worker} for {@link CBOWModelTrainer} */
20 | private class CBOWWorker extends Worker {
21 | private CBOWWorker(int randomSeed, int iter, Iterable> batch) {
22 | super(randomSeed, iter, batch);
23 | }
24 |
25 | @Override void trainSentence(List sentence) {
26 | int sentenceLength = sentence.size();
27 |
28 | for (int sentencePosition = 0; sentencePosition < sentenceLength; sentencePosition++) {
29 | String word = sentence.get(sentencePosition);
30 | HuffmanNode huffmanNode = huffmanNodes.get(word);
31 |
32 | for (int c = 0; c < layer1_size; c++)
33 | neu1[c] = 0;
34 | for (int c = 0; c < layer1_size; c++)
35 | neu1e[c] = 0;
36 |
37 | nextRandom = incrementRandom(nextRandom);
38 | int b = (int)((nextRandom % window) + window) % window;
39 |
40 | // in -> hidden
41 | int cw = 0;
42 | for (int a = b; a < window * 2 + 1 - b; a++) {
43 | if (a == window)
44 | continue;
45 | int c = sentencePosition - window + a;
46 | if (c < 0 || c >= sentenceLength)
47 | continue;
48 | int idx = huffmanNodes.get(sentence.get(c)).idx;
49 | for (int d = 0; d < layer1_size; d++) {
50 | neu1[d] += syn0[idx][d];
51 | }
52 |
53 | cw++;
54 | }
55 |
56 | if (cw == 0)
57 | continue;
58 |
59 | for (int c = 0; c < layer1_size; c++)
60 | neu1[c] /= cw;
61 |
62 | if (config.useHierarchicalSoftmax) {
63 | for (int d = 0; d < huffmanNode.code.length; d++) {
64 | double f = 0;
65 | int l2 = huffmanNode.point[d];
66 | // Propagate hidden -> output
67 | for (int c = 0; c < layer1_size; c++)
68 | f += neu1[c] * syn1[l2][c];
69 | if (f <= -MAX_EXP || f >= MAX_EXP)
70 | continue;
71 | else
72 | f = EXP_TABLE[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];
73 | // 'g' is the gradient multiplied by the learning rate
74 | double g = (1 - huffmanNode.code[d] - f) * alpha;
75 | // Propagate errors output -> hidden
76 | for (int c = 0; c < layer1_size; c++)
77 | neu1e[c] += g * syn1[l2][c];
78 | // Learn weights hidden -> output
79 | for (int c = 0; c < layer1_size; c++)
80 | syn1[l2][c] += g * neu1[c];
81 | }
82 | }
83 |
84 | handleNegativeSampling(huffmanNode);
85 |
86 | // hidden -> in
87 | for (int a = b; a < window * 2 + 1 - b; a++) {
88 | if (a == window)
89 | continue;
90 | int c = sentencePosition - window + a;
91 | if (c < 0 || c >= sentenceLength)
92 | continue;
93 | int idx = huffmanNodes.get(sentence.get(c)).idx;
94 | for (int d = 0; d < layer1_size; d++)
95 | syn0[idx][d] += neu1e[d];
96 | }
97 | }
98 | }
99 | }
100 |
101 | @Override Worker createWorker(int randomSeed, int iter, Iterable> batch) {
102 | return new CBOWWorker(randomSeed, iter, batch);
103 | }
104 | }
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkConfig.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.neuralnetwork;
2 |
3 | import com.google.common.collect.ImmutableMultiset;
4 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener;
5 | import com.medallia.word2vec.huffman.HuffmanCoding.HuffmanNode;
6 |
7 | import java.util.Map;
8 |
9 | /** Fixed configuration for training the neural network */
10 | public class NeuralNetworkConfig {
11 | final int numThreads;
12 | final int iterations;
13 | final NeuralNetworkType type;
14 | final int layerSize;
15 | final int windowSize;
16 | final int negativeSamples;
17 | final boolean useHierarchicalSoftmax;
18 |
19 | final double initialLearningRate;
20 | final double downSampleRate;
21 |
22 | /** Constructor */
23 | public NeuralNetworkConfig(
24 | NeuralNetworkType type,
25 | int numThreads,
26 | int iterations,
27 | int layerSize,
28 | int windowSize,
29 | int negativeSamples,
30 | double downSampleRate,
31 | double initialLearningRate,
32 | boolean useHierarchicalSoftmax) {
33 | this.type = type;
34 | this.iterations = iterations;
35 | this.numThreads = numThreads;
36 | this.layerSize = layerSize;
37 | this.windowSize = windowSize;
38 | this.negativeSamples = negativeSamples;
39 | this.useHierarchicalSoftmax = useHierarchicalSoftmax;
40 | this.initialLearningRate = initialLearningRate;
41 | this.downSampleRate = downSampleRate;
42 | }
43 |
44 | /** @return {@link NeuralNetworkTrainer} */
45 | public NeuralNetworkTrainer createTrainer(ImmutableMultiset vocab, Map huffmanNodes, TrainingProgressListener listener) {
46 | return type.createTrainer(this, vocab, huffmanNodes, listener);
47 | }
48 |
49 | @Override public String toString() {
50 | return String.format("%s with %s threads, %s iterations[%s layer size, %s window, %s hierarchical softmax, %s negative samples, %s initial learning rate, %s down sample rate]",
51 | type.name(),
52 | numThreads,
53 | iterations,
54 | layerSize,
55 | windowSize,
56 | useHierarchicalSoftmax ? "using" : "not using",
57 | negativeSamples,
58 | initialLearningRate,
59 | downSampleRate
60 | );
61 | }
62 | }
63 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.neuralnetwork;
2 |
3 | import com.google.common.collect.Iterables;
4 | import com.google.common.collect.Multiset;
5 | import com.google.common.util.concurrent.Futures;
6 | import com.google.common.util.concurrent.ListenableFuture;
7 | import com.google.common.util.concurrent.ListeningExecutorService;
8 | import com.google.common.util.concurrent.MoreExecutors;
9 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener;
10 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener.Stage;
11 | import com.medallia.word2vec.huffman.HuffmanCoding.HuffmanNode;
12 | import com.medallia.word2vec.util.CallableVoid;
13 |
14 | import java.util.ArrayList;
15 | import java.util.Iterator;
16 | import java.util.List;
17 | import java.util.Map;
18 | import java.util.concurrent.ExecutionException;
19 | import java.util.concurrent.Executors;
20 | import java.util.concurrent.atomic.AtomicInteger;
21 |
22 | /** Parent class for training word2vec's neural network */
23 | public abstract class NeuralNetworkTrainer {
24 | /** Sentences longer than this are broken into multiple chunks */
25 | private static final int MAX_SENTENCE_LENGTH = 1_000;
26 |
27 | /** Boundary for maximum exponent allowed */
28 | static final int MAX_EXP = 6;
29 |
30 | /** Size of the pre-cached exponent table */
31 | static final int EXP_TABLE_SIZE = 1_000;
32 | static final double[] EXP_TABLE = new double[EXP_TABLE_SIZE];
33 | static {
34 | for (int i = 0; i < EXP_TABLE_SIZE; i++) {
35 | // Precompute the exp() table
36 | EXP_TABLE[i] = Math.exp((i / (double)EXP_TABLE_SIZE * 2 - 1) * MAX_EXP);
37 | // Precompute f(x) = x / (x + 1)
38 | EXP_TABLE[i] /= EXP_TABLE[i] + 1;
39 | }
40 | }
41 |
42 | private static final int TABLE_SIZE = (int)1e8;
43 |
44 | private final TrainingProgressListener listener;
45 |
46 | final NeuralNetworkConfig config;
47 | final Map huffmanNodes;
48 | private final int vocabSize;
49 | final int layer1_size;
50 | final int window;
51 | /**
52 | * In the C version, this includes the token that replaces a newline character
53 | */
54 | int numTrainedTokens;
55 |
56 | /* The following includes shared state that is updated per worker thread */
57 |
58 | /**
59 | * To be precise, this is the number of words in the training data that exist in the vocabulary
60 | * which have been processed so far. It includes words that are discarded from sampling.
61 | * Note that each word is processed once per iteration.
62 | */
63 | protected final AtomicInteger actualWordCount;
64 | /** Learning rate, affects how fast values in the layers get updated */
65 | volatile double alpha;
66 | /**
67 | * This contains the outer layers of the neural network
68 | * First dimension is the vocab, second is the layer
69 | */
70 | final double[][] syn0;
71 | /** This contains hidden layers of the neural network */
72 | final double[][] syn1;
73 | /** This is used for negative sampling */
74 | private final double[][] syn1neg;
75 | /** Used for negative sampling */
76 | private final int[] table;
77 | long startNano;
78 |
79 | NeuralNetworkTrainer(NeuralNetworkConfig config, Multiset vocab, Map huffmanNodes, TrainingProgressListener listener) {
80 | this.config = config;
81 | this.huffmanNodes = huffmanNodes;
82 | this.listener = listener;
83 | this.vocabSize = huffmanNodes.size();
84 | this.numTrainedTokens = vocab.size();
85 | this.layer1_size = config.layerSize;
86 | this.window = config.windowSize;
87 |
88 | this.actualWordCount = new AtomicInteger();
89 | this.alpha = config.initialLearningRate;
90 |
91 | this.syn0 = new double[vocabSize][layer1_size];
92 | this.syn1 = new double[vocabSize][layer1_size];
93 | this.syn1neg = new double[vocabSize][layer1_size];
94 | this.table = new int[TABLE_SIZE];
95 |
96 | initializeSyn0();
97 | initializeUnigramTable();
98 | }
99 |
100 | private void initializeUnigramTable() {
101 | long trainWordsPow = 0;
102 | double power = 0.75;
103 |
104 | for (HuffmanNode node : huffmanNodes.values()) {
105 | trainWordsPow += Math.pow(node.count, power);
106 | }
107 |
108 | Iterator nodeIter = huffmanNodes.values().iterator();
109 | HuffmanNode last = nodeIter.next();
110 | double d1 = Math.pow(last.count, power) / trainWordsPow;
111 | int i = 0;
112 | for (int a = 0; a < TABLE_SIZE; a++) {
113 | table[a] = i;
114 | if (a / (double)TABLE_SIZE > d1) {
115 | i++;
116 | HuffmanNode next = nodeIter.hasNext()
117 | ? nodeIter.next()
118 | : last;
119 |
120 | d1 += Math.pow(next.count, power) / trainWordsPow;
121 |
122 | last = next;
123 | }
124 | }
125 | }
126 |
127 | private void initializeSyn0() {
128 | long nextRandom = 1;
129 | for (int a = 0; a < huffmanNodes.size(); a++) {
130 | // Consume a random for fun
131 | // Actually we do this to use up the injected token
132 | nextRandom = incrementRandom(nextRandom);
133 | for (int b = 0; b < layer1_size; b++) {
134 | nextRandom = incrementRandom(nextRandom);
135 | syn0[a][b] = (((nextRandom & 0xFFFF) / (double)65_536) - 0.5) / layer1_size;
136 | }
137 | }
138 | }
139 |
140 | /** @return Next random value to use */
141 | static long incrementRandom(long r) {
142 | return r * 25_214_903_917L + 11;
143 | }
144 |
145 | /** Represents a neural network model */
146 | public interface NeuralNetworkModel {
147 | /** Size of the layers */
148 | int layerSize();
149 | /** Resulting vectors */
150 | double[][] vectors();
151 | }
152 |
153 | /** @return Trained NN model */
154 | public NeuralNetworkModel train(Iterable> sentences) throws InterruptedException {
155 | ListeningExecutorService ex = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(config.numThreads));
156 |
157 | int numSentences = Iterables.size(sentences);
158 | numTrainedTokens += numSentences;
159 |
160 | // Partition the sentences evenly amongst the threads
161 | Iterable>> partitioned = Iterables.partition(sentences, numSentences / config.numThreads + 1);
162 |
163 | try {
164 | listener.update(Stage.TRAIN_NEURAL_NETWORK, 0.0);
165 | for (int iter = config.iterations; iter > 0; iter--) {
166 | List tasks = new ArrayList<>();
167 | int i = 0;
168 | for (final List> batch : partitioned) {
169 | tasks.add(createWorker(i, iter, batch));
170 | i++;
171 | }
172 |
173 | List> futures = new ArrayList<>(tasks.size());
174 | for (CallableVoid task : tasks)
175 | futures.add(ex.submit(task));
176 | try {
177 | Futures.allAsList(futures).get();
178 | } catch (ExecutionException e) {
179 | throw new IllegalStateException("Error training neural network", e.getCause());
180 | }
181 | }
182 | ex.shutdown();
183 | } finally {
184 | ex.shutdownNow();
185 | }
186 |
187 | return new NeuralNetworkModel() {
188 | @Override public int layerSize() {
189 | return config.layerSize;
190 | }
191 |
192 | @Override public double[][] vectors() {
193 | return syn0;
194 | }
195 | };
196 | }
197 |
198 | /** @return {@link Worker} to process the given sentences */
199 | abstract Worker createWorker(int randomSeed, int iter, Iterable> batch);
200 |
201 | /** Worker thread that updates the neural network model */
202 | abstract class Worker extends CallableVoid {
203 | private static final int LEARNING_RATE_UPDATE_FREQUENCY = 10_000;
204 |
205 | long nextRandom;
206 | final int iter;
207 | final Iterable> batch;
208 |
209 | /**
210 | * The number of words observed in the training data for this worker that exist
211 | * in the vocabulary. It includes words that are discarded from sampling.
212 | */
213 | int wordCount;
214 | /** Value of wordCount the last time alpha was updated */
215 | int lastWordCount;
216 |
217 | final double[] neu1 = new double[layer1_size];
218 | final double[] neu1e = new double[layer1_size];
219 |
220 | Worker(int randomSeed, int iter, Iterable> batch) {
221 | this.nextRandom = randomSeed;
222 | this.iter = iter;
223 | this.batch = batch;
224 | }
225 |
226 | @Override public void run() throws InterruptedException {
227 | for (List sentence : batch) {
228 | List filteredSentence = new ArrayList<>(sentence.size());
229 | for (String s : sentence) {
230 | if (!huffmanNodes.containsKey(s))
231 | continue;
232 |
233 | wordCount++;
234 | if (config.downSampleRate > 0) {
235 | HuffmanNode huffmanNode = huffmanNodes.get(s);
236 | double random = (Math.sqrt(huffmanNode.count / (config.downSampleRate * numTrainedTokens)) + 1)
237 | * (config.downSampleRate * numTrainedTokens) / huffmanNode.count;
238 | nextRandom = incrementRandom(nextRandom);
239 | if (random < (nextRandom & 0xFFFF) / (double)65_536) {
240 | continue;
241 | }
242 | }
243 |
244 | filteredSentence.add(s);
245 | }
246 |
247 | // Increment word count one extra for the injected token
248 | // Turns out if you don't do this, the produced word vectors aren't as tasty
249 | wordCount++;
250 |
251 | Iterable> partitioned = Iterables.partition(filteredSentence, MAX_SENTENCE_LENGTH);
252 | for (List chunked : partitioned) {
253 | if (Thread.currentThread().isInterrupted())
254 | throw new InterruptedException("Interrupted while training word2vec model");
255 |
256 | if (wordCount - lastWordCount > LEARNING_RATE_UPDATE_FREQUENCY) {
257 | updateAlpha(iter);
258 | }
259 | trainSentence(chunked);
260 | }
261 | }
262 |
263 | actualWordCount.addAndGet(wordCount - lastWordCount);
264 | }
265 |
266 | /**
267 | * Degrades the learning rate (alpha) steadily towards 0
268 | * @param iter Only used for debugging
269 | */
270 | private void updateAlpha(int iter) {
271 | int currentActual = actualWordCount.addAndGet(wordCount - lastWordCount);
272 | lastWordCount = wordCount;
273 |
274 | // Degrade the learning rate linearly towards 0 but keep a minimum
275 | alpha = config.initialLearningRate * Math.max(
276 | 1 - currentActual / (double)(config.iterations * numTrainedTokens),
277 | 0.0001
278 | );
279 |
280 | listener.update(
281 | Stage.TRAIN_NEURAL_NETWORK,
282 | currentActual / (double) (config.iterations * numTrainedTokens + 1)
283 | );
284 | }
285 |
286 | void handleNegativeSampling(HuffmanNode huffmanNode) {
287 | for (int d = 0; d <= config.negativeSamples; d++) {
288 | int target;
289 | final int label;
290 | if (d == 0) {
291 | target = huffmanNode.idx;
292 | label = 1;
293 | } else {
294 | nextRandom = incrementRandom(nextRandom);
295 | target = table[(int) (((nextRandom >> 16) % TABLE_SIZE) + TABLE_SIZE) % TABLE_SIZE];
296 | if (target == 0)
297 | target = (int)(((nextRandom % (vocabSize - 1)) + vocabSize - 1) % (vocabSize - 1)) + 1;
298 | if (target == huffmanNode.idx)
299 | continue;
300 | label = 0;
301 | }
302 | int l2 = target;
303 | double f = 0;
304 | for (int c = 0; c < layer1_size; c++)
305 | f += neu1[c] * syn1neg[l2][c];
306 | final double g;
307 | if (f > MAX_EXP)
308 | g = (label - 1) * alpha;
309 | else if (f < -MAX_EXP)
310 | g = (label - 0) * alpha;
311 | else
312 | g = (label - EXP_TABLE[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]) * alpha;
313 | for (int c = 0; c < layer1_size; c++)
314 | neu1e[c] += g * syn1neg[l2][c];
315 | for (int c = 0; c < layer1_size; c++)
316 | syn1neg[l2][c] += g * neu1[c];
317 | }
318 | }
319 |
320 | /** Update the model with the given raw sentence */
321 | abstract void trainSentence(List unfiltered);
322 | }
323 | }
324 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkType.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.neuralnetwork;
2 |
3 | import com.google.common.collect.Multiset;
4 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener;
5 | import com.medallia.word2vec.huffman.HuffmanCoding.HuffmanNode;
6 |
7 | import java.util.Map;
8 |
9 | /**
10 | * Supported types for the neural network
11 | */
12 | public enum NeuralNetworkType {
13 | /** Faster, slightly better accuracy for frequent words */
14 | CBOW {
15 | @Override NeuralNetworkTrainer createTrainer(NeuralNetworkConfig config, Multiset counts, Map huffmanNodes, TrainingProgressListener listener) {
16 | return new CBOWModelTrainer(config, counts, huffmanNodes, listener);
17 | }
18 |
19 | @Override public double getDefaultInitialLearningRate() {
20 | return 0.05;
21 | }
22 | },
23 | /** Slower, better for infrequent words */
24 | SKIP_GRAM {
25 | @Override NeuralNetworkTrainer createTrainer(NeuralNetworkConfig config, Multiset counts, Map huffmanNodes, TrainingProgressListener listener) {
26 | return new SkipGramModelTrainer(config, counts, huffmanNodes, listener);
27 | }
28 |
29 | @Override public double getDefaultInitialLearningRate() {
30 | return 0.025;
31 | }
32 | },
33 | ;
34 |
35 | /** @return Default initial learning rate */
36 | public abstract double getDefaultInitialLearningRate();
37 |
38 | /** @return New {@link NeuralNetworkTrainer} */
39 | abstract NeuralNetworkTrainer createTrainer(NeuralNetworkConfig config, Multiset counts, Map huffmanNodes, TrainingProgressListener listener);
40 | }
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/neuralnetwork/SkipGramModelTrainer.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.neuralnetwork;
2 |
3 | import com.google.common.collect.Multiset;
4 | import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener;
5 | import com.medallia.word2vec.huffman.HuffmanCoding.HuffmanNode;
6 |
7 | import java.util.List;
8 | import java.util.Map;
9 |
10 | /**
11 | * Trainer for neural network using skip gram
12 | */
13 | class SkipGramModelTrainer extends NeuralNetworkTrainer {
14 |
15 | SkipGramModelTrainer(NeuralNetworkConfig config, Multiset counts, Map huffmanNodes, TrainingProgressListener listener) {
16 | super(config, counts, huffmanNodes, listener);
17 | }
18 |
19 | /** {@link Worker} for {@link SkipGramModelTrainer} */
20 | private class SkipGramWorker extends Worker {
21 | private SkipGramWorker(int randomSeed, int iter, Iterable> batch) {
22 | super(randomSeed, iter, batch);
23 | }
24 |
25 | @Override void trainSentence(List sentence) {
26 | int sentenceLength = sentence.size();
27 |
28 | for (int sentencePosition = 0; sentencePosition < sentenceLength; sentencePosition++) {
29 | String word = sentence.get(sentencePosition);
30 | HuffmanNode huffmanNode = huffmanNodes.get(word);
31 |
32 | for (int c = 0; c < layer1_size; c++)
33 | neu1[c] = 0;
34 | for (int c = 0; c < layer1_size; c++)
35 | neu1e[c] = 0;
36 | nextRandom = incrementRandom(nextRandom);
37 |
38 | int b = (int)(((nextRandom % window) + nextRandom) % window);
39 |
40 | for (int a = b; a < window * 2 + 1 - b; a++) {
41 | if (a == window)
42 | continue;
43 | int c = sentencePosition - window + a;
44 |
45 | if (c < 0 || c >= sentenceLength)
46 | continue;
47 | for (int d = 0; d < layer1_size; d++)
48 | neu1e[d] = 0;
49 |
50 | int l1 = huffmanNodes.get(sentence.get(c)).idx;
51 |
52 | if (config.useHierarchicalSoftmax) {
53 | for (int d = 0; d < huffmanNode.code.length; d++) {
54 | double f = 0;
55 | int l2 = huffmanNode.point[d];
56 | // Propagate hidden -> output
57 | for (int e = 0; e < layer1_size; e++)
58 | f += syn0[l1][e] * syn1[l2][e];
59 |
60 | if (f <= -MAX_EXP || f >= MAX_EXP)
61 | continue;
62 | else
63 | f = EXP_TABLE[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];
64 | // 'g' is the gradient multiplied by the learning rate
65 | double g = (1 - huffmanNode.code[d] - f) * alpha;
66 |
67 | // Propagate errors output -> hidden
68 | for (int e = 0; e < layer1_size; e++)
69 | neu1e[e] += g * syn1[l2][e];
70 | // Learn weights hidden -> output
71 | for (int e = 0; e < layer1_size; e++)
72 | syn1[l2][e] += g * syn0[l1][e];
73 | }
74 | }
75 |
76 | handleNegativeSampling(huffmanNode);
77 |
78 | // Learn weights input -> hidden
79 | for (int d = 0; d < layer1_size; d++) {
80 | syn0[l1][d] += neu1e[d];
81 | }
82 | }
83 | }
84 | }
85 | }
86 |
87 | @Override Worker createWorker(int randomSeed, int iter, Iterable> batch) {
88 | return new SkipGramWorker(randomSeed, iter, batch);
89 | }
90 | }
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/thrift/Word2VecModelThrift.java:
--------------------------------------------------------------------------------
1 | /**
2 | * Autogenerated by Thrift Compiler (0.9.1)
3 | *
4 | * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
5 | * @generated
6 | */
7 | package com.medallia.word2vec.thrift;
8 |
9 | import org.apache.commons.lang3.builder.HashCodeBuilder;
10 | import org.apache.thrift.EncodingUtils;
11 | import org.apache.thrift.protocol.TTupleProtocol;
12 | import org.apache.thrift.scheme.IScheme;
13 | import org.apache.thrift.scheme.SchemeFactory;
14 | import org.apache.thrift.scheme.StandardScheme;
15 | import org.apache.thrift.scheme.TupleScheme;
16 |
17 | import java.util.ArrayList;
18 | import java.util.BitSet;
19 | import java.util.Collections;
20 | import java.util.EnumMap;
21 | import java.util.EnumSet;
22 | import java.util.HashMap;
23 | import java.util.List;
24 | import java.util.Map;
25 |
26 | public class Word2VecModelThrift implements org.apache.thrift.TBase, java.io.Serializable, Cloneable, Comparable {
27 | private static final org.apache.thrift.protocol.TStruct STRUCT_DESC = new org.apache.thrift.protocol.TStruct("Word2VecModelThrift");
28 |
29 | private static final org.apache.thrift.protocol.TField VOCAB_FIELD_DESC = new org.apache.thrift.protocol.TField("vocab", org.apache.thrift.protocol.TType.LIST, (short)1);
30 | private static final org.apache.thrift.protocol.TField LAYER_SIZE_FIELD_DESC = new org.apache.thrift.protocol.TField("layerSize", org.apache.thrift.protocol.TType.I32, (short)2);
31 | private static final org.apache.thrift.protocol.TField VECTORS_FIELD_DESC = new org.apache.thrift.protocol.TField("vectors", org.apache.thrift.protocol.TType.LIST, (short)3);
32 |
33 | private static final Map, SchemeFactory> schemes = new HashMap, SchemeFactory>();
34 | static {
35 | schemes.put(StandardScheme.class, new Word2VecModelThriftStandardSchemeFactory());
36 | schemes.put(TupleScheme.class, new Word2VecModelThriftTupleSchemeFactory());
37 | }
38 |
39 | private List vocab; // optional
40 | private int layerSize; // optional
41 | private List vectors; // optional
42 |
43 | /** The set of fields this struct contains, along with convenience methods for finding and manipulating them. */
44 | public enum _Fields implements org.apache.thrift.TFieldIdEnum {
45 | VOCAB((short)1, "vocab"),
46 | LAYER_SIZE((short)2, "layerSize"),
47 | VECTORS((short)3, "vectors");
48 |
49 | private static final Map byName = new HashMap();
50 |
51 | static {
52 | for (_Fields field : EnumSet.allOf(_Fields.class)) {
53 | byName.put(field.getFieldName(), field);
54 | }
55 | }
56 |
57 | /**
58 | * Find the _Fields constant that matches fieldId, or null if its not found.
59 | */
60 | public static _Fields findByThriftId(int fieldId) {
61 | switch(fieldId) {
62 | case 1: // VOCAB
63 | return VOCAB;
64 | case 2: // LAYER_SIZE
65 | return LAYER_SIZE;
66 | case 3: // VECTORS
67 | return VECTORS;
68 | default:
69 | return null;
70 | }
71 | }
72 |
73 | /**
74 | * Find the _Fields constant that matches fieldId, throwing an exception
75 | * if it is not found.
76 | */
77 | public static _Fields findByThriftIdOrThrow(int fieldId) {
78 | _Fields fields = findByThriftId(fieldId);
79 | if (fields == null) throw new IllegalArgumentException("Field " + fieldId + " doesn't exist!");
80 | return fields;
81 | }
82 |
83 | /**
84 | * Find the _Fields constant that matches name, or null if its not found.
85 | */
86 | public static _Fields findByName(String name) {
87 | return byName.get(name);
88 | }
89 |
90 | private final short _thriftId;
91 | private final String _fieldName;
92 |
93 | _Fields(short thriftId, String fieldName) {
94 | _thriftId = thriftId;
95 | _fieldName = fieldName;
96 | }
97 |
98 | public short getThriftFieldId() {
99 | return _thriftId;
100 | }
101 |
102 | public String getFieldName() {
103 | return _fieldName;
104 | }
105 | }
106 |
107 | // isset id assignments
108 | private static final int __LAYERSIZE_ISSET_ID = 0;
109 | private byte __isset_bitfield = 0;
110 | private static _Fields optionals[] = {_Fields.VOCAB, _Fields.LAYER_SIZE, _Fields.VECTORS};
111 | public static final Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> metaDataMap;
112 | static {
113 | Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> tmpMap = new EnumMap<_Fields, org.apache.thrift.meta_data.FieldMetaData>(_Fields.class);
114 | tmpMap.put(_Fields.VOCAB, new org.apache.thrift.meta_data.FieldMetaData("vocab", org.apache.thrift.TFieldRequirementType.OPTIONAL,
115 | new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST,
116 | new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING))));
117 | tmpMap.put(_Fields.LAYER_SIZE, new org.apache.thrift.meta_data.FieldMetaData("layerSize", org.apache.thrift.TFieldRequirementType.OPTIONAL,
118 | new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32)));
119 | tmpMap.put(_Fields.VECTORS, new org.apache.thrift.meta_data.FieldMetaData("vectors", org.apache.thrift.TFieldRequirementType.OPTIONAL,
120 | new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST,
121 | new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.DOUBLE))));
122 | metaDataMap = Collections.unmodifiableMap(tmpMap);
123 | org.apache.thrift.meta_data.FieldMetaData.addStructMetaDataMap(Word2VecModelThrift.class, metaDataMap);
124 | }
125 |
126 | public Word2VecModelThrift() {
127 | }
128 |
129 | /**
130 | * Performs a deep copy on other.
131 | */
132 | public Word2VecModelThrift(Word2VecModelThrift other) {
133 | __isset_bitfield = other.__isset_bitfield;
134 | if (other.isSetVocab()) {
135 | List __this__vocab = new ArrayList(other.vocab);
136 | this.vocab = __this__vocab;
137 | }
138 | this.layerSize = other.layerSize;
139 | if (other.isSetVectors()) {
140 | List __this__vectors = new ArrayList(other.vectors);
141 | this.vectors = __this__vectors;
142 | }
143 | }
144 |
145 | public Word2VecModelThrift deepCopy() {
146 | return new Word2VecModelThrift(this);
147 | }
148 |
149 | @Override
150 | public void clear() {
151 | this.vocab = null;
152 | setLayerSizeIsSet(false);
153 | this.layerSize = 0;
154 | this.vectors = null;
155 | }
156 |
157 | public int getVocabSize() {
158 | return (this.vocab == null) ? 0 : this.vocab.size();
159 | }
160 |
161 | public java.util.Iterator getVocabIterator() {
162 | return (this.vocab == null) ? null : this.vocab.iterator();
163 | }
164 |
165 | public void addToVocab(String elem) {
166 | if (this.vocab == null) {
167 | this.vocab = new ArrayList();
168 | }
169 | this.vocab.add(elem);
170 | }
171 |
172 | public List getVocab() {
173 | return this.vocab;
174 | }
175 |
176 | public Word2VecModelThrift setVocab(List vocab) {
177 | this.vocab = vocab;
178 | return this;
179 | }
180 |
181 | public void unsetVocab() {
182 | this.vocab = null;
183 | }
184 |
185 | /** Returns true if field vocab is set (has been assigned a value) and false otherwise */
186 | public boolean isSetVocab() {
187 | return this.vocab != null;
188 | }
189 |
190 | public void setVocabIsSet(boolean value) {
191 | if (!value) {
192 | this.vocab = null;
193 | }
194 | }
195 |
196 | public int getLayerSize() {
197 | return this.layerSize;
198 | }
199 |
200 | public Word2VecModelThrift setLayerSize(int layerSize) {
201 | this.layerSize = layerSize;
202 | setLayerSizeIsSet(true);
203 | return this;
204 | }
205 |
206 | public void unsetLayerSize() {
207 | __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __LAYERSIZE_ISSET_ID);
208 | }
209 |
210 | /** Returns true if field layerSize is set (has been assigned a value) and false otherwise */
211 | public boolean isSetLayerSize() {
212 | return EncodingUtils.testBit(__isset_bitfield, __LAYERSIZE_ISSET_ID);
213 | }
214 |
215 | public void setLayerSizeIsSet(boolean value) {
216 | __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __LAYERSIZE_ISSET_ID, value);
217 | }
218 |
219 | public int getVectorsSize() {
220 | return (this.vectors == null) ? 0 : this.vectors.size();
221 | }
222 |
223 | public java.util.Iterator getVectorsIterator() {
224 | return (this.vectors == null) ? null : this.vectors.iterator();
225 | }
226 |
227 | public void addToVectors(double elem) {
228 | if (this.vectors == null) {
229 | this.vectors = new ArrayList();
230 | }
231 | this.vectors.add(elem);
232 | }
233 |
234 | public List getVectors() {
235 | return this.vectors;
236 | }
237 |
238 | public Word2VecModelThrift setVectors(List vectors) {
239 | this.vectors = vectors;
240 | return this;
241 | }
242 |
243 | public void unsetVectors() {
244 | this.vectors = null;
245 | }
246 |
247 | /** Returns true if field vectors is set (has been assigned a value) and false otherwise */
248 | public boolean isSetVectors() {
249 | return this.vectors != null;
250 | }
251 |
252 | public void setVectorsIsSet(boolean value) {
253 | if (!value) {
254 | this.vectors = null;
255 | }
256 | }
257 |
258 | public void setFieldValue(_Fields field, Object value) {
259 | switch (field) {
260 | case VOCAB:
261 | if (value == null) {
262 | unsetVocab();
263 | } else {
264 | setVocab((List)value);
265 | }
266 | break;
267 |
268 | case LAYER_SIZE:
269 | if (value == null) {
270 | unsetLayerSize();
271 | } else {
272 | setLayerSize((Integer)value);
273 | }
274 | break;
275 |
276 | case VECTORS:
277 | if (value == null) {
278 | unsetVectors();
279 | } else {
280 | setVectors((List)value);
281 | }
282 | break;
283 |
284 | }
285 | }
286 |
287 | public Object getFieldValue(_Fields field) {
288 | switch (field) {
289 | case VOCAB:
290 | return getVocab();
291 |
292 | case LAYER_SIZE:
293 | return Integer.valueOf(getLayerSize());
294 |
295 | case VECTORS:
296 | return getVectors();
297 |
298 | }
299 | throw new IllegalStateException();
300 | }
301 |
302 | /** Returns true if field corresponding to fieldID is set (has been assigned a value) and false otherwise */
303 | public boolean isSet(_Fields field) {
304 | if (field == null) {
305 | throw new IllegalArgumentException();
306 | }
307 |
308 | switch (field) {
309 | case VOCAB:
310 | return isSetVocab();
311 | case LAYER_SIZE:
312 | return isSetLayerSize();
313 | case VECTORS:
314 | return isSetVectors();
315 | }
316 | throw new IllegalStateException();
317 | }
318 |
319 | @Override
320 | public boolean equals(Object that) {
321 | if (that == null)
322 | return false;
323 | if (that instanceof Word2VecModelThrift)
324 | return this.equals((Word2VecModelThrift)that);
325 | return false;
326 | }
327 |
328 | public boolean equals(Word2VecModelThrift that) {
329 | if (that == null)
330 | return false;
331 |
332 | boolean this_present_vocab = true && this.isSetVocab();
333 | boolean that_present_vocab = true && that.isSetVocab();
334 | if (this_present_vocab || that_present_vocab) {
335 | if (!(this_present_vocab && that_present_vocab))
336 | return false;
337 | if (!this.vocab.equals(that.vocab))
338 | return false;
339 | }
340 |
341 | boolean this_present_layerSize = true && this.isSetLayerSize();
342 | boolean that_present_layerSize = true && that.isSetLayerSize();
343 | if (this_present_layerSize || that_present_layerSize) {
344 | if (!(this_present_layerSize && that_present_layerSize))
345 | return false;
346 | if (this.layerSize != that.layerSize)
347 | return false;
348 | }
349 |
350 | boolean this_present_vectors = true && this.isSetVectors();
351 | boolean that_present_vectors = true && that.isSetVectors();
352 | if (this_present_vectors || that_present_vectors) {
353 | if (!(this_present_vectors && that_present_vectors))
354 | return false;
355 | if (!this.vectors.equals(that.vectors))
356 | return false;
357 | }
358 |
359 | return true;
360 | }
361 |
362 | @Override
363 | public int hashCode() {
364 | HashCodeBuilder builder = new HashCodeBuilder();
365 |
366 | boolean present_vocab = true && (isSetVocab());
367 | builder.append(present_vocab);
368 | if (present_vocab)
369 | builder.append(vocab);
370 |
371 | boolean present_layerSize = true && (isSetLayerSize());
372 | builder.append(present_layerSize);
373 | if (present_layerSize)
374 | builder.append(layerSize);
375 |
376 | boolean present_vectors = true && (isSetVectors());
377 | builder.append(present_vectors);
378 | if (present_vectors)
379 | builder.append(vectors);
380 |
381 | return builder.toHashCode();
382 | }
383 |
384 | @Override
385 | public int compareTo(Word2VecModelThrift other) {
386 | if (!getClass().equals(other.getClass())) {
387 | return getClass().getName().compareTo(other.getClass().getName());
388 | }
389 |
390 | int lastComparison = 0;
391 |
392 | lastComparison = Boolean.valueOf(isSetVocab()).compareTo(other.isSetVocab());
393 | if (lastComparison != 0) {
394 | return lastComparison;
395 | }
396 | if (isSetVocab()) {
397 | lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.vocab, other.vocab);
398 | if (lastComparison != 0) {
399 | return lastComparison;
400 | }
401 | }
402 | lastComparison = Boolean.valueOf(isSetLayerSize()).compareTo(other.isSetLayerSize());
403 | if (lastComparison != 0) {
404 | return lastComparison;
405 | }
406 | if (isSetLayerSize()) {
407 | lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.layerSize, other.layerSize);
408 | if (lastComparison != 0) {
409 | return lastComparison;
410 | }
411 | }
412 | lastComparison = Boolean.valueOf(isSetVectors()).compareTo(other.isSetVectors());
413 | if (lastComparison != 0) {
414 | return lastComparison;
415 | }
416 | if (isSetVectors()) {
417 | lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.vectors, other.vectors);
418 | if (lastComparison != 0) {
419 | return lastComparison;
420 | }
421 | }
422 | return 0;
423 | }
424 |
425 | public _Fields fieldForId(int fieldId) {
426 | return _Fields.findByThriftId(fieldId);
427 | }
428 |
429 | public void read(org.apache.thrift.protocol.TProtocol iprot) throws org.apache.thrift.TException {
430 | schemes.get(iprot.getScheme()).getScheme().read(iprot, this);
431 | }
432 |
433 | public void write(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException {
434 | schemes.get(oprot.getScheme()).getScheme().write(oprot, this);
435 | }
436 |
437 | @Override
438 | public String toString() {
439 | StringBuilder sb = new StringBuilder("Word2VecModelThrift(");
440 | boolean first = true;
441 |
442 | if (isSetVocab()) {
443 | sb.append("vocab:");
444 | if (this.vocab == null) {
445 | sb.append("null");
446 | } else {
447 | sb.append(this.vocab);
448 | }
449 | first = false;
450 | }
451 | if (isSetLayerSize()) {
452 | if (!first) sb.append(", ");
453 | sb.append("layerSize:");
454 | sb.append(this.layerSize);
455 | first = false;
456 | }
457 | if (isSetVectors()) {
458 | if (!first) sb.append(", ");
459 | sb.append("vectors:");
460 | if (this.vectors == null) {
461 | sb.append("null");
462 | } else {
463 | sb.append(this.vectors);
464 | }
465 | first = false;
466 | }
467 | sb.append(")");
468 | return sb.toString();
469 | }
470 |
471 | public void validate() throws org.apache.thrift.TException {
472 | // check for required fields
473 | // check for sub-struct validity
474 | }
475 |
476 | private void writeObject(java.io.ObjectOutputStream out) throws java.io.IOException {
477 | try {
478 | write(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(out)));
479 | } catch (org.apache.thrift.TException te) {
480 | throw new java.io.IOException(te);
481 | }
482 | }
483 |
484 | private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException {
485 | try {
486 | // it doesn't seem like you should have to do this, but java serialization is wacky, and doesn't call the default constructor.
487 | __isset_bitfield = 0;
488 | read(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(in)));
489 | } catch (org.apache.thrift.TException te) {
490 | throw new java.io.IOException(te);
491 | }
492 | }
493 |
494 | private static class Word2VecModelThriftStandardSchemeFactory implements SchemeFactory {
495 | public Word2VecModelThriftStandardScheme getScheme() {
496 | return new Word2VecModelThriftStandardScheme();
497 | }
498 | }
499 |
500 | private static class Word2VecModelThriftStandardScheme extends StandardScheme {
501 |
502 | public void read(org.apache.thrift.protocol.TProtocol iprot, Word2VecModelThrift struct) throws org.apache.thrift.TException {
503 | org.apache.thrift.protocol.TField schemeField;
504 | iprot.readStructBegin();
505 | while (true)
506 | {
507 | schemeField = iprot.readFieldBegin();
508 | if (schemeField.type == org.apache.thrift.protocol.TType.STOP) {
509 | break;
510 | }
511 | switch (schemeField.id) {
512 | case 1: // VOCAB
513 | if (schemeField.type == org.apache.thrift.protocol.TType.LIST) {
514 | {
515 | org.apache.thrift.protocol.TList _list0 = iprot.readListBegin();
516 | struct.vocab = new ArrayList(_list0.size);
517 | for (int _i1 = 0; _i1 < _list0.size; ++_i1)
518 | {
519 | String _elem2;
520 | _elem2 = iprot.readString();
521 | struct.vocab.add(_elem2);
522 | }
523 | iprot.readListEnd();
524 | }
525 | struct.setVocabIsSet(true);
526 | } else {
527 | org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type);
528 | }
529 | break;
530 | case 2: // LAYER_SIZE
531 | if (schemeField.type == org.apache.thrift.protocol.TType.I32) {
532 | struct.layerSize = iprot.readI32();
533 | struct.setLayerSizeIsSet(true);
534 | } else {
535 | org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type);
536 | }
537 | break;
538 | case 3: // VECTORS
539 | if (schemeField.type == org.apache.thrift.protocol.TType.LIST) {
540 | {
541 | org.apache.thrift.protocol.TList _list3 = iprot.readListBegin();
542 | struct.vectors = new ArrayList(_list3.size);
543 | for (int _i4 = 0; _i4 < _list3.size; ++_i4)
544 | {
545 | double _elem5;
546 | _elem5 = iprot.readDouble();
547 | struct.vectors.add(_elem5);
548 | }
549 | iprot.readListEnd();
550 | }
551 | struct.setVectorsIsSet(true);
552 | } else {
553 | org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type);
554 | }
555 | break;
556 | default:
557 | org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type);
558 | }
559 | iprot.readFieldEnd();
560 | }
561 | iprot.readStructEnd();
562 |
563 | // check for required fields of primitive type, which can't be checked in the validate method
564 | struct.validate();
565 | }
566 |
567 | public void write(org.apache.thrift.protocol.TProtocol oprot, Word2VecModelThrift struct) throws org.apache.thrift.TException {
568 | struct.validate();
569 |
570 | oprot.writeStructBegin(STRUCT_DESC);
571 | if (struct.vocab != null) {
572 | if (struct.isSetVocab()) {
573 | oprot.writeFieldBegin(VOCAB_FIELD_DESC);
574 | {
575 | oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, struct.vocab.size()));
576 | for (String _iter6 : struct.vocab)
577 | {
578 | oprot.writeString(_iter6);
579 | }
580 | oprot.writeListEnd();
581 | }
582 | oprot.writeFieldEnd();
583 | }
584 | }
585 | if (struct.isSetLayerSize()) {
586 | oprot.writeFieldBegin(LAYER_SIZE_FIELD_DESC);
587 | oprot.writeI32(struct.layerSize);
588 | oprot.writeFieldEnd();
589 | }
590 | if (struct.vectors != null) {
591 | if (struct.isSetVectors()) {
592 | oprot.writeFieldBegin(VECTORS_FIELD_DESC);
593 | {
594 | oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.DOUBLE, struct.vectors.size()));
595 | for (double _iter7 : struct.vectors)
596 | {
597 | oprot.writeDouble(_iter7);
598 | }
599 | oprot.writeListEnd();
600 | }
601 | oprot.writeFieldEnd();
602 | }
603 | }
604 | oprot.writeFieldStop();
605 | oprot.writeStructEnd();
606 | }
607 |
608 | }
609 |
610 | private static class Word2VecModelThriftTupleSchemeFactory implements SchemeFactory {
611 | public Word2VecModelThriftTupleScheme getScheme() {
612 | return new Word2VecModelThriftTupleScheme();
613 | }
614 | }
615 |
616 | private static class Word2VecModelThriftTupleScheme extends TupleScheme {
617 |
618 | @Override
619 | public void write(org.apache.thrift.protocol.TProtocol prot, Word2VecModelThrift struct) throws org.apache.thrift.TException {
620 | TTupleProtocol oprot = (TTupleProtocol) prot;
621 | BitSet optionals = new BitSet();
622 | if (struct.isSetVocab()) {
623 | optionals.set(0);
624 | }
625 | if (struct.isSetLayerSize()) {
626 | optionals.set(1);
627 | }
628 | if (struct.isSetVectors()) {
629 | optionals.set(2);
630 | }
631 | oprot.writeBitSet(optionals, 3);
632 | if (struct.isSetVocab()) {
633 | {
634 | oprot.writeI32(struct.vocab.size());
635 | for (String _iter8 : struct.vocab)
636 | {
637 | oprot.writeString(_iter8);
638 | }
639 | }
640 | }
641 | if (struct.isSetLayerSize()) {
642 | oprot.writeI32(struct.layerSize);
643 | }
644 | if (struct.isSetVectors()) {
645 | {
646 | oprot.writeI32(struct.vectors.size());
647 | for (double _iter9 : struct.vectors)
648 | {
649 | oprot.writeDouble(_iter9);
650 | }
651 | }
652 | }
653 | }
654 |
655 | @Override
656 | public void read(org.apache.thrift.protocol.TProtocol prot, Word2VecModelThrift struct) throws org.apache.thrift.TException {
657 | TTupleProtocol iprot = (TTupleProtocol) prot;
658 | BitSet incoming = iprot.readBitSet(3);
659 | if (incoming.get(0)) {
660 | {
661 | org.apache.thrift.protocol.TList _list10 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, iprot.readI32());
662 | struct.vocab = new ArrayList(_list10.size);
663 | for (int _i11 = 0; _i11 < _list10.size; ++_i11)
664 | {
665 | String _elem12;
666 | _elem12 = iprot.readString();
667 | struct.vocab.add(_elem12);
668 | }
669 | }
670 | struct.setVocabIsSet(true);
671 | }
672 | if (incoming.get(1)) {
673 | struct.layerSize = iprot.readI32();
674 | struct.setLayerSizeIsSet(true);
675 | }
676 | if (incoming.get(2)) {
677 | {
678 | org.apache.thrift.protocol.TList _list13 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.DOUBLE, iprot.readI32());
679 | struct.vectors = new ArrayList(_list13.size);
680 | for (int _i14 = 0; _i14 < _list13.size; ++_i14)
681 | {
682 | double _elem15;
683 | _elem15 = iprot.readDouble();
684 | struct.vectors.add(_elem15);
685 | }
686 | }
687 | struct.setVectorsIsSet(true);
688 | }
689 | }
690 | }
691 |
692 | }
693 |
694 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/util/AC.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.util;
2 |
3 | /** Extension of {@link AutoCloseable} where {@link #close()} does not throw any exception */
4 | public interface AC extends AutoCloseable {
5 | @Override void close();
6 |
7 | /** {@link AC} that does nothing */
8 | AC NOTHING = new AC() {
9 | @Override public void close() { }
10 | };
11 | }
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/util/AutoLog.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.util;
2 |
3 | import com.google.common.base.Preconditions;
4 | import org.apache.commons.logging.Log;
5 | import org.apache.commons.logging.LogFactory;
6 | import org.apache.log4j.Logger;
7 | import org.apache.log4j.varia.NullAppender;
8 |
9 | /**
10 | * Creates loggers based on the caller's class.
11 | */
12 | public final class AutoLog {
13 | /** Prevents initialization. */
14 | private AutoLog() {
15 | }
16 |
17 | /** @return {@link org.apache.commons.logging.Log} based on the caller's class */
18 | public static Log getLog() {
19 | return getLog(2);
20 | }
21 |
22 | /** Make sure there is at least one appender to avoid a warning printed on stderr */
23 | private static class InitializeOnDemand {
24 | private static final boolean INIT = init();
25 | private static boolean init() {
26 | if (!Logger.getRootLogger().getAllAppenders().hasMoreElements())
27 | Logger.getRootLogger().addAppender(new NullAppender());
28 | return true;
29 | }
30 | }
31 |
32 | /** @return {@link org.apache.commons.logging.Log} based on the stacktrace distance to
33 | * the original caller. 1= the caller to this method. 2 = the caller to the caller... etc*/
34 | public static Log getLog(int distance) {
35 | Preconditions.checkState(InitializeOnDemand.INIT);
36 | String callerClassName = Common.myCaller(distance).getClassName();
37 | try {
38 | return LogFactory.getLog(Class.forName(callerClassName));
39 | } catch (ClassNotFoundException t) {
40 | String err = "Class.forName on " + callerClassName + " failed";
41 | System.err.println(err);
42 | throw new IllegalStateException(err, t);
43 | }
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/util/CallableVoid.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.util;
2 |
3 | import java.util.concurrent.Callable;
4 |
5 | /** Utility base implementation of Callable with a Void return type. */
6 | public abstract class CallableVoid implements Callable {
7 |
8 | @Override public final Void call() throws Exception {
9 | run();
10 | return null;
11 | }
12 |
13 | /** Do the actual work here instead of using {@link #call()} */
14 | protected abstract void run() throws Exception;
15 |
16 | }
17 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/util/Common.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.util;
2 |
3 | import org.apache.commons.io.FilenameUtils;
4 | import org.apache.commons.io.IOUtils;
5 |
6 | import java.io.BufferedReader;
7 | import java.io.ByteArrayOutputStream;
8 | import java.io.File;
9 | import java.io.FileInputStream;
10 | import java.io.FileNotFoundException;
11 | import java.io.IOException;
12 | import java.io.InputStream;
13 | import java.io.ObjectOutputStream;
14 | import java.io.Reader;
15 | import java.io.Serializable;
16 | import java.io.StringWriter;
17 | import java.net.URL;
18 | import java.util.ArrayList;
19 | import java.util.Collections;
20 | import java.util.List;
21 | import java.util.zip.GZIPInputStream;
22 |
23 | /**
24 | * Simple utilities that in no way deserve their own class.
25 | */
26 | public class Common {
27 | /**
28 | * @param distance use 1 for our caller, 2 for their caller, etc...
29 | * @return the stack trace element from where the calling method was invoked
30 | */
31 | public static StackTraceElement myCaller(int distance) {
32 | // 0 here, 1 our caller, 2 their caller
33 | int index = distance + 1;
34 | try {
35 | StackTraceElement st[] = new Throwable().getStackTrace();
36 | // hack: skip synthetic caster methods
37 | if (st[index].getLineNumber() == 1) return st[index + 1];
38 | return st[index];
39 | } catch (Throwable t) {
40 | return new StackTraceElement("[unknown]","-","-",0);
41 | }
42 | }
43 |
44 | /** Serialize the given object into the given stream */
45 | public static void serialize(Serializable obj, ByteArrayOutputStream bout) {
46 | try {
47 | ObjectOutputStream out = new ObjectOutputStream(bout);
48 | out.writeObject(obj);
49 | out.close();
50 | } catch (IOException e) {
51 | throw new IllegalStateException("Could not serialize " + obj, e);
52 | }
53 | }
54 |
55 | /**
56 | * Read the file line for line and return the result in a list
57 | * @throws IOException upon failure in reading, note that we wrap the underlying IOException with the file name
58 | */
59 | public static List readToList(File f) throws IOException {
60 | try (final Reader reader = asReaderUTF8Lenient(new FileInputStream(f))) {
61 | return readToList(reader);
62 | } catch (IOException ioe) {
63 | throw new IllegalStateException(String.format("Failed to read %s: %s", f.getAbsolutePath(), ioe), ioe);
64 | }
65 | }
66 | /** Read the Reader line for line and return the result in a list */
67 | public static List readToList(Reader r) throws IOException {
68 | try ( BufferedReader in = new BufferedReader(r) ) {
69 | List l = new ArrayList<>();
70 | String line = null;
71 | while ((line = in.readLine()) != null)
72 | l.add(line);
73 | return Collections.unmodifiableList(l);
74 | }
75 | }
76 |
77 | /** Wrap the InputStream in a Reader that reads UTF-8. Invalid content will be replaced by unicode replacement glyph. */
78 | public static Reader asReaderUTF8Lenient(InputStream in) {
79 | return new UnicodeReader(in, "utf-8");
80 | }
81 |
82 | /** Read the contents of the given file into a string */
83 | public static String readFileToString(File f) throws IOException {
84 | StringWriter sw = new StringWriter();
85 | IO.copyAndCloseBoth(Common.asReaderUTF8Lenient(new FileInputStream(f)), sw);
86 | return sw.toString();
87 | }
88 |
89 | /** @return true if i is an even number */
90 | public static boolean isEven(int i) { return (i&1)==0; }
91 | /** @return true if i is an odd number */
92 | public static boolean isOdd(int i) { return !isEven(i); }
93 |
94 | /** Read the lines (as UTF8) of the resource file fn from the package of the given class into a (unmodifiable) list of strings
95 | * @throws IOException */
96 | public static List readResource(Class> clazz, String fn) throws IOException {
97 | try (final Reader reader = asReaderUTF8Lenient(getResourceAsStream(clazz, fn))) {
98 | return readToList(reader);
99 | }
100 | }
101 |
102 | /** Get an input stream to read the raw contents of the given resource, remember to close it :) */
103 | public static InputStream getResourceAsStream(Class> clazz, String fn) throws IOException {
104 | InputStream stream = clazz.getResourceAsStream(fn);
105 | if (stream == null) {
106 | throw new IOException("resource \"" + fn + "\" relative to " + clazz + " not found.");
107 | }
108 | return unpackStream(stream, fn);
109 | }
110 |
111 | /** Get a file to read the raw contents of the given resource :) */
112 | public static File getResourceAsFile(Class> clazz, String fn) throws IOException {
113 | URL url = clazz.getResource(fn);
114 | if (url == null || url.getFile() == null) {
115 | throw new IOException("resource \"" + fn + "\" relative to " + clazz + " not found.");
116 | }
117 | return new File(url.getFile());
118 | }
119 |
120 | /**
121 | * @throws IOException if {@code is} is null or if an {@link IOException} is thrown when reading from {@code is}
122 | */
123 | public static InputStream unpackStream(InputStream is, String fn) throws IOException {
124 | if (is == null)
125 | throw new FileNotFoundException("InputStream is null for " + fn);
126 |
127 | switch (FilenameUtils.getExtension(fn).toLowerCase()) {
128 | case "gz":
129 | return new GZIPInputStream(is);
130 | default:
131 | return is;
132 | }
133 | }
134 |
135 | /** Read the lines (as UTF8) of the resource file fn from the package of the given class into a string */
136 | public static String readResourceToStringChecked(Class> clazz, String fn) throws IOException {
137 | try (InputStream stream = getResourceAsStream(clazz, fn)) {
138 | return IOUtils.toString(asReaderUTF8Lenient(stream));
139 | }
140 | }
141 | }
142 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/util/Compare.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.util;
2 |
3 | /** Utility class for general comparison and equality operations. */
4 | public class Compare {
5 | /**
6 | * {@link NullPointerException} safe compare method; nulls are less than non-nulls.
7 | */
8 | public static > int compare(X x1, X x2) {
9 | if (x1 == null) return x2 == null ? 0 : -1;
10 | return x2 == null ? 1 : x1.compareTo(x2);
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/util/FileUtils.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.util;
2 |
3 | import java.io.File;
4 | import java.io.FileFilter;
5 | import java.io.IOException;
6 | import java.nio.file.Paths;
7 | import java.util.Arrays;
8 | import java.util.List;
9 | import java.util.UUID;
10 |
11 | import com.google.common.base.Function;
12 |
13 | import com.google.common.base.Strings;
14 |
15 | /**
16 | * Collection of file-related utilities.
17 | */
18 | public final class FileUtils {
19 |
20 | public static final int ONE_KB = 1<<10;
21 | public static final int ONE_MB = 1<<20;
22 |
23 | public static final Function FILE_TO_NAME = new Function() {
24 | @Override public String apply(File file) {
25 | return file.getName();
26 | }
27 | };
28 |
29 | /**
30 | * Returns a subdirectory of a given directory; the subdirectory is expected to already exist.
31 | * @param parent the directory in which to find the specified subdirectory
32 | * @param item the name of the subdirectory
33 | * @return the subdirectory having the specified name; null if no such directory exists or
34 | * exists but is a regular file.
35 | */
36 | public static File getDir(File parent, String item) {
37 | File dir = new File(parent, item);
38 | return (dir.exists() && dir.isDirectory()) ? dir : null;
39 | }
40 |
41 | /** @return File for the specified directory; creates the directory if necessary. */
42 | private static File getOrCreateDir(File f) throws IOException {
43 | if (!f.exists()) {
44 | if (!f.mkdirs()) {
45 | throw new IOException(f.getName() + ": Unable to create directory: " + f.getAbsolutePath());
46 | }
47 | } else if (!f.isDirectory()) {
48 | throw new IOException(f.getName() + ": Exists and is not a directory: " + f.getAbsolutePath());
49 | }
50 | return f;
51 | }
52 |
53 | /**
54 | * @return File for the specified directory, creates dirName directory if necessary.
55 | * @throws IOException if there is an error during directory creation or if a non-directory file with the desired name already
56 | * exists.
57 | */
58 | public static File getOrCreateDir(String dirName) throws IOException {
59 | return getOrCreateDir(new File(dirName));
60 | }
61 |
62 | /**
63 | * @return File for the specified directory; parent must already exist, creates dirName subdirectory if necessary.
64 | * @throws IOException if there is an error during directory creation or if a non-directory file with the desired name already
65 | * exists.
66 | */
67 | public static File getOrCreateDir(File parent, String dirName) throws IOException {
68 | return getOrCreateDir(new File(parent, dirName));
69 | }
70 |
71 | /**
72 | * @return File for the specified directory; parent must already exist, creates all intermediate dirNames subdirectories if necessary.
73 | * @throws IOException if there is an error during directory creation or if a non-directory file with the desired name already
74 | * exists.
75 | */
76 | public static File getOrCreateDir(File parent, String ... dirNames) throws IOException {
77 | return getOrCreateDir(Paths.get(parent.getPath(), dirNames).toFile());
78 | }
79 |
80 | /**
81 | * Deletes a file or directory.
82 | * If the file is a directory it recursively deletes it.
83 | * @param file file to be deleted
84 | * @return true if all the files where deleted successfully.
85 | */
86 | public static boolean deleteRecursive(final File file) {
87 | boolean result = true;
88 | if (file.isDirectory()) {
89 | for (final File inner : file.listFiles()) {
90 | result &= deleteRecursive(inner);
91 | }
92 | }
93 | return result & file.delete();
94 | }
95 |
96 | /** Utility class; don't instantiate. */
97 | private FileUtils() {
98 | throw new AssertionError("Do not instantiate.");
99 | }
100 |
101 | /** @return A random temporary folder that can be used for file-system operations testing */
102 | public static File getRandomTemporaryFolder(String prefix, String suffix) {
103 | return new File(System.getProperty("java.io.tmpdir"), Strings.nullToEmpty(prefix) + UUID.randomUUID().toString() + Strings.nullToEmpty(suffix));
104 | }
105 | }
106 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/util/Format.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.util;
2 |
3 | /**
4 | * Created by yibin on 2/2/15.
5 | */
6 | public class Format {
7 | /** @see {@link Strings#formatEnum(Enum)} */
8 | public static String formatEnum(Enum> enumValue) {
9 | return Strings.formatEnum(enumValue);
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/util/IO.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.util;
2 |
3 | import com.google.common.base.Function;
4 | import com.google.common.base.Preconditions;
5 | import org.apache.commons.io.IOUtils;
6 |
7 | import java.io.ByteArrayInputStream;
8 | import java.io.ByteArrayOutputStream;
9 | import java.io.Closeable;
10 | import java.io.File;
11 | import java.io.FileOutputStream;
12 | import java.io.IOException;
13 | import java.io.InputStream;
14 | import java.io.ObjectInputStream;
15 | import java.io.ObjectOutputStream;
16 | import java.io.ObjectStreamClass;
17 | import java.io.OutputStream;
18 | import java.io.Reader;
19 | import java.io.Writer;
20 | import java.lang.reflect.Field;
21 | import java.util.Collection;
22 | import java.util.Collections;
23 | import java.util.Comparator;
24 | import java.util.zip.GZIPInputStream;
25 | import java.util.zip.GZIPOutputStream;
26 |
27 | /**
28 | * Static utility functions related to IO.
29 | */
30 | public final class IO {
31 |
32 | /** a Comparator that orders {@link File} objects as by the last modified date */
33 | public static final Comparator FILE_LAST_MODIFIED_COMPARATOR = new Comparator() {
34 | @Override public int compare(File o1, File o2) {
35 | return Long.compare(o1.lastModified(), o2.lastModified());
36 | }};
37 |
38 | private static final int DEFAULT_BUFFER_SIZE = 1024 * 8; // 8K. Same as BufferedOutputStream, so writes are written-through with no extra copying
39 |
40 | private IO() { }
41 |
42 | /** @return the oldest file in the given collection using the last modified date or return null if the collection is empty */
43 | public static File getOldestFileOrNull(Collection files) {
44 | if (files.isEmpty())
45 | return null;
46 |
47 | return Collections.min(files, FILE_LAST_MODIFIED_COMPARATOR);
48 | }
49 |
50 | /** Copy input to output and close the output stream before returning */
51 | public static long copyAndCloseOutput(InputStream input, OutputStream output) throws IOException {
52 | try (OutputStream outputStream = output) {
53 | return copy(input, outputStream);
54 | }
55 | }
56 |
57 | /** Copy input to output and close both the input and output streams before returning */
58 | public static long copyAndCloseBoth(InputStream input, OutputStream output) throws IOException {
59 | try (InputStream inputStream = input) {
60 | return copyAndCloseOutput(inputStream, output);
61 | }
62 | }
63 |
64 | /** Similar to {@link IOUtils#toByteArray(InputStream)} but closes the stream. */
65 | public static byte[] toByteArray(InputStream is) throws IOException {
66 | ByteArrayOutputStream bao = new ByteArrayOutputStream();
67 | IO.copyAndCloseBoth(is, bao);
68 | return bao.toByteArray();
69 | }
70 |
71 | /** Copy input to output; neither stream is closed */
72 | public static long copy(InputStream input, OutputStream output) throws IOException {
73 | byte[] buffer = new byte[DEFAULT_BUFFER_SIZE];
74 | long count = 0;
75 | int n = 0;
76 | while (-1 != (n = input.read(buffer))) {
77 | output.write(buffer, 0, n);
78 | count += n;
79 | }
80 | return count;
81 | }
82 |
83 | /** Copy input to output; neither stream is closed */
84 | public static int copy(Reader input, Writer output) throws IOException {
85 | char[] buffer = new char[DEFAULT_BUFFER_SIZE];
86 | int count = 0;
87 | int n = 0;
88 | while (-1 != (n = input.read(buffer))) {
89 | output.write(buffer, 0, n);
90 | count += n;
91 | }
92 | return count;
93 | }
94 |
95 | /** Copy input to output and close the output stream before returning */
96 | public static int copyAndCloseOutput(Reader input, Writer output) throws IOException {
97 | try {
98 | return copy(input, output);
99 | } finally {
100 | output.close();
101 | }
102 |
103 | }
104 | /** Copy input to output and close both the input and output streams before returning */
105 | public static int copyAndCloseBoth(Reader input, Writer output) throws IOException {
106 | try {
107 | return copyAndCloseOutput(input, output);
108 | } finally {
109 | input.close();
110 | }
111 | }
112 |
113 | /**
114 | * Copy the data from the given {@link InputStream} to a temporary file and call the given
115 | * {@link Function} with it; after the function returns the file is deleted.
116 | */
117 | public static X runWithFile(InputStream stream, Function function) throws IOException {
118 | File f = File.createTempFile("run-with-file", null);
119 | try {
120 | try (FileOutputStream out = new FileOutputStream(f)) {
121 | IOUtils.copy(stream, out);
122 | }
123 | return function.apply(f);
124 | } finally {
125 | f.delete();
126 | }
127 | }
128 |
129 | /** @return the compressed (gzip) version of the given bytes */
130 | public static byte[] gzip(byte[] in) {
131 | try {
132 | ByteArrayOutputStream bos = new ByteArrayOutputStream();
133 | GZIPOutputStream gz = new GZIPOutputStream(bos);
134 | gz.write(in);
135 | gz.close();
136 | return bos.toByteArray();
137 | } catch (IOException e) {
138 | throw new RuntimeException("Failed to compress bytes", e);
139 | }
140 | }
141 |
142 | /** @return the decompressed version of the given (compressed with gzip) bytes */
143 | public static byte[] gunzip(byte[] in) {
144 | try {
145 | GZIPInputStream gis = new GZIPInputStream(new ByteArrayInputStream(in));
146 | ByteArrayOutputStream bos = new ByteArrayOutputStream();
147 | IO.copyAndCloseOutput(gis, bos);
148 | return bos.toByteArray();
149 | } catch (IOException e) {
150 | throw new RuntimeException("Failed to decompress data", e);
151 | }
152 | }
153 |
154 | /** @return the compressed (gzip) version of the given object */
155 | public static byte[] gzipObject(Object object) {
156 | try (
157 | ByteArrayOutputStream bos = new ByteArrayOutputStream();
158 | GZIPOutputStream zos = new GZIPOutputStream(bos);
159 | ObjectOutputStream oos = new ObjectOutputStream(zos))
160 | {
161 | oos.writeObject(object);
162 | zos.close(); // Terminate gzip
163 | return bos.toByteArray();
164 | } catch (IOException e) {
165 | throw new RuntimeException("Failed to compress bytes", e);
166 | }
167 | }
168 |
169 | /** @return the decompressed version of the given (compressed with gzip) bytes */
170 | public static Object gunzipObject(byte[] in) {
171 | try {
172 | return new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(in))).readObject();
173 | } catch (IOException | ClassNotFoundException e) {
174 | throw new RuntimeException("Failed to decompress data", e);
175 | }
176 | }
177 |
178 | /** ObjectInputStream which doesn't care too much about serialVersionUIDs. Horrible :)
179 | */
180 | public static ObjectInputStream gullibleObjectInputStream(InputStream is) throws IOException {
181 | return new ObjectInputStream(is) {
182 | @Override
183 | protected ObjectStreamClass readClassDescriptor() throws IOException, ClassNotFoundException {
184 | ObjectStreamClass oc = super.readClassDescriptor();
185 | try {
186 | Class> c = Class.forName(oc.getName());
187 | // interfaces do not have fields
188 | if (!c.isInterface()) {
189 | Field f = oc.getClass().getDeclaredField("suid");
190 | f.setAccessible(true);
191 | f.set(oc, ObjectStreamClass.lookup(c).getSerialVersionUID());
192 | }
193 | } catch (Exception e) {
194 | System.err.println("Couldn't fake class descriptor for "+oc+": "+ e);
195 | }
196 | return oc;
197 | }
198 | };
199 | }
200 |
201 | /** Close, ignoring exceptions */
202 | public static void close(Closeable stream) {
203 | try {
204 | stream.close();
205 | } catch (IOException e) {
206 | }
207 | }
208 |
209 | /**
210 | * Creates a directory if it does not exist.
211 | *
212 | *
213 | * It will recursively create all the absolute path specified in the input
214 | * parameter.
215 | *
216 | * @param directory
217 | * the {@link File} containing the directory structure that is
218 | * going to be created.
219 | * @return a {@link File} pointing to the created directory
220 | * @throws IOException if there was en error while creating the directory.
221 | */
222 | public static File createDirIfNotExists(File directory) throws IOException {
223 | if (!directory.isDirectory()) {
224 | if (!directory.mkdirs()) {
225 | throw new IOException("Failed to create directory: " + directory.getAbsolutePath());
226 | }
227 | }
228 | return directory;
229 | }
230 |
231 | /**
232 | * @return true if the {@link File} is null, does not exist, or we were able to delete it; false
233 | * if the file exists and could not be deleted.
234 | */
235 | public static boolean deleteIfPresent(File f) {
236 | return f == null || !f.exists() || f.delete();
237 | }
238 |
239 |
240 | /**
241 | * Stores the given contents into a temporary file
242 | * @param fileContents the raw contents to store in the temporary file
243 | * @param namePrefix the desired file name prefix (must be at least 3 characters long)
244 | * @param extension the desired extension including the '.' character (use null for '.tmp')
245 | * @return a {@link File} reference to the newly created temporary file
246 | * @throws IOException if the temporary file creation fails
247 | */
248 | public static File createTempFile(byte[] fileContents, String namePrefix, String extension) throws IOException {
249 | Preconditions.checkNotNull(fileContents, "file contents missing");
250 | File tempFile = File.createTempFile(namePrefix, extension);
251 | try (FileOutputStream fos = new FileOutputStream(tempFile)) {
252 | fos.write(fileContents);
253 | }
254 | return tempFile;
255 | }
256 | }
257 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/util/NDC.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.util;
2 |
3 |
4 | /** Helper to create {@link org.apache.log4j.NDC} for nested diagnostic contexts */
5 | public class NDC implements AC {
6 | private final int size;
7 |
8 | /** Push all the contexts given and pop them when auto-closed */
9 | public static NDC push(String... context) {
10 | return new NDC(context);
11 | }
12 |
13 | /** Construct an {@link AutoCloseable} {@link NDC} with the given contexts */
14 | private NDC(String... context) {
15 | for (String c : context) {
16 | org.apache.log4j.NDC.push("[" + c + "]");
17 | }
18 | this.size = context.length;
19 | }
20 |
21 | @Override
22 | public void close() {
23 | for (int i = 0; i < size; i++) {
24 | org.apache.log4j.NDC.pop();
25 | }
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/src/main/java/com/medallia/word2vec/util/Pair.java:
--------------------------------------------------------------------------------
1 | package com.medallia.word2vec.util;
2 |
3 | import com.google.common.base.Function;
4 | import com.google.common.base.Predicate;
5 | import com.google.common.base.Predicates;
6 | import com.google.common.collect.FluentIterable;
7 | import com.google.common.collect.ImmutableList;
8 | import com.google.common.collect.ImmutableSet;
9 | import com.google.common.collect.Iterables;
10 | import com.google.common.collect.Lists;
11 | import com.google.common.collect.Ordering;
12 | import com.google.common.collect.Sets;
13 |
14 | import java.io.Serializable;
15 | import java.util.ArrayList;
16 | import java.util.Collection;
17 | import java.util.Comparator;
18 | import java.util.HashSet;
19 | import java.util.Iterator;
20 | import java.util.List;
21 | import java.util.Map;
22 | import java.util.Objects;
23 | import java.util.Set;
24 |
25 | /**
26 | * Simple class for storing two arbitrary objects in one.
27 | *
28 | * @param the type of the first value
29 | * @param the type of the second value
30 | */
31 | public class Pair implements Map.Entry, Serializable {
32 |
33 | /** @see Serializable */
34 | private static final long serialVersionUID = 1L;
35 |
36 | /** The first item in the pair. */
37 | public final K first;
38 | /** The second item in the pair. */
39 | public final V second;
40 |
41 | /** Creates a new instance of Pair */
42 | protected Pair(K first, V second) {
43 | this.first = first;
44 | this.second = second;
45 | }
46 |
47 | /** Type-inferring constructor */
48 | public static Pair cons(X x, Y y) { return new Pair(x,y); }
49 |
50 | /** Type-inferring constructor for pairs of the same type, which can optionally be swapped */
51 | public static Pair cons(X x, X y, boolean swapped) { return swapped ? new Pair(y, x) : new Pair(x, y); }
52 |
53 | @Override
54 | public int hashCode() {
55 | // Compute by hand instead of using Encoding.combineHashes for improved performance
56 | return (first == null ? 0 : first.hashCode() * 13) + (second == null ? 0 : second.hashCode() * 17);
57 | }
58 |
59 | @Override
60 | public boolean equals(Object o) {
61 | if (o == this)
62 | return true;
63 | if (o == null || !getClass().equals(o.getClass()))
64 | return false;
65 |
66 | Pair,?> op = (Pair,?>) o;
67 | return Objects.equals(op.first, first) && Objects.equals(op.second, second);
68 | }
69 |
70 | /** @return {@link #first}; needed because String Templates have 'first' as a reserved word. */
71 | public K getOne() { return first; }
72 | /** @return {@link #first} */
73 | public K getFirst() { return first; }
74 | /** @return {@link #second} */
75 | public V getSecond() { return second; }
76 |
77 | @Override public String toString() {
78 | return "Pair<"+first+","+second+">";
79 | }
80 |
81 | /** @return a list with the two elements from this pair, regardless of whether they are null or not */
82 | public List