├── LICENSE
├── README.md
├── pom.xml
└── src
└── main
└── scala
└── org
└── apache
└── spark
└── mllib
└── topicModeling
├── GibbsLDAOptimizer.scala
├── GibbsLDASampler.scala
├── LDA.scala
├── LDAExample.scala
├── LDAModel.scala
├── LDAOptimizer.scala
├── OnlineHDP.scala
└── OnlineLDAOptimizer.scala
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Topic Modeling on Apache Spark
2 | This package contains a set of distributed text modeling algorithms implemented on Spark, including:
3 |
4 | - **Online LDA**: an early version of the implementation was merged into MLlib (PR #4419), and several extensions (e.g., predict) are added here
5 |
6 | - **Gibbs sampling LDA**: the implementation is adapted from Spark PRs(#1405 and #4807) and JIRA SPARK-5556 (https://github.com/witgo/spark/tree/lda_Gibbs, https://github.com/EntilZha/spark/tree/LDA-Refactor, https://github.com/witgo/zen/tree/lda_opt/ml, etc.), with several extensions (e.g., support for MLlib interface, predict and in-place state update) added
7 |
8 | - **Online HDP (hierarchical Dirichlet process)**: implemented based on the paper "Online Variational Inference for the Hierarchical Dirichlet Process" (Chong Wang, John Paisley and David M. Blei)
9 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
3 | 4.0.0
4 | org.apache.spark.mllib
5 | topicModel
6 | 1.0-SNAPSHOT
7 |
8 |
9 | 1.6
10 | 1.6
11 | UTF-8
12 | 2.10.4
13 |
14 |
15 |
16 |
17 |
18 |
19 | net.alchim31.maven
20 | scala-maven-plugin
21 | 3.1.5
22 |
23 |
24 | org.apache.maven.plugins
25 | maven-compiler-plugin
26 | 2.0.2
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 | net.alchim31.maven
35 | scala-maven-plugin
36 |
37 |
38 | scala-compile-first
39 | process-resources
40 |
41 | add-source
42 | compile
43 |
44 |
45 |
46 | scala-test-compile
47 | process-test-resources
48 |
49 | testCompile
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 | maven-assembly-plugin
58 | 2.4
59 |
60 |
61 | jar-with-dependencies
62 |
63 |
64 |
65 |
66 | make-assembly
67 | package
68 |
69 | single
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 | org.scala-lang
81 | scala-library
82 | ${scala.version}
83 |
84 |
85 |
86 | org.apache.spark
87 | spark-core_2.10
88 | 1.3.1
89 |
90 |
91 | com.github.scopt
92 | scopt_2.10
93 | 3.2.0
94 |
95 |
96 | org.apache.spark
97 | spark-mllib_2.10
98 | 1.3.1
99 |
100 |
101 | org.scalanlp
102 | breeze_2.10
103 | 0.11.2
104 |
105 |
106 | junit
107 | junit
108 |
109 |
110 | org.apache.commons
111 | commons-math3
112 |
113 |
114 |
115 |
116 |
117 |
118 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/mllib/topicModeling/GibbsLDAOptimizer.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.mllib.topicModeling
19 |
20 | import java.util.Random
21 |
22 | import org.apache.hadoop.io.NullWritable
23 | import org.apache.hadoop.mapred.TextOutputFormat
24 | import org.apache.spark.graphx._
25 | import org.apache.spark.graphx.impl.GraphImpl
26 | import org.apache.spark.{HashPartitioner, Logging, Partitioner}
27 | import org.apache.spark.rdd.RDD
28 | import org.apache.spark.serializer.KryoRegistrator
29 | import org.apache.spark.storage.StorageLevel
30 | import org.apache.spark.SparkContext
31 | import org.apache.hadoop.fs._
32 |
33 | import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV, StorageVector,
34 | sum => brzSum, norm => brzNorm, DenseMatrix => BDM, Matrix => BM, CSCMatrix => BSM}
35 |
36 | import org.apache.spark.mllib.linalg.{DenseVector => SDV, SparseVector => SSV,
37 | DenseMatrix => SDM, SparseMatrix => SSM, Vector => SV, Matrix}
38 |
39 | import GibbsLDAOptimizer._
40 |
41 | import scala.reflect.ClassTag
42 | import org.json4s.DefaultFormats
43 | import org.json4s.JsonDSL._
44 | import org.json4s.jackson.JsonMethods._
45 | import org.apache.hadoop.io.{NullWritable, Text}
46 |
47 | /**
48 | *
49 | * Adapted from Spark PRs(#1405 and #4807) and JIRA SPARK-5556 (https://github.com/witgo/spark/tree/lda_Gibbs,
50 | * https://github.com/EntilZha/spark/tree/LDA-Refactor, https://github.com/witgo/zen/tree/lda_opt/ml, etc.),
51 | * with several extensions (e.g., support for MLlib interface, predict and in-place state update) added
52 | */
53 | class GibbsLDAOptimizer private[topicModeling](
54 | private var alphaAS: Float,
55 | private var storageLevel: StorageLevel,
56 | private var sampler:GibbsLDASampler = new GibbsLDAAliasSampler,
57 | var edgePartitioner:String = "none",
58 | var printPerplexity:Boolean = false)
59 | extends LDAOptimizer with Serializable with Logging {
60 |
61 | def this() = this(0.1f, StorageLevel.MEMORY_AND_DISK)
62 |
63 | @transient private var corpus:Graph[VD, ED] = null
64 | private var alpha = 0.01f
65 | private var beta = 0.01f
66 | private var numTopics = 0
67 | private var numTerms = 0
68 | private var numTokens = 0L
69 | private var checkpointInterval = Int.MaxValue
70 |
71 | /**
72 | * Initializer for the optimizer. LDA passes the common parameters to the optimizer and
73 | * the internal structure can be initialized properly.
74 | */
75 | override def initialize(docs: RDD[(Long, SV)], lda: LDA): LDAOptimizer = {
76 | alpha = lda.getAlpha.toFloat
77 | beta = lda.getBeta.toFloat
78 | numTopics = lda.getK
79 | numTerms = docs.first()._2.size
80 | seed = lda.getSeed.toInt
81 | setCheckpointInterval(lda.getCheckpointInterval)
82 | corpus = initializeCorpus(docs, lda.getK, storageLevel, edgePartitioner)
83 | numTokens = corpus.edges.map(e => e.attr.size.toDouble).sum().toLong
84 | totalTopicCounter = collectTotalTopicCounter(corpus, numTopics, numTokens)
85 | this
86 | }
87 |
88 | private var lastSampledCorpus:Option[Graph[VD, ED]] = None
89 |
90 | private def termVertices = corpus.vertices.filter(t => t._1>=0)
91 |
92 | private def docVertices = corpus.vertices.filter(t => t._1<0)
93 |
94 | /**
95 | * run an iteration
96 | * @return
97 | */
98 | def next(): LDAOptimizer = {
99 |
100 | gibbsSampling()
101 |
102 | if (printPerplexity) {
103 | println(s"Perplexity of $innerIter-th is ${perplexity}")
104 | }
105 | this
106 | }
107 |
108 | private def gibbsSampling(): Unit = {
109 | val sampledCorpus = sampler.sampleTokens(corpus, totalTopicCounter, innerIter + seed,
110 | numTokens, numTopics, numTerms, alpha, alphaAS, beta)
111 | sampledCorpus.edges.persist(storageLevel)
112 | sampledCorpus.edges.count()
113 |
114 | val counterCorpus = updateCounter(sampledCorpus, numTopics)
115 | checkpoint(counterCorpus, innerIter, checkpointInterval)
116 | counterCorpus.vertices.persist(storageLevel)
117 | counterCorpus.vertices.count()
118 |
119 | totalTopicCounter = collectTotalTopicCounter(counterCorpus, numTopics, numTokens)
120 |
121 | corpus.edges.unpersist(false)
122 | corpus.vertices.unpersist(false)
123 | lastSampledCorpus.map(_.edges.unpersist(false))
124 | lastSampledCorpus = Some(sampledCorpus)
125 | corpus = counterCorpus
126 |
127 | innerIter += 1
128 | }
129 |
130 | /**
131 | * get model of current iteration
132 | * @param iterationTimes
133 | * @return
134 | */
135 | override def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
136 | val ldaModel:GibbsLDAModel = saveModel(1)
137 | ldaModel
138 | }
139 |
140 | def setCheckpointInterval(cpInterval: Int): this.type = {
141 | this.checkpointInterval = cpInterval
142 | this
143 | }
144 |
145 | def getCheckpointInterval(): Int = this.checkpointInterval
146 |
147 | def setAlphaAS(alphaAS: Float): this.type = {
148 | this.alphaAS = alphaAS
149 | this
150 | }
151 |
152 | def getAlphaAS():Float = this.alphaAS
153 |
154 | def setStorageLevel(newStorageLevel: StorageLevel): this.type = {
155 | this.storageLevel = newStorageLevel
156 | this
157 | }
158 |
159 | def getStorageLevel(): StorageLevel = this.storageLevel
160 |
161 | def setSeed(newSeed: Int): this.type = {
162 | this.seed = newSeed
163 | this
164 | }
165 |
166 | def getSeed():Int = this.seed
167 |
168 | def setSampler(sampler:GibbsLDASampler): this.type = {
169 | this.sampler = sampler
170 | this
171 | }
172 |
173 | def getSampler(): GibbsLDASampler = {
174 | this.sampler
175 | }
176 |
177 | def setSampler(sampler:String): this.type = {
178 | this.sampler =
179 | sampler.toLowerCase match {
180 | case "alias" => new GibbsLDAAliasSampler
181 | case "sparse" => new GibbsLDASparseSampler
182 | case "light" => new GibbsLDALightSampler
183 | case "fast" => new GibbsLDAFastSampler
184 | case _ =>
185 | throw new IllegalArgumentException(s"Only alias, sparse, light are supported but got $sampler.")
186 | }
187 | this
188 | }
189 |
190 | def getCorpus = corpus
191 |
192 | @transient private var seed = new Random().nextInt()
193 | @transient private var innerIter = 1
194 | @transient private var totalTopicCounter: BDV[Count] = null
195 |
196 | // scalastyle:off
197 | /**
198 | * p(w)=\sum_{k}{p(k|d)*p(w|k)}=
199 | * \sum_{k}{\frac{{n}_{kw}+{\beta }_{w}} {{n}_{k}+\bar{\beta }} \frac{{n}_{kd}+{\alpha }_{k}} {\sum{{n}_{k}}+\bar{\alpha }}}=
200 | * \sum_{k} \frac{{\alpha }_{k}{\beta }_{w} + {n}_{kw}{\alpha }_{k} + {n}_{kd}{\beta }_{w} + {n}_{kw}{n}_{kd}}{{n}_{k}+\bar{\beta }} \frac{1}{\sum{{n}_{k}}+\bar{\alpha }}}
201 | * \exp^{-(\sum{\log(p(w))})/N}
202 | * N is total token number within the corpus
203 | */
204 | // scalastyle:on
205 | def perplexity(): Double = {
206 | val totalTopicCounter = this.totalTopicCounter
207 | val numTopics = this.numTopics
208 | val numTerms = this.numTerms
209 | val alpha = this.alpha
210 | val beta = this.beta
211 |
212 | val totalSize = brzSum(totalTopicCounter) & 0x00000000FFFFFFFFL
213 | var totalProb = 0D
214 | // \frac{{\alpha }_{k}{\beta }_{w}}{{n}_{k}+\bar{\beta }}
215 | totalTopicCounter.activeIterator.foreach { case (topic, cn) =>
216 | totalProb += alpha * beta / (cn + numTerms * beta)
217 | }
218 | val termProb = corpus.mapVertices { (vid, counter) =>
219 | var probDist = 0D
220 | if (vid >= 0) {
221 | val termTopicCounter = counter
222 | // \frac{{n}_{kw}{\alpha }_{k}}{{n}_{k}+\bar{\beta }}
223 | termTopicCounter.activeIterator.filter(_._2 > 0).foreach { case (topic, cn) =>
224 | probDist += cn * alpha /
225 | (totalTopicCounter(topic) + numTerms * beta)
226 | }
227 | } else {
228 | val docTopicCounter = counter
229 | // \frac{{n}_{kd}{\beta }_{w}}{{n}_{k}+\bar{\beta }}
230 | docTopicCounter.activeIterator.filter(_._2 > 0).foreach { case (topic, cn) =>
231 | probDist += cn * beta /
232 | (totalTopicCounter(topic) + numTerms * beta)
233 | }
234 | }
235 | (counter, probDist)
236 | }.mapTriplets { triplet =>
237 | val (termTopicCounter, termProb) = triplet.srcAttr
238 | val (docTopicCounter, docProb) = triplet.dstAttr
239 | val docSize = docTopicCounter.sum
240 | val docTermSize = triplet.attr.length
241 | var prob = 0D
242 | // \frac{{n}_{kw}{n}_{kd}}{{n}_{k}+\bar{\beta}}
243 | docTopicCounter.activeIterator.filter(_._2 > 0).foreach { case (topic, cn) =>
244 | prob += ((termTopicCounter(topic) / (totalTopicCounter(topic) + numTerms * beta))*cn)
245 | }
246 | prob += docProb + termProb + totalProb
247 | prob += prob / (docSize + numTopics * alpha)
248 | docTermSize * Math.log(prob)
249 | }.edges.map(t => t.attr).sum()
250 | math.exp(-1 * termProb / totalSize)
251 | }
252 |
253 | /**
254 | * Save the term-topic related model
255 | * @param totalIter
256 | */
257 | def saveModel(totalIter: Int = 1): GibbsLDAModel = {
258 | var termTopicCounter: RDD[(VertexId, VD)] = null
259 | for (iter <- 1 to totalIter) {
260 | logInfo(s"Save TopicModel (Iteration $iter/$totalIter)")
261 | var previousTermTopicCounter = termTopicCounter
262 | gibbsSampling()
263 | val newTermTopicCounter = termVertices
264 | termTopicCounter = Option(termTopicCounter).map(_.join(newTermTopicCounter).map {
265 | case (term, (a, b)) =>
266 | var c: VD = null
267 | if(a.isInstanceOf[BSV[Count]] && b.isInstanceOf[BSV[Count]]) {
268 | c = a.asInstanceOf[BSV[Count]] :+ b.asInstanceOf[BSV[Count]]
269 | } else if(a.isInstanceOf[BDV[Count]] && b.isInstanceOf[BDV[Count]]){
270 | c = a.asInstanceOf[BDV[Count]] :+ b.asInstanceOf[BDV[Count]]
271 | } else if(a.isInstanceOf[BDV[Count]]) {
272 | c = a.asInstanceOf[BDV[Count]] :+ b.asInstanceOf[BSV[Count]].toDenseVector
273 | } else {
274 | c = a.asInstanceOf[BSV[Count]].toDenseVector :+ b.asInstanceOf[BDV[Count]]
275 | }
276 | (term, c)
277 | }).getOrElse(newTermTopicCounter)
278 |
279 | termTopicCounter.persist(storageLevel).count()
280 | Option(previousTermTopicCounter).foreach(_.unpersist(blocking = false))
281 | previousTermTopicCounter = termTopicCounter
282 | }
283 | val rand = new Random()
284 | val ttc = termTopicCounter.mapValues(c => {
285 | if (c.isInstanceOf[BDV[Count]]) {
286 | val dv = c.asInstanceOf[BDV[Count]]
287 | val nc = new BDV[Count](dv.data.map (v => {
288 | val mid = v.toDouble / totalIter
289 | val l = math.floor(mid)
290 | if (rand.nextDouble() > mid - l) {
291 | l
292 | } else {
293 | l + 1
294 | }
295 | }.asInstanceOf[Count]))
296 | nc.asInstanceOf[StorageVector[Count]]
297 | } else {
298 | val sv = c.asInstanceOf[BSV[Count]]
299 | val nc = new BSV[Count](sv.index.slice(0, sv.used), sv.data.slice(0, sv.used).map(v => {
300 | val mid = v.toDouble / totalIter
301 | val l = math.floor(mid)
302 | if (rand.nextDouble() > mid - l) {
303 | l
304 | } else {
305 | l + 1
306 | }
307 | }.asInstanceOf[Count]), c.length)
308 | nc
309 | }
310 | })
311 |
312 | ttc.persist(storageLevel)
313 | val gtc = ttc.map(_._2).aggregate(BDV.zeros[Count](numTopics))(_ :+= _,_ :+= _)
314 | new GibbsLDAModel(gtc, ttc, corpus.vertices, numTerms, alpha, beta, alphaAS)
315 | }
316 | }
317 |
318 | object GibbsLDAOptimizer {
319 |
320 | private[topicModeling] type DocId = VertexId
321 | private[topicModeling] type WordId = VertexId
322 | private[topicModeling] type Count = Int
323 | private[topicModeling] type ED = Array[Count]
324 | private[topicModeling] type VD = StorageVector[Count]
325 |
326 | def checkpoint(corpus: Graph[VD, ED], innerIter: Int, checkpointInterval: Int): Unit = {
327 | if (innerIter % checkpointInterval == 0 && corpus.edges.sparkContext.getCheckpointDir.isDefined) {
328 | corpus.checkpoint()
329 | }
330 | }
331 |
332 | def collectTotalTopicCounter(graph: Graph[VD, ED], numTopics: Int, numTokens: Long): BDV[Count] = {
333 | val globalTopicCounter = collectGlobalCounter(graph, numTopics)
334 | val totalSize = brzSum(globalTopicCounter) & 0x00000000FFFFFFFFL
335 | println("numTokens:"+numTokens+" brzSum:"+totalSize)
336 | assert(totalSize == numTokens)
337 | globalTopicCounter
338 | }
339 |
340 | def updateCounter(graph: Graph[VD, ED], numTopics: Int): Graph[VD, ED] = {
341 | val newCounter = graph.aggregateMessages[BSV[Count]](ctx => {
342 | val topics = ctx.attr
343 | val vector = BSV.zeros[Count](numTopics)
344 | for (topic <- topics) {
345 | vector(topic) += 1
346 | }
347 | ctx.sendToDst(vector)
348 | ctx.sendToSrc(vector)
349 | }, _ + _, TripletFields.EdgeOnly)
350 | .mapValues(sparseVector => {
351 | val storageVector:VD =
352 | if (sparseVector.activeSize > sparseVector.length / 2) {
353 | sparseVector.toDenseVector
354 | } else {
355 | sparseVector
356 | }
357 | storageVector
358 | })
359 | // GraphImpl.fromExistingRDDs(newCounter, graph.edges)
360 | GraphImpl(newCounter, graph.edges)
361 | }
362 |
363 | def collectGlobalCounter(graph: Graph[VD, ED], numTopics: Int): BDV[Count] = {
364 | graph.vertices.filter(t => t._1 >= 0).map(_._2).aggregate(BDV.zeros[Count](numTopics))((a, b) => {
365 | a :+= b
366 | }, _ :+= _)
367 | }
368 |
369 | def initializeCorpus(
370 | docs: RDD[(GibbsLDAOptimizer.DocId, SV)],
371 | numTopics: Int,
372 | storageLevel: StorageLevel, edgePartitioner:String): Graph[VD, ED] = {
373 | val edges = docs.mapPartitionsWithIndex((pid, iter) => {
374 | val gen = new Random(pid)
375 | iter.flatMap {
376 | case (docId, doc) =>
377 | initializeEdges(gen, doc, docId, numTopics)
378 | }
379 | })
380 | edges.persist(storageLevel)
381 | val corpus: Graph[VD, ED] = edgePartitioner match {
382 | case "none" =>
383 | Graph.fromEdges(edges, null, storageLevel, storageLevel)
384 | case "degree" =>
385 | val degreeCorpus = Graph.fromEdges(edges, null, storageLevel, storageLevel)
386 | val degrees = degreeCorpus.outerJoinVertices(degreeCorpus.degrees) { (vid, data, deg) => deg.getOrElse(0) }
387 | val numPartitions = edges.partitions.size
388 | val partitionStrategy = new DBHPartitioner(numPartitions)
389 | val newEdges = degrees.triplets.map { e =>
390 | (partitionStrategy.getPartition(e), Edge(e.srcId, e.dstId, e.attr))
391 | }.partitionBy(new HashPartitioner(numPartitions)).map(_._2)
392 | Graph.fromEdges(newEdges, null, storageLevel, storageLevel)
393 | case "docIdCluster" =>
394 | val newEdges = edges.map { e =>
395 | (e.dstId, Edge(e.srcId, e.dstId, e.attr))
396 | }.partitionBy(new HashPartitioner(docs.partitions.size)).map(_._2)
397 | Graph.fromEdges(newEdges, null, storageLevel, storageLevel)
398 |
399 | case _ =>
400 | throw new IllegalArgumentException(s"invalid values of edgePartitioner are none, degree and docIdCluster, but got $edgePartitioner")
401 | }
402 |
403 | val resultCorpus = updateCounter(corpus, numTopics).cache()
404 | resultCorpus.vertices.count()
405 | resultCorpus.edges.count()
406 | corpus.unpersist(false)
407 | edges.unpersist(false)
408 | docs.unpersist(false)
409 | resultCorpus
410 | }
411 |
412 | private def initializeEdges(
413 | gen: Random,
414 | doc: SV,
415 | docId: DocId,
416 | numTopics: Int): Iterator[Edge[ED]] = {
417 | assert(docId >= 0)
418 | val newDocId: DocId = -(docId + 1L)
419 |
420 | doc.toBreeze.activeIterator.filter(_._2 > 0).map {case (termId, counter) =>
421 | val topics = new Array[Int](counter.toInt)
422 | for (i <- 0 until counter.toInt) {
423 | topics(i) = gen.nextInt(numTopics)
424 | }
425 | Edge(termId, newDocId, topics)
426 | }
427 | }
428 | }
429 |
430 | class GibbsLDAModel (
431 | private[topicModeling] val gtc: BDV[Count],
432 | @transient private[topicModeling] val ttc: RDD[(VertexId, VD)],
433 | @transient private[topicModeling] val vertices: RDD[(VertexId, VD)],
434 | val numTerms: Int,
435 | val alpha: Float,
436 | val beta: Float,
437 | val alphaAS: Float) extends org.apache.spark.mllib.topicModeling.LDAModel with Serializable {
438 |
439 | private val numTopics = gtc.size
440 |
441 | private lazy val numTokens = brzSum(gtc) & 0x00000000FFFFFFFFL
442 |
443 | /** Number of topics */
444 | def k: Int = numTopics
445 |
446 | /** Vocabulary size (number of terms or terms in the vocabulary) */
447 | def vocabSize: Int = numTerms
448 |
449 | /**
450 | * Inferred topics, where each topic is represented by a distribution over terms.
451 | * This is a matrix of size vocabSize x k, where each column is a topic.
452 | * No guarantees are given about the ordering of the topics.
453 | */
454 | def topicsMatrix: Matrix = {
455 | val matrix = BDM.zeros[Double](numTerms, numTopics)
456 |
457 | val ttcArray = Array.fill(numTerms.toInt) {
458 | BSV.zeros[Count](numTopics)
459 | }
460 | ttc.collect().foreach { case (termId, vector) =>
461 | ttcArray(termId.toInt) :+= vector
462 | }
463 |
464 | for (termId <- 0 until numTerms) {
465 | val sv = ttcArray(termId)
466 | for (topicId <- 0 until numTopics) {
467 | matrix(termId, topicId) = sv(topicId)
468 | }
469 | }
470 | GibbsLDAUtils.fromBreezeMatrix(matrix)
471 | }
472 |
473 | /**
474 | * Return the topics described by weighted terms.
475 | *
476 | * This limits the number of terms per topic.
477 | * This is approximate; it may not return exactly the top-weighted terms for each topic.
478 | * To get a more precise set of top terms, increase maxTermsPerTopic.
479 | *
480 | * @param maxTermsPerTopic Maximum number of terms to collect for each topic.
481 | * @return Array over topics. Each topic is represented as a pair of matching arrays:
482 | * (term indices, term weights in topic).
483 | * Each topic's terms are sorted in order of decreasing weight.
484 | */
485 | def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = {
486 | val ttcArray = Array.fill(numTerms.toInt) {
487 | BSV.zeros[Count](numTopics)
488 | }
489 | ttc.collect().foreach { case (termId, vector) =>
490 | ttcArray(termId.toInt) :+= vector
491 | }
492 |
493 | (0 until numTopics).map(topicId => {
494 | val terms = (0 until numTerms).map(termId => (termId, ttcArray(termId)(topicId).toDouble))
495 | .sortBy(_._2)(Ordering[Double].reverse)
496 | .take(maxTermsPerTopic)
497 | .map(_._1).toArray
498 | val weights = terms.map(ttcArray(_)(topicId).toDouble)
499 | val wsum = weights.sum
500 | if (wsum > 1E-5) {
501 | (0 until weights.length).foreach(weights(_) /= wsum)
502 | }
503 | (terms, weights)
504 | }).toArray
505 | }
506 |
507 | /**
508 | * For each document in the training set, return the distribution over topics for that document
509 | * ("theta_doc").
510 | *
511 | * @return RDD of (document ID, topic distribution) pairs
512 | */
513 | def predict(documents: RDD[(Long, SV)],
514 | optimizer: GibbsLDAOptimizer,
515 | totalIter: Int = 25,
516 | burnIn: Int = 22): RDD[(Long, SV)] = {
517 |
518 | val previousCorpus: Graph[VD, ED] = initializeCorpus(documents, numTopics,
519 | optimizer.getStorageLevel(), optimizer.edgePartitioner)
520 |
521 | var corpus = previousCorpus.outerJoinVertices(ttc) { (vid, c, v) =>
522 | if (vid >= 0) {
523 | assert(v.isDefined)
524 | }
525 | v.getOrElse(c)
526 | }
527 | corpus.persist(optimizer.getStorageLevel)
528 | corpus.vertices.count()
529 | corpus.edges.count()
530 | previousCorpus.edges.unpersist(blocking = false)
531 | previousCorpus.vertices.unpersist(blocking = false)
532 |
533 | var docTopicCounter: RDD[(VertexId, VD)] = null
534 |
535 | for(i <- 1 to totalIter) {
536 | val previousCorpus = corpus
537 |
538 | val sampledCorpus = optimizer.getSampler().sampleTokens(corpus, gtc, i + optimizer.getSeed(),
539 | numTokens, numTopics, numTerms, alpha, alphaAS, beta)
540 | sampledCorpus.persist(optimizer.getStorageLevel)
541 |
542 | corpus = updateCounter(sampledCorpus, numTopics)
543 | checkpoint(corpus, i, optimizer.getCheckpointInterval)
544 | corpus.persist(optimizer.getStorageLevel)
545 |
546 | previousCorpus.edges.unpersist(false)
547 | previousCorpus.vertices.unpersist(false)
548 |
549 | sampledCorpus.edges.unpersist(false)
550 | sampledCorpus.vertices.unpersist(false)
551 |
552 | if (i > burnIn) {
553 | var previousDocTopicCounter = docTopicCounter
554 | val newTermTopicCounter = corpus.vertices.filter(t => t._1 < 0)
555 | docTopicCounter = Option(docTopicCounter).map(_.join(newTermTopicCounter).map {
556 | case (docId, (a, b)) => {
557 | var c: VD = null
558 | if(a.isInstanceOf[BSV[Count]] && b.isInstanceOf[BSV[Count]]) {
559 | c = a.asInstanceOf[BSV[Count]] :+ b.asInstanceOf[BSV[Count]]
560 | } else if(a.isInstanceOf[BDV[Count]] && b.isInstanceOf[BDV[Count]]){
561 | c = a.asInstanceOf[BDV[Count]] :+ b.asInstanceOf[BDV[Count]]
562 | } else if(a.isInstanceOf[BDV[Count]]) {
563 | c = a.asInstanceOf[BDV[Count]] :+ b.asInstanceOf[BSV[Count]].toDenseVector
564 | } else {
565 | c = a.asInstanceOf[BSV[Count]].toDenseVector :+ b.asInstanceOf[BDV[Count]]
566 | }
567 | (docId, c)
568 | }
569 | }).getOrElse(newTermTopicCounter)
570 |
571 | docTopicCounter.persist(optimizer.getStorageLevel).count()
572 | Option(previousDocTopicCounter).foreach(_.unpersist(blocking = false))
573 | previousDocTopicCounter = docTopicCounter
574 | }
575 | }
576 | docTopicCounter.map { case (docId, v) =>
577 | if(v.isInstanceOf[BDV[Count]]) {
578 | val dv = v.asInstanceOf[BDV[Count]]
579 | val norm = brzNorm(dv, 1)
580 | (docId, GibbsLDAUtils.fromBreezeConv[Double](dv.map(_.toDouble) / norm ))
581 | } else {
582 | val sv = v.asInstanceOf[BSV[Count]]
583 | val norm = brzNorm(sv, 1)
584 | (docId, GibbsLDAUtils.fromBreezeConv[Double](sv.map(_.toDouble) / norm))
585 | }
586 | }
587 | }
588 |
589 | /**
590 | * For each document in the training set, return the distribution over topics for that document
591 | *
592 | * @return RDD of (document ID, topic distribution) pairs
593 | */
594 | def docTopicDistributions: RDD[(VertexId, VD)] = {
595 | vertices.filter(t => t._1<0)
596 | }
597 |
598 | /**
599 | * For each term in the training set, return the distribution over topics for that term
600 | *
601 | * @return RDD of (term ID, topic distribution) pairs
602 | */
603 | def termTopicDistributions: RDD[(VertexId, VD)] = {
604 | vertices.filter(t => t._1>=0)
605 | }
606 | }
607 |
608 | private[topicModeling] object GibbsLDAUtils {
609 |
610 | private def _conv[T1: ClassTag, T2: ClassTag](data: Array[T1]): Array[T2] = {
611 | data.map(_.asInstanceOf[T2]).array
612 | }
613 |
614 | def fromBreezeConv[T: ClassTag](breezeVector: BV[T]): SV = {
615 | implicit val conv: Array[T] => Array[Double] = _conv[T, Double]
616 |
617 | breezeVector match {
618 | case v: BDV[T] =>
619 | if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) {
620 | new SDV(v.data)
621 | } else {
622 | new SDV(v.toArray) // Can't use underlying array directly, so make a new one
623 | }
624 | case v: BSV[T] =>
625 | if (v.index.length == v.used) {
626 | new SSV(v.length, v.index, _conv[T, Double](v.data))
627 | } else {
628 | new SSV(v.length, v.index.slice(0, v.used), v.data.slice(0, v.used))
629 | }
630 | case v: BV[T] =>
631 | sys.error("Unsupported Breeze vector type: " + v.getClass.getName)
632 | }
633 | }
634 |
635 | def binarySearchInterval(
636 | index: Array[Float],
637 | key: Float,
638 | begin: Int,
639 | end: Int,
640 | greater: Boolean): Int = {
641 | if (begin == end) {
642 | return if (greater) end else begin - 1
643 | }
644 | var b = begin
645 | var e = end - 1
646 |
647 | var mid: Int = (e + b) >> 1
648 | while (b <= e) {
649 | mid = (e + b) >> 1
650 | val v = index(mid)
651 | if (scala.math.abs(key-v)<=1e-6) {
652 | return mid
653 | }
654 | else if (v > key) {
655 | e = mid - 1
656 | }
657 | else {
658 | b = mid + 1
659 | }
660 | }
661 | val v = index(mid)
662 | mid = if ((greater && v >= key) || (!greater && v <= key)) {
663 | mid
664 | }
665 | else if (greater) {
666 | mid + 1
667 | }
668 | else {
669 | mid - 1
670 | }
671 |
672 | if (greater) {
673 | if (mid < end) assert(index(mid) >= key)
674 | if (mid > 0) assert(index(mid - 1) <= key)
675 | } else {
676 | if (mid > 0) assert(index(mid) <= key)
677 | if (mid < end - 1) assert(index(mid + 1) >= key)
678 | }
679 | mid
680 | }
681 |
682 | /**
683 | * Creates a Matrix instance from a breeze matrix.
684 | * @param breeze a breeze matrix
685 | * @return a Matrix instance
686 | */
687 | def fromBreezeMatrix(breeze: BM[Double]): Matrix = {
688 | breeze match {
689 | case dm: BDM[Double] =>
690 | require(dm.majorStride == dm.rows,
691 | "Do not support stride size different from the number of rows.")
692 | new SDM(dm.rows, dm.cols, dm.data)
693 | case sm: BSM[Double] =>
694 | new SSM(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data)
695 | case _ =>
696 | throw new UnsupportedOperationException(
697 | s"Do not support conversion from type ${breeze.getClass.getName}.")
698 | }
699 | }
700 |
701 | /**
702 | * Creates a vector instance from a breeze vector.
703 | */
704 | def fromBreezeVector(breezeVector: BV[Double]): SV = {
705 | breezeVector match {
706 | case v: BDV[Double] =>
707 | if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) {
708 | new SDV(v.data)
709 | } else {
710 | new SDV(v.toArray) // Can't use underlying array directly, so make a new one
711 | }
712 | case v: BSV[Double] =>
713 | if (v.index.length == v.used) {
714 | new SSV(v.length, v.index, v.data)
715 | } else {
716 | new SSV(v.length, v.index.slice(0, v.used), v.data.slice(0, v.used))
717 | }
718 | case v: BV[_] =>
719 | sys.error("Unsupported Breeze vector type: " + v.getClass.getName)
720 | }
721 | }
722 |
723 | def saveLDAModel(sc: SparkContext, ldamodel: GibbsLDAModel, path: String) = {
724 | val metadataPath = new Path(path, "metadata").toUri.toString
725 | val ttcPath = new Path(path, "ttc").toUri.toString
726 | val verticesPath = new Path(path, "vertices").toUri.toString
727 |
728 | val metaData = compact(render
729 | (("alpha" -> ldamodel.alpha) ~ ("beta" -> ldamodel.beta) ~
730 | ("alphaAS" -> ldamodel.alphaAS) ~ ("numTerms" -> ldamodel.numTerms) ~
731 | ("numTopics" -> ldamodel.k)))
732 |
733 | sc.parallelize(Seq(metaData)).saveAsTextFile(metadataPath)
734 | ldamodel.ttc.map { case (id, vector) =>
735 | val list = vector.activeIterator.toList.sortWith((a, b) => a._2>b._2)
736 | (NullWritable.get(), new Text(id + "\t" + list.map(item => item._1 + ":" + item._2).mkString("\t")))
737 | }.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](ttcPath)
738 |
739 | ldamodel.vertices.map { case (id, vector) =>
740 | val list = vector.activeIterator.toList.sortWith((a, b) => a._2>b._2)
741 | (NullWritable.get(), new Text(id + "\t" + list.map(item => item._1 + ":" + item._2).mkString("\t")))
742 | }.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](verticesPath)
743 | }
744 |
745 | def loadLDAModel(sc: SparkContext, path: String): GibbsLDAModel = {
746 | val metadataPath = new Path(path, "metadata").toUri.toString
747 | val ttcPath = new Path(path, "ttc").toUri.toString
748 | val verticesPath = new Path(path, "vertices").toUri.toString
749 |
750 | implicit val formats = DefaultFormats
751 | val metaData = parse(sc.textFile(metadataPath).first())
752 | val alpha = (metaData \ "alpha").extract[Float]
753 | val beta = (metaData \ "beta").extract[Float]
754 | val alphaAS = (metaData \ "alphaAS").extract[Float]
755 | val numTerms = (metaData \ "numTerms").extract[Int]
756 | val numTopics = (metaData \ "numTopics").extract[Int]
757 |
758 | val ttc = sc.textFile(ttcPath).map { line =>
759 | val sv = BSV.zeros[Count](numTopics)
760 | val arr = line.split("\t")
761 | arr.tail.foreach { sub =>
762 | val Array(index, value) = sub.split(":")
763 | sv(index.toInt) = value.toInt
764 | }
765 | sv.compact()
766 | (arr.head.toLong, sv.asInstanceOf[StorageVector[Count]])
767 | }.cache()
768 |
769 | val vertices = sc.textFile(verticesPath).map { line =>
770 | val sv = BSV.zeros[Count](numTopics)
771 | val arr = line.split("\t")
772 | arr.tail.foreach { sub =>
773 | val Array(index, value) = sub.split(":")
774 | sv(index.toInt) = value.toInt
775 | }
776 | sv.compact()
777 | (arr.head.toLong, sv.asInstanceOf[StorageVector[Count]])
778 | }
779 |
780 | val gtc = ttc.map(_._2).aggregate(BDV.zeros[Count](numTopics))(_ :+= _,_ :+= _)
781 | new GibbsLDAModel(gtc, ttc, vertices, numTerms, alpha, beta, alphaAS)
782 | }
783 | }
784 |
785 | /**
786 | * Degree-Based Hashing, the paper:
787 | * http://nips.cc/Conferences/2014/Program/event.php?ID=4569
788 | * @param partitions
789 | */
790 | private class DBHPartitioner(partitions: Int) extends Partitioner {
791 | val mixingPrime: Long = 1125899906842597L
792 |
793 | def numPartitions = partitions
794 |
795 | def getPartition(key: Any): Int = {
796 | val edge = key.asInstanceOf[EdgeTriplet[Int, ED]]
797 | val srcDeg = edge.srcAttr
798 | val dstDeg = edge.dstAttr
799 | val srcId = edge.srcId
800 | val dstId = edge.dstId
801 | val minId = if (srcDeg < dstDeg) srcId else dstId
802 | getPartition(minId)
803 | }
804 |
805 | def getPartition(idx: Int): PartitionID = {
806 | (math.abs(idx * mixingPrime) % partitions).toInt
807 | }
808 |
809 | def getPartition(idx: Long): PartitionID = {
810 | (math.abs(idx * mixingPrime) % partitions).toInt
811 | }
812 |
813 | override def equals(other: Any): Boolean = other match {
814 | case h: DBHPartitioner =>
815 | h.numPartitions == numPartitions
816 | case _ =>
817 | false
818 | }
819 |
820 | override def hashCode: Int = numPartitions
821 | }
822 |
823 | private[topicModeling] class LDAKryoRegistrator extends KryoRegistrator {
824 | def registerClasses(kryo: com.esotericsoftware.kryo.Kryo) {
825 | val gkr = new GraphKryoRegistrator
826 | gkr.registerClasses(kryo)
827 |
828 | kryo.register(classOf[BSV[GibbsLDAOptimizer.Count]])
829 | kryo.register(classOf[BSV[Double]])
830 |
831 | kryo.register(classOf[BDV[GibbsLDAOptimizer.Count]])
832 | kryo.register(classOf[BDV[Double]])
833 |
834 | kryo.register(classOf[SV])
835 | kryo.register(classOf[SSV])
836 | kryo.register(classOf[SDV])
837 |
838 | kryo.register(classOf[GibbsLDAOptimizer.ED])
839 | kryo.register(classOf[GibbsLDAOptimizer.VD])
840 |
841 | kryo.register(classOf[Random])
842 | kryo.register(classOf[GibbsLDAOptimizer])
843 | kryo.register(classOf[GibbsLDAModel])
844 | }
845 | }
846 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/mllib/topicModeling/GibbsLDASampler.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 | package org.apache.spark.mllib.topicModeling
18 |
19 | import java.lang.ref.SoftReference
20 | import java.util
21 | import org.apache.spark.Logging
22 | import org.apache.spark.graphx.impl.GraphImpl
23 | import org.apache.spark.graphx.{TripletFields, VertexId, Graph}
24 | import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV, sum => brzSum, StorageVector}
25 |
26 | import java.util.Random
27 | import org.apache.spark.util.random.XORShiftRandom
28 | import org.apache.spark.util.collection.AppendOnlyMap
29 | import scala.reflect.ClassTag
30 |
31 | trait GibbsLDASampler {
32 | type VD = GibbsLDAOptimizer.VD
33 | type ED = GibbsLDAOptimizer.ED
34 | type Count = GibbsLDAOptimizer.Count
35 |
36 | def sampleTokens(graph: Graph[GibbsLDAOptimizer.VD, ED],
37 | totalTopicCounter: BDV[Count],
38 | innerIter: Long,
39 | numTokens: Long,
40 | numTopics: Int,
41 | numTerms: Int,
42 | alpha: Float,
43 | alphaAS: Float,
44 | beta: Float): Graph[VD, ED]
45 | }
46 |
47 | //private[topicModeling] object TermAliasTableCache {
48 | //
49 | // type VD = GibbsLDAOptimizer.VD
50 | // type ED = GibbsLDAOptimizer.ED
51 | // type Count = GibbsLDAOptimizer.Count
52 | //
53 | // private var content = new AppendOnlyMap[Int, SoftReference[(GibbsAliasTable, Float)]]()
54 | //
55 | // private def wSparse(
56 | // totalTopicCounter: BDV[Count],
57 | // termTopicCounter: VD,
58 | // numTokens: Long,
59 | // numTerms: Int,
60 | // alpha: Float,
61 | // alphaAS: Float,
62 | // beta: Float): (Float, BSV[Float]) = {
63 | // val numTopics = totalTopicCounter.length
64 | // val alphaSum = alpha * numTopics
65 | // val termSum = numTokens - 1f + alphaAS * numTopics
66 | // val betaSum = numTerms * beta
67 | // val w = BSV.zeros[Float](numTopics)
68 | // var sum = 0.0f
69 | // termTopicCounter.activeIterator.filter(_._2 > 0).foreach { t =>
70 | // val topic = t._1
71 | // val count = t._2
72 | // val last = count * alphaSum * (totalTopicCounter(topic) + alphaAS) /
73 | // ((totalTopicCounter(topic) + betaSum) * termSum)
74 | // w(topic) = last
75 | // sum += last
76 | // }
77 | // (sum, w)
78 | // }
79 | //
80 | // def get(totalTopicCounter: BDV[Count],
81 | // termTopicCounter: VD,
82 | // termId: VertexId,
83 | // numTokens: Long,
84 | // numTerms: Int,
85 | // alpha: Float,
86 | // alphaAS: Float,
87 | // beta: Float): (GibbsAliasTable, Float) = {
88 | // if(content(termId.toInt) == null) {
89 | // synchronized {
90 | // if(content(termId.toInt) == null) {
91 | // val sv = wSparse(totalTopicCounter, termTopicCounter,
92 | // numTokens, numTerms, alpha, alphaAS, beta)
93 | //
94 | // val table = new GibbsAliasTable(sv._2.activeSize)
95 | // GibbsAliasTable.generateAlias(sv._2, sv._1, table)
96 | //
97 | // content(termId.toInt) = new SoftReference((table, sv._1))
98 | // }
99 | // }
100 | // }
101 | // content(termId.toInt).get
102 | // }
103 | //
104 | // def clear() = {
105 | // content = new AppendOnlyMap[Int, SoftReference[(GibbsAliasTable, Float)]]()
106 | // }
107 | //}
108 |
109 | private[topicModeling] class GibbsAliasTable(initUsed: Int) extends Serializable {
110 |
111 | private var _l: Array[Int] = new Array[Int](initUsed)
112 | private var _h: Array[Int] = new Array[Int](initUsed)
113 | private var _p: Array[Float] = new Array[Float](initUsed)
114 | private var _used = initUsed
115 |
116 | def l: Array[Int] = _l
117 |
118 | def h: Array[Int] = _h
119 |
120 | def p: Array[Float] = _p
121 |
122 | def used: Int = _used
123 |
124 | def length: Int = size
125 |
126 | def size: Int = l.length
127 |
128 | def sampleAlias(gen: Random): Int = {
129 | val bin = gen.nextInt(_used)
130 | val prob = _p(bin)
131 | if (_used * prob > gen.nextFloat()) {
132 | _l(bin)
133 | } else {
134 | _h(bin)
135 | }
136 | }
137 |
138 | private[GibbsAliasTable] def reset(newSize: Int): this.type = {
139 | if (_l.length < newSize) {
140 | _l = new Array[Int](newSize)
141 | _h = new Array[Int](newSize)
142 | _p = new Array[Float](newSize)
143 | }
144 | _used = newSize
145 | this
146 | }
147 | }
148 |
149 | private[topicModeling] object GibbsAliasTable {
150 | def generateAlias(sv: BV[Float]): GibbsAliasTable = {
151 | val used = sv.activeSize
152 | val sum = brzSum(sv)
153 | val probs = sv.activeIterator.slice(0, used)
154 | generateAlias(probs, sum, used)
155 | }
156 |
157 | def generateAlias(probs: Iterator[(Int, Float)], sum: Float, used: Int): GibbsAliasTable = {
158 | val table = new GibbsAliasTable(used)
159 | generateAlias(probs, sum, used, table)
160 | }
161 |
162 | def generateAlias(
163 | probs: Iterator[(Int, Float)],
164 | sum: Float,
165 | used: Int,
166 | table: GibbsAliasTable): GibbsAliasTable = {
167 | table.reset(used)
168 | val pMean = 1.0f / used
169 |
170 | val lq = new util.ArrayDeque[(Int, Float)]()
171 | val hq = new util.ArrayDeque[(Int, Float)]()
172 |
173 | probs.slice(0, used).foreach { pair =>
174 | val i = pair._1
175 | val pi = pair._2 / sum
176 | if (pi < pMean) {
177 | if(pMean-pi<=1e-6) {
178 | lq.addFirst(i, pi)
179 | } else {
180 | lq.add((i, pi))
181 | }
182 | } else {
183 | if(pi-pMean<=1e-6) {
184 | hq.addFirst(i, pi)
185 | } else {
186 | hq.add((i, pi))
187 | }
188 | }
189 | }
190 |
191 | var offset = 0
192 | while (!lq.isEmpty & !hq.isEmpty) {
193 | val (i, pi) = lq.removeLast()
194 | val (h, ph) = hq.removeLast()
195 | table.l(offset) = i
196 | table.h(offset) = h
197 | table.p(offset) = pi
198 | val pd = ph - (pMean - pi)
199 | if (pd >= pMean) {
200 | if(pd-pMean<=1e-6) {
201 | hq.addFirst(h, pd)
202 | } else {
203 | hq.add((h, pd))
204 | }
205 | } else {
206 | if(pMean-pd<=1e-6) {
207 | lq.addFirst(h, pd)
208 | } else {
209 | lq.add((h, pd))
210 | }
211 | }
212 | offset += 1
213 | }
214 | while (!hq.isEmpty) {
215 | val (h, ph) = hq.removeLast()
216 | // assert(ph - pMean < 1e-4)
217 | table.l(offset) = h
218 | table.h(offset) = h
219 | table.p(offset) = ph
220 | offset += 1
221 | }
222 |
223 | while (!lq.isEmpty) {
224 | val (i, pi) = lq.removeLast()
225 | // assert(pMean - pi < 1e-4)
226 | table.l(offset) = i
227 | table.h(offset) = i
228 | table.p(offset) = pi
229 | offset += 1
230 | }
231 | table
232 | }
233 |
234 | def generateAlias(sv: BV[Float], sum: Float): GibbsAliasTable = {
235 | val used = sv.activeSize
236 | val probs = sv.activeIterator.slice(0, used)
237 | generateAlias(probs, sum, used)
238 | }
239 |
240 | def generateAlias(sv: BV[Float], sum: Float, table: GibbsAliasTable): GibbsAliasTable = {
241 | val used = sv.activeSize
242 | val probs = sv.activeIterator.slice(0, used)
243 | generateAlias(probs, sum, used, table)
244 | }
245 |
246 | }
247 |
248 | class GibbsLDAAliasSampler extends GibbsLDASampler with Logging with Serializable{
249 | def sampleTokens(graph: Graph[GibbsLDAOptimizer.VD, ED],
250 | totalTopicCounter: BDV[Count],
251 | innerIter: Long,
252 | numTokens: Long,
253 | numTopics: Int,
254 | numTerms: Int,
255 | alpha: Float,
256 | alphaAS: Float,
257 | beta: Float): Graph[VD, ED] = {
258 | val parts = graph.edges.partitions.size
259 |
260 | val newGraph = graph.mapTriplets(
261 | (pid, iter) => {
262 | val gen = new XORShiftRandom(parts * innerIter + pid)
263 | // table is a per term data structure
264 | // in GraphX, edges in a partition are clustered by source IDs (term id in this case)
265 | // so, use below simple cache to avoid calculating table each time
266 | val lastTable = new GibbsAliasTable(numTopics)
267 | var lastVid: VertexId = -1
268 | var lastWSum = 0.0f
269 | val dv = tDense(totalTopicCounter, numTokens, numTerms, alpha, alphaAS, beta)
270 | val dData = new Array[Float](numTopics.toInt)
271 | val t = GibbsAliasTable.generateAlias(dv._2, dv._1)
272 | val tSum = dv._1
273 |
274 | iter.map {
275 | triplet =>
276 | val termId = triplet.srcId
277 | val docId = triplet.dstId
278 | val termTopicCounter = triplet.srcAttr
279 | val docTopicCounter = triplet.dstAttr
280 | val topics = triplet.attr.clone()
281 |
282 | for (i <- 0 until topics.length) {
283 | val currentTopic = topics(i)
284 | dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, dData,
285 | currentTopic, numTokens, numTerms, alpha, alphaAS, beta)
286 |
287 | if (lastVid != termId || gen.nextDouble() < 1e-4) {
288 | lastWSum = wordTable(lastTable, totalTopicCounter, termTopicCounter,
289 | termId, numTokens, numTerms, alpha, alphaAS, beta)
290 | lastVid = termId
291 | }
292 |
293 | val newTopic = tokenSampling(gen, t, tSum, lastTable, termTopicCounter, lastWSum,
294 | docTopicCounter, dData, currentTopic)
295 |
296 | if (newTopic != currentTopic) {
297 | topics(i) = newTopic
298 | docTopicCounter(currentTopic) -= 1
299 | docTopicCounter(newTopic) += 1
300 | // if (docTopicCounter(currentTopic) == 0) docTopicCounter.compact()
301 |
302 | termTopicCounter(currentTopic) -= 1
303 | termTopicCounter(newTopic) += 1
304 | // if (termTopicCounter(currentTopic) == 0) termTopicCounter.compact()
305 |
306 | totalTopicCounter(currentTopic) -= 1
307 | totalTopicCounter(newTopic) += 1
308 | }
309 | }
310 | topics
311 | }
312 | }, TripletFields.All)
313 | GraphImpl(newGraph.vertices.mapValues(t => null), newGraph.edges)
314 | }
315 |
316 | // scalastyle:off
317 | /**
318 | * the formula is
319 | * t = \frac{{\beta }_{w} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} ) } {({n}_{k}^{-di}+\bar{\beta}) ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})}
320 | */
321 | // scalastyle:on
322 | private def tDense(
323 | totalTopicCounter: BDV[Count],
324 | numTokens: Long,
325 | numTerms: Int,
326 | alpha: Float,
327 | alphaAS: Float,
328 | beta: Float): (Float, BDV[Float]) = {
329 | val numTopics = totalTopicCounter.length
330 | val t = BDV.zeros[Float](numTopics)
331 | val alphaSum = alpha * numTopics
332 | val termSum = numTokens - 1F + alphaAS * numTopics
333 | val betaSum = numTerms * beta
334 | var sum = 0.0f
335 | for (topic <- 0 until numTopics) {
336 | val last = beta * alphaSum * (totalTopicCounter(topic) + alphaAS) /
337 | ((totalTopicCounter(topic) + betaSum) * termSum)
338 | t(topic) = last
339 | sum += last
340 | }
341 | (sum, t)
342 | }
343 |
344 | // scalastyle:off
345 | /**
346 | * the formula is:
347 | * d = \frac{{n}_{kd} ^{-di}({\sum{n}_{k}^{-di} + \bar{\acute{\alpha}}})({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})}
348 | * = \frac{{n}_{kd} ^{-di}({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta}) }
349 | */
350 | // scalastyle:on
351 | private def dSparse(
352 | totalTopicCounter: BDV[Count],
353 | termTopicCounter: VD,
354 | docTopicCounter: VD,
355 | d: Array[Float],
356 | currentTopic: Int,
357 | numTokens: Long,
358 | numTerms: Int,
359 | alpha: Float,
360 | alphaAS: Float,
361 | beta: Float): Unit = {
362 | // val termSum = numTokens - 1D + alphaAS * numTopics
363 | val betaSum = numTerms * beta
364 | var sum = 0.0f
365 | var i = 0
366 |
367 | docTopicCounter.activeIterator.filter(_._2 > 0).foreach { t =>
368 | val topic = t._1
369 | val count = t._2
370 | val adjustment = if (currentTopic == topic) -1F else 0
371 | val last = (count + adjustment) * (termTopicCounter(topic) + adjustment + beta) /
372 | (totalTopicCounter(topic) + adjustment + betaSum)
373 | // val lastD = (count + adjustment) * termSum * (termTopicCounter(topic) + adjustment + beta) /
374 | // ((totalTopicCounter(topic) + adjustment + betaSum) * termSum)
375 |
376 | sum += last
377 | d(i) = sum
378 | i += 1
379 | }
380 | }
381 |
382 | private def wordTable(
383 | table: GibbsAliasTable,
384 | totalTopicCounter: BDV[Count],
385 | termTopicCounter: VD,
386 | termId: VertexId,
387 | numTokens: Long,
388 | numTerms: Int,
389 | alpha: Float,
390 | alphaAS: Float,
391 | beta: Float): Float = {
392 | val sv = wSparse(totalTopicCounter, termTopicCounter,
393 | numTokens, numTerms, alpha, alphaAS, beta)
394 | GibbsAliasTable.generateAlias(sv._2, sv._1, table)
395 | sv._1
396 | }
397 |
398 | // scalastyle:off
399 | /**
400 | * use both Gibbs sampler and Metropolis Hastings sampler
401 | * Complexity is O(1)
402 | * Use formula (3) from the Gibbs sampler paper: Rethinking LDA: Why Priors Matter
403 | * \frac{{n}_{kw}^{-di}+{\beta }_{w}}{{n}_{k}^{-di}+\bar{\beta}} \frac{{n}_{kd} ^{-di}+ \bar{\alpha} \frac{{n}_{k}^{-di} + \acute{\alpha}}{\sum{n}_{k} +\bar{\acute{\alpha}}}}{\sum{n}_{kd}^{-di} +\bar{\alpha}}
404 | * = t + w + d
405 | * t the global part
406 | * t = \frac{{\beta }_{w} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} ) } {({n}_{k}^{-di}+\bar{\beta}) ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})}
407 | * w: term related
408 | * w = \frac{ {n}_{kw}^{-di} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} )}{({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})}
409 | * d: product of doc and term
410 | * d = \frac{{n}_{kd}^{-di}({\sum{n}_{k}^{-di} + \bar{\acute{\alpha}}})({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})}
411 | * = \frac{{n}_{kd ^{-di}({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta}) }
412 | * where:
413 | * \bar{\beta}=\sum_{w}{\beta}_{w}
414 | * \bar{\alpha}=\sum_{k}{\alpha}_{k}
415 | * \bar{\acute{\alpha}}=\barval lq = new util.ArrayDeque[(Int, Float)]()
416 | val hq = new util.ArrayDeque[(Int, Float)](){\acute{\alpha}}=\sum_{k}\acute{\alpha}
417 | * {n}_{kd} number of tokens in document d that are assigned to topic k
418 | * {n}_{kw} number of tokens with word w (across all docs) that are assigned to topic k
419 | * {n}_{k} number of tokens across all docs that are assigned to topic k
420 | * -di substract topic of current token
421 | */
422 | // scalastyle:on
423 | private def tokenSampling(gen: Random,
424 | t: GibbsAliasTable,
425 | tSum: Float,
426 | w: GibbsAliasTable,
427 | termTopicCounter: VD,
428 | wSum: Float,
429 | docTopicCounter: VD,
430 | dData: Array[Float],
431 | currentTopic: Int): Int = {
432 | val used = docTopicCounter.activeSize
433 | val dSum = dData(used - 1)
434 | val distSum = tSum + wSum + dSum
435 | val genSum = gen.nextFloat() * distSum
436 | if (genSum < dSum) {
437 | val dGenSum = gen.nextFloat() * dSum
438 | val pos = GibbsLDAUtils.binarySearchInterval(dData, dGenSum, 0, used, true)
439 | docTopicCounter.indexAt(pos)
440 | } else if (genSum < (dSum + wSum)) {
441 | sampleSV(gen, w, termTopicCounter, currentTopic)
442 | } else {
443 | t.sampleAlias(gen)
444 | }
445 | }
446 |
447 | // scalastyle:off
448 | /**
449 | * the formula is:
450 | * w = \frac{ {n}_{kw}^{-di} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} )}{({n}_{k}^{-di}+\bar{\beta}) ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})}
451 | */
452 | // scalastyle:on
453 | private def wSparse(
454 | totalTopicCounter: BDV[Count],
455 | termTopicCounter: VD,
456 | numTokens: Long,
457 | numTerms: Int,
458 | alpha: Float,
459 | alphaAS: Float,
460 | beta: Float): (Float, BSV[Float]) = {
461 | val numTopics = totalTopicCounter.length
462 | val alphaSum = alpha * numTopics
463 | val termSum = numTokens - 1F + alphaAS * numTopics
464 | val betaSum = numTerms * beta
465 | val w = BSV.zeros[Float](numTopics)
466 | var sum = 0.0f
467 | termTopicCounter.activeIterator.filter(_._2 > 0).foreach { t =>
468 | val topic = t._1
469 | val count = t._2
470 | val last = count * alphaSum * (totalTopicCounter(topic) + alphaAS) /
471 | ((totalTopicCounter(topic) + betaSum) * termSum)
472 | w(topic) = last
473 | sum += last
474 | }
475 | (sum, w)
476 | }
477 |
478 | private def sampleSV(gen: Random, table: GibbsAliasTable, sv: VD, currentTopic: Int): Int = {
479 | var docTopic = table.sampleAlias(gen)
480 |
481 | while(docTopic == currentTopic) {
482 |
483 | val svCounter = sv(currentTopic)
484 |
485 | if ((svCounter == 1 && table.used > 1) ||
486 | (svCounter > 1 && gen.nextFloat() < 1.0f / svCounter)) {
487 | docTopic = table.sampleAlias(gen)
488 | } else {
489 | return docTopic
490 | }
491 | }
492 | docTopic
493 | }
494 | }
495 |
496 | class GibbsLDAFastSampler extends GibbsLDASampler with Serializable with Logging {
497 | def sampleTokens(graph: Graph[GibbsLDAOptimizer.VD, ED],
498 | totalTopicCounter: BDV[Count],
499 | innerIter: Long,
500 | numTokens: Long,
501 | numTopics: Int,
502 | numTerms: Int,
503 | alpha: Float,
504 | alphaAS: Float,
505 | beta: Float): Graph[VD, ED] = {
506 | val parts = graph.edges.partitions.size
507 | val nweGraph = graph.mapTriplets(
508 | (pid, iter) => {
509 | val gen = new Random(parts * innerIter + pid)
510 | // table is a per term data structure
511 | // in GraphX, edges in a partition are clustered by source IDs (term id in this case)
512 | // so, use below simple cache to avoid calculating table each time
513 | val lastTable = new GibbsAliasTable(numTopics.toInt)
514 | var lastVid: VertexId = -1
515 | var lastWSum = 0.0f
516 | val dv = tDense(totalTopicCounter, numTokens, numTerms, alpha, alphaAS, beta)
517 | val dData = new Array[Float](numTopics.toInt)
518 | val t = GibbsAliasTable.generateAlias(dv._2, dv._1)
519 | val tSum = dv._1
520 | iter.map {
521 | triplet =>
522 | val termId = triplet.srcId
523 | val docId = triplet.dstId
524 | val termTopicCounter = triplet.srcAttr
525 | val docTopicCounter = triplet.dstAttr
526 | val topics = triplet.attr.clone()
527 | for (i <- 0 until topics.length) {
528 | val currentTopic = topics(i)
529 | dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, dData,
530 | currentTopic, numTokens, numTerms, alpha, alphaAS, beta)
531 | if (lastVid != termId) {
532 | lastWSum = wordTable(lastTable, totalTopicCounter, termTopicCounter,
533 | termId, numTokens, numTerms, alpha, alphaAS, beta)
534 | lastVid = termId
535 | }
536 |
537 | val newTopic = tokenSampling(gen, t, tSum, lastTable, termTopicCounter, lastWSum,
538 | docTopicCounter, dData, currentTopic)
539 |
540 | if (newTopic != currentTopic) {
541 | topics(i) = newTopic
542 | }
543 | }
544 |
545 | topics
546 | }
547 | }, TripletFields.All)
548 | GraphImpl(nweGraph.vertices.mapValues(t => null), nweGraph.edges)
549 |
550 | }
551 |
552 | private def tokenSampling(
553 | gen: Random,
554 | t: GibbsAliasTable,
555 | tSum: Float,
556 | w: GibbsAliasTable,
557 | termTopicCounter: VD,
558 | wSum: Float,
559 | docTopicCounter: VD,
560 | dData: Array[Float],
561 | currentTopic: Int): Int = {
562 | val used = docTopicCounter.activeSize
563 | val dSum = dData(docTopicCounter.activeSize - 1)
564 | val distSum = tSum + wSum + dSum
565 | val genSum = gen.nextFloat() * distSum
566 | if (genSum < dSum) {
567 | val dGenSum = gen.nextFloat() * dSum
568 | val pos = GibbsLDAUtils.binarySearchInterval(dData, dGenSum, 0, used, true)
569 | docTopicCounter.indexAt(pos)
570 | } else if (genSum < (dSum + wSum)) {
571 | sampleSV(gen, w, termTopicCounter, currentTopic)
572 | } else {
573 | t.sampleAlias(gen)
574 | }
575 | }
576 |
577 | /**
578 | * dense part in the decomposed sampling formula:
579 | * t = \frac{{\beta }_{w} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} ) } {({n}_{k}^{-di}+\bar{\beta})
580 | * ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})}
581 | */
582 | private def tDense(
583 | totalTopicCounter: BDV[Count],
584 | numTokens: Long,
585 | numTerms: Int,
586 | alpha: Float,
587 | alphaAS: Float,
588 | beta: Float): (Float, BDV[Float]) = {
589 | val numTopics = totalTopicCounter.length
590 | val t = BDV.zeros[Float](numTopics)
591 | val alphaSum = alpha * numTopics
592 | val termSum = numTokens - 1F + alphaAS * numTopics
593 | val betaSum = numTerms * beta
594 | var sum = 0.0f
595 | for (topic <- 0 until numTopics) {
596 | val last = beta * alphaSum * (totalTopicCounter(topic) + alphaAS) /
597 | ((totalTopicCounter(topic) + betaSum) * termSum)
598 | t(topic) = last
599 | sum += last
600 | }
601 | (sum, t)
602 | }
603 |
604 | /**
605 | * word related sparse part in the decomposed sampling formula:
606 | * w = \frac{ {n}_{kw}^{-di} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} )}{({n}_{k}^{-di}+\bar{\beta})
607 | * ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})}
608 | */
609 | private def wSparse(
610 | totalTopicCounter: BDV[Count],
611 | termTopicCounter: VD,
612 | numTokens: Long,
613 | numTerms: Int,
614 | alpha: Float,
615 | alphaAS: Float,
616 | beta: Float): (Float, BSV[Float]) = {
617 | val numTopics = totalTopicCounter.length
618 | val alphaSum = alpha * numTopics
619 | val termSum = numTokens - 1F + alphaAS * numTopics
620 | val betaSum = numTerms * beta
621 | val w = BSV.zeros[Float](numTopics)
622 | var sum = 0.0f
623 | termTopicCounter.activeIterator.filter(_._2 > 0).foreach { t =>
624 | val topic = t._1
625 | val count = t._2
626 | val last = count * alphaSum * (totalTopicCounter(topic) + alphaAS) /
627 | ((totalTopicCounter(topic) + betaSum) * termSum)
628 | w(topic) = last
629 | sum += last
630 | }
631 | (sum, w)
632 | }
633 |
634 | /**
635 | * doc related sparse part in the decomposed sampling formula:
636 | * d = \frac{{n}_{kd} ^{-di}({\sum{n}_{k}^{-di} + \bar{\acute{\alpha}}})({n}_{kw}^{-di}+{\beta}_{w})}
637 | * {({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})}
638 | * = \frac{{n}_{kd} ^{-di}({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta}) }
639 | */
640 | private def dSparse(
641 | totalTopicCounter: BDV[Count],
642 | termTopicCounter: VD,
643 | docTopicCounter: VD,
644 | d: Array[Float],
645 | currentTopic: Int,
646 | numTokens: Long,
647 | numTerms: Int,
648 | alpha: Float,
649 | alphaAS: Float,
650 | beta: Float): Unit = {
651 | val data = docTopicCounter.data
652 | val used = docTopicCounter.activeSize
653 |
654 | // val termSum = numTokens - 1D + alphaAS * numTopics
655 | val betaSum = numTerms * beta
656 | var sum = 0.0f
657 | var i = 0
658 | docTopicCounter.activeIterator.filter(_._2 > 0).foreach { t =>
659 | val topic = t._1
660 | val count = t._2
661 | val adjustment = if (currentTopic == topic) -1F else 0
662 | val last = (count + adjustment) * (termTopicCounter(topic) + adjustment + beta) /
663 | (totalTopicCounter(topic) + adjustment + betaSum)
664 | // val lastD = (count + adjustment) * termSum * (termTopicCounter(topic) + adjustment + beta) /
665 | // ((totalTopicCounter(topic) + adjustment + betaSum) * termSum)
666 | sum += last
667 | d(i) = sum
668 | i += 1
669 | }
670 | }
671 |
672 | private def wordTable(
673 | table: GibbsAliasTable,
674 | totalTopicCounter: BDV[Count],
675 | termTopicCounter: VD,
676 | termId: VertexId,
677 | numTokens: Long,
678 | numTerms: Int,
679 | alpha: Float,
680 | alphaAS: Float,
681 | beta: Float): Float = {
682 | val sv = wSparse(totalTopicCounter, termTopicCounter,
683 | numTokens, numTerms, alpha, alphaAS, beta)
684 | GibbsAliasTable.generateAlias(sv._2, sv._1, table)
685 | sv._1
686 | }
687 |
688 | private def sampleSV(
689 | gen: Random,
690 | table: GibbsAliasTable,
691 | sv: VD,
692 | currentTopic: Int,
693 | currentTopicCounter: Int = 0,
694 | numSampling: Int = 0): Int = {
695 | val docTopic = table.sampleAlias(gen)
696 | if (docTopic == currentTopic && numSampling < 16) {
697 | val svCounter = if (currentTopicCounter == 0) sv(currentTopic) else currentTopicCounter
698 | // TODO: not sure it is correct or not?
699 | // discard it if the newly sampled topic is current topic
700 | if ((svCounter == 1 && table.used > 1) ||
701 | /* the sampled topic that contains current token and other tokens */
702 | (svCounter > 1 && gen.nextDouble() < 1.0 / svCounter)
703 | /* the sampled topic has 1/svCounter probability that belongs to current token */ ) {
704 | return sampleSV(gen, table, sv, currentTopic, svCounter, numSampling + 1)
705 | }
706 | }
707 | docTopic
708 | }
709 | }
710 |
711 | class GibbsLDALightSampler extends GibbsLDASampler with Logging with Serializable {
712 |
713 | private def sampleSV(
714 | gen: Random,
715 | table: GibbsAliasTable,
716 | sv: VD,
717 | currentTopic: Int,
718 | currentTopicCounter: Int = 0,
719 | numSampling: Int = 0): Int = {
720 | val docTopic = table.sampleAlias(gen)
721 | if (docTopic == currentTopic && numSampling < 16) {
722 | val svCounter = if (currentTopicCounter == 0) sv(currentTopic) else currentTopicCounter
723 | // TODO: not sure it is correct or not?
724 | // discard it if the newly sampled topic is current topic
725 | if ((svCounter == 1 && table.used > 1) ||
726 | /* the sampled topic that contains current token and other tokens */
727 | (svCounter > 1 && gen.nextDouble() < 1.0 / svCounter)
728 | /* the sampled topic has 1/svCounter probability that belongs to current token */ ) {
729 | return sampleSV(gen, table, sv, currentTopic, svCounter, numSampling + 1)
730 | }
731 | }
732 | docTopic
733 | }
734 |
735 | def sampleTokens(graph: Graph[GibbsLDAOptimizer.VD, ED],
736 | totalTopicCounter: BDV[Count],
737 | innerIter: Long,
738 | numTokens: Long,
739 | numTopics: Int,
740 | numTerms: Int,
741 | alpha: Float,
742 | alphaAS: Float,
743 | beta: Float): Graph[VD, ED] = {
744 | val parts = graph.edges.partitions.length
745 | val nweGraph = graph.mapTriplets(
746 | (pid, iter) => {
747 | val gen = new Random(parts * innerIter + pid)
748 | val docTableCache = new AppendOnlyMap[VertexId, SoftReference[(Float, GibbsAliasTable)]]()
749 |
750 | // table is a per term data structure
751 | // in GraphX, edges in a partition are clustered by source IDs (term id in this case)
752 | // so, use below simple cache to avoid calculating table each time
753 | val lastTable = new GibbsAliasTable(numTopics.toInt)
754 | var lastVid: VertexId = -1
755 | var lastWSum = 0.0f
756 |
757 | val p = tokenTopicProb(totalTopicCounter, beta, alpha,
758 | alphaAS, numTokens, numTerms) _
759 | val dPFun = docProb(totalTopicCounter, alpha, alphaAS, numTokens) _
760 | val wPFun = wordProb(totalTopicCounter, numTerms, beta) _
761 |
762 | var dD: GibbsAliasTable = null
763 | var dDSum: Float = 0.0f
764 | var wD: GibbsAliasTable = null
765 | var wDSum: Float = 0.0f
766 |
767 | iter.map {
768 | triplet =>
769 | val termId = triplet.srcId
770 | val docId = triplet.dstId
771 | val termTopicCounter = triplet.srcAttr
772 | val docTopicCounter = triplet.dstAttr
773 | val topics = triplet.attr.clone()
774 |
775 | if (dD == null || gen.nextDouble() < 1e-6) {
776 | var dv = dDense(totalTopicCounter, alpha, alphaAS, numTokens)
777 | dDSum = dv._1
778 | dD = GibbsAliasTable.generateAlias(dv._2, dDSum)
779 |
780 | dv = wDense(totalTopicCounter, numTerms, beta)
781 | wDSum = dv._1
782 | wD = GibbsAliasTable.generateAlias(dv._2, wDSum)
783 | }
784 | val (dSum, d) = docTopicCounter.synchronized {
785 | docTable(x => x == null || x.get() == null || gen.nextDouble() < 1e-2,
786 | docTableCache, docTopicCounter, docId)
787 | }
788 | val (wSum, w) = termTopicCounter.synchronized {
789 | if (lastVid != termId || gen.nextDouble() < 1e-4) {
790 | lastWSum = wordTable(lastTable, totalTopicCounter, termTopicCounter, termId, numTerms, beta)
791 | lastVid = termId
792 | }
793 | (lastWSum, lastTable)
794 | }
795 | for (i <- topics.indices) {
796 | var docProposal = gen.nextDouble() < 0.5
797 | var maxSampling = 8
798 | while (maxSampling > 0) {
799 | maxSampling -= 1
800 | docProposal = !docProposal
801 | val currentTopic = topics(i)
802 | var proposalTopic = -1
803 | val q = if (docProposal) {
804 | if (gen.nextFloat() < dDSum / (dSum - 1.0f + dDSum)) {
805 | proposalTopic = dD.sampleAlias(gen)
806 | }
807 | else {
808 | proposalTopic = docTopicCounter.synchronized {
809 | sampleSV(gen, d, docTopicCounter, currentTopic)
810 | }
811 | }
812 | dPFun
813 | } else {
814 | val table = if (gen.nextDouble() < wSum / (wSum + wDSum)) w else wD
815 | proposalTopic = table.sampleAlias(gen)
816 | wPFun
817 | }
818 |
819 | val newTopic = docTopicCounter.synchronized {
820 | termTopicCounter.synchronized {
821 | tokenSampling(gen, docTopicCounter, termTopicCounter, docProposal,
822 | currentTopic, proposalTopic, q, p)
823 | }
824 | }
825 |
826 | assert(newTopic >= 0 && newTopic < numTopics)
827 | if (newTopic != currentTopic) {
828 | topics(i) = newTopic
829 | docTopicCounter.synchronized {
830 | docTopicCounter(currentTopic) -= 1
831 | docTopicCounter(newTopic) += 1
832 | }
833 | termTopicCounter.synchronized {
834 | termTopicCounter(currentTopic) -= 1
835 | termTopicCounter(newTopic) += 1
836 | }
837 | totalTopicCounter(currentTopic) -= 1
838 | totalTopicCounter(newTopic) += 1
839 | }
840 | }
841 | }
842 | topics
843 | }
844 | }, TripletFields.All)
845 | GraphImpl(nweGraph.vertices.mapValues(t => null), nweGraph.edges)
846 | }
847 |
848 | // scalastyle:off
849 | private def tokenTopicProb(
850 | totalTopicCounter: BDV[Count],
851 | beta: Float,
852 | alpha: Float,
853 | alphaAS: Float,
854 | numTokens: Long,
855 | numTerms: Int)(docTopicCounter: VD,
856 | termTopicCounter: VD,
857 | topic: Int,
858 | isAdjustment: Boolean): Float = {
859 | val numTopics = docTopicCounter.length
860 | val adjustment = if (isAdjustment) -1 else 0
861 | val ratio = (totalTopicCounter(topic) + adjustment + alphaAS) /
862 | (numTokens - 1 + alphaAS * numTopics)
863 | val asPrior = ratio * (alpha * numTopics)
864 | // constant terms are removed (docLen - 1 + alpha * numTopics)
865 | (termTopicCounter(topic) + adjustment + beta) *
866 | (docTopicCounter(topic) + adjustment + asPrior) /
867 | (totalTopicCounter(topic) + adjustment + (numTerms * beta))
868 |
869 | // original formula: Rethinking LDA: Why Priors Matter formula (3)
870 | // val docLen = brzSum(docTopicCounter)
871 | // (termTopicCounter(topic) + adjustment + beta) * (docTopicCounter(topic) + adjustment + asPrior) /
872 | // ((totalTopicCounter(topic) + adjustment + (numTerms * beta)) * (docLen - 1 + alpha * numTopics))
873 | }
874 |
875 | // scalastyle:on
876 |
877 | private def wordProb(
878 | totalTopicCounter: BDV[Count],
879 | numTerms: Int,
880 | beta: Float)(termTopicCounter: VD, topic: Int, isAdjustment: Boolean): Float = {
881 | (termTopicCounter(topic) + beta) / (totalTopicCounter(topic) + beta * numTerms)
882 | }
883 |
884 | private def docProb(
885 | totalTopicCounter: BDV[Count],
886 | alpha: Float,
887 | alphaAS: Float,
888 | numTokens: Long)(docTopicCounter: VD, topic: Int, isAdjustment: Boolean): Float = {
889 | val adjustment = if (isAdjustment) -1 else 0
890 | val numTopics = totalTopicCounter.length
891 | val ratio = (totalTopicCounter(topic) + alphaAS) /
892 | (numTokens - 1 + alphaAS * numTopics)
893 | val asPrior = ratio * (alpha * numTopics)
894 | docTopicCounter(topic) + adjustment + asPrior
895 | }
896 |
897 | /**
898 | * \frac{{n}_{kw}}{{n}_{k}+\bar{\beta}}
899 | */
900 | private def wSparse(
901 | totalTopicCounter: BDV[Count],
902 | termTopicCounter: VD,
903 | numTerms: Int,
904 | beta: Float): (Float, BV[Float]) = {
905 | val numTopics = termTopicCounter.length
906 | val termSum = beta * numTerms
907 | val w = BSV.zeros[Float](numTopics)
908 |
909 | var sum = 0.0f
910 | termTopicCounter.activeIterator.foreach { t =>
911 | val topic = t._1
912 | val count = t._2
913 | if (count > 0) {
914 | val last = count / (totalTopicCounter(topic) + termSum)
915 | w(topic) = last
916 | sum += last
917 | }
918 | }
919 | (sum, w)
920 | }
921 |
922 | /**
923 | * \frac{{\beta}_{w}}{{n}_{k}+\bar{\beta}}
924 | */
925 | private def wDense(
926 | totalTopicCounter: BDV[Count],
927 | numTerms: Int,
928 | beta: Float): (Float, BV[Float]) = {
929 | val numTopics = totalTopicCounter.length
930 | val t = BDV.zeros[Float](numTopics)
931 | val termSum = beta * numTerms
932 | var sum = 0.0f
933 | for (topic <- 0 until numTopics) {
934 | val last = beta / (totalTopicCounter(topic) + termSum)
935 | t(topic) = last
936 | sum += last
937 | }
938 | (sum, t)
939 | }
940 |
941 | private def dSparse(docTopicCounter: VD): (Float, BV[Float]) = {
942 | val numTopics = docTopicCounter.length
943 | val d = BSV.zeros[Float](numTopics)
944 | var sum = 0.0f
945 | docTopicCounter.activeIterator.foreach { t =>
946 | val topic = t._1
947 | val count = t._2
948 | if (count > 0) {
949 | val last = count
950 | d(topic) = last
951 | sum += last
952 | }
953 | }
954 | (sum, d)
955 | }
956 |
957 |
958 | private def dDense(
959 | totalTopicCounter: BDV[Count],
960 | alpha: Float,
961 | alphaAS: Float,
962 | numTokens: Long): (Float, BV[Float]) = {
963 | val numTopics = totalTopicCounter.length
964 | val asPrior = BDV.zeros[Float](numTopics)
965 |
966 | var sum = 0.0f
967 | for (topic <- 0 until numTopics) {
968 | val ratio = (totalTopicCounter(topic) + alphaAS) /
969 | (numTokens - 1 + alphaAS * numTopics)
970 | val last = ratio * (alpha * numTopics)
971 | asPrior(topic) = last
972 | sum += last
973 | }
974 | (sum, asPrior)
975 | }
976 |
977 | private def docTable(
978 | updateFunc: SoftReference[(Float, GibbsAliasTable)] => Boolean,
979 | cacheMap: AppendOnlyMap[VertexId, SoftReference[(Float, GibbsAliasTable)]],
980 | docTopicCounter: VD,
981 | docId: VertexId): (Float, GibbsAliasTable) = {
982 | val cacheD = cacheMap(docId)
983 | if (!updateFunc(cacheD)) {
984 | cacheD.get
985 | } else {
986 | docTopicCounter.synchronized {
987 | val sv = dSparse(docTopicCounter)
988 | val d = (sv._1, GibbsAliasTable.generateAlias(sv._2, sv._1))
989 | cacheMap.update(docId, new SoftReference(d))
990 | d
991 | }
992 | }
993 | }
994 |
995 | private def wordTable(
996 | table: GibbsAliasTable,
997 | totalTopicCounter: BDV[Count],
998 | termTopicCounter: VD,
999 | termId: VertexId,
1000 | numTerms: Int,
1001 | beta: Float): Float = {
1002 | val sv = wSparse(totalTopicCounter, termTopicCounter, numTerms, beta)
1003 | GibbsAliasTable.generateAlias(sv._2, sv._1, table)
1004 | sv._1
1005 | }
1006 |
1007 | // scalastyle:off
1008 | /**
1009 | * use both Gibbs sampler and Metropolis Hastings sampler
1010 | * Complexity is O(1)
1011 | * 1. use term related portions of the Gibbs sampler LDA formula
1012 | * LightLDA: Big Topic Models on Modest Compute Clusters, formula(6):
1013 | * ( \frac{{n}_{kd}^{-di}+{\beta }_{w}}{{n}_{k}^{-di}+\bar{\beta }} )
1014 | * 2. use probability sampled from step 1 as Proposal q(.), use Metropolis Hastings sampler Sampling asymmetric transcendental equation
1015 | * reference paper: Rethinking LDA: Why Priors Matter, formula(3)
1016 | * \frac{{n}_{kw}^{-di}+{\beta }_{w}}{{n}_{k}^{-di}+\bar{\beta}} \frac{{n}_{kd} ^{-di}+ \bar{\alpha} \frac{{n}_{k}^{-di} + \acute{\alpha}}{\sum{n}_{k} +\bar{\acute{\alpha}}}}{\sum{n}_{kd}^{-di} +\bar{\alpha}}
1017 | *
1018 | * where
1019 | * \bar{\beta}=\sum_{w}{\beta}_{w}
1020 | * \bar{\alpha}=\sum_{k}{\alpha}_{k}
1021 | * \bar{\acute{\alpha}}=\bar{\acute{\alpha}}=\sum_{k}\acute{\alpha}
1022 | * {n}_{kd} number of tokens in document d that are assigned to topic k
1023 | * {n}_{kw} number of tokens with word w (across all docs) that are assigned to topic k
1024 | * {n}_{k} number of tokens across all docs that are assigned to topic k
1025 | */
1026 | // scalastyle:on
1027 | def tokenSampling(
1028 | gen: Random,
1029 | docTopicCounter: VD,
1030 | termTopicCounter: VD,
1031 | docProposal: Boolean,
1032 | currentTopic: Int,
1033 | proposalTopic: Int,
1034 | q: (VD, Int, Boolean) => Float,
1035 | p: (VD, VD, Int, Boolean) => Float): Int = {
1036 | if (proposalTopic == currentTopic) return proposalTopic
1037 | val cp = p(docTopicCounter, termTopicCounter, currentTopic, true)
1038 | val np = p(docTopicCounter, termTopicCounter, proposalTopic, false)
1039 | val vd = if (docProposal) docTopicCounter else termTopicCounter
1040 | val cq = q(vd, currentTopic, true)
1041 | val nq = q(vd, proposalTopic, false)
1042 |
1043 | val pi = (np * cq) / (cp * nq)
1044 | if (gen.nextDouble() < 1e-32) {
1045 | println(s"Pi: ${pi}")
1046 | println(s"($np * $cq) / ($cp * $nq)")
1047 | }
1048 |
1049 | if (gen.nextDouble() < math.min(1.0, pi)) proposalTopic else currentTopic
1050 | }
1051 | }
1052 |
1053 | class GibbsLDASparseSampler extends GibbsLDASampler with Logging with Serializable {
1054 | def sampleTokens(graph: Graph[GibbsLDAOptimizer.VD, ED],
1055 | totalTopicCounter: BDV[Count],
1056 | innerIter: Long,
1057 | numTokens: Long,
1058 | numTopics: Int,
1059 | numTerms: Int,
1060 | alpha: Float,
1061 | alphaAS: Float,
1062 | beta: Float): Graph[VD, ED] = {
1063 | val parts = graph.edges.partitions.size
1064 | val newGraph = graph.mapTriplets(
1065 | (pid, iter) => {
1066 | val gen = new XORShiftRandom(parts * innerIter + pid)
1067 | val d = BDV.zeros[Float](numTopics)
1068 | var lastTermId:VertexId = -1
1069 | var lastWordTable:BSV[Float] = null
1070 | var tCache: BDV[Float] = null
1071 | iter.map {
1072 | triplet =>
1073 | val termId = triplet.srcId
1074 | val termTopicCounter = triplet.srcAttr
1075 | val docTopicCounter = triplet.dstAttr
1076 | val topics = triplet.attr.clone()
1077 | termTopicCounter.synchronized {
1078 | docTopicCounter.synchronized {
1079 | if (lastTermId != termId || gen.nextDouble() < 1e-4) {
1080 | lastWordTable = w(totalTopicCounter, termTopicCounter, termId,
1081 | numTokens, numTerms, alpha, beta, alphaAS)
1082 | lastTermId = termId
1083 | }
1084 | if (tCache == null || gen.nextDouble() < 1e-7) {
1085 | tCache = this.t(totalTopicCounter, numTokens, numTerms,
1086 | numTopics, alpha, beta, alphaAS)
1087 | }
1088 | val t = tCache
1089 | var i = 0
1090 | while (i < topics.length) {
1091 | val currentTopic = topics(i)
1092 | this.d(totalTopicCounter, termTopicCounter, docTopicCounter, d,
1093 | currentTopic, numTokens, numTerms, numTopics, beta, alphaAS)
1094 | val newTopic = multinomialDistSampler(gen, docTopicCounter, d, lastWordTable, t)
1095 | if (currentTopic != newTopic) {
1096 | topics(i) = newTopic
1097 | docTopicCounter(currentTopic) -= 1
1098 | docTopicCounter(newTopic) += 1
1099 | //if (docTopicCounter(currentTopic) == 0) docTopicCounter.compact()
1100 |
1101 | termTopicCounter(currentTopic) -= 1
1102 | termTopicCounter(newTopic) += 1
1103 | //if (termTopicCounter(currentTopic) == 0) termTopicCounter.compact()
1104 |
1105 | totalTopicCounter(currentTopic) -= 1
1106 | totalTopicCounter(newTopic) += 1
1107 | }
1108 | i += 1
1109 | }
1110 | }
1111 | }
1112 | topics
1113 | }
1114 | }, TripletFields.All)
1115 | GraphImpl(newGraph.vertices.mapValues(t => null), newGraph.edges)
1116 | }
1117 |
1118 | private def w(
1119 | totalTopicCounter: BDV[Count],
1120 | termTopicCounter: VD,
1121 | termId: VertexId,
1122 | numTokens: Long,
1123 | numTerms: Int,
1124 | alpha: Float,
1125 | beta: Float,
1126 | alphaAS: Float): BSV[Float] = {
1127 | val numTopics = totalTopicCounter.length
1128 | val alphaSum = alpha * numTopics
1129 | val termSum = numTokens - 1F + alphaAS * numTopics
1130 | val betaSum = numTerms * beta
1131 | val length = termTopicCounter.length
1132 | val used = termTopicCounter.activeSize
1133 | val index =
1134 | termTopicCounter match {
1135 | case idx:BSV[_] => idx.index.slice(0, used)
1136 | case idx:BDV[_] => (0 until idx.length).toArray
1137 | }
1138 | val data = termTopicCounter.data
1139 | val w = new Array[Float](used)
1140 |
1141 | var lastSum = 0F
1142 | var i = 0
1143 |
1144 | while (i < used) {
1145 | val topic = index(i)
1146 | val count = data(i)
1147 | val lastW = count * alphaSum * (totalTopicCounter(topic) + alphaAS) /
1148 | ((totalTopicCounter(topic) + betaSum) * termSum)
1149 | lastSum += lastW
1150 | w(i) = lastSum
1151 | i += 1
1152 | }
1153 | new BSV[Float](index, w, used, length)
1154 | }
1155 |
1156 | private def t(
1157 | totalTopicCounter: BDV[Count],
1158 | numTokens: Long,
1159 | numTerms: Int,
1160 | numTopics: Int,
1161 | alpha: Float,
1162 | beta: Float,
1163 | alphaAS: Float): BDV[Float] = {
1164 | val t = BDV.zeros[Float](numTopics)
1165 | val alphaSum = alpha * numTopics
1166 | val termSum = numTokens - 1F + alphaAS * numTopics
1167 | val betaSum = numTerms * beta
1168 |
1169 | var lastSum = 0F
1170 | for (topic <- 0 until numTopics) {
1171 | val lastT = beta * alphaSum * (totalTopicCounter(topic) + alphaAS) /
1172 | ((totalTopicCounter(topic) + betaSum) * termSum)
1173 | lastSum += lastT
1174 | t(topic) = lastSum
1175 | }
1176 | t
1177 | }
1178 |
1179 | private def d(
1180 | totalTopicCounter: BDV[Count],
1181 | termTopicCounter: VD,
1182 | docTopicCounter: VD,
1183 | d: BDV[Float],
1184 | currentTopic: Int,
1185 | numTokens: Long,
1186 | numTerms: Int,
1187 | numTopics: Int,
1188 | beta: Float,
1189 | alphaAS: Float): Unit = {
1190 | val used = docTopicCounter.activeSize
1191 | val index =
1192 | docTopicCounter match {
1193 | case idx:BSV[_] => idx.index
1194 | case idx:BDV[_] => (0 until idx.length).toArray
1195 | }
1196 | val data = docTopicCounter.data
1197 |
1198 | // val termSum = numTokens - 1D + alphaAS * numTopics
1199 | val betaSum = numTerms * beta
1200 | var i = 0
1201 | var lastSum = 0F
1202 |
1203 | while (i < used) {
1204 | val topic = index(i)
1205 | val count: Float = data(i)
1206 | val adjustment = if (currentTopic == topic) -1F else 0
1207 | // val lastD = count * termSum * (termTopicCounter(topic) + beta) /
1208 | // ((totalTopicCounter(topic) + betaSum) * termSum)
1209 |
1210 | val lastD = (count + adjustment) * (termTopicCounter(topic) + adjustment + beta) /
1211 | (totalTopicCounter(topic) + adjustment + betaSum)
1212 | lastSum += lastD
1213 | d(topic) = lastSum
1214 | i += 1
1215 | }
1216 | d(numTopics - 1) = lastSum
1217 | }
1218 |
1219 | /**
1220 | * A multinomial distribution sampler, using roulette method to sample an Int back.
1221 | */
1222 | private def multinomialDistSampler(
1223 | gen: Random,
1224 | docTopicCounter: VD,
1225 | d: BDV[Float],
1226 | w: BSV[Float],
1227 | t: BDV[Float]): Int = {
1228 | val numTopics = d.length
1229 | val lastSum = t(numTopics - 1) + w.data(w.used - 1) + d(numTopics - 1)
1230 | val distSum = gen.nextFloat() * lastSum
1231 | val fun = index(docTopicCounter, d, w, t) _
1232 | val topic = binarySearchInterval[Float](fun, distSum, 0, numTopics, true)
1233 | math.min(topic, numTopics - 1)
1234 | }
1235 |
1236 | private def index(
1237 | docTopicCounter: VD,
1238 | d: BDV[Float],
1239 | w: BSV[Float],
1240 | t: BDV[Float])(i: Int) = {
1241 | val lastDS = binarySearchDenseVector(i, docTopicCounter, d)
1242 | val lastWS = binarySearchSparseVector(i, w)
1243 | val lastTS = t(i)
1244 | lastDS + lastWS + lastTS
1245 | }
1246 |
1247 | private[topicModeling] def binarySearchInterval[K](
1248 | index: Int => K,
1249 | key: K,
1250 | begin: Int,
1251 | end: Int,
1252 | greater: Boolean)(implicit ord: Ordering[K], ctag: ClassTag[K]): Int = {
1253 | if (begin == end) {
1254 | return if (greater) end else begin - 1
1255 | }
1256 | var b = begin
1257 | var e = end - 1
1258 |
1259 | var mid: Int = (e + b) >> 1
1260 | while (b <= e) {
1261 | mid = (e + b) >> 1
1262 | val v = index(mid)
1263 | if (ord.lt(v, key)) {
1264 | b = mid + 1
1265 | }
1266 | else if (ord.gt(v, key)) {
1267 | e = mid - 1
1268 | }
1269 | else {
1270 | return mid
1271 | }
1272 | }
1273 |
1274 | val v = index(mid)
1275 | mid = if ((greater && ord.gteq(v, key)) || (!greater && ord.lteq(v, key))) {
1276 | mid
1277 | }
1278 | else if (greater) {
1279 | mid + 1
1280 | }
1281 | else {
1282 | mid - 1
1283 | }
1284 |
1285 | if (greater) {
1286 | if (mid < end) assert(ord.gteq(index(mid), key))
1287 | if (mid > 0) assert(ord.lteq(index(mid - 1), key))
1288 | } else {
1289 | if (mid > 0) assert(ord.lteq(index(mid), key))
1290 | if (mid < end - 1) assert(ord.gteq(index(mid + 1), key))
1291 | }
1292 | mid
1293 | }
1294 |
1295 | private[topicModeling] def binarySearchArray[K](
1296 | index: Array[K],
1297 | key: K,
1298 | begin: Int,
1299 | end: Int,
1300 | greater: Boolean)(implicit ord: Ordering[K], ctag: ClassTag[K]): Int = {
1301 | binarySearchInterval(i => index(i), key, begin, end, greater)
1302 | }
1303 |
1304 | private[topicModeling] def binarySearchSparseVector(topic: Int, sv: BSV[Float]) = {
1305 | val index = sv.index
1306 | val used = sv.used
1307 | val data = sv.data
1308 | val pos = binarySearchArray(index, topic, 0, used, false)
1309 | if (pos > -1) data(pos) else 0F
1310 | }
1311 |
1312 | private[topicModeling] def binarySearchDenseVector[V](
1313 | topic: Int,
1314 | sv: StorageVector[V],
1315 | dv: BDV[Float]): Float = {
1316 | val index =
1317 | sv match {
1318 | case sv:BDV[_] => (0 until sv.length).toArray
1319 | case sv:BSV[_] => sv.index
1320 | }
1321 | val used = sv.activeSize
1322 | val pos = binarySearchArray(index, topic, 0, used, false)
1323 | if (pos > -1) dv(index(pos)) else 0F
1324 | }
1325 | }
1326 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/mllib/topicModeling/LDA.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.mllib.topicModeling
19 |
20 | import org.apache.spark.Logging
21 | import org.apache.spark.annotation.Experimental
22 | import org.apache.spark.api.java.JavaPairRDD
23 | import org.apache.spark.mllib.linalg.Vector
24 | import org.apache.spark.rdd.RDD
25 |
26 |
27 | /**
28 | * :: Experimental ::
29 | *
30 | * Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
31 | *
32 | * Terminology:
33 | * - "word" = "term": an element of the vocabulary
34 | * - "token": instance of a term appearing in a document
35 | * - "topic": multinomial distribution over words representing some concept
36 | *
37 | * Adapted from MLlib LDA implementation, it supports both online LDA and Gibbbs sampling LDA
38 | */
39 | @Experimental
40 | class LDA private (
41 | private var k: Int,
42 | private var maxIterations: Int,
43 | private var docConcentration: Double,
44 | private var topicConcentration: Double,
45 | private var seed: Long,
46 | private var checkpointInterval: Int,
47 | private var ldaOptimizer: LDAOptimizer) extends Logging {
48 |
49 | def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1,
50 | seed = new scala.util.Random().nextLong(), checkpointInterval = 10, ldaOptimizer = new OnlineLDAOptimizer)
51 |
52 | /**
53 | * Number of topics to infer. I.e., the number of soft cluster centers.
54 | */
55 | def getK: Int = k
56 |
57 | /**
58 | * Number of topics to infer. I.e., the number of soft cluster centers.
59 | * (default = 10)
60 | */
61 | def setK(k: Int): this.type = {
62 | require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k")
63 | this.k = k
64 | this
65 | }
66 |
67 | /**
68 | * Concentration parameter (commonly named "alpha") for the prior placed on documents'
69 | * distributions over topics ("theta").
70 | *
71 | * This is the parameter to a symmetric Dirichlet distribution.
72 | */
73 | def getDocConcentration: Double = this.docConcentration
74 |
75 | /**
76 | * Concentration parameter (commonly named "alpha") for the prior placed on documents'
77 | * distributions over topics ("theta").
78 | *
79 | * This is the parameter to a symmetric Dirichlet distribution, where larger values
80 | * mean more smoothing (more regularization).
81 | *
82 | * If set to -1, then docConcentration is set automatically.
83 | * (default = -1 = automatic)
84 | *
85 | * Optimizer-specific parameter settings:
86 | * - Online
87 | * - Value should be >= 0
88 | * - default = (1.0 / k), following the implementation from
89 | * [[https://github.com/Blei-Lab/onlineldavb]].
90 | */
91 | def setDocConcentration(docConcentration: Double): this.type = {
92 | this.docConcentration = docConcentration
93 | this
94 | }
95 |
96 | /** Alias for [[getDocConcentration]] */
97 | def getAlpha: Double = getDocConcentration
98 |
99 | /** Alias for [[setDocConcentration()]] */
100 | def setAlpha(alpha: Double): this.type = setDocConcentration(alpha)
101 |
102 | /**
103 | * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics'
104 | * distributions over terms.
105 | *
106 | * This is the parameter to a symmetric Dirichlet distribution.
107 | *
108 | * Note: The topics' distributions over terms are called "beta" in the original LDA paper
109 | * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
110 | */
111 | def getTopicConcentration: Double = this.topicConcentration
112 |
113 | /**
114 | * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics'
115 | * distributions over terms.
116 | *
117 | * This is the parameter to a symmetric Dirichlet distribution.
118 | *
119 | * Note: The topics' distributions over terms are called "beta" in the original LDA paper
120 | * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
121 | *
122 | * If set to -1, then topicConcentration is set automatically.
123 | * (default = -1 = automatic)
124 | *
125 | * Optimizer-specific parameter settings:
126 | * - Online
127 | * - Value should be >= 0
128 | * - default = (1.0 / k), following the implementation from
129 | * [[https://github.com/Blei-Lab/onlineldavb]].
130 | */
131 | def setTopicConcentration(topicConcentration: Double): this.type = {
132 | this.topicConcentration = topicConcentration
133 | this
134 | }
135 |
136 | /** Alias for [[getTopicConcentration]] */
137 | def getBeta: Double = getTopicConcentration
138 |
139 | /** Alias for [[setTopicConcentration()]] */
140 | def setBeta(beta: Double): this.type = setTopicConcentration(beta)
141 |
142 | /**
143 | * Maximum number of iterations for learning.
144 | */
145 | def getMaxIterations: Int = maxIterations
146 |
147 | /**
148 | * Maximum number of iterations for learning.
149 | * (default = 20)
150 | */
151 | def setMaxIterations(maxIterations: Int): this.type = {
152 | this.maxIterations = maxIterations
153 | this
154 | }
155 |
156 | /** Random seed */
157 | def getSeed: Long = seed
158 |
159 | /** Random seed */
160 | def setSeed(seed: Long): this.type = {
161 | this.seed = seed
162 | this
163 | }
164 |
165 | /**
166 | * Period (in iterations) between checkpoints.
167 | */
168 | def getCheckpointInterval: Int = checkpointInterval
169 |
170 | /**
171 | * Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery
172 | * (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be
173 | * important when LDA is run for many iterations. If the checkpoint directory is not set in
174 | * [[org.apache.spark.SparkContext]], this setting is ignored.
175 | *
176 | * @see [[org.apache.spark.SparkContext#setCheckpointDir]]
177 | */
178 | def setCheckpointInterval(checkpointInterval: Int): this.type = {
179 | this.checkpointInterval = checkpointInterval
180 | this
181 | }
182 |
183 |
184 | /** LDAOptimizer used to perform the actual calculation */
185 | def getOptimizer: LDAOptimizer = ldaOptimizer
186 |
187 | /**
188 | * LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer)
189 | */
190 | def setOptimizer(optimizer: LDAOptimizer): this.type = {
191 | this.ldaOptimizer = optimizer
192 | this
193 | }
194 |
195 | /**
196 | * Set the LDAOptimizer used to perform the actual calculation by algorithm name.
197 | * Only "online", "gibbs" are supported.
198 | */
199 | def setOptimizer(optimizerName: String): this.type = {
200 | this.ldaOptimizer =
201 | optimizerName.toLowerCase match {
202 | case "online" => new OnlineLDAOptimizer
203 | case "gibbs" => new GibbsLDAOptimizer()
204 | case other =>
205 | throw new IllegalArgumentException(s"Only online, gibbs are supported but got $other.")
206 | }
207 | this
208 | }
209 |
210 | /**
211 | * Learn an LDA model using the given dataset.
212 | *
213 | * @param documents RDD of documents, which are term (word) count vectors paired with IDs.
214 | * The term count vectors are "bags of words" with a fixed-size vocabulary
215 | * (where the vocabulary size is the length of the vector).
216 | * Document IDs must be unique and >= 0.
217 | * @return Inferred LDA model
218 | */
219 | def run(documents: RDD[(Long, Vector)]): LDAModel = {
220 | val state = ldaOptimizer.initialize(documents, this)
221 | var iter = 0
222 | val iterationTimes = Array.fill[Double](maxIterations)(0)
223 | while (iter < maxIterations) {
224 | logInfo(s"Starting iteration $iter/$maxIterations")
225 | val start = System.nanoTime()
226 | state.next()
227 | val elapsedSeconds = (System.nanoTime() - start) / 1e9
228 | iterationTimes(iter) = elapsedSeconds
229 | iter += 1
230 | }
231 | state.getLDAModel(iterationTimes)
232 | }
233 |
234 | /** Java-friendly version of [[run()]] */
235 | def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = {
236 | run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
237 | }
238 | }
239 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/mllib/topicModeling/LDAExample.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.mllib.topicModeling
19 |
20 | import java.text.BreakIterator
21 |
22 | import org.apache.log4j.{Level, Logger}
23 | import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
24 | import org.apache.spark.rdd.RDD
25 | import org.apache.spark.{SparkConf, SparkContext}
26 | import scopt.OptionParser
27 |
28 | import scala.collection.mutable
29 | import scala.reflect.runtime.universe._
30 |
31 | /**
32 | * Abstract class for parameter case classes.
33 | * This overrides the [[toString]] method to print all case class fields by name and value.
34 | * @tparam T Concrete parameter class.
35 | */
36 | abstract class AbstractParams[T: TypeTag] {
37 |
38 | private def tag: TypeTag[T] = typeTag[T]
39 |
40 | /**
41 | * Finds all case class fields in concrete class instance, and outputs them in JSON-style format:
42 | * {
43 | * [field name]:\t[field value]\n
44 | * [field name]:\t[field value]\n
45 | * ...
46 | * }
47 | */
48 | override def toString: String = {
49 | val tpe = tag.tpe
50 | val allAccessors = tpe.declarations.collect {
51 | case m: MethodSymbol if m.isCaseAccessor => m
52 | }
53 | val mirror = runtimeMirror(getClass.getClassLoader)
54 | val instanceMirror = mirror.reflect(this)
55 | allAccessors.map { f =>
56 | val paramName = f.name.toString
57 | val fieldMirror = instanceMirror.reflectField(f)
58 | val paramValue = fieldMirror.get
59 | s" $paramName:\t$paramValue"
60 | }.mkString("{\n", ",\n", "\n}")
61 | }
62 | }
63 |
64 | /**
65 | * An example Latent Dirichlet Allocation (LDA) app. Run with
66 | * {{{
67 | * ./bin/run-example mllib.LDAExample [options]
68 | * }}}
69 | * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
70 | */
71 | object LDAExample {
72 |
73 | private case class Params(
74 | input: Seq[String] = Seq.empty,
75 | k: Int = 20,
76 | maxIterations: Int = 10,
77 | maxInnerIterations: Int = 5,
78 | docConcentration: Double = 0.01,
79 | topicConcentration: Double = 0.01,
80 | vocabSize: Int = 10000,
81 | stopwordFile: String = "",
82 | checkpointDir: Option[String] = None,
83 | checkpointInterval: Int = 10,
84 | optimizer:String = "em",
85 | gibbsSampler:String = "alias",
86 | gibbsAlphaAS:Double = 0.1,
87 | gibbsPrintPerplexity:Boolean = false,
88 | gibbsEdgePartitioner:String = "none",
89 | partitions:Int = 2,
90 | logLevel:String = "info",
91 | psMasterAddr:String = null
92 | ) extends AbstractParams[Params]
93 |
94 | def main(args: Array[String]) {
95 | val defaultParams = Params()
96 |
97 | val parser = new OptionParser[Params]("LDAExample") {
98 | head("LDAExample: an example LDA app for plain text data.")
99 | opt[Int]("k")
100 | .text(s"number of topics. default: ${defaultParams.k}")
101 | .action((x, c) => c.copy(k = x))
102 | opt[Int]("maxIterations")
103 | .text(s"number of iterations of learning. default: ${defaultParams.maxIterations}")
104 | .action((x, c) => c.copy(maxIterations = x))
105 | opt[Int]("maxInnerIterations")
106 | .text(s"number of inner iterations of learning. default: ${defaultParams.maxInnerIterations}")
107 | .action((x, c) => c.copy(maxInnerIterations = x))
108 | opt[Double]("docConcentration")
109 | .text(s"amount of topic smoothing to use (> 1.0) (-1=auto)." +
110 | s" default: ${defaultParams.docConcentration}")
111 | .action((x, c) => c.copy(docConcentration = x))
112 | opt[Double]("topicConcentration")
113 | .text(s"amount of term (word) smoothing to use (> 1.0) (-1=auto)." +
114 | s" default: ${defaultParams.topicConcentration}")
115 | .action((x, c) => c.copy(topicConcentration = x))
116 | opt[Int]("vocabSize")
117 | .text(s"number of distinct word types to use, chosen by frequency. (-1=all)" +
118 | s" default: ${defaultParams.vocabSize}")
119 | .action((x, c) => c.copy(vocabSize = x))
120 | opt[String]("stopwordFile")
121 | .text(s"filepath for a list of stopwords. Note: This must fit on a single machine." +
122 | s" default: ${defaultParams.stopwordFile}")
123 | .action((x, c) => c.copy(stopwordFile = x))
124 | opt[String]("checkpointDir")
125 | .text(s"Directory for checkpointing intermediate results." +
126 | s" Checkpointing helps with recovery and eliminates temporary shuffle files on disk." +
127 | s" default: ${defaultParams.checkpointDir}")
128 | .action((x, c) => c.copy(checkpointDir = Some(x)))
129 | opt[Int]("checkpointInterval")
130 | .text(s"Iterations between each checkpoint. Only used if checkpointDir is set." +
131 | s" default: ${defaultParams.checkpointInterval}")
132 | .action((x, c) => c.copy(checkpointInterval = x))
133 | opt[String]("optimizer")
134 | .text(s"available optimizer are online and gibbs, default: ${defaultParams.optimizer}")
135 | .action((x, c) => c.copy(optimizer = x))
136 | opt[String]("gibbs.sampler")
137 | .text(s"sampler for gibbs optimizer, available options are alias, sparse, light and fast, default: ${defaultParams.gibbsSampler}")
138 | .action((x, c) => c.copy(gibbsSampler = x))
139 | opt[Double]("gibbs.alphaAS")
140 | .text(s"alphaAS for gibbs optimizer, default: ${defaultParams.gibbsAlphaAS}")
141 | .action((x, c) => c.copy(gibbsAlphaAS = x))
142 | opt[Boolean]("gibbs.printPerplexity")
143 | .text(s"print perplexity for gibbs optimizer, default: ${defaultParams.gibbsPrintPerplexity}")
144 | .action((x, c) => c.copy(gibbsPrintPerplexity = x))
145 | opt[String]("gibbs.edgePartitioner")
146 | .text(s"edge partitioner for gibbs optimizer, available options are none, degree, and even, default: ${defaultParams.gibbsEdgePartitioner}}")
147 | .action((x, c) => c.copy(gibbsEdgePartitioner = x))
148 | opt[Int]("partitions")
149 | .text(s"Minimum edge partitions, default: ${defaultParams.partitions}")
150 | .action((x, c) => c.copy(partitions = x))
151 | opt[String]("logLevel")
152 | .text(s"Log level, default: ${defaultParams.logLevel}")
153 | .action((x, c) => c.copy(logLevel = x))
154 | opt[String]("psMasterAddr")
155 | .text(s"psMaster address, default: ${defaultParams.psMasterAddr}")
156 | .action((x, c) => c.copy(psMasterAddr = x))
157 | arg[String]("...")
158 | .text("input paths (directories) to plain text corpora." +
159 | " Each text file line should hold 1 document.")
160 | .unbounded()
161 | .required()
162 | .action((x, c) => c.copy(input = c.input :+ x))
163 | }
164 |
165 | parser.parse(args, defaultParams).map { params =>
166 | run(params)
167 | }.getOrElse {
168 | parser.showUsageAsError
169 | sys.exit(1)
170 | }
171 | }
172 |
173 | // private def createOptimizer(params: Params, lineRdd: RDD[Int], columnRdd: RDD[Int]):LDAOptimizer = {
174 | private def createOptimizer(params: Params):LDAOptimizer = {
175 | params.optimizer match {
176 | case "online" => val optimizer = new OnlineLDAOptimizer
177 | optimizer
178 | case "gibbs" =>
179 | val optimizer = new GibbsLDAOptimizer
180 | optimizer.setSampler(params.gibbsSampler)
181 | optimizer.printPerplexity = params.gibbsPrintPerplexity
182 | optimizer.edgePartitioner = params.gibbsEdgePartitioner
183 | optimizer.setAlphaAS(params.gibbsAlphaAS.toFloat)
184 | optimizer
185 | case _ =>
186 | throw new IllegalArgumentException(s"available optimizers are em, online and gibbs, but got ${params.optimizer}")
187 | }
188 | }
189 |
190 | /**
191 | * run LDA
192 | * @param params
193 | */
194 | private def run(params: Params) {
195 | val conf = new SparkConf().setAppName(s"LDAExample with $params")
196 | val sc = new SparkContext(conf)
197 |
198 | val logLevel = Level.toLevel(params.logLevel, Level.INFO)
199 | Logger.getRootLogger.setLevel(logLevel)
200 | println(s"Setting log level to $logLevel")
201 |
202 | // Load documents, and prepare them for LDA.
203 | val preprocessStart = System.nanoTime()
204 | val (corpus, actualCorpusSize, vocabArray, actualNumTokens) =
205 | preprocess(sc, params.input, params.vocabSize, params.partitions, params.stopwordFile)
206 | val actualVocabSize = vocabArray.size
207 | val preprocessElapsed = (System.nanoTime() - preprocessStart) / 1e9
208 |
209 | println()
210 | println(s"Corpus summary:")
211 | println(s"\t Training set size: $actualCorpusSize documents")
212 | println(s"\t Vocabulary size: $actualVocabSize terms")
213 | println(s"\t Training set size: $actualNumTokens tokens")
214 | println(s"\t Preprocessing time: $preprocessElapsed sec")
215 | println()
216 |
217 | // Run LDA.
218 | val lda = new LDA()
219 | lda.setK(params.k)
220 | .setMaxIterations(params.maxIterations)
221 | .setDocConcentration(params.docConcentration)
222 | .setTopicConcentration(params.topicConcentration)
223 | .setCheckpointInterval(params.checkpointInterval)
224 | .setOptimizer(createOptimizer(params))
225 |
226 | if (params.checkpointDir.nonEmpty) {
227 | sc.setCheckpointDir(params.checkpointDir.get)
228 | }
229 |
230 | val startTime = System.nanoTime()
231 | val ldaModel = lda.run(corpus)
232 | val elapsed = (System.nanoTime() - startTime) / 1e9
233 |
234 | println(s"Finished training LDA model using ${lda.getOptimizer.getClass.getName}")
235 |
236 | // Print the topics, showing the top-weighted terms for each topic.
237 | val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
238 |
239 | val topics = topicIndices.map { case (terms, termWeights) =>
240 | terms.zip(termWeights).map { case (term, weight) => (vocabArray(term.toInt), weight) }
241 | }
242 | println(s"${params.k} topics:")
243 | topics.zipWithIndex.foreach { case (topic, i) =>
244 | println(s"TOPIC $i")
245 | topic.foreach { case (term, weight) =>
246 | println(s"$term\t$weight")
247 | }
248 | println()
249 | }
250 | sc.stop()
251 | }
252 |
253 | /**
254 | * Load documents, tokenize them, create vocabulary, and prepare documents as term count vectors.
255 | * @return (corpus, vocabulary as array, total token count in corpus)
256 | */
257 | private def preprocess(
258 | sc: SparkContext,
259 | paths: Seq[String],
260 | vocabSize: Int,
261 | partitions:Int,
262 | stopwordFile: String): (RDD[(Long, Vector)], Long, Array[String], Long) = {
263 |
264 | // Get dataset of document texts
265 | // One document per line in each text file. If the input consists of many small files,
266 | // this can result in a large number of small partitions, which can degrade performance.
267 | // In this case, consider using coalesce() to create fewer, larger partitions.
268 | val textRDD: RDD[String] = sc.textFile(paths.mkString(","), Math.max(1, partitions)).coalesce(partitions)
269 |
270 |
271 | // Split text into words
272 | val tokenizer = new SimpleTokenizer(sc, stopwordFile)
273 | val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) =>
274 | id -> tokenizer.getWords(text)
275 | }
276 | tokenized.cache()
277 |
278 | // Counts words: RDD[(word, wordCount)]
279 | val wordCounts: RDD[(String, Long)] = tokenized
280 | .flatMap { case (_, tokens) => tokens.map(_ -> 1L) }
281 | .reduceByKey(_ + _)
282 | wordCounts.cache()
283 | val fullVocabSize = wordCounts.count()
284 | // Select vocab
285 | // (vocab: Map[word -> id], total tokens after selecting vocab)
286 | val (vocab: Map[String, Int], selectedTokenCount: Long) = {
287 | val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) {
288 | // Use all terms
289 | wordCounts.collect().sortBy(-_._2)
290 | } else {
291 | // Sort terms to select vocab
292 | wordCounts.sortBy(_._2, ascending = false).take(vocabSize)
293 | }
294 | (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum)
295 | }
296 |
297 | val documents = tokenized.map { case (id, tokens) =>
298 | // Filter tokens by vocabulary, and create word count vector representation of document.
299 | val wc = new mutable.HashMap[Int, Int]()
300 | tokens.foreach { term =>
301 | if (vocab.contains(term)) {
302 | val termIndex = vocab(term)
303 | wc(termIndex) = wc.getOrElse(termIndex, 0) + 1
304 | }
305 | }
306 | val indices = wc.keys.toArray.sorted
307 | val values = indices.map(i => wc(i).toDouble)
308 |
309 | val sb = Vectors.sparse(vocab.size, indices, values)
310 | (id, sb)
311 | }.filter(_._2.asInstanceOf[SparseVector].values.length > 0).cache
312 | val corpusSize = documents.count
313 | tokenized.unpersist(false)
314 |
315 | val vocabArray = new Array[String](vocab.size)
316 | vocab.foreach { case (term, i) => vocabArray(i) = term }
317 |
318 | (documents, corpusSize, vocabArray, selectedTokenCount)
319 | }
320 | }
321 |
322 | /**
323 | * Simple Tokenizer.
324 | *
325 | * TODO: Formalize the interface, and make this a public class in mllib.feature
326 | */
327 | private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable {
328 |
329 | private val stopwords: Set[String] = if (stopwordFile.isEmpty) {
330 | Set.empty[String]
331 | } else {
332 | val stopwordText = sc.textFile(stopwordFile).collect()
333 | stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet
334 | }
335 |
336 | // Matches sequences of Unicode letters
337 | private val allWordRegex = "^(\\p{L}*)$".r
338 |
339 | // Ignore words shorter than this length.
340 | private val minWordLength = 3
341 |
342 | def getWords(text: String): IndexedSeq[String] = {
343 |
344 | val words = new mutable.ArrayBuffer[String]()
345 |
346 | // Use Java BreakIterator to tokenize text into words.
347 | val wb = BreakIterator.getWordInstance
348 | wb.setText(text)
349 |
350 | // current,end index start,end of each word
351 | var current = wb.first()
352 | var end = wb.next()
353 | while (end != BreakIterator.DONE) {
354 | // Convert to lowercase
355 | val word: String = text.substring(current, end).toLowerCase
356 | // Remove short words and strings that aren't only letters
357 | word match {
358 | case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) =>
359 | words += w
360 | case _ =>
361 | }
362 |
363 | current = end
364 | try {
365 | end = wb.next()
366 | } catch {
367 | case e: Exception =>
368 | // Ignore remaining text in line.
369 | // This is a known bug in BreakIterator (for some Java versions),
370 | // which fails when it sees certain characters.
371 | end = BreakIterator.DONE
372 | }
373 | }
374 | words
375 | }
376 |
377 | }
378 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/mllib/topicModeling/LDAModel.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.mllib.topicModeling
19 |
20 | import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
21 | import org.apache.spark.annotation.Experimental
22 | import org.apache.spark.graphx.{EdgeContext, Graph, VertexId}
23 | import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors}
24 | import org.apache.spark.rdd.RDD
25 | import scala.collection.mutable.ArrayBuffer
26 |
27 | /**
28 | * :: Experimental ::
29 | *
30 | * Latent Dirichlet Allocation (LDA) model.
31 | *
32 | * This abstraction permits for different underlying representations,
33 | * including local and distributed data structures.
34 |
35 | * Adapted from MLlib LDAModel implementation
36 | */
37 | @Experimental
38 | abstract class LDAModel {
39 |
40 | /** Number of topics */
41 | def k: Int
42 |
43 | /** Vocabulary size (number of terms or terms in the vocabulary) */
44 | def vocabSize: Int
45 |
46 | /**
47 | * Inferred topics, where each topic is represented by a distribution over terms.
48 | * This is a matrix of size vocabSize x k, where each column is a topic.
49 | * No guarantees are given about the ordering of the topics.
50 | */
51 | def topicsMatrix: Matrix
52 |
53 | /**
54 | * Return the topics described by weighted terms.
55 | *
56 | * This limits the number of terms per topic.
57 | * This is approximate; it may not return exactly the top-weighted terms for each topic.
58 | * To get a more precise set of top terms, increase maxTermsPerTopic.
59 | *
60 | * @param maxTermsPerTopic Maximum number of terms to collect for each topic.
61 | * @return Array over topics. Each topic is represented as a pair of matching arrays:
62 | * (term indices, term weights in topic).
63 | * Each topic's terms are sorted in order of decreasing weight.
64 | */
65 | def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])]
66 |
67 | /**
68 |
69 | * Return the topics described by weighted terms.
70 | *
71 | * WARNING: If vocabSize and k are large, this can return a large object!
72 | *
73 | * @return Array over topics. Each topic is represented as a pair of matching arrays:
74 | * (term indices, term weights in topic).
75 | * Each topic's terms are sorted in order of decreasing weight.
76 | */
77 | def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize)
78 |
79 | /* TODO (once LDA can be trained with Strings or given a dictionary)
80 | * Return the topics described by weighted terms.
81 | *
82 | * This is similar to [[describeTopics()]] but returns String values for terms.
83 | * If this model was trained using Strings or was given a dictionary, then this method returns
84 | * terms as text. Otherwise, this method returns terms as term indices.
85 | *
86 | * This limits the number of terms per topic.
87 | * This is approximate; it may not return exactly the top-weighted terms for each topic.
88 | * To get a more precise set of top terms, increase maxTermsPerTopic.
89 | *
90 | * @param maxTermsPerTopic Maximum number of terms to collect for each topic.
91 | * @return Array over topics. Each topic is represented as a pair of matching arrays:
92 | * (terms, term weights in topic) where terms are either the actual term text
93 | * (if available) or the term indices.
94 | * Each topic's terms are sorted in order of decreasing weight.
95 | */
96 | // def describeTopicsAsStrings(maxTermsPerTopic: Int): Array[(Array[Double], Array[String])]
97 |
98 | /* TODO (once LDA can be trained with Strings or given a dictionary)
99 | * Return the topics described by weighted terms.
100 | *
101 | * This is similar to [[describeTopics()]] but returns String values for terms.
102 | * If this model was trained using Strings or was given a dictionary, then this method returns
103 | * terms as text. Otherwise, this method returns terms as term indices.
104 | *
105 | * WARNING: If vocabSize and k are large, this can return a large object!
106 | *
107 | * @return Array over topics. Each topic is represented as a pair of matching arrays:
108 | * (terms, term weights in topic) where terms are either the actual term text
109 | * (if available) or the term indices.
110 | * Each topic's terms are sorted in order of decreasing weight.
111 | */
112 | // def describeTopicsAsStrings(): Array[(Array[Double], Array[String])] =
113 | // describeTopicsAsStrings(vocabSize)
114 |
115 | /* TODO
116 | * Compute the log likelihood of the observed tokens, given the current parameter estimates:
117 | * log P(docs | topics, topic distributions for docs, alpha, eta)
118 | *
119 | * Note:
120 | * - This excludes the prior.
121 | * - Even with the prior, this is NOT the same as the data log likelihood given the
122 | * hyperparameters.
123 | *
124 | * @param documents RDD of documents, which are term (word) count vectors paired with IDs.
125 | * The term count vectors are "bags of words" with a fixed-size vocabulary
126 | * (where the vocabulary size is the length of the vector).
127 | * This must use the same vocabulary (ordering of term counts) as in training.
128 | * Document IDs must be unique and >= 0.
129 | * @return Estimated log likelihood of the data under this model
130 | */
131 | // def logLikelihood(documents: RDD[(Long, Vector)]): Double
132 |
133 | /* TODO
134 | * Compute the estimated topic distribution for each document.
135 | * This is often called 'theta' in the literature.
136 | *
137 | * @param documents RDD of documents, which are term (word) count vectors paired with IDs.
138 | * The term count vectors are "bags of words" with a fixed-size vocabulary
139 | * (where the vocabulary size is the length of the vector).
140 | * This must use the same vocabulary (ordering of term counts) as in training.
141 | * Document IDs must be unique and >= 0.
142 | * @return Estimated topic distribution for each document.
143 | * The returned RDD may be zipped with the given RDD, where each returned vector
144 | * is a multinomial distribution over topics.
145 | */
146 | // def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)]
147 |
148 | }
149 |
150 | /**
151 | * :: Experimental ::
152 | *
153 | * Local LDA model.
154 | * This model stores only the inferred topics.
155 | *
156 | * @param topics Inferred topics (vocabSize x k matrix).
157 | */
158 | @Experimental
159 | class LocalLDAModel private[topicModeling] (
160 | private val topics: Matrix) extends LDAModel with Serializable {
161 |
162 | override def k: Int = topics.numCols
163 |
164 | override def vocabSize: Int = topics.numRows
165 |
166 | override def topicsMatrix: Matrix = topics
167 |
168 | override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = {
169 | val brzTopics = topics.toBreeze.toDenseMatrix
170 | Range(0, k).map { topicIndex =>
171 | val topic = normalize(brzTopics(::, topicIndex), 1.0)
172 | val (termWeights, terms) =
173 | topic.toArray.zipWithIndex.sortBy(-_._1).take(maxTermsPerTopic).unzip
174 | (terms.toArray, termWeights.toArray)
175 | }.toArray
176 | }
177 |
178 | // TODO
179 | // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
180 |
181 | // TODO:
182 | // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
183 |
184 | }
185 |
186 | /**
187 | * :: Experimental ::
188 | *
189 | * Online LDA Model with an interface supporting prediction
190 | * on document distribution
191 | * @param topics Inferred topics (vocabSize x k matrix).
192 | */
193 | @Experimental
194 | class OnlineLDAModel(
195 | private val topics: Matrix,
196 | private val alpha: Double,
197 | private val gammaShape: Double) extends LDAModel with Serializable {
198 |
199 | override def k: Int = topics.numCols
200 |
201 | override def vocabSize: Int = topics.numRows
202 |
203 | // vocabSize x k, where each column is a topic.
204 | override def topicsMatrix: Matrix = topics
205 |
206 | override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = {
207 | val brzTopics = topics.toBreeze.toDenseMatrix
208 | Range(0, k).map { topicIndex =>
209 | val topic = normalize(brzTopics(::, topicIndex), 1.0)
210 | val (termWeights, terms) =
211 | topic.toArray.zipWithIndex.sortBy(-_._1).take(maxTermsPerTopic).unzip
212 | (terms.toArray, termWeights.toArray)
213 | }.toArray
214 | }
215 |
216 | /**
217 | * For each document in the training set, return the distribution over topics for that document
218 | * ("theta_doc").
219 | *
220 | * @return RDD of (document ID, topic distribution) pairs
221 | */
222 | def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = {
223 | val (_, _, gammaArray) = OnlineLDAOptimizer.inference(
224 | k, vocabSize, (topics.transpose).toBreeze.toDenseMatrix, alpha, gammaShape, documents)
225 | val result = gammaArray.map(p => {
226 | (p._1, Vectors.fromBreeze(p._2))
227 | })
228 | result
229 | }
230 |
231 | /**
232 | * For each document in the training set, return the distribution over topics for that document
233 | * ("theta_doc").
234 | *
235 | * @return RDD of (document ID, topic distribution) pairs
236 | */
237 | def predict(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = {
238 | topicDistributions(documents)
239 | }
240 | }
241 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/mllib/topicModeling/LDAOptimizer.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.mllib.topicModeling
19 |
20 | import org.apache.spark.annotation.Experimental
21 | import org.apache.spark.mllib.linalg.Vector
22 | import org.apache.spark.rdd.RDD
23 |
24 | /**
25 | * :: Experimental ::
26 | *
27 | * An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can
28 | * hold optimizer-specific parameters for users to set.
29 | */
30 | @Experimental
31 | trait LDAOptimizer {
32 |
33 | /*
34 | DEVELOPERS NOTE:
35 |
36 | An LDAOptimizer contains an algorithm for LDA and performs the actual computation, which
37 | stores internal data structure (Graph or Matrix) and other parameters for the algorithm.
38 | The interface is isolated to improve the extensibility of LDA.
39 | */
40 |
41 | /**
42 | * Initializer for the optimizer. LDA passes the common parameters to the optimizer and
43 | * the internal structure can be initialized properly.
44 | */
45 | private[topicModeling] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer
46 |
47 | /**
48 | * run an iteration
49 | * @return
50 | */
51 | private[topicModeling] def next(): LDAOptimizer
52 |
53 | /**
54 | * get model of current iteration
55 | * @param iterationTimes
56 | * @return
57 | */
58 | private[topicModeling] def getLDAModel(iterationTimes: Array[Double]): LDAModel
59 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/mllib/topicModeling/OnlineHDP.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.mllib.topicModeling
2 |
3 | import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Vector => BV, _}
4 | import breeze.numerics.{abs, digamma, exp, _}
5 | import org.apache.spark.mllib.linalg.{SparseVector, Vector}
6 | import org.apache.spark.rdd.RDD
7 |
8 | import scala.collection.mutable.ArrayBuffer
9 |
10 | class SuffStats(
11 | val T: Int,
12 | val Wt: Int,
13 | val m_chunksize: Int) extends Serializable {
14 | var m_var_sticks_ss = BDV.zeros[Double](T)
15 | var m_var_beta_ss = BDM.zeros[Double](T, Wt)
16 |
17 | def set_zero: Unit = {
18 | m_var_sticks_ss = BDV.zeros(T)
19 | m_var_beta_ss = BDM.zeros(T, Wt)
20 | }
21 | }
22 |
23 | object OnlineHDPOptimizer extends Serializable {
24 | val rhot_bound = 0.0
25 |
26 | def log_normalize(v: BDV[Double]): (BDV[Double], Double) = {
27 | val log_max = 100.0
28 | val max_val = v.toArray.max
29 | val log_shift = log_max - log(v.size + 1.0) - max_val
30 | val tot: Double = sum(exp(v + log_shift))
31 | val log_norm = log(tot) - log_shift
32 | (v - log_norm, log_norm)
33 | }
34 |
35 | def log_normalize(m: BDM[Double]): (BDM[Double], BDV[Double]) = {
36 | val log_max = 100.0
37 | // get max for every row
38 | val max_val: BDV[Double] = m(*, ::).map(v => max(v))
39 | val log_shift: BDV[Double] = log_max - log(m.cols + 1.0) - max_val
40 |
41 | val m_shift: BDM[Double] = exp(m(::, *) + log_shift)
42 | val tot: BDV[Double] = sum(m_shift(*, ::))
43 |
44 | val log_norm: BDV[Double] = log(tot) - log_shift
45 | (m(::, *) - log_norm, log_norm)
46 | }
47 |
48 | def expect_log_sticks(m: BDM[Double]): BDV[Double] = {
49 | // """
50 | // For stick-breaking hdp, return the E[log(sticks)]
51 | // """
52 | val column = sum(m(::, *))
53 |
54 | val dig_sum: BDV[Double] = digamma(column.toDenseVector)
55 | val ElogW: BDV[Double] = digamma(m(0, ::).inner) - dig_sum
56 | val Elog1_W: BDV[Double] = digamma(m(1, ::).inner) - dig_sum
57 | //
58 | val n = m.cols + 1
59 | val Elogsticks = BDV.zeros[Double](n)
60 | Elogsticks(0 until n - 1) := ElogW(0 until n - 1)
61 | val cs = accumulate(Elog1_W)
62 |
63 | Elogsticks(1 to n - 1) := Elogsticks(1 to n - 1) + cs
64 | Elogsticks
65 | }
66 |
67 | /**
68 | * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation
69 | * uses digamma which is accurate but expensive.
70 | */
71 | private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = {
72 | val rowSum = sum(alpha(breeze.linalg.*, ::))
73 | val digAlpha = digamma(alpha)
74 | val digRowSum = digamma(rowSum)
75 | val result = digAlpha(::, breeze.linalg.*) - digRowSum
76 | result
77 | }
78 |
79 | }
80 |
81 | /**
82 | * Implemented based on the paper "Online Variational Inference for the Hierarchical Dirichlet Process" (Chong Wang, John Paisley and David M. Blei)
83 | */
84 |
85 | class OnlineHDPOptimizer(
86 | val corpus: RDD[(Long, Vector)],
87 | val chunkSize: Int = 256,
88 | val m_kappa: Double = 1.0,
89 | var m_tau: Double = 64.0,
90 | val m_K: Int = 15,
91 | val m_T: Int = 150,
92 | val m_alpha: Double = 1,
93 | val m_gamma: Double = 1,
94 | val m_eta: Double = 0.01,
95 | val m_scale: Double = 1.0,
96 | val m_var_converge: Double = 0.0001,
97 | val iterations: Int = 10
98 | ) extends Serializable {
99 |
100 |
101 | val lda_alpha: Double = 1D
102 | val lda_beta: Double = 1D
103 | val m_W: Int = corpus.first()._2.size
104 | var m_D: Double = corpus.count().toDouble
105 |
106 | val m_var_sticks = BDM.zeros[Double](2, m_T - 1) // 2 * T - 1
107 | m_var_sticks(0, ::) := 1.0
108 | m_var_sticks(1, ::) := new BDV[Double]((m_T - 1 to 1 by -1).map(_.toDouble).toArray).t
109 | var m_varphi_ss: BDV[Double] = BDV.zeros[Double](m_T) // T
110 |
111 | // T * W
112 | val m_lambda: BDM[Double] = BDM.rand(m_T, m_W) * (m_D.toDouble) * 100.0 / (m_T * m_W).toDouble - m_eta
113 |
114 | // T * W
115 | val m_Elogbeta = OnlineHDPOptimizer.dirichletExpectation(m_lambda + m_eta)
116 |
117 | m_tau = m_tau + 1
118 | var m_updatect = 0
119 | var m_status_up_to_date = true
120 |
121 | val m_timestamp: BDV[Int] = BDV.zeros[Int](m_W)
122 | val m_r = collection.mutable.MutableList[Double](0)
123 | var m_lambda_sum = sum(m_lambda(*, ::)) // row sum
124 |
125 | val rhot_bound = 0.0
126 |
127 |
128 | def update(docs: RDD[(Long, Vector)]): Unit = {
129 | for (i <- 1 to iterations) {
130 | val chunk = docs
131 | update_chunk(chunk)
132 | }
133 | }
134 |
135 |
136 | def update_chunk(chunk: RDD[(Long, Vector)], update: Boolean = true): (Double, Int) = {
137 | // Find the unique words in this chunk...
138 | val unique_words = scala.collection.mutable.Map[Int, Int]()
139 | val raw_word_list = ArrayBuffer[Int]()
140 | chunk.collect().foreach(doc => {
141 | doc._2.foreachActive { case (id, count) =>
142 | if (count > 0 && !unique_words.contains(id)) {
143 | unique_words += (id -> unique_words.size)
144 | raw_word_list.append(id)
145 | }
146 | }
147 | })
148 |
149 | val word_list = raw_word_list.toList
150 |
151 | val Wt = word_list.length // length of words in these documents
152 |
153 | // ...and do the lazy updates on the necessary columns of lambda
154 | // rw = np.array([self.m_r[t] for t in self.m_timestamp[word_list]])
155 | // self.m_lambda[:, word_list] *= np.exp(self.m_r[-1] - rw)
156 | // self.m_Elogbeta[:, word_list] = \
157 | // sp.psi(self.m_eta + self.m_lambda[:, word_list]) - \
158 | // sp.psi(self.m_W * self.m_eta + self.m_lambda_sum[:, np.newaxis])
159 |
160 |
161 | val rw: BDV[Double] = new BDV(word_list.map(id => m_timestamp(id)).map(t => m_r(t)).toArray)
162 |
163 | val exprw: BDV[Double] = exp(rw.map(d => m_r.last - d))
164 |
165 | val wordsMatrix = m_lambda(::, word_list).toDenseMatrix
166 | for (row <- 0 until wordsMatrix.rows) {
167 | wordsMatrix(row, ::) := (wordsMatrix(row, ::).t :* exprw).t
168 | }
169 | m_lambda(::, word_list) := wordsMatrix
170 |
171 | for (id <- word_list) {
172 | m_Elogbeta(::, id) := digamma(m_lambda(::, id) + m_eta) - digamma(m_lambda_sum + m_W * m_eta)
173 | }
174 |
175 | val ss = new SuffStats(m_T, Wt, chunk.count().toInt)
176 |
177 | val Elogsticks_1st: BDV[Double] = OnlineHDPOptimizer.expect_log_sticks(m_var_sticks) // global sticks
178 |
179 | // run variational inference on some new docs
180 | var score = 0.0
181 | var count = 0D
182 | chunk.collect().foreach(doc =>
183 | if (doc._2.size > 0) {
184 | val doc_word_ids = doc._2.asInstanceOf[SparseVector].indices
185 | val doc_word_counts = doc._2.asInstanceOf[SparseVector].values
186 | val dict = unique_words.toMap
187 | val wl = doc_word_ids.toList
188 |
189 | val doc_score = doc_e_step(doc, ss, Elogsticks_1st,
190 | word_list, dict, wl,
191 | new BDV[Double](doc_word_counts), m_var_converge)
192 | count += sum(doc_word_counts)
193 | score += doc_score
194 | }
195 | )
196 | if (update) {
197 | update_lambda(ss, word_list)
198 | }
199 |
200 | (score, count.toInt)
201 | }
202 |
203 |
204 | def update_lambda(sstats: SuffStats, word_list: List[Int]): Unit = {
205 | m_status_up_to_date = false
206 | // rhot will be between 0 and 1, and says how much to weight
207 | // the information we got from this mini-chunk.
208 | var rhot = m_scale * pow(m_tau + m_updatect, -m_kappa)
209 | if (rhot < rhot_bound)
210 | rhot = rhot_bound
211 |
212 | // Update appropriate columns of lambda based on documents.
213 | // T * Wt T * Wt T * Wt
214 | m_lambda(::, word_list) := (m_lambda(::, word_list).toDenseMatrix * (1 - rhot)) + sstats.m_var_beta_ss * rhot * m_D / sstats.m_chunksize.toDouble
215 | m_lambda_sum = (1 - rhot) * m_lambda_sum + sum(sstats.m_var_beta_ss(*, ::)) * rhot * m_D / sstats.m_chunksize.toDouble
216 |
217 | m_updatect += 1
218 | m_timestamp(word_list) := m_updatect
219 | m_r += (m_r.last + log(1 - rhot))
220 |
221 | // T
222 | m_varphi_ss = (1.0 - rhot) * m_varphi_ss + rhot * sstats.m_var_sticks_ss * m_D.toDouble / sstats.m_chunksize.toDouble
223 | // update top level sticks
224 | // 2 * T - 1
225 | m_var_sticks(0, ::) := (m_varphi_ss(0 to m_T - 2) + 1.0).t
226 | val var_phi_sum = flipud(m_varphi_ss(1 to m_varphi_ss.length - 1)) // T - 1
227 | m_var_sticks(1, ::) := (flipud(accumulate(var_phi_sum)) + m_gamma).t
228 |
229 | }
230 |
231 |
232 | def doc_e_step(doc: (Long, Vector),
233 | ss: SuffStats,
234 | Elogsticks_1st: BDV[Double],
235 | word_list: List[Int],
236 | unique_words: Map[Int, Int],
237 | doc_word_ids: List[Int],
238 | doc_word_counts: BDV[Double],
239 | var_converge: Double): Double = {
240 |
241 | val chunkids = doc_word_ids.map(id => unique_words(id))
242 |
243 | val Elogbeta_doc: BDM[Double] = m_Elogbeta(::, doc_word_ids).toDenseMatrix // T * Wt
244 | // very similar to the hdp equations, 2 * K - 1
245 | val v = BDM.zeros[Double](2, m_K - 1)
246 | v(0, ::) := 1.0
247 | v(1, ::) := m_alpha
248 |
249 | var Elogsticks_2nd = OnlineHDPOptimizer.expect_log_sticks(v)
250 |
251 | // back to the uniform
252 | var phi: BDM[Double] = BDM.ones[Double](doc_word_ids.size, m_K) * 1.0 / m_K.toDouble // Wt * K
253 |
254 | var likelihood = 0.0
255 | var old_likelihood = -1e200
256 | val converge = 1.0
257 | val eps = 1e-100
258 |
259 | var iter = 0
260 | val max_iter = 100
261 |
262 | var var_phi_out: BDM[Double] = BDM.ones[Double](1, 1)
263 |
264 | // not yet support second level optimization yet, to be done in the future
265 | while (iter < max_iter && (converge < 0.0 || converge > var_converge)) {
266 |
267 | // var_phi
268 | val (log_var_phi: BDM[Double], var_phi: BDM[Double]) =
269 | if (iter < 3) {
270 | val element = Elogbeta_doc.copy // T * Wt
271 | for (i <- 0 to element.rows - 1) {
272 | element(i, ::) := (element(i, ::).t :* doc_word_counts).t
273 | }
274 | var var_phi: BDM[Double] = phi.t * element.t // K * Wt * Wt * T => K * T
275 | val (log_var_phi, log_norm) = OnlineHDPOptimizer.log_normalize(var_phi)
276 | var_phi = exp(log_var_phi)
277 | (log_var_phi, var_phi)
278 | }
279 | else {
280 | val element = Elogbeta_doc.copy
281 | for (i <- 0 to element.rows - 1) {
282 | element(i, ::) := (element(i, ::).t :* doc_word_counts).t
283 | }
284 | val product: BDM[Double] = phi.t * element.t
285 | for (i <- 0 until product.rows) {
286 | product(i, ::) := (product(i, ::).t + Elogsticks_1st).t
287 | }
288 |
289 | var var_phi: BDM[Double] = product
290 | val (log_var_phi, log_norm) = OnlineHDPOptimizer.log_normalize(var_phi)
291 | var_phi = exp(log_var_phi)
292 | (log_var_phi, var_phi)
293 | }
294 |
295 | val (log_phi, log_norm) =
296 | // phi
297 | if (iter < 3) {
298 | phi = (var_phi * Elogbeta_doc).t
299 | val (log_phi, log_norm) = OnlineHDPOptimizer.log_normalize(phi)
300 | phi = exp(log_phi)
301 | (log_phi, log_norm)
302 | }
303 | else {
304 | // K * T T * Wt
305 | val product: BDM[Double] = (var_phi * Elogbeta_doc).t
306 | for (i <- 0 until product.rows) {
307 | product(i, ::) := (product(i, ::).t + Elogsticks_2nd).t
308 | }
309 | phi = product
310 | val (log_phi, log_norm) = OnlineHDPOptimizer.log_normalize(phi)
311 | phi = exp(log_phi)
312 | (log_phi, log_norm)
313 | }
314 |
315 |
316 | // v
317 | val phi_all = phi.copy
318 | for (i <- 0 until phi_all.cols) {
319 | phi_all(::, i) := (phi_all(::, i)) :* doc_word_counts
320 | }
321 |
322 | v(0, ::) := sum(phi_all(::, m_K - 1)) + 1.0
323 | val selected = phi_all(::, 1 until m_K)
324 | val t_sum = sum(selected(::, *)).toDenseVector
325 | val phi_cum = flipud(t_sum)
326 | v(1, ::) := (flipud(accumulate(phi_cum)) + m_alpha).t
327 | Elogsticks_2nd = OnlineHDPOptimizer.expect_log_sticks(v)
328 |
329 | likelihood = 0.0
330 | // compute likelihood
331 | // var_phi part/ C in john's notation
332 |
333 | val diff = log_var_phi.copy
334 | for (i <- 0 to diff.rows - 1) {
335 | diff(i, ::) := (Elogsticks_1st :- diff(i, ::).t).t
336 | }
337 |
338 | likelihood += sum(diff :* var_phi)
339 |
340 | // v part/ v in john's notation, john's beta is alpha here
341 | val log_alpha = log(m_alpha)
342 | likelihood += (m_K - 1) * log_alpha
343 | val dig_sum = (digamma(sum(v(::, *)))).toDenseVector
344 | val vCopy = v.copy
345 | for (i <- 0 until v.cols) {
346 | vCopy(::, i) := BDV[Double](1.0, m_alpha) - vCopy(::, i)
347 | }
348 |
349 | val dv = digamma(v)
350 | for (i <- 0 until v.rows) {
351 | dv(i, ::) := dv(i, ::) - dig_sum.t
352 | }
353 |
354 | likelihood += sum(vCopy :* dv)
355 | likelihood -= sum(lgamma(sum(v(::, *)))) - sum(lgamma(v))
356 |
357 | // Z part
358 | val log_phiCopy = log_phi.copy
359 | for (i <- 0 until log_phiCopy.rows) {
360 | log_phiCopy(i, ::) := (Elogsticks_2nd - log_phiCopy(i, ::).t).t
361 | }
362 | likelihood += sum(log_phiCopy :* phi)
363 |
364 | // X part, the data part
365 | val Elogbeta_docCopy = Elogbeta_doc.copy
366 | for (i <- 0 until Elogbeta_docCopy.rows) {
367 | Elogbeta_docCopy(i, ::) := (Elogbeta_docCopy(i, ::).t :* doc_word_counts).t
368 | }
369 |
370 | likelihood += sum(phi.t :* (var_phi * Elogbeta_docCopy))
371 |
372 | val converge = (likelihood - old_likelihood) / abs(old_likelihood)
373 | old_likelihood = likelihood
374 |
375 | if (converge < -0.000001)
376 | println("likelihood is decreasing!")
377 |
378 | iter += 1
379 | var_phi_out = var_phi
380 | }
381 |
382 | // update the suff_stat ss
383 | // this time it only contains information from one doc
384 | val sumPhiOut = sum(var_phi_out(::, *))
385 | ss.m_var_sticks_ss += sumPhiOut.toDenseVector
386 |
387 | val phiCopy = phi.copy.t
388 | for (i <- 0 until phi.rows) {
389 | phiCopy(i, ::) := (phiCopy(i, ::).t :* doc_word_counts).t
390 | }
391 |
392 | val middleResult: BDM[Double] = var_phi_out.t * phiCopy
393 | for (i <- 0 until chunkids.size) {
394 | ss.m_var_beta_ss(::, chunkids(i)) := ss.m_var_beta_ss(::, chunkids(i)) + middleResult(::, i)
395 | }
396 |
397 | return likelihood
398 | }
399 |
400 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/mllib/topicModeling/OnlineLDAOptimizer.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.mllib.topicModeling
2 |
3 | import java.util.Random
4 |
5 | import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, sum, max}
6 | import breeze.numerics.{abs, digamma, exp, log, lgamma}
7 | import breeze.stats.distributions.{Gamma, RandBasis}
8 | import org.apache.spark.mllib.linalg._
9 | import org.apache.spark.rdd.RDD
10 | import org.apache.spark.util.collection.AppendOnlyMap
11 | import scala.collection.mutable.ArrayBuffer
12 |
13 | /**
14 | * :: Experimental ::
15 | *
16 | * An Optimizer based on OnlineLDAOptimizer, that can also output the the document~topic
17 | * distribution
18 | *
19 | * An early version of the implementation was merged into MLlib (PR #4419), and several extensions (e.g., predict) are added here
20 | *
21 | */
22 | object OnlineLDAOptimizer{
23 | def inference(
24 | k: Int,
25 | vocabSize: Int,
26 | lambda: BDM[Double],
27 | alpha: Double,
28 | gammaShape: Double,
29 | batch: RDD[(Long, Vector)]):
30 | (BDM[Double], RDD[BDM[Double]], RDD[(Long, BDV[Double])]) = {
31 |
32 | val Elogbeta = dirichletExpectation(lambda)
33 | val expElogbeta = exp(Elogbeta)
34 |
35 | val statsAndgammaArray: (RDD[(BDM[Double], Array[(Long, BDV[Double])])]) =
36 | batch.mapPartitions { docs =>
37 | val stat = BDM.zeros[Double](k, vocabSize)
38 | val gammaList = new ArrayBuffer[(Long, BDV[Double])]
39 | docs.foreach { doc =>
40 | val termCounts = doc._2
41 | val (ids: List[Int], cts: Array[Double]) = termCounts match {
42 | case v: DenseVector => ((0 until v.size).toList, v.values)
43 | case v: SparseVector => (v.indices.toList, v.values)
44 | case v => throw new IllegalArgumentException("Online LDA does not support vector type "
45 | + v.getClass)
46 | }
47 |
48 | // Initialize the variational distribution q(theta|gamma) for the mini-batch
49 | var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K
50 | var Elogthetad = digamma(gammad) - digamma(sum(gammad)) // 1 * K
51 | var expElogthetad = exp(Elogthetad) // 1 * K
52 | val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids
53 |
54 | var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
55 | var meanchange = 1D
56 | val ctsVector = new BDV[Double](cts).t // 1 * ids
57 |
58 | // Iterate between gamma and phi until convergence
59 | while (meanchange > 1e-3) {
60 | val lastgamma = gammad
61 | // 1*K 1 * ids ids * k
62 | gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha
63 | Elogthetad = digamma(gammad) - digamma(sum(gammad))
64 | expElogthetad = exp(Elogthetad)
65 | phinorm = expElogthetad * expElogbetad + 1e-100
66 | meanchange = sum(abs(gammad - lastgamma)) / k
67 | }
68 |
69 | gammaList += Tuple2(doc._1, gammad.t)
70 | val m1 = expElogthetad.t
71 | val m2 = (ctsVector / phinorm).t.toDenseVector
72 | var i = 0
73 | while (i < ids.size) {
74 | stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i)
75 | i += 1
76 | }
77 | }
78 | Iterator((stat, gammaList.toArray))
79 | }
80 |
81 | val stats = statsAndgammaArray.map(_._1)
82 | val gammaArray = statsAndgammaArray.flatMap(_._2)
83 |
84 | (expElogbeta, stats, gammaArray)
85 | }
86 |
87 | /**
88 | * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation
89 | * uses digamma which is accurate but expensive.
90 | */
91 | def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = {
92 | val rowSum = sum(alpha(breeze.linalg.*, ::))
93 | val digAlpha = digamma(alpha)
94 | val digRowSum = digamma(rowSum)
95 | val result = digAlpha(::, breeze.linalg.*) - digRowSum
96 | result
97 | }
98 | }
99 |
100 | final class OnlineLDAOptimizer extends LDAOptimizer with Serializable{
101 | // LDA common parameters
102 | private var k: Int = 0
103 | private var corpusSize: Long = 0
104 | private var vocabSize: Int = 0
105 |
106 | /** alias for docConcentration */
107 | private var alpha: Double = 0
108 |
109 | /** (private[lda] for debugging) Get docConcentration */
110 | private[topicModeling] def getAlpha: Double = alpha
111 |
112 | /** alias for topicConcentration */
113 | private var eta: Double = 0
114 |
115 | /** (private[lda] for debugging) Get topicConcentration */
116 | private[topicModeling] def getEta: Double = eta
117 |
118 | private var randomGenerator: java.util.Random = null
119 |
120 | // Online LDA specific parameters
121 | // Learning rate is: (tau0 + t)^{-kappa}
122 | private var tau0: Double = 1024
123 | private var kappa: Double = 0.51
124 | private var miniBatchFraction: Double = 0.05
125 |
126 | // internal data structure
127 | private var docs: RDD[(Long, Vector)] = null
128 |
129 | /** Dirichlet parameter for the posterior over topics */
130 | private var lambda: BDM[Double] = null
131 |
132 | /** (private[lda] for debugging) Get parameter for topics */
133 | private[topicModeling] def getLambda: BDM[Double] = lambda
134 |
135 | /** Current iteration (count of invocations of [[next()]]) */
136 | private var iteration: Int = 0
137 | private var gammaShape: Double = 100
138 |
139 | /**
140 | * A (positive) learning parameter that downweights early iterations. Larger values make early
141 | * iterations count less.
142 | */
143 | def getTau0: Double = this.tau0
144 |
145 | /**
146 | * A (positive) learning parameter that downweights early iterations. Larger values make early
147 | * iterations count less.
148 | * Default: 1024, following the original Online LDA paper.
149 | */
150 | def setTau0(tau0: Double): this.type = {
151 | require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0")
152 | this.tau0 = tau0
153 | this
154 | }
155 |
156 | /**
157 | * Learning rate: exponential decay rate
158 | */
159 | def getKappa: Double = this.kappa
160 |
161 | /**
162 | * Learning rate: exponential decay rate---should be between
163 | * (0.5, 1.0] to guarantee asymptotic convergence.
164 | * Default: 0.51, based on the original Online LDA paper.
165 | */
166 | def setKappa(kappa: Double): this.type = {
167 | require(kappa >= 0, s"Online LDA kappa must be nonnegative, but was set to $kappa")
168 | this.kappa = kappa
169 | this
170 | }
171 |
172 | /**
173 | * Mini-batch fraction, which sets the fraction of document sampled and used in each iteration
174 | */
175 | def getMiniBatchFraction: Double = this.miniBatchFraction
176 |
177 | /**
178 | * Mini-batch fraction in (0, 1], which sets the fraction of document sampled and used in
179 | * each iteration.
180 | *
181 | * Note that this should be adjusted in synch with
182 | * so the entire corpus is used. Specifically, set both so that
183 | * maxIterations * miniBatchFraction >= 1.
184 | *
185 | * Default: 0.05, i.e., 5% of total documents.
186 | */
187 | def setMiniBatchFraction(miniBatchFraction: Double): this.type = {
188 | require(miniBatchFraction > 0.0 && miniBatchFraction <= 1.0,
189 | s"Online LDA miniBatchFraction must be in range (0,1], but was set to $miniBatchFraction")
190 | this.miniBatchFraction = miniBatchFraction
191 | this
192 | }
193 |
194 | /**
195 | * (private[lda])
196 | * Set the Dirichlet parameter for the posterior over topics.
197 | * This is only used for testing now. In the future, it can help support training stop/resume.
198 | */
199 | private[topicModeling] def setLambda(lambda: BDM[Double]): this.type = {
200 | this.lambda = lambda
201 | this
202 | }
203 |
204 | /**
205 | * (private[lda])
206 | * Used for random initialization of the variational parameters.
207 | * Larger value produces values closer to 1.0.
208 | * This is only used for testing currently.
209 | */
210 | private[topicModeling] def setGammaShape(shape: Double): this.type = {
211 | this.gammaShape = shape
212 | this
213 | }
214 |
215 | override def initialize(docs: RDD[(Long, Vector)], lda: LDA): this.type = {
216 | this.k = lda.getK
217 | this.corpusSize = docs.count()
218 | this.vocabSize = docs.first()._2.size
219 | this.alpha = if (lda.getDocConcentration == -1) 1.0 / k else lda.getDocConcentration
220 | this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration
221 | this.randomGenerator = new Random(lda.getSeed)
222 |
223 | this.docs = docs
224 | this.lambda = getGammaMatrix(k, vocabSize)
225 |
226 | this.iteration = 0
227 | this
228 | }
229 |
230 | def initialize(corpusSize: Long, vocabSize: Int, lda: LDA): this.type ={
231 | this.k = lda.getK
232 | this.corpusSize = corpusSize
233 | this.vocabSize = vocabSize
234 | this.alpha = if (lda.getDocConcentration == -1) 1.0 / k else lda.getDocConcentration
235 | this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration
236 | this.randomGenerator = new Random(lda.getSeed)
237 |
238 | // Initialize the variational distribution q(beta|lambda)
239 | this.lambda = getGammaMatrix(k, vocabSize)
240 | this.iteration = 0
241 | this
242 | }
243 |
244 | override def next(): this.type = {
245 | val batch = docs.sample(withReplacement = true, miniBatchFraction, randomGenerator.nextLong())
246 | submitMiniBatch(batch)
247 | }
248 |
249 | /**
250 | * Submit a subset (like 1%, decide by the miniBatchFraction) of the corpus to the Online LDA
251 | * model, and it will update the topic distribution adaptively for the terms appearing in the
252 | * subset.
253 | */
254 | private[topicModeling] def submitMiniBatch(batch: RDD[(Long, Vector)]): this.type = {
255 | if (batch.isEmpty()) return this
256 | iteration += 1
257 |
258 | val (expElogbeta, stats, _) =
259 | OnlineLDAOptimizer.inference(k, vocabSize, lambda, alpha, gammaShape, batch)
260 | val statsSum: BDM[Double] = stats.reduce(_ += _)
261 | val batchResult = statsSum :* expElogbeta
262 |
263 | // Note that this is an optimization to avoid batch.count
264 | update(batchResult, iteration, batch.count().toInt)
265 | this
266 | }
267 |
268 | override def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
269 | new OnlineLDAModel(Matrices.fromBreeze(lambda).transpose, this.alpha, this.gammaShape)
270 | }
271 |
272 | /**
273 | * Update lambda based on the batch submitted. batchSize can be different for each iteration.
274 | */
275 | private[topicModeling] def update(stat: BDM[Double], iter: Int, batchSize: Int): Unit = {
276 | // weight of the mini-batch.
277 | val weight = math.pow(getTau0 + iter, -getKappa)
278 |
279 | // Update lambda based on documents.
280 | lambda = lambda * (1 - weight) +
281 | (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) * weight
282 | }
283 |
284 | /**
285 | * Get a random matrix to initialize lambda
286 | */
287 | private def getGammaMatrix(row: Int, col: Int): BDM[Double] = {
288 | val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(
289 | randomGenerator.nextLong()))
290 | val gammaRandomGenerator = new Gamma(gammaShape, 1.0 / gammaShape)(randBasis)
291 | val temp = gammaRandomGenerator.sample(row * col).toArray
292 | new BDM[Double](col, row, temp).t
293 | }
294 |
295 | def perplexity(docs: RDD[(Long, Vector)]): Double = {
296 | val alphaVector = Vectors.dense(Array.fill(k)(alpha))
297 | val brzAlpha = alphaVector.toBreeze.toDenseVector
298 |
299 | val Elogbeta = OnlineLDAOptimizer.dirichletExpectation(lambda)
300 |
301 | val (_, _, gammaArray) = OnlineLDAOptimizer.inference(k, vocabSize, lambda, alpha, gammaShape, docs)
302 |
303 | var score = docs.join(gammaArray).map { case (id: Long, (termCounts: Vector, gammad: BDV[Double])) =>
304 | var docScore = 0.0D
305 |
306 | val Elogthetad: BDV[Double] = digamma(gammad) - digamma(sum(gammad))
307 |
308 | // E[log p(doc | theta, beta)]
309 | termCounts.foreachActive { case (idx, count) =>
310 | val x = Elogthetad + Elogbeta(::, idx)
311 | val a = max(x)
312 | docScore += count * (a + log(sum(exp(x :- a))))
313 | }
314 | // E[log p(theta | alpha) - log q(theta | gamma)]; assumes alpha is a vector
315 | docScore += sum((brzAlpha - gammad) :* Elogthetad)
316 | docScore += sum(lgamma(gammad) - lgamma(brzAlpha))
317 | docScore += lgamma(sum(brzAlpha)) - lgamma(sum(gammad))
318 |
319 | docScore
320 | }.sum()
321 |
322 | // E[log p(beta | eta) - log q (beta | lambda)]; assumes eta is a scalar
323 | score += sum((eta - lambda) :* Elogbeta)
324 | score += sum(lgamma(lambda) - lgamma(eta))
325 |
326 | val sumEta = eta * vocabSize
327 | score += sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*))))
328 |
329 | math.exp(-1 * score / vocabSize)
330 | }
331 | }
332 |
333 |
--------------------------------------------------------------------------------