├── .gitignore ├── LICENSE ├── README.md ├── common ├── __init__.py ├── document.py ├── document_test.py ├── lda.proto ├── model.py ├── model_test.py ├── ordered_sparse_topic_histogram.py ├── ordered_sparse_topic_histogram_test.py ├── recordio.py ├── recordio_test.py ├── vocabulary.py └── vocabulary_test.py ├── inference ├── __init__.py ├── multi_chain_gibbs_sampler.py ├── multi_chain_gibbs_sampler_test.py ├── sparselda_gibbs_sampler.py └── sparselda_gibbs_sampler_test.py ├── lda_inferencer.py ├── lda_trainer.py ├── lda_trainer.sh ├── testdata ├── corpus │ └── document1.dat ├── lda_model │ ├── lda.global_topic_hist │ ├── lda.hyper_params │ └── lda.word_topic_hist ├── recordio.dat └── vocabulary.dat └── training ├── __init__.py ├── model_evaluator.py ├── model_evaluator_test.py ├── sparselda_train_gibbs_sampler.py ├── sparselda_train_gibbs_sampler_test.py ├── topic_words_stat.py └── topic_words_stat_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | 3 | # C extensions 4 | *.so 5 | 6 | # Packages 7 | *.egg 8 | *.egg-info 9 | dist 10 | build 11 | eggs 12 | parts 13 | bin 14 | var 15 | sdist 16 | develop-eggs 17 | .installed.cfg 18 | lib 19 | lib64 20 | 21 | # Installer logs 22 | pip-log.txt 23 | 24 | # Unit test / coverage reports 25 | .coverage 26 | .tox 27 | nosetests.xml 28 | 29 | # Translations 30 | *.mo 31 | 32 | # Mr Developer 33 | .mr.developer.cfg 34 | .project 35 | .pydevproject 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## python-sparselda 2 | ================ 3 | python-sparselda is a Latent Dirichlet Allocation(LDA) topic modeling package based on SparseLDA Gibbs Sampling inference algorithm, and written in Python 2.6 or newer, Python 3.0 or newer excluded. 4 | 5 | Frankly, python-sparselda is just a mini project, we hope it can help you better understand the standard LDA and SparseLDA algorithms. RTFSC for more details. Have fun. 6 | 7 | Please use the github issue tracker for python-sparselda at: 8 | https://github.com/fandywang/python-sparselda/issues 9 | 10 | ## Usage 11 | ================ 12 | ### 1. Install Google Protocol Buffers 13 | python-sparselda serialize and persistent store the lda model and checkpoint based on protobuf, so you should install it first. 14 | 15 | wget https://protobuf.googlecode.com/files/protobuf-2.5.0.tar.bz2 16 | tar -zxvf protobuf-2.5.0.tar.bz2 17 | cd protobuf-2.5.0 18 | ./configure 19 | make 20 | sudo make install 21 | cd python 22 | python ./setup.py build 23 | sudo python ./setup.py install 24 | 25 | cd python-sparselda/common 26 | protoc -I=. --python_out=. lda.proto 27 | 28 | ### 2. Training 29 | #### 2.1 Command line 30 | Usage: python lda_trainer.py [options]. 31 | 32 | Options: 33 | -h, --help show this help message and exit 34 | --corpus_dir=CORPUS_DIR 35 | the corpus directory. 36 | --vocabulary_file=VOCABULARY_FILE 37 | the vocabulary file. 38 | --num_topics=NUM_TOPICS 39 | the num of topics. 40 | --topic_prior=TOPIC_PRIOR 41 | the topic prior alpha. 42 | --word_prior=WORD_PRIOR 43 | the word prior beta. 44 | --total_iterations=TOTAL_ITERATIONS 45 | the total iteration. 46 | --model_dir=MODEL_DIR 47 | the model directory. 48 | --save_model_interval=SAVE_MODEL_INTERVAL 49 | the interval to save lda model. 50 | --topic_word_accumulated_prob_threshold=TOPIC_WORD_ACCUMULATED_PROB_THRESHOLD 51 | the accumulated_prob_threshold of topic top words. 52 | --save_checkpoint_interval=SAVE_CHECKPOINT_INTERVAL 53 | the interval to save checkpoint. 54 | --checkpoint_dir=CHECKPOINT_DIR 55 | the checkpoint directory. 56 | --compute_loglikelihood_interval=COMPUTE_LOGLIKELIHOOD_INTERVAL 57 | the interval to compute loglikelihood. 58 | 59 | #### 2.2 Input corpus format 60 | The corpus for training/estimating the model have the line format as follows: 61 | 62 | [document1] 63 | [document2] 64 | ... 65 | [documentM] 66 | 67 | in which each line is one document. [documenti] is the ith document of the dataset that consists of a list of Ni words/terms. 68 | 69 | [documenti] = [wordi1]\t[wordi2]\t...\t[wordiNi] 70 | 71 | in which all [wordij] <i=1...M, j=1...Ni> are text strings and they are separated by the tab character. 72 | 73 | **Note that** the terms document and word here are abstract and should not only be understood as normal text documents. 74 | This's because LDA can be used to discover the underlying topic structures of any kind of discrete data. Therefore, 75 | python-sparselda is not limited to text and natural language processing but can also be applied to other kinds of data 76 | like images. 77 | 78 | Also, keep in mind that for text/Web data collections, you should first preprocess the data (e.g., word segment, 79 | removing stopwords and rare words, stemming, etc.) before estimating with python-sparselda. 80 | 81 | #### 2.3 Input vocabulary format 82 | The vocabulary for training/estimating the model have the line format as follows: 83 | 84 | [word1] 85 | [word2] 86 | ... 87 | [wordV] 88 | 89 | in which each line is a unique word. Words only appear in vocabulary will be considered for parameter estimation. 90 | 91 | #### 2.4 Outputs 92 | ##### 1) LDA Model 93 | It includs three files. 94 | * lda.topic_word_hist: This file contains the word-topic histograms, i.e., N(word|topic). 95 | * lda.global_topic_hist: This file contains the global topic histogram, i.e., N(topic). 96 | * lda.hyper_params: This file contails the hyperparams, i.e., alpha and beta. 97 | 98 | ##### 2) Checkpoint 99 | Every `--save_checkpoint_interval` iterations, the lda_trainer will dump current checkpoint for fault tolerance. 100 | The checkpoint mainly includes two types files. 101 | * LDA Model: See above. 102 | * Corpus: This directory contains serialized documents. 103 | 104 | ##### 3) Topic words 105 | * lda.topic_words: This file contains most likely words of each topic. The number of topic top words is depend on `--topic_word_accumulated_prob_threshold`. 106 | 107 | ### 3. Inference 108 | Please refer the example: lda_inferencer.py. 109 | 110 | **Note that** we strongly recommend you to use `MultiChainGibbsSampler` class for trade off between efficiency and effectiveness. 111 | 112 | ### 4. Evaluation 113 | Instead of manual evaluation, we want to evaluate topics quality automatically, and filter out a few meaningless topics to enchance the inference effect. 114 | 115 | ## TODO 116 | ================ 117 | 1. Hyperparameters optimization. 118 | 2. Memory optimization. 119 | 3. More experiments. 120 | 4. Data and model parallelization. 121 | 122 | ## References 123 | ================ 124 | 1. Blei, A. Ng, and M. Jordan. [Latent Dirichlet allocation](http://www.cs.princeton.edu/~blei/papers/BleiNgJordan2003.pdf). Journal of Machine Learning Research, 2003. 125 | 2. Gregor Heinrich. [Parameter estimation for text analysis](http://www.arbylon.net/publications/text-est.pdf). Technical Note, 2004. 126 | 3. Griffiths, T. L., & Steyvers, M. [Finding scientific topics](http://www.pnas.org/content/101/suppl.1/5228.full.pdf). Proceedings of the National Academy of Sciences(PNAS), 2004. 127 | 4. I. Porteous, D. Newman, A. Ihler, A. Asuncion, P. Smyth, and M. Welling. [Fast collapsed Gibbs sampling for latent Dirichlet allocation](http://www.ics.uci.edu/~asuncion/pubs/KDD_08.pdf). In SIGKDD, 2008. 128 | 5. Limin Yao, David Mimno, Andrew McCallum. [Efficient methods for topic model inference on streaming document collections](https://www.cs.umass.edu/~mimno/papers/fast-topic-model.pdf), In SIGKDD, 2009. 129 | 6. Newman et al. [Distributed Inference for Latent Dirichlet Allocation](http://www.csee.ogi.edu/~zak/cs506-pslc/dist_lda.pdf), NIPS 2007. 130 | 7. X. Wei, W. Bruce Croft. [LDA-based document models for ad hoc retrieval](http://www.bradblock.com/LDA_Based_Document_Models_for_Ad_hoc_Retrieval.pdf). In Proc. SIGIR. 2006. 131 | 7. Rickjin, [LDA 数学八卦](http://vdisk.weibo.com/s/q0sGh/1360334108?utm_source=weibolife). Technical Note, 2013. 132 | 8. Yi Wang, Hongjie Bai, Matt Stanton, Wen-Yen Chen, and Edward Y. Chang. [PLDA: Parallel Latent Dirichlet Allocation for Large-scale Applications](http://plda.googlecode.com/files/aaim.pdf). AAIM 2009. 133 | 134 | ## Links 135 | =============== 136 | Here are some pointers to other implementations of LDA. 137 | 138 | 1. [LDA-C](http://www.cs.princeton.edu/~blei/lda-c/index.html): A C implementation of variational EM for latent Dirichlet allocation (LDA), a topic model for text or other discrete data. 139 | 2. [GibbsLDA++](http://gibbslda.sourceforge.net/): A C/C++ implementation of Latent Dirichlet Allocation (LDA) using Gibbs Sampling technique for parameter estimation and inference. 140 | 3. [plda/plda+](https://code.google.com/p/plda/): A parallel C++ implementation of Latent Dirichlet Allocation (LDA). 141 | 4. [Mr. LDA](https://github.com/lintool/Mr.LDA): A Latent Dirichlet Allocation topic modeling package based on Variational Bayesian learning approach using MapReduce and Hadoop, developed by a Cloud Computing Research Team in University of Maryland, College Park. 142 | 5. [Yahoo_LDA](https://github.com/sudar/Yahoo_LDA): Y!LDA Topic Modelling Framework, it provides a fast C++ implementation of the inferencing algorithm which can use both multi-core parallelism and multi-machine parallelism using a hadoop cluster. It can infer about a thousand topics on a million document corpus while running for a thousand iterations on an eight core machine in one day. 143 | 6. [Mahout](https://cwiki.apache.org/confluence/display/MAHOUT/Latent+Dirichlet+Allocation): Mahout's goal is to build scalable machine learning libraries. 144 | 7. [MALLET ](http://mallet.cs.umass.edu/): A Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text. 145 | 8. [ompi-lda](https://code.google.com/p/ompi-lda/): OpenMP and MPI Based Paralllel Implementation of LDA. 146 | 9. [lda-go](https://code.google.com/p/lda-go/): Gibbs sampling training and inference of the Latent Dirichlet Allocation model written in Google's Go programming language. 147 | 10. [Matlab Topic Modeling Toolbox](http://psiexp.ss.uci.edu/research/programs_data/toolbox.htm) 148 | 11. [lda-j](http://www.arbylon.net/projects/): Java version of LDA-C and a short Java version of Gibbs Sampling for LDA. 149 | 150 | ## Copyright and license 151 | ============================== 152 | Copyright(c) 2013 python-sparselda project. 153 | 154 | Licensed under the Apache License, Version 2.0 (the "License"); 155 | you may not use this work except in compliance with the License. 156 | You may obtain a copy of the License in the LICENSE file, or at: 157 | 158 | http://www.apache.org/licenses/LICENSE-2.0 159 | 160 | Unless required by applicable law or agreed to in writing, software 161 | distributed under the License is distributed on an "AS IS" BASIS, 162 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 163 | See the License for the specific language governing permissions and 164 | limitations under the License. 165 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankazhao/python-sparselda/f84d05ee99899ceadb8371dc52cc76f5aa0da934/common/__init__.py -------------------------------------------------------------------------------- /common/document.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import random 20 | 21 | from lda_pb2 import DocumentPB 22 | from model import Model 23 | from ordered_sparse_topic_histogram import OrderedSparseTopicHistogram 24 | from vocabulary import Vocabulary 25 | 26 | class Word(object): 27 | 28 | def __init__(self, id, topic): 29 | self.id = id 30 | self.topic = topic 31 | 32 | def __str__(self): 33 | return '' 34 | 35 | 36 | class Document(object): 37 | 38 | def __init__(self, num_topics): 39 | self.num_topics = num_topics 40 | self.words = [] # word occurances of the document, 41 | # item fmt: Word 42 | self.doc_topic_hist = OrderedSparseTopicHistogram(num_topics) # N(z|d) 43 | 44 | def parse_from_tokens(self, doc_tokens, rand, vocabulary, model = None): 45 | """Parse the text document from tokens. Only tokens in vocabulary 46 | and model will be considered. 47 | """ 48 | self.words = [] 49 | self.doc_topic_hist = OrderedSparseTopicHistogram(self.num_topics) 50 | 51 | for token in doc_tokens: 52 | word_index = vocabulary.word_index(token) 53 | if (word_index != -1 and 54 | (model == None or model.has_word(word_index))): 55 | # initialize a random topic for cur word 56 | topic = rand.randint(0, self.num_topics - 1) 57 | self.words.append(Word(word_index, topic)) 58 | self.doc_topic_hist.increase_topic(topic, 1) 59 | 60 | def serialize_to_string(self): 61 | """Serialize document to DocumentPB string. 62 | """ 63 | document_pb = DocumentPB() 64 | for word in self.words: 65 | word_pb = document_pb.words.add() 66 | word_pb.id = word.id 67 | word_pb.topic = word.topic 68 | return document_pb.SerializeToString() 69 | 70 | def parse_from_string(self, document_str): 71 | """Parse document from DocumentPB serialized string. 72 | """ 73 | self.words = [] 74 | self.doc_topic_hist = OrderedSparseTopicHistogram(self.num_topics) 75 | self.document_pb = DocumentPB() 76 | self.document_pb.ParseFromString(document_str) 77 | for word_pb in self.document_pb.words: 78 | self.words.append(Word(word_pb.id, word_pb.topic)) 79 | self.increase_topic(word_pb.topic, 1) 80 | 81 | def num_words(self): 82 | return len(self.words) 83 | 84 | def get_topic_count(self, topic): 85 | """Returns N(z|d). 86 | """ 87 | return self.doc_topic_hist.count(topic) 88 | 89 | def increase_topic(self, topic, count = 1): 90 | """Adds count to current topic, and returns the updated count. 91 | """ 92 | return self.doc_topic_hist.increase_topic(topic, count) 93 | 94 | def decrease_topic(self, topic, count = 1): 95 | """Subtracts count from current topic, and returns the updated count. 96 | """ 97 | return self.doc_topic_hist.decrease_topic(topic, count) 98 | 99 | def __str__(self): 100 | """Outputs a human-readable representation of the model. 101 | """ 102 | document_str = [] 103 | for word in self.words: 104 | document_str.append(str(word)) 105 | document_str.append(str(self.doc_topic_hist)) 106 | return '\n'.join(document_str) 107 | 108 | -------------------------------------------------------------------------------- /common/document_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import logging 20 | import random 21 | import unittest 22 | 23 | from document import Document 24 | from model import Model 25 | from vocabulary import Vocabulary 26 | 27 | class DocumentTest(unittest.TestCase): 28 | 29 | def setUp(self): 30 | self.document = Document(20) 31 | self.vocabulary = Vocabulary() 32 | self.vocabulary.load("../testdata/vocabulary.dat") 33 | 34 | self.model = Model(20) 35 | self.model.load('../testdata/lda_model') 36 | 37 | self.doc_tokens = ['macbook', 'ipad', # exist in vocabulary and model 38 | 'mac os x', 'chrome', # only exist in vocabulary 39 | 'nokia', 'null'] # inexistent 40 | 41 | def test_parse_from_tokens(self): 42 | # initialize document during lda training. 43 | self.document.parse_from_tokens( 44 | self.doc_tokens, random, self.vocabulary) 45 | 46 | self.assertEqual(4, self.document.num_words()) 47 | topic_hist = self.document.doc_topic_hist 48 | for i in xrange(len(topic_hist.non_zeros) - 1): 49 | self.assertGreaterEqual(topic_hist.non_zeros[i].count, 50 | topic_hist.non_zeros[i + 1].count) 51 | logging.info(str(self.document)) 52 | 53 | # initialize document during lda inference. 54 | self.document.parse_from_tokens( 55 | self.doc_tokens, random, self.vocabulary, self.model) 56 | self.assertEqual(2, self.document.num_words()) 57 | for i in xrange(len(topic_hist.non_zeros) - 1): 58 | self.assertGreaterEqual(topic_hist.non_zeros[i].count, 59 | topic_hist.non_zeros[i + 1].count) 60 | # print str(self.document) 61 | 62 | def test_serialize_and_parse(self): 63 | self.document.parse_from_tokens( 64 | self.doc_tokens, random, self.vocabulary) 65 | 66 | test_doc = Document(20) 67 | test_doc.parse_from_string(self.document.serialize_to_string()) 68 | 69 | self.assertEqual(self.document.num_words(), test_doc.num_words()) 70 | self.assertEqual(str(self.document), str(test_doc)) 71 | 72 | def test_increase_decrease_topic(self): 73 | self.document.parse_from_tokens( 74 | self.doc_tokens, random, self.vocabulary, self.model) 75 | self.document.increase_topic(0, 5) 76 | self.document.increase_topic(4, 5) 77 | self.document.increase_topic(9, 5) 78 | topic_hist = self.document.doc_topic_hist 79 | for i in xrange(len(topic_hist.non_zeros) - 1): 80 | self.assertGreaterEqual(topic_hist.non_zeros[i].count, 81 | topic_hist.non_zeros[i + 1].count) 82 | 83 | self.document.decrease_topic(4, 4) 84 | self.document.decrease_topic(9, 3) 85 | for i in xrange(len(topic_hist.non_zeros) - 1): 86 | self.assertGreaterEqual(topic_hist.non_zeros[i].count, 87 | topic_hist.non_zeros[i + 1].count) 88 | 89 | if __name__ == '__main__': 90 | unittest.main() 91 | 92 | -------------------------------------------------------------------------------- /common/lda.proto: -------------------------------------------------------------------------------- 1 | package lda; 2 | 3 | // Copyright(c) 2013 python-sparselda project. 4 | // Author: Lifeng Wang (ofandywang@gmail.com) 5 | // 6 | // Licensed under the Apache License, Version 2.0 (the "License"); 7 | // you may not use this file except in compliance with the License. 8 | // You may obtain a copy of the License at 9 | // 10 | // http://www.apache.org/licenses/LICENSE-2.0 11 | // 12 | // Unless required by applicable law or agreed to in writing, software 13 | // distributed under the License is distributed on an "AS IS" BASIS, 14 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | // See the License for the specific language governing permissions and 16 | // limitations under the License. 17 | // 18 | // Using Google Protocal Buffers (protobuf) to serialize corpus and the lda model. 19 | // 20 | // See https://developers.google.com/protocol-buffers/docs/pythontutorial for 21 | // more details. 22 | 23 | message WordPB { 24 | optional int32 id = 1; // index of current word 25 | optional int32 topic = 2; // topic assignment to current word 26 | } 27 | 28 | message DocumentPB { 29 | repeated WordPB words = 1; 30 | } 31 | 32 | message NonZeroPB { 33 | optional int32 topic = 1; 34 | optional int64 count = 2; 35 | } 36 | 37 | // the sparse topic histogram 38 | message SparseTopicHistogramPB { 39 | repeated NonZeroPB non_zeros = 1; 40 | } 41 | 42 | // N(w|z) 43 | message WordTopicHistogramPB { 44 | optional int32 word = 1; 45 | optional SparseTopicHistogramPB sparse_topic_hist = 2; 46 | } 47 | 48 | // N(z), the dense topic histogram. 49 | message GlobalTopicHistogramPB { 50 | repeated int64 topic_counts = 1; 51 | } 52 | 53 | // Dirichlet prior 54 | message HyperParamsPB { 55 | optional double topic_prior = 1; 56 | optional double word_prior = 2; 57 | } 58 | 59 | -------------------------------------------------------------------------------- /common/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import logging 20 | import os 21 | 22 | from lda_pb2 import GlobalTopicHistogramPB 23 | from lda_pb2 import HyperParamsPB 24 | from lda_pb2 import WordTopicHistogramPB 25 | from lda_pb2 import SparseTopicHistogramPB 26 | from ordered_sparse_topic_histogram import OrderedSparseTopicHistogram 27 | from recordio import RecordReader 28 | from recordio import RecordWriter 29 | from vocabulary import Vocabulary 30 | 31 | class HyperParams(object): 32 | 33 | # TODO(fandywang): optimize the hyper_params. 34 | # Because we find that an asymmetric Dirichlet prior over the document- 35 | # topic distributions has substantial advantages over a symmetric prior, 36 | # while an asymmetric prior over topic-word distributions provides no 37 | # real benefit. 38 | # 39 | # See 'Hanna Wallach, David Mimno, and Andrew McCallum. 2009. 40 | # Rethinking LDA: Why priors matter. In Proceedings of NIPS-09, 41 | # Vancouver, BC.' for more details. 42 | def __init__(self, topic_prior = 0.01, word_prior = 0.1): 43 | self.topic_prior = topic_prior 44 | self.word_prior = word_prior 45 | 46 | def serialize_to_string(self): 47 | hyper_params_pb = HyperParamsPB() 48 | hyper_params_pb.topic_prior = self.topic_prior 49 | hyper_params_pb.word_prior = self.word_prior 50 | return hyper_params_pb.SerializeToString() 51 | 52 | def parse_from_string(self, hyper_params_str): 53 | hyper_params_pb = HyperParamsPB() 54 | hyper_params_pb.ParseFromString(hyper_params_str) 55 | self.topic_prior = hyper_params_pb.topic_prior 56 | self.word_prior = hyper_params_pb.word_prior 57 | 58 | def __str(self): 59 | return '' 61 | 62 | 63 | class Model(object): 64 | """Model implements the sparselda model. 65 | It includes the following parts: 66 | 0. num_topics, represents |K|. 67 | 1. global_topic_hist, represents N(z). 68 | 2. word_topic_hist, represents N(w|z). 69 | 3. hyper_params 70 | 3.1 topic_prior, represents the dirichlet prior of topic \alpha. 71 | 3.2 word_prior, represents the dirichlet prior of word \beta. 72 | """ 73 | GLOABLE_TOPIC_HIST_FILENAME = "lda.global_topic_hist" 74 | WORD_TOPIC_HIST_FILENAME = "lda.word_topic_hist" 75 | HYPER_PARAMS_FILENAME = "lda.hyper_params" 76 | 77 | def __init__(self, num_topics, topic_prior = 0.1, word_prior = 0.01): 78 | self.num_topics = num_topics 79 | 80 | self.global_topic_hist = [0] * self.num_topics # item fmt: N(z) 81 | self.word_topic_hist = {} # item fmt: w -> N(w|z) 82 | 83 | self.hyper_params = HyperParams() 84 | self.hyper_params.topic_prior = topic_prior # alpha, default symmetrical 85 | self.hyper_params.word_prior = word_prior # beta, default symmetrical 86 | 87 | def save(self, model_dir): 88 | logging.info('Save lda model to %s.' % model_dir) 89 | if not os.path.exists(model_dir): 90 | os.mkdir(model_dir) 91 | 92 | self._save_word_topic_hist(model_dir + "/" + 93 | self.__class__.WORD_TOPIC_HIST_FILENAME) 94 | self._save_global_topic_hist(model_dir + "/" + 95 | self.__class__.GLOABLE_TOPIC_HIST_FILENAME) 96 | self._save_hyper_params(model_dir + "/" + 97 | self.__class__.HYPER_PARAMS_FILENAME) 98 | 99 | def load(self, model_dir): 100 | logging.info('Load lda model from %s.' % model_dir) 101 | assert self._load_global_topic_hist(model_dir + "/" + 102 | self.__class__.GLOABLE_TOPIC_HIST_FILENAME) 103 | self.num_topics = len(self.global_topic_hist) 104 | assert self._load_word_topic_hist(model_dir + "/" + 105 | self.__class__.WORD_TOPIC_HIST_FILENAME) 106 | assert self._load_hyper_params(model_dir + "/" + 107 | self.__class__.HYPER_PARAMS_FILENAME) 108 | 109 | def _save_global_topic_hist(self, filename): 110 | fp = open(filename, 'wb') 111 | record_writer = RecordWriter(fp) 112 | global_topic_hist_pb = GlobalTopicHistogramPB() 113 | for topic_count in self.global_topic_hist: 114 | global_topic_hist_pb.topic_counts.append(topic_count) 115 | record_writer.write(global_topic_hist_pb.SerializeToString()) 116 | fp.close() 117 | 118 | def _save_word_topic_hist(self, filename): 119 | fp = open(filename, 'wb') 120 | record_writer = RecordWriter(fp) 121 | for word, ordered_sparse_topic_hist in self.word_topic_hist.iteritems(): 122 | word_topic_hist_pb = WordTopicHistogramPB() 123 | word_topic_hist_pb.word = word 124 | word_topic_hist_pb.sparse_topic_hist.ParseFromString( 125 | ordered_sparse_topic_hist.serialize_to_string()) 126 | record_writer.write(word_topic_hist_pb.SerializeToString()) 127 | fp.close() 128 | 129 | def _save_hyper_params(self, filename): 130 | fp = open(filename, 'wb') 131 | record_writer = RecordWriter(fp) 132 | record_writer.write(self.hyper_params.serialize_to_string()) 133 | fp.close() 134 | 135 | def _load_global_topic_hist(self, filename): 136 | logging.info('Loading global_topic_hist vector N(z).') 137 | self.global_topic_hist = [] 138 | 139 | fp = open(filename, "rb") 140 | record_reader = RecordReader(fp) 141 | blob = record_reader.read() 142 | fp.close() 143 | if blob == None: 144 | logging.error('GlobalTopicHist is nil, file %s' % filename) 145 | return False 146 | 147 | global_topic_hist_pb = GlobalTopicHistogramPB() 148 | global_topic_hist_pb.ParseFromString(blob) 149 | for topic_count in global_topic_hist_pb.topic_counts: 150 | self.global_topic_hist.append(topic_count) 151 | return True 152 | 153 | def _load_word_topic_hist(self, filename): 154 | logging.info('Loading word_topic_hist matrix N(w|z).') 155 | self.word_topic_hist.clear() 156 | 157 | fp = open(filename, "rb") 158 | record_reader = RecordReader(fp) 159 | while True: 160 | blob = record_reader.read() 161 | if blob == None: 162 | break 163 | 164 | word_topic_hist_pb = WordTopicHistogramPB() 165 | word_topic_hist_pb.ParseFromString(blob) 166 | 167 | ordered_sparse_topic_hist = \ 168 | OrderedSparseTopicHistogram(self.num_topics) 169 | ordered_sparse_topic_hist.parse_from_string( 170 | word_topic_hist_pb.sparse_topic_hist.SerializeToString()) 171 | self.word_topic_hist[word_topic_hist_pb.word] = \ 172 | ordered_sparse_topic_hist 173 | fp.close() 174 | return (len(self.word_topic_hist) > 0) 175 | 176 | def _load_hyper_params(self, filename): 177 | logging.info('Loading hyper_params topic_prior and word_prior.') 178 | fp = open(filename, "rb") 179 | record_reader = RecordReader(fp) 180 | blob = record_reader.read() 181 | fp.close() 182 | if blob == None: 183 | logging.error('HyperParams is nil, file %s' % filename) 184 | return False 185 | 186 | self.hyper_params.parse_from_string(blob) 187 | return True 188 | 189 | def has_word(self, word): 190 | return word in self.word_topic_hist 191 | 192 | def get_word_topic_dist(self, vocab_size): 193 | """Returns topic-word distributions matrix p(w|z), indexed by word. 194 | """ 195 | word_topic_dist = {} 196 | word_prior_sum = self.hyper_params.word_prior * vocab_size 197 | 198 | # TODO(fandywang): only cache sub-matrix p(w|z) of frequency words. 199 | for word_id, ordered_sparse_topic_hist in self.word_topic_hist.iteritems(): 200 | dense_topic_dist = [] 201 | for topic in xrange(self.num_topics): 202 | dense_topic_dist.append(self.hyper_params.word_prior / 203 | (word_prior_sum + self.global_topic_hist[topic])) 204 | for non_zero in ordered_sparse_topic_hist.non_zeros: 205 | dense_topic_dist[non_zero.topic] = \ 206 | (self.hyper_params.word_prior + non_zero.count) / \ 207 | (word_prior_sum + self.global_topic_hist[non_zero.topic]) 208 | word_topic_dist[word_id] = dense_topic_dist 209 | 210 | return word_topic_dist 211 | 212 | def __str__(self): 213 | """Outputs a human-readable representation of the model. 214 | """ 215 | model_str = [] 216 | model_str.append('NumTopics: %d' % self.num_topics) 217 | model_str.append('GlobalTopicHist: %s' % str(self.global_topic_hist)) 218 | model_str.append('WordTopicHist: ') 219 | for word, ordered_sparse_topic_hist in self.word_topic_hist.iteritems(): 220 | model_str.append('word: %d' % word) 221 | model_str.append('topic_hist: %s' % str(ordered_sparse_topic_hist)) 222 | model_str.append('HyperParams: ') 223 | model_str.append(str(self.hyper_params)) 224 | return '\n'.join(model_str) 225 | 226 | -------------------------------------------------------------------------------- /common/model_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import os 20 | import unittest 21 | from model import Model 22 | from ordered_sparse_topic_histogram import OrderedSparseTopicHistogram 23 | 24 | class ModelTest(unittest.TestCase): 25 | 26 | def setUp(self): 27 | self.model = Model(20) 28 | 29 | # initialize self.model.global_topic_hist and 30 | # self.model.word_topic_hist 31 | for i in xrange(10): 32 | ordered_sparse_topic_hist = OrderedSparseTopicHistogram(20) 33 | for j in xrange(10 + i): 34 | ordered_sparse_topic_hist.increase_topic(j, j + 1) 35 | self.model.global_topic_hist[j] += j + 1 36 | self.model.word_topic_hist[i] = ordered_sparse_topic_hist 37 | 38 | def test_save_and_load(self): 39 | model_dir = '../testdata/lda_model' 40 | self.model.save(model_dir) 41 | self.assertTrue(os.path.exists(model_dir)) 42 | 43 | new_model = Model(20) 44 | new_model.load(model_dir) 45 | 46 | self.assertEqual(new_model.num_topics, self.model.num_topics) 47 | self.assertEqual(len(new_model.word_topic_hist), 48 | len(self.model.word_topic_hist)) 49 | 50 | for word, new_sparse_topic_hist in new_model.word_topic_hist.iteritems(): 51 | self.assertTrue(word in self.model.word_topic_hist) 52 | sparse_topic_hist = self.model.word_topic_hist[word] 53 | self.assertEqual(new_sparse_topic_hist.size(), 54 | sparse_topic_hist.size()) 55 | 56 | for j in xrange(new_sparse_topic_hist.size()): 57 | self.assertEqual(new_sparse_topic_hist.non_zeros[j].topic, 58 | sparse_topic_hist.non_zeros[j].topic) 59 | self.assertEqual(new_sparse_topic_hist.non_zeros[j].count, 60 | sparse_topic_hist.non_zeros[j].count) 61 | 62 | self.assertEqual(new_model.hyper_params.topic_prior, 63 | self.model.hyper_params.topic_prior) 64 | self.assertEqual(new_model.hyper_params.word_prior, 65 | self.model.hyper_params.word_prior) 66 | 67 | # print self.model 68 | 69 | def test_has_word(self): 70 | self.assertTrue(self.model.has_word(0)) 71 | self.assertTrue(self.model.has_word(2)) 72 | self.assertTrue(self.model.has_word(4)) 73 | self.assertTrue(self.model.has_word(6)) 74 | self.assertTrue(self.model.has_word(8)) 75 | self.assertFalse(self.model.has_word(10)) 76 | self.assertFalse(self.model.has_word(12)) 77 | self.assertFalse(self.model.has_word(14)) 78 | self.assertFalse(self.model.has_word(16)) 79 | self.assertFalse(self.model.has_word(18)) 80 | 81 | def test_get_word_topic_dist(self): 82 | word_topic_dist = self.model.get_word_topic_dist(10) 83 | self.assertTrue(len(word_topic_dist)) 84 | # print word_topic_dist 85 | 86 | if __name__ == '__main__': 87 | unittest.main() 88 | 89 | -------------------------------------------------------------------------------- /common/ordered_sparse_topic_histogram.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from lda_pb2 import SparseTopicHistogramPB 20 | 21 | class NonZero(object): 22 | 23 | def __init__(self, topic, count = 0): 24 | self.topic = topic 25 | self.count = count 26 | 27 | def __str__(self): 28 | return '' 29 | 30 | 31 | class OrderedSparseTopicHistogram(object): 32 | """OrderedSparseTopicHistogram implements the class of sparse topic 33 | histogram, which maintains the topics in descending orderd by their counts. 34 | """ 35 | def __init__(self, num_topics): 36 | self.non_zeros = [] # item fmt: NonZero 37 | self.num_topics = num_topics 38 | 39 | def size(self): 40 | """Returns the size of the sparse sequence 'self.non_zeros'. 41 | """ 42 | return len(self.non_zeros) 43 | 44 | def serialize_to_string(self): 45 | """Serialize the OrderedSparseTopicHistogram to SparseTopicHistogramPB 46 | string. 47 | """ 48 | sparse_topic_hist = SparseTopicHistogramPB() 49 | for non_zero in self.non_zeros: 50 | non_zero_pb = sparse_topic_hist.non_zeros.add() 51 | non_zero_pb.topic = non_zero.topic 52 | non_zero_pb.count = non_zero.count 53 | return sparse_topic_hist.SerializeToString() 54 | 55 | def parse_from_string(self, sparse_topic_hist_str): 56 | """Parse OrderedSparseTopicHistogram from SparseTopicHistogramPB 57 | serialized string. 58 | """ 59 | self.non_zeros = [] 60 | sparse_topic_hist = SparseTopicHistogramPB() 61 | sparse_topic_hist.ParseFromString(sparse_topic_hist_str) 62 | for non_zero_pb in sparse_topic_hist.non_zeros: 63 | self.non_zeros.append(NonZero(non_zero_pb.topic, non_zero_pb.count)) 64 | 65 | def count(self, topic): 66 | """Returns the count of topic 67 | """ 68 | for non_zero in self.non_zeros: 69 | if non_zero.topic == topic: 70 | return non_zero.count 71 | return 0 72 | 73 | def increase_topic(self, topic, count = 1): 74 | """Adds count on topic, and returns the updated count. 75 | """ 76 | assert (topic >= 0 and topic < self.num_topics and count > 0) 77 | 78 | index = -1 79 | for i, non_zero in enumerate(self.non_zeros): 80 | if non_zero.topic == topic: 81 | non_zero.count += count 82 | index = i 83 | break 84 | 85 | if index == -1: 86 | self.non_zeros.append(NonZero(topic, count)) 87 | index = len(self.non_zeros) - 1 88 | 89 | # ensure that topics sorted by their counts. 90 | non_zero = self.non_zeros[index] 91 | while index > 0 and non_zero.count > self.non_zeros[index - 1].count: 92 | self.non_zeros[index] = self.non_zeros[index - 1] 93 | index -= 1 94 | self.non_zeros[index] = non_zero 95 | return non_zero.count 96 | 97 | def decrease_topic(self, topic, count = 1): 98 | """Subtracts count from topic, and returns the updated count. 99 | """ 100 | assert (topic >= 0 and topic < self.num_topics and count > 0) 101 | 102 | index = -1 103 | for i, non_zero in enumerate(self.non_zeros): 104 | if non_zero.topic == topic: 105 | non_zero.count -= count 106 | assert non_zero.count >= 0 107 | index = i 108 | break 109 | 110 | assert index != -1 111 | 112 | # ensure that topics sorted by their counts. 113 | non_zero = self.non_zeros[index] 114 | while (index < len(self.non_zeros) - 1 and 115 | non_zero.count < self.non_zeros[index + 1].count): 116 | self.non_zeros[index] = self.non_zeros[index + 1] 117 | index += 1 118 | if non_zero.count == 0: 119 | del self.non_zeros[index:] 120 | else: 121 | self.non_zeros[index] = non_zero 122 | return non_zero.count 123 | 124 | def __str__(self): 125 | """Outputs a human-readable representation of the model. 126 | """ 127 | topic_hist_str = [] 128 | for non_zero in self.non_zeros: 129 | topic_hist_str.append(str(non_zero)) 130 | return '\n'.join(topic_hist_str) 131 | -------------------------------------------------------------------------------- /common/ordered_sparse_topic_histogram_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import unittest 20 | 21 | from ordered_sparse_topic_histogram import OrderedSparseTopicHistogram 22 | 23 | class OrderedSparseTopicHistogramTest(unittest.TestCase): 24 | 25 | def setUp(self): 26 | self.num_topics = 20 27 | self.ordered_sparse_topic_hist = \ 28 | OrderedSparseTopicHistogram(self.num_topics) 29 | for i in xrange(10): 30 | self.ordered_sparse_topic_hist.increase_topic(i, i + 1) 31 | 32 | def test_ordered_sparse_topic_hist(self): 33 | self.assertEqual(10, len(self.ordered_sparse_topic_hist.non_zeros)) 34 | for i in xrange(len(self.ordered_sparse_topic_hist.non_zeros)): 35 | self.assertEqual(10 - i - 1, 36 | self.ordered_sparse_topic_hist.non_zeros[i].topic) 37 | self.assertEqual(10 - i, 38 | self.ordered_sparse_topic_hist.non_zeros[i].count) 39 | 40 | def test_num_topics(self): 41 | self.assertEqual(self.num_topics, 42 | self.ordered_sparse_topic_hist.num_topics) 43 | 44 | def test_size(self): 45 | self.assertEqual(10, self.ordered_sparse_topic_hist.size()) 46 | 47 | def test_serialize_and_parse(self): 48 | blob = self.ordered_sparse_topic_hist.serialize_to_string() 49 | 50 | sparse_topic_hist = OrderedSparseTopicHistogram(self.num_topics) 51 | sparse_topic_hist.parse_from_string(blob) 52 | 53 | self.assertEqual(sparse_topic_hist.size(), 54 | self.ordered_sparse_topic_hist.size()) 55 | self.assertEqual(str(sparse_topic_hist), 56 | str(self.ordered_sparse_topic_hist)) 57 | 58 | def test_count(self): 59 | for i in xrange(10): 60 | self.assertEqual(i + 1, self.ordered_sparse_topic_hist.count(i)) 61 | for i in xrange(10, 20): 62 | self.assertEqual(0, self.ordered_sparse_topic_hist.count(i)) 63 | 64 | def test_increase_topic(self): 65 | for i in xrange(20): 66 | if i < 10: 67 | self.assertEqual(2 * (i + 1), 68 | self.ordered_sparse_topic_hist.increase_topic(i, i + 1)) 69 | else: 70 | self.assertEqual(i + 1, 71 | self.ordered_sparse_topic_hist.increase_topic(i, i + 1)) 72 | 73 | for j in xrange(len(self.ordered_sparse_topic_hist.non_zeros) - 1): 74 | self.assertGreaterEqual( 75 | self.ordered_sparse_topic_hist.non_zeros[j].count, 76 | self.ordered_sparse_topic_hist.non_zeros[j + 1].count) 77 | 78 | self.assertEqual(2, self.ordered_sparse_topic_hist.count(0)) 79 | self.assertEqual(12, self.ordered_sparse_topic_hist.count(5)) 80 | self.assertEqual(11, self.ordered_sparse_topic_hist.count(10)) 81 | self.assertEqual(16, self.ordered_sparse_topic_hist.count(15)) 82 | self.assertEqual(20, self.ordered_sparse_topic_hist.increase_topic(15, 4)) 83 | 84 | def test_decrease_topic(self): 85 | self.assertEqual(6, self.ordered_sparse_topic_hist.count(5)) 86 | self.assertEqual(7, self.ordered_sparse_topic_hist.count(6)) 87 | self.assertEqual(5, self.ordered_sparse_topic_hist.decrease_topic(5, 1)) 88 | self.assertEqual(3, self.ordered_sparse_topic_hist.decrease_topic(6, 4)) 89 | self.assertEqual(10, self.ordered_sparse_topic_hist.size()) 90 | self.assertEqual(5, self.ordered_sparse_topic_hist.count(5)) 91 | self.assertEqual(3, self.ordered_sparse_topic_hist.count(6)) 92 | 93 | for i in xrange(len(self.ordered_sparse_topic_hist.non_zeros) - 1): 94 | self.assertGreaterEqual( 95 | self.ordered_sparse_topic_hist.non_zeros[i].count, 96 | self.ordered_sparse_topic_hist.non_zeros[i + 1].count) 97 | 98 | self.assertEqual(0, self.ordered_sparse_topic_hist.decrease_topic(6, 3)) 99 | self.assertEqual(9, self.ordered_sparse_topic_hist.size()) 100 | for i in xrange(len(self.ordered_sparse_topic_hist.non_zeros) - 1): 101 | self.assertGreaterEqual( 102 | self.ordered_sparse_topic_hist.non_zeros[i].count, 103 | self.ordered_sparse_topic_hist.non_zeros[i + 1].count) 104 | 105 | if __name__ == '__main__': 106 | unittest.main() 107 | 108 | -------------------------------------------------------------------------------- /common/recordio.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import logging 20 | import struct 21 | 22 | # Maximum record size, default = 64MB 23 | SANITY_CHECK_BYTES = 64 * 1024 * 1024 24 | 25 | class RecordWriter(object): 26 | """Write string records to a stream. 27 | 28 | Max record size is 64MB for the sake of sanity. 29 | """ 30 | 31 | def __init__(self, fp): 32 | """Initialize a Writer from the file pointer fp. 33 | """ 34 | self.fp = fp 35 | if ('w' not in self.fp.mode and 36 | 'a' not in self.fp.mode and 37 | '+' not in self.fp.mode): 38 | logging.error("""Filehandle supplied to RecordWriter does not 39 | appear to be writeable.""") 40 | 41 | def write(self, blob): 42 | """Append the blob to the current RecordWriter. 43 | 44 | Returns True on success, False on any filesystem failure. 45 | """ 46 | if not isinstance(blob, str): 47 | logging.error('Invalid type, blob (type = %s) not StringType.' 48 | % type(blob)) 49 | return False 50 | 51 | blob_len = len(blob) 52 | global SANITY_CHECK_BYTES 53 | if blob_len > SANITY_CHECK_BYTES: 54 | logging.error('Record size %d exceeded.' % blob_len) 55 | return False 56 | 57 | self.fp.write(struct.pack('>L', blob_len)) 58 | self.fp.write(blob) 59 | 60 | return True 61 | 62 | class RecordReader(object): 63 | """Read string records from a RecordWriter stream. 64 | """ 65 | 66 | def __init__(self, fp): 67 | """Initialize a Reader from the file pointer fp. 68 | """ 69 | self.fp = fp 70 | if (('w' in self.fp.mode or 'a' in self.fp.mode) and 71 | '+' not in self.fp.mode): 72 | logging.error("""Filehandle supplied to RecordReader does not 73 | appear to be readable.""") 74 | 75 | def read(self): 76 | """Read s single record from this stream. Updates the file position 77 | on both success and failure (unless no data is available, in which case 78 | the file position is unchanged and None is returned.) 79 | 80 | Returns string blob or None if no data available. 81 | """ 82 | blob = self.fp.read(4) 83 | if len(blob) == 0: 84 | logging.debug('%s has no data (current offset = %d).' 85 | % (self.fp.name, self.fp.tell())) 86 | self.fp.seek(self.fp.tell()) 87 | return None 88 | 89 | if len(blob) != 4: 90 | logging.error('Expected 4 bytes, but got %d.' % len(blob)) 91 | return None 92 | 93 | blob_len = struct.unpack('>L', blob)[0] 94 | global SANITY_CHECK_BYTES 95 | if blob_len > SANITY_CHECK_BYTES: 96 | logging.error('Record size %d exceeded.' % blob_len) 97 | 98 | read_blob = self.fp.read(blob_len) 99 | if len(read_blob) != blob_len: 100 | logging.error('Premature end of stream.') 101 | return read_blob 102 | 103 | -------------------------------------------------------------------------------- /common/recordio_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import unittest 20 | from recordio import RecordWriter 21 | from recordio import RecordReader 22 | from lda_pb2 import WordTopicHistogramPB 23 | 24 | class RecordIOTest(unittest.TestCase): 25 | 26 | def setUp(self): 27 | pass 28 | 29 | def test_read_and_write_normal(self): 30 | fp = open('../testdata/recordio.dat', 'wb') 31 | record_writer = RecordWriter(fp) 32 | self.assertFalse(record_writer.write(111)) 33 | self.assertFalse(record_writer.write(111.89)) 34 | self.assertFalse(record_writer.write(True)) 35 | self.assertTrue(record_writer.write('111')) 36 | self.assertTrue(record_writer.write('89')) 37 | self.assertTrue(record_writer.write('apple')) 38 | self.assertTrue(record_writer.write('ipad')) 39 | fp.close() 40 | 41 | fp = open('../testdata/recordio.dat', 'rb') 42 | record_reader = RecordReader(fp) 43 | self.assertEqual('111', record_reader.read()) 44 | self.assertEqual('89', record_reader.read()) 45 | self.assertEqual('apple', record_reader.read()) 46 | self.assertEqual('ipad', record_reader.read()) 47 | self.assertIsNone(record_reader.read()) 48 | fp.close() 49 | 50 | def test_read_and_writer_pb(self): 51 | fp = open('../testdata/recordio.dat', 'wb') 52 | record_writer = RecordWriter(fp) 53 | for i in xrange(20): 54 | word_topic_hist = WordTopicHistogramPB() 55 | word_topic_hist.word = i 56 | for j in xrange(20): 57 | non_zero = word_topic_hist.sparse_topic_hist.non_zeros.add() 58 | non_zero.topic = j 59 | non_zero.count = j + 1 60 | self.assertTrue( 61 | record_writer.write(word_topic_hist.SerializeToString())) 62 | fp.close() 63 | 64 | fp = open('../testdata/recordio.dat', 'rb') 65 | record_reader = RecordReader(fp) 66 | i = 0 67 | while True: 68 | blob = record_reader.read() 69 | if blob == None: 70 | break 71 | word_topic_hist = WordTopicHistogramPB() 72 | word_topic_hist.ParseFromString(blob) 73 | self.assertEqual(i, word_topic_hist.word) 74 | sparse_topic_hist = word_topic_hist.sparse_topic_hist 75 | self.assertEqual(20, len(sparse_topic_hist.non_zeros)) 76 | for j in xrange(len(sparse_topic_hist.non_zeros)): 77 | self.assertEqual(j, sparse_topic_hist.non_zeros[j].topic) 78 | self.assertEqual(j + 1, sparse_topic_hist.non_zeros[j].count) 79 | i += 1 80 | self.assertEqual(20, i) 81 | fp.close() 82 | 83 | if __name__ == '__main__': 84 | unittest.main() 85 | 86 | -------------------------------------------------------------------------------- /common/vocabulary.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | class Vocabulary(object): 20 | """Vocabulary implements the structure of words. 21 | """ 22 | def __init__(self): 23 | self.word_map = {} # item fmt: word -> index 24 | self.words = [] # item fmt: word, default index 25 | 26 | def load(self, filename): 27 | """read words from filename. 28 | line fmt: word [\t count] 29 | """ 30 | self.word_map.clear() 31 | self.words = [] 32 | fp = open(filename, 'r') 33 | for line in fp.readlines(): 34 | line = line.decode('gbk') 35 | fields = line.strip().split('\t') 36 | if len(fields) > 0 and fields[0] not in self.word_map: 37 | self.word_map[fields[0]] = len(self.words) 38 | self.words.append(fields[0]) 39 | fp.close() 40 | 41 | def has_word(self, word): 42 | return (word in self.word_map) 43 | 44 | def word_index(self, word): 45 | return self.word_map.get(word, -1) 46 | 47 | def word(self, index): 48 | assert index >= 0 and index < len(self.words) 49 | return self.words[index] 50 | 51 | def size(self): 52 | return len(self.words) 53 | 54 | -------------------------------------------------------------------------------- /common/vocabulary_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import unittest 20 | from vocabulary import Vocabulary 21 | 22 | class VocabularyTest(unittest.TestCase): 23 | 24 | def setUp(self): 25 | self.vocabulary = Vocabulary() 26 | self.vocabulary.load("../testdata/vocabulary.dat") 27 | 28 | def test_has_word(self): 29 | self.assertTrue(self.vocabulary.has_word('ipad')) 30 | self.assertTrue(self.vocabulary.has_word('iphone')) 31 | self.assertTrue(self.vocabulary.has_word('macbook')) 32 | self.assertFalse(self.vocabulary.has_word('nokia')) 33 | self.assertFalse(self.vocabulary.has_word('thinkpad')) 34 | 35 | def test_word_index(self): 36 | self.assertEqual(0, self.vocabulary.word_index('ipad')) 37 | self.assertEqual(1, self.vocabulary.word_index('iphone')) 38 | self.assertEqual(2, self.vocabulary.word_index('macbook')) 39 | self.assertEqual(-1, self.vocabulary.word_index('nokia')) 40 | self.assertEqual(-1, self.vocabulary.word_index('thinkpad')) 41 | 42 | def test_word(self): 43 | self.assertEqual('ipad', self.vocabulary.word(0)) 44 | self.assertEqual('iphone', self.vocabulary.word(1)) 45 | self.assertEqual('macbook', self.vocabulary.word(2)) 46 | 47 | def test_size(self): 48 | self.assertEqual(17, self.vocabulary.size()) 49 | 50 | if __name__ == '__main__': 51 | unittest.main() 52 | -------------------------------------------------------------------------------- /inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankazhao/python-sparselda/f84d05ee99899ceadb8371dc52cc76f5aa0da934/inference/__init__.py -------------------------------------------------------------------------------- /inference/multi_chain_gibbs_sampler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import random 20 | import sys 21 | 22 | sys.path.append('..') 23 | from common.document import Document 24 | from common.ordered_sparse_topic_histogram import OrderedSparseTopicHistogram 25 | from common.model import Model 26 | from common.vocabulary import Vocabulary 27 | from sparselda_gibbs_sampler import SparseLDAGibbsSampler 28 | 29 | class MultiChainGibbsSampler(SparseLDAGibbsSampler): 30 | """MultiChainGibbsSampler implements multi-markov-chain based SparseLDA 31 | gibbs sampling inference algorithm. 32 | 33 | See 'X. Wei, W. Bruce Croft. LDA-based document models for ad hoc retrieval. 34 | In Proc. SIGIR. 2006.' for more details. 35 | """ 36 | 37 | def __init__(self, model, vocabulary, num_markov_chains, 38 | total_iterations, burn_in_iterations): 39 | super(MultiChainGibbsSampler, self).__init__(model, vocabulary, 40 | total_iterations, burn_in_iterations) 41 | self.num_markov_chains = num_markov_chains 42 | 43 | def infer_topics(self, doc_tokens): 44 | """Inference topics embedded in the given document, which represents as 45 | a token sequence named 'doc_tokens'. 46 | 47 | Returns the sparse topics sorted by their probabilities p(z|d), 48 | such as {'apple' : 0.87, 'iphone' : 0.23, 'ipad': 0.17, 'nokia' : 0.1} 49 | """ 50 | rand = random.Random() 51 | rand.seed(hash(str(doc_tokens))) 52 | 53 | accumulated_topic_dist = {} 54 | for i in range(0, self.num_markov_chains): 55 | topic_dist = self._inference_one_chain(doc_tokens, rand) 56 | 57 | for topic, prob in topic_dist.iteritems(): 58 | if topic in accumulated_topic_dist: 59 | accumulated_topic_dist[topic] += prob 60 | else: 61 | accumulated_topic_dist[topic] = prob 62 | 63 | topic_dist = self._l1normalize_distribution(accumulated_topic_dist) 64 | sorted(topic_dist.items(), lambda x, y: cmp(x[1], y[1]), reverse = True) 65 | return topic_dist 66 | 67 | -------------------------------------------------------------------------------- /inference/multi_chain_gibbs_sampler_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import unittest 20 | import sys 21 | 22 | sys.path.append('..') 23 | from common.model import Model 24 | from common.vocabulary import Vocabulary 25 | from multi_chain_gibbs_sampler import MultiChainGibbsSampler 26 | 27 | class MultiChainGibbsSamplerTest(unittest.TestCase): 28 | 29 | def setUp(self): 30 | model = Model(20) 31 | model.load('../testdata/lda_model') 32 | vocabulary = Vocabulary() 33 | vocabulary.load('../testdata/vocabulary.dat') 34 | self.multi_chain_gibbs_sampler = \ 35 | MultiChainGibbsSampler(model, vocabulary, 10, 10, 5) 36 | 37 | def test_infer_topics(self): 38 | doc_tokens = [] 39 | doc_topic_dist = self.multi_chain_gibbs_sampler.infer_topics(doc_tokens) 40 | self.assertEqual(0, len(doc_topic_dist)) 41 | 42 | doc_tokens = ['apple', 'ipad'] 43 | doc_topic_dist = self.multi_chain_gibbs_sampler.infer_topics(doc_tokens) 44 | print doc_topic_dist 45 | self.assertEqual(5, len(doc_topic_dist)) 46 | self.assertTrue(0 in doc_topic_dist) 47 | self.assertEqual(0.05, doc_topic_dist[0]) 48 | self.assertTrue(1 in doc_topic_dist) 49 | self.assertEqual(0.32, doc_topic_dist[1]) 50 | self.assertTrue(3 in doc_topic_dist) 51 | self.assertEqual(0.14, doc_topic_dist[3]) 52 | 53 | doc_tokens = ['apple', 'ipad', 'apple', 'null', 'nokia', 'macbook'] 54 | doc_topic_dist = self.multi_chain_gibbs_sampler.infer_topics(doc_tokens) 55 | print doc_topic_dist 56 | self.assertEqual(6, len(doc_topic_dist)) 57 | 58 | if __name__ == '__main__': 59 | unittest.main() 60 | 61 | -------------------------------------------------------------------------------- /inference/sparselda_gibbs_sampler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import logging 20 | import random 21 | import sys 22 | 23 | sys.path.append('..') 24 | from common.document import Document 25 | from common.model import Model 26 | from common.vocabulary import Vocabulary 27 | 28 | class SparseLDAGibbsSampler(object): 29 | """SparseLDAGibbsSampler implements the SparseLDA gibbs sampling inference 30 | algorithm. In gibbs sampling formula: 31 | 32 | (0) p(z|w) --> p(z|d) * p(w|z) --> [alpha(z) + N(z|d)] * p(w|z) 33 | 34 | (1) s(z, w) = alpha(z) * p(w|z) 35 | (2) r(z, w, d) = N(z|d) * p(w|z) 36 | (3) p(w|z) = [beta + N(w|z)] / [beta * V + N(z)] 37 | 38 | The process divides the full sampling mass p(z|w) into two buckets, 39 | where s(z, w) is a smoothing-only bucket, and r(z, w, d) is a 40 | document-topic bucket. 41 | 42 | To achieve time efficiency, topic distributions matrix p(w|z) are 43 | pre-computed and cached. 44 | 45 | See 'Limin Yao, David Mimno, Andrew McCallum. Efficient methods for topic 46 | model inference on streaming document collections, In SIGKDD, 2009.' for 47 | more details. 48 | """ 49 | 50 | def __init__(self, model, vocabulary, total_iterations, 51 | burn_in_iterations): 52 | self.model = model 53 | self.vocabulary = vocabulary 54 | self.total_iterations = total_iterations 55 | self.burn_in_iterations = burn_in_iterations 56 | 57 | # cache p(w|z), indexed by word. 58 | # item fmt: word -> dense topic distribution. 59 | # TODO(fandywang): only cache the submatrix p(w|z) of frequent words 60 | # for time and memory efficieny, in other words, high scalability. 61 | self.word_topic_dist = self.model.get_word_topic_dist(vocabulary.size()) 62 | self.smoothing_only_sum = {} # cache s(z, w) 63 | self.__init_smoothing_only_sum() 64 | 65 | def __init_smoothing_only_sum(self): 66 | """Compute and cahce the smoothing_only_sum s(z, w). 67 | """ 68 | for word_id in self.model.word_topic_hist.keys(): 69 | cur_sum = 0.0 70 | topic_dist = self.word_topic_dist[word_id] 71 | for topic, prob in enumerate(topic_dist): 72 | cur_sum += self.model.hyper_params.topic_prior * prob 73 | self.smoothing_only_sum[word_id] = cur_sum 74 | 75 | def infer_topics(self, doc_tokens): 76 | """Inference topics embedded in the given document, which represents as 77 | a token sequence named 'doc_tokens'. 78 | 79 | Returns the dict of topics sorted by their probabilities p(z|d), such as 80 | {1 : 0.87, 6 : 0.23, 4: 0.17, 15 : 0.1} 81 | """ 82 | rand = random.Random() 83 | rand.seed(hash(str(doc_tokens))) 84 | 85 | doc_topic_dist = self._inference_one_chain(doc_tokens, rand) 86 | sorted(doc_topic_dist.items(), lambda x, y: cmp(x[1], y[1]), 87 | reverse = True) 88 | return doc_topic_dist 89 | 90 | # TODO(fandywang): infer topic words later. 91 | def infer_topic_words(self, doc_tokens): 92 | """Inference topic words embedded in the given document, which 93 | represents as a token sequence named 'doc_tokens'. 94 | 95 | Returns the dict of topic words sorted by their probabilities p(w|d) = 96 | p(z|d)*p(w|z), 97 | such as {'apple' : 0.87, 'iphone' : 0.23, 'ipad': 0.17, 'nokia' : 0.1} 98 | """ 99 | doc_topic_words_dist = {} 100 | doc_topic_dist = self.infer_topics(doc_tokens) 101 | # cahce the p(w|z), indexd by topic. 102 | for topic, prob in doc_topic_dist: 103 | pass 104 | return doc_topic_words_dist 105 | 106 | def _inference_one_chain(self, doc_tokens, rand): 107 | """Inference topics with one markov chain. 108 | 109 | Returns the sparse topics p(z|d). 110 | """ 111 | document = Document(self.model.num_topics) 112 | document.parse_from_tokens(doc_tokens, rand, 113 | self.vocabulary, self.model) 114 | if document.num_words() == 0: 115 | return dict() 116 | 117 | accumulated_topic_hist = {} 118 | for i in xrange(self.total_iterations): 119 | # one iteration 120 | for word in document.words: 121 | # -- 122 | document.decrease_topic(word.topic, 1) 123 | 124 | new_topic = self._sample_word_topic(document, word.id, rand) 125 | assert new_topic != None 126 | word.topic = new_topic 127 | # ++ 128 | document.increase_topic(new_topic, 1) 129 | 130 | if i >= self.burn_in_iterations: 131 | for non_zero in document.doc_topic_hist.non_zeros: 132 | if non_zero.topic in accumulated_topic_hist: 133 | accumulated_topic_hist[non_zero.topic] += non_zero.count 134 | else: 135 | accumulated_topic_hist[non_zero.topic] = non_zero.count 136 | 137 | topic_dist = self._l1normalize_distribution(accumulated_topic_hist) 138 | return topic_dist 139 | 140 | def _sample_word_topic(self, doc, word, rand): 141 | """Sampling a new topic for current word. 142 | 143 | Returns the new topic. 144 | """ 145 | doc_topic_bucket, doc_topic_sum = \ 146 | self._compute_doc_topic_bucket(doc, word) 147 | 148 | total_mass = self.smoothing_only_sum[word] + doc_topic_sum 149 | sample = rand.uniform(0.0, total_mass) 150 | 151 | # sample in document topic bucket 152 | if sample < doc_topic_sum: 153 | for topic_prob in doc_topic_bucket: 154 | sample -= topic_prob[1] 155 | if sample <= 0: 156 | return topic_prob[0] 157 | else: # sample in smoothing only bucket 158 | sample -= doc_topic_sum 159 | topic_dist = self.word_topic_dist[word] 160 | for topic, prob in enumerate(topic_dist): 161 | sample -= prob 162 | if sample <= 0: 163 | return topic 164 | logging.error('sample word topic error, sample: %f, dist_sum: %f.' 165 | % (sample, dist_sum)) 166 | return None 167 | 168 | def _compute_doc_topic_bucket(self, doc, word): 169 | """Compute the document topic bucket r(z, w, d). 170 | 171 | Returns document-topic distributions and their sum. 172 | """ 173 | doc_topic_bucket = [] 174 | doc_topic_sum = 0.0 175 | 176 | dense_topic_dist = self.word_topic_dist[word] 177 | for non_zero in doc.doc_topic_hist.non_zeros: 178 | doc_topic_bucket.append([non_zero.topic, 179 | non_zero.count * dense_topic_dist[non_zero.topic]]) 180 | doc_topic_sum += doc_topic_bucket[-1][1] 181 | return doc_topic_bucket, doc_topic_sum 182 | 183 | def _l1normalize_distribution(self, topic_dict): 184 | """Returns the l1-normalized topic distributions. 185 | """ 186 | topic_dist = {} 187 | weight_sum = 0 188 | for topic, weight in topic_dict.iteritems(): 189 | weight_sum += weight 190 | if weight_sum == 0: 191 | logging.warning('The sum of topic weight is zero.') 192 | return topic_dist 193 | for topic, weight in topic_dict.iteritems(): 194 | topic_dist[topic] = float(weight) / weight_sum 195 | return topic_dist 196 | 197 | -------------------------------------------------------------------------------- /inference/sparselda_gibbs_sampler_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import unittest 20 | import sys 21 | 22 | sys.path.append('..') 23 | from common.model import Model 24 | from common.vocabulary import Vocabulary 25 | from sparselda_gibbs_sampler import SparseLDAGibbsSampler 26 | 27 | class SparseLDAGibbsSamplerTest(unittest.TestCase): 28 | 29 | def setUp(self): 30 | model = Model(20) 31 | model.load('../testdata/lda_model') 32 | vocabulary = Vocabulary() 33 | vocabulary.load('../testdata/vocabulary.dat') 34 | self.sparselda_gibbs_sampler = \ 35 | SparseLDAGibbsSampler(model, vocabulary, 10, 5) 36 | 37 | def test_infer_topics(self): 38 | doc_tokens = [] 39 | doc_topic_dist = self.sparselda_gibbs_sampler.infer_topics(doc_tokens) 40 | self.assertEqual(0, len(doc_topic_dist)) 41 | 42 | doc_tokens = ['apple', 'ipad'] 43 | doc_topic_dist = self.sparselda_gibbs_sampler.infer_topics(doc_tokens) 44 | self.assertEqual(3, len(doc_topic_dist)) 45 | self.assertTrue(1 in doc_topic_dist) 46 | 47 | doc_tokens = ['apple', 'ipad', 'apple', 'null', 'nokia', 'macbook'] 48 | doc_topic_dist = self.sparselda_gibbs_sampler.infer_topics(doc_tokens) 49 | self.assertEqual(4, len(doc_topic_dist)) 50 | self.assertTrue(0 in doc_topic_dist) 51 | self.assertEqual(0.1, doc_topic_dist[0]) 52 | self.assertTrue(2 in doc_topic_dist) 53 | self.assertEqual(0.3, doc_topic_dist[2]) 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | 58 | -------------------------------------------------------------------------------- /lda_inferencer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import logging 20 | import optparse 21 | 22 | from common.model import Model 23 | from common.vocabulary import Vocabulary 24 | from inference.sparselda_gibbs_sampler import SparseLDAGibbsSampler 25 | 26 | def main(args): 27 | model = Model(0) 28 | model.load(args.model_dir) 29 | vocabulary = Vocabulary() 30 | vocabulary.load(args.vocabulary) 31 | multi_chain_gibbs_sampler = MultiChainGibbsSampler(model, vocabulary, 32 | args.num_markov_chains, args.total_iterations, 33 | args.burn_in_iterations) 34 | 35 | fp = open(args.documents, 'r') 36 | for doc_str in fp.readlines(): 37 | doc_str = doc_str.decode('gbk') 38 | doc_tokens = doc_str.strip().split('\t') 39 | topic_dist = multi_chain_gibbs_sampler.infer_topics(doc_tokens) 40 | print doc_str 41 | print topic_dist 42 | fp.close() 43 | 44 | if __name__ == '__main__': 45 | parser = optparse.OptionParser('usage: python lda_inference.py -h.') 46 | parser.add_option('--model_dir', help = 'the lda model directory.') 47 | parser.add_option('--vocabulary_file', help = 'the vocabulary file.') 48 | parser.add_option('--document_file_file', 49 | help = 'the document file in gbk, line fmt: w1 \t w2 \t w3 \t... .') 50 | parser.add_option('--num_markov_chains', type = int, default = 5, 51 | help = 'the num of markov chains.') 52 | parser.add_option('--total_iterations', type = int, default = 50, 53 | help = 'the num of total_iterations.') 54 | parser.add_option('--burn_in_iterations', type = int, default = 20, 55 | help = 'the num of burn_in iteration.') 56 | 57 | (options, args) = parser.parse_args() 58 | logging.info('Parameters : %s' % str(options)) 59 | 60 | main(options) 61 | 62 | -------------------------------------------------------------------------------- /lda_trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import logging 20 | import optparse 21 | import os 22 | import random 23 | 24 | from common.model import Model 25 | from common.vocabulary import Vocabulary 26 | from training.sparselda_train_gibbs_sampler import SparseLDATrainGibbsSampler 27 | from training.model_evaluator import ModelEvaluator 28 | from training.topic_words_stat import TopicWordsStat 29 | 30 | def main(args): 31 | model = Model(args.num_topics, args.topic_prior, args.word_prior) 32 | vocabulary = Vocabulary() 33 | vocabulary.load(args.vocabulary_file) 34 | sparselda_train_gibbs_sampler = SparseLDATrainGibbsSampler( 35 | model, vocabulary) 36 | sparselda_train_gibbs_sampler.load_corpus(args.corpus_dir) 37 | 38 | rand = random.Random() 39 | 40 | for i in xrange(args.total_iterations): 41 | logging.info('sparselda trainer, gibbs sampling iteration %d.' 42 | % (i + 1)) 43 | sparselda_train_gibbs_sampler.gibbs_sampling(rand) 44 | 45 | # dump lda model 46 | if i == 0 or (i + 1) % args.save_model_interval == 0: 47 | logging.info('iteration %d start saving lda model.' % (i + 1)) 48 | sparselda_train_gibbs_sampler.save_model(args.model_dir, i + 1) 49 | topic_words_stat = TopicWordsStat(model, vocabulary) 50 | topic_words_stat.save( 51 | '%s/topic_top_words.%d' % (args.model_dir, i + 1), 52 | args.topic_word_accumulated_prob_threshold) 53 | logging.info('iteration %d save lda model ok.' % (i + 1)) 54 | 55 | # dump checkpoint 56 | if i == 0 or (i + 1) % args.save_checkpoint_interval == 0: 57 | logging.info('iteration %d start saving checkpoint.' % (i + 1)) 58 | sparselda_train_gibbs_sampler.save_checkpoint( 59 | args.checkpoint_dir, i + 1) 60 | logging.info('iteration %d save checkpoint ok.' % (i + 1)) 61 | 62 | # compute the loglikelihood 63 | if i == 0 or (i + 1) % args.compute_loglikelihood_interval == 0: 64 | logging.info('iteration %d start computing loglikelihood.' % (i + 1)) 65 | model_evaluator = ModelEvaluator(model, vocabulary) 66 | ll = model_evaluator.compute_loglikelihood( 67 | sparselda_train_gibbs_sampler.documents) 68 | logging.info('iteration %d loglikelihood is %f.' % (i + 1, ll)) 69 | 70 | if __name__ == '__main__': 71 | parser = optparse.OptionParser('usage: python lda_trainer.py -h.') 72 | parser.add_option('--corpus_dir', 73 | help = 'the corpus directory, line fmt: w1 \t w2 \t w3 ... .') 74 | parser.add_option('--vocabulary_file', 75 | help = 'the vocabulary file, line fmt: w [\tfreq].') 76 | parser.add_option('--num_topics', type = int, help = 'the num of topics.') 77 | parser.add_option('--topic_prior', type = float, default = 0.1, 78 | help = 'the topic prior alpha.') 79 | parser.add_option('--word_prior', type = float, default = 0.01, 80 | help = 'the word prior beta.') 81 | parser.add_option('--total_iterations', type = int, default = 10000, 82 | help = 'the total iteration.') 83 | parser.add_option('--model_dir', help = 'the model directory.') 84 | parser.add_option('--save_model_interval', type = int, default = 100, 85 | help = 'the interval to save lda model.') 86 | parser.add_option('--topic_word_accumulated_prob_threshold', 87 | type = float, default = 0.5, 88 | help = 'the accumulated_prob_threshold of topic words.') 89 | parser.add_option('--save_checkpoint_interval', type = int, default = 10, 90 | help = 'the interval to save checkpoint.') 91 | parser.add_option('--checkpoint_dir', help = 'the checkpoint directory.') 92 | parser.add_option('--compute_loglikelihood_interval', type = int, 93 | default = 10, help = 'the interval to compute loglikelihood.') 94 | 95 | (options, args) = parser.parse_args() 96 | logging.basicConfig(filename = os.path.join(os.getcwd(), 'log.txt'), 97 | level = logging.DEBUG, filemode = 'a', 98 | format = '%(asctime)s - %(levelname)s: %(message)s') 99 | logging.info('Parameters : %s' % str(options)) 100 | 101 | main(options) 102 | -------------------------------------------------------------------------------- /lda_trainer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | nohup python lda_trainer.py \ 4 | --corpus_dir=testdata/corpus \ 5 | --vocabulary_file=testdata/vocabulary.dat \ 6 | --num_topics=500 \ 7 | --topic_prior=0.1 \ 8 | --word_prior=0.01 \ 9 | --save_model_interval=10 \ 10 | --model_dir=sparselda_models \ 11 | --save_checkpoint_interval=10 \ 12 | --checkpoint_dir=sparselda_checkpoints \ 13 | --total_iterations=10000 \ 14 | --compute_loglikelihood_interval=10 \ 15 | --topic_word_accumulated_prob_threshold=0.5 \ 16 | > train.log 2>&1 & 17 | 18 | -------------------------------------------------------------------------------- /testdata/corpus/document1.dat: -------------------------------------------------------------------------------- 1 | apple steve jobs ipad ipod 2 | nokia iphone apple ios 3 | macbook air macbook pro mac os x itunes 4 | chrome browser google andriod os 5 | chrome 6 | macbook air 7 | apple 8 | -------------------------------------------------------------------------------- /testdata/lda_model/lda.global_topic_hist: -------------------------------------------------------------------------------- 1 | ( 2 |  +6ALWbmmkgaYOC5% -------------------------------------------------------------------------------- /testdata/lda_model/lda.hyper_params: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankazhao/python-sparselda/f84d05ee99899ceadb8371dc52cc76f5aa0da934/testdata/lda_model/lda.hyper_params -------------------------------------------------------------------------------- /testdata/lda_model/lda.word_topic_hist: -------------------------------------------------------------------------------- 1 | @< 2 |   3 | 4 |  5 |  6 |  7 |  8 |  9 |  10 |  11 |  12 | FB 13 |  14 |  15 |   16 | 17 |  18 |  19 |  20 |  21 |  22 |  23 |  24 |  25 | LH 26 |   27 |  28 |  29 |   30 | 31 |  32 |  33 |  34 |  35 |  36 |  37 |  38 |  39 | RN 40 |   41 |   42 |  43 |  44 |   45 | 46 |  47 |  48 |  49 |  50 |  51 |  52 |  53 |  54 | XT 55 |   56 |   57 |   58 |  59 |  60 |   61 | 62 |  63 |  64 |  65 |  66 |  67 |  68 |  69 |  70 | ^Z 71 |  72 |   73 |   74 |   75 |  76 |  77 |   78 | 79 |  80 |  81 |  82 |  83 |  84 |  85 |  86 |  87 | d` 88 |  89 |  90 |   91 |   92 |   93 |  94 |  95 |   96 | 97 |  98 |  99 |  100 |  101 |  102 |  103 |  104 |  105 | jf 106 |  107 |  108 |  109 |   110 |   111 |   112 |  113 |  114 |   115 | 116 |  117 |  118 |  119 |  120 |  121 |  122 |  123 |  124 | pl 125 |  126 |  127 |  128 |  129 |   130 |   131 |   132 |  133 |  134 |   135 | 136 |  137 |  138 |  139 |  140 |  141 |  142 |  143 |  144 | v r 145 |  146 |  147 |  148 |  149 |  150 |   151 |   152 |   153 |  154 |  155 |   156 | 157 |  158 |  159 |  160 |  161 |  162 |  163 |  164 |  165 |  -------------------------------------------------------------------------------- /testdata/recordio.dat: -------------------------------------------------------------------------------- 1 | |x 2 |  3 |  4 |  5 |  6 |  7 |  8 |  9 |  10 |  11 |   12 | 13 |  14 |  15 |   16 |   17 |   18 |  19 |  20 |  21 |  22 |  23 | |x 24 |  25 |  26 |  27 |  28 |  29 |  30 |  31 |  32 |  33 |   34 | 35 |  36 |  37 |   38 |   39 |   40 |  41 |  42 |  43 |  44 |  45 | |x 46 |  47 |  48 |  49 |  50 |  51 |  52 |  53 |  54 |  55 |   56 | 57 |  58 |  59 |   60 |   61 |   62 |  63 |  64 |  65 |  66 |  67 | |x 68 |  69 |  70 |  71 |  72 |  73 |  74 |  75 |  76 |  77 |   78 | 79 |  80 |  81 |   82 |   83 |   84 |  85 |  86 |  87 |  88 |  89 | |x 90 |  91 |  92 |  93 |  94 |  95 |  96 |  97 |  98 |  99 |   100 | 101 |  102 |  103 |   104 |   105 |   106 |  107 |  108 |  109 |  110 |  111 | |x 112 |  113 |  114 |  115 |  116 |  117 |  118 |  119 |  120 |  121 |   122 | 123 |  124 |  125 |   126 |   127 |   128 |  129 |  130 |  131 |  132 |  133 | |x 134 |  135 |  136 |  137 |  138 |  139 |  140 |  141 |  142 |  143 |   144 | 145 |  146 |  147 |   148 |   149 |   150 |  151 |  152 |  153 |  154 |  155 | |x 156 |  157 |  158 |  159 |  160 |  161 |  162 |  163 |  164 |  165 |   166 | 167 |  168 |  169 |   170 |   171 |   172 |  173 |  174 |  175 |  176 |  177 | |x 178 |  179 |  180 |  181 |  182 |  183 |  184 |  185 |  186 |  187 |   188 | 189 |  190 |  191 |   192 |   193 |   194 |  195 |  196 |  197 |  198 |  199 | | x 200 |  201 |  202 |  203 |  204 |  205 |  206 |  207 |  208 |  209 |   210 | 211 |  212 |  213 |   214 |   215 |   216 |  217 |  218 |  219 |  220 |  221 | | 222 | x 223 |  224 |  225 |  226 |  227 |  228 |  229 |  230 |  231 |  232 |   233 | 234 |  235 |  236 |   237 |   238 |   239 |  240 |  241 |  242 |  243 |  244 | | x 245 |  246 |  247 |  248 |  249 |  250 |  251 |  252 |  253 |  254 |   255 | 256 |  257 |  258 |   259 |   260 |   261 |  262 |  263 |  264 |  265 |  266 | | x 267 |  268 |  269 |  270 |  271 |  272 |  273 |  274 |  275 |  276 |   277 | 278 |  279 |  280 |   281 |   282 |   283 |  284 |  285 |  286 |  287 |  288 | | x 289 |  290 |  291 |  292 |  293 |  294 |  295 |  296 |  297 |  298 |   299 | 300 |  301 |  302 |   303 |   304 |   305 |  306 |  307 |  308 |  309 |  310 | |x 311 |  312 |  313 |  314 |  315 |  316 |  317 |  318 |  319 |  320 |   321 | 322 |  323 |  324 |   325 |   326 |   327 |  328 |  329 |  330 |  331 |  332 | |x 333 |  334 |  335 |  336 |  337 |  338 |  339 |  340 |  341 |  342 |   343 | 344 |  345 |  346 |   347 |   348 |   349 |  350 |  351 |  352 |  353 |  354 | |x 355 |  356 |  357 |  358 |  359 |  360 |  361 |  362 |  363 |  364 |   365 | 366 |  367 |  368 |   369 |   370 |   371 |  372 |  373 |  374 |  375 |  376 | |x 377 |  378 |  379 |  380 |  381 |  382 |  383 |  384 |  385 |  386 |   387 | 388 |  389 |  390 |   391 |   392 |   393 |  394 |  395 |  396 |  397 |  398 | |x 399 |  400 |  401 |  402 |  403 |  404 |  405 |  406 |  407 |  408 |   409 | 410 |  411 |  412 |   413 |   414 |   415 |  416 |  417 |  418 |  419 |  420 | |x 421 |  422 |  423 |  424 |  425 |  426 |  427 |  428 |  429 |  430 |   431 | 432 |  433 |  434 |   435 |   436 |   437 |  438 |  439 |  440 |  441 |  442 |  -------------------------------------------------------------------------------- /testdata/vocabulary.dat: -------------------------------------------------------------------------------- 1 | ipad 2 | iphone 3 | macbook 4 | macbook pro 5 | macbook air 6 | air 7 | apple 8 | steve jobs 9 | itunes 10 | ipod 11 | ipad 12 | iphone 13 | mac os x 14 | chrome 15 | google 16 | glass 17 | andriod 18 | ios 19 | os 20 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankazhao/python-sparselda/f84d05ee99899ceadb8371dc52cc76f5aa0da934/training/__init__.py -------------------------------------------------------------------------------- /training/model_evaluator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import logging 20 | import math 21 | import os 22 | import random 23 | import sys 24 | 25 | sys.path.append('..') 26 | from common.document import Document 27 | from common.model import Model 28 | from common.vocabulary import Vocabulary 29 | 30 | class ModelEvaluator(object): 31 | """ModelEvaluator implements the evaluation method of lda model's quality. 32 | """ 33 | 34 | def __init__(self, model, vocabulary): 35 | self.model = model 36 | self.vocabulary = vocabulary 37 | 38 | # cache matrix p(w|z), indexed by word. 39 | self.word_topic_dist = \ 40 | self.model.get_word_topic_dist(self.vocabulary.size()) 41 | 42 | def compute_loglikelihood(self, documents): 43 | """Compute and return the loglikelihood of documents. 44 | 45 | p(D|M) = p(d1)p(d2)... 46 | 47 | p(d) = p(w1)p(w2)... 48 | = sum_z {p(z|d)p(w1|z)} * sum_z {p(z|d)p(w2|z)} * ... 49 | 50 | log(p(d)) = log(sum_z p(z|d)p(w1|z)) + log(sum_z p(z|d)p(w2|z)) + ... 51 | 52 | p(D|M) -> log(p(D|M)) = log(p(d1)) + log(p(d2)) + ... 53 | """ 54 | loglikelihood = 0.0 55 | for document in documents: 56 | doc_dense_topic_dist = \ 57 | self._compute_doc_topic_distribution(document) 58 | doc_loglikelihood = 0.0 59 | for word in document.words: 60 | word_topic_dist = self.word_topic_dist.get(word.id) 61 | if word_topic_dist is None: 62 | continue 63 | word_prob_sum = 0.0 64 | for topic, prob in enumerate(word_topic_dist): 65 | word_prob_sum += prob * doc_dense_topic_dist[topic] 66 | doc_loglikelihood += math.log(word_prob_sum) 67 | loglikelihood += doc_loglikelihood 68 | return loglikelihood 69 | 70 | def _compute_doc_topic_distribution(self, document): 71 | dense_topic_hist = [0] * self.model.num_topics 72 | topic_hist_sum = 0 73 | for non_zero in document.doc_topic_hist.non_zeros: 74 | dense_topic_hist[non_zero.topic] = non_zero.count 75 | topic_hist_sum += non_zero.count 76 | 77 | dense_topic_dist = [] 78 | denominator = \ 79 | self.model.hyper_params.topic_prior * self.model.num_topics + \ 80 | topic_hist_sum 81 | for i in xrange(self.model.num_topics): 82 | dense_topic_dist.append( 83 | (self.model.hyper_params.topic_prior + dense_topic_hist[i]) 84 | / denominator) 85 | return dense_topic_dist 86 | 87 | -------------------------------------------------------------------------------- /training/model_evaluator_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | 7 | import random 8 | import unittest 9 | import sys 10 | 11 | sys.path.append('..') 12 | from common.document import Document 13 | from common.model import Model 14 | from common.vocabulary import Vocabulary 15 | from model_evaluator import ModelEvaluator 16 | 17 | class ModelEvaluatorTest(unittest.TestCase): 18 | 19 | def setUp(self): 20 | self.model = Model(20) 21 | self.model.load('../testdata/lda_model') 22 | self.vocabulary = Vocabulary() 23 | self.vocabulary.load('../testdata/vocabulary.dat') 24 | 25 | self.model_evaluator = ModelEvaluator(self.model, self.vocabulary) 26 | 27 | def test_compute_loglikelihood(self): 28 | doc_tokens = ['macbook', 'ipad', # exist in vocabulary and model 29 | 'mac os x', 'chrome', # only exist in vocabulary 30 | 'nokia', 'null'] # inexistent 31 | document = Document(self.model.num_topics) 32 | rand = random.Random() 33 | rand.seed(0) 34 | document.parse_from_tokens( 35 | doc_tokens, rand, self.vocabulary, self.model) 36 | documents = [document, document] 37 | self.assertEqual(-14.113955684239654, 38 | self.model_evaluator.compute_loglikelihood(documents)) 39 | 40 | if __name__ == '__main__': 41 | unittest.main() 42 | 43 | -------------------------------------------------------------------------------- /training/sparselda_train_gibbs_sampler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import logging 20 | import os 21 | import random 22 | import sys 23 | 24 | sys.path.append('..') 25 | from common.document import Document 26 | from common.model import Model 27 | from common.ordered_sparse_topic_histogram import OrderedSparseTopicHistogram 28 | from common.recordio import RecordReader 29 | from common.recordio import RecordWriter 30 | from common.vocabulary import Vocabulary 31 | 32 | class SparseLDATrainGibbsSampler(object): 33 | """SparseLDATrainGibbsSampler implements the SparseLDA gibbs sampling 34 | training algorithm. In gibbs sampling formula: 35 | 36 | (0) p(z|w) --> p(z|d) * p(w|z) 37 | --> (alpha(z) + N(z|d)) * p(w|z) 38 | = (alpha(z) + N(z|d)) * p(w|z) * 39 | (beta + N(w|z)) / (beta * |V| + N(z)) 40 | = alpha(z) * beta / (beta * |V| + N(z)) + 41 | N(z|d) * beta / (beta * |V| + N(z)) + 42 | (alpha(z) + N(z|d)) * N(w|z) / (beta * |V| + N(z)) 43 | 44 | (1) s(z) = alpha(z) * beta / (beta * |V| + N(z)) 45 | (2) r(z, d) = N(z|d) * beta / (beta * |V| + N(z)) 46 | (3) q(z, w, d) = N(w|z) * (alpha(z) + N(z|d)) / (beta * |V| + N(z)) 47 | 48 | (4) q_coefficient(z, d) = (alpha(z) + N(z|d)) / (beta * |V| + N(z)) 49 | 50 | This process divides the full sampling mass into three buckets, where s(z) 51 | is a smoothing-only bucket, r(z, d) is a document-topic bucket, and 52 | q(z, w, d) is a topic-word bucket. 53 | 54 | The values of the three components of the normalization constant, s, r, q, 55 | can be efficiently calculated. The constant s only changes when update 56 | hyperparameters alpha(z). The constant r depends only on the document-topic 57 | counts, so we can calculate it once at the begining of each document and 58 | then update it by subtracting and adding values for the terms involving the 59 | old and new topic at each gibbs update. This process takes constant time, 60 | independent of the number of topics. The topic word constant q changes with 61 | the value of w, so we cannot as easily recycle earlier computation. We can, 62 | however, cache the coefficient q_coefficient for every topic, so calculating 63 | q for a given w consists of one multiply operation for every topic such that 64 | N(w|z) > 0. 65 | 66 | See 'Limin Yao, David Mimno, Andrew McCallum. Efficient methods for topic 67 | model inference on streaming document collections, In SIGKDD, 2009.' for 68 | more details. 69 | """ 70 | 71 | def __init__(self, model, vocabulary): 72 | logging.basicConfig(filename = os.path.join(os.getcwd(), 73 | 'log.txt'), level = logging.DEBUG, filemode = 'a', 74 | format = '%(asctime)s - %(levelname)s: %(message)s') 75 | 76 | self.model = model 77 | self.vocabulary = vocabulary 78 | self.word_prior_sum = self.model.hyper_params.word_prior * \ 79 | self.vocabulary.size() 80 | self.documents = [] # item fmt: common.lda_pb2.Document 81 | 82 | # s(z), smoothing only bucket, indexed by topic z. 83 | self.smoothing_only_bucket = [0.0] * self.model.num_topics 84 | self.smoothing_only_sum = 0.0 85 | 86 | # r(z, d), document-topic bucket, indexed by topic z. 87 | self.doc_topic_bucket = [0.0] * self.model.num_topics 88 | self.doc_topic_sum = 0.0 89 | 90 | # q(z, w, d), topic-word bucket, indexed by topic z. 91 | self.topic_word_bucket = [0.0] * self.model.num_topics 92 | self.topic_word_sum = 0.0 93 | # q_coefficient(z, d), indexed by topic z. 94 | self.topic_word_coef = [0.0] * self.model.num_topics 95 | 96 | def load_corpus(self, corpus_dir): 97 | """Load corpus from a given directory, then initialize the documents 98 | and model. 99 | Line format: token1 \t token2 \t token3 \t ... ... 100 | """ 101 | self.documents = [] 102 | rand = random.Random() 103 | 104 | logging.info('Load corpus from %s.' % corpus_dir) 105 | for root, dirs, files in os.walk(corpus_dir): 106 | for f in files: 107 | filename = os.path.join(root, f) 108 | logging.info('Load filename %s.' % filename) 109 | fp = open(filename, 'r') 110 | for doc_str in fp.readlines(): 111 | doc_str = doc_str.decode('gbk') 112 | doc_tokens = doc_str.strip().split('\t') 113 | if len(doc_tokens) < 2: 114 | continue 115 | document = Document(self.model.num_topics) 116 | document.parse_from_tokens(doc_tokens, rand, self.vocabulary) 117 | if document.num_words() < 2: 118 | continue 119 | self.documents.append(document) 120 | fp.close() 121 | 122 | logging.info('The document number is %d.' % len(self.documents)) 123 | self._initialize_model() 124 | 125 | self._compute_smoothing_only_bucket() 126 | self._initialize_topic_word_coefficient() 127 | 128 | def _initialize_model(self): 129 | self.model.global_topic_hist = [0] * self.model.num_topics 130 | self.model.word_topic_hist = {} 131 | word_topic_stat = {} 132 | 133 | for document in self.documents: 134 | for word in document.words: 135 | if word.id not in word_topic_stat: 136 | word_topic_stat[word.id] = {} 137 | if word.topic not in word_topic_stat[word.id]: 138 | word_topic_stat[word.id][word.topic] = 1 139 | else: 140 | word_topic_stat[word.id][word.topic] += 1 141 | self.model.global_topic_hist[word.topic] += 1 142 | 143 | for word_id, topic_stat in word_topic_stat.iteritems(): 144 | self.model.word_topic_hist[word_id] = \ 145 | OrderedSparseTopicHistogram(self.model.num_topics) 146 | for topic, count in topic_stat.iteritems(): 147 | self.model.word_topic_hist[word_id].increase_topic(topic, count) 148 | 149 | def save_model(self, model_dir, iteration = ''): 150 | """Save lda model to model_dir. 151 | """ 152 | if not os.path.exists(model_dir): 153 | os.mkdir(model_dir) 154 | self.model.save(model_dir + '/' + str(iteration)) 155 | 156 | def save_checkpoint(self, checkpoint_dir, iteration): 157 | """Dump the corpus and current model as checkpoint. 158 | """ 159 | if not os.path.exists(checkpoint_dir): 160 | os.mkdir(checkpoint_dir) 161 | checkpoint_dir += '/' + str(iteration) 162 | logging.info('Save checkpoint to %s.' % checkpoint_dir) 163 | if not os.path.exists(checkpoint_dir): 164 | os.mkdir(checkpoint_dir) 165 | 166 | # dump corpus 167 | corpus_dir = checkpoint_dir + '/corpus' 168 | if not os.path.exists(corpus_dir): 169 | os.mkdir(corpus_dir) 170 | c = 1 171 | fp = open(corpus_dir + '/documents.%d' % c, 'wb') 172 | record_writer = RecordWriter(fp) 173 | for document in self.documents: 174 | if c % 10000 == 0: 175 | fp.close() 176 | fp = open(corpus_dir + '/documents.%d' % c, 'wb') 177 | record_writer = RecordWriter(fp) 178 | record_writer.write(document.serialize_to_string()) 179 | c += 1 180 | fp.close() 181 | 182 | # dump model 183 | self.save_model(checkpoint_dir + '/lda_model') 184 | 185 | def load_checkpoint(self, checkpoint_dir): 186 | """Load checkpoint form checkpoint_dir. 187 | """ 188 | max_iteration = -1 189 | for sub_dir in os.listdir(checkpoint_dir): 190 | iteration = int(sub_dir) 191 | if iteration > max_iteration: 192 | max_iteration = iteration 193 | if max_iteration == -1: 194 | logging.warning('The checkpoint directory %s does not exists.' 195 | % checkpoint_dir) 196 | return None 197 | checkpoint_dir += '/' + str(max_iteration) 198 | logging.info('Load checkpoint from %s.' % checkpoint_dir) 199 | 200 | assert self._load_corpus(checkpoint_dir + '/corpus') 201 | assert self._load_model(checkpoint_dir + '/lda_model') 202 | return max_iteration 203 | 204 | def _load_corpus(self, corpus_dir): 205 | self.documents = [] 206 | if not os.path.exists(corpus_dir): 207 | logging.error('The corpus directory %s does not exists.' 208 | % corpus_dir) 209 | return False 210 | 211 | for root, dirs, files in os.walk(corpus_dir): 212 | for f in files: 213 | filename = os.path.join(root, f) 214 | fp = open(filename, 'rb') 215 | record_reader = RecordReader(fp) 216 | while True: 217 | blob = record_reader.read() 218 | if blob == None: 219 | break 220 | document = Document(self.model.num_topics) 221 | document.parse_from_string(blob) 222 | self.documents.append(document) 223 | 224 | return True 225 | 226 | def _load_model(self, model_dir): 227 | if not os.path.exists(model_dir): 228 | logging.error('The lda model directory %s does not exists.' 229 | % model_dir) 230 | return False 231 | self.model.load(model_dir) 232 | return True 233 | 234 | def gibbs_sampling(self, rand): 235 | """Perform one iteration of Gibbs Sampling. 236 | """ 237 | for document in self.documents: 238 | self._compute_doc_topic_bucket(document) 239 | self._update_topic_word_coefficient(document) 240 | for word in document.words: 241 | self._remove_word_topic(document, word) 242 | self._compute_topic_word_bucket(word) 243 | new_topic = self._sample_new_topic(document, word, rand) 244 | word.topic = new_topic 245 | self._add_word_topic(document, word) 246 | self._reset_topic_word_coefficient(document) 247 | 248 | def _compute_smoothing_only_bucket(self): 249 | """s(z) = alpha(z) * beta / (beta * |V| + N(z)) 250 | """ 251 | self.smoothing_only_sum = 0.0 252 | for topic in xrange(self.model.num_topics): 253 | self.smoothing_only_bucket[topic] = \ 254 | self.model.hyper_params.topic_prior * \ 255 | self.model.hyper_params.word_prior / \ 256 | (self.word_prior_sum + self.model.global_topic_hist[topic]) 257 | self.smoothing_only_sum += self.smoothing_only_bucket[topic] 258 | 259 | def _compute_doc_topic_bucket(self, document): 260 | """r(z, d) = N(z|d) * beta / (beta * |V| + N(z)) 261 | """ 262 | self.doc_topic_sum = 0.0 263 | self.doc_topic_bucket = [0] * self.model.num_topics 264 | for non_zero in document.doc_topic_hist.non_zeros: 265 | self.doc_topic_bucket[non_zero.topic] = \ 266 | non_zero.count * self.model.hyper_params.word_prior / \ 267 | (self.word_prior_sum + 268 | self.model.global_topic_hist[non_zero.topic]) 269 | self.doc_topic_sum += self.doc_topic_bucket[non_zero.topic] 270 | 271 | def _initialize_topic_word_coefficient(self): 272 | """q_coefficient(z) = alpha(z) / (beta * |V| + N(z)), 273 | """ 274 | for topic in xrange(self.model.num_topics): 275 | self.topic_word_coef[topic] = \ 276 | self.model.hyper_params.topic_prior / \ 277 | (self.word_prior_sum + self.model.global_topic_hist[topic]) 278 | 279 | def _update_topic_word_coefficient(self, document): 280 | """q_coefficient(z, d) = (alpha(z) + N(z|d)) / (beta * |V| + N(z)) 281 | """ 282 | for non_zero in document.doc_topic_hist.non_zeros: 283 | self.topic_word_coef[non_zero.topic] = \ 284 | (self.model.hyper_params.topic_prior + non_zero.count) / \ 285 | (self.word_prior_sum + 286 | self.model.global_topic_hist[non_zero.topic]) 287 | 288 | def _reset_topic_word_coefficient(self, document): 289 | """q_coefficient(z) = alpha(z) / (beta * |V| + N(z)), 290 | """ 291 | for non_zero in document.doc_topic_hist.non_zeros: 292 | self.topic_word_coef[non_zero.topic] = \ 293 | self.model.hyper_params.topic_prior / \ 294 | (self.word_prior_sum + 295 | self.model.global_topic_hist[non_zero.topic]) 296 | 297 | def _compute_topic_word_bucket(self, word): 298 | """q(z, w, d) = N(w|z) * (alpha(z) + N(z|d)) / (beta * |V| + N(z)) 299 | = N(w|z) * q_coefficient(z, d) 300 | """ 301 | self.topic_word_sum = 0.0 302 | ordered_sparse_topic_hist = self.model.word_topic_hist[word.id] 303 | for non_zero in ordered_sparse_topic_hist.non_zeros: 304 | self.topic_word_bucket[non_zero.topic] = \ 305 | non_zero.count * self.topic_word_coef[non_zero.topic] 306 | self.topic_word_sum += self.topic_word_bucket[non_zero.topic] 307 | 308 | def _remove_word_topic(self, document, word): 309 | self.model.global_topic_hist[word.topic] -= 1 310 | self.model.word_topic_hist[word.id].decrease_topic(word.topic, 1) 311 | 312 | self.smoothing_only_sum -= self.smoothing_only_bucket[word.topic] 313 | self.doc_topic_sum -= self.doc_topic_bucket[word.topic] 314 | updated_topic_count = document.decrease_topic(word.topic, 1) 315 | 316 | self.smoothing_only_bucket[word.topic] = \ 317 | self.model.hyper_params.topic_prior * \ 318 | self.model.hyper_params.word_prior / \ 319 | (self.word_prior_sum + self.model.global_topic_hist[word.topic]) 320 | self.smoothing_only_sum += self.smoothing_only_bucket[word.topic] 321 | 322 | self.doc_topic_bucket[word.topic] = \ 323 | updated_topic_count * self.model.hyper_params.word_prior / \ 324 | (self.word_prior_sum + self.model.global_topic_hist[word.topic]) 325 | self.doc_topic_sum += self.doc_topic_bucket[word.topic] 326 | 327 | self.topic_word_coef[word.topic] = \ 328 | (self.model.hyper_params.topic_prior + updated_topic_count) / \ 329 | (self.word_prior_sum + self.model.global_topic_hist[word.topic]) 330 | 331 | def _add_word_topic(self, document, word): 332 | self.model.global_topic_hist[word.topic] += 1 333 | self.model.word_topic_hist[word.id].increase_topic(word.topic, 1) 334 | 335 | self.smoothing_only_sum -= self.smoothing_only_bucket[word.topic] 336 | self.doc_topic_sum -= self.doc_topic_bucket[word.topic] 337 | updated_topic_count = document.increase_topic(word.topic, 1) 338 | 339 | self.smoothing_only_bucket[word.topic] = \ 340 | self.model.hyper_params.topic_prior * \ 341 | self.model.hyper_params.word_prior / \ 342 | (self.word_prior_sum + self.model.global_topic_hist[word.topic]) 343 | self.smoothing_only_sum += self.smoothing_only_bucket[word.topic] 344 | 345 | self.doc_topic_bucket[word.topic] = \ 346 | updated_topic_count * self.model.hyper_params.word_prior / \ 347 | (self.word_prior_sum + self.model.global_topic_hist[word.topic]) 348 | self.doc_topic_sum += self.doc_topic_bucket[word.topic] 349 | 350 | self.topic_word_coef[word.topic] = \ 351 | (self.model.hyper_params.topic_prior + updated_topic_count) / \ 352 | (self.word_prior_sum + self.model.global_topic_hist[word.topic]) 353 | 354 | def _sample_new_topic(self, document, word, rand): 355 | """Sampling a new topic for current word. 356 | 357 | Returns the new topic. 358 | """ 359 | total_mass = self.smoothing_only_sum + self.doc_topic_sum + \ 360 | self.topic_word_sum 361 | sample = rand.uniform(0.0, total_mass) 362 | 363 | # In general, self.topic_word_sum >> self.smoothing_only_sum 364 | # self.topic_word_sum >> self.doc_topic_sum 365 | if sample < self.topic_word_sum: 366 | ordered_sparse_topic_hist = self.model.word_topic_hist[word.id] 367 | for non_zero in ordered_sparse_topic_hist.non_zeros: 368 | sample -= self.topic_word_bucket[non_zero.topic] 369 | if sample <= 0: 370 | return non_zero.topic 371 | else: 372 | sample -= self.topic_word_sum 373 | # self.doc_topic_bucket is sparse. 374 | if sample < self.doc_topic_sum: 375 | for non_zero in document.doc_topic_hist.non_zeros: 376 | sample -= self.doc_topic_bucket[non_zero.topic] 377 | if sample <= 0: 378 | return non_zero.topic 379 | else: 380 | sample -= self.doc_topic_sum 381 | for topic, value in enumerate(self.smoothing_only_bucket): 382 | sample -= value 383 | if sample <= 0: 384 | return topic 385 | 386 | logging.error('Sampling word topic failed.') 387 | return None 388 | 389 | -------------------------------------------------------------------------------- /training/sparselda_train_gibbs_sampler_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import random 20 | import unittest 21 | import sys 22 | 23 | sys.path.append('..') 24 | from common.model import Model 25 | from common.vocabulary import Vocabulary 26 | from sparselda_train_gibbs_sampler import SparseLDATrainGibbsSampler 27 | 28 | class SparseLDATrainGibbsSamplerTest(unittest.TestCase): 29 | 30 | def setUp(self): 31 | self.model = Model(20) 32 | self.vocabulary = Vocabulary() 33 | self.vocabulary.load('../testdata/vocabulary.dat') 34 | self.sparselda_train_gibbs_sampler = \ 35 | SparseLDATrainGibbsSampler(self.model, self.vocabulary) 36 | 37 | def test_load_corpus(self): 38 | self.sparselda_train_gibbs_sampler.load_corpus('../testdata/corpus') 39 | self.assertEqual(4, len(self.sparselda_train_gibbs_sampler.documents)) 40 | 41 | def test_gibbs_sampling(self): 42 | self.sparselda_train_gibbs_sampler.load_corpus('../testdata/corpus') 43 | rand = random.Random() 44 | for i in xrange(100): 45 | self.sparselda_train_gibbs_sampler.gibbs_sampling(rand) 46 | if (i + 1) % 10 == 0: 47 | self.sparselda_train_gibbs_sampler.save_checkpoint( 48 | '../testdata/checkpoint', i + 1) 49 | self.sparselda_train_gibbs_sampler.save_model( 50 | '../testdata/train_model', 100) 51 | 52 | def test_load_checkpoint(self): 53 | cur_iteration = self.sparselda_train_gibbs_sampler.load_checkpoint( 54 | '../testdata/checkpoint') 55 | rand = random.Random() 56 | for i in xrange(cur_iteration, 200): 57 | self.sparselda_train_gibbs_sampler.gibbs_sampling(rand) 58 | if (i + 1) % 10 == 0: 59 | self.sparselda_train_gibbs_sampler.save_checkpoint( 60 | '../testdata/checkpoint', i + 1) 61 | 62 | if __name__ == '__main__': 63 | unittest.main() 64 | 65 | -------------------------------------------------------------------------------- /training/topic_words_stat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import sys 20 | 21 | sys.path.append('..') 22 | from common.model import Model 23 | from common.vocabulary import Vocabulary 24 | 25 | class TopicWordsStat(object): 26 | """TopicWords implements topic words tools. 27 | """ 28 | 29 | def __init__(self, model, vocabulary): 30 | self.model = model 31 | self.vocabulary = vocabulary 32 | 33 | def save(self, topic_words_file, accumulated_prob_threshold): 34 | """Save the topic words to file. 35 | """ 36 | fp = open(topic_words_file, 'w') 37 | fp.write(self.get_topic_top_words(accumulated_prob_threshold)) 38 | fp.close() 39 | 40 | def get_topic_top_words(self, accumulated_prob_threshold): 41 | """Returns topics' top words. 42 | """ 43 | topic_top_words = [] 44 | sparse_topic_word_dist = self.compute_topic_word_distribution() 45 | 46 | for topic, word_probs in enumerate(sparse_topic_word_dist): 47 | top_words = [] 48 | top_words.append(str(topic)) 49 | top_words.append(str(self.model.global_topic_hist[topic])) 50 | accumulated_prob = 0.0 51 | for word_prob in word_probs: 52 | top_words.append( 53 | self.vocabulary.word(word_prob[0]).encode('gbk', 'ignore')) 54 | top_words.append(str(word_prob[1])) 55 | accumulated_prob += word_prob[1] 56 | if accumulated_prob > accumulated_prob_threshold: 57 | break 58 | topic_top_words.append('\t'.join(top_words)) 59 | 60 | return '\n'.join(topic_top_words) 61 | 62 | def compute_topic_word_distribution(self): 63 | """Compute the topic word distribution p(w|z), indexed by topic z. 64 | """ 65 | # item fmt: z -> 66 | sparse_topic_word_dist = [] 67 | for topic in xrange(self.model.num_topics): 68 | sparse_topic_word_dist.append([]) 69 | 70 | for word_id, ordered_sparse_topic_hist in \ 71 | self.model.word_topic_hist.iteritems(): 72 | for non_zero in ordered_sparse_topic_hist.non_zeros: 73 | sparse_topic_word_dist[non_zero.topic].append( 74 | [word_id, 75 | (non_zero.count + self.model.hyper_params.word_prior) / 76 | (self.model.hyper_params.word_prior * self.vocabulary.size() + 77 | self.model.global_topic_hist[non_zero.topic])]) 78 | 79 | for topic, word_probs in enumerate(sparse_topic_word_dist): 80 | word_probs.sort(cmp=lambda x,y:cmp(x[1], y[1]), reverse=True) 81 | 82 | return sparse_topic_word_dist 83 | -------------------------------------------------------------------------------- /training/topic_words_stat_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | # Copyright(c) 2013 python-sparselda project. 5 | # Author: Lifeng Wang (ofandywang@gmail.com) 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import unittest 20 | import sys 21 | 22 | sys.path.append('..') 23 | from common.model import Model 24 | from common.vocabulary import Vocabulary 25 | from topic_words_stat import TopicWordsStat 26 | 27 | class TopicWordsStatTest(unittest.TestCase): 28 | 29 | def setUp(self): 30 | self.model = Model(20) 31 | self.model.load('../testdata/lda_model') 32 | self.vocabulary = Vocabulary() 33 | self.vocabulary.load('../testdata/vocabulary.dat') 34 | 35 | self.topic_words_stat = TopicWordsStat(self.model, self.vocabulary) 36 | 37 | def test_save(self): 38 | print self.topic_words_stat.save('../testdata/topic_top_words.dat', 0.8) 39 | 40 | def test_get_topic_top_words(self): 41 | print self.topic_words_stat.get_topic_top_words(0.8) 42 | 43 | def test_compute_topic_word_distribution(self): 44 | print self.topic_words_stat.compute_topic_word_distribution() 45 | 46 | if __name__ == '__main__': 47 | unittest.main() 48 | 49 | --------------------------------------------------------------------------------