├── LICENSE ├── README.md ├── pom.xml └── src └── main └── scala └── org └── apache └── spark └── mllib └── topicModeling ├── GibbsLDAOptimizer.scala ├── GibbsLDASampler.scala ├── LDA.scala ├── LDAExample.scala ├── LDAModel.scala ├── LDAOptimizer.scala ├── OnlineHDP.scala └── OnlineLDAOptimizer.scala /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Topic Modeling on Apache Spark 2 | This package contains a set of distributed text modeling algorithms implemented on Spark, including: 3 | 4 | - **Online LDA**: an early version of the implementation was merged into MLlib (PR #4419), and several extensions (e.g., predict) are added here 5 | 6 | - **Gibbs sampling LDA**: the implementation is adapted from Spark PRs(#1405 and #4807) and JIRA SPARK-5556 (https://github.com/witgo/spark/tree/lda_Gibbs, https://github.com/EntilZha/spark/tree/LDA-Refactor, https://github.com/witgo/zen/tree/lda_opt/ml, etc.), with several extensions (e.g., support for MLlib interface, predict and in-place state update) added 7 | 8 | - **Online HDP (hierarchical Dirichlet process)**: implemented based on the paper "Online Variational Inference for the Hierarchical Dirichlet Process" (Chong Wang, John Paisley and David M. Blei) 9 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | org.apache.spark.mllib 5 | topicModel 6 | 1.0-SNAPSHOT 7 | 8 | 9 | 1.6 10 | 1.6 11 | UTF-8 12 | 2.10.4 13 | 14 | 15 | 16 | 17 | 18 | 19 | net.alchim31.maven 20 | scala-maven-plugin 21 | 3.1.5 22 | 23 | 24 | org.apache.maven.plugins 25 | maven-compiler-plugin 26 | 2.0.2 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | net.alchim31.maven 35 | scala-maven-plugin 36 | 37 | 38 | scala-compile-first 39 | process-resources 40 | 41 | add-source 42 | compile 43 | 44 | 45 | 46 | scala-test-compile 47 | process-test-resources 48 | 49 | testCompile 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | maven-assembly-plugin 58 | 2.4 59 | 60 | 61 | jar-with-dependencies 62 | 63 | 64 | 65 | 66 | make-assembly 67 | package 68 | 69 | single 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | org.scala-lang 81 | scala-library 82 | ${scala.version} 83 | 84 | 85 | 86 | org.apache.spark 87 | spark-core_2.10 88 | 1.3.1 89 | 90 | 91 | com.github.scopt 92 | scopt_2.10 93 | 3.2.0 94 | 95 | 96 | org.apache.spark 97 | spark-mllib_2.10 98 | 1.3.1 99 | 100 | 101 | org.scalanlp 102 | breeze_2.10 103 | 0.11.2 104 | 105 | 106 | junit 107 | junit 108 | 109 | 110 | org.apache.commons 111 | commons-math3 112 | 113 | 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/topicModeling/GibbsLDAOptimizer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.topicModeling 19 | 20 | import java.util.Random 21 | 22 | import org.apache.hadoop.io.NullWritable 23 | import org.apache.hadoop.mapred.TextOutputFormat 24 | import org.apache.spark.graphx._ 25 | import org.apache.spark.graphx.impl.GraphImpl 26 | import org.apache.spark.{HashPartitioner, Logging, Partitioner} 27 | import org.apache.spark.rdd.RDD 28 | import org.apache.spark.serializer.KryoRegistrator 29 | import org.apache.spark.storage.StorageLevel 30 | import org.apache.spark.SparkContext 31 | import org.apache.hadoop.fs._ 32 | 33 | import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV, StorageVector, 34 | sum => brzSum, norm => brzNorm, DenseMatrix => BDM, Matrix => BM, CSCMatrix => BSM} 35 | 36 | import org.apache.spark.mllib.linalg.{DenseVector => SDV, SparseVector => SSV, 37 | DenseMatrix => SDM, SparseMatrix => SSM, Vector => SV, Matrix} 38 | 39 | import GibbsLDAOptimizer._ 40 | 41 | import scala.reflect.ClassTag 42 | import org.json4s.DefaultFormats 43 | import org.json4s.JsonDSL._ 44 | import org.json4s.jackson.JsonMethods._ 45 | import org.apache.hadoop.io.{NullWritable, Text} 46 | 47 | /** 48 | * 49 | * Adapted from Spark PRs(#1405 and #4807) and JIRA SPARK-5556 (https://github.com/witgo/spark/tree/lda_Gibbs, 50 | * https://github.com/EntilZha/spark/tree/LDA-Refactor, https://github.com/witgo/zen/tree/lda_opt/ml, etc.), 51 | * with several extensions (e.g., support for MLlib interface, predict and in-place state update) added 52 | */ 53 | class GibbsLDAOptimizer private[topicModeling]( 54 | private var alphaAS: Float, 55 | private var storageLevel: StorageLevel, 56 | private var sampler:GibbsLDASampler = new GibbsLDAAliasSampler, 57 | var edgePartitioner:String = "none", 58 | var printPerplexity:Boolean = false) 59 | extends LDAOptimizer with Serializable with Logging { 60 | 61 | def this() = this(0.1f, StorageLevel.MEMORY_AND_DISK) 62 | 63 | @transient private var corpus:Graph[VD, ED] = null 64 | private var alpha = 0.01f 65 | private var beta = 0.01f 66 | private var numTopics = 0 67 | private var numTerms = 0 68 | private var numTokens = 0L 69 | private var checkpointInterval = Int.MaxValue 70 | 71 | /** 72 | * Initializer for the optimizer. LDA passes the common parameters to the optimizer and 73 | * the internal structure can be initialized properly. 74 | */ 75 | override def initialize(docs: RDD[(Long, SV)], lda: LDA): LDAOptimizer = { 76 | alpha = lda.getAlpha.toFloat 77 | beta = lda.getBeta.toFloat 78 | numTopics = lda.getK 79 | numTerms = docs.first()._2.size 80 | seed = lda.getSeed.toInt 81 | setCheckpointInterval(lda.getCheckpointInterval) 82 | corpus = initializeCorpus(docs, lda.getK, storageLevel, edgePartitioner) 83 | numTokens = corpus.edges.map(e => e.attr.size.toDouble).sum().toLong 84 | totalTopicCounter = collectTotalTopicCounter(corpus, numTopics, numTokens) 85 | this 86 | } 87 | 88 | private var lastSampledCorpus:Option[Graph[VD, ED]] = None 89 | 90 | private def termVertices = corpus.vertices.filter(t => t._1>=0) 91 | 92 | private def docVertices = corpus.vertices.filter(t => t._1<0) 93 | 94 | /** 95 | * run an iteration 96 | * @return 97 | */ 98 | def next(): LDAOptimizer = { 99 | 100 | gibbsSampling() 101 | 102 | if (printPerplexity) { 103 | println(s"Perplexity of $innerIter-th is ${perplexity}") 104 | } 105 | this 106 | } 107 | 108 | private def gibbsSampling(): Unit = { 109 | val sampledCorpus = sampler.sampleTokens(corpus, totalTopicCounter, innerIter + seed, 110 | numTokens, numTopics, numTerms, alpha, alphaAS, beta) 111 | sampledCorpus.edges.persist(storageLevel) 112 | sampledCorpus.edges.count() 113 | 114 | val counterCorpus = updateCounter(sampledCorpus, numTopics) 115 | checkpoint(counterCorpus, innerIter, checkpointInterval) 116 | counterCorpus.vertices.persist(storageLevel) 117 | counterCorpus.vertices.count() 118 | 119 | totalTopicCounter = collectTotalTopicCounter(counterCorpus, numTopics, numTokens) 120 | 121 | corpus.edges.unpersist(false) 122 | corpus.vertices.unpersist(false) 123 | lastSampledCorpus.map(_.edges.unpersist(false)) 124 | lastSampledCorpus = Some(sampledCorpus) 125 | corpus = counterCorpus 126 | 127 | innerIter += 1 128 | } 129 | 130 | /** 131 | * get model of current iteration 132 | * @param iterationTimes 133 | * @return 134 | */ 135 | override def getLDAModel(iterationTimes: Array[Double]): LDAModel = { 136 | val ldaModel:GibbsLDAModel = saveModel(1) 137 | ldaModel 138 | } 139 | 140 | def setCheckpointInterval(cpInterval: Int): this.type = { 141 | this.checkpointInterval = cpInterval 142 | this 143 | } 144 | 145 | def getCheckpointInterval(): Int = this.checkpointInterval 146 | 147 | def setAlphaAS(alphaAS: Float): this.type = { 148 | this.alphaAS = alphaAS 149 | this 150 | } 151 | 152 | def getAlphaAS():Float = this.alphaAS 153 | 154 | def setStorageLevel(newStorageLevel: StorageLevel): this.type = { 155 | this.storageLevel = newStorageLevel 156 | this 157 | } 158 | 159 | def getStorageLevel(): StorageLevel = this.storageLevel 160 | 161 | def setSeed(newSeed: Int): this.type = { 162 | this.seed = newSeed 163 | this 164 | } 165 | 166 | def getSeed():Int = this.seed 167 | 168 | def setSampler(sampler:GibbsLDASampler): this.type = { 169 | this.sampler = sampler 170 | this 171 | } 172 | 173 | def getSampler(): GibbsLDASampler = { 174 | this.sampler 175 | } 176 | 177 | def setSampler(sampler:String): this.type = { 178 | this.sampler = 179 | sampler.toLowerCase match { 180 | case "alias" => new GibbsLDAAliasSampler 181 | case "sparse" => new GibbsLDASparseSampler 182 | case "light" => new GibbsLDALightSampler 183 | case "fast" => new GibbsLDAFastSampler 184 | case _ => 185 | throw new IllegalArgumentException(s"Only alias, sparse, light are supported but got $sampler.") 186 | } 187 | this 188 | } 189 | 190 | def getCorpus = corpus 191 | 192 | @transient private var seed = new Random().nextInt() 193 | @transient private var innerIter = 1 194 | @transient private var totalTopicCounter: BDV[Count] = null 195 | 196 | // scalastyle:off 197 | /** 198 | * p(w)=\sum_{k}{p(k|d)*p(w|k)}= 199 | * \sum_{k}{\frac{{n}_{kw}+{\beta }_{w}} {{n}_{k}+\bar{\beta }} \frac{{n}_{kd}+{\alpha }_{k}} {\sum{{n}_{k}}+\bar{\alpha }}}= 200 | * \sum_{k} \frac{{\alpha }_{k}{\beta }_{w} + {n}_{kw}{\alpha }_{k} + {n}_{kd}{\beta }_{w} + {n}_{kw}{n}_{kd}}{{n}_{k}+\bar{\beta }} \frac{1}{\sum{{n}_{k}}+\bar{\alpha }}} 201 | * \exp^{-(\sum{\log(p(w))})/N} 202 | * N is total token number within the corpus 203 | */ 204 | // scalastyle:on 205 | def perplexity(): Double = { 206 | val totalTopicCounter = this.totalTopicCounter 207 | val numTopics = this.numTopics 208 | val numTerms = this.numTerms 209 | val alpha = this.alpha 210 | val beta = this.beta 211 | 212 | val totalSize = brzSum(totalTopicCounter) & 0x00000000FFFFFFFFL 213 | var totalProb = 0D 214 | // \frac{{\alpha }_{k}{\beta }_{w}}{{n}_{k}+\bar{\beta }} 215 | totalTopicCounter.activeIterator.foreach { case (topic, cn) => 216 | totalProb += alpha * beta / (cn + numTerms * beta) 217 | } 218 | val termProb = corpus.mapVertices { (vid, counter) => 219 | var probDist = 0D 220 | if (vid >= 0) { 221 | val termTopicCounter = counter 222 | // \frac{{n}_{kw}{\alpha }_{k}}{{n}_{k}+\bar{\beta }} 223 | termTopicCounter.activeIterator.filter(_._2 > 0).foreach { case (topic, cn) => 224 | probDist += cn * alpha / 225 | (totalTopicCounter(topic) + numTerms * beta) 226 | } 227 | } else { 228 | val docTopicCounter = counter 229 | // \frac{{n}_{kd}{\beta }_{w}}{{n}_{k}+\bar{\beta }} 230 | docTopicCounter.activeIterator.filter(_._2 > 0).foreach { case (topic, cn) => 231 | probDist += cn * beta / 232 | (totalTopicCounter(topic) + numTerms * beta) 233 | } 234 | } 235 | (counter, probDist) 236 | }.mapTriplets { triplet => 237 | val (termTopicCounter, termProb) = triplet.srcAttr 238 | val (docTopicCounter, docProb) = triplet.dstAttr 239 | val docSize = docTopicCounter.sum 240 | val docTermSize = triplet.attr.length 241 | var prob = 0D 242 | // \frac{{n}_{kw}{n}_{kd}}{{n}_{k}+\bar{\beta}} 243 | docTopicCounter.activeIterator.filter(_._2 > 0).foreach { case (topic, cn) => 244 | prob += ((termTopicCounter(topic) / (totalTopicCounter(topic) + numTerms * beta))*cn) 245 | } 246 | prob += docProb + termProb + totalProb 247 | prob += prob / (docSize + numTopics * alpha) 248 | docTermSize * Math.log(prob) 249 | }.edges.map(t => t.attr).sum() 250 | math.exp(-1 * termProb / totalSize) 251 | } 252 | 253 | /** 254 | * Save the term-topic related model 255 | * @param totalIter 256 | */ 257 | def saveModel(totalIter: Int = 1): GibbsLDAModel = { 258 | var termTopicCounter: RDD[(VertexId, VD)] = null 259 | for (iter <- 1 to totalIter) { 260 | logInfo(s"Save TopicModel (Iteration $iter/$totalIter)") 261 | var previousTermTopicCounter = termTopicCounter 262 | gibbsSampling() 263 | val newTermTopicCounter = termVertices 264 | termTopicCounter = Option(termTopicCounter).map(_.join(newTermTopicCounter).map { 265 | case (term, (a, b)) => 266 | var c: VD = null 267 | if(a.isInstanceOf[BSV[Count]] && b.isInstanceOf[BSV[Count]]) { 268 | c = a.asInstanceOf[BSV[Count]] :+ b.asInstanceOf[BSV[Count]] 269 | } else if(a.isInstanceOf[BDV[Count]] && b.isInstanceOf[BDV[Count]]){ 270 | c = a.asInstanceOf[BDV[Count]] :+ b.asInstanceOf[BDV[Count]] 271 | } else if(a.isInstanceOf[BDV[Count]]) { 272 | c = a.asInstanceOf[BDV[Count]] :+ b.asInstanceOf[BSV[Count]].toDenseVector 273 | } else { 274 | c = a.asInstanceOf[BSV[Count]].toDenseVector :+ b.asInstanceOf[BDV[Count]] 275 | } 276 | (term, c) 277 | }).getOrElse(newTermTopicCounter) 278 | 279 | termTopicCounter.persist(storageLevel).count() 280 | Option(previousTermTopicCounter).foreach(_.unpersist(blocking = false)) 281 | previousTermTopicCounter = termTopicCounter 282 | } 283 | val rand = new Random() 284 | val ttc = termTopicCounter.mapValues(c => { 285 | if (c.isInstanceOf[BDV[Count]]) { 286 | val dv = c.asInstanceOf[BDV[Count]] 287 | val nc = new BDV[Count](dv.data.map (v => { 288 | val mid = v.toDouble / totalIter 289 | val l = math.floor(mid) 290 | if (rand.nextDouble() > mid - l) { 291 | l 292 | } else { 293 | l + 1 294 | } 295 | }.asInstanceOf[Count])) 296 | nc.asInstanceOf[StorageVector[Count]] 297 | } else { 298 | val sv = c.asInstanceOf[BSV[Count]] 299 | val nc = new BSV[Count](sv.index.slice(0, sv.used), sv.data.slice(0, sv.used).map(v => { 300 | val mid = v.toDouble / totalIter 301 | val l = math.floor(mid) 302 | if (rand.nextDouble() > mid - l) { 303 | l 304 | } else { 305 | l + 1 306 | } 307 | }.asInstanceOf[Count]), c.length) 308 | nc 309 | } 310 | }) 311 | 312 | ttc.persist(storageLevel) 313 | val gtc = ttc.map(_._2).aggregate(BDV.zeros[Count](numTopics))(_ :+= _,_ :+= _) 314 | new GibbsLDAModel(gtc, ttc, corpus.vertices, numTerms, alpha, beta, alphaAS) 315 | } 316 | } 317 | 318 | object GibbsLDAOptimizer { 319 | 320 | private[topicModeling] type DocId = VertexId 321 | private[topicModeling] type WordId = VertexId 322 | private[topicModeling] type Count = Int 323 | private[topicModeling] type ED = Array[Count] 324 | private[topicModeling] type VD = StorageVector[Count] 325 | 326 | def checkpoint(corpus: Graph[VD, ED], innerIter: Int, checkpointInterval: Int): Unit = { 327 | if (innerIter % checkpointInterval == 0 && corpus.edges.sparkContext.getCheckpointDir.isDefined) { 328 | corpus.checkpoint() 329 | } 330 | } 331 | 332 | def collectTotalTopicCounter(graph: Graph[VD, ED], numTopics: Int, numTokens: Long): BDV[Count] = { 333 | val globalTopicCounter = collectGlobalCounter(graph, numTopics) 334 | val totalSize = brzSum(globalTopicCounter) & 0x00000000FFFFFFFFL 335 | println("numTokens:"+numTokens+" brzSum:"+totalSize) 336 | assert(totalSize == numTokens) 337 | globalTopicCounter 338 | } 339 | 340 | def updateCounter(graph: Graph[VD, ED], numTopics: Int): Graph[VD, ED] = { 341 | val newCounter = graph.aggregateMessages[BSV[Count]](ctx => { 342 | val topics = ctx.attr 343 | val vector = BSV.zeros[Count](numTopics) 344 | for (topic <- topics) { 345 | vector(topic) += 1 346 | } 347 | ctx.sendToDst(vector) 348 | ctx.sendToSrc(vector) 349 | }, _ + _, TripletFields.EdgeOnly) 350 | .mapValues(sparseVector => { 351 | val storageVector:VD = 352 | if (sparseVector.activeSize > sparseVector.length / 2) { 353 | sparseVector.toDenseVector 354 | } else { 355 | sparseVector 356 | } 357 | storageVector 358 | }) 359 | // GraphImpl.fromExistingRDDs(newCounter, graph.edges) 360 | GraphImpl(newCounter, graph.edges) 361 | } 362 | 363 | def collectGlobalCounter(graph: Graph[VD, ED], numTopics: Int): BDV[Count] = { 364 | graph.vertices.filter(t => t._1 >= 0).map(_._2).aggregate(BDV.zeros[Count](numTopics))((a, b) => { 365 | a :+= b 366 | }, _ :+= _) 367 | } 368 | 369 | def initializeCorpus( 370 | docs: RDD[(GibbsLDAOptimizer.DocId, SV)], 371 | numTopics: Int, 372 | storageLevel: StorageLevel, edgePartitioner:String): Graph[VD, ED] = { 373 | val edges = docs.mapPartitionsWithIndex((pid, iter) => { 374 | val gen = new Random(pid) 375 | iter.flatMap { 376 | case (docId, doc) => 377 | initializeEdges(gen, doc, docId, numTopics) 378 | } 379 | }) 380 | edges.persist(storageLevel) 381 | val corpus: Graph[VD, ED] = edgePartitioner match { 382 | case "none" => 383 | Graph.fromEdges(edges, null, storageLevel, storageLevel) 384 | case "degree" => 385 | val degreeCorpus = Graph.fromEdges(edges, null, storageLevel, storageLevel) 386 | val degrees = degreeCorpus.outerJoinVertices(degreeCorpus.degrees) { (vid, data, deg) => deg.getOrElse(0) } 387 | val numPartitions = edges.partitions.size 388 | val partitionStrategy = new DBHPartitioner(numPartitions) 389 | val newEdges = degrees.triplets.map { e => 390 | (partitionStrategy.getPartition(e), Edge(e.srcId, e.dstId, e.attr)) 391 | }.partitionBy(new HashPartitioner(numPartitions)).map(_._2) 392 | Graph.fromEdges(newEdges, null, storageLevel, storageLevel) 393 | case "docIdCluster" => 394 | val newEdges = edges.map { e => 395 | (e.dstId, Edge(e.srcId, e.dstId, e.attr)) 396 | }.partitionBy(new HashPartitioner(docs.partitions.size)).map(_._2) 397 | Graph.fromEdges(newEdges, null, storageLevel, storageLevel) 398 | 399 | case _ => 400 | throw new IllegalArgumentException(s"invalid values of edgePartitioner are none, degree and docIdCluster, but got $edgePartitioner") 401 | } 402 | 403 | val resultCorpus = updateCounter(corpus, numTopics).cache() 404 | resultCorpus.vertices.count() 405 | resultCorpus.edges.count() 406 | corpus.unpersist(false) 407 | edges.unpersist(false) 408 | docs.unpersist(false) 409 | resultCorpus 410 | } 411 | 412 | private def initializeEdges( 413 | gen: Random, 414 | doc: SV, 415 | docId: DocId, 416 | numTopics: Int): Iterator[Edge[ED]] = { 417 | assert(docId >= 0) 418 | val newDocId: DocId = -(docId + 1L) 419 | 420 | doc.toBreeze.activeIterator.filter(_._2 > 0).map {case (termId, counter) => 421 | val topics = new Array[Int](counter.toInt) 422 | for (i <- 0 until counter.toInt) { 423 | topics(i) = gen.nextInt(numTopics) 424 | } 425 | Edge(termId, newDocId, topics) 426 | } 427 | } 428 | } 429 | 430 | class GibbsLDAModel ( 431 | private[topicModeling] val gtc: BDV[Count], 432 | @transient private[topicModeling] val ttc: RDD[(VertexId, VD)], 433 | @transient private[topicModeling] val vertices: RDD[(VertexId, VD)], 434 | val numTerms: Int, 435 | val alpha: Float, 436 | val beta: Float, 437 | val alphaAS: Float) extends org.apache.spark.mllib.topicModeling.LDAModel with Serializable { 438 | 439 | private val numTopics = gtc.size 440 | 441 | private lazy val numTokens = brzSum(gtc) & 0x00000000FFFFFFFFL 442 | 443 | /** Number of topics */ 444 | def k: Int = numTopics 445 | 446 | /** Vocabulary size (number of terms or terms in the vocabulary) */ 447 | def vocabSize: Int = numTerms 448 | 449 | /** 450 | * Inferred topics, where each topic is represented by a distribution over terms. 451 | * This is a matrix of size vocabSize x k, where each column is a topic. 452 | * No guarantees are given about the ordering of the topics. 453 | */ 454 | def topicsMatrix: Matrix = { 455 | val matrix = BDM.zeros[Double](numTerms, numTopics) 456 | 457 | val ttcArray = Array.fill(numTerms.toInt) { 458 | BSV.zeros[Count](numTopics) 459 | } 460 | ttc.collect().foreach { case (termId, vector) => 461 | ttcArray(termId.toInt) :+= vector 462 | } 463 | 464 | for (termId <- 0 until numTerms) { 465 | val sv = ttcArray(termId) 466 | for (topicId <- 0 until numTopics) { 467 | matrix(termId, topicId) = sv(topicId) 468 | } 469 | } 470 | GibbsLDAUtils.fromBreezeMatrix(matrix) 471 | } 472 | 473 | /** 474 | * Return the topics described by weighted terms. 475 | * 476 | * This limits the number of terms per topic. 477 | * This is approximate; it may not return exactly the top-weighted terms for each topic. 478 | * To get a more precise set of top terms, increase maxTermsPerTopic. 479 | * 480 | * @param maxTermsPerTopic Maximum number of terms to collect for each topic. 481 | * @return Array over topics. Each topic is represented as a pair of matching arrays: 482 | * (term indices, term weights in topic). 483 | * Each topic's terms are sorted in order of decreasing weight. 484 | */ 485 | def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { 486 | val ttcArray = Array.fill(numTerms.toInt) { 487 | BSV.zeros[Count](numTopics) 488 | } 489 | ttc.collect().foreach { case (termId, vector) => 490 | ttcArray(termId.toInt) :+= vector 491 | } 492 | 493 | (0 until numTopics).map(topicId => { 494 | val terms = (0 until numTerms).map(termId => (termId, ttcArray(termId)(topicId).toDouble)) 495 | .sortBy(_._2)(Ordering[Double].reverse) 496 | .take(maxTermsPerTopic) 497 | .map(_._1).toArray 498 | val weights = terms.map(ttcArray(_)(topicId).toDouble) 499 | val wsum = weights.sum 500 | if (wsum > 1E-5) { 501 | (0 until weights.length).foreach(weights(_) /= wsum) 502 | } 503 | (terms, weights) 504 | }).toArray 505 | } 506 | 507 | /** 508 | * For each document in the training set, return the distribution over topics for that document 509 | * ("theta_doc"). 510 | * 511 | * @return RDD of (document ID, topic distribution) pairs 512 | */ 513 | def predict(documents: RDD[(Long, SV)], 514 | optimizer: GibbsLDAOptimizer, 515 | totalIter: Int = 25, 516 | burnIn: Int = 22): RDD[(Long, SV)] = { 517 | 518 | val previousCorpus: Graph[VD, ED] = initializeCorpus(documents, numTopics, 519 | optimizer.getStorageLevel(), optimizer.edgePartitioner) 520 | 521 | var corpus = previousCorpus.outerJoinVertices(ttc) { (vid, c, v) => 522 | if (vid >= 0) { 523 | assert(v.isDefined) 524 | } 525 | v.getOrElse(c) 526 | } 527 | corpus.persist(optimizer.getStorageLevel) 528 | corpus.vertices.count() 529 | corpus.edges.count() 530 | previousCorpus.edges.unpersist(blocking = false) 531 | previousCorpus.vertices.unpersist(blocking = false) 532 | 533 | var docTopicCounter: RDD[(VertexId, VD)] = null 534 | 535 | for(i <- 1 to totalIter) { 536 | val previousCorpus = corpus 537 | 538 | val sampledCorpus = optimizer.getSampler().sampleTokens(corpus, gtc, i + optimizer.getSeed(), 539 | numTokens, numTopics, numTerms, alpha, alphaAS, beta) 540 | sampledCorpus.persist(optimizer.getStorageLevel) 541 | 542 | corpus = updateCounter(sampledCorpus, numTopics) 543 | checkpoint(corpus, i, optimizer.getCheckpointInterval) 544 | corpus.persist(optimizer.getStorageLevel) 545 | 546 | previousCorpus.edges.unpersist(false) 547 | previousCorpus.vertices.unpersist(false) 548 | 549 | sampledCorpus.edges.unpersist(false) 550 | sampledCorpus.vertices.unpersist(false) 551 | 552 | if (i > burnIn) { 553 | var previousDocTopicCounter = docTopicCounter 554 | val newTermTopicCounter = corpus.vertices.filter(t => t._1 < 0) 555 | docTopicCounter = Option(docTopicCounter).map(_.join(newTermTopicCounter).map { 556 | case (docId, (a, b)) => { 557 | var c: VD = null 558 | if(a.isInstanceOf[BSV[Count]] && b.isInstanceOf[BSV[Count]]) { 559 | c = a.asInstanceOf[BSV[Count]] :+ b.asInstanceOf[BSV[Count]] 560 | } else if(a.isInstanceOf[BDV[Count]] && b.isInstanceOf[BDV[Count]]){ 561 | c = a.asInstanceOf[BDV[Count]] :+ b.asInstanceOf[BDV[Count]] 562 | } else if(a.isInstanceOf[BDV[Count]]) { 563 | c = a.asInstanceOf[BDV[Count]] :+ b.asInstanceOf[BSV[Count]].toDenseVector 564 | } else { 565 | c = a.asInstanceOf[BSV[Count]].toDenseVector :+ b.asInstanceOf[BDV[Count]] 566 | } 567 | (docId, c) 568 | } 569 | }).getOrElse(newTermTopicCounter) 570 | 571 | docTopicCounter.persist(optimizer.getStorageLevel).count() 572 | Option(previousDocTopicCounter).foreach(_.unpersist(blocking = false)) 573 | previousDocTopicCounter = docTopicCounter 574 | } 575 | } 576 | docTopicCounter.map { case (docId, v) => 577 | if(v.isInstanceOf[BDV[Count]]) { 578 | val dv = v.asInstanceOf[BDV[Count]] 579 | val norm = brzNorm(dv, 1) 580 | (docId, GibbsLDAUtils.fromBreezeConv[Double](dv.map(_.toDouble) / norm )) 581 | } else { 582 | val sv = v.asInstanceOf[BSV[Count]] 583 | val norm = brzNorm(sv, 1) 584 | (docId, GibbsLDAUtils.fromBreezeConv[Double](sv.map(_.toDouble) / norm)) 585 | } 586 | } 587 | } 588 | 589 | /** 590 | * For each document in the training set, return the distribution over topics for that document 591 | * 592 | * @return RDD of (document ID, topic distribution) pairs 593 | */ 594 | def docTopicDistributions: RDD[(VertexId, VD)] = { 595 | vertices.filter(t => t._1<0) 596 | } 597 | 598 | /** 599 | * For each term in the training set, return the distribution over topics for that term 600 | * 601 | * @return RDD of (term ID, topic distribution) pairs 602 | */ 603 | def termTopicDistributions: RDD[(VertexId, VD)] = { 604 | vertices.filter(t => t._1>=0) 605 | } 606 | } 607 | 608 | private[topicModeling] object GibbsLDAUtils { 609 | 610 | private def _conv[T1: ClassTag, T2: ClassTag](data: Array[T1]): Array[T2] = { 611 | data.map(_.asInstanceOf[T2]).array 612 | } 613 | 614 | def fromBreezeConv[T: ClassTag](breezeVector: BV[T]): SV = { 615 | implicit val conv: Array[T] => Array[Double] = _conv[T, Double] 616 | 617 | breezeVector match { 618 | case v: BDV[T] => 619 | if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) { 620 | new SDV(v.data) 621 | } else { 622 | new SDV(v.toArray) // Can't use underlying array directly, so make a new one 623 | } 624 | case v: BSV[T] => 625 | if (v.index.length == v.used) { 626 | new SSV(v.length, v.index, _conv[T, Double](v.data)) 627 | } else { 628 | new SSV(v.length, v.index.slice(0, v.used), v.data.slice(0, v.used)) 629 | } 630 | case v: BV[T] => 631 | sys.error("Unsupported Breeze vector type: " + v.getClass.getName) 632 | } 633 | } 634 | 635 | def binarySearchInterval( 636 | index: Array[Float], 637 | key: Float, 638 | begin: Int, 639 | end: Int, 640 | greater: Boolean): Int = { 641 | if (begin == end) { 642 | return if (greater) end else begin - 1 643 | } 644 | var b = begin 645 | var e = end - 1 646 | 647 | var mid: Int = (e + b) >> 1 648 | while (b <= e) { 649 | mid = (e + b) >> 1 650 | val v = index(mid) 651 | if (scala.math.abs(key-v)<=1e-6) { 652 | return mid 653 | } 654 | else if (v > key) { 655 | e = mid - 1 656 | } 657 | else { 658 | b = mid + 1 659 | } 660 | } 661 | val v = index(mid) 662 | mid = if ((greater && v >= key) || (!greater && v <= key)) { 663 | mid 664 | } 665 | else if (greater) { 666 | mid + 1 667 | } 668 | else { 669 | mid - 1 670 | } 671 | 672 | if (greater) { 673 | if (mid < end) assert(index(mid) >= key) 674 | if (mid > 0) assert(index(mid - 1) <= key) 675 | } else { 676 | if (mid > 0) assert(index(mid) <= key) 677 | if (mid < end - 1) assert(index(mid + 1) >= key) 678 | } 679 | mid 680 | } 681 | 682 | /** 683 | * Creates a Matrix instance from a breeze matrix. 684 | * @param breeze a breeze matrix 685 | * @return a Matrix instance 686 | */ 687 | def fromBreezeMatrix(breeze: BM[Double]): Matrix = { 688 | breeze match { 689 | case dm: BDM[Double] => 690 | require(dm.majorStride == dm.rows, 691 | "Do not support stride size different from the number of rows.") 692 | new SDM(dm.rows, dm.cols, dm.data) 693 | case sm: BSM[Double] => 694 | new SSM(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data) 695 | case _ => 696 | throw new UnsupportedOperationException( 697 | s"Do not support conversion from type ${breeze.getClass.getName}.") 698 | } 699 | } 700 | 701 | /** 702 | * Creates a vector instance from a breeze vector. 703 | */ 704 | def fromBreezeVector(breezeVector: BV[Double]): SV = { 705 | breezeVector match { 706 | case v: BDV[Double] => 707 | if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) { 708 | new SDV(v.data) 709 | } else { 710 | new SDV(v.toArray) // Can't use underlying array directly, so make a new one 711 | } 712 | case v: BSV[Double] => 713 | if (v.index.length == v.used) { 714 | new SSV(v.length, v.index, v.data) 715 | } else { 716 | new SSV(v.length, v.index.slice(0, v.used), v.data.slice(0, v.used)) 717 | } 718 | case v: BV[_] => 719 | sys.error("Unsupported Breeze vector type: " + v.getClass.getName) 720 | } 721 | } 722 | 723 | def saveLDAModel(sc: SparkContext, ldamodel: GibbsLDAModel, path: String) = { 724 | val metadataPath = new Path(path, "metadata").toUri.toString 725 | val ttcPath = new Path(path, "ttc").toUri.toString 726 | val verticesPath = new Path(path, "vertices").toUri.toString 727 | 728 | val metaData = compact(render 729 | (("alpha" -> ldamodel.alpha) ~ ("beta" -> ldamodel.beta) ~ 730 | ("alphaAS" -> ldamodel.alphaAS) ~ ("numTerms" -> ldamodel.numTerms) ~ 731 | ("numTopics" -> ldamodel.k))) 732 | 733 | sc.parallelize(Seq(metaData)).saveAsTextFile(metadataPath) 734 | ldamodel.ttc.map { case (id, vector) => 735 | val list = vector.activeIterator.toList.sortWith((a, b) => a._2>b._2) 736 | (NullWritable.get(), new Text(id + "\t" + list.map(item => item._1 + ":" + item._2).mkString("\t"))) 737 | }.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](ttcPath) 738 | 739 | ldamodel.vertices.map { case (id, vector) => 740 | val list = vector.activeIterator.toList.sortWith((a, b) => a._2>b._2) 741 | (NullWritable.get(), new Text(id + "\t" + list.map(item => item._1 + ":" + item._2).mkString("\t"))) 742 | }.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](verticesPath) 743 | } 744 | 745 | def loadLDAModel(sc: SparkContext, path: String): GibbsLDAModel = { 746 | val metadataPath = new Path(path, "metadata").toUri.toString 747 | val ttcPath = new Path(path, "ttc").toUri.toString 748 | val verticesPath = new Path(path, "vertices").toUri.toString 749 | 750 | implicit val formats = DefaultFormats 751 | val metaData = parse(sc.textFile(metadataPath).first()) 752 | val alpha = (metaData \ "alpha").extract[Float] 753 | val beta = (metaData \ "beta").extract[Float] 754 | val alphaAS = (metaData \ "alphaAS").extract[Float] 755 | val numTerms = (metaData \ "numTerms").extract[Int] 756 | val numTopics = (metaData \ "numTopics").extract[Int] 757 | 758 | val ttc = sc.textFile(ttcPath).map { line => 759 | val sv = BSV.zeros[Count](numTopics) 760 | val arr = line.split("\t") 761 | arr.tail.foreach { sub => 762 | val Array(index, value) = sub.split(":") 763 | sv(index.toInt) = value.toInt 764 | } 765 | sv.compact() 766 | (arr.head.toLong, sv.asInstanceOf[StorageVector[Count]]) 767 | }.cache() 768 | 769 | val vertices = sc.textFile(verticesPath).map { line => 770 | val sv = BSV.zeros[Count](numTopics) 771 | val arr = line.split("\t") 772 | arr.tail.foreach { sub => 773 | val Array(index, value) = sub.split(":") 774 | sv(index.toInt) = value.toInt 775 | } 776 | sv.compact() 777 | (arr.head.toLong, sv.asInstanceOf[StorageVector[Count]]) 778 | } 779 | 780 | val gtc = ttc.map(_._2).aggregate(BDV.zeros[Count](numTopics))(_ :+= _,_ :+= _) 781 | new GibbsLDAModel(gtc, ttc, vertices, numTerms, alpha, beta, alphaAS) 782 | } 783 | } 784 | 785 | /** 786 | * Degree-Based Hashing, the paper: 787 | * http://nips.cc/Conferences/2014/Program/event.php?ID=4569 788 | * @param partitions 789 | */ 790 | private class DBHPartitioner(partitions: Int) extends Partitioner { 791 | val mixingPrime: Long = 1125899906842597L 792 | 793 | def numPartitions = partitions 794 | 795 | def getPartition(key: Any): Int = { 796 | val edge = key.asInstanceOf[EdgeTriplet[Int, ED]] 797 | val srcDeg = edge.srcAttr 798 | val dstDeg = edge.dstAttr 799 | val srcId = edge.srcId 800 | val dstId = edge.dstId 801 | val minId = if (srcDeg < dstDeg) srcId else dstId 802 | getPartition(minId) 803 | } 804 | 805 | def getPartition(idx: Int): PartitionID = { 806 | (math.abs(idx * mixingPrime) % partitions).toInt 807 | } 808 | 809 | def getPartition(idx: Long): PartitionID = { 810 | (math.abs(idx * mixingPrime) % partitions).toInt 811 | } 812 | 813 | override def equals(other: Any): Boolean = other match { 814 | case h: DBHPartitioner => 815 | h.numPartitions == numPartitions 816 | case _ => 817 | false 818 | } 819 | 820 | override def hashCode: Int = numPartitions 821 | } 822 | 823 | private[topicModeling] class LDAKryoRegistrator extends KryoRegistrator { 824 | def registerClasses(kryo: com.esotericsoftware.kryo.Kryo) { 825 | val gkr = new GraphKryoRegistrator 826 | gkr.registerClasses(kryo) 827 | 828 | kryo.register(classOf[BSV[GibbsLDAOptimizer.Count]]) 829 | kryo.register(classOf[BSV[Double]]) 830 | 831 | kryo.register(classOf[BDV[GibbsLDAOptimizer.Count]]) 832 | kryo.register(classOf[BDV[Double]]) 833 | 834 | kryo.register(classOf[SV]) 835 | kryo.register(classOf[SSV]) 836 | kryo.register(classOf[SDV]) 837 | 838 | kryo.register(classOf[GibbsLDAOptimizer.ED]) 839 | kryo.register(classOf[GibbsLDAOptimizer.VD]) 840 | 841 | kryo.register(classOf[Random]) 842 | kryo.register(classOf[GibbsLDAOptimizer]) 843 | kryo.register(classOf[GibbsLDAModel]) 844 | } 845 | } 846 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/topicModeling/GibbsLDASampler.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.mllib.topicModeling 18 | 19 | import java.lang.ref.SoftReference 20 | import java.util 21 | import org.apache.spark.Logging 22 | import org.apache.spark.graphx.impl.GraphImpl 23 | import org.apache.spark.graphx.{TripletFields, VertexId, Graph} 24 | import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV, sum => brzSum, StorageVector} 25 | 26 | import java.util.Random 27 | import org.apache.spark.util.random.XORShiftRandom 28 | import org.apache.spark.util.collection.AppendOnlyMap 29 | import scala.reflect.ClassTag 30 | 31 | trait GibbsLDASampler { 32 | type VD = GibbsLDAOptimizer.VD 33 | type ED = GibbsLDAOptimizer.ED 34 | type Count = GibbsLDAOptimizer.Count 35 | 36 | def sampleTokens(graph: Graph[GibbsLDAOptimizer.VD, ED], 37 | totalTopicCounter: BDV[Count], 38 | innerIter: Long, 39 | numTokens: Long, 40 | numTopics: Int, 41 | numTerms: Int, 42 | alpha: Float, 43 | alphaAS: Float, 44 | beta: Float): Graph[VD, ED] 45 | } 46 | 47 | //private[topicModeling] object TermAliasTableCache { 48 | // 49 | // type VD = GibbsLDAOptimizer.VD 50 | // type ED = GibbsLDAOptimizer.ED 51 | // type Count = GibbsLDAOptimizer.Count 52 | // 53 | // private var content = new AppendOnlyMap[Int, SoftReference[(GibbsAliasTable, Float)]]() 54 | // 55 | // private def wSparse( 56 | // totalTopicCounter: BDV[Count], 57 | // termTopicCounter: VD, 58 | // numTokens: Long, 59 | // numTerms: Int, 60 | // alpha: Float, 61 | // alphaAS: Float, 62 | // beta: Float): (Float, BSV[Float]) = { 63 | // val numTopics = totalTopicCounter.length 64 | // val alphaSum = alpha * numTopics 65 | // val termSum = numTokens - 1f + alphaAS * numTopics 66 | // val betaSum = numTerms * beta 67 | // val w = BSV.zeros[Float](numTopics) 68 | // var sum = 0.0f 69 | // termTopicCounter.activeIterator.filter(_._2 > 0).foreach { t => 70 | // val topic = t._1 71 | // val count = t._2 72 | // val last = count * alphaSum * (totalTopicCounter(topic) + alphaAS) / 73 | // ((totalTopicCounter(topic) + betaSum) * termSum) 74 | // w(topic) = last 75 | // sum += last 76 | // } 77 | // (sum, w) 78 | // } 79 | // 80 | // def get(totalTopicCounter: BDV[Count], 81 | // termTopicCounter: VD, 82 | // termId: VertexId, 83 | // numTokens: Long, 84 | // numTerms: Int, 85 | // alpha: Float, 86 | // alphaAS: Float, 87 | // beta: Float): (GibbsAliasTable, Float) = { 88 | // if(content(termId.toInt) == null) { 89 | // synchronized { 90 | // if(content(termId.toInt) == null) { 91 | // val sv = wSparse(totalTopicCounter, termTopicCounter, 92 | // numTokens, numTerms, alpha, alphaAS, beta) 93 | // 94 | // val table = new GibbsAliasTable(sv._2.activeSize) 95 | // GibbsAliasTable.generateAlias(sv._2, sv._1, table) 96 | // 97 | // content(termId.toInt) = new SoftReference((table, sv._1)) 98 | // } 99 | // } 100 | // } 101 | // content(termId.toInt).get 102 | // } 103 | // 104 | // def clear() = { 105 | // content = new AppendOnlyMap[Int, SoftReference[(GibbsAliasTable, Float)]]() 106 | // } 107 | //} 108 | 109 | private[topicModeling] class GibbsAliasTable(initUsed: Int) extends Serializable { 110 | 111 | private var _l: Array[Int] = new Array[Int](initUsed) 112 | private var _h: Array[Int] = new Array[Int](initUsed) 113 | private var _p: Array[Float] = new Array[Float](initUsed) 114 | private var _used = initUsed 115 | 116 | def l: Array[Int] = _l 117 | 118 | def h: Array[Int] = _h 119 | 120 | def p: Array[Float] = _p 121 | 122 | def used: Int = _used 123 | 124 | def length: Int = size 125 | 126 | def size: Int = l.length 127 | 128 | def sampleAlias(gen: Random): Int = { 129 | val bin = gen.nextInt(_used) 130 | val prob = _p(bin) 131 | if (_used * prob > gen.nextFloat()) { 132 | _l(bin) 133 | } else { 134 | _h(bin) 135 | } 136 | } 137 | 138 | private[GibbsAliasTable] def reset(newSize: Int): this.type = { 139 | if (_l.length < newSize) { 140 | _l = new Array[Int](newSize) 141 | _h = new Array[Int](newSize) 142 | _p = new Array[Float](newSize) 143 | } 144 | _used = newSize 145 | this 146 | } 147 | } 148 | 149 | private[topicModeling] object GibbsAliasTable { 150 | def generateAlias(sv: BV[Float]): GibbsAliasTable = { 151 | val used = sv.activeSize 152 | val sum = brzSum(sv) 153 | val probs = sv.activeIterator.slice(0, used) 154 | generateAlias(probs, sum, used) 155 | } 156 | 157 | def generateAlias(probs: Iterator[(Int, Float)], sum: Float, used: Int): GibbsAliasTable = { 158 | val table = new GibbsAliasTable(used) 159 | generateAlias(probs, sum, used, table) 160 | } 161 | 162 | def generateAlias( 163 | probs: Iterator[(Int, Float)], 164 | sum: Float, 165 | used: Int, 166 | table: GibbsAliasTable): GibbsAliasTable = { 167 | table.reset(used) 168 | val pMean = 1.0f / used 169 | 170 | val lq = new util.ArrayDeque[(Int, Float)]() 171 | val hq = new util.ArrayDeque[(Int, Float)]() 172 | 173 | probs.slice(0, used).foreach { pair => 174 | val i = pair._1 175 | val pi = pair._2 / sum 176 | if (pi < pMean) { 177 | if(pMean-pi<=1e-6) { 178 | lq.addFirst(i, pi) 179 | } else { 180 | lq.add((i, pi)) 181 | } 182 | } else { 183 | if(pi-pMean<=1e-6) { 184 | hq.addFirst(i, pi) 185 | } else { 186 | hq.add((i, pi)) 187 | } 188 | } 189 | } 190 | 191 | var offset = 0 192 | while (!lq.isEmpty & !hq.isEmpty) { 193 | val (i, pi) = lq.removeLast() 194 | val (h, ph) = hq.removeLast() 195 | table.l(offset) = i 196 | table.h(offset) = h 197 | table.p(offset) = pi 198 | val pd = ph - (pMean - pi) 199 | if (pd >= pMean) { 200 | if(pd-pMean<=1e-6) { 201 | hq.addFirst(h, pd) 202 | } else { 203 | hq.add((h, pd)) 204 | } 205 | } else { 206 | if(pMean-pd<=1e-6) { 207 | lq.addFirst(h, pd) 208 | } else { 209 | lq.add((h, pd)) 210 | } 211 | } 212 | offset += 1 213 | } 214 | while (!hq.isEmpty) { 215 | val (h, ph) = hq.removeLast() 216 | // assert(ph - pMean < 1e-4) 217 | table.l(offset) = h 218 | table.h(offset) = h 219 | table.p(offset) = ph 220 | offset += 1 221 | } 222 | 223 | while (!lq.isEmpty) { 224 | val (i, pi) = lq.removeLast() 225 | // assert(pMean - pi < 1e-4) 226 | table.l(offset) = i 227 | table.h(offset) = i 228 | table.p(offset) = pi 229 | offset += 1 230 | } 231 | table 232 | } 233 | 234 | def generateAlias(sv: BV[Float], sum: Float): GibbsAliasTable = { 235 | val used = sv.activeSize 236 | val probs = sv.activeIterator.slice(0, used) 237 | generateAlias(probs, sum, used) 238 | } 239 | 240 | def generateAlias(sv: BV[Float], sum: Float, table: GibbsAliasTable): GibbsAliasTable = { 241 | val used = sv.activeSize 242 | val probs = sv.activeIterator.slice(0, used) 243 | generateAlias(probs, sum, used, table) 244 | } 245 | 246 | } 247 | 248 | class GibbsLDAAliasSampler extends GibbsLDASampler with Logging with Serializable{ 249 | def sampleTokens(graph: Graph[GibbsLDAOptimizer.VD, ED], 250 | totalTopicCounter: BDV[Count], 251 | innerIter: Long, 252 | numTokens: Long, 253 | numTopics: Int, 254 | numTerms: Int, 255 | alpha: Float, 256 | alphaAS: Float, 257 | beta: Float): Graph[VD, ED] = { 258 | val parts = graph.edges.partitions.size 259 | 260 | val newGraph = graph.mapTriplets( 261 | (pid, iter) => { 262 | val gen = new XORShiftRandom(parts * innerIter + pid) 263 | // table is a per term data structure 264 | // in GraphX, edges in a partition are clustered by source IDs (term id in this case) 265 | // so, use below simple cache to avoid calculating table each time 266 | val lastTable = new GibbsAliasTable(numTopics) 267 | var lastVid: VertexId = -1 268 | var lastWSum = 0.0f 269 | val dv = tDense(totalTopicCounter, numTokens, numTerms, alpha, alphaAS, beta) 270 | val dData = new Array[Float](numTopics.toInt) 271 | val t = GibbsAliasTable.generateAlias(dv._2, dv._1) 272 | val tSum = dv._1 273 | 274 | iter.map { 275 | triplet => 276 | val termId = triplet.srcId 277 | val docId = triplet.dstId 278 | val termTopicCounter = triplet.srcAttr 279 | val docTopicCounter = triplet.dstAttr 280 | val topics = triplet.attr.clone() 281 | 282 | for (i <- 0 until topics.length) { 283 | val currentTopic = topics(i) 284 | dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, dData, 285 | currentTopic, numTokens, numTerms, alpha, alphaAS, beta) 286 | 287 | if (lastVid != termId || gen.nextDouble() < 1e-4) { 288 | lastWSum = wordTable(lastTable, totalTopicCounter, termTopicCounter, 289 | termId, numTokens, numTerms, alpha, alphaAS, beta) 290 | lastVid = termId 291 | } 292 | 293 | val newTopic = tokenSampling(gen, t, tSum, lastTable, termTopicCounter, lastWSum, 294 | docTopicCounter, dData, currentTopic) 295 | 296 | if (newTopic != currentTopic) { 297 | topics(i) = newTopic 298 | docTopicCounter(currentTopic) -= 1 299 | docTopicCounter(newTopic) += 1 300 | // if (docTopicCounter(currentTopic) == 0) docTopicCounter.compact() 301 | 302 | termTopicCounter(currentTopic) -= 1 303 | termTopicCounter(newTopic) += 1 304 | // if (termTopicCounter(currentTopic) == 0) termTopicCounter.compact() 305 | 306 | totalTopicCounter(currentTopic) -= 1 307 | totalTopicCounter(newTopic) += 1 308 | } 309 | } 310 | topics 311 | } 312 | }, TripletFields.All) 313 | GraphImpl(newGraph.vertices.mapValues(t => null), newGraph.edges) 314 | } 315 | 316 | // scalastyle:off 317 | /** 318 | * the formula is 319 | * t = \frac{{\beta }_{w} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} ) } {({n}_{k}^{-di}+\bar{\beta}) ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} 320 | */ 321 | // scalastyle:on 322 | private def tDense( 323 | totalTopicCounter: BDV[Count], 324 | numTokens: Long, 325 | numTerms: Int, 326 | alpha: Float, 327 | alphaAS: Float, 328 | beta: Float): (Float, BDV[Float]) = { 329 | val numTopics = totalTopicCounter.length 330 | val t = BDV.zeros[Float](numTopics) 331 | val alphaSum = alpha * numTopics 332 | val termSum = numTokens - 1F + alphaAS * numTopics 333 | val betaSum = numTerms * beta 334 | var sum = 0.0f 335 | for (topic <- 0 until numTopics) { 336 | val last = beta * alphaSum * (totalTopicCounter(topic) + alphaAS) / 337 | ((totalTopicCounter(topic) + betaSum) * termSum) 338 | t(topic) = last 339 | sum += last 340 | } 341 | (sum, t) 342 | } 343 | 344 | // scalastyle:off 345 | /** 346 | * the formula is: 347 | * d = \frac{{n}_{kd} ^{-di}({\sum{n}_{k}^{-di} + \bar{\acute{\alpha}}})({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} 348 | * = \frac{{n}_{kd} ^{-di}({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta}) } 349 | */ 350 | // scalastyle:on 351 | private def dSparse( 352 | totalTopicCounter: BDV[Count], 353 | termTopicCounter: VD, 354 | docTopicCounter: VD, 355 | d: Array[Float], 356 | currentTopic: Int, 357 | numTokens: Long, 358 | numTerms: Int, 359 | alpha: Float, 360 | alphaAS: Float, 361 | beta: Float): Unit = { 362 | // val termSum = numTokens - 1D + alphaAS * numTopics 363 | val betaSum = numTerms * beta 364 | var sum = 0.0f 365 | var i = 0 366 | 367 | docTopicCounter.activeIterator.filter(_._2 > 0).foreach { t => 368 | val topic = t._1 369 | val count = t._2 370 | val adjustment = if (currentTopic == topic) -1F else 0 371 | val last = (count + adjustment) * (termTopicCounter(topic) + adjustment + beta) / 372 | (totalTopicCounter(topic) + adjustment + betaSum) 373 | // val lastD = (count + adjustment) * termSum * (termTopicCounter(topic) + adjustment + beta) / 374 | // ((totalTopicCounter(topic) + adjustment + betaSum) * termSum) 375 | 376 | sum += last 377 | d(i) = sum 378 | i += 1 379 | } 380 | } 381 | 382 | private def wordTable( 383 | table: GibbsAliasTable, 384 | totalTopicCounter: BDV[Count], 385 | termTopicCounter: VD, 386 | termId: VertexId, 387 | numTokens: Long, 388 | numTerms: Int, 389 | alpha: Float, 390 | alphaAS: Float, 391 | beta: Float): Float = { 392 | val sv = wSparse(totalTopicCounter, termTopicCounter, 393 | numTokens, numTerms, alpha, alphaAS, beta) 394 | GibbsAliasTable.generateAlias(sv._2, sv._1, table) 395 | sv._1 396 | } 397 | 398 | // scalastyle:off 399 | /** 400 | * use both Gibbs sampler and Metropolis Hastings sampler 401 | * Complexity is O(1) 402 | * Use formula (3) from the Gibbs sampler paper: Rethinking LDA: Why Priors Matter 403 | * \frac{{n}_{kw}^{-di}+{\beta }_{w}}{{n}_{k}^{-di}+\bar{\beta}} \frac{{n}_{kd} ^{-di}+ \bar{\alpha} \frac{{n}_{k}^{-di} + \acute{\alpha}}{\sum{n}_{k} +\bar{\acute{\alpha}}}}{\sum{n}_{kd}^{-di} +\bar{\alpha}} 404 | * = t + w + d 405 | * t the global part 406 | * t = \frac{{\beta }_{w} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} ) } {({n}_{k}^{-di}+\bar{\beta}) ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} 407 | * w: term related 408 | * w = \frac{ {n}_{kw}^{-di} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} )}{({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} 409 | * d: product of doc and term 410 | * d = \frac{{n}_{kd}^{-di}({\sum{n}_{k}^{-di} + \bar{\acute{\alpha}}})({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} 411 | * = \frac{{n}_{kd ^{-di}({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta}) } 412 | * where: 413 | * \bar{\beta}=\sum_{w}{\beta}_{w} 414 | * \bar{\alpha}=\sum_{k}{\alpha}_{k} 415 | * \bar{\acute{\alpha}}=\barval lq = new util.ArrayDeque[(Int, Float)]() 416 | val hq = new util.ArrayDeque[(Int, Float)](){\acute{\alpha}}=\sum_{k}\acute{\alpha} 417 | * {n}_{kd} number of tokens in document d that are assigned to topic k 418 | * {n}_{kw} number of tokens with word w (across all docs) that are assigned to topic k 419 | * {n}_{k} number of tokens across all docs that are assigned to topic k 420 | * -di substract topic of current token 421 | */ 422 | // scalastyle:on 423 | private def tokenSampling(gen: Random, 424 | t: GibbsAliasTable, 425 | tSum: Float, 426 | w: GibbsAliasTable, 427 | termTopicCounter: VD, 428 | wSum: Float, 429 | docTopicCounter: VD, 430 | dData: Array[Float], 431 | currentTopic: Int): Int = { 432 | val used = docTopicCounter.activeSize 433 | val dSum = dData(used - 1) 434 | val distSum = tSum + wSum + dSum 435 | val genSum = gen.nextFloat() * distSum 436 | if (genSum < dSum) { 437 | val dGenSum = gen.nextFloat() * dSum 438 | val pos = GibbsLDAUtils.binarySearchInterval(dData, dGenSum, 0, used, true) 439 | docTopicCounter.indexAt(pos) 440 | } else if (genSum < (dSum + wSum)) { 441 | sampleSV(gen, w, termTopicCounter, currentTopic) 442 | } else { 443 | t.sampleAlias(gen) 444 | } 445 | } 446 | 447 | // scalastyle:off 448 | /** 449 | * the formula is: 450 | * w = \frac{ {n}_{kw}^{-di} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} )}{({n}_{k}^{-di}+\bar{\beta}) ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} 451 | */ 452 | // scalastyle:on 453 | private def wSparse( 454 | totalTopicCounter: BDV[Count], 455 | termTopicCounter: VD, 456 | numTokens: Long, 457 | numTerms: Int, 458 | alpha: Float, 459 | alphaAS: Float, 460 | beta: Float): (Float, BSV[Float]) = { 461 | val numTopics = totalTopicCounter.length 462 | val alphaSum = alpha * numTopics 463 | val termSum = numTokens - 1F + alphaAS * numTopics 464 | val betaSum = numTerms * beta 465 | val w = BSV.zeros[Float](numTopics) 466 | var sum = 0.0f 467 | termTopicCounter.activeIterator.filter(_._2 > 0).foreach { t => 468 | val topic = t._1 469 | val count = t._2 470 | val last = count * alphaSum * (totalTopicCounter(topic) + alphaAS) / 471 | ((totalTopicCounter(topic) + betaSum) * termSum) 472 | w(topic) = last 473 | sum += last 474 | } 475 | (sum, w) 476 | } 477 | 478 | private def sampleSV(gen: Random, table: GibbsAliasTable, sv: VD, currentTopic: Int): Int = { 479 | var docTopic = table.sampleAlias(gen) 480 | 481 | while(docTopic == currentTopic) { 482 | 483 | val svCounter = sv(currentTopic) 484 | 485 | if ((svCounter == 1 && table.used > 1) || 486 | (svCounter > 1 && gen.nextFloat() < 1.0f / svCounter)) { 487 | docTopic = table.sampleAlias(gen) 488 | } else { 489 | return docTopic 490 | } 491 | } 492 | docTopic 493 | } 494 | } 495 | 496 | class GibbsLDAFastSampler extends GibbsLDASampler with Serializable with Logging { 497 | def sampleTokens(graph: Graph[GibbsLDAOptimizer.VD, ED], 498 | totalTopicCounter: BDV[Count], 499 | innerIter: Long, 500 | numTokens: Long, 501 | numTopics: Int, 502 | numTerms: Int, 503 | alpha: Float, 504 | alphaAS: Float, 505 | beta: Float): Graph[VD, ED] = { 506 | val parts = graph.edges.partitions.size 507 | val nweGraph = graph.mapTriplets( 508 | (pid, iter) => { 509 | val gen = new Random(parts * innerIter + pid) 510 | // table is a per term data structure 511 | // in GraphX, edges in a partition are clustered by source IDs (term id in this case) 512 | // so, use below simple cache to avoid calculating table each time 513 | val lastTable = new GibbsAliasTable(numTopics.toInt) 514 | var lastVid: VertexId = -1 515 | var lastWSum = 0.0f 516 | val dv = tDense(totalTopicCounter, numTokens, numTerms, alpha, alphaAS, beta) 517 | val dData = new Array[Float](numTopics.toInt) 518 | val t = GibbsAliasTable.generateAlias(dv._2, dv._1) 519 | val tSum = dv._1 520 | iter.map { 521 | triplet => 522 | val termId = triplet.srcId 523 | val docId = triplet.dstId 524 | val termTopicCounter = triplet.srcAttr 525 | val docTopicCounter = triplet.dstAttr 526 | val topics = triplet.attr.clone() 527 | for (i <- 0 until topics.length) { 528 | val currentTopic = topics(i) 529 | dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, dData, 530 | currentTopic, numTokens, numTerms, alpha, alphaAS, beta) 531 | if (lastVid != termId) { 532 | lastWSum = wordTable(lastTable, totalTopicCounter, termTopicCounter, 533 | termId, numTokens, numTerms, alpha, alphaAS, beta) 534 | lastVid = termId 535 | } 536 | 537 | val newTopic = tokenSampling(gen, t, tSum, lastTable, termTopicCounter, lastWSum, 538 | docTopicCounter, dData, currentTopic) 539 | 540 | if (newTopic != currentTopic) { 541 | topics(i) = newTopic 542 | } 543 | } 544 | 545 | topics 546 | } 547 | }, TripletFields.All) 548 | GraphImpl(nweGraph.vertices.mapValues(t => null), nweGraph.edges) 549 | 550 | } 551 | 552 | private def tokenSampling( 553 | gen: Random, 554 | t: GibbsAliasTable, 555 | tSum: Float, 556 | w: GibbsAliasTable, 557 | termTopicCounter: VD, 558 | wSum: Float, 559 | docTopicCounter: VD, 560 | dData: Array[Float], 561 | currentTopic: Int): Int = { 562 | val used = docTopicCounter.activeSize 563 | val dSum = dData(docTopicCounter.activeSize - 1) 564 | val distSum = tSum + wSum + dSum 565 | val genSum = gen.nextFloat() * distSum 566 | if (genSum < dSum) { 567 | val dGenSum = gen.nextFloat() * dSum 568 | val pos = GibbsLDAUtils.binarySearchInterval(dData, dGenSum, 0, used, true) 569 | docTopicCounter.indexAt(pos) 570 | } else if (genSum < (dSum + wSum)) { 571 | sampleSV(gen, w, termTopicCounter, currentTopic) 572 | } else { 573 | t.sampleAlias(gen) 574 | } 575 | } 576 | 577 | /** 578 | * dense part in the decomposed sampling formula: 579 | * t = \frac{{\beta }_{w} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} ) } {({n}_{k}^{-di}+\bar{\beta}) 580 | * ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} 581 | */ 582 | private def tDense( 583 | totalTopicCounter: BDV[Count], 584 | numTokens: Long, 585 | numTerms: Int, 586 | alpha: Float, 587 | alphaAS: Float, 588 | beta: Float): (Float, BDV[Float]) = { 589 | val numTopics = totalTopicCounter.length 590 | val t = BDV.zeros[Float](numTopics) 591 | val alphaSum = alpha * numTopics 592 | val termSum = numTokens - 1F + alphaAS * numTopics 593 | val betaSum = numTerms * beta 594 | var sum = 0.0f 595 | for (topic <- 0 until numTopics) { 596 | val last = beta * alphaSum * (totalTopicCounter(topic) + alphaAS) / 597 | ((totalTopicCounter(topic) + betaSum) * termSum) 598 | t(topic) = last 599 | sum += last 600 | } 601 | (sum, t) 602 | } 603 | 604 | /** 605 | * word related sparse part in the decomposed sampling formula: 606 | * w = \frac{ {n}_{kw}^{-di} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} )}{({n}_{k}^{-di}+\bar{\beta}) 607 | * ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} 608 | */ 609 | private def wSparse( 610 | totalTopicCounter: BDV[Count], 611 | termTopicCounter: VD, 612 | numTokens: Long, 613 | numTerms: Int, 614 | alpha: Float, 615 | alphaAS: Float, 616 | beta: Float): (Float, BSV[Float]) = { 617 | val numTopics = totalTopicCounter.length 618 | val alphaSum = alpha * numTopics 619 | val termSum = numTokens - 1F + alphaAS * numTopics 620 | val betaSum = numTerms * beta 621 | val w = BSV.zeros[Float](numTopics) 622 | var sum = 0.0f 623 | termTopicCounter.activeIterator.filter(_._2 > 0).foreach { t => 624 | val topic = t._1 625 | val count = t._2 626 | val last = count * alphaSum * (totalTopicCounter(topic) + alphaAS) / 627 | ((totalTopicCounter(topic) + betaSum) * termSum) 628 | w(topic) = last 629 | sum += last 630 | } 631 | (sum, w) 632 | } 633 | 634 | /** 635 | * doc related sparse part in the decomposed sampling formula: 636 | * d = \frac{{n}_{kd} ^{-di}({\sum{n}_{k}^{-di} + \bar{\acute{\alpha}}})({n}_{kw}^{-di}+{\beta}_{w})} 637 | * {({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} 638 | * = \frac{{n}_{kd} ^{-di}({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta}) } 639 | */ 640 | private def dSparse( 641 | totalTopicCounter: BDV[Count], 642 | termTopicCounter: VD, 643 | docTopicCounter: VD, 644 | d: Array[Float], 645 | currentTopic: Int, 646 | numTokens: Long, 647 | numTerms: Int, 648 | alpha: Float, 649 | alphaAS: Float, 650 | beta: Float): Unit = { 651 | val data = docTopicCounter.data 652 | val used = docTopicCounter.activeSize 653 | 654 | // val termSum = numTokens - 1D + alphaAS * numTopics 655 | val betaSum = numTerms * beta 656 | var sum = 0.0f 657 | var i = 0 658 | docTopicCounter.activeIterator.filter(_._2 > 0).foreach { t => 659 | val topic = t._1 660 | val count = t._2 661 | val adjustment = if (currentTopic == topic) -1F else 0 662 | val last = (count + adjustment) * (termTopicCounter(topic) + adjustment + beta) / 663 | (totalTopicCounter(topic) + adjustment + betaSum) 664 | // val lastD = (count + adjustment) * termSum * (termTopicCounter(topic) + adjustment + beta) / 665 | // ((totalTopicCounter(topic) + adjustment + betaSum) * termSum) 666 | sum += last 667 | d(i) = sum 668 | i += 1 669 | } 670 | } 671 | 672 | private def wordTable( 673 | table: GibbsAliasTable, 674 | totalTopicCounter: BDV[Count], 675 | termTopicCounter: VD, 676 | termId: VertexId, 677 | numTokens: Long, 678 | numTerms: Int, 679 | alpha: Float, 680 | alphaAS: Float, 681 | beta: Float): Float = { 682 | val sv = wSparse(totalTopicCounter, termTopicCounter, 683 | numTokens, numTerms, alpha, alphaAS, beta) 684 | GibbsAliasTable.generateAlias(sv._2, sv._1, table) 685 | sv._1 686 | } 687 | 688 | private def sampleSV( 689 | gen: Random, 690 | table: GibbsAliasTable, 691 | sv: VD, 692 | currentTopic: Int, 693 | currentTopicCounter: Int = 0, 694 | numSampling: Int = 0): Int = { 695 | val docTopic = table.sampleAlias(gen) 696 | if (docTopic == currentTopic && numSampling < 16) { 697 | val svCounter = if (currentTopicCounter == 0) sv(currentTopic) else currentTopicCounter 698 | // TODO: not sure it is correct or not? 699 | // discard it if the newly sampled topic is current topic 700 | if ((svCounter == 1 && table.used > 1) || 701 | /* the sampled topic that contains current token and other tokens */ 702 | (svCounter > 1 && gen.nextDouble() < 1.0 / svCounter) 703 | /* the sampled topic has 1/svCounter probability that belongs to current token */ ) { 704 | return sampleSV(gen, table, sv, currentTopic, svCounter, numSampling + 1) 705 | } 706 | } 707 | docTopic 708 | } 709 | } 710 | 711 | class GibbsLDALightSampler extends GibbsLDASampler with Logging with Serializable { 712 | 713 | private def sampleSV( 714 | gen: Random, 715 | table: GibbsAliasTable, 716 | sv: VD, 717 | currentTopic: Int, 718 | currentTopicCounter: Int = 0, 719 | numSampling: Int = 0): Int = { 720 | val docTopic = table.sampleAlias(gen) 721 | if (docTopic == currentTopic && numSampling < 16) { 722 | val svCounter = if (currentTopicCounter == 0) sv(currentTopic) else currentTopicCounter 723 | // TODO: not sure it is correct or not? 724 | // discard it if the newly sampled topic is current topic 725 | if ((svCounter == 1 && table.used > 1) || 726 | /* the sampled topic that contains current token and other tokens */ 727 | (svCounter > 1 && gen.nextDouble() < 1.0 / svCounter) 728 | /* the sampled topic has 1/svCounter probability that belongs to current token */ ) { 729 | return sampleSV(gen, table, sv, currentTopic, svCounter, numSampling + 1) 730 | } 731 | } 732 | docTopic 733 | } 734 | 735 | def sampleTokens(graph: Graph[GibbsLDAOptimizer.VD, ED], 736 | totalTopicCounter: BDV[Count], 737 | innerIter: Long, 738 | numTokens: Long, 739 | numTopics: Int, 740 | numTerms: Int, 741 | alpha: Float, 742 | alphaAS: Float, 743 | beta: Float): Graph[VD, ED] = { 744 | val parts = graph.edges.partitions.length 745 | val nweGraph = graph.mapTriplets( 746 | (pid, iter) => { 747 | val gen = new Random(parts * innerIter + pid) 748 | val docTableCache = new AppendOnlyMap[VertexId, SoftReference[(Float, GibbsAliasTable)]]() 749 | 750 | // table is a per term data structure 751 | // in GraphX, edges in a partition are clustered by source IDs (term id in this case) 752 | // so, use below simple cache to avoid calculating table each time 753 | val lastTable = new GibbsAliasTable(numTopics.toInt) 754 | var lastVid: VertexId = -1 755 | var lastWSum = 0.0f 756 | 757 | val p = tokenTopicProb(totalTopicCounter, beta, alpha, 758 | alphaAS, numTokens, numTerms) _ 759 | val dPFun = docProb(totalTopicCounter, alpha, alphaAS, numTokens) _ 760 | val wPFun = wordProb(totalTopicCounter, numTerms, beta) _ 761 | 762 | var dD: GibbsAliasTable = null 763 | var dDSum: Float = 0.0f 764 | var wD: GibbsAliasTable = null 765 | var wDSum: Float = 0.0f 766 | 767 | iter.map { 768 | triplet => 769 | val termId = triplet.srcId 770 | val docId = triplet.dstId 771 | val termTopicCounter = triplet.srcAttr 772 | val docTopicCounter = triplet.dstAttr 773 | val topics = triplet.attr.clone() 774 | 775 | if (dD == null || gen.nextDouble() < 1e-6) { 776 | var dv = dDense(totalTopicCounter, alpha, alphaAS, numTokens) 777 | dDSum = dv._1 778 | dD = GibbsAliasTable.generateAlias(dv._2, dDSum) 779 | 780 | dv = wDense(totalTopicCounter, numTerms, beta) 781 | wDSum = dv._1 782 | wD = GibbsAliasTable.generateAlias(dv._2, wDSum) 783 | } 784 | val (dSum, d) = docTopicCounter.synchronized { 785 | docTable(x => x == null || x.get() == null || gen.nextDouble() < 1e-2, 786 | docTableCache, docTopicCounter, docId) 787 | } 788 | val (wSum, w) = termTopicCounter.synchronized { 789 | if (lastVid != termId || gen.nextDouble() < 1e-4) { 790 | lastWSum = wordTable(lastTable, totalTopicCounter, termTopicCounter, termId, numTerms, beta) 791 | lastVid = termId 792 | } 793 | (lastWSum, lastTable) 794 | } 795 | for (i <- topics.indices) { 796 | var docProposal = gen.nextDouble() < 0.5 797 | var maxSampling = 8 798 | while (maxSampling > 0) { 799 | maxSampling -= 1 800 | docProposal = !docProposal 801 | val currentTopic = topics(i) 802 | var proposalTopic = -1 803 | val q = if (docProposal) { 804 | if (gen.nextFloat() < dDSum / (dSum - 1.0f + dDSum)) { 805 | proposalTopic = dD.sampleAlias(gen) 806 | } 807 | else { 808 | proposalTopic = docTopicCounter.synchronized { 809 | sampleSV(gen, d, docTopicCounter, currentTopic) 810 | } 811 | } 812 | dPFun 813 | } else { 814 | val table = if (gen.nextDouble() < wSum / (wSum + wDSum)) w else wD 815 | proposalTopic = table.sampleAlias(gen) 816 | wPFun 817 | } 818 | 819 | val newTopic = docTopicCounter.synchronized { 820 | termTopicCounter.synchronized { 821 | tokenSampling(gen, docTopicCounter, termTopicCounter, docProposal, 822 | currentTopic, proposalTopic, q, p) 823 | } 824 | } 825 | 826 | assert(newTopic >= 0 && newTopic < numTopics) 827 | if (newTopic != currentTopic) { 828 | topics(i) = newTopic 829 | docTopicCounter.synchronized { 830 | docTopicCounter(currentTopic) -= 1 831 | docTopicCounter(newTopic) += 1 832 | } 833 | termTopicCounter.synchronized { 834 | termTopicCounter(currentTopic) -= 1 835 | termTopicCounter(newTopic) += 1 836 | } 837 | totalTopicCounter(currentTopic) -= 1 838 | totalTopicCounter(newTopic) += 1 839 | } 840 | } 841 | } 842 | topics 843 | } 844 | }, TripletFields.All) 845 | GraphImpl(nweGraph.vertices.mapValues(t => null), nweGraph.edges) 846 | } 847 | 848 | // scalastyle:off 849 | private def tokenTopicProb( 850 | totalTopicCounter: BDV[Count], 851 | beta: Float, 852 | alpha: Float, 853 | alphaAS: Float, 854 | numTokens: Long, 855 | numTerms: Int)(docTopicCounter: VD, 856 | termTopicCounter: VD, 857 | topic: Int, 858 | isAdjustment: Boolean): Float = { 859 | val numTopics = docTopicCounter.length 860 | val adjustment = if (isAdjustment) -1 else 0 861 | val ratio = (totalTopicCounter(topic) + adjustment + alphaAS) / 862 | (numTokens - 1 + alphaAS * numTopics) 863 | val asPrior = ratio * (alpha * numTopics) 864 | // constant terms are removed (docLen - 1 + alpha * numTopics) 865 | (termTopicCounter(topic) + adjustment + beta) * 866 | (docTopicCounter(topic) + adjustment + asPrior) / 867 | (totalTopicCounter(topic) + adjustment + (numTerms * beta)) 868 | 869 | // original formula: Rethinking LDA: Why Priors Matter formula (3) 870 | // val docLen = brzSum(docTopicCounter) 871 | // (termTopicCounter(topic) + adjustment + beta) * (docTopicCounter(topic) + adjustment + asPrior) / 872 | // ((totalTopicCounter(topic) + adjustment + (numTerms * beta)) * (docLen - 1 + alpha * numTopics)) 873 | } 874 | 875 | // scalastyle:on 876 | 877 | private def wordProb( 878 | totalTopicCounter: BDV[Count], 879 | numTerms: Int, 880 | beta: Float)(termTopicCounter: VD, topic: Int, isAdjustment: Boolean): Float = { 881 | (termTopicCounter(topic) + beta) / (totalTopicCounter(topic) + beta * numTerms) 882 | } 883 | 884 | private def docProb( 885 | totalTopicCounter: BDV[Count], 886 | alpha: Float, 887 | alphaAS: Float, 888 | numTokens: Long)(docTopicCounter: VD, topic: Int, isAdjustment: Boolean): Float = { 889 | val adjustment = if (isAdjustment) -1 else 0 890 | val numTopics = totalTopicCounter.length 891 | val ratio = (totalTopicCounter(topic) + alphaAS) / 892 | (numTokens - 1 + alphaAS * numTopics) 893 | val asPrior = ratio * (alpha * numTopics) 894 | docTopicCounter(topic) + adjustment + asPrior 895 | } 896 | 897 | /** 898 | * \frac{{n}_{kw}}{{n}_{k}+\bar{\beta}} 899 | */ 900 | private def wSparse( 901 | totalTopicCounter: BDV[Count], 902 | termTopicCounter: VD, 903 | numTerms: Int, 904 | beta: Float): (Float, BV[Float]) = { 905 | val numTopics = termTopicCounter.length 906 | val termSum = beta * numTerms 907 | val w = BSV.zeros[Float](numTopics) 908 | 909 | var sum = 0.0f 910 | termTopicCounter.activeIterator.foreach { t => 911 | val topic = t._1 912 | val count = t._2 913 | if (count > 0) { 914 | val last = count / (totalTopicCounter(topic) + termSum) 915 | w(topic) = last 916 | sum += last 917 | } 918 | } 919 | (sum, w) 920 | } 921 | 922 | /** 923 | * \frac{{\beta}_{w}}{{n}_{k}+\bar{\beta}} 924 | */ 925 | private def wDense( 926 | totalTopicCounter: BDV[Count], 927 | numTerms: Int, 928 | beta: Float): (Float, BV[Float]) = { 929 | val numTopics = totalTopicCounter.length 930 | val t = BDV.zeros[Float](numTopics) 931 | val termSum = beta * numTerms 932 | var sum = 0.0f 933 | for (topic <- 0 until numTopics) { 934 | val last = beta / (totalTopicCounter(topic) + termSum) 935 | t(topic) = last 936 | sum += last 937 | } 938 | (sum, t) 939 | } 940 | 941 | private def dSparse(docTopicCounter: VD): (Float, BV[Float]) = { 942 | val numTopics = docTopicCounter.length 943 | val d = BSV.zeros[Float](numTopics) 944 | var sum = 0.0f 945 | docTopicCounter.activeIterator.foreach { t => 946 | val topic = t._1 947 | val count = t._2 948 | if (count > 0) { 949 | val last = count 950 | d(topic) = last 951 | sum += last 952 | } 953 | } 954 | (sum, d) 955 | } 956 | 957 | 958 | private def dDense( 959 | totalTopicCounter: BDV[Count], 960 | alpha: Float, 961 | alphaAS: Float, 962 | numTokens: Long): (Float, BV[Float]) = { 963 | val numTopics = totalTopicCounter.length 964 | val asPrior = BDV.zeros[Float](numTopics) 965 | 966 | var sum = 0.0f 967 | for (topic <- 0 until numTopics) { 968 | val ratio = (totalTopicCounter(topic) + alphaAS) / 969 | (numTokens - 1 + alphaAS * numTopics) 970 | val last = ratio * (alpha * numTopics) 971 | asPrior(topic) = last 972 | sum += last 973 | } 974 | (sum, asPrior) 975 | } 976 | 977 | private def docTable( 978 | updateFunc: SoftReference[(Float, GibbsAliasTable)] => Boolean, 979 | cacheMap: AppendOnlyMap[VertexId, SoftReference[(Float, GibbsAliasTable)]], 980 | docTopicCounter: VD, 981 | docId: VertexId): (Float, GibbsAliasTable) = { 982 | val cacheD = cacheMap(docId) 983 | if (!updateFunc(cacheD)) { 984 | cacheD.get 985 | } else { 986 | docTopicCounter.synchronized { 987 | val sv = dSparse(docTopicCounter) 988 | val d = (sv._1, GibbsAliasTable.generateAlias(sv._2, sv._1)) 989 | cacheMap.update(docId, new SoftReference(d)) 990 | d 991 | } 992 | } 993 | } 994 | 995 | private def wordTable( 996 | table: GibbsAliasTable, 997 | totalTopicCounter: BDV[Count], 998 | termTopicCounter: VD, 999 | termId: VertexId, 1000 | numTerms: Int, 1001 | beta: Float): Float = { 1002 | val sv = wSparse(totalTopicCounter, termTopicCounter, numTerms, beta) 1003 | GibbsAliasTable.generateAlias(sv._2, sv._1, table) 1004 | sv._1 1005 | } 1006 | 1007 | // scalastyle:off 1008 | /** 1009 | * use both Gibbs sampler and Metropolis Hastings sampler 1010 | * Complexity is O(1) 1011 | * 1. use term related portions of the Gibbs sampler LDA formula 1012 | * LightLDA: Big Topic Models on Modest Compute Clusters, formula(6): 1013 | * ( \frac{{n}_{kd}^{-di}+{\beta }_{w}}{{n}_{k}^{-di}+\bar{\beta }} ) 1014 | * 2. use probability sampled from step 1 as Proposal q(.), use Metropolis Hastings sampler Sampling asymmetric transcendental equation 1015 | * reference paper: Rethinking LDA: Why Priors Matter, formula(3) 1016 | * \frac{{n}_{kw}^{-di}+{\beta }_{w}}{{n}_{k}^{-di}+\bar{\beta}} \frac{{n}_{kd} ^{-di}+ \bar{\alpha} \frac{{n}_{k}^{-di} + \acute{\alpha}}{\sum{n}_{k} +\bar{\acute{\alpha}}}}{\sum{n}_{kd}^{-di} +\bar{\alpha}} 1017 | * 1018 | * where 1019 | * \bar{\beta}=\sum_{w}{\beta}_{w} 1020 | * \bar{\alpha}=\sum_{k}{\alpha}_{k} 1021 | * \bar{\acute{\alpha}}=\bar{\acute{\alpha}}=\sum_{k}\acute{\alpha} 1022 | * {n}_{kd} number of tokens in document d that are assigned to topic k 1023 | * {n}_{kw} number of tokens with word w (across all docs) that are assigned to topic k 1024 | * {n}_{k} number of tokens across all docs that are assigned to topic k 1025 | */ 1026 | // scalastyle:on 1027 | def tokenSampling( 1028 | gen: Random, 1029 | docTopicCounter: VD, 1030 | termTopicCounter: VD, 1031 | docProposal: Boolean, 1032 | currentTopic: Int, 1033 | proposalTopic: Int, 1034 | q: (VD, Int, Boolean) => Float, 1035 | p: (VD, VD, Int, Boolean) => Float): Int = { 1036 | if (proposalTopic == currentTopic) return proposalTopic 1037 | val cp = p(docTopicCounter, termTopicCounter, currentTopic, true) 1038 | val np = p(docTopicCounter, termTopicCounter, proposalTopic, false) 1039 | val vd = if (docProposal) docTopicCounter else termTopicCounter 1040 | val cq = q(vd, currentTopic, true) 1041 | val nq = q(vd, proposalTopic, false) 1042 | 1043 | val pi = (np * cq) / (cp * nq) 1044 | if (gen.nextDouble() < 1e-32) { 1045 | println(s"Pi: ${pi}") 1046 | println(s"($np * $cq) / ($cp * $nq)") 1047 | } 1048 | 1049 | if (gen.nextDouble() < math.min(1.0, pi)) proposalTopic else currentTopic 1050 | } 1051 | } 1052 | 1053 | class GibbsLDASparseSampler extends GibbsLDASampler with Logging with Serializable { 1054 | def sampleTokens(graph: Graph[GibbsLDAOptimizer.VD, ED], 1055 | totalTopicCounter: BDV[Count], 1056 | innerIter: Long, 1057 | numTokens: Long, 1058 | numTopics: Int, 1059 | numTerms: Int, 1060 | alpha: Float, 1061 | alphaAS: Float, 1062 | beta: Float): Graph[VD, ED] = { 1063 | val parts = graph.edges.partitions.size 1064 | val newGraph = graph.mapTriplets( 1065 | (pid, iter) => { 1066 | val gen = new XORShiftRandom(parts * innerIter + pid) 1067 | val d = BDV.zeros[Float](numTopics) 1068 | var lastTermId:VertexId = -1 1069 | var lastWordTable:BSV[Float] = null 1070 | var tCache: BDV[Float] = null 1071 | iter.map { 1072 | triplet => 1073 | val termId = triplet.srcId 1074 | val termTopicCounter = triplet.srcAttr 1075 | val docTopicCounter = triplet.dstAttr 1076 | val topics = triplet.attr.clone() 1077 | termTopicCounter.synchronized { 1078 | docTopicCounter.synchronized { 1079 | if (lastTermId != termId || gen.nextDouble() < 1e-4) { 1080 | lastWordTable = w(totalTopicCounter, termTopicCounter, termId, 1081 | numTokens, numTerms, alpha, beta, alphaAS) 1082 | lastTermId = termId 1083 | } 1084 | if (tCache == null || gen.nextDouble() < 1e-7) { 1085 | tCache = this.t(totalTopicCounter, numTokens, numTerms, 1086 | numTopics, alpha, beta, alphaAS) 1087 | } 1088 | val t = tCache 1089 | var i = 0 1090 | while (i < topics.length) { 1091 | val currentTopic = topics(i) 1092 | this.d(totalTopicCounter, termTopicCounter, docTopicCounter, d, 1093 | currentTopic, numTokens, numTerms, numTopics, beta, alphaAS) 1094 | val newTopic = multinomialDistSampler(gen, docTopicCounter, d, lastWordTable, t) 1095 | if (currentTopic != newTopic) { 1096 | topics(i) = newTopic 1097 | docTopicCounter(currentTopic) -= 1 1098 | docTopicCounter(newTopic) += 1 1099 | //if (docTopicCounter(currentTopic) == 0) docTopicCounter.compact() 1100 | 1101 | termTopicCounter(currentTopic) -= 1 1102 | termTopicCounter(newTopic) += 1 1103 | //if (termTopicCounter(currentTopic) == 0) termTopicCounter.compact() 1104 | 1105 | totalTopicCounter(currentTopic) -= 1 1106 | totalTopicCounter(newTopic) += 1 1107 | } 1108 | i += 1 1109 | } 1110 | } 1111 | } 1112 | topics 1113 | } 1114 | }, TripletFields.All) 1115 | GraphImpl(newGraph.vertices.mapValues(t => null), newGraph.edges) 1116 | } 1117 | 1118 | private def w( 1119 | totalTopicCounter: BDV[Count], 1120 | termTopicCounter: VD, 1121 | termId: VertexId, 1122 | numTokens: Long, 1123 | numTerms: Int, 1124 | alpha: Float, 1125 | beta: Float, 1126 | alphaAS: Float): BSV[Float] = { 1127 | val numTopics = totalTopicCounter.length 1128 | val alphaSum = alpha * numTopics 1129 | val termSum = numTokens - 1F + alphaAS * numTopics 1130 | val betaSum = numTerms * beta 1131 | val length = termTopicCounter.length 1132 | val used = termTopicCounter.activeSize 1133 | val index = 1134 | termTopicCounter match { 1135 | case idx:BSV[_] => idx.index.slice(0, used) 1136 | case idx:BDV[_] => (0 until idx.length).toArray 1137 | } 1138 | val data = termTopicCounter.data 1139 | val w = new Array[Float](used) 1140 | 1141 | var lastSum = 0F 1142 | var i = 0 1143 | 1144 | while (i < used) { 1145 | val topic = index(i) 1146 | val count = data(i) 1147 | val lastW = count * alphaSum * (totalTopicCounter(topic) + alphaAS) / 1148 | ((totalTopicCounter(topic) + betaSum) * termSum) 1149 | lastSum += lastW 1150 | w(i) = lastSum 1151 | i += 1 1152 | } 1153 | new BSV[Float](index, w, used, length) 1154 | } 1155 | 1156 | private def t( 1157 | totalTopicCounter: BDV[Count], 1158 | numTokens: Long, 1159 | numTerms: Int, 1160 | numTopics: Int, 1161 | alpha: Float, 1162 | beta: Float, 1163 | alphaAS: Float): BDV[Float] = { 1164 | val t = BDV.zeros[Float](numTopics) 1165 | val alphaSum = alpha * numTopics 1166 | val termSum = numTokens - 1F + alphaAS * numTopics 1167 | val betaSum = numTerms * beta 1168 | 1169 | var lastSum = 0F 1170 | for (topic <- 0 until numTopics) { 1171 | val lastT = beta * alphaSum * (totalTopicCounter(topic) + alphaAS) / 1172 | ((totalTopicCounter(topic) + betaSum) * termSum) 1173 | lastSum += lastT 1174 | t(topic) = lastSum 1175 | } 1176 | t 1177 | } 1178 | 1179 | private def d( 1180 | totalTopicCounter: BDV[Count], 1181 | termTopicCounter: VD, 1182 | docTopicCounter: VD, 1183 | d: BDV[Float], 1184 | currentTopic: Int, 1185 | numTokens: Long, 1186 | numTerms: Int, 1187 | numTopics: Int, 1188 | beta: Float, 1189 | alphaAS: Float): Unit = { 1190 | val used = docTopicCounter.activeSize 1191 | val index = 1192 | docTopicCounter match { 1193 | case idx:BSV[_] => idx.index 1194 | case idx:BDV[_] => (0 until idx.length).toArray 1195 | } 1196 | val data = docTopicCounter.data 1197 | 1198 | // val termSum = numTokens - 1D + alphaAS * numTopics 1199 | val betaSum = numTerms * beta 1200 | var i = 0 1201 | var lastSum = 0F 1202 | 1203 | while (i < used) { 1204 | val topic = index(i) 1205 | val count: Float = data(i) 1206 | val adjustment = if (currentTopic == topic) -1F else 0 1207 | // val lastD = count * termSum * (termTopicCounter(topic) + beta) / 1208 | // ((totalTopicCounter(topic) + betaSum) * termSum) 1209 | 1210 | val lastD = (count + adjustment) * (termTopicCounter(topic) + adjustment + beta) / 1211 | (totalTopicCounter(topic) + adjustment + betaSum) 1212 | lastSum += lastD 1213 | d(topic) = lastSum 1214 | i += 1 1215 | } 1216 | d(numTopics - 1) = lastSum 1217 | } 1218 | 1219 | /** 1220 | * A multinomial distribution sampler, using roulette method to sample an Int back. 1221 | */ 1222 | private def multinomialDistSampler( 1223 | gen: Random, 1224 | docTopicCounter: VD, 1225 | d: BDV[Float], 1226 | w: BSV[Float], 1227 | t: BDV[Float]): Int = { 1228 | val numTopics = d.length 1229 | val lastSum = t(numTopics - 1) + w.data(w.used - 1) + d(numTopics - 1) 1230 | val distSum = gen.nextFloat() * lastSum 1231 | val fun = index(docTopicCounter, d, w, t) _ 1232 | val topic = binarySearchInterval[Float](fun, distSum, 0, numTopics, true) 1233 | math.min(topic, numTopics - 1) 1234 | } 1235 | 1236 | private def index( 1237 | docTopicCounter: VD, 1238 | d: BDV[Float], 1239 | w: BSV[Float], 1240 | t: BDV[Float])(i: Int) = { 1241 | val lastDS = binarySearchDenseVector(i, docTopicCounter, d) 1242 | val lastWS = binarySearchSparseVector(i, w) 1243 | val lastTS = t(i) 1244 | lastDS + lastWS + lastTS 1245 | } 1246 | 1247 | private[topicModeling] def binarySearchInterval[K]( 1248 | index: Int => K, 1249 | key: K, 1250 | begin: Int, 1251 | end: Int, 1252 | greater: Boolean)(implicit ord: Ordering[K], ctag: ClassTag[K]): Int = { 1253 | if (begin == end) { 1254 | return if (greater) end else begin - 1 1255 | } 1256 | var b = begin 1257 | var e = end - 1 1258 | 1259 | var mid: Int = (e + b) >> 1 1260 | while (b <= e) { 1261 | mid = (e + b) >> 1 1262 | val v = index(mid) 1263 | if (ord.lt(v, key)) { 1264 | b = mid + 1 1265 | } 1266 | else if (ord.gt(v, key)) { 1267 | e = mid - 1 1268 | } 1269 | else { 1270 | return mid 1271 | } 1272 | } 1273 | 1274 | val v = index(mid) 1275 | mid = if ((greater && ord.gteq(v, key)) || (!greater && ord.lteq(v, key))) { 1276 | mid 1277 | } 1278 | else if (greater) { 1279 | mid + 1 1280 | } 1281 | else { 1282 | mid - 1 1283 | } 1284 | 1285 | if (greater) { 1286 | if (mid < end) assert(ord.gteq(index(mid), key)) 1287 | if (mid > 0) assert(ord.lteq(index(mid - 1), key)) 1288 | } else { 1289 | if (mid > 0) assert(ord.lteq(index(mid), key)) 1290 | if (mid < end - 1) assert(ord.gteq(index(mid + 1), key)) 1291 | } 1292 | mid 1293 | } 1294 | 1295 | private[topicModeling] def binarySearchArray[K]( 1296 | index: Array[K], 1297 | key: K, 1298 | begin: Int, 1299 | end: Int, 1300 | greater: Boolean)(implicit ord: Ordering[K], ctag: ClassTag[K]): Int = { 1301 | binarySearchInterval(i => index(i), key, begin, end, greater) 1302 | } 1303 | 1304 | private[topicModeling] def binarySearchSparseVector(topic: Int, sv: BSV[Float]) = { 1305 | val index = sv.index 1306 | val used = sv.used 1307 | val data = sv.data 1308 | val pos = binarySearchArray(index, topic, 0, used, false) 1309 | if (pos > -1) data(pos) else 0F 1310 | } 1311 | 1312 | private[topicModeling] def binarySearchDenseVector[V]( 1313 | topic: Int, 1314 | sv: StorageVector[V], 1315 | dv: BDV[Float]): Float = { 1316 | val index = 1317 | sv match { 1318 | case sv:BDV[_] => (0 until sv.length).toArray 1319 | case sv:BSV[_] => sv.index 1320 | } 1321 | val used = sv.activeSize 1322 | val pos = binarySearchArray(index, topic, 0, used, false) 1323 | if (pos > -1) dv(index(pos)) else 0F 1324 | } 1325 | } 1326 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/topicModeling/LDA.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.topicModeling 19 | 20 | import org.apache.spark.Logging 21 | import org.apache.spark.annotation.Experimental 22 | import org.apache.spark.api.java.JavaPairRDD 23 | import org.apache.spark.mllib.linalg.Vector 24 | import org.apache.spark.rdd.RDD 25 | 26 | 27 | /** 28 | * :: Experimental :: 29 | * 30 | * Latent Dirichlet Allocation (LDA), a topic model designed for text documents. 31 | * 32 | * Terminology: 33 | * - "word" = "term": an element of the vocabulary 34 | * - "token": instance of a term appearing in a document 35 | * - "topic": multinomial distribution over words representing some concept 36 | * 37 | * Adapted from MLlib LDA implementation, it supports both online LDA and Gibbbs sampling LDA 38 | */ 39 | @Experimental 40 | class LDA private ( 41 | private var k: Int, 42 | private var maxIterations: Int, 43 | private var docConcentration: Double, 44 | private var topicConcentration: Double, 45 | private var seed: Long, 46 | private var checkpointInterval: Int, 47 | private var ldaOptimizer: LDAOptimizer) extends Logging { 48 | 49 | def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1, 50 | seed = new scala.util.Random().nextLong(), checkpointInterval = 10, ldaOptimizer = new OnlineLDAOptimizer) 51 | 52 | /** 53 | * Number of topics to infer. I.e., the number of soft cluster centers. 54 | */ 55 | def getK: Int = k 56 | 57 | /** 58 | * Number of topics to infer. I.e., the number of soft cluster centers. 59 | * (default = 10) 60 | */ 61 | def setK(k: Int): this.type = { 62 | require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k") 63 | this.k = k 64 | this 65 | } 66 | 67 | /** 68 | * Concentration parameter (commonly named "alpha") for the prior placed on documents' 69 | * distributions over topics ("theta"). 70 | * 71 | * This is the parameter to a symmetric Dirichlet distribution. 72 | */ 73 | def getDocConcentration: Double = this.docConcentration 74 | 75 | /** 76 | * Concentration parameter (commonly named "alpha") for the prior placed on documents' 77 | * distributions over topics ("theta"). 78 | * 79 | * This is the parameter to a symmetric Dirichlet distribution, where larger values 80 | * mean more smoothing (more regularization). 81 | * 82 | * If set to -1, then docConcentration is set automatically. 83 | * (default = -1 = automatic) 84 | * 85 | * Optimizer-specific parameter settings: 86 | * - Online 87 | * - Value should be >= 0 88 | * - default = (1.0 / k), following the implementation from 89 | * [[https://github.com/Blei-Lab/onlineldavb]]. 90 | */ 91 | def setDocConcentration(docConcentration: Double): this.type = { 92 | this.docConcentration = docConcentration 93 | this 94 | } 95 | 96 | /** Alias for [[getDocConcentration]] */ 97 | def getAlpha: Double = getDocConcentration 98 | 99 | /** Alias for [[setDocConcentration()]] */ 100 | def setAlpha(alpha: Double): this.type = setDocConcentration(alpha) 101 | 102 | /** 103 | * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' 104 | * distributions over terms. 105 | * 106 | * This is the parameter to a symmetric Dirichlet distribution. 107 | * 108 | * Note: The topics' distributions over terms are called "beta" in the original LDA paper 109 | * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. 110 | */ 111 | def getTopicConcentration: Double = this.topicConcentration 112 | 113 | /** 114 | * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' 115 | * distributions over terms. 116 | * 117 | * This is the parameter to a symmetric Dirichlet distribution. 118 | * 119 | * Note: The topics' distributions over terms are called "beta" in the original LDA paper 120 | * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. 121 | * 122 | * If set to -1, then topicConcentration is set automatically. 123 | * (default = -1 = automatic) 124 | * 125 | * Optimizer-specific parameter settings: 126 | * - Online 127 | * - Value should be >= 0 128 | * - default = (1.0 / k), following the implementation from 129 | * [[https://github.com/Blei-Lab/onlineldavb]]. 130 | */ 131 | def setTopicConcentration(topicConcentration: Double): this.type = { 132 | this.topicConcentration = topicConcentration 133 | this 134 | } 135 | 136 | /** Alias for [[getTopicConcentration]] */ 137 | def getBeta: Double = getTopicConcentration 138 | 139 | /** Alias for [[setTopicConcentration()]] */ 140 | def setBeta(beta: Double): this.type = setTopicConcentration(beta) 141 | 142 | /** 143 | * Maximum number of iterations for learning. 144 | */ 145 | def getMaxIterations: Int = maxIterations 146 | 147 | /** 148 | * Maximum number of iterations for learning. 149 | * (default = 20) 150 | */ 151 | def setMaxIterations(maxIterations: Int): this.type = { 152 | this.maxIterations = maxIterations 153 | this 154 | } 155 | 156 | /** Random seed */ 157 | def getSeed: Long = seed 158 | 159 | /** Random seed */ 160 | def setSeed(seed: Long): this.type = { 161 | this.seed = seed 162 | this 163 | } 164 | 165 | /** 166 | * Period (in iterations) between checkpoints. 167 | */ 168 | def getCheckpointInterval: Int = checkpointInterval 169 | 170 | /** 171 | * Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery 172 | * (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be 173 | * important when LDA is run for many iterations. If the checkpoint directory is not set in 174 | * [[org.apache.spark.SparkContext]], this setting is ignored. 175 | * 176 | * @see [[org.apache.spark.SparkContext#setCheckpointDir]] 177 | */ 178 | def setCheckpointInterval(checkpointInterval: Int): this.type = { 179 | this.checkpointInterval = checkpointInterval 180 | this 181 | } 182 | 183 | 184 | /** LDAOptimizer used to perform the actual calculation */ 185 | def getOptimizer: LDAOptimizer = ldaOptimizer 186 | 187 | /** 188 | * LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer) 189 | */ 190 | def setOptimizer(optimizer: LDAOptimizer): this.type = { 191 | this.ldaOptimizer = optimizer 192 | this 193 | } 194 | 195 | /** 196 | * Set the LDAOptimizer used to perform the actual calculation by algorithm name. 197 | * Only "online", "gibbs" are supported. 198 | */ 199 | def setOptimizer(optimizerName: String): this.type = { 200 | this.ldaOptimizer = 201 | optimizerName.toLowerCase match { 202 | case "online" => new OnlineLDAOptimizer 203 | case "gibbs" => new GibbsLDAOptimizer() 204 | case other => 205 | throw new IllegalArgumentException(s"Only online, gibbs are supported but got $other.") 206 | } 207 | this 208 | } 209 | 210 | /** 211 | * Learn an LDA model using the given dataset. 212 | * 213 | * @param documents RDD of documents, which are term (word) count vectors paired with IDs. 214 | * The term count vectors are "bags of words" with a fixed-size vocabulary 215 | * (where the vocabulary size is the length of the vector). 216 | * Document IDs must be unique and >= 0. 217 | * @return Inferred LDA model 218 | */ 219 | def run(documents: RDD[(Long, Vector)]): LDAModel = { 220 | val state = ldaOptimizer.initialize(documents, this) 221 | var iter = 0 222 | val iterationTimes = Array.fill[Double](maxIterations)(0) 223 | while (iter < maxIterations) { 224 | logInfo(s"Starting iteration $iter/$maxIterations") 225 | val start = System.nanoTime() 226 | state.next() 227 | val elapsedSeconds = (System.nanoTime() - start) / 1e9 228 | iterationTimes(iter) = elapsedSeconds 229 | iter += 1 230 | } 231 | state.getLDAModel(iterationTimes) 232 | } 233 | 234 | /** Java-friendly version of [[run()]] */ 235 | def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = { 236 | run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) 237 | } 238 | } 239 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/topicModeling/LDAExample.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.topicModeling 19 | 20 | import java.text.BreakIterator 21 | 22 | import org.apache.log4j.{Level, Logger} 23 | import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} 24 | import org.apache.spark.rdd.RDD 25 | import org.apache.spark.{SparkConf, SparkContext} 26 | import scopt.OptionParser 27 | 28 | import scala.collection.mutable 29 | import scala.reflect.runtime.universe._ 30 | 31 | /** 32 | * Abstract class for parameter case classes. 33 | * This overrides the [[toString]] method to print all case class fields by name and value. 34 | * @tparam T Concrete parameter class. 35 | */ 36 | abstract class AbstractParams[T: TypeTag] { 37 | 38 | private def tag: TypeTag[T] = typeTag[T] 39 | 40 | /** 41 | * Finds all case class fields in concrete class instance, and outputs them in JSON-style format: 42 | * { 43 | * [field name]:\t[field value]\n 44 | * [field name]:\t[field value]\n 45 | * ... 46 | * } 47 | */ 48 | override def toString: String = { 49 | val tpe = tag.tpe 50 | val allAccessors = tpe.declarations.collect { 51 | case m: MethodSymbol if m.isCaseAccessor => m 52 | } 53 | val mirror = runtimeMirror(getClass.getClassLoader) 54 | val instanceMirror = mirror.reflect(this) 55 | allAccessors.map { f => 56 | val paramName = f.name.toString 57 | val fieldMirror = instanceMirror.reflectField(f) 58 | val paramValue = fieldMirror.get 59 | s" $paramName:\t$paramValue" 60 | }.mkString("{\n", ",\n", "\n}") 61 | } 62 | } 63 | 64 | /** 65 | * An example Latent Dirichlet Allocation (LDA) app. Run with 66 | * {{{ 67 | * ./bin/run-example mllib.LDAExample [options] 68 | * }}} 69 | * If you use it as a template to create your own app, please use `spark-submit` to submit your app. 70 | */ 71 | object LDAExample { 72 | 73 | private case class Params( 74 | input: Seq[String] = Seq.empty, 75 | k: Int = 20, 76 | maxIterations: Int = 10, 77 | maxInnerIterations: Int = 5, 78 | docConcentration: Double = 0.01, 79 | topicConcentration: Double = 0.01, 80 | vocabSize: Int = 10000, 81 | stopwordFile: String = "", 82 | checkpointDir: Option[String] = None, 83 | checkpointInterval: Int = 10, 84 | optimizer:String = "em", 85 | gibbsSampler:String = "alias", 86 | gibbsAlphaAS:Double = 0.1, 87 | gibbsPrintPerplexity:Boolean = false, 88 | gibbsEdgePartitioner:String = "none", 89 | partitions:Int = 2, 90 | logLevel:String = "info", 91 | psMasterAddr:String = null 92 | ) extends AbstractParams[Params] 93 | 94 | def main(args: Array[String]) { 95 | val defaultParams = Params() 96 | 97 | val parser = new OptionParser[Params]("LDAExample") { 98 | head("LDAExample: an example LDA app for plain text data.") 99 | opt[Int]("k") 100 | .text(s"number of topics. default: ${defaultParams.k}") 101 | .action((x, c) => c.copy(k = x)) 102 | opt[Int]("maxIterations") 103 | .text(s"number of iterations of learning. default: ${defaultParams.maxIterations}") 104 | .action((x, c) => c.copy(maxIterations = x)) 105 | opt[Int]("maxInnerIterations") 106 | .text(s"number of inner iterations of learning. default: ${defaultParams.maxInnerIterations}") 107 | .action((x, c) => c.copy(maxInnerIterations = x)) 108 | opt[Double]("docConcentration") 109 | .text(s"amount of topic smoothing to use (> 1.0) (-1=auto)." + 110 | s" default: ${defaultParams.docConcentration}") 111 | .action((x, c) => c.copy(docConcentration = x)) 112 | opt[Double]("topicConcentration") 113 | .text(s"amount of term (word) smoothing to use (> 1.0) (-1=auto)." + 114 | s" default: ${defaultParams.topicConcentration}") 115 | .action((x, c) => c.copy(topicConcentration = x)) 116 | opt[Int]("vocabSize") 117 | .text(s"number of distinct word types to use, chosen by frequency. (-1=all)" + 118 | s" default: ${defaultParams.vocabSize}") 119 | .action((x, c) => c.copy(vocabSize = x)) 120 | opt[String]("stopwordFile") 121 | .text(s"filepath for a list of stopwords. Note: This must fit on a single machine." + 122 | s" default: ${defaultParams.stopwordFile}") 123 | .action((x, c) => c.copy(stopwordFile = x)) 124 | opt[String]("checkpointDir") 125 | .text(s"Directory for checkpointing intermediate results." + 126 | s" Checkpointing helps with recovery and eliminates temporary shuffle files on disk." + 127 | s" default: ${defaultParams.checkpointDir}") 128 | .action((x, c) => c.copy(checkpointDir = Some(x))) 129 | opt[Int]("checkpointInterval") 130 | .text(s"Iterations between each checkpoint. Only used if checkpointDir is set." + 131 | s" default: ${defaultParams.checkpointInterval}") 132 | .action((x, c) => c.copy(checkpointInterval = x)) 133 | opt[String]("optimizer") 134 | .text(s"available optimizer are online and gibbs, default: ${defaultParams.optimizer}") 135 | .action((x, c) => c.copy(optimizer = x)) 136 | opt[String]("gibbs.sampler") 137 | .text(s"sampler for gibbs optimizer, available options are alias, sparse, light and fast, default: ${defaultParams.gibbsSampler}") 138 | .action((x, c) => c.copy(gibbsSampler = x)) 139 | opt[Double]("gibbs.alphaAS") 140 | .text(s"alphaAS for gibbs optimizer, default: ${defaultParams.gibbsAlphaAS}") 141 | .action((x, c) => c.copy(gibbsAlphaAS = x)) 142 | opt[Boolean]("gibbs.printPerplexity") 143 | .text(s"print perplexity for gibbs optimizer, default: ${defaultParams.gibbsPrintPerplexity}") 144 | .action((x, c) => c.copy(gibbsPrintPerplexity = x)) 145 | opt[String]("gibbs.edgePartitioner") 146 | .text(s"edge partitioner for gibbs optimizer, available options are none, degree, and even, default: ${defaultParams.gibbsEdgePartitioner}}") 147 | .action((x, c) => c.copy(gibbsEdgePartitioner = x)) 148 | opt[Int]("partitions") 149 | .text(s"Minimum edge partitions, default: ${defaultParams.partitions}") 150 | .action((x, c) => c.copy(partitions = x)) 151 | opt[String]("logLevel") 152 | .text(s"Log level, default: ${defaultParams.logLevel}") 153 | .action((x, c) => c.copy(logLevel = x)) 154 | opt[String]("psMasterAddr") 155 | .text(s"psMaster address, default: ${defaultParams.psMasterAddr}") 156 | .action((x, c) => c.copy(psMasterAddr = x)) 157 | arg[String]("...") 158 | .text("input paths (directories) to plain text corpora." + 159 | " Each text file line should hold 1 document.") 160 | .unbounded() 161 | .required() 162 | .action((x, c) => c.copy(input = c.input :+ x)) 163 | } 164 | 165 | parser.parse(args, defaultParams).map { params => 166 | run(params) 167 | }.getOrElse { 168 | parser.showUsageAsError 169 | sys.exit(1) 170 | } 171 | } 172 | 173 | // private def createOptimizer(params: Params, lineRdd: RDD[Int], columnRdd: RDD[Int]):LDAOptimizer = { 174 | private def createOptimizer(params: Params):LDAOptimizer = { 175 | params.optimizer match { 176 | case "online" => val optimizer = new OnlineLDAOptimizer 177 | optimizer 178 | case "gibbs" => 179 | val optimizer = new GibbsLDAOptimizer 180 | optimizer.setSampler(params.gibbsSampler) 181 | optimizer.printPerplexity = params.gibbsPrintPerplexity 182 | optimizer.edgePartitioner = params.gibbsEdgePartitioner 183 | optimizer.setAlphaAS(params.gibbsAlphaAS.toFloat) 184 | optimizer 185 | case _ => 186 | throw new IllegalArgumentException(s"available optimizers are em, online and gibbs, but got ${params.optimizer}") 187 | } 188 | } 189 | 190 | /** 191 | * run LDA 192 | * @param params 193 | */ 194 | private def run(params: Params) { 195 | val conf = new SparkConf().setAppName(s"LDAExample with $params") 196 | val sc = new SparkContext(conf) 197 | 198 | val logLevel = Level.toLevel(params.logLevel, Level.INFO) 199 | Logger.getRootLogger.setLevel(logLevel) 200 | println(s"Setting log level to $logLevel") 201 | 202 | // Load documents, and prepare them for LDA. 203 | val preprocessStart = System.nanoTime() 204 | val (corpus, actualCorpusSize, vocabArray, actualNumTokens) = 205 | preprocess(sc, params.input, params.vocabSize, params.partitions, params.stopwordFile) 206 | val actualVocabSize = vocabArray.size 207 | val preprocessElapsed = (System.nanoTime() - preprocessStart) / 1e9 208 | 209 | println() 210 | println(s"Corpus summary:") 211 | println(s"\t Training set size: $actualCorpusSize documents") 212 | println(s"\t Vocabulary size: $actualVocabSize terms") 213 | println(s"\t Training set size: $actualNumTokens tokens") 214 | println(s"\t Preprocessing time: $preprocessElapsed sec") 215 | println() 216 | 217 | // Run LDA. 218 | val lda = new LDA() 219 | lda.setK(params.k) 220 | .setMaxIterations(params.maxIterations) 221 | .setDocConcentration(params.docConcentration) 222 | .setTopicConcentration(params.topicConcentration) 223 | .setCheckpointInterval(params.checkpointInterval) 224 | .setOptimizer(createOptimizer(params)) 225 | 226 | if (params.checkpointDir.nonEmpty) { 227 | sc.setCheckpointDir(params.checkpointDir.get) 228 | } 229 | 230 | val startTime = System.nanoTime() 231 | val ldaModel = lda.run(corpus) 232 | val elapsed = (System.nanoTime() - startTime) / 1e9 233 | 234 | println(s"Finished training LDA model using ${lda.getOptimizer.getClass.getName}") 235 | 236 | // Print the topics, showing the top-weighted terms for each topic. 237 | val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) 238 | 239 | val topics = topicIndices.map { case (terms, termWeights) => 240 | terms.zip(termWeights).map { case (term, weight) => (vocabArray(term.toInt), weight) } 241 | } 242 | println(s"${params.k} topics:") 243 | topics.zipWithIndex.foreach { case (topic, i) => 244 | println(s"TOPIC $i") 245 | topic.foreach { case (term, weight) => 246 | println(s"$term\t$weight") 247 | } 248 | println() 249 | } 250 | sc.stop() 251 | } 252 | 253 | /** 254 | * Load documents, tokenize them, create vocabulary, and prepare documents as term count vectors. 255 | * @return (corpus, vocabulary as array, total token count in corpus) 256 | */ 257 | private def preprocess( 258 | sc: SparkContext, 259 | paths: Seq[String], 260 | vocabSize: Int, 261 | partitions:Int, 262 | stopwordFile: String): (RDD[(Long, Vector)], Long, Array[String], Long) = { 263 | 264 | // Get dataset of document texts 265 | // One document per line in each text file. If the input consists of many small files, 266 | // this can result in a large number of small partitions, which can degrade performance. 267 | // In this case, consider using coalesce() to create fewer, larger partitions. 268 | val textRDD: RDD[String] = sc.textFile(paths.mkString(","), Math.max(1, partitions)).coalesce(partitions) 269 | 270 | 271 | // Split text into words 272 | val tokenizer = new SimpleTokenizer(sc, stopwordFile) 273 | val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) => 274 | id -> tokenizer.getWords(text) 275 | } 276 | tokenized.cache() 277 | 278 | // Counts words: RDD[(word, wordCount)] 279 | val wordCounts: RDD[(String, Long)] = tokenized 280 | .flatMap { case (_, tokens) => tokens.map(_ -> 1L) } 281 | .reduceByKey(_ + _) 282 | wordCounts.cache() 283 | val fullVocabSize = wordCounts.count() 284 | // Select vocab 285 | // (vocab: Map[word -> id], total tokens after selecting vocab) 286 | val (vocab: Map[String, Int], selectedTokenCount: Long) = { 287 | val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) { 288 | // Use all terms 289 | wordCounts.collect().sortBy(-_._2) 290 | } else { 291 | // Sort terms to select vocab 292 | wordCounts.sortBy(_._2, ascending = false).take(vocabSize) 293 | } 294 | (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum) 295 | } 296 | 297 | val documents = tokenized.map { case (id, tokens) => 298 | // Filter tokens by vocabulary, and create word count vector representation of document. 299 | val wc = new mutable.HashMap[Int, Int]() 300 | tokens.foreach { term => 301 | if (vocab.contains(term)) { 302 | val termIndex = vocab(term) 303 | wc(termIndex) = wc.getOrElse(termIndex, 0) + 1 304 | } 305 | } 306 | val indices = wc.keys.toArray.sorted 307 | val values = indices.map(i => wc(i).toDouble) 308 | 309 | val sb = Vectors.sparse(vocab.size, indices, values) 310 | (id, sb) 311 | }.filter(_._2.asInstanceOf[SparseVector].values.length > 0).cache 312 | val corpusSize = documents.count 313 | tokenized.unpersist(false) 314 | 315 | val vocabArray = new Array[String](vocab.size) 316 | vocab.foreach { case (term, i) => vocabArray(i) = term } 317 | 318 | (documents, corpusSize, vocabArray, selectedTokenCount) 319 | } 320 | } 321 | 322 | /** 323 | * Simple Tokenizer. 324 | * 325 | * TODO: Formalize the interface, and make this a public class in mllib.feature 326 | */ 327 | private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable { 328 | 329 | private val stopwords: Set[String] = if (stopwordFile.isEmpty) { 330 | Set.empty[String] 331 | } else { 332 | val stopwordText = sc.textFile(stopwordFile).collect() 333 | stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet 334 | } 335 | 336 | // Matches sequences of Unicode letters 337 | private val allWordRegex = "^(\\p{L}*)$".r 338 | 339 | // Ignore words shorter than this length. 340 | private val minWordLength = 3 341 | 342 | def getWords(text: String): IndexedSeq[String] = { 343 | 344 | val words = new mutable.ArrayBuffer[String]() 345 | 346 | // Use Java BreakIterator to tokenize text into words. 347 | val wb = BreakIterator.getWordInstance 348 | wb.setText(text) 349 | 350 | // current,end index start,end of each word 351 | var current = wb.first() 352 | var end = wb.next() 353 | while (end != BreakIterator.DONE) { 354 | // Convert to lowercase 355 | val word: String = text.substring(current, end).toLowerCase 356 | // Remove short words and strings that aren't only letters 357 | word match { 358 | case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) => 359 | words += w 360 | case _ => 361 | } 362 | 363 | current = end 364 | try { 365 | end = wb.next() 366 | } catch { 367 | case e: Exception => 368 | // Ignore remaining text in line. 369 | // This is a known bug in BreakIterator (for some Java versions), 370 | // which fails when it sees certain characters. 371 | end = BreakIterator.DONE 372 | } 373 | } 374 | words 375 | } 376 | 377 | } 378 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/topicModeling/LDAModel.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.topicModeling 19 | 20 | import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum} 21 | import org.apache.spark.annotation.Experimental 22 | import org.apache.spark.graphx.{EdgeContext, Graph, VertexId} 23 | import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors} 24 | import org.apache.spark.rdd.RDD 25 | import scala.collection.mutable.ArrayBuffer 26 | 27 | /** 28 | * :: Experimental :: 29 | * 30 | * Latent Dirichlet Allocation (LDA) model. 31 | * 32 | * This abstraction permits for different underlying representations, 33 | * including local and distributed data structures. 34 | 35 | * Adapted from MLlib LDAModel implementation 36 | */ 37 | @Experimental 38 | abstract class LDAModel { 39 | 40 | /** Number of topics */ 41 | def k: Int 42 | 43 | /** Vocabulary size (number of terms or terms in the vocabulary) */ 44 | def vocabSize: Int 45 | 46 | /** 47 | * Inferred topics, where each topic is represented by a distribution over terms. 48 | * This is a matrix of size vocabSize x k, where each column is a topic. 49 | * No guarantees are given about the ordering of the topics. 50 | */ 51 | def topicsMatrix: Matrix 52 | 53 | /** 54 | * Return the topics described by weighted terms. 55 | * 56 | * This limits the number of terms per topic. 57 | * This is approximate; it may not return exactly the top-weighted terms for each topic. 58 | * To get a more precise set of top terms, increase maxTermsPerTopic. 59 | * 60 | * @param maxTermsPerTopic Maximum number of terms to collect for each topic. 61 | * @return Array over topics. Each topic is represented as a pair of matching arrays: 62 | * (term indices, term weights in topic). 63 | * Each topic's terms are sorted in order of decreasing weight. 64 | */ 65 | def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] 66 | 67 | /** 68 | 69 | * Return the topics described by weighted terms. 70 | * 71 | * WARNING: If vocabSize and k are large, this can return a large object! 72 | * 73 | * @return Array over topics. Each topic is represented as a pair of matching arrays: 74 | * (term indices, term weights in topic). 75 | * Each topic's terms are sorted in order of decreasing weight. 76 | */ 77 | def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize) 78 | 79 | /* TODO (once LDA can be trained with Strings or given a dictionary) 80 | * Return the topics described by weighted terms. 81 | * 82 | * This is similar to [[describeTopics()]] but returns String values for terms. 83 | * If this model was trained using Strings or was given a dictionary, then this method returns 84 | * terms as text. Otherwise, this method returns terms as term indices. 85 | * 86 | * This limits the number of terms per topic. 87 | * This is approximate; it may not return exactly the top-weighted terms for each topic. 88 | * To get a more precise set of top terms, increase maxTermsPerTopic. 89 | * 90 | * @param maxTermsPerTopic Maximum number of terms to collect for each topic. 91 | * @return Array over topics. Each topic is represented as a pair of matching arrays: 92 | * (terms, term weights in topic) where terms are either the actual term text 93 | * (if available) or the term indices. 94 | * Each topic's terms are sorted in order of decreasing weight. 95 | */ 96 | // def describeTopicsAsStrings(maxTermsPerTopic: Int): Array[(Array[Double], Array[String])] 97 | 98 | /* TODO (once LDA can be trained with Strings or given a dictionary) 99 | * Return the topics described by weighted terms. 100 | * 101 | * This is similar to [[describeTopics()]] but returns String values for terms. 102 | * If this model was trained using Strings or was given a dictionary, then this method returns 103 | * terms as text. Otherwise, this method returns terms as term indices. 104 | * 105 | * WARNING: If vocabSize and k are large, this can return a large object! 106 | * 107 | * @return Array over topics. Each topic is represented as a pair of matching arrays: 108 | * (terms, term weights in topic) where terms are either the actual term text 109 | * (if available) or the term indices. 110 | * Each topic's terms are sorted in order of decreasing weight. 111 | */ 112 | // def describeTopicsAsStrings(): Array[(Array[Double], Array[String])] = 113 | // describeTopicsAsStrings(vocabSize) 114 | 115 | /* TODO 116 | * Compute the log likelihood of the observed tokens, given the current parameter estimates: 117 | * log P(docs | topics, topic distributions for docs, alpha, eta) 118 | * 119 | * Note: 120 | * - This excludes the prior. 121 | * - Even with the prior, this is NOT the same as the data log likelihood given the 122 | * hyperparameters. 123 | * 124 | * @param documents RDD of documents, which are term (word) count vectors paired with IDs. 125 | * The term count vectors are "bags of words" with a fixed-size vocabulary 126 | * (where the vocabulary size is the length of the vector). 127 | * This must use the same vocabulary (ordering of term counts) as in training. 128 | * Document IDs must be unique and >= 0. 129 | * @return Estimated log likelihood of the data under this model 130 | */ 131 | // def logLikelihood(documents: RDD[(Long, Vector)]): Double 132 | 133 | /* TODO 134 | * Compute the estimated topic distribution for each document. 135 | * This is often called 'theta' in the literature. 136 | * 137 | * @param documents RDD of documents, which are term (word) count vectors paired with IDs. 138 | * The term count vectors are "bags of words" with a fixed-size vocabulary 139 | * (where the vocabulary size is the length of the vector). 140 | * This must use the same vocabulary (ordering of term counts) as in training. 141 | * Document IDs must be unique and >= 0. 142 | * @return Estimated topic distribution for each document. 143 | * The returned RDD may be zipped with the given RDD, where each returned vector 144 | * is a multinomial distribution over topics. 145 | */ 146 | // def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] 147 | 148 | } 149 | 150 | /** 151 | * :: Experimental :: 152 | * 153 | * Local LDA model. 154 | * This model stores only the inferred topics. 155 | * 156 | * @param topics Inferred topics (vocabSize x k matrix). 157 | */ 158 | @Experimental 159 | class LocalLDAModel private[topicModeling] ( 160 | private val topics: Matrix) extends LDAModel with Serializable { 161 | 162 | override def k: Int = topics.numCols 163 | 164 | override def vocabSize: Int = topics.numRows 165 | 166 | override def topicsMatrix: Matrix = topics 167 | 168 | override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { 169 | val brzTopics = topics.toBreeze.toDenseMatrix 170 | Range(0, k).map { topicIndex => 171 | val topic = normalize(brzTopics(::, topicIndex), 1.0) 172 | val (termWeights, terms) = 173 | topic.toArray.zipWithIndex.sortBy(-_._1).take(maxTermsPerTopic).unzip 174 | (terms.toArray, termWeights.toArray) 175 | }.toArray 176 | } 177 | 178 | // TODO 179 | // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? 180 | 181 | // TODO: 182 | // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? 183 | 184 | } 185 | 186 | /** 187 | * :: Experimental :: 188 | * 189 | * Online LDA Model with an interface supporting prediction 190 | * on document distribution 191 | * @param topics Inferred topics (vocabSize x k matrix). 192 | */ 193 | @Experimental 194 | class OnlineLDAModel( 195 | private val topics: Matrix, 196 | private val alpha: Double, 197 | private val gammaShape: Double) extends LDAModel with Serializable { 198 | 199 | override def k: Int = topics.numCols 200 | 201 | override def vocabSize: Int = topics.numRows 202 | 203 | // vocabSize x k, where each column is a topic. 204 | override def topicsMatrix: Matrix = topics 205 | 206 | override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { 207 | val brzTopics = topics.toBreeze.toDenseMatrix 208 | Range(0, k).map { topicIndex => 209 | val topic = normalize(brzTopics(::, topicIndex), 1.0) 210 | val (termWeights, terms) = 211 | topic.toArray.zipWithIndex.sortBy(-_._1).take(maxTermsPerTopic).unzip 212 | (terms.toArray, termWeights.toArray) 213 | }.toArray 214 | } 215 | 216 | /** 217 | * For each document in the training set, return the distribution over topics for that document 218 | * ("theta_doc"). 219 | * 220 | * @return RDD of (document ID, topic distribution) pairs 221 | */ 222 | def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = { 223 | val (_, _, gammaArray) = OnlineLDAOptimizer.inference( 224 | k, vocabSize, (topics.transpose).toBreeze.toDenseMatrix, alpha, gammaShape, documents) 225 | val result = gammaArray.map(p => { 226 | (p._1, Vectors.fromBreeze(p._2)) 227 | }) 228 | result 229 | } 230 | 231 | /** 232 | * For each document in the training set, return the distribution over topics for that document 233 | * ("theta_doc"). 234 | * 235 | * @return RDD of (document ID, topic distribution) pairs 236 | */ 237 | def predict(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = { 238 | topicDistributions(documents) 239 | } 240 | } 241 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/topicModeling/LDAOptimizer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.topicModeling 19 | 20 | import org.apache.spark.annotation.Experimental 21 | import org.apache.spark.mllib.linalg.Vector 22 | import org.apache.spark.rdd.RDD 23 | 24 | /** 25 | * :: Experimental :: 26 | * 27 | * An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can 28 | * hold optimizer-specific parameters for users to set. 29 | */ 30 | @Experimental 31 | trait LDAOptimizer { 32 | 33 | /* 34 | DEVELOPERS NOTE: 35 | 36 | An LDAOptimizer contains an algorithm for LDA and performs the actual computation, which 37 | stores internal data structure (Graph or Matrix) and other parameters for the algorithm. 38 | The interface is isolated to improve the extensibility of LDA. 39 | */ 40 | 41 | /** 42 | * Initializer for the optimizer. LDA passes the common parameters to the optimizer and 43 | * the internal structure can be initialized properly. 44 | */ 45 | private[topicModeling] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer 46 | 47 | /** 48 | * run an iteration 49 | * @return 50 | */ 51 | private[topicModeling] def next(): LDAOptimizer 52 | 53 | /** 54 | * get model of current iteration 55 | * @param iterationTimes 56 | * @return 57 | */ 58 | private[topicModeling] def getLDAModel(iterationTimes: Array[Double]): LDAModel 59 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/topicModeling/OnlineHDP.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib.topicModeling 2 | 3 | import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Vector => BV, _} 4 | import breeze.numerics.{abs, digamma, exp, _} 5 | import org.apache.spark.mllib.linalg.{SparseVector, Vector} 6 | import org.apache.spark.rdd.RDD 7 | 8 | import scala.collection.mutable.ArrayBuffer 9 | 10 | class SuffStats( 11 | val T: Int, 12 | val Wt: Int, 13 | val m_chunksize: Int) extends Serializable { 14 | var m_var_sticks_ss = BDV.zeros[Double](T) 15 | var m_var_beta_ss = BDM.zeros[Double](T, Wt) 16 | 17 | def set_zero: Unit = { 18 | m_var_sticks_ss = BDV.zeros(T) 19 | m_var_beta_ss = BDM.zeros(T, Wt) 20 | } 21 | } 22 | 23 | object OnlineHDPOptimizer extends Serializable { 24 | val rhot_bound = 0.0 25 | 26 | def log_normalize(v: BDV[Double]): (BDV[Double], Double) = { 27 | val log_max = 100.0 28 | val max_val = v.toArray.max 29 | val log_shift = log_max - log(v.size + 1.0) - max_val 30 | val tot: Double = sum(exp(v + log_shift)) 31 | val log_norm = log(tot) - log_shift 32 | (v - log_norm, log_norm) 33 | } 34 | 35 | def log_normalize(m: BDM[Double]): (BDM[Double], BDV[Double]) = { 36 | val log_max = 100.0 37 | // get max for every row 38 | val max_val: BDV[Double] = m(*, ::).map(v => max(v)) 39 | val log_shift: BDV[Double] = log_max - log(m.cols + 1.0) - max_val 40 | 41 | val m_shift: BDM[Double] = exp(m(::, *) + log_shift) 42 | val tot: BDV[Double] = sum(m_shift(*, ::)) 43 | 44 | val log_norm: BDV[Double] = log(tot) - log_shift 45 | (m(::, *) - log_norm, log_norm) 46 | } 47 | 48 | def expect_log_sticks(m: BDM[Double]): BDV[Double] = { 49 | // """ 50 | // For stick-breaking hdp, return the E[log(sticks)] 51 | // """ 52 | val column = sum(m(::, *)) 53 | 54 | val dig_sum: BDV[Double] = digamma(column.toDenseVector) 55 | val ElogW: BDV[Double] = digamma(m(0, ::).inner) - dig_sum 56 | val Elog1_W: BDV[Double] = digamma(m(1, ::).inner) - dig_sum 57 | // 58 | val n = m.cols + 1 59 | val Elogsticks = BDV.zeros[Double](n) 60 | Elogsticks(0 until n - 1) := ElogW(0 until n - 1) 61 | val cs = accumulate(Elog1_W) 62 | 63 | Elogsticks(1 to n - 1) := Elogsticks(1 to n - 1) + cs 64 | Elogsticks 65 | } 66 | 67 | /** 68 | * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation 69 | * uses digamma which is accurate but expensive. 70 | */ 71 | private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = { 72 | val rowSum = sum(alpha(breeze.linalg.*, ::)) 73 | val digAlpha = digamma(alpha) 74 | val digRowSum = digamma(rowSum) 75 | val result = digAlpha(::, breeze.linalg.*) - digRowSum 76 | result 77 | } 78 | 79 | } 80 | 81 | /** 82 | * Implemented based on the paper "Online Variational Inference for the Hierarchical Dirichlet Process" (Chong Wang, John Paisley and David M. Blei) 83 | */ 84 | 85 | class OnlineHDPOptimizer( 86 | val corpus: RDD[(Long, Vector)], 87 | val chunkSize: Int = 256, 88 | val m_kappa: Double = 1.0, 89 | var m_tau: Double = 64.0, 90 | val m_K: Int = 15, 91 | val m_T: Int = 150, 92 | val m_alpha: Double = 1, 93 | val m_gamma: Double = 1, 94 | val m_eta: Double = 0.01, 95 | val m_scale: Double = 1.0, 96 | val m_var_converge: Double = 0.0001, 97 | val iterations: Int = 10 98 | ) extends Serializable { 99 | 100 | 101 | val lda_alpha: Double = 1D 102 | val lda_beta: Double = 1D 103 | val m_W: Int = corpus.first()._2.size 104 | var m_D: Double = corpus.count().toDouble 105 | 106 | val m_var_sticks = BDM.zeros[Double](2, m_T - 1) // 2 * T - 1 107 | m_var_sticks(0, ::) := 1.0 108 | m_var_sticks(1, ::) := new BDV[Double]((m_T - 1 to 1 by -1).map(_.toDouble).toArray).t 109 | var m_varphi_ss: BDV[Double] = BDV.zeros[Double](m_T) // T 110 | 111 | // T * W 112 | val m_lambda: BDM[Double] = BDM.rand(m_T, m_W) * (m_D.toDouble) * 100.0 / (m_T * m_W).toDouble - m_eta 113 | 114 | // T * W 115 | val m_Elogbeta = OnlineHDPOptimizer.dirichletExpectation(m_lambda + m_eta) 116 | 117 | m_tau = m_tau + 1 118 | var m_updatect = 0 119 | var m_status_up_to_date = true 120 | 121 | val m_timestamp: BDV[Int] = BDV.zeros[Int](m_W) 122 | val m_r = collection.mutable.MutableList[Double](0) 123 | var m_lambda_sum = sum(m_lambda(*, ::)) // row sum 124 | 125 | val rhot_bound = 0.0 126 | 127 | 128 | def update(docs: RDD[(Long, Vector)]): Unit = { 129 | for (i <- 1 to iterations) { 130 | val chunk = docs 131 | update_chunk(chunk) 132 | } 133 | } 134 | 135 | 136 | def update_chunk(chunk: RDD[(Long, Vector)], update: Boolean = true): (Double, Int) = { 137 | // Find the unique words in this chunk... 138 | val unique_words = scala.collection.mutable.Map[Int, Int]() 139 | val raw_word_list = ArrayBuffer[Int]() 140 | chunk.collect().foreach(doc => { 141 | doc._2.foreachActive { case (id, count) => 142 | if (count > 0 && !unique_words.contains(id)) { 143 | unique_words += (id -> unique_words.size) 144 | raw_word_list.append(id) 145 | } 146 | } 147 | }) 148 | 149 | val word_list = raw_word_list.toList 150 | 151 | val Wt = word_list.length // length of words in these documents 152 | 153 | // ...and do the lazy updates on the necessary columns of lambda 154 | // rw = np.array([self.m_r[t] for t in self.m_timestamp[word_list]]) 155 | // self.m_lambda[:, word_list] *= np.exp(self.m_r[-1] - rw) 156 | // self.m_Elogbeta[:, word_list] = \ 157 | // sp.psi(self.m_eta + self.m_lambda[:, word_list]) - \ 158 | // sp.psi(self.m_W * self.m_eta + self.m_lambda_sum[:, np.newaxis]) 159 | 160 | 161 | val rw: BDV[Double] = new BDV(word_list.map(id => m_timestamp(id)).map(t => m_r(t)).toArray) 162 | 163 | val exprw: BDV[Double] = exp(rw.map(d => m_r.last - d)) 164 | 165 | val wordsMatrix = m_lambda(::, word_list).toDenseMatrix 166 | for (row <- 0 until wordsMatrix.rows) { 167 | wordsMatrix(row, ::) := (wordsMatrix(row, ::).t :* exprw).t 168 | } 169 | m_lambda(::, word_list) := wordsMatrix 170 | 171 | for (id <- word_list) { 172 | m_Elogbeta(::, id) := digamma(m_lambda(::, id) + m_eta) - digamma(m_lambda_sum + m_W * m_eta) 173 | } 174 | 175 | val ss = new SuffStats(m_T, Wt, chunk.count().toInt) 176 | 177 | val Elogsticks_1st: BDV[Double] = OnlineHDPOptimizer.expect_log_sticks(m_var_sticks) // global sticks 178 | 179 | // run variational inference on some new docs 180 | var score = 0.0 181 | var count = 0D 182 | chunk.collect().foreach(doc => 183 | if (doc._2.size > 0) { 184 | val doc_word_ids = doc._2.asInstanceOf[SparseVector].indices 185 | val doc_word_counts = doc._2.asInstanceOf[SparseVector].values 186 | val dict = unique_words.toMap 187 | val wl = doc_word_ids.toList 188 | 189 | val doc_score = doc_e_step(doc, ss, Elogsticks_1st, 190 | word_list, dict, wl, 191 | new BDV[Double](doc_word_counts), m_var_converge) 192 | count += sum(doc_word_counts) 193 | score += doc_score 194 | } 195 | ) 196 | if (update) { 197 | update_lambda(ss, word_list) 198 | } 199 | 200 | (score, count.toInt) 201 | } 202 | 203 | 204 | def update_lambda(sstats: SuffStats, word_list: List[Int]): Unit = { 205 | m_status_up_to_date = false 206 | // rhot will be between 0 and 1, and says how much to weight 207 | // the information we got from this mini-chunk. 208 | var rhot = m_scale * pow(m_tau + m_updatect, -m_kappa) 209 | if (rhot < rhot_bound) 210 | rhot = rhot_bound 211 | 212 | // Update appropriate columns of lambda based on documents. 213 | // T * Wt T * Wt T * Wt 214 | m_lambda(::, word_list) := (m_lambda(::, word_list).toDenseMatrix * (1 - rhot)) + sstats.m_var_beta_ss * rhot * m_D / sstats.m_chunksize.toDouble 215 | m_lambda_sum = (1 - rhot) * m_lambda_sum + sum(sstats.m_var_beta_ss(*, ::)) * rhot * m_D / sstats.m_chunksize.toDouble 216 | 217 | m_updatect += 1 218 | m_timestamp(word_list) := m_updatect 219 | m_r += (m_r.last + log(1 - rhot)) 220 | 221 | // T 222 | m_varphi_ss = (1.0 - rhot) * m_varphi_ss + rhot * sstats.m_var_sticks_ss * m_D.toDouble / sstats.m_chunksize.toDouble 223 | // update top level sticks 224 | // 2 * T - 1 225 | m_var_sticks(0, ::) := (m_varphi_ss(0 to m_T - 2) + 1.0).t 226 | val var_phi_sum = flipud(m_varphi_ss(1 to m_varphi_ss.length - 1)) // T - 1 227 | m_var_sticks(1, ::) := (flipud(accumulate(var_phi_sum)) + m_gamma).t 228 | 229 | } 230 | 231 | 232 | def doc_e_step(doc: (Long, Vector), 233 | ss: SuffStats, 234 | Elogsticks_1st: BDV[Double], 235 | word_list: List[Int], 236 | unique_words: Map[Int, Int], 237 | doc_word_ids: List[Int], 238 | doc_word_counts: BDV[Double], 239 | var_converge: Double): Double = { 240 | 241 | val chunkids = doc_word_ids.map(id => unique_words(id)) 242 | 243 | val Elogbeta_doc: BDM[Double] = m_Elogbeta(::, doc_word_ids).toDenseMatrix // T * Wt 244 | // very similar to the hdp equations, 2 * K - 1 245 | val v = BDM.zeros[Double](2, m_K - 1) 246 | v(0, ::) := 1.0 247 | v(1, ::) := m_alpha 248 | 249 | var Elogsticks_2nd = OnlineHDPOptimizer.expect_log_sticks(v) 250 | 251 | // back to the uniform 252 | var phi: BDM[Double] = BDM.ones[Double](doc_word_ids.size, m_K) * 1.0 / m_K.toDouble // Wt * K 253 | 254 | var likelihood = 0.0 255 | var old_likelihood = -1e200 256 | val converge = 1.0 257 | val eps = 1e-100 258 | 259 | var iter = 0 260 | val max_iter = 100 261 | 262 | var var_phi_out: BDM[Double] = BDM.ones[Double](1, 1) 263 | 264 | // not yet support second level optimization yet, to be done in the future 265 | while (iter < max_iter && (converge < 0.0 || converge > var_converge)) { 266 | 267 | // var_phi 268 | val (log_var_phi: BDM[Double], var_phi: BDM[Double]) = 269 | if (iter < 3) { 270 | val element = Elogbeta_doc.copy // T * Wt 271 | for (i <- 0 to element.rows - 1) { 272 | element(i, ::) := (element(i, ::).t :* doc_word_counts).t 273 | } 274 | var var_phi: BDM[Double] = phi.t * element.t // K * Wt * Wt * T => K * T 275 | val (log_var_phi, log_norm) = OnlineHDPOptimizer.log_normalize(var_phi) 276 | var_phi = exp(log_var_phi) 277 | (log_var_phi, var_phi) 278 | } 279 | else { 280 | val element = Elogbeta_doc.copy 281 | for (i <- 0 to element.rows - 1) { 282 | element(i, ::) := (element(i, ::).t :* doc_word_counts).t 283 | } 284 | val product: BDM[Double] = phi.t * element.t 285 | for (i <- 0 until product.rows) { 286 | product(i, ::) := (product(i, ::).t + Elogsticks_1st).t 287 | } 288 | 289 | var var_phi: BDM[Double] = product 290 | val (log_var_phi, log_norm) = OnlineHDPOptimizer.log_normalize(var_phi) 291 | var_phi = exp(log_var_phi) 292 | (log_var_phi, var_phi) 293 | } 294 | 295 | val (log_phi, log_norm) = 296 | // phi 297 | if (iter < 3) { 298 | phi = (var_phi * Elogbeta_doc).t 299 | val (log_phi, log_norm) = OnlineHDPOptimizer.log_normalize(phi) 300 | phi = exp(log_phi) 301 | (log_phi, log_norm) 302 | } 303 | else { 304 | // K * T T * Wt 305 | val product: BDM[Double] = (var_phi * Elogbeta_doc).t 306 | for (i <- 0 until product.rows) { 307 | product(i, ::) := (product(i, ::).t + Elogsticks_2nd).t 308 | } 309 | phi = product 310 | val (log_phi, log_norm) = OnlineHDPOptimizer.log_normalize(phi) 311 | phi = exp(log_phi) 312 | (log_phi, log_norm) 313 | } 314 | 315 | 316 | // v 317 | val phi_all = phi.copy 318 | for (i <- 0 until phi_all.cols) { 319 | phi_all(::, i) := (phi_all(::, i)) :* doc_word_counts 320 | } 321 | 322 | v(0, ::) := sum(phi_all(::, m_K - 1)) + 1.0 323 | val selected = phi_all(::, 1 until m_K) 324 | val t_sum = sum(selected(::, *)).toDenseVector 325 | val phi_cum = flipud(t_sum) 326 | v(1, ::) := (flipud(accumulate(phi_cum)) + m_alpha).t 327 | Elogsticks_2nd = OnlineHDPOptimizer.expect_log_sticks(v) 328 | 329 | likelihood = 0.0 330 | // compute likelihood 331 | // var_phi part/ C in john's notation 332 | 333 | val diff = log_var_phi.copy 334 | for (i <- 0 to diff.rows - 1) { 335 | diff(i, ::) := (Elogsticks_1st :- diff(i, ::).t).t 336 | } 337 | 338 | likelihood += sum(diff :* var_phi) 339 | 340 | // v part/ v in john's notation, john's beta is alpha here 341 | val log_alpha = log(m_alpha) 342 | likelihood += (m_K - 1) * log_alpha 343 | val dig_sum = (digamma(sum(v(::, *)))).toDenseVector 344 | val vCopy = v.copy 345 | for (i <- 0 until v.cols) { 346 | vCopy(::, i) := BDV[Double](1.0, m_alpha) - vCopy(::, i) 347 | } 348 | 349 | val dv = digamma(v) 350 | for (i <- 0 until v.rows) { 351 | dv(i, ::) := dv(i, ::) - dig_sum.t 352 | } 353 | 354 | likelihood += sum(vCopy :* dv) 355 | likelihood -= sum(lgamma(sum(v(::, *)))) - sum(lgamma(v)) 356 | 357 | // Z part 358 | val log_phiCopy = log_phi.copy 359 | for (i <- 0 until log_phiCopy.rows) { 360 | log_phiCopy(i, ::) := (Elogsticks_2nd - log_phiCopy(i, ::).t).t 361 | } 362 | likelihood += sum(log_phiCopy :* phi) 363 | 364 | // X part, the data part 365 | val Elogbeta_docCopy = Elogbeta_doc.copy 366 | for (i <- 0 until Elogbeta_docCopy.rows) { 367 | Elogbeta_docCopy(i, ::) := (Elogbeta_docCopy(i, ::).t :* doc_word_counts).t 368 | } 369 | 370 | likelihood += sum(phi.t :* (var_phi * Elogbeta_docCopy)) 371 | 372 | val converge = (likelihood - old_likelihood) / abs(old_likelihood) 373 | old_likelihood = likelihood 374 | 375 | if (converge < -0.000001) 376 | println("likelihood is decreasing!") 377 | 378 | iter += 1 379 | var_phi_out = var_phi 380 | } 381 | 382 | // update the suff_stat ss 383 | // this time it only contains information from one doc 384 | val sumPhiOut = sum(var_phi_out(::, *)) 385 | ss.m_var_sticks_ss += sumPhiOut.toDenseVector 386 | 387 | val phiCopy = phi.copy.t 388 | for (i <- 0 until phi.rows) { 389 | phiCopy(i, ::) := (phiCopy(i, ::).t :* doc_word_counts).t 390 | } 391 | 392 | val middleResult: BDM[Double] = var_phi_out.t * phiCopy 393 | for (i <- 0 until chunkids.size) { 394 | ss.m_var_beta_ss(::, chunkids(i)) := ss.m_var_beta_ss(::, chunkids(i)) + middleResult(::, i) 395 | } 396 | 397 | return likelihood 398 | } 399 | 400 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/topicModeling/OnlineLDAOptimizer.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib.topicModeling 2 | 3 | import java.util.Random 4 | 5 | import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, sum, max} 6 | import breeze.numerics.{abs, digamma, exp, log, lgamma} 7 | import breeze.stats.distributions.{Gamma, RandBasis} 8 | import org.apache.spark.mllib.linalg._ 9 | import org.apache.spark.rdd.RDD 10 | import org.apache.spark.util.collection.AppendOnlyMap 11 | import scala.collection.mutable.ArrayBuffer 12 | 13 | /** 14 | * :: Experimental :: 15 | * 16 | * An Optimizer based on OnlineLDAOptimizer, that can also output the the document~topic 17 | * distribution 18 | * 19 | * An early version of the implementation was merged into MLlib (PR #4419), and several extensions (e.g., predict) are added here 20 | * 21 | */ 22 | object OnlineLDAOptimizer{ 23 | def inference( 24 | k: Int, 25 | vocabSize: Int, 26 | lambda: BDM[Double], 27 | alpha: Double, 28 | gammaShape: Double, 29 | batch: RDD[(Long, Vector)]): 30 | (BDM[Double], RDD[BDM[Double]], RDD[(Long, BDV[Double])]) = { 31 | 32 | val Elogbeta = dirichletExpectation(lambda) 33 | val expElogbeta = exp(Elogbeta) 34 | 35 | val statsAndgammaArray: (RDD[(BDM[Double], Array[(Long, BDV[Double])])]) = 36 | batch.mapPartitions { docs => 37 | val stat = BDM.zeros[Double](k, vocabSize) 38 | val gammaList = new ArrayBuffer[(Long, BDV[Double])] 39 | docs.foreach { doc => 40 | val termCounts = doc._2 41 | val (ids: List[Int], cts: Array[Double]) = termCounts match { 42 | case v: DenseVector => ((0 until v.size).toList, v.values) 43 | case v: SparseVector => (v.indices.toList, v.values) 44 | case v => throw new IllegalArgumentException("Online LDA does not support vector type " 45 | + v.getClass) 46 | } 47 | 48 | // Initialize the variational distribution q(theta|gamma) for the mini-batch 49 | var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K 50 | var Elogthetad = digamma(gammad) - digamma(sum(gammad)) // 1 * K 51 | var expElogthetad = exp(Elogthetad) // 1 * K 52 | val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids 53 | 54 | var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids 55 | var meanchange = 1D 56 | val ctsVector = new BDV[Double](cts).t // 1 * ids 57 | 58 | // Iterate between gamma and phi until convergence 59 | while (meanchange > 1e-3) { 60 | val lastgamma = gammad 61 | // 1*K 1 * ids ids * k 62 | gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha 63 | Elogthetad = digamma(gammad) - digamma(sum(gammad)) 64 | expElogthetad = exp(Elogthetad) 65 | phinorm = expElogthetad * expElogbetad + 1e-100 66 | meanchange = sum(abs(gammad - lastgamma)) / k 67 | } 68 | 69 | gammaList += Tuple2(doc._1, gammad.t) 70 | val m1 = expElogthetad.t 71 | val m2 = (ctsVector / phinorm).t.toDenseVector 72 | var i = 0 73 | while (i < ids.size) { 74 | stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i) 75 | i += 1 76 | } 77 | } 78 | Iterator((stat, gammaList.toArray)) 79 | } 80 | 81 | val stats = statsAndgammaArray.map(_._1) 82 | val gammaArray = statsAndgammaArray.flatMap(_._2) 83 | 84 | (expElogbeta, stats, gammaArray) 85 | } 86 | 87 | /** 88 | * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation 89 | * uses digamma which is accurate but expensive. 90 | */ 91 | def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = { 92 | val rowSum = sum(alpha(breeze.linalg.*, ::)) 93 | val digAlpha = digamma(alpha) 94 | val digRowSum = digamma(rowSum) 95 | val result = digAlpha(::, breeze.linalg.*) - digRowSum 96 | result 97 | } 98 | } 99 | 100 | final class OnlineLDAOptimizer extends LDAOptimizer with Serializable{ 101 | // LDA common parameters 102 | private var k: Int = 0 103 | private var corpusSize: Long = 0 104 | private var vocabSize: Int = 0 105 | 106 | /** alias for docConcentration */ 107 | private var alpha: Double = 0 108 | 109 | /** (private[lda] for debugging) Get docConcentration */ 110 | private[topicModeling] def getAlpha: Double = alpha 111 | 112 | /** alias for topicConcentration */ 113 | private var eta: Double = 0 114 | 115 | /** (private[lda] for debugging) Get topicConcentration */ 116 | private[topicModeling] def getEta: Double = eta 117 | 118 | private var randomGenerator: java.util.Random = null 119 | 120 | // Online LDA specific parameters 121 | // Learning rate is: (tau0 + t)^{-kappa} 122 | private var tau0: Double = 1024 123 | private var kappa: Double = 0.51 124 | private var miniBatchFraction: Double = 0.05 125 | 126 | // internal data structure 127 | private var docs: RDD[(Long, Vector)] = null 128 | 129 | /** Dirichlet parameter for the posterior over topics */ 130 | private var lambda: BDM[Double] = null 131 | 132 | /** (private[lda] for debugging) Get parameter for topics */ 133 | private[topicModeling] def getLambda: BDM[Double] = lambda 134 | 135 | /** Current iteration (count of invocations of [[next()]]) */ 136 | private var iteration: Int = 0 137 | private var gammaShape: Double = 100 138 | 139 | /** 140 | * A (positive) learning parameter that downweights early iterations. Larger values make early 141 | * iterations count less. 142 | */ 143 | def getTau0: Double = this.tau0 144 | 145 | /** 146 | * A (positive) learning parameter that downweights early iterations. Larger values make early 147 | * iterations count less. 148 | * Default: 1024, following the original Online LDA paper. 149 | */ 150 | def setTau0(tau0: Double): this.type = { 151 | require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0") 152 | this.tau0 = tau0 153 | this 154 | } 155 | 156 | /** 157 | * Learning rate: exponential decay rate 158 | */ 159 | def getKappa: Double = this.kappa 160 | 161 | /** 162 | * Learning rate: exponential decay rate---should be between 163 | * (0.5, 1.0] to guarantee asymptotic convergence. 164 | * Default: 0.51, based on the original Online LDA paper. 165 | */ 166 | def setKappa(kappa: Double): this.type = { 167 | require(kappa >= 0, s"Online LDA kappa must be nonnegative, but was set to $kappa") 168 | this.kappa = kappa 169 | this 170 | } 171 | 172 | /** 173 | * Mini-batch fraction, which sets the fraction of document sampled and used in each iteration 174 | */ 175 | def getMiniBatchFraction: Double = this.miniBatchFraction 176 | 177 | /** 178 | * Mini-batch fraction in (0, 1], which sets the fraction of document sampled and used in 179 | * each iteration. 180 | * 181 | * Note that this should be adjusted in synch with 182 | * so the entire corpus is used. Specifically, set both so that 183 | * maxIterations * miniBatchFraction >= 1. 184 | * 185 | * Default: 0.05, i.e., 5% of total documents. 186 | */ 187 | def setMiniBatchFraction(miniBatchFraction: Double): this.type = { 188 | require(miniBatchFraction > 0.0 && miniBatchFraction <= 1.0, 189 | s"Online LDA miniBatchFraction must be in range (0,1], but was set to $miniBatchFraction") 190 | this.miniBatchFraction = miniBatchFraction 191 | this 192 | } 193 | 194 | /** 195 | * (private[lda]) 196 | * Set the Dirichlet parameter for the posterior over topics. 197 | * This is only used for testing now. In the future, it can help support training stop/resume. 198 | */ 199 | private[topicModeling] def setLambda(lambda: BDM[Double]): this.type = { 200 | this.lambda = lambda 201 | this 202 | } 203 | 204 | /** 205 | * (private[lda]) 206 | * Used for random initialization of the variational parameters. 207 | * Larger value produces values closer to 1.0. 208 | * This is only used for testing currently. 209 | */ 210 | private[topicModeling] def setGammaShape(shape: Double): this.type = { 211 | this.gammaShape = shape 212 | this 213 | } 214 | 215 | override def initialize(docs: RDD[(Long, Vector)], lda: LDA): this.type = { 216 | this.k = lda.getK 217 | this.corpusSize = docs.count() 218 | this.vocabSize = docs.first()._2.size 219 | this.alpha = if (lda.getDocConcentration == -1) 1.0 / k else lda.getDocConcentration 220 | this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration 221 | this.randomGenerator = new Random(lda.getSeed) 222 | 223 | this.docs = docs 224 | this.lambda = getGammaMatrix(k, vocabSize) 225 | 226 | this.iteration = 0 227 | this 228 | } 229 | 230 | def initialize(corpusSize: Long, vocabSize: Int, lda: LDA): this.type ={ 231 | this.k = lda.getK 232 | this.corpusSize = corpusSize 233 | this.vocabSize = vocabSize 234 | this.alpha = if (lda.getDocConcentration == -1) 1.0 / k else lda.getDocConcentration 235 | this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration 236 | this.randomGenerator = new Random(lda.getSeed) 237 | 238 | // Initialize the variational distribution q(beta|lambda) 239 | this.lambda = getGammaMatrix(k, vocabSize) 240 | this.iteration = 0 241 | this 242 | } 243 | 244 | override def next(): this.type = { 245 | val batch = docs.sample(withReplacement = true, miniBatchFraction, randomGenerator.nextLong()) 246 | submitMiniBatch(batch) 247 | } 248 | 249 | /** 250 | * Submit a subset (like 1%, decide by the miniBatchFraction) of the corpus to the Online LDA 251 | * model, and it will update the topic distribution adaptively for the terms appearing in the 252 | * subset. 253 | */ 254 | private[topicModeling] def submitMiniBatch(batch: RDD[(Long, Vector)]): this.type = { 255 | if (batch.isEmpty()) return this 256 | iteration += 1 257 | 258 | val (expElogbeta, stats, _) = 259 | OnlineLDAOptimizer.inference(k, vocabSize, lambda, alpha, gammaShape, batch) 260 | val statsSum: BDM[Double] = stats.reduce(_ += _) 261 | val batchResult = statsSum :* expElogbeta 262 | 263 | // Note that this is an optimization to avoid batch.count 264 | update(batchResult, iteration, batch.count().toInt) 265 | this 266 | } 267 | 268 | override def getLDAModel(iterationTimes: Array[Double]): LDAModel = { 269 | new OnlineLDAModel(Matrices.fromBreeze(lambda).transpose, this.alpha, this.gammaShape) 270 | } 271 | 272 | /** 273 | * Update lambda based on the batch submitted. batchSize can be different for each iteration. 274 | */ 275 | private[topicModeling] def update(stat: BDM[Double], iter: Int, batchSize: Int): Unit = { 276 | // weight of the mini-batch. 277 | val weight = math.pow(getTau0 + iter, -getKappa) 278 | 279 | // Update lambda based on documents. 280 | lambda = lambda * (1 - weight) + 281 | (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) * weight 282 | } 283 | 284 | /** 285 | * Get a random matrix to initialize lambda 286 | */ 287 | private def getGammaMatrix(row: Int, col: Int): BDM[Double] = { 288 | val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister( 289 | randomGenerator.nextLong())) 290 | val gammaRandomGenerator = new Gamma(gammaShape, 1.0 / gammaShape)(randBasis) 291 | val temp = gammaRandomGenerator.sample(row * col).toArray 292 | new BDM[Double](col, row, temp).t 293 | } 294 | 295 | def perplexity(docs: RDD[(Long, Vector)]): Double = { 296 | val alphaVector = Vectors.dense(Array.fill(k)(alpha)) 297 | val brzAlpha = alphaVector.toBreeze.toDenseVector 298 | 299 | val Elogbeta = OnlineLDAOptimizer.dirichletExpectation(lambda) 300 | 301 | val (_, _, gammaArray) = OnlineLDAOptimizer.inference(k, vocabSize, lambda, alpha, gammaShape, docs) 302 | 303 | var score = docs.join(gammaArray).map { case (id: Long, (termCounts: Vector, gammad: BDV[Double])) => 304 | var docScore = 0.0D 305 | 306 | val Elogthetad: BDV[Double] = digamma(gammad) - digamma(sum(gammad)) 307 | 308 | // E[log p(doc | theta, beta)] 309 | termCounts.foreachActive { case (idx, count) => 310 | val x = Elogthetad + Elogbeta(::, idx) 311 | val a = max(x) 312 | docScore += count * (a + log(sum(exp(x :- a)))) 313 | } 314 | // E[log p(theta | alpha) - log q(theta | gamma)]; assumes alpha is a vector 315 | docScore += sum((brzAlpha - gammad) :* Elogthetad) 316 | docScore += sum(lgamma(gammad) - lgamma(brzAlpha)) 317 | docScore += lgamma(sum(brzAlpha)) - lgamma(sum(gammad)) 318 | 319 | docScore 320 | }.sum() 321 | 322 | // E[log p(beta | eta) - log q (beta | lambda)]; assumes eta is a scalar 323 | score += sum((eta - lambda) :* Elogbeta) 324 | score += sum(lgamma(lambda) - lgamma(eta)) 325 | 326 | val sumEta = eta * vocabSize 327 | score += sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*)))) 328 | 329 | math.exp(-1 * score / vocabSize) 330 | } 331 | } 332 | 333 | --------------------------------------------------------------------------------