├── project ├── build.properties └── plugins.sbt ├── README.md ├── src ├── test │ └── scala │ │ └── org │ │ └── apache │ │ └── spark │ │ └── ml │ │ └── embedding │ │ └── Word2VecSuite.scala └── main │ └── scala │ └── org │ └── apache │ └── spark │ ├── ml │ └── embedding │ │ └── Word2Vec.scala │ └── mllib │ └── embedding │ └── Word2Vec.scala └── LICENSE /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 0.13.5 -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | logLevel := Level.Warn 2 | 3 | addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.3") 4 | 5 | resolvers += Classpaths.sbtPluginReleases 6 | 7 | resolvers += "Spark Package Main Repo" at "https://dl.bintray.com/spark-packages/maven" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spark-Word2Vec 2 | Spark-Word2Vec creates vector representation of words in a text corpus. It is based on the implementation of word2vec in Spark MLlib. Several optimization techniques are used to make this algorithm more scalable and accurate. 3 | 4 | # Highlights 5 | + Two models CBOW and Skip-gram are used in our implementation. 6 | + Both hierarchical softmax and negative sampling methods are supported to train the model. 7 | + The sub-sampling trick can be used to achieve both faster training and significantly better representations of uncommon words. 8 | 9 | # Examples 10 | ## Scala API 11 | ```scala 12 | val spark = SparkSession 13 | .builder 14 | .appName("Word2Vec example") 15 | .master("local[*]") 16 | .getOrCreate() 17 | 18 | // $example on$ 19 | // Input data: Each row is a bag of words from a sentence or document. 20 | val documentDF = spark.createDataFrame(Seq( 21 | "Hi I heard about Spark".split(" "), 22 | "I wish Java could use case classes".split(" "), 23 | "Logistic regression models are neat".split(" ") 24 | ).map(Tuple1.apply)).toDF("text") 25 | 26 | // Learn a mapping from words to Vectors. 27 | val word2Vec = new Word2Vec() 28 | .setInputCol("text") 29 | .setOutputCol("result") 30 | .setVectorSize(3) 31 | .setMinCount(0) 32 | val model = word2Vec.fit(documentDF) 33 | 34 | val result = model.transform(documentDF) 35 | result.collect().foreach { case Row(text: Seq[_], features: Vector) => 36 | println(s"Text: [${text.mkString(", ")}] => \nVector: $features\n") } 37 | // $example off$ 38 | 39 | spark.stop() 40 | ``` 41 | 42 | # Requirements 43 | Spark-Word2Vec is built against Spark 2.1.1. 44 | 45 | # Build From Source 46 | ```scala 47 | sbt package 48 | ``` 49 | 50 | # Licenses 51 | Spark-Word2Vec is available under Apache Licenses 2.0. 52 | 53 | # Contact & Feedback 54 | If you encounter bugs, feel free to submit an issue or pull request. Also you can mail to: 55 | + hibayesian (hibayesian@gmail.com). -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/ml/embedding/Word2VecSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.embedding 2 | 3 | /* 4 | * Licensed to the Apache Software Foundation (ASF) under one or more 5 | * contributor license agreements. See the NOTICE file distributed with 6 | * this work for additional information regarding copyright ownership. 7 | * The ASF licenses this file to You under the Apache License, Version 2.0 8 | * (the "License"); you may not use this file except in compliance with 9 | * the License. You may obtain a copy of the License at 10 | * 11 | * http://www.apache.org/licenses/LICENSE-2.0 12 | * 13 | * Unless required by applicable law or agreed to in writing, software 14 | * distributed under the License is distributed on an "AS IS" BASIS, 15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | * See the License for the specific language governing permissions and 17 | * limitations under the License. 18 | */ 19 | 20 | import org.apache.spark.sql.{Row, SparkSession} 21 | import org.apache.spark.ml.linalg.Vector 22 | 23 | object Word2VecSuite { 24 | def main(args: Array[String]) { 25 | val spark = SparkSession 26 | .builder 27 | .appName("Word2Vec example") 28 | .master("local[*]") 29 | .getOrCreate() 30 | 31 | // $example on$ 32 | // Input data: Each row is a bag of words from a sentence or document. 33 | val documentDF = spark.createDataFrame(Seq( 34 | "Hi I heard about Spark".split(" "), 35 | "I wish Java could use case classes".split(" "), 36 | "Logistic regression models are neat".split(" ") 37 | ).map(Tuple1.apply)).toDF("text") 38 | 39 | // Learn a mapping from words to Vectors. 40 | val word2Vec = new Word2Vec() 41 | .setInputCol("text") 42 | .setOutputCol("result") 43 | .setVectorSize(3) 44 | .setMinCount(0) 45 | val model = word2Vec.fit(documentDF) 46 | 47 | val result = model.transform(documentDF) 48 | result.collect().foreach { case Row(text: Seq[_], features: Vector) => 49 | println(s"Text: [${text.mkString(", ")}] => \nVector: $features\n") } 50 | // $example off$ 51 | 52 | spark.stop() 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/embedding/Word2Vec.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.embedding 19 | 20 | import org.apache.hadoop.fs.Path 21 | import org.apache.spark.ml.linalg.{BLAS, Vector, VectorUDT, Vectors} 22 | import org.apache.spark.ml.param._ 23 | import org.apache.spark.ml.param.shared._ 24 | import org.apache.spark.ml.util._ 25 | import org.apache.spark.ml.{Estimator, Model} 26 | import org.apache.spark.mllib.embedding 27 | import org.apache.spark.mllib.linalg.VectorImplicits._ 28 | import org.apache.spark.sql.functions._ 29 | import org.apache.spark.sql.types._ 30 | import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} 31 | 32 | /** 33 | * Params for [[Word2Vec]] and [[Word2VecModel]]. 34 | */ 35 | private[embedding] trait Word2VecBase extends Params 36 | with HasInputCol with HasOutputCol with HasMaxIter with HasStepSize with HasSeed { 37 | 38 | /** 39 | * The dimension of the code that you want to transform from words. 40 | * Default: 100 41 | * @group param 42 | */ 43 | final val vectorSize = new IntParam( 44 | this, "vectorSize", "the dimension of codes after transforming from words (> 0)", 45 | ParamValidators.gt(0)) 46 | setDefault(vectorSize -> 100) 47 | 48 | /** @group getParam */ 49 | def getVectorSize: Int = $(vectorSize) 50 | 51 | /** 52 | * The window size (context words from [-window, window]). 53 | * Default: 5 54 | * @group expertParam 55 | */ 56 | final val windowSize = new IntParam( 57 | this, "windowSize", "the window size (context words from [-window, window]) (> 0)", 58 | ParamValidators.gt(0)) 59 | setDefault(windowSize -> 5) 60 | 61 | /** @group expertGetParam */ 62 | def getWindowSize: Int = $(windowSize) 63 | 64 | /** 65 | * Number of partitions for sentences of words. 66 | * Default: 1 67 | * @group param 68 | */ 69 | final val numPartitions = new IntParam( 70 | this, "numPartitions", "number of partitions for sentences of words (> 0)", 71 | ParamValidators.gt(0)) 72 | setDefault(numPartitions -> 1) 73 | 74 | /** @group getParam */ 75 | def getNumPartitions: Int = $(numPartitions) 76 | 77 | /** 78 | * The minimum number of times a token must appear to be included in the word2vec model's 79 | * vocabulary. 80 | * Default: 5 81 | * @group param 82 | */ 83 | final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " + 84 | "appear to be included in the word2vec model's vocabulary (>= 0)", ParamValidators.gtEq(0)) 85 | setDefault(minCount -> 5) 86 | 87 | /** @group getParam */ 88 | def getMinCount: Int = $(minCount) 89 | 90 | /** 91 | * Sets the maximum length (in words) of each sentence in the input data. 92 | * Any sentence longer than this threshold will be divided into chunks of 93 | * up to `maxSentenceLength` size. 94 | * Default: 1000 95 | * @group param 96 | */ 97 | final val maxSentenceLength = new IntParam(this, "maxSentenceLength", "Maximum length " + 98 | "(in words) of each sentence in the input data. Any sentence longer than this threshold will " + 99 | "be divided into chunks up to the size (> 0)", ParamValidators.gt(0)) 100 | setDefault(maxSentenceLength -> 1000) 101 | 102 | /** @group getParam */ 103 | def getMaxSentenceLength: Int = $(maxSentenceLength) 104 | 105 | /** 106 | * Use continues bag-of-words model. 107 | * Default: 0 108 | * @group param 109 | */ 110 | final val cbow = new IntParam(this, "cbow", "Use continues bag-of-words model", 111 | ParamValidators.inArray(Array(0, 1))) 112 | setDefault(cbow -> 0) 113 | 114 | /** @group getParam */ 115 | def getCBOW: Int = $(cbow) 116 | 117 | /** 118 | * Use hierarchical softmax method to train the model. 119 | * Default: 1 120 | * @group param 121 | */ 122 | final val hs = new IntParam(this, "hs", "Use hierarchical softmax method to train the model", 123 | ParamValidators.inArray(Array(0, 1))) 124 | setDefault(hs -> 1) 125 | 126 | /** @group getParam */ 127 | def getHS: Int = $(hs) 128 | 129 | /** 130 | * Use negative sampling method to train the model. 131 | * Default: 0 132 | * @group param 133 | */ 134 | final val negative = new IntParam(this, "negative", "Use negative sampling method to train the model", 135 | ParamValidators.inArray(Array(0, 1))) 136 | setDefault(negative -> 0) 137 | 138 | /** @group getParam */ 139 | def getNegative: Int = $(negative) 140 | 141 | /** 142 | * Use sub-sampling trick to improve the performance. 143 | * Default: 0 144 | * @group param 145 | */ 146 | final val sample = new DoubleParam(this, "sample", "Use sub-sampling trick to improve the performance", 147 | ParamValidators.inRange(0, 1, true, true)) 148 | setDefault(sample -> 0) 149 | 150 | /** @group getParam */ 151 | def getSample: Double = $(sample) 152 | 153 | setDefault(stepSize -> 0.025) 154 | setDefault(maxIter -> 1) 155 | 156 | /** 157 | * Validate and transform the input schema. 158 | */ 159 | protected def validateAndTransformSchema(schema: StructType): StructType = { 160 | val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) 161 | SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) 162 | SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) 163 | } 164 | } 165 | 166 | /** 167 | * Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further 168 | * natural language processing or machine learning process. 169 | */ 170 | final class Word2Vec (override val uid: String) 171 | extends Estimator[Word2VecModel] with Word2VecBase with DefaultParamsWritable { 172 | 173 | def this() = this(Identifiable.randomUID("w2v")) 174 | 175 | /** @group setParam */ 176 | def setInputCol(value: String): this.type = set(inputCol, value) 177 | 178 | /** @group setParam */ 179 | def setOutputCol(value: String): this.type = set(outputCol, value) 180 | 181 | /** @group setParam */ 182 | def setVectorSize(value: Int): this.type = set(vectorSize, value) 183 | 184 | /** @group expertSetParam */ 185 | def setWindowSize(value: Int): this.type = set(windowSize, value) 186 | 187 | /** @group setParam */ 188 | def setStepSize(value: Double): this.type = set(stepSize, value) 189 | 190 | /** @group setParam */ 191 | def setNumPartitions(value: Int): this.type = set(numPartitions, value) 192 | 193 | /** @group setParam */ 194 | def setMaxIter(value: Int): this.type = set(maxIter, value) 195 | 196 | /** @group setParam */ 197 | def setSeed(value: Long): this.type = set(seed, value) 198 | 199 | /** @group setParam */ 200 | def setMinCount(value: Int): this.type = set(minCount, value) 201 | 202 | /** @group setParam */ 203 | def setMaxSentenceLength(value: Int): this.type = set(maxSentenceLength, value) 204 | 205 | /** @group getParam */ 206 | def setCBOW(value: Int): this.type = set(cbow, value) 207 | 208 | /** @group getParam */ 209 | def setHS(value: Int): this.type = set(hs, value) 210 | 211 | /** @group getParam */ 212 | def setNegative(value: Int): this.type = set(negative, value) 213 | 214 | /** @group getParam */ 215 | def setSample(value: Double): this.type = set(sample, value) 216 | 217 | override def fit(dataset: Dataset[_]): Word2VecModel = { 218 | transformSchema(dataset.schema, logging = true) 219 | val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) 220 | val wordVectors = new embedding.Word2Vec() 221 | .setLearningRate($(stepSize)) 222 | .setMinCount($(minCount)) 223 | .setNumIterations($(maxIter)) 224 | .setNumPartitions($(numPartitions)) 225 | .setSeed($(seed)) 226 | .setVectorSize($(vectorSize)) 227 | .setWindowSize($(windowSize)) 228 | .setMaxSentenceLength($(maxSentenceLength)) 229 | .setCBOW($(cbow)) 230 | .setHS($(hs)) 231 | .setNegative($(negative)) 232 | .setSample($(sample)) 233 | .fit(input) 234 | copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) 235 | } 236 | 237 | override def transformSchema(schema: StructType): StructType = { 238 | validateAndTransformSchema(schema) 239 | } 240 | 241 | override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra) 242 | } 243 | 244 | object Word2Vec extends DefaultParamsReadable[Word2Vec] { 245 | 246 | override def load(path: String): Word2Vec = super.load(path) 247 | } 248 | 249 | /** 250 | * Model fitted by [[Word2Vec]]. 251 | */ 252 | class Word2VecModel private[ml] ( 253 | override val uid: String, 254 | @transient private val wordVectors: embedding.Word2VecModel) 255 | extends Model[Word2VecModel] with Word2VecBase with MLWritable { 256 | 257 | import Word2VecModel._ 258 | 259 | /** 260 | * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and 261 | * and the vector the DenseVector that it is mapped to. 262 | */ 263 | @transient lazy val getVectors: DataFrame = { 264 | val spark = SparkSession.builder().getOrCreate() 265 | val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble))) 266 | spark.createDataFrame(wordVec.toSeq).toDF("word", "vector") 267 | } 268 | 269 | /** 270 | * Find "num" number of words closest in similarity to the given word, not 271 | * including the word itself. Returns a dataframe with the words and the 272 | * cosine similarities between the synonyms and the given word. 273 | */ 274 | def findSynonyms(word: String, num: Int): DataFrame = { 275 | val spark = SparkSession.builder().getOrCreate() 276 | spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") 277 | } 278 | 279 | /** 280 | * Find "num" number of words whose vector representation most similar to the supplied vector. 281 | * If the supplied vector is the vector representation of a word in the model's vocabulary, 282 | * that word will be in the results. Returns a dataframe with the words and the cosine 283 | * similarities between the synonyms and the given word vector. 284 | */ 285 | def findSynonyms(vec: Vector, num: Int): DataFrame = { 286 | val spark = SparkSession.builder().getOrCreate() 287 | spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity") 288 | } 289 | 290 | /** @group setParam */ 291 | def setInputCol(value: String): this.type = set(inputCol, value) 292 | 293 | /** @group setParam */ 294 | def setOutputCol(value: String): this.type = set(outputCol, value) 295 | 296 | /** 297 | * Transform a sentence column to a vector column to represent the whole sentence. The transform 298 | * is performed by averaging all word vectors it contains. 299 | */ 300 | override def transform(dataset: Dataset[_]): DataFrame = { 301 | transformSchema(dataset.schema, logging = true) 302 | val vectors = wordVectors.getVectors 303 | .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) 304 | .map(identity) // mapValues doesn't return a serializable map (SI-7005) 305 | val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors) 306 | val d = $(vectorSize) 307 | val word2Vec = udf { sentence: Seq[String] => 308 | if (sentence.isEmpty) { 309 | Vectors.sparse(d, Array.empty[Int], Array.empty[Double]) 310 | } else { 311 | val sum = Vectors.zeros(d) 312 | sentence.foreach { word => 313 | bVectors.value.get(word).foreach { v => 314 | BLAS.axpy(1.0, v, sum) 315 | } 316 | } 317 | BLAS.scal(1.0 / sentence.size, sum) 318 | sum 319 | } 320 | } 321 | dataset.withColumn($(outputCol), word2Vec(col($(inputCol)))) 322 | } 323 | 324 | override def transformSchema(schema: StructType): StructType = { 325 | validateAndTransformSchema(schema) 326 | } 327 | 328 | override def copy(extra: ParamMap): Word2VecModel = { 329 | val copied = new Word2VecModel(uid, wordVectors) 330 | copyValues(copied, extra).setParent(parent) 331 | } 332 | 333 | override def write: MLWriter = new Word2VecModelWriter(this) 334 | } 335 | 336 | object Word2VecModel extends MLReadable[Word2VecModel] { 337 | 338 | private[Word2VecModel] 339 | class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter { 340 | 341 | private case class Data(wordIndex: Map[String, Int], wordVectors: Seq[Float]) 342 | 343 | override protected def saveImpl(path: String): Unit = { 344 | DefaultParamsWriter.saveMetadata(instance, path, sc) 345 | val data = Data(instance.wordVectors.wordIndex, instance.wordVectors.wordVectors.toSeq) 346 | val dataPath = new Path(path, "data").toString 347 | sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) 348 | } 349 | } 350 | 351 | private class Word2VecModelReader extends MLReader[Word2VecModel] { 352 | 353 | private val className = classOf[Word2VecModel].getName 354 | 355 | override def load(path: String): Word2VecModel = { 356 | val metadata = DefaultParamsReader.loadMetadata(path, sc, className) 357 | val dataPath = new Path(path, "data").toString 358 | val data = sparkSession.read.parquet(dataPath) 359 | .select("wordIndex", "wordVectors") 360 | .head() 361 | val wordIndex = data.getAs[Map[String, Int]](0) 362 | val wordVectors = data.getAs[Seq[Float]](1).toArray 363 | val oldModel = new embedding.Word2VecModel(wordIndex, wordVectors) 364 | val model = new Word2VecModel(metadata.uid, oldModel) 365 | DefaultParamsReader.getAndSetParams(model, metadata) 366 | model 367 | } 368 | } 369 | 370 | override def read: MLReader[Word2VecModel] = new Word2VecModelReader 371 | 372 | override def load(path: String): Word2VecModel = super.load(path) 373 | } 374 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/embedding/Word2Vec.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.embedding 19 | 20 | import java.lang.{Iterable => JavaIterable} 21 | 22 | import com.github.fommil.netlib.BLAS.{getInstance => blas} 23 | import org.apache.spark.SparkContext 24 | import org.apache.spark.api.java.JavaRDD 25 | import org.apache.spark.internal.Logging 26 | import org.apache.spark.mllib.linalg.{Vector, Vectors} 27 | import org.apache.spark.mllib.util.{Loader, Saveable} 28 | import org.apache.spark.rdd._ 29 | import org.apache.spark.sql.SparkSession 30 | import org.apache.spark.util.Utils 31 | import org.apache.spark.util.random.XORShiftRandom 32 | import org.json4s.DefaultFormats 33 | import org.json4s.JsonDSL._ 34 | import org.json4s.jackson.JsonMethods._ 35 | 36 | import scala.collection.JavaConverters._ 37 | import scala.collection.mutable 38 | 39 | /** 40 | * Entry in vocabulary 41 | */ 42 | private case class VocabWord( 43 | var word: String, 44 | var cn: Int, 45 | var point: Array[Int], 46 | var code: Array[Int], 47 | var codeLen: Int 48 | ) 49 | 50 | /** 51 | * Word2Vec creates vector representation of words in a text corpus. 52 | * The algorithm first constructs a vocabulary from the corpus 53 | * and then learns vector representation of words in the vocabulary. 54 | * The vector representation can be used as features in 55 | * natural language processing and machine learning algorithms. 56 | * 57 | * Two models cbow and skip-gram are used in our implementation. 58 | * Both hierarchical softmax and negative sampling methods are 59 | * supported to train the model. The variable names in the implementation 60 | * matches the original C implementation. 61 | * 62 | * For original C implementation, see https://code.google.com/p/word2vec/ 63 | * For research papers, see 64 | * Efficient Estimation of Word Representations in Vector Space 65 | * and 66 | * Distributed Representations of Words and Phrases and their Compositionality. 67 | */ 68 | class Word2Vec extends Serializable with Logging { 69 | 70 | private var vectorSize = 100 71 | private var learningRate = 0.025 72 | private var numPartitions = 1 73 | private var numIterations = 1 74 | private var seed = Utils.random.nextLong() 75 | private var minCount = 5 76 | private var maxSentenceLength = 1000 77 | private var cbow = 0 78 | private var negative = 0 79 | private var sample = 0.0 80 | private var hs = 1 81 | private val tableSize = 1e8.toInt 82 | 83 | /** 84 | * Sets vector size (default: 100). 85 | */ 86 | def setVectorSize(vectorSize: Int): this.type = { 87 | require(vectorSize > 0, s"vectorSize must be greater than 0 but got $vectorSize") 88 | this.vectorSize = vectorSize 89 | this 90 | } 91 | 92 | /** 93 | * Sets initial learning rate (default: 0.025). 94 | */ 95 | def setLearningRate(learningRate: Double): this.type = { 96 | require(learningRate > 0 && learningRate <= 1, s"learningRate must be between 0 and 1 but got $learningRate") 97 | this.learningRate = learningRate 98 | this 99 | } 100 | 101 | /** 102 | * Sets number of partitions (default: 1). Use a small number for accuracy. 103 | */ 104 | def setNumPartitions(numPartitions: Int): this.type = { 105 | require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions") 106 | this.numPartitions = numPartitions 107 | this 108 | } 109 | 110 | /** 111 | * Sets number of iterations (default: 1), which should be smaller than or equal to number of 112 | * partitions. 113 | */ 114 | def setNumIterations(numIterations: Int): this.type = { 115 | require(numIterations > 0, s"numIterations must be greater than 0 but got $numIterations") 116 | this.numIterations = numIterations 117 | this 118 | } 119 | 120 | /** 121 | * Sets random seed (default: a random long integer). 122 | */ 123 | def setSeed(seed: Long): this.type = { 124 | this.seed = seed 125 | this 126 | } 127 | 128 | /** 129 | * Sets the window of words (default: 5) 130 | */ 131 | def setWindowSize(window: Int): this.type = { 132 | require(window > 0, s"window must be greater than 0 but got $window") 133 | this.window = window 134 | this 135 | } 136 | 137 | /** 138 | * Sets minCount, the minimum number of times a token must appear to be included in the word2vec 139 | * model's vocabulary (default: 5). 140 | */ 141 | def setMinCount(minCount: Int): this.type = { 142 | require(minCount >= 0, s"minCount must be greater than or equal to 0 but got $minCount") 143 | this.minCount = minCount 144 | this 145 | } 146 | 147 | /** 148 | * Sets the maximum length (in words) of each sentence in the input data. 149 | * Any sentence longer than this threshold will be divided into chunks of 150 | * up to `maxSentenceLength` size (default: 1000) 151 | */ 152 | def setMaxSentenceLength(maxSentenceLength: Int): this.type = { 153 | require(maxSentenceLength > 0, s"maxSentenceLength must be greater than 0 but got $maxSentenceLength") 154 | this.maxSentenceLength = maxSentenceLength 155 | this 156 | } 157 | 158 | /** 159 | * Sets cbow. Use continues bag-of-words model (default: 0). 160 | */ 161 | def setCBOW(cbow: Int): this.type = { 162 | require(cbow == 0 || cbow == 1, s"cbow must be equal to 0 or 1 but got $cbow") 163 | this.cbow = cbow 164 | this 165 | } 166 | 167 | /** 168 | * Set sample. Use sub-sampling trick to improve the performance (default: 0). 169 | */ 170 | def setSample(sample: Double): this.type = { 171 | require(sample >= 0 && sample <= 1, s"sample must be between 0 and 1 but got $sample") 172 | this.sample = sample 173 | this 174 | } 175 | 176 | /** 177 | * Set hs. Use hierarchical softmax method to train the model (default: 1). 178 | */ 179 | def setHS(hs: Int): this.type = { 180 | require(hs == 0 || hs == 1, s"hs must be equal to 0 or 1 but got $hs") 181 | this.hs = hs 182 | this 183 | } 184 | 185 | /** 186 | * Set negative. Use negative sampling method to train the model (default: 0). 187 | */ 188 | def setNegative(negative: Int): this.type = { 189 | require(negative >= 0, s"negative must be greater than or equal to 0 but got $negative") 190 | this.negative = negative 191 | this 192 | } 193 | 194 | private val EXP_TABLE_SIZE = 1000 195 | private val MAX_EXP = 6 196 | private val MAX_CODE_LENGTH = 40 197 | private val MAX_SENTENCE_LENGTH = 1000 198 | 199 | /** context words from [-window, window] */ 200 | private var window = 5 201 | 202 | private var trainWordsCount = 0L 203 | private var vocabSize = 0 204 | @transient private var vocab: Array[VocabWord] = null 205 | @transient private var vocabHash = mutable.HashMap.empty[String, Int] 206 | @transient private var table: Array[Int] = null 207 | 208 | 209 | private def learnVocab[S <: Iterable[String]](dataset: RDD[S]): Unit = { 210 | val words = dataset.flatMap(x => x) 211 | 212 | vocab = words.map(w => (w, 1)) 213 | .reduceByKey(_ + _) 214 | .filter(_._2 >= minCount) 215 | .map(x => VocabWord( 216 | x._1, 217 | x._2, 218 | new Array[Int](MAX_CODE_LENGTH), 219 | new Array[Int](MAX_CODE_LENGTH), 220 | 0)) 221 | .collect() 222 | .sortWith((a, b) => a.cn > b.cn) 223 | 224 | vocabSize = vocab.length 225 | require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " + 226 | "the setting of minCount, which could be large enough to remove all your words in sentences.") 227 | 228 | var a = 0 229 | while (a < vocabSize) { 230 | vocabHash += vocab(a).word -> a 231 | trainWordsCount += vocab(a).cn 232 | a += 1 233 | } 234 | logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount") 235 | } 236 | 237 | private def createExpTable(): Array[Float] = { 238 | val expTable = new Array[Float](EXP_TABLE_SIZE) 239 | var i = 0 240 | while (i < EXP_TABLE_SIZE) { 241 | val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP) 242 | expTable(i) = (tmp / (tmp + 1.0)).toFloat 243 | i += 1 244 | } 245 | expTable 246 | } 247 | 248 | private def createBinaryTree(): Unit = { 249 | val count = new Array[Long](vocabSize * 2 - 1) 250 | val binary = new Array[Int](vocabSize * 2 - 1) 251 | val parentNode = new Array[Int](vocabSize * 2 - 1) 252 | val code = new Array[Int](MAX_CODE_LENGTH) 253 | val point = new Array[Int](MAX_CODE_LENGTH) 254 | var a = 0 255 | while (a < vocabSize) { 256 | count(a) = vocab(a).cn 257 | a += 1 258 | } 259 | while (a < 2 * vocabSize - 1) { 260 | count(a) = 1e9.toInt 261 | a += 1 262 | } 263 | var pos1 = vocabSize - 1 264 | var pos2 = vocabSize 265 | 266 | var min1i = 0 267 | var min2i = 0 268 | 269 | a = 0 270 | while (a < vocabSize - 1) { 271 | if (pos1 >= 0) { 272 | if (count(pos1) < count(pos2)) { 273 | min1i = pos1 274 | pos1 -= 1 275 | } else { 276 | min1i = pos2 277 | pos2 += 1 278 | } 279 | } else { 280 | min1i = pos2 281 | pos2 += 1 282 | } 283 | if (pos1 >= 0) { 284 | if (count(pos1) < count(pos2)) { 285 | min2i = pos1 286 | pos1 -= 1 287 | } else { 288 | min2i = pos2 289 | pos2 += 1 290 | } 291 | } else { 292 | min2i = pos2 293 | pos2 += 1 294 | } 295 | count(vocabSize + a) = count(min1i) + count(min2i) 296 | parentNode(min1i) = vocabSize + a 297 | parentNode(min2i) = vocabSize + a 298 | binary(min2i) = 1 299 | a += 1 300 | } 301 | // Now assign binary code to each vocabulary word 302 | var i = 0 303 | a = 0 304 | while (a < vocabSize) { 305 | var b = a 306 | i = 0 307 | while (b != vocabSize * 2 - 2) { 308 | code(i) = binary(b) 309 | point(i) = b 310 | i += 1 311 | b = parentNode(b) 312 | } 313 | vocab(a).codeLen = i 314 | vocab(a).point(0) = vocabSize - 2 315 | b = 0 316 | while (b < i) { 317 | vocab(a).code(i - b - 1) = code(b) 318 | vocab(a).point(i - b) = point(b) - vocabSize 319 | b += 1 320 | } 321 | a += 1 322 | } 323 | } 324 | 325 | private def initUnigramTable(): Unit = { 326 | var a = 0 327 | val power = 0.75 328 | var trainWordsPow = 0.0 329 | table = new Array[Int](tableSize) 330 | 331 | while (a < vocabSize) { 332 | trainWordsPow += Math.pow(vocab(a).cn, power) 333 | a += 1 334 | } 335 | 336 | var i = 0 337 | var d1 = Math.pow(vocab(i).cn, power) / trainWordsPow 338 | a = 0 339 | while (a < tableSize) { 340 | table(a) = i 341 | if (a.toDouble / tableSize > d1) { 342 | i += 1 343 | d1 += Math.pow(vocab(i).cn, power) / trainWordsPow 344 | } 345 | if (i >= vocabSize) { 346 | i = vocabSize - 1 347 | } 348 | a += 1 349 | } 350 | } 351 | 352 | /** 353 | * Computes the vector representation of each word in vocabulary. 354 | * @param dataset an RDD of sentences, each sentence is expressed as an iterable collection of words 355 | * @return a Word2VecModel 356 | */ 357 | def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { 358 | 359 | learnVocab(dataset) 360 | 361 | createBinaryTree() 362 | 363 | if (negative > 0) { 364 | initUnigramTable() 365 | } else if (hs == 0) { 366 | throw new RuntimeException(s"negative and hs can not both be equal to 0.") 367 | } 368 | 369 | val sc = dataset.context 370 | 371 | val expTable = sc.broadcast(createExpTable()) 372 | val bcVocab = sc.broadcast(vocab) 373 | val bcVocabHash = sc.broadcast(vocabHash) 374 | val bcTable = sc.broadcast(table) 375 | 376 | // each partition is a collection of sentences, 377 | // will be translated into arrays of Index integer 378 | val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter => 379 | // Each sentence will map to 0 or more Array[Int] 380 | sentenceIter.flatMap { sentence => 381 | // Sentence of words, some of which map to a word index 382 | val wordIndexes = sentence.flatMap(bcVocabHash.value.get) 383 | // break wordIndexes into trunks of maxSentenceLength when has more 384 | wordIndexes.grouped(maxSentenceLength).map(_.toArray) 385 | } 386 | } 387 | 388 | val newSentences = sentences.repartition(numPartitions).cache() 389 | val initRandom = new XORShiftRandom(seed) 390 | 391 | if (vocabSize.toLong * vectorSize >= Int.MaxValue) { 392 | throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" + 393 | " to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " + 394 | "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue`.") 395 | } 396 | 397 | val syn0Global = 398 | Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) 399 | val syn1Global = new Array[Float](vocabSize * vectorSize) 400 | val syn1NegGlobal = new Array[Float](vocabSize * vectorSize) 401 | var alpha = learningRate 402 | 403 | for (k <- 1 to numIterations) { 404 | val bcSyn0Global = sc.broadcast(syn0Global) 405 | val bcSyn1Global = sc.broadcast(syn1Global) 406 | val bcSyn1NegGlobal = sc.broadcast(syn1NegGlobal) 407 | 408 | val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => 409 | val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) 410 | val syn0Modify = new Array[Int](vocabSize) 411 | val syn1Modify = new Array[Int](vocabSize) 412 | val syn1NegModify = new Array[Int](vocabSize) 413 | val sen = new Array[Int](MAX_SENTENCE_LENGTH) 414 | 415 | val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, bcSyn1NegGlobal.value, 0L, 0L)) { 416 | case ((syn0, syn1, syn1Neg, lastWordCount, wordCount), sentence) => 417 | var sentenceLength = 0 418 | 419 | // The sub-sampling trick randomly discards frequent words while keeping the ranking same 420 | var sentencePosition = 0 421 | while (sentencePosition < sentence.length && sentencePosition < MAX_SENTENCE_LENGTH) { 422 | val word = sentence(sentencePosition) 423 | if (sample > 0) { 424 | val ran = Math.sqrt(bcVocab.value(word).cn / (sample * trainWordsCount) + 1) * 425 | (sample * trainWordsCount) / bcVocab.value(word).cn 426 | if (ran >= random.nextFloat()) { 427 | sen(sentenceLength) = word 428 | sentenceLength += 1 429 | } 430 | } else { 431 | sen(sentenceLength) = word 432 | sentenceLength += 1 433 | } 434 | sentencePosition += 1 435 | } 436 | 437 | var lwc = lastWordCount 438 | var wc = wordCount 439 | if (wordCount - lastWordCount > 10000) { 440 | lwc = wordCount 441 | alpha = 442 | learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) 443 | if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001 444 | logInfo("wordCount = " + wordCount + ", alpha = " + alpha) 445 | } 446 | wc += sentence.length 447 | sentencePosition = 0 448 | while (sentencePosition < sentenceLength) { 449 | val word = sen(sentencePosition) 450 | val b = random.nextInt(window) 451 | val neu1 = new Array[Float](vectorSize) 452 | 453 | if (cbow == 1) { 454 | // Train CBOW 455 | val neu1e = new Array[Float](vectorSize) 456 | var a = b 457 | while (a < window * 2 + 1 - b) { 458 | if (a != window) { 459 | val c = sentencePosition - window + a 460 | if (c >= 0 && c < sentence.length) { 461 | val lastWord = sen(c) 462 | val l1 = lastWord * vectorSize 463 | blas.saxpy(vectorSize, 1, syn0, l1, 1, neu1, 0, 1) 464 | } 465 | } 466 | a += 1 467 | } 468 | 469 | if (hs == 1) { 470 | // Hierarchical softmax 471 | var d = 0 472 | while (d < bcVocab.value(word).codeLen) { 473 | val inner = bcVocab.value(word).point(d) 474 | val l2 = inner * vectorSize 475 | // Propagate hidden -> output 476 | var f = blas.sdot(vectorSize, neu1, 0, 1, syn1, l2, 1) 477 | if (f > -MAX_EXP && f < MAX_EXP) { 478 | val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt 479 | f = expTable.value(ind) 480 | val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat 481 | blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) 482 | blas.saxpy(vectorSize, g, neu1, 0, 1, syn1, l2, 1) 483 | syn1Modify(inner) += 1 484 | } 485 | d += 1 486 | } 487 | } 488 | 489 | if (negative > 0) { 490 | // Negative sampling 491 | var target = -1 492 | var label = -1 493 | var d = 0 494 | while (d < negative + 1) { 495 | var isContinued = false 496 | if (d == 0) { 497 | target = word 498 | label = 1 499 | } else { 500 | target = bcTable.value(random.nextInt(tableSize)) 501 | if (target == word) { 502 | isContinued = true 503 | } 504 | if (!isContinued.equals(true)) { 505 | label = 0 506 | } 507 | } 508 | if (!isContinued.equals(true)) { 509 | val l2 = target * vectorSize 510 | val f = blas.sdot(vectorSize, neu1, 0, 1, syn1Neg, l2, 1) 511 | var g = 0.0 512 | if (f > MAX_EXP) { 513 | g = (label - 1) * alpha 514 | } else if (f < -MAX_EXP) { 515 | g = (label - 0) * alpha 516 | } else { 517 | val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt 518 | g = (label - expTable.value(ind)) * alpha 519 | } 520 | blas.saxpy(vectorSize, g.toFloat, syn1Neg, l2, 1, neu1e, 0, 1) 521 | blas.saxpy(vectorSize, g.toFloat, neu1, 0, 1, syn1Neg, l2, 1) 522 | syn1NegModify(target) += 1 523 | } 524 | d += 1 525 | } 526 | } 527 | 528 | // Hidden -> in 529 | a = b 530 | while (a < window * 2 + 1 - b) { 531 | if (a != window) { 532 | val c = sentencePosition - window + a 533 | if (c >= 0 && c < sentenceLength) { 534 | val lastWord = sen(c) 535 | val l1 = lastWord * vectorSize 536 | blas.saxpy(vectorSize, 1, neu1e, 0, 1, syn0, l1, 1) 537 | syn0Modify(lastWord) += 1 538 | } 539 | } 540 | a += 1 541 | } 542 | sentencePosition += 1 543 | } else { 544 | // Train Skip-gram 545 | var a = b 546 | while (a < window * 2 + 1 - b) { 547 | if (a != window) { 548 | val c = sentencePosition - window + a 549 | if (c >= 0 && c < sentenceLength) { 550 | val lastWord = sen(c) 551 | val l1 = lastWord * vectorSize 552 | val neu1e = new Array[Float](vectorSize) 553 | 554 | if (hs == 1) { 555 | // Hierarchical softmax 556 | var d = 0 557 | while (d < bcVocab.value(word).codeLen) { 558 | val inner = bcVocab.value(word).point(d) 559 | val l2 = inner * vectorSize 560 | // Propagate hidden -> output 561 | var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) 562 | if (f > -MAX_EXP && f < MAX_EXP) { 563 | val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt 564 | f = expTable.value(ind) 565 | val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat 566 | blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) 567 | blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) 568 | syn1Modify(inner) += 1 569 | } 570 | d += 1 571 | } 572 | } 573 | 574 | if (negative > 0) { 575 | // Negative sampling 576 | var target = -1 577 | var label = -1 578 | var d = 0 579 | while (d < negative + 1) { 580 | var isContinued = false 581 | if (d == 0) { 582 | target = word 583 | label = 1 584 | } else { 585 | target = bcTable.value(random.nextInt(tableSize)) 586 | if (target == word) { 587 | isContinued = true 588 | } 589 | if (!isContinued.equals(true)) { 590 | label = 0 591 | } 592 | } 593 | if (!isContinued.equals(true)) { 594 | val l2 = target * vectorSize 595 | val f = blas.sdot(vectorSize, syn0, l1, 1, syn1Neg, l2, 1) 596 | var g = 0.0 597 | if (f > MAX_EXP) { 598 | g = (label - 1) * alpha 599 | } else if (f < -MAX_EXP) { 600 | g = (label - 0) * alpha 601 | } else { 602 | val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt 603 | g = (label - expTable.value(ind)) * alpha 604 | } 605 | blas.saxpy(vectorSize, g.toFloat, syn1Neg, l2, 1, neu1e, 0, 1) 606 | blas.saxpy(vectorSize, g.toFloat, syn0, l1, 1, syn1Neg, l2, 1) 607 | syn1NegModify(target) += 1 608 | } 609 | d += 1 610 | } 611 | } 612 | 613 | syn0Modify(lastWord) += 1 614 | blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) 615 | } 616 | } 617 | a += 1 618 | } 619 | sentencePosition += 1 620 | } 621 | } 622 | (syn0, syn1, syn1Neg, lwc, wc) 623 | } 624 | val syn0Local = model._1 625 | val syn1Local = model._2 626 | val syn1NegLocal = model._3 627 | // Only output modified vectors. 628 | Iterator.tabulate(vocabSize) { index => 629 | if (syn0Modify(index) > 0) { 630 | Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))) 631 | } else { 632 | None 633 | } 634 | }.flatten ++ Iterator.tabulate(vocabSize) { index => 635 | if (syn1Modify(index) > 0) { 636 | Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))) 637 | } else { 638 | None 639 | } 640 | }.flatten ++ Iterator.tabulate(vocabSize) { index => 641 | if (syn1NegModify(index) > 0) { 642 | Some((index + 2 * vocabSize, syn1NegLocal.slice(index * vectorSize, (index + 1) * vectorSize))) 643 | } else { 644 | None 645 | } 646 | }.flatten 647 | } 648 | 649 | val synAgg = partial.reduceByKey { case (v1, v2) => 650 | blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) 651 | v1 652 | }.collect() 653 | var i = 0 654 | while (i < synAgg.length) { 655 | val index = synAgg(i)._1 656 | if (index < vocabSize) { 657 | Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize) 658 | } else if (index >= vocabSize && index < 2 * vocabSize) { 659 | Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize) 660 | } else { 661 | Array.copy(synAgg(i)._2, 0, syn1NegGlobal, (index - 2 * vocabSize) * vectorSize, vectorSize) 662 | } 663 | i += 1 664 | } 665 | bcSyn0Global.unpersist(false) 666 | bcSyn1Global.unpersist(false) 667 | bcSyn1NegGlobal.unpersist(false) 668 | } 669 | newSentences.unpersist() 670 | expTable.destroy(false) 671 | bcVocab.destroy(false) 672 | bcVocabHash.destroy(false) 673 | 674 | val wordArray = vocab.map(_.word) 675 | new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global) 676 | } 677 | 678 | /** 679 | * Computes the vector representation of each word in vocabulary (Java version). 680 | * @param dataset a JavaRDD of words 681 | * @return a Word2VecModel 682 | */ 683 | def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = { 684 | fit(dataset.rdd.map(_.asScala)) 685 | } 686 | } 687 | 688 | /** 689 | * Word2Vec model 690 | * @param wordIndex maps each word to an index, which can retrieve the corresponding 691 | * vector from wordVectors 692 | * @param wordVectors array of length numWords * vectorSize, vector corresponding 693 | * to the word mapped with index i can be retrieved by the slice 694 | * (i * vectorSize, i * vectorSize + vectorSize) 695 | */ 696 | class Word2VecModel private[spark] ( 697 | private[spark] val wordIndex: Map[String, Int], 698 | private[spark] val wordVectors: Array[Float]) extends Serializable with Saveable { 699 | 700 | val numWords = wordIndex.size 701 | // vectorSize: Dimension of each word's vector. 702 | val vectorSize = wordVectors.length / numWords 703 | 704 | // wordList: Ordered list of words obtained from wordIndex. 705 | private val wordList: Array[String] = { 706 | val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip 707 | wl.toArray 708 | } 709 | 710 | // wordVecNorms: Array of length numWords, each value being the Euclidean norm 711 | // of the wordVector. 712 | private val wordVecNorms: Array[Double] = { 713 | val wordVecNorms = new Array[Double](numWords) 714 | var i = 0 715 | while (i < numWords) { 716 | val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize) 717 | wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1) 718 | i += 1 719 | } 720 | wordVecNorms 721 | } 722 | 723 | def this(model: Map[String, Array[Float]]) = { 724 | this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model)) 725 | } 726 | 727 | override protected def formatVersion = "1.0" 728 | 729 | def save(sc: SparkContext, path: String): Unit = { 730 | Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors) 731 | } 732 | 733 | /* private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { 734 | require(v1.length == v2.length, "Vectors should have the same length") 735 | val n = v1.length 736 | val norm1 = blas.snrm2(n, v1, 1) 737 | val norm2 = blas.snrm2(n, v2, 1) 738 | if (norm1 == 0 || norm2 == 0) return 0.0 739 | blas.sdot(n, v1, 1, v2, 1) / norm1 / norm2 740 | }*/ 741 | 742 | /** 743 | * Transforms a word to its vector representation 744 | * @param word a word 745 | * @return vector representation of word 746 | */ 747 | def transform(word: String): Vector = { 748 | wordIndex.get(word) match { 749 | case Some(ind) => 750 | val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize) 751 | Vectors.dense(vec.map(_.toDouble)) 752 | case None => 753 | throw new IllegalStateException(s"$word not in vocabulary") 754 | } 755 | } 756 | 757 | /** 758 | * Find synonyms of a word 759 | * @param word a word 760 | * @param num number of synonyms to find 761 | * @return array of (word, cosineSimilarity) 762 | */ 763 | def findSynonyms(word: String, num: Int): Array[(String, Double)] = { 764 | val vector = transform(word) 765 | findSynonyms(vector, num) 766 | } 767 | 768 | /** 769 | * Find synonyms of the vector representation of a word, possibly 770 | * including any words in the model vocabulary whose vector respresentation 771 | * is the supplied vector. 772 | * @param vector vector representation of a word 773 | * @param num number of synonyms to find 774 | * @return array of (word, cosineSimilarity) 775 | */ 776 | def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { 777 | findSynonyms(vector, num, None) 778 | } 779 | 780 | 781 | /** 782 | * Find synonyms of the vector representation of a word, rejecting 783 | * words identical to the value of wordOpt, if one is supplied. 784 | * @param vector vector representation of a word 785 | * @param num number of synonyms to find 786 | * @param wordOpt optionally, a word to reject from the results list 787 | * @return array of (word, cosineSimilarity) 788 | */ 789 | private def findSynonyms( 790 | vector: Vector, 791 | num: Int, 792 | wordOpt: Option[String]): Array[(String, Double)] = { 793 | require(num > 0, "Number of similar words should > 0") 794 | // TODO: optimize top-k 795 | val fVector = vector.toArray.map(_.toFloat) 796 | val cosineVec = Array.fill[Float](numWords)(0) 797 | val alpha: Float = 1 798 | val beta: Float = 0 799 | // Normalize input vector before blas.sgemv to avoid Inf value 800 | val vecNorm = blas.snrm2(vectorSize, fVector, 1) 801 | if (vecNorm != 0.0f) { 802 | blas.sscal(vectorSize, 1 / vecNorm, fVector, 0, 1) 803 | } 804 | blas.sgemv( 805 | "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) 806 | 807 | val cosVec = cosineVec.map(_.toDouble) 808 | var ind = 0 809 | while (ind < numWords) { 810 | val norm = wordVecNorms(ind) 811 | if (norm == 0.0) { 812 | cosVec(ind) = 0.0 813 | } else { 814 | cosVec(ind) /= norm 815 | } 816 | ind += 1 817 | } 818 | 819 | val scored = wordList.zip(cosVec).toSeq.sortBy(-_._2) 820 | 821 | val filtered = wordOpt match { 822 | case Some(w) => scored.take(num + 1).filter(tup => w != tup._1) 823 | case None => scored 824 | } 825 | 826 | filtered.take(num).toArray 827 | } 828 | 829 | /** 830 | * Returns a map of words to their vector representations. 831 | */ 832 | def getVectors: Map[String, Array[Float]] = { 833 | wordIndex.map { case (word, ind) => 834 | (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize)) 835 | } 836 | } 837 | } 838 | 839 | object Word2VecModel extends Loader[Word2VecModel] { 840 | 841 | private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = { 842 | model.keys.zipWithIndex.toMap 843 | } 844 | 845 | private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = { 846 | require(model.nonEmpty, "Word2VecMap should be non-empty") 847 | val (vectorSize, numWords) = (model.head._2.size, model.size) 848 | val wordList = model.keys.toArray 849 | val wordVectors = new Array[Float](vectorSize * numWords) 850 | var i = 0 851 | while (i < numWords) { 852 | Array.copy(model(wordList(i)), 0, wordVectors, i * vectorSize, vectorSize) 853 | i += 1 854 | } 855 | wordVectors 856 | } 857 | 858 | private object SaveLoadV1_0 { 859 | 860 | val formatVersionV1_0 = "1.0" 861 | 862 | val classNameV1_0 = "org.apache.spark.mllib.feature.Word2VecModel" 863 | 864 | case class Data(word: String, vector: Array[Float]) 865 | 866 | def load(sc: SparkContext, path: String): Word2VecModel = { 867 | val spark = SparkSession.builder().sparkContext(sc).getOrCreate() 868 | val dataFrame = spark.read.parquet(Loader.dataPath(path)) 869 | // Check schema explicitly since erasure makes it hard to use match-case for checking. 870 | Loader.checkSchema[Data](dataFrame.schema) 871 | 872 | val dataArray = dataFrame.select("word", "vector").collect() 873 | val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap 874 | new Word2VecModel(word2VecMap) 875 | } 876 | 877 | def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { 878 | val spark = SparkSession.builder().sparkContext(sc).getOrCreate() 879 | 880 | val vectorSize = model.values.head.length 881 | val numWords = model.size 882 | val metadata = compact(render( 883 | ("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ 884 | ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) 885 | sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) 886 | 887 | // We want to partition the model in partitions smaller than 888 | // spark.kryoserializer.buffer.max 889 | val bufferSize = Utils.byteStringAsBytes( 890 | spark.conf.get("spark.kryoserializer.buffer.max", "64m")) 891 | // We calculate the approximate size of the model 892 | // We only calculate the array size, considering an 893 | // average string size of 15 bytes, the formula is: 894 | // (floatSize * vectorSize + 15) * numWords 895 | val approxSize = (4L * vectorSize + 15) * numWords 896 | val nPartitions = ((approxSize / bufferSize) + 1).toInt 897 | val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } 898 | spark.createDataFrame(dataArray).repartition(nPartitions).write.parquet(Loader.dataPath(path)) 899 | } 900 | } 901 | 902 | override def load(sc: SparkContext, path: String): Word2VecModel = { 903 | 904 | val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) 905 | implicit val formats = DefaultFormats 906 | val expectedVectorSize = (metadata \ "vectorSize").extract[Int] 907 | val expectedNumWords = (metadata \ "numWords").extract[Int] 908 | val classNameV1_0 = SaveLoadV1_0.classNameV1_0 909 | (loadedClassName, loadedVersion) match { 910 | case (classNameV1_0, "1.0") => 911 | val model = SaveLoadV1_0.load(sc, path) 912 | val vectorSize = model.getVectors.values.head.length 913 | val numWords = model.getVectors.size 914 | require(expectedVectorSize == vectorSize, 915 | s"Word2VecModel requires each word to be mapped to a vector of size " + 916 | s"$expectedVectorSize, got vector of size $vectorSize") 917 | require(expectedNumWords == numWords, 918 | s"Word2VecModel requires $expectedNumWords words, but got $numWords") 919 | model 920 | case _ => throw new Exception( 921 | s"Word2VecModel.load did not recognize model with (className, format version):" + 922 | s"($loadedClassName, $loadedVersion). Supported:\n" + 923 | s" ($classNameV1_0, 1.0)") 924 | } 925 | } 926 | } 927 | --------------------------------------------------------------------------------